[Mlir-commits] [mlir] [mlir][math] Propagate fast math attrs in AlgebraicSimplification (PR #166802)

Aleksei Nurmukhametov llvmlistbot at llvm.org
Mon Nov 10 07:01:39 PST 2025


https://github.com/nurmukhametov updated https://github.com/llvm/llvm-project/pull/166802

>From 5378706533f3202a566d373c7c63410140ee9872 Mon Sep 17 00:00:00 2001
From: Aleksei Nurmukhametov <anurmukh at amd.com>
Date: Thu, 6 Nov 2025 16:44:12 +0000
Subject: [PATCH 1/4] [NFC][mlir][math] Minor code cleanup in
 AlgebraicSimplification

---
 .../Dialect/Math/Transforms/AlgebraicSimplification.cpp  | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index 77b10cec48d8e..7fb26f487e3af 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -66,7 +66,7 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
   // Maybe broadcasts scalar value into vector type compatible with `op`.
   auto bcast = [&](Value value) -> Value {
     if (auto vec = dyn_cast<VectorType>(op.getType()))
-      return vector::BroadcastOp::create(rewriter, op.getLoc(), vec, value);
+      return vector::BroadcastOp::create(rewriter, loc, vec, value);
     return value;
   };
 
@@ -84,8 +84,7 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
 
   // Replace `pow(x, 3.0)` with `x * x * x`.
   if (isExponentValue(3.0)) {
-    Value square =
-        arith::MulFOp::create(rewriter, op.getLoc(), ValueRange({x, x}));
+    Value square = arith::MulFOp::create(rewriter, loc, ValueRange({x, x}));
     rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, square}));
     return success();
   }
@@ -113,8 +112,8 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
 
   // Replace `pow(x, 0.75)` with `sqrt(sqrt(x)) * sqrt(x)`.
   if (isExponentValue(0.75)) {
-    Value powHalf = math::SqrtOp::create(rewriter, op.getLoc(), x);
-    Value powQuarter = math::SqrtOp::create(rewriter, op.getLoc(), powHalf);
+    Value powHalf = math::SqrtOp::create(rewriter, loc, x);
+    Value powQuarter = math::SqrtOp::create(rewriter, loc, powHalf);
     rewriter.replaceOpWithNewOp<arith::MulFOp>(op,
                                                ValueRange{powHalf, powQuarter});
     return success();

>From 4b463cec6ccaecae403b1aa7b7ac7cfb345d8169 Mon Sep 17 00:00:00 2001
From: Aleksei Nurmukhametov <anurmukh at amd.com>
Date: Thu, 6 Nov 2025 16:51:04 +0000
Subject: [PATCH 2/4] [mlir][math] Propagate fast math attrs in
 AlgebraicSimplification

---
 .../Transforms/AlgebraicSimplification.cpp    | 20 +++++++++----------
 1 file changed, 10 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index 7fb26f487e3af..677d7505662a0 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -43,6 +43,7 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
                                        PatternRewriter &rewriter) const {
   Location loc = op.getLoc();
   Value x = op.getLhs();
+  auto fmf = op.getFastmathAttr().getValue();
 
   FloatAttr scalarExponent;
   DenseFPElementsAttr vectorExponent;
@@ -78,14 +79,14 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
 
   // Replace `pow(x, 2.0)` with `x * x`.
   if (isExponentValue(2.0)) {
-    rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, x}));
+    rewriter.replaceOpWithNewOp<arith::MulFOp>(op, x, x, fmf);
     return success();
   }
 
   // Replace `pow(x, 3.0)` with `x * x * x`.
   if (isExponentValue(3.0)) {
-    Value square = arith::MulFOp::create(rewriter, loc, ValueRange({x, x}));
-    rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, square}));
+    Value square = arith::MulFOp::create(rewriter, loc, x, x, fmf);
+    rewriter.replaceOpWithNewOp<arith::MulFOp>(op, x, square, fmf);
     return success();
   }
 
@@ -94,28 +95,27 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
     Value one = arith::ConstantOp::create(
         rewriter, loc,
         rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
-    rewriter.replaceOpWithNewOp<arith::DivFOp>(op, ValueRange({bcast(one), x}));
+    rewriter.replaceOpWithNewOp<arith::DivFOp>(op, bcast(one), x, fmf);
     return success();
   }
 
   // Replace `pow(x, 0.5)` with `sqrt(x)`.
   if (isExponentValue(0.5)) {
-    rewriter.replaceOpWithNewOp<math::SqrtOp>(op, x);
+    rewriter.replaceOpWithNewOp<math::SqrtOp>(op, x, fmf);
     return success();
   }
 
   // Replace `pow(x, -0.5)` with `rsqrt(x)`.
   if (isExponentValue(-0.5)) {
-    rewriter.replaceOpWithNewOp<math::RsqrtOp>(op, x);
+    rewriter.replaceOpWithNewOp<math::RsqrtOp>(op, x, fmf);
     return success();
   }
 
   // Replace `pow(x, 0.75)` with `sqrt(sqrt(x)) * sqrt(x)`.
   if (isExponentValue(0.75)) {
-    Value powHalf = math::SqrtOp::create(rewriter, loc, x);
-    Value powQuarter = math::SqrtOp::create(rewriter, loc, powHalf);
-    rewriter.replaceOpWithNewOp<arith::MulFOp>(op,
-                                               ValueRange{powHalf, powQuarter});
+    Value powHalf = math::SqrtOp::create(rewriter, loc, x, fmf);
+    Value powQuarter = math::SqrtOp::create(rewriter, loc, powHalf, fmf);
+    rewriter.replaceOpWithNewOp<arith::MulFOp>(op, powHalf, powQuarter, fmf);
     return success();
   }
 

>From bfca113f99f264816ad13a6fee3bdc0cde54ed79 Mon Sep 17 00:00:00 2001
From: Aleksei Nurmukhametov <anurmukh at amd.com>
Date: Fri, 7 Nov 2025 15:28:30 +0000
Subject: [PATCH 3/4] spell out fmf var type

---
 mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index 677d7505662a0..ff5f7f685903f 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -43,7 +43,7 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
                                        PatternRewriter &rewriter) const {
   Location loc = op.getLoc();
   Value x = op.getLhs();
-  auto fmf = op.getFastmathAttr().getValue();
+  arith::FastMathFlags fmf = op.getFastmathAttr().getValue();
 
   FloatAttr scalarExponent;
   DenseFPElementsAttr vectorExponent;

>From ea4c56b32717c64937462bcfe2d03095fe3b6ef3 Mon Sep 17 00:00:00 2001
From: Aleksei Nurmukhametov <anurmukh at amd.com>
Date: Fri, 7 Nov 2025 17:51:50 +0000
Subject: [PATCH 4/4] Add lit tests for fast math flags

---
 .../Transforms/AlgebraicSimplification.cpp    | 10 ++-
 .../Math/algebraic-simplification.mlir        | 80 +++++++++++++++++++
 2 files changed, 87 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index ff5f7f685903f..0d800ec6d8d02 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -44,6 +44,9 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
   Location loc = op.getLoc();
   Value x = op.getLhs();
   arith::FastMathFlags fmf = op.getFastmathAttr().getValue();
+  arith::FastMathFlags intermediateFmf = arith::bitEnumClear(
+      fmf, arith::FastMathFlags::reassoc | arith::FastMathFlags::contract |
+               arith::FastMathFlags::arcp);
 
   FloatAttr scalarExponent;
   DenseFPElementsAttr vectorExponent;
@@ -85,7 +88,7 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
 
   // Replace `pow(x, 3.0)` with `x * x * x`.
   if (isExponentValue(3.0)) {
-    Value square = arith::MulFOp::create(rewriter, loc, x, x, fmf);
+    Value square = arith::MulFOp::create(rewriter, loc, x, x, intermediateFmf);
     rewriter.replaceOpWithNewOp<arith::MulFOp>(op, x, square, fmf);
     return success();
   }
@@ -113,8 +116,9 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
 
   // Replace `pow(x, 0.75)` with `sqrt(sqrt(x)) * sqrt(x)`.
   if (isExponentValue(0.75)) {
-    Value powHalf = math::SqrtOp::create(rewriter, loc, x, fmf);
-    Value powQuarter = math::SqrtOp::create(rewriter, loc, powHalf, fmf);
+    Value powHalf = math::SqrtOp::create(rewriter, loc, x, intermediateFmf);
+    Value powQuarter =
+        math::SqrtOp::create(rewriter, loc, powHalf, intermediateFmf);
     rewriter.replaceOpWithNewOp<arith::MulFOp>(op, powHalf, powQuarter, fmf);
     return success();
   }
diff --git a/mlir/test/Dialect/Math/algebraic-simplification.mlir b/mlir/test/Dialect/Math/algebraic-simplification.mlir
index e0e2b9853a2a1..7310469597504 100644
--- a/mlir/test/Dialect/Math/algebraic-simplification.mlir
+++ b/mlir/test/Dialect/Math/algebraic-simplification.mlir
@@ -22,6 +22,18 @@ func.func @pow_square(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>)
   return %0, %1 : f32, vector<4xf32>
 }
 
+// CHECK-LABEL: @pow_square_fast
+func.func @pow_square_fast(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
+  // CHECK: %[[SCALAR:.*]] = arith.mulf %arg0, %arg0 fastmath<fast>
+  // CHECK: %[[VECTOR:.*]] = arith.mulf %arg1, %arg1 fastmath<fast>
+  // CHECK: return %[[SCALAR]], %[[VECTOR]]
+  %c = arith.constant 2.0 : f32
+  %v = arith.constant dense <2.0> : vector<4xf32>
+  %0 = math.powf %arg0, %c fastmath<fast> : f32
+  %1 = math.powf %arg1, %v fastmath<fast> : vector<4xf32>
+  return %0, %1 : f32, vector<4xf32>
+}
+
 // CHECK-LABEL: @pow_cube
 func.func @pow_cube(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
   // CHECK: %[[TMP_S:.*]] = arith.mulf %arg0, %arg0
@@ -36,6 +48,20 @@ func.func @pow_cube(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
   return %0, %1 : f32, vector<4xf32>
 }
 
+// CHECK-LABEL: @pow_cube_fast
+func.func @pow_cube_fast(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
+  // CHECK: %[[TMP_S:.*]] = arith.mulf %arg0, %arg0 fastmath<nnan,ninf,nsz,afn>
+  // CHECK: %[[SCALAR:.*]] = arith.mulf %arg0, %[[TMP_S]] fastmath<fast>
+  // CHECK: %[[TMP_V:.*]] = arith.mulf %arg1, %arg1 fastmath<nnan,ninf,nsz,afn>
+  // CHECK: %[[VECTOR:.*]] = arith.mulf %arg1, %[[TMP_V]] fastmath<fast>
+  // CHECK: return %[[SCALAR]], %[[VECTOR]]
+  %c = arith.constant 3.0 : f32
+  %v = arith.constant dense <3.0> : vector<4xf32>
+  %0 = math.powf %arg0, %c fastmath<fast> : f32
+  %1 = math.powf %arg1, %v fastmath<fast> : vector<4xf32>
+  return %0, %1 : f32, vector<4xf32>
+}
+
 // CHECK-LABEL: @pow_recip
 func.func @pow_recip(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
   // CHECK-DAG: %[[CST_S:.*]] = arith.constant 1.0{{.*}} : f32
@@ -50,6 +76,20 @@ func.func @pow_recip(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>)
   return %0, %1 : f32, vector<4xf32>
 }
 
+// CHECK-LABEL: @pow_recip_fast
+func.func @pow_recip_fast(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
+  // CHECK-DAG: %[[CST_S:.*]] = arith.constant 1.0{{.*}} : f32
+  // CHECK-DAG: %[[CST_V:.*]] = arith.constant dense<1.0{{.*}}> : vector<4xf32>
+  // CHECK: %[[SCALAR:.*]] = arith.divf %[[CST_S]], %arg0 fastmath<fast>
+  // CHECK: %[[VECTOR:.*]] = arith.divf %[[CST_V]], %arg1 fastmath<fast>
+  // CHECK: return %[[SCALAR]], %[[VECTOR]]
+  %c = arith.constant -1.0 : f32
+  %v = arith.constant dense <-1.0> : vector<4xf32>
+  %0 = math.powf %arg0, %c fastmath<fast> : f32
+  %1 = math.powf %arg1, %v fastmath<fast> : vector<4xf32>
+  return %0, %1 : f32, vector<4xf32>
+}
+
 // CHECK-LABEL: @pow_sqrt
 func.func @pow_sqrt(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
   // CHECK: %[[SCALAR:.*]] = math.sqrt %arg0
@@ -62,6 +102,18 @@ func.func @pow_sqrt(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
   return %0, %1 : f32, vector<4xf32>
 }
 
+// CHECK-LABEL: @pow_sqrt_fast
+func.func @pow_sqrt_fast(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
+  // CHECK: %[[SCALAR:.*]] = math.sqrt %arg0 fastmath<fast>
+  // CHECK: %[[VECTOR:.*]] = math.sqrt %arg1 fastmath<fast>
+  // CHECK: return %[[SCALAR]], %[[VECTOR]]
+  %c = arith.constant 0.5 : f32
+  %v = arith.constant dense <0.5> : vector<4xf32>
+  %0 = math.powf %arg0, %c fastmath<fast> : f32
+  %1 = math.powf %arg1, %v fastmath<fast> : vector<4xf32>
+  return %0, %1 : f32, vector<4xf32>
+}
+
 // CHECK-LABEL: @pow_rsqrt
 func.func @pow_rsqrt(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
   // CHECK: %[[SCALAR:.*]] = math.rsqrt %arg0
@@ -74,6 +126,18 @@ func.func @pow_rsqrt(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>)
   return %0, %1 : f32, vector<4xf32>
 }
 
+// CHECK-LABEL: @pow_rsqrt_fast
+func.func @pow_rsqrt_fast(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
+  // CHECK: %[[SCALAR:.*]] = math.rsqrt %arg0 fastmath<fast>
+  // CHECK: %[[VECTOR:.*]] = math.rsqrt %arg1 fastmath<fast>
+  // CHECK: return %[[SCALAR]], %[[VECTOR]]
+  %c = arith.constant -0.5 : f32
+  %v = arith.constant dense <-0.5> : vector<4xf32>
+  %0 = math.powf %arg0, %c fastmath<fast> : f32
+  %1 = math.powf %arg1, %v fastmath<fast> : vector<4xf32>
+  return %0, %1 : f32, vector<4xf32>
+}
+
 // CHECK-LABEL: @pow_0_75
 func.func @pow_0_75(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
   // CHECK: %[[SQRT1S:.*]] = math.sqrt %arg0
@@ -90,6 +154,22 @@ func.func @pow_0_75(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
   return %0, %1 : f32, vector<4xf32>
 }
 
+// CHECK-LABEL: @pow_0_75_fast
+func.func @pow_0_75_fast(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
+  // CHECK: %[[SQRT1S:.*]] = math.sqrt %arg0 fastmath<nnan,ninf,nsz,afn>
+  // CHECK: %[[SQRT2S:.*]] = math.sqrt %[[SQRT1S]] fastmath<nnan,ninf,nsz,afn>
+  // CHECK: %[[SCALAR:.*]] = arith.mulf %[[SQRT1S]], %[[SQRT2S]] fastmath<fast>
+  // CHECK: %[[SQRT1V:.*]] = math.sqrt %arg1 fastmath<nnan,ninf,nsz,afn>
+  // CHECK: %[[SQRT2V:.*]] = math.sqrt %[[SQRT1V]] fastmath<nnan,ninf,nsz,afn>
+  // CHECK: %[[VECTOR:.*]] = arith.mulf %[[SQRT1V]], %[[SQRT2V]] fastmath<fast>
+  // CHECK: return %[[SCALAR]], %[[VECTOR]]
+  %c = arith.constant 0.75 : f32
+  %v = arith.constant dense <0.75> : vector<4xf32>
+  %0 = math.powf %arg0, %c fastmath<fast> : f32
+  %1 = math.powf %arg1, %v fastmath<fast> : vector<4xf32>
+  return %0, %1 : f32, vector<4xf32>
+}
+
 // CHECK-LABEL: @ipowi_zero_exp(
 // CHECK-SAME: %[[ARG0:.+]]: i32
 // CHECK-SAME: %[[ARG1:.+]]: vector<4xi32>



More information about the Mlir-commits mailing list