[Mlir-commits] [mlir] [MLIR] Enable caching of type conversion in the presence of context-aware conversion (PR #158072)
Mehdi Amini
llvmlistbot at llvm.org
Thu Sep 11 06:36:52 PDT 2025
https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/158072
>From bfb414612f83633ea91e4f65b8edea80ecba0054 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Thu, 11 Sep 2025 06:23:47 -0700
Subject: [PATCH 1/2] [MLIR] Enable caching of type conversion in the presence
of context-aware conversion
The current implementation is overly conservative and disable all possible
caching as soon as a context-aware conversion is present. However the
context-aware conversion only affects subsequent converters, we can cache
the previous ones.
---
.../mlir/Transforms/DialectConversion.h | 17 +++-
.../Transforms/Utils/DialectConversion.cpp | 95 +++++++++++--------
2 files changed, 69 insertions(+), 43 deletions(-)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 6949f4a14fdba..1b9f7a76fc579 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -433,7 +433,7 @@ class TypeConverter {
std::is_same_v<T, Value>,
ConversionCallbackFn>
wrapCallback(FnT &&callback) {
- hasContextAwareTypeConversions = true;
+ contextAwareTypeConversionsIndex = conversions.size();
return [callback = std::forward<FnT>(callback)](
PointerUnion<Type, Value> typeOrValue,
SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
@@ -555,6 +555,10 @@ class TypeConverter {
cachedMultiConversions.clear();
}
+ /// Internal implementation of the type conversion.
+ template <typename T>
+ LogicalResult convertTypeImpl(T t, SmallVectorImpl<Type> &results) const;
+
/// The set of registered conversion functions.
SmallVector<ConversionCallbackFn, 4> conversions;
@@ -575,10 +579,13 @@ class TypeConverter {
mutable llvm::sys::SmartRWMutex<true> cacheMutex;
/// Whether the type converter has context-aware type conversions. I.e.,
/// conversion rules that depend on the SSA value instead of just the type.
- /// Type conversion caching is deactivated when there are context-aware
- /// conversions because the type converter may return different results for
- /// the same input type.
- bool hasContextAwareTypeConversions = false;
+ /// We store here the index in the `conversions` vector of the last added
+ /// context-aware conversion, if any. This is useful because we can't cache
+ /// the result of type conversion happening after context-aware conversions,
+ /// because the type converter may return different results for the same input
+ /// type. This is why it is recommened to add context-aware conversions first,
+ /// any context-free conversions after will benefit from caching.
+ int contextAwareTypeConversionsIndex = -1;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 36ee87b533b3b..8f36a653e3a17 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -3406,22 +3406,37 @@ void TypeConverter::SignatureConversion::remapInput(
SmallVector<Value, 1>(replacements.begin(), replacements.end())};
}
-LogicalResult TypeConverter::convertType(Type t,
- SmallVectorImpl<Type> &results) const {
+/// Internal implementation of the type conversion.
+/// This is used with either a Type or a Value as the first argument.
+/// When using a value, the caching behavior is different:
+/// - we can cache the context-free conversions until the last registered
+/// context-aware conversion.
+/// - we can't cache the result of type conversion happening after context-aware
+/// conversions,
+/// because the type converter may return different results for the same input
+/// type.
+template <typename T>
+LogicalResult
+TypeConverter::convertTypeImpl(T t, SmallVectorImpl<Type> &results) const {
assert(t && "expected non-null type");
-
+ auto getType = [&](auto typeOrValue) {
+ if constexpr (std::is_same_v<decltype(typeOrValue), Type>)
+ return typeOrValue;
+ else
+ return typeOrValue.getType();
+ };
{
std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
std::defer_lock);
if (t.getContext()->isMultithreadingEnabled())
cacheReadLock.lock();
- auto existingIt = cachedDirectConversions.find(t);
+ auto existingIt = cachedDirectConversions.find(getType(t));
if (existingIt != cachedDirectConversions.end()) {
if (existingIt->second)
results.push_back(existingIt->second);
return success(existingIt->second != nullptr);
}
- auto multiIt = cachedMultiConversions.find(t);
+ auto multiIt = cachedMultiConversions.find(getType(t));
if (multiIt != cachedMultiConversions.end()) {
results.append(multiIt->second.begin(), multiIt->second.end());
return success();
@@ -3431,52 +3446,56 @@ LogicalResult TypeConverter::convertType(Type t,
// registered first.
size_t currentCount = results.size();
+ // We can cache the context-free conversions until the last registered
+ // context-aware conversion. But only if we're processing a Value right now.
+ auto isCacheable = [&](int index) {
+ if constexpr (std::is_same_v<T, Type>)
+ return true;
+ int numberOfConversionsUntilContextAware =
+ conversions.size() - 1 - contextAwareTypeConversionsIndex;
+ return index < numberOfConversionsUntilContextAware;
+ };
+
std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
std::defer_lock);
- for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
- if (std::optional<LogicalResult> result = converter(t, results)) {
- if (t.getContext()->isMultithreadingEnabled())
- cacheWriteLock.lock();
- if (!succeeded(*result)) {
- assert(results.size() == currentCount &&
- "failed type conversion should not change results");
- cachedDirectConversions.try_emplace(t, nullptr);
- return failure();
- }
- auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
- if (newTypes.size() == 1)
- cachedDirectConversions.try_emplace(t, newTypes.front());
- else
- cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
+ for (auto indexedConverter : llvm::enumerate(llvm::reverse(conversions))) {
+ const ConversionCallbackFn &converter = indexedConverter.value();
+ std::optional<LogicalResult> result = converter(t, results);
+ if (!result) {
+ assert(results.size() == currentCount &&
+ "failed type conversion should not change results");
+ continue;
+ }
+ if (!isCacheable(indexedConverter.index()))
return success();
- } else {
+ if (t.getContext()->isMultithreadingEnabled())
+ cacheWriteLock.lock();
+ if (!succeeded(*result)) {
assert(results.size() == currentCount &&
"failed type conversion should not change results");
+ cachedDirectConversions.try_emplace(getType(t), nullptr);
+ return failure();
}
+ auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
+ if (newTypes.size() == 1)
+ cachedDirectConversions.try_emplace(getType(t), newTypes.front());
+ else
+ cachedMultiConversions.try_emplace(getType(t),
+ llvm::to_vector<2>(newTypes));
+ return success();
}
return failure();
}
-LogicalResult TypeConverter::convertType(Value v,
+LogicalResult TypeConverter::convertType(Type t,
SmallVectorImpl<Type> &results) const {
- assert(v && "expected non-null value");
-
- // If this type converter does not have context-aware type conversions, call
- // the type-based overload, which has caching.
- if (!hasContextAwareTypeConversions)
- return convertType(v.getType(), results);
+ return convertTypeImpl(t, results);
+}
- // Walk the added converters in reverse order to apply the most recently
- // registered first.
- for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
- if (std::optional<LogicalResult> result = converter(v, results)) {
- if (!succeeded(*result))
- return failure();
- return success();
- }
- }
- return failure();
+LogicalResult TypeConverter::convertType(Value v,
+ SmallVectorImpl<Type> &results) const {
+ return convertTypeImpl(v, results);
}
Type TypeConverter::convertType(Type t) const {
>From 7e09fe858bc8db2b35e28b6bba60db04401edff1 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Thu, 11 Sep 2025 06:36:14 -0700
Subject: [PATCH 2/2] Make `convertType(Type t,...` caching context-aware
---
mlir/docs/DialectConversion.md | 9 ++++++---
mlir/lib/Transforms/Utils/DialectConversion.cpp | 2 --
2 files changed, 6 insertions(+), 5 deletions(-)
diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index 7070351755e7a..60038c12609c1 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -285,9 +285,12 @@ conversions. A context-unaware conversion function converts a `Type` into a
`Type`. A context-aware conversion function converts a `Value` into a type. The
latter allows users to customize type conversion rules based on the IR.
-Note: When there is at least one context-aware type conversion function, the
-result of type conversions can no longer be cached, which can increase
-compilation time. Use this feature with caution!
+Note: context-aware type conversion function impact the ability of the
+framework to cache the conversin result: only context-free type conversions
+added after a context-aware type conversion can be cached (conversions are
+applied in reverse order).
+As such it is advised to add context-aware conversions as early as possible in
+the chain of conversion (so that they apply last).
A `materialization` describes how a list of values should be converted to a
list of values with specific types. An important distinction from a
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 8f36a653e3a17..a2835c568af2e 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -3449,8 +3449,6 @@ TypeConverter::convertTypeImpl(T t, SmallVectorImpl<Type> &results) const {
// We can cache the context-free conversions until the last registered
// context-aware conversion. But only if we're processing a Value right now.
auto isCacheable = [&](int index) {
- if constexpr (std::is_same_v<T, Type>)
- return true;
int numberOfConversionsUntilContextAware =
conversions.size() - 1 - contextAwareTypeConversionsIndex;
return index < numberOfConversionsUntilContextAware;
More information about the Mlir-commits
mailing list