[Mlir-commits] [mlir] 3a33775 - [mlir][math]Update `convertPowfOp` `ExpandPatterns.cpp` (#124402)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 28 19:56:47 PST 2025


Author: Hyunsung Lee
Date: 2025-01-28T22:56:43-05:00
New Revision: 3a3377579f137a0a6e14b60d891a9736707e7e8d

URL: https://github.com/llvm/llvm-project/commit/3a3377579f137a0a6e14b60d891a9736707e7e8d
DIFF: https://github.com/llvm/llvm-project/commit/3a3377579f137a0a6e14b60d891a9736707e7e8d.diff

LOG: [mlir][math]Update `convertPowfOp` `ExpandPatterns.cpp` (#124402)

The current implementation of `convertPowfOp` requires a calculation of
`a * a` but, max\<fp16\> ~= 65,504, and if `a` is about 16, it will
overflow so get INF in fp8 or fp16 easily.


Remove support when `a < 0`. Overhead of handling negative value of `a`
is large and easy to overflow;

- related issue in iree:
https://github.com/iree-org/iree/issues/15936

Added: 
    

Modified: 
    mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
    mlir/test/Dialect/Math/expand-math.mlir
    mlir/test/mlir-runner/test-expand-math-approx.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 3dadf9474cf4f67..30bcdfc45837a65 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -311,7 +311,8 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
   return success();
 }
 
-// Converts  Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
+// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
+// Restricting a >= 0
 static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
   Value operandA = op.getOperand(0);
@@ -319,21 +320,10 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   Type opType = operandA.getType();
   Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
   Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
-  Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
-  Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
-  Value opASquared = b.create<arith::MulFOp>(opType, operandA, operandA);
-  Value opBHalf = b.create<arith::DivFOp>(opType, operandB, two);
 
-  Value logA = b.create<math::LogOp>(opType, opASquared);
-  Value mult = b.create<arith::MulFOp>(opType, opBHalf, logA);
+  Value logA = b.create<math::LogOp>(opType, operandA);
+  Value mult = b.create<arith::MulFOp>(opType, operandB, logA);
   Value expResult = b.create<math::ExpOp>(opType, mult);
-  Value negExpResult = b.create<arith::MulFOp>(opType, expResult, negOne);
-  Value remainder = b.create<arith::RemFOp>(opType, operandB, two);
-  Value negCheck =
-      b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
-  Value oddPower =
-      b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
-  Value oddAndNeg = b.create<arith::AndIOp>(op->getLoc(), oddPower, negCheck);
 
   // First, we select between the exp value and the adjusted value for odd
   // powers of negatives. Then, we ensure that one is produced if `b` is zero.
@@ -341,10 +331,9 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   // `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`.
   Value zeroCheck =
       b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
-  Value res = b.create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult,
-                                        expResult);
-  res = b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, res);
-  rewriter.replaceOp(op, res);
+  Value finalResult =
+      b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, expResult);
+  rewriter.replaceOp(op, finalResult);
   return success();
 }
 

diff  --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 6055ed0504c84ca..5b443e9e8d4e78e 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -202,25 +202,15 @@ func.func @roundf_func(%a: f32) -> f32 {
 
 // CHECK-LABEL:   func @powf_func
 // CHECK-SAME:    ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64)
-func.func @powf_func(%a: f64, %b: f64) ->f64 {
+func.func @powf_func(%a: f64, %b: f64) -> f64 {
   // CHECK-DAG: [[CST0:%.+]] = arith.constant 0.000000e+00
   // CHECK-DAG: [[CST1:%.+]] = arith.constant 1.0
-  // CHECK-DAG: [[TWO:%.+]] = arith.constant 2.000000e+00
-  // CHECK-DAG: [[NEGONE:%.+]] = arith.constant -1.000000e+00
-  // CHECK-DAG: [[SQR:%.+]] = arith.mulf [[ARG0]], [[ARG0]]
-  // CHECK-DAG: [[HALF:%.+]] = arith.divf [[ARG1]], [[TWO]]
-  // CHECK-DAG: [[LOG:%.+]] = math.log [[SQR]]
-  // CHECK-DAG: [[MULT:%.+]] = arith.mulf [[HALF]], [[LOG]]
-  // CHECK-DAG: [[EXPR:%.+]] = math.exp [[MULT]]
-  // CHECK-DAG: [[NEGEXPR:%.+]] = arith.mulf [[EXPR]], [[NEGONE]]
-  // CHECK-DAG: [[REMF:%.+]] = arith.remf [[ARG1]], [[TWO]]
-  // CHECK-DAG: [[CMPNEG:%.+]] = arith.cmpf olt, [[ARG0]]
-  // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf one, [[REMF]]
-  // CHECK-DAG: [[AND:%.+]] = arith.andi [[CMPZERO]], [[CMPNEG]]
-  // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
-  // CHECK-DAG: [[SEL:%.+]] = arith.select [[AND]], [[NEGEXPR]], [[EXPR]]
-  // CHECK-DAG: [[SEL1:%.+]] = arith.select [[CMPZERO]], [[CST1]], [[SEL]]
-  // CHECK: return [[SEL1]]
+  // CHECK: [[LOGA:%.+]] = math.log [[ARG0]]
+  // CHECK: [[MULB:%.+]] = arith.mulf [[ARG1]], [[LOGA]]
+  // CHECK: [[EXP:%.+]] = math.exp [[MULB]]
+  // CHECK: [[CMPF:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
+  // CHECK: [[SEL:%.+]] = arith.select [[CMPF]], [[CST1]], [[EXP]]
+  // CHECK: return [[SEL]]
   %ret = math.powf %a, %b : f64
   return %ret : f64
 }
@@ -602,26 +592,15 @@ func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> t
   return %2 : tensor<8xf32>
 }
 // CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>, %[[ARG1:.*]]: tensor<8xi32>) -> tensor<8xf32> {
-// CHECK-DAG:    %[[CSTNEG1:.*]] = arith.constant dense<-1.000000e+00> : tensor<8xf32>
-// CHECK-DAG:    %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<8xf32>
-// CHECK-DAG:    %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
 // CHECK-DAG:    %[[CST1:.+]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
-// CHECK:        %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
-// CHECK:        %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
-// CHECK:        %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : tensor<8xf32>
-// CHECK:        %[[LG:.*]] = math.log %[[SQ]] : tensor<8xf32>
-// CHECK:        %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : tensor<8xf32>
-// CHECK:        %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32>
-// CHECK:        %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : tensor<8xf32>
-// CHECK:        %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : tensor<8xf32>
-// CHECK:        %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : tensor<8xf32>
-// CHECK:        %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : tensor<8xf32>
-// CHECK:        %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : tensor<8xi1>
-// CHECK:        %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
-// CHECK:        %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32>
-// CHECK:        %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
-// CHECK:      return %[[SEL1]] : tensor<8xf32>
-
+// CHECK-DAG:    %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
+// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
+// CHECK: %[[LOGA:.*]] = math.log %[[ARG0]] : tensor<8xf32>
+// CHECK: %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : tensor<8xf32>
+// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32>
+// CHECK: %[[CMP:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : tensor<8xf32>
+// CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[CST1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32>
+// CHECK: return %[[SEL]]
 // -----
 
 // CHECK-LABEL:   func.func @math_fpowi_to_powf_scalar
@@ -630,25 +609,15 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
   return %2 : f32
 }
 // CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: i64) -> f32 {
-// CHECK-DAG:    %[[CSTNEG1:.*]] = arith.constant -1.000000e+00 : f32
-// CHECK-DAG:    %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
 // CHECK-DAG:    %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK-DAG:    %[[CST1:.+]] = arith.constant 1.000000e+00 : f32
 // CHECK:        %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : i64 to f32
-// CHECK:        %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : f32
-// CHECK:        %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : f32
-// CHECK:        %[[LG:.*]] = math.log %[[SQ]] : f32
-// CHECK:        %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : f32
+// CHECK:        %[[LOGA:.*]] = math.log %[[ARG0]] : f32
+// CHECK:        %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : f32
 // CHECK:        %[[EXP:.*]] = math.exp %[[MUL]] : f32
-// CHECK:        %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : f32
-// CHECK:        %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : f32
-// CHECK:        %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : f32
-// CHECK:        %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : f32
-// CHECK:        %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : i1
-// CHECK:        %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
-// CHECK:        %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : f32
-// CHECK:        %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
-// CHECK:       return %[[SEL1]] : f32
+// CHECK:        %[[CMP:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : f32
+// CHECK:        %[[SEL:.*]] = arith.select %[[CMP]], %[[CST1]], %[[EXP]] : f32
+// CHECK:       return %[[SEL]] : f32
 
 // -----
 

diff  --git a/mlir/test/mlir-runner/test-expand-math-approx.mlir b/mlir/test/mlir-runner/test-expand-math-approx.mlir
index 106b48a2daea2e3..d1916c28878b97a 100644
--- a/mlir/test/mlir-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-runner/test-expand-math-approx.mlir
@@ -202,11 +202,6 @@ func.func @powf() {
   %a_p = arith.constant 2.0 : f64
   call @func_powff64(%a, %a_p) : (f64, f64) -> ()
 
-  // CHECK-NEXT: -27
-  %b   = arith.constant -3.0 : f64
-  %b_p = arith.constant 3.0 : f64
-  call @func_powff64(%b, %b_p) : (f64, f64) -> ()
-
   // CHECK-NEXT: 2.343
   %c   = arith.constant 2.343 : f64
   %c_p = arith.constant 1.000 : f64


        


More information about the Mlir-commits mailing list