[Mlir-commits] [mlir] [mlir][spirv] Update math.powf lowering (PR #111388)

Dmitriy Smirnov llvmlistbot at llvm.org
Mon Oct 7 13:03:40 PDT 2024


https://github.com/d-smirnov updated https://github.com/llvm/llvm-project/pull/111388

>From 54191573a1c056856f8969375b301b7fdc22079b Mon Sep 17 00:00:00 2001
From: Dmitriy Smirnov <dmitriy.smirnov at arm.com>
Date: Mon, 7 Oct 2024 14:29:11 +0100
Subject: [PATCH 1/2] [mlir][spirv] Update math.powf lowering

math.powf lowering now produces NaN result for negative base with fractonal exponent
which matches the actual behavior of the C/C++ implementation
---
 .../Conversion/MathToSPIRV/MathToSPIRV.cpp    | 26 ++++++++++++++++++-
 .../MathToSPIRV/math-to-gl-spirv.mlir         | 12 ++++++++-
 2 files changed, 36 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 6f948e80d5af8f..fb7816b2fc2473 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -389,7 +389,31 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
         spirv::ConstantOp::getZero(adaptor.getLhs().getType(), loc, rewriter);
     Value lessThan =
         rewriter.create<spirv::FOrdLessThanOp>(loc, adaptor.getLhs(), zero);
-    Value abs = rewriter.create<spirv::GLFAbsOp>(loc, adaptor.getLhs());
+
+    // Per C/CPP spec:
+    // "pow(base, exponent) returns NaN (and raises FE_INVALID) if base is "
+    // " finite and negative and exponent is finite and non-integer. "
+    // Calculae calc reminder from exponent and check whether it is zero
+    Value floatOne =
+        spirv::ConstantOp::getOne(adaptor.getRhs().getType(), loc, rewriter);
+    Value expRem =
+        rewriter.create<spirv::FRemOp>(loc, adaptor.getRhs(), floatOne);
+    Value expRemNonZero =
+        rewriter.create<spirv::FOrdNotEqualOp>(loc, expRem, zero);
+    Value cmpNegativeWithFractionalExp =
+        rewriter.create<spirv::LogicalAndOp>(loc, expRemNonZero, lessThan);
+    // Create NaN result and replace base value if conditions meet
+    const auto &floatSemantics = scalarFloatType.getFloatSemantics();
+    const auto nan = APFloat::getNaN(floatSemantics);
+    Attribute nanAttr = rewriter.getFloatAttr(scalarFloatType, nan);
+    if (auto vectorType = dyn_cast<VectorType>(adaptor.getRhs().getType()))
+      nanAttr = DenseElementsAttr::get(vectorType, nan);
+
+    Value NanValue = rewriter.create<spirv::ConstantOp>(
+        loc, adaptor.getRhs().getType(), nanAttr);
+    Value lhs = rewriter.create<spirv::SelectOp>(
+        loc, cmpNegativeWithFractionalExp, NanValue, adaptor.getLhs());
+    Value abs = rewriter.create<spirv::GLFAbsOp>(loc, lhs);
 
     // TODO: The following just forcefully casts y into an integer value in
     // order to properly propagate the sign, assuming integer y cases. It
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
index ecbd59e54971ef..5c6561c1043892 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
@@ -156,7 +156,13 @@ func.func @ctlz_vector2(%val: vector<2xi32>) -> vector<2xi32> {
 func.func @powf_scalar(%lhs: f32, %rhs: f32) -> f32 {
   // CHECK: %[[F0:.+]] = spirv.Constant 0.000000e+00 : f32
   // CHECK: %[[LT:.+]] = spirv.FOrdLessThan %[[LHS]], %[[F0]] : f32
-  // CHECK: %[[ABS:.+]] = spirv.GL.FAbs %[[LHS]] : f32
+  // CHECK: %[[F1:.+]] = spirv.Constant 1.000000e+00 : f32
+  // CHECK: %[[REM:.+]] = spirv.FRem %[[RHS]], %[[F1]] : f32
+  // CHECK: %[[IS_FRACTION:.+]] = spirv.FOrdNotEqual %[[REM]], %[[F0]] : f32
+  // CHECK: %[[AND:.+]] = spirv.LogicalAnd %[[IS_FRACTION]], %[[LT]] : i1
+  // CHECK: %[[NAN:.+]] = spirv.Constant 0x7FC00000 : f32
+  // CHECK: %[[NEW_LHS:.+]] = spirv.Select %[[AND]], %[[NAN]], %[[LHS]] : i1, f32
+  // CHECK: %[[ABS:.+]] = spirv.GL.FAbs %[[NEW_LHS]] : f32
   // CHECK: %[[IRHS:.+]] = spirv.ConvertFToS
   // CHECK: %[[CST1:.+]] = spirv.Constant 1 : i32
   // CHECK: %[[REM:.+]] = spirv.BitwiseAnd %[[IRHS]]
@@ -173,6 +179,10 @@ func.func @powf_scalar(%lhs: f32, %rhs: f32) -> f32 {
 // CHECK-LABEL: @powf_vector
 func.func @powf_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) -> vector<4xf32> {
   // CHECK: spirv.FOrdLessThan
+  // CHECK: spirv.FRem
+  // CHECK: spirv.FOrdNotEqual
+  // CHECK: spirv.LogicalAnd
+  // CHECK: spirv.Select
   // CHECK: spirv.GL.FAbs
   // CHECK: spirv.BitwiseAnd %{{.*}} : vector<4xi32>
   // CHECK: spirv.IEqual %{{.*}} : vector<4xi32>

>From e36f8a2a12250619354ecd4ceb55162e2549c4c0 Mon Sep 17 00:00:00 2001
From: Dmitriy Smirnov <dmitriy.smirnov at arm.com>
Date: Mon, 7 Oct 2024 20:56:35 +0100
Subject: [PATCH 2/2] Addressed comments

---
 .../Conversion/MathToSPIRV/MathToSPIRV.cpp    | 26 +++++++++----------
 1 file changed, 13 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index fb7816b2fc2473..9a43894da6da6a 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -377,7 +377,8 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
     // Get int type of the same shape as the float type.
     Type scalarIntType = rewriter.getIntegerType(32);
     Type intType = scalarIntType;
-    if (auto vectorType = dyn_cast<VectorType>(adaptor.getRhs().getType())) {
+    auto exponentType = adaptor.getRhs().getType();
+    if (auto vectorType = dyn_cast<VectorType>(exponentType)) {
       auto shape = vectorType.getShape();
       intType = VectorType::get(shape, scalarIntType);
     }
@@ -385,32 +386,31 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
     // Per GL Pow extended instruction spec:
     // "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0."
     Location loc = powfOp.getLoc();
-    Value zero =
-        spirv::ConstantOp::getZero(adaptor.getLhs().getType(), loc, rewriter);
+    auto baseType = adaptor.getLhs().getType();
+    Value zero = spirv::ConstantOp::getZero(baseType, loc, rewriter);
     Value lessThan =
         rewriter.create<spirv::FOrdLessThanOp>(loc, adaptor.getLhs(), zero);
 
-    // Per C/CPP spec:
-    // "pow(base, exponent) returns NaN (and raises FE_INVALID) if base is "
-    // " finite and negative and exponent is finite and non-integer. "
-    // Calculae calc reminder from exponent and check whether it is zero
-    Value floatOne =
-        spirv::ConstantOp::getOne(adaptor.getRhs().getType(), loc, rewriter);
+    // Per C/C++ spec:
+    // > pow(base, exponent) returns NaN (and raises FE_INVALID) if base is
+    // > finite and negative and exponent is finite and non-integer.
+    // Calculate the reminder from the exponent and check whether it is zero.
+    Value floatOne = spirv::ConstantOp::getOne(exponentType, loc, rewriter);
     Value expRem =
         rewriter.create<spirv::FRemOp>(loc, adaptor.getRhs(), floatOne);
     Value expRemNonZero =
         rewriter.create<spirv::FOrdNotEqualOp>(loc, expRem, zero);
     Value cmpNegativeWithFractionalExp =
         rewriter.create<spirv::LogicalAndOp>(loc, expRemNonZero, lessThan);
-    // Create NaN result and replace base value if conditions meet
+    // Create NaN result and replace base value if conditions are met
     const auto &floatSemantics = scalarFloatType.getFloatSemantics();
     const auto nan = APFloat::getNaN(floatSemantics);
     Attribute nanAttr = rewriter.getFloatAttr(scalarFloatType, nan);
-    if (auto vectorType = dyn_cast<VectorType>(adaptor.getRhs().getType()))
+    if (auto vectorType = dyn_cast<VectorType>(baseType))
       nanAttr = DenseElementsAttr::get(vectorType, nan);
 
-    Value NanValue = rewriter.create<spirv::ConstantOp>(
-        loc, adaptor.getRhs().getType(), nanAttr);
+    Value NanValue =
+        rewriter.create<spirv::ConstantOp>(loc, exponentType, nanAttr);
     Value lhs = rewriter.create<spirv::SelectOp>(
         loc, cmpNegativeWithFractionalExp, NanValue, adaptor.getLhs());
     Value abs = rewriter.create<spirv::GLFAbsOp>(loc, lhs);



More information about the Mlir-commits mailing list