[Mlir-commits] [mlir] 8faefe3 - [mlir][Transforms][NFC] Modularize block actions (#81237)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 14 08:15:34 PST 2024


Author: Matthias Springer
Date: 2024-02-14T17:15:30+01:00
New Revision: 8faefe36ed57c2dab2b50e76fd27045b908f8c1d

URL: https://github.com/llvm/llvm-project/commit/8faefe36ed57c2dab2b50e76fd27045b908f8c1d
DIFF: https://github.com/llvm/llvm-project/commit/8faefe36ed57c2dab2b50e76fd27045b908f8c1d.diff

LOG: [mlir][Transforms][NFC] Modularize block actions (#81237)

Throughout the rewrite process, the dialect conversion maintains a list
of "block actions" that can be rolled back upon failure. This commit
encapsulates the existing block actions into separate classes, making it
easier to add additional actions in the future.

This commit also renames "block actions" to "IR rewrites". In a
subsequent commit, an "operation rewrite" class that allows rolling back
movements of single operations is added. This is to support
`moveOpBefore` in the dialect conversion.

Rewrites have two methods: `commit()` commits an action. It can no
longer be rolled back afterwards. `rollback()` undoes a rewrite. It can
no longer be committed afterwards.

Added: 
    

Modified: 
    mlir/lib/Transforms/Utils/DialectConversion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index dbf5bf50d60e7f..9875f8668b65a8 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -154,13 +154,12 @@ namespace {
 struct RewriterState {
   RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
                 unsigned numReplacements, unsigned numArgReplacements,
-                unsigned numBlockActions, unsigned numIgnoredOperations,
+                unsigned numRewrites, unsigned numIgnoredOperations,
                 unsigned numRootUpdates)
       : numCreatedOps(numCreatedOps),
         numUnresolvedMaterializations(numUnresolvedMaterializations),
         numReplacements(numReplacements),
-        numArgReplacements(numArgReplacements),
-        numBlockActions(numBlockActions),
+        numArgReplacements(numArgReplacements), numRewrites(numRewrites),
         numIgnoredOperations(numIgnoredOperations),
         numRootUpdates(numRootUpdates) {}
 
@@ -176,8 +175,8 @@ struct RewriterState {
   /// The current number of argument replacements queued.
   unsigned numArgReplacements;
 
-  /// The current number of block actions performed.
-  unsigned numBlockActions;
+  /// The current number of rewrites performed.
+  unsigned numRewrites;
 
   /// The current number of ignored operations.
   unsigned numIgnoredOperations;
@@ -235,86 +234,6 @@ struct OpReplacement {
   const TypeConverter *converter;
 };
 
-//===----------------------------------------------------------------------===//
-// BlockAction
-
-/// The kind of the block action performed during the rewrite.  Actions can be
-/// undone if the conversion fails.
-enum class BlockActionKind {
-  Create,
-  Erase,
-  Inline,
-  Move,
-  Split,
-  TypeConversion
-};
-
-/// Original position of the given block in its parent region. During undo
-/// actions, the block needs to be placed before `insertBeforeBlock`.
-struct BlockPosition {
-  Region *region;
-  Block *insertBeforeBlock;
-};
-
-/// Information needed to undo inlining actions.
-/// - the source block
-/// - the first inlined operation (could be null if the source block was empty)
-/// - the last inlined operation (could be null if the source block was empty)
-struct InlineInfo {
-  Block *sourceBlock;
-  Operation *firstInlinedInst;
-  Operation *lastInlinedInst;
-};
-
-/// The storage class for an undoable block action (one of BlockActionKind),
-/// contains the information necessary to undo this action.
-struct BlockAction {
-  static BlockAction getCreate(Block *block) {
-    return {BlockActionKind::Create, block, {}};
-  }
-  static BlockAction getErase(Block *block, BlockPosition originalPosition) {
-    return {BlockActionKind::Erase, block, {originalPosition}};
-  }
-  static BlockAction getInline(Block *block, Block *srcBlock,
-                               Block::iterator before) {
-    BlockAction action{BlockActionKind::Inline, block, {}};
-    action.inlineInfo = {srcBlock,
-                         srcBlock->empty() ? nullptr : &srcBlock->front(),
-                         srcBlock->empty() ? nullptr : &srcBlock->back()};
-    return action;
-  }
-  static BlockAction getMove(Block *block, BlockPosition originalPosition) {
-    return {BlockActionKind::Move, block, {originalPosition}};
-  }
-  static BlockAction getSplit(Block *block, Block *originalBlock) {
-    BlockAction action{BlockActionKind::Split, block, {}};
-    action.originalBlock = originalBlock;
-    return action;
-  }
-  static BlockAction getTypeConversion(Block *block) {
-    return BlockAction{BlockActionKind::TypeConversion, block, {}};
-  }
-
-  // The action kind.
-  BlockActionKind kind;
-
-  // A pointer to the block that was created by the action.
-  Block *block;
-
-  union {
-    // In use if kind == BlockActionKind::Inline or BlockActionKind::Erase, and
-    // contains a pointer to the region that originally contained the block as
-    // well as the position of the block in that region.
-    BlockPosition originalPosition;
-    // In use if kind == BlockActionKind::Split and contains a pointer to the
-    // block that was split into two parts.
-    Block *originalBlock;
-    // In use if kind == BlockActionKind::Inline, and contains the information
-    // needed to undo the inlining.
-    InlineInfo inlineInfo;
-  };
-};
-
 //===----------------------------------------------------------------------===//
 // UnresolvedMaterialization
 
@@ -820,6 +739,251 @@ void ArgConverter::insertConversion(Block *newBlock,
   conversionInfo.insert({newBlock, std::move(info)});
 }
 
+//===----------------------------------------------------------------------===//
+// IR rewrites
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// An IR rewrite that can be committed (upon success) or rolled back (upon
+/// failure).
+///
+/// The dialect conversion keeps track of IR modifications (requested by the
+/// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites
+/// are directly applied to the IR as the rewriter API is used, some are applied
+/// partially, and some are delayed until the `IRRewrite` objects are committed.
+class IRRewrite {
+public:
+  /// The kind of the rewrite. Rewrites can be undone if the conversion fails.
+  enum class Kind {
+    CreateBlock,
+    EraseBlock,
+    InlineBlock,
+    MoveBlock,
+    SplitBlock,
+    BlockTypeConversion
+  };
+
+  virtual ~IRRewrite() = default;
+
+  /// Roll back the rewrite.
+  virtual void rollback() = 0;
+
+  /// Commit the rewrite.
+  virtual void commit() {}
+
+  Kind getKind() const { return kind; }
+
+  static bool classof(const IRRewrite *rewrite) { return true; }
+
+protected:
+  IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
+      : kind(kind), rewriterImpl(rewriterImpl) {}
+
+  const Kind kind;
+  ConversionPatternRewriterImpl &rewriterImpl;
+};
+
+/// A block rewrite.
+class BlockRewrite : public IRRewrite {
+public:
+  /// Return the block that this rewrite operates on.
+  Block *getBlock() const { return block; }
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() >= Kind::CreateBlock &&
+           rewrite->getKind() <= Kind::BlockTypeConversion;
+  }
+
+protected:
+  BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
+               Block *block)
+      : IRRewrite(kind, rewriterImpl), block(block) {}
+
+  // The block that this rewrite operates on.
+  Block *block;
+};
+
+/// Creation of a block. Block creations are immediately reflected in the IR.
+/// There is no extra work to commit the rewrite. During rollback, the newly
+/// created block is erased.
+class CreateBlockRewrite : public BlockRewrite {
+public:
+  CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
+      : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {}
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() == Kind::CreateBlock;
+  }
+
+  void rollback() override {
+    // Unlink all of the operations within this block, they will be deleted
+    // separately.
+    auto &blockOps = block->getOperations();
+    while (!blockOps.empty())
+      blockOps.remove(blockOps.begin());
+    block->dropAllDefinedValueUses();
+    block->erase();
+  }
+};
+
+/// Erasure of a block. Block erasures are partially reflected in the IR. Erased
+/// blocks are immediately unlinked, but only erased when the rewrite is
+/// committed. This makes it easier to rollback a block erasure: the block is
+/// simply inserted into its original location.
+class EraseBlockRewrite : public BlockRewrite {
+public:
+  EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
+                    Region *region, Block *insertBeforeBlock)
+      : BlockRewrite(Kind::EraseBlock, rewriterImpl, block), region(region),
+        insertBeforeBlock(insertBeforeBlock) {}
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() == Kind::EraseBlock;
+  }
+
+  ~EraseBlockRewrite() override {
+    assert(!block && "rewrite was neither rolled back nor committed");
+  }
+
+  void rollback() override {
+    // The block (owned by this rewrite) was not actually erased yet. It was
+    // just unlinked. Put it back into its original position.
+    assert(block && "expected block");
+    auto &blockList = region->getBlocks();
+    Region::iterator before = insertBeforeBlock
+                                  ? Region::iterator(insertBeforeBlock)
+                                  : blockList.end();
+    blockList.insert(before, block);
+    block = nullptr;
+  }
+
+  void commit() override {
+    // Erase the block.
+    assert(block && "expected block");
+    delete block;
+    block = nullptr;
+  }
+
+private:
+  // The region in which this block was previously contained.
+  Region *region;
+
+  // The original successor of this block before it was unlinked. "nullptr" if
+  // this block was the only block in the region.
+  Block *insertBeforeBlock;
+};
+
+/// Inlining of a block. This rewrite is immediately reflected in the IR.
+/// Note: This rewrite represents only the inlining of the operations. The
+/// erasure of the inlined block is a separate rewrite.
+class InlineBlockRewrite : public BlockRewrite {
+public:
+  InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
+                     Block *sourceBlock, Block::iterator before)
+      : BlockRewrite(Kind::InlineBlock, rewriterImpl, block),
+        sourceBlock(sourceBlock),
+        firstInlinedInst(sourceBlock->empty() ? nullptr
+                                              : &sourceBlock->front()),
+        lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
+  }
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() == Kind::InlineBlock;
+  }
+
+  void rollback() override {
+    // Put the operations from the destination block (owned by the rewrite)
+    // back into the source block.
+    if (firstInlinedInst) {
+      assert(lastInlinedInst && "expected operation");
+      sourceBlock->getOperations().splice(sourceBlock->begin(),
+                                          block->getOperations(),
+                                          Block::iterator(firstInlinedInst),
+                                          ++Block::iterator(lastInlinedInst));
+    }
+  }
+
+private:
+  // The block that originally contained the operations.
+  Block *sourceBlock;
+
+  // The first inlined operation.
+  Operation *firstInlinedInst;
+
+  // The last inlined operation.
+  Operation *lastInlinedInst;
+};
+
+/// Moving of a block. This rewrite is immediately reflected in the IR.
+class MoveBlockRewrite : public BlockRewrite {
+public:
+  MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
+                   Region *region, Block *insertBeforeBlock)
+      : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region),
+        insertBeforeBlock(insertBeforeBlock) {}
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() == Kind::MoveBlock;
+  }
+
+  void rollback() override {
+    // Move the block back to its original position.
+    Region::iterator before =
+        insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end();
+    region->getBlocks().splice(before, block->getParent()->getBlocks(), block);
+  }
+
+private:
+  // The region in which this block was previously contained.
+  Region *region;
+
+  // The original successor of this block before it was moved. "nullptr" if
+  // this block was the only block in the region.
+  Block *insertBeforeBlock;
+};
+
+/// Splitting of a block. This rewrite is immediately reflected in the IR.
+class SplitBlockRewrite : public BlockRewrite {
+public:
+  SplitBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
+                    Block *originalBlock)
+      : BlockRewrite(Kind::SplitBlock, rewriterImpl, block),
+        originalBlock(originalBlock) {}
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() == Kind::SplitBlock;
+  }
+
+  void rollback() override {
+    // Merge back the block that was split out.
+    originalBlock->getOperations().splice(originalBlock->end(),
+                                          block->getOperations());
+    block->dropAllDefinedValueUses();
+    block->erase();
+  }
+
+private:
+  // The original block from which this block was split.
+  Block *originalBlock;
+};
+
+/// Block type conversion. This rewrite is partially reflected in the IR.
+class BlockTypeConversionRewrite : public BlockRewrite {
+public:
+  BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+                             Block *block)
+      : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block) {}
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() == Kind::BlockTypeConversion;
+  }
+
+  // TODO: Block type conversions are currently committed in
+  // `ArgConverter::applyRewrites`. This should be done in the "commit" method.
+  void rollback() override;
+};
+} // namespace
+
 //===----------------------------------------------------------------------===//
 // ConversionPatternRewriterImpl
 //===----------------------------------------------------------------------===//
@@ -848,13 +1012,17 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// Reset the state of the rewriter to a previously saved point.
   void resetState(RewriterState state);
 
-  /// Erase any blocks that were unlinked from their regions and stored in block
-  /// actions.
-  void eraseDanglingBlocks();
+  /// Append a rewrite. Rewrites are committed upon success and rolled back upon
+  /// failure.
+  template <typename RewriteTy, typename... Args>
+  void appendRewrite(Args &&...args) {
+    rewrites.push_back(
+        std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
+  }
 
-  /// Undo the block actions (motions, splits) one by one in reverse order until
-  /// "numActionsToKeep" actions remains.
-  void undoBlockActions(unsigned numActionsToKeep = 0);
+  /// Undo the rewrites (motions, splits) one by one in reverse order until
+  /// "numRewritesToKeep" rewrites remains.
+  void undoRewrites(unsigned numRewritesToKeep = 0);
 
   /// Remap the given values to those with potentially 
diff erent types. Returns
   /// success if the values could be remapped, failure otherwise. `valueDiagTag`
@@ -954,7 +1122,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   SmallVector<BlockArgument, 4> argReplacements;
 
   /// Ordered list of block operations (creations, splits, motions).
-  SmallVector<BlockAction, 4> blockActions;
+  SmallVector<std::unique_ptr<IRRewrite>> rewrites;
 
   /// A set of operations that should no longer be considered for legalization,
   /// but were not directly replace/erased/etc. by a pattern. These are
@@ -995,6 +1163,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
 } // namespace detail
 } // namespace mlir
 
+void BlockTypeConversionRewrite::rollback() {
+  // Undo the type conversion.
+  rewriterImpl.argConverter.discardRewrites(block);
+}
+
 /// Detach any operations nested in the given operation from their parent
 /// blocks, and erase the given operation. This can be used when the nested
 /// operations are scheduled for erasure themselves, so deleting the regions of
@@ -1020,7 +1193,7 @@ void ConversionPatternRewriterImpl::discardRewrites() {
   for (auto &state : rootUpdates)
     state.resetOperation();
 
-  undoBlockActions();
+  undoRewrites();
 
   // Remove any newly created ops.
   for (UnresolvedMaterialization &materialization : unresolvedMaterializations)
@@ -1083,8 +1256,9 @@ void ConversionPatternRewriterImpl::applyRewrites() {
 
   argConverter.applyRewrites(mapping);
 
-  // Now that the ops have been erased, also erase dangling blocks.
-  eraseDanglingBlocks();
+  // Commit all rewrites.
+  for (auto &rewrite : rewrites)
+    rewrite->commit();
 }
 
 //===----------------------------------------------------------------------===//
@@ -1093,8 +1267,7 @@ void ConversionPatternRewriterImpl::applyRewrites() {
 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
   return RewriterState(createdOps.size(), unresolvedMaterializations.size(),
                        replacements.size(), argReplacements.size(),
-                       blockActions.size(), ignoredOps.size(),
-                       rootUpdates.size());
+                       rewrites.size(), ignoredOps.size(), rootUpdates.size());
 }
 
 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
@@ -1109,8 +1282,8 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
     mapping.erase(replacedArg);
   argReplacements.resize(state.numArgReplacements);
 
-  // Undo any block actions.
-  undoBlockActions(state.numBlockActions);
+  // Undo any rewrites.
+  undoRewrites(state.numRewrites);
 
   // Reset any replaced operations and undo any saved mappings.
   for (auto &repl : llvm::drop_begin(replacements, state.numReplacements))
@@ -1149,76 +1322,11 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
     operationsWithChangedResults.pop_back();
 }
 
-void ConversionPatternRewriterImpl::eraseDanglingBlocks() {
-  for (auto &action : blockActions)
-    if (action.kind == BlockActionKind::Erase)
-      delete action.block;
-}
-
-void ConversionPatternRewriterImpl::undoBlockActions(
-    unsigned numActionsToKeep) {
-  for (auto &action :
-       llvm::reverse(llvm::drop_begin(blockActions, numActionsToKeep))) {
-    switch (action.kind) {
-    // Delete the created block.
-    case BlockActionKind::Create: {
-      // Unlink all of the operations within this block, they will be deleted
-      // separately.
-      auto &blockOps = action.block->getOperations();
-      while (!blockOps.empty())
-        blockOps.remove(blockOps.begin());
-      action.block->dropAllDefinedValueUses();
-      action.block->erase();
-      break;
-    }
-    // Put the block (owned by action) back into its original position.
-    case BlockActionKind::Erase: {
-      auto &blockList = action.originalPosition.region->getBlocks();
-      Block *insertBeforeBlock = action.originalPosition.insertBeforeBlock;
-      blockList.insert((insertBeforeBlock ? Region::iterator(insertBeforeBlock)
-                                          : blockList.end()),
-                       action.block);
-      break;
-    }
-    // Put the instructions from the destination block (owned by the action)
-    // back into the source block.
-    case BlockActionKind::Inline: {
-      Block *sourceBlock = action.inlineInfo.sourceBlock;
-      if (action.inlineInfo.firstInlinedInst) {
-        assert(action.inlineInfo.lastInlinedInst && "expected operation");
-        sourceBlock->getOperations().splice(
-            sourceBlock->begin(), action.block->getOperations(),
-            Block::iterator(action.inlineInfo.firstInlinedInst),
-            ++Block::iterator(action.inlineInfo.lastInlinedInst));
-      }
-      break;
-    }
-    // Move the block back to its original position.
-    case BlockActionKind::Move: {
-      Region *originalRegion = action.originalPosition.region;
-      Block *insertBeforeBlock = action.originalPosition.insertBeforeBlock;
-      originalRegion->getBlocks().splice(
-          (insertBeforeBlock ? Region::iterator(insertBeforeBlock)
-                             : originalRegion->end()),
-          action.block->getParent()->getBlocks(), action.block);
-      break;
-    }
-    // Merge back the block that was split out.
-    case BlockActionKind::Split: {
-      action.originalBlock->getOperations().splice(
-          action.originalBlock->end(), action.block->getOperations());
-      action.block->dropAllDefinedValueUses();
-      action.block->erase();
-      break;
-    }
-    // Undo the type conversion.
-    case BlockActionKind::TypeConversion: {
-      argConverter.discardRewrites(action.block);
-      break;
-    }
-    }
-  }
-  blockActions.resize(numActionsToKeep);
+void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
+  for (auto &rewrite :
+       llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
+    rewrite->rollback();
+  rewrites.resize(numRewritesToKeep);
 }
 
 LogicalResult ConversionPatternRewriterImpl::remapValues(
@@ -1309,7 +1417,7 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
     return failure();
   if (Block *newBlock = *result) {
     if (newBlock != block)
-      blockActions.push_back(BlockAction::getTypeConversion(newBlock));
+      appendRewrite<BlockTypeConversionRewrite>(newBlock);
   }
   return result;
 }
@@ -1410,28 +1518,28 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
 void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
   Region *region = block->getParent();
   Block *origNextBlock = block->getNextNode();
-  blockActions.push_back(BlockAction::getErase(block, {region, origNextBlock}));
+  appendRewrite<EraseBlockRewrite>(block, region, origNextBlock);
 }
 
 void ConversionPatternRewriterImpl::notifyBlockInserted(
     Block *block, Region *previous, Region::iterator previousIt) {
   if (!previous) {
     // This is a newly created block.
-    blockActions.push_back(BlockAction::getCreate(block));
+    appendRewrite<CreateBlockRewrite>(block);
     return;
   }
   Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt;
-  blockActions.push_back(BlockAction::getMove(block, {previous, prevBlock}));
+  appendRewrite<MoveBlockRewrite>(block, previous, prevBlock);
 }
 
 void ConversionPatternRewriterImpl::notifySplitBlock(Block *block,
                                                      Block *continuation) {
-  blockActions.push_back(BlockAction::getSplit(continuation, block));
+  appendRewrite<SplitBlockRewrite>(continuation, block);
 }
 
 void ConversionPatternRewriterImpl::notifyBlockBeingInlined(
     Block *block, Block *srcBlock, Block::iterator before) {
-  blockActions.push_back(BlockAction::getInline(block, srcBlock, before));
+  appendRewrite<InlineBlockRewrite>(block, srcBlock, before);
 }
 
 void ConversionPatternRewriterImpl::notifyMatchFailure(
@@ -1501,8 +1609,8 @@ void ConversionPatternRewriter::eraseBlock(Block *block) {
   for (Operation &op : *block)
     eraseOp(&op);
 
-  // Unlink the block from its parent region. The block is kept in the block
-  // action and will be actually destroyed when rewrites are applied. This
+  // Unlink the block from its parent region. The block is kept in the rewrite
+  // object and will be actually destroyed when rewrites are applied. This
   // allows us to keep the operations in the block live and undo the removal by
   // re-inserting the block.
   block->getParent()->getBlocks().remove(block);
@@ -1700,11 +1808,11 @@ class OperationLegalizer {
                                       RewriterState &curState);
 
   /// Legalizes the actions registered during the execution of a pattern.
-  LogicalResult legalizePatternBlockActions(Operation *op,
-                                            ConversionPatternRewriter &rewriter,
-                                            ConversionPatternRewriterImpl &impl,
-                                            RewriterState &state,
-                                            RewriterState &newState);
+  LogicalResult
+  legalizePatternBlockRewrites(Operation *op,
+                               ConversionPatternRewriter &rewriter,
+                               ConversionPatternRewriterImpl &impl,
+                               RewriterState &state, RewriterState &newState);
   LogicalResult legalizePatternCreatedOperations(
       ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
       RewriterState &state, RewriterState &newState);
@@ -1986,8 +2094,8 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
 
   // Legalize each of the actions registered during application.
   RewriterState newState = impl.getCurrentState();
-  if (failed(legalizePatternBlockActions(op, rewriter, impl, curState,
-                                         newState)) ||
+  if (failed(legalizePatternBlockRewrites(op, rewriter, impl, curState,
+                                          newState)) ||
       failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) ||
       failed(legalizePatternCreatedOperations(rewriter, impl, curState,
                                               newState))) {
@@ -1998,7 +2106,7 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
   return success();
 }
 
-LogicalResult OperationLegalizer::legalizePatternBlockActions(
+LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
     Operation *op, ConversionPatternRewriter &rewriter,
     ConversionPatternRewriterImpl &impl, RewriterState &state,
     RewriterState &newState) {
@@ -2006,22 +2114,22 @@ LogicalResult OperationLegalizer::legalizePatternBlockActions(
 
   // If the pattern moved or created any blocks, make sure the types of block
   // arguments get legalized.
-  for (int i = state.numBlockActions, e = newState.numBlockActions; i != e;
-       ++i) {
-    auto &action = impl.blockActions[i];
-    if (action.kind == BlockActionKind::TypeConversion ||
-        action.kind == BlockActionKind::Erase)
+  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
+    BlockRewrite *rewrite = dyn_cast<BlockRewrite>(impl.rewrites[i].get());
+    if (!rewrite)
+      continue;
+    Block *block = rewrite->getBlock();
+    if (isa<BlockTypeConversionRewrite, EraseBlockRewrite>(rewrite))
       continue;
     // Only check blocks outside of the current operation.
-    Operation *parentOp = action.block->getParentOp();
-    if (!parentOp || parentOp == op || action.block->getNumArguments() == 0)
+    Operation *parentOp = block->getParentOp();
+    if (!parentOp || parentOp == op || block->getNumArguments() == 0)
       continue;
 
     // If the region of the block has a type converter, try to convert the block
     // directly.
-    if (auto *converter =
-            impl.argConverter.getConverter(action.block->getParent())) {
-      if (failed(impl.convertBlockSignature(action.block, converter))) {
+    if (auto *converter = impl.argConverter.getConverter(block->getParent())) {
+      if (failed(impl.convertBlockSignature(block, converter))) {
         LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
                                            "block"));
         return failure();
@@ -2042,9 +2150,9 @@ LogicalResult OperationLegalizer::legalizePatternBlockActions(
     // If this operation should be considered for re-legalization, try it.
     if (operationsToIgnore.insert(parentOp).second &&
         failed(legalize(parentOp, rewriter))) {
-      LLVM_DEBUG(logFailure(
-          impl.logger, "operation '{0}'({1}) became illegal after block action",
-          parentOp->getName(), parentOp));
+      LLVM_DEBUG(logFailure(impl.logger,
+                            "operation '{0}'({1}) became illegal after rewrite",
+                            parentOp->getName(), parentOp));
       return failure();
     }
   }


        


More information about the Mlir-commits mailing list