[llvm-branch-commits] [mlir] [mlir][Transforms] Dialect conversion: Fix `replaceUsesOfBlockArgument` (PR #117666)
Matthias Springer via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Nov 25 20:48:55 PST 2024
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/117666
This commit fixes the implementation of `ConversionPatternRewriter::replaceUsesOfBlockArgument`. The old implementation was different from what the documentation says.
```
/// Replace all the uses of the block argument `from` with value `to`.
void ConversionPatternRewriter::replaceUsesOfBlockArgument(
BlockArgument from, Value to) {
// ...
impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
}
```
The extra `mapping.lookupOrDefault` was incorrect: we may not replace `from`, but the value that `from` is mapped to (if it is mapped).
This function is typically used after a block signature conversion to "fix up" some block arguments. During a 1:N conversion, an argument materialization is inserted. The old implementation could be used to replace the argument materialization by passing the old block argument as the `from` parameter. This was unintuitive, because it's not the block argument that is being replaced. Furthermore, replacing a block arguments of an erased block (scheduled for erasure to be precise) is incorrect from an API perspective because a block argument of an erased block should not have any uses anymore.
The new implementation of `replaceUsesOfBlockArgument` now does what the documentation says: it replaces the `from` argument. No extra lookup magic anymore. When an argument materialization should be replaced, users can call `replaceOp` on the argument materialization.
>From 18302a346179fda0b04416883a380decae3e4bfd Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Tue, 26 Nov 2024 05:38:50 +0100
Subject: [PATCH] [mlir][Transforms] Dialect conversion: Fix
`replaceUsesOfBlockArgument`
---
.../Transforms/Utils/DialectConversion.cpp | 2 +-
mlir/test/Transforms/test-legalizer.mlir | 20 +++++++++
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 45 ++++++++++++++++++-
3 files changed, 65 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 60b3656d98a38e..8b7ffd791d2591 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1641,7 +1641,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
});
impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
impl->currentTypeConverter);
- impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
+ impl->mapping.map(from, to);
}
Value ConversionPatternRewriter::getRemappedValue(Value key) {
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 624add08846a28..dfa619796700eb 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -472,3 +472,23 @@ func.func @circular_mapping() {
%0 = "test.erase_op"() : () -> (i64)
"test.drop_operands_and_replace_with_valid"(%0) : (i64) -> ()
}
+
+// -----
+
+// CHECK-LABEL: func @test_replace_uses_of_block_arg() {
+// CHECK: "test.convert_block_and_replace_arg"() ({
+// CHECK: ^bb0(%[[arg0:.*]]: f64, %[[arg1:.*]]: f64):
+// CHECK: %[[producer:.*]] = "test.type_producer"() : () -> f64
+// CHECK: %[[cast:.*]] = "test.cast"(%[[producer]], %[[arg1]]) : (f64, f64) -> f32
+// CHECK: "test.some_user"(%[[cast]]) : (f32) -> ()
+// CHECK: }) {legal} : () -> ()
+// CHECK: "test.return"() : () -> ()
+// CHECK: }
+func.func @test_replace_uses_of_block_arg() {
+ "test.convert_block_and_replace_arg"() ({
+ ^bb0(%arg0: f32):
+ // expected-remark @below{{'test.some_user' is not legalizable}}
+ "test.some_user"(%arg0) : (f32) -> ()
+ }) : () -> ()
+ "test.return"() : () -> ()
+}
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index e931b394c86210..54699f402e2f1e 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -902,6 +902,44 @@ struct TestUndoBlockArgReplace : public ConversionPattern {
}
};
+struct TestConvertBlockAndReplaceArg : public ConversionPattern {
+ TestConvertBlockAndReplaceArg(MLIRContext *ctx,
+ const TypeConverter &converter)
+ : ConversionPattern(converter, "test.convert_block_and_replace_arg",
+ /*benefit=*/1, ctx) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ // Expect single region with single block with single block argument.
+ if (op->getNumRegions() != 1)
+ return failure();
+ if (op->getRegion(0).getBlocks().size() != 1)
+ return failure();
+ Block *block = &op->getRegion(0).front();
+ if (block->getArguments().size() != 1)
+ return failure();
+
+ // Convert the block argument into to F64 block arguments.
+ TypeConverter::SignatureConversion result(1);
+ result.addInputs(0, {rewriter.getF64Type(), rewriter.getF64Type()});
+ Block *newBlock =
+ rewriter.applySignatureConversion(block, result, getTypeConverter());
+
+ // Replace the first block argument with a new op.
+ BlockArgument arg = newBlock->getArgument(0);
+ rewriter.setInsertionPointToStart(newBlock);
+ Value zero = rewriter.create<TestTypeProducerOp>(op->getLoc(),
+ rewriter.getF64Type());
+ rewriter.replaceUsesOfBlockArgument(arg, zero);
+
+ // Mark the op as legal.
+ rewriter.modifyOpInPlace(
+ op, [&]() { op->setAttr("legal", rewriter.getUnitAttr()); });
+ return success();
+ }
+};
+
/// This pattern hoists ops out of a "test.hoist_me" and then fails conversion.
/// This is to test the rollback logic.
struct TestUndoMoveOpBefore : public ConversionPattern {
@@ -1265,7 +1303,8 @@ struct TestLegalizePatternDriver
TestCreateUnregisteredOp, TestUndoMoveOpBefore,
TestUndoPropertiesModification, TestEraseOp>(&getContext());
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
- TestPassthroughInvalidOp>(&getContext(), converter);
+ TestPassthroughInvalidOp, TestConvertBlockAndReplaceArg>(
+ &getContext(), converter);
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
mlir::populateCallOpTypeConversionPattern(patterns, converter);
@@ -1317,6 +1356,10 @@ struct TestLegalizePatternDriver
target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>(
[](TestOpInPlaceSelfFold op) { return op.getFolded(); });
+ target.addDynamicallyLegalOp(
+ OperationName("test.convert_block_and_replace_arg", &getContext()),
+ [](Operation *op) { return op->hasAttr("legal"); });
+
// Handle a partial conversion.
if (mode == ConversionMode::Partial) {
DenseSet<Operation *> unlegalizedOps;
More information about the llvm-branch-commits
mailing list