[Mlir-commits] [mlir] [mlir][math] Add constant folding for `math.fpowi` (PR #193761)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 23 07:18:51 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-math
Author: Longsheng Mou (CoTinker)
<details>
<summary>Changes</summary>
Adds a constant folder for `math.fpowi` when both operands are constant and the integer exponent is exactly representable in the floating-point type of the base.
---
Full diff: https://github.com/llvm/llvm-project/pull/193761.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/CommonFolders.h (+52-42)
- (modified) mlir/include/mlir/Dialect/Math/IR/MathOps.td (+2-4)
- (modified) mlir/lib/Dialect/Math/IR/MathOps.cpp (+28)
- (modified) mlir/test/Dialect/Math/canonicalize.mlir (+33)
- (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+2-1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h
index 113765157946d..736b16ed25d44 100644
--- a/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/mlir/include/mlir/Dialect/CommonFolders.h
@@ -36,13 +36,14 @@ class PoisonAttr;
/// Uses `resultType` for the type of the returned attribute.
/// Optional PoisonAttr template argument allows to specify 'poison' attribute
/// which will be directly propagated to result.
-template <class AttrElementT, //
- class ElementValueT = typename AttrElementT::ValueType,
+template <class LAttrElementT, class RAttrElementT = LAttrElementT,
+ class LElementValueT = typename LAttrElementT::ValueType,
+ class RElementValueT = typename RAttrElementT::ValueType,
class PoisonAttr = ub::PoisonAttr,
- class ResultAttrElementT = AttrElementT,
+ class ResultAttrElementT = LAttrElementT,
class ResultElementValueT = typename ResultAttrElementT::ValueType,
- class CalculationT = function_ref<
- std::optional<ResultElementValueT>(ElementValueT, ElementValueT)>>
+ class CalculationT = function_ref<std::optional<ResultElementValueT>(
+ LElementValueT, RElementValueT)>>
Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
Type resultType,
CalculationT &&calculate) {
@@ -62,11 +63,12 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
if (!resultType || !operands[0] || !operands[1])
return {};
- if (isa<AttrElementT>(operands[0]) && isa<AttrElementT>(operands[1])) {
- auto lhs = cast<AttrElementT>(operands[0]);
- auto rhs = cast<AttrElementT>(operands[1]);
- if (lhs.getType() != rhs.getType())
- return {};
+ if (isa<LAttrElementT>(operands[0]) && isa<RAttrElementT>(operands[1])) {
+ auto lhs = cast<LAttrElementT>(operands[0]);
+ auto rhs = cast<RAttrElementT>(operands[1]);
+ if constexpr (std::is_same_v<LElementValueT, RElementValueT>)
+ if (lhs.getType() != rhs.getType())
+ return {};
auto calRes = calculate(lhs.getValue(), rhs.getValue());
@@ -82,11 +84,12 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
// just fold based on the splat value.
auto lhs = cast<SplatElementsAttr>(operands[0]);
auto rhs = cast<SplatElementsAttr>(operands[1]);
- if (lhs.getType() != rhs.getType())
- return {};
+ if constexpr (std::is_same_v<LElementValueT, RElementValueT>)
+ if (lhs.getType() != rhs.getType())
+ return {};
- auto elementResult = calculate(lhs.getSplatValue<ElementValueT>(),
- rhs.getSplatValue<ElementValueT>());
+ auto elementResult = calculate(lhs.getSplatValue<LElementValueT>(),
+ rhs.getSplatValue<RElementValueT>());
if (!elementResult)
return {};
@@ -98,11 +101,12 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
// expanding the values.
auto lhs = cast<ElementsAttr>(operands[0]);
auto rhs = cast<ElementsAttr>(operands[1]);
- if (lhs.getType() != rhs.getType())
- return {};
+ if constexpr (std::is_same_v<LElementValueT, RElementValueT>)
+ if (lhs.getType() != rhs.getType())
+ return {};
- auto maybeLhsIt = lhs.try_value_begin<ElementValueT>();
- auto maybeRhsIt = rhs.try_value_begin<ElementValueT>();
+ auto maybeLhsIt = lhs.try_value_begin<LElementValueT>();
+ auto maybeRhsIt = rhs.try_value_begin<RElementValueT>();
if (!maybeLhsIt || !maybeRhsIt)
return {};
auto lhsIt = *maybeLhsIt;
@@ -127,13 +131,14 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
/// attribute.
/// Optional PoisonAttr template argument allows to specify 'poison' attribute
/// which will be directly propagated to result.
-template <class AttrElementT, //
- class ElementValueT = typename AttrElementT::ValueType,
+template <class LAttrElementT, class RAttrElementT = LAttrElementT,
+ class LElementValueT = typename LAttrElementT::ValueType,
+ class RElementValueT = typename RAttrElementT::ValueType,
class PoisonAttr = ub::PoisonAttr,
- class ResultAttrElementT = AttrElementT,
+ class ResultAttrElementT = LAttrElementT,
class ResultElementValueT = typename ResultAttrElementT::ValueType,
- class CalculationT = function_ref<
- std::optional<ResultElementValueT>(ElementValueT, ElementValueT)>>
+ class CalculationT = function_ref<std::optional<ResultElementValueT>(
+ LElementValueT, RElementValueT)>>
Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
CalculationT &&calculate) {
assert(operands.size() == 2 && "binary op takes two operands");
@@ -159,44 +164,49 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
Type rhsType = getAttrType(operands[1]);
if (!lhsType || !rhsType)
return {};
- if (lhsType != rhsType)
- return {};
+ if constexpr (std::is_same_v<LElementValueT, RElementValueT>)
+ if (lhsType != rhsType)
+ return {};
- return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
- ResultAttrElementT, ResultElementValueT,
- CalculationT>(
+ return constFoldBinaryOpConditional<
+ LAttrElementT, RAttrElementT, LElementValueT, RElementValueT, PoisonAttr,
+ ResultAttrElementT, ResultElementValueT, CalculationT>(
operands, lhsType, std::forward<CalculationT>(calculate));
}
-template <class AttrElementT,
- class ElementValueT = typename AttrElementT::ValueType,
+template <class LAttrElementT, class RAttrElementT = LAttrElementT,
+ class LElementValueT = typename LAttrElementT::ValueType,
+ class RElementValueT = typename RAttrElementT::ValueType,
class PoisonAttr = void, //
- class ResultAttrElementT = AttrElementT,
+ class ResultAttrElementT = LAttrElementT,
class ResultElementValueT = typename ResultAttrElementT::ValueType,
class CalculationT =
- function_ref<ResultElementValueT(ElementValueT, ElementValueT)>>
+ function_ref<ResultElementValueT(LElementValueT, RElementValueT)>>
Attribute constFoldBinaryOp(ArrayRef<Attribute> operands, Type resultType,
CalculationT &&calculate) {
- return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
- ResultAttrElementT>(
+ return constFoldBinaryOpConditional<LAttrElementT, RAttrElementT,
+ LElementValueT, RElementValueT,
+ PoisonAttr, ResultAttrElementT>(
operands, resultType,
- [&](ElementValueT a, ElementValueT b)
+ [&](LElementValueT a, RElementValueT b)
-> std::optional<ResultElementValueT> { return calculate(a, b); });
}
-template <class AttrElementT, //
- class ElementValueT = typename AttrElementT::ValueType,
+template <class LAttrElementT, class RAttrElementT = LAttrElementT,
+ class LElementValueT = typename LAttrElementT::ValueType,
+ class RElementValueT = typename RAttrElementT::ValueType,
class PoisonAttr = ub::PoisonAttr,
- class ResultAttrElementT = AttrElementT,
+ class ResultAttrElementT = LAttrElementT,
class ResultElementValueT = typename ResultAttrElementT::ValueType,
class CalculationT =
- function_ref<ResultElementValueT(ElementValueT, ElementValueT)>>
+ function_ref<ResultElementValueT(LElementValueT, RElementValueT)>>
Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
CalculationT &&calculate) {
- return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
- ResultAttrElementT>(
+ return constFoldBinaryOpConditional<LAttrElementT, RAttrElementT,
+ LElementValueT, RElementValueT,
+ PoisonAttr, ResultAttrElementT>(
operands,
- [&](ElementValueT a, ElementValueT b)
+ [&](LElementValueT a, RElementValueT b)
-> std::optional<ResultElementValueT> { return calculate(a, b); });
}
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index 1265bfb18aaa2..90f3f121a16d9 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -1148,7 +1148,7 @@ def Math_FPowIOp : Math_Op<"fpowi",
The operation is elementwise for non-scalars, e.g.:
```mlir
- %v = math.fpowi %base, %power : vector<2xf32>, vector<2xi32
+ %v = math.fpowi %base, %power : vector<2xf32>, vector<2xi32>
```
The result is a vector of:
@@ -1172,9 +1172,7 @@ def Math_FPowIOp : Math_Op<"fpowi",
let assemblyFormat = [{ $lhs `,` $rhs (`fastmath` `` $fastmath^)?
attr-dict `:` type($lhs) `,` type($rhs) }];
- // TODO: add a constant folder using pow[f] for cases, when
- // the power argument is exactly representable in floating
- // point type of the base.
+ let hasFolder = 1;
}
#endif // MATH_OPS
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 4c0274ddb18a1..bb552bd253b5f 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -776,6 +776,34 @@ OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) {
});
}
+//===----------------------------------------------------------------------===//
+// FPowIOp folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::FPowIOp::fold(FoldAdaptor adaptor) {
+ return constFoldBinaryOpConditional<FloatAttr, IntegerAttr>(
+ adaptor.getOperands(),
+ [](const APFloat &base, const APInt &exp) -> std::optional<APFloat> {
+ const llvm::fltSemantics &sem = base.getSemantics();
+ // Fold when the exponent is exactly representable in the
+ // floating-point type of the base.
+ APFloat fExp(sem);
+ if (fExp.convertFromAPInt(exp, /*isSigned=*/true,
+ APFloat::rmNearestTiesToEven) !=
+ APFloat::opOK)
+ return {};
+
+ switch (APFloat::getSizeInBits(sem)) {
+ case 64:
+ return APFloat(pow(base.convertToDouble(), fExp.convertToDouble()));
+ case 32:
+ return APFloat(powf(base.convertToFloat(), fExp.convertToFloat()));
+ default:
+ return {};
+ }
+ });
+}
+
/// Materialize an integer or floating point constant.
Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
diff --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir
index 67235c38e9cdf..228faa31781c4 100644
--- a/mlir/test/Dialect/Math/canonicalize.mlir
+++ b/mlir/test/Dialect/Math/canonicalize.mlir
@@ -614,3 +614,36 @@ func.func @ipowi_i1_const_neg_exp() -> i1 {
%r = math.ipowi %b, %e : i1
return %r : i1
}
+
+// CHECK-LABEL: @fpowi_fold
+// CHECK: %[[cst:.+]] = arith.constant 4.000000e+00 : f64
+// CHECK: %[[cst0:.+]] = arith.constant 4.000000e+00 : f32
+// CHECK: return %[[cst]], %[[cst0]] : f64, f32
+func.func @fpowi_fold() -> (f64, f32) {
+ %cst = arith.constant 2.000000e+00 : f64
+ %cst_0 = arith.constant 2.000000e+00 : f32
+ %c2_i32 = arith.constant 2 : i32
+ %0 = math.fpowi %cst, %c2_i32 : f64, i32
+ %1 = math.fpowi %cst_0, %c2_i32 : f32, i32
+ return %0, %1 : f64, f32
+}
+
+// CHECK-LABEL: @fpowi_fold_vec
+// CHECK: %[[cst:.+]] = arith.constant dense<[1.000000e+00, 1.600000e+01, 9.000000e+00, 1.600000e+01]> : vector<4xf32>
+// CHECK: return %[[cst]]
+func.func @fpowi_fold_vec() -> vector<4xf32> {
+ %cst = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
+ %cst_0 = arith.constant dense<[2, 4, 2, 2]> : vector<4xi32>
+ %0 = math.fpowi %cst, %cst_0 : vector<4xf32>, vector<4xi32>
+ return %0 : vector<4xf32>
+}
+
+// 16777217 is not exactly representable in f32.
+// CHECK-LABEL: @fpowi_fold_failed
+// CHECK: math.fpowi
+func.func @fpowi_fold_failed() -> f32 {
+ %cst = arith.constant 2.000000e+00 : f32
+ %c16777217_i32 = arith.constant 16777217 : i32
+ %0 = math.fpowi %cst, %c16777217_i32 : f32, i32
+ return %0 : f32
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index c8be4bf3f0f8d..55e72b57cfd1b 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -265,7 +265,8 @@ struct FoldLessThanOpF32ToI1 : public OpRewritePattern<test::LessThanOp> {
Attribute operandAttrs[2] = {lhsAttr, rhsAttr};
TypedAttr res = cast_or_null<TypedAttr>(
- constFoldBinaryOp<FloatAttr, FloatAttr::ValueType, void, IntegerAttr>(
+ constFoldBinaryOp<FloatAttr, FloatAttr, FloatAttr::ValueType,
+ FloatAttr::ValueType, void, IntegerAttr>(
operandAttrs, op.getType(), [](APFloat lhs, APFloat rhs) -> APInt {
return APInt(1, lhs < rhs);
}));
``````````
</details>
https://github.com/llvm/llvm-project/pull/193761
More information about the Mlir-commits
mailing list