[Mlir-commits] [mlir] d68d295 - [mlir][Transforms][NFC] Turn op/block arg replacements into `IRRewrite`s (#81757)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Feb 23 00:48:17 PST 2024
Author: Matthias Springer
Date: 2024-02-23T09:48:13+01:00
New Revision: d68d29516102252f6bf6dc23fb22cef144ca1cb3
URL: https://github.com/llvm/llvm-project/commit/d68d29516102252f6bf6dc23fb22cef144ca1cb3
DIFF: https://github.com/llvm/llvm-project/commit/d68d29516102252f6bf6dc23fb22cef144ca1cb3.diff
LOG: [mlir][Transforms][NFC] Turn op/block arg replacements into `IRRewrite`s (#81757)
This commit is a refactoring of the dialect conversion. The dialect
conversion maintains a list of "IR rewrites" that can be committed (upon
success) or rolled back (upon failure).
Until now, op replacements and block argument replacements were kept
track in separate data structures inside the dialect conversion. This
commit turns them into `IRRewrite`s, so that they can be committed or
rolled back just like any other rewrite. This simplifies the internal
state of the dialect conversion.
Overview of changes:
* Add two new rewrite classes: `ReplaceBlockArgRewrite` and
`ReplaceOperationRewrite`. Remove the `OpReplacement` helper class; it
is now part of `ReplaceOperationRewrite`.
* Simplify `RewriterState`: `numReplacements` and `numArgReplacements`
are no longer needed. (Now being kept track of by `numRewrites`.)
* Add `IRRewrite::cleanup`. Operations should not be erased in `commit`
because they may still be referenced in other internal state of the
dialect conversion (`mapping`). Detaching operations is fine.
* `trackedOps` are now updated during the "commit" phase instead of
after applying all rewrites.
Added:
Modified:
mlir/lib/Transforms/Utils/DialectConversion.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index db41b9f19e7e8d..dec68048dc1d30 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -153,14 +153,12 @@ namespace {
/// This is useful when saving and undoing a set of rewrites.
struct RewriterState {
RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
- unsigned numReplacements, unsigned numArgReplacements,
unsigned numRewrites, unsigned numIgnoredOperations,
unsigned numErased)
: numCreatedOps(numCreatedOps),
numUnresolvedMaterializations(numUnresolvedMaterializations),
- numReplacements(numReplacements),
- numArgReplacements(numArgReplacements), numRewrites(numRewrites),
- numIgnoredOperations(numIgnoredOperations), numErased(numErased) {}
+ numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
+ numErased(numErased) {}
/// The current number of created operations.
unsigned numCreatedOps;
@@ -168,12 +166,6 @@ struct RewriterState {
/// The current number of unresolved materializations.
unsigned numUnresolvedMaterializations;
- /// The current number of replacements queued.
- unsigned numReplacements;
-
- /// The current number of argument replacements queued.
- unsigned numArgReplacements;
-
/// The current number of rewrites performed.
unsigned numRewrites;
@@ -184,20 +176,6 @@ struct RewriterState {
unsigned numErased;
};
-//===----------------------------------------------------------------------===//
-// OpReplacement
-
-/// This class represents one requested operation replacement via 'replaceOp' or
-/// 'eraseOp`.
-struct OpReplacement {
- OpReplacement(const TypeConverter *converter = nullptr)
- : converter(converter) {}
-
- /// An optional type converter that can be used to materialize conversions
- /// between the new and old values if necessary.
- const TypeConverter *converter;
-};
-
//===----------------------------------------------------------------------===//
// UnresolvedMaterialization
@@ -321,19 +299,27 @@ class IRRewrite {
MoveBlock,
SplitBlock,
BlockTypeConversion,
+ ReplaceBlockArg,
// Operation rewrites
MoveOperation,
- ModifyOperation
+ ModifyOperation,
+ ReplaceOperation
};
virtual ~IRRewrite() = default;
- /// Roll back the rewrite.
+ /// Roll back the rewrite. Operations may be erased during rollback.
virtual void rollback() = 0;
- /// Commit the rewrite.
+ /// Commit the rewrite. Operations may be unlinked from their blocks during
+ /// the commit phase, but they must not be erased yet. This is because
+ /// internal dialect conversion state (such as `mapping`) may still be using
+ /// them. Operations must be erased during cleanup.
virtual void commit() {}
+ /// Cleanup operations. Cleanup is called after commit.
+ virtual void cleanup() {}
+
Kind getKind() const { return kind; }
static bool classof(const IRRewrite *rewrite) { return true; }
@@ -360,7 +346,7 @@ class BlockRewrite : public IRRewrite {
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() >= Kind::CreateBlock &&
- rewrite->getKind() <= Kind::BlockTypeConversion;
+ rewrite->getKind() <= Kind::ReplaceBlockArg;
}
protected:
@@ -428,6 +414,8 @@ class EraseBlockRewrite : public BlockRewrite {
void commit() override {
// Erase the block.
assert(block && "expected block");
+ assert(block->empty() && "expected empty block");
+ block->dropAllDefinedValueUses();
delete block;
block = nullptr;
}
@@ -589,6 +577,27 @@ class BlockTypeConversionRewrite : public BlockRewrite {
const TypeConverter *converter;
};
+/// Replacing a block argument. This rewrite is not immediately reflected in the
+/// IR. An internal IR mapping is updated, but the actual replacement is delayed
+/// until the rewrite is committed.
+class ReplaceBlockArgRewrite : public BlockRewrite {
+public:
+ ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+ Block *block, BlockArgument arg)
+ : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg) {}
+
+ static bool classof(const IRRewrite *rewrite) {
+ return rewrite->getKind() == Kind::ReplaceBlockArg;
+ }
+
+ void commit() override;
+
+ void rollback() override;
+
+private:
+ BlockArgument arg;
+};
+
/// An operation rewrite.
class OperationRewrite : public IRRewrite {
public:
@@ -597,7 +606,7 @@ class OperationRewrite : public IRRewrite {
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() >= Kind::MoveOperation &&
- rewrite->getKind() <= Kind::ModifyOperation;
+ rewrite->getKind() <= Kind::ReplaceOperation;
}
protected:
@@ -698,6 +707,39 @@ class ModifyOperationRewrite : public OperationRewrite {
SmallVector<Block *, 2> successors;
void *propertiesStorage = nullptr;
};
+
+/// Replacing an operation. Erasing an operation is treated as a special case
+/// with "null" replacements. This rewrite is not immediately reflected in the
+/// IR. An internal IR mapping is updated, but values are not replaced and the
+/// original op is not erased until the rewrite is committed.
+class ReplaceOperationRewrite : public OperationRewrite {
+public:
+ ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+ Operation *op, const TypeConverter *converter,
+ bool changedResults)
+ : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
+ converter(converter), changedResults(changedResults) {}
+
+ static bool classof(const IRRewrite *rewrite) {
+ return rewrite->getKind() == Kind::ReplaceOperation;
+ }
+
+ void commit() override;
+
+ void rollback() override;
+
+ void cleanup() override;
+
+private:
+ friend struct OperationConverter;
+
+ /// An optional type converter that can be used to materialize conversions
+ /// between the new and old values if necessary.
+ const TypeConverter *converter;
+
+ /// A boolean flag that indicates whether result types have changed or not.
+ bool changedResults;
+};
} // namespace
/// Return "true" if there is an operation rewrite that matches the specified
@@ -890,6 +932,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
void eraseBlock(Block *block) override {
if (erased.contains(block))
return;
+ assert(block->empty() && "expected empty block");
block->dropAllDefinedValueUses();
RewriterBase::eraseBlock(block);
}
@@ -921,12 +964,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// conversion.
SmallVector<UnresolvedMaterialization> unresolvedMaterializations;
- /// Ordered map of requested operation replacements.
- llvm::MapVector<Operation *, OpReplacement> replacements;
-
- /// Ordered vector of any requested block argument replacements.
- SmallVector<BlockArgument, 4> argReplacements;
-
/// Ordered list of block operations (creations, splits, motions).
SmallVector<std::unique_ptr<IRRewrite>> rewrites;
@@ -941,11 +978,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// operation was ignored.
SetVector<Operation *> ignoredOps;
- /// A vector of indices into `replacements` of operations that were replaced
- /// with values with
diff erent result types than the original operation, e.g.
- /// 1->N conversion of some kind.
- SmallVector<unsigned, 4> operationsWithChangedResults;
-
/// The current type converter, or nullptr if no type converter is currently
/// active.
const TypeConverter *currentTypeConverter = nullptr;
@@ -957,6 +989,12 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// This allows the user to collect the match failure message.
function_ref<void(Diagnostic &)> notifyCallback;
+ /// A set of pre-existing operations. When mode == OpConversionMode::Analysis,
+ /// this is populated with ops found to be legalizable to the target.
+ /// When mode == OpConversionMode::Partial, this is populated with ops found
+ /// *not* to be legalizable to the target.
+ DenseSet<Operation *> *trackedOps = nullptr;
+
#ifndef NDEBUG
/// A set of operations that have pending updates. This tracking isn't
/// strictly necessary, and is thus only active during debug builds for extra
@@ -1001,6 +1039,8 @@ void BlockTypeConversionRewrite::commit() {
}
}
+ assert(origBlock->empty() && "expected empty block");
+ origBlock->dropAllDefinedValueUses();
delete origBlock;
origBlock = nullptr;
}
@@ -1063,6 +1103,47 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
return success();
}
+void ReplaceBlockArgRewrite::commit() {
+ Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
+ if (!repl)
+ return;
+
+ if (isa<BlockArgument>(repl)) {
+ arg.replaceAllUsesWith(repl);
+ return;
+ }
+
+ // If the replacement value is an operation, we check to make sure that we
+ // don't replace uses that are within the parent operation of the
+ // replacement value.
+ Operation *replOp = cast<OpResult>(repl).getOwner();
+ Block *replBlock = replOp->getBlock();
+ arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
+ Operation *user = operand.getOwner();
+ return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
+ });
+}
+
+void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(arg); }
+
+void ReplaceOperationRewrite::commit() {
+ for (OpResult result : op->getResults())
+ if (Value newValue =
+ rewriterImpl.mapping.lookupOrNull(result, result.getType()))
+ result.replaceAllUsesWith(newValue);
+ if (rewriterImpl.trackedOps)
+ rewriterImpl.trackedOps->erase(op);
+ // Do not erase the operation yet. It may still be referenced in `mapping`.
+ op->getBlock()->getOperations().remove(op);
+}
+
+void ReplaceOperationRewrite::rollback() {
+ for (auto result : op->getResults())
+ rewriterImpl.mapping.erase(result);
+}
+
+void ReplaceOperationRewrite::cleanup() { eraseOp(op); }
+
void ConversionPatternRewriterImpl::detachNestedAndErase(Operation *op) {
for (Region ®ion : op->getRegions()) {
for (Block &block : region.getBlocks()) {
@@ -1085,51 +1166,16 @@ void ConversionPatternRewriterImpl::discardRewrites() {
}
void ConversionPatternRewriterImpl::applyRewrites() {
- // Apply all of the rewrites replacements requested during conversion.
- for (auto &repl : replacements) {
- for (OpResult result : repl.first->getResults())
- if (Value newValue = mapping.lookupOrNull(result, result.getType()))
- result.replaceAllUsesWith(newValue);
- }
-
- // Apply all of the requested argument replacements.
- for (BlockArgument arg : argReplacements) {
- Value repl = mapping.lookupOrNull(arg, arg.getType());
- if (!repl)
- continue;
-
- if (isa<BlockArgument>(repl)) {
- arg.replaceAllUsesWith(repl);
- continue;
- }
-
- // If the replacement value is an operation, we check to make sure that we
- // don't replace uses that are within the parent operation of the
- // replacement value.
- Operation *replOp = cast<OpResult>(repl).getOwner();
- Block *replBlock = replOp->getBlock();
- arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
- Operation *user = operand.getOwner();
- return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
- });
- }
+ // Commit all rewrites.
+ for (auto &rewrite : rewrites)
+ rewrite->commit();
+ for (auto &rewrite : rewrites)
+ rewrite->cleanup();
// Drop all of the unresolved materialization operations created during
// conversion.
for (auto &mat : unresolvedMaterializations)
eraseRewriter.eraseOp(mat.getOp());
-
- // In a second pass, erase all of the replaced operations in reverse. This
- // allows processing nested operations before their parent region is
- // destroyed. Because we process in reverse order, producers may be deleted
- // before their users (a pattern deleting a producer and then the consumer)
- // so we first drop all uses explicitly.
- for (auto &repl : llvm::reverse(replacements))
- eraseRewriter.eraseOp(repl.first);
-
- // Commit all rewrites.
- for (auto &rewrite : rewrites)
- rewrite->commit();
}
//===----------------------------------------------------------------------===//
@@ -1137,28 +1183,14 @@ void ConversionPatternRewriterImpl::applyRewrites() {
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
return RewriterState(createdOps.size(), unresolvedMaterializations.size(),
- replacements.size(), argReplacements.size(),
rewrites.size(), ignoredOps.size(),
eraseRewriter.erased.size());
}
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
- // Reset any replaced arguments.
- for (BlockArgument replacedArg :
- llvm::drop_begin(argReplacements, state.numArgReplacements))
- mapping.erase(replacedArg);
- argReplacements.resize(state.numArgReplacements);
-
// Undo any rewrites.
undoRewrites(state.numRewrites);
- // Reset any replaced operations and undo any saved mappings.
- for (auto &repl : llvm::drop_begin(replacements, state.numReplacements))
- for (auto result : repl.first->getResults())
- mapping.erase(result);
- while (replacements.size() != state.numReplacements)
- replacements.pop_back();
-
// Pop all of the newly inserted materializations.
while (unresolvedMaterializations.size() !=
state.numUnresolvedMaterializations) {
@@ -1183,11 +1215,6 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
while (ignoredOps.size() != state.numIgnoredOperations)
ignoredOps.pop_back();
- // Reset operations with changed results.
- while (!operationsWithChangedResults.empty() &&
- operationsWithChangedResults.back() >= state.numReplacements)
- operationsWithChangedResults.pop_back();
-
while (eraseRewriter.erased.size() != state.numErased)
eraseRewriter.erased.pop_back();
}
@@ -1256,7 +1283,8 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
// Check to see if this operation was replaced or its parent ignored.
- return replacements.count(op) || ignoredOps.count(op->getParentOp());
+ return ignoredOps.count(op->getParentOp()) ||
+ hasRewrite<ReplaceOperationRewrite>(rewrites, op);
}
void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) {
@@ -1396,7 +1424,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
"invalid to provide a replacement value when the argument isn't "
"dropped");
mapping.map(origArg, inputMap->replacementValue);
- argReplacements.push_back(origArg);
+ appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
continue;
}
@@ -1430,7 +1458,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
}
mapping.map(origArg, newArg);
- argReplacements.push_back(origArg);
+ appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
}
@@ -1462,7 +1490,12 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
ValueRange newValues) {
assert(newValues.size() == op->getNumResults());
- assert(!replacements.count(op) && "operation was already replaced");
+#ifndef NDEBUG
+ for (auto &rewrite : rewrites)
+ if (auto *opReplacement = dyn_cast<ReplaceOperationRewrite>(rewrite.get()))
+ assert(opReplacement->getOperation() != op &&
+ "operation was already replaced");
+#endif // NDEBUG
// Track if any of the results changed, e.g. erased and replaced with null.
bool resultChanged = false;
@@ -1477,11 +1510,9 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
mapping.map(result, newValue);
resultChanged |= (newValue.getType() != result.getType());
}
- if (resultChanged)
- operationsWithChangedResults.push_back(replacements.size());
- // Record the requested operation replacement.
- replacements.insert(std::make_pair(op, OpReplacement(currentTypeConverter)));
+ appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter,
+ resultChanged);
// Mark this operation as recursively ignored so that we don't need to
// convert any nested operations.
@@ -1576,8 +1607,6 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
}
void ConversionPatternRewriter::eraseBlock(Block *block) {
- impl->notifyBlockIsBeingErased(block);
-
// Mark all ops for erasure.
for (Operation &op : *block)
eraseOp(&op);
@@ -1586,6 +1615,7 @@ void ConversionPatternRewriter::eraseBlock(Block *block) {
// object and will be actually destroyed when rewrites are applied. This
// allows us to keep the operations in the block live and undo the removal by
// re-inserting the block.
+ impl->notifyBlockIsBeingErased(block);
block->getParent()->getBlocks().remove(block);
}
@@ -1615,7 +1645,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
<< "'(in region of '" << parentOp->getName()
<< "'(" << from.getOwner()->getParentOp() << ")\n";
});
- impl->argReplacements.push_back(from);
+ impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from);
impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
}
@@ -2039,16 +2069,13 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
#ifndef NDEBUG
assert(impl.pendingRootUpdates.empty() && "dangling root updates");
-
// Check that the root was either replaced or updated in place.
+ auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
auto replacedRoot = [&] {
- return llvm::any_of(
- llvm::drop_begin(impl.replacements, curState.numReplacements),
- [op](auto &it) { return it.first == op; });
+ return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
};
auto updatedRootInPlace = [&] {
- return hasRewrite<ModifyOperationRewrite>(
- llvm::drop_begin(impl.rewrites, curState.numRewrites), op);
+ return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
};
assert((replacedRoot() || updatedRootInPlace()) &&
"expected pattern to replace the root operation");
@@ -2081,7 +2108,8 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
if (!rewrite)
continue;
Block *block = rewrite->getBlock();
- if (isa<BlockTypeConversionRewrite, EraseBlockRewrite>(rewrite))
+ if (isa<BlockTypeConversionRewrite, EraseBlockRewrite,
+ ReplaceBlockArgRewrite>(rewrite))
continue;
// Only check blocks outside of the current operation.
Operation *parentOp = block->getParentOp();
@@ -2476,6 +2504,7 @@ LogicalResult OperationConverter::convertOperations(
ConversionPatternRewriter rewriter(ops.front()->getContext());
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
rewriterImpl.notifyCallback = notifyCallback;
+ rewriterImpl.trackedOps = trackedOps;
for (auto *op : toConvert)
if (failed(convert(rewriter, op)))
@@ -2493,13 +2522,6 @@ LogicalResult OperationConverter::convertOperations(
rewriterImpl.discardRewrites();
} else {
rewriterImpl.applyRewrites();
-
- // It is possible for a later pattern to erase an op that was originally
- // identified as illegal and added to the trackedOps, remove it now after
- // replacements have been computed.
- if (trackedOps)
- for (auto &repl : rewriterImpl.replacements)
- trackedOps->erase(repl.first);
}
return success();
}
@@ -2513,21 +2535,20 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
return failure();
- if (rewriterImpl.operationsWithChangedResults.empty())
- return success();
-
// Process requested operation replacements.
- for (unsigned i = 0, e = rewriterImpl.operationsWithChangedResults.size();
- i != e; ++i) {
- unsigned replIdx = rewriterImpl.operationsWithChangedResults[i];
- auto &repl = *(rewriterImpl.replacements.begin() + replIdx);
- for (OpResult result : repl.first->getResults()) {
+ for (unsigned i = 0; i < rewriterImpl.rewrites.size(); ++i) {
+ auto *opReplacement =
+ dyn_cast<ReplaceOperationRewrite>(rewriterImpl.rewrites[i].get());
+ if (!opReplacement || !opReplacement->changedResults)
+ continue;
+ Operation *op = opReplacement->getOperation();
+ for (OpResult result : op->getResults()) {
Value newValue = rewriterImpl.mapping.lookupOrNull(result);
// If the operation result was replaced with null, all of the uses of this
// value should be replaced.
if (!newValue) {
- if (failed(legalizeErasedResult(repl.first, result, rewriterImpl)))
+ if (failed(legalizeErasedResult(op, result, rewriterImpl)))
return failure();
continue;
}
@@ -2541,15 +2562,11 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
inverseMapping = rewriterImpl.mapping.getInverse();
// Legalize this result.
- rewriter.setInsertionPoint(repl.first);
- if (failed(legalizeChangedResultType(repl.first, result, newValue,
- repl.second.converter, rewriter,
+ rewriter.setInsertionPoint(op);
+ if (failed(legalizeChangedResultType(op, result, newValue,
+ opReplacement->converter, rewriter,
rewriterImpl, *inverseMapping)))
return failure();
-
- // Update the end iterator for this loop in the case it was updated
- // when legalizing generated conversion operations.
- e = rewriterImpl.operationsWithChangedResults.size();
}
}
return success();
More information about the Mlir-commits
mailing list