From ed0319cfed2c99e6366aaf725d96bb28a9332e4d Mon Sep 17 00:00:00 2001
From: Liam <byteslice@airmail.cc>
Date: Sat, 2 Jul 2022 12:33:49 -0400
Subject: [PATCH] common/fiber: make fibers easier to use

---
 src/common/fiber.cpp                |  21 ++---
 src/common/fiber.h                  |   7 +-
 src/core/cpu_manager.cpp            |  51 ++++--------
 src/core/cpu_manager.h              |  21 +++--
 src/core/hle/kernel/k_scheduler.cpp |   7 +-
 src/core/hle/kernel/k_scheduler.h   |   1 -
 src/core/hle/kernel/k_thread.cpp    |  15 ++--
 src/core/hle/kernel/k_thread.h      |   3 +-
 src/tests/common/fibers.cpp         | 123 ++++++++--------------------
 9 files changed, 79 insertions(+), 170 deletions(-)

diff --git a/src/common/fiber.cpp b/src/common/fiber.cpp
index f9aeb692a4..bc92b360b4 100644
--- a/src/common/fiber.cpp
+++ b/src/common/fiber.cpp
@@ -20,10 +20,8 @@ struct Fiber::FiberImpl {
     VirtualBuffer<u8> rewind_stack;
 
     std::mutex guard;
-    std::function<void(void*)> entry_point;
-    std::function<void(void*)> rewind_point;
-    void* rewind_parameter{};
-    void* start_parameter{};
+    std::function<void()> entry_point;
+    std::function<void()> rewind_point;
     std::shared_ptr<Fiber> previous_fiber;
     bool is_thread_fiber{};
     bool released{};
@@ -34,13 +32,8 @@ struct Fiber::FiberImpl {
     boost::context::detail::fcontext_t rewind_context{};
 };
 
-void Fiber::SetStartParameter(void* new_parameter) {
-    impl->start_parameter = new_parameter;
-}
-
-void Fiber::SetRewindPoint(std::function<void(void*)>&& rewind_func, void* rewind_param) {
+void Fiber::SetRewindPoint(std::function<void()>&& rewind_func) {
     impl->rewind_point = std::move(rewind_func);
-    impl->rewind_parameter = rewind_param;
 }
 
 void Fiber::Start(boost::context::detail::transfer_t& transfer) {
@@ -48,7 +41,7 @@ void Fiber::Start(boost::context::detail::transfer_t& transfer) {
     impl->previous_fiber->impl->context = transfer.fctx;
     impl->previous_fiber->impl->guard.unlock();
     impl->previous_fiber.reset();
-    impl->entry_point(impl->start_parameter);
+    impl->entry_point();
     UNREACHABLE();
 }
 
@@ -59,7 +52,7 @@ void Fiber::OnRewind([[maybe_unused]] boost::context::detail::transfer_t& transf
     u8* tmp = impl->stack_limit;
     impl->stack_limit = impl->rewind_stack_limit;
     impl->rewind_stack_limit = tmp;
-    impl->rewind_point(impl->rewind_parameter);
+    impl->rewind_point();
     UNREACHABLE();
 }
 
@@ -73,10 +66,8 @@ void Fiber::RewindStartFunc(boost::context::detail::transfer_t transfer) {
     fiber->OnRewind(transfer);
 }
 
-Fiber::Fiber(std::function<void(void*)>&& entry_point_func, void* start_parameter)
-    : impl{std::make_unique<FiberImpl>()} {
+Fiber::Fiber(std::function<void()>&& entry_point_func) : impl{std::make_unique<FiberImpl>()} {
     impl->entry_point = std::move(entry_point_func);
-    impl->start_parameter = start_parameter;
     impl->stack_limit = impl->stack.data();
     impl->rewind_stack_limit = impl->rewind_stack.data();
     u8* stack_base = impl->stack_limit + default_stack_size;
diff --git a/src/common/fiber.h b/src/common/fiber.h
index 873604bc6d..f24d333a30 100644
--- a/src/common/fiber.h
+++ b/src/common/fiber.h
@@ -29,7 +29,7 @@ namespace Common {
  */
 class Fiber {
 public:
-    Fiber(std::function<void(void*)>&& entry_point_func, void* start_parameter);
+    Fiber(std::function<void()>&& entry_point_func);
     ~Fiber();
 
     Fiber(const Fiber&) = delete;
@@ -43,16 +43,13 @@ public:
     static void YieldTo(std::weak_ptr<Fiber> weak_from, Fiber& to);
     [[nodiscard]] static std::shared_ptr<Fiber> ThreadToFiber();
 
-    void SetRewindPoint(std::function<void(void*)>&& rewind_func, void* rewind_param);
+    void SetRewindPoint(std::function<void()>&& rewind_func);
 
     void Rewind();
 
     /// Only call from main thread's fiber
     void Exit();
 
-    /// Changes the start parameter of the fiber. Has no effect if the fiber already started
-    void SetStartParameter(void* new_parameter);
-
 private:
     Fiber();
 
diff --git a/src/core/cpu_manager.cpp b/src/core/cpu_manager.cpp
index fd6928105a..f184b904b2 100644
--- a/src/core/cpu_manager.cpp
+++ b/src/core/cpu_manager.cpp
@@ -41,51 +41,32 @@ void CpuManager::Shutdown() {
     }
 }
 
-std::function<void(void*)> CpuManager::GetGuestThreadStartFunc() {
-    return GuestThreadFunction;
-}
-
-std::function<void(void*)> CpuManager::GetIdleThreadStartFunc() {
-    return IdleThreadFunction;
-}
-
-std::function<void(void*)> CpuManager::GetShutdownThreadStartFunc() {
-    return ShutdownThreadFunction;
-}
-
-void CpuManager::GuestThreadFunction(void* cpu_manager_) {
-    CpuManager* cpu_manager = static_cast<CpuManager*>(cpu_manager_);
-    if (cpu_manager->is_multicore) {
-        cpu_manager->MultiCoreRunGuestThread();
+void CpuManager::GuestThreadFunction() {
+    if (is_multicore) {
+        MultiCoreRunGuestThread();
     } else {
-        cpu_manager->SingleCoreRunGuestThread();
+        SingleCoreRunGuestThread();
     }
 }
 
-void CpuManager::GuestRewindFunction(void* cpu_manager_) {
-    CpuManager* cpu_manager = static_cast<CpuManager*>(cpu_manager_);
-    if (cpu_manager->is_multicore) {
-        cpu_manager->MultiCoreRunGuestLoop();
+void CpuManager::GuestRewindFunction() {
+    if (is_multicore) {
+        MultiCoreRunGuestLoop();
     } else {
-        cpu_manager->SingleCoreRunGuestLoop();
+        SingleCoreRunGuestLoop();
     }
 }
 
-void CpuManager::IdleThreadFunction(void* cpu_manager_) {
-    CpuManager* cpu_manager = static_cast<CpuManager*>(cpu_manager_);
-    if (cpu_manager->is_multicore) {
-        cpu_manager->MultiCoreRunIdleThread();
+void CpuManager::IdleThreadFunction() {
+    if (is_multicore) {
+        MultiCoreRunIdleThread();
     } else {
-        cpu_manager->SingleCoreRunIdleThread();
+        SingleCoreRunIdleThread();
     }
 }
 
-void CpuManager::ShutdownThreadFunction(void* cpu_manager) {
-    static_cast<CpuManager*>(cpu_manager)->ShutdownThread();
-}
-
-void* CpuManager::GetStartFuncParameter() {
-    return this;
+void CpuManager::ShutdownThreadFunction() {
+    ShutdownThread();
 }
 
 ///////////////////////////////////////////////////////////////////////////////
@@ -97,7 +78,7 @@ void CpuManager::MultiCoreRunGuestThread() {
     kernel.CurrentScheduler()->OnThreadStart();
     auto* thread = kernel.CurrentScheduler()->GetSchedulerCurrentThread();
     auto& host_context = thread->GetHostContext();
-    host_context->SetRewindPoint(GuestRewindFunction, this);
+    host_context->SetRewindPoint([this] { GuestRewindFunction(); });
     MultiCoreRunGuestLoop();
 }
 
@@ -134,7 +115,7 @@ void CpuManager::SingleCoreRunGuestThread() {
     kernel.CurrentScheduler()->OnThreadStart();
     auto* thread = kernel.CurrentScheduler()->GetSchedulerCurrentThread();
     auto& host_context = thread->GetHostContext();
-    host_context->SetRewindPoint(GuestRewindFunction, this);
+    host_context->SetRewindPoint([this] { GuestRewindFunction(); });
     SingleCoreRunGuestLoop();
 }
 
diff --git a/src/core/cpu_manager.h b/src/core/cpu_manager.h
index f0751fc588..76dc58ee1e 100644
--- a/src/core/cpu_manager.h
+++ b/src/core/cpu_manager.h
@@ -50,10 +50,15 @@ public:
     void Initialize();
     void Shutdown();
 
-    static std::function<void(void*)> GetGuestThreadStartFunc();
-    static std::function<void(void*)> GetIdleThreadStartFunc();
-    static std::function<void(void*)> GetShutdownThreadStartFunc();
-    void* GetStartFuncParameter();
+    std::function<void()> GetGuestThreadStartFunc() {
+        return [this] { GuestThreadFunction(); };
+    }
+    std::function<void()> GetIdleThreadStartFunc() {
+        return [this] { IdleThreadFunction(); };
+    }
+    std::function<void()> GetShutdownThreadStartFunc() {
+        return [this] { ShutdownThreadFunction(); };
+    }
 
     void PreemptSingleCore(bool from_running_enviroment = true);
 
@@ -62,10 +67,10 @@ public:
     }
 
 private:
-    static void GuestThreadFunction(void* cpu_manager);
-    static void GuestRewindFunction(void* cpu_manager);
-    static void IdleThreadFunction(void* cpu_manager);
-    static void ShutdownThreadFunction(void* cpu_manager);
+    void GuestThreadFunction();
+    void GuestRewindFunction();
+    void IdleThreadFunction();
+    void ShutdownThreadFunction();
 
     void MultiCoreRunGuestThread();
     void MultiCoreRunGuestLoop();
diff --git a/src/core/hle/kernel/k_scheduler.cpp b/src/core/hle/kernel/k_scheduler.cpp
index d586b3f5c7..d599d2bcb9 100644
--- a/src/core/hle/kernel/k_scheduler.cpp
+++ b/src/core/hle/kernel/k_scheduler.cpp
@@ -622,7 +622,7 @@ void KScheduler::YieldToAnyThread(KernelCore& kernel) {
 }
 
 KScheduler::KScheduler(Core::System& system_, s32 core_id_) : system{system_}, core_id{core_id_} {
-    switch_fiber = std::make_shared<Common::Fiber>(OnSwitch, this);
+    switch_fiber = std::make_shared<Common::Fiber>([this] { SwitchToCurrent(); });
     state.needs_scheduling.store(true);
     state.interrupt_task_thread_runnable = false;
     state.should_count_idle = false;
@@ -778,11 +778,6 @@ void KScheduler::ScheduleImpl() {
     next_scheduler.SwitchContextStep2();
 }
 
-void KScheduler::OnSwitch(void* this_scheduler) {
-    KScheduler* sched = static_cast<KScheduler*>(this_scheduler);
-    sched->SwitchToCurrent();
-}
-
 void KScheduler::SwitchToCurrent() {
     while (true) {
         {
diff --git a/src/core/hle/kernel/k_scheduler.h b/src/core/hle/kernel/k_scheduler.h
index 3f90656eee..bd66bffc4f 100644
--- a/src/core/hle/kernel/k_scheduler.h
+++ b/src/core/hle/kernel/k_scheduler.h
@@ -165,7 +165,6 @@ private:
      */
     void UpdateLastContextSwitchTime(KThread* thread, KProcess* process);
 
-    static void OnSwitch(void* this_scheduler);
     void SwitchToCurrent();
 
     KThread* prev_thread{};
diff --git a/src/core/hle/kernel/k_thread.cpp b/src/core/hle/kernel/k_thread.cpp
index 8d7faa6623..23bf7425ab 100644
--- a/src/core/hle/kernel/k_thread.cpp
+++ b/src/core/hle/kernel/k_thread.cpp
@@ -246,14 +246,12 @@ Result KThread::Initialize(KThreadFunction func, uintptr_t arg, VAddr user_stack
 
 Result KThread::InitializeThread(KThread* thread, KThreadFunction func, uintptr_t arg,
                                  VAddr user_stack_top, s32 prio, s32 core, KProcess* owner,
-                                 ThreadType type, std::function<void(void*)>&& init_func,
-                                 void* init_func_parameter) {
+                                 ThreadType type, std::function<void()>&& init_func) {
     // Initialize the thread.
     R_TRY(thread->Initialize(func, arg, user_stack_top, prio, core, owner, type));
 
     // Initialize emulation parameters.
-    thread->host_context =
-        std::make_shared<Common::Fiber>(std::move(init_func), init_func_parameter);
+    thread->host_context = std::make_shared<Common::Fiber>(std::move(init_func));
     thread->is_single_core = !Settings::values.use_multi_core.GetValue();
 
     return ResultSuccess;
@@ -265,15 +263,13 @@ Result KThread::InitializeDummyThread(KThread* thread) {
 
 Result KThread::InitializeIdleThread(Core::System& system, KThread* thread, s32 virt_core) {
     return InitializeThread(thread, {}, {}, {}, IdleThreadPriority, virt_core, {}, ThreadType::Main,
-                            Core::CpuManager::GetIdleThreadStartFunc(),
-                            system.GetCpuManager().GetStartFuncParameter());
+                            system.GetCpuManager().GetIdleThreadStartFunc());
 }
 
 Result KThread::InitializeHighPriorityThread(Core::System& system, KThread* thread,
                                              KThreadFunction func, uintptr_t arg, s32 virt_core) {
     return InitializeThread(thread, func, arg, {}, {}, virt_core, nullptr, ThreadType::HighPriority,
-                            Core::CpuManager::GetShutdownThreadStartFunc(),
-                            system.GetCpuManager().GetStartFuncParameter());
+                            system.GetCpuManager().GetShutdownThreadStartFunc());
 }
 
 Result KThread::InitializeUserThread(Core::System& system, KThread* thread, KThreadFunction func,
@@ -281,8 +277,7 @@ Result KThread::InitializeUserThread(Core::System& system, KThread* thread, KThr
                                      KProcess* owner) {
     system.Kernel().GlobalSchedulerContext().AddThread(thread);
     return InitializeThread(thread, func, arg, user_stack_top, prio, virt_core, owner,
-                            ThreadType::User, Core::CpuManager::GetGuestThreadStartFunc(),
-                            system.GetCpuManager().GetStartFuncParameter());
+                            ThreadType::User, system.GetCpuManager().GetGuestThreadStartFunc());
 }
 
 void KThread::PostDestroy(uintptr_t arg) {
diff --git a/src/core/hle/kernel/k_thread.h b/src/core/hle/kernel/k_thread.h
index 94c4cd1c86..28cd7ecb09 100644
--- a/src/core/hle/kernel/k_thread.h
+++ b/src/core/hle/kernel/k_thread.h
@@ -729,8 +729,7 @@ private:
     [[nodiscard]] static Result InitializeThread(KThread* thread, KThreadFunction func,
                                                  uintptr_t arg, VAddr user_stack_top, s32 prio,
                                                  s32 core, KProcess* owner, ThreadType type,
-                                                 std::function<void(void*)>&& init_func,
-                                                 void* init_func_parameter);
+                                                 std::function<void()>&& init_func);
 
     static void RestorePriority(KernelCore& kernel_ctx, KThread* thread);
 
diff --git a/src/tests/common/fibers.cpp b/src/tests/common/fibers.cpp
index cfc84d423a..4e29f91996 100644
--- a/src/tests/common/fibers.cpp
+++ b/src/tests/common/fibers.cpp
@@ -43,7 +43,15 @@ class TestControl1 {
 public:
     TestControl1() = default;
 
-    void DoWork();
+    void DoWork() {
+        const u32 id = thread_ids.Get();
+        u32 value = items[id];
+        for (u32 i = 0; i < id; i++) {
+            value++;
+        }
+        results[id] = value;
+        Fiber::YieldTo(work_fibers[id], *thread_fibers[id]);
+    }
 
     void ExecuteThread(u32 id);
 
@@ -54,35 +62,16 @@ public:
     std::vector<u32> results;
 };
 
-static void WorkControl1(void* control) {
-    auto* test_control = static_cast<TestControl1*>(control);
-    test_control->DoWork();
-}
-
-void TestControl1::DoWork() {
-    const u32 id = thread_ids.Get();
-    u32 value = items[id];
-    for (u32 i = 0; i < id; i++) {
-        value++;
-    }
-    results[id] = value;
-    Fiber::YieldTo(work_fibers[id], *thread_fibers[id]);
-}
-
 void TestControl1::ExecuteThread(u32 id) {
     thread_ids.Register(id);
     auto thread_fiber = Fiber::ThreadToFiber();
     thread_fibers[id] = thread_fiber;
-    work_fibers[id] = std::make_shared<Fiber>(std::function<void(void*)>{WorkControl1}, this);
+    work_fibers[id] = std::make_shared<Fiber>([this] { DoWork(); });
     items[id] = rand() % 256;
     Fiber::YieldTo(thread_fibers[id], *work_fibers[id]);
     thread_fibers[id]->Exit();
 }
 
-static void ThreadStart1(u32 id, TestControl1& test_control) {
-    test_control.ExecuteThread(id);
-}
-
 /** This test checks for fiber setup configuration and validates that fibers are
  *  doing all the work required.
  */
@@ -95,7 +84,7 @@ TEST_CASE("Fibers::Setup", "[common]") {
     test_control.results.resize(num_threads, 0);
     std::vector<std::thread> threads;
     for (u32 i = 0; i < num_threads; i++) {
-        threads.emplace_back(ThreadStart1, i, std::ref(test_control));
+        threads.emplace_back([&test_control, i] { test_control.ExecuteThread(i); });
     }
     for (u32 i = 0; i < num_threads; i++) {
         threads[i].join();
@@ -167,21 +156,6 @@ public:
     std::shared_ptr<Common::Fiber> fiber3;
 };
 
-static void WorkControl2_1(void* control) {
-    auto* test_control = static_cast<TestControl2*>(control);
-    test_control->DoWork1();
-}
-
-static void WorkControl2_2(void* control) {
-    auto* test_control = static_cast<TestControl2*>(control);
-    test_control->DoWork2();
-}
-
-static void WorkControl2_3(void* control) {
-    auto* test_control = static_cast<TestControl2*>(control);
-    test_control->DoWork3();
-}
-
 void TestControl2::ExecuteThread(u32 id) {
     thread_ids.Register(id);
     auto thread_fiber = Fiber::ThreadToFiber();
@@ -193,18 +167,6 @@ void TestControl2::Exit() {
     thread_fibers[id]->Exit();
 }
 
-static void ThreadStart2_1(u32 id, TestControl2& test_control) {
-    test_control.ExecuteThread(id);
-    test_control.CallFiber1();
-    test_control.Exit();
-}
-
-static void ThreadStart2_2(u32 id, TestControl2& test_control) {
-    test_control.ExecuteThread(id);
-    test_control.CallFiber2();
-    test_control.Exit();
-}
-
 /** This test checks for fiber thread exchange configuration and validates that fibers are
  *  that a fiber has been successfully transferred from one thread to another and that the TLS
  *  region of the thread is kept while changing fibers.
@@ -212,14 +174,19 @@ static void ThreadStart2_2(u32 id, TestControl2& test_control) {
 TEST_CASE("Fibers::InterExchange", "[common]") {
     TestControl2 test_control{};
     test_control.thread_fibers.resize(2);
-    test_control.fiber1 =
-        std::make_shared<Fiber>(std::function<void(void*)>{WorkControl2_1}, &test_control);
-    test_control.fiber2 =
-        std::make_shared<Fiber>(std::function<void(void*)>{WorkControl2_2}, &test_control);
-    test_control.fiber3 =
-        std::make_shared<Fiber>(std::function<void(void*)>{WorkControl2_3}, &test_control);
-    std::thread thread1(ThreadStart2_1, 0, std::ref(test_control));
-    std::thread thread2(ThreadStart2_2, 1, std::ref(test_control));
+    test_control.fiber1 = std::make_shared<Fiber>([&test_control] { test_control.DoWork1(); });
+    test_control.fiber2 = std::make_shared<Fiber>([&test_control] { test_control.DoWork2(); });
+    test_control.fiber3 = std::make_shared<Fiber>([&test_control] { test_control.DoWork3(); });
+    std::thread thread1{[&test_control] {
+        test_control.ExecuteThread(0);
+        test_control.CallFiber1();
+        test_control.Exit();
+    }};
+    std::thread thread2{[&test_control] {
+        test_control.ExecuteThread(1);
+        test_control.CallFiber2();
+        test_control.Exit();
+    }};
     thread1.join();
     thread2.join();
     REQUIRE(test_control.assert1);
@@ -270,16 +237,6 @@ public:
     std::shared_ptr<Common::Fiber> fiber2;
 };
 
-static void WorkControl3_1(void* control) {
-    auto* test_control = static_cast<TestControl3*>(control);
-    test_control->DoWork1();
-}
-
-static void WorkControl3_2(void* control) {
-    auto* test_control = static_cast<TestControl3*>(control);
-    test_control->DoWork2();
-}
-
 void TestControl3::ExecuteThread(u32 id) {
     thread_ids.Register(id);
     auto thread_fiber = Fiber::ThreadToFiber();
@@ -291,12 +248,6 @@ void TestControl3::Exit() {
     thread_fibers[id]->Exit();
 }
 
-static void ThreadStart3(u32 id, TestControl3& test_control) {
-    test_control.ExecuteThread(id);
-    test_control.CallFiber1();
-    test_control.Exit();
-}
-
 /** This test checks for one two threads racing for starting the same fiber.
  *  It checks execution occurred in an ordered manner and by no time there were
  *  two contexts at the same time.
@@ -304,12 +255,15 @@ static void ThreadStart3(u32 id, TestControl3& test_control) {
 TEST_CASE("Fibers::StartRace", "[common]") {
     TestControl3 test_control{};
     test_control.thread_fibers.resize(2);
-    test_control.fiber1 =
-        std::make_shared<Fiber>(std::function<void(void*)>{WorkControl3_1}, &test_control);
-    test_control.fiber2 =
-        std::make_shared<Fiber>(std::function<void(void*)>{WorkControl3_2}, &test_control);
-    std::thread thread1(ThreadStart3, 0, std::ref(test_control));
-    std::thread thread2(ThreadStart3, 1, std::ref(test_control));
+    test_control.fiber1 = std::make_shared<Fiber>([&test_control] { test_control.DoWork1(); });
+    test_control.fiber2 = std::make_shared<Fiber>([&test_control] { test_control.DoWork2(); });
+    const auto race_function{[&test_control](u32 id) {
+        test_control.ExecuteThread(id);
+        test_control.CallFiber1();
+        test_control.Exit();
+    }};
+    std::thread thread1([&] { race_function(0); });
+    std::thread thread2([&] { race_function(1); });
     thread1.join();
     thread2.join();
     REQUIRE(test_control.value1 == 1);
@@ -319,12 +273,10 @@ TEST_CASE("Fibers::StartRace", "[common]") {
 
 class TestControl4;
 
-static void WorkControl4(void* control);
-
 class TestControl4 {
 public:
     TestControl4() {
-        fiber1 = std::make_shared<Fiber>(std::function<void(void*)>{WorkControl4}, this);
+        fiber1 = std::make_shared<Fiber>([this] { DoWork(); });
         goal_reached = false;
         rewinded = false;
     }
@@ -336,7 +288,7 @@ public:
     }
 
     void DoWork() {
-        fiber1->SetRewindPoint(std::function<void(void*)>{WorkControl4}, this);
+        fiber1->SetRewindPoint([this] { DoWork(); });
         if (rewinded) {
             goal_reached = true;
             Fiber::YieldTo(fiber1, *thread_fiber);
@@ -351,11 +303,6 @@ public:
     bool rewinded;
 };
 
-static void WorkControl4(void* control) {
-    auto* test_control = static_cast<TestControl4*>(control);
-    test_control->DoWork();
-}
-
 TEST_CASE("Fibers::Rewind", "[common]") {
     TestControl4 test_control{};
     test_control.Execute();