[llvm] 39009a8 - [DAGCombiner] tighten fast-math constraints for fma fold
Sanjay Patel via llvm-commits
llvm-commits at lists.llvm.org
Sun Jul 12 05:53:14 PDT 2020
Author: Sanjay Patel
Date: 2020-07-12T08:51:49-04:00
New Revision: 39009a8245dae78250081b16fc679ce338af405a
URL: https://github.com/llvm/llvm-project/commit/39009a8245dae78250081b16fc679ce338af405a
DIFF: https://github.com/llvm/llvm-project/commit/39009a8245dae78250081b16fc679ce338af405a.diff
LOG: [DAGCombiner] tighten fast-math constraints for fma fold
fadd (fma A, B, (fmul C, D)), E --> fma A, B, (fma C, D, E)
This is only allowed when "reassoc" is present on the fadd.
As discussed in D80801, this transform goes beyond
what is allowed by "contract" FMF (-ffp-contract=fast).
That is because we are fusing the trailing add of 'E' with a
multiply, but without "reassoc", the code mandates that the
products A*B and C*D are added together before adding in 'E'.
I've added this example to the LangRef to try to clarify the
meaning of "contract". If that seems reasonable, we should
probably do something similar for the clang docs because
there does not appear to be any formal spec for the behavior
of -ffp-contract=fast.
Differential Revision: https://reviews.llvm.org/D82499
Added:
Modified:
llvm/docs/LangRef.rst
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
llvm/test/CodeGen/AArch64/fadd-combines.ll
llvm/test/CodeGen/X86/fma_patterns.ll
Removed:
################################################################################
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index c2d6200e67fa..86d315be74bc 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -2778,7 +2778,9 @@ floating-point transformations.
``contract``
Allow floating-point contraction (e.g. fusing a multiply followed by an
- addition into a fused multiply-and-add).
+ addition into a fused multiply-and-add). This does not enable reassociating
+ to form arbitrary contractions. For example, ``(a*b) + (c*d) + e`` can not
+ be transformed into ``(a*b) + ((c*d) + e)`` to create two fma operations.
``afn``
Approximate functions - Allow substitution of approximate calculations for
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 0d84cd89f5ae..42e6e12f3f02 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -11986,6 +11986,8 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
SDNodeFlags Flags = N->getFlags();
bool CanFuse = Options.UnsafeFPMath || isContractable(N);
+ bool CanReassociate =
+ Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
CanFuse || HasFMAD);
// If the addition is not contractable, do not combine.
@@ -12028,13 +12030,14 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
// fadd (fma A, B, (fmul C, D)), E --> fma A, B, (fma C, D, E)
// fadd E, (fma A, B, (fmul C, D)) --> fma A, B, (fma C, D, E)
+ // This requires reassociation because it changes the order of operations.
SDValue FMA, E;
- if (CanFuse && N0.getOpcode() == PreferredFusedOpcode &&
+ if (CanReassociate && N0.getOpcode() == PreferredFusedOpcode &&
N0.getOperand(2).getOpcode() == ISD::FMUL && N0.hasOneUse() &&
N0.getOperand(2).hasOneUse()) {
FMA = N0;
E = N1;
- } else if (CanFuse && N1.getOpcode() == PreferredFusedOpcode &&
+ } else if (CanReassociate && N1.getOpcode() == PreferredFusedOpcode &&
N1.getOperand(2).getOpcode() == ISD::FMUL && N1.hasOneUse() &&
N1.getOperand(2).hasOneUse()) {
FMA = N1;
diff --git a/llvm/test/CodeGen/AArch64/fadd-combines.ll b/llvm/test/CodeGen/AArch64/fadd-combines.ll
index 0e4f2c02c311..2ff485830780 100644
--- a/llvm/test/CodeGen/AArch64/fadd-combines.ll
+++ b/llvm/test/CodeGen/AArch64/fadd-combines.ll
@@ -207,6 +207,10 @@ define double @fadd_fma_fmul_1(double %a, double %b, double %c, double %d, doubl
ret double %a2
}
+; Minimum FMF - the 1st fadd is contracted because that combines
+; fmul+fadd as specified by the order of operations; the 2nd fadd
+; requires reassociation to fuse with c*d.
+
define float @fadd_fma_fmul_fmf(float %a, float %b, float %c, float %d, float %n0) nounwind {
; CHECK-LABEL: fadd_fma_fmul_fmf:
; CHECK: // %bb.0:
@@ -220,13 +224,14 @@ define float @fadd_fma_fmul_fmf(float %a, float %b, float %c, float %d, float %n
ret float %a2
}
-; Minimum FMF, commute final add operands, change type.
+; Not minimum FMF.
define float @fadd_fma_fmul_2(float %a, float %b, float %c, float %d, float %n0) nounwind {
; CHECK-LABEL: fadd_fma_fmul_2:
; CHECK: // %bb.0:
-; CHECK-NEXT: fmadd s2, s2, s3, s4
+; CHECK-NEXT: fmul s2, s2, s3
; CHECK-NEXT: fmadd s0, s0, s1, s2
+; CHECK-NEXT: fadd s0, s4, s0
; CHECK-NEXT: ret
%m1 = fmul float %a, %b
%m2 = fmul float %c, %d
diff --git a/llvm/test/CodeGen/X86/fma_patterns.ll b/llvm/test/CodeGen/X86/fma_patterns.ll
index 3049365b6f32..43b1f4a79aff 100644
--- a/llvm/test/CodeGen/X86/fma_patterns.ll
+++ b/llvm/test/CodeGen/X86/fma_patterns.ll
@@ -1821,6 +1821,10 @@ define double @fadd_fma_fmul_1(double %a, double %b, double %c, double %d, doubl
ret double %a2
}
+; Minimum FMF - the 1st fadd is contracted because that combines
+; fmul+fadd as specified by the order of operations; the 2nd fadd
+; requires reassociation to fuse with c*d.
+
define float @fadd_fma_fmul_fmf(float %a, float %b, float %c, float %d, float %n0) nounwind {
; FMA-LABEL: fadd_fma_fmul_fmf:
; FMA: # %bb.0:
@@ -1846,25 +1850,28 @@ define float @fadd_fma_fmul_fmf(float %a, float %b, float %c, float %d, float %n
ret float %a2
}
-; Minimum FMF, commute final add operands, change type.
+; Not minimum FMF.
define float @fadd_fma_fmul_2(float %a, float %b, float %c, float %d, float %n0) nounwind {
; FMA-LABEL: fadd_fma_fmul_2:
; FMA: # %bb.0:
-; FMA-NEXT: vfmadd213ss {{.*#+}} xmm2 = (xmm3 * xmm2) + xmm4
-; FMA-NEXT: vfmadd213ss {{.*#+}} xmm0 = (xmm1 * xmm0) + xmm2
+; FMA-NEXT: vmulss %xmm3, %xmm2, %xmm2
+; FMA-NEXT: vfmadd231ss {{.*#+}} xmm2 = (xmm1 * xmm0) + xmm2
+; FMA-NEXT: vaddss %xmm2, %xmm4, %xmm0
; FMA-NEXT: retq
;
; FMA4-LABEL: fadd_fma_fmul_2:
; FMA4: # %bb.0:
-; FMA4-NEXT: vfmaddss {{.*#+}} xmm2 = (xmm2 * xmm3) + xmm4
+; FMA4-NEXT: vmulss %xmm3, %xmm2, %xmm2
; FMA4-NEXT: vfmaddss {{.*#+}} xmm0 = (xmm0 * xmm1) + xmm2
+; FMA4-NEXT: vaddss %xmm0, %xmm4, %xmm0
; FMA4-NEXT: retq
;
; AVX512-LABEL: fadd_fma_fmul_2:
; AVX512: # %bb.0:
-; AVX512-NEXT: vfmadd213ss {{.*#+}} xmm2 = (xmm3 * xmm2) + xmm4
-; AVX512-NEXT: vfmadd213ss {{.*#+}} xmm0 = (xmm1 * xmm0) + xmm2
+; AVX512-NEXT: vmulss %xmm3, %xmm2, %xmm2
+; AVX512-NEXT: vfmadd231ss {{.*#+}} xmm2 = (xmm1 * xmm0) + xmm2
+; AVX512-NEXT: vaddss %xmm2, %xmm4, %xmm0
; AVX512-NEXT: retq
%m1 = fmul float %a, %b
%m2 = fmul float %c, %d
More information about the llvm-commits
mailing list