[llvm] [SandboxVec][DAG] MemDGNode for memory-dependency candidate nodes (PR #109684)

via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 27 10:34:07 PDT 2024


https://github.com/vporpo updated https://github.com/llvm/llvm-project/pull/109684

>From e26186a603b260bd5dc34983f8c6adbbae4e3871 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Fri, 20 Sep 2024 14:48:13 -0700
Subject: [PATCH] [SandboxVec][DAG] MemDGNode for memory-dependency candidate
 nodes

This patch implements the MemDGNode class for DAG nodes that are candidates
for memory dependencies. These nodes form a chain that is accessible by
`getPrevNode()` and `getNextNode()`.

The patch also implements the MemDGNodeIterator for iterating over
the MemDGNode chain.

Finally, makeMemRange() returns a MemDGNodeIterator range given an
instruction interval of memory or non-memory instructions. This will be used
in a follow-up patch for dependency scanning.
---
 .../SandboxVectorizer/DependencyGraph.h       | 139 ++++++++++++++++--
 .../SandboxVectorizer/DependencyGraph.cpp     |  34 +++++
 .../SandboxVectorizer/DependencyGraphTest.cpp | 105 +++++++++++--
 3 files changed, 254 insertions(+), 24 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index 5437853c366ae6..6cd7646995f8f4 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -29,22 +29,42 @@
 
 namespace llvm::sandboxir {
 
+class DependencyGraph;
+class MemDGNode;
+
+/// SubclassIDs for isa/dyn_cast etc.
+enum class DGNodeID {
+  DGNode,
+  MemDGNode,
+};
+
 /// A DependencyGraph Node that points to an Instruction and contains memory
 /// dependency edges.
 class DGNode {
+protected:
   Instruction *I;
+  // TODO: Use a PointerIntPair for SubclassID and I.
+  /// For isa/dyn_cast etc.
+  DGNodeID SubclassID;
   /// Memory predecessors.
   DenseSet<DGNode *> MemPreds;
-  /// This is true if this may read/write memory, or if it has some ordering
-  /// constraints, like with stacksave/stackrestore and alloca/inalloca.
-  bool IsMem;
+
+  DGNode(Instruction *I, DGNodeID ID) : I(I), SubclassID(ID) {}
+  friend class MemDGNode; // For constructor.
 
 public:
-  DGNode(Instruction *I) : I(I) {
-    IsMem = I->isMemDepCandidate() ||
-            (isa<AllocaInst>(I) && cast<AllocaInst>(I)->isUsedWithInAlloca()) ||
-            I->isStackSaveOrRestoreIntrinsic();
+  DGNode(Instruction *I) : I(I), SubclassID(DGNodeID::DGNode) {
+    assert(!isMemDepCandidate(I) && "Expected Non-Mem instruction, ");
   }
+  DGNode(const DGNode &Other) = delete;
+  virtual ~DGNode() = default;
+  /// \Returns true if \p I is a memory dependency candidate instruction.
+  static bool isMemDepCandidate(Instruction *I) {
+    return I->isMemDepCandidate() ||
+           (isa<AllocaInst>(I) && cast<AllocaInst>(I)->isUsedWithInAlloca()) ||
+           I->isStackSaveOrRestoreIntrinsic();
+  }
+
   Instruction *getInstruction() const { return I; }
   void addMemPred(DGNode *PredN) { MemPreds.insert(PredN); }
   /// \Returns all memory dependency predecessors.
@@ -53,11 +73,9 @@ class DGNode {
   }
   /// \Returns true if there is a memory dependency N->this.
   bool hasMemPred(DGNode *N) const { return MemPreds.count(N); }
-  /// \Returns true if this may read/write memory, or if it has some ordering
-  /// constraints, like with stacksave/stackrestore and alloca/inalloca.
-  bool isMem() const { return IsMem; }
+
 #ifndef NDEBUG
-  void print(raw_ostream &OS, bool PrintDeps = true) const;
+  virtual void print(raw_ostream &OS, bool PrintDeps = true) const;
   friend raw_ostream &operator<<(DGNode &N, raw_ostream &OS) {
     N.print(OS);
     return OS;
@@ -66,9 +84,94 @@ class DGNode {
 #endif // NDEBUG
 };
 
+/// A DependencyGraph Node for instructiosn that may read/write memory, or have
+/// some ordering constraints, like with stacksave/stackrestore and
+/// alloca/inalloca.
+class MemDGNode final : public DGNode {
+  MemDGNode *PrevMemN = nullptr;
+  MemDGNode *NextMemN = nullptr;
+
+  void setNextNode(MemDGNode *N) { NextMemN = N; }
+  void setPrevNode(MemDGNode *N) { PrevMemN = N; }
+  friend class DependencyGraph; // For setNextNode(), setPrevNode().
+
+public:
+  MemDGNode(Instruction *I) : DGNode(I, DGNodeID::MemDGNode) {
+    assert(isMemDepCandidate(I) && "Expected Mem instruction!");
+  }
+  static bool classof(const DGNode *Other) {
+    return Other->SubclassID == DGNodeID::MemDGNode;
+  }
+  /// \Returns the previous Mem DGNode in instruction order.
+  MemDGNode *getPrevNode() const { return PrevMemN; }
+  /// \Returns the next Mem DGNode in instruction order.
+  MemDGNode *getNextNode() const { return NextMemN; }
+};
+
+/// Walks in the order of the instruction chain but skips non-mem Nodes.
+/// This is used for building/updating the DAG.
+class MemDGNodeIterator {
+  MemDGNode *N;
+
+public:
+  using difference_type = std::ptrdiff_t;
+  using value_type = DGNode;
+  using pointer = value_type *;
+  using reference = value_type &;
+  using iterator_category = std::bidirectional_iterator_tag;
+  MemDGNodeIterator(MemDGNode *N) : N(N) {}
+  MemDGNodeIterator &operator++() {
+    assert(N != nullptr && "Already at end!");
+    N = N->getNextNode();
+    return *this;
+  }
+  MemDGNodeIterator operator++(int) {
+    auto ItCopy = *this;
+    ++*this;
+    return ItCopy;
+  }
+  MemDGNodeIterator &operator--() {
+    N = N->getPrevNode();
+    return *this;
+  }
+  MemDGNodeIterator operator--(int) {
+    auto ItCopy = *this;
+    --*this;
+    return ItCopy;
+  }
+  pointer operator*() { return N; }
+  const DGNode *operator*() const { return N; }
+  bool operator==(const MemDGNodeIterator &Other) const { return N == Other.N; }
+  bool operator!=(const MemDGNodeIterator &Other) const {
+    return !(*this == Other);
+  }
+};
+
+/// A MemDGNodeIterator with convenience builders and dump().
+class DGNodeRange : public iterator_range<MemDGNodeIterator> {
+public:
+  DGNodeRange(MemDGNodeIterator Begin, MemDGNodeIterator End)
+      : iterator_range(Begin, End) {}
+  /// An empty range.
+  DGNodeRange()
+      : iterator_range(MemDGNodeIterator(nullptr), MemDGNodeIterator(nullptr)) {
+  }
+  /// Given \p Instrs it finds their closest mem nodes in the interval and
+  /// returns the corresponding mem range. Note: BotN (or its neighboring mem
+  /// node) is included in the range.
+  static DGNodeRange makeMemRange(const Interval<Instruction> &Instrs,
+                                  DependencyGraph &DAG);
+  static DGNodeRange makeEmptyMemRange() { return DGNodeRange(); }
+#ifndef NDEBUG
+  LLVM_DUMP_METHOD void dump() const;
+#endif
+};
+
 class DependencyGraph {
 private:
   DenseMap<Instruction *, std::unique_ptr<DGNode>> InstrToNodeMap;
+  /// The DAG spans across all instructions in this interval.
+  Interval<Instruction> DAGInterval;
 
 public:
   DependencyGraph() {}
@@ -77,10 +180,20 @@ class DependencyGraph {
     auto It = InstrToNodeMap.find(I);
     return It != InstrToNodeMap.end() ? It->second.get() : nullptr;
   }
+  /// Like getNode() but returns nullptr if \p I is nullptr.
+  DGNode *getNodeOrNull(Instruction *I) const {
+    if (I == nullptr)
+      return nullptr;
+    return getNode(I);
+  }
   DGNode *getOrCreateNode(Instruction *I) {
     auto [It, NotInMap] = InstrToNodeMap.try_emplace(I);
-    if (NotInMap)
-      It->second = std::make_unique<DGNode>(I);
+    if (NotInMap) {
+      if (I->isMemDepCandidate() || I->isStackSaveOrRestoreIntrinsic())
+        It->second = std::make_unique<MemDGNode>(I);
+      else
+        It->second = std::make_unique<DGNode>(I);
+    }
     return It->second.get();
   }
   /// Build/extend the dependency graph such that it includes \p Instrs. Returns
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
index 67b56451c7b594..e2e4cc37408690 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -31,6 +31,32 @@ void DGNode::dump() const {
 }
 #endif // NDEBUG
 
+DGNodeRange DGNodeRange::makeMemRange(const Interval<Instruction> &Instrs,
+                                      DependencyGraph &DAG) {
+  // If top or bottom instructions are not mem-dep candidate nodes we need to
+  // walk down/up the chain and find the mem-dep ones.
+  Instruction *MemTopI = Instrs.top();
+  Instruction *MemBotI = Instrs.bottom();
+  while (!DGNode::isMemDepCandidate(MemTopI) && MemTopI != MemBotI)
+    MemTopI = MemTopI->getNextNode();
+  while (!DGNode::isMemDepCandidate(MemBotI) && MemBotI != MemTopI)
+    MemBotI = MemBotI->getPrevNode();
+  // If we couldn't find a mem node in range TopN - BotN then it's empty.
+  if (!DGNode::isMemDepCandidate(MemTopI))
+    return {};
+  // Now that we have the mem-dep nodes, create and return the range.
+  return DGNodeRange(
+      MemDGNodeIterator(cast<MemDGNode>(DAG.getNode(MemTopI))),
+      MemDGNodeIterator(cast<MemDGNode>(DAG.getNode(MemBotI))->getNextNode()));
+}
+
+#ifndef NDEBUG
+void DGNodeRange::dump() const {
+  for (const DGNode *N : *this)
+    N->dump();
+}
+#endif // NDEBUG
+
 Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
   if (Instrs.empty())
     return {};
@@ -39,10 +65,18 @@ Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
   auto *TopI = Interval.top();
   auto *BotI = Interval.bottom();
   DGNode *LastN = getOrCreateNode(TopI);
+  MemDGNode *LastMemN = dyn_cast<MemDGNode>(LastN);
   for (Instruction *I = TopI->getNextNode(), *E = BotI->getNextNode(); I != E;
        I = I->getNextNode()) {
     auto *N = getOrCreateNode(I);
     N->addMemPred(LastN);
+    // Build the Mem node chain.
+    if (auto *MemN = dyn_cast<MemDGNode>(N)) {
+      MemN->setPrevNode(LastMemN);
+      if (LastMemN != nullptr)
+        LastMemN->setNextNode(MemN);
+      LastMemN = MemN;
+    }
     LastN = N;
   }
   return Interval;
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index f6bfd097f20a4e..792170c51d15e2 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -27,7 +27,7 @@ struct DependencyGraphTest : public testing::Test {
   }
 };
 
-TEST_F(DependencyGraphTest, DGNode_IsMem) {
+TEST_F(DependencyGraphTest, MemDGNode) {
   parseIR(C, R"IR(
 declare void @llvm.sideeffect()
 declare void @llvm.pseudoprobe(i64, i64, i32, i64)
@@ -64,16 +64,16 @@ define void @foo(i8 %v1, ptr %ptr) {
 
   sandboxir::DependencyGraph DAG;
   DAG.extend({&*BB->begin(), BB->getTerminator()});
-  EXPECT_TRUE(DAG.getNode(Store)->isMem());
-  EXPECT_TRUE(DAG.getNode(Load)->isMem());
-  EXPECT_FALSE(DAG.getNode(Add)->isMem());
-  EXPECT_TRUE(DAG.getNode(StackSave)->isMem());
-  EXPECT_TRUE(DAG.getNode(StackRestore)->isMem());
-  EXPECT_FALSE(DAG.getNode(SideEffect)->isMem());
-  EXPECT_FALSE(DAG.getNode(PseudoProbe)->isMem());
-  EXPECT_TRUE(DAG.getNode(FakeUse)->isMem());
-  EXPECT_TRUE(DAG.getNode(Call)->isMem());
-  EXPECT_FALSE(DAG.getNode(Ret)->isMem());
+  EXPECT_TRUE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(Store)));
+  EXPECT_TRUE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(Load)));
+  EXPECT_FALSE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(Add)));
+  EXPECT_TRUE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(StackSave)));
+  EXPECT_TRUE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(StackRestore)));
+  EXPECT_FALSE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(SideEffect)));
+  EXPECT_FALSE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(PseudoProbe)));
+  EXPECT_TRUE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(FakeUse)));
+  EXPECT_TRUE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(Call)));
+  EXPECT_FALSE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(Ret)));
 }
 
 TEST_F(DependencyGraphTest, Basic) {
@@ -113,3 +113,86 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
   EXPECT_THAT(N1->memPreds(), testing::ElementsAre(N0));
   EXPECT_THAT(N2->memPreds(), testing::ElementsAre(N1));
 }
+
+TEST_F(DependencyGraphTest, MemDGNode_getPrevNode_getNextNode) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
+  store i8 %v0, ptr %ptr
+  add i8 %v0, %v0
+  store i8 %v1, ptr %ptr
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  auto *S0 = cast<sandboxir::StoreInst>(&*It++);
+  [[maybe_unused]] auto *Add = cast<sandboxir::BinaryOperator>(&*It++);
+  auto *S1 = cast<sandboxir::StoreInst>(&*It++);
+  [[maybe_unused]] auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
+
+  sandboxir::DependencyGraph DAG;
+  DAG.extend({&*BB->begin(), BB->getTerminator()});
+
+  auto *S0N = cast<sandboxir::MemDGNode>(DAG.getNode(S0));
+  auto *S1N = cast<sandboxir::MemDGNode>(DAG.getNode(S1));
+
+  EXPECT_EQ(S0N->getPrevNode(), nullptr);
+  EXPECT_EQ(S0N->getNextNode(), S1N);
+
+  EXPECT_EQ(S1N->getPrevNode(), S0N);
+  EXPECT_EQ(S1N->getNextNode(), nullptr);
+}
+
+TEST_F(DependencyGraphTest, DGNodeRange) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
+  add i8 %v0, %v0
+  store i8 %v0, ptr %ptr
+  add i8 %v0, %v0
+  store i8 %v1, ptr %ptr
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  auto *Add0 = cast<sandboxir::BinaryOperator>(&*It++);
+  auto *S0 = cast<sandboxir::StoreInst>(&*It++);
+  auto *Add1 = cast<sandboxir::BinaryOperator>(&*It++);
+  auto *S1 = cast<sandboxir::StoreInst>(&*It++);
+  auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
+
+  sandboxir::DependencyGraph DAG;
+  DAG.extend({&*BB->begin(), BB->getTerminator()});
+
+  auto *S0N = cast<sandboxir::MemDGNode>(DAG.getNode(S0));
+  auto *S1N = cast<sandboxir::MemDGNode>(DAG.getNode(S1));
+
+  // Check empty range.
+  EXPECT_THAT(sandboxir::DGNodeRange::makeEmptyMemRange(),
+              testing::ElementsAre());
+
+  // Both TopN and BotN are memory.
+  EXPECT_THAT(sandboxir::DGNodeRange::makeMemRange({S0, S1}, DAG),
+              testing::ElementsAre(S0N, S1N));
+  // Only TopN is memory.
+  EXPECT_THAT(sandboxir::DGNodeRange::makeMemRange({S0, Ret}, DAG),
+              testing::ElementsAre(S0N, S1N));
+  EXPECT_THAT(sandboxir::DGNodeRange::makeMemRange({S0, Add1}, DAG),
+              testing::ElementsAre(S0N));
+  // Only BotN is memory.
+  EXPECT_THAT(sandboxir::DGNodeRange::makeMemRange({Add0, S1}, DAG),
+              testing::ElementsAre(S0N, S1N));
+  EXPECT_THAT(sandboxir::DGNodeRange::makeMemRange({Add0, S0}, DAG),
+              testing::ElementsAre(S0N));
+  // Neither TopN or BotN is memory.
+  EXPECT_THAT(sandboxir::DGNodeRange::makeMemRange({Add0, Ret}, DAG),
+              testing::ElementsAre(S0N, S1N));
+  EXPECT_THAT(sandboxir::DGNodeRange::makeMemRange({Add0, Add0}, DAG),
+              testing::ElementsAre());
+}



More information about the llvm-commits mailing list