[Mlir-commits] [mlir] de09986 - [mlir][math] `powf(a, b)` drop support when a < 0 (#126338)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 13 08:01:51 PST 2025
Author: Hyunsung Lee
Date: 2025-02-13T08:01:47-08:00
New Revision: de09986596c9bbc89262456dda319715fb49353f
URL: https://github.com/llvm/llvm-project/commit/de09986596c9bbc89262456dda319715fb49353f
DIFF: https://github.com/llvm/llvm-project/commit/de09986596c9bbc89262456dda319715fb49353f.diff
LOG: [mlir][math] `powf(a, b)` drop support when a < 0 (#126338)
Related: #124402
- change inefficient implementation of `powf(a, b)` to handle `a < 0`
case
- thus drop `a < 0` case support
However, some special cases are being used such as:
- `a < 0` and `b = 0, b = 0.5, b = 1 or b = 2`
- convert those special cases into simpler ops.
Added:
Modified:
mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
mlir/test/Dialect/Math/expand-math.mlir
mlir/test/mlir-runner/test-expand-math-approx.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 3dadf9474cf4f..d7953719d44b5 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -19,6 +19,7 @@
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/APFloat.h"
using namespace mlir;
@@ -311,40 +312,71 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
return success();
}
-// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
+// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
+// Some special cases where b is constant are handled separately:
+// when b == 0, or |b| == 0.5, 1.0, or 2.0.
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operandA = op.getOperand(0);
Value operandB = op.getOperand(1);
- Type opType = operandA.getType();
- Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
- Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
- Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
- Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
- Value opASquared = b.create<arith::MulFOp>(opType, operandA, operandA);
- Value opBHalf = b.create<arith::DivFOp>(opType, operandB, two);
-
- Value logA = b.create<math::LogOp>(opType, opASquared);
- Value mult = b.create<arith::MulFOp>(opType, opBHalf, logA);
- Value expResult = b.create<math::ExpOp>(opType, mult);
- Value negExpResult = b.create<arith::MulFOp>(opType, expResult, negOne);
- Value remainder = b.create<arith::RemFOp>(opType, operandB, two);
- Value negCheck =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
- Value oddPower =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
- Value oddAndNeg = b.create<arith::AndIOp>(op->getLoc(), oddPower, negCheck);
-
- // First, we select between the exp value and the adjusted value for odd
- // powers of negatives. Then, we ensure that one is produced if `b` is zero.
- // This corresponds to `libm` behavior, even for `0^0`. Without this check,
- // `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`.
- Value zeroCheck =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
- Value res = b.create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult,
- expResult);
- res = b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, res);
- rewriter.replaceOp(op, res);
+ auto typeA = operandA.getType();
+ auto typeB = operandB.getType();
+
+ auto &sem =
+ cast<mlir::FloatType>(getElementTypeOrSelf(typeB)).getFloatSemantics();
+ APFloat valueB(sem);
+ if (matchPattern(operandB, m_ConstantFloat(&valueB))) {
+ if (valueB.isZero()) {
+ // a^0 -> 1
+ Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter);
+ rewriter.replaceOp(op, one);
+ return success();
+ }
+ if (valueB.isExactlyValue(1.0)) {
+ // a^1 -> a
+ rewriter.replaceOp(op, operandA);
+ return success();
+ }
+ if (valueB.isExactlyValue(-1.0)) {
+ // a^(-1) -> 1 / a
+ Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter);
+ Value div = b.create<arith::DivFOp>(one, operandA);
+ rewriter.replaceOp(op, div);
+ return success();
+ }
+ if (valueB.isExactlyValue(0.5)) {
+ // a^(1/2) -> sqrt(a)
+ Value sqrt = b.create<math::SqrtOp>(operandA);
+ rewriter.replaceOp(op, sqrt);
+ return success();
+ }
+ if (valueB.isExactlyValue(-0.5)) {
+ // a^(-1/2) -> 1 / sqrt(a)
+ Value rsqrt = b.create<math::RsqrtOp>(operandA);
+ rewriter.replaceOp(op, rsqrt);
+ return success();
+ }
+ if (valueB.isExactlyValue(2.0)) {
+ // a^2 -> a * a
+ Value mul = b.create<arith::MulFOp>(operandA, operandA);
+ rewriter.replaceOp(op, mul);
+ return success();
+ }
+ if (valueB.isExactlyValue(-2.0)) {
+ // a^(-2) -> 1 / (a * a)
+ Value mul = b.create<arith::MulFOp>(operandA, operandA);
+ Value one =
+ createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
+ Value div = b.create<arith::DivFOp>(one, mul);
+ rewriter.replaceOp(op, div);
+ return success();
+ }
+ }
+
+ Value logA = b.create<math::LogOp>(operandA);
+ Value mult = b.create<arith::MulFOp>(operandB, logA);
+ Value expResult = b.create<math::ExpOp>(mult);
+ rewriter.replaceOp(op, expResult);
return success();
}
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 6055ed0504c84..f39d1a7a6dc50 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -201,26 +201,86 @@ func.func @roundf_func(%a: f32) -> f32 {
// -----
// CHECK-LABEL: func @powf_func
-// CHECK-SAME: ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64)
-func.func @powf_func(%a: f64, %b: f64) ->f64 {
- // CHECK-DAG: [[CST0:%.+]] = arith.constant 0.000000e+00
- // CHECK-DAG: [[CST1:%.+]] = arith.constant 1.0
- // CHECK-DAG: [[TWO:%.+]] = arith.constant 2.000000e+00
- // CHECK-DAG: [[NEGONE:%.+]] = arith.constant -1.000000e+00
- // CHECK-DAG: [[SQR:%.+]] = arith.mulf [[ARG0]], [[ARG0]]
- // CHECK-DAG: [[HALF:%.+]] = arith.divf [[ARG1]], [[TWO]]
- // CHECK-DAG: [[LOG:%.+]] = math.log [[SQR]]
- // CHECK-DAG: [[MULT:%.+]] = arith.mulf [[HALF]], [[LOG]]
- // CHECK-DAG: [[EXPR:%.+]] = math.exp [[MULT]]
- // CHECK-DAG: [[NEGEXPR:%.+]] = arith.mulf [[EXPR]], [[NEGONE]]
- // CHECK-DAG: [[REMF:%.+]] = arith.remf [[ARG1]], [[TWO]]
- // CHECK-DAG: [[CMPNEG:%.+]] = arith.cmpf olt, [[ARG0]]
- // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf one, [[REMF]]
- // CHECK-DAG: [[AND:%.+]] = arith.andi [[CMPZERO]], [[CMPNEG]]
- // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
- // CHECK-DAG: [[SEL:%.+]] = arith.select [[AND]], [[NEGEXPR]], [[EXPR]]
- // CHECK-DAG: [[SEL1:%.+]] = arith.select [[CMPZERO]], [[CST1]], [[SEL]]
- // CHECK: return [[SEL1]]
+// CHECK-SAME: (%[[ARG0:.+]]: f64, %[[ARG1:.+]]: f64) -> f64
+func.func @powf_func(%a: f64, %b: f64) -> f64 {
+ // CHECK: %[[LOGA:.+]] = math.log %[[ARG0]] : f64
+ // CHECK: %[[MUL:.+]] = arith.mulf %[[ARG1]], %[[LOGA]] : f64
+ // CHECK: %[[EXP:.+]] = math.exp %[[MUL]] : f64
+ // CHECK: return %[[EXP]] : f64
+ %ret = math.powf %a, %b : f64
+ return %ret : f64
+}
+
+// CHECK-LABEL: func @powf_func_zero
+// CHECK-SAME: (%[[ARG0:.+]]: f64) -> f64
+func.func @powf_func_zero(%a: f64) -> f64{
+ // CHECK: %[[ONE:.+]] = arith.constant 1.000000e+00 : f64
+ // CHECK: return %[[ONE]] : f64
+ %b = arith.constant 0.0 : f64
+ %ret = math.powf %a, %b : f64
+ return %ret : f64
+}
+
+// CHECK-LABEL: func @powf_func_one
+// CHECK-SAME: (%[[ARG0:.+]]: f64) -> f64
+func.func @powf_func_one(%a: f64) -> f64{
+ // CHECK: return %[[ARG0]] : f64
+ %b = arith.constant 1.0 : f64
+ %ret = math.powf %a, %b : f64
+ return %ret : f64
+}
+
+// CHECK-LABEL: func @powf_func_negone
+// CHECK-SAME: (%[[ARG0:.+]]: f64) -> f64
+func.func @powf_func_negone(%a: f64) -> f64{
+ // CHECK: %[[CSTONE:.+]] = arith.constant 1.000000e+00 : f64
+ // CHECK: %[[DIV:.+]] = arith.divf %[[CSTONE]], %[[ARG0]] : f64
+ // CHECK: return %[[DIV]] : f64
+ %b = arith.constant -1.0 : f64
+ %ret = math.powf %a, %b : f64
+ return %ret : f64
+}
+
+// CHECK-LABEL: func @powf_func_half
+// CHECK-SAME: (%[[ARG0:.+]]: f64) -> f64
+func.func @powf_func_half(%a: f64) -> f64{
+ // CHECK: %[[SQRT:.+]] = math.sqrt %[[ARG0]] : f64
+ // CHECK: return %[[SQRT]] : f64
+ %b = arith.constant 0.5 : f64
+ %ret = math.powf %a, %b : f64
+ return %ret : f64
+}
+
+// CHECK-LABEL: func @powf_func_neghalf
+// CHECK-SAME: (%[[ARG0:.+]]: f64) -> f64
+func.func @powf_func_neghalf(%a: f64) -> f64{
+ // CHECK: %[[CSTONE:.+]] = arith.constant 1.000000e+00 : f64
+ // CHECK: %[[SQRT:.+]] = math.sqrt %[[ARG0]] : f64
+ // CHECK: %[[DIV:.+]] = arith.divf %[[CSTONE]], %[[SQRT]] : f64
+ // CHECK: return %[[DIV]] : f64
+ %b = arith.constant -0.5 : f64
+ %ret = math.powf %a, %b : f64
+ return %ret : f64
+}
+
+// CHECK-LABEL: func @powf_func_two
+// CHECK-SAME: (%[[ARG0:.+]]: f64) -> f64
+func.func @powf_func_two(%a: f64) -> f64{
+ // CHECK: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG0]] : f64
+ // CHECK: return %[[MUL]] : f64
+ %b = arith.constant 2.0 : f64
+ %ret = math.powf %a, %b : f64
+ return %ret : f64
+}
+
+// CHECK-LABEL: func @powf_func_negtwo
+// CHECK-SAME: (%[[ARG0:.+]]: f64) -> f64
+func.func @powf_func_negtwo(%a: f64) -> f64{
+ // CHECK-DAG: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG0]] : f64
+ // CHECK-DAG: %[[CSTONE:.+]] = arith.constant 1.000000e+00 : f64
+ // CHECK: %[[DIV:.+]] = arith.divf %[[CSTONE]], %[[MUL]] : f64
+ // CHECK: return %[[DIV]] : f64
+ %b = arith.constant -2.0 : f64
%ret = math.powf %a, %b : f64
return %ret : f64
}
@@ -602,26 +662,11 @@ func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> t
return %2 : tensor<8xf32>
}
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>, %[[ARG1:.*]]: tensor<8xi32>) -> tensor<8xf32> {
-// CHECK-DAG: %[[CSTNEG1:.*]] = arith.constant dense<-1.000000e+00> : tensor<8xf32>
-// CHECK-DAG: %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<8xf32>
-// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
-// CHECK-DAG: %[[CST1:.+]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
-// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
-// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
-// CHECK: %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : tensor<8xf32>
-// CHECK: %[[LG:.*]] = math.log %[[SQ]] : tensor<8xf32>
-// CHECK: %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : tensor<8xf32>
-// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32>
-// CHECK: %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : tensor<8xf32>
-// CHECK: %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : tensor<8xf32>
-// CHECK: %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : tensor<8xf32>
-// CHECK: %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : tensor<8xf32>
-// CHECK: %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : tensor<8xi1>
-// CHECK: %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
-// CHECK: %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32>
-// CHECK: %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
-// CHECK: return %[[SEL1]] : tensor<8xf32>
-
+// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
+// CHECK: %[[LOGA:.*]] = math.log %[[ARG0]] : tensor<8xf32>
+// CHECK: %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : tensor<8xf32>
+// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32>
+// CHECK: return %[[EXP]]
// -----
// CHECK-LABEL: func.func @math_fpowi_to_powf_scalar
@@ -630,25 +675,11 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
return %2 : f32
}
// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: i64) -> f32 {
-// CHECK-DAG: %[[CSTNEG1:.*]] = arith.constant -1.000000e+00 : f32
-// CHECK-DAG: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
-// CHECK-DAG: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK-DAG: %[[CST1:.+]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : i64 to f32
-// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : f32
-// CHECK: %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : f32
-// CHECK: %[[LG:.*]] = math.log %[[SQ]] : f32
-// CHECK: %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : f32
+// CHECK: %[[LOGA:.*]] = math.log %[[ARG0]] : f32
+// CHECK: %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : f32
// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : f32
-// CHECK: %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : f32
-// CHECK: %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : f32
-// CHECK: %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : f32
-// CHECK: %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : f32
-// CHECK: %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : i1
-// CHECK: %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
-// CHECK: %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : f32
-// CHECK: %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
-// CHECK: return %[[SEL1]] : f32
+// CHECK: return %[[EXP]] : f32
// -----
diff --git a/mlir/test/mlir-runner/test-expand-math-approx.mlir b/mlir/test/mlir-runner/test-expand-math-approx.mlir
index 106b48a2daea2..b599c9d8435d4 100644
--- a/mlir/test/mlir-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-runner/test-expand-math-approx.mlir
@@ -202,55 +202,62 @@ func.func @powf() {
%a_p = arith.constant 2.0 : f64
call @func_powff64(%a, %a_p) : (f64, f64) -> ()
- // CHECK-NEXT: -27
- %b = arith.constant -3.0 : f64
- %b_p = arith.constant 3.0 : f64
- call @func_powff64(%b, %b_p) : (f64, f64) -> ()
-
// CHECK-NEXT: 2.343
- %c = arith.constant 2.343 : f64
- %c_p = arith.constant 1.000 : f64
- call @func_powff64(%c, %c_p) : (f64, f64) -> ()
+ %b = arith.constant 2.343 : f64
+ %b_p = arith.constant 1.000 : f64
+ call @func_powff64(%b, %b_p) : (f64, f64) -> ()
// CHECK-NEXT: 0.176171
- %d = arith.constant 4.25 : f64
- %d_p = arith.constant -1.2 : f64
- call @func_powff64(%d, %d_p) : (f64, f64) -> ()
+ %c = arith.constant 4.25 : f64
+ %c_p = arith.constant -1.2 : f64
+ call @func_powff64(%c, %c_p) : (f64, f64) -> ()
// CHECK-NEXT: 1
- %e = arith.constant 4.385 : f64
- %e_p = arith.constant 0.00 : f64
- call @func_powff64(%e, %e_p) : (f64, f64) -> ()
+ %d = arith.constant 4.385 : f64
+ %d_p = arith.constant 0.00 : f64
+ call @func_powff64(%d, %d_p) : (f64, f64) -> ()
// CHECK-NEXT: 6.62637
- %f = arith.constant 4.835 : f64
- %f_p = arith.constant 1.2 : f64
- call @func_powff64(%f, %f_p) : (f64, f64) -> ()
+ %e = arith.constant 4.835 : f64
+ %e_p = arith.constant 1.2 : f64
+ call @func_powff64(%e, %e_p) : (f64, f64) -> ()
// CHECK-NEXT: nan
- %i = arith.constant 1.0 : f64
- %h = arith.constant 0x7fffffffffffffff : f64
- call @func_powff64(%i, %h) : (f64, f64) -> ()
+ %f = arith.constant 1.0 : f64
+ %f_p = arith.constant 0x7fffffffffffffff : f64
+ call @func_powff64(%f, %f_p) : (f64, f64) -> ()
// CHECK-NEXT: inf
- %j = arith.constant 29385.0 : f64
- %j_p = arith.constant 23598.0 : f64
- call @func_powff64(%j, %j_p) : (f64, f64) -> ()
+ %g = arith.constant 29385.0 : f64
+ %g_p = arith.constant 23598.0 : f64
+ call @func_powff64(%g, %g_p) : (f64, f64) -> ()
// CHECK-NEXT: -nan
- %k = arith.constant 1.0 : f64
- %k_p = arith.constant 0xfff0000001000000 : f64
- call @func_powff64(%k, %k_p) : (f64, f64) -> ()
+ %h = arith.constant 1.0 : f64
+ %h_p = arith.constant 0xfff0000001000000 : f64
+ call @func_powff64(%h, %h_p) : (f64, f64) -> ()
// CHECK-NEXT: -nan
- %l = arith.constant 1.0 : f32
- %l_p = arith.constant 0xffffffff : f32
- call @func_powff32(%l, %l_p) : (f32, f32) -> ()
+ %i = arith.constant 1.0 : f32
+ %i_p = arith.constant 0xffffffff : f32
+ call @func_powff32(%i, %i_p) : (f32, f32) -> ()
// CHECK-NEXT: 1
- %zero = arith.constant 0.0 : f32
- call @func_powff32(%zero, %zero) : (f32, f32) -> ()
+ %j = arith.constant 0.000 : f32
+ %j_r = math.powf %j, %j : f32
+ vector.print %j_r : f32
+ // CHECK-NEXT: 4
+ %k = arith.constant -2.0 : f32
+ %k_p = arith.constant 2.0 : f32
+ %k_r = math.powf %k, %k_p : f32
+ vector.print %k_r : f32
+
+ // CHECK-NEXT: 0.25
+ %l = arith.constant -2.0 : f32
+ %l_p = arith.constant -2.0 : f32
+ %l_r = math.powf %k, %l_p : f32
+ vector.print %l_r : f32
return
}
More information about the Mlir-commits
mailing list