[Mlir-commits] [mlir] 06c6758 - [mlir][spirv] Handle corner cases for math.powf conversion

Lei Zhang llvmlistbot at llvm.org
Tue Jun 14 20:02:58 PDT 2022


Author: Lei Zhang
Date: 2022-06-14T23:02:44-04:00
New Revision: 06c6758a98161262ac97fad42248139d78d39581

URL: https://github.com/llvm/llvm-project/commit/06c6758a98161262ac97fad42248139d78d39581
DIFF: https://github.com/llvm/llvm-project/commit/06c6758a98161262ac97fad42248139d78d39581.diff

LOG: [mlir][spirv] Handle corner cases for math.powf conversion

Per GLSL Pow extended instruction spec: "Result is undefined if
x < 0. Result is undefined if x = 0 and y <= 0." So we need to
handle negative `x` values specifically.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D127816

Added: 
    

Modified: 
    mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
    mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 6b792124d269d..07c99f06ab2c3 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -201,6 +201,33 @@ struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
     return success();
   }
 };
+
+/// Converts math.powf to SPIRV-Ops.
+struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(math::PowFOp powfOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto dstType = getTypeConverter()->convertType(powfOp.getType());
+    if (!dstType)
+      return failure();
+
+    // Per GLSL Pow extended instruction spec:
+    // "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0."
+    Location loc = powfOp.getLoc();
+    Value zero =
+        spirv::ConstantOp::getZero(adaptor.getLhs().getType(), loc, rewriter);
+    Value lessThan =
+        rewriter.create<spirv::FOrdLessThanOp>(loc, adaptor.getLhs(), zero);
+    Value abs = rewriter.create<spirv::GLSLFAbsOp>(loc, adaptor.getLhs());
+    Value pow = rewriter.create<spirv::GLSLPowOp>(loc, abs, adaptor.getRhs());
+    Value negate = rewriter.create<spirv::FNegateOp>(loc, pow);
+    rewriter.replaceOpWithNewOp<spirv::SelectOp>(powfOp, lessThan, negate, pow);
+    return success();
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -216,7 +243,7 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
   // GLSL patterns
   patterns
       .add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLSLLogOp>,
-           ExpM1OpPattern<spirv::GLSLExpOp>,
+           ExpM1OpPattern<spirv::GLSLExpOp>, PowFOpPattern,
            spirv::ElementwiseOpPattern<math::AbsOp, spirv::GLSLFAbsOp>,
            spirv::ElementwiseOpPattern<math::CeilOp, spirv::GLSLCeilOp>,
            spirv::ElementwiseOpPattern<math::CosOp, spirv::GLSLCosOp>,
@@ -224,7 +251,6 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
            spirv::ElementwiseOpPattern<math::FloorOp, spirv::GLSLFloorOp>,
            spirv::ElementwiseOpPattern<math::FmaOp, spirv::GLSLFmaOp>,
            spirv::ElementwiseOpPattern<math::LogOp, spirv::GLSLLogOp>,
-           spirv::ElementwiseOpPattern<math::PowFOp, spirv::GLSLPowOp>,
            spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>,
            spirv::ElementwiseOpPattern<math::SinOp, spirv::GLSLSinOp>,
            spirv::ElementwiseOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>,

diff  --git a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
index 7940daab10d34..d8126d4e956c6 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
@@ -64,20 +64,6 @@ func.func @float32_unary_vector(%arg0: vector<3xf32>) {
   return
 }
 
-// CHECK-LABEL: @float32_binary_scalar
-func.func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
-  // CHECK: spv.GLSL.Pow %{{.*}}: f32
-  %0 = math.powf %lhs, %rhs : f32
-  return
-}
-
-// CHECK-LABEL: @float32_binary_vector
-func.func @float32_binary_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) {
-  // CHECK: spv.GLSL.Pow %{{.*}}: vector<4xf32>
-  %0 = math.powf %lhs, %rhs : vector<4xf32>
-  return
-}
-
 // CHECK-LABEL: @float32_ternary_scalar
 func.func @float32_ternary_scalar(%a: f32, %b: f32, %c: f32) {
   // CHECK: spv.GLSL.Fma %{{.*}}: f32
@@ -133,6 +119,31 @@ func.func @ctlz_vector2(%val: vector<2xi32>) -> vector<2xi32> {
   return %0 : vector<2xi32>
 }
 
+// CHECK-LABEL: @powf_scalar
+//  CHECK-SAME: (%[[LHS:.+]]: f32, %[[RHS:.+]]: f32)
+func.func @powf_scalar(%lhs: f32, %rhs: f32) -> f32 {
+  // CHECK: %[[F0:.+]] = spv.Constant 0.000000e+00 : f32
+  // CHECK: %[[LT:.+]] = spv.FOrdLessThan %[[LHS]], %[[F0]] : f32
+  // CHECK: %[[ABS:.+]] = spv.GLSL.FAbs %[[LHS]] : f32
+  // CHECK: %[[POW:.+]] = spv.GLSL.Pow %[[ABS]], %[[RHS]] : f32
+  // CHECK: %[[NEG:.+]] = spv.FNegate %[[POW]] : f32
+  // CHECK: %[[SEL:.+]] = spv.Select %[[LT]], %[[NEG]], %[[POW]] : i1, f32
+  %0 = math.powf %lhs, %rhs : f32
+  // CHECK: return %[[SEL]]
+  return %0: f32
+}
+
+// CHECK-LABEL: @powf_vector
+func.func @powf_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) -> vector<4xf32> {
+  // CHECK: spv.FOrdLessThan
+  // CHEKC: spv.GLSL.FAbs
+  // CHECK: spv.GLSL.Pow %{{.*}}: vector<4xf32>
+  // CHECK: spv.FNegate
+  // CHECK: spv.Select
+  %0 = math.powf %lhs, %rhs : vector<4xf32>
+  return %0: vector<4xf32>
+}
+
 } // end module
 
 // -----


        


More information about the Mlir-commits mailing list