[Mlir-commits] [mlir] [mlir][math] Add constant folding for math.rsqrt (PR #184443)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 3 14:05:05 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-math
Author: Ian Wood (IanWood1)
<details>
<summary>Changes</summary>
Add a fold() method to RsqrtOp, matching the pattern used by SqrtOp and other math unary ops. The fold computes `1.0 / sqrt(x)` using APFloat division.
---
Full diff: https://github.com/llvm/llvm-project/pull/184443.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/Math/IR/MathOps.td (+1)
- (modified) mlir/lib/Dialect/Math/IR/MathOps.cpp (+22)
- (modified) mlir/test/Dialect/Math/canonicalize.mlir (+18)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index df5787dc48403..1265bfb18aaa2 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -995,6 +995,7 @@ def Math_RsqrtOp : Math_FloatUnaryOp<"rsqrt"> {
%a = math.rsqrt %b : f64
```
}];
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 7e9d4acae6822..4c0274ddb18a1 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -517,6 +517,28 @@ OpFoldResult math::PowFOp::fold(FoldAdaptor adaptor) {
});
}
+//===----------------------------------------------------------------------===//
+// RsqrtOp folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::RsqrtOp::fold(FoldAdaptor adaptor) {
+ return constFoldUnaryOpConditional<FloatAttr>(
+ adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
+ if (a.isNegative())
+ return {};
+
+ APFloat one(a.getSemantics(), 1);
+ switch (a.getSizeInBits(a.getSemantics())) {
+ case 64:
+ return one / APFloat(sqrt(a.convertToDouble()));
+ case 32:
+ return one / APFloat(sqrtf(a.convertToFloat()));
+ default:
+ return {};
+ }
+ });
+}
+
//===----------------------------------------------------------------------===//
// SqrtOp folder
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir
index c900a24e821b4..36a8cd99be661 100644
--- a/mlir/test/Dialect/Math/canonicalize.mlir
+++ b/mlir/test/Dialect/Math/canonicalize.mlir
@@ -102,6 +102,24 @@ func.func @powf_fold_vec() -> (vector<4xf32>) {
return %0 : vector<4xf32>
}
+// CHECK-LABEL: @rsqrt_fold
+// CHECK: %[[cst:.+]] = arith.constant 5.000000e-01 : f32
+// CHECK: return %[[cst]]
+func.func @rsqrt_fold() -> f32 {
+ %c = arith.constant 4.0 : f32
+ %r = math.rsqrt %c : f32
+ return %r : f32
+}
+
+// CHECK-LABEL: @rsqrt_fold_vec
+// CHECK: %[[cst:.+]] = arith.constant dense<[1.000000e+00, 5.000000e-01]> : vector<2xf32>
+// CHECK: return %[[cst]]
+func.func @rsqrt_fold_vec() -> (vector<2xf32>) {
+ %v1 = arith.constant dense<[1.0, 4.0]> : vector<2xf32>
+ %0 = math.rsqrt %v1 : vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
// CHECK-LABEL: @sqrt_fold
// CHECK: %[[cst:.+]] = arith.constant 2.000000e+00 : f32
// CHECK: return %[[cst]]
``````````
</details>
https://github.com/llvm/llvm-project/pull/184443
More information about the Mlir-commits
mailing list