[Mlir-commits] [flang] [mlir] [mlir][IR] Change `notifyBlockCreated` to `notifyBlockInserted` (PR #79472)

Matthias Springer llvmlistbot at llvm.org
Fri Jan 26 01:29:39 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/79472

>From 8e048fed5e9b5e9161c1512abc50fcc23d68e4c7 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Thu, 25 Jan 2024 17:09:57 +0000
Subject: [PATCH] [mlir][IR] Change `notifyBlockCreated` to
 `notifyBlockInserted`

This change makes the callback consistent with `notifyOperationInserted`: both now notify about IR insertion, not IR creation. See also #78988.

This change also simplifies the dialect conversion: it is no longer necessary to override the `inlineRegionBefore` method. All information that is necessary for rollback is provided with the `notifyBlockInserted` callback.
---
 .../HLFIR/Transforms/BufferizeHLFIR.cpp       | 10 ++-
 mlir/include/mlir/IR/Builders.h               | 10 ++-
 mlir/include/mlir/IR/PatternMatch.h           |  9 +--
 .../mlir/Transforms/DialectConversion.h       |  8 +--
 mlir/lib/IR/Builders.cpp                      |  2 +-
 mlir/lib/IR/PatternMatch.cpp                  | 13 +++-
 .../Transforms/Utils/DialectConversion.cpp    | 66 +++++++------------
 .../Utils/GreedyPatternRewriteDriver.cpp      | 10 +--
 8 files changed, 67 insertions(+), 61 deletions(-)

diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
index 5fe78b7408026f8..a7eb6174eb423aa 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
@@ -735,9 +735,13 @@ struct HLFIRListener : public mlir::OpBuilder::Listener {
     builder.notifyOperationInserted(op, previous);
     rewriter.notifyOperationInserted(op, previous);
   }
-  virtual void notifyBlockCreated(mlir::Block *block) override {
-    builder.notifyBlockCreated(block);
-    rewriter.notifyBlockCreated(block);
+  virtual void notifyBlockInserted(mlir::Block *block, mlir::Region *previous,
+                                   mlir::Region::iterator previousIt) override {
+    // We only care about newly created blocks.
+    if (previous)
+      return;
+    builder.notifyBlockInserted(block, previous, previousIt);
+    rewriter.notifyBlockInserted(block, previous, previousIt);
   }
   fir::FirOpBuilder &builder;
   mlir::ConversionPatternRewriter &rewriter;
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 6b95be7c6d372f8..8c25a1aa2fad14a 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -297,7 +297,15 @@ class OpBuilder : public Builder {
     virtual void notifyOperationInserted(Operation *op, InsertPoint previous) {}
 
     /// Notify the listener that the specified block was inserted.
-    virtual void notifyBlockCreated(Block *block) {}
+    ///
+    /// * If the block was moved, then `previous` and `previousIt` are the
+    ///   previous location of the block.
+    /// * If the block was unlinked before it was inserted, then `previous`
+    ///   is "nullptr".
+    ///
+    /// Note: Creating an (unlinked) block does not trigger this notification.
+    virtual void notifyBlockInserted(Block *block, Region *previous,
+                                     Region::iterator previousIt) {}
 
   protected:
     Listener(Kind kind) : ListenerBase(kind) {}
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 7f233cd3f4d4b3c..8eb129206b95ef6 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -455,8 +455,9 @@ class RewriterBase : public OpBuilder {
     void notifyOperationInserted(Operation *op, InsertPoint previous) override {
       listener->notifyOperationInserted(op, previous);
     }
-    void notifyBlockCreated(Block *block) override {
-      listener->notifyBlockCreated(block);
+    void notifyBlockInserted(Block *block, Region *previous,
+                             Region::iterator previousIt) override {
+      listener->notifyBlockInserted(block, previous, previousIt);
     }
     void notifyBlockRemoved(Block *block) override {
       if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
@@ -495,8 +496,8 @@ class RewriterBase : public OpBuilder {
   /// another region "parent". The two regions must be different. The caller
   /// is responsible for creating or updating the operation transferring flow
   /// of control to the region and passing it the correct block arguments.
-  virtual void inlineRegionBefore(Region &region, Region &parent,
-                                  Region::iterator before);
+  void inlineRegionBefore(Region &region, Region &parent,
+                          Region::iterator before);
   void inlineRegionBefore(Region &region, Block *before);
 
   /// Clone the blocks that belong to "region" before the given position in
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 32c5937d014e9ef..d9470de9ceb9f56 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -713,7 +713,8 @@ class ConversionPatternRewriter final : public PatternRewriter,
   void eraseBlock(Block *block) override;
 
   /// PatternRewriter hook creating a new block.
-  void notifyBlockCreated(Block *block) override;
+  void notifyBlockInserted(Block *block, Region *previous,
+                           Region::iterator previousIt) override;
 
   /// PatternRewriter hook for splitting a block into two parts.
   Block *splitBlock(Block *block, Block::iterator before) override;
@@ -723,11 +724,6 @@ class ConversionPatternRewriter final : public PatternRewriter,
                          ValueRange argValues = std::nullopt) override;
   using PatternRewriter::inlineBlockBefore;
 
-  /// PatternRewriter hook for moving blocks out of a region.
-  void inlineRegionBefore(Region &region, Region &parent,
-                          Region::iterator before) override;
-  using PatternRewriter::inlineRegionBefore;
-
   /// PatternRewriter hook for cloning blocks of one region into another. The
   /// given region to clone *must* not have been modified as part of conversion
   /// yet, i.e. it must be within an operation that is either in the process of
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index a319afcdc6a9a23..7acef1073c6de20 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -429,7 +429,7 @@ Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt,
   setInsertionPointToEnd(b);
 
   if (listener)
-    listener->notifyBlockCreated(b);
+    listener->notifyBlockInserted(b, /*previous=*/nullptr, /*previousIt=*/{});
   return b;
 }
 
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index affb8898fa07544..817bbb363e0d585 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -343,7 +343,18 @@ Block *RewriterBase::splitBlock(Block *block, Block::iterator before) {
 /// region and pass it the correct block arguments.
 void RewriterBase::inlineRegionBefore(Region &region, Region &parent,
                                       Region::iterator before) {
-  parent.getBlocks().splice(before, region.getBlocks());
+  // Fast path: If no listener is attached, move all blocks at once.
+  if (!listener) {
+    parent.getBlocks().splice(before, region.getBlocks());
+    return;
+  }
+
+  // Move blocks from the beginning of the region one-by-one.
+  while (!region.empty()) {
+    Block *block = &region.front();
+    parent.getBlocks().splice(before, region.getBlocks(), block->getIterator());
+    listener->notifyBlockInserted(block, &region, region.begin());
+  }
 }
 void RewriterBase::inlineRegionBefore(Region &region, Block *before) {
   inlineRegionBefore(region, *before->getParent(), before->getIterator());
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index f5bede2b94f9cb2..a79e9076fc28faf 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -250,10 +250,10 @@ enum class BlockActionKind {
 };
 
 /// Original position of the given block in its parent region. During undo
-/// actions, the block needs to be placed after `insertAfterBlock`.
+/// actions, the block needs to be placed before `insertBeforeBlock`.
 struct BlockPosition {
   Region *region;
-  Block *insertAfterBlock;
+  Block *insertBeforeBlock;
 };
 
 /// Information needed to undo inlining actions.
@@ -910,7 +910,8 @@ struct ConversionPatternRewriterImpl {
   void notifyBlockIsBeingErased(Block *block);
 
   /// Notifies that a block was created.
-  void notifyCreatedBlock(Block *block);
+  void notifyInsertedBlock(Block *block, Region *previous,
+                           Region::iterator previousIt);
 
   /// Notifies that a block was split.
   void notifySplitBlock(Block *block, Block *continuation);
@@ -919,10 +920,6 @@ struct ConversionPatternRewriterImpl {
   void notifyBlockBeingInlined(Block *block, Block *srcBlock,
                                Block::iterator before);
 
-  /// Notifies that the blocks of a region are about to be moved.
-  void notifyRegionIsBeingInlinedBefore(Region &region, Region &parent,
-                                        Region::iterator before);
-
   /// Notifies that a pattern match failed for the given reason.
   LogicalResult
   notifyMatchFailure(Location loc,
@@ -1173,10 +1170,9 @@ void ConversionPatternRewriterImpl::undoBlockActions(
     // Put the block (owned by action) back into its original position.
     case BlockActionKind::Erase: {
       auto &blockList = action.originalPosition.region->getBlocks();
-      Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
-      blockList.insert((insertAfterBlock
-                            ? std::next(Region::iterator(insertAfterBlock))
-                            : blockList.begin()),
+      Block *insertBeforeBlock = action.originalPosition.insertBeforeBlock;
+      blockList.insert((insertBeforeBlock ? Region::iterator(insertBeforeBlock)
+                                          : blockList.end()),
                        action.block);
       break;
     }
@@ -1196,10 +1192,10 @@ void ConversionPatternRewriterImpl::undoBlockActions(
     // Move the block back to its original position.
     case BlockActionKind::Move: {
       Region *originalRegion = action.originalPosition.region;
-      Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
+      Block *insertBeforeBlock = action.originalPosition.insertBeforeBlock;
       originalRegion->getBlocks().splice(
-          (insertAfterBlock ? std::next(Region::iterator(insertAfterBlock))
-                            : originalRegion->end()),
+          (insertBeforeBlock ? Region::iterator(insertBeforeBlock)
+                             : originalRegion->end()),
           action.block->getParent()->getBlocks(), action.block);
       break;
     }
@@ -1398,12 +1394,19 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
 
 void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
   Region *region = block->getParent();
-  Block *origPrevBlock = block->getPrevNode();
-  blockActions.push_back(BlockAction::getErase(block, {region, origPrevBlock}));
+  Block *origNextBlock = block->getNextNode();
+  blockActions.push_back(BlockAction::getErase(block, {region, origNextBlock}));
 }
 
-void ConversionPatternRewriterImpl::notifyCreatedBlock(Block *block) {
-  blockActions.push_back(BlockAction::getCreate(block));
+void ConversionPatternRewriterImpl::notifyInsertedBlock(
+    Block *block, Region *previous, Region::iterator previousIt) {
+  if (!previous) {
+    // This is a newly created block.
+    blockActions.push_back(BlockAction::getCreate(block));
+    return;
+  }
+  Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt;
+  blockActions.push_back(BlockAction::getMove(block, {previous, prevBlock}));
 }
 
 void ConversionPatternRewriterImpl::notifySplitBlock(Block *block,
@@ -1416,19 +1419,6 @@ void ConversionPatternRewriterImpl::notifyBlockBeingInlined(
   blockActions.push_back(BlockAction::getInline(block, srcBlock, before));
 }
 
-void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore(
-    Region &region, Region &parent, Region::iterator before) {
-  if (region.empty())
-    return;
-  Block *laterBlock = &region.back();
-  for (auto &earlierBlock : llvm::drop_begin(llvm::reverse(region), 1)) {
-    blockActions.push_back(
-        BlockAction::getMove(laterBlock, {&region, &earlierBlock}));
-    laterBlock = &earlierBlock;
-  }
-  blockActions.push_back(BlockAction::getMove(laterBlock, {&region, nullptr}));
-}
-
 LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure(
     Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
   LLVM_DEBUG({
@@ -1551,8 +1541,9 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
                            results);
 }
 
-void ConversionPatternRewriter::notifyBlockCreated(Block *block) {
-  impl->notifyCreatedBlock(block);
+void ConversionPatternRewriter::notifyBlockInserted(
+    Block *block, Region *previous, Region::iterator previousIt) {
+  impl->notifyInsertedBlock(block, previous, previousIt);
 }
 
 Block *ConversionPatternRewriter::splitBlock(Block *block,
@@ -1582,13 +1573,6 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
   eraseBlock(source);
 }
 
-void ConversionPatternRewriter::inlineRegionBefore(Region &region,
-                                                   Region &parent,
-                                                   Region::iterator before) {
-  impl->notifyRegionIsBeingInlinedBefore(region, parent, before);
-  PatternRewriter::inlineRegionBefore(region, parent, before);
-}
-
 void ConversionPatternRewriter::cloneRegionBefore(Region &region,
                                                   Region &parent,
                                                   Region::iterator before,
@@ -1600,7 +1584,7 @@ void ConversionPatternRewriter::cloneRegionBefore(Region &region,
 
   for (Block &b : ForwardDominanceIterator<>::makeIterable(region)) {
     Block *cloned = mapping.lookup(&b);
-    impl->notifyCreatedBlock(cloned);
+    impl->notifyInsertedBlock(cloned, /*previous=*/nullptr, /*previousIt=*/{});
     cloned->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>(
         [&](Operation *op) { notifyOperationInserted(op, /*previous=*/{}); });
   }
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index c27fee7a738eba0..543dab0f309136f 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -377,8 +377,9 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
   /// simplifications.
   void addOperandsToWorklist(ValueRange operands);
 
-  /// Notify the driver that the given block was created.
-  void notifyBlockCreated(Block *block) override;
+  /// Notify the driver that the given block was inserted.
+  void notifyBlockInserted(Block *block, Region *previous,
+                           Region::iterator previousIt) override;
 
   /// Notify the driver that the given block is about to be removed.
   void notifyBlockRemoved(Block *block) override;
@@ -638,9 +639,10 @@ void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) {
     worklist.push(op);
 }
 
-void GreedyPatternRewriteDriver::notifyBlockCreated(Block *block) {
+void GreedyPatternRewriteDriver::notifyBlockInserted(
+    Block *block, Region *previous, Region::iterator previousIt) {
   if (config.listener)
-    config.listener->notifyBlockCreated(block);
+    config.listener->notifyBlockInserted(block, previous, previousIt);
 }
 
 void GreedyPatternRewriteDriver::notifyBlockRemoved(Block *block) {



More information about the Mlir-commits mailing list