[PATCH] X86: introduce rsqrtss/rsqrtps for y / sqrtf(x)
Steven Noonan
steven at uplinklabs.net
Sat Mar 1 23:52:39 PST 2014
In an application that uses inverse square root for distance calculations
I noticed that LLVM and GCC emit different code, with GCC being the faster of
the two. One of the most apparent differences was the lack of lowering for
1.0f / sqrtf(x). While GCC was emitting a reciprocal square root followed by
a Newton-Raphson iteration, LLVM was implementing the much more straightforward
(but also much slower) sqrt followed by div.
This change introduces y / sqrtf(x) combination by handling FDIV combining when
unsafe math optimizations are enabled.
Signed-off-by: Steven Noonan <steven at uplinklabs.net>
---
lib/Target/X86/X86ISelLowering.cpp | 133 +++++++++++++++++++++++++++++++++++++
test/CodeGen/X86/rsqrt-fastmath.ll | 64 ++++++++++++++++++
2 files changed, 197 insertions(+)
create mode 100644 test/CodeGen/X86/rsqrt-fastmath.ll
diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp
index 76eeb64..d47a183 100644
--- a/lib/Target/X86/X86ISelLowering.cpp
+++ b/lib/Target/X86/X86ISelLowering.cpp
@@ -1513,6 +1513,7 @@ void X86TargetLowering::resetOperationActions() {
setTargetDAGCombine(ISD::ADD);
setTargetDAGCombine(ISD::FADD);
setTargetDAGCombine(ISD::FSUB);
+ setTargetDAGCombine(ISD::FDIV);
setTargetDAGCombine(ISD::FMA);
setTargetDAGCombine(ISD::SUB);
setTargetDAGCombine(ISD::LOAD);
@@ -4751,6 +4752,33 @@ static SDValue getZeroVector(EVT VT, const X86Subtarget *Subtarget,
return DAG.getNode(ISD::BITCAST, dl, VT, Vec);
}
+/// getSplatVectorFP - Returns a single-precision floating-point vector with
+/// all elements set to the same value.
+static SDValue getSplatVectorFP(EVT VT, SDValue Cst,
+ const X86Subtarget *Subtarget,
+ SelectionDAG &DAG, SDLoc dl) {
+ assert(VT.isVector() && "Expected a vector type");
+ assert(Cst.getValueType() == MVT::f32 &&
+ "Expected Cst to be a single-precision floating-point value.");
+
+ SDValue Vec;
+ if (VT.is128BitVector()) { // SSE
+ Vec = DAG.getNode(ISD::BUILD_VECTOR, dl, MVT::v4f32, Cst, Cst, Cst, Cst);
+ } else if (VT.is256BitVector()) { // AVX
+ SDValue Ops[] = { Cst, Cst, Cst, Cst, Cst, Cst, Cst, Cst };
+ Vec = DAG.getNode(ISD::BUILD_VECTOR, dl, MVT::v8f32, Ops,
+ array_lengthof(Ops));
+ } else if (VT.is512BitVector()) { // AVX-512
+ SDValue Ops[] = { Cst, Cst, Cst, Cst, Cst, Cst, Cst, Cst,
+ Cst, Cst, Cst, Cst, Cst, Cst, Cst, Cst };
+ Vec = DAG.getNode(ISD::BUILD_VECTOR, dl, MVT::v16f32, Ops,
+ array_lengthof(Ops));
+ } else
+ llvm_unreachable("Unexpected vector type");
+
+ return Vec;
+}
+
/// getOnesVector - Returns a vector of specified type with all bits set.
/// Always build ones vectors as <4 x i32> or <8 x i32>. For 256-bit types with
/// no AVX2 supprt, use two <4 x i32> inserted in a <8 x i32> appropriately.
@@ -18630,6 +18658,110 @@ static SDValue PerformFSUBCombine(SDNode *N, SelectionDAG &DAG,
return SDValue();
}
+// PerformFastRecipFSQRT - Do fast reciprical floating-point square root
+static SDValue PerformFastRecipFSQRT(SDValue Op,
+ TargetLowering::DAGCombinerInfo &DCI,
+ const X86Subtarget *Subtarget) {
+ if (DCI.isAfterLegalizeVectorOps())
+ return SDValue();
+
+ EVT VT = Op.getValueType();
+
+ if (VT == MVT::f32 ||
+ (VT == MVT::v4f32 && Subtarget->hasSSE1()) ||
+ (VT == MVT::v8f32 && Subtarget->hasFp256()) ||
+ (VT == MVT::v16f32 && Subtarget->hasAVX512()) ) {
+
+ SelectionDAG &DAG = DCI.DAG;
+ SDLoc dl(Op);
+
+ SDValue FPThree =
+ DAG.getConstantFP(-3.0, VT.getScalarType());
+ if (VT.isVector()) {
+ FPThree = getSplatVectorFP(VT, FPThree, Subtarget, DAG, dl);
+ }
+
+ SDValue FPHalf =
+ DAG.getConstantFP(-0.5, VT.getScalarType());
+ if (VT.isVector()) {
+ FPHalf = getSplatVectorFP(VT, FPHalf, Subtarget, DAG, dl);
+ }
+
+ SDValue Est = DAG.getNode(X86ISD::FRSQRT, dl, VT, Op);
+ DCI.AddToWorklist(Est.getNode());
+
+ SDValue E0 = DAG.getNode(ISD::FMUL, dl, VT, Est, Op);
+ DCI.AddToWorklist(E0.getNode());
+
+ SDValue E1 = DAG.getNode(ISD::FMUL, dl, VT, E0, Est);
+ DCI.AddToWorklist(E1.getNode());
+
+ SDValue E2 = DAG.getNode(ISD::FADD, dl, VT, E1, FPThree);
+ DCI.AddToWorklist(E2.getNode());
+
+ SDValue E3 = DAG.getNode(ISD::FMUL, dl, VT, FPHalf, Est);
+ DCI.AddToWorklist(E3.getNode());
+
+ Est = DAG.getNode(ISD::FMUL, dl, VT, E2, E3);
+ DCI.AddToWorklist(Est.getNode());
+
+ return Est;
+ }
+
+ return SDValue();
+}
+
+/// PerformFDIVCombine - Do target-specific dag update on floating point divs.
+static SDValue PerformFDIVCombine(SDNode *N, SelectionDAG &DAG,
+ TargetLowering::DAGCombinerInfo &DCI,
+ const X86Subtarget *Subtarget) {
+
+ // These tricks will reduce precision.
+ if (!DAG.getTarget().Options.UnsafeFPMath)
+ return SDValue();
+
+ SDLoc dl(N);
+
+ if (N->getOperand(1).getOpcode() == ISD::FSQRT) {
+ SDValue RV =
+ PerformFastRecipFSQRT(N->getOperand(1).getOperand(0), DCI, Subtarget);
+ if (RV.getNode() != 0) {
+ DCI.AddToWorklist(RV.getNode());
+ return DAG.getNode(ISD::FMUL, dl, N->getValueType(0),
+ N->getOperand(0), RV);
+ }
+ } else if (N->getOperand(1).getOpcode() == ISD::FP_EXTEND &&
+ N->getOperand(1).getOperand(0).getOpcode() == ISD::FSQRT) {
+ SDValue RV =
+ PerformFastRecipFSQRT(N->getOperand(1).getOperand(0).getOperand(0),
+ DCI, Subtarget);
+ if (RV.getNode() != 0) {
+ DCI.AddToWorklist(RV.getNode());
+ RV = DAG.getNode(ISD::FP_EXTEND, SDLoc(N->getOperand(1)),
+ N->getValueType(0), RV);
+ DCI.AddToWorklist(RV.getNode());
+ return DAG.getNode(ISD::FMUL, dl, N->getValueType(0),
+ N->getOperand(0), RV);
+ }
+ } else if (N->getOperand(1).getOpcode() == ISD::FP_ROUND &&
+ N->getOperand(1).getOperand(0).getOpcode() == ISD::FSQRT) {
+ SDValue RV =
+ PerformFastRecipFSQRT(N->getOperand(1).getOperand(0).getOperand(0),
+ DCI, Subtarget);
+ if (RV.getNode() != 0) {
+ DCI.AddToWorklist(RV.getNode());
+ RV = DAG.getNode(ISD::FP_ROUND, SDLoc(N->getOperand(1)),
+ N->getValueType(0), RV,
+ N->getOperand(1).getOperand(1));
+ DCI.AddToWorklist(RV.getNode());
+ return DAG.getNode(ISD::FMUL, dl, N->getValueType(0),
+ N->getOperand(0), RV);
+ }
+ }
+
+ return SDValue();
+}
+
/// PerformFORCombine - Do target-specific dag combines on X86ISD::FOR and
/// X86ISD::FXOR nodes.
static SDValue PerformFORCombine(SDNode *N, SelectionDAG &DAG) {
@@ -19142,6 +19274,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
case ISD::SINT_TO_FP: return PerformSINT_TO_FPCombine(N, DAG, this);
case ISD::FADD: return PerformFADDCombine(N, DAG, Subtarget);
case ISD::FSUB: return PerformFSUBCombine(N, DAG, Subtarget);
+ case ISD::FDIV: return PerformFDIVCombine(N, DAG, DCI, Subtarget);
case X86ISD::FXOR:
case X86ISD::FOR: return PerformFORCombine(N, DAG);
case X86ISD::FMIN:
diff --git a/test/CodeGen/X86/rsqrt-fastmath.ll b/test/CodeGen/X86/rsqrt-fastmath.ll
new file mode 100644
index 0000000..5f393e9
--- /dev/null
+++ b/test/CodeGen/X86/rsqrt-fastmath.ll
@@ -0,0 +1,64 @@
+; RUN: llc < %s -mcpu=core2 | FileCheck %s
+
+; generated using "clang -S -O2 -ffast-math -emit-llvm sqrt.c" from
+; #include <math.h>
+;
+; double fd(double d) {
+; return 1.0 / sqrt(d);
+; }
+;
+; float ff(float f) {
+; return 1.0f / sqrtf(f);
+; }
+;
+; long double fld(long double ld) {
+; return 1.0 / sqrtl(ld);
+; }
+;
+; Tests conversion of single-precision floating-point square root instructions
+; in the denominator of a floating point divide into rsqrt instructions when
+; -ffast-math is in effect.
+
+; ModuleID = 'rsqrt.c'
+target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+; Function Attrs: nounwind readnone uwtable
+define double @fd(double %d) #0 {
+entry:
+; CHECK: sqrtsd
+ %0 = tail call double @llvm.sqrt.f64(double %d)
+ %div = fdiv fast double 1.000000e+00, %0
+ ret double %div
+}
+
+; Function Attrs: nounwind readonly
+declare double @llvm.sqrt.f64(double) #1
+
+; Function Attrs: nounwind readnone uwtable
+define float @ff(float %f) #0 {
+entry:
+; CHECK: rsqrtss
+ %0 = tail call float @llvm.sqrt.f32(float %f)
+ %div = fdiv fast float 1.000000e+00, %0
+ ret float %div
+}
+
+; Function Attrs: nounwind readonly
+declare float @llvm.sqrt.f32(float) #1
+
+; Function Attrs: nounwind readnone uwtable
+define x86_fp80 @fld(x86_fp80 %ld) #0 {
+entry:
+; CHECK: fsqrt
+ %0 = tail call x86_fp80 @llvm.sqrt.f80(x86_fp80 %ld)
+ %div = fdiv fast x86_fp80 0xK3FFF8000000000000000, %0
+ ret x86_fp80 %div
+}
+
+; Function Attrs: nounwind readonly
+declare x86_fp80 @llvm.sqrt.f80(x86_fp80) #1
+
+attributes #0 = { nounwind readnone uwtable "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="true" "no-nans-fp-math"="true" "unsafe-fp-math"="true" "use-soft-float"="false" }
+attributes #1 = { nounwind readnone "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="true" "no-nans-fp-math"="true" "unsafe-fp-math"="true" "use-soft-float"="false" }
+attributes #2 = { nounwind readnone }
--
1.9.0
More information about the llvm-commits
mailing list