[Mlir-commits] [mlir] 42c31d8 - [mlir][IR] Clean up mergeBlockBefore and mergeBlocks

Matthias Springer llvmlistbot at llvm.org
Mon Mar 6 04:53:41 PST 2023


Author: Matthias Springer
Date: 2023-03-06T13:46:08+01:00
New Revision: 42c31d8302ff8f716601df9c276a5cd9ace6b158

URL: https://github.com/llvm/llvm-project/commit/42c31d8302ff8f716601df9c276a5cd9ace6b158
DIFF: https://github.com/llvm/llvm-project/commit/42c31d8302ff8f716601df9c276a5cd9ace6b158.diff

LOG: [mlir][IR] Clean up mergeBlockBefore and mergeBlocks

* `RewriterBase::mergeBlocks` is simplified: it is implemented in terms of `mergeBlockBefore`.
* The signature of `mergeBlockBefore` is consistent with other API (such as `inlineRegionBefore`): an overload for a `Block::iterator` is added.
* Additional safety checks are added to `mergeBlockBefore`: detect cases where the resulting IR could be invalid (no more `dropAllUses`) or partly unreachable (likely a case of incorrect API usage).
* Rename `mergeBlockBefore` to `inlineBlockBefore`.

Differential Revision: https://reviews.llvm.org/D144969

Added: 
    

Modified: 
    mlir/include/mlir/IR/PatternMatch.h
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
    mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
    mlir/lib/IR/PatternMatch.cpp
    mlir/lib/Transforms/Utils/DialectConversion.cpp
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index ed431badf05b0..9c4790c031201 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -491,17 +491,36 @@ class RewriterBase : public OpBuilder {
   /// This method erases all operations in a block.
   virtual void eraseBlock(Block *block);
 
-  /// Merge the operations of block 'source' into the end of block 'dest'.
-  /// 'source's predecessors must either be empty or only contain 'dest`.
-  /// 'argValues' is used to replace the block arguments of 'source' after
-  /// merging.
-  virtual void mergeBlocks(Block *source, Block *dest,
-                           ValueRange argValues = std::nullopt);
-
-  // Merge the operations of block 'source' before the operation 'op'. Source
-  // block should not have existing predecessors or successors.
-  void mergeBlockBefore(Block *source, Operation *op,
-                        ValueRange argValues = std::nullopt);
+  /// Inline the operations of block 'source' into block 'dest' before the given
+  /// position. The source block will be deleted and must have no uses.
+  /// 'argValues' is used to replace the block arguments of 'source'.
+  ///
+  /// If the source block is inserted at the end of the dest block, the dest
+  /// block must have no successors. Similarly, if the source block is inserted
+  /// somewhere in the middle (or beginning) of the dest block, the source block
+  /// must have no successors. Otherwise, the resulting IR would have
+  /// unreachable operations.
+  virtual void inlineBlockBefore(Block *source, Block *dest,
+                                 Block::iterator before,
+                                 ValueRange argValues = std::nullopt);
+
+  /// Inline the operations of block 'source' before the operation 'op'. The
+  /// source block will be deleted and must have no uses. 'argValues' is used to
+  /// replace the block arguments of 'source'
+  ///
+  /// The source block must have no successors. Otherwise, the resulting IR
+  /// would have unreachable operations.
+  void inlineBlockBefore(Block *source, Operation *op,
+                         ValueRange argValues = std::nullopt);
+
+  /// Inline the operations of block 'source' into the end of block 'dest'. The
+  /// source block will be deleted and must have no uses. 'argValues' is used to
+  /// replace the block arguments of 'source'
+  ///
+  /// The dest block must have no successors. Otherwise, the resulting IR would
+  /// have unreachable operation.
+  void mergeBlocks(Block *source, Block *dest,
+                   ValueRange argValues = std::nullopt);
 
   /// Split the operations starting at "before" (inclusive) out of the given
   /// block into a new block, and return it.

diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 229dc016957c6..b5127c099a366 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -702,8 +702,10 @@ class ConversionPatternRewriter final : public PatternRewriter,
   /// PatternRewriter hook for splitting a block into two parts.
   Block *splitBlock(Block *block, Block::iterator before) override;
 
-  /// PatternRewriter hook for merging a block into another.
-  void mergeBlocks(Block *source, Block *dest, ValueRange argValues) override;
+  /// PatternRewriter hook for inlining the ops of a block into another block.
+  void inlineBlockBefore(Block *source, Block *dest, Block::iterator before,
+                         ValueRange argValues = std::nullopt) override;
+  using PatternRewriter::inlineBlockBefore;
 
   /// PatternRewriter hook for moving blocks out of a region.
   void inlineRegionBefore(Region &region, Region &parent,

diff  --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 2f9ad2b7ad9ba..57ba18bd53f1b 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -506,7 +506,7 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
     Value arg = iterArgs[yieldOperands.size()];
     yieldOperands.push_back(reduceBlock.getTerminator()->getOperand(0));
     rewriter.eraseOp(reduceBlock.getTerminator());
-    rewriter.mergeBlockBefore(&reduceBlock, &op, {arg, reduce.getOperand()});
+    rewriter.inlineBlockBefore(&reduceBlock, &op, {arg, reduce.getOperand()});
     rewriter.eraseOp(reduce);
   }
 
@@ -516,8 +516,8 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
   if (newBody->empty())
     rewriter.mergeBlocks(parallelOp.getBody(), newBody, ivs);
   else
-    rewriter.mergeBlockBefore(parallelOp.getBody(), newBody->getTerminator(),
-                              ivs);
+    rewriter.inlineBlockBefore(parallelOp.getBody(), newBody->getTerminator(),
+                               ivs);
 
   // Finally, create the terminator if required (for loops with no results, it
   // has been already created in loop construction).

diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index e74829ac7758d..8b46f530ae394 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2724,7 +2724,7 @@ struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
     Operation *blockToMoveTerminator = blockToMove->getTerminator();
     // Promote the "blockToMove" block to the parent operation block between the
     // prologue and epilogue of "op".
-    rewriter.mergeBlockBefore(blockToMove, op);
+    rewriter.inlineBlockBefore(blockToMove, op);
     // Replace the "op" operation with the operands of the
     // "blockToMoveTerminator" operation. Note that "blockToMoveTerminator" is
     // the affine.yield operation present in the "blockToMove" block. It has no

diff  --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index eaf69f734bae3..6bdf42e527887 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -147,8 +147,9 @@ simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) {
     return failure();
 
   // Merge the successor into the current block and erase the branch.
-  rewriter.mergeBlocks(succ, opParent, op.getOperands());
+  SmallVector<Value> brOperands(op.getOperands());
   rewriter.eraseOp(op);
+  rewriter.mergeBlocks(succ, opParent, brOperands);
   return success();
 }
 

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 45fac0ce82829..c8999943d868e 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -553,7 +553,7 @@ struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> {
     Block *block = &op.getRegion().front();
     Operation *terminator = block->getTerminator();
     ValueRange results = terminator->getOperands();
-    rewriter.mergeBlockBefore(block, op);
+    rewriter.inlineBlockBefore(block, op);
     rewriter.replaceOp(op, results);
     rewriter.eraseOp(terminator);
     return success();

diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 8852b79b7b7ee..ed5e1bf04c5dc 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -107,7 +107,7 @@ static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
   Block *block = &region.front();
   Operation *terminator = block->getTerminator();
   ValueRange results = terminator->getOperands();
-  rewriter.mergeBlockBefore(block, op, blockArgs);
+  rewriter.inlineBlockBefore(block, op, blockArgs);
   rewriter.replaceOp(op, results);
   rewriter.eraseOp(terminator);
 }
@@ -630,7 +630,7 @@ namespace {
 // the ForOp region and can just be forwarded after simplifying the op inits,
 // yields and returns.
 //
-// The implementation uses `mergeBlockBefore` to steal the content of the
+// The implementation uses `inlineBlockBefore` to steal the content of the
 // original ForOp and avoid cloning.
 struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
   using OpRewritePattern<scf::ForOp>::OpRewritePattern;
@@ -645,7 +645,7 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
     // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
     // transformed block argument mappings. This plays the role of a
     // IRMapping for the particular use case of calling into
-    // `mergeBlockBefore`.
+    // `inlineBlockBefore`.
     SmallVector<bool, 4> keepMask;
     keepMask.reserve(yieldOp.getNumOperands());
     SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
@@ -715,7 +715,7 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
     // original terminator that has been merged in.
     if (newIterArgs.empty()) {
       auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
-      rewriter.mergeBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
+      rewriter.inlineBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
       rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
       rewriter.replaceOp(forOp, newResultValues);
       return success();

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 0663bd927fe5e..dfe2d74b1e54f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -168,7 +168,7 @@ static LogicalResult genForeachOnSparseConstant(ForeachOp op,
         auto cloned = cast<ForeachOp>(rewriter.clone(*op.getOperation()));
         assert(args.size() == cloned.getBody()->getNumArguments());
         Operation *yield = cloned.getBody()->getTerminator();
-        rewriter.mergeBlockBefore(cloned.getBody(), op, args);
+        rewriter.inlineBlockBefore(cloned.getBody(), op, args);
         // clean up
         rewriter.eraseOp(cloned);
         reduc = yield->getOperands();
@@ -988,7 +988,8 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
       // This is annoying, since scf.for inserts a implicit yield op when
       // there is no reduction variable upon creation, in this case we need to
       // merge the block *before* the yield op.
-      rewriter.mergeBlockBefore(srcBlock, &*rewriter.getInsertionPoint(), args);
+      rewriter.inlineBlockBefore(srcBlock, &*rewriter.getInsertionPoint(),
+                                 args);
     }
 
     for (Dimension d = 0; d < dimRank; d++) {

diff  --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index a0892b25e2c7b..4614a103b77e2 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -1211,7 +1211,7 @@ static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region &region,
   YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator());
   // Merge cloned block and return yield value.
   Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-  rewriter.mergeBlockBefore(&tmpRegion.front(), placeholder, vals);
+  rewriter.inlineBlockBefore(&tmpRegion.front(), placeholder, vals);
   Value val = clonedYield.getResult();
   rewriter.eraseOp(clonedYield);
   rewriter.eraseOp(placeholder);

diff  --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 1fc234e104219..052696d5cb13a 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -303,29 +303,6 @@ void RewriterBase::finalizeRootUpdate(Operation *op) {
     rewriteListener->notifyOperationModified(op);
 }
 
-/// Merge the operations of block 'source' into the end of block 'dest'.
-/// 'source's predecessors must be empty or only contain 'dest`.
-/// 'argValues' is used to replace the block arguments of 'source' after
-/// merging.
-void RewriterBase::mergeBlocks(Block *source, Block *dest,
-                               ValueRange argValues) {
-  assert(llvm::all_of(source->getPredecessors(),
-                      [dest](Block *succ) { return succ == dest; }) &&
-         "expected 'source' to have no predecessors or only 'dest'");
-  assert(argValues.size() == source->getNumArguments() &&
-         "incorrect # of argument replacement values");
-
-  // Replace all of the successor arguments with the provided values.
-  for (auto it : llvm::zip(source->getArguments(), argValues))
-    replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
-
-  // Splice the operations of the 'source' block into the 'dest' block and erase
-  // it.
-  dest->getOperations().splice(dest->end(), source->getOperations());
-  source->dropAllUses();
-  source->erase();
-}
-
 /// Find uses of `from` and replace them with `to` if the `functor` returns
 /// true. It also marks every modified uses and notifies the rewriter that an
 /// in-place operation modification is about to happen.
@@ -337,26 +314,48 @@ void RewriterBase::replaceUsesWithIf(Value from, Value to,
   }
 }
 
-// Merge the operations of block 'source' before the operation 'op'. Source
-// block should not have existing predecessors or successors.
-void RewriterBase::mergeBlockBefore(Block *source, Operation *op,
-                                    ValueRange argValues) {
+void RewriterBase::inlineBlockBefore(Block *source, Block *dest,
+                                     Block::iterator before,
+                                     ValueRange argValues) {
+  assert(argValues.size() == source->getNumArguments() &&
+         "incorrect # of argument replacement values");
+
+  // The source block will be deleted, so it should not have any users (i.e.,
+  // there should be no predecessors).
   assert(source->hasNoPredecessors() &&
          "expected 'source' to have no predecessors");
-  assert(source->hasNoSuccessors() &&
-         "expected 'source' to have no successors");
 
-  // Split the block containing 'op' into two, one containing all operations
-  // before 'op' (prologue) and another (epilogue) containing 'op' and all
-  // operations after it.
-  Block *prologue = op->getBlock();
-  Block *epilogue = splitBlock(prologue, op->getIterator());
+  if (dest->end() != before) {
+    // The source block will be inserted in the middle of the dest block, so
+    // the source block should have no successors. Otherwise, the remainder of
+    // the dest block would be unreachable.
+    assert(source->hasNoSuccessors() &&
+           "expected 'source' to have no successors");
+  } else {
+    // The source block will be inserted at the end of the dest block, so the
+    // dest block should have no successors. Otherwise, the inserted operations
+    // will be unreachable.
+    assert(dest->hasNoSuccessors() && "expected 'dest' to have no successors");
+  }
 
-  // Merge the source block at the end of the prologue.
-  mergeBlocks(source, prologue, argValues);
+  // Replace all of the successor arguments with the provided values.
+  for (auto it : llvm::zip(source->getArguments(), argValues))
+    replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
 
-  // Merge the epilogue at the end the prologue.
-  mergeBlocks(epilogue, prologue);
+  // Move operations from the source block to the dest block and erase the
+  // source block.
+  dest->getOperations().splice(before, source->getOperations());
+  source->erase();
+}
+
+void RewriterBase::inlineBlockBefore(Block *source, Operation *op,
+                                     ValueRange argValues) {
+  inlineBlockBefore(source, op->getBlock(), op->getIterator(), argValues);
+}
+
+void RewriterBase::mergeBlocks(Block *source, Block *dest,
+                               ValueRange argValues) {
+  inlineBlockBefore(source, dest, dest->end(), argValues);
 }
 
 /// Split the operations starting at "before" (inclusive) out of the given

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 52ebd54e1bee8..99f4bf6ba092f 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -289,7 +289,7 @@ struct OpReplacement {
 enum class BlockActionKind {
   Create,
   Erase,
-  Merge,
+  Inline,
   Move,
   Split,
   TypeConversion
@@ -302,13 +302,14 @@ struct BlockPosition {
   Block *insertAfterBlock;
 };
 
-/// Information needed to undo the merge actions.
-/// - the source block, and
-/// - the Operation that was the last operation in the dest block before the
-///   merge (could be null if the dest block was empty).
-struct MergeInfo {
+/// Information needed to undo inlining actions.
+/// - the source block
+/// - the first inlined operation (could be null if the source block was empty)
+/// - the last inlined operation (could be null if the source block was empty)
+struct InlineInfo {
   Block *sourceBlock;
-  Operation *destBlockLastInst;
+  Operation *firstInlinedInst;
+  Operation *lastInlinedInst;
 };
 
 /// The storage class for an undoable block action (one of BlockActionKind),
@@ -320,9 +321,12 @@ struct BlockAction {
   static BlockAction getErase(Block *block, BlockPosition originalPosition) {
     return {BlockActionKind::Erase, block, {originalPosition}};
   }
-  static BlockAction getMerge(Block *block, Block *sourceBlock) {
-    BlockAction action{BlockActionKind::Merge, block, {}};
-    action.mergeInfo = {sourceBlock, block->empty() ? nullptr : &block->back()};
+  static BlockAction getInline(Block *block, Block *srcBlock,
+                               Block::iterator before) {
+    BlockAction action{BlockActionKind::Inline, block, {}};
+    action.inlineInfo = {srcBlock,
+                         srcBlock->empty() ? nullptr : &srcBlock->front(),
+                         srcBlock->empty() ? nullptr : &srcBlock->back()};
     return action;
   }
   static BlockAction getMove(Block *block, BlockPosition originalPosition) {
@@ -344,16 +348,16 @@ struct BlockAction {
   Block *block;
 
   union {
-    // In use if kind == BlockActionKind::Move or BlockActionKind::Erase, and
+    // In use if kind == BlockActionKind::Inline or BlockActionKind::Erase, and
     // contains a pointer to the region that originally contained the block as
     // well as the position of the block in that region.
     BlockPosition originalPosition;
     // In use if kind == BlockActionKind::Split and contains a pointer to the
     // block that was split into two parts.
     Block *originalBlock;
-    // In use if kind == BlockActionKind::Merge, and contains the information
-    // needed to undo the merge.
-    MergeInfo mergeInfo;
+    // In use if kind == BlockActionKind::Inline, and contains the information
+    // needed to undo the inlining.
+    InlineInfo inlineInfo;
   };
 };
 
@@ -945,8 +949,9 @@ struct ConversionPatternRewriterImpl {
   /// Notifies that a block was split.
   void notifySplitBlock(Block *block, Block *continuation);
 
-  /// Notifies that `block` is being merged with `srcBlock`.
-  void notifyBlocksBeingMerged(Block *block, Block *srcBlock);
+  /// Notifies that a block is being inlined into another block.
+  void notifyBlockBeingInlined(Block *block, Block *srcBlock,
+                               Block::iterator before);
 
   /// Notifies that the blocks of a region are about to be moved.
   void notifyRegionIsBeingInlinedBefore(Region &region, Region &parent,
@@ -1213,18 +1218,17 @@ void ConversionPatternRewriterImpl::undoBlockActions(
                        action.block);
       break;
     }
-    // Split the block at the position which was originally the end of the
-    // destination block (owned by action), and put the instructions back into
-    // the block used before the merge.
-    case BlockActionKind::Merge: {
-      Block *sourceBlock = action.mergeInfo.sourceBlock;
-      Block::iterator splitPoint =
-          (action.mergeInfo.destBlockLastInst
-               ? ++Block::iterator(action.mergeInfo.destBlockLastInst)
-               : action.block->begin());
-      sourceBlock->getOperations().splice(sourceBlock->begin(),
-                                          action.block->getOperations(),
-                                          splitPoint, action.block->end());
+    // Put the instructions from the destination block (owned by the action)
+    // back into the source block.
+    case BlockActionKind::Inline: {
+      Block *sourceBlock = action.inlineInfo.sourceBlock;
+      if (action.inlineInfo.firstInlinedInst) {
+        assert(action.inlineInfo.lastInlinedInst && "expected operation");
+        sourceBlock->getOperations().splice(
+            sourceBlock->begin(), action.block->getOperations(),
+            Block::iterator(action.inlineInfo.firstInlinedInst),
+            ++Block::iterator(action.inlineInfo.lastInlinedInst));
+      }
       break;
     }
     // Move the block back to its original position.
@@ -1445,9 +1449,9 @@ void ConversionPatternRewriterImpl::notifySplitBlock(Block *block,
   blockActions.push_back(BlockAction::getSplit(continuation, block));
 }
 
-void ConversionPatternRewriterImpl::notifyBlocksBeingMerged(Block *block,
-                                                            Block *srcBlock) {
-  blockActions.push_back(BlockAction::getMerge(block, srcBlock));
+void ConversionPatternRewriterImpl::notifyBlockBeingInlined(
+    Block *block, Block *srcBlock, Block::iterator before) {
+  blockActions.push_back(BlockAction::getInline(block, srcBlock, before));
 }
 
 void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore(
@@ -1603,17 +1607,23 @@ Block *ConversionPatternRewriter::splitBlock(Block *block,
   return continuation;
 }
 
-void ConversionPatternRewriter::mergeBlocks(Block *source, Block *dest,
-                                            ValueRange argValues) {
-  impl->notifyBlocksBeingMerged(dest, source);
-  assert(llvm::all_of(source->getPredecessors(),
-                      [dest](Block *succ) { return succ == dest; }) &&
-         "expected 'source' to have no predecessors or only 'dest'");
+void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
+                                                  Block::iterator before,
+                                                  ValueRange argValues) {
   assert(argValues.size() == source->getNumArguments() &&
          "incorrect # of argument replacement values");
+#ifndef NDEBUG
+  auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
+#endif // NDEBUG
+  // The source block will be deleted, so it should not have any users (i.e.,
+  // there should be no predecessors).
+  assert(llvm::all_of(source->getUsers(), opIgnored) &&
+         "expected 'source' to have no predecessors");
+
+  impl->notifyBlockBeingInlined(dest, source, before);
   for (auto it : llvm::zip(source->getArguments(), argValues))
     replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
-  dest->getOperations().splice(dest->end(), source->getOperations());
+  dest->getOperations().splice(before, source->getOperations());
   eraseBlock(source);
 }
 

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 66c369d845bdf..29dc580b081c1 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1592,7 +1592,7 @@ struct TestMergeSingleBlockOps
     Block &innerBlock = op.getRegion().front();
     TerminatorOp innerTerminator =
         cast<TerminatorOp>(innerBlock.getTerminator());
-    rewriter.mergeBlockBefore(&innerBlock, op);
+    rewriter.inlineBlockBefore(&innerBlock, op);
     rewriter.eraseOp(innerTerminator);
     rewriter.eraseOp(op);
     rewriter.updateRootInPlace(op, [] {});


        


More information about the Mlir-commits mailing list