[llvm] [SandboxVec][DAG] Refactoring: Move MemPreds from DGNode to MemDGNode (PR #111897)

via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 10 12:15:18 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-vectorizers

Author: vporpo (vporpo)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/111897.diff


3 Files Affected:

- (modified) llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h (+20-16) 
- (modified) llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp (+11-9) 
- (modified) llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp (+41-32) 


``````````diff
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index eba6d7562e41de..da50e5326ea069 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -1,4 +1,4 @@
-//===- DependencyGraph.h ----------------------------------*- C++ -*-===//
+//===- DependencyGraph.h ----------------------------------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -96,9 +96,6 @@ class DGNode {
   // TODO: Use a PointerIntPair for SubclassID and I.
   /// For isa/dyn_cast etc.
   DGNodeID SubclassID;
-  // TODO: Move MemPreds to MemDGNode.
-  /// Memory predecessors.
-  DenseSet<MemDGNode *> MemPreds;
 
   DGNode(Instruction *I, DGNodeID ID) : I(I), SubclassID(ID) {}
   friend class MemDGNode; // For constructor.
@@ -170,17 +167,6 @@ class DGNode {
   }
 
   Instruction *getInstruction() const { return I; }
-  void addMemPred(MemDGNode *PredN) { MemPreds.insert(PredN); }
-  /// \Returns all memory dependency predecessors.
-  iterator_range<DenseSet<MemDGNode *>::const_iterator> memPreds() const {
-    return make_range(MemPreds.begin(), MemPreds.end());
-  }
-  /// \Returns true if there is a memory dependency N->this.
-  bool hasMemPred(DGNode *N) const {
-    if (auto *MN = dyn_cast<MemDGNode>(N))
-      return MemPreds.count(MN);
-    return false;
-  }
 
 #ifndef NDEBUG
   virtual void print(raw_ostream &OS, bool PrintDeps = true) const;
@@ -198,6 +184,9 @@ class DGNode {
 class MemDGNode final : public DGNode {
   MemDGNode *PrevMemN = nullptr;
   MemDGNode *NextMemN = nullptr;
+  /// Memory predecessors.
+  DenseSet<MemDGNode *> MemPreds;
+  friend class PredIterator; // For MemPreds.
 
   void setNextNode(MemDGNode *N) { NextMemN = N; }
   void setPrevNode(MemDGNode *N) { PrevMemN = N; }
@@ -222,6 +211,21 @@ class MemDGNode final : public DGNode {
   MemDGNode *getPrevNode() const { return PrevMemN; }
   /// \Returns the next Mem DGNode in instruction order.
   MemDGNode *getNextNode() const { return NextMemN; }
+  /// Adds the mem dependency edge PredN->this.
+  void addMemPred(MemDGNode *PredN) { MemPreds.insert(PredN); }
+  /// \Returns true if there is a memory dependency N->this.
+  bool hasMemPred(DGNode *N) const {
+    if (auto *MN = dyn_cast<MemDGNode>(N))
+      return MemPreds.count(MN);
+    return false;
+  }
+  /// \Returns all memory dependency predecessors. Used by tests.
+  iterator_range<DenseSet<MemDGNode *>::const_iterator> memPreds() const {
+    return make_range(MemPreds.begin(), MemPreds.end());
+  }
+#ifndef NDEBUG
+  virtual void print(raw_ostream &OS, bool PrintDeps = true) const override;
+#endif // NDEBUG
 };
 
 /// Convenience builders for a MemDGNode interval.
@@ -266,7 +270,7 @@ class DependencyGraph {
 
   /// Go through all mem nodes in \p SrcScanRange and try to add dependencies to
   /// \p DstN.
-  void scanAndAddDeps(DGNode &DstN, const Interval<MemDGNode> &SrcScanRange);
+  void scanAndAddDeps(MemDGNode &DstN, const Interval<MemDGNode> &SrcScanRange);
 
 public:
   DependencyGraph(AAResults &AA)
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
index 7aea466ed6d8db..70843812ff65bc 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -23,7 +23,8 @@ PredIterator::value_type PredIterator::operator*() {
   // or a mem predecessor.
   if (OpIt != OpItE)
     return DAG->getNode(cast<Instruction>((Value *)*OpIt));
-  assert(MemIt != cast<MemDGNode>(N)->memPreds().end() &&
+  // It's a MemDGNode with OpIt == end, so we need to use MemIt.
+  assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() &&
          "Cant' dereference end iterator!");
   return *MemIt;
 }
@@ -45,7 +46,8 @@ PredIterator &PredIterator::operator++() {
     OpIt = skipNonInstr(OpIt, OpItE);
     return *this;
   }
-  assert(MemIt != cast<MemDGNode>(N)->memPreds().end() && "Already at end!");
+  // It's a MemDGNode with OpIt == end, so we need to increment MemIt.
+  assert(MemIt != cast<MemDGNode>(N)->MemPreds.end() && "Already at end!");
   ++MemIt;
   return *this;
 }
@@ -57,10 +59,14 @@ bool PredIterator::operator==(const PredIterator &Other) const {
 }
 
 #ifndef NDEBUG
-void DGNode::print(raw_ostream &OS, bool PrintDeps) const {
+void DGNode::print(raw_ostream &OS, bool PrintDeps) const { I->dumpOS(OS); }
+void DGNode::dump() const {
+  print(dbgs());
+  dbgs() << "\n";
+}
+void MemDGNode::print(raw_ostream &OS, bool PrintDeps) const {
   I->dumpOS(OS);
   if (PrintDeps) {
-    OS << "\n";
     // Print memory preds.
     static constexpr const unsigned Indent = 4;
     for (auto *Pred : MemPreds) {
@@ -70,10 +76,6 @@ void DGNode::print(raw_ostream &OS, bool PrintDeps) const {
     }
   }
 }
-void DGNode::dump() const {
-  print(dbgs());
-  dbgs() << "\n";
-}
 #endif // NDEBUG
 
 Interval<MemDGNode>
@@ -179,7 +181,7 @@ bool DependencyGraph::hasDep(Instruction *SrcI, Instruction *DstI) {
   llvm_unreachable("Unknown DependencyType enum");
 }
 
-void DependencyGraph::scanAndAddDeps(DGNode &DstN,
+void DependencyGraph::scanAndAddDeps(MemDGNode &DstN,
                                      const Interval<MemDGNode> &SrcScanRange) {
   assert(isa<MemDGNode>(DstN) &&
          "DstN is the mem dep destination, so it must be mem");
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index 6b3d9cc77c9955..5a9c9815ca42fa 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -50,10 +50,10 @@ struct DependencyGraphTest : public testing::Test {
     return *AA;
   }
   /// \Returns true if there is a dependency: SrcN->DstN.
-  bool dependency(sandboxir::DGNode *SrcN, sandboxir::DGNode *DstN) {
-    const auto &Preds = DstN->memPreds();
-    auto It = find(Preds, SrcN);
-    return It != Preds.end();
+  bool memDependency(sandboxir::DGNode *SrcN, sandboxir::DGNode *DstN) {
+    if (auto *MemDstN = dyn_cast<sandboxir::MemDGNode>(DstN))
+      return MemDstN->hasMemPred(SrcN);
+    return false;
   }
 };
 
@@ -230,9 +230,10 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
   EXPECT_EQ(Span.top(), &*BB->begin());
   EXPECT_EQ(Span.bottom(), BB->getTerminator());
 
-  sandboxir::DGNode *N0 = DAG.getNode(S0);
-  sandboxir::DGNode *N1 = DAG.getNode(S1);
-  sandboxir::DGNode *N2 = DAG.getNode(Ret);
+  auto *N0 = cast<sandboxir::MemDGNode>(DAG.getNode(S0));
+  auto *N1 = cast<sandboxir::MemDGNode>(DAG.getNode(S1));
+  auto *N2 = DAG.getNode(Ret);
+
   // Check getInstruction().
   EXPECT_EQ(N0->getInstruction(), S0);
   EXPECT_EQ(N1->getInstruction(), S1);
@@ -247,7 +248,7 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
   // Check memPreds().
   EXPECT_TRUE(N0->memPreds().empty());
   EXPECT_THAT(N1->memPreds(), testing::ElementsAre(N0));
-  EXPECT_TRUE(N2->memPreds().empty());
+  EXPECT_TRUE(N2->preds(DAG).empty());
 }
 
 TEST_F(DependencyGraphTest, Preds) {
@@ -399,12 +400,14 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
   sandboxir::DependencyGraph DAG(getAA(*LLVMF));
   DAG.extend({&*BB->begin(), BB->getTerminator()});
   auto It = BB->begin();
-  auto *Store0N = DAG.getNode(cast<sandboxir::StoreInst>(&*It++));
-  auto *Store1N = DAG.getNode(cast<sandboxir::StoreInst>(&*It++));
+  auto *Store0N = cast<sandboxir::MemDGNode>(
+      DAG.getNode(cast<sandboxir::StoreInst>(&*It++)));
+  auto *Store1N = cast<sandboxir::MemDGNode>(
+      DAG.getNode(cast<sandboxir::StoreInst>(&*It++)));
   auto *RetN = DAG.getNode(cast<sandboxir::ReturnInst>(&*It++));
   EXPECT_TRUE(Store0N->memPreds().empty());
   EXPECT_THAT(Store1N->memPreds(), testing::ElementsAre(Store0N));
-  EXPECT_TRUE(RetN->memPreds().empty());
+  EXPECT_TRUE(RetN->preds(DAG).empty());
 }
 
 TEST_F(DependencyGraphTest, NonAliasingStores) {
@@ -422,13 +425,15 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1, i8 %v0, i8 %v1) {
   sandboxir::DependencyGraph DAG(getAA(*LLVMF));
   DAG.extend({&*BB->begin(), BB->getTerminator()});
   auto It = BB->begin();
-  auto *Store0N = DAG.getNode(cast<sandboxir::StoreInst>(&*It++));
-  auto *Store1N = DAG.getNode(cast<sandboxir::StoreInst>(&*It++));
+  auto *Store0N = cast<sandboxir::MemDGNode>(
+      DAG.getNode(cast<sandboxir::StoreInst>(&*It++)));
+  auto *Store1N = cast<sandboxir::MemDGNode>(
+      DAG.getNode(cast<sandboxir::StoreInst>(&*It++)));
   auto *RetN = DAG.getNode(cast<sandboxir::ReturnInst>(&*It++));
   // We expect no dependencies because the stores don't alias.
   EXPECT_TRUE(Store0N->memPreds().empty());
   EXPECT_TRUE(Store1N->memPreds().empty());
-  EXPECT_TRUE(RetN->memPreds().empty());
+  EXPECT_TRUE(RetN->preds(DAG).empty());
 }
 
 TEST_F(DependencyGraphTest, VolatileLoads) {
@@ -446,12 +451,14 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1) {
   sandboxir::DependencyGraph DAG(getAA(*LLVMF));
   DAG.extend({&*BB->begin(), BB->getTerminator()});
   auto It = BB->begin();
-  auto *Ld0N = DAG.getNode(cast<sandboxir::LoadInst>(&*It++));
-  auto *Ld1N = DAG.getNode(cast<sandboxir::LoadInst>(&*It++));
+  auto *Ld0N = cast<sandboxir::MemDGNode>(
+      DAG.getNode(cast<sandboxir::LoadInst>(&*It++)));
+  auto *Ld1N = cast<sandboxir::MemDGNode>(
+      DAG.getNode(cast<sandboxir::LoadInst>(&*It++)));
   auto *RetN = DAG.getNode(cast<sandboxir::ReturnInst>(&*It++));
   EXPECT_TRUE(Ld0N->memPreds().empty());
   EXPECT_THAT(Ld1N->memPreds(), testing::ElementsAre(Ld0N));
-  EXPECT_TRUE(RetN->memPreds().empty());
+  EXPECT_TRUE(RetN->preds(DAG).empty());
 }
 
 TEST_F(DependencyGraphTest, VolatileSotres) {
@@ -469,12 +476,14 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1, i8 %v) {
   sandboxir::DependencyGraph DAG(getAA(*LLVMF));
   DAG.extend({&*BB->begin(), BB->getTerminator()});
   auto It = BB->begin();
-  auto *Store0N = DAG.getNode(cast<sandboxir::StoreInst>(&*It++));
-  auto *Store1N = DAG.getNode(cast<sandboxir::StoreInst>(&*It++));
+  auto *Store0N = cast<sandboxir::MemDGNode>(
+      DAG.getNode(cast<sandboxir::StoreInst>(&*It++)));
+  auto *Store1N = cast<sandboxir::MemDGNode>(
+      DAG.getNode(cast<sandboxir::StoreInst>(&*It++)));
   auto *RetN = DAG.getNode(cast<sandboxir::ReturnInst>(&*It++));
   EXPECT_TRUE(Store0N->memPreds().empty());
   EXPECT_THAT(Store1N->memPreds(), testing::ElementsAre(Store0N));
-  EXPECT_TRUE(RetN->memPreds().empty());
+  EXPECT_TRUE(RetN->preds(DAG).empty());
 }
 
 TEST_F(DependencyGraphTest, Call) {
@@ -498,12 +507,12 @@ define void @foo(float %v1, float %v2) {
   DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()});
 
   auto It = BB->begin();
-  auto *Call1N = DAG.getNode(&*It++);
+  auto *Call1N = cast<sandboxir::MemDGNode>(DAG.getNode(&*It++));
   auto *AddN = DAG.getNode(&*It++);
-  auto *Call2N = DAG.getNode(&*It++);
+  auto *Call2N = cast<sandboxir::MemDGNode>(DAG.getNode(&*It++));
 
   EXPECT_THAT(Call1N->memPreds(), testing::ElementsAre());
-  EXPECT_THAT(AddN->memPreds(), testing::ElementsAre());
+  EXPECT_THAT(AddN->preds(DAG), testing::ElementsAre());
   EXPECT_THAT(Call2N->memPreds(), testing::ElementsAre(Call1N));
 }
 
@@ -534,8 +543,8 @@ define void @foo() {
   auto *AllocaN = DAG.getNode(&*It++);
   auto *StackRestoreN = DAG.getNode(&*It++);
 
-  EXPECT_TRUE(dependency(AllocaN, StackRestoreN));
-  EXPECT_TRUE(dependency(StackSaveN, AllocaN));
+  EXPECT_TRUE(memDependency(AllocaN, StackRestoreN));
+  EXPECT_TRUE(memDependency(StackSaveN, AllocaN));
 }
 
 // Checks that stacksave and stackrestore depend on other mem instrs.
@@ -567,9 +576,9 @@ define void @foo(i8 %v0, i8 %v1, ptr %ptr) {
   auto *StackRestoreN = DAG.getNode(&*It++);
   auto *Store1N = DAG.getNode(&*It++);
 
-  EXPECT_TRUE(dependency(Store0N, StackSaveN));
-  EXPECT_TRUE(dependency(StackSaveN, StackRestoreN));
-  EXPECT_TRUE(dependency(StackRestoreN, Store1N));
+  EXPECT_TRUE(memDependency(Store0N, StackSaveN));
+  EXPECT_TRUE(memDependency(StackSaveN, StackRestoreN));
+  EXPECT_TRUE(memDependency(StackRestoreN, Store1N));
 }
 
 // Make sure there is a dependency between a stackrestore and an alloca.
@@ -596,7 +605,7 @@ define void @foo(ptr %ptr) {
   auto *StackRestoreN = DAG.getNode(&*It++);
   auto *AllocaN = DAG.getNode(&*It++);
 
-  EXPECT_TRUE(dependency(StackRestoreN, AllocaN));
+  EXPECT_TRUE(memDependency(StackRestoreN, AllocaN));
 }
 
 // Make sure there is a dependency between the alloca and stacksave
@@ -623,7 +632,7 @@ define void @foo(ptr %ptr) {
   auto *AllocaN = DAG.getNode(&*It++);
   auto *StackSaveN = DAG.getNode(&*It++);
 
-  EXPECT_TRUE(dependency(AllocaN, StackSaveN));
+  EXPECT_TRUE(memDependency(AllocaN, StackSaveN));
 }
 
 // A non-InAlloca in a stacksave-stackrestore region does not need extra
@@ -655,6 +664,6 @@ define void @foo() {
   auto *AllocaN = DAG.getNode(&*It++);
   auto *StackRestoreN = DAG.getNode(&*It++);
 
-  EXPECT_FALSE(dependency(StackSaveN, AllocaN));
-  EXPECT_FALSE(dependency(AllocaN, StackRestoreN));
+  EXPECT_FALSE(memDependency(StackSaveN, AllocaN));
+  EXPECT_FALSE(memDependency(AllocaN, StackRestoreN));
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/111897


More information about the llvm-commits mailing list