[Mlir-commits] [mlir] [mlir][math] Simplify pow(2^n, y) to exp2(y) (PR #166183)

Aleksei Nurmukhametov llvmlistbot at llvm.org
Mon Nov 3 08:11:46 PST 2025


https://github.com/nurmukhametov updated https://github.com/llvm/llvm-project/pull/166183

>From 690b52b46432bfdc39cf3f975e61e726b07bc189 Mon Sep 17 00:00:00 2001
From: Aleksei Nurmukhametov <anurmukh at amd.com>
Date: Mon, 3 Nov 2025 15:02:53 +0000
Subject: [PATCH] [mlir][math] Simplify pow(2^n, y) to exp2(y)

---
 .../Transforms/AlgebraicSimplification.cpp    | 34 +++++++++--
 .../Math/algebraic-simplification.mlir        | 60 +++++++++++++++++++
 2 files changed, 88 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index 77b10cec48d8e..2e5b48ebbb1eb 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -42,20 +42,24 @@ LogicalResult
 PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
                                        PatternRewriter &rewriter) const {
   Location loc = op.getLoc();
+  // pow(x, y)
   Value x = op.getLhs();
+  Value y = op.getRhs();
 
-  FloatAttr scalarExponent;
-  DenseFPElementsAttr vectorExponent;
+  FloatAttr scalarBase, scalarExponent;
+  DenseFPElementsAttr vectorBase, vectorExponent;
 
-  bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent));
-  bool isVector = matchPattern(op.getRhs(), m_Constant(&vectorExponent));
+  bool isScalarBase = matchPattern(x, m_Constant(&scalarBase));
+  bool isVectorBase = matchPattern(x, m_Constant(&vectorBase));
+  bool isScalarExponent = matchPattern(y, m_Constant(&scalarExponent));
+  bool isVectorExponent = matchPattern(y, m_Constant(&vectorExponent));
 
   // Returns true if exponent is a constant equal to `value`.
   auto isExponentValue = [&](double value) -> bool {
-    if (isScalar)
+    if (isScalarExponent)
       return scalarExponent.getValue().isExactlyValue(value);
 
-    if (isVector && vectorExponent.isSplat())
+    if (isVectorExponent && vectorExponent.isSplat())
       return vectorExponent.getSplatValue<FloatAttr>()
           .getValue()
           .isExactlyValue(value);
@@ -120,6 +124,24 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
     return success();
   }
 
+  // Replace `pow(2.0^n, y)` with `exp2(n * y)`
+  if (isScalarBase || (isVectorBase && vectorBase.isSplat())) {
+    APFloat baseValue = isScalarBase
+                            ? scalarBase.getValue()
+                            : vectorBase.getSplatValue<FloatAttr>().getValue();
+    // Check if base is an exact power of 2
+    int n = baseValue.getExactLog2();
+    if (n != INT_MIN) {
+      Type opType = getElementTypeOrSelf(op.getType());
+      Value nValue = arith::ConstantOp::create(
+          rewriter, loc, rewriter.getFloatAttr(opType, n));
+      Value nTimesY =
+          arith::MulFOp::create(rewriter, loc, ValueRange({bcast(nValue), y}));
+      rewriter.replaceOpWithNewOp<math::Exp2Op>(op, nTimesY);
+      return success();
+    }
+  }
+
   return failure();
 }
 
diff --git a/mlir/test/Dialect/Math/algebraic-simplification.mlir b/mlir/test/Dialect/Math/algebraic-simplification.mlir
index e0e2b9853a2a1..239be5eeeb6ac 100644
--- a/mlir/test/Dialect/Math/algebraic-simplification.mlir
+++ b/mlir/test/Dialect/Math/algebraic-simplification.mlir
@@ -90,6 +90,66 @@ func.func @pow_0_75(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
   return %0, %1 : f32, vector<4xf32>
 }
 
+// CHECK-LABEL: @pow_of_two
+func.func @pow_of_two(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
+  // CHECK: %[[SCALAR:.*]] = math.exp2 %arg0
+  // CHECK: %[[VECTOR:.*]] = math.exp2 %arg1
+  // CHECK: return %[[SCALAR]], %[[VECTOR]]
+  %c = arith.constant 2.0 : f32
+  %v = arith.constant dense <2.0> : vector<4xf32>
+  %0 = math.powf %c, %arg0 : f32
+  %1 = math.powf %v, %arg1 : vector<4xf32>
+  return %0, %1 : f32, vector<4xf32>
+}
+
+// CHECK-LABEL: @pow_of_four
+func.func @pow_of_four(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
+  // CHECK-DAG: %[[CST_V:.*]] = arith.constant dense<2.000000e+00> : vector<4xf32>
+  // CHECK-DAG: %[[CST_S:.*]] = arith.constant 2.000000e+00 : f32
+  // CHECK: %[[MUL_S:.*]] = arith.mulf %arg0, %[[CST_S]]
+  // CHECK: %[[SCALAR:.*]] = math.exp2 %[[MUL_S]]
+  // CHECK: %[[MUL_V:.*]] = arith.mulf %arg1, %[[CST_V]]
+  // CHECK: %[[VECTOR:.*]] = math.exp2 %[[MUL_V]]
+  // CHECK: return %[[SCALAR]], %[[VECTOR]]
+  %c = arith.constant 4.0 : f32
+  %v = arith.constant dense <4.0> : vector<4xf32>
+  %0 = math.powf %c, %arg0 : f32
+  %1 = math.powf %v, %arg1 : vector<4xf32>
+  return %0, %1 : f32, vector<4xf32>
+}
+
+// CHECK-LABEL: @pow_of_half
+func.func @pow_of_half(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
+  // CHECK-DAG: %[[CST_V:.*]] = arith.constant dense<-1.000000e+00> : vector<4xf32>
+  // CHECK-DAG: %[[CST_S:.*]] = arith.constant -1.000000e+00 : f32
+  // CHECK: %[[MUL_S:.*]] = arith.mulf %arg0, %[[CST_S]]
+  // CHECK: %[[SCALAR:.*]] = math.exp2 %[[MUL_S]]
+  // CHECK: %[[MUL_V:.*]] = arith.mulf %arg1, %[[CST_V]]
+  // CHECK: %[[VECTOR:.*]] = math.exp2 %[[MUL_V]]
+  // CHECK: return %[[SCALAR]], %[[VECTOR]]
+  %c = arith.constant 0.5 : f32
+  %v = arith.constant dense <0.5> : vector<4xf32>
+  %0 = math.powf %c, %arg0 : f32
+  %1 = math.powf %v, %arg1 : vector<4xf32>
+  return %0, %1 : f32, vector<4xf32>
+}
+
+// CHECK-LABEL: @pow_of_quarter
+func.func @pow_of_quarter(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
+  // CHECK-DAG: %[[CST_V:.*]] = arith.constant dense<-2.000000e+00> : vector<4xf32>
+  // CHECK-DAG: %[[CST_S:.*]] = arith.constant -2.000000e+00 : f32
+  // CHECK: %[[MUL_S:.*]] = arith.mulf %arg0, %[[CST_S]]
+  // CHECK: %[[SCALAR:.*]] = math.exp2 %[[MUL_S]]
+  // CHECK: %[[MUL_V:.*]] = arith.mulf %arg1, %[[CST_V]]
+  // CHECK: %[[VECTOR:.*]] = math.exp2 %[[MUL_V]]
+  // CHECK: return %[[SCALAR]], %[[VECTOR]]
+  %c = arith.constant 0.25 : f32
+  %v = arith.constant dense <0.25> : vector<4xf32>
+  %0 = math.powf %c, %arg0 : f32
+  %1 = math.powf %v, %arg1 : vector<4xf32>
+  return %0, %1 : f32, vector<4xf32>
+}
+
 // CHECK-LABEL: @ipowi_zero_exp(
 // CHECK-SAME: %[[ARG0:.+]]: i32
 // CHECK-SAME: %[[ARG1:.+]]: vector<4xi32>



More information about the Mlir-commits mailing list