[Mlir-commits] [mlir] 62bf771 - [mlir][IR] Add `notifyBlockRemoved` callback to listener (#78306)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jan 21 01:06:57 PST 2024


Author: Matthias Springer
Date: 2024-01-21T10:06:53+01:00
New Revision: 62bf7710ff295cc7bb0bb281c471ca0c91fd156e

URL: https://github.com/llvm/llvm-project/commit/62bf7710ff295cc7bb0bb281c471ca0c91fd156e
DIFF: https://github.com/llvm/llvm-project/commit/62bf7710ff295cc7bb0bb281c471ca0c91fd156e.diff

LOG: [mlir][IR] Add `notifyBlockRemoved` callback to listener (#78306)

There is already a "block inserted" notification (in
`OpBuilder::Listener`), so there should also be a "block removed"
notification.

The purpose of this change is to make the listener API more mature.
There is currently a gap between what kind of IR changes can be made and
what IR changes can be listened to. At the moment, the only way to
inform listeners about "block removal" is to send a manual
`notifyOperationModified` for the parent op (e.g., by wrapping the
`eraseBlock(b)` method call in `updateRootInPlace(b->getParentOp())`).
This tells the listener that *something* has changed, but it is somewhat
of an API abuse.

Added: 
    

Modified: 
    mlir/include/mlir/IR/PatternMatch.h
    mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
    mlir/lib/IR/PatternMatch.cpp
    mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index b065d4e8d37689..815340c9185093 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -402,6 +402,10 @@ class RewriterBase : public OpBuilder {
     Listener()
         : OpBuilder::Listener(ListenerBase::Kind::RewriterBaseListener) {}
 
+    /// Notify the listener that the specified block is about to be erased.
+    /// At this point, the block has zero uses.
+    virtual void notifyBlockRemoved(Block *block) {}
+
     /// Notify the listener that the specified operation was modified in-place.
     virtual void notifyOperationModified(Operation *op) {}
 
@@ -452,6 +456,10 @@ class RewriterBase : public OpBuilder {
     void notifyBlockCreated(Block *block) override {
       listener->notifyBlockCreated(block);
     }
+    void notifyBlockRemoved(Block *block) override {
+      if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+        rewriteListener->notifyBlockRemoved(block);
+    }
     void notifyOperationModified(Operation *op) override {
       if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
         rewriteListener->notifyOperationModified(op);

diff  --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index abb65bc3c38f22..b63baf330c8645 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -1114,7 +1114,7 @@ static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
   scf::ForOp newLoop = rewriter.create<scf::ForOp>(
       loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
       operands);
-  newLoop.getBody()->erase();
+  rewriter.eraseBlock(newLoop.getBody());
 
   newLoop.getRegion().getBlocks().splice(
       newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());

diff  --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 73f232fd0de01a..ba0516e0539b6c 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -244,7 +244,7 @@ void RewriterBase::eraseOp(Operation *op) {
           for (BlockArgument bbArg : b->getArguments())
             bbArg.dropAllUses();
           b->dropAllUses();
-          b->erase();
+          eraseBlock(b);
         }
       }
     }
@@ -256,10 +256,17 @@ void RewriterBase::eraseOp(Operation *op) {
 }
 
 void RewriterBase::eraseBlock(Block *block) {
+  assert(block->use_empty() && "expected 'block' to have no uses");
+
   for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
     assert(op.use_empty() && "expected 'op' to have no uses");
     eraseOp(&op);
   }
+
+  // Notify the listener that the block is about to be removed.
+  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
+    rewriteListener->notifyBlockRemoved(block);
+
   block->erase();
 }
 
@@ -311,7 +318,7 @@ void RewriterBase::inlineBlockBefore(Block *source, Block *dest,
   // Move operations from the source block to the dest block and erase the
   // source block.
   dest->getOperations().splice(before, source->getOperations());
-  source->erase();
+  eraseBlock(source);
 }
 
 void RewriterBase::inlineBlockBefore(Block *source, Operation *op,

diff  --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 36d63d62bf10fc..ac73e82bfe92a6 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -373,6 +373,9 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
   /// Notify the driver that the given block was created.
   void notifyBlockCreated(Block *block) override;
 
+  /// Notify the driver that the given block is about to be removed.
+  void notifyBlockRemoved(Block *block) override;
+
   /// For debugging only: Notify the driver of a pattern match failure.
   LogicalResult
   notifyMatchFailure(Location loc,
@@ -633,6 +636,11 @@ void GreedyPatternRewriteDriver::notifyBlockCreated(Block *block) {
     config.listener->notifyBlockCreated(block);
 }
 
+void GreedyPatternRewriteDriver::notifyBlockRemoved(Block *block) {
+  if (config.listener)
+    config.listener->notifyBlockRemoved(block);
+}
+
 void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
   LLVM_DEBUG({
     logger.startLine() << "** Insert  : '" << op->getName() << "'(" << op


        


More information about the Mlir-commits mailing list