diff --git a/src/core/hle/service/nwm/nwm_uds.cpp b/src/core/hle/service/nwm/nwm_uds.cpp index 8cd2fe04b3..bcffae4ae6 100644 --- a/src/core/hle/service/nwm/nwm_uds.cpp +++ b/src/core/hle/service/nwm/nwm_uds.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -75,13 +76,22 @@ static u8 network_channel = DefaultNetworkChannel; // Information about the network that we're currently connected to. static NetworkInfo network_info; +// Mapping of mac addresses to their respective node_ids. +static std::map node_map; + // Event that will generate and send the 802.11 beacon frames. static CoreTiming::EventType* beacon_broadcast_event; +// Callback identifier for the OnWifiPacketReceived event. +static Network::RoomMember::CallbackHandle wifi_packet_received; + // Mutex to synchronize access to the connection status between the emulation thread and the // network thread. static std::mutex connection_status_mutex; +// token for the blocking ConnectToNetwork +static ThreadContinuationToken connection_token; + // Mutex to synchronize access to the list of received beacons between the emulation thread and the // network thread. static std::mutex beacon_mutex; @@ -119,7 +129,12 @@ std::list GetReceivedBeacons(const MacAddress& sender) { /// Sends a WifiPacket to the room we're currently connected to. void SendPacket(Network::WifiPacket& packet) { - // TODO(Subv): Implement. + if (auto room_member = Network::GetRoomMember().lock()) { + if (room_member->GetState() == Network::RoomMember::State::Joined) { + packet.transmitter_address = room_member->GetMacAddress(); + room_member->SendWifiPacket(packet); + } + } } /* @@ -214,6 +229,8 @@ static void HandleEAPoLPacket(const Network::WifiPacket& packet) { network_info.total_nodes++; + node_map[packet.transmitter_address] = node_id; + // Send the EAPoL-Logoff packet. using Network::WifiPacket; WifiPacket eapol_logoff; @@ -237,6 +254,7 @@ static void HandleEAPoLPacket(const Network::WifiPacket& packet) { } auto logoff = ParseEAPoLLogoffFrame(packet.data); + network_info.host_mac_address = packet.transmitter_address; network_info.total_nodes = logoff.connected_nodes; network_info.max_nodes = logoff.max_nodes; @@ -260,6 +278,9 @@ static void HandleEAPoLPacket(const Network::WifiPacket& packet) { // If blocking is implemented this lock needs to be changed, // otherwise it might cause deadlocks connection_status_event->Signal(); + if (connection_token.IsValid()) { + ContinueClientThread(connection_token); + } } } @@ -397,6 +418,36 @@ void HandleAuthenticationFrame(const Network::WifiPacket& packet) { } } +/// Handles the deauthentication frames sent from clients to hosts, when they leave a session +void HandleDeauthenticationFrame(const Network::WifiPacket& packet) { + LOG_DEBUG(Service_NWM, "called"); + std::unique_lock hle_lock(HLE::g_hle_lock, std::defer_lock); + std::unique_lock lock(connection_status_mutex, std::defer_lock); + std::lock(hle_lock, lock); + if (connection_status.status != static_cast(NetworkStatus::ConnectedAsHost)) { + LOG_ERROR(Service_NWM, "Got deauthentication frame but we are not the host"); + return; + } + if (node_map.find(packet.transmitter_address) == node_map.end()) { + LOG_ERROR(Service_NWM, "Got deauthentication frame from unknown node"); + return; + } + + u16 node_id = node_map[packet.transmitter_address]; + auto node = std::find_if(node_info.begin(), node_info.end(), [&node_id](const NodeInfo& info) { + return info.network_node_id == node_id + 1; + }); + ASSERT(node != node_info.end()); + + connection_status.node_bitmask &= ~(1 << node_id); + connection_status.changed_nodes |= 1 << node_id; + connection_status.total_nodes--; + + network_info.total_nodes--; + node_info.erase(node); + connection_status_event->Signal(); +} + static void HandleDataFrame(const Network::WifiPacket& packet) { switch (GetFrameEtherType(packet.data)) { case EtherType::EAPoL: @@ -423,6 +474,9 @@ void OnWifiPacketReceived(const Network::WifiPacket& packet) { case Network::WifiPacket::PacketType::Data: HandleDataFrame(packet); break; + case Network::WifiPacket::PacketType::Deauthentication: + HandleDeauthenticationFrame(packet); + break; } } @@ -435,13 +489,22 @@ void OnWifiPacketReceived(const Network::WifiPacket& packet) { * 1 : Result of function, 0 on success, otherwise error code */ static void Shutdown(Interface* self) { - u32* cmd_buff = Kernel::GetCommandBuffer(); + IPC::RequestParser rp(Kernel::GetCommandBuffer(), 0x03, 0, 0); - // TODO(purpasmart): Verify return header on HW + if (auto room_member = Network::GetRoomMember().lock()) + room_member->Unbind(wifi_packet_received); - cmd_buff[1] = RESULT_SUCCESS.raw; + // TODO(B3N30): Check on HW if Shutdown signals those events + for (auto bind_node : channel_data) { + bind_node.second.event->Signal(); + } + channel_data.clear(); - LOG_WARNING(Service_NWM, "(STUBBED) called"); + recv_buffer_memory.reset(); + + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + rb.Push(RESULT_SUCCESS); + LOG_DEBUG(Service_NWM, "called"); } /** @@ -494,7 +557,6 @@ static void RecvBeaconBroadcastData(Interface* self) { Memory::WriteBlock(current_buffer_pos, &data_reply_header, sizeof(BeaconDataReplyHeader)); current_buffer_pos += sizeof(BeaconDataReplyHeader); - // Write each of the received beacons into the buffer for (const auto& beacon : beacons) { BeaconEntryHeader entry{}; @@ -560,6 +622,16 @@ static void InitializeWithVersion(Interface* self) { ASSERT_MSG(recv_buffer_memory->size == sharedmem_size, "Invalid shared memory size."); + if (auto room_member = Network::GetRoomMember().lock()) { + wifi_packet_received = room_member->BindOnWifiPacketReceived(OnWifiPacketReceived); + } else { + LOG_ERROR(Service_NWM, "Network isn't initalized"); + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + // TODO(B3N30): Find the correct error code and return it; + rb.Push(-1); + return; + } + { std::lock_guard lock(connection_status_mutex); @@ -567,15 +639,14 @@ static void InitializeWithVersion(Interface* self) { // except for the actual status value. connection_status = {}; connection_status.status = static_cast(NetworkStatus::NotConnected); + node_info.clear(); + node_info.push_back(NodeInfo{}); } IPC::RequestBuilder rb = rp.MakeBuilder(1, 2); rb.Push(RESULT_SUCCESS); rb.PushCopyHandles(Kernel::g_handle_table.Create(connection_status_event).Unwrap()); - // TODO(Subv): Connect the OnWifiPacketReceived function to the wifi packet received callback of - // the room we're currently in. - LOG_DEBUG(Service_NWM, "called sharedmem_size=0x%08X, version=0x%08X, sharedmem_handle=0x%08X", sharedmem_size, version, sharedmem_handle); } @@ -857,26 +928,79 @@ static void BeginHostingNetwork(Interface* self) { static void DestroyNetwork(Interface* self) { IPC::RequestParser rp(Kernel::GetCommandBuffer(), 0x08, 0, 0); - // TODO(Subv): Find out what happens if this is called while - // no network is being hosted. - // Unschedule the beacon broadcast event. CoreTiming::UnscheduleEvent(beacon_broadcast_event, 0); - { - std::lock_guard lock(connection_status_mutex); - - // TODO(Subv): Check if connection_status is indeed reset after this call. - connection_status = {}; - connection_status.status = static_cast(NetworkStatus::NotConnected); + // Only a host can destroy + std::lock_guard lock(connection_status_mutex); + if (connection_status.status != static_cast(NetworkStatus::ConnectedAsHost)) { + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + rb.Push(RESULT_SUCCESS); + LOG_WARNING(Service_NWM, "called with status %u", connection_status.status); + return; } + + // TODO(B3N30): Send 3 Deauth packets + + u16_le tmp_node_id = connection_status.network_node_id; + connection_status = {}; + connection_status.status = static_cast(NetworkStatus::NotConnected); + connection_status.network_node_id = tmp_node_id; connection_status_event->Signal(); IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + // TODO(B3N30): HW test if events get signaled here. + for (auto bind_node : channel_data) { + bind_node.second.event->Signal(); + } + channel_data.clear(); + rb.Push(RESULT_SUCCESS); - LOG_WARNING(Service_NWM, "called"); + LOG_DEBUG(Service_NWM, "called"); +} + +static void DisconnectNetwork(Interface* self) { + IPC::RequestParser rp(Kernel::GetCommandBuffer(), 0xA, 0, 0); + IPC::RequestBuilder rb = rp.MakeBuilder(1, 0); + rb.Push(RESULT_SUCCESS); + + using Network::WifiPacket; + WifiPacket deauth; + { + std::lock_guard lock(connection_status_mutex); + if (connection_status.status == static_cast(NetworkStatus::ConnectedAsHost)) { + // A real 3ds makes strange things here. We do the same + u16_le tmp_node_id = connection_status.network_node_id; + connection_status = {}; + connection_status.status = static_cast(NetworkStatus::ConnectedAsHost); + connection_status.network_node_id = tmp_node_id; + LOG_DEBUG(Service_NWM, "called as a host"); + return; + } + u16_le tmp_node_id = connection_status.network_node_id; + connection_status = {}; + connection_status.status = static_cast(NetworkStatus::NotConnected); + connection_status.network_node_id = tmp_node_id; + connection_status_event->Signal(); + + deauth.channel = network_channel; + // TODO(B3N30): Add disconnect reason + deauth.data = {}; + deauth.destination_address = network_info.host_mac_address; + deauth.type = WifiPacket::PacketType::Deauthentication; + } + + SendPacket(deauth); + + // TODO(B3N30): Check on HW if Shutdown signals those events + for (auto bind_node : channel_data) { + bind_node.second.event->Signal(); + } + channel_data.clear(); + + LOG_DEBUG(Service_NWM, "called"); } /** @@ -1075,6 +1199,55 @@ static void GetChannel(Interface* self) { LOG_DEBUG(Service_NWM, "called"); } +/** + * NWM_UDS::ConnectToNetwork service function. + * This connects to the specified network + * Inputs: + * 0 : Command header + * 1 : Connection type: 0x1 = Client, 0x2 = Spectator. + * 2 : Passphrase buffer size + * 3 : (NetworkStructSize<<12) | 0x402 + * 4 : Network struct buffer ptr + * 5 : (PassphraseSize<<12) | 2 + * 6 : Input passphrase buffer ptr + * Outputs: + * 0 : Return header + * 1 : Result of function, 0 on success, otherwise error code + */ +static void ConnectToNetwork(Interface* self) { + IPC::RequestParser rp(Kernel::GetCommandBuffer(), 0x1E, 2, 4); + + u8 connection_type = rp.Pop(); + u32 passphrase_size = rp.Pop(); + + size_t desc_size; + const VAddr network_struct_addr = rp.PopStaticBuffer(&desc_size); + ASSERT(desc_size == sizeof(NetworkInfo)); + + size_t passphrase_desc_size; + const VAddr passphrase_addr = rp.PopStaticBuffer(&passphrase_desc_size); + + Memory::ReadBlock(network_struct_addr, &network_info, sizeof(network_info)); + + // Start the connection sequence + StartConnectionSequence(network_info.host_mac_address); + + connection_token = + SleepClientThread("uds::ConnectToNetwork", [](Kernel::SharedPtr thread) { + VAddr address = thread->GetCommandBufferAddress(); + std::array buffer; + IPC::RequestBuilder rb(buffer.data(), 0x1E, 1, 0); + // TODO(B3N30): Add error handling for host full and timeout + rb.Push(RESULT_SUCCESS); + Memory::WriteBlock(address, &*thread->owner_process, *buffer.data()); + + LOG_DEBUG(Service_NWM, "connection sequence finished"); + }); + + // TODO(B3N30): Add a timout for the connection sequence + LOG_DEBUG(Service_NWM, "called"); +} + /** * NWM_UDS::SetApplicationData service function. * Updates the application data that is being broadcast in the beacon frames @@ -1088,7 +1261,7 @@ static void GetChannel(Interface* self) { * 2 : Channel of the current WiFi network connection. */ static void SetApplicationData(Interface* self) { - IPC::RequestParser rp(Kernel::GetCommandBuffer(), 0x1A, 1, 2); + IPC::RequestParser rp(Kernel::GetCommandBuffer(), 0x10, 1, 2); u32 size = rp.Pop(); @@ -1147,8 +1320,8 @@ static void DecryptBeaconData(Interface* self) { // This size is hardcoded in the 3DS UDS code. ASSERT(output_buffer_size == sizeof(NodeInfo) * UDSMaxNodes); - LOG_WARNING(Service_NWM, "called in0=%08X in1=%08X out=%08X", encrypted_data0_addr, - encrypted_data1_addr, output_buffer_addr); + LOG_DEBUG(Service_NWM, "called in0=%08X in1=%08X out=%08X", encrypted_data0_addr, + encrypted_data1_addr, output_buffer_addr); NetworkInfo net_info; Memory::ReadBlock(network_struct_addr, &net_info, sizeof(net_info)); @@ -1158,15 +1331,10 @@ static void DecryptBeaconData(Interface* self) { std::array oui; Memory::ReadBlock(encrypted_data0_addr, oui.data(), oui.size()); ASSERT_MSG(oui == NintendoOUI, "Unexpected OUI"); - Memory::ReadBlock(encrypted_data1_addr, oui.data(), oui.size()); - ASSERT_MSG(oui == NintendoOUI, "Unexpected OUI"); ASSERT_MSG(Memory::Read8(encrypted_data0_addr + 3) == static_cast(NintendoTagId::EncryptedData0), "Unexpected tag id"); - ASSERT_MSG(Memory::Read8(encrypted_data1_addr + 3) == - static_cast(NintendoTagId::EncryptedData1), - "Unexpected tag id"); std::vector beacon_data(data0_size + data1_size); Memory::ReadBlock(encrypted_data0_addr + 4, beacon_data.data(), data0_size); @@ -1230,26 +1398,6 @@ static void BeaconBroadcastCallback(u64 userdata, int cycles_late) { beacon_broadcast_event, 0); } -/* - * Called when a client connects to an UDS network we're hosting, - * updates the connection status and signals the update event. - * @param network_node_id Network Node Id of the connecting client. - */ -void OnClientConnected(u16 network_node_id) { - std::lock_guard lock(connection_status_mutex); - ASSERT_MSG(connection_status.status == static_cast(NetworkStatus::ConnectedAsHost), - "Can not accept clients if we're not hosting a network"); - ASSERT_MSG(connection_status.total_nodes < connection_status.max_nodes, - "Can not accept connections on a full network"); - - u32 node_id = GetNextAvailableNodeId(); - connection_status.node_bitmask |= 1 << node_id; - connection_status.changed_nodes |= 1 << node_id; - connection_status.nodes[node_id] = network_node_id; - connection_status.total_nodes++; - connection_status_event->Signal(); -} - const Interface::FunctionInfo FunctionTable[] = { {0x000102C2, nullptr, "Initialize (deprecated)"}, {0x00020000, nullptr, "Scrap"}, @@ -1260,7 +1408,7 @@ const Interface::FunctionInfo FunctionTable[] = { {0x00070080, nullptr, "UpdateNetworkAttribute"}, {0x00080000, DestroyNetwork, "DestroyNetwork"}, {0x00090442, nullptr, "ConnectNetwork (deprecated)"}, - {0x000A0000, nullptr, "DisconnectNetwork"}, + {0x000A0000, DisconnectNetwork, "DisconnectNetwork"}, {0x000B0000, GetConnectionStatus, "GetConnectionStatus"}, {0x000D0040, GetNodeInformation, "GetNodeInformation"}, {0x000E0006, nullptr, "DecryptBeaconData (deprecated)"}, @@ -1275,7 +1423,7 @@ const Interface::FunctionInfo FunctionTable[] = { {0x001A0000, GetChannel, "GetChannel"}, {0x001B0302, InitializeWithVersion, "InitializeWithVersion"}, {0x001D0044, BeginHostingNetwork, "BeginHostingNetwork"}, - {0x001E0084, nullptr, "ConnectToNetwork"}, + {0x001E0084, ConnectToNetwork, "ConnectToNetwork"}, {0x001F0006, DecryptBeaconData, "DecryptBeaconData"}, {0x00200040, nullptr, "Flush"}, {0x00210080, nullptr, "SetProbeResponseParam"}, @@ -1305,6 +1453,9 @@ NWM_UDS::~NWM_UDS() { connection_status.status = static_cast(NetworkStatus::NotConnected); } + if (auto room_member = Network::GetRoomMember().lock()) + room_member->Unbind(wifi_packet_received); + CoreTiming::UnscheduleEvent(beacon_broadcast_event, 0); }