[llvm-branch-commits] [mlir] [mlir][Transforms][NFC] Modularize block actions (PR #81237)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Feb 9 01:39:03 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

Throughout the rewrite process, the dialect conversion maintains a list of "block actions" that can be rolled back upon failure. This commit encapsulates the existing block actions into separate classes, making it easier to add additional actions in the future.

This commit also renames "block actions" to "rewrite actions". In a subsequent commit, an "operation action" that allows rolling back movements of single operations is added. This is to support `moveOpBefore` in the dialect conversion.

Rewrite actions have two methods: `commit()` commits an action. It can no longer be rolled back afterwards. `rollback()` undoes an action. It can no longer be committed afterwards.

---

Patch is 23.31 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81237.diff


1 Files Affected:

- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+283-183) 


``````````diff
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index e41231d7cbd390..44c107c8733f3d 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -154,13 +154,13 @@ namespace {
 struct RewriterState {
   RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
                 unsigned numReplacements, unsigned numArgReplacements,
-                unsigned numBlockActions, unsigned numIgnoredOperations,
+                unsigned numRewriteActions, unsigned numIgnoredOperations,
                 unsigned numRootUpdates)
       : numCreatedOps(numCreatedOps),
         numUnresolvedMaterializations(numUnresolvedMaterializations),
         numReplacements(numReplacements),
         numArgReplacements(numArgReplacements),
-        numBlockActions(numBlockActions),
+        numRewriteActions(numRewriteActions),
         numIgnoredOperations(numIgnoredOperations),
         numRootUpdates(numRootUpdates) {}
 
@@ -176,8 +176,8 @@ struct RewriterState {
   /// The current number of argument replacements queued.
   unsigned numArgReplacements;
 
-  /// The current number of block actions performed.
-  unsigned numBlockActions;
+  /// The current number of rewrite actions performed.
+  unsigned numRewriteActions;
 
   /// The current number of ignored operations.
   unsigned numIgnoredOperations;
@@ -235,86 +235,6 @@ struct OpReplacement {
   const TypeConverter *converter;
 };
 
-//===----------------------------------------------------------------------===//
-// BlockAction
-
-/// The kind of the block action performed during the rewrite.  Actions can be
-/// undone if the conversion fails.
-enum class BlockActionKind {
-  Create,
-  Erase,
-  Inline,
-  Move,
-  Split,
-  TypeConversion
-};
-
-/// Original position of the given block in its parent region. During undo
-/// actions, the block needs to be placed before `insertBeforeBlock`.
-struct BlockPosition {
-  Region *region;
-  Block *insertBeforeBlock;
-};
-
-/// Information needed to undo inlining actions.
-/// - the source block
-/// - the first inlined operation (could be null if the source block was empty)
-/// - the last inlined operation (could be null if the source block was empty)
-struct InlineInfo {
-  Block *sourceBlock;
-  Operation *firstInlinedInst;
-  Operation *lastInlinedInst;
-};
-
-/// The storage class for an undoable block action (one of BlockActionKind),
-/// contains the information necessary to undo this action.
-struct BlockAction {
-  static BlockAction getCreate(Block *block) {
-    return {BlockActionKind::Create, block, {}};
-  }
-  static BlockAction getErase(Block *block, BlockPosition originalPosition) {
-    return {BlockActionKind::Erase, block, {originalPosition}};
-  }
-  static BlockAction getInline(Block *block, Block *srcBlock,
-                               Block::iterator before) {
-    BlockAction action{BlockActionKind::Inline, block, {}};
-    action.inlineInfo = {srcBlock,
-                         srcBlock->empty() ? nullptr : &srcBlock->front(),
-                         srcBlock->empty() ? nullptr : &srcBlock->back()};
-    return action;
-  }
-  static BlockAction getMove(Block *block, BlockPosition originalPosition) {
-    return {BlockActionKind::Move, block, {originalPosition}};
-  }
-  static BlockAction getSplit(Block *block, Block *originalBlock) {
-    BlockAction action{BlockActionKind::Split, block, {}};
-    action.originalBlock = originalBlock;
-    return action;
-  }
-  static BlockAction getTypeConversion(Block *block) {
-    return BlockAction{BlockActionKind::TypeConversion, block, {}};
-  }
-
-  // The action kind.
-  BlockActionKind kind;
-
-  // A pointer to the block that was created by the action.
-  Block *block;
-
-  union {
-    // In use if kind == BlockActionKind::Inline or BlockActionKind::Erase, and
-    // contains a pointer to the region that originally contained the block as
-    // well as the position of the block in that region.
-    BlockPosition originalPosition;
-    // In use if kind == BlockActionKind::Split and contains a pointer to the
-    // block that was split into two parts.
-    Block *originalBlock;
-    // In use if kind == BlockActionKind::Inline, and contains the information
-    // needed to undo the inlining.
-    InlineInfo inlineInfo;
-  };
-};
-
 //===----------------------------------------------------------------------===//
 // UnresolvedMaterialization
 
@@ -820,6 +740,238 @@ void ArgConverter::insertConversion(Block *newBlock,
   conversionInfo.insert({newBlock, std::move(info)});
 }
 
+//===----------------------------------------------------------------------===//
+// RewriteAction
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// An IR rewrite that can be committed (upon success) or rolled back (upon
+/// failure).
+class RewriteAction {
+public:
+  /// The kind of the action performed during the rewrite. Actions can be
+  /// undone if the conversion fails.
+  enum class Kind {
+    CreateBlock,
+    EraseBlock,
+    InlineBlock,
+    MoveBlock,
+    SplitBlock,
+    BlockTypeConversion
+  };
+
+  virtual ~RewriteAction() = default;
+
+  /// Roll back the action.
+  virtual void rollback() = 0;
+
+  /// Commit the action.
+  virtual void commit() {}
+
+  Kind getKind() const { return kind; }
+
+  static bool classof(const RewriteAction *action) { return true; }
+
+protected:
+  RewriteAction(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
+      : kind(kind), rewriterImpl(rewriterImpl) {}
+
+  const Kind kind;
+  ConversionPatternRewriterImpl &rewriterImpl;
+};
+
+/// A block rewrite.
+class BlockAction : public RewriteAction {
+public:
+  /// Return the block that this action operates on.
+  Block *getBlock() const { return block; }
+
+  static bool classof(const RewriteAction *action) {
+    return action->getKind() >= Kind::CreateBlock &&
+           action->getKind() <= Kind::BlockTypeConversion;
+  }
+
+protected:
+  BlockAction(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
+              Block *block)
+      : RewriteAction(kind, rewriterImpl), block(block) {}
+
+  // The block that this action operates on.
+  Block *block;
+};
+
+/// Rewrite action that represent the creation of a block.
+class CreateBlockAction : public BlockAction {
+public:
+  CreateBlockAction(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
+      : BlockAction(Kind::CreateBlock, rewriterImpl, block) {}
+
+  static bool classof(const RewriteAction *action) {
+    return action->getKind() == Kind::CreateBlock;
+  }
+
+  void rollback() override {
+    // Unlink all of the operations within this block, they will be deleted
+    // separately.
+    auto &blockOps = block->getOperations();
+    while (!blockOps.empty())
+      blockOps.remove(blockOps.begin());
+    block->dropAllDefinedValueUses();
+    block->erase();
+  }
+};
+
+/// Rewrite action that represent the erasure of a block.
+class EraseBlockAction : public BlockAction {
+public:
+  EraseBlockAction(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
+                   Region *region, Block *insertBeforeBlock)
+      : BlockAction(Kind::EraseBlock, rewriterImpl, block), region(region),
+        insertBeforeBlock(insertBeforeBlock) {}
+
+  static bool classof(const RewriteAction *action) {
+    return action->getKind() == Kind::EraseBlock;
+  }
+
+  ~EraseBlockAction() override {
+    assert(!block && "action was neither rolled back nor committed");
+  }
+
+  void rollback() override {
+    // The block (owned by this action) was not actually erased yet. It was just
+    // unlinked. Put it back into its original position.
+    assert(block && "expected block");
+    auto &blockList = region->getBlocks();
+    Region::iterator before = insertBeforeBlock
+                                  ? Region::iterator(insertBeforeBlock)
+                                  : blockList.end();
+    blockList.insert(before, block);
+    block = nullptr;
+  }
+
+  void commit() override {
+    // Erase the block.
+    assert(block && "expected block");
+    delete block;
+    block = nullptr;
+  }
+
+private:
+  // The region in which this block was previously contained.
+  Region *region;
+
+  // The original successor of this block before it was unlinked. "nullptr" if
+  // this block was the only block in the region.
+  Block *insertBeforeBlock;
+};
+
+/// Rewrite action that represent the inlining of a block.
+class InlineBlockAction : public BlockAction {
+public:
+  InlineBlockAction(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
+                    Block *sourceBlock, Block::iterator before)
+      : BlockAction(Kind::InlineBlock, rewriterImpl, block),
+        sourceBlock(sourceBlock),
+        firstInlinedInst(sourceBlock->empty() ? nullptr
+                                              : &sourceBlock->front()),
+        lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
+  }
+
+  static bool classof(const RewriteAction *action) {
+    return action->getKind() == Kind::InlineBlock;
+  }
+
+  void rollback() override {
+    // Put the operations from the destination block (owned by the action)
+    // back into the source block.
+    if (firstInlinedInst) {
+      assert(lastInlinedInst && "expected operation");
+      sourceBlock->getOperations().splice(sourceBlock->begin(),
+                                          block->getOperations(),
+                                          Block::iterator(firstInlinedInst),
+                                          ++Block::iterator(lastInlinedInst));
+    }
+  }
+
+private:
+  // The block that originally contained the operations.
+  Block *sourceBlock;
+
+  // The first inlined operation.
+  Operation *firstInlinedInst;
+
+  // The last inlined operation.
+  Operation *lastInlinedInst;
+};
+
+/// Rewrite action that represent the moving of a block.
+class MoveBlockAction : public BlockAction {
+public:
+  MoveBlockAction(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
+                  Region *region, Block *insertBeforeBlock)
+      : BlockAction(Kind::MoveBlock, rewriterImpl, block), region(region),
+        insertBeforeBlock(insertBeforeBlock) {}
+
+  static bool classof(const RewriteAction *action) {
+    return action->getKind() == Kind::MoveBlock;
+  }
+
+  void rollback() override {
+    // Move the block back to its original position.
+    Region::iterator before =
+        insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end();
+    region->getBlocks().splice(before, block->getParent()->getBlocks(), block);
+  }
+
+private:
+  // The region in which this block was previously contained.
+  Region *region;
+
+  // The original successor of this block before it was moved. "nullptr" if
+  // this block was the only block in the region.
+  Block *insertBeforeBlock;
+};
+
+/// Rewrite action that represent the splitting of a block.
+class SplitBlockAction : public BlockAction {
+public:
+  SplitBlockAction(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
+                   Block *originalBlock)
+      : BlockAction(Kind::SplitBlock, rewriterImpl, block),
+        originalBlock(originalBlock) {}
+
+  static bool classof(const RewriteAction *action) {
+    return action->getKind() == Kind::SplitBlock;
+  }
+
+  void rollback() override {
+    // Merge back the block that was split out.
+    originalBlock->getOperations().splice(originalBlock->end(),
+                                          block->getOperations());
+    block->dropAllDefinedValueUses();
+    block->erase();
+  }
+
+private:
+  // The original block from which this block was split.
+  Block *originalBlock;
+};
+
+/// Rewrite action that represent a block type conversion.
+class BlockTypeConversionAction : public BlockAction {
+public:
+  BlockTypeConversionAction(ConversionPatternRewriterImpl &rewriterImpl,
+                            Block *block)
+      : BlockAction(Kind::BlockTypeConversion, rewriterImpl, block) {}
+
+  static bool classof(const RewriteAction *action) {
+    return action->getKind() == Kind::BlockTypeConversion;
+  }
+
+  void rollback() override;
+};
+} // namespace
+
 //===----------------------------------------------------------------------===//
 // ConversionPatternRewriterImpl
 //===----------------------------------------------------------------------===//
@@ -848,13 +1000,17 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// Reset the state of the rewriter to a previously saved point.
   void resetState(RewriterState state);
 
-  /// Erase any blocks that were unlinked from their regions and stored in block
-  /// actions.
-  void eraseDanglingBlocks();
+  /// Append a rewrite action. Actions are committed upon success and rolled
+  /// back upon failure.
+  template <typename ActionTy, typename... Args>
+  void appendRewriteAction(Args &&...args) {
+    rewriteActions.push_back(
+        std::make_unique<ActionTy>(*this, std::forward<Args>(args)...));
+  }
 
-  /// Undo the block actions (motions, splits) one by one in reverse order until
-  /// "numActionsToKeep" actions remains.
-  void undoBlockActions(unsigned numActionsToKeep = 0);
+  /// Undo the rewrite actions (motions, splits) one by one in reverse order
+  /// until "numActionsToKeep" actions remains.
+  void undoRewriteActions(unsigned numActionsToKeep = 0);
 
   /// Remap the given values to those with potentially different types. Returns
   /// success if the values could be remapped, failure otherwise. `valueDiagTag`
@@ -954,7 +1110,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   SmallVector<BlockArgument, 4> argReplacements;
 
   /// Ordered list of block operations (creations, splits, motions).
-  SmallVector<BlockAction, 4> blockActions;
+  SmallVector<std::unique_ptr<RewriteAction>> rewriteActions;
 
   /// A set of operations that should no longer be considered for legalization,
   /// but were not directly replace/erased/etc. by a pattern. These are
@@ -995,6 +1151,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
 } // namespace detail
 } // namespace mlir
 
+void BlockTypeConversionAction::rollback() {
+  // Undo the type conversion.
+  rewriterImpl.argConverter.discardRewrites(block);
+}
+
 /// Detach any operations nested in the given operation from their parent
 /// blocks, and erase the given operation. This can be used when the nested
 /// operations are scheduled for erasure themselves, so deleting the regions of
@@ -1020,7 +1181,7 @@ void ConversionPatternRewriterImpl::discardRewrites() {
   for (auto &state : rootUpdates)
     state.resetOperation();
 
-  undoBlockActions();
+  undoRewriteActions();
 
   // Remove any newly created ops.
   for (UnresolvedMaterialization &materialization : unresolvedMaterializations)
@@ -1083,8 +1244,9 @@ void ConversionPatternRewriterImpl::applyRewrites() {
 
   argConverter.applyRewrites(mapping);
 
-  // Now that the ops have been erased, also erase dangling blocks.
-  eraseDanglingBlocks();
+  // Commit all rewrite actions.
+  for (auto &action : rewriteActions)
+    action->commit();
 }
 
 //===----------------------------------------------------------------------===//
@@ -1093,7 +1255,7 @@ void ConversionPatternRewriterImpl::applyRewrites() {
 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
   return RewriterState(createdOps.size(), unresolvedMaterializations.size(),
                        replacements.size(), argReplacements.size(),
-                       blockActions.size(), ignoredOps.size(),
+                       rewriteActions.size(), ignoredOps.size(),
                        rootUpdates.size());
 }
 
@@ -1109,8 +1271,8 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
     mapping.erase(replacedArg);
   argReplacements.resize(state.numArgReplacements);
 
-  // Undo any block actions.
-  undoBlockActions(state.numBlockActions);
+  // Undo any rewrite actions.
+  undoRewriteActions(state.numRewriteActions);
 
   // Reset any replaced operations and undo any saved mappings.
   for (auto &repl : llvm::drop_begin(replacements, state.numReplacements))
@@ -1149,76 +1311,12 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
     operationsWithChangedResults.pop_back();
 }
 
-void ConversionPatternRewriterImpl::eraseDanglingBlocks() {
-  for (auto &action : blockActions)
-    if (action.kind == BlockActionKind::Erase)
-      delete action.block;
-}
-
-void ConversionPatternRewriterImpl::undoBlockActions(
+void ConversionPatternRewriterImpl::undoRewriteActions(
     unsigned numActionsToKeep) {
   for (auto &action :
-       llvm::reverse(llvm::drop_begin(blockActions, numActionsToKeep))) {
-    switch (action.kind) {
-    // Delete the created block.
-    case BlockActionKind::Create: {
-      // Unlink all of the operations within this block, they will be deleted
-      // separately.
-      auto &blockOps = action.block->getOperations();
-      while (!blockOps.empty())
-        blockOps.remove(blockOps.begin());
-      action.block->dropAllDefinedValueUses();
-      action.block->erase();
-      break;
-    }
-    // Put the block (owned by action) back into its original position.
-    case BlockActionKind::Erase: {
-      auto &blockList = action.originalPosition.region->getBlocks();
-      Block *insertBeforeBlock = action.originalPosition.insertBeforeBlock;
-      blockList.insert((insertBeforeBlock ? Region::iterator(insertBeforeBlock)
-                                          : blockList.end()),
-                       action.block);
-      break;
-    }
-    // Put the instructions from the destination block (owned by the action)
-    // back into the source block.
-    case BlockActionKind::Inline: {
-      Block *sourceBlock = action.inlineInfo.sourceBlock;
-      if (action.inlineInfo.firstInlinedInst) {
-        assert(action.inlineInfo.lastInlinedInst && "expected operation");
-        sourceBlock->getOperations().splice(
-            sourceBlock->begin(), action.block->getOperations(),
-            Block::iterator(action.inlineInfo.firstInlinedInst),
-            ++Block::iterator(action.inlineInfo.lastInlinedInst));
-      }
-      break;
-    }
-    // Move the block back to its original position.
-    case BlockActionKind::Move: {
-      Region *originalRegion = action.originalPosition.region;
-      Block *insertBeforeBlock = action.originalPosition.insertBeforeBlock;
-      originalRegion->getBlocks().splice(
-          (insertBeforeBlock ? Region::iterator(insertBeforeBlock)
-                             : originalRegion->end()),
-          action.block->getParent()->getBlocks(), action.block);
-      break;
-    }
-    // Merge back the block that was split out.
-    case BlockActionKind::Split: {
-      action.originalBlock->getOperations().splice(
-          action.originalBlock->end(), action.block->getOperations());
-      action.block->dropAllDefinedValueUses();
-      action.block->erase();
-      break;
-    }
-    // Undo the type conversion.
-    case BlockActionKind::TypeConversion: {
-      argConverter.discardRewrites(action.block);
-      break;
-    }
-    }
-  }
-  blockActions.resize(numActionsToKeep);
+       llvm::reverse(llvm::drop_begin(rewriteActions, numActionsToKeep)))
+    action->rollback();
+  rewriteActions.resize(numActionsToKeep);
 }
 
 LogicalResult ConversionPatternRewriterImpl::remapValues(
@@ -1309,7 +1407,7 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
     return failure();
   if (Block *newBlock = *result) {
     if (newBlock != block)
-      blockActions.push_back(BlockAction::getTypeConversion(newBlock));
+      appendRewriteAction<BlockTypeConversionAction>(newBlock);
   }
   return result;
 }
@@ -1410,28 +1508,28 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
 void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
   Region *region = block->getParent();
   Block *origNextBlock = block->getNextNode();
-  blockActions.push_back(BlockAction::getErase(block, {region, origNextBlock}));
+  appendRewriteAction<EraseBlockAction>(block, region, origNextBlock);
 }
 
 void ConversionPatternRewriterImpl::notifyBlockInserted(
     Block *block, Region *previous, Re...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/81237


More information about the llvm-branch-commits mailing list