[Mlir-commits] [mlir] [mlir][Transforms] Dialect Conversion: Add 1:N op replacement test case (PR #121271)
Matthias Springer
llvmlistbot at llvm.org
Sat Dec 28 11:32:11 PST 2024
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/121271
This commit adds a test case that performs two back-to-back 1:N replacements: `(i16) -> (i16, i16) -> ((i16, i16), (i16, i16))`. For the moment, 3 argument materializations are inserted. In the future (when the conversion value mapping supports 1:N), a single target materialization will be inserted. Addresses [comment](https://github.com/llvm/llvm-project/pull/116524#discussion_r1894629711) in #116524.
>From 36e18d06b60fd84df7a6524fb9a233e8e61a723a Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 28 Dec 2024 17:24:41 +0100
Subject: [PATCH] test double replacement
---
mlir/test/Transforms/test-legalizer.mlir | 24 +++++++++-
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 51 +++++++++++++++++++--
2 files changed, 70 insertions(+), 5 deletions(-)
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 2ca5f49637523f..297eb5acef21b7 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -450,7 +450,7 @@ func.func @fold_legalization() -> i32 {
// -----
// CHECK-LABEL: func @convert_detached_signature()
-// CHECK: "test.legal_op_with_region"() ({
+// CHECK: "test.legal_op"() ({
// CHECK: ^bb0(%arg0: f64):
// CHECK: "test.return"() : () -> ()
// CHECK: }) : () -> ()
@@ -483,3 +483,25 @@ func.func @test_1_to_n_block_signature_conversion() {
"test.return"() : () -> ()
}
+// -----
+
+// CHECK: notifyOperationInserted: test.step_1
+// CHECK: notifyOperationReplaced: test.multiple_1_to_n_replacement
+// CHECK: notifyOperationErased: test.multiple_1_to_n_replacement
+// CHECK: notifyOperationInserted: test.legal_op
+// CHECK: notifyOperationReplaced: test.step_1
+// CHECK: notifyOperationErased: test.step_1
+
+// CHECK-LABEL: func @test_multiple_1_to_n_replacement()
+// CHECK: %[[legal_op:.*]]:4 = "test.legal_op"() : () -> (f16, f16, f16, f16)
+// TODO: There should be a single cast (i.e., a single target materialization).
+// This is currently not possible due to 1:N limitations of the conversion
+// mapping. Instead, we have 3 argument materializations.
+// CHECK: %[[cast1:.*]] = "test.cast"(%[[legal_op]]#2, %[[legal_op]]#3) : (f16, f16) -> f16
+// CHECK: %[[cast2:.*]] = "test.cast"(%[[legal_op]]#0, %[[legal_op]]#1) : (f16, f16) -> f16
+// CHECK: %[[cast3:.*]] = "test.cast"(%[[cast2]], %[[cast1]]) : (f16, f16) -> f16
+// CHECK: "test.valid"(%[[cast3]]) : (f16) -> ()
+func.func @test_multiple_1_to_n_replacement() {
+ %0 = "test.multiple_1_to_n_replacement"() : () -> (f16)
+ "test.invalid"(%0) : (f16) -> ()
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index a470497fdbb560..62ba7057019099 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -785,7 +785,7 @@ struct TestDetachedSignatureConversion : public ConversionPattern {
ConversionPatternRewriter &rewriter) const final {
if (op->getNumRegions() != 1)
return failure();
- OperationState state(op->getLoc(), "test.legal_op_with_region", operands,
+ OperationState state(op->getLoc(), "test.legal_op", operands,
op->getResultTypes(), {}, BlockRange());
Region *newRegion = state.addRegion();
rewriter.inlineRegionBefore(op->getRegion(0), *newRegion,
@@ -1234,6 +1234,49 @@ class TestRepetitive1ToNConsumer : public ConversionPattern {
}
};
+/// A pattern that tests two back-to-back 1 -> 2 op replacements.
+class TestMultiple1ToNReplacement : public ConversionPattern {
+public:
+ TestMultiple1ToNReplacement(MLIRContext *ctx, const TypeConverter &converter)
+ : ConversionPattern(converter, "test.multiple_1_to_n_replacement", 1,
+ ctx) {}
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ // Helper function that replaces the given op with a new op of the given
+ // name and doubles each result (1 -> 2 replacement of each result).
+ auto replaceWithDoubleResults = [&](Operation *op, StringRef name) {
+ SmallVector<Type> types;
+ for (Type t : op->getResultTypes()) {
+ types.push_back(t);
+ types.push_back(t);
+ }
+ OperationState state(op->getLoc(), name,
+ /*operands=*/{}, types, op->getAttrs());
+ auto *newOp = rewriter.create(state);
+ SmallVector<ValueRange> repls;
+ for (int i = 0; i < op->getNumResults(); ++i)
+ repls.push_back(newOp->getResults().slice(2 * i, 2));
+ rewriter.replaceOpWithMultiple(op, repls);
+ return newOp;
+ };
+
+ // Replace test.multiple_1_to_n_replacement with test.step_1.
+ Operation *repl1 = replaceWithDoubleResults(op, "test.step_1");
+ // Now replace test.step_1 with test.legal_op.
+ // TODO: Ideally, it should not be necessary to reset the insertion point
+ // here. Based on the API calls, it looks like test.step_1 is entirely
+ // erased. But that's not the case: an argument materialization will
+ // survive. And that argument materialization will be used by the users of
+ // `op`. If we don't reset the insertion point here, we get dominance
+ // errors. This will be fixed when we have 1:N support in the conversion
+ // value mapping.
+ rewriter.setInsertionPoint(repl1);
+ replaceWithDoubleResults(repl1, "test.legal_op");
+ return success();
+ }
+};
+
} // namespace
namespace {
@@ -1319,7 +1362,8 @@ struct TestLegalizePatternDriver
TestUndoPropertiesModification, TestEraseOp,
TestRepetitive1ToNConsumer>(&getContext());
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
- TestPassthroughInvalidOp>(&getContext(), converter);
+ TestPassthroughInvalidOp, TestMultiple1ToNReplacement>(
+ &getContext(), converter);
patterns.add<TestDuplicateBlockArgs>(converter, &getContext());
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
@@ -1330,8 +1374,7 @@ struct TestLegalizePatternDriver
target.addLegalOp<ModuleOp>();
target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
TerminatorOp, OneRegionOp>();
- target.addLegalOp(
- OperationName("test.legal_op_with_region", &getContext()));
+ target.addLegalOp(OperationName("test.legal_op", &getContext()));
target
.addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
More information about the Mlir-commits
mailing list