[flang-commits] [flang] 250f43d - [mlir] Remove the use of "kinds" from Attributes and Types
River Riddle via flang-commits
flang-commits at lists.llvm.org
Tue Aug 18 16:20:35 PDT 2020
Author: River Riddle
Date: 2020-08-18T16:20:14-07:00
New Revision: 250f43d3ecc8d6a3780c9aa2e3770c0193a28850
URL: https://github.com/llvm/llvm-project/commit/250f43d3ecc8d6a3780c9aa2e3770c0193a28850
DIFF: https://github.com/llvm/llvm-project/commit/250f43d3ecc8d6a3780c9aa2e3770c0193a28850.diff
LOG: [mlir] Remove the use of "kinds" from Attributes and Types
This greatly simplifies a large portion of the underlying infrastructure, allows for lookups of singleton classes to be much more efficient and always thread-safe(no locking). As a result of this, the dialect symbol registry has been removed as it is no longer necessary.
For users broken by this change, an alert was sent out(https://llvm.discourse.group/t/removing-kinds-from-attributes-and-types) that helps prevent a majority of the breakage surface area. All that should be necessary, if the advice in that alert was followed, is removing the kind passed to the ::get methods.
Differential Revision: https://reviews.llvm.org/D86121
Added:
Modified:
flang/include/flang/Optimizer/Dialect/FIRAttr.h
flang/include/flang/Optimizer/Dialect/FIRType.h
flang/lib/Optimizer/Dialect/FIRAttr.cpp
flang/lib/Optimizer/Dialect/FIRType.cpp
mlir/docs/Tutorials/Toy/Ch-7.md
mlir/examples/toy/Ch7/include/toy/Dialect.h
mlir/examples/toy/Ch7/mlir/Dialect.cpp
mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
mlir/include/mlir/Dialect/Quant/QuantTypes.h
mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h
mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
mlir/include/mlir/Dialect/Shape/IR/Shape.h
mlir/include/mlir/IR/AttributeSupport.h
mlir/include/mlir/IR/Attributes.h
mlir/include/mlir/IR/Dialect.h
mlir/include/mlir/IR/OpImplementation.h
mlir/include/mlir/IR/StandardTypes.h
mlir/include/mlir/IR/StorageUniquerSupport.h
mlir/include/mlir/IR/TypeSupport.h
mlir/include/mlir/IR/Types.h
mlir/include/mlir/Support/StorageUniquer.h
mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
mlir/lib/Dialect/SDBM/SDBMDialect.cpp
mlir/lib/Dialect/SDBM/SDBMExpr.cpp
mlir/lib/Dialect/SDBM/SDBMExprDetail.h
mlir/lib/Dialect/SPIRV/SPIRVAttributes.cpp
mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
mlir/lib/IR/AffineExpr.cpp
mlir/lib/IR/AffineExprDetail.h
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/Attributes.cpp
mlir/lib/IR/Location.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/lib/IR/StandardTypes.cpp
mlir/lib/IR/Types.cpp
mlir/lib/Support/StorageUniquer.cpp
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestTypes.h
mlir/test/lib/IR/TestTypes.cpp
Removed:
mlir/include/mlir/IR/DialectSymbolRegistry.def
################################################################################
diff --git a/flang/include/flang/Optimizer/Dialect/FIRAttr.h b/flang/include/flang/Optimizer/Dialect/FIRAttr.h
index e9b16909f3fb..e0008161f993 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRAttr.h
+++ b/flang/include/flang/Optimizer/Dialect/FIRAttr.h
@@ -25,17 +25,6 @@ struct RealAttributeStorage;
struct TypeAttributeStorage;
} // namespace detail
-enum AttributeKind {
- FIR_ATTR = mlir::Attribute::FIRST_FIR_ATTR,
- FIR_EXACTTYPE, // instance_of, precise type relation
- FIR_SUBCLASS, // subsumed_by, is-a (subclass) relation
- FIR_POINT,
- FIR_CLOSEDCLOSED_INTERVAL,
- FIR_OPENCLOSED_INTERVAL,
- FIR_CLOSEDOPEN_INTERVAL,
- FIR_REAL_ATTR,
-};
-
class ExactTypeAttr
: public mlir::Attribute::AttrBase<ExactTypeAttr, mlir::Attribute,
detail::TypeAttributeStorage> {
@@ -47,8 +36,6 @@ class ExactTypeAttr
static ExactTypeAttr get(mlir::Type value);
mlir::Type getType() const;
-
- static constexpr unsigned getId() { return AttributeKind::FIR_EXACTTYPE; }
};
class SubclassAttr
@@ -62,8 +49,6 @@ class SubclassAttr
static SubclassAttr get(mlir::Type value);
mlir::Type getType() const;
-
- static constexpr unsigned getId() { return AttributeKind::FIR_SUBCLASS; }
};
// Attributes for building SELECT CASE multiway branches
@@ -80,9 +65,6 @@ class ClosedIntervalAttr
static constexpr llvm::StringRef getAttrName() { return "interval"; }
static ClosedIntervalAttr get(mlir::MLIRContext *ctxt);
- static constexpr unsigned getId() {
- return AttributeKind::FIR_CLOSEDCLOSED_INTERVAL;
- }
};
/// An upper bound is an open interval (including the bound value) as given as
@@ -97,9 +79,6 @@ class UpperBoundAttr
static constexpr llvm::StringRef getAttrName() { return "upper"; }
static UpperBoundAttr get(mlir::MLIRContext *ctxt);
- static constexpr unsigned getId() {
- return AttributeKind::FIR_OPENCLOSED_INTERVAL;
- }
};
/// A lower bound is an open interval (including the bound value) as given as
@@ -114,9 +93,6 @@ class LowerBoundAttr
static constexpr llvm::StringRef getAttrName() { return "lower"; }
static LowerBoundAttr get(mlir::MLIRContext *ctxt);
- static constexpr unsigned getId() {
- return AttributeKind::FIR_CLOSEDOPEN_INTERVAL;
- }
};
/// A pointer interval is a closed interval as given as an ssa-value. The
@@ -131,7 +107,6 @@ class PointIntervalAttr
static constexpr llvm::StringRef getAttrName() { return "point"; }
static PointIntervalAttr get(mlir::MLIRContext *ctxt);
- static constexpr unsigned getId() { return AttributeKind::FIR_POINT; }
};
/// A real attribute is used to workaround MLIR's default parsing of a real
@@ -150,8 +125,6 @@ class RealAttr
int getFKind() const;
llvm::APFloat getValue() const;
-
- static constexpr unsigned getId() { return AttributeKind::FIR_REAL_ATTR; }
};
mlir::Attribute parseFirAttribute(FIROpsDialect *dialect,
diff --git a/flang/include/flang/Optimizer/Dialect/FIRType.h b/flang/include/flang/Optimizer/Dialect/FIRType.h
index 3d3125c97e93..6d2aec25fa8f 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRType.h
+++ b/flang/include/flang/Optimizer/Dialect/FIRType.h
@@ -54,29 +54,6 @@ struct SequenceTypeStorage;
struct TypeDescTypeStorage;
} // namespace detail
-/// Integral identifier for all the types comprising the FIR type system
-enum TypeKind {
- // The enum starts at the range reserved for this dialect.
- FIR_TYPE = mlir::Type::FIRST_FIR_TYPE,
- FIR_BOX, // (static) descriptor
- FIR_BOXCHAR, // CHARACTER pointer and length
- FIR_BOXPROC, // procedure with host association
- FIR_CHARACTER, // intrinsic type
- FIR_COMPLEX, // intrinsic type
- FIR_DERIVED, // derived
- FIR_DIMS,
- FIR_FIELD,
- FIR_HEAP,
- FIR_INT, // intrinsic type
- FIR_LEN,
- FIR_LOGICAL, // intrinsic type
- FIR_POINTER, // POINTER attr
- FIR_REAL, // intrinsic type
- FIR_REFERENCE,
- FIR_SEQUENCE, // DIMENSION attr
- FIR_TYPEDESC,
-};
-
// These isa_ routines follow the precedent of llvm::isa_or_null<>
/// Is `t` any of the FIR dialect types?
@@ -111,12 +88,6 @@ bool isa_aggregate(mlir::Type t);
/// not a memory reference type, then returns a null `Type`.
mlir::Type dyn_cast_ptrEleTy(mlir::Type t);
-/// Boilerplate mixin template
-template <typename A, unsigned Id>
-struct IntrinsicTypeMixin {
- static constexpr unsigned getId() { return Id; }
-};
-
// Intrinsic types
/// Model of the Fortran CHARACTER intrinsic type, including the KIND type
@@ -124,8 +95,7 @@ struct IntrinsicTypeMixin {
/// is thus the type of a single character value.
class CharacterType
: public mlir::Type::TypeBase<CharacterType, mlir::Type,
- detail::CharacterTypeStorage>,
- public IntrinsicTypeMixin<CharacterType, TypeKind::FIR_CHARACTER> {
+ detail::CharacterTypeStorage> {
public:
using Base::Base;
static CharacterType get(mlir::MLIRContext *ctxt, KindTy kind);
@@ -136,8 +106,7 @@ class CharacterType
/// parameter. COMPLEX is a floating point type with a real and imaginary
/// member.
class CplxType : public mlir::Type::TypeBase<CplxType, mlir::Type,
- detail::CplxTypeStorage>,
- public IntrinsicTypeMixin<CplxType, TypeKind::FIR_COMPLEX> {
+ detail::CplxTypeStorage> {
public:
using Base::Base;
static CplxType get(mlir::MLIRContext *ctxt, KindTy kind);
@@ -151,8 +120,7 @@ class CplxType : public mlir::Type::TypeBase<CplxType, mlir::Type,
/// Model of a Fortran INTEGER intrinsic type, including the KIND type
/// parameter.
class IntType
- : public mlir::Type::TypeBase<IntType, mlir::Type, detail::IntTypeStorage>,
- public IntrinsicTypeMixin<IntType, TypeKind::FIR_INT> {
+ : public mlir::Type::TypeBase<IntType, mlir::Type, detail::IntTypeStorage> {
public:
using Base::Base;
static IntType get(mlir::MLIRContext *ctxt, KindTy kind);
@@ -163,8 +131,7 @@ class IntType
/// parameter.
class LogicalType
: public mlir::Type::TypeBase<LogicalType, mlir::Type,
- detail::LogicalTypeStorage>,
- public IntrinsicTypeMixin<LogicalType, TypeKind::FIR_LOGICAL> {
+ detail::LogicalTypeStorage> {
public:
using Base::Base;
static LogicalType get(mlir::MLIRContext *ctxt, KindTy kind);
@@ -174,8 +141,7 @@ class LogicalType
/// Model of a Fortran REAL (and DOUBLE PRECISION) intrinsic type, including the
/// KIND type parameter.
class RealType : public mlir::Type::TypeBase<RealType, mlir::Type,
- detail::RealTypeStorage>,
- public IntrinsicTypeMixin<RealType, TypeKind::FIR_REAL> {
+ detail::RealTypeStorage> {
public:
using Base::Base;
static RealType get(mlir::MLIRContext *ctxt, KindTy kind);
@@ -400,7 +366,6 @@ class RecordType : public mlir::Type::TypeBase<RecordType, mlir::Type,
static RecordType get(mlir::MLIRContext *ctxt, llvm::StringRef name);
void finalize(llvm::ArrayRef<TypePair> lenPList,
llvm::ArrayRef<TypePair> typeList);
- static constexpr unsigned getId() { return TypeKind::FIR_DERIVED; }
detail::RecordTypeStorage const *uniqueKey() const;
diff --git a/flang/lib/Optimizer/Dialect/FIRAttr.cpp b/flang/lib/Optimizer/Dialect/FIRAttr.cpp
index 09780d306fb8..0a219d1cab74 100644
--- a/flang/lib/Optimizer/Dialect/FIRAttr.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRAttr.cpp
@@ -74,13 +74,13 @@ struct TypeAttributeStorage : public mlir::AttributeStorage {
} // namespace detail
ExactTypeAttr ExactTypeAttr::get(mlir::Type value) {
- return Base::get(value.getContext(), FIR_EXACTTYPE, value);
+ return Base::get(value.getContext(), value);
}
mlir::Type ExactTypeAttr::getType() const { return getImpl()->getType(); }
SubclassAttr SubclassAttr::get(mlir::Type value) {
- return Base::get(value.getContext(), FIR_SUBCLASS, value);
+ return Base::get(value.getContext(), value);
}
mlir::Type SubclassAttr::getType() const { return getImpl()->getType(); }
@@ -88,26 +88,26 @@ mlir::Type SubclassAttr::getType() const { return getImpl()->getType(); }
using AttributeUniquer = mlir::detail::AttributeUniquer;
ClosedIntervalAttr ClosedIntervalAttr::get(mlir::MLIRContext *ctxt) {
- return AttributeUniquer::get<ClosedIntervalAttr>(ctxt, getId());
+ return AttributeUniquer::get<ClosedIntervalAttr>(ctxt);
}
UpperBoundAttr UpperBoundAttr::get(mlir::MLIRContext *ctxt) {
- return AttributeUniquer::get<UpperBoundAttr>(ctxt, getId());
+ return AttributeUniquer::get<UpperBoundAttr>(ctxt);
}
LowerBoundAttr LowerBoundAttr::get(mlir::MLIRContext *ctxt) {
- return AttributeUniquer::get<LowerBoundAttr>(ctxt, getId());
+ return AttributeUniquer::get<LowerBoundAttr>(ctxt);
}
PointIntervalAttr PointIntervalAttr::get(mlir::MLIRContext *ctxt) {
- return AttributeUniquer::get<PointIntervalAttr>(ctxt, getId());
+ return AttributeUniquer::get<PointIntervalAttr>(ctxt);
}
// RealAttr
RealAttr RealAttr::get(mlir::MLIRContext *ctxt,
const RealAttr::ValueType &key) {
- return Base::get(ctxt, getId(), key);
+ return Base::get(ctxt, key);
}
int RealAttr::getFKind() const { return getImpl()->getFKind(); }
diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp
index c29412b86944..e4c548035d4a 100644
--- a/flang/lib/Optimizer/Dialect/FIRType.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRType.cpp
@@ -824,13 +824,11 @@ bool inbounds(A v, B lb, B ub) {
}
bool isa_fir_type(mlir::Type t) {
- return inbounds(t.getKind(), mlir::Type::FIRST_FIR_TYPE,
- mlir::Type::LAST_FIR_TYPE);
+ return llvm::isa<FIROpsDialect>(t.getDialect());
}
bool isa_std_type(mlir::Type t) {
- return inbounds(t.getKind(), mlir::Type::FIRST_STANDARD_TYPE,
- mlir::Type::LAST_STANDARD_TYPE);
+ return t.getDialect().getNamespace().empty();
}
bool isa_fir_or_std_type(mlir::Type t) {
@@ -868,7 +866,7 @@ mlir::Type dyn_cast_ptrEleTy(mlir::Type t) {
// CHARACTER
CharacterType fir::CharacterType::get(mlir::MLIRContext *ctxt, KindTy kind) {
- return Base::get(ctxt, FIR_CHARACTER, kind);
+ return Base::get(ctxt, kind);
}
int fir::CharacterType::getFKind() const { return getImpl()->getFKind(); }
@@ -876,7 +874,7 @@ int fir::CharacterType::getFKind() const { return getImpl()->getFKind(); }
// Dims
DimsType fir::DimsType::get(mlir::MLIRContext *ctxt, unsigned rank) {
- return Base::get(ctxt, FIR_DIMS, rank);
+ return Base::get(ctxt, rank);
}
unsigned fir::DimsType::getRank() const { return getImpl()->getRank(); }
@@ -884,19 +882,19 @@ unsigned fir::DimsType::getRank() const { return getImpl()->getRank(); }
// Field
FieldType fir::FieldType::get(mlir::MLIRContext *ctxt) {
- return Base::get(ctxt, FIR_FIELD, 0);
+ return Base::get(ctxt, 0);
}
// Len
LenType fir::LenType::get(mlir::MLIRContext *ctxt) {
- return Base::get(ctxt, FIR_LEN, 0);
+ return Base::get(ctxt, 0);
}
// LOGICAL
LogicalType fir::LogicalType::get(mlir::MLIRContext *ctxt, KindTy kind) {
- return Base::get(ctxt, FIR_LOGICAL, kind);
+ return Base::get(ctxt, kind);
}
int fir::LogicalType::getFKind() const { return getImpl()->getFKind(); }
@@ -904,7 +902,7 @@ int fir::LogicalType::getFKind() const { return getImpl()->getFKind(); }
// INTEGER
IntType fir::IntType::get(mlir::MLIRContext *ctxt, KindTy kind) {
- return Base::get(ctxt, FIR_INT, kind);
+ return Base::get(ctxt, kind);
}
int fir::IntType::getFKind() const { return getImpl()->getFKind(); }
@@ -912,7 +910,7 @@ int fir::IntType::getFKind() const { return getImpl()->getFKind(); }
// COMPLEX
CplxType fir::CplxType::get(mlir::MLIRContext *ctxt, KindTy kind) {
- return Base::get(ctxt, FIR_COMPLEX, kind);
+ return Base::get(ctxt, kind);
}
mlir::Type fir::CplxType::getElementType() const {
@@ -924,7 +922,7 @@ KindTy fir::CplxType::getFKind() const { return getImpl()->getFKind(); }
// REAL
RealType fir::RealType::get(mlir::MLIRContext *ctxt, KindTy kind) {
- return Base::get(ctxt, FIR_REAL, kind);
+ return Base::get(ctxt, kind);
}
int fir::RealType::getFKind() const { return getImpl()->getFKind(); }
@@ -932,7 +930,7 @@ int fir::RealType::getFKind() const { return getImpl()->getFKind(); }
// Box<T>
BoxType fir::BoxType::get(mlir::Type elementType, mlir::AffineMapAttr map) {
- return Base::get(elementType.getContext(), FIR_BOX, elementType, map);
+ return Base::get(elementType.getContext(), elementType, map);
}
mlir::Type fir::BoxType::getEleTy() const {
@@ -953,7 +951,7 @@ fir::BoxType::verifyConstructionInvariants(mlir::Location, mlir::Type eleTy,
// BoxChar<C>
BoxCharType fir::BoxCharType::get(mlir::MLIRContext *ctxt, KindTy kind) {
- return Base::get(ctxt, FIR_BOXCHAR, kind);
+ return Base::get(ctxt, kind);
}
CharacterType fir::BoxCharType::getEleTy() const {
@@ -963,7 +961,7 @@ CharacterType fir::BoxCharType::getEleTy() const {
// BoxProc<T>
BoxProcType fir::BoxProcType::get(mlir::Type elementType) {
- return Base::get(elementType.getContext(), FIR_BOXPROC, elementType);
+ return Base::get(elementType.getContext(), elementType);
}
mlir::Type fir::BoxProcType::getEleTy() const {
@@ -984,7 +982,7 @@ fir::BoxProcType::verifyConstructionInvariants(mlir::Location loc,
// Reference<T>
ReferenceType fir::ReferenceType::get(mlir::Type elementType) {
- return Base::get(elementType.getContext(), FIR_REFERENCE, elementType);
+ return Base::get(elementType.getContext(), elementType);
}
mlir::Type fir::ReferenceType::getEleTy() const {
@@ -1005,7 +1003,7 @@ fir::ReferenceType::verifyConstructionInvariants(mlir::Location loc,
PointerType fir::PointerType::get(mlir::Type elementType) {
assert(singleIndirectionLevel(elementType) && "invalid element type");
- return Base::get(elementType.getContext(), FIR_POINTER, elementType);
+ return Base::get(elementType.getContext(), elementType);
}
mlir::Type fir::PointerType::getEleTy() const {
@@ -1033,7 +1031,7 @@ fir::PointerType::verifyConstructionInvariants(mlir::Location loc,
HeapType fir::HeapType::get(mlir::Type elementType) {
assert(singleIndirectionLevel(elementType) && "invalid element type");
- return Base::get(elementType.getContext(), FIR_HEAP, elementType);
+ return Base::get(elementType.getContext(), elementType);
}
mlir::Type fir::HeapType::getEleTy() const {
@@ -1054,7 +1052,7 @@ fir::HeapType::verifyConstructionInvariants(mlir::Location loc,
SequenceType fir::SequenceType::get(const Shape &shape, mlir::Type elementType,
mlir::AffineMapAttr map) {
auto *ctxt = elementType.getContext();
- return Base::get(ctxt, FIR_SEQUENCE, shape, elementType, map);
+ return Base::get(ctxt, shape, elementType, map);
}
mlir::Type fir::SequenceType::getEleTy() const {
@@ -1136,7 +1134,7 @@ llvm::hash_code fir::hash_value(const SequenceType::Shape &sh) {
/// This type captures a Fortran "derived type"
RecordType fir::RecordType::get(mlir::MLIRContext *ctxt, llvm::StringRef name) {
- return Base::get(ctxt, FIR_DERIVED, name);
+ return Base::get(ctxt, name);
}
void fir::RecordType::finalize(llvm::ArrayRef<TypePair> lenPList,
@@ -1179,7 +1177,7 @@ mlir::Type fir::RecordType::getType(llvm::StringRef ident) {
TypeDescType fir::TypeDescType::get(mlir::Type ofType) {
assert(!ofType.isa<ReferenceType>());
- return Base::get(ofType.getContext(), FIR_TYPEDESC, ofType);
+ return Base::get(ofType.getContext(), ofType);
}
mlir::Type fir::TypeDescType::getOfTy() const { return getImpl()->getOfType(); }
@@ -1222,9 +1220,7 @@ void fir::verifyIntegralType(mlir::Type type) {
void fir::printFirType(FIROpsDialect *, mlir::Type ty,
mlir::DialectAsmPrinter &p) {
auto &os = p.getStream();
- switch (ty.getKind()) {
- case fir::FIR_BOX: {
- auto type = ty.cast<BoxType>();
+ if (auto type = ty.dyn_cast<BoxType>()) {
os << "box<";
p.printType(type.getEleTy());
if (auto map = type.getLayoutMap()) {
@@ -1232,24 +1228,28 @@ void fir::printFirType(FIROpsDialect *, mlir::Type ty,
p.printAttribute(map);
}
os << '>';
- } break;
- case fir::FIR_BOXCHAR: {
- auto type = ty.cast<BoxCharType>().getEleTy();
- os << "boxchar<" << type.cast<fir::CharacterType>().getFKind() << '>';
- } break;
- case fir::FIR_BOXPROC:
+ return;
+ }
+ if (auto type = ty.dyn_cast<BoxCharType>()) {
+ os << "boxchar<" << type.getEleTy().cast<fir::CharacterType>().getFKind()
+ << '>';
+ return;
+ }
+ if (auto type = ty.dyn_cast<BoxProcType>()) {
os << "boxproc<";
- p.printType(ty.cast<BoxProcType>().getEleTy());
+ p.printType(type.getEleTy());
os << '>';
- break;
- case fir::FIR_CHARACTER: // intrinsic
- os << "char<" << ty.cast<CharacterType>().getFKind() << '>';
- break;
- case fir::FIR_COMPLEX: // intrinsic
- os << "complex<" << ty.cast<CplxType>().getFKind() << '>';
- break;
- case fir::FIR_DERIVED: { // derived
- auto type = ty.cast<fir::RecordType>();
+ return;
+ }
+ if (auto type = ty.dyn_cast<CharacterType>()) {
+ os << "char<" << type.getFKind() << '>';
+ return;
+ }
+ if (auto type = ty.dyn_cast<CplxType>()) {
+ os << "complex<" << type.getFKind() << '>';
+ return;
+ }
+ if (auto type = ty.dyn_cast<RecordType>()) {
os << "type<" << type.getName();
if (!recordTypeVisited.count(type.uniqueKey())) {
recordTypeVisited.insert(type.uniqueKey());
@@ -1274,43 +1274,52 @@ void fir::printFirType(FIROpsDialect *, mlir::Type ty,
recordTypeVisited.erase(type.uniqueKey());
}
os << '>';
- } break;
- case fir::FIR_DIMS:
- os << "dims<" << ty.cast<DimsType>().getRank() << '>';
- break;
- case fir::FIR_FIELD:
+ return;
+ }
+ if (auto type = ty.dyn_cast<DimsType>()) {
+ os << "dims<" << type.getRank() << '>';
+ return;
+ }
+ if (ty.isa<FieldType>()) {
os << "field";
- break;
- case fir::FIR_HEAP:
+ return;
+ }
+ if (auto type = ty.dyn_cast<HeapType>()) {
os << "heap<";
- p.printType(ty.cast<HeapType>().getEleTy());
+ p.printType(type.getEleTy());
os << '>';
- break;
- case fir::FIR_INT: // intrinsic
- os << "int<" << ty.cast<fir::IntType>().getFKind() << '>';
- break;
- case fir::FIR_LEN:
+ return;
+ }
+ if (auto type = ty.dyn_cast<fir::IntType>()) {
+ os << "int<" << type.getFKind() << '>';
+ return;
+ }
+ if (auto type = ty.dyn_cast<LenType>()) {
os << "len";
- break;
- case fir::FIR_LOGICAL: // intrinsic
- os << "logical<" << ty.cast<LogicalType>().getFKind() << '>';
- break;
- case fir::FIR_POINTER:
+ return;
+ }
+ if (auto type = ty.dyn_cast<LogicalType>()) {
+ os << "logical<" << type.getFKind() << '>';
+ return;
+ }
+ if (auto type = ty.dyn_cast<PointerType>()) {
os << "ptr<";
- p.printType(ty.cast<PointerType>().getEleTy());
+ p.printType(type.getEleTy());
os << '>';
- break;
- case fir::FIR_REAL: // intrinsic
- os << "real<" << ty.cast<fir::RealType>().getFKind() << '>';
- break;
- case fir::FIR_REFERENCE:
+ return;
+ }
+ if (auto type = ty.dyn_cast<fir::RealType>()) {
+ os << "real<" << type.getFKind() << '>';
+ return;
+ }
+ if (auto type = ty.dyn_cast<ReferenceType>()) {
os << "ref<";
- p.printType(ty.cast<ReferenceType>().getEleTy());
+ p.printType(type.getEleTy());
os << '>';
- break;
- case fir::FIR_SEQUENCE: {
+ return;
+ }
+ if (auto type = ty.dyn_cast<SequenceType>()) {
os << "array";
- auto type = ty.cast<SequenceType>();
auto shape = type.getShape();
if (shape.size()) {
printBounds(os, shape);
@@ -1323,11 +1332,12 @@ void fir::printFirType(FIROpsDialect *, mlir::Type ty,
map.print(os);
}
os << '>';
- } break;
- case fir::FIR_TYPEDESC:
+ return;
+ }
+ if (auto type = ty.dyn_cast<TypeDescType>()) {
os << "tdesc<";
- p.printType(ty.cast<TypeDescType>().getOfTy());
+ p.printType(type.getOfTy());
os << '>';
- break;
+ return;
}
}
diff --git a/mlir/docs/Tutorials/Toy/Ch-7.md b/mlir/docs/Tutorials/Toy/Ch-7.md
index cbab1e1cadb0..c20b8d95617d 100644
--- a/mlir/docs/Tutorials/Toy/Ch-7.md
+++ b/mlir/docs/Tutorials/Toy/Ch-7.md
@@ -190,11 +190,10 @@ public:
assert(!elementTypes.empty() && "expected at least 1 element type");
// Call into a helper 'get' method in 'TypeBase' to get a uniqued instance
- // of this type. The first two parameters are the context to unique in and
- // the kind of the type. The parameters after the type kind are forwarded to
- // the storage instance.
+ // of this type. The first parameter is the context to unique in. The
+ // parameters after the type kind are forwarded to the storage instance.
mlir::MLIRContext *ctx = elementTypes.front().getContext();
- return Base::get(ctx, ToyTypes::Struct, elementTypes);
+ return Base::get(ctx, elementTypes);
}
/// Returns the element types of this struct type.
diff --git a/mlir/examples/toy/Ch7/include/toy/Dialect.h b/mlir/examples/toy/Ch7/include/toy/Dialect.h
index b69516992401..4eceb422efa6 100644
--- a/mlir/examples/toy/Ch7/include/toy/Dialect.h
+++ b/mlir/examples/toy/Ch7/include/toy/Dialect.h
@@ -63,13 +63,6 @@ class ToyDialect : public mlir::Dialect {
// Toy Types
//===----------------------------------------------------------------------===//
-/// Create a local enumeration with all of the types that are defined by Toy.
-namespace ToyTypes {
-enum Types {
- Struct = mlir::Type::FIRST_TOY_TYPE,
-};
-} // end namespace ToyTypes
-
/// This class defines the Toy struct type. It represents a collection of
/// element types. All derived types in MLIR must inherit from the CRTP class
/// 'Type::TypeBase'. It takes as template parameters the concrete type
diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
index e233a5549934..04c796ce6d0b 100644
--- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
@@ -474,11 +474,10 @@ StructType StructType::get(llvm::ArrayRef<mlir::Type> elementTypes) {
assert(!elementTypes.empty() && "expected at least 1 element type");
// Call into a helper 'get' method in 'TypeBase' to get a uniqued instance
- // of this type. The first two parameters are the context to unique in and the
- // kind of the type. The parameters after the type kind are forwarded to the
- // storage instance.
+ // of this type. The first parameter is the context to unique in. The
+ // parameters after the type kind are forwarded to the storage instance.
mlir::MLIRContext *ctx = elementTypes.front().getContext();
- return Base::get(ctx, ToyTypes::Struct, elementTypes);
+ return Base::get(ctx, elementTypes);
}
/// Returns the element types of this struct type.
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index b71964b5d0f8..e9a62cf5bac5 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -64,34 +64,6 @@ class LLVMIntegerType;
/// structs, the entire type is the identifier) and are thread-safe.
class LLVMType : public Type {
public:
- enum Kind {
- // Keep non-parametric types contiguous in the enum.
- VoidType = FIRST_LLVM_TYPE + 1,
- HalfType,
- BFloatType,
- FloatType,
- DoubleType,
- FP128Type,
- X86FP80Type,
- PPCFP128Type,
- X86MMXType,
- LabelType,
- TokenType,
- MetadataType,
- // End of non-parametric types.
- FunctionType,
- IntegerType,
- PointerType,
- FixedVectorType,
- ScalableVectorType,
- ArrayType,
- StructType,
- FIRST_NEW_LLVM_TYPE = VoidType,
- LAST_NEW_LLVM_TYPE = StructType,
- FIRST_TRIVIAL_TYPE = VoidType,
- LAST_TRIVIAL_TYPE = MetadataType
- };
-
/// Inherit base constructors.
using Type::Type;
@@ -256,27 +228,24 @@ class LLVMType : public Type {
//===----------------------------------------------------------------------===//
// Batch-define trivial types.
-#define DEFINE_TRIVIAL_LLVM_TYPE(ClassName, Kind) \
+#define DEFINE_TRIVIAL_LLVM_TYPE(ClassName) \
class ClassName : public Type::TypeBase<ClassName, LLVMType, TypeStorage> { \
public: \
using Base::Base; \
- static ClassName get(MLIRContext *context) { \
- return Base::get(context, Kind); \
- } \
}
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMVoidType, LLVMType::VoidType);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMHalfType, LLVMType::HalfType);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMBFloatType, LLVMType::BFloatType);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMFloatType, LLVMType::FloatType);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMDoubleType, LLVMType::DoubleType);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMFP128Type, LLVMType::FP128Type);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86FP80Type, LLVMType::X86FP80Type);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMPPCFP128Type, LLVMType::PPCFP128Type);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86MMXType, LLVMType::X86MMXType);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMTokenType, LLVMType::TokenType);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMLabelType, LLVMType::LabelType);
-DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType, LLVMType::MetadataType);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMVoidType);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMHalfType);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMBFloatType);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMFloatType);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMDoubleType);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMFP128Type);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86FP80Type);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMPPCFP128Type);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86MMXType);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMTokenType);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMLabelType);
+DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType);
#undef DEFINE_TRIVIAL_LLVM_TYPE
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
index 17e803db8211..18b2c3aaa53d 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
@@ -16,11 +16,6 @@ namespace mlir {
class MLIRContext;
namespace linalg {
-enum LinalgTypes {
- Range = Type::FIRST_LINALG_TYPE,
- LAST_USED_LINALG_TYPE = Range,
-};
-
#include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.h.inc"
/// A RangeType represents a minimal range abstraction (min, max, step).
@@ -36,11 +31,6 @@ class RangeType : public Type::TypeBase<RangeType, Type, TypeStorage> {
public:
// Used for generic hooks in TypeBase.
using Base::Base;
- /// Construction hook.
- static RangeType get(MLIRContext *context) {
- /// Custom, uniq'ed construction in the MLIRContext.
- return Base::get(context, LinalgTypes::Range);
- }
};
} // namespace linalg
diff --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/QuantTypes.h
index ccdc289a9a7c..567b63936dd3 100644
--- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h
+++ b/mlir/include/mlir/Dialect/Quant/QuantTypes.h
@@ -31,15 +31,6 @@ struct UniformQuantizedPerAxisTypeStorage;
} // namespace detail
-namespace QuantizationTypes {
-enum Kind {
- Any = Type::FIRST_QUANTIZATION_TYPE,
- UniformQuantized,
- UniformQuantizedPerAxis,
- LAST_USED_QUANTIZATION_TYPE = UniformQuantizedPerAxis,
-};
-} // namespace QuantizationTypes
-
/// Enumeration of bit-mapped flags related to quantized types.
namespace QuantizationFlags {
enum FlagValue {
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h
index 6788d5952cd4..b1909b367553 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h
@@ -32,15 +32,6 @@ struct TargetEnvAttributeStorage;
struct VerCapExtAttributeStorage;
} // namespace detail
-/// SPIR-V dialect-specific attribute kinds.
-namespace AttrKind {
-enum Kind {
- InterfaceVarABI = Attribute::FIRST_SPIRV_ATTR, /// Interface var ABI
- TargetEnv, /// Target environment
- VerCapExt, /// (version, extension, capability) triple
-};
-} // namespace AttrKind
-
/// An attribute that specifies the information regarding the interface
/// variable: descriptor set, binding, storage class.
class InterfaceVarABIAttr
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
index a9d120b5d114..2d224effdee3 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
@@ -65,19 +65,6 @@ struct StructTypeStorage;
} // namespace detail
-namespace TypeKind {
-enum Kind {
- Array = Type::FIRST_SPIRV_TYPE,
- CooperativeMatrix,
- Image,
- Matrix,
- Pointer,
- RuntimeArray,
- Struct,
- LAST_SPIRV_TYPE = Struct,
-};
-}
-
// Base SPIR-V type for providing availability queries.
class SPIRVType : public Type {
public:
diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h
index 3168e87b3df0..cc601bdedaca 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h
+++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h
@@ -29,56 +29,28 @@ namespace shape {
/// Alias type for extent tensors.
RankedTensorType getExtentTensorType(MLIRContext *ctx);
-namespace ShapeTypes {
-enum Kind {
- Component = Type::FIRST_SHAPE_TYPE,
- Element,
- Shape,
- Size,
- ValueShape,
- Witness,
- LAST_SHAPE_TYPE = Witness
-};
-} // namespace ShapeTypes
-
/// The component type corresponding to shape, element type and attribute.
class ComponentType : public Type::TypeBase<ComponentType, Type, TypeStorage> {
public:
using Base::Base;
-
- static ComponentType get(MLIRContext *context) {
- return Base::get(context, ShapeTypes::Kind::Component);
- }
};
/// The element type of the shaped type.
class ElementType : public Type::TypeBase<ElementType, Type, TypeStorage> {
public:
using Base::Base;
-
- static ElementType get(MLIRContext *context) {
- return Base::get(context, ShapeTypes::Kind::Element);
- }
};
/// The shape descriptor type represents rank and dimension sizes.
class ShapeType : public Type::TypeBase<ShapeType, Type, TypeStorage> {
public:
using Base::Base;
-
- static ShapeType get(MLIRContext *context) {
- return Base::get(context, ShapeTypes::Kind::Shape);
- }
};
/// The type of a single dimension.
class SizeType : public Type::TypeBase<SizeType, Type, TypeStorage> {
public:
using Base::Base;
-
- static SizeType get(MLIRContext *context) {
- return Base::get(context, ShapeTypes::Kind::Size);
- }
};
/// The ValueShape represents a (potentially unknown) runtime value and shape.
@@ -86,10 +58,6 @@ class ValueShapeType
: public Type::TypeBase<ValueShapeType, Type, TypeStorage> {
public:
using Base::Base;
-
- static ValueShapeType get(MLIRContext *context) {
- return Base::get(context, ShapeTypes::Kind::ValueShape);
- }
};
/// The Witness represents a runtime constraint, to be used as shape related
@@ -97,10 +65,6 @@ class ValueShapeType
class WitnessType : public Type::TypeBase<WitnessType, Type, TypeStorage> {
public:
using Base::Base;
-
- static WitnessType get(MLIRContext *context) {
- return Base::get(context, ShapeTypes::Kind::Witness);
- }
};
#define GET_OP_CLASSES
diff --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h
index 31e6285164ab..35084a20493f 100644
--- a/mlir/include/mlir/IR/AttributeSupport.h
+++ b/mlir/include/mlir/IR/AttributeSupport.h
@@ -137,15 +137,23 @@ namespace detail {
// MLIRContext. This class manages all creation and uniquing of attributes.
class AttributeUniquer {
public:
- /// Get an uniqued instance of attribute T.
+ /// Get an uniqued instance of a parametric attribute T.
template <typename T, typename... Args>
- static T get(MLIRContext *ctx, unsigned kind, Args &&... args) {
+ static typename std::enable_if_t<
+ !std::is_same<typename T::ImplType, AttributeStorage>::value, T>
+ get(MLIRContext *ctx, Args &&...args) {
return ctx->getAttributeUniquer().get<typename T::ImplType>(
- T::getTypeID(),
[ctx](AttributeStorage *storage) {
initializeAttributeStorage(storage, ctx, T::getTypeID());
},
- kind, std::forward<Args>(args)...);
+ T::getTypeID(), std::forward<Args>(args)...);
+ }
+ /// Get an uniqued instance of a singleton attribute T.
+ template <typename T>
+ static typename std::enable_if_t<
+ std::is_same<typename T::ImplType, AttributeStorage>::value, T>
+ get(MLIRContext *ctx) {
+ return ctx->getAttributeUniquer().get<typename T::ImplType>(T::getTypeID());
}
template <typename T, typename... Args>
@@ -156,6 +164,26 @@ class AttributeUniquer {
std::forward<Args>(args)...);
}
+ /// Register a parametric attribute instance T with the uniquer.
+ template <typename T>
+ static typename std::enable_if_t<
+ !std::is_same<typename T::ImplType, AttributeStorage>::value>
+ registerAttribute(MLIRContext *ctx) {
+ ctx->getAttributeUniquer()
+ .registerParametricStorageType<typename T::ImplType>(T::getTypeID());
+ }
+ /// Register a singleton attribute instance T with the uniquer.
+ template <typename T>
+ static typename std::enable_if_t<
+ std::is_same<typename T::ImplType, AttributeStorage>::value>
+ registerAttribute(MLIRContext *ctx) {
+ ctx->getAttributeUniquer()
+ .registerSingletonStorageType<typename T::ImplType>(
+ T::getTypeID(), [ctx](AttributeStorage *storage) {
+ initializeAttributeStorage(storage, ctx, T::getTypeID());
+ });
+ }
+
private:
/// Initialize the given attribute storage instance.
static void initializeAttributeStorage(AttributeStorage *storage,
diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index 75ac2adc302c..aa8f2eafb896 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -54,14 +54,6 @@ struct SparseElementsAttributeStorage;
/// passed by value.
class Attribute {
public:
- /// Integer identifier for all the concrete attribute kinds.
- enum Kind {
- // Reserve attribute kinds for dialect specific extensions.
-#define DEFINE_SYM_KIND_RANGE(Dialect) \
- FIRST_##Dialect##_ATTR, LAST_##Dialect##_ATTR = FIRST_##Dialect##_ATTR + 0xff,
-#include "DialectSymbolRegistry.def"
- };
-
/// Utility class for implementing attributes.
template <typename ConcreteType, typename BaseType, typename StorageType,
template <typename T> class... Traits>
@@ -94,9 +86,6 @@ class Attribute {
// Support dyn_cast'ing Attribute to itself.
static bool classof(Attribute) { return true; }
- /// Return the classification for this attribute.
- unsigned getKind() const { return impl->getKind(); }
-
/// Return a unique identifier for the concrete attribute type. This is used
/// to support dynamic type casting.
TypeID getTypeID() { return impl->getAbstractAttribute().getTypeID(); }
@@ -173,54 +162,6 @@ class AttributeInterface
friend InterfaceBase;
};
-//===----------------------------------------------------------------------===//
-// StandardAttributes
-//===----------------------------------------------------------------------===//
-
-namespace StandardAttributes {
-enum Kind {
- AffineMap = Attribute::FIRST_STANDARD_ATTR,
- Array,
- Dictionary,
- Float,
- Integer,
- IntegerSet,
- Opaque,
- String,
- SymbolRef,
- Type,
- Unit,
-
- /// Elements Attributes.
- DenseIntOrFPElements,
- DenseStringElements,
- OpaqueElements,
- SparseElements,
- FIRST_ELEMENTS_ATTR = DenseIntOrFPElements,
- LAST_ELEMENTS_ATTR = SparseElements,
-
- /// Locations.
- CallSiteLocation,
- FileLineColLocation,
- FusedLocation,
- NameLocation,
- OpaqueLocation,
- UnknownLocation,
-
- // Represents a location as a 'void*' pointer to a front-end's opaque
- // location information, which must live longer than the MLIR objects that
- // refer to it. OpaqueLocation's are never serialized.
- //
- // TODO: OpaqueLocation,
-
- // Represents a value inlined through a function call.
- // TODO: InlinedLocation,
-
- FIRST_LOCATION_ATTR = CallSiteLocation,
- LAST_LOCATION_ATTR = UnknownLocation,
-};
-} // namespace StandardAttributes
-
//===----------------------------------------------------------------------===//
// AffineMapAttr
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h
index 4f9e4cb3618b..12a19af20546 100644
--- a/mlir/include/mlir/IR/Dialect.h
+++ b/mlir/include/mlir/IR/Dialect.h
@@ -154,21 +154,15 @@ class Dialect {
void addOperation(AbstractOperation opInfo);
- /// This method is used by derived classes to add their types to the set.
+ /// Register a set of type classes with this dialect.
template <typename... Args> void addTypes() {
- (void)std::initializer_list<int>{
- 0, (addType(Args::getTypeID(), AbstractType::get<Args>(*this)), 0)...};
+ (void)std::initializer_list<int>{0, (addType<Args>(), 0)...};
}
- void addType(TypeID typeID, AbstractType &&typeInfo);
- /// This method is used by derived classes to add their attributes to the set.
+ /// Register a set of attribute classes with this dialect.
template <typename... Args> void addAttributes() {
- (void)std::initializer_list<int>{
- 0,
- (addAttribute(Args::getTypeID(), AbstractAttribute::get<Args>(*this)),
- 0)...};
+ (void)std::initializer_list<int>{0, (addAttribute<Args>(), 0)...};
}
- void addAttribute(TypeID typeID, AbstractAttribute &&attrInfo);
/// Enable support for unregistered operations.
void allowUnknownOperations(bool allow = true) { unknownOpsAllowed = allow; }
@@ -189,6 +183,22 @@ class Dialect {
Dialect(const Dialect &) = delete;
void operator=(Dialect &) = delete;
+ /// Register an attribute instance with this dialect.
+ template <typename T> void addAttribute() {
+ // Add this attribute to the dialect and register it with the uniquer.
+ addAttribute(T::getTypeID(), AbstractAttribute::get<T>(*this));
+ detail::AttributeUniquer::registerAttribute<T>(context);
+ }
+ void addAttribute(TypeID typeID, AbstractAttribute &&attrInfo);
+
+ /// Register a type instance with this dialect.
+ template <typename T> void addType() {
+ // Add this type to the dialect and register it with the uniquer.
+ addType(T::getTypeID(), AbstractType::get<T>(*this));
+ detail::TypeUniquer::registerType<T>(context);
+ }
+ void addType(TypeID typeID, AbstractType &&typeInfo);
+
/// The namespace of this dialect.
StringRef name;
diff --git a/mlir/include/mlir/IR/DialectSymbolRegistry.def b/mlir/include/mlir/IR/DialectSymbolRegistry.def
deleted file mode 100644
index acba383e9113..000000000000
--- a/mlir/include/mlir/IR/DialectSymbolRegistry.def
+++ /dev/null
@@ -1,44 +0,0 @@
-//===- DialectSymbolRegistry.def - MLIR Dialect Symbol Registry -*- C++ -*-===//
-//
-// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file enumerates the
diff erent dialects that define custom classes
-// within the attribute or type system.
-//
-//===----------------------------------------------------------------------===//
-
-DEFINE_SYM_KIND_RANGE(STANDARD)
-DEFINE_SYM_KIND_RANGE(TENSORFLOW_CONTROL)
-DEFINE_SYM_KIND_RANGE(TENSORFLOW_EXECUTOR)
-DEFINE_SYM_KIND_RANGE(TENSORFLOW)
-DEFINE_SYM_KIND_RANGE(LLVM)
-DEFINE_SYM_KIND_RANGE(QUANTIZATION)
-DEFINE_SYM_KIND_RANGE(IREE) // IREE stands for IR Execution Engine
-DEFINE_SYM_KIND_RANGE(LINALG) // Linear Algebra Dialect
-DEFINE_SYM_KIND_RANGE(FIR) // Flang Fortran IR Dialect
-DEFINE_SYM_KIND_RANGE(OPENACC) // OpenACC IR Dialect
-DEFINE_SYM_KIND_RANGE(OPENMP) // OpenMP IR Dialect
-DEFINE_SYM_KIND_RANGE(TOY) // Toy language (tutorial) Dialect
-DEFINE_SYM_KIND_RANGE(SPIRV) // SPIR-V dialect
-DEFINE_SYM_KIND_RANGE(XLA_HLO) // XLA HLO dialect
-DEFINE_SYM_KIND_RANGE(SHAPE) // Shape dialect
-DEFINE_SYM_KIND_RANGE(TF_FRAMEWORK) // TF Framework dialect
-
-// The following ranges are reserved for experimenting with MLIR dialects in a
-// private context without having to register them here.
-DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_0)
-DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_1)
-DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_2)
-DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_3)
-DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_4)
-DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_5)
-DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_6)
-DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_7)
-DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_8)
-DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_9)
-
-#undef DEFINE_SYM_KIND_RANGE
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 0124ef5f7c0a..df54919ade1e 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -756,7 +756,7 @@ class OpAsmDialectInterface
/// all attributes of the given kind in the form : <alias>[0-9]+. These
/// aliases must not contain `.`.
virtual void getAttributeKindAliases(
- SmallVectorImpl<std::pair<unsigned, StringRef>> &aliases) const {}
+ SmallVectorImpl<std::pair<TypeID, StringRef>> &aliases) const {}
/// Hook for defining Attribute aliases. These aliases must not contain `.` or
/// end with a numeric digit([0-9]+).
virtual void getAttributeAliases(
diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h
index 6ceddec53377..e309595415d1 100644
--- a/mlir/include/mlir/IR/StandardTypes.h
+++ b/mlir/include/mlir/IR/StandardTypes.h
@@ -38,33 +38,6 @@ struct TupleTypeStorage;
} // namespace detail
-namespace StandardTypes {
-enum Kind {
- // Floating point.
- BF16 = Type::Kind::FIRST_STANDARD_TYPE,
- F16,
- F32,
- F64,
- FIRST_FLOATING_POINT_TYPE = BF16,
- LAST_FLOATING_POINT_TYPE = F64,
-
- // Target pointer sized integer, used (e.g.) in affine mappings.
- Index,
-
- // Derived types.
- Integer,
- Vector,
- RankedTensor,
- UnrankedTensor,
- MemRef,
- UnrankedMemRef,
- Complex,
- Tuple,
- None,
-};
-
-} // namespace StandardTypes
-
//===----------------------------------------------------------------------===//
// ComplexType
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h
index 48026c219082..75bc40abdaef 100644
--- a/mlir/include/mlir/IR/StorageUniquerSupport.h
+++ b/mlir/include/mlir/IR/StorageUniquerSupport.h
@@ -82,29 +82,29 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
return detail::InterfaceMap::template get<Traits<ConcreteT>...>();
}
-protected:
/// Get or create a new ConcreteT instance within the ctx. This
/// function is guaranteed to return a non null object and will assert if
/// the arguments provided are invalid.
template <typename... Args>
- static ConcreteT get(MLIRContext *ctx, unsigned kind, Args... args) {
+ static ConcreteT get(MLIRContext *ctx, Args... args) {
// Ensure that the invariants are correct for construction.
assert(succeeded(ConcreteT::verifyConstructionInvariants(
generateUnknownStorageLocation(ctx), args...)));
- return UniquerT::template get<ConcreteT>(ctx, kind, args...);
+ return UniquerT::template get<ConcreteT>(ctx, args...);
}
/// Get or create a new ConcreteT instance within the ctx, defined at
/// the given, potentially unknown, location. If the arguments provided are
/// invalid then emit errors and return a null object.
template <typename LocationT, typename... Args>
- static ConcreteT getChecked(LocationT loc, unsigned kind, Args... args) {
+ static ConcreteT getChecked(LocationT loc, Args... args) {
// If the construction invariants fail then we return a null attribute.
if (failed(ConcreteT::verifyConstructionInvariants(loc, args...)))
return ConcreteT();
- return UniquerT::template get<ConcreteT>(loc.getContext(), kind, args...);
+ return UniquerT::template get<ConcreteT>(loc.getContext(), args...);
}
+protected:
/// Mutate the current storage instance. This will not change the unique key.
/// The arguments are forwarded to 'ConcreteT::mutate'.
template <typename... Args> LogicalResult mutate(Args &&...args) {
diff --git a/mlir/include/mlir/IR/TypeSupport.h b/mlir/include/mlir/IR/TypeSupport.h
index aa2daefd26c4..ace5eaa73345 100644
--- a/mlir/include/mlir/IR/TypeSupport.h
+++ b/mlir/include/mlir/IR/TypeSupport.h
@@ -121,15 +121,23 @@ namespace detail {
/// A utility class to get, or create, unique instances of types within an
/// MLIRContext. This class manages all creation and uniquing of types.
struct TypeUniquer {
- /// Get an uniqued instance of a type T.
+ /// Get an uniqued instance of a parametric type T.
template <typename T, typename... Args>
- static T get(MLIRContext *ctx, unsigned kind, Args &&... args) {
+ static typename std::enable_if_t<
+ !std::is_same<typename T::ImplType, TypeStorage>::value, T>
+ get(MLIRContext *ctx, Args &&...args) {
return ctx->getTypeUniquer().get<typename T::ImplType>(
- T::getTypeID(),
[&](TypeStorage *storage) {
storage->initialize(AbstractType::lookup(T::getTypeID(), ctx));
},
- kind, std::forward<Args>(args)...);
+ T::getTypeID(), std::forward<Args>(args)...);
+ }
+ /// Get an uniqued instance of a singleton type T.
+ template <typename T>
+ static typename std::enable_if_t<
+ std::is_same<typename T::ImplType, TypeStorage>::value, T>
+ get(MLIRContext *ctx) {
+ return ctx->getTypeUniquer().get<typename T::ImplType>(T::getTypeID());
}
/// Change the mutable component of the given type instance in the provided
@@ -141,6 +149,25 @@ struct TypeUniquer {
return ctx->getTypeUniquer().mutate(T::getTypeID(), impl,
std::forward<Args>(args)...);
}
+
+ /// Register a parametric type instance T with the uniquer.
+ template <typename T>
+ static typename std::enable_if_t<
+ !std::is_same<typename T::ImplType, TypeStorage>::value>
+ registerType(MLIRContext *ctx) {
+ ctx->getTypeUniquer().registerParametricStorageType<typename T::ImplType>(
+ T::getTypeID());
+ }
+ /// Register a singleton type instance T with the uniquer.
+ template <typename T>
+ static typename std::enable_if_t<
+ std::is_same<typename T::ImplType, TypeStorage>::value>
+ registerType(MLIRContext *ctx) {
+ ctx->getTypeUniquer().registerSingletonStorageType<TypeStorage>(
+ T::getTypeID(), [&](TypeStorage *storage) {
+ storage->initialize(AbstractType::lookup(T::getTypeID(), ctx));
+ });
+ }
};
} // namespace detail
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 8101690daeb6..ad7e436068bc 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -34,11 +34,11 @@ struct OpaqueTypeStorage;
///
/// Some types are "primitives" meaning they do not have any parameters, for
/// example the Index type. Parametric types have additional information that
-///
diff erentiates the types of the same kind between them, for example the
-/// Integer type has bitwidth, making i8 and i16 belong to the same kind by be
-///
diff erent instances of the IntegerType. Type parameters are part of the
-/// unique immutable key. The mutable component of the type can be modified
-/// after the type is created, but cannot affect the identity of the type.
+///
diff erentiates the types of the same class, for example the Integer type has
+/// bitwidth, making i8 and i16 belong to the same kind by be
diff erent
+/// instances of the IntegerType. Type parameters are part of the unique
+/// immutable key. The mutable component of the type can be modified after the
+/// type is created, but cannot affect the identity of the type.
///
/// Types are constructed and uniqued via the 'detail::TypeUniquer' class.
///
@@ -53,20 +53,19 @@ struct OpaqueTypeStorage;
/// * This method is expected to return failure if a type cannot be
/// constructed with 'args', success otherwise.
/// * 'args' must correspond with the arguments passed into the
-/// 'TypeBase::get' call after the type kind.
+/// 'TypeBase::get' call.
///
///
/// Type storage objects inherit from TypeStorage and contain the following:
-/// - The type kind (for LLVM-style RTTI).
/// - The dialect that defined the type.
/// - Any parameters of the type.
/// - An optional mutable component.
/// For non-parametric types, a convenience DefaultTypeStorage is provided.
/// Parametric storage types must derive TypeStorage and respect the following:
/// - Define a type alias, KeyTy, to a type that uniquely identifies the
-/// instance of the type within its kind.
+/// instance of the type.
/// * The key type must be constructible from the values passed into the
-/// detail::TypeUniquer::get call after the type kind.
+/// detail::TypeUniquer::get call.
/// * If the KeyTy does not have an llvm::DenseMapInfo specialization, the
/// storage class must define a hashing method:
/// 'static unsigned hashKey(const KeyTy &)'
@@ -84,23 +83,6 @@ struct OpaqueTypeStorage;
// the key.
class Type {
public:
- /// Integer identifier for all the concrete type kinds.
- /// Note: This is not an enum class as each dialect will likely define a
- /// separate enumeration for the specific types that they define. Not being an
- /// enum class also simplifies the handling of type kinds by not requiring
- /// casts for each use.
- enum Kind {
- // Builtin types.
- Function,
- Opaque,
- LAST_BUILTIN_TYPE = Opaque,
-
- // Reserve type kinds for dialect specific type system extensions.
-#define DEFINE_SYM_KIND_RANGE(Dialect) \
- FIRST_##Dialect##_TYPE, LAST_##Dialect##_TYPE = FIRST_##Dialect##_TYPE + 0xff,
-#include "DialectSymbolRegistry.def"
- };
-
/// Utility class for implementing types.
template <typename ConcreteType, typename BaseType, typename StorageType,
template <typename T> class... Traits>
@@ -136,9 +118,6 @@ class Type {
/// dynamic type casting.
TypeID getTypeID() { return impl->getAbstractType().getTypeID(); }
- /// Return the classification for this type.
- unsigned getKind() const;
-
/// Return the LLVMContext in which this type was uniqued.
MLIRContext *getContext() const;
diff --git a/mlir/include/mlir/Support/StorageUniquer.h b/mlir/include/mlir/Support/StorageUniquer.h
index 6c7c7b0496da..eb04688be190 100644
--- a/mlir/include/mlir/Support/StorageUniquer.h
+++ b/mlir/include/mlir/Support/StorageUniquer.h
@@ -11,12 +11,11 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
+#include "mlir/Support/TypeID.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/Support/Allocator.h"
namespace mlir {
-class TypeID;
-
namespace detail {
struct StorageUniquerImpl;
@@ -29,22 +28,19 @@ template <typename ImplTy, typename T>
using has_impltype_hash_t = decltype(ImplTy::hashKey(std::declval<T>()));
} // namespace detail
-/// A utility class to get, or create instances of storage classes. These
-/// storage classes must respect the following constraints:
-/// - Derive from StorageUniquer::BaseStorage.
-/// - Provide an unsigned 'kind' value to be used as part of the unique'ing
-/// process.
+/// A utility class to get or create instances of "storage classes". These
+/// storage classes must derive from 'StorageUniquer::BaseStorage'.
///
-/// For non-parametric storage classes, i.e. those that are solely uniqued by
-/// their kind, nothing else is needed. Instances of these classes can be
-/// created by calling `get` without trailing arguments.
+/// For non-parametric storage classes, i.e. singleton classes, nothing else is
+/// needed. Instances of these classes can be created by calling `get` without
+/// trailing arguments.
///
/// Otherwise, the parametric storage classes may be created with `get`,
/// and must respect the following:
/// - Define a type alias, KeyTy, to a type that uniquely identifies the
-/// instance of the storage class within its kind.
+/// instance of the storage class.
/// * The key type must be constructible from the values passed into the
-/// getComplex call after the kind.
+/// getComplex call.
/// * If the KeyTy does not have an llvm::DenseMapInfo specialization, the
/// storage class must define a hashing method:
/// 'static unsigned hashKey(const KeyTy &)'
@@ -83,32 +79,11 @@ using has_impltype_hash_t = decltype(ImplTy::hashKey(std::declval<T>()));
/// class.
class StorageUniquer {
public:
- StorageUniquer();
- ~StorageUniquer();
-
- /// Set the flag specifying if multi-threading is disabled within the uniquer.
- void disableMultithreading(bool disable = true);
-
- /// Register a new storage object with this uniquer using the given unique
- /// type id.
- void registerStorageType(TypeID id);
-
/// This class acts as the base storage that all storage classes must derived
/// from.
class BaseStorage {
- public:
- /// Get the kind classification of this storage.
- unsigned getKind() const { return kind; }
-
protected:
- BaseStorage() : kind(0) {}
-
- private:
- /// Allow access to the kind field.
- friend detail::StorageUniquerImpl;
-
- /// Classification of the subclass, used for type checking.
- unsigned kind;
+ BaseStorage() = default;
};
/// This is a utility allocator used to allocate memory for instances of
@@ -145,19 +120,61 @@ class StorageUniquer {
llvm::BumpPtrAllocator allocator;
};
- /// Gets a uniqued instance of 'Storage'. 'initFn' is an optional parameter
- /// that can be used to initialize a newly inserted storage instance. This
- /// function is used for derived types that have complex storage or uniquing
+ StorageUniquer();
+ ~StorageUniquer();
+
+ /// Set the flag specifying if multi-threading is disabled within the uniquer.
+ void disableMultithreading(bool disable = true);
+
+ /// Register a new parametric storage class, this is necessary to create
+ /// instances of this class type. `id` is the type identifier that will be
+ /// used to identify this type when creating instances of it via 'get'.
+ template <typename Storage> void registerParametricStorageType(TypeID id) {
+ registerParametricStorageTypeImpl(id);
+ }
+ /// Utility override when the storage type represents the type id.
+ template <typename Storage> void registerParametricStorageType() {
+ registerParametricStorageType<Storage>(TypeID::get<Storage>());
+ }
+ /// Register a new singleton storage class, this is necessary to get the
+ /// singletone instance. `id` is the type identifier that will be used to
+ /// access the singleton instance via 'get'. An optional initialization
+ /// function may also be provided to initialize the newly created storage
+ /// instance, and used when the singleton instance is created.
+ template <typename Storage>
+ void registerSingletonStorageType(TypeID id,
+ function_ref<void(Storage *)> initFn) {
+ auto ctorFn = [&](StorageAllocator &allocator) {
+ auto *storage = new (allocator.allocate<Storage>()) Storage();
+ if (initFn)
+ initFn(storage);
+ return storage;
+ };
+ registerSingletonImpl(id, ctorFn);
+ }
+ template <typename Storage> void registerSingletonStorageType(TypeID id) {
+ registerSingletonStorageType<Storage>(id, llvm::None);
+ }
+ /// Utility override when the storage type represents the type id.
+ template <typename Storage>
+ void registerSingletonStorageType(
+ function_ref<void(Storage *)> initFn = llvm::None) {
+ registerSingletonStorageType<Storage>(TypeID::get<Storage>(), initFn);
+ }
+
+ /// Gets a uniqued instance of 'Storage'. 'id' is the type id used when
+ /// registering the storage instance. 'initFn' is an optional parameter that
+ /// can be used to initialize a newly inserted storage instance. This function
+ /// is used for derived types that have complex storage or uniquing
/// constraints.
- template <typename Storage, typename Arg, typename... Args>
- Storage *get(const TypeID &id, function_ref<void(Storage *)> initFn,
- unsigned kind, Arg &&arg, Args &&...args) {
+ template <typename Storage, typename... Args>
+ Storage *get(function_ref<void(Storage *)> initFn, TypeID id,
+ Args &&...args) {
// Construct a value of the derived key type.
- auto derivedKey =
- getKey<Storage>(std::forward<Arg>(arg), std::forward<Args>(args)...);
+ auto derivedKey = getKey<Storage>(std::forward<Args>(args)...);
- // Create a hash of the kind and the derived key.
- unsigned hashValue = getHash<Storage>(kind, derivedKey);
+ // Create a hash of the derived key.
+ unsigned hashValue = getHash<Storage>(derivedKey);
// Generate an equality function for the derived storage.
auto isEqual = [&derivedKey](const BaseStorage *existing) {
@@ -174,29 +191,29 @@ class StorageUniquer {
// Get an instance for the derived storage.
return static_cast<Storage *>(
- getImpl(id, kind, hashValue, isEqual, ctorFn));
+ getParametricStorageTypeImpl(id, hashValue, isEqual, ctorFn));
+ }
+ /// Utility override when the storage type represents the type id.
+ template <typename Storage, typename... Args>
+ Storage *get(function_ref<void(Storage *)> initFn, Args &&...args) {
+ return get<Storage>(initFn, TypeID::get<Storage>(),
+ std::forward<Args>(args)...);
}
- /// Gets a uniqued instance of 'Storage'. 'initFn' is an optional parameter
- /// that can be used to initialize a newly inserted storage instance. This
- /// function is used for derived types that use no additional storage or
- /// uniquing outside of the kind.
- template <typename Storage>
- Storage *get(const TypeID &id, function_ref<void(Storage *)> initFn,
- unsigned kind) {
- auto ctorFn = [&](StorageAllocator &allocator) {
- auto *storage = new (allocator.allocate<Storage>()) Storage();
- if (initFn)
- initFn(storage);
- return storage;
- };
- return static_cast<Storage *>(getImpl(id, kind, ctorFn));
+ /// Gets a uniqued instance of 'Storage' which is a singleton storage type.
+ /// 'id' is the type id used when registering the storage instance.
+ template <typename Storage> Storage *get(TypeID id) {
+ return static_cast<Storage *>(getSingletonImpl(id));
+ }
+ /// Utility override when the storage type represents the type id.
+ template <typename Storage> Storage *get() {
+ return get<Storage>(TypeID::get<Storage>());
}
/// Changes the mutable component of 'storage' by forwarding the trailing
/// arguments to the 'mutate' function of the derived class.
template <typename Storage, typename... Args>
- LogicalResult mutate(const TypeID &id, Storage *storage, Args &&...args) {
+ LogicalResult mutate(TypeID id, Storage *storage, Args &&...args) {
auto mutationFn = [&](StorageAllocator &allocator) -> LogicalResult {
return static_cast<Storage &>(*storage).mutate(
allocator, std::forward<Args>(args)...);
@@ -207,13 +224,13 @@ class StorageUniquer {
/// Erases a uniqued instance of 'Storage'. This function is used for derived
/// types that have complex storage or uniquing constraints.
template <typename Storage, typename Arg, typename... Args>
- void erase(const TypeID &id, unsigned kind, Arg &&arg, Args &&...args) {
+ void erase(TypeID id, Arg &&arg, Args &&...args) {
// Construct a value of the derived key type.
auto derivedKey =
getKey<Storage>(std::forward<Arg>(arg), std::forward<Args>(args)...);
- // Create a hash of the kind and the derived key.
- unsigned hashValue = getHash<Storage>(kind, derivedKey);
+ // Create a hash of the derived key.
+ unsigned hashValue = getHash<Storage>(derivedKey);
// Generate an equality function for the derived storage.
auto isEqual = [&derivedKey](const BaseStorage *existing) {
@@ -221,32 +238,42 @@ class StorageUniquer {
};
// Attempt to erase the storage instance.
- eraseImpl(id, kind, hashValue, isEqual, [](BaseStorage *storage) {
+ eraseImpl(id, hashValue, isEqual, [](BaseStorage *storage) {
static_cast<Storage *>(storage)->cleanup();
});
}
private:
/// Implementation for getting/creating an instance of a derived type with
- /// complex storage.
- BaseStorage *getImpl(const TypeID &id, unsigned kind, unsigned hashValue,
- function_ref<bool(const BaseStorage *)> isEqual,
- function_ref<BaseStorage *(StorageAllocator &)> ctorFn);
+ /// parametric storage.
+ BaseStorage *getParametricStorageTypeImpl(
+ TypeID id, unsigned hashValue,
+ function_ref<bool(const BaseStorage *)> isEqual,
+ function_ref<BaseStorage *(StorageAllocator &)> ctorFn);
- /// Implementation for getting/creating an instance of a derived type with
- /// default storage.
- BaseStorage *getImpl(const TypeID &id, unsigned kind,
- function_ref<BaseStorage *(StorageAllocator &)> ctorFn);
+ /// Implementation for registering an instance of a derived type with
+ /// parametric storage.
+ void registerParametricStorageTypeImpl(TypeID id);
+
+ /// Implementation for getting an instance of a derived type with default
+ /// storage.
+ BaseStorage *getSingletonImpl(TypeID id);
+
+ /// Implementation for registering an instance of a derived type with default
+ /// storage.
+ void
+ registerSingletonImpl(TypeID id,
+ function_ref<BaseStorage *(StorageAllocator &)> ctorFn);
/// Implementation for erasing an instance of a derived type with complex
/// storage.
- void eraseImpl(const TypeID &id, unsigned kind, unsigned hashValue,
+ void eraseImpl(TypeID id, unsigned hashValue,
function_ref<bool(const BaseStorage *)> isEqual,
function_ref<void(BaseStorage *)> cleanupFn);
/// Implementation for mutating an instance of a derived storage.
LogicalResult
- mutateImpl(const TypeID &id,
+ mutateImpl(TypeID id,
function_ref<LogicalResult(StorageAllocator &)> mutationFn);
/// The internal implementation class.
@@ -276,27 +303,26 @@ class StorageUniquer {
}
//===--------------------------------------------------------------------===//
- // Key and Kind Hashing
+ // Key Hashing
//===--------------------------------------------------------------------===//
- /// Used to generate a hash for the 'ImplTy::KeyTy' and kind of a storage
- /// instance if there is an 'ImplTy::hashKey' overload for 'DerivedKey'.
+ /// Used to generate a hash for the 'ImplTy::KeyTy' of a storage instance if
+ /// there is an 'ImplTy::hashKey' overload for 'DerivedKey'.
template <typename ImplTy, typename DerivedKey>
static typename std::enable_if<
llvm::is_detected<detail::has_impltype_hash_t, ImplTy, DerivedKey>::value,
::llvm::hash_code>::type
- getHash(unsigned kind, const DerivedKey &derivedKey) {
- return llvm::hash_combine(kind, ImplTy::hashKey(derivedKey));
+ getHash(const DerivedKey &derivedKey) {
+ return ImplTy::hashKey(derivedKey);
}
- /// If there is no 'ImplTy::hashKey' default to using the
- /// 'llvm::DenseMapInfo' definition for 'DerivedKey' for generating a hash.
+ /// If there is no 'ImplTy::hashKey' default to using the 'llvm::DenseMapInfo'
+ /// definition for 'DerivedKey' for generating a hash.
template <typename ImplTy, typename DerivedKey>
static typename std::enable_if<!llvm::is_detected<detail::has_impltype_hash_t,
ImplTy, DerivedKey>::value,
::llvm::hash_code>::type
- getHash(unsigned kind, const DerivedKey &derivedKey) {
- return llvm::hash_combine(
- kind, DenseMapInfo<DerivedKey>::getHashValue(derivedKey));
+ getHash(const DerivedKey &derivedKey) {
+ return DenseMapInfo<DerivedKey>::getHashValue(derivedKey);
}
};
} // end namespace mlir
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index f4a278dfbbb0..727efbb4704f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -264,14 +264,13 @@ bool LLVMArrayType::isValidElementType(LLVMType type) {
LLVMArrayType LLVMArrayType::get(LLVMType elementType, unsigned numElements) {
assert(elementType && "expected non-null subtype");
- return Base::get(elementType.getContext(), LLVMType::ArrayType, elementType,
- numElements);
+ return Base::get(elementType.getContext(), elementType, numElements);
}
LLVMArrayType LLVMArrayType::getChecked(Location loc, LLVMType elementType,
unsigned numElements) {
assert(elementType && "expected non-null subtype");
- return Base::getChecked(loc, LLVMType::ArrayType, elementType, numElements);
+ return Base::getChecked(loc, elementType, numElements);
}
LLVMType LLVMArrayType::getElementType() { return getImpl()->elementType; }
@@ -301,16 +300,14 @@ LLVMFunctionType LLVMFunctionType::get(LLVMType result,
ArrayRef<LLVMType> arguments,
bool isVarArg) {
assert(result && "expected non-null result");
- return Base::get(result.getContext(), LLVMType::FunctionType, result,
- arguments, isVarArg);
+ return Base::get(result.getContext(), result, arguments, isVarArg);
}
LLVMFunctionType LLVMFunctionType::getChecked(Location loc, LLVMType result,
ArrayRef<LLVMType> arguments,
bool isVarArg) {
assert(result && "expected non-null result");
- return Base::getChecked(loc, LLVMType::FunctionType, result, arguments,
- isVarArg);
+ return Base::getChecked(loc, result, arguments, isVarArg);
}
LLVMType LLVMFunctionType::getReturnType() {
@@ -347,11 +344,11 @@ LogicalResult LLVMFunctionType::verifyConstructionInvariants(
// Integer type.
LLVMIntegerType LLVMIntegerType::get(MLIRContext *ctx, unsigned bitwidth) {
- return Base::get(ctx, LLVMType::IntegerType, bitwidth);
+ return Base::get(ctx, bitwidth);
}
LLVMIntegerType LLVMIntegerType::getChecked(Location loc, unsigned bitwidth) {
- return Base::getChecked(loc, LLVMType::IntegerType, bitwidth);
+ return Base::getChecked(loc, bitwidth);
}
unsigned LLVMIntegerType::getBitWidth() { return getImpl()->bitwidth; }
@@ -374,13 +371,12 @@ bool LLVMPointerType::isValidElementType(LLVMType type) {
LLVMPointerType LLVMPointerType::get(LLVMType pointee, unsigned addressSpace) {
assert(pointee && "expected non-null subtype");
- return Base::get(pointee.getContext(), LLVMType::PointerType, pointee,
- addressSpace);
+ return Base::get(pointee.getContext(), pointee, addressSpace);
}
LLVMPointerType LLVMPointerType::getChecked(Location loc, LLVMType pointee,
unsigned addressSpace) {
- return Base::getChecked(loc, LLVMType::PointerType, pointee, addressSpace);
+ return Base::getChecked(loc, pointee, addressSpace);
}
LLVMType LLVMPointerType::getElementType() { return getImpl()->pointeeType; }
@@ -405,32 +401,32 @@ bool LLVMStructType::isValidElementType(LLVMType type) {
LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
StringRef name) {
- return Base::get(context, LLVMType::StructType, name, /*opaque=*/false);
+ return Base::get(context, name, /*opaque=*/false);
}
LLVMStructType LLVMStructType::getIdentifiedChecked(Location loc,
StringRef name) {
- return Base::getChecked(loc, LLVMType::StructType, name, /*opaque=*/false);
+ return Base::getChecked(loc, name, /*opaque=*/false);
}
LLVMStructType LLVMStructType::getLiteral(MLIRContext *context,
ArrayRef<LLVMType> types,
bool isPacked) {
- return Base::get(context, LLVMType::StructType, types, isPacked);
+ return Base::get(context, types, isPacked);
}
LLVMStructType LLVMStructType::getLiteralChecked(Location loc,
ArrayRef<LLVMType> types,
bool isPacked) {
- return Base::getChecked(loc, LLVMType::StructType, types, isPacked);
+ return Base::getChecked(loc, types, isPacked);
}
LLVMStructType LLVMStructType::getOpaque(StringRef name, MLIRContext *context) {
- return Base::get(context, LLVMType::StructType, name, /*opaque=*/true);
+ return Base::get(context, name, /*opaque=*/true);
}
LLVMStructType LLVMStructType::getOpaqueChecked(Location loc, StringRef name) {
- return Base::getChecked(loc, LLVMType::StructType, name, /*opaque=*/true);
+ return Base::getChecked(loc, name, /*opaque=*/true);
}
LogicalResult LLVMStructType::setBody(ArrayRef<LLVMType> types, bool isPacked) {
@@ -508,16 +504,14 @@ LLVMVectorType::verifyConstructionInvariants(Location loc, LLVMType elementType,
LLVMFixedVectorType LLVMFixedVectorType::get(LLVMType elementType,
unsigned numElements) {
assert(elementType && "expected non-null subtype");
- return Base::get(elementType.getContext(), LLVMType::FixedVectorType,
- elementType, numElements);
+ return Base::get(elementType.getContext(), elementType, numElements);
}
LLVMFixedVectorType LLVMFixedVectorType::getChecked(Location loc,
LLVMType elementType,
unsigned numElements) {
assert(elementType && "expected non-null subtype");
- return Base::getChecked(loc, LLVMType::FixedVectorType, elementType,
- numElements);
+ return Base::getChecked(loc, elementType, numElements);
}
unsigned LLVMFixedVectorType::getNumElements() {
@@ -527,16 +521,14 @@ unsigned LLVMFixedVectorType::getNumElements() {
LLVMScalableVectorType LLVMScalableVectorType::get(LLVMType elementType,
unsigned minNumElements) {
assert(elementType && "expected non-null subtype");
- return Base::get(elementType.getContext(), LLVMType::ScalableVectorType,
- elementType, minNumElements);
+ return Base::get(elementType.getContext(), elementType, minNumElements);
}
LLVMScalableVectorType
LLVMScalableVectorType::getChecked(Location loc, LLVMType elementType,
unsigned minNumElements) {
assert(elementType && "expected non-null subtype");
- return Base::getChecked(loc, LLVMType::ScalableVectorType, elementType,
- minNumElements);
+ return Base::getChecked(loc, elementType, minNumElements);
}
unsigned LLVMScalableVectorType::getMinNumElements() {
diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index ef7d8144b259..41e64d1540f3 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -204,8 +204,8 @@ AnyQuantizedType AnyQuantizedType::get(unsigned flags, Type storageType,
Type expressedType,
int64_t storageTypeMin,
int64_t storageTypeMax) {
- return Base::get(storageType.getContext(), QuantizationTypes::Any, flags,
- storageType, expressedType, storageTypeMin, storageTypeMax);
+ return Base::get(storageType.getContext(), flags, storageType, expressedType,
+ storageTypeMin, storageTypeMax);
}
AnyQuantizedType AnyQuantizedType::getChecked(unsigned flags, Type storageType,
@@ -213,8 +213,8 @@ AnyQuantizedType AnyQuantizedType::getChecked(unsigned flags, Type storageType,
int64_t storageTypeMin,
int64_t storageTypeMax,
Location location) {
- return Base::getChecked(location, QuantizationTypes::Any, flags, storageType,
- expressedType, storageTypeMin, storageTypeMax);
+ return Base::getChecked(location, flags, storageType, expressedType,
+ storageTypeMin, storageTypeMax);
}
LogicalResult AnyQuantizedType::verifyConstructionInvariants(
@@ -240,10 +240,8 @@ UniformQuantizedType UniformQuantizedType::get(unsigned flags, Type storageType,
int64_t zeroPoint,
int64_t storageTypeMin,
int64_t storageTypeMax) {
- return Base::get(storageType.getContext(),
- QuantizationTypes::UniformQuantized, flags, storageType,
- expressedType, scale, zeroPoint, storageTypeMin,
- storageTypeMax);
+ return Base::get(storageType.getContext(), flags, storageType, expressedType,
+ scale, zeroPoint, storageTypeMin, storageTypeMax);
}
UniformQuantizedType
@@ -251,9 +249,8 @@ UniformQuantizedType::getChecked(unsigned flags, Type storageType,
Type expressedType, double scale,
int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMax, Location location) {
- return Base::getChecked(location, QuantizationTypes::UniformQuantized, flags,
- storageType, expressedType, scale, zeroPoint,
- storageTypeMin, storageTypeMax);
+ return Base::getChecked(location, flags, storageType, expressedType, scale,
+ zeroPoint, storageTypeMin, storageTypeMax);
}
LogicalResult UniformQuantizedType::verifyConstructionInvariants(
@@ -295,10 +292,9 @@ UniformQuantizedPerAxisType UniformQuantizedPerAxisType::get(
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension, int64_t storageTypeMin,
int64_t storageTypeMax) {
- return Base::get(storageType.getContext(),
- QuantizationTypes::UniformQuantizedPerAxis, flags,
- storageType, expressedType, scales, zeroPoints,
- quantizedDimension, storageTypeMin, storageTypeMax);
+ return Base::get(storageType.getContext(), flags, storageType, expressedType,
+ scales, zeroPoints, quantizedDimension, storageTypeMin,
+ storageTypeMax);
}
UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
@@ -306,9 +302,9 @@ UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax,
Location location) {
- return Base::getChecked(location, QuantizationTypes::UniformQuantizedPerAxis,
- flags, storageType, expressedType, scales, zeroPoints,
- quantizedDimension, storageTypeMin, storageTypeMax);
+ return Base::getChecked(location, flags, storageType, expressedType, scales,
+ zeroPoints, quantizedDimension, storageTypeMin,
+ storageTypeMax);
}
LogicalResult UniformQuantizedPerAxisType::verifyConstructionInvariants(
diff --git a/mlir/lib/Dialect/SDBM/SDBMDialect.cpp b/mlir/lib/Dialect/SDBM/SDBMDialect.cpp
index 09c9d1dfd3d8..4e3e050b4a4f 100644
--- a/mlir/lib/Dialect/SDBM/SDBMDialect.cpp
+++ b/mlir/lib/Dialect/SDBM/SDBMDialect.cpp
@@ -13,11 +13,11 @@ using namespace mlir;
SDBMDialect::SDBMDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context, TypeID::get<SDBMDialect>()) {
- uniquer.registerStorageType(TypeID::get<detail::SDBMBinaryExprStorage>());
- uniquer.registerStorageType(TypeID::get<detail::SDBMConstantExprStorage>());
- uniquer.registerStorageType(TypeID::get<detail::SDBMDiffExprStorage>());
- uniquer.registerStorageType(TypeID::get<detail::SDBMNegExprStorage>());
- uniquer.registerStorageType(TypeID::get<detail::SDBMTermExprStorage>());
+ uniquer.registerParametricStorageType<detail::SDBMBinaryExprStorage>();
+ uniquer.registerParametricStorageType<detail::SDBMConstantExprStorage>();
+ uniquer.registerParametricStorageType<detail::SDBMDiffExprStorage>();
+ uniquer.registerParametricStorageType<detail::SDBMNegExprStorage>();
+ uniquer.registerParametricStorageType<detail::SDBMTermExprStorage>();
}
SDBMDialect::~SDBMDialect() = default;
diff --git a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp
index 435c7fe25f0c..8da6c40bba88 100644
--- a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp
+++ b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp
@@ -246,7 +246,6 @@ SDBMSumExpr SDBMSumExpr::get(SDBMTermExpr lhs, SDBMConstantExpr rhs) {
StorageUniquer &uniquer = lhs.getDialect()->getUniquer();
return uniquer.get<detail::SDBMBinaryExprStorage>(
- TypeID::get<detail::SDBMBinaryExprStorage>(),
/*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Add), lhs, rhs);
}
@@ -533,9 +532,7 @@ SDBMDiffExpr SDBMDiffExpr::get(SDBMDirectExpr lhs, SDBMTermExpr rhs) {
assert(rhs && "expected SDBM dimension");
StorageUniquer &uniquer = lhs.getDialect()->getUniquer();
- return uniquer.get<detail::SDBMDiffExprStorage>(
- TypeID::get<detail::SDBMDiffExprStorage>(),
- /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Diff), lhs, rhs);
+ return uniquer.get<detail::SDBMDiffExprStorage>(/*initFn=*/{}, lhs, rhs);
}
SDBMDirectExpr SDBMDiffExpr::getLHS() const {
@@ -575,7 +572,6 @@ SDBMStripeExpr SDBMStripeExpr::get(SDBMDirectExpr var,
StorageUniquer &uniquer = var.getDialect()->getUniquer();
return uniquer.get<detail::SDBMBinaryExprStorage>(
- TypeID::get<detail::SDBMBinaryExprStorage>(),
/*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Stripe), var,
stripeFactor);
}
@@ -611,8 +607,7 @@ SDBMDimExpr SDBMDimExpr::get(SDBMDialect *dialect, unsigned position) {
StorageUniquer &uniquer = dialect->getUniquer();
return uniquer.get<detail::SDBMTermExprStorage>(
- TypeID::get<detail::SDBMTermExprStorage>(), assignDialect,
- static_cast<unsigned>(SDBMExprKind::DimId), position);
+ assignDialect, static_cast<unsigned>(SDBMExprKind::DimId), position);
}
//===----------------------------------------------------------------------===//
@@ -628,8 +623,7 @@ SDBMSymbolExpr SDBMSymbolExpr::get(SDBMDialect *dialect, unsigned position) {
StorageUniquer &uniquer = dialect->getUniquer();
return uniquer.get<detail::SDBMTermExprStorage>(
- TypeID::get<detail::SDBMTermExprStorage>(), assignDialect,
- static_cast<unsigned>(SDBMExprKind::SymbolId), position);
+ assignDialect, static_cast<unsigned>(SDBMExprKind::SymbolId), position);
}
//===----------------------------------------------------------------------===//
@@ -644,9 +638,7 @@ SDBMConstantExpr SDBMConstantExpr::get(SDBMDialect *dialect, int64_t value) {
};
StorageUniquer &uniquer = dialect->getUniquer();
- return uniquer.get<detail::SDBMConstantExprStorage>(
- TypeID::get<detail::SDBMConstantExprStorage>(), assignCtx,
- static_cast<unsigned>(SDBMExprKind::Constant), value);
+ return uniquer.get<detail::SDBMConstantExprStorage>(assignCtx, value);
}
int64_t SDBMConstantExpr::getValue() const {
@@ -661,9 +653,7 @@ SDBMNegExpr SDBMNegExpr::get(SDBMDirectExpr var) {
assert(var && "expected non-null SDBM direct expression");
StorageUniquer &uniquer = var.getDialect()->getUniquer();
- return uniquer.get<detail::SDBMNegExprStorage>(
- TypeID::get<detail::SDBMNegExprStorage>(),
- /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Neg), var);
+ return uniquer.get<detail::SDBMNegExprStorage>(/*initFn=*/{}, var);
}
SDBMDirectExpr SDBMNegExpr::getVar() const {
diff --git a/mlir/lib/Dialect/SDBM/SDBMExprDetail.h b/mlir/lib/Dialect/SDBM/SDBMExprDetail.h
index e344917fb2ca..8d91334c807e 100644
--- a/mlir/lib/Dialect/SDBM/SDBMExprDetail.h
+++ b/mlir/lib/Dialect/SDBM/SDBMExprDetail.h
@@ -25,27 +25,28 @@ namespace detail {
// Base storage class for SDBMExpr.
struct SDBMExprStorage : public StorageUniquer::BaseStorage {
- SDBMExprKind getKind() {
- return static_cast<SDBMExprKind>(BaseStorage::getKind());
- }
+ SDBMExprKind getKind() { return kind; }
SDBMDialect *dialect;
+ SDBMExprKind kind;
};
// Storage class for SDBM sum and stripe expressions.
struct SDBMBinaryExprStorage : public SDBMExprStorage {
- using KeyTy = std::pair<SDBMDirectExpr, SDBMConstantExpr>;
+ using KeyTy = std::tuple<unsigned, SDBMDirectExpr, SDBMConstantExpr>;
bool operator==(const KeyTy &key) const {
- return std::get<0>(key) == lhs && std::get<1>(key) == rhs;
+ return static_cast<SDBMExprKind>(std::get<0>(key)) == kind &&
+ std::get<1>(key) == lhs && std::get<2>(key) == rhs;
}
static SDBMBinaryExprStorage *
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
auto *result = allocator.allocate<SDBMBinaryExprStorage>();
- result->lhs = std::get<0>(key);
- result->rhs = std::get<1>(key);
+ result->lhs = std::get<1>(key);
+ result->rhs = std::get<2>(key);
result->dialect = result->lhs.getDialect();
+ result->kind = static_cast<SDBMExprKind>(std::get<0>(key));
return result;
}
@@ -67,6 +68,7 @@ struct SDBMDiffExprStorage : public SDBMExprStorage {
result->lhs = std::get<0>(key);
result->rhs = std::get<1>(key);
result->dialect = result->lhs.getDialect();
+ result->kind = SDBMExprKind::Diff;
return result;
}
@@ -84,6 +86,7 @@ struct SDBMConstantExprStorage : public SDBMExprStorage {
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
auto *result = allocator.allocate<SDBMConstantExprStorage>();
result->constant = key;
+ result->kind = SDBMExprKind::Constant;
return result;
}
@@ -92,14 +95,18 @@ struct SDBMConstantExprStorage : public SDBMExprStorage {
// Storage class for SDBM dimension and symbol expressions.
struct SDBMTermExprStorage : public SDBMExprStorage {
- using KeyTy = unsigned;
+ using KeyTy = std::pair<unsigned, unsigned>;
- bool operator==(const KeyTy &key) const { return position == key; }
+ bool operator==(const KeyTy &key) const {
+ return kind == static_cast<SDBMExprKind>(key.first) &&
+ position == key.second;
+ }
static SDBMTermExprStorage *
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
auto *result = allocator.allocate<SDBMTermExprStorage>();
- result->position = key;
+ result->kind = static_cast<SDBMExprKind>(key.first);
+ result->position = key.second;
return result;
}
@@ -117,6 +124,7 @@ struct SDBMNegExprStorage : public SDBMExprStorage {
auto *result = allocator.allocate<SDBMNegExprStorage>();
result->expr = key;
result->dialect = key.getDialect();
+ result->kind = SDBMExprKind::Neg;
return result;
}
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVAttributes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVAttributes.cpp
index b2df52b07608..c2bf4840ddc8 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVAttributes.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVAttributes.cpp
@@ -120,8 +120,7 @@ spirv::InterfaceVarABIAttr::get(IntegerAttr descriptorSet, IntegerAttr binding,
IntegerAttr storageClass) {
assert(descriptorSet && binding);
MLIRContext *context = descriptorSet.getContext();
- return Base::get(context, spirv::AttrKind::InterfaceVarABI, descriptorSet,
- binding, storageClass);
+ return Base::get(context, descriptorSet, binding, storageClass);
}
StringRef spirv::InterfaceVarABIAttr::getKindName() {
@@ -195,8 +194,7 @@ spirv::VerCapExtAttr spirv::VerCapExtAttr::get(IntegerAttr version,
ArrayAttr extensions) {
assert(version && capabilities && extensions);
MLIRContext *context = version.getContext();
- return Base::get(context, spirv::AttrKind::VerCapExt, version, capabilities,
- extensions);
+ return Base::get(context, version, capabilities, extensions);
}
StringRef spirv::VerCapExtAttr::getKindName() { return "vce"; }
@@ -272,7 +270,7 @@ spirv::TargetEnvAttr spirv::TargetEnvAttr::get(spirv::VerCapExtAttr triple,
DictionaryAttr limits) {
assert(triple && limits && "expected valid triple and limits");
MLIRContext *context = triple.getContext();
- return Base::get(context, spirv::AttrKind::TargetEnv, triple, limits);
+ return Base::get(context, triple, limits);
}
StringRef spirv::TargetEnvAttr::getKindName() { return "target_env"; }
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
index b52ea812306e..e9cb4b2835e5 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
@@ -124,15 +124,14 @@ struct spirv::detail::ArrayTypeStorage : public TypeStorage {
ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
assert(elementCount && "ArrayType needs at least one element");
- return Base::get(elementType.getContext(), TypeKind::Array, elementType,
- elementCount, /*stride=*/0);
+ return Base::get(elementType.getContext(), elementType, elementCount,
+ /*stride=*/0);
}
ArrayType ArrayType::get(Type elementType, unsigned elementCount,
unsigned stride) {
assert(elementCount && "ArrayType needs at least one element");
- return Base::get(elementType.getContext(), TypeKind::Array, elementType,
- elementCount, stride);
+ return Base::get(elementType.getContext(), elementType, elementCount, stride);
}
unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; }
@@ -285,8 +284,7 @@ struct spirv::detail::CooperativeMatrixTypeStorage : public TypeStorage {
CooperativeMatrixNVType CooperativeMatrixNVType::get(Type elementType,
Scope scope, unsigned rows,
unsigned columns) {
- return Base::get(elementType.getContext(), TypeKind::CooperativeMatrix,
- elementType, scope, rows, columns);
+ return Base::get(elementType.getContext(), elementType, scope, rows, columns);
}
Type CooperativeMatrixNVType::getElementType() const {
@@ -389,7 +387,7 @@ ImageType
ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
value) {
- return Base::get(std::get<0>(value).getContext(), TypeKind::Image, value);
+ return Base::get(std::get<0>(value).getContext(), value);
}
Type ImageType::getElementType() const { return getImpl()->elementType; }
@@ -453,8 +451,7 @@ struct spirv::detail::PointerTypeStorage : public TypeStorage {
};
PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
- return Base::get(pointeeType.getContext(), TypeKind::Pointer, pointeeType,
- storageClass);
+ return Base::get(pointeeType.getContext(), pointeeType, storageClass);
}
Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
@@ -511,13 +508,11 @@ struct spirv::detail::RuntimeArrayTypeStorage : public TypeStorage {
};
RuntimeArrayType RuntimeArrayType::get(Type elementType) {
- return Base::get(elementType.getContext(), TypeKind::RuntimeArray,
- elementType, /*stride=*/0);
+ return Base::get(elementType.getContext(), elementType, /*stride=*/0);
}
RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
- return Base::get(elementType.getContext(), TypeKind::RuntimeArray,
- elementType, stride);
+ return Base::get(elementType.getContext(), elementType, stride);
}
Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
@@ -846,12 +841,12 @@ StructType::get(ArrayRef<Type> memberTypes,
SmallVector<StructType::MemberDecorationInfo, 4> sortedDecorations(
memberDecorations.begin(), memberDecorations.end());
llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
- return Base::get(memberTypes.vec().front().getContext(), TypeKind::Struct,
- memberTypes, offsetInfo, sortedDecorations);
+ return Base::get(memberTypes.vec().front().getContext(), memberTypes,
+ offsetInfo, sortedDecorations);
}
StructType StructType::getEmpty(MLIRContext *context) {
- return Base::get(context, TypeKind::Struct, ArrayRef<Type>(),
+ return Base::get(context, ArrayRef<Type>(),
ArrayRef<StructType::OffsetInfo>(),
ArrayRef<StructType::MemberDecorationInfo>());
}
@@ -946,13 +941,12 @@ struct spirv::detail::MatrixTypeStorage : public TypeStorage {
};
MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
- return Base::get(columnType.getContext(), TypeKind::Matrix, columnType,
- columnCount);
+ return Base::get(columnType.getContext(), columnType, columnCount);
}
MatrixType MatrixType::getChecked(Type columnType, uint32_t columnCount,
Location location) {
- return Base::getChecked(location, TypeKind::Matrix, columnType, columnCount);
+ return Base::getChecked(location, columnType, columnCount);
}
LogicalResult MatrixType::verifyConstructionInvariants(Location loc,
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 83d080f17d7d..fdecdc6c7168 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -20,9 +20,7 @@ using namespace mlir::detail;
MLIRContext *AffineExpr::getContext() const { return expr->context; }
-AffineExprKind AffineExpr::getKind() const {
- return static_cast<AffineExprKind>(expr->getKind());
-}
+AffineExprKind AffineExpr::getKind() const { return expr->kind; }
/// Walk all of the AffineExprs in this subgraph in postorder.
void AffineExpr::walk(std::function<void(AffineExpr)> callback) const {
@@ -449,8 +447,7 @@ static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position,
StorageUniquer &uniquer = context->getAffineUniquer();
return uniquer.get<AffineDimExprStorage>(
- TypeID::get<AffineDimExprStorage>(), assignCtx,
- static_cast<unsigned>(kind), position);
+ assignCtx, static_cast<unsigned>(kind), position);
}
AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) {
@@ -484,9 +481,7 @@ AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) {
};
StorageUniquer &uniquer = context->getAffineUniquer();
- return uniquer.get<AffineConstantExprStorage>(
- TypeID::get<AffineConstantExprStorage>(), assignCtx,
- static_cast<unsigned>(AffineExprKind::Constant), constant);
+ return uniquer.get<AffineConstantExprStorage>(assignCtx, constant);
}
/// Simplify add expression. Return nullptr if it can't be simplified.
@@ -594,7 +589,6 @@ AffineExpr AffineExpr::operator+(AffineExpr other) const {
StorageUniquer &uniquer = getContext()->getAffineUniquer();
return uniquer.get<AffineBinaryOpExprStorage>(
- TypeID::get<AffineBinaryOpExprStorage>(),
/*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), *this, other);
}
@@ -655,7 +649,6 @@ AffineExpr AffineExpr::operator*(AffineExpr other) const {
StorageUniquer &uniquer = getContext()->getAffineUniquer();
return uniquer.get<AffineBinaryOpExprStorage>(
- TypeID::get<AffineBinaryOpExprStorage>(),
/*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), *this, other);
}
@@ -722,7 +715,6 @@ AffineExpr AffineExpr::floorDiv(AffineExpr other) const {
StorageUniquer &uniquer = getContext()->getAffineUniquer();
return uniquer.get<AffineBinaryOpExprStorage>(
- TypeID::get<AffineBinaryOpExprStorage>(),
/*initFn=*/{}, static_cast<unsigned>(AffineExprKind::FloorDiv), *this,
other);
}
@@ -766,7 +758,6 @@ AffineExpr AffineExpr::ceilDiv(AffineExpr other) const {
StorageUniquer &uniquer = getContext()->getAffineUniquer();
return uniquer.get<AffineBinaryOpExprStorage>(
- TypeID::get<AffineBinaryOpExprStorage>(),
/*initFn=*/{}, static_cast<unsigned>(AffineExprKind::CeilDiv), *this,
other);
}
@@ -814,7 +805,6 @@ AffineExpr AffineExpr::operator%(AffineExpr other) const {
StorageUniquer &uniquer = getContext()->getAffineUniquer();
return uniquer.get<AffineBinaryOpExprStorage>(
- TypeID::get<AffineBinaryOpExprStorage>(),
/*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mod), *this, other);
}
diff --git a/mlir/lib/IR/AffineExprDetail.h b/mlir/lib/IR/AffineExprDetail.h
index ff47cd95969c..1c38e54bc3e6 100644
--- a/mlir/lib/IR/AffineExprDetail.h
+++ b/mlir/lib/IR/AffineExprDetail.h
@@ -27,21 +27,24 @@ namespace detail {
/// Base storage class appearing in an affine expression.
struct AffineExprStorage : public StorageUniquer::BaseStorage {
MLIRContext *context;
+ AffineExprKind kind;
};
/// A binary operation appearing in an affine expression.
struct AffineBinaryOpExprStorage : public AffineExprStorage {
- using KeyTy = std::pair<AffineExpr, AffineExpr>;
+ using KeyTy = std::tuple<unsigned, AffineExpr, AffineExpr>;
bool operator==(const KeyTy &key) const {
- return key.first == lhs && key.second == rhs;
+ return static_cast<AffineExprKind>(std::get<0>(key)) == kind &&
+ std::get<1>(key) == lhs && std::get<2>(key) == rhs;
}
static AffineBinaryOpExprStorage *
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
auto *result = allocator.allocate<AffineBinaryOpExprStorage>();
- result->lhs = key.first;
- result->rhs = key.second;
+ result->kind = static_cast<AffineExprKind>(std::get<0>(key));
+ result->lhs = std::get<1>(key);
+ result->rhs = std::get<2>(key);
result->context = result->lhs.getContext();
return result;
}
@@ -52,14 +55,18 @@ struct AffineBinaryOpExprStorage : public AffineExprStorage {
/// A dimensional or symbolic identifier appearing in an affine expression.
struct AffineDimExprStorage : public AffineExprStorage {
- using KeyTy = unsigned;
+ using KeyTy = std::pair<unsigned, unsigned>;
- bool operator==(const KeyTy &key) const { return position == key; }
+ bool operator==(const KeyTy &key) const {
+ return kind == static_cast<AffineExprKind>(key.first) &&
+ position == key.second;
+ }
static AffineDimExprStorage *
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
auto *result = allocator.allocate<AffineDimExprStorage>();
- result->position = key;
+ result->kind = static_cast<AffineExprKind>(key.first);
+ result->position = key.second;
return result;
}
@@ -76,6 +83,7 @@ struct AffineConstantExprStorage : public AffineExprStorage {
static AffineConstantExprStorage *
construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
auto *result = allocator.allocate<AffineConstantExprStorage>();
+ result->kind = AffineExprKind::Constant;
result->constant = key;
return result;
}
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 61eecb811085..2247fe390ad1 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -271,7 +271,7 @@ class AliasState {
/// Mapping between attribute kind and a pair comprised of a base alias name
/// and a unique list of attributes belonging to this kind sorted by location
/// seen in the module.
- llvm::MapVector<unsigned, std::pair<StringRef, std::vector<Attribute>>>
+ llvm::MapVector<TypeID, std::pair<StringRef, std::vector<Attribute>>>
attrKindToAlias;
/// Set of types known to be used within the module.
@@ -301,13 +301,13 @@ void AliasState::initialize(
llvm::StringSet<> usedAliases;
// Collect the set of aliases from each dialect.
- SmallVector<std::pair<unsigned, StringRef>, 8> attributeKindAliases;
+ SmallVector<std::pair<TypeID, StringRef>, 8> attributeKindAliases;
SmallVector<std::pair<Attribute, StringRef>, 8> attributeAliases;
SmallVector<std::pair<Type, StringRef>, 16> typeAliases;
// AffineMap/Integer set have specific kind aliases.
- attributeKindAliases.emplace_back(StandardAttributes::AffineMap, "map");
- attributeKindAliases.emplace_back(StandardAttributes::IntegerSet, "set");
+ attributeKindAliases.emplace_back(AffineMapAttr::getTypeID(), "map");
+ attributeKindAliases.emplace_back(IntegerSetAttr::getTypeID(), "set");
for (auto &interface : interfaces) {
interface.getAttributeKindAliases(attributeKindAliases);
@@ -317,7 +317,7 @@ void AliasState::initialize(
// Setup the attribute kind aliases.
StringRef alias;
- unsigned attrKind;
+ TypeID attrKind;
for (auto &attrAliasPair : attributeKindAliases) {
std::tie(attrKind, alias) = attrAliasPair;
assert(!alias.empty() && "expected non-empty alias string");
@@ -420,7 +420,7 @@ void AliasState::recordAttributeReference(Attribute attr) {
return;
// If this attribute kind has an alias, then record one for this attribute.
- auto alias = attrKindToAlias.find(static_cast<unsigned>(attr.getKind()));
+ auto alias = attrKindToAlias.find(attr.getTypeID());
if (alias == attrKindToAlias.end())
return;
std::pair<StringRef, int> attrAlias(alias->second.first,
diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index dba7872b2a6c..ac51cba88f1c 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -57,7 +57,7 @@ Dialect &Attribute::getDialect() const {
//===----------------------------------------------------------------------===//
AffineMapAttr AffineMapAttr::get(AffineMap value) {
- return Base::get(value.getContext(), StandardAttributes::AffineMap, value);
+ return Base::get(value.getContext(), value);
}
AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
@@ -67,7 +67,7 @@ AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
//===----------------------------------------------------------------------===//
ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
- return Base::get(context, StandardAttributes::Array, value);
+ return Base::get(context, value);
}
ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; }
@@ -156,7 +156,7 @@ DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value,
if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
value = storage;
- return Base::get(context, StandardAttributes::Dictionary, value);
+ return Base::get(context, value);
}
/// Construct a dictionary with an array of values that is known to already be
/// sorted by name and uniqued.
@@ -175,7 +175,7 @@ DictionaryAttr DictionaryAttr::getWithSorted(ArrayRef<NamedAttribute> value,
return l.first == r.first;
}) == value.end() &&
"DictionaryAttr element names must be unique");
- return Base::get(context, StandardAttributes::Dictionary, value);
+ return Base::get(context, value);
}
ArrayRef<NamedAttribute> DictionaryAttr::getValue() const {
@@ -219,19 +219,19 @@ size_t DictionaryAttr::size() const { return getValue().size(); }
//===----------------------------------------------------------------------===//
FloatAttr FloatAttr::get(Type type, double value) {
- return Base::get(type.getContext(), StandardAttributes::Float, type, value);
+ return Base::get(type.getContext(), type, value);
}
FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
- return Base::getChecked(loc, StandardAttributes::Float, type, value);
+ return Base::getChecked(loc, type, value);
}
FloatAttr FloatAttr::get(Type type, const APFloat &value) {
- return Base::get(type.getContext(), StandardAttributes::Float, type, value);
+ return Base::get(type.getContext(), type, value);
}
FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) {
- return Base::getChecked(loc, StandardAttributes::Float, type, value);
+ return Base::getChecked(loc, type, value);
}
APFloat FloatAttr::getValue() const { return getImpl()->getValue(); }
@@ -279,14 +279,13 @@ LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
//===----------------------------------------------------------------------===//
FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
- return Base::get(ctx, StandardAttributes::SymbolRef, value, llvm::None)
- .cast<FlatSymbolRefAttr>();
+ return Base::get(ctx, value, llvm::None).cast<FlatSymbolRefAttr>();
}
SymbolRefAttr SymbolRefAttr::get(StringRef value,
ArrayRef<FlatSymbolRefAttr> nestedReferences,
MLIRContext *ctx) {
- return Base::get(ctx, StandardAttributes::SymbolRef, value, nestedReferences);
+ return Base::get(ctx, value, nestedReferences);
}
StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; }
@@ -307,7 +306,7 @@ ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const {
IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
if (type.isSignlessInteger(1))
return BoolAttr::get(value.getBoolValue(), type.getContext());
- return Base::get(type.getContext(), StandardAttributes::Integer, type, value);
+ return Base::get(type.getContext(), type, value);
}
IntegerAttr IntegerAttr::get(Type type, int64_t value) {
@@ -380,8 +379,7 @@ bool BoolAttr::classof(Attribute attr) {
//===----------------------------------------------------------------------===//
IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
- return Base::get(value.getConstraint(0).getContext(),
- StandardAttributes::IntegerSet, value);
+ return Base::get(value.getConstraint(0).getContext(), value);
}
IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
@@ -392,14 +390,12 @@ IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type,
MLIRContext *context) {
- return Base::get(context, StandardAttributes::Opaque, dialect, attrData,
- type);
+ return Base::get(context, dialect, attrData, type);
}
OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData,
Type type, Location location) {
- return Base::getChecked(location, StandardAttributes::Opaque, dialect,
- attrData, type);
+ return Base::getChecked(location, dialect, attrData, type);
}
/// Returns the dialect namespace of the opaque attribute.
@@ -430,7 +426,7 @@ StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
/// Get an instance of a StringAttr with the given string and Type.
StringAttr StringAttr::get(StringRef bytes, Type type) {
- return Base::get(type.getContext(), StandardAttributes::String, bytes, type);
+ return Base::get(type.getContext(), bytes, type);
}
StringRef StringAttr::getValue() const { return getImpl()->value; }
@@ -440,7 +436,7 @@ StringRef StringAttr::getValue() const { return getImpl()->value; }
//===----------------------------------------------------------------------===//
TypeAttr TypeAttr::get(Type value) {
- return Base::get(value.getContext(), StandardAttributes::Type, value);
+ return Base::get(value.getContext(), value);
}
Type TypeAttr::getValue() const { return getImpl()->value; }
@@ -1036,8 +1032,7 @@ DenseElementsAttr DenseElementsAttr::mapValues(
DenseStringElementsAttr
DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) {
- return Base::get(type.getContext(), StandardAttributes::DenseStringElements,
- type, values, (values.size() == 1));
+ return Base::get(type.getContext(), type, values, (values.size() == 1));
}
//===----------------------------------------------------------------------===//
@@ -1088,8 +1083,7 @@ DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
assert((type.isa<RankedTensorType, VectorType>()) &&
"type must be ranked tensor or vector");
assert(type.hasStaticShape() && "type must have static shape");
- return Base::get(type.getContext(), StandardAttributes::DenseIntOrFPElements,
- type, data, isSplat);
+ return Base::get(type.getContext(), type, data, isSplat);
}
/// Overload of the raw 'get' method that asserts that the given type is of
@@ -1210,8 +1204,7 @@ OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type,
StringRef bytes) {
assert(TensorType::isValidElementType(type.getElementType()) &&
"Input element type should be a valid tensor element type");
- return Base::get(type.getContext(), StandardAttributes::OpaqueElements, type,
- dialect, bytes);
+ return Base::get(type.getContext(), type, dialect, bytes);
}
StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; }
@@ -1248,7 +1241,7 @@ SparseElementsAttr SparseElementsAttr::get(ShapedType type,
assert((type.isa<RankedTensorType, VectorType>()) &&
"type must be ranked tensor or vector");
assert(type.hasStaticShape() && "type must have static shape");
- return Base::get(type.getContext(), StandardAttributes::SparseElements, type,
+ return Base::get(type.getContext(), type,
indices.cast<DenseIntElementsAttr>(), values);
}
diff --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp
index 48b05ba0eb40..151e2cf0cd61 100644
--- a/mlir/lib/IR/Location.cpp
+++ b/mlir/lib/IR/Location.cpp
@@ -28,8 +28,7 @@ bool LocationAttr::classof(Attribute attr) {
//===----------------------------------------------------------------------===//
Location CallSiteLoc::get(Location callee, Location caller) {
- return Base::get(callee->getContext(), StandardAttributes::CallSiteLocation,
- callee, caller);
+ return Base::get(callee->getContext(), callee, caller);
}
Location CallSiteLoc::get(Location name, ArrayRef<Location> frames) {
@@ -50,8 +49,7 @@ Location CallSiteLoc::getCaller() const { return getImpl()->caller; }
Location FileLineColLoc::get(Identifier filename, unsigned line,
unsigned column, MLIRContext *context) {
- return Base::get(context, StandardAttributes::FileLineColLocation, filename,
- line, column);
+ return Base::get(context, filename, line, column);
}
Location FileLineColLoc::get(StringRef filename, unsigned line, unsigned column,
@@ -95,7 +93,7 @@ Location FusedLoc::get(ArrayRef<Location> locs, Attribute metadata,
return UnknownLoc::get(context);
if (locs.size() == 1)
return locs.front();
- return Base::get(context, StandardAttributes::FusedLocation, locs, metadata);
+ return Base::get(context, locs, metadata);
}
ArrayRef<Location> FusedLoc::getLocations() const {
@@ -111,8 +109,7 @@ Attribute FusedLoc::getMetadata() const { return getImpl()->metadata; }
Location NameLoc::get(Identifier name, Location child) {
assert(!child.isa<NameLoc>() &&
"a NameLoc cannot be used as a child of another NameLoc");
- return Base::get(child->getContext(), StandardAttributes::NameLocation, name,
- child);
+ return Base::get(child->getContext(), name, child);
}
Location NameLoc::get(Identifier name, MLIRContext *context) {
@@ -131,9 +128,8 @@ Location NameLoc::getChildLoc() const { return getImpl()->child; }
Location OpaqueLoc::get(uintptr_t underlyingLocation, TypeID typeID,
Location fallbackLocation) {
- return Base::get(fallbackLocation->getContext(),
- StandardAttributes::OpaqueLocation, underlyingLocation,
- typeID, fallbackLocation);
+ return Base::get(fallbackLocation->getContext(), underlyingLocation, typeID,
+ fallbackLocation);
}
uintptr_t OpaqueLoc::getUnderlyingLocation() const {
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 0d66070657aa..a86f27a7145f 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -87,6 +87,10 @@ namespace {
struct BuiltinDialect : public Dialect {
BuiltinDialect(MLIRContext *context)
: Dialect(/*name=*/"", context, TypeID::get<BuiltinDialect>()) {
+ addTypes<ComplexType, BFloat16Type, Float16Type, Float32Type, Float64Type,
+ FunctionType, IndexType, IntegerType, MemRefType,
+ UnrankedMemRefType, NoneType, OpaqueType, RankedTensorType,
+ TupleType, UnrankedTensorType, VectorType>();
addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr,
DenseStringElementsAttr, DictionaryAttr, FloatAttr,
SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr,
@@ -95,11 +99,6 @@ struct BuiltinDialect : public Dialect {
addAttributes<CallSiteLoc, FileLineColLoc, FusedLoc, NameLoc, OpaqueLoc,
UnknownLoc>();
- addTypes<ComplexType, BFloat16Type, Float16Type, Float32Type, Float64Type,
- FunctionType, IndexType, IntegerType, MemRefType,
- UnrankedMemRefType, NoneType, OpaqueType, RankedTensorType,
- TupleType, UnrankedTensorType, VectorType>();
-
// TODO: These operations should be moved to a
diff erent dialect when they
// have been fully decoupled from the core.
addOperations<FuncOp, ModuleOp, ModuleTerminatorOp>();
@@ -363,56 +362,50 @@ MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) {
//// Types.
/// Floating-point Types.
- impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this, StandardTypes::BF16);
- impl->f16Ty = TypeUniquer::get<Float16Type>(this, StandardTypes::F16);
- impl->f32Ty = TypeUniquer::get<Float32Type>(this, StandardTypes::F32);
- impl->f64Ty = TypeUniquer::get<Float64Type>(this, StandardTypes::F64);
+ impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
+ impl->f16Ty = TypeUniquer::get<Float16Type>(this);
+ impl->f32Ty = TypeUniquer::get<Float32Type>(this);
+ impl->f64Ty = TypeUniquer::get<Float64Type>(this);
/// Index Type.
- impl->indexTy = TypeUniquer::get<IndexType>(this, StandardTypes::Index);
+ impl->indexTy = TypeUniquer::get<IndexType>(this);
/// Integer Types.
- impl->int1Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer, 1,
- IntegerType::Signless);
- impl->int8Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer, 8,
- IntegerType::Signless);
- impl->int16Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer,
- 16, IntegerType::Signless);
- impl->int32Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer,
- 32, IntegerType::Signless);
- impl->int64Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer,
- 64, IntegerType::Signless);
- impl->int128Ty = TypeUniquer::get<IntegerType>(this, StandardTypes::Integer,
- 128, IntegerType::Signless);
+ impl->int1Ty = TypeUniquer::get<IntegerType>(this, 1, IntegerType::Signless);
+ impl->int8Ty = TypeUniquer::get<IntegerType>(this, 8, IntegerType::Signless);
+ impl->int16Ty =
+ TypeUniquer::get<IntegerType>(this, 16, IntegerType::Signless);
+ impl->int32Ty =
+ TypeUniquer::get<IntegerType>(this, 32, IntegerType::Signless);
+ impl->int64Ty =
+ TypeUniquer::get<IntegerType>(this, 64, IntegerType::Signless);
+ impl->int128Ty =
+ TypeUniquer::get<IntegerType>(this, 128, IntegerType::Signless);
/// None Type.
- impl->noneType = TypeUniquer::get<NoneType>(this, StandardTypes::None);
+ impl->noneType = TypeUniquer::get<NoneType>(this);
//// Attributes.
//// Note: These must be registered after the types as they may generate one
//// of the above types internally.
/// Bool Attributes.
impl->falseAttr = AttributeUniquer::get<IntegerAttr>(
- this, StandardAttributes::Integer, impl->int1Ty,
- APInt(/*numBits=*/1, false))
+ this, impl->int1Ty, APInt(/*numBits=*/1, false))
.cast<BoolAttr>();
impl->trueAttr = AttributeUniquer::get<IntegerAttr>(
- this, StandardAttributes::Integer, impl->int1Ty,
- APInt(/*numBits=*/1, true))
+ this, impl->int1Ty, APInt(/*numBits=*/1, true))
.cast<BoolAttr>();
/// Unit Attribute.
- impl->unitAttr =
- AttributeUniquer::get<UnitAttr>(this, StandardAttributes::Unit);
+ impl->unitAttr = AttributeUniquer::get<UnitAttr>(this);
/// Unknown Location Attribute.
- impl->unknownLocAttr = AttributeUniquer::get<UnknownLoc>(
- this, StandardAttributes::UnknownLocation);
+ impl->unknownLocAttr = AttributeUniquer::get<UnknownLoc>(this);
/// The empty dictionary attribute.
- impl->emptyDictionaryAttr = AttributeUniquer::get<DictionaryAttr>(
- this, StandardAttributes::Dictionary, ArrayRef<NamedAttribute>());
+ impl->emptyDictionaryAttr =
+ AttributeUniquer::get<DictionaryAttr>(this, ArrayRef<NamedAttribute>());
// Register the affine storage objects with the uniquer.
- impl->affineUniquer.registerStorageType(
- TypeID::get<AffineBinaryOpExprStorage>());
- impl->affineUniquer.registerStorageType(
- TypeID::get<AffineConstantExprStorage>());
- impl->affineUniquer.registerStorageType(TypeID::get<AffineDimExprStorage>());
+ impl->affineUniquer
+ .registerParametricStorageType<AffineBinaryOpExprStorage>();
+ impl->affineUniquer
+ .registerParametricStorageType<AffineConstantExprStorage>();
+ impl->affineUniquer.registerParametricStorageType<AffineDimExprStorage>();
}
MLIRContext::~MLIRContext() {}
@@ -582,7 +575,6 @@ void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
AbstractType(std::move(typeInfo));
if (!impl.registeredTypes.insert({typeID, newInfo}).second)
llvm::report_fatal_error("Dialect Type already registered.");
- impl.typeUniquer.registerStorageType(typeID);
}
void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) {
@@ -592,7 +584,6 @@ void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) {
AbstractAttribute(std::move(attrInfo));
if (!impl.registeredAttributes.insert({typeID, newInfo}).second)
llvm::report_fatal_error("Dialect Attribute already registered.");
- impl.attributeUniquer.registerStorageType(typeID);
}
/// Get the dialect that registered the attribute with the provided typeid.
@@ -718,7 +709,7 @@ IntegerType IntegerType::get(unsigned width,
MLIRContext *context) {
if (auto cached = getCachedIntegerType(width, signedness, context))
return cached;
- return Base::get(context, StandardTypes::Integer, width, signedness);
+ return Base::get(context, width, signedness);
}
IntegerType IntegerType::getChecked(unsigned width, Location location) {
@@ -731,12 +722,16 @@ IntegerType IntegerType::getChecked(unsigned width,
if (auto cached =
getCachedIntegerType(width, signedness, location->getContext()))
return cached;
- return Base::getChecked(location, StandardTypes::Integer, width, signedness);
+ return Base::getChecked(location, width, signedness);
}
/// Get an instance of the NoneType.
NoneType NoneType::get(MLIRContext *context) {
- return context->getImpl().noneType;
+ if (NoneType cachedInst = context->getImpl().noneType)
+ return cachedInst;
+ // Note: May happen when initializing the singleton attributes of the builtin
+ // dialect.
+ return Base::get(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp
index f07532434a54..8eb9025b7735 100644
--- a/mlir/lib/IR/StandardTypes.cpp
+++ b/mlir/lib/IR/StandardTypes.cpp
@@ -102,12 +102,11 @@ unsigned Type::getIntOrFloatBitWidth() {
//===----------------------------------------------------------------------===//
ComplexType ComplexType::get(Type elementType) {
- return Base::get(elementType.getContext(), StandardTypes::Complex,
- elementType);
+ return Base::get(elementType.getContext(), elementType);
}
ComplexType ComplexType::getChecked(Type elementType, Location location) {
- return Base::getChecked(location, StandardTypes::Complex, elementType);
+ return Base::getChecked(location, elementType);
}
/// Verify the construction of an integer type.
@@ -265,13 +264,12 @@ bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const {
//===----------------------------------------------------------------------===//
VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) {
- return Base::get(elementType.getContext(), StandardTypes::Vector, shape,
- elementType);
+ return Base::get(elementType.getContext(), shape, elementType);
}
VectorType VectorType::getChecked(ArrayRef<int64_t> shape, Type elementType,
Location location) {
- return Base::getChecked(location, StandardTypes::Vector, shape, elementType);
+ return Base::getChecked(location, shape, elementType);
}
LogicalResult VectorType::verifyConstructionInvariants(Location loc,
@@ -320,15 +318,13 @@ bool TensorType::isValidElementType(Type type) {
RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape,
Type elementType) {
- return Base::get(elementType.getContext(), StandardTypes::RankedTensor, shape,
- elementType);
+ return Base::get(elementType.getContext(), shape, elementType);
}
RankedTensorType RankedTensorType::getChecked(ArrayRef<int64_t> shape,
Type elementType,
Location location) {
- return Base::getChecked(location, StandardTypes::RankedTensor, shape,
- elementType);
+ return Base::getChecked(location, shape, elementType);
}
LogicalResult RankedTensorType::verifyConstructionInvariants(
@@ -349,13 +345,12 @@ ArrayRef<int64_t> RankedTensorType::getShape() const {
//===----------------------------------------------------------------------===//
UnrankedTensorType UnrankedTensorType::get(Type elementType) {
- return Base::get(elementType.getContext(), StandardTypes::UnrankedTensor,
- elementType);
+ return Base::get(elementType.getContext(), elementType);
}
UnrankedTensorType UnrankedTensorType::getChecked(Type elementType,
Location location) {
- return Base::getChecked(location, StandardTypes::UnrankedTensor, elementType);
+ return Base::getChecked(location, elementType);
}
LogicalResult
@@ -444,8 +439,8 @@ MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
cleanedAffineMapComposition.push_back(map);
}
- return Base::get(context, StandardTypes::MemRef, shape, elementType,
- cleanedAffineMapComposition, memorySpace);
+ return Base::get(context, shape, elementType, cleanedAffineMapComposition,
+ memorySpace);
}
ArrayRef<int64_t> MemRefType::getShape() const { return getImpl()->getShape(); }
@@ -462,15 +457,13 @@ unsigned MemRefType::getMemorySpace() const { return getImpl()->memorySpace; }
UnrankedMemRefType UnrankedMemRefType::get(Type elementType,
unsigned memorySpace) {
- return Base::get(elementType.getContext(), StandardTypes::UnrankedMemRef,
- elementType, memorySpace);
+ return Base::get(elementType.getContext(), elementType, memorySpace);
}
UnrankedMemRefType UnrankedMemRefType::getChecked(Type elementType,
unsigned memorySpace,
Location location) {
- return Base::getChecked(location, StandardTypes::UnrankedMemRef, elementType,
- memorySpace);
+ return Base::getChecked(location, elementType, memorySpace);
}
unsigned UnrankedMemRefType::getMemorySpace() const {
@@ -642,7 +635,7 @@ LogicalResult mlir::getStridesAndOffset(MemRefType t,
/// Get or create a new TupleType with the provided element types. Assumes the
/// arguments define a well-formed type.
TupleType TupleType::get(TypeRange elementTypes, MLIRContext *context) {
- return Base::get(context, StandardTypes::Tuple, elementTypes);
+ return Base::get(context, elementTypes);
}
/// Get or create an empty tuple type.
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index ae2dd909ff59..cdcd6a9c6ea5 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -19,8 +19,6 @@ using namespace mlir::detail;
// Type
//===----------------------------------------------------------------------===//
-unsigned Type::getKind() const { return impl->getKind(); }
-
Dialect &Type::getDialect() const {
return impl->getAbstractType().getDialect();
}
@@ -33,7 +31,7 @@ MLIRContext *Type::getContext() const { return getDialect().getContext(); }
FunctionType FunctionType::get(TypeRange inputs, TypeRange results,
MLIRContext *context) {
- return Base::get(context, Type::Kind::Function, inputs, results);
+ return Base::get(context, inputs, results);
}
unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
@@ -54,12 +52,12 @@ ArrayRef<Type> FunctionType::getResults() const {
OpaqueType OpaqueType::get(Identifier dialect, StringRef typeData,
MLIRContext *context) {
- return Base::get(context, Type::Kind::Opaque, dialect, typeData);
+ return Base::get(context, dialect, typeData);
}
OpaqueType OpaqueType::getChecked(Identifier dialect, StringRef typeData,
MLIRContext *context, Location location) {
- return Base::getChecked(location, Kind::Opaque, dialect, typeData);
+ return Base::getChecked(location, dialect, typeData);
}
/// Returns the dialect namespace of the opaque type.
diff --git a/mlir/lib/Support/StorageUniquer.cpp b/mlir/lib/Support/StorageUniquer.cpp
index 49e7272091fb..73578b5c91ac 100644
--- a/mlir/lib/Support/StorageUniquer.cpp
+++ b/mlir/lib/Support/StorageUniquer.cpp
@@ -16,19 +16,17 @@ using namespace mlir;
using namespace mlir::detail;
namespace {
-/// This class represents a uniquer for storage instances of a specific type. It
-/// contains all of the necessary data to unique storage instances in a thread
-/// safe way. This allows for the main uniquer to bucket each of the individual
-/// sub-types removing the need to lock the main uniquer itself.
-struct InstSpecificUniquer {
+/// This class represents a uniquer for storage instances of a specific type
+/// that has parametric storage. It contains all of the necessary data to unique
+/// storage instances in a thread safe way. This allows for the main uniquer to
+/// bucket each of the individual sub-types removing the need to lock the main
+/// uniquer itself.
+struct ParametricStorageUniquer {
using BaseStorage = StorageUniquer::BaseStorage;
using StorageAllocator = StorageUniquer::StorageAllocator;
/// A lookup key for derived instances of storage objects.
struct LookupKey {
- /// The known derived kind for the storage.
- unsigned kind;
-
/// The known hash value of the key.
unsigned hashValue;
@@ -63,18 +61,14 @@ struct InstSpecificUniquer {
static bool isEqual(const LookupKey &lhs, const HashedStorage &rhs) {
if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
return false;
- // If the lookup kind matches the kind of the storage, then invoke the
- // equality function on the lookup key.
- return lhs.kind == rhs.storage->getKind() && lhs.isEqual(rhs.storage);
+ // Invoke the equality function on the lookup key.
+ return lhs.isEqual(rhs.storage);
}
};
- /// Unique types with specific hashing or storage constraints.
+ /// The set containing the allocated storage instances.
using StorageTypeSet = DenseSet<HashedStorage, StorageKeyInfo>;
- StorageTypeSet complexInstances;
-
- /// Instances of this storage object.
- llvm::SmallDenseMap<unsigned, BaseStorage *, 1> simpleInstances;
+ StorageTypeSet instances;
/// Allocator to use when constructing derived instances.
StorageAllocator allocator;
@@ -91,107 +85,79 @@ struct StorageUniquerImpl {
using BaseStorage = StorageUniquer::BaseStorage;
using StorageAllocator = StorageUniquer::StorageAllocator;
- /// Get or create an instance of a complex derived type.
+ //===--------------------------------------------------------------------===//
+ // Parametric Storage
+ //===--------------------------------------------------------------------===//
+
+ /// Get or create an instance of a parametric type.
BaseStorage *
- getOrCreate(TypeID id, unsigned kind, unsigned hashValue,
+ getOrCreate(TypeID id, unsigned hashValue,
function_ref<bool(const BaseStorage *)> isEqual,
function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
- assert(instUniquers.count(id) && "creating unregistered storage instance");
- InstSpecificUniquer::LookupKey lookupKey{kind, hashValue, isEqual};
- InstSpecificUniquer &storageUniquer = *instUniquers[id];
+ assert(parametricUniquers.count(id) &&
+ "creating unregistered storage instance");
+ ParametricStorageUniquer::LookupKey lookupKey{hashValue, isEqual};
+ ParametricStorageUniquer &storageUniquer = *parametricUniquers[id];
if (!threadingIsEnabled)
- return getOrCreateUnsafe(storageUniquer, kind, lookupKey, ctorFn);
+ return getOrCreateUnsafe(storageUniquer, lookupKey, ctorFn);
// Check for an existing instance in read-only mode.
{
llvm::sys::SmartScopedReader<true> typeLock(storageUniquer.mutex);
- auto it = storageUniquer.complexInstances.find_as(lookupKey);
- if (it != storageUniquer.complexInstances.end())
+ auto it = storageUniquer.instances.find_as(lookupKey);
+ if (it != storageUniquer.instances.end())
return it->storage;
}
// Acquire a writer-lock so that we can safely create the new type instance.
llvm::sys::SmartScopedWriter<true> typeLock(storageUniquer.mutex);
- return getOrCreateUnsafe(storageUniquer, kind, lookupKey, ctorFn);
+ return getOrCreateUnsafe(storageUniquer, lookupKey, ctorFn);
}
/// Get or create an instance of a complex derived type in an thread-unsafe
/// fashion.
BaseStorage *
- getOrCreateUnsafe(InstSpecificUniquer &storageUniquer, unsigned kind,
- InstSpecificUniquer::LookupKey &lookupKey,
+ getOrCreateUnsafe(ParametricStorageUniquer &storageUniquer,
+ ParametricStorageUniquer::LookupKey &lookupKey,
function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
- auto existing = storageUniquer.complexInstances.insert_as({}, lookupKey);
+ auto existing = storageUniquer.instances.insert_as({}, lookupKey);
if (!existing.second)
return existing.first->storage;
// Otherwise, construct and initialize the derived storage for this type
// instance.
- BaseStorage *storage =
- initializeStorage(kind, storageUniquer.allocator, ctorFn);
+ BaseStorage *storage = ctorFn(storageUniquer.allocator);
*existing.first =
- InstSpecificUniquer::HashedStorage{lookupKey.hashValue, storage};
+ ParametricStorageUniquer::HashedStorage{lookupKey.hashValue, storage};
return storage;
}
- /// Get or create an instance of a simple derived type.
- BaseStorage *
- getOrCreate(TypeID id, unsigned kind,
- function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
- assert(instUniquers.count(id) && "creating unregistered storage instance");
- InstSpecificUniquer &storageUniquer = *instUniquers[id];
- if (!threadingIsEnabled)
- return getOrCreateUnsafe(storageUniquer, kind, ctorFn);
-
- // Check for an existing instance in read-only mode.
- {
- llvm::sys::SmartScopedReader<true> typeLock(storageUniquer.mutex);
- auto it = storageUniquer.simpleInstances.find(kind);
- if (it != storageUniquer.simpleInstances.end())
- return it->second;
- }
-
- // Acquire a writer-lock so that we can safely create the new type instance.
- llvm::sys::SmartScopedWriter<true> typeLock(storageUniquer.mutex);
- return getOrCreateUnsafe(storageUniquer, kind, ctorFn);
- }
- /// Get or create an instance of a simple derived type in an thread-unsafe
- /// fashion.
- BaseStorage *
- getOrCreateUnsafe(InstSpecificUniquer &storageUniquer, unsigned kind,
- function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
- auto &result = storageUniquer.simpleInstances[kind];
- if (result)
- return result;
-
- // Otherwise, create and return a new storage instance.
- return result = initializeStorage(kind, storageUniquer.allocator, ctorFn);
- }
-
- /// Erase an instance of a complex derived type.
- void erase(TypeID id, unsigned kind, unsigned hashValue,
+ /// Erase an instance of a parametric derived type.
+ void erase(TypeID id, unsigned hashValue,
function_ref<bool(const BaseStorage *)> isEqual,
function_ref<void(BaseStorage *)> cleanupFn) {
- assert(instUniquers.count(id) && "erasing unregistered storage instance");
- InstSpecificUniquer &storageUniquer = *instUniquers[id];
- InstSpecificUniquer::LookupKey lookupKey{kind, hashValue, isEqual};
+ assert(parametricUniquers.count(id) &&
+ "erasing unregistered storage instance");
+ ParametricStorageUniquer &storageUniquer = *parametricUniquers[id];
+ ParametricStorageUniquer::LookupKey lookupKey{hashValue, isEqual};
// Acquire a writer-lock so that we can safely erase the type instance.
llvm::sys::SmartScopedWriter<true> lock(storageUniquer.mutex);
- auto existing = storageUniquer.complexInstances.find_as(lookupKey);
- if (existing == storageUniquer.complexInstances.end())
+ auto existing = storageUniquer.instances.find_as(lookupKey);
+ if (existing == storageUniquer.instances.end())
return;
// Cleanup the storage and remove it from the map.
cleanupFn(existing->storage);
- storageUniquer.complexInstances.erase(existing);
+ storageUniquer.instances.erase(existing);
}
/// Mutates an instance of a derived storage in a thread-safe way.
LogicalResult
mutate(TypeID id,
function_ref<LogicalResult(StorageAllocator &)> mutationFn) {
- assert(instUniquers.count(id) && "mutating unregistered storage instance");
- InstSpecificUniquer &storageUniquer = *instUniquers[id];
+ assert(parametricUniquers.count(id) &&
+ "mutating unregistered storage instance");
+ ParametricStorageUniquer &storageUniquer = *parametricUniquers[id];
if (!threadingIsEnabled)
return mutationFn(storageUniquer.allocator);
@@ -200,20 +166,30 @@ struct StorageUniquerImpl {
}
//===--------------------------------------------------------------------===//
- // Instance Storage
+ // Singleton Storage
//===--------------------------------------------------------------------===//
- /// Utility to create and initialize a storage instance.
- BaseStorage *
- initializeStorage(unsigned kind, StorageAllocator &allocator,
- function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
- BaseStorage *storage = ctorFn(allocator);
- storage->kind = kind;
- return storage;
+ /// Get or create an instance of a singleton storage class.
+ BaseStorage *getSingleton(TypeID id) {
+ BaseStorage *singletonInstance = singletonInstances[id];
+ assert(singletonInstance && "expected singleton instance to exist");
+ return singletonInstance;
}
+ //===--------------------------------------------------------------------===//
+ // Instance Storage
+ //===--------------------------------------------------------------------===//
+
/// Map of type ids to the storage uniquer to use for registered objects.
- DenseMap<TypeID, std::unique_ptr<InstSpecificUniquer>> instUniquers;
+ DenseMap<TypeID, std::unique_ptr<ParametricStorageUniquer>>
+ parametricUniquers;
+
+ /// Map of type ids to a singleton instance when the storage class is a
+ /// singleton.
+ DenseMap<TypeID, BaseStorage *> singletonInstances;
+
+ /// Allocator used for uniquing singleton instances.
+ StorageAllocator singletonAllocator;
/// Flag specifying if multi-threading is enabled within the uniquer.
bool threadingIsEnabled = true;
@@ -229,41 +205,47 @@ void StorageUniquer::disableMultithreading(bool disable) {
impl->threadingIsEnabled = !disable;
}
-/// Register a new storage object with this uniquer using the given unique type
-/// id.
-void StorageUniquer::registerStorageType(TypeID id) {
- impl->instUniquers.try_emplace(id, std::make_unique<InstSpecificUniquer>());
-}
-
/// Implementation for getting/creating an instance of a derived type with
-/// complex storage.
-auto StorageUniquer::getImpl(
- const TypeID &id, unsigned kind, unsigned hashValue,
+/// parametric storage.
+auto StorageUniquer::getParametricStorageTypeImpl(
+ TypeID id, unsigned hashValue,
function_ref<bool(const BaseStorage *)> isEqual,
function_ref<BaseStorage *(StorageAllocator &)> ctorFn) -> BaseStorage * {
- return impl->getOrCreate(id, kind, hashValue, isEqual, ctorFn);
+ return impl->getOrCreate(id, hashValue, isEqual, ctorFn);
}
-/// Implementation for getting/creating an instance of a derived type with
-/// default storage.
-auto StorageUniquer::getImpl(
- const TypeID &id, unsigned kind,
- function_ref<BaseStorage *(StorageAllocator &)> ctorFn) -> BaseStorage * {
- return impl->getOrCreate(id, kind, ctorFn);
+/// Implementation for registering an instance of a derived type with
+/// parametric storage.
+void StorageUniquer::registerParametricStorageTypeImpl(TypeID id) {
+ impl->parametricUniquers.try_emplace(
+ id, std::make_unique<ParametricStorageUniquer>());
+}
+
+/// Implementation for getting an instance of a derived type with default
+/// storage.
+auto StorageUniquer::getSingletonImpl(TypeID id) -> BaseStorage * {
+ return impl->getSingleton(id);
+}
+
+/// Implementation for registering an instance of a derived type with default
+/// storage.
+void StorageUniquer::registerSingletonImpl(
+ TypeID id, function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
+ assert(!impl->singletonInstances.count(id) &&
+ "storage class already registered");
+ impl->singletonInstances.try_emplace(id, ctorFn(impl->singletonAllocator));
}
-/// Implementation for erasing an instance of a derived type with complex
+/// Implementation for erasing an instance of a derived type with parametric
/// storage.
-void StorageUniquer::eraseImpl(const TypeID &id, unsigned kind,
- unsigned hashValue,
+void StorageUniquer::eraseImpl(TypeID id, unsigned hashValue,
function_ref<bool(const BaseStorage *)> isEqual,
function_ref<void(BaseStorage *)> cleanupFn) {
- impl->erase(id, kind, hashValue, isEqual, cleanupFn);
+ impl->erase(id, hashValue, isEqual, cleanupFn);
}
/// Implementation for mutating an instance of a derived storage.
LogicalResult StorageUniquer::mutateImpl(
- const TypeID &id,
- function_ref<LogicalResult(StorageAllocator &)> mutationFn) {
+ TypeID id, function_ref<LogicalResult(StorageAllocator &)> mutationFn) {
return impl->mutate(id, mutationFn);
}
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index c873a009f151..7bea72dbc040 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -156,7 +156,7 @@ static Type parseTestType(DialectAsmParser &parser,
StringRef name;
if (parser.parseLess() || parser.parseKeyword(&name))
return Type();
- auto rec = TestRecursiveType::create(parser.getBuilder().getContext(), name);
+ auto rec = TestRecursiveType::get(parser.getBuilder().getContext(), name);
// If this type already has been parsed above in the stack, expect just the
// name.
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h
index 1df165591672..c7fd80ef1fb9 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.h
+++ b/mlir/test/lib/Dialect/Test/TestTypes.h
@@ -26,10 +26,6 @@ struct TestType : public Type::TypeBase<TestType, Type, TypeStorage,
TestTypeInterface::Trait> {
using Base::Base;
- static TestType get(MLIRContext *context) {
- return Base::get(context, Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_9_TYPE);
- }
-
/// Provide a definition for the necessary interface methods.
void printTypeC(Location loc) const {
emitRemark(loc) << *this << " - TestC";
@@ -72,9 +68,8 @@ class TestRecursiveType
public:
using Base::Base;
- static TestRecursiveType create(MLIRContext *ctx, StringRef name) {
- return Base::get(ctx, Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_9_TYPE + 1,
- name);
+ static TestRecursiveType get(MLIRContext *ctx, StringRef name) {
+ return Base::get(ctx, name);
}
/// Body getter and setter.
diff --git a/mlir/test/lib/IR/TestTypes.cpp b/mlir/test/lib/IR/TestTypes.cpp
index f62c06eededf..37e322c610c1 100644
--- a/mlir/test/lib/IR/TestTypes.cpp
+++ b/mlir/test/lib/IR/TestTypes.cpp
@@ -41,7 +41,7 @@ struct TestRecursiveTypesPass
LogicalResult TestRecursiveTypesPass::createIRWithTypes() {
MLIRContext *ctx = &getContext();
FuncOp func = getFunction();
- auto type = TestRecursiveType::create(ctx, "some_long_and_unique_name");
+ auto type = TestRecursiveType::get(ctx, "some_long_and_unique_name");
if (failed(type.setBody(type)))
return func.emitError("expected to be able to set the type body");
@@ -56,7 +56,7 @@ LogicalResult TestRecursiveTypesPass::createIRWithTypes() {
"not expected to be able to change function body more than once");
// Expecting to get the same type for the same name.
- auto other = TestRecursiveType::create(ctx, "some_long_and_unique_name");
+ auto other = TestRecursiveType::get(ctx, "some_long_and_unique_name");
if (type != other)
return func.emitError("expected type name to be the uniquing key");
More information about the flang-commits
mailing list