[PATCH] X86: introduce rsqrtss/rsqrtps for y / sqrtf(x)

Steven Noonan steven at uplinklabs.net
Sun Mar 2 11:25:17 PST 2014


On Sun, Mar 2, 2014 at 7:57 AM, Hal Finkel <hfinkel at anl.gov> wrote:
> ----- Original Message -----
>> From: "Steven Noonan" <steven at uplinklabs.net>
>> To: llvm-commits at cs.uiuc.edu
>> Sent: Sunday, March 2, 2014 1:52:39 AM
>> Subject: [PATCH] X86: introduce rsqrtss/rsqrtps for y / sqrtf(x)
>>
>> In an application that uses inverse square root for distance
>> calculations
>> I noticed that LLVM and GCC emit different code, with GCC being the
>> faster of
>> the two. One of the most apparent differences was the lack of
>> lowering for
>> 1.0f / sqrtf(x). While GCC was emitting a reciprocal square root
>> followed by
>> a Newton-Raphson iteration, LLVM was implementing the much more
>> straightforward
>> (but also much slower) sqrt followed by div.
>>
>> This change introduces y / sqrtf(x) combination by handling FDIV
>> combining when
>> unsafe math optimizations are enabled.
>>
>> Signed-off-by: Steven Noonan <steven at uplinklabs.net>
>
> I wrote very similar code in the PowerPC backend. Do you think we could refactor these into a common utility? There are several target architectures that have reciprocal estimate functions (X86, PowerPC, ARM, etc.)
>
>  -Hal

Yeah, I was actually going to target ARM with a change like this in a
while. I agree in principle that it should be a common utility, but in
practice I'm not sure where that would sit or what it would look like.
Can we collaborate on this?

>
>> ---
>>  lib/Target/X86/X86ISelLowering.cpp | 133
>>  +++++++++++++++++++++++++++++++++++++
>>  test/CodeGen/X86/rsqrt-fastmath.ll |  64 ++++++++++++++++++
>>  2 files changed, 197 insertions(+)
>>  create mode 100644 test/CodeGen/X86/rsqrt-fastmath.ll
>>
>> diff --git a/lib/Target/X86/X86ISelLowering.cpp
>> b/lib/Target/X86/X86ISelLowering.cpp
>> index 76eeb64..d47a183 100644
>> --- a/lib/Target/X86/X86ISelLowering.cpp
>> +++ b/lib/Target/X86/X86ISelLowering.cpp
>> @@ -1513,6 +1513,7 @@ void X86TargetLowering::resetOperationActions()
>> {
>>    setTargetDAGCombine(ISD::ADD);
>>    setTargetDAGCombine(ISD::FADD);
>>    setTargetDAGCombine(ISD::FSUB);
>> +  setTargetDAGCombine(ISD::FDIV);
>>    setTargetDAGCombine(ISD::FMA);
>>    setTargetDAGCombine(ISD::SUB);
>>    setTargetDAGCombine(ISD::LOAD);
>> @@ -4751,6 +4752,33 @@ static SDValue getZeroVector(EVT VT, const
>> X86Subtarget *Subtarget,
>>    return DAG.getNode(ISD::BITCAST, dl, VT, Vec);
>>  }
>>
>> +/// getSplatVectorFP - Returns a single-precision floating-point
>> vector with
>> +/// all elements set to the same value.
>> +static SDValue getSplatVectorFP(EVT VT, SDValue Cst,
>> +                                const X86Subtarget *Subtarget,
>> +                                SelectionDAG &DAG, SDLoc dl) {
>> +  assert(VT.isVector() && "Expected a vector type");
>> +  assert(Cst.getValueType() == MVT::f32 &&
>> +         "Expected Cst to be a single-precision floating-point
>> value.");
>> +
>> +  SDValue Vec;
>> +  if (VT.is128BitVector()) {  // SSE
>> +    Vec = DAG.getNode(ISD::BUILD_VECTOR, dl, MVT::v4f32, Cst, Cst,
>> Cst, Cst);
>> +  } else if (VT.is256BitVector()) { // AVX
>> +    SDValue Ops[] = { Cst, Cst, Cst, Cst, Cst, Cst, Cst, Cst };
>> +    Vec = DAG.getNode(ISD::BUILD_VECTOR, dl, MVT::v8f32, Ops,
>> +                      array_lengthof(Ops));
>> +  } else if (VT.is512BitVector()) { // AVX-512
>> +    SDValue Ops[] = { Cst, Cst, Cst, Cst, Cst, Cst, Cst, Cst,
>> +                      Cst, Cst, Cst, Cst, Cst, Cst, Cst, Cst };
>> +    Vec = DAG.getNode(ISD::BUILD_VECTOR, dl, MVT::v16f32, Ops,
>> +        array_lengthof(Ops));
>> +  } else
>> +    llvm_unreachable("Unexpected vector type");
>> +
>> +  return Vec;
>> +}
>> +
>>  /// getOnesVector - Returns a vector of specified type with all bits
>>  set.
>>  /// Always build ones vectors as <4 x i32> or <8 x i32>. For 256-bit
>>  types with
>>  /// no AVX2 supprt, use two <4 x i32> inserted in a <8 x i32>
>>  appropriately.
>> @@ -18630,6 +18658,110 @@ static SDValue PerformFSUBCombine(SDNode
>> *N, SelectionDAG &DAG,
>>    return SDValue();
>>  }
>>
>> +// PerformFastRecipFSQRT - Do fast reciprical floating-point square
>> root
>> +static SDValue PerformFastRecipFSQRT(SDValue Op,
>> +                                     TargetLowering::DAGCombinerInfo
>> &DCI,
>> +                                     const X86Subtarget *Subtarget)
>> {
>> +  if (DCI.isAfterLegalizeVectorOps())
>> +    return SDValue();
>> +
>> +  EVT VT = Op.getValueType();
>> +
>> +  if (VT == MVT::f32 ||
>> +      (VT == MVT::v4f32 && Subtarget->hasSSE1()) ||
>> +      (VT == MVT::v8f32 && Subtarget->hasFp256()) ||
>> +      (VT == MVT::v16f32 && Subtarget->hasAVX512()) ) {
>> +
>> +    SelectionDAG &DAG = DCI.DAG;
>> +    SDLoc dl(Op);
>> +
>> +    SDValue FPThree =
>> +      DAG.getConstantFP(-3.0, VT.getScalarType());
>> +    if (VT.isVector()) {
>> +      FPThree = getSplatVectorFP(VT, FPThree, Subtarget, DAG, dl);
>> +    }
>> +
>> +    SDValue FPHalf =
>> +      DAG.getConstantFP(-0.5, VT.getScalarType());
>> +    if (VT.isVector()) {
>> +      FPHalf = getSplatVectorFP(VT, FPHalf, Subtarget, DAG, dl);
>> +    }
>> +
>> +    SDValue Est = DAG.getNode(X86ISD::FRSQRT, dl, VT, Op);
>> +    DCI.AddToWorklist(Est.getNode());
>> +
>> +    SDValue E0 = DAG.getNode(ISD::FMUL, dl, VT, Est, Op);
>> +    DCI.AddToWorklist(E0.getNode());
>> +
>> +    SDValue E1 = DAG.getNode(ISD::FMUL, dl, VT, E0, Est);
>> +    DCI.AddToWorklist(E1.getNode());
>> +
>> +    SDValue E2 = DAG.getNode(ISD::FADD, dl, VT, E1, FPThree);
>> +    DCI.AddToWorklist(E2.getNode());
>> +
>> +    SDValue E3 = DAG.getNode(ISD::FMUL, dl, VT, FPHalf, Est);
>> +    DCI.AddToWorklist(E3.getNode());
>> +
>> +    Est = DAG.getNode(ISD::FMUL, dl, VT, E2, E3);
>> +    DCI.AddToWorklist(Est.getNode());
>> +
>> +    return Est;
>> +  }
>> +
>> +  return SDValue();
>> +}
>> +
>> +/// PerformFDIVCombine - Do target-specific dag update on floating
>> point divs.
>> +static SDValue PerformFDIVCombine(SDNode *N, SelectionDAG &DAG,
>> +                                  TargetLowering::DAGCombinerInfo
>> &DCI,
>> +                                  const X86Subtarget *Subtarget) {
>> +
>> +  // These tricks will reduce precision.
>> +  if (!DAG.getTarget().Options.UnsafeFPMath)
>> +    return SDValue();
>> +
>> +  SDLoc dl(N);
>> +
>> +  if (N->getOperand(1).getOpcode() == ISD::FSQRT) {
>> +    SDValue RV =
>> +      PerformFastRecipFSQRT(N->getOperand(1).getOperand(0), DCI,
>> Subtarget);
>> +    if (RV.getNode() != 0) {
>> +      DCI.AddToWorklist(RV.getNode());
>> +      return DAG.getNode(ISD::FMUL, dl, N->getValueType(0),
>> +                         N->getOperand(0), RV);
>> +    }
>> +  } else if (N->getOperand(1).getOpcode() == ISD::FP_EXTEND &&
>> +             N->getOperand(1).getOperand(0).getOpcode() ==
>> ISD::FSQRT) {
>> +    SDValue RV =
>> +
>>      PerformFastRecipFSQRT(N->getOperand(1).getOperand(0).getOperand(0),
>> +                            DCI, Subtarget);
>> +    if (RV.getNode() != 0) {
>> +      DCI.AddToWorklist(RV.getNode());
>> +      RV = DAG.getNode(ISD::FP_EXTEND, SDLoc(N->getOperand(1)),
>> +                       N->getValueType(0), RV);
>> +      DCI.AddToWorklist(RV.getNode());
>> +      return DAG.getNode(ISD::FMUL, dl, N->getValueType(0),
>> +                         N->getOperand(0), RV);
>> +    }
>> +  } else if (N->getOperand(1).getOpcode() == ISD::FP_ROUND &&
>> +             N->getOperand(1).getOperand(0).getOpcode() ==
>> ISD::FSQRT) {
>> +    SDValue RV =
>> +
>>      PerformFastRecipFSQRT(N->getOperand(1).getOperand(0).getOperand(0),
>> +                            DCI, Subtarget);
>> +    if (RV.getNode() != 0) {
>> +      DCI.AddToWorklist(RV.getNode());
>> +      RV = DAG.getNode(ISD::FP_ROUND, SDLoc(N->getOperand(1)),
>> +                       N->getValueType(0), RV,
>> +                       N->getOperand(1).getOperand(1));
>> +      DCI.AddToWorklist(RV.getNode());
>> +      return DAG.getNode(ISD::FMUL, dl, N->getValueType(0),
>> +                         N->getOperand(0), RV);
>> +    }
>> +  }
>> +
>> +  return SDValue();
>> +}
>> +
>>  /// PerformFORCombine - Do target-specific dag combines on
>>  X86ISD::FOR and
>>  /// X86ISD::FXOR nodes.
>>  static SDValue PerformFORCombine(SDNode *N, SelectionDAG &DAG) {
>> @@ -19142,6 +19274,7 @@ SDValue
>> X86TargetLowering::PerformDAGCombine(SDNode *N,
>>    case ISD::SINT_TO_FP:     return PerformSINT_TO_FPCombine(N, DAG,
>>    this);
>>    case ISD::FADD:           return PerformFADDCombine(N, DAG,
>>    Subtarget);
>>    case ISD::FSUB:           return PerformFSUBCombine(N, DAG,
>>    Subtarget);
>> +  case ISD::FDIV:           return PerformFDIVCombine(N, DAG, DCI,
>> Subtarget);
>>    case X86ISD::FXOR:
>>    case X86ISD::FOR:         return PerformFORCombine(N, DAG);
>>    case X86ISD::FMIN:
>> diff --git a/test/CodeGen/X86/rsqrt-fastmath.ll
>> b/test/CodeGen/X86/rsqrt-fastmath.ll
>> new file mode 100644
>> index 0000000..5f393e9
>> --- /dev/null
>> +++ b/test/CodeGen/X86/rsqrt-fastmath.ll
>> @@ -0,0 +1,64 @@
>> +; RUN: llc < %s -mcpu=core2 | FileCheck %s
>> +
>> +; generated using "clang -S -O2 -ffast-math -emit-llvm sqrt.c" from
>> +; #include <math.h>
>> +;
>> +; double fd(double d) {
>> +;   return 1.0 / sqrt(d);
>> +; }
>> +;
>> +; float ff(float f) {
>> +;   return 1.0f / sqrtf(f);
>> +; }
>> +;
>> +; long double fld(long double ld) {
>> +;   return 1.0 / sqrtl(ld);
>> +; }
>> +;
>> +; Tests conversion of single-precision floating-point square root
>> instructions
>> +; in the denominator of a floating point divide into rsqrt
>> instructions when
>> +; -ffast-math is in effect.
>> +
>> +; ModuleID = 'rsqrt.c'
>> +target datalayout =
>> "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:16:32:64-S128"
>> +target triple = "x86_64-unknown-linux-gnu"
>> +
>> +; Function Attrs: nounwind readnone uwtable
>> +define double @fd(double %d) #0 {
>> +entry:
>> +; CHECK: sqrtsd
>> +  %0 = tail call double @llvm.sqrt.f64(double %d)
>> +  %div = fdiv fast double 1.000000e+00, %0
>> +  ret double %div
>> +}
>> +
>> +; Function Attrs: nounwind readonly
>> +declare double @llvm.sqrt.f64(double) #1
>> +
>> +; Function Attrs: nounwind readnone uwtable
>> +define float @ff(float %f) #0 {
>> +entry:
>> +; CHECK: rsqrtss
>> +  %0 = tail call float @llvm.sqrt.f32(float %f)
>> +  %div = fdiv fast float 1.000000e+00, %0
>> +  ret float %div
>> +}
>> +
>> +; Function Attrs: nounwind readonly
>> +declare float @llvm.sqrt.f32(float) #1
>> +
>> +; Function Attrs: nounwind readnone uwtable
>> +define x86_fp80 @fld(x86_fp80 %ld) #0 {
>> +entry:
>> +; CHECK: fsqrt
>> +  %0 = tail call x86_fp80 @llvm.sqrt.f80(x86_fp80 %ld)
>> +  %div = fdiv fast x86_fp80 0xK3FFF8000000000000000, %0
>> +  ret x86_fp80 %div
>> +}
>> +
>> +; Function Attrs: nounwind readonly
>> +declare x86_fp80 @llvm.sqrt.f80(x86_fp80) #1
>> +
>> +attributes #0 = { nounwind readnone uwtable
>> "less-precise-fpmad"="false" "no-frame-pointer-elim"="false"
>> "no-infs-fp-math"="true" "no-nans-fp-math"="true"
>> "unsafe-fp-math"="true" "use-soft-float"="false" }
>> +attributes #1 = { nounwind readnone "less-precise-fpmad"="false"
>> "no-frame-pointer-elim"="false" "no-infs-fp-math"="true"
>> "no-nans-fp-math"="true" "unsafe-fp-math"="true"
>> "use-soft-float"="false" }
>> +attributes #2 = { nounwind readnone }
>> --
>> 1.9.0
>>
>> _______________________________________________
>> llvm-commits mailing list
>> llvm-commits at cs.uiuc.edu
>> http://lists.cs.uiuc.edu/mailman/listinfo/llvm-commits
>>
>
> --
> Hal Finkel
> Assistant Computational Scientist
> Leadership Computing Facility
> Argonne National Laboratory



More information about the llvm-commits mailing list