[Mlir-commits] [llvm] [mlir] [MLIR][Analysis] Consolidate topological sort utilities (PR #92563)
Tobias Gysi
llvmlistbot at llvm.org
Tue May 21 10:34:58 PDT 2024
================
@@ -146,3 +151,135 @@ bool mlir::computeTopologicalSorting(
return allOpsScheduled;
}
+
+SetVector<Block *> mlir::getBlocksSortedByDominance(Region ®ion) {
+ // For each block that has not been visited yet (i.e. that has no
+ // predecessors), add it to the list as well as its successors.
+ SetVector<Block *> blocks;
+ for (Block &b : region) {
+ if (blocks.count(&b) == 0) {
+ llvm::ReversePostOrderTraversal<Block *> traversal(&b);
+ blocks.insert(traversal.begin(), traversal.end());
+ }
+ }
+ assert(blocks.size() == region.getBlocks().size() &&
+ "some blocks are not sorted");
+
+ return blocks;
+}
+
+namespace {
+class TopoSortHelper {
+public:
+ explicit TopoSortHelper(const SetVector<Operation *> &toSort)
+ : toSort(toSort) {}
+
+ /// Executes the topological sort of the operations this instance was
+ /// constructed with. This function will destroy the internal state of the
+ /// instance.
+ SetVector<Operation *> sort() {
+ if (toSort.size() <= 1)
+ // Note: Creates a copy on purpose.
+ return toSort;
+
+ // First, find the root region to start the traversal through the IR. This
+ // additionally enriches the internal caches with all relevant ancestor
+ // regions and blocks.
+ Region *rootRegion = findCommonAncestorRegion();
+ assert(rootRegion && "expected all ops to have a common ancestor");
+
+ // Sort all element in `toSort` by traversing the IR in the appropriate
+ // order.
+ SetVector<Operation *> result = topoSortRegion(*rootRegion);
+ assert(result.size() == toSort.size() &&
+ "expected all operations to be present in the result");
+ return result;
+ }
+
+private:
+ /// Computes the closest common ancestor region of all operations in `toSort`.
+ /// Remembers all the traversed regions in `ancestorRegions`.
+ Region *findCommonAncestorRegion() {
+ // Map to count the number of times a region was encountered.
+ DenseMap<Region *, size_t> regionCounts;
+ size_t expectedCount = toSort.size();
+
+ // Walk the region tree for each operation towards the root and add to the
+ // region count.
+ Region *res = nullptr;
+ for (Operation *op : toSort) {
+ Region *current = op->getParentRegion();
+ // Store the block as an ancestor block.
+ ancestorBlocks.insert(op->getBlock());
+ while (current) {
+
+ // Insert or update the count and compare it.
+ if (++regionCounts[current] == expectedCount) {
+ res = current;
+ break;
+ }
+ ancestorBlocks.insert(current->getParentOp()->getBlock());
+ current = current->getParentRegion();
+ }
+ }
+ auto firstRange = llvm::make_first_range(regionCounts);
+ ancestorRegions.insert(firstRange.begin(), firstRange.end());
+ return res;
+ }
+
+ /// Performs the dominance respecting IR walk to collect the topological order
+ /// of the operation to sort.
+ SetVector<Operation *> topoSortRegion(Region &rootRegion) {
+ using StackT = PointerUnion<Region *, Block *, Operation *>;
+
+ SetVector<Operation *> result;
+ // Stack that stores the different IR constructs to traverse.
+ SmallVector<StackT> stack;
+ stack.push_back(&rootRegion);
+
+ // Traverse the IR in a dominance respecting pre-order walk.
+ while (!stack.empty()) {
+ StackT current = stack.pop_back_val();
+ if (auto *region = dyn_cast<Region *>(current)) {
+ // A region's blocks need to be traversed in dominance order.
+ SetVector<Block *> sortedBlocks = getBlocksSortedByDominance(*region);
+ for (Block *block : llvm::reverse(sortedBlocks))
+ // Only add blocks to the stack that are ancestors of the operations
+ // to sort.
+ if (ancestorBlocks.contains(block))
+ stack.push_back(block);
----------------
gysit wrote:
```suggestion
for (Block *block : llvm::reverse(sortedBlocks)) {
// Only add blocks to the stack that are ancestors of the operations
// to sort.
if (ancestorBlocks.contains(block))
stack.push_back(block);
}
```
https://github.com/llvm/llvm-project/pull/92563
More information about the Mlir-commits
mailing list