[llvm] a916e81 - [X86] Various improvements to our vector splitting helpers for lowering. NFC
Craig Topper via llvm-commits
llvm-commits at lists.llvm.org
Wed Apr 15 11:00:29 PDT 2020
Author: Craig Topper
Date: 2020-04-15T10:57:53-07:00
New Revision: a916e819275922ab9a350283a12647da6f4ad4b1
URL: https://github.com/llvm/llvm-project/commit/a916e819275922ab9a350283a12647da6f4ad4b1
DIFF: https://github.com/llvm/llvm-project/commit/a916e819275922ab9a350283a12647da6f4ad4b1.diff
LOG: [X86] Various improvements to our vector splitting helpers for lowering. NFC
-Consistently name the functions as split*
-Add a helper for doing the two extractSubvector calls and determining the size of the split
-Use getSplitDestVTs to get the result type for the split node.
-Move the binary and unary helper to one place in the file near the extractSubvector functions. Left the VSETCC one near LowerVSETCC since that's its only caller.
-Remove the 256/512 wrappers that just had asserts. I don't think they provided a lot of value and now with the routines called split* the call sites are more obvious what they do.
-Make the unary routine support different source and dest types to support D76212.
-Add some weaker asserts into the helpers to make up for losing the very specific asserts from the 256/512 wrappers.
Differential Revision: https://reviews.llvm.org/D78176
Added:
Modified:
llvm/lib/Target/X86/X86ISelLowering.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 47ef38ab3f4c..81bb829cd3ce 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -5802,6 +5802,71 @@ static bool collectConcatOps(SDNode *N, SmallVectorImpl<SDValue> &Ops) {
return false;
}
+static std::pair<SDValue, SDValue> splitVector(SDValue Op, SelectionDAG &DAG,
+ const SDLoc &dl) {
+ MVT VT = Op.getSimpleValueType();
+ unsigned NumElems = VT.getVectorNumElements();
+ unsigned SizeInBits = VT.getSizeInBits();
+
+ SDValue Lo = extractSubVector(Op, 0, DAG, dl, SizeInBits / 2);
+ SDValue Hi = extractSubVector(Op, NumElems / 2, DAG, dl, SizeInBits / 2);
+
+ return std::make_pair(Lo, Hi);
+}
+
+// Split an unary integer op into 2 half sized ops.
+static SDValue splitVectorIntUnary(SDValue Op, SelectionDAG &DAG) {
+ EVT VT = Op.getValueType();
+
+ // Make sure we only try to split 256/512-bit types to avoid creating
+ // narrow vectors.
+ assert((Op.getOperand(0).getValueType().is256BitVector() ||
+ Op.getOperand(0).getValueType().is512BitVector()) &&
+ (VT.is256BitVector() || VT.is512BitVector()) && "Unsupported VT!");
+ assert(Op.getOperand(0).getValueType().getVectorNumElements() ==
+ VT.getVectorNumElements() &&
+ "Unexpected VTs!");
+
+ SDLoc dl(Op);
+
+ // Extract the Lo/Hi vectors
+ SDValue Lo, Hi;
+ std::tie(Lo, Hi) = splitVector(Op.getOperand(0), DAG, dl);
+
+ EVT LoVT, HiVT;
+ std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VT);
+ return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT,
+ DAG.getNode(Op.getOpcode(), dl, LoVT, Lo),
+ DAG.getNode(Op.getOpcode(), dl, HiVT, Hi));
+}
+
+/// Break a binary integer operation into 2 half sized ops and then
+/// concatenate the result back.
+static SDValue splitVectorIntBinary(SDValue Op, SelectionDAG &DAG) {
+ EVT VT = Op.getValueType();
+
+ // Sanity check that all the types match.
+ assert(Op.getOperand(0).getValueType() == VT &&
+ Op.getOperand(1).getValueType() == VT && "Unexpected VTs!");
+ assert((VT.is256BitVector() || VT.is512BitVector()) && "Unsupported VT!");
+
+ SDLoc dl(Op);
+
+ // Extract the LHS Lo/Hi vectors
+ SDValue LHS1, LHS2;
+ std::tie(LHS1, LHS2) = splitVector(Op.getOperand(0), DAG, dl);
+
+ // Extract the RHS Lo/Hi vectors
+ SDValue RHS1, RHS2;
+ std::tie(RHS1, RHS2) = splitVector(Op.getOperand(1), DAG, dl);
+
+ EVT LoVT, HiVT;
+ std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VT);
+ return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT,
+ DAG.getNode(Op.getOpcode(), dl, LoVT, LHS1, RHS1),
+ DAG.getNode(Op.getOpcode(), dl, HiVT, LHS2, RHS2));
+}
+
// Helper for splitting operands of an operation to legal target size and
// apply a function on each part.
// Useful for operations that are available on SSE2 in 128-bit, on AVX2 in
@@ -21820,32 +21885,30 @@ static unsigned translateX86FSETCC(ISD::CondCode SetCCOpcode, SDValue &Op0,
/// Break a VSETCC 256-bit integer VSETCC into two new 128 ones and then
/// concatenate the result back.
-static SDValue Lower256IntVSETCC(SDValue Op, SelectionDAG &DAG) {
- MVT VT = Op.getSimpleValueType();
+static SDValue splitIntVSETCC(SDValue Op, SelectionDAG &DAG) {
+ EVT VT = Op.getValueType();
- assert(VT.is256BitVector() && Op.getOpcode() == ISD::SETCC &&
- "Unsupported value type for operation");
+ assert(Op.getOpcode() == ISD::SETCC && "Unsupported operation");
+ assert(Op.getOperand(0).getValueType().isInteger() &&
+ VT == Op.getOperand(0).getValueType() && "Unsupported VTs!");
- unsigned NumElems = VT.getVectorNumElements();
SDLoc dl(Op);
SDValue CC = Op.getOperand(2);
- // Extract the LHS vectors
- SDValue LHS = Op.getOperand(0);
- SDValue LHS1 = extract128BitVector(LHS, 0, DAG, dl);
- SDValue LHS2 = extract128BitVector(LHS, NumElems / 2, DAG, dl);
+ // Extract the LHS Lo/Hi vectors
+ SDValue LHS1, LHS2;
+ std::tie(LHS1, LHS2) = splitVector(Op.getOperand(0), DAG, dl);
- // Extract the RHS vectors
- SDValue RHS = Op.getOperand(1);
- SDValue RHS1 = extract128BitVector(RHS, 0, DAG, dl);
- SDValue RHS2 = extract128BitVector(RHS, NumElems / 2, DAG, dl);
+ // Extract the RHS Lo/Hi vectors
+ SDValue RHS1, RHS2;
+ std::tie(RHS1, RHS2) = splitVector(Op.getOperand(1), DAG, dl);
// Issue the operation on the smaller types and concatenate the result back
- MVT EltVT = VT.getVectorElementType();
- MVT NewVT = MVT::getVectorVT(EltVT, NumElems/2);
+ EVT LoVT, HiVT;
+ std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VT);
return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT,
- DAG.getNode(Op.getOpcode(), dl, NewVT, LHS1, RHS1, CC),
- DAG.getNode(Op.getOpcode(), dl, NewVT, LHS2, RHS2, CC));
+ DAG.getNode(ISD::SETCC, dl, LoVT, LHS1, RHS1, CC),
+ DAG.getNode(ISD::SETCC, dl, HiVT, LHS2, RHS2, CC));
}
static SDValue LowerIntVSETCC_AVX512(SDValue Op, SelectionDAG &DAG) {
@@ -22187,7 +22250,7 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget,
// Break 256-bit integer vector compare into smaller ones.
if (VT.is256BitVector() && !Subtarget.hasInt256())
- return Lower256IntVSETCC(Op, DAG);
+ return splitIntVSETCC(Op, DAG);
// If this is a SETNE against the signed minimum value, change it to SETGT.
// If this is a SETNE against the signed maximum value, change it to SETLT.
@@ -25922,43 +25985,6 @@ SDValue X86TargetLowering::LowerFLT_ROUNDS_(SDValue Op,
return DAG.getMergeValues({RetVal, Chain}, DL);
}
-// Split an unary integer op into 2 half sized ops.
-static SDValue LowerVectorIntUnary(SDValue Op, SelectionDAG &DAG) {
- MVT VT = Op.getSimpleValueType();
- unsigned NumElems = VT.getVectorNumElements();
- unsigned SizeInBits = VT.getSizeInBits();
- MVT EltVT = VT.getVectorElementType();
- SDValue Src = Op.getOperand(0);
- assert(EltVT == Src.getSimpleValueType().getVectorElementType() &&
- "Src and Op should have the same element type!");
-
- // Extract the Lo/Hi vectors
- SDLoc dl(Op);
- SDValue Lo = extractSubVector(Src, 0, DAG, dl, SizeInBits / 2);
- SDValue Hi = extractSubVector(Src, NumElems / 2, DAG, dl, SizeInBits / 2);
-
- MVT NewVT = MVT::getVectorVT(EltVT, NumElems / 2);
- return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT,
- DAG.getNode(Op.getOpcode(), dl, NewVT, Lo),
- DAG.getNode(Op.getOpcode(), dl, NewVT, Hi));
-}
-
-// Decompose 256-bit ops into smaller 128-bit ops.
-static SDValue Lower256IntUnary(SDValue Op, SelectionDAG &DAG) {
- assert(Op.getSimpleValueType().is256BitVector() &&
- Op.getSimpleValueType().isInteger() &&
- "Only handle AVX 256-bit vector integer operation");
- return LowerVectorIntUnary(Op, DAG);
-}
-
-// Decompose 512-bit ops into smaller 256-bit ops.
-static SDValue Lower512IntUnary(SDValue Op, SelectionDAG &DAG) {
- assert(Op.getSimpleValueType().is512BitVector() &&
- Op.getSimpleValueType().isInteger() &&
- "Only handle AVX 512-bit vector integer operation");
- return LowerVectorIntUnary(Op, DAG);
-}
-
/// Lower a vector CTLZ using native supported vector CTLZ instruction.
//
// i8/i16 vector implemented using dword LZCNT vector instruction
@@ -25979,7 +26005,7 @@ static SDValue LowerVectorCTLZ_AVX512CDI(SDValue Op, SelectionDAG &DAG,
// Split vector, it's Lo and Hi parts will be handled in next iteration.
if (NumElems > 16 ||
(NumElems == 16 && !Subtarget.canExtendTo512DQ()))
- return LowerVectorIntUnary(Op, DAG);
+ return splitVectorIntUnary(Op, DAG);
MVT NewVT = MVT::getVectorVT(MVT::i32, NumElems);
assert((NewVT.is256BitVector() || NewVT.is512BitVector()) &&
@@ -26089,11 +26115,11 @@ static SDValue LowerVectorCTLZ(SDValue Op, const SDLoc &DL,
// Decompose 256-bit ops into smaller 128-bit ops.
if (VT.is256BitVector() && !Subtarget.hasInt256())
- return Lower256IntUnary(Op, DAG);
+ return splitVectorIntUnary(Op, DAG);
// Decompose 512-bit ops into smaller 256-bit ops.
if (VT.is512BitVector() && !Subtarget.hasBWI())
- return Lower512IntUnary(Op, DAG);
+ return splitVectorIntUnary(Op, DAG);
assert(Subtarget.hasSSSE3() && "Expected SSSE3 support for PSHUFB");
return LowerVectorCTLZInRegLUT(Op, DL, Subtarget, DAG);
@@ -26159,48 +26185,6 @@ static SDValue LowerCTTZ(SDValue Op, const X86Subtarget &Subtarget,
return DAG.getNode(X86ISD::CMOV, dl, VT, Ops);
}
-/// Break a binary integer operation into 2 half sized ops and then
-/// concatenate the result back.
-static SDValue splitVectorIntBinary(SDValue Op, SelectionDAG &DAG) {
- MVT VT = Op.getSimpleValueType();
- unsigned NumElems = VT.getVectorNumElements();
- unsigned SizeInBits = VT.getSizeInBits();
- SDLoc dl(Op);
-
- // Extract the LHS Lo/Hi vectors
- SDValue LHS = Op.getOperand(0);
- SDValue LHS1 = extractSubVector(LHS, 0, DAG, dl, SizeInBits / 2);
- SDValue LHS2 = extractSubVector(LHS, NumElems / 2, DAG, dl, SizeInBits / 2);
-
- // Extract the RHS Lo/Hi vectors
- SDValue RHS = Op.getOperand(1);
- SDValue RHS1 = extractSubVector(RHS, 0, DAG, dl, SizeInBits / 2);
- SDValue RHS2 = extractSubVector(RHS, NumElems / 2, DAG, dl, SizeInBits / 2);
-
- MVT NewVT = MVT::getVectorVT(VT.getVectorElementType(), NumElems / 2);
- return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT,
- DAG.getNode(Op.getOpcode(), dl, NewVT, LHS1, RHS1),
- DAG.getNode(Op.getOpcode(), dl, NewVT, LHS2, RHS2));
-}
-
-/// Break a 256-bit integer operation into two new 128-bit ones and then
-/// concatenate the result back.
-static SDValue split256IntArith(SDValue Op, SelectionDAG &DAG) {
- assert(Op.getSimpleValueType().is256BitVector() &&
- Op.getSimpleValueType().isInteger() &&
- "Unsupported value type for operation");
- return splitVectorIntBinary(Op, DAG);
-}
-
-/// Break a 512-bit integer operation into two new 256-bit ones and then
-/// concatenate the result back.
-static SDValue split512IntArith(SDValue Op, SelectionDAG &DAG) {
- assert(Op.getSimpleValueType().is512BitVector() &&
- Op.getSimpleValueType().isInteger() &&
- "Unsupported value type for operation");
- return splitVectorIntBinary(Op, DAG);
-}
-
static SDValue lowerAddSub(SDValue Op, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
MVT VT = Op.getSimpleValueType();
@@ -26214,7 +26198,7 @@ static SDValue lowerAddSub(SDValue Op, SelectionDAG &DAG,
assert(Op.getSimpleValueType().is256BitVector() &&
Op.getSimpleValueType().isInteger() &&
"Only handle AVX 256-bit vector integer operation");
- return split256IntArith(Op, DAG);
+ return splitVectorIntBinary(Op, DAG);
}
static SDValue LowerADDSAT_SUBSAT(SDValue Op, SelectionDAG &DAG,
@@ -26262,7 +26246,7 @@ static SDValue LowerADDSAT_SUBSAT(SDValue Op, SelectionDAG &DAG,
assert(Op.getSimpleValueType().is256BitVector() &&
Op.getSimpleValueType().isInteger() &&
"Only handle AVX 256-bit vector integer operation");
- return split256IntArith(Op, DAG);
+ return splitVectorIntBinary(Op, DAG);
}
static SDValue LowerABS(SDValue Op, const X86Subtarget &Subtarget,
@@ -26292,7 +26276,7 @@ static SDValue LowerABS(SDValue Op, const X86Subtarget &Subtarget,
if (VT.is256BitVector() && !Subtarget.hasInt256()) {
assert(VT.isInteger() &&
"Only handle AVX 256-bit vector integer operation");
- return Lower256IntUnary(Op, DAG);
+ return splitVectorIntUnary(Op, DAG);
}
// Default to expand.
@@ -26304,7 +26288,7 @@ static SDValue LowerMINMAX(SDValue Op, SelectionDAG &DAG) {
// For AVX1 cases, split to use legal ops (everything but v4i64).
if (VT.getScalarType() != MVT::i64 && VT.is256BitVector())
- return split256IntArith(Op, DAG);
+ return splitVectorIntBinary(Op, DAG);
SDLoc DL(Op);
unsigned Opcode = Op.getOpcode();
@@ -26348,7 +26332,7 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget,
// Decompose 256-bit ops into 128-bit ops.
if (VT.is256BitVector() && !Subtarget.hasInt256())
- return split256IntArith(Op, DAG);
+ return splitVectorIntBinary(Op, DAG);
SDValue A = Op.getOperand(0);
SDValue B = Op.getOperand(1);
@@ -26494,7 +26478,7 @@ static SDValue LowerMULH(SDValue Op, const X86Subtarget &Subtarget,
// Decompose 256-bit ops into 128-bit ops.
if (VT.is256BitVector() && !Subtarget.hasInt256())
- return split256IntArith(Op, DAG);
+ return splitVectorIntBinary(Op, DAG);
if (VT == MVT::v4i32 || VT == MVT::v8i32 || VT == MVT::v16i32) {
assert((VT == MVT::v4i32 && Subtarget.hasSSE2()) ||
@@ -26586,7 +26570,7 @@ static SDValue LowerMULH(SDValue Op, const X86Subtarget &Subtarget,
// For signed 512-bit vectors, split into 256-bit vectors to allow the
// sign-extension to occur.
if (VT == MVT::v64i8 && IsSigned)
- return split512IntArith(Op, DAG);
+ return splitVectorIntBinary(Op, DAG);
// Signed AVX2 implementation - extend xmm subvectors to ymm.
if (VT == MVT::v32i8 && IsSigned) {
@@ -27560,7 +27544,7 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
// Decompose 256-bit shifts into 128-bit shifts.
if (VT.is256BitVector())
- return split256IntArith(Op, DAG);
+ return splitVectorIntBinary(Op, DAG);
return SDValue();
}
@@ -27606,7 +27590,7 @@ static SDValue LowerRotate(SDValue Op, const X86Subtarget &Subtarget,
// XOP implicitly uses modulo rotation amounts.
if (Subtarget.hasXOP()) {
if (VT.is256BitVector())
- return split256IntArith(Op, DAG);
+ return splitVectorIntBinary(Op, DAG);
assert(VT.is128BitVector() && "Only rotate 128-bit vectors!");
// Attempt to rotate by immediate.
@@ -27622,7 +27606,7 @@ static SDValue LowerRotate(SDValue Op, const X86Subtarget &Subtarget,
// Split 256-bit integers on pre-AVX2 targets.
if (VT.is256BitVector() && !Subtarget.hasAVX2())
- return split256IntArith(Op, DAG);
+ return splitVectorIntBinary(Op, DAG);
assert((VT == MVT::v4i32 || VT == MVT::v8i16 || VT == MVT::v16i8 ||
((VT == MVT::v8i32 || VT == MVT::v16i16 || VT == MVT::v32i8) &&
@@ -28287,11 +28271,11 @@ static SDValue LowerVectorCTPOP(SDValue Op, const X86Subtarget &Subtarget,
// Decompose 256-bit ops into smaller 128-bit ops.
if (VT.is256BitVector() && !Subtarget.hasInt256())
- return Lower256IntUnary(Op, DAG);
+ return splitVectorIntUnary(Op, DAG);
// Decompose 512-bit ops into smaller 256-bit ops.
if (VT.is512BitVector() && !Subtarget.hasBWI())
- return Lower512IntUnary(Op, DAG);
+ return splitVectorIntUnary(Op, DAG);
// For element types greater than i8, do vXi8 pop counts and a bytesum.
if (VT.getScalarType() != MVT::i8) {
@@ -28335,7 +28319,7 @@ static SDValue LowerBITREVERSE_XOP(SDValue Op, SelectionDAG &DAG) {
// Decompose 256-bit ops into smaller 128-bit ops.
if (VT.is256BitVector())
- return Lower256IntUnary(Op, DAG);
+ return splitVectorIntUnary(Op, DAG);
assert(VT.is128BitVector() &&
"Only 128-bit vector bitreverse lowering supported.");
@@ -28376,7 +28360,7 @@ static SDValue LowerBITREVERSE(SDValue Op, const X86Subtarget &Subtarget,
// lowering.
if (VT == MVT::v8i64 || VT == MVT::v16i32) {
assert(!Subtarget.hasBWI() && "BWI should Expand BITREVERSE");
- return Lower512IntUnary(Op, DAG);
+ return splitVectorIntUnary(Op, DAG);
}
unsigned NumElts = VT.getVectorNumElements();
@@ -28385,7 +28369,7 @@ static SDValue LowerBITREVERSE(SDValue Op, const X86Subtarget &Subtarget,
// Decompose 256-bit ops into smaller 128-bit ops on pre-AVX2.
if (VT.is256BitVector() && !Subtarget.hasInt256())
- return Lower256IntUnary(Op, DAG);
+ return splitVectorIntUnary(Op, DAG);
// Perform BITREVERSE using PSHUFB lookups. Each byte is split into
// two nibbles and a PSHUFB lookup to find the bitreverse of each
@@ -47137,7 +47121,7 @@ static SDValue combineExtractSubvector(SDNode *N, SelectionDAG &DAG,
if (isConcatenatedNot(InVecBC.getOperand(0)) ||
isConcatenatedNot(InVecBC.getOperand(1))) {
// extract (and v4i64 X, (not (concat Y1, Y2))), n -> andnp v2i64 X(n), Y1
- SDValue Concat = split256IntArith(InVecBC, DAG);
+ SDValue Concat = splitVectorIntBinary(InVecBC, DAG);
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), VT,
DAG.getBitcast(InVecVT, Concat), N->getOperand(1));
}
More information about the llvm-commits
mailing list