[Mlir-commits] [mlir] [MLIR] Lower `math.powf(x, 3.0)` to `x * x * x`. (PR #127256)

Benoit Jacob llvmlistbot at llvm.org
Fri Feb 14 12:08:24 PST 2025


https://github.com/bjacob created https://github.com/llvm/llvm-project/pull/127256

`math.powf(x, y)` never really supported negative values of `x`, but that was unclear (happened to work for some values of `y`) until https://github.com/llvm/llvm-project/pull/126338 was merged yesterday and lowered it to the usual `exp(y * log(x))` outside of a few special exponent values, such as y == 2.0` lowering to `x * x`.

It turns out that code in the wild has been relying on `math.powf(x, y)` with negative `x` for some integral values of `y` for which a lowering to muls was intended: https://github.com/iree-org/iree/issues/19996

This PR adds such a lowering for `y == 3.0`. It "fixes" such cases, and it is a more efficient lowering anyway.

There needs to be a wider project to stop altogether using `powf` with negative `x`, use `math.fpowi` for that.

>From cde72d56cb8a09b4af96c48768fa9482938025fa Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Fri, 14 Feb 2025 14:02:43 -0600
Subject: [PATCH] pow3

Signed-off-by: Benoit Jacob <jacob.benoit.1 at gmail.com>
---
 mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp | 13 +++++++++----
 mlir/test/Dialect/Math/expand-math.mlir             | 11 +++++++++++
 2 files changed, 20 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index d7953719d44b5..23356d752146d 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -325,6 +325,9 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   auto &sem =
       cast<mlir::FloatType>(getElementTypeOrSelf(typeB)).getFloatSemantics();
   APFloat valueB(sem);
+  auto mulf = [&](Value x, Value y) -> Value {
+    return b.create<arith::MulFOp>(x, y);
+  };
   if (matchPattern(operandB, m_ConstantFloat(&valueB))) {
     if (valueB.isZero()) {
       // a^0 -> 1
@@ -358,19 +361,21 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
     }
     if (valueB.isExactlyValue(2.0)) {
       // a^2 -> a * a
-      Value mul = b.create<arith::MulFOp>(operandA, operandA);
-      rewriter.replaceOp(op, mul);
+      rewriter.replaceOp(op, mulf(operandA, operandA));
       return success();
     }
     if (valueB.isExactlyValue(-2.0)) {
       // a^(-2) -> 1 / (a * a)
-      Value mul = b.create<arith::MulFOp>(operandA, operandA);
       Value one =
           createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
-      Value div = b.create<arith::DivFOp>(one, mul);
+      Value div = b.create<arith::DivFOp>(one, mulf(operandA, operandA));
       rewriter.replaceOp(op, div);
       return success();
     }
+    if (valueB.isExactlyValue(3.0)) {
+      rewriter.replaceOp(op, mulf(mulf(operandA, operandA), operandA));
+      return success();
+    }
   }
 
   Value logA = b.create<math::LogOp>(operandA);
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index f39d1a7a6dc50..1fdfb854325b4 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -285,6 +285,17 @@ func.func @powf_func_negtwo(%a: f64) -> f64{
   return %ret : f64
 }
 
+// CHECK-LABEL:   func @powf_func_three
+// CHECK-SAME:    (%[[ARG0:.+]]: f64) -> f64
+func.func @powf_func_three(%a: f64) -> f64{
+  // CHECK: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG0]] : f64
+  // CHECK: %[[MUL2:.+]] = arith.mulf %[[MUL]], %[[ARG0]] : f64
+  // CHECK: return %[[MUL2]] : f64
+  %b = arith.constant 3.0 : f64
+  %ret = math.powf %a, %b : f64
+  return %ret : f64
+}
+
 // -----
 
 // CHECK-LABEL:   func.func @roundeven64



More information about the Mlir-commits mailing list