[llvm] ce356e1 - [DAG] Add BuildVectorSDNode::getRepeatedSequence helper to recognise multi-element splat patterns
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Sat Oct 24 04:23:41 PDT 2020
Author: Simon Pilgrim
Date: 2020-10-24T12:23:09+01:00
New Revision: ce356e1546c9c538134dcdfc9f2d728b8ba0719c
URL: https://github.com/llvm/llvm-project/commit/ce356e1546c9c538134dcdfc9f2d728b8ba0719c
DIFF: https://github.com/llvm/llvm-project/commit/ce356e1546c9c538134dcdfc9f2d728b8ba0719c.diff
LOG: [DAG] Add BuildVectorSDNode::getRepeatedSequence helper to recognise multi-element splat patterns
Replace the X86 specific isSplatZeroExtended helper with a generic BuildVectorSDNode method.
I've just used this to simplify the X86ISD::BROADCASTM lowering so far (and remove isSplatZeroExtended), but we should be able to use this in more places to lower to complex broadcast patterns.
Differential Revision: https://reviews.llvm.org/D87930
Added:
Modified:
llvm/include/llvm/CodeGen/SelectionDAGNodes.h
llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 632d112dcd6c..fb81718e119b 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -1908,6 +1908,33 @@ class BuildVectorSDNode : public SDNode {
/// the vector width and set the bits where elements are undef.
SDValue getSplatValue(BitVector *UndefElements = nullptr) const;
+ /// Find the shortest repeating sequence of values in the build vector.
+ ///
+ /// e.g. { u, X, u, X, u, u, X, u } -> { X }
+ /// { X, Y, u, Y, u, u, X, u } -> { X, Y }
+ ///
+ /// Currently this must be a power-of-2 build vector.
+ /// The DemandedElts mask indicates the elements that must be present,
+ /// undemanded elements in Sequence may be null (SDValue()). If passed a
+ /// non-null UndefElements bitvector, it will resize it to match the original
+ /// vector width and set the bits where elements are undef. If result is
+ /// false, Sequence will be empty.
+ bool getRepeatedSequence(const APInt &DemandedElts,
+ SmallVectorImpl<SDValue> &Sequence,
+ BitVector *UndefElements = nullptr) const;
+
+ /// Find the shortest repeating sequence of values in the build vector.
+ ///
+ /// e.g. { u, X, u, X, u, u, X, u } -> { X }
+ /// { X, Y, u, Y, u, u, X, u } -> { X, Y }
+ ///
+ /// Currently this must be a power-of-2 build vector.
+ /// If passed a non-null UndefElements bitvector, it will resize it to match
+ /// the original vector width and set the bits where elements are undef.
+ /// If result is false, Sequence will be empty.
+ bool getRepeatedSequence(SmallVectorImpl<SDValue> &Sequence,
+ BitVector *UndefElements = nullptr) const;
+
/// Returns the demanded splatted constant or null if this is not a constant
/// splat.
///
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 24c32d367283..f05c6a6b7c23 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -9870,6 +9870,58 @@ SDValue BuildVectorSDNode::getSplatValue(BitVector *UndefElements) const {
return getSplatValue(DemandedElts, UndefElements);
}
+bool BuildVectorSDNode::getRepeatedSequence(const APInt &DemandedElts,
+ SmallVectorImpl<SDValue> &Sequence,
+ BitVector *UndefElements) const {
+ unsigned NumOps = getNumOperands();
+ Sequence.clear();
+ if (UndefElements) {
+ UndefElements->clear();
+ UndefElements->resize(NumOps);
+ }
+ assert(NumOps == DemandedElts.getBitWidth() && "Unexpected vector size");
+ if (!DemandedElts || NumOps < 2 || !isPowerOf2_32(NumOps))
+ return false;
+
+ // Set the undefs even if we don't find a sequence (like getSplatValue).
+ if (UndefElements)
+ for (unsigned I = 0; I != NumOps; ++I)
+ if (DemandedElts[I] && getOperand(I).isUndef())
+ (*UndefElements)[I] = true;
+
+ // Iteratively widen the sequence length looking for repetitions.
+ for (unsigned SeqLen = 1; SeqLen < NumOps; SeqLen *= 2) {
+ Sequence.append(SeqLen, SDValue());
+ for (unsigned I = 0; I != NumOps; ++I) {
+ if (!DemandedElts[I])
+ continue;
+ SDValue &SeqOp = Sequence[I % SeqLen];
+ SDValue Op = getOperand(I);
+ if (Op.isUndef()) {
+ if (!SeqOp)
+ SeqOp = Op;
+ continue;
+ }
+ if (SeqOp && !SeqOp.isUndef() && SeqOp != Op) {
+ Sequence.clear();
+ break;
+ }
+ SeqOp = Op;
+ }
+ if (!Sequence.empty())
+ return true;
+ }
+
+ assert(Sequence.empty() && "Failed to empty non-repeating sequence pattern");
+ return false;
+}
+
+bool BuildVectorSDNode::getRepeatedSequence(SmallVectorImpl<SDValue> &Sequence,
+ BitVector *UndefElements) const {
+ APInt DemandedElts = APInt::getAllOnesValue(getNumOperands());
+ return getRepeatedSequence(DemandedElts, Sequence, UndefElements);
+}
+
ConstantSDNode *
BuildVectorSDNode::getConstantSplatNode(const APInt &DemandedElts,
BitVector *UndefElements) const {
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 6aeb0d9d062c..868adaf61a51 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -8659,43 +8659,6 @@ static bool isFoldableUseOfShuffle(SDNode *N) {
return false;
}
-// Check if the current node of build vector is a zero extended vector.
-// // If so, return the value extended.
-// // For example: (0,0,0,a,0,0,0,a,0,0,0,a,0,0,0,a) returns a.
-// // NumElt - return the number of zero extended identical values.
-// // EltType - return the type of the value include the zero extend.
-static SDValue isSplatZeroExtended(const BuildVectorSDNode *Op,
- unsigned &NumElt, MVT &EltType) {
- SDValue ExtValue = Op->getOperand(0);
- unsigned NumElts = Op->getNumOperands();
- unsigned Delta = NumElts;
-
- for (unsigned i = 1; i < NumElts; i++) {
- if (Op->getOperand(i) == ExtValue) {
- Delta = i;
- break;
- }
- if (!(Op->getOperand(i).isUndef() || isNullConstant(Op->getOperand(i))))
- return SDValue();
- }
- if (!isPowerOf2_32(Delta) || Delta == 1)
- return SDValue();
-
- for (unsigned i = Delta; i < NumElts; i++) {
- if (i % Delta == 0) {
- if (Op->getOperand(i) != ExtValue)
- return SDValue();
- } else if (!(isNullConstant(Op->getOperand(i)) ||
- Op->getOperand(i).isUndef()))
- return SDValue();
- }
- unsigned EltSize = Op->getSimpleValueType(0).getScalarSizeInBits();
- unsigned ExtVTSize = EltSize * Delta;
- EltType = MVT::getIntegerVT(ExtVTSize);
- NumElt = NumElts / Delta;
- return ExtValue;
-}
-
/// Attempt to use the vbroadcast instruction to generate a splat value
/// from a splat BUILD_VECTOR which uses:
/// a. A single scalar load, or a constant.
@@ -8713,13 +8676,21 @@ static SDValue lowerBuildVectorAsBroadcast(BuildVectorSDNode *BVOp,
return SDValue();
MVT VT = BVOp->getSimpleValueType(0);
+ unsigned NumElts = VT.getVectorNumElements();
SDLoc dl(BVOp);
assert((VT.is128BitVector() || VT.is256BitVector() || VT.is512BitVector()) &&
"Unsupported vector type for broadcast.");
+ // See if the build vector is a repeating sequence of scalars (inc. splat).
+ SDValue Ld;
BitVector UndefElements;
- SDValue Ld = BVOp->getSplatValue(&UndefElements);
+ SmallVector<SDValue, 16> Sequence;
+ if (BVOp->getRepeatedSequence(Sequence, &UndefElements)) {
+ assert((NumElts % Sequence.size()) == 0 && "Sequence doesn't fit.");
+ if (Sequence.size() == 1)
+ Ld = Sequence[0];
+ }
// Attempt to use VBROADCASTM
// From this pattern:
@@ -8727,29 +8698,29 @@ static SDValue lowerBuildVectorAsBroadcast(BuildVectorSDNode *BVOp,
// b. t1 = (build_vector t0 t0)
//
// Create (VBROADCASTM v2i1 X)
- if (Subtarget.hasCDI()) {
- MVT EltType = VT.getScalarType();
- unsigned NumElts = VT.getVectorNumElements();
- SDValue BOperand;
- SDValue ZeroExtended = isSplatZeroExtended(BVOp, NumElts, EltType);
- if ((ZeroExtended && ZeroExtended.getOpcode() == ISD::BITCAST) ||
- (ZeroExtended && ZeroExtended.getOpcode() == ISD::ZERO_EXTEND &&
- ZeroExtended.getOperand(0).getOpcode() == ISD::BITCAST) ||
- (Ld && Ld.getOpcode() == ISD::ZERO_EXTEND &&
- Ld.getOperand(0).getOpcode() == ISD::BITCAST)) {
- if (ZeroExtended && ZeroExtended.getOpcode() == ISD::BITCAST)
- BOperand = ZeroExtended.getOperand(0);
- else if (ZeroExtended)
- BOperand = ZeroExtended.getOperand(0).getOperand(0);
- else
- BOperand = Ld.getOperand(0).getOperand(0);
+ if (!Sequence.empty() && Subtarget.hasCDI()) {
+ // If not a splat, are the upper sequence values zeroable?
+ unsigned SeqLen = Sequence.size();
+ bool UpperZeroOrUndef =
+ SeqLen == 1 ||
+ llvm::all_of(makeArrayRef(Sequence).drop_front(), [](SDValue V) {
+ return !V || V.isUndef() || isNullConstant(V);
+ });
+ SDValue Op0 = Sequence[0];
+ if (UpperZeroOrUndef && ((Op0.getOpcode() == ISD::BITCAST) ||
+ (Op0.getOpcode() == ISD::ZERO_EXTEND &&
+ Op0.getOperand(0).getOpcode() == ISD::BITCAST))) {
+ SDValue BOperand = Op0.getOpcode() == ISD::BITCAST
+ ? Op0.getOperand(0)
+ : Op0.getOperand(0).getOperand(0);
MVT MaskVT = BOperand.getSimpleValueType();
+ MVT EltType = MVT::getIntegerVT(VT.getScalarSizeInBits() * SeqLen);
if ((EltType == MVT::i64 && MaskVT == MVT::v8i1) || // for broadcastmb2q
(EltType == MVT::i32 && MaskVT == MVT::v16i1)) { // for broadcastmw2d
- MVT BcstVT = MVT::getVectorVT(EltType, NumElts);
+ MVT BcstVT = MVT::getVectorVT(EltType, NumElts / SeqLen);
if (!VT.is512BitVector() && !Subtarget.hasVLX()) {
unsigned Scale = 512 / VT.getSizeInBits();
- BcstVT = MVT::getVectorVT(EltType, NumElts * Scale);
+ BcstVT = MVT::getVectorVT(EltType, Scale * (NumElts / SeqLen));
}
SDValue Bcst = DAG.getNode(X86ISD::VBROADCASTM, dl, BcstVT, BOperand);
if (BcstVT.getSizeInBits() != VT.getSizeInBits())
@@ -8759,7 +8730,6 @@ static SDValue lowerBuildVectorAsBroadcast(BuildVectorSDNode *BVOp,
}
}
- unsigned NumElts = VT.getVectorNumElements();
unsigned NumUndefElts = UndefElements.count();
if (!Ld || (NumElts - NumUndefElts) <= 1) {
APInt SplatValue, Undef;
@@ -8833,6 +8803,8 @@ static SDValue lowerBuildVectorAsBroadcast(BuildVectorSDNode *BVOp,
(Ld.getOpcode() == ISD::Constant || Ld.getOpcode() == ISD::ConstantFP);
bool IsLoad = ISD::isNormalLoad(Ld.getNode());
+ // TODO: Handle broadcasts of non-constant sequences.
+
// Make sure that all of the users of a non-constant load are from the
// BUILD_VECTOR node.
// FIXME: Is the use count needed for non-constant, non-load case?
diff --git a/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp b/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp
index 584ef65b20bc..d601b552af84 100644
--- a/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp
+++ b/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp
@@ -472,6 +472,127 @@ TEST_F(AArch64SelectionDAGTest, getSplatSourceVector_Scalable_ADD_of_SPLAT_VECTO
EXPECT_EQ(SplatIdx, 0);
}
+TEST_F(AArch64SelectionDAGTest, getRepeatedSequence_Patterns) {
+ if (!TM)
+ return;
+
+ TargetLowering TL(*TM);
+
+ SDLoc Loc;
+ unsigned NumElts = 16;
+ MVT IntVT = MVT::i8;
+ MVT VecVT = MVT::getVectorVT(IntVT, NumElts);
+
+ // Base scalar constants.
+ SDValue Val0 = DAG->getConstant(0, Loc, IntVT);
+ SDValue Val1 = DAG->getConstant(1, Loc, IntVT);
+ SDValue Val2 = DAG->getConstant(2, Loc, IntVT);
+ SDValue Val3 = DAG->getConstant(3, Loc, IntVT);
+ SDValue UndefVal = DAG->getUNDEF(IntVT);
+
+ // Build some repeating sequences.
+ SmallVector<SDValue, 16> Pattern1111, Pattern1133, Pattern0123;
+ for(int I = 0; I != 4; ++I) {
+ Pattern1111.append(4, Val1);
+ Pattern1133.append(2, Val1);
+ Pattern1133.append(2, Val3);
+ Pattern0123.push_back(Val0);
+ Pattern0123.push_back(Val1);
+ Pattern0123.push_back(Val2);
+ Pattern0123.push_back(Val3);
+ }
+
+ // Build a non-pow2 repeating sequence.
+ SmallVector<SDValue, 16> Pattern022;
+ Pattern022.push_back(Val0);
+ Pattern022.append(2, Val2);
+ Pattern022.push_back(Val0);
+ Pattern022.append(2, Val2);
+ Pattern022.push_back(Val0);
+ Pattern022.append(2, Val2);
+ Pattern022.push_back(Val0);
+ Pattern022.append(2, Val2);
+ Pattern022.push_back(Val0);
+ Pattern022.append(2, Val2);
+ Pattern022.push_back(Val0);
+
+ // Build a non-repeating sequence.
+ SmallVector<SDValue, 16> Pattern1_3;
+ Pattern1_3.append(8, Val1);
+ Pattern1_3.append(8, Val3);
+
+ // Add some undefs to make it trickier.
+ Pattern1111[1] = Pattern1111[2] = Pattern1111[15] = UndefVal;
+ Pattern1133[0] = Pattern1133[2] = UndefVal;
+
+ auto *BV1111 =
+ cast<BuildVectorSDNode>(DAG->getBuildVector(VecVT, Loc, Pattern1111));
+ auto *BV1133 =
+ cast<BuildVectorSDNode>(DAG->getBuildVector(VecVT, Loc, Pattern1133));
+ auto *BV0123=
+ cast<BuildVectorSDNode>(DAG->getBuildVector(VecVT, Loc, Pattern0123));
+ auto *BV022 =
+ cast<BuildVectorSDNode>(DAG->getBuildVector(VecVT, Loc, Pattern022));
+ auto *BV1_3 =
+ cast<BuildVectorSDNode>(DAG->getBuildVector(VecVT, Loc, Pattern1_3));
+
+ // Check for sequences.
+ SmallVector<SDValue, 16> Seq1111, Seq1133, Seq0123, Seq022, Seq1_3;
+ BitVector Undefs1111, Undefs1133, Undefs0123, Undefs022, Undefs1_3;
+
+ EXPECT_TRUE(BV1111->getRepeatedSequence(Seq1111, &Undefs1111));
+ EXPECT_EQ(Undefs1111.count(), 3);
+ EXPECT_EQ(Seq1111.size(), 1);
+ EXPECT_EQ(Seq1111[0], Val1);
+
+ EXPECT_TRUE(BV1133->getRepeatedSequence(Seq1133, &Undefs1133));
+ EXPECT_EQ(Undefs1133.count(), 2);
+ EXPECT_EQ(Seq1133.size(), 4);
+ EXPECT_EQ(Seq1133[0], Val1);
+ EXPECT_EQ(Seq1133[1], Val1);
+ EXPECT_EQ(Seq1133[2], Val3);
+ EXPECT_EQ(Seq1133[3], Val3);
+
+ EXPECT_TRUE(BV0123->getRepeatedSequence(Seq0123, &Undefs0123));
+ EXPECT_EQ(Undefs0123.count(), 0);
+ EXPECT_EQ(Seq0123.size(), 4);
+ EXPECT_EQ(Seq0123[0], Val0);
+ EXPECT_EQ(Seq0123[1], Val1);
+ EXPECT_EQ(Seq0123[2], Val2);
+ EXPECT_EQ(Seq0123[3], Val3);
+
+ EXPECT_FALSE(BV022->getRepeatedSequence(Seq022, &Undefs022));
+ EXPECT_FALSE(BV1_3->getRepeatedSequence(Seq1_3, &Undefs1_3));
+
+ // Try again with DemandedElts masks.
+ APInt Mask1111_0 = APInt::getOneBitSet(NumElts, 0);
+ EXPECT_TRUE(BV1111->getRepeatedSequence(Mask1111_0, Seq1111, &Undefs1111));
+ EXPECT_EQ(Undefs1111.count(), 0);
+ EXPECT_EQ(Seq1111.size(), 1);
+ EXPECT_EQ(Seq1111[0], Val1);
+
+ APInt Mask1111_1 = APInt::getOneBitSet(NumElts, 2);
+ EXPECT_TRUE(BV1111->getRepeatedSequence(Mask1111_1, Seq1111, &Undefs1111));
+ EXPECT_EQ(Undefs1111.count(), 1);
+ EXPECT_EQ(Seq1111.size(), 1);
+ EXPECT_EQ(Seq1111[0], UndefVal);
+
+ APInt Mask0123 = APInt(NumElts, 0x7777);
+ EXPECT_TRUE(BV0123->getRepeatedSequence(Mask0123, Seq0123, &Undefs0123));
+ EXPECT_EQ(Undefs0123.count(), 0);
+ EXPECT_EQ(Seq0123.size(), 4);
+ EXPECT_EQ(Seq0123[0], Val0);
+ EXPECT_EQ(Seq0123[1], Val1);
+ EXPECT_EQ(Seq0123[2], Val2);
+ EXPECT_EQ(Seq0123[3], SDValue());
+
+ APInt Mask1_3 = APInt::getHighBitsSet(16, 8);
+ EXPECT_TRUE(BV1_3->getRepeatedSequence(Mask1_3, Seq1_3, &Undefs1_3));
+ EXPECT_EQ(Undefs1_3.count(), 0);
+ EXPECT_EQ(Seq1_3.size(), 1);
+ EXPECT_EQ(Seq1_3[0], Val3);
+}
+
TEST_F(AArch64SelectionDAGTest, getTypeConversion_SplitScalableMVT) {
if (!TM)
return;
More information about the llvm-commits
mailing list