[llvm-branch-commits] [mlir] 94f348b - [mlir][math] Modify math.powf to handle negative bases.

Tobias Hieta via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Aug 30 23:56:12 PDT 2023


Author: Balaji V. Iyer
Date: 2023-08-31T08:54:24+02:00
New Revision: 94f348b7842a2d3a00b5a7d6641b394c95486252

URL: https://github.com/llvm/llvm-project/commit/94f348b7842a2d3a00b5a7d6641b394c95486252
DIFF: https://github.com/llvm/llvm-project/commit/94f348b7842a2d3a00b5a7d6641b394c95486252.diff

LOG: [mlir][math] Modify math.powf to handle negative bases.

Powf expansion currently returns NaN when the base is negative.
This is because taking natural log of a negative number gives
NaN. This patch will square the base and half the exponent, thereby
getting around the negative base problem.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D158797

Added: 
    

Modified: 
    mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
    mlir/test/Dialect/Math/expand-math.mlir
    mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index ee8f23cf362b62..98c97fdc2c0905 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -168,11 +168,26 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   Value operandA = op.getOperand(0);
   Value operandB = op.getOperand(1);
   Type opType = operandA.getType();
+  Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
+  Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
+  Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
+  Value opASquared = b.create<arith::MulFOp>(opType, operandA, operandA);
+  Value opBHalf = b.create<arith::DivFOp>(opType, operandB, two);
 
-  Value logA = b.create<math::LogOp>(opType, operandA);
-  Value mult = b.create<arith::MulFOp>(opType, logA, operandB);
+  Value logA = b.create<math::LogOp>(opType, opASquared);
+  Value mult = b.create<arith::MulFOp>(opType, opBHalf, logA);
   Value expResult = b.create<math::ExpOp>(opType, mult);
-  rewriter.replaceOp(op, expResult);
+  Value negExpResult = b.create<arith::MulFOp>(opType, expResult, negOne);
+  Value remainder = b.create<arith::RemFOp>(opType, operandB, two);
+  Value negCheck =
+      b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
+  Value oddPower =
+      b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
+  Value oddAndNeg = b.create<arith::AndIOp>(op->getLoc(), oddPower, negCheck);
+
+  Value res = b.create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult,
+                                        expResult);
+  rewriter.replaceOp(op, res);
   return success();
 }
 

diff  --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index c28e2141db061a..4cd64611020790 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -222,10 +222,21 @@ func.func @roundf_func(%a: f32) -> f32 {
 // CHECK-LABEL:   func @powf_func
 // CHECK-SAME:    ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64)
 func.func @powf_func(%a: f64, %b: f64) ->f64 {
-  // CHECK-DAG: [[LOG:%.+]] = math.log [[ARG0]]
-  // CHECK-DAG: [[MULT:%.+]] = arith.mulf [[LOG]], [[ARG1]]
+  // CHECK-DAG = [[CST0:%.+]] = arith.constant 0.000000e+00
+  // CHECK-DAG: [[TWO:%.+]] = arith.constant 2.000000e+00
+  // CHECK-DAG: [[NEGONE:%.+]] = arith.constant -1.000000e+00
+  // CHECK-DAG: [[SQR:%.+]] = arith.mulf [[ARG0]], [[ARG0]]
+  // CHECK-DAG: [[HALF:%.+]] = arith.divf [[ARG1]], [[TWO]] 
+  // CHECK-DAG: [[LOG:%.+]] = math.log [[SQR]]
+  // CHECK-DAG: [[MULT:%.+]] = arith.mulf [[HALF]], [[LOG]]
   // CHECK-DAG: [[EXPR:%.+]] = math.exp [[MULT]]
-  // CHECK: return [[EXPR]]
+  // CHECK-DAG: [[NEGEXPR:%.+]] = arith.mulf [[EXPR]], [[NEGONE]]
+  // CHECK-DAG: [[REMF:%.+]] = arith.remf [[ARG1]], [[TWO]]
+  // CHECK-DAG: [[CMPNEG:%.+]] = arith.cmpf olt, [[ARG0]]
+  // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf one, [[REMF]]
+  // CHECK-DAG: [[AND:%.+]] = arith.andi [[CMPZERO]], [[CMPNEG]]
+  // CHECK-DAG: [[SEL:%.+]] = arith.select [[AND]], [[NEGEXPR]], [[EXPR]]
+  // CHECK: return [[SEL]]
   %ret = math.powf %a, %b : f64
   return %ret : f64
 }

diff  --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
index 30f30def56fdd5..847c41fec9135e 100644
--- a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
@@ -195,7 +195,7 @@ func.func @powf() {
   %a_p = arith.constant 2.0 : f64
   call @func_powff64(%a, %a_p) : (f64, f64) -> ()
 
-  // CHECK-NEXT: nan
+  // CHECK-NEXT: -27
   %b   = arith.constant -3.0 : f64
   %b_p = arith.constant 3.0 : f64
   call @func_powff64(%b, %b_p) : (f64, f64) -> ()
@@ -220,16 +220,9 @@ func.func @powf() {
   %f_p  = arith.constant 1.2 : f64
   call @func_powff64(%f, %f_p) : (f64, f64) -> ()
 
-  // CHECK-NEXT: nan
-  %g    = arith.constant 0xff80000000000000 : f64
-  call @func_powff64(%g, %g) : (f64, f64) -> ()
-
-  // CHECK-NEXT: nan
-  %h = arith.constant 0x7fffffffffffffff : f64
-  call @func_powff64(%h, %h) : (f64, f64) -> ()
-
   // CHECK-NEXT: nan
   %i = arith.constant 1.0 : f64
+  %h = arith.constant 0x7fffffffffffffff : f64
   call @func_powff64(%i, %h) : (f64, f64) -> ()
 
   // CHECK-NEXT: inf


        


More information about the llvm-branch-commits mailing list