[llvm] c5276f7 - [X86] Combine constant vector inputs for FMA
Phoebe Wang via llvm-commits
llvm-commits at lists.llvm.org
Wed May 17 00:21:42 PDT 2023
Author: Evgenii Kudriashov
Date: 2023-05-17T15:21:34+08:00
New Revision: c5276f7728900ac59512e84bce25b8e69c0fd603
URL: https://github.com/llvm/llvm-project/commit/c5276f7728900ac59512e84bce25b8e69c0fd603
DIFF: https://github.com/llvm/llvm-project/commit/c5276f7728900ac59512e84bce25b8e69c0fd603.diff
LOG: [X86] Combine constant vector inputs for FMA
Inspired by https://discourse.llvm.org/t/folding-memory-into-fma/69217
Reviewed By: pengfei
Differential Revision: https://reviews.llvm.org/D146494
Added:
Modified:
llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/test/CodeGen/X86/avx2-fma-fneg-combine.ll
llvm/test/CodeGen/X86/fma-fneg-combine-2.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index fa3ff0f04e45..455500c8b4b6 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -54523,6 +54523,59 @@ static SDValue combineSext(SDNode *N, SelectionDAG &DAG,
return SDValue();
}
+// Inverting a constant vector is profitable if it can be eliminated and the
+// inverted vector is already present in DAG. Otherwise, it will be loaded
+// anyway.
+//
+// We determine which of the values can be completely eliminated and invert it.
+// If both are eliminable, select a vector with the first negative element.
+static SDValue getInvertedVectorForFMA(SDValue V, SelectionDAG &DAG) {
+ assert(ISD::isBuildVectorOfConstantFPSDNodes(V.getNode()) &&
+ "ConstantFP build vector expected");
+ // Check if we can eliminate V. We assume if a value is only used in FMAs, we
+ // can eliminate it. Since this function is invoked for each FMA with this
+ // vector.
+ auto IsNotFMA = [](SDNode *Use) {
+ return Use->getOpcode() != ISD::FMA && Use->getOpcode() != ISD::STRICT_FMA;
+ };
+ if (llvm::any_of(V->uses(), IsNotFMA))
+ return SDValue();
+
+ SmallVector<SDValue, 8> Ops;
+ EVT VT = V.getValueType();
+ EVT EltVT = VT.getVectorElementType();
+ for (auto Op : V->op_values()) {
+ if (auto *Cst = dyn_cast<ConstantFPSDNode>(Op)) {
+ Ops.push_back(DAG.getConstantFP(-Cst->getValueAPF(), SDLoc(Op), EltVT));
+ } else {
+ assert(Op.isUndef());
+ Ops.push_back(DAG.getUNDEF(EltVT));
+ }
+ }
+
+ SDNode *NV = DAG.getNodeIfExists(ISD::BUILD_VECTOR, DAG.getVTList(VT), Ops);
+ if (!NV)
+ return SDValue();
+
+ // If an inverted version cannot be eliminated, choose it instead of the
+ // original version.
+ if (llvm::any_of(NV->uses(), IsNotFMA))
+ return SDValue(NV, 0);
+
+ // If the inverted version also can be eliminated, we have to consistently
+ // prefer one of the values. We prefer a constant with a negative value on
+ // the first place.
+ // N.B. We need to skip undefs that may precede a value.
+ for (auto op : V->op_values()) {
+ if (auto *Cst = dyn_cast<ConstantFPSDNode>(op)) {
+ if (Cst->isNegative())
+ return SDValue();
+ break;
+ }
+ }
+ return SDValue(NV, 0);
+}
+
static SDValue combineFMA(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
@@ -54574,7 +54627,13 @@ static SDValue combineFMA(SDNode *N, SelectionDAG &DAG,
return true;
}
}
-
+ // Lookup if there is an inverted version of constant vector V in DAG.
+ if (ISD::isBuildVectorOfConstantFPSDNodes(V.getNode())) {
+ if (SDValue NegV = getInvertedVectorForFMA(V, DAG)) {
+ V = NegV;
+ return true;
+ }
+ }
return false;
};
diff --git a/llvm/test/CodeGen/X86/avx2-fma-fneg-combine.ll b/llvm/test/CodeGen/X86/avx2-fma-fneg-combine.ll
index 9969734c97e9..5d7fc7aaa185 100644
--- a/llvm/test/CodeGen/X86/avx2-fma-fneg-combine.ll
+++ b/llvm/test/CodeGen/X86/avx2-fma-fneg-combine.ll
@@ -154,15 +154,13 @@ define <4 x double> @test9(<4 x double> %a) {
; X32-LABEL: test9:
; X32: # %bb.0:
; X32-NEXT: vbroadcastsd {{.*#+}} ymm1 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1]
-; X32-NEXT: vbroadcastsd {{.*#+}} ymm2 = [5.0E-1,5.0E-1,5.0E-1,5.0E-1]
-; X32-NEXT: vfmadd213pd {{.*#+}} ymm0 = (ymm2 * ymm0) + ymm1
+; X32-NEXT: vfnmadd213pd {{.*#+}} ymm0 = -(ymm1 * ymm0) + ymm1
; X32-NEXT: retl
;
; X64-LABEL: test9:
; X64: # %bb.0:
; X64-NEXT: vbroadcastsd {{.*#+}} ymm1 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1]
-; X64-NEXT: vbroadcastsd {{.*#+}} ymm2 = [5.0E-1,5.0E-1,5.0E-1,5.0E-1]
-; X64-NEXT: vfmadd213pd {{.*#+}} ymm0 = (ymm2 * ymm0) + ymm1
+; X64-NEXT: vfnmadd213pd {{.*#+}} ymm0 = -(ymm1 * ymm0) + ymm1
; X64-NEXT: retq
%t = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %a, <4 x double> <double 5.000000e-01, double 5.000000e-01, double 5.000000e-01, double 5.000000e-01>, <4 x double> <double -5.000000e-01, double -5.000000e-01, double -5.000000e-01, double -5.000000e-01>)
ret <4 x double> %t
@@ -172,17 +170,19 @@ define <4 x double> @test10(<4 x double> %a, <4 x double> %b) {
; X32-LABEL: test10:
; X32: # %bb.0:
; X32-NEXT: vmovapd {{.*#+}} ymm2 = <-9.5E+0,u,-5.5E+0,-2.5E+0>
-; X32-NEXT: vfmadd213pd {{.*#+}} ymm2 = (ymm0 * ymm2) + ymm1
-; X32-NEXT: vfmadd231pd {{.*#+}} ymm1 = (ymm0 * mem) + ymm1
-; X32-NEXT: vaddpd %ymm1, %ymm2, %ymm0
+; X32-NEXT: vmovapd %ymm2, %ymm3
+; X32-NEXT: vfmadd213pd {{.*#+}} ymm3 = (ymm0 * ymm3) + ymm1
+; X32-NEXT: vfnmadd213pd {{.*#+}} ymm2 = -(ymm0 * ymm2) + ymm1
+; X32-NEXT: vaddpd %ymm2, %ymm3, %ymm0
; X32-NEXT: retl
;
; X64-LABEL: test10:
; X64: # %bb.0:
; X64-NEXT: vmovapd {{.*#+}} ymm2 = <-9.5E+0,u,-5.5E+0,-2.5E+0>
-; X64-NEXT: vfmadd213pd {{.*#+}} ymm2 = (ymm0 * ymm2) + ymm1
-; X64-NEXT: vfmadd231pd {{.*#+}} ymm1 = (ymm0 * mem) + ymm1
-; X64-NEXT: vaddpd %ymm1, %ymm2, %ymm0
+; X64-NEXT: vmovapd %ymm2, %ymm3
+; X64-NEXT: vfmadd213pd {{.*#+}} ymm3 = (ymm0 * ymm3) + ymm1
+; X64-NEXT: vfnmadd213pd {{.*#+}} ymm2 = -(ymm0 * ymm2) + ymm1
+; X64-NEXT: vaddpd %ymm2, %ymm3, %ymm0
; X64-NEXT: retq
%t0 = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %a, <4 x double> <double -95.00000e-01, double undef, double -55.00000e-01, double -25.00000e-01>, <4 x double> %b)
%t1 = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %a, <4 x double> <double 95.00000e-01, double undef, double 55.00000e-01, double 25.00000e-01>, <4 x double> %b)
@@ -196,7 +196,7 @@ define <4 x double> @test11(<4 x double> %a) {
; X32-NEXT: vbroadcastf128 {{.*#+}} ymm1 = [5.0E-1,2.5E+0,5.0E-1,2.5E+0]
; X32-NEXT: # ymm1 = mem[0,1,0,1]
; X32-NEXT: vaddpd %ymm1, %ymm0, %ymm0
-; X32-NEXT: vfmadd213pd {{.*#+}} ymm0 = (ymm1 * ymm0) + mem
+; X32-NEXT: vfmsub213pd {{.*#+}} ymm0 = (ymm1 * ymm0) - ymm1
; X32-NEXT: retl
;
; X64-LABEL: test11:
@@ -204,7 +204,7 @@ define <4 x double> @test11(<4 x double> %a) {
; X64-NEXT: vbroadcastf128 {{.*#+}} ymm1 = [5.0E-1,2.5E+0,5.0E-1,2.5E+0]
; X64-NEXT: # ymm1 = mem[0,1,0,1]
; X64-NEXT: vaddpd %ymm1, %ymm0, %ymm0
-; X64-NEXT: vfmadd213pd {{.*#+}} ymm0 = (ymm1 * ymm0) + mem
+; X64-NEXT: vfmsub213pd {{.*#+}} ymm0 = (ymm1 * ymm0) - ymm1
; X64-NEXT: retq
%t0 = fadd <4 x double> %a, <double 5.000000e-01, double 25.00000e-01, double 5.000000e-01, double 25.00000e-01>
%t1 = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %t0, <4 x double> <double 5.000000e-01, double 25.00000e-01, double 5.000000e-01, double 25.00000e-01>, <4 x double> <double -5.000000e-01, double -25.00000e-01, double -5.000000e-01, double -25.00000e-01>)
@@ -214,20 +214,18 @@ define <4 x double> @test11(<4 x double> %a) {
define <4 x double> @test12(<4 x double> %a, <4 x double> %b) {
; X32-LABEL: test12:
; X32: # %bb.0:
-; X32-NEXT: vmovapd {{.*#+}} ymm2 = [7.5E+0,2.5E+0,5.5E+0,9.5E+0]
-; X32-NEXT: vfmadd213pd {{.*#+}} ymm2 = (ymm0 * ymm2) + mem
-; X32-NEXT: vmovapd {{.*#+}} ymm0 = <u,2.5E+0,5.5E+0,9.5E+0>
-; X32-NEXT: vfmadd213pd {{.*#+}} ymm0 = (ymm1 * ymm0) + mem
-; X32-NEXT: vaddpd %ymm0, %ymm2, %ymm0
+; X32-NEXT: vmovapd {{.*#+}} ymm2 = [-7.5E+0,-2.5E+0,-5.5E+0,-9.5E+0]
+; X32-NEXT: vfnmadd213pd {{.*#+}} ymm0 = -(ymm2 * ymm0) + mem
+; X32-NEXT: vfmadd132pd {{.*#+}} ymm1 = (ymm1 * mem) + ymm2
+; X32-NEXT: vaddpd %ymm1, %ymm0, %ymm0
; X32-NEXT: retl
;
; X64-LABEL: test12:
; X64: # %bb.0:
-; X64-NEXT: vmovapd {{.*#+}} ymm2 = [7.5E+0,2.5E+0,5.5E+0,9.5E+0]
-; X64-NEXT: vfmadd213pd {{.*#+}} ymm2 = (ymm0 * ymm2) + mem
-; X64-NEXT: vmovapd {{.*#+}} ymm0 = <u,2.5E+0,5.5E+0,9.5E+0>
-; X64-NEXT: vfmadd213pd {{.*#+}} ymm0 = (ymm1 * ymm0) + mem
-; X64-NEXT: vaddpd %ymm0, %ymm2, %ymm0
+; X64-NEXT: vmovapd {{.*#+}} ymm2 = [-7.5E+0,-2.5E+0,-5.5E+0,-9.5E+0]
+; X64-NEXT: vfnmadd213pd {{.*#+}} ymm0 = -(ymm2 * ymm0) + mem
+; X64-NEXT: vfmadd132pd {{.*#+}} ymm1 = (ymm1 * mem) + ymm2
+; X64-NEXT: vaddpd %ymm1, %ymm0, %ymm0
; X64-NEXT: retq
%t0 = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %a, <4 x double> <double 75.00000e-01, double 25.00000e-01, double 55.00000e-01, double 95.00000e-01>, <4 x double> <double -75.00000e-01, double undef, double -55.00000e-01, double -95.00000e-01>)
%t1 = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %b, <4 x double> <double undef, double 25.00000e-01, double 55.00000e-01, double 95.00000e-01>, <4 x double> <double -75.00000e-01, double -25.00000e-01, double -55.00000e-01, double -95.00000e-01>)
diff --git a/llvm/test/CodeGen/X86/fma-fneg-combine-2.ll b/llvm/test/CodeGen/X86/fma-fneg-combine-2.ll
index d3bc7399789d..2a3c3e3c7f4f 100644
--- a/llvm/test/CodeGen/X86/fma-fneg-combine-2.ll
+++ b/llvm/test/CodeGen/X86/fma-fneg-combine-2.ll
@@ -129,14 +129,14 @@ define float @negated_constant(float %x) {
define <4 x double> @negated_constant_v4f64(<4 x double> %a) {
; FMA3-LABEL: negated_constant_v4f64:
; FMA3: # %bb.0:
-; FMA3-NEXT: vmovapd {{.*#+}} ymm1 = [5.0E-1,2.5E-1,1.25E-1,6.25E-2]
-; FMA3-NEXT: vfmadd213pd {{.*#+}} ymm0 = (ymm1 * ymm0) + mem
+; FMA3-NEXT: vmovapd {{.*#+}} ymm1 = [-5.0E-1,-2.5E-1,-1.25E-1,-6.25E-2]
+; FMA3-NEXT: vfnmadd213pd {{.*#+}} ymm0 = -(ymm1 * ymm0) + ymm1
; FMA3-NEXT: retq
;
; FMA4-LABEL: negated_constant_v4f64:
; FMA4: # %bb.0:
-; FMA4-NEXT: vmovapd {{.*#+}} ymm1 = [5.0E-1,2.5E-1,1.25E-1,6.25E-2]
-; FMA4-NEXT: vfmaddpd {{.*#+}} ymm0 = (ymm0 * ymm1) + mem
+; FMA4-NEXT: vmovapd {{.*#+}} ymm1 = [-5.0E-1,-2.5E-1,-1.25E-1,-6.25E-2]
+; FMA4-NEXT: vfnmaddpd {{.*#+}} ymm0 = -(ymm0 * ymm1) + ymm1
; FMA4-NEXT: retq
%t = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %a, <4 x double> <double 5.000000e-01, double 2.5000000e-01, double 1.25000000e-01, double 0.62500000e-01>, <4 x double> <double -5.000000e-01, double -2.5000000e-01, double -1.25000000e-01, double -0.62500000e-01>)
ret <4 x double> %t
@@ -146,16 +146,18 @@ define <4 x double> @negated_constant_v4f64_2fmas(<4 x double> %a, <4 x double>
; FMA3-LABEL: negated_constant_v4f64_2fmas:
; FMA3: # %bb.0:
; FMA3-NEXT: vmovapd {{.*#+}} ymm2 = <-5.0E-1,u,-2.5E+0,-4.5E+0>
-; FMA3-NEXT: vfmadd213pd {{.*#+}} ymm2 = (ymm0 * ymm2) + ymm1
-; FMA3-NEXT: vfmadd231pd {{.*#+}} ymm1 = (ymm0 * mem) + ymm1
-; FMA3-NEXT: vaddpd %ymm1, %ymm2, %ymm0
+; FMA3-NEXT: vmovapd %ymm2, %ymm3
+; FMA3-NEXT: vfmadd213pd {{.*#+}} ymm3 = (ymm0 * ymm3) + ymm1
+; FMA3-NEXT: vfnmadd213pd {{.*#+}} ymm2 = -(ymm0 * ymm2) + ymm1
+; FMA3-NEXT: vaddpd %ymm2, %ymm3, %ymm0
; FMA3-NEXT: retq
;
; FMA4-LABEL: negated_constant_v4f64_2fmas:
; FMA4: # %bb.0:
-; FMA4-NEXT: vfmaddpd {{.*#+}} ymm2 = (ymm0 * mem) + ymm1
-; FMA4-NEXT: vfmaddpd {{.*#+}} ymm0 = (ymm0 * mem) + ymm1
-; FMA4-NEXT: vaddpd %ymm0, %ymm2, %ymm0
+; FMA4-NEXT: vmovapd {{.*#+}} ymm2 = <-5.0E-1,u,-2.5E+0,-4.5E+0>
+; FMA4-NEXT: vfmaddpd {{.*#+}} ymm3 = (ymm0 * ymm2) + ymm1
+; FMA4-NEXT: vfnmaddpd {{.*#+}} ymm0 = -(ymm0 * ymm2) + ymm1
+; FMA4-NEXT: vaddpd %ymm0, %ymm3, %ymm0
; FMA4-NEXT: retq
%t0 = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %a, <4 x double> <double -5.000000e-01, double undef, double -25.000000e-01, double -45.000000e-01>, <4 x double> %b)
%t1 = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %a, <4 x double> <double 5.000000e-01, double undef, double 25.000000e-01, double 45.000000e-01>, <4 x double> %b)
@@ -169,7 +171,7 @@ define <4 x double> @negated_constant_v4f64_fadd(<4 x double> %a) {
; FMA3-NEXT: vbroadcastf128 {{.*#+}} ymm1 = [1.5E+0,1.25E-1,1.5E+0,1.25E-1]
; FMA3-NEXT: # ymm1 = mem[0,1,0,1]
; FMA3-NEXT: vaddpd %ymm1, %ymm0, %ymm0
-; FMA3-NEXT: vfmadd213pd {{.*#+}} ymm0 = (ymm1 * ymm0) + mem
+; FMA3-NEXT: vfmsub213pd {{.*#+}} ymm0 = (ymm1 * ymm0) - ymm1
; FMA3-NEXT: retq
;
; FMA4-LABEL: negated_constant_v4f64_fadd:
@@ -177,7 +179,7 @@ define <4 x double> @negated_constant_v4f64_fadd(<4 x double> %a) {
; FMA4-NEXT: vbroadcastf128 {{.*#+}} ymm1 = [1.5E+0,1.25E-1,1.5E+0,1.25E-1]
; FMA4-NEXT: # ymm1 = mem[0,1,0,1]
; FMA4-NEXT: vaddpd %ymm1, %ymm0, %ymm0
-; FMA4-NEXT: vfmaddpd {{.*#+}} ymm0 = (ymm0 * ymm1) + mem
+; FMA4-NEXT: vfmsubpd {{.*#+}} ymm0 = (ymm0 * ymm1) - ymm1
; FMA4-NEXT: retq
%t0 = fadd <4 x double> %a, <double 15.000000e-01, double 1.25000000e-01, double 15.000000e-01, double 1.25000000e-01>
%t1 = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %t0, <4 x double> <double 15.000000e-01, double 1.25000000e-01, double 15.000000e-01, double 1.25000000e-01>, <4 x double> <double -15.000000e-01, double -1.25000000e-01, double -15.000000e-01, double -1.25000000e-01>)
@@ -187,19 +189,17 @@ define <4 x double> @negated_constant_v4f64_fadd(<4 x double> %a) {
define <4 x double> @negated_constant_v4f64_2fma_undefs(<4 x double> %a, <4 x double> %b) {
; FMA3-LABEL: negated_constant_v4f64_2fma_undefs:
; FMA3: # %bb.0:
-; FMA3-NEXT: vmovapd {{.*#+}} ymm2 = [5.0E-1,5.0E-1,5.0E-1,5.0E-1]
-; FMA3-NEXT: vfmadd213pd {{.*#+}} ymm2 = (ymm0 * ymm2) + mem
-; FMA3-NEXT: vmovapd {{.*#+}} ymm0 = <u,5.0E-1,5.0E-1,5.0E-1>
-; FMA3-NEXT: vfmadd213pd {{.*#+}} ymm0 = (ymm1 * ymm0) + mem
-; FMA3-NEXT: vaddpd %ymm0, %ymm2, %ymm0
+; FMA3-NEXT: vmovapd {{.*#+}} ymm2 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1]
+; FMA3-NEXT: vfnmadd213pd {{.*#+}} ymm0 = -(ymm2 * ymm0) + mem
+; FMA3-NEXT: vfmadd132pd {{.*#+}} ymm1 = (ymm1 * mem) + ymm2
+; FMA3-NEXT: vaddpd %ymm1, %ymm0, %ymm0
; FMA3-NEXT: retq
;
; FMA4-LABEL: negated_constant_v4f64_2fma_undefs:
; FMA4: # %bb.0:
-; FMA4-NEXT: vmovapd {{.*#+}} ymm2 = [5.0E-1,5.0E-1,5.0E-1,5.0E-1]
-; FMA4-NEXT: vfmaddpd {{.*#+}} ymm0 = (ymm0 * ymm2) + mem
-; FMA4-NEXT: vmovapd {{.*#+}} ymm2 = <u,5.0E-1,5.0E-1,5.0E-1>
-; FMA4-NEXT: vfmaddpd {{.*#+}} ymm1 = (ymm1 * ymm2) + mem
+; FMA4-NEXT: vmovapd {{.*#+}} ymm2 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1]
+; FMA4-NEXT: vfnmaddpd {{.*#+}} ymm0 = -(ymm0 * ymm2) + mem
+; FMA4-NEXT: vfmaddpd {{.*#+}} ymm1 = (ymm1 * mem) + ymm2
; FMA4-NEXT: vaddpd %ymm1, %ymm0, %ymm0
; FMA4-NEXT: retq
%t0 = tail call <4 x double> @llvm.fma.v4f64(<4 x double> %a, <4 x double> <double 5.000000e-01, double 5.000000e-01, double 5.000000e-01, double 5.000000e-01>, <4 x double> <double -5.000000e-01, double undef, double -5.000000e-01, double -5.000000e-01>)
More information about the llvm-commits
mailing list