[Mlir-commits] [mlir] 42c31d8 - [mlir][IR] Clean up mergeBlockBefore and mergeBlocks
Matthias Springer
llvmlistbot at llvm.org
Mon Mar 6 04:53:41 PST 2023
Author: Matthias Springer
Date: 2023-03-06T13:46:08+01:00
New Revision: 42c31d8302ff8f716601df9c276a5cd9ace6b158
URL: https://github.com/llvm/llvm-project/commit/42c31d8302ff8f716601df9c276a5cd9ace6b158
DIFF: https://github.com/llvm/llvm-project/commit/42c31d8302ff8f716601df9c276a5cd9ace6b158.diff
LOG: [mlir][IR] Clean up mergeBlockBefore and mergeBlocks
* `RewriterBase::mergeBlocks` is simplified: it is implemented in terms of `mergeBlockBefore`.
* The signature of `mergeBlockBefore` is consistent with other API (such as `inlineRegionBefore`): an overload for a `Block::iterator` is added.
* Additional safety checks are added to `mergeBlockBefore`: detect cases where the resulting IR could be invalid (no more `dropAllUses`) or partly unreachable (likely a case of incorrect API usage).
* Rename `mergeBlockBefore` to `inlineBlockBefore`.
Differential Revision: https://reviews.llvm.org/D144969
Added:
Modified:
mlir/include/mlir/IR/PatternMatch.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
mlir/lib/IR/PatternMatch.cpp
mlir/lib/Transforms/Utils/DialectConversion.cpp
mlir/test/lib/Dialect/Test/TestPatterns.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index ed431badf05b0..9c4790c031201 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -491,17 +491,36 @@ class RewriterBase : public OpBuilder {
/// This method erases all operations in a block.
virtual void eraseBlock(Block *block);
- /// Merge the operations of block 'source' into the end of block 'dest'.
- /// 'source's predecessors must either be empty or only contain 'dest`.
- /// 'argValues' is used to replace the block arguments of 'source' after
- /// merging.
- virtual void mergeBlocks(Block *source, Block *dest,
- ValueRange argValues = std::nullopt);
-
- // Merge the operations of block 'source' before the operation 'op'. Source
- // block should not have existing predecessors or successors.
- void mergeBlockBefore(Block *source, Operation *op,
- ValueRange argValues = std::nullopt);
+ /// Inline the operations of block 'source' into block 'dest' before the given
+ /// position. The source block will be deleted and must have no uses.
+ /// 'argValues' is used to replace the block arguments of 'source'.
+ ///
+ /// If the source block is inserted at the end of the dest block, the dest
+ /// block must have no successors. Similarly, if the source block is inserted
+ /// somewhere in the middle (or beginning) of the dest block, the source block
+ /// must have no successors. Otherwise, the resulting IR would have
+ /// unreachable operations.
+ virtual void inlineBlockBefore(Block *source, Block *dest,
+ Block::iterator before,
+ ValueRange argValues = std::nullopt);
+
+ /// Inline the operations of block 'source' before the operation 'op'. The
+ /// source block will be deleted and must have no uses. 'argValues' is used to
+ /// replace the block arguments of 'source'
+ ///
+ /// The source block must have no successors. Otherwise, the resulting IR
+ /// would have unreachable operations.
+ void inlineBlockBefore(Block *source, Operation *op,
+ ValueRange argValues = std::nullopt);
+
+ /// Inline the operations of block 'source' into the end of block 'dest'. The
+ /// source block will be deleted and must have no uses. 'argValues' is used to
+ /// replace the block arguments of 'source'
+ ///
+ /// The dest block must have no successors. Otherwise, the resulting IR would
+ /// have unreachable operation.
+ void mergeBlocks(Block *source, Block *dest,
+ ValueRange argValues = std::nullopt);
/// Split the operations starting at "before" (inclusive) out of the given
/// block into a new block, and return it.
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 229dc016957c6..b5127c099a366 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -702,8 +702,10 @@ class ConversionPatternRewriter final : public PatternRewriter,
/// PatternRewriter hook for splitting a block into two parts.
Block *splitBlock(Block *block, Block::iterator before) override;
- /// PatternRewriter hook for merging a block into another.
- void mergeBlocks(Block *source, Block *dest, ValueRange argValues) override;
+ /// PatternRewriter hook for inlining the ops of a block into another block.
+ void inlineBlockBefore(Block *source, Block *dest, Block::iterator before,
+ ValueRange argValues = std::nullopt) override;
+ using PatternRewriter::inlineBlockBefore;
/// PatternRewriter hook for moving blocks out of a region.
void inlineRegionBefore(Region ®ion, Region &parent,
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 2f9ad2b7ad9ba..57ba18bd53f1b 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -506,7 +506,7 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
Value arg = iterArgs[yieldOperands.size()];
yieldOperands.push_back(reduceBlock.getTerminator()->getOperand(0));
rewriter.eraseOp(reduceBlock.getTerminator());
- rewriter.mergeBlockBefore(&reduceBlock, &op, {arg, reduce.getOperand()});
+ rewriter.inlineBlockBefore(&reduceBlock, &op, {arg, reduce.getOperand()});
rewriter.eraseOp(reduce);
}
@@ -516,8 +516,8 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
if (newBody->empty())
rewriter.mergeBlocks(parallelOp.getBody(), newBody, ivs);
else
- rewriter.mergeBlockBefore(parallelOp.getBody(), newBody->getTerminator(),
- ivs);
+ rewriter.inlineBlockBefore(parallelOp.getBody(), newBody->getTerminator(),
+ ivs);
// Finally, create the terminator if required (for loops with no results, it
// has been already created in loop construction).
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index e74829ac7758d..8b46f530ae394 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2724,7 +2724,7 @@ struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
Operation *blockToMoveTerminator = blockToMove->getTerminator();
// Promote the "blockToMove" block to the parent operation block between the
// prologue and epilogue of "op".
- rewriter.mergeBlockBefore(blockToMove, op);
+ rewriter.inlineBlockBefore(blockToMove, op);
// Replace the "op" operation with the operands of the
// "blockToMoveTerminator" operation. Note that "blockToMoveTerminator" is
// the affine.yield operation present in the "blockToMove" block. It has no
diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index eaf69f734bae3..6bdf42e527887 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -147,8 +147,9 @@ simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) {
return failure();
// Merge the successor into the current block and erase the branch.
- rewriter.mergeBlocks(succ, opParent, op.getOperands());
+ SmallVector<Value> brOperands(op.getOperands());
rewriter.eraseOp(op);
+ rewriter.mergeBlocks(succ, opParent, brOperands);
return success();
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 45fac0ce82829..c8999943d868e 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -553,7 +553,7 @@ struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> {
Block *block = &op.getRegion().front();
Operation *terminator = block->getTerminator();
ValueRange results = terminator->getOperands();
- rewriter.mergeBlockBefore(block, op);
+ rewriter.inlineBlockBefore(block, op);
rewriter.replaceOp(op, results);
rewriter.eraseOp(terminator);
return success();
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 8852b79b7b7ee..ed5e1bf04c5dc 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -107,7 +107,7 @@ static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
Block *block = ®ion.front();
Operation *terminator = block->getTerminator();
ValueRange results = terminator->getOperands();
- rewriter.mergeBlockBefore(block, op, blockArgs);
+ rewriter.inlineBlockBefore(block, op, blockArgs);
rewriter.replaceOp(op, results);
rewriter.eraseOp(terminator);
}
@@ -630,7 +630,7 @@ namespace {
// the ForOp region and can just be forwarded after simplifying the op inits,
// yields and returns.
//
-// The implementation uses `mergeBlockBefore` to steal the content of the
+// The implementation uses `inlineBlockBefore` to steal the content of the
// original ForOp and avoid cloning.
struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
using OpRewritePattern<scf::ForOp>::OpRewritePattern;
@@ -645,7 +645,7 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
// arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
// transformed block argument mappings. This plays the role of a
// IRMapping for the particular use case of calling into
- // `mergeBlockBefore`.
+ // `inlineBlockBefore`.
SmallVector<bool, 4> keepMask;
keepMask.reserve(yieldOp.getNumOperands());
SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
@@ -715,7 +715,7 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
// original terminator that has been merged in.
if (newIterArgs.empty()) {
auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
- rewriter.mergeBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
+ rewriter.inlineBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
rewriter.replaceOp(forOp, newResultValues);
return success();
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 0663bd927fe5e..dfe2d74b1e54f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -168,7 +168,7 @@ static LogicalResult genForeachOnSparseConstant(ForeachOp op,
auto cloned = cast<ForeachOp>(rewriter.clone(*op.getOperation()));
assert(args.size() == cloned.getBody()->getNumArguments());
Operation *yield = cloned.getBody()->getTerminator();
- rewriter.mergeBlockBefore(cloned.getBody(), op, args);
+ rewriter.inlineBlockBefore(cloned.getBody(), op, args);
// clean up
rewriter.eraseOp(cloned);
reduc = yield->getOperands();
@@ -988,7 +988,8 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
// This is annoying, since scf.for inserts a implicit yield op when
// there is no reduction variable upon creation, in this case we need to
// merge the block *before* the yield op.
- rewriter.mergeBlockBefore(srcBlock, &*rewriter.getInsertionPoint(), args);
+ rewriter.inlineBlockBefore(srcBlock, &*rewriter.getInsertionPoint(),
+ args);
}
for (Dimension d = 0; d < dimRank; d++) {
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index a0892b25e2c7b..4614a103b77e2 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -1211,7 +1211,7 @@ static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region ®ion,
YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator());
// Merge cloned block and return yield value.
Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- rewriter.mergeBlockBefore(&tmpRegion.front(), placeholder, vals);
+ rewriter.inlineBlockBefore(&tmpRegion.front(), placeholder, vals);
Value val = clonedYield.getResult();
rewriter.eraseOp(clonedYield);
rewriter.eraseOp(placeholder);
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 1fc234e104219..052696d5cb13a 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -303,29 +303,6 @@ void RewriterBase::finalizeRootUpdate(Operation *op) {
rewriteListener->notifyOperationModified(op);
}
-/// Merge the operations of block 'source' into the end of block 'dest'.
-/// 'source's predecessors must be empty or only contain 'dest`.
-/// 'argValues' is used to replace the block arguments of 'source' after
-/// merging.
-void RewriterBase::mergeBlocks(Block *source, Block *dest,
- ValueRange argValues) {
- assert(llvm::all_of(source->getPredecessors(),
- [dest](Block *succ) { return succ == dest; }) &&
- "expected 'source' to have no predecessors or only 'dest'");
- assert(argValues.size() == source->getNumArguments() &&
- "incorrect # of argument replacement values");
-
- // Replace all of the successor arguments with the provided values.
- for (auto it : llvm::zip(source->getArguments(), argValues))
- replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
-
- // Splice the operations of the 'source' block into the 'dest' block and erase
- // it.
- dest->getOperations().splice(dest->end(), source->getOperations());
- source->dropAllUses();
- source->erase();
-}
-
/// Find uses of `from` and replace them with `to` if the `functor` returns
/// true. It also marks every modified uses and notifies the rewriter that an
/// in-place operation modification is about to happen.
@@ -337,26 +314,48 @@ void RewriterBase::replaceUsesWithIf(Value from, Value to,
}
}
-// Merge the operations of block 'source' before the operation 'op'. Source
-// block should not have existing predecessors or successors.
-void RewriterBase::mergeBlockBefore(Block *source, Operation *op,
- ValueRange argValues) {
+void RewriterBase::inlineBlockBefore(Block *source, Block *dest,
+ Block::iterator before,
+ ValueRange argValues) {
+ assert(argValues.size() == source->getNumArguments() &&
+ "incorrect # of argument replacement values");
+
+ // The source block will be deleted, so it should not have any users (i.e.,
+ // there should be no predecessors).
assert(source->hasNoPredecessors() &&
"expected 'source' to have no predecessors");
- assert(source->hasNoSuccessors() &&
- "expected 'source' to have no successors");
- // Split the block containing 'op' into two, one containing all operations
- // before 'op' (prologue) and another (epilogue) containing 'op' and all
- // operations after it.
- Block *prologue = op->getBlock();
- Block *epilogue = splitBlock(prologue, op->getIterator());
+ if (dest->end() != before) {
+ // The source block will be inserted in the middle of the dest block, so
+ // the source block should have no successors. Otherwise, the remainder of
+ // the dest block would be unreachable.
+ assert(source->hasNoSuccessors() &&
+ "expected 'source' to have no successors");
+ } else {
+ // The source block will be inserted at the end of the dest block, so the
+ // dest block should have no successors. Otherwise, the inserted operations
+ // will be unreachable.
+ assert(dest->hasNoSuccessors() && "expected 'dest' to have no successors");
+ }
- // Merge the source block at the end of the prologue.
- mergeBlocks(source, prologue, argValues);
+ // Replace all of the successor arguments with the provided values.
+ for (auto it : llvm::zip(source->getArguments(), argValues))
+ replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
- // Merge the epilogue at the end the prologue.
- mergeBlocks(epilogue, prologue);
+ // Move operations from the source block to the dest block and erase the
+ // source block.
+ dest->getOperations().splice(before, source->getOperations());
+ source->erase();
+}
+
+void RewriterBase::inlineBlockBefore(Block *source, Operation *op,
+ ValueRange argValues) {
+ inlineBlockBefore(source, op->getBlock(), op->getIterator(), argValues);
+}
+
+void RewriterBase::mergeBlocks(Block *source, Block *dest,
+ ValueRange argValues) {
+ inlineBlockBefore(source, dest, dest->end(), argValues);
}
/// Split the operations starting at "before" (inclusive) out of the given
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 52ebd54e1bee8..99f4bf6ba092f 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -289,7 +289,7 @@ struct OpReplacement {
enum class BlockActionKind {
Create,
Erase,
- Merge,
+ Inline,
Move,
Split,
TypeConversion
@@ -302,13 +302,14 @@ struct BlockPosition {
Block *insertAfterBlock;
};
-/// Information needed to undo the merge actions.
-/// - the source block, and
-/// - the Operation that was the last operation in the dest block before the
-/// merge (could be null if the dest block was empty).
-struct MergeInfo {
+/// Information needed to undo inlining actions.
+/// - the source block
+/// - the first inlined operation (could be null if the source block was empty)
+/// - the last inlined operation (could be null if the source block was empty)
+struct InlineInfo {
Block *sourceBlock;
- Operation *destBlockLastInst;
+ Operation *firstInlinedInst;
+ Operation *lastInlinedInst;
};
/// The storage class for an undoable block action (one of BlockActionKind),
@@ -320,9 +321,12 @@ struct BlockAction {
static BlockAction getErase(Block *block, BlockPosition originalPosition) {
return {BlockActionKind::Erase, block, {originalPosition}};
}
- static BlockAction getMerge(Block *block, Block *sourceBlock) {
- BlockAction action{BlockActionKind::Merge, block, {}};
- action.mergeInfo = {sourceBlock, block->empty() ? nullptr : &block->back()};
+ static BlockAction getInline(Block *block, Block *srcBlock,
+ Block::iterator before) {
+ BlockAction action{BlockActionKind::Inline, block, {}};
+ action.inlineInfo = {srcBlock,
+ srcBlock->empty() ? nullptr : &srcBlock->front(),
+ srcBlock->empty() ? nullptr : &srcBlock->back()};
return action;
}
static BlockAction getMove(Block *block, BlockPosition originalPosition) {
@@ -344,16 +348,16 @@ struct BlockAction {
Block *block;
union {
- // In use if kind == BlockActionKind::Move or BlockActionKind::Erase, and
+ // In use if kind == BlockActionKind::Inline or BlockActionKind::Erase, and
// contains a pointer to the region that originally contained the block as
// well as the position of the block in that region.
BlockPosition originalPosition;
// In use if kind == BlockActionKind::Split and contains a pointer to the
// block that was split into two parts.
Block *originalBlock;
- // In use if kind == BlockActionKind::Merge, and contains the information
- // needed to undo the merge.
- MergeInfo mergeInfo;
+ // In use if kind == BlockActionKind::Inline, and contains the information
+ // needed to undo the inlining.
+ InlineInfo inlineInfo;
};
};
@@ -945,8 +949,9 @@ struct ConversionPatternRewriterImpl {
/// Notifies that a block was split.
void notifySplitBlock(Block *block, Block *continuation);
- /// Notifies that `block` is being merged with `srcBlock`.
- void notifyBlocksBeingMerged(Block *block, Block *srcBlock);
+ /// Notifies that a block is being inlined into another block.
+ void notifyBlockBeingInlined(Block *block, Block *srcBlock,
+ Block::iterator before);
/// Notifies that the blocks of a region are about to be moved.
void notifyRegionIsBeingInlinedBefore(Region ®ion, Region &parent,
@@ -1213,18 +1218,17 @@ void ConversionPatternRewriterImpl::undoBlockActions(
action.block);
break;
}
- // Split the block at the position which was originally the end of the
- // destination block (owned by action), and put the instructions back into
- // the block used before the merge.
- case BlockActionKind::Merge: {
- Block *sourceBlock = action.mergeInfo.sourceBlock;
- Block::iterator splitPoint =
- (action.mergeInfo.destBlockLastInst
- ? ++Block::iterator(action.mergeInfo.destBlockLastInst)
- : action.block->begin());
- sourceBlock->getOperations().splice(sourceBlock->begin(),
- action.block->getOperations(),
- splitPoint, action.block->end());
+ // Put the instructions from the destination block (owned by the action)
+ // back into the source block.
+ case BlockActionKind::Inline: {
+ Block *sourceBlock = action.inlineInfo.sourceBlock;
+ if (action.inlineInfo.firstInlinedInst) {
+ assert(action.inlineInfo.lastInlinedInst && "expected operation");
+ sourceBlock->getOperations().splice(
+ sourceBlock->begin(), action.block->getOperations(),
+ Block::iterator(action.inlineInfo.firstInlinedInst),
+ ++Block::iterator(action.inlineInfo.lastInlinedInst));
+ }
break;
}
// Move the block back to its original position.
@@ -1445,9 +1449,9 @@ void ConversionPatternRewriterImpl::notifySplitBlock(Block *block,
blockActions.push_back(BlockAction::getSplit(continuation, block));
}
-void ConversionPatternRewriterImpl::notifyBlocksBeingMerged(Block *block,
- Block *srcBlock) {
- blockActions.push_back(BlockAction::getMerge(block, srcBlock));
+void ConversionPatternRewriterImpl::notifyBlockBeingInlined(
+ Block *block, Block *srcBlock, Block::iterator before) {
+ blockActions.push_back(BlockAction::getInline(block, srcBlock, before));
}
void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore(
@@ -1603,17 +1607,23 @@ Block *ConversionPatternRewriter::splitBlock(Block *block,
return continuation;
}
-void ConversionPatternRewriter::mergeBlocks(Block *source, Block *dest,
- ValueRange argValues) {
- impl->notifyBlocksBeingMerged(dest, source);
- assert(llvm::all_of(source->getPredecessors(),
- [dest](Block *succ) { return succ == dest; }) &&
- "expected 'source' to have no predecessors or only 'dest'");
+void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
+ Block::iterator before,
+ ValueRange argValues) {
assert(argValues.size() == source->getNumArguments() &&
"incorrect # of argument replacement values");
+#ifndef NDEBUG
+ auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
+#endif // NDEBUG
+ // The source block will be deleted, so it should not have any users (i.e.,
+ // there should be no predecessors).
+ assert(llvm::all_of(source->getUsers(), opIgnored) &&
+ "expected 'source' to have no predecessors");
+
+ impl->notifyBlockBeingInlined(dest, source, before);
for (auto it : llvm::zip(source->getArguments(), argValues))
replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
- dest->getOperations().splice(dest->end(), source->getOperations());
+ dest->getOperations().splice(before, source->getOperations());
eraseBlock(source);
}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 66c369d845bdf..29dc580b081c1 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1592,7 +1592,7 @@ struct TestMergeSingleBlockOps
Block &innerBlock = op.getRegion().front();
TerminatorOp innerTerminator =
cast<TerminatorOp>(innerBlock.getTerminator());
- rewriter.mergeBlockBefore(&innerBlock, op);
+ rewriter.inlineBlockBefore(&innerBlock, op);
rewriter.eraseOp(innerTerminator);
rewriter.eraseOp(op);
rewriter.updateRootInPlace(op, [] {});
More information about the Mlir-commits
mailing list