[Mlir-commits] [mlir] [mlir][Transforms] Dialect Conversion: Convert entry block only (PR #165180)
Matthias Springer
llvmlistbot at llvm.org
Thu Oct 30 08:15:00 PDT 2025
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/165180
>From a9aa0a29525572910b43f4b2e789b481f2bf5bad Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sun, 26 Oct 2025 23:21:56 +0000
Subject: [PATCH] [mlir][Transforms] Dialect Conversion: Convert entry block
only
---
.../Transforms/Utils/DialectConversion.cpp | 112 ++++--------------
mlir/test/Transforms/test-legalizer.mlir | 30 -----
2 files changed, 20 insertions(+), 122 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 3a23bbfd70eac..2fe06970eb568 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1105,10 +1105,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// A set of operations that were modified by the current pattern.
SetVector<Operation *> patternModifiedOps;
- /// A set of blocks that were inserted (newly-created blocks or moved blocks)
- /// by the current pattern.
- SetVector<Block *> patternInsertedBlocks;
-
/// A list of unresolved materializations that were created by the current
/// pattern.
DenseSet<UnrealizedConversionCastOp> patternMaterializations;
@@ -2046,8 +2042,6 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
if (!config.allowPatternRollback && config.listener)
config.listener->notifyBlockInserted(block, previous, previousIt);
- patternInsertedBlocks.insert(block);
-
if (wasDetached) {
// If the block was detached, it is most likely a newly created block.
if (config.allowPatternRollback) {
@@ -2399,17 +2393,12 @@ class OperationLegalizer {
bool canApplyPattern(Operation *op, const Pattern &pattern);
/// Legalize the resultant IR after successfully applying the given pattern.
- LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
- const RewriterState &curState,
- const SetVector<Operation *> &newOps,
- const SetVector<Operation *> &modifiedOps,
- const SetVector<Block *> &insertedBlocks);
-
- /// Legalizes the actions registered during the execution of a pattern.
LogicalResult
- legalizePatternBlockRewrites(Operation *op,
- const SetVector<Block *> &insertedBlocks,
- const SetVector<Operation *> &newOps);
+ legalizePatternResult(Operation *op, const Pattern &pattern,
+ const RewriterState &curState,
+ const SetVector<Operation *> &newOps,
+ const SetVector<Operation *> &modifiedOps);
+
LogicalResult
legalizePatternCreatedOperations(const SetVector<Operation *> &newOps);
LogicalResult
@@ -2608,7 +2597,6 @@ LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
auto cleanup = llvm::make_scope_exit([&]() {
rewriterImpl.patternNewOps.clear();
rewriterImpl.patternModifiedOps.clear();
- rewriterImpl.patternInsertedBlocks.clear();
});
// Upon failure, undo all changes made by the folder.
@@ -2662,24 +2650,16 @@ LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
static void
reportNewIrLegalizationFatalError(const Pattern &pattern,
const SetVector<Operation *> &newOps,
- const SetVector<Operation *> &modifiedOps,
- const SetVector<Block *> &insertedBlocks) {
+ const SetVector<Operation *> &modifiedOps) {
auto newOpNames = llvm::map_range(
newOps, [](Operation *op) { return op->getName().getStringRef(); });
auto modifiedOpNames = llvm::map_range(
modifiedOps, [](Operation *op) { return op->getName().getStringRef(); });
- StringRef detachedBlockStr = "(detached block)";
- auto insertedBlockNames = llvm::map_range(insertedBlocks, [&](Block *block) {
- if (block->getParentOp())
- return block->getParentOp()->getName().getStringRef();
- return detachedBlockStr;
- });
- llvm::report_fatal_error(
- "pattern '" + pattern.getDebugName() +
- "' produced IR that could not be legalized. " + "new ops: {" +
- llvm::join(newOpNames, ", ") + "}, " + "modified ops: {" +
- llvm::join(modifiedOpNames, ", ") + "}, " + "inserted block into ops: {" +
- llvm::join(insertedBlockNames, ", ") + "}");
+ llvm::report_fatal_error("pattern '" + pattern.getDebugName() +
+ "' produced IR that could not be legalized. " +
+ "new ops: {" + llvm::join(newOpNames, ", ") + "}, " +
+ "modified ops: {" +
+ llvm::join(modifiedOpNames, ", ") + "}");
}
LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
@@ -2743,7 +2723,6 @@ LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
}
rewriterImpl.patternNewOps.clear();
rewriterImpl.patternModifiedOps.clear();
- rewriterImpl.patternInsertedBlocks.clear();
LLVM_DEBUG({
logFailure(rewriterImpl.logger, "pattern failed to match");
if (rewriterImpl.config.notifyCallback) {
@@ -2777,15 +2756,12 @@ LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps);
SetVector<Operation *> modifiedOps =
moveAndReset(rewriterImpl.patternModifiedOps);
- SetVector<Block *> insertedBlocks =
- moveAndReset(rewriterImpl.patternInsertedBlocks);
- auto result = legalizePatternResult(op, pattern, curState, newOps,
- modifiedOps, insertedBlocks);
+ auto result =
+ legalizePatternResult(op, pattern, curState, newOps, modifiedOps);
appliedPatterns.erase(&pattern);
if (failed(result)) {
if (!rewriterImpl.config.allowPatternRollback)
- reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps,
- insertedBlocks);
+ reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps);
rewriterImpl.resetState(curState, pattern.getDebugName());
}
if (config.listener)
@@ -2823,8 +2799,7 @@ bool OperationLegalizer::canApplyPattern(Operation *op,
LogicalResult OperationLegalizer::legalizePatternResult(
Operation *op, const Pattern &pattern, const RewriterState &curState,
const SetVector<Operation *> &newOps,
- const SetVector<Operation *> &modifiedOps,
- const SetVector<Block *> &insertedBlocks) {
+ const SetVector<Operation *> &modifiedOps) {
[[maybe_unused]] auto &impl = rewriter.getImpl();
assert(impl.pendingRootUpdates.empty() && "dangling root updates");
@@ -2843,8 +2818,7 @@ LogicalResult OperationLegalizer::legalizePatternResult(
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Legalize each of the actions registered during application.
- if (failed(legalizePatternBlockRewrites(op, insertedBlocks, newOps)) ||
- failed(legalizePatternRootUpdates(modifiedOps)) ||
+ if (failed(legalizePatternRootUpdates(modifiedOps)) ||
failed(legalizePatternCreatedOperations(newOps))) {
return failure();
}
@@ -2853,53 +2827,6 @@ LogicalResult OperationLegalizer::legalizePatternResult(
return success();
}
-LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
- Operation *op, const SetVector<Block *> &insertedBlocks,
- const SetVector<Operation *> &newOps) {
- ConversionPatternRewriterImpl &impl = rewriter.getImpl();
- SmallPtrSet<Operation *, 16> alreadyLegalized;
-
- // If the pattern moved or created any blocks, make sure the types of block
- // arguments get legalized.
- for (Block *block : insertedBlocks) {
- if (impl.erasedBlocks.contains(block))
- continue;
-
- // Only check blocks outside of the current operation.
- Operation *parentOp = block->getParentOp();
- if (!parentOp || parentOp == op || block->getNumArguments() == 0)
- continue;
-
- // If the region of the block has a type converter, try to convert the block
- // directly.
- if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
- std::optional<TypeConverter::SignatureConversion> conversion =
- converter->convertBlockSignature(block);
- if (!conversion) {
- LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
- "block"));
- return failure();
- }
- impl.applySignatureConversion(block, converter, *conversion);
- continue;
- }
-
- // Otherwise, try to legalize the parent operation if it was not generated
- // by this pattern. This is because we will attempt to legalize the parent
- // operation, and blocks in regions created by this pattern will already be
- // legalized later on.
- if (!newOps.count(parentOp) && alreadyLegalized.insert(parentOp).second) {
- if (failed(legalize(parentOp))) {
- LLVM_DEBUG(logFailure(
- impl.logger, "operation '{0}'({1}) became illegal after rewrite",
- parentOp->getName(), parentOp));
- return failure();
- }
- }
- }
- return success();
-}
-
LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
const SetVector<Operation *> &newOps) {
for (Operation *op : newOps) {
@@ -3800,10 +3727,11 @@ static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
TypeConverter::SignatureConversion result(type.getNumInputs());
SmallVector<Type, 1> newResults;
if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
- failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
- failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
- typeConverter, &result)))
+ failed(typeConverter.convertTypes(type.getResults(), newResults)))
return failure();
+ if (!funcOp.getFunctionBody().empty())
+ rewriter.applySignatureConversion(&funcOp.getFunctionBody().front(), result,
+ &typeConverter);
// Update the function signature in-place.
auto newType = FunctionType::get(rewriter.getContext(),
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 94c5bb4e93b06..ba1f962fdb68b 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -146,36 +146,6 @@ func.func @no_remap_nested() {
// -----
-// CHECK-LABEL: func @remap_moved_region_args
-func.func @remap_moved_region_args() {
- // CHECK-NEXT: return
- // CHECK-NEXT: ^bb1(%{{.*}}: f64, %{{.*}}: f64, %{{.*}}: f16, %{{.*}}: f16):
- // CHECK-NEXT: "test.cast"{{.*}} : (f16, f16) -> f32
- // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64, f32)
- "test.region"() ({
- ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
- "test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> ()
- }) : () -> ()
- // expected-remark at +1 {{op 'func.return' is not legalizable}}
- return
-}
-
-// -----
-
-// CHECK-LABEL: func @remap_cloned_region_args
-func.func @remap_cloned_region_args() {
- // CHECK-NEXT: return
- // CHECK-NEXT: ^bb1(%{{.*}}: f64, %{{.*}}: f64, %{{.*}}: f16, %{{.*}}: f16):
- // CHECK-NEXT: "test.cast"{{.*}} : (f16, f16) -> f32
- // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64, f32)
- "test.region"() ({
- ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
- "test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> ()
- }) {legalizer.should_clone} : () -> ()
- // expected-remark at +1 {{op 'func.return' is not legalizable}}
- return
-}
-
// CHECK-LABEL: func @remap_drop_region
func.func @remap_drop_region() {
// CHECK-NEXT: return
More information about the Mlir-commits
mailing list