[Mlir-commits] [mlir] [mlir][Transforms] Dialect conversion: Context-aware type conversions (PR #140434)

Matthias Springer llvmlistbot at llvm.org
Sat May 17 22:12:25 PDT 2025


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/140434

This commit adds support for context-aware type conversions: type conversion rules that can return different types depending on the IR.

There is no change for existing (context-unaware) type conversion rules:
```
// Example: Conversion any integer type to f32.
converter.addConversion([](IntegerType t) {
  return Float32Type::get(t.getContext());
}
```

There is now an additional overload to register context-aware type conversion rules:
```
// Example: Type conversion rule for integers, depending on the context:
// Get the defining op of `v`, read its "increment" attribute and return an
// integer with a bitwidth that is increased by "increment".
converter.addConversion([](Value v) -> std::optional<Type> {
  auto intType = dyn_cast<IntegerType>(v.getType());
  if (!intType)
    return std::nullopt;
  Operation *op = v.getDefiningOp();
  if (!op)
    return std::nullopt;
  auto incrementAttr = op->getAttrOfType<IntegerAttr>("increment");
  if (!incrementAttr)
    return std::nullopt;
  return IntegerType::get(v.getContext(),
                          intType.getWidth() + incrementAttr.getInt());
});
```

For performance reasons, the type converter caches the result of type conversions. This is no longer possible when there context-aware type conversions because each conversion could compute a different type depending on the context. There is no performance degradation when there are only context-unaware type conversions.

Note: This commit just adds context-aware type conversions to the dialect conversion framework. There are many existing patterns that still call `converter.convertType(someValue.getType())`. These should be gradually updated in subsequent commits to call `converter.convertType(someValue)`.


>From d5656952d555ba96360ceefa25ad8bf312d8050b Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Wed, 7 May 2025 10:42:58 +0200
Subject: [PATCH] prototype

experiment

more
---
 mlir/docs/DialectConversion.md                | 47 +++++----
 .../mlir/Transforms/DialectConversion.h       | 95 ++++++++++++++++---
 .../Transforms/StructuralTypeConversions.cpp  |  4 +-
 .../Transforms/Utils/DialectConversion.cpp    | 45 ++++++++-
 .../test-legalize-type-conversion.mlir        | 18 ++++
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   | 29 +++++-
 6 files changed, 200 insertions(+), 38 deletions(-)

diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index cf577eca5b9a6..61872d10670dc 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -235,6 +235,15 @@ target types. If the source type is converted to itself, we say it is a "legal"
 type. Type conversions are specified via the `addConversion` method described
 below.
 
+There are two kind of conversion functions: context-aware and context-unaware
+conversions. A context-unaware conversion function converts a `Type` into a
+`Type`. A context-aware conversion function converts a `Value` into a type. The
+latter allows users to customize type conversion rules based on the IR.
+
+Note: When there is at least one context-aware type conversion function, the
+result of type conversions can no longer be cached, which can increase
+compilation time. Use this feature with caution!
+
 A `materialization` describes how a list of values should be converted to a
 list of values with specific types. An important distinction from a
 `conversion` is that a `materialization` can produce IR, whereas a `conversion`
@@ -287,29 +296,31 @@ Several of the available hooks are detailed below:
 ```c++
 class TypeConverter {
  public:
-  /// Register a conversion function. A conversion function defines how a given
-  /// source type should be converted. A conversion function must be convertible
-  /// to any of the following forms(where `T` is a class derived from `Type`:
-  ///   * Optional<Type>(T)
+  /// Register a conversion function. A conversion function must be convertible
+  /// to any of the following forms (where `T` is `Value` or a class derived
+  /// from `Type`, including `Type` itself):
+  ///
+  ///   * std::optional<Type>(T)
   ///     - This form represents a 1-1 type conversion. It should return nullptr
-  ///       or `std::nullopt` to signify failure. If `std::nullopt` is returned, the
-  ///       converter is allowed to try another conversion function to perform
-  ///       the conversion.
-  ///   * Optional<LogicalResult>(T, SmallVectorImpl<Type> &)
+  ///       or `std::nullopt` to signify failure. If `std::nullopt` is returned,
+  ///       the converter is allowed to try another conversion function to
+  ///       perform the conversion.
+  ///   * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &)
   ///     - This form represents a 1-N type conversion. It should return
-  ///       `failure` or `std::nullopt` to signify a failed conversion. If the new
-  ///       set of types is empty, the type is removed and any usages of the
+  ///       `failure` or `std::nullopt` to signify a failed conversion. If the
+  ///       new set of types is empty, the type is removed and any usages of the
   ///       existing value are expected to be removed during conversion. If
   ///       `std::nullopt` is returned, the converter is allowed to try another
   ///       conversion function to perform the conversion.
-  ///   * Optional<LogicalResult>(T, SmallVectorImpl<Type> &, ArrayRef<Type>)
-  ///     - This form represents a 1-N type conversion supporting recursive
-  ///       types. The first two arguments and the return value are the same as
-  ///       for the regular 1-N form. The third argument is contains is the
-  ///       "call stack" of the recursive conversion: it contains the list of
-  ///       types currently being converted, with the current type being the
-  ///       last one. If it is present more than once in the list, the
-  ///       conversion concerns a recursive type.
+  ///
+  /// Conversion functions that accept `Value` as the first argument are
+  /// context-aware. I.e., they can take into account IR when converting the
+  /// type of the given value. Context-unaware conversion functions accept
+  /// `Type` or a derived class as the first argument.
+  ///
+  /// Note: Context-unaware conversions are cached, but context-aware
+  /// conversions are not.
+  ///
   /// Note: When attempting to convert a type, e.g. via 'convertType', the
   ///       mostly recently added conversions will be invoked first.
   template <typename FnT,
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index e7d05c3ce1adf..07adbde3a5a60 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -18,6 +18,7 @@
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/StringMap.h"
 #include <type_traits>
+#include <variant>
 
 namespace mlir {
 
@@ -139,7 +140,8 @@ class TypeConverter {
   };
 
   /// Register a conversion function. A conversion function must be convertible
-  /// to any of the following forms (where `T` is a class derived from `Type`):
+  /// to any of the following forms (where `T` is `Value` or a class derived
+  /// from `Type`, including `Type` itself):
   ///
   ///   * std::optional<Type>(T)
   ///     - This form represents a 1-1 type conversion. It should return nullptr
@@ -154,6 +156,14 @@ class TypeConverter {
   ///       `std::nullopt` is returned, the converter is allowed to try another
   ///       conversion function to perform the conversion.
   ///
+  /// Conversion functions that accept `Value` as the first argument are
+  /// context-aware. I.e., they can take into account IR when converting the
+  /// type of the given value. Context-unaware conversion functions accept
+  /// `Type` or a derived class as the first argument.
+  ///
+  /// Note: Context-unaware conversions are cached, but context-aware
+  /// conversions are not.
+  ///
   /// Note: When attempting to convert a type, e.g. via 'convertType', the
   ///       mostly recently added conversions will be invoked first.
   template <typename FnT, typename T = typename llvm::function_traits<
@@ -241,15 +251,28 @@ class TypeConverter {
         wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback)));
   }
 
-  /// Convert the given type. This function should return failure if no valid
+  /// Convert the given type. This function returns failure if no valid
   /// conversion exists, success otherwise. If the new set of types is empty,
   /// the type is removed and any usages of the existing value are expected to
   /// be removed during conversion.
+  ///
+  /// Note: This overload invokes only context-unaware type conversion
+  /// functions. Users should call the other overload if possible.
   LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) const;
 
+  /// Convert the type of the given value. This function returns failure if no
+  /// valid conversion exists, success otherwise. If the new set of types is
+  /// empty, the type is removed and any usages of the existing value are
+  /// expected to be removed during conversion.
+  ///
+  /// Note: This overload invokes both context-aware and context-unaware type
+  /// conversion functions.
+  LogicalResult convertType(Value v, SmallVectorImpl<Type> &results) const;
+
   /// This hook simplifies defining 1-1 type conversions. This function returns
   /// the type to convert to on success, and a null type on failure.
   Type convertType(Type t) const;
+  Type convertType(Value v) const;
 
   /// Attempts a 1-1 type conversion, expecting the result type to be
   /// `TargetType`. Returns the converted type cast to `TargetType` on success,
@@ -258,13 +281,23 @@ class TypeConverter {
   TargetType convertType(Type t) const {
     return dyn_cast_or_null<TargetType>(convertType(t));
   }
+  template <typename TargetType>
+  TargetType convertType(Value v) const {
+    return dyn_cast_or_null<TargetType>(convertType(v));
+  }
 
-  /// Convert the given set of types, filling 'results' as necessary. This
-  /// returns failure if the conversion of any of the types fails, success
+  /// Convert the given types, filling 'results' as necessary. This returns
+  /// "failure" if the conversion of any of the types fails, "success"
   /// otherwise.
   LogicalResult convertTypes(TypeRange types,
                              SmallVectorImpl<Type> &results) const;
 
+  /// Convert the types of the given values, filling 'results' as necessary.
+  /// This returns "failure" if the conversion of any of the types fails,
+  /// "success" otherwise.
+  LogicalResult convertTypes(ValueRange values,
+                             SmallVectorImpl<Type> &results) const;
+
   /// Return true if the given type is legal for this type converter, i.e. the
   /// type converts to itself.
   bool isLegal(Type type) const;
@@ -328,7 +361,7 @@ class TypeConverter {
   /// types is empty, the type is removed and any usages of the existing value
   /// are expected to be removed during conversion.
   using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
-      Type, SmallVectorImpl<Type> &)>;
+      std::variant<Type, Value>, SmallVectorImpl<Type> &)>;
 
   /// The signature of the callback used to materialize a source conversion.
   ///
@@ -348,13 +381,14 @@ class TypeConverter {
 
   /// Generate a wrapper for the given callback. This allows for accepting
   /// different callback forms, that all compose into a single version.
-  /// With callback of form: `std::optional<Type>(T)`
+  /// With callback of form: `std::optional<Type>(T)`, where `T` can be a
+  /// `Value` or a `Type` (or a class derived from `Type`).
   template <typename T, typename FnT>
   std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn>
-  wrapCallback(FnT &&callback) const {
+  wrapCallback(FnT &&callback) {
     return wrapCallback<T>([callback = std::forward<FnT>(callback)](
-                               T type, SmallVectorImpl<Type> &results) {
-      if (std::optional<Type> resultOpt = callback(type)) {
+                               T typeOrValue, SmallVectorImpl<Type> &results) {
+      if (std::optional<Type> resultOpt = callback(typeOrValue)) {
         bool wasSuccess = static_cast<bool>(*resultOpt);
         if (wasSuccess)
           results.push_back(*resultOpt);
@@ -364,20 +398,49 @@ class TypeConverter {
     });
   }
   /// With callback of form: `std::optional<LogicalResult>(
-  ///     T, SmallVectorImpl<Type> &, ArrayRef<Type>)`.
+  ///     T, SmallVectorImpl<Type> &)`, where `T` is a type.
   template <typename T, typename FnT>
-  std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &>,
+  std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
+                       std::is_base_of_v<Type, T>,
                    ConversionCallbackFn>
   wrapCallback(FnT &&callback) const {
     return [callback = std::forward<FnT>(callback)](
-               Type type,
+               std::variant<Type, Value> type,
                SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
-      T derivedType = dyn_cast<T>(type);
+      T derivedType;
+      if (Type *t = std::get_if<Type>(&type)) {
+        derivedType = dyn_cast<T>(*t);
+      } else if (Value *v = std::get_if<Value>(&type)) {
+        derivedType = dyn_cast<T>(v->getType());
+      } else {
+        llvm_unreachable("unexpected variant");
+      }
       if (!derivedType)
         return std::nullopt;
       return callback(derivedType, results);
     };
   }
+  /// With callback of form: `std::optional<LogicalResult>(
+  ///     T, SmallVectorImpl<Type>)`, where `T` is a `Value`.
+  template <typename T, typename FnT>
+  std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
+                       std::is_same_v<T, Value>,
+                   ConversionCallbackFn>
+  wrapCallback(FnT &&callback) {
+    hasContextAwareTypeConversions = true;
+    return [callback = std::forward<FnT>(callback)](
+               std::variant<Type, Value> type,
+               SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
+      if (Type *t = std::get_if<Type>(&type)) {
+        // Context-aware type conversion was called with a type.
+        return std::nullopt;
+      } else if (Value *v = std::get_if<Value>(&type)) {
+        return callback(*v, results);
+      }
+      llvm_unreachable("unexpected variant");
+      return std::nullopt;
+    };
+  }
 
   /// Register a type conversion.
   void registerConversion(ConversionCallbackFn callback) {
@@ -504,6 +567,12 @@ class TypeConverter {
   mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
   /// A mutex used for cache access
   mutable llvm::sys::SmartRWMutex<true> cacheMutex;
+  /// Whether the type converter has context-aware type conversions. I.e.,
+  /// conversion rules that depend on the SSA value instead of just the type.
+  /// Type conversion caching is deactivated when there are context-aware
+  /// conversions because the type converter may return different results for
+  /// the same input type.
+  bool hasContextAwareTypeConversions = false;
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index 09326242eec2a..de4612fa0846a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -52,8 +52,8 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
     SmallVector<unsigned> offsets;
     offsets.push_back(0);
     // Do the type conversion and record the offsets.
-    for (Type type : op.getResultTypes()) {
-      if (failed(typeConverter->convertTypes(type, dstTypes)))
+    for (Value v : op.getResults()) {
+      if (failed(typeConverter->convertType(v, dstTypes)))
         return rewriter.notifyMatchFailure(op, "could not convert result type");
       offsets.push_back(dstTypes.size());
     }
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index bd11bbe58a3f6..2a1d154faeaf3 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1256,7 +1256,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
 
     // If there is no legal conversion, fail to match this pattern.
     SmallVector<Type, 1> legalTypes;
-    if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
+    if (failed(currentTypeConverter->convertType(operand, legalTypes))) {
       notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
         diag << "unable to convert type for " << valueDiagTag << " #"
              << it.index() << ", type was " << origType;
@@ -2899,6 +2899,28 @@ LogicalResult TypeConverter::convertType(Type t,
   return failure();
 }
 
+LogicalResult TypeConverter::convertType(Value v,
+                                         SmallVectorImpl<Type> &results) const {
+  assert(v && "expected non-null value");
+
+  // If this type converter does not have context-aware type conversions, call
+  // the type-based overload, which has caching.
+  if (!hasContextAwareTypeConversions) {
+    return convertType(v.getType(), results);
+  }
+
+  // Walk the added converters in reverse order to apply the most recently
+  // registered first.
+  for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
+    if (std::optional<LogicalResult> result = converter(v, results)) {
+      if (!succeeded(*result))
+        return failure();
+      return success();
+    }
+  }
+  return failure();
+}
+
 Type TypeConverter::convertType(Type t) const {
   // Use the multi-type result version to convert the type.
   SmallVector<Type, 1> results;
@@ -2909,6 +2931,16 @@ Type TypeConverter::convertType(Type t) const {
   return results.size() == 1 ? results.front() : nullptr;
 }
 
+Type TypeConverter::convertType(Value v) const {
+  // Use the multi-type result version to convert the type.
+  SmallVector<Type, 1> results;
+  if (failed(convertType(v, results)))
+    return nullptr;
+
+  // Check to ensure that only one type was produced.
+  return results.size() == 1 ? results.front() : nullptr;
+}
+
 LogicalResult
 TypeConverter::convertTypes(TypeRange types,
                             SmallVectorImpl<Type> &results) const {
@@ -2918,6 +2950,15 @@ TypeConverter::convertTypes(TypeRange types,
   return success();
 }
 
+LogicalResult
+TypeConverter::convertTypes(ValueRange values,
+                            SmallVectorImpl<Type> &results) const {
+  for (Value value : values)
+    if (failed(convertType(value, results)))
+      return failure();
+  return success();
+}
+
 bool TypeConverter::isLegal(Type type) const {
   return convertType(type) == type;
 }
@@ -3128,7 +3169,7 @@ mlir::convertOpResultTypes(Operation *op, ValueRange operands,
   newOp.addOperands(operands);
 
   SmallVector<Type> newResultTypes;
-  if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
+  if (failed(converter.convertTypes(op->getResults(), newResultTypes)))
     return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
 
   newOp.addTypes(newResultTypes);
diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index db8bd0f6378d2..7b5e6e796a528 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -142,3 +142,21 @@ func.func @test_signature_conversion_no_converter() {
   }) : () -> ()
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @context_aware_conversion()
+func.func @context_aware_conversion() {
+  // Case 1: Convert i37 --> i38.
+  // CHECK: %[[cast0:.*]] = unrealized_conversion_cast %{{.*}} : i37 to i38
+  // CHECK: "test.legal_op_d"(%[[cast0]]) : (i38) -> ()
+  %0 = "test.context_op"() {increment = 1 : i64} : () -> (i37)
+  "test.replace_with_legal_op"(%0) : (i37) -> ()
+
+  // Case 2: Convert i37 --> i39.
+  // CHECK: %[[cast1:.*]] = unrealized_conversion_cast %{{.*}} : i37 to i39
+  // CHECK: "test.legal_op_d"(%[[cast1]]) : (i39) -> ()
+  %1 = "test.context_op"() {increment = 2 : i64} : () -> (i37)
+  "test.replace_with_legal_op"(%1) : (i37) -> ()
+  return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index d073843484d81..bd85e6fd9ae7f 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1827,9 +1827,9 @@ struct TestReplaceWithLegalOp : public ConversionPattern {
       : ConversionPattern(converter, "test.replace_with_legal_op",
                           /*benefit=*/1, ctx) {}
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
                   ConversionPatternRewriter &rewriter) const final {
-    rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0]);
+    rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0].front());
     return success();
   }
 };
@@ -1865,7 +1865,7 @@ struct TestTypeConversionDriver
       return nullptr;
     });
     converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) {
-      // Drop all integer types.
+      // Drop all other integer types.
       return success();
     });
     converter.addConversion(
@@ -1902,6 +1902,19 @@ struct TestTypeConversionDriver
           results.push_back(result);
           return success();
         });
+    converter.addConversion([](Value v) -> std::optional<Type> {
+      auto intType = dyn_cast<IntegerType>(v.getType());
+      if (!intType || intType.getWidth() != 37)
+        return std::nullopt;
+      Operation *op = v.getDefiningOp();
+      if (!op)
+        return std::nullopt;
+      auto incrementAttr = op->getAttrOfType<IntegerAttr>("increment");
+      if (!incrementAttr)
+        return std::nullopt;
+      return IntegerType::get(v.getContext(),
+                              intType.getWidth() + incrementAttr.getInt());
+    });
 
     /// Add the legal set of type materializations.
     converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
@@ -1922,9 +1935,19 @@ struct TestTypeConversionDriver
       // Otherwise, fail.
       return nullptr;
     });
+    // Materialize i37 to any desired type with unrealized_conversion_cast.
+    converter.addTargetMaterialization([](OpBuilder &builder, Type type,
+                                          ValueRange inputs,
+                                          Location loc) -> Value {
+      if (inputs.size() != 1 || !inputs[0].getType().isInteger(37))
+        return Value();
+      return builder.create<UnrealizedConversionCastOp>(loc, type, inputs)
+          .getResult(0);
+    });
 
     // Initialize the conversion target.
     mlir::ConversionTarget target(getContext());
+    target.addLegalOp(OperationName("test.context_op", &getContext()));
     target.addLegalOp<LegalOpD>();
     target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
       auto recursiveType = dyn_cast<test::TestRecursiveType>(op.getType());



More information about the Mlir-commits mailing list