[llvm-branch-commits] [mlir] [MLIR] Stop visiting unreachable blocks in the walkAndApplyPatterns driver (PR #154038)
Mehdi Amini via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Sun Aug 17 14:34:37 PDT 2025
https://github.com/joker-eph created https://github.com/llvm/llvm-project/pull/154038
This is similar to the fix to the greedy driver in #153957 ; except that instead of removing unreachable code, we just ignore it.
Operations like:
```
%add = arith.addi %add, %add : i64
```
are legal in unreachable code.
Unfortunately many patterns would be unsafe to apply on such IR and can lead to crashes or infinite loops.
>From e2e4ae7bc3ee5c91215d10743c607e80729bd28a Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Sun, 17 Aug 2025 14:24:35 -0700
Subject: [PATCH] [MLIR] Stop visiting unreachable blocks in the
walkAndApplyPatterns driver
This is similar to the fix to the greedy driver in #153957 ; except that
instead of removing unreachable code, we just ignore it.
Operations like:
%add = arith.addi %add, %add : i64
are legal in unreachable code.
Unfortunately many patterns would be unsafe to apply on such IR and can
lead to crashes or infinite loops.
---
.../Transforms/WalkPatternRewriteDriver.h | 2 +
.../Utils/WalkPatternRewriteDriver.cpp | 76 ++++++++++++++-----
.../IR/test-walk-pattern-rewrite-driver.mlir | 10 +++
3 files changed, 70 insertions(+), 18 deletions(-)
diff --git a/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h b/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h
index 6d62ae3dd43dc..7d5c1d5cebb26 100644
--- a/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h
@@ -27,6 +27,8 @@ namespace mlir {
/// This is intended as the simplest and most lightweight pattern rewriter in
/// cases when a simple walk gets the job done.
///
+/// The driver will skip unreachable blocks.
+///
/// Note: Does not apply patterns to the given operation itself.
void walkAndApplyPatterns(Operation *op,
const FrozenRewritePatternSet &patterns,
diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
index 52f8ea5472883..8f26a294f6d9b 100644
--- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
@@ -27,6 +27,26 @@
namespace mlir {
+// Find all reachable blocks in the region and add them to the visitedBlocks
+// set.
+static void findReachableBlocks(Region ®ion,
+ DenseSet<Block *> &reachableBlocks) {
+ Block *entryBlock = ®ion.front();
+ reachableBlocks.insert(entryBlock);
+ // Traverse the CFG and add all reachable blocks to the blockList.
+ SmallVector<Block *> worklist({entryBlock});
+ Block *block = worklist.pop_back_val();
+ while (!worklist.empty()) {
+ Operation *terminator = &block->back();
+ for (Block *successor : terminator->getSuccessors()) {
+ if (reachableBlocks.contains(successor))
+ continue;
+ worklist.push_back(successor);
+ reachableBlocks.insert(successor);
+ }
+ }
+}
+
namespace {
struct WalkAndApplyPatternsAction final
: tracing::ActionImpl<WalkAndApplyPatternsAction> {
@@ -90,18 +110,28 @@ void walkAndApplyPatterns(Operation *op,
PatternApplicator applicator(patterns);
applicator.applyDefaultCostModel();
- // Cursor to track where we're at in the traversal.
- struct Cursor {
- Cursor(Region *region) : region(region) {
+ // Iterator on all reachable operations in the region.
+ // Also keep track if we visited the nested regions of the current op
+ // already to drive the post-order traversal.
+ struct RegionReachableOpIterator {
+ RegionReachableOpIterator(Region *region) : region(region) {
regionIt = region->begin();
if (regionIt != region->end())
blockIt = regionIt->begin();
+ if (!llvm::hasSingleElement(*region))
+ findReachableBlocks(*region, reachableBlocks);
}
- void next() {
+ // Advance the iterator to the next reachable operation.
+ void advance() {
assert(regionIt != region->end());
hasVisitedRegions = false;
if (blockIt == regionIt->end()) {
regionIt++;
+ while (regionIt != region->end() &&
+ !reachableBlocks.contains(&*regionIt))
+ regionIt++;
+ if (regionIt != region->end())
+ blockIt = regionIt->begin();
return;
}
blockIt++;
@@ -110,14 +140,23 @@ void walkAndApplyPatterns(Operation *op,
<< OpWithFlags(&*blockIt, OpPrintingFlags().skipRegions());
}
}
+
+ // The region we're iterating over.
Region *region;
+ // The Block currently being iterated over.
Region::iterator regionIt;
+ // The Operation currently being iterated over.
Block::iterator blockIt;
+ // The set of blocks that are reachable in the current region.
+ DenseSet<Block *> reachableBlocks;
+ // Whether we've visited the nested regions of the current op already.
bool hasVisitedRegions = false;
};
- SmallVector<Cursor> worklist;
+ SmallVector<RegionReachableOpIterator> worklist;
LDBG() << "Starting walk-based pattern rewrite driver";
+ // Perform a post-order traversal of the region, visiting each reachable
+ // operation.
ctx->executeAction<WalkAndApplyPatternsAction>(
[&] {
for (Region ®ion : op->getRegions()) {
@@ -128,36 +167,37 @@ void walkAndApplyPatterns(Operation *op,
// Prime the worklist with the entry block of this region.
worklist.push_back({®ion});
while (!worklist.empty()) {
- Cursor &cursor = worklist.back();
- if (cursor.regionIt == cursor.region->end()) {
+ RegionReachableOpIterator &it = worklist.back();
+ if (it.regionIt == it.region->end()) {
// We're done with this region.
worklist.pop_back();
continue;
}
- if (cursor.blockIt == cursor.regionIt->end()) {
+ if (it.blockIt == it.regionIt->end()) {
// We're done with this block.
- cursor.regionIt++;
- if (cursor.regionIt != cursor.region->end())
- cursor.blockIt = cursor.regionIt->begin();
+ it.advance();
continue;
}
- Operation *op = &*cursor.blockIt;
- if (!cursor.hasVisitedRegions) {
- cursor.hasVisitedRegions = true;
+ Operation *op = &*it.blockIt;
+ // If we haven't visited the nested regions of this op yet,
+ // enqueue them.
+ if (!it.hasVisitedRegions) {
+ it.hasVisitedRegions = true;
for (Region &nestedRegion : llvm::reverse(op->getRegions())) {
if (nestedRegion.empty())
continue;
worklist.push_back({&nestedRegion});
}
}
- // If we're not at the back of the worklist, we're visiting a nested
- // region first. We'll come back to this op later.
- if (&cursor != &worklist.back())
+ // If we're not at the back of the worklist, we've enqueued some
+ // nested region for processing. We'll come back to this op later
+ // (post-order)
+ if (&it != &worklist.back())
continue;
// Premptively increment the cursor, in case the current op
// would be erased.
- cursor.next();
+ it.advance();
LDBG() << "Visiting op: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
diff --git a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
index c75c478ec3734..1acff6fdf029e 100644
--- a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
+++ b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
@@ -119,3 +119,13 @@ func.func @erase_nested_block() -> i32 {
}): () -> (i32)
return %a : i32
}
+
+
+// CHECK-LABEL: func.func @unreachable_replace_with_new_op
+// CHECK: "test.replace_with_new_op"
+func.func @unreachable_replace_with_new_op() {
+ return
+^unreachable:
+ %a = "test.replace_with_new_op"() : () -> (i32)
+ return
+}
More information about the llvm-branch-commits
mailing list