diff --git a/TODO b/TODO index c4f46e3126..097c310a79 100644 --- a/TODO +++ b/TODO @@ -46,6 +46,7 @@ ✔ Replace SERIALIZE_AS_POD with BOOST_IS_BITWISE_SERIALIZABLE @started(20-01-03 13:47) @done(20-01-03 13:58) @lasted(11m22s) ☐ Review constructor/initialization code ☐ Review core timing events +☐ Review base class serialization everywhere ✔ Fix CI @done(19-12-31 21:32) ✔ HW @done(19-08-13 15:41) ✔ GPU regs @done(19-08-13 15:41) @@ -87,7 +88,7 @@ ✔ Shared page @done(20-01-04 21:09) ✔ SVC @done(19-12-22 21:32) Nothing to do - all data is constant - ☐ Thread @started(19-08-13 16:45) + ✔ Thread @started(19-08-13 16:45) @done(20-01-06 20:01) @lasted(20w6d4h16m22s) This requires refactoring wakeup_callback to be an object ref ✔ Timer @done(19-08-13 16:45) ✔ VM Manager @started(19-08-13 16:46) @done(20-01-04 21:09) @lasted(20w4d5h23m42s) diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt index ea117d4409..7ab54242d5 100644 --- a/src/common/CMakeLists.txt +++ b/src/common/CMakeLists.txt @@ -96,6 +96,7 @@ add_library(common STATIC serialization/atomic.h serialization/boost_discrete_interval.hpp serialization/boost_flat_set.h + serialization/boost_small_vector.hpp serialization/boost_vector.hpp string_util.cpp string_util.h diff --git a/src/common/serialization/boost_small_vector.hpp b/src/common/serialization/boost_small_vector.hpp new file mode 100644 index 0000000000..b4e07a8962 --- /dev/null +++ b/src/common/serialization/boost_small_vector.hpp @@ -0,0 +1,144 @@ +#ifndef BOOST_SERIALIZATION_BOOST_SMALL_VECTOR_HPP +#define BOOST_SERIALIZATION_BOOST_SMALL_VECTOR_HPP + +// MS compatible compilers support #pragma once +#if defined(_MSC_VER) +#pragma once +#endif + +/////////1/////////2/////////3/////////4/////////5/////////6/////////7/////////8 +// boost_vector.hpp: serialization for boost vector templates + +// (C) Copyright 2002 Robert Ramey - http://www.rrsd.com . +// fast array serialization (C) Copyright 2005 Matthias Troyer +// Use, modification and distribution is subject to the Boost Software +// License, Version 1.0. (See accompanying file LICENSE_1_0.txt or copy at +// http://www.boost.org/LICENSE_1_0.txt) + +// See http://www.boost.org for updates, documentation, and revision history. + +#include + +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +// default is being compatible with version 1.34.1 files, not 1.35 files +#ifndef BOOST_SERIALIZATION_VECTOR_VERSIONED +#define BOOST_SERIALIZATION_VECTOR_VERSIONED(V) (V == 4 || V == 5) +#endif + +namespace boost { +namespace serialization { + +/////////1/////////2/////////3/////////4/////////5/////////6/////////7/////////8 +// vector< T > + +// the default versions + +template +inline void save(Archive& ar, const boost::container::small_vector& t, + const unsigned int /* file_version */, mpl::false_) { + boost::serialization::stl::save_collection>(ar, + t); +} + +template +inline void load(Archive& ar, boost::container::small_vector& t, + const unsigned int /* file_version */, mpl::false_) { + const boost::archive::library_version_type library_version(ar.get_library_version()); + // retrieve number of elements + item_version_type item_version(0); + collection_size_type count; + ar >> BOOST_SERIALIZATION_NVP(count); + if (boost::archive::library_version_type(3) < library_version) { + ar >> BOOST_SERIALIZATION_NVP(item_version); + } + t.reserve(count); + stl::collection_load_impl(ar, t, count, item_version); +} + +// the optimized versions + +template +inline void save(Archive& ar, const boost::container::small_vector& t, + const unsigned int /* file_version */, mpl::true_) { + const collection_size_type count(t.size()); + ar << BOOST_SERIALIZATION_NVP(count); + if (!t.empty()) + // explict template arguments to pass intel C++ compiler + ar << serialization::make_array(static_cast(&t[0]), + count); +} + +template +inline void load(Archive& ar, boost::container::small_vector& t, + const unsigned int /* file_version */, mpl::true_) { + collection_size_type count(t.size()); + ar >> BOOST_SERIALIZATION_NVP(count); + t.resize(count); + unsigned int item_version = 0; + if (BOOST_SERIALIZATION_VECTOR_VERSIONED(ar.get_library_version())) { + ar >> BOOST_SERIALIZATION_NVP(item_version); + } + if (!t.empty()) + // explict template arguments to pass intel C++ compiler + ar >> serialization::make_array(static_cast(&t[0]), count); +} + +// dispatch to either default or optimized versions + +template +inline void save(Archive& ar, const boost::container::small_vector& t, + const unsigned int file_version) { + typedef typename boost::serialization::use_array_optimization::template apply< + typename remove_const::type>::type use_optimized; + save(ar, t, file_version, use_optimized()); +} + +template +inline void load(Archive& ar, boost::container::small_vector& t, + const unsigned int file_version) { +#ifdef BOOST_SERIALIZATION_VECTOR_135_HPP + if (ar.get_library_version() == boost::archive::library_version_type(5)) { + load(ar, t, file_version, boost::is_arithmetic()); + return; + } +#endif + typedef typename boost::serialization::use_array_optimization::template apply< + typename remove_const::type>::type use_optimized; + load(ar, t, file_version, use_optimized()); +} + +// split non-intrusive serialization function member into separate +// non intrusive save/load member functions +template +inline void serialize(Archive& ar, boost::container::small_vector& t, + const unsigned int file_version) { + boost::serialization::split_free(ar, t, file_version); +} + +// split non-intrusive serialization function member into separate +// non intrusive save/load member functions +template +inline void serialize(Archive& ar, boost::container::small_vector& t, + const unsigned int file_version) { + boost::serialization::split_free(ar, t, file_version); +} + +} // namespace serialization +} // namespace boost + +#endif // BOOST_SERIALIZATION_BOOST_SMALL_VECTOR_HPP diff --git a/src/common/serialization/boost_vector.hpp b/src/common/serialization/boost_vector.hpp index d97ebd2085..55a5b9eaee 100644 --- a/src/common/serialization/boost_vector.hpp +++ b/src/common/serialization/boost_vector.hpp @@ -1,9 +1,9 @@ -#ifndef BOOST_SERIALIZATION_BOOST_VECTOR_HPP +#ifndef BOOST_SERIALIZATION_BOOST_VECTOR_HPP #define BOOST_SERIALIZATION_BOOST_VECTOR_HPP // MS compatible compilers support #pragma once #if defined(_MSC_VER) -# pragma once +#pragma once #endif /////////1/////////2/////////3/////////4/////////5/////////6/////////7/////////8 @@ -24,20 +24,20 @@ #include #include -#include #include #include +#include -#include -#include -#include -#include #include #include +#include +#include +#include +#include // default is being compatible with version 1.34.1 files, not 1.35 files #ifndef BOOST_SERIALIZATION_VECTOR_VERSIONED -#define BOOST_SERIALIZATION_VECTOR_VERSIONED(V) (V==4 || V==5) +#define BOOST_SERIALIZATION_VECTOR_VERSIONED(V) (V == 4 || V == 5) #endif namespace boost { @@ -48,33 +48,23 @@ namespace serialization { // the default versions -template -inline void save( - Archive & ar, - const boost::container::vector &t, - const unsigned int /* file_version */, - mpl::false_ -){ - boost::serialization::stl::save_collection >( - ar, t - ); +template +inline void save(Archive& ar, const boost::container::vector& t, + const unsigned int /* file_version */, mpl::false_) { + boost::serialization::stl::save_collection>(ar, + t); } -template -inline void load( - Archive & ar, - boost::container::vector &t, - const unsigned int /* file_version */, - mpl::false_ -){ - const boost::archive::library_version_type library_version( - ar.get_library_version() - ); +template +inline void load(Archive& ar, boost::container::vector& t, + const unsigned int /* file_version */, mpl::false_) { + const boost::archive::library_version_type library_version(ar.get_library_version()); // retrieve number of elements item_version_type item_version(0); collection_size_type count; ar >> BOOST_SERIALIZATION_NVP(count); - if(boost::archive::library_version_type(3) < library_version){ + if (boost::archive::library_version_type(3) < library_version) { ar >> BOOST_SERIALIZATION_NVP(item_version); } t.reserve(count); @@ -83,107 +73,77 @@ inline void load( // the optimized versions -template -inline void save( - Archive & ar, - const boost::container::vector &t, - const unsigned int /* file_version */, - mpl::true_ -){ +template +inline void save(Archive& ar, const boost::container::vector& t, + const unsigned int /* file_version */, mpl::true_) { const collection_size_type count(t.size()); ar << BOOST_SERIALIZATION_NVP(count); if (!t.empty()) // explict template arguments to pass intel C++ compiler - ar << serialization::make_array( - static_cast(&t[0]), - count - ); + ar << serialization::make_array(static_cast(&t[0]), + count); } -template -inline void load( - Archive & ar, - boost::container::vector &t, - const unsigned int /* file_version */, - mpl::true_ -){ +template +inline void load(Archive& ar, boost::container::vector& t, + const unsigned int /* file_version */, mpl::true_) { collection_size_type count(t.size()); ar >> BOOST_SERIALIZATION_NVP(count); t.resize(count); - unsigned int item_version=0; - if(BOOST_SERIALIZATION_VECTOR_VERSIONED(ar.get_library_version())) { + unsigned int item_version = 0; + if (BOOST_SERIALIZATION_VECTOR_VERSIONED(ar.get_library_version())) { ar >> BOOST_SERIALIZATION_NVP(item_version); } if (!t.empty()) // explict template arguments to pass intel C++ compiler - ar >> serialization::make_array( - static_cast(&t[0]), - count - ); - } + ar >> serialization::make_array(static_cast(&t[0]), count); +} // dispatch to either default or optimized versions -template -inline void save( - Archive & ar, - const boost::container::vector &t, - const unsigned int file_version -){ - typedef typename - boost::serialization::use_array_optimization::template apply< - typename remove_const::type - >::type use_optimized; - save(ar,t,file_version, use_optimized()); +template +inline void save(Archive& ar, const boost::container::vector& t, + const unsigned int file_version) { + typedef typename boost::serialization::use_array_optimization::template apply< + typename remove_const::type>::type use_optimized; + save(ar, t, file_version, use_optimized()); } -template -inline void load( - Archive & ar, - boost::container::vector &t, - const unsigned int file_version -){ +template +inline void load(Archive& ar, boost::container::vector& t, + const unsigned int file_version) { #ifdef BOOST_SERIALIZATION_VECTOR_135_HPP - if (ar.get_library_version()==boost::archive::library_version_type(5)) - { - load(ar,t,file_version, boost::is_arithmetic()); - return; + if (ar.get_library_version() == boost::archive::library_version_type(5)) { + load(ar, t, file_version, boost::is_arithmetic()); + return; } #endif - typedef typename - boost::serialization::use_array_optimization::template apply< - typename remove_const::type - >::type use_optimized; - load(ar,t,file_version, use_optimized()); + typedef typename boost::serialization::use_array_optimization::template apply< + typename remove_const::type>::type use_optimized; + load(ar, t, file_version, use_optimized()); } // split non-intrusive serialization function member into separate // non intrusive save/load member functions -template -inline void serialize( - Archive & ar, - boost::container::vector & t, - const unsigned int file_version -){ +template +inline void serialize(Archive& ar, boost::container::vector& t, + const unsigned int file_version) { boost::serialization::split_free(ar, t, file_version); } // split non-intrusive serialization function member into separate // non intrusive save/load member functions -template -inline void serialize( - Archive & ar, - boost::container::vector & t, - const unsigned int file_version -){ +template +inline void serialize(Archive& ar, boost::container::vector& t, + const unsigned int file_version) { boost::serialization::split_free(ar, t, file_version); } -} // serialization +} // namespace serialization } // namespace boost #include BOOST_SERIALIZATION_COLLECTION_TRAITS(boost::container::vector) -#endif // BOOST_SERIALIZATION_VECTOR_HPP +#endif // BOOST_SERIALIZATION_BOOST_VECTOR_HPP diff --git a/src/core/hle/kernel/address_arbiter.cpp b/src/core/hle/kernel/address_arbiter.cpp index 7b8329c4d4..5074e352fc 100644 --- a/src/core/hle/kernel/address_arbiter.cpp +++ b/src/core/hle/kernel/address_arbiter.cpp @@ -80,16 +80,18 @@ std::shared_ptr KernelSystem::CreateAddressArbiter(std::string n return address_arbiter; } +void AddressArbiter::WakeUp(ThreadWakeupReason reason, std::shared_ptr thread, + std::shared_ptr object) { + ASSERT(reason == ThreadWakeupReason::Timeout); + // Remove the newly-awakened thread from the Arbiter's waiting list. + waiting_threads.erase(std::remove(waiting_threads.begin(), waiting_threads.end(), thread), + waiting_threads.end()); +}; + ResultCode AddressArbiter::ArbitrateAddress(std::shared_ptr thread, ArbitrationType type, VAddr address, s32 value, u64 nanoseconds) { - auto timeout_callback = [this](ThreadWakeupReason reason, std::shared_ptr thread, - std::shared_ptr object) { - ASSERT(reason == ThreadWakeupReason::Timeout); - // Remove the newly-awakened thread from the Arbiter's waiting list. - waiting_threads.erase(std::remove(waiting_threads.begin(), waiting_threads.end(), thread), - waiting_threads.end()); - }; + auto timeout_callback = std::dynamic_pointer_cast(shared_from_this()); switch (type) { diff --git a/src/core/hle/kernel/address_arbiter.h b/src/core/hle/kernel/address_arbiter.h index 19b80315f8..85a2a065a5 100644 --- a/src/core/hle/kernel/address_arbiter.h +++ b/src/core/hle/kernel/address_arbiter.h @@ -12,6 +12,7 @@ #include #include "common/common_types.h" #include "core/hle/kernel/object.h" +#include "core/hle/kernel/thread.h" #include "core/hle/result.h" // Address arbiters are an underlying kernel synchronization object that can be created/used via @@ -34,7 +35,7 @@ enum class ArbitrationType : u32 { DecrementAndWaitIfLessThanWithTimeout, }; -class AddressArbiter final : public Object { +class AddressArbiter final : public Object, public WakeupCallback { public: explicit AddressArbiter(KernelSystem& kernel); ~AddressArbiter() override; @@ -56,6 +57,9 @@ public: ResultCode ArbitrateAddress(std::shared_ptr thread, ArbitrationType type, VAddr address, s32 value, u64 nanoseconds); + void WakeUp(ThreadWakeupReason reason, std::shared_ptr thread, + std::shared_ptr object); + private: KernelSystem& kernel; diff --git a/src/core/hle/kernel/hle_ipc.cpp b/src/core/hle/kernel/hle_ipc.cpp index ab4ecfd059..b1e2f7d8b0 100644 --- a/src/core/hle/kernel/hle_ipc.cpp +++ b/src/core/hle/kernel/hle_ipc.cpp @@ -16,6 +16,46 @@ namespace Kernel { +class HLERequestContext::ThreadCallback : public Kernel::WakeupCallback { + +public: + ThreadCallback(std::shared_ptr context_, + std::shared_ptr callback_) + : context(context_), callback(callback_) {} + void WakeUp(ThreadWakeupReason reason, std::shared_ptr thread, + std::shared_ptr object) { + ASSERT(thread->status == ThreadStatus::WaitHleEvent); + if (callback) { + callback->WakeUp(thread, *context, reason); + } + + auto& process = thread->owner_process; + // We must copy the entire command buffer *plus* the entire static buffers area, since + // the translation might need to read from it in order to retrieve the StaticBuffer + // target addresses. + std::array cmd_buff; + Memory::MemorySystem& memory = context->kernel.memory; + memory.ReadBlock(*process, thread->GetCommandBufferAddress(), cmd_buff.data(), + cmd_buff.size() * sizeof(u32)); + context->WriteToOutgoingCommandBuffer(cmd_buff.data(), *process); + // Copy the translated command buffer back into the thread's command buffer area. + memory.WriteBlock(*process, thread->GetCommandBufferAddress(), cmd_buff.data(), + cmd_buff.size() * sizeof(u32)); + } + +private: + ThreadCallback() = default; + std::shared_ptr callback{}; + std::shared_ptr context{}; + + template + void serialize(Archive& ar, const unsigned int) { + ar& callback; + ar& context; + } + friend class boost::serialization::access; +}; + SessionRequestHandler::SessionInfo::SessionInfo(std::shared_ptr session, std::unique_ptr data) : session(std::move(session)), data(std::move(data)) {} @@ -33,34 +73,16 @@ void SessionRequestHandler::ClientDisconnected(std::shared_ptr se connected_sessions.end()); } -std::shared_ptr HLERequestContext::SleepClientThread(const std::string& reason, - std::chrono::nanoseconds timeout, - WakeupCallback&& callback) { +std::shared_ptr HLERequestContext::SleepClientThread( + const std::string& reason, std::chrono::nanoseconds timeout, + std::shared_ptr callback) { // Put the client thread to sleep until the wait event is signaled or the timeout expires. - thread->wakeup_callback = [context = *this, - callback](ThreadWakeupReason reason, std::shared_ptr thread, - std::shared_ptr object) mutable { - ASSERT(thread->status == ThreadStatus::WaitHleEvent); - callback(thread, context, reason); - - auto& process = thread->owner_process; - // We must copy the entire command buffer *plus* the entire static buffers area, since - // the translation might need to read from it in order to retrieve the StaticBuffer - // target addresses. - std::array cmd_buff; - Memory::MemorySystem& memory = context.kernel.memory; - memory.ReadBlock(*process, thread->GetCommandBufferAddress(), cmd_buff.data(), - cmd_buff.size() * sizeof(u32)); - context.WriteToOutgoingCommandBuffer(cmd_buff.data(), *process); - // Copy the translated command buffer back into the thread's command buffer area. - memory.WriteBlock(*process, thread->GetCommandBufferAddress(), cmd_buff.data(), - cmd_buff.size() * sizeof(u32)); - }; + thread->wakeup_callback = std::make_shared(shared_from_this(), callback); auto event = kernel.CreateEvent(Kernel::ResetType::OneShot, "HLE Pause Event: " + reason); thread->status = ThreadStatus::WaitHleEvent; thread->wait_objects = {event}; - event->AddWaitingThread(SharedFrom(thread)); + event->AddWaitingThread(thread); if (timeout.count() > 0) thread->WakeAfterDelay(timeout.count()); @@ -68,8 +90,10 @@ std::shared_ptr HLERequestContext::SleepClientThread(const std::string& r return event; } +HLERequestContext::HLERequestContext() : kernel(Core::Global()) {} + HLERequestContext::HLERequestContext(KernelSystem& kernel, std::shared_ptr session, - Thread* thread) + std::shared_ptr thread) : kernel(kernel), session(std::move(session)), thread(thread) { cmd_buf[0] = 0; } @@ -98,8 +122,9 @@ void HLERequestContext::AddStaticBuffer(u8 buffer_id, std::vector data) { static_buffers[buffer_id] = std::move(data); } -ResultCode HLERequestContext::PopulateFromIncomingCommandBuffer(const u32_le* src_cmdbuf, - Process& src_process) { +ResultCode HLERequestContext::PopulateFromIncomingCommandBuffer( + const u32_le* src_cmdbuf, std::shared_ptr src_process_) { + auto& src_process = *src_process_; IPC::Header header{src_cmdbuf[0]}; std::size_t untranslated_size = 1u + header.normal_params_size; @@ -158,7 +183,7 @@ ResultCode HLERequestContext::PopulateFromIncomingCommandBuffer(const u32_le* sr } case IPC::DescriptorType::MappedBuffer: { u32 next_id = static_cast(request_mapped_buffers.size()); - request_mapped_buffers.emplace_back(kernel.memory, src_process, descriptor, + request_mapped_buffers.emplace_back(kernel.memory, src_process_, descriptor, src_cmdbuf[i], next_id); cmd_buf[i++] = next_id; break; @@ -170,7 +195,7 @@ ResultCode HLERequestContext::PopulateFromIncomingCommandBuffer(const u32_le* sr if (should_record) { std::vector translated_cmdbuf{cmd_buf.begin(), cmd_buf.begin() + command_size}; - kernel.GetIPCRecorder().SetRequestInfo(SharedFrom(thread), std::move(untranslated_cmdbuf), + kernel.GetIPCRecorder().SetRequestInfo(thread, std::move(untranslated_cmdbuf), std::move(translated_cmdbuf)); } @@ -248,7 +273,7 @@ ResultCode HLERequestContext::WriteToOutgoingCommandBuffer(u32_le* dst_cmdbuf, if (should_record) { std::vector translated_cmdbuf{dst_cmdbuf, dst_cmdbuf + command_size}; - kernel.GetIPCRecorder().SetReplyInfo(SharedFrom(thread), std::move(untranslated_cmdbuf), + kernel.GetIPCRecorder().SetReplyInfo(thread, std::move(untranslated_cmdbuf), std::move(translated_cmdbuf)); } @@ -262,13 +287,15 @@ MappedBuffer& HLERequestContext::GetMappedBuffer(u32 id_from_cmdbuf) { void HLERequestContext::ReportUnimplemented() const { if (kernel.GetIPCRecorder().IsEnabled()) { - kernel.GetIPCRecorder().SetHLEUnimplemented(SharedFrom(thread)); + kernel.GetIPCRecorder().SetHLEUnimplemented(thread); } } -MappedBuffer::MappedBuffer(Memory::MemorySystem& memory, const Process& process, u32 descriptor, - VAddr address, u32 id) - : memory(&memory), id(id), address(address), process(&process) { +MappedBuffer::MappedBuffer() : memory(&Core::Global().Memory()) {} + +MappedBuffer::MappedBuffer(Memory::MemorySystem& memory, std::shared_ptr process, + u32 descriptor, VAddr address, u32 id) + : memory(&memory), id(id), address(address), process(process) { IPC::MappedBufferDescInfo desc{descriptor}; size = desc.size; perms = desc.perms; @@ -287,3 +314,5 @@ void MappedBuffer::Write(const void* src_buffer, std::size_t offset, std::size_t } } // namespace Kernel + +SERIALIZE_EXPORT_IMPL(Kernel::HLERequestContext::ThreadCallback) diff --git a/src/core/hle/kernel/hle_ipc.h b/src/core/hle/kernel/hle_ipc.h index 2177b733ea..56c6d8ce1c 100644 --- a/src/core/hle/kernel/hle_ipc.h +++ b/src/core/hle/kernel/hle_ipc.h @@ -16,6 +16,7 @@ #include #include #include "common/common_types.h" +#include "common/serialization/boost_small_vector.hpp" #include "common/swap.h" #include "core/hle/ipc.h" #include "core/hle/kernel/object.h" @@ -127,7 +128,7 @@ private: class MappedBuffer { public: - MappedBuffer(Memory::MemorySystem& memory, const Process& process, u32 descriptor, + MappedBuffer(Memory::MemorySystem& memory, std::shared_ptr process, u32 descriptor, VAddr address, u32 id); // interface for service @@ -151,9 +152,21 @@ private: Memory::MemorySystem* memory; u32 id; VAddr address; - const Process* process; - std::size_t size; + std::shared_ptr process; + u32 size; IPC::MappedBufferPermissions perms; + + MappedBuffer(); + + template + void serialize(Archive& ar, const unsigned int) { + ar& id; + ar& address; + ar& process; + ar& size; + ar& perms; + } + friend class boost::serialization::access; }; /** @@ -185,9 +198,10 @@ private: * id of the memory interface and let kernel convert it back to client vaddr. No real unmapping is * needed in this case, though. */ -class HLERequestContext { +class HLERequestContext : std::enable_shared_from_this { public: - HLERequestContext(KernelSystem& kernel, std::shared_ptr session, Thread* thread); + HLERequestContext(KernelSystem& kernel, std::shared_ptr session, + std::shared_ptr thread); ~HLERequestContext(); /// Returns a pointer to the IPC command buffer for this request. @@ -203,8 +217,12 @@ public: return session; } - using WakeupCallback = std::function thread, HLERequestContext& context, ThreadWakeupReason reason)>; + class WakeupCallback { + public: + virtual ~WakeupCallback() = default; + virtual void WakeUp(std::shared_ptr thread, HLERequestContext& context, + ThreadWakeupReason reason) = 0; + }; /** * Puts the specified guest thread to sleep until the returned event is signaled or until the @@ -219,7 +237,7 @@ public: */ std::shared_ptr SleepClientThread(const std::string& reason, std::chrono::nanoseconds timeout, - WakeupCallback&& callback); + std::shared_ptr callback); /** * Resolves a object id from the request command buffer into a pointer to an object. See the @@ -259,26 +277,43 @@ public: MappedBuffer& GetMappedBuffer(u32 id_from_cmdbuf); /// Populates this context with data from the requesting process/thread. - ResultCode PopulateFromIncomingCommandBuffer(const u32_le* src_cmdbuf, Process& src_process); + ResultCode PopulateFromIncomingCommandBuffer(const u32_le* src_cmdbuf, + std::shared_ptr src_process); /// Writes data from this context back to the requesting process/thread. ResultCode WriteToOutgoingCommandBuffer(u32_le* dst_cmdbuf, Process& dst_process) const; /// Reports an unimplemented function. void ReportUnimplemented() const; + class ThreadCallback; + friend class ThreadCallback; + private: KernelSystem& kernel; std::array cmd_buf; std::shared_ptr session; - Thread* thread; + std::shared_ptr thread; // TODO(yuriks): Check common usage of this and optimize size accordingly boost::container::small_vector, 8> request_handles; // The static buffers will be created when the IPC request is translated. std::array, IPC::MAX_STATIC_BUFFERS> static_buffers; // The mapped buffers will be created when the IPC request is translated boost::container::small_vector request_mapped_buffers; + + HLERequestContext(); + template + void serialize(Archive& ar, const unsigned int) { + ar& cmd_buf; + ar& session; + ar& thread; + ar& request_handles; + ar& static_buffers; + ar& request_mapped_buffers; + } + friend class boost::serialization::access; }; } // namespace Kernel BOOST_CLASS_EXPORT_KEY(Kernel::SessionRequestHandler::SessionDataBase) +BOOST_CLASS_EXPORT_KEY(Kernel::HLERequestContext::ThreadCallback) diff --git a/src/core/hle/kernel/ipc.cpp b/src/core/hle/kernel/ipc.cpp index e39732c149..eb14888404 100644 --- a/src/core/hle/kernel/ipc.cpp +++ b/src/core/hle/kernel/ipc.cpp @@ -72,7 +72,7 @@ ResultCode TranslateCommandBuffer(Kernel::KernelSystem& kernel, Memory::MemorySy if (handle == CurrentThread) { object = src_thread; } else if (handle == CurrentProcess) { - object = SharedFrom(src_process); + object = src_process; } else if (handle != 0) { object = src_process->handle_table.GetGeneric(handle); if (descriptor == IPC::DescriptorType::MoveHandle) { diff --git a/src/core/hle/kernel/ipc_debugger/recorder.cpp b/src/core/hle/kernel/ipc_debugger/recorder.cpp index 968815c5bd..f1e4a09f16 100644 --- a/src/core/hle/kernel/ipc_debugger/recorder.cpp +++ b/src/core/hle/kernel/ipc_debugger/recorder.cpp @@ -52,7 +52,7 @@ void Recorder::RegisterRequest(const std::shared_ptr& cli RequestRecord record = {/* id */ ++record_count, /* status */ RequestStatus::Sent, - /* client_process */ GetObjectInfo(client_thread->owner_process), + /* client_process */ GetObjectInfo(client_thread->owner_process.get()), /* client_thread */ GetObjectInfo(client_thread.get()), /* client_session */ GetObjectInfo(client_session.get()), /* client_port */ GetObjectInfo(client_session->parent->port.get()), @@ -82,7 +82,7 @@ void Recorder::SetRequestInfo(const std::shared_ptr& client_thre record.translated_request_cmdbuf = std::move(translated_cmdbuf); if (server_thread) { - record.server_process = GetObjectInfo(server_thread->owner_process); + record.server_process = GetObjectInfo(server_thread->owner_process.get()); record.server_thread = GetObjectInfo(server_thread.get()); } else { record.is_hle = true; diff --git a/src/core/hle/kernel/kernel.h b/src/core/hle/kernel/kernel.h index c07bfdbcaf..e7b7314d62 100644 --- a/src/core/hle/kernel/kernel.h +++ b/src/core/hle/kernel/kernel.h @@ -134,7 +134,7 @@ public: */ ResultVal> CreateThread(std::string name, VAddr entry_point, u32 priority, u32 arg, s32 processor_id, - VAddr stack_top, Process& owner_process); + VAddr stack_top, std::shared_ptr owner_process); /** * Creates a semaphore. diff --git a/src/core/hle/kernel/server_session.cpp b/src/core/hle/kernel/server_session.cpp index 4b393b63d9..8bb82fc839 100644 --- a/src/core/hle/kernel/server_session.cpp +++ b/src/core/hle/kernel/server_session.cpp @@ -71,12 +71,12 @@ ResultCode ServerSession::HandleSyncRequest(std::shared_ptr thread) { // If this ServerSession has an associated HLE handler, forward the request to it. if (hle_handler != nullptr) { std::array cmd_buf; - Kernel::Process* current_process = thread->owner_process; + auto current_process = thread->owner_process; kernel.memory.ReadBlock(*current_process, thread->GetCommandBufferAddress(), cmd_buf.data(), cmd_buf.size() * sizeof(u32)); - Kernel::HLERequestContext context(kernel, SharedFrom(this), thread.get()); - context.PopulateFromIncomingCommandBuffer(cmd_buf.data(), *current_process); + Kernel::HLERequestContext context(kernel, SharedFrom(this), thread); + context.PopulateFromIncomingCommandBuffer(cmd_buf.data(), current_process); hle_handler->HandleSyncRequest(context); diff --git a/src/core/hle/kernel/svc.cpp b/src/core/hle/kernel/svc.cpp index b5ebaf936e..7f7cc12725 100644 --- a/src/core/hle/kernel/svc.cpp +++ b/src/core/hle/kernel/svc.cpp @@ -282,7 +282,7 @@ void SVC::ExitProcess() { // Stop all the process threads that are currently waiting for objects. auto& thread_list = kernel.GetThreadManager().GetThreadList(); for (auto& thread : thread_list) { - if (thread->owner_process != current_process.get()) + if (thread->owner_process != current_process) continue; if (thread.get() == kernel.GetThreadManager().GetCurrentThread()) @@ -403,6 +403,73 @@ ResultCode SVC::CloseHandle(Handle handle) { return kernel.GetCurrentProcess()->handle_table.Close(handle); } +static ResultCode ReceiveIPCRequest(Kernel::KernelSystem& kernel, Memory::MemorySystem& memory, + std::shared_ptr server_session, + std::shared_ptr thread); + +class SVC_SyncCallback : public Kernel::WakeupCallback { +public: + SVC_SyncCallback(bool do_output_) : do_output(do_output_) {} + void WakeUp(ThreadWakeupReason reason, std::shared_ptr thread, + std::shared_ptr object) { + + if (reason == ThreadWakeupReason::Timeout) { + thread->SetWaitSynchronizationResult(RESULT_TIMEOUT); + return; + } + + ASSERT(reason == ThreadWakeupReason::Signal); + + thread->SetWaitSynchronizationResult(RESULT_SUCCESS); + + // The wait_all case does not update the output index. + if (do_output) { + thread->SetWaitSynchronizationOutput(thread->GetWaitObjectIndex(object.get())); + } + } + +private: + bool do_output; + + SVC_SyncCallback() = default; + template + void serialize(Archive& ar, const unsigned int) { + ar& do_output; + } + friend class boost::serialization::access; +}; + +class SVC_IPCCallback : public Kernel::WakeupCallback { +public: + SVC_IPCCallback(Core::System& system_) : system(system_) {} + + void WakeUp(ThreadWakeupReason reason, std::shared_ptr thread, + std::shared_ptr object) { + + ASSERT(thread->status == ThreadStatus::WaitSynchAny); + ASSERT(reason == ThreadWakeupReason::Signal); + + ResultCode result = RESULT_SUCCESS; + + if (object->GetHandleType() == HandleType::ServerSession) { + auto server_session = DynamicObjectCast(object); + result = ReceiveIPCRequest(system.Kernel(), system.Memory(), server_session, thread); + } + + thread->SetWaitSynchronizationResult(result); + thread->SetWaitSynchronizationOutput(thread->GetWaitObjectIndex(object.get())); + } + +private: + Core::System& system; + + SVC_IPCCallback() : system(Core::Global()) {} + + template + void serialize(Archive& ar, const unsigned int) {} + friend class boost::serialization::access; +}; + /// Wait for a handle to synchronize, timeout after the specified nanoseconds ResultCode SVC::WaitSynchronization1(Handle handle, s64 nano_seconds) { auto object = kernel.GetCurrentProcess()->handle_table.Get(handle); @@ -426,21 +493,7 @@ ResultCode SVC::WaitSynchronization1(Handle handle, s64 nano_seconds) { // Create an event to wake the thread up after the specified nanosecond delay has passed thread->WakeAfterDelay(nano_seconds); - thread->wakeup_callback = [](ThreadWakeupReason reason, std::shared_ptr thread, - std::shared_ptr object) { - ASSERT(thread->status == ThreadStatus::WaitSynchAny); - - if (reason == ThreadWakeupReason::Timeout) { - thread->SetWaitSynchronizationResult(RESULT_TIMEOUT); - return; - } - - ASSERT(reason == ThreadWakeupReason::Signal); - thread->SetWaitSynchronizationResult(RESULT_SUCCESS); - - // WaitSynchronization1 doesn't have an output index like WaitSynchronizationN, so we - // don't have to do anything else here. - }; + thread->wakeup_callback = std::make_shared(false); system.PrepareReschedule(); @@ -515,20 +568,7 @@ ResultCode SVC::WaitSynchronizationN(s32* out, VAddr handles_address, s32 handle // Create an event to wake the thread up after the specified nanosecond delay has passed thread->WakeAfterDelay(nano_seconds); - thread->wakeup_callback = [](ThreadWakeupReason reason, std::shared_ptr thread, - std::shared_ptr object) { - ASSERT(thread->status == ThreadStatus::WaitSynchAll); - - if (reason == ThreadWakeupReason::Timeout) { - thread->SetWaitSynchronizationResult(RESULT_TIMEOUT); - return; - } - - ASSERT(reason == ThreadWakeupReason::Signal); - - thread->SetWaitSynchronizationResult(RESULT_SUCCESS); - // The wait_all case does not update the output index. - }; + thread->wakeup_callback = std::make_shared(false); system.PrepareReschedule(); @@ -575,20 +615,7 @@ ResultCode SVC::WaitSynchronizationN(s32* out, VAddr handles_address, s32 handle // Create an event to wake the thread up after the specified nanosecond delay has passed thread->WakeAfterDelay(nano_seconds); - thread->wakeup_callback = [](ThreadWakeupReason reason, std::shared_ptr thread, - std::shared_ptr object) { - ASSERT(thread->status == ThreadStatus::WaitSynchAny); - - if (reason == ThreadWakeupReason::Timeout) { - thread->SetWaitSynchronizationResult(RESULT_TIMEOUT); - return; - } - - ASSERT(reason == ThreadWakeupReason::Signal); - - thread->SetWaitSynchronizationResult(RESULT_SUCCESS); - thread->SetWaitSynchronizationOutput(thread->GetWaitObjectIndex(object.get())); - }; + thread->wakeup_callback = std::make_shared(true); system.PrepareReschedule(); @@ -730,22 +757,7 @@ ResultCode SVC::ReplyAndReceive(s32* index, VAddr handles_address, s32 handle_co thread->wait_objects = std::move(objects); - thread->wakeup_callback = [& kernel = this->kernel, &memory = this->memory]( - ThreadWakeupReason reason, std::shared_ptr thread, - std::shared_ptr object) { - ASSERT(thread->status == ThreadStatus::WaitSynchAny); - ASSERT(reason == ThreadWakeupReason::Signal); - - ResultCode result = RESULT_SUCCESS; - - if (object->GetHandleType() == HandleType::ServerSession) { - auto server_session = DynamicObjectCast(object); - result = ReceiveIPCRequest(kernel, memory, server_session, thread); - } - - thread->SetWaitSynchronizationResult(result); - thread->SetWaitSynchronizationOutput(thread->GetWaitObjectIndex(object.get())); - }; + thread->wakeup_callback = std::make_shared(system); system.PrepareReschedule(); @@ -911,7 +923,7 @@ ResultCode SVC::CreateThread(Handle* out_handle, u32 entry_point, u32 arg, VAddr CASCADE_RESULT(std::shared_ptr thread, kernel.CreateThread(name, entry_point, priority, arg, processor_id, stack_top, - *current_process)); + current_process)); thread->context->SetFpscr(FPSCR_DEFAULT_NAN | FPSCR_FLUSH_TO_ZERO | FPSCR_ROUND_TOZERO); // 0x03C00000 @@ -1020,7 +1032,7 @@ ResultCode SVC::GetProcessIdOfThread(u32* process_id, Handle thread_handle) { if (thread == nullptr) return ERR_INVALID_HANDLE; - const std::shared_ptr process = SharedFrom(thread->owner_process); + const std::shared_ptr process = thread->owner_process; ASSERT_MSG(process != nullptr, "Invalid parent process for thread={:#010X}", thread_handle); @@ -1611,3 +1623,6 @@ void SVCContext::CallSVC(u32 immediate) { } } // namespace Kernel + +SERIALIZE_EXPORT_IMPL(Kernel::SVC_SyncCallback) +SERIALIZE_EXPORT_IMPL(Kernel::SVC_IPCCallback) diff --git a/src/core/hle/kernel/svc.h b/src/core/hle/kernel/svc.h index efddff9e88..4cf2654001 100644 --- a/src/core/hle/kernel/svc.h +++ b/src/core/hle/kernel/svc.h @@ -5,6 +5,7 @@ #pragma once #include +#include #include "common/common_types.h" namespace Core { @@ -25,4 +26,10 @@ private: std::unique_ptr impl; }; +class SVC_SyncCallback; +class SVC_IPCCallback; + } // namespace Kernel + +BOOST_CLASS_EXPORT_KEY(Kernel::SVC_SyncCallback) +BOOST_CLASS_EXPORT_KEY(Kernel::SVC_IPCCallback) diff --git a/src/core/hle/kernel/thread.cpp b/src/core/hle/kernel/thread.cpp index 1af2d4f4d6..13a3017a53 100644 --- a/src/core/hle/kernel/thread.cpp +++ b/src/core/hle/kernel/thread.cpp @@ -48,7 +48,7 @@ void Thread::serialize(Archive& ar, const unsigned int file_version) { ar& wait_objects; ar& wait_address; ar& name; - // TODO: How the hell to do wakeup_callback + ar& wakeup_callback; } SERIALIZE_IMPL(Thread) @@ -138,8 +138,8 @@ void ThreadManager::SwitchContext(Thread* new_thread) { ready_queue.remove(new_thread->current_priority, new_thread); new_thread->status = ThreadStatus::Running; - if (previous_process.get() != current_thread->owner_process) { - kernel.SetCurrentProcess(SharedFrom(current_thread->owner_process)); + if (previous_process != current_thread->owner_process) { + kernel.SetCurrentProcess(current_thread->owner_process); } cpu->LoadContext(new_thread->context); @@ -196,7 +196,7 @@ void ThreadManager::ThreadWakeupCallback(u64 thread_id, s64 cycles_late) { // Invoke the wakeup callback before clearing the wait objects if (thread->wakeup_callback) - thread->wakeup_callback(ThreadWakeupReason::Timeout, thread, nullptr); + thread->wakeup_callback->WakeUp(ThreadWakeupReason::Timeout, thread, nullptr); // Remove the thread from each of its waiting objects' waitlists for (auto& object : thread->wait_objects) @@ -313,10 +313,9 @@ static void ResetThreadContext(const std::unique_ptrSetCpsr(USER32MODE | ((entry_point & 1) << 5)); // Usermode and THUMB mode } -ResultVal> KernelSystem::CreateThread(std::string name, VAddr entry_point, - u32 priority, u32 arg, - s32 processor_id, VAddr stack_top, - Process& owner_process) { +ResultVal> KernelSystem::CreateThread( + std::string name, VAddr entry_point, u32 priority, u32 arg, s32 processor_id, VAddr stack_top, + std::shared_ptr owner_process) { // Check if priority is in ranged. Lowest priority -> highest priority id. if (priority > ThreadPrioLowest) { LOG_ERROR(Kernel_SVC, "Invalid thread priority: {}", priority); @@ -330,7 +329,7 @@ ResultVal> KernelSystem::CreateThread(std::string name, // TODO(yuriks): Other checks, returning 0xD9001BEA - if (!Memory::IsValidVirtualAddress(owner_process, entry_point)) { + if (!Memory::IsValidVirtualAddress(*owner_process, entry_point)) { LOG_ERROR(Kernel_SVC, "(name={}): invalid entry {:08x}", name, entry_point); // TODO: Verify error return ResultCode(ErrorDescription::InvalidAddress, ErrorModule::Kernel, @@ -353,10 +352,10 @@ ResultVal> KernelSystem::CreateThread(std::string name, thread->wait_address = 0; thread->name = std::move(name); thread_manager->wakeup_callback_table[thread->thread_id] = thread.get(); - thread->owner_process = &owner_process; + thread->owner_process = owner_process; // Find the next available TLS index, and mark it as used - auto& tls_slots = owner_process.tls_slots; + auto& tls_slots = owner_process->tls_slots; auto [available_page, available_slot, needs_allocation] = GetFreeThreadLocalSlot(tls_slots); @@ -372,13 +371,13 @@ ResultVal> KernelSystem::CreateThread(std::string name, "Not enough space in region to allocate a new TLS page for thread"); return ERR_OUT_OF_MEMORY; } - owner_process.memory_used += Memory::PAGE_SIZE; + owner_process->memory_used += Memory::PAGE_SIZE; tls_slots.emplace_back(0); // The page is completely available at the start available_page = tls_slots.size() - 1; available_slot = 0; // Use the first slot in the new page - auto& vm_manager = owner_process.vm_manager; + auto& vm_manager = owner_process->vm_manager; // Map the page to the current process' address space. vm_manager.MapBackingMemory(Memory::TLS_AREA_VADDR + available_page * Memory::PAGE_SIZE, @@ -391,7 +390,7 @@ ResultVal> KernelSystem::CreateThread(std::string name, thread->tls_address = Memory::TLS_AREA_VADDR + available_page * Memory::PAGE_SIZE + available_slot * Memory::TLS_ENTRY_SIZE; - memory.ZeroBlock(owner_process, thread->tls_address, Memory::TLS_ENTRY_SIZE); + memory.ZeroBlock(*owner_process, thread->tls_address, Memory::TLS_ENTRY_SIZE); // TODO(peachum): move to ScheduleThread() when scheduler is added so selected core is used // to initialize the context @@ -438,7 +437,7 @@ std::shared_ptr SetupMainThread(KernelSystem& kernel, u32 entry_point, u // Initialize new "main" thread auto thread_res = kernel.CreateThread("main", entry_point, priority, 0, owner_process->ideal_processor, - Memory::HEAP_VADDR_END, *owner_process); + Memory::HEAP_VADDR_END, owner_process); std::shared_ptr thread = std::move(thread_res).Unwrap(); diff --git a/src/core/hle/kernel/thread.h b/src/core/hle/kernel/thread.h index 219db9f2d9..9941c76db5 100644 --- a/src/core/hle/kernel/thread.h +++ b/src/core/hle/kernel/thread.h @@ -59,6 +59,15 @@ enum class ThreadWakeupReason { Timeout // The thread was woken up due to a wait timeout. }; +class Thread; + +class WakeupCallback { +public: + virtual ~WakeupCallback() = default; + virtual void WakeUp(ThreadWakeupReason reason, std::shared_ptr thread, + std::shared_ptr object) = 0; +}; + class ThreadManager { public: explicit ThreadManager(Kernel::KernelSystem& kernel); @@ -300,7 +309,7 @@ public: /// Mutexes that this thread is currently waiting for. boost::container::flat_set> pending_mutexes; - Process* owner_process; ///< Process that owns this thread + std::shared_ptr owner_process; ///< Process that owns this thread /// Objects that the thread is waiting on, in the same order as they were // passed to WaitSynchronization1/N. @@ -310,12 +319,10 @@ public: std::string name; - using WakeupCallback = void(ThreadWakeupReason reason, std::shared_ptr thread, - std::shared_ptr object); // Callback that will be invoked when the thread is resumed from a waiting state. If the thread // was waiting via WaitSynchronizationN then the object will be the last object that became // available. In case of a timeout, the object will be nullptr. - std::function wakeup_callback; + std::shared_ptr wakeup_callback; private: ThreadManager& thread_manager; diff --git a/src/core/hle/kernel/wait_object.cpp b/src/core/hle/kernel/wait_object.cpp index 94ae632cde..6fe75b6d82 100644 --- a/src/core/hle/kernel/wait_object.cpp +++ b/src/core/hle/kernel/wait_object.cpp @@ -80,7 +80,7 @@ void WaitObject::WakeupAllWaitingThreads() { // Invoke the wakeup callback before clearing the wait objects if (thread->wakeup_callback) - thread->wakeup_callback(ThreadWakeupReason::Signal, thread, SharedFrom(this)); + thread->wakeup_callback->WakeUp(ThreadWakeupReason::Signal, thread, SharedFrom(this)); for (auto& object : thread->wait_objects) object->RemoveWaitingThread(thread.get()); diff --git a/src/core/hle/service/fs/file.cpp b/src/core/hle/service/fs/file.cpp index 946e27211c..c8727314b8 100644 --- a/src/core/hle/service/fs/file.cpp +++ b/src/core/hle/service/fs/file.cpp @@ -71,12 +71,7 @@ void File::Read(Kernel::HLERequestContext& ctx) { rb.PushMappedBuffer(buffer); std::chrono::nanoseconds read_timeout_ns{backend->GetReadDelayNs(length)}; - ctx.SleepClientThread("file::read", read_timeout_ns, - [](std::shared_ptr /*thread*/, - Kernel::HLERequestContext& /*ctx*/, - Kernel::ThreadWakeupReason /*reason*/) { - // Nothing to do here - }); + ctx.SleepClientThread("file::read", read_timeout_ns, nullptr); } void File::Write(Kernel::HLERequestContext& ctx) { diff --git a/src/core/hle/service/fs/fs_user.cpp b/src/core/hle/service/fs/fs_user.cpp index e027e837fe..010600ee2d 100644 --- a/src/core/hle/service/fs/fs_user.cpp +++ b/src/core/hle/service/fs/fs_user.cpp @@ -76,12 +76,7 @@ void FS_USER::OpenFile(Kernel::HLERequestContext& ctx) { LOG_ERROR(Service_FS, "failed to get a handle for file {}", file_path.DebugStr()); } - ctx.SleepClientThread("fs_user::open", open_timeout_ns, - [](std::shared_ptr /*thread*/, - Kernel::HLERequestContext& /*ctx*/, - Kernel::ThreadWakeupReason /*reason*/) { - // Nothing to do here - }); + ctx.SleepClientThread("fs_user::open", open_timeout_ns, nullptr); } void FS_USER::OpenFileDirectly(Kernel::HLERequestContext& ctx) { @@ -134,12 +129,7 @@ void FS_USER::OpenFileDirectly(Kernel::HLERequestContext& ctx) { file_path.DebugStr(), mode.hex, attributes); } - ctx.SleepClientThread("fs_user::open_directly", open_timeout_ns, - [](std::shared_ptr /*thread*/, - Kernel::HLERequestContext& /*ctx*/, - Kernel::ThreadWakeupReason /*reason*/) { - // Nothing to do here - }); + ctx.SleepClientThread("fs_user::open_directly", open_timeout_ns, nullptr); } void FS_USER::DeleteFile(Kernel::HLERequestContext& ctx) { diff --git a/src/core/hle/service/nwm/nwm_uds.cpp b/src/core/hle/service/nwm/nwm_uds.cpp index dddf55333e..eef8cbd1d3 100644 --- a/src/core/hle/service/nwm/nwm_uds.cpp +++ b/src/core/hle/service/nwm/nwm_uds.cpp @@ -1170,6 +1170,29 @@ void NWM_UDS::GetChannel(Kernel::HLERequestContext& ctx) { LOG_DEBUG(Service_NWM, "called"); } +class NWM_UDS::ThreadCallback : public Kernel::HLERequestContext::WakeupCallback { +public: + ThreadCallback(u16 command_id_) : command_id(command_id_) {} + + void WakeUp(std::shared_ptr thread, Kernel::HLERequestContext& ctx, + Kernel::ThreadWakeupReason reason) { + // TODO(B3N30): Add error handling for host full and timeout + IPC::RequestBuilder rb(ctx, command_id, 1, 0); + rb.Push(RESULT_SUCCESS); + LOG_DEBUG(Service_NWM, "connection sequence finished"); + } + +private: + ThreadCallback() = default; + u16 command_id; + + template + void serialize(Archive& ar, const unsigned int) { + ar& command_id; + } + friend class boost::serialization::access; +}; + void NWM_UDS::ConnectToNetwork(Kernel::HLERequestContext& ctx, u16 command_id, const u8* network_info_buffer, std::size_t network_info_size, u8 connection_type, std::vector passphrase) { @@ -1183,15 +1206,8 @@ void NWM_UDS::ConnectToNetwork(Kernel::HLERequestContext& ctx, u16 command_id, // Since this timing is handled by core_timing it could differ from the 'real world' time static constexpr std::chrono::nanoseconds UDSConnectionTimeout{300000000}; - connection_event = ctx.SleepClientThread( - "uds::ConnectToNetwork", UDSConnectionTimeout, - [command_id](std::shared_ptr thread, Kernel::HLERequestContext& ctx, - Kernel::ThreadWakeupReason reason) { - // TODO(B3N30): Add error handling for host full and timeout - IPC::RequestBuilder rb(ctx, command_id, 1, 0); - rb.Push(RESULT_SUCCESS); - LOG_DEBUG(Service_NWM, "connection sequence finished"); - }); + connection_event = ctx.SleepClientThread("uds::ConnectToNetwork", UDSConnectionTimeout, + std::make_shared(command_id)); } void NWM_UDS::ConnectToNetwork(Kernel::HLERequestContext& ctx) { @@ -1418,3 +1434,5 @@ NWM_UDS::~NWM_UDS() { } } // namespace Service::NWM + +SERIALIZE_EXPORT_IMPL(Service::NWM::NWM_UDS::ThreadCallback) diff --git a/src/core/hle/service/nwm/nwm_uds.h b/src/core/hle/service/nwm/nwm_uds.h index 07dd7e9ba1..b2ab1d76a6 100644 --- a/src/core/hle/service/nwm/nwm_uds.h +++ b/src/core/hle/service/nwm/nwm_uds.h @@ -15,6 +15,7 @@ #include #include #include +#include #include "common/common_types.h" #include "common/swap.h" #include "core/hle/service/service.h" @@ -127,6 +128,8 @@ public: explicit NWM_UDS(Core::System& system); ~NWM_UDS(); + class ThreadCallback; + private: Core::System& system; @@ -560,3 +563,4 @@ private: SERVICE_CONSTRUCT(Service::NWM::NWM_UDS) BOOST_CLASS_EXPORT_KEY(Service::NWM::NWM_UDS) +BOOST_CLASS_EXPORT_KEY(Service::NWM::NWM_UDS::ThreadCallback) diff --git a/src/core/hle/service/sm/srv.cpp b/src/core/hle/service/sm/srv.cpp index 396bd3559e..10179f4163 100644 --- a/src/core/hle/service/sm/srv.cpp +++ b/src/core/hle/service/sm/srv.cpp @@ -3,6 +3,7 @@ // Refer to the license.txt file included. #include +#include "common/archives.h" #include "common/common_types.h" #include "common/logging/log.h" #include "core/core.h" @@ -71,6 +72,46 @@ void SRV::EnableNotification(Kernel::HLERequestContext& ctx) { LOG_WARNING(Service_SRV, "(STUBBED) called"); } +class SRV::ThreadCallback : public Kernel::HLERequestContext::WakeupCallback { + +public: + ThreadCallback(Core::System& system_, std::string name_) : system(system_), name(name_) {} + + void WakeUp(std::shared_ptr thread, Kernel::HLERequestContext& ctx, + Kernel::ThreadWakeupReason reason) { + LOG_ERROR(Service_SRV, "called service={} wakeup", name); + auto client_port = system.ServiceManager().GetServicePort(name); + + auto session = client_port.Unwrap()->Connect(); + if (session.Succeeded()) { + LOG_DEBUG(Service_SRV, "called service={} -> session={}", name, + (*session)->GetObjectId()); + IPC::RequestBuilder rb(ctx, 0x5, 1, 2); + rb.Push(session.Code()); + rb.PushMoveObjects(std::move(session).Unwrap()); + } else if (session.Code() == Kernel::ERR_MAX_CONNECTIONS_REACHED) { + LOG_ERROR(Service_SRV, "called service={} -> ERR_MAX_CONNECTIONS_REACHED", name); + UNREACHABLE(); + } else { + LOG_ERROR(Service_SRV, "called service={} -> error 0x{:08X}", name, session.Code().raw); + IPC::RequestBuilder rb(ctx, 0x5, 1, 0); + rb.Push(session.Code()); + } + } + +private: + Core::System& system; + std::string name; + + ThreadCallback() : system(Core::Global()) {} + + template + void serialize(Archive& ar, const unsigned int) { + ar& name; + } + friend class boost::serialization::access; +}; + /** * SRV::GetServiceHandle service function * Inputs: @@ -100,28 +141,7 @@ void SRV::GetServiceHandle(Kernel::HLERequestContext& ctx) { // TODO(yuriks): Permission checks go here - auto get_handle = [name, this](std::shared_ptr thread, - Kernel::HLERequestContext& ctx, - Kernel::ThreadWakeupReason reason) { - LOG_ERROR(Service_SRV, "called service={} wakeup", name); - auto client_port = system.ServiceManager().GetServicePort(name); - - auto session = client_port.Unwrap()->Connect(); - if (session.Succeeded()) { - LOG_DEBUG(Service_SRV, "called service={} -> session={}", name, - (*session)->GetObjectId()); - IPC::RequestBuilder rb(ctx, 0x5, 1, 2); - rb.Push(session.Code()); - rb.PushMoveObjects(std::move(session).Unwrap()); - } else if (session.Code() == Kernel::ERR_MAX_CONNECTIONS_REACHED) { - LOG_ERROR(Service_SRV, "called service={} -> ERR_MAX_CONNECTIONS_REACHED", name); - UNREACHABLE(); - } else { - LOG_ERROR(Service_SRV, "called service={} -> error 0x{:08X}", name, session.Code().raw); - IPC::RequestBuilder rb(ctx, 0x5, 1, 0); - rb.Push(session.Code()); - } - }; + auto get_handle = std::make_shared(system, name); auto client_port = system.ServiceManager().GetServicePort(name); if (client_port.Failed()) { @@ -266,3 +286,5 @@ SRV::SRV(Core::System& system) : ServiceFramework("srv:", 4), system(system) { SRV::~SRV() = default; } // namespace Service::SM + +SERIALIZE_EXPORT_IMPL(Service::SM::SRV::ThreadCallback) diff --git a/src/core/hle/service/sm/srv.h b/src/core/hle/service/sm/srv.h index 2382f48425..7d17f87a58 100644 --- a/src/core/hle/service/sm/srv.h +++ b/src/core/hle/service/sm/srv.h @@ -6,6 +6,7 @@ #include #include +#include #include "core/hle/service/service.h" namespace Core { @@ -25,6 +26,8 @@ public: explicit SRV(Core::System& system); ~SRV(); + class ThreadCallback; + private: void RegisterClient(Kernel::HLERequestContext& ctx); void EnableNotification(Kernel::HLERequestContext& ctx); @@ -40,3 +43,5 @@ private: }; } // namespace Service::SM + +BOOST_CLASS_EXPORT_KEY(Service::SM::SRV::ThreadCallback) diff --git a/src/tests/core/hle/kernel/hle_ipc.cpp b/src/tests/core/hle/kernel/hle_ipc.cpp index 414d64021f..a4f7c8062c 100644 --- a/src/tests/core/hle/kernel/hle_ipc.cpp +++ b/src/tests/core/hle/kernel/hle_ipc.cpp @@ -37,7 +37,7 @@ TEST_CASE("HLERequestContext::PopulateFromIncomingCommandBuffer", "[core][kernel IPC::MakeHeader(0x1234, 0, 0), }; - context.PopulateFromIncomingCommandBuffer(input, *process); + context.PopulateFromIncomingCommandBuffer(input, process); REQUIRE(context.CommandBuffer()[0] == 0x12340000); } @@ -50,7 +50,7 @@ TEST_CASE("HLERequestContext::PopulateFromIncomingCommandBuffer", "[core][kernel 0xAABBCCDD, }; - context.PopulateFromIncomingCommandBuffer(input, *process); + context.PopulateFromIncomingCommandBuffer(input, process); auto* output = context.CommandBuffer(); REQUIRE(output[1] == 0x12345678); @@ -67,7 +67,7 @@ TEST_CASE("HLERequestContext::PopulateFromIncomingCommandBuffer", "[core][kernel a_handle, }; - context.PopulateFromIncomingCommandBuffer(input, *process); + context.PopulateFromIncomingCommandBuffer(input, process); auto* output = context.CommandBuffer(); REQUIRE(context.GetIncomingHandle(output[2]) == a); @@ -83,7 +83,7 @@ TEST_CASE("HLERequestContext::PopulateFromIncomingCommandBuffer", "[core][kernel a_handle, }; - context.PopulateFromIncomingCommandBuffer(input, *process); + context.PopulateFromIncomingCommandBuffer(input, process); auto* output = context.CommandBuffer(); REQUIRE(context.GetIncomingHandle(output[2]) == a); @@ -103,7 +103,7 @@ TEST_CASE("HLERequestContext::PopulateFromIncomingCommandBuffer", "[core][kernel process->handle_table.Create(c).Unwrap(), }; - context.PopulateFromIncomingCommandBuffer(input, *process); + context.PopulateFromIncomingCommandBuffer(input, process); auto* output = context.CommandBuffer(); REQUIRE(context.GetIncomingHandle(output[2]) == a); @@ -118,7 +118,7 @@ TEST_CASE("HLERequestContext::PopulateFromIncomingCommandBuffer", "[core][kernel 0, }; - auto result = context.PopulateFromIncomingCommandBuffer(input, *process); + auto result = context.PopulateFromIncomingCommandBuffer(input, process); REQUIRE(result == RESULT_SUCCESS); auto* output = context.CommandBuffer(); @@ -132,7 +132,7 @@ TEST_CASE("HLERequestContext::PopulateFromIncomingCommandBuffer", "[core][kernel 0x98989898, }; - context.PopulateFromIncomingCommandBuffer(input, *process); + context.PopulateFromIncomingCommandBuffer(input, process); REQUIRE(context.CommandBuffer()[2] == process->process_id); } @@ -153,7 +153,7 @@ TEST_CASE("HLERequestContext::PopulateFromIncomingCommandBuffer", "[core][kernel target_address, }; - context.PopulateFromIncomingCommandBuffer(input, *process); + context.PopulateFromIncomingCommandBuffer(input, process); CHECK(context.GetStaticBuffer(0) == mem->Vector()); @@ -175,7 +175,7 @@ TEST_CASE("HLERequestContext::PopulateFromIncomingCommandBuffer", "[core][kernel target_address, }; - context.PopulateFromIncomingCommandBuffer(input, *process); + context.PopulateFromIncomingCommandBuffer(input, process); std::vector other_buffer(buffer.GetSize()); context.GetMappedBuffer(0).Read(other_buffer.data(), 0, buffer.GetSize()); @@ -219,7 +219,7 @@ TEST_CASE("HLERequestContext::PopulateFromIncomingCommandBuffer", "[core][kernel target_address_mapped, }; - context.PopulateFromIncomingCommandBuffer(input, *process); + context.PopulateFromIncomingCommandBuffer(input, process); auto* output = context.CommandBuffer(); CHECK(output[1] == 0x12345678); @@ -365,7 +365,7 @@ TEST_CASE("HLERequestContext::WriteToOutgoingCommandBuffer", "[core][kernel]") { target_address, }; - context.PopulateFromIncomingCommandBuffer(input_cmdbuff, *process); + context.PopulateFromIncomingCommandBuffer(input_cmdbuff, process); context.GetMappedBuffer(0).Write(input_buffer.data(), 0, input_buffer.size());