[llvm] [SandboxVec][DAG] Implement functions for iterating through DGNodes (PR #109684)

via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 23 09:37:12 PDT 2024


https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/109684

This patch adds DGNode member functions for visiting them in program order.
- DGNode::getPrev() and getNext() return the prev/next node.
- DGNode::getPrevMem() and getNextMem() return the prev/next node that is a memory dependency candidate.
- MemDGNodeIterator iterates through memory candidate nodes.
- makeMemRange() returns a MemDGNodeIterator range given two nodes.

>From a5707e6addcaf1a2ab6c6be7eac06da28472590f Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Fri, 20 Sep 2024 14:48:13 -0700
Subject: [PATCH] [SandboxVec][DAG] Implement functions for iterating through
 DGNodes

This patch adds DGNode member functions for visiting them in program order.
- DGNode::getPrev() and getNext() return the prev/next node.
- DGNode::getPrevMem() and getNextMem() return the prev/next node that is
a memory dependency candidate.
- MemDGNodeIterator iterates through memory candidate nodes.
- makeMemRange() returns a MemDGNodeIterator range given two nodes.
---
 .../SandboxVectorizer/DependencyGraph.h       |  86 +++++++++++++++
 .../SandboxVectorizer/DependencyGraph.cpp     |  61 +++++++++++
 .../SandboxVectorizer/DependencyGraphTest.cpp | 102 ++++++++++++++++++
 3 files changed, 249 insertions(+)

diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index 75b2073d0557c5..c040c00b68220e 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -29,6 +29,8 @@
 
 namespace llvm::sandboxir {
 
+class DependencyGraph;
+
 /// A DependencyGraph Node that points to an Instruction and contains memory
 /// dependency edges.
 class DGNode {
@@ -45,6 +47,7 @@ class DGNode {
             (isa<AllocaInst>(I) && cast<AllocaInst>(I)->isUsedWithInAlloca()) ||
             I->isStackSaveOrRestoreIntrinsic();
   }
+  DGNode(const DGNode &Other) = delete;
   Instruction *getInstruction() const { return I; }
   void addMemPred(DGNode *PredN) { MemPreds.insert(PredN); }
   /// \Returns all memory dependency predecessors.
@@ -56,6 +59,19 @@ class DGNode {
   /// \Returns true if this may read/write memory, or if it has some ordering
   /// constraings, like with stacksave/stackrestore and alloca/inalloca.
   bool isMem() const { return IsMem; }
+  /// \Returns the previous DGNode in program order.
+  DGNode *getPrev(DependencyGraph &DAG) const;
+  /// \Returns the next DGNode in program order.
+  DGNode *getNext(DependencyGraph &DAG) const;
+  /// Walks up the instruction chain looking for the next memory dependency
+  /// candidate instruction.
+  /// \Returns the corresponding DAG Node, or null if no instruction found.
+  DGNode *getPrevMem(DependencyGraph &DAG) const;
+  /// Walks down the instr chain looking for the next memory dependency
+  /// candidate instruction.
+  /// \Returns the corresponding DAG Node, or null if no instruction found.
+  DGNode *getNextMem(DependencyGraph &DAG) const;
+
 #ifndef NDEBUG
   void print(raw_ostream &OS, bool PrintDeps = true) const;
   friend raw_ostream &operator<<(DGNode &N, raw_ostream &OS) {
@@ -66,9 +82,73 @@ class DGNode {
 #endif // NDEBUG
 };
 
+/// Walks in the order of the instruction chain but skips non-mem Nodes.
+/// This is used for building/updating the DAG.
+class MemDGNodeIterator {
+  DGNode *N;
+  DependencyGraph *DAG;
+
+public:
+  using difference_type = std::ptrdiff_t;
+  using value_type = DGNode;
+  using pointer = value_type *;
+  using reference = value_type &;
+  using iterator_category = std::bidirectional_iterator_tag;
+  MemDGNodeIterator(DGNode *N, DependencyGraph *DAG) : N(N), DAG(DAG) {
+    assert((N == nullptr || N->isMem()) && "Expects mem node!");
+  }
+  MemDGNodeIterator &operator++() {
+    assert(N != nullptr && "Already at end!");
+    N = N->getNextMem(*DAG);
+    return *this;
+  }
+  MemDGNodeIterator operator++(int) {
+    auto ItCopy = *this;
+    ++*this;
+    return ItCopy;
+  }
+  MemDGNodeIterator &operator--() {
+    N = N->getPrevMem(*DAG);
+    return *this;
+  }
+  MemDGNodeIterator operator--(int) {
+    auto ItCopy = *this;
+    --*this;
+    return ItCopy;
+  }
+  pointer operator*() { return N; }
+  const DGNode *operator*() const { return N; }
+  bool operator==(const MemDGNodeIterator &Other) const { return N == Other.N; }
+  bool operator!=(const MemDGNodeIterator &Other) const {
+    return !(*this == Other);
+  }
+};
+
+/// A MemDGNodeIterator with convenience builders and dump().
+class DGNodeRange : public iterator_range<MemDGNodeIterator> {
+public:
+  DGNodeRange(MemDGNodeIterator Begin, MemDGNodeIterator End)
+      : iterator_range(Begin, End) {}
+  /// An empty range.
+  DGNodeRange()
+      : iterator_range(MemDGNodeIterator(nullptr, nullptr),
+                       MemDGNodeIterator(nullptr, nullptr)) {}
+  /// Given \p TopN and \p BotN it finds their closest mem nodes in the range
+  /// TopN to BotN and returns the corresponding mem range.
+  /// Note: BotN (or its neighboring mem node) is included in the range.
+  static DGNodeRange makeMemRange(DGNode *TopN, DGNode *BotN,
+                                  DependencyGraph &DAG);
+  static DGNodeRange makeEmptyMemRange() { return DGNodeRange(); }
+#ifndef NDEBUG
+  LLVM_DUMP_METHOD void dump() const;
+#endif
+};
+
 class DependencyGraph {
 private:
   DenseMap<Instruction *, std::unique_ptr<DGNode>> InstrToNodeMap;
+  /// The DAG spans across all instructions in this interval.
+  InstrInterval DAGInterval;
 
 public:
   DependencyGraph() {}
@@ -77,6 +157,12 @@ class DependencyGraph {
     auto It = InstrToNodeMap.find(I);
     return It != InstrToNodeMap.end() ? It->second.get() : nullptr;
   }
+  /// Like getNode() but returns nullptr if \p I is nullptr.
+  DGNode *getNodeOrNull(Instruction *I) const {
+    if (I == nullptr)
+      return nullptr;
+    return getNode(I);
+  }
   DGNode *getOrCreateNode(Instruction *I) {
     auto [It, NotInMap] = InstrToNodeMap.try_emplace(I);
     if (NotInMap)
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
index 139e581ce03d96..58545df8f456e7 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -11,6 +11,39 @@
 
 using namespace llvm::sandboxir;
 
+// TODO: Move this to Utils once it lands.
+/// \Returns the previous memory-dependency-candidate instruction before \p I in
+/// the instruction stream.
+static llvm::sandboxir::Instruction *
+getPrevMemDepInst(llvm::sandboxir::Instruction *I) {
+  for (I = I->getPrevNode(); I != nullptr; I = I->getPrevNode())
+    if (I->isMemDepCandidate() || I->isStackSaveOrRestoreIntrinsic())
+      return I;
+  return nullptr;
+}
+/// \Returns the next memory-dependency-candidate instruction after \p I in the
+/// instruction stream.
+static llvm::sandboxir::Instruction *
+getNextMemDepInst(llvm::sandboxir::Instruction *I) {
+  for (I = I->getNextNode(); I != nullptr; I = I->getNextNode())
+    if (I->isMemDepCandidate() || I->isStackSaveOrRestoreIntrinsic())
+      return I;
+  return nullptr;
+}
+
+DGNode *DGNode::getPrev(DependencyGraph &DAG) const {
+  return DAG.getNodeOrNull(I->getPrevNode());
+}
+DGNode *DGNode::getNext(DependencyGraph &DAG) const {
+  return DAG.getNodeOrNull(I->getNextNode());
+}
+DGNode *DGNode::getPrevMem(DependencyGraph &DAG) const {
+  return DAG.getNodeOrNull(getPrevMemDepInst(I));
+}
+DGNode *DGNode::getNextMem(DependencyGraph &DAG) const {
+  return DAG.getNodeOrNull(getNextMemDepInst(I));
+}
+
 #ifndef NDEBUG
 void DGNode::print(raw_ostream &OS, bool PrintDeps) const {
   I->dumpOS(OS);
@@ -31,6 +64,34 @@ void DGNode::dump() const {
 }
 #endif // NDEBUG
 
+DGNodeRange DGNodeRange::makeMemRange(DGNode *TopN, DGNode *BotN,
+                                      DependencyGraph &DAG) {
+  assert((TopN == BotN ||
+          TopN->getInstruction()->comesBefore(BotN->getInstruction())) &&
+         "Expected TopN before BotN!");
+  // If TopN/BotN are not mem-dep candidate nodes we need to walk down/up the
+  // chain and find the mem-dep ones.
+  DGNode *MemTopN = TopN;
+  DGNode *MemBotN = BotN;
+  while (!MemTopN->isMem() && MemTopN != MemBotN)
+    MemTopN = MemTopN->getNext(DAG);
+  while (!MemBotN->isMem() && MemBotN != MemTopN)
+    MemBotN = MemBotN->getPrev(DAG);
+  // If we couldn't find a mem node in range TopN - BotN then it's empty.
+  if (!MemTopN->isMem())
+    return {};
+  // Now that we have the mem-dep nodes, create and return the range.
+  return DGNodeRange(MemDGNodeIterator(MemTopN, &DAG),
+                     MemDGNodeIterator(MemBotN->getNextMem(DAG), &DAG));
+}
+
+#ifndef NDEBUG
+void DGNodeRange::dump() const {
+  for (const DGNode *N : *this)
+    N->dump();
+}
+#endif // NDEBUG
+
 InstrInterval DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
   if (Instrs.empty())
     return {};
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index f6bfd097f20a4e..3c8e217264481c 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -113,3 +113,105 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
   EXPECT_THAT(N1->memPreds(), testing::ElementsAre(N0));
   EXPECT_THAT(N2->memPreds(), testing::ElementsAre(N1));
 }
+
+TEST_F(DependencyGraphTest, DGNode_getPrev_getNext_getPrevMem_getNextMem) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
+  store i8 %v0, ptr %ptr
+  add i8 %v0, %v0
+  store i8 %v1, ptr %ptr
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  auto *S0 = cast<sandboxir::StoreInst>(&*It++);
+  auto *Add = cast<sandboxir::BinaryOperator>(&*It++);
+  auto *S1 = cast<sandboxir::StoreInst>(&*It++);
+  auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
+
+  sandboxir::DependencyGraph DAG;
+  DAG.extend({&*BB->begin(), BB->getTerminator()});
+
+  sandboxir::DGNode *S0N = DAG.getNode(S0);
+  sandboxir::DGNode *AddN = DAG.getNode(Add);
+  sandboxir::DGNode *S1N = DAG.getNode(S1);
+  sandboxir::DGNode *RetN = DAG.getNode(Ret);
+
+  EXPECT_EQ(S0N->getPrev(DAG), nullptr);
+  EXPECT_EQ(S0N->getNext(DAG), AddN);
+  EXPECT_EQ(S0N->getPrevMem(DAG), nullptr);
+  EXPECT_EQ(S0N->getNextMem(DAG), S1N);
+
+  EXPECT_EQ(AddN->getPrev(DAG), S0N);
+  EXPECT_EQ(AddN->getNext(DAG), S1N);
+  EXPECT_EQ(AddN->getPrevMem(DAG), S0N);
+  EXPECT_EQ(AddN->getNextMem(DAG), S1N);
+
+  EXPECT_EQ(S1N->getPrev(DAG), AddN);
+  EXPECT_EQ(S1N->getNext(DAG), RetN);
+  EXPECT_EQ(S1N->getPrevMem(DAG), S0N);
+  EXPECT_EQ(S1N->getNextMem(DAG), nullptr);
+
+  EXPECT_EQ(RetN->getPrev(DAG), S1N);
+  EXPECT_EQ(RetN->getNext(DAG), nullptr);
+  EXPECT_EQ(RetN->getPrevMem(DAG), S1N);
+  EXPECT_EQ(RetN->getNextMem(DAG), nullptr);
+}
+
+TEST_F(DependencyGraphTest, DGNodeRange) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
+  add i8 %v0, %v0
+  store i8 %v0, ptr %ptr
+  add i8 %v0, %v0
+  store i8 %v1, ptr %ptr
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  auto *Add0 = cast<sandboxir::BinaryOperator>(&*It++);
+  auto *S0 = cast<sandboxir::StoreInst>(&*It++);
+  auto *Add1 = cast<sandboxir::BinaryOperator>(&*It++);
+  auto *S1 = cast<sandboxir::StoreInst>(&*It++);
+  auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
+
+  sandboxir::DependencyGraph DAG;
+  DAG.extend({&*BB->begin(), BB->getTerminator()});
+
+  sandboxir::DGNode *Add0N = DAG.getNode(Add0);
+  sandboxir::DGNode *S0N = DAG.getNode(S0);
+  sandboxir::DGNode *Add1N = DAG.getNode(Add1);
+  sandboxir::DGNode *S1N = DAG.getNode(S1);
+  sandboxir::DGNode *RetN = DAG.getNode(Ret);
+
+  // Check empty range.
+  EXPECT_THAT(sandboxir::DGNodeRange::makeEmptyMemRange(),
+              testing::ElementsAre());
+
+  // Both TopN and BotN are memory.
+  EXPECT_THAT(sandboxir::DGNodeRange::makeMemRange(S0N, S1N, DAG),
+              testing::ElementsAre(S0N, S1N));
+  // Only TopN is memory.
+  EXPECT_THAT(sandboxir::DGNodeRange::makeMemRange(S0N, RetN, DAG),
+              testing::ElementsAre(S0N, S1N));
+  EXPECT_THAT(sandboxir::DGNodeRange::makeMemRange(S0N, Add1N, DAG),
+              testing::ElementsAre(S0N));
+  // Only BotN is memory.
+  EXPECT_THAT(sandboxir::DGNodeRange::makeMemRange(Add0N, S1N, DAG),
+              testing::ElementsAre(S0N, S1N));
+  EXPECT_THAT(sandboxir::DGNodeRange::makeMemRange(Add0N, S0N, DAG),
+              testing::ElementsAre(S0N));
+  // Neither TopN or BotN is memory.
+  EXPECT_THAT(sandboxir::DGNodeRange::makeMemRange(Add0N, RetN, DAG),
+              testing::ElementsAre(S0N, S1N));
+  EXPECT_THAT(sandboxir::DGNodeRange::makeMemRange(Add0N, Add0N, DAG),
+              testing::ElementsAre());
+}



More information about the llvm-commits mailing list