[llvm] [SandboxVectorizer] New class to actually collect and manage seeds (PR #112979)

via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 21 15:34:15 PDT 2024


https://github.com/Sterling-Augustine updated https://github.com/llvm/llvm-project/pull/112979

>From d11ad2da70a33139b6018d1eb640a545a0c845f0 Mon Sep 17 00:00:00 2001
From: Sterling Augustine <saugustine at google.com>
Date: Fri, 18 Oct 2024 13:35:05 -0700
Subject: [PATCH 1/5] [SandboxVectorizer] New class to actually collect and
 manage memory seeds

There are many more tests to add, but I would like to get this reviewed
before it grows too big.
---
 llvm/include/llvm/SandboxIR/Utils.h           |  10 ++
 .../SandboxVectorizer/SeedCollector.h         |  30 ++++
 .../SandboxVectorizer/SeedCollector.cpp       |  87 +++++++++
 .../SandboxVectorizer/SeedCollectorTest.cpp   | 168 ++++++++++++++++++
 4 files changed, 295 insertions(+)

diff --git a/llvm/include/llvm/SandboxIR/Utils.h b/llvm/include/llvm/SandboxIR/Utils.h
index a73498adea1d59..c416e3497b67e6 100644
--- a/llvm/include/llvm/SandboxIR/Utils.h
+++ b/llvm/include/llvm/SandboxIR/Utils.h
@@ -60,6 +60,16 @@ class Utils {
         getUnderlyingObject(LSI->getPointerOperand()->Val));
   }
 
+  /// \Returns the number of elements in \p Ty, that is the number of lanes
+  /// if a fixed vector or 1 if scalar. ScalableVectors
+  static int getNumElements(Type *Ty) {
+    return Ty->isVectorTy() ? cast<FixedVectorType>(Ty)->getNumElements() : 1;
+  }
+  /// Returns \p Ty if scalar or its element type if vector.
+  static Type *getElementType(Type *Ty) {
+    return Ty->isVectorTy() ? cast<FixedVectorType>(Ty)->getElementType() : Ty;
+  }
+
   /// \Returns the number of bits required to represent the operands or return
   /// value of \p V in \p DL.
   static unsigned getNumBits(Value *V, const DataLayout &DL) {
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h
index a4512862136a8b..1e55fa0f0a5688 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h
@@ -284,6 +284,36 @@ class SeedContainer {
 #endif // NDEBUG
 };
 
+class SeedCollector {
+  SeedContainer StoreSeeds;
+  SeedContainer LoadSeeds;
+  BasicBlock *BB;
+  Context &Ctx;
+
+  /// \Returns the number of SeedBundle groups for all seed types.
+  /// This is to be used for limiting compilation time.
+  unsigned totalNumSeedGroups() const {
+    return StoreSeeds.size() + LoadSeeds.size();
+  }
+
+public:
+  SeedCollector(BasicBlock *SBBB, ScalarEvolution &SE);
+  ~SeedCollector();
+
+  BasicBlock *getBasicBlock() { return BB; }
+
+  iterator_range<SeedContainer::iterator> getStoreSeeds() {
+    return {StoreSeeds.begin(), StoreSeeds.end()};
+  }
+  iterator_range<SeedContainer::iterator> getLoadSeeds() {
+    return {LoadSeeds.begin(), LoadSeeds.end()};
+  }
+#ifndef NDEBUG
+  void print(raw_ostream &OS) const;
+  LLVM_DUMP_METHOD void dump() const;
+#endif
+};
+
 } // namespace llvm::sandboxir
 
 #endif // LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_SEEDCOLLECTOR_H
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp
index 66fac080a7b7cc..806671a10a7d61 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp
@@ -22,6 +22,23 @@ namespace llvm::sandboxir {
 cl::opt<unsigned> SeedBundleSizeLimit(
     "sbvec-seed-bundle-size-limit", cl::init(32), cl::Hidden,
     cl::desc("Limit the size of the seed bundle to cap compilation time."));
+cl::opt<bool>
+    DisableStoreSeeds("sbvec-disable-store-seeds", cl::init(false), cl::Hidden,
+                      cl::desc("Don't collect store seed instructions."));
+cl::opt<bool>
+    DisableLoadSeeds("sbvec-disable-load-seeds", cl::init(true), cl::Hidden,
+                     cl::desc("Don't collect load seed instructions."));
+
+#define LoadSeedsDef "loads"
+#define StoreSeedsDef "stores"
+cl::opt<std::string>
+    ForceSeed("sbvec-force-seeds", cl::init(""), cl::Hidden,
+              cl::desc("Enable only this type of seeds. This can be one "
+                       "of: '" LoadSeedsDef "','" StoreSeedsDef "'."));
+cl::opt<unsigned> SeedGroupsLimit(
+    "sbvec-seed-groups-limit", cl::init(256), cl::Hidden,
+    cl::desc("Limit the number of collected seeds groups in a BB to "
+             "cap compilation time."));
 
 MutableArrayRef<Instruction *> SeedBundle::getSlice(unsigned StartIdx,
                                                     unsigned MaxVecRegBits,
@@ -131,4 +148,74 @@ void SeedContainer::print(raw_ostream &OS) const {
 LLVM_DUMP_METHOD void SeedContainer::dump() const { print(dbgs()); }
 #endif // NDEBUG
 
+template <typename LoadOrStoreT> static bool isValidMemSeed(LoadOrStoreT *LSI) {
+  if (LSI->isSimple())
+    return true;
+  auto *Ty = Utils::getExpectedType(LSI);
+  // Omit types that are architecturally unvectorizable
+  if (Ty->isX86_FP80Ty() || Ty->isPPC_FP128Ty())
+    return false;
+  // Omit vector types without compile-time-known lane counts
+  if (isa<ScalableVectorType>(Ty))
+    return false;
+  if (auto *VTy = dyn_cast<FixedVectorType>(Ty))
+    return VectorType::isValidElementType(VTy->getElementType());
+  return VectorType::isValidElementType(Ty);
+}
+
+template bool isValidMemSeed(LoadInst *LSI);
+template bool isValidMemSeed<StoreInst>(StoreInst *LSI);
+
+SeedCollector::SeedCollector(BasicBlock *SBBB, ScalarEvolution &SE)
+    : StoreSeeds(SE), LoadSeeds(SE), BB(SBBB), Ctx(BB->getContext()) {
+  // TODO: Register a callback for updating the Collector datastructures upon
+  // instr removal
+
+  bool CollectStores = !DisableStoreSeeds;
+  bool CollectLoads = !DisableLoadSeeds;
+  if (LLVM_UNLIKELY(!ForceSeed.empty())) {
+    CollectStores = false;
+    CollectLoads = false;
+    // Enable only the selected one.
+    if (ForceSeed == StoreSeedsDef)
+      CollectStores = true;
+    else if (ForceSeed == LoadSeedsDef)
+      CollectLoads = true;
+    else {
+      errs() << "Bad argument '" << ForceSeed << "' in -" << ForceSeed.ArgStr
+             << "='" << ForceSeed << "'.\n";
+      errs() << "Description: " << ForceSeed.HelpStr << "\n";
+      exit(1);
+    }
+  }
+  // Actually collect the seeds.
+  for (auto &I : *BB) {
+    if (StoreInst *SI = dyn_cast<StoreInst>(&I))
+      if (CollectStores && isValidMemSeed(SI))
+        StoreSeeds.insert(SI);
+    if (LoadInst *LI = dyn_cast<LoadInst>(&I))
+      if (CollectLoads && isValidMemSeed(LI))
+        LoadSeeds.insert(LI);
+    // Cap compilation time.
+    if (totalNumSeedGroups() > SeedGroupsLimit)
+      break;
+  }
+}
+
+SeedCollector::~SeedCollector() {
+  // TODO: Unregister the callback for updating the seed datastructures upon
+  // instr removal
+}
+
+#ifndef NDEBUG
+void SeedCollector::print(raw_ostream &OS) const {
+  OS << "=== StoreSeeds ===\n";
+  StoreSeeds.print(OS);
+  OS << "=== LoadSeeds ===\n";
+  LoadSeeds.print(OS);
+}
+
+void SeedCollector::dump() const { print(dbgs()); }
+#endif
+
 } // namespace llvm::sandboxir
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp
index 82b230d50c4ec9..1dad0a707c73c8 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp
@@ -268,3 +268,171 @@ define void @foo(ptr %ptrA, float %val, ptr %ptrB) {
   }
   EXPECT_EQ(Cnt, 0u);
 }
+
+TEST_F(SeedBundleTest, ConsecutiveStores) {
+  // Where "Consecutive" means the stores address consecutive locations in
+  // memory, but not in program order. Check to see that the collector puts them
+  // in the proper order for vectorization.
+  parseIR(C, R"IR(
+define void @foo(ptr noalias %ptr, float %val) {
+bb:
+  %ptr0 = getelementptr float, ptr %ptr, i32 0
+  %ptr1 = getelementptr float, ptr %ptr, i32 1
+  %ptr2 = getelementptr float, ptr %ptr, i32 2
+  %ptr3 = getelementptr float, ptr %ptr, i32 3
+  store float %val, ptr %ptr0
+  store float %val, ptr %ptr2
+  store float %val, ptr %ptr1
+  store float %val, ptr %ptr3
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  DominatorTree DT(LLVMF);
+  TargetLibraryInfoImpl TLII;
+  TargetLibraryInfo TLI(TLII);
+  DataLayout DL(M->getDataLayout());
+  LoopInfo LI(DT);
+  AssumptionCache AC(LLVMF);
+  ScalarEvolution SE(LLVMF, TLI, AC, DT, LI);
+
+  sandboxir::Context Ctx(C);
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto BB = F.begin();
+  sandboxir::SeedCollector SC(&*BB, SE);
+
+  // Find the stores
+  auto It = std::next(BB->begin(), 4);
+  // StX with X as the order by offset in memory
+  auto *St0 = &*It++;
+  auto *St2 = &*It++;
+  auto *St1 = &*It++;
+  auto *St3 = &*It++;
+
+  auto StoreSeedsRange = SC.getStoreSeeds();
+  auto &SB = *StoreSeedsRange.begin();
+  // Expect just one vector of store seeds
+  EXPECT_TRUE(std::next(StoreSeedsRange.begin()) == StoreSeedsRange.end());
+  EXPECT_THAT(SB, testing::ElementsAre(St0, St1, St2, St3));
+}
+
+TEST_F(SeedBundleTest, StoresWithGaps) {
+  parseIR(C, R"IR(
+define void @foo(ptr noalias %ptr, float %val) {
+bb:
+  %ptr0 = getelementptr float, ptr %ptr, i32 0
+  %ptr1 = getelementptr float, ptr %ptr, i32 3
+  %ptr2 = getelementptr float, ptr %ptr, i32 5
+  %ptr3 = getelementptr float, ptr %ptr, i32 7
+  store float %val, ptr %ptr0
+  store float %val, ptr %ptr2
+  store float %val, ptr %ptr1
+  store float %val, ptr %ptr3
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  DominatorTree DT(LLVMF);
+  TargetLibraryInfoImpl TLII;
+  TargetLibraryInfo TLI(TLII);
+  DataLayout DL(M->getDataLayout());
+  LoopInfo LI(DT);
+  AssumptionCache AC(LLVMF);
+  ScalarEvolution SE(LLVMF, TLI, AC, DT, LI);
+
+  sandboxir::Context Ctx(C);
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto BB = F.begin();
+  sandboxir::SeedCollector SC(&*BB, SE);
+
+  // Find the stores
+  auto It = std::next(BB->begin(), 4);
+  // StX with X as the order by offset in memory
+  auto *St0 = &*It++;
+  auto *St2 = &*It++;
+  auto *St1 = &*It++;
+  auto *St3 = &*It++;
+
+  auto StoreSeedsRange = SC.getStoreSeeds();
+  auto &SB = *StoreSeedsRange.begin();
+  // Expect just one vector of store seeds
+  EXPECT_TRUE(std::next(StoreSeedsRange.begin()) == StoreSeedsRange.end());
+  EXPECT_THAT(SB, testing::ElementsAre(St0, St1, St2, St3));
+}
+
+TEST_F(SeedBundleTest, VectorStores) {
+  parseIR(C, R"IR(
+define void @foo(ptr noalias %ptr, <2 x float> %val) {
+bb:
+  %ptr0 = getelementptr float, ptr %ptr, i32 0
+  %ptr2 = getelementptr float, ptr %ptr, i32 2
+  store <2 x float> %val, ptr %ptr2
+  store <2 x float> %val, ptr %ptr0
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  DominatorTree DT(LLVMF);
+  TargetLibraryInfoImpl TLII;
+  TargetLibraryInfo TLI(TLII);
+  DataLayout DL(M->getDataLayout());
+  LoopInfo LI(DT);
+  AssumptionCache AC(LLVMF);
+  ScalarEvolution SE(LLVMF, TLI, AC, DT, LI);
+
+  sandboxir::Context Ctx(C);
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto BB = F.begin();
+  sandboxir::SeedCollector SC(&*BB, SE);
+
+  // Find the stores
+  auto It = std::next(BB->begin(), 2);
+  // StX with X as the order by offset in memory
+  auto *St2 = &*It++;
+  auto *St0 = &*It++;
+
+  auto StoreSeedsRange = SC.getStoreSeeds();
+  auto &SB = *StoreSeedsRange.begin();
+  EXPECT_TRUE(std::next(StoreSeedsRange.begin()) == StoreSeedsRange.end());
+  EXPECT_THAT(SB, testing::ElementsAre(St0, St2));
+}
+
+TEST_F(SeedBundleTest, MixedScalarVectors) {
+  parseIR(C, R"IR(
+define void @foo(ptr noalias %ptr, float %v, <2 x float> %val) {
+bb:
+  %ptr0 = getelementptr float, ptr %ptr, i32 0
+  %ptr1 = getelementptr float, ptr %ptr, i32 1
+  %ptr3 = getelementptr float, ptr %ptr, i32 3
+  store float %v, ptr %ptr0
+  store float %v, ptr %ptr3
+  store <2 x float> %val, ptr %ptr1
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  DominatorTree DT(LLVMF);
+  TargetLibraryInfoImpl TLII;
+  TargetLibraryInfo TLI(TLII);
+  DataLayout DL(M->getDataLayout());
+  LoopInfo LI(DT);
+  AssumptionCache AC(LLVMF);
+  ScalarEvolution SE(LLVMF, TLI, AC, DT, LI);
+
+  sandboxir::Context Ctx(C);
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto BB = F.begin();
+  sandboxir::SeedCollector SC(&*BB, SE);
+
+  // Find the stores
+  auto It = std::next(BB->begin(), 3);
+  // StX with X as the order by offset in memory
+  auto *St0 = &*It++;
+  auto *St3 = &*It++;
+  auto *St1 = &*It++;
+
+  auto &SB = *SC.getStoreSeeds().begin();
+  EXPECT_TRUE(std::next(SC.getStoreSeeds().begin()) ==
+              SC.getStoreSeeds().end());
+  EXPECT_THAT(SB, testing::ElementsAre(St0, St1, St3));
+}

>From 93a101645e1673e52617b7ff1de4eb141bb519fe Mon Sep 17 00:00:00 2001
From: Sterling Augustine <saugustine at google.com>
Date: Fri, 18 Oct 2024 16:53:23 -0700
Subject: [PATCH 2/5] Address comments

---
 llvm/include/llvm/SandboxIR/Utils.h           |  6 ++-
 .../SandboxVectorizer/SeedCollector.cpp       | 44 +++++--------------
 .../SandboxVectorizer/SeedCollectorTest.cpp   | 14 +++---
 3 files changed, 23 insertions(+), 41 deletions(-)

diff --git a/llvm/include/llvm/SandboxIR/Utils.h b/llvm/include/llvm/SandboxIR/Utils.h
index c416e3497b67e6..144d4255c2818a 100644
--- a/llvm/include/llvm/SandboxIR/Utils.h
+++ b/llvm/include/llvm/SandboxIR/Utils.h
@@ -60,9 +60,11 @@ class Utils {
         getUnderlyingObject(LSI->getPointerOperand()->Val));
   }
 
-  /// \Returns the number of elements in \p Ty, that is the number of lanes
-  /// if a fixed vector or 1 if scalar. ScalableVectors
+  /// \Returns the number of elements in \p Ty. That is the number of lanes if a
+  /// fixed vector or 1 if scalar. ScalableVectors have unknown size and
+  /// therefore are unsupported.
   static int getNumElements(Type *Ty) {
+    assert(!isa<ScalableVectorType>(Ty));
     return Ty->isVectorTy() ? cast<FixedVectorType>(Ty)->getNumElements() : 1;
   }
   /// Returns \p Ty if scalar or its element type if vector.
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp
index 806671a10a7d61..5ca9c84877be24 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp
@@ -22,19 +22,12 @@ namespace llvm::sandboxir {
 cl::opt<unsigned> SeedBundleSizeLimit(
     "sbvec-seed-bundle-size-limit", cl::init(32), cl::Hidden,
     cl::desc("Limit the size of the seed bundle to cap compilation time."));
-cl::opt<bool>
-    DisableStoreSeeds("sbvec-disable-store-seeds", cl::init(false), cl::Hidden,
-                      cl::desc("Don't collect store seed instructions."));
-cl::opt<bool>
-    DisableLoadSeeds("sbvec-disable-load-seeds", cl::init(true), cl::Hidden,
-                     cl::desc("Don't collect load seed instructions."));
-
 #define LoadSeedsDef "loads"
 #define StoreSeedsDef "stores"
-cl::opt<std::string>
-    ForceSeed("sbvec-force-seeds", cl::init(""), cl::Hidden,
-              cl::desc("Enable only this type of seeds. This can be one "
-                       "of: '" LoadSeedsDef "','" StoreSeedsDef "'."));
+cl::opt<std::string> CollectSeeds(
+    "sbvec-collect-seeds", cl::init(LoadSeedsDef "," StoreSeedsDef), cl::Hidden,
+    cl::desc("Collect these seeds. Use empty for none or a comma-separated "
+             "list of '" LoadSeedsDef "' and '" StoreSeedsDef "'."));
 cl::opt<unsigned> SeedGroupsLimit(
     "sbvec-seed-groups-limit", cl::init(256), cl::Hidden,
     cl::desc("Limit the number of collected seeds groups in a BB to "
@@ -163,31 +156,18 @@ template <typename LoadOrStoreT> static bool isValidMemSeed(LoadOrStoreT *LSI) {
   return VectorType::isValidElementType(Ty);
 }
 
-template bool isValidMemSeed(LoadInst *LSI);
+template bool isValidMemSeed<LoadInst>(LoadInst *LSI);
 template bool isValidMemSeed<StoreInst>(StoreInst *LSI);
 
-SeedCollector::SeedCollector(BasicBlock *SBBB, ScalarEvolution &SE)
-    : StoreSeeds(SE), LoadSeeds(SE), BB(SBBB), Ctx(BB->getContext()) {
-  // TODO: Register a callback for updating the Collector datastructures upon
+SeedCollector::SeedCollector(BasicBlock *BB, ScalarEvolution &SE)
+    : StoreSeeds(SE), LoadSeeds(SE), BB(BB), Ctx(BB->getContext()) {
+  // TODO: Register a callback for updating the Collector data structures upon
   // instr removal
 
-  bool CollectStores = !DisableStoreSeeds;
-  bool CollectLoads = !DisableLoadSeeds;
-  if (LLVM_UNLIKELY(!ForceSeed.empty())) {
-    CollectStores = false;
-    CollectLoads = false;
-    // Enable only the selected one.
-    if (ForceSeed == StoreSeedsDef)
-      CollectStores = true;
-    else if (ForceSeed == LoadSeedsDef)
-      CollectLoads = true;
-    else {
-      errs() << "Bad argument '" << ForceSeed << "' in -" << ForceSeed.ArgStr
-             << "='" << ForceSeed << "'.\n";
-      errs() << "Description: " << ForceSeed.HelpStr << "\n";
-      exit(1);
-    }
-  }
+  bool CollectStores = CollectSeeds.find(StoreSeedsDef) != std::string::npos;
+  bool CollectLoads = CollectSeeds.find(LoadSeedsDef) != std::string::npos;
+  if (!CollectStores && !CollectLoads)
+    return;
   // Actually collect the seeds.
   for (auto &I : *BB) {
     if (StoreInst *SI = dyn_cast<StoreInst>(&I))
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp
index 1dad0a707c73c8..d5b6d8cb479723 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp
@@ -365,8 +365,8 @@ TEST_F(SeedBundleTest, VectorStores) {
 define void @foo(ptr noalias %ptr, <2 x float> %val) {
 bb:
   %ptr0 = getelementptr float, ptr %ptr, i32 0
-  %ptr2 = getelementptr float, ptr %ptr, i32 2
-  store <2 x float> %val, ptr %ptr2
+  %ptr1 = getelementptr float, ptr %ptr, i32 2
+  store <2 x float> %val, ptr %ptr1
   store <2 x float> %val, ptr %ptr0
   ret void
 }
@@ -388,12 +388,12 @@ define void @foo(ptr noalias %ptr, <2 x float> %val) {
   // Find the stores
   auto It = std::next(BB->begin(), 2);
   // StX with X as the order by offset in memory
-  auto *St2 = &*It++;
+  auto *St1 = &*It++;
   auto *St0 = &*It++;
 
   auto StoreSeedsRange = SC.getStoreSeeds();
+  EXPECT_EQ(range_size(StoreSeedsRange), 1u);
   auto &SB = *StoreSeedsRange.begin();
-  EXPECT_TRUE(std::next(StoreSeedsRange.begin()) == StoreSeedsRange.end());
   EXPECT_THAT(SB, testing::ElementsAre(St0, St2));
 }
 
@@ -431,8 +431,8 @@ define void @foo(ptr noalias %ptr, float %v, <2 x float> %val) {
   auto *St3 = &*It++;
   auto *St1 = &*It++;
 
-  auto &SB = *SC.getStoreSeeds().begin();
-  EXPECT_TRUE(std::next(SC.getStoreSeeds().begin()) ==
-              SC.getStoreSeeds().end());
+  auto StoreSeedsRange = SC.getStoreSeeds();
+  EXPECT_EQ(range_size(StoreSeedsRange), 1u);
+  auto &SB = *StoreSeedsRange.begin();
   EXPECT_THAT(SB, testing::ElementsAre(St0, St1, St3));
 }

>From 23cc5e7771b8ecc6ccda77c4a82dfb16ecd98893 Mon Sep 17 00:00:00 2001
From: Sterling Augustine <saugustine at google.com>
Date: Mon, 21 Oct 2024 11:44:48 -0700
Subject: [PATCH 3/5] Address comments

---
 llvm/include/llvm/SandboxIR/Utils.h           | 12 --------
 .../SandboxVectorizer/SeedCollector.h         |  5 +---
 .../Vectorize/SandboxVectorizer/VecUtils.h    | 30 +++++++++++++++++++
 .../SandboxVectorizer/SeedCollector.cpp       |  2 +-
 .../SandboxVectorizer/SeedCollectorTest.cpp   |  2 +-
 5 files changed, 33 insertions(+), 18 deletions(-)
 create mode 100644 llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h

diff --git a/llvm/include/llvm/SandboxIR/Utils.h b/llvm/include/llvm/SandboxIR/Utils.h
index 144d4255c2818a..a73498adea1d59 100644
--- a/llvm/include/llvm/SandboxIR/Utils.h
+++ b/llvm/include/llvm/SandboxIR/Utils.h
@@ -60,18 +60,6 @@ class Utils {
         getUnderlyingObject(LSI->getPointerOperand()->Val));
   }
 
-  /// \Returns the number of elements in \p Ty. That is the number of lanes if a
-  /// fixed vector or 1 if scalar. ScalableVectors have unknown size and
-  /// therefore are unsupported.
-  static int getNumElements(Type *Ty) {
-    assert(!isa<ScalableVectorType>(Ty));
-    return Ty->isVectorTy() ? cast<FixedVectorType>(Ty)->getNumElements() : 1;
-  }
-  /// Returns \p Ty if scalar or its element type if vector.
-  static Type *getElementType(Type *Ty) {
-    return Ty->isVectorTy() ? cast<FixedVectorType>(Ty)->getElementType() : Ty;
-  }
-
   /// \Returns the number of bits required to represent the operands or return
   /// value of \p V in \p DL.
   static unsigned getNumBits(Value *V, const DataLayout &DL) {
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h
index 1e55fa0f0a5688..ed1cb8488c29eb 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h
@@ -287,7 +287,6 @@ class SeedContainer {
 class SeedCollector {
   SeedContainer StoreSeeds;
   SeedContainer LoadSeeds;
-  BasicBlock *BB;
   Context &Ctx;
 
   /// \Returns the number of SeedBundle groups for all seed types.
@@ -297,11 +296,9 @@ class SeedCollector {
   }
 
 public:
-  SeedCollector(BasicBlock *SBBB, ScalarEvolution &SE);
+  SeedCollector(BasicBlock *BB, ScalarEvolution &SE);
   ~SeedCollector();
 
-  BasicBlock *getBasicBlock() { return BB; }
-
   iterator_range<SeedContainer::iterator> getStoreSeeds() {
     return {StoreSeeds.begin(), StoreSeeds.end()};
   }
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
new file mode 100644
index 00000000000000..64f57edb38484e
--- /dev/null
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
@@ -0,0 +1,30 @@
+//===- VecUtils.h -----------------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Collector for SandboxVectorizer related convenience functions that don't
+// belong in other classes.
+
+#ifndef LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_VECUTILS_H
+#define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_VECUTILS_H
+
+class Utils {
+public:
+  /// \Returns the number of elements in \p Ty. That is the number of lanes if a
+  /// fixed vector or 1 if scalar. ScalableVectors have unknown size and
+  /// therefore are unsupported.
+  static int getNumElements(Type *Ty) {
+    assert(!isa<ScalableVectorType>(Ty));
+    return Ty->isVectorTy() ? cast<FixedVectorType>(Ty)->getNumElements() : 1;
+  }
+  /// Returns \p Ty if scalar or its element type if vector.
+  static Type *getElementType(Type *Ty) {
+    return Ty->isVectorTy() ? cast<FixedVectorType>(Ty)->getElementType() : Ty;
+  }
+}
+
+#endif LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_VECUTILS_H
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp
index 5ca9c84877be24..0d928af1902073 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp
@@ -160,7 +160,7 @@ template bool isValidMemSeed<LoadInst>(LoadInst *LSI);
 template bool isValidMemSeed<StoreInst>(StoreInst *LSI);
 
 SeedCollector::SeedCollector(BasicBlock *BB, ScalarEvolution &SE)
-    : StoreSeeds(SE), LoadSeeds(SE), BB(BB), Ctx(BB->getContext()) {
+    : StoreSeeds(SE), LoadSeeds(SE), Ctx(BB->getContext()) {
   // TODO: Register a callback for updating the Collector data structures upon
   // instr removal
 
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp
index d5b6d8cb479723..7ad3d84e4b68a3 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp
@@ -394,7 +394,7 @@ define void @foo(ptr noalias %ptr, <2 x float> %val) {
   auto StoreSeedsRange = SC.getStoreSeeds();
   EXPECT_EQ(range_size(StoreSeedsRange), 1u);
   auto &SB = *StoreSeedsRange.begin();
-  EXPECT_THAT(SB, testing::ElementsAre(St0, St2));
+  EXPECT_THAT(SB, testing::ElementsAre(St0, St1));
 }
 
 TEST_F(SeedBundleTest, MixedScalarVectors) {

>From d4f53de5820d5ceaa0fab7136d45818b2c997873 Mon Sep 17 00:00:00 2001
From: Sterling Augustine <saugustine at google.com>
Date: Mon, 21 Oct 2024 14:44:05 -0700
Subject: [PATCH 4/5] address more comments

---
 .../Vectorize/SandboxVectorizer/SeedCollectorTest.cpp         | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp
index 7ad3d84e4b68a3..08002ebf688f4f 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp
@@ -312,7 +312,7 @@ define void @foo(ptr noalias %ptr, float %val) {
   auto StoreSeedsRange = SC.getStoreSeeds();
   auto &SB = *StoreSeedsRange.begin();
   // Expect just one vector of store seeds
-  EXPECT_TRUE(std::next(StoreSeedsRange.begin()) == StoreSeedsRange.end());
+  EXPECT_EQ(range_size(StoreSeedsRange), 1u);
   EXPECT_THAT(SB, testing::ElementsAre(St0, St1, St2, St3));
 }
 
@@ -365,7 +365,7 @@ TEST_F(SeedBundleTest, VectorStores) {
 define void @foo(ptr noalias %ptr, <2 x float> %val) {
 bb:
   %ptr0 = getelementptr float, ptr %ptr, i32 0
-  %ptr1 = getelementptr float, ptr %ptr, i32 2
+  %ptr1 = getelementptr float, ptr %ptr, i32 1
   store <2 x float> %val, ptr %ptr1
   store <2 x float> %val, ptr %ptr0
   ret void

>From de1203a770cccfc316193a4c5646eea3d118e8b2 Mon Sep 17 00:00:00 2001
From: Sterling Augustine <saugustine at google.com>
Date: Mon, 21 Oct 2024 15:33:35 -0700
Subject: [PATCH 5/5] Fix missed comment

---
 .../Vectorize/SandboxVectorizer/SeedCollectorTest.cpp           | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp
index 08002ebf688f4f..4e28413a931a61 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp
@@ -356,7 +356,7 @@ define void @foo(ptr noalias %ptr, float %val) {
   auto StoreSeedsRange = SC.getStoreSeeds();
   auto &SB = *StoreSeedsRange.begin();
   // Expect just one vector of store seeds
-  EXPECT_TRUE(std::next(StoreSeedsRange.begin()) == StoreSeedsRange.end());
+  EXPECT_EQ(range_size(StoreSeedsRange), 1u);
   EXPECT_THAT(SB, testing::ElementsAre(St0, St1, St2, St3));
 }
 



More information about the llvm-commits mailing list