[Mlir-commits] [mlir] [mlir][math] Propagate fast math attrs in AlgebraicSimplification (PR #166802)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 6 09:07:23 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-math

@llvm/pr-subscribers-mlir

Author: Aleksei Nurmukhametov (nurmukhametov)

<details>
<summary>Changes</summary>

Fix missing propagation of fast-math flags in algebraic simplification patterns of the MLIR math dialect.

---
Full diff: https://github.com/llvm/llvm-project/pull/166802.diff


1 Files Affected:

- (modified) mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp (+11-12) 


``````````diff
diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index 77b10cec48d8e..677d7505662a0 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -43,6 +43,7 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
                                        PatternRewriter &rewriter) const {
   Location loc = op.getLoc();
   Value x = op.getLhs();
+  auto fmf = op.getFastmathAttr().getValue();
 
   FloatAttr scalarExponent;
   DenseFPElementsAttr vectorExponent;
@@ -66,7 +67,7 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
   // Maybe broadcasts scalar value into vector type compatible with `op`.
   auto bcast = [&](Value value) -> Value {
     if (auto vec = dyn_cast<VectorType>(op.getType()))
-      return vector::BroadcastOp::create(rewriter, op.getLoc(), vec, value);
+      return vector::BroadcastOp::create(rewriter, loc, vec, value);
     return value;
   };
 
@@ -78,15 +79,14 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
 
   // Replace `pow(x, 2.0)` with `x * x`.
   if (isExponentValue(2.0)) {
-    rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, x}));
+    rewriter.replaceOpWithNewOp<arith::MulFOp>(op, x, x, fmf);
     return success();
   }
 
   // Replace `pow(x, 3.0)` with `x * x * x`.
   if (isExponentValue(3.0)) {
-    Value square =
-        arith::MulFOp::create(rewriter, op.getLoc(), ValueRange({x, x}));
-    rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, square}));
+    Value square = arith::MulFOp::create(rewriter, loc, x, x, fmf);
+    rewriter.replaceOpWithNewOp<arith::MulFOp>(op, x, square, fmf);
     return success();
   }
 
@@ -95,28 +95,27 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
     Value one = arith::ConstantOp::create(
         rewriter, loc,
         rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
-    rewriter.replaceOpWithNewOp<arith::DivFOp>(op, ValueRange({bcast(one), x}));
+    rewriter.replaceOpWithNewOp<arith::DivFOp>(op, bcast(one), x, fmf);
     return success();
   }
 
   // Replace `pow(x, 0.5)` with `sqrt(x)`.
   if (isExponentValue(0.5)) {
-    rewriter.replaceOpWithNewOp<math::SqrtOp>(op, x);
+    rewriter.replaceOpWithNewOp<math::SqrtOp>(op, x, fmf);
     return success();
   }
 
   // Replace `pow(x, -0.5)` with `rsqrt(x)`.
   if (isExponentValue(-0.5)) {
-    rewriter.replaceOpWithNewOp<math::RsqrtOp>(op, x);
+    rewriter.replaceOpWithNewOp<math::RsqrtOp>(op, x, fmf);
     return success();
   }
 
   // Replace `pow(x, 0.75)` with `sqrt(sqrt(x)) * sqrt(x)`.
   if (isExponentValue(0.75)) {
-    Value powHalf = math::SqrtOp::create(rewriter, op.getLoc(), x);
-    Value powQuarter = math::SqrtOp::create(rewriter, op.getLoc(), powHalf);
-    rewriter.replaceOpWithNewOp<arith::MulFOp>(op,
-                                               ValueRange{powHalf, powQuarter});
+    Value powHalf = math::SqrtOp::create(rewriter, loc, x, fmf);
+    Value powQuarter = math::SqrtOp::create(rewriter, loc, powHalf, fmf);
+    rewriter.replaceOpWithNewOp<arith::MulFOp>(op, powHalf, powQuarter, fmf);
     return success();
   }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/166802


More information about the Mlir-commits mailing list