[Mlir-commits] [mlir] [MLIR] Fix ErasedOpsListener false positives for newly created ops/blocks (PR #192291)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 15 10:14:57 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Mehdi Amini (joker-eph)

<details>
<summary>Changes</summary>

WalkPatternRewriteDriver's ErasedOpsListener incorrectly flagged erasures of ops/blocks that were created during the current pattern application. Since those ops were never in the walk schedule, erasing them is safe.

Track newly inserted ops and blocks per visited op; skip the erasure check for them. Also fix two related issues:
- DropUnitDims: use a fresh IRRewriter (not inheriting the walk pattern rewriter's listener) for replaceUnitDimIndexOps on cloned ops.
- TestPatterns CloneRegionBeforeOp: wrap op->setAttr() in modifyOpInPlace to properly notify the rewriter of the in-place change.

Assisted-by: Claude Code

Fix a failure present with MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS=ON.

---
Full diff: https://github.com/llvm/llvm-project/pull/192291.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp (+4-2) 
- (modified) mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp (+30-2) 
- (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+2-1) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index c3dca148b7f94..4bb4ad61cf8da 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -578,8 +578,10 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
     b.cloneRegionBefore(genericOp.getRegion(), replacementOp.getRegion(),
                         replacementOp.getRegion().begin());
     // 5a. Replace `linalg.index` operations that refer to the dropped unit
-    //     dimensions.
-    IRRewriter rewriter(b);
+    //     dimensions. Use a fresh IRRewriter to avoid inheriting any listener
+    //     from the builder (e.g., WalkPatternRewriter's erasure listener),
+    //     since the ops being erased here are newly cloned, not the matched op.
+    IRRewriter rewriter(b.getContext());
     replaceUnitDimIndexOps(replacementOp, droppedDims, rewriter);
 
     return replacementOp;
diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
index 1382550e0f7e6..40fcb351ee079 100644
--- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
@@ -19,6 +19,7 @@
 #include "mlir/IR/Verifier.h"
 #include "mlir/IR/Visitors.h"
 #include "mlir/Rewrite/PatternApplicator.h"
+#include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/DebugLog.h"
 #include "llvm/Support/ErrorHandling.h"
@@ -61,16 +62,38 @@ struct WalkAndApplyPatternsAction final
 // ops/blocks. Because we use walk-based pattern application, erasing the
 // op/block from the *next* iteration (e.g., a user of the visited op) is not
 // valid. Note that this is only used with expensive pattern API checks.
+//
+// Ops and blocks that were *created* during the current pattern application are
+// exempt: they were not in the walk schedule before the pattern ran, so erasing
+// them cannot disrupt the walk.
 struct ErasedOpsListener final : RewriterBase::ForwardingListener {
   using RewriterBase::ForwardingListener::ForwardingListener;
 
+  void notifyOperationInserted(Operation *op,
+                               OpBuilder::InsertPoint previous) override {
+    if (visitedOp)
+      newlyCreatedOps.insert(op);
+    ForwardingListener::notifyOperationInserted(op, previous);
+  }
+
+  void notifyBlockInserted(Block *block, Region *previous,
+                           Region::iterator previousIt) override {
+    if (visitedOp)
+      newlyCreatedBlocks.insert(block);
+    ForwardingListener::notifyBlockInserted(block, previous, previousIt);
+  }
+
   void notifyOperationErased(Operation *op) override {
-    checkErasure(op);
+    if (!newlyCreatedOps.contains(op))
+      checkErasure(op);
+    newlyCreatedOps.erase(op);
     ForwardingListener::notifyOperationErased(op);
   }
 
   void notifyBlockErased(Block *block) override {
-    checkErasure(block->getParentOp());
+    if (!newlyCreatedBlocks.contains(block))
+      checkErasure(block->getParentOp());
+    newlyCreatedBlocks.erase(block);
     ForwardingListener::notifyBlockErased(block);
   }
 
@@ -86,6 +109,9 @@ struct ErasedOpsListener final : RewriterBase::ForwardingListener {
   }
 
   Operation *visitedOp = nullptr;
+  // Ops and blocks inserted since visitedOp was last set; may be freely erased.
+  DenseSet<Operation *> newlyCreatedOps;
+  DenseSet<Block *> newlyCreatedBlocks;
 };
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 } // namespace
@@ -204,6 +230,8 @@ void walkAndApplyPatterns(Operation *op,
                    << OpWithFlags(op, OpPrintingFlags().skipRegions());
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
             erasedListener.visitedOp = op;
+            erasedListener.newlyCreatedOps.clear();
+            erasedListener.newlyCreatedBlocks.clear();
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
             if (succeeded(applicator.matchAndRewrite(op, rewriter)))
               LDBG() << "\tOp matched and rewritten";
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index c8be4bf3f0f8d..1f02c284d924f 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -383,7 +383,8 @@ struct CloneRegionBeforeOp : public RewritePattern {
       return failure();
     for (Region &r : op->getRegions())
       rewriter.cloneRegionBefore(r, op->getBlock());
-    op->setAttr("was_cloned", rewriter.getUnitAttr());
+    rewriter.modifyOpInPlace(
+        op, [&]() { op->setAttr("was_cloned", rewriter.getUnitAttr()); });
     return success();
   }
 };

``````````

</details>


https://github.com/llvm/llvm-project/pull/192291


More information about the Mlir-commits mailing list