[llvm] f22ac1d - [DAGCombiner] reassociate reciprocal sqrt expression to eliminate FP division, part 2

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Sat Aug 8 08:10:33 PDT 2020


Author: Sanjay Patel
Date: 2020-08-08T10:38:06-04:00
New Revision: f22ac1d15b1b3c8e890cad4aa126a8239bec61f7

URL: https://github.com/llvm/llvm-project/commit/f22ac1d15b1b3c8e890cad4aa126a8239bec61f7
DIFF: https://github.com/llvm/llvm-project/commit/f22ac1d15b1b3c8e890cad4aa126a8239bec61f7.diff

LOG: [DAGCombiner] reassociate reciprocal sqrt expression to eliminate FP division, part 2

Follow-up to D82716 / rGea71ba11ab11
We do not have the fabs removal fold in IR yet for the case
where the sqrt operand is repeated, so that's another potential
improvement.

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/test/CodeGen/X86/sqrt-fastmath.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index b2077f47d4a3..105afb55269e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -13313,21 +13313,26 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) {
       }
       if (Sqrt.getNode()) {
         // If the other multiply operand is known positive, pull it into the
-        // sqrt. That will eliminate the division if we convert to an estimate:
-        // X / (fabs(A) * sqrt(Z)) --> X / sqrt(A*A*Z) --> X * rsqrt(A*A*Z)
-        // TODO: Also fold the case where A == Z (fabs is missing).
+        // sqrt. That will eliminate the division if we convert to an estimate.
         if (Flags.hasAllowReassociation() && N1.hasOneUse() &&
-            N1->getFlags().hasAllowReassociation() && Sqrt.hasOneUse() &&
-            Y.getOpcode() == ISD::FABS && Y.hasOneUse()) {
-          SDValue AA = DAG.getNode(ISD::FMUL, DL, VT, Y.getOperand(0),
-                                   Y.getOperand(0), Flags);
-          SDValue AAZ =
-              DAG.getNode(ISD::FMUL, DL, VT, AA, Sqrt.getOperand(0), Flags);
-          if (SDValue Rsqrt = buildRsqrtEstimate(AAZ, Flags))
-            return DAG.getNode(ISD::FMUL, DL, VT, N0, Rsqrt, Flags);
-
-          // Estimate creation failed. Clean up speculatively created nodes.
-          recursivelyDeleteUnusedNodes(AAZ.getNode());
+            N1->getFlags().hasAllowReassociation() && Sqrt.hasOneUse()) {
+          SDValue A;
+          if (Y.getOpcode() == ISD::FABS && Y.hasOneUse())
+            A = Y.getOperand(0);
+          else if (Y == Sqrt.getOperand(0))
+            A = Y;
+          if (A) {
+            // X / (fabs(A) * sqrt(Z)) --> X / sqrt(A*A*Z) --> X * rsqrt(A*A*Z)
+            // X / (A * sqrt(A))       --> X / sqrt(A*A*A) --> X * rsqrt(A*A*A)
+            SDValue AA = DAG.getNode(ISD::FMUL, DL, VT, A, A, Flags);
+            SDValue AAZ =
+                DAG.getNode(ISD::FMUL, DL, VT, AA, Sqrt.getOperand(0), Flags);
+            if (SDValue Rsqrt = buildRsqrtEstimate(AAZ, Flags))
+              return DAG.getNode(ISD::FMUL, DL, VT, N0, Rsqrt, Flags);
+
+            // Estimate creation failed. Clean up speculatively created nodes.
+            recursivelyDeleteUnusedNodes(AAZ.getNode());
+          }
         }
 
         // We found a FSQRT, so try to make this fold:

diff  --git a/llvm/test/CodeGen/X86/sqrt-fastmath.ll b/llvm/test/CodeGen/X86/sqrt-fastmath.ll
index 3b547c4bb515..d9f56e553332 100644
--- a/llvm/test/CodeGen/X86/sqrt-fastmath.ll
+++ b/llvm/test/CodeGen/X86/sqrt-fastmath.ll
@@ -803,38 +803,43 @@ define double @div_sqrt_fabs_f64(double %x, double %y, double %z) {
 define float @div_sqrt_f32(float %x, float %y) {
 ; SSE-LABEL: div_sqrt_f32:
 ; SSE:       # %bb.0:
-; SSE-NEXT:    rsqrtss %xmm1, %xmm2
-; SSE-NEXT:    movaps %xmm1, %xmm3
-; SSE-NEXT:    mulss %xmm2, %xmm3
-; SSE-NEXT:    mulss %xmm2, %xmm3
-; SSE-NEXT:    addss {{.*}}(%rip), %xmm3
-; SSE-NEXT:    mulss {{.*}}(%rip), %xmm2
-; SSE-NEXT:    mulss %xmm3, %xmm2
-; SSE-NEXT:    divss %xmm1, %xmm2
-; SSE-NEXT:    mulss %xmm2, %xmm0
+; SSE-NEXT:    movaps %xmm1, %xmm2
+; SSE-NEXT:    mulss %xmm1, %xmm2
+; SSE-NEXT:    mulss %xmm1, %xmm2
+; SSE-NEXT:    xorps %xmm1, %xmm1
+; SSE-NEXT:    rsqrtss %xmm2, %xmm1
+; SSE-NEXT:    mulss %xmm1, %xmm2
+; SSE-NEXT:    mulss %xmm1, %xmm2
+; SSE-NEXT:    addss {{.*}}(%rip), %xmm2
+; SSE-NEXT:    mulss {{.*}}(%rip), %xmm1
+; SSE-NEXT:    mulss %xmm0, %xmm1
+; SSE-NEXT:    mulss %xmm2, %xmm1
+; SSE-NEXT:    movaps %xmm1, %xmm0
 ; SSE-NEXT:    retq
 ;
 ; AVX1-LABEL: div_sqrt_f32:
 ; AVX1:       # %bb.0:
+; AVX1-NEXT:    vmulss %xmm1, %xmm1, %xmm2
+; AVX1-NEXT:    vmulss %xmm1, %xmm2, %xmm1
 ; AVX1-NEXT:    vrsqrtss %xmm1, %xmm1, %xmm2
-; AVX1-NEXT:    vmulss %xmm2, %xmm1, %xmm3
-; AVX1-NEXT:    vmulss %xmm2, %xmm3, %xmm3
-; AVX1-NEXT:    vaddss {{.*}}(%rip), %xmm3, %xmm3
+; AVX1-NEXT:    vmulss %xmm2, %xmm1, %xmm1
+; AVX1-NEXT:    vmulss %xmm2, %xmm1, %xmm1
+; AVX1-NEXT:    vaddss {{.*}}(%rip), %xmm1, %xmm1
 ; AVX1-NEXT:    vmulss {{.*}}(%rip), %xmm2, %xmm2
-; AVX1-NEXT:    vmulss %xmm3, %xmm2, %xmm2
-; AVX1-NEXT:    vdivss %xmm1, %xmm2, %xmm1
-; AVX1-NEXT:    vmulss %xmm1, %xmm0, %xmm0
+; AVX1-NEXT:    vmulss %xmm0, %xmm2, %xmm0
+; AVX1-NEXT:    vmulss %xmm0, %xmm1, %xmm0
 ; AVX1-NEXT:    retq
 ;
 ; AVX512-LABEL: div_sqrt_f32:
 ; AVX512:       # %bb.0:
+; AVX512-NEXT:    vmulss %xmm1, %xmm1, %xmm2
+; AVX512-NEXT:    vmulss %xmm1, %xmm2, %xmm1
 ; AVX512-NEXT:    vrsqrtss %xmm1, %xmm1, %xmm2
-; AVX512-NEXT:    vmulss %xmm2, %xmm1, %xmm3
-; AVX512-NEXT:    vfmadd213ss {{.*#+}} xmm3 = (xmm2 * xmm3) + mem
+; AVX512-NEXT:    vmulss %xmm2, %xmm1, %xmm1
+; AVX512-NEXT:    vfmadd213ss {{.*#+}} xmm1 = (xmm2 * xmm1) + mem
 ; AVX512-NEXT:    vmulss {{.*}}(%rip), %xmm2, %xmm2
-; AVX512-NEXT:    vmulss %xmm3, %xmm2, %xmm2
-; AVX512-NEXT:    vdivss %xmm1, %xmm2, %xmm1
-; AVX512-NEXT:    vmulss %xmm1, %xmm0, %xmm0
+; AVX512-NEXT:    vmulss %xmm0, %xmm2, %xmm0
+; AVX512-NEXT:    vmulss %xmm0, %xmm1, %xmm0
 ; AVX512-NEXT:    retq
   %s = call fast float @llvm.sqrt.f32(float %y)
   %m = fmul fast float %s, %y
@@ -850,39 +855,42 @@ define float @div_sqrt_f32(float %x, float %y) {
 define <4 x float> @div_sqrt_v4f32(<4 x float> %x, <4 x float> %y) {
 ; SSE-LABEL: div_sqrt_v4f32:
 ; SSE:       # %bb.0:
-; SSE-NEXT:    rsqrtps %xmm1, %xmm2
-; SSE-NEXT:    movaps %xmm1, %xmm3
-; SSE-NEXT:    mulps %xmm2, %xmm3
-; SSE-NEXT:    mulps %xmm2, %xmm3
-; SSE-NEXT:    addps {{.*}}(%rip), %xmm3
-; SSE-NEXT:    mulps {{.*}}(%rip), %xmm2
-; SSE-NEXT:    mulps %xmm3, %xmm2
-; SSE-NEXT:    divps %xmm1, %xmm2
-; SSE-NEXT:    mulps %xmm2, %xmm0
+; SSE-NEXT:    movaps %xmm1, %xmm2
+; SSE-NEXT:    mulps %xmm1, %xmm2
+; SSE-NEXT:    mulps %xmm1, %xmm2
+; SSE-NEXT:    rsqrtps %xmm2, %xmm1
+; SSE-NEXT:    mulps %xmm1, %xmm2
+; SSE-NEXT:    mulps %xmm1, %xmm2
+; SSE-NEXT:    addps {{.*}}(%rip), %xmm2
+; SSE-NEXT:    mulps {{.*}}(%rip), %xmm1
+; SSE-NEXT:    mulps %xmm2, %xmm1
+; SSE-NEXT:    mulps %xmm1, %xmm0
 ; SSE-NEXT:    retq
 ;
 ; AVX1-LABEL: div_sqrt_v4f32:
 ; AVX1:       # %bb.0:
+; AVX1-NEXT:    vmulps %xmm1, %xmm1, %xmm2
+; AVX1-NEXT:    vmulps %xmm1, %xmm2, %xmm1
 ; AVX1-NEXT:    vrsqrtps %xmm1, %xmm2
-; AVX1-NEXT:    vmulps %xmm2, %xmm1, %xmm3
-; AVX1-NEXT:    vmulps %xmm2, %xmm3, %xmm3
-; AVX1-NEXT:    vaddps {{.*}}(%rip), %xmm3, %xmm3
+; AVX1-NEXT:    vmulps %xmm2, %xmm1, %xmm1
+; AVX1-NEXT:    vmulps %xmm2, %xmm1, %xmm1
+; AVX1-NEXT:    vaddps {{.*}}(%rip), %xmm1, %xmm1
 ; AVX1-NEXT:    vmulps {{.*}}(%rip), %xmm2, %xmm2
-; AVX1-NEXT:    vmulps %xmm3, %xmm2, %xmm2
-; AVX1-NEXT:    vdivps %xmm1, %xmm2, %xmm1
+; AVX1-NEXT:    vmulps %xmm1, %xmm2, %xmm1
 ; AVX1-NEXT:    vmulps %xmm1, %xmm0, %xmm0
 ; AVX1-NEXT:    retq
 ;
 ; AVX512-LABEL: div_sqrt_v4f32:
 ; AVX512:       # %bb.0:
+; AVX512-NEXT:    vmulps %xmm1, %xmm1, %xmm2
+; AVX512-NEXT:    vmulps %xmm1, %xmm2, %xmm1
 ; AVX512-NEXT:    vrsqrtps %xmm1, %xmm2
-; AVX512-NEXT:    vmulps %xmm2, %xmm1, %xmm3
-; AVX512-NEXT:    vbroadcastss {{.*#+}} xmm4 = [-3.0E+0,-3.0E+0,-3.0E+0,-3.0E+0]
-; AVX512-NEXT:    vfmadd231ps {{.*#+}} xmm4 = (xmm2 * xmm3) + xmm4
-; AVX512-NEXT:    vbroadcastss {{.*#+}} xmm3 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1]
-; AVX512-NEXT:    vmulps %xmm3, %xmm2, %xmm2
-; AVX512-NEXT:    vmulps %xmm4, %xmm2, %xmm2
-; AVX512-NEXT:    vdivps %xmm1, %xmm2, %xmm1
+; AVX512-NEXT:    vmulps %xmm2, %xmm1, %xmm1
+; AVX512-NEXT:    vbroadcastss {{.*#+}} xmm3 = [-3.0E+0,-3.0E+0,-3.0E+0,-3.0E+0]
+; AVX512-NEXT:    vfmadd231ps {{.*#+}} xmm3 = (xmm2 * xmm1) + xmm3
+; AVX512-NEXT:    vbroadcastss {{.*#+}} xmm1 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1]
+; AVX512-NEXT:    vmulps %xmm1, %xmm2, %xmm1
+; AVX512-NEXT:    vmulps %xmm3, %xmm1, %xmm1
 ; AVX512-NEXT:    vmulps %xmm1, %xmm0, %xmm0
 ; AVX512-NEXT:    retq
   %s = call <4 x float> @llvm.sqrt.v4f32(<4 x float> %y)


        


More information about the llvm-commits mailing list