[llvm] df60805 - [NVPTX] Improve support for rsqrt.approx (#89417)

via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 23 08:56:43 PDT 2024


Author: Alex MacLean
Date: 2024-04-23T08:56:39-07:00
New Revision: df608051234c256f1dc2c89f30afd034706c2c2e

URL: https://github.com/llvm/llvm-project/commit/df608051234c256f1dc2c89f30afd034706c2c2e
DIFF: https://github.com/llvm/llvm-project/commit/df608051234c256f1dc2c89f30afd034706c2c2e.diff

LOG: [NVPTX] Improve support for rsqrt.approx (#89417)

Complete support for rsqrt.approx with rsqrt.approx.f64 ([PTX ISA
9.7.3.17. Floating Point Instructions:
rsqrt.approx.ftz.f64](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rsqrt-approx-ftz-f64)).
Additionally, add support for folding `sqrt` into `rsqrt`, with an
optional flag to disable.

Added: 
    llvm/test/CodeGen/NVPTX/rsqrt-opt.ll
    llvm/test/CodeGen/NVPTX/rsqrt.ll

Modified: 
    llvm/include/llvm/IR/IntrinsicsNVVM.td
    llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
    llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
    llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
    llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 726cea004606e2..0a9139e0062ba3 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1003,6 +1003,8 @@ let TargetPrefix = "nvvm" in {
 
   def int_nvvm_rsqrt_approx_ftz_f : ClangBuiltin<"__nvvm_rsqrt_approx_ftz_f">,
       DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty], [IntrNoMem]>;
+  def int_nvvm_rsqrt_approx_ftz_d : ClangBuiltin<"__nvvm_rsqrt_approx_ftz_d">,
+      DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty], [IntrNoMem]>;
   def int_nvvm_rsqrt_approx_f : ClangBuiltin<"__nvvm_rsqrt_approx_f">,
       DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty], [IntrNoMem]>;
   def int_nvvm_rsqrt_approx_d : ClangBuiltin<"__nvvm_rsqrt_approx_d">,

diff  --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index a362709c98efd6..595395bb1b4b43 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -30,6 +30,10 @@ using namespace llvm;
 #define DEBUG_TYPE "nvptx-isel"
 #define PASS_NAME "NVPTX DAG->DAG Pattern Instruction Selection"
 
+static cl::opt<bool>
+    EnableRsqrtOpt("nvptx-rsqrt-approx-opt", cl::init(true), cl::Hidden,
+                   cl::desc("Enable reciprocal sqrt optimization"));
+
 /// createNVPTXISelDag - This pass converts a legalized DAG into a
 /// NVPTX-specific DAG, ready for instruction scheduling.
 FunctionPass *llvm::createNVPTXISelDag(NVPTXTargetMachine &TM,
@@ -74,6 +78,8 @@ bool NVPTXDAGToDAGISel::allowUnsafeFPMath() const {
   return TL->allowUnsafeFPMath(*MF);
 }
 
+bool NVPTXDAGToDAGISel::doRsqrtOpt() const { return EnableRsqrtOpt; }
+
 /// Select - Select instructions not customized! Used for
 /// expanded, promoted and normal instructions.
 void NVPTXDAGToDAGISel::Select(SDNode *N) {

diff  --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index 10822f87cef308..7a7774744bc71d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -36,6 +36,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
   bool useF32FTZ() const;
   bool allowFMA() const;
   bool allowUnsafeFPMath() const;
+  bool doRsqrtOpt() const;
 
 public:
   static char ID;

diff  --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 931292c7fd6042..897ee89323f083 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -142,6 +142,7 @@ def hasLDU : Predicate<"Subtarget->hasLDU()">;
 
 def doF32FTZ : Predicate<"useF32FTZ()">;
 def doNoF32FTZ : Predicate<"!useF32FTZ()">;
+def doRsqrtOpt : Predicate<"doRsqrtOpt()">;
 
 def doMulWide      : Predicate<"doMulWide">;
 

diff  --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index ec9170b4e41e5c..5f6e28283c5d25 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1171,11 +1171,36 @@ def : Pat<(int_nvvm_sqrt_f Float32Regs:$a),
 def INT_NVVM_RSQRT_APPROX_FTZ_F
   : F_MATH_1<"rsqrt.approx.ftz.f32 \t$dst, $src0;", Float32Regs, Float32Regs,
     int_nvvm_rsqrt_approx_ftz_f>;
+def INT_NVVM_RSQRT_APPROX_FTZ_D
+  : F_MATH_1<"rsqrt.approx.ftz.f64 \t$dst, $src0;", Float64Regs, Float64Regs,
+    int_nvvm_rsqrt_approx_ftz_d>;
+
 def INT_NVVM_RSQRT_APPROX_F : F_MATH_1<"rsqrt.approx.f32 \t$dst, $src0;",
   Float32Regs, Float32Regs, int_nvvm_rsqrt_approx_f>;
 def INT_NVVM_RSQRT_APPROX_D : F_MATH_1<"rsqrt.approx.f64 \t$dst, $src0;",
   Float64Regs, Float64Regs, int_nvvm_rsqrt_approx_d>;
 
+// 1.0f / sqrt_approx -> rsqrt_approx
+def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_approx_f Float32Regs:$a)),
+         (INT_NVVM_RSQRT_APPROX_F Float32Regs:$a)>,
+         Requires<[doRsqrtOpt]>;
+def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_approx_ftz_f Float32Regs:$a)),
+         (INT_NVVM_RSQRT_APPROX_FTZ_F Float32Regs:$a)>,
+         Requires<[doRsqrtOpt]>;
+// same for int_nvvm_sqrt_f when non-precision sqrt is requested
+def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_f Float32Regs:$a)),
+         (INT_NVVM_RSQRT_APPROX_F Float32Regs:$a)>,
+         Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doNoF32FTZ]>;
+def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_f Float32Regs:$a)),
+         (INT_NVVM_RSQRT_APPROX_FTZ_F Float32Regs:$a)>,
+         Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doF32FTZ]>;
+
+def: Pat<(fdiv FloatConst1, (fsqrt Float32Regs:$a)),
+         (INT_NVVM_RSQRT_APPROX_F Float32Regs:$a)>,
+         Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doNoF32FTZ]>;
+def: Pat<(fdiv FloatConst1, (fsqrt Float32Regs:$a)),
+         (INT_NVVM_RSQRT_APPROX_FTZ_F Float32Regs:$a)>,
+         Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doF32FTZ]>;
 //
 // Add
 //

diff  --git a/llvm/test/CodeGen/NVPTX/rsqrt-opt.ll b/llvm/test/CodeGen/NVPTX/rsqrt-opt.ll
new file mode 100644
index 00000000000000..9dda6075a23c6d
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/rsqrt-opt.ll
@@ -0,0 +1,75 @@
+; RUN: llc < %s -march=nvptx64 | FileCheck %s --check-prefixes CHECK,CHECK-APPROX-OPT,CHECK-SQRT-NOOPT
+; RUN: llc < %s -march=nvptx64 -nvptx-prec-sqrtf32=0 | FileCheck %s --check-prefixes CHECK,CHECK-APPROX-OPT,CHECK-SQRT-OPT
+; RUN: llc < %s -march=nvptx64 -nvptx-rsqrt-approx-opt=0 | FileCheck %s --check-prefixes CHECK,CHECK-APPROX-NOOPT,CHECK-SQRT-NOOPT
+;
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %}
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -nvptx-prec-sqrtf32=0 | %ptxas-verify %}
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -nvptx-rsqrt-approx-opt=0 | %ptxas-verify %}
+
+
+; CHECK-LABEL: .func{{.*}}test1
+define float @test1(float %in) local_unnamed_addr {
+; CHECK-APPROX-OPT: rsqrt.approx.f32
+; CHECK-APPROX-NOOPT: sqrt.approx.f32
+; CHECK-APPROX-NOOPT-NEXT: rcp.rn.f32
+  %sqrt = tail call float @llvm.nvvm.sqrt.approx.f(float %in)
+  %rsqrt = fdiv float 1.0, %sqrt
+  ret float %rsqrt
+}
+; CHECK-LABEL: .func{{.*}}test2
+define float @test2(float %in) local_unnamed_addr {
+; CHECK-APPROX-OPT: rsqrt.approx.ftz.f32
+; CHECK-APPROX-NOOPT: sqrt.approx.ftz.f32
+; CHECK-APPROX-NOOPT-NEXT: rcp.rn.f32
+  %sqrt = tail call float @llvm.nvvm.sqrt.approx.ftz.f(float %in)
+  %rsqrt = fdiv float 1.0, %sqrt
+  ret float %rsqrt
+}
+
+; CHECK-LABEL: .func{{.*}}test3
+define float @test3(float %in) local_unnamed_addr {
+; CHECK-SQRT-OPT: rsqrt.approx.f32
+; CHECK-SQRT-NOOPT: sqrt.rn.f32
+; CHECK-SQRT-NOOPT-NEXT: rcp.rn.f32
+  %sqrt = tail call float @llvm.nvvm.sqrt.f(float %in)
+  %rsqrt = fdiv float 1.0, %sqrt
+  ret float %rsqrt
+}
+
+; CHECK-LABEL: .func{{.*}}test4
+define float @test4(float %in) local_unnamed_addr #0 {
+; CHECK-SQRT-OPT: rsqrt.approx.ftz.f32
+; CHECK-SQRT-NOOPT: sqrt.rn.ftz.f32
+; CHECK-SQRT-NOOPT-NEXT: rcp.rn.ftz.f32
+  %sqrt = tail call float @llvm.nvvm.sqrt.f(float %in)
+  %rsqrt = fdiv float 1.0, %sqrt
+  ret float %rsqrt
+}
+
+; CHECK-LABEL: .func{{.*}}test5
+define float @test5(float %in) local_unnamed_addr {
+; CHECK-SQRT-OPT: rsqrt.approx.f32
+; CHECK-SQRT-NOOPT: sqrt.rn.f32
+; CHECK-SQRT-NOOPT-NEXT: rcp.rn.f32
+  %sqrt = tail call float @llvm.sqrt.f32(float %in)
+  %rsqrt = fdiv float 1.0, %sqrt
+  ret float %rsqrt
+}
+
+; CHECK-LABEL: .func{{.*}}test6
+define float @test6(float %in) local_unnamed_addr #0 {
+; CHECK-SQRT-OPT: rsqrt.approx.ftz.f32
+; CHECK-SQRT-NOOPT: sqrt.rn.ftz.f32
+; CHECK-SQRT-NOOPT-NEXT: rcp.rn.ftz.f32
+  %sqrt = tail call float @llvm.sqrt.f32(float %in)
+  %rsqrt = fdiv float 1.0, %sqrt
+  ret float %rsqrt
+}
+
+
+declare float @llvm.nvvm.sqrt.f(float)
+declare float @llvm.nvvm.sqrt.approx.f(float)
+declare float @llvm.nvvm.sqrt.approx.ftz.f(float)
+declare float @llvm.sqrt.f32(float)
+
+attributes #0 = { "denormal-fp-math-f32" = "preserve-sign" }

diff  --git a/llvm/test/CodeGen/NVPTX/rsqrt.ll b/llvm/test/CodeGen/NVPTX/rsqrt.ll
new file mode 100644
index 00000000000000..c7367245c532e3
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/rsqrt.ll
@@ -0,0 +1,35 @@
+; RUN: llc < %s -march=nvptx64 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %}
+
+; CHECK-LABEL: .func{{.*}}test1
+define float @test1(float %in) local_unnamed_addr {
+; CHECK: rsqrt.approx.f32
+  %call = call float @llvm.nvvm.rsqrt.approx.f(float %in)
+  ret float %call
+}
+
+; CHECK-LABEL: .func{{.*}}test2
+define double @test2(double %in) local_unnamed_addr {
+; CHECK: rsqrt.approx.f64
+  %call = call double @llvm.nvvm.rsqrt.approx.d(double %in)
+  ret double %call
+}
+
+; CHECK-LABEL: .func{{.*}}test3
+define float @test3(float %in) local_unnamed_addr {
+; CHECK: rsqrt.approx.ftz.f32
+  %call = tail call float @llvm.nvvm.rsqrt.approx.ftz.f(float %in)
+  ret float %call
+}
+
+; CHECK-LABEL: .func{{.*}}test4
+define double @test4(double %in) local_unnamed_addr {
+; CHECK: rsqrt.approx.ftz.f64
+  %call = tail call double @llvm.nvvm.rsqrt.approx.ftz.d(double %in)
+  ret double %call
+}
+
+declare float @llvm.nvvm.rsqrt.approx.ftz.f(float)
+declare double @llvm.nvvm.rsqrt.approx.ftz.d(double)
+declare float @llvm.nvvm.rsqrt.approx.f(float)
+declare double @llvm.nvvm.rsqrt.approx.d(double)


        


More information about the llvm-commits mailing list