[Mlir-commits] [mlir] [mlir][Transforms] Dialect conversion: Context-aware type conversions (PR #140434)
Matthias Springer
llvmlistbot at llvm.org
Sat May 17 22:12:25 PDT 2025
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/140434
This commit adds support for context-aware type conversions: type conversion rules that can return different types depending on the IR.
There is no change for existing (context-unaware) type conversion rules:
```
// Example: Conversion any integer type to f32.
converter.addConversion([](IntegerType t) {
return Float32Type::get(t.getContext());
}
```
There is now an additional overload to register context-aware type conversion rules:
```
// Example: Type conversion rule for integers, depending on the context:
// Get the defining op of `v`, read its "increment" attribute and return an
// integer with a bitwidth that is increased by "increment".
converter.addConversion([](Value v) -> std::optional<Type> {
auto intType = dyn_cast<IntegerType>(v.getType());
if (!intType)
return std::nullopt;
Operation *op = v.getDefiningOp();
if (!op)
return std::nullopt;
auto incrementAttr = op->getAttrOfType<IntegerAttr>("increment");
if (!incrementAttr)
return std::nullopt;
return IntegerType::get(v.getContext(),
intType.getWidth() + incrementAttr.getInt());
});
```
For performance reasons, the type converter caches the result of type conversions. This is no longer possible when there context-aware type conversions because each conversion could compute a different type depending on the context. There is no performance degradation when there are only context-unaware type conversions.
Note: This commit just adds context-aware type conversions to the dialect conversion framework. There are many existing patterns that still call `converter.convertType(someValue.getType())`. These should be gradually updated in subsequent commits to call `converter.convertType(someValue)`.
>From d5656952d555ba96360ceefa25ad8bf312d8050b Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Wed, 7 May 2025 10:42:58 +0200
Subject: [PATCH] prototype
experiment
more
---
mlir/docs/DialectConversion.md | 47 +++++----
.../mlir/Transforms/DialectConversion.h | 95 ++++++++++++++++---
.../Transforms/StructuralTypeConversions.cpp | 4 +-
.../Transforms/Utils/DialectConversion.cpp | 45 ++++++++-
.../test-legalize-type-conversion.mlir | 18 ++++
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 29 +++++-
6 files changed, 200 insertions(+), 38 deletions(-)
diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index cf577eca5b9a6..61872d10670dc 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -235,6 +235,15 @@ target types. If the source type is converted to itself, we say it is a "legal"
type. Type conversions are specified via the `addConversion` method described
below.
+There are two kind of conversion functions: context-aware and context-unaware
+conversions. A context-unaware conversion function converts a `Type` into a
+`Type`. A context-aware conversion function converts a `Value` into a type. The
+latter allows users to customize type conversion rules based on the IR.
+
+Note: When there is at least one context-aware type conversion function, the
+result of type conversions can no longer be cached, which can increase
+compilation time. Use this feature with caution!
+
A `materialization` describes how a list of values should be converted to a
list of values with specific types. An important distinction from a
`conversion` is that a `materialization` can produce IR, whereas a `conversion`
@@ -287,29 +296,31 @@ Several of the available hooks are detailed below:
```c++
class TypeConverter {
public:
- /// Register a conversion function. A conversion function defines how a given
- /// source type should be converted. A conversion function must be convertible
- /// to any of the following forms(where `T` is a class derived from `Type`:
- /// * Optional<Type>(T)
+ /// Register a conversion function. A conversion function must be convertible
+ /// to any of the following forms (where `T` is `Value` or a class derived
+ /// from `Type`, including `Type` itself):
+ ///
+ /// * std::optional<Type>(T)
/// - This form represents a 1-1 type conversion. It should return nullptr
- /// or `std::nullopt` to signify failure. If `std::nullopt` is returned, the
- /// converter is allowed to try another conversion function to perform
- /// the conversion.
- /// * Optional<LogicalResult>(T, SmallVectorImpl<Type> &)
+ /// or `std::nullopt` to signify failure. If `std::nullopt` is returned,
+ /// the converter is allowed to try another conversion function to
+ /// perform the conversion.
+ /// * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &)
/// - This form represents a 1-N type conversion. It should return
- /// `failure` or `std::nullopt` to signify a failed conversion. If the new
- /// set of types is empty, the type is removed and any usages of the
+ /// `failure` or `std::nullopt` to signify a failed conversion. If the
+ /// new set of types is empty, the type is removed and any usages of the
/// existing value are expected to be removed during conversion. If
/// `std::nullopt` is returned, the converter is allowed to try another
/// conversion function to perform the conversion.
- /// * Optional<LogicalResult>(T, SmallVectorImpl<Type> &, ArrayRef<Type>)
- /// - This form represents a 1-N type conversion supporting recursive
- /// types. The first two arguments and the return value are the same as
- /// for the regular 1-N form. The third argument is contains is the
- /// "call stack" of the recursive conversion: it contains the list of
- /// types currently being converted, with the current type being the
- /// last one. If it is present more than once in the list, the
- /// conversion concerns a recursive type.
+ ///
+ /// Conversion functions that accept `Value` as the first argument are
+ /// context-aware. I.e., they can take into account IR when converting the
+ /// type of the given value. Context-unaware conversion functions accept
+ /// `Type` or a derived class as the first argument.
+ ///
+ /// Note: Context-unaware conversions are cached, but context-aware
+ /// conversions are not.
+ ///
/// Note: When attempting to convert a type, e.g. via 'convertType', the
/// mostly recently added conversions will be invoked first.
template <typename FnT,
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index e7d05c3ce1adf..07adbde3a5a60 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -18,6 +18,7 @@
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/StringMap.h"
#include <type_traits>
+#include <variant>
namespace mlir {
@@ -139,7 +140,8 @@ class TypeConverter {
};
/// Register a conversion function. A conversion function must be convertible
- /// to any of the following forms (where `T` is a class derived from `Type`):
+ /// to any of the following forms (where `T` is `Value` or a class derived
+ /// from `Type`, including `Type` itself):
///
/// * std::optional<Type>(T)
/// - This form represents a 1-1 type conversion. It should return nullptr
@@ -154,6 +156,14 @@ class TypeConverter {
/// `std::nullopt` is returned, the converter is allowed to try another
/// conversion function to perform the conversion.
///
+ /// Conversion functions that accept `Value` as the first argument are
+ /// context-aware. I.e., they can take into account IR when converting the
+ /// type of the given value. Context-unaware conversion functions accept
+ /// `Type` or a derived class as the first argument.
+ ///
+ /// Note: Context-unaware conversions are cached, but context-aware
+ /// conversions are not.
+ ///
/// Note: When attempting to convert a type, e.g. via 'convertType', the
/// mostly recently added conversions will be invoked first.
template <typename FnT, typename T = typename llvm::function_traits<
@@ -241,15 +251,28 @@ class TypeConverter {
wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback)));
}
- /// Convert the given type. This function should return failure if no valid
+ /// Convert the given type. This function returns failure if no valid
/// conversion exists, success otherwise. If the new set of types is empty,
/// the type is removed and any usages of the existing value are expected to
/// be removed during conversion.
+ ///
+ /// Note: This overload invokes only context-unaware type conversion
+ /// functions. Users should call the other overload if possible.
LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) const;
+ /// Convert the type of the given value. This function returns failure if no
+ /// valid conversion exists, success otherwise. If the new set of types is
+ /// empty, the type is removed and any usages of the existing value are
+ /// expected to be removed during conversion.
+ ///
+ /// Note: This overload invokes both context-aware and context-unaware type
+ /// conversion functions.
+ LogicalResult convertType(Value v, SmallVectorImpl<Type> &results) const;
+
/// This hook simplifies defining 1-1 type conversions. This function returns
/// the type to convert to on success, and a null type on failure.
Type convertType(Type t) const;
+ Type convertType(Value v) const;
/// Attempts a 1-1 type conversion, expecting the result type to be
/// `TargetType`. Returns the converted type cast to `TargetType` on success,
@@ -258,13 +281,23 @@ class TypeConverter {
TargetType convertType(Type t) const {
return dyn_cast_or_null<TargetType>(convertType(t));
}
+ template <typename TargetType>
+ TargetType convertType(Value v) const {
+ return dyn_cast_or_null<TargetType>(convertType(v));
+ }
- /// Convert the given set of types, filling 'results' as necessary. This
- /// returns failure if the conversion of any of the types fails, success
+ /// Convert the given types, filling 'results' as necessary. This returns
+ /// "failure" if the conversion of any of the types fails, "success"
/// otherwise.
LogicalResult convertTypes(TypeRange types,
SmallVectorImpl<Type> &results) const;
+ /// Convert the types of the given values, filling 'results' as necessary.
+ /// This returns "failure" if the conversion of any of the types fails,
+ /// "success" otherwise.
+ LogicalResult convertTypes(ValueRange values,
+ SmallVectorImpl<Type> &results) const;
+
/// Return true if the given type is legal for this type converter, i.e. the
/// type converts to itself.
bool isLegal(Type type) const;
@@ -328,7 +361,7 @@ class TypeConverter {
/// types is empty, the type is removed and any usages of the existing value
/// are expected to be removed during conversion.
using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
- Type, SmallVectorImpl<Type> &)>;
+ std::variant<Type, Value>, SmallVectorImpl<Type> &)>;
/// The signature of the callback used to materialize a source conversion.
///
@@ -348,13 +381,14 @@ class TypeConverter {
/// Generate a wrapper for the given callback. This allows for accepting
/// different callback forms, that all compose into a single version.
- /// With callback of form: `std::optional<Type>(T)`
+ /// With callback of form: `std::optional<Type>(T)`, where `T` can be a
+ /// `Value` or a `Type` (or a class derived from `Type`).
template <typename T, typename FnT>
std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn>
- wrapCallback(FnT &&callback) const {
+ wrapCallback(FnT &&callback) {
return wrapCallback<T>([callback = std::forward<FnT>(callback)](
- T type, SmallVectorImpl<Type> &results) {
- if (std::optional<Type> resultOpt = callback(type)) {
+ T typeOrValue, SmallVectorImpl<Type> &results) {
+ if (std::optional<Type> resultOpt = callback(typeOrValue)) {
bool wasSuccess = static_cast<bool>(*resultOpt);
if (wasSuccess)
results.push_back(*resultOpt);
@@ -364,20 +398,49 @@ class TypeConverter {
});
}
/// With callback of form: `std::optional<LogicalResult>(
- /// T, SmallVectorImpl<Type> &, ArrayRef<Type>)`.
+ /// T, SmallVectorImpl<Type> &)`, where `T` is a type.
template <typename T, typename FnT>
- std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &>,
+ std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
+ std::is_base_of_v<Type, T>,
ConversionCallbackFn>
wrapCallback(FnT &&callback) const {
return [callback = std::forward<FnT>(callback)](
- Type type,
+ std::variant<Type, Value> type,
SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
- T derivedType = dyn_cast<T>(type);
+ T derivedType;
+ if (Type *t = std::get_if<Type>(&type)) {
+ derivedType = dyn_cast<T>(*t);
+ } else if (Value *v = std::get_if<Value>(&type)) {
+ derivedType = dyn_cast<T>(v->getType());
+ } else {
+ llvm_unreachable("unexpected variant");
+ }
if (!derivedType)
return std::nullopt;
return callback(derivedType, results);
};
}
+ /// With callback of form: `std::optional<LogicalResult>(
+ /// T, SmallVectorImpl<Type>)`, where `T` is a `Value`.
+ template <typename T, typename FnT>
+ std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
+ std::is_same_v<T, Value>,
+ ConversionCallbackFn>
+ wrapCallback(FnT &&callback) {
+ hasContextAwareTypeConversions = true;
+ return [callback = std::forward<FnT>(callback)](
+ std::variant<Type, Value> type,
+ SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
+ if (Type *t = std::get_if<Type>(&type)) {
+ // Context-aware type conversion was called with a type.
+ return std::nullopt;
+ } else if (Value *v = std::get_if<Value>(&type)) {
+ return callback(*v, results);
+ }
+ llvm_unreachable("unexpected variant");
+ return std::nullopt;
+ };
+ }
/// Register a type conversion.
void registerConversion(ConversionCallbackFn callback) {
@@ -504,6 +567,12 @@ class TypeConverter {
mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
/// A mutex used for cache access
mutable llvm::sys::SmartRWMutex<true> cacheMutex;
+ /// Whether the type converter has context-aware type conversions. I.e.,
+ /// conversion rules that depend on the SSA value instead of just the type.
+ /// Type conversion caching is deactivated when there are context-aware
+ /// conversions because the type converter may return different results for
+ /// the same input type.
+ bool hasContextAwareTypeConversions = false;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index 09326242eec2a..de4612fa0846a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -52,8 +52,8 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
SmallVector<unsigned> offsets;
offsets.push_back(0);
// Do the type conversion and record the offsets.
- for (Type type : op.getResultTypes()) {
- if (failed(typeConverter->convertTypes(type, dstTypes)))
+ for (Value v : op.getResults()) {
+ if (failed(typeConverter->convertType(v, dstTypes)))
return rewriter.notifyMatchFailure(op, "could not convert result type");
offsets.push_back(dstTypes.size());
}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index bd11bbe58a3f6..2a1d154faeaf3 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1256,7 +1256,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
// If there is no legal conversion, fail to match this pattern.
SmallVector<Type, 1> legalTypes;
- if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
+ if (failed(currentTypeConverter->convertType(operand, legalTypes))) {
notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
diag << "unable to convert type for " << valueDiagTag << " #"
<< it.index() << ", type was " << origType;
@@ -2899,6 +2899,28 @@ LogicalResult TypeConverter::convertType(Type t,
return failure();
}
+LogicalResult TypeConverter::convertType(Value v,
+ SmallVectorImpl<Type> &results) const {
+ assert(v && "expected non-null value");
+
+ // If this type converter does not have context-aware type conversions, call
+ // the type-based overload, which has caching.
+ if (!hasContextAwareTypeConversions) {
+ return convertType(v.getType(), results);
+ }
+
+ // Walk the added converters in reverse order to apply the most recently
+ // registered first.
+ for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
+ if (std::optional<LogicalResult> result = converter(v, results)) {
+ if (!succeeded(*result))
+ return failure();
+ return success();
+ }
+ }
+ return failure();
+}
+
Type TypeConverter::convertType(Type t) const {
// Use the multi-type result version to convert the type.
SmallVector<Type, 1> results;
@@ -2909,6 +2931,16 @@ Type TypeConverter::convertType(Type t) const {
return results.size() == 1 ? results.front() : nullptr;
}
+Type TypeConverter::convertType(Value v) const {
+ // Use the multi-type result version to convert the type.
+ SmallVector<Type, 1> results;
+ if (failed(convertType(v, results)))
+ return nullptr;
+
+ // Check to ensure that only one type was produced.
+ return results.size() == 1 ? results.front() : nullptr;
+}
+
LogicalResult
TypeConverter::convertTypes(TypeRange types,
SmallVectorImpl<Type> &results) const {
@@ -2918,6 +2950,15 @@ TypeConverter::convertTypes(TypeRange types,
return success();
}
+LogicalResult
+TypeConverter::convertTypes(ValueRange values,
+ SmallVectorImpl<Type> &results) const {
+ for (Value value : values)
+ if (failed(convertType(value, results)))
+ return failure();
+ return success();
+}
+
bool TypeConverter::isLegal(Type type) const {
return convertType(type) == type;
}
@@ -3128,7 +3169,7 @@ mlir::convertOpResultTypes(Operation *op, ValueRange operands,
newOp.addOperands(operands);
SmallVector<Type> newResultTypes;
- if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
+ if (failed(converter.convertTypes(op->getResults(), newResultTypes)))
return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
newOp.addTypes(newResultTypes);
diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index db8bd0f6378d2..7b5e6e796a528 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -142,3 +142,21 @@ func.func @test_signature_conversion_no_converter() {
}) : () -> ()
return
}
+
+// -----
+
+// CHECK-LABEL: func @context_aware_conversion()
+func.func @context_aware_conversion() {
+ // Case 1: Convert i37 --> i38.
+ // CHECK: %[[cast0:.*]] = unrealized_conversion_cast %{{.*}} : i37 to i38
+ // CHECK: "test.legal_op_d"(%[[cast0]]) : (i38) -> ()
+ %0 = "test.context_op"() {increment = 1 : i64} : () -> (i37)
+ "test.replace_with_legal_op"(%0) : (i37) -> ()
+
+ // Case 2: Convert i37 --> i39.
+ // CHECK: %[[cast1:.*]] = unrealized_conversion_cast %{{.*}} : i37 to i39
+ // CHECK: "test.legal_op_d"(%[[cast1]]) : (i39) -> ()
+ %1 = "test.context_op"() {increment = 2 : i64} : () -> (i37)
+ "test.replace_with_legal_op"(%1) : (i37) -> ()
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index d073843484d81..bd85e6fd9ae7f 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1827,9 +1827,9 @@ struct TestReplaceWithLegalOp : public ConversionPattern {
: ConversionPattern(converter, "test.replace_with_legal_op",
/*benefit=*/1, ctx) {}
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
- rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0]);
+ rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0].front());
return success();
}
};
@@ -1865,7 +1865,7 @@ struct TestTypeConversionDriver
return nullptr;
});
converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) {
- // Drop all integer types.
+ // Drop all other integer types.
return success();
});
converter.addConversion(
@@ -1902,6 +1902,19 @@ struct TestTypeConversionDriver
results.push_back(result);
return success();
});
+ converter.addConversion([](Value v) -> std::optional<Type> {
+ auto intType = dyn_cast<IntegerType>(v.getType());
+ if (!intType || intType.getWidth() != 37)
+ return std::nullopt;
+ Operation *op = v.getDefiningOp();
+ if (!op)
+ return std::nullopt;
+ auto incrementAttr = op->getAttrOfType<IntegerAttr>("increment");
+ if (!incrementAttr)
+ return std::nullopt;
+ return IntegerType::get(v.getContext(),
+ intType.getWidth() + incrementAttr.getInt());
+ });
/// Add the legal set of type materializations.
converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
@@ -1922,9 +1935,19 @@ struct TestTypeConversionDriver
// Otherwise, fail.
return nullptr;
});
+ // Materialize i37 to any desired type with unrealized_conversion_cast.
+ converter.addTargetMaterialization([](OpBuilder &builder, Type type,
+ ValueRange inputs,
+ Location loc) -> Value {
+ if (inputs.size() != 1 || !inputs[0].getType().isInteger(37))
+ return Value();
+ return builder.create<UnrealizedConversionCastOp>(loc, type, inputs)
+ .getResult(0);
+ });
// Initialize the conversion target.
mlir::ConversionTarget target(getContext());
+ target.addLegalOp(OperationName("test.context_op", &getContext()));
target.addLegalOp<LegalOpD>();
target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
auto recursiveType = dyn_cast<test::TestRecursiveType>(op.getType());
More information about the Mlir-commits
mailing list