[Mlir-commits] [mlir] [mlir][IR] Add `notifyBlockRemoved` callback to listener (PR #78306)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 16 08:15:07 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-gpu

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

There is already a "block inserted" notification (in `OpBuilder::Listener`), so there should also be a "block removed" notification.

The purpose of this change is to make the listener API more mature. There is currently a gap between what kind of IR changes can be made and what IR changes can be listened to. At the moment, the only way to inform listeners about "block removal" is to send a manual `notifyOperationModified` for the parent op (e.g., by wrapping the `eraseBlock(b)` method call in `updateRootInPlace(b->getParentOp())`). This tells the listener that *something* has changed, but it is somewhat of an API abuse.

---
Full diff: https://github.com/llvm/llvm-project/pull/78306.diff


3 Files Affected:

- (modified) mlir/include/mlir/IR/PatternMatch.h (+8) 
- (modified) mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (+1-1) 
- (modified) mlir/lib/IR/PatternMatch.cpp (+9-2) 


``````````diff
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 9b4fa65bff49e1..6c5d317a79b7ab 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -402,6 +402,10 @@ class RewriterBase : public OpBuilder {
     Listener()
         : OpBuilder::Listener(ListenerBase::Kind::RewriterBaseListener) {}
 
+    /// Notify the listener that the specified block is about to be erased.
+    /// At this point, the block has zero uses.
+    virtual void notifyBlockRemoved(Block *block) {}
+
     /// Notify the listener that the specified operation was modified in-place.
     virtual void notifyOperationModified(Operation *op) {}
 
@@ -452,6 +456,10 @@ class RewriterBase : public OpBuilder {
     void notifyBlockCreated(Block *block) override {
       listener->notifyBlockCreated(block);
     }
+    void notifyBlockRemoved(Block *block) override {
+      if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+        rewriteListener->notifyBlockRemoved(block);
+    }
     void notifyOperationModified(Operation *op) override {
       if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
         rewriteListener->notifyOperationModified(op);
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index abb65bc3c38f22..b63baf330c8645 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -1114,7 +1114,7 @@ static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
   scf::ForOp newLoop = rewriter.create<scf::ForOp>(
       loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
       operands);
-  newLoop.getBody()->erase();
+  rewriter.eraseBlock(newLoop.getBody());
 
   newLoop.getRegion().getBlocks().splice(
       newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 5e788cdb4897d3..de226f40cb3e9a 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -244,7 +244,7 @@ void RewriterBase::eraseOp(Operation *op) {
           for (BlockArgument bbArg : b->getArguments())
             bbArg.dropAllUses();
           b->dropAllUses();
-          b->erase();
+          eraseBlock(b);
         }
       }
     }
@@ -256,10 +256,17 @@ void RewriterBase::eraseOp(Operation *op) {
 }
 
 void RewriterBase::eraseBlock(Block *block) {
+  assert(block->use_empty() && "expected 'block' to have no uses");
+
   for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
     assert(op.use_empty() && "expected 'op' to have no uses");
     eraseOp(&op);
   }
+
+  // Notify the listener that the block is about to be removed.
+  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
+    rewriteListener->notifyBlockRemoved(block);
+
   block->erase();
 }
 
@@ -311,7 +318,7 @@ void RewriterBase::inlineBlockBefore(Block *source, Block *dest,
   // Move operations from the source block to the dest block and erase the
   // source block.
   dest->getOperations().splice(before, source->getOperations());
-  source->erase();
+  eraseBlock(source);
 }
 
 void RewriterBase::inlineBlockBefore(Block *source, Operation *op,

``````````

</details>


https://github.com/llvm/llvm-project/pull/78306


More information about the Mlir-commits mailing list