[PATCH] X86: introduce rsqrtss/rsqrtps for y / sqrtf(x)

Hal Finkel hfinkel at anl.gov
Sun Mar 2 07:57:24 PST 2014


----- Original Message -----
> From: "Steven Noonan" <steven at uplinklabs.net>
> To: llvm-commits at cs.uiuc.edu
> Sent: Sunday, March 2, 2014 1:52:39 AM
> Subject: [PATCH] X86: introduce rsqrtss/rsqrtps for y / sqrtf(x)
> 
> In an application that uses inverse square root for distance
> calculations
> I noticed that LLVM and GCC emit different code, with GCC being the
> faster of
> the two. One of the most apparent differences was the lack of
> lowering for
> 1.0f / sqrtf(x). While GCC was emitting a reciprocal square root
> followed by
> a Newton-Raphson iteration, LLVM was implementing the much more
> straightforward
> (but also much slower) sqrt followed by div.
> 
> This change introduces y / sqrtf(x) combination by handling FDIV
> combining when
> unsafe math optimizations are enabled.
> 
> Signed-off-by: Steven Noonan <steven at uplinklabs.net>

I wrote very similar code in the PowerPC backend. Do you think we could refactor these into a common utility? There are several target architectures that have reciprocal estimate functions (X86, PowerPC, ARM, etc.)

 -Hal

> ---
>  lib/Target/X86/X86ISelLowering.cpp | 133
>  +++++++++++++++++++++++++++++++++++++
>  test/CodeGen/X86/rsqrt-fastmath.ll |  64 ++++++++++++++++++
>  2 files changed, 197 insertions(+)
>  create mode 100644 test/CodeGen/X86/rsqrt-fastmath.ll
> 
> diff --git a/lib/Target/X86/X86ISelLowering.cpp
> b/lib/Target/X86/X86ISelLowering.cpp
> index 76eeb64..d47a183 100644
> --- a/lib/Target/X86/X86ISelLowering.cpp
> +++ b/lib/Target/X86/X86ISelLowering.cpp
> @@ -1513,6 +1513,7 @@ void X86TargetLowering::resetOperationActions()
> {
>    setTargetDAGCombine(ISD::ADD);
>    setTargetDAGCombine(ISD::FADD);
>    setTargetDAGCombine(ISD::FSUB);
> +  setTargetDAGCombine(ISD::FDIV);
>    setTargetDAGCombine(ISD::FMA);
>    setTargetDAGCombine(ISD::SUB);
>    setTargetDAGCombine(ISD::LOAD);
> @@ -4751,6 +4752,33 @@ static SDValue getZeroVector(EVT VT, const
> X86Subtarget *Subtarget,
>    return DAG.getNode(ISD::BITCAST, dl, VT, Vec);
>  }
>  
> +/// getSplatVectorFP - Returns a single-precision floating-point
> vector with
> +/// all elements set to the same value.
> +static SDValue getSplatVectorFP(EVT VT, SDValue Cst,
> +                                const X86Subtarget *Subtarget,
> +                                SelectionDAG &DAG, SDLoc dl) {
> +  assert(VT.isVector() && "Expected a vector type");
> +  assert(Cst.getValueType() == MVT::f32 &&
> +         "Expected Cst to be a single-precision floating-point
> value.");
> +
> +  SDValue Vec;
> +  if (VT.is128BitVector()) {  // SSE
> +    Vec = DAG.getNode(ISD::BUILD_VECTOR, dl, MVT::v4f32, Cst, Cst,
> Cst, Cst);
> +  } else if (VT.is256BitVector()) { // AVX
> +    SDValue Ops[] = { Cst, Cst, Cst, Cst, Cst, Cst, Cst, Cst };
> +    Vec = DAG.getNode(ISD::BUILD_VECTOR, dl, MVT::v8f32, Ops,
> +                      array_lengthof(Ops));
> +  } else if (VT.is512BitVector()) { // AVX-512
> +    SDValue Ops[] = { Cst, Cst, Cst, Cst, Cst, Cst, Cst, Cst,
> +                      Cst, Cst, Cst, Cst, Cst, Cst, Cst, Cst };
> +    Vec = DAG.getNode(ISD::BUILD_VECTOR, dl, MVT::v16f32, Ops,
> +        array_lengthof(Ops));
> +  } else
> +    llvm_unreachable("Unexpected vector type");
> +
> +  return Vec;
> +}
> +
>  /// getOnesVector - Returns a vector of specified type with all bits
>  set.
>  /// Always build ones vectors as <4 x i32> or <8 x i32>. For 256-bit
>  types with
>  /// no AVX2 supprt, use two <4 x i32> inserted in a <8 x i32>
>  appropriately.
> @@ -18630,6 +18658,110 @@ static SDValue PerformFSUBCombine(SDNode
> *N, SelectionDAG &DAG,
>    return SDValue();
>  }
>  
> +// PerformFastRecipFSQRT - Do fast reciprical floating-point square
> root
> +static SDValue PerformFastRecipFSQRT(SDValue Op,
> +                                     TargetLowering::DAGCombinerInfo
> &DCI,
> +                                     const X86Subtarget *Subtarget)
> {
> +  if (DCI.isAfterLegalizeVectorOps())
> +    return SDValue();
> +
> +  EVT VT = Op.getValueType();
> +
> +  if (VT == MVT::f32 ||
> +      (VT == MVT::v4f32 && Subtarget->hasSSE1()) ||
> +      (VT == MVT::v8f32 && Subtarget->hasFp256()) ||
> +      (VT == MVT::v16f32 && Subtarget->hasAVX512()) ) {
> +
> +    SelectionDAG &DAG = DCI.DAG;
> +    SDLoc dl(Op);
> +
> +    SDValue FPThree =
> +      DAG.getConstantFP(-3.0, VT.getScalarType());
> +    if (VT.isVector()) {
> +      FPThree = getSplatVectorFP(VT, FPThree, Subtarget, DAG, dl);
> +    }
> +
> +    SDValue FPHalf =
> +      DAG.getConstantFP(-0.5, VT.getScalarType());
> +    if (VT.isVector()) {
> +      FPHalf = getSplatVectorFP(VT, FPHalf, Subtarget, DAG, dl);
> +    }
> +
> +    SDValue Est = DAG.getNode(X86ISD::FRSQRT, dl, VT, Op);
> +    DCI.AddToWorklist(Est.getNode());
> +
> +    SDValue E0 = DAG.getNode(ISD::FMUL, dl, VT, Est, Op);
> +    DCI.AddToWorklist(E0.getNode());
> +
> +    SDValue E1 = DAG.getNode(ISD::FMUL, dl, VT, E0, Est);
> +    DCI.AddToWorklist(E1.getNode());
> +
> +    SDValue E2 = DAG.getNode(ISD::FADD, dl, VT, E1, FPThree);
> +    DCI.AddToWorklist(E2.getNode());
> +
> +    SDValue E3 = DAG.getNode(ISD::FMUL, dl, VT, FPHalf, Est);
> +    DCI.AddToWorklist(E3.getNode());
> +
> +    Est = DAG.getNode(ISD::FMUL, dl, VT, E2, E3);
> +    DCI.AddToWorklist(Est.getNode());
> +
> +    return Est;
> +  }
> +
> +  return SDValue();
> +}
> +
> +/// PerformFDIVCombine - Do target-specific dag update on floating
> point divs.
> +static SDValue PerformFDIVCombine(SDNode *N, SelectionDAG &DAG,
> +                                  TargetLowering::DAGCombinerInfo
> &DCI,
> +                                  const X86Subtarget *Subtarget) {
> +
> +  // These tricks will reduce precision.
> +  if (!DAG.getTarget().Options.UnsafeFPMath)
> +    return SDValue();
> +
> +  SDLoc dl(N);
> +
> +  if (N->getOperand(1).getOpcode() == ISD::FSQRT) {
> +    SDValue RV =
> +      PerformFastRecipFSQRT(N->getOperand(1).getOperand(0), DCI,
> Subtarget);
> +    if (RV.getNode() != 0) {
> +      DCI.AddToWorklist(RV.getNode());
> +      return DAG.getNode(ISD::FMUL, dl, N->getValueType(0),
> +                         N->getOperand(0), RV);
> +    }
> +  } else if (N->getOperand(1).getOpcode() == ISD::FP_EXTEND &&
> +             N->getOperand(1).getOperand(0).getOpcode() ==
> ISD::FSQRT) {
> +    SDValue RV =
> +
>      PerformFastRecipFSQRT(N->getOperand(1).getOperand(0).getOperand(0),
> +                            DCI, Subtarget);
> +    if (RV.getNode() != 0) {
> +      DCI.AddToWorklist(RV.getNode());
> +      RV = DAG.getNode(ISD::FP_EXTEND, SDLoc(N->getOperand(1)),
> +                       N->getValueType(0), RV);
> +      DCI.AddToWorklist(RV.getNode());
> +      return DAG.getNode(ISD::FMUL, dl, N->getValueType(0),
> +                         N->getOperand(0), RV);
> +    }
> +  } else if (N->getOperand(1).getOpcode() == ISD::FP_ROUND &&
> +             N->getOperand(1).getOperand(0).getOpcode() ==
> ISD::FSQRT) {
> +    SDValue RV =
> +
>      PerformFastRecipFSQRT(N->getOperand(1).getOperand(0).getOperand(0),
> +                            DCI, Subtarget);
> +    if (RV.getNode() != 0) {
> +      DCI.AddToWorklist(RV.getNode());
> +      RV = DAG.getNode(ISD::FP_ROUND, SDLoc(N->getOperand(1)),
> +                       N->getValueType(0), RV,
> +                       N->getOperand(1).getOperand(1));
> +      DCI.AddToWorklist(RV.getNode());
> +      return DAG.getNode(ISD::FMUL, dl, N->getValueType(0),
> +                         N->getOperand(0), RV);
> +    }
> +  }
> +
> +  return SDValue();
> +}
> +
>  /// PerformFORCombine - Do target-specific dag combines on
>  X86ISD::FOR and
>  /// X86ISD::FXOR nodes.
>  static SDValue PerformFORCombine(SDNode *N, SelectionDAG &DAG) {
> @@ -19142,6 +19274,7 @@ SDValue
> X86TargetLowering::PerformDAGCombine(SDNode *N,
>    case ISD::SINT_TO_FP:     return PerformSINT_TO_FPCombine(N, DAG,
>    this);
>    case ISD::FADD:           return PerformFADDCombine(N, DAG,
>    Subtarget);
>    case ISD::FSUB:           return PerformFSUBCombine(N, DAG,
>    Subtarget);
> +  case ISD::FDIV:           return PerformFDIVCombine(N, DAG, DCI,
> Subtarget);
>    case X86ISD::FXOR:
>    case X86ISD::FOR:         return PerformFORCombine(N, DAG);
>    case X86ISD::FMIN:
> diff --git a/test/CodeGen/X86/rsqrt-fastmath.ll
> b/test/CodeGen/X86/rsqrt-fastmath.ll
> new file mode 100644
> index 0000000..5f393e9
> --- /dev/null
> +++ b/test/CodeGen/X86/rsqrt-fastmath.ll
> @@ -0,0 +1,64 @@
> +; RUN: llc < %s -mcpu=core2 | FileCheck %s
> +
> +; generated using "clang -S -O2 -ffast-math -emit-llvm sqrt.c" from
> +; #include <math.h>
> +;
> +; double fd(double d) {
> +;   return 1.0 / sqrt(d);
> +; }
> +;
> +; float ff(float f) {
> +;   return 1.0f / sqrtf(f);
> +; }
> +;
> +; long double fld(long double ld) {
> +;   return 1.0 / sqrtl(ld);
> +; }
> +;
> +; Tests conversion of single-precision floating-point square root
> instructions
> +; in the denominator of a floating point divide into rsqrt
> instructions when
> +; -ffast-math is in effect.
> +
> +; ModuleID = 'rsqrt.c'
> +target datalayout =
> "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:16:32:64-S128"
> +target triple = "x86_64-unknown-linux-gnu"
> +
> +; Function Attrs: nounwind readnone uwtable
> +define double @fd(double %d) #0 {
> +entry:
> +; CHECK: sqrtsd
> +  %0 = tail call double @llvm.sqrt.f64(double %d)
> +  %div = fdiv fast double 1.000000e+00, %0
> +  ret double %div
> +}
> +
> +; Function Attrs: nounwind readonly
> +declare double @llvm.sqrt.f64(double) #1
> +
> +; Function Attrs: nounwind readnone uwtable
> +define float @ff(float %f) #0 {
> +entry:
> +; CHECK: rsqrtss
> +  %0 = tail call float @llvm.sqrt.f32(float %f)
> +  %div = fdiv fast float 1.000000e+00, %0
> +  ret float %div
> +}
> +
> +; Function Attrs: nounwind readonly
> +declare float @llvm.sqrt.f32(float) #1
> +
> +; Function Attrs: nounwind readnone uwtable
> +define x86_fp80 @fld(x86_fp80 %ld) #0 {
> +entry:
> +; CHECK: fsqrt
> +  %0 = tail call x86_fp80 @llvm.sqrt.f80(x86_fp80 %ld)
> +  %div = fdiv fast x86_fp80 0xK3FFF8000000000000000, %0
> +  ret x86_fp80 %div
> +}
> +
> +; Function Attrs: nounwind readonly
> +declare x86_fp80 @llvm.sqrt.f80(x86_fp80) #1
> +
> +attributes #0 = { nounwind readnone uwtable
> "less-precise-fpmad"="false" "no-frame-pointer-elim"="false"
> "no-infs-fp-math"="true" "no-nans-fp-math"="true"
> "unsafe-fp-math"="true" "use-soft-float"="false" }
> +attributes #1 = { nounwind readnone "less-precise-fpmad"="false"
> "no-frame-pointer-elim"="false" "no-infs-fp-math"="true"
> "no-nans-fp-math"="true" "unsafe-fp-math"="true"
> "use-soft-float"="false" }
> +attributes #2 = { nounwind readnone }
> --
> 1.9.0
> 
> _______________________________________________
> llvm-commits mailing list
> llvm-commits at cs.uiuc.edu
> http://lists.cs.uiuc.edu/mailman/listinfo/llvm-commits
> 

-- 
Hal Finkel
Assistant Computational Scientist
Leadership Computing Facility
Argonne National Laboratory



More information about the llvm-commits mailing list