[llvm] fd5e220 - [SandboxVec][DAG] MemDGNode for memory-dependency candidate nodes (#109684)

via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 1 15:40:57 PDT 2024


Author: vporpo
Date: 2024-10-01T15:40:54-07:00
New Revision: fd5e220fa63bf9142e65be1b553af1100501c4bc

URL: https://github.com/llvm/llvm-project/commit/fd5e220fa63bf9142e65be1b553af1100501c4bc
DIFF: https://github.com/llvm/llvm-project/commit/fd5e220fa63bf9142e65be1b553af1100501c4bc.diff

LOG: [SandboxVec][DAG] MemDGNode for memory-dependency candidate nodes (#109684)

This patch implements the MemDGNode class for DAG nodes that are
candidates
for memory dependencies. These nodes form a chain that is accessible by
`getPrevNode()` and `getNextNode()`.

It also implements a builder class that creates MemDGNode intervals from
Instructions.

Added: 
    

Modified: 
    llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
    llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
    llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index 7f6e6d11e5f53a..0ddc227e3a02b4 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -29,35 +29,61 @@
 
 namespace llvm::sandboxir {
 
+class DependencyGraph;
+class MemDGNode;
+
+/// SubclassIDs for isa/dyn_cast etc.
+enum class DGNodeID {
+  DGNode,
+  MemDGNode,
+};
+
 /// A DependencyGraph Node that points to an Instruction and contains memory
 /// dependency edges.
 class DGNode {
+protected:
   Instruction *I;
+  // TODO: Use a PointerIntPair for SubclassID and I.
+  /// For isa/dyn_cast etc.
+  DGNodeID SubclassID;
   /// Memory predecessors.
-  DenseSet<DGNode *> MemPreds;
-  /// This is true if this may read/write memory, or if it has some ordering
-  /// constraints, like with stacksave/stackrestore and alloca/inalloca.
-  bool IsMem;
+  DenseSet<MemDGNode *> MemPreds;
+
+  DGNode(Instruction *I, DGNodeID ID) : I(I), SubclassID(ID) {}
+  friend class MemDGNode; // For constructor.
 
 public:
-  DGNode(Instruction *I) : I(I) {
-    IsMem = I->isMemDepCandidate() ||
-            (isa<AllocaInst>(I) && cast<AllocaInst>(I)->isUsedWithInAlloca()) ||
-            I->isStackSaveOrRestoreIntrinsic();
+  DGNode(Instruction *I) : I(I), SubclassID(DGNodeID::DGNode) {
+    assert(!isMemDepCandidate(I) && "Expected Non-Mem instruction, ");
+  }
+  DGNode(const DGNode &Other) = delete;
+  virtual ~DGNode() = default;
+  /// \Returns true if this is before \p Other in program order.
+  bool comesBefore(const DGNode *Other) { return I->comesBefore(Other->I); }
+  /// \Returns true if \p I is a memory dependency candidate instruction.
+  static bool isMemDepCandidate(Instruction *I) {
+    AllocaInst *Alloca;
+    return I->isMemDepCandidate() ||
+           ((Alloca = dyn_cast<AllocaInst>(I)) &&
+            Alloca->isUsedWithInAlloca()) ||
+           I->isStackSaveOrRestoreIntrinsic();
   }
+
   Instruction *getInstruction() const { return I; }
-  void addMemPred(DGNode *PredN) { MemPreds.insert(PredN); }
+  void addMemPred(MemDGNode *PredN) { MemPreds.insert(PredN); }
   /// \Returns all memory dependency predecessors.
-  iterator_range<DenseSet<DGNode *>::const_iterator> memPreds() const {
+  iterator_range<DenseSet<MemDGNode *>::const_iterator> memPreds() const {
     return make_range(MemPreds.begin(), MemPreds.end());
   }
   /// \Returns true if there is a memory dependency N->this.
-  bool hasMemPred(DGNode *N) const { return MemPreds.count(N); }
-  /// \Returns true if this may read/write memory, or if it has some ordering
-  /// constraints, like with stacksave/stackrestore and alloca/inalloca.
-  bool isMem() const { return IsMem; }
+  bool hasMemPred(DGNode *N) const {
+    if (auto *MN = dyn_cast<MemDGNode>(N))
+      return MemPreds.count(MN);
+    return false;
+  }
+
 #ifndef NDEBUG
-  void print(raw_ostream &OS, bool PrintDeps = true) const;
+  virtual void print(raw_ostream &OS, bool PrintDeps = true) const;
   friend raw_ostream &operator<<(DGNode &N, raw_ostream &OS) {
     N.print(OS);
     return OS;
@@ -66,9 +92,46 @@ class DGNode {
 #endif // NDEBUG
 };
 
+/// A DependencyGraph Node for instructions that may read/write memory, or have
+/// some ordering constraints, like with stacksave/stackrestore and
+/// alloca/inalloca.
+class MemDGNode final : public DGNode {
+  MemDGNode *PrevMemN = nullptr;
+  MemDGNode *NextMemN = nullptr;
+
+  void setNextNode(MemDGNode *N) { NextMemN = N; }
+  void setPrevNode(MemDGNode *N) { PrevMemN = N; }
+  friend class DependencyGraph; // For setNextNode(), setPrevNode().
+
+public:
+  MemDGNode(Instruction *I) : DGNode(I, DGNodeID::MemDGNode) {
+    assert(isMemDepCandidate(I) && "Expected Mem instruction!");
+  }
+  static bool classof(const DGNode *Other) {
+    return Other->SubclassID == DGNodeID::MemDGNode;
+  }
+  /// \Returns the previous Mem DGNode in instruction order.
+  MemDGNode *getPrevNode() const { return PrevMemN; }
+  /// \Returns the next Mem DGNode in instruction order.
+  MemDGNode *getNextNode() const { return NextMemN; }
+};
+
+/// Convenience builders for a MemDGNode interval.
+class MemDGNodeIntervalBuilder {
+public:
+  /// Given \p Instrs it finds their closest mem nodes in the interval and
+  /// returns the corresponding mem range. Note: BotN (or its neighboring mem
+  /// node) is included in the range.
+  static Interval<MemDGNode> make(const Interval<Instruction> &Instrs,
+                                  DependencyGraph &DAG);
+  static Interval<MemDGNode> makeEmpty() { return {}; }
+};
+
 class DependencyGraph {
 private:
   DenseMap<Instruction *, std::unique_ptr<DGNode>> InstrToNodeMap;
+  /// The DAG spans across all instructions in this interval.
+  Interval<Instruction> DAGInterval;
 
 public:
   DependencyGraph() {}
@@ -77,10 +140,20 @@ class DependencyGraph {
     auto It = InstrToNodeMap.find(I);
     return It != InstrToNodeMap.end() ? It->second.get() : nullptr;
   }
+  /// Like getNode() but returns nullptr if \p I is nullptr.
+  DGNode *getNodeOrNull(Instruction *I) const {
+    if (I == nullptr)
+      return nullptr;
+    return getNode(I);
+  }
   DGNode *getOrCreateNode(Instruction *I) {
     auto [It, NotInMap] = InstrToNodeMap.try_emplace(I);
-    if (NotInMap)
-      It->second = std::make_unique<DGNode>(I);
+    if (NotInMap) {
+      if (DGNode::isMemDepCandidate(I))
+        It->second = std::make_unique<MemDGNode>(I);
+      else
+        It->second = std::make_unique<DGNode>(I);
+    }
     return It->second.get();
   }
   /// Build/extend the dependency graph such that it includes \p Instrs. Returns

diff  --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
index 67b56451c7b594..ce295e8bf5df3f 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -31,6 +31,25 @@ void DGNode::dump() const {
 }
 #endif // NDEBUG
 
+Interval<MemDGNode>
+MemDGNodeIntervalBuilder::make(const Interval<Instruction> &Instrs,
+                               DependencyGraph &DAG) {
+  // If top or bottom instructions are not mem-dep candidate nodes we need to
+  // walk down/up the chain and find the mem-dep ones.
+  Instruction *MemTopI = Instrs.top();
+  Instruction *MemBotI = Instrs.bottom();
+  while (!DGNode::isMemDepCandidate(MemTopI) && MemTopI != MemBotI)
+    MemTopI = MemTopI->getNextNode();
+  while (!DGNode::isMemDepCandidate(MemBotI) && MemBotI != MemTopI)
+    MemBotI = MemBotI->getPrevNode();
+  // If we couldn't find a mem node in range TopN - BotN then it's empty.
+  if (!DGNode::isMemDepCandidate(MemTopI))
+    return {};
+  // Now that we have the mem-dep nodes, create and return the range.
+  return Interval<MemDGNode>(cast<MemDGNode>(DAG.getNode(MemTopI)),
+                             cast<MemDGNode>(DAG.getNode(MemBotI)));
+}
+
 Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
   if (Instrs.empty())
     return {};
@@ -39,10 +58,18 @@ Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
   auto *TopI = Interval.top();
   auto *BotI = Interval.bottom();
   DGNode *LastN = getOrCreateNode(TopI);
+  MemDGNode *LastMemN = dyn_cast<MemDGNode>(LastN);
   for (Instruction *I = TopI->getNextNode(), *E = BotI->getNextNode(); I != E;
        I = I->getNextNode()) {
     auto *N = getOrCreateNode(I);
-    N->addMemPred(LastN);
+    N->addMemPred(LastMemN);
+    // Build the Mem node chain.
+    if (auto *MemN = dyn_cast<MemDGNode>(N)) {
+      MemN->setPrevNode(LastMemN);
+      if (LastMemN != nullptr)
+        LastMemN->setNextNode(MemN);
+      LastMemN = MemN;
+    }
     LastN = N;
   }
   return Interval;

diff  --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index d8b6f519982eb1..28ab38ce3d3536 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -29,7 +29,7 @@ struct DependencyGraphTest : public testing::Test {
   }
 };
 
-TEST_F(DependencyGraphTest, DGNode_IsMem) {
+TEST_F(DependencyGraphTest, MemDGNode) {
   parseIR(C, R"IR(
 declare void @llvm.sideeffect()
 declare void @llvm.pseudoprobe(i64, i64, i32, i64)
@@ -66,16 +66,16 @@ define void @foo(i8 %v1, ptr %ptr) {
 
   sandboxir::DependencyGraph DAG;
   DAG.extend({&*BB->begin(), BB->getTerminator()});
-  EXPECT_TRUE(DAG.getNode(Store)->isMem());
-  EXPECT_TRUE(DAG.getNode(Load)->isMem());
-  EXPECT_FALSE(DAG.getNode(Add)->isMem());
-  EXPECT_TRUE(DAG.getNode(StackSave)->isMem());
-  EXPECT_TRUE(DAG.getNode(StackRestore)->isMem());
-  EXPECT_FALSE(DAG.getNode(SideEffect)->isMem());
-  EXPECT_FALSE(DAG.getNode(PseudoProbe)->isMem());
-  EXPECT_TRUE(DAG.getNode(FakeUse)->isMem());
-  EXPECT_TRUE(DAG.getNode(Call)->isMem());
-  EXPECT_FALSE(DAG.getNode(Ret)->isMem());
+  EXPECT_TRUE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(Store)));
+  EXPECT_TRUE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(Load)));
+  EXPECT_FALSE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(Add)));
+  EXPECT_TRUE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(StackSave)));
+  EXPECT_TRUE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(StackRestore)));
+  EXPECT_FALSE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(SideEffect)));
+  EXPECT_FALSE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(PseudoProbe)));
+  EXPECT_TRUE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(FakeUse)));
+  EXPECT_TRUE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(Call)));
+  EXPECT_FALSE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(Ret)));
 }
 
 TEST_F(DependencyGraphTest, Basic) {
@@ -115,3 +115,100 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
   EXPECT_THAT(N1->memPreds(), testing::ElementsAre(N0));
   EXPECT_THAT(N2->memPreds(), testing::ElementsAre(N1));
 }
+
+TEST_F(DependencyGraphTest, MemDGNode_getPrevNode_getNextNode) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
+  store i8 %v0, ptr %ptr
+  add i8 %v0, %v0
+  store i8 %v1, ptr %ptr
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  auto *S0 = cast<sandboxir::StoreInst>(&*It++);
+  [[maybe_unused]] auto *Add = cast<sandboxir::BinaryOperator>(&*It++);
+  auto *S1 = cast<sandboxir::StoreInst>(&*It++);
+  [[maybe_unused]] auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
+
+  sandboxir::DependencyGraph DAG;
+  DAG.extend({&*BB->begin(), BB->getTerminator()});
+
+  auto *S0N = cast<sandboxir::MemDGNode>(DAG.getNode(S0));
+  auto *S1N = cast<sandboxir::MemDGNode>(DAG.getNode(S1));
+
+  EXPECT_EQ(S0N->getPrevNode(), nullptr);
+  EXPECT_EQ(S0N->getNextNode(), S1N);
+
+  EXPECT_EQ(S1N->getPrevNode(), S0N);
+  EXPECT_EQ(S1N->getNextNode(), nullptr);
+}
+
+TEST_F(DependencyGraphTest, DGNodeRange) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
+  add i8 %v0, %v0
+  store i8 %v0, ptr %ptr
+  add i8 %v0, %v0
+  store i8 %v1, ptr %ptr
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  auto *Add0 = cast<sandboxir::BinaryOperator>(&*It++);
+  auto *S0 = cast<sandboxir::StoreInst>(&*It++);
+  auto *Add1 = cast<sandboxir::BinaryOperator>(&*It++);
+  auto *S1 = cast<sandboxir::StoreInst>(&*It++);
+  auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
+
+  sandboxir::DependencyGraph DAG;
+  DAG.extend({&*BB->begin(), BB->getTerminator()});
+
+  auto *S0N = cast<sandboxir::MemDGNode>(DAG.getNode(S0));
+  auto *S1N = cast<sandboxir::MemDGNode>(DAG.getNode(S1));
+
+  // Check empty range.
+  EXPECT_THAT(sandboxir::MemDGNodeIntervalBuilder::makeEmpty(),
+              testing::ElementsAre());
+
+  // Returns the pointers in Range.
+  auto getPtrVec = [](const auto &Range) {
+    SmallVector<const sandboxir::DGNode *> Vec;
+    for (const sandboxir::DGNode &N : Range)
+      Vec.push_back(&N);
+    return Vec;
+  };
+  // Both TopN and BotN are memory.
+  EXPECT_THAT(
+      getPtrVec(sandboxir::MemDGNodeIntervalBuilder::make({S0, S1}, DAG)),
+      testing::ElementsAre(S0N, S1N));
+  // Only TopN is memory.
+  EXPECT_THAT(
+      getPtrVec(sandboxir::MemDGNodeIntervalBuilder::make({S0, Ret}, DAG)),
+      testing::ElementsAre(S0N, S1N));
+  EXPECT_THAT(
+      getPtrVec(sandboxir::MemDGNodeIntervalBuilder::make({S0, Add1}, DAG)),
+      testing::ElementsAre(S0N));
+  // Only BotN is memory.
+  EXPECT_THAT(
+      getPtrVec(sandboxir::MemDGNodeIntervalBuilder::make({Add0, S1}, DAG)),
+      testing::ElementsAre(S0N, S1N));
+  EXPECT_THAT(
+      getPtrVec(sandboxir::MemDGNodeIntervalBuilder::make({Add0, S0}, DAG)),
+      testing::ElementsAre(S0N));
+  // Neither TopN or BotN is memory.
+  EXPECT_THAT(
+      getPtrVec(sandboxir::MemDGNodeIntervalBuilder::make({Add0, Ret}, DAG)),
+      testing::ElementsAre(S0N, S1N));
+  EXPECT_THAT(
+      getPtrVec(sandboxir::MemDGNodeIntervalBuilder::make({Add0, Add0}, DAG)),
+      testing::ElementsAre());
+}


        


More information about the llvm-commits mailing list