[llvm] [NVPTX] Improve support for rsqrt.approx (PR #89417)

Alex MacLean via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 19 16:34:12 PDT 2024


https://github.com/AlexMaclean updated https://github.com/llvm/llvm-project/pull/89417

>From e9ee87c05c142c01b03cffead1ab23047c30f4c4 Mon Sep 17 00:00:00 2001
From: Alex MacLean <amaclean at nvidia.com>
Date: Thu, 18 Apr 2024 16:52:46 +0000
Subject: [PATCH 1/2] [NVPTX] Improve support for rsqrt.approx

---
 llvm/include/llvm/IR/IntrinsicsNVVM.td      |  2 +
 llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp |  6 ++
 llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h   |  1 +
 llvm/lib/Target/NVPTX/NVPTXInstrInfo.td     |  1 +
 llvm/lib/Target/NVPTX/NVPTXIntrinsics.td    | 25 +++++++
 llvm/test/CodeGen/NVPTX/rsqrt-opt.ll        | 75 +++++++++++++++++++++
 llvm/test/CodeGen/NVPTX/rsqrt.ll            | 35 ++++++++++
 7 files changed, 145 insertions(+)
 create mode 100644 llvm/test/CodeGen/NVPTX/rsqrt-opt.ll
 create mode 100644 llvm/test/CodeGen/NVPTX/rsqrt.ll

diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 726cea004606e2..0a9139e0062ba3 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1003,6 +1003,8 @@ let TargetPrefix = "nvvm" in {
 
   def int_nvvm_rsqrt_approx_ftz_f : ClangBuiltin<"__nvvm_rsqrt_approx_ftz_f">,
       DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty], [IntrNoMem]>;
+  def int_nvvm_rsqrt_approx_ftz_d : ClangBuiltin<"__nvvm_rsqrt_approx_ftz_d">,
+      DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty], [IntrNoMem]>;
   def int_nvvm_rsqrt_approx_f : ClangBuiltin<"__nvvm_rsqrt_approx_f">,
       DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty], [IntrNoMem]>;
   def int_nvvm_rsqrt_approx_d : ClangBuiltin<"__nvvm_rsqrt_approx_d">,
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 3ff8994602e16b..26b95b396cfebf 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -30,6 +30,10 @@ using namespace llvm;
 #define DEBUG_TYPE "nvptx-isel"
 #define PASS_NAME "NVPTX DAG->DAG Pattern Instruction Selection"
 
+static cl::opt<bool>
+    DisableRsqrtOpt("nvptx-disable-rsqrt-opt", cl::init(false), cl::Hidden,
+                    cl::desc("Disable reciprocal sqrt optimization"));
+
 /// createNVPTXISelDag - This pass converts a legalized DAG into a
 /// NVPTX-specific DAG, ready for instruction scheduling.
 FunctionPass *llvm::createNVPTXISelDag(NVPTXTargetMachine &TM,
@@ -78,6 +82,8 @@ bool NVPTXDAGToDAGISel::useShortPointers() const {
   return TM.useShortPointers();
 }
 
+bool NVPTXDAGToDAGISel::doRsqrtOpt() const { return !DisableRsqrtOpt; }
+
 /// Select - Select instructions not customized! Used for
 /// expanded, promoted and normal instructions.
 void NVPTXDAGToDAGISel::Select(SDNode *N) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index 84c8432047ca31..c74f8fa884596a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -37,6 +37,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
   bool allowFMA() const;
   bool allowUnsafeFPMath() const;
   bool useShortPointers() const;
+  bool doRsqrtOpt() const;
 
 public:
   static char ID;
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index cd8546005c0289..0af7423dfd0b9b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -142,6 +142,7 @@ def hasLDU : Predicate<"Subtarget->hasLDU()">;
 
 def doF32FTZ : Predicate<"useF32FTZ()">;
 def doNoF32FTZ : Predicate<"!useF32FTZ()">;
+def doRsqrtOpt : Predicate<"doRsqrtOpt()">;
 
 def doMulWide      : Predicate<"doMulWide">;
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index c0c53380a13e9b..05b22aa41198d1 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1171,11 +1171,36 @@ def : Pat<(int_nvvm_sqrt_f Float32Regs:$a),
 def INT_NVVM_RSQRT_APPROX_FTZ_F
   : F_MATH_1<"rsqrt.approx.ftz.f32 \t$dst, $src0;", Float32Regs, Float32Regs,
     int_nvvm_rsqrt_approx_ftz_f>;
+def INT_NVVM_RSQRT_APPROX_FTZ_D
+  : F_MATH_1<"rsqrt.approx.ftz.f64 \t$dst, $src0;", Float64Regs, Float64Regs,
+    int_nvvm_rsqrt_approx_ftz_d>;
+
 def INT_NVVM_RSQRT_APPROX_F : F_MATH_1<"rsqrt.approx.f32 \t$dst, $src0;",
   Float32Regs, Float32Regs, int_nvvm_rsqrt_approx_f>;
 def INT_NVVM_RSQRT_APPROX_D : F_MATH_1<"rsqrt.approx.f64 \t$dst, $src0;",
   Float64Regs, Float64Regs, int_nvvm_rsqrt_approx_d>;
 
+// 1.0f / sqrt_approx -> rsqrt_approx
+def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_approx_f Float32Regs:$a)),
+         (INT_NVVM_RSQRT_APPROX_F Float32Regs:$a)>,
+         Requires<[doRsqrtOpt]>;
+def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_approx_ftz_f Float32Regs:$a)),
+         (INT_NVVM_RSQRT_APPROX_FTZ_F Float32Regs:$a)>,
+         Requires<[doRsqrtOpt]>;
+// same for int_nvvm_sqrt_f when non-precision sqrt is requested
+def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_f Float32Regs:$a)),
+         (INT_NVVM_RSQRT_APPROX_F Float32Regs:$a)>,
+         Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doNoF32FTZ]>;
+def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_f Float32Regs:$a)),
+         (INT_NVVM_RSQRT_APPROX_FTZ_F Float32Regs:$a)>,
+         Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doF32FTZ]>;
+
+def: Pat<(fdiv FloatConst1, (fsqrt Float32Regs:$a)),
+         (INT_NVVM_RSQRT_APPROX_F Float32Regs:$a)>,
+         Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doNoF32FTZ]>;
+def: Pat<(fdiv FloatConst1, (fsqrt Float32Regs:$a)),
+         (INT_NVVM_RSQRT_APPROX_FTZ_F Float32Regs:$a)>,
+         Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doF32FTZ]>;
 //
 // Add
 //
diff --git a/llvm/test/CodeGen/NVPTX/rsqrt-opt.ll b/llvm/test/CodeGen/NVPTX/rsqrt-opt.ll
new file mode 100644
index 00000000000000..11bd82a015ce12
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/rsqrt-opt.ll
@@ -0,0 +1,75 @@
+; RUN: llc < %s -march=nvptx64 | FileCheck %s --check-prefixes CHECK,CHECK-APPROX-OPT,CHECK-SQRT-NOOPT
+; RUN: llc < %s -march=nvptx64 -nvptx-prec-sqrtf32=0 | FileCheck %s --check-prefixes CHECK,CHECK-APPROX-OPT,CHECK-SQRT-OPT
+; RUN: llc < %s -march=nvptx64 -nvptx-disable-rsqrt-opt | FileCheck %s --check-prefixes CHECK,CHECK-APPROX-NOOPT,CHECK-SQRT-NOOPT
+;
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %}
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -nvptx-prec-sqrtf32=0 | %ptxas-verify %}
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -nvptx-disable-rsqrt-opt | %ptxas-verify %}
+
+
+; CHECK-LABEL: .func{{.*}}test1
+define float @test1(float %in) local_unnamed_addr {
+; CHECK-APPROX-OPT: rsqrt.approx.f32
+; CHECK-APPROX-NOOPT: sqrt.approx.f32
+; CHECK-APPROX-NOOPT-NEXT: rcp.rn.f32
+  %sqrt = tail call float @llvm.nvvm.sqrt.approx.f(float %in)
+  %rsqrt = fdiv float 1.0, %sqrt
+  ret float %rsqrt
+}
+; CHECK-LABEL: .func{{.*}}test2
+define float @test2(float %in) local_unnamed_addr {
+; CHECK-APPROX-OPT: rsqrt.approx.ftz.f32
+; CHECK-APPROX-NOOPT: sqrt.approx.ftz.f32
+; CHECK-APPROX-NOOPT-NEXT: rcp.rn.f32
+  %sqrt = tail call float @llvm.nvvm.sqrt.approx.ftz.f(float %in)
+  %rsqrt = fdiv float 1.0, %sqrt
+  ret float %rsqrt
+}
+
+; CHECK-LABEL: .func{{.*}}test3
+define float @test3(float %in) local_unnamed_addr {
+; CHECK-SQRT-OPT: rsqrt.approx.f32
+; CHECK-SQRT-NOOPT: sqrt.rn.f32
+; CHECK-SQRT-NOOPT-NEXT: rcp.rn.f32
+  %sqrt = tail call float @llvm.nvvm.sqrt.f(float %in)
+  %rsqrt = fdiv float 1.0, %sqrt
+  ret float %rsqrt
+}
+
+; CHECK-LABEL: .func{{.*}}test4
+define float @test4(float %in) local_unnamed_addr #0 {
+; CHECK-SQRT-OPT: rsqrt.approx.ftz.f32
+; CHECK-SQRT-NOOPT: sqrt.rn.ftz.f32
+; CHECK-SQRT-NOOPT-NEXT: rcp.rn.ftz.f32
+  %sqrt = tail call float @llvm.nvvm.sqrt.f(float %in)
+  %rsqrt = fdiv float 1.0, %sqrt
+  ret float %rsqrt
+}
+
+; CHECK-LABEL: .func{{.*}}test5
+define float @test5(float %in) local_unnamed_addr {
+; CHECK-SQRT-OPT: rsqrt.approx.f32
+; CHECK-SQRT-NOOPT: sqrt.rn.f32
+; CHECK-SQRT-NOOPT-NEXT: rcp.rn.f32
+  %sqrt = tail call float @llvm.sqrt.f32(float %in)
+  %rsqrt = fdiv float 1.0, %sqrt
+  ret float %rsqrt
+}
+
+; CHECK-LABEL: .func{{.*}}test6
+define float @test6(float %in) local_unnamed_addr #0 {
+; CHECK-SQRT-OPT: rsqrt.approx.ftz.f32
+; CHECK-SQRT-NOOPT: sqrt.rn.ftz.f32
+; CHECK-SQRT-NOOPT-NEXT: rcp.rn.ftz.f32
+  %sqrt = tail call float @llvm.sqrt.f32(float %in)
+  %rsqrt = fdiv float 1.0, %sqrt
+  ret float %rsqrt
+}
+
+
+declare float @llvm.nvvm.sqrt.f(float)
+declare float @llvm.nvvm.sqrt.approx.f(float)
+declare float @llvm.nvvm.sqrt.approx.ftz.f(float)
+declare float @llvm.sqrt.f32(float)
+
+attributes #0 = { "denormal-fp-math-f32" = "preserve-sign" }
diff --git a/llvm/test/CodeGen/NVPTX/rsqrt.ll b/llvm/test/CodeGen/NVPTX/rsqrt.ll
new file mode 100644
index 00000000000000..c7367245c532e3
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/rsqrt.ll
@@ -0,0 +1,35 @@
+; RUN: llc < %s -march=nvptx64 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %}
+
+; CHECK-LABEL: .func{{.*}}test1
+define float @test1(float %in) local_unnamed_addr {
+; CHECK: rsqrt.approx.f32
+  %call = call float @llvm.nvvm.rsqrt.approx.f(float %in)
+  ret float %call
+}
+
+; CHECK-LABEL: .func{{.*}}test2
+define double @test2(double %in) local_unnamed_addr {
+; CHECK: rsqrt.approx.f64
+  %call = call double @llvm.nvvm.rsqrt.approx.d(double %in)
+  ret double %call
+}
+
+; CHECK-LABEL: .func{{.*}}test3
+define float @test3(float %in) local_unnamed_addr {
+; CHECK: rsqrt.approx.ftz.f32
+  %call = tail call float @llvm.nvvm.rsqrt.approx.ftz.f(float %in)
+  ret float %call
+}
+
+; CHECK-LABEL: .func{{.*}}test4
+define double @test4(double %in) local_unnamed_addr {
+; CHECK: rsqrt.approx.ftz.f64
+  %call = tail call double @llvm.nvvm.rsqrt.approx.ftz.d(double %in)
+  ret double %call
+}
+
+declare float @llvm.nvvm.rsqrt.approx.ftz.f(float)
+declare double @llvm.nvvm.rsqrt.approx.ftz.d(double)
+declare float @llvm.nvvm.rsqrt.approx.f(float)
+declare double @llvm.nvvm.rsqrt.approx.d(double)

>From 42055df40a018db30385b1af0eb483e965f484ab Mon Sep 17 00:00:00 2001
From: Alex MacLean <amaclean at nvidia.com>
Date: Fri, 19 Apr 2024 22:58:01 +0000
Subject: [PATCH 2/2] rename option nvptx-rsqrt-approx-opt

---
 llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp | 6 +++---
 llvm/test/CodeGen/NVPTX/rsqrt-opt.ll        | 4 ++--
 2 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 26b95b396cfebf..806431be7f7f2d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -31,8 +31,8 @@ using namespace llvm;
 #define PASS_NAME "NVPTX DAG->DAG Pattern Instruction Selection"
 
 static cl::opt<bool>
-    DisableRsqrtOpt("nvptx-disable-rsqrt-opt", cl::init(false), cl::Hidden,
-                    cl::desc("Disable reciprocal sqrt optimization"));
+    EnableRsqrtOpt("nvptx-rsqrt-approx-opt", cl::init(true), cl::Hidden,
+                   cl::desc("Enable reciprocal sqrt optimization"));
 
 /// createNVPTXISelDag - This pass converts a legalized DAG into a
 /// NVPTX-specific DAG, ready for instruction scheduling.
@@ -82,7 +82,7 @@ bool NVPTXDAGToDAGISel::useShortPointers() const {
   return TM.useShortPointers();
 }
 
-bool NVPTXDAGToDAGISel::doRsqrtOpt() const { return !DisableRsqrtOpt; }
+bool NVPTXDAGToDAGISel::doRsqrtOpt() const { return EnableRsqrtOpt; }
 
 /// Select - Select instructions not customized! Used for
 /// expanded, promoted and normal instructions.
diff --git a/llvm/test/CodeGen/NVPTX/rsqrt-opt.ll b/llvm/test/CodeGen/NVPTX/rsqrt-opt.ll
index 11bd82a015ce12..9dda6075a23c6d 100644
--- a/llvm/test/CodeGen/NVPTX/rsqrt-opt.ll
+++ b/llvm/test/CodeGen/NVPTX/rsqrt-opt.ll
@@ -1,10 +1,10 @@
 ; RUN: llc < %s -march=nvptx64 | FileCheck %s --check-prefixes CHECK,CHECK-APPROX-OPT,CHECK-SQRT-NOOPT
 ; RUN: llc < %s -march=nvptx64 -nvptx-prec-sqrtf32=0 | FileCheck %s --check-prefixes CHECK,CHECK-APPROX-OPT,CHECK-SQRT-OPT
-; RUN: llc < %s -march=nvptx64 -nvptx-disable-rsqrt-opt | FileCheck %s --check-prefixes CHECK,CHECK-APPROX-NOOPT,CHECK-SQRT-NOOPT
+; RUN: llc < %s -march=nvptx64 -nvptx-rsqrt-approx-opt=0 | FileCheck %s --check-prefixes CHECK,CHECK-APPROX-NOOPT,CHECK-SQRT-NOOPT
 ;
 ; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %}
 ; RUN: %if ptxas %{ llc < %s -march=nvptx64 -nvptx-prec-sqrtf32=0 | %ptxas-verify %}
-; RUN: %if ptxas %{ llc < %s -march=nvptx64 -nvptx-disable-rsqrt-opt | %ptxas-verify %}
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -nvptx-rsqrt-approx-opt=0 | %ptxas-verify %}
 
 
 ; CHECK-LABEL: .func{{.*}}test1



More information about the llvm-commits mailing list