[llvm] 166b2e8 - [SandboxVec][DAG] Update DAG when a new instruction is created (#126124)

via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 6 14:12:06 PST 2025


Author: vporpo
Date: 2025-02-06T14:12:03-08:00
New Revision: 166b2e88378bae4d74f9bdc56f1521150162fbf1

URL: https://github.com/llvm/llvm-project/commit/166b2e88378bae4d74f9bdc56f1521150162fbf1
DIFF: https://github.com/llvm/llvm-project/commit/166b2e88378bae4d74f9bdc56f1521150162fbf1.diff

LOG: [SandboxVec][DAG] Update DAG when a new instruction is created (#126124)

The DAG will now receive a callback whenever a new instruction is
created and will update itself accordingly.

Added: 
    

Modified: 
    llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
    llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Interval.h
    llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
    llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
    llvm/unittests/Transforms/Vectorize/SandboxVectorizer/IntervalTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index f4e74fdee84c919..fab456d925526c2 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -263,6 +263,7 @@ class MemDGNode final : public DGNode {
   void addMemPred(MemDGNode *PredN) {
     [[maybe_unused]] auto Inserted = MemPreds.insert(PredN).second;
     assert(Inserted && "PredN already exists!");
+    assert(PredN != this && "Trying to add a dependency to self!");
     if (!Scheduled) {
       ++PredN->UnscheduledSuccs;
     }

diff  --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Interval.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Interval.h
index 18cd29e9e14ee40..f6c5a204673372f 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Interval.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Interval.h
@@ -108,6 +108,10 @@ template <typename T> class Interval {
     return (Top == I || Top->comesBefore(I)) &&
            (I == Bottom || I->comesBefore(Bottom));
   }
+  /// \Returns true if \p Elm is right before the top or right after the bottom.
+  bool touches(T *Elm) const {
+    return Top == Elm->getNextNode() || Bottom == Elm->getPrevNode();
+  }
   T *top() const { return Top; }
   T *bottom() const { return Bottom; }
 

diff  --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
index e03cf32be024406..2680667afc4de29 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -368,8 +368,13 @@ MemDGNode *DependencyGraph::getMemDGNodeAfter(DGNode *N, bool IncludingN,
 }
 
 void DependencyGraph::notifyCreateInstr(Instruction *I) {
-  auto *MemN = dyn_cast<MemDGNode>(getOrCreateNode(I));
-  // TODO: Update the dependencies for the new node.
+  // Nothing to do if the node is not in the focus range of the DAG.
+  if (!(DAGInterval.contains(I) || DAGInterval.touches(I)))
+    return;
+  // Include `I` into the interval.
+  DAGInterval = DAGInterval.getUnionInterval({I, I});
+  auto *N = getOrCreateNode(I);
+  auto *MemN = dyn_cast<MemDGNode>(N);
 
   // Update the MemDGNode chain if this is a memory node.
   if (MemN != nullptr) {
@@ -381,6 +386,21 @@ void DependencyGraph::notifyCreateInstr(Instruction *I) {
       NextMemN->PrevMemN = MemN;
       MemN->NextMemN = NextMemN;
     }
+
+    // Add Mem dependencies.
+    // 1. Scan for deps above `I` for deps to `I`: AboveN->MemN.
+    if (DAGInterval.top()->comesBefore(I)) {
+      Interval<Instruction> AboveIntvl(DAGInterval.top(), I->getPrevNode());
+      auto SrcInterval = MemDGNodeIntervalBuilder::make(AboveIntvl, *this);
+      scanAndAddDeps(*MemN, SrcInterval);
+    }
+    // 2. Scan for deps below `I` for deps from `I`: MemN->BelowN.
+    if (I->comesBefore(DAGInterval.bottom())) {
+      Interval<Instruction> BelowIntvl(I->getNextNode(), DAGInterval.bottom());
+      for (MemDGNode &BelowN :
+           MemDGNodeIntervalBuilder::make(BelowIntvl, *this))
+        scanAndAddDeps(BelowN, Interval<MemDGNode>(MemN, MemN));
+    }
   }
 }
 

diff  --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index 263a37ac335d2ae..f1e9afefb45311b 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -832,9 +832,10 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) {
   }
 }
 
+// Check that the DAG gets updated when we create a new instruction.
 TEST_F(DependencyGraphTest, CreateInstrCallback) {
   parseIR(C, R"IR(
-define void @foo(ptr %ptr, ptr noalias %ptr2, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
+define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %new1, i8 %new2) {
   store i8 %v1, ptr %ptr
   store i8 %v2, ptr %ptr
   store i8 %v3, ptr %ptr
@@ -851,42 +852,52 @@ define void @foo(ptr %ptr, ptr noalias %ptr2, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
   auto *S3 = cast<sandboxir::StoreInst>(&*It++);
   auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
 
-  // Check new instruction callback.
   sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
-  DAG.extend({S1, Ret});
-  auto *Arg = F->getArg(3);
+  // Create a DAG spanning S1 to S3.
+  DAG.extend({S1, S3});
+  auto *ArgNew1 = F->getArg(4);
+  auto *ArgNew2 = F->getArg(5);
   auto *Ptr = S1->getPointerOperand();
+
+  auto *S1MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S1));
+  auto *S2MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
+  auto *S3MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
+  sandboxir::MemDGNode *New1MemN = nullptr;
+  sandboxir::MemDGNode *New2MemN = nullptr;
   {
+    // Create a new store before S3 (within the span of the DAG).
     sandboxir::StoreInst *NewS =
-        sandboxir::StoreInst::create(Arg, Ptr, Align(8), S3->getIterator(),
+        sandboxir::StoreInst::create(ArgNew1, Ptr, Align(8), S3->getIterator(),
                                      /*IsVolatile=*/true, Ctx);
-    auto *NewSN = DAG.getNode(NewS);
-    EXPECT_TRUE(NewSN != nullptr);
-
     // Check the MemDGNode chain.
-    auto *S2MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
-    auto *NewMemSN = cast<sandboxir::MemDGNode>(NewSN);
-    auto *S3MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
-    EXPECT_EQ(S2MemN->getNextNode(), NewMemSN);
-    EXPECT_EQ(NewMemSN->getPrevNode(), S2MemN);
-    EXPECT_EQ(NewMemSN->getNextNode(), S3MemN);
-    EXPECT_EQ(S3MemN->getPrevNode(), NewMemSN);
+    New1MemN = cast<sandboxir::MemDGNode>(DAG.getNode(NewS));
+    EXPECT_EQ(S2MemN->getNextNode(), New1MemN);
+    EXPECT_EQ(New1MemN->getPrevNode(), S2MemN);
+    EXPECT_EQ(New1MemN->getNextNode(), S3MemN);
+    EXPECT_EQ(S3MemN->getPrevNode(), New1MemN);
+
+    // Check dependencies.
+    EXPECT_TRUE(memDependency(S1MemN, New1MemN));
+    EXPECT_TRUE(memDependency(S2MemN, New1MemN));
+    EXPECT_TRUE(memDependency(New1MemN, S3MemN));
   }
-
   {
-    // Also check if new node is at the end of the BB, after Ret.
+    // Create a new store before Ret (outside the current DAG).
     sandboxir::StoreInst *NewS =
-        sandboxir::StoreInst::create(Arg, Ptr, Align(8), BB->end(),
+        sandboxir::StoreInst::create(ArgNew2, Ptr, Align(8), Ret->getIterator(),
                                      /*IsVolatile=*/true, Ctx);
     // Check the MemDGNode chain.
-    auto *S3MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
-    auto *NewMemSN = cast<sandboxir::MemDGNode>(DAG.getNode(NewS));
-    EXPECT_EQ(S3MemN->getNextNode(), NewMemSN);
-    EXPECT_EQ(NewMemSN->getPrevNode(), S3MemN);
-    EXPECT_EQ(NewMemSN->getNextNode(), nullptr);
+    New2MemN = cast<sandboxir::MemDGNode>(DAG.getNode(NewS));
+    EXPECT_EQ(S3MemN->getNextNode(), New2MemN);
+    EXPECT_EQ(New2MemN->getPrevNode(), S3MemN);
+    EXPECT_EQ(New2MemN->getNextNode(), nullptr);
+
+    // Check dependencies.
+    EXPECT_TRUE(memDependency(S1MemN, New2MemN));
+    EXPECT_TRUE(memDependency(S2MemN, New2MemN));
+    EXPECT_TRUE(memDependency(New1MemN, New2MemN));
+    EXPECT_TRUE(memDependency(S3MemN, New2MemN));
   }
-
-  // TODO: Check the dependencies to/from NewSN after they land.
 }
 
 TEST_F(DependencyGraphTest, EraseInstrCallback) {

diff  --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/IntervalTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/IntervalTest.cpp
index 32521ed79a314be..59498371b4d73e6 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/IntervalTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/IntervalTest.cpp
@@ -87,6 +87,15 @@ define void @foo(i8 %v0) {
   EXPECT_FALSE(One.contains(I1));
   EXPECT_FALSE(One.contains(I2));
   EXPECT_FALSE(One.contains(Ret));
+  // Check touches().
+  {
+    sandboxir::Interval<sandboxir::Instruction> Intvl(I2, I2);
+    EXPECT_TRUE(Intvl.touches(I1));
+    EXPECT_TRUE(Intvl.contains(I2));
+    EXPECT_FALSE(Intvl.touches(I2));
+    EXPECT_TRUE(Intvl.touches(Ret));
+    EXPECT_FALSE(Intvl.touches(I0));
+  }
   // Check iterator.
   auto BBIt = BB->begin();
   for (auto &I : Intvl)


        


More information about the llvm-commits mailing list