[Mlir-commits] [mlir] [mlir][Transforms] Dialect Conversion: Convert entry block only (PR #165180)

Matthias Springer llvmlistbot at llvm.org
Thu Oct 30 08:15:00 PDT 2025


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/165180

>From a9aa0a29525572910b43f4b2e789b481f2bf5bad Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sun, 26 Oct 2025 23:21:56 +0000
Subject: [PATCH] [mlir][Transforms] Dialect Conversion: Convert entry block
 only

---
 .../Transforms/Utils/DialectConversion.cpp    | 112 ++++--------------
 mlir/test/Transforms/test-legalizer.mlir      |  30 -----
 2 files changed, 20 insertions(+), 122 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 3a23bbfd70eac..2fe06970eb568 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1105,10 +1105,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// A set of operations that were modified by the current pattern.
   SetVector<Operation *> patternModifiedOps;
 
-  /// A set of blocks that were inserted (newly-created blocks or moved blocks)
-  /// by the current pattern.
-  SetVector<Block *> patternInsertedBlocks;
-
   /// A list of unresolved materializations that were created by the current
   /// pattern.
   DenseSet<UnrealizedConversionCastOp> patternMaterializations;
@@ -2046,8 +2042,6 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
   if (!config.allowPatternRollback && config.listener)
     config.listener->notifyBlockInserted(block, previous, previousIt);
 
-  patternInsertedBlocks.insert(block);
-
   if (wasDetached) {
     // If the block was detached, it is most likely a newly created block.
     if (config.allowPatternRollback) {
@@ -2399,17 +2393,12 @@ class OperationLegalizer {
   bool canApplyPattern(Operation *op, const Pattern &pattern);
 
   /// Legalize the resultant IR after successfully applying the given pattern.
-  LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
-                                      const RewriterState &curState,
-                                      const SetVector<Operation *> &newOps,
-                                      const SetVector<Operation *> &modifiedOps,
-                                      const SetVector<Block *> &insertedBlocks);
-
-  /// Legalizes the actions registered during the execution of a pattern.
   LogicalResult
-  legalizePatternBlockRewrites(Operation *op,
-                               const SetVector<Block *> &insertedBlocks,
-                               const SetVector<Operation *> &newOps);
+  legalizePatternResult(Operation *op, const Pattern &pattern,
+                        const RewriterState &curState,
+                        const SetVector<Operation *> &newOps,
+                        const SetVector<Operation *> &modifiedOps);
+
   LogicalResult
   legalizePatternCreatedOperations(const SetVector<Operation *> &newOps);
   LogicalResult
@@ -2608,7 +2597,6 @@ LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
   auto cleanup = llvm::make_scope_exit([&]() {
     rewriterImpl.patternNewOps.clear();
     rewriterImpl.patternModifiedOps.clear();
-    rewriterImpl.patternInsertedBlocks.clear();
   });
 
   // Upon failure, undo all changes made by the folder.
@@ -2662,24 +2650,16 @@ LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
 static void
 reportNewIrLegalizationFatalError(const Pattern &pattern,
                                   const SetVector<Operation *> &newOps,
-                                  const SetVector<Operation *> &modifiedOps,
-                                  const SetVector<Block *> &insertedBlocks) {
+                                  const SetVector<Operation *> &modifiedOps) {
   auto newOpNames = llvm::map_range(
       newOps, [](Operation *op) { return op->getName().getStringRef(); });
   auto modifiedOpNames = llvm::map_range(
       modifiedOps, [](Operation *op) { return op->getName().getStringRef(); });
-  StringRef detachedBlockStr = "(detached block)";
-  auto insertedBlockNames = llvm::map_range(insertedBlocks, [&](Block *block) {
-    if (block->getParentOp())
-      return block->getParentOp()->getName().getStringRef();
-    return detachedBlockStr;
-  });
-  llvm::report_fatal_error(
-      "pattern '" + pattern.getDebugName() +
-      "' produced IR that could not be legalized. " + "new ops: {" +
-      llvm::join(newOpNames, ", ") + "}, " + "modified ops: {" +
-      llvm::join(modifiedOpNames, ", ") + "}, " + "inserted block into ops: {" +
-      llvm::join(insertedBlockNames, ", ") + "}");
+  llvm::report_fatal_error("pattern '" + pattern.getDebugName() +
+                           "' produced IR that could not be legalized. " +
+                           "new ops: {" + llvm::join(newOpNames, ", ") + "}, " +
+                           "modified ops: {" +
+                           llvm::join(modifiedOpNames, ", ") + "}");
 }
 
 LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
@@ -2743,7 +2723,6 @@ LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
     }
     rewriterImpl.patternNewOps.clear();
     rewriterImpl.patternModifiedOps.clear();
-    rewriterImpl.patternInsertedBlocks.clear();
     LLVM_DEBUG({
       logFailure(rewriterImpl.logger, "pattern failed to match");
       if (rewriterImpl.config.notifyCallback) {
@@ -2777,15 +2756,12 @@ LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
     SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps);
     SetVector<Operation *> modifiedOps =
         moveAndReset(rewriterImpl.patternModifiedOps);
-    SetVector<Block *> insertedBlocks =
-        moveAndReset(rewriterImpl.patternInsertedBlocks);
-    auto result = legalizePatternResult(op, pattern, curState, newOps,
-                                        modifiedOps, insertedBlocks);
+    auto result =
+        legalizePatternResult(op, pattern, curState, newOps, modifiedOps);
     appliedPatterns.erase(&pattern);
     if (failed(result)) {
       if (!rewriterImpl.config.allowPatternRollback)
-        reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps,
-                                          insertedBlocks);
+        reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps);
       rewriterImpl.resetState(curState, pattern.getDebugName());
     }
     if (config.listener)
@@ -2823,8 +2799,7 @@ bool OperationLegalizer::canApplyPattern(Operation *op,
 LogicalResult OperationLegalizer::legalizePatternResult(
     Operation *op, const Pattern &pattern, const RewriterState &curState,
     const SetVector<Operation *> &newOps,
-    const SetVector<Operation *> &modifiedOps,
-    const SetVector<Block *> &insertedBlocks) {
+    const SetVector<Operation *> &modifiedOps) {
   [[maybe_unused]] auto &impl = rewriter.getImpl();
   assert(impl.pendingRootUpdates.empty() && "dangling root updates");
 
@@ -2843,8 +2818,7 @@ LogicalResult OperationLegalizer::legalizePatternResult(
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 
   // Legalize each of the actions registered during application.
-  if (failed(legalizePatternBlockRewrites(op, insertedBlocks, newOps)) ||
-      failed(legalizePatternRootUpdates(modifiedOps)) ||
+  if (failed(legalizePatternRootUpdates(modifiedOps)) ||
       failed(legalizePatternCreatedOperations(newOps))) {
     return failure();
   }
@@ -2853,53 +2827,6 @@ LogicalResult OperationLegalizer::legalizePatternResult(
   return success();
 }
 
-LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
-    Operation *op, const SetVector<Block *> &insertedBlocks,
-    const SetVector<Operation *> &newOps) {
-  ConversionPatternRewriterImpl &impl = rewriter.getImpl();
-  SmallPtrSet<Operation *, 16> alreadyLegalized;
-
-  // If the pattern moved or created any blocks, make sure the types of block
-  // arguments get legalized.
-  for (Block *block : insertedBlocks) {
-    if (impl.erasedBlocks.contains(block))
-      continue;
-
-    // Only check blocks outside of the current operation.
-    Operation *parentOp = block->getParentOp();
-    if (!parentOp || parentOp == op || block->getNumArguments() == 0)
-      continue;
-
-    // If the region of the block has a type converter, try to convert the block
-    // directly.
-    if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
-      std::optional<TypeConverter::SignatureConversion> conversion =
-          converter->convertBlockSignature(block);
-      if (!conversion) {
-        LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
-                                           "block"));
-        return failure();
-      }
-      impl.applySignatureConversion(block, converter, *conversion);
-      continue;
-    }
-
-    // Otherwise, try to legalize the parent operation if it was not generated
-    // by this pattern. This is because we will attempt to legalize the parent
-    // operation, and blocks in regions created by this pattern will already be
-    // legalized later on.
-    if (!newOps.count(parentOp) && alreadyLegalized.insert(parentOp).second) {
-      if (failed(legalize(parentOp))) {
-        LLVM_DEBUG(logFailure(
-            impl.logger, "operation '{0}'({1}) became illegal after rewrite",
-            parentOp->getName(), parentOp));
-        return failure();
-      }
-    }
-  }
-  return success();
-}
-
 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
     const SetVector<Operation *> &newOps) {
   for (Operation *op : newOps) {
@@ -3800,10 +3727,11 @@ static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
   TypeConverter::SignatureConversion result(type.getNumInputs());
   SmallVector<Type, 1> newResults;
   if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
-      failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
-      failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
-                                         typeConverter, &result)))
+      failed(typeConverter.convertTypes(type.getResults(), newResults)))
     return failure();
+  if (!funcOp.getFunctionBody().empty())
+    rewriter.applySignatureConversion(&funcOp.getFunctionBody().front(), result,
+                                      &typeConverter);
 
   // Update the function signature in-place.
   auto newType = FunctionType::get(rewriter.getContext(),
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 94c5bb4e93b06..ba1f962fdb68b 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -146,36 +146,6 @@ func.func @no_remap_nested() {
 
 // -----
 
-// CHECK-LABEL: func @remap_moved_region_args
-func.func @remap_moved_region_args() {
-  // CHECK-NEXT: return
-  // CHECK-NEXT: ^bb1(%{{.*}}: f64, %{{.*}}: f64, %{{.*}}: f16, %{{.*}}: f16):
-  // CHECK-NEXT: "test.cast"{{.*}} : (f16, f16) -> f32
-  // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64, f32)
-  "test.region"() ({
-    ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
-      "test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> ()
-  }) : () -> ()
-  // expected-remark at +1 {{op 'func.return' is not legalizable}}
-  return
-}
-
-// -----
-
-// CHECK-LABEL: func @remap_cloned_region_args
-func.func @remap_cloned_region_args() {
-  // CHECK-NEXT: return
-  // CHECK-NEXT: ^bb1(%{{.*}}: f64, %{{.*}}: f64, %{{.*}}: f16, %{{.*}}: f16):
-  // CHECK-NEXT: "test.cast"{{.*}} : (f16, f16) -> f32
-  // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64, f32)
-  "test.region"() ({
-    ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
-      "test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> ()
-  }) {legalizer.should_clone} : () -> ()
-  // expected-remark at +1 {{op 'func.return' is not legalizable}}
-  return
-}
-
 // CHECK-LABEL: func @remap_drop_region
 func.func @remap_drop_region() {
   // CHECK-NEXT: return



More information about the Mlir-commits mailing list