[llvm] r293605 - [NVPTX] Implement NVPTXTargetLowering::getSqrtEstimate.

Justin Lebar via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 30 21:58:22 PST 2017


Author: jlebar
Date: Mon Jan 30 23:58:22 2017
New Revision: 293605

URL: http://llvm.org/viewvc/llvm-project?rev=293605&view=rev
Log:
[NVPTX] Implement NVPTXTargetLowering::getSqrtEstimate.

Summary:

This lets us lower to sqrt.approx and rsqrt.approx under more
circumstances.

* Now we emit sqrt.approx and rsqrt.approx for calls to @llvm.sqrt.f32,
  when fast-math is enabled.  Previously, we only would emit it for
  calls to @llvm.nvvm.sqrt.f.  (With this patch we no longer emit
  sqrt.approx for calls to @llvm.nvvm.sqrt.f; we rely on intcombine to
  simplify llvm.nvvm.sqrt.f into llvm.sqrt.f32.)

* Now we emit the ftz version of rsqrt.approx when ftz is enabled.
  Previously, we only emitted rsqrt.approx when ftz was disabled.

Reviewers: hfinkel

Subscribers: llvm-commits, tra, jholewinski

Differential Revision: https://reviews.llvm.org/D28508

Added:
    llvm/trunk/test/CodeGen/NVPTX/sqrt-approx.ll
Removed:
    llvm/trunk/test/CodeGen/NVPTX/rsqrt.ll
Modified:
    llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp
    llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.h
    llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td
    llvm/trunk/test/CodeGen/NVPTX/fast-math.ll

Modified: llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp?rev=293605&r1=293604&r2=293605&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp Mon Jan 30 23:58:22 2017
@@ -1043,6 +1043,50 @@ NVPTXTargetLowering::getPreferredVectorA
   return TargetLoweringBase::getPreferredVectorAction(VT);
 }
 
+SDValue NVPTXTargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG,
+                                             int Enabled, int &ExtraSteps,
+                                             bool &UseOneConst,
+                                             bool Reciprocal) const {
+  if (!(Enabled == ReciprocalEstimate::Enabled ||
+        (Enabled == ReciprocalEstimate::Unspecified && !usePrecSqrtF32())))
+    return SDValue();
+
+  if (ExtraSteps == ReciprocalEstimate::Unspecified)
+    ExtraSteps = 0;
+
+  SDLoc DL(Operand);
+  EVT VT = Operand.getValueType();
+  bool Ftz = useF32FTZ(DAG.getMachineFunction());
+
+  auto MakeIntrinsicCall = [&](Intrinsic::ID IID) {
+    return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
+                       DAG.getConstant(IID, DL, MVT::i32), Operand);
+  };
+
+  // The sqrt and rsqrt refinement processes assume we always start out with an
+  // approximation of the rsqrt.  Therefore, if we're going to do any refinement
+  // (i.e. ExtraSteps > 0), we must return an rsqrt.  But if we're *not* doing
+  // any refinement, we must return a regular sqrt.
+  if (Reciprocal || ExtraSteps > 0) {
+    if (VT == MVT::f32)
+      return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_rsqrt_approx_ftz_f
+                                   : Intrinsic::nvvm_rsqrt_approx_f);
+    else if (VT == MVT::f64)
+      return MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d);
+    else
+      return SDValue();
+  } else {
+    if (VT == MVT::f32)
+      return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_sqrt_approx_ftz_f
+                                   : Intrinsic::nvvm_sqrt_approx_f);
+    else {
+      // There's no sqrt.approx.f64 instruction, so we emit x * rsqrt(x).
+      return DAG.getNode(ISD::FMUL, DL, VT, Operand,
+                         MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d));
+    }
+  }
+}
+
 SDValue
 NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
   SDLoc dl(Op);

Modified: llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.h?rev=293605&r1=293604&r2=293605&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.h (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.h Mon Jan 30 23:58:22 2017
@@ -526,6 +526,10 @@ public:
   // to sign-preserving zero.
   bool useF32FTZ(const MachineFunction &MF) const;
 
+  SDValue getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
+                          int &ExtraSteps, bool &UseOneConst,
+                          bool Reciprocal) const override;
+
   bool allowFMA(MachineFunction &MF, CodeGenOpt::Level OptLevel) const;
   bool allowUnsafeFPMath(MachineFunction &MF) const;
 

Modified: llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td?rev=293605&r1=293604&r2=293605&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td Mon Jan 30 23:58:22 2017
@@ -966,18 +966,9 @@ def FDIV32ri_prec :
             Requires<[reqPTX20]>;
 
 //
-// F32 rsqrt
+// FMA
 //
 
-def RSQRTF32approx1r : NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$b),
-                       "rsqrt.approx.f32 \t$dst, $b;", []>;
-
-// Convert 1.0f/sqrt(x) to rsqrt.approx.f32.  (There is an rsqrt.approx.f64, but
-// it's emulated in software.)
-def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_f Float32Regs:$b)),
-         (RSQRTF32approx1r Float32Regs:$b)>,
-         Requires<[do_DIVF32_FULL, do_SQRTF32_APPROX, doNoF32FTZ]>;
-
 multiclass FMA<string OpcStr, RegisterClass RC, Operand ImmCls, Predicate Pred> {
    def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
                        !strconcat(OpcStr, " \t$dst, $a, $b, $c;"),

Modified: llvm/trunk/test/CodeGen/NVPTX/fast-math.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/NVPTX/fast-math.ll?rev=293605&r1=293604&r2=293605&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/NVPTX/fast-math.ll (original)
+++ llvm/trunk/test/CodeGen/NVPTX/fast-math.ll Mon Jan 30 23:58:22 2017
@@ -1,25 +1,91 @@
 ; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s
 
-declare float @llvm.nvvm.sqrt.f(float)
+declare float @llvm.sqrt.f32(float)
+declare double @llvm.sqrt.f64(double)
 
-; CHECK-LABEL: sqrt_div
+; CHECK-LABEL: sqrt_div(
 ; CHECK: sqrt.rn.f32
 ; CHECK: div.rn.f32
 define float @sqrt_div(float %a, float %b) {
-  %t1 = tail call float @llvm.nvvm.sqrt.f(float %a)
+  %t1 = tail call float @llvm.sqrt.f32(float %a)
   %t2 = fdiv float %t1, %b
   ret float %t2
 }
 
-; CHECK-LABEL: sqrt_div_fast
+; CHECK-LABEL: sqrt_div_fast(
 ; CHECK: sqrt.approx.f32
 ; CHECK: div.approx.f32
 define float @sqrt_div_fast(float %a, float %b) #0 {
-  %t1 = tail call float @llvm.nvvm.sqrt.f(float %a)
+  %t1 = tail call float @llvm.sqrt.f32(float %a)
   %t2 = fdiv float %t1, %b
   ret float %t2
 }
 
+; CHECK-LABEL: sqrt_div_ftz(
+; CHECK: sqrt.rn.ftz.f32
+; CHECK: div.rn.ftz.f32
+define float @sqrt_div_ftz(float %a, float %b) #1 {
+  %t1 = tail call float @llvm.sqrt.f32(float %a)
+  %t2 = fdiv float %t1, %b
+  ret float %t2
+}
+
+; CHECK-LABEL: sqrt_div_fast_ftz(
+; CHECK: sqrt.approx.ftz.f32
+; CHECK: div.approx.ftz.f32
+define float @sqrt_div_fast_ftz(float %a, float %b) #0 #1 {
+  %t1 = tail call float @llvm.sqrt.f32(float %a)
+  %t2 = fdiv float %t1, %b
+  ret float %t2
+}
+
+; There are no fast-math or ftz versions of sqrt and div for f64.  We use
+; x * rsqrt(x) for sqrt(x), and emit a vanilla divide.
+
+; CHECK-LABEL: sqrt_div_fast_ftz_f64(
+; CHECK: rsqrt.approx.f64
+; CHECK: mul.f64
+; CHECK: div.rn.f64
+define double @sqrt_div_fast_ftz_f64(double %a, double %b) #0 #1 {
+  %t1 = tail call double @llvm.sqrt.f64(double %a)
+  %t2 = fdiv double %t1, %b
+  ret double %t2
+}
+
+; CHECK-LABEL: rsqrt(
+; CHECK-NOT: rsqrt.approx
+; CHECK: sqrt.rn.f32
+; CHECK-NOT: rsqrt.approx
+define float @rsqrt(float %a) {
+  %b = tail call float @llvm.sqrt.f32(float %a)
+  %ret = fdiv float 1.0, %b
+  ret float %ret
+}
+
+; CHECK-LABEL: rsqrt_fast(
+; CHECK-NOT: div.
+; CHECK-NOT: sqrt.
+; CHECK: rsqrt.approx.f32
+; CHECK-NOT: div.
+; CHECK-NOT: sqrt.
+define float @rsqrt_fast(float %a) #0 {
+  %b = tail call float @llvm.sqrt.f32(float %a)
+  %ret = fdiv float 1.0, %b
+  ret float %ret
+}
+
+; CHECK-LABEL: rsqrt_fast_ftz(
+; CHECK-NOT: div.
+; CHECK-NOT: sqrt.
+; CHECK: rsqrt.approx.ftz.f32
+; CHECK-NOT: div.
+; CHECK-NOT: sqrt.
+define float @rsqrt_fast_ftz(float %a) #0 #1 {
+  %b = tail call float @llvm.sqrt.f32(float %a)
+  %ret = fdiv float 1.0, %b
+  ret float %ret
+}
+
 ; CHECK-LABEL: fadd
 ; CHECK: add.rn.f32
 define float @fadd(float %a, float %b) {

Removed: llvm/trunk/test/CodeGen/NVPTX/rsqrt.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/NVPTX/rsqrt.ll?rev=293604&view=auto
==============================================================================
--- llvm/trunk/test/CodeGen/NVPTX/rsqrt.ll (original)
+++ llvm/trunk/test/CodeGen/NVPTX/rsqrt.ll (removed)
@@ -1,13 +0,0 @@
-; RUN: llc < %s -march=nvptx -mcpu=sm_20 -nvptx-prec-divf32=1 -nvptx-prec-sqrtf32=0 | FileCheck %s
-
-target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"
-
-declare float @llvm.nvvm.sqrt.f(float)
-
-define float @foo(float %a) {
-; CHECK: rsqrt.approx.f32
-  %val = tail call float @llvm.nvvm.sqrt.f(float %a)
-  %ret = fdiv float 1.0, %val
-  ret float %ret
-}
-  

Added: llvm/trunk/test/CodeGen/NVPTX/sqrt-approx.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/NVPTX/sqrt-approx.ll?rev=293605&view=auto
==============================================================================
--- llvm/trunk/test/CodeGen/NVPTX/sqrt-approx.ll (added)
+++ llvm/trunk/test/CodeGen/NVPTX/sqrt-approx.ll Mon Jan 30 23:58:22 2017
@@ -0,0 +1,148 @@
+; RUN: llc < %s -march=nvptx -mcpu=sm_20 -nvptx-prec-divf32=0 -nvptx-prec-sqrtf32=0 \
+; RUN:   | FileCheck %s
+
+target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"
+
+declare float @llvm.sqrt.f32(float)
+declare double @llvm.sqrt.f64(double)
+
+; -- reciprocal sqrt --
+
+; CHECK-LABEL test_rsqrt32
+define float @test_rsqrt32(float %a) #0 {
+; CHECK: rsqrt.approx.f32
+  %val = tail call float @llvm.sqrt.f32(float %a)
+  %ret = fdiv float 1.0, %val
+  ret float %ret
+}
+
+; CHECK-LABEL test_rsqrt_ftz
+define float @test_rsqrt_ftz(float %a) #0 #1 {
+; CHECK: rsqrt.approx.ftz.f32
+  %val = tail call float @llvm.sqrt.f32(float %a)
+  %ret = fdiv float 1.0, %val
+  ret float %ret
+}
+
+; CHECK-LABEL test_rsqrt64
+define double @test_rsqrt64(double %a) #0 {
+; CHECK: rsqrt.approx.f64
+  %val = tail call double @llvm.sqrt.f64(double %a)
+  %ret = fdiv double 1.0, %val
+  ret double %ret
+}
+
+; CHECK-LABEL test_rsqrt64_ftz
+define double @test_rsqrt64_ftz(double %a) #0 #1 {
+; There's no rsqrt.approx.ftz.f64 instruction; we just use the non-ftz version.
+; CHECK: rsqrt.approx.f64
+  %val = tail call double @llvm.sqrt.f64(double %a)
+  %ret = fdiv double 1.0, %val
+  ret double %ret
+}
+
+; -- sqrt --
+
+; CHECK-LABEL test_sqrt32
+define float @test_sqrt32(float %a) #0 {
+; CHECK: sqrt.approx.f32
+  %ret = tail call float @llvm.sqrt.f32(float %a)
+  ret float %ret
+}
+
+; CHECK-LABEL test_sqrt_ftz
+define float @test_sqrt_ftz(float %a) #0 #1 {
+; CHECK: sqrt.approx.ftz.f32
+  %ret = tail call float @llvm.sqrt.f32(float %a)
+  ret float %ret
+}
+
+; CHECK-LABEL test_sqrt64
+define double @test_sqrt64(double %a) #0 {
+; There's no sqrt.approx.f64 instruction; we emit x * rsqrt.approx.f64(x).
+; CHECK: rsqrt.approx.f64
+; CHECK: mul.f64
+  %ret = tail call double @llvm.sqrt.f64(double %a)
+  ret double %ret
+}
+
+; CHECK-LABEL test_sqrt64_ftz
+define double @test_sqrt64_ftz(double %a) #0 #1 {
+; There's no sqrt.approx.ftz.f64 instruction; we just use the non-ftz version.
+; CHECK: rsqrt.approx.f64
+; CHECK: mul.f64
+  %ret = tail call double @llvm.sqrt.f64(double %a)
+  ret double %ret
+}
+
+; -- refined sqrt and rsqrt --
+;
+; The sqrt and rsqrt refinement algorithms both emit an rsqrt.approx, followed
+; by some math.
+
+; CHECK-LABEL: test_rsqrt32_refined
+define float @test_rsqrt32_refined(float %a) #0 #2 {
+; CHECK: rsqrt.approx.f32
+  %val = tail call float @llvm.sqrt.f32(float %a)
+  %ret = fdiv float 1.0, %val
+  ret float %ret
+}
+
+; CHECK-LABEL: test_sqrt32_refined
+define float @test_sqrt32_refined(float %a) #0 #2 {
+; CHECK: rsqrt.approx.f32
+  %ret = tail call float @llvm.sqrt.f32(float %a)
+  ret float %ret
+}
+
+; CHECK-LABEL: test_rsqrt64_refined
+define double @test_rsqrt64_refined(double %a) #0 #2 {
+; CHECK: rsqrt.approx.f64
+  %val = tail call double @llvm.sqrt.f64(double %a)
+  %ret = fdiv double 1.0, %val
+  ret double %ret
+}
+
+; CHECK-LABEL: test_sqrt64_refined
+define double @test_sqrt64_refined(double %a) #0 #2 {
+; CHECK: rsqrt.approx.f64
+  %ret = tail call double @llvm.sqrt.f64(double %a)
+  ret double %ret
+}
+
+; -- refined sqrt and rsqrt with ftz enabled --
+
+; CHECK-LABEL: test_rsqrt32_refined_ftz
+define float @test_rsqrt32_refined_ftz(float %a) #0 #1 #2 {
+; CHECK: rsqrt.approx.ftz.f32
+  %val = tail call float @llvm.sqrt.f32(float %a)
+  %ret = fdiv float 1.0, %val
+  ret float %ret
+}
+
+; CHECK-LABEL: test_sqrt32_refined_ftz
+define float @test_sqrt32_refined_ftz(float %a) #0 #1 #2 {
+; CHECK: rsqrt.approx.ftz.f32
+  %ret = tail call float @llvm.sqrt.f32(float %a)
+  ret float %ret
+}
+
+; CHECK-LABEL: test_rsqrt64_refined_ftz
+define double @test_rsqrt64_refined_ftz(double %a) #0 #1 #2 {
+; There's no rsqrt.approx.ftz.f64, so we just use the non-ftz version.
+; CHECK: rsqrt.approx.f64
+  %val = tail call double @llvm.sqrt.f64(double %a)
+  %ret = fdiv double 1.0, %val
+  ret double %ret
+}
+
+; CHECK-LABEL: test_sqrt64_refined_ftz
+define double @test_sqrt64_refined_ftz(double %a) #0 #1 #2 {
+; CHECK: rsqrt.approx.f64
+  %ret = tail call double @llvm.sqrt.f64(double %a)
+  ret double %ret
+}
+
+attributes #0 = { "unsafe-fp-math" = "true" }
+attributes #1 = { "nvptx-f32ftz" = "true" }
+attributes #2 = { "reciprocal-estimates" = "rsqrtf:1,rsqrtd:1,sqrtf:1,sqrtd:1" }




More information about the llvm-commits mailing list