[Mlir-commits] [mlir] 0301bf9 - [MLIR] Lower `math.powf(x, 3.0)` to `x * x * x`. (#127256)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Feb 14 12:25:05 PST 2025
Author: Benoit Jacob
Date: 2025-02-14T15:25:01-05:00
New Revision: 0301bf977aa9842003462a2a7a3c3ce56abfaae0
URL: https://github.com/llvm/llvm-project/commit/0301bf977aa9842003462a2a7a3c3ce56abfaae0
DIFF: https://github.com/llvm/llvm-project/commit/0301bf977aa9842003462a2a7a3c3ce56abfaae0.diff
LOG: [MLIR] Lower `math.powf(x, 3.0)` to `x * x * x`. (#127256)
`math.powf(x, y)` never really supported negative values of `x`, but
that was unclear (happened to work for some values of `y`) until
https://github.com/llvm/llvm-project/pull/126338 was merged yesterday
and lowered it to the usual `exp(y * log(x))` outside of a few special
exponent values, such as y == 2.0` lowering to `x * x`.
It turns out that code in the wild has been relying on `math.powf(x, y)`
with negative `x` for some integral values of `y` for which a lowering
to muls was intended: https://github.com/iree-org/iree/issues/19996
This PR adds such a lowering for `y == 3.0`. It "fixes" such cases, and
it is a more efficient lowering anyway.
There needs to be a wider project to stop altogether using `powf` with
negative `x`, use `math.fpowi` for that.
Signed-off-by: Benoit Jacob <jacob.benoit.1 at gmail.com>
Added:
Modified:
mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
mlir/test/Dialect/Math/expand-math.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index d7953719d44b5..23356d752146d 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -325,6 +325,9 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
auto &sem =
cast<mlir::FloatType>(getElementTypeOrSelf(typeB)).getFloatSemantics();
APFloat valueB(sem);
+ auto mulf = [&](Value x, Value y) -> Value {
+ return b.create<arith::MulFOp>(x, y);
+ };
if (matchPattern(operandB, m_ConstantFloat(&valueB))) {
if (valueB.isZero()) {
// a^0 -> 1
@@ -358,19 +361,21 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
}
if (valueB.isExactlyValue(2.0)) {
// a^2 -> a * a
- Value mul = b.create<arith::MulFOp>(operandA, operandA);
- rewriter.replaceOp(op, mul);
+ rewriter.replaceOp(op, mulf(operandA, operandA));
return success();
}
if (valueB.isExactlyValue(-2.0)) {
// a^(-2) -> 1 / (a * a)
- Value mul = b.create<arith::MulFOp>(operandA, operandA);
Value one =
createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
- Value div = b.create<arith::DivFOp>(one, mul);
+ Value div = b.create<arith::DivFOp>(one, mulf(operandA, operandA));
rewriter.replaceOp(op, div);
return success();
}
+ if (valueB.isExactlyValue(3.0)) {
+ rewriter.replaceOp(op, mulf(mulf(operandA, operandA), operandA));
+ return success();
+ }
}
Value logA = b.create<math::LogOp>(operandA);
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index f39d1a7a6dc50..1fdfb854325b4 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -285,6 +285,17 @@ func.func @powf_func_negtwo(%a: f64) -> f64{
return %ret : f64
}
+// CHECK-LABEL: func @powf_func_three
+// CHECK-SAME: (%[[ARG0:.+]]: f64) -> f64
+func.func @powf_func_three(%a: f64) -> f64{
+ // CHECK: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG0]] : f64
+ // CHECK: %[[MUL2:.+]] = arith.mulf %[[MUL]], %[[ARG0]] : f64
+ // CHECK: return %[[MUL2]] : f64
+ %b = arith.constant 3.0 : f64
+ %ret = math.powf %a, %b : f64
+ return %ret : f64
+}
+
// -----
// CHECK-LABEL: func.func @roundeven64
More information about the Mlir-commits
mailing list