[Mlir-commits] [mlir] dd48773 - [mlir][Types] Remove the subclass data from Type
River Riddle
llvmlistbot at llvm.org
Fri Aug 7 13:43:46 PDT 2020
Author: River Riddle
Date: 2020-08-07T13:43:25-07:00
New Revision: dd48773396f77fd7b19adc43b23d41aef356809a
URL: https://github.com/llvm/llvm-project/commit/dd48773396f77fd7b19adc43b23d41aef356809a
DIFF: https://github.com/llvm/llvm-project/commit/dd48773396f77fd7b19adc43b23d41aef356809a.diff
LOG: [mlir][Types] Remove the subclass data from Type
Subclass data is useful when a certain amount of memory is allocated, but not all of it is used. In the case of Type, that hasn't been the case for a while and the subclass is just taking up a full `unsigned`. Removing this frees up ~8 bytes for almost every type instance.
Differential Revision: https://reviews.llvm.org/D85348
Added:
Modified:
mlir/include/mlir/IR/StandardTypes.h
mlir/include/mlir/IR/TypeSupport.h
mlir/include/mlir/IR/Types.h
mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
mlir/lib/IR/StandardTypes.cpp
mlir/lib/IR/TypeDetail.h
mlir/lib/IR/Types.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h
index 3daf226603a8..406598c9d061 100644
--- a/mlir/include/mlir/IR/StandardTypes.h
+++ b/mlir/include/mlir/IR/StandardTypes.h
@@ -127,7 +127,7 @@ class IntegerType
using Base::Base;
/// Signedness semantics.
- enum SignednessSemantics {
+ enum SignednessSemantics : uint32_t {
Signless, /// No signedness semantics
Signed, /// Signed integer
Unsigned, /// Unsigned integer
diff --git a/mlir/include/mlir/IR/TypeSupport.h b/mlir/include/mlir/IR/TypeSupport.h
index 0e1a6c72c11d..c26aec6411c0 100644
--- a/mlir/include/mlir/IR/TypeSupport.h
+++ b/mlir/include/mlir/IR/TypeSupport.h
@@ -79,16 +79,9 @@ class TypeStorage : public StorageUniquer::BaseStorage {
return *abstractType;
}
- /// Get the subclass data.
- unsigned getSubclassData() const { return subclassData; }
-
- /// Set the subclass data.
- void setSubclassData(unsigned val) { subclassData = val; }
-
protected:
/// This constructor is used by derived classes as part of the TypeUniquer.
- TypeStorage(unsigned subclassData = 0)
- : abstractType(nullptr), subclassData(subclassData) {}
+ TypeStorage() : abstractType(nullptr) {}
private:
/// Set the abstract type for this storage instance. This is used by the
@@ -99,9 +92,6 @@ class TypeStorage : public StorageUniquer::BaseStorage {
/// The abstract description for this type.
const AbstractType *abstractType;
-
- /// Space for subclasses to store data.
- unsigned subclassData;
};
/// Default storage type for types that require no additional initialization or
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index ed63f696a84c..bd0ea4bbd5dc 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -192,9 +192,6 @@ class Type {
friend ::llvm::hash_code hash_value(Type arg);
- unsigned getSubclassData() const;
- void setSubclassData(unsigned val);
-
/// Methods for supporting PointerLikeTypeTraits.
const void *getAsOpaquePointer() const {
return static_cast<const void *>(impl);
@@ -264,7 +261,8 @@ class FunctionType
MLIRContext *context);
// Input types.
- unsigned getNumInputs() const { return getSubclassData(); }
+ unsigned getNumInputs() const;
+
Type getInput(unsigned i) const { return getInputs()[i]; }
ArrayRef<Type> getInputs() const;
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
index 583a779408b4..be42ca833f21 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
@@ -108,14 +108,15 @@ struct spirv::detail::ArrayTypeStorage : public TypeStorage {
}
bool operator==(const KeyTy &key) const {
- return key == KeyTy(elementType, getSubclassData(), stride);
+ return key == KeyTy(elementType, elementCount, stride);
}
ArrayTypeStorage(const KeyTy &key)
- : TypeStorage(std::get<1>(key)), elementType(std::get<0>(key)),
+ : elementType(std::get<0>(key)), elementCount(std::get<1>(key)),
stride(std::get<2>(key)) {}
Type elementType;
+ unsigned elementCount;
unsigned stride;
};
@@ -132,9 +133,7 @@ ArrayType ArrayType::get(Type elementType, unsigned elementCount,
elementCount, stride);
}
-unsigned ArrayType::getNumElements() const {
- return getImpl()->getSubclassData();
-}
+unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; }
Type ArrayType::getElementType() const { return getImpl()->elementType; }
@@ -321,19 +320,17 @@ struct spirv::detail::CooperativeMatrixTypeStorage : public TypeStorage {
}
bool operator==(const KeyTy &key) const {
- return key == KeyTy(elementType, getScope(), rows, columns);
+ return key == KeyTy(elementType, scope, rows, columns);
}
CooperativeMatrixTypeStorage(const KeyTy &key)
- : TypeStorage(static_cast<unsigned>(std::get<1>(key))),
- elementType(std::get<0>(key)), rows(std::get<2>(key)),
- columns(std::get<3>(key)) {}
-
- Scope getScope() const { return static_cast<Scope>(getSubclassData()); }
+ : elementType(std::get<0>(key)), rows(std::get<2>(key)),
+ columns(std::get<3>(key)), scope(std::get<1>(key)) {}
Type elementType;
unsigned rows;
unsigned columns;
+ Scope scope;
};
CooperativeMatrixNVType CooperativeMatrixNVType::get(Type elementType,
@@ -347,9 +344,7 @@ Type CooperativeMatrixNVType::getElementType() const {
return getImpl()->elementType;
}
-Scope CooperativeMatrixNVType::getScope() const {
- return getImpl()->getScope();
-}
+Scope CooperativeMatrixNVType::getScope() const { return getImpl()->scope; }
unsigned CooperativeMatrixNVType::getRows() const { return getImpl()->rows; }
@@ -412,20 +407,6 @@ template <> constexpr unsigned getNumBits<ImageFormat>() {
}
struct spirv::detail::ImageTypeStorage : public TypeStorage {
-private:
- /// Define a bit-field struct to pack the enum values
- union EnumPack {
- struct {
- unsigned dimEncoding : getNumBits<Dim>();
- unsigned depthInfoEncoding : getNumBits<ImageDepthInfo>();
- unsigned arrayedInfoEncoding : getNumBits<ImageArrayedInfo>();
- unsigned samplingInfoEncoding : getNumBits<ImageSamplingInfo>();
- unsigned samplerUseInfoEncoding : getNumBits<ImageSamplerUseInfo>();
- unsigned formatEncoding : getNumBits<ImageFormat>();
- } data;
- unsigned storage;
- };
-
public:
using KeyTy = std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
@@ -436,95 +417,23 @@ struct spirv::detail::ImageTypeStorage : public TypeStorage {
}
bool operator==(const KeyTy &key) const {
- return key == KeyTy(elementType, getDim(), getDepthInfo(), getArrayedInfo(),
- getSamplingInfo(), getSamplerUseInfo(),
- getImageFormat());
- }
-
- Dim getDim() const {
- EnumPack v;
- v.storage = getSubclassData();
- return static_cast<Dim>(v.data.dimEncoding);
- }
- void setDim(Dim dim) {
- EnumPack v;
- v.storage = getSubclassData();
- v.data.dimEncoding = static_cast<unsigned>(dim);
- setSubclassData(v.storage);
- }
-
- ImageDepthInfo getDepthInfo() const {
- EnumPack v;
- v.storage = getSubclassData();
- return static_cast<ImageDepthInfo>(v.data.depthInfoEncoding);
- }
- void setDepthInfo(ImageDepthInfo depthInfo) {
- EnumPack v;
- v.storage = getSubclassData();
- v.data.depthInfoEncoding = static_cast<unsigned>(depthInfo);
- setSubclassData(v.storage);
- }
-
- ImageArrayedInfo getArrayedInfo() const {
- EnumPack v;
- v.storage = getSubclassData();
- return static_cast<ImageArrayedInfo>(v.data.arrayedInfoEncoding);
- }
- void setArrayedInfo(ImageArrayedInfo arrayedInfo) {
- EnumPack v;
- v.storage = getSubclassData();
- v.data.arrayedInfoEncoding = static_cast<unsigned>(arrayedInfo);
- setSubclassData(v.storage);
+ return key == KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo,
+ samplerUseInfo, format);
}
- ImageSamplingInfo getSamplingInfo() const {
- EnumPack v;
- v.storage = getSubclassData();
- return static_cast<ImageSamplingInfo>(v.data.samplingInfoEncoding);
- }
- void setSamplingInfo(ImageSamplingInfo samplingInfo) {
- EnumPack v;
- v.storage = getSubclassData();
- v.data.samplingInfoEncoding = static_cast<unsigned>(samplingInfo);
- setSubclassData(v.storage);
- }
-
- ImageSamplerUseInfo getSamplerUseInfo() const {
- EnumPack v;
- v.storage = getSubclassData();
- return static_cast<ImageSamplerUseInfo>(v.data.samplerUseInfoEncoding);
- }
- void setSamplerUseInfo(ImageSamplerUseInfo samplerUseInfo) {
- EnumPack v;
- v.storage = getSubclassData();
- v.data.samplerUseInfoEncoding = static_cast<unsigned>(samplerUseInfo);
- setSubclassData(v.storage);
- }
-
- ImageFormat getImageFormat() const {
- EnumPack v;
- v.storage = getSubclassData();
- return static_cast<ImageFormat>(v.data.formatEncoding);
- }
- void setImageFormat(ImageFormat format) {
- EnumPack v;
- v.storage = getSubclassData();
- v.data.formatEncoding = static_cast<unsigned>(format);
- setSubclassData(v.storage);
- }
-
- ImageTypeStorage(const KeyTy &key) : elementType(std::get<0>(key)) {
- static_assert(sizeof(EnumPack) <= sizeof(getSubclassData()),
- "EnumPack size greater than subClassData type size");
- setDim(std::get<1>(key));
- setDepthInfo(std::get<2>(key));
- setArrayedInfo(std::get<3>(key));
- setSamplingInfo(std::get<4>(key));
- setSamplerUseInfo(std::get<5>(key));
- setImageFormat(std::get<6>(key));
- }
+ ImageTypeStorage(const KeyTy &key)
+ : elementType(std::get<0>(key)), dim(std::get<1>(key)),
+ depthInfo(std::get<2>(key)), arrayedInfo(std::get<3>(key)),
+ samplingInfo(std::get<4>(key)), samplerUseInfo(std::get<5>(key)),
+ format(std::get<6>(key)) {}
Type elementType;
+ Dim dim : getNumBits<Dim>();
+ ImageDepthInfo depthInfo : getNumBits<ImageDepthInfo>();
+ ImageArrayedInfo arrayedInfo : getNumBits<ImageArrayedInfo>();
+ ImageSamplingInfo samplingInfo : getNumBits<ImageSamplingInfo>();
+ ImageSamplerUseInfo samplerUseInfo : getNumBits<ImageSamplerUseInfo>();
+ ImageFormat format : getNumBits<ImageFormat>();
};
ImageType
@@ -536,27 +445,23 @@ ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
Type ImageType::getElementType() const { return getImpl()->elementType; }
-Dim ImageType::getDim() const { return getImpl()->getDim(); }
+Dim ImageType::getDim() const { return getImpl()->dim; }
-ImageDepthInfo ImageType::getDepthInfo() const {
- return getImpl()->getDepthInfo();
-}
+ImageDepthInfo ImageType::getDepthInfo() const { return getImpl()->depthInfo; }
ImageArrayedInfo ImageType::getArrayedInfo() const {
- return getImpl()->getArrayedInfo();
+ return getImpl()->arrayedInfo;
}
ImageSamplingInfo ImageType::getSamplingInfo() const {
- return getImpl()->getSamplingInfo();
+ return getImpl()->samplingInfo;
}
ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
- return getImpl()->getSamplerUseInfo();
+ return getImpl()->samplerUseInfo;
}
-ImageFormat ImageType::getImageFormat() const {
- return getImpl()->getImageFormat();
-}
+ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
void ImageType::getExtensions(SPIRVType::ExtensionArrayRefVector &,
Optional<StorageClass>) {
@@ -588,18 +493,14 @@ struct spirv::detail::PointerTypeStorage : public TypeStorage {
}
bool operator==(const KeyTy &key) const {
- return key == KeyTy(pointeeType, getStorageClass());
+ return key == KeyTy(pointeeType, storageClass);
}
PointerTypeStorage(const KeyTy &key)
- : TypeStorage(static_cast<unsigned>(key.second)), pointeeType(key.first) {
- }
-
- StorageClass getStorageClass() const {
- return static_cast<StorageClass>(getSubclassData());
- }
+ : pointeeType(key.first), storageClass(key.second) {}
Type pointeeType;
+ StorageClass storageClass;
};
PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
@@ -610,7 +511,7 @@ PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
StorageClass PointerType::getStorageClass() const {
- return getImpl()->getStorageClass();
+ return getImpl()->storageClass;
}
void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
@@ -650,13 +551,14 @@ struct spirv::detail::RuntimeArrayTypeStorage : public TypeStorage {
}
bool operator==(const KeyTy &key) const {
- return key == KeyTy(elementType, getSubclassData());
+ return key == KeyTy(elementType, stride);
}
RuntimeArrayTypeStorage(const KeyTy &key)
- : TypeStorage(key.second), elementType(key.first) {}
+ : elementType(key.first), stride(key.second) {}
Type elementType;
+ unsigned stride;
};
RuntimeArrayType RuntimeArrayType::get(Type elementType) {
@@ -671,9 +573,7 @@ RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
-unsigned RuntimeArrayType::getArrayStride() const {
- return getImpl()->getSubclassData();
-}
+unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
void RuntimeArrayType::getExtensions(
SPIRVType::ExtensionArrayRefVector &extensions,
@@ -921,8 +821,8 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
unsigned numMembers, Type const *memberTypes,
StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
StructType::MemberDecorationInfo const *memberDecorationsInfo)
- : TypeStorage(numMembers), memberTypes(memberTypes),
- offsetInfo(layoutInfo), numMemberDecorations(numMemberDecorations),
+ : memberTypes(memberTypes), offsetInfo(layoutInfo),
+ numMembers(numMembers), numMemberDecorations(numMemberDecorations),
memberDecorationsInfo(memberDecorationsInfo) {}
using KeyTy = std::tuple<ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
@@ -964,12 +864,12 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
}
ArrayRef<Type> getMemberTypes() const {
- return ArrayRef<Type>(memberTypes, getSubclassData());
+ return ArrayRef<Type>(memberTypes, numMembers);
}
ArrayRef<StructType::OffsetInfo> getOffsetInfo() const {
if (offsetInfo) {
- return ArrayRef<StructType::OffsetInfo>(offsetInfo, getSubclassData());
+ return ArrayRef<StructType::OffsetInfo>(offsetInfo, numMembers);
}
return {};
}
@@ -984,6 +884,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
Type const *memberTypes;
StructType::OffsetInfo const *offsetInfo;
+ unsigned numMembers;
unsigned numMemberDecorations;
StructType::MemberDecorationInfo const *memberDecorationsInfo;
};
@@ -1007,9 +908,7 @@ StructType StructType::getEmpty(MLIRContext *context) {
ArrayRef<StructType::MemberDecorationInfo>());
}
-unsigned StructType::getNumElements() const {
- return getImpl()->getSubclassData();
-}
+unsigned StructType::getNumElements() const { return getImpl()->numMembers; }
Type StructType::getElementType(unsigned index) const {
assert(getNumElements() > index && "member index out of range");
diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp
index f4bb79362ffd..0cc82477d380 100644
--- a/mlir/lib/IR/StandardTypes.cpp
+++ b/mlir/lib/IR/StandardTypes.cpp
@@ -130,10 +130,10 @@ IntegerType::verifyConstructionInvariants(Location loc, unsigned width,
return success();
}
-unsigned IntegerType::getWidth() const { return getImpl()->getWidth(); }
+unsigned IntegerType::getWidth() const { return getImpl()->width; }
IntegerType::SignednessSemantics IntegerType::getSignedness() const {
- return getImpl()->getSignedness();
+ return getImpl()->signedness;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h
index 783983473a38..b5019f00d222 100644
--- a/mlir/lib/IR/TypeDetail.h
+++ b/mlir/lib/IR/TypeDetail.h
@@ -56,17 +56,17 @@ struct OpaqueTypeStorage : public TypeStorage {
struct IntegerTypeStorage : public TypeStorage {
IntegerTypeStorage(unsigned width,
IntegerType::SignednessSemantics signedness)
- : TypeStorage(packKeyBits(width, signedness)) {}
+ : width(width), signedness(signedness) {}
/// The hash key used for uniquing.
using KeyTy = std::pair<unsigned, IntegerType::SignednessSemantics>;
static llvm::hash_code hashKey(const KeyTy &key) {
- return llvm::hash_value(packKeyBits(key.first, key.second));
+ return llvm::hash_value(key);
}
bool operator==(const KeyTy &key) const {
- return getSubclassData() == packKeyBits(key.first, key.second);
+ return KeyTy(width, signedness) == key;
}
static IntegerTypeStorage *construct(TypeStorageAllocator &allocator,
@@ -75,35 +75,15 @@ struct IntegerTypeStorage : public TypeStorage {
IntegerTypeStorage(key.first, key.second);
}
- struct KeyBits {
- unsigned width : 30;
- unsigned signedness : 2;
- };
-
- /// Pack the given `width` and `signedness` as a key.
- static unsigned packKeyBits(unsigned width,
- IntegerType::SignednessSemantics signedness) {
- KeyBits bits{width, static_cast<unsigned>(signedness)};
- return llvm::bit_cast<unsigned>(bits);
- }
-
- static KeyBits unpackKeyBits(unsigned bits) {
- return llvm::bit_cast<KeyBits>(bits);
- }
-
- unsigned getWidth() { return unpackKeyBits(getSubclassData()).width; }
-
- IntegerType::SignednessSemantics getSignedness() {
- return static_cast<IntegerType::SignednessSemantics>(
- unpackKeyBits(getSubclassData()).signedness);
- }
+ unsigned width : 30;
+ IntegerType::SignednessSemantics signedness : 2;
};
/// Function Type Storage and Uniquing.
struct FunctionTypeStorage : public TypeStorage {
FunctionTypeStorage(unsigned numInputs, unsigned numResults,
Type const *inputsAndResults)
- : TypeStorage(numInputs), numResults(numResults),
+ : numInputs(numInputs), numResults(numResults),
inputsAndResults(inputsAndResults) {}
/// The hash key used for uniquing.
@@ -130,20 +110,20 @@ struct FunctionTypeStorage : public TypeStorage {
}
ArrayRef<Type> getInputs() const {
- return ArrayRef<Type>(inputsAndResults, getSubclassData());
+ return ArrayRef<Type>(inputsAndResults, numInputs);
}
ArrayRef<Type> getResults() const {
- return ArrayRef<Type>(inputsAndResults + getSubclassData(), numResults);
+ return ArrayRef<Type>(inputsAndResults + numInputs, numResults);
}
+ unsigned numInputs;
unsigned numResults;
Type const *inputsAndResults;
};
/// Shaped Type Storage.
struct ShapedTypeStorage : public TypeStorage {
- ShapedTypeStorage(Type elementTy, unsigned subclassData = 0)
- : TypeStorage(subclassData), elementType(elementTy) {}
+ ShapedTypeStorage(Type elementTy) : elementType(elementTy) {}
/// The hash key used for uniquing.
using KeyTy = Type;
@@ -156,7 +136,8 @@ struct ShapedTypeStorage : public TypeStorage {
struct VectorTypeStorage : public ShapedTypeStorage {
VectorTypeStorage(unsigned shapeSize, Type elementTy,
const int64_t *shapeElements)
- : ShapedTypeStorage(elementTy, shapeSize), shapeElements(shapeElements) {}
+ : ShapedTypeStorage(elementTy), shapeElements(shapeElements),
+ shapeSize(shapeSize) {}
/// The hash key used for uniquing.
using KeyTy = std::pair<ArrayRef<int64_t>, Type>;
@@ -176,16 +157,18 @@ struct VectorTypeStorage : public ShapedTypeStorage {
}
ArrayRef<int64_t> getShape() const {
- return ArrayRef<int64_t>(shapeElements, getSubclassData());
+ return ArrayRef<int64_t>(shapeElements, shapeSize);
}
const int64_t *shapeElements;
+ unsigned shapeSize;
};
struct RankedTensorTypeStorage : public ShapedTypeStorage {
RankedTensorTypeStorage(unsigned shapeSize, Type elementTy,
const int64_t *shapeElements)
- : ShapedTypeStorage(elementTy, shapeSize), shapeElements(shapeElements) {}
+ : ShapedTypeStorage(elementTy), shapeElements(shapeElements),
+ shapeSize(shapeSize) {}
/// The hash key used for uniquing.
using KeyTy = std::pair<ArrayRef<int64_t>, Type>;
@@ -205,10 +188,11 @@ struct RankedTensorTypeStorage : public ShapedTypeStorage {
}
ArrayRef<int64_t> getShape() const {
- return ArrayRef<int64_t>(shapeElements, getSubclassData());
+ return ArrayRef<int64_t>(shapeElements, shapeSize);
}
const int64_t *shapeElements;
+ unsigned shapeSize;
};
struct UnrankedTensorTypeStorage : public ShapedTypeStorage {
@@ -227,9 +211,9 @@ struct MemRefTypeStorage : public ShapedTypeStorage {
MemRefTypeStorage(unsigned shapeSize, Type elementType,
const int64_t *shapeElements, const unsigned numAffineMaps,
AffineMap const *affineMapList, const unsigned memorySpace)
- : ShapedTypeStorage(elementType, shapeSize), shapeElements(shapeElements),
- numAffineMaps(numAffineMaps), affineMapList(affineMapList),
- memorySpace(memorySpace) {}
+ : ShapedTypeStorage(elementType), shapeElements(shapeElements),
+ shapeSize(shapeSize), numAffineMaps(numAffineMaps),
+ affineMapList(affineMapList), memorySpace(memorySpace) {}
/// The hash key used for uniquing.
// MemRefs are uniqued based on their shape, element type, affine map
@@ -258,7 +242,7 @@ struct MemRefTypeStorage : public ShapedTypeStorage {
}
ArrayRef<int64_t> getShape() const {
- return ArrayRef<int64_t>(shapeElements, getSubclassData());
+ return ArrayRef<int64_t>(shapeElements, shapeSize);
}
ArrayRef<AffineMap> getAffineMaps() const {
@@ -267,6 +251,8 @@ struct MemRefTypeStorage : public ShapedTypeStorage {
/// An array of integers which stores the shape dimension sizes.
const int64_t *shapeElements;
+ /// The number of shape elements.
+ unsigned shapeSize;
/// The number of affine maps in the 'affineMapList' array.
const unsigned numAffineMaps;
/// List of affine maps in the memref's layout/index map composition.
@@ -324,7 +310,7 @@ struct TupleTypeStorage final
public llvm::TrailingObjects<TupleTypeStorage, Type> {
using KeyTy = TypeRange;
- TupleTypeStorage(unsigned numTypes) : TypeStorage(numTypes) {}
+ TupleTypeStorage(unsigned numTypes) : numElements(numTypes) {}
/// Construction.
static TupleTypeStorage *construct(TypeStorageAllocator &allocator,
@@ -343,12 +329,15 @@ struct TupleTypeStorage final
bool operator==(const KeyTy &key) const { return key == getTypes(); }
/// Return the number of held types.
- unsigned size() const { return getSubclassData(); }
+ unsigned size() const { return numElements; }
/// Return the held types.
ArrayRef<Type> getTypes() const {
return {getTrailingObjects<Type>(), size()};
}
+
+ /// The number of tuple elements.
+ unsigned numElements;
};
} // namespace detail
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index fea2cc6648e3..ae2dd909ff59 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -27,9 +27,6 @@ Dialect &Type::getDialect() const {
MLIRContext *Type::getContext() const { return getDialect().getContext(); }
-unsigned Type::getSubclassData() const { return impl->getSubclassData(); }
-void Type::setSubclassData(unsigned val) { impl->setSubclassData(val); }
-
//===----------------------------------------------------------------------===//
// FunctionType
//===----------------------------------------------------------------------===//
@@ -39,6 +36,8 @@ FunctionType FunctionType::get(TypeRange inputs, TypeRange results,
return Base::get(context, Type::Kind::Function, inputs, results);
}
+unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
+
ArrayRef<Type> FunctionType::getInputs() const {
return getImpl()->getInputs();
}
More information about the Mlir-commits
mailing list