[llvm-branch-commits] [mlir] [mlir][Transforms][NFC] Turn in-place op modifications into `RewriteAction`s (PR #81245)
Matthias Springer via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Feb 9 03:38:38 PST 2024
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/81245
This commit simplifies the internal state of the dialect conversion. A separate field for the previous state of in-place op modifications is no longer needed.
>From cdbd92786887ba8f801cb0c4299f708f0a410465 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 9 Feb 2024 11:36:59 +0000
Subject: [PATCH] [mlir][Transforms][NFC] Turn in-place op modifications into
`RewriteAction`s
This commit simplifies the internal state of the dialect conversion. A separate field for the previous state of in-place op modifications is no longer needed.
---
.../Transforms/Utils/DialectConversion.cpp | 128 ++++++++----------
1 file changed, 58 insertions(+), 70 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index ffdb069f6e9b81..d0114a148cd374 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -154,15 +154,13 @@ namespace {
struct RewriterState {
RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
unsigned numReplacements, unsigned numArgReplacements,
- unsigned numRewriteActions, unsigned numIgnoredOperations,
- unsigned numRootUpdates)
+ unsigned numRewriteActions, unsigned numIgnoredOperations)
: numCreatedOps(numCreatedOps),
numUnresolvedMaterializations(numUnresolvedMaterializations),
numReplacements(numReplacements),
numArgReplacements(numArgReplacements),
numRewriteActions(numRewriteActions),
- numIgnoredOperations(numIgnoredOperations),
- numRootUpdates(numRootUpdates) {}
+ numIgnoredOperations(numIgnoredOperations) {}
/// The current number of created operations.
unsigned numCreatedOps;
@@ -181,44 +179,6 @@ struct RewriterState {
/// The current number of ignored operations.
unsigned numIgnoredOperations;
-
- /// The current number of operations that were updated in place.
- unsigned numRootUpdates;
-};
-
-//===----------------------------------------------------------------------===//
-// OperationTransactionState
-
-/// The state of an operation that was updated by a pattern in-place. This
-/// contains all of the necessary information to reconstruct an operation that
-/// was updated in place.
-class OperationTransactionState {
-public:
- OperationTransactionState() = default;
- OperationTransactionState(Operation *op)
- : op(op), loc(op->getLoc()), attrs(op->getAttrDictionary()),
- operands(op->operand_begin(), op->operand_end()),
- successors(op->successor_begin(), op->successor_end()) {}
-
- /// Discard the transaction state and reset the state of the original
- /// operation.
- void resetOperation() const {
- op->setLoc(loc);
- op->setAttrs(attrs);
- op->setOperands(operands);
- for (const auto &it : llvm::enumerate(successors))
- op->setSuccessor(it.value(), it.index());
- }
-
- /// Return the original operation of this state.
- Operation *getOperation() const { return op; }
-
-private:
- Operation *op;
- LocationAttr loc;
- DictionaryAttr attrs;
- SmallVector<Value, 8> operands;
- SmallVector<Block *, 2> successors;
};
//===----------------------------------------------------------------------===//
@@ -758,7 +718,8 @@ class RewriteAction {
MoveBlock,
SplitBlock,
BlockTypeConversion,
- MoveOperation
+ MoveOperation,
+ ModifyOperation
};
virtual ~RewriteAction() = default;
@@ -980,7 +941,7 @@ class OperationAction : public RewriteAction {
static bool classof(const RewriteAction *action) {
return action->getKind() >= Kind::MoveOperation &&
- action->getKind() <= Kind::MoveOperation;
+ action->getKind() <= Kind::ModifyOperation;
}
protected:
@@ -1019,6 +980,34 @@ class MoveOperationAction : public OperationAction {
// this operation was the only operation in the region.
Operation *insertBeforeOp;
};
+
+/// Rewrite action that represents the in-place modification of an operation.
+/// The previous state of the operation is stored in this action.
+class ModifyOperationAction : public OperationAction {
+public:
+ ModifyOperationAction(ConversionPatternRewriterImpl &rewriterImpl,
+ Operation *op)
+ : OperationAction(Kind::ModifyOperation, rewriterImpl, op),
+ loc(op->getLoc()), attrs(op->getAttrDictionary()),
+ operands(op->operand_begin(), op->operand_end()),
+ successors(op->successor_begin(), op->successor_end()) {}
+
+ /// Discard the transaction state and reset the state of the original
+ /// operation.
+ void rollback() override {
+ op->setLoc(loc);
+ op->setAttrs(attrs);
+ op->setOperands(operands);
+ for (const auto &it : llvm::enumerate(successors))
+ op->setSuccessor(it.value(), it.index());
+ }
+
+private:
+ LocationAttr loc;
+ DictionaryAttr attrs;
+ SmallVector<Value, 8> operands;
+ SmallVector<Block *, 2> successors;
+};
} // namespace
//===----------------------------------------------------------------------===//
@@ -1172,9 +1161,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// operation was ignored.
SetVector<Operation *> ignoredOps;
- /// A transaction state for each of operations that were updated in-place.
- SmallVector<OperationTransactionState, 4> rootUpdates;
-
/// A vector of indices into `replacements` of operations that were replaced
/// with values with different result types than the original operation, e.g.
/// 1->N conversion of some kind.
@@ -1226,10 +1212,6 @@ static void detachNestedAndErase(Operation *op) {
}
void ConversionPatternRewriterImpl::discardRewrites() {
- // Reset any operations that were updated in place.
- for (auto &state : rootUpdates)
- state.resetOperation();
-
undoRewriteActions();
// Remove any newly created ops.
@@ -1304,16 +1286,10 @@ void ConversionPatternRewriterImpl::applyRewrites() {
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
return RewriterState(createdOps.size(), unresolvedMaterializations.size(),
replacements.size(), argReplacements.size(),
- rewriteActions.size(), ignoredOps.size(),
- rootUpdates.size());
+ rewriteActions.size(), ignoredOps.size());
}
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
- // Reset any operations that were updated in place.
- for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i)
- rootUpdates[i].resetOperation();
- rootUpdates.resize(state.numRootUpdates);
-
// Reset any replaced arguments.
for (BlockArgument replacedArg :
llvm::drop_begin(argReplacements, state.numArgReplacements))
@@ -1740,7 +1716,7 @@ void ConversionPatternRewriter::startOpModification(Operation *op) {
#ifndef NDEBUG
impl->pendingRootUpdates.insert(op);
#endif
- impl->rootUpdates.emplace_back(op);
+ impl->appendRewriteAction<ModifyOperationAction>(op);
}
void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
@@ -1759,13 +1735,17 @@ void ConversionPatternRewriter::cancelOpModification(Operation *op) {
"operation did not have a pending in-place update");
#endif
// Erase the last update for this operation.
- auto stateHasOp = [op](const auto &it) { return it.getOperation() == op; };
- auto &rootUpdates = impl->rootUpdates;
- auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp);
- assert(it != rootUpdates.rend() && "no root update started on op");
- (*it).resetOperation();
- int updateIdx = std::prev(rootUpdates.rend()) - it;
- rootUpdates.erase(rootUpdates.begin() + updateIdx);
+ auto it =
+ llvm::find_if(llvm::reverse(impl->rewriteActions),
+ [&](std::unique_ptr<RewriteAction> &action) {
+ auto *modifyAction =
+ dynamic_cast<ModifyOperationAction *>(action.get());
+ return modifyAction && modifyAction->getOperation() == op;
+ });
+ assert(it != impl->rewriteActions.rend() && "no root update started on op");
+ (*it)->rollback();
+ int updateIdx = std::prev(impl->rewriteActions.rend()) - it;
+ impl->rewriteActions.erase(impl->rewriteActions.begin() + updateIdx);
}
detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
@@ -2118,8 +2098,11 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
};
auto updatedRootInPlace = [&] {
return llvm::any_of(
- llvm::drop_begin(impl.rootUpdates, curState.numRootUpdates),
- [op](auto &state) { return state.getOperation() == op; });
+ llvm::drop_begin(impl.rewriteActions, curState.numRewriteActions),
+ [op](auto &action) {
+ auto *modifyAction = dyn_cast<ModifyOperationAction>(action.get());
+ return modifyAction && modifyAction->getOperation() == op;
+ });
};
(void)replacedRoot;
(void)updatedRootInPlace;
@@ -2213,8 +2196,13 @@ LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
LogicalResult OperationLegalizer::legalizePatternRootUpdates(
ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
RewriterState &state, RewriterState &newState) {
- for (int i = state.numRootUpdates, e = newState.numRootUpdates; i != e; ++i) {
- Operation *op = impl.rootUpdates[i].getOperation();
+ for (int i = state.numRewriteActions, e = newState.numRewriteActions; i != e;
+ ++i) {
+ auto *action =
+ dyn_cast<ModifyOperationAction>(impl.rewriteActions[i].get());
+ if (!action)
+ continue;
+ Operation *op = action->getOperation();
if (failed(legalize(op, rewriter))) {
LLVM_DEBUG(logFailure(
impl.logger, "failed to legalize operation updated in-place '{0}'",
More information about the llvm-branch-commits
mailing list