[Mlir-commits] [mlir] [MLIR][NVVM] Add rsqrt Op (PR #195854)

Varad Rahul Kamthe llvmlistbot at llvm.org
Tue May 5 06:28:37 PDT 2026


https://github.com/varadk27 created https://github.com/llvm/llvm-project/pull/195854

Adds `nvvm.rsqrt` op for fast approximate reciprocal square root. Supports
f32 and f64 with an optional `ftz` attribute.

For more information, see PTX ISA:
(https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-rsqrt)


>From 1a2706921c626d4f8c0737a1cac6a3cc985f4ec9 Mon Sep 17 00:00:00 2001
From: Varad Rahul Kamthe <vkamthe at nvidia.com>
Date: Mon, 4 May 2026 14:40:46 +0000
Subject: [PATCH] Add rsqrt Op

---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td   | 24 +++++++++++++++++++
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp    | 19 +++++++++++++++
 mlir/test/Target/LLVMIR/nvvm/rsqrt/rsqrt.mlir | 21 ++++++++++++++++
 3 files changed, 64 insertions(+)
 create mode 100644 mlir/test/Target/LLVMIR/nvvm/rsqrt/rsqrt.mlir

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 51396947fad4e..e9cf1eea14760 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -578,6 +578,30 @@ def NVVM_Ex2Op : NVVM_F32UnaryApproxOp<"ex2"> {
   }];
 }
 
+def NVVM_RsqrtApproxOp
+    : NVVM_SingleResultIntrinsicOp<"rsqrt", [Pure, SameOperandsAndResultType]> {
+  let summary = "Reciprocal square root (fast approximation)";
+  let description = [{
+    Computes an approximation of the reciprocal of the square root of the
+    input value: `d = 1 / sqrt(a)`. Supports both f32 and f64. The maximum
+    relative error for the f32 form over the entire positive finite range
+    is 2^-22.9.
+
+    The `ftz` attribute, when set, flushes subnormal inputs and results to
+    sign-preserving zero. For f64 inputs, `ftz=true` selects a coarser
+    approximation that uses only the upper 32 bits of the input (the lower
+    32 bits of the result are zeroed).
+
+    For more information, see PTX ISA:
+    [rsqrt](https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-rsqrt)
+  }];
+  
+  let arguments = (ins AnyTypeOf<[F32, F64]>:$src,
+      DefaultValuedAttr<BoolAttr, "false">:$ftz);
+  let results = (outs AnyTypeOf<[F32, F64]>:$res);
+  let assemblyFormat = "$src attr-dict `:` type($src)";
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM redux op definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 7770f8f2dba0a..0a49632569f0c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -3515,6 +3515,25 @@ Ex2Op::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
   return {id, {mt.lookupValue(thisOp.getSrc())}};
 }
 
+mlir::NVVM::IDArgPair
+RsqrtApproxOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+                                     llvm::IRBuilderBase &builder) {
+  auto thisOp = cast<NVVM::RsqrtApproxOp>(op);
+  Type t = thisOp.getRes().getType();
+  bool isFtz = thisOp.getFtz();
+
+  llvm::Intrinsic::ID id = [&] {
+    if (t.isF32())
+      return isFtz ? llvm::Intrinsic::nvvm_rsqrt_approx_ftz_f
+                   : llvm::Intrinsic::nvvm_rsqrt_approx_f;
+    // f64
+    return isFtz ? llvm::Intrinsic::nvvm_rsqrt_approx_ftz_d
+                 : llvm::Intrinsic::nvvm_rsqrt_approx_d;
+  }();
+
+  return {id, {mt.lookupValue(thisOp.getSrc())}};
+}
+
 mlir::NVVM::IDArgPair
 PMEventOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
                                  llvm::IRBuilderBase &builder) {
diff --git a/mlir/test/Target/LLVMIR/nvvm/rsqrt/rsqrt.mlir b/mlir/test/Target/LLVMIR/nvvm/rsqrt/rsqrt.mlir
new file mode 100644
index 0000000000000..a055311567216
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/rsqrt/rsqrt.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// f32 rsqrt -- non-ftz and ftz forms.
+llvm.func @rsqrt_f32(%a : f32) -> f32 {
+  // CHECK-LABEL: define float @rsqrt_f32(float %0) {
+  // CHECK: call float @llvm.nvvm.rsqrt.approx.f(float %{{.*}})
+  // CHECK: call float @llvm.nvvm.rsqrt.approx.ftz.f(float %{{.*}})
+  %r1 = nvvm.rsqrt %a : f32
+  %r2 = nvvm.rsqrt %r1 {ftz = true} : f32
+  llvm.return %r2 : f32
+}
+
+// f64 rsqrt -- non-ftz and ftz forms.
+llvm.func @rsqrt_f64(%a : f64) -> f64 {
+  // CHECK-LABEL: define double @rsqrt_f64(double %0) {
+  // CHECK: call double @llvm.nvvm.rsqrt.approx.d(double %{{.*}})
+  // CHECK: call double @llvm.nvvm.rsqrt.approx.ftz.d(double %{{.*}})
+  %r1 = nvvm.rsqrt %a : f64
+  %r2 = nvvm.rsqrt %r1 {ftz = true} : f64
+  llvm.return %r2 : f64
+}



More information about the Mlir-commits mailing list