[Mlir-commits] [mlir] [mlir][Transforms] Dialect Conversion: No target mat. for 1:N replacement (PR #117513)
Matthias Springer
llvmlistbot at llvm.org
Thu Nov 28 22:51:16 PST 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/117513
>From 43699861bfd161217fa7ba06c2f71e3871cfd571 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Mon, 25 Nov 2024 04:13:24 +0100
Subject: [PATCH] [mlir][Transforms] Dialect Conversion: Do not build target
mat. during 1:N replacement
fix test
experiement
---
.../Transforms/Utils/DialectConversion.cpp | 26 +----------
mlir/test/Transforms/test-legalizer.mlir | 8 ++--
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 45 ++++++++++---------
3 files changed, 29 insertions(+), 50 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 1424c4974f2d43..202695cb97c13d 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -839,8 +839,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// function will be deleted when full 1:N support has been added.
///
/// This function inserts an argument materialization back to the original
- /// type, followed by a target materialization to the legalized type (if
- /// applicable).
+ /// type.
void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc,
ValueRange replacements, Value originalValue,
const TypeConverter *converter);
@@ -1379,31 +1378,10 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
Value originalValue, const TypeConverter *converter) {
// Insert argument materialization back to the original type.
Type originalType = originalValue.getType();
- Value argMat = buildUnresolvedMaterialization(
+ buildUnresolvedMaterialization(
MaterializationKind::Argument, ip, loc, /*valueToMap=*/originalValue,
/*inputs=*/replacements, originalType, /*originalType=*/Type(),
converter);
-
- // Insert target materialization to the legalized type.
- Type legalOutputType;
- if (converter) {
- legalOutputType = converter->convertType(originalType);
- } else if (replacements.size() == 1) {
- // When there is no type converter, assume that the replacement value
- // types are legal. This is reasonable to assume because they were
- // specified by the user.
- // FIXME: This won't work for 1->N conversions because multiple output
- // types are not supported in parts of the dialect conversion. In such a
- // case, we currently use the original value type.
- legalOutputType = replacements[0].getType();
- }
- if (legalOutputType && legalOutputType != originalType) {
- buildUnresolvedMaterialization(MaterializationKind::Target,
- computeInsertPoint(argMat), loc,
- /*valueToMap=*/argMat, /*inputs=*/argMat,
- /*outputType=*/legalOutputType,
- /*originalType=*/originalType, converter);
- }
}
Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index e05f444afa68f0..624add08846a28 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -124,10 +124,10 @@ func.func @no_remap_nested() {
// CHECK-NEXT: "foo.region"
// expected-remark at +1 {{op 'foo.region' is not legalizable}}
"foo.region"() ({
- // CHECK-NEXT: ^bb0(%{{.*}}: i64, %{{.*}}: i16, %{{.*}}: i64):
- ^bb0(%i0: i64, %unused: i16, %i1: i64):
- // CHECK-NEXT: "test.valid"{{.*}} : (i64, i64)
- "test.invalid"(%i0, %i1) : (i64, i64) -> ()
+ // CHECK-NEXT: ^bb0(%{{.*}}: f64, %{{.*}}: i16, %{{.*}}: f64):
+ ^bb0(%i0: f64, %unused: i16, %i1: f64):
+ // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64)
+ "test.invalid"(%i0, %i1) : (f64, f64) -> ()
}) : () -> ()
// expected-remark at +1 {{op 'func.return' is not legalizable}}
return
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index bbd55938718fe7..e931b394c86210 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -979,8 +979,8 @@ struct TestDropOpSignatureConversion : public ConversionPattern {
};
/// This pattern simply updates the operands of the given operation.
struct TestPassthroughInvalidOp : public ConversionPattern {
- TestPassthroughInvalidOp(MLIRContext *ctx)
- : ConversionPattern("test.invalid", 1, ctx) {}
+ TestPassthroughInvalidOp(MLIRContext *ctx, const TypeConverter &converter)
+ : ConversionPattern(converter, "test.invalid", 1, ctx) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
@@ -1254,18 +1254,18 @@ struct TestLegalizePatternDriver
TestTypeConverter converter;
mlir::RewritePatternSet patterns(&getContext());
populateWithGenerated(patterns);
- patterns.add<
- TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
- TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
- TestUndoBlockArgReplace, TestUndoBlockErase, TestPassthroughInvalidOp,
- TestSplitReturnType, TestChangeProducerTypeI32ToF32,
- TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid,
- TestUpdateConsumerType, TestNonRootReplacement,
- TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
- TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
- TestUndoPropertiesModification, TestEraseOp>(&getContext());
- patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp>(
- &getContext(), converter);
+ patterns
+ .add<TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
+ TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
+ TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType,
+ TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
+ TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
+ TestNonRootReplacement, TestBoundedRecursiveRewrite,
+ TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
+ TestCreateUnregisteredOp, TestUndoMoveOpBefore,
+ TestUndoPropertiesModification, TestEraseOp>(&getContext());
+ patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
+ TestPassthroughInvalidOp>(&getContext(), converter);
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
mlir::populateCallOpTypeConversionPattern(patterns, converter);
@@ -1697,8 +1697,9 @@ struct TestTypeConversionAnotherProducer
};
struct TestReplaceWithLegalOp : public ConversionPattern {
- TestReplaceWithLegalOp(MLIRContext *ctx)
- : ConversionPattern("test.replace_with_legal_op", /*benefit=*/1, ctx) {}
+ TestReplaceWithLegalOp(const TypeConverter &converter, MLIRContext *ctx)
+ : ConversionPattern(converter, "test.replace_with_legal_op",
+ /*benefit=*/1, ctx) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
@@ -1820,12 +1821,12 @@ struct TestTypeConversionDriver
// Initialize the set of rewrite patterns.
RewritePatternSet patterns(&getContext());
- patterns.add<TestTypeConsumerForward, TestTypeConversionProducer,
- TestSignatureConversionUndo,
- TestTestSignatureConversionNoConverter>(converter,
- &getContext());
- patterns.add<TestTypeConversionAnotherProducer, TestReplaceWithLegalOp>(
- &getContext());
+ patterns
+ .add<TestTypeConsumerForward, TestTypeConversionProducer,
+ TestSignatureConversionUndo,
+ TestTestSignatureConversionNoConverter, TestReplaceWithLegalOp>(
+ converter, &getContext());
+ patterns.add<TestTypeConversionAnotherProducer>(&getContext());
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
More information about the Mlir-commits
mailing list