[llvm] [SandboxVec][DAG] Refactoring: Outline code that looks for mem nodes (PR #111750)

via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 9 12:25:42 PDT 2024


https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/111750

None

>From 1c69a3f66f0c54d077b2150c614f5c15011d0b13 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Wed, 9 Oct 2024 12:00:04 -0700
Subject: [PATCH] [SandboxVec][DAG] Refactoring: Outline code that looks for
 mem nodes

---
 .../SandboxVectorizer/DependencyGraph.h       | 10 +++++
 .../SandboxVectorizer/DependencyGraph.cpp     | 39 +++++++++++++------
 .../SandboxVectorizer/DependencyGraphTest.cpp | 19 +++++++++
 3 files changed, 57 insertions(+), 11 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index 134adc4b21ab12..050e119040c281 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -154,6 +154,16 @@ class MemDGNode final : public DGNode {
 /// Convenience builders for a MemDGNode interval.
 class MemDGNodeIntervalBuilder {
 public:
+  /// Scans the instruction chain after \p I until \p BeforeI, looking for
+  /// a mem dependency candidate and return the corresponding MemDGNode, or
+  /// nullptr if not found.
+  static MemDGNode *getMemDGNodeAfter(Instruction *I, Instruction *BeforeI,
+                                      const DependencyGraph &DAG);
+  /// Scans the instruction chain before \p I until \p AfterI, looking for
+  /// a mem dependency candidate and return the corresponding MemDGNode, or
+  /// nullptr if not found.
+  static MemDGNode *getMemDGNodeBefore(Instruction *I, Instruction *AfterI,
+                                       const DependencyGraph &DAG);
   /// Given \p Instrs it finds their closest mem nodes in the interval and
   /// returns the corresponding mem range. Note: BotN (or its neighboring mem
   /// node) is included in the range.
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
index 82f253d4c63231..6266b4155dc253 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -32,23 +32,40 @@ void DGNode::dump() const {
 }
 #endif // NDEBUG
 
+MemDGNode *MemDGNodeIntervalBuilder::getMemDGNodeAfter(
+    Instruction *I, Instruction *BeforeI, const DependencyGraph &DAG) {
+  assert((I == BeforeI || I->comesBefore(BeforeI)) &&
+         "Expected I before BeforeI");
+  // Walk down the chain looking for a mem-dep candidate instruction.
+  while (!DGNode::isMemDepNodeCandidate(I) && I != BeforeI)
+    I = I->getNextNode();
+  if (!DGNode::isMemDepNodeCandidate(I))
+    return nullptr;
+  return cast<MemDGNode>(DAG.getNode(I));
+}
+
+MemDGNode *MemDGNodeIntervalBuilder::getMemDGNodeBefore(
+    Instruction *I, Instruction *AfterI, const DependencyGraph &DAG) {
+  assert((I == AfterI || AfterI->comesBefore(I)) && "Expected AfterI before I");
+  // Walk up the chain looking for a mem-dep candidate instruction.
+  while (!DGNode::isMemDepNodeCandidate(I) && I != AfterI)
+    I = I->getPrevNode();
+  if (!DGNode::isMemDepNodeCandidate(I))
+    return nullptr;
+  return cast<MemDGNode>(DAG.getNode(I));
+}
+
 Interval<MemDGNode>
 MemDGNodeIntervalBuilder::make(const Interval<Instruction> &Instrs,
                                DependencyGraph &DAG) {
-  // If top or bottom instructions are not mem-dep candidate nodes we need to
-  // walk down/up the chain and find the mem-dep ones.
-  Instruction *MemTopI = Instrs.top();
-  Instruction *MemBotI = Instrs.bottom();
-  while (!DGNode::isMemDepNodeCandidate(MemTopI) && MemTopI != MemBotI)
-    MemTopI = MemTopI->getNextNode();
-  while (!DGNode::isMemDepNodeCandidate(MemBotI) && MemBotI != MemTopI)
-    MemBotI = MemBotI->getPrevNode();
+  auto *TopMemN = getMemDGNodeAfter(Instrs.top(), Instrs.bottom(), DAG);
   // If we couldn't find a mem node in range TopN - BotN then it's empty.
-  if (!DGNode::isMemDepNodeCandidate(MemTopI))
+  if (TopMemN == nullptr)
     return {};
+  auto *BotMemN = getMemDGNodeBefore(Instrs.bottom(), Instrs.top(), DAG);
+  assert(BotMemN != nullptr && "TopMemN should be null too!");
   // Now that we have the mem-dep nodes, create and return the range.
-  return Interval<MemDGNode>(cast<MemDGNode>(DAG.getNode(MemTopI)),
-                             cast<MemDGNode>(DAG.getNode(MemBotI)));
+  return Interval<MemDGNode>(TopMemN, BotMemN);
 }
 
 DependencyGraph::DependencyType
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index e2f16919a5cddd..3d14da7b9358ec 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -305,6 +305,25 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
   auto *S0N = cast<sandboxir::MemDGNode>(DAG.getNode(S0));
   auto *S1N = cast<sandboxir::MemDGNode>(DAG.getNode(S1));
 
+  // Check getMemDGNodeAfter().
+  using B = sandboxir::MemDGNodeIntervalBuilder;
+  EXPECT_EQ(B::getMemDGNodeAfter(S0, S0, DAG), S0N);
+  EXPECT_EQ(B::getMemDGNodeAfter(S0, Ret, DAG), S0N);
+#ifndef NDEBUG
+  EXPECT_DEATH(B::getMemDGNodeAfter(S0, Add0, DAG), ".*before.*");
+#endif
+  EXPECT_EQ(B::getMemDGNodeAfter(Add0, Add1, DAG), S0N);
+  EXPECT_EQ(B::getMemDGNodeAfter(Add0, Add0, DAG), nullptr);
+
+  // Check getMemDGNodeBefore().
+  EXPECT_EQ(B::getMemDGNodeBefore(S1, S1, DAG), S1N);
+  EXPECT_EQ(B::getMemDGNodeBefore(S1, Add0, DAG), S1N);
+#ifndef NDEBUG
+  EXPECT_DEATH(B::getMemDGNodeBefore(S1, Ret, DAG), ".*before.*");
+#endif
+  EXPECT_EQ(B::getMemDGNodeBefore(Ret, Add0, DAG), S1N);
+  EXPECT_EQ(B::getMemDGNodeBefore(Ret, Ret, DAG), nullptr);
+
   // Check empty range.
   EXPECT_THAT(sandboxir::MemDGNodeIntervalBuilder::makeEmpty(),
               testing::ElementsAre());



More information about the llvm-commits mailing list