[Mlir-commits] [mlir] [mlir][Transforms][NFC] Turn in-place op modification into `IRRewrite` (PR #81245)
Matthias Springer
llvmlistbot at llvm.org
Wed Feb 21 07:19:19 PST 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/81245
>From 8a8a79da31d7bcef8209e34ddbbef9a6c0343852 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 16 Feb 2024 15:10:21 +0000
Subject: [PATCH] [mlir][Transforms][NFC] Turn in-place op modifications into
`RewriteAction`s
This commit simplifies the internal state of the dialect conversion. A separate field for the previous state of in-place op modifications is no longer needed.
BEGIN_PUBLIC
No public commit message needed for presubmit.
END_PUBLIC
---
.../mlir/Transforms/DialectConversion.h | 4 +-
.../Transforms/Utils/DialectConversion.cpp | 146 +++++++++---------
2 files changed, 74 insertions(+), 76 deletions(-)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 15fa39bde104b9..0d7722aa07ee38 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -744,8 +744,8 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// PatternRewriter hook for updating the given operation in-place.
/// Note: These methods only track updates to the given operation itself,
- /// and not nested regions. Updates to regions will still require
- /// notification through other more specific hooks above.
+ /// and not nested regions. Updates to regions will still require notification
+ /// through other more specific hooks above.
void startOpModification(Operation *op) override;
/// PatternRewriter hook for updating the given operation in-place.
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 84597fb7986b07..9e0ca3fc9b3491 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -154,14 +154,12 @@ namespace {
struct RewriterState {
RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
unsigned numReplacements, unsigned numArgReplacements,
- unsigned numRewrites, unsigned numIgnoredOperations,
- unsigned numRootUpdates)
+ unsigned numRewrites, unsigned numIgnoredOperations)
: numCreatedOps(numCreatedOps),
numUnresolvedMaterializations(numUnresolvedMaterializations),
numReplacements(numReplacements),
numArgReplacements(numArgReplacements), numRewrites(numRewrites),
- numIgnoredOperations(numIgnoredOperations),
- numRootUpdates(numRootUpdates) {}
+ numIgnoredOperations(numIgnoredOperations) {}
/// The current number of created operations.
unsigned numCreatedOps;
@@ -180,44 +178,6 @@ struct RewriterState {
/// The current number of ignored operations.
unsigned numIgnoredOperations;
-
- /// The current number of operations that were updated in place.
- unsigned numRootUpdates;
-};
-
-//===----------------------------------------------------------------------===//
-// OperationTransactionState
-
-/// The state of an operation that was updated by a pattern in-place. This
-/// contains all of the necessary information to reconstruct an operation that
-/// was updated in place.
-class OperationTransactionState {
-public:
- OperationTransactionState() = default;
- OperationTransactionState(Operation *op)
- : op(op), loc(op->getLoc()), attrs(op->getAttrDictionary()),
- operands(op->operand_begin(), op->operand_end()),
- successors(op->successor_begin(), op->successor_end()) {}
-
- /// Discard the transaction state and reset the state of the original
- /// operation.
- void resetOperation() const {
- op->setLoc(loc);
- op->setAttrs(attrs);
- op->setOperands(operands);
- for (const auto &it : llvm::enumerate(successors))
- op->setSuccessor(it.value(), it.index());
- }
-
- /// Return the original operation of this state.
- Operation *getOperation() const { return op; }
-
-private:
- Operation *op;
- LocationAttr loc;
- DictionaryAttr attrs;
- SmallVector<Value, 8> operands;
- SmallVector<Block *, 2> successors;
};
//===----------------------------------------------------------------------===//
@@ -754,14 +714,19 @@ namespace {
class IRRewrite {
public:
/// The kind of the rewrite. Rewrites can be undone if the conversion fails.
+ /// Enum values are ordered, so that they can be used in `classof`: first all
+ /// block rewrites, then all operation rewrites.
enum class Kind {
+ // Block rewrites
CreateBlock,
EraseBlock,
InlineBlock,
MoveBlock,
SplitBlock,
BlockTypeConversion,
- MoveOperation
+ // Operation rewrites
+ MoveOperation,
+ ModifyOperation
};
virtual ~IRRewrite() = default;
@@ -992,7 +957,7 @@ class OperationRewrite : public IRRewrite {
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() >= Kind::MoveOperation &&
- rewrite->getKind() <= Kind::MoveOperation;
+ rewrite->getKind() <= Kind::ModifyOperation;
}
protected:
@@ -1031,8 +996,48 @@ class MoveOperationRewrite : public OperationRewrite {
// this operation was the only operation in the region.
Operation *insertBeforeOp;
};
+
+/// In-place modification of an op. This rewrite is immediately reflected in
+/// the IR. The previous state of the operation is stored in this object.
+class ModifyOperationRewrite : public OperationRewrite {
+public:
+ ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+ Operation *op)
+ : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
+ loc(op->getLoc()), attrs(op->getAttrDictionary()),
+ operands(op->operand_begin(), op->operand_end()),
+ successors(op->successor_begin(), op->successor_end()) {}
+
+ static bool classof(const IRRewrite *rewrite) {
+ return rewrite->getKind() == Kind::ModifyOperation;
+ }
+
+ void rollback() override {
+ op->setLoc(loc);
+ op->setAttrs(attrs);
+ op->setOperands(operands);
+ for (const auto &it : llvm::enumerate(successors))
+ op->setSuccessor(it.value(), it.index());
+ }
+
+private:
+ LocationAttr loc;
+ DictionaryAttr attrs;
+ SmallVector<Value, 8> operands;
+ SmallVector<Block *, 2> successors;
+};
} // namespace
+/// Return "true" if there is an operation rewrite that matches the specified
+/// rewrite type and operation among the given rewrites.
+template <typename RewriteTy, typename R>
+static bool hasRewrite(R &&rewrites, Operation *op) {
+ return any_of(std::move(rewrites), [&](auto &rewrite) {
+ auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
+ return rewriteTy && rewriteTy->getOperation() == op;
+ });
+}
+
//===----------------------------------------------------------------------===//
// ConversionPatternRewriterImpl
//===----------------------------------------------------------------------===//
@@ -1184,9 +1189,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// operation was ignored.
SetVector<Operation *> ignoredOps;
- /// A transaction state for each of operations that were updated in-place.
- SmallVector<OperationTransactionState, 4> rootUpdates;
-
/// A vector of indices into `replacements` of operations that were replaced
/// with values with different result types than the original operation, e.g.
/// 1->N conversion of some kind.
@@ -1238,10 +1240,6 @@ static void detachNestedAndErase(Operation *op) {
}
void ConversionPatternRewriterImpl::discardRewrites() {
- // Reset any operations that were updated in place.
- for (auto &state : rootUpdates)
- state.resetOperation();
-
undoRewrites();
// Remove any newly created ops.
@@ -1316,15 +1314,10 @@ void ConversionPatternRewriterImpl::applyRewrites() {
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
return RewriterState(createdOps.size(), unresolvedMaterializations.size(),
replacements.size(), argReplacements.size(),
- rewrites.size(), ignoredOps.size(), rootUpdates.size());
+ rewrites.size(), ignoredOps.size());
}
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
- // Reset any operations that were updated in place.
- for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i)
- rootUpdates[i].resetOperation();
- rootUpdates.resize(state.numRootUpdates);
-
// Reset any replaced arguments.
for (BlockArgument replacedArg :
llvm::drop_begin(argReplacements, state.numArgReplacements))
@@ -1750,7 +1743,7 @@ void ConversionPatternRewriter::startOpModification(Operation *op) {
#ifndef NDEBUG
impl->pendingRootUpdates.insert(op);
#endif
- impl->rootUpdates.emplace_back(op);
+ impl->appendRewrite<ModifyOperationRewrite>(op);
}
void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
@@ -1769,13 +1762,15 @@ void ConversionPatternRewriter::cancelOpModification(Operation *op) {
"operation did not have a pending in-place update");
#endif
// Erase the last update for this operation.
- auto stateHasOp = [op](const auto &it) { return it.getOperation() == op; };
- auto &rootUpdates = impl->rootUpdates;
- auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp);
- assert(it != rootUpdates.rend() && "no root update started on op");
- (*it).resetOperation();
- int updateIdx = std::prev(rootUpdates.rend()) - it;
- rootUpdates.erase(rootUpdates.begin() + updateIdx);
+ auto it = llvm::find_if(
+ llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) {
+ auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
+ return modifyRewrite && modifyRewrite->getOperation() == op;
+ });
+ assert(it != impl->rewrites.rend() && "no root update started on op");
+ (*it)->rollback();
+ int updateIdx = std::prev(impl->rewrites.rend()) - it;
+ impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
}
detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
@@ -2059,6 +2054,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
// Functor that cleans up the rewriter state after a pattern failed to match.
RewriterState curState = rewriterImpl.getCurrentState();
auto onFailure = [&](const Pattern &pattern) {
+ assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
LLVM_DEBUG({
logFailure(rewriterImpl.logger, "pattern failed to match");
if (rewriterImpl.notifyCallback) {
@@ -2076,6 +2072,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
// Functor that performs additional legalization when a pattern is
// successfully applied.
auto onSuccess = [&](const Pattern &pattern) {
+ assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
auto result = legalizePatternResult(op, pattern, rewriter, curState);
appliedPatterns.erase(&pattern);
if (failed(result))
@@ -2118,7 +2115,6 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
#ifndef NDEBUG
assert(impl.pendingRootUpdates.empty() && "dangling root updates");
-#endif
// Check that the root was either replaced or updated in place.
auto replacedRoot = [&] {
@@ -2127,14 +2123,12 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
[op](auto &it) { return it.first == op; });
};
auto updatedRootInPlace = [&] {
- return llvm::any_of(
- llvm::drop_begin(impl.rootUpdates, curState.numRootUpdates),
- [op](auto &state) { return state.getOperation() == op; });
+ return hasRewrite<ModifyOperationRewrite>(
+ llvm::drop_begin(impl.rewrites, curState.numRewrites), op);
};
- (void)replacedRoot;
- (void)updatedRootInPlace;
assert((replacedRoot() || updatedRootInPlace()) &&
"expected pattern to replace the root operation");
+#endif // NDEBUG
// Legalize each of the actions registered during application.
RewriterState newState = impl.getCurrentState();
@@ -2221,8 +2215,11 @@ LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
LogicalResult OperationLegalizer::legalizePatternRootUpdates(
ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
RewriterState &state, RewriterState &newState) {
- for (int i = state.numRootUpdates, e = newState.numRootUpdates; i != e; ++i) {
- Operation *op = impl.rootUpdates[i].getOperation();
+ for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
+ auto *rewrite = dyn_cast<ModifyOperationRewrite>(impl.rewrites[i].get());
+ if (!rewrite)
+ continue;
+ Operation *op = rewrite->getOperation();
if (failed(legalize(op, rewriter))) {
LLVM_DEBUG(logFailure(
impl.logger, "failed to legalize operation updated in-place '{0}'",
@@ -3562,7 +3559,8 @@ mlir::applyPartialConversion(Operation *op, const ConversionTarget &target,
// Full Conversion
LogicalResult
-mlir::applyFullConversion(ArrayRef<Operation *> ops, const ConversionTarget &target,
+mlir::applyFullConversion(ArrayRef<Operation *> ops,
+ const ConversionTarget &target,
const FrozenRewritePatternSet &patterns) {
OperationConverter opConverter(target, patterns, OpConversionMode::Full);
return opConverter.convertOperations(ops);
More information about the Mlir-commits
mailing list