[Mlir-commits] [mlir] [MLIR][NVVM] Add rsqrt Op (PR #195854)
Varad Rahul Kamthe
llvmlistbot at llvm.org
Tue May 5 06:28:37 PDT 2026
https://github.com/varadk27 created https://github.com/llvm/llvm-project/pull/195854
Adds `nvvm.rsqrt` op for fast approximate reciprocal square root. Supports
f32 and f64 with an optional `ftz` attribute.
For more information, see PTX ISA:
(https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-rsqrt)
>From 1a2706921c626d4f8c0737a1cac6a3cc985f4ec9 Mon Sep 17 00:00:00 2001
From: Varad Rahul Kamthe <vkamthe at nvidia.com>
Date: Mon, 4 May 2026 14:40:46 +0000
Subject: [PATCH] Add rsqrt Op
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 24 +++++++++++++++++++
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 19 +++++++++++++++
mlir/test/Target/LLVMIR/nvvm/rsqrt/rsqrt.mlir | 21 ++++++++++++++++
3 files changed, 64 insertions(+)
create mode 100644 mlir/test/Target/LLVMIR/nvvm/rsqrt/rsqrt.mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 51396947fad4e..e9cf1eea14760 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -578,6 +578,30 @@ def NVVM_Ex2Op : NVVM_F32UnaryApproxOp<"ex2"> {
}];
}
+def NVVM_RsqrtApproxOp
+ : NVVM_SingleResultIntrinsicOp<"rsqrt", [Pure, SameOperandsAndResultType]> {
+ let summary = "Reciprocal square root (fast approximation)";
+ let description = [{
+ Computes an approximation of the reciprocal of the square root of the
+ input value: `d = 1 / sqrt(a)`. Supports both f32 and f64. The maximum
+ relative error for the f32 form over the entire positive finite range
+ is 2^-22.9.
+
+ The `ftz` attribute, when set, flushes subnormal inputs and results to
+ sign-preserving zero. For f64 inputs, `ftz=true` selects a coarser
+ approximation that uses only the upper 32 bits of the input (the lower
+ 32 bits of the result are zeroed).
+
+ For more information, see PTX ISA:
+ [rsqrt](https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-rsqrt)
+ }];
+
+ let arguments = (ins AnyTypeOf<[F32, F64]>:$src,
+ DefaultValuedAttr<BoolAttr, "false">:$ftz);
+ let results = (outs AnyTypeOf<[F32, F64]>:$res);
+ let assemblyFormat = "$src attr-dict `:` type($src)";
+}
+
//===----------------------------------------------------------------------===//
// NVVM redux op definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 7770f8f2dba0a..0a49632569f0c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -3515,6 +3515,25 @@ Ex2Op::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
return {id, {mt.lookupValue(thisOp.getSrc())}};
}
+mlir::NVVM::IDArgPair
+RsqrtApproxOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::RsqrtApproxOp>(op);
+ Type t = thisOp.getRes().getType();
+ bool isFtz = thisOp.getFtz();
+
+ llvm::Intrinsic::ID id = [&] {
+ if (t.isF32())
+ return isFtz ? llvm::Intrinsic::nvvm_rsqrt_approx_ftz_f
+ : llvm::Intrinsic::nvvm_rsqrt_approx_f;
+ // f64
+ return isFtz ? llvm::Intrinsic::nvvm_rsqrt_approx_ftz_d
+ : llvm::Intrinsic::nvvm_rsqrt_approx_d;
+ }();
+
+ return {id, {mt.lookupValue(thisOp.getSrc())}};
+}
+
mlir::NVVM::IDArgPair
PMEventOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
llvm::IRBuilderBase &builder) {
diff --git a/mlir/test/Target/LLVMIR/nvvm/rsqrt/rsqrt.mlir b/mlir/test/Target/LLVMIR/nvvm/rsqrt/rsqrt.mlir
new file mode 100644
index 0000000000000..a055311567216
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/rsqrt/rsqrt.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// f32 rsqrt -- non-ftz and ftz forms.
+llvm.func @rsqrt_f32(%a : f32) -> f32 {
+ // CHECK-LABEL: define float @rsqrt_f32(float %0) {
+ // CHECK: call float @llvm.nvvm.rsqrt.approx.f(float %{{.*}})
+ // CHECK: call float @llvm.nvvm.rsqrt.approx.ftz.f(float %{{.*}})
+ %r1 = nvvm.rsqrt %a : f32
+ %r2 = nvvm.rsqrt %r1 {ftz = true} : f32
+ llvm.return %r2 : f32
+}
+
+// f64 rsqrt -- non-ftz and ftz forms.
+llvm.func @rsqrt_f64(%a : f64) -> f64 {
+ // CHECK-LABEL: define double @rsqrt_f64(double %0) {
+ // CHECK: call double @llvm.nvvm.rsqrt.approx.d(double %{{.*}})
+ // CHECK: call double @llvm.nvvm.rsqrt.approx.ftz.d(double %{{.*}})
+ %r1 = nvvm.rsqrt %a : f64
+ %r2 = nvvm.rsqrt %r1 {ftz = true} : f64
+ llvm.return %r2 : f64
+}
More information about the Mlir-commits
mailing list