[llvm] [mlir] [MLIR][Analysis] Consolidate topological sort utilities (PR #92563)
Christian Ulmann via llvm-commits
llvm-commits at lists.llvm.org
Tue May 21 01:35:04 PDT 2024
================
@@ -146,3 +151,93 @@ bool mlir::computeTopologicalSorting(
return allOpsScheduled;
}
+
+SetVector<Block *> mlir::getBlocksSortedByDominance(Region ®ion) {
+ // For each block that has not been visited yet (i.e. that has no
+ // predecessors), add it to the list as well as its successors.
+ SetVector<Block *> blocks;
+ for (Block &b : region) {
+ if (blocks.count(&b) == 0) {
+ llvm::ReversePostOrderTraversal<Block *> traversal(&b);
+ blocks.insert(traversal.begin(), traversal.end());
+ }
+ }
+ assert(blocks.size() == region.getBlocks().size() &&
+ "some blocks are not sorted");
+
+ return blocks;
+}
+
+/// Computes the common ancestor region of all operations in `ops`. Remembers
+/// all the traversed regions in `traversedRegions`.
+static Region *findCommonParentRegion(const SetVector<Operation *> &ops,
+ DenseSet<Region *> &traversedRegions) {
+ // Map to count the number of times a region was encountered.
+ llvm::DenseMap<Region *, size_t> regionCounts;
+ size_t expectedCount = ops.size();
+
+ // Walk the region tree for each operation towards the root and add to the
+ // region count.
+ Region *res = nullptr;
+ for (Operation *op : ops) {
+ Region *current = op->getParentRegion();
+ while (current) {
+ // Insert or get the count.
+ auto it = regionCounts.try_emplace(current, 0).first;
+ size_t count = ++it->getSecond();
+ if (count == expectedCount) {
+ res = current;
+ break;
+ }
+ current = current->getParentRegion();
+ }
+ }
+ auto firstRange = llvm::make_first_range(regionCounts);
+ traversedRegions.insert(firstRange.begin(), firstRange.end());
+ return res;
+}
+
+/// Topologically traverses `region` and insers all encountered operations in
+/// `toSort` into the result. Recursively traverses regions when they are
+/// present in `relevantRegions`.
+static void topoSortRegion(Region ®ion,
+ const DenseSet<Region *> &relevantRegions,
+ const SetVector<Operation *> &toSort,
+ SetVector<Operation *> &result) {
+ SetVector<Block *> sortedBlocks = getBlocksSortedByDominance(region);
+ for (Block *block : sortedBlocks) {
+ for (Operation &op : *block) {
+ if (toSort.contains(&op))
+ result.insert(&op);
+ for (Region &subRegion : op.getRegions()) {
+ // Skip regions that do not contain operations from `toSort`.
+ if (!relevantRegions.contains(®ion))
+ continue;
+ topoSortRegion(subRegion, relevantRegions, toSort, result);
+ }
+ }
+ }
+}
+
+SetVector<Operation *>
+mlir::topologicalSort(const SetVector<Operation *> &toSort) {
+ if (toSort.size() <= 1)
+ return toSort;
+
+ assert(llvm::all_of(toSort,
+ [&](Operation *op) { return toSort.count(op) == 1; }) &&
----------------
Dinistro wrote:
That was there before, but you are right 🙂
https://github.com/llvm/llvm-project/pull/92563
More information about the llvm-commits
mailing list