[llvm] r272920 - [DAG] Remove redundant FMUL in Newton-Raphson SQRT code

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 16 09:58:55 PDT 2016


Author: spatel
Date: Thu Jun 16 11:58:54 2016
New Revision: 272920

URL: http://llvm.org/viewvc/llvm-project?rev=272920&view=rev
Log:
[DAG] Remove redundant FMUL in Newton-Raphson SQRT code

When calculating a square root using Newton-Raphson with two constants,
a naive implementation is to use five multiplications (four muls to calculate
reciprocal square root and another one to calculate the square root itself).
However, after some reassociation and CSE the same result can be obtained
with only four multiplications. Unfortunately, there's no reliable way to do
such a reassociation in the back-end. So, the patch modifies NR code itself
so that it directly builds optimal code for SQRT and doesn't rely on any
further reassociation.

Patch by Nikolai Bozhenov!

Differential Revision: http://reviews.llvm.org/D21127


Added:
    llvm/trunk/test/CodeGen/X86/sqrt-fastmath-mir.ll
Modified:
    llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/trunk/test/CodeGen/X86/sqrt-fastmath.ll

Modified: llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp?rev=272920&r1=272919&r2=272920&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (original)
+++ llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp Thu Jun 16 11:58:54 2016
@@ -357,11 +357,13 @@ namespace {
     SDValue BuildSDIVPow2(SDNode *N);
     SDValue BuildUDIV(SDNode *N);
     SDValue BuildReciprocalEstimate(SDValue Op, SDNodeFlags *Flags);
-    SDValue BuildRsqrtEstimate(SDValue Op, SDNodeFlags *Flags);
-    SDValue BuildRsqrtNROneConst(SDValue Op, SDValue Est, unsigned Iterations,
-                                 SDNodeFlags *Flags);
-    SDValue BuildRsqrtNRTwoConst(SDValue Op, SDValue Est, unsigned Iterations,
-                                 SDNodeFlags *Flags);
+    SDValue buildRsqrtEstimate(SDValue Op, SDNodeFlags *Flags);
+    SDValue buildSqrtEstimate(SDValue Op, SDNodeFlags *Flags);
+    SDValue buildSqrtEstimateImpl(SDValue Op, SDNodeFlags *Flags, bool Recip);
+    SDValue buildSqrtNROneConst(SDValue Op, SDValue Est, unsigned Iterations,
+                                SDNodeFlags *Flags, bool Reciprocal);
+    SDValue buildSqrtNRTwoConst(SDValue Op, SDValue Est, unsigned Iterations,
+                                SDNodeFlags *Flags, bool Reciprocal);
     SDValue MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
                                bool DemandHighBits = true);
     SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1);
@@ -8825,12 +8827,12 @@ SDValue DAGCombiner::visitFDIV(SDNode *N
     // If this FDIV is part of a reciprocal square root, it may be folded
     // into a target-specific square root estimate instruction.
     if (N1.getOpcode() == ISD::FSQRT) {
-      if (SDValue RV = BuildRsqrtEstimate(N1.getOperand(0), Flags)) {
+      if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0), Flags)) {
         return DAG.getNode(ISD::FMUL, DL, VT, N0, RV, Flags);
       }
     } else if (N1.getOpcode() == ISD::FP_EXTEND &&
                N1.getOperand(0).getOpcode() == ISD::FSQRT) {
-      if (SDValue RV = BuildRsqrtEstimate(N1.getOperand(0).getOperand(0),
+      if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0).getOperand(0),
                                           Flags)) {
         RV = DAG.getNode(ISD::FP_EXTEND, SDLoc(N1), VT, RV);
         AddToWorklist(RV.getNode());
@@ -8838,7 +8840,7 @@ SDValue DAGCombiner::visitFDIV(SDNode *N
       }
     } else if (N1.getOpcode() == ISD::FP_ROUND &&
                N1.getOperand(0).getOpcode() == ISD::FSQRT) {
-      if (SDValue RV = BuildRsqrtEstimate(N1.getOperand(0).getOperand(0),
+      if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0).getOperand(0),
                                           Flags)) {
         RV = DAG.getNode(ISD::FP_ROUND, SDLoc(N1), VT, RV, N1.getOperand(1));
         AddToWorklist(RV.getNode());
@@ -8859,7 +8861,7 @@ SDValue DAGCombiner::visitFDIV(SDNode *N
       if (SqrtOp.getNode()) {
         // We found a FSQRT, so try to make this fold:
         // x / (y * sqrt(z)) -> x * (rsqrt(z) / y)
-        if (SDValue RV = BuildRsqrtEstimate(SqrtOp.getOperand(0), Flags)) {
+        if (SDValue RV = buildRsqrtEstimate(SqrtOp.getOperand(0), Flags)) {
           RV = DAG.getNode(ISD::FDIV, SDLoc(N1), VT, RV, OtherOp, Flags);
           AddToWorklist(RV.getNode());
           return DAG.getNode(ISD::FMUL, DL, VT, N0, RV, Flags);
@@ -8916,27 +8918,7 @@ SDValue DAGCombiner::visitFSQRT(SDNode *
   // For now, create a Flags object for use with all unsafe math transforms.
   SDNodeFlags Flags;
   Flags.setUnsafeAlgebra(true);
-
-  // Compute this as X * (1/sqrt(X)) = X * (X ** -0.5)
-  SDValue RV = BuildRsqrtEstimate(N->getOperand(0), &Flags);
-  if (!RV)
-    return SDValue();
-
-  EVT VT = RV.getValueType();
-  SDLoc DL(N);
-  RV = DAG.getNode(ISD::FMUL, DL, VT, N->getOperand(0), RV, &Flags);
-  AddToWorklist(RV.getNode());
-
-  // Unfortunately, RV is now NaN if the input was exactly 0.
-  // Select out this case and force the answer to 0.
-  SDValue Zero = DAG.getConstantFP(0.0, DL, VT);
-  EVT CCVT = getSetCCResultType(VT);
-  SDValue ZeroCmp = DAG.getSetCC(DL, CCVT, N->getOperand(0), Zero, ISD::SETEQ);
-  AddToWorklist(ZeroCmp.getNode());
-  AddToWorklist(RV.getNode());
-
-  return DAG.getNode(VT.isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT,
-                     ZeroCmp, Zero, RV);
+  return buildSqrtEstimate(N->getOperand(0), &Flags);
 }
 
 /// copysign(x, fp_extend(y)) -> copysign(x, y)
@@ -14587,9 +14569,9 @@ SDValue DAGCombiner::BuildReciprocalEsti
 ///     =>
 ///   X_{i+1} = X_i (1.5 - A X_i^2 / 2)
 /// As a result, we precompute A/2 prior to the iteration loop.
-SDValue DAGCombiner::BuildRsqrtNROneConst(SDValue Arg, SDValue Est,
-                                          unsigned Iterations,
-                                          SDNodeFlags *Flags) {
+SDValue DAGCombiner::buildSqrtNROneConst(SDValue Arg, SDValue Est,
+                                         unsigned Iterations,
+                                         SDNodeFlags *Flags, bool Reciprocal) {
   EVT VT = Arg.getValueType();
   SDLoc DL(Arg);
   SDValue ThreeHalves = DAG.getConstantFP(1.5, DL, VT);
@@ -14616,6 +14598,13 @@ SDValue DAGCombiner::BuildRsqrtNROneCons
     Est = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst, Flags);
     AddToWorklist(Est.getNode());
   }
+
+  // If non-reciprocal square root is requested, multiply the result by Arg.
+  if (!Reciprocal) {
+    Est = DAG.getNode(ISD::FMUL, DL, VT, Est, Arg, Flags);
+    AddToWorklist(Est.getNode());
+  }
+
   return Est;
 }
 
@@ -14624,35 +14613,55 @@ SDValue DAGCombiner::BuildRsqrtNROneCons
 ///   F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
 ///     =>
 ///   X_{i+1} = (-0.5 * X_i) * (A * X_i * X_i + (-3.0))
-SDValue DAGCombiner::BuildRsqrtNRTwoConst(SDValue Arg, SDValue Est,
-                                          unsigned Iterations,
-                                          SDNodeFlags *Flags) {
+SDValue DAGCombiner::buildSqrtNRTwoConst(SDValue Arg, SDValue Est,
+                                         unsigned Iterations,
+                                         SDNodeFlags *Flags, bool Reciprocal) {
   EVT VT = Arg.getValueType();
   SDLoc DL(Arg);
   SDValue MinusThree = DAG.getConstantFP(-3.0, DL, VT);
   SDValue MinusHalf = DAG.getConstantFP(-0.5, DL, VT);
 
-  // Newton iterations: Est = -0.5 * Est * (-3.0 + Arg * Est * Est)
+  // This routine must enter the loop below to work correctly
+  // when (Reciprocal == false).
+  assert(Iterations > 0);
+
+  // Newton iterations for reciprocal square root:
+  // E = (E * -0.5) * ((A * E) * E + -3.0)
   for (unsigned i = 0; i < Iterations; ++i) {
-    SDValue HalfEst = DAG.getNode(ISD::FMUL, DL, VT, Est, MinusHalf, Flags);
-    AddToWorklist(HalfEst.getNode());
+    SDValue AE = DAG.getNode(ISD::FMUL, DL, VT, Arg, Est, Flags);
+    AddToWorklist(AE.getNode());
 
-    Est = DAG.getNode(ISD::FMUL, DL, VT, Est, Est, Flags);
-    AddToWorklist(Est.getNode());
+    SDValue AEE = DAG.getNode(ISD::FMUL, DL, VT, AE, Est, Flags);
+    AddToWorklist(AEE.getNode());
 
-    Est = DAG.getNode(ISD::FMUL, DL, VT, Est, Arg, Flags);
-    AddToWorklist(Est.getNode());
+    SDValue RHS = DAG.getNode(ISD::FADD, DL, VT, AEE, MinusThree, Flags);
+    AddToWorklist(RHS.getNode());
 
-    Est = DAG.getNode(ISD::FADD, DL, VT, Est, MinusThree, Flags);
-    AddToWorklist(Est.getNode());
+    // When calculating a square root at the last iteration build:
+    // S = ((A * E) * -0.5) * ((A * E) * E + -3.0)
+    // (notice a common subexpression)
+    SDValue LHS;
+    if (Reciprocal || (i + 1) < Iterations) {
+      // RSQRT: LHS = (E * -0.5)
+      LHS = DAG.getNode(ISD::FMUL, DL, VT, Est, MinusHalf, Flags);
+    } else {
+      // SQRT: LHS = (A * E) * -0.5
+      LHS = DAG.getNode(ISD::FMUL, DL, VT, AE, MinusHalf, Flags);
+    }
+    AddToWorklist(LHS.getNode());
 
-    Est = DAG.getNode(ISD::FMUL, DL, VT, Est, HalfEst, Flags);
+    Est = DAG.getNode(ISD::FMUL, DL, VT, LHS, RHS, Flags);
     AddToWorklist(Est.getNode());
   }
+
   return Est;
 }
 
-SDValue DAGCombiner::BuildRsqrtEstimate(SDValue Op, SDNodeFlags *Flags) {
+/// Build code to calculate either rsqrt(Op) or sqrt(Op). In the latter case
+/// Op*rsqrt(Op) is actually computed, so additional postprocessing is needed if
+/// Op can be zero.
+SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags *Flags,
+                                           bool Reciprocal) {
   if (Level >= AfterLegalizeDAG)
     return SDValue();
 
@@ -14663,9 +14672,9 @@ SDValue DAGCombiner::BuildRsqrtEstimate(
   if (SDValue Est = TLI.getRsqrtEstimate(Op, DCI, Iterations, UseOneConstNR)) {
     AddToWorklist(Est.getNode());
     if (Iterations) {
-      Est = UseOneConstNR ?
-        BuildRsqrtNROneConst(Op, Est, Iterations, Flags) :
-        BuildRsqrtNRTwoConst(Op, Est, Iterations, Flags);
+      Est = UseOneConstNR
+                ? buildSqrtNROneConst(Op, Est, Iterations, Flags, Reciprocal)
+                : buildSqrtNRTwoConst(Op, Est, Iterations, Flags, Reciprocal);
     }
     return Est;
   }
@@ -14673,6 +14682,30 @@ SDValue DAGCombiner::BuildRsqrtEstimate(
   return SDValue();
 }
 
+SDValue DAGCombiner::buildRsqrtEstimate(SDValue Op, SDNodeFlags *Flags) {
+  return buildSqrtEstimateImpl(Op, Flags, true);
+}
+
+SDValue DAGCombiner::buildSqrtEstimate(SDValue Op, SDNodeFlags *Flags) {
+  SDValue Est = buildSqrtEstimateImpl(Op, Flags, false);
+  if (!Est)
+    return SDValue();
+
+  // Unfortunately, Est is now NaN if the input was exactly 0.
+  // Select out this case and force the answer to 0.
+  EVT VT = Est.getValueType();
+  SDLoc DL(Op);
+  SDValue Zero = DAG.getConstantFP(0.0, DL, VT);
+  EVT CCVT = getSetCCResultType(VT);
+  SDValue ZeroCmp = DAG.getSetCC(DL, CCVT, Op, Zero, ISD::SETEQ);
+  AddToWorklist(ZeroCmp.getNode());
+
+  Est = DAG.getNode(VT.isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT, ZeroCmp,
+                    Zero, Est);
+  AddToWorklist(Est.getNode());
+  return Est;
+}
+
 /// Return true if base is a frame index, which is known not to alias with
 /// anything but itself.  Provides base object and offset as results.
 static bool FindBaseOffset(SDValue Ptr, SDValue &Base, int64_t &Offset,

Added: llvm/trunk/test/CodeGen/X86/sqrt-fastmath-mir.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/X86/sqrt-fastmath-mir.ll?rev=272920&view=auto
==============================================================================
--- llvm/trunk/test/CodeGen/X86/sqrt-fastmath-mir.ll (added)
+++ llvm/trunk/test/CodeGen/X86/sqrt-fastmath-mir.ll Thu Jun 16 11:58:54 2016
@@ -0,0 +1,52 @@
+; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=avx2,fma -recip=sqrt:2 -stop-after=expand-isel-pseudos 2>&1 | FileCheck %s
+
+declare float @llvm.sqrt.f32(float) #0
+
+define float @foo(float %f) #0 {
+; CHECK: {{name: *foo}}
+; CHECK: body:
+; CHECK:     %0 = COPY %xmm0
+; CHECK:     %1 = VRSQRTSSr killed %2, %0
+; CHECK:     %3 = VMULSSrr %0, %1
+; CHECK:     %4 = VMOVSSrm
+; CHECK:     %5 = VFMADDSSr213r %1, killed %3, %4
+; CHECK:     %6 = VMOVSSrm
+; CHECK:     %7 = VMULSSrr %1, %6
+; CHECK:     %8 = VMULSSrr killed %7, killed %5
+; CHECK:     %9 = VMULSSrr %0, %8
+; CHECK:     %10 = VFMADDSSr213r %8, %9, %4
+; CHECK:     %11 = VMULSSrr %9, %6
+; CHECK:     %12 = VMULSSrr killed %11, killed %10
+; CHECK:     %13 = FsFLD0SS
+; CHECK:     %14 = VCMPSSrr %0, killed %13, 0
+; CHECK:     %15 = VFsANDNPSrr killed %14, killed %12
+; CHECK:     %xmm0 = COPY %15
+; CHECK:     RET 0, %xmm0
+  %call = tail call float @llvm.sqrt.f32(float %f) #1
+  ret float %call
+}
+
+define float @rfoo(float %f) #0 {
+; CHECK: {{name: *rfoo}}
+; CHECK: body:             |
+; CHECK:     %0 = COPY %xmm0
+; CHECK:     %1 = VRSQRTSSr killed %2, %0
+; CHECK:     %3 = VMULSSrr %0, %1
+; CHECK:     %4 = VMOVSSrm
+; CHECK:     %5 = VFMADDSSr213r %1, killed %3, %4
+; CHECK:     %6 = VMOVSSrm
+; CHECK:     %7 = VMULSSrr %1, %6
+; CHECK:     %8 = VMULSSrr killed %7, killed %5
+; CHECK:     %9 = VMULSSrr %0, %8
+; CHECK:     %10 = VFMADDSSr213r %8, killed %9, %4
+; CHECK:     %11 = VMULSSrr %8, %6
+; CHECK:     %12 = VMULSSrr killed %11, killed %10
+; CHECK:     %xmm0 = COPY %12
+; CHECK:     RET 0, %xmm0
+  %sqrt = tail call float @llvm.sqrt.f32(float %f)
+  %div = fdiv fast float 1.0, %sqrt
+  ret float %div
+}
+
+attributes #0 = { "unsafe-fp-math"="true" }
+attributes #1 = { nounwind readnone }

Modified: llvm/trunk/test/CodeGen/X86/sqrt-fastmath.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/X86/sqrt-fastmath.ll?rev=272920&r1=272919&r2=272920&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/X86/sqrt-fastmath.ll (original)
+++ llvm/trunk/test/CodeGen/X86/sqrt-fastmath.ll Thu Jun 16 11:58:54 2016
@@ -34,12 +34,11 @@ define float @ff(float %f) #0 {
 ; ESTIMATE-LABEL: ff:
 ; ESTIMATE:       # BB#0:
 ; ESTIMATE-NEXT:    vrsqrtss %xmm0, %xmm0, %xmm1
-; ESTIMATE-NEXT:    vmulss {{.*}}(%rip), %xmm1, %xmm2
-; ESTIMATE-NEXT:    vmulss %xmm0, %xmm1, %xmm3
-; ESTIMATE-NEXT:    vmulss %xmm3, %xmm1, %xmm1
+; ESTIMATE-NEXT:    vmulss %xmm1, %xmm0, %xmm2
+; ESTIMATE-NEXT:    vmulss %xmm1, %xmm2, %xmm1
 ; ESTIMATE-NEXT:    vaddss {{.*}}(%rip), %xmm1, %xmm1
-; ESTIMATE-NEXT:    vmulss %xmm0, %xmm2, %xmm2
-; ESTIMATE-NEXT:    vmulss %xmm2, %xmm1, %xmm1
+; ESTIMATE-NEXT:    vmulss {{.*}}(%rip), %xmm2, %xmm2
+; ESTIMATE-NEXT:    vmulss %xmm1, %xmm2, %xmm1
 ; ESTIMATE-NEXT:    vxorps %xmm2, %xmm2, %xmm2
 ; ESTIMATE-NEXT:    vcmpeqss %xmm2, %xmm0, %xmm0
 ; ESTIMATE-NEXT:    vandnps %xmm1, %xmm0, %xmm0
@@ -78,11 +77,11 @@ define float @reciprocal_square_root(flo
 ; ESTIMATE-LABEL: reciprocal_square_root:
 ; ESTIMATE:       # BB#0:
 ; ESTIMATE-NEXT:    vrsqrtss %xmm0, %xmm0, %xmm1
-; ESTIMATE-NEXT:    vmulss {{.*}}(%rip), %xmm1, %xmm2
-; ESTIMATE-NEXT:    vmulss %xmm0, %xmm1, %xmm0
-; ESTIMATE-NEXT:    vmulss %xmm0, %xmm1, %xmm0
-; ESTIMATE-NEXT:    vaddss {{.*}}(%rip), %xmm0, %xmm0
+; ESTIMATE-NEXT:    vmulss %xmm1, %xmm1, %xmm2
 ; ESTIMATE-NEXT:    vmulss %xmm2, %xmm0, %xmm0
+; ESTIMATE-NEXT:    vaddss {{.*}}(%rip), %xmm0, %xmm0
+; ESTIMATE-NEXT:    vmulss {{.*}}(%rip), %xmm1, %xmm1
+; ESTIMATE-NEXT:    vmulss %xmm0, %xmm1, %xmm0
 ; ESTIMATE-NEXT:    retq
   %sqrt = tail call float @llvm.sqrt.f32(float %x)
   %div = fdiv fast float 1.0, %sqrt
@@ -100,11 +99,11 @@ define <4 x float> @reciprocal_square_ro
 ; ESTIMATE-LABEL: reciprocal_square_root_v4f32:
 ; ESTIMATE:       # BB#0:
 ; ESTIMATE-NEXT:    vrsqrtps %xmm0, %xmm1
-; ESTIMATE-NEXT:    vmulps %xmm0, %xmm1, %xmm0
-; ESTIMATE-NEXT:    vmulps %xmm0, %xmm1, %xmm0
+; ESTIMATE-NEXT:    vmulps %xmm1, %xmm1, %xmm2
+; ESTIMATE-NEXT:    vmulps %xmm2, %xmm0, %xmm0
 ; ESTIMATE-NEXT:    vaddps {{.*}}(%rip), %xmm0, %xmm0
 ; ESTIMATE-NEXT:    vmulps {{.*}}(%rip), %xmm1, %xmm1
-; ESTIMATE-NEXT:    vmulps %xmm1, %xmm0, %xmm0
+; ESTIMATE-NEXT:    vmulps %xmm0, %xmm1, %xmm0
 ; ESTIMATE-NEXT:    retq
   %sqrt = tail call <4 x float> @llvm.sqrt.v4f32(<4 x float> %x)
   %div = fdiv fast <4 x float> <float 1.0, float 1.0, float 1.0, float 1.0>, %sqrt
@@ -125,11 +124,11 @@ define <8 x float> @reciprocal_square_ro
 ; ESTIMATE-LABEL: reciprocal_square_root_v8f32:
 ; ESTIMATE:       # BB#0:
 ; ESTIMATE-NEXT:    vrsqrtps %ymm0, %ymm1
-; ESTIMATE-NEXT:    vmulps %ymm0, %ymm1, %ymm0
-; ESTIMATE-NEXT:    vmulps %ymm0, %ymm1, %ymm0
+; ESTIMATE-NEXT:    vmulps %ymm1, %ymm1, %ymm2
+; ESTIMATE-NEXT:    vmulps %ymm2, %ymm0, %ymm0
 ; ESTIMATE-NEXT:    vaddps {{.*}}(%rip), %ymm0, %ymm0
 ; ESTIMATE-NEXT:    vmulps {{.*}}(%rip), %ymm1, %ymm1
-; ESTIMATE-NEXT:    vmulps %ymm1, %ymm0, %ymm0
+; ESTIMATE-NEXT:    vmulps %ymm0, %ymm1, %ymm0
 ; ESTIMATE-NEXT:    retq
   %sqrt = tail call <8 x float> @llvm.sqrt.v8f32(<8 x float> %x)
   %div = fdiv fast <8 x float> <float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0>, %sqrt




More information about the llvm-commits mailing list