[Mlir-commits] [mlir] 16aa283 - [MLIR] Refactor the walkAndApplyPatterns driver to remove the recursion (#154037)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Aug 18 02:07:23 PDT 2025
Author: Mehdi Amini
Date: 2025-08-18T09:07:19Z
New Revision: 16aa283344d3fec9467e6adb7ea511c0a88afacc
URL: https://github.com/llvm/llvm-project/commit/16aa283344d3fec9467e6adb7ea511c0a88afacc
DIFF: https://github.com/llvm/llvm-project/commit/16aa283344d3fec9467e6adb7ea511c0a88afacc.diff
LOG: [MLIR] Refactor the walkAndApplyPatterns driver to remove the recursion (#154037)
This is in preparation of a follow-up change to stop traversing
unreachable blocks.
This is not NFC because of a subtlety of the early_inc. On a test case
like:
```
scf.if %cond {
"test.move_after_parent_op"() ({
"test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
}) : () -> ()
}
```
We recursively traverse the nested regions, and process an op when the
region is done (post-order).
We need to pre-increment the iterator before processing an operation in
case it gets deleted. However
we can do this before or after processing the nested region. This
implementation does the latter.
Added:
Modified:
mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
index ee5c642c943c4..2111e29120567 100644
--- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
@@ -13,12 +13,14 @@
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Rewrite/PatternApplicator.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/ErrorHandling.h"
#define DEBUG_TYPE "walk-rewriter"
@@ -88,20 +90,97 @@ void walkAndApplyPatterns(Operation *op,
PatternApplicator applicator(patterns);
applicator.applyDefaultCostModel();
+ // Iterator on all reachable operations in the region.
+ // Also keep track if we visited the nested regions of the current op
+ // already to drive the post-order traversal.
+ struct RegionReachableOpIterator {
+ RegionReachableOpIterator(Region *region) : region(region) {
+ regionIt = region->begin();
+ if (regionIt != region->end())
+ blockIt = regionIt->begin();
+ }
+ // Advance the iterator to the next reachable operation.
+ void advance() {
+ assert(regionIt != region->end());
+ hasVisitedRegions = false;
+ if (blockIt == regionIt->end()) {
+ ++regionIt;
+ if (regionIt != region->end())
+ blockIt = regionIt->begin();
+ return;
+ }
+ ++blockIt;
+ if (blockIt != regionIt->end()) {
+ LDBG() << "Incrementing block iterator, next op: "
+ << OpWithFlags(&*blockIt, OpPrintingFlags().skipRegions());
+ }
+ }
+ // The region we're iterating over.
+ Region *region;
+ // The Block currently being iterated over.
+ Region::iterator regionIt;
+ // The Operation currently being iterated over.
+ Block::iterator blockIt;
+ // Whether we've visited the nested regions of the current op already.
+ bool hasVisitedRegions = false;
+ };
+
+ // Worklist of regions to visit to drive the post-order traversal.
+ SmallVector<RegionReachableOpIterator> worklist;
+
+ LDBG() << "Starting walk-based pattern rewrite driver";
ctx->executeAction<WalkAndApplyPatternsAction>(
[&] {
+ // Perform a post-order traversal of the regions, visiting each
+ // reachable operation.
for (Region ®ion : op->getRegions()) {
- region.walk([&](Operation *visitedOp) {
- LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print(
- llvm::dbgs(), OpPrintingFlags().skipRegions());
- llvm::dbgs() << "\n";);
+ assert(worklist.empty());
+ if (region.empty())
+ continue;
+
+ // Prime the worklist with the entry block of this region.
+ worklist.push_back({®ion});
+ while (!worklist.empty()) {
+ RegionReachableOpIterator &it = worklist.back();
+ if (it.regionIt == it.region->end()) {
+ // We're done with this region.
+ worklist.pop_back();
+ continue;
+ }
+ if (it.blockIt == it.regionIt->end()) {
+ // We're done with this block.
+ it.advance();
+ continue;
+ }
+ Operation *op = &*it.blockIt;
+ // If we haven't visited the nested regions of this op yet,
+ // enqueue them.
+ if (!it.hasVisitedRegions) {
+ it.hasVisitedRegions = true;
+ for (Region &nestedRegion : llvm::reverse(op->getRegions())) {
+ if (nestedRegion.empty())
+ continue;
+ worklist.push_back({&nestedRegion});
+ }
+ }
+ // If we're not at the back of the worklist, we've enqueued some
+ // nested region for processing. We'll come back to this op later
+ // (post-order)
+ if (&it != &worklist.back())
+ continue;
+
+ // Preemptively increment the iterator, in case the current op
+ // would be erased.
+ it.advance();
+
+ LDBG() << "Visiting op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- erasedListener.visitedOp = visitedOp;
+ erasedListener.visitedOp = op;
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
- LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
- }
- });
+ if (succeeded(applicator.matchAndRewrite(op, rewriter)))
+ LDBG() << "\tOp matched and rewritten";
+ }
}
},
{op});
diff --git a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
index 02f7e60671c9b..c75c478ec3734 100644
--- a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
+++ b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
@@ -40,12 +40,12 @@ func.func @move_before(%cond : i1) {
}
// Check that the driver handles rewriter.moveAfter. In this case, we expect
-// the moved op to be visited only once since walk uses `make_early_inc_range`.
+// the moved op to be visited twice.
// CHECK-LABEL: func.func @move_after(
// CHECK: scf.if
// CHECK: }
// CHECK: "test.move_after_parent_op"
-// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
+// CHECK: "test.any_attr_of_i32_str"() <{attr = 2 : i32}> : () -> ()
// CHECK: return
func.func @move_after(%cond : i1) {
scf.if %cond {
More information about the Mlir-commits
mailing list