[Mlir-commits] [mlir] f66e4bd - [mlir][math] Modify math.powf to handle negative bases.
Rob Suderman
llvmlistbot at llvm.org
Fri Aug 25 15:35:41 PDT 2023
Author: Balaji V. Iyer
Date: 2023-08-25T15:35:05-07:00
New Revision: f66e4bd67adf0b0aaecd94154c38f02253bf7190
URL: https://github.com/llvm/llvm-project/commit/f66e4bd67adf0b0aaecd94154c38f02253bf7190
DIFF: https://github.com/llvm/llvm-project/commit/f66e4bd67adf0b0aaecd94154c38f02253bf7190.diff
LOG: [mlir][math] Modify math.powf to handle negative bases.
Powf expansion currently returns NaN when the base is negative.
This is because taking natural log of a negative number gives
NaN. This patch will square the base and half the exponent, thereby
getting around the negative base problem.
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D158797
Added:
Modified:
mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
mlir/test/Dialect/Math/expand-math.mlir
mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index aa5fd1db528e69..9c46a4ca10a8ec 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -168,11 +168,26 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
Value operandA = op.getOperand(0);
Value operandB = op.getOperand(1);
Type opType = operandA.getType();
+ Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
+ Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
+ Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
+ Value opASquared = b.create<arith::MulFOp>(opType, operandA, operandA);
+ Value opBHalf = b.create<arith::DivFOp>(opType, operandB, two);
- Value logA = b.create<math::LogOp>(opType, operandA);
- Value mult = b.create<arith::MulFOp>(opType, logA, operandB);
+ Value logA = b.create<math::LogOp>(opType, opASquared);
+ Value mult = b.create<arith::MulFOp>(opType, opBHalf, logA);
Value expResult = b.create<math::ExpOp>(opType, mult);
- rewriter.replaceOp(op, expResult);
+ Value negExpResult = b.create<arith::MulFOp>(opType, expResult, negOne);
+ Value remainder = b.create<arith::RemFOp>(opType, operandB, two);
+ Value negCheck =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
+ Value oddPower =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
+ Value oddAndNeg = b.create<arith::AndIOp>(op->getLoc(), oddPower, negCheck);
+
+ Value res = b.create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult,
+ expResult);
+ rewriter.replaceOp(op, res);
return success();
}
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 51821e3f099a0c..3e1c3462fc8fc6 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -222,10 +222,21 @@ func.func @roundf_func(%a: f32) -> f32 {
// CHECK-LABEL: func @powf_func
// CHECK-SAME: ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64)
func.func @powf_func(%a: f64, %b: f64) ->f64 {
- // CHECK-DAG: [[LOG:%.+]] = math.log [[ARG0]]
- // CHECK-DAG: [[MULT:%.+]] = arith.mulf [[LOG]], [[ARG1]]
+ // CHECK-DAG = [[CST0:%.+]] = arith.constant 0.000000e+00
+ // CHECK-DAG: [[TWO:%.+]] = arith.constant 2.000000e+00
+ // CHECK-DAG: [[NEGONE:%.+]] = arith.constant -1.000000e+00
+ // CHECK-DAG: [[SQR:%.+]] = arith.mulf [[ARG0]], [[ARG0]]
+ // CHECK-DAG: [[HALF:%.+]] = arith.divf [[ARG1]], [[TWO]]
+ // CHECK-DAG: [[LOG:%.+]] = math.log [[SQR]]
+ // CHECK-DAG: [[MULT:%.+]] = arith.mulf [[HALF]], [[LOG]]
// CHECK-DAG: [[EXPR:%.+]] = math.exp [[MULT]]
- // CHECK: return [[EXPR]]
+ // CHECK-DAG: [[NEGEXPR:%.+]] = arith.mulf [[EXPR]], [[NEGONE]]
+ // CHECK-DAG: [[REMF:%.+]] = arith.remf [[ARG1]], [[TWO]]
+ // CHECK-DAG: [[CMPNEG:%.+]] = arith.cmpf olt, [[ARG0]]
+ // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf one, [[REMF]]
+ // CHECK-DAG: [[AND:%.+]] = arith.andi [[CMPZERO]], [[CMPNEG]]
+ // CHECK-DAG: [[SEL:%.+]] = arith.select [[AND]], [[NEGEXPR]], [[EXPR]]
+ // CHECK: return [[SEL]]
%ret = math.powf %a, %b : f64
return %ret : f64
}
diff --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
index 6ca25edef59e79..3bf474ea47f37f 100644
--- a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
@@ -196,7 +196,7 @@ func.func @powf() {
%a_p = arith.constant 2.0 : f64
call @func_powff64(%a, %a_p) : (f64, f64) -> ()
- // CHECK-NEXT: nan
+ // CHECK-NEXT: -27
%b = arith.constant -3.0 : f64
%b_p = arith.constant 3.0 : f64
call @func_powff64(%b, %b_p) : (f64, f64) -> ()
@@ -221,16 +221,9 @@ func.func @powf() {
%f_p = arith.constant 1.2 : f64
call @func_powff64(%f, %f_p) : (f64, f64) -> ()
- // CHECK-NEXT: nan
- %g = arith.constant 0xff80000000000000 : f64
- call @func_powff64(%g, %g) : (f64, f64) -> ()
-
- // CHECK-NEXT: nan
- %h = arith.constant 0x7fffffffffffffff : f64
- call @func_powff64(%h, %h) : (f64, f64) -> ()
-
// CHECK-NEXT: nan
%i = arith.constant 1.0 : f64
+ %h = arith.constant 0x7fffffffffffffff : f64
call @func_powff64(%i, %h) : (f64, f64) -> ()
// CHECK-NEXT: inf
More information about the Mlir-commits
mailing list