[llvm-branch-commits] [mlir] [MLIR] Stop visiting unreachable blocks in the walkAndApplyPatterns driver (PR #154038)
Mehdi Amini via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Sun Aug 17 14:56:31 PDT 2025
https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/154038
>From 7260f9263133579e45e704ca2f42883770b0221a Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Sun, 17 Aug 2025 14:24:35 -0700
Subject: [PATCH] [MLIR] Stop visiting unreachable blocks in the
walkAndApplyPatterns driver
This is similar to the fix to the greedy driver in #153957 ; except that
instead of removing unreachable code, we just ignore it.
Operations like:
%add = arith.addi %add, %add : i64
are legal in unreachable code.
Unfortunately many patterns would be unsafe to apply on such IR and can
lead to crashes or infinite loops.
---
.../Transforms/WalkPatternRewriteDriver.h | 2 ++
.../Utils/WalkPatternRewriteDriver.cpp | 27 +++++++++++++++++++
.../IR/test-walk-pattern-rewrite-driver.mlir | 15 +++++++++++
3 files changed, 44 insertions(+)
diff --git a/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h b/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h
index 6d62ae3dd43dc..7d5c1d5cebb26 100644
--- a/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h
@@ -27,6 +27,8 @@ namespace mlir {
/// This is intended as the simplest and most lightweight pattern rewriter in
/// cases when a simple walk gets the job done.
///
+/// The driver will skip unreachable blocks.
+///
/// Note: Does not apply patterns to the given operation itself.
void walkAndApplyPatterns(Operation *op,
const FrozenRewritePatternSet &patterns,
diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
index 03a6e59aab4d9..0c2661b9ef561 100644
--- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
@@ -27,6 +27,26 @@
namespace mlir {
+// Find all reachable blocks in the region and add them to the visitedBlocks
+// set.
+static void findReachableBlocks(Region ®ion,
+ DenseSet<Block *> &reachableBlocks) {
+ Block *entryBlock = ®ion.front();
+ reachableBlocks.insert(entryBlock);
+ // Traverse the CFG and add all reachable blocks to the blockList.
+ SmallVector<Block *> worklist({entryBlock});
+ while (!worklist.empty()) {
+ Block *block = worklist.pop_back_val();
+ Operation *terminator = &block->back();
+ for (Block *successor : terminator->getSuccessors()) {
+ if (reachableBlocks.contains(successor))
+ continue;
+ worklist.push_back(successor);
+ reachableBlocks.insert(successor);
+ }
+ }
+}
+
namespace {
struct WalkAndApplyPatternsAction final
: tracing::ActionImpl<WalkAndApplyPatternsAction> {
@@ -98,6 +118,8 @@ void walkAndApplyPatterns(Operation *op,
regionIt = region->begin();
if (regionIt != region->end())
blockIt = regionIt->begin();
+ if (!llvm::hasSingleElement(*region))
+ findReachableBlocks(*region, reachableBlocks);
}
// Advance the iterator to the next reachable operation.
void advance() {
@@ -105,6 +127,9 @@ void walkAndApplyPatterns(Operation *op,
hasVisitedRegions = false;
if (blockIt == regionIt->end()) {
regionIt++;
+ while (regionIt != region->end() &&
+ !reachableBlocks.contains(&*regionIt))
+ regionIt++;
if (regionIt != region->end())
blockIt = regionIt->begin();
return;
@@ -121,6 +146,8 @@ void walkAndApplyPatterns(Operation *op,
Region::iterator regionIt;
// The Operation currently being iterated over.
Block::iterator blockIt;
+ // The set of blocks that are reachable in the current region.
+ DenseSet<Block *> reachableBlocks;
// Whether we've visited the nested regions of the current op already.
bool hasVisitedRegions = false;
};
diff --git a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
index c75c478ec3734..3479108dc28d4 100644
--- a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
+++ b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
@@ -119,3 +119,18 @@ func.func @erase_nested_block() -> i32 {
}): () -> (i32)
return %a : i32
}
+
+
+// CHECK-LABEL: func.func @unreachable_replace_with_new_op
+// CHECK: "test.new_op"
+// CHECK: "test.replace_with_new_op"
+func.func @unreachable_replace_with_new_op() {
+ "test.br"()[^bb1] : () -> ()
+^bb1:
+ %a = "test.replace_with_new_op"() : () -> (i32)
+ return
+^unreachable:
+ %b = "test.replace_with_new_op"() : () -> (i32)
+ return
+}
+
More information about the llvm-branch-commits
mailing list