[Mlir-commits] [mlir] [mlir][IR] Rename `notify*Removed` to `notify*Erased` (PR #82253)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Feb 19 06:54:09 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

Rename listener callback names:
* `notifyOperationRemoved` -> `notifyOperationErased`
* `notifyBlockRemoved` -> `notifyBlockErased`

The current callback names are misnomers. The callbacks are triggered when an operation/block is erased, not when it is removed (unlinked).

E.g.:
```c++
/// Notify the listener that the specified operation is about to be erased.
/// At this point, the operation has zero uses.
///
/// Note: This notification is not triggered when unlinking an operation.
virtual void notifyOperationErased(Operation *op) {}
```

This change is in preparation of adding listener support to the dialect conversion. The dialect conversion internally unlinks IR before erasing it at a later point of time. There is an important difference between "remove" and "erase". Lister callback names should be accurate to avoid confusion.

---
Full diff: https://github.com/llvm/llvm-project/pull/82253.diff


10 Files Affected:

- (modified) mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h (+1-1) 
- (modified) mlir/include/mlir/IR/PatternMatch.h (+6-6) 
- (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+1-1) 
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+2-2) 
- (modified) mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp (+1-1) 
- (modified) mlir/lib/IR/PatternMatch.cpp (+2-2) 
- (modified) mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp (+12-12) 
- (modified) mlir/test/Transforms/test-strict-pattern-driver.mlir (+42-42) 
- (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+2-2) 
- (modified) mlir/test/lib/Transforms/TestConstantFold.cpp (+1-1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 2e096e1f552924..313cdc27f780a7 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -1012,7 +1012,7 @@ class TrackingListener : public RewriterBase::Listener,
 private:
   friend class TransformRewriter;
 
-  void notifyOperationRemoved(Operation *op) override;
+  void notifyOperationErased(Operation *op) override;
 
   void notifyOperationReplaced(Operation *op, ValueRange newValues) override;
   using Listener::notifyOperationReplaced;
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index b8aeea0d23475b..2ce3bc3fc2e783 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -404,7 +404,7 @@ class RewriterBase : public OpBuilder {
 
     /// Notify the listener that the specified block is about to be erased.
     /// At this point, the block has zero uses.
-    virtual void notifyBlockRemoved(Block *block) {}
+    virtual void notifyBlockErased(Block *block) {}
 
     /// Notify the listener that the specified operation was modified in-place.
     virtual void notifyOperationModified(Operation *op) {}
@@ -430,7 +430,7 @@ class RewriterBase : public OpBuilder {
     /// At this point, the operation has zero uses.
     ///
     /// Note: This notification is not triggered when unlinking an operation.
-    virtual void notifyOperationRemoved(Operation *op) {}
+    virtual void notifyOperationErased(Operation *op) {}
 
     /// Notify the listener that the pattern failed to match the given
     /// operation, and provide a callback to populate a diagnostic with the
@@ -457,9 +457,9 @@ class RewriterBase : public OpBuilder {
                              Region::iterator previousIt) override {
       listener->notifyBlockInserted(block, previous, previousIt);
     }
-    void notifyBlockRemoved(Block *block) override {
+    void notifyBlockErased(Block *block) override {
       if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
-        rewriteListener->notifyBlockRemoved(block);
+        rewriteListener->notifyBlockErased(block);
     }
     void notifyOperationModified(Operation *op) override {
       if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
@@ -474,9 +474,9 @@ class RewriterBase : public OpBuilder {
       if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
         rewriteListener->notifyOperationReplaced(op, replacement);
     }
-    void notifyOperationRemoved(Operation *op) override {
+    void notifyOperationErased(Operation *op) override {
       if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
-        rewriteListener->notifyOperationRemoved(op);
+        rewriteListener->notifyOperationErased(op);
     }
     void notifyMatchFailure(
         Location loc,
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 208cbda3a9eb63..6a0ad66549965a 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -369,7 +369,7 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
   }
 
 protected:
-  void notifyOperationRemoved(Operation *op) override {
+  void notifyOperationErased(Operation *op) override {
     erasedOps.insert(op);
     // Erase if present.
     toMemrefOps.erase(op);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 585fd14b40d764..4ef8859fd5c430 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -250,8 +250,8 @@ class NewOpsListener : public RewriterBase::ForwardingListener {
     assert(inserted.second && "expected newly created op");
   }
 
-  void notifyOperationRemoved(Operation *op) override {
-    ForwardingListener::notifyOperationRemoved(op);
+  void notifyOperationErased(Operation *op) override {
+    ForwardingListener::notifyOperationErased(op);
     op->walk([&](Operation *op) { newOps.erase(op); });
   }
 
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index a964c205b62e84..bb9f6fec452986 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -1274,7 +1274,7 @@ void transform::TrackingListener::notifyMatchFailure(
   });
 }
 
-void transform::TrackingListener::notifyOperationRemoved(Operation *op) {
+void transform::TrackingListener::notifyOperationErased(Operation *op) {
   // TODO: Walk can be removed when D144193 has landed.
   op->walk([&](Operation *op) {
     // Remove mappings for result values.
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 9204733c99bab7..5ba5328f14b89e 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -209,7 +209,7 @@ void RewriterBase::eraseOp(Operation *op) {
       assert(mayBeGraphRegion(*op->getParentRegion()) &&
              "expected that op has no uses");
 #endif // NDEBUG
-    rewriteListener->notifyOperationRemoved(op);
+    rewriteListener->notifyOperationErased(op);
 
     // Explicitly drop all uses in case the op is in a graph region.
     op->dropAllUses();
@@ -265,7 +265,7 @@ void RewriterBase::eraseBlock(Block *block) {
 
   // Notify the listener that the block is about to be removed.
   if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
-    rewriteListener->notifyBlockRemoved(block);
+    rewriteListener->notifyBlockErased(block);
 
   block->erase();
 }
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index bde8c290e774bc..51d2f5e01b7235 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -130,8 +130,8 @@ struct ExpensiveChecks : public RewriterBase::ForwardingListener {
   /// Invalidate the finger print of the given op, i.e., remove it from the map.
   void invalidateFingerPrint(Operation *op) { fingerprints.erase(op); }
 
-  void notifyBlockRemoved(Block *block) override {
-    RewriterBase::ForwardingListener::notifyBlockRemoved(block);
+  void notifyBlockErased(Block *block) override {
+    RewriterBase::ForwardingListener::notifyBlockErased(block);
 
     // The block structure (number of blocks, types of block arguments, etc.)
     // is part of the fingerprint of the parent op.
@@ -152,8 +152,8 @@ struct ExpensiveChecks : public RewriterBase::ForwardingListener {
     invalidateFingerPrint(op);
   }
 
-  void notifyOperationRemoved(Operation *op) override {
-    RewriterBase::ForwardingListener::notifyOperationRemoved(op);
+  void notifyOperationErased(Operation *op) override {
+    RewriterBase::ForwardingListener::notifyOperationErased(op);
     op->walk([this](Operation *op) { invalidateFingerPrint(op); });
   }
 
@@ -345,7 +345,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
   /// Notify the driver that the specified operation was removed. Update the
   /// worklist as needed: The operation and its children are removed from the
   /// worklist.
-  void notifyOperationRemoved(Operation *op) override;
+  void notifyOperationErased(Operation *op) override;
 
   /// Notify the driver that the specified operation was replaced. Update the
   /// worklist as needed: New users are added enqueued.
@@ -384,7 +384,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
                            Region::iterator previousIt) override;
 
   /// Notify the driver that the given block is about to be removed.
-  void notifyBlockRemoved(Block *block) override;
+  void notifyBlockErased(Block *block) override;
 
   /// For debugging only: Notify the driver of a pattern match failure.
   void
@@ -647,9 +647,9 @@ void GreedyPatternRewriteDriver::notifyBlockInserted(
     config.listener->notifyBlockInserted(block, previous, previousIt);
 }
 
-void GreedyPatternRewriteDriver::notifyBlockRemoved(Block *block) {
+void GreedyPatternRewriteDriver::notifyBlockErased(Block *block) {
   if (config.listener)
-    config.listener->notifyBlockRemoved(block);
+    config.listener->notifyBlockErased(block);
 }
 
 void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op,
@@ -689,7 +689,7 @@ void GreedyPatternRewriteDriver::addOperandsToWorklist(ValueRange operands) {
   }
 }
 
-void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
+void GreedyPatternRewriteDriver::notifyOperationErased(Operation *op) {
   LLVM_DEBUG({
     logger.startLine() << "** Erase   : '" << op->getName() << "'(" << op
                        << ")\n";
@@ -707,7 +707,7 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
 #endif // NDEBUG
 
   if (config.listener)
-    config.listener->notifyOperationRemoved(op);
+    config.listener->notifyOperationErased(op);
 
   addOperandsToWorklist(op->getOperands());
   worklist.remove(op);
@@ -901,8 +901,8 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
   LogicalResult simplify(ArrayRef<Operation *> ops, bool *changed = nullptr) &&;
 
 private:
-  void notifyOperationRemoved(Operation *op) override {
-    GreedyPatternRewriteDriver::notifyOperationRemoved(op);
+  void notifyOperationErased(Operation *op) override {
+    GreedyPatternRewriteDriver::notifyOperationErased(op);
     if (survivingOps)
       survivingOps->erase(op);
   }
diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir
index 559561b34ceecb..c87444cba8e1a2 100644
--- a/mlir/test/Transforms/test-strict-pattern-driver.mlir
+++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir
@@ -52,8 +52,8 @@ func.func @test_replace_with_new_op() {
 // -----
 
 // CHECK-EN: notifyOperationInserted: test.erase_op, was unlinked
-// CHECK-EN: notifyOperationRemoved: test.replace_with_new_op
-// CHECK-EN: notifyOperationRemoved: test.erase_op
+// CHECK-EN: notifyOperationErased: test.replace_with_new_op
+// CHECK-EN: notifyOperationErased: test.erase_op
 // CHECK-EN-LABEL: func @test_replace_with_erase_op
 //  CHECK-EN-SAME:     {pattern_driver_all_erased = true, pattern_driver_changed = true}
 //   CHECK-EN-NOT:   "test.replace_with_new_op"
@@ -91,10 +91,10 @@ func.func @test_trigger_rewrite_through_block() {
 
 // -----
 
-// CHECK-AN: notifyOperationRemoved: test.foo_b
-// CHECK-AN: notifyOperationRemoved: test.foo_a
-// CHECK-AN: notifyOperationRemoved: test.graph_region
-// CHECK-AN: notifyOperationRemoved: test.erase_op
+// CHECK-AN: notifyOperationErased: test.foo_b
+// CHECK-AN: notifyOperationErased: test.foo_a
+// CHECK-AN: notifyOperationErased: test.graph_region
+// CHECK-AN: notifyOperationErased: test.erase_op
 // CHECK-AN-LABEL: func @test_remove_graph_region()
 //  CHECK-AN-NEXT:   return
 func.func @test_remove_graph_region() {
@@ -109,13 +109,13 @@ func.func @test_remove_graph_region() {
 
 // -----
 
-// CHECK-AN: notifyOperationRemoved: cf.br
-// CHECK-AN: notifyOperationRemoved: test.bar
-// CHECK-AN: notifyOperationRemoved: cf.br
-// CHECK-AN: notifyOperationRemoved: test.foo
-// CHECK-AN: notifyOperationRemoved: cf.br
-// CHECK-AN: notifyOperationRemoved: test.dummy_op
-// CHECK-AN: notifyOperationRemoved: test.erase_op
+// CHECK-AN: notifyOperationErased: cf.br
+// CHECK-AN: notifyOperationErased: test.bar
+// CHECK-AN: notifyOperationErased: cf.br
+// CHECK-AN: notifyOperationErased: test.foo
+// CHECK-AN: notifyOperationErased: cf.br
+// CHECK-AN: notifyOperationErased: test.dummy_op
+// CHECK-AN: notifyOperationErased: test.erase_op
 // CHECK-AN-LABEL: func @test_remove_cyclic_blocks()
 //  CHECK-AN-NEXT:   return
 func.func @test_remove_cyclic_blocks() {
@@ -134,14 +134,14 @@ func.func @test_remove_cyclic_blocks() {
 
 // -----
 
-// CHECK-AN: notifyOperationRemoved: test.dummy_op
-// CHECK-AN: notifyOperationRemoved: test.bar
-// CHECK-AN: notifyOperationRemoved: test.qux
-// CHECK-AN: notifyOperationRemoved: test.qux_unreachable
-// CHECK-AN: notifyOperationRemoved: test.nested_dummy
-// CHECK-AN: notifyOperationRemoved: cf.br
-// CHECK-AN: notifyOperationRemoved: test.foo
-// CHECK-AN: notifyOperationRemoved: test.erase_op
+// CHECK-AN: notifyOperationErased: test.dummy_op
+// CHECK-AN: notifyOperationErased: test.bar
+// CHECK-AN: notifyOperationErased: test.qux
+// CHECK-AN: notifyOperationErased: test.qux_unreachable
+// CHECK-AN: notifyOperationErased: test.nested_dummy
+// CHECK-AN: notifyOperationErased: cf.br
+// CHECK-AN: notifyOperationErased: test.foo
+// CHECK-AN: notifyOperationErased: test.erase_op
 // CHECK-AN-LABEL: func @test_remove_dead_blocks()
 //  CHECK-AN-NEXT:   return
 func.func @test_remove_dead_blocks() {
@@ -169,20 +169,20 @@ func.func @test_remove_dead_blocks() {
 // test.nested_* must be deleted before test.foo.
 // test.bar must be deleted before test.foo.
 
-// CHECK-AN: notifyOperationRemoved: cf.br
-// CHECK-AN: notifyOperationRemoved: test.bar
-// CHECK-AN: notifyOperationRemoved: cf.br
-// CHECK-AN: notifyOperationRemoved: test.nested_b
-// CHECK-AN: notifyOperationRemoved: test.nested_a
-// CHECK-AN: notifyOperationRemoved: test.nested_d
-// CHECK-AN: notifyOperationRemoved: cf.br
-// CHECK-AN: notifyOperationRemoved: test.nested_e
-// CHECK-AN: notifyOperationRemoved: cf.br
-// CHECK-AN: notifyOperationRemoved: test.nested_c
-// CHECK-AN: notifyOperationRemoved: test.foo
-// CHECK-AN: notifyOperationRemoved: cf.br
-// CHECK-AN: notifyOperationRemoved: test.dummy_op
-// CHECK-AN: notifyOperationRemoved: test.erase_op
+// CHECK-AN: notifyOperationErased: cf.br
+// CHECK-AN: notifyOperationErased: test.bar
+// CHECK-AN: notifyOperationErased: cf.br
+// CHECK-AN: notifyOperationErased: test.nested_b
+// CHECK-AN: notifyOperationErased: test.nested_a
+// CHECK-AN: notifyOperationErased: test.nested_d
+// CHECK-AN: notifyOperationErased: cf.br
+// CHECK-AN: notifyOperationErased: test.nested_e
+// CHECK-AN: notifyOperationErased: cf.br
+// CHECK-AN: notifyOperationErased: test.nested_c
+// CHECK-AN: notifyOperationErased: test.foo
+// CHECK-AN: notifyOperationErased: cf.br
+// CHECK-AN: notifyOperationErased: test.dummy_op
+// CHECK-AN: notifyOperationErased: test.erase_op
 // CHECK-AN-LABEL: func @test_remove_nested_ops()
 //  CHECK-AN-NEXT:   return
 func.func @test_remove_nested_ops() {
@@ -212,12 +212,12 @@ func.func @test_remove_nested_ops() {
 
 // -----
 
-// CHECK-AN: notifyOperationRemoved: test.qux
-// CHECK-AN: notifyOperationRemoved: cf.br
-// CHECK-AN: notifyOperationRemoved: test.foo
-// CHECK-AN: notifyOperationRemoved: cf.br
-// CHECK-AN: notifyOperationRemoved: test.bar
-// CHECK-AN: notifyOperationRemoved: cf.cond_br
+// CHECK-AN: notifyOperationErased: test.qux
+// CHECK-AN: notifyOperationErased: cf.br
+// CHECK-AN: notifyOperationErased: test.foo
+// CHECK-AN: notifyOperationErased: cf.br
+// CHECK-AN: notifyOperationErased: test.bar
+// CHECK-AN: notifyOperationErased: cf.cond_br
 // CHECK-AN-LABEL: func @test_remove_diamond(
 //  CHECK-AN-NEXT:   return
 func.func @test_remove_diamond(%c: i1) {
@@ -277,7 +277,7 @@ func.func @test_inline_block_before() {
 // CHECK-AN: notifyOperationInserted: test.op_2, was last in block
 // CHECK-AN: notifyOperationInserted: test.split_block_here, was last in block
 // CHECK-AN: notifyOperationInserted: test.new_op, was unlinked
-// CHECK-AN: notifyOperationRemoved: test.split_block_here
+// CHECK-AN: notifyOperationErased: test.split_block_here
 // CHECK-AN-LABEL: func @test_split_block(
 //       CHECK-AN:   "test.op_with_region"() ({
 //       CHECK-AN:     test.op_1
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 1c02232b8adbb1..875ba892463c11 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -349,8 +349,8 @@ struct DumpNotifications : public RewriterBase::Listener {
       }
     }
   }
-  void notifyOperationRemoved(Operation *op) override {
-    llvm::outs() << "notifyOperationRemoved: " << op->getName() << "\n";
+  void notifyOperationErased(Operation *op) override {
+    llvm::outs() << "notifyOperationErased: " << op->getName() << "\n";
   }
 };
 
diff --git a/mlir/test/lib/Transforms/TestConstantFold.cpp b/mlir/test/lib/Transforms/TestConstantFold.cpp
index 81f634d8d3ef78..87aaeb0d6fc0c7 100644
--- a/mlir/test/lib/Transforms/TestConstantFold.cpp
+++ b/mlir/test/lib/Transforms/TestConstantFold.cpp
@@ -31,7 +31,7 @@ struct TestConstantFold : public PassWrapper<TestConstantFold, OperationPass<>>,
                                OpBuilder::InsertPoint previous) override {
     existingConstants.push_back(op);
   }
-  void notifyOperationRemoved(Operation *op) override {
+  void notifyOperationErased(Operation *op) override {
     auto it = llvm::find(existingConstants, op);
     if (it != existingConstants.end())
       existingConstants.erase(it);

``````````

</details>


https://github.com/llvm/llvm-project/pull/82253


More information about the Mlir-commits mailing list