[Mlir-commits] [mlir] [mlir][Transforms] Make lookup without type converter unambiguous (PR #151747)
Matthias Springer
llvmlistbot at llvm.org
Wed Aug 6 04:44:58 PDT 2025
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/151747
>From 45f831dc4c9644f41913b0285a2368605745654d Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sun, 3 Aug 2025 11:03:21 +0000
Subject: [PATCH] make lookup unambiguous
---
mlir/docs/DialectConversion.md | 67 +++++-
.../Transforms/Utils/DialectConversion.cpp | 203 +++++++++++-------
mlir/test/Transforms/test-legalizer.mlir | 17 ++
mlir/test/lib/Dialect/Test/TestOps.td | 4 +
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 49 ++++-
5 files changed, 255 insertions(+), 85 deletions(-)
diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index cf577eca5b9a6..556e73c2d56c7 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -202,17 +202,62 @@ struct MyConversionPattern : public ConversionPattern {
#### Type Safety
-The types of the remapped operands provided to a conversion pattern must be of a
-type expected by the pattern. The expected types of a pattern are determined by
-a provided [TypeConverter](#type-converter). If no type converter is provided,
-the types of the remapped operands are expected to match the types of the
-original operands. If a type converter is provided, the types of the remapped
-operands are expected to be legal as determined by the converter. If the
-remapped operand types are not of an expected type, and a materialization to the
-expected type could not be performed, the pattern fails application before the
-`matchAndRewrite` hook is invoked. This ensures that patterns do not have to
-explicitly ensure type safety, or sanitize the types of the incoming remapped
-operands. More information on type conversion is detailed in the
+The types of the remapped operands provided to a conversion pattern (through
+the adaptor or `ArrayRef` of operands) depend on type conversion rules.
+
+If the pattern was initialized with a [type converter](#type-converter), the
+conversion driver passes values whose types match the legalized types of the
+operands of the matched operation as per the type converter. To that end, the
+conversion driver may insert target materializations to convert the most
+recently mapped values to the expected legalized types. The driver tries to
+reuse existing materializations on a best-effort basis, but this is not
+guaranteed by the infrastructure. If the operand types of the matched op could
+not be legalized, the pattern fails to apply before the `matchAndRewrite` hook
+is invoked.
+
+Example:
+```c++
+// Type converter that converts all FloatTypes to IntegerTypes.
+TypeConverter converter;
+converter.addConversion([](FloatType t) {
+ return IntegerType::get(t.getContext(), t.getWidth());
+});
+
+// Assuming that `MyConversionPattern` was initialized with `converter`.
+struct MyConversionPattern : public ConversionPattern {
+ virtual LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands, /* ... */) const {
+// ^^^^^^^^
+// If `op` has a FloatType operand, the respective value in `operands`
+// is guaranteed to have the legalized IntegerType. If another pattern
+// previously replaced the operand SSA value with an SSA value of the
+// legalized type (via "replaceOp" or "applySignatureConversion"), you
+// will get that SSA value directly (unless the replacement value was
+// also replaced). Otherwise, you will get a materialization to the
+// legalized type.
+```
+
+If the pattern was initialized without a type converter, the conversion driver
+passes the most recently mapped values to the pattern, excluding any
+materializations. If a value with the same type as the original operand is
+desired, users can directly take the respective operand from the matched
+operation.
+
+Example: When initializing the pattern from the above example without a type
+converter, `operands` contains the most recent replacement values, regardless
+of their types.
+
+Note: When running without a type converter, materializations are intentionally
+excluded from the lookup process because their presence may depend on other
+patterns. Passing materializations would make the conversion infrastructure
+fragile and unpredictable. Moreover, there could be multiple materializations
+to different types. (This can be the case when multiple patterns are running
+with different type converters.) In such a case, it would be unclear which
+materialization to pass.
+
+The above rules ensure that patterns do not have to explicitly ensure type
+safety, or sanitize the types of the incoming remapped operands. More
+information on type conversion is detailed in the
[dedicated section](#type-conversion) below.
## Type Conversion
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index f23c6197accd5..dedc84f1adde9 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -121,17 +121,8 @@ struct ConversionValueMapping {
/// false positives.
bool isMappedTo(Value value) const { return mappedTo.contains(value); }
- /// Lookup the most recently mapped values with the desired types in the
- /// mapping.
- ///
- /// Special cases:
- /// - If the desired type range is empty, simply return the most recently
- /// mapped values.
- /// - If there is no mapping to the desired types, also return the most
- /// recently mapped values.
- /// - If there is no mapping for the given values at all, return the given
- /// value.
- ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
+ /// Lookup a value in the mapping.
+ ValueVector lookup(const ValueVector &from) const;
template <typename T>
struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
@@ -185,54 +176,31 @@ struct ConversionValueMapping {
};
} // namespace
-ValueVector
-ConversionValueMapping::lookupOrDefault(Value from,
- TypeRange desiredTypes) const {
- // Try to find the deepest values that have the desired types. If there is no
- // such mapping, simply return the deepest values.
- ValueVector desiredValue;
- ValueVector current{from};
- do {
- // Store the current value if the types match.
- if (TypeRange(ValueRange(current)) == desiredTypes)
- desiredValue = current;
-
- // If possible, Replace each value with (one or multiple) mapped values.
- ValueVector next;
- for (Value v : current) {
- auto it = mapping.find({v});
- if (it != mapping.end()) {
- llvm::append_range(next, it->second);
- } else {
- next.push_back(v);
- }
- }
- if (next != current) {
- // If at least one value was replaced, continue the lookup from there.
- current = std::move(next);
- continue;
- }
-
- // Otherwise: Check if there is a mapping for the entire vector. Such
- // mappings are materializations. (N:M mapping are not supported for value
- // replacements.)
- //
- // Note: From a correctness point of view, materializations do not have to
- // be stored (and looked up) in the mapping. But for performance reasons,
- // we choose to reuse existing IR (when possible) instead of creating it
- // multiple times.
- auto it = mapping.find(current);
- if (it == mapping.end()) {
- // No mapping found: The lookup stops here.
- break;
- }
- current = it->second;
- } while (true);
+/// Marker attribute for pure type conversions. I.e., mappings whose only
+/// purpose is to resolve a type mismatch. (In contrast, mappings that point to
+/// the replacement values of a "replaceOp" call, etc., are not pure type
+/// conversions.)
+static const StringRef kPureTypeConversionMarker = "__pure_type_conversion__";
+
+/// A vector of values is a pure type conversion if all values are defined by
+/// the same operation and the operation has the `kPureTypeConversionMarker`
+/// attribute.
+static bool isPureTypeConversion(const ValueVector &values) {
+ assert(!values.empty() && "expected non-empty value vector");
+ Operation *op = values.front().getDefiningOp();
+ for (Value v : llvm::drop_begin(values))
+ if (v.getDefiningOp() != op)
+ return false;
+ return op && op->hasAttr(kPureTypeConversionMarker);
+}
- // If the desired values were found use them, otherwise default to the leaf
- // values.
- // Note: If `desiredTypes` is empty, this function always returns `current`.
- return !desiredValue.empty() ? std::move(desiredValue) : std::move(current);
+ValueVector ConversionValueMapping::lookup(const ValueVector &from) const {
+ auto it = mapping.find(from);
+ if (it == mapping.end()) {
+ // No mapping found: The lookup stops here.
+ return {};
+ }
+ return it->second;
}
//===----------------------------------------------------------------------===//
@@ -930,7 +898,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// recently mapped values.
/// - If there is no mapping for the given values at all, return the given
/// value.
- ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
+ ///
+ /// If `skipPureTypeConversions` is "true", materializations that are pure
+ /// type conversions are not considered.
+ ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {},
+ bool skipPureTypeConversions = false) const;
/// Lookup the given value within the map, or return an empty vector if the
/// value is not mapped. If it is mapped, this follows the same behavior
@@ -993,11 +965,19 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// If `valuesToMap` is set to a non-null Value, then that value is mapped to
/// the results of the unresolved materialization in the conversion value
/// mapping.
+ ///
+ /// If `isPureTypeConversion` is "true", the materialization is created only
+ /// to resolve a type mismatch. That means it is not a regular value
+ /// replacement issued by the user. (Replacement values that are created
+ /// "out of thin air" appear like unresolved materializations because they are
+ /// unrealized_conversion_cast ops. However, they must be treated like
+ /// regular value replacements.)
ValueRange buildUnresolvedMaterialization(
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
Type originalType, const TypeConverter *converter,
- UnrealizedConversionCastOp *castOp = nullptr);
+ UnrealizedConversionCastOp *castOp = nullptr,
+ bool isPureTypeConversion = true);
/// Find a replacement value for the given SSA value in the conversion value
/// mapping. The replacement value must have the same type as the given SSA
@@ -1264,10 +1244,77 @@ void ConversionPatternRewriterImpl::applyRewrites() {
// State Management
//===----------------------------------------------------------------------===//
-ValueVector
-ConversionPatternRewriterImpl::lookupOrDefault(Value from,
- TypeRange desiredTypes) const {
- return mapping.lookupOrDefault(from, desiredTypes);
+ValueVector ConversionPatternRewriterImpl::lookupOrDefault(
+ Value from, TypeRange desiredTypes, bool skipPureTypeConversions) const {
+ // Helper function that looks up each value in `values` individually and then
+ // composes the results. If that fails, it tries to look up the entire vector
+ // at once.
+ auto composedLookup = [&](const ValueVector &values) -> ValueVector {
+ // If possible, replace each value with (one or multiple) mapped values.
+ ValueVector next;
+ for (Value v : values) {
+ ValueVector r = mapping.lookup({v});
+ if (!r.empty()) {
+ llvm::append_range(next, r);
+ } else {
+ next.push_back(v);
+ }
+ }
+ if (next != values) {
+ // At least one value was replaced.
+ return next;
+ }
+
+ // Otherwise: Check if there is a mapping for the entire vector. Such
+ // mappings are materializations. (N:M mapping are not supported for value
+ // replacements.)
+ //
+ // Note: From a correctness point of view, materializations do not have to
+ // be stored (and looked up) in the mapping. But for performance reasons,
+ // we choose to reuse existing IR (when possible) instead of creating it
+ // multiple times.
+ ValueVector r = mapping.lookup(values);
+ if (r.empty()) {
+ // No mapping found: The lookup stops here.
+ return {};
+ }
+ return r;
+ };
+
+ // Try to find the deepest values that have the desired types. If there is no
+ // such mapping, simply return the deepest values.
+ ValueVector desiredValue;
+ ValueVector current{from};
+ ValueVector lastNonMaterialization{from};
+ do {
+ // Store the current value if the types match.
+ bool match = TypeRange(ValueRange(current)) == desiredTypes;
+ if (skipPureTypeConversions) {
+ // Skip pure type conversions, if requested.
+ bool pureConversion = isPureTypeConversion(current);
+ match &= !pureConversion;
+ // Keep track of the last mapped value that was not a pure type
+ // conversion.
+ if (!pureConversion)
+ lastNonMaterialization = current;
+ }
+ if (match)
+ desiredValue = current;
+
+ // Lookup next value in the mapping.
+ ValueVector next = composedLookup(current);
+ if (next.empty())
+ break;
+ current = std::move(next);
+ } while (true);
+
+ // If the desired values were found use them, otherwise default to the leaf
+ // values. (Skip pure type conversions, if requested.)
+ if (!desiredTypes.empty())
+ return desiredValue;
+ if (skipPureTypeConversions)
+ return lastNonMaterialization;
+ return current;
}
ValueVector
@@ -1324,10 +1371,13 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
if (!currentTypeConverter) {
- // The current pattern does not have a type converter. I.e., it does not
- // distinguish between legal and illegal types. For each operand, simply
- // pass through the most recently mapped values.
- remapped.push_back(lookupOrDefault(operand));
+ // The current pattern does not have a type converter. Pass the most
+ // recently mapped values, excluding materializations. Materializations
+ // are intentionally excluded because their presence may depend on other
+ // patterns. Including materializations would make the lookup fragile
+ // and unpredictable.
+ remapped.push_back(lookupOrDefault(operand, /*desiredTypes=*/{},
+ /*skipPureTypeConversions=*/true));
continue;
}
@@ -1356,7 +1406,8 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
}
// Create a materialization for the most recently mapped values.
- repl = lookupOrDefault(operand);
+ repl = lookupOrDefault(operand, /*desiredTypes=*/{},
+ /*skipPureTypeConversions=*/true);
ValueRange castValues = buildUnresolvedMaterialization(
MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
/*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
@@ -1482,7 +1533,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
OpBuilder::InsertPoint(newBlock, newBlock->begin()),
origArg.getLoc(),
/*valuesToMap=*/{}, /*inputs=*/ValueRange(),
- /*outputTypes=*/origArgType, /*originalType=*/Type(), converter)
+ /*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
+ /*castOp=*/nullptr, /*isPureTypeConversion=*/false)
.front();
replaceUsesOfBlockArgument(origArg, mat, converter);
continue;
@@ -1523,7 +1575,7 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
Type originalType, const TypeConverter *converter,
- UnrealizedConversionCastOp *castOp) {
+ UnrealizedConversionCastOp *castOp, bool isPureTypeConversion) {
assert((!originalType || kind == MaterializationKind::Target) &&
"original type is valid only for target materializations");
assert(TypeRange(inputs) != outputTypes &&
@@ -1535,6 +1587,8 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
auto convertOp =
UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
+ if (isPureTypeConversion)
+ convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr());
if (!valuesToMap.empty())
mapping.map(valuesToMap, convertOp.getResults());
if (castOp)
@@ -1650,7 +1704,8 @@ void ConversionPatternRewriterImpl::replaceOp(
MaterializationKind::Source, computeInsertPoint(result),
result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(),
/*outputTypes=*/result.getType(), /*originalType=*/Type(),
- currentTypeConverter);
+ currentTypeConverter, /*castOp=*/nullptr,
+ /*isPureTypeConversion=*/false);
continue;
}
@@ -2902,6 +2957,10 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
SmallVector<UnrealizedConversionCastOp> remainingCastOps;
reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
+ // Drop markers.
+ for (UnrealizedConversionCastOp castOp : remainingCastOps)
+ castOp->removeAttr(kPureTypeConversionMarker);
+
// Try to legalize all unresolved materializations.
if (config.buildMaterializations) {
IRRewriter rewriter(rewriterImpl.context, config.listener);
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index e4406e60ffead..5630d1540e4d5 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -415,3 +415,20 @@ func.func @test_multiple_1_to_n_replacement() {
%0 = "test.multiple_1_to_n_replacement"() : () -> (f16)
"test.invalid"(%0) : (f16) -> ()
}
+
+// -----
+
+// CHECK-LABEL: func @test_lookup_without_converter
+// CHECK: %[[producer:.*]] = "test.valid_producer"() : () -> i16
+// CHECK: %[[cast:.*]] = "test.cast"(%[[producer]]) : (i16) -> f64
+// CHECK: "test.valid_consumer"(%[[cast]]) : (f64) -> ()
+// CHECK: "test.valid_consumer"(%[[producer]]) : (i16) -> ()
+func.func @test_lookup_without_converter() {
+ %0 = "test.replace_with_valid_producer"() {type = i16} : () -> (i64)
+ "test.replace_with_valid_consumer"(%0) {with_converter} : (i64) -> ()
+ // Make sure that the second "replace_with_valid_consumer" lowering does not
+ // lookup the materialization that was created for the above op.
+ "test.replace_with_valid_consumer"(%0) : (i64) -> ()
+ // expected-remark at +1 {{op 'func.return' is not legalizable}}
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 2eaad552a7a3a..843bd30a51aff 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2104,6 +2104,10 @@ def TestInvalidOp : TEST_Op<"invalid", [Terminator]>,
Arguments<(ins Variadic<AnyType>)>;
def TestTypeProducerOp : TEST_Op<"type_producer">,
Results<(outs AnyType)>;
+def TestValidProducerOp : TEST_Op<"valid_producer">,
+ Results<(outs AnyType)>;
+def TestValidConsumerOp : TEST_Op<"valid_consumer">,
+ Arguments<(ins AnyType)>;
def TestAnotherTypeProducerOp : TEST_Op<"another_type_producer">,
Results<(outs AnyType)>;
def TestTypeConsumerOp : TEST_Op<"type_consumer">,
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index eda618f5b09c6..7150401bdbdce 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1198,6 +1198,47 @@ class TestEraseOp : public ConversionPattern {
}
};
+/// Pattern that replaces test.replace_with_valid_producer with
+/// test.valid_producer and the specified type.
+class TestReplaceWithValidProducer : public ConversionPattern {
+public:
+ TestReplaceWithValidProducer(MLIRContext *ctx)
+ : ConversionPattern("test.replace_with_valid_producer", 1, ctx) {}
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ auto attr = op->getAttrOfType<TypeAttr>("type");
+ if (!attr)
+ return failure();
+ rewriter.replaceOpWithNewOp<TestValidProducerOp>(op, attr.getValue());
+ return success();
+ }
+};
+
+/// Pattern that replaces test.replace_with_valid_consumer with
+/// test.valid_consumer. Can be used with and without a type converter.
+class TestReplaceWithValidConsumer : public ConversionPattern {
+public:
+ TestReplaceWithValidConsumer(MLIRContext *ctx, const TypeConverter &converter)
+ : ConversionPattern(converter, "test.replace_with_valid_consumer", 1,
+ ctx) {}
+ TestReplaceWithValidConsumer(MLIRContext *ctx)
+ : ConversionPattern("test.replace_with_valid_consumer", 1, ctx) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ // with_converter present: pattern must have been initialized with a type
+ // converter.
+ // with_converter absent: pattern must have been initialized without a type
+ // converter.
+ if (op->hasAttr("with_converter") != static_cast<bool>(getTypeConverter()))
+ return failure();
+ rewriter.replaceOpWithNewOp<TestValidConsumerOp>(op, operands[0]);
+ return success();
+ }
+};
+
/// This pattern matches a test.convert_block_args op. It either:
/// a) Duplicates all block arguments,
/// b) or: drops all block arguments and replaces each with 2x the first
@@ -1314,6 +1355,7 @@ struct TestTypeConverter : public TypeConverter {
TestTypeConverter() {
addConversion(convertType);
addSourceMaterialization(materializeCast);
+ addTargetMaterialization(materializeCast);
}
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
@@ -1389,10 +1431,12 @@ struct TestLegalizePatternDriver
TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
TestUndoPropertiesModification, TestEraseOp,
+ TestReplaceWithValidProducer, TestReplaceWithValidConsumer,
TestRepetitive1ToNConsumer>(&getContext());
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
TestPassthroughInvalidOp, TestMultiple1ToNReplacement,
- TestBlockArgReplace>(&getContext(), converter);
+ TestBlockArgReplace, TestReplaceWithValidConsumer>(
+ &getContext(), converter);
patterns.add<TestConvertBlockArgs>(converter, &getContext());
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
@@ -1402,7 +1446,8 @@ struct TestLegalizePatternDriver
ConversionTarget target(getContext());
target.addLegalOp<ModuleOp>();
target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
- TerminatorOp, OneRegionOp>();
+ TerminatorOp, OneRegionOp, TestValidProducerOp,
+ TestValidConsumerOp>();
target.addLegalOp(OperationName("test.legal_op", &getContext()));
target
.addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
More information about the Mlir-commits
mailing list