diff --git a/src/core/hle/kernel/address_arbiter.cpp b/src/core/hle/kernel/address_arbiter.cpp index 5074e352fc..97228f153e 100644 --- a/src/core/hle/kernel/address_arbiter.cpp +++ b/src/core/hle/kernel/address_arbiter.cpp @@ -16,8 +16,6 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// // Kernel namespace -SERIALIZE_EXPORT_IMPL(Kernel::AddressArbiter) - namespace Kernel { void AddressArbiter::WaitThread(std::shared_ptr thread, VAddr wait_address) { @@ -80,19 +78,36 @@ std::shared_ptr KernelSystem::CreateAddressArbiter(std::string n return address_arbiter; } +class AddressArbiter::Callback : public WakeupCallback { +public: + Callback(AddressArbiter& _parent) : parent(SharedFrom(&_parent)) {} + std::shared_ptr parent; + + void WakeUp(ThreadWakeupReason reason, std::shared_ptr thread, + std::shared_ptr object) override { + parent->WakeUp(reason, thread, object); + } + +private: + Callback() = default; + template + void serialize(Archive& ar, const unsigned int) { + ar& boost::serialization::base_object(*this); + ar& parent; + } + friend class boost::serialization::access; +}; + void AddressArbiter::WakeUp(ThreadWakeupReason reason, std::shared_ptr thread, std::shared_ptr object) { ASSERT(reason == ThreadWakeupReason::Timeout); // Remove the newly-awakened thread from the Arbiter's waiting list. waiting_threads.erase(std::remove(waiting_threads.begin(), waiting_threads.end(), thread), - waiting_threads.end()); + waiting_threads.end()); }; ResultCode AddressArbiter::ArbitrateAddress(std::shared_ptr thread, ArbitrationType type, VAddr address, s32 value, u64 nanoseconds) { - - auto timeout_callback = std::dynamic_pointer_cast(shared_from_this()); - switch (type) { // Signal thread(s) waiting for arbitrate address... @@ -114,6 +129,9 @@ ResultCode AddressArbiter::ArbitrateAddress(std::shared_ptr thread, Arbi } break; case ArbitrationType::WaitIfLessThanWithTimeout: + if (!timeout_callback) { + timeout_callback = std::make_shared(*this); + } if ((s32)kernel.memory.Read32(address) < value) { thread->wakeup_callback = timeout_callback; thread->WakeAfterDelay(nanoseconds); @@ -130,6 +148,9 @@ ResultCode AddressArbiter::ArbitrateAddress(std::shared_ptr thread, Arbi break; } case ArbitrationType::DecrementAndWaitIfLessThanWithTimeout: { + if (!timeout_callback) { + timeout_callback = std::make_shared(*this); + } s32 memory_value = kernel.memory.Read32(address); if (memory_value < value) { // Only change the memory value if the thread should wait @@ -157,3 +178,6 @@ ResultCode AddressArbiter::ArbitrateAddress(std::shared_ptr thread, Arbi } } // namespace Kernel + +SERIALIZE_EXPORT_IMPL(Kernel::AddressArbiter) +SERIALIZE_EXPORT_IMPL(Kernel::AddressArbiter::Callback) diff --git a/src/core/hle/kernel/address_arbiter.h b/src/core/hle/kernel/address_arbiter.h index c7a263d9ef..de76310660 100644 --- a/src/core/hle/kernel/address_arbiter.h +++ b/src/core/hle/kernel/address_arbiter.h @@ -59,8 +59,7 @@ public: ResultCode ArbitrateAddress(std::shared_ptr thread, ArbitrationType type, VAddr address, s32 value, u64 nanoseconds); - void WakeUp(ThreadWakeupReason reason, std::shared_ptr thread, - std::shared_ptr object); + class Callback; private: KernelSystem& kernel; @@ -78,20 +77,39 @@ private: /// Threads waiting for the address arbiter to be signaled. std::vector> waiting_threads; + std::shared_ptr timeout_callback; + + void WakeUp(ThreadWakeupReason reason, std::shared_ptr thread, + std::shared_ptr object); + + class DummyCallback : public WakeupCallback { + public: + void WakeUp(ThreadWakeupReason reason, std::shared_ptr thread, + std::shared_ptr object) override {} + }; + friend class boost::serialization::access; template void serialize(Archive& ar, const unsigned int file_version) { ar& boost::serialization::base_object(*this); - if (file_version > 0) { - ar& boost::serialization::base_object(*this); + if (file_version == 1) { + // This rigmarole is needed because in past versions, AddressArbiter inherited WakeupCallback + // But it turns out this breaks shared_from_this, so we split it out. + // Using a dummy class to deserialize a base_object allows compatibility to be maintained. + DummyCallback x; + ar& boost::serialization::base_object(x); } ar& name; ar& waiting_threads; + if (file_version > 1) { + ar& timeout_callback; + } } }; } // namespace Kernel BOOST_CLASS_EXPORT_KEY(Kernel::AddressArbiter) -BOOST_CLASS_VERSION(Kernel::AddressArbiter, 1) +BOOST_CLASS_EXPORT_KEY(Kernel::AddressArbiter::Callback) +BOOST_CLASS_VERSION(Kernel::AddressArbiter, 2) CONSTRUCT_KERNEL_OBJECT(Kernel::AddressArbiter)