[Mlir-commits] [clang] [llvm] [mlir] [ADT] Refactor post order traversal (PR #191047)

Alexis Engelke llvmlistbot at llvm.org
Thu Apr 9 12:23:50 PDT 2026


================
@@ -120,180 +54,185 @@ using DefaultSet =
 
 } // namespace po_detail
 
-template <class GraphT, class SetType = po_detail::DefaultSet<GraphT>,
-          bool ExtStorage = false, class GT = GraphTraits<GraphT>>
-class po_iterator : public po_iterator_storage<SetType, ExtStorage> {
-public:
-  // When External storage is used we are not multi-pass safe.
-  using iterator_category =
-      std::conditional_t<ExtStorage, std::input_iterator_tag,
-                         std::forward_iterator_tag>;
-  using value_type = typename GT::NodeRef;
-  using difference_type = std::ptrdiff_t;
-  using pointer = value_type *;
-  using reference = const value_type &;
-
-private:
-  using NodeRef = typename GT::NodeRef;
-  using ChildItTy = typename GT::ChildIteratorType;
+/// CRTP base class for post-order traversal. Storage for visited nodes must be
+/// provided by the sub-class, which must implement insertEdge(). Due to CRTP
+/// limitations, the sub-class must call init() with the start node before
+/// traversing; not calling init results in an empty iterator.
+///
+/// Sub-classes can observe the post-order traversal with finishPostorder(),
+/// which is called before the iterator moves to the next node, and also the
+/// pre-order traversal with insertEdge().
+///
+/// Unwanted graph nodes (e.g. from a previous traversal) can be skipped by
+/// returning false from insertEdge().
+///
+/// This class only supports a single traversal of the graph.
+template <typename DerivedT, typename GraphTraits>
+class PostOrderTraversalBase {
+  using NodeRef = typename GraphTraits::NodeRef;
+  using ChildItTy = typename GraphTraits::ChildIteratorType;
 
   /// Used to maintain the ordering.
   /// First element is basic block pointer, second is iterator for the next
   /// child to visit, third is the end iterator.
   SmallVector<std::tuple<NodeRef, ChildItTy, ChildItTy>, 8> VisitStack;
 
-  po_iterator(NodeRef BB) {
-    this->insertEdge(std::optional<NodeRef>(), BB);
-    VisitStack.emplace_back(BB, GT::child_begin(BB), GT::child_end(BB));
-    traverseChild();
-  }
+public:
+  class iterator {
+    friend class PostOrderTraversalBase;
+
+  public:
+    using iterator_category = std::input_iterator_tag;
+    using value_type = NodeRef;
+    using difference_type = std::ptrdiff_t;
+    using pointer = value_type *;
+    using reference = NodeRef;
+
+  private:
+    DerivedT *POT = nullptr;
+    NodeRef V = nullptr;
+
+  public:
+    iterator() = default;
+
+  private:
+    iterator(DerivedT &POT, value_type V) : POT(&POT), V(V) {}
+
+  public:
+    bool operator==(const iterator &X) const { return V == X.V; }
+    bool operator!=(const iterator &X) const { return !(*this == X); }
+
+    reference operator*() const { return V; }
+    pointer operator->() const { return &V; }
+
+    iterator &operator++() { // Preincrement
+      V = POT->next();
+      return *this;
+    }
+
+    iterator operator++(int) { // Postincrement
+      iterator tmp = *this;
+      ++*this;
+      return tmp;
+    }
+  };
+
+protected:
+  PostOrderTraversalBase() = default;
 
-  po_iterator() = default; // End is when stack is empty.
+  DerivedT *derived() { return static_cast<DerivedT *>(this); }
 
-  po_iterator(NodeRef BB, SetType &S)
-      : po_iterator_storage<SetType, ExtStorage>(S) {
-    if (this->insertEdge(std::optional<NodeRef>(), BB)) {
-      VisitStack.emplace_back(BB, GT::child_begin(BB), GT::child_end(BB));
+  /// Initialize post-order traversal at given start node.
+  void init(NodeRef Start) {
+    if (derived()->insertEdge(std::optional<NodeRef>(), Start)) {
+      VisitStack.emplace_back(Start, GraphTraits::child_begin(Start),
+                              GraphTraits::child_end(Start));
       traverseChild();
     }
   }
 
-  po_iterator(SetType &S)
-      : po_iterator_storage<SetType, ExtStorage>(S) {
-  } // End is when stack is empty.
-
+private:
   void traverseChild() {
     while (true) {
       auto &Entry = VisitStack.back();
       if (std::get<1>(Entry) == std::get<2>(Entry))
         break;
       NodeRef BB = *std::get<1>(Entry)++;
-      if (this->insertEdge(std::optional<NodeRef>(std::get<0>(Entry)), BB)) {
+      if (derived()->insertEdge(std::optional<NodeRef>(std::get<0>(Entry)),
+                                BB)) {
         // If the block is not visited...
-        VisitStack.emplace_back(BB, GT::child_begin(BB), GT::child_end(BB));
+        VisitStack.emplace_back(BB, GraphTraits::child_begin(BB),
+                                GraphTraits::child_end(BB));
       }
     }
   }
 
-public:
-  // Provide static "constructors"...
-  static po_iterator begin(const GraphT &G) {
-    return po_iterator(GT::getEntryNode(G));
-  }
-  static po_iterator end(const GraphT &G) { return po_iterator(); }
-
-  static po_iterator begin(const GraphT &G, SetType &S) {
-    return po_iterator(GT::getEntryNode(G), S);
-  }
-  static po_iterator end(const GraphT &G, SetType &S) { return po_iterator(S); }
-
-  bool operator==(const po_iterator &x) const {
-    return VisitStack == x.VisitStack;
-  }
-  bool operator!=(const po_iterator &x) const { return !(*this == x); }
-
-  reference operator*() const { return std::get<0>(VisitStack.back()); }
-
-  // This is a nonstandard operator-> that dereferences the pointer an extra
-  // time... so that you can actually call methods ON the BasicBlock, because
-  // the contained type is a pointer.  This allows BBIt->getTerminator() f.e.
-  //
-  NodeRef operator->() const { return **this; }
-
-  po_iterator &operator++() { // Preincrement
-    this->finishPostorder(std::get<0>(VisitStack.back()));
+  NodeRef next() {
+    derived()->finishPostorder(std::get<0>(VisitStack.back()));
     VisitStack.pop_back();
     if (!VisitStack.empty())
       traverseChild();
-    return *this;
+    return !VisitStack.empty() ? std::get<0>(VisitStack.back()) : nullptr;
   }
 
-  po_iterator operator++(int) { // Postincrement
-    po_iterator tmp = *this;
-    ++*this;
-    return tmp;
+public:
+  iterator begin() {
+    if (VisitStack.empty())
+      return iterator(); // We don't even want to see the start node.
+    return iterator(*derived(), std::get<0>(VisitStack.back()));
   }
-};
+  iterator end() { return iterator(); }
 
-// Provide global constructors that automatically figure out correct types...
-//
-template <class T>
-po_iterator<T> po_begin(const T &G) { return po_iterator<T>::begin(G); }
-template <class T>
-po_iterator<T> po_end  (const T &G) { return po_iterator<T>::end(G); }
+  // Methods that are intended to be overridden by sub-classes.
 
-template <class T> iterator_range<po_iterator<T>> post_order(const T &G) {
-  return make_range(po_begin(G), po_end(G));
-}
+  /// Add edge and return whether To should be visited. From is nullopt for the
+  /// root node.
+  bool insertEdge(std::optional<NodeRef> From, NodeRef To);
 
-// Provide global definitions of external postorder iterators...
-template <class T, class SetType = std::set<typename GraphTraits<T>::NodeRef>>
-struct po_ext_iterator : po_iterator<T, SetType, true> {
-  po_ext_iterator(const po_iterator<T, SetType, true> &V) :
-  po_iterator<T, SetType, true>(V) {}
+  /// Callback just before the iterator moves to the next block.
+  void finishPostorder(NodeRef) {}
 };
 
-template <class T, class SetType>
-po_ext_iterator<T, SetType> po_ext_begin(const T &G, SetType &S) {
-  return po_ext_iterator<T, SetType>::begin(G, S);
-}
-
-template <class T, class SetType>
-po_ext_iterator<T, SetType> po_ext_end(const T &G, SetType &S) {
-  return po_ext_iterator<T, SetType>::end(G, S);
-}
-
-template <class T, class SetType>
-iterator_range<po_ext_iterator<T, SetType>> post_order_ext(const T &G, SetType &S) {
-  return make_range(po_ext_begin(G, S), po_ext_end(G, S));
-}
+/// Post-order traversal of a graph. Note: the traversal state is stored in this
+/// class, not in the iterators -- the lifetime of PostOrderTraversal must
+/// exceed the lifetime of the iterators. Special care must be taken with
+/// range-based for-loops in combination with LLVM ranges:
+///
+///   // Fine:
+///   for (BasicBlock *BB : post_order(F)) { ... }
+///
+///   // Problematic! Lifetime of PostOrderTraversal ends before the loop is
+///   // entered, because make_filter_range only stores the iterators but not
+///   // the range object itself.
+///   for (BasicBlock *BB : make_filter_range(post_order(F), ...)) { ... }
+///   // Fixed:
+///   auto POT = post_order(F);
+///   for (BasicBlock *BB : make_filter_range(POT, ...)) { ... }
+///
+/// This class only supports a single traversal of the graph.
+template <typename GraphT, typename SetType = po_detail::DefaultSet<GraphT>>
+class PostOrderTraversal
+    : public PostOrderTraversalBase<PostOrderTraversal<GraphT, SetType>,
+                                    GraphTraits<GraphT>> {
+  using NodeRef = typename GraphTraits<GraphT>::NodeRef;
 
-// Provide global definitions of inverse post order iterators...
-template <class T, class SetType = std::set<typename GraphTraits<T>::NodeRef>,
-          bool External = false>
-struct ipo_iterator : po_iterator<Inverse<T>, SetType, External> {
-  ipo_iterator(const po_iterator<Inverse<T>, SetType, External> &V) :
-     po_iterator<Inverse<T>, SetType, External> (V) {}
-};
+  SetType Visited;
 
-template <class T>
-ipo_iterator<T> ipo_begin(const T &G) {
-  return ipo_iterator<T>::begin(G);
-}
+public:
+  /// Default constructor for an empty traversal.
+  PostOrderTraversal() = default;
----------------
aengelke wrote:

MLIR needs a way to support empty regions.

https://github.com/llvm/llvm-project/pull/191047


More information about the Mlir-commits mailing list