From e13735b6245af0dab111af3a1c16aefa6ee4db67 Mon Sep 17 00:00:00 2001 From: Wunk Date: Sun, 5 Nov 2023 12:40:31 -0800 Subject: [PATCH] video_core: Implement an arm64 shader-jit backend (#7002) * externals: Add oaksim submodule Used for emitting ARM64 assembly * common: Implement aarch64 ABI Utilize oaknut to implement a stack frame. * tests: Allow shader-jit tests for x64 and a64 Run the shader-jit tests for both x86_64 and arm64 targets * video_core: Initialize arm64 shader-jit backend Passes all current unit tests! * shader_jit_a64: protect/unprotect memory when jit-ing Required on MacOS. Memory needs to be fully unprotected and then re-protected when writing or there will be memory access errors on MacOS. * shader_jit_a64: Fix ARM64-Imm overflow These conditionals were throwing exceptions since the immediate values were overflowing the available space in the `EOR` instructions. Instead they are generated from `MOV` and then `EOR`-ed after. * shader_jit_a64: Fix Geometry shader conditional * shader_jit_a64: Replace `ADRL` with `MOVP2R` Fixes some immediate-generation exceptions. * common/aarch64: Fix CallFarFunction * shader_jit_a64: Optimize `SantitizedMul` Co-authored-by: merryhime * shader_jit_a64: Fix address register offset behavior Based on https://github.com/citra-emu/citra/pull/6942 Passes unit tests. * shader_jit_a64: Fix `RET` address offset A64 stack is 16-byte aligned rather than 8. So a direct port of the x64 code won't work. Fixes weird branches into invalid memory for any shaders with subroutines. * shader_jit_a64: Increase max program size Tuned for A64 program size. * shader_jit_a64: Use `UBFX` for extracting loop-state Co-authored-by: JosJuice * shader_jit_a64: Optimize `SUB+CMP` to `SUBS` Co-authored-by: JosJuice * shader_jit_a64: Optimize `CMP+B` to `CBNZ` Co-authored-by: JosJuice * shader_jit_a64: Use `FMOV` for `ONE` vector Co-authored-by: JosJuice * shader_jit_a64: Remove x86-specific documentation * shader_jit_a64: Use `UBFX` to extract exponent Co-authored-by: JosJuice * shader_jit_a64: Remove redundant MIN/MAX `SRC2`-NaN check Special handling only needs to check SRC1 for NaN, not SRC2. It would work as follows in the four possible cases: No NaN: No special handling needed. Only SRC1 is NaN: The special handling is triggered because SRC1 is NaN, and SRC2 is picked. Only SRC2 is NaN: FMAX automatically picks SRC2 because it always picks the NaN if there is one. Both SRC1 and SRC2 are NaN: The special handling is triggered because SRC1 is NaN, and SRC2 is picked. Co-authored-by: JosJuice * shader_jit/tests:: Add catch-stringifier for vec2f/vec3f * shader_jit/tests: Add Dest Mask unit test * shader_jit_a64: Fix Dest-Mask `BSL` operand order Passes the dest-mask unit tests now. * shader_jit_a64: Use `MOVI` for DestEnable mask Accelerate certain cases of masking with MOVI as well Co-authored-by: JosJuice * shader_jit/tests: Add source-swizzle unit test This is not expansive. Generating all `4^4` cases seems to make Catch2 crash. So I've added some component-masking(non-reordering) tests based on the Dest-Mask unit-test and some additional ones to test broadcasts/splats and component re-ordering. * shader_jit_a64: Fix swizzle index generation This was still generating `SHUFPS` indices and not the ones that we wanted for the `TBL` instruction. Passes all unit tests now. * shader_jit/tests: Add `ShaderSetup` constructor to `ShaderTest` Rather than using the direct output of `CompileShaderSetup` allow a `ShaderSetup` object to be passed in directly. This enabled the ability emit assembly that is not directly supported by nihstro. * shader_jit/tests: Add `CALL` unit-test Tests nested `CALL` instructions to eventually reach an `EX2` instruction. EX2 is picked in particular since it is implemented as an even deeper dispatch and ensures subroutines are properly implemented between `CALL` instructions and implementation-calls. * shader_jit_a64: Fix nested `BL` subroutines `lr` was getting writen over by nested calls to `BL`, causing undefined behavior with mixtures of `CALL`, `EX2`, and `LG2` instructions. Each usage of `BL` is now protected with a stach push/pop to preserve and restore teh `lr` register to allow nested subroutines to work properly. * shader_jit/tests: Allocate generated tests on heap Each of these generated shader-test objects were causing the stack to overflow. Allocate each of the generated tests on the heap and use unique_ptr so they only exist within the life-time of the `REQUIRE` statement. * shader_jit_a64: Preserve `lr` register from external function calls `EMIT` makes an external function call, and should be preserving `lr` * shader_jit/tests: Add `MAD` unit-test The Inline Asm version requires an upstream fix: https://github.com/neobrain/nihstro/issues/68 Instead, the program code is manually configured and added. * shader_jit/tests: Fix uninitialized instructions These `union`-type instruction-types were uninitialized, causing tests to indeterminantly fail at times. * shader_jit_a64: Remove unneeded `MOV` Residue from the direct-port of x64 code. * shader_jit_a64: Use `std::array` for `instr_table` Add some type-safety and const-correctness around this type as well. * shader_jit_a64: Avoid c-style offset casting Add some more const-correctness to this function as well. * video_core: Add arch preprocessor comments * common/aarch64: Use X16 as the veneer register https://developer.arm.com/documentation/102374/0101/Procedure-Call-Standard * shader_jit/tests: Add uniform reading unit-test Particularly to ensure that addresses are being properly truncated * common/aarch64: Use `X0` as `ABI_RETURN` `X8` is used as the indirect return result value in the case that the result is bigger than 128-bits. Principally `X0` is the general-case return register though. * common/aarch64: Add veneer register note `LR` is generally overwritten by `BLR` anyways, and would also be a safe veneer to utilize for far-calls. * shader_jit_a64: Remove unneeded scratch register from `SanitizedMul` * shader_jit_a64: Fix CALLU condition Should be `EQ` not `NE`. Fixes the regression on Kid Icarus. No known regressions anymore! --------- Co-authored-by: merryhime Co-authored-by: JosJuice --- .gitmodules | 3 + externals/CMakeLists.txt | 5 + externals/oaknut | 1 + src/common/CMakeLists.txt | 6 + src/common/aarch64/oaknut_abi.h | 155 +++ src/common/aarch64/oaknut_util.h | 43 + src/tests/CMakeLists.txt | 2 +- ...4_compiler.cpp => shader_jit_compiler.cpp} | 221 ++- src/video_core/CMakeLists.txt | 8 + src/video_core/shader/shader.cpp | 18 +- src/video_core/shader/shader_jit_a64.cpp | 51 + src/video_core/shader/shader_jit_a64.h | 33 + .../shader/shader_jit_a64_compiler.cpp | 1207 +++++++++++++++++ .../shader/shader_jit_a64_compiler.h | 146 ++ 14 files changed, 1874 insertions(+), 25 deletions(-) create mode 160000 externals/oaknut create mode 100644 src/common/aarch64/oaknut_abi.h create mode 100644 src/common/aarch64/oaknut_util.h rename src/tests/video_core/shader/{shader_jit_x64_compiler.cpp => shader_jit_compiler.cpp} (69%) create mode 100644 src/video_core/shader/shader_jit_a64.cpp create mode 100644 src/video_core/shader/shader_jit_a64.h create mode 100644 src/video_core/shader/shader_jit_a64_compiler.cpp create mode 100644 src/video_core/shader/shader_jit_a64_compiler.h diff --git a/.gitmodules b/.gitmodules index 1320809b5f..1b72ddb32f 100644 --- a/.gitmodules +++ b/.gitmodules @@ -88,3 +88,6 @@ [submodule "libadrenotools"] path = externals/libadrenotools url = https://github.com/bylaws/libadrenotools +[submodule "oaknut"] + path = externals/oaknut + url = https://github.com/merryhime/oaknut.git diff --git a/externals/CMakeLists.txt b/externals/CMakeLists.txt index 802a6be513..46e8ca13d4 100644 --- a/externals/CMakeLists.txt +++ b/externals/CMakeLists.txt @@ -85,6 +85,11 @@ if ("x86_64" IN_LIST ARCHITECTURE) endif() endif() +# Oaknut +if ("arm64" IN_LIST ARCHITECTURE) + add_subdirectory(oaknut EXCLUDE_FROM_ALL) +endif() + # Dynarmic if ("x86_64" IN_LIST ARCHITECTURE OR "arm64" IN_LIST ARCHITECTURE) if(USE_SYSTEM_DYNARMIC) diff --git a/externals/oaknut b/externals/oaknut new file mode 160000 index 0000000000..e6eecc3f94 --- /dev/null +++ b/externals/oaknut @@ -0,0 +1 @@ +Subproject commit e6eecc3f9460728be0a8d3f63e66d31c0362f472 diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt index b4abd328c6..3b9c51984f 100644 --- a/src/common/CMakeLists.txt +++ b/src/common/CMakeLists.txt @@ -53,6 +53,8 @@ add_custom_command(OUTPUT scm_rev.cpp add_library(citra_common STATIC aarch64/cpu_detect.cpp aarch64/cpu_detect.h + aarch64/oaknut_abi.h + aarch64/oaknut_util.h alignment.h android_storage.h android_storage.cpp @@ -179,6 +181,10 @@ if ("x86_64" IN_LIST ARCHITECTURE) target_link_libraries(citra_common PRIVATE xbyak) endif() +if ("arm64" IN_LIST ARCHITECTURE) + target_link_libraries(citra_common PRIVATE oaknut) +endif() + if (CITRA_USE_PRECOMPILED_HEADERS) target_precompile_headers(citra_common PRIVATE precompiled_headers.h) endif() diff --git a/src/common/aarch64/oaknut_abi.h b/src/common/aarch64/oaknut_abi.h new file mode 100644 index 0000000000..7323cfca49 --- /dev/null +++ b/src/common/aarch64/oaknut_abi.h @@ -0,0 +1,155 @@ +// Copyright 2023 Citra Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#pragma once + +#include "common/arch.h" +#if CITRA_ARCH(arm64) + +#include +#include +#include +#include "common/assert.h" + +namespace Common::A64 { + +constexpr std::size_t RegToIndex(const oaknut::Reg& reg) { + ASSERT(reg.index() != 31); // ZR not allowed + return reg.index() + (reg.is_vector() ? 32 : 0); +} + +constexpr oaknut::XReg IndexToXReg(std::size_t reg_index) { + ASSERT(reg_index <= 30); + return oaknut::XReg(static_cast(reg_index)); +} + +constexpr oaknut::VReg IndexToVReg(std::size_t reg_index) { + ASSERT(reg_index >= 32 && reg_index < 64); + return oaknut::QReg(static_cast(reg_index - 32)); +} + +constexpr oaknut::Reg IndexToReg(std::size_t reg_index) { + if (reg_index < 32) { + return IndexToXReg(reg_index); + } else { + return IndexToVReg(reg_index); + } +} + +inline constexpr std::bitset<64> BuildRegSet(std::initializer_list regs) { + std::bitset<64> bits; + for (const oaknut::Reg& reg : regs) { + bits.set(RegToIndex(reg)); + } + return bits; +} + +constexpr inline std::bitset<64> ABI_ALL_GPRS(0x00000000'7FFFFFFF); +constexpr inline std::bitset<64> ABI_ALL_FPRS(0xFFFFFFFF'00000000); + +constexpr inline oaknut::XReg ABI_RETURN = oaknut::util::X0; +constexpr inline oaknut::XReg ABI_PARAM1 = oaknut::util::X0; +constexpr inline oaknut::XReg ABI_PARAM2 = oaknut::util::X1; +constexpr inline oaknut::XReg ABI_PARAM3 = oaknut::util::X2; +constexpr inline oaknut::XReg ABI_PARAM4 = oaknut::util::X3; + +constexpr std::bitset<64> ABI_ALL_CALLER_SAVED = 0xffffffff'4000ffff; +constexpr std::bitset<64> ABI_ALL_CALLEE_SAVED = 0x0000ff00'7ff80000; + +struct ABIFrameInfo { + u32 subtraction; + u32 fprs_offset; +}; + +inline ABIFrameInfo ABI_CalculateFrameSize(std::bitset<64> regs, std::size_t frame_size) { + const size_t gprs_count = (regs & ABI_ALL_GPRS).count(); + const size_t fprs_count = (regs & ABI_ALL_FPRS).count(); + + const size_t gprs_size = (gprs_count + 1) / 2 * 16; + const size_t fprs_size = fprs_count * 16; + + size_t total_size = 0; + total_size += gprs_size; + const size_t fprs_base_subtraction = total_size; + total_size += fprs_size; + total_size += frame_size; + + return ABIFrameInfo{static_cast(total_size), static_cast(fprs_base_subtraction)}; +} + +inline void ABI_PushRegisters(oaknut::CodeGenerator& code, std::bitset<64> regs, + std::size_t frame_size = 0) { + using namespace oaknut; + using namespace oaknut::util; + auto frame_info = ABI_CalculateFrameSize(regs, frame_size); + + // Allocate stack-space + if (frame_info.subtraction != 0) { + code.SUB(SP, SP, frame_info.subtraction); + } + + // TODO(wunk): Push pairs of registers at a time with STP + std::size_t offset = 0; + for (std::size_t i = 0; i < 32; ++i) { + if (regs[i] && ABI_ALL_GPRS[i]) { + const XReg reg = IndexToXReg(i); + code.STR(reg, SP, offset); + offset += 8; + } + } + + offset = 0; + for (std::size_t i = 32; i < 64; ++i) { + if (regs[i] && ABI_ALL_FPRS[i]) { + const VReg reg = IndexToVReg(i); + code.STR(reg.toQ(), SP, u16(frame_info.fprs_offset + offset)); + offset += 16; + } + } + + // Allocate frame-space + if (frame_size != 0) { + code.SUB(SP, SP, frame_size); + } +} + +inline void ABI_PopRegisters(oaknut::CodeGenerator& code, std::bitset<64> regs, + std::size_t frame_size = 0) { + using namespace oaknut; + using namespace oaknut::util; + auto frame_info = ABI_CalculateFrameSize(regs, frame_size); + + // Free frame-space + if (frame_size != 0) { + code.ADD(SP, SP, frame_size); + } + + // TODO(wunk): Pop pairs of registers at a time with LDP + std::size_t offset = 0; + for (std::size_t i = 0; i < 32; ++i) { + if (regs[i] && ABI_ALL_GPRS[i]) { + const XReg reg = IndexToXReg(i); + code.LDR(reg, SP, offset); + offset += 8; + } + } + + offset = 0; + for (std::size_t i = 32; i < 64; ++i) { + if (regs[i] && ABI_ALL_FPRS[i]) { + const VReg reg = IndexToVReg(i); + code.LDR(reg.toQ(), SP, frame_info.fprs_offset + offset); + offset += 16; + } + } + + // Free stack-space + if (frame_info.subtraction != 0) { + code.ADD(SP, SP, frame_info.subtraction); + } +} + +} // namespace Common::A64 + +#endif // CITRA_ARCH(arm64) diff --git a/src/common/aarch64/oaknut_util.h b/src/common/aarch64/oaknut_util.h new file mode 100644 index 0000000000..0118fac5ff --- /dev/null +++ b/src/common/aarch64/oaknut_util.h @@ -0,0 +1,43 @@ +// Copyright 2023 Citra Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#pragma once + +#include "common/arch.h" +#if CITRA_ARCH(arm64) + +#include +#include +#include "common/aarch64/oaknut_abi.h" + +namespace Common::A64 { + +// BL can only reach targets within +-128MiB(24 bits) +inline bool IsWithin128M(uintptr_t ref, uintptr_t target) { + const u64 distance = target - (ref + 4); + return !(distance >= 0x800'0000ULL && distance <= ~0x800'0000ULL); +} + +inline bool IsWithin128M(const oaknut::CodeGenerator& code, uintptr_t target) { + return IsWithin128M(code.ptr(), target); +} + +template +inline void CallFarFunction(oaknut::CodeGenerator& code, const T f) { + static_assert(std::is_pointer_v, "Argument must be a (function) pointer."); + const std::uintptr_t addr = reinterpret_cast(f); + if (IsWithin128M(code, addr)) { + code.BL(reinterpret_cast(f)); + } else { + // X16(IP0) and X17(IP1) is the standard veneer register + // LR is also available as an intermediate register + // https://developer.arm.com/documentation/102374/0101/Procedure-Call-Standard + code.MOVP2R(oaknut::util::X16, reinterpret_cast(f)); + code.BLR(oaknut::util::X16); + } +} + +} // namespace Common::A64 + +#endif // CITRA_ARCH(arm64) diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt index 026d0cf23a..dd6e6942f3 100644 --- a/src/tests/CMakeLists.txt +++ b/src/tests/CMakeLists.txt @@ -15,7 +15,7 @@ add_executable(tests audio_core/lle/lle.cpp audio_core/audio_fixures.h audio_core/decoder_tests.cpp - video_core/shader/shader_jit_x64_compiler.cpp + video_core/shader/shader_jit_compiler.cpp ) create_target_directory_groups(tests) diff --git a/src/tests/video_core/shader/shader_jit_x64_compiler.cpp b/src/tests/video_core/shader/shader_jit_compiler.cpp similarity index 69% rename from src/tests/video_core/shader/shader_jit_x64_compiler.cpp rename to src/tests/video_core/shader/shader_jit_compiler.cpp index d0696d2095..c76a2d573d 100644 --- a/src/tests/video_core/shader/shader_jit_x64_compiler.cpp +++ b/src/tests/video_core/shader/shader_jit_compiler.cpp @@ -1,9 +1,9 @@ -// Copyright 2017 Citra Emulator Project +// Copyright 2023 Citra Emulator Project // Licensed under GPLv2 or any later version // Refer to the license.txt file included. #include "common/arch.h" -#if CITRA_ARCH(x86_64) +#if CITRA_ARCH(x86_64) || CITRA_ARCH(arm64) #include #include @@ -14,7 +14,11 @@ #include #include #include "video_core/shader/shader_interpreter.h" +#if CITRA_ARCH(x86_64) #include "video_core/shader/shader_jit_x64_compiler.h" +#elif CITRA_ARCH(arm64) +#include "video_core/shader/shader_jit_a64_compiler.h" +#endif using JitShader = Pica::Shader::JitShader; using ShaderInterpreter = Pica::Shader::InterpreterEngine; @@ -31,6 +35,18 @@ static constexpr Common::Vec4f vec4_zero = Common::Vec4f::AssignToAll(0.0f); namespace Catch { template <> +struct StringMaker { + static std::string convert(Common::Vec2f value) { + return fmt::format("({}, {})", value.x, value.y); + } +}; +template <> +struct StringMaker { + static std::string convert(Common::Vec3f value) { + return fmt::format("({}, {}, {})", value.r(), value.g(), value.b()); + } +}; +template <> struct StringMaker { static std::string convert(Common::Vec4f value) { return fmt::format("({}, {}, {}, {})", value.r(), value.g(), value.b(), value.a()); @@ -59,6 +75,11 @@ public: shader_jit.Compile(&shader_setup->program_code, &shader_setup->swizzle_data); } + explicit ShaderTest(std::unique_ptr input_shader_setup) + : shader_setup(std::move(input_shader_setup)) { + shader_jit.Compile(&shader_setup->program_code, &shader_setup->swizzle_data); + } + Common::Vec4f Run(std::span inputs) { Pica::Shader::UnitState shader_unit; RunJit(shader_unit, inputs); @@ -144,6 +165,41 @@ TEST_CASE("ADD", "[video_core][shader][shader_jit]") { REQUIRE(std::isinf(shader.Run({INFINITY, -1.0f}).x)); } +TEST_CASE("CALL", "[video_core][shader][shader_jit]") { + const auto sh_input = SourceRegister::MakeInput(0); + const auto sh_output = DestRegister::MakeOutput(0); + + auto shader_setup = CompileShaderSetup({ + {OpCode::Id::NOP}, // call foo + {OpCode::Id::END}, + // .proc foo + {OpCode::Id::NOP}, // call ex2 + {OpCode::Id::END}, + // .proc ex2 + {OpCode::Id::EX2, sh_output, sh_input}, + {OpCode::Id::END}, + }); + + // nihstro does not support the CALL* instructions, so the instruction-binary must be manually + // inserted here: + nihstro::Instruction CALL = {}; + CALL.opcode = nihstro::OpCode(nihstro::OpCode::Id::CALL); + + // call foo + CALL.flow_control.dest_offset = 2; + CALL.flow_control.num_instructions = 1; + shader_setup->program_code[0] = CALL.hex; + + // call ex2 + CALL.flow_control.dest_offset = 4; + CALL.flow_control.num_instructions = 1; + shader_setup->program_code[2] = CALL.hex; + + auto shader = ShaderTest(std::move(shader_setup)); + + REQUIRE(shader.Run(0.f).x == Catch::Approx(1.f)); +} + TEST_CASE("DP3", "[video_core][shader][shader_jit]") { const auto sh_input1 = SourceRegister::MakeInput(0); const auto sh_input2 = SourceRegister::MakeInput(1); @@ -395,6 +451,39 @@ TEST_CASE("RSQ", "[video_core][shader][shader_jit]") { REQUIRE(shader.Run({0.0625f}).x == Catch::Approx(4.0f).margin(0.004f)); } +TEST_CASE("Uniform Read", "[video_core][shader][shader_jit]") { + const auto sh_input = SourceRegister::MakeInput(0); + const auto sh_c0 = SourceRegister::MakeFloat(0); + const auto sh_output = DestRegister::MakeOutput(0); + + auto shader = ShaderTest({ + // mova a0.x, sh_input.x + {OpCode::Id::MOVA, DestRegister{}, "x", sh_input, "x", SourceRegister{}, "", + nihstro::InlineAsm::RelativeAddress::A1}, + // mov sh_output.xyzw, c0[a0.x].xyzw + {OpCode::Id::MOV, sh_output, "xyzw", sh_c0, "xyzw", SourceRegister{}, "", + nihstro::InlineAsm::RelativeAddress::A1}, + {OpCode::Id::END}, + }); + + // Prepare shader uniforms + std::array f_uniforms = {}; + for (u32 i = 0; i < 96; ++i) { + const float color = (i * 2.0f) / 255.0f; + const auto color_f24 = Pica::f24::FromFloat32(color); + shader.shader_setup->uniforms.f[i] = {color_f24, color_f24, color_f24, Pica::f24::One()}; + f_uniforms[i] = {color, color, color, 1.0f}; + } + + for (u32 i = 0; i < 96; ++i) { + const float index = static_cast(i); + // Add some fractional values to test proper float->integer truncation + const float fractional = (i % 17) / 17.0f; + + REQUIRE(shader.Run(index + fractional) == f_uniforms[i]); + } +} + TEST_CASE("Address Register Offset", "[video_core][shader][shader_jit]") { const auto sh_input = SourceRegister::MakeInput(0); const auto sh_c40 = SourceRegister::MakeFloat(40); @@ -445,23 +534,83 @@ TEST_CASE("Address Register Offset", "[video_core][shader][shader_jit]") { REQUIRE(shader.Run(-129.f) == f_uniforms[40]); } -// TODO: Requires fix from https://github.com/neobrain/nihstro/issues/68 -// TEST_CASE("MAD", "[video_core][shader][shader_jit]") { -// const auto sh_input1 = SourceRegister::MakeInput(0); -// const auto sh_input2 = SourceRegister::MakeInput(1); -// const auto sh_input3 = SourceRegister::MakeInput(2); -// const auto sh_output = DestRegister::MakeOutput(0); +TEST_CASE("Dest Mask", "[video_core][shader][shader_jit]") { + const auto sh_input = SourceRegister::MakeInput(0); + const auto sh_output = DestRegister::MakeOutput(0); -// auto shader = ShaderTest({ -// {OpCode::Id::MAD, sh_output, sh_input1, sh_input2, sh_input3}, -// {OpCode::Id::END}, -// }); + const auto shader = [&sh_input, &sh_output](const char* dest_mask) { + return std::unique_ptr(new ShaderTest{ + {OpCode::Id::MOV, sh_output, dest_mask, sh_input, "xyzw", SourceRegister{}, ""}, + {OpCode::Id::END}, + }); + }; -// REQUIRE(shader.Run({vec4_inf, vec4_zero, vec4_zero}).x == 0.0f); -// REQUIRE(std::isnan(shader.Run({vec4_nan, vec4_zero, vec4_zero}).x)); + const Common::Vec4f iota_vec = {1.0f, 2.0f, 3.0f, 4.0f}; -// REQUIRE(shader.Run({vec4_one, vec4_one, vec4_one}).x == 2.0f); -// } + REQUIRE(shader("x")->Run({iota_vec}).x == iota_vec.x); + REQUIRE(shader("y")->Run({iota_vec}).y == iota_vec.y); + REQUIRE(shader("z")->Run({iota_vec}).z == iota_vec.z); + REQUIRE(shader("w")->Run({iota_vec}).w == iota_vec.w); + REQUIRE(shader("xy")->Run({iota_vec}).xy() == iota_vec.xy()); + REQUIRE(shader("xz")->Run({iota_vec}).xz() == iota_vec.xz()); + REQUIRE(shader("xw")->Run({iota_vec}).xw() == iota_vec.xw()); + REQUIRE(shader("yz")->Run({iota_vec}).yz() == iota_vec.yz()); + REQUIRE(shader("yw")->Run({iota_vec}).yw() == iota_vec.yw()); + REQUIRE(shader("zw")->Run({iota_vec}).zw() == iota_vec.zw()); + REQUIRE(shader("xyz")->Run({iota_vec}).xyz() == iota_vec.xyz()); + REQUIRE(shader("xyw")->Run({iota_vec}).xyw() == iota_vec.xyw()); + REQUIRE(shader("xzw")->Run({iota_vec}).xzw() == iota_vec.xzw()); + REQUIRE(shader("yzw")->Run({iota_vec}).yzw() == iota_vec.yzw()); + REQUIRE(shader("xyzw")->Run({iota_vec}) == iota_vec); +} + +TEST_CASE("MAD", "[video_core][shader][shader_jit]") { + const auto sh_input1 = SourceRegister::MakeInput(0); + const auto sh_input2 = SourceRegister::MakeInput(1); + const auto sh_input3 = SourceRegister::MakeInput(2); + const auto sh_output = DestRegister::MakeOutput(0); + + auto shader_setup = CompileShaderSetup({ + // TODO: Requires fix from https://github.com/neobrain/nihstro/issues/68 + // {OpCode::Id::MAD, sh_output, sh_input1, sh_input2, sh_input3}, + {OpCode::Id::NOP}, + {OpCode::Id::END}, + }); + + // nihstro does not support the MAD* instructions, so the instruction-binary must be manually + // inserted here: + nihstro::Instruction MAD = {}; + MAD.opcode = nihstro::OpCode::Id::MAD; + MAD.mad.operand_desc_id = 0; + MAD.mad.src1 = sh_input1; + MAD.mad.src2 = sh_input2; + MAD.mad.src3 = sh_input3; + MAD.mad.dest = sh_output; + shader_setup->program_code[0] = MAD.hex; + + nihstro::SwizzlePattern swizzle = {}; + swizzle.dest_mask = 0b1111; + swizzle.SetSelectorSrc1(0, SwizzlePattern::Selector::x); + swizzle.SetSelectorSrc1(1, SwizzlePattern::Selector::y); + swizzle.SetSelectorSrc1(2, SwizzlePattern::Selector::z); + swizzle.SetSelectorSrc1(3, SwizzlePattern::Selector::w); + swizzle.SetSelectorSrc2(0, SwizzlePattern::Selector::x); + swizzle.SetSelectorSrc2(1, SwizzlePattern::Selector::y); + swizzle.SetSelectorSrc2(2, SwizzlePattern::Selector::z); + swizzle.SetSelectorSrc2(3, SwizzlePattern::Selector::w); + swizzle.SetSelectorSrc3(0, SwizzlePattern::Selector::x); + swizzle.SetSelectorSrc3(1, SwizzlePattern::Selector::y); + swizzle.SetSelectorSrc3(2, SwizzlePattern::Selector::z); + swizzle.SetSelectorSrc3(3, SwizzlePattern::Selector::w); + shader_setup->swizzle_data[0] = swizzle.hex; + + auto shader = ShaderTest(std::move(shader_setup)); + + REQUIRE(shader.Run({vec4_zero, vec4_zero, vec4_zero}) == vec4_zero); + REQUIRE(shader.Run({vec4_one, vec4_one, vec4_one}) == (vec4_one * 2.0f)); + REQUIRE(shader.Run({vec4_inf, vec4_zero, vec4_zero}) == vec4_zero); + REQUIRE(shader.Run({vec4_nan, vec4_zero, vec4_zero}) == vec4_nan); +} TEST_CASE("Nested Loop", "[video_core][shader][shader_jit]") { const auto sh_input = SourceRegister::MakeInput(0); @@ -518,4 +667,42 @@ TEST_CASE("Nested Loop", "[video_core][shader][shader_jit]") { } } -#endif // CITRA_ARCH(x86_64) +TEST_CASE("Source Swizzle", "[video_core][shader][shader_jit]") { + const auto sh_input = SourceRegister::MakeInput(0); + const auto sh_output = DestRegister::MakeOutput(0); + + const auto shader = [&sh_input, &sh_output](const char* swizzle) { + return std::unique_ptr(new ShaderTest{ + {OpCode::Id::MOV, sh_output, "xyzw", sh_input, swizzle, SourceRegister{}, ""}, + {OpCode::Id::END}, + }); + }; + + const Common::Vec4f iota_vec = {1.0f, 2.0f, 3.0f, 4.0f}; + + REQUIRE(shader("x")->Run({iota_vec}).x == iota_vec.x); + REQUIRE(shader("y")->Run({iota_vec}).x == iota_vec.y); + REQUIRE(shader("z")->Run({iota_vec}).x == iota_vec.z); + REQUIRE(shader("w")->Run({iota_vec}).x == iota_vec.w); + REQUIRE(shader("xy")->Run({iota_vec}).xy() == iota_vec.xy()); + REQUIRE(shader("xz")->Run({iota_vec}).xy() == iota_vec.xz()); + REQUIRE(shader("xw")->Run({iota_vec}).xy() == iota_vec.xw()); + REQUIRE(shader("yz")->Run({iota_vec}).xy() == iota_vec.yz()); + REQUIRE(shader("yw")->Run({iota_vec}).xy() == iota_vec.yw()); + REQUIRE(shader("zw")->Run({iota_vec}).xy() == iota_vec.zw()); + REQUIRE(shader("yy")->Run({iota_vec}).xy() == iota_vec.yy()); + REQUIRE(shader("wx")->Run({iota_vec}).xy() == iota_vec.wx()); + REQUIRE(shader("xyz")->Run({iota_vec}).xyz() == iota_vec.xyz()); + REQUIRE(shader("xyw")->Run({iota_vec}).xyz() == iota_vec.xyw()); + REQUIRE(shader("xzw")->Run({iota_vec}).xyz() == iota_vec.xzw()); + REQUIRE(shader("yzw")->Run({iota_vec}).xyz() == iota_vec.yzw()); + REQUIRE(shader("yyy")->Run({iota_vec}).xyz() == iota_vec.yyy()); + REQUIRE(shader("yxw")->Run({iota_vec}).xyz() == iota_vec.yxw()); + REQUIRE(shader("xyzw")->Run({iota_vec}) == iota_vec); + REQUIRE(shader("wzxy")->Run({iota_vec}) == + Common::Vec4f(iota_vec.w, iota_vec.z, iota_vec.x, iota_vec.y)); + REQUIRE(shader("yyyy")->Run({iota_vec}) == + Common::Vec4f(iota_vec.y, iota_vec.y, iota_vec.y, iota_vec.y)); +} + +#endif // CITRA_ARCH(x86_64) || CITRA_ARCH(arm64) diff --git a/src/video_core/CMakeLists.txt b/src/video_core/CMakeLists.txt index 7951002c89..cdba5cbd31 100644 --- a/src/video_core/CMakeLists.txt +++ b/src/video_core/CMakeLists.txt @@ -149,6 +149,10 @@ add_library(video_core STATIC shader/shader.h shader/shader_interpreter.cpp shader/shader_interpreter.h + shader/shader_jit_a64.cpp + shader/shader_jit_a64_compiler.cpp + shader/shader_jit_a64.h + shader/shader_jit_a64_compiler.h shader/shader_jit_x64.cpp shader/shader_jit_x64_compiler.cpp shader/shader_jit_x64.h @@ -177,6 +181,10 @@ if ("x86_64" IN_LIST ARCHITECTURE) target_link_libraries(video_core PUBLIC xbyak) endif() +if ("arm64" IN_LIST ARCHITECTURE) + target_link_libraries(video_core PUBLIC oaknut) +endif() + if (CITRA_USE_PRECOMPILED_HEADERS) target_precompile_headers(video_core PRIVATE precompiled_headers.h) endif() diff --git a/src/video_core/shader/shader.cpp b/src/video_core/shader/shader.cpp index b8393c379e..6124a64851 100644 --- a/src/video_core/shader/shader.cpp +++ b/src/video_core/shader/shader.cpp @@ -15,7 +15,9 @@ #include "video_core/shader/shader_interpreter.h" #if CITRA_ARCH(x86_64) #include "video_core/shader/shader_jit_x64.h" -#endif // CITRA_ARCH(x86_64) +#elif CITRA_ARCH(arm64) +#include "video_core/shader/shader_jit_a64.h" +#endif #include "video_core/video_core.h" namespace Pica::Shader { @@ -141,27 +143,29 @@ MICROPROFILE_DEFINE(GPU_Shader, "GPU", "Shader", MP_RGB(50, 50, 240)); #if CITRA_ARCH(x86_64) static std::unique_ptr jit_engine; -#endif // CITRA_ARCH(x86_64) +#elif CITRA_ARCH(arm64) +static std::unique_ptr jit_engine; +#endif static InterpreterEngine interpreter_engine; ShaderEngine* GetEngine() { -#if CITRA_ARCH(x86_64) +#if CITRA_ARCH(x86_64) || CITRA_ARCH(arm64) // TODO(yuriks): Re-initialize on each change rather than being persistent if (VideoCore::g_shader_jit_enabled) { if (jit_engine == nullptr) { - jit_engine = std::make_unique(); + jit_engine = std::make_unique(); } return jit_engine.get(); } -#endif // CITRA_ARCH(x86_64) +#endif // CITRA_ARCH(x86_64) || CITRA_ARCH(arm64) return &interpreter_engine; } void Shutdown() { -#if CITRA_ARCH(x86_64) +#if CITRA_ARCH(x86_64) || CITRA_ARCH(arm64) jit_engine = nullptr; -#endif // CITRA_ARCH(x86_64) +#endif // CITRA_ARCH(x86_64) || CITRA_ARCH(arm64) } } // namespace Pica::Shader diff --git a/src/video_core/shader/shader_jit_a64.cpp b/src/video_core/shader/shader_jit_a64.cpp new file mode 100644 index 0000000000..9f6ffbbc93 --- /dev/null +++ b/src/video_core/shader/shader_jit_a64.cpp @@ -0,0 +1,51 @@ +// Copyright 2023 Citra Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#include "common/arch.h" +#if CITRA_ARCH(arm64) + +#include "common/assert.h" +#include "common/microprofile.h" +#include "video_core/shader/shader.h" +#include "video_core/shader/shader_jit_a64.h" +#include "video_core/shader/shader_jit_a64_compiler.h" + +namespace Pica::Shader { + +JitA64Engine::JitA64Engine() = default; +JitA64Engine::~JitA64Engine() = default; + +void JitA64Engine::SetupBatch(ShaderSetup& setup, unsigned int entry_point) { + ASSERT(entry_point < MAX_PROGRAM_CODE_LENGTH); + setup.engine_data.entry_point = entry_point; + + u64 code_hash = setup.GetProgramCodeHash(); + u64 swizzle_hash = setup.GetSwizzleDataHash(); + + u64 cache_key = code_hash ^ swizzle_hash; + auto iter = cache.find(cache_key); + if (iter != cache.end()) { + setup.engine_data.cached_shader = iter->second.get(); + } else { + auto shader = std::make_unique(); + shader->Compile(&setup.program_code, &setup.swizzle_data); + setup.engine_data.cached_shader = shader.get(); + cache.emplace_hint(iter, cache_key, std::move(shader)); + } +} + +MICROPROFILE_DECLARE(GPU_Shader); + +void JitA64Engine::Run(const ShaderSetup& setup, UnitState& state) const { + ASSERT(setup.engine_data.cached_shader != nullptr); + + MICROPROFILE_SCOPE(GPU_Shader); + + const JitShader* shader = static_cast(setup.engine_data.cached_shader); + shader->Run(setup, state, setup.engine_data.entry_point); +} + +} // namespace Pica::Shader + +#endif // CITRA_ARCH(arm64) diff --git a/src/video_core/shader/shader_jit_a64.h b/src/video_core/shader/shader_jit_a64.h new file mode 100644 index 0000000000..8b19695f0a --- /dev/null +++ b/src/video_core/shader/shader_jit_a64.h @@ -0,0 +1,33 @@ +// Copyright 2023 Citra Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#pragma once + +#include "common/arch.h" +#if CITRA_ARCH(arm64) + +#include +#include +#include "common/common_types.h" +#include "video_core/shader/shader.h" + +namespace Pica::Shader { + +class JitShader; + +class JitA64Engine final : public ShaderEngine { +public: + JitA64Engine(); + ~JitA64Engine() override; + + void SetupBatch(ShaderSetup& setup, unsigned int entry_point) override; + void Run(const ShaderSetup& setup, UnitState& state) const override; + +private: + std::unordered_map> cache; +}; + +} // namespace Pica::Shader + +#endif // CITRA_ARCH(arm64) diff --git a/src/video_core/shader/shader_jit_a64_compiler.cpp b/src/video_core/shader/shader_jit_a64_compiler.cpp new file mode 100644 index 0000000000..176a14f26a --- /dev/null +++ b/src/video_core/shader/shader_jit_a64_compiler.cpp @@ -0,0 +1,1207 @@ +// Copyright 2023 Citra Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#include "common/arch.h" +#if CITRA_ARCH(arm64) + +#include +#include +#include +#include +#include "common/aarch64/cpu_detect.h" +#include "common/aarch64/oaknut_abi.h" +#include "common/aarch64/oaknut_util.h" +#include "common/assert.h" +#include "common/logging/log.h" +#include "common/vector_math.h" +#include "video_core/pica_state.h" +#include "video_core/pica_types.h" +#include "video_core/shader/shader.h" +#include "video_core/shader/shader_jit_a64_compiler.h" + +using namespace Common::A64; +using namespace oaknut; +using namespace oaknut::util; + +using nihstro::DestRegister; +using nihstro::RegisterType; + +namespace Pica::Shader { + +typedef void (JitShader::*JitFunction)(Instruction instr); + +const std::array instr_table = { + &JitShader::Compile_ADD, // add + &JitShader::Compile_DP3, // dp3 + &JitShader::Compile_DP4, // dp4 + &JitShader::Compile_DPH, // dph + nullptr, // unknown + &JitShader::Compile_EX2, // ex2 + &JitShader::Compile_LG2, // lg2 + nullptr, // unknown + &JitShader::Compile_MUL, // mul + &JitShader::Compile_SGE, // sge + &JitShader::Compile_SLT, // slt + &JitShader::Compile_FLR, // flr + &JitShader::Compile_MAX, // max + &JitShader::Compile_MIN, // min + &JitShader::Compile_RCP, // rcp + &JitShader::Compile_RSQ, // rsq + nullptr, // unknown + nullptr, // unknown + &JitShader::Compile_MOVA, // mova + &JitShader::Compile_MOV, // mov + nullptr, // unknown + nullptr, // unknown + nullptr, // unknown + nullptr, // unknown + &JitShader::Compile_DPH, // dphi + nullptr, // unknown + &JitShader::Compile_SGE, // sgei + &JitShader::Compile_SLT, // slti + nullptr, // unknown + nullptr, // unknown + nullptr, // unknown + nullptr, // unknown + nullptr, // unknown + &JitShader::Compile_NOP, // nop + &JitShader::Compile_END, // end + &JitShader::Compile_BREAKC, // breakc + &JitShader::Compile_CALL, // call + &JitShader::Compile_CALLC, // callc + &JitShader::Compile_CALLU, // callu + &JitShader::Compile_IF, // ifu + &JitShader::Compile_IF, // ifc + &JitShader::Compile_LOOP, // loop + &JitShader::Compile_EMIT, // emit + &JitShader::Compile_SETE, // sete + &JitShader::Compile_JMP, // jmpc + &JitShader::Compile_JMP, // jmpu + &JitShader::Compile_CMP, // cmp + &JitShader::Compile_CMP, // cmp + &JitShader::Compile_MAD, // madi + &JitShader::Compile_MAD, // madi + &JitShader::Compile_MAD, // madi + &JitShader::Compile_MAD, // madi + &JitShader::Compile_MAD, // madi + &JitShader::Compile_MAD, // madi + &JitShader::Compile_MAD, // madi + &JitShader::Compile_MAD, // madi + &JitShader::Compile_MAD, // mad + &JitShader::Compile_MAD, // mad + &JitShader::Compile_MAD, // mad + &JitShader::Compile_MAD, // mad + &JitShader::Compile_MAD, // mad + &JitShader::Compile_MAD, // mad + &JitShader::Compile_MAD, // mad + &JitShader::Compile_MAD, // mad +}; + +// The following is used to alias some commonly used registers: +/// Pointer to the uniform memory +constexpr XReg UNIFORMS = X9; +/// The two 32-bit VS address offset registers set by the MOVA instruction +constexpr XReg ADDROFFS_REG_0 = X10; +constexpr XReg ADDROFFS_REG_1 = X11; +/// VS loop count register (Multiplied by 16) +constexpr WReg LOOPCOUNT_REG = W12; +/// Current VS loop iteration number (we could probably use LOOPCOUNT_REG, but this quicker) +constexpr WReg LOOPCOUNT = W6; +/// Number to increment LOOPCOUNT_REG by on each loop iteration (Multiplied by 16) +constexpr WReg LOOPINC = W7; +/// Result of the previous CMP instruction for the X-component comparison +constexpr XReg COND0 = X13; +/// Result of the previous CMP instruction for the Y-component comparison +constexpr XReg COND1 = X14; +/// Pointer to the UnitState instance for the current VS unit +constexpr XReg STATE = X15; +/// Scratch registers +constexpr XReg XSCRATCH0 = X4; +constexpr XReg XSCRATCH1 = X5; +constexpr QReg VSCRATCH0 = Q0; +constexpr QReg VSCRATCH1 = Q4; +constexpr QReg VSCRATCH2 = Q15; +/// Loaded with the first swizzled source register, otherwise can be used as a scratch register +constexpr QReg SRC1 = Q1; +/// Loaded with the second swizzled source register, otherwise can be used as a scratch register +constexpr QReg SRC2 = Q2; +/// Loaded with the third swizzled source register, otherwise can be used as a scratch register +constexpr QReg SRC3 = Q3; +/// Constant vector of [1.0f, 1.0f, 1.0f, 1.0f], used to efficiently set a vector to one +constexpr QReg ONE = Q14; + +// State registers that must not be modified by external functions calls +// Scratch registers, e.g., SRC1 and VSCRATCH0, have to be saved on the side if needed +static const std::bitset<64> persistent_regs = + BuildRegSet({// Pointers to register blocks + UNIFORMS, STATE, + // Cached registers + ADDROFFS_REG_0, ADDROFFS_REG_1, LOOPCOUNT_REG, COND0, COND1, + // Constants + ONE, + // Loop variables + LOOPCOUNT, LOOPINC, + // Link Register + X30}); + +/// Raw constant for the source register selector that indicates no swizzling is performed +static const u8 NO_SRC_REG_SWIZZLE = 0x1b; +/// Raw constant for the destination register enable mask that indicates all components are enabled +static const u8 NO_DEST_REG_MASK = 0xf; + +static void LogCritical(const char* msg) { + LOG_CRITICAL(HW_GPU, "{}", msg); +} + +void JitShader::Compile_Assert(bool condition, const char* msg) {} + +/** + * Loads and swizzles a source register into the specified QReg register. + * @param instr VS instruction, used for determining how to load the source register + * @param src_num Number indicating which source register to load (1 = src1, 2 = src2, 3 = src3) + * @param src_reg SourceRegister object corresponding to the source register to load + * @param dest Destination QReg register to store the loaded, swizzled source register + */ +void JitShader::Compile_SwizzleSrc(Instruction instr, unsigned src_num, SourceRegister src_reg, + QReg dest) { + XReg src_ptr = XZR; + std::size_t src_offset; + switch (src_reg.GetRegisterType()) { + case RegisterType::FloatUniform: + src_ptr = UNIFORMS; + src_offset = Uniforms::GetFloatUniformOffset(src_reg.GetIndex()); + break; + case RegisterType::Input: + src_ptr = STATE; + src_offset = UnitState::InputOffset(src_reg.GetIndex()); + break; + case RegisterType::Temporary: + src_ptr = STATE; + src_offset = UnitState::TemporaryOffset(src_reg.GetIndex()); + break; + default: + UNREACHABLE_MSG("Encountered unknown source register type: {}", src_reg.GetRegisterType()); + break; + } + + const s32 src_offset_disp = static_cast(src_offset); + ASSERT_MSG(src_offset == static_cast(src_offset_disp), + "Source register offset too large for int type"); + + u32 operand_desc_id; + + const bool is_inverted = + (0 != (instr.opcode.Value().GetInfo().subtype & OpCode::Info::SrcInversed)); + + u32 address_register_index; + u32 offset_src; + + if (instr.opcode.Value().EffectiveOpCode() == OpCode::Id::MAD || + instr.opcode.Value().EffectiveOpCode() == OpCode::Id::MADI) { + operand_desc_id = instr.mad.operand_desc_id; + offset_src = is_inverted ? 3 : 2; + address_register_index = instr.mad.address_register_index; + } else { + operand_desc_id = instr.common.operand_desc_id; + offset_src = is_inverted ? 2 : 1; + address_register_index = instr.common.address_register_index; + } + + if (src_reg.GetRegisterType() == RegisterType::FloatUniform && src_num == offset_src && + address_register_index != 0) { + XReg address_reg = XZR; + switch (address_register_index) { + case 1: + address_reg = ADDROFFS_REG_0; + break; + case 2: + address_reg = ADDROFFS_REG_1; + break; + case 3: + address_reg = LOOPCOUNT_REG.toX(); + break; + default: + UNREACHABLE(); + break; + } + + // s32 offset = address_reg >= -128 && address_reg <= 127 ? address_reg : 0; + // u32 index = (src_reg.GetIndex() + offset) & 0x7f; + + // First we add 128 to address_reg so the first comparison is turned to + // address_reg >= 0 && address_reg < 256 + + // offset = ((address_reg + 128) < 256) ? address_reg : 0 + ADD(XSCRATCH1.toW(), address_reg.toW(), 128); + CMP(XSCRATCH1.toW(), 256); + CSEL(XSCRATCH0.toW(), address_reg.toW(), WZR, Cond::LO); + + // index = (src_reg.GetIndex() + offset) & 0x7f; + ADD(XSCRATCH0.toW(), XSCRATCH0.toW(), src_reg.GetIndex()); + AND(XSCRATCH0.toW(), XSCRATCH0.toW(), 0x7f); + + // index > 95 ? vec4(1.0) : uniforms.f[index]; + MOV(dest.B16(), ONE.B16()); + CMP(XSCRATCH0.toW(), 95); + Label load_end; + B(Cond::GT, load_end); + LDR(dest, src_ptr, XSCRATCH0, IndexExt::LSL, 4); + l(load_end); + } else { + // Load the source + LDR(dest, src_ptr, src_offset_disp); + } + + const SwizzlePattern swiz = {(*swizzle_data)[operand_desc_id]}; + + // Generate instructions for source register swizzling as needed + u8 sel = swiz.GetRawSelector(src_num); + if (sel != NO_SRC_REG_SWIZZLE) { + const int table[] = { + ((sel & 0b11'00'00'00) >> 6), + ((sel & 0b00'11'00'00) >> 4), + ((sel & 0b00'00'11'00) >> 2), + ((sel & 0b00'00'00'11) >> 0), + }; + + // Generate table-vector + MOV(XSCRATCH0.toW(), u32(0x03'02'01'00u + (table[0] * 0x04'04'04'04u))); + MOV(VSCRATCH0.Selem()[0], XSCRATCH0.toW()); + + MOV(XSCRATCH0.toW(), u32(0x03'02'01'00u + (table[1] * 0x04'04'04'04u))); + MOV(VSCRATCH0.Selem()[1], XSCRATCH0.toW()); + + MOV(XSCRATCH0.toW(), u32(0x03'02'01'00u + (table[2] * 0x04'04'04'04u))); + MOV(VSCRATCH0.Selem()[2], XSCRATCH0.toW()); + + MOV(XSCRATCH0.toW(), u32(0x03'02'01'00u + (table[3] * 0x04'04'04'04u))); + MOV(VSCRATCH0.Selem()[3], XSCRATCH0.toW()); + + TBL(dest.B16(), List{dest.B16()}, VSCRATCH0.B16()); + } + + // If the source register should be negated, flip the negative bit using XOR + const bool negate[] = {swiz.negate_src1, swiz.negate_src2, swiz.negate_src3}; + if (negate[src_num - 1]) { + FNEG(dest.S4(), dest.S4()); + } +} + +void JitShader::Compile_DestEnable(Instruction instr, QReg src) { + DestRegister dest; + u32 operand_desc_id; + if (instr.opcode.Value().EffectiveOpCode() == OpCode::Id::MAD || + instr.opcode.Value().EffectiveOpCode() == OpCode::Id::MADI) { + operand_desc_id = instr.mad.operand_desc_id; + dest = instr.mad.dest.Value(); + } else { + operand_desc_id = instr.common.operand_desc_id; + dest = instr.common.dest.Value(); + } + + SwizzlePattern swiz = {(*swizzle_data)[operand_desc_id]}; + + std::size_t dest_offset_disp; + switch (dest.GetRegisterType()) { + case RegisterType::Output: + dest_offset_disp = UnitState::OutputOffset(dest.GetIndex()); + break; + case RegisterType::Temporary: + dest_offset_disp = UnitState::TemporaryOffset(dest.GetIndex()); + break; + default: + UNREACHABLE_MSG("Encountered unknown destination register type: {}", + dest.GetRegisterType()); + break; + } + + // If all components are enabled, write the result to the destination register + if (swiz.dest_mask == NO_DEST_REG_MASK) { + // Store dest back to memory + STR(src, STATE, dest_offset_disp); + + } else { + // Not all components are enabled, so mask the result when storing to the destination + // register... + LDR(VSCRATCH0, STATE, dest_offset_disp); + + // MOVI encodes a 64-bit value into an 8-bit immidiate by replicating bits + // The 8-bit immediate "a:b:c:d:e:f:g:h" maps to the 64-bit value: + // "aaaaaaaabbbbbbbbccccccccddddddddeeeeeeeeffffffffgggggggghhhhhhhh" + if (((swiz.dest_mask & 0b1100) >> 2) == (swiz.dest_mask & 0b11)) { + // Upper/Lower halfs are the same bit-pattern, broadcast the same mask to both + // 64-bit lanes + const u8 rep_imm = ((swiz.dest_mask & 4) ? 0b11'11'00'00 : 0) | + ((swiz.dest_mask & 8) ? 0b00'00'11'11 : 0); + + MOVI(VSCRATCH2.D2(), RepImm(rep_imm)); + } else if ((swiz.dest_mask & 0b0011) == 0) { + // Upper elements are zero, create the final mask in the 64-bit lane + const u8 rep_imm = ((swiz.dest_mask & 4) ? 0b11'11'00'00 : 0) | + ((swiz.dest_mask & 8) ? 0b00'00'11'11 : 0); + + MOVI(VSCRATCH2.toD(), RepImm(rep_imm)); + } else { + // Create a 64-bit mask and widen it to 32-bits + const u8 rep_imm = ((swiz.dest_mask & 1) ? 0b11'00'00'00 : 0) | + ((swiz.dest_mask & 2) ? 0b00'11'00'00 : 0) | + ((swiz.dest_mask & 4) ? 0b00'00'11'00 : 0) | + ((swiz.dest_mask & 8) ? 0b00'00'00'11 : 0); + + MOVI(VSCRATCH2.toD(), RepImm(rep_imm)); + + // Widen 16->32 + ZIP1(VSCRATCH2.H8(), VSCRATCH2.H8(), VSCRATCH2.H8()); + } + + // Select between src and dst using mask + BSL(VSCRATCH2.B16(), src.B16(), VSCRATCH0.B16()); + + // Store dest back to memory + STR(VSCRATCH2, STATE, dest_offset_disp); + } +} + +void JitShader::Compile_SanitizedMul(QReg src1, QReg src2, QReg scratch0) { + // 0 * inf and inf * 0 in the PICA should return 0 instead of NaN. This can be implemented by + // checking for NaNs before and after the multiplication. If the multiplication result is NaN + // where neither source was, this NaN was generated by a 0 * inf multiplication, and so the + // result should be transformed to 0 to match PICA fp rules. + FMULX(VSCRATCH0.S4(), src1.S4(), src2.S4()); + FMUL(src1.S4(), src1.S4(), src2.S4()); + CMEQ(VSCRATCH0.S4(), VSCRATCH0.S4(), src1.S4()); + AND(src1.B16(), src1.B16(), VSCRATCH0.B16()); +} + +void JitShader::Compile_EvaluateCondition(Instruction instr) { + // Note: NXOR is used below to check for equality + switch (instr.flow_control.op) { + case Instruction::FlowControlType::Or: + MOV(XSCRATCH0, (instr.flow_control.refx.Value() ^ 1)); + MOV(XSCRATCH1, (instr.flow_control.refy.Value() ^ 1)); + EOR(XSCRATCH0, XSCRATCH0, COND0); + EOR(XSCRATCH1, XSCRATCH1, COND1); + ORR(XSCRATCH0, XSCRATCH0, XSCRATCH1); + break; + + case Instruction::FlowControlType::And: + MOV(XSCRATCH0, (instr.flow_control.refx.Value() ^ 1)); + MOV(XSCRATCH1, (instr.flow_control.refy.Value() ^ 1)); + EOR(XSCRATCH0, XSCRATCH0, COND0); + EOR(XSCRATCH1, XSCRATCH1, COND1); + AND(XSCRATCH0, XSCRATCH0, XSCRATCH1); + break; + + case Instruction::FlowControlType::JustX: + MOV(XSCRATCH0, (instr.flow_control.refx.Value() ^ 1)); + EOR(XSCRATCH0, XSCRATCH0, COND0); + break; + + case Instruction::FlowControlType::JustY: + MOV(XSCRATCH0, (instr.flow_control.refy.Value() ^ 1)); + EOR(XSCRATCH0, XSCRATCH0, COND1); + break; + } + CMP(XSCRATCH0, 0); +} + +void JitShader::Compile_UniformCondition(Instruction instr) { + const std::size_t offset = Uniforms::GetBoolUniformOffset(instr.flow_control.bool_uniform_id); + LDRB(XSCRATCH0.toW(), UNIFORMS, offset); + CMP(XSCRATCH0.toW(), 0); +} + +std::bitset<64> JitShader::PersistentCallerSavedRegs() { + return persistent_regs & ABI_ALL_CALLER_SAVED; +} + +void JitShader::Compile_ADD(Instruction instr) { + Compile_SwizzleSrc(instr, 1, instr.common.src1, SRC1); + Compile_SwizzleSrc(instr, 2, instr.common.src2, SRC2); + FADD(SRC1.S4(), SRC1.S4(), SRC2.S4()); + Compile_DestEnable(instr, SRC1); +} + +void JitShader::Compile_DP3(Instruction instr) { + Compile_SwizzleSrc(instr, 1, instr.common.src1, SRC1); + Compile_SwizzleSrc(instr, 2, instr.common.src2, SRC2); + + Compile_SanitizedMul(SRC1, SRC2, VSCRATCH0); + + // Set last element to 0.0 + MOV(SRC1.Selem()[3], WZR); + + FADDP(SRC1.S4(), SRC1.S4(), SRC1.S4()); + FADDP(SRC1.toS(), SRC1.toD().S2()); + DUP(SRC1.S4(), SRC1.Selem()[0]); + + Compile_DestEnable(instr, SRC1); +} + +void JitShader::Compile_DP4(Instruction instr) { + Compile_SwizzleSrc(instr, 1, instr.common.src1, SRC1); + Compile_SwizzleSrc(instr, 2, instr.common.src2, SRC2); + + Compile_SanitizedMul(SRC1, SRC2, VSCRATCH0); + + FADDP(SRC1.S4(), SRC1.S4(), SRC1.S4()); + FADDP(SRC1.toS(), SRC1.toD().S2()); + DUP(SRC1.S4(), SRC1.Selem()[0]); + + Compile_DestEnable(instr, SRC1); +} + +void JitShader::Compile_DPH(Instruction instr) { + if (instr.opcode.Value().EffectiveOpCode() == OpCode::Id::DPHI) { + Compile_SwizzleSrc(instr, 1, instr.common.src1i, SRC1); + Compile_SwizzleSrc(instr, 2, instr.common.src2i, SRC2); + } else { + Compile_SwizzleSrc(instr, 1, instr.common.src1, SRC1); + Compile_SwizzleSrc(instr, 2, instr.common.src2, SRC2); + } + + // Set 4th component to 1.0 + MOV(SRC1.Selem()[3], ONE.Selem()[0]); + + Compile_SanitizedMul(SRC1, SRC2, VSCRATCH0); + + FADDP(SRC1.S4(), SRC1.S4(), SRC1.S4()); + FADDP(SRC1.toS(), SRC1.toD().S2()); + DUP(SRC1.S4(), SRC1.Selem()[0]); + + Compile_DestEnable(instr, SRC1); +} + +void JitShader::Compile_EX2(Instruction instr) { + Compile_SwizzleSrc(instr, 1, instr.common.src1, SRC1); + STR(X30, SP, POST_INDEXED, -16); + BL(exp2_subroutine); + LDR(X30, SP, PRE_INDEXED, 16); + Compile_DestEnable(instr, SRC1); +} + +void JitShader::Compile_LG2(Instruction instr) { + Compile_SwizzleSrc(instr, 1, instr.common.src1, SRC1); + STR(X30, SP, POST_INDEXED, -16); + BL(log2_subroutine); + LDR(X30, SP, PRE_INDEXED, 16); + Compile_DestEnable(instr, SRC1); +} + +void JitShader::Compile_MUL(Instruction instr) { + Compile_SwizzleSrc(instr, 1, instr.common.src1, SRC1); + Compile_SwizzleSrc(instr, 2, instr.common.src2, SRC2); + Compile_SanitizedMul(SRC1, SRC2, VSCRATCH0); + Compile_DestEnable(instr, SRC1); +} + +void JitShader::Compile_SGE(Instruction instr) { + if (instr.opcode.Value().EffectiveOpCode() == OpCode::Id::SGEI) { + Compile_SwizzleSrc(instr, 1, instr.common.src1i, SRC1); + Compile_SwizzleSrc(instr, 2, instr.common.src2i, SRC2); + } else { + Compile_SwizzleSrc(instr, 1, instr.common.src1, SRC1); + Compile_SwizzleSrc(instr, 2, instr.common.src2, SRC2); + } + + FCMGE(SRC2.S4(), SRC1.S4(), SRC2.S4()); + AND(SRC2.B16(), SRC2.B16(), ONE.B16()); + + Compile_DestEnable(instr, SRC2); +} + +void JitShader::Compile_SLT(Instruction instr) { + if (instr.opcode.Value().EffectiveOpCode() == OpCode::Id::SLTI) { + Compile_SwizzleSrc(instr, 1, instr.common.src1i, SRC1); + Compile_SwizzleSrc(instr, 2, instr.common.src2i, SRC2); + } else { + Compile_SwizzleSrc(instr, 1, instr.common.src1, SRC1); + Compile_SwizzleSrc(instr, 2, instr.common.src2, SRC2); + } + + FCMGT(SRC1.S4(), SRC2.S4(), SRC1.S4()); + AND(SRC1.B16(), SRC1.B16(), ONE.B16()); + + Compile_DestEnable(instr, SRC1); +} + +void JitShader::Compile_FLR(Instruction instr) { + Compile_SwizzleSrc(instr, 1, instr.common.src1, SRC1); + FRINTM(SRC1.S4(), SRC1.S4()); + Compile_DestEnable(instr, SRC1); +} + +void JitShader::Compile_MAX(Instruction instr) { + Compile_SwizzleSrc(instr, 1, instr.common.src1, SRC1); + Compile_SwizzleSrc(instr, 2, instr.common.src2, SRC2); + + // VSCRATCH0 = Ordinal(SRC1) + FCMEQ(VSCRATCH0.S4(), SRC1.S4(), SRC1.S4()); + + // FMAX will always pick the NaN + FMAX(SRC1.S4(), SRC1.S4(), SRC2.S4()); + + // In the case of NaN, pick SRC2 + BIF(SRC1.B16(), SRC2.B16(), VSCRATCH0.B16()); + + Compile_DestEnable(instr, SRC1); +} + +void JitShader::Compile_MIN(Instruction instr) { + Compile_SwizzleSrc(instr, 1, instr.common.src1, SRC1); + Compile_SwizzleSrc(instr, 2, instr.common.src2, SRC2); + + // VSCRATCH0 = Ordinal(SRC1) + FCMEQ(VSCRATCH0.S4(), SRC1.S4(), SRC1.S4()); + + // FMIN will always pick the NaN + FMIN(SRC1.S4(), SRC1.S4(), SRC2.S4()); + + // In the case of NaN, pick SRC2 + BIF(SRC1.B16(), SRC2.B16(), VSCRATCH0.B16()); + + Compile_DestEnable(instr, SRC1); +} + +void JitShader::Compile_MOVA(Instruction instr) { + SwizzlePattern swiz = {(*swizzle_data)[instr.common.operand_desc_id]}; + + if (!swiz.DestComponentEnabled(0) && !swiz.DestComponentEnabled(1)) { + return; + } + + Compile_SwizzleSrc(instr, 1, instr.common.src1, SRC1); + + // Convert floats to integers using truncation (only care about X and Y components) + FCVTZS(SRC1.S4(), SRC1.S4()); + + // Get result + MOV(XSCRATCH0, SRC1.Delem()[0]); + + // Handle destination enable + if (swiz.DestComponentEnabled(0) && swiz.DestComponentEnabled(1)) { + // Move and sign-extend low 32 bits + SXTW(ADDROFFS_REG_0, XSCRATCH0.toW()); + + // Move and sign-extend high 32 bits + LSR(XSCRATCH0, XSCRATCH0, 32); + SXTW(ADDROFFS_REG_1, XSCRATCH0.toW()); + } else { + if (swiz.DestComponentEnabled(0)) { + // Move and sign-extend low 32 bits + SXTW(ADDROFFS_REG_0, XSCRATCH0.toW()); + } else if (swiz.DestComponentEnabled(1)) { + // Move and sign-extend high 32 bits + LSR(XSCRATCH0, XSCRATCH0, 32); + SXTW(ADDROFFS_REG_1, XSCRATCH0.toW()); + } + } +} + +void JitShader::Compile_MOV(Instruction instr) { + Compile_SwizzleSrc(instr, 1, instr.common.src1, SRC1); + Compile_DestEnable(instr, SRC1); +} + +void JitShader::Compile_RCP(Instruction instr) { + Compile_SwizzleSrc(instr, 1, instr.common.src1, SRC1); + + // FRECPE can be pretty inaccurate + // FRECPE(1.0f) = 0.99805f != 1.0f + // FRECPE(SRC1.S4(), SRC1.S4()); + // Just do an exact 1.0f / N + FDIV(SRC1.toS(), ONE.toS(), SRC1.toS()); + + DUP(SRC1.S4(), SRC1.Selem()[0]); // XYWZ -> XXXX + Compile_DestEnable(instr, SRC1); +} + +void JitShader::Compile_RSQ(Instruction instr) { + Compile_SwizzleSrc(instr, 1, instr.common.src1, SRC1); + + // FRSQRTE can be pretty inaccurate + // FRSQRTE(8.0f) = 0.35254f != 0.3535533845 + // FRSQRTE(SRC1.S4(), SRC1.S4()); + // Just do an exact 1.0f / sqrt(N) + FSQRT(SRC1.toS(), SRC1.toS()); + FDIV(SRC1.toS(), ONE.toS(), SRC1.toS()); + + DUP(SRC1.S4(), SRC1.Selem()[0]); // XYWZ -> XXXX + Compile_DestEnable(instr, SRC1); +} + +void JitShader::Compile_NOP(Instruction instr) {} + +void JitShader::Compile_END(Instruction instr) { + // Save conditional code + STRB(COND0.toW(), STATE, u32(offsetof(UnitState, conditional_code[0]))); + STRB(COND1.toW(), STATE, u32(offsetof(UnitState, conditional_code[1]))); + + // Save address/loop registers + STR(ADDROFFS_REG_0.toW(), STATE, u32(offsetof(UnitState, address_registers[0]))); + STR(ADDROFFS_REG_1.toW(), STATE, u32(offsetof(UnitState, address_registers[1]))); + STR(LOOPCOUNT_REG.toW(), STATE, u32(offsetof(UnitState, address_registers[2]))); + + ABI_PopRegisters(*this, ABI_ALL_CALLEE_SAVED, 16); + RET(); +} + +void JitShader::Compile_BREAKC(Instruction instr) { + Compile_Assert(loop_depth, "BREAKC must be inside a LOOP"); + if (loop_depth) { + Compile_EvaluateCondition(instr); + ASSERT(!loop_break_labels.empty()); + B(Cond::NE, loop_break_labels.back()); + } +} + +void JitShader::Compile_CALL(Instruction instr) { + // Push offset of the return and link-register + MOV(XSCRATCH0, instr.flow_control.dest_offset + instr.flow_control.num_instructions); + STP(XSCRATCH0, X30, SP, POST_INDEXED, -16); + + // Call the subroutine + BL(instruction_labels[instr.flow_control.dest_offset]); + + // Restore the link-register + // Skip over the return offset that's on the stack + LDP(XZR, X30, SP, PRE_INDEXED, 16); +} + +void JitShader::Compile_CALLC(Instruction instr) { + Compile_EvaluateCondition(instr); + Label b; + B(Cond::EQ, b); + Compile_CALL(instr); + l(b); +} + +void JitShader::Compile_CALLU(Instruction instr) { + Compile_UniformCondition(instr); + Label b; + B(Cond::EQ, b); + Compile_CALL(instr); + l(b); +} + +void JitShader::Compile_CMP(Instruction instr) { + using Op = Instruction::Common::CompareOpType::Op; + Op op_x = instr.common.compare_op.x; + Op op_y = instr.common.compare_op.y; + + Compile_SwizzleSrc(instr, 1, instr.common.src1, SRC1); + Compile_SwizzleSrc(instr, 2, instr.common.src2, SRC2); + + static constexpr Cond cmp[] = {Cond::EQ, Cond::NE, Cond::LT, Cond::LE, Cond::GT, Cond::GE}; + + // Compare X-component + FCMP(SRC1.toS(), SRC2.toS()); + CSET(COND0, cmp[op_x]); + + // Compare Y-component + MOV(VSCRATCH0.toS(), SRC1.Selem()[1]); + MOV(VSCRATCH1.toS(), SRC2.Selem()[1]); + FCMP(VSCRATCH0.toS(), VSCRATCH1.toS()); + CSET(COND1, cmp[op_y]); +} + +void JitShader::Compile_MAD(Instruction instr) { + Compile_SwizzleSrc(instr, 1, instr.mad.src1, SRC1); + + if (instr.opcode.Value().EffectiveOpCode() == OpCode::Id::MADI) { + Compile_SwizzleSrc(instr, 2, instr.mad.src2i, SRC2); + Compile_SwizzleSrc(instr, 3, instr.mad.src3i, SRC3); + } else { + Compile_SwizzleSrc(instr, 2, instr.mad.src2, SRC2); + Compile_SwizzleSrc(instr, 3, instr.mad.src3, SRC3); + } + + Compile_SanitizedMul(SRC1, SRC2, VSCRATCH0); + FADD(SRC1.S4(), SRC1.S4(), SRC3.S4()); + + Compile_DestEnable(instr, SRC1); +} + +void JitShader::Compile_IF(Instruction instr) { + Compile_Assert(instr.flow_control.dest_offset >= program_counter, + "Backwards if-statements not supported"); + Label l_else, l_endif; + + // Evaluate the "IF" condition + if (instr.opcode.Value() == OpCode::Id::IFU) { + Compile_UniformCondition(instr); + } else if (instr.opcode.Value() == OpCode::Id::IFC) { + Compile_EvaluateCondition(instr); + } + B(Cond::EQ, l_else); + + // Compile the code that corresponds to the condition evaluating as true + Compile_Block(instr.flow_control.dest_offset); + + // If there isn't an "ELSE" condition, we are done here + if (instr.flow_control.num_instructions == 0) { + l(l_else); + return; + } + + B(l_endif); + + l(l_else); + // This code corresponds to the "ELSE" condition + // Comple the code that corresponds to the condition evaluating as false + Compile_Block(instr.flow_control.dest_offset + instr.flow_control.num_instructions); + + l(l_endif); +} + +void JitShader::Compile_LOOP(Instruction instr) { + Compile_Assert(instr.flow_control.dest_offset >= program_counter, + "Backwards loops not supported"); + Compile_Assert(loop_depth < 1, "Nested loops may not be supported"); + if (loop_depth++) { + const auto loop_save_regs = BuildRegSet({LOOPCOUNT_REG, LOOPINC, LOOPCOUNT}); + ABI_PushRegisters(*this, loop_save_regs); + } + + // This decodes the fields from the integer uniform at index instr.flow_control.int_uniform_id + const std::size_t offset = Uniforms::GetIntUniformOffset(instr.flow_control.int_uniform_id); + LDR(LOOPCOUNT, UNIFORMS, offset); + + UBFX(LOOPCOUNT_REG, LOOPCOUNT, 8, 8); // Y-component is the start + UBFX(LOOPINC, LOOPCOUNT, 16, 8); // Z-component is the incrementer + UXTB(LOOPCOUNT, LOOPCOUNT); // X-component is iteration count + ADD(LOOPCOUNT, LOOPCOUNT, 1); // Iteration count is X-component + 1 + + Label l_loop_start; + l(l_loop_start); + + loop_break_labels.emplace_back(oaknut::Label()); + Compile_Block(instr.flow_control.dest_offset + 1); + ADD(LOOPCOUNT_REG, LOOPCOUNT_REG, LOOPINC); // Increment LOOPCOUNT_REG by Z-component + SUBS(LOOPCOUNT, LOOPCOUNT, 1); // Increment loop count by 1 + B(Cond::NE, l_loop_start); // Loop if not equal + + l(loop_break_labels.back()); + loop_break_labels.pop_back(); + + if (--loop_depth) { + const auto loop_save_regs = BuildRegSet({LOOPCOUNT_REG, LOOPINC, LOOPCOUNT}); + ABI_PopRegisters(*this, loop_save_regs); + } +} + +void JitShader::Compile_JMP(Instruction instr) { + if (instr.opcode.Value() == OpCode::Id::JMPC) { + Compile_EvaluateCondition(instr); + } else if (instr.opcode.Value() == OpCode::Id::JMPU) { + Compile_UniformCondition(instr); + } else { + UNREACHABLE(); + } + + const bool inverted_condition = + (instr.opcode.Value() == OpCode::Id::JMPU) && (instr.flow_control.num_instructions & 1); + + Label& b = instruction_labels[instr.flow_control.dest_offset]; + if (inverted_condition) { + B(Cond::EQ, b); + } else { + B(Cond::NE, b); + } +} + +static void Emit(GSEmitter* emitter, Common::Vec4 (*output)[16]) { + emitter->Emit(*output); +} + +void JitShader::Compile_EMIT(Instruction instr) { + Label have_emitter, end; + + LDR(XSCRATCH0, STATE, u32(offsetof(UnitState, emitter_ptr))); + CBNZ(XSCRATCH0, have_emitter); + + ABI_PushRegisters(*this, PersistentCallerSavedRegs()); + MOVP2R(ABI_PARAM1, reinterpret_cast("Execute EMIT on VS")); + CallFarFunction(*this, LogCritical); + ABI_PopRegisters(*this, PersistentCallerSavedRegs()); + B(end); + + l(have_emitter); + ABI_PushRegisters(*this, PersistentCallerSavedRegs()); + MOV(ABI_PARAM1, XSCRATCH0); + MOV(ABI_PARAM2, STATE); + ADD(ABI_PARAM2, ABI_PARAM2, u32(offsetof(UnitState, registers.output))); + CallFarFunction(*this, Emit); + ABI_PopRegisters(*this, PersistentCallerSavedRegs()); + l(end); +} + +void JitShader::Compile_SETE(Instruction instr) { + Label have_emitter, end; + + LDR(XSCRATCH0, STATE, u32(offsetof(UnitState, emitter_ptr))); + + CBNZ(XSCRATCH0, have_emitter); + + ABI_PushRegisters(*this, PersistentCallerSavedRegs()); + MOVP2R(ABI_PARAM1, reinterpret_cast("Execute SETEMIT on VS")); + CallFarFunction(*this, LogCritical); + ABI_PopRegisters(*this, PersistentCallerSavedRegs()); + B(end); + + l(have_emitter); + + MOV(XSCRATCH1.toW(), instr.setemit.vertex_id); + STRB(XSCRATCH1.toW(), XSCRATCH0, u32(offsetof(GSEmitter, vertex_id))); + MOV(XSCRATCH1.toW(), instr.setemit.prim_emit); + STRB(XSCRATCH1.toW(), XSCRATCH0, u32(offsetof(GSEmitter, prim_emit))); + MOV(XSCRATCH1.toW(), instr.setemit.winding); + STRB(XSCRATCH1.toW(), XSCRATCH0, u32(offsetof(GSEmitter, winding))); + + l(end); +} + +void JitShader::Compile_Block(unsigned end) { + while (program_counter < end) { + Compile_NextInstr(); + } +} + +void JitShader::Compile_Return() { + // Peek return offset on the stack and check if we're at that offset + LDR(XSCRATCH0, SP, 16); + CMP(XSCRATCH0.toW(), program_counter); + + // If so, jump back to before CALL + Label b; + B(Cond::NE, b); + RET(); + l(b); +} + +void JitShader::Compile_NextInstr() { + if (std::binary_search(return_offsets.begin(), return_offsets.end(), program_counter)) { + Compile_Return(); + } + + l(instruction_labels[program_counter]); + + const Instruction instr = {(*program_code)[program_counter++]}; + + const OpCode::Id opcode = instr.opcode.Value(); + const auto instr_func = instr_table[static_cast(opcode)]; + + if (instr_func) { + // JIT the instruction! + ((*this).*instr_func)(instr); + } else { + // Unhandled instruction + LOG_CRITICAL(HW_GPU, "Unhandled instruction: 0x{:02x} (0x{:08x})", + static_cast(instr.opcode.Value().EffectiveOpCode()), instr.hex); + } +} + +void JitShader::FindReturnOffsets() { + return_offsets.clear(); + + for (std::size_t offset = 0; offset < program_code->size(); ++offset) { + Instruction instr = {(*program_code)[offset]}; + + switch (instr.opcode.Value()) { + case OpCode::Id::CALL: + case OpCode::Id::CALLC: + case OpCode::Id::CALLU: + return_offsets.push_back(instr.flow_control.dest_offset + + instr.flow_control.num_instructions); + break; + default: + break; + } + } + + // Sort for efficient binary search later + std::sort(return_offsets.begin(), return_offsets.end()); +} + +void JitShader::Compile(const std::array* program_code_, + const std::array* swizzle_data_) { + program_code = program_code_; + swizzle_data = swizzle_data_; + + // Reset flow control state + program = (CompiledShader*)current_address(); + program_counter = 0; + loop_depth = 0; + instruction_labels.fill(Label()); + + // Find all `CALL` instructions and identify return locations + FindReturnOffsets(); + + // The stack pointer is 8 modulo 16 at the entry of a procedure + // We reserve 16 bytes and assign a dummy value to the first 8 bytes, to catch any potential + // return checks (see Compile_Return) that happen in shader main routine. + ABI_PushRegisters(*this, ABI_ALL_CALLEE_SAVED, 16); + MVN(XSCRATCH0, XZR); + STR(XSCRATCH0, SP, 8); + + MOV(UNIFORMS, ABI_PARAM1); + MOV(STATE, ABI_PARAM2); + + // Load address/loop registers + LDR(ADDROFFS_REG_0.toW(), STATE, u32(offsetof(UnitState, address_registers[0]))); + LDR(ADDROFFS_REG_1.toW(), STATE, u32(offsetof(UnitState, address_registers[1]))); + LDR(LOOPCOUNT_REG.toW(), STATE, u32(offsetof(UnitState, address_registers[2]))); + + //// Load conditional code + LDRB(COND0.toW(), STATE, u32(offsetof(UnitState, conditional_code[0]))); + LDRB(COND1.toW(), STATE, u32(offsetof(UnitState, conditional_code[1]))); + + // Used to set a register to one + FMOV(ONE.S4(), FImm8(false, 7, 0)); + + // Jump to start of the shader program + BR(ABI_PARAM3); + + // Compile entire program + Compile_Block(static_cast(program_code->size())); + + // Free memory that's no longer needed + program_code = nullptr; + swizzle_data = nullptr; + return_offsets.clear(); + return_offsets.shrink_to_fit(); + + // Memory is ready to execute + protect(); + invalidate_all(); + + const size_t code_size = + current_address() - reinterpret_cast(oaknut::CodeBlock::ptr()); + + ASSERT_MSG(code_size <= MAX_SHADER_SIZE, "Compiled a shader that exceeds the allocated size!"); + LOG_DEBUG(HW_GPU, "Compiled shader size={}", code_size); +} + +JitShader::JitShader() + : oaknut::CodeBlock(MAX_SHADER_SIZE), oaknut::CodeGenerator(oaknut::CodeBlock::ptr()) { + unprotect(); + CompilePrelude(); +} + +void JitShader::CompilePrelude() { + log2_subroutine = CompilePrelude_Log2(); + exp2_subroutine = CompilePrelude_Exp2(); +} + +oaknut::Label JitShader::CompilePrelude_Log2() { + oaknut::Label subroutine; + + // We perform this approximation by first performing a range reduction into the range + // [1.0, 2.0). A minimax polynomial which was fit for the function log2(x) / (x - 1) is then + // evaluated. We multiply the result by (x - 1) then restore the result into the appropriate + // range. Coefficients for the minimax polynomial. + // f(x) computes approximately log2(x) / (x - 1). + // f(x) = c4 + x * (c3 + x * (c2 + x * (c1 + x * c0)). + oaknut::Label c0; + align(16); + l(c0); + dw(0x3d74552f); + + align(16); + oaknut::Label c14; + l(c14); + dw(0xbeee7397); + dw(0x3fbd96dd); + dw(0xc02153f6); + dw(0x4038d96c); + + align(16); + oaknut::Label negative_infinity_vector; + l(negative_infinity_vector); + dw(0xff800000); + dw(0xff800000); + dw(0xff800000); + dw(0xff800000); + oaknut::Label default_qnan_vector; + l(default_qnan_vector); + dw(0x7fc00000); + dw(0x7fc00000); + dw(0x7fc00000); + dw(0x7fc00000); + + oaknut::Label input_is_nan, input_is_zero, input_out_of_range; + + align(16); + l(input_out_of_range); + B(Cond::EQ, input_is_zero); + MOVP2R(XSCRATCH0, default_qnan_vector.ptr()); + LDR(SRC1, XSCRATCH0); + RET(); + + l(input_is_zero); + MOVP2R(XSCRATCH0, negative_infinity_vector.ptr()); + LDR(SRC1, XSCRATCH0); + RET(); + + align(16); + l(subroutine); + + // Here we handle edge cases: input in {NaN, 0, -Inf, Negative}. + // Ordinal(n) ? 0xFFFFFFFF : 0x0 + FCMEQ(VSCRATCH0.toS(), SRC1.toS(), SRC1.toS()); + MOV(XSCRATCH0.toW(), VSCRATCH0.Selem()[0]); + CMP(XSCRATCH0.toW(), 0); + B(Cond::EQ, input_is_nan); // SRC1 == NaN + + // (0.0 >= n) ? 0xFFFFFFFF : 0x0 + MOV(XSCRATCH0.toW(), SRC1.Selem()[0]); + CMP(XSCRATCH0.toW(), 0); + B(Cond::LE, input_out_of_range); // SRC1 <= 0.0 + + // Split input: SRC1=MANT[1,2) VSCRATCH1=Exponent + MOV(XSCRATCH0.toW(), SRC1.Selem()[0]); + MOV(XSCRATCH1.toW(), XSCRATCH0.toW()); + AND(XSCRATCH1.toW(), XSCRATCH1.toW(), 0x007fffff); + ORR(XSCRATCH1.toW(), XSCRATCH1.toW(), 0x3f800000); + MOV(SRC1.Selem()[0], XSCRATCH1.toW()); + // SRC1 now contains the mantissa of the input. + UBFX(XSCRATCH0.toW(), XSCRATCH0.toW(), 23, 8); + SUB(XSCRATCH0.toW(), XSCRATCH0.toW(), 0x7F); + MOV(VSCRATCH1.Selem()[0], XSCRATCH0.toW()); + UCVTF(VSCRATCH1.toS(), VSCRATCH1.toS()); + // VSCRATCH1 now contains the exponent of the input. + + MOVP2R(XSCRATCH0, c0.ptr()); + LDR(XSCRATCH0.toW(), XSCRATCH0); + MOV(VSCRATCH0.Selem()[0], XSCRATCH0.toW()); + + // Complete computation of polynomial + // Load C1,C2,C3,C4 into a single scratch register + const QReg C14 = SRC2; + MOVP2R(XSCRATCH0, c14.ptr()); + LDR(C14, XSCRATCH0); + FMUL(VSCRATCH0.toS(), VSCRATCH0.toS(), SRC1.toS()); + FMLA(VSCRATCH0.toS(), ONE.toS(), C14.Selem()[0]); + FMUL(VSCRATCH0.toS(), VSCRATCH0.toS(), SRC1.toS()); + FMLA(VSCRATCH0.toS(), ONE.toS(), C14.Selem()[1]); + FMUL(VSCRATCH0.toS(), VSCRATCH0.toS(), SRC1.toS()); + FMLA(VSCRATCH0.toS(), ONE.toS(), C14.Selem()[2]); + FMUL(VSCRATCH0.toS(), VSCRATCH0.toS(), SRC1.toS()); + + FSUB(SRC1.toS(), SRC1.toS(), ONE.toS()); + FMLA(VSCRATCH0.toS(), ONE.toS(), C14.Selem()[3]); + + FMUL(VSCRATCH0.toS(), VSCRATCH0.toS(), SRC1.toS()); + FADD(VSCRATCH1.toS(), VSCRATCH0.toS(), VSCRATCH1.toS()); + + // Duplicate result across vector + MOV(SRC1.Selem()[0], VSCRATCH1.Selem()[0]); + l(input_is_nan); + DUP(SRC1.S4(), SRC1.Selem()[0]); + + RET(); + + return subroutine; +} + +oaknut::Label JitShader::CompilePrelude_Exp2() { + oaknut::Label subroutine; + + // This approximation first performs a range reduction into the range [-0.5, 0.5). A minmax + // polynomial which was fit for the function exp2(x) is then evaluated. We then restore the + // result into the appropriate range. + + align(16); + const void* input_max = (const void*)current_address(); + dw(0x43010000); + const void* input_min = (const void*)current_address(); + dw(0xc2fdffff); + const void* c0 = (const void*)current_address(); + dw(0x3c5dbe69); + const void* half = (const void*)current_address(); + dw(0x3f000000); + const void* c1 = (const void*)current_address(); + dw(0x3d5509f9); + const void* c2 = (const void*)current_address(); + dw(0x3e773cc5); + const void* c3 = (const void*)current_address(); + dw(0x3f3168b3); + const void* c4 = (const void*)current_address(); + dw(0x3f800016); + + oaknut::Label ret_label; + + align(16); + l(subroutine); + + // Handle edge cases + FCMP(SRC1.toS(), SRC1.toS()); + B(Cond::NE, ret_label); // branch if NaN + + // Decompose input: + // VSCRATCH0=2^round(input) + // SRC1=input-round(input) [-0.5, 0.5) + // Clamp to maximum range since we shift the value directly into the exponent. + MOVP2R(XSCRATCH0, input_max); + LDR(VSCRATCH0.toS(), XSCRATCH0); + FMIN(SRC1.toS(), SRC1.toS(), VSCRATCH0.toS()); + + MOVP2R(XSCRATCH0, input_min); + LDR(VSCRATCH0.toS(), XSCRATCH0); + FMAX(SRC1.toS(), SRC1.toS(), VSCRATCH0.toS()); + + MOVP2R(XSCRATCH0, half); + LDR(VSCRATCH0.toS(), XSCRATCH0); + FSUB(VSCRATCH0.toS(), SRC1.toS(), VSCRATCH0.toS()); + + FCVTNS(VSCRATCH0.toS(), VSCRATCH0.toS()); + MOV(XSCRATCH0.toW(), VSCRATCH0.Selem()[0]); + SCVTF(VSCRATCH0.toS(), XSCRATCH0.toW()); + + // VSCRATCH0 now contains input rounded to the nearest integer. + ADD(XSCRATCH0.toW(), XSCRATCH0.toW(), 0x7F); + FSUB(SRC1.toS(), SRC1.toS(), VSCRATCH0.toS()); + // SRC1 contains input - round(input), which is in [-0.5, 0.5). + LSL(XSCRATCH0.toW(), XSCRATCH0.toW(), 23); + MOV(VSCRATCH0.Selem()[0], XSCRATCH0.toW()); + // VSCRATCH0 contains 2^(round(input)). + + // Complete computation of polynomial. + ADR(XSCRATCH1, c0); + LDR(VSCRATCH1.toS(), XSCRATCH1); + FMUL(VSCRATCH1.toS(), SRC1.toS(), VSCRATCH1.toS()); + + ADR(XSCRATCH1, c1); + LDR(VSCRATCH2.toS(), XSCRATCH1); + FADD(VSCRATCH1.toS(), VSCRATCH1.toS(), VSCRATCH2.toS()); + FMUL(VSCRATCH1.toS(), VSCRATCH1.toS(), SRC1.toS()); + + ADR(XSCRATCH1, c2); + LDR(VSCRATCH2.toS(), XSCRATCH1); + FADD(VSCRATCH1.toS(), VSCRATCH1.toS(), VSCRATCH2.toS()); + FMUL(VSCRATCH1.toS(), VSCRATCH1.toS(), SRC1.toS()); + + ADR(XSCRATCH1, c3); + LDR(VSCRATCH2.toS(), XSCRATCH1); + FADD(VSCRATCH1.toS(), VSCRATCH1.toS(), VSCRATCH2.toS()); + FMUL(SRC1.toS(), VSCRATCH1.toS(), SRC1.toS()); + + ADR(XSCRATCH1, c4); + LDR(VSCRATCH2.toS(), XSCRATCH1); + FADD(SRC1.toS(), VSCRATCH2.toS(), SRC1.toS()); + + FMUL(SRC1.toS(), SRC1.toS(), VSCRATCH0.toS()); + + // Duplicate result across vector + l(ret_label); + DUP(SRC1.S4(), SRC1.Selem()[0]); + + RET(); + + return subroutine; +} + +} // namespace Pica::Shader + +#endif // CITRA_ARCH(arm64) diff --git a/src/video_core/shader/shader_jit_a64_compiler.h b/src/video_core/shader/shader_jit_a64_compiler.h new file mode 100644 index 0000000000..aa39e4e1df --- /dev/null +++ b/src/video_core/shader/shader_jit_a64_compiler.h @@ -0,0 +1,146 @@ +// Copyright 2023 Citra Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#pragma once + +#include "common/arch.h" +#if CITRA_ARCH(arm64) + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "common/common_types.h" +#include "video_core/shader/shader.h" + +using nihstro::Instruction; +using nihstro::OpCode; +using nihstro::SourceRegister; +using nihstro::SwizzlePattern; + +namespace Pica::Shader { + +/// Memory allocated for each compiled shader +constexpr std::size_t MAX_SHADER_SIZE = MAX_PROGRAM_CODE_LENGTH * 256; + +/** + * This class implements the shader JIT compiler. It recompiles a Pica shader program into x86_64 + * code that can be executed on the host machine directly. + */ +class JitShader : private oaknut::CodeBlock, public oaknut::CodeGenerator { +public: + JitShader(); + + void Run(const ShaderSetup& setup, UnitState& state, unsigned offset) const { + program(&setup.uniforms, &state, instruction_labels[offset].ptr()); + } + + void Compile(const std::array* program_code, + const std::array* swizzle_data); + + void Compile_ADD(Instruction instr); + void Compile_DP3(Instruction instr); + void Compile_DP4(Instruction instr); + void Compile_DPH(Instruction instr); + void Compile_EX2(Instruction instr); + void Compile_LG2(Instruction instr); + void Compile_MUL(Instruction instr); + void Compile_SGE(Instruction instr); + void Compile_SLT(Instruction instr); + void Compile_FLR(Instruction instr); + void Compile_MAX(Instruction instr); + void Compile_MIN(Instruction instr); + void Compile_RCP(Instruction instr); + void Compile_RSQ(Instruction instr); + void Compile_MOVA(Instruction instr); + void Compile_MOV(Instruction instr); + void Compile_NOP(Instruction instr); + void Compile_END(Instruction instr); + void Compile_BREAKC(Instruction instr); + void Compile_CALL(Instruction instr); + void Compile_CALLC(Instruction instr); + void Compile_CALLU(Instruction instr); + void Compile_IF(Instruction instr); + void Compile_LOOP(Instruction instr); + void Compile_JMP(Instruction instr); + void Compile_CMP(Instruction instr); + void Compile_MAD(Instruction instr); + void Compile_EMIT(Instruction instr); + void Compile_SETE(Instruction instr); + +private: + void Compile_Block(unsigned end); + void Compile_NextInstr(); + + void Compile_SwizzleSrc(Instruction instr, unsigned src_num, SourceRegister src_reg, + oaknut::QReg dest); + void Compile_DestEnable(Instruction instr, oaknut::QReg dest); + + /** + * Compiles a `MUL src1, src2` operation, properly handling the PICA semantics when multiplying + * zero by inf. Clobbers `src2` and `scratch`. + */ + void Compile_SanitizedMul(oaknut::QReg src1, oaknut::QReg src2, oaknut::QReg scratch0); + + void Compile_EvaluateCondition(Instruction instr); + void Compile_UniformCondition(Instruction instr); + + /** + * Emits the code to conditionally return from a subroutine envoked by the `CALL` instruction. + */ + void Compile_Return(); + + std::bitset<64> PersistentCallerSavedRegs(); + + /** + * Assertion evaluated at compile-time, but only triggered if executed at runtime. + * @param condition Condition to be evaluated. + * @param msg Message to be logged if the assertion fails. + */ + void Compile_Assert(bool condition, const char* msg); + + /** + * Analyzes the entire shader program for `CALL` instructions before emitting any code, + * identifying the locations where a return needs to be inserted. + */ + void FindReturnOffsets(); + + /** + * Emits data and code for utility functions. + */ + void CompilePrelude(); + oaknut::Label CompilePrelude_Log2(); + oaknut::Label CompilePrelude_Exp2(); + + const std::array* program_code = nullptr; + const std::array* swizzle_data = nullptr; + + /// Mapping of Pica VS instructions to pointers in the emitted code + std::array instruction_labels; + + /// Labels pointing to the end of each nested LOOP block. Used by the BREAKC instruction to + /// break out of a loop. + std::vector loop_break_labels; + + /// Offsets in code where a return needs to be inserted + std::vector return_offsets; + + unsigned program_counter = 0; ///< Offset of the next instruction to decode + u8 loop_depth = 0; ///< Depth of the (nested) loops currently compiled + + using CompiledShader = void(const void* setup, void* state, const std::byte* start_addr); + CompiledShader* program = nullptr; + + oaknut::Label log2_subroutine; + oaknut::Label exp2_subroutine; +}; + +} // namespace Pica::Shader + +#endif