[Mlir-commits] [mlir] b9314a8 - [mlir][spirv] Update math.powf lowering (#111388)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Oct 9 01:04:35 PDT 2024
Author: Dmitriy Smirnov
Date: 2024-10-09T09:04:31+01:00
New Revision: b9314a82196a656e2bcc48459123a98ccc02a54d
URL: https://github.com/llvm/llvm-project/commit/b9314a82196a656e2bcc48459123a98ccc02a54d
DIFF: https://github.com/llvm/llvm-project/commit/b9314a82196a656e2bcc48459123a98ccc02a54d.diff
LOG: [mlir][spirv] Update math.powf lowering (#111388)
The PR updates math.powf lowering to produce NaN result for a negative
base with a fractional exponent which matches the actual behaviour of
the C/C++ implementation.
Added:
Modified:
mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 6f948e80d5af8f..1b83794b5f4502 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -377,7 +377,8 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
// Get int type of the same shape as the float type.
Type scalarIntType = rewriter.getIntegerType(32);
Type intType = scalarIntType;
- if (auto vectorType = dyn_cast<VectorType>(adaptor.getRhs().getType())) {
+ auto operandType = adaptor.getRhs().getType();
+ if (auto vectorType = dyn_cast<VectorType>(operandType)) {
auto shape = vectorType.getShape();
intType = VectorType::get(shape, scalarIntType);
}
@@ -385,11 +386,33 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
// Per GL Pow extended instruction spec:
// "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0."
Location loc = powfOp.getLoc();
- Value zero =
- spirv::ConstantOp::getZero(adaptor.getLhs().getType(), loc, rewriter);
+ Value zero = spirv::ConstantOp::getZero(operandType, loc, rewriter);
Value lessThan =
rewriter.create<spirv::FOrdLessThanOp>(loc, adaptor.getLhs(), zero);
- Value abs = rewriter.create<spirv::GLFAbsOp>(loc, adaptor.getLhs());
+
+ // Per C/C++ spec:
+ // > pow(base, exponent) returns NaN (and raises FE_INVALID) if base is
+ // > finite and negative and exponent is finite and non-integer.
+ // Calculate the reminder from the exponent and check whether it is zero.
+ Value floatOne = spirv::ConstantOp::getOne(operandType, loc, rewriter);
+ Value expRem =
+ rewriter.create<spirv::FRemOp>(loc, adaptor.getRhs(), floatOne);
+ Value expRemNonZero =
+ rewriter.create<spirv::FOrdNotEqualOp>(loc, expRem, zero);
+ Value cmpNegativeWithFractionalExp =
+ rewriter.create<spirv::LogicalAndOp>(loc, expRemNonZero, lessThan);
+ // Create NaN result and replace base value if conditions are met.
+ const auto &floatSemantics = scalarFloatType.getFloatSemantics();
+ const auto nan = APFloat::getNaN(floatSemantics);
+ Attribute nanAttr = rewriter.getFloatAttr(scalarFloatType, nan);
+ if (auto vectorType = dyn_cast<VectorType>(operandType))
+ nanAttr = DenseElementsAttr::get(vectorType, nan);
+
+ Value NanValue =
+ rewriter.create<spirv::ConstantOp>(loc, operandType, nanAttr);
+ Value lhs = rewriter.create<spirv::SelectOp>(
+ loc, cmpNegativeWithFractionalExp, NanValue, adaptor.getLhs());
+ Value abs = rewriter.create<spirv::GLFAbsOp>(loc, lhs);
// TODO: The following just forcefully casts y into an integer value in
// order to properly propagate the sign, assuming integer y cases. It
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
index ecbd59e54971ef..5c6561c1043892 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
@@ -156,7 +156,13 @@ func.func @ctlz_vector2(%val: vector<2xi32>) -> vector<2xi32> {
func.func @powf_scalar(%lhs: f32, %rhs: f32) -> f32 {
// CHECK: %[[F0:.+]] = spirv.Constant 0.000000e+00 : f32
// CHECK: %[[LT:.+]] = spirv.FOrdLessThan %[[LHS]], %[[F0]] : f32
- // CHECK: %[[ABS:.+]] = spirv.GL.FAbs %[[LHS]] : f32
+ // CHECK: %[[F1:.+]] = spirv.Constant 1.000000e+00 : f32
+ // CHECK: %[[REM:.+]] = spirv.FRem %[[RHS]], %[[F1]] : f32
+ // CHECK: %[[IS_FRACTION:.+]] = spirv.FOrdNotEqual %[[REM]], %[[F0]] : f32
+ // CHECK: %[[AND:.+]] = spirv.LogicalAnd %[[IS_FRACTION]], %[[LT]] : i1
+ // CHECK: %[[NAN:.+]] = spirv.Constant 0x7FC00000 : f32
+ // CHECK: %[[NEW_LHS:.+]] = spirv.Select %[[AND]], %[[NAN]], %[[LHS]] : i1, f32
+ // CHECK: %[[ABS:.+]] = spirv.GL.FAbs %[[NEW_LHS]] : f32
// CHECK: %[[IRHS:.+]] = spirv.ConvertFToS
// CHECK: %[[CST1:.+]] = spirv.Constant 1 : i32
// CHECK: %[[REM:.+]] = spirv.BitwiseAnd %[[IRHS]]
@@ -173,6 +179,10 @@ func.func @powf_scalar(%lhs: f32, %rhs: f32) -> f32 {
// CHECK-LABEL: @powf_vector
func.func @powf_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) -> vector<4xf32> {
// CHECK: spirv.FOrdLessThan
+ // CHECK: spirv.FRem
+ // CHECK: spirv.FOrdNotEqual
+ // CHECK: spirv.LogicalAnd
+ // CHECK: spirv.Select
// CHECK: spirv.GL.FAbs
// CHECK: spirv.BitwiseAnd %{{.*}} : vector<4xi32>
// CHECK: spirv.IEqual %{{.*}} : vector<4xi32>
More information about the Mlir-commits
mailing list