[llvm-branch-commits] [mlir] [mlir][Transforms] Support `moveOpBefore`/`After` in dialect conversion (PR #81240)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Feb 12 01:10:02 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/81240

>From c60c43bcd2296715ceca83a3f9666433883ec303 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 12 Feb 2024 09:05:50 +0000
Subject: [PATCH 1/2] [mlir][Transforms][WIP] RewriteAction

BEGIN_PUBLIC
No public commit message needed for presubmit.
END_PUBLIC
---
 .../Transforms/Utils/DialectConversion.cpp    | 504 +++++++++++-------
 1 file changed, 306 insertions(+), 198 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index e41231d7cbd390..edca84e5a73f04 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -154,13 +154,12 @@ namespace {
 struct RewriterState {
   RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
                 unsigned numReplacements, unsigned numArgReplacements,
-                unsigned numBlockActions, unsigned numIgnoredOperations,
+                unsigned numRewrites, unsigned numIgnoredOperations,
                 unsigned numRootUpdates)
       : numCreatedOps(numCreatedOps),
         numUnresolvedMaterializations(numUnresolvedMaterializations),
         numReplacements(numReplacements),
-        numArgReplacements(numArgReplacements),
-        numBlockActions(numBlockActions),
+        numArgReplacements(numArgReplacements), numRewrites(numRewrites),
         numIgnoredOperations(numIgnoredOperations),
         numRootUpdates(numRootUpdates) {}
 
@@ -176,8 +175,8 @@ struct RewriterState {
   /// The current number of argument replacements queued.
   unsigned numArgReplacements;
 
-  /// The current number of block actions performed.
-  unsigned numBlockActions;
+  /// The current number of rewrites performed.
+  unsigned numRewrites;
 
   /// The current number of ignored operations.
   unsigned numIgnoredOperations;
@@ -235,86 +234,6 @@ struct OpReplacement {
   const TypeConverter *converter;
 };
 
-//===----------------------------------------------------------------------===//
-// BlockAction
-
-/// The kind of the block action performed during the rewrite.  Actions can be
-/// undone if the conversion fails.
-enum class BlockActionKind {
-  Create,
-  Erase,
-  Inline,
-  Move,
-  Split,
-  TypeConversion
-};
-
-/// Original position of the given block in its parent region. During undo
-/// actions, the block needs to be placed before `insertBeforeBlock`.
-struct BlockPosition {
-  Region *region;
-  Block *insertBeforeBlock;
-};
-
-/// Information needed to undo inlining actions.
-/// - the source block
-/// - the first inlined operation (could be null if the source block was empty)
-/// - the last inlined operation (could be null if the source block was empty)
-struct InlineInfo {
-  Block *sourceBlock;
-  Operation *firstInlinedInst;
-  Operation *lastInlinedInst;
-};
-
-/// The storage class for an undoable block action (one of BlockActionKind),
-/// contains the information necessary to undo this action.
-struct BlockAction {
-  static BlockAction getCreate(Block *block) {
-    return {BlockActionKind::Create, block, {}};
-  }
-  static BlockAction getErase(Block *block, BlockPosition originalPosition) {
-    return {BlockActionKind::Erase, block, {originalPosition}};
-  }
-  static BlockAction getInline(Block *block, Block *srcBlock,
-                               Block::iterator before) {
-    BlockAction action{BlockActionKind::Inline, block, {}};
-    action.inlineInfo = {srcBlock,
-                         srcBlock->empty() ? nullptr : &srcBlock->front(),
-                         srcBlock->empty() ? nullptr : &srcBlock->back()};
-    return action;
-  }
-  static BlockAction getMove(Block *block, BlockPosition originalPosition) {
-    return {BlockActionKind::Move, block, {originalPosition}};
-  }
-  static BlockAction getSplit(Block *block, Block *originalBlock) {
-    BlockAction action{BlockActionKind::Split, block, {}};
-    action.originalBlock = originalBlock;
-    return action;
-  }
-  static BlockAction getTypeConversion(Block *block) {
-    return BlockAction{BlockActionKind::TypeConversion, block, {}};
-  }
-
-  // The action kind.
-  BlockActionKind kind;
-
-  // A pointer to the block that was created by the action.
-  Block *block;
-
-  union {
-    // In use if kind == BlockActionKind::Inline or BlockActionKind::Erase, and
-    // contains a pointer to the region that originally contained the block as
-    // well as the position of the block in that region.
-    BlockPosition originalPosition;
-    // In use if kind == BlockActionKind::Split and contains a pointer to the
-    // block that was split into two parts.
-    Block *originalBlock;
-    // In use if kind == BlockActionKind::Inline, and contains the information
-    // needed to undo the inlining.
-    InlineInfo inlineInfo;
-  };
-};
-
 //===----------------------------------------------------------------------===//
 // UnresolvedMaterialization
 
@@ -820,6 +739,251 @@ void ArgConverter::insertConversion(Block *newBlock,
   conversionInfo.insert({newBlock, std::move(info)});
 }
 
+//===----------------------------------------------------------------------===//
+// IR rewrites
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// An IR rewrite that can be committed (upon success) or rolled back (upon
+/// failure).
+///
+/// The dialect conversion keeps track of IR modifications (requested by the
+/// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites
+/// are directly applied to the IR as the rewriter API is used, some are applied
+/// partially, and some are delayed until the `IRRewrite` objects are committed.
+class IRRewrite {
+public:
+  /// The kind of the rewrite. Rewrites can be undone if the conversion fails.
+  enum class Kind {
+    CreateBlock,
+    EraseBlock,
+    InlineBlock,
+    MoveBlock,
+    SplitBlock,
+    BlockTypeConversion
+  };
+
+  virtual ~IRRewrite() = default;
+
+  /// Roll back the rewrite.
+  virtual void rollback() = 0;
+
+  /// Commit the rewrite.
+  virtual void commit() {}
+
+  Kind getKind() const { return kind; }
+
+  static bool classof(const IRRewrite *rewrite) { return true; }
+
+protected:
+  IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
+      : kind(kind), rewriterImpl(rewriterImpl) {}
+
+  const Kind kind;
+  ConversionPatternRewriterImpl &rewriterImpl;
+};
+
+/// A block rewrite.
+class BlockRewrite : public IRRewrite {
+public:
+  /// Return the block that this rewrite operates on.
+  Block *getBlock() const { return block; }
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() >= Kind::CreateBlock &&
+           rewrite->getKind() <= Kind::BlockTypeConversion;
+  }
+
+protected:
+  BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
+               Block *block)
+      : IRRewrite(kind, rewriterImpl), block(block) {}
+
+  // The block that this rewrite operates on.
+  Block *block;
+};
+
+/// Creation of a block. Block creations are immediately reflected in the IR.
+/// There is no extra work to commit the rewrite. During rollback, the newly
+/// created block is erased.
+class CreateBlockRewrite : public BlockRewrite {
+public:
+  CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
+      : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {}
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() == Kind::CreateBlock;
+  }
+
+  void rollback() override {
+    // Unlink all of the operations within this block, they will be deleted
+    // separately.
+    auto &blockOps = block->getOperations();
+    while (!blockOps.empty())
+      blockOps.remove(blockOps.begin());
+    block->dropAllDefinedValueUses();
+    block->erase();
+  }
+};
+
+/// Erasure of a block. Block erasures are partially reflected in the IR. Erased
+/// blocks are immediately unlinked, but only erased when the rewrite is
+/// committed. This makes it easier to rollback a block erasure: the block is
+/// simply inserted into its original location.
+class EraseBlockRewrite : public BlockRewrite {
+public:
+  EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
+                    Region *region, Block *insertBeforeBlock)
+      : BlockRewrite(Kind::EraseBlock, rewriterImpl, block), region(region),
+        insertBeforeBlock(insertBeforeBlock) {}
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() == Kind::EraseBlock;
+  }
+
+  ~EraseBlockRewrite() override {
+    assert(!block && "rewrite was neither rolled back nor committed");
+  }
+
+  void rollback() override {
+    // The block (owned by this rewrite) was not actually erased yet. It was
+    // just unlinked. Put it back into its original position.
+    assert(block && "expected block");
+    auto &blockList = region->getBlocks();
+    Region::iterator before = insertBeforeBlock
+                                  ? Region::iterator(insertBeforeBlock)
+                                  : blockList.end();
+    blockList.insert(before, block);
+    block = nullptr;
+  }
+
+  void commit() override {
+    // Erase the block.
+    assert(block && "expected block");
+    delete block;
+    block = nullptr;
+  }
+
+private:
+  // The region in which this block was previously contained.
+  Region *region;
+
+  // The original successor of this block before it was unlinked. "nullptr" if
+  // this block was the only block in the region.
+  Block *insertBeforeBlock;
+};
+
+/// Inlining of a block. This rewrite is immediately reflected in the IR.
+/// Note: This rewrite represents only the inlining of the operations. The
+/// erasure of the inlined block is a separate rewrite.
+class InlineBlockRewrite : public BlockRewrite {
+public:
+  InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
+                     Block *sourceBlock, Block::iterator before)
+      : BlockRewrite(Kind::InlineBlock, rewriterImpl, block),
+        sourceBlock(sourceBlock),
+        firstInlinedInst(sourceBlock->empty() ? nullptr
+                                              : &sourceBlock->front()),
+        lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
+  }
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() == Kind::InlineBlock;
+  }
+
+  void rollback() override {
+    // Put the operations from the destination block (owned by the rewrite)
+    // back into the source block.
+    if (firstInlinedInst) {
+      assert(lastInlinedInst && "expected operation");
+      sourceBlock->getOperations().splice(sourceBlock->begin(),
+                                          block->getOperations(),
+                                          Block::iterator(firstInlinedInst),
+                                          ++Block::iterator(lastInlinedInst));
+    }
+  }
+
+private:
+  // The block that originally contained the operations.
+  Block *sourceBlock;
+
+  // The first inlined operation.
+  Operation *firstInlinedInst;
+
+  // The last inlined operation.
+  Operation *lastInlinedInst;
+};
+
+/// Moving of a block. This rewrite is immediately reflected in the IR.
+class MoveBlockRewrite : public BlockRewrite {
+public:
+  MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
+                   Region *region, Block *insertBeforeBlock)
+      : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region),
+        insertBeforeBlock(insertBeforeBlock) {}
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() == Kind::MoveBlock;
+  }
+
+  void rollback() override {
+    // Move the block back to its original position.
+    Region::iterator before =
+        insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end();
+    region->getBlocks().splice(before, block->getParent()->getBlocks(), block);
+  }
+
+private:
+  // The region in which this block was previously contained.
+  Region *region;
+
+  // The original successor of this block before it was moved. "nullptr" if
+  // this block was the only block in the region.
+  Block *insertBeforeBlock;
+};
+
+/// Splitting of a block. This rewrite is immediately reflected in the IR.
+class SplitBlockRewrite : public BlockRewrite {
+public:
+  SplitBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
+                    Block *originalBlock)
+      : BlockRewrite(Kind::SplitBlock, rewriterImpl, block),
+        originalBlock(originalBlock) {}
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() == Kind::SplitBlock;
+  }
+
+  void rollback() override {
+    // Merge back the block that was split out.
+    originalBlock->getOperations().splice(originalBlock->end(),
+                                          block->getOperations());
+    block->dropAllDefinedValueUses();
+    block->erase();
+  }
+
+private:
+  // The original block from which this block was split.
+  Block *originalBlock;
+};
+
+/// Block type conversion. This rewrite is partially reflected in the IR.
+class BlockTypeConversionRewrite : public BlockRewrite {
+public:
+  BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+                             Block *block)
+      : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block) {}
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() == Kind::BlockTypeConversion;
+  }
+
+  // TODO: Block type conversions are currently committed in
+  // `ArgConverter::applyRewrites`. This should be done in the "commit" method.
+  void rollback() override;
+};
+} // namespace
+
 //===----------------------------------------------------------------------===//
 // ConversionPatternRewriterImpl
 //===----------------------------------------------------------------------===//
@@ -848,13 +1012,17 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// Reset the state of the rewriter to a previously saved point.
   void resetState(RewriterState state);
 
-  /// Erase any blocks that were unlinked from their regions and stored in block
-  /// actions.
-  void eraseDanglingBlocks();
+  /// Append a rewrite. Rewrites are committed upon success and rolled back upon
+  /// failure.
+  template <typename ActionTy, typename... Args>
+  void appendRewrite(Args &&...args) {
+    rewrites.push_back(
+        std::make_unique<ActionTy>(*this, std::forward<Args>(args)...));
+  }
 
-  /// Undo the block actions (motions, splits) one by one in reverse order until
-  /// "numActionsToKeep" actions remains.
-  void undoBlockActions(unsigned numActionsToKeep = 0);
+  /// Undo the rewrites (motions, splits) one by one in reverse order until
+  /// "numRewritesToKeep" rewrites remains.
+  void undoRewrites(unsigned numRewritesToKeep = 0);
 
   /// Remap the given values to those with potentially different types. Returns
   /// success if the values could be remapped, failure otherwise. `valueDiagTag`
@@ -954,7 +1122,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   SmallVector<BlockArgument, 4> argReplacements;
 
   /// Ordered list of block operations (creations, splits, motions).
-  SmallVector<BlockAction, 4> blockActions;
+  SmallVector<std::unique_ptr<IRRewrite>> rewrites;
 
   /// A set of operations that should no longer be considered for legalization,
   /// but were not directly replace/erased/etc. by a pattern. These are
@@ -995,6 +1163,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
 } // namespace detail
 } // namespace mlir
 
+void BlockTypeConversionRewrite::rollback() {
+  // Undo the type conversion.
+  rewriterImpl.argConverter.discardRewrites(block);
+}
+
 /// Detach any operations nested in the given operation from their parent
 /// blocks, and erase the given operation. This can be used when the nested
 /// operations are scheduled for erasure themselves, so deleting the regions of
@@ -1020,7 +1193,7 @@ void ConversionPatternRewriterImpl::discardRewrites() {
   for (auto &state : rootUpdates)
     state.resetOperation();
 
-  undoBlockActions();
+  undoRewrites();
 
   // Remove any newly created ops.
   for (UnresolvedMaterialization &materialization : unresolvedMaterializations)
@@ -1083,8 +1256,9 @@ void ConversionPatternRewriterImpl::applyRewrites() {
 
   argConverter.applyRewrites(mapping);
 
-  // Now that the ops have been erased, also erase dangling blocks.
-  eraseDanglingBlocks();
+  // Commit all rewrites.
+  for (auto &rewrite : rewrites)
+    rewrite->commit();
 }
 
 //===----------------------------------------------------------------------===//
@@ -1093,8 +1267,7 @@ void ConversionPatternRewriterImpl::applyRewrites() {
 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
   return RewriterState(createdOps.size(), unresolvedMaterializations.size(),
                        replacements.size(), argReplacements.size(),
-                       blockActions.size(), ignoredOps.size(),
-                       rootUpdates.size());
+                       rewrites.size(), ignoredOps.size(), rootUpdates.size());
 }
 
 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
@@ -1109,8 +1282,8 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
     mapping.erase(replacedArg);
   argReplacements.resize(state.numArgReplacements);
 
-  // Undo any block actions.
-  undoBlockActions(state.numBlockActions);
+  // Undo any rewrites.
+  undoRewrites(state.numRewrites);
 
   // Reset any replaced operations and undo any saved mappings.
   for (auto &repl : llvm::drop_begin(replacements, state.numReplacements))
@@ -1149,76 +1322,11 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
     operationsWithChangedResults.pop_back();
 }
 
-void ConversionPatternRewriterImpl::eraseDanglingBlocks() {
-  for (auto &action : blockActions)
-    if (action.kind == BlockActionKind::Erase)
-      delete action.block;
-}
-
-void ConversionPatternRewriterImpl::undoBlockActions(
-    unsigned numActionsToKeep) {
-  for (auto &action :
-       llvm::reverse(llvm::drop_begin(blockActions, numActionsToKeep))) {
-    switch (action.kind) {
-    // Delete the created block.
-    case BlockActionKind::Create: {
-      // Unlink all of the operations within this block, they will be deleted
-      // separately.
-      auto &blockOps = action.block->getOperations();
-      while (!blockOps.empty())
-        blockOps.remove(blockOps.begin());
-      action.block->dropAllDefinedValueUses();
-      action.block->erase();
-      break;
-    }
-    // Put the block (owned by action) back into its original position.
-    case BlockActionKind::Erase: {
-      auto &blockList = action.originalPosition.region->getBlocks();
-      Block *insertBeforeBlock = action.originalPosition.insertBeforeBlock;
-      blockList.insert((insertBeforeBlock ? Region::iterator(insertBeforeBlock)
-                                          : blockList.end()),
-                       action.block);
-      break;
-    }
-    // Put the instructions from the destination block (owned by the action)
-    // back into the source block.
-    case BlockActionKind::Inline: {
-      Block *sourceBlock = action.inlineInfo.sourceBlock;
-      if (action.inlineInfo.firstInlinedInst) {
-        assert(action.inlineInfo.lastInlinedInst && "expected operation");
-        sourceBlock->getOperations().splice(
-            sourceBlock->begin(), action.block->getOperations(),
-            Block::iterator(action.inlineInfo.firstInlinedInst),
-            ++Block::iterator(action.inlineInfo.lastInlinedInst));
-      }
-      break;
-    }
-    // Move the block back to its original position.
-    case BlockActionKind::Move: {
-      Region *originalRegion = action.originalPosition.region;
-      Block *insertBeforeBlock = action.originalPosition.insertBeforeBlock;
-      originalRegion->getBlocks().splice(
-          (insertBeforeBlock ? Region::iterator(insertBeforeBlock)
-                             : originalRegion->end()),
-          action.block->getParent()->getBlocks(), action.block);
-      break;
-    }
-    // Merge back the block that was split out.
-    case BlockActionKind::Split: {
-      action.originalBlock->getOperations().splice(
-          action.originalBlock->end(), action.block->getOperations());
-      action.block->dropAllDefinedValueUses();
-      action.block->erase();
-      break;
-    }
-    // Undo the type conversion.
-    case BlockActionKind::TypeConversion: {
-      argConverter.discardRewrites(action.block);
-      break;
-    }
-    }
-  }
-  blockActions.resize(numActionsToKeep);
+void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
+  for (auto &rewrite :
+       llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
+    rewrite->rollback();
+  rewrites.resize(numRewritesToKeep);
 }
 
 LogicalResult ConversionPatternRewriterImpl::remapValues(
@@ -1309,7 +1417,7 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
     return failure();
   if (Block *newBlock = *result) {
     if (newBlock != block)
-      blockActions.push_back(BlockAction::getTypeConversion(newBlock));
+      appendRewrite<BlockTypeConversionRewrite>(newBlock);
   }
   return result;
 }
@@ -1410,28 +1518,28 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
 void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
   Region *region = block->getParent();
   Block *origNextBlock = block->getNextNode();
-  blockActions.push_back(BlockAction::getErase(block, {region, origNextBlock}));
+  appendRewrite<EraseBlockRewrite>(block, region, origNextBlock);
 }
 
 void ConversionPatternRewriterImpl::notifyBlockInserted(
     Block *block, Region *previous, Region::iterator previousIt) {
   if (!previous) {
     // This is a newly created block.
-    blockActions.push_back(BlockAction::getCreate(block));
+    appendRewrite<CreateBlockRewrite>(block);
     return;
   }
   Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt;
-  blockActions.push_back(BlockAction::getMove(block, {previous, prevBlock}));
+  appendRewrite<MoveBlockRewrite>(block, previous, prevBlock);
 }
 
 void ConversionPatternRewriterImpl::notifySplitBlock(Block *block,
                                                      Block *continuation) {
-  blockActions.push_back(BlockAction::getSplit(continuation, block));
+  appendRewrite<SplitBlockRewrite>(continuation, block);
 }
 
 void ConversionPatternRewriterImpl::notifyBlockBeingInlined(
     Block *block, Block *srcBlock, Block::iterator before) {
-  blockActions.push_back(BlockAction::getInline(block, srcBlock, before));
+  appendRewrite<InlineBlockRewrite>(block, srcBlock, before);
 }
 
 void ConversionPatternRewriterImpl::notifyMatchFailure(
@@ -1501,8 +1609,8 @@ void ConversionPatternRewriter::eraseBlock(Block *block) {
   for (Operation &op : *block)
     eraseOp(&op);
 
-  // Unlink the block from its parent region. The block is kept in the block
-  // action and will be actually destroyed when rewrites are applied. This
+  // Unlink the block from its parent region. The block is kept in the rewrite
+  // object and will be actually destroyed when rewrites are applied. This
   // allows us to keep the operations in the block live and undo the removal by
   // re-inserting the block.
   block->getParent()->getBlocks().remove(block);
@@ -1700,11 +1808,11 @@ class OperationLegalizer {
                                       RewriterState &curState);
 
   /// Legalizes the actions registered during the execution of a pattern.
-  LogicalResult legalizePatternBlockActions(Operation *op,
-                                            ConversionPatternRewriter &rewriter,
-                                            ConversionPatternRewriterImpl &impl,
-                                            RewriterState &state,
-                                            RewriterState &newState);
+  LogicalResult
+  legalizePatternBlockRewrites(Operation *op,
+                               ConversionPatternRewriter &rewriter,
+                               ConversionPatternRewriterImpl &impl,
+                               RewriterState &state, RewriterState &newState);
   LogicalResult legalizePatternCreatedOperations(
       ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
       RewriterState &state, RewriterState &newState);
@@ -1986,8 +2094,8 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
 
   // Legalize each of the actions registered during application.
   RewriterState newState = impl.getCurrentState();
-  if (failed(legalizePatternBlockActions(op, rewriter, impl, curState,
-                                         newState)) ||
+  if (failed(legalizePatternBlockRewrites(op, rewriter, impl, curState,
+                                          newState)) ||
       failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) ||
       failed(legalizePatternCreatedOperations(rewriter, impl, curState,
                                               newState))) {
@@ -1998,7 +2106,7 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
   return success();
 }
 
-LogicalResult OperationLegalizer::legalizePatternBlockActions(
+LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
     Operation *op, ConversionPatternRewriter &rewriter,
     ConversionPatternRewriterImpl &impl, RewriterState &state,
     RewriterState &newState) {
@@ -2006,22 +2114,22 @@ LogicalResult OperationLegalizer::legalizePatternBlockActions(
 
   // If the pattern moved or created any blocks, make sure the types of block
   // arguments get legalized.
-  for (int i = state.numBlockActions, e = newState.numBlockActions; i != e;
-       ++i) {
-    auto &action = impl.blockActions[i];
-    if (action.kind == BlockActionKind::TypeConversion ||
-        action.kind == BlockActionKind::Erase)
+  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
+    BlockRewrite *rewrite = dyn_cast<BlockRewrite>(impl.rewrites[i].get());
+    if (!rewrite)
+      continue;
+    Block *block = rewrite->getBlock();
+    if (isa<BlockTypeConversionRewrite, EraseBlockRewrite>(rewrite))
       continue;
     // Only check blocks outside of the current operation.
-    Operation *parentOp = action.block->getParentOp();
-    if (!parentOp || parentOp == op || action.block->getNumArguments() == 0)
+    Operation *parentOp = block->getParentOp();
+    if (!parentOp || parentOp == op || block->getNumArguments() == 0)
       continue;
 
     // If the region of the block has a type converter, try to convert the block
     // directly.
-    if (auto *converter =
-            impl.argConverter.getConverter(action.block->getParent())) {
-      if (failed(impl.convertBlockSignature(action.block, converter))) {
+    if (auto *converter = impl.argConverter.getConverter(block->getParent())) {
+      if (failed(impl.convertBlockSignature(block, converter))) {
         LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
                                            "block"));
         return failure();
@@ -2042,9 +2150,9 @@ LogicalResult OperationLegalizer::legalizePatternBlockActions(
     // If this operation should be considered for re-legalization, try it.
     if (operationsToIgnore.insert(parentOp).second &&
         failed(legalize(parentOp, rewriter))) {
-      LLVM_DEBUG(logFailure(
-          impl.logger, "operation '{0}'({1}) became illegal after block action",
-          parentOp->getName(), parentOp));
+      LLVM_DEBUG(logFailure(impl.logger,
+                            "operation '{0}'({1}) became illegal after rewrite",
+                            parentOp->getName(), parentOp));
       return failure();
     }
   }

>From ebfaca6b688394233b0d6a22f77b8b7cccaf67a8 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 12 Feb 2024 09:08:21 +0000
Subject: [PATCH 2/2] [mlir][Transforms] Support `moveOpBefore`/`After` in
 dialect conversion

Add a new rewrite action for "operation movements". This action can roll back `moveOpBefore` and `moveOpAfter`.

`RewriterBase::moveOpBefore` and `RewriterBase::moveOpAfter` is no longer virtual. (The dialect conversion can gather all required information for rollbacks from listener notifications.)

BEGIN_PUBLIC
No public commit message needed for presubmit.
END_PUBLIC
---
 mlir/include/mlir/IR/PatternMatch.h           |  6 +-
 .../mlir/Transforms/DialectConversion.h       |  9 +--
 .../Transforms/Utils/DialectConversion.cpp    | 74 +++++++++++++++----
 mlir/test/Transforms/test-legalizer.mlir      | 14 ++++
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   | 20 ++++-
 5 files changed, 95 insertions(+), 28 deletions(-)

diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 78dcfe7f6fc3d2..b8aeea0d23475b 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -588,8 +588,7 @@ class RewriterBase : public OpBuilder {
 
   /// Unlink this operation from its current block and insert it right before
   /// `iterator` in the specified block.
-  virtual void moveOpBefore(Operation *op, Block *block,
-                            Block::iterator iterator);
+  void moveOpBefore(Operation *op, Block *block, Block::iterator iterator);
 
   /// Unlink this operation from its current block and insert it right after
   /// `existingOp` which may be in the same or another block in the same
@@ -598,8 +597,7 @@ class RewriterBase : public OpBuilder {
 
   /// Unlink this operation from its current block and insert it right after
   /// `iterator` in the specified block.
-  virtual void moveOpAfter(Operation *op, Block *block,
-                           Block::iterator iterator);
+  void moveOpAfter(Operation *op, Block *block, Block::iterator iterator);
 
   /// Unlink this block and insert it right before `existingBlock`.
   void moveBlockBefore(Block *block, Block *anotherBlock);
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index f061d761ecefbb..b028d2b71b3762 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -721,8 +721,8 @@ class ConversionPatternRewriter final : public PatternRewriter {
 
   /// PatternRewriter hook for updating the given operation in-place.
   /// Note: These methods only track updates to the given operation itself,
-  /// and not nested regions. Updates to regions will still require notification
-  /// through other more specific hooks above.
+  /// and not nested regions. Updates to regions will still require
+  /// notification through other more specific hooks above.
   void startOpModification(Operation *op) override;
 
   /// PatternRewriter hook for updating the given operation in-place.
@@ -738,11 +738,6 @@ class ConversionPatternRewriter final : public PatternRewriter {
   // Hide unsupported pattern rewriter API.
   using OpBuilder::setListener;
 
-  void moveOpBefore(Operation *op, Block *block,
-                    Block::iterator iterator) override;
-  void moveOpAfter(Operation *op, Block *block,
-                   Block::iterator iterator) override;
-
   std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
 };
 
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index edca84e5a73f04..85b67bb834de7c 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -760,7 +760,8 @@ class IRRewrite {
     InlineBlock,
     MoveBlock,
     SplitBlock,
-    BlockTypeConversion
+    BlockTypeConversion,
+    MoveOperation
   };
 
   virtual ~IRRewrite() = default;
@@ -982,6 +983,54 @@ class BlockTypeConversionRewrite : public BlockRewrite {
   // `ArgConverter::applyRewrites`. This should be done in the "commit" method.
   void rollback() override;
 };
+
+/// An operation rewrite.
+class OperationRewrite : public IRRewrite {
+public:
+  /// Return the operation that this rewrite operates on.
+  Operation *getOperation() const { return op; }
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() >= Kind::MoveOperation &&
+           rewrite->getKind() <= Kind::MoveOperation;
+  }
+
+protected:
+  OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
+                   Operation *op)
+      : IRRewrite(kind, rewriterImpl), op(op) {}
+
+  // The operation that this rewrite operates on.
+  Operation *op;
+};
+
+/// Moving of an operation. This rewrite is immediately reflected in the IR.
+class MoveOperationRewrite : public OperationRewrite {
+public:
+  MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+                       Operation *op, Block *block, Operation *insertBeforeOp)
+      : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block),
+        insertBeforeOp(insertBeforeOp) {}
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() == Kind::MoveOperation;
+  }
+
+  void rollback() override {
+    // Move the operation back to its original position.
+    Block::iterator before =
+        insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
+    block->getOperations().splice(before, op->getBlock()->getOperations(), op);
+  }
+
+private:
+  // The block in which this operation was previously contained.
+  Block *block;
+
+  // The original successor of this operation before it was moved. "nullptr" if
+  // this operation was the only operation in the region.
+  Operation *insertBeforeOp;
+};
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -1478,12 +1527,19 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
 
 void ConversionPatternRewriterImpl::notifyOperationInserted(
     Operation *op, OpBuilder::InsertPoint previous) {
-  assert(!previous.isSet() && "expected newly created op");
   LLVM_DEBUG({
     logger.startLine() << "** Insert  : '" << op->getName() << "'(" << op
                        << ")\n";
   });
-  createdOps.push_back(op);
+  if (!previous.isSet()) {
+    // This is a newly created op.
+    createdOps.push_back(op);
+    return;
+  }
+  Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
+                          ? nullptr
+                          : &*previous.getPoint();
+  appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp);
 }
 
 void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
@@ -1722,18 +1778,6 @@ void ConversionPatternRewriter::cancelOpModification(Operation *op) {
   rootUpdates.erase(rootUpdates.begin() + updateIdx);
 }
 
-void ConversionPatternRewriter::moveOpBefore(Operation *op, Block *block,
-                                             Block::iterator iterator) {
-  llvm_unreachable(
-      "moving single ops is not supported in a dialect conversion");
-}
-
-void ConversionPatternRewriter::moveOpAfter(Operation *op, Block *block,
-                                            Block::iterator iterator) {
-  llvm_unreachable(
-      "moving single ops is not supported in a dialect conversion");
-}
-
 detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
   return *impl;
 }
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index d8cf6e4719cede..84fcc18ab7d370 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -320,3 +320,17 @@ module {
     return
   }
 }
+
+// -----
+
+// CHECK-LABEL: func @test_move_op_before_rollback()
+func.func @test_move_op_before_rollback() {
+  // CHECK: "test.one_region_op"()
+  // CHECK: "test.hoist_me"()
+  "test.one_region_op"() ({
+    // expected-remark @below{{'test.hoist_me' is not legalizable}}
+    %0 = "test.hoist_me"() : () -> (i32)
+    "test.valid"(%0) : (i32) -> ()
+  }) : () -> ()
+  "test.return"() : () -> ()
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index d7e5d6db50c1fb..1c02232b8adbb1 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -773,6 +773,22 @@ struct TestUndoBlockArgReplace : public ConversionPattern {
   }
 };
 
+/// This pattern hoists ops out of a "test.hoist_me" and then fails conversion.
+/// This is to test the rollback logic.
+struct TestUndoMoveOpBefore : public ConversionPattern {
+  TestUndoMoveOpBefore(MLIRContext *ctx)
+      : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.moveOpBefore(op, op->getParentOp());
+    // Replace with an illegal op to ensure the conversion fails.
+    rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type());
+    return success();
+  }
+};
+
 /// A rewrite pattern that tests the undo mechanism when erasing a block.
 struct TestUndoBlockErase : public ConversionPattern {
   TestUndoBlockErase(MLIRContext *ctx)
@@ -1069,7 +1085,7 @@ struct TestLegalizePatternDriver
              TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
              TestNonRootReplacement, TestBoundedRecursiveRewrite,
              TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
-             TestCreateUnregisteredOp>(&getContext());
+             TestCreateUnregisteredOp, TestUndoMoveOpBefore>(&getContext());
     patterns.add<TestDropOpSignatureConversion>(&getContext(), converter);
     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
                                                               converter);
@@ -1079,7 +1095,7 @@ struct TestLegalizePatternDriver
     ConversionTarget target(getContext());
     target.addLegalOp<ModuleOp>();
     target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
-                      TerminatorOp>();
+                      TerminatorOp, OneRegionOp>();
     target
         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
     target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {



More information about the llvm-branch-commits mailing list