[llvm-branch-commits] [mlir] [mlir][Transforms] Add 1:N support to `replaceUsesOfBlockArgument` (PR #145171)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Sat Jun 21 07:35:35 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

This commit adds 1:N support to `ConversionPatternRewriter::replaceUsesOfBlockArgument`. This was one of the few remaining dialect conversion APIs that does not support 1:N conversions yet.

This commit also reuses `replaceUsesOfBlockArgument` in the implementation of `applySignatureConversion`. This is in preparation of the One-Shot Dialect Conversion refactoring. The goal is to bring the `applySignatureConversion` implementation into a state where it works both with and without rollbacks. To that end, `applySignatureConversion` should not directly access the `mapping`.

Depends on #<!-- -->145155.


---
Full diff: https://github.com/llvm/llvm-project/pull/145171.diff


5 Files Affected:

- (modified) mlir/include/mlir/Transforms/DialectConversion.h (+3-2) 
- (modified) mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp (+1-1) 
- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+25-15) 
- (modified) mlir/test/Transforms/test-legalizer.mlir (+24-7) 
- (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+29-22) 


``````````diff
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 5a5f116073a9a..81858812d2623 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -763,8 +763,9 @@ class ConversionPatternRewriter final : public PatternRewriter {
       Region *region, const TypeConverter &converter,
       TypeConverter::SignatureConversion *entryConversion = nullptr);
 
-  /// Replace all the uses of the block argument `from` with value `to`.
-  void replaceUsesOfBlockArgument(BlockArgument from, Value to);
+  /// Replace all the uses of the block argument `from` with `to`. This
+  /// function supports both 1:1 and 1:N replacements.
+  void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to);
 
   /// Return the converted value of 'key' with a type defined by the type
   /// converter of the currently executing pattern. Return nullptr in the case
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 538016927256b..9e8e746507557 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -294,7 +294,7 @@ static void restoreByValRefArgumentType(
     Type resTy = typeConverter.convertType(
         cast<TypeAttr>(byValRefAttr->getValue()).getValue());
 
-    auto valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
+    Value valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
     rewriter.replaceUsesOfBlockArgument(arg, valueArg);
   }
 }
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 774d58973eb91..9cb6f2ba1eaae 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -948,6 +948,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// uses.
   void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);
 
+  /// Replace the given block argument with the given values. The specified
+  /// converter is used to build materializations (if necessary).
+  void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to,
+                                  const TypeConverter *converter);
+
   /// Erase the given block and its contents.
   void eraseBlock(Block *block);
 
@@ -1434,12 +1439,15 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     if (!inputMap) {
       // This block argument was dropped and no replacement value was provided.
       // Materialize a replacement value "out of thin air".
-      buildUnresolvedMaterialization(
-          MaterializationKind::Source,
-          OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
-          /*valuesToMap=*/{origArg}, /*inputs=*/ValueRange(),
-          /*outputTypes=*/origArgType, /*originalType=*/Type(), converter);
-      appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
+      Value mat =
+          buildUnresolvedMaterialization(
+              MaterializationKind::Source,
+              OpBuilder::InsertPoint(newBlock, newBlock->begin()),
+              origArg.getLoc(),
+              /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
+              /*outputTypes=*/origArgType, /*originalType=*/Type(), converter)
+              .front();
+      replaceUsesOfBlockArgument(origArg, mat, converter);
       continue;
     }
 
@@ -1448,17 +1456,15 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
       assert(inputMap->size == 0 &&
              "invalid to provide a replacement value when the argument isn't "
              "dropped");
-      mapping.map(origArg, inputMap->replacementValues);
-      appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
+      replaceUsesOfBlockArgument(origArg, inputMap->replacementValues,
+                                 converter);
       continue;
     }
 
     // This is a 1->1+ mapping.
     auto replArgs =
         newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
-    ValueVector replArgVals = llvm::to_vector_of<Value, 1>(replArgs);
-    mapping.map(origArg, std::move(replArgVals));
-    appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
+    replaceUsesOfBlockArgument(origArg, replArgs, converter);
   }
 
   appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
@@ -1612,6 +1618,12 @@ void ConversionPatternRewriterImpl::replaceOp(
   op->walk([&](Operation *op) { replacedOps.insert(op); });
 }
 
+void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument(
+    BlockArgument from, ValueRange to, const TypeConverter *converter) {
+  appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, converter);
+  mapping.map(from, to);
+}
+
 void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
   assert(!wasOpReplaced(block->getParentOp()) &&
          "attempting to erase a block within a replaced/erased op");
@@ -1744,7 +1756,7 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
 }
 
 void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
-                                                           Value to) {
+                                                           ValueRange to) {
   LLVM_DEBUG({
     impl->logger.startLine() << "** Replace Argument : '" << from << "'";
     if (Operation *parentOp = from.getOwner()->getParentOp()) {
@@ -1754,9 +1766,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
       impl->logger.getOStream() << " (unlinked block)\n";
     }
   });
-  impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
-                                              impl->currentTypeConverter);
-  impl->mapping.map(from, to);
+  impl->replaceUsesOfBlockArgument(from, to, impl->currentTypeConverter);
 }
 
 Value ConversionPatternRewriter::getRemappedValue(Value key) {
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 204c8c1456826..79518b04e7158 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -300,18 +300,35 @@ func.func @create_illegal_block() {
 // -----
 
 // CHECK-LABEL: @undo_block_arg_replace
+// expected-remark at +1{{applyPartialConversion failed}}
+module {
 func.func @undo_block_arg_replace() {
-  // expected-remark at +1 {{op 'test.undo_block_arg_replace' is not legalizable}}
-  "test.undo_block_arg_replace"() ({
-  ^bb0(%arg0: i32):
-    // CHECK: ^bb0(%[[ARG:.*]]: i32):
-    // CHECK-NEXT: "test.return"(%[[ARG]]) : (i32)
+  // expected-error at +1{{failed to legalize operation 'test.block_arg_replace' that was explicitly marked illegal}}
+  "test.block_arg_replace"() ({
+  ^bb0(%arg0: i32, %arg1: i16):
+    // CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16):
+    // CHECK-NEXT: "test.return"(%[[ARG0]]) : (i32)
 
     "test.return"(%arg0) : (i32) -> ()
-  }) : () -> ()
-  // expected-remark at +1 {{op 'func.return' is not legalizable}}
+  }) {trigger_rollback} : () -> ()
   return
 }
+}
+
+// -----
+
+// CHECK-LABEL: @replace_block_arg_1_to_n
+func.func @replace_block_arg_1_to_n() {
+  // CHECK: "test.block_arg_replace"
+  "test.block_arg_replace"() ({
+  ^bb0(%arg0: i32, %arg1: i16):
+    // CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16):
+    // CHECK: %[[cast:.*]] = "test.cast"(%[[ARG1]], %[[ARG1]]) : (i16, i16) -> i32
+    // CHECK-NEXT: "test.return"(%[[cast]]) : (i32)
+    "test.return"(%arg0) : (i32) -> ()
+  }) : () -> ()
+  "test.return"() : () -> ()
+}
 
 // -----
 
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index d073843484d81..588e529665dd1 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -891,20 +891,25 @@ struct TestCreateIllegalBlock : public RewritePattern {
   }
 };
 
-/// A simple pattern that tests the undo mechanism when replacing the uses of a
-/// block argument.
-struct TestUndoBlockArgReplace : public ConversionPattern {
-  TestUndoBlockArgReplace(MLIRContext *ctx)
-      : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
+/// A simple pattern that tests the "replaceUsesOfBlockArgument" API.
+struct TestBlockArgReplace : public ConversionPattern {
+  TestBlockArgReplace(MLIRContext *ctx, const TypeConverter &converter)
+      : ConversionPattern(converter, "test.block_arg_replace", /*benefit=*/1,
+                          ctx) {}
 
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
-    auto illegalOp =
-        rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
+    // Replace the first block argument with 2x the second block argument.
+    Value repl = op->getRegion(0).getArgument(1);
     rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
-                                        illegalOp->getResult(0));
-    rewriter.modifyOpInPlace(op, [] {});
+                                        {repl, repl});
+    rewriter.modifyOpInPlace(op, [&] {
+      // If the "trigger_rollback" attribute is set, keep the op illegal, so
+      // that a rollback is triggered.
+      if (!op->hasAttr("trigger_rollback"))
+        op->setAttr("is_legal", rewriter.getUnitAttr());
+    });
     return success();
   }
 };
@@ -1375,20 +1380,19 @@ struct TestLegalizePatternDriver
     TestTypeConverter converter;
     mlir::RewritePatternSet patterns(&getContext());
     populateWithGenerated(patterns);
-    patterns
-        .add<TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
-             TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
-             TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType,
-             TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
-             TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
-             TestNonRootReplacement, TestBoundedRecursiveRewrite,
-             TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
-             TestCreateUnregisteredOp, TestUndoMoveOpBefore,
-             TestUndoPropertiesModification, TestEraseOp,
-             TestRepetitive1ToNConsumer>(&getContext());
+    patterns.add<
+        TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
+        TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
+        TestUndoBlockErase, TestSplitReturnType, TestChangeProducerTypeI32ToF32,
+        TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid,
+        TestUpdateConsumerType, TestNonRootReplacement,
+        TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
+        TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
+        TestUndoPropertiesModification, TestEraseOp,
+        TestRepetitive1ToNConsumer>(&getContext());
     patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
-                 TestPassthroughInvalidOp, TestMultiple1ToNReplacement>(
-        &getContext(), converter);
+                 TestPassthroughInvalidOp, TestMultiple1ToNReplacement,
+                 TestBlockArgReplace>(&getContext(), converter);
     patterns.add<TestConvertBlockArgs>(converter, &getContext());
     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
                                                               converter);
@@ -1413,6 +1417,9 @@ struct TestLegalizePatternDriver
     });
     target.addDynamicallyLegalOp<func::CallOp>(
         [&](func::CallOp op) { return converter.isLegal(op); });
+    target.addDynamicallyLegalOp(
+        OperationName("test.block_arg_replace", &getContext()),
+        [](Operation *op) { return op->hasAttr("is_legal"); });
 
     // TestCreateUnregisteredOp creates `arith.constant` operation,
     // which was not added to target intentionally to test

``````````

</details>


https://github.com/llvm/llvm-project/pull/145171


More information about the llvm-branch-commits mailing list