[Mlir-commits] [mlir] 58839f2 - [spirv][math] Fix sign propagation for math.powf conversion
Lei Zhang
llvmlistbot at llvm.org
Tue May 9 21:51:02 PDT 2023
Author: Daniel Garvey
Date: 2023-05-09T21:44:09-07:00
New Revision: 58839f2e2913c2b7c0bb11a2292b8e666835cdb2
URL: https://github.com/llvm/llvm-project/commit/58839f2e2913c2b7c0bb11a2292b8e666835cdb2
DIFF: https://github.com/llvm/llvm-project/commit/58839f2e2913c2b7c0bb11a2292b8e666835cdb2.diff
LOG: [spirv][math] Fix sign propagation for math.powf conversion
For `x^y`, the result's sign should consider whether `y` is
an integer and whether it's odd or even.
This still does not cover all corner cases regarding `x^y`
but it's an improvement over the current implementation.
Reviewed By: antiagainst, qedawkins
Differential Revision: https://reviews.llvm.org/D150234
Added:
Modified:
mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 80b22576e61c8..412f99ce042e9 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -305,6 +305,24 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
if (!dstType)
return failure();
+ // Get the scalar float type.
+ FloatType scalarFloatType;
+ if (auto scalarType = powfOp.getType().dyn_cast<FloatType>()) {
+ scalarFloatType = scalarType;
+ } else if (auto vectorType = powfOp.getType().dyn_cast<VectorType>()) {
+ scalarFloatType = vectorType.getElementType().cast<FloatType>();
+ } else {
+ return failure();
+ }
+
+ // Get int type of the same shape as the float type.
+ Type scalarIntType = rewriter.getIntegerType(32);
+ Type intType = scalarIntType;
+ if (auto vectorType = adaptor.getRhs().getType().dyn_cast<VectorType>()) {
+ auto shape = vectorType.getShape();
+ intType = VectorType::get(shape, scalarIntType);
+ }
+
// Per GL Pow extended instruction spec:
// "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0."
Location loc = powfOp.getLoc();
@@ -313,9 +331,27 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
Value lessThan =
rewriter.create<spirv::FOrdLessThanOp>(loc, adaptor.getLhs(), zero);
Value abs = rewriter.create<spirv::GLFAbsOp>(loc, adaptor.getLhs());
+
+ // TODO: The following just forcefully casts y into an integer value in
+ // order to properly propagate the sign, assuming integer y cases. It
+ // doesn't cover other cases and should be fixed.
+
+ // Cast exponent to integer and calculate exponent % 2 != 0.
+ Value intRhs =
+ rewriter.create<spirv::ConvertFToSOp>(loc, intType, adaptor.getRhs());
+ Value intOne = spirv::ConstantOp::getOne(intType, loc, rewriter);
+ Value bitwiseAndOne =
+ rewriter.create<spirv::BitwiseAndOp>(loc, intRhs, intOne);
+ Value isOdd = rewriter.create<spirv::IEqualOp>(loc, bitwiseAndOne, intOne);
+
+ // calculate pow based on abs(lhs)^rhs.
Value pow = rewriter.create<spirv::GLPowOp>(loc, abs, adaptor.getRhs());
Value negate = rewriter.create<spirv::FNegateOp>(loc, pow);
- rewriter.replaceOpWithNewOp<spirv::SelectOp>(powfOp, lessThan, negate, pow);
+ // if the exponent is odd and lhs < 0, negate the result.
+ Value shouldNegate =
+ rewriter.create<spirv::LogicalAndOp>(loc, lessThan, isOdd);
+ rewriter.replaceOpWithNewOp<spirv::SelectOp>(powfOp, shouldNegate, negate,
+ pow);
return success();
}
};
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
index 125478e2cb214..4d0ef06d7e92f 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
@@ -137,9 +137,14 @@ func.func @powf_scalar(%lhs: f32, %rhs: f32) -> f32 {
// CHECK: %[[F0:.+]] = spirv.Constant 0.000000e+00 : f32
// CHECK: %[[LT:.+]] = spirv.FOrdLessThan %[[LHS]], %[[F0]] : f32
// CHECK: %[[ABS:.+]] = spirv.GL.FAbs %[[LHS]] : f32
+ // CHECK: %[[IRHS:.+]] = spirv.ConvertFToS
+ // CHECK: %[[CST1:.+]] = spirv.Constant 1 : i32
+ // CHECK: %[[REM:.+]] = spirv.BitwiseAnd %[[IRHS]]
+ // CHECK: %[[ODD:.+]] = spirv.IEqual %[[REM]], %[[CST1]] : i32
// CHECK: %[[POW:.+]] = spirv.GL.Pow %[[ABS]], %[[RHS]] : f32
// CHECK: %[[NEG:.+]] = spirv.FNegate %[[POW]] : f32
- // CHECK: %[[SEL:.+]] = spirv.Select %[[LT]], %[[NEG]], %[[POW]] : i1, f32
+ // CHECK: %[[SNEG:.+]] = spirv.LogicalAnd %[[LT]], %[[ODD]] : i1
+ // CHECK: %[[SEL:.+]] = spirv.Select %[[SNEG]], %[[NEG]], %[[POW]] : i1, f32
%0 = math.powf %lhs, %rhs : f32
// CHECK: return %[[SEL]]
return %0: f32
@@ -149,6 +154,8 @@ func.func @powf_scalar(%lhs: f32, %rhs: f32) -> f32 {
func.func @powf_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) -> vector<4xf32> {
// CHECK: spirv.FOrdLessThan
// CHECK: spirv.GL.FAbs
+ // CHECK: spirv.BitwiseAnd %{{.*}} : vector<4xi32>
+ // CHECK: spirv.IEqual %{{.*}} : vector<4xi32>
// CHECK: spirv.GL.Pow %{{.*}}: vector<4xf32>
// CHECK: spirv.FNegate
// CHECK: spirv.Select
More information about the Mlir-commits
mailing list