[PATCH] X86: introduce rsqrtss/rsqrtps for y / sqrtf(x)

Hal Finkel hfinkel at anl.gov
Sun Mar 2 11:40:59 PST 2014


----- Original Message -----
> From: "Steven Noonan" <steven at uplinklabs.net>
> To: "Hal Finkel" <hfinkel at anl.gov>
> Cc: llvm-commits at cs.uiuc.edu
> Sent: Sunday, March 2, 2014 1:25:17 PM
> Subject: Re: [PATCH] X86: introduce rsqrtss/rsqrtps for y / sqrtf(x)
> 
> On Sun, Mar 2, 2014 at 7:57 AM, Hal Finkel <hfinkel at anl.gov> wrote:
> > ----- Original Message -----
> >> From: "Steven Noonan" <steven at uplinklabs.net>
> >> To: llvm-commits at cs.uiuc.edu
> >> Sent: Sunday, March 2, 2014 1:52:39 AM
> >> Subject: [PATCH] X86: introduce rsqrtss/rsqrtps for y / sqrtf(x)
> >>
> >> In an application that uses inverse square root for distance
> >> calculations
> >> I noticed that LLVM and GCC emit different code, with GCC being
> >> the
> >> faster of
> >> the two. One of the most apparent differences was the lack of
> >> lowering for
> >> 1.0f / sqrtf(x). While GCC was emitting a reciprocal square root
> >> followed by
> >> a Newton-Raphson iteration, LLVM was implementing the much more
> >> straightforward
> >> (but also much slower) sqrt followed by div.
> >>
> >> This change introduces y / sqrtf(x) combination by handling FDIV
> >> combining when
> >> unsafe math optimizations are enabled.
> >>
> >> Signed-off-by: Steven Noonan <steven at uplinklabs.net>
> >
> > I wrote very similar code in the PowerPC backend. Do you think we
> > could refactor these into a common utility? There are several
> > target architectures that have reciprocal estimate functions (X86,
> > PowerPC, ARM, etc.)
> >
> >  -Hal
> 
> Yeah, I was actually going to target ARM with a change like this in a
> while. I agree in principle that it should be a common utility, but
> in
> practice I'm not sure where that would sit or what it would look
> like.
> Can we collaborate on this?

Yes. The PowerPC implementation can generate a variable number of iterations, and I think could serve as the base for this. Can you look at the implementation and note any roadblocks you see to generalization to handle X86 (and ARM) as well? I think that the common code should live in DAGCombine, and we might need to add some new callbacks (and maybe a pair of new ISD nodes?)

Thanks again,
Hal

> 
> >
> >> ---
> >>  lib/Target/X86/X86ISelLowering.cpp | 133
> >>  +++++++++++++++++++++++++++++++++++++
> >>  test/CodeGen/X86/rsqrt-fastmath.ll |  64 ++++++++++++++++++
> >>  2 files changed, 197 insertions(+)
> >>  create mode 100644 test/CodeGen/X86/rsqrt-fastmath.ll
> >>
> >> diff --git a/lib/Target/X86/X86ISelLowering.cpp
> >> b/lib/Target/X86/X86ISelLowering.cpp
> >> index 76eeb64..d47a183 100644
> >> --- a/lib/Target/X86/X86ISelLowering.cpp
> >> +++ b/lib/Target/X86/X86ISelLowering.cpp
> >> @@ -1513,6 +1513,7 @@ void
> >> X86TargetLowering::resetOperationActions()
> >> {
> >>    setTargetDAGCombine(ISD::ADD);
> >>    setTargetDAGCombine(ISD::FADD);
> >>    setTargetDAGCombine(ISD::FSUB);
> >> +  setTargetDAGCombine(ISD::FDIV);
> >>    setTargetDAGCombine(ISD::FMA);
> >>    setTargetDAGCombine(ISD::SUB);
> >>    setTargetDAGCombine(ISD::LOAD);
> >> @@ -4751,6 +4752,33 @@ static SDValue getZeroVector(EVT VT, const
> >> X86Subtarget *Subtarget,
> >>    return DAG.getNode(ISD::BITCAST, dl, VT, Vec);
> >>  }
> >>
> >> +/// getSplatVectorFP - Returns a single-precision floating-point
> >> vector with
> >> +/// all elements set to the same value.
> >> +static SDValue getSplatVectorFP(EVT VT, SDValue Cst,
> >> +                                const X86Subtarget *Subtarget,
> >> +                                SelectionDAG &DAG, SDLoc dl) {
> >> +  assert(VT.isVector() && "Expected a vector type");
> >> +  assert(Cst.getValueType() == MVT::f32 &&
> >> +         "Expected Cst to be a single-precision floating-point
> >> value.");
> >> +
> >> +  SDValue Vec;
> >> +  if (VT.is128BitVector()) {  // SSE
> >> +    Vec = DAG.getNode(ISD::BUILD_VECTOR, dl, MVT::v4f32, Cst,
> >> Cst,
> >> Cst, Cst);
> >> +  } else if (VT.is256BitVector()) { // AVX
> >> +    SDValue Ops[] = { Cst, Cst, Cst, Cst, Cst, Cst, Cst, Cst };
> >> +    Vec = DAG.getNode(ISD::BUILD_VECTOR, dl, MVT::v8f32, Ops,
> >> +                      array_lengthof(Ops));
> >> +  } else if (VT.is512BitVector()) { // AVX-512
> >> +    SDValue Ops[] = { Cst, Cst, Cst, Cst, Cst, Cst, Cst, Cst,
> >> +                      Cst, Cst, Cst, Cst, Cst, Cst, Cst, Cst };
> >> +    Vec = DAG.getNode(ISD::BUILD_VECTOR, dl, MVT::v16f32, Ops,
> >> +        array_lengthof(Ops));
> >> +  } else
> >> +    llvm_unreachable("Unexpected vector type");
> >> +
> >> +  return Vec;
> >> +}
> >> +
> >>  /// getOnesVector - Returns a vector of specified type with all
> >>  bits
> >>  set.
> >>  /// Always build ones vectors as <4 x i32> or <8 x i32>. For
> >>  256-bit
> >>  types with
> >>  /// no AVX2 supprt, use two <4 x i32> inserted in a <8 x i32>
> >>  appropriately.
> >> @@ -18630,6 +18658,110 @@ static SDValue PerformFSUBCombine(SDNode
> >> *N, SelectionDAG &DAG,
> >>    return SDValue();
> >>  }
> >>
> >> +// PerformFastRecipFSQRT - Do fast reciprical floating-point
> >> square
> >> root
> >> +static SDValue PerformFastRecipFSQRT(SDValue Op,
> >> +
> >>                                     TargetLowering::DAGCombinerInfo
> >> &DCI,
> >> +                                     const X86Subtarget
> >> *Subtarget)
> >> {
> >> +  if (DCI.isAfterLegalizeVectorOps())
> >> +    return SDValue();
> >> +
> >> +  EVT VT = Op.getValueType();
> >> +
> >> +  if (VT == MVT::f32 ||
> >> +      (VT == MVT::v4f32 && Subtarget->hasSSE1()) ||
> >> +      (VT == MVT::v8f32 && Subtarget->hasFp256()) ||
> >> +      (VT == MVT::v16f32 && Subtarget->hasAVX512()) ) {
> >> +
> >> +    SelectionDAG &DAG = DCI.DAG;
> >> +    SDLoc dl(Op);
> >> +
> >> +    SDValue FPThree =
> >> +      DAG.getConstantFP(-3.0, VT.getScalarType());
> >> +    if (VT.isVector()) {
> >> +      FPThree = getSplatVectorFP(VT, FPThree, Subtarget, DAG,
> >> dl);
> >> +    }
> >> +
> >> +    SDValue FPHalf =
> >> +      DAG.getConstantFP(-0.5, VT.getScalarType());
> >> +    if (VT.isVector()) {
> >> +      FPHalf = getSplatVectorFP(VT, FPHalf, Subtarget, DAG, dl);
> >> +    }
> >> +
> >> +    SDValue Est = DAG.getNode(X86ISD::FRSQRT, dl, VT, Op);
> >> +    DCI.AddToWorklist(Est.getNode());
> >> +
> >> +    SDValue E0 = DAG.getNode(ISD::FMUL, dl, VT, Est, Op);
> >> +    DCI.AddToWorklist(E0.getNode());
> >> +
> >> +    SDValue E1 = DAG.getNode(ISD::FMUL, dl, VT, E0, Est);
> >> +    DCI.AddToWorklist(E1.getNode());
> >> +
> >> +    SDValue E2 = DAG.getNode(ISD::FADD, dl, VT, E1, FPThree);
> >> +    DCI.AddToWorklist(E2.getNode());
> >> +
> >> +    SDValue E3 = DAG.getNode(ISD::FMUL, dl, VT, FPHalf, Est);
> >> +    DCI.AddToWorklist(E3.getNode());
> >> +
> >> +    Est = DAG.getNode(ISD::FMUL, dl, VT, E2, E3);
> >> +    DCI.AddToWorklist(Est.getNode());
> >> +
> >> +    return Est;
> >> +  }
> >> +
> >> +  return SDValue();
> >> +}
> >> +
> >> +/// PerformFDIVCombine - Do target-specific dag update on
> >> floating
> >> point divs.
> >> +static SDValue PerformFDIVCombine(SDNode *N, SelectionDAG &DAG,
> >> +                                  TargetLowering::DAGCombinerInfo
> >> &DCI,
> >> +                                  const X86Subtarget *Subtarget)
> >> {
> >> +
> >> +  // These tricks will reduce precision.
> >> +  if (!DAG.getTarget().Options.UnsafeFPMath)
> >> +    return SDValue();
> >> +
> >> +  SDLoc dl(N);
> >> +
> >> +  if (N->getOperand(1).getOpcode() == ISD::FSQRT) {
> >> +    SDValue RV =
> >> +      PerformFastRecipFSQRT(N->getOperand(1).getOperand(0), DCI,
> >> Subtarget);
> >> +    if (RV.getNode() != 0) {
> >> +      DCI.AddToWorklist(RV.getNode());
> >> +      return DAG.getNode(ISD::FMUL, dl, N->getValueType(0),
> >> +                         N->getOperand(0), RV);
> >> +    }
> >> +  } else if (N->getOperand(1).getOpcode() == ISD::FP_EXTEND &&
> >> +             N->getOperand(1).getOperand(0).getOpcode() ==
> >> ISD::FSQRT) {
> >> +    SDValue RV =
> >> +
> >>      PerformFastRecipFSQRT(N->getOperand(1).getOperand(0).getOperand(0),
> >> +                            DCI, Subtarget);
> >> +    if (RV.getNode() != 0) {
> >> +      DCI.AddToWorklist(RV.getNode());
> >> +      RV = DAG.getNode(ISD::FP_EXTEND, SDLoc(N->getOperand(1)),
> >> +                       N->getValueType(0), RV);
> >> +      DCI.AddToWorklist(RV.getNode());
> >> +      return DAG.getNode(ISD::FMUL, dl, N->getValueType(0),
> >> +                         N->getOperand(0), RV);
> >> +    }
> >> +  } else if (N->getOperand(1).getOpcode() == ISD::FP_ROUND &&
> >> +             N->getOperand(1).getOperand(0).getOpcode() ==
> >> ISD::FSQRT) {
> >> +    SDValue RV =
> >> +
> >>      PerformFastRecipFSQRT(N->getOperand(1).getOperand(0).getOperand(0),
> >> +                            DCI, Subtarget);
> >> +    if (RV.getNode() != 0) {
> >> +      DCI.AddToWorklist(RV.getNode());
> >> +      RV = DAG.getNode(ISD::FP_ROUND, SDLoc(N->getOperand(1)),
> >> +                       N->getValueType(0), RV,
> >> +                       N->getOperand(1).getOperand(1));
> >> +      DCI.AddToWorklist(RV.getNode());
> >> +      return DAG.getNode(ISD::FMUL, dl, N->getValueType(0),
> >> +                         N->getOperand(0), RV);
> >> +    }
> >> +  }
> >> +
> >> +  return SDValue();
> >> +}
> >> +
> >>  /// PerformFORCombine - Do target-specific dag combines on
> >>  X86ISD::FOR and
> >>  /// X86ISD::FXOR nodes.
> >>  static SDValue PerformFORCombine(SDNode *N, SelectionDAG &DAG) {
> >> @@ -19142,6 +19274,7 @@ SDValue
> >> X86TargetLowering::PerformDAGCombine(SDNode *N,
> >>    case ISD::SINT_TO_FP:     return PerformSINT_TO_FPCombine(N,
> >>    DAG,
> >>    this);
> >>    case ISD::FADD:           return PerformFADDCombine(N, DAG,
> >>    Subtarget);
> >>    case ISD::FSUB:           return PerformFSUBCombine(N, DAG,
> >>    Subtarget);
> >> +  case ISD::FDIV:           return PerformFDIVCombine(N, DAG,
> >> DCI,
> >> Subtarget);
> >>    case X86ISD::FXOR:
> >>    case X86ISD::FOR:         return PerformFORCombine(N, DAG);
> >>    case X86ISD::FMIN:
> >> diff --git a/test/CodeGen/X86/rsqrt-fastmath.ll
> >> b/test/CodeGen/X86/rsqrt-fastmath.ll
> >> new file mode 100644
> >> index 0000000..5f393e9
> >> --- /dev/null
> >> +++ b/test/CodeGen/X86/rsqrt-fastmath.ll
> >> @@ -0,0 +1,64 @@
> >> +; RUN: llc < %s -mcpu=core2 | FileCheck %s
> >> +
> >> +; generated using "clang -S -O2 -ffast-math -emit-llvm sqrt.c"
> >> from
> >> +; #include <math.h>
> >> +;
> >> +; double fd(double d) {
> >> +;   return 1.0 / sqrt(d);
> >> +; }
> >> +;
> >> +; float ff(float f) {
> >> +;   return 1.0f / sqrtf(f);
> >> +; }
> >> +;
> >> +; long double fld(long double ld) {
> >> +;   return 1.0 / sqrtl(ld);
> >> +; }
> >> +;
> >> +; Tests conversion of single-precision floating-point square root
> >> instructions
> >> +; in the denominator of a floating point divide into rsqrt
> >> instructions when
> >> +; -ffast-math is in effect.
> >> +
> >> +; ModuleID = 'rsqrt.c'
> >> +target datalayout =
> >> "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:16:32:64-S128"
> >> +target triple = "x86_64-unknown-linux-gnu"
> >> +
> >> +; Function Attrs: nounwind readnone uwtable
> >> +define double @fd(double %d) #0 {
> >> +entry:
> >> +; CHECK: sqrtsd
> >> +  %0 = tail call double @llvm.sqrt.f64(double %d)
> >> +  %div = fdiv fast double 1.000000e+00, %0
> >> +  ret double %div
> >> +}
> >> +
> >> +; Function Attrs: nounwind readonly
> >> +declare double @llvm.sqrt.f64(double) #1
> >> +
> >> +; Function Attrs: nounwind readnone uwtable
> >> +define float @ff(float %f) #0 {
> >> +entry:
> >> +; CHECK: rsqrtss
> >> +  %0 = tail call float @llvm.sqrt.f32(float %f)
> >> +  %div = fdiv fast float 1.000000e+00, %0
> >> +  ret float %div
> >> +}
> >> +
> >> +; Function Attrs: nounwind readonly
> >> +declare float @llvm.sqrt.f32(float) #1
> >> +
> >> +; Function Attrs: nounwind readnone uwtable
> >> +define x86_fp80 @fld(x86_fp80 %ld) #0 {
> >> +entry:
> >> +; CHECK: fsqrt
> >> +  %0 = tail call x86_fp80 @llvm.sqrt.f80(x86_fp80 %ld)
> >> +  %div = fdiv fast x86_fp80 0xK3FFF8000000000000000, %0
> >> +  ret x86_fp80 %div
> >> +}
> >> +
> >> +; Function Attrs: nounwind readonly
> >> +declare x86_fp80 @llvm.sqrt.f80(x86_fp80) #1
> >> +
> >> +attributes #0 = { nounwind readnone uwtable
> >> "less-precise-fpmad"="false" "no-frame-pointer-elim"="false"
> >> "no-infs-fp-math"="true" "no-nans-fp-math"="true"
> >> "unsafe-fp-math"="true" "use-soft-float"="false" }
> >> +attributes #1 = { nounwind readnone "less-precise-fpmad"="false"
> >> "no-frame-pointer-elim"="false" "no-infs-fp-math"="true"
> >> "no-nans-fp-math"="true" "unsafe-fp-math"="true"
> >> "use-soft-float"="false" }
> >> +attributes #2 = { nounwind readnone }
> >> --
> >> 1.9.0
> >>
> >> _______________________________________________
> >> llvm-commits mailing list
> >> llvm-commits at cs.uiuc.edu
> >> http://lists.cs.uiuc.edu/mailman/listinfo/llvm-commits
> >>
> >
> > --
> > Hal Finkel
> > Assistant Computational Scientist
> > Leadership Computing Facility
> > Argonne National Laboratory
> 

-- 
Hal Finkel
Assistant Computational Scientist
Leadership Computing Facility
Argonne National Laboratory



More information about the llvm-commits mailing list