[llvm] [SandboxVec][DAG] Update MemDGNode chain upon instr creation (PR #116896)

via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 19 16:45:27 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: vporpo (vporpo)

<details>
<summary>Changes</summary>

The DAG maintains a chain of MemDGNodes that links together all the nodes that may touch memroy.
Whenever a new instruction gets created we need to make sure that this chain gets updated. If the new instruction touches memory then its corresponding MemDGNode should be inserted into the chain.

---
Full diff: https://github.com/llvm/llvm-project/pull/116896.diff


3 Files Affected:

- (modified) llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h (+12-6) 
- (modified) llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp (+45) 
- (modified) llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp (+33-8) 


``````````diff
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index 68a2daca1403df..911ee3e839521c 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -329,13 +329,19 @@ class DependencyGraph {
   /// chain.
   void createNewNodes(const Interval<Instruction> &NewInterval);
 
+  /// Helper for `notify*Instr()`. \Returns the first MemDGNode that comes
+  /// before \p N, including or excluding \p N based on \p IncludingN, or
+  /// nullptr if not found.
+  MemDGNode *getMemDGNodeBefore(DGNode *N, bool IncludingN) const;
+  /// Helper for `notifyMoveInstr()`. \Returns the first MemDGNode that comes
+  /// after \p N, including or excluding \p N based on \p IncludingN, or nullptr
+  /// if not found.
+  MemDGNode *getMemDGNodeAfter(DGNode *N, bool IncludingN) const;
+
   /// Called by the callbacks when a new instruction \p I has been created.
-  void notifyCreateInstr(Instruction *I) {
-    getOrCreateNode(I);
-    // TODO: Update the dependencies for the new node.
-    // TODO: Update the MemDGNode chain to include the new node if needed.
-  }
-  /// Called by the callbacks when instruction \p I is about to get deleted.
+  void notifyCreateInstr(Instruction *I);
+  /// Called by the callbacks when instruction \p I is about to get
+  /// deleted.
   void notifyEraseInstr(Instruction *I) {
     InstrToNodeMap.erase(I);
     // TODO: Update the dependencies.
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
index 4b0e12c28f07b7..5cf44ba9dcbaaa 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -325,6 +325,51 @@ void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) {
   setDefUseUnscheduledSuccs(NewInterval);
 }
 
+MemDGNode *DependencyGraph::getMemDGNodeBefore(DGNode *N,
+                                               bool IncludingN) const {
+  auto *I = N->getInstruction();
+  for (auto *PrevI = IncludingN ? I : I->getPrevNode(); PrevI != nullptr;
+       PrevI = PrevI->getPrevNode()) {
+    auto *PrevN = getNodeOrNull(PrevI);
+    if (PrevN == nullptr)
+      return nullptr;
+    if (auto *PrevMemN = dyn_cast<MemDGNode>(PrevN))
+      return PrevMemN;
+  }
+  return nullptr;
+}
+
+MemDGNode *DependencyGraph::getMemDGNodeAfter(DGNode *N,
+                                              bool IncludingN) const {
+  auto *I = N->getInstruction();
+  for (auto *NextI = IncludingN ? I : I->getNextNode(); NextI != nullptr;
+       NextI = NextI->getNextNode()) {
+    auto *NextN = getNodeOrNull(NextI);
+    if (NextN == nullptr)
+      return nullptr;
+    if (auto *NextMemN = dyn_cast<MemDGNode>(NextN))
+      return NextMemN;
+  }
+  return nullptr;
+}
+
+void DependencyGraph::notifyCreateInstr(Instruction *I) {
+  auto *MemN = dyn_cast<MemDGNode>(getOrCreateNode(I));
+  // TODO: Update the dependencies for the new node.
+
+  // Update the MemDGNode chain if this is a memory node.
+  if (MemN != nullptr) {
+    if (auto *PrevMemN = getMemDGNodeBefore(MemN, /*IncludingN=*/false)) {
+      PrevMemN->NextMemN = MemN;
+      MemN->PrevMemN = PrevMemN;
+    }
+    if (auto *NextMemN = getMemDGNodeAfter(MemN, /*IncludingN=*/false)) {
+      NextMemN->PrevMemN = MemN;
+      MemN->NextMemN = NextMemN;
+    }
+  }
+}
+
 Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
   if (Instrs.empty())
     return {};
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index e6bb4b4684d262..1130c9c63c71da 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -814,21 +814,46 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
   auto *BB = &*F->begin();
   auto It = BB->begin();
   auto *S1 = cast<sandboxir::StoreInst>(&*It++);
-  [[maybe_unused]] auto *S2 = cast<sandboxir::StoreInst>(&*It++);
+  auto *S2 = cast<sandboxir::StoreInst>(&*It++);
   auto *S3 = cast<sandboxir::StoreInst>(&*It++);
+  auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
 
   // Check new instruction callback.
   sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
-  DAG.extend({S1, S3});
+  DAG.extend({S1, Ret});
   auto *Arg = F->getArg(3);
   auto *Ptr = S1->getPointerOperand();
-  sandboxir::StoreInst *NewS =
-      sandboxir::StoreInst::create(Arg, Ptr, Align(8), S3->getIterator(),
-                                   /*IsVolatile=*/true, Ctx);
-  auto *NewSN = DAG.getNode(NewS);
-  EXPECT_TRUE(NewSN != nullptr);
+  {
+    sandboxir::StoreInst *NewS =
+        sandboxir::StoreInst::create(Arg, Ptr, Align(8), S3->getIterator(),
+                                     /*IsVolatile=*/true, Ctx);
+    auto *NewSN = DAG.getNode(NewS);
+    EXPECT_TRUE(NewSN != nullptr);
+
+    // Check the MemDGNode chain.
+    auto *S2MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
+    auto *NewMemSN = cast<sandboxir::MemDGNode>(NewSN);
+    auto *S3MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
+    EXPECT_EQ(S2MemN->getNextNode(), NewMemSN);
+    EXPECT_EQ(NewMemSN->getPrevNode(), S2MemN);
+    EXPECT_EQ(NewMemSN->getNextNode(), S3MemN);
+    EXPECT_EQ(S3MemN->getPrevNode(), NewMemSN);
+  }
+
+  {
+    // Also check if new node is at the end of the BB, after Ret.
+    sandboxir::StoreInst *NewS =
+        sandboxir::StoreInst::create(Arg, Ptr, Align(8), BB->end(),
+                                     /*IsVolatile=*/true, Ctx);
+    // Check the MemDGNode chain.
+    auto *S3MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
+    auto *NewMemSN = cast<sandboxir::MemDGNode>(DAG.getNode(NewS));
+    EXPECT_EQ(S3MemN->getNextNode(), NewMemSN);
+    EXPECT_EQ(NewMemSN->getPrevNode(), S3MemN);
+    EXPECT_EQ(NewMemSN->getNextNode(), nullptr);
+  }
+
   // TODO: Check the dependencies to/from NewSN after they land.
-  // TODO: Check the MemDGNode chain.
 }
 
 TEST_F(DependencyGraphTest, EraseInstrCallback) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/116896


More information about the llvm-commits mailing list