[Mlir-commits] [mlir] [mlir] Dialect Conversion: Add support for post-order legalization order (PR #166292)
Matthias Springer
llvmlistbot at llvm.org
Mon Nov 3 20:06:07 PST 2025
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/166292
>From b4c8f6916e8f2414da3ac4896463ee93b6d3f4ef Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Tue, 4 Nov 2025 03:05:05 +0000
Subject: [PATCH] [mlir] Dialect Conversion: Add support for post-order
legalization order
---
.../mlir/Transforms/DialectConversion.h | 25 +++-
.../Transforms/Utils/DialectConversion.cpp | 120 +++++++++++++-----
mlir/test/Transforms/test-legalizer-full.mlir | 18 +++
.../Transforms/test-legalizer-rollback.mlir | 19 +++
mlir/test/Transforms/test-legalizer.mlir | 32 +++++
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 22 +++-
6 files changed, 199 insertions(+), 37 deletions(-)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index ed7e2a08ebfd9..5ac9e26e8636d 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -981,6 +981,28 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// Return a reference to the internal implementation.
detail::ConversionPatternRewriterImpl &getImpl();
+ /// Attempt to legalize the given operation. This can be used within
+ /// conversion patterns to change the default pre-order legalization order.
+ /// Returns "success" if the operation was legalized, "failure" otherwise.
+ ///
+ /// Note: In a partial conversion, this function returns "success" even if
+ /// the operation could not be legalized, as long as it was not explicitly
+ /// marked as illegal in the conversion target.
+ LogicalResult legalize(Operation *op);
+
+ /// Attempt to legalize the given region. This can be used within
+ /// conversion patterns to change the default pre-order legalization order.
+ /// Returns "success" if the region was legalized, "failure" otherwise.
+ ///
+ /// If the current pattern runs with a type converter, the entry block
+ /// signature will be converted before legalizing the operations in the
+ /// region.
+ ///
+ /// Note: In a partial conversion, this function returns "success" even if
+ /// an operation could not be legalized, as long as it was not explicitly
+ /// marked as illegal in the conversion target.
+ LogicalResult legalize(Region *r);
+
private:
// Allow OperationConverter to construct new rewriters.
friend struct OperationConverter;
@@ -989,7 +1011,8 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// conversions. They apply some IR rewrites in a delayed fashion and could
/// bring the IR into an inconsistent state when used standalone.
explicit ConversionPatternRewriter(MLIRContext *ctx,
- const ConversionConfig &config);
+ const ConversionConfig &config,
+ OperationConverter &converter);
// Hide unsupported pattern rewriter API.
using OpBuilder::setListener;
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 2fe06970eb568..f8c38fadbd229 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -92,6 +92,22 @@ static OpBuilder::InsertPoint computeInsertPoint(ArrayRef<Value> vals) {
return pt;
}
+namespace {
+enum OpConversionMode {
+ /// In this mode, the conversion will ignore failed conversions to allow
+ /// illegal operations to co-exist in the IR.
+ Partial,
+
+ /// In this mode, all operations must be legal for the given target for the
+ /// conversion to succeed.
+ Full,
+
+ /// In this mode, operations are analyzed for legality. No actual rewrites are
+ /// applied to the operations on success.
+ Analysis,
+};
+} // namespace
+
//===----------------------------------------------------------------------===//
// ConversionValueMapping
//===----------------------------------------------------------------------===//
@@ -866,8 +882,9 @@ namespace mlir {
namespace detail {
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
explicit ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter,
- const ConversionConfig &config)
- : rewriter(rewriter), config(config),
+ const ConversionConfig &config,
+ OperationConverter &opConverter)
+ : rewriter(rewriter), config(config), opConverter(opConverter),
notifyingRewriter(rewriter.getContext(), config.listener) {}
//===--------------------------------------------------------------------===//
@@ -1124,6 +1141,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// Dialect conversion configuration.
const ConversionConfig &config;
+ /// The operation converter to use for recursive legalization.
+ OperationConverter &opConverter;
+
/// A set of erased operations. This set is utilized only if
/// `allowPatternRollback` is set to "false". Conceptually, this set is
/// similar to `replacedOps` (which is maintained when the flag is set to
@@ -2084,9 +2104,10 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
//===----------------------------------------------------------------------===//
ConversionPatternRewriter::ConversionPatternRewriter(
- MLIRContext *ctx, const ConversionConfig &config)
- : PatternRewriter(ctx),
- impl(new detail::ConversionPatternRewriterImpl(*this, config)) {
+ MLIRContext *ctx, const ConversionConfig &config,
+ OperationConverter &opConverter)
+ : PatternRewriter(ctx), impl(new detail::ConversionPatternRewriterImpl(
+ *this, config, opConverter)) {
setListener(impl.get());
}
@@ -2207,6 +2228,37 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
return success();
}
+LogicalResult ConversionPatternRewriter::legalize(Region *r) {
+ // Fast path: If the region is empty, there is nothing to legalize.
+ if (r->empty())
+ return success();
+
+ // Gather a list of all operations to legalize. This is done before
+ // converting the entry block signature because unrealized_conversion_cast
+ // ops should not be included.
+ SmallVector<Operation *> ops;
+ for (Block &b : *r)
+ for (Operation &op : b)
+ ops.push_back(&op);
+
+ // If the current pattern runs with a type converter, convert the entry block
+ // signature.
+ if (const TypeConverter *converter = impl->currentTypeConverter) {
+ std::optional<TypeConverter::SignatureConversion> conversion =
+ converter->convertBlockSignature(&r->front());
+ if (!conversion)
+ return failure();
+ applySignatureConversion(&r->front(), *conversion, converter);
+ }
+
+ // Legalize all operations in the region.
+ for (Operation *op : ops)
+ if (failed(legalize(op)))
+ return failure();
+
+ return success();
+}
+
void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
Block::iterator before,
ValueRange argValues) {
@@ -3192,22 +3244,6 @@ static void reconcileUnrealizedCasts(
// OperationConverter
//===----------------------------------------------------------------------===//
-namespace {
-enum OpConversionMode {
- /// In this mode, the conversion will ignore failed conversions to allow
- /// illegal operations to co-exist in the IR.
- Partial,
-
- /// In this mode, all operations must be legal for the given target for the
- /// conversion to succeed.
- Full,
-
- /// In this mode, operations are analyzed for legality. No actual rewrites are
- /// applied to the operations on success.
- Analysis,
-};
-} // namespace
-
namespace mlir {
// This class converts operations to a given conversion target via a set of
// rewrite patterns. The conversion behaves differently depending on the
@@ -3217,16 +3253,20 @@ struct OperationConverter {
const FrozenRewritePatternSet &patterns,
const ConversionConfig &config,
OpConversionMode mode)
- : rewriter(ctx, config), opLegalizer(rewriter, target, patterns),
+ : rewriter(ctx, config, *this), opLegalizer(rewriter, target, patterns),
mode(mode) {}
/// Converts the given operations to the conversion target.
LogicalResult convertOperations(ArrayRef<Operation *> ops);
-private:
- /// Converts an operation with the given rewriter.
- LogicalResult convert(Operation *op);
+ /// Converts a single operation. If `isRecursiveLegalization` is "true", the
+ /// conversion is a recursive legalization request, triggered from within a
+ /// pattern. In that case, do not emit errors because there will be another
+ /// attempt at legalizing the operation later (via the regular pre-order
+ /// legalization mechanism).
+ LogicalResult convert(Operation *op, bool isRecursiveLegalization = false);
+private:
/// The rewriter to use when converting operations.
ConversionPatternRewriter rewriter;
@@ -3238,32 +3278,42 @@ struct OperationConverter {
};
} // namespace mlir
-LogicalResult OperationConverter::convert(Operation *op) {
+LogicalResult ConversionPatternRewriter::legalize(Operation *op) {
+ return impl->opConverter.convert(op, /*isRecursiveLegalization=*/true);
+}
+
+LogicalResult OperationConverter::convert(Operation *op,
+ bool isRecursiveLegalization) {
const ConversionConfig &config = rewriter.getConfig();
// Legalize the given operation.
if (failed(opLegalizer.legalize(op))) {
// Handle the case of a failed conversion for each of the different modes.
// Full conversions expect all operations to be converted.
- if (mode == OpConversionMode::Full)
- return op->emitError()
- << "failed to legalize operation '" << op->getName() << "'";
+ if (mode == OpConversionMode::Full) {
+ if (!isRecursiveLegalization)
+ op->emitError() << "failed to legalize operation '" << op->getName()
+ << "'";
+ return failure();
+ }
// Partial conversions allow conversions to fail iff the operation was not
// explicitly marked as illegal. If the user provided a `unlegalizedOps`
// set, non-legalizable ops are added to that set.
if (mode == OpConversionMode::Partial) {
- if (opLegalizer.isIllegal(op))
- return op->emitError()
- << "failed to legalize operation '" << op->getName()
- << "' that was explicitly marked illegal";
- if (config.unlegalizedOps)
+ if (opLegalizer.isIllegal(op)) {
+ if (!isRecursiveLegalization)
+ op->emitError() << "failed to legalize operation '" << op->getName()
+ << "' that was explicitly marked illegal";
+ return failure();
+ }
+ if (config.unlegalizedOps && !isRecursiveLegalization)
config.unlegalizedOps->insert(op);
}
} else if (mode == OpConversionMode::Analysis) {
// Analysis conversions don't fail if any operations fail to legalize,
// they are only interested in the operations that were successfully
// legalized.
- if (config.legalizableOps)
+ if (config.legalizableOps && !isRecursiveLegalization)
config.legalizableOps->insert(op);
}
return success();
diff --git a/mlir/test/Transforms/test-legalizer-full.mlir b/mlir/test/Transforms/test-legalizer-full.mlir
index 42cec68b9fbbb..8da9109a32762 100644
--- a/mlir/test/Transforms/test-legalizer-full.mlir
+++ b/mlir/test/Transforms/test-legalizer-full.mlir
@@ -72,3 +72,21 @@ builtin.module {
}
}
+
+// -----
+
+// The region of "test.post_order_legalization" is converted before the op.
+
+// expected-remark at +1 {{applyFullConversion failed}}
+builtin.module {
+func.func @test_preorder_legalization() {
+ // expected-error at +1 {{failed to legalize operation 'test.post_order_legalization'}}
+ "test.post_order_legalization"() ({
+ ^bb0(%arg0: i64):
+ // Not-explicitly-legal ops are not allowed to survive.
+ "test.remaining_consumer"(%arg0) : (i64) -> ()
+ "test.invalid"(%arg0) : (i64) -> ()
+ }) : () -> ()
+ return
+}
+}
diff --git a/mlir/test/Transforms/test-legalizer-rollback.mlir b/mlir/test/Transforms/test-legalizer-rollback.mlir
index 71e11782e14b0..4bcca6b7e5228 100644
--- a/mlir/test/Transforms/test-legalizer-rollback.mlir
+++ b/mlir/test/Transforms/test-legalizer-rollback.mlir
@@ -163,3 +163,22 @@ func.func @create_unregistered_op_in_pattern() -> i32 {
"test.return"(%0) : (i32) -> ()
}
}
+
+// -----
+
+// CHECK-LABEL: func @test_failed_preorder_legalization
+// CHECK: "test.post_order_legalization"() ({
+// CHECK: %[[r:.*]] = "test.illegal_op_g"() : () -> i32
+// CHECK: "test.return"(%[[r]]) : (i32) -> ()
+// CHECK: }) : () -> ()
+// expected-remark @+1 {{applyPartialConversion failed}}
+module {
+func.func @test_failed_preorder_legalization() {
+ // expected-error @+1 {{failed to legalize operation 'test.post_order_legalization' that was explicitly marked illegal}}
+ "test.post_order_legalization"() ({
+ %0 = "test.illegal_op_g"() : () -> (i32)
+ "test.return"(%0) : (i32) -> ()
+ }) : () -> ()
+ return
+}
+}
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 7c43bb7bface0..88a71cc26ab0c 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -448,3 +448,35 @@ func.func @test_working_1to1_pattern(%arg0: f16) {
"test.type_consumer"(%arg0) : (f16) -> ()
"test.return"() : () -> ()
}
+
+// -----
+
+// The region of "test.post_order_legalization" is converted before the op.
+
+// CHECK: notifyBlockInserted into test.post_order_legalization: was unlinked
+// CHECK: notifyOperationInserted: test.invalid
+// CHECK: notifyBlockErased
+// CHECK: notifyOperationInserted: test.valid, was unlinked
+// CHECK: notifyOperationReplaced: test.invalid
+// CHECK: notifyOperationErased: test.invalid
+// CHECK: notifyOperationModified: test.post_order_legalization
+
+// CHECK-LABEL: func @test_preorder_legalization
+// CHECK: "test.post_order_legalization"() ({
+// CHECK: ^{{.*}}(%[[arg0:.*]]: f64):
+// Note: The survival of a not-explicitly-invalid operation does *not* cause
+// a conversion failure in when applying a partial conversion.
+// CHECK: %[[cast:.*]] = "test.cast"(%[[arg0]]) : (f64) -> i64
+// CHECK: "test.remaining_consumer"(%[[cast]]) : (i64) -> ()
+// CHECK: "test.valid"(%[[arg0]]) : (f64) -> ()
+// CHECK: }) {is_legal} : () -> ()
+func.func @test_preorder_legalization() {
+ "test.post_order_legalization"() ({
+ ^bb0(%arg0: i64):
+ // expected-remark @+1 {{'test.remaining_consumer' is not legalizable}}
+ "test.remaining_consumer"(%arg0) : (i64) -> ()
+ "test.invalid"(%arg0) : (i64) -> ()
+ }) : () -> ()
+ // expected-remark @+1 {{'func.return' is not legalizable}}
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 12edecc113495..9b64bc691588d 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1418,6 +1418,22 @@ class TestTypeConsumerOpPattern
}
};
+class TestPostOrderLegalization : public ConversionPattern {
+public:
+ TestPostOrderLegalization(MLIRContext *ctx, const TypeConverter &converter)
+ : ConversionPattern(converter, "test.post_order_legalization", 1, ctx) {}
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ for (Region &r : op->getRegions())
+ if (failed(rewriter.legalize(&r)))
+ return failure();
+ rewriter.modifyOpInPlace(
+ op, [&]() { op->setAttr("is_legal", rewriter.getUnitAttr()); });
+ return success();
+ }
+};
+
/// Test unambiguous overload resolution of replaceOpWithMultiple. This
/// function is just to trigger compiler errors. It is never executed.
[[maybe_unused]] void testReplaceOpWithMultipleOverloads(
@@ -1532,7 +1548,8 @@ struct TestLegalizePatternDriver
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
TestPassthroughInvalidOp, TestMultiple1ToNReplacement,
TestValueReplace, TestReplaceWithValidConsumer,
- TestTypeConsumerOpPattern>(&getContext(), converter);
+ TestTypeConsumerOpPattern, TestPostOrderLegalization>(
+ &getContext(), converter);
patterns.add<TestConvertBlockArgs>(converter, &getContext());
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
@@ -1560,6 +1577,9 @@ struct TestLegalizePatternDriver
target.addDynamicallyLegalOp(
OperationName("test.value_replace", &getContext()),
[](Operation *op) { return op->hasAttr("is_legal"); });
+ target.addDynamicallyLegalOp(
+ OperationName("test.post_order_legalization", &getContext()),
+ [](Operation *op) { return op->hasAttr("is_legal"); });
// TestCreateUnregisteredOp creates `arith.constant` operation,
// which was not added to target intentionally to test
More information about the Mlir-commits
mailing list