[llvm] 6e48214 - [SandboxVec][DAG] Register callback for erase instr (#116742)

via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 19 16:20:41 PST 2024


Author: vporpo
Date: 2024-11-19T16:20:38-08:00
New Revision: 6e4821487fcab23bf9ca7f7c667826956bee4d1b

URL: https://github.com/llvm/llvm-project/commit/6e4821487fcab23bf9ca7f7c667826956bee4d1b
DIFF: https://github.com/llvm/llvm-project/commit/6e4821487fcab23bf9ca7f7c667826956bee4d1b.diff

LOG: [SandboxVec][DAG] Register callback for erase instr (#116742)

This patch adds the callback registration logic in the DAG's constructor
and the corresponding deregistration logic in the destructor. It also
implements the code that makes sure that SchedBundle and DGNodes can be
safely destroyed in any order.

Added: 
    

Modified: 
    llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
    llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h
    llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
    llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index 765b65c4971bed..68a2daca1403df 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -117,7 +117,7 @@ class DGNode {
     assert(!isMemDepNodeCandidate(I) && "Expected Non-Mem instruction, ");
   }
   DGNode(const DGNode &Other) = delete;
-  virtual ~DGNode() = default;
+  virtual ~DGNode();
   /// \Returns the number of unscheduled successors.
   unsigned getNumUnscheduledSuccs() const { return UnscheduledSuccs; }
   void decrUnscheduledSuccs() {
@@ -292,6 +292,7 @@ class DependencyGraph {
 
   Context *Ctx = nullptr;
   std::optional<Context::CallbackID> CreateInstrCB;
+  std::optional<Context::CallbackID> EraseInstrCB;
 
   std::unique_ptr<BatchAAResults> BatchAA;
 
@@ -334,6 +335,12 @@ class DependencyGraph {
     // TODO: Update the dependencies for the new node.
     // TODO: Update the MemDGNode chain to include the new node if needed.
   }
+  /// Called by the callbacks when instruction \p I is about to get deleted.
+  void notifyEraseInstr(Instruction *I) {
+    InstrToNodeMap.erase(I);
+    // TODO: Update the dependencies.
+    // TODO: Update the MemDGNode chain to remove the node if needed.
+  }
 
 public:
   /// This constructor also registers callbacks.
@@ -341,10 +348,14 @@ class DependencyGraph {
       : Ctx(&Ctx), BatchAA(std::make_unique<BatchAAResults>(AA)) {
     CreateInstrCB = Ctx.registerCreateInstrCallback(
         [this](Instruction *I) { notifyCreateInstr(I); });
+    EraseInstrCB = Ctx.registerEraseInstrCallback(
+        [this](Instruction *I) { notifyEraseInstr(I); });
   }
   ~DependencyGraph() {
     if (CreateInstrCB)
       Ctx->unregisterCreateInstrCallback(*CreateInstrCB);
+    if (EraseInstrCB)
+      Ctx->unregisterEraseInstrCallback(*EraseInstrCB);
   }
 
   DGNode *getNode(Instruction *I) const {

diff  --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h
index 022fd71df67dc6..3959f84c601e04 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h
@@ -69,6 +69,10 @@ class SchedBundle {
 private:
   ContainerTy Nodes;
 
+  /// Called by the DGNode destructor to avoid accessing freed memory.
+  void eraseFromBundle(DGNode *N) { Nodes.erase(find(Nodes, N)); }
+  friend DGNode::~DGNode(); // For eraseFromBundle().
+
 public:
   SchedBundle() = default;
   SchedBundle(ContainerTy &&Nodes) : Nodes(std::move(Nodes)) {

diff  --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
index 6217c9fecf45dd..4b0e12c28f07b7 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -10,6 +10,7 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/SandboxIR/Instruction.h"
 #include "llvm/SandboxIR/Utils.h"
+#include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h"
 
 namespace llvm::sandboxir {
 
@@ -58,6 +59,12 @@ bool PredIterator::operator==(const PredIterator &Other) const {
   return OpIt == Other.OpIt && MemIt == Other.MemIt;
 }
 
+DGNode::~DGNode() {
+  if (SB == nullptr)
+    return;
+  SB->eraseFromBundle(this);
+}
+
 #ifndef NDEBUG
 void DGNode::print(raw_ostream &OS, bool PrintDeps) const {
   OS << *I << " USuccs:" << UnscheduledSuccs << " Sched:" << Scheduled << "\n";

diff  --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index 206f6c5b4c1359..e6bb4b4684d262 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -830,3 +830,31 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
   // TODO: Check the dependencies to/from NewSN after they land.
   // TODO: Check the MemDGNode chain.
 }
+
+TEST_F(DependencyGraphTest, EraseInstrCallback) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
+  store i8 %v1, ptr %ptr
+  store i8 %v2, ptr %ptr
+  store i8 %v3, ptr %ptr
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  auto *S1 = cast<sandboxir::StoreInst>(&*It++);
+  auto *S2 = cast<sandboxir::StoreInst>(&*It++);
+  auto *S3 = cast<sandboxir::StoreInst>(&*It++);
+
+  // Check erase instruction callback.
+  sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
+  DAG.extend({S1, S3});
+  S2->eraseFromParent();
+  auto *DeletedN = DAG.getNodeOrNull(S2);
+  EXPECT_TRUE(DeletedN == nullptr);
+  // TODO: Check the dependencies to/from NewSN after they land.
+  // TODO: Check the MemDGNode chain.
+}


        


More information about the llvm-commits mailing list