[llvm] 702cf93 - [DAGCombiner] allow more folding of fadd + fmul into fma

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 9 07:41:40 PDT 2020


Author: Sanjay Patel
Date: 2020-06-09T10:41:27-04:00
New Revision: 702cf933565ea942c5feb7521c89b237f281c4f3

URL: https://github.com/llvm/llvm-project/commit/702cf933565ea942c5feb7521c89b237f281c4f3
DIFF: https://github.com/llvm/llvm-project/commit/702cf933565ea942c5feb7521c89b237f281c4f3.diff

LOG: [DAGCombiner] allow more folding of fadd + fmul into fma

If fmul and fadd are separated by an fma, we can fold them together
to save an instruction:
fadd (fma A, B, (fmul C, D)), N1 --> fma(A, B, fma(C, D, N1))

The fold implemented here is actually a specialization - we should
be able to peek through >1 fma to find this pattern. That's another
patch if we want to try that enhancement though.

This transform was guarded by the TLI hook enableAggressiveFMAFusion(),
so it was done for some in-tree targets like PowerPC, but not AArch64
or x86. The hook is protecting against forming a potentially more
expensive computation when fma takes longer to execute than a single
fadd. That hook may be needed for other transforms, but in this case,
we are replacing fmul+fadd with fma, and the fma should never take
longer than the 2 individual instructions.

'contract' FMF is all we need to allow this transform. That flag
corresponds to -ffp-contract=fast in Clang, so we are allowed to form
fma ops freely across expressions.

Differential Revision: https://reviews.llvm.org/D80801

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/test/CodeGen/AArch64/fadd-combines.ll
    llvm/test/CodeGen/X86/fma_patterns.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index e3275caed112..1352a7d64f2a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -11919,6 +11919,29 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
                        N1.getOperand(0), N1.getOperand(1), N0, Flags);
   }
 
+  // fadd (fma A, B, (fmul C, D)), E --> fma A, B, (fma C, D, E)
+  // fadd E, (fma A, B, (fmul C, D)) --> fma A, B, (fma C, D, E)
+  SDValue FMA, E;
+  if (CanFuse && N0.getOpcode() == PreferredFusedOpcode &&
+      N0.getOperand(2).getOpcode() == ISD::FMUL && N0.hasOneUse() &&
+      N0.getOperand(2).hasOneUse()) {
+    FMA = N0;
+    E = N1;
+  } else if (CanFuse && N1.getOpcode() == PreferredFusedOpcode &&
+             N1.getOperand(2).getOpcode() == ISD::FMUL && N1.hasOneUse() &&
+             N1.getOperand(2).hasOneUse()) {
+    FMA = N1;
+    E = N0;
+  }
+  if (FMA && E) {
+    SDValue A = FMA.getOperand(0);
+    SDValue B = FMA.getOperand(1);
+    SDValue C = FMA.getOperand(2).getOperand(0);
+    SDValue D = FMA.getOperand(2).getOperand(1);
+    SDValue CDE = DAG.getNode(PreferredFusedOpcode, SL, VT, C, D, E, Flags);
+    return DAG.getNode(PreferredFusedOpcode, SL, VT, A, B, CDE, Flags);
+  }
+
   // Look through FP_EXTEND nodes to do more combining.
 
   // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z)
@@ -11952,29 +11975,6 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
 
   // More folding opportunities when target permits.
   if (Aggressive) {
-    // fadd (fma A, B, (fmul C, D)), E --> fma A, B, (fma C, D, E)
-    // fadd E, (fma A, B, (fmul C, D)) --> fma A, B, (fma C, D, E)
-    SDValue FMA, E;
-    if (CanFuse && N0.getOpcode() == PreferredFusedOpcode &&
-        N0.getOperand(2).getOpcode() == ISD::FMUL && N0.hasOneUse() &&
-        N0.getOperand(2).hasOneUse()) {
-      FMA = N0;
-      E = N1;
-    } else if (CanFuse && N1.getOpcode() == PreferredFusedOpcode &&
-               N1.getOperand(2).getOpcode() == ISD::FMUL && N1.hasOneUse() &&
-               N1.getOperand(2).hasOneUse()) {
-      FMA = N1;
-      E = N0;
-    }
-    if (FMA && E) {
-      SDValue A = FMA.getOperand(0);
-      SDValue B = FMA.getOperand(1);
-      SDValue C = FMA.getOperand(2).getOperand(0);
-      SDValue D = FMA.getOperand(2).getOperand(1);
-      SDValue CDE = DAG.getNode(PreferredFusedOpcode, SL, VT, C, D, E, Flags);
-      return DAG.getNode(PreferredFusedOpcode, SL, VT, A, B, CDE, Flags);
-    }
-
     // fold (fadd (fma x, y, (fpext (fmul u, v))), z)
     //   -> (fma x, y, (fma (fpext u), (fpext v), z))
     auto FoldFAddFMAFPExtFMul = [&] (

diff  --git a/llvm/test/CodeGen/AArch64/fadd-combines.ll b/llvm/test/CodeGen/AArch64/fadd-combines.ll
index 3702cc540da3..61d0ecc04e03 100644
--- a/llvm/test/CodeGen/AArch64/fadd-combines.ll
+++ b/llvm/test/CodeGen/AArch64/fadd-combines.ll
@@ -197,9 +197,8 @@ define <2 x double> @fmul2_negated_vec(<2 x double> %a, <2 x double> %b, <2 x do
 define double @fadd_fma_fmul_1(double %a, double %b, double %c, double %d, double %n1) nounwind {
 ; CHECK-LABEL: fadd_fma_fmul_1:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    fmul d2, d2, d3
+; CHECK-NEXT:    fmadd d2, d2, d3, d4
 ; CHECK-NEXT:    fmadd d0, d0, d1, d2
-; CHECK-NEXT:    fadd d0, d0, d4
 ; CHECK-NEXT:    ret
   %m1 = fmul fast double %a, %b
   %m2 = fmul fast double %c, %d
@@ -213,9 +212,8 @@ define double @fadd_fma_fmul_1(double %a, double %b, double %c, double %d, doubl
 define float @fadd_fma_fmul_2(float %a, float %b, float %c, float %d, float %n0) nounwind {
 ; CHECK-LABEL: fadd_fma_fmul_2:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    fmul s2, s2, s3
+; CHECK-NEXT:    fmadd s2, s2, s3, s4
 ; CHECK-NEXT:    fmadd s0, s0, s1, s2
-; CHECK-NEXT:    fadd s0, s4, s0
 ; CHECK-NEXT:    ret
   %m1 = fmul float %a, %b
   %m2 = fmul float %c, %d
@@ -230,10 +228,10 @@ define <2 x double> @fadd_fma_fmul_3(<2 x double> %x1, <2 x double> %x2, <2 x do
 ; CHECK-LABEL: fadd_fma_fmul_3:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    fmul v2.2d, v2.2d, v3.2d
-; CHECK-NEXT:    fmul v3.2d, v6.2d, v7.2d
 ; CHECK-NEXT:    fmla v2.2d, v1.2d, v0.2d
-; CHECK-NEXT:    fmla v3.2d, v5.2d, v4.2d
-; CHECK-NEXT:    fadd v0.2d, v2.2d, v3.2d
+; CHECK-NEXT:    fmla v2.2d, v7.2d, v6.2d
+; CHECK-NEXT:    fmla v2.2d, v5.2d, v4.2d
+; CHECK-NEXT:    mov v0.16b, v2.16b
 ; CHECK-NEXT:    ret
   %m1 = fmul fast <2 x double> %x1, %x2
   %m2 = fmul fast <2 x double> %x3, %x4
@@ -245,6 +243,8 @@ define <2 x double> @fadd_fma_fmul_3(<2 x double> %x1, <2 x double> %x2, <2 x do
   ret <2 x double> %a3
 }
 
+; negative test
+
 define float @fadd_fma_fmul_extra_use_1(float %a, float %b, float %c, float %d, float %n0, float* %p) nounwind {
 ; CHECK-LABEL: fadd_fma_fmul_extra_use_1:
 ; CHECK:       // %bb.0:
@@ -261,6 +261,8 @@ define float @fadd_fma_fmul_extra_use_1(float %a, float %b, float %c, float %d,
   ret float %a2
 }
 
+; negative test
+
 define float @fadd_fma_fmul_extra_use_2(float %a, float %b, float %c, float %d, float %n0, float* %p) nounwind {
 ; CHECK-LABEL: fadd_fma_fmul_extra_use_2:
 ; CHECK:       // %bb.0:
@@ -277,6 +279,8 @@ define float @fadd_fma_fmul_extra_use_2(float %a, float %b, float %c, float %d,
   ret float %a2
 }
 
+; negative test
+
 define float @fadd_fma_fmul_extra_use_3(float %a, float %b, float %c, float %d, float %n0, float* %p) nounwind {
 ; CHECK-LABEL: fadd_fma_fmul_extra_use_3:
 ; CHECK:       // %bb.0:

diff  --git a/llvm/test/CodeGen/X86/fma_patterns.ll b/llvm/test/CodeGen/X86/fma_patterns.ll
index 7b9d51147480..a11ddf171830 100644
--- a/llvm/test/CodeGen/X86/fma_patterns.ll
+++ b/llvm/test/CodeGen/X86/fma_patterns.ll
@@ -1799,23 +1799,20 @@ define <4 x double> @test_v4f64_fneg_fmul_no_nsz(<4 x double> %x, <4 x double> %
 define double @fadd_fma_fmul_1(double %a, double %b, double %c, double %d, double %n1) nounwind {
 ; FMA-LABEL: fadd_fma_fmul_1:
 ; FMA:       # %bb.0:
-; FMA-NEXT:    vmulsd %xmm3, %xmm2, %xmm2
-; FMA-NEXT:    vfmadd231sd {{.*#+}} xmm2 = (xmm1 * xmm0) + xmm2
-; FMA-NEXT:    vaddsd %xmm4, %xmm2, %xmm0
+; FMA-NEXT:    vfmadd213sd {{.*#+}} xmm2 = (xmm3 * xmm2) + xmm4
+; FMA-NEXT:    vfmadd213sd {{.*#+}} xmm0 = (xmm1 * xmm0) + xmm2
 ; FMA-NEXT:    retq
 ;
 ; FMA4-LABEL: fadd_fma_fmul_1:
 ; FMA4:       # %bb.0:
-; FMA4-NEXT:    vmulsd %xmm3, %xmm2, %xmm2
+; FMA4-NEXT:    vfmaddsd {{.*#+}} xmm2 = (xmm2 * xmm3) + xmm4
 ; FMA4-NEXT:    vfmaddsd {{.*#+}} xmm0 = (xmm0 * xmm1) + xmm2
-; FMA4-NEXT:    vaddsd %xmm4, %xmm0, %xmm0
 ; FMA4-NEXT:    retq
 ;
 ; AVX512-LABEL: fadd_fma_fmul_1:
 ; AVX512:       # %bb.0:
-; AVX512-NEXT:    vmulsd %xmm3, %xmm2, %xmm2
-; AVX512-NEXT:    vfmadd231sd {{.*#+}} xmm2 = (xmm1 * xmm0) + xmm2
-; AVX512-NEXT:    vaddsd %xmm4, %xmm2, %xmm0
+; AVX512-NEXT:    vfmadd213sd {{.*#+}} xmm2 = (xmm3 * xmm2) + xmm4
+; AVX512-NEXT:    vfmadd213sd {{.*#+}} xmm0 = (xmm1 * xmm0) + xmm2
 ; AVX512-NEXT:    retq
   %m1 = fmul fast double %a, %b
   %m2 = fmul fast double %c, %d
@@ -1829,23 +1826,20 @@ define double @fadd_fma_fmul_1(double %a, double %b, double %c, double %d, doubl
 define float @fadd_fma_fmul_2(float %a, float %b, float %c, float %d, float %n0) nounwind {
 ; FMA-LABEL: fadd_fma_fmul_2:
 ; FMA:       # %bb.0:
-; FMA-NEXT:    vmulss %xmm3, %xmm2, %xmm2
-; FMA-NEXT:    vfmadd231ss {{.*#+}} xmm2 = (xmm1 * xmm0) + xmm2
-; FMA-NEXT:    vaddss %xmm2, %xmm4, %xmm0
+; FMA-NEXT:    vfmadd213ss {{.*#+}} xmm2 = (xmm3 * xmm2) + xmm4
+; FMA-NEXT:    vfmadd213ss {{.*#+}} xmm0 = (xmm1 * xmm0) + xmm2
 ; FMA-NEXT:    retq
 ;
 ; FMA4-LABEL: fadd_fma_fmul_2:
 ; FMA4:       # %bb.0:
-; FMA4-NEXT:    vmulss %xmm3, %xmm2, %xmm2
+; FMA4-NEXT:    vfmaddss {{.*#+}} xmm2 = (xmm2 * xmm3) + xmm4
 ; FMA4-NEXT:    vfmaddss {{.*#+}} xmm0 = (xmm0 * xmm1) + xmm2
-; FMA4-NEXT:    vaddss %xmm0, %xmm4, %xmm0
 ; FMA4-NEXT:    retq
 ;
 ; AVX512-LABEL: fadd_fma_fmul_2:
 ; AVX512:       # %bb.0:
-; AVX512-NEXT:    vmulss %xmm3, %xmm2, %xmm2
-; AVX512-NEXT:    vfmadd231ss {{.*#+}} xmm2 = (xmm1 * xmm0) + xmm2
-; AVX512-NEXT:    vaddss %xmm2, %xmm4, %xmm0
+; AVX512-NEXT:    vfmadd213ss {{.*#+}} xmm2 = (xmm3 * xmm2) + xmm4
+; AVX512-NEXT:    vfmadd213ss {{.*#+}} xmm0 = (xmm1 * xmm0) + xmm2
 ; AVX512-NEXT:    retq
   %m1 = fmul float %a, %b
   %m2 = fmul float %c, %d
@@ -1860,28 +1854,27 @@ define <2 x double> @fadd_fma_fmul_3(<2 x double> %x1, <2 x double> %x2, <2 x do
 ; FMA-LABEL: fadd_fma_fmul_3:
 ; FMA:       # %bb.0:
 ; FMA-NEXT:    vmulpd %xmm3, %xmm2, %xmm2
-; FMA-NEXT:    vmulpd %xmm7, %xmm6, %xmm3
 ; FMA-NEXT:    vfmadd231pd {{.*#+}} xmm2 = (xmm1 * xmm0) + xmm2
-; FMA-NEXT:    vfmadd231pd {{.*#+}} xmm3 = (xmm5 * xmm4) + xmm3
-; FMA-NEXT:    vaddpd %xmm3, %xmm2, %xmm0
+; FMA-NEXT:    vfmadd231pd {{.*#+}} xmm2 = (xmm7 * xmm6) + xmm2
+; FMA-NEXT:    vfmadd231pd {{.*#+}} xmm2 = (xmm5 * xmm4) + xmm2
+; FMA-NEXT:    vmovapd %xmm2, %xmm0
 ; FMA-NEXT:    retq
 ;
 ; FMA4-LABEL: fadd_fma_fmul_3:
 ; FMA4:       # %bb.0:
 ; FMA4-NEXT:    vmulpd %xmm3, %xmm2, %xmm2
-; FMA4-NEXT:    vmulpd %xmm7, %xmm6, %xmm3
 ; FMA4-NEXT:    vfmaddpd {{.*#+}} xmm0 = (xmm0 * xmm1) + xmm2
-; FMA4-NEXT:    vfmaddpd {{.*#+}} xmm1 = (xmm4 * xmm5) + xmm3
-; FMA4-NEXT:    vaddpd %xmm1, %xmm0, %xmm0
+; FMA4-NEXT:    vfmaddpd {{.*#+}} xmm0 = (xmm6 * xmm7) + xmm0
+; FMA4-NEXT:    vfmaddpd {{.*#+}} xmm0 = (xmm4 * xmm5) + xmm0
 ; FMA4-NEXT:    retq
 ;
 ; AVX512-LABEL: fadd_fma_fmul_3:
 ; AVX512:       # %bb.0:
 ; AVX512-NEXT:    vmulpd %xmm3, %xmm2, %xmm2
-; AVX512-NEXT:    vmulpd %xmm7, %xmm6, %xmm3
 ; AVX512-NEXT:    vfmadd231pd {{.*#+}} xmm2 = (xmm1 * xmm0) + xmm2
-; AVX512-NEXT:    vfmadd231pd {{.*#+}} xmm3 = (xmm5 * xmm4) + xmm3
-; AVX512-NEXT:    vaddpd %xmm3, %xmm2, %xmm0
+; AVX512-NEXT:    vfmadd231pd {{.*#+}} xmm2 = (xmm7 * xmm6) + xmm2
+; AVX512-NEXT:    vfmadd231pd {{.*#+}} xmm2 = (xmm5 * xmm4) + xmm2
+; AVX512-NEXT:    vmovapd %xmm2, %xmm0
 ; AVX512-NEXT:    retq
   %m1 = fmul fast <2 x double> %x1, %x2
   %m2 = fmul fast <2 x double> %x3, %x4
@@ -1893,6 +1886,8 @@ define <2 x double> @fadd_fma_fmul_3(<2 x double> %x1, <2 x double> %x2, <2 x do
   ret <2 x double> %a3
 }
 
+; negative test
+
 define float @fadd_fma_fmul_extra_use_1(float %a, float %b, float %c, float %d, float %n0, float* %p) nounwind {
 ; FMA-LABEL: fadd_fma_fmul_extra_use_1:
 ; FMA:       # %bb.0:
@@ -1925,6 +1920,8 @@ define float @fadd_fma_fmul_extra_use_1(float %a, float %b, float %c, float %d,
   ret float %a2
 }
 
+; negative test
+
 define float @fadd_fma_fmul_extra_use_2(float %a, float %b, float %c, float %d, float %n0, float* %p) nounwind {
 ; FMA-LABEL: fadd_fma_fmul_extra_use_2:
 ; FMA:       # %bb.0:
@@ -1957,6 +1954,8 @@ define float @fadd_fma_fmul_extra_use_2(float %a, float %b, float %c, float %d,
   ret float %a2
 }
 
+; negative test
+
 define float @fadd_fma_fmul_extra_use_3(float %a, float %b, float %c, float %d, float %n0, float* %p) nounwind {
 ; FMA-LABEL: fadd_fma_fmul_extra_use_3:
 ; FMA:       # %bb.0:


        


More information about the llvm-commits mailing list