[Mlir-commits] [mlir] 03326f9 - [MLIR][NVVM] Add rsqrt Op (#195854)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue May 12 23:46:04 PDT 2026
Author: Varad Rahul Kamthe
Date: 2026-05-13T12:15:59+05:30
New Revision: 03326f9f0ad7a32e9ab015687f16ff09a2182f64
URL: https://github.com/llvm/llvm-project/commit/03326f9f0ad7a32e9ab015687f16ff09a2182f64
DIFF: https://github.com/llvm/llvm-project/commit/03326f9f0ad7a32e9ab015687f16ff09a2182f64.diff
LOG: [MLIR][NVVM] Add rsqrt Op (#195854)
Adds `nvvm.rsqrt` op for fast approximate reciprocal square root. Supports f32 and f64 with an optional `ftz` attribute.
For more information, see PTX ISA: https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-rsqrt
Added:
mlir/test/Target/LLVMIR/nvvm/rsqrt/rsqrt.mlir
Modified:
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 4e0bce985fdf7..70aad7ac095b9 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -501,6 +501,34 @@ def NVVM_Ex2Op : NVVM_F32UnaryApproxOp<"ex2"> {
}];
}
+//===----------------------------------------------------------------------===//
+// NVVM rsqrt op definitions
+//===----------------------------------------------------------------------===//
+
+def NVVM_RsqrtOp
+ : NVVM_SingleResultIntrinsicOp<"rsqrt", [Pure, SameOperandsAndResultType]> {
+ let summary = "Reciprocal square root (fast approximation)";
+ let description = [{
+ Computes an approximation of the reciprocal of the square root of the
+ input value: `d = 1 / sqrt(a)`. Supports both f32 and f64. The maximum
+ relative error for the f32 form over the entire positive finite range
+ is 2^-22.9.
+
+ The `ftz` attribute, when set, flushes subnormal inputs and results to
+ sign-preserving zero. For f64 inputs, `ftz=true` selects a coarser
+ approximation that uses only the upper 32 bits of the input (the lower
+ 32 bits of the result are zeroed).
+
+ For more information, see PTX ISA:
+ [rsqrt](https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-rsqrt)
+ }];
+
+ let arguments = (ins AnyTypeOf<[F32, F64]>:$src,
+ DefaultValuedAttr<BoolAttr, "false">:$ftz);
+ let results = (outs AnyTypeOf<[F32, F64]>:$res);
+ let assemblyFormat = "$src attr-dict `:` type($src)";
+}
+
//===----------------------------------------------------------------------===//
// NVVM redux op definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 7770f8f2dba0a..943de71d34a96 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -3515,6 +3515,26 @@ Ex2Op::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
return {id, {mt.lookupValue(thisOp.getSrc())}};
}
+mlir::NVVM::IDArgPair
+RsqrtOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::RsqrtOp>(op);
+ Type t = thisOp.getRes().getType();
+ bool isFtz = thisOp.getFtz();
+
+ llvm::Intrinsic::ID id = [&] {
+ if (t.isF32()) {
+ return isFtz ? llvm::Intrinsic::nvvm_rsqrt_approx_ftz_f
+ : llvm::Intrinsic::nvvm_rsqrt_approx_f;
+ }
+ // f64
+ return isFtz ? llvm::Intrinsic::nvvm_rsqrt_approx_ftz_d
+ : llvm::Intrinsic::nvvm_rsqrt_approx_d;
+ }();
+
+ return {id, {mt.lookupValue(thisOp.getSrc())}};
+}
+
mlir::NVVM::IDArgPair
PMEventOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
llvm::IRBuilderBase &builder) {
diff --git a/mlir/test/Target/LLVMIR/nvvm/rsqrt/rsqrt.mlir b/mlir/test/Target/LLVMIR/nvvm/rsqrt/rsqrt.mlir
new file mode 100644
index 0000000000000..a055311567216
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/rsqrt/rsqrt.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// f32 rsqrt -- non-ftz and ftz forms.
+llvm.func @rsqrt_f32(%a : f32) -> f32 {
+ // CHECK-LABEL: define float @rsqrt_f32(float %0) {
+ // CHECK: call float @llvm.nvvm.rsqrt.approx.f(float %{{.*}})
+ // CHECK: call float @llvm.nvvm.rsqrt.approx.ftz.f(float %{{.*}})
+ %r1 = nvvm.rsqrt %a : f32
+ %r2 = nvvm.rsqrt %r1 {ftz = true} : f32
+ llvm.return %r2 : f32
+}
+
+// f64 rsqrt -- non-ftz and ftz forms.
+llvm.func @rsqrt_f64(%a : f64) -> f64 {
+ // CHECK-LABEL: define double @rsqrt_f64(double %0) {
+ // CHECK: call double @llvm.nvvm.rsqrt.approx.d(double %{{.*}})
+ // CHECK: call double @llvm.nvvm.rsqrt.approx.ftz.d(double %{{.*}})
+ %r1 = nvvm.rsqrt %a : f64
+ %r2 = nvvm.rsqrt %r1 {ftz = true} : f64
+ llvm.return %r2 : f64
+}
More information about the Mlir-commits
mailing list