[Mlir-commits] [llvm] [mlir] [MLIR][Analysis] Consolidate topological sort utilities (PR #92563)

Tobias Gysi llvmlistbot at llvm.org
Tue May 21 10:34:58 PDT 2024


================
@@ -146,3 +151,135 @@ bool mlir::computeTopologicalSorting(
 
   return allOpsScheduled;
 }
+
+SetVector<Block *> mlir::getBlocksSortedByDominance(Region &region) {
+  // For each block that has not been visited yet (i.e. that has no
+  // predecessors), add it to the list as well as its successors.
+  SetVector<Block *> blocks;
+  for (Block &b : region) {
+    if (blocks.count(&b) == 0) {
+      llvm::ReversePostOrderTraversal<Block *> traversal(&b);
+      blocks.insert(traversal.begin(), traversal.end());
+    }
+  }
+  assert(blocks.size() == region.getBlocks().size() &&
+         "some blocks are not sorted");
+
+  return blocks;
+}
+
+namespace {
+class TopoSortHelper {
+public:
+  explicit TopoSortHelper(const SetVector<Operation *> &toSort)
+      : toSort(toSort) {}
+
+  /// Executes the topological sort of the operations this instance was
+  /// constructed with. This function will destroy the internal state of the
+  /// instance.
+  SetVector<Operation *> sort() {
+    if (toSort.size() <= 1)
+      // Note: Creates a copy on purpose.
+      return toSort;
+
+    // First, find the root region to start the traversal through the IR. This
+    // additionally enriches the internal caches with all relevant ancestor
+    // regions and blocks.
+    Region *rootRegion = findCommonAncestorRegion();
+    assert(rootRegion && "expected all ops to have a common ancestor");
+
+    // Sort all element in `toSort` by traversing the IR in the appropriate
+    // order.
+    SetVector<Operation *> result = topoSortRegion(*rootRegion);
+    assert(result.size() == toSort.size() &&
+           "expected all operations to be present in the result");
+    return result;
+  }
+
+private:
+  /// Computes the closest common ancestor region of all operations in `toSort`.
+  /// Remembers all the traversed regions in `ancestorRegions`.
+  Region *findCommonAncestorRegion() {
+    // Map to count the number of times a region was encountered.
+    DenseMap<Region *, size_t> regionCounts;
+    size_t expectedCount = toSort.size();
+
+    // Walk the region tree for each operation towards the root and add to the
+    // region count.
+    Region *res = nullptr;
+    for (Operation *op : toSort) {
+      Region *current = op->getParentRegion();
+      // Store the block as an ancestor block.
+      ancestorBlocks.insert(op->getBlock());
+      while (current) {
+
+        // Insert or update the count and compare it.
+        if (++regionCounts[current] == expectedCount) {
+          res = current;
+          break;
+        }
+        ancestorBlocks.insert(current->getParentOp()->getBlock());
+        current = current->getParentRegion();
+      }
+    }
+    auto firstRange = llvm::make_first_range(regionCounts);
+    ancestorRegions.insert(firstRange.begin(), firstRange.end());
+    return res;
+  }
+
+  /// Performs the dominance respecting IR walk to collect the topological order
+  /// of the operation to sort.
+  SetVector<Operation *> topoSortRegion(Region &rootRegion) {
+    using StackT = PointerUnion<Region *, Block *, Operation *>;
+
+    SetVector<Operation *> result;
+    // Stack that stores the different IR constructs to traverse.
+    SmallVector<StackT> stack;
+    stack.push_back(&rootRegion);
+
+    // Traverse the IR in a dominance respecting pre-order walk.
+    while (!stack.empty()) {
+      StackT current = stack.pop_back_val();
+      if (auto *region = dyn_cast<Region *>(current)) {
+        // A region's blocks need to be traversed in dominance order.
+        SetVector<Block *> sortedBlocks = getBlocksSortedByDominance(*region);
+        for (Block *block : llvm::reverse(sortedBlocks))
+          // Only add blocks to the stack that are ancestors of the operations
+          // to sort.
+          if (ancestorBlocks.contains(block))
+            stack.push_back(block);
----------------
gysit wrote:

```suggestion
        for (Block *block : llvm::reverse(sortedBlocks)) {
          // Only add blocks to the stack that are ancestors of the operations
          // to sort.
          if (ancestorBlocks.contains(block))
            stack.push_back(block);
        }
```

https://github.com/llvm/llvm-project/pull/92563


More information about the Mlir-commits mailing list