[Mlir-commits] [mlir] [mlir] Dialect Conversion: Add support for post-order legalization order (PR #166292)

Matthias Springer llvmlistbot at llvm.org
Mon Nov 3 20:06:07 PST 2025


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/166292

>From b4c8f6916e8f2414da3ac4896463ee93b6d3f4ef Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Tue, 4 Nov 2025 03:05:05 +0000
Subject: [PATCH] [mlir] Dialect Conversion: Add support for post-order
 legalization order

---
 .../mlir/Transforms/DialectConversion.h       |  25 +++-
 .../Transforms/Utils/DialectConversion.cpp    | 120 +++++++++++++-----
 mlir/test/Transforms/test-legalizer-full.mlir |  18 +++
 .../Transforms/test-legalizer-rollback.mlir   |  19 +++
 mlir/test/Transforms/test-legalizer.mlir      |  32 +++++
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   |  22 +++-
 6 files changed, 199 insertions(+), 37 deletions(-)

diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index ed7e2a08ebfd9..5ac9e26e8636d 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -981,6 +981,28 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// Return a reference to the internal implementation.
   detail::ConversionPatternRewriterImpl &getImpl();
 
+  /// Attempt to legalize the given operation. This can be used within
+  /// conversion patterns to change the default pre-order legalization order.
+  /// Returns "success" if the operation was legalized, "failure" otherwise.
+  ///
+  /// Note: In a partial conversion, this function returns "success" even if
+  /// the operation could not be legalized, as long as it was not explicitly
+  /// marked as illegal in the conversion target.
+  LogicalResult legalize(Operation *op);
+
+  /// Attempt to legalize the given region. This can be used within
+  /// conversion patterns to change the default pre-order legalization order.
+  /// Returns "success" if the region was legalized, "failure" otherwise.
+  ///
+  /// If the current pattern runs with a type converter, the entry block
+  /// signature will be converted before legalizing the operations in the
+  /// region.
+  ///
+  /// Note: In a partial conversion, this function returns "success" even if
+  /// an operation could not be legalized, as long as it was not explicitly
+  /// marked as illegal in the conversion target.
+  LogicalResult legalize(Region *r);
+
 private:
   // Allow OperationConverter to construct new rewriters.
   friend struct OperationConverter;
@@ -989,7 +1011,8 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// conversions. They apply some IR rewrites in a delayed fashion and could
   /// bring the IR into an inconsistent state when used standalone.
   explicit ConversionPatternRewriter(MLIRContext *ctx,
-                                     const ConversionConfig &config);
+                                     const ConversionConfig &config,
+                                     OperationConverter &converter);
 
   // Hide unsupported pattern rewriter API.
   using OpBuilder::setListener;
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 2fe06970eb568..f8c38fadbd229 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -92,6 +92,22 @@ static OpBuilder::InsertPoint computeInsertPoint(ArrayRef<Value> vals) {
   return pt;
 }
 
+namespace {
+enum OpConversionMode {
+  /// In this mode, the conversion will ignore failed conversions to allow
+  /// illegal operations to co-exist in the IR.
+  Partial,
+
+  /// In this mode, all operations must be legal for the given target for the
+  /// conversion to succeed.
+  Full,
+
+  /// In this mode, operations are analyzed for legality. No actual rewrites are
+  /// applied to the operations on success.
+  Analysis,
+};
+} // namespace
+
 //===----------------------------------------------------------------------===//
 // ConversionValueMapping
 //===----------------------------------------------------------------------===//
@@ -866,8 +882,9 @@ namespace mlir {
 namespace detail {
 struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   explicit ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter,
-                                         const ConversionConfig &config)
-      : rewriter(rewriter), config(config),
+                                         const ConversionConfig &config,
+                                         OperationConverter &opConverter)
+      : rewriter(rewriter), config(config), opConverter(opConverter),
         notifyingRewriter(rewriter.getContext(), config.listener) {}
 
   //===--------------------------------------------------------------------===//
@@ -1124,6 +1141,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// Dialect conversion configuration.
   const ConversionConfig &config;
 
+  /// The operation converter to use for recursive legalization.
+  OperationConverter &opConverter;
+
   /// A set of erased operations. This set is utilized only if
   /// `allowPatternRollback` is set to "false". Conceptually, this set is
   /// similar to `replacedOps` (which is maintained when the flag is set to
@@ -2084,9 +2104,10 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
 //===----------------------------------------------------------------------===//
 
 ConversionPatternRewriter::ConversionPatternRewriter(
-    MLIRContext *ctx, const ConversionConfig &config)
-    : PatternRewriter(ctx),
-      impl(new detail::ConversionPatternRewriterImpl(*this, config)) {
+    MLIRContext *ctx, const ConversionConfig &config,
+    OperationConverter &opConverter)
+    : PatternRewriter(ctx), impl(new detail::ConversionPatternRewriterImpl(
+                                *this, config, opConverter)) {
   setListener(impl.get());
 }
 
@@ -2207,6 +2228,37 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
   return success();
 }
 
+LogicalResult ConversionPatternRewriter::legalize(Region *r) {
+  // Fast path: If the region is empty, there is nothing to legalize.
+  if (r->empty())
+    return success();
+
+  // Gather a list of all operations to legalize. This is done before
+  // converting the entry block signature because unrealized_conversion_cast
+  // ops should not be included.
+  SmallVector<Operation *> ops;
+  for (Block &b : *r)
+    for (Operation &op : b)
+      ops.push_back(&op);
+
+  // If the current pattern runs with a type converter, convert the entry block
+  // signature.
+  if (const TypeConverter *converter = impl->currentTypeConverter) {
+    std::optional<TypeConverter::SignatureConversion> conversion =
+        converter->convertBlockSignature(&r->front());
+    if (!conversion)
+      return failure();
+    applySignatureConversion(&r->front(), *conversion, converter);
+  }
+
+  // Legalize all operations in the region.
+  for (Operation *op : ops)
+    if (failed(legalize(op)))
+      return failure();
+
+  return success();
+}
+
 void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
                                                   Block::iterator before,
                                                   ValueRange argValues) {
@@ -3192,22 +3244,6 @@ static void reconcileUnrealizedCasts(
 // OperationConverter
 //===----------------------------------------------------------------------===//
 
-namespace {
-enum OpConversionMode {
-  /// In this mode, the conversion will ignore failed conversions to allow
-  /// illegal operations to co-exist in the IR.
-  Partial,
-
-  /// In this mode, all operations must be legal for the given target for the
-  /// conversion to succeed.
-  Full,
-
-  /// In this mode, operations are analyzed for legality. No actual rewrites are
-  /// applied to the operations on success.
-  Analysis,
-};
-} // namespace
-
 namespace mlir {
 // This class converts operations to a given conversion target via a set of
 // rewrite patterns. The conversion behaves differently depending on the
@@ -3217,16 +3253,20 @@ struct OperationConverter {
                               const FrozenRewritePatternSet &patterns,
                               const ConversionConfig &config,
                               OpConversionMode mode)
-      : rewriter(ctx, config), opLegalizer(rewriter, target, patterns),
+      : rewriter(ctx, config, *this), opLegalizer(rewriter, target, patterns),
         mode(mode) {}
 
   /// Converts the given operations to the conversion target.
   LogicalResult convertOperations(ArrayRef<Operation *> ops);
 
-private:
-  /// Converts an operation with the given rewriter.
-  LogicalResult convert(Operation *op);
+  /// Converts a single operation. If `isRecursiveLegalization` is "true", the
+  /// conversion is a recursive legalization request, triggered from within a
+  /// pattern. In that case, do not emit errors because there will be another
+  /// attempt at legalizing the operation later (via the regular pre-order
+  /// legalization mechanism).
+  LogicalResult convert(Operation *op, bool isRecursiveLegalization = false);
 
+private:
   /// The rewriter to use when converting operations.
   ConversionPatternRewriter rewriter;
 
@@ -3238,32 +3278,42 @@ struct OperationConverter {
 };
 } // namespace mlir
 
-LogicalResult OperationConverter::convert(Operation *op) {
+LogicalResult ConversionPatternRewriter::legalize(Operation *op) {
+  return impl->opConverter.convert(op, /*isRecursiveLegalization=*/true);
+}
+
+LogicalResult OperationConverter::convert(Operation *op,
+                                          bool isRecursiveLegalization) {
   const ConversionConfig &config = rewriter.getConfig();
 
   // Legalize the given operation.
   if (failed(opLegalizer.legalize(op))) {
     // Handle the case of a failed conversion for each of the different modes.
     // Full conversions expect all operations to be converted.
-    if (mode == OpConversionMode::Full)
-      return op->emitError()
-             << "failed to legalize operation '" << op->getName() << "'";
+    if (mode == OpConversionMode::Full) {
+      if (!isRecursiveLegalization)
+        op->emitError() << "failed to legalize operation '" << op->getName()
+                        << "'";
+      return failure();
+    }
     // Partial conversions allow conversions to fail iff the operation was not
     // explicitly marked as illegal. If the user provided a `unlegalizedOps`
     // set, non-legalizable ops are added to that set.
     if (mode == OpConversionMode::Partial) {
-      if (opLegalizer.isIllegal(op))
-        return op->emitError()
-               << "failed to legalize operation '" << op->getName()
-               << "' that was explicitly marked illegal";
-      if (config.unlegalizedOps)
+      if (opLegalizer.isIllegal(op)) {
+        if (!isRecursiveLegalization)
+          op->emitError() << "failed to legalize operation '" << op->getName()
+                          << "' that was explicitly marked illegal";
+        return failure();
+      }
+      if (config.unlegalizedOps && !isRecursiveLegalization)
         config.unlegalizedOps->insert(op);
     }
   } else if (mode == OpConversionMode::Analysis) {
     // Analysis conversions don't fail if any operations fail to legalize,
     // they are only interested in the operations that were successfully
     // legalized.
-    if (config.legalizableOps)
+    if (config.legalizableOps && !isRecursiveLegalization)
       config.legalizableOps->insert(op);
   }
   return success();
diff --git a/mlir/test/Transforms/test-legalizer-full.mlir b/mlir/test/Transforms/test-legalizer-full.mlir
index 42cec68b9fbbb..8da9109a32762 100644
--- a/mlir/test/Transforms/test-legalizer-full.mlir
+++ b/mlir/test/Transforms/test-legalizer-full.mlir
@@ -72,3 +72,21 @@ builtin.module {
   }
 
 }
+
+// -----
+
+// The region of "test.post_order_legalization" is converted before the op.
+
+// expected-remark at +1 {{applyFullConversion failed}}
+builtin.module {
+func.func @test_preorder_legalization() {
+  // expected-error at +1 {{failed to legalize operation 'test.post_order_legalization'}}
+  "test.post_order_legalization"() ({
+  ^bb0(%arg0: i64):
+    // Not-explicitly-legal ops are not allowed to survive.
+    "test.remaining_consumer"(%arg0) : (i64) -> ()
+    "test.invalid"(%arg0) : (i64) -> ()
+  }) : () -> ()
+  return
+}
+}
diff --git a/mlir/test/Transforms/test-legalizer-rollback.mlir b/mlir/test/Transforms/test-legalizer-rollback.mlir
index 71e11782e14b0..4bcca6b7e5228 100644
--- a/mlir/test/Transforms/test-legalizer-rollback.mlir
+++ b/mlir/test/Transforms/test-legalizer-rollback.mlir
@@ -163,3 +163,22 @@ func.func @create_unregistered_op_in_pattern() -> i32 {
   "test.return"(%0) : (i32) -> ()
 }
 }
+
+// -----
+
+// CHECK-LABEL: func @test_failed_preorder_legalization
+//       CHECK:   "test.post_order_legalization"() ({
+//       CHECK:     %[[r:.*]] = "test.illegal_op_g"() : () -> i32
+//       CHECK:     "test.return"(%[[r]]) : (i32) -> ()
+//       CHECK:   }) : () -> ()
+// expected-remark @+1 {{applyPartialConversion failed}}
+module {
+func.func @test_failed_preorder_legalization() {
+  // expected-error @+1 {{failed to legalize operation 'test.post_order_legalization' that was explicitly marked illegal}}
+  "test.post_order_legalization"() ({
+    %0 = "test.illegal_op_g"() : () -> (i32)
+    "test.return"(%0) : (i32) -> ()
+  }) : () -> ()
+  return
+}
+}
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 7c43bb7bface0..88a71cc26ab0c 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -448,3 +448,35 @@ func.func @test_working_1to1_pattern(%arg0: f16) {
   "test.type_consumer"(%arg0) : (f16) -> ()
   "test.return"() : () -> ()
 }
+
+// -----
+
+// The region of "test.post_order_legalization" is converted before the op.
+
+// CHECK: notifyBlockInserted into test.post_order_legalization: was unlinked
+// CHECK: notifyOperationInserted: test.invalid
+// CHECK: notifyBlockErased
+// CHECK: notifyOperationInserted: test.valid, was unlinked
+// CHECK: notifyOperationReplaced: test.invalid
+// CHECK: notifyOperationErased: test.invalid
+// CHECK: notifyOperationModified: test.post_order_legalization
+
+// CHECK-LABEL: func @test_preorder_legalization
+//       CHECK:   "test.post_order_legalization"() ({
+//       CHECK:   ^{{.*}}(%[[arg0:.*]]: f64):
+// Note: The survival of a not-explicitly-invalid operation does *not* cause
+// a conversion failure in when applying a partial conversion.
+//       CHECK:     %[[cast:.*]] = "test.cast"(%[[arg0]]) : (f64) -> i64
+//       CHECK:     "test.remaining_consumer"(%[[cast]]) : (i64) -> ()
+//       CHECK:     "test.valid"(%[[arg0]]) : (f64) -> ()
+//       CHECK:   }) {is_legal} : () -> ()
+func.func @test_preorder_legalization() {
+  "test.post_order_legalization"() ({
+  ^bb0(%arg0: i64):
+    // expected-remark @+1 {{'test.remaining_consumer' is not legalizable}}
+    "test.remaining_consumer"(%arg0) : (i64) -> ()
+    "test.invalid"(%arg0) : (i64) -> ()
+  }) : () -> ()
+  // expected-remark @+1 {{'func.return' is not legalizable}}
+  return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 12edecc113495..9b64bc691588d 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1418,6 +1418,22 @@ class TestTypeConsumerOpPattern
   }
 };
 
+class TestPostOrderLegalization : public ConversionPattern {
+public:
+  TestPostOrderLegalization(MLIRContext *ctx, const TypeConverter &converter)
+      : ConversionPattern(converter, "test.post_order_legalization", 1, ctx) {}
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    for (Region &r : op->getRegions())
+      if (failed(rewriter.legalize(&r)))
+        return failure();
+    rewriter.modifyOpInPlace(
+        op, [&]() { op->setAttr("is_legal", rewriter.getUnitAttr()); });
+    return success();
+  }
+};
+
 /// Test unambiguous overload resolution of replaceOpWithMultiple. This
 /// function is just to trigger compiler errors. It is never executed.
 [[maybe_unused]] void testReplaceOpWithMultipleOverloads(
@@ -1532,7 +1548,8 @@ struct TestLegalizePatternDriver
     patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
                  TestPassthroughInvalidOp, TestMultiple1ToNReplacement,
                  TestValueReplace, TestReplaceWithValidConsumer,
-                 TestTypeConsumerOpPattern>(&getContext(), converter);
+                 TestTypeConsumerOpPattern, TestPostOrderLegalization>(
+        &getContext(), converter);
     patterns.add<TestConvertBlockArgs>(converter, &getContext());
     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
                                                               converter);
@@ -1560,6 +1577,9 @@ struct TestLegalizePatternDriver
     target.addDynamicallyLegalOp(
         OperationName("test.value_replace", &getContext()),
         [](Operation *op) { return op->hasAttr("is_legal"); });
+    target.addDynamicallyLegalOp(
+        OperationName("test.post_order_legalization", &getContext()),
+        [](Operation *op) { return op->hasAttr("is_legal"); });
 
     // TestCreateUnregisteredOp creates `arith.constant` operation,
     // which was not added to target intentionally to test



More information about the Mlir-commits mailing list