[Mlir-commits] [mlir] b22f94d - [MLIR] Enable caching of type conversion in the presence of context-aware conversion (#158072)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 11 07:17:14 PDT 2025


Author: Mehdi Amini
Date: 2025-09-11T14:17:10Z
New Revision: b22f94dcc58e09710c188045b498a201db83d9a2

URL: https://github.com/llvm/llvm-project/commit/b22f94dcc58e09710c188045b498a201db83d9a2
DIFF: https://github.com/llvm/llvm-project/commit/b22f94dcc58e09710c188045b498a201db83d9a2.diff

LOG: [MLIR] Enable caching of type conversion in the presence of context-aware conversion (#158072)

The current implementation is overly conservative and disable all
possible caching as soon as a context-aware conversion is present.
However the context-aware conversion only affects subsequent converters,
we can cache the previous ones.

This isn't NFC because if fixed a bug where we use to unconditionally
cache when using the `convertType(Type t, ...` API, while now all APIs
are aware of context-aware conversions.

Added: 
    

Modified: 
    mlir/docs/DialectConversion.md
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Transforms/Utils/DialectConversion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index 7070351755e7a..5ae35155b4ed5 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -285,9 +285,13 @@ conversions. A context-unaware conversion function converts a `Type` into a
 `Type`. A context-aware conversion function converts a `Value` into a type. The
 latter allows users to customize type conversion rules based on the IR.
 
-Note: When there is at least one context-aware type conversion function, the
-result of type conversions can no longer be cached, which can increase
-compilation time. Use this feature with caution!
+Note: context-aware type conversion functions impact the ability of the
+framework to cache the conversion result. In the absence of a context-aware
+conversion, all context-free type conversions can be cached. Otherwise only the
+context-free conversions added after a context-aware type conversion can be
+cached (conversions are applied in reverse order). 
+As such it is advised to add context-aware conversions as early as possible in
+the sequence of `addConversion` calls (so that they apply last).
 
 A `materialization` describes how a list of values should be converted to a
 list of values with specific types. An important distinction from a

diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 6949f4a14fdba..a096f82a4cfd8 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -433,7 +433,7 @@ class TypeConverter {
                        std::is_same_v<T, Value>,
                    ConversionCallbackFn>
   wrapCallback(FnT &&callback) {
-    hasContextAwareTypeConversions = true;
+    contextAwareTypeConversionsIndex = conversions.size();
     return [callback = std::forward<FnT>(callback)](
                PointerUnion<Type, Value> typeOrValue,
                SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
@@ -555,6 +555,10 @@ class TypeConverter {
     cachedMultiConversions.clear();
   }
 
+  /// Internal implementation of the type conversion.
+  LogicalResult convertTypeImpl(PointerUnion<Type, Value> t,
+                                SmallVectorImpl<Type> &results) const;
+
   /// The set of registered conversion functions.
   SmallVector<ConversionCallbackFn, 4> conversions;
 
@@ -575,10 +579,13 @@ class TypeConverter {
   mutable llvm::sys::SmartRWMutex<true> cacheMutex;
   /// Whether the type converter has context-aware type conversions. I.e.,
   /// conversion rules that depend on the SSA value instead of just the type.
-  /// Type conversion caching is deactivated when there are context-aware
-  /// conversions because the type converter may return 
diff erent results for
-  /// the same input type.
-  bool hasContextAwareTypeConversions = false;
+  /// We store here the index in the `conversions` vector of the last added
+  /// context-aware conversion, if any. This is useful because we can't cache
+  /// the result of type conversion happening after context-aware conversions,
+  /// because the type converter may return 
diff erent results for the same input
+  /// type. This is why it is recommened to add context-aware conversions first,
+  /// any context-free conversions after will benefit from caching.
+  int contextAwareTypeConversionsIndex = -1;
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 36ee87b533b3b..df9700f11200f 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -3406,10 +3406,19 @@ void TypeConverter::SignatureConversion::remapInput(
       SmallVector<Value, 1>(replacements.begin(), replacements.end())};
 }
 
-LogicalResult TypeConverter::convertType(Type t,
-                                         SmallVectorImpl<Type> &results) const {
-  assert(t && "expected non-null type");
-
+/// Internal implementation of the type conversion.
+/// This is used with either a Type or a Value as the first argument.
+/// - we can cache the context-free conversions until the last registered
+/// context-aware conversion.
+/// - we can't cache the result of type conversion happening after context-aware
+/// conversions, because the type converter may return 
diff erent results for the
+/// same input type.
+LogicalResult
+TypeConverter::convertTypeImpl(PointerUnion<Type, Value> typeOrValue,
+                               SmallVectorImpl<Type> &results) const {
+  assert(typeOrValue && "expected non-null type");
+  Type t = (isa<Value>(typeOrValue)) ? cast<Value>(typeOrValue).getType()
+                                     : cast<Type>(typeOrValue);
   {
     std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
                                                          std::defer_lock);
@@ -3431,52 +3440,53 @@ LogicalResult TypeConverter::convertType(Type t,
   // registered first.
   size_t currentCount = results.size();
 
+  // We can cache the context-free conversions until the last registered
+  // context-aware conversion. But only if we're processing a Value right now.
+  auto isCacheable = [&](int index) {
+    int numberOfConversionsUntilContextAware =
+        conversions.size() - 1 - contextAwareTypeConversionsIndex;
+    return index < numberOfConversionsUntilContextAware;
+  };
+
   std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
                                                         std::defer_lock);
 
-  for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
-    if (std::optional<LogicalResult> result = converter(t, results)) {
-      if (t.getContext()->isMultithreadingEnabled())
-        cacheWriteLock.lock();
-      if (!succeeded(*result)) {
-        assert(results.size() == currentCount &&
-               "failed type conversion should not change results");
-        cachedDirectConversions.try_emplace(t, nullptr);
-        return failure();
-      }
-      auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
-      if (newTypes.size() == 1)
-        cachedDirectConversions.try_emplace(t, newTypes.front());
-      else
-        cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
+  for (auto indexedConverter : llvm::enumerate(llvm::reverse(conversions))) {
+    const ConversionCallbackFn &converter = indexedConverter.value();
+    std::optional<LogicalResult> result = converter(typeOrValue, results);
+    if (!result) {
+      assert(results.size() == currentCount &&
+             "failed type conversion should not change results");
+      continue;
+    }
+    if (!isCacheable(indexedConverter.index()))
       return success();
-    } else {
+    if (t.getContext()->isMultithreadingEnabled())
+      cacheWriteLock.lock();
+    if (!succeeded(*result)) {
       assert(results.size() == currentCount &&
              "failed type conversion should not change results");
+      cachedDirectConversions.try_emplace(t, nullptr);
+      return failure();
     }
+    auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
+    if (newTypes.size() == 1)
+      cachedDirectConversions.try_emplace(t, newTypes.front());
+    else
+      cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
+    return success();
   }
   return failure();
 }
 
-LogicalResult TypeConverter::convertType(Value v,
+LogicalResult TypeConverter::convertType(Type t,
                                          SmallVectorImpl<Type> &results) const {
-  assert(v && "expected non-null value");
-
-  // If this type converter does not have context-aware type conversions, call
-  // the type-based overload, which has caching.
-  if (!hasContextAwareTypeConversions)
-    return convertType(v.getType(), results);
+  return convertTypeImpl(t, results);
+}
 
-  // Walk the added converters in reverse order to apply the most recently
-  // registered first.
-  for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
-    if (std::optional<LogicalResult> result = converter(v, results)) {
-      if (!succeeded(*result))
-        return failure();
-      return success();
-    }
-  }
-  return failure();
+LogicalResult TypeConverter::convertType(Value v,
+                                         SmallVectorImpl<Type> &results) const {
+  return convertTypeImpl(v, results);
 }
 
 Type TypeConverter::convertType(Type t) const {


        


More information about the Mlir-commits mailing list