[llvm] r210644 - [X86] Refactor the logic to select horizontal adds/subs to a helper function.
Andrea Di Biagio
Andrea_DiBiagio at sn.scee.net
Wed Jun 11 00:57:50 PDT 2014
Author: adibiagio
Date: Wed Jun 11 02:57:50 2014
New Revision: 210644
URL: http://llvm.org/viewvc/llvm-project?rev=210644&view=rev
Log:
[X86] Refactor the logic to select horizontal adds/subs to a helper function.
This patch moves part of the logic implemented by the target specific
combine rules added at r210477 to a separate helper function.
This should make easier to add more rules for matching AVX/AVX2 horizontal
adds/subs.
This patch also fixes a problem caused by a wrong check performed on indices
of extract_vector_elt dag nodes in input to the scalar adds/subs.
New tests have been added to verify that we correctly check indices of
extract_vector_elt dag nodes when selecting a horizontal operation.
Modified:
llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
llvm/trunk/test/CodeGen/X86/haddsub-2.ll
Modified: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86ISelLowering.cpp?rev=210644&r1=210643&r2=210644&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp Wed Jun 11 02:57:50 2014
@@ -6057,102 +6057,130 @@ X86TargetLowering::LowerBUILD_VECTORvXi1
return DAG.getNode(ISD::BITCAST, dl, VT, Select);
}
-static SDValue PerformBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG,
- const X86Subtarget *Subtarget) {
- EVT VT = N->getValueType(0);
-
- // Try to match a horizontal ADD or SUB.
- if (((VT == MVT::v4f32 || VT == MVT::v2f64) && Subtarget->hasSSE3()) ||
- ((VT == MVT::v4i32 || VT == MVT::v8i16) && Subtarget->hasSSSE3()) ||
- ((VT == MVT::v8f32 || VT == MVT::v4f64 || VT == MVT::v8i32 ||
- VT == MVT::v16i16) && Subtarget->hasAVX())) {
- unsigned NumOperands = N->getNumOperands();
- unsigned Opcode = N->getOperand(0)->getOpcode();
- bool isCommutable = false;
- bool CanFold = false;
- switch (Opcode) {
- default : break;
- case ISD::ADD :
- case ISD::FADD :
- isCommutable = true;
- // FALL-THROUGH
- case ISD::SUB :
- case ISD::FSUB :
- CanFold = true;
- }
+/// \brief Return true if \p N implements a horizontal binop and return the
+/// operands for the horizontal binop into V0 and V1.
+///
+/// This is a helper function of PerformBUILD_VECTORCombine.
+/// This function checks that the build_vector \p N in input implements a
+/// horizontal operation. Parameter \p Opcode defines the kind of horizontal
+/// operation to match.
+/// For example, if \p Opcode is equal to ISD::ADD, then this function
+/// checks if \p N implements a horizontal arithmetic add; if instead \p Opcode
+/// is equal to ISD::SUB, then this function checks if this is a horizontal
+/// arithmetic sub.
+///
+/// This function only analyzes elements of \p N whose indices are
+/// in range [BaseIdx, LastIdx).
+static bool isHorizontalBinOp(const BuildVectorSDNode *N, unsigned Opcode,
+ unsigned BaseIdx, unsigned LastIdx,
+ SDValue &V0, SDValue &V1) {
+ assert(BaseIdx * 2 <= LastIdx && "Invalid Indices in input!");
+ assert(N->getValueType(0).isVector() &&
+ N->getValueType(0).getVectorNumElements() >= LastIdx &&
+ "Invalid Vector in input!");
+
+ bool IsCommutable = (Opcode == ISD::ADD || Opcode == ISD::FADD);
+ bool CanFold = true;
+ unsigned ExpectedVExtractIdx = BaseIdx;
+ unsigned NumElts = LastIdx - BaseIdx;
+
+ // Check if N implements a horizontal binop.
+ for (unsigned i = 0, e = NumElts; i != e && CanFold; ++i) {
+ SDValue Op = N->getOperand(i + BaseIdx);
+ CanFold = Op->getOpcode() == Opcode && Op->hasOneUse();
+
+ if (!CanFold)
+ break;
+
+ SDValue Op0 = Op.getOperand(0);
+ SDValue Op1 = Op.getOperand(1);
+
+ // Try to match the following pattern:
+ // (BINOP (extract_vector_elt A, I), (extract_vector_elt A, I+1))
+ CanFold = (Op0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
+ Op1.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
+ Op0.getOperand(0) == Op1.getOperand(0) &&
+ isa<ConstantSDNode>(Op0.getOperand(1)) &&
+ isa<ConstantSDNode>(Op1.getOperand(1)));
+ if (!CanFold)
+ break;
- // Verify that operands have the same opcode; also, the opcode can only
- // be either of: ADD, FADD, SUB, FSUB.
- SDValue InVec0, InVec1;
- for (unsigned i = 0, e = NumOperands; i != e && CanFold; ++i) {
- SDValue Op = N->getOperand(i);
- CanFold = Op->getOpcode() == Opcode && Op->hasOneUse();
-
- if (!CanFold)
- break;
-
- SDValue Op0 = Op.getOperand(0);
- SDValue Op1 = Op.getOperand(1);
-
- // Try to match the following pattern:
- // (BINOP (extract_vector_elt A, I), (extract_vector_elt A, I+1))
- CanFold = (Op0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
- Op1.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
- Op0.getOperand(0) == Op1.getOperand(0) &&
- isa<ConstantSDNode>(Op0.getOperand(1)) &&
- isa<ConstantSDNode>(Op1.getOperand(1)));
- if (!CanFold)
- break;
-
- unsigned I0 = cast<ConstantSDNode>(Op0.getOperand(1))->getZExtValue();
- unsigned I1 = cast<ConstantSDNode>(Op1.getOperand(1))->getZExtValue();
- unsigned ExpectedIndex = (i * 2) % NumOperands;
+ unsigned I0 = cast<ConstantSDNode>(Op0.getOperand(1))->getZExtValue();
+ unsigned I1 = cast<ConstantSDNode>(Op1.getOperand(1))->getZExtValue();
- if (i == 0)
- InVec0 = Op0.getOperand(0);
- else if (i * 2 == NumOperands)
- InVec1 = Op0.getOperand(0);
-
- SDValue Expected = (i * 2 < NumOperands) ? InVec0 : InVec1;
- if (I0 == ExpectedIndex)
- CanFold = I1 == I0 + 1 && Op0.getOperand(0) == Expected;
- else if (isCommutable && I1 == ExpectedIndex) {
- // Try to see if we can match the following dag sequence:
- // (BINOP (extract_vector_elt A, I+1), (extract_vector_elt A, I))
- CanFold = I0 == I1 + 1 && Op1.getOperand(0) == Expected;
- }
+ if (i == 0)
+ V0 = Op0.getOperand(0);
+ else if (i * 2 == NumElts) {
+ V1 = Op0.getOperand(0);
+ ExpectedVExtractIdx = BaseIdx;
}
- if (CanFold) {
- unsigned NewOpcode;
- switch (Opcode) {
- default : llvm_unreachable("Unexpected opcode found!");
- case ISD::ADD : NewOpcode = X86ISD::HADD; break;
- case ISD::FADD : NewOpcode = X86ISD::FHADD; break;
- case ISD::SUB : NewOpcode = X86ISD::HSUB; break;
- case ISD::FSUB : NewOpcode = X86ISD::FHSUB; break;
- }
-
- if (VT.is256BitVector()) {
- SDLoc dl(N);
+ SDValue Expected = (i * 2 < NumElts) ? V0 : V1;
+ if (I0 == ExpectedVExtractIdx)
+ CanFold = I1 == I0 + 1 && Op0.getOperand(0) == Expected;
+ else if (IsCommutable && I1 == ExpectedVExtractIdx) {
+ // Try to match the following dag sequence:
+ // (BINOP (extract_vector_elt A, I+1), (extract_vector_elt A, I))
+ CanFold = I0 == I1 + 1 && Op1.getOperand(0) == Expected;
+ } else
+ CanFold = false;
- // Convert this sequence into two horizontal add/sub followed
- // by a concat vector.
- SDValue InVec0_LO = Extract128BitVector(InVec0, 0, DAG, dl);
- SDValue InVec0_HI =
- Extract128BitVector(InVec0, NumOperands/2, DAG, dl);
- SDValue InVec1_LO = Extract128BitVector(InVec1, 0, DAG, dl);
- SDValue InVec1_HI =
- Extract128BitVector(InVec1, NumOperands/2, DAG, dl);
- EVT NewVT = InVec0_LO.getValueType();
-
- SDValue LO = DAG.getNode(NewOpcode, dl, NewVT, InVec0_LO, InVec0_HI);
- SDValue HI = DAG.getNode(NewOpcode, dl, NewVT, InVec1_LO, InVec1_HI);
- return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, LO, HI);
- }
+ ExpectedVExtractIdx += 2;
+ }
- return DAG.getNode(NewOpcode, SDLoc(N), VT, InVec0, InVec1);
- }
+ return CanFold;
+}
+
+static SDValue PerformBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG,
+ const X86Subtarget *Subtarget) {
+ SDLoc DL(N);
+ EVT VT = N->getValueType(0);
+ unsigned NumElts = VT.getVectorNumElements();
+ BuildVectorSDNode *BV = cast<BuildVectorSDNode>(N);
+ SDValue InVec0, InVec1;
+
+ // Try to match horizontal ADD/SUB.
+ if ((VT == MVT::v4f32 || VT == MVT::v2f64) && Subtarget->hasSSE3()) {
+ // Try to match an SSE3 float HADD/HSUB.
+ if (isHorizontalBinOp(BV, ISD::FADD, 0, NumElts, InVec0, InVec1))
+ return DAG.getNode(X86ISD::FHADD, DL, VT, InVec0, InVec1);
+
+ if (isHorizontalBinOp(BV, ISD::FSUB, 0, NumElts, InVec0, InVec1))
+ return DAG.getNode(X86ISD::FHSUB, DL, VT, InVec0, InVec1);
+ } else if ((VT == MVT::v4i32 || VT == MVT::v8i16) && Subtarget->hasSSSE3()) {
+ // Try to match an SSSE3 integer HADD/HSUB.
+ if (isHorizontalBinOp(BV, ISD::ADD, 0, NumElts, InVec0, InVec1))
+ return DAG.getNode(X86ISD::HADD, DL, VT, InVec0, InVec1);
+
+ if (isHorizontalBinOp(BV, ISD::SUB, 0, NumElts, InVec0, InVec1))
+ return DAG.getNode(X86ISD::HSUB, DL, VT, InVec0, InVec1);
+ }
+
+ if ((VT == MVT::v8f32 || VT == MVT::v4f64 || VT == MVT::v8i32 ||
+ VT == MVT::v16i16) && Subtarget->hasAVX()) {
+ unsigned X86Opcode;
+ if (isHorizontalBinOp(BV, ISD::ADD, 0, NumElts, InVec0, InVec1))
+ X86Opcode = X86ISD::HADD;
+ else if (isHorizontalBinOp(BV, ISD::SUB, 0, NumElts, InVec0, InVec1))
+ X86Opcode = X86ISD::HSUB;
+ else if (isHorizontalBinOp(BV, ISD::FADD, 0, NumElts, InVec0, InVec1))
+ X86Opcode = X86ISD::FHADD;
+ else if (isHorizontalBinOp(BV, ISD::FSUB, 0, NumElts, InVec0, InVec1))
+ X86Opcode = X86ISD::FHSUB;
+ else
+ return SDValue();
+
+ // Convert this build_vector into two horizontal add/sub followed by
+ // a concat vector.
+ SDValue InVec0_LO = Extract128BitVector(InVec0, 0, DAG, DL);
+ SDValue InVec0_HI = Extract128BitVector(InVec0, NumElts/2, DAG, DL);
+ SDValue InVec1_LO = Extract128BitVector(InVec1, 0, DAG, DL);
+ SDValue InVec1_HI = Extract128BitVector(InVec1, NumElts/2, DAG, DL);
+ EVT NewVT = InVec0_LO.getValueType();
+
+ SDValue LO = DAG.getNode(X86Opcode, DL, NewVT, InVec0_LO, InVec0_HI);
+ SDValue HI = DAG.getNode(X86Opcode, DL, NewVT, InVec1_LO, InVec1_HI);
+ return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, LO, HI);
}
return SDValue();
Modified: llvm/trunk/test/CodeGen/X86/haddsub-2.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/X86/haddsub-2.ll?rev=210644&r1=210643&r2=210644&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/X86/haddsub-2.ll (original)
+++ llvm/trunk/test/CodeGen/X86/haddsub-2.ll Wed Jun 11 02:57:50 2014
@@ -86,12 +86,12 @@ define <4 x float> @hsub_ps_test2(<4 x f
%vecext3 = extractelement <4 x float> %A, i32 1
%sub4 = fsub float %vecext2, %vecext3
%vecinit5 = insertelement <4 x float> %vecinit, float %sub4, i32 0
- %vecext6 = extractelement <4 x float> %B, i32 3
- %vecext7 = extractelement <4 x float> %B, i32 2
+ %vecext6 = extractelement <4 x float> %B, i32 2
+ %vecext7 = extractelement <4 x float> %B, i32 3
%sub8 = fsub float %vecext6, %vecext7
%vecinit9 = insertelement <4 x float> %vecinit5, float %sub8, i32 3
- %vecext10 = extractelement <4 x float> %B, i32 1
- %vecext11 = extractelement <4 x float> %B, i32 0
+ %vecext10 = extractelement <4 x float> %B, i32 0
+ %vecext11 = extractelement <4 x float> %B, i32 1
%sub12 = fsub float %vecext10, %vecext11
%vecinit13 = insertelement <4 x float> %vecinit9, float %sub12, i32 2
ret <4 x float> %vecinit13
@@ -137,12 +137,12 @@ define <4 x i32> @phadd_d_test2(<4 x i32
%vecext3 = extractelement <4 x i32> %A, i32 1
%add4 = add i32 %vecext2, %vecext3
%vecinit5 = insertelement <4 x i32> %vecinit, i32 %add4, i32 0
- %vecext6 = extractelement <4 x i32> %B, i32 2
- %vecext7 = extractelement <4 x i32> %B, i32 3
+ %vecext6 = extractelement <4 x i32> %B, i32 3
+ %vecext7 = extractelement <4 x i32> %B, i32 2
%add8 = add i32 %vecext6, %vecext7
%vecinit9 = insertelement <4 x i32> %vecinit5, i32 %add8, i32 3
- %vecext10 = extractelement <4 x i32> %B, i32 0
- %vecext11 = extractelement <4 x i32> %B, i32 1
+ %vecext10 = extractelement <4 x i32> %B, i32 1
+ %vecext11 = extractelement <4 x i32> %B, i32 0
%add12 = add i32 %vecext10, %vecext11
%vecinit13 = insertelement <4 x i32> %vecinit9, i32 %add12, i32 2
ret <4 x i32> %vecinit13
@@ -191,12 +191,12 @@ define <4 x i32> @phsub_d_test2(<4 x i32
%vecext3 = extractelement <4 x i32> %A, i32 1
%sub4 = sub i32 %vecext2, %vecext3
%vecinit5 = insertelement <4 x i32> %vecinit, i32 %sub4, i32 0
- %vecext6 = extractelement <4 x i32> %B, i32 3
- %vecext7 = extractelement <4 x i32> %B, i32 2
+ %vecext6 = extractelement <4 x i32> %B, i32 2
+ %vecext7 = extractelement <4 x i32> %B, i32 3
%sub8 = sub i32 %vecext6, %vecext7
%vecinit9 = insertelement <4 x i32> %vecinit5, i32 %sub8, i32 3
- %vecext10 = extractelement <4 x i32> %B, i32 1
- %vecext11 = extractelement <4 x i32> %B, i32 0
+ %vecext10 = extractelement <4 x i32> %B, i32 0
+ %vecext11 = extractelement <4 x i32> %B, i32 1
%sub12 = sub i32 %vecext10, %vecext11
%vecinit13 = insertelement <4 x i32> %vecinit9, i32 %sub12, i32 2
ret <4 x i32> %vecinit13
@@ -258,14 +258,14 @@ define <2 x double> @hsub_pd_test1(<2 x
define <2 x double> @hsub_pd_test2(<2 x double> %A, <2 x double> %B) {
- %vecext = extractelement <2 x double> %A, i32 1
- %vecext1 = extractelement <2 x double> %A, i32 0
+ %vecext = extractelement <2 x double> %B, i32 0
+ %vecext1 = extractelement <2 x double> %B, i32 1
%sub = fsub double %vecext, %vecext1
- %vecinit = insertelement <2 x double> undef, double %sub, i32 0
- %vecext2 = extractelement <2 x double> %B, i32 1
- %vecext3 = extractelement <2 x double> %B, i32 0
+ %vecinit = insertelement <2 x double> undef, double %sub, i32 1
+ %vecext2 = extractelement <2 x double> %A, i32 0
+ %vecext3 = extractelement <2 x double> %A, i32 1
%sub2 = fsub double %vecext2, %vecext3
- %vecinit2 = insertelement <2 x double> %vecinit, double %sub2, i32 1
+ %vecinit2 = insertelement <2 x double> %vecinit, double %sub2, i32 0
ret <2 x double> %vecinit2
}
; CHECK-LABEL: hsub_pd_test2
@@ -458,3 +458,68 @@ define <16 x i16> @avx2_vphadd_w_test(<1
; CHECK: ret
+; Verify that we don't select horizontal subs in the following functions.
+
+define <4 x i32> @not_a_hsub_1(<4 x i32> %A, <4 x i32> %B) {
+ %vecext = extractelement <4 x i32> %A, i32 0
+ %vecext1 = extractelement <4 x i32> %A, i32 1
+ %sub = sub i32 %vecext, %vecext1
+ %vecinit = insertelement <4 x i32> undef, i32 %sub, i32 0
+ %vecext2 = extractelement <4 x i32> %A, i32 2
+ %vecext3 = extractelement <4 x i32> %A, i32 3
+ %sub4 = sub i32 %vecext2, %vecext3
+ %vecinit5 = insertelement <4 x i32> %vecinit, i32 %sub4, i32 1
+ %vecext6 = extractelement <4 x i32> %B, i32 1
+ %vecext7 = extractelement <4 x i32> %B, i32 0
+ %sub8 = sub i32 %vecext6, %vecext7
+ %vecinit9 = insertelement <4 x i32> %vecinit5, i32 %sub8, i32 2
+ %vecext10 = extractelement <4 x i32> %B, i32 3
+ %vecext11 = extractelement <4 x i32> %B, i32 2
+ %sub12 = sub i32 %vecext10, %vecext11
+ %vecinit13 = insertelement <4 x i32> %vecinit9, i32 %sub12, i32 3
+ ret <4 x i32> %vecinit13
+}
+; CHECK-LABEL: not_a_hsub_1
+; CHECK-NOT: phsubd
+; CHECK: ret
+
+
+define <4 x float> @not_a_hsub_2(<4 x float> %A, <4 x float> %B) {
+ %vecext = extractelement <4 x float> %A, i32 2
+ %vecext1 = extractelement <4 x float> %A, i32 3
+ %sub = fsub float %vecext, %vecext1
+ %vecinit = insertelement <4 x float> undef, float %sub, i32 1
+ %vecext2 = extractelement <4 x float> %A, i32 0
+ %vecext3 = extractelement <4 x float> %A, i32 1
+ %sub4 = fsub float %vecext2, %vecext3
+ %vecinit5 = insertelement <4 x float> %vecinit, float %sub4, i32 0
+ %vecext6 = extractelement <4 x float> %B, i32 3
+ %vecext7 = extractelement <4 x float> %B, i32 2
+ %sub8 = fsub float %vecext6, %vecext7
+ %vecinit9 = insertelement <4 x float> %vecinit5, float %sub8, i32 3
+ %vecext10 = extractelement <4 x float> %B, i32 0
+ %vecext11 = extractelement <4 x float> %B, i32 1
+ %sub12 = fsub float %vecext10, %vecext11
+ %vecinit13 = insertelement <4 x float> %vecinit9, float %sub12, i32 2
+ ret <4 x float> %vecinit13
+}
+; CHECK-LABEL: not_a_hsub_2
+; CHECK-NOT: hsubps
+; CHECK: ret
+
+
+define <2 x double> @not_a_hsub_3(<2 x double> %A, <2 x double> %B) {
+ %vecext = extractelement <2 x double> %B, i32 0
+ %vecext1 = extractelement <2 x double> %B, i32 1
+ %sub = fsub double %vecext, %vecext1
+ %vecinit = insertelement <2 x double> undef, double %sub, i32 1
+ %vecext2 = extractelement <2 x double> %A, i32 1
+ %vecext3 = extractelement <2 x double> %A, i32 0
+ %sub2 = fsub double %vecext2, %vecext3
+ %vecinit2 = insertelement <2 x double> %vecinit, double %sub2, i32 0
+ ret <2 x double> %vecinit2
+}
+; CHECK-LABEL: not_a_hsub_3
+; CHECK-NOT: hsubpd
+; CHECK: ret
+
More information about the llvm-commits
mailing list