[Mlir-commits] [mlir] 337707a - [mlir][Transforms] Dialect conversion: Context-aware type conversions (#140434)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Aug 27 00:13:55 PDT 2025


Author: Matthias Springer
Date: 2025-08-27T09:13:52+02:00
New Revision: 337707a5417dbdc8751c2a11eda920e250417b5a

URL: https://github.com/llvm/llvm-project/commit/337707a5417dbdc8751c2a11eda920e250417b5a
DIFF: https://github.com/llvm/llvm-project/commit/337707a5417dbdc8751c2a11eda920e250417b5a.diff

LOG: [mlir][Transforms] Dialect conversion: Context-aware type conversions (#140434)

This commit adds support for context-aware type conversions: type
conversion rules that can return different types depending on the IR.

There is no change for existing (context-unaware) type conversion rules:
```c++
// Example: Conversion any integer type to f32.
converter.addConversion([](IntegerType t) {
  return Float32Type::get(t.getContext());
}
```

There is now an additional overload to register context-aware type
conversion rules:
```c++
// Example: Type conversion rule for integers, depending on the context:
// Get the defining op of `v`, read its "increment" attribute and return an
// integer with a bitwidth that is increased by "increment".
converter.addConversion([](Value v) -> std::optional<Type> {
  auto intType = dyn_cast<IntegerType>(v.getType());
  if (!intType)
    return std::nullopt;
  Operation *op = v.getDefiningOp();
  if (!op)
    return std::nullopt;
  auto incrementAttr = op->getAttrOfType<IntegerAttr>("increment");
  if (!incrementAttr)
    return std::nullopt;
  return IntegerType::get(v.getContext(),
                          intType.getWidth() + incrementAttr.getInt());
});
```

For performance reasons, the type converter caches the result of type
conversions. This is no longer possible when there context-aware type
conversions because each conversion could compute a different type
depending on the context. There is no performance degradation when there
are only context-unaware type conversions.

Note: This commit just adds context-aware type conversions to the
dialect conversion framework. There are many existing patterns that
still call `converter.convertType(someValue.getType())`. These should be
gradually updated in subsequent commits to call
`converter.convertType(someValue)`.

Co-authored-by: Markus Böck <markus.boeck02 at gmail.com>

Added: 
    mlir/test/Transforms/test-context-aware-type-converter.mlir

Modified: 
    mlir/docs/DialectConversion.md
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
    mlir/lib/Transforms/Utils/DialectConversion.cpp
    mlir/test/Transforms/test-legalize-type-conversion.mlir
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index 556e73c2d56c7..7070351755e7a 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -280,6 +280,15 @@ target types. If the source type is converted to itself, we say it is a "legal"
 type. Type conversions are specified via the `addConversion` method described
 below.
 
+There are two kind of conversion functions: context-aware and context-unaware
+conversions. A context-unaware conversion function converts a `Type` into a
+`Type`. A context-aware conversion function converts a `Value` into a type. The
+latter allows users to customize type conversion rules based on the IR.
+
+Note: When there is at least one context-aware type conversion function, the
+result of type conversions can no longer be cached, which can increase
+compilation time. Use this feature with caution!
+
 A `materialization` describes how a list of values should be converted to a
 list of values with specific types. An important distinction from a
 `conversion` is that a `materialization` can produce IR, whereas a `conversion`
@@ -332,29 +341,31 @@ Several of the available hooks are detailed below:
 ```c++
 class TypeConverter {
  public:
-  /// Register a conversion function. A conversion function defines how a given
-  /// source type should be converted. A conversion function must be convertible
-  /// to any of the following forms(where `T` is a class derived from `Type`:
-  ///   * Optional<Type>(T)
+  /// Register a conversion function. A conversion function must be convertible
+  /// to any of the following forms (where `T` is `Value` or a class derived
+  /// from `Type`, including `Type` itself):
+  ///
+  ///   * std::optional<Type>(T)
   ///     - This form represents a 1-1 type conversion. It should return nullptr
-  ///       or `std::nullopt` to signify failure. If `std::nullopt` is returned, the
-  ///       converter is allowed to try another conversion function to perform
-  ///       the conversion.
-  ///   * Optional<LogicalResult>(T, SmallVectorImpl<Type> &)
+  ///       or `std::nullopt` to signify failure. If `std::nullopt` is returned,
+  ///       the converter is allowed to try another conversion function to
+  ///       perform the conversion.
+  ///   * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &)
   ///     - This form represents a 1-N type conversion. It should return
-  ///       `failure` or `std::nullopt` to signify a failed conversion. If the new
-  ///       set of types is empty, the type is removed and any usages of the
+  ///       `failure` or `std::nullopt` to signify a failed conversion. If the
+  ///       new set of types is empty, the type is removed and any usages of the
   ///       existing value are expected to be removed during conversion. If
   ///       `std::nullopt` is returned, the converter is allowed to try another
   ///       conversion function to perform the conversion.
-  ///   * Optional<LogicalResult>(T, SmallVectorImpl<Type> &, ArrayRef<Type>)
-  ///     - This form represents a 1-N type conversion supporting recursive
-  ///       types. The first two arguments and the return value are the same as
-  ///       for the regular 1-N form. The third argument is contains is the
-  ///       "call stack" of the recursive conversion: it contains the list of
-  ///       types currently being converted, with the current type being the
-  ///       last one. If it is present more than once in the list, the
-  ///       conversion concerns a recursive type.
+  ///
+  /// Conversion functions that accept `Value` as the first argument are
+  /// context-aware. I.e., they can take into account IR when converting the
+  /// type of the given value. Context-unaware conversion functions accept
+  /// `Type` or a derived class as the first argument.
+  ///
+  /// Note: Context-unaware conversions are cached, but context-aware
+  /// conversions are not.
+  ///
   /// Note: When attempting to convert a type, e.g. via 'convertType', the
   ///       mostly recently added conversions will be invoked first.
   template <typename FnT,

diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index f23a70601fc0a..14dfbf18836c6 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -139,7 +139,8 @@ class TypeConverter {
   };
 
   /// Register a conversion function. A conversion function must be convertible
-  /// to any of the following forms (where `T` is a class derived from `Type`):
+  /// to any of the following forms (where `T` is `Value` or a class derived
+  /// from `Type`, including `Type` itself):
   ///
   ///   * std::optional<Type>(T)
   ///     - This form represents a 1-1 type conversion. It should return nullptr
@@ -154,6 +155,14 @@ class TypeConverter {
   ///       `std::nullopt` is returned, the converter is allowed to try another
   ///       conversion function to perform the conversion.
   ///
+  /// Conversion functions that accept `Value` as the first argument are
+  /// context-aware. I.e., they can take into account IR when converting the
+  /// type of the given value. Context-unaware conversion functions accept
+  /// `Type` or a derived class as the first argument.
+  ///
+  /// Note: Context-unaware conversions are cached, but context-aware
+  /// conversions are not.
+  ///
   /// Note: When attempting to convert a type, e.g. via 'convertType', the
   ///       mostly recently added conversions will be invoked first.
   template <typename FnT, typename T = typename llvm::function_traits<
@@ -242,15 +251,28 @@ class TypeConverter {
         wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback)));
   }
 
-  /// Convert the given type. This function should return failure if no valid
+  /// Convert the given type. This function returns failure if no valid
   /// conversion exists, success otherwise. If the new set of types is empty,
   /// the type is removed and any usages of the existing value are expected to
   /// be removed during conversion.
+  ///
+  /// Note: This overload invokes only context-unaware type conversion
+  /// functions. Users should call the other overload if possible.
   LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) const;
 
+  /// Convert the type of the given value. This function returns failure if no
+  /// valid conversion exists, success otherwise. If the new set of types is
+  /// empty, the type is removed and any usages of the existing value are
+  /// expected to be removed during conversion.
+  ///
+  /// Note: This overload invokes both context-aware and context-unaware type
+  /// conversion functions.
+  LogicalResult convertType(Value v, SmallVectorImpl<Type> &results) const;
+
   /// This hook simplifies defining 1-1 type conversions. This function returns
   /// the type to convert to on success, and a null type on failure.
   Type convertType(Type t) const;
+  Type convertType(Value v) const;
 
   /// Attempts a 1-1 type conversion, expecting the result type to be
   /// `TargetType`. Returns the converted type cast to `TargetType` on success,
@@ -259,25 +281,36 @@ class TypeConverter {
   TargetType convertType(Type t) const {
     return dyn_cast_or_null<TargetType>(convertType(t));
   }
+  template <typename TargetType>
+  TargetType convertType(Value v) const {
+    return dyn_cast_or_null<TargetType>(convertType(v));
+  }
 
-  /// Convert the given set of types, filling 'results' as necessary. This
-  /// returns failure if the conversion of any of the types fails, success
+  /// Convert the given types, filling 'results' as necessary. This returns
+  /// "failure" if the conversion of any of the types fails, "success"
   /// otherwise.
   LogicalResult convertTypes(TypeRange types,
                              SmallVectorImpl<Type> &results) const;
 
+  /// Convert the types of the given values, filling 'results' as necessary.
+  /// This returns "failure" if the conversion of any of the types fails,
+  /// "success" otherwise.
+  LogicalResult convertTypes(ValueRange values,
+                             SmallVectorImpl<Type> &results) const;
+
   /// Return true if the given type is legal for this type converter, i.e. the
   /// type converts to itself.
   bool isLegal(Type type) const;
+  bool isLegal(Value value) const;
 
   /// Return true if all of the given types are legal for this type converter.
-  template <typename RangeT>
-  std::enable_if_t<!std::is_convertible<RangeT, Type>::value &&
-                       !std::is_convertible<RangeT, Operation *>::value,
-                   bool>
-  isLegal(RangeT &&range) const {
+  bool isLegal(TypeRange range) const {
     return llvm::all_of(range, [this](Type type) { return isLegal(type); });
   }
+  bool isLegal(ValueRange range) const {
+    return llvm::all_of(range, [this](Value value) { return isLegal(value); });
+  }
+
   /// Return true if the given operation has legal operand and result types.
   bool isLegal(Operation *op) const;
 
@@ -296,6 +329,11 @@ class TypeConverter {
   LogicalResult convertSignatureArgs(TypeRange types,
                                      SignatureConversion &result,
                                      unsigned origInputOffset = 0) const;
+  LogicalResult convertSignatureArg(unsigned inputNo, Value value,
+                                    SignatureConversion &result) const;
+  LogicalResult convertSignatureArgs(ValueRange values,
+                                     SignatureConversion &result,
+                                     unsigned origInputOffset = 0) const;
 
   /// This function converts the type signature of the given block, by invoking
   /// 'convertSignatureArg' for each argument. This function should return a
@@ -329,7 +367,7 @@ class TypeConverter {
   /// types is empty, the type is removed and any usages of the existing value
   /// are expected to be removed during conversion.
   using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
-      Type, SmallVectorImpl<Type> &)>;
+      PointerUnion<Type, Value>, SmallVectorImpl<Type> &)>;
 
   /// The signature of the callback used to materialize a source conversion.
   ///
@@ -349,13 +387,14 @@ class TypeConverter {
 
   /// Generate a wrapper for the given callback. This allows for accepting
   /// 
diff erent callback forms, that all compose into a single version.
-  /// With callback of form: `std::optional<Type>(T)`
+  /// With callback of form: `std::optional<Type>(T)`, where `T` can be a
+  /// `Value` or a `Type` (or a class derived from `Type`).
   template <typename T, typename FnT>
   std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn>
-  wrapCallback(FnT &&callback) const {
+  wrapCallback(FnT &&callback) {
     return wrapCallback<T>([callback = std::forward<FnT>(callback)](
-                               T type, SmallVectorImpl<Type> &results) {
-      if (std::optional<Type> resultOpt = callback(type)) {
+                               T typeOrValue, SmallVectorImpl<Type> &results) {
+      if (std::optional<Type> resultOpt = callback(typeOrValue)) {
         bool wasSuccess = static_cast<bool>(*resultOpt);
         if (wasSuccess)
           results.push_back(*resultOpt);
@@ -365,20 +404,49 @@ class TypeConverter {
     });
   }
   /// With callback of form: `std::optional<LogicalResult>(
-  ///     T, SmallVectorImpl<Type> &, ArrayRef<Type>)`.
+  ///     T, SmallVectorImpl<Type> &)`, where `T` is a type.
   template <typename T, typename FnT>
-  std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &>,
+  std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
+                       std::is_base_of_v<Type, T>,
                    ConversionCallbackFn>
   wrapCallback(FnT &&callback) const {
     return [callback = std::forward<FnT>(callback)](
-               Type type,
+               PointerUnion<Type, Value> typeOrValue,
                SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
-      T derivedType = dyn_cast<T>(type);
+      T derivedType;
+      if (Type t = dyn_cast<Type>(typeOrValue)) {
+        derivedType = dyn_cast<T>(t);
+      } else if (Value v = dyn_cast<Value>(typeOrValue)) {
+        derivedType = dyn_cast<T>(v.getType());
+      } else {
+        llvm_unreachable("unexpected variant");
+      }
       if (!derivedType)
         return std::nullopt;
       return callback(derivedType, results);
     };
   }
+  /// With callback of form: `std::optional<LogicalResult>(
+  ///     T, SmallVectorImpl<Type>)`, where `T` is a `Value`.
+  template <typename T, typename FnT>
+  std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
+                       std::is_same_v<T, Value>,
+                   ConversionCallbackFn>
+  wrapCallback(FnT &&callback) {
+    hasContextAwareTypeConversions = true;
+    return [callback = std::forward<FnT>(callback)](
+               PointerUnion<Type, Value> typeOrValue,
+               SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
+      if (Type t = dyn_cast<Type>(typeOrValue)) {
+        // Context-aware type conversion was called with a type.
+        return std::nullopt;
+      } else if (Value v = dyn_cast<Value>(typeOrValue)) {
+        return callback(v, results);
+      }
+      llvm_unreachable("unexpected variant");
+      return std::nullopt;
+    };
+  }
 
   /// Register a type conversion.
   void registerConversion(ConversionCallbackFn callback) {
@@ -505,6 +573,12 @@ class TypeConverter {
   mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
   /// A mutex used for cache access
   mutable llvm::sys::SmartRWMutex<true> cacheMutex;
+  /// Whether the type converter has context-aware type conversions. I.e.,
+  /// conversion rules that depend on the SSA value instead of just the type.
+  /// Type conversion caching is deactivated when there are context-aware
+  /// conversions because the type converter may return 
diff erent results for
+  /// the same input type.
+  bool hasContextAwareTypeConversions = false;
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index 3b75970c98ad4..072bc501aa5c6 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -52,8 +52,8 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
     SmallVector<unsigned> offsets;
     offsets.push_back(0);
     // Do the type conversion and record the offsets.
-    for (Type type : op.getResultTypes()) {
-      if (failed(typeConverter->convertTypes(type, dstTypes)))
+    for (Value v : op.getResults()) {
+      if (failed(typeConverter->convertType(v, dstTypes)))
         return rewriter.notifyMatchFailure(op, "could not convert result type");
       offsets.push_back(dstTypes.size());
     }
@@ -127,7 +127,6 @@ class ConvertForOpTypes
     // Inline the type converted region from the original operation.
     rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
                                 newOp.getRegion().end());
-
     return newOp;
   }
 };
@@ -226,15 +225,14 @@ void mlir::scf::populateSCFStructuralTypeConversions(
 
 void mlir::scf::populateSCFStructuralTypeConversionTarget(
     const TypeConverter &typeConverter, ConversionTarget &target) {
-  target.addDynamicallyLegalOp<ForOp, IfOp>([&](Operation *op) {
-    return typeConverter.isLegal(op->getResultTypes());
-  });
+  target.addDynamicallyLegalOp<ForOp, IfOp>(
+      [&](Operation *op) { return typeConverter.isLegal(op->getResults()); });
   target.addDynamicallyLegalOp<scf::YieldOp>([&](scf::YieldOp op) {
     // We only have conversions for a subset of ops that use scf.yield
     // terminators.
     if (!isa<ForOp, IfOp, WhileOp>(op->getParentOp()))
       return true;
-    return typeConverter.isLegal(op.getOperandTypes());
+    return typeConverter.isLegal(op.getOperands());
   });
   target.addDynamicallyLegalOp<WhileOp, ConditionOp>(
       [&](Operation *op) { return typeConverter.isLegal(op); });

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index e3248204d6694..a0232937e9a78 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1436,7 +1436,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
 
     // If there is no legal conversion, fail to match this pattern.
     SmallVector<Type, 1> legalTypes;
-    if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
+    if (failed(currentTypeConverter->convertType(operand, legalTypes))) {
       notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
         diag << "unable to convert type for " << valueDiagTag << " #"
              << it.index() << ", type was " << origType;
@@ -3430,6 +3430,27 @@ LogicalResult TypeConverter::convertType(Type t,
   return failure();
 }
 
+LogicalResult TypeConverter::convertType(Value v,
+                                         SmallVectorImpl<Type> &results) const {
+  assert(v && "expected non-null value");
+
+  // If this type converter does not have context-aware type conversions, call
+  // the type-based overload, which has caching.
+  if (!hasContextAwareTypeConversions)
+    return convertType(v.getType(), results);
+
+  // Walk the added converters in reverse order to apply the most recently
+  // registered first.
+  for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
+    if (std::optional<LogicalResult> result = converter(v, results)) {
+      if (!succeeded(*result))
+        return failure();
+      return success();
+    }
+  }
+  return failure();
+}
+
 Type TypeConverter::convertType(Type t) const {
   // Use the multi-type result version to convert the type.
   SmallVector<Type, 1> results;
@@ -3440,6 +3461,16 @@ Type TypeConverter::convertType(Type t) const {
   return results.size() == 1 ? results.front() : nullptr;
 }
 
+Type TypeConverter::convertType(Value v) const {
+  // Use the multi-type result version to convert the type.
+  SmallVector<Type, 1> results;
+  if (failed(convertType(v, results)))
+    return nullptr;
+
+  // Check to ensure that only one type was produced.
+  return results.size() == 1 ? results.front() : nullptr;
+}
+
 LogicalResult
 TypeConverter::convertTypes(TypeRange types,
                             SmallVectorImpl<Type> &results) const {
@@ -3449,21 +3480,38 @@ TypeConverter::convertTypes(TypeRange types,
   return success();
 }
 
+LogicalResult
+TypeConverter::convertTypes(ValueRange values,
+                            SmallVectorImpl<Type> &results) const {
+  for (Value value : values)
+    if (failed(convertType(value, results)))
+      return failure();
+  return success();
+}
+
 bool TypeConverter::isLegal(Type type) const {
   return convertType(type) == type;
 }
+
+bool TypeConverter::isLegal(Value value) const {
+  return convertType(value) == value.getType();
+}
+
 bool TypeConverter::isLegal(Operation *op) const {
-  return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
+  return isLegal(op->getOperands()) && isLegal(op->getResults());
 }
 
 bool TypeConverter::isLegal(Region *region) const {
-  return llvm::all_of(*region, [this](Block &block) {
-    return isLegal(block.getArgumentTypes());
-  });
+  return llvm::all_of(
+      *region, [this](Block &block) { return isLegal(block.getArguments()); });
 }
 
 bool TypeConverter::isSignatureLegal(FunctionType ty) const {
-  return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
+  if (!isLegal(ty.getInputs()))
+    return false;
+  if (!isLegal(ty.getResults()))
+    return false;
+  return true;
 }
 
 LogicalResult
@@ -3491,6 +3539,31 @@ TypeConverter::convertSignatureArgs(TypeRange types,
       return failure();
   return success();
 }
+LogicalResult
+TypeConverter::convertSignatureArg(unsigned inputNo, Value value,
+                                   SignatureConversion &result) const {
+  // Try to convert the given input type.
+  SmallVector<Type, 1> convertedTypes;
+  if (failed(convertType(value, convertedTypes)))
+    return failure();
+
+  // If this argument is being dropped, there is nothing left to do.
+  if (convertedTypes.empty())
+    return success();
+
+  // Otherwise, add the new inputs.
+  result.addInputs(inputNo, convertedTypes);
+  return success();
+}
+LogicalResult
+TypeConverter::convertSignatureArgs(ValueRange values,
+                                    SignatureConversion &result,
+                                    unsigned origInputOffset) const {
+  for (unsigned i = 0, e = values.size(); i != e; ++i)
+    if (failed(convertSignatureArg(origInputOffset + i, values[i], result)))
+      return failure();
+  return success();
+}
 
 Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
                                                  Location loc, Type resultType,
@@ -3534,7 +3607,7 @@ SmallVector<Value> TypeConverter::materializeTargetConversion(
 std::optional<TypeConverter::SignatureConversion>
 TypeConverter::convertBlockSignature(Block *block) const {
   SignatureConversion conversion(block->getNumArguments());
-  if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion)))
+  if (failed(convertSignatureArgs(block->getArguments(), conversion)))
     return std::nullopt;
   return conversion;
 }
@@ -3659,7 +3732,7 @@ mlir::convertOpResultTypes(Operation *op, ValueRange operands,
   newOp.addOperands(operands);
 
   SmallVector<Type> newResultTypes;
-  if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
+  if (failed(converter.convertTypes(op->getResults(), newResultTypes)))
     return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
 
   newOp.addTypes(newResultTypes);

diff  --git a/mlir/test/Transforms/test-context-aware-type-converter.mlir b/mlir/test/Transforms/test-context-aware-type-converter.mlir
new file mode 100644
index 0000000000000..ae178b676a392
--- /dev/null
+++ b/mlir/test/Transforms/test-context-aware-type-converter.mlir
@@ -0,0 +1,40 @@
+// RUN: mlir-opt %s -test-legalize-type-conversion="allow-pattern-rollback=0" -split-input-file -verify-diagnostics | FileCheck %s
+
+// CHECK-LABEL: func @simple_context_aware_conversion_1()
+func.func @simple_context_aware_conversion_1() attributes {increment = 1 : i64} {
+  // Case 1: Convert i37 --> i38.
+  // CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i37 to i38
+  // CHECK: "test.legal_op_d"(%[[cast]]) : (i38) -> ()
+  %0 = "test.context_op"() : () -> (i37)
+  "test.replace_with_legal_op"(%0) : (i37) -> ()
+  return
+}
+
+// CHECK-LABEL: func @simple_context_aware_conversion_2()
+func.func @simple_context_aware_conversion_2() attributes {increment = 2 : i64} {
+  // Case 2: Convert i37 --> i39.
+  // CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i37 to i39
+  // CHECK: "test.legal_op_d"(%[[cast]]) : (i39) -> ()
+  %0 = "test.context_op"() : () -> (i37)
+  "test.replace_with_legal_op"(%0) : (i37) -> ()
+  return
+}
+
+// -----
+
+// Note: This test case does not work with allow-pattern-rollback=1. When
+// rollback is enabled, the type converter cannot find the enclosing function
+// because the operand of the scf.yield is pointing to a detached block.
+
+// CHECK-LABEL: func @convert_block_arguments
+//       CHECK:   %[[cast:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i37 to i38
+//       CHECK:   scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[iter:.*]] = %[[cast]]) -> (i38) {
+//       CHECK:     scf.yield %[[iter]] : i38
+//       CHECK:   }
+func.func @convert_block_arguments(%lb: index, %ub: index, %step: index) attributes {increment = 1 : i64} {
+  %0 = "test.context_op"() : () -> (i37)
+  scf.for %iv = %lb to %ub step %step iter_args(%arg0 = %0) -> i37 {
+    scf.yield %arg0 : i37
+  }
+  return
+}

diff  --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index 9bffe92b374d5..c003f8b2cb1cd 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -142,3 +142,4 @@ func.func @test_signature_conversion_no_converter() {
   }) : () -> ()
   return
 }
+

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index b6f16ac1b5c48..95f381ec471d6 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/CommonFolders.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Matchers.h"
@@ -1983,9 +1984,9 @@ struct TestReplaceWithLegalOp : public ConversionPattern {
       : ConversionPattern(converter, "test.replace_with_legal_op",
                           /*benefit=*/1, ctx) {}
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
                   ConversionPatternRewriter &rewriter) const final {
-    rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0]);
+    rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0].front());
     return success();
   }
 };
@@ -1994,6 +1995,10 @@ struct TestTypeConversionDriver
     : public PassWrapper<TestTypeConversionDriver, OperationPass<>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver)
 
+  TestTypeConversionDriver() = default;
+  TestTypeConversionDriver(const TestTypeConversionDriver &other)
+      : PassWrapper(other) {}
+
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<TestDialect>();
   }
@@ -2020,8 +2025,13 @@ struct TestTypeConversionDriver
       // Otherwise, the type is illegal.
       return nullptr;
     });
-    converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) {
-      // Drop all integer types.
+    converter.addConversion([](IndexType type) { return type; });
+    converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &types) {
+      if (type.isInteger(38)) {
+        // i38 is legal.
+        types.push_back(type);
+      }
+      // Drop all other integer types.
       return success();
     });
     converter.addConversion(
@@ -2058,6 +2068,33 @@ struct TestTypeConversionDriver
           results.push_back(result);
           return success();
         });
+    converter.addConversion([](Value v) -> std::optional<Type> {
+      // Context-aware type conversion rule that converts i37 to
+      // i(37 + increment). The increment is taken from the enclosing
+      // function.
+      auto intType = dyn_cast<IntegerType>(v.getType());
+      if (!intType || intType.getWidth() != 37)
+        return std::nullopt;
+      Region *r = v.getParentRegion();
+      if (!r) {
+        // No enclosing region found. This can happen when running with
+        // allow-pattern-rollback = true. Context-aware type conversions are
+        // not fully supported when running in rollback mode.
+        return Type();
+      }
+      Operation *op = r->getParentOp();
+      if (!op)
+        return Type();
+      if (!isa<FunctionOpInterface>(op))
+        op = op->getParentOfType<FunctionOpInterface>();
+      if (!op)
+        return Type();
+      auto incrementAttr = op->getAttrOfType<IntegerAttr>("increment");
+      if (!incrementAttr)
+        return Type();
+      return IntegerType::get(v.getContext(),
+                              intType.getWidth() + incrementAttr.getInt());
+    });
 
     /// Add the legal set of type materializations.
     converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
@@ -2078,9 +2115,19 @@ struct TestTypeConversionDriver
       // Otherwise, fail.
       return nullptr;
     });
+    // Materialize i37 to any desired type with unrealized_conversion_cast.
+    converter.addTargetMaterialization([](OpBuilder &builder, Type type,
+                                          ValueRange inputs,
+                                          Location loc) -> Value {
+      if (inputs.size() != 1 || !inputs[0].getType().isInteger(37))
+        return Value();
+      return builder.create<UnrealizedConversionCastOp>(loc, type, inputs)
+          .getResult(0);
+    });
 
     // Initialize the conversion target.
     mlir::ConversionTarget target(getContext());
+    target.addLegalOp(OperationName("test.context_op", &getContext()));
     target.addLegalOp<LegalOpD>();
     target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
       auto recursiveType = dyn_cast<test::TestRecursiveType>(op.getType());
@@ -2111,11 +2158,19 @@ struct TestTypeConversionDriver
     patterns.add<TestTypeConversionAnotherProducer>(&getContext());
     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
                                                               converter);
+    mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
+        converter, patterns, target);
 
+    ConversionConfig config;
+    config.allowPatternRollback = allowPatternRollback;
     if (failed(applyPartialConversion(getOperation(), target,
-                                      std::move(patterns))))
+                                      std::move(patterns), config)))
       signalPassFailure();
   }
+
+  Option<bool> allowPatternRollback{*this, "allow-pattern-rollback",
+                                    llvm::cl::desc("Allow pattern rollback"),
+                                    llvm::cl::init(true)};
 };
 } // namespace
 


        


More information about the Mlir-commits mailing list