[Mlir-commits] [mlir] [MLIR][Math] Use square-and-multiply sequence instead of linear during power expansion (PR #177631)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 23 10:13:20 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-math
Author: foxtran (foxtran)
<details>
<summary>Changes</summary>
Currently, `flang` generates suboptimal code for `a**i` where `i` is known at compile time and is in [4, 8]. For such `i`, linear sequence of computations is generated. See an example here: https://godbolt.org/z/PT3q8v1GE
Or here,
```asm
sub_:
movaps xmm1, xmm0
mulss xmm1, xmm0
mulss xmm1, xmm0
mulss xmm1, xmm0
mulss xmm1, xmm0
mulss xmm1, xmm0
mulss xmm1, xmm0
mulss xmm0, xmm1
ret
```
After this PR, `flang` will generate much better code:
```asm
sub_:
mulss %xmm0, %xmm0
mulss %xmm0, %xmm0
mulss %xmm0, %xmm0
retq
```
---
Full diff: https://github.com/llvm/llvm-project/pull/177631.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp (+11-6)
- (modified) mlir/test/Dialect/Math/algebraic-simplification.mlir (+14)
``````````diff
diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index ff5f7f685903f..acf063248cb30 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -206,16 +206,16 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
if (exponentValue > exponentThreshold)
return failure();
- Value result = base;
- // Transform to naive sequence of multiplications:
+ Value result = one;
+ // Transform to a square-and-multiply sequence:
// * For positive exponent case replace:
// `[fi]powi(x, positive_exponent)`
// with:
- // x * x * x * ...
+ // a chain of x, x*x, x**4, and so on...
// * For negative exponent case replace:
// `[fi]powi(x, negative_exponent)`
// with:
- // (1 / x) * (1 / x) * (1 / x) * ...
+ // a chain of (1/x), (1/x)*(1/x), and so on...
auto buildMul = [&](Value lhs, Value rhs) {
if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>)
return MulOpTy::create(rewriter, loc, op.getType(), lhs, rhs,
@@ -223,8 +223,13 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
else
return MulOpTy::create(rewriter, loc, lhs, rhs);
};
- for (unsigned i = 1; i < exponentValue; ++i)
- result = buildMul(result, base);
+ while (exponentValue > 0) {
+ if (exponentValue & 1) {
+ result = buildMul(base, result);
+ }
+ exponentValue >>= 1;
+ base = buildMul(base, base);
+ }
// Inverse the base for negative exponent, i.e. for
// `[fi]powi(x, negative_exponent)` set `x` to `1 / x`.
diff --git a/mlir/test/Dialect/Math/algebraic-simplification.mlir b/mlir/test/Dialect/Math/algebraic-simplification.mlir
index 7342600748967..e4fc58c010d54 100644
--- a/mlir/test/Dialect/Math/algebraic-simplification.mlir
+++ b/mlir/test/Dialect/Math/algebraic-simplification.mlir
@@ -62,6 +62,20 @@ func.func @pow_cube_fast(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf3
return %0, %1 : f32, vector<4xf32>
}
+// CHECK-LABEL: @pow_4_int
+func.func @pow_4_int(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
+ // CHECK: %[[TMP_S:.*]] = arith.mulf %arg0, %arg0
+ // CHECK: %[[SCALAR:.*]] = arith.mulf %[[TMP_S]], %[[TMP_S]]
+ // CHECK: %[[TMP_V:.*]] = arith.mulf %arg1, %arg1
+ // CHECK: %[[VECTOR:.*]] = arith.mulf %[[TMP_V]], %[[TMP_V]]
+ // CHECK: return %[[SCALAR]], %[[VECTOR]]
+ %c = arith.constant 4 : i32
+ %v = arith.constant dense <4> : vector<4xi32>
+ %0 = math.fpowi %arg0, %c : f32, i32
+ %1 = math.fpowi %arg1, %v : vector<4xf32>, vector<4xi32>
+ return %0, %1 : f32, vector<4xf32>
+}
+
// CHECK-LABEL: @pow_recip
func.func @pow_recip(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
// CHECK-DAG: %[[CST_S:.*]] = arith.constant 1.0{{.*}} : f32
``````````
</details>
https://github.com/llvm/llvm-project/pull/177631
More information about the Mlir-commits
mailing list