[Mlir-commits] [mlir] 06c6758 - [mlir][spirv] Handle corner cases for math.powf conversion
Lei Zhang
llvmlistbot at llvm.org
Tue Jun 14 20:02:58 PDT 2022
Author: Lei Zhang
Date: 2022-06-14T23:02:44-04:00
New Revision: 06c6758a98161262ac97fad42248139d78d39581
URL: https://github.com/llvm/llvm-project/commit/06c6758a98161262ac97fad42248139d78d39581
DIFF: https://github.com/llvm/llvm-project/commit/06c6758a98161262ac97fad42248139d78d39581.diff
LOG: [mlir][spirv] Handle corner cases for math.powf conversion
Per GLSL Pow extended instruction spec: "Result is undefined if
x < 0. Result is undefined if x = 0 and y <= 0." So we need to
handle negative `x` values specifically.
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D127816
Added:
Modified:
mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 6b792124d269d..07c99f06ab2c3 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -201,6 +201,33 @@ struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
return success();
}
};
+
+/// Converts math.powf to SPIRV-Ops.
+struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(math::PowFOp powfOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto dstType = getTypeConverter()->convertType(powfOp.getType());
+ if (!dstType)
+ return failure();
+
+ // Per GLSL Pow extended instruction spec:
+ // "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0."
+ Location loc = powfOp.getLoc();
+ Value zero =
+ spirv::ConstantOp::getZero(adaptor.getLhs().getType(), loc, rewriter);
+ Value lessThan =
+ rewriter.create<spirv::FOrdLessThanOp>(loc, adaptor.getLhs(), zero);
+ Value abs = rewriter.create<spirv::GLSLFAbsOp>(loc, adaptor.getLhs());
+ Value pow = rewriter.create<spirv::GLSLPowOp>(loc, abs, adaptor.getRhs());
+ Value negate = rewriter.create<spirv::FNegateOp>(loc, pow);
+ rewriter.replaceOpWithNewOp<spirv::SelectOp>(powfOp, lessThan, negate, pow);
+ return success();
+ }
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -216,7 +243,7 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
// GLSL patterns
patterns
.add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLSLLogOp>,
- ExpM1OpPattern<spirv::GLSLExpOp>,
+ ExpM1OpPattern<spirv::GLSLExpOp>, PowFOpPattern,
spirv::ElementwiseOpPattern<math::AbsOp, spirv::GLSLFAbsOp>,
spirv::ElementwiseOpPattern<math::CeilOp, spirv::GLSLCeilOp>,
spirv::ElementwiseOpPattern<math::CosOp, spirv::GLSLCosOp>,
@@ -224,7 +251,6 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
spirv::ElementwiseOpPattern<math::FloorOp, spirv::GLSLFloorOp>,
spirv::ElementwiseOpPattern<math::FmaOp, spirv::GLSLFmaOp>,
spirv::ElementwiseOpPattern<math::LogOp, spirv::GLSLLogOp>,
- spirv::ElementwiseOpPattern<math::PowFOp, spirv::GLSLPowOp>,
spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>,
spirv::ElementwiseOpPattern<math::SinOp, spirv::GLSLSinOp>,
spirv::ElementwiseOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>,
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
index 7940daab10d34..d8126d4e956c6 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
@@ -64,20 +64,6 @@ func.func @float32_unary_vector(%arg0: vector<3xf32>) {
return
}
-// CHECK-LABEL: @float32_binary_scalar
-func.func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
- // CHECK: spv.GLSL.Pow %{{.*}}: f32
- %0 = math.powf %lhs, %rhs : f32
- return
-}
-
-// CHECK-LABEL: @float32_binary_vector
-func.func @float32_binary_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) {
- // CHECK: spv.GLSL.Pow %{{.*}}: vector<4xf32>
- %0 = math.powf %lhs, %rhs : vector<4xf32>
- return
-}
-
// CHECK-LABEL: @float32_ternary_scalar
func.func @float32_ternary_scalar(%a: f32, %b: f32, %c: f32) {
// CHECK: spv.GLSL.Fma %{{.*}}: f32
@@ -133,6 +119,31 @@ func.func @ctlz_vector2(%val: vector<2xi32>) -> vector<2xi32> {
return %0 : vector<2xi32>
}
+// CHECK-LABEL: @powf_scalar
+// CHECK-SAME: (%[[LHS:.+]]: f32, %[[RHS:.+]]: f32)
+func.func @powf_scalar(%lhs: f32, %rhs: f32) -> f32 {
+ // CHECK: %[[F0:.+]] = spv.Constant 0.000000e+00 : f32
+ // CHECK: %[[LT:.+]] = spv.FOrdLessThan %[[LHS]], %[[F0]] : f32
+ // CHECK: %[[ABS:.+]] = spv.GLSL.FAbs %[[LHS]] : f32
+ // CHECK: %[[POW:.+]] = spv.GLSL.Pow %[[ABS]], %[[RHS]] : f32
+ // CHECK: %[[NEG:.+]] = spv.FNegate %[[POW]] : f32
+ // CHECK: %[[SEL:.+]] = spv.Select %[[LT]], %[[NEG]], %[[POW]] : i1, f32
+ %0 = math.powf %lhs, %rhs : f32
+ // CHECK: return %[[SEL]]
+ return %0: f32
+}
+
+// CHECK-LABEL: @powf_vector
+func.func @powf_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) -> vector<4xf32> {
+ // CHECK: spv.FOrdLessThan
+ // CHEKC: spv.GLSL.FAbs
+ // CHECK: spv.GLSL.Pow %{{.*}}: vector<4xf32>
+ // CHECK: spv.FNegate
+ // CHECK: spv.Select
+ %0 = math.powf %lhs, %rhs : vector<4xf32>
+ return %0: vector<4xf32>
+}
+
} // end module
// -----
More information about the Mlir-commits
mailing list