[Mlir-commits] [mlir] [mlir][Transforms] Support 1:N mappings in `ConversionValueMapping` (PR #116524)
Matthias Springer
llvmlistbot at llvm.org
Sat Dec 14 06:03:16 PST 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/116524
>From f9e875810d6c8af844c90128d7a052a7a40297c0 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Mon, 25 Nov 2024 04:13:24 +0100
Subject: [PATCH 1/2] [mlir][Transforms] Dialect Conversion: Do not build
target mat. during 1:N replacement
fix test
experiement
---
.../Conversion/LLVMCommon/TypeConverter.cpp | 130 ++++++++++++------
.../Transforms/Utils/DialectConversion.cpp | 46 ++-----
mlir/test/Transforms/test-legalizer.mlir | 8 +-
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 47 +++----
4 files changed, 128 insertions(+), 103 deletions(-)
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 59b0f5c9b09bcd..e2ab0ed6f66cc5 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -153,70 +153,112 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
type.isVarArg());
});
+ // Add generic source and target materializations to handle cases where
+ // non-LLVM types persist after an LLVM conversion.
+ addSourceMaterialization([&](OpBuilder &builder, Type resultType,
+ ValueRange inputs, Location loc) {
+ return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+ .getResult(0);
+ });
+ addTargetMaterialization([&](OpBuilder &builder, Type resultType,
+ ValueRange inputs, Location loc) {
+ return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+ .getResult(0);
+ });
+
// Helper function that checks if the given value range is a bare pointer.
auto isBarePointer = [](ValueRange values) {
return values.size() == 1 &&
isa<LLVM::LLVMPointerType>(values.front().getType());
};
- // Argument materializations convert from the new block argument types
- // (multiple SSA values that make up a memref descriptor) back to the
- // original block argument type. The dialect conversion framework will then
- // insert a target materialization from the original block argument type to
- // a legal type.
- addArgumentMaterialization([&](OpBuilder &builder,
- UnrankedMemRefType resultType,
- ValueRange inputs, Location loc) {
+ // TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
+ // must be passed explicitly.
+ auto packUnrankedMemRefDesc =
+ [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
+ Location loc, LLVMTypeConverter &converter) -> Value {
// Note: Bare pointers are not supported for unranked memrefs because a
// memref descriptor cannot be built just from a bare pointer.
- if (TypeRange(inputs) != getUnrankedMemRefDescriptorFields())
+ if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields())
return Value();
- Value desc =
- UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
+ return UnrankedMemRefDescriptor::pack(builder, loc, converter, resultType,
+ inputs);
+ };
+
+ // MemRef descriptor elements -> UnrankedMemRefType
+ auto unrakedMemRefMaterialization = [&](OpBuilder &builder,
+ UnrankedMemRefType resultType,
+ ValueRange inputs, Location loc) {
// An argument materialization must return a value of type
// `resultType`, so insert a cast from the memref descriptor type
// (!llvm.struct) to the original memref type.
- return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
- .getResult(0);
- });
- addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
- ValueRange inputs, Location loc) {
- Value desc;
- if (isBarePointer(inputs)) {
- desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
- inputs[0]);
- } else if (TypeRange(inputs) ==
- getMemRefDescriptorFields(resultType,
- /*unpackAggregates=*/true)) {
- desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
- } else {
- // The inputs are neither a bare pointer nor an unpacked memref
- // descriptor. This materialization function cannot be used.
+ Value packed =
+ packUnrankedMemRefDesc(builder, resultType, inputs, loc, *this);
+ if (!packed)
return Value();
- }
+ return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
+ .getResult(0);
+ };
+
+ // TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
+ // must be passed explicitly.
+ auto packRankedMemRefDesc = [&](OpBuilder &builder, MemRefType resultType,
+ ValueRange inputs, Location loc,
+ LLVMTypeConverter &converter) -> Value {
+ assert(resultType && "expected non-null result type");
+ if (isBarePointer(inputs))
+ return MemRefDescriptor::fromStaticShape(builder, loc, converter,
+ resultType, inputs[0]);
+ if (TypeRange(inputs) ==
+ converter.getMemRefDescriptorFields(resultType,
+ /*unpackAggregates=*/true))
+ return MemRefDescriptor::pack(builder, loc, converter, resultType,
+ inputs);
+ // The inputs are neither a bare pointer nor an unpacked memref descriptor.
+ // This materialization function cannot be used.
+ return Value();
+ };
+
+ // MemRef descriptor elements -> MemRefType
+ auto rankedMemRefMaterialization = [&](OpBuilder &builder,
+ MemRefType resultType,
+ ValueRange inputs, Location loc) {
// An argument materialization must return a value of type `resultType`,
// so insert a cast from the memref descriptor type (!llvm.struct) to the
// original memref type.
- return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
- .getResult(0);
- });
- // Add generic source and target materializations to handle cases where
- // non-LLVM types persist after an LLVM conversion.
- addSourceMaterialization([&](OpBuilder &builder, Type resultType,
- ValueRange inputs, Location loc) {
- if (inputs.size() != 1)
+ Value packed =
+ packRankedMemRefDesc(builder, resultType, inputs, loc, *this);
+ if (!packed)
return Value();
-
- return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+ return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
.getResult(0);
- });
+ };
+
+ // Argument materializations convert from the new block argument types
+ // (multiple SSA values that make up a memref descriptor) back to the
+ // original block argument type.
+ addArgumentMaterialization(unrakedMemRefMaterialization);
+ addArgumentMaterialization(rankedMemRefMaterialization);
+ addSourceMaterialization(unrakedMemRefMaterialization);
+ addSourceMaterialization(rankedMemRefMaterialization);
+
+ // Bare pointer -> Packed MemRef descriptor
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
- ValueRange inputs, Location loc) {
- if (inputs.size() != 1)
+ ValueRange inputs, Location loc,
+ Type originalType) -> Value {
+ // The original MemRef type is required to build a MemRef descriptor
+ // because the sizes/strides of the MemRef cannot be inferred from just the
+ // bare pointer.
+ if (!originalType)
return Value();
-
- return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
- .getResult(0);
+ if (resultType != convertType(originalType))
+ return Value();
+ if (auto memrefType = dyn_cast<MemRefType>(originalType))
+ return packRankedMemRefDesc(builder, memrefType, inputs, loc, *this);
+ if (auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(originalType))
+ return packUnrankedMemRefDesc(builder, unrankedMemrefType, inputs, loc,
+ *this);
+ return Value();
});
// Integer memory spaces map to themselves.
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 1607740a1ee076..51686646a0a2fc 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -849,8 +849,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// function will be deleted when full 1:N support has been added.
///
/// This function inserts an argument materialization back to the original
- /// type, followed by a target materialization to the legalized type (if
- /// applicable).
+ /// type.
void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc,
ValueRange replacements, Value originalValue,
const TypeConverter *converter);
@@ -1376,9 +1375,13 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
// used as a replacement.
auto replArgs =
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
- insertNTo1Materialization(
- OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
- /*replacements=*/replArgs, /*outputValue=*/origArg, converter);
+ if (replArgs.size() == 1) {
+ mapping.map(origArg, replArgs.front());
+ } else {
+ insertNTo1Materialization(
+ OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
+ /*replacements=*/replArgs, /*outputValue=*/origArg, converter);
+ }
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
}
@@ -1437,36 +1440,12 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
// Insert argument materialization back to the original type.
Type originalType = originalValue.getType();
UnrealizedConversionCastOp argCastOp;
- Value argMat = buildUnresolvedMaterialization(
+ buildUnresolvedMaterialization(
MaterializationKind::Argument, ip, loc, /*valueToMap=*/originalValue,
- /*inputs=*/replacements, originalType, /*originalType=*/Type(), converter,
- &argCastOp);
+ /*inputs=*/replacements, originalType,
+ /*originalType=*/Type(), converter, &argCastOp);
if (argCastOp)
nTo1TempMaterializations.insert(argCastOp);
-
- // Insert target materialization to the legalized type.
- Type legalOutputType;
- if (converter) {
- legalOutputType = converter->convertType(originalType);
- } else if (replacements.size() == 1) {
- // When there is no type converter, assume that the replacement value
- // types are legal. This is reasonable to assume because they were
- // specified by the user.
- // FIXME: This won't work for 1->N conversions because multiple output
- // types are not supported in parts of the dialect conversion. In such a
- // case, we currently use the original value type.
- legalOutputType = replacements[0].getType();
- }
- if (legalOutputType && legalOutputType != originalType) {
- UnrealizedConversionCastOp targetCastOp;
- buildUnresolvedMaterialization(
- MaterializationKind::Target, computeInsertPoint(argMat), loc,
- /*valueToMap=*/argMat, /*inputs=*/argMat,
- /*outputType=*/legalOutputType, /*originalType=*/originalType,
- converter, &targetCastOp);
- if (targetCastOp)
- nTo1TempMaterializations.insert(targetCastOp);
- }
}
Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
@@ -2864,6 +2843,9 @@ void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
LogicalResult TypeConverter::convertType(Type t,
SmallVectorImpl<Type> &results) const {
+ assert(this && "expected non-null type converter");
+ assert(t && "expected non-null type");
+
{
std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
std::defer_lock);
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index d98a6a036e6b1f..2ca5f49637523f 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -124,10 +124,10 @@ func.func @no_remap_nested() {
// CHECK-NEXT: "foo.region"
// expected-remark at +1 {{op 'foo.region' is not legalizable}}
"foo.region"() ({
- // CHECK-NEXT: ^bb0(%{{.*}}: i64, %{{.*}}: i16, %{{.*}}: i64):
- ^bb0(%i0: i64, %unused: i16, %i1: i64):
- // CHECK-NEXT: "test.valid"{{.*}} : (i64, i64)
- "test.invalid"(%i0, %i1) : (i64, i64) -> ()
+ // CHECK-NEXT: ^bb0(%{{.*}}: f64, %{{.*}}: i16, %{{.*}}: f64):
+ ^bb0(%i0: f64, %unused: i16, %i1: f64):
+ // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64)
+ "test.invalid"(%i0, %i1) : (f64, f64) -> ()
}) : () -> ()
// expected-remark at +1 {{op 'func.return' is not legalizable}}
return
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 8a0bc597c56beb..466ae7ff6f46f1 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -979,8 +979,8 @@ struct TestDropOpSignatureConversion : public ConversionPattern {
};
/// This pattern simply updates the operands of the given operation.
struct TestPassthroughInvalidOp : public ConversionPattern {
- TestPassthroughInvalidOp(MLIRContext *ctx)
- : ConversionPattern("test.invalid", 1, ctx) {}
+ TestPassthroughInvalidOp(MLIRContext *ctx, const TypeConverter &converter)
+ : ConversionPattern(converter, "test.invalid", 1, ctx) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
@@ -1301,19 +1301,19 @@ struct TestLegalizePatternDriver
TestTypeConverter converter;
mlir::RewritePatternSet patterns(&getContext());
populateWithGenerated(patterns);
- patterns.add<
- TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
- TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
- TestUndoBlockArgReplace, TestUndoBlockErase, TestPassthroughInvalidOp,
- TestSplitReturnType, TestChangeProducerTypeI32ToF32,
- TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid,
- TestUpdateConsumerType, TestNonRootReplacement,
- TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
- TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
- TestUndoPropertiesModification, TestEraseOp,
- TestRepetitive1ToNConsumer>(&getContext());
- patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp>(
- &getContext(), converter);
+ patterns
+ .add<TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
+ TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
+ TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType,
+ TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
+ TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
+ TestNonRootReplacement, TestBoundedRecursiveRewrite,
+ TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
+ TestCreateUnregisteredOp, TestUndoMoveOpBefore,
+ TestUndoPropertiesModification, TestEraseOp,
+ TestRepetitive1ToNConsumer>(&getContext());
+ patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
+ TestPassthroughInvalidOp>(&getContext(), converter);
patterns.add<TestDuplicateBlockArgs>(converter, &getContext());
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
@@ -1749,8 +1749,9 @@ struct TestTypeConversionAnotherProducer
};
struct TestReplaceWithLegalOp : public ConversionPattern {
- TestReplaceWithLegalOp(MLIRContext *ctx)
- : ConversionPattern("test.replace_with_legal_op", /*benefit=*/1, ctx) {}
+ TestReplaceWithLegalOp(const TypeConverter &converter, MLIRContext *ctx)
+ : ConversionPattern(converter, "test.replace_with_legal_op",
+ /*benefit=*/1, ctx) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
@@ -1872,12 +1873,12 @@ struct TestTypeConversionDriver
// Initialize the set of rewrite patterns.
RewritePatternSet patterns(&getContext());
- patterns.add<TestTypeConsumerForward, TestTypeConversionProducer,
- TestSignatureConversionUndo,
- TestTestSignatureConversionNoConverter>(converter,
- &getContext());
- patterns.add<TestTypeConversionAnotherProducer, TestReplaceWithLegalOp>(
- &getContext());
+ patterns
+ .add<TestTypeConsumerForward, TestTypeConversionProducer,
+ TestSignatureConversionUndo,
+ TestTestSignatureConversionNoConverter, TestReplaceWithLegalOp>(
+ converter, &getContext());
+ patterns.add<TestTypeConversionAnotherProducer>(&getContext());
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
>From f983e30b097770fa04358d7a827545a0997de04f Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 14 Dec 2024 15:02:10 +0100
Subject: [PATCH 2/2] prototype
---
.../Conversion/LLVMCommon/TypeConverter.cpp | 2 -
.../EmitC/Transforms/TypeConversions.cpp | 1 -
.../Dialect/Linalg/Transforms/Detensorize.cpp | 1 -
.../Quant/Transforms/StripFuncQuantTypes.cpp | 1 -
.../Utils/SparseTensorDescriptor.cpp | 3 -
.../Vector/Transforms/VectorLinearize.cpp | 1 -
.../Transforms/Utils/DialectConversion.cpp | 526 +++++++++++-------
.../Func/TestDecomposeCallGraphTypes.cpp | 2 +-
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 1 -
.../lib/Transforms/TestDialectConversion.cpp | 1 -
10 files changed, 318 insertions(+), 221 deletions(-)
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index e2ab0ed6f66cc5..ef8181e80cee38 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -237,8 +237,6 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
// Argument materializations convert from the new block argument types
// (multiple SSA values that make up a memref descriptor) back to the
// original block argument type.
- addArgumentMaterialization(unrakedMemRefMaterialization);
- addArgumentMaterialization(rankedMemRefMaterialization);
addSourceMaterialization(unrakedMemRefMaterialization);
addSourceMaterialization(rankedMemRefMaterialization);
diff --git a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
index 0b3a494794f3f5..72c8fd0f324850 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
@@ -33,7 +33,6 @@ void mlir::populateEmitCSizeTTypeConversions(TypeConverter &converter) {
converter.addSourceMaterialization(materializeAsUnrealizedCast);
converter.addTargetMaterialization(materializeAsUnrealizedCast);
- converter.addArgumentMaterialization(materializeAsUnrealizedCast);
}
/// Get an unsigned integer or size data type corresponding to \p ty.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index af38485291182f..61bc5022893741 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -154,7 +154,6 @@ class DetensorizeTypeConverter : public TypeConverter {
});
addSourceMaterialization(sourceMaterializationCallback);
- addArgumentMaterialization(sourceMaterializationCallback);
}
};
diff --git a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
index 61912722662830..71b88d1be1b05b 100644
--- a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
@@ -56,7 +56,6 @@ class QuantizedTypeConverter : public TypeConverter {
addConversion(convertQuantizedType);
addConversion(convertTensorType);
- addArgumentMaterialization(materializeConversion);
addSourceMaterialization(materializeConversion);
addTargetMaterialization(materializeConversion);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
index 834e3634cc130d..8bbb2cac5efdf3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
@@ -69,9 +69,6 @@ SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
// Required by scf.for 1:N type conversion.
addSourceMaterialization(materializeTuple);
-
- // Required as a workaround until we have full 1:N support.
- addArgumentMaterialization(materializeTuple);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 757631944f224f..68535ae5a7a5c6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -481,7 +481,6 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
};
- typeConverter.addArgumentMaterialization(materializeCast);
typeConverter.addSourceMaterialization(materializeCast);
typeConverter.addTargetMaterialization(materializeCast);
target.markUnknownOpDynamicallyLegal(
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 51686646a0a2fc..81c8c1f422551f 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -63,11 +63,45 @@ static OpBuilder::InsertPoint computeInsertPoint(Value value) {
return OpBuilder::InsertPoint(insertBlock, insertPt);
}
+/// Helper function that computes an insertion point where the given value is
+/// defined and can be used without a dominance violation.
+static OpBuilder::InsertPoint computeInsertPoint(ArrayRef<Value> vals) {
+ assert(!vals.empty() && "expected at least one value");
+ OpBuilder::InsertPoint pt = computeInsertPoint(vals.front());
+ for (Value v : vals.drop_front()) {
+ OpBuilder::InsertPoint pt2 = computeInsertPoint(v);
+ assert(pt.getBlock() == pt2.getBlock());
+ if (pt.getPoint() == pt.getBlock()->begin()) {
+ pt = pt2;
+ continue;
+ }
+ if (pt2.getPoint() == pt2.getBlock()->begin()) {
+ continue;
+ }
+ if (pt.getPoint()->isBeforeInBlock(&*pt2.getPoint()))
+ pt = pt2;
+ }
+ return pt;
+}
+
//===----------------------------------------------------------------------===//
// ConversionValueMapping
//===----------------------------------------------------------------------===//
namespace {
+struct SmallVectorMapInfo {
+ static SmallVector<Value, 1> getEmptyKey() { return SmallVector<Value, 1>{}; }
+ static SmallVector<Value, 1> getTombstoneKey() {
+ return SmallVector<Value, 1>{};
+ }
+ static ::llvm::hash_code getHashValue(SmallVector<Value, 1> val) {
+ return ::llvm::hash_combine_range(val.begin(), val.end());
+ }
+ static bool isEqual(SmallVector<Value, 1> LHS, SmallVector<Value, 1> RHS) {
+ return LHS == RHS;
+ }
+};
+
/// This class wraps a IRMapping to provide recursive lookup
/// functionality, i.e. we will traverse if the mapped value also has a mapping.
struct ConversionValueMapping {
@@ -75,71 +109,240 @@ struct ConversionValueMapping {
/// false positives.
bool isMappedTo(Value value) const { return mappedTo.contains(value); }
- /// Lookup the most recently mapped value with the desired type in the
- /// mapping.
+ /// Find the most recently mapped values for the given value. If the value is
+ /// not mapped at all, return the given value.
+ SmallVector<Value, 1> lookupOrDefault(Value from) const;
+
+ /// TODO: Find most recently mapped or materialization with matching type. May
+ /// return the given value if the type matches.
+ SmallVector<Value, 1>
+ lookupOrDefault(Value from, SmallVector<Type, 1> desiredTypes) const;
+
+ Value lookupDirectSingleReplacement(Value from) const {
+ auto it = mapping.find(from);
+ if (it == mapping.end())
+ return Value();
+ const SmallVector<Value, 1> &repl = it->second;
+ if (repl.size() != 1)
+ return Value();
+ return repl.front();
+ /*
+ if (!mapping.contains(from)) return Value();
+ auto it = llvm::find(mapping, from);
+ const SmallVector<Value, 1> &repl = it->second;
+ if (repl.size() != 1) return Value();
+ return repl.front();
+ */
+ }
+
+ SmallVector<Value,1> lookupDirectReplacement(Value from) const {
+ auto it = mapping.find(from);
+ if (it == mapping.end())
+ return {};
+ return it->second;
+ }
+
+ /// Find the most recently mapped values for the given value. If the value is
+ /// not mapped at all, return an empty vector.
+ SmallVector<Value, 1> lookupOrNull(Value from) const;
+
+ /// Find the most recently mapped values for the given value. If those values
+ /// have the desired types, return them. Otherwise, try to find a
+ /// materialization to the desired types.
///
- /// Special cases:
- /// - If the desired type is "null", simply return the most recently mapped
- /// value.
- /// - If there is no mapping to the desired type, also return the most
- /// recently mapped value.
- /// - If there is no mapping for the given value at all, return the given
- /// value.
- Value lookupOrDefault(Value from, Type desiredType = nullptr) const;
-
- /// Lookup a mapped value within the map, or return null if a mapping does not
- /// exist. If a mapping exists, this follows the same behavior of
- /// `lookupOrDefault`.
- Value lookupOrNull(Value from, Type desiredType = nullptr) const;
-
- /// Map a value to the one provided.
- void map(Value oldVal, Value newVal) {
- LLVM_DEBUG({
- for (Value it = newVal; it; it = mapping.lookupOrNull(it))
- assert(it != oldVal && "inserting cyclic mapping");
- });
- mapping.map(oldVal, newVal);
- mappedTo.insert(newVal);
+ /// If the given value is not mapped at all or if there are no mapped values/
+ /// materialization results with the desired types, return an empty vector.
+ SmallVector<Value, 1> lookupOrNull(Value from,
+ SmallVector<Type, 1> desiredTypes) const;
+
+ Value lookupOrNull(Value from, Type desiredType) {
+ SmallVector<Value, 1> vals =
+ lookupOrNull(from, SmallVector<Type, 1>{desiredType});
+ if (vals.empty())
+ return Value();
+ assert(vals.size() == 1 && "expected single value");
+ return vals.front();
+ }
+
+ void erase(Value from) { mapping.erase(from); }
+
+ void map(Value from, ValueRange to) {
+#ifndef NDEBUG
+ assert(from && "expected non-null value");
+ assert(!to.empty() && "cannot map to zero values");
+ for (Value v : to)
+ assert(v && "expected non-null value");
+#endif
+ // assert(from != to && "cannot map value to itself");
+ // TODO: Check for cyclic mapping.
+ assert(!mapping.contains(from) && "value is already mapped");
+ mapping[from].assign(to.begin(), to.end());
+ for (Value v : to)
+ mappedTo.insert(v);
+ }
+
+ void map(Value from, ArrayRef<BlockArgument> to) {
+ SmallVector<Value> vals;
+ for (Value v : to)
+ vals.push_back(v);
+ map(from, vals);
+ }
+ /*
+ void map(Value from, ArrayRef<Value> to) {
+ #ifndef NDEBUG
+ assert(from && "expected non-null value");
+ assert(!to.empty() && "cannot map to zero values");
+ for (Value v : to)
+ assert(v && "expected non-null value");
+ #endif
+ // assert(from != to && "cannot map value to itself");
+ // TODO: Check for cyclic mapping.
+ assert(!mapping.contains(from) && "value is already mapped");
+ mapping[from].assign(to.begin(), to.end());
+ }
+ */
+
+ void mapMaterialization(SmallVector<Value, 1> from,
+ SmallVector<Value, 1> to) {
+#ifndef NDEBUG
+ assert(!from.empty() && "from cannot be empty");
+ assert(!to.empty() && "to cannot be empty");
+ for (Value v : from) {
+ assert(v && "expected non-null value");
+ assert(!mapping.contains(v) &&
+ "cannot add materialization for mapped value");
+ }
+ for (Value v : to) {
+ assert(v && "expected non-null value");
+ }
+ assert(TypeRange(from) != TypeRange(to) &&
+ "cannot add materialization for identical type");
+ for (const SmallVector<Value, 1> &mat : materializations[from])
+ assert(TypeRange(mat) != TypeRange(to) &&
+ "cannot register duplicate materialization");
+#endif // NDEBUG
+ materializations[from].push_back(to);
+ for (Value v : to)
+ mappedTo.insert(v);
+ }
+
+ void eraseMaterialization(SmallVector<Value, 1> from,
+ SmallVector<Value, 1> to) {
+ if (!materializations.count(from))
+ return;
+ auto it = llvm::find(materializations[from], to);
+ if (it == materializations[from].end())
+ return;
+ if (materializations[from].size() == 1)
+ materializations.erase(from);
+ else
+ materializations[from].erase(it);
}
- /// Drop the last mapping for the given value.
- void erase(Value value) { mapping.erase(value); }
+ /// Returns the inverse raw value mapping (without recursive query support).
+ DenseMap<Value, SmallVector<Value>> getInverse() const {
+ DenseMap<Value, SmallVector<Value>> inverse;
+
+ for (auto &it : mapping)
+ for (Value v : it.second)
+ inverse[v].push_back(it.first);
+
+ for (auto &it : materializations)
+ for (const SmallVector<Value, 1> &mat : it.second)
+ for (Value v : mat)
+ for (Value v2 : it.first)
+ inverse[v].push_back(v2);
+
+ return inverse;
+ }
private:
- /// Current value mappings.
- IRMapping mapping;
+ /// Replacement mapping: Value -> ValueRange
+ DenseMap<Value, SmallVector<Value, 1>> mapping;
+
+ /// Materializations: ValueRange -> ValueRange*
+ DenseMap<SmallVector<Value, 1>, SmallVector<SmallVector<Value, 1>>,
+ SmallVectorMapInfo>
+ materializations;
/// All SSA values that are mapped to. May contain false positives.
DenseSet<Value> mappedTo;
};
} // namespace
-Value ConversionValueMapping::lookupOrDefault(Value from,
- Type desiredType) const {
- // Try to find the deepest value that has the desired type. If there is no
- // such value, simply return the deepest value.
- Value desiredValue;
- do {
- if (!desiredType || from.getType() == desiredType)
- desiredValue = from;
-
- Value mappedValue = mapping.lookupOrNull(from);
- if (!mappedValue)
- break;
- from = mappedValue;
- } while (true);
+SmallVector<Value, 1>
+ConversionValueMapping::lookupOrDefault(Value from) const {
+ SmallVector<Value, 1> to = lookupOrNull(from);
+ return to.empty() ? SmallVector<Value, 1>{from} : to;
+}
- // If the desired value was found use it, otherwise default to the leaf value.
- return desiredValue ? desiredValue : from;
+SmallVector<Value, 1> ConversionValueMapping::lookupOrDefault(
+ Value from, SmallVector<Type, 1> desiredTypes) const {
+#ifndef NDEBUG
+ assert(desiredTypes.size() > 0 && "expected non-empty types");
+ for (Type t : desiredTypes)
+ assert(t && "expected non-null type");
+#endif // NDEBUG
+
+ SmallVector<Value, 1> vals = lookupOrNull(from);
+ if (vals.empty()) {
+ // Value is not mapped. Return if the type matches.
+ if (TypeRange(from) == desiredTypes)
+ return {from};
+ // Check materializations.
+ auto it = materializations.find({from});
+ if (it == materializations.end())
+ return {};
+ for (const SmallVector<Value, 1> &mat : it->second)
+ if (TypeRange(mat) == desiredTypes)
+ return mat;
+ return {};
+ }
+
+ return lookupOrNull(from, desiredTypes);
}
-Value ConversionValueMapping::lookupOrNull(Value from, Type desiredType) const {
- Value result = lookupOrDefault(from, desiredType);
- if (result == from || (desiredType && result.getType() != desiredType))
- return nullptr;
+SmallVector<Value, 1> ConversionValueMapping::lookupOrNull(Value from) const {
+ auto it = mapping.find(from);
+ if (it == mapping.end())
+ return {};
+ SmallVector<Value, 1> result;
+ for (Value v : it->second) {
+ llvm::append_range(result, lookupOrDefault(v));
+ }
return result;
}
+SmallVector<Value, 1>
+ConversionValueMapping::lookupOrNull(Value from,
+ SmallVector<Type, 1> desiredTypes) const {
+#ifndef NDEBUG
+ assert(desiredTypes.size() > 0 && "expected non-empty types");
+ for (Type t : desiredTypes)
+ assert(t && "expected non-null type");
+#endif // NDEBUG
+
+ SmallVector<Value, 1> vals = lookupOrNull(from);
+ if (vals.empty())
+ return {};
+
+ // There is a mapping and the types match.
+ if (TypeRange(vals) == desiredTypes)
+ return vals;
+
+ // There is a mapping, but the types do not match. Try to find a matching
+ // materialization.
+ auto it = materializations.find(vals);
+ if (it == materializations.end())
+ return {};
+ for (const SmallVector<Value, 1> &mat : it->second)
+ if (TypeRange(mat) == desiredTypes)
+ return mat;
+
+ // No materialization found. Return an empty vector.
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// Rewriter and Translation State
//===----------------------------------------------------------------------===//
@@ -673,7 +876,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
UnrealizedConversionCastOp op,
const TypeConverter *converter,
MaterializationKind kind, Type originalType,
- Value mappedValue);
+ SmallVector<Value, 1> mappedValue);
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -710,7 +913,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
/// The value in the conversion value mapping that is being replaced by the
/// results of this unresolved materialization.
- Value mappedValue;
+ SmallVector<Value, 1> mappedValue;
};
} // namespace
@@ -779,7 +982,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
LogicalResult remapValues(StringRef valueDiagTag,
std::optional<Location> inputLoc,
PatternRewriter &rewriter, ValueRange values,
- SmallVector<SmallVector<Value>> &remapped);
+ SmallVector<SmallVector<Value, 1>> &remapped);
/// Return "true" if the given operation is ignored, and does not need to be
/// converted.
@@ -825,7 +1028,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// mapping.
ValueRange buildUnresolvedMaterialization(
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
- Value valueToMap, ValueRange inputs, TypeRange outputTypes,
+ SmallVector<Value,1> valueToMap, ValueRange inputs, TypeRange outputTypes,
Type originalType, const TypeConverter *converter,
UnrealizedConversionCastOp *castOp = nullptr);
Value buildUnresolvedMaterialization(
@@ -833,27 +1036,14 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
Value valueToMap, ValueRange inputs, Type outputType, Type originalType,
const TypeConverter *converter,
UnrealizedConversionCastOp *castOp = nullptr) {
- return buildUnresolvedMaterialization(kind, ip, loc, valueToMap, inputs,
+ SmallVector<Value,1> valuesToMap;
+ if (valueToMap) valuesToMap.push_back(valueToMap);
+ return buildUnresolvedMaterialization(kind, ip, loc, valuesToMap, inputs,
TypeRange(outputType), originalType,
converter, castOp)
.front();
}
- /// Build an N:1 materialization for the given original value that was
- /// replaced with the given replacement values.
- ///
- /// This is a workaround around incomplete 1:N support in the dialect
- /// conversion driver. The conversion mapping can store only 1:1 replacements
- /// and the conversion patterns only support single Value replacements in the
- /// adaptor, so N values must be converted back to a single value. This
- /// function will be deleted when full 1:N support has been added.
- ///
- /// This function inserts an argument materialization back to the original
- /// type.
- void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc,
- ValueRange replacements, Value originalValue,
- const TypeConverter *converter);
-
/// Find a replacement value for the given SSA value in the conversion value
/// mapping. The replacement value must have the same type as the given SSA
/// value. If there is no replacement value with the correct type, find the
@@ -862,16 +1052,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
Value findOrBuildReplacementValue(Value value,
const TypeConverter *converter);
- /// Unpack an N:1 materialization and return the inputs of the
- /// materialization. This function unpacks only those materializations that
- /// were built with `insertNTo1Materialization`.
- ///
- /// This is a workaround around incomplete 1:N support in the dialect
- /// conversion driver. It allows us to write 1:N conversion patterns while
- /// 1:N support is still missing in the conversion value mapping. This
- /// function will be deleted when full 1:N support has been added.
- SmallVector<Value> unpackNTo1Materialization(Value value);
-
//===--------------------------------------------------------------------===//
// Rewriter Notification Hooks
//===--------------------------------------------------------------------===//
@@ -1101,7 +1281,7 @@ void CreateOperationRewrite::rollback() {
UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
const TypeConverter *converter, MaterializationKind kind, Type originalType,
- Value mappedValue)
+ SmallVector<Value, 1> mappedValue)
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
converterAndKind(converter, kind), originalType(originalType),
mappedValue(mappedValue) {
@@ -1111,8 +1291,8 @@ UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
}
void UnresolvedMaterializationRewrite::rollback() {
- if (mappedValue)
- rewriterImpl.mapping.erase(mappedValue);
+ if (!mappedValue.empty())
+ rewriterImpl.mapping.eraseMaterialization(mappedValue, op->getResults());
rewriterImpl.unresolvedMaterializations.erase(getOperation());
rewriterImpl.nTo1TempMaterializations.erase(getOperation());
op->erase();
@@ -1160,7 +1340,7 @@ void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
LogicalResult ConversionPatternRewriterImpl::remapValues(
StringRef valueDiagTag, std::optional<Location> inputLoc,
PatternRewriter &rewriter, ValueRange values,
- SmallVector<SmallVector<Value>> &remapped) {
+ SmallVector<SmallVector<Value, 1>> &remapped) {
remapped.reserve(llvm::size(values));
for (const auto &it : llvm::enumerate(values)) {
@@ -1168,18 +1348,12 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
Type origType = operand.getType();
Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
- // Find the most recently mapped value. Unpack all temporary N:1
- // materializations. Such conversions are a workaround around missing
- // 1:N support in the ConversionValueMapping. (The conversion patterns
- // already support 1:N replacements.)
- Value repl = mapping.lookupOrDefault(operand);
- SmallVector<Value> unpacked = unpackNTo1Materialization(repl);
-
if (!currentTypeConverter) {
// The current pattern does not have a type converter. I.e., it does not
// distinguish between legal and illegal types. For each operand, simply
// pass through the most recently mapped value.
- remapped.push_back(std::move(unpacked));
+ SmallVector<Value, 1> repl = mapping.lookupOrDefault(operand);
+ remapped.push_back(repl);
continue;
}
@@ -1199,44 +1373,23 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
continue;
}
- if (legalTypes.size() != 1) {
- // TODO: This is a 1:N conversion. The conversion value mapping does not
- // store such materializations yet. If the types of the most recently
- // mapped values do not match, build a target materialization.
- ValueRange unpackedRange(unpacked);
- if (TypeRange(unpackedRange) == legalTypes) {
- remapped.push_back(std::move(unpacked));
- continue;
- }
-
- // Insert a target materialization if the current pattern expects
- // different legalized types.
- ValueRange targetMat = buildUnresolvedMaterialization(
- MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
- /*valueToMap=*/Value(), /*inputs=*/unpacked,
- /*outputType=*/legalTypes, /*originalType=*/origType,
- currentTypeConverter);
- remapped.push_back(targetMat);
+ SmallVector<Value, 1> mat = mapping.lookupOrDefault(operand, legalTypes);
+ if (!mat.empty()) {
+ // Mapped value has the correct type or there is an existing
+ // materialization. Or the value is not mapped at all and has the
+ // correct type.
+ remapped.push_back(mat);
continue;
}
- // Handle 1->1 type conversions.
- Type desiredType = legalTypes.front();
- // Try to find a mapped value with the desired type. (Or the operand itself
- // if the value is not mapped at all.)
- Value newOperand = mapping.lookupOrDefault(operand, desiredType);
- if (newOperand.getType() != desiredType) {
- // If the looked up value's type does not have the desired type, it means
- // that the value was replaced with a value of different type and no
- // target materialization was created yet.
- Value castValue = buildUnresolvedMaterialization(
- MaterializationKind::Target, computeInsertPoint(newOperand),
- operandLoc, /*valueToMap=*/newOperand, /*inputs=*/unpacked,
- /*outputType=*/desiredType, /*originalType=*/origType,
- currentTypeConverter);
- newOperand = castValue;
- }
- remapped.push_back({newOperand});
+ // Create a materialization for the most recently mapped value.
+ SmallVector<Value, 1> vals = mapping.lookupOrDefault(operand);
+ ValueRange castValues = buildUnresolvedMaterialization(
+ MaterializationKind::Target, computeInsertPoint(vals), operandLoc,
+ /*valueToMap=*/vals,
+ /*inputs=*/vals, /*outputTypes=*/legalTypes, /*originalType=*/origType,
+ currentTypeConverter);
+ remapped.push_back(castValues);
}
return success();
}
@@ -1350,11 +1503,12 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
if (!inputMap) {
// This block argument was dropped and no replacement value was provided.
// Materialize a replacement value "out of thin air".
- buildUnresolvedMaterialization(
+ Value sourceMat = buildUnresolvedMaterialization(
MaterializationKind::Source,
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
- /*valueToMap=*/origArg, /*inputs=*/ValueRange(),
+ /*valueToMap=*/Value(), /*inputs=*/ValueRange(),
/*outputType=*/origArgType, /*originalType=*/Type(), converter);
+ mapping.map(origArg, sourceMat);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
continue;
}
@@ -1369,19 +1523,10 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
continue;
}
- // This is a 1->1+ mapping. 1->N mappings are not fully supported in the
- // dialect conversion. Therefore, we need an argument materialization to
- // turn the replacement block arguments into a single SSA value that can be
- // used as a replacement.
+ // This is a 1->1+ mapping.
auto replArgs =
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
- if (replArgs.size() == 1) {
- mapping.map(origArg, replArgs.front());
- } else {
- insertNTo1Materialization(
- OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
- /*replacements=*/replArgs, /*outputValue=*/origArg, converter);
- }
+ mapping.map(origArg, replArgs);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
}
@@ -1402,7 +1547,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
/// of input operands.
ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
- Value valueToMap, ValueRange inputs, TypeRange outputTypes,
+ SmallVector<Value,1> valuesToMap, ValueRange inputs, TypeRange outputTypes,
Type originalType, const TypeConverter *converter,
UnrealizedConversionCastOp *castOp) {
assert((!originalType || kind == MaterializationKind::Target) &&
@@ -1410,10 +1555,8 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
// Avoid materializing an unnecessary cast.
if (TypeRange(inputs) == outputTypes) {
- if (valueToMap) {
- assert(inputs.size() == 1 && "1:N mapping is not supported");
- mapping.map(valueToMap, inputs.front());
- }
+ if (!valuesToMap.empty())
+ mapping.mapMaterialization(valuesToMap, inputs);
return inputs;
}
@@ -1423,36 +1566,23 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
auto convertOp =
builder.create<UnrealizedConversionCastOp>(loc, outputTypes, inputs);
- if (valueToMap) {
- assert(outputTypes.size() == 1 && "1:N mapping is not supported");
- mapping.map(valueToMap, convertOp.getResult(0));
- }
+ if (!valuesToMap.empty())
+ mapping.mapMaterialization(valuesToMap, {convertOp.getResult(0)});
if (castOp)
*castOp = convertOp;
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
- originalType, valueToMap);
+ originalType, valuesToMap);
return convertOp.getResults();
}
-void ConversionPatternRewriterImpl::insertNTo1Materialization(
- OpBuilder::InsertPoint ip, Location loc, ValueRange replacements,
- Value originalValue, const TypeConverter *converter) {
- // Insert argument materialization back to the original type.
- Type originalType = originalValue.getType();
- UnrealizedConversionCastOp argCastOp;
- buildUnresolvedMaterialization(
- MaterializationKind::Argument, ip, loc, /*valueToMap=*/originalValue,
- /*inputs=*/replacements, originalType,
- /*originalType=*/Type(), converter, &argCastOp);
- if (argCastOp)
- nTo1TempMaterializations.insert(argCastOp);
-}
-
Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
Value value, const TypeConverter *converter) {
+ //if (Value repl = mapping.lookupDirectSingleReplacement(value))
+ // if (repl.getType() == value.getType())
+ // return repl;
+
// Find a replacement value with the same type.
- Value repl = mapping.lookupOrNull(value, value.getType());
- if (repl)
+ if (Value repl = mapping.lookupOrNull(value, value.getType()))
return repl;
// Check if the value is dead. No replacement value is needed in that case.
@@ -1467,8 +1597,8 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
// No replacement value was found. Get the latest replacement value
// (regardless of the type) and build a source materialization to the
// original type.
- repl = mapping.lookupOrNull(value);
- if (!repl) {
+ SmallVector<Value, 1> repl = mapping.lookupOrNull(value);
+ if (repl.empty()) {
// No replacement value is registered in the mapping. This means that the
// value is dropped and no longer needed. (If the value were still needed,
// a source materialization producing a replacement value "out of thin air"
@@ -1478,34 +1608,12 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
}
Value castValue = buildUnresolvedMaterialization(
MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(),
- /*valueToMap=*/value, /*inputs=*/repl, /*outputType=*/value.getType(),
- /*originalType=*/Type(), converter);
- mapping.map(value, castValue);
+ /*valueToMap=*/repl, /*inputs=*/repl, /*outputType=*/{value.getType()},
+ /*originalType=*/Type(), converter)[0];
+ //mapping.map(value, castValue);
return castValue;
}
-SmallVector<Value>
-ConversionPatternRewriterImpl::unpackNTo1Materialization(Value value) {
- // Unpack unrealized_conversion_cast ops that were inserted as a N:1
- // workaround.
- auto castOp = value.getDefiningOp<UnrealizedConversionCastOp>();
- if (!castOp)
- return {value};
- if (!nTo1TempMaterializations.contains(castOp))
- return {value};
- assert(castOp->getNumResults() == 1 && "expected single result");
-
- SmallVector<Value> result;
- for (Value v : castOp.getOperands()) {
- // Keep unpacking if possible. This is needed because during block
- // signature conversions and 1:N op replacements, the driver may have
- // inserted two materializations back-to-back: first an argument
- // materialization, then a target materialization.
- llvm::append_range(result, unpackNTo1Materialization(v));
- }
- return result;
-}
-
//===----------------------------------------------------------------------===//
// Rewriter Notification Hooks
@@ -1552,11 +1660,12 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
}
// Materialize a replacement value "out of thin air".
- buildUnresolvedMaterialization(
+ Value sourceMat = buildUnresolvedMaterialization(
MaterializationKind::Source, computeInsertPoint(result),
- result.getLoc(), /*valueToMap=*/result, /*inputs=*/ValueRange(),
+ result.getLoc(), /*valueToMap=*/Value(), /*inputs=*/ValueRange(),
/*outputType=*/result.getType(), /*originalType=*/Type(),
currentTypeConverter);
+ mapping.map(result, sourceMat);
continue;
} else {
// Make sure that the user does not mess with unresolved materializations
@@ -1572,16 +1681,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
// Remap result to replacement value.
if (repl.empty())
continue;
-
- if (repl.size() == 1) {
- // Single replacement value: replace directly.
- mapping.map(result, repl.front());
- } else {
- // Multiple replacement values: insert N:1 materialization.
- insertNTo1Materialization(computeInsertPoint(result), result.getLoc(),
- /*replacements=*/repl, /*outputValue=*/result,
- currentTypeConverter);
- }
+ mapping.map(result, repl);
}
appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
@@ -1660,8 +1760,13 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
<< "** Replace : '" << op->getName() << "'(" << op << ")\n";
});
SmallVector<ValueRange> newVals;
- for (size_t i = 0; i < newValues.size(); ++i)
- newVals.push_back(newValues.slice(i, 1));
+ for (size_t i = 0; i < newValues.size(); ++i) {
+ if (newValues[i]) {
+ newVals.push_back(newValues.slice(i, 1));
+ } else {
+ newVals.push_back(ValueRange());
+ }
+ }
impl->notifyOpReplaced(op, newVals);
}
@@ -1729,11 +1834,14 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
});
impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
impl->currentTypeConverter);
- impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
+ SmallVector<Value, 1> mapped = impl->mapping.lookupOrDefault(from);
+ assert(mapped.size() == 1 &&
+ "replaceUsesOfBlockArgument is not supported for 1:N replacements");
+ impl->mapping.map(mapped.front(), to);
}
Value ConversionPatternRewriter::getRemappedValue(Value key) {
- SmallVector<SmallVector<Value>> remappedValues;
+ SmallVector<SmallVector<Value, 1>> remappedValues;
if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
remappedValues)))
return nullptr;
@@ -1746,7 +1854,7 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
SmallVectorImpl<Value> &results) {
if (keys.empty())
return success();
- SmallVector<SmallVector<Value>> remapped;
+ SmallVector<SmallVector<Value, 1>> remapped;
if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
remapped)))
return failure();
@@ -1872,7 +1980,7 @@ ConversionPattern::matchAndRewrite(Operation *op,
getTypeConverter());
// Remap the operands of the operation.
- SmallVector<SmallVector<Value>> remapped;
+ SmallVector<SmallVector<Value, 1>> remapped;
if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
op->getOperands(), remapped))) {
return failure();
diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
index 09c5b4b2a0ad50..d0b62e71ab0cf2 100644
--- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
+++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
@@ -139,7 +139,7 @@ struct TestDecomposeCallGraphTypes
tupleType.getFlattenedTypes(types);
return success();
});
- typeConverter.addArgumentMaterialization(buildMakeTupleOp);
+ typeConverter.addSourceMaterialization(buildMakeTupleOp);
typeConverter.addTargetMaterialization(buildDecomposeTuple);
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 466ae7ff6f46f1..749df2cb9ea7cc 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1235,7 +1235,6 @@ struct TestTypeConverter : public TypeConverter {
using TypeConverter::TypeConverter;
TestTypeConverter() {
addConversion(convertType);
- addArgumentMaterialization(materializeCast);
addSourceMaterialization(materializeCast);
}
diff --git a/mlir/test/lib/Transforms/TestDialectConversion.cpp b/mlir/test/lib/Transforms/TestDialectConversion.cpp
index 2cc1fb5d39d788..a03bf0a1023d57 100644
--- a/mlir/test/lib/Transforms/TestDialectConversion.cpp
+++ b/mlir/test/lib/Transforms/TestDialectConversion.cpp
@@ -28,7 +28,6 @@ namespace {
struct PDLLTypeConverter : public TypeConverter {
PDLLTypeConverter() {
addConversion(convertType);
- addArgumentMaterialization(materializeCast);
addSourceMaterialization(materializeCast);
}
More information about the Mlir-commits
mailing list