[Mlir-commits] [mlir] [mlir][IR] Add `notifyBlockRemoved` callback to listener (PR #78306)

Matthias Springer llvmlistbot at llvm.org
Tue Jan 16 08:19:13 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/78306

>From d661ff6e56f470263b6c6467d5600a3a93365fce Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Tue, 16 Jan 2024 16:12:47 +0000
Subject: [PATCH] [mlir][IR] Add `notifyBlockRemoved` callback to listener

There is already a "block inserted" notification (in `OpBuilder::Listener`), so there should also be a "block removed" notification.

The purpose of this change is to make the listener API more mature. There is currently a gap between what kind of IR changes can be made and what IR changes can be listened to. At the moment, the only way to inform listeners about "block removal" is to send a manual `notifyOperationModified` for the parent op (e.g., by wrapping the `eraseBlock(b)` method call in `updateRootInPlace(b->getParentOp())`). This tells the listener that *something* has changed, but it is somewhat of an API abuse.
---
 mlir/include/mlir/IR/PatternMatch.h                   |  8 ++++++++
 mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp       |  2 +-
 mlir/lib/IR/PatternMatch.cpp                          | 11 +++++++++--
 .../Transforms/Utils/GreedyPatternRewriteDriver.cpp   |  8 ++++++++
 4 files changed, 26 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 9b4fa65bff49e12..6c5d317a79b7ab5 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -402,6 +402,10 @@ class RewriterBase : public OpBuilder {
     Listener()
         : OpBuilder::Listener(ListenerBase::Kind::RewriterBaseListener) {}
 
+    /// Notify the listener that the specified block is about to be erased.
+    /// At this point, the block has zero uses.
+    virtual void notifyBlockRemoved(Block *block) {}
+
     /// Notify the listener that the specified operation was modified in-place.
     virtual void notifyOperationModified(Operation *op) {}
 
@@ -452,6 +456,10 @@ class RewriterBase : public OpBuilder {
     void notifyBlockCreated(Block *block) override {
       listener->notifyBlockCreated(block);
     }
+    void notifyBlockRemoved(Block *block) override {
+      if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+        rewriteListener->notifyBlockRemoved(block);
+    }
     void notifyOperationModified(Operation *op) override {
       if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
         rewriteListener->notifyOperationModified(op);
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index abb65bc3c38f221..b63baf330c86457 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -1114,7 +1114,7 @@ static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
   scf::ForOp newLoop = rewriter.create<scf::ForOp>(
       loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
       operands);
-  newLoop.getBody()->erase();
+  rewriter.eraseBlock(newLoop.getBody());
 
   newLoop.getRegion().getBlocks().splice(
       newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 5e788cdb4897d3d..de226f40cb3e9ad 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -244,7 +244,7 @@ void RewriterBase::eraseOp(Operation *op) {
           for (BlockArgument bbArg : b->getArguments())
             bbArg.dropAllUses();
           b->dropAllUses();
-          b->erase();
+          eraseBlock(b);
         }
       }
     }
@@ -256,10 +256,17 @@ void RewriterBase::eraseOp(Operation *op) {
 }
 
 void RewriterBase::eraseBlock(Block *block) {
+  assert(block->use_empty() && "expected 'block' to have no uses");
+
   for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
     assert(op.use_empty() && "expected 'op' to have no uses");
     eraseOp(&op);
   }
+
+  // Notify the listener that the block is about to be removed.
+  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
+    rewriteListener->notifyBlockRemoved(block);
+
   block->erase();
 }
 
@@ -311,7 +318,7 @@ void RewriterBase::inlineBlockBefore(Block *source, Block *dest,
   // Move operations from the source block to the dest block and erase the
   // source block.
   dest->getOperations().splice(before, source->getOperations());
-  source->erase();
+  eraseBlock(source);
 }
 
 void RewriterBase::inlineBlockBefore(Block *source, Operation *op,
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 36d63d62bf10fc2..ac73e82bfe92a69 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -373,6 +373,9 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
   /// Notify the driver that the given block was created.
   void notifyBlockCreated(Block *block) override;
 
+  /// Notify the driver that the given block is about to be removed.
+  void notifyBlockRemoved(Block *block) override;
+
   /// For debugging only: Notify the driver of a pattern match failure.
   LogicalResult
   notifyMatchFailure(Location loc,
@@ -633,6 +636,11 @@ void GreedyPatternRewriteDriver::notifyBlockCreated(Block *block) {
     config.listener->notifyBlockCreated(block);
 }
 
+void GreedyPatternRewriteDriver::notifyBlockRemoved(Block *block) {
+  if (config.listener)
+    config.listener->notifyBlockRemoved(block);
+}
+
 void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
   LLVM_DEBUG({
     logger.startLine() << "** Insert  : '" << op->getName() << "'(" << op



More information about the Mlir-commits mailing list