[llvm-branch-commits] [mlir] [mlir][Transforms] Support `moveOpBefore`/`After` in dialect conversion (PR #81240)
Matthias Springer via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Feb 12 01:10:02 PST 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/81240
>From c60c43bcd2296715ceca83a3f9666433883ec303 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 12 Feb 2024 09:05:50 +0000
Subject: [PATCH 1/2] [mlir][Transforms][WIP] RewriteAction
BEGIN_PUBLIC
No public commit message needed for presubmit.
END_PUBLIC
---
.../Transforms/Utils/DialectConversion.cpp | 504 +++++++++++-------
1 file changed, 306 insertions(+), 198 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index e41231d7cbd390..edca84e5a73f04 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -154,13 +154,12 @@ namespace {
struct RewriterState {
RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
unsigned numReplacements, unsigned numArgReplacements,
- unsigned numBlockActions, unsigned numIgnoredOperations,
+ unsigned numRewrites, unsigned numIgnoredOperations,
unsigned numRootUpdates)
: numCreatedOps(numCreatedOps),
numUnresolvedMaterializations(numUnresolvedMaterializations),
numReplacements(numReplacements),
- numArgReplacements(numArgReplacements),
- numBlockActions(numBlockActions),
+ numArgReplacements(numArgReplacements), numRewrites(numRewrites),
numIgnoredOperations(numIgnoredOperations),
numRootUpdates(numRootUpdates) {}
@@ -176,8 +175,8 @@ struct RewriterState {
/// The current number of argument replacements queued.
unsigned numArgReplacements;
- /// The current number of block actions performed.
- unsigned numBlockActions;
+ /// The current number of rewrites performed.
+ unsigned numRewrites;
/// The current number of ignored operations.
unsigned numIgnoredOperations;
@@ -235,86 +234,6 @@ struct OpReplacement {
const TypeConverter *converter;
};
-//===----------------------------------------------------------------------===//
-// BlockAction
-
-/// The kind of the block action performed during the rewrite. Actions can be
-/// undone if the conversion fails.
-enum class BlockActionKind {
- Create,
- Erase,
- Inline,
- Move,
- Split,
- TypeConversion
-};
-
-/// Original position of the given block in its parent region. During undo
-/// actions, the block needs to be placed before `insertBeforeBlock`.
-struct BlockPosition {
- Region *region;
- Block *insertBeforeBlock;
-};
-
-/// Information needed to undo inlining actions.
-/// - the source block
-/// - the first inlined operation (could be null if the source block was empty)
-/// - the last inlined operation (could be null if the source block was empty)
-struct InlineInfo {
- Block *sourceBlock;
- Operation *firstInlinedInst;
- Operation *lastInlinedInst;
-};
-
-/// The storage class for an undoable block action (one of BlockActionKind),
-/// contains the information necessary to undo this action.
-struct BlockAction {
- static BlockAction getCreate(Block *block) {
- return {BlockActionKind::Create, block, {}};
- }
- static BlockAction getErase(Block *block, BlockPosition originalPosition) {
- return {BlockActionKind::Erase, block, {originalPosition}};
- }
- static BlockAction getInline(Block *block, Block *srcBlock,
- Block::iterator before) {
- BlockAction action{BlockActionKind::Inline, block, {}};
- action.inlineInfo = {srcBlock,
- srcBlock->empty() ? nullptr : &srcBlock->front(),
- srcBlock->empty() ? nullptr : &srcBlock->back()};
- return action;
- }
- static BlockAction getMove(Block *block, BlockPosition originalPosition) {
- return {BlockActionKind::Move, block, {originalPosition}};
- }
- static BlockAction getSplit(Block *block, Block *originalBlock) {
- BlockAction action{BlockActionKind::Split, block, {}};
- action.originalBlock = originalBlock;
- return action;
- }
- static BlockAction getTypeConversion(Block *block) {
- return BlockAction{BlockActionKind::TypeConversion, block, {}};
- }
-
- // The action kind.
- BlockActionKind kind;
-
- // A pointer to the block that was created by the action.
- Block *block;
-
- union {
- // In use if kind == BlockActionKind::Inline or BlockActionKind::Erase, and
- // contains a pointer to the region that originally contained the block as
- // well as the position of the block in that region.
- BlockPosition originalPosition;
- // In use if kind == BlockActionKind::Split and contains a pointer to the
- // block that was split into two parts.
- Block *originalBlock;
- // In use if kind == BlockActionKind::Inline, and contains the information
- // needed to undo the inlining.
- InlineInfo inlineInfo;
- };
-};
-
//===----------------------------------------------------------------------===//
// UnresolvedMaterialization
@@ -820,6 +739,251 @@ void ArgConverter::insertConversion(Block *newBlock,
conversionInfo.insert({newBlock, std::move(info)});
}
+//===----------------------------------------------------------------------===//
+// IR rewrites
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// An IR rewrite that can be committed (upon success) or rolled back (upon
+/// failure).
+///
+/// The dialect conversion keeps track of IR modifications (requested by the
+/// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites
+/// are directly applied to the IR as the rewriter API is used, some are applied
+/// partially, and some are delayed until the `IRRewrite` objects are committed.
+class IRRewrite {
+public:
+ /// The kind of the rewrite. Rewrites can be undone if the conversion fails.
+ enum class Kind {
+ CreateBlock,
+ EraseBlock,
+ InlineBlock,
+ MoveBlock,
+ SplitBlock,
+ BlockTypeConversion
+ };
+
+ virtual ~IRRewrite() = default;
+
+ /// Roll back the rewrite.
+ virtual void rollback() = 0;
+
+ /// Commit the rewrite.
+ virtual void commit() {}
+
+ Kind getKind() const { return kind; }
+
+ static bool classof(const IRRewrite *rewrite) { return true; }
+
+protected:
+ IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
+ : kind(kind), rewriterImpl(rewriterImpl) {}
+
+ const Kind kind;
+ ConversionPatternRewriterImpl &rewriterImpl;
+};
+
+/// A block rewrite.
+class BlockRewrite : public IRRewrite {
+public:
+ /// Return the block that this rewrite operates on.
+ Block *getBlock() const { return block; }
+
+ static bool classof(const IRRewrite *rewrite) {
+ return rewrite->getKind() >= Kind::CreateBlock &&
+ rewrite->getKind() <= Kind::BlockTypeConversion;
+ }
+
+protected:
+ BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
+ Block *block)
+ : IRRewrite(kind, rewriterImpl), block(block) {}
+
+ // The block that this rewrite operates on.
+ Block *block;
+};
+
+/// Creation of a block. Block creations are immediately reflected in the IR.
+/// There is no extra work to commit the rewrite. During rollback, the newly
+/// created block is erased.
+class CreateBlockRewrite : public BlockRewrite {
+public:
+ CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block)
+ : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {}
+
+ static bool classof(const IRRewrite *rewrite) {
+ return rewrite->getKind() == Kind::CreateBlock;
+ }
+
+ void rollback() override {
+ // Unlink all of the operations within this block, they will be deleted
+ // separately.
+ auto &blockOps = block->getOperations();
+ while (!blockOps.empty())
+ blockOps.remove(blockOps.begin());
+ block->dropAllDefinedValueUses();
+ block->erase();
+ }
+};
+
+/// Erasure of a block. Block erasures are partially reflected in the IR. Erased
+/// blocks are immediately unlinked, but only erased when the rewrite is
+/// committed. This makes it easier to rollback a block erasure: the block is
+/// simply inserted into its original location.
+class EraseBlockRewrite : public BlockRewrite {
+public:
+ EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
+ Region *region, Block *insertBeforeBlock)
+ : BlockRewrite(Kind::EraseBlock, rewriterImpl, block), region(region),
+ insertBeforeBlock(insertBeforeBlock) {}
+
+ static bool classof(const IRRewrite *rewrite) {
+ return rewrite->getKind() == Kind::EraseBlock;
+ }
+
+ ~EraseBlockRewrite() override {
+ assert(!block && "rewrite was neither rolled back nor committed");
+ }
+
+ void rollback() override {
+ // The block (owned by this rewrite) was not actually erased yet. It was
+ // just unlinked. Put it back into its original position.
+ assert(block && "expected block");
+ auto &blockList = region->getBlocks();
+ Region::iterator before = insertBeforeBlock
+ ? Region::iterator(insertBeforeBlock)
+ : blockList.end();
+ blockList.insert(before, block);
+ block = nullptr;
+ }
+
+ void commit() override {
+ // Erase the block.
+ assert(block && "expected block");
+ delete block;
+ block = nullptr;
+ }
+
+private:
+ // The region in which this block was previously contained.
+ Region *region;
+
+ // The original successor of this block before it was unlinked. "nullptr" if
+ // this block was the only block in the region.
+ Block *insertBeforeBlock;
+};
+
+/// Inlining of a block. This rewrite is immediately reflected in the IR.
+/// Note: This rewrite represents only the inlining of the operations. The
+/// erasure of the inlined block is a separate rewrite.
+class InlineBlockRewrite : public BlockRewrite {
+public:
+ InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
+ Block *sourceBlock, Block::iterator before)
+ : BlockRewrite(Kind::InlineBlock, rewriterImpl, block),
+ sourceBlock(sourceBlock),
+ firstInlinedInst(sourceBlock->empty() ? nullptr
+ : &sourceBlock->front()),
+ lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
+ }
+
+ static bool classof(const IRRewrite *rewrite) {
+ return rewrite->getKind() == Kind::InlineBlock;
+ }
+
+ void rollback() override {
+ // Put the operations from the destination block (owned by the rewrite)
+ // back into the source block.
+ if (firstInlinedInst) {
+ assert(lastInlinedInst && "expected operation");
+ sourceBlock->getOperations().splice(sourceBlock->begin(),
+ block->getOperations(),
+ Block::iterator(firstInlinedInst),
+ ++Block::iterator(lastInlinedInst));
+ }
+ }
+
+private:
+ // The block that originally contained the operations.
+ Block *sourceBlock;
+
+ // The first inlined operation.
+ Operation *firstInlinedInst;
+
+ // The last inlined operation.
+ Operation *lastInlinedInst;
+};
+
+/// Moving of a block. This rewrite is immediately reflected in the IR.
+class MoveBlockRewrite : public BlockRewrite {
+public:
+ MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
+ Region *region, Block *insertBeforeBlock)
+ : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region),
+ insertBeforeBlock(insertBeforeBlock) {}
+
+ static bool classof(const IRRewrite *rewrite) {
+ return rewrite->getKind() == Kind::MoveBlock;
+ }
+
+ void rollback() override {
+ // Move the block back to its original position.
+ Region::iterator before =
+ insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end();
+ region->getBlocks().splice(before, block->getParent()->getBlocks(), block);
+ }
+
+private:
+ // The region in which this block was previously contained.
+ Region *region;
+
+ // The original successor of this block before it was moved. "nullptr" if
+ // this block was the only block in the region.
+ Block *insertBeforeBlock;
+};
+
+/// Splitting of a block. This rewrite is immediately reflected in the IR.
+class SplitBlockRewrite : public BlockRewrite {
+public:
+ SplitBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
+ Block *originalBlock)
+ : BlockRewrite(Kind::SplitBlock, rewriterImpl, block),
+ originalBlock(originalBlock) {}
+
+ static bool classof(const IRRewrite *rewrite) {
+ return rewrite->getKind() == Kind::SplitBlock;
+ }
+
+ void rollback() override {
+ // Merge back the block that was split out.
+ originalBlock->getOperations().splice(originalBlock->end(),
+ block->getOperations());
+ block->dropAllDefinedValueUses();
+ block->erase();
+ }
+
+private:
+ // The original block from which this block was split.
+ Block *originalBlock;
+};
+
+/// Block type conversion. This rewrite is partially reflected in the IR.
+class BlockTypeConversionRewrite : public BlockRewrite {
+public:
+ BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+ Block *block)
+ : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block) {}
+
+ static bool classof(const IRRewrite *rewrite) {
+ return rewrite->getKind() == Kind::BlockTypeConversion;
+ }
+
+ // TODO: Block type conversions are currently committed in
+ // `ArgConverter::applyRewrites`. This should be done in the "commit" method.
+ void rollback() override;
+};
+} // namespace
+
//===----------------------------------------------------------------------===//
// ConversionPatternRewriterImpl
//===----------------------------------------------------------------------===//
@@ -848,13 +1012,17 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// Reset the state of the rewriter to a previously saved point.
void resetState(RewriterState state);
- /// Erase any blocks that were unlinked from their regions and stored in block
- /// actions.
- void eraseDanglingBlocks();
+ /// Append a rewrite. Rewrites are committed upon success and rolled back upon
+ /// failure.
+ template <typename ActionTy, typename... Args>
+ void appendRewrite(Args &&...args) {
+ rewrites.push_back(
+ std::make_unique<ActionTy>(*this, std::forward<Args>(args)...));
+ }
- /// Undo the block actions (motions, splits) one by one in reverse order until
- /// "numActionsToKeep" actions remains.
- void undoBlockActions(unsigned numActionsToKeep = 0);
+ /// Undo the rewrites (motions, splits) one by one in reverse order until
+ /// "numRewritesToKeep" rewrites remains.
+ void undoRewrites(unsigned numRewritesToKeep = 0);
/// Remap the given values to those with potentially different types. Returns
/// success if the values could be remapped, failure otherwise. `valueDiagTag`
@@ -954,7 +1122,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
SmallVector<BlockArgument, 4> argReplacements;
/// Ordered list of block operations (creations, splits, motions).
- SmallVector<BlockAction, 4> blockActions;
+ SmallVector<std::unique_ptr<IRRewrite>> rewrites;
/// A set of operations that should no longer be considered for legalization,
/// but were not directly replace/erased/etc. by a pattern. These are
@@ -995,6 +1163,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
} // namespace detail
} // namespace mlir
+void BlockTypeConversionRewrite::rollback() {
+ // Undo the type conversion.
+ rewriterImpl.argConverter.discardRewrites(block);
+}
+
/// Detach any operations nested in the given operation from their parent
/// blocks, and erase the given operation. This can be used when the nested
/// operations are scheduled for erasure themselves, so deleting the regions of
@@ -1020,7 +1193,7 @@ void ConversionPatternRewriterImpl::discardRewrites() {
for (auto &state : rootUpdates)
state.resetOperation();
- undoBlockActions();
+ undoRewrites();
// Remove any newly created ops.
for (UnresolvedMaterialization &materialization : unresolvedMaterializations)
@@ -1083,8 +1256,9 @@ void ConversionPatternRewriterImpl::applyRewrites() {
argConverter.applyRewrites(mapping);
- // Now that the ops have been erased, also erase dangling blocks.
- eraseDanglingBlocks();
+ // Commit all rewrites.
+ for (auto &rewrite : rewrites)
+ rewrite->commit();
}
//===----------------------------------------------------------------------===//
@@ -1093,8 +1267,7 @@ void ConversionPatternRewriterImpl::applyRewrites() {
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
return RewriterState(createdOps.size(), unresolvedMaterializations.size(),
replacements.size(), argReplacements.size(),
- blockActions.size(), ignoredOps.size(),
- rootUpdates.size());
+ rewrites.size(), ignoredOps.size(), rootUpdates.size());
}
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
@@ -1109,8 +1282,8 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
mapping.erase(replacedArg);
argReplacements.resize(state.numArgReplacements);
- // Undo any block actions.
- undoBlockActions(state.numBlockActions);
+ // Undo any rewrites.
+ undoRewrites(state.numRewrites);
// Reset any replaced operations and undo any saved mappings.
for (auto &repl : llvm::drop_begin(replacements, state.numReplacements))
@@ -1149,76 +1322,11 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
operationsWithChangedResults.pop_back();
}
-void ConversionPatternRewriterImpl::eraseDanglingBlocks() {
- for (auto &action : blockActions)
- if (action.kind == BlockActionKind::Erase)
- delete action.block;
-}
-
-void ConversionPatternRewriterImpl::undoBlockActions(
- unsigned numActionsToKeep) {
- for (auto &action :
- llvm::reverse(llvm::drop_begin(blockActions, numActionsToKeep))) {
- switch (action.kind) {
- // Delete the created block.
- case BlockActionKind::Create: {
- // Unlink all of the operations within this block, they will be deleted
- // separately.
- auto &blockOps = action.block->getOperations();
- while (!blockOps.empty())
- blockOps.remove(blockOps.begin());
- action.block->dropAllDefinedValueUses();
- action.block->erase();
- break;
- }
- // Put the block (owned by action) back into its original position.
- case BlockActionKind::Erase: {
- auto &blockList = action.originalPosition.region->getBlocks();
- Block *insertBeforeBlock = action.originalPosition.insertBeforeBlock;
- blockList.insert((insertBeforeBlock ? Region::iterator(insertBeforeBlock)
- : blockList.end()),
- action.block);
- break;
- }
- // Put the instructions from the destination block (owned by the action)
- // back into the source block.
- case BlockActionKind::Inline: {
- Block *sourceBlock = action.inlineInfo.sourceBlock;
- if (action.inlineInfo.firstInlinedInst) {
- assert(action.inlineInfo.lastInlinedInst && "expected operation");
- sourceBlock->getOperations().splice(
- sourceBlock->begin(), action.block->getOperations(),
- Block::iterator(action.inlineInfo.firstInlinedInst),
- ++Block::iterator(action.inlineInfo.lastInlinedInst));
- }
- break;
- }
- // Move the block back to its original position.
- case BlockActionKind::Move: {
- Region *originalRegion = action.originalPosition.region;
- Block *insertBeforeBlock = action.originalPosition.insertBeforeBlock;
- originalRegion->getBlocks().splice(
- (insertBeforeBlock ? Region::iterator(insertBeforeBlock)
- : originalRegion->end()),
- action.block->getParent()->getBlocks(), action.block);
- break;
- }
- // Merge back the block that was split out.
- case BlockActionKind::Split: {
- action.originalBlock->getOperations().splice(
- action.originalBlock->end(), action.block->getOperations());
- action.block->dropAllDefinedValueUses();
- action.block->erase();
- break;
- }
- // Undo the type conversion.
- case BlockActionKind::TypeConversion: {
- argConverter.discardRewrites(action.block);
- break;
- }
- }
- }
- blockActions.resize(numActionsToKeep);
+void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
+ for (auto &rewrite :
+ llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
+ rewrite->rollback();
+ rewrites.resize(numRewritesToKeep);
}
LogicalResult ConversionPatternRewriterImpl::remapValues(
@@ -1309,7 +1417,7 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
return failure();
if (Block *newBlock = *result) {
if (newBlock != block)
- blockActions.push_back(BlockAction::getTypeConversion(newBlock));
+ appendRewrite<BlockTypeConversionRewrite>(newBlock);
}
return result;
}
@@ -1410,28 +1518,28 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
Region *region = block->getParent();
Block *origNextBlock = block->getNextNode();
- blockActions.push_back(BlockAction::getErase(block, {region, origNextBlock}));
+ appendRewrite<EraseBlockRewrite>(block, region, origNextBlock);
}
void ConversionPatternRewriterImpl::notifyBlockInserted(
Block *block, Region *previous, Region::iterator previousIt) {
if (!previous) {
// This is a newly created block.
- blockActions.push_back(BlockAction::getCreate(block));
+ appendRewrite<CreateBlockRewrite>(block);
return;
}
Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt;
- blockActions.push_back(BlockAction::getMove(block, {previous, prevBlock}));
+ appendRewrite<MoveBlockRewrite>(block, previous, prevBlock);
}
void ConversionPatternRewriterImpl::notifySplitBlock(Block *block,
Block *continuation) {
- blockActions.push_back(BlockAction::getSplit(continuation, block));
+ appendRewrite<SplitBlockRewrite>(continuation, block);
}
void ConversionPatternRewriterImpl::notifyBlockBeingInlined(
Block *block, Block *srcBlock, Block::iterator before) {
- blockActions.push_back(BlockAction::getInline(block, srcBlock, before));
+ appendRewrite<InlineBlockRewrite>(block, srcBlock, before);
}
void ConversionPatternRewriterImpl::notifyMatchFailure(
@@ -1501,8 +1609,8 @@ void ConversionPatternRewriter::eraseBlock(Block *block) {
for (Operation &op : *block)
eraseOp(&op);
- // Unlink the block from its parent region. The block is kept in the block
- // action and will be actually destroyed when rewrites are applied. This
+ // Unlink the block from its parent region. The block is kept in the rewrite
+ // object and will be actually destroyed when rewrites are applied. This
// allows us to keep the operations in the block live and undo the removal by
// re-inserting the block.
block->getParent()->getBlocks().remove(block);
@@ -1700,11 +1808,11 @@ class OperationLegalizer {
RewriterState &curState);
/// Legalizes the actions registered during the execution of a pattern.
- LogicalResult legalizePatternBlockActions(Operation *op,
- ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &impl,
- RewriterState &state,
- RewriterState &newState);
+ LogicalResult
+ legalizePatternBlockRewrites(Operation *op,
+ ConversionPatternRewriter &rewriter,
+ ConversionPatternRewriterImpl &impl,
+ RewriterState &state, RewriterState &newState);
LogicalResult legalizePatternCreatedOperations(
ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
RewriterState &state, RewriterState &newState);
@@ -1986,8 +2094,8 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
// Legalize each of the actions registered during application.
RewriterState newState = impl.getCurrentState();
- if (failed(legalizePatternBlockActions(op, rewriter, impl, curState,
- newState)) ||
+ if (failed(legalizePatternBlockRewrites(op, rewriter, impl, curState,
+ newState)) ||
failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) ||
failed(legalizePatternCreatedOperations(rewriter, impl, curState,
newState))) {
@@ -1998,7 +2106,7 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
return success();
}
-LogicalResult OperationLegalizer::legalizePatternBlockActions(
+LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
Operation *op, ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &impl, RewriterState &state,
RewriterState &newState) {
@@ -2006,22 +2114,22 @@ LogicalResult OperationLegalizer::legalizePatternBlockActions(
// If the pattern moved or created any blocks, make sure the types of block
// arguments get legalized.
- for (int i = state.numBlockActions, e = newState.numBlockActions; i != e;
- ++i) {
- auto &action = impl.blockActions[i];
- if (action.kind == BlockActionKind::TypeConversion ||
- action.kind == BlockActionKind::Erase)
+ for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
+ BlockRewrite *rewrite = dyn_cast<BlockRewrite>(impl.rewrites[i].get());
+ if (!rewrite)
+ continue;
+ Block *block = rewrite->getBlock();
+ if (isa<BlockTypeConversionRewrite, EraseBlockRewrite>(rewrite))
continue;
// Only check blocks outside of the current operation.
- Operation *parentOp = action.block->getParentOp();
- if (!parentOp || parentOp == op || action.block->getNumArguments() == 0)
+ Operation *parentOp = block->getParentOp();
+ if (!parentOp || parentOp == op || block->getNumArguments() == 0)
continue;
// If the region of the block has a type converter, try to convert the block
// directly.
- if (auto *converter =
- impl.argConverter.getConverter(action.block->getParent())) {
- if (failed(impl.convertBlockSignature(action.block, converter))) {
+ if (auto *converter = impl.argConverter.getConverter(block->getParent())) {
+ if (failed(impl.convertBlockSignature(block, converter))) {
LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
"block"));
return failure();
@@ -2042,9 +2150,9 @@ LogicalResult OperationLegalizer::legalizePatternBlockActions(
// If this operation should be considered for re-legalization, try it.
if (operationsToIgnore.insert(parentOp).second &&
failed(legalize(parentOp, rewriter))) {
- LLVM_DEBUG(logFailure(
- impl.logger, "operation '{0}'({1}) became illegal after block action",
- parentOp->getName(), parentOp));
+ LLVM_DEBUG(logFailure(impl.logger,
+ "operation '{0}'({1}) became illegal after rewrite",
+ parentOp->getName(), parentOp));
return failure();
}
}
>From ebfaca6b688394233b0d6a22f77b8b7cccaf67a8 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 12 Feb 2024 09:08:21 +0000
Subject: [PATCH 2/2] [mlir][Transforms] Support `moveOpBefore`/`After` in
dialect conversion
Add a new rewrite action for "operation movements". This action can roll back `moveOpBefore` and `moveOpAfter`.
`RewriterBase::moveOpBefore` and `RewriterBase::moveOpAfter` is no longer virtual. (The dialect conversion can gather all required information for rollbacks from listener notifications.)
BEGIN_PUBLIC
No public commit message needed for presubmit.
END_PUBLIC
---
mlir/include/mlir/IR/PatternMatch.h | 6 +-
.../mlir/Transforms/DialectConversion.h | 9 +--
.../Transforms/Utils/DialectConversion.cpp | 74 +++++++++++++++----
mlir/test/Transforms/test-legalizer.mlir | 14 ++++
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 20 ++++-
5 files changed, 95 insertions(+), 28 deletions(-)
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 78dcfe7f6fc3d2..b8aeea0d23475b 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -588,8 +588,7 @@ class RewriterBase : public OpBuilder {
/// Unlink this operation from its current block and insert it right before
/// `iterator` in the specified block.
- virtual void moveOpBefore(Operation *op, Block *block,
- Block::iterator iterator);
+ void moveOpBefore(Operation *op, Block *block, Block::iterator iterator);
/// Unlink this operation from its current block and insert it right after
/// `existingOp` which may be in the same or another block in the same
@@ -598,8 +597,7 @@ class RewriterBase : public OpBuilder {
/// Unlink this operation from its current block and insert it right after
/// `iterator` in the specified block.
- virtual void moveOpAfter(Operation *op, Block *block,
- Block::iterator iterator);
+ void moveOpAfter(Operation *op, Block *block, Block::iterator iterator);
/// Unlink this block and insert it right before `existingBlock`.
void moveBlockBefore(Block *block, Block *anotherBlock);
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index f061d761ecefbb..b028d2b71b3762 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -721,8 +721,8 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// PatternRewriter hook for updating the given operation in-place.
/// Note: These methods only track updates to the given operation itself,
- /// and not nested regions. Updates to regions will still require notification
- /// through other more specific hooks above.
+ /// and not nested regions. Updates to regions will still require
+ /// notification through other more specific hooks above.
void startOpModification(Operation *op) override;
/// PatternRewriter hook for updating the given operation in-place.
@@ -738,11 +738,6 @@ class ConversionPatternRewriter final : public PatternRewriter {
// Hide unsupported pattern rewriter API.
using OpBuilder::setListener;
- void moveOpBefore(Operation *op, Block *block,
- Block::iterator iterator) override;
- void moveOpAfter(Operation *op, Block *block,
- Block::iterator iterator) override;
-
std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
};
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index edca84e5a73f04..85b67bb834de7c 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -760,7 +760,8 @@ class IRRewrite {
InlineBlock,
MoveBlock,
SplitBlock,
- BlockTypeConversion
+ BlockTypeConversion,
+ MoveOperation
};
virtual ~IRRewrite() = default;
@@ -982,6 +983,54 @@ class BlockTypeConversionRewrite : public BlockRewrite {
// `ArgConverter::applyRewrites`. This should be done in the "commit" method.
void rollback() override;
};
+
+/// An operation rewrite.
+class OperationRewrite : public IRRewrite {
+public:
+ /// Return the operation that this rewrite operates on.
+ Operation *getOperation() const { return op; }
+
+ static bool classof(const IRRewrite *rewrite) {
+ return rewrite->getKind() >= Kind::MoveOperation &&
+ rewrite->getKind() <= Kind::MoveOperation;
+ }
+
+protected:
+ OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
+ Operation *op)
+ : IRRewrite(kind, rewriterImpl), op(op) {}
+
+ // The operation that this rewrite operates on.
+ Operation *op;
+};
+
+/// Moving of an operation. This rewrite is immediately reflected in the IR.
+class MoveOperationRewrite : public OperationRewrite {
+public:
+ MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+ Operation *op, Block *block, Operation *insertBeforeOp)
+ : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block),
+ insertBeforeOp(insertBeforeOp) {}
+
+ static bool classof(const IRRewrite *rewrite) {
+ return rewrite->getKind() == Kind::MoveOperation;
+ }
+
+ void rollback() override {
+ // Move the operation back to its original position.
+ Block::iterator before =
+ insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
+ block->getOperations().splice(before, op->getBlock()->getOperations(), op);
+ }
+
+private:
+ // The block in which this operation was previously contained.
+ Block *block;
+
+ // The original successor of this operation before it was moved. "nullptr" if
+ // this operation was the only operation in the region.
+ Operation *insertBeforeOp;
+};
} // namespace
//===----------------------------------------------------------------------===//
@@ -1478,12 +1527,19 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
void ConversionPatternRewriterImpl::notifyOperationInserted(
Operation *op, OpBuilder::InsertPoint previous) {
- assert(!previous.isSet() && "expected newly created op");
LLVM_DEBUG({
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
<< ")\n";
});
- createdOps.push_back(op);
+ if (!previous.isSet()) {
+ // This is a newly created op.
+ createdOps.push_back(op);
+ return;
+ }
+ Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
+ ? nullptr
+ : &*previous.getPoint();
+ appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp);
}
void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
@@ -1722,18 +1778,6 @@ void ConversionPatternRewriter::cancelOpModification(Operation *op) {
rootUpdates.erase(rootUpdates.begin() + updateIdx);
}
-void ConversionPatternRewriter::moveOpBefore(Operation *op, Block *block,
- Block::iterator iterator) {
- llvm_unreachable(
- "moving single ops is not supported in a dialect conversion");
-}
-
-void ConversionPatternRewriter::moveOpAfter(Operation *op, Block *block,
- Block::iterator iterator) {
- llvm_unreachable(
- "moving single ops is not supported in a dialect conversion");
-}
-
detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
return *impl;
}
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index d8cf6e4719cede..84fcc18ab7d370 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -320,3 +320,17 @@ module {
return
}
}
+
+// -----
+
+// CHECK-LABEL: func @test_move_op_before_rollback()
+func.func @test_move_op_before_rollback() {
+ // CHECK: "test.one_region_op"()
+ // CHECK: "test.hoist_me"()
+ "test.one_region_op"() ({
+ // expected-remark @below{{'test.hoist_me' is not legalizable}}
+ %0 = "test.hoist_me"() : () -> (i32)
+ "test.valid"(%0) : (i32) -> ()
+ }) : () -> ()
+ "test.return"() : () -> ()
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index d7e5d6db50c1fb..1c02232b8adbb1 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -773,6 +773,22 @@ struct TestUndoBlockArgReplace : public ConversionPattern {
}
};
+/// This pattern hoists ops out of a "test.hoist_me" and then fails conversion.
+/// This is to test the rollback logic.
+struct TestUndoMoveOpBefore : public ConversionPattern {
+ TestUndoMoveOpBefore(MLIRContext *ctx)
+ : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.moveOpBefore(op, op->getParentOp());
+ // Replace with an illegal op to ensure the conversion fails.
+ rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type());
+ return success();
+ }
+};
+
/// A rewrite pattern that tests the undo mechanism when erasing a block.
struct TestUndoBlockErase : public ConversionPattern {
TestUndoBlockErase(MLIRContext *ctx)
@@ -1069,7 +1085,7 @@ struct TestLegalizePatternDriver
TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
TestNonRootReplacement, TestBoundedRecursiveRewrite,
TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
- TestCreateUnregisteredOp>(&getContext());
+ TestCreateUnregisteredOp, TestUndoMoveOpBefore>(&getContext());
patterns.add<TestDropOpSignatureConversion>(&getContext(), converter);
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
@@ -1079,7 +1095,7 @@ struct TestLegalizePatternDriver
ConversionTarget target(getContext());
target.addLegalOp<ModuleOp>();
target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
- TerminatorOp>();
+ TerminatorOp, OneRegionOp>();
target
.addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
More information about the llvm-branch-commits
mailing list