[Mlir-commits] [mlir] [MLIR] Refactor the walkAndApplyPatterns driver to remove the recursion (PR #154037)

Mehdi Amini llvmlistbot at llvm.org
Sun Aug 17 14:40:43 PDT 2025


https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/154037

>From c074f7495e94334e21d0994c3958afe4921f7158 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Sun, 17 Aug 2025 05:32:46 -0700
Subject: [PATCH] [MLIR] Refactor the walkAndApplyPatterns driver to remove the
 recursion (NFC)

This is in preparation of a follow-up change to stop traversing unreachable blocks.
---
 .../Utils/WalkPatternRewriteDriver.cpp        | 99 +++++++++++++++++--
 .../IR/test-walk-pattern-rewrite-driver.mlir  |  4 +-
 2 files changed, 91 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
index ee5c642c943c4..a36d96c981603 100644
--- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
@@ -13,12 +13,14 @@
 #include "mlir/Transforms/WalkPatternRewriteDriver.h"
 
 #include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Operation.h"
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/IR/Visitors.h"
 #include "mlir/Rewrite/PatternApplicator.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
 #include "llvm/Support/ErrorHandling.h"
 
 #define DEBUG_TYPE "walk-rewriter"
@@ -88,20 +90,97 @@ void walkAndApplyPatterns(Operation *op,
   PatternApplicator applicator(patterns);
   applicator.applyDefaultCostModel();
 
+  // Iterator on all reachable operations in the region.
+  // Also keep track if we visited the nested regions of the current op
+  // already to drive the post-order traversal.
+  struct RegionReachableOpIterator {
+    RegionReachableOpIterator(Region *region) : region(region) {
+      regionIt = region->begin();
+      if (regionIt != region->end())
+        blockIt = regionIt->begin();
+    }
+    // Advance the iterator to the next reachable operation.
+    void advance() {
+      assert(regionIt != region->end());
+      hasVisitedRegions = false;
+      if (blockIt == regionIt->end()) {
+        regionIt++;
+        if (regionIt != region->end())
+          blockIt = regionIt->begin();
+        return;
+      }
+      blockIt++;
+      if (blockIt != regionIt->end()) {
+        LDBG() << "Incrementing block iterator, next op: "
+               << OpWithFlags(&*blockIt, OpPrintingFlags().skipRegions());
+      }
+    }
+    // The region we're iterating over.
+    Region *region;
+    // The Block currently being iterated over.
+    Region::iterator regionIt;
+    // The Operation currently being iterated over.
+    Block::iterator blockIt;
+    // Whether we've visited the nested regions of the current op already.
+    bool hasVisitedRegions = false;
+  };
+
+  // Worklist of regions to visit to drive the post-order traversal.
+  SmallVector<RegionReachableOpIterator> worklist;
+
+  LDBG() << "Starting walk-based pattern rewrite driver";
   ctx->executeAction<WalkAndApplyPatternsAction>(
       [&] {
+        // Perform a post-order traversal of the regions, visiting each
+        // reachable operation.
         for (Region &region : op->getRegions()) {
-          region.walk([&](Operation *visitedOp) {
-            LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print(
-                llvm::dbgs(), OpPrintingFlags().skipRegions());
-                       llvm::dbgs() << "\n";);
+          assert(worklist.empty());
+          if (region.empty())
+            continue;
+
+          // Prime the worklist with the entry block of this region.
+          worklist.push_back({&region});
+          while (!worklist.empty()) {
+            RegionReachableOpIterator &it = worklist.back();
+            if (it.regionIt == it.region->end()) {
+              // We're done with this region.
+              worklist.pop_back();
+              continue;
+            }
+            if (it.blockIt == it.regionIt->end()) {
+              // We're done with this block.
+              it.advance();
+              continue;
+            }
+            Operation *op = &*it.blockIt;
+            // If we haven't visited the nested regions of this op yet,
+            // enqueue them.
+            if (!it.hasVisitedRegions) {
+              it.hasVisitedRegions = true;
+              for (Region &nestedRegion : llvm::reverse(op->getRegions())) {
+                if (nestedRegion.empty())
+                  continue;
+                worklist.push_back({&nestedRegion});
+              }
+            }
+            // If we're not at the back of the worklist, we've enqueued some
+            // nested region for processing. We'll come back to this op later
+            // (post-order)
+            if (&it != &worklist.back())
+              continue;
+
+            // Premptively increment the iterator, in case the current op
+            // would be erased.
+            it.advance();
+
+            LDBG() << "Visiting op: "
+                   << OpWithFlags(op, OpPrintingFlags().skipRegions());
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
-            erasedListener.visitedOp = visitedOp;
+            erasedListener.visitedOp = op;
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
-            if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
-              LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
-            }
-          });
+            if (succeeded(applicator.matchAndRewrite(op, rewriter)))
+              LDBG() << "\tOp matched and rewritten";
+          }
         }
       },
       {op});
diff --git a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
index 02f7e60671c9b..c75c478ec3734 100644
--- a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
+++ b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
@@ -40,12 +40,12 @@ func.func @move_before(%cond : i1) {
 }
 
 // Check that the driver handles rewriter.moveAfter. In this case, we expect
-// the moved op to be visited only once since walk uses `make_early_inc_range`.
+// the moved op to be visited twice.
 // CHECK-LABEL: func.func @move_after(
 // CHECK: scf.if
 // CHECK: }
 // CHECK: "test.move_after_parent_op"
-// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
+// CHECK: "test.any_attr_of_i32_str"() <{attr = 2 : i32}> : () -> ()
 // CHECK: return
 func.func @move_after(%cond : i1) {
   scf.if %cond {



More information about the Mlir-commits mailing list