[llvm] c5276f7 - [X86] Combine constant vector inputs for FMA

Phoebe Wang via llvm-commits llvm-commits at lists.llvm.org
Wed May 17 00:21:42 PDT 2023


Author: Evgenii Kudriashov
Date: 2023-05-17T15:21:34+08:00
New Revision: c5276f7728900ac59512e84bce25b8e69c0fd603

URL: https://github.com/llvm/llvm-project/commit/c5276f7728900ac59512e84bce25b8e69c0fd603
DIFF: https://github.com/llvm/llvm-project/commit/c5276f7728900ac59512e84bce25b8e69c0fd603.diff

LOG: [X86] Combine constant vector inputs for FMA

Inspired by https://discourse.llvm.org/t/folding-memory-into-fma/69217

Reviewed By: pengfei

Differential Revision: https://reviews.llvm.org/D146494

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/test/CodeGen/X86/avx2-fma-fneg-combine.ll
    llvm/test/CodeGen/X86/fma-fneg-combine-2.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index fa3ff0f04e45..455500c8b4b6 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -54523,6 +54523,59 @@ static SDValue combineSext(SDNode *N, SelectionDAG &DAG,
   return SDValue();
 }
 
+// Inverting a constant vector is profitable if it can be eliminated and the
+// inverted vector is already present in DAG. Otherwise, it will be loaded
+// anyway.
+//
+// We determine which of the values can be completely eliminated and invert it.
+// If both are eliminable, select a vector with the first negative element.
+static SDValue getInvertedVectorForFMA(SDValue V, SelectionDAG &DAG) {
+  assert(ISD::isBuildVectorOfConstantFPSDNodes(V.getNode()) &&
+         "ConstantFP build vector expected");
+  // Check if we can eliminate V. We assume if a value is only used in FMAs, we
+  // can eliminate it. Since this function is invoked for each FMA with this
+  // vector.
+  auto IsNotFMA = [](SDNode *Use) {
+    return Use->getOpcode() != ISD::FMA && Use->getOpcode() != ISD::STRICT_FMA;
+  };
+  if (llvm::any_of(V->uses(), IsNotFMA))
+    return SDValue();
+
+  SmallVector<SDValue, 8> Ops;
+  EVT VT = V.getValueType();
+  EVT EltVT = VT.getVectorElementType();
+  for (auto Op : V->op_values()) {
+    if (auto *Cst = dyn_cast<ConstantFPSDNode>(Op)) {
+      Ops.push_back(DAG.getConstantFP(-Cst->getValueAPF(), SDLoc(Op), EltVT));
+    } else {
+      assert(Op.isUndef());
+      Ops.push_back(DAG.getUNDEF(EltVT));
+    }
+  }
+
+  SDNode *NV = DAG.getNodeIfExists(ISD::BUILD_VECTOR, DAG.getVTList(VT), Ops);
+  if (!NV)
+    return SDValue();
+
+  // If an inverted version cannot be eliminated, choose it instead of the
+  // original version.
+  if (llvm::any_of(NV->uses(), IsNotFMA))
+    return SDValue(NV, 0);
+
+  // If the inverted version also can be eliminated, we have to consistently
+  // prefer one of the values. We prefer a constant with a negative value on
+  // the first place.
+  // N.B. We need to skip undefs that may precede a value.
+  for (auto op : V->op_values()) {
+    if (auto *Cst = dyn_cast<ConstantFPSDNode>(op)) {
+      if (Cst->isNegative())
+        return SDValue();
+      break;
+    }
+  }
+  return SDValue(NV, 0);
+}
+
 static SDValue combineFMA(SDNode *N, SelectionDAG &DAG,
                           TargetLowering::DAGCombinerInfo &DCI,
                           const X86Subtarget &Subtarget) {
@@ -54574,7 +54627,13 @@ static SDValue combineFMA(SDNode *N, SelectionDAG &DAG,
         return true;
       }
     }
-
+    // Lookup if there is an inverted version of constant vector V in DAG.
+    if (ISD::isBuildVectorOfConstantFPSDNodes(V.getNode())) {
+      if (SDValue NegV = getInvertedVectorForFMA(V, DAG)) {
+        V = NegV;
+        return true;
+      }
+    }
     return false;
   };
 

diff  --git a/llvm/test/CodeGen/X86/avx2-fma-fneg-combine.ll b/llvm/test/CodeGen/X86/avx2-fma-fneg-combine.ll
index 9969734c97e9..5d7fc7aaa185 100644
--- a/llvm/test/CodeGen/X86/avx2-fma-fneg-combine.ll
+++ b/llvm/test/CodeGen/X86/avx2-fma-fneg-combine.ll
@@ -154,15 +154,13 @@ define <4 x double> @test9(<4 x double> %a) {
 ; X32-LABEL: test9:
 ; X32:       # %bb.0:
 ; X32-NEXT:    vbroadcastsd {{.*#+}} ymm1 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1]
-; X32-NEXT:    vbroadcastsd {{.*#+}} ymm2 = [5.0E-1,5.0E-1,5.0E-1,5.0E-1]
-; X32-NEXT:    vfmadd213pd {{.*#+}} ymm0 = (ymm2 * ymm0) + ymm1
+; X32-NEXT:    vfnmadd213pd {{.*#+}} ymm0 = -(ymm1 * ymm0) + ymm1
 ; X32-NEXT:    retl
 ;
 ; X64-LABEL: test9:
 ; X64:       # %bb.0:
 ; X64-NEXT:    vbroadcastsd {{.*#+}} ymm1 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1]
-; X64-NEXT:    vbroadcastsd {{.*#+}} ymm2 = [5.0E-1,5.0E-1,5.0E-1,5.0E-1]
-; X64-NEXT:    vfmadd213pd {{.*#+}} ymm0 = (ymm2 * ymm0) + ymm1
+; X64-NEXT:    vfnmadd213pd {{.*#+}} ymm0 = -(ymm1 * ymm0) + ymm1
 ; X64-NEXT:    retq
   %t = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %a, <4 x double> <double 5.000000e-01, double 5.000000e-01, double 5.000000e-01, double 5.000000e-01>, <4 x double> <double -5.000000e-01, double -5.000000e-01, double -5.000000e-01, double -5.000000e-01>)
   ret <4 x double> %t
@@ -172,17 +170,19 @@ define <4 x double> @test10(<4 x double> %a, <4 x double> %b) {
 ; X32-LABEL: test10:
 ; X32:       # %bb.0:
 ; X32-NEXT:    vmovapd {{.*#+}} ymm2 = <-9.5E+0,u,-5.5E+0,-2.5E+0>
-; X32-NEXT:    vfmadd213pd {{.*#+}} ymm2 = (ymm0 * ymm2) + ymm1
-; X32-NEXT:    vfmadd231pd {{.*#+}} ymm1 = (ymm0 * mem) + ymm1
-; X32-NEXT:    vaddpd %ymm1, %ymm2, %ymm0
+; X32-NEXT:    vmovapd %ymm2, %ymm3
+; X32-NEXT:    vfmadd213pd {{.*#+}} ymm3 = (ymm0 * ymm3) + ymm1
+; X32-NEXT:    vfnmadd213pd {{.*#+}} ymm2 = -(ymm0 * ymm2) + ymm1
+; X32-NEXT:    vaddpd %ymm2, %ymm3, %ymm0
 ; X32-NEXT:    retl
 ;
 ; X64-LABEL: test10:
 ; X64:       # %bb.0:
 ; X64-NEXT:    vmovapd {{.*#+}} ymm2 = <-9.5E+0,u,-5.5E+0,-2.5E+0>
-; X64-NEXT:    vfmadd213pd {{.*#+}} ymm2 = (ymm0 * ymm2) + ymm1
-; X64-NEXT:    vfmadd231pd {{.*#+}} ymm1 = (ymm0 * mem) + ymm1
-; X64-NEXT:    vaddpd %ymm1, %ymm2, %ymm0
+; X64-NEXT:    vmovapd %ymm2, %ymm3
+; X64-NEXT:    vfmadd213pd {{.*#+}} ymm3 = (ymm0 * ymm3) + ymm1
+; X64-NEXT:    vfnmadd213pd {{.*#+}} ymm2 = -(ymm0 * ymm2) + ymm1
+; X64-NEXT:    vaddpd %ymm2, %ymm3, %ymm0
 ; X64-NEXT:    retq
   %t0 = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %a, <4 x double> <double -95.00000e-01, double undef, double -55.00000e-01, double -25.00000e-01>, <4 x double> %b)
   %t1 = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %a, <4 x double> <double 95.00000e-01, double undef, double 55.00000e-01, double 25.00000e-01>, <4 x double> %b)
@@ -196,7 +196,7 @@ define <4 x double> @test11(<4 x double> %a) {
 ; X32-NEXT:    vbroadcastf128 {{.*#+}} ymm1 = [5.0E-1,2.5E+0,5.0E-1,2.5E+0]
 ; X32-NEXT:    # ymm1 = mem[0,1,0,1]
 ; X32-NEXT:    vaddpd %ymm1, %ymm0, %ymm0
-; X32-NEXT:    vfmadd213pd {{.*#+}} ymm0 = (ymm1 * ymm0) + mem
+; X32-NEXT:    vfmsub213pd {{.*#+}} ymm0 = (ymm1 * ymm0) - ymm1
 ; X32-NEXT:    retl
 ;
 ; X64-LABEL: test11:
@@ -204,7 +204,7 @@ define <4 x double> @test11(<4 x double> %a) {
 ; X64-NEXT:    vbroadcastf128 {{.*#+}} ymm1 = [5.0E-1,2.5E+0,5.0E-1,2.5E+0]
 ; X64-NEXT:    # ymm1 = mem[0,1,0,1]
 ; X64-NEXT:    vaddpd %ymm1, %ymm0, %ymm0
-; X64-NEXT:    vfmadd213pd {{.*#+}} ymm0 = (ymm1 * ymm0) + mem
+; X64-NEXT:    vfmsub213pd {{.*#+}} ymm0 = (ymm1 * ymm0) - ymm1
 ; X64-NEXT:    retq
   %t0 = fadd <4 x double> %a, <double 5.000000e-01, double 25.00000e-01, double 5.000000e-01, double 25.00000e-01>
   %t1 = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %t0, <4 x double> <double 5.000000e-01, double 25.00000e-01, double 5.000000e-01, double 25.00000e-01>, <4 x double> <double -5.000000e-01, double -25.00000e-01, double -5.000000e-01, double -25.00000e-01>)
@@ -214,20 +214,18 @@ define <4 x double> @test11(<4 x double> %a) {
 define <4 x double> @test12(<4 x double> %a, <4 x double> %b) {
 ; X32-LABEL: test12:
 ; X32:       # %bb.0:
-; X32-NEXT:    vmovapd {{.*#+}} ymm2 = [7.5E+0,2.5E+0,5.5E+0,9.5E+0]
-; X32-NEXT:    vfmadd213pd {{.*#+}} ymm2 = (ymm0 * ymm2) + mem
-; X32-NEXT:    vmovapd {{.*#+}} ymm0 = <u,2.5E+0,5.5E+0,9.5E+0>
-; X32-NEXT:    vfmadd213pd {{.*#+}} ymm0 = (ymm1 * ymm0) + mem
-; X32-NEXT:    vaddpd %ymm0, %ymm2, %ymm0
+; X32-NEXT:    vmovapd {{.*#+}} ymm2 = [-7.5E+0,-2.5E+0,-5.5E+0,-9.5E+0]
+; X32-NEXT:    vfnmadd213pd {{.*#+}} ymm0 = -(ymm2 * ymm0) + mem
+; X32-NEXT:    vfmadd132pd {{.*#+}} ymm1 = (ymm1 * mem) + ymm2
+; X32-NEXT:    vaddpd %ymm1, %ymm0, %ymm0
 ; X32-NEXT:    retl
 ;
 ; X64-LABEL: test12:
 ; X64:       # %bb.0:
-; X64-NEXT:    vmovapd {{.*#+}} ymm2 = [7.5E+0,2.5E+0,5.5E+0,9.5E+0]
-; X64-NEXT:    vfmadd213pd {{.*#+}} ymm2 = (ymm0 * ymm2) + mem
-; X64-NEXT:    vmovapd {{.*#+}} ymm0 = <u,2.5E+0,5.5E+0,9.5E+0>
-; X64-NEXT:    vfmadd213pd {{.*#+}} ymm0 = (ymm1 * ymm0) + mem
-; X64-NEXT:    vaddpd %ymm0, %ymm2, %ymm0
+; X64-NEXT:    vmovapd {{.*#+}} ymm2 = [-7.5E+0,-2.5E+0,-5.5E+0,-9.5E+0]
+; X64-NEXT:    vfnmadd213pd {{.*#+}} ymm0 = -(ymm2 * ymm0) + mem
+; X64-NEXT:    vfmadd132pd {{.*#+}} ymm1 = (ymm1 * mem) + ymm2
+; X64-NEXT:    vaddpd %ymm1, %ymm0, %ymm0
 ; X64-NEXT:    retq
   %t0 = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %a, <4 x double> <double 75.00000e-01, double 25.00000e-01, double 55.00000e-01, double 95.00000e-01>, <4 x double> <double -75.00000e-01, double undef, double -55.00000e-01, double -95.00000e-01>)
   %t1 = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %b, <4 x double> <double undef, double 25.00000e-01, double 55.00000e-01, double 95.00000e-01>, <4 x double> <double -75.00000e-01, double -25.00000e-01, double -55.00000e-01, double -95.00000e-01>)

diff  --git a/llvm/test/CodeGen/X86/fma-fneg-combine-2.ll b/llvm/test/CodeGen/X86/fma-fneg-combine-2.ll
index d3bc7399789d..2a3c3e3c7f4f 100644
--- a/llvm/test/CodeGen/X86/fma-fneg-combine-2.ll
+++ b/llvm/test/CodeGen/X86/fma-fneg-combine-2.ll
@@ -129,14 +129,14 @@ define float @negated_constant(float %x) {
 define <4 x double> @negated_constant_v4f64(<4 x double> %a) {
 ; FMA3-LABEL: negated_constant_v4f64:
 ; FMA3:       # %bb.0:
-; FMA3-NEXT:    vmovapd {{.*#+}} ymm1 = [5.0E-1,2.5E-1,1.25E-1,6.25E-2]
-; FMA3-NEXT:    vfmadd213pd {{.*#+}} ymm0 = (ymm1 * ymm0) + mem
+; FMA3-NEXT:    vmovapd {{.*#+}} ymm1 = [-5.0E-1,-2.5E-1,-1.25E-1,-6.25E-2]
+; FMA3-NEXT:    vfnmadd213pd {{.*#+}} ymm0 = -(ymm1 * ymm0) + ymm1
 ; FMA3-NEXT:    retq
 ;
 ; FMA4-LABEL: negated_constant_v4f64:
 ; FMA4:       # %bb.0:
-; FMA4-NEXT:    vmovapd {{.*#+}} ymm1 = [5.0E-1,2.5E-1,1.25E-1,6.25E-2]
-; FMA4-NEXT:    vfmaddpd {{.*#+}} ymm0 = (ymm0 * ymm1) + mem
+; FMA4-NEXT:    vmovapd {{.*#+}} ymm1 = [-5.0E-1,-2.5E-1,-1.25E-1,-6.25E-2]
+; FMA4-NEXT:    vfnmaddpd {{.*#+}} ymm0 = -(ymm0 * ymm1) + ymm1
 ; FMA4-NEXT:    retq
   %t = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %a, <4 x double> <double 5.000000e-01, double 2.5000000e-01, double 1.25000000e-01, double 0.62500000e-01>, <4 x double> <double -5.000000e-01, double -2.5000000e-01, double -1.25000000e-01, double -0.62500000e-01>)
   ret <4 x double> %t
@@ -146,16 +146,18 @@ define <4 x double> @negated_constant_v4f64_2fmas(<4 x double> %a, <4 x double>
 ; FMA3-LABEL: negated_constant_v4f64_2fmas:
 ; FMA3:       # %bb.0:
 ; FMA3-NEXT:    vmovapd {{.*#+}} ymm2 = <-5.0E-1,u,-2.5E+0,-4.5E+0>
-; FMA3-NEXT:    vfmadd213pd {{.*#+}} ymm2 = (ymm0 * ymm2) + ymm1
-; FMA3-NEXT:    vfmadd231pd {{.*#+}} ymm1 = (ymm0 * mem) + ymm1
-; FMA3-NEXT:    vaddpd %ymm1, %ymm2, %ymm0
+; FMA3-NEXT:    vmovapd %ymm2, %ymm3
+; FMA3-NEXT:    vfmadd213pd {{.*#+}} ymm3 = (ymm0 * ymm3) + ymm1
+; FMA3-NEXT:    vfnmadd213pd {{.*#+}} ymm2 = -(ymm0 * ymm2) + ymm1
+; FMA3-NEXT:    vaddpd %ymm2, %ymm3, %ymm0
 ; FMA3-NEXT:    retq
 ;
 ; FMA4-LABEL: negated_constant_v4f64_2fmas:
 ; FMA4:       # %bb.0:
-; FMA4-NEXT:    vfmaddpd {{.*#+}} ymm2 = (ymm0 * mem) + ymm1
-; FMA4-NEXT:    vfmaddpd {{.*#+}} ymm0 = (ymm0 * mem) + ymm1
-; FMA4-NEXT:    vaddpd %ymm0, %ymm2, %ymm0
+; FMA4-NEXT:    vmovapd {{.*#+}} ymm2 = <-5.0E-1,u,-2.5E+0,-4.5E+0>
+; FMA4-NEXT:    vfmaddpd {{.*#+}} ymm3 = (ymm0 * ymm2) + ymm1
+; FMA4-NEXT:    vfnmaddpd {{.*#+}} ymm0 = -(ymm0 * ymm2) + ymm1
+; FMA4-NEXT:    vaddpd %ymm0, %ymm3, %ymm0
 ; FMA4-NEXT:    retq
   %t0 = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %a, <4 x double> <double -5.000000e-01, double undef, double -25.000000e-01, double -45.000000e-01>, <4 x double> %b)
   %t1 = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %a, <4 x double> <double 5.000000e-01, double undef, double 25.000000e-01, double 45.000000e-01>, <4 x double> %b)
@@ -169,7 +171,7 @@ define <4 x double> @negated_constant_v4f64_fadd(<4 x double> %a) {
 ; FMA3-NEXT:    vbroadcastf128 {{.*#+}} ymm1 = [1.5E+0,1.25E-1,1.5E+0,1.25E-1]
 ; FMA3-NEXT:    # ymm1 = mem[0,1,0,1]
 ; FMA3-NEXT:    vaddpd %ymm1, %ymm0, %ymm0
-; FMA3-NEXT:    vfmadd213pd {{.*#+}} ymm0 = (ymm1 * ymm0) + mem
+; FMA3-NEXT:    vfmsub213pd {{.*#+}} ymm0 = (ymm1 * ymm0) - ymm1
 ; FMA3-NEXT:    retq
 ;
 ; FMA4-LABEL: negated_constant_v4f64_fadd:
@@ -177,7 +179,7 @@ define <4 x double> @negated_constant_v4f64_fadd(<4 x double> %a) {
 ; FMA4-NEXT:    vbroadcastf128 {{.*#+}} ymm1 = [1.5E+0,1.25E-1,1.5E+0,1.25E-1]
 ; FMA4-NEXT:    # ymm1 = mem[0,1,0,1]
 ; FMA4-NEXT:    vaddpd %ymm1, %ymm0, %ymm0
-; FMA4-NEXT:    vfmaddpd {{.*#+}} ymm0 = (ymm0 * ymm1) + mem
+; FMA4-NEXT:    vfmsubpd {{.*#+}} ymm0 = (ymm0 * ymm1) - ymm1
 ; FMA4-NEXT:    retq
   %t0 = fadd <4 x double> %a, <double 15.000000e-01, double 1.25000000e-01, double 15.000000e-01, double 1.25000000e-01>
   %t1 = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %t0, <4 x double> <double 15.000000e-01, double 1.25000000e-01, double 15.000000e-01, double 1.25000000e-01>, <4 x double> <double -15.000000e-01, double -1.25000000e-01, double -15.000000e-01, double -1.25000000e-01>)
@@ -187,19 +189,17 @@ define <4 x double> @negated_constant_v4f64_fadd(<4 x double> %a) {
 define <4 x double> @negated_constant_v4f64_2fma_undefs(<4 x double> %a, <4 x double> %b) {
 ; FMA3-LABEL: negated_constant_v4f64_2fma_undefs:
 ; FMA3:       # %bb.0:
-; FMA3-NEXT:    vmovapd {{.*#+}} ymm2 = [5.0E-1,5.0E-1,5.0E-1,5.0E-1]
-; FMA3-NEXT:    vfmadd213pd {{.*#+}} ymm2 = (ymm0 * ymm2) + mem
-; FMA3-NEXT:    vmovapd {{.*#+}} ymm0 = <u,5.0E-1,5.0E-1,5.0E-1>
-; FMA3-NEXT:    vfmadd213pd {{.*#+}} ymm0 = (ymm1 * ymm0) + mem
-; FMA3-NEXT:    vaddpd %ymm0, %ymm2, %ymm0
+; FMA3-NEXT:    vmovapd {{.*#+}} ymm2 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1]
+; FMA3-NEXT:    vfnmadd213pd {{.*#+}} ymm0 = -(ymm2 * ymm0) + mem
+; FMA3-NEXT:    vfmadd132pd {{.*#+}} ymm1 = (ymm1 * mem) + ymm2
+; FMA3-NEXT:    vaddpd %ymm1, %ymm0, %ymm0
 ; FMA3-NEXT:    retq
 ;
 ; FMA4-LABEL: negated_constant_v4f64_2fma_undefs:
 ; FMA4:       # %bb.0:
-; FMA4-NEXT:    vmovapd {{.*#+}} ymm2 = [5.0E-1,5.0E-1,5.0E-1,5.0E-1]
-; FMA4-NEXT:    vfmaddpd {{.*#+}} ymm0 = (ymm0 * ymm2) + mem
-; FMA4-NEXT:    vmovapd {{.*#+}} ymm2 = <u,5.0E-1,5.0E-1,5.0E-1>
-; FMA4-NEXT:    vfmaddpd {{.*#+}} ymm1 = (ymm1 * ymm2) + mem
+; FMA4-NEXT:    vmovapd {{.*#+}} ymm2 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1]
+; FMA4-NEXT:    vfnmaddpd {{.*#+}} ymm0 = -(ymm0 * ymm2) + mem
+; FMA4-NEXT:    vfmaddpd {{.*#+}} ymm1 = (ymm1 * mem) + ymm2
 ; FMA4-NEXT:    vaddpd %ymm1, %ymm0, %ymm0
 ; FMA4-NEXT:    retq
   %t0 = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %a, <4 x double> <double 5.000000e-01, double 5.000000e-01, double 5.000000e-01, double 5.000000e-01>, <4 x double> <double -5.000000e-01, double undef, double -5.000000e-01, double -5.000000e-01>)


        


More information about the llvm-commits mailing list