[Mlir-commits] [mlir] [mlir][Transforms] Make lookup without type converter unambiguous (PR #151747)
Matthias Springer
llvmlistbot at llvm.org
Sat Aug 2 06:37:56 PDT 2025
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/151747
>From d087c2f8fb250f94654447792a9d8c91da8063ed Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Fri, 1 Aug 2025 17:15:23 +0000
Subject: [PATCH] [mlir][Transforms] Make lookup without type converter
unambiguous
---
mlir/docs/DialectConversion.md | 34 ++--
.../Transforms/Utils/DialectConversion.cpp | 185 ++++++++++++------
mlir/test/Transforms/test-legalizer.mlir | 17 ++
mlir/test/lib/Dialect/Test/TestOps.td | 4 +
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 49 ++++-
5 files changed, 220 insertions(+), 69 deletions(-)
diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index cf577eca5b9a6..89c8c78749957 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -202,17 +202,29 @@ struct MyConversionPattern : public ConversionPattern {
#### Type Safety
-The types of the remapped operands provided to a conversion pattern must be of a
-type expected by the pattern. The expected types of a pattern are determined by
-a provided [TypeConverter](#type-converter). If no type converter is provided,
-the types of the remapped operands are expected to match the types of the
-original operands. If a type converter is provided, the types of the remapped
-operands are expected to be legal as determined by the converter. If the
-remapped operand types are not of an expected type, and a materialization to the
-expected type could not be performed, the pattern fails application before the
-`matchAndRewrite` hook is invoked. This ensures that patterns do not have to
-explicitly ensure type safety, or sanitize the types of the incoming remapped
-operands. More information on type conversion is detailed in the
+The types of the remapped operands provided to a conversion pattern (through
+the adaptor or `ArrayRef` of operands) depend on type conversio rules.
+
+If the pattern was initialized with a [type converter](#type-converter), the
+conversion driver passes values whose types match the legalized types of the
+operands of the matched operation as per the type converter. To that end, the
+conversion driver may insert target materializations to convert the most
+recently mapped values to the expected legalized types. The driver tries to
+reuse existing materializations on a best-effort basis, but this is not
+guaranteed by the infrastructure. If the operand types of the matched op could
+not be legalized, the pattern fails to apply before the `matchAndRewrite` hook
+is invoked.
+
+If the pattern was initialized without a type converter, the conversion driver
+passes the most recently mapped values to the pattern, excluding any
+materializations. Materializations are intentionally excluded because their
+presence may depend on other patterns. If a value of the same type as an
+operand is desired, users can directly take the respective operand from the
+matched operation.
+
+The above rules ensure that patterns do not have to explicitly ensure type
+safety, or sanitize the types of the incoming remapped operands. More
+information on type conversion is detailed in the
[dedicated section](#type-conversion) below.
## Type Conversion
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index f23c6197accd5..7c13c98b7b0d0 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -131,7 +131,16 @@ struct ConversionValueMapping {
/// recently mapped values.
/// - If there is no mapping for the given values at all, return the given
/// value.
- ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
+ ///
+ /// If `skipMaterializations` is true, materializations are not considered.
+ ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {},
+ bool skipMaterializations = false) const;
+
+ /// Lookup a value from the mapping. (Just once, not following the chain of
+ /// potential mappings.) Look for actual replacements first, then for
+ /// materializations. The materializations lookup can be skipped.
+ ValueVector lookupSingleStep(const ValueVector &from,
+ bool skipMaterializations = false) const;
template <typename T>
struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
@@ -139,8 +148,8 @@ struct ConversionValueMapping {
/// Map a value vector to the one provided.
template <typename OldVal, typename NewVal>
std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
- map(OldVal &&oldVal, NewVal &&newVal) {
- LLVM_DEBUG({
+ map(OldVal &&oldVal, NewVal &&newVal, bool isOnlyTypeConversion = false) {
+ auto checkCircularMapping = [&](auto &mapping) {
ValueVector next(newVal);
while (true) {
assert(next != oldVal && "inserting cyclic mapping");
@@ -149,45 +158,117 @@ struct ConversionValueMapping {
break;
next = it->second;
}
- });
+ };
+ (void)checkCircularMapping;
+
mappedTo.insert_range(newVal);
- mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
+ if (isOnlyTypeConversion) {
+ // This is a materialization.
+ LLVM_DEBUG({ checkCircularMapping(materializations); });
+ materializations[std::forward<OldVal>(oldVal)] =
+ std::forward<NewVal>(newVal);
+ } else {
+ // This is a regular value replacement.
+ LLVM_DEBUG({ checkCircularMapping(mapping); });
+ mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
+ }
}
/// Map a value vector or single value to the one provided.
template <typename OldVal, typename NewVal>
std::enable_if_t<!IsValueVector<OldVal>::value ||
!IsValueVector<NewVal>::value>
- map(OldVal &&oldVal, NewVal &&newVal) {
+ map(OldVal &&oldVal, NewVal &&newVal, bool isOnlyTypeConversion = false) {
if constexpr (IsValueVector<OldVal>{}) {
- map(std::forward<OldVal>(oldVal), ValueVector{newVal});
+ map(std::forward<OldVal>(oldVal), ValueVector{newVal},
+ isOnlyTypeConversion);
} else if constexpr (IsValueVector<NewVal>{}) {
- map(ValueVector{oldVal}, std::forward<NewVal>(newVal));
+ map(ValueVector{oldVal}, std::forward<NewVal>(newVal),
+ isOnlyTypeConversion);
} else {
- map(ValueVector{oldVal}, ValueVector{newVal});
+ map(ValueVector{oldVal}, ValueVector{newVal}, isOnlyTypeConversion);
}
}
- void map(Value oldVal, SmallVector<Value> &&newVal) {
+ void map(Value oldVal, SmallVector<Value> &&newVal,
+ bool isOnlyTypeConversion = false) {
map(ValueVector{oldVal}, ValueVector(std::move(newVal)));
}
/// Drop the last mapping for the given values.
- void erase(const ValueVector &value) { mapping.erase(value); }
+ void erase(const ValueVector &value) {
+ mapping.erase(value);
+ materializations.erase(value);
+ }
private:
- /// Current value mappings.
+ /// Mapping of actual replacements.
DenseMap<ValueVector, ValueVector, ValueVectorMapInfo> mapping;
+ /// Mapping of materializations that are created only to resolve type
+ /// mismatches.
+ DenseMap<ValueVector, ValueVector, ValueVectorMapInfo> materializations;
+
/// All SSA values that are mapped to. May contain false positives.
DenseSet<Value> mappedTo;
};
} // namespace
ValueVector
-ConversionValueMapping::lookupOrDefault(Value from,
- TypeRange desiredTypes) const {
+ConversionValueMapping::lookupSingleStep(const ValueVector &from,
+ bool skipMaterializations) const {
+ // Continue the lookup on each value separately. (Each value could have been
+ // mapped to one or multiple other values.)
+ ValueVector next;
+ for (Value v : from) {
+ // First check regular value replacements.
+ auto it = mapping.find({v});
+ if (it != mapping.end()) {
+ llvm::append_range(next, it->second);
+ continue;
+ }
+ if (skipMaterializations) {
+ next.push_back(v);
+ continue;
+ }
+ // Then check materializations.
+ it = materializations.find({v});
+ if (it != materializations.end()) {
+ llvm::append_range(next, it->second);
+ continue;
+ }
+ next.push_back(v);
+ }
+
+ if (next != from)
+ return next;
+
+ // Otherwise: Check if there is a mapping for the entire vector. Such
+ // mappings are materializations. (N:M mapping are not supported for value
+ // replacements.)
+ //
+ // Note: From a correctness point of view, materializations do not have to
+ // be stored (and looked up) in the mapping. But for performance reasons,
+ // we choose to reuse existing IR (when possible) instead of creating it
+ // multiple times.
+ //
+ // First check regular value replacements.
+ auto it = mapping.find(from);
+ if (it != mapping.end())
+ return it->second;
+ if (skipMaterializations)
+ return {};
+ // Then check materializations.
+ it = materializations.find(from);
+ if (it != materializations.end())
+ return it->second;
+ return {};
+}
+
+ValueVector
+ConversionValueMapping::lookupOrDefault(Value from, TypeRange desiredTypes,
+ bool skipMaterializations) const {
// Try to find the deepest values that have the desired types. If there is no
// such mapping, simply return the deepest values.
ValueVector desiredValue;
@@ -197,36 +278,13 @@ ConversionValueMapping::lookupOrDefault(Value from,
if (TypeRange(ValueRange(current)) == desiredTypes)
desiredValue = current;
- // If possible, Replace each value with (one or multiple) mapped values.
- ValueVector next;
- for (Value v : current) {
- auto it = mapping.find({v});
- if (it != mapping.end()) {
- llvm::append_range(next, it->second);
- } else {
- next.push_back(v);
- }
- }
- if (next != current) {
- // If at least one value was replaced, continue the lookup from there.
- current = std::move(next);
- continue;
- }
-
- // Otherwise: Check if there is a mapping for the entire vector. Such
- // mappings are materializations. (N:M mapping are not supported for value
- // replacements.)
- //
- // Note: From a correctness point of view, materializations do not have to
- // be stored (and looked up) in the mapping. But for performance reasons,
- // we choose to reuse existing IR (when possible) instead of creating it
- // multiple times.
- auto it = mapping.find(current);
- if (it == mapping.end()) {
+ ValueVector next = lookupSingleStep(current, skipMaterializations);
+ if (next.empty()) {
// No mapping found: The lookup stops here.
break;
}
- current = it->second;
+
+ current = std::move(next);
} while (true);
// If the desired values were found use them, otherwise default to the leaf
@@ -930,7 +988,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// recently mapped values.
/// - If there is no mapping for the given values at all, return the given
/// value.
- ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
+ ///
+ /// If `skipMaterializations` is true, materializations are not considered.
+ ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {},
+ bool skipMaterializations = false) const;
/// Lookup the given value within the map, or return an empty vector if the
/// value is not mapped. If it is mapped, this follows the same behavior
@@ -993,11 +1054,18 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// If `valuesToMap` is set to a non-null Value, then that value is mapped to
/// the results of the unresolved materialization in the conversion value
/// mapping.
+ ///
+ /// If `isOnlyTypeConversion` is "true", the materialization is created to
+ /// resolve a type mismatch, and not a regular value replacement issued by
+ /// the user. (Replacement values that are created "out of thin air" are
+ /// treated appear like unresolved materializations, but are not just type
+ /// conversions.)
ValueRange buildUnresolvedMaterialization(
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
Type originalType, const TypeConverter *converter,
- UnrealizedConversionCastOp *castOp = nullptr);
+ UnrealizedConversionCastOp *castOp = nullptr,
+ bool isOnlyTypeConversion = true);
/// Find a replacement value for the given SSA value in the conversion value
/// mapping. The replacement value must have the same type as the given SSA
@@ -1264,10 +1332,9 @@ void ConversionPatternRewriterImpl::applyRewrites() {
// State Management
//===----------------------------------------------------------------------===//
-ValueVector
-ConversionPatternRewriterImpl::lookupOrDefault(Value from,
- TypeRange desiredTypes) const {
- return mapping.lookupOrDefault(from, desiredTypes);
+ValueVector ConversionPatternRewriterImpl::lookupOrDefault(
+ Value from, TypeRange desiredTypes, bool skipMaterializations) const {
+ return mapping.lookupOrDefault(from, desiredTypes, skipMaterializations);
}
ValueVector
@@ -1324,10 +1391,13 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
if (!currentTypeConverter) {
- // The current pattern does not have a type converter. I.e., it does not
- // distinguish between legal and illegal types. For each operand, simply
- // pass through the most recently mapped values.
- remapped.push_back(lookupOrDefault(operand));
+ // The current pattern does not have a type converter. Pass the most
+ // recently mapped values, excluding materializations. Materializations
+ // are intentionally excluded because their presence may depend on other
+ // patterns. Including materializations would make the lookup fragile
+ // and unpredictable.
+ remapped.push_back(lookupOrDefault(operand, /*desiredTypes=*/{},
+ /*skipMaterializations=*/true));
continue;
}
@@ -1356,7 +1426,8 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
}
// Create a materialization for the most recently mapped values.
- repl = lookupOrDefault(operand);
+ repl = lookupOrDefault(operand, /*desiredTypes=*/{},
+ /*skipMaterializations=*/true);
ValueRange castValues = buildUnresolvedMaterialization(
MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
/*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
@@ -1482,7 +1553,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
OpBuilder::InsertPoint(newBlock, newBlock->begin()),
origArg.getLoc(),
/*valuesToMap=*/{}, /*inputs=*/ValueRange(),
- /*outputTypes=*/origArgType, /*originalType=*/Type(), converter)
+ /*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
+ /*castOp=*/nullptr, /*isOnlyTypeConversion=*/false)
.front();
replaceUsesOfBlockArgument(origArg, mat, converter);
continue;
@@ -1523,7 +1595,7 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
Type originalType, const TypeConverter *converter,
- UnrealizedConversionCastOp *castOp) {
+ UnrealizedConversionCastOp *castOp, bool isOnlyTypeConversion) {
assert((!originalType || kind == MaterializationKind::Target) &&
"original type is valid only for target materializations");
assert(TypeRange(inputs) != outputTypes &&
@@ -1536,7 +1608,7 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
auto convertOp =
UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
if (!valuesToMap.empty())
- mapping.map(valuesToMap, convertOp.getResults());
+ mapping.map(valuesToMap, convertOp.getResults(), isOnlyTypeConversion);
if (castOp)
*castOp = convertOp;
unresolvedMaterializations[convertOp] =
@@ -1650,7 +1722,8 @@ void ConversionPatternRewriterImpl::replaceOp(
MaterializationKind::Source, computeInsertPoint(result),
result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(),
/*outputTypes=*/result.getType(), /*originalType=*/Type(),
- currentTypeConverter);
+ currentTypeConverter, /*castOp=*/nullptr,
+ /*isOnlyTypeConversion=*/false);
continue;
}
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index e4406e60ffead..5630d1540e4d5 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -415,3 +415,20 @@ func.func @test_multiple_1_to_n_replacement() {
%0 = "test.multiple_1_to_n_replacement"() : () -> (f16)
"test.invalid"(%0) : (f16) -> ()
}
+
+// -----
+
+// CHECK-LABEL: func @test_lookup_without_converter
+// CHECK: %[[producer:.*]] = "test.valid_producer"() : () -> i16
+// CHECK: %[[cast:.*]] = "test.cast"(%[[producer]]) : (i16) -> f64
+// CHECK: "test.valid_consumer"(%[[cast]]) : (f64) -> ()
+// CHECK: "test.valid_consumer"(%[[producer]]) : (i16) -> ()
+func.func @test_lookup_without_converter() {
+ %0 = "test.replace_with_valid_producer"() {type = i16} : () -> (i64)
+ "test.replace_with_valid_consumer"(%0) {with_converter} : (i64) -> ()
+ // Make sure that the second "replace_with_valid_consumer" lowering does not
+ // lookup the materialization that was created for the above op.
+ "test.replace_with_valid_consumer"(%0) : (i64) -> ()
+ // expected-remark at +1 {{op 'func.return' is not legalizable}}
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 2eaad552a7a3a..843bd30a51aff 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2104,6 +2104,10 @@ def TestInvalidOp : TEST_Op<"invalid", [Terminator]>,
Arguments<(ins Variadic<AnyType>)>;
def TestTypeProducerOp : TEST_Op<"type_producer">,
Results<(outs AnyType)>;
+def TestValidProducerOp : TEST_Op<"valid_producer">,
+ Results<(outs AnyType)>;
+def TestValidConsumerOp : TEST_Op<"valid_consumer">,
+ Arguments<(ins AnyType)>;
def TestAnotherTypeProducerOp : TEST_Op<"another_type_producer">,
Results<(outs AnyType)>;
def TestTypeConsumerOp : TEST_Op<"type_consumer">,
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index eda618f5b09c6..7150401bdbdce 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1198,6 +1198,47 @@ class TestEraseOp : public ConversionPattern {
}
};
+/// Pattern that replaces test.replace_with_valid_producer with
+/// test.valid_producer and the specified type.
+class TestReplaceWithValidProducer : public ConversionPattern {
+public:
+ TestReplaceWithValidProducer(MLIRContext *ctx)
+ : ConversionPattern("test.replace_with_valid_producer", 1, ctx) {}
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ auto attr = op->getAttrOfType<TypeAttr>("type");
+ if (!attr)
+ return failure();
+ rewriter.replaceOpWithNewOp<TestValidProducerOp>(op, attr.getValue());
+ return success();
+ }
+};
+
+/// Pattern that replaces test.replace_with_valid_consumer with
+/// test.valid_consumer. Can be used with and without a type converter.
+class TestReplaceWithValidConsumer : public ConversionPattern {
+public:
+ TestReplaceWithValidConsumer(MLIRContext *ctx, const TypeConverter &converter)
+ : ConversionPattern(converter, "test.replace_with_valid_consumer", 1,
+ ctx) {}
+ TestReplaceWithValidConsumer(MLIRContext *ctx)
+ : ConversionPattern("test.replace_with_valid_consumer", 1, ctx) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ // with_converter present: pattern must have been initialized with a type
+ // converter.
+ // with_converter absent: pattern must have been initialized without a type
+ // converter.
+ if (op->hasAttr("with_converter") != static_cast<bool>(getTypeConverter()))
+ return failure();
+ rewriter.replaceOpWithNewOp<TestValidConsumerOp>(op, operands[0]);
+ return success();
+ }
+};
+
/// This pattern matches a test.convert_block_args op. It either:
/// a) Duplicates all block arguments,
/// b) or: drops all block arguments and replaces each with 2x the first
@@ -1314,6 +1355,7 @@ struct TestTypeConverter : public TypeConverter {
TestTypeConverter() {
addConversion(convertType);
addSourceMaterialization(materializeCast);
+ addTargetMaterialization(materializeCast);
}
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
@@ -1389,10 +1431,12 @@ struct TestLegalizePatternDriver
TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
TestUndoPropertiesModification, TestEraseOp,
+ TestReplaceWithValidProducer, TestReplaceWithValidConsumer,
TestRepetitive1ToNConsumer>(&getContext());
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
TestPassthroughInvalidOp, TestMultiple1ToNReplacement,
- TestBlockArgReplace>(&getContext(), converter);
+ TestBlockArgReplace, TestReplaceWithValidConsumer>(
+ &getContext(), converter);
patterns.add<TestConvertBlockArgs>(converter, &getContext());
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
@@ -1402,7 +1446,8 @@ struct TestLegalizePatternDriver
ConversionTarget target(getContext());
target.addLegalOp<ModuleOp>();
target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
- TerminatorOp, OneRegionOp>();
+ TerminatorOp, OneRegionOp, TestValidProducerOp,
+ TestValidConsumerOp>();
target.addLegalOp(OperationName("test.legal_op", &getContext()));
target
.addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
More information about the Mlir-commits
mailing list