[Mlir-commits] [mlir] [mlir][math] Simplify pow(2^n, y) to exp2(y) (PR #166183)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 3 07:46:40 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Aleksei Nurmukhametov (nurmukhametov)
<details>
<summary>Changes</summary>
This PR adds a `pow(2^n, y)` to `exp2(y)` transformation to the algebraic simplifications of the math dialect.
---
Full diff: https://github.com/llvm/llvm-project/pull/166183.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp (+27-6)
- (modified) mlir/test/Dialect/Math/algebraic-simplification.mlir (+60)
``````````diff
diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index 77b10cec48d8e..03e8a0d020919 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -42,20 +42,24 @@ LogicalResult
PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
PatternRewriter &rewriter) const {
Location loc = op.getLoc();
+ // pow(x, y)
Value x = op.getLhs();
+ Value y = op.getRhs();
- FloatAttr scalarExponent;
- DenseFPElementsAttr vectorExponent;
+ FloatAttr scalarBase, scalarExponent;
+ DenseFPElementsAttr vectorBase, vectorExponent;
- bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent));
- bool isVector = matchPattern(op.getRhs(), m_Constant(&vectorExponent));
+ bool isScalarBase = matchPattern(x, m_Constant(&scalarBase));
+ bool isVectorBase = matchPattern(x, m_Constant(&vectorBase));
+ bool isScalarExponent = matchPattern(y, m_Constant(&scalarExponent));
+ bool isVectorExponent = matchPattern(y, m_Constant(&vectorExponent));
// Returns true if exponent is a constant equal to `value`.
auto isExponentValue = [&](double value) -> bool {
- if (isScalar)
+ if (isScalarExponent)
return scalarExponent.getValue().isExactlyValue(value);
- if (isVector && vectorExponent.isSplat())
+ if (isVectorExponent && vectorExponent.isSplat())
return vectorExponent.getSplatValue<FloatAttr>()
.getValue()
.isExactlyValue(value);
@@ -120,6 +124,23 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
return success();
}
+ // Replace `pow(2.0^n, y)` with `exp2(n * y)`
+ if (isScalarBase || (isVectorBase && vectorBase.isSplat())) {
+ APFloat baseValue = isScalarBase
+ ? scalarBase.getValue()
+ : vectorBase.getSplatValue<FloatAttr>().getValue();
+ // Check if base is an exact power of 2
+ int n = baseValue.getExactLog2();
+ if (n != INT_MIN) {
+ Value nValue = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), n));
+ Value nTimesY =
+ rewriter.create<arith::MulFOp>(loc, ValueRange({bcast(nValue), y}));
+ rewriter.replaceOpWithNewOp<math::Exp2Op>(op, nTimesY);
+ return success();
+ }
+ }
+
return failure();
}
diff --git a/mlir/test/Dialect/Math/algebraic-simplification.mlir b/mlir/test/Dialect/Math/algebraic-simplification.mlir
index e0e2b9853a2a1..239be5eeeb6ac 100644
--- a/mlir/test/Dialect/Math/algebraic-simplification.mlir
+++ b/mlir/test/Dialect/Math/algebraic-simplification.mlir
@@ -90,6 +90,66 @@ func.func @pow_0_75(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
return %0, %1 : f32, vector<4xf32>
}
+// CHECK-LABEL: @pow_of_two
+func.func @pow_of_two(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
+ // CHECK: %[[SCALAR:.*]] = math.exp2 %arg0
+ // CHECK: %[[VECTOR:.*]] = math.exp2 %arg1
+ // CHECK: return %[[SCALAR]], %[[VECTOR]]
+ %c = arith.constant 2.0 : f32
+ %v = arith.constant dense <2.0> : vector<4xf32>
+ %0 = math.powf %c, %arg0 : f32
+ %1 = math.powf %v, %arg1 : vector<4xf32>
+ return %0, %1 : f32, vector<4xf32>
+}
+
+// CHECK-LABEL: @pow_of_four
+func.func @pow_of_four(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
+ // CHECK-DAG: %[[CST_V:.*]] = arith.constant dense<2.000000e+00> : vector<4xf32>
+ // CHECK-DAG: %[[CST_S:.*]] = arith.constant 2.000000e+00 : f32
+ // CHECK: %[[MUL_S:.*]] = arith.mulf %arg0, %[[CST_S]]
+ // CHECK: %[[SCALAR:.*]] = math.exp2 %[[MUL_S]]
+ // CHECK: %[[MUL_V:.*]] = arith.mulf %arg1, %[[CST_V]]
+ // CHECK: %[[VECTOR:.*]] = math.exp2 %[[MUL_V]]
+ // CHECK: return %[[SCALAR]], %[[VECTOR]]
+ %c = arith.constant 4.0 : f32
+ %v = arith.constant dense <4.0> : vector<4xf32>
+ %0 = math.powf %c, %arg0 : f32
+ %1 = math.powf %v, %arg1 : vector<4xf32>
+ return %0, %1 : f32, vector<4xf32>
+}
+
+// CHECK-LABEL: @pow_of_half
+func.func @pow_of_half(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
+ // CHECK-DAG: %[[CST_V:.*]] = arith.constant dense<-1.000000e+00> : vector<4xf32>
+ // CHECK-DAG: %[[CST_S:.*]] = arith.constant -1.000000e+00 : f32
+ // CHECK: %[[MUL_S:.*]] = arith.mulf %arg0, %[[CST_S]]
+ // CHECK: %[[SCALAR:.*]] = math.exp2 %[[MUL_S]]
+ // CHECK: %[[MUL_V:.*]] = arith.mulf %arg1, %[[CST_V]]
+ // CHECK: %[[VECTOR:.*]] = math.exp2 %[[MUL_V]]
+ // CHECK: return %[[SCALAR]], %[[VECTOR]]
+ %c = arith.constant 0.5 : f32
+ %v = arith.constant dense <0.5> : vector<4xf32>
+ %0 = math.powf %c, %arg0 : f32
+ %1 = math.powf %v, %arg1 : vector<4xf32>
+ return %0, %1 : f32, vector<4xf32>
+}
+
+// CHECK-LABEL: @pow_of_quarter
+func.func @pow_of_quarter(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
+ // CHECK-DAG: %[[CST_V:.*]] = arith.constant dense<-2.000000e+00> : vector<4xf32>
+ // CHECK-DAG: %[[CST_S:.*]] = arith.constant -2.000000e+00 : f32
+ // CHECK: %[[MUL_S:.*]] = arith.mulf %arg0, %[[CST_S]]
+ // CHECK: %[[SCALAR:.*]] = math.exp2 %[[MUL_S]]
+ // CHECK: %[[MUL_V:.*]] = arith.mulf %arg1, %[[CST_V]]
+ // CHECK: %[[VECTOR:.*]] = math.exp2 %[[MUL_V]]
+ // CHECK: return %[[SCALAR]], %[[VECTOR]]
+ %c = arith.constant 0.25 : f32
+ %v = arith.constant dense <0.25> : vector<4xf32>
+ %0 = math.powf %c, %arg0 : f32
+ %1 = math.powf %v, %arg1 : vector<4xf32>
+ return %0, %1 : f32, vector<4xf32>
+}
+
// CHECK-LABEL: @ipowi_zero_exp(
// CHECK-SAME: %[[ARG0:.+]]: i32
// CHECK-SAME: %[[ARG1:.+]]: vector<4xi32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/166183
More information about the Mlir-commits
mailing list