[llvm-branch-commits] [mlir] [MLIR] Stop visiting unreachable blocks in the walkAndApplyPatterns driver (PR #154038)

Mehdi Amini via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Sun Aug 17 14:34:37 PDT 2025


https://github.com/joker-eph created https://github.com/llvm/llvm-project/pull/154038

This is similar to the fix to the greedy driver in #153957 ; except that instead of removing unreachable code, we just ignore it.

Operations like:

```
%add = arith.addi %add, %add : i64
```

are legal in unreachable code.
Unfortunately many patterns would be unsafe to apply on such IR and can lead to crashes or infinite loops.

>From e2e4ae7bc3ee5c91215d10743c607e80729bd28a Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Sun, 17 Aug 2025 14:24:35 -0700
Subject: [PATCH] [MLIR] Stop visiting unreachable blocks in the
 walkAndApplyPatterns driver

This is similar to the fix to the greedy driver in #153957 ; except that
instead of removing unreachable code, we just ignore it.

Operations like:

%add = arith.addi %add, %add : i64

are legal in unreachable code.
Unfortunately many patterns would be unsafe to apply on such IR and can
lead to crashes or infinite loops.
---
 .../Transforms/WalkPatternRewriteDriver.h     |  2 +
 .../Utils/WalkPatternRewriteDriver.cpp        | 76 ++++++++++++++-----
 .../IR/test-walk-pattern-rewrite-driver.mlir  | 10 +++
 3 files changed, 70 insertions(+), 18 deletions(-)

diff --git a/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h b/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h
index 6d62ae3dd43dc..7d5c1d5cebb26 100644
--- a/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h
@@ -27,6 +27,8 @@ namespace mlir {
 /// This is intended as the simplest and most lightweight pattern rewriter in
 /// cases when a simple walk gets the job done.
 ///
+/// The driver will skip unreachable blocks.
+///
 /// Note: Does not apply patterns to the given operation itself.
 void walkAndApplyPatterns(Operation *op,
                           const FrozenRewritePatternSet &patterns,
diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
index 52f8ea5472883..8f26a294f6d9b 100644
--- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
@@ -27,6 +27,26 @@
 
 namespace mlir {
 
+// Find all reachable blocks in the region and add them to the visitedBlocks
+// set.
+static void findReachableBlocks(Region &region,
+                                DenseSet<Block *> &reachableBlocks) {
+  Block *entryBlock = &region.front();
+  reachableBlocks.insert(entryBlock);
+  // Traverse the CFG and add all reachable blocks to the blockList.
+  SmallVector<Block *> worklist({entryBlock});
+  Block *block = worklist.pop_back_val();
+  while (!worklist.empty()) {
+    Operation *terminator = &block->back();
+    for (Block *successor : terminator->getSuccessors()) {
+      if (reachableBlocks.contains(successor))
+        continue;
+      worklist.push_back(successor);
+      reachableBlocks.insert(successor);
+    }
+  }
+}
+
 namespace {
 struct WalkAndApplyPatternsAction final
     : tracing::ActionImpl<WalkAndApplyPatternsAction> {
@@ -90,18 +110,28 @@ void walkAndApplyPatterns(Operation *op,
   PatternApplicator applicator(patterns);
   applicator.applyDefaultCostModel();
 
-  // Cursor to track where we're at in the traversal.
-  struct Cursor {
-    Cursor(Region *region) : region(region) {
+  // Iterator on all reachable operations in the region.
+  // Also keep track if we visited the nested regions of the current op
+  // already to drive the post-order traversal.
+  struct RegionReachableOpIterator {
+    RegionReachableOpIterator(Region *region) : region(region) {
       regionIt = region->begin();
       if (regionIt != region->end())
         blockIt = regionIt->begin();
+      if (!llvm::hasSingleElement(*region))
+        findReachableBlocks(*region, reachableBlocks);
     }
-    void next() {
+    // Advance the iterator to the next reachable operation.
+    void advance() {
       assert(regionIt != region->end());
       hasVisitedRegions = false;
       if (blockIt == regionIt->end()) {
         regionIt++;
+        while (regionIt != region->end() &&
+               !reachableBlocks.contains(&*regionIt))
+          regionIt++;
+        if (regionIt != region->end())
+          blockIt = regionIt->begin();
         return;
       }
       blockIt++;
@@ -110,14 +140,23 @@ void walkAndApplyPatterns(Operation *op,
                << OpWithFlags(&*blockIt, OpPrintingFlags().skipRegions());
       }
     }
+
+    // The region we're iterating over.
     Region *region;
+    // The Block currently being iterated over.
     Region::iterator regionIt;
+    // The Operation currently being iterated over.
     Block::iterator blockIt;
+    // The set of blocks that are reachable in the current region.
+    DenseSet<Block *> reachableBlocks;
+    // Whether we've visited the nested regions of the current op already.
     bool hasVisitedRegions = false;
   };
-  SmallVector<Cursor> worklist;
+  SmallVector<RegionReachableOpIterator> worklist;
 
   LDBG() << "Starting walk-based pattern rewrite driver";
+  // Perform a post-order traversal of the region, visiting each reachable
+  // operation.
   ctx->executeAction<WalkAndApplyPatternsAction>(
       [&] {
         for (Region &region : op->getRegions()) {
@@ -128,36 +167,37 @@ void walkAndApplyPatterns(Operation *op,
           // Prime the worklist with the entry block of this region.
           worklist.push_back({&region});
           while (!worklist.empty()) {
-            Cursor &cursor = worklist.back();
-            if (cursor.regionIt == cursor.region->end()) {
+            RegionReachableOpIterator &it = worklist.back();
+            if (it.regionIt == it.region->end()) {
               // We're done with this region.
               worklist.pop_back();
               continue;
             }
-            if (cursor.blockIt == cursor.regionIt->end()) {
+            if (it.blockIt == it.regionIt->end()) {
               // We're done with this block.
-              cursor.regionIt++;
-              if (cursor.regionIt != cursor.region->end())
-                cursor.blockIt = cursor.regionIt->begin();
+              it.advance();
               continue;
             }
-            Operation *op = &*cursor.blockIt;
-            if (!cursor.hasVisitedRegions) {
-              cursor.hasVisitedRegions = true;
+            Operation *op = &*it.blockIt;
+            // If we haven't visited the nested regions of this op yet,
+            // enqueue them.
+            if (!it.hasVisitedRegions) {
+              it.hasVisitedRegions = true;
               for (Region &nestedRegion : llvm::reverse(op->getRegions())) {
                 if (nestedRegion.empty())
                   continue;
                 worklist.push_back({&nestedRegion});
               }
             }
-            // If we're not at the back of the worklist, we're visiting a nested
-            // region first. We'll come back to this op later.
-            if (&cursor != &worklist.back())
+            // If we're not at the back of the worklist, we've enqueued some
+            // nested region for processing. We'll come back to this op later
+            // (post-order)
+            if (&it != &worklist.back())
               continue;
 
             // Premptively increment the cursor, in case the current op
             // would be erased.
-            cursor.next();
+            it.advance();
 
             LDBG() << "Visiting op: "
                    << OpWithFlags(op, OpPrintingFlags().skipRegions());
diff --git a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
index c75c478ec3734..1acff6fdf029e 100644
--- a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
+++ b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
@@ -119,3 +119,13 @@ func.func @erase_nested_block() -> i32 {
   }): () -> (i32)
   return %a : i32
 }
+
+
+// CHECK-LABEL: func.func @unreachable_replace_with_new_op
+// CHECK: "test.replace_with_new_op"
+func.func @unreachable_replace_with_new_op() {
+  return
+^unreachable:
+  %a = "test.replace_with_new_op"() : () -> (i32)
+  return
+}



More information about the llvm-branch-commits mailing list