[Mlir-commits] [llvm] [mlir] [MLIR][Analysis] Consolidate topological sort utilities (PR #92563)
Mehdi Amini
llvmlistbot at llvm.org
Fri May 17 09:40:28 PDT 2024
================
@@ -146,3 +151,93 @@ bool mlir::computeTopologicalSorting(
return allOpsScheduled;
}
+
+SetVector<Block *> mlir::getBlocksSortedByDominance(Region ®ion) {
+ // For each block that has not been visited yet (i.e. that has no
+ // predecessors), add it to the list as well as its successors.
+ SetVector<Block *> blocks;
+ for (Block &b : region) {
+ if (blocks.count(&b) == 0) {
+ llvm::ReversePostOrderTraversal<Block *> traversal(&b);
+ blocks.insert(traversal.begin(), traversal.end());
+ }
+ }
+ assert(blocks.size() == region.getBlocks().size() &&
+ "some blocks are not sorted");
+
+ return blocks;
+}
+
+/// Computes the common ancestor region of all operations in `ops`. Remembers
+/// all the traversed regions in `traversedRegions`.
+static Region *findCommonParentRegion(const SetVector<Operation *> &ops,
+ DenseSet<Region *> &traversedRegions) {
+ // Map to count the number of times a region was encountered.
+ llvm::DenseMap<Region *, size_t> regionCounts;
+ size_t expectedCount = ops.size();
+
+ // Walk the region tree for each operation towards the root and add to the
+ // region count.
+ Region *res = nullptr;
+ for (Operation *op : ops) {
+ Region *current = op->getParentRegion();
+ while (current) {
+ // Insert or get the count.
+ auto it = regionCounts.try_emplace(current, 0).first;
+ size_t count = ++it->getSecond();
+ if (count == expectedCount) {
+ res = current;
+ break;
+ }
+ current = current->getParentRegion();
+ }
+ }
+ auto firstRange = llvm::make_first_range(regionCounts);
+ traversedRegions.insert(firstRange.begin(), firstRange.end());
+ return res;
+}
+
+/// Topologically traverses `region` and insers all encountered operations in
+/// `toSort` into the result. Recursively traverses regions when they are
+/// present in `relevantRegions`.
+static void topoSortRegion(Region ®ion,
+ const DenseSet<Region *> &relevantRegions,
+ const SetVector<Operation *> &toSort,
+ SetVector<Operation *> &result) {
+ SetVector<Block *> sortedBlocks = getBlocksSortedByDominance(region);
+ for (Block *block : sortedBlocks) {
+ for (Operation &op : *block) {
+ if (toSort.contains(&op))
+ result.insert(&op);
+ for (Region &subRegion : op.getRegions()) {
+ // Skip regions that do not contain operations from `toSort`.
+ if (!relevantRegions.contains(®ion))
+ continue;
+ topoSortRegion(subRegion, relevantRegions, toSort, result);
----------------
joker-eph wrote:
Can you replace this recursion with an iterative algorithm please?
These recursions are ticking bombs...
https://github.com/llvm/llvm-project/pull/92563
More information about the Mlir-commits
mailing list