From acf618afbc834ccfd05a33205c035ecb9737b5db Mon Sep 17 00:00:00 2001
From: ReinUsesLisp <reinuseslisp@airmail.cc>
Date: Tue, 9 Apr 2019 17:33:48 -0300
Subject: [PATCH] renderer_opengl: Implement half float NaN comparisons

---
 .../renderer_opengl/gl_shader_decompiler.cpp  | 60 +++++++++++++------
 src/video_core/shader/shader_ir.cpp           | 17 ++----
 src/video_core/shader/shader_ir.h             | 18 ++++--
 3 files changed, 59 insertions(+), 36 deletions(-)

diff --git a/src/video_core/renderer_opengl/gl_shader_decompiler.cpp b/src/video_core/renderer_opengl/gl_shader_decompiler.cpp
index 28e490b3cb..cbaa4dcebd 100644
--- a/src/video_core/renderer_opengl/gl_shader_decompiler.cpp
+++ b/src/video_core/renderer_opengl/gl_shader_decompiler.cpp
@@ -1173,34 +1173,46 @@ private:
         return GenerateUnary(operation, "any", Type::Bool, Type::Bool2);
     }
 
+    template <bool with_nan>
+    std::string GenerateHalfComparison(Operation operation, std::string compare_op) {
+        std::string comparison{GenerateBinaryCall(operation, compare_op, Type::Bool2,
+                                                  Type::HalfFloat, Type::HalfFloat)};
+        if constexpr (!with_nan) {
+            return comparison;
+        }
+        return "halfFloatNanComparison(" + comparison + ", " +
+               VisitOperand(operation, 0, Type::HalfFloat) + ", " +
+               VisitOperand(operation, 1, Type::HalfFloat) + ')';
+    }
+
+    template <bool with_nan>
     std::string Logical2HLessThan(Operation operation) {
-        return GenerateBinaryCall(operation, "lessThan", Type::Bool2, Type::HalfFloat,
-                                  Type::HalfFloat);
+        return GenerateHalfComparison<with_nan>(operation, "lessThan");
     }
 
+    template <bool with_nan>
     std::string Logical2HEqual(Operation operation) {
-        return GenerateBinaryCall(operation, "equal", Type::Bool2, Type::HalfFloat,
-                                  Type::HalfFloat);
+        return GenerateHalfComparison<with_nan>(operation, "equal");
     }
 
+    template <bool with_nan>
     std::string Logical2HLessEqual(Operation operation) {
-        return GenerateBinaryCall(operation, "lessThanEqual", Type::Bool2, Type::HalfFloat,
-                                  Type::HalfFloat);
+        return GenerateHalfComparison<with_nan>(operation, "lessThanEqual");
     }
 
+    template <bool with_nan>
     std::string Logical2HGreaterThan(Operation operation) {
-        return GenerateBinaryCall(operation, "greaterThan", Type::Bool2, Type::HalfFloat,
-                                  Type::HalfFloat);
+        return GenerateHalfComparison<with_nan>(operation, "greaterThan");
     }
 
+    template <bool with_nan>
     std::string Logical2HNotEqual(Operation operation) {
-        return GenerateBinaryCall(operation, "notEqual", Type::Bool2, Type::HalfFloat,
-                                  Type::HalfFloat);
+        return GenerateHalfComparison<with_nan>(operation, "notEqual");
     }
 
+    template <bool with_nan>
     std::string Logical2HGreaterEqual(Operation operation) {
-        return GenerateBinaryCall(operation, "greaterThanEqual", Type::Bool2, Type::HalfFloat,
-                                  Type::HalfFloat);
+        return GenerateHalfComparison<with_nan>(operation, "greaterThanEqual");
     }
 
     std::string Texture(Operation operation) {
@@ -1525,12 +1537,18 @@ private:
         &GLSLDecompiler::LogicalNotEqual<Type::Uint>,
         &GLSLDecompiler::LogicalGreaterEqual<Type::Uint>,
 
-        &GLSLDecompiler::Logical2HLessThan,
-        &GLSLDecompiler::Logical2HEqual,
-        &GLSLDecompiler::Logical2HLessEqual,
-        &GLSLDecompiler::Logical2HGreaterThan,
-        &GLSLDecompiler::Logical2HNotEqual,
-        &GLSLDecompiler::Logical2HGreaterEqual,
+        &GLSLDecompiler::Logical2HLessThan<false>,
+        &GLSLDecompiler::Logical2HEqual<false>,
+        &GLSLDecompiler::Logical2HLessEqual<false>,
+        &GLSLDecompiler::Logical2HGreaterThan<false>,
+        &GLSLDecompiler::Logical2HNotEqual<false>,
+        &GLSLDecompiler::Logical2HGreaterEqual<false>,
+        &GLSLDecompiler::Logical2HLessThan<true>,
+        &GLSLDecompiler::Logical2HEqual<true>,
+        &GLSLDecompiler::Logical2HLessEqual<true>,
+        &GLSLDecompiler::Logical2HGreaterThan<true>,
+        &GLSLDecompiler::Logical2HNotEqual<true>,
+        &GLSLDecompiler::Logical2HGreaterEqual<true>,
 
         &GLSLDecompiler::Texture,
         &GLSLDecompiler::TextureLod,
@@ -1633,6 +1651,12 @@ std::string GetCommonDeclarations() {
            "}\n\n"
            "vec2 toHalf2(float value) {\n"
            "    return unpackHalf2x16(ftou(value));\n"
+           "}\n\n"
+           "bvec2 halfFloatNanComparison(bvec2 comparison, vec2 pair1, vec2 pair2) {\n"
+           "    bvec2 is_nan1 = isnan(pair1);\n"
+           "    bvec2 is_nan2 = isnan(pair2);\n"
+           "    return bvec2(comparison.x || is_nan1.x || is_nan2.x, comparison.y || is_nan1.y || "
+           "is_nan2.y);\n"
            "}\n";
 }
 
diff --git a/src/video_core/shader/shader_ir.cpp b/src/video_core/shader/shader_ir.cpp
index 5175f83c67..5c1c591f84 100644
--- a/src/video_core/shader/shader_ir.cpp
+++ b/src/video_core/shader/shader_ir.cpp
@@ -285,13 +285,6 @@ Node ShaderIR::GetPredicateComparisonInteger(PredCondition condition, bool is_si
 
 Node ShaderIR::GetPredicateComparisonHalf(Tegra::Shader::PredCondition condition,
                                           const MetaHalfArithmetic& meta, Node op_a, Node op_b) {
-    UNIMPLEMENTED_IF_MSG(condition == PredCondition::LessThanWithNan ||
-                             condition == PredCondition::NotEqualWithNan ||
-                             condition == PredCondition::LessEqualWithNan ||
-                             condition == PredCondition::GreaterThanWithNan ||
-                             condition == PredCondition::GreaterEqualWithNan,
-                         "Unimplemented NaN comparison for half floats");
-
     const std::unordered_map<PredCondition, OperationCode> PredicateComparisonTable = {
         {PredCondition::LessThan, OperationCode::Logical2HLessThan},
         {PredCondition::Equal, OperationCode::Logical2HEqual},
@@ -299,11 +292,11 @@ Node ShaderIR::GetPredicateComparisonHalf(Tegra::Shader::PredCondition condition
         {PredCondition::GreaterThan, OperationCode::Logical2HGreaterThan},
         {PredCondition::NotEqual, OperationCode::Logical2HNotEqual},
         {PredCondition::GreaterEqual, OperationCode::Logical2HGreaterEqual},
-        {PredCondition::LessThanWithNan, OperationCode::Logical2HLessThan},
-        {PredCondition::NotEqualWithNan, OperationCode::Logical2HNotEqual},
-        {PredCondition::LessEqualWithNan, OperationCode::Logical2HLessEqual},
-        {PredCondition::GreaterThanWithNan, OperationCode::Logical2HGreaterThan},
-        {PredCondition::GreaterEqualWithNan, OperationCode::Logical2HGreaterEqual}};
+        {PredCondition::LessThanWithNan, OperationCode::Logical2HLessThanWithNan},
+        {PredCondition::NotEqualWithNan, OperationCode::Logical2HNotEqualWithNan},
+        {PredCondition::LessEqualWithNan, OperationCode::Logical2HLessEqualWithNan},
+        {PredCondition::GreaterThanWithNan, OperationCode::Logical2HGreaterThanWithNan},
+        {PredCondition::GreaterEqualWithNan, OperationCode::Logical2HGreaterEqualWithNan}};
 
     const auto comparison{PredicateComparisonTable.find(condition)};
     UNIMPLEMENTED_IF_MSG(comparison == PredicateComparisonTable.end(),
diff --git a/src/video_core/shader/shader_ir.h b/src/video_core/shader/shader_ir.h
index 4888998d34..0ae51389bc 100644
--- a/src/video_core/shader/shader_ir.h
+++ b/src/video_core/shader/shader_ir.h
@@ -150,12 +150,18 @@ enum class OperationCode {
     LogicalUNotEqual,     /// (uint a, uint b) -> bool
     LogicalUGreaterEqual, /// (uint a, uint b) -> bool
 
-    Logical2HLessThan,     /// (MetaHalfArithmetic, f16vec2 a, f16vec2) -> bool2
-    Logical2HEqual,        /// (MetaHalfArithmetic, f16vec2 a, f16vec2) -> bool2
-    Logical2HLessEqual,    /// (MetaHalfArithmetic, f16vec2 a, f16vec2) -> bool2
-    Logical2HGreaterThan,  /// (MetaHalfArithmetic, f16vec2 a, f16vec2) -> bool2
-    Logical2HNotEqual,     /// (MetaHalfArithmetic, f16vec2 a, f16vec2) -> bool2
-    Logical2HGreaterEqual, /// (MetaHalfArithmetic, f16vec2 a, f16vec2) -> bool2
+    Logical2HLessThan,            /// (MetaHalfArithmetic, f16vec2 a, f16vec2) -> bool2
+    Logical2HEqual,               /// (MetaHalfArithmetic, f16vec2 a, f16vec2) -> bool2
+    Logical2HLessEqual,           /// (MetaHalfArithmetic, f16vec2 a, f16vec2) -> bool2
+    Logical2HGreaterThan,         /// (MetaHalfArithmetic, f16vec2 a, f16vec2) -> bool2
+    Logical2HNotEqual,            /// (MetaHalfArithmetic, f16vec2 a, f16vec2) -> bool2
+    Logical2HGreaterEqual,        /// (MetaHalfArithmetic, f16vec2 a, f16vec2) -> bool2
+    Logical2HLessThanWithNan,     /// (MetaHalfArithmetic, f16vec2 a, f16vec2) -> bool2
+    Logical2HEqualWithNan,        /// (MetaHalfArithmetic, f16vec2 a, f16vec2) -> bool2
+    Logical2HLessEqualWithNan,    /// (MetaHalfArithmetic, f16vec2 a, f16vec2) -> bool2
+    Logical2HGreaterThanWithNan,  /// (MetaHalfArithmetic, f16vec2 a, f16vec2) -> bool2
+    Logical2HNotEqualWithNan,     /// (MetaHalfArithmetic, f16vec2 a, f16vec2) -> bool2
+    Logical2HGreaterEqualWithNan, /// (MetaHalfArithmetic, f16vec2 a, f16vec2) -> bool2
 
     Texture,                /// (MetaTexture, float[N] coords) -> float4
     TextureLod,             /// (MetaTexture, float[N] coords) -> float4