[Mlir-commits] [mlir] [MLIR] Introduce support for early exits (PR #166688)
Matthias Springer
llvmlistbot at llvm.org
Sun Mar 1 04:24:29 PST 2026
================
@@ -32,3 +37,154 @@ bool mlir::mayBeGraphRegion(Region ®ion) {
return false;
return !regionKindOp.hasSSADominance(region.getRegionNumber());
}
+
+namespace {
+// Iterator on all reachable operations in the region.
+// Also keep track if we visited the nested regions of the current op
+// already to drive the traversal.
+struct NestedOpIterator {
+ NestedOpIterator(Region *region, int nestedLevel)
+ : region(region), nestedLevel(nestedLevel) {
+ regionIt = region->begin();
+ blockIt = regionIt->end();
+ if (regionIt != region->end())
+ blockIt = regionIt->begin();
+ }
+ // Advance the iterator to the next reachable operation.
+ void advance() {
+ assert(regionIt != region->end());
+ if (blockIt == regionIt->end()) {
+ ++regionIt;
+ if (regionIt != region->end())
+ blockIt = regionIt->begin();
+ return;
+ }
+ ++blockIt;
+ if (blockIt != regionIt->end()) {
+ LDBG() << this << " - Incrementing block iterator, next op: "
+ << OpWithFlags(&*blockIt, OpPrintingFlags().skipRegions());
+ }
+ }
+
+ // The region we're iterating over.
+ Region *region;
+ // The Block currently being iterated over.
+ Region::iterator regionIt;
+ // The Operation currently being iterated over.
+ Block::iterator blockIt;
+ // The nested level of the current region relative to the starting region.
+ int nestedLevel = 0;
+};
+} // namespace
+
+/// Recursive walk that calls the callback only for terminator operation which
+/// are breaking control flow.
+static void walk(Operation *rootOp,
+ function_ref<WalkResult(Operation *, int)> callback) {
+ // Worklist of regions to visit to drive the traversal.
+ SmallVector<NestedOpIterator> worklist;
+
+ // Perform a traversal of the regions, visiting each
+ // reachable operation.
+ for (Region ®ion : rootOp->getRegions()) {
+ if (region.empty())
+ continue;
+ worklist.push_back({®ion, 1});
+ }
+ while (!worklist.empty()) {
+ NestedOpIterator &it = worklist.back();
+ if (it.regionIt == it.region->end()) {
+ // We're done with this region.
+ worklist.pop_back();
+ continue;
+ }
+ if (it.blockIt == it.regionIt->end()) {
+ // We're done with this block.
+ it.advance();
+ continue;
+ }
+ Operation *op = &*it.blockIt;
+
+ // Only call the callback if we're at the end of the block.
+ if (std::next(it.blockIt) == it.regionIt->end() &&
+ callback(op, it.nestedLevel).wasInterrupted())
+ return;
+
+ // Advance before pushing nested regions to avoid reference invalidation.
+ int currentNestedLevel = it.nestedLevel;
+ it.advance();
+
+ // Recursively visit the nested regions.
+ for (Region &nestedRegion : op->getRegions()) {
+ if (nestedRegion.empty())
+ continue;
+ worklist.push_back({&nestedRegion, currentNestedLevel + 1});
+ }
+ }
+}
+
+/// Return true if `op` has at least one RegionTerminator nested inside it
+/// that directly targets `op` as its control-flow destination. A terminator
+/// directly targets `op` when its num-breaking-regions equals the nesting
+/// depth at which it appears inside `op`'s regions, AND that depth is > 1
+/// (depth 1 would mean the terminator exits only the immediately enclosing
+/// region, going to `op`'s parent rather than `op` itself — that case is
+/// handled by the normal RegionBranchOpInterface path).
+bool mlir::hasNestedPredecessors(Operation *op) {
+ bool found = false;
+ walk(op, [&](Operation *visitedOp, int nestedLevel) {
+ if (nestedLevel > 1 &&
+ nestedLevel ==
+ static_cast<int>(visitedOp->getNumBreakingControlRegions()))
+ found = true;
+ return found ? WalkResult::interrupt() : WalkResult::advance();
+ });
+ return found;
+}
+
+/// Return true if `op` contains any RegionTerminator whose num-breaking-
+/// regions value would carry it *past* `op` toward an outer ancestor. Such a
+/// terminator's nestedLevel (depth relative to `op`'s body) is strictly less
+/// than its num-breaking-regions, meaning `op` is one of the intermediate
+/// PropagateControlFlowBreak ops that is bypassed by this early exit.
+bool mlir::hasBreakingControlFlowOps(Operation *op) {
+ bool found = false;
+ walk(op, [&](Operation *visitedOp, int nestedLevel) {
+ if (nestedLevel <
+ static_cast<int>(visitedOp->getNumBreakingControlRegions()))
+ found = true;
+ return found ? WalkResult::interrupt() : WalkResult::advance();
+ });
+ return found;
+}
+
+/// Invoke `callback` for every RegionTerminator inside `op` whose
+/// num-breaking-regions is >= its current nesting depth (i.e. the terminator
+/// either terminates directly into `op` or propagates further upward). The
+/// `nestedLevel` passed to the callback is the 1-based depth of the terminator
+/// relative to `op`'s outermost region.
+void mlir::detail::visitNestedBreakingControlFlowOpsImpl(
+ Operation *op,
+ function_ref<WalkResult(Operation *, int nestedLevel)> callback) {
+ ::walk(op, [&](Operation *visitedOp, int nestedLevel) {
+ if (nestedLevel <=
+ static_cast<int>(visitedOp->getNumBreakingControlRegions()))
+ return callback(visitedOp, nestedLevel);
+ return WalkResult::advance();
+ });
+}
+
+/// Collect all RegionTerminator ops nested inside `op` that directly target
+/// `op` as their control-flow destination (num-breaking-regions ==
----------------
matthias-springer wrote:
Is the implementation correct? The documentation of `visitNestedBreakingControlFlowOpsImpl` says: `the terminator either terminates directly into `op` or propagates further upward`. But this function should collect only terminators that directly target the given op.
https://github.com/llvm/llvm-project/pull/166688
More information about the Mlir-commits
mailing list