[Mlir-commits] [mlir] [MLIR] Fix ErasedOpsListener false positives for newly created ops/blocks (PR #192291)
Mehdi Amini
llvmlistbot at llvm.org
Wed Apr 15 10:14:21 PDT 2026
https://github.com/joker-eph created https://github.com/llvm/llvm-project/pull/192291
WalkPatternRewriteDriver's ErasedOpsListener incorrectly flagged erasures of ops/blocks that were created during the current pattern application. Since those ops were never in the walk schedule, erasing them is safe.
Track newly inserted ops and blocks per visited op; skip the erasure check for them. Also fix two related issues:
- DropUnitDims: use a fresh IRRewriter (not inheriting the walk pattern rewriter's listener) for replaceUnitDimIndexOps on cloned ops.
- TestPatterns CloneRegionBeforeOp: wrap op->setAttr() in modifyOpInPlace to properly notify the rewriter of the in-place change.
Assisted-by: Claude Code
Fix a failure present with MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS=ON.
>From 5f0b8bfc943af69f261090302aeed09eda11b183 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Thu, 26 Mar 2026 15:57:39 -0700
Subject: [PATCH] [MLIR] Fix ErasedOpsListener false positives for newly
created ops/blocks
WalkPatternRewriteDriver's ErasedOpsListener incorrectly flagged erasures
of ops/blocks that were created during the current pattern application.
Since those ops were never in the walk schedule, erasing them is safe.
Track newly inserted ops and blocks per visited op; skip the erasure check
for them. Also fix two related issues:
- DropUnitDims: use a fresh IRRewriter (not inheriting the walk pattern
rewriter's listener) for replaceUnitDimIndexOps on cloned ops.
- TestPatterns CloneRegionBeforeOp: wrap op->setAttr() in modifyOpInPlace
to properly notify the rewriter of the in-place change.
Assisted-by: Claude Code
Co-Authored-By: Claude Sonnet 4.6 <noreply at anthropic.com>
Fix a failure present with MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS=ON.
---
.../Linalg/Transforms/DropUnitDims.cpp | 6 ++--
.../Utils/WalkPatternRewriteDriver.cpp | 32 +++++++++++++++++--
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 3 +-
3 files changed, 36 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index c3dca148b7f94..4bb4ad61cf8da 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -578,8 +578,10 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
b.cloneRegionBefore(genericOp.getRegion(), replacementOp.getRegion(),
replacementOp.getRegion().begin());
// 5a. Replace `linalg.index` operations that refer to the dropped unit
- // dimensions.
- IRRewriter rewriter(b);
+ // dimensions. Use a fresh IRRewriter to avoid inheriting any listener
+ // from the builder (e.g., WalkPatternRewriter's erasure listener),
+ // since the ops being erased here are newly cloned, not the matched op.
+ IRRewriter rewriter(b.getContext());
replaceUnitDimIndexOps(replacementOp, droppedDims, rewriter);
return replacementOp;
diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
index 1382550e0f7e6..40fcb351ee079 100644
--- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
@@ -19,6 +19,7 @@
#include "mlir/IR/Verifier.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Rewrite/PatternApplicator.h"
+#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/DebugLog.h"
#include "llvm/Support/ErrorHandling.h"
@@ -61,16 +62,38 @@ struct WalkAndApplyPatternsAction final
// ops/blocks. Because we use walk-based pattern application, erasing the
// op/block from the *next* iteration (e.g., a user of the visited op) is not
// valid. Note that this is only used with expensive pattern API checks.
+//
+// Ops and blocks that were *created* during the current pattern application are
+// exempt: they were not in the walk schedule before the pattern ran, so erasing
+// them cannot disrupt the walk.
struct ErasedOpsListener final : RewriterBase::ForwardingListener {
using RewriterBase::ForwardingListener::ForwardingListener;
+ void notifyOperationInserted(Operation *op,
+ OpBuilder::InsertPoint previous) override {
+ if (visitedOp)
+ newlyCreatedOps.insert(op);
+ ForwardingListener::notifyOperationInserted(op, previous);
+ }
+
+ void notifyBlockInserted(Block *block, Region *previous,
+ Region::iterator previousIt) override {
+ if (visitedOp)
+ newlyCreatedBlocks.insert(block);
+ ForwardingListener::notifyBlockInserted(block, previous, previousIt);
+ }
+
void notifyOperationErased(Operation *op) override {
- checkErasure(op);
+ if (!newlyCreatedOps.contains(op))
+ checkErasure(op);
+ newlyCreatedOps.erase(op);
ForwardingListener::notifyOperationErased(op);
}
void notifyBlockErased(Block *block) override {
- checkErasure(block->getParentOp());
+ if (!newlyCreatedBlocks.contains(block))
+ checkErasure(block->getParentOp());
+ newlyCreatedBlocks.erase(block);
ForwardingListener::notifyBlockErased(block);
}
@@ -86,6 +109,9 @@ struct ErasedOpsListener final : RewriterBase::ForwardingListener {
}
Operation *visitedOp = nullptr;
+ // Ops and blocks inserted since visitedOp was last set; may be freely erased.
+ DenseSet<Operation *> newlyCreatedOps;
+ DenseSet<Block *> newlyCreatedBlocks;
};
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
} // namespace
@@ -204,6 +230,8 @@ void walkAndApplyPatterns(Operation *op,
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
erasedListener.visitedOp = op;
+ erasedListener.newlyCreatedOps.clear();
+ erasedListener.newlyCreatedBlocks.clear();
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (succeeded(applicator.matchAndRewrite(op, rewriter)))
LDBG() << "\tOp matched and rewritten";
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index c8be4bf3f0f8d..1f02c284d924f 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -383,7 +383,8 @@ struct CloneRegionBeforeOp : public RewritePattern {
return failure();
for (Region &r : op->getRegions())
rewriter.cloneRegionBefore(r, op->getBlock());
- op->setAttr("was_cloned", rewriter.getUnitAttr());
+ rewriter.modifyOpInPlace(
+ op, [&]() { op->setAttr("was_cloned", rewriter.getUnitAttr()); });
return success();
}
};
More information about the Mlir-commits
mailing list