[Mlir-commits] [mlir] a92e3df - [mlir][math] Fix `math.powf` expansion case for `pow(x, 0)` (#119015)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 9 13:52:59 PST 2024


Author: Christopher Bate
Date: 2024-12-09T14:52:55-07:00
New Revision: a92e3df3007a323b5434c6cf301e15e01a642a70

URL: https://github.com/llvm/llvm-project/commit/a92e3df3007a323b5434c6cf301e15e01a642a70
DIFF: https://github.com/llvm/llvm-project/commit/a92e3df3007a323b5434c6cf301e15e01a642a70.diff

LOG: [mlir][math] Fix `math.powf` expansion case for `pow(x, 0)` (#119015)

Lowering `math.powf` to `llvm.intr.powf` will result in `pow(x, 0) =
1`, even for `x=0`. When using the Math dialect expansion patterns,
`pow(0, 0)` will result in `-nan`, however, This change adds two
additional instructions to the lowering to ensure the `pow(x, 0)` case
lowers to to `1` regardless of the value of `x`.

Resolves https://github.com/llvm/llvm-project/issues/118945.

Added: 
    

Modified: 
    mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
    mlir/test/Dialect/Math/expand-math.mlir
    mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 80569d95137c3a..8bcbdb4c9a664a 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -343,6 +343,7 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   Value operandB = op.getOperand(1);
   Type opType = operandA.getType();
   Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
+  Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
   Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
   Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
   Value opASquared = b.create<arith::MulFOp>(opType, operandA, operandA);
@@ -359,8 +360,15 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
       b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
   Value oddAndNeg = b.create<arith::AndIOp>(op->getLoc(), oddPower, negCheck);
 
+  // First, we select between the exp value and the adjusted value for odd
+  // powers of negatives. Then, we ensure that one is produced if `b` is zero.
+  // This corresponds to `libm` behavior, even for `0^0`. Without this check,
+  // `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`.
+  Value zeroCheck =
+      b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
   Value res = b.create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult,
                                         expResult);
+  res = b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, res);
   rewriter.replaceOp(op, res);
   return success();
 }

diff  --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index c10a78ca4ae4ca..89413b95703322 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -222,10 +222,11 @@ func.func @roundf_func(%a: f32) -> f32 {
 // CHECK-SAME:    ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64)
 func.func @powf_func(%a: f64, %b: f64) ->f64 {
   // CHECK-DAG: [[CST0:%.+]] = arith.constant 0.000000e+00
+  // CHECK-DAG: [[CST1:%.+]] = arith.constant 1.0
   // CHECK-DAG: [[TWO:%.+]] = arith.constant 2.000000e+00
   // CHECK-DAG: [[NEGONE:%.+]] = arith.constant -1.000000e+00
   // CHECK-DAG: [[SQR:%.+]] = arith.mulf [[ARG0]], [[ARG0]]
-  // CHECK-DAG: [[HALF:%.+]] = arith.divf [[ARG1]], [[TWO]] 
+  // CHECK-DAG: [[HALF:%.+]] = arith.divf [[ARG1]], [[TWO]]
   // CHECK-DAG: [[LOG:%.+]] = math.log [[SQR]]
   // CHECK-DAG: [[MULT:%.+]] = arith.mulf [[HALF]], [[LOG]]
   // CHECK-DAG: [[EXPR:%.+]] = math.exp [[MULT]]
@@ -234,8 +235,10 @@ func.func @powf_func(%a: f64, %b: f64) ->f64 {
   // CHECK-DAG: [[CMPNEG:%.+]] = arith.cmpf olt, [[ARG0]]
   // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf one, [[REMF]]
   // CHECK-DAG: [[AND:%.+]] = arith.andi [[CMPZERO]], [[CMPNEG]]
+  // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
   // CHECK-DAG: [[SEL:%.+]] = arith.select [[AND]], [[NEGEXPR]], [[EXPR]]
-  // CHECK: return [[SEL]]
+  // CHECK-DAG: [[SEL1:%.+]] = arith.select [[CMPZERO]], [[CST1]], [[SEL]]
+  // CHECK: return [[SEL1]]
   %ret = math.powf %a, %b : f64
   return %ret : f64
 }
@@ -516,7 +519,7 @@ func.func @roundeven16(%arg: f16) -> f16 {
 
 // CHECK-LABEL:   func.func @math_fpowi_neg_odd_power
 func.func @math_fpowi_neg_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
-  %1 = arith.constant dense<-3> : tensor<8xi64> 
+  %1 = arith.constant dense<-3> : tensor<8xi64>
   %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
   return %2 : tensor<8xf32>
 }
@@ -539,7 +542,7 @@ func.func @math_fpowi_neg_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
 
 // CHECK-LABEL:   func.func @math_fpowi_neg_even_power
 func.func @math_fpowi_neg_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
-  %1 = arith.constant dense<-4> : tensor<8xi64> 
+  %1 = arith.constant dense<-4> : tensor<8xi64>
   %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
   return %2 : tensor<8xf32>
 }
@@ -562,7 +565,7 @@ func.func @math_fpowi_neg_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
 
 // CHECK-LABEL:   func.func @math_fpowi_pos_odd_power
 func.func @math_fpowi_pos_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
-  %1 = arith.constant dense<5> : tensor<8xi64> 
+  %1 = arith.constant dense<5> : tensor<8xi64>
   %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
   return %2 : tensor<8xf32>
 }
@@ -576,7 +579,7 @@ func.func @math_fpowi_pos_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
 
 // CHECK-LABEL:   func.func @math_fpowi_pos_even_power
 func.func @math_fpowi_pos_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
-  %1 = arith.constant dense<4> : tensor<8xi64> 
+  %1 = arith.constant dense<4> : tensor<8xi64>
   %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
   return %2 : tensor<8xf32>
 }
@@ -617,9 +620,10 @@ func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> t
   return %2 : tensor<8xf32>
 }
 // CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>, %[[ARG1:.*]]: tensor<8xi32>) -> tensor<8xf32> {
-// CHECK:        %[[CSTNEG1:.*]] = arith.constant dense<-1.000000e+00> : tensor<8xf32>
-// CHECK:        %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<8xf32>
-// CHECK:        %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
+// CHECK-DAG:    %[[CSTNEG1:.*]] = arith.constant dense<-1.000000e+00> : tensor<8xf32>
+// CHECK-DAG:    %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<8xf32>
+// CHECK-DAG:    %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
+// CHECK-DAG:    %[[CST1:.+]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
 // CHECK:        %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
 // CHECK:        %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
 // CHECK:        %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : tensor<8xf32>
@@ -631,8 +635,10 @@ func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> t
 // CHECK:        %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : tensor<8xf32>
 // CHECK:        %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : tensor<8xf32>
 // CHECK:        %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : tensor<8xi1>
+// CHECK:        %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
 // CHECK:        %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32>
-// CHECK:      return %[[SEL]] : tensor<8xf32>
+// CHECK:        %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
+// CHECK:      return %[[SEL1]] : tensor<8xf32>
 
 // -----
 
@@ -642,9 +648,10 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
   return %2 : f32
 }
 // CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: i64) -> f32 {
-// CHECK:        %[[CSTNEG1:.*]] = arith.constant -1.000000e+00 : f32
-// CHECK:        %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
-// CHECK:        %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG:    %[[CSTNEG1:.*]] = arith.constant -1.000000e+00 : f32
+// CHECK-DAG:    %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG:    %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG:    %[[CST1:.+]] = arith.constant 1.000000e+00 : f32
 // CHECK:        %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : i64 to f32
 // CHECK:        %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : f32
 // CHECK:        %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : f32
@@ -656,8 +663,10 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
 // CHECK:        %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : f32
 // CHECK:        %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : f32
 // CHECK:        %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : i1
+// CHECK:        %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
 // CHECK:        %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : f32
-// CHECK:       return %[[SEL]] : f32
+// CHECK:        %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
+// CHECK:       return %[[SEL1]] : f32
 
 // -----
 

diff  --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
index 80d559cc6f730b..93de767b551769 100644
--- a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
@@ -18,7 +18,7 @@ func.func @func_exp2f(%a : f64) {
 func.func @exp2f() {
   // CHECK: 2
   %a = arith.constant 1.0 : f64
-  call @func_exp2f(%a) : (f64) -> ()  
+  call @func_exp2f(%a) : (f64) -> ()
 
   // CHECK-NEXT: 4
   %b = arith.constant 2.0 : f64
@@ -240,13 +240,18 @@ func.func @powf() {
   // CHECK-NEXT: -nan
   %k = arith.constant 1.0 : f64
   %k_p = arith.constant 0xfff0000001000000 : f64
-  call @func_powff64(%k, %k_p) : (f64, f64) -> ()  
+  call @func_powff64(%k, %k_p) : (f64, f64) -> ()
 
   // CHECK-NEXT: -nan
   %l = arith.constant 1.0 : f32
   %l_p = arith.constant 0xffffffff : f32
-  call @func_powff32(%l, %l_p) : (f32, f32) -> ()  
-  return  
+  call @func_powff32(%l, %l_p) : (f32, f32) -> ()
+
+  // CHECK-NEXT: 1
+  %zero = arith.constant 0.0 : f32
+  call @func_powff32(%zero, %zero) : (f32, f32) -> ()
+
+  return
 }
 
 // -------------------------------------------------------------------------- //


        


More information about the Mlir-commits mailing list