diff --git a/src/core/hle/service/nwm/nwm_uds.cpp b/src/core/hle/service/nwm/nwm_uds.cpp index 67a43be007..74d6fdad2e 100644 --- a/src/core/hle/service/nwm/nwm_uds.cpp +++ b/src/core/hle/service/nwm/nwm_uds.cpp @@ -4,6 +4,7 @@ #include #include +#include #include #include #include @@ -27,6 +28,12 @@ namespace Service { namespace NWM { +namespace ErrCodes { +enum { + NotInitialized = 2, +}; +} // namespace ErrCodes + // Event that is signaled every time the connection status changes. static Kernel::SharedPtr connection_status_event; @@ -37,6 +44,8 @@ static Kernel::SharedPtr recv_buffer_memory; // Connection status of this 3DS. static ConnectionStatus connection_status{}; +static std::atomic initialized(false); + /* Node information about the current network. * The amount of elements in this vector is always the maximum number * of nodes specified in the network configuration. @@ -155,7 +164,7 @@ void HandleAssociationResponseFrame(const Network::WifiPacket& packet) { "Could not join network"); { std::lock_guard lock(connection_status_mutex); - ASSERT(connection_status.status == static_cast(NetworkStatus::NotConnected)); + ASSERT(connection_status.status == static_cast(NetworkStatus::Connecting)); } // Send the EAPoL-Start packet to the server. @@ -171,8 +180,9 @@ void HandleAssociationResponseFrame(const Network::WifiPacket& packet) { } static void HandleEAPoLPacket(const Network::WifiPacket& packet) { - std::lock_guard hle_lock(HLE::g_hle_lock); - std::lock_guard lock(connection_status_mutex); + std::unique_lock hle_lock(HLE::g_hle_lock, std::defer_lock); + std::unique_lock lock(connection_status_mutex, std::defer_lock); + std::lock(hle_lock, lock); if (GetEAPoLFrameType(packet.data) == EAPoLStartMagic) { if (connection_status.status != static_cast(NetworkStatus::ConnectedAsHost)) { @@ -220,7 +230,7 @@ static void HandleEAPoLPacket(const Network::WifiPacket& packet) { // The 3ds does this presumably to support spectators. connection_status_event->Signal(); } else { - if (connection_status.status != static_cast(NetworkStatus::NotConnected)) { + if (connection_status.status != static_cast(NetworkStatus::Connecting)) { LOG_DEBUG(Service_NWM, "Connection sequence aborted, because connection status is %u", connection_status.status); return; @@ -249,15 +259,15 @@ static void HandleEAPoLPacket(const Network::WifiPacket& packet) { // Some games require ConnectToNetwork to block, for now it doesn't // If blocking is implemented this lock needs to be changed, // otherwise it might cause deadlocks - std::lock_guard lock(HLE::g_hle_lock); connection_status_event->Signal(); } } static void HandleSecureDataPacket(const Network::WifiPacket& packet) { auto secure_data = ParseSecureDataHeader(packet.data); - std::lock_guard hle_lock(HLE::g_hle_lock); - std::lock_guard lock(connection_status_mutex); + std::unique_lock hle_lock(HLE::g_hle_lock, std::defer_lock); + std::unique_lock lock(connection_status_mutex, std::defer_lock); + std::lock(hle_lock, lock); if (secure_data.src_node_id == connection_status.network_node_id) { // Ignore packets that came from ourselves. @@ -315,7 +325,7 @@ void StartConnectionSequence(const MacAddress& server) { WifiPacket auth_request; { std::lock_guard lock(connection_status_mutex); - ASSERT(connection_status.status == static_cast(NetworkStatus::NotConnected)); + connection_status.status = static_cast(NetworkStatus::Connecting); // TODO(Subv): Handle timeout. @@ -546,6 +556,8 @@ static void InitializeWithVersion(Interface* self) { recv_buffer_memory = Kernel::g_handle_table.Get(sharedmem_handle); + initialized = true; + ASSERT_MSG(recv_buffer_memory->size == sharedmem_size, "Invalid shared memory size."); { @@ -614,8 +626,12 @@ static void GetNodeInformation(Interface* self) { IPC::RequestParser rp(Kernel::GetCommandBuffer(), 0xD, 1, 0); u16 network_node_id = rp.Pop(); - IPC::RequestBuilder rb = rp.MakeBuilder(11, 0); - rb.Push(RESULT_SUCCESS); + if (!initialized) { + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + rb.Push(ResultCode(ErrorDescription::NotInitialized, ErrorModule::UDS, + ErrorSummary::StatusChanged, ErrorLevel::Status)); + return; + } { std::lock_guard lock(connection_status_mutex); @@ -623,7 +639,15 @@ static void GetNodeInformation(Interface* self) { [network_node_id](const NodeInfo& node) { return node.network_node_id == network_node_id; }); - ASSERT(itr != node_info.end()); + if (itr == node_info.end()) { + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + rb.Push(ResultCode(ErrorDescription::NotFound, ErrorModule::UDS, + ErrorSummary::WrongArgument, ErrorLevel::Status)); + return; + } + + IPC::RequestBuilder rb = rp.MakeBuilder(11, 0); + rb.Push(RESULT_SUCCESS); rb.PushRaw(*itr); } LOG_DEBUG(Service_NWM, "called"); @@ -653,13 +677,29 @@ static void Bind(Interface* self) { LOG_DEBUG(Service_NWM, "called"); - if (data_channel == 0) { + if (data_channel == 0 || bind_node_id == 0) { IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); rb.Push(ResultCode(ErrorDescription::NotAuthorized, ErrorModule::UDS, ErrorSummary::WrongArgument, ErrorLevel::Usage)); return; } + constexpr size_t MaxBindNodes = 16; + if (channel_data.size() >= MaxBindNodes) { + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + rb.Push(ResultCode(ErrorDescription::OutOfMemory, ErrorModule::UDS, + ErrorSummary::OutOfResource, ErrorLevel::Status)); + return; + } + + constexpr u32 MinRecvBufferSize = 0x5F4; + if (recv_buffer_size < MinRecvBufferSize) { + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + rb.Push(ResultCode(ErrorDescription::TooLarge, ErrorModule::UDS, + ErrorSummary::WrongArgument, ErrorLevel::Usage)); + return; + } + // Create a new event for this bind node. auto event = Kernel::Event::Create(Kernel::ResetType::OneShot, "NWM::BindNodeEvent" + std::to_string(bind_node_id)); @@ -687,6 +727,12 @@ static void Unbind(Interface* self) { IPC::RequestParser rp(Kernel::GetCommandBuffer(), 0x12, 1, 0); u32 bind_node_id = rp.Pop(); + if (bind_node_id == 0) { + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + rb.Push(ResultCode(ErrorDescription::NotAuthorized, ErrorModule::UDS, + ErrorSummary::WrongArgument, ErrorLevel::Usage)); + return; + } std::lock_guard lock(connection_status_mutex); @@ -699,8 +745,13 @@ static void Unbind(Interface* self) { channel_data.erase(itr); } - IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + IPC::RequestBuilder rb = rp.MakeBuilder(5, 0); rb.Push(RESULT_SUCCESS); + rb.Push(bind_node_id); + // TODO(B3N30): Find out what the other return values are + rb.Push(0); + rb.Push(0); + rb.Push(0); } /** @@ -729,13 +780,14 @@ static void BeginHostingNetwork(Interface* self) { LOG_DEBUG(Service_NWM, "called"); - Memory::ReadBlock(network_info_address, &network_info, sizeof(NetworkInfo)); - - // The real UDS module throws a fatal error if this assert fails. - ASSERT_MSG(network_info.max_nodes > 1, "Trying to host a network of only one member."); - { std::lock_guard lock(connection_status_mutex); + + Memory::ReadBlock(network_info_address, &network_info, sizeof(NetworkInfo)); + + // The real UDS module throws a fatal error if this assert fails. + ASSERT_MSG(network_info.max_nodes > 1, "Trying to host a network of only one member."); + connection_status.status = static_cast(NetworkStatus::ConnectedAsHost); // Ensure the application data size is less than the maximum value. @@ -749,11 +801,13 @@ static void BeginHostingNetwork(Interface* self) { connection_status.max_nodes = network_info.max_nodes; // Resize the nodes list to hold max_nodes. + node_info.clear(); node_info.resize(network_info.max_nodes); // There's currently only one node in the network (the host). connection_status.total_nodes = 1; network_info.total_nodes = 1; + // The host is always the first node connection_status.network_node_id = 1; current_node.network_node_id = 1; @@ -762,12 +816,22 @@ static void BeginHostingNetwork(Interface* self) { connection_status.node_bitmask |= 1; // Notify the application that the first node was set. connection_status.changed_nodes |= 1; - node_info[0] = current_node; - } - // If the game has a preferred channel, use that instead. - if (network_info.channel != 0) - network_channel = network_info.channel; + if (auto room_member = Network::GetRoomMember().lock()) { + if (room_member->IsConnected()) { + network_info.host_mac_address = room_member->GetMacAddress(); + } else { + network_info.host_mac_address = {{0x0, 0x0, 0x0, 0x0, 0x0, 0x0}}; + } + } + node_info[0] = current_node; + + // If the game has a preferred channel, use that instead. + if (network_info.channel != 0) + network_channel = network_info.channel; + else + network_info.channel = DefaultNetworkChannel; + } connection_status_event->Signal(); @@ -775,8 +839,7 @@ static void BeginHostingNetwork(Interface* self) { CoreTiming::ScheduleEvent(msToCycles(DefaultBeaconInterval * MillisecondsPerTU), beacon_broadcast_event, 0); - LOG_WARNING(Service_NWM, - "An UDS network has been created, but broadcasting it is unimplemented."); + LOG_DEBUG(Service_NWM, "An UDS network has been created."); IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); rb.Push(RESULT_SUCCESS); @@ -929,6 +992,14 @@ static void PullPacket(Interface* self) { ASSERT(desc_size == max_out_buff_size); std::lock_guard lock(connection_status_mutex); + if (connection_status.status != static_cast(NetworkStatus::ConnectedAsHost) && + connection_status.status != static_cast(NetworkStatus::ConnectedAsClient) && + connection_status.status != static_cast(NetworkStatus::ConnectedAsSpectator)) { + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + rb.Push(ResultCode(ErrorDescription::NotAuthorized, ErrorModule::UDS, + ErrorSummary::InvalidState, ErrorLevel::Status)); + return; + } auto channel = std::find_if(channel_data.begin(), channel_data.end(), [bind_node_id](const auto& data) { @@ -937,8 +1008,8 @@ static void PullPacket(Interface* self) { if (channel == channel_data.end()) { IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); - // TODO(B3N30): Find the right error code - rb.Push(-1); + rb.Push(ResultCode(ErrorDescription::NotAuthorized, ErrorModule::UDS, + ErrorSummary::WrongArgument, ErrorLevel::Usage)); return; } @@ -959,7 +1030,8 @@ static void PullPacket(Interface* self) { if (data_size > max_out_buff_size) { IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); - rb.Push(0xE10113E9); + rb.Push(ResultCode(ErrorDescription::TooLarge, ErrorModule::UDS, + ErrorSummary::WrongArgument, ErrorLevel::Usage)); return; } @@ -1225,6 +1297,7 @@ NWM_UDS::~NWM_UDS() { channel_data.clear(); connection_status_event = nullptr; recv_buffer_memory = nullptr; + initialized = false; { std::lock_guard lock(connection_status_mutex); diff --git a/src/core/hle/service/nwm/nwm_uds.h b/src/core/hle/service/nwm/nwm_uds.h index f95ae25aad..f1caaf9746 100644 --- a/src/core/hle/service/nwm/nwm_uds.h +++ b/src/core/hle/service/nwm/nwm_uds.h @@ -32,7 +32,7 @@ struct NodeInfo { std::array username; INSERT_PADDING_BYTES(4); u16_le network_node_id; - std::array address; + INSERT_PADDING_BYTES(6); }; static_assert(sizeof(NodeInfo) == 40, "NodeInfo has incorrect size."); @@ -42,6 +42,7 @@ using NodeList = std::vector; enum class NetworkStatus { NotConnected = 3, ConnectedAsHost = 6, + Connecting = 7, ConnectedAsClient = 9, ConnectedAsSpectator = 10, }; diff --git a/src/core/hle/service/nwm/uds_data.h b/src/core/hle/service/nwm/uds_data.h index 4161025a95..59906f677e 100644 --- a/src/core/hle/service/nwm/uds_data.h +++ b/src/core/hle/service/nwm/uds_data.h @@ -52,7 +52,7 @@ struct SecureDataHeader { u16_be dest_node_id; u16_be src_node_id; - u32 GetActualDataSize() { + u32 GetActualDataSize() const { return protocol_size - sizeof(SecureDataHeader); } };