[llvm] [SandboxVec][DAG] Implement PredIterator (PR #111604)

via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 8 16:24:42 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: vporpo (vporpo)

<details>
<summary>Changes</summary>

This patch implements an iterator for iterating over both use-def and mem dependencies of MemDGNodes.

---
Full diff: https://github.com/llvm/llvm-project/pull/111604.diff


3 Files Affected:

- (modified) llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h (+73) 
- (modified) llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp (+35) 
- (modified) llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp (+41) 


``````````diff
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index ab49c3aa27143c..8c466465438194 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -40,6 +40,54 @@ enum class DGNodeID {
   MemDGNode,
 };
 
+class DGNode;
+class MemDGNode;
+class DependencyGraph;
+
+/// While OpIt points to a Value that is not an Instruction keep incrementing
+/// it. \Returns the first iterator that points to an Instruction, or end.
+[[nodiscard]] static User::op_iterator skipNonInstr(User::op_iterator OpIt,
+                                                    User::op_iterator OpItE) {
+  while (OpIt != OpItE && !isa<Instruction>((*OpIt).get()))
+    ++OpIt;
+  return OpIt;
+}
+
+/// Iterate over both def-use and mem dependencies.
+class PredIterator {
+  User::op_iterator OpIt;
+  User::op_iterator OpItE;
+  DenseSet<MemDGNode *>::iterator MemIt;
+  DGNode *N = nullptr;
+  DependencyGraph *DAG = nullptr;
+
+  PredIterator(const User::op_iterator &OpIt, const User::op_iterator &OpItE,
+               const DenseSet<MemDGNode *>::iterator &MemIt, DGNode *N,
+               DependencyGraph &DAG)
+      : OpIt(OpIt), OpItE(OpItE), MemIt(MemIt), N(N), DAG(&DAG) {}
+  PredIterator(const User::op_iterator &OpIt, const User::op_iterator &OpItE,
+               DGNode *N, DependencyGraph &DAG)
+      : OpIt(OpIt), OpItE(OpItE), N(N), DAG(&DAG) {}
+  friend class DGNode;    // For constructor
+  friend class MemDGNode; // For constructor
+
+public:
+  using difference_type = std::ptrdiff_t;
+  using value_type = DGNode *;
+  using pointer = value_type *;
+  using reference = value_type &;
+  using iterator_category = std::input_iterator_tag;
+  value_type operator*();
+  PredIterator &operator++();
+  PredIterator operator++(int) {
+    auto Copy = *this;
+    ++(*this);
+    return Copy;
+  }
+  bool operator==(const PredIterator &Other) const;
+  bool operator!=(const PredIterator &Other) const { return !(*this == Other); }
+};
+
 /// A DependencyGraph Node that points to an Instruction and contains memory
 /// dependency edges.
 class DGNode {
@@ -63,6 +111,23 @@ class DGNode {
   virtual ~DGNode() = default;
   /// \Returns true if this is before \p Other in program order.
   bool comesBefore(const DGNode *Other) { return I->comesBefore(Other->I); }
+  using iterator = PredIterator;
+  virtual iterator preds_begin(DependencyGraph &DAG) {
+    return PredIterator(skipNonInstr(I->op_begin(), I->op_end()), I->op_end(),
+                        this, DAG);
+  }
+  virtual iterator preds_end(DependencyGraph &DAG) {
+    return PredIterator(I->op_end(), I->op_end(), this, DAG);
+  }
+  iterator preds_begin(DependencyGraph &DAG) const {
+    return const_cast<DGNode *>(this)->preds_begin(DAG);
+  }
+  iterator preds_end(DependencyGraph &DAG) const {
+    return const_cast<DGNode *>(this)->preds_end(DAG);
+  }
+  iterator_range<iterator> preds(DependencyGraph &DAG) const {
+    return make_range(preds_begin(DAG), preds_end(DAG));
+  }
 
   static bool isStackSaveOrRestoreIntrinsic(Instruction *I) {
     if (auto *II = dyn_cast<IntrinsicInst>(I)) {
@@ -145,6 +210,14 @@ class MemDGNode final : public DGNode {
   static bool classof(const DGNode *Other) {
     return Other->SubclassID == DGNodeID::MemDGNode;
   }
+  iterator preds_begin(DependencyGraph &DAG) override {
+    auto OpEndIt = I->op_end();
+    return PredIterator(skipNonInstr(I->op_begin(), OpEndIt), OpEndIt,
+                        MemPreds.begin(), this, DAG);
+  }
+  iterator preds_end(DependencyGraph &DAG) override {
+    return PredIterator(I->op_end(), I->op_end(), MemPreds.end(), this, DAG);
+  }
   /// \Returns the previous Mem DGNode in instruction order.
   MemDGNode *getPrevNode() const { return PrevMemN; }
   /// \Returns the next Mem DGNode in instruction order.
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
index 845fadefc9bf03..1730c4a9e4bc0f 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -8,10 +8,45 @@
 
 #include "llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/SandboxIR/Instruction.h"
 #include "llvm/SandboxIR/Utils.h"
 
 namespace llvm::sandboxir {
 
+PredIterator::value_type PredIterator::operator*() {
+  // If it's a DGNode, or a MemDGNode with an OpIt != end.
+  if (!isa<MemDGNode>(N) || OpIt != OpItE) {
+    assert(OpIt != OpItE && "Can't dereference end iterator!");
+    Value *OpV = *OpIt;
+    return DAG->getNode(cast<Instruction>(OpV));
+  }
+  // It's a MemDGNode with OpIt == end, so we need to use MemIt.
+  assert(MemIt != cast<MemDGNode>(N)->memPreds().end() &&
+         "Cant' dereference end iterator!");
+  return *MemIt;
+}
+
+PredIterator &PredIterator::operator++() {
+  // If it's a DGNode, or a MemDGNode with an OpIt != end.
+  if (!isa<MemDGNode>(N) || OpIt != OpItE) {
+    assert(OpIt != OpItE && "Already at end!");
+    ++OpIt;
+    // Skip operands that are not instructions.
+    OpIt = skipNonInstr(OpIt, OpItE);
+    return *this;
+  }
+  // It's a MemDGNode with OpIt == end, so we need to increment MemIt.
+  assert(MemIt != cast<MemDGNode>(N)->memPreds().end() && "Already at end!");
+  ++MemIt;
+  return *this;
+}
+
+bool PredIterator::operator==(const PredIterator &Other) const {
+  assert(DAG == Other.DAG && "Iterators of different DAGs!");
+  assert(N == Other.N && "Iterators of different nodes!");
+  return OpIt == Other.OpIt && MemIt == Other.MemIt;
+}
+
 #ifndef NDEBUG
 void DGNode::print(raw_ostream &OS, bool PrintDeps) const {
   I->dumpOS(OS);
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index e2f16919a5cddd..6b3d9cc77c9955 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -240,12 +240,53 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
   EXPECT_TRUE(N1->hasMemPred(N0));
   EXPECT_FALSE(N0->hasMemPred(N1));
 
+  // Check preds().
+  EXPECT_TRUE(N0->preds(DAG).empty());
+  EXPECT_THAT(N1->preds(DAG), testing::ElementsAre(N0));
+
   // Check memPreds().
   EXPECT_TRUE(N0->memPreds().empty());
   EXPECT_THAT(N1->memPreds(), testing::ElementsAre(N0));
   EXPECT_TRUE(N2->memPreds().empty());
 }
 
+TEST_F(DependencyGraphTest, Preds) {
+  parseIR(C, R"IR(
+declare ptr @bar(i8)
+define i8 @foo(i8 %v0, i8 %v1) {
+  %add0 = add i8 %v0, %v0
+  %add1 = add i8 %v1, %v1
+  %add2 = add i8 %add0, %add1
+  %ptr = call ptr @bar(i8 %add1)
+  store i8 %add2, ptr %ptr
+  ret i8 %add2
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+  DAG.extend({&*BB->begin(), BB->getTerminator()});
+
+  auto *AddN0 = DAG.getNode(cast<sandboxir::BinaryOperator>(&*It++));
+  auto *AddN1 = DAG.getNode(cast<sandboxir::BinaryOperator>(&*It++));
+  auto *AddN2 = DAG.getNode(cast<sandboxir::BinaryOperator>(&*It++));
+  auto *CallN = DAG.getNode(cast<sandboxir::CallInst>(&*It++));
+  auto *StN = DAG.getNode(cast<sandboxir::StoreInst>(&*It++));
+  auto *RetN = DAG.getNode(cast<sandboxir::ReturnInst>(&*It++));
+
+  // Check preds().
+  EXPECT_THAT(AddN0->preds(DAG), testing::ElementsAre());
+  EXPECT_THAT(AddN1->preds(DAG), testing::ElementsAre());
+  EXPECT_THAT(AddN2->preds(DAG), testing::ElementsAre(AddN0, AddN1));
+  EXPECT_THAT(CallN->preds(DAG), testing::ElementsAre(AddN1));
+  EXPECT_THAT(StN->preds(DAG),
+              testing::UnorderedElementsAre(CallN, CallN, AddN2));
+  EXPECT_THAT(RetN->preds(DAG), testing::ElementsAre(AddN2));
+}
+
 TEST_F(DependencyGraphTest, MemDGNode_getPrevNode_getNextNode) {
   parseIR(C, R"IR(
 define void @foo(ptr %ptr, i8 %v0, i8 %v1) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/111604


More information about the llvm-commits mailing list