[Mlir-commits] [mlir] [mlir][math] Propagate fast math attrs in AlgebraicSimplification (PR #166802)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 6 09:07:23 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-math
@llvm/pr-subscribers-mlir
Author: Aleksei Nurmukhametov (nurmukhametov)
<details>
<summary>Changes</summary>
Fix missing propagation of fast-math flags in algebraic simplification patterns of the MLIR math dialect.
---
Full diff: https://github.com/llvm/llvm-project/pull/166802.diff
1 Files Affected:
- (modified) mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp (+11-12)
``````````diff
diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index 77b10cec48d8e..677d7505662a0 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -43,6 +43,7 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
PatternRewriter &rewriter) const {
Location loc = op.getLoc();
Value x = op.getLhs();
+ auto fmf = op.getFastmathAttr().getValue();
FloatAttr scalarExponent;
DenseFPElementsAttr vectorExponent;
@@ -66,7 +67,7 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
// Maybe broadcasts scalar value into vector type compatible with `op`.
auto bcast = [&](Value value) -> Value {
if (auto vec = dyn_cast<VectorType>(op.getType()))
- return vector::BroadcastOp::create(rewriter, op.getLoc(), vec, value);
+ return vector::BroadcastOp::create(rewriter, loc, vec, value);
return value;
};
@@ -78,15 +79,14 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
// Replace `pow(x, 2.0)` with `x * x`.
if (isExponentValue(2.0)) {
- rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, x}));
+ rewriter.replaceOpWithNewOp<arith::MulFOp>(op, x, x, fmf);
return success();
}
// Replace `pow(x, 3.0)` with `x * x * x`.
if (isExponentValue(3.0)) {
- Value square =
- arith::MulFOp::create(rewriter, op.getLoc(), ValueRange({x, x}));
- rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, square}));
+ Value square = arith::MulFOp::create(rewriter, loc, x, x, fmf);
+ rewriter.replaceOpWithNewOp<arith::MulFOp>(op, x, square, fmf);
return success();
}
@@ -95,28 +95,27 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
Value one = arith::ConstantOp::create(
rewriter, loc,
rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
- rewriter.replaceOpWithNewOp<arith::DivFOp>(op, ValueRange({bcast(one), x}));
+ rewriter.replaceOpWithNewOp<arith::DivFOp>(op, bcast(one), x, fmf);
return success();
}
// Replace `pow(x, 0.5)` with `sqrt(x)`.
if (isExponentValue(0.5)) {
- rewriter.replaceOpWithNewOp<math::SqrtOp>(op, x);
+ rewriter.replaceOpWithNewOp<math::SqrtOp>(op, x, fmf);
return success();
}
// Replace `pow(x, -0.5)` with `rsqrt(x)`.
if (isExponentValue(-0.5)) {
- rewriter.replaceOpWithNewOp<math::RsqrtOp>(op, x);
+ rewriter.replaceOpWithNewOp<math::RsqrtOp>(op, x, fmf);
return success();
}
// Replace `pow(x, 0.75)` with `sqrt(sqrt(x)) * sqrt(x)`.
if (isExponentValue(0.75)) {
- Value powHalf = math::SqrtOp::create(rewriter, op.getLoc(), x);
- Value powQuarter = math::SqrtOp::create(rewriter, op.getLoc(), powHalf);
- rewriter.replaceOpWithNewOp<arith::MulFOp>(op,
- ValueRange{powHalf, powQuarter});
+ Value powHalf = math::SqrtOp::create(rewriter, loc, x, fmf);
+ Value powQuarter = math::SqrtOp::create(rewriter, loc, powHalf, fmf);
+ rewriter.replaceOpWithNewOp<arith::MulFOp>(op, powHalf, powQuarter, fmf);
return success();
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/166802
More information about the Mlir-commits
mailing list