[Mlir-commits] [mlir] [MLIR][Math] Use square-and-multiply sequence instead of linear during power expansion (PR #177631)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 23 10:13:20 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-math

Author: foxtran (foxtran)

<details>
<summary>Changes</summary>

Currently, `flang` generates suboptimal code for `a**i` where `i` is known at compile time and is in [4, 8]. For such `i`, linear sequence of computations is generated. See an example here: https://godbolt.org/z/PT3q8v1GE

Or here, 
```asm
sub_:
        movaps  xmm1, xmm0
        mulss   xmm1, xmm0
        mulss   xmm1, xmm0
        mulss   xmm1, xmm0
        mulss   xmm1, xmm0
        mulss   xmm1, xmm0
        mulss   xmm1, xmm0
        mulss   xmm0, xmm1
        ret
```

After this PR, `flang` will generate much better code:
```asm
sub_:
	mulss	%xmm0, %xmm0
	mulss	%xmm0, %xmm0
	mulss	%xmm0, %xmm0
	retq
```


---
Full diff: https://github.com/llvm/llvm-project/pull/177631.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp (+11-6) 
- (modified) mlir/test/Dialect/Math/algebraic-simplification.mlir (+14) 


``````````diff
diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index ff5f7f685903f..acf063248cb30 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -206,16 +206,16 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
   if (exponentValue > exponentThreshold)
     return failure();
 
-  Value result = base;
-  // Transform to naive sequence of multiplications:
+  Value result = one;
+  // Transform to a square-and-multiply sequence:
   //   * For positive exponent case replace:
   //       `[fi]powi(x, positive_exponent)`
   //     with:
-  //       x * x * x * ...
+  //       a chain of x, x*x, x**4, and so on...
   //   * For negative exponent case replace:
   //       `[fi]powi(x, negative_exponent)`
   //     with:
-  //       (1 / x) * (1 / x) * (1 / x) * ...
+  //       a chain of (1/x), (1/x)*(1/x), and so on...
   auto buildMul = [&](Value lhs, Value rhs) {
     if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>)
       return MulOpTy::create(rewriter, loc, op.getType(), lhs, rhs,
@@ -223,8 +223,13 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
     else
       return MulOpTy::create(rewriter, loc, lhs, rhs);
   };
-  for (unsigned i = 1; i < exponentValue; ++i)
-    result = buildMul(result, base);
+  while (exponentValue > 0) {
+    if (exponentValue & 1) {
+      result = buildMul(base, result);
+    }
+    exponentValue >>= 1;
+    base = buildMul(base, base);
+  }
 
   // Inverse the base for negative exponent, i.e. for
   // `[fi]powi(x, negative_exponent)` set `x` to `1 / x`.
diff --git a/mlir/test/Dialect/Math/algebraic-simplification.mlir b/mlir/test/Dialect/Math/algebraic-simplification.mlir
index 7342600748967..e4fc58c010d54 100644
--- a/mlir/test/Dialect/Math/algebraic-simplification.mlir
+++ b/mlir/test/Dialect/Math/algebraic-simplification.mlir
@@ -62,6 +62,20 @@ func.func @pow_cube_fast(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf3
   return %0, %1 : f32, vector<4xf32>
 }
 
+// CHECK-LABEL: @pow_4_int
+func.func @pow_4_int(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
+  // CHECK: %[[TMP_S:.*]] = arith.mulf %arg0, %arg0
+  // CHECK: %[[SCALAR:.*]] = arith.mulf %[[TMP_S]], %[[TMP_S]]
+  // CHECK: %[[TMP_V:.*]] = arith.mulf %arg1, %arg1
+  // CHECK: %[[VECTOR:.*]] = arith.mulf %[[TMP_V]], %[[TMP_V]]
+  // CHECK: return %[[SCALAR]], %[[VECTOR]]
+  %c = arith.constant 4 : i32
+  %v = arith.constant dense <4> : vector<4xi32>
+  %0 = math.fpowi %arg0, %c : f32, i32
+  %1 = math.fpowi %arg1, %v : vector<4xf32>, vector<4xi32>
+  return %0, %1 : f32, vector<4xf32>
+}
+
 // CHECK-LABEL: @pow_recip
 func.func @pow_recip(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
   // CHECK-DAG: %[[CST_S:.*]] = arith.constant 1.0{{.*}} : f32

``````````

</details>


https://github.com/llvm/llvm-project/pull/177631


More information about the Mlir-commits mailing list