[llvm] [SandboxVec][DAG] Cleanup: Move callback registration from Scheduler to DAG (PR #116455)

via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 18 19:46:51 PST 2024


https://github.com/vporpo updated https://github.com/llvm/llvm-project/pull/116455

>From df777524a41b8b5cb5206dd4536d70724025f08e Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Thu, 7 Nov 2024 09:25:34 -0800
Subject: [PATCH] [SandboxVec][DAG] Cleanup: Move callback registration from
 Scheduler to DAG

This is a refactoring patch that moves the callback registration for
getting notified about new instructions from the scheduler to the DAG.
This makes sense from a design and testing point of view:
- the DAG should not rely on the scheduler for getting notified
- the notifiers don't need to be public
- it's easier to test the notifiers directly from within the DAG unit tests
---
 .../SandboxVectorizer/DependencyGraph.h       | 27 ++++++--
 .../Vectorize/SandboxVectorizer/Scheduler.h   |  9 +--
 .../SandboxVectorizer/DependencyGraphTest.cpp | 66 ++++++++++++++-----
 .../SandboxVectorizer/SchedulerTest.cpp       |  2 +-
 4 files changed, 72 insertions(+), 32 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index 5211c7922ea2fd..765b65c4971bed 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -290,6 +290,9 @@ class DependencyGraph {
   /// The DAG spans across all instructions in this interval.
   Interval<Instruction> DAGInterval;
 
+  Context *Ctx = nullptr;
+  std::optional<Context::CallbackID> CreateInstrCB;
+
   std::unique_ptr<BatchAAResults> BatchAA;
 
   enum class DependencyType {
@@ -325,9 +328,24 @@ class DependencyGraph {
   /// chain.
   void createNewNodes(const Interval<Instruction> &NewInterval);
 
+  /// Called by the callbacks when a new instruction \p I has been created.
+  void notifyCreateInstr(Instruction *I) {
+    getOrCreateNode(I);
+    // TODO: Update the dependencies for the new node.
+    // TODO: Update the MemDGNode chain to include the new node if needed.
+  }
+
 public:
-  DependencyGraph(AAResults &AA)
-      : BatchAA(std::make_unique<BatchAAResults>(AA)) {}
+  /// This constructor also registers callbacks.
+  DependencyGraph(AAResults &AA, Context &Ctx)
+      : Ctx(&Ctx), BatchAA(std::make_unique<BatchAAResults>(AA)) {
+    CreateInstrCB = Ctx.registerCreateInstrCallback(
+        [this](Instruction *I) { notifyCreateInstr(I); });
+  }
+  ~DependencyGraph() {
+    if (CreateInstrCB)
+      Ctx->unregisterCreateInstrCallback(*CreateInstrCB);
+  }
 
   DGNode *getNode(Instruction *I) const {
     auto It = InstrToNodeMap.find(I);
@@ -354,11 +372,6 @@ class DependencyGraph {
   Interval<Instruction> extend(ArrayRef<Instruction *> Instrs);
   /// \Returns the range of instructions included in the DAG.
   Interval<Instruction> getInterval() const { return DAGInterval; }
-  /// Called by the scheduler when a new instruction \p I has been created.
-  void notifyCreateInstr(Instruction *I) {
-    getOrCreateNode(I);
-    // TODO: Update the dependencies for the new node.
-  }
   void clear() {
     InstrToNodeMap.clear();
     DAGInterval = {};
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h
index 9c11b5dbc16432..022fd71df67dc6 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h
@@ -106,8 +106,6 @@ class Scheduler {
   std::optional<BasicBlock::iterator> ScheduleTopItOpt;
   // TODO: This is wasting memory in exchange for fast removal using a raw ptr.
   DenseMap<SchedBundle *, std::unique_ptr<SchedBundle>> Bndls;
-  Context &Ctx;
-  Context::CallbackID CreateInstrCB;
 
   /// \Returns a scheduling bundle containing \p Instrs.
   SchedBundle *createBundle(ArrayRef<Instruction *> Instrs);
@@ -137,11 +135,8 @@ class Scheduler {
   Scheduler &operator=(const Scheduler &) = delete;
 
 public:
-  Scheduler(AAResults &AA, Context &Ctx) : DAG(AA), Ctx(Ctx) {
-    CreateInstrCB = Ctx.registerCreateInstrCallback(
-        [this](Instruction *I) { DAG.notifyCreateInstr(I); });
-  }
-  ~Scheduler() { Ctx.unregisterCreateInstrCallback(CreateInstrCB); }
+  Scheduler(AAResults &AA, Context &Ctx) : DAG(AA, Ctx) {}
+  ~Scheduler() {}
 
   bool trySchedule(ArrayRef<Instruction *> Instrs);
 
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index 061d57c31ce236..206f6c5b4c1359 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -194,7 +194,7 @@ define void @foo(i8 %v1, ptr %ptr) {
   auto *Call = cast<sandboxir::CallInst>(&*It++);
   auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
 
-  sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+  sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
   DAG.extend({&*BB->begin(), BB->getTerminator()});
   EXPECT_TRUE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(Store)));
   EXPECT_TRUE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(Load)));
@@ -224,7 +224,7 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
   auto *S0 = cast<sandboxir::StoreInst>(&*It++);
   auto *S1 = cast<sandboxir::StoreInst>(&*It++);
   auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
-  sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+  sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
   auto Span = DAG.extend({&*BB->begin(), BB->getTerminator()});
   // Check extend().
   EXPECT_EQ(Span.top(), &*BB->begin());
@@ -285,7 +285,7 @@ define i8 @foo(i8 %v0, i8 %v1) {
   auto *F = Ctx.createFunction(LLVMF);
   auto *BB = &*F->begin();
   auto It = BB->begin();
-  sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+  sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
   DAG.extend({&*BB->begin(), BB->getTerminator()});
 
   auto *AddN0 = DAG.getNode(cast<sandboxir::BinaryOperator>(&*It++));
@@ -332,7 +332,7 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
   auto *S1 = cast<sandboxir::StoreInst>(&*It++);
   [[maybe_unused]] auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
 
-  sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+  sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
   DAG.extend({&*BB->begin(), BB->getTerminator()});
 
   auto *S0N = cast<sandboxir::MemDGNode>(DAG.getNode(S0));
@@ -366,7 +366,7 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
   auto *S1 = cast<sandboxir::StoreInst>(&*It++);
   auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
 
-  sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+  sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
   DAG.extend({&*BB->begin(), BB->getTerminator()});
 
   auto *S0N = cast<sandboxir::MemDGNode>(DAG.getNode(S0));
@@ -436,7 +436,7 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
   sandboxir::Context Ctx(C);
   auto *F = Ctx.createFunction(LLVMF);
   auto *BB = &*F->begin();
-  sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+  sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
   DAG.extend({&*BB->begin(), BB->getTerminator()});
   auto It = BB->begin();
   auto *Store0N = cast<sandboxir::MemDGNode>(
@@ -461,7 +461,7 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1, i8 %v0, i8 %v1) {
   sandboxir::Context Ctx(C);
   auto *F = Ctx.createFunction(LLVMF);
   auto *BB = &*F->begin();
-  sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+  sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
   DAG.extend({&*BB->begin(), BB->getTerminator()});
   auto It = BB->begin();
   auto *Store0N = cast<sandboxir::MemDGNode>(
@@ -487,7 +487,7 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1) {
   sandboxir::Context Ctx(C);
   auto *F = Ctx.createFunction(LLVMF);
   auto *BB = &*F->begin();
-  sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+  sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
   DAG.extend({&*BB->begin(), BB->getTerminator()});
   auto It = BB->begin();
   auto *Ld0N = cast<sandboxir::MemDGNode>(
@@ -512,7 +512,7 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1, i8 %v) {
   sandboxir::Context Ctx(C);
   auto *F = Ctx.createFunction(LLVMF);
   auto *BB = &*F->begin();
-  sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+  sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
   DAG.extend({&*BB->begin(), BB->getTerminator()});
   auto It = BB->begin();
   auto *Store0N = cast<sandboxir::MemDGNode>(
@@ -542,7 +542,7 @@ define void @foo(float %v1, float %v2) {
   auto *F = Ctx.createFunction(LLVMF);
   auto *BB = &*F->begin();
 
-  sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+  sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
   DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()});
 
   auto It = BB->begin();
@@ -574,7 +574,7 @@ define void @foo() {
   auto *F = Ctx.createFunction(LLVMF);
   auto *BB = &*F->begin();
 
-  sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+  sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
   DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()});
 
   auto It = BB->begin();
@@ -606,7 +606,7 @@ define void @foo(i8 %v0, i8 %v1, ptr %ptr) {
   auto *F = Ctx.createFunction(LLVMF);
   auto *BB = &*F->begin();
 
-  sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+  sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
   DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()});
 
   auto It = BB->begin();
@@ -637,7 +637,7 @@ define void @foo(ptr %ptr) {
   auto *F = Ctx.createFunction(LLVMF);
   auto *BB = &*F->begin();
 
-  sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+  sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
   DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()});
 
   auto It = BB->begin();
@@ -664,7 +664,7 @@ define void @foo(ptr %ptr) {
   auto *F = Ctx.createFunction(LLVMF);
   auto *BB = &*F->begin();
 
-  sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+  sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
   DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()});
 
   auto It = BB->begin();
@@ -695,7 +695,7 @@ define void @foo() {
   auto *F = Ctx.createFunction(LLVMF);
   auto *BB = &*F->begin();
 
-  sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+  sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
   DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()});
 
   auto It = BB->begin();
@@ -728,7 +728,7 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) {
   auto *S3 = cast<sandboxir::StoreInst>(&*It++);
   auto *S4 = cast<sandboxir::StoreInst>(&*It++);
   auto *S5 = cast<sandboxir::StoreInst>(&*It++);
-  sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+  sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
   {
     // Scenario 1: Build new DAG
     auto NewIntvl = DAG.extend({S3, S3});
@@ -788,7 +788,7 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) {
 
   {
     // Check UnscheduledSuccs when a node is scheduled
-    sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+    sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
     DAG.extend({S2, S2});
     auto *S2N = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
     S2N->setScheduled(true);
@@ -798,3 +798,35 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) {
     EXPECT_EQ(S1N->getNumUnscheduledSuccs(), 0u); // S1 is scheduled
   }
 }
+
+TEST_F(DependencyGraphTest, CreateInstrCallback) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
+  store i8 %v1, ptr %ptr
+  store i8 %v2, ptr %ptr
+  store i8 %v3, ptr %ptr
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  auto *S1 = cast<sandboxir::StoreInst>(&*It++);
+  [[maybe_unused]] auto *S2 = cast<sandboxir::StoreInst>(&*It++);
+  auto *S3 = cast<sandboxir::StoreInst>(&*It++);
+
+  // Check new instruction callback.
+  sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
+  DAG.extend({S1, S3});
+  auto *Arg = F->getArg(3);
+  auto *Ptr = S1->getPointerOperand();
+  sandboxir::StoreInst *NewS =
+      sandboxir::StoreInst::create(Arg, Ptr, Align(8), S3->getIterator(),
+                                   /*IsVolatile=*/true, Ctx);
+  auto *NewSN = DAG.getNode(NewS);
+  EXPECT_TRUE(NewSN != nullptr);
+  // TODO: Check the dependencies to/from NewSN after they land.
+  // TODO: Check the MemDGNode chain.
+}
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp
index 94a57914429748..c5e44a97976a72 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp
@@ -70,7 +70,7 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
   auto *S1 = cast<sandboxir::StoreInst>(&*It++);
   auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
 
-  sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+  sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
   DAG.extend({&*BB->begin(), BB->getTerminator()});
   auto *SN0 = DAG.getNode(S0);
   auto *SN1 = DAG.getNode(S1);



More information about the llvm-commits mailing list