[llvm] [SandboxVec][DAG] Fix MemDGNode chain maintenance when move destinati… (PR #124227)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 23 21:31:16 PST 2025


https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/124227

…on is non mem

This patch fixes a bug in the maintenance of the MemDGNode chain of the DAG. Whenever we move a memory instruction, the DAG gets notified about the move and maintains the chain of memory nodes. The bug was that if the destination of the move was not a memory instruction, then the memory node's next node would end up pointing to itself.

>From b71e3ff8a6367a026812ce1b988d500952f4f625 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Tue, 14 Jan 2025 10:49:22 -0800
Subject: [PATCH] [SandboxVec][DAG] Fix MemDGNode chain maintenance when move
 destination is non mem

This patch fixes a bug in the maintenance of the MemDGNode chain of the DAG.
Whenever we move a memory instruction, the DAG gets notified about the move
and maintains the chain of memory nodes. The bug was that if the destination
of the move was not a memory instruction, then the memory node's next node
would end up pointing to itself.
---
 .../SandboxVectorizer/DependencyGraph.h       | 16 +++--
 .../SandboxVectorizer/DependencyGraph.cpp     | 70 +++++++++++++------
 .../SandboxVectorizer/DependencyGraphTest.cpp | 43 ++++++++++++
 3 files changed, 103 insertions(+), 26 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index b2d7c9b8aa8bbc..6e3f99d78b9329 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -218,12 +218,14 @@ class MemDGNode final : public DGNode {
   friend class PredIterator; // For MemPreds.
   /// Creates both edges: this<->N.
   void setNextNode(MemDGNode *N) {
+    assert(N != this && "About to point to self!");
     NextMemN = N;
     if (NextMemN != nullptr)
       NextMemN->PrevMemN = this;
   }
   /// Creates both edges: N<->this.
   void setPrevNode(MemDGNode *N) {
+    assert(N != this && "About to point to self!");
     PrevMemN = N;
     if (PrevMemN != nullptr)
       PrevMemN->NextMemN = this;
@@ -348,13 +350,15 @@ class DependencyGraph {
   void createNewNodes(const Interval<Instruction> &NewInterval);
 
   /// Helper for `notify*Instr()`. \Returns the first MemDGNode that comes
-  /// before \p N, including or excluding \p N based on \p IncludingN, or
-  /// nullptr if not found.
-  MemDGNode *getMemDGNodeBefore(DGNode *N, bool IncludingN) const;
+  /// before \p N, skipping \p SkipN, including or excluding \p N based on
+  /// \p IncludingN, or nullptr if not found.
+  MemDGNode *getMemDGNodeBefore(DGNode *N, bool IncludingN,
+                                MemDGNode *SkipN = nullptr) const;
   /// Helper for `notifyMoveInstr()`. \Returns the first MemDGNode that comes
-  /// after \p N, including or excluding \p N based on \p IncludingN, or nullptr
-  /// if not found.
-  MemDGNode *getMemDGNodeAfter(DGNode *N, bool IncludingN) const;
+  /// after \p N, skipping \p SkipN, including or excluding \p N based on \p
+  /// IncludingN, or nullptr if not found.
+  MemDGNode *getMemDGNodeAfter(DGNode *N, bool IncludingN,
+                               MemDGNode *SkipN = nullptr) const;
 
   /// Called by the callbacks when a new instruction \p I has been created.
   void notifyCreateInstr(Instruction *I);
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
index f080111f08d45e..390a5e9688cc78 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -325,29 +325,31 @@ void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) {
   setDefUseUnscheduledSuccs(NewInterval);
 }
 
-MemDGNode *DependencyGraph::getMemDGNodeBefore(DGNode *N,
-                                               bool IncludingN) const {
+MemDGNode *DependencyGraph::getMemDGNodeBefore(DGNode *N, bool IncludingN,
+                                               MemDGNode *SkipN) const {
   auto *I = N->getInstruction();
   for (auto *PrevI = IncludingN ? I : I->getPrevNode(); PrevI != nullptr;
        PrevI = PrevI->getPrevNode()) {
     auto *PrevN = getNodeOrNull(PrevI);
     if (PrevN == nullptr)
       return nullptr;
-    if (auto *PrevMemN = dyn_cast<MemDGNode>(PrevN))
+    auto *PrevMemN = dyn_cast<MemDGNode>(PrevN);
+    if (PrevMemN != nullptr && PrevMemN != SkipN)
       return PrevMemN;
   }
   return nullptr;
 }
 
-MemDGNode *DependencyGraph::getMemDGNodeAfter(DGNode *N,
-                                              bool IncludingN) const {
+MemDGNode *DependencyGraph::getMemDGNodeAfter(DGNode *N, bool IncludingN,
+                                              MemDGNode *SkipN) const {
   auto *I = N->getInstruction();
   for (auto *NextI = IncludingN ? I : I->getNextNode(); NextI != nullptr;
        NextI = NextI->getNextNode()) {
     auto *NextN = getNodeOrNull(NextI);
     if (NextN == nullptr)
       return nullptr;
-    if (auto *NextMemN = dyn_cast<MemDGNode>(NextN))
+    auto *NextMemN = dyn_cast<MemDGNode>(NextN);
+    if (NextMemN != nullptr && NextMemN != SkipN)
       return NextMemN;
   }
   return nullptr;
@@ -377,6 +379,20 @@ void DependencyGraph::notifyMoveInstr(Instruction *I, const BBIterator &To) {
          !(To == BB->end() && std::next(I->getIterator()) == BB->end()) &&
          "Should not have been called if destination is same as origin.");
 
+  // TODO: We can only handle fully internal movements within DAGInterval or at
+  // the borders, i.e., right before the top or right after the bottom.
+  assert(To.getNodeParent() == I->getParent() &&
+         "TODO: We don't support movement across BBs!");
+  assert(
+      (To == std::next(DAGInterval.bottom()->getIterator()) ||
+       (To != BB->end() && std::next(To) == DAGInterval.top()->getIterator()) ||
+       (To != BB->end() && DAGInterval.contains(&*To))) &&
+      "TODO: To should be either within the DAGInterval or right "
+      "before/after it.");
+
+  // Make a copy of the DAGInterval before we update it.
+  auto OrigDAGInterval = DAGInterval;
+
   // Maintain the DAGInterval.
   DAGInterval.notifyMoveInstr(I, To);
 
@@ -389,23 +405,37 @@ void DependencyGraph::notifyMoveInstr(Instruction *I, const BBIterator &To) {
   MemDGNode *MemN = dyn_cast<MemDGNode>(N);
   if (MemN == nullptr)
     return;
-  // First detach it from the existing chain.
+
+  // First safely detach it from the existing chain.
   MemN->detachFromChain();
+
   // Now insert it back into the chain at the new location.
-  if (To != BB->end()) {
-    DGNode *ToN = getNodeOrNull(&*To);
-    if (ToN != nullptr) {
-      MemN->setPrevNode(getMemDGNodeBefore(ToN, /*IncludingN=*/false));
-      MemN->setNextNode(getMemDGNodeAfter(ToN, /*IncludingN=*/true));
-    }
+  //
+  // We won't always have a DGNode to insert before it. If `To` is BB->end() or
+  // if it points to an instr after DAGInterval.bottom() then we will have to
+  // find a node to insert *after*.
+  //
+  // BB:                              BB:
+  //  I1                               I1 ^
+  //  I2                               I2 | DAGInteval [I1 to I3]
+  //  I3                               I3 V
+  //  I4                               I4   <- `To` == right after DAGInterval
+  //    <- `To` == BB->end()
+  //
+  if (To == BB->end() ||
+      To == std::next(OrigDAGInterval.bottom()->getIterator())) {
+    // If we don't have a node to insert before, find a node to insert after and
+    // update the chain.
+    DGNode *InsertAfterN = getNode(&*std::prev(To));
+    MemN->setPrevNode(
+        getMemDGNodeBefore(InsertAfterN, /*IncludingN=*/true, /*SkipN=*/MemN));
   } else {
-    // MemN becomes the last instruction in the BB.
-    auto *TermN = getNodeOrNull(BB->getTerminator());
-    if (TermN != nullptr) {
-      MemN->setPrevNode(getMemDGNodeBefore(TermN, /*IncludingN=*/false));
-    } else {
-      // The terminator is outside the DAG interval so do nothing.
-    }
+    // We have a node to insert before, so update the chain.
+    DGNode *BeforeToN = getNode(&*To);
+    MemN->setPrevNode(
+        getMemDGNodeBefore(BeforeToN, /*IncludingN=*/false, /*SkipN=*/MemN));
+    MemN->setNextNode(
+        getMemDGNodeAfter(BeforeToN, /*IncludingN=*/true, /*SkipN=*/MemN));
   }
 }
 
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index 3fa4de501f3f5d..29fc05a7f256a2 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -926,3 +926,46 @@ define void @foo(ptr %ptr, ptr %ptr2, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
   EXPECT_EQ(LdN->getPrevNode(), S1N);
   EXPECT_EQ(LdN->getNextNode(), S2N);
 }
+
+// Check that the mem chain is maintained correctly when the move destination is
+// not a mem node.
+TEST_F(DependencyGraphTest, MoveInstrCallbackWithNonMemInstrs) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %arg) {
+  %ld = load i8, ptr %ptr
+  %zext1 = zext i8 %arg to i32
+  %zext2 = zext i8 %arg to i32
+  store i8 %v1, ptr %ptr
+  store i8 %v2, ptr %ptr
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  auto *Ld = cast<sandboxir::LoadInst>(&*It++);
+  [[maybe_unused]] auto *Zext1 = cast<sandboxir::CastInst>(&*It++);
+  auto *Zext2 = cast<sandboxir::CastInst>(&*It++);
+  auto *S1 = cast<sandboxir::StoreInst>(&*It++);
+  auto *S2 = cast<sandboxir::StoreInst>(&*It++);
+  auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
+
+  sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
+  DAG.extend({Ld, S2});
+  auto *LdN = cast<sandboxir::MemDGNode>(DAG.getNode(Ld));
+  auto *S1N = cast<sandboxir::MemDGNode>(DAG.getNode(S1));
+  auto *S2N = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
+  EXPECT_EQ(LdN->getNextNode(), S1N);
+  EXPECT_EQ(S1N->getNextNode(), S2N);
+
+  S1->moveBefore(Zext2);
+  EXPECT_EQ(LdN->getNextNode(), S1N);
+  EXPECT_EQ(S1N->getNextNode(), S2N);
+
+  // Try move right after the end of the DAGInterval.
+  S1->moveBefore(Ret);
+  EXPECT_EQ(S2N->getNextNode(), S1N);
+  EXPECT_EQ(S1N->getNextNode(), nullptr);
+}



More information about the llvm-commits mailing list