[llvm] [SandboxVec][DAG] Extend DAG (PR #111908)

via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 10 17:59:37 PDT 2024


https://github.com/vporpo updated https://github.com/llvm/llvm-project/pull/111908

>From 7ddf808ee8bd09a7fca5cc18dcc7025072266578 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Mon, 7 Oct 2024 16:06:38 -0700
Subject: [PATCH] [SandboxVec][DAG] Extend DAG

This patch implements growing the DAG towards the top or bottom.
This does the necessary dependency checks and adds new mem dependencies.
---
 .../SandboxVectorizer/DependencyGraph.h       |   8 +-
 .../SandboxVectorizer/DependencyGraph.cpp     | 117 +++++++++++++++---
 .../SandboxVectorizer/DependencyGraphTest.cpp |  67 ++++++++++
 3 files changed, 176 insertions(+), 16 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index 5fa57efc1462e8..0da52c4236d77e 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -284,6 +284,10 @@ class DependencyGraph {
   /// \p DstN.
   void scanAndAddDeps(MemDGNode &DstN, const Interval<MemDGNode> &SrcScanRange);
 
+  /// Create DAG nodes for instrs in \p NewInterval and update the MemNode
+  /// chain.
+  void createNewNodes(const Interval<Instruction> &NewInterval);
+
 public:
   DependencyGraph(AAResults &AA)
       : BatchAA(std::make_unique<BatchAAResults>(AA)) {}
@@ -309,8 +313,10 @@ class DependencyGraph {
     return It->second.get();
   }
   /// Build/extend the dependency graph such that it includes \p Instrs. Returns
-  /// the interval spanning \p Instrs.
+  /// the range of instructions added to the DAG.
   Interval<Instruction> extend(ArrayRef<Instruction *> Instrs);
+  /// \Returns the range of instructions included in the DAG.
+  Interval<Instruction> getInterval() const { return DAGInterval; }
 #ifndef NDEBUG
   void print(raw_ostream &OS) const;
   LLVM_DUMP_METHOD void dump() const;
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
index 0cd2240e7ff1b3..db58069de47051 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -215,17 +215,11 @@ void DependencyGraph::scanAndAddDeps(MemDGNode &DstN,
   }
 }
 
-Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
-  if (Instrs.empty())
-    return {};
-
-  Interval<Instruction> InstrInterval(Instrs);
-
-  DGNode *LastN = getOrCreateNode(InstrInterval.top());
-  // Create DGNodes for all instrs in Interval to avoid future Instruction to
-  // DGNode lookups.
+void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) {
+  // Create Nodes only for the new sections of the DAG.
+  DGNode *LastN = getOrCreateNode(NewInterval.top());
   MemDGNode *LastMemN = dyn_cast<MemDGNode>(LastN);
-  for (Instruction &I : drop_begin(InstrInterval)) {
+  for (Instruction &I : drop_begin(NewInterval)) {
     auto *N = getOrCreateNode(&I);
     // Build the Mem node chain.
     if (auto *MemN = dyn_cast<MemDGNode>(N)) {
@@ -235,16 +229,109 @@ Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
       LastMemN = MemN;
     }
   }
+  // Link new MemDGNode chain with the old one, if any.
+  if (!DAGInterval.empty()) {
+    // TODO: Implement Interval::comesBefore() to replace this check.
+    bool NewIsAbove = NewInterval.bottom()->comesBefore(DAGInterval.top());
+    assert(
+        (NewIsAbove || DAGInterval.bottom()->comesBefore(NewInterval.top())) &&
+        "Expected NewInterval below DAGInterval.");
+    const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval;
+    const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval;
+    MemDGNode *LinkTopN =
+        MemDGNodeIntervalBuilder::getBotMemDGNode(TopInterval, *this);
+    MemDGNode *LinkBotN =
+        MemDGNodeIntervalBuilder::getTopMemDGNode(BotInterval, *this);
+    assert(LinkTopN->comesBefore(LinkBotN) && "Wrong order!");
+    if (LinkTopN != nullptr && LinkBotN != nullptr) {
+      LinkTopN->setNextNode(LinkBotN);
+      LinkBotN->setPrevNode(LinkTopN);
+    }
+#ifndef NDEBUG
+    // TODO: Remove this once we've done enough testing.
+    // Check that the chain is well formed.
+    auto UnionIntvl = DAGInterval.getUnionInterval(NewInterval);
+    MemDGNode *ChainTopN =
+        MemDGNodeIntervalBuilder::getTopMemDGNode(UnionIntvl, *this);
+    MemDGNode *ChainBotN =
+        MemDGNodeIntervalBuilder::getBotMemDGNode(UnionIntvl, *this);
+    if (ChainTopN != nullptr && ChainBotN != nullptr) {
+      for (auto *N = ChainTopN->getNextNode(), *LastN = ChainTopN; N != nullptr;
+           LastN = N, N = N->getNextNode()) {
+        assert(N == LastN->getNextNode() && "Bad chain!");
+        assert(N->getPrevNode() == LastN && "Bad chain!");
+      }
+    }
+#endif // NDEBUG
+  }
+}
+
+Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
+  if (Instrs.empty())
+    return {};
+
+  Interval<Instruction> InstrsInterval(Instrs);
+  Interval<Instruction> Union = DAGInterval.getUnionInterval(InstrsInterval);
+  auto NewInterval = Union.getSingleDiff(DAGInterval);
+  if (NewInterval.empty())
+    return {};
+
+  createNewNodes(NewInterval);
+
   // Create the dependencies.
-  auto DstRange = MemDGNodeIntervalBuilder::make(InstrInterval, *this);
-  if (!DstRange.empty()) {
-    for (MemDGNode &DstN : drop_begin(DstRange)) {
-      auto SrcRange = Interval<MemDGNode>(DstRange.top(), DstN.getPrevNode());
+  //
+  // 1. DAGInterval empty      2. New is below Old     3. New is above old
+  // ------------------------  -------------------      -------------------
+  //                                         Scan:           DstN:    Scan:
+  //                           +---+         -ScanTopN  +---+DstTopN  -ScanTopN
+  //                           |   |         |          |New|         |
+  //                           |Old|         |          +---+         -ScanBotN
+  //                           |   |         |          +---+
+  //      DstN:    Scan:       +---+DstN:    |          |   |
+  // +---+DstTopN  -ScanTopN   +---+DstTopN  |          |Old|
+  // |New|         |           |New|         |          |   |
+  // +---+DstBotN  -ScanBotN   +---+DstBotN  -ScanBotN  +---+DstBotN
+
+  // 1. This is a new DAG.
+  if (DAGInterval.empty()) {
+    assert(NewInterval == InstrsInterval && "Expected empty DAGInterval!");
+    auto DstRange = MemDGNodeIntervalBuilder::make(NewInterval, *this);
+    if (!DstRange.empty()) {
+      for (MemDGNode &DstN : drop_begin(DstRange)) {
+        auto SrcRange = Interval<MemDGNode>(DstRange.top(), DstN.getPrevNode());
+        scanAndAddDeps(DstN, SrcRange);
+      }
+    }
+  }
+  // 2. The new section is below the old section.
+  else if (DAGInterval.bottom()->comesBefore(NewInterval.top())) {
+    auto DstRange = MemDGNodeIntervalBuilder::make(NewInterval, *this);
+    auto SrcRangeFull = MemDGNodeIntervalBuilder::make(
+        DAGInterval.getUnionInterval(NewInterval), *this);
+    for (MemDGNode &DstN : DstRange) {
+      auto SrcRange =
+          Interval<MemDGNode>(SrcRangeFull.top(), DstN.getPrevNode());
       scanAndAddDeps(DstN, SrcRange);
     }
   }
+  // 3. The new section is above the old section.
+  else if (NewInterval.bottom()->comesBefore(DAGInterval.top())) {
+    auto DstRange = MemDGNodeIntervalBuilder::make(
+        NewInterval.getUnionInterval(DAGInterval), *this);
+    auto SrcRangeFull = MemDGNodeIntervalBuilder::make(NewInterval, *this);
+    if (!DstRange.empty()) {
+      for (MemDGNode &DstN : drop_begin(DstRange)) {
+        auto SrcRange =
+            Interval<MemDGNode>(SrcRangeFull.top(), DstN.getPrevNode());
+        scanAndAddDeps(DstN, SrcRange);
+      }
+    }
+  } else {
+    llvm_unreachable("We don't expect extending in both directions!");
+  }
 
-  return InstrInterval;
+  DAGInterval = Union;
+  return NewInterval;
 }
 
 #ifndef NDEBUG
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index 7e2be25fa25ae6..3dbf03e4ba44e2 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -681,3 +681,70 @@ define void @foo() {
   EXPECT_FALSE(memDependency(StackSaveN, AllocaN));
   EXPECT_FALSE(memDependency(AllocaN, StackRestoreN));
 }
+
+TEST_F(DependencyGraphTest, Extend) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) {
+  store i8 %v1, ptr %ptr
+  store i8 %v2, ptr %ptr
+  store i8 %v3, ptr %ptr
+  store i8 %v4, ptr %ptr
+  store i8 %v5, ptr %ptr
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  auto *S1 = cast<sandboxir::StoreInst>(&*It++);
+  auto *S2 = cast<sandboxir::StoreInst>(&*It++);
+  auto *S3 = cast<sandboxir::StoreInst>(&*It++);
+  auto *S4 = cast<sandboxir::StoreInst>(&*It++);
+  auto *S5 = cast<sandboxir::StoreInst>(&*It++);
+  sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+  {
+    // Scenario 1: Build new DAG
+    auto NewIntvl = DAG.extend({S3, S3});
+    EXPECT_EQ(NewIntvl, sandboxir::Interval<sandboxir::Instruction>(S3, S3));
+    EXPECT_EQ(DAG.getInterval().top(), S3);
+    EXPECT_EQ(DAG.getInterval().bottom(), S3);
+    [[maybe_unused]] auto *S3N = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
+  }
+  {
+    // Scenario 2: Extend below
+    auto NewIntvl = DAG.extend({S5, S5});
+    EXPECT_EQ(NewIntvl, sandboxir::Interval<sandboxir::Instruction>(S4, S5));
+    auto *S3N = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
+    auto *S4N = cast<sandboxir::MemDGNode>(DAG.getNode(S4));
+    auto *S5N = cast<sandboxir::MemDGNode>(DAG.getNode(S5));
+    EXPECT_TRUE(S4N->hasMemPred(S3N));
+    EXPECT_TRUE(S5N->hasMemPred(S4N));
+    EXPECT_TRUE(S5N->hasMemPred(S3N));
+  }
+  {
+    // Scenario 3: Extend above
+    auto NewIntvl = DAG.extend({S1, S2});
+    EXPECT_EQ(NewIntvl, sandboxir::Interval<sandboxir::Instruction>(S1, S2));
+    auto *S1N = cast<sandboxir::MemDGNode>(DAG.getNode(S1));
+    auto *S2N = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
+    auto *S3N = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
+    auto *S4N = cast<sandboxir::MemDGNode>(DAG.getNode(S4));
+    auto *S5N = cast<sandboxir::MemDGNode>(DAG.getNode(S5));
+
+    EXPECT_TRUE(S2N->hasMemPred(S1N));
+
+    EXPECT_TRUE(S3N->hasMemPred(S2N));
+    EXPECT_TRUE(S3N->hasMemPred(S1N));
+
+    EXPECT_TRUE(S4N->hasMemPred(S3N));
+    EXPECT_TRUE(S4N->hasMemPred(S2N));
+    EXPECT_TRUE(S4N->hasMemPred(S1N));
+
+    EXPECT_TRUE(S5N->hasMemPred(S4N));
+    EXPECT_TRUE(S5N->hasMemPred(S3N));
+    EXPECT_TRUE(S5N->hasMemPred(S2N));
+    EXPECT_TRUE(S5N->hasMemPred(S1N));
+  }
+}



More information about the llvm-commits mailing list