[Mlir-commits] [mlir] [mlir][Transforms][NFC] Simplify function signatures (PR #155997)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Aug 29 02:40:05 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
Many internal functions take a `ConversionPatternRewriter &` or `ConversionPatternRewriterImpl &` as a parameter. There's only a single instance of these classes, so it's better to store the reference in a field. This commit is in preparation of another PR that will require access to `ConversionPatternRewriter` in additional helper functions.
---
Patch is 26.33 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/155997.diff
1 Files Affected:
- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+87-102)
``````````diff
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index b6a216adfdd25..c0685f54731d5 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -848,9 +848,10 @@ static bool hasRewrite(R &&rewrites, Block *block) {
namespace mlir {
namespace detail {
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
- explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
+ explicit ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter,
const ConversionConfig &config)
- : context(ctx), config(config), notifyingRewriter(ctx, config.listener) {}
+ : rewriter(rewriter), config(config),
+ notifyingRewriter(rewriter.getContext(), config.listener) {}
//===--------------------------------------------------------------------===//
// State Management
@@ -887,8 +888,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// is the tag used when describing a value within a diagnostic, e.g.
/// "operand".
LogicalResult remapValues(StringRef valueDiagTag,
- std::optional<Location> inputLoc,
- PatternRewriter &rewriter, ValueRange values,
+ std::optional<Location> inputLoc, ValueRange values,
SmallVector<ValueVector> &remapped);
/// Return "true" if the given operation is ignored, and does not need to be
@@ -918,8 +918,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// Convert the types of block arguments within the given region.
FailureOr<Block *>
- convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
- const TypeConverter &converter,
+ convertRegionTypes(Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion);
/// Apply the given signature conversion on the given block. The new block
@@ -929,8 +928,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// translate between the origin argument types and those specified in the
/// signature conversion.
Block *applySignatureConversion(
- ConversionPatternRewriter &rewriter, Block *block,
- const TypeConverter *converter,
+ Block *block, const TypeConverter *converter,
TypeConverter::SignatureConversion &signatureConversion);
/// Replace the results of the given operation with the given values and
@@ -1060,8 +1058,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
// State
//===--------------------------------------------------------------------===//
- /// MLIR context.
- MLIRContext *context;
+ /// The rewriter that is used to perform the conversion.
+ ConversionPatternRewriter &rewriter;
// Mapping between replaced values that differ in type. This happens when
// replacing a value with one of a different type.
@@ -1258,16 +1256,17 @@ void UnresolvedMaterializationRewrite::rollback() {
}
void ConversionPatternRewriterImpl::applyRewrites() {
- // Commit all rewrites.
- IRRewriter rewriter(context, config.listener);
+ // Commit all rewrites. Use a new rewriter, so the modifications are not
+ // tracked for rollback purposes etc.
+ IRRewriter irRewriter(rewriter.getContext(), config.listener);
// Note: New rewrites may be added during the "commit" phase and the
// `rewrites` vector may reallocate.
for (size_t i = 0; i < rewrites.size(); ++i)
- rewrites[i]->commit(rewriter);
+ rewrites[i]->commit(irRewriter);
// Clean up all rewrites.
SingleEraseRewriter eraseRewriter(
- context, /*opErasedCallback=*/[&](Operation *op) {
+ rewriter.getContext(), /*opErasedCallback=*/[&](Operation *op) {
if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
unresolvedMaterializations.erase(castOp);
});
@@ -1412,8 +1411,7 @@ void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
}
LogicalResult ConversionPatternRewriterImpl::remapValues(
- StringRef valueDiagTag, std::optional<Location> inputLoc,
- PatternRewriter &rewriter, ValueRange values,
+ StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values,
SmallVector<ValueVector> &remapped) {
remapped.reserve(llvm::size(values));
@@ -1484,8 +1482,7 @@ bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
//===----------------------------------------------------------------------===//
FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
- ConversionPatternRewriter &rewriter, Region *region,
- const TypeConverter &converter,
+ Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion) {
regionToConverter[region] = &converter;
if (region->empty())
@@ -1500,25 +1497,23 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
if (!conversion)
return failure();
// Convert the block with the computed signature.
- applySignatureConversion(rewriter, &block, &converter, *conversion);
+ applySignatureConversion(&block, &converter, *conversion);
}
// Convert the entry block. If an entry signature conversion was provided,
// use that one. Otherwise, compute the signature with the type converter.
if (entryConversion)
- return applySignatureConversion(rewriter, ®ion->front(), &converter,
+ return applySignatureConversion(®ion->front(), &converter,
*entryConversion);
std::optional<TypeConverter::SignatureConversion> conversion =
converter.convertBlockSignature(®ion->front());
if (!conversion)
return failure();
- return applySignatureConversion(rewriter, ®ion->front(), &converter,
- *conversion);
+ return applySignatureConversion(®ion->front(), &converter, *conversion);
}
Block *ConversionPatternRewriterImpl::applySignatureConversion(
- ConversionPatternRewriter &rewriter, Block *block,
- const TypeConverter *converter,
+ Block *block, const TypeConverter *converter,
TypeConverter::SignatureConversion &signatureConversion) {
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// A block cannot be converted multiple times.
@@ -2023,7 +2018,7 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
ConversionPatternRewriter::ConversionPatternRewriter(
MLIRContext *ctx, const ConversionConfig &config)
: PatternRewriter(ctx),
- impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
+ impl(new detail::ConversionPatternRewriterImpl(*this, config)) {
setListener(impl.get());
}
@@ -2100,7 +2095,7 @@ Block *ConversionPatternRewriter::applySignatureConversion(
assert(!impl->wasOpReplaced(block->getParentOp()) &&
"attempting to apply a signature conversion to a block within a "
"replaced/erased op");
- return impl->applySignatureConversion(*this, block, converter, conversion);
+ return impl->applySignatureConversion(block, converter, conversion);
}
FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
@@ -2109,7 +2104,7 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
assert(!impl->wasOpReplaced(region->getParentOp()) &&
"attempting to apply a signature conversion to a block within a "
"replaced/erased op");
- return impl->convertRegionTypes(*this, region, converter, entryConversion);
+ return impl->convertRegionTypes(region, converter, entryConversion);
}
void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
@@ -2128,7 +2123,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
Value ConversionPatternRewriter::getRemappedValue(Value key) {
SmallVector<ValueVector> remappedValues;
- if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
+ if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, key,
remappedValues)))
return nullptr;
assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
@@ -2141,7 +2136,7 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
if (keys.empty())
return success();
SmallVector<ValueVector> remapped;
- if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
+ if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, keys,
remapped)))
return failure();
for (const auto &values : remapped) {
@@ -2288,7 +2283,7 @@ ConversionPattern::matchAndRewrite(Operation *op,
// Remap the operands of the operation.
SmallVector<ValueVector> remapped;
- if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
+ if (failed(rewriterImpl.remapValues("operand", op->getLoc(),
op->getOperands(), remapped))) {
return failure();
}
@@ -2310,7 +2305,8 @@ class OperationLegalizer {
public:
using LegalizationAction = ConversionTarget::LegalizationAction;
- OperationLegalizer(const ConversionTarget &targetInfo,
+ OperationLegalizer(ConversionPatternRewriter &rewriter,
+ const ConversionTarget &targetInfo,
const FrozenRewritePatternSet &patterns);
/// Returns true if the given operation is known to be illegal on the target.
@@ -2318,29 +2314,25 @@ class OperationLegalizer {
/// Attempt to legalize the given operation. Returns success if the operation
/// was legalized, failure otherwise.
- LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
+ LogicalResult legalize(Operation *op);
/// Returns the conversion target in use by the legalizer.
const ConversionTarget &getTarget() { return target; }
private:
/// Attempt to legalize the given operation by folding it.
- LogicalResult legalizeWithFold(Operation *op,
- ConversionPatternRewriter &rewriter);
+ LogicalResult legalizeWithFold(Operation *op);
/// Attempt to legalize the given operation by applying a pattern. Returns
/// success if the operation was legalized, failure otherwise.
- LogicalResult legalizeWithPattern(Operation *op,
- ConversionPatternRewriter &rewriter);
+ LogicalResult legalizeWithPattern(Operation *op);
/// Return true if the given pattern may be applied to the given operation,
/// false otherwise.
- bool canApplyPattern(Operation *op, const Pattern &pattern,
- ConversionPatternRewriter &rewriter);
+ bool canApplyPattern(Operation *op, const Pattern &pattern);
/// Legalize the resultant IR after successfully applying the given pattern.
LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
- ConversionPatternRewriter &rewriter,
const RewriterState &curState,
const SetVector<Operation *> &newOps,
const SetVector<Operation *> &modifiedOps,
@@ -2349,18 +2341,12 @@ class OperationLegalizer {
/// Legalizes the actions registered during the execution of a pattern.
LogicalResult
legalizePatternBlockRewrites(Operation *op,
- ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &impl,
const SetVector<Block *> &insertedBlocks,
const SetVector<Operation *> &newOps);
LogicalResult
- legalizePatternCreatedOperations(ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &impl,
- const SetVector<Operation *> &newOps);
+ legalizePatternCreatedOperations(const SetVector<Operation *> &newOps);
LogicalResult
- legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &impl,
- const SetVector<Operation *> &modifiedOps);
+ legalizePatternRootUpdates(const SetVector<Operation *> &modifiedOps);
//===--------------------------------------------------------------------===//
// Cost Model
@@ -2403,6 +2389,9 @@ class OperationLegalizer {
/// The current set of patterns that have been applied.
SmallPtrSet<const Pattern *, 8> appliedPatterns;
+ /// The rewriter to use when converting operations.
+ ConversionPatternRewriter &rewriter;
+
/// The legalization information provided by the target.
const ConversionTarget ⌖
@@ -2411,9 +2400,10 @@ class OperationLegalizer {
};
} // namespace
-OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
+OperationLegalizer::OperationLegalizer(ConversionPatternRewriter &rewriter,
+ const ConversionTarget &targetInfo,
const FrozenRewritePatternSet &patterns)
- : target(targetInfo), applicator(patterns) {
+ : rewriter(rewriter), target(targetInfo), applicator(patterns) {
// The set of patterns that can be applied to illegal operations to transform
// them into legal ones.
DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
@@ -2427,9 +2417,7 @@ bool OperationLegalizer::isIllegal(Operation *op) const {
return target.isIllegal(op);
}
-LogicalResult
-OperationLegalizer::legalize(Operation *op,
- ConversionPatternRewriter &rewriter) {
+LogicalResult OperationLegalizer::legalize(Operation *op) {
#ifndef NDEBUG
const char *logLineComment =
"//===-------------------------------------------===//\n";
@@ -2495,7 +2483,7 @@ OperationLegalizer::legalize(Operation *op,
// is 'BeforePatterns'. 'Never' will skip this.
const ConversionConfig &config = rewriter.getConfig();
if (config.foldingMode == DialectConversionFoldingMode::BeforePatterns) {
- if (succeeded(legalizeWithFold(op, rewriter))) {
+ if (succeeded(legalizeWithFold(op))) {
LLVM_DEBUG({
logSuccess(logger, "operation was folded");
logger.startLine() << logLineComment;
@@ -2505,7 +2493,7 @@ OperationLegalizer::legalize(Operation *op,
}
// Otherwise, we need to apply a legalization pattern to this operation.
- if (succeeded(legalizeWithPattern(op, rewriter))) {
+ if (succeeded(legalizeWithPattern(op))) {
LLVM_DEBUG({
logSuccess(logger, "");
logger.startLine() << logLineComment;
@@ -2516,7 +2504,7 @@ OperationLegalizer::legalize(Operation *op,
// If the operation can't be legalized via patterns, try to fold it in-place
// if the folding mode is 'AfterPatterns'.
if (config.foldingMode == DialectConversionFoldingMode::AfterPatterns) {
- if (succeeded(legalizeWithFold(op, rewriter))) {
+ if (succeeded(legalizeWithFold(op))) {
LLVM_DEBUG({
logSuccess(logger, "operation was folded");
logger.startLine() << logLineComment;
@@ -2541,9 +2529,7 @@ static T moveAndReset(T &obj) {
return result;
}
-LogicalResult
-OperationLegalizer::legalizeWithFold(Operation *op,
- ConversionPatternRewriter &rewriter) {
+LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
auto &rewriterImpl = rewriter.getImpl();
LLVM_DEBUG({
rewriterImpl.logger.startLine() << "* Fold {\n";
@@ -2577,14 +2563,14 @@ OperationLegalizer::legalizeWithFold(Operation *op,
// An empty list of replacement values indicates that the fold was in-place.
// As the operation changed, a new legalization needs to be attempted.
if (replacementValues.empty())
- return legalize(op, rewriter);
+ return legalize(op);
// Insert a replacement for 'op' with the folded replacement values.
rewriter.replaceOp(op, replacementValues);
// Recursively legalize any new constant operations.
for (Operation *newOp : newOps) {
- if (failed(legalize(newOp, rewriter))) {
+ if (failed(legalize(newOp))) {
LLVM_DEBUG(logFailure(rewriterImpl.logger,
"failed to legalize generated constant '{0}'",
newOp->getName()));
@@ -2629,9 +2615,7 @@ reportNewIrLegalizationFatalError(const Pattern &pattern,
llvm::join(insertedBlockNames, ", ") + "}");
}
-LogicalResult
-OperationLegalizer::legalizeWithPattern(Operation *op,
- ConversionPatternRewriter &rewriter) {
+LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
auto &rewriterImpl = rewriter.getImpl();
const ConversionConfig &config = rewriter.getConfig();
@@ -2663,7 +2647,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
// Functor that returns if the given pattern may be applied.
auto canApply = [&](const Pattern &pattern) {
- bool canApply = canApplyPattern(op, pattern, rewriter);
+ bool canApply = canApplyPattern(op, pattern);
if (canApply && config.listener)
config.listener->notifyPatternBegin(pattern, op);
return canApply;
@@ -2728,7 +2712,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
moveAndReset(rewriterImpl.patternModifiedOps);
SetVector<Block *> insertedBlocks =
moveAndReset(rewriterImpl.patternInsertedBlocks);
- auto result = legalizePatternResult(op, pattern, rewriter, curState, newOps,
+ auto result = legalizePatternResult(op, pattern, curState, newOps,
modifiedOps, insertedBlocks);
appliedPatterns.erase(&pattern);
if (failed(result)) {
@@ -2747,8 +2731,8 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
onSuccess);
}
-bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
- ConversionPatternRewriter &rewriter) {
+bool OperationLegalizer::canApplyPattern(Operation *op,
+ const Pattern &pattern) {
LLVM_DEBUG({
auto &os = rewriter.getImpl().logger;
os.getOStream() << "\n";
@@ -2770,8 +2754,8 @@ bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
}
LogicalResult OperationLegalizer::legalizePatternResult(
- Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter,
- const RewriterState &curState, const SetVector<Operation *> &newOps,
+ Operation *op, const Pattern &pattern, const RewriterState &curState,
+ const SetVector<Operation *> &newOps,
const SetVector<Operation *> &modifiedOps,
const SetVector<Block *> &insertedBlocks) {
auto &impl = rewriter.getImpl();
@@ -2792,10 +2776,9 @@ LogicalResult OperationLegalizer::legalizePatternResult(
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Legalize each of the actions registered during application.
- if (failed(legalizePatternBlockRewrites(op, rewriter, impl, insertedBlocks,
- newOps)) ||
- failed(legalizePatternRootUpdates(rewriter, impl, modifiedOps)) ||
- failed(legalizePatternCreatedOperations(rewriter, impl, newOps))) {
+ if (failed(legalizePatternBlockRewrites(op, insertedBlocks, newOps)) ||
+ failed(legalizePatternRootUpdates(modifiedOps)) ||
+ failed(legalizePatternCreatedOperations(newOps))) {
return failure();
}
@@ -2804,10 +2787,9 @@ LogicalResult OperationLegalizer::legalizePatternResult(
}
LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
- Operation *op, ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &impl,
- const SetVector<Block *> &insertedBlocks,
+ Operation *op, const SetVector<Block *> &insertedBlocks,
const SetVector<Operation *> &newOps) {
+ ConversionPatternRewriterImpl &impl = rewriter.getImpl();
SmallPtrSet<Operation *, 16> alreadyLegalized;
// If the pattern moved or created any blocks, make sure the types of block
@@ -2831,7 +2813,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
"block"));
return failure()...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/155997
More information about the Mlir-commits
mailing list