[llvm] SCC iterator for general graph (PR #84268)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Mar 6 19:40:21 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-adt
Author: Yida Zhang (zyd2001)
<details>
<summary>Changes</summary>
Current `scc_iterator` in LLVM performs DFS from only one entry node. It works fine for CFG, but for a more general graph with multiple "entry nodes" or has multiple unconnected components, you won't get expected result using `scc_iterator` on that graph. In my project I need to find SCCs in some kind of dependency graph and `scc_iterator` doesn't meet my need, so I decide to make a more general SCC iterator.
For this PR, I didn't change any original code, and I just add a new template iterator `graph_scc_iterator` in `llvm/ADT/SCCIterator.h` which implements Tarjan's algorithm. It basically has the same semantics as `scc_iterator` but requires `nodes_iterator` from the graph.
The code is rough for now, if you think this is worth to be included, I can add more comments to the code.
---
Full diff: https://github.com/llvm/llvm-project/pull/84268.diff
1 Files Affected:
- (modified) llvm/include/llvm/ADT/SCCIterator.h (+140)
``````````diff
diff --git a/llvm/include/llvm/ADT/SCCIterator.h b/llvm/include/llvm/ADT/SCCIterator.h
index e743ae7c11edbc..57722789492726 100644
--- a/llvm/include/llvm/ADT/SCCIterator.h
+++ b/llvm/include/llvm/ADT/SCCIterator.h
@@ -31,6 +31,7 @@
#include <iterator>
#include <queue>
#include <set>
+#include <stack>
#include <unordered_map>
#include <unordered_set>
#include <vector>
@@ -377,6 +378,145 @@ scc_member_iterator<GraphT, GT>::scc_member_iterator(
assert(InputNodes.size() == Nodes.size() && "missing nodes in MST");
std::reverse(Nodes.begin(), Nodes.end());
}
+
+template <class GraphT, class GT = GraphTraits<GraphT>>
+class graph_scc_iterator
+ : public iterator_facade_base<
+ graph_scc_iterator<GraphT, GT>, std::forward_iterator_tag,
+ const std::vector<typename GT::NodeRef>, ptrdiff_t> {
+ using NodeRef = typename GT::NodeRef;
+ using NodesIter = typename GT::nodes_iterator;
+ using ChildIter = typename GT::ChildIteratorType;
+ using reference = typename graph_scc_iterator::reference;
+
+ struct NodeEntry {
+ int index = 0;
+ int lowlink;
+ bool onStack = false;
+ };
+
+ struct StackEntry {
+ NodeRef n;
+ ChildIter currentChild;
+ NodeRef lastChild;
+ bool visited;
+ };
+
+ std::vector<NodeRef> currentSCC;
+ std::stack<StackEntry> recursionStack;
+ std::stack<NodeRef> S;
+ DenseMap<NodeRef, NodeEntry> nodeInfo;
+ int index = 1;
+
+ // current node in the outer loop for all nodes
+ NodesIter currentNode;
+ NodesIter nodeEndIter;
+
+ graph_scc_iterator(const GraphT &g, bool e = false)
+ : currentNode(GT::nodes_begin(&g)), nodeEndIter(GT::nodes_end(&g)) {
+ if (e) {
+ currentNode = nodeEndIter;
+ return;
+ }
+ computeNext();
+ }
+
+ void computeNext() {
+ currentSCC.clear();
+ if (recursionStack.empty()) {
+ // not in DFS process
+ for (; currentNode != nodeEndIter; currentNode++) {
+ NodeRef v = *currentNode;
+ if (nodeInfo[v].index == 0) {
+ recursionStack.emplace(
+ StackEntry{v, GT::child_begin(v), nullptr, false});
+ currentNode++;
+ break;
+ }
+ }
+ }
+
+ while (!recursionStack.empty()) {
+ StackEntry &stackEntry = recursionStack.top();
+ NodeRef v = stackEntry.n;
+ ChildIter &child = stackEntry.currentChild;
+ NodeEntry &vEntry = nodeInfo[v];
+ if (!stackEntry.visited) {
+ // first time
+ vEntry.index = index;
+ vEntry.lowlink = index;
+ index++;
+ S.push(v);
+ vEntry.onStack = true;
+ stackEntry.visited = true;
+ } else {
+ assert(nodeInfo.count(stackEntry.lastChild));
+ vEntry.lowlink =
+ std::min(vEntry.lowlink, nodeInfo[stackEntry.lastChild].lowlink);
+ }
+
+ bool inserted = false;
+ for (; child != GT::child_end(v); child++) {
+ NodeRef w = *child;
+ NodeEntry &wEntry = nodeInfo[w];
+ if (wEntry.index == 0) {
+ inserted = true;
+ recursionStack.emplace(
+ StackEntry{w, GT::child_begin(w), nullptr, false});
+ stackEntry.lastChild = w;
+ child++;
+ break;
+ } else if (wEntry.onStack) {
+ vEntry.lowlink = std::min(vEntry.lowlink, wEntry.index);
+ }
+ }
+
+ if (inserted)
+ continue;
+
+ recursionStack.pop();
+ if (vEntry.lowlink == vEntry.index) {
+ NodeRef w;
+ do {
+ w = S.top();
+ S.pop();
+ nodeInfo[w].onStack = false;
+ currentSCC.push_back(w);
+ } while (w != v);
+ return;
+ }
+ }
+ }
+
+public:
+ static graph_scc_iterator begin(const GraphT &g) {
+ return graph_scc_iterator(g);
+ }
+ static graph_scc_iterator end(const GraphT &g) {
+ return graph_scc_iterator(g, true);
+ }
+
+ bool operator!=(const graph_scc_iterator &x) const {
+ return currentNode != x.currentNode || currentSCC != x.currentSCC;
+ }
+
+ graph_scc_iterator &operator++() {
+ computeNext();
+ return *this;
+ }
+
+ reference operator*() const { return currentSCC; }
+};
+/// Construct the begin iterator for a deduced graph type T.
+template <class T> graph_scc_iterator<T> graph_scc_begin(const T &G) {
+ return graph_scc_iterator<T>::begin(G);
+}
+
+/// Construct the end iterator for a deduced graph type T.
+template <class T> graph_scc_iterator<T> graph_scc_end(const T &G) {
+ return graph_scc_iterator<T>::end(G);
+}
+
} // end namespace llvm
#endif // LLVM_ADT_SCCITERATOR_H
``````````
</details>
https://github.com/llvm/llvm-project/pull/84268
More information about the llvm-commits
mailing list