[llvm] 0baace5 - [DAGCombine] Add node level checks for fp-contract and fp-ninf in visitFMULForFMADistributiveCombine().

Abinav Puthan Purayil via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 1 23:03:41 PDT 2021


Author: Abinav Puthan Purayil
Date: 2021-09-02T11:33:14+05:30
New Revision: 0baace53799424bd1e3cfe9068ee254fae0ca677

URL: https://github.com/llvm/llvm-project/commit/0baace53799424bd1e3cfe9068ee254fae0ca677
DIFF: https://github.com/llvm/llvm-project/commit/0baace53799424bd1e3cfe9068ee254fae0ca677.diff

LOG: [DAGCombine] Add node level checks for fp-contract and fp-ninf in visitFMULForFMADistributiveCombine().

Differential Revision: https://reviews.llvm.org/D107551

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/test/CodeGen/AMDGPU/fma.ll
    llvm/test/CodeGen/X86/fma-scalar-combine.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index aae9eea994bb8..9c8febcedc59e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -13015,6 +13015,20 @@ ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) {
   return DAG.getBuildVector(VT, DL, Ops);
 }
 
+// Returns true if floating point contraction is allowed on the FMUL-SDValue
+// `N`
+static bool isContractableFMUL(const TargetOptions &Options, SDValue N) {
+  assert(N.getOpcode() == ISD::FMUL);
+
+  return Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath ||
+         N->getFlags().hasAllowContract();
+}
+
+// Return true if `N` can assume no infinities involved in it's computation.
+static bool hasNoInfs(const TargetOptions &Options, SDValue N) {
+  return Options.NoInfsFPMath || N.getNode()->getFlags().hasNoInfs();
+}
+
 /// Try to perform FMA combining on a given FADD node.
 SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
   SDValue N0 = N->getOperand(0);
@@ -13557,12 +13571,13 @@ SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) {
 
   // The transforms below are incorrect when x == 0 and y == inf, because the
   // intermediate multiplication produces a nan.
-  if (!Options.NoInfsFPMath)
+  SDValue FAdd = N0.getOpcode() == ISD::FADD ? N0 : N1;
+  if (!hasNoInfs(Options, FAdd))
     return SDValue();
 
   // Floating-point multiply-add without intermediate rounding.
   bool HasFMA =
-      (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath) &&
+      isContractableFMUL(Options, SDValue(N, 0)) &&
       TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) &&
       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
 

diff  --git a/llvm/test/CodeGen/AMDGPU/fma.ll b/llvm/test/CodeGen/AMDGPU/fma.ll
index 6fadd91baf671..dcf3254242f4c 100644
--- a/llvm/test/CodeGen/AMDGPU/fma.ll
+++ b/llvm/test/CodeGen/AMDGPU/fma.ll
@@ -144,3 +144,13 @@ bb:
   store float %tmp10, float addrspace(1)* %gep.out
   ret void
 }
+
+; Fold (fmul (fadd x, 1.0), y) -> (fma x, y, y) without FP specific command-line
+; options.
+; FUNC-LABEL: {{^}}fold_fmul_distributive:
+; GFX906: v_fmac_f32_e32 v0, v1, v0
+define float @fold_fmul_distributive(float %x, float %y) {
+  %fadd = fadd ninf float %y, 1.0
+  %fmul = fmul contract float %fadd, %x
+  ret float %fmul
+}

diff  --git a/llvm/test/CodeGen/X86/fma-scalar-combine.ll b/llvm/test/CodeGen/X86/fma-scalar-combine.ll
index 08ae3430ea20b..7f7e21bf81d51 100644
--- a/llvm/test/CodeGen/X86/fma-scalar-combine.ll
+++ b/llvm/test/CodeGen/X86/fma-scalar-combine.ll
@@ -558,3 +558,16 @@ define float @fma_const_fmul(float %x) {
   %add1 = fadd contract float %mul1, %mul2
   ret float %add1
 }
+
+; Fold (fmul (fadd x, 1.0), y) -> (fma x, y, y) without FP specific command-line
+; options.
+define float @combine_fmul_distributive(float %x, float %y) {
+; CHECK-LABEL: combine_fmul_distributive:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vfmadd231ss %xmm0, %xmm1, %xmm0 # EVEX TO VEX Compression encoding: [0xc4,0xe2,0x71,0xb9,0xc0]
+; CHECK-NEXT:    # xmm0 = (xmm1 * xmm0) + xmm0
+; CHECK-NEXT:    retq # encoding: [0xc3]
+  %fadd = fadd ninf float %y, 1.0
+  %fmul = fmul contract float %fadd, %x
+  ret float %fmul
+}


        


More information about the llvm-commits mailing list