[llvm] [ADT] SCC iterator for general graph (PR #84268)

David Blaikie via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 7 12:58:25 PST 2024


================
@@ -377,6 +378,145 @@ scc_member_iterator<GraphT, GT>::scc_member_iterator(
   assert(InputNodes.size() == Nodes.size() && "missing nodes in MST");
   std::reverse(Nodes.begin(), Nodes.end());
 }
+
+template <class GraphT, class GT = GraphTraits<GraphT>>
+class graph_scc_iterator
+    : public iterator_facade_base<
+          graph_scc_iterator<GraphT, GT>, std::forward_iterator_tag,
+          const std::vector<typename GT::NodeRef>, ptrdiff_t> {
+  using NodeRef = typename GT::NodeRef;
+  using NodesIter = typename GT::nodes_iterator;
+  using ChildIter = typename GT::ChildIteratorType;
+  using reference = typename graph_scc_iterator::reference;
+
+  struct NodeEntry {
+    int index = 0;
+    int lowlink;
+    bool onStack = false;
+  };
+
+  struct StackEntry {
+    NodeRef n;
+    ChildIter currentChild;
+    NodeRef lastChild;
+    bool visited;
+  };
+
+  std::vector<NodeRef> currentSCC;
+  std::stack<StackEntry> recursionStack;
+  std::stack<NodeRef> S;
+  DenseMap<NodeRef, NodeEntry> nodeInfo;
+  int index = 1;
+
+  // current node in the outer loop for all nodes
+  NodesIter currentNode;
+  NodesIter nodeEndIter;
+
+  graph_scc_iterator(const GraphT &g, bool e = false)
+      : currentNode(GT::nodes_begin(&g)), nodeEndIter(GT::nodes_end(&g)) {
+    if (e) {
+      currentNode = nodeEndIter;
+      return;
+    }
+    computeNext();
+  }
+
+  void computeNext() {
+    currentSCC.clear();
+    if (recursionStack.empty()) {
+      // not in DFS process
+      for (; currentNode != nodeEndIter; currentNode++) {
+        NodeRef v = *currentNode;
+        if (nodeInfo[v].index == 0) {
+          recursionStack.emplace(
+              StackEntry{v, GT::child_begin(v), nullptr, false});
+          currentNode++;
+          break;
+        }
+      }
+    }
+
+    while (!recursionStack.empty()) {
+      StackEntry &stackEntry = recursionStack.top();
+      NodeRef v = stackEntry.n;
+      ChildIter &child = stackEntry.currentChild;
+      NodeEntry &vEntry = nodeInfo[v];
+      if (!stackEntry.visited) {
+        // first time
+        vEntry.index = index;
+        vEntry.lowlink = index;
+        index++;
+        S.push(v);
+        vEntry.onStack = true;
+        stackEntry.visited = true;
+      } else {
+        assert(nodeInfo.count(stackEntry.lastChild));
+        vEntry.lowlink =
+            std::min(vEntry.lowlink, nodeInfo[stackEntry.lastChild].lowlink);
+      }
+
+      bool inserted = false;
+      for (; child != GT::child_end(v); child++) {
+        NodeRef w = *child;
+        NodeEntry &wEntry = nodeInfo[w];
+        if (wEntry.index == 0) {
+          inserted = true;
+          recursionStack.emplace(
+              StackEntry{w, GT::child_begin(w), nullptr, false});
+          stackEntry.lastChild = w;
+          child++;
+          break;
+        } else if (wEntry.onStack) {
+          vEntry.lowlink = std::min(vEntry.lowlink, wEntry.index);
+        }
+      }
+
+      if (inserted)
+        continue;
+
+      recursionStack.pop();
+      if (vEntry.lowlink == vEntry.index) {
+        NodeRef w;
+        do {
+          w = S.top();
+          S.pop();
+          nodeInfo[w].onStack = false;
+          currentSCC.push_back(w);
+        } while (w != v);
+        return;
+      }
+    }
+  }
+
+public:
+  static graph_scc_iterator begin(const GraphT &g) {
+    return graph_scc_iterator(g);
+  }
+  static graph_scc_iterator end(const GraphT &g) {
+    return graph_scc_iterator(g, true);
+  }
+
+  bool operator!=(const graph_scc_iterator &x) const {
+    return currentNode != x.currentNode || currentSCC != x.currentSCC;
+  }
+
----------------
dwblaikie wrote:

generally prefer non-member operator overloads when possible (you can still make it an inline friend definition if having the definition inside the class definition is nicer/keeps related things close together, etc) - it ensures type conversions are equally considered for the LHS and RHS of the operator

https://github.com/llvm/llvm-project/pull/84268


More information about the llvm-commits mailing list