[Mlir-commits] [mlir] [MLIR] Lower `math.powf(x, 3.0)` to `x * x * x`. (PR #127256)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Feb 14 12:08:57 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Benoit Jacob (bjacob)
<details>
<summary>Changes</summary>
`math.powf(x, y)` never really supported negative values of `x`, but that was unclear (happened to work for some values of `y`) until https://github.com/llvm/llvm-project/pull/126338 was merged yesterday and lowered it to the usual `exp(y * log(x))` outside of a few special exponent values, such as y == 2.0` lowering to `x * x`.
It turns out that code in the wild has been relying on `math.powf(x, y)` with negative `x` for some integral values of `y` for which a lowering to muls was intended: https://github.com/iree-org/iree/issues/19996
This PR adds such a lowering for `y == 3.0`. It "fixes" such cases, and it is a more efficient lowering anyway.
There needs to be a wider project to stop altogether using `powf` with negative `x`, use `math.fpowi` for that.
---
Full diff: https://github.com/llvm/llvm-project/pull/127256.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp (+9-4)
- (modified) mlir/test/Dialect/Math/expand-math.mlir (+11)
``````````diff
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index d7953719d44b5..23356d752146d 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -325,6 +325,9 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
auto &sem =
cast<mlir::FloatType>(getElementTypeOrSelf(typeB)).getFloatSemantics();
APFloat valueB(sem);
+ auto mulf = [&](Value x, Value y) -> Value {
+ return b.create<arith::MulFOp>(x, y);
+ };
if (matchPattern(operandB, m_ConstantFloat(&valueB))) {
if (valueB.isZero()) {
// a^0 -> 1
@@ -358,19 +361,21 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
}
if (valueB.isExactlyValue(2.0)) {
// a^2 -> a * a
- Value mul = b.create<arith::MulFOp>(operandA, operandA);
- rewriter.replaceOp(op, mul);
+ rewriter.replaceOp(op, mulf(operandA, operandA));
return success();
}
if (valueB.isExactlyValue(-2.0)) {
// a^(-2) -> 1 / (a * a)
- Value mul = b.create<arith::MulFOp>(operandA, operandA);
Value one =
createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
- Value div = b.create<arith::DivFOp>(one, mul);
+ Value div = b.create<arith::DivFOp>(one, mulf(operandA, operandA));
rewriter.replaceOp(op, div);
return success();
}
+ if (valueB.isExactlyValue(3.0)) {
+ rewriter.replaceOp(op, mulf(mulf(operandA, operandA), operandA));
+ return success();
+ }
}
Value logA = b.create<math::LogOp>(operandA);
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index f39d1a7a6dc50..1fdfb854325b4 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -285,6 +285,17 @@ func.func @powf_func_negtwo(%a: f64) -> f64{
return %ret : f64
}
+// CHECK-LABEL: func @powf_func_three
+// CHECK-SAME: (%[[ARG0:.+]]: f64) -> f64
+func.func @powf_func_three(%a: f64) -> f64{
+ // CHECK: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG0]] : f64
+ // CHECK: %[[MUL2:.+]] = arith.mulf %[[MUL]], %[[ARG0]] : f64
+ // CHECK: return %[[MUL2]] : f64
+ %b = arith.constant 3.0 : f64
+ %ret = math.powf %a, %b : f64
+ return %ret : f64
+}
+
// -----
// CHECK-LABEL: func.func @roundeven64
``````````
</details>
https://github.com/llvm/llvm-project/pull/127256
More information about the Mlir-commits
mailing list