[Mlir-commits] [mlir] dd48773 - [mlir][Types] Remove the subclass data from Type

River Riddle llvmlistbot at llvm.org
Fri Aug 7 13:43:46 PDT 2020


Author: River Riddle
Date: 2020-08-07T13:43:25-07:00
New Revision: dd48773396f77fd7b19adc43b23d41aef356809a

URL: https://github.com/llvm/llvm-project/commit/dd48773396f77fd7b19adc43b23d41aef356809a
DIFF: https://github.com/llvm/llvm-project/commit/dd48773396f77fd7b19adc43b23d41aef356809a.diff

LOG: [mlir][Types] Remove the subclass data from Type

Subclass data is useful when a certain amount of memory is allocated, but not all of it is used. In the case of Type, that hasn't been the case for a while and the subclass is just taking up a full `unsigned`. Removing this frees up ~8 bytes for almost every type instance.

Differential Revision: https://reviews.llvm.org/D85348

Added: 
    

Modified: 
    mlir/include/mlir/IR/StandardTypes.h
    mlir/include/mlir/IR/TypeSupport.h
    mlir/include/mlir/IR/Types.h
    mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
    mlir/lib/IR/StandardTypes.cpp
    mlir/lib/IR/TypeDetail.h
    mlir/lib/IR/Types.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h
index 3daf226603a8..406598c9d061 100644
--- a/mlir/include/mlir/IR/StandardTypes.h
+++ b/mlir/include/mlir/IR/StandardTypes.h
@@ -127,7 +127,7 @@ class IntegerType
   using Base::Base;
 
   /// Signedness semantics.
-  enum SignednessSemantics {
+  enum SignednessSemantics : uint32_t {
     Signless, /// No signedness semantics
     Signed,   /// Signed integer
     Unsigned, /// Unsigned integer

diff  --git a/mlir/include/mlir/IR/TypeSupport.h b/mlir/include/mlir/IR/TypeSupport.h
index 0e1a6c72c11d..c26aec6411c0 100644
--- a/mlir/include/mlir/IR/TypeSupport.h
+++ b/mlir/include/mlir/IR/TypeSupport.h
@@ -79,16 +79,9 @@ class TypeStorage : public StorageUniquer::BaseStorage {
     return *abstractType;
   }
 
-  /// Get the subclass data.
-  unsigned getSubclassData() const { return subclassData; }
-
-  /// Set the subclass data.
-  void setSubclassData(unsigned val) { subclassData = val; }
-
 protected:
   /// This constructor is used by derived classes as part of the TypeUniquer.
-  TypeStorage(unsigned subclassData = 0)
-      : abstractType(nullptr), subclassData(subclassData) {}
+  TypeStorage() : abstractType(nullptr) {}
 
 private:
   /// Set the abstract type for this storage instance. This is used by the
@@ -99,9 +92,6 @@ class TypeStorage : public StorageUniquer::BaseStorage {
 
   /// The abstract description for this type.
   const AbstractType *abstractType;
-
-  /// Space for subclasses to store data.
-  unsigned subclassData;
 };
 
 /// Default storage type for types that require no additional initialization or

diff  --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index ed63f696a84c..bd0ea4bbd5dc 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -192,9 +192,6 @@ class Type {
 
   friend ::llvm::hash_code hash_value(Type arg);
 
-  unsigned getSubclassData() const;
-  void setSubclassData(unsigned val);
-
   /// Methods for supporting PointerLikeTypeTraits.
   const void *getAsOpaquePointer() const {
     return static_cast<const void *>(impl);
@@ -264,7 +261,8 @@ class FunctionType
                           MLIRContext *context);
 
   // Input types.
-  unsigned getNumInputs() const { return getSubclassData(); }
+  unsigned getNumInputs() const;
+
   Type getInput(unsigned i) const { return getInputs()[i]; }
   ArrayRef<Type> getInputs() const;
 

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
index 583a779408b4..be42ca833f21 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
@@ -108,14 +108,15 @@ struct spirv::detail::ArrayTypeStorage : public TypeStorage {
   }
 
   bool operator==(const KeyTy &key) const {
-    return key == KeyTy(elementType, getSubclassData(), stride);
+    return key == KeyTy(elementType, elementCount, stride);
   }
 
   ArrayTypeStorage(const KeyTy &key)
-      : TypeStorage(std::get<1>(key)), elementType(std::get<0>(key)),
+      : elementType(std::get<0>(key)), elementCount(std::get<1>(key)),
         stride(std::get<2>(key)) {}
 
   Type elementType;
+  unsigned elementCount;
   unsigned stride;
 };
 
@@ -132,9 +133,7 @@ ArrayType ArrayType::get(Type elementType, unsigned elementCount,
                    elementCount, stride);
 }
 
-unsigned ArrayType::getNumElements() const {
-  return getImpl()->getSubclassData();
-}
+unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; }
 
 Type ArrayType::getElementType() const { return getImpl()->elementType; }
 
@@ -321,19 +320,17 @@ struct spirv::detail::CooperativeMatrixTypeStorage : public TypeStorage {
   }
 
   bool operator==(const KeyTy &key) const {
-    return key == KeyTy(elementType, getScope(), rows, columns);
+    return key == KeyTy(elementType, scope, rows, columns);
   }
 
   CooperativeMatrixTypeStorage(const KeyTy &key)
-      : TypeStorage(static_cast<unsigned>(std::get<1>(key))),
-        elementType(std::get<0>(key)), rows(std::get<2>(key)),
-        columns(std::get<3>(key)) {}
-
-  Scope getScope() const { return static_cast<Scope>(getSubclassData()); }
+      : elementType(std::get<0>(key)), rows(std::get<2>(key)),
+        columns(std::get<3>(key)), scope(std::get<1>(key)) {}
 
   Type elementType;
   unsigned rows;
   unsigned columns;
+  Scope scope;
 };
 
 CooperativeMatrixNVType CooperativeMatrixNVType::get(Type elementType,
@@ -347,9 +344,7 @@ Type CooperativeMatrixNVType::getElementType() const {
   return getImpl()->elementType;
 }
 
-Scope CooperativeMatrixNVType::getScope() const {
-  return getImpl()->getScope();
-}
+Scope CooperativeMatrixNVType::getScope() const { return getImpl()->scope; }
 
 unsigned CooperativeMatrixNVType::getRows() const { return getImpl()->rows; }
 
@@ -412,20 +407,6 @@ template <> constexpr unsigned getNumBits<ImageFormat>() {
 }
 
 struct spirv::detail::ImageTypeStorage : public TypeStorage {
-private:
-  /// Define a bit-field struct to pack the enum values
-  union EnumPack {
-    struct {
-      unsigned dimEncoding : getNumBits<Dim>();
-      unsigned depthInfoEncoding : getNumBits<ImageDepthInfo>();
-      unsigned arrayedInfoEncoding : getNumBits<ImageArrayedInfo>();
-      unsigned samplingInfoEncoding : getNumBits<ImageSamplingInfo>();
-      unsigned samplerUseInfoEncoding : getNumBits<ImageSamplerUseInfo>();
-      unsigned formatEncoding : getNumBits<ImageFormat>();
-    } data;
-    unsigned storage;
-  };
-
 public:
   using KeyTy = std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
                            ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
@@ -436,95 +417,23 @@ struct spirv::detail::ImageTypeStorage : public TypeStorage {
   }
 
   bool operator==(const KeyTy &key) const {
-    return key == KeyTy(elementType, getDim(), getDepthInfo(), getArrayedInfo(),
-                        getSamplingInfo(), getSamplerUseInfo(),
-                        getImageFormat());
-  }
-
-  Dim getDim() const {
-    EnumPack v;
-    v.storage = getSubclassData();
-    return static_cast<Dim>(v.data.dimEncoding);
-  }
-  void setDim(Dim dim) {
-    EnumPack v;
-    v.storage = getSubclassData();
-    v.data.dimEncoding = static_cast<unsigned>(dim);
-    setSubclassData(v.storage);
-  }
-
-  ImageDepthInfo getDepthInfo() const {
-    EnumPack v;
-    v.storage = getSubclassData();
-    return static_cast<ImageDepthInfo>(v.data.depthInfoEncoding);
-  }
-  void setDepthInfo(ImageDepthInfo depthInfo) {
-    EnumPack v;
-    v.storage = getSubclassData();
-    v.data.depthInfoEncoding = static_cast<unsigned>(depthInfo);
-    setSubclassData(v.storage);
-  }
-
-  ImageArrayedInfo getArrayedInfo() const {
-    EnumPack v;
-    v.storage = getSubclassData();
-    return static_cast<ImageArrayedInfo>(v.data.arrayedInfoEncoding);
-  }
-  void setArrayedInfo(ImageArrayedInfo arrayedInfo) {
-    EnumPack v;
-    v.storage = getSubclassData();
-    v.data.arrayedInfoEncoding = static_cast<unsigned>(arrayedInfo);
-    setSubclassData(v.storage);
+    return key == KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo,
+                        samplerUseInfo, format);
   }
 
-  ImageSamplingInfo getSamplingInfo() const {
-    EnumPack v;
-    v.storage = getSubclassData();
-    return static_cast<ImageSamplingInfo>(v.data.samplingInfoEncoding);
-  }
-  void setSamplingInfo(ImageSamplingInfo samplingInfo) {
-    EnumPack v;
-    v.storage = getSubclassData();
-    v.data.samplingInfoEncoding = static_cast<unsigned>(samplingInfo);
-    setSubclassData(v.storage);
-  }
-
-  ImageSamplerUseInfo getSamplerUseInfo() const {
-    EnumPack v;
-    v.storage = getSubclassData();
-    return static_cast<ImageSamplerUseInfo>(v.data.samplerUseInfoEncoding);
-  }
-  void setSamplerUseInfo(ImageSamplerUseInfo samplerUseInfo) {
-    EnumPack v;
-    v.storage = getSubclassData();
-    v.data.samplerUseInfoEncoding = static_cast<unsigned>(samplerUseInfo);
-    setSubclassData(v.storage);
-  }
-
-  ImageFormat getImageFormat() const {
-    EnumPack v;
-    v.storage = getSubclassData();
-    return static_cast<ImageFormat>(v.data.formatEncoding);
-  }
-  void setImageFormat(ImageFormat format) {
-    EnumPack v;
-    v.storage = getSubclassData();
-    v.data.formatEncoding = static_cast<unsigned>(format);
-    setSubclassData(v.storage);
-  }
-
-  ImageTypeStorage(const KeyTy &key) : elementType(std::get<0>(key)) {
-    static_assert(sizeof(EnumPack) <= sizeof(getSubclassData()),
-                  "EnumPack size greater than subClassData type size");
-    setDim(std::get<1>(key));
-    setDepthInfo(std::get<2>(key));
-    setArrayedInfo(std::get<3>(key));
-    setSamplingInfo(std::get<4>(key));
-    setSamplerUseInfo(std::get<5>(key));
-    setImageFormat(std::get<6>(key));
-  }
+  ImageTypeStorage(const KeyTy &key)
+      : elementType(std::get<0>(key)), dim(std::get<1>(key)),
+        depthInfo(std::get<2>(key)), arrayedInfo(std::get<3>(key)),
+        samplingInfo(std::get<4>(key)), samplerUseInfo(std::get<5>(key)),
+        format(std::get<6>(key)) {}
 
   Type elementType;
+  Dim dim : getNumBits<Dim>();
+  ImageDepthInfo depthInfo : getNumBits<ImageDepthInfo>();
+  ImageArrayedInfo arrayedInfo : getNumBits<ImageArrayedInfo>();
+  ImageSamplingInfo samplingInfo : getNumBits<ImageSamplingInfo>();
+  ImageSamplerUseInfo samplerUseInfo : getNumBits<ImageSamplerUseInfo>();
+  ImageFormat format : getNumBits<ImageFormat>();
 };
 
 ImageType
@@ -536,27 +445,23 @@ ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
 
 Type ImageType::getElementType() const { return getImpl()->elementType; }
 
-Dim ImageType::getDim() const { return getImpl()->getDim(); }
+Dim ImageType::getDim() const { return getImpl()->dim; }
 
-ImageDepthInfo ImageType::getDepthInfo() const {
-  return getImpl()->getDepthInfo();
-}
+ImageDepthInfo ImageType::getDepthInfo() const { return getImpl()->depthInfo; }
 
 ImageArrayedInfo ImageType::getArrayedInfo() const {
-  return getImpl()->getArrayedInfo();
+  return getImpl()->arrayedInfo;
 }
 
 ImageSamplingInfo ImageType::getSamplingInfo() const {
-  return getImpl()->getSamplingInfo();
+  return getImpl()->samplingInfo;
 }
 
 ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
-  return getImpl()->getSamplerUseInfo();
+  return getImpl()->samplerUseInfo;
 }
 
-ImageFormat ImageType::getImageFormat() const {
-  return getImpl()->getImageFormat();
-}
+ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
 
 void ImageType::getExtensions(SPIRVType::ExtensionArrayRefVector &,
                               Optional<StorageClass>) {
@@ -588,18 +493,14 @@ struct spirv::detail::PointerTypeStorage : public TypeStorage {
   }
 
   bool operator==(const KeyTy &key) const {
-    return key == KeyTy(pointeeType, getStorageClass());
+    return key == KeyTy(pointeeType, storageClass);
   }
 
   PointerTypeStorage(const KeyTy &key)
-      : TypeStorage(static_cast<unsigned>(key.second)), pointeeType(key.first) {
-  }
-
-  StorageClass getStorageClass() const {
-    return static_cast<StorageClass>(getSubclassData());
-  }
+      : pointeeType(key.first), storageClass(key.second) {}
 
   Type pointeeType;
+  StorageClass storageClass;
 };
 
 PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
@@ -610,7 +511,7 @@ PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
 Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
 
 StorageClass PointerType::getStorageClass() const {
-  return getImpl()->getStorageClass();
+  return getImpl()->storageClass;
 }
 
 void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
@@ -650,13 +551,14 @@ struct spirv::detail::RuntimeArrayTypeStorage : public TypeStorage {
   }
 
   bool operator==(const KeyTy &key) const {
-    return key == KeyTy(elementType, getSubclassData());
+    return key == KeyTy(elementType, stride);
   }
 
   RuntimeArrayTypeStorage(const KeyTy &key)
-      : TypeStorage(key.second), elementType(key.first) {}
+      : elementType(key.first), stride(key.second) {}
 
   Type elementType;
+  unsigned stride;
 };
 
 RuntimeArrayType RuntimeArrayType::get(Type elementType) {
@@ -671,9 +573,7 @@ RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
 
 Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
 
-unsigned RuntimeArrayType::getArrayStride() const {
-  return getImpl()->getSubclassData();
-}
+unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
 
 void RuntimeArrayType::getExtensions(
     SPIRVType::ExtensionArrayRefVector &extensions,
@@ -921,8 +821,8 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
       unsigned numMembers, Type const *memberTypes,
       StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
       StructType::MemberDecorationInfo const *memberDecorationsInfo)
-      : TypeStorage(numMembers), memberTypes(memberTypes),
-        offsetInfo(layoutInfo), numMemberDecorations(numMemberDecorations),
+      : memberTypes(memberTypes), offsetInfo(layoutInfo),
+        numMembers(numMembers), numMemberDecorations(numMemberDecorations),
         memberDecorationsInfo(memberDecorationsInfo) {}
 
   using KeyTy = std::tuple<ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
@@ -964,12 +864,12 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
   }
 
   ArrayRef<Type> getMemberTypes() const {
-    return ArrayRef<Type>(memberTypes, getSubclassData());
+    return ArrayRef<Type>(memberTypes, numMembers);
   }
 
   ArrayRef<StructType::OffsetInfo> getOffsetInfo() const {
     if (offsetInfo) {
-      return ArrayRef<StructType::OffsetInfo>(offsetInfo, getSubclassData());
+      return ArrayRef<StructType::OffsetInfo>(offsetInfo, numMembers);
     }
     return {};
   }
@@ -984,6 +884,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
 
   Type const *memberTypes;
   StructType::OffsetInfo const *offsetInfo;
+  unsigned numMembers;
   unsigned numMemberDecorations;
   StructType::MemberDecorationInfo const *memberDecorationsInfo;
 };
@@ -1007,9 +908,7 @@ StructType StructType::getEmpty(MLIRContext *context) {
                    ArrayRef<StructType::MemberDecorationInfo>());
 }
 
-unsigned StructType::getNumElements() const {
-  return getImpl()->getSubclassData();
-}
+unsigned StructType::getNumElements() const { return getImpl()->numMembers; }
 
 Type StructType::getElementType(unsigned index) const {
   assert(getNumElements() > index && "member index out of range");

diff  --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp
index f4bb79362ffd..0cc82477d380 100644
--- a/mlir/lib/IR/StandardTypes.cpp
+++ b/mlir/lib/IR/StandardTypes.cpp
@@ -130,10 +130,10 @@ IntegerType::verifyConstructionInvariants(Location loc, unsigned width,
   return success();
 }
 
-unsigned IntegerType::getWidth() const { return getImpl()->getWidth(); }
+unsigned IntegerType::getWidth() const { return getImpl()->width; }
 
 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
-  return getImpl()->getSignedness();
+  return getImpl()->signedness;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h
index 783983473a38..b5019f00d222 100644
--- a/mlir/lib/IR/TypeDetail.h
+++ b/mlir/lib/IR/TypeDetail.h
@@ -56,17 +56,17 @@ struct OpaqueTypeStorage : public TypeStorage {
 struct IntegerTypeStorage : public TypeStorage {
   IntegerTypeStorage(unsigned width,
                      IntegerType::SignednessSemantics signedness)
-      : TypeStorage(packKeyBits(width, signedness)) {}
+      : width(width), signedness(signedness) {}
 
   /// The hash key used for uniquing.
   using KeyTy = std::pair<unsigned, IntegerType::SignednessSemantics>;
 
   static llvm::hash_code hashKey(const KeyTy &key) {
-    return llvm::hash_value(packKeyBits(key.first, key.second));
+    return llvm::hash_value(key);
   }
 
   bool operator==(const KeyTy &key) const {
-    return getSubclassData() == packKeyBits(key.first, key.second);
+    return KeyTy(width, signedness) == key;
   }
 
   static IntegerTypeStorage *construct(TypeStorageAllocator &allocator,
@@ -75,35 +75,15 @@ struct IntegerTypeStorage : public TypeStorage {
         IntegerTypeStorage(key.first, key.second);
   }
 
-  struct KeyBits {
-    unsigned width : 30;
-    unsigned signedness : 2;
-  };
-
-  /// Pack the given `width` and `signedness` as a key.
-  static unsigned packKeyBits(unsigned width,
-                              IntegerType::SignednessSemantics signedness) {
-    KeyBits bits{width, static_cast<unsigned>(signedness)};
-    return llvm::bit_cast<unsigned>(bits);
-  }
-
-  static KeyBits unpackKeyBits(unsigned bits) {
-    return llvm::bit_cast<KeyBits>(bits);
-  }
-
-  unsigned getWidth() { return unpackKeyBits(getSubclassData()).width; }
-
-  IntegerType::SignednessSemantics getSignedness() {
-    return static_cast<IntegerType::SignednessSemantics>(
-        unpackKeyBits(getSubclassData()).signedness);
-  }
+  unsigned width : 30;
+  IntegerType::SignednessSemantics signedness : 2;
 };
 
 /// Function Type Storage and Uniquing.
 struct FunctionTypeStorage : public TypeStorage {
   FunctionTypeStorage(unsigned numInputs, unsigned numResults,
                       Type const *inputsAndResults)
-      : TypeStorage(numInputs), numResults(numResults),
+      : numInputs(numInputs), numResults(numResults),
         inputsAndResults(inputsAndResults) {}
 
   /// The hash key used for uniquing.
@@ -130,20 +110,20 @@ struct FunctionTypeStorage : public TypeStorage {
   }
 
   ArrayRef<Type> getInputs() const {
-    return ArrayRef<Type>(inputsAndResults, getSubclassData());
+    return ArrayRef<Type>(inputsAndResults, numInputs);
   }
   ArrayRef<Type> getResults() const {
-    return ArrayRef<Type>(inputsAndResults + getSubclassData(), numResults);
+    return ArrayRef<Type>(inputsAndResults + numInputs, numResults);
   }
 
+  unsigned numInputs;
   unsigned numResults;
   Type const *inputsAndResults;
 };
 
 /// Shaped Type Storage.
 struct ShapedTypeStorage : public TypeStorage {
-  ShapedTypeStorage(Type elementTy, unsigned subclassData = 0)
-      : TypeStorage(subclassData), elementType(elementTy) {}
+  ShapedTypeStorage(Type elementTy) : elementType(elementTy) {}
 
   /// The hash key used for uniquing.
   using KeyTy = Type;
@@ -156,7 +136,8 @@ struct ShapedTypeStorage : public TypeStorage {
 struct VectorTypeStorage : public ShapedTypeStorage {
   VectorTypeStorage(unsigned shapeSize, Type elementTy,
                     const int64_t *shapeElements)
-      : ShapedTypeStorage(elementTy, shapeSize), shapeElements(shapeElements) {}
+      : ShapedTypeStorage(elementTy), shapeElements(shapeElements),
+        shapeSize(shapeSize) {}
 
   /// The hash key used for uniquing.
   using KeyTy = std::pair<ArrayRef<int64_t>, Type>;
@@ -176,16 +157,18 @@ struct VectorTypeStorage : public ShapedTypeStorage {
   }
 
   ArrayRef<int64_t> getShape() const {
-    return ArrayRef<int64_t>(shapeElements, getSubclassData());
+    return ArrayRef<int64_t>(shapeElements, shapeSize);
   }
 
   const int64_t *shapeElements;
+  unsigned shapeSize;
 };
 
 struct RankedTensorTypeStorage : public ShapedTypeStorage {
   RankedTensorTypeStorage(unsigned shapeSize, Type elementTy,
                           const int64_t *shapeElements)
-      : ShapedTypeStorage(elementTy, shapeSize), shapeElements(shapeElements) {}
+      : ShapedTypeStorage(elementTy), shapeElements(shapeElements),
+        shapeSize(shapeSize) {}
 
   /// The hash key used for uniquing.
   using KeyTy = std::pair<ArrayRef<int64_t>, Type>;
@@ -205,10 +188,11 @@ struct RankedTensorTypeStorage : public ShapedTypeStorage {
   }
 
   ArrayRef<int64_t> getShape() const {
-    return ArrayRef<int64_t>(shapeElements, getSubclassData());
+    return ArrayRef<int64_t>(shapeElements, shapeSize);
   }
 
   const int64_t *shapeElements;
+  unsigned shapeSize;
 };
 
 struct UnrankedTensorTypeStorage : public ShapedTypeStorage {
@@ -227,9 +211,9 @@ struct MemRefTypeStorage : public ShapedTypeStorage {
   MemRefTypeStorage(unsigned shapeSize, Type elementType,
                     const int64_t *shapeElements, const unsigned numAffineMaps,
                     AffineMap const *affineMapList, const unsigned memorySpace)
-      : ShapedTypeStorage(elementType, shapeSize), shapeElements(shapeElements),
-        numAffineMaps(numAffineMaps), affineMapList(affineMapList),
-        memorySpace(memorySpace) {}
+      : ShapedTypeStorage(elementType), shapeElements(shapeElements),
+        shapeSize(shapeSize), numAffineMaps(numAffineMaps),
+        affineMapList(affineMapList), memorySpace(memorySpace) {}
 
   /// The hash key used for uniquing.
   // MemRefs are uniqued based on their shape, element type, affine map
@@ -258,7 +242,7 @@ struct MemRefTypeStorage : public ShapedTypeStorage {
   }
 
   ArrayRef<int64_t> getShape() const {
-    return ArrayRef<int64_t>(shapeElements, getSubclassData());
+    return ArrayRef<int64_t>(shapeElements, shapeSize);
   }
 
   ArrayRef<AffineMap> getAffineMaps() const {
@@ -267,6 +251,8 @@ struct MemRefTypeStorage : public ShapedTypeStorage {
 
   /// An array of integers which stores the shape dimension sizes.
   const int64_t *shapeElements;
+  /// The number of shape elements.
+  unsigned shapeSize;
   /// The number of affine maps in the 'affineMapList' array.
   const unsigned numAffineMaps;
   /// List of affine maps in the memref's layout/index map composition.
@@ -324,7 +310,7 @@ struct TupleTypeStorage final
       public llvm::TrailingObjects<TupleTypeStorage, Type> {
   using KeyTy = TypeRange;
 
-  TupleTypeStorage(unsigned numTypes) : TypeStorage(numTypes) {}
+  TupleTypeStorage(unsigned numTypes) : numElements(numTypes) {}
 
   /// Construction.
   static TupleTypeStorage *construct(TypeStorageAllocator &allocator,
@@ -343,12 +329,15 @@ struct TupleTypeStorage final
   bool operator==(const KeyTy &key) const { return key == getTypes(); }
 
   /// Return the number of held types.
-  unsigned size() const { return getSubclassData(); }
+  unsigned size() const { return numElements; }
 
   /// Return the held types.
   ArrayRef<Type> getTypes() const {
     return {getTrailingObjects<Type>(), size()};
   }
+
+  /// The number of tuple elements.
+  unsigned numElements;
 };
 
 } // namespace detail

diff  --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index fea2cc6648e3..ae2dd909ff59 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -27,9 +27,6 @@ Dialect &Type::getDialect() const {
 
 MLIRContext *Type::getContext() const { return getDialect().getContext(); }
 
-unsigned Type::getSubclassData() const { return impl->getSubclassData(); }
-void Type::setSubclassData(unsigned val) { impl->setSubclassData(val); }
-
 //===----------------------------------------------------------------------===//
 // FunctionType
 //===----------------------------------------------------------------------===//
@@ -39,6 +36,8 @@ FunctionType FunctionType::get(TypeRange inputs, TypeRange results,
   return Base::get(context, Type::Kind::Function, inputs, results);
 }
 
+unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
+
 ArrayRef<Type> FunctionType::getInputs() const {
   return getImpl()->getInputs();
 }


        


More information about the Mlir-commits mailing list