[Mlir-commits] [mlir] [mlir][Transforms] Dialect Conversion: Convert entry block only (PR #165180)

Matthias Springer llvmlistbot at llvm.org
Sun Oct 26 16:30:52 PDT 2025


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/165180

When converting a function, convert only the entry block signature. The remaining block signatures should be converted by the respective branching ops. The `FuncToLLVM` / `ControlFlowToLLVM` patterns already use that design.

This is consistent with the fact that operations from unreachable blocks are not put on the initial worklist.

With this change, parent ops are no longer recursively legalized when inserting a block.

>From b91825b4f661d12a8b4b78ce12e562184e5a0a90 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sun, 26 Oct 2025 23:21:56 +0000
Subject: [PATCH] [mlir][Transforms] Dialect Conversion: Convert entry block
 only

---
 .../Transforms/Utils/DialectConversion.cpp    | 112 ++++--------------
 mlir/test/Transforms/test-legalizer.mlir      |  30 -----
 2 files changed, 20 insertions(+), 122 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 3a23bbfd70eac..2fe06970eb568 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1105,10 +1105,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// A set of operations that were modified by the current pattern.
   SetVector<Operation *> patternModifiedOps;
 
-  /// A set of blocks that were inserted (newly-created blocks or moved blocks)
-  /// by the current pattern.
-  SetVector<Block *> patternInsertedBlocks;
-
   /// A list of unresolved materializations that were created by the current
   /// pattern.
   DenseSet<UnrealizedConversionCastOp> patternMaterializations;
@@ -2046,8 +2042,6 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
   if (!config.allowPatternRollback && config.listener)
     config.listener->notifyBlockInserted(block, previous, previousIt);
 
-  patternInsertedBlocks.insert(block);
-
   if (wasDetached) {
     // If the block was detached, it is most likely a newly created block.
     if (config.allowPatternRollback) {
@@ -2399,17 +2393,12 @@ class OperationLegalizer {
   bool canApplyPattern(Operation *op, const Pattern &pattern);
 
   /// Legalize the resultant IR after successfully applying the given pattern.
-  LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
-                                      const RewriterState &curState,
-                                      const SetVector<Operation *> &newOps,
-                                      const SetVector<Operation *> &modifiedOps,
-                                      const SetVector<Block *> &insertedBlocks);
-
-  /// Legalizes the actions registered during the execution of a pattern.
   LogicalResult
-  legalizePatternBlockRewrites(Operation *op,
-                               const SetVector<Block *> &insertedBlocks,
-                               const SetVector<Operation *> &newOps);
+  legalizePatternResult(Operation *op, const Pattern &pattern,
+                        const RewriterState &curState,
+                        const SetVector<Operation *> &newOps,
+                        const SetVector<Operation *> &modifiedOps);
+
   LogicalResult
   legalizePatternCreatedOperations(const SetVector<Operation *> &newOps);
   LogicalResult
@@ -2608,7 +2597,6 @@ LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
   auto cleanup = llvm::make_scope_exit([&]() {
     rewriterImpl.patternNewOps.clear();
     rewriterImpl.patternModifiedOps.clear();
-    rewriterImpl.patternInsertedBlocks.clear();
   });
 
   // Upon failure, undo all changes made by the folder.
@@ -2662,24 +2650,16 @@ LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
 static void
 reportNewIrLegalizationFatalError(const Pattern &pattern,
                                   const SetVector<Operation *> &newOps,
-                                  const SetVector<Operation *> &modifiedOps,
-                                  const SetVector<Block *> &insertedBlocks) {
+                                  const SetVector<Operation *> &modifiedOps) {
   auto newOpNames = llvm::map_range(
       newOps, [](Operation *op) { return op->getName().getStringRef(); });
   auto modifiedOpNames = llvm::map_range(
       modifiedOps, [](Operation *op) { return op->getName().getStringRef(); });
-  StringRef detachedBlockStr = "(detached block)";
-  auto insertedBlockNames = llvm::map_range(insertedBlocks, [&](Block *block) {
-    if (block->getParentOp())
-      return block->getParentOp()->getName().getStringRef();
-    return detachedBlockStr;
-  });
-  llvm::report_fatal_error(
-      "pattern '" + pattern.getDebugName() +
-      "' produced IR that could not be legalized. " + "new ops: {" +
-      llvm::join(newOpNames, ", ") + "}, " + "modified ops: {" +
-      llvm::join(modifiedOpNames, ", ") + "}, " + "inserted block into ops: {" +
-      llvm::join(insertedBlockNames, ", ") + "}");
+  llvm::report_fatal_error("pattern '" + pattern.getDebugName() +
+                           "' produced IR that could not be legalized. " +
+                           "new ops: {" + llvm::join(newOpNames, ", ") + "}, " +
+                           "modified ops: {" +
+                           llvm::join(modifiedOpNames, ", ") + "}");
 }
 
 LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
@@ -2743,7 +2723,6 @@ LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
     }
     rewriterImpl.patternNewOps.clear();
     rewriterImpl.patternModifiedOps.clear();
-    rewriterImpl.patternInsertedBlocks.clear();
     LLVM_DEBUG({
       logFailure(rewriterImpl.logger, "pattern failed to match");
       if (rewriterImpl.config.notifyCallback) {
@@ -2777,15 +2756,12 @@ LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
     SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps);
     SetVector<Operation *> modifiedOps =
         moveAndReset(rewriterImpl.patternModifiedOps);
-    SetVector<Block *> insertedBlocks =
-        moveAndReset(rewriterImpl.patternInsertedBlocks);
-    auto result = legalizePatternResult(op, pattern, curState, newOps,
-                                        modifiedOps, insertedBlocks);
+    auto result =
+        legalizePatternResult(op, pattern, curState, newOps, modifiedOps);
     appliedPatterns.erase(&pattern);
     if (failed(result)) {
       if (!rewriterImpl.config.allowPatternRollback)
-        reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps,
-                                          insertedBlocks);
+        reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps);
       rewriterImpl.resetState(curState, pattern.getDebugName());
     }
     if (config.listener)
@@ -2823,8 +2799,7 @@ bool OperationLegalizer::canApplyPattern(Operation *op,
 LogicalResult OperationLegalizer::legalizePatternResult(
     Operation *op, const Pattern &pattern, const RewriterState &curState,
     const SetVector<Operation *> &newOps,
-    const SetVector<Operation *> &modifiedOps,
-    const SetVector<Block *> &insertedBlocks) {
+    const SetVector<Operation *> &modifiedOps) {
   [[maybe_unused]] auto &impl = rewriter.getImpl();
   assert(impl.pendingRootUpdates.empty() && "dangling root updates");
 
@@ -2843,8 +2818,7 @@ LogicalResult OperationLegalizer::legalizePatternResult(
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 
   // Legalize each of the actions registered during application.
-  if (failed(legalizePatternBlockRewrites(op, insertedBlocks, newOps)) ||
-      failed(legalizePatternRootUpdates(modifiedOps)) ||
+  if (failed(legalizePatternRootUpdates(modifiedOps)) ||
       failed(legalizePatternCreatedOperations(newOps))) {
     return failure();
   }
@@ -2853,53 +2827,6 @@ LogicalResult OperationLegalizer::legalizePatternResult(
   return success();
 }
 
-LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
-    Operation *op, const SetVector<Block *> &insertedBlocks,
-    const SetVector<Operation *> &newOps) {
-  ConversionPatternRewriterImpl &impl = rewriter.getImpl();
-  SmallPtrSet<Operation *, 16> alreadyLegalized;
-
-  // If the pattern moved or created any blocks, make sure the types of block
-  // arguments get legalized.
-  for (Block *block : insertedBlocks) {
-    if (impl.erasedBlocks.contains(block))
-      continue;
-
-    // Only check blocks outside of the current operation.
-    Operation *parentOp = block->getParentOp();
-    if (!parentOp || parentOp == op || block->getNumArguments() == 0)
-      continue;
-
-    // If the region of the block has a type converter, try to convert the block
-    // directly.
-    if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
-      std::optional<TypeConverter::SignatureConversion> conversion =
-          converter->convertBlockSignature(block);
-      if (!conversion) {
-        LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
-                                           "block"));
-        return failure();
-      }
-      impl.applySignatureConversion(block, converter, *conversion);
-      continue;
-    }
-
-    // Otherwise, try to legalize the parent operation if it was not generated
-    // by this pattern. This is because we will attempt to legalize the parent
-    // operation, and blocks in regions created by this pattern will already be
-    // legalized later on.
-    if (!newOps.count(parentOp) && alreadyLegalized.insert(parentOp).second) {
-      if (failed(legalize(parentOp))) {
-        LLVM_DEBUG(logFailure(
-            impl.logger, "operation '{0}'({1}) became illegal after rewrite",
-            parentOp->getName(), parentOp));
-        return failure();
-      }
-    }
-  }
-  return success();
-}
-
 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
     const SetVector<Operation *> &newOps) {
   for (Operation *op : newOps) {
@@ -3800,10 +3727,11 @@ static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
   TypeConverter::SignatureConversion result(type.getNumInputs());
   SmallVector<Type, 1> newResults;
   if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
-      failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
-      failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
-                                         typeConverter, &result)))
+      failed(typeConverter.convertTypes(type.getResults(), newResults)))
     return failure();
+  if (!funcOp.getFunctionBody().empty())
+    rewriter.applySignatureConversion(&funcOp.getFunctionBody().front(), result,
+                                      &typeConverter);
 
   // Update the function signature in-place.
   auto newType = FunctionType::get(rewriter.getContext(),
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 94c5bb4e93b06..ba1f962fdb68b 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -146,36 +146,6 @@ func.func @no_remap_nested() {
 
 // -----
 
-// CHECK-LABEL: func @remap_moved_region_args
-func.func @remap_moved_region_args() {
-  // CHECK-NEXT: return
-  // CHECK-NEXT: ^bb1(%{{.*}}: f64, %{{.*}}: f64, %{{.*}}: f16, %{{.*}}: f16):
-  // CHECK-NEXT: "test.cast"{{.*}} : (f16, f16) -> f32
-  // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64, f32)
-  "test.region"() ({
-    ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
-      "test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> ()
-  }) : () -> ()
-  // expected-remark at +1 {{op 'func.return' is not legalizable}}
-  return
-}
-
-// -----
-
-// CHECK-LABEL: func @remap_cloned_region_args
-func.func @remap_cloned_region_args() {
-  // CHECK-NEXT: return
-  // CHECK-NEXT: ^bb1(%{{.*}}: f64, %{{.*}}: f64, %{{.*}}: f16, %{{.*}}: f16):
-  // CHECK-NEXT: "test.cast"{{.*}} : (f16, f16) -> f32
-  // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64, f32)
-  "test.region"() ({
-    ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
-      "test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> ()
-  }) {legalizer.should_clone} : () -> ()
-  // expected-remark at +1 {{op 'func.return' is not legalizable}}
-  return
-}
-
 // CHECK-LABEL: func @remap_drop_region
 func.func @remap_drop_region() {
   // CHECK-NEXT: return



More information about the Mlir-commits mailing list