[Mlir-commits] [mlir] 3dd5833 - [mlir][Transforms] TypeConverter: Mark conversion/materialization functions as "const"

Matthias Springer llvmlistbot at llvm.org
Thu Aug 10 04:59:36 PDT 2023


Author: Matthias Springer
Date: 2023-08-10T13:54:04+02:00
New Revision: 3dd58333d09faf97872ada19a8349fbe93c0ef11

URL: https://github.com/llvm/llvm-project/commit/3dd58333d09faf97872ada19a8349fbe93c0ef11
DIFF: https://github.com/llvm/llvm-project/commit/3dd58333d09faf97872ada19a8349fbe93c0ef11.diff

LOG: [mlir][Transforms] TypeConverter: Mark conversion/materialization functions as "const"

Functions that materialize IR or convert types can be const.

Caching data structures inside the TypeConverter are marked as `mutable`.

Differential Revision: https://reviews.llvm.org/D157597

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Transforms/Utils/DialectConversion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index e8a0e6ec6991b0..b4051093d4b0a9 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -219,76 +219,78 @@ class TypeConverter {
   /// conversion exists, success otherwise. If the new set of types is empty,
   /// the type is removed and any usages of the existing value are expected to
   /// be removed during conversion.
-  LogicalResult convertType(Type t, SmallVectorImpl<Type> &results);
+  LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) const;
 
   /// This hook simplifies defining 1-1 type conversions. This function returns
   /// the type to convert to on success, and a null type on failure.
-  Type convertType(Type t);
+  Type convertType(Type t) const;
 
   /// Attempts a 1-1 type conversion, expecting the result type to be
   /// `TargetType`. Returns the converted type cast to `TargetType` on success,
   /// and a null type on conversion or cast failure.
-  template <typename TargetType>
-  TargetType convertType(Type t) {
+  template <typename TargetType> TargetType convertType(Type t) const {
     return dyn_cast_or_null<TargetType>(convertType(t));
   }
 
   /// Convert the given set of types, filling 'results' as necessary. This
   /// returns failure if the conversion of any of the types fails, success
   /// otherwise.
-  LogicalResult convertTypes(TypeRange types, SmallVectorImpl<Type> &results);
+  LogicalResult convertTypes(TypeRange types,
+                             SmallVectorImpl<Type> &results) const;
 
   /// Return true if the given type is legal for this type converter, i.e. the
   /// type converts to itself.
-  bool isLegal(Type type);
+  bool isLegal(Type type) const;
+
   /// Return true if all of the given types are legal for this type converter.
   template <typename RangeT>
   std::enable_if_t<!std::is_convertible<RangeT, Type>::value &&
                        !std::is_convertible<RangeT, Operation *>::value,
                    bool>
-  isLegal(RangeT &&range) {
+  isLegal(RangeT &&range) const {
     return llvm::all_of(range, [this](Type type) { return isLegal(type); });
   }
   /// Return true if the given operation has legal operand and result types.
-  bool isLegal(Operation *op);
+  bool isLegal(Operation *op) const;
 
   /// Return true if the types of block arguments within the region are legal.
-  bool isLegal(Region *region);
+  bool isLegal(Region *region) const;
 
   /// Return true if the inputs and outputs of the given function type are
   /// legal.
-  bool isSignatureLegal(FunctionType ty);
+  bool isSignatureLegal(FunctionType ty) const;
 
   /// This method allows for converting a specific argument of a signature. It
   /// takes as inputs the original argument input number, type.
   /// On success, it populates 'result' with any new mappings.
   LogicalResult convertSignatureArg(unsigned inputNo, Type type,
-                                    SignatureConversion &result);
+                                    SignatureConversion &result) const;
   LogicalResult convertSignatureArgs(TypeRange types,
                                      SignatureConversion &result,
-                                     unsigned origInputOffset = 0);
+                                     unsigned origInputOffset = 0) const;
 
   /// This function converts the type signature of the given block, by invoking
   /// 'convertSignatureArg' for each argument. This function should return a
   /// valid conversion for the signature on success, std::nullopt otherwise.
-  std::optional<SignatureConversion> convertBlockSignature(Block *block);
+  std::optional<SignatureConversion> convertBlockSignature(Block *block) const;
 
   /// Materialize a conversion from a set of types into one result type by
   /// generating a cast sequence of some kind. See the respective
   /// `add*Materialization` for more information on the context for these
   /// methods.
   Value materializeArgumentConversion(OpBuilder &builder, Location loc,
-                                      Type resultType, ValueRange inputs) {
+                                      Type resultType,
+                                      ValueRange inputs) const {
     return materializeConversion(argumentMaterializations, builder, loc,
                                  resultType, inputs);
   }
   Value materializeSourceConversion(OpBuilder &builder, Location loc,
-                                    Type resultType, ValueRange inputs) {
+                                    Type resultType, ValueRange inputs) const {
     return materializeConversion(sourceMaterializations, builder, loc,
                                  resultType, inputs);
   }
   Value materializeTargetConversion(OpBuilder &builder, Location loc,
-                                    Type resultType, ValueRange inputs) {
+                                    Type resultType, ValueRange inputs) const {
     return materializeConversion(targetMaterializations, builder, loc,
                                  resultType, inputs);
   }
@@ -297,7 +299,8 @@ class TypeConverter {
   /// the registered conversion functions. If no applicable conversion has been
   /// registered, return std::nullopt. Note that the empty attribute/`nullptr`
   /// is a valid return value for this function.
-  std::optional<Attribute> convertTypeAttribute(Type type, Attribute attr);
+  std::optional<Attribute> convertTypeAttribute(Type type,
+                                                Attribute attr) const;
 
 private:
   /// The signature of the callback used to convert a type. If the new set of
@@ -316,16 +319,17 @@ class TypeConverter {
 
   /// Attempt to materialize a conversion using one of the provided
   /// materialization functions.
-  Value materializeConversion(
-      MutableArrayRef<MaterializationCallbackFn> materializations,
-      OpBuilder &builder, Location loc, Type resultType, ValueRange inputs);
+  Value
+  materializeConversion(ArrayRef<MaterializationCallbackFn> materializations,
+                        OpBuilder &builder, Location loc, Type resultType,
+                        ValueRange inputs) const;
 
   /// Generate a wrapper for the given callback. This allows for accepting
   /// 
diff erent callback forms, that all compose into a single version.
   /// With callback of form: `std::optional<Type>(T)`
   template <typename T, typename FnT>
   std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn>
-  wrapCallback(FnT &&callback) {
+  wrapCallback(FnT &&callback) const {
     return wrapCallback<T>(
         [callback = std::forward<FnT>(callback)](
             T type, SmallVectorImpl<Type> &results, ArrayRef<Type>) {
@@ -343,7 +347,7 @@ class TypeConverter {
   template <typename T, typename FnT>
   std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &>,
                    ConversionCallbackFn>
-  wrapCallback(FnT &&callback) {
+  wrapCallback(FnT &&callback) const {
     return wrapCallback<T>(
         [callback = std::forward<FnT>(callback)](
             T type, SmallVectorImpl<Type> &results, ArrayRef<Type>) {
@@ -356,7 +360,7 @@ class TypeConverter {
   std::enable_if_t<
       std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &, ArrayRef<Type>>,
       ConversionCallbackFn>
-  wrapCallback(FnT &&callback) {
+  wrapCallback(FnT &&callback) const {
     return [callback = std::forward<FnT>(callback)](
                Type type, SmallVectorImpl<Type> &results,
                ArrayRef<Type> callStack) -> std::optional<LogicalResult> {
@@ -378,7 +382,7 @@ class TypeConverter {
   /// may take any subclass of `Type` and the wrapper will check for the target
   /// type to be of the expected class before calling the callback.
   template <typename T, typename FnT>
-  MaterializationCallbackFn wrapMaterialization(FnT &&callback) {
+  MaterializationCallbackFn wrapMaterialization(FnT &&callback) const {
     return [callback = std::forward<FnT>(callback)](
                OpBuilder &builder, Type resultType, ValueRange inputs,
                Location loc) -> std::optional<Value> {
@@ -394,7 +398,7 @@ class TypeConverter {
   /// callback.
   template <typename T, typename A, typename FnT>
   TypeAttributeConversionCallbackFn
-  wrapTypeAttributeConversion(FnT &&callback) {
+  wrapTypeAttributeConversion(FnT &&callback) const {
     return [callback = std::forward<FnT>(callback)](
                Type type, Attribute attr) -> AttributeConversionResult {
       if (T derivedType = dyn_cast<T>(type)) {
@@ -428,13 +432,13 @@ class TypeConverter {
   /// A set of cached conversions to avoid recomputing in the common case.
   /// Direct 1-1 conversions are the most common, so this cache stores the
   /// successful 1-1 conversions as well as all failed conversions.
-  DenseMap<Type, Type> cachedDirectConversions;
+  mutable DenseMap<Type, Type> cachedDirectConversions;
   /// This cache stores the successful 1->N conversions, where N != 1.
-  DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
+  mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
 
   /// Stores the types that are being converted in the case when convertType
   /// is being called recursively to convert nested types.
-  SmallVector<Type, 2> conversionCallStack;
+  mutable SmallVector<Type, 2> conversionCallStack;
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index d2cbab2b50cf28..fa75d6efa15bb2 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2906,7 +2906,7 @@ void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
 }
 
 LogicalResult TypeConverter::convertType(Type t,
-                                         SmallVectorImpl<Type> &results) {
+                                         SmallVectorImpl<Type> &results) const {
   auto existingIt = cachedDirectConversions.find(t);
   if (existingIt != cachedDirectConversions.end()) {
     if (existingIt->second)
@@ -2925,7 +2925,7 @@ LogicalResult TypeConverter::convertType(Type t,
   conversionCallStack.push_back(t);
   auto popConversionCallStack =
       llvm::make_scope_exit([this]() { conversionCallStack.pop_back(); });
-  for (ConversionCallbackFn &converter : llvm::reverse(conversions)) {
+  for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
     if (std::optional<LogicalResult> result =
             converter(t, results, conversionCallStack)) {
       if (!succeeded(*result)) {
@@ -2943,7 +2943,7 @@ LogicalResult TypeConverter::convertType(Type t,
   return failure();
 }
 
-Type TypeConverter::convertType(Type t) {
+Type TypeConverter::convertType(Type t) const {
   // Use the multi-type result version to convert the type.
   SmallVector<Type, 1> results;
   if (failed(convertType(t, results)))
@@ -2953,31 +2953,35 @@ Type TypeConverter::convertType(Type t) {
   return results.size() == 1 ? results.front() : nullptr;
 }
 
-LogicalResult TypeConverter::convertTypes(TypeRange types,
-                                          SmallVectorImpl<Type> &results) {
+LogicalResult
+TypeConverter::convertTypes(TypeRange types,
+                            SmallVectorImpl<Type> &results) const {
   for (Type type : types)
     if (failed(convertType(type, results)))
       return failure();
   return success();
 }
 
-bool TypeConverter::isLegal(Type type) { return convertType(type) == type; }
-bool TypeConverter::isLegal(Operation *op) {
+bool TypeConverter::isLegal(Type type) const {
+  return convertType(type) == type;
+}
+bool TypeConverter::isLegal(Operation *op) const {
   return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
 }
 
-bool TypeConverter::isLegal(Region *region) {
+bool TypeConverter::isLegal(Region *region) const {
   return llvm::all_of(*region, [this](Block &block) {
     return isLegal(block.getArgumentTypes());
   });
 }
 
-bool TypeConverter::isSignatureLegal(FunctionType ty) {
+bool TypeConverter::isSignatureLegal(FunctionType ty) const {
   return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
 }
 
-LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
-                                                 SignatureConversion &result) {
+LogicalResult
+TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
+                                   SignatureConversion &result) const {
   // Try to convert the given input type.
   SmallVector<Type, 1> convertedTypes;
   if (failed(convertType(type, convertedTypes)))
@@ -2991,9 +2995,10 @@ LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
   result.addInputs(inputNo, convertedTypes);
   return success();
 }
-LogicalResult TypeConverter::convertSignatureArgs(TypeRange types,
-                                                  SignatureConversion &result,
-                                                  unsigned origInputOffset) {
+LogicalResult
+TypeConverter::convertSignatureArgs(TypeRange types,
+                                    SignatureConversion &result,
+                                    unsigned origInputOffset) const {
   for (unsigned i = 0, e = types.size(); i != e; ++i)
     if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
       return failure();
@@ -3001,16 +3006,16 @@ LogicalResult TypeConverter::convertSignatureArgs(TypeRange types,
 }
 
 Value TypeConverter::materializeConversion(
-    MutableArrayRef<MaterializationCallbackFn> materializations,
-    OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) {
-  for (MaterializationCallbackFn &fn : llvm::reverse(materializations))
+    ArrayRef<MaterializationCallbackFn> materializations, OpBuilder &builder,
+    Location loc, Type resultType, ValueRange inputs) const {
+  for (const MaterializationCallbackFn &fn : llvm::reverse(materializations))
     if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
       return *result;
   return nullptr;
 }
 
-auto TypeConverter::convertBlockSignature(Block *block)
-    -> std::optional<SignatureConversion> {
+std::optional<TypeConverter::SignatureConversion>
+TypeConverter::convertBlockSignature(Block *block) const {
   SignatureConversion conversion(block->getNumArguments());
   if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion)))
     return std::nullopt;
@@ -3052,9 +3057,9 @@ Attribute TypeConverter::AttributeConversionResult::getResult() const {
   return impl.getPointer();
 }
 
-std::optional<Attribute> TypeConverter::convertTypeAttribute(Type type,
-                                                             Attribute attr) {
-  for (TypeAttributeConversionCallbackFn &fn :
+std::optional<Attribute>
+TypeConverter::convertTypeAttribute(Type type, Attribute attr) const {
+  for (const TypeAttributeConversionCallbackFn &fn :
        llvm::reverse(typeAttributeConversions)) {
     AttributeConversionResult res = fn(type, attr);
     if (res.hasResult())


        


More information about the Mlir-commits mailing list