[Mlir-commits] [mlir] [mlir][math] Add constant folding for math.rsqrt (PR #184443)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 3 14:05:05 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-math

Author: Ian Wood (IanWood1)

<details>
<summary>Changes</summary>

Add a fold() method to RsqrtOp, matching the pattern used by SqrtOp and other math unary ops. The fold computes `1.0 / sqrt(x)` using APFloat division.

---
Full diff: https://github.com/llvm/llvm-project/pull/184443.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/Math/IR/MathOps.td (+1) 
- (modified) mlir/lib/Dialect/Math/IR/MathOps.cpp (+22) 
- (modified) mlir/test/Dialect/Math/canonicalize.mlir (+18) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index df5787dc48403..1265bfb18aaa2 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -995,6 +995,7 @@ def Math_RsqrtOp : Math_FloatUnaryOp<"rsqrt"> {
     %a = math.rsqrt %b : f64
     ```
   }];
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 7e9d4acae6822..4c0274ddb18a1 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -517,6 +517,28 @@ OpFoldResult math::PowFOp::fold(FoldAdaptor adaptor) {
       });
 }
 
+//===----------------------------------------------------------------------===//
+// RsqrtOp folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::RsqrtOp::fold(FoldAdaptor adaptor) {
+  return constFoldUnaryOpConditional<FloatAttr>(
+      adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
+        if (a.isNegative())
+          return {};
+
+        APFloat one(a.getSemantics(), 1);
+        switch (a.getSizeInBits(a.getSemantics())) {
+        case 64:
+          return one / APFloat(sqrt(a.convertToDouble()));
+        case 32:
+          return one / APFloat(sqrtf(a.convertToFloat()));
+        default:
+          return {};
+        }
+      });
+}
+
 //===----------------------------------------------------------------------===//
 // SqrtOp folder
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir
index c900a24e821b4..36a8cd99be661 100644
--- a/mlir/test/Dialect/Math/canonicalize.mlir
+++ b/mlir/test/Dialect/Math/canonicalize.mlir
@@ -102,6 +102,24 @@ func.func @powf_fold_vec() -> (vector<4xf32>) {
   return %0 : vector<4xf32>
 }
 
+// CHECK-LABEL: @rsqrt_fold
+// CHECK: %[[cst:.+]] = arith.constant 5.000000e-01 : f32
+// CHECK: return %[[cst]]
+func.func @rsqrt_fold() -> f32 {
+  %c = arith.constant 4.0 : f32
+  %r = math.rsqrt %c : f32
+  return %r : f32
+}
+
+// CHECK-LABEL: @rsqrt_fold_vec
+// CHECK: %[[cst:.+]] = arith.constant dense<[1.000000e+00, 5.000000e-01]> : vector<2xf32>
+// CHECK: return %[[cst]]
+func.func @rsqrt_fold_vec() -> (vector<2xf32>) {
+  %v1 = arith.constant dense<[1.0, 4.0]> : vector<2xf32>
+  %0 = math.rsqrt %v1 : vector<2xf32>
+  return %0 : vector<2xf32>
+}
+
 // CHECK-LABEL: @sqrt_fold
 // CHECK: %[[cst:.+]] = arith.constant 2.000000e+00 : f32
 // CHECK: return %[[cst]]

``````````

</details>


https://github.com/llvm/llvm-project/pull/184443


More information about the Mlir-commits mailing list