[llvm-branch-commits] [mlir] [mlir][Transforms] Dialect conversion: Add missing erasure notifications (PR #145030)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Jun 20 05:29:04 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
Add missing listener notifications when erasing nested blocks/operations.
This commit also moves some of the functionality from `ConversionPatternRewriter` to `ConversionPatternRewriterImpl`. This is in preparation of the One-Shot Dialect Conversion refactoring: The implementations in `ConversionPatternRewriter` should be as simple as possible, so that a switch between "rollback allowed" and "rollback not allowed" can be inserted at that level. (In the latter case, `ConversionPatternRewriterImpl` can be bypassed to some degree, and `PatternRewriter::eraseBlock` etc. can be used.)
---
Full diff: https://github.com/llvm/llvm-project/pull/145030.diff
2 Files Affected:
- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+43-19)
- (modified) mlir/test/Transforms/test-legalizer.mlir (+16-2)
``````````diff
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index ff48647f43305..7419d79cd8856 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -274,6 +274,26 @@ struct RewriterState {
// IR rewrites
//===----------------------------------------------------------------------===//
+static void notifyIRErased(RewriterBase::Listener *listener, Operation &op);
+
+/// Notify the listener that the given block and its contents are being erased.
+static void notifyIRErased(RewriterBase::Listener *listener, Block &b) {
+ for (Operation &op : b)
+ notifyIRErased(listener, op);
+ listener->notifyBlockErased(&b);
+}
+
+/// Notify the listener that the given operation and its contents are being
+/// erased.
+static void notifyIRErased(RewriterBase::Listener *listener, Operation &op) {
+ for (Region &r : op.getRegions()) {
+ for (Block &b : r) {
+ notifyIRErased(listener, b);
+ }
+ }
+ listener->notifyOperationErased(&op);
+}
+
/// An IR rewrite that can be committed (upon success) or rolled back (upon
/// failure).
///
@@ -422,17 +442,20 @@ class EraseBlockRewrite : public BlockRewrite {
}
void commit(RewriterBase &rewriter) override {
- // Erase the block.
assert(block && "expected block");
- assert(block->empty() && "expected empty block");
- // Notify the listener that the block is about to be erased.
+ // Notify the listener that the block and its contents are being erased.
if (auto *listener =
dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
- listener->notifyBlockErased(block);
+ notifyIRErased(listener, *block);
}
void cleanup(RewriterBase &rewriter) override {
+ // Erase the contents of the block.
+ for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
+ rewriter.eraseOp(&op);
+ assert(block->empty() && "expected empty block");
+
// Erase the block.
block->dropAllDefinedValueUses();
delete block;
@@ -1147,12 +1170,9 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
if (getConfig().unlegalizedOps)
getConfig().unlegalizedOps->erase(op);
- // Notify the listener that the operation (and its nested operations) was
- // erased.
- if (listener) {
- op->walk<WalkOrder::PostOrder>(
- [&](Operation *op) { listener->notifyOperationErased(op); });
- }
+ // Notify the listener that the operation and its contents are being erased.
+ if (listener)
+ notifyIRErased(listener, *op);
// Do not erase the operation yet. It may still be referenced in `mapping`.
// Just unlink it for now and erase it during cleanup.
@@ -1605,6 +1625,8 @@ void ConversionPatternRewriterImpl::replaceOp(
}
void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
+ assert(!wasOpReplaced(block->getParentOp()) &&
+ "attempting to erase a block within a replaced/erased op");
appendRewrite<EraseBlockRewrite>(block);
// Unlink the block from its parent region. The block is kept in the rewrite
@@ -1612,12 +1634,16 @@ void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
// allows us to keep the operations in the block live and undo the removal by
// re-inserting the block.
block->getParent()->getBlocks().remove(block);
+
+ // Mark all nested ops as erased.
+ block->walk([&](Operation *op) { replacedOps.insert(op); });
}
void ConversionPatternRewriterImpl::notifyBlockInserted(
Block *block, Region *previous, Region::iterator previousIt) {
- assert(!wasOpReplaced(block->getParentOp()) &&
- "attempting to insert into a region within a replaced/erased op");
+ assert(
+ (!config.allowPatternRollback || !wasOpReplaced(block->getParentOp())) &&
+ "attempting to insert into a region within a replaced/erased op");
LLVM_DEBUG(
{
Operation *parent = block->getParentOp();
@@ -1630,6 +1656,11 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
}
});
+ if (!config.allowPatternRollback) {
+ // Pattern rollback is not allowed. No extra bookkeeping is needed.
+ return;
+ }
+
if (!previous) {
// This is a newly created block.
appendRewrite<CreateBlockRewrite>(block);
@@ -1709,13 +1740,6 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
}
void ConversionPatternRewriter::eraseBlock(Block *block) {
- assert(!impl->wasOpReplaced(block->getParentOp()) &&
- "attempting to erase a block within a replaced/erased op");
-
- // Mark all ops for erasure.
- for (Operation &op : *block)
- eraseOp(&op);
-
impl->eraseBlock(block);
}
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 34948ae685f0a..204c8c1456826 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -461,12 +461,26 @@ func.func @convert_detached_signature() {
// -----
+// CHECK: notifyOperationReplaced: test.erase_op
+// CHECK: notifyOperationErased: test.dummy_op_lvl_2
+// CHECK: notifyBlockErased
+// CHECK: notifyOperationErased: test.dummy_op_lvl_1
+// CHECK: notifyBlockErased
+// CHECK: notifyOperationErased: test.erase_op
+// CHECK: notifyOperationInserted: test.valid, was unlinked
+// CHECK: notifyOperationReplaced: test.drop_operands_and_replace_with_valid
+// CHECK: notifyOperationErased: test.drop_operands_and_replace_with_valid
+
// CHECK-LABEL: func @circular_mapping()
// CHECK-NEXT: "test.valid"() : () -> ()
func.func @circular_mapping() {
// Regression test that used to crash due to circular
- // unrealized_conversion_cast ops.
- %0 = "test.erase_op"() : () -> (i64)
+ // unrealized_conversion_cast ops.
+ %0 = "test.erase_op"() ({
+ "test.dummy_op_lvl_1"() ({
+ "test.dummy_op_lvl_2"() : () -> ()
+ }) : () -> ()
+ }): () -> (i64)
"test.drop_operands_and_replace_with_valid"(%0) : (i64) -> ()
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/145030
More information about the llvm-branch-commits
mailing list