[Mlir-commits] [mlir] 1d6a8de - [mlir] Remove the need to define `kindof` on attribute and type classes.
River Riddle
llvmlistbot at llvm.org
Fri Aug 7 13:43:50 PDT 2020
Author: River Riddle
Date: 2020-08-07T13:43:25-07:00
New Revision: 1d6a8deb41221f73c57b57fe9add180da34af77f
URL: https://github.com/llvm/llvm-project/commit/1d6a8deb41221f73c57b57fe9add180da34af77f
DIFF: https://github.com/llvm/llvm-project/commit/1d6a8deb41221f73c57b57fe9add180da34af77f.diff
LOG: [mlir] Remove the need to define `kindof` on attribute and type classes.
This revision refactors the default definition of the attribute and type `classof` methods to use the TypeID of the concrete class instead of invoking the `kindof` method. The TypeID is already used as part of uniquing, and this allows for removing the need for users to define any of the type casting utilities themselves.
Differential Revision: https://reviews.llvm.org/D85356
Added:
Modified:
flang/include/flang/Optimizer/Dialect/FIRAttr.h
flang/include/flang/Optimizer/Dialect/FIRType.h
mlir/docs/Tutorials/DefiningAttributesAndTypes.md
mlir/docs/Tutorials/Toy/Ch-7.md
mlir/examples/toy/Ch7/include/toy/Dialect.h
mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
mlir/include/mlir/Dialect/Quant/QuantTypes.h
mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h
mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
mlir/include/mlir/Dialect/Shape/IR/Shape.h
mlir/include/mlir/IR/AttributeSupport.h
mlir/include/mlir/IR/Attributes.h
mlir/include/mlir/IR/Location.h
mlir/include/mlir/IR/StandardTypes.h
mlir/include/mlir/IR/StorageUniquerSupport.h
mlir/include/mlir/IR/TypeSupport.h
mlir/include/mlir/IR/Types.h
mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/test/lib/Dialect/Test/TestTypes.h
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/Dialect/FIRAttr.h b/flang/include/flang/Optimizer/Dialect/FIRAttr.h
index 0b22bc21224c..e9b16909f3fb 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRAttr.h
+++ b/flang/include/flang/Optimizer/Dialect/FIRAttr.h
@@ -48,7 +48,6 @@ class ExactTypeAttr
mlir::Type getType() const;
- static constexpr bool kindof(unsigned kind) { return kind == getId(); }
static constexpr unsigned getId() { return AttributeKind::FIR_EXACTTYPE; }
};
@@ -64,7 +63,6 @@ class SubclassAttr
mlir::Type getType() const;
- static constexpr bool kindof(unsigned kind) { return kind == getId(); }
static constexpr unsigned getId() { return AttributeKind::FIR_SUBCLASS; }
};
@@ -82,7 +80,6 @@ class ClosedIntervalAttr
static constexpr llvm::StringRef getAttrName() { return "interval"; }
static ClosedIntervalAttr get(mlir::MLIRContext *ctxt);
- static constexpr bool kindof(unsigned kind) { return kind == getId(); }
static constexpr unsigned getId() {
return AttributeKind::FIR_CLOSEDCLOSED_INTERVAL;
}
@@ -100,7 +97,6 @@ class UpperBoundAttr
static constexpr llvm::StringRef getAttrName() { return "upper"; }
static UpperBoundAttr get(mlir::MLIRContext *ctxt);
- static constexpr bool kindof(unsigned kind) { return kind == getId(); }
static constexpr unsigned getId() {
return AttributeKind::FIR_OPENCLOSED_INTERVAL;
}
@@ -118,7 +114,6 @@ class LowerBoundAttr
static constexpr llvm::StringRef getAttrName() { return "lower"; }
static LowerBoundAttr get(mlir::MLIRContext *ctxt);
- static constexpr bool kindof(unsigned kind) { return kind == getId(); }
static constexpr unsigned getId() {
return AttributeKind::FIR_CLOSEDOPEN_INTERVAL;
}
@@ -136,7 +131,6 @@ class PointIntervalAttr
static constexpr llvm::StringRef getAttrName() { return "point"; }
static PointIntervalAttr get(mlir::MLIRContext *ctxt);
- static constexpr bool kindof(unsigned kind) { return kind == getId(); }
static constexpr unsigned getId() { return AttributeKind::FIR_POINT; }
};
@@ -157,7 +151,6 @@ class RealAttr
int getFKind() const;
llvm::APFloat getValue() const;
- static constexpr bool kindof(unsigned kind) { return kind == getId(); }
static constexpr unsigned getId() { return AttributeKind::FIR_REAL_ATTR; }
};
diff --git a/flang/include/flang/Optimizer/Dialect/FIRType.h b/flang/include/flang/Optimizer/Dialect/FIRType.h
index b1f1cc85b744..3d3125c97e93 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRType.h
+++ b/flang/include/flang/Optimizer/Dialect/FIRType.h
@@ -114,7 +114,6 @@ mlir::Type dyn_cast_ptrEleTy(mlir::Type t);
/// Boilerplate mixin template
template <typename A, unsigned Id>
struct IntrinsicTypeMixin {
- static constexpr bool kindof(unsigned kind) { return kind == getId(); }
static constexpr unsigned getId() { return Id; }
};
@@ -194,7 +193,6 @@ class BoxType
public:
using Base::Base;
static BoxType get(mlir::Type eleTy, mlir::AffineMapAttr map = {});
- static bool kindof(unsigned kind) { return kind == TypeKind::FIR_BOX; }
mlir::Type getEleTy() const;
mlir::AffineMapAttr getLayoutMap() const;
@@ -211,7 +209,6 @@ class BoxCharType : public mlir::Type::TypeBase<BoxCharType, mlir::Type,
public:
using Base::Base;
static BoxCharType get(mlir::MLIRContext *ctxt, KindTy kind);
- static bool kindof(unsigned kind) { return kind == TypeKind::FIR_BOXCHAR; }
CharacterType getEleTy() const;
};
@@ -223,7 +220,6 @@ class BoxProcType : public mlir::Type::TypeBase<BoxProcType, mlir::Type,
public:
using Base::Base;
static BoxProcType get(mlir::Type eleTy);
- static bool kindof(unsigned kind) { return kind == TypeKind::FIR_BOXPROC; }
mlir::Type getEleTy() const;
static mlir::LogicalResult verifyConstructionInvariants(mlir::Location,
@@ -239,7 +235,6 @@ class DimsType : public mlir::Type::TypeBase<DimsType, mlir::Type,
public:
using Base::Base;
static DimsType get(mlir::MLIRContext *ctx, unsigned rank);
- static bool kindof(unsigned kind) { return kind == TypeKind::FIR_DIMS; }
/// returns -1 if the rank is unknown
unsigned getRank() const;
@@ -253,7 +248,6 @@ class FieldType : public mlir::Type::TypeBase<FieldType, mlir::Type,
public:
using Base::Base;
static FieldType get(mlir::MLIRContext *ctxt);
- static bool kindof(unsigned kind) { return kind == TypeKind::FIR_FIELD; }
};
/// The type of a heap pointer. Fortran entities with the ALLOCATABLE attribute
@@ -265,7 +259,6 @@ class HeapType : public mlir::Type::TypeBase<HeapType, mlir::Type,
public:
using Base::Base;
static HeapType get(mlir::Type elementType);
- static bool kindof(unsigned kind) { return kind == TypeKind::FIR_HEAP; }
mlir::Type getEleTy() const;
@@ -281,7 +274,6 @@ class LenType
public:
using Base::Base;
static LenType get(mlir::MLIRContext *ctxt);
- static bool kindof(unsigned kind) { return kind == TypeKind::FIR_LEN; }
};
/// The type of entities with the POINTER attribute. These pointers are
@@ -292,7 +284,6 @@ class PointerType : public mlir::Type::TypeBase<PointerType, mlir::Type,
public:
using Base::Base;
static PointerType get(mlir::Type elementType);
- static bool kindof(unsigned kind) { return kind == TypeKind::FIR_POINTER; }
mlir::Type getEleTy() const;
@@ -307,7 +298,6 @@ class ReferenceType
public:
using Base::Base;
static ReferenceType get(mlir::Type elementType);
- static bool kindof(unsigned kind) { return kind == TypeKind::FIR_REFERENCE; }
mlir::Type getEleTy() const;
@@ -361,8 +351,6 @@ class SequenceType : public mlir::Type::TypeBase<SequenceType, mlir::Type,
/// The value `-1` represents an unknown extent for a dimension
static constexpr Extent getUnknownExtent() { return -1; }
- static bool kindof(unsigned kind) { return kind == TypeKind::FIR_SEQUENCE; }
-
static mlir::LogicalResult
verifyConstructionInvariants(mlir::Location loc, const Shape &shape,
mlir::Type eleTy, mlir::AffineMapAttr map);
@@ -379,9 +367,6 @@ class TypeDescType : public mlir::Type::TypeBase<TypeDescType, mlir::Type,
public:
using Base::Base;
static TypeDescType get(mlir::Type ofType);
- static constexpr bool kindof(unsigned kind) {
- return kind == TypeKind::FIR_TYPEDESC;
- }
mlir::Type getOfTy() const;
static mlir::LogicalResult verifyConstructionInvariants(mlir::Location,
@@ -415,7 +400,6 @@ class RecordType : public mlir::Type::TypeBase<RecordType, mlir::Type,
static RecordType get(mlir::MLIRContext *ctxt, llvm::StringRef name);
void finalize(llvm::ArrayRef<TypePair> lenPList,
llvm::ArrayRef<TypePair> typeList);
- static constexpr bool kindof(unsigned kind) { return kind == getId(); }
static constexpr unsigned getId() { return TypeKind::FIR_DERIVED; }
detail::RecordTypeStorage const *uniqueKey() const;
diff --git a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
index 45756e1a31ea..13cd75d91261 100644
--- a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
+++ b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
@@ -89,10 +89,6 @@ public:
/// Inherit some necessary constructors from 'TypeBase'.
using Base::Base;
- /// This static method is used to support type inquiry through isa, cast,
- /// and dyn_cast.
- static bool kindof(unsigned kind) { return kind == MyTypes::Simple; }
-
/// This method is used to get an instance of the 'SimpleType'. Given that
/// this is a parameterless type, it just needs to take the context for
/// uniquing purposes.
@@ -193,10 +189,6 @@ public:
/// Inherit some necessary constructors from 'TypeBase'.
using Base::Base;
- /// This static method is used to support type inquiry through isa, cast,
- /// and dyn_cast.
- static bool kindof(unsigned kind) { return kind == MyTypes::Complex; }
-
/// This method is used to get an instance of the 'ComplexType'. This method
/// asserts that all of the construction invariants were satisfied. To
/// gracefully handle failed construction, getChecked should be used instead.
@@ -327,10 +319,6 @@ public:
/// Inherit parent constructors.
using Base::Base;
- /// This static method is used to support type inquiry through isa, cast,
- /// and dyn_cast.
- static bool kindof(unsigned kind) { return kind == MyTypes::Recursive; }
-
/// Creates an instance of the Recursive type. This only takes the type name
/// and returns the type with uninitialized body.
static RecursiveType get(MLIRContext *ctx, StringRef name) {
diff --git a/mlir/docs/Tutorials/Toy/Ch-7.md b/mlir/docs/Tutorials/Toy/Ch-7.md
index 733e22c5b0a5..cbab1e1cadb0 100644
--- a/mlir/docs/Tutorials/Toy/Ch-7.md
+++ b/mlir/docs/Tutorials/Toy/Ch-7.md
@@ -184,10 +184,6 @@ public:
/// Inherit some necessary constructors from 'TypeBase'.
using Base::Base;
- /// This static method is used to support type inquiry through isa, cast,
- /// and dyn_cast.
- static bool kindof(unsigned kind) { return kind == ToyTypes::Struct; }
-
/// Create an instance of a `StructType` with the given element types. There
/// *must* be at least one element type.
static StructType get(llvm::ArrayRef<mlir::Type> elementTypes) {
diff --git a/mlir/examples/toy/Ch7/include/toy/Dialect.h b/mlir/examples/toy/Ch7/include/toy/Dialect.h
index da6e9553d30b..b69516992401 100644
--- a/mlir/examples/toy/Ch7/include/toy/Dialect.h
+++ b/mlir/examples/toy/Ch7/include/toy/Dialect.h
@@ -81,10 +81,6 @@ class StructType : public mlir::Type::TypeBase<StructType, mlir::Type,
/// Inherit some necessary constructors from 'TypeBase'.
using Base::Base;
- /// This static method is used to support type inquiry through isa, cast,
- /// and dyn_cast.
- static bool kindof(unsigned kind) { return kind == ToyTypes::Struct; }
-
/// Create an instance of a `StructType` with the given element types. There
/// *must* be atleast one element type.
static StructType get(llvm::ArrayRef<mlir::Type> elementTypes);
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index 0d3a6f3249b1..32ec77ad63e4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -61,7 +61,7 @@ class LLVMIntegerType;
/// Similarly to other MLIR types, LLVM dialect types are owned by the MLIR
/// context, have an immutable identifier (for most types except identified
/// structs, the entire type is the identifier) and are thread-safe.
-class LLVMType : public Type::TypeBase<LLVMType, Type, TypeStorage> {
+class LLVMType : public Type {
public:
enum Kind {
// Keep non-parametric types contiguous in the enum.
@@ -92,7 +92,7 @@ class LLVMType : public Type::TypeBase<LLVMType, Type, TypeStorage> {
};
/// Inherit base constructors.
- using Base::Base;
+ using Type::Type;
/// Support for PointerLikeTypeTraits.
using Type::getAsOpaquePointer;
@@ -101,8 +101,9 @@ class LLVMType : public Type::TypeBase<LLVMType, Type, TypeStorage> {
}
/// Support for isa/cast.
- static bool kindof(unsigned kind) {
- return FIRST_NEW_LLVM_TYPE <= kind && kind <= LAST_NEW_LLVM_TYPE;
+ static bool classof(Type type) {
+ return type.getKind() >= FIRST_NEW_LLVM_TYPE &&
+ type.getKind() <= LAST_NEW_LLVM_TYPE;
}
LLVMDialect &getDialect();
@@ -256,7 +257,6 @@ class LLVMType : public Type::TypeBase<LLVMType, Type, TypeStorage> {
class ClassName : public Type::TypeBase<ClassName, LLVMType, TypeStorage> { \
public: \
using Base::Base; \
- static bool kindof(unsigned kind) { return kind == Kind; } \
static ClassName get(MLIRContext *context) { \
return Base::get(context, Kind); \
} \
@@ -290,9 +290,6 @@ class LLVMArrayType : public Type::TypeBase<LLVMArrayType, LLVMType,
/// Inherit base constructors.
using Base::Base;
- /// Support for isa/cast.
- static bool kindof(unsigned kind) { return kind == LLVMType::ArrayType; }
-
/// Gets or creates an instance of LLVM dialect array type containing
/// `numElements` of `elementType`, in the same context as `elementType`.
static LLVMArrayType get(LLVMType elementType, unsigned numElements);
@@ -318,9 +315,6 @@ class LLVMFunctionType
/// Inherit base constructors.
using Base::Base;
- /// Support for isa/cast.
- static bool kindof(unsigned kind) { return kind == LLVMType::FunctionType; }
-
/// Gets or creates an instance of LLVM dialect function in the same context
/// as the `result` type.
static LLVMFunctionType get(LLVMType result, ArrayRef<LLVMType> arguments,
@@ -354,9 +348,6 @@ class LLVMIntegerType : public Type::TypeBase<LLVMIntegerType, LLVMType,
/// Inherit base constructor.
using Base::Base;
- /// Support for isa/cast.
- static bool kindof(unsigned kind) { return kind == LLVMType::IntegerType; }
-
/// Gets or creates an instance of the integer of the specified `bitwidth` in
/// the given context.
static LLVMIntegerType get(MLIRContext *ctx, unsigned bitwidth);
@@ -378,9 +369,6 @@ class LLVMPointerType : public Type::TypeBase<LLVMPointerType, LLVMType,
/// Inherit base constructors.
using Base::Base;
- /// Support for isa/cast.
- static bool kindof(unsigned kind) { return kind == LLVMType::PointerType; }
-
/// Gets or creates an instance of LLVM dialect pointer type pointing to an
/// object of `pointee` type in the given address space. The pointer type is
/// created in the same context as `pointee`.
@@ -427,9 +415,6 @@ class LLVMStructType : public Type::TypeBase<LLVMStructType, LLVMType,
/// Inherit base construtors.
using Base::Base;
- /// Support for isa/cast.
- static bool kindof(unsigned kind) { return kind == LLVMType::StructType; }
-
/// Gets or creates an identified struct with the given name in the provided
/// context. Note that unlike llvm::StructType::create, this function will
/// _NOT_ rename a struct in case a struct with the same name already exists
@@ -485,17 +470,13 @@ class LLVMStructType : public Type::TypeBase<LLVMStructType, LLVMType,
/// LLVM dialect vector type, represents a sequence of elements that can be
/// processed as one, typically in SIMD context. This is a base class for fixed
/// and scalable vectors.
-class LLVMVectorType : public Type::TypeBase<LLVMVectorType, LLVMType,
- detail::LLVMTypeAndSizeStorage> {
+class LLVMVectorType : public LLVMType {
public:
/// Inherit base constructor.
- using Base::Base;
+ using LLVMType::LLVMType;
- /// Support for isa/cast.
- static bool kindof(unsigned kind) {
- return kind == LLVMType::FixedVectorType ||
- kind == LLVMType::ScalableVectorType;
- }
+ /// Support type casting functionality.
+ static bool classof(Type type);
/// Returns the element type of the vector.
LLVMType getElementType();
@@ -517,11 +498,6 @@ class LLVMFixedVectorType
/// Inherit base constructor.
using Base::Base;
- /// Support for isa/cast.
- static bool kindof(unsigned kind) {
- return kind == LLVMType::FixedVectorType;
- }
-
/// Gets or creates a fixed vector type containing `numElements` of
/// `elementType` in the same context as `elementType`.
static LLVMFixedVectorType get(LLVMType elementType, unsigned numElements);
@@ -544,11 +520,6 @@ class LLVMScalableVectorType
/// Inherit base constructor.
using Base::Base;
- /// Support for isa/cast.
- static bool kindof(unsigned kind) {
- return kind == LLVMType::ScalableVectorType;
- }
-
/// Gets or creates a scalable vector type containing a non-zero multiple of
/// `minNumElements` of `elementType` in the same context as `elementType`.
static LLVMScalableVectorType get(LLVMType elementType,
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
index 9f23e64e7c5d..17e803db8211 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
@@ -41,8 +41,6 @@ class RangeType : public Type::TypeBase<RangeType, Type, TypeStorage> {
/// Custom, uniq'ed construction in the MLIRContext.
return Base::get(context, LinalgTypes::Range);
}
- /// Used to implement llvm-style cast.
- static bool kindof(unsigned kind) { return kind == LinalgTypes::Range; }
};
} // namespace linalg
diff --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/QuantTypes.h
index 6062da0a0d5b..689ac49163eb 100644
--- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h
+++ b/mlir/include/mlir/Dialect/Quant/QuantTypes.h
@@ -211,9 +211,6 @@ class AnyQuantizedType
public:
using Base::Base;
- /// Support method to enable LLVM-style type casting.
- static bool kindof(unsigned kind) { return kind == QuantizationTypes::Any; }
-
/// Gets an instance of the type with all parameters specified but not
/// checked.
static AnyQuantizedType get(unsigned flags, Type storageType,
@@ -292,11 +289,6 @@ class UniformQuantizedType
int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMax);
- /// Support method to enable LLVM-style type casting.
- static bool kindof(unsigned kind) {
- return kind == QuantizationTypes::UniformQuantized;
- }
-
/// Gets the scale term. The scale designates the
diff erence between the real
/// values corresponding to consecutive quantized values
diff ering by 1.
double getScale() const;
@@ -357,11 +349,6 @@ class UniformQuantizedPerAxisType
int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax);
- /// Support method to enable LLVM-style type casting.
- static bool kindof(unsigned kind) {
- return kind == QuantizationTypes::UniformQuantizedPerAxis;
- }
-
/// Gets the quantization scales. The scales designate the
diff erence between
/// the real values corresponding to consecutive quantized values
diff ering
/// by 1. The ith scale corresponds to the ith slice in the
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h
index 36344b41d6cb..6788d5952cd4 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h
@@ -68,10 +68,6 @@ class InterfaceVarABIAttr
/// Returns `spirv::StorageClass`.
Optional<StorageClass> getStorageClass();
- static bool kindof(unsigned kind) {
- return kind == AttrKind::InterfaceVarABI;
- }
-
static LogicalResult verifyConstructionInvariants(Location loc,
IntegerAttr descriptorSet,
IntegerAttr binding,
@@ -123,8 +119,6 @@ class VerCapExtAttr
/// Returns the capabilities as an integer array attribute.
ArrayAttr getCapabilitiesAttr();
- static bool kindof(unsigned kind) { return kind == AttrKind::VerCapExt; }
-
static LogicalResult verifyConstructionInvariants(Location loc,
IntegerAttr version,
ArrayAttr capabilities,
@@ -165,8 +159,6 @@ class TargetEnvAttr
/// Returns the target resource limits.
ResourceLimitsAttr getResourceLimits();
- static bool kindof(unsigned kind) { return kind == AttrKind::TargetEnv; }
-
static LogicalResult verifyConstructionInvariants(Location loc,
VerCapExtAttr triple,
DictionaryAttr limits);
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
index b13d3c7892be..a9d120b5d114 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
@@ -170,8 +170,6 @@ class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
public:
using Base::Base;
- static bool kindof(unsigned kind) { return kind == TypeKind::Array; }
-
static ArrayType get(Type elementType, unsigned elementCount);
/// Returns an array type with the given stride in bytes.
@@ -202,8 +200,6 @@ class ImageType
public:
using Base::Base;
- static bool kindof(unsigned kind) { return kind == TypeKind::Image; }
-
static ImageType
get(Type elementType, Dim dim,
ImageDepthInfo depth = ImageDepthInfo::DepthUnknown,
@@ -243,8 +239,6 @@ class PointerType : public Type::TypeBase<PointerType, SPIRVType,
public:
using Base::Base;
- static bool kindof(unsigned kind) { return kind == TypeKind::Pointer; }
-
static PointerType get(Type pointeeType, StorageClass storageClass);
Type getPointeeType() const;
@@ -264,8 +258,6 @@ class RuntimeArrayType
public:
using Base::Base;
- static bool kindof(unsigned kind) { return kind == TypeKind::RuntimeArray; }
-
static RuntimeArrayType get(Type elementType);
/// Returns a runtime array type with the given stride in bytes.
@@ -318,8 +310,6 @@ class StructType : public Type::TypeBase<StructType, CompositeType,
}
};
- static bool kindof(unsigned kind) { return kind == TypeKind::Struct; }
-
/// Construct a StructType with at least one member.
static StructType get(ArrayRef<Type> memberTypes,
ArrayRef<OffsetInfo> offsetInfo = {},
@@ -385,10 +375,6 @@ class CooperativeMatrixNVType
public:
using Base::Base;
- static bool kindof(unsigned kind) {
- return kind == TypeKind::CooperativeMatrix;
- }
-
static CooperativeMatrixNVType get(Type elementType, spirv::Scope scope,
unsigned rows, unsigned columns);
Type getElementType() const;
@@ -412,8 +398,6 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
public:
using Base::Base;
- static bool kindof(unsigned kind) { return kind == TypeKind::Matrix; }
-
static MatrixType get(Type columnType, uint32_t columnCount);
static MatrixType getChecked(Type columnType, uint32_t columnCount,
diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h
index ca1e0668d070..3168e87b3df0 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h
+++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h
@@ -49,11 +49,6 @@ class ComponentType : public Type::TypeBase<ComponentType, Type, TypeStorage> {
static ComponentType get(MLIRContext *context) {
return Base::get(context, ShapeTypes::Kind::Component);
}
-
- /// Support method to enable LLVM-style type casting.
- static bool kindof(unsigned kind) {
- return kind == ShapeTypes::Kind::Component;
- }
};
/// The element type of the shaped type.
@@ -64,11 +59,6 @@ class ElementType : public Type::TypeBase<ElementType, Type, TypeStorage> {
static ElementType get(MLIRContext *context) {
return Base::get(context, ShapeTypes::Kind::Element);
}
-
- /// Support method to enable LLVM-style type casting.
- static bool kindof(unsigned kind) {
- return kind == ShapeTypes::Kind::Element;
- }
};
/// The shape descriptor type represents rank and dimension sizes.
@@ -79,9 +69,6 @@ class ShapeType : public Type::TypeBase<ShapeType, Type, TypeStorage> {
static ShapeType get(MLIRContext *context) {
return Base::get(context, ShapeTypes::Kind::Shape);
}
-
- /// Support method to enable LLVM-style type casting.
- static bool kindof(unsigned kind) { return kind == ShapeTypes::Kind::Shape; }
};
/// The type of a single dimension.
@@ -92,9 +79,6 @@ class SizeType : public Type::TypeBase<SizeType, Type, TypeStorage> {
static SizeType get(MLIRContext *context) {
return Base::get(context, ShapeTypes::Kind::Size);
}
-
- /// Support method to enable LLVM-style type casting.
- static bool kindof(unsigned kind) { return kind == ShapeTypes::Kind::Size; }
};
/// The ValueShape represents a (potentially unknown) runtime value and shape.
@@ -106,11 +90,6 @@ class ValueShapeType
static ValueShapeType get(MLIRContext *context) {
return Base::get(context, ShapeTypes::Kind::ValueShape);
}
-
- /// Support method to enable LLVM-style type casting.
- static bool kindof(unsigned kind) {
- return kind == ShapeTypes::Kind::ValueShape;
- }
};
/// The Witness represents a runtime constraint, to be used as shape related
@@ -122,11 +101,6 @@ class WitnessType : public Type::TypeBase<WitnessType, Type, TypeStorage> {
static WitnessType get(MLIRContext *context) {
return Base::get(context, ShapeTypes::Kind::Witness);
}
-
- /// Support method to enable LLVM-style type casting.
- static bool kindof(unsigned kind) {
- return kind == ShapeTypes::Kind::Witness;
- }
};
#define GET_OP_CLASSES
diff --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h
index 79ce1dd2db95..31e6285164ab 100644
--- a/mlir/include/mlir/IR/AttributeSupport.h
+++ b/mlir/include/mlir/IR/AttributeSupport.h
@@ -36,7 +36,7 @@ class AbstractAttribute {
/// This method is used by Dialect objects when they register the list of
/// attributes they contain.
template <typename T> static AbstractAttribute get(Dialect &dialect) {
- return AbstractAttribute(dialect, T::getInterfaceMap());
+ return AbstractAttribute(dialect, T::getInterfaceMap(), T::getTypeID());
}
/// Return the dialect this attribute was registered to.
@@ -49,15 +49,23 @@ class AbstractAttribute {
return interfaceMap.lookup<T>();
}
+ /// Return the unique identifier representing the concrete attribute class.
+ TypeID getTypeID() const { return typeID; }
+
private:
- AbstractAttribute(Dialect &dialect, detail::InterfaceMap &&interfaceMap)
- : dialect(dialect), interfaceMap(std::move(interfaceMap)) {}
+ AbstractAttribute(Dialect &dialect, detail::InterfaceMap &&interfaceMap,
+ TypeID typeID)
+ : dialect(dialect), interfaceMap(std::move(interfaceMap)),
+ typeID(typeID) {}
/// This is the dialect that this attribute was registered to.
Dialect &dialect;
/// This is a collection of the interfaces registered to this attribute.
detail::InterfaceMap interfaceMap;
+
+ /// The unique identifier of the derived Attribute class.
+ TypeID typeID;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index 5ecf5763ecd4..a57adb315bc3 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -97,6 +97,10 @@ class Attribute {
/// Return the classification for this attribute.
unsigned getKind() const { return impl->getKind(); }
+ /// Return a unique identifier for the concrete attribute type. This is used
+ /// to support dynamic type casting.
+ TypeID getTypeID() { return impl->getAbstractAttribute().getTypeID(); }
+
/// Return the type of this attribute.
Type getType() const;
@@ -231,11 +235,6 @@ class AffineMapAttr
static AffineMapAttr get(AffineMap value);
AffineMap getValue() const;
-
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool kindof(unsigned kind) {
- return kind == StandardAttributes::AffineMap;
- }
};
//===----------------------------------------------------------------------===//
@@ -262,11 +261,6 @@ class ArrayAttr : public Attribute::AttrBase<ArrayAttr, Attribute,
size_t size() const { return getValue().size(); }
bool empty() const { return size() == 0; }
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool kindof(unsigned kind) {
- return kind == StandardAttributes::Array;
- }
-
private:
/// Class for underlying value iterator support.
template <typename AttrTy>
@@ -357,11 +351,6 @@ class DictionaryAttr
/// Requires: uniquely named attributes.
static bool sortInPlace(SmallVectorImpl<NamedAttribute> &array);
- /// Methods for supporting type inquiry through isa, cast, and dyn_cast.
- static bool kindof(unsigned kind) {
- return kind == StandardAttributes::Dictionary;
- }
-
private:
/// Return empty dictionary.
static DictionaryAttr getEmpty(MLIRContext *context);
@@ -394,11 +383,6 @@ class FloatAttr : public Attribute::AttrBase<FloatAttr, Attribute,
double getValueAsDouble() const;
static double getValueAsDouble(APFloat val);
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool kindof(unsigned kind) {
- return kind == StandardAttributes::Float;
- }
-
/// Verify the construction invariants for a double value.
static LogicalResult verifyConstructionInvariants(Location loc, Type type,
double value);
@@ -432,11 +416,6 @@ class IntegerAttr
/// an unsigned integer.
uint64_t getUInt() const;
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool kindof(unsigned kind) {
- return kind == StandardAttributes::Integer;
- }
-
static LogicalResult verifyConstructionInvariants(Location loc, Type type,
int64_t value);
static LogicalResult verifyConstructionInvariants(Location loc, Type type,
@@ -480,11 +459,6 @@ class IntegerSetAttr
static IntegerSetAttr get(IntegerSet value);
IntegerSet getValue() const;
-
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool kindof(unsigned kind) {
- return kind == StandardAttributes::IntegerSet;
- }
};
//===----------------------------------------------------------------------===//
@@ -520,10 +494,6 @@ class OpaqueAttr : public Attribute::AttrBase<OpaqueAttr, Attribute,
Identifier dialect,
StringRef attrData,
Type type);
-
- static bool kindof(unsigned kind) {
- return kind == StandardAttributes::Opaque;
- }
};
//===----------------------------------------------------------------------===//
@@ -543,11 +513,6 @@ class StringAttr : public Attribute::AttrBase<StringAttr, Attribute,
static StringAttr get(StringRef bytes, Type type);
StringRef getValue() const;
-
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool kindof(unsigned kind) {
- return kind == StandardAttributes::String;
- }
};
//===----------------------------------------------------------------------===//
@@ -584,11 +549,6 @@ class SymbolRefAttr
/// Returns the set of nested references representing the path to the symbol
/// nested under the root reference.
ArrayRef<FlatSymbolRefAttr> getNestedReferences() const;
-
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool kindof(unsigned kind) {
- return kind == StandardAttributes::SymbolRef;
- }
};
/// A symbol reference with a reference path containing a single element. This
@@ -630,9 +590,6 @@ class TypeAttr : public Attribute::AttrBase<TypeAttr, Attribute,
static TypeAttr get(Type value);
Type getValue() const;
-
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool kindof(unsigned kind) { return kind == StandardAttributes::Type; }
};
//===----------------------------------------------------------------------===//
@@ -647,8 +604,6 @@ class UnitAttr
using Base::Base;
static UnitAttr get(MLIRContext *context);
-
- static bool kindof(unsigned kind) { return kind == StandardAttributes::Unit; }
};
//===----------------------------------------------------------------------===//
@@ -1229,11 +1184,6 @@ class DenseStringElementsAttr
public:
using Base::Base;
- /// Method for support type inquiry through isa, cast and dyn_cast.
- static bool kindof(unsigned kind) {
- return kind == StandardAttributes::DenseStringElements;
- }
-
/// Overload of the raw 'get' method that asserts that the given type is of
/// integer or floating-point type. This method is used to verify type
/// invariants that the templatized 'get' method cannot.
@@ -1252,11 +1202,6 @@ class DenseIntOrFPElementsAttr
public:
using Base::Base;
- /// Method for support type inquiry through isa, cast and dyn_cast.
- static bool kindof(unsigned kind) {
- return kind == StandardAttributes::DenseIntOrFPElements;
- }
-
protected:
friend DenseElementsAttr;
@@ -1394,11 +1339,6 @@ class OpaqueElementsAttr
/// Returns dialect associated with this opaque constant.
Dialect *getDialect() const;
-
- /// Method for support type inquiry through isa, cast and dyn_cast.
- static bool kindof(unsigned kind) {
- return kind == StandardAttributes::OpaqueElements;
- }
};
/// An attribute that represents a reference to a sparse vector or tensor
@@ -1460,11 +1400,6 @@ class SparseElementsAttr
/// expected to refer to a valid element.
Attribute getValue(ArrayRef<uint64_t> index) const;
- /// Method for support type inquiry through isa, cast and dyn_cast.
- static bool kindof(unsigned kind) {
- return kind == StandardAttributes::SparseElements;
- }
-
private:
/// Get a zero APFloat for the given sparse attribute.
APFloat getZeroAPFloat() const;
diff --git a/mlir/include/mlir/IR/Location.h b/mlir/include/mlir/IR/Location.h
index c1b6dd016db7..5bbf003a4271 100644
--- a/mlir/include/mlir/IR/Location.h
+++ b/mlir/include/mlir/IR/Location.h
@@ -120,11 +120,6 @@ class CallSiteLoc
/// The caller's location.
Location getCaller() const;
-
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool kindof(unsigned kind) {
- return kind == StandardAttributes::CallSiteLocation;
- }
};
/// Represents a location derived from a file/line/column location. The column
@@ -146,11 +141,6 @@ class FileLineColLoc
unsigned getLine() const;
unsigned getColumn() const;
-
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool kindof(unsigned kind) {
- return kind == StandardAttributes::FileLineColLocation;
- }
};
/// Represents a value composed of multiple source constructs, with an optional
@@ -174,11 +164,6 @@ class FusedLoc : public Attribute::AttrBase<FusedLoc, LocationAttr,
/// Returns the optional metadata attached to this fused location. Given that
/// it is optional, the return value may be a null node.
Attribute getMetadata() const;
-
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool kindof(unsigned kind) {
- return kind == StandardAttributes::FusedLocation;
- }
};
/// Represents an identity name attached to a child location.
@@ -199,11 +184,6 @@ class NameLoc : public Attribute::AttrBase<NameLoc, LocationAttr,
/// Return the child location.
Location getChildLoc() const;
-
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool kindof(unsigned kind) {
- return kind == StandardAttributes::NameLocation;
- }
};
/// Represents an unknown location. This is always a singleton for a given
@@ -215,11 +195,6 @@ class UnknownLoc
/// Get an instance of the UnknownLoc.
static Location get(MLIRContext *context);
-
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool kindof(unsigned kind) {
- return kind == StandardAttributes::UnknownLocation;
- }
};
/// Represents a location that is external to MLIR. Contains a pointer to some
@@ -283,11 +258,6 @@ class OpaqueLoc : public Attribute::AttrBase<OpaqueLoc, LocationAttr,
/// Returns a fallback location.
Location getFallbackLocation() const;
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool kindof(unsigned kind) {
- return kind == StandardAttributes::OpaqueLocation;
- }
-
private:
static Location get(uintptr_t underlyingLocation, TypeID typeID,
Location fallbackLocation);
diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h
index 406598c9d061..ff9af5bcdeeb 100644
--- a/mlir/include/mlir/IR/StandardTypes.h
+++ b/mlir/include/mlir/IR/StandardTypes.h
@@ -92,8 +92,6 @@ class ComplexType
Type elementType);
Type getElementType();
-
- static bool kindof(unsigned kind) { return kind == StandardTypes::Complex; }
};
//===----------------------------------------------------------------------===//
@@ -109,9 +107,6 @@ class IndexType : public Type::TypeBase<IndexType, Type, TypeStorage> {
/// Get an instance of the IndexType.
static IndexType get(MLIRContext *context);
- /// Support method to enable LLVM-style type casting.
- static bool kindof(unsigned kind) { return kind == StandardTypes::Index; }
-
/// Storage bit width used for IndexType by internal compiler data structures.
static constexpr unsigned kInternalStorageBitWidth = 64;
};
@@ -177,9 +172,6 @@ class IntegerType
/// Return true if this is an unsigned integer type.
bool isUnsigned() const { return getSignedness() == Unsigned; }
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool kindof(unsigned kind) { return kind == StandardTypes::Integer; }
-
/// Integer representation maximal bitwidth.
static constexpr unsigned kMaxWidth = 4096;
};
@@ -208,12 +200,6 @@ class FloatType : public Type::TypeBase<FloatType, Type, TypeStorage> {
return get(StandardTypes::F64, ctx);
}
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool kindof(unsigned kind) {
- return kind >= StandardTypes::FIRST_FLOATING_POINT_TYPE &&
- kind <= StandardTypes::LAST_FLOATING_POINT_TYPE;
- }
-
/// Return the bitwidth of this float type.
unsigned getWidth();
@@ -233,8 +219,6 @@ class NoneType : public Type::TypeBase<NoneType, Type, TypeStorage> {
/// Get an instance of the NoneType.
static NoneType get(MLIRContext *context);
-
- static bool kindof(unsigned kind) { return kind == StandardTypes::None; }
};
//===----------------------------------------------------------------------===//
@@ -361,9 +345,6 @@ class VectorType
}
ArrayRef<int64_t> getShape() const;
-
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool kindof(unsigned kind) { return kind == StandardTypes::Vector; }
};
//===----------------------------------------------------------------------===//
@@ -422,10 +403,6 @@ class RankedTensorType
Type elementType);
ArrayRef<int64_t> getShape() const;
-
- static bool kindof(unsigned kind) {
- return kind == StandardTypes::RankedTensor;
- }
};
//===----------------------------------------------------------------------===//
@@ -454,10 +431,6 @@ class UnrankedTensorType
Type elementType);
ArrayRef<int64_t> getShape() const { return llvm::None; }
-
- static bool kindof(unsigned kind) {
- return kind == StandardTypes::UnrankedTensor;
- }
};
//===----------------------------------------------------------------------===//
@@ -568,8 +541,6 @@ class MemRefType : public Type::TypeBase<MemRefType, BaseMemRefType,
return ShapedType::kDynamicStrideOrOffset;
}
- static bool kindof(unsigned kind) { return kind == StandardTypes::MemRef; }
-
private:
/// Get or create a new MemRefType defined by the arguments. If the resulting
/// type would be ill-formed, return nullptr. If the location is provided,
@@ -611,9 +582,6 @@ class UnrankedMemRefType
/// Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpace() const;
- static bool kindof(unsigned kind) {
- return kind == StandardTypes::UnrankedMemRef;
- }
};
//===----------------------------------------------------------------------===//
@@ -659,8 +627,6 @@ class TupleType
assert(index < size() && "invalid index for tuple type");
return getTypes()[index];
}
-
- static bool kindof(unsigned kind) { return kind == StandardTypes::Tuple; }
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h
index 707d4d7f9ad6..48026c219082 100644
--- a/mlir/include/mlir/IR/StorageUniquerSupport.h
+++ b/mlir/include/mlir/IR/StorageUniquerSupport.h
@@ -68,12 +68,12 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
/// Return a unique identifier for the concrete type.
static TypeID getTypeID() { return TypeID::get<ConcreteT>(); }
- /// Provide a default implementation of 'classof' that invokes a 'kindof'
- /// method on the concrete type.
+ /// Provide an implementation of 'classof' that compares the type id of the
+ /// provided value with that of the concerete type.
template <typename T> static bool classof(T val) {
static_assert(std::is_convertible<ConcreteT, T>::value,
"casting from a non-convertible type");
- return ConcreteT::kindof(val.getKind());
+ return val.getTypeID() == getTypeID();
}
/// Returns an interface map for the interfaces registered to this storage
@@ -107,8 +107,7 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
/// Mutate the current storage instance. This will not change the unique key.
/// The arguments are forwarded to 'ConcreteT::mutate'.
- template <typename... Args>
- LogicalResult mutate(Args &&...args) {
+ template <typename... Args> LogicalResult mutate(Args &&...args) {
return UniquerT::template mutate<ConcreteT>(this->getContext(), getImpl(),
std::forward<Args>(args)...);
}
diff --git a/mlir/include/mlir/IR/TypeSupport.h b/mlir/include/mlir/IR/TypeSupport.h
index c26aec6411c0..aa2daefd26c4 100644
--- a/mlir/include/mlir/IR/TypeSupport.h
+++ b/mlir/include/mlir/IR/TypeSupport.h
@@ -35,7 +35,7 @@ class AbstractType {
/// This method is used by Dialect objects when they register the list of
/// types they contain.
template <typename T> static AbstractType get(Dialect &dialect) {
- return AbstractType(dialect, T::getInterfaceMap());
+ return AbstractType(dialect, T::getInterfaceMap(), T::getTypeID());
}
/// Return the dialect this type was registered to.
@@ -48,15 +48,23 @@ class AbstractType {
return interfaceMap.lookup<T>();
}
+ /// Return the unique identifier representing the concrete type class.
+ TypeID getTypeID() const { return typeID; }
+
private:
- AbstractType(Dialect &dialect, detail::InterfaceMap &&interfaceMap)
- : dialect(dialect), interfaceMap(std::move(interfaceMap)) {}
+ AbstractType(Dialect &dialect, detail::InterfaceMap &&interfaceMap,
+ TypeID typeID)
+ : dialect(dialect), interfaceMap(std::move(interfaceMap)),
+ typeID(typeID) {}
/// This is the dialect that this type was registered to.
Dialect &dialect;
/// This is a collection of the interfaces registered to this type.
detail::InterfaceMap interfaceMap;
+
+ /// The unique identifier of the derived Type class.
+ TypeID typeID;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index bd0ea4bbd5dc..8101690daeb6 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -44,11 +44,6 @@ struct OpaqueTypeStorage;
///
/// Derived type classes are expected to implement several required
/// implementation hooks:
-/// * Required:
-/// - static bool kindof(unsigned kind);
-/// * Returns if the provided type kind corresponds to an instance of the
-/// current type. Used for isa/dyn_cast casting functionality.
-///
/// * Optional:
/// - static LogicalResult verifyConstructionInvariants(Location loc,
/// Args... args)
@@ -137,6 +132,10 @@ class Type {
// Support type casting Type to itself.
static bool classof(Type) { return true; }
+ /// Return a unique identifier for the concrete type. This is used to support
+ /// dynamic type casting.
+ TypeID getTypeID() { return impl->getAbstractType().getTypeID(); }
+
/// Return the classification for this type.
unsigned getKind() const;
@@ -262,7 +261,6 @@ class FunctionType
// Input types.
unsigned getNumInputs() const;
-
Type getInput(unsigned i) const { return getInputs()[i]; }
ArrayRef<Type> getInputs() const;
@@ -270,9 +268,6 @@ class FunctionType
unsigned getNumResults() const;
Type getResult(unsigned i) const { return getResults()[i]; }
ArrayRef<Type> getResults() const;
-
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool kindof(unsigned kind) { return kind == Kind::Function; }
};
//===----------------------------------------------------------------------===//
@@ -307,8 +302,6 @@ class OpaqueType
static LogicalResult verifyConstructionInvariants(Location loc,
Identifier dialect,
StringRef typeData);
-
- static bool kindof(unsigned kind) { return kind == Kind::Opaque; }
};
// Make Type hashable.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index f8cadeb0c40f..ee429c1f73e3 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -322,6 +322,11 @@ ArrayRef<LLVMType> LLVMStructType::getBody() {
//===----------------------------------------------------------------------===//
// Vector types.
+/// Support type casting functionality.
+bool LLVMVectorType::classof(Type type) {
+ return type.isa<LLVMFixedVectorType, LLVMScalableVectorType>();
+}
+
LLVMType LLVMVectorType::getElementType() {
// Both derived classes share the implementation type.
return static_cast<detail::LLVMTypeAndSizeStorage *>(impl)->elementType;
@@ -331,7 +336,7 @@ llvm::ElementCount LLVMVectorType::getElementCount() {
// Both derived classes share the implementation type.
return llvm::ElementCount(
static_cast<detail::LLVMTypeAndSizeStorage *>(impl)->numElements,
- this->isa<LLVMScalableVectorType>());
+ isa<LLVMScalableVectorType>());
}
LLVMFixedVectorType LLVMFixedVectorType::get(LLVMType elementType,
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 42c4d4855e50..ace355319659 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -676,7 +676,6 @@ Identifier Identifier::get(StringRef str, MLIRContext *context) {
StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
FloatType FloatType::get(StandardTypes::Kind kind, MLIRContext *context) {
- assert(kindof(kind) && "Not a FP kind.");
switch (kind) {
case StandardTypes::BF16:
return context->getImpl().bf16Ty;
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h
index 9e2c297c6a89..1df165591672 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.h
+++ b/mlir/test/lib/Dialect/Test/TestTypes.h
@@ -26,10 +26,6 @@ struct TestType : public Type::TypeBase<TestType, Type, TypeStorage,
TestTypeInterface::Trait> {
using Base::Base;
- static bool kindof(unsigned kind) {
- return kind == Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_9_TYPE;
- }
-
static TestType get(MLIRContext *context) {
return Base::get(context, Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_9_TYPE);
}
@@ -76,10 +72,6 @@ class TestRecursiveType
public:
using Base::Base;
- static bool kindof(unsigned kind) {
- return kind == Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_9_TYPE + 1;
- }
-
static TestRecursiveType create(MLIRContext *ctx, StringRef name) {
return Base::get(ctx, Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_9_TYPE + 1,
name);
More information about the Mlir-commits
mailing list