[llvm] ea71ba1 - [DAGCombiner] reassociate reciprocal sqrt expression to eliminate FP division

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 6 16:12:55 PDT 2020


Author: Sanjay Patel
Date: 2020-07-06T19:12:21-04:00
New Revision: ea71ba11ab1187af03a790dc20967ddd62f68bfe

URL: https://github.com/llvm/llvm-project/commit/ea71ba11ab1187af03a790dc20967ddd62f68bfe
DIFF: https://github.com/llvm/llvm-project/commit/ea71ba11ab1187af03a790dc20967ddd62f68bfe.diff

LOG: [DAGCombiner] reassociate reciprocal sqrt expression to eliminate FP division

X / (fabs(A) * sqrt(Z)) --> X / sqrt(A*A*Z) --> X * rsqrt(A*A*Z)

In the motivating case from PR46406:
https://bugs.llvm.org/show_bug.cgi?id=46406
...this is restoring the sequence that was originally in the source code.
We extracted a term from within the sqrt because we do not know in
instcombine whether a target will expand a sqrt call.
Note: we could say that the transform in IR should be restricted, but
that would not solve the problem if the source was originally in the
pattern shown here.

This is a gray area for fast-math-flag requirements. I think we should at
least check fast-math-flags on the fdiv and fmul because I view this
transform as 2 pieces: reassociate the fmul operands and form reciprocal
from the fdiv (as with the existing transform). We could argue that the
sqrt also needs FMF, but that was not required before, so we should change
that in a follow-up patch if that seems better.

We don't currently have a way to check that the target will produce a sqrt
or recip estimate without actually creating nodes (the APIs are SDValue
getSqrtEstimate() and SDValue getRecipEstimate()), so we clean up
speculatively created nodes if we are not able to create an estimate.
The x86 test with doubles verifies that we are not changing a test with
no estimate sequence.

Differential Revision: https://reviews.llvm.org/D82716

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/test/CodeGen/X86/sqrt-fastmath.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 015d78a3e868..c94bbeb60716 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -13232,6 +13232,24 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) {
         Y = N1.getOperand(0);
       }
       if (Sqrt.getNode()) {
+        // If the other multiply operand is known positive, pull it into the
+        // sqrt. That will eliminate the division if we convert to an estimate:
+        // X / (fabs(A) * sqrt(Z)) --> X / sqrt(A*A*Z) --> X * rsqrt(A*A*Z)
+        // TODO: Also fold the case where A == Z (fabs is missing).
+        if (Flags.hasAllowReassociation() && N1.hasOneUse() &&
+            N1->getFlags().hasAllowReassociation() && Sqrt.hasOneUse() &&
+            Y.getOpcode() == ISD::FABS && Y.hasOneUse()) {
+          SDValue AA = DAG.getNode(ISD::FMUL, DL, VT, Y.getOperand(0),
+                                   Y.getOperand(0), Flags);
+          SDValue AAZ =
+              DAG.getNode(ISD::FMUL, DL, VT, AA, Sqrt.getOperand(0), Flags);
+          if (SDValue Rsqrt = buildRsqrtEstimate(AAZ, Flags))
+            return DAG.getNode(ISD::FMUL, DL, VT, N0, Rsqrt, Flags);
+
+          // Estimate creation failed. Clean up speculatively created nodes.
+          recursivelyDeleteUnusedNodes(AAZ.getNode());
+        }
+
         // We found a FSQRT, so try to make this fold:
         // X / (Y * sqrt(Z)) -> X * (rsqrt(Z) / Y)
         if (SDValue Rsqrt = buildRsqrtEstimate(Sqrt.getOperand(0), Flags)) {

diff  --git a/llvm/test/CodeGen/X86/sqrt-fastmath.ll b/llvm/test/CodeGen/X86/sqrt-fastmath.ll
index b1582d7288c9..29d21f6927bd 100644
--- a/llvm/test/CodeGen/X86/sqrt-fastmath.ll
+++ b/llvm/test/CodeGen/X86/sqrt-fastmath.ll
@@ -618,46 +618,47 @@ define <16 x float> @v16f32_estimate(<16 x float> %x) #1 {
   ret <16 x float> %div
 }
 
-; x / (fabs(y) * sqrt(z))
+; x / (fabs(y) * sqrt(z)) --> x * rsqrt(y*y*z)
 
 define float @div_sqrt_fabs_f32(float %x, float %y, float %z) {
 ; SSE-LABEL: div_sqrt_fabs_f32:
 ; SSE:       # %bb.0:
-; SSE-NEXT:    rsqrtss %xmm2, %xmm3
-; SSE-NEXT:    mulss %xmm3, %xmm2
-; SSE-NEXT:    mulss %xmm3, %xmm2
-; SSE-NEXT:    addss {{.*}}(%rip), %xmm2
-; SSE-NEXT:    mulss {{.*}}(%rip), %xmm3
-; SSE-NEXT:    mulss %xmm2, %xmm3
-; SSE-NEXT:    andps {{.*}}(%rip), %xmm1
-; SSE-NEXT:    divss %xmm1, %xmm3
-; SSE-NEXT:    mulss %xmm3, %xmm0
+; SSE-NEXT:    mulss %xmm1, %xmm1
+; SSE-NEXT:    mulss %xmm2, %xmm1
+; SSE-NEXT:    xorps %xmm2, %xmm2
+; SSE-NEXT:    rsqrtss %xmm1, %xmm2
+; SSE-NEXT:    mulss %xmm2, %xmm1
+; SSE-NEXT:    mulss %xmm2, %xmm1
+; SSE-NEXT:    addss {{.*}}(%rip), %xmm1
+; SSE-NEXT:    mulss {{.*}}(%rip), %xmm2
+; SSE-NEXT:    mulss %xmm0, %xmm2
+; SSE-NEXT:    mulss %xmm1, %xmm2
+; SSE-NEXT:    movaps %xmm2, %xmm0
 ; SSE-NEXT:    retq
 ;
 ; AVX1-LABEL: div_sqrt_fabs_f32:
 ; AVX1:       # %bb.0:
-; AVX1-NEXT:    vrsqrtss %xmm2, %xmm2, %xmm3
-; AVX1-NEXT:    vmulss %xmm3, %xmm2, %xmm2
-; AVX1-NEXT:    vmulss %xmm3, %xmm2, %xmm2
-; AVX1-NEXT:    vaddss {{.*}}(%rip), %xmm2, %xmm2
-; AVX1-NEXT:    vmulss {{.*}}(%rip), %xmm3, %xmm3
-; AVX1-NEXT:    vmulss %xmm2, %xmm3, %xmm2
-; AVX1-NEXT:    vandps {{.*}}(%rip), %xmm1, %xmm1
-; AVX1-NEXT:    vdivss %xmm1, %xmm2, %xmm1
-; AVX1-NEXT:    vmulss %xmm1, %xmm0, %xmm0
+; AVX1-NEXT:    vmulss %xmm1, %xmm1, %xmm1
+; AVX1-NEXT:    vmulss %xmm2, %xmm1, %xmm1
+; AVX1-NEXT:    vrsqrtss %xmm1, %xmm1, %xmm2
+; AVX1-NEXT:    vmulss %xmm2, %xmm1, %xmm1
+; AVX1-NEXT:    vmulss %xmm2, %xmm1, %xmm1
+; AVX1-NEXT:    vaddss {{.*}}(%rip), %xmm1, %xmm1
+; AVX1-NEXT:    vmulss {{.*}}(%rip), %xmm2, %xmm2
+; AVX1-NEXT:    vmulss %xmm0, %xmm2, %xmm0
+; AVX1-NEXT:    vmulss %xmm0, %xmm1, %xmm0
 ; AVX1-NEXT:    retq
 ;
 ; AVX512-LABEL: div_sqrt_fabs_f32:
 ; AVX512:       # %bb.0:
-; AVX512-NEXT:    vrsqrtss %xmm2, %xmm2, %xmm3
-; AVX512-NEXT:    vmulss %xmm3, %xmm2, %xmm2
-; AVX512-NEXT:    vfmadd213ss {{.*#+}} xmm2 = (xmm3 * xmm2) + mem
-; AVX512-NEXT:    vmulss {{.*}}(%rip), %xmm3, %xmm3
-; AVX512-NEXT:    vbroadcastss {{.*#+}} xmm4 = [NaN,NaN,NaN,NaN]
-; AVX512-NEXT:    vmulss %xmm2, %xmm3, %xmm2
-; AVX512-NEXT:    vandps %xmm4, %xmm1, %xmm1
-; AVX512-NEXT:    vdivss %xmm1, %xmm2, %xmm1
-; AVX512-NEXT:    vmulss %xmm1, %xmm0, %xmm0
+; AVX512-NEXT:    vmulss %xmm1, %xmm1, %xmm1
+; AVX512-NEXT:    vmulss %xmm2, %xmm1, %xmm1
+; AVX512-NEXT:    vrsqrtss %xmm1, %xmm1, %xmm2
+; AVX512-NEXT:    vmulss %xmm2, %xmm1, %xmm1
+; AVX512-NEXT:    vfmadd213ss {{.*#+}} xmm1 = (xmm2 * xmm1) + mem
+; AVX512-NEXT:    vmulss {{.*}}(%rip), %xmm2, %xmm2
+; AVX512-NEXT:    vmulss %xmm0, %xmm2, %xmm0
+; AVX512-NEXT:    vmulss %xmm0, %xmm1, %xmm0
 ; AVX512-NEXT:    retq
   %s = call fast float @llvm.sqrt.f32(float %z)
   %a = call fast float @llvm.fabs.f32(float %y)
@@ -666,47 +667,46 @@ define float @div_sqrt_fabs_f32(float %x, float %y, float %z) {
   ret float %d
 }
 
-; x / (fabs(y) * sqrt(z))
+; x / (fabs(y) * sqrt(z)) --> x * rsqrt(y*y*z)
 
 define <4 x float> @div_sqrt_fabs_v4f32(<4 x float> %x, <4 x float> %y, <4 x float> %z) {
 ; SSE-LABEL: div_sqrt_fabs_v4f32:
 ; SSE:       # %bb.0:
-; SSE-NEXT:    rsqrtps %xmm2, %xmm3
-; SSE-NEXT:    mulps %xmm3, %xmm2
-; SSE-NEXT:    mulps %xmm3, %xmm2
-; SSE-NEXT:    addps {{.*}}(%rip), %xmm2
-; SSE-NEXT:    mulps {{.*}}(%rip), %xmm3
-; SSE-NEXT:    mulps %xmm2, %xmm3
-; SSE-NEXT:    andps {{.*}}(%rip), %xmm1
-; SSE-NEXT:    divps %xmm1, %xmm3
-; SSE-NEXT:    mulps %xmm3, %xmm0
+; SSE-NEXT:    mulps %xmm1, %xmm1
+; SSE-NEXT:    mulps %xmm2, %xmm1
+; SSE-NEXT:    rsqrtps %xmm1, %xmm2
+; SSE-NEXT:    mulps %xmm2, %xmm1
+; SSE-NEXT:    mulps %xmm2, %xmm1
+; SSE-NEXT:    addps {{.*}}(%rip), %xmm1
+; SSE-NEXT:    mulps {{.*}}(%rip), %xmm2
+; SSE-NEXT:    mulps %xmm1, %xmm2
+; SSE-NEXT:    mulps %xmm2, %xmm0
 ; SSE-NEXT:    retq
 ;
 ; AVX1-LABEL: div_sqrt_fabs_v4f32:
 ; AVX1:       # %bb.0:
-; AVX1-NEXT:    vrsqrtps %xmm2, %xmm3
-; AVX1-NEXT:    vmulps %xmm3, %xmm2, %xmm2
-; AVX1-NEXT:    vmulps %xmm3, %xmm2, %xmm2
-; AVX1-NEXT:    vaddps {{.*}}(%rip), %xmm2, %xmm2
-; AVX1-NEXT:    vmulps {{.*}}(%rip), %xmm3, %xmm3
-; AVX1-NEXT:    vmulps %xmm2, %xmm3, %xmm2
-; AVX1-NEXT:    vandps {{.*}}(%rip), %xmm1, %xmm1
-; AVX1-NEXT:    vdivps %xmm1, %xmm2, %xmm1
+; AVX1-NEXT:    vmulps %xmm1, %xmm1, %xmm1
+; AVX1-NEXT:    vmulps %xmm2, %xmm1, %xmm1
+; AVX1-NEXT:    vrsqrtps %xmm1, %xmm2
+; AVX1-NEXT:    vmulps %xmm2, %xmm1, %xmm1
+; AVX1-NEXT:    vmulps %xmm2, %xmm1, %xmm1
+; AVX1-NEXT:    vaddps {{.*}}(%rip), %xmm1, %xmm1
+; AVX1-NEXT:    vmulps {{.*}}(%rip), %xmm2, %xmm2
+; AVX1-NEXT:    vmulps %xmm1, %xmm2, %xmm1
 ; AVX1-NEXT:    vmulps %xmm1, %xmm0, %xmm0
 ; AVX1-NEXT:    retq
 ;
 ; AVX512-LABEL: div_sqrt_fabs_v4f32:
 ; AVX512:       # %bb.0:
-; AVX512-NEXT:    vrsqrtps %xmm2, %xmm3
-; AVX512-NEXT:    vmulps %xmm3, %xmm2, %xmm2
-; AVX512-NEXT:    vbroadcastss {{.*#+}} xmm4 = [-3.0E+0,-3.0E+0,-3.0E+0,-3.0E+0]
-; AVX512-NEXT:    vfmadd231ps {{.*#+}} xmm4 = (xmm3 * xmm2) + xmm4
-; AVX512-NEXT:    vbroadcastss {{.*#+}} xmm2 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1]
-; AVX512-NEXT:    vmulps %xmm2, %xmm3, %xmm2
-; AVX512-NEXT:    vbroadcastss {{.*#+}} xmm3 = [NaN,NaN,NaN,NaN]
-; AVX512-NEXT:    vmulps %xmm4, %xmm2, %xmm2
-; AVX512-NEXT:    vandps %xmm3, %xmm1, %xmm1
-; AVX512-NEXT:    vdivps %xmm1, %xmm2, %xmm1
+; AVX512-NEXT:    vmulps %xmm1, %xmm1, %xmm1
+; AVX512-NEXT:    vmulps %xmm2, %xmm1, %xmm1
+; AVX512-NEXT:    vrsqrtps %xmm1, %xmm2
+; AVX512-NEXT:    vmulps %xmm2, %xmm1, %xmm1
+; AVX512-NEXT:    vbroadcastss {{.*#+}} xmm3 = [-3.0E+0,-3.0E+0,-3.0E+0,-3.0E+0]
+; AVX512-NEXT:    vfmadd231ps {{.*#+}} xmm3 = (xmm2 * xmm1) + xmm3
+; AVX512-NEXT:    vbroadcastss {{.*#+}} xmm1 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1]
+; AVX512-NEXT:    vmulps %xmm1, %xmm2, %xmm1
+; AVX512-NEXT:    vmulps %xmm3, %xmm1, %xmm1
 ; AVX512-NEXT:    vmulps %xmm1, %xmm0, %xmm0
 ; AVX512-NEXT:    retq
   %s = call <4 x float> @llvm.sqrt.v4f32(<4 x float> %z)
@@ -716,6 +716,11 @@ define <4 x float> @div_sqrt_fabs_v4f32(<4 x float> %x, <4 x float> %y, <4 x flo
   ret <4 x float> %d
 }
 
+; This has 'arcp' but does not have 'reassoc' FMF.
+; We allow converting the sqrt to an estimate, but
+; do not pull the divisor into the estimate.
+; x / (fabs(y) * sqrt(z)) --> x * rsqrt(z) / fabs(y)
+
 define <4 x float> @div_sqrt_fabs_v4f32_fmf(<4 x float> %x, <4 x float> %y, <4 x float> %z) {
 ; SSE-LABEL: div_sqrt_fabs_v4f32_fmf:
 ; SSE:       # %bb.0:
@@ -765,6 +770,8 @@ define <4 x float> @div_sqrt_fabs_v4f32_fmf(<4 x float> %x, <4 x float> %y, <4 x
   ret <4 x float> %d
 }
 
+; No estimates for f64, so do not convert fabs into an fmul.
+
 define double @div_sqrt_fabs_f64(double %x, double %y, double %z) {
 ; SSE-LABEL: div_sqrt_fabs_f64:
 ; SSE:       # %bb.0:


        


More information about the llvm-commits mailing list