[llvm-branch-commits] [mlir] [mlir][Transforms][NFC] Turn op/block arg replacements into `IRRewrite`s (PR #81757)
Matthias Springer via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Feb 14 08:43:47 PST 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/81757
>From 613a6162af4ac07066dbcf87b580a085eca5a47a Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Wed, 14 Feb 2024 16:27:53 +0000
Subject: [PATCH] [mlir][Transforms][NFC] Turn op/block arg replacements into
`IRRewrite`s
This commit is a refactoring of the dialect conversion. The dialect conversion maintains a list of "IR rewrites" that can be commited (upon success) or rolled back (upon failure).
Until now, op replacements and block argument replacements were kept track in separate data structures inside the dialect conversion. This commit turns them into `IRRewrite`s, so that they can be committed or rolled back just like any other rewrite. This simplifies the internal state of the dialect conversion.
Overview of changes:
* Add two new rewrite classes: `ReplaceBlockArgRewrite` and `ReplaceOperationRewrite`. Remove the `OpReplacement` helper class; it is now part of `ReplaceOperationRewrite`.
* Simplify `RewriterState`: `numReplacements` and `numArgReplacements` are no longer needed. (Now being kept track of by `numRewrites`.)
* Add `IRRewrite::cleanup`. Operations should not be erased in `commit` because they may still be referenced in other internal state of the dialect conversion (`mapping`). Detaching operations is fine.
---
.../Transforms/Utils/DialectConversion.cpp | 297 ++++++++++--------
1 file changed, 159 insertions(+), 138 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index b2baa88879b6e9..a07c8a56822de5 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -153,14 +153,12 @@ namespace {
/// This is useful when saving and undoing a set of rewrites.
struct RewriterState {
RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
- unsigned numReplacements, unsigned numArgReplacements,
unsigned numRewrites, unsigned numIgnoredOperations,
unsigned numErased)
: numCreatedOps(numCreatedOps),
numUnresolvedMaterializations(numUnresolvedMaterializations),
- numReplacements(numReplacements),
- numArgReplacements(numArgReplacements), numRewrites(numRewrites),
- numIgnoredOperations(numIgnoredOperations), numErased(numErased) {}
+ numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
+ numErased(numErased) {}
/// The current number of created operations.
unsigned numCreatedOps;
@@ -168,12 +166,6 @@ struct RewriterState {
/// The current number of unresolved materializations.
unsigned numUnresolvedMaterializations;
- /// The current number of replacements queued.
- unsigned numReplacements;
-
- /// The current number of argument replacements queued.
- unsigned numArgReplacements;
-
/// The current number of rewrites performed.
unsigned numRewrites;
@@ -184,20 +176,6 @@ struct RewriterState {
unsigned numErased;
};
-//===----------------------------------------------------------------------===//
-// OpReplacement
-
-/// This class represents one requested operation replacement via 'replaceOp' or
-/// 'eraseOp`.
-struct OpReplacement {
- OpReplacement(const TypeConverter *converter = nullptr)
- : converter(converter) {}
-
- /// An optional type converter that can be used to materialize conversions
- /// between the new and old values if necessary.
- const TypeConverter *converter;
-};
-
//===----------------------------------------------------------------------===//
// UnresolvedMaterialization
@@ -318,8 +296,10 @@ class IRRewrite {
MoveBlock,
SplitBlock,
BlockTypeConversion,
+ ReplaceBlockArg,
MoveOperation,
- ModifyOperation
+ ModifyOperation,
+ ReplaceOperation
};
virtual ~IRRewrite() = default;
@@ -330,6 +310,12 @@ class IRRewrite {
/// Commit the rewrite.
virtual void commit() {}
+ /// Cleanup operations. Operations may be unlinked from their blocks during
+ /// the commit/rollback phase, but they must not be erased yet. This is
+ /// because internal dialect conversion state (such as `mapping`) may still
+ /// be using them. Operations must be erased during cleanup.
+ virtual void cleanup() {}
+
/// Erase the given op (unless it was already erased).
void eraseOp(Operation *op);
@@ -356,7 +342,7 @@ class BlockRewrite : public IRRewrite {
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() >= Kind::CreateBlock &&
- rewrite->getKind() <= Kind::BlockTypeConversion;
+ rewrite->getKind() <= Kind::ReplaceBlockArg;
}
protected:
@@ -424,6 +410,8 @@ class EraseBlockRewrite : public BlockRewrite {
void commit() override {
// Erase the block.
assert(block && "expected block");
+ assert(block->empty() && "expected empty block");
+ block->dropAllDefinedValueUses();
delete block;
block = nullptr;
}
@@ -585,6 +573,27 @@ class BlockTypeConversionRewrite : public BlockRewrite {
const TypeConverter *converter;
};
+/// Replacing a block argument. This rewrite is not immediately reflected in the
+/// IR. An internal IR mapping is updated, but the actual replacement is delayed
+/// until the rewrite is committed.
+class ReplaceBlockArgRewrite : public BlockRewrite {
+public:
+ ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+ Block *block, BlockArgument arg)
+ : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg) {}
+
+ static bool classof(const IRRewrite *rewrite) {
+ return rewrite->getKind() == Kind::ReplaceBlockArg;
+ }
+
+ void commit() override;
+
+ void rollback() override;
+
+private:
+ BlockArgument arg;
+};
+
/// An operation rewrite.
class OperationRewrite : public IRRewrite {
public:
@@ -593,7 +602,7 @@ class OperationRewrite : public IRRewrite {
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() >= Kind::MoveOperation &&
- rewrite->getKind() <= Kind::ModifyOperation;
+ rewrite->getKind() <= Kind::ReplaceOperation;
}
protected:
@@ -664,6 +673,41 @@ class ModifyOperationRewrite : public OperationRewrite {
SmallVector<Value, 8> operands;
SmallVector<Block *, 2> successors;
};
+
+/// Replacing an operation. Erasing an operation is treated as a special case
+/// with "null" replacements. This rewrite is not immediately reflected in the
+/// IR. An internal IR mapping is updated, but values are not replaced and the
+/// original op is not erased until the rewrite is committed.
+class ReplaceOperationRewrite : public OperationRewrite {
+public:
+ ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+ Operation *op, const TypeConverter *converter,
+ bool changedResults)
+ : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
+ converter(converter), changedResults(changedResults) {}
+
+ static bool classof(const IRRewrite *rewrite) {
+ return rewrite->getKind() == Kind::ReplaceOperation;
+ }
+
+ void commit() override;
+
+ void rollback() override;
+
+ void cleanup() override;
+
+private:
+ friend struct OperationConverter;
+
+ /// An optional type converter that can be used to materialize conversions
+ /// between the new and old values if necessary.
+ const TypeConverter *converter;
+
+ /// A vector of indices into `replacements` of operations that were replaced
+ /// with values with different result types than the original operation, e.g.
+ /// 1->N conversion of some kind.
+ bool changedResults;
+};
} // namespace
/// Return "true" if there is an operation rewrite that matches the specified
@@ -856,6 +900,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
void eraseBlock(Block *block) override {
if (erased.contains(block))
return;
+ assert(block->empty() && "expected empty block");
block->dropAllDefinedValueUses();
RewriterBase::eraseBlock(block);
}
@@ -887,12 +932,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// conversion.
SmallVector<UnresolvedMaterialization> unresolvedMaterializations;
- /// Ordered map of requested operation replacements.
- llvm::MapVector<Operation *, OpReplacement> replacements;
-
- /// Ordered vector of any requested block argument replacements.
- SmallVector<BlockArgument, 4> argReplacements;
-
/// Ordered list of block operations (creations, splits, motions).
SmallVector<std::unique_ptr<IRRewrite>> rewrites;
@@ -907,11 +946,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// operation was ignored.
SetVector<Operation *> ignoredOps;
- /// A vector of indices into `replacements` of operations that were replaced
- /// with values with different result types than the original operation, e.g.
- /// 1->N conversion of some kind.
- SmallVector<unsigned, 4> operationsWithChangedResults;
-
/// The current type converter, or nullptr if no type converter is currently
/// active.
const TypeConverter *currentTypeConverter = nullptr;
@@ -923,6 +957,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// This allows the user to collect the match failure message.
function_ref<void(Diagnostic &)> notifyCallback;
+ DenseSet<Operation *> *trackedOps = nullptr;
+
#ifndef NDEBUG
/// A set of operations that have pending updates. This tracking isn't
/// strictly necessary, and is thus only active during debug builds for extra
@@ -969,6 +1005,8 @@ void BlockTypeConversionRewrite::commit() {
}
}
+ assert(origBlock->empty() && "expected empty block");
+ origBlock->dropAllDefinedValueUses();
delete origBlock;
origBlock = nullptr;
}
@@ -1031,6 +1069,47 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
return success();
}
+void ReplaceBlockArgRewrite::commit() {
+ Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
+ if (!repl)
+ return;
+
+ if (isa<BlockArgument>(repl)) {
+ arg.replaceAllUsesWith(repl);
+ return;
+ }
+
+ // If the replacement value is an operation, we check to make sure that we
+ // don't replace uses that are within the parent operation of the
+ // replacement value.
+ Operation *replOp = cast<OpResult>(repl).getOwner();
+ Block *replBlock = replOp->getBlock();
+ arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
+ Operation *user = operand.getOwner();
+ return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
+ });
+}
+
+void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(arg); }
+
+void ReplaceOperationRewrite::commit() {
+ for (OpResult result : op->getResults())
+ if (Value newValue =
+ rewriterImpl.mapping.lookupOrNull(result, result.getType()))
+ result.replaceAllUsesWith(newValue);
+ if (rewriterImpl.trackedOps)
+ rewriterImpl.trackedOps->erase(op);
+ // Do not erase the operation yet. It may still be referenced in `mapping`.
+ op->getBlock()->getOperations().remove(op);
+}
+
+void ReplaceOperationRewrite::rollback() {
+ for (auto result : op->getResults())
+ rewriterImpl.mapping.erase(result);
+}
+
+void ReplaceOperationRewrite::cleanup() { eraseOp(op); }
+
void ConversionPatternRewriterImpl::detachNestedAndErase(Operation *op) {
for (Region ®ion : op->getRegions()) {
for (Block &block : region.getBlocks()) {
@@ -1053,51 +1132,16 @@ void ConversionPatternRewriterImpl::discardRewrites() {
}
void ConversionPatternRewriterImpl::applyRewrites() {
- // Apply all of the rewrites replacements requested during conversion.
- for (auto &repl : replacements) {
- for (OpResult result : repl.first->getResults())
- if (Value newValue = mapping.lookupOrNull(result, result.getType()))
- result.replaceAllUsesWith(newValue);
- }
-
- // Apply all of the requested argument replacements.
- for (BlockArgument arg : argReplacements) {
- Value repl = mapping.lookupOrNull(arg, arg.getType());
- if (!repl)
- continue;
-
- if (isa<BlockArgument>(repl)) {
- arg.replaceAllUsesWith(repl);
- continue;
- }
-
- // If the replacement value is an operation, we check to make sure that we
- // don't replace uses that are within the parent operation of the
- // replacement value.
- Operation *replOp = cast<OpResult>(repl).getOwner();
- Block *replBlock = replOp->getBlock();
- arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
- Operation *user = operand.getOwner();
- return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
- });
- }
+ // Commit all rewrites.
+ for (auto &rewrite : rewrites)
+ rewrite->commit();
+ for (auto &rewrite : rewrites)
+ rewrite->cleanup();
// Drop all of the unresolved materialization operations created during
// conversion.
for (auto &mat : unresolvedMaterializations)
eraseRewriter.eraseOp(mat.getOp());
-
- // In a second pass, erase all of the replaced operations in reverse. This
- // allows processing nested operations before their parent region is
- // destroyed. Because we process in reverse order, producers may be deleted
- // before their users (a pattern deleting a producer and then the consumer)
- // so we first drop all uses explicitly.
- for (auto &repl : llvm::reverse(replacements))
- eraseRewriter.eraseOp(repl.first);
-
- // Commit all rewrites.
- for (auto &rewrite : rewrites)
- rewrite->commit();
}
//===----------------------------------------------------------------------===//
@@ -1105,28 +1149,14 @@ void ConversionPatternRewriterImpl::applyRewrites() {
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
return RewriterState(createdOps.size(), unresolvedMaterializations.size(),
- replacements.size(), argReplacements.size(),
rewrites.size(), ignoredOps.size(),
eraseRewriter.erased.size());
}
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
- // Reset any replaced arguments.
- for (BlockArgument replacedArg :
- llvm::drop_begin(argReplacements, state.numArgReplacements))
- mapping.erase(replacedArg);
- argReplacements.resize(state.numArgReplacements);
-
// Undo any rewrites.
undoRewrites(state.numRewrites);
- // Reset any replaced operations and undo any saved mappings.
- for (auto &repl : llvm::drop_begin(replacements, state.numReplacements))
- for (auto result : repl.first->getResults())
- mapping.erase(result);
- while (replacements.size() != state.numReplacements)
- replacements.pop_back();
-
// Pop all of the newly inserted materializations.
while (unresolvedMaterializations.size() !=
state.numUnresolvedMaterializations) {
@@ -1151,11 +1181,6 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
while (ignoredOps.size() != state.numIgnoredOperations)
ignoredOps.pop_back();
- // Reset operations with changed results.
- while (!operationsWithChangedResults.empty() &&
- operationsWithChangedResults.back() >= state.numReplacements)
- operationsWithChangedResults.pop_back();
-
while (eraseRewriter.erased.size() != state.numErased)
eraseRewriter.erased.pop_back();
}
@@ -1224,7 +1249,14 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
// Check to see if this operation was replaced or its parent ignored.
- return replacements.count(op) || ignoredOps.count(op->getParentOp());
+ return ignoredOps.count(op->getParentOp()) ||
+ llvm::any_of(rewrites, [&](auto &rewrite) {
+ auto *opReplacement =
+ dyn_cast<ReplaceOperationRewrite>(rewrite.get());
+ if (!opReplacement)
+ return false;
+ return opReplacement->getOperation() == op;
+ });
}
void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) {
@@ -1374,7 +1406,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
"invalid to provide a replacement value when the argument isn't "
"dropped");
mapping.map(origArg, inputMap->replacementValue);
- argReplacements.push_back(origArg);
+ appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
continue;
}
@@ -1408,7 +1440,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
}
mapping.map(origArg, newArg);
- argReplacements.push_back(origArg);
+ appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
}
@@ -1440,7 +1472,12 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
ValueRange newValues) {
assert(newValues.size() == op->getNumResults());
- assert(!replacements.count(op) && "operation was already replaced");
+#ifndef NDEBUG
+ for (auto &rewrite : rewrites)
+ if (auto *opReplacement = dyn_cast<ReplaceOperationRewrite>(rewrite.get()))
+ assert(opReplacement->getOperation() != op &&
+ "operation was already replaced");
+#endif // NDEBUG
// Track if any of the results changed, e.g. erased and replaced with null.
bool resultChanged = false;
@@ -1455,11 +1492,9 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
mapping.map(result, newValue);
resultChanged |= (newValue.getType() != result.getType());
}
- if (resultChanged)
- operationsWithChangedResults.push_back(replacements.size());
- // Record the requested operation replacement.
- replacements.insert(std::make_pair(op, OpReplacement(currentTypeConverter)));
+ appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter,
+ resultChanged);
// Mark this operation as recursively ignored so that we don't need to
// convert any nested operations.
@@ -1554,8 +1589,6 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
}
void ConversionPatternRewriter::eraseBlock(Block *block) {
- impl->notifyBlockIsBeingErased(block);
-
// Mark all ops for erasure.
for (Operation &op : *block)
eraseOp(&op);
@@ -1564,6 +1597,7 @@ void ConversionPatternRewriter::eraseBlock(Block *block) {
// object and will be actually destroyed when rewrites are applied. This
// allows us to keep the operations in the block live and undo the removal by
// re-inserting the block.
+ impl->notifyBlockIsBeingErased(block);
block->getParent()->getBlocks().remove(block);
}
@@ -1593,7 +1627,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
<< "'(in region of '" << parentOp->getName()
<< "'(" << from.getOwner()->getParentOp() << ")\n";
});
- impl->argReplacements.push_back(from);
+ impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from);
impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
}
@@ -2015,16 +2049,13 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
#ifndef NDEBUG
assert(impl.pendingRootUpdates.empty() && "dangling root updates");
-
// Check that the root was either replaced or updated in place.
+ auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
auto replacedRoot = [&] {
- return llvm::any_of(
- llvm::drop_begin(impl.replacements, curState.numReplacements),
- [op](auto &it) { return it.first == op; });
+ return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
};
auto updatedRootInPlace = [&] {
- return hasRewrite<ModifyOperationRewrite>(
- llvm::drop_begin(impl.rewrites, curState.numRewrites), op);
+ return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
};
assert((replacedRoot() || updatedRootInPlace()) &&
"expected pattern to replace the root operation");
@@ -2057,7 +2088,8 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
if (!rewrite)
continue;
Block *block = rewrite->getBlock();
- if (isa<BlockTypeConversionRewrite, EraseBlockRewrite>(rewrite))
+ if (isa<BlockTypeConversionRewrite, EraseBlockRewrite,
+ ReplaceBlockArgRewrite>(rewrite))
continue;
// Only check blocks outside of the current operation.
Operation *parentOp = block->getParentOp();
@@ -2452,6 +2484,7 @@ LogicalResult OperationConverter::convertOperations(
ConversionPatternRewriter rewriter(ops.front()->getContext());
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
rewriterImpl.notifyCallback = notifyCallback;
+ rewriterImpl.trackedOps = trackedOps;
for (auto *op : toConvert)
if (failed(convert(rewriter, op)))
@@ -2469,13 +2502,6 @@ LogicalResult OperationConverter::convertOperations(
rewriterImpl.discardRewrites();
} else {
rewriterImpl.applyRewrites();
-
- // It is possible for a later pattern to erase an op that was originally
- // identified as illegal and added to the trackedOps, remove it now after
- // replacements have been computed.
- if (trackedOps)
- for (auto &repl : rewriterImpl.replacements)
- trackedOps->erase(repl.first);
}
return success();
}
@@ -2489,21 +2515,20 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
return failure();
- if (rewriterImpl.operationsWithChangedResults.empty())
- return success();
-
// Process requested operation replacements.
- for (unsigned i = 0, e = rewriterImpl.operationsWithChangedResults.size();
- i != e; ++i) {
- unsigned replIdx = rewriterImpl.operationsWithChangedResults[i];
- auto &repl = *(rewriterImpl.replacements.begin() + replIdx);
- for (OpResult result : repl.first->getResults()) {
+ for (unsigned i = 0; i < rewriterImpl.rewrites.size(); ++i) {
+ auto *opReplacement =
+ dyn_cast<ReplaceOperationRewrite>(rewriterImpl.rewrites[i].get());
+ if (!opReplacement || !opReplacement->changedResults)
+ continue;
+ Operation *op = opReplacement->getOperation();
+ for (OpResult result : op->getResults()) {
Value newValue = rewriterImpl.mapping.lookupOrNull(result);
// If the operation result was replaced with null, all of the uses of this
// value should be replaced.
if (!newValue) {
- if (failed(legalizeErasedResult(repl.first, result, rewriterImpl)))
+ if (failed(legalizeErasedResult(op, result, rewriterImpl)))
return failure();
continue;
}
@@ -2517,15 +2542,11 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
inverseMapping = rewriterImpl.mapping.getInverse();
// Legalize this result.
- rewriter.setInsertionPoint(repl.first);
- if (failed(legalizeChangedResultType(repl.first, result, newValue,
- repl.second.converter, rewriter,
+ rewriter.setInsertionPoint(op);
+ if (failed(legalizeChangedResultType(op, result, newValue,
+ opReplacement->converter, rewriter,
rewriterImpl, *inverseMapping)))
return failure();
-
- // Update the end iterator for this loop in the case it was updated
- // when legalizing generated conversion operations.
- e = rewriterImpl.operationsWithChangedResults.size();
}
}
return success();
More information about the llvm-branch-commits
mailing list