[llvm-branch-commits] [mlir] [mlir][Transforms] Add 1:N support to `replaceUsesOfBlockArgument` (PR #145171)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Sat Jun 21 07:35:35 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
This commit adds 1:N support to `ConversionPatternRewriter::replaceUsesOfBlockArgument`. This was one of the few remaining dialect conversion APIs that does not support 1:N conversions yet.
This commit also reuses `replaceUsesOfBlockArgument` in the implementation of `applySignatureConversion`. This is in preparation of the One-Shot Dialect Conversion refactoring. The goal is to bring the `applySignatureConversion` implementation into a state where it works both with and without rollbacks. To that end, `applySignatureConversion` should not directly access the `mapping`.
Depends on #<!-- -->145155.
---
Full diff: https://github.com/llvm/llvm-project/pull/145171.diff
5 Files Affected:
- (modified) mlir/include/mlir/Transforms/DialectConversion.h (+3-2)
- (modified) mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp (+1-1)
- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+25-15)
- (modified) mlir/test/Transforms/test-legalizer.mlir (+24-7)
- (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+29-22)
``````````diff
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 5a5f116073a9a..81858812d2623 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -763,8 +763,9 @@ class ConversionPatternRewriter final : public PatternRewriter {
Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion = nullptr);
- /// Replace all the uses of the block argument `from` with value `to`.
- void replaceUsesOfBlockArgument(BlockArgument from, Value to);
+ /// Replace all the uses of the block argument `from` with `to`. This
+ /// function supports both 1:1 and 1:N replacements.
+ void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to);
/// Return the converted value of 'key' with a type defined by the type
/// converter of the currently executing pattern. Return nullptr in the case
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 538016927256b..9e8e746507557 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -294,7 +294,7 @@ static void restoreByValRefArgumentType(
Type resTy = typeConverter.convertType(
cast<TypeAttr>(byValRefAttr->getValue()).getValue());
- auto valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
+ Value valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
rewriter.replaceUsesOfBlockArgument(arg, valueArg);
}
}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 774d58973eb91..9cb6f2ba1eaae 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -948,6 +948,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// uses.
void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);
+ /// Replace the given block argument with the given values. The specified
+ /// converter is used to build materializations (if necessary).
+ void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to,
+ const TypeConverter *converter);
+
/// Erase the given block and its contents.
void eraseBlock(Block *block);
@@ -1434,12 +1439,15 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
if (!inputMap) {
// This block argument was dropped and no replacement value was provided.
// Materialize a replacement value "out of thin air".
- buildUnresolvedMaterialization(
- MaterializationKind::Source,
- OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
- /*valuesToMap=*/{origArg}, /*inputs=*/ValueRange(),
- /*outputTypes=*/origArgType, /*originalType=*/Type(), converter);
- appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
+ Value mat =
+ buildUnresolvedMaterialization(
+ MaterializationKind::Source,
+ OpBuilder::InsertPoint(newBlock, newBlock->begin()),
+ origArg.getLoc(),
+ /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
+ /*outputTypes=*/origArgType, /*originalType=*/Type(), converter)
+ .front();
+ replaceUsesOfBlockArgument(origArg, mat, converter);
continue;
}
@@ -1448,17 +1456,15 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
assert(inputMap->size == 0 &&
"invalid to provide a replacement value when the argument isn't "
"dropped");
- mapping.map(origArg, inputMap->replacementValues);
- appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
+ replaceUsesOfBlockArgument(origArg, inputMap->replacementValues,
+ converter);
continue;
}
// This is a 1->1+ mapping.
auto replArgs =
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
- ValueVector replArgVals = llvm::to_vector_of<Value, 1>(replArgs);
- mapping.map(origArg, std::move(replArgVals));
- appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
+ replaceUsesOfBlockArgument(origArg, replArgs, converter);
}
appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
@@ -1612,6 +1618,12 @@ void ConversionPatternRewriterImpl::replaceOp(
op->walk([&](Operation *op) { replacedOps.insert(op); });
}
+void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument(
+ BlockArgument from, ValueRange to, const TypeConverter *converter) {
+ appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, converter);
+ mapping.map(from, to);
+}
+
void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
assert(!wasOpReplaced(block->getParentOp()) &&
"attempting to erase a block within a replaced/erased op");
@@ -1744,7 +1756,7 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
}
void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
- Value to) {
+ ValueRange to) {
LLVM_DEBUG({
impl->logger.startLine() << "** Replace Argument : '" << from << "'";
if (Operation *parentOp = from.getOwner()->getParentOp()) {
@@ -1754,9 +1766,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
impl->logger.getOStream() << " (unlinked block)\n";
}
});
- impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
- impl->currentTypeConverter);
- impl->mapping.map(from, to);
+ impl->replaceUsesOfBlockArgument(from, to, impl->currentTypeConverter);
}
Value ConversionPatternRewriter::getRemappedValue(Value key) {
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 204c8c1456826..79518b04e7158 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -300,18 +300,35 @@ func.func @create_illegal_block() {
// -----
// CHECK-LABEL: @undo_block_arg_replace
+// expected-remark at +1{{applyPartialConversion failed}}
+module {
func.func @undo_block_arg_replace() {
- // expected-remark at +1 {{op 'test.undo_block_arg_replace' is not legalizable}}
- "test.undo_block_arg_replace"() ({
- ^bb0(%arg0: i32):
- // CHECK: ^bb0(%[[ARG:.*]]: i32):
- // CHECK-NEXT: "test.return"(%[[ARG]]) : (i32)
+ // expected-error at +1{{failed to legalize operation 'test.block_arg_replace' that was explicitly marked illegal}}
+ "test.block_arg_replace"() ({
+ ^bb0(%arg0: i32, %arg1: i16):
+ // CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16):
+ // CHECK-NEXT: "test.return"(%[[ARG0]]) : (i32)
"test.return"(%arg0) : (i32) -> ()
- }) : () -> ()
- // expected-remark at +1 {{op 'func.return' is not legalizable}}
+ }) {trigger_rollback} : () -> ()
return
}
+}
+
+// -----
+
+// CHECK-LABEL: @replace_block_arg_1_to_n
+func.func @replace_block_arg_1_to_n() {
+ // CHECK: "test.block_arg_replace"
+ "test.block_arg_replace"() ({
+ ^bb0(%arg0: i32, %arg1: i16):
+ // CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16):
+ // CHECK: %[[cast:.*]] = "test.cast"(%[[ARG1]], %[[ARG1]]) : (i16, i16) -> i32
+ // CHECK-NEXT: "test.return"(%[[cast]]) : (i32)
+ "test.return"(%arg0) : (i32) -> ()
+ }) : () -> ()
+ "test.return"() : () -> ()
+}
// -----
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index d073843484d81..588e529665dd1 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -891,20 +891,25 @@ struct TestCreateIllegalBlock : public RewritePattern {
}
};
-/// A simple pattern that tests the undo mechanism when replacing the uses of a
-/// block argument.
-struct TestUndoBlockArgReplace : public ConversionPattern {
- TestUndoBlockArgReplace(MLIRContext *ctx)
- : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
+/// A simple pattern that tests the "replaceUsesOfBlockArgument" API.
+struct TestBlockArgReplace : public ConversionPattern {
+ TestBlockArgReplace(MLIRContext *ctx, const TypeConverter &converter)
+ : ConversionPattern(converter, "test.block_arg_replace", /*benefit=*/1,
+ ctx) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
- auto illegalOp =
- rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
+ // Replace the first block argument with 2x the second block argument.
+ Value repl = op->getRegion(0).getArgument(1);
rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
- illegalOp->getResult(0));
- rewriter.modifyOpInPlace(op, [] {});
+ {repl, repl});
+ rewriter.modifyOpInPlace(op, [&] {
+ // If the "trigger_rollback" attribute is set, keep the op illegal, so
+ // that a rollback is triggered.
+ if (!op->hasAttr("trigger_rollback"))
+ op->setAttr("is_legal", rewriter.getUnitAttr());
+ });
return success();
}
};
@@ -1375,20 +1380,19 @@ struct TestLegalizePatternDriver
TestTypeConverter converter;
mlir::RewritePatternSet patterns(&getContext());
populateWithGenerated(patterns);
- patterns
- .add<TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
- TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
- TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType,
- TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
- TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
- TestNonRootReplacement, TestBoundedRecursiveRewrite,
- TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
- TestCreateUnregisteredOp, TestUndoMoveOpBefore,
- TestUndoPropertiesModification, TestEraseOp,
- TestRepetitive1ToNConsumer>(&getContext());
+ patterns.add<
+ TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
+ TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
+ TestUndoBlockErase, TestSplitReturnType, TestChangeProducerTypeI32ToF32,
+ TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid,
+ TestUpdateConsumerType, TestNonRootReplacement,
+ TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
+ TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
+ TestUndoPropertiesModification, TestEraseOp,
+ TestRepetitive1ToNConsumer>(&getContext());
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
- TestPassthroughInvalidOp, TestMultiple1ToNReplacement>(
- &getContext(), converter);
+ TestPassthroughInvalidOp, TestMultiple1ToNReplacement,
+ TestBlockArgReplace>(&getContext(), converter);
patterns.add<TestConvertBlockArgs>(converter, &getContext());
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
@@ -1413,6 +1417,9 @@ struct TestLegalizePatternDriver
});
target.addDynamicallyLegalOp<func::CallOp>(
[&](func::CallOp op) { return converter.isLegal(op); });
+ target.addDynamicallyLegalOp(
+ OperationName("test.block_arg_replace", &getContext()),
+ [](Operation *op) { return op->hasAttr("is_legal"); });
// TestCreateUnregisteredOp creates `arith.constant` operation,
// which was not added to target intentionally to test
``````````
</details>
https://github.com/llvm/llvm-project/pull/145171
More information about the llvm-branch-commits
mailing list