[llvm] [NVPTX] Improve support for rsqrt.approx (PR #89417)

Alex MacLean via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 19 09:48:18 PDT 2024


https://github.com/AlexMaclean created https://github.com/llvm/llvm-project/pull/89417

Complete support for rsqrt.approx with rsqrt.approx.f64 ([PTX ISA 9.7.3.17. Floating Point Instructions: rsqrt.approx.ftz.f64](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rsqrt-approx-ftz-f64)). Additionally, add support for folding `sqrt` into `rsqrt`, with an optional flag to disable.

>From e9ee87c05c142c01b03cffead1ab23047c30f4c4 Mon Sep 17 00:00:00 2001
From: Alex MacLean <amaclean at nvidia.com>
Date: Thu, 18 Apr 2024 16:52:46 +0000
Subject: [PATCH] [NVPTX] Improve support for rsqrt.approx

---
 llvm/include/llvm/IR/IntrinsicsNVVM.td      |  2 +
 llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp |  6 ++
 llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h   |  1 +
 llvm/lib/Target/NVPTX/NVPTXInstrInfo.td     |  1 +
 llvm/lib/Target/NVPTX/NVPTXIntrinsics.td    | 25 +++++++
 llvm/test/CodeGen/NVPTX/rsqrt-opt.ll        | 75 +++++++++++++++++++++
 llvm/test/CodeGen/NVPTX/rsqrt.ll            | 35 ++++++++++
 7 files changed, 145 insertions(+)
 create mode 100644 llvm/test/CodeGen/NVPTX/rsqrt-opt.ll
 create mode 100644 llvm/test/CodeGen/NVPTX/rsqrt.ll

diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 726cea004606e2..0a9139e0062ba3 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1003,6 +1003,8 @@ let TargetPrefix = "nvvm" in {
 
   def int_nvvm_rsqrt_approx_ftz_f : ClangBuiltin<"__nvvm_rsqrt_approx_ftz_f">,
       DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty], [IntrNoMem]>;
+  def int_nvvm_rsqrt_approx_ftz_d : ClangBuiltin<"__nvvm_rsqrt_approx_ftz_d">,
+      DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty], [IntrNoMem]>;
   def int_nvvm_rsqrt_approx_f : ClangBuiltin<"__nvvm_rsqrt_approx_f">,
       DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty], [IntrNoMem]>;
   def int_nvvm_rsqrt_approx_d : ClangBuiltin<"__nvvm_rsqrt_approx_d">,
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 3ff8994602e16b..26b95b396cfebf 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -30,6 +30,10 @@ using namespace llvm;
 #define DEBUG_TYPE "nvptx-isel"
 #define PASS_NAME "NVPTX DAG->DAG Pattern Instruction Selection"
 
+static cl::opt<bool>
+    DisableRsqrtOpt("nvptx-disable-rsqrt-opt", cl::init(false), cl::Hidden,
+                    cl::desc("Disable reciprocal sqrt optimization"));
+
 /// createNVPTXISelDag - This pass converts a legalized DAG into a
 /// NVPTX-specific DAG, ready for instruction scheduling.
 FunctionPass *llvm::createNVPTXISelDag(NVPTXTargetMachine &TM,
@@ -78,6 +82,8 @@ bool NVPTXDAGToDAGISel::useShortPointers() const {
   return TM.useShortPointers();
 }
 
+bool NVPTXDAGToDAGISel::doRsqrtOpt() const { return !DisableRsqrtOpt; }
+
 /// Select - Select instructions not customized! Used for
 /// expanded, promoted and normal instructions.
 void NVPTXDAGToDAGISel::Select(SDNode *N) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index 84c8432047ca31..c74f8fa884596a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -37,6 +37,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
   bool allowFMA() const;
   bool allowUnsafeFPMath() const;
   bool useShortPointers() const;
+  bool doRsqrtOpt() const;
 
 public:
   static char ID;
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index cd8546005c0289..0af7423dfd0b9b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -142,6 +142,7 @@ def hasLDU : Predicate<"Subtarget->hasLDU()">;
 
 def doF32FTZ : Predicate<"useF32FTZ()">;
 def doNoF32FTZ : Predicate<"!useF32FTZ()">;
+def doRsqrtOpt : Predicate<"doRsqrtOpt()">;
 
 def doMulWide      : Predicate<"doMulWide">;
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index c0c53380a13e9b..05b22aa41198d1 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1171,11 +1171,36 @@ def : Pat<(int_nvvm_sqrt_f Float32Regs:$a),
 def INT_NVVM_RSQRT_APPROX_FTZ_F
   : F_MATH_1<"rsqrt.approx.ftz.f32 \t$dst, $src0;", Float32Regs, Float32Regs,
     int_nvvm_rsqrt_approx_ftz_f>;
+def INT_NVVM_RSQRT_APPROX_FTZ_D
+  : F_MATH_1<"rsqrt.approx.ftz.f64 \t$dst, $src0;", Float64Regs, Float64Regs,
+    int_nvvm_rsqrt_approx_ftz_d>;
+
 def INT_NVVM_RSQRT_APPROX_F : F_MATH_1<"rsqrt.approx.f32 \t$dst, $src0;",
   Float32Regs, Float32Regs, int_nvvm_rsqrt_approx_f>;
 def INT_NVVM_RSQRT_APPROX_D : F_MATH_1<"rsqrt.approx.f64 \t$dst, $src0;",
   Float64Regs, Float64Regs, int_nvvm_rsqrt_approx_d>;
 
+// 1.0f / sqrt_approx -> rsqrt_approx
+def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_approx_f Float32Regs:$a)),
+         (INT_NVVM_RSQRT_APPROX_F Float32Regs:$a)>,
+         Requires<[doRsqrtOpt]>;
+def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_approx_ftz_f Float32Regs:$a)),
+         (INT_NVVM_RSQRT_APPROX_FTZ_F Float32Regs:$a)>,
+         Requires<[doRsqrtOpt]>;
+// same for int_nvvm_sqrt_f when non-precision sqrt is requested
+def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_f Float32Regs:$a)),
+         (INT_NVVM_RSQRT_APPROX_F Float32Regs:$a)>,
+         Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doNoF32FTZ]>;
+def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_f Float32Regs:$a)),
+         (INT_NVVM_RSQRT_APPROX_FTZ_F Float32Regs:$a)>,
+         Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doF32FTZ]>;
+
+def: Pat<(fdiv FloatConst1, (fsqrt Float32Regs:$a)),
+         (INT_NVVM_RSQRT_APPROX_F Float32Regs:$a)>,
+         Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doNoF32FTZ]>;
+def: Pat<(fdiv FloatConst1, (fsqrt Float32Regs:$a)),
+         (INT_NVVM_RSQRT_APPROX_FTZ_F Float32Regs:$a)>,
+         Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doF32FTZ]>;
 //
 // Add
 //
diff --git a/llvm/test/CodeGen/NVPTX/rsqrt-opt.ll b/llvm/test/CodeGen/NVPTX/rsqrt-opt.ll
new file mode 100644
index 00000000000000..11bd82a015ce12
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/rsqrt-opt.ll
@@ -0,0 +1,75 @@
+; RUN: llc < %s -march=nvptx64 | FileCheck %s --check-prefixes CHECK,CHECK-APPROX-OPT,CHECK-SQRT-NOOPT
+; RUN: llc < %s -march=nvptx64 -nvptx-prec-sqrtf32=0 | FileCheck %s --check-prefixes CHECK,CHECK-APPROX-OPT,CHECK-SQRT-OPT
+; RUN: llc < %s -march=nvptx64 -nvptx-disable-rsqrt-opt | FileCheck %s --check-prefixes CHECK,CHECK-APPROX-NOOPT,CHECK-SQRT-NOOPT
+;
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %}
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -nvptx-prec-sqrtf32=0 | %ptxas-verify %}
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -nvptx-disable-rsqrt-opt | %ptxas-verify %}
+
+
+; CHECK-LABEL: .func{{.*}}test1
+define float @test1(float %in) local_unnamed_addr {
+; CHECK-APPROX-OPT: rsqrt.approx.f32
+; CHECK-APPROX-NOOPT: sqrt.approx.f32
+; CHECK-APPROX-NOOPT-NEXT: rcp.rn.f32
+  %sqrt = tail call float @llvm.nvvm.sqrt.approx.f(float %in)
+  %rsqrt = fdiv float 1.0, %sqrt
+  ret float %rsqrt
+}
+; CHECK-LABEL: .func{{.*}}test2
+define float @test2(float %in) local_unnamed_addr {
+; CHECK-APPROX-OPT: rsqrt.approx.ftz.f32
+; CHECK-APPROX-NOOPT: sqrt.approx.ftz.f32
+; CHECK-APPROX-NOOPT-NEXT: rcp.rn.f32
+  %sqrt = tail call float @llvm.nvvm.sqrt.approx.ftz.f(float %in)
+  %rsqrt = fdiv float 1.0, %sqrt
+  ret float %rsqrt
+}
+
+; CHECK-LABEL: .func{{.*}}test3
+define float @test3(float %in) local_unnamed_addr {
+; CHECK-SQRT-OPT: rsqrt.approx.f32
+; CHECK-SQRT-NOOPT: sqrt.rn.f32
+; CHECK-SQRT-NOOPT-NEXT: rcp.rn.f32
+  %sqrt = tail call float @llvm.nvvm.sqrt.f(float %in)
+  %rsqrt = fdiv float 1.0, %sqrt
+  ret float %rsqrt
+}
+
+; CHECK-LABEL: .func{{.*}}test4
+define float @test4(float %in) local_unnamed_addr #0 {
+; CHECK-SQRT-OPT: rsqrt.approx.ftz.f32
+; CHECK-SQRT-NOOPT: sqrt.rn.ftz.f32
+; CHECK-SQRT-NOOPT-NEXT: rcp.rn.ftz.f32
+  %sqrt = tail call float @llvm.nvvm.sqrt.f(float %in)
+  %rsqrt = fdiv float 1.0, %sqrt
+  ret float %rsqrt
+}
+
+; CHECK-LABEL: .func{{.*}}test5
+define float @test5(float %in) local_unnamed_addr {
+; CHECK-SQRT-OPT: rsqrt.approx.f32
+; CHECK-SQRT-NOOPT: sqrt.rn.f32
+; CHECK-SQRT-NOOPT-NEXT: rcp.rn.f32
+  %sqrt = tail call float @llvm.sqrt.f32(float %in)
+  %rsqrt = fdiv float 1.0, %sqrt
+  ret float %rsqrt
+}
+
+; CHECK-LABEL: .func{{.*}}test6
+define float @test6(float %in) local_unnamed_addr #0 {
+; CHECK-SQRT-OPT: rsqrt.approx.ftz.f32
+; CHECK-SQRT-NOOPT: sqrt.rn.ftz.f32
+; CHECK-SQRT-NOOPT-NEXT: rcp.rn.ftz.f32
+  %sqrt = tail call float @llvm.sqrt.f32(float %in)
+  %rsqrt = fdiv float 1.0, %sqrt
+  ret float %rsqrt
+}
+
+
+declare float @llvm.nvvm.sqrt.f(float)
+declare float @llvm.nvvm.sqrt.approx.f(float)
+declare float @llvm.nvvm.sqrt.approx.ftz.f(float)
+declare float @llvm.sqrt.f32(float)
+
+attributes #0 = { "denormal-fp-math-f32" = "preserve-sign" }
diff --git a/llvm/test/CodeGen/NVPTX/rsqrt.ll b/llvm/test/CodeGen/NVPTX/rsqrt.ll
new file mode 100644
index 00000000000000..c7367245c532e3
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/rsqrt.ll
@@ -0,0 +1,35 @@
+; RUN: llc < %s -march=nvptx64 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %}
+
+; CHECK-LABEL: .func{{.*}}test1
+define float @test1(float %in) local_unnamed_addr {
+; CHECK: rsqrt.approx.f32
+  %call = call float @llvm.nvvm.rsqrt.approx.f(float %in)
+  ret float %call
+}
+
+; CHECK-LABEL: .func{{.*}}test2
+define double @test2(double %in) local_unnamed_addr {
+; CHECK: rsqrt.approx.f64
+  %call = call double @llvm.nvvm.rsqrt.approx.d(double %in)
+  ret double %call
+}
+
+; CHECK-LABEL: .func{{.*}}test3
+define float @test3(float %in) local_unnamed_addr {
+; CHECK: rsqrt.approx.ftz.f32
+  %call = tail call float @llvm.nvvm.rsqrt.approx.ftz.f(float %in)
+  ret float %call
+}
+
+; CHECK-LABEL: .func{{.*}}test4
+define double @test4(double %in) local_unnamed_addr {
+; CHECK: rsqrt.approx.ftz.f64
+  %call = tail call double @llvm.nvvm.rsqrt.approx.ftz.d(double %in)
+  ret double %call
+}
+
+declare float @llvm.nvvm.rsqrt.approx.ftz.f(float)
+declare double @llvm.nvvm.rsqrt.approx.ftz.d(double)
+declare float @llvm.nvvm.rsqrt.approx.f(float)
+declare double @llvm.nvvm.rsqrt.approx.d(double)



More information about the llvm-commits mailing list