[llvm] [SandboxVectorizer] Add container class to track and manage SeedBundles (PR #112048)

via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 11 14:22:21 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-vectorizers

Author: None (Sterling-Augustine)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/112048.diff


3 Files Affected:

- (modified) llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h (+111-1) 
- (modified) llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp (+73) 
- (modified) llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp (+63) 


``````````diff
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h
index 619c2147f2e5c4..d0013aca5ad485 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h
@@ -54,6 +54,10 @@ class SeedBundle {
     NumUnusedBits += Utils::getNumBits(I);
   }
 
+  virtual void insert(Instruction *I, ScalarEvolution &SE) {
+    assert("Subclasses must override this function.");
+  }
+
   unsigned getFirstUnusedElementIdx() const {
     for (unsigned ElmIdx : seq<unsigned>(0, Seeds.size()))
       if (!isUsed(ElmIdx))
@@ -96,6 +100,9 @@ class SeedBundle {
   MutableArrayRef<Instruction *>
   getSlice(unsigned StartIdx, unsigned MaxVecRegBits, bool ForcePowOf2);
 
+  /// \Returns the number of seed elements in the bundle.
+  std::size_t size() const { return Seeds.size(); }
+
 protected:
   SmallVector<Instruction *> Seeds;
   /// The lanes that we have already vectorized.
@@ -148,7 +155,7 @@ template <typename LoadOrStoreT> class MemSeedBundle : public SeedBundle {
                   "Expected LoadInst or StoreInst!");
     assert(isa<LoadOrStoreT>(MemI) && "Expected Load or Store!");
   }
-  void insert(sandboxir::Instruction *I, ScalarEvolution &SE) {
+  void insert(sandboxir::Instruction *I, ScalarEvolution &SE) override {
     assert(isa<LoadOrStoreT>(I) && "Expected a Store or a Load!");
     auto Cmp = [&SE](Instruction *I0, Instruction *I1) {
       return Utils::atLowerAddress(cast<LoadOrStoreT>(I0),
@@ -162,5 +169,108 @@ template <typename LoadOrStoreT> class MemSeedBundle : public SeedBundle {
 using StoreSeedBundle = MemSeedBundle<sandboxir::StoreInst>;
 using LoadSeedBundle = MemSeedBundle<sandboxir::LoadInst>;
 
+/// Class to conveniently track Seeds within Seedbundles. Saves newly collected
+/// seeds in the proper bundle. Supports constant-time removal, as seeds and
+/// entire bundles are vectorized and marked used to signify removal. Iterators
+/// skip bundles that are completely used.
+class SeedContainer {
+  // Use the same key for different seeds if they are the same type and
+  // reference the same pointer, even if at different offsets. This directs
+  // potentially vectorizable seeds into the same bundle.
+  using KeyT = std::tuple<Value *, Type *, sandboxir::Instruction::Opcode>;
+  // Trying to vectorize too many seeds at once is expensive in
+  // compilation-time. Use a vector of bundles (all with the same key) to
+  // partition the candidate set into more manageable units. Each bundle is
+  // size-limited by sbvec-seed-bundle-size-limit.  TODO: There might be a
+  // better way to divide these than by simple insertion order.
+  using ValT = SmallVector<std::unique_ptr<SeedBundle>>;
+  using BundleMapT = MapVector<KeyT, ValT>;
+  // Map from {pointer, Type, Opcode} to a vector of bundles.
+  BundleMapT Bundles;
+  // Allows finding a particular Instruction's bundle.
+  DenseMap<sandboxir::Instruction *, SeedBundle *> SeedLookupMap;
+
+  ScalarEvolution &SE;
+
+  template <typename LoadOrStoreT> KeyT getKey(LoadOrStoreT *LSI) const;
+
+public:
+  SeedContainer(ScalarEvolution &SE) : SE(SE) {}
+
+  class iterator {
+    BundleMapT *Map = nullptr;
+    BundleMapT::iterator MapIt;
+    ValT *Vec = nullptr;
+    size_t VecIdx;
+
+  public:
+    using difference_type = std::ptrdiff_t;
+    using value_type = SeedBundle;
+    using pointer = value_type *;
+    using reference = value_type &;
+    using iterator_category = std::input_iterator_tag;
+
+    iterator(BundleMapT &Map, BundleMapT::iterator MapIt, ValT *Vec, int VecIdx)
+        : Map(&Map), MapIt(MapIt), Vec(Vec), VecIdx(VecIdx) {}
+    value_type &operator*() {
+      assert(Vec != nullptr && "Already at end!");
+      return *(*Vec)[VecIdx];
+    }
+    // Skip completely used bundles by repeatedly calling operator++().
+    void skipUsed() {
+      while (Vec && VecIdx < Vec->size() && this->operator*().allUsed())
+        ++(*this);
+    }
+    iterator &operator++() {
+      assert(VecIdx >= 0 && "Already at end!");
+      ++VecIdx;
+      if (VecIdx >= Vec->size()) {
+        assert(MapIt != Map->end() && "Already at end!");
+        VecIdx = 0;
+        ++MapIt;
+        if (MapIt != Map->end())
+          Vec = &MapIt->second;
+        else {
+          Vec = nullptr;
+        }
+      }
+      skipUsed();
+      return *this;
+    }
+    iterator operator++(int) {
+      auto Copy = *this;
+      ++(*this);
+      return Copy;
+    }
+    bool operator==(const iterator &Other) const {
+      assert(Map == Other.Map && "Iterator of different objects!");
+      return MapIt == Other.MapIt && VecIdx == Other.VecIdx;
+    }
+    bool operator!=(const iterator &Other) const { return !(*this == Other); }
+  };
+  using const_iterator = BundleMapT::const_iterator;
+  template <typename LoadOrStoreT> void insert(LoadOrStoreT *LSI);
+  // To support constant-time erase, these just mark the element used, rather
+  // than actually removing them from the bundle.
+  bool erase(sandboxir::Instruction *I);
+  bool erase(const KeyT &Key) { return Bundles.erase(Key); }
+  iterator begin() {
+    if (Bundles.empty())
+      return end();
+    auto BeginIt =
+        iterator(Bundles, Bundles.begin(), &Bundles.begin()->second, 0);
+    BeginIt.skipUsed();
+    return BeginIt;
+  }
+  iterator end() { return iterator(Bundles, Bundles.end(), nullptr, 0); }
+  unsigned size() const { return Bundles.size(); }
+
+#ifndef NDEBUG
+  void dump(raw_ostream &OS) const;
+  LLVM_DUMP_METHOD void dump() const;
+#endif // NDEBUG
+};
+
 } // namespace llvm::sandboxir
+
 #endif // LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_SEEDCOLLECTOR_H
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp
index 00a7dc3fcec93e..88a22807dcede0 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp
@@ -19,6 +19,10 @@
 using namespace llvm;
 namespace llvm::sandboxir {
 
+cl::opt<unsigned> SeedBundleSizeLimit(
+    "sbvec-seed-bundle-size-limit", cl::init(32), cl::Hidden,
+    cl::desc("Limit the size of the seed bundle to cap compilation time."));
+
 MutableArrayRef<Instruction *> SeedBundle::getSlice(unsigned StartIdx,
                                                     unsigned MaxVecRegBits,
                                                     bool ForcePowerOf2) {
@@ -61,4 +65,73 @@ MutableArrayRef<Instruction *> SeedBundle::getSlice(unsigned StartIdx,
     return {};
 }
 
+template <typename LoadOrStoreT>
+SeedContainer::KeyT SeedContainer::getKey(LoadOrStoreT *LSI) const {
+  assert((isa<LoadInst>(LSI) || isa<StoreInst>(LSI)) &&
+         "Expected Load or Store!");
+  Value *Ptr = Utils::getMemInstructionBase(LSI);
+  Instruction::Opcode Op = LSI->getOpcode();
+  Type *Ty = Utils::getExpectedType(LSI);
+  if (auto *VTy = dyn_cast<VectorType>(Ty))
+    Ty = VTy->getElementType();
+  return {Ptr, Ty, Op};
+}
+
+// Explicit instantiations
+template SeedContainer::KeyT
+SeedContainer::getKey<LoadInst>(LoadInst *LSI) const;
+template SeedContainer::KeyT
+SeedContainer::getKey<StoreInst>(StoreInst *LSI) const;
+
+bool SeedContainer::erase(Instruction *I) {
+  assert((isa<LoadInst>(I) || isa<StoreInst>(I)) && "Expected Load or Store!");
+  auto It = SeedLookupMap.find(I);
+  if (It == SeedLookupMap.end())
+    return false;
+  SeedBundle *Bndl = It->second;
+  Bndl->setUsed(I);
+  return true;
+}
+
+template <typename LoadOrStoreT> void SeedContainer::insert(LoadOrStoreT *LSI) {
+  // Find the bundle containing seeds for this symbol and type-of-access.
+  auto &BundleVec = Bundles[getKey(LSI)];
+  // Fill this vector of bundles front to back so that only the last bundle in
+  // the vector may have available space. This avoids iteration to find one with
+  // space.
+  if (BundleVec.empty() || BundleVec.back()->size() == SeedBundleSizeLimit)
+    BundleVec.emplace_back(std::make_unique<MemSeedBundle<LoadOrStoreT>>(LSI));
+  else
+    BundleVec.back()->insert(LSI, SE);
+
+  SeedLookupMap[LSI] = BundleVec.back().get();
+  this->dump();
+}
+
+// Explicit instantiations
+template void SeedContainer::insert<LoadInst>(LoadInst *);
+template void SeedContainer::insert<StoreInst>(StoreInst *);
+
+#ifndef NDEBUG
+void SeedContainer::dump(raw_ostream &OS) const {
+  for (const auto &Pair : Bundles) {
+    auto [I, Ty, Opc] = Pair.first;
+    const auto &SeedsVec = Pair.second;
+    std::string RefType = dyn_cast<LoadInst>(I)    ? "Load"
+                          : dyn_cast<StoreInst>(I) ? "Store"
+                                                   : "Other";
+    OS << "[Inst=" << *I << " Ty=" << Ty << " " << RefType << "]\n";
+    for (const auto &SeedPtr : SeedsVec) {
+      SeedPtr->dump(OS);
+      OS << "\n";
+    }
+  }
+}
+
+void SeedContainer::dump() const {
+  dump(dbgs());
+  dbgs() << "\n";
+}
+#endif // NDEBUG
+
 } // namespace llvm::sandboxir
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp
index dd41b0a6605095..0c5523d88ff9fc 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp
@@ -196,3 +196,66 @@ define void @foo(ptr %ptrA, float %val, ptr %ptr) {
   sandboxir::LoadSeedBundle LB(std::move(Loads), SE);
   EXPECT_THAT(LB, testing::ElementsAre(L0, L1, L2, L3));
 }
+
+TEST_F(SeedBundleTest, Container) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptrA, float %val, ptr %ptrB) {
+bb:
+  %gepA0 = getelementptr float, ptr %ptrA, i32 0
+  %gepA1 = getelementptr float, ptr %ptrA, i32 1
+  %gepB0 = getelementptr float, ptr %ptrB, i32 0
+  %gepB1 = getelementptr float, ptr %ptrB, i32 1
+  store float %val, ptr %gepA0
+  store float %val, ptr %gepA1
+  store float %val, ptr %gepB0
+  store float %val, ptr %gepB1
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+
+  DominatorTree DT(LLVMF);
+  TargetLibraryInfoImpl TLII;
+  TargetLibraryInfo TLI(TLII);
+  DataLayout DL(M->getDataLayout());
+  LoopInfo LI(DT);
+  AssumptionCache AC(LLVMF);
+  ScalarEvolution SE(LLVMF, TLI, AC, DT, LI);
+
+  sandboxir::Context Ctx(C);
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto &BB = *F.begin();
+  auto It = std::next(BB.begin(), 4);
+  auto *S0 = cast<sandboxir::StoreInst>(&*It++);
+  auto *S1 = cast<sandboxir::StoreInst>(&*It++);
+  auto *S2 = cast<sandboxir::StoreInst>(&*It++);
+  auto *S3 = cast<sandboxir::StoreInst>(&*It++);
+  sandboxir::SeedContainer SC(SE);
+  // Check begin() end() when empty.
+  EXPECT_EQ(SC.begin(), SC.end());
+
+  SC.insert(S0);
+  SC.insert(S1);
+  SC.insert(S2);
+  SC.insert(S3);
+  unsigned Cnt = 0;
+  SmallVector<sandboxir::SeedBundle *> Bndls;
+  for (auto &SeedBndl : SC) {
+    EXPECT_EQ(SeedBndl.size(), 2u);
+    ++Cnt;
+    Bndls.push_back(&SeedBndl);
+  }
+  EXPECT_EQ(Cnt, 2u);
+
+  // Mark them "Used" to check if operator++ skips them in the next loop.
+  for (auto *SeedBndl : Bndls)
+    for (auto Lane : seq<unsigned>(SeedBndl->size()))
+      SeedBndl->setUsed(Lane);
+  // Check if iterator::operator++ skips used lanes.
+  Cnt = 0;
+  for (auto &SeedBndl : SC) {
+    (void)SeedBndl;
+    ++Cnt;
+  }
+  EXPECT_EQ(Cnt, 0u);
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/112048


More information about the llvm-commits mailing list