[Mlir-commits] [mlir] [mlir][Transforms] Dialect Conversion: No target mat. for 1:N replacement (PR #117513)
Matthias Springer
llvmlistbot at llvm.org
Mon Dec 23 04:10:47 PST 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/117513
>From 759edf6f0dcc6bc96da75e93f8f75f628bbee35a Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Mon, 25 Nov 2024 04:13:24 +0100
Subject: [PATCH] [mlir][Transforms] Dialect Conversion: Do not build target
mat. during 1:N replacement
fix test
experiement
---
.../Conversion/LLVMCommon/TypeConverter.cpp | 130 ++++++++++++------
.../Transforms/Utils/DialectConversion.cpp | 46 ++-----
mlir/test/Transforms/test-legalizer.mlir | 8 +-
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 47 +++----
4 files changed, 128 insertions(+), 103 deletions(-)
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 59b0f5c9b09bcd..e2ab0ed6f66cc5 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -153,70 +153,112 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
type.isVarArg());
});
+ // Add generic source and target materializations to handle cases where
+ // non-LLVM types persist after an LLVM conversion.
+ addSourceMaterialization([&](OpBuilder &builder, Type resultType,
+ ValueRange inputs, Location loc) {
+ return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+ .getResult(0);
+ });
+ addTargetMaterialization([&](OpBuilder &builder, Type resultType,
+ ValueRange inputs, Location loc) {
+ return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+ .getResult(0);
+ });
+
// Helper function that checks if the given value range is a bare pointer.
auto isBarePointer = [](ValueRange values) {
return values.size() == 1 &&
isa<LLVM::LLVMPointerType>(values.front().getType());
};
- // Argument materializations convert from the new block argument types
- // (multiple SSA values that make up a memref descriptor) back to the
- // original block argument type. The dialect conversion framework will then
- // insert a target materialization from the original block argument type to
- // a legal type.
- addArgumentMaterialization([&](OpBuilder &builder,
- UnrankedMemRefType resultType,
- ValueRange inputs, Location loc) {
+ // TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
+ // must be passed explicitly.
+ auto packUnrankedMemRefDesc =
+ [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
+ Location loc, LLVMTypeConverter &converter) -> Value {
// Note: Bare pointers are not supported for unranked memrefs because a
// memref descriptor cannot be built just from a bare pointer.
- if (TypeRange(inputs) != getUnrankedMemRefDescriptorFields())
+ if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields())
return Value();
- Value desc =
- UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
+ return UnrankedMemRefDescriptor::pack(builder, loc, converter, resultType,
+ inputs);
+ };
+
+ // MemRef descriptor elements -> UnrankedMemRefType
+ auto unrakedMemRefMaterialization = [&](OpBuilder &builder,
+ UnrankedMemRefType resultType,
+ ValueRange inputs, Location loc) {
// An argument materialization must return a value of type
// `resultType`, so insert a cast from the memref descriptor type
// (!llvm.struct) to the original memref type.
- return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
- .getResult(0);
- });
- addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
- ValueRange inputs, Location loc) {
- Value desc;
- if (isBarePointer(inputs)) {
- desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
- inputs[0]);
- } else if (TypeRange(inputs) ==
- getMemRefDescriptorFields(resultType,
- /*unpackAggregates=*/true)) {
- desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
- } else {
- // The inputs are neither a bare pointer nor an unpacked memref
- // descriptor. This materialization function cannot be used.
+ Value packed =
+ packUnrankedMemRefDesc(builder, resultType, inputs, loc, *this);
+ if (!packed)
return Value();
- }
+ return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
+ .getResult(0);
+ };
+
+ // TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
+ // must be passed explicitly.
+ auto packRankedMemRefDesc = [&](OpBuilder &builder, MemRefType resultType,
+ ValueRange inputs, Location loc,
+ LLVMTypeConverter &converter) -> Value {
+ assert(resultType && "expected non-null result type");
+ if (isBarePointer(inputs))
+ return MemRefDescriptor::fromStaticShape(builder, loc, converter,
+ resultType, inputs[0]);
+ if (TypeRange(inputs) ==
+ converter.getMemRefDescriptorFields(resultType,
+ /*unpackAggregates=*/true))
+ return MemRefDescriptor::pack(builder, loc, converter, resultType,
+ inputs);
+ // The inputs are neither a bare pointer nor an unpacked memref descriptor.
+ // This materialization function cannot be used.
+ return Value();
+ };
+
+ // MemRef descriptor elements -> MemRefType
+ auto rankedMemRefMaterialization = [&](OpBuilder &builder,
+ MemRefType resultType,
+ ValueRange inputs, Location loc) {
// An argument materialization must return a value of type `resultType`,
// so insert a cast from the memref descriptor type (!llvm.struct) to the
// original memref type.
- return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
- .getResult(0);
- });
- // Add generic source and target materializations to handle cases where
- // non-LLVM types persist after an LLVM conversion.
- addSourceMaterialization([&](OpBuilder &builder, Type resultType,
- ValueRange inputs, Location loc) {
- if (inputs.size() != 1)
+ Value packed =
+ packRankedMemRefDesc(builder, resultType, inputs, loc, *this);
+ if (!packed)
return Value();
-
- return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+ return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
.getResult(0);
- });
+ };
+
+ // Argument materializations convert from the new block argument types
+ // (multiple SSA values that make up a memref descriptor) back to the
+ // original block argument type.
+ addArgumentMaterialization(unrakedMemRefMaterialization);
+ addArgumentMaterialization(rankedMemRefMaterialization);
+ addSourceMaterialization(unrakedMemRefMaterialization);
+ addSourceMaterialization(rankedMemRefMaterialization);
+
+ // Bare pointer -> Packed MemRef descriptor
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
- ValueRange inputs, Location loc) {
- if (inputs.size() != 1)
+ ValueRange inputs, Location loc,
+ Type originalType) -> Value {
+ // The original MemRef type is required to build a MemRef descriptor
+ // because the sizes/strides of the MemRef cannot be inferred from just the
+ // bare pointer.
+ if (!originalType)
return Value();
-
- return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
- .getResult(0);
+ if (resultType != convertType(originalType))
+ return Value();
+ if (auto memrefType = dyn_cast<MemRefType>(originalType))
+ return packRankedMemRefDesc(builder, memrefType, inputs, loc, *this);
+ if (auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(originalType))
+ return packUnrankedMemRefDesc(builder, unrankedMemrefType, inputs, loc,
+ *this);
+ return Value();
});
// Integer memory spaces map to themselves.
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 1607740a1ee076..51686646a0a2fc 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -849,8 +849,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// function will be deleted when full 1:N support has been added.
///
/// This function inserts an argument materialization back to the original
- /// type, followed by a target materialization to the legalized type (if
- /// applicable).
+ /// type.
void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc,
ValueRange replacements, Value originalValue,
const TypeConverter *converter);
@@ -1376,9 +1375,13 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
// used as a replacement.
auto replArgs =
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
- insertNTo1Materialization(
- OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
- /*replacements=*/replArgs, /*outputValue=*/origArg, converter);
+ if (replArgs.size() == 1) {
+ mapping.map(origArg, replArgs.front());
+ } else {
+ insertNTo1Materialization(
+ OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
+ /*replacements=*/replArgs, /*outputValue=*/origArg, converter);
+ }
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
}
@@ -1437,36 +1440,12 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
// Insert argument materialization back to the original type.
Type originalType = originalValue.getType();
UnrealizedConversionCastOp argCastOp;
- Value argMat = buildUnresolvedMaterialization(
+ buildUnresolvedMaterialization(
MaterializationKind::Argument, ip, loc, /*valueToMap=*/originalValue,
- /*inputs=*/replacements, originalType, /*originalType=*/Type(), converter,
- &argCastOp);
+ /*inputs=*/replacements, originalType,
+ /*originalType=*/Type(), converter, &argCastOp);
if (argCastOp)
nTo1TempMaterializations.insert(argCastOp);
-
- // Insert target materialization to the legalized type.
- Type legalOutputType;
- if (converter) {
- legalOutputType = converter->convertType(originalType);
- } else if (replacements.size() == 1) {
- // When there is no type converter, assume that the replacement value
- // types are legal. This is reasonable to assume because they were
- // specified by the user.
- // FIXME: This won't work for 1->N conversions because multiple output
- // types are not supported in parts of the dialect conversion. In such a
- // case, we currently use the original value type.
- legalOutputType = replacements[0].getType();
- }
- if (legalOutputType && legalOutputType != originalType) {
- UnrealizedConversionCastOp targetCastOp;
- buildUnresolvedMaterialization(
- MaterializationKind::Target, computeInsertPoint(argMat), loc,
- /*valueToMap=*/argMat, /*inputs=*/argMat,
- /*outputType=*/legalOutputType, /*originalType=*/originalType,
- converter, &targetCastOp);
- if (targetCastOp)
- nTo1TempMaterializations.insert(targetCastOp);
- }
}
Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
@@ -2864,6 +2843,9 @@ void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
LogicalResult TypeConverter::convertType(Type t,
SmallVectorImpl<Type> &results) const {
+ assert(this && "expected non-null type converter");
+ assert(t && "expected non-null type");
+
{
std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
std::defer_lock);
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index d98a6a036e6b1f..2ca5f49637523f 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -124,10 +124,10 @@ func.func @no_remap_nested() {
// CHECK-NEXT: "foo.region"
// expected-remark at +1 {{op 'foo.region' is not legalizable}}
"foo.region"() ({
- // CHECK-NEXT: ^bb0(%{{.*}}: i64, %{{.*}}: i16, %{{.*}}: i64):
- ^bb0(%i0: i64, %unused: i16, %i1: i64):
- // CHECK-NEXT: "test.valid"{{.*}} : (i64, i64)
- "test.invalid"(%i0, %i1) : (i64, i64) -> ()
+ // CHECK-NEXT: ^bb0(%{{.*}}: f64, %{{.*}}: i16, %{{.*}}: f64):
+ ^bb0(%i0: f64, %unused: i16, %i1: f64):
+ // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64)
+ "test.invalid"(%i0, %i1) : (f64, f64) -> ()
}) : () -> ()
// expected-remark at +1 {{op 'func.return' is not legalizable}}
return
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index ce2820b80a945d..a470497fdbb560 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -985,8 +985,8 @@ struct TestDropOpSignatureConversion : public ConversionPattern {
};
/// This pattern simply updates the operands of the given operation.
struct TestPassthroughInvalidOp : public ConversionPattern {
- TestPassthroughInvalidOp(MLIRContext *ctx)
- : ConversionPattern("test.invalid", 1, ctx) {}
+ TestPassthroughInvalidOp(MLIRContext *ctx, const TypeConverter &converter)
+ : ConversionPattern(converter, "test.invalid", 1, ctx) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
@@ -1307,19 +1307,19 @@ struct TestLegalizePatternDriver
TestTypeConverter converter;
mlir::RewritePatternSet patterns(&getContext());
populateWithGenerated(patterns);
- patterns.add<
- TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
- TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
- TestUndoBlockArgReplace, TestUndoBlockErase, TestPassthroughInvalidOp,
- TestSplitReturnType, TestChangeProducerTypeI32ToF32,
- TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid,
- TestUpdateConsumerType, TestNonRootReplacement,
- TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
- TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
- TestUndoPropertiesModification, TestEraseOp,
- TestRepetitive1ToNConsumer>(&getContext());
- patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp>(
- &getContext(), converter);
+ patterns
+ .add<TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
+ TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
+ TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType,
+ TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
+ TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
+ TestNonRootReplacement, TestBoundedRecursiveRewrite,
+ TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
+ TestCreateUnregisteredOp, TestUndoMoveOpBefore,
+ TestUndoPropertiesModification, TestEraseOp,
+ TestRepetitive1ToNConsumer>(&getContext());
+ patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
+ TestPassthroughInvalidOp>(&getContext(), converter);
patterns.add<TestDuplicateBlockArgs>(converter, &getContext());
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
@@ -1755,8 +1755,9 @@ struct TestTypeConversionAnotherProducer
};
struct TestReplaceWithLegalOp : public ConversionPattern {
- TestReplaceWithLegalOp(MLIRContext *ctx)
- : ConversionPattern("test.replace_with_legal_op", /*benefit=*/1, ctx) {}
+ TestReplaceWithLegalOp(const TypeConverter &converter, MLIRContext *ctx)
+ : ConversionPattern(converter, "test.replace_with_legal_op",
+ /*benefit=*/1, ctx) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
@@ -1878,12 +1879,12 @@ struct TestTypeConversionDriver
// Initialize the set of rewrite patterns.
RewritePatternSet patterns(&getContext());
- patterns.add<TestTypeConsumerForward, TestTypeConversionProducer,
- TestSignatureConversionUndo,
- TestTestSignatureConversionNoConverter>(converter,
- &getContext());
- patterns.add<TestTypeConversionAnotherProducer, TestReplaceWithLegalOp>(
- &getContext());
+ patterns
+ .add<TestTypeConsumerForward, TestTypeConversionProducer,
+ TestSignatureConversionUndo,
+ TestTestSignatureConversionNoConverter, TestReplaceWithLegalOp>(
+ converter, &getContext());
+ patterns.add<TestTypeConversionAnotherProducer>(&getContext());
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
More information about the Mlir-commits
mailing list