[llvm] ce356e1 - [DAG] Add BuildVectorSDNode::getRepeatedSequence helper to recognise multi-element splat patterns

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Sat Oct 24 04:23:41 PDT 2020


Author: Simon Pilgrim
Date: 2020-10-24T12:23:09+01:00
New Revision: ce356e1546c9c538134dcdfc9f2d728b8ba0719c

URL: https://github.com/llvm/llvm-project/commit/ce356e1546c9c538134dcdfc9f2d728b8ba0719c
DIFF: https://github.com/llvm/llvm-project/commit/ce356e1546c9c538134dcdfc9f2d728b8ba0719c.diff

LOG: [DAG] Add BuildVectorSDNode::getRepeatedSequence helper to recognise multi-element splat patterns

Replace the X86 specific isSplatZeroExtended helper with a generic BuildVectorSDNode method.

I've just used this to simplify the X86ISD::BROADCASTM lowering so far (and remove isSplatZeroExtended), but we should be able to use this in more places to lower to complex broadcast patterns.

Differential Revision: https://reviews.llvm.org/D87930

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/SelectionDAGNodes.h
    llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 632d112dcd6c..fb81718e119b 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -1908,6 +1908,33 @@ class BuildVectorSDNode : public SDNode {
   /// the vector width and set the bits where elements are undef.
   SDValue getSplatValue(BitVector *UndefElements = nullptr) const;
 
+  /// Find the shortest repeating sequence of values in the build vector.
+  ///
+  /// e.g. { u, X, u, X, u, u, X, u } -> { X }
+  ///      { X, Y, u, Y, u, u, X, u } -> { X, Y }
+  ///
+  /// Currently this must be a power-of-2 build vector.
+  /// The DemandedElts mask indicates the elements that must be present,
+  /// undemanded elements in Sequence may be null (SDValue()). If passed a
+  /// non-null UndefElements bitvector, it will resize it to match the original
+  /// vector width and set the bits where elements are undef. If result is
+  /// false, Sequence will be empty.
+  bool getRepeatedSequence(const APInt &DemandedElts,
+                           SmallVectorImpl<SDValue> &Sequence,
+                           BitVector *UndefElements = nullptr) const;
+
+  /// Find the shortest repeating sequence of values in the build vector.
+  ///
+  /// e.g. { u, X, u, X, u, u, X, u } -> { X }
+  ///      { X, Y, u, Y, u, u, X, u } -> { X, Y }
+  ///
+  /// Currently this must be a power-of-2 build vector.
+  /// If passed a non-null UndefElements bitvector, it will resize it to match
+  /// the original vector width and set the bits where elements are undef.
+  /// If result is false, Sequence will be empty.
+  bool getRepeatedSequence(SmallVectorImpl<SDValue> &Sequence,
+                           BitVector *UndefElements = nullptr) const;
+
   /// Returns the demanded splatted constant or null if this is not a constant
   /// splat.
   ///

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 24c32d367283..f05c6a6b7c23 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -9870,6 +9870,58 @@ SDValue BuildVectorSDNode::getSplatValue(BitVector *UndefElements) const {
   return getSplatValue(DemandedElts, UndefElements);
 }
 
+bool BuildVectorSDNode::getRepeatedSequence(const APInt &DemandedElts,
+                                            SmallVectorImpl<SDValue> &Sequence,
+                                            BitVector *UndefElements) const {
+  unsigned NumOps = getNumOperands();
+  Sequence.clear();
+  if (UndefElements) {
+    UndefElements->clear();
+    UndefElements->resize(NumOps);
+  }
+  assert(NumOps == DemandedElts.getBitWidth() && "Unexpected vector size");
+  if (!DemandedElts || NumOps < 2 || !isPowerOf2_32(NumOps))
+    return false;
+
+  // Set the undefs even if we don't find a sequence (like getSplatValue).
+  if (UndefElements)
+    for (unsigned I = 0; I != NumOps; ++I)
+      if (DemandedElts[I] && getOperand(I).isUndef())
+        (*UndefElements)[I] = true;
+
+  // Iteratively widen the sequence length looking for repetitions.
+  for (unsigned SeqLen = 1; SeqLen < NumOps; SeqLen *= 2) {
+    Sequence.append(SeqLen, SDValue());
+    for (unsigned I = 0; I != NumOps; ++I) {
+      if (!DemandedElts[I])
+        continue;
+      SDValue &SeqOp = Sequence[I % SeqLen];
+      SDValue Op = getOperand(I);
+      if (Op.isUndef()) {
+        if (!SeqOp)
+          SeqOp = Op;
+        continue;
+      }
+      if (SeqOp && !SeqOp.isUndef() && SeqOp != Op) {
+        Sequence.clear();
+        break;
+      }
+      SeqOp = Op;
+    }
+    if (!Sequence.empty())
+      return true;
+  }
+
+  assert(Sequence.empty() && "Failed to empty non-repeating sequence pattern");
+  return false;
+}
+
+bool BuildVectorSDNode::getRepeatedSequence(SmallVectorImpl<SDValue> &Sequence,
+                                            BitVector *UndefElements) const {
+  APInt DemandedElts = APInt::getAllOnesValue(getNumOperands());
+  return getRepeatedSequence(DemandedElts, Sequence, UndefElements);
+}
+
 ConstantSDNode *
 BuildVectorSDNode::getConstantSplatNode(const APInt &DemandedElts,
                                         BitVector *UndefElements) const {

diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 6aeb0d9d062c..868adaf61a51 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -8659,43 +8659,6 @@ static bool isFoldableUseOfShuffle(SDNode *N) {
   return false;
 }
 
-// Check if the current node of build vector is a zero extended vector.
-// // If so, return the value extended.
-// // For example: (0,0,0,a,0,0,0,a,0,0,0,a,0,0,0,a) returns a.
-// // NumElt - return the number of zero extended identical values.
-// // EltType - return the type of the value include the zero extend.
-static SDValue isSplatZeroExtended(const BuildVectorSDNode *Op,
-                                   unsigned &NumElt, MVT &EltType) {
-  SDValue ExtValue = Op->getOperand(0);
-  unsigned NumElts = Op->getNumOperands();
-  unsigned Delta = NumElts;
-
-  for (unsigned i = 1; i < NumElts; i++) {
-    if (Op->getOperand(i) == ExtValue) {
-      Delta = i;
-      break;
-    }
-    if (!(Op->getOperand(i).isUndef() || isNullConstant(Op->getOperand(i))))
-      return SDValue();
-  }
-  if (!isPowerOf2_32(Delta) || Delta == 1)
-    return SDValue();
-
-  for (unsigned i = Delta; i < NumElts; i++) {
-    if (i % Delta == 0) {
-      if (Op->getOperand(i) != ExtValue)
-        return SDValue();
-    } else if (!(isNullConstant(Op->getOperand(i)) ||
-                 Op->getOperand(i).isUndef()))
-      return SDValue();
-  }
-  unsigned EltSize = Op->getSimpleValueType(0).getScalarSizeInBits();
-  unsigned ExtVTSize = EltSize * Delta;
-  EltType = MVT::getIntegerVT(ExtVTSize);
-  NumElt = NumElts / Delta;
-  return ExtValue;
-}
-
 /// Attempt to use the vbroadcast instruction to generate a splat value
 /// from a splat BUILD_VECTOR which uses:
 ///  a. A single scalar load, or a constant.
@@ -8713,13 +8676,21 @@ static SDValue lowerBuildVectorAsBroadcast(BuildVectorSDNode *BVOp,
     return SDValue();
 
   MVT VT = BVOp->getSimpleValueType(0);
+  unsigned NumElts = VT.getVectorNumElements();
   SDLoc dl(BVOp);
 
   assert((VT.is128BitVector() || VT.is256BitVector() || VT.is512BitVector()) &&
          "Unsupported vector type for broadcast.");
 
+  // See if the build vector is a repeating sequence of scalars (inc. splat).
+  SDValue Ld;
   BitVector UndefElements;
-  SDValue Ld = BVOp->getSplatValue(&UndefElements);
+  SmallVector<SDValue, 16> Sequence;
+  if (BVOp->getRepeatedSequence(Sequence, &UndefElements)) {
+    assert((NumElts % Sequence.size()) == 0 && "Sequence doesn't fit.");
+    if (Sequence.size() == 1)
+      Ld = Sequence[0];
+  }
 
   // Attempt to use VBROADCASTM
   // From this pattern:
@@ -8727,29 +8698,29 @@ static SDValue lowerBuildVectorAsBroadcast(BuildVectorSDNode *BVOp,
   // b. t1 = (build_vector t0 t0)
   //
   // Create (VBROADCASTM v2i1 X)
-  if (Subtarget.hasCDI()) {
-    MVT EltType = VT.getScalarType();
-    unsigned NumElts = VT.getVectorNumElements();
-    SDValue BOperand;
-    SDValue ZeroExtended = isSplatZeroExtended(BVOp, NumElts, EltType);
-    if ((ZeroExtended && ZeroExtended.getOpcode() == ISD::BITCAST) ||
-        (ZeroExtended && ZeroExtended.getOpcode() == ISD::ZERO_EXTEND &&
-         ZeroExtended.getOperand(0).getOpcode() == ISD::BITCAST) ||
-        (Ld && Ld.getOpcode() == ISD::ZERO_EXTEND &&
-         Ld.getOperand(0).getOpcode() == ISD::BITCAST)) {
-      if (ZeroExtended && ZeroExtended.getOpcode() == ISD::BITCAST)
-        BOperand = ZeroExtended.getOperand(0);
-      else if (ZeroExtended)
-        BOperand = ZeroExtended.getOperand(0).getOperand(0);
-      else
-        BOperand = Ld.getOperand(0).getOperand(0);
+  if (!Sequence.empty() && Subtarget.hasCDI()) {
+    // If not a splat, are the upper sequence values zeroable?
+    unsigned SeqLen = Sequence.size();
+    bool UpperZeroOrUndef =
+        SeqLen == 1 ||
+        llvm::all_of(makeArrayRef(Sequence).drop_front(), [](SDValue V) {
+          return !V || V.isUndef() || isNullConstant(V);
+        });
+    SDValue Op0 = Sequence[0];
+    if (UpperZeroOrUndef && ((Op0.getOpcode() == ISD::BITCAST) ||
+                             (Op0.getOpcode() == ISD::ZERO_EXTEND &&
+                              Op0.getOperand(0).getOpcode() == ISD::BITCAST))) {
+      SDValue BOperand = Op0.getOpcode() == ISD::BITCAST
+                             ? Op0.getOperand(0)
+                             : Op0.getOperand(0).getOperand(0);
       MVT MaskVT = BOperand.getSimpleValueType();
+      MVT EltType = MVT::getIntegerVT(VT.getScalarSizeInBits() * SeqLen);
       if ((EltType == MVT::i64 && MaskVT == MVT::v8i1) ||  // for broadcastmb2q
           (EltType == MVT::i32 && MaskVT == MVT::v16i1)) { // for broadcastmw2d
-        MVT BcstVT = MVT::getVectorVT(EltType, NumElts);
+        MVT BcstVT = MVT::getVectorVT(EltType, NumElts / SeqLen);
         if (!VT.is512BitVector() && !Subtarget.hasVLX()) {
           unsigned Scale = 512 / VT.getSizeInBits();
-          BcstVT = MVT::getVectorVT(EltType, NumElts * Scale);
+          BcstVT = MVT::getVectorVT(EltType, Scale * (NumElts / SeqLen));
         }
         SDValue Bcst = DAG.getNode(X86ISD::VBROADCASTM, dl, BcstVT, BOperand);
         if (BcstVT.getSizeInBits() != VT.getSizeInBits())
@@ -8759,7 +8730,6 @@ static SDValue lowerBuildVectorAsBroadcast(BuildVectorSDNode *BVOp,
     }
   }
 
-  unsigned NumElts = VT.getVectorNumElements();
   unsigned NumUndefElts = UndefElements.count();
   if (!Ld || (NumElts - NumUndefElts) <= 1) {
     APInt SplatValue, Undef;
@@ -8833,6 +8803,8 @@ static SDValue lowerBuildVectorAsBroadcast(BuildVectorSDNode *BVOp,
       (Ld.getOpcode() == ISD::Constant || Ld.getOpcode() == ISD::ConstantFP);
   bool IsLoad = ISD::isNormalLoad(Ld.getNode());
 
+  // TODO: Handle broadcasts of non-constant sequences.
+
   // Make sure that all of the users of a non-constant load are from the
   // BUILD_VECTOR node.
   // FIXME: Is the use count needed for non-constant, non-load case?

diff  --git a/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp b/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp
index 584ef65b20bc..d601b552af84 100644
--- a/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp
+++ b/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp
@@ -472,6 +472,127 @@ TEST_F(AArch64SelectionDAGTest, getSplatSourceVector_Scalable_ADD_of_SPLAT_VECTO
   EXPECT_EQ(SplatIdx, 0);
 }
 
+TEST_F(AArch64SelectionDAGTest, getRepeatedSequence_Patterns) {
+  if (!TM)
+    return;
+
+  TargetLowering TL(*TM);
+
+  SDLoc Loc;
+  unsigned NumElts = 16;
+  MVT IntVT = MVT::i8;
+  MVT VecVT = MVT::getVectorVT(IntVT, NumElts);
+
+  // Base scalar constants.
+  SDValue Val0 = DAG->getConstant(0, Loc, IntVT);
+  SDValue Val1 = DAG->getConstant(1, Loc, IntVT);
+  SDValue Val2 = DAG->getConstant(2, Loc, IntVT);
+  SDValue Val3 = DAG->getConstant(3, Loc, IntVT);
+  SDValue UndefVal = DAG->getUNDEF(IntVT);
+
+  // Build some repeating sequences.
+  SmallVector<SDValue, 16> Pattern1111, Pattern1133, Pattern0123;
+  for(int I = 0; I != 4; ++I) {
+    Pattern1111.append(4, Val1);
+    Pattern1133.append(2, Val1);
+    Pattern1133.append(2, Val3);
+    Pattern0123.push_back(Val0);
+    Pattern0123.push_back(Val1);
+    Pattern0123.push_back(Val2);
+    Pattern0123.push_back(Val3);
+  }
+
+  // Build a non-pow2 repeating sequence.
+  SmallVector<SDValue, 16> Pattern022;
+  Pattern022.push_back(Val0);
+  Pattern022.append(2, Val2);
+  Pattern022.push_back(Val0);
+  Pattern022.append(2, Val2);
+  Pattern022.push_back(Val0);
+  Pattern022.append(2, Val2);
+  Pattern022.push_back(Val0);
+  Pattern022.append(2, Val2);
+  Pattern022.push_back(Val0);
+  Pattern022.append(2, Val2);
+  Pattern022.push_back(Val0);
+
+  // Build a non-repeating sequence.
+  SmallVector<SDValue, 16> Pattern1_3;
+  Pattern1_3.append(8, Val1);
+  Pattern1_3.append(8, Val3);
+
+  // Add some undefs to make it trickier.
+  Pattern1111[1] = Pattern1111[2] = Pattern1111[15] = UndefVal;
+  Pattern1133[0] = Pattern1133[2] = UndefVal;
+
+  auto *BV1111 =
+      cast<BuildVectorSDNode>(DAG->getBuildVector(VecVT, Loc, Pattern1111));
+  auto *BV1133 =
+      cast<BuildVectorSDNode>(DAG->getBuildVector(VecVT, Loc, Pattern1133));
+  auto *BV0123=
+      cast<BuildVectorSDNode>(DAG->getBuildVector(VecVT, Loc, Pattern0123));
+  auto *BV022 =
+      cast<BuildVectorSDNode>(DAG->getBuildVector(VecVT, Loc, Pattern022));
+  auto *BV1_3 =
+      cast<BuildVectorSDNode>(DAG->getBuildVector(VecVT, Loc, Pattern1_3));
+
+  // Check for sequences.
+  SmallVector<SDValue, 16> Seq1111, Seq1133, Seq0123, Seq022, Seq1_3;
+  BitVector Undefs1111, Undefs1133, Undefs0123, Undefs022, Undefs1_3;
+
+  EXPECT_TRUE(BV1111->getRepeatedSequence(Seq1111, &Undefs1111));
+  EXPECT_EQ(Undefs1111.count(), 3);
+  EXPECT_EQ(Seq1111.size(), 1);
+  EXPECT_EQ(Seq1111[0], Val1);
+
+  EXPECT_TRUE(BV1133->getRepeatedSequence(Seq1133, &Undefs1133));
+  EXPECT_EQ(Undefs1133.count(), 2);
+  EXPECT_EQ(Seq1133.size(), 4);
+  EXPECT_EQ(Seq1133[0], Val1);
+  EXPECT_EQ(Seq1133[1], Val1);
+  EXPECT_EQ(Seq1133[2], Val3);
+  EXPECT_EQ(Seq1133[3], Val3);
+
+  EXPECT_TRUE(BV0123->getRepeatedSequence(Seq0123, &Undefs0123));
+  EXPECT_EQ(Undefs0123.count(), 0);
+  EXPECT_EQ(Seq0123.size(), 4);
+  EXPECT_EQ(Seq0123[0], Val0);
+  EXPECT_EQ(Seq0123[1], Val1);
+  EXPECT_EQ(Seq0123[2], Val2);
+  EXPECT_EQ(Seq0123[3], Val3);
+
+  EXPECT_FALSE(BV022->getRepeatedSequence(Seq022, &Undefs022));
+  EXPECT_FALSE(BV1_3->getRepeatedSequence(Seq1_3, &Undefs1_3));
+
+  // Try again with DemandedElts masks.
+  APInt Mask1111_0 = APInt::getOneBitSet(NumElts, 0);
+  EXPECT_TRUE(BV1111->getRepeatedSequence(Mask1111_0, Seq1111, &Undefs1111));
+  EXPECT_EQ(Undefs1111.count(), 0);
+  EXPECT_EQ(Seq1111.size(), 1);
+  EXPECT_EQ(Seq1111[0], Val1);
+
+  APInt Mask1111_1 = APInt::getOneBitSet(NumElts, 2);
+  EXPECT_TRUE(BV1111->getRepeatedSequence(Mask1111_1, Seq1111, &Undefs1111));
+  EXPECT_EQ(Undefs1111.count(), 1);
+  EXPECT_EQ(Seq1111.size(), 1);
+  EXPECT_EQ(Seq1111[0], UndefVal);
+
+  APInt Mask0123 = APInt(NumElts, 0x7777);
+  EXPECT_TRUE(BV0123->getRepeatedSequence(Mask0123, Seq0123, &Undefs0123));
+  EXPECT_EQ(Undefs0123.count(), 0);
+  EXPECT_EQ(Seq0123.size(), 4);
+  EXPECT_EQ(Seq0123[0], Val0);
+  EXPECT_EQ(Seq0123[1], Val1);
+  EXPECT_EQ(Seq0123[2], Val2);
+  EXPECT_EQ(Seq0123[3], SDValue());
+
+  APInt Mask1_3 = APInt::getHighBitsSet(16, 8);
+  EXPECT_TRUE(BV1_3->getRepeatedSequence(Mask1_3, Seq1_3, &Undefs1_3));
+  EXPECT_EQ(Undefs1_3.count(), 0);
+  EXPECT_EQ(Seq1_3.size(), 1);
+  EXPECT_EQ(Seq1_3[0], Val3);
+}
+
 TEST_F(AArch64SelectionDAGTest, getTypeConversion_SplitScalableMVT) {
   if (!TM)
     return;


        


More information about the llvm-commits mailing list