[Mlir-commits] [mlir] [MLIR][Math] Use square-and-multiply sequence instead of linear during power expansion (PR #177631)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 23 10:12:43 PST 2026
https://github.com/foxtran created https://github.com/llvm/llvm-project/pull/177631
Currently, `flang` generates suboptimal code for `a**i` where `i` is known at compile time and is in [4, 8]. For such `i`, linear sequence of computations is generated. See an example here: https://godbolt.org/z/PT3q8v1GE
Or here,
```asm
sub_:
movaps xmm1, xmm0
mulss xmm1, xmm0
mulss xmm1, xmm0
mulss xmm1, xmm0
mulss xmm1, xmm0
mulss xmm1, xmm0
mulss xmm1, xmm0
mulss xmm0, xmm1
ret
```
After this PR, `flang` will generate much better code:
```
sub_:
mulss %xmm0, %xmm0
mulss %xmm0, %xmm0
mulss %xmm0, %xmm0
retq
```
>From 95e3717360e32385af7f370ed9948cd15121fb97 Mon Sep 17 00:00:00 2001
From: "Igor S. Gerasimov" <foxtranigor at gmail.com>
Date: Fri, 23 Jan 2026 18:47:19 +0100
Subject: [PATCH 1/2] Add a new test for exponential powering
---
.../Dialect/Math/algebraic-simplification.mlir | 14 ++++++++++++++
1 file changed, 14 insertions(+)
diff --git a/mlir/test/Dialect/Math/algebraic-simplification.mlir b/mlir/test/Dialect/Math/algebraic-simplification.mlir
index 7342600748967..e4fc58c010d54 100644
--- a/mlir/test/Dialect/Math/algebraic-simplification.mlir
+++ b/mlir/test/Dialect/Math/algebraic-simplification.mlir
@@ -62,6 +62,20 @@ func.func @pow_cube_fast(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf3
return %0, %1 : f32, vector<4xf32>
}
+// CHECK-LABEL: @pow_4_int
+func.func @pow_4_int(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
+ // CHECK: %[[TMP_S:.*]] = arith.mulf %arg0, %arg0
+ // CHECK: %[[SCALAR:.*]] = arith.mulf %[[TMP_S]], %[[TMP_S]]
+ // CHECK: %[[TMP_V:.*]] = arith.mulf %arg1, %arg1
+ // CHECK: %[[VECTOR:.*]] = arith.mulf %[[TMP_V]], %[[TMP_V]]
+ // CHECK: return %[[SCALAR]], %[[VECTOR]]
+ %c = arith.constant 4 : i32
+ %v = arith.constant dense <4> : vector<4xi32>
+ %0 = math.fpowi %arg0, %c : f32, i32
+ %1 = math.fpowi %arg1, %v : vector<4xf32>, vector<4xi32>
+ return %0, %1 : f32, vector<4xf32>
+}
+
// CHECK-LABEL: @pow_recip
func.func @pow_recip(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
// CHECK-DAG: %[[CST_S:.*]] = arith.constant 1.0{{.*}} : f32
>From 668236db426303b8695d13f178681177278ff45c Mon Sep 17 00:00:00 2001
From: "Igor S. Gerasimov" <foxtranigor at gmail.com>
Date: Fri, 23 Jan 2026 18:59:33 +0100
Subject: [PATCH 2/2] Implement square-and-multiply algorithm
---
.../Math/Transforms/AlgebraicSimplification.cpp | 17 +++++++++++------
1 file changed, 11 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index ff5f7f685903f..acf063248cb30 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -206,16 +206,16 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
if (exponentValue > exponentThreshold)
return failure();
- Value result = base;
- // Transform to naive sequence of multiplications:
+ Value result = one;
+ // Transform to a square-and-multiply sequence:
// * For positive exponent case replace:
// `[fi]powi(x, positive_exponent)`
// with:
- // x * x * x * ...
+ // a chain of x, x*x, x**4, and so on...
// * For negative exponent case replace:
// `[fi]powi(x, negative_exponent)`
// with:
- // (1 / x) * (1 / x) * (1 / x) * ...
+ // a chain of (1/x), (1/x)*(1/x), and so on...
auto buildMul = [&](Value lhs, Value rhs) {
if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>)
return MulOpTy::create(rewriter, loc, op.getType(), lhs, rhs,
@@ -223,8 +223,13 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
else
return MulOpTy::create(rewriter, loc, lhs, rhs);
};
- for (unsigned i = 1; i < exponentValue; ++i)
- result = buildMul(result, base);
+ while (exponentValue > 0) {
+ if (exponentValue & 1) {
+ result = buildMul(base, result);
+ }
+ exponentValue >>= 1;
+ base = buildMul(base, base);
+ }
// Inverse the base for negative exponent, i.e. for
// `[fi]powi(x, negative_exponent)` set `x` to `1 / x`.
More information about the Mlir-commits
mailing list