[Mlir-commits] [mlir] [mlir][math] Add constant folding for `math.fpowi` (PR #193761)
Longsheng Mou
llvmlistbot at llvm.org
Fri Apr 24 00:10:39 PDT 2026
https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/193761
>From cd149c2d640f6fa21aea9fd33f19c92a4056a6f4 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Thu, 23 Apr 2026 15:53:00 +0800
Subject: [PATCH 1/3] [mlir][math] Fold math.fpowi with constant operands
Adds a constant folder for math.fpowi when both operands are constant and the integer exponent is exactly representable in the floating-point type of the base.
---
mlir/include/mlir/Dialect/CommonFolders.h | 94 +++++++++++---------
mlir/include/mlir/Dialect/Math/IR/MathOps.td | 6 +-
mlir/lib/Dialect/Math/IR/MathOps.cpp | 28 ++++++
mlir/test/Dialect/Math/canonicalize.mlir | 33 +++++++
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 3 +-
5 files changed, 117 insertions(+), 47 deletions(-)
diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h
index 113765157946d..736b16ed25d44 100644
--- a/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/mlir/include/mlir/Dialect/CommonFolders.h
@@ -36,13 +36,14 @@ class PoisonAttr;
/// Uses `resultType` for the type of the returned attribute.
/// Optional PoisonAttr template argument allows to specify 'poison' attribute
/// which will be directly propagated to result.
-template <class AttrElementT, //
- class ElementValueT = typename AttrElementT::ValueType,
+template <class LAttrElementT, class RAttrElementT = LAttrElementT,
+ class LElementValueT = typename LAttrElementT::ValueType,
+ class RElementValueT = typename RAttrElementT::ValueType,
class PoisonAttr = ub::PoisonAttr,
- class ResultAttrElementT = AttrElementT,
+ class ResultAttrElementT = LAttrElementT,
class ResultElementValueT = typename ResultAttrElementT::ValueType,
- class CalculationT = function_ref<
- std::optional<ResultElementValueT>(ElementValueT, ElementValueT)>>
+ class CalculationT = function_ref<std::optional<ResultElementValueT>(
+ LElementValueT, RElementValueT)>>
Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
Type resultType,
CalculationT &&calculate) {
@@ -62,11 +63,12 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
if (!resultType || !operands[0] || !operands[1])
return {};
- if (isa<AttrElementT>(operands[0]) && isa<AttrElementT>(operands[1])) {
- auto lhs = cast<AttrElementT>(operands[0]);
- auto rhs = cast<AttrElementT>(operands[1]);
- if (lhs.getType() != rhs.getType())
- return {};
+ if (isa<LAttrElementT>(operands[0]) && isa<RAttrElementT>(operands[1])) {
+ auto lhs = cast<LAttrElementT>(operands[0]);
+ auto rhs = cast<RAttrElementT>(operands[1]);
+ if constexpr (std::is_same_v<LElementValueT, RElementValueT>)
+ if (lhs.getType() != rhs.getType())
+ return {};
auto calRes = calculate(lhs.getValue(), rhs.getValue());
@@ -82,11 +84,12 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
// just fold based on the splat value.
auto lhs = cast<SplatElementsAttr>(operands[0]);
auto rhs = cast<SplatElementsAttr>(operands[1]);
- if (lhs.getType() != rhs.getType())
- return {};
+ if constexpr (std::is_same_v<LElementValueT, RElementValueT>)
+ if (lhs.getType() != rhs.getType())
+ return {};
- auto elementResult = calculate(lhs.getSplatValue<ElementValueT>(),
- rhs.getSplatValue<ElementValueT>());
+ auto elementResult = calculate(lhs.getSplatValue<LElementValueT>(),
+ rhs.getSplatValue<RElementValueT>());
if (!elementResult)
return {};
@@ -98,11 +101,12 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
// expanding the values.
auto lhs = cast<ElementsAttr>(operands[0]);
auto rhs = cast<ElementsAttr>(operands[1]);
- if (lhs.getType() != rhs.getType())
- return {};
+ if constexpr (std::is_same_v<LElementValueT, RElementValueT>)
+ if (lhs.getType() != rhs.getType())
+ return {};
- auto maybeLhsIt = lhs.try_value_begin<ElementValueT>();
- auto maybeRhsIt = rhs.try_value_begin<ElementValueT>();
+ auto maybeLhsIt = lhs.try_value_begin<LElementValueT>();
+ auto maybeRhsIt = rhs.try_value_begin<RElementValueT>();
if (!maybeLhsIt || !maybeRhsIt)
return {};
auto lhsIt = *maybeLhsIt;
@@ -127,13 +131,14 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
/// attribute.
/// Optional PoisonAttr template argument allows to specify 'poison' attribute
/// which will be directly propagated to result.
-template <class AttrElementT, //
- class ElementValueT = typename AttrElementT::ValueType,
+template <class LAttrElementT, class RAttrElementT = LAttrElementT,
+ class LElementValueT = typename LAttrElementT::ValueType,
+ class RElementValueT = typename RAttrElementT::ValueType,
class PoisonAttr = ub::PoisonAttr,
- class ResultAttrElementT = AttrElementT,
+ class ResultAttrElementT = LAttrElementT,
class ResultElementValueT = typename ResultAttrElementT::ValueType,
- class CalculationT = function_ref<
- std::optional<ResultElementValueT>(ElementValueT, ElementValueT)>>
+ class CalculationT = function_ref<std::optional<ResultElementValueT>(
+ LElementValueT, RElementValueT)>>
Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
CalculationT &&calculate) {
assert(operands.size() == 2 && "binary op takes two operands");
@@ -159,44 +164,49 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
Type rhsType = getAttrType(operands[1]);
if (!lhsType || !rhsType)
return {};
- if (lhsType != rhsType)
- return {};
+ if constexpr (std::is_same_v<LElementValueT, RElementValueT>)
+ if (lhsType != rhsType)
+ return {};
- return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
- ResultAttrElementT, ResultElementValueT,
- CalculationT>(
+ return constFoldBinaryOpConditional<
+ LAttrElementT, RAttrElementT, LElementValueT, RElementValueT, PoisonAttr,
+ ResultAttrElementT, ResultElementValueT, CalculationT>(
operands, lhsType, std::forward<CalculationT>(calculate));
}
-template <class AttrElementT,
- class ElementValueT = typename AttrElementT::ValueType,
+template <class LAttrElementT, class RAttrElementT = LAttrElementT,
+ class LElementValueT = typename LAttrElementT::ValueType,
+ class RElementValueT = typename RAttrElementT::ValueType,
class PoisonAttr = void, //
- class ResultAttrElementT = AttrElementT,
+ class ResultAttrElementT = LAttrElementT,
class ResultElementValueT = typename ResultAttrElementT::ValueType,
class CalculationT =
- function_ref<ResultElementValueT(ElementValueT, ElementValueT)>>
+ function_ref<ResultElementValueT(LElementValueT, RElementValueT)>>
Attribute constFoldBinaryOp(ArrayRef<Attribute> operands, Type resultType,
CalculationT &&calculate) {
- return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
- ResultAttrElementT>(
+ return constFoldBinaryOpConditional<LAttrElementT, RAttrElementT,
+ LElementValueT, RElementValueT,
+ PoisonAttr, ResultAttrElementT>(
operands, resultType,
- [&](ElementValueT a, ElementValueT b)
+ [&](LElementValueT a, RElementValueT b)
-> std::optional<ResultElementValueT> { return calculate(a, b); });
}
-template <class AttrElementT, //
- class ElementValueT = typename AttrElementT::ValueType,
+template <class LAttrElementT, class RAttrElementT = LAttrElementT,
+ class LElementValueT = typename LAttrElementT::ValueType,
+ class RElementValueT = typename RAttrElementT::ValueType,
class PoisonAttr = ub::PoisonAttr,
- class ResultAttrElementT = AttrElementT,
+ class ResultAttrElementT = LAttrElementT,
class ResultElementValueT = typename ResultAttrElementT::ValueType,
class CalculationT =
- function_ref<ResultElementValueT(ElementValueT, ElementValueT)>>
+ function_ref<ResultElementValueT(LElementValueT, RElementValueT)>>
Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
CalculationT &&calculate) {
- return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
- ResultAttrElementT>(
+ return constFoldBinaryOpConditional<LAttrElementT, RAttrElementT,
+ LElementValueT, RElementValueT,
+ PoisonAttr, ResultAttrElementT>(
operands,
- [&](ElementValueT a, ElementValueT b)
+ [&](LElementValueT a, RElementValueT b)
-> std::optional<ResultElementValueT> { return calculate(a, b); });
}
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index 1265bfb18aaa2..90f3f121a16d9 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -1148,7 +1148,7 @@ def Math_FPowIOp : Math_Op<"fpowi",
The operation is elementwise for non-scalars, e.g.:
```mlir
- %v = math.fpowi %base, %power : vector<2xf32>, vector<2xi32
+ %v = math.fpowi %base, %power : vector<2xf32>, vector<2xi32>
```
The result is a vector of:
@@ -1172,9 +1172,7 @@ def Math_FPowIOp : Math_Op<"fpowi",
let assemblyFormat = [{ $lhs `,` $rhs (`fastmath` `` $fastmath^)?
attr-dict `:` type($lhs) `,` type($rhs) }];
- // TODO: add a constant folder using pow[f] for cases, when
- // the power argument is exactly representable in floating
- // point type of the base.
+ let hasFolder = 1;
}
#endif // MATH_OPS
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 4c0274ddb18a1..bb552bd253b5f 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -776,6 +776,34 @@ OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) {
});
}
+//===----------------------------------------------------------------------===//
+// FPowIOp folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::FPowIOp::fold(FoldAdaptor adaptor) {
+ return constFoldBinaryOpConditional<FloatAttr, IntegerAttr>(
+ adaptor.getOperands(),
+ [](const APFloat &base, const APInt &exp) -> std::optional<APFloat> {
+ const llvm::fltSemantics &sem = base.getSemantics();
+ // Fold when the exponent is exactly representable in the
+ // floating-point type of the base.
+ APFloat fExp(sem);
+ if (fExp.convertFromAPInt(exp, /*isSigned=*/true,
+ APFloat::rmNearestTiesToEven) !=
+ APFloat::opOK)
+ return {};
+
+ switch (APFloat::getSizeInBits(sem)) {
+ case 64:
+ return APFloat(pow(base.convertToDouble(), fExp.convertToDouble()));
+ case 32:
+ return APFloat(powf(base.convertToFloat(), fExp.convertToFloat()));
+ default:
+ return {};
+ }
+ });
+}
+
/// Materialize an integer or floating point constant.
Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
diff --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir
index 67235c38e9cdf..228faa31781c4 100644
--- a/mlir/test/Dialect/Math/canonicalize.mlir
+++ b/mlir/test/Dialect/Math/canonicalize.mlir
@@ -614,3 +614,36 @@ func.func @ipowi_i1_const_neg_exp() -> i1 {
%r = math.ipowi %b, %e : i1
return %r : i1
}
+
+// CHECK-LABEL: @fpowi_fold
+// CHECK: %[[cst:.+]] = arith.constant 4.000000e+00 : f64
+// CHECK: %[[cst0:.+]] = arith.constant 4.000000e+00 : f32
+// CHECK: return %[[cst]], %[[cst0]] : f64, f32
+func.func @fpowi_fold() -> (f64, f32) {
+ %cst = arith.constant 2.000000e+00 : f64
+ %cst_0 = arith.constant 2.000000e+00 : f32
+ %c2_i32 = arith.constant 2 : i32
+ %0 = math.fpowi %cst, %c2_i32 : f64, i32
+ %1 = math.fpowi %cst_0, %c2_i32 : f32, i32
+ return %0, %1 : f64, f32
+}
+
+// CHECK-LABEL: @fpowi_fold_vec
+// CHECK: %[[cst:.+]] = arith.constant dense<[1.000000e+00, 1.600000e+01, 9.000000e+00, 1.600000e+01]> : vector<4xf32>
+// CHECK: return %[[cst]]
+func.func @fpowi_fold_vec() -> vector<4xf32> {
+ %cst = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
+ %cst_0 = arith.constant dense<[2, 4, 2, 2]> : vector<4xi32>
+ %0 = math.fpowi %cst, %cst_0 : vector<4xf32>, vector<4xi32>
+ return %0 : vector<4xf32>
+}
+
+// 16777217 is not exactly representable in f32.
+// CHECK-LABEL: @fpowi_fold_failed
+// CHECK: math.fpowi
+func.func @fpowi_fold_failed() -> f32 {
+ %cst = arith.constant 2.000000e+00 : f32
+ %c16777217_i32 = arith.constant 16777217 : i32
+ %0 = math.fpowi %cst, %c16777217_i32 : f32, i32
+ return %0 : f32
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index c8be4bf3f0f8d..55e72b57cfd1b 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -265,7 +265,8 @@ struct FoldLessThanOpF32ToI1 : public OpRewritePattern<test::LessThanOp> {
Attribute operandAttrs[2] = {lhsAttr, rhsAttr};
TypedAttr res = cast_or_null<TypedAttr>(
- constFoldBinaryOp<FloatAttr, FloatAttr::ValueType, void, IntegerAttr>(
+ constFoldBinaryOp<FloatAttr, FloatAttr, FloatAttr::ValueType,
+ FloatAttr::ValueType, void, IntegerAttr>(
operandAttrs, op.getType(), [](APFloat lhs, APFloat rhs) -> APInt {
return APInt(1, lhs < rhs);
}));
>From c283dc4853e3740fc907f50e204b485c22015bcc Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Thu, 23 Apr 2026 23:38:54 +0800
Subject: [PATCH 2/3] use semantic rather than bitwidth
---
mlir/lib/Dialect/Math/IR/MathOps.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index bb552bd253b5f..5aa3a5f0f9f66 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -793,10 +793,10 @@ OpFoldResult math::FPowIOp::fold(FoldAdaptor adaptor) {
APFloat::opOK)
return {};
- switch (APFloat::getSizeInBits(sem)) {
- case 64:
+ switch (APFloat::SemanticsToEnum(sem)) {
+ case APFloat::S_IEEEdouble:
return APFloat(pow(base.convertToDouble(), fExp.convertToDouble()));
- case 32:
+ case APFloat::S_IEEEsingle:
return APFloat(powf(base.convertToFloat(), fExp.convertToFloat()));
default:
return {};
>From 6557f5b15dace76e116fd94d924ec55875549e82 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Fri, 24 Apr 2026 15:10:29 +0800
Subject: [PATCH 3/3] use semantics enum
---
mlir/lib/Dialect/Math/IR/MathOps.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 5aa3a5f0f9f66..9ba11c4a71f99 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -794,9 +794,9 @@ OpFoldResult math::FPowIOp::fold(FoldAdaptor adaptor) {
return {};
switch (APFloat::SemanticsToEnum(sem)) {
- case APFloat::S_IEEEdouble:
+ case APFloat::Semantics::S_IEEEdouble:
return APFloat(pow(base.convertToDouble(), fExp.convertToDouble()));
- case APFloat::S_IEEEsingle:
+ case APFloat::Semantics::S_IEEEsingle:
return APFloat(powf(base.convertToFloat(), fExp.convertToFloat()));
default:
return {};
More information about the Mlir-commits
mailing list