[llvm] 01d8759 - [IR][ShuffleVector] Introduce `isReplicationMask()` matcher

Roman Lebedev via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 5 06:54:01 PDT 2021


Author: Roman Lebedev
Date: 2021-11-05T16:53:47+03:00
New Revision: 01d8759ac9ad73cf9235f4f85abdfaa26abaaef5

URL: https://github.com/llvm/llvm-project/commit/01d8759ac9ad73cf9235f4f85abdfaa26abaaef5
DIFF: https://github.com/llvm/llvm-project/commit/01d8759ac9ad73cf9235f4f85abdfaa26abaaef5.diff

LOG: [IR][ShuffleVector] Introduce `isReplicationMask()` matcher

Avid readers of this saga may recall from previous installments,
that replication mask replicates (lol) each of the `VF` elements
in a vector `ReplicationFactor` times. For example, the mask for
`ReplicationFactor=3` and `VF=4` is: `<0,0,0,1,1,1,2,2,2,3,3,3>`.
More importantly, replication mask is used by LoopVectorizer
when using masked interleaved memory operations.

As discussed in previous installments, while it is used by LV,
and we **seem** to support masked interleaved memory operations on X86,
it's support in cost model leaves a lot to be desired:
until basically yesterday even for AVX512 we had no cost model for it.

As it has been witnessed in the recent
AVX2 `X86TTIImpl::getInterleavedMemoryOpCost()`
costmodel patches, while it is hard-enough to query the cost
of a particular assembly sequence [from llvm-mca],
afterwards the check lines LV costmodel tests must be updated manually.
This is, at the very least, boring.

Okay, now we have decent costmodel coverage for interleaving shuffles,
but now basically the same mind-killing sequence has to be performed
for replication mask. I think we can improve at least the second half
of the problem, by teaching
the `TargetTransformInfoImplCRTPBase::getUserCost()` to recognize
`Instruction::ShuffleVector` that are repetition masks,
adding exhaustive test coverage
using `-cost-model -analyze` + `utils/update_analyze_test_checks.py`

This way we can have good exhaustive coverage for cost model,
and only basic coverage for the LV costmodel.

This patch adds precise undef-aware `isReplicationMask()`,
with exhaustive test coverage.
* `InstructionsTest.ShuffleMaskIsReplicationMask` shows that
   it correctly detects all the known masks.
* `InstructionsTest.ShuffleMaskIsReplicationMask_undef`
  shows that replacing some mask elements in a known replication mask
  still allows us to recognize it as a replication mask.
  Note, with enough undef elts, we may detect a different tuple.
* `InstructionsTest.ShuffleMaskIsReplicationMask_Exhaustive_Correctness`
  shows that if we detected the replication mask with given params,
  then if we actually generate a true replication mask with said params,
  it matches element-wise ignoring undef mask elements.

Reviewed By: spatel

Differential Revision: https://reviews.llvm.org/D113214

Added: 
    

Modified: 
    llvm/include/llvm/IR/Instructions.h
    llvm/lib/IR/Instructions.cpp
    llvm/unittests/IR/InstructionsTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/Instructions.h b/llvm/include/llvm/IR/Instructions.h
index 59a783dbdf8d8..cf8ef5be66c5f 100644
--- a/llvm/include/llvm/IR/Instructions.h
+++ b/llvm/include/llvm/IR/Instructions.h
@@ -2354,6 +2354,34 @@ class ShuffleVectorInst : public Instruction {
     return isInsertSubvectorMask(ShuffleMask, NumSrcElts, NumSubElts, Index);
   }
 
+  /// Return true if this shuffle mask replicates each of the \p VF elements
+  /// in a vector \p ReplicationFactor times.
+  /// For example, the mask for \p ReplicationFactor=3 and \p VF=4 is:
+  ///   <0,0,0,1,1,1,2,2,2,3,3,3>
+  static bool isReplicationMask(ArrayRef<int> Mask, int &ReplicationFactor,
+                                int &VF);
+  static bool isReplicationMask(const Constant *Mask, int &ReplicationFactor,
+                                int &VF) {
+    assert(Mask->getType()->isVectorTy() && "Shuffle needs vector constant.");
+    // Not possible to express a shuffle mask for a scalable vector for this
+    // case.
+    if (isa<ScalableVectorType>(Mask->getType()))
+      return false;
+    SmallVector<int, 16> MaskAsInts;
+    getShuffleMask(Mask, MaskAsInts);
+    return isReplicationMask(MaskAsInts, ReplicationFactor, VF);
+  }
+
+  /// Return true if this shuffle mask is an replication mask.
+  bool isReplicationMask(int &ReplicationFactor, int &VF) const {
+    // Not possible to express a shuffle mask for a scalable vector for this
+    // case.
+    if (isa<ScalableVectorType>(getType()))
+      return false;
+
+    return isReplicationMask(ShuffleMask, ReplicationFactor, VF);
+  }
+
   /// Change values in a shuffle permute mask assuming the two vector operands
   /// of length InVecNumElts have swapped position.
   static void commuteShuffleMask(MutableArrayRef<int> Mask,

diff  --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 1bc4ebc7ac162..63dd07543f434 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -2436,6 +2436,72 @@ bool ShuffleVectorInst::isConcat() const {
   return isIdentityMaskImpl(getShuffleMask(), NumMaskElts);
 }
 
+static bool isReplicationMaskWithParams(ArrayRef<int> Mask,
+                                        int ReplicationFactor, int VF) {
+  assert(Mask.size() == (unsigned)ReplicationFactor * VF &&
+         "Unexpected mask size.");
+
+  for (int CurrElt : seq(0, VF)) {
+    ArrayRef<int> CurrSubMask = Mask.take_front(ReplicationFactor);
+    assert(CurrSubMask.size() == (unsigned)ReplicationFactor &&
+           "Run out of mask?");
+    Mask = Mask.drop_front(ReplicationFactor);
+    if (!all_of(CurrSubMask, [CurrElt](int MaskElt) {
+          return MaskElt == UndefMaskElem || MaskElt == CurrElt;
+        }))
+      return false;
+  }
+  assert(Mask.empty() && "Did not consume the whole mask?");
+
+  return true;
+}
+
+bool ShuffleVectorInst::isReplicationMask(ArrayRef<int> Mask,
+                                          int &ReplicationFactor, int &VF) {
+  // undef-less case is trivial.
+  if (none_of(Mask, [](int MaskElt) { return MaskElt == UndefMaskElem; })) {
+    ReplicationFactor =
+        Mask.take_while([](int MaskElt) { return MaskElt == 0; }).size();
+    if (ReplicationFactor == 0 || Mask.size() % ReplicationFactor != 0)
+      return false;
+    VF = Mask.size() / ReplicationFactor;
+    return isReplicationMaskWithParams(Mask, ReplicationFactor, VF);
+  }
+
+  // However, if the mask contains undef's, we have to enumerate possible tuples
+  // and pick one. There are bounds on replication factor: [1, mask size]
+  // (where RF=1 is an identity shuffle, RF=mask size is a broadcast shuffle)
+  // Additionally, mask size is a replication factor multiplied by vector size,
+  // which further significantly reduces the search space.
+
+  // Before doing that, let's perform basic sanity check first.
+  int Largest = -1;
+  for (int MaskElt : Mask) {
+    if (MaskElt == UndefMaskElem)
+      continue;
+    // Elements must be in non-decreasing order.
+    if (MaskElt < Largest)
+      return false;
+    Largest = std::max(Largest, MaskElt);
+  }
+
+  // Prefer larger replication factor if all else equal.
+  for (int PossibleReplicationFactor :
+       reverse(seq_inclusive<unsigned>(1, Mask.size()))) {
+    if (Mask.size() % PossibleReplicationFactor != 0)
+      continue;
+    int PossibleVF = Mask.size() / PossibleReplicationFactor;
+    if (!isReplicationMaskWithParams(Mask, PossibleReplicationFactor,
+                                     PossibleVF))
+      continue;
+    ReplicationFactor = PossibleReplicationFactor;
+    VF = PossibleVF;
+    return true;
+  }
+
+  return false;
+}
+
 //===----------------------------------------------------------------------===//
 //                             InsertValueInst Class
 //===----------------------------------------------------------------------===//

diff  --git a/llvm/unittests/IR/InstructionsTest.cpp b/llvm/unittests/IR/InstructionsTest.cpp
index f3fc08ca583c6..a60ee79acfd0a 100644
--- a/llvm/unittests/IR/InstructionsTest.cpp
+++ b/llvm/unittests/IR/InstructionsTest.cpp
@@ -6,10 +6,12 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "llvm/AsmParser/Parser.h"
 #include "llvm/IR/Instructions.h"
+#include "llvm/ADT/CombinationGenerator.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Analysis/ValueTracking.h"
+#include "llvm/Analysis/VectorUtils.h"
+#include "llvm/AsmParser/Parser.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DataLayout.h"
@@ -1115,6 +1117,92 @@ TEST(InstructionsTest, ShuffleMaskQueries) {
   delete Id15;
 }
 
+TEST(InstructionsTest, ShuffleMaskIsReplicationMask) {
+  for (int ReplicationFactor : seq_inclusive(1, 8)) {
+    for (int VF : seq_inclusive(1, 8)) {
+      const auto ReplicatedMask = createReplicatedMask(ReplicationFactor, VF);
+      int GuessedReplicationFactor = -1, GuessedVF = -1;
+      EXPECT_TRUE(ShuffleVectorInst::isReplicationMask(
+          ReplicatedMask, GuessedReplicationFactor, GuessedVF));
+      EXPECT_EQ(GuessedReplicationFactor, ReplicationFactor);
+      EXPECT_EQ(GuessedVF, VF);
+    }
+  }
+}
+
+TEST(InstructionsTest, ShuffleMaskIsReplicationMask_undef) {
+  for (int ReplicationFactor : seq_inclusive(1, 6)) {
+    for (int VF : seq_inclusive(1, 4)) {
+      const auto ReplicatedMask = createReplicatedMask(ReplicationFactor, VF);
+      int GuessedReplicationFactor = -1, GuessedVF = -1;
+
+      // If we change some mask elements to undef, we should still match.
+
+      SmallVector<SmallVector<bool>> ElementChoices(ReplicatedMask.size(),
+                                                    {false, true});
+
+      CombinationGenerator<bool, decltype(ElementChoices)::value_type,
+                           /*variable_smallsize=*/4>
+          G(ElementChoices);
+
+      G.generate([&](ArrayRef<bool> UndefOverrides) -> bool {
+        SmallVector<int> AdjustedMask;
+        AdjustedMask.reserve(ReplicatedMask.size());
+        for (auto I : zip(ReplicatedMask, UndefOverrides))
+          AdjustedMask.emplace_back(std::get<1>(I) ? -1 : std::get<0>(I));
+        assert(AdjustedMask.size() == ReplicatedMask.size() &&
+               "Size misprediction");
+
+        EXPECT_TRUE(ShuffleVectorInst::isReplicationMask(
+            AdjustedMask, GuessedReplicationFactor, GuessedVF));
+        // Do not check GuessedReplicationFactor and GuessedVF,
+        // with enough undef's we may deduce a 
diff erent tuple.
+
+        return /*Abort=*/false;
+      });
+    }
+  }
+}
+
+TEST(InstructionsTest, ShuffleMaskIsReplicationMask_Exhaustive_Correctness) {
+  for (int ShufMaskNumElts : seq_inclusive(1, 8)) {
+    SmallVector<int> PossibleShufMaskElts;
+    PossibleShufMaskElts.reserve(ShufMaskNumElts + 2);
+    for (int PossibleShufMaskElt : seq_inclusive(-1, ShufMaskNumElts))
+      PossibleShufMaskElts.emplace_back(PossibleShufMaskElt);
+    assert(PossibleShufMaskElts.size() == ShufMaskNumElts + 2U &&
+           "Size misprediction");
+
+    SmallVector<SmallVector<int>> ElementChoices(ShufMaskNumElts,
+                                                 PossibleShufMaskElts);
+
+    CombinationGenerator<int, decltype(ElementChoices)::value_type,
+                         /*variable_smallsize=*/4>
+        G(ElementChoices);
+
+    G.generate([&](ArrayRef<int> Mask) -> bool {
+      int GuessedReplicationFactor = -1, GuessedVF = -1;
+      bool Match = ShuffleVectorInst::isReplicationMask(
+          Mask, GuessedReplicationFactor, GuessedVF);
+      if (!Match)
+        return /*Abort=*/false;
+
+      const auto ActualMask =
+          createReplicatedMask(GuessedReplicationFactor, GuessedVF);
+      EXPECT_EQ(Mask.size(), ActualMask.size());
+      for (auto I : zip(Mask, ActualMask)) {
+        int Elt = std::get<0>(I);
+        int ActualElt = std::get<0>(I);
+
+        if (Elt != -1)
+          EXPECT_EQ(Elt, ActualElt);
+      }
+
+      return /*Abort=*/false;
+    });
+  }
+}
+
 TEST(InstructionsTest, GetSplat) {
   // Create the elements for various constant vectors.
   LLVMContext Ctx;


        


More information about the llvm-commits mailing list