[Mlir-commits] [mlir] 337707a - [mlir][Transforms] Dialect conversion: Context-aware type conversions (#140434)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Aug 27 00:13:55 PDT 2025
Author: Matthias Springer
Date: 2025-08-27T09:13:52+02:00
New Revision: 337707a5417dbdc8751c2a11eda920e250417b5a
URL: https://github.com/llvm/llvm-project/commit/337707a5417dbdc8751c2a11eda920e250417b5a
DIFF: https://github.com/llvm/llvm-project/commit/337707a5417dbdc8751c2a11eda920e250417b5a.diff
LOG: [mlir][Transforms] Dialect conversion: Context-aware type conversions (#140434)
This commit adds support for context-aware type conversions: type
conversion rules that can return different types depending on the IR.
There is no change for existing (context-unaware) type conversion rules:
```c++
// Example: Conversion any integer type to f32.
converter.addConversion([](IntegerType t) {
return Float32Type::get(t.getContext());
}
```
There is now an additional overload to register context-aware type
conversion rules:
```c++
// Example: Type conversion rule for integers, depending on the context:
// Get the defining op of `v`, read its "increment" attribute and return an
// integer with a bitwidth that is increased by "increment".
converter.addConversion([](Value v) -> std::optional<Type> {
auto intType = dyn_cast<IntegerType>(v.getType());
if (!intType)
return std::nullopt;
Operation *op = v.getDefiningOp();
if (!op)
return std::nullopt;
auto incrementAttr = op->getAttrOfType<IntegerAttr>("increment");
if (!incrementAttr)
return std::nullopt;
return IntegerType::get(v.getContext(),
intType.getWidth() + incrementAttr.getInt());
});
```
For performance reasons, the type converter caches the result of type
conversions. This is no longer possible when there context-aware type
conversions because each conversion could compute a different type
depending on the context. There is no performance degradation when there
are only context-unaware type conversions.
Note: This commit just adds context-aware type conversions to the
dialect conversion framework. There are many existing patterns that
still call `converter.convertType(someValue.getType())`. These should be
gradually updated in subsequent commits to call
`converter.convertType(someValue)`.
Co-authored-by: Markus Böck <markus.boeck02 at gmail.com>
Added:
mlir/test/Transforms/test-context-aware-type-converter.mlir
Modified:
mlir/docs/DialectConversion.md
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
mlir/lib/Transforms/Utils/DialectConversion.cpp
mlir/test/Transforms/test-legalize-type-conversion.mlir
mlir/test/lib/Dialect/Test/TestPatterns.cpp
Removed:
################################################################################
diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index 556e73c2d56c7..7070351755e7a 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -280,6 +280,15 @@ target types. If the source type is converted to itself, we say it is a "legal"
type. Type conversions are specified via the `addConversion` method described
below.
+There are two kind of conversion functions: context-aware and context-unaware
+conversions. A context-unaware conversion function converts a `Type` into a
+`Type`. A context-aware conversion function converts a `Value` into a type. The
+latter allows users to customize type conversion rules based on the IR.
+
+Note: When there is at least one context-aware type conversion function, the
+result of type conversions can no longer be cached, which can increase
+compilation time. Use this feature with caution!
+
A `materialization` describes how a list of values should be converted to a
list of values with specific types. An important distinction from a
`conversion` is that a `materialization` can produce IR, whereas a `conversion`
@@ -332,29 +341,31 @@ Several of the available hooks are detailed below:
```c++
class TypeConverter {
public:
- /// Register a conversion function. A conversion function defines how a given
- /// source type should be converted. A conversion function must be convertible
- /// to any of the following forms(where `T` is a class derived from `Type`:
- /// * Optional<Type>(T)
+ /// Register a conversion function. A conversion function must be convertible
+ /// to any of the following forms (where `T` is `Value` or a class derived
+ /// from `Type`, including `Type` itself):
+ ///
+ /// * std::optional<Type>(T)
/// - This form represents a 1-1 type conversion. It should return nullptr
- /// or `std::nullopt` to signify failure. If `std::nullopt` is returned, the
- /// converter is allowed to try another conversion function to perform
- /// the conversion.
- /// * Optional<LogicalResult>(T, SmallVectorImpl<Type> &)
+ /// or `std::nullopt` to signify failure. If `std::nullopt` is returned,
+ /// the converter is allowed to try another conversion function to
+ /// perform the conversion.
+ /// * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &)
/// - This form represents a 1-N type conversion. It should return
- /// `failure` or `std::nullopt` to signify a failed conversion. If the new
- /// set of types is empty, the type is removed and any usages of the
+ /// `failure` or `std::nullopt` to signify a failed conversion. If the
+ /// new set of types is empty, the type is removed and any usages of the
/// existing value are expected to be removed during conversion. If
/// `std::nullopt` is returned, the converter is allowed to try another
/// conversion function to perform the conversion.
- /// * Optional<LogicalResult>(T, SmallVectorImpl<Type> &, ArrayRef<Type>)
- /// - This form represents a 1-N type conversion supporting recursive
- /// types. The first two arguments and the return value are the same as
- /// for the regular 1-N form. The third argument is contains is the
- /// "call stack" of the recursive conversion: it contains the list of
- /// types currently being converted, with the current type being the
- /// last one. If it is present more than once in the list, the
- /// conversion concerns a recursive type.
+ ///
+ /// Conversion functions that accept `Value` as the first argument are
+ /// context-aware. I.e., they can take into account IR when converting the
+ /// type of the given value. Context-unaware conversion functions accept
+ /// `Type` or a derived class as the first argument.
+ ///
+ /// Note: Context-unaware conversions are cached, but context-aware
+ /// conversions are not.
+ ///
/// Note: When attempting to convert a type, e.g. via 'convertType', the
/// mostly recently added conversions will be invoked first.
template <typename FnT,
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index f23a70601fc0a..14dfbf18836c6 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -139,7 +139,8 @@ class TypeConverter {
};
/// Register a conversion function. A conversion function must be convertible
- /// to any of the following forms (where `T` is a class derived from `Type`):
+ /// to any of the following forms (where `T` is `Value` or a class derived
+ /// from `Type`, including `Type` itself):
///
/// * std::optional<Type>(T)
/// - This form represents a 1-1 type conversion. It should return nullptr
@@ -154,6 +155,14 @@ class TypeConverter {
/// `std::nullopt` is returned, the converter is allowed to try another
/// conversion function to perform the conversion.
///
+ /// Conversion functions that accept `Value` as the first argument are
+ /// context-aware. I.e., they can take into account IR when converting the
+ /// type of the given value. Context-unaware conversion functions accept
+ /// `Type` or a derived class as the first argument.
+ ///
+ /// Note: Context-unaware conversions are cached, but context-aware
+ /// conversions are not.
+ ///
/// Note: When attempting to convert a type, e.g. via 'convertType', the
/// mostly recently added conversions will be invoked first.
template <typename FnT, typename T = typename llvm::function_traits<
@@ -242,15 +251,28 @@ class TypeConverter {
wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback)));
}
- /// Convert the given type. This function should return failure if no valid
+ /// Convert the given type. This function returns failure if no valid
/// conversion exists, success otherwise. If the new set of types is empty,
/// the type is removed and any usages of the existing value are expected to
/// be removed during conversion.
+ ///
+ /// Note: This overload invokes only context-unaware type conversion
+ /// functions. Users should call the other overload if possible.
LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) const;
+ /// Convert the type of the given value. This function returns failure if no
+ /// valid conversion exists, success otherwise. If the new set of types is
+ /// empty, the type is removed and any usages of the existing value are
+ /// expected to be removed during conversion.
+ ///
+ /// Note: This overload invokes both context-aware and context-unaware type
+ /// conversion functions.
+ LogicalResult convertType(Value v, SmallVectorImpl<Type> &results) const;
+
/// This hook simplifies defining 1-1 type conversions. This function returns
/// the type to convert to on success, and a null type on failure.
Type convertType(Type t) const;
+ Type convertType(Value v) const;
/// Attempts a 1-1 type conversion, expecting the result type to be
/// `TargetType`. Returns the converted type cast to `TargetType` on success,
@@ -259,25 +281,36 @@ class TypeConverter {
TargetType convertType(Type t) const {
return dyn_cast_or_null<TargetType>(convertType(t));
}
+ template <typename TargetType>
+ TargetType convertType(Value v) const {
+ return dyn_cast_or_null<TargetType>(convertType(v));
+ }
- /// Convert the given set of types, filling 'results' as necessary. This
- /// returns failure if the conversion of any of the types fails, success
+ /// Convert the given types, filling 'results' as necessary. This returns
+ /// "failure" if the conversion of any of the types fails, "success"
/// otherwise.
LogicalResult convertTypes(TypeRange types,
SmallVectorImpl<Type> &results) const;
+ /// Convert the types of the given values, filling 'results' as necessary.
+ /// This returns "failure" if the conversion of any of the types fails,
+ /// "success" otherwise.
+ LogicalResult convertTypes(ValueRange values,
+ SmallVectorImpl<Type> &results) const;
+
/// Return true if the given type is legal for this type converter, i.e. the
/// type converts to itself.
bool isLegal(Type type) const;
+ bool isLegal(Value value) const;
/// Return true if all of the given types are legal for this type converter.
- template <typename RangeT>
- std::enable_if_t<!std::is_convertible<RangeT, Type>::value &&
- !std::is_convertible<RangeT, Operation *>::value,
- bool>
- isLegal(RangeT &&range) const {
+ bool isLegal(TypeRange range) const {
return llvm::all_of(range, [this](Type type) { return isLegal(type); });
}
+ bool isLegal(ValueRange range) const {
+ return llvm::all_of(range, [this](Value value) { return isLegal(value); });
+ }
+
/// Return true if the given operation has legal operand and result types.
bool isLegal(Operation *op) const;
@@ -296,6 +329,11 @@ class TypeConverter {
LogicalResult convertSignatureArgs(TypeRange types,
SignatureConversion &result,
unsigned origInputOffset = 0) const;
+ LogicalResult convertSignatureArg(unsigned inputNo, Value value,
+ SignatureConversion &result) const;
+ LogicalResult convertSignatureArgs(ValueRange values,
+ SignatureConversion &result,
+ unsigned origInputOffset = 0) const;
/// This function converts the type signature of the given block, by invoking
/// 'convertSignatureArg' for each argument. This function should return a
@@ -329,7 +367,7 @@ class TypeConverter {
/// types is empty, the type is removed and any usages of the existing value
/// are expected to be removed during conversion.
using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
- Type, SmallVectorImpl<Type> &)>;
+ PointerUnion<Type, Value>, SmallVectorImpl<Type> &)>;
/// The signature of the callback used to materialize a source conversion.
///
@@ -349,13 +387,14 @@ class TypeConverter {
/// Generate a wrapper for the given callback. This allows for accepting
///
diff erent callback forms, that all compose into a single version.
- /// With callback of form: `std::optional<Type>(T)`
+ /// With callback of form: `std::optional<Type>(T)`, where `T` can be a
+ /// `Value` or a `Type` (or a class derived from `Type`).
template <typename T, typename FnT>
std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn>
- wrapCallback(FnT &&callback) const {
+ wrapCallback(FnT &&callback) {
return wrapCallback<T>([callback = std::forward<FnT>(callback)](
- T type, SmallVectorImpl<Type> &results) {
- if (std::optional<Type> resultOpt = callback(type)) {
+ T typeOrValue, SmallVectorImpl<Type> &results) {
+ if (std::optional<Type> resultOpt = callback(typeOrValue)) {
bool wasSuccess = static_cast<bool>(*resultOpt);
if (wasSuccess)
results.push_back(*resultOpt);
@@ -365,20 +404,49 @@ class TypeConverter {
});
}
/// With callback of form: `std::optional<LogicalResult>(
- /// T, SmallVectorImpl<Type> &, ArrayRef<Type>)`.
+ /// T, SmallVectorImpl<Type> &)`, where `T` is a type.
template <typename T, typename FnT>
- std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &>,
+ std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
+ std::is_base_of_v<Type, T>,
ConversionCallbackFn>
wrapCallback(FnT &&callback) const {
return [callback = std::forward<FnT>(callback)](
- Type type,
+ PointerUnion<Type, Value> typeOrValue,
SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
- T derivedType = dyn_cast<T>(type);
+ T derivedType;
+ if (Type t = dyn_cast<Type>(typeOrValue)) {
+ derivedType = dyn_cast<T>(t);
+ } else if (Value v = dyn_cast<Value>(typeOrValue)) {
+ derivedType = dyn_cast<T>(v.getType());
+ } else {
+ llvm_unreachable("unexpected variant");
+ }
if (!derivedType)
return std::nullopt;
return callback(derivedType, results);
};
}
+ /// With callback of form: `std::optional<LogicalResult>(
+ /// T, SmallVectorImpl<Type>)`, where `T` is a `Value`.
+ template <typename T, typename FnT>
+ std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
+ std::is_same_v<T, Value>,
+ ConversionCallbackFn>
+ wrapCallback(FnT &&callback) {
+ hasContextAwareTypeConversions = true;
+ return [callback = std::forward<FnT>(callback)](
+ PointerUnion<Type, Value> typeOrValue,
+ SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
+ if (Type t = dyn_cast<Type>(typeOrValue)) {
+ // Context-aware type conversion was called with a type.
+ return std::nullopt;
+ } else if (Value v = dyn_cast<Value>(typeOrValue)) {
+ return callback(v, results);
+ }
+ llvm_unreachable("unexpected variant");
+ return std::nullopt;
+ };
+ }
/// Register a type conversion.
void registerConversion(ConversionCallbackFn callback) {
@@ -505,6 +573,12 @@ class TypeConverter {
mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
/// A mutex used for cache access
mutable llvm::sys::SmartRWMutex<true> cacheMutex;
+ /// Whether the type converter has context-aware type conversions. I.e.,
+ /// conversion rules that depend on the SSA value instead of just the type.
+ /// Type conversion caching is deactivated when there are context-aware
+ /// conversions because the type converter may return
diff erent results for
+ /// the same input type.
+ bool hasContextAwareTypeConversions = false;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index 3b75970c98ad4..072bc501aa5c6 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -52,8 +52,8 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
SmallVector<unsigned> offsets;
offsets.push_back(0);
// Do the type conversion and record the offsets.
- for (Type type : op.getResultTypes()) {
- if (failed(typeConverter->convertTypes(type, dstTypes)))
+ for (Value v : op.getResults()) {
+ if (failed(typeConverter->convertType(v, dstTypes)))
return rewriter.notifyMatchFailure(op, "could not convert result type");
offsets.push_back(dstTypes.size());
}
@@ -127,7 +127,6 @@ class ConvertForOpTypes
// Inline the type converted region from the original operation.
rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
newOp.getRegion().end());
-
return newOp;
}
};
@@ -226,15 +225,14 @@ void mlir::scf::populateSCFStructuralTypeConversions(
void mlir::scf::populateSCFStructuralTypeConversionTarget(
const TypeConverter &typeConverter, ConversionTarget &target) {
- target.addDynamicallyLegalOp<ForOp, IfOp>([&](Operation *op) {
- return typeConverter.isLegal(op->getResultTypes());
- });
+ target.addDynamicallyLegalOp<ForOp, IfOp>(
+ [&](Operation *op) { return typeConverter.isLegal(op->getResults()); });
target.addDynamicallyLegalOp<scf::YieldOp>([&](scf::YieldOp op) {
// We only have conversions for a subset of ops that use scf.yield
// terminators.
if (!isa<ForOp, IfOp, WhileOp>(op->getParentOp()))
return true;
- return typeConverter.isLegal(op.getOperandTypes());
+ return typeConverter.isLegal(op.getOperands());
});
target.addDynamicallyLegalOp<WhileOp, ConditionOp>(
[&](Operation *op) { return typeConverter.isLegal(op); });
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index e3248204d6694..a0232937e9a78 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1436,7 +1436,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
// If there is no legal conversion, fail to match this pattern.
SmallVector<Type, 1> legalTypes;
- if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
+ if (failed(currentTypeConverter->convertType(operand, legalTypes))) {
notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
diag << "unable to convert type for " << valueDiagTag << " #"
<< it.index() << ", type was " << origType;
@@ -3430,6 +3430,27 @@ LogicalResult TypeConverter::convertType(Type t,
return failure();
}
+LogicalResult TypeConverter::convertType(Value v,
+ SmallVectorImpl<Type> &results) const {
+ assert(v && "expected non-null value");
+
+ // If this type converter does not have context-aware type conversions, call
+ // the type-based overload, which has caching.
+ if (!hasContextAwareTypeConversions)
+ return convertType(v.getType(), results);
+
+ // Walk the added converters in reverse order to apply the most recently
+ // registered first.
+ for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
+ if (std::optional<LogicalResult> result = converter(v, results)) {
+ if (!succeeded(*result))
+ return failure();
+ return success();
+ }
+ }
+ return failure();
+}
+
Type TypeConverter::convertType(Type t) const {
// Use the multi-type result version to convert the type.
SmallVector<Type, 1> results;
@@ -3440,6 +3461,16 @@ Type TypeConverter::convertType(Type t) const {
return results.size() == 1 ? results.front() : nullptr;
}
+Type TypeConverter::convertType(Value v) const {
+ // Use the multi-type result version to convert the type.
+ SmallVector<Type, 1> results;
+ if (failed(convertType(v, results)))
+ return nullptr;
+
+ // Check to ensure that only one type was produced.
+ return results.size() == 1 ? results.front() : nullptr;
+}
+
LogicalResult
TypeConverter::convertTypes(TypeRange types,
SmallVectorImpl<Type> &results) const {
@@ -3449,21 +3480,38 @@ TypeConverter::convertTypes(TypeRange types,
return success();
}
+LogicalResult
+TypeConverter::convertTypes(ValueRange values,
+ SmallVectorImpl<Type> &results) const {
+ for (Value value : values)
+ if (failed(convertType(value, results)))
+ return failure();
+ return success();
+}
+
bool TypeConverter::isLegal(Type type) const {
return convertType(type) == type;
}
+
+bool TypeConverter::isLegal(Value value) const {
+ return convertType(value) == value.getType();
+}
+
bool TypeConverter::isLegal(Operation *op) const {
- return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
+ return isLegal(op->getOperands()) && isLegal(op->getResults());
}
bool TypeConverter::isLegal(Region *region) const {
- return llvm::all_of(*region, [this](Block &block) {
- return isLegal(block.getArgumentTypes());
- });
+ return llvm::all_of(
+ *region, [this](Block &block) { return isLegal(block.getArguments()); });
}
bool TypeConverter::isSignatureLegal(FunctionType ty) const {
- return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
+ if (!isLegal(ty.getInputs()))
+ return false;
+ if (!isLegal(ty.getResults()))
+ return false;
+ return true;
}
LogicalResult
@@ -3491,6 +3539,31 @@ TypeConverter::convertSignatureArgs(TypeRange types,
return failure();
return success();
}
+LogicalResult
+TypeConverter::convertSignatureArg(unsigned inputNo, Value value,
+ SignatureConversion &result) const {
+ // Try to convert the given input type.
+ SmallVector<Type, 1> convertedTypes;
+ if (failed(convertType(value, convertedTypes)))
+ return failure();
+
+ // If this argument is being dropped, there is nothing left to do.
+ if (convertedTypes.empty())
+ return success();
+
+ // Otherwise, add the new inputs.
+ result.addInputs(inputNo, convertedTypes);
+ return success();
+}
+LogicalResult
+TypeConverter::convertSignatureArgs(ValueRange values,
+ SignatureConversion &result,
+ unsigned origInputOffset) const {
+ for (unsigned i = 0, e = values.size(); i != e; ++i)
+ if (failed(convertSignatureArg(origInputOffset + i, values[i], result)))
+ return failure();
+ return success();
+}
Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
Location loc, Type resultType,
@@ -3534,7 +3607,7 @@ SmallVector<Value> TypeConverter::materializeTargetConversion(
std::optional<TypeConverter::SignatureConversion>
TypeConverter::convertBlockSignature(Block *block) const {
SignatureConversion conversion(block->getNumArguments());
- if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion)))
+ if (failed(convertSignatureArgs(block->getArguments(), conversion)))
return std::nullopt;
return conversion;
}
@@ -3659,7 +3732,7 @@ mlir::convertOpResultTypes(Operation *op, ValueRange operands,
newOp.addOperands(operands);
SmallVector<Type> newResultTypes;
- if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
+ if (failed(converter.convertTypes(op->getResults(), newResultTypes)))
return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
newOp.addTypes(newResultTypes);
diff --git a/mlir/test/Transforms/test-context-aware-type-converter.mlir b/mlir/test/Transforms/test-context-aware-type-converter.mlir
new file mode 100644
index 0000000000000..ae178b676a392
--- /dev/null
+++ b/mlir/test/Transforms/test-context-aware-type-converter.mlir
@@ -0,0 +1,40 @@
+// RUN: mlir-opt %s -test-legalize-type-conversion="allow-pattern-rollback=0" -split-input-file -verify-diagnostics | FileCheck %s
+
+// CHECK-LABEL: func @simple_context_aware_conversion_1()
+func.func @simple_context_aware_conversion_1() attributes {increment = 1 : i64} {
+ // Case 1: Convert i37 --> i38.
+ // CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i37 to i38
+ // CHECK: "test.legal_op_d"(%[[cast]]) : (i38) -> ()
+ %0 = "test.context_op"() : () -> (i37)
+ "test.replace_with_legal_op"(%0) : (i37) -> ()
+ return
+}
+
+// CHECK-LABEL: func @simple_context_aware_conversion_2()
+func.func @simple_context_aware_conversion_2() attributes {increment = 2 : i64} {
+ // Case 2: Convert i37 --> i39.
+ // CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i37 to i39
+ // CHECK: "test.legal_op_d"(%[[cast]]) : (i39) -> ()
+ %0 = "test.context_op"() : () -> (i37)
+ "test.replace_with_legal_op"(%0) : (i37) -> ()
+ return
+}
+
+// -----
+
+// Note: This test case does not work with allow-pattern-rollback=1. When
+// rollback is enabled, the type converter cannot find the enclosing function
+// because the operand of the scf.yield is pointing to a detached block.
+
+// CHECK-LABEL: func @convert_block_arguments
+// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %{{.*}} : i37 to i38
+// CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[iter:.*]] = %[[cast]]) -> (i38) {
+// CHECK: scf.yield %[[iter]] : i38
+// CHECK: }
+func.func @convert_block_arguments(%lb: index, %ub: index, %step: index) attributes {increment = 1 : i64} {
+ %0 = "test.context_op"() : () -> (i37)
+ scf.for %iv = %lb to %ub step %step iter_args(%arg0 = %0) -> i37 {
+ scf.yield %arg0 : i37
+ }
+ return
+}
diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index 9bffe92b374d5..c003f8b2cb1cd 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -142,3 +142,4 @@ func.func @test_signature_conversion_no_converter() {
}) : () -> ()
return
}
+
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index b6f16ac1b5c48..95f381ec471d6 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Matchers.h"
@@ -1983,9 +1984,9 @@ struct TestReplaceWithLegalOp : public ConversionPattern {
: ConversionPattern(converter, "test.replace_with_legal_op",
/*benefit=*/1, ctx) {}
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
- rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0]);
+ rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0].front());
return success();
}
};
@@ -1994,6 +1995,10 @@ struct TestTypeConversionDriver
: public PassWrapper<TestTypeConversionDriver, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver)
+ TestTypeConversionDriver() = default;
+ TestTypeConversionDriver(const TestTypeConversionDriver &other)
+ : PassWrapper(other) {}
+
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<TestDialect>();
}
@@ -2020,8 +2025,13 @@ struct TestTypeConversionDriver
// Otherwise, the type is illegal.
return nullptr;
});
- converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) {
- // Drop all integer types.
+ converter.addConversion([](IndexType type) { return type; });
+ converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &types) {
+ if (type.isInteger(38)) {
+ // i38 is legal.
+ types.push_back(type);
+ }
+ // Drop all other integer types.
return success();
});
converter.addConversion(
@@ -2058,6 +2068,33 @@ struct TestTypeConversionDriver
results.push_back(result);
return success();
});
+ converter.addConversion([](Value v) -> std::optional<Type> {
+ // Context-aware type conversion rule that converts i37 to
+ // i(37 + increment). The increment is taken from the enclosing
+ // function.
+ auto intType = dyn_cast<IntegerType>(v.getType());
+ if (!intType || intType.getWidth() != 37)
+ return std::nullopt;
+ Region *r = v.getParentRegion();
+ if (!r) {
+ // No enclosing region found. This can happen when running with
+ // allow-pattern-rollback = true. Context-aware type conversions are
+ // not fully supported when running in rollback mode.
+ return Type();
+ }
+ Operation *op = r->getParentOp();
+ if (!op)
+ return Type();
+ if (!isa<FunctionOpInterface>(op))
+ op = op->getParentOfType<FunctionOpInterface>();
+ if (!op)
+ return Type();
+ auto incrementAttr = op->getAttrOfType<IntegerAttr>("increment");
+ if (!incrementAttr)
+ return Type();
+ return IntegerType::get(v.getContext(),
+ intType.getWidth() + incrementAttr.getInt());
+ });
/// Add the legal set of type materializations.
converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
@@ -2078,9 +2115,19 @@ struct TestTypeConversionDriver
// Otherwise, fail.
return nullptr;
});
+ // Materialize i37 to any desired type with unrealized_conversion_cast.
+ converter.addTargetMaterialization([](OpBuilder &builder, Type type,
+ ValueRange inputs,
+ Location loc) -> Value {
+ if (inputs.size() != 1 || !inputs[0].getType().isInteger(37))
+ return Value();
+ return builder.create<UnrealizedConversionCastOp>(loc, type, inputs)
+ .getResult(0);
+ });
// Initialize the conversion target.
mlir::ConversionTarget target(getContext());
+ target.addLegalOp(OperationName("test.context_op", &getContext()));
target.addLegalOp<LegalOpD>();
target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
auto recursiveType = dyn_cast<test::TestRecursiveType>(op.getType());
@@ -2111,11 +2158,19 @@ struct TestTypeConversionDriver
patterns.add<TestTypeConversionAnotherProducer>(&getContext());
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
+ mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
+ converter, patterns, target);
+ ConversionConfig config;
+ config.allowPatternRollback = allowPatternRollback;
if (failed(applyPartialConversion(getOperation(), target,
- std::move(patterns))))
+ std::move(patterns), config)))
signalPassFailure();
}
+
+ Option<bool> allowPatternRollback{*this, "allow-pattern-rollback",
+ llvm::cl::desc("Allow pattern rollback"),
+ llvm::cl::init(true)};
};
} // namespace
More information about the Mlir-commits
mailing list