[llvm] [SandboxVec][DAG] Register move instr callback (PR #120146)

via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 16 13:06:52 PST 2024


https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/120146

This patch implements the move instruction notifier for the DAG. Whenever an instruction moves the notifier will maintain the DAG.

>From 26a263ae6232db7cf1592f51cec17cdfef7cd344 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Thu, 7 Nov 2024 12:46:53 -0800
Subject: [PATCH] [SandboxVec][DAG] Register move instr callback

This patch implements the move instruction notifier for the DAG.
Whenever an instruction moves the notifier will maintain the DAG.
---
 .../SandboxVectorizer/DependencyGraph.h       | 18 ++++++++
 .../SandboxVectorizer/DependencyGraph.cpp     | 46 +++++++++++++++++++
 .../SandboxVectorizer/DependencyGraphTest.cpp | 35 +++++++++++++-
 3 files changed, 98 insertions(+), 1 deletion(-)

diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index b1cad2421bc0d2..f423e1ee456cd1 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -220,6 +220,14 @@ class MemDGNode final : public DGNode {
   void setNextNode(MemDGNode *N) { NextMemN = N; }
   void setPrevNode(MemDGNode *N) { PrevMemN = N; }
   friend class DependencyGraph; // For setNextNode(), setPrevNode().
+  void detachFromChain() {
+    if (PrevMemN != nullptr)
+      PrevMemN->NextMemN = NextMemN;
+    if (NextMemN != nullptr)
+      NextMemN->PrevMemN = PrevMemN;
+    PrevMemN = nullptr;
+    NextMemN = nullptr;
+  }
 
 public:
   MemDGNode(Instruction *I) : DGNode(I, DGNodeID::MemDGNode) {
@@ -293,6 +301,7 @@ class DependencyGraph {
   Context *Ctx = nullptr;
   std::optional<Context::CallbackID> CreateInstrCB;
   std::optional<Context::CallbackID> EraseInstrCB;
+  std::optional<Context::CallbackID> MoveInstrCB;
 
   std::unique_ptr<BatchAAResults> BatchAA;
 
@@ -343,6 +352,9 @@ class DependencyGraph {
   /// Called by the callbacks when instruction \p I is about to get
   /// deleted.
   void notifyEraseInstr(Instruction *I);
+  /// Called by the callbacks when instruction \p I is about to be moved to
+  /// \p To.
+  void notifyMoveInstr(Instruction *I, const BBIterator &To);
 
 public:
   /// This constructor also registers callbacks.
@@ -352,12 +364,18 @@ class DependencyGraph {
         [this](Instruction *I) { notifyCreateInstr(I); });
     EraseInstrCB = Ctx.registerEraseInstrCallback(
         [this](Instruction *I) { notifyEraseInstr(I); });
+    MoveInstrCB = Ctx.registerMoveInstrCallback(
+        [this](Instruction *I, const BBIterator &To) {
+          notifyMoveInstr(I, To);
+        });
   }
   ~DependencyGraph() {
     if (CreateInstrCB)
       Ctx->unregisterCreateInstrCallback(*CreateInstrCB);
     if (EraseInstrCB)
       Ctx->unregisterEraseInstrCallback(*EraseInstrCB);
+    if (MoveInstrCB)
+      Ctx->unregisterMoveInstrCallback(*MoveInstrCB);
   }
 
   DGNode *getNode(Instruction *I) const {
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
index 25f2665d450d13..ba62c45a4e704e 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -370,6 +370,52 @@ void DependencyGraph::notifyCreateInstr(Instruction *I) {
   }
 }
 
+void DependencyGraph::notifyMoveInstr(Instruction *I, const BBIterator &To) {
+  // Early return if `I` doesn't actually move.
+  BasicBlock *BB = To.getNodeParent();
+  if (To != BB->end() && &*To == I->getNextNode())
+    return;
+
+  // Maintain the DAGInterval.
+  DAGInterval.notifyMoveInstr(I, To);
+
+  // TODO: Perhaps check if this is legal by checking the dependencies?
+
+  // Update the MemDGNode chain to reflect the instr movement if necessary.
+  DGNode *N = getNodeOrNull(I);
+  if (N == nullptr)
+    return;
+  MemDGNode *MemN = dyn_cast<MemDGNode>(N);
+  if (MemN == nullptr)
+    return;
+  // First detach it from the existing chain.
+  MemN->detachFromChain();
+  // Now insert it back into the chain at the new location.
+  if (To != BB->end()) {
+    DGNode *ToN = getNodeOrNull(&*To);
+    if (ToN != nullptr) {
+      MemDGNode *PrevMemN = getMemDGNodeBefore(ToN, /*IncludingN=*/false);
+      MemDGNode *NextMemN = getMemDGNodeAfter(ToN, /*IncludingN=*/true);
+      MemN->PrevMemN = PrevMemN;
+      if (PrevMemN != nullptr)
+        PrevMemN->NextMemN = MemN;
+      MemN->NextMemN = NextMemN;
+      if (NextMemN != nullptr)
+        NextMemN->PrevMemN = MemN;
+    }
+  } else {
+    // MemN becomes the last instruction in the BB.
+    auto *TermN = getNodeOrNull(BB->getTerminator());
+    if (TermN != nullptr) {
+      MemDGNode *PrevMemN = getMemDGNodeBefore(TermN, /*IncludingN=*/false);
+      PrevMemN->NextMemN = MemN;
+      MemN->PrevMemN = PrevMemN;
+    } else {
+      // The terminator is outside the DAG interval so do nothing.
+    }
+  }
+}
+
 void DependencyGraph::notifyEraseInstr(Instruction *I) {
   // Update the MemDGNode chain if this is a memory node.
   if (auto *MemN = dyn_cast_or_null<MemDGNode>(getNodeOrNull(I))) {
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index 8c73ee1def8ae1..3fa4de501f3f5d 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -801,7 +801,7 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) {
 
 TEST_F(DependencyGraphTest, CreateInstrCallback) {
   parseIR(C, R"IR(
-define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
+define void @foo(ptr %ptr, ptr noalias %ptr2, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
   store i8 %v1, ptr %ptr
   store i8 %v2, ptr %ptr
   store i8 %v3, ptr %ptr
@@ -893,3 +893,36 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
 
   // TODO: Check the dependencies to/from NewSN after they land.
 }
+
+TEST_F(DependencyGraphTest, MoveInstrCallback) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr, ptr %ptr2, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
+  %ld0 = load i8, ptr %ptr2
+  store i8 %v1, ptr %ptr
+  store i8 %v2, ptr %ptr
+  store i8 %v3, ptr %ptr
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  auto *Ld = cast<sandboxir::LoadInst>(&*It++);
+  auto *S1 = cast<sandboxir::StoreInst>(&*It++);
+  auto *S2 = cast<sandboxir::StoreInst>(&*It++);
+  auto *S3 = cast<sandboxir::StoreInst>(&*It++);
+
+  sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
+  DAG.extend({Ld, S3});
+  auto *LdN = cast<sandboxir::MemDGNode>(DAG.getNode(Ld));
+  auto *S1N = cast<sandboxir::MemDGNode>(DAG.getNode(S1));
+  auto *S2N = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
+  EXPECT_EQ(S1N->getPrevNode(), LdN);
+  S1->moveBefore(Ld);
+  EXPECT_EQ(S1N->getPrevNode(), nullptr);
+  EXPECT_EQ(S1N->getNextNode(), LdN);
+  EXPECT_EQ(LdN->getPrevNode(), S1N);
+  EXPECT_EQ(LdN->getNextNode(), S2N);
+}



More information about the llvm-commits mailing list