[Mlir-commits] [mlir] [mlir][Transforms][NFC] Simplify function signatures (PR #155997)
Matthias Springer
llvmlistbot at llvm.org
Fri Aug 29 02:39:21 PDT 2025
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/155997
Many internal functions take a `ConversionPatternRewriter &` or `ConversionPatternRewriterImpl &` as a parameter. There's only a single instance of these classes, so it's better to the reference in a field. This commit is in preparation of another PR that will require access to `ConversionPatternRewriter` in additional helper functions.
>From 1e9b64b1dd3e17db92831a8ee18658ccfd8dfac9 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Fri, 29 Aug 2025 09:36:42 +0000
Subject: [PATCH] [mlir][Transforms][NFC] Simplify function signatures
---
.../Transforms/Utils/DialectConversion.cpp | 189 ++++++++----------
1 file changed, 87 insertions(+), 102 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index b6a216adfdd25..c0685f54731d5 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -848,9 +848,10 @@ static bool hasRewrite(R &&rewrites, Block *block) {
namespace mlir {
namespace detail {
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
- explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
+ explicit ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter,
const ConversionConfig &config)
- : context(ctx), config(config), notifyingRewriter(ctx, config.listener) {}
+ : rewriter(rewriter), config(config),
+ notifyingRewriter(rewriter.getContext(), config.listener) {}
//===--------------------------------------------------------------------===//
// State Management
@@ -887,8 +888,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// is the tag used when describing a value within a diagnostic, e.g.
/// "operand".
LogicalResult remapValues(StringRef valueDiagTag,
- std::optional<Location> inputLoc,
- PatternRewriter &rewriter, ValueRange values,
+ std::optional<Location> inputLoc, ValueRange values,
SmallVector<ValueVector> &remapped);
/// Return "true" if the given operation is ignored, and does not need to be
@@ -918,8 +918,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// Convert the types of block arguments within the given region.
FailureOr<Block *>
- convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
- const TypeConverter &converter,
+ convertRegionTypes(Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion);
/// Apply the given signature conversion on the given block. The new block
@@ -929,8 +928,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// translate between the origin argument types and those specified in the
/// signature conversion.
Block *applySignatureConversion(
- ConversionPatternRewriter &rewriter, Block *block,
- const TypeConverter *converter,
+ Block *block, const TypeConverter *converter,
TypeConverter::SignatureConversion &signatureConversion);
/// Replace the results of the given operation with the given values and
@@ -1060,8 +1058,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
// State
//===--------------------------------------------------------------------===//
- /// MLIR context.
- MLIRContext *context;
+ /// The rewriter that is used to perform the conversion.
+ ConversionPatternRewriter &rewriter;
// Mapping between replaced values that differ in type. This happens when
// replacing a value with one of a different type.
@@ -1258,16 +1256,17 @@ void UnresolvedMaterializationRewrite::rollback() {
}
void ConversionPatternRewriterImpl::applyRewrites() {
- // Commit all rewrites.
- IRRewriter rewriter(context, config.listener);
+ // Commit all rewrites. Use a new rewriter, so the modifications are not
+ // tracked for rollback purposes etc.
+ IRRewriter irRewriter(rewriter.getContext(), config.listener);
// Note: New rewrites may be added during the "commit" phase and the
// `rewrites` vector may reallocate.
for (size_t i = 0; i < rewrites.size(); ++i)
- rewrites[i]->commit(rewriter);
+ rewrites[i]->commit(irRewriter);
// Clean up all rewrites.
SingleEraseRewriter eraseRewriter(
- context, /*opErasedCallback=*/[&](Operation *op) {
+ rewriter.getContext(), /*opErasedCallback=*/[&](Operation *op) {
if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
unresolvedMaterializations.erase(castOp);
});
@@ -1412,8 +1411,7 @@ void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
}
LogicalResult ConversionPatternRewriterImpl::remapValues(
- StringRef valueDiagTag, std::optional<Location> inputLoc,
- PatternRewriter &rewriter, ValueRange values,
+ StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values,
SmallVector<ValueVector> &remapped) {
remapped.reserve(llvm::size(values));
@@ -1484,8 +1482,7 @@ bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
//===----------------------------------------------------------------------===//
FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
- ConversionPatternRewriter &rewriter, Region *region,
- const TypeConverter &converter,
+ Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion) {
regionToConverter[region] = &converter;
if (region->empty())
@@ -1500,25 +1497,23 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
if (!conversion)
return failure();
// Convert the block with the computed signature.
- applySignatureConversion(rewriter, &block, &converter, *conversion);
+ applySignatureConversion(&block, &converter, *conversion);
}
// Convert the entry block. If an entry signature conversion was provided,
// use that one. Otherwise, compute the signature with the type converter.
if (entryConversion)
- return applySignatureConversion(rewriter, ®ion->front(), &converter,
+ return applySignatureConversion(®ion->front(), &converter,
*entryConversion);
std::optional<TypeConverter::SignatureConversion> conversion =
converter.convertBlockSignature(®ion->front());
if (!conversion)
return failure();
- return applySignatureConversion(rewriter, ®ion->front(), &converter,
- *conversion);
+ return applySignatureConversion(®ion->front(), &converter, *conversion);
}
Block *ConversionPatternRewriterImpl::applySignatureConversion(
- ConversionPatternRewriter &rewriter, Block *block,
- const TypeConverter *converter,
+ Block *block, const TypeConverter *converter,
TypeConverter::SignatureConversion &signatureConversion) {
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// A block cannot be converted multiple times.
@@ -2023,7 +2018,7 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
ConversionPatternRewriter::ConversionPatternRewriter(
MLIRContext *ctx, const ConversionConfig &config)
: PatternRewriter(ctx),
- impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
+ impl(new detail::ConversionPatternRewriterImpl(*this, config)) {
setListener(impl.get());
}
@@ -2100,7 +2095,7 @@ Block *ConversionPatternRewriter::applySignatureConversion(
assert(!impl->wasOpReplaced(block->getParentOp()) &&
"attempting to apply a signature conversion to a block within a "
"replaced/erased op");
- return impl->applySignatureConversion(*this, block, converter, conversion);
+ return impl->applySignatureConversion(block, converter, conversion);
}
FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
@@ -2109,7 +2104,7 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
assert(!impl->wasOpReplaced(region->getParentOp()) &&
"attempting to apply a signature conversion to a block within a "
"replaced/erased op");
- return impl->convertRegionTypes(*this, region, converter, entryConversion);
+ return impl->convertRegionTypes(region, converter, entryConversion);
}
void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
@@ -2128,7 +2123,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
Value ConversionPatternRewriter::getRemappedValue(Value key) {
SmallVector<ValueVector> remappedValues;
- if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
+ if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, key,
remappedValues)))
return nullptr;
assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
@@ -2141,7 +2136,7 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
if (keys.empty())
return success();
SmallVector<ValueVector> remapped;
- if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
+ if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, keys,
remapped)))
return failure();
for (const auto &values : remapped) {
@@ -2288,7 +2283,7 @@ ConversionPattern::matchAndRewrite(Operation *op,
// Remap the operands of the operation.
SmallVector<ValueVector> remapped;
- if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
+ if (failed(rewriterImpl.remapValues("operand", op->getLoc(),
op->getOperands(), remapped))) {
return failure();
}
@@ -2310,7 +2305,8 @@ class OperationLegalizer {
public:
using LegalizationAction = ConversionTarget::LegalizationAction;
- OperationLegalizer(const ConversionTarget &targetInfo,
+ OperationLegalizer(ConversionPatternRewriter &rewriter,
+ const ConversionTarget &targetInfo,
const FrozenRewritePatternSet &patterns);
/// Returns true if the given operation is known to be illegal on the target.
@@ -2318,29 +2314,25 @@ class OperationLegalizer {
/// Attempt to legalize the given operation. Returns success if the operation
/// was legalized, failure otherwise.
- LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
+ LogicalResult legalize(Operation *op);
/// Returns the conversion target in use by the legalizer.
const ConversionTarget &getTarget() { return target; }
private:
/// Attempt to legalize the given operation by folding it.
- LogicalResult legalizeWithFold(Operation *op,
- ConversionPatternRewriter &rewriter);
+ LogicalResult legalizeWithFold(Operation *op);
/// Attempt to legalize the given operation by applying a pattern. Returns
/// success if the operation was legalized, failure otherwise.
- LogicalResult legalizeWithPattern(Operation *op,
- ConversionPatternRewriter &rewriter);
+ LogicalResult legalizeWithPattern(Operation *op);
/// Return true if the given pattern may be applied to the given operation,
/// false otherwise.
- bool canApplyPattern(Operation *op, const Pattern &pattern,
- ConversionPatternRewriter &rewriter);
+ bool canApplyPattern(Operation *op, const Pattern &pattern);
/// Legalize the resultant IR after successfully applying the given pattern.
LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
- ConversionPatternRewriter &rewriter,
const RewriterState &curState,
const SetVector<Operation *> &newOps,
const SetVector<Operation *> &modifiedOps,
@@ -2349,18 +2341,12 @@ class OperationLegalizer {
/// Legalizes the actions registered during the execution of a pattern.
LogicalResult
legalizePatternBlockRewrites(Operation *op,
- ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &impl,
const SetVector<Block *> &insertedBlocks,
const SetVector<Operation *> &newOps);
LogicalResult
- legalizePatternCreatedOperations(ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &impl,
- const SetVector<Operation *> &newOps);
+ legalizePatternCreatedOperations(const SetVector<Operation *> &newOps);
LogicalResult
- legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &impl,
- const SetVector<Operation *> &modifiedOps);
+ legalizePatternRootUpdates(const SetVector<Operation *> &modifiedOps);
//===--------------------------------------------------------------------===//
// Cost Model
@@ -2403,6 +2389,9 @@ class OperationLegalizer {
/// The current set of patterns that have been applied.
SmallPtrSet<const Pattern *, 8> appliedPatterns;
+ /// The rewriter to use when converting operations.
+ ConversionPatternRewriter &rewriter;
+
/// The legalization information provided by the target.
const ConversionTarget ⌖
@@ -2411,9 +2400,10 @@ class OperationLegalizer {
};
} // namespace
-OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
+OperationLegalizer::OperationLegalizer(ConversionPatternRewriter &rewriter,
+ const ConversionTarget &targetInfo,
const FrozenRewritePatternSet &patterns)
- : target(targetInfo), applicator(patterns) {
+ : rewriter(rewriter), target(targetInfo), applicator(patterns) {
// The set of patterns that can be applied to illegal operations to transform
// them into legal ones.
DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
@@ -2427,9 +2417,7 @@ bool OperationLegalizer::isIllegal(Operation *op) const {
return target.isIllegal(op);
}
-LogicalResult
-OperationLegalizer::legalize(Operation *op,
- ConversionPatternRewriter &rewriter) {
+LogicalResult OperationLegalizer::legalize(Operation *op) {
#ifndef NDEBUG
const char *logLineComment =
"//===-------------------------------------------===//\n";
@@ -2495,7 +2483,7 @@ OperationLegalizer::legalize(Operation *op,
// is 'BeforePatterns'. 'Never' will skip this.
const ConversionConfig &config = rewriter.getConfig();
if (config.foldingMode == DialectConversionFoldingMode::BeforePatterns) {
- if (succeeded(legalizeWithFold(op, rewriter))) {
+ if (succeeded(legalizeWithFold(op))) {
LLVM_DEBUG({
logSuccess(logger, "operation was folded");
logger.startLine() << logLineComment;
@@ -2505,7 +2493,7 @@ OperationLegalizer::legalize(Operation *op,
}
// Otherwise, we need to apply a legalization pattern to this operation.
- if (succeeded(legalizeWithPattern(op, rewriter))) {
+ if (succeeded(legalizeWithPattern(op))) {
LLVM_DEBUG({
logSuccess(logger, "");
logger.startLine() << logLineComment;
@@ -2516,7 +2504,7 @@ OperationLegalizer::legalize(Operation *op,
// If the operation can't be legalized via patterns, try to fold it in-place
// if the folding mode is 'AfterPatterns'.
if (config.foldingMode == DialectConversionFoldingMode::AfterPatterns) {
- if (succeeded(legalizeWithFold(op, rewriter))) {
+ if (succeeded(legalizeWithFold(op))) {
LLVM_DEBUG({
logSuccess(logger, "operation was folded");
logger.startLine() << logLineComment;
@@ -2541,9 +2529,7 @@ static T moveAndReset(T &obj) {
return result;
}
-LogicalResult
-OperationLegalizer::legalizeWithFold(Operation *op,
- ConversionPatternRewriter &rewriter) {
+LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
auto &rewriterImpl = rewriter.getImpl();
LLVM_DEBUG({
rewriterImpl.logger.startLine() << "* Fold {\n";
@@ -2577,14 +2563,14 @@ OperationLegalizer::legalizeWithFold(Operation *op,
// An empty list of replacement values indicates that the fold was in-place.
// As the operation changed, a new legalization needs to be attempted.
if (replacementValues.empty())
- return legalize(op, rewriter);
+ return legalize(op);
// Insert a replacement for 'op' with the folded replacement values.
rewriter.replaceOp(op, replacementValues);
// Recursively legalize any new constant operations.
for (Operation *newOp : newOps) {
- if (failed(legalize(newOp, rewriter))) {
+ if (failed(legalize(newOp))) {
LLVM_DEBUG(logFailure(rewriterImpl.logger,
"failed to legalize generated constant '{0}'",
newOp->getName()));
@@ -2629,9 +2615,7 @@ reportNewIrLegalizationFatalError(const Pattern &pattern,
llvm::join(insertedBlockNames, ", ") + "}");
}
-LogicalResult
-OperationLegalizer::legalizeWithPattern(Operation *op,
- ConversionPatternRewriter &rewriter) {
+LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
auto &rewriterImpl = rewriter.getImpl();
const ConversionConfig &config = rewriter.getConfig();
@@ -2663,7 +2647,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
// Functor that returns if the given pattern may be applied.
auto canApply = [&](const Pattern &pattern) {
- bool canApply = canApplyPattern(op, pattern, rewriter);
+ bool canApply = canApplyPattern(op, pattern);
if (canApply && config.listener)
config.listener->notifyPatternBegin(pattern, op);
return canApply;
@@ -2728,7 +2712,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
moveAndReset(rewriterImpl.patternModifiedOps);
SetVector<Block *> insertedBlocks =
moveAndReset(rewriterImpl.patternInsertedBlocks);
- auto result = legalizePatternResult(op, pattern, rewriter, curState, newOps,
+ auto result = legalizePatternResult(op, pattern, curState, newOps,
modifiedOps, insertedBlocks);
appliedPatterns.erase(&pattern);
if (failed(result)) {
@@ -2747,8 +2731,8 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
onSuccess);
}
-bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
- ConversionPatternRewriter &rewriter) {
+bool OperationLegalizer::canApplyPattern(Operation *op,
+ const Pattern &pattern) {
LLVM_DEBUG({
auto &os = rewriter.getImpl().logger;
os.getOStream() << "\n";
@@ -2770,8 +2754,8 @@ bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
}
LogicalResult OperationLegalizer::legalizePatternResult(
- Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter,
- const RewriterState &curState, const SetVector<Operation *> &newOps,
+ Operation *op, const Pattern &pattern, const RewriterState &curState,
+ const SetVector<Operation *> &newOps,
const SetVector<Operation *> &modifiedOps,
const SetVector<Block *> &insertedBlocks) {
auto &impl = rewriter.getImpl();
@@ -2792,10 +2776,9 @@ LogicalResult OperationLegalizer::legalizePatternResult(
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Legalize each of the actions registered during application.
- if (failed(legalizePatternBlockRewrites(op, rewriter, impl, insertedBlocks,
- newOps)) ||
- failed(legalizePatternRootUpdates(rewriter, impl, modifiedOps)) ||
- failed(legalizePatternCreatedOperations(rewriter, impl, newOps))) {
+ if (failed(legalizePatternBlockRewrites(op, insertedBlocks, newOps)) ||
+ failed(legalizePatternRootUpdates(modifiedOps)) ||
+ failed(legalizePatternCreatedOperations(newOps))) {
return failure();
}
@@ -2804,10 +2787,9 @@ LogicalResult OperationLegalizer::legalizePatternResult(
}
LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
- Operation *op, ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &impl,
- const SetVector<Block *> &insertedBlocks,
+ Operation *op, const SetVector<Block *> &insertedBlocks,
const SetVector<Operation *> &newOps) {
+ ConversionPatternRewriterImpl &impl = rewriter.getImpl();
SmallPtrSet<Operation *, 16> alreadyLegalized;
// If the pattern moved or created any blocks, make sure the types of block
@@ -2831,7 +2813,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
"block"));
return failure();
}
- impl.applySignatureConversion(rewriter, block, converter, *conversion);
+ impl.applySignatureConversion(block, converter, *conversion);
continue;
}
@@ -2840,7 +2822,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
// operation, and blocks in regions created by this pattern will already be
// legalized later on.
if (!newOps.count(parentOp) && alreadyLegalized.insert(parentOp).second) {
- if (failed(legalize(parentOp, rewriter))) {
+ if (failed(legalize(parentOp))) {
LLVM_DEBUG(logFailure(
impl.logger, "operation '{0}'({1}) became illegal after rewrite",
parentOp->getName(), parentOp));
@@ -2852,11 +2834,10 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
}
LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
- ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
const SetVector<Operation *> &newOps) {
for (Operation *op : newOps) {
- if (failed(legalize(op, rewriter))) {
- LLVM_DEBUG(logFailure(impl.logger,
+ if (failed(legalize(op))) {
+ LLVM_DEBUG(logFailure(rewriter.getImpl().logger,
"failed to legalize generated operation '{0}'({1})",
op->getName(), op));
return failure();
@@ -2866,13 +2847,13 @@ LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
}
LogicalResult OperationLegalizer::legalizePatternRootUpdates(
- ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
const SetVector<Operation *> &modifiedOps) {
for (Operation *op : modifiedOps) {
- if (failed(legalize(op, rewriter))) {
- LLVM_DEBUG(logFailure(
- impl.logger, "failed to legalize operation updated in-place '{0}'",
- op->getName()));
+ if (failed(legalize(op))) {
+ LLVM_DEBUG(
+ logFailure(rewriter.getImpl().logger,
+ "failed to legalize operation updated in-place '{0}'",
+ op->getName()));
return failure();
}
}
@@ -3092,21 +3073,22 @@ namespace mlir {
// rewrite patterns. The conversion behaves differently depending on the
// conversion mode.
struct OperationConverter {
- explicit OperationConverter(const ConversionTarget &target,
+ explicit OperationConverter(MLIRContext *ctx, const ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
const ConversionConfig &config,
OpConversionMode mode)
- : config(config), opLegalizer(target, patterns), mode(mode) {}
+ : rewriter(ctx, config), opLegalizer(rewriter, target, patterns),
+ mode(mode) {}
/// Converts the given operations to the conversion target.
LogicalResult convertOperations(ArrayRef<Operation *> ops);
private:
/// Converts an operation with the given rewriter.
- LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
+ LogicalResult convert(Operation *op);
- /// Dialect conversion configuration.
- ConversionConfig config;
+ /// The rewriter to use when converting operations.
+ ConversionPatternRewriter rewriter;
/// The legalizer to use when converting operations.
OperationLegalizer opLegalizer;
@@ -3116,10 +3098,11 @@ struct OperationConverter {
};
} // namespace mlir
-LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
- Operation *op) {
+LogicalResult OperationConverter::convert(Operation *op) {
+ const ConversionConfig &config = rewriter.getConfig();
+
// Legalize the given operation.
- if (failed(opLegalizer.legalize(op, rewriter))) {
+ if (failed(opLegalizer.legalize(op))) {
// Handle the case of a failed conversion for each of the different modes.
// Full conversions expect all operations to be converted.
if (mode == OpConversionMode::Full)
@@ -3195,7 +3178,6 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
}
LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
- assert(!ops.empty() && "expected at least one operation");
const ConversionTarget &target = opLegalizer.getTarget();
// Compute the set of operations and blocks to convert.
@@ -3214,11 +3196,10 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
}
// Convert each operation and discard rewrites on failure.
- ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
for (auto *op : toConvert) {
- if (failed(convert(rewriter, op))) {
+ if (failed(convert(op))) {
// Dialect conversion failed.
if (rewriterImpl.config.allowPatternRollback) {
// Rollback is allowed: restore the original IR.
@@ -3253,13 +3234,16 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
castOp->removeAttr(kPureTypeConversionMarker);
// Try to legalize all unresolved materializations.
- if (config.buildMaterializations) {
- IRRewriter rewriter(rewriterImpl.context, config.listener);
+ if (rewriter.getConfig().buildMaterializations) {
+ // Use a new rewriter, so the modifications are not tracked for rollback
+ // purposes etc.
+ IRRewriter irRewriter(rewriterImpl.rewriter.getContext(),
+ rewriter.getConfig().listener);
for (UnrealizedConversionCastOp castOp : remainingCastOps) {
auto it = materializations.find(castOp);
assert(it != materializations.end() && "inconsistent state");
- if (failed(
- legalizeUnresolvedMaterialization(rewriter, castOp, it->second)))
+ if (failed(legalizeUnresolvedMaterialization(irRewriter, castOp,
+ it->second)))
return failure();
}
}
@@ -4001,7 +3985,8 @@ static LogicalResult applyConversion(ArrayRef<Operation *> ops,
SmallVector<IRUnit> irUnits(ops.begin(), ops.end());
ctx->executeAction<ApplyConversionAction>(
[&] {
- OperationConverter opConverter(target, patterns, config, mode);
+ OperationConverter opConverter(ops.front()->getContext(), target,
+ patterns, config, mode);
status = opConverter.convertOperations(ops);
},
irUnits);
More information about the Mlir-commits
mailing list