[Mlir-commits] [llvm] [mlir] [MLIR][Analysis] Consolidate topological sort utilities (PR #92563)

Mehdi Amini llvmlistbot at llvm.org
Fri May 17 09:40:28 PDT 2024


================
@@ -146,3 +151,93 @@ bool mlir::computeTopologicalSorting(
 
   return allOpsScheduled;
 }
+
+SetVector<Block *> mlir::getBlocksSortedByDominance(Region &region) {
+  // For each block that has not been visited yet (i.e. that has no
+  // predecessors), add it to the list as well as its successors.
+  SetVector<Block *> blocks;
+  for (Block &b : region) {
+    if (blocks.count(&b) == 0) {
+      llvm::ReversePostOrderTraversal<Block *> traversal(&b);
+      blocks.insert(traversal.begin(), traversal.end());
+    }
+  }
+  assert(blocks.size() == region.getBlocks().size() &&
+         "some blocks are not sorted");
+
+  return blocks;
+}
+
+/// Computes the common ancestor region of all operations in `ops`. Remembers
+/// all the traversed regions in `traversedRegions`.
+static Region *findCommonParentRegion(const SetVector<Operation *> &ops,
+                                      DenseSet<Region *> &traversedRegions) {
+  // Map to count the number of times a region was encountered.
+  llvm::DenseMap<Region *, size_t> regionCounts;
+  size_t expectedCount = ops.size();
+
+  // Walk the region tree for each operation towards the root and add to the
+  // region count.
+  Region *res = nullptr;
+  for (Operation *op : ops) {
+    Region *current = op->getParentRegion();
+    while (current) {
+      // Insert or get the count.
+      auto it = regionCounts.try_emplace(current, 0).first;
+      size_t count = ++it->getSecond();
+      if (count == expectedCount) {
+        res = current;
+        break;
+      }
+      current = current->getParentRegion();
+    }
+  }
+  auto firstRange = llvm::make_first_range(regionCounts);
+  traversedRegions.insert(firstRange.begin(), firstRange.end());
+  return res;
+}
+
+/// Topologically traverses `region` and insers all encountered operations in
+/// `toSort` into the result. Recursively traverses regions when they are
+/// present in `relevantRegions`.
+static void topoSortRegion(Region &region,
+                           const DenseSet<Region *> &relevantRegions,
+                           const SetVector<Operation *> &toSort,
+                           SetVector<Operation *> &result) {
+  SetVector<Block *> sortedBlocks = getBlocksSortedByDominance(region);
+  for (Block *block : sortedBlocks) {
+    for (Operation &op : *block) {
+      if (toSort.contains(&op))
+        result.insert(&op);
+      for (Region &subRegion : op.getRegions()) {
+        // Skip regions that do not contain operations from `toSort`.
+        if (!relevantRegions.contains(&region))
+          continue;
+        topoSortRegion(subRegion, relevantRegions, toSort, result);
----------------
joker-eph wrote:

Can you replace this recursion with an iterative algorithm please?
These recursions are ticking bombs...

https://github.com/llvm/llvm-project/pull/92563


More information about the Mlir-commits mailing list