[Mlir-commits] [mlir] [MLIR] Refactor the walkAndApplyPatterns driver to remove the recursion (PR #154037)
Mehdi Amini
llvmlistbot at llvm.org
Sun Aug 17 14:40:43 PDT 2025
https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/154037
>From c074f7495e94334e21d0994c3958afe4921f7158 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Sun, 17 Aug 2025 05:32:46 -0700
Subject: [PATCH] [MLIR] Refactor the walkAndApplyPatterns driver to remove the
recursion (NFC)
This is in preparation of a follow-up change to stop traversing unreachable blocks.
---
.../Utils/WalkPatternRewriteDriver.cpp | 99 +++++++++++++++++--
.../IR/test-walk-pattern-rewrite-driver.mlir | 4 +-
2 files changed, 91 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
index ee5c642c943c4..a36d96c981603 100644
--- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
@@ -13,12 +13,14 @@
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Rewrite/PatternApplicator.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/ErrorHandling.h"
#define DEBUG_TYPE "walk-rewriter"
@@ -88,20 +90,97 @@ void walkAndApplyPatterns(Operation *op,
PatternApplicator applicator(patterns);
applicator.applyDefaultCostModel();
+ // Iterator on all reachable operations in the region.
+ // Also keep track if we visited the nested regions of the current op
+ // already to drive the post-order traversal.
+ struct RegionReachableOpIterator {
+ RegionReachableOpIterator(Region *region) : region(region) {
+ regionIt = region->begin();
+ if (regionIt != region->end())
+ blockIt = regionIt->begin();
+ }
+ // Advance the iterator to the next reachable operation.
+ void advance() {
+ assert(regionIt != region->end());
+ hasVisitedRegions = false;
+ if (blockIt == regionIt->end()) {
+ regionIt++;
+ if (regionIt != region->end())
+ blockIt = regionIt->begin();
+ return;
+ }
+ blockIt++;
+ if (blockIt != regionIt->end()) {
+ LDBG() << "Incrementing block iterator, next op: "
+ << OpWithFlags(&*blockIt, OpPrintingFlags().skipRegions());
+ }
+ }
+ // The region we're iterating over.
+ Region *region;
+ // The Block currently being iterated over.
+ Region::iterator regionIt;
+ // The Operation currently being iterated over.
+ Block::iterator blockIt;
+ // Whether we've visited the nested regions of the current op already.
+ bool hasVisitedRegions = false;
+ };
+
+ // Worklist of regions to visit to drive the post-order traversal.
+ SmallVector<RegionReachableOpIterator> worklist;
+
+ LDBG() << "Starting walk-based pattern rewrite driver";
ctx->executeAction<WalkAndApplyPatternsAction>(
[&] {
+ // Perform a post-order traversal of the regions, visiting each
+ // reachable operation.
for (Region ®ion : op->getRegions()) {
- region.walk([&](Operation *visitedOp) {
- LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print(
- llvm::dbgs(), OpPrintingFlags().skipRegions());
- llvm::dbgs() << "\n";);
+ assert(worklist.empty());
+ if (region.empty())
+ continue;
+
+ // Prime the worklist with the entry block of this region.
+ worklist.push_back({®ion});
+ while (!worklist.empty()) {
+ RegionReachableOpIterator &it = worklist.back();
+ if (it.regionIt == it.region->end()) {
+ // We're done with this region.
+ worklist.pop_back();
+ continue;
+ }
+ if (it.blockIt == it.regionIt->end()) {
+ // We're done with this block.
+ it.advance();
+ continue;
+ }
+ Operation *op = &*it.blockIt;
+ // If we haven't visited the nested regions of this op yet,
+ // enqueue them.
+ if (!it.hasVisitedRegions) {
+ it.hasVisitedRegions = true;
+ for (Region &nestedRegion : llvm::reverse(op->getRegions())) {
+ if (nestedRegion.empty())
+ continue;
+ worklist.push_back({&nestedRegion});
+ }
+ }
+ // If we're not at the back of the worklist, we've enqueued some
+ // nested region for processing. We'll come back to this op later
+ // (post-order)
+ if (&it != &worklist.back())
+ continue;
+
+ // Premptively increment the iterator, in case the current op
+ // would be erased.
+ it.advance();
+
+ LDBG() << "Visiting op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- erasedListener.visitedOp = visitedOp;
+ erasedListener.visitedOp = op;
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
- LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
- }
- });
+ if (succeeded(applicator.matchAndRewrite(op, rewriter)))
+ LDBG() << "\tOp matched and rewritten";
+ }
}
},
{op});
diff --git a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
index 02f7e60671c9b..c75c478ec3734 100644
--- a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
+++ b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
@@ -40,12 +40,12 @@ func.func @move_before(%cond : i1) {
}
// Check that the driver handles rewriter.moveAfter. In this case, we expect
-// the moved op to be visited only once since walk uses `make_early_inc_range`.
+// the moved op to be visited twice.
// CHECK-LABEL: func.func @move_after(
// CHECK: scf.if
// CHECK: }
// CHECK: "test.move_after_parent_op"
-// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
+// CHECK: "test.any_attr_of_i32_str"() <{attr = 2 : i32}> : () -> ()
// CHECK: return
func.func @move_after(%cond : i1) {
scf.if %cond {
More information about the Mlir-commits
mailing list