[llvm] [SandboxVec][SeedCollector] Implement collection of seeds with different types (PR #146171)

via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 27 15:29:24 PDT 2025


https://github.com/vporpo updated https://github.com/llvm/llvm-project/pull/146171

>From daa4efc68b6b5d461b0272669670abfaf72a6869 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Tue, 11 Mar 2025 13:13:12 -0700
Subject: [PATCH] [SandboxVec][SeedCollector] Implement collection of seeds
 with different types

Up until now the seed collector could only collect seeds with the same element
type. For example, `i32` and <2 x i32>`.
This patch implements the collection of seeds with different types, like `i32`
and `i8`.
---
 .../SandboxVectorizer/SeedCollector.h         | 13 +++--
 .../SandboxVectorizer/SeedCollector.cpp       | 37 +++++++++------
 .../SandboxVectorizer/SeedCollectorTest.cpp   | 47 +++++++++++++++++--
 3 files changed, 74 insertions(+), 23 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h
index ec70350691abe..b289520fa83af 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h
@@ -191,7 +191,8 @@ class SeedContainer {
 
   ScalarEvolution &SE;
 
-  template <typename LoadOrStoreT> KeyT getKey(LoadOrStoreT *LSI) const;
+  template <typename LoadOrStoreT>
+  KeyT getKey(LoadOrStoreT *LSI, bool AllowDiffTypes) const;
 
 public:
   SeedContainer(ScalarEvolution &SE) : SE(SE) {}
@@ -267,7 +268,8 @@ class SeedContainer {
     bool operator!=(const iterator &Other) const { return !(*this == Other); }
   };
   using const_iterator = BundleMapT::const_iterator;
-  template <typename LoadOrStoreT> void insert(LoadOrStoreT *LSI);
+  template <typename LoadOrStoreT>
+  void insert(LoadOrStoreT *LSI, bool AllowDiffTypes);
   // To support constant-time erase, these just mark the element used, rather
   // than actually removing them from the bundle.
   LLVM_ABI bool erase(Instruction *I);
@@ -291,9 +293,9 @@ class SeedContainer {
 
 // Explicit instantiations
 extern template LLVM_TEMPLATE_ABI void
-SeedContainer::insert<LoadInst>(LoadInst *);
+SeedContainer::insert<LoadInst>(LoadInst *, bool);
 extern template LLVM_TEMPLATE_ABI void
-SeedContainer::insert<StoreInst>(StoreInst *);
+SeedContainer::insert<StoreInst>(StoreInst *, bool);
 
 class SeedCollector {
   SeedContainer StoreSeeds;
@@ -308,7 +310,8 @@ class SeedCollector {
 
 public:
   LLVM_ABI SeedCollector(BasicBlock *BB, ScalarEvolution &SE,
-                         bool CollectStores, bool CollectLoads);
+                         bool CollectStores, bool CollectLoads,
+                         bool AllowDiffTypes = false);
   LLVM_ABI ~SeedCollector();
 
   iterator_range<SeedContainer::iterator> getStoreSeeds() {
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp
index 1934326866cdf..e80dc04e15326 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp
@@ -73,22 +73,28 @@ ArrayRef<Instruction *> SeedBundle::getSlice(unsigned StartIdx,
 }
 
 template <typename LoadOrStoreT>
-SeedContainer::KeyT SeedContainer::getKey(LoadOrStoreT *LSI) const {
+SeedContainer::KeyT SeedContainer::getKey(LoadOrStoreT *LSI,
+                                          bool AllowDiffTypes) const {
   assert((isa<LoadInst>(LSI) || isa<StoreInst>(LSI)) &&
          "Expected Load or Store!");
   Value *Ptr = Utils::getMemInstructionBase(LSI);
   Instruction::Opcode Op = LSI->getOpcode();
-  Type *Ty = Utils::getExpectedType(LSI);
-  if (auto *VTy = dyn_cast<VectorType>(Ty))
-    Ty = VTy->getElementType();
+  Type *Ty;
+  if (AllowDiffTypes) {
+    Ty = nullptr;
+  } else {
+    Ty = Utils::getExpectedType(LSI);
+    if (auto *VTy = dyn_cast<VectorType>(Ty))
+      Ty = VTy->getElementType();
+  }
   return {Ptr, Ty, Op};
 }
 
 // Explicit instantiations
 template SeedContainer::KeyT
-SeedContainer::getKey<LoadInst>(LoadInst *LSI) const;
+SeedContainer::getKey<LoadInst>(LoadInst *LSI, bool AllowDiffTypes) const;
 template SeedContainer::KeyT
-SeedContainer::getKey<StoreInst>(StoreInst *LSI) const;
+SeedContainer::getKey<StoreInst>(StoreInst *LSI, bool AllowDiffTypes) const;
 
 bool SeedContainer::erase(Instruction *I) {
   assert((isa<LoadInst>(I) || isa<StoreInst>(I)) && "Expected Load or Store!");
@@ -100,9 +106,10 @@ bool SeedContainer::erase(Instruction *I) {
   return true;
 }
 
-template <typename LoadOrStoreT> void SeedContainer::insert(LoadOrStoreT *LSI) {
+template <typename LoadOrStoreT>
+void SeedContainer::insert(LoadOrStoreT *LSI, bool AllowDiffTypes) {
   // Find the bundle containing seeds for this symbol and type-of-access.
-  auto &BundleVec = Bundles[getKey(LSI)];
+  auto &BundleVec = Bundles[getKey(LSI, AllowDiffTypes)];
   // Fill this vector of bundles front to back so that only the last bundle in
   // the vector may have available space. This avoids iteration to find one with
   // space.
@@ -115,9 +122,10 @@ template <typename LoadOrStoreT> void SeedContainer::insert(LoadOrStoreT *LSI) {
 }
 
 // Explicit instantiations
-template LLVM_EXPORT_TEMPLATE void SeedContainer::insert<LoadInst>(LoadInst *);
-template LLVM_EXPORT_TEMPLATE void
-SeedContainer::insert<StoreInst>(StoreInst *);
+template LLVM_EXPORT_TEMPLATE void SeedContainer::insert<LoadInst>(LoadInst *,
+                                                                   bool);
+template LLVM_EXPORT_TEMPLATE void SeedContainer::insert<StoreInst>(StoreInst *,
+                                                                    bool);
 
 #ifndef NDEBUG
 void SeedContainer::print(raw_ostream &OS) const {
@@ -158,7 +166,8 @@ template bool isValidMemSeed<LoadInst>(LoadInst *LSI);
 template bool isValidMemSeed<StoreInst>(StoreInst *LSI);
 
 SeedCollector::SeedCollector(BasicBlock *BB, ScalarEvolution &SE,
-                             bool CollectStores, bool CollectLoads)
+                             bool CollectStores, bool CollectLoads,
+                             bool AllowDiffTypes)
     : StoreSeeds(SE), LoadSeeds(SE), Ctx(BB->getContext()) {
 
   if (!CollectStores && !CollectLoads)
@@ -175,10 +184,10 @@ SeedCollector::SeedCollector(BasicBlock *BB, ScalarEvolution &SE,
   for (auto &I : *BB) {
     if (StoreInst *SI = dyn_cast<StoreInst>(&I))
       if (CollectStores && isValidMemSeed(SI))
-        StoreSeeds.insert(SI);
+        StoreSeeds.insert(SI, AllowDiffTypes);
     if (LoadInst *LI = dyn_cast<LoadInst>(&I))
       if (CollectLoads && isValidMemSeed(LI))
-        LoadSeeds.insert(LI);
+        LoadSeeds.insert(LI, AllowDiffTypes);
     // Cap compilation time.
     if (totalNumSeedGroups() > SeedGroupsLimit)
       break;
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp
index 7f9a59bd428a0..31b4a5ec9e391 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp
@@ -259,10 +259,10 @@ define void @foo(ptr %ptrA, float %val, ptr %ptrB) {
   // Check begin() end() when empty.
   EXPECT_EQ(SC.begin(), SC.end());
 
-  SC.insert(S0);
-  SC.insert(S1);
-  SC.insert(S2);
-  SC.insert(S3);
+  SC.insert(S0, /*AllowDiffTypes=*/false);
+  SC.insert(S1, /*AllowDiffTypes=*/false);
+  SC.insert(S2, /*AllowDiffTypes=*/false);
+  SC.insert(S3, /*AllowDiffTypes=*/false);
   unsigned Cnt = 0;
   SmallVector<sandboxir::SeedBundle *> Bndls;
   for (auto &SeedBndl : SC) {
@@ -480,6 +480,45 @@ define void @foo(ptr noalias %ptr, float %v, <2 x float> %val) {
   ExpectThatElementsAre(SB, {St0, St1, St3});
 }
 
+TEST_F(SeedBundleTest, DiffTypes) {
+  parseIR(C, R"IR(
+define void @foo(ptr noalias %ptr, i8 %v, i16 %v16) {
+bb:
+  %ptr0 = getelementptr i8, ptr %ptr, i32 0
+  %ptr1 = getelementptr i8, ptr %ptr, i32 1
+  %ptr3 = getelementptr i8, ptr %ptr, i32 3
+  store i8 %v, ptr %ptr0
+  store i8 %v, ptr %ptr3
+  store i16 %v16, ptr %ptr1
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  DominatorTree DT(LLVMF);
+  TargetLibraryInfoImpl TLII(M->getTargetTriple());
+  TargetLibraryInfo TLI(TLII);
+  DataLayout DL(M->getDataLayout());
+  LoopInfo LI(DT);
+  AssumptionCache AC(LLVMF);
+  ScalarEvolution SE(LLVMF, TLI, AC, DT, LI);
+
+  sandboxir::Context Ctx(C);
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto BB = F.begin();
+  auto It = std::next(BB->begin(), 3);
+  auto *St0 = &*It++;
+  auto *St3 = &*It++;
+  auto *St1 = &*It++;
+
+  sandboxir::SeedCollector SC(&*BB, SE, /*CollectStores=*/true,
+                              /*CollectLoads=*/false, /*AllowDiffTypes=*/true);
+
+  auto StoreSeedsRange = SC.getStoreSeeds();
+  EXPECT_EQ(range_size(StoreSeedsRange), 1u);
+  auto &SB = *StoreSeedsRange.begin();
+  ExpectThatElementsAre(SB, {St0, St1, St3});
+}
+
 TEST_F(SeedBundleTest, VectorLoads) {
   parseIR(C, R"IR(
 define void @foo(ptr noalias %ptr, <2 x float> %val0) {



More information about the llvm-commits mailing list