[Mlir-commits] [mlir] [mlir][Transforms][NFC] Store per-pattern IR modifications in separate state (PR #145319)
Matthias Springer
llvmlistbot at llvm.org
Tue Jun 24 00:15:55 PDT 2025
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/145319
>From 46bf825dddc524101c5eea3bd03f2b898ed713e8 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 23 Jun 2025 12:36:02 +0000
Subject: [PATCH] [mlir][Transforms][NFC] Store per-pattern IR modifications in
separate state
---
.../Transforms/Utils/DialectConversion.cpp | 139 ++++++++++--------
1 file changed, 75 insertions(+), 64 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 7cfe7250d02c3..955a106c21941 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1080,6 +1080,16 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// to modify/access them is invalid rewriter API usage.
SetVector<Operation *> replacedOps;
+ /// A set of operations that were created by the current pattern.
+ SetVector<Operation *> patternNewOps;
+
+ /// A set of operations that were modified by the current pattern.
+ SetVector<Operation *> patternModifiedOps;
+
+ /// A set of blocks that were inserted (newly-created blocks or moved blocks)
+ /// by the current pattern.
+ SetVector<Block *> patternInsertedBlocks;
+
/// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
/// to the corresponding rewrite objects.
DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
@@ -1571,6 +1581,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
if (!previous.isSet()) {
// This is a newly created op.
appendRewrite<CreateOperationRewrite>(op);
+ patternNewOps.insert(op);
return;
}
Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
@@ -1655,6 +1666,8 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
}
});
+ patternInsertedBlocks.insert(block);
+
if (!previous) {
// This is a newly created block.
appendRewrite<CreateBlockRewrite>(block);
@@ -1852,6 +1865,8 @@ void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
assert(!impl->wasOpReplaced(op) &&
"attempting to modify a replaced/erased op");
PatternRewriter::finalizeOpModification(op);
+ impl->patternModifiedOps.insert(op);
+
// There is nothing to do here, we only need to track the operation at the
// start of the update.
#ifndef NDEBUG
@@ -1964,21 +1979,25 @@ class OperationLegalizer {
/// Legalize the resultant IR after successfully applying the given pattern.
LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
ConversionPatternRewriter &rewriter,
- RewriterState &curState);
+ const SetVector<Operation *> &newOps,
+ const SetVector<Operation *> &modifiedOps,
+ const SetVector<Block *> &insertedBlocks);
/// Legalizes the actions registered during the execution of a pattern.
LogicalResult
legalizePatternBlockRewrites(Operation *op,
ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &impl,
- RewriterState &state, RewriterState &newState);
- LogicalResult legalizePatternCreatedOperations(
- ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
- RewriterState &state, RewriterState &newState);
- LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &impl,
- RewriterState &state,
- RewriterState &newState);
+ const SetVector<Block *> &insertedBlocks,
+ const SetVector<Operation *> &newOps);
+ LogicalResult
+ legalizePatternCreatedOperations(ConversionPatternRewriter &rewriter,
+ ConversionPatternRewriterImpl &impl,
+ const SetVector<Operation *> &newOps);
+ LogicalResult
+ legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
+ ConversionPatternRewriterImpl &impl,
+ const SetVector<Operation *> &modifiedOps);
//===--------------------------------------------------------------------===//
// Cost Model
@@ -2131,6 +2150,15 @@ OperationLegalizer::legalize(Operation *op,
return failure();
}
+/// Helper function that moves and returns the given object. Also resets the
+/// original object, so that it is in a valid, empty state again.
+template <typename T>
+static T moveAndReset(T &obj) {
+ T result = std::move(obj);
+ obj = T();
+ return result;
+}
+
LogicalResult
OperationLegalizer::legalizeWithFold(Operation *op,
ConversionPatternRewriter &rewriter) {
@@ -2192,6 +2220,9 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
RewriterState curState = rewriterImpl.getCurrentState();
auto onFailure = [&](const Pattern &pattern) {
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
+ rewriterImpl.patternNewOps.clear();
+ rewriterImpl.patternModifiedOps.clear();
+ rewriterImpl.patternInsertedBlocks.clear();
LLVM_DEBUG({
logFailure(rewriterImpl.logger, "pattern failed to match");
if (rewriterImpl.config.notifyCallback) {
@@ -2212,7 +2243,13 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
// successfully applied.
auto onSuccess = [&](const Pattern &pattern) {
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
- auto result = legalizePatternResult(op, pattern, rewriter, curState);
+ SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps);
+ SetVector<Operation *> modifiedOps =
+ moveAndReset(rewriterImpl.patternModifiedOps);
+ SetVector<Block *> insertedBlocks =
+ moveAndReset(rewriterImpl.patternInsertedBlocks);
+ auto result = legalizePatternResult(op, pattern, rewriter, newOps,
+ modifiedOps, insertedBlocks);
appliedPatterns.erase(&pattern);
if (failed(result)) {
if (!rewriterImpl.config.allowPatternRollback)
@@ -2253,10 +2290,11 @@ bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
return true;
}
-LogicalResult
-OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
- ConversionPatternRewriter &rewriter,
- RewriterState &curState) {
+LogicalResult OperationLegalizer::legalizePatternResult(
+ Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter,
+ const SetVector<Operation *> &newOps,
+ const SetVector<Operation *> &modifiedOps,
+ const SetVector<Block *> &insertedBlocks) {
auto &impl = rewriter.getImpl();
assert(impl.pendingRootUpdates.empty() && "dangling root updates");
@@ -2274,12 +2312,10 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Legalize each of the actions registered during application.
- RewriterState newState = impl.getCurrentState();
- if (failed(legalizePatternBlockRewrites(op, rewriter, impl, curState,
- newState)) ||
- failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) ||
- failed(legalizePatternCreatedOperations(rewriter, impl, curState,
- newState))) {
+ if (failed(legalizePatternBlockRewrites(op, rewriter, impl, insertedBlocks,
+ newOps)) ||
+ failed(legalizePatternRootUpdates(rewriter, impl, modifiedOps)) ||
+ failed(legalizePatternCreatedOperations(rewriter, impl, newOps))) {
return failure();
}
@@ -2289,20 +2325,14 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
Operation *op, ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &impl, RewriterState &state,
- RewriterState &newState) {
- SmallPtrSet<Operation *, 16> operationsToIgnore;
+ ConversionPatternRewriterImpl &impl,
+ const SetVector<Block *> &insertedBlocks,
+ const SetVector<Operation *> &newOps) {
+ SmallPtrSet<Operation *, 16> alreadyLegalized;
// If the pattern moved or created any blocks, make sure the types of block
// arguments get legalized.
- for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
- BlockRewrite *rewrite = dyn_cast<BlockRewrite>(impl.rewrites[i].get());
- if (!rewrite)
- continue;
- Block *block = rewrite->getBlock();
- if (isa<BlockTypeConversionRewrite, EraseBlockRewrite,
- ReplaceBlockArgRewrite, InlineBlockRewrite>(rewrite))
- continue;
+ for (Block *block : insertedBlocks) {
// Only check blocks outside of the current operation.
Operation *parentOp = block->getParentOp();
if (!parentOp || parentOp == op || block->getNumArguments() == 0)
@@ -2322,41 +2352,26 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
continue;
}
- // Otherwise, check that this operation isn't one generated by this pattern.
- // This is because we will attempt to legalize the parent operation, and
- // blocks in regions created by this pattern will already be legalized later
- // on. If we haven't built the set yet, build it now.
- if (operationsToIgnore.empty()) {
- for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e;
- ++i) {
- auto *createOp =
- dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
- if (!createOp)
- continue;
- operationsToIgnore.insert(createOp->getOperation());
+ // Otherwise, try to legalize the parent operation if it was not generated
+ // by this pattern. This is because we will attempt to legalize the parent
+ // operation, and blocks in regions created by this pattern will already be
+ // legalized later on.
+ if (!newOps.count(parentOp) && alreadyLegalized.insert(parentOp).second) {
+ if (failed(legalize(parentOp, rewriter))) {
+ LLVM_DEBUG(logFailure(
+ impl.logger, "operation '{0}'({1}) became illegal after rewrite",
+ parentOp->getName(), parentOp));
+ return failure();
}
}
-
- // If this operation should be considered for re-legalization, try it.
- if (operationsToIgnore.insert(parentOp).second &&
- failed(legalize(parentOp, rewriter))) {
- LLVM_DEBUG(logFailure(impl.logger,
- "operation '{0}'({1}) became illegal after rewrite",
- parentOp->getName(), parentOp));
- return failure();
- }
}
return success();
}
LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
- RewriterState &state, RewriterState &newState) {
- for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
- auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
- if (!createOp)
- continue;
- Operation *op = createOp->getOperation();
+ const SetVector<Operation *> &newOps) {
+ for (Operation *op : newOps) {
if (failed(legalize(op, rewriter))) {
LLVM_DEBUG(logFailure(impl.logger,
"failed to legalize generated operation '{0}'({1})",
@@ -2369,12 +2384,8 @@ LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
LogicalResult OperationLegalizer::legalizePatternRootUpdates(
ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
- RewriterState &state, RewriterState &newState) {
- for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
- auto *rewrite = dyn_cast<ModifyOperationRewrite>(impl.rewrites[i].get());
- if (!rewrite)
- continue;
- Operation *op = rewrite->getOperation();
+ const SetVector<Operation *> &modifiedOps) {
+ for (Operation *op : modifiedOps) {
if (failed(legalize(op, rewriter))) {
LLVM_DEBUG(logFailure(
impl.logger, "failed to legalize operation updated in-place '{0}'",
More information about the Mlir-commits
mailing list