[PATCH] X86: introduce rsqrtss/rsqrtps for y / sqrtf(x)

Steven Noonan steven at uplinklabs.net
Sat Mar 1 23:52:39 PST 2014


In an application that uses inverse square root for distance calculations
I noticed that LLVM and GCC emit different code, with GCC being the faster of
the two. One of the most apparent differences was the lack of lowering for
1.0f / sqrtf(x). While GCC was emitting a reciprocal square root followed by
a Newton-Raphson iteration, LLVM was implementing the much more straightforward
(but also much slower) sqrt followed by div.

This change introduces y / sqrtf(x) combination by handling FDIV combining when
unsafe math optimizations are enabled.

Signed-off-by: Steven Noonan <steven at uplinklabs.net>
---
 lib/Target/X86/X86ISelLowering.cpp | 133 +++++++++++++++++++++++++++++++++++++
 test/CodeGen/X86/rsqrt-fastmath.ll |  64 ++++++++++++++++++
 2 files changed, 197 insertions(+)
 create mode 100644 test/CodeGen/X86/rsqrt-fastmath.ll

diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp
index 76eeb64..d47a183 100644
--- a/lib/Target/X86/X86ISelLowering.cpp
+++ b/lib/Target/X86/X86ISelLowering.cpp
@@ -1513,6 +1513,7 @@ void X86TargetLowering::resetOperationActions() {
   setTargetDAGCombine(ISD::ADD);
   setTargetDAGCombine(ISD::FADD);
   setTargetDAGCombine(ISD::FSUB);
+  setTargetDAGCombine(ISD::FDIV);
   setTargetDAGCombine(ISD::FMA);
   setTargetDAGCombine(ISD::SUB);
   setTargetDAGCombine(ISD::LOAD);
@@ -4751,6 +4752,33 @@ static SDValue getZeroVector(EVT VT, const X86Subtarget *Subtarget,
   return DAG.getNode(ISD::BITCAST, dl, VT, Vec);
 }
 
+/// getSplatVectorFP - Returns a single-precision floating-point vector with
+/// all elements set to the same value.
+static SDValue getSplatVectorFP(EVT VT, SDValue Cst,
+                                const X86Subtarget *Subtarget,
+                                SelectionDAG &DAG, SDLoc dl) {
+  assert(VT.isVector() && "Expected a vector type");
+  assert(Cst.getValueType() == MVT::f32 &&
+         "Expected Cst to be a single-precision floating-point value.");
+
+  SDValue Vec;
+  if (VT.is128BitVector()) {  // SSE
+    Vec = DAG.getNode(ISD::BUILD_VECTOR, dl, MVT::v4f32, Cst, Cst, Cst, Cst);
+  } else if (VT.is256BitVector()) { // AVX
+    SDValue Ops[] = { Cst, Cst, Cst, Cst, Cst, Cst, Cst, Cst };
+    Vec = DAG.getNode(ISD::BUILD_VECTOR, dl, MVT::v8f32, Ops,
+                      array_lengthof(Ops));
+  } else if (VT.is512BitVector()) { // AVX-512
+    SDValue Ops[] = { Cst, Cst, Cst, Cst, Cst, Cst, Cst, Cst,
+                      Cst, Cst, Cst, Cst, Cst, Cst, Cst, Cst };
+    Vec = DAG.getNode(ISD::BUILD_VECTOR, dl, MVT::v16f32, Ops,
+        array_lengthof(Ops));
+  } else
+    llvm_unreachable("Unexpected vector type");
+
+  return Vec;
+}
+
 /// getOnesVector - Returns a vector of specified type with all bits set.
 /// Always build ones vectors as <4 x i32> or <8 x i32>. For 256-bit types with
 /// no AVX2 supprt, use two <4 x i32> inserted in a <8 x i32> appropriately.
@@ -18630,6 +18658,110 @@ static SDValue PerformFSUBCombine(SDNode *N, SelectionDAG &DAG,
   return SDValue();
 }
 
+// PerformFastRecipFSQRT - Do fast reciprical floating-point square root
+static SDValue PerformFastRecipFSQRT(SDValue Op,
+                                     TargetLowering::DAGCombinerInfo &DCI,
+                                     const X86Subtarget *Subtarget) {
+  if (DCI.isAfterLegalizeVectorOps())
+    return SDValue();
+
+  EVT VT = Op.getValueType();
+
+  if (VT == MVT::f32 ||
+      (VT == MVT::v4f32 && Subtarget->hasSSE1()) ||
+      (VT == MVT::v8f32 && Subtarget->hasFp256()) ||
+      (VT == MVT::v16f32 && Subtarget->hasAVX512()) ) {
+
+    SelectionDAG &DAG = DCI.DAG;
+    SDLoc dl(Op);
+
+    SDValue FPThree =
+      DAG.getConstantFP(-3.0, VT.getScalarType());
+    if (VT.isVector()) {
+      FPThree = getSplatVectorFP(VT, FPThree, Subtarget, DAG, dl);
+    }
+
+    SDValue FPHalf =
+      DAG.getConstantFP(-0.5, VT.getScalarType());
+    if (VT.isVector()) {
+      FPHalf = getSplatVectorFP(VT, FPHalf, Subtarget, DAG, dl);
+    }
+
+    SDValue Est = DAG.getNode(X86ISD::FRSQRT, dl, VT, Op);
+    DCI.AddToWorklist(Est.getNode());
+
+    SDValue E0 = DAG.getNode(ISD::FMUL, dl, VT, Est, Op);
+    DCI.AddToWorklist(E0.getNode());
+
+    SDValue E1 = DAG.getNode(ISD::FMUL, dl, VT, E0, Est);
+    DCI.AddToWorklist(E1.getNode());
+
+    SDValue E2 = DAG.getNode(ISD::FADD, dl, VT, E1, FPThree);
+    DCI.AddToWorklist(E2.getNode());
+
+    SDValue E3 = DAG.getNode(ISD::FMUL, dl, VT, FPHalf, Est);
+    DCI.AddToWorklist(E3.getNode());
+
+    Est = DAG.getNode(ISD::FMUL, dl, VT, E2, E3);
+    DCI.AddToWorklist(Est.getNode());
+
+    return Est;
+  }
+
+  return SDValue();
+}
+
+/// PerformFDIVCombine - Do target-specific dag update on floating point divs.
+static SDValue PerformFDIVCombine(SDNode *N, SelectionDAG &DAG,
+                                  TargetLowering::DAGCombinerInfo &DCI,
+                                  const X86Subtarget *Subtarget) {
+
+  // These tricks will reduce precision.
+  if (!DAG.getTarget().Options.UnsafeFPMath)
+    return SDValue();
+
+  SDLoc dl(N);
+
+  if (N->getOperand(1).getOpcode() == ISD::FSQRT) {
+    SDValue RV =
+      PerformFastRecipFSQRT(N->getOperand(1).getOperand(0), DCI, Subtarget);
+    if (RV.getNode() != 0) {
+      DCI.AddToWorklist(RV.getNode());
+      return DAG.getNode(ISD::FMUL, dl, N->getValueType(0),
+                         N->getOperand(0), RV);
+    }
+  } else if (N->getOperand(1).getOpcode() == ISD::FP_EXTEND &&
+             N->getOperand(1).getOperand(0).getOpcode() == ISD::FSQRT) {
+    SDValue RV =
+      PerformFastRecipFSQRT(N->getOperand(1).getOperand(0).getOperand(0),
+                            DCI, Subtarget);
+    if (RV.getNode() != 0) {
+      DCI.AddToWorklist(RV.getNode());
+      RV = DAG.getNode(ISD::FP_EXTEND, SDLoc(N->getOperand(1)),
+                       N->getValueType(0), RV);
+      DCI.AddToWorklist(RV.getNode());
+      return DAG.getNode(ISD::FMUL, dl, N->getValueType(0),
+                         N->getOperand(0), RV);
+    }
+  } else if (N->getOperand(1).getOpcode() == ISD::FP_ROUND &&
+             N->getOperand(1).getOperand(0).getOpcode() == ISD::FSQRT) {
+    SDValue RV =
+      PerformFastRecipFSQRT(N->getOperand(1).getOperand(0).getOperand(0),
+                            DCI, Subtarget);
+    if (RV.getNode() != 0) {
+      DCI.AddToWorklist(RV.getNode());
+      RV = DAG.getNode(ISD::FP_ROUND, SDLoc(N->getOperand(1)),
+                       N->getValueType(0), RV,
+                       N->getOperand(1).getOperand(1));
+      DCI.AddToWorklist(RV.getNode());
+      return DAG.getNode(ISD::FMUL, dl, N->getValueType(0),
+                         N->getOperand(0), RV);
+    }
+  }
+
+  return SDValue();
+}
+
 /// PerformFORCombine - Do target-specific dag combines on X86ISD::FOR and
 /// X86ISD::FXOR nodes.
 static SDValue PerformFORCombine(SDNode *N, SelectionDAG &DAG) {
@@ -19142,6 +19274,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
   case ISD::SINT_TO_FP:     return PerformSINT_TO_FPCombine(N, DAG, this);
   case ISD::FADD:           return PerformFADDCombine(N, DAG, Subtarget);
   case ISD::FSUB:           return PerformFSUBCombine(N, DAG, Subtarget);
+  case ISD::FDIV:           return PerformFDIVCombine(N, DAG, DCI, Subtarget);
   case X86ISD::FXOR:
   case X86ISD::FOR:         return PerformFORCombine(N, DAG);
   case X86ISD::FMIN:
diff --git a/test/CodeGen/X86/rsqrt-fastmath.ll b/test/CodeGen/X86/rsqrt-fastmath.ll
new file mode 100644
index 0000000..5f393e9
--- /dev/null
+++ b/test/CodeGen/X86/rsqrt-fastmath.ll
@@ -0,0 +1,64 @@
+; RUN: llc < %s -mcpu=core2 | FileCheck %s
+
+; generated using "clang -S -O2 -ffast-math -emit-llvm sqrt.c" from
+; #include <math.h>
+;
+; double fd(double d) {
+;   return 1.0 / sqrt(d);
+; }
+;
+; float ff(float f) {
+;   return 1.0f / sqrtf(f);
+; }
+;
+; long double fld(long double ld) {
+;   return 1.0 / sqrtl(ld);
+; }
+;
+; Tests conversion of single-precision floating-point square root instructions
+; in the denominator of a floating point divide into rsqrt instructions when
+; -ffast-math is in effect.
+
+; ModuleID = 'rsqrt.c'
+target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+; Function Attrs: nounwind readnone uwtable
+define double @fd(double %d) #0 {
+entry:
+; CHECK: sqrtsd
+  %0 = tail call double @llvm.sqrt.f64(double %d)
+  %div = fdiv fast double 1.000000e+00, %0
+  ret double %div
+}
+
+; Function Attrs: nounwind readonly
+declare double @llvm.sqrt.f64(double) #1
+
+; Function Attrs: nounwind readnone uwtable
+define float @ff(float %f) #0 {
+entry:
+; CHECK: rsqrtss
+  %0 = tail call float @llvm.sqrt.f32(float %f)
+  %div = fdiv fast float 1.000000e+00, %0
+  ret float %div
+}
+
+; Function Attrs: nounwind readonly
+declare float @llvm.sqrt.f32(float) #1
+
+; Function Attrs: nounwind readnone uwtable
+define x86_fp80 @fld(x86_fp80 %ld) #0 {
+entry:
+; CHECK: fsqrt
+  %0 = tail call x86_fp80 @llvm.sqrt.f80(x86_fp80 %ld)
+  %div = fdiv fast x86_fp80 0xK3FFF8000000000000000, %0
+  ret x86_fp80 %div
+}
+
+; Function Attrs: nounwind readonly
+declare x86_fp80 @llvm.sqrt.f80(x86_fp80) #1
+
+attributes #0 = { nounwind readnone uwtable "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="true" "no-nans-fp-math"="true" "unsafe-fp-math"="true" "use-soft-float"="false" }
+attributes #1 = { nounwind readnone "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="true" "no-nans-fp-math"="true" "unsafe-fp-math"="true" "use-soft-float"="false" }
+attributes #2 = { nounwind readnone }
-- 
1.9.0




More information about the llvm-commits mailing list