[llvm] 39009a8 - [DAGCombiner] tighten fast-math constraints for fma fold

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Sun Jul 12 05:53:14 PDT 2020


Author: Sanjay Patel
Date: 2020-07-12T08:51:49-04:00
New Revision: 39009a8245dae78250081b16fc679ce338af405a

URL: https://github.com/llvm/llvm-project/commit/39009a8245dae78250081b16fc679ce338af405a
DIFF: https://github.com/llvm/llvm-project/commit/39009a8245dae78250081b16fc679ce338af405a.diff

LOG: [DAGCombiner] tighten fast-math constraints for fma fold

fadd (fma A, B, (fmul C, D)), E --> fma A, B, (fma C, D, E)

This is only allowed when "reassoc" is present on the fadd.

As discussed in D80801, this transform goes beyond
what is allowed by "contract" FMF (-ffp-contract=fast).
That is because we are fusing the trailing add of 'E' with a
multiply, but without "reassoc", the code mandates that the
products A*B and C*D are added together before adding in 'E'.

I've added this example to the LangRef to try to clarify the
meaning of "contract". If that seems reasonable, we should
probably do something similar for the clang docs because
there does not appear to be any formal spec for the behavior
of -ffp-contract=fast.

Differential Revision: https://reviews.llvm.org/D82499

Added: 
    

Modified: 
    llvm/docs/LangRef.rst
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/test/CodeGen/AArch64/fadd-combines.ll
    llvm/test/CodeGen/X86/fma_patterns.ll

Removed: 
    


################################################################################
diff  --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index c2d6200e67fa..86d315be74bc 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -2778,7 +2778,9 @@ floating-point transformations.
 
 ``contract``
    Allow floating-point contraction (e.g. fusing a multiply followed by an
-   addition into a fused multiply-and-add).
+   addition into a fused multiply-and-add). This does not enable reassociating
+   to form arbitrary contractions. For example, ``(a*b) + (c*d) + e`` can not
+   be transformed into ``(a*b) + ((c*d) + e)`` to create two fma operations.
 
 ``afn``
    Approximate functions - Allow substitution of approximate calculations for

diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 0d84cd89f5ae..42e6e12f3f02 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -11986,6 +11986,8 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
 
   SDNodeFlags Flags = N->getFlags();
   bool CanFuse = Options.UnsafeFPMath || isContractable(N);
+  bool CanReassociate =
+      Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
   bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
                               CanFuse || HasFMAD);
   // If the addition is not contractable, do not combine.
@@ -12028,13 +12030,14 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
 
   // fadd (fma A, B, (fmul C, D)), E --> fma A, B, (fma C, D, E)
   // fadd E, (fma A, B, (fmul C, D)) --> fma A, B, (fma C, D, E)
+  // This requires reassociation because it changes the order of operations.
   SDValue FMA, E;
-  if (CanFuse && N0.getOpcode() == PreferredFusedOpcode &&
+  if (CanReassociate && N0.getOpcode() == PreferredFusedOpcode &&
       N0.getOperand(2).getOpcode() == ISD::FMUL && N0.hasOneUse() &&
       N0.getOperand(2).hasOneUse()) {
     FMA = N0;
     E = N1;
-  } else if (CanFuse && N1.getOpcode() == PreferredFusedOpcode &&
+  } else if (CanReassociate && N1.getOpcode() == PreferredFusedOpcode &&
              N1.getOperand(2).getOpcode() == ISD::FMUL && N1.hasOneUse() &&
              N1.getOperand(2).hasOneUse()) {
     FMA = N1;

diff  --git a/llvm/test/CodeGen/AArch64/fadd-combines.ll b/llvm/test/CodeGen/AArch64/fadd-combines.ll
index 0e4f2c02c311..2ff485830780 100644
--- a/llvm/test/CodeGen/AArch64/fadd-combines.ll
+++ b/llvm/test/CodeGen/AArch64/fadd-combines.ll
@@ -207,6 +207,10 @@ define double @fadd_fma_fmul_1(double %a, double %b, double %c, double %d, doubl
   ret double %a2
 }
 
+; Minimum FMF - the 1st fadd is contracted because that combines
+; fmul+fadd as specified by the order of operations; the 2nd fadd
+; requires reassociation to fuse with c*d.
+
 define float @fadd_fma_fmul_fmf(float %a, float %b, float %c, float %d, float %n0) nounwind {
 ; CHECK-LABEL: fadd_fma_fmul_fmf:
 ; CHECK:       // %bb.0:
@@ -220,13 +224,14 @@ define float @fadd_fma_fmul_fmf(float %a, float %b, float %c, float %d, float %n
   ret float %a2
 }
 
-; Minimum FMF, commute final add operands, change type.
+; Not minimum FMF.
 
 define float @fadd_fma_fmul_2(float %a, float %b, float %c, float %d, float %n0) nounwind {
 ; CHECK-LABEL: fadd_fma_fmul_2:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    fmadd s2, s2, s3, s4
+; CHECK-NEXT:    fmul s2, s2, s3
 ; CHECK-NEXT:    fmadd s0, s0, s1, s2
+; CHECK-NEXT:    fadd s0, s4, s0
 ; CHECK-NEXT:    ret
   %m1 = fmul float %a, %b
   %m2 = fmul float %c, %d

diff  --git a/llvm/test/CodeGen/X86/fma_patterns.ll b/llvm/test/CodeGen/X86/fma_patterns.ll
index 3049365b6f32..43b1f4a79aff 100644
--- a/llvm/test/CodeGen/X86/fma_patterns.ll
+++ b/llvm/test/CodeGen/X86/fma_patterns.ll
@@ -1821,6 +1821,10 @@ define double @fadd_fma_fmul_1(double %a, double %b, double %c, double %d, doubl
   ret double %a2
 }
 
+; Minimum FMF - the 1st fadd is contracted because that combines
+; fmul+fadd as specified by the order of operations; the 2nd fadd
+; requires reassociation to fuse with c*d.
+
 define float @fadd_fma_fmul_fmf(float %a, float %b, float %c, float %d, float %n0) nounwind {
 ; FMA-LABEL: fadd_fma_fmul_fmf:
 ; FMA:       # %bb.0:
@@ -1846,25 +1850,28 @@ define float @fadd_fma_fmul_fmf(float %a, float %b, float %c, float %d, float %n
   ret float %a2
 }
 
-; Minimum FMF, commute final add operands, change type.
+; Not minimum FMF.
 
 define float @fadd_fma_fmul_2(float %a, float %b, float %c, float %d, float %n0) nounwind {
 ; FMA-LABEL: fadd_fma_fmul_2:
 ; FMA:       # %bb.0:
-; FMA-NEXT:    vfmadd213ss {{.*#+}} xmm2 = (xmm3 * xmm2) + xmm4
-; FMA-NEXT:    vfmadd213ss {{.*#+}} xmm0 = (xmm1 * xmm0) + xmm2
+; FMA-NEXT:    vmulss %xmm3, %xmm2, %xmm2
+; FMA-NEXT:    vfmadd231ss {{.*#+}} xmm2 = (xmm1 * xmm0) + xmm2
+; FMA-NEXT:    vaddss %xmm2, %xmm4, %xmm0
 ; FMA-NEXT:    retq
 ;
 ; FMA4-LABEL: fadd_fma_fmul_2:
 ; FMA4:       # %bb.0:
-; FMA4-NEXT:    vfmaddss {{.*#+}} xmm2 = (xmm2 * xmm3) + xmm4
+; FMA4-NEXT:    vmulss %xmm3, %xmm2, %xmm2
 ; FMA4-NEXT:    vfmaddss {{.*#+}} xmm0 = (xmm0 * xmm1) + xmm2
+; FMA4-NEXT:    vaddss %xmm0, %xmm4, %xmm0
 ; FMA4-NEXT:    retq
 ;
 ; AVX512-LABEL: fadd_fma_fmul_2:
 ; AVX512:       # %bb.0:
-; AVX512-NEXT:    vfmadd213ss {{.*#+}} xmm2 = (xmm3 * xmm2) + xmm4
-; AVX512-NEXT:    vfmadd213ss {{.*#+}} xmm0 = (xmm1 * xmm0) + xmm2
+; AVX512-NEXT:    vmulss %xmm3, %xmm2, %xmm2
+; AVX512-NEXT:    vfmadd231ss {{.*#+}} xmm2 = (xmm1 * xmm0) + xmm2
+; AVX512-NEXT:    vaddss %xmm2, %xmm4, %xmm0
 ; AVX512-NEXT:    retq
   %m1 = fmul float %a, %b
   %m2 = fmul float %c, %d


        


More information about the llvm-commits mailing list