[Mlir-commits] [mlir] 03326f9 - [MLIR][NVVM] Add rsqrt Op (#195854)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 12 23:46:04 PDT 2026


Author: Varad Rahul Kamthe
Date: 2026-05-13T12:15:59+05:30
New Revision: 03326f9f0ad7a32e9ab015687f16ff09a2182f64

URL: https://github.com/llvm/llvm-project/commit/03326f9f0ad7a32e9ab015687f16ff09a2182f64
DIFF: https://github.com/llvm/llvm-project/commit/03326f9f0ad7a32e9ab015687f16ff09a2182f64.diff

LOG: [MLIR][NVVM] Add rsqrt Op (#195854)

Adds `nvvm.rsqrt` op for fast approximate reciprocal square root. Supports f32 and f64 with an optional `ftz` attribute.

For more information, see PTX ISA: https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-rsqrt

Added: 
    mlir/test/Target/LLVMIR/nvvm/rsqrt/rsqrt.mlir

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
    mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 4e0bce985fdf7..70aad7ac095b9 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -501,6 +501,34 @@ def NVVM_Ex2Op : NVVM_F32UnaryApproxOp<"ex2"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// NVVM rsqrt op definitions
+//===----------------------------------------------------------------------===//
+
+def NVVM_RsqrtOp
+    : NVVM_SingleResultIntrinsicOp<"rsqrt", [Pure, SameOperandsAndResultType]> {
+  let summary = "Reciprocal square root (fast approximation)";
+  let description = [{
+    Computes an approximation of the reciprocal of the square root of the
+    input value: `d = 1 / sqrt(a)`. Supports both f32 and f64. The maximum
+    relative error for the f32 form over the entire positive finite range
+    is 2^-22.9.
+
+    The `ftz` attribute, when set, flushes subnormal inputs and results to
+    sign-preserving zero. For f64 inputs, `ftz=true` selects a coarser
+    approximation that uses only the upper 32 bits of the input (the lower
+    32 bits of the result are zeroed).
+
+    For more information, see PTX ISA:
+    [rsqrt](https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-rsqrt)
+  }];
+
+  let arguments = (ins AnyTypeOf<[F32, F64]>:$src,
+      DefaultValuedAttr<BoolAttr, "false">:$ftz);
+  let results = (outs AnyTypeOf<[F32, F64]>:$res);
+  let assemblyFormat = "$src attr-dict `:` type($src)";
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM redux op definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 7770f8f2dba0a..943de71d34a96 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -3515,6 +3515,26 @@ Ex2Op::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
   return {id, {mt.lookupValue(thisOp.getSrc())}};
 }
 
+mlir::NVVM::IDArgPair
+RsqrtOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+                               llvm::IRBuilderBase &builder) {
+  auto thisOp = cast<NVVM::RsqrtOp>(op);
+  Type t = thisOp.getRes().getType();
+  bool isFtz = thisOp.getFtz();
+
+  llvm::Intrinsic::ID id = [&] {
+    if (t.isF32()) {
+      return isFtz ? llvm::Intrinsic::nvvm_rsqrt_approx_ftz_f
+                   : llvm::Intrinsic::nvvm_rsqrt_approx_f;
+    }
+    // f64
+    return isFtz ? llvm::Intrinsic::nvvm_rsqrt_approx_ftz_d
+                 : llvm::Intrinsic::nvvm_rsqrt_approx_d;
+  }();
+
+  return {id, {mt.lookupValue(thisOp.getSrc())}};
+}
+
 mlir::NVVM::IDArgPair
 PMEventOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
                                  llvm::IRBuilderBase &builder) {

diff  --git a/mlir/test/Target/LLVMIR/nvvm/rsqrt/rsqrt.mlir b/mlir/test/Target/LLVMIR/nvvm/rsqrt/rsqrt.mlir
new file mode 100644
index 0000000000000..a055311567216
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/rsqrt/rsqrt.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// f32 rsqrt -- non-ftz and ftz forms.
+llvm.func @rsqrt_f32(%a : f32) -> f32 {
+  // CHECK-LABEL: define float @rsqrt_f32(float %0) {
+  // CHECK: call float @llvm.nvvm.rsqrt.approx.f(float %{{.*}})
+  // CHECK: call float @llvm.nvvm.rsqrt.approx.ftz.f(float %{{.*}})
+  %r1 = nvvm.rsqrt %a : f32
+  %r2 = nvvm.rsqrt %r1 {ftz = true} : f32
+  llvm.return %r2 : f32
+}
+
+// f64 rsqrt -- non-ftz and ftz forms.
+llvm.func @rsqrt_f64(%a : f64) -> f64 {
+  // CHECK-LABEL: define double @rsqrt_f64(double %0) {
+  // CHECK: call double @llvm.nvvm.rsqrt.approx.d(double %{{.*}})
+  // CHECK: call double @llvm.nvvm.rsqrt.approx.ftz.d(double %{{.*}})
+  %r1 = nvvm.rsqrt %a : f64
+  %r2 = nvvm.rsqrt %r1 {ftz = true} : f64
+  llvm.return %r2 : f64
+}


        


More information about the Mlir-commits mailing list