[Mlir-commits] [mlir] [mlir][math] Add constant folding for `math.fpowi` (PR #193761)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 23 07:18:51 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-math

Author: Longsheng Mou (CoTinker)

<details>
<summary>Changes</summary>

Adds a constant folder for `math.fpowi` when both operands are constant and the integer exponent is exactly representable in the floating-point type of the base.

---
Full diff: https://github.com/llvm/llvm-project/pull/193761.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/CommonFolders.h (+52-42) 
- (modified) mlir/include/mlir/Dialect/Math/IR/MathOps.td (+2-4) 
- (modified) mlir/lib/Dialect/Math/IR/MathOps.cpp (+28) 
- (modified) mlir/test/Dialect/Math/canonicalize.mlir (+33) 
- (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+2-1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h
index 113765157946d..736b16ed25d44 100644
--- a/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/mlir/include/mlir/Dialect/CommonFolders.h
@@ -36,13 +36,14 @@ class PoisonAttr;
 /// Uses `resultType` for the type of the returned attribute.
 /// Optional PoisonAttr template argument allows to specify 'poison' attribute
 /// which will be directly propagated to result.
-template <class AttrElementT, //
-          class ElementValueT = typename AttrElementT::ValueType,
+template <class LAttrElementT, class RAttrElementT = LAttrElementT,
+          class LElementValueT = typename LAttrElementT::ValueType,
+          class RElementValueT = typename RAttrElementT::ValueType,
           class PoisonAttr = ub::PoisonAttr,
-          class ResultAttrElementT = AttrElementT,
+          class ResultAttrElementT = LAttrElementT,
           class ResultElementValueT = typename ResultAttrElementT::ValueType,
-          class CalculationT = function_ref<
-              std::optional<ResultElementValueT>(ElementValueT, ElementValueT)>>
+          class CalculationT = function_ref<std::optional<ResultElementValueT>(
+              LElementValueT, RElementValueT)>>
 Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
                                        Type resultType,
                                        CalculationT &&calculate) {
@@ -62,11 +63,12 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
   if (!resultType || !operands[0] || !operands[1])
     return {};
 
-  if (isa<AttrElementT>(operands[0]) && isa<AttrElementT>(operands[1])) {
-    auto lhs = cast<AttrElementT>(operands[0]);
-    auto rhs = cast<AttrElementT>(operands[1]);
-    if (lhs.getType() != rhs.getType())
-      return {};
+  if (isa<LAttrElementT>(operands[0]) && isa<RAttrElementT>(operands[1])) {
+    auto lhs = cast<LAttrElementT>(operands[0]);
+    auto rhs = cast<RAttrElementT>(operands[1]);
+    if constexpr (std::is_same_v<LElementValueT, RElementValueT>)
+      if (lhs.getType() != rhs.getType())
+        return {};
 
     auto calRes = calculate(lhs.getValue(), rhs.getValue());
 
@@ -82,11 +84,12 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
     // just fold based on the splat value.
     auto lhs = cast<SplatElementsAttr>(operands[0]);
     auto rhs = cast<SplatElementsAttr>(operands[1]);
-    if (lhs.getType() != rhs.getType())
-      return {};
+    if constexpr (std::is_same_v<LElementValueT, RElementValueT>)
+      if (lhs.getType() != rhs.getType())
+        return {};
 
-    auto elementResult = calculate(lhs.getSplatValue<ElementValueT>(),
-                                   rhs.getSplatValue<ElementValueT>());
+    auto elementResult = calculate(lhs.getSplatValue<LElementValueT>(),
+                                   rhs.getSplatValue<RElementValueT>());
     if (!elementResult)
       return {};
 
@@ -98,11 +101,12 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
     // expanding the values.
     auto lhs = cast<ElementsAttr>(operands[0]);
     auto rhs = cast<ElementsAttr>(operands[1]);
-    if (lhs.getType() != rhs.getType())
-      return {};
+    if constexpr (std::is_same_v<LElementValueT, RElementValueT>)
+      if (lhs.getType() != rhs.getType())
+        return {};
 
-    auto maybeLhsIt = lhs.try_value_begin<ElementValueT>();
-    auto maybeRhsIt = rhs.try_value_begin<ElementValueT>();
+    auto maybeLhsIt = lhs.try_value_begin<LElementValueT>();
+    auto maybeRhsIt = rhs.try_value_begin<RElementValueT>();
     if (!maybeLhsIt || !maybeRhsIt)
       return {};
     auto lhsIt = *maybeLhsIt;
@@ -127,13 +131,14 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
 /// attribute.
 /// Optional PoisonAttr template argument allows to specify 'poison' attribute
 /// which will be directly propagated to result.
-template <class AttrElementT, //
-          class ElementValueT = typename AttrElementT::ValueType,
+template <class LAttrElementT, class RAttrElementT = LAttrElementT,
+          class LElementValueT = typename LAttrElementT::ValueType,
+          class RElementValueT = typename RAttrElementT::ValueType,
           class PoisonAttr = ub::PoisonAttr,
-          class ResultAttrElementT = AttrElementT,
+          class ResultAttrElementT = LAttrElementT,
           class ResultElementValueT = typename ResultAttrElementT::ValueType,
-          class CalculationT = function_ref<
-              std::optional<ResultElementValueT>(ElementValueT, ElementValueT)>>
+          class CalculationT = function_ref<std::optional<ResultElementValueT>(
+              LElementValueT, RElementValueT)>>
 Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
                                        CalculationT &&calculate) {
   assert(operands.size() == 2 && "binary op takes two operands");
@@ -159,44 +164,49 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
   Type rhsType = getAttrType(operands[1]);
   if (!lhsType || !rhsType)
     return {};
-  if (lhsType != rhsType)
-    return {};
+  if constexpr (std::is_same_v<LElementValueT, RElementValueT>)
+    if (lhsType != rhsType)
+      return {};
 
-  return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
-                                      ResultAttrElementT, ResultElementValueT,
-                                      CalculationT>(
+  return constFoldBinaryOpConditional<
+      LAttrElementT, RAttrElementT, LElementValueT, RElementValueT, PoisonAttr,
+      ResultAttrElementT, ResultElementValueT, CalculationT>(
       operands, lhsType, std::forward<CalculationT>(calculate));
 }
 
-template <class AttrElementT,
-          class ElementValueT = typename AttrElementT::ValueType,
+template <class LAttrElementT, class RAttrElementT = LAttrElementT,
+          class LElementValueT = typename LAttrElementT::ValueType,
+          class RElementValueT = typename RAttrElementT::ValueType,
           class PoisonAttr = void, //
-          class ResultAttrElementT = AttrElementT,
+          class ResultAttrElementT = LAttrElementT,
           class ResultElementValueT = typename ResultAttrElementT::ValueType,
           class CalculationT =
-              function_ref<ResultElementValueT(ElementValueT, ElementValueT)>>
+              function_ref<ResultElementValueT(LElementValueT, RElementValueT)>>
 Attribute constFoldBinaryOp(ArrayRef<Attribute> operands, Type resultType,
                             CalculationT &&calculate) {
-  return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
-                                      ResultAttrElementT>(
+  return constFoldBinaryOpConditional<LAttrElementT, RAttrElementT,
+                                      LElementValueT, RElementValueT,
+                                      PoisonAttr, ResultAttrElementT>(
       operands, resultType,
-      [&](ElementValueT a, ElementValueT b)
+      [&](LElementValueT a, RElementValueT b)
           -> std::optional<ResultElementValueT> { return calculate(a, b); });
 }
 
-template <class AttrElementT, //
-          class ElementValueT = typename AttrElementT::ValueType,
+template <class LAttrElementT, class RAttrElementT = LAttrElementT,
+          class LElementValueT = typename LAttrElementT::ValueType,
+          class RElementValueT = typename RAttrElementT::ValueType,
           class PoisonAttr = ub::PoisonAttr,
-          class ResultAttrElementT = AttrElementT,
+          class ResultAttrElementT = LAttrElementT,
           class ResultElementValueT = typename ResultAttrElementT::ValueType,
           class CalculationT =
-              function_ref<ResultElementValueT(ElementValueT, ElementValueT)>>
+              function_ref<ResultElementValueT(LElementValueT, RElementValueT)>>
 Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
                             CalculationT &&calculate) {
-  return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
-                                      ResultAttrElementT>(
+  return constFoldBinaryOpConditional<LAttrElementT, RAttrElementT,
+                                      LElementValueT, RElementValueT,
+                                      PoisonAttr, ResultAttrElementT>(
       operands,
-      [&](ElementValueT a, ElementValueT b)
+      [&](LElementValueT a, RElementValueT b)
           -> std::optional<ResultElementValueT> { return calculate(a, b); });
 }
 
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index 1265bfb18aaa2..90f3f121a16d9 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -1148,7 +1148,7 @@ def Math_FPowIOp : Math_Op<"fpowi",
     The operation is elementwise for non-scalars, e.g.:
 
     ```mlir
-    %v = math.fpowi %base, %power : vector<2xf32>, vector<2xi32
+    %v = math.fpowi %base, %power : vector<2xf32>, vector<2xi32>
     ```
 
     The result is a vector of:
@@ -1172,9 +1172,7 @@ def Math_FPowIOp : Math_Op<"fpowi",
   let assemblyFormat = [{ $lhs `,` $rhs (`fastmath` `` $fastmath^)?
                           attr-dict `:` type($lhs) `,` type($rhs) }];
 
-  // TODO: add a constant folder using pow[f] for cases, when
-  //       the power argument is exactly representable in floating
-  //       point type of the base.
+  let hasFolder = 1;
 }
 
 #endif // MATH_OPS
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 4c0274ddb18a1..bb552bd253b5f 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -776,6 +776,34 @@ OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) {
       });
 }
 
+//===----------------------------------------------------------------------===//
+// FPowIOp folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::FPowIOp::fold(FoldAdaptor adaptor) {
+  return constFoldBinaryOpConditional<FloatAttr, IntegerAttr>(
+      adaptor.getOperands(),
+      [](const APFloat &base, const APInt &exp) -> std::optional<APFloat> {
+        const llvm::fltSemantics &sem = base.getSemantics();
+        // Fold when the exponent is exactly representable in the
+        // floating-point type of the base.
+        APFloat fExp(sem);
+        if (fExp.convertFromAPInt(exp, /*isSigned=*/true,
+                                  APFloat::rmNearestTiesToEven) !=
+            APFloat::opOK)
+          return {};
+
+        switch (APFloat::getSizeInBits(sem)) {
+        case 64:
+          return APFloat(pow(base.convertToDouble(), fExp.convertToDouble()));
+        case 32:
+          return APFloat(powf(base.convertToFloat(), fExp.convertToFloat()));
+        default:
+          return {};
+        }
+      });
+}
+
 /// Materialize an integer or floating point constant.
 Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
                                                   Attribute value, Type type,
diff --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir
index 67235c38e9cdf..228faa31781c4 100644
--- a/mlir/test/Dialect/Math/canonicalize.mlir
+++ b/mlir/test/Dialect/Math/canonicalize.mlir
@@ -614,3 +614,36 @@ func.func @ipowi_i1_const_neg_exp() -> i1 {
   %r = math.ipowi %b, %e : i1
   return %r : i1
 }
+
+// CHECK-LABEL: @fpowi_fold
+// CHECK: %[[cst:.+]] = arith.constant 4.000000e+00 : f64
+// CHECK: %[[cst0:.+]] = arith.constant 4.000000e+00 : f32
+// CHECK: return %[[cst]], %[[cst0]] : f64, f32
+func.func @fpowi_fold() -> (f64, f32) {
+  %cst = arith.constant 2.000000e+00 : f64
+  %cst_0 = arith.constant 2.000000e+00 : f32
+  %c2_i32 = arith.constant 2 : i32
+  %0 = math.fpowi %cst, %c2_i32 : f64, i32
+  %1 = math.fpowi %cst_0, %c2_i32 : f32, i32
+  return %0, %1 : f64, f32
+}
+
+// CHECK-LABEL: @fpowi_fold_vec
+// CHECK: %[[cst:.+]] = arith.constant dense<[1.000000e+00, 1.600000e+01, 9.000000e+00, 1.600000e+01]> : vector<4xf32>
+// CHECK: return %[[cst]]
+func.func @fpowi_fold_vec() -> vector<4xf32> {
+  %cst = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
+  %cst_0 = arith.constant dense<[2, 4, 2, 2]> : vector<4xi32>
+  %0 = math.fpowi %cst, %cst_0 : vector<4xf32>, vector<4xi32>
+  return %0 : vector<4xf32>
+}
+
+// 16777217 is not exactly representable in f32.
+// CHECK-LABEL: @fpowi_fold_failed
+// CHECK:       math.fpowi
+func.func @fpowi_fold_failed() -> f32 {
+  %cst = arith.constant 2.000000e+00 : f32
+  %c16777217_i32 = arith.constant 16777217 : i32
+  %0 = math.fpowi %cst, %c16777217_i32 : f32, i32
+  return %0 : f32
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index c8be4bf3f0f8d..55e72b57cfd1b 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -265,7 +265,8 @@ struct FoldLessThanOpF32ToI1 : public OpRewritePattern<test::LessThanOp> {
 
     Attribute operandAttrs[2] = {lhsAttr, rhsAttr};
     TypedAttr res = cast_or_null<TypedAttr>(
-        constFoldBinaryOp<FloatAttr, FloatAttr::ValueType, void, IntegerAttr>(
+        constFoldBinaryOp<FloatAttr, FloatAttr, FloatAttr::ValueType,
+                          FloatAttr::ValueType, void, IntegerAttr>(
             operandAttrs, op.getType(), [](APFloat lhs, APFloat rhs) -> APInt {
               return APInt(1, lhs < rhs);
             }));

``````````

</details>


https://github.com/llvm/llvm-project/pull/193761


More information about the Mlir-commits mailing list