[llvm] [SandboxVec][DAG] Cleanup: Move callback registration from Scheduler to DAG (PR #116455)

via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 15 16:17:00 PST 2024


https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/116455

This is a refactoring patch that moves the callback registration for getting notified about new instructions from the scheduler to the DAG. This makes sense from a design and testing point of view:
- the DAG should not rely on the scheduler for getting notified
- the notifiers don't need to be public
- it's easier to test the notifiers directly from within the DAG unit tests

>From 56284621c8941b63627a93592c940fa7e779da5f Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Thu, 7 Nov 2024 09:25:34 -0800
Subject: [PATCH] [SandboxVec][DAG] Cleanup: Move callback registration from
 Scheduler to DAG

This is a refactoring patch that moves the callback registration for
getting notified about new instructions from the scheduler to the DAG.
This makes sense from a design and testing point of view:
- the DAG should not rely on the scheduler for getting notified
- the notifiers don't need to be public
- it's easier to test the notifiers directly from within the DAG unit tests
---
 .../SandboxVectorizer/DependencyGraph.h       | 25 ++++++++++++---
 .../Vectorize/SandboxVectorizer/Scheduler.h   |  9 ++----
 .../SandboxVectorizer/DependencyGraphTest.cpp | 32 +++++++++++++++++++
 3 files changed, 54 insertions(+), 12 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index 5211c7922ea2fd..c2c7a6697f173e 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -290,6 +290,9 @@ class DependencyGraph {
   /// The DAG spans across all instructions in this interval.
   Interval<Instruction> DAGInterval;
 
+  Context *Ctx = nullptr;
+  std::optional<Context::CallbackID> CreateInstrCB;
+
   std::unique_ptr<BatchAAResults> BatchAA;
 
   enum class DependencyType {
@@ -325,9 +328,26 @@ class DependencyGraph {
   /// chain.
   void createNewNodes(const Interval<Instruction> &NewInterval);
 
+  /// Called by the callbacks when a new instruction \p I has been created.
+  void notifyCreateInstr(Instruction *I) {
+    getOrCreateNode(I);
+    // TODO: Update the dependencies for the new node.
+    // TODO: Update the MemDGNode chain to include the new node if needed.
+  }
+
 public:
   DependencyGraph(AAResults &AA)
       : BatchAA(std::make_unique<BatchAAResults>(AA)) {}
+  /// This constructor also registers callbacks.
+  DependencyGraph(AAResults &AA, Context &Ctx)
+      : Ctx(&Ctx), BatchAA(std::make_unique<BatchAAResults>(AA)) {
+    CreateInstrCB = Ctx.registerCreateInstrCallback(
+        [this](Instruction *I) { notifyCreateInstr(I); });
+  }
+  ~DependencyGraph() {
+    if (CreateInstrCB)
+      Ctx->unregisterCreateInstrCallback(*CreateInstrCB);
+  }
 
   DGNode *getNode(Instruction *I) const {
     auto It = InstrToNodeMap.find(I);
@@ -354,11 +374,6 @@ class DependencyGraph {
   Interval<Instruction> extend(ArrayRef<Instruction *> Instrs);
   /// \Returns the range of instructions included in the DAG.
   Interval<Instruction> getInterval() const { return DAGInterval; }
-  /// Called by the scheduler when a new instruction \p I has been created.
-  void notifyCreateInstr(Instruction *I) {
-    getOrCreateNode(I);
-    // TODO: Update the dependencies for the new node.
-  }
   void clear() {
     InstrToNodeMap.clear();
     DAGInterval = {};
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h
index 9c11b5dbc16432..022fd71df67dc6 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h
@@ -106,8 +106,6 @@ class Scheduler {
   std::optional<BasicBlock::iterator> ScheduleTopItOpt;
   // TODO: This is wasting memory in exchange for fast removal using a raw ptr.
   DenseMap<SchedBundle *, std::unique_ptr<SchedBundle>> Bndls;
-  Context &Ctx;
-  Context::CallbackID CreateInstrCB;
 
   /// \Returns a scheduling bundle containing \p Instrs.
   SchedBundle *createBundle(ArrayRef<Instruction *> Instrs);
@@ -137,11 +135,8 @@ class Scheduler {
   Scheduler &operator=(const Scheduler &) = delete;
 
 public:
-  Scheduler(AAResults &AA, Context &Ctx) : DAG(AA), Ctx(Ctx) {
-    CreateInstrCB = Ctx.registerCreateInstrCallback(
-        [this](Instruction *I) { DAG.notifyCreateInstr(I); });
-  }
-  ~Scheduler() { Ctx.unregisterCreateInstrCallback(CreateInstrCB); }
+  Scheduler(AAResults &AA, Context &Ctx) : DAG(AA, Ctx) {}
+  ~Scheduler() {}
 
   bool trySchedule(ArrayRef<Instruction *> Instrs);
 
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index 061d57c31ce236..6129664c6e429a 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -798,3 +798,35 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) {
     EXPECT_EQ(S1N->getNumUnscheduledSuccs(), 0u); // S1 is scheduled
   }
 }
+
+TEST_F(DependencyGraphTest, CreateInstrCallback) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
+  store i8 %v1, ptr %ptr
+  store i8 %v2, ptr %ptr
+  store i8 %v3, ptr %ptr
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  auto *S1 = cast<sandboxir::StoreInst>(&*It++);
+  [[maybe_unused]] auto *S2 = cast<sandboxir::StoreInst>(&*It++);
+  auto *S3 = cast<sandboxir::StoreInst>(&*It++);
+
+  // Check new instruction callback.
+  sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
+  DAG.extend({S1, S3});
+  auto *Arg = F->getArg(3);
+  auto *Ptr = S1->getPointerOperand();
+  sandboxir::StoreInst *NewS =
+      sandboxir::StoreInst::create(Arg, Ptr, Align(8), S3->getIterator(),
+                                   /*IsVolatile=*/true, Ctx);
+  auto *NewSN = DAG.getNode(NewS);
+  EXPECT_TRUE(NewSN != nullptr);
+  // TODO: Check the dependencies to/from NewSN after they land.
+  // TODO: Check the MemDGNode chain.
+}



More information about the llvm-commits mailing list