[Mlir-commits] [mlir] [mlir][IR] Change `notifyBlockCreated` to `notifyBlockInserted` (PR #79472)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jan 25 09:12:50 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
This change makes the callback consistent with `notifyOperationInserted`: both now notify about IR insertion, not IR creation. See also #<!-- -->78988.
This change also simplifies the dialect conversion: it is no longer necessary to override the `inlineRegionBefore` method. All information that is necessary for rollback is provided with the `notifyBlockInserted` callback.
---
Full diff: https://github.com/llvm/llvm-project/pull/79472.diff
7 Files Affected:
- (modified) mlir/include/mlir/IR/Builders.h (+9-1)
- (modified) mlir/include/mlir/IR/PatternMatch.h (+5-4)
- (modified) mlir/include/mlir/Transforms/DialectConversion.h (+2-6)
- (modified) mlir/lib/IR/Builders.cpp (+1-1)
- (modified) mlir/lib/IR/PatternMatch.cpp (+12-1)
- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+25-41)
- (modified) mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp (+6-4)
``````````diff
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 6b95be7c6d372f8..8c25a1aa2fad14a 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -297,7 +297,15 @@ class OpBuilder : public Builder {
virtual void notifyOperationInserted(Operation *op, InsertPoint previous) {}
/// Notify the listener that the specified block was inserted.
- virtual void notifyBlockCreated(Block *block) {}
+ ///
+ /// * If the block was moved, then `previous` and `previousIt` are the
+ /// previous location of the block.
+ /// * If the block was unlinked before it was inserted, then `previous`
+ /// is "nullptr".
+ ///
+ /// Note: Creating an (unlinked) block does not trigger this notification.
+ virtual void notifyBlockInserted(Block *block, Region *previous,
+ Region::iterator previousIt) {}
protected:
Listener(Kind kind) : ListenerBase(kind) {}
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 7f233cd3f4d4b3c..8eb129206b95ef6 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -455,8 +455,9 @@ class RewriterBase : public OpBuilder {
void notifyOperationInserted(Operation *op, InsertPoint previous) override {
listener->notifyOperationInserted(op, previous);
}
- void notifyBlockCreated(Block *block) override {
- listener->notifyBlockCreated(block);
+ void notifyBlockInserted(Block *block, Region *previous,
+ Region::iterator previousIt) override {
+ listener->notifyBlockInserted(block, previous, previousIt);
}
void notifyBlockRemoved(Block *block) override {
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
@@ -495,8 +496,8 @@ class RewriterBase : public OpBuilder {
/// another region "parent". The two regions must be different. The caller
/// is responsible for creating or updating the operation transferring flow
/// of control to the region and passing it the correct block arguments.
- virtual void inlineRegionBefore(Region ®ion, Region &parent,
- Region::iterator before);
+ void inlineRegionBefore(Region ®ion, Region &parent,
+ Region::iterator before);
void inlineRegionBefore(Region ®ion, Block *before);
/// Clone the blocks that belong to "region" before the given position in
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 32c5937d014e9ef..d9470de9ceb9f56 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -713,7 +713,8 @@ class ConversionPatternRewriter final : public PatternRewriter,
void eraseBlock(Block *block) override;
/// PatternRewriter hook creating a new block.
- void notifyBlockCreated(Block *block) override;
+ void notifyBlockInserted(Block *block, Region *previous,
+ Region::iterator previousIt) override;
/// PatternRewriter hook for splitting a block into two parts.
Block *splitBlock(Block *block, Block::iterator before) override;
@@ -723,11 +724,6 @@ class ConversionPatternRewriter final : public PatternRewriter,
ValueRange argValues = std::nullopt) override;
using PatternRewriter::inlineBlockBefore;
- /// PatternRewriter hook for moving blocks out of a region.
- void inlineRegionBefore(Region ®ion, Region &parent,
- Region::iterator before) override;
- using PatternRewriter::inlineRegionBefore;
-
/// PatternRewriter hook for cloning blocks of one region into another. The
/// given region to clone *must* not have been modified as part of conversion
/// yet, i.e. it must be within an operation that is either in the process of
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index a319afcdc6a9a23..7acef1073c6de20 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -429,7 +429,7 @@ Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt,
setInsertionPointToEnd(b);
if (listener)
- listener->notifyBlockCreated(b);
+ listener->notifyBlockInserted(b, /*previous=*/nullptr, /*previousIt=*/{});
return b;
}
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index affb8898fa07544..817bbb363e0d585 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -343,7 +343,18 @@ Block *RewriterBase::splitBlock(Block *block, Block::iterator before) {
/// region and pass it the correct block arguments.
void RewriterBase::inlineRegionBefore(Region ®ion, Region &parent,
Region::iterator before) {
- parent.getBlocks().splice(before, region.getBlocks());
+ // Fast path: If no listener is attached, move all blocks at once.
+ if (!listener) {
+ parent.getBlocks().splice(before, region.getBlocks());
+ return;
+ }
+
+ // Move blocks from the beginning of the region one-by-one.
+ while (!region.empty()) {
+ Block *block = ®ion.front();
+ parent.getBlocks().splice(before, region.getBlocks(), block->getIterator());
+ listener->notifyBlockInserted(block, ®ion, region.begin());
+ }
}
void RewriterBase::inlineRegionBefore(Region ®ion, Block *before) {
inlineRegionBefore(region, *before->getParent(), before->getIterator());
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index f5bede2b94f9cb2..a79e9076fc28faf 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -250,10 +250,10 @@ enum class BlockActionKind {
};
/// Original position of the given block in its parent region. During undo
-/// actions, the block needs to be placed after `insertAfterBlock`.
+/// actions, the block needs to be placed before `insertBeforeBlock`.
struct BlockPosition {
Region *region;
- Block *insertAfterBlock;
+ Block *insertBeforeBlock;
};
/// Information needed to undo inlining actions.
@@ -910,7 +910,8 @@ struct ConversionPatternRewriterImpl {
void notifyBlockIsBeingErased(Block *block);
/// Notifies that a block was created.
- void notifyCreatedBlock(Block *block);
+ void notifyInsertedBlock(Block *block, Region *previous,
+ Region::iterator previousIt);
/// Notifies that a block was split.
void notifySplitBlock(Block *block, Block *continuation);
@@ -919,10 +920,6 @@ struct ConversionPatternRewriterImpl {
void notifyBlockBeingInlined(Block *block, Block *srcBlock,
Block::iterator before);
- /// Notifies that the blocks of a region are about to be moved.
- void notifyRegionIsBeingInlinedBefore(Region ®ion, Region &parent,
- Region::iterator before);
-
/// Notifies that a pattern match failed for the given reason.
LogicalResult
notifyMatchFailure(Location loc,
@@ -1173,10 +1170,9 @@ void ConversionPatternRewriterImpl::undoBlockActions(
// Put the block (owned by action) back into its original position.
case BlockActionKind::Erase: {
auto &blockList = action.originalPosition.region->getBlocks();
- Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
- blockList.insert((insertAfterBlock
- ? std::next(Region::iterator(insertAfterBlock))
- : blockList.begin()),
+ Block *insertBeforeBlock = action.originalPosition.insertBeforeBlock;
+ blockList.insert((insertBeforeBlock ? Region::iterator(insertBeforeBlock)
+ : blockList.end()),
action.block);
break;
}
@@ -1196,10 +1192,10 @@ void ConversionPatternRewriterImpl::undoBlockActions(
// Move the block back to its original position.
case BlockActionKind::Move: {
Region *originalRegion = action.originalPosition.region;
- Block *insertAfterBlock = action.originalPosition.insertAfterBlock;
+ Block *insertBeforeBlock = action.originalPosition.insertBeforeBlock;
originalRegion->getBlocks().splice(
- (insertAfterBlock ? std::next(Region::iterator(insertAfterBlock))
- : originalRegion->end()),
+ (insertBeforeBlock ? Region::iterator(insertBeforeBlock)
+ : originalRegion->end()),
action.block->getParent()->getBlocks(), action.block);
break;
}
@@ -1398,12 +1394,19 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
Region *region = block->getParent();
- Block *origPrevBlock = block->getPrevNode();
- blockActions.push_back(BlockAction::getErase(block, {region, origPrevBlock}));
+ Block *origNextBlock = block->getNextNode();
+ blockActions.push_back(BlockAction::getErase(block, {region, origNextBlock}));
}
-void ConversionPatternRewriterImpl::notifyCreatedBlock(Block *block) {
- blockActions.push_back(BlockAction::getCreate(block));
+void ConversionPatternRewriterImpl::notifyInsertedBlock(
+ Block *block, Region *previous, Region::iterator previousIt) {
+ if (!previous) {
+ // This is a newly created block.
+ blockActions.push_back(BlockAction::getCreate(block));
+ return;
+ }
+ Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt;
+ blockActions.push_back(BlockAction::getMove(block, {previous, prevBlock}));
}
void ConversionPatternRewriterImpl::notifySplitBlock(Block *block,
@@ -1416,19 +1419,6 @@ void ConversionPatternRewriterImpl::notifyBlockBeingInlined(
blockActions.push_back(BlockAction::getInline(block, srcBlock, before));
}
-void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore(
- Region ®ion, Region &parent, Region::iterator before) {
- if (region.empty())
- return;
- Block *laterBlock = ®ion.back();
- for (auto &earlierBlock : llvm::drop_begin(llvm::reverse(region), 1)) {
- blockActions.push_back(
- BlockAction::getMove(laterBlock, {®ion, &earlierBlock}));
- laterBlock = &earlierBlock;
- }
- blockActions.push_back(BlockAction::getMove(laterBlock, {®ion, nullptr}));
-}
-
LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure(
Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
LLVM_DEBUG({
@@ -1551,8 +1541,9 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
results);
}
-void ConversionPatternRewriter::notifyBlockCreated(Block *block) {
- impl->notifyCreatedBlock(block);
+void ConversionPatternRewriter::notifyBlockInserted(
+ Block *block, Region *previous, Region::iterator previousIt) {
+ impl->notifyInsertedBlock(block, previous, previousIt);
}
Block *ConversionPatternRewriter::splitBlock(Block *block,
@@ -1582,13 +1573,6 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
eraseBlock(source);
}
-void ConversionPatternRewriter::inlineRegionBefore(Region ®ion,
- Region &parent,
- Region::iterator before) {
- impl->notifyRegionIsBeingInlinedBefore(region, parent, before);
- PatternRewriter::inlineRegionBefore(region, parent, before);
-}
-
void ConversionPatternRewriter::cloneRegionBefore(Region ®ion,
Region &parent,
Region::iterator before,
@@ -1600,7 +1584,7 @@ void ConversionPatternRewriter::cloneRegionBefore(Region ®ion,
for (Block &b : ForwardDominanceIterator<>::makeIterable(region)) {
Block *cloned = mapping.lookup(&b);
- impl->notifyCreatedBlock(cloned);
+ impl->notifyInsertedBlock(cloned, /*previous=*/nullptr, /*previousIt=*/{});
cloned->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>(
[&](Operation *op) { notifyOperationInserted(op, /*previous=*/{}); });
}
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index c27fee7a738eba0..543dab0f309136f 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -377,8 +377,9 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
/// simplifications.
void addOperandsToWorklist(ValueRange operands);
- /// Notify the driver that the given block was created.
- void notifyBlockCreated(Block *block) override;
+ /// Notify the driver that the given block was inserted.
+ void notifyBlockInserted(Block *block, Region *previous,
+ Region::iterator previousIt) override;
/// Notify the driver that the given block is about to be removed.
void notifyBlockRemoved(Block *block) override;
@@ -638,9 +639,10 @@ void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) {
worklist.push(op);
}
-void GreedyPatternRewriteDriver::notifyBlockCreated(Block *block) {
+void GreedyPatternRewriteDriver::notifyBlockInserted(
+ Block *block, Region *previous, Region::iterator previousIt) {
if (config.listener)
- config.listener->notifyBlockCreated(block);
+ config.listener->notifyBlockInserted(block, previous, previousIt);
}
void GreedyPatternRewriteDriver::notifyBlockRemoved(Block *block) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/79472
More information about the Mlir-commits
mailing list