[Mlir-commits] [mlir] [mlir][math] Add constant folding for math.rsqrt (PR #184443)
Ian Wood
llvmlistbot at llvm.org
Tue Mar 3 14:04:33 PST 2026
https://github.com/IanWood1 created https://github.com/llvm/llvm-project/pull/184443
Add a fold() method to RsqrtOp, matching the pattern used by SqrtOp and other math unary ops. The fold computes `1.0 / sqrt(x)` using APFloat division.
>From 62eda90c7a69e3a6be7b8f09358038607cff1a05 Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood at u.northwestern.edu>
Date: Tue, 3 Mar 2026 14:00:26 -0800
Subject: [PATCH] [mlir][math] Add constant folding for math.rsqrt
Add a fold() method to RsqrtOp, matching the pattern used by SqrtOp
and other math unary ops. This allows the canonicalizer to evaluate
math.rsqrt on constant operands at compile time.
The fold computes 1.0 / sqrt(x) using APFloat division for
IEEE 754-compliant results, rejecting negative inputs.
---
mlir/include/mlir/Dialect/Math/IR/MathOps.td | 1 +
mlir/lib/Dialect/Math/IR/MathOps.cpp | 22 ++++++++++++++++++++
mlir/test/Dialect/Math/canonicalize.mlir | 18 ++++++++++++++++
3 files changed, 41 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index df5787dc48403..1265bfb18aaa2 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -995,6 +995,7 @@ def Math_RsqrtOp : Math_FloatUnaryOp<"rsqrt"> {
%a = math.rsqrt %b : f64
```
}];
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 7e9d4acae6822..4c0274ddb18a1 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -517,6 +517,28 @@ OpFoldResult math::PowFOp::fold(FoldAdaptor adaptor) {
});
}
+//===----------------------------------------------------------------------===//
+// RsqrtOp folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::RsqrtOp::fold(FoldAdaptor adaptor) {
+ return constFoldUnaryOpConditional<FloatAttr>(
+ adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
+ if (a.isNegative())
+ return {};
+
+ APFloat one(a.getSemantics(), 1);
+ switch (a.getSizeInBits(a.getSemantics())) {
+ case 64:
+ return one / APFloat(sqrt(a.convertToDouble()));
+ case 32:
+ return one / APFloat(sqrtf(a.convertToFloat()));
+ default:
+ return {};
+ }
+ });
+}
+
//===----------------------------------------------------------------------===//
// SqrtOp folder
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir
index c900a24e821b4..36a8cd99be661 100644
--- a/mlir/test/Dialect/Math/canonicalize.mlir
+++ b/mlir/test/Dialect/Math/canonicalize.mlir
@@ -102,6 +102,24 @@ func.func @powf_fold_vec() -> (vector<4xf32>) {
return %0 : vector<4xf32>
}
+// CHECK-LABEL: @rsqrt_fold
+// CHECK: %[[cst:.+]] = arith.constant 5.000000e-01 : f32
+// CHECK: return %[[cst]]
+func.func @rsqrt_fold() -> f32 {
+ %c = arith.constant 4.0 : f32
+ %r = math.rsqrt %c : f32
+ return %r : f32
+}
+
+// CHECK-LABEL: @rsqrt_fold_vec
+// CHECK: %[[cst:.+]] = arith.constant dense<[1.000000e+00, 5.000000e-01]> : vector<2xf32>
+// CHECK: return %[[cst]]
+func.func @rsqrt_fold_vec() -> (vector<2xf32>) {
+ %v1 = arith.constant dense<[1.0, 4.0]> : vector<2xf32>
+ %0 = math.rsqrt %v1 : vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
// CHECK-LABEL: @sqrt_fold
// CHECK: %[[cst:.+]] = arith.constant 2.000000e+00 : f32
// CHECK: return %[[cst]]
More information about the Mlir-commits
mailing list