[Mlir-commits] [mlir] [mlir][IR][NFC] Rename `notify*Removed` to `notify*Erased` (PR #82253)
Matthias Springer
llvmlistbot at llvm.org
Tue Feb 20 00:00:14 PST 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/82253
>From 584af073c376edfeb622177400e08da9d068eada Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 19 Feb 2024 14:52:33 +0000
Subject: [PATCH] [mlir][IR] Rename `notify*Removed` to `notify*Erased`
Rename listener callback names:
* `notifyOperationRemoved` -> `notifyOperationErased`
* `notifyBlockRemoved` -> `notifyBlockErased`
The current callback names are misnomers. The callbacks are triggered when an operation/block is erased, not when it is removed (unlinked).
E.g.:
```c++
/// Notify the listener that the specified operation is about to be erased.
/// At this point, the operation has zero uses.
///
/// Note: This notification is not triggered when unlinking an operation.
virtual void notifyOperationErased(Operation *op) {}
```
This change is in preparation of adding listener support to the dialect conversion. The dialect conversion internally unlinks IR before erasing it at a later point of time. There is an important difference between "remove" and "erase". Lister callback names should be accurate to avoid confusion.
---
.../Transform/IR/TransformInterfaces.h | 2 +-
mlir/include/mlir/IR/PatternMatch.h | 12 +--
.../Bufferization/Transforms/Bufferize.cpp | 2 +-
.../TransformOps/LinalgTransformOps.cpp | 4 +-
.../Transform/IR/TransformInterfaces.cpp | 2 +-
mlir/lib/IR/PatternMatch.cpp | 4 +-
.../Utils/GreedyPatternRewriteDriver.cpp | 24 +++---
.../test-strict-pattern-driver.mlir | 84 +++++++++----------
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 4 +-
mlir/test/lib/Transforms/TestConstantFold.cpp | 2 +-
10 files changed, 70 insertions(+), 70 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 2e096e1f552924..313cdc27f780a7 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -1012,7 +1012,7 @@ class TrackingListener : public RewriterBase::Listener,
private:
friend class TransformRewriter;
- void notifyOperationRemoved(Operation *op) override;
+ void notifyOperationErased(Operation *op) override;
void notifyOperationReplaced(Operation *op, ValueRange newValues) override;
using Listener::notifyOperationReplaced;
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index b8aeea0d23475b..2ce3bc3fc2e783 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -404,7 +404,7 @@ class RewriterBase : public OpBuilder {
/// Notify the listener that the specified block is about to be erased.
/// At this point, the block has zero uses.
- virtual void notifyBlockRemoved(Block *block) {}
+ virtual void notifyBlockErased(Block *block) {}
/// Notify the listener that the specified operation was modified in-place.
virtual void notifyOperationModified(Operation *op) {}
@@ -430,7 +430,7 @@ class RewriterBase : public OpBuilder {
/// At this point, the operation has zero uses.
///
/// Note: This notification is not triggered when unlinking an operation.
- virtual void notifyOperationRemoved(Operation *op) {}
+ virtual void notifyOperationErased(Operation *op) {}
/// Notify the listener that the pattern failed to match the given
/// operation, and provide a callback to populate a diagnostic with the
@@ -457,9 +457,9 @@ class RewriterBase : public OpBuilder {
Region::iterator previousIt) override {
listener->notifyBlockInserted(block, previous, previousIt);
}
- void notifyBlockRemoved(Block *block) override {
+ void notifyBlockErased(Block *block) override {
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
- rewriteListener->notifyBlockRemoved(block);
+ rewriteListener->notifyBlockErased(block);
}
void notifyOperationModified(Operation *op) override {
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
@@ -474,9 +474,9 @@ class RewriterBase : public OpBuilder {
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
rewriteListener->notifyOperationReplaced(op, replacement);
}
- void notifyOperationRemoved(Operation *op) override {
+ void notifyOperationErased(Operation *op) override {
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
- rewriteListener->notifyOperationRemoved(op);
+ rewriteListener->notifyOperationErased(op);
}
void notifyMatchFailure(
Location loc,
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 208cbda3a9eb63..6a0ad66549965a 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -369,7 +369,7 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
}
protected:
- void notifyOperationRemoved(Operation *op) override {
+ void notifyOperationErased(Operation *op) override {
erasedOps.insert(op);
// Erase if present.
toMemrefOps.erase(op);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 585fd14b40d764..4ef8859fd5c430 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -250,8 +250,8 @@ class NewOpsListener : public RewriterBase::ForwardingListener {
assert(inserted.second && "expected newly created op");
}
- void notifyOperationRemoved(Operation *op) override {
- ForwardingListener::notifyOperationRemoved(op);
+ void notifyOperationErased(Operation *op) override {
+ ForwardingListener::notifyOperationErased(op);
op->walk([&](Operation *op) { newOps.erase(op); });
}
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index a964c205b62e84..bb9f6fec452986 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -1274,7 +1274,7 @@ void transform::TrackingListener::notifyMatchFailure(
});
}
-void transform::TrackingListener::notifyOperationRemoved(Operation *op) {
+void transform::TrackingListener::notifyOperationErased(Operation *op) {
// TODO: Walk can be removed when D144193 has landed.
op->walk([&](Operation *op) {
// Remove mappings for result values.
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 9204733c99bab7..5ba5328f14b89e 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -209,7 +209,7 @@ void RewriterBase::eraseOp(Operation *op) {
assert(mayBeGraphRegion(*op->getParentRegion()) &&
"expected that op has no uses");
#endif // NDEBUG
- rewriteListener->notifyOperationRemoved(op);
+ rewriteListener->notifyOperationErased(op);
// Explicitly drop all uses in case the op is in a graph region.
op->dropAllUses();
@@ -265,7 +265,7 @@ void RewriterBase::eraseBlock(Block *block) {
// Notify the listener that the block is about to be removed.
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
- rewriteListener->notifyBlockRemoved(block);
+ rewriteListener->notifyBlockErased(block);
block->erase();
}
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index bde8c290e774bc..51d2f5e01b7235 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -130,8 +130,8 @@ struct ExpensiveChecks : public RewriterBase::ForwardingListener {
/// Invalidate the finger print of the given op, i.e., remove it from the map.
void invalidateFingerPrint(Operation *op) { fingerprints.erase(op); }
- void notifyBlockRemoved(Block *block) override {
- RewriterBase::ForwardingListener::notifyBlockRemoved(block);
+ void notifyBlockErased(Block *block) override {
+ RewriterBase::ForwardingListener::notifyBlockErased(block);
// The block structure (number of blocks, types of block arguments, etc.)
// is part of the fingerprint of the parent op.
@@ -152,8 +152,8 @@ struct ExpensiveChecks : public RewriterBase::ForwardingListener {
invalidateFingerPrint(op);
}
- void notifyOperationRemoved(Operation *op) override {
- RewriterBase::ForwardingListener::notifyOperationRemoved(op);
+ void notifyOperationErased(Operation *op) override {
+ RewriterBase::ForwardingListener::notifyOperationErased(op);
op->walk([this](Operation *op) { invalidateFingerPrint(op); });
}
@@ -345,7 +345,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
/// Notify the driver that the specified operation was removed. Update the
/// worklist as needed: The operation and its children are removed from the
/// worklist.
- void notifyOperationRemoved(Operation *op) override;
+ void notifyOperationErased(Operation *op) override;
/// Notify the driver that the specified operation was replaced. Update the
/// worklist as needed: New users are added enqueued.
@@ -384,7 +384,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
Region::iterator previousIt) override;
/// Notify the driver that the given block is about to be removed.
- void notifyBlockRemoved(Block *block) override;
+ void notifyBlockErased(Block *block) override;
/// For debugging only: Notify the driver of a pattern match failure.
void
@@ -647,9 +647,9 @@ void GreedyPatternRewriteDriver::notifyBlockInserted(
config.listener->notifyBlockInserted(block, previous, previousIt);
}
-void GreedyPatternRewriteDriver::notifyBlockRemoved(Block *block) {
+void GreedyPatternRewriteDriver::notifyBlockErased(Block *block) {
if (config.listener)
- config.listener->notifyBlockRemoved(block);
+ config.listener->notifyBlockErased(block);
}
void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op,
@@ -689,7 +689,7 @@ void GreedyPatternRewriteDriver::addOperandsToWorklist(ValueRange operands) {
}
}
-void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
+void GreedyPatternRewriteDriver::notifyOperationErased(Operation *op) {
LLVM_DEBUG({
logger.startLine() << "** Erase : '" << op->getName() << "'(" << op
<< ")\n";
@@ -707,7 +707,7 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
#endif // NDEBUG
if (config.listener)
- config.listener->notifyOperationRemoved(op);
+ config.listener->notifyOperationErased(op);
addOperandsToWorklist(op->getOperands());
worklist.remove(op);
@@ -901,8 +901,8 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
LogicalResult simplify(ArrayRef<Operation *> ops, bool *changed = nullptr) &&;
private:
- void notifyOperationRemoved(Operation *op) override {
- GreedyPatternRewriteDriver::notifyOperationRemoved(op);
+ void notifyOperationErased(Operation *op) override {
+ GreedyPatternRewriteDriver::notifyOperationErased(op);
if (survivingOps)
survivingOps->erase(op);
}
diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir
index 559561b34ceecb..c87444cba8e1a2 100644
--- a/mlir/test/Transforms/test-strict-pattern-driver.mlir
+++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir
@@ -52,8 +52,8 @@ func.func @test_replace_with_new_op() {
// -----
// CHECK-EN: notifyOperationInserted: test.erase_op, was unlinked
-// CHECK-EN: notifyOperationRemoved: test.replace_with_new_op
-// CHECK-EN: notifyOperationRemoved: test.erase_op
+// CHECK-EN: notifyOperationErased: test.replace_with_new_op
+// CHECK-EN: notifyOperationErased: test.erase_op
// CHECK-EN-LABEL: func @test_replace_with_erase_op
// CHECK-EN-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true}
// CHECK-EN-NOT: "test.replace_with_new_op"
@@ -91,10 +91,10 @@ func.func @test_trigger_rewrite_through_block() {
// -----
-// CHECK-AN: notifyOperationRemoved: test.foo_b
-// CHECK-AN: notifyOperationRemoved: test.foo_a
-// CHECK-AN: notifyOperationRemoved: test.graph_region
-// CHECK-AN: notifyOperationRemoved: test.erase_op
+// CHECK-AN: notifyOperationErased: test.foo_b
+// CHECK-AN: notifyOperationErased: test.foo_a
+// CHECK-AN: notifyOperationErased: test.graph_region
+// CHECK-AN: notifyOperationErased: test.erase_op
// CHECK-AN-LABEL: func @test_remove_graph_region()
// CHECK-AN-NEXT: return
func.func @test_remove_graph_region() {
@@ -109,13 +109,13 @@ func.func @test_remove_graph_region() {
// -----
-// CHECK-AN: notifyOperationRemoved: cf.br
-// CHECK-AN: notifyOperationRemoved: test.bar
-// CHECK-AN: notifyOperationRemoved: cf.br
-// CHECK-AN: notifyOperationRemoved: test.foo
-// CHECK-AN: notifyOperationRemoved: cf.br
-// CHECK-AN: notifyOperationRemoved: test.dummy_op
-// CHECK-AN: notifyOperationRemoved: test.erase_op
+// CHECK-AN: notifyOperationErased: cf.br
+// CHECK-AN: notifyOperationErased: test.bar
+// CHECK-AN: notifyOperationErased: cf.br
+// CHECK-AN: notifyOperationErased: test.foo
+// CHECK-AN: notifyOperationErased: cf.br
+// CHECK-AN: notifyOperationErased: test.dummy_op
+// CHECK-AN: notifyOperationErased: test.erase_op
// CHECK-AN-LABEL: func @test_remove_cyclic_blocks()
// CHECK-AN-NEXT: return
func.func @test_remove_cyclic_blocks() {
@@ -134,14 +134,14 @@ func.func @test_remove_cyclic_blocks() {
// -----
-// CHECK-AN: notifyOperationRemoved: test.dummy_op
-// CHECK-AN: notifyOperationRemoved: test.bar
-// CHECK-AN: notifyOperationRemoved: test.qux
-// CHECK-AN: notifyOperationRemoved: test.qux_unreachable
-// CHECK-AN: notifyOperationRemoved: test.nested_dummy
-// CHECK-AN: notifyOperationRemoved: cf.br
-// CHECK-AN: notifyOperationRemoved: test.foo
-// CHECK-AN: notifyOperationRemoved: test.erase_op
+// CHECK-AN: notifyOperationErased: test.dummy_op
+// CHECK-AN: notifyOperationErased: test.bar
+// CHECK-AN: notifyOperationErased: test.qux
+// CHECK-AN: notifyOperationErased: test.qux_unreachable
+// CHECK-AN: notifyOperationErased: test.nested_dummy
+// CHECK-AN: notifyOperationErased: cf.br
+// CHECK-AN: notifyOperationErased: test.foo
+// CHECK-AN: notifyOperationErased: test.erase_op
// CHECK-AN-LABEL: func @test_remove_dead_blocks()
// CHECK-AN-NEXT: return
func.func @test_remove_dead_blocks() {
@@ -169,20 +169,20 @@ func.func @test_remove_dead_blocks() {
// test.nested_* must be deleted before test.foo.
// test.bar must be deleted before test.foo.
-// CHECK-AN: notifyOperationRemoved: cf.br
-// CHECK-AN: notifyOperationRemoved: test.bar
-// CHECK-AN: notifyOperationRemoved: cf.br
-// CHECK-AN: notifyOperationRemoved: test.nested_b
-// CHECK-AN: notifyOperationRemoved: test.nested_a
-// CHECK-AN: notifyOperationRemoved: test.nested_d
-// CHECK-AN: notifyOperationRemoved: cf.br
-// CHECK-AN: notifyOperationRemoved: test.nested_e
-// CHECK-AN: notifyOperationRemoved: cf.br
-// CHECK-AN: notifyOperationRemoved: test.nested_c
-// CHECK-AN: notifyOperationRemoved: test.foo
-// CHECK-AN: notifyOperationRemoved: cf.br
-// CHECK-AN: notifyOperationRemoved: test.dummy_op
-// CHECK-AN: notifyOperationRemoved: test.erase_op
+// CHECK-AN: notifyOperationErased: cf.br
+// CHECK-AN: notifyOperationErased: test.bar
+// CHECK-AN: notifyOperationErased: cf.br
+// CHECK-AN: notifyOperationErased: test.nested_b
+// CHECK-AN: notifyOperationErased: test.nested_a
+// CHECK-AN: notifyOperationErased: test.nested_d
+// CHECK-AN: notifyOperationErased: cf.br
+// CHECK-AN: notifyOperationErased: test.nested_e
+// CHECK-AN: notifyOperationErased: cf.br
+// CHECK-AN: notifyOperationErased: test.nested_c
+// CHECK-AN: notifyOperationErased: test.foo
+// CHECK-AN: notifyOperationErased: cf.br
+// CHECK-AN: notifyOperationErased: test.dummy_op
+// CHECK-AN: notifyOperationErased: test.erase_op
// CHECK-AN-LABEL: func @test_remove_nested_ops()
// CHECK-AN-NEXT: return
func.func @test_remove_nested_ops() {
@@ -212,12 +212,12 @@ func.func @test_remove_nested_ops() {
// -----
-// CHECK-AN: notifyOperationRemoved: test.qux
-// CHECK-AN: notifyOperationRemoved: cf.br
-// CHECK-AN: notifyOperationRemoved: test.foo
-// CHECK-AN: notifyOperationRemoved: cf.br
-// CHECK-AN: notifyOperationRemoved: test.bar
-// CHECK-AN: notifyOperationRemoved: cf.cond_br
+// CHECK-AN: notifyOperationErased: test.qux
+// CHECK-AN: notifyOperationErased: cf.br
+// CHECK-AN: notifyOperationErased: test.foo
+// CHECK-AN: notifyOperationErased: cf.br
+// CHECK-AN: notifyOperationErased: test.bar
+// CHECK-AN: notifyOperationErased: cf.cond_br
// CHECK-AN-LABEL: func @test_remove_diamond(
// CHECK-AN-NEXT: return
func.func @test_remove_diamond(%c: i1) {
@@ -277,7 +277,7 @@ func.func @test_inline_block_before() {
// CHECK-AN: notifyOperationInserted: test.op_2, was last in block
// CHECK-AN: notifyOperationInserted: test.split_block_here, was last in block
// CHECK-AN: notifyOperationInserted: test.new_op, was unlinked
-// CHECK-AN: notifyOperationRemoved: test.split_block_here
+// CHECK-AN: notifyOperationErased: test.split_block_here
// CHECK-AN-LABEL: func @test_split_block(
// CHECK-AN: "test.op_with_region"() ({
// CHECK-AN: test.op_1
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 17b2f29c45dbc9..2102a4ffabf7b8 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -349,8 +349,8 @@ struct DumpNotifications : public RewriterBase::Listener {
}
}
}
- void notifyOperationRemoved(Operation *op) override {
- llvm::outs() << "notifyOperationRemoved: " << op->getName() << "\n";
+ void notifyOperationErased(Operation *op) override {
+ llvm::outs() << "notifyOperationErased: " << op->getName() << "\n";
}
};
diff --git a/mlir/test/lib/Transforms/TestConstantFold.cpp b/mlir/test/lib/Transforms/TestConstantFold.cpp
index b145ee1fef82c6..c97ab9091cb66d 100644
--- a/mlir/test/lib/Transforms/TestConstantFold.cpp
+++ b/mlir/test/lib/Transforms/TestConstantFold.cpp
@@ -31,7 +31,7 @@ struct TestConstantFold : public PassWrapper<TestConstantFold, OperationPass<>>,
OpBuilder::InsertPoint previous) override {
existingConstants.push_back(op);
}
- void notifyOperationRemoved(Operation *op) override {
+ void notifyOperationErased(Operation *op) override {
auto *it = llvm::find(existingConstants, op);
if (it != existingConstants.end())
existingConstants.erase(it);
More information about the Mlir-commits
mailing list