[Mlir-commits] [mlir] 3dd5833 - [mlir][Transforms] TypeConverter: Mark conversion/materialization functions as "const"
Matthias Springer
llvmlistbot at llvm.org
Thu Aug 10 04:59:36 PDT 2023
Author: Matthias Springer
Date: 2023-08-10T13:54:04+02:00
New Revision: 3dd58333d09faf97872ada19a8349fbe93c0ef11
URL: https://github.com/llvm/llvm-project/commit/3dd58333d09faf97872ada19a8349fbe93c0ef11
DIFF: https://github.com/llvm/llvm-project/commit/3dd58333d09faf97872ada19a8349fbe93c0ef11.diff
LOG: [mlir][Transforms] TypeConverter: Mark conversion/materialization functions as "const"
Functions that materialize IR or convert types can be const.
Caching data structures inside the TypeConverter are marked as `mutable`.
Differential Revision: https://reviews.llvm.org/D157597
Added:
Modified:
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Transforms/Utils/DialectConversion.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index e8a0e6ec6991b0..b4051093d4b0a9 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -219,76 +219,78 @@ class TypeConverter {
/// conversion exists, success otherwise. If the new set of types is empty,
/// the type is removed and any usages of the existing value are expected to
/// be removed during conversion.
- LogicalResult convertType(Type t, SmallVectorImpl<Type> &results);
+ LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) const;
/// This hook simplifies defining 1-1 type conversions. This function returns
/// the type to convert to on success, and a null type on failure.
- Type convertType(Type t);
+ Type convertType(Type t) const;
/// Attempts a 1-1 type conversion, expecting the result type to be
/// `TargetType`. Returns the converted type cast to `TargetType` on success,
/// and a null type on conversion or cast failure.
- template <typename TargetType>
- TargetType convertType(Type t) {
+ template <typename TargetType> TargetType convertType(Type t) const {
return dyn_cast_or_null<TargetType>(convertType(t));
}
/// Convert the given set of types, filling 'results' as necessary. This
/// returns failure if the conversion of any of the types fails, success
/// otherwise.
- LogicalResult convertTypes(TypeRange types, SmallVectorImpl<Type> &results);
+ LogicalResult convertTypes(TypeRange types,
+ SmallVectorImpl<Type> &results) const;
/// Return true if the given type is legal for this type converter, i.e. the
/// type converts to itself.
- bool isLegal(Type type);
+ bool isLegal(Type type) const;
+
/// Return true if all of the given types are legal for this type converter.
template <typename RangeT>
std::enable_if_t<!std::is_convertible<RangeT, Type>::value &&
!std::is_convertible<RangeT, Operation *>::value,
bool>
- isLegal(RangeT &&range) {
+ isLegal(RangeT &&range) const {
return llvm::all_of(range, [this](Type type) { return isLegal(type); });
}
/// Return true if the given operation has legal operand and result types.
- bool isLegal(Operation *op);
+ bool isLegal(Operation *op) const;
/// Return true if the types of block arguments within the region are legal.
- bool isLegal(Region *region);
+ bool isLegal(Region *region) const;
/// Return true if the inputs and outputs of the given function type are
/// legal.
- bool isSignatureLegal(FunctionType ty);
+ bool isSignatureLegal(FunctionType ty) const;
/// This method allows for converting a specific argument of a signature. It
/// takes as inputs the original argument input number, type.
/// On success, it populates 'result' with any new mappings.
LogicalResult convertSignatureArg(unsigned inputNo, Type type,
- SignatureConversion &result);
+ SignatureConversion &result) const;
LogicalResult convertSignatureArgs(TypeRange types,
SignatureConversion &result,
- unsigned origInputOffset = 0);
+ unsigned origInputOffset = 0) const;
/// This function converts the type signature of the given block, by invoking
/// 'convertSignatureArg' for each argument. This function should return a
/// valid conversion for the signature on success, std::nullopt otherwise.
- std::optional<SignatureConversion> convertBlockSignature(Block *block);
+ std::optional<SignatureConversion> convertBlockSignature(Block *block) const;
/// Materialize a conversion from a set of types into one result type by
/// generating a cast sequence of some kind. See the respective
/// `add*Materialization` for more information on the context for these
/// methods.
Value materializeArgumentConversion(OpBuilder &builder, Location loc,
- Type resultType, ValueRange inputs) {
+ Type resultType,
+ ValueRange inputs) const {
return materializeConversion(argumentMaterializations, builder, loc,
resultType, inputs);
}
Value materializeSourceConversion(OpBuilder &builder, Location loc,
- Type resultType, ValueRange inputs) {
+ Type resultType, ValueRange inputs) const {
return materializeConversion(sourceMaterializations, builder, loc,
resultType, inputs);
}
Value materializeTargetConversion(OpBuilder &builder, Location loc,
- Type resultType, ValueRange inputs) {
+ Type resultType, ValueRange inputs) const {
return materializeConversion(targetMaterializations, builder, loc,
resultType, inputs);
}
@@ -297,7 +299,8 @@ class TypeConverter {
/// the registered conversion functions. If no applicable conversion has been
/// registered, return std::nullopt. Note that the empty attribute/`nullptr`
/// is a valid return value for this function.
- std::optional<Attribute> convertTypeAttribute(Type type, Attribute attr);
+ std::optional<Attribute> convertTypeAttribute(Type type,
+ Attribute attr) const;
private:
/// The signature of the callback used to convert a type. If the new set of
@@ -316,16 +319,17 @@ class TypeConverter {
/// Attempt to materialize a conversion using one of the provided
/// materialization functions.
- Value materializeConversion(
- MutableArrayRef<MaterializationCallbackFn> materializations,
- OpBuilder &builder, Location loc, Type resultType, ValueRange inputs);
+ Value
+ materializeConversion(ArrayRef<MaterializationCallbackFn> materializations,
+ OpBuilder &builder, Location loc, Type resultType,
+ ValueRange inputs) const;
/// Generate a wrapper for the given callback. This allows for accepting
///
diff erent callback forms, that all compose into a single version.
/// With callback of form: `std::optional<Type>(T)`
template <typename T, typename FnT>
std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn>
- wrapCallback(FnT &&callback) {
+ wrapCallback(FnT &&callback) const {
return wrapCallback<T>(
[callback = std::forward<FnT>(callback)](
T type, SmallVectorImpl<Type> &results, ArrayRef<Type>) {
@@ -343,7 +347,7 @@ class TypeConverter {
template <typename T, typename FnT>
std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &>,
ConversionCallbackFn>
- wrapCallback(FnT &&callback) {
+ wrapCallback(FnT &&callback) const {
return wrapCallback<T>(
[callback = std::forward<FnT>(callback)](
T type, SmallVectorImpl<Type> &results, ArrayRef<Type>) {
@@ -356,7 +360,7 @@ class TypeConverter {
std::enable_if_t<
std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &, ArrayRef<Type>>,
ConversionCallbackFn>
- wrapCallback(FnT &&callback) {
+ wrapCallback(FnT &&callback) const {
return [callback = std::forward<FnT>(callback)](
Type type, SmallVectorImpl<Type> &results,
ArrayRef<Type> callStack) -> std::optional<LogicalResult> {
@@ -378,7 +382,7 @@ class TypeConverter {
/// may take any subclass of `Type` and the wrapper will check for the target
/// type to be of the expected class before calling the callback.
template <typename T, typename FnT>
- MaterializationCallbackFn wrapMaterialization(FnT &&callback) {
+ MaterializationCallbackFn wrapMaterialization(FnT &&callback) const {
return [callback = std::forward<FnT>(callback)](
OpBuilder &builder, Type resultType, ValueRange inputs,
Location loc) -> std::optional<Value> {
@@ -394,7 +398,7 @@ class TypeConverter {
/// callback.
template <typename T, typename A, typename FnT>
TypeAttributeConversionCallbackFn
- wrapTypeAttributeConversion(FnT &&callback) {
+ wrapTypeAttributeConversion(FnT &&callback) const {
return [callback = std::forward<FnT>(callback)](
Type type, Attribute attr) -> AttributeConversionResult {
if (T derivedType = dyn_cast<T>(type)) {
@@ -428,13 +432,13 @@ class TypeConverter {
/// A set of cached conversions to avoid recomputing in the common case.
/// Direct 1-1 conversions are the most common, so this cache stores the
/// successful 1-1 conversions as well as all failed conversions.
- DenseMap<Type, Type> cachedDirectConversions;
+ mutable DenseMap<Type, Type> cachedDirectConversions;
/// This cache stores the successful 1->N conversions, where N != 1.
- DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
+ mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
/// Stores the types that are being converted in the case when convertType
/// is being called recursively to convert nested types.
- SmallVector<Type, 2> conversionCallStack;
+ mutable SmallVector<Type, 2> conversionCallStack;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index d2cbab2b50cf28..fa75d6efa15bb2 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2906,7 +2906,7 @@ void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
}
LogicalResult TypeConverter::convertType(Type t,
- SmallVectorImpl<Type> &results) {
+ SmallVectorImpl<Type> &results) const {
auto existingIt = cachedDirectConversions.find(t);
if (existingIt != cachedDirectConversions.end()) {
if (existingIt->second)
@@ -2925,7 +2925,7 @@ LogicalResult TypeConverter::convertType(Type t,
conversionCallStack.push_back(t);
auto popConversionCallStack =
llvm::make_scope_exit([this]() { conversionCallStack.pop_back(); });
- for (ConversionCallbackFn &converter : llvm::reverse(conversions)) {
+ for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
if (std::optional<LogicalResult> result =
converter(t, results, conversionCallStack)) {
if (!succeeded(*result)) {
@@ -2943,7 +2943,7 @@ LogicalResult TypeConverter::convertType(Type t,
return failure();
}
-Type TypeConverter::convertType(Type t) {
+Type TypeConverter::convertType(Type t) const {
// Use the multi-type result version to convert the type.
SmallVector<Type, 1> results;
if (failed(convertType(t, results)))
@@ -2953,31 +2953,35 @@ Type TypeConverter::convertType(Type t) {
return results.size() == 1 ? results.front() : nullptr;
}
-LogicalResult TypeConverter::convertTypes(TypeRange types,
- SmallVectorImpl<Type> &results) {
+LogicalResult
+TypeConverter::convertTypes(TypeRange types,
+ SmallVectorImpl<Type> &results) const {
for (Type type : types)
if (failed(convertType(type, results)))
return failure();
return success();
}
-bool TypeConverter::isLegal(Type type) { return convertType(type) == type; }
-bool TypeConverter::isLegal(Operation *op) {
+bool TypeConverter::isLegal(Type type) const {
+ return convertType(type) == type;
+}
+bool TypeConverter::isLegal(Operation *op) const {
return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
}
-bool TypeConverter::isLegal(Region *region) {
+bool TypeConverter::isLegal(Region *region) const {
return llvm::all_of(*region, [this](Block &block) {
return isLegal(block.getArgumentTypes());
});
}
-bool TypeConverter::isSignatureLegal(FunctionType ty) {
+bool TypeConverter::isSignatureLegal(FunctionType ty) const {
return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
}
-LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
- SignatureConversion &result) {
+LogicalResult
+TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
+ SignatureConversion &result) const {
// Try to convert the given input type.
SmallVector<Type, 1> convertedTypes;
if (failed(convertType(type, convertedTypes)))
@@ -2991,9 +2995,10 @@ LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
result.addInputs(inputNo, convertedTypes);
return success();
}
-LogicalResult TypeConverter::convertSignatureArgs(TypeRange types,
- SignatureConversion &result,
- unsigned origInputOffset) {
+LogicalResult
+TypeConverter::convertSignatureArgs(TypeRange types,
+ SignatureConversion &result,
+ unsigned origInputOffset) const {
for (unsigned i = 0, e = types.size(); i != e; ++i)
if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
return failure();
@@ -3001,16 +3006,16 @@ LogicalResult TypeConverter::convertSignatureArgs(TypeRange types,
}
Value TypeConverter::materializeConversion(
- MutableArrayRef<MaterializationCallbackFn> materializations,
- OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) {
- for (MaterializationCallbackFn &fn : llvm::reverse(materializations))
+ ArrayRef<MaterializationCallbackFn> materializations, OpBuilder &builder,
+ Location loc, Type resultType, ValueRange inputs) const {
+ for (const MaterializationCallbackFn &fn : llvm::reverse(materializations))
if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
return *result;
return nullptr;
}
-auto TypeConverter::convertBlockSignature(Block *block)
- -> std::optional<SignatureConversion> {
+std::optional<TypeConverter::SignatureConversion>
+TypeConverter::convertBlockSignature(Block *block) const {
SignatureConversion conversion(block->getNumArguments());
if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion)))
return std::nullopt;
@@ -3052,9 +3057,9 @@ Attribute TypeConverter::AttributeConversionResult::getResult() const {
return impl.getPointer();
}
-std::optional<Attribute> TypeConverter::convertTypeAttribute(Type type,
- Attribute attr) {
- for (TypeAttributeConversionCallbackFn &fn :
+std::optional<Attribute>
+TypeConverter::convertTypeAttribute(Type type, Attribute attr) const {
+ for (const TypeAttributeConversionCallbackFn &fn :
llvm::reverse(typeAttributeConversions)) {
AttributeConversionResult res = fn(type, attr);
if (res.hasResult())
More information about the Mlir-commits
mailing list