[Mlir-commits] [mlir] a517191 - [mlir][NFC] Refactor ClassID into a TypeID class.
River Riddle
llvmlistbot at llvm.org
Fri Apr 10 23:57:23 PDT 2020
Author: River Riddle
Date: 2020-04-10T23:52:33-07:00
New Revision: a517191a474f7d6867621d0f8e8cc454c27334bf
URL: https://github.com/llvm/llvm-project/commit/a517191a474f7d6867621d0f8e8cc454c27334bf
DIFF: https://github.com/llvm/llvm-project/commit/a517191a474f7d6867621d0f8e8cc454c27334bf.diff
LOG: [mlir][NFC] Refactor ClassID into a TypeID class.
Summary: ClassID is a bit janky right now as it involves passing a magic pointer around. This revision hides the internal implementation mechanism within a new class TypeID. This class is a value-typed wrapper around the original ClassID implementation.
Differential Revision: https://reviews.llvm.org/D77768
Added:
mlir/include/mlir/Support/TypeID.h
Modified:
mlir/docs/WritingAPass.md
mlir/include/mlir/IR/AttributeSupport.h
mlir/include/mlir/IR/Dialect.h
mlir/include/mlir/IR/DialectHooks.h
mlir/include/mlir/IR/DialectInterface.h
mlir/include/mlir/IR/Location.h
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/OperationSupport.h
mlir/include/mlir/IR/StorageUniquerSupport.h
mlir/include/mlir/IR/TypeSupport.h
mlir/include/mlir/Interfaces/SideEffects.h
mlir/include/mlir/Pass/AnalysisManager.h
mlir/include/mlir/Pass/Pass.h
mlir/include/mlir/Pass/PassInstrumentation.h
mlir/include/mlir/Pass/PassRegistry.h
mlir/include/mlir/Support/STLExtras.h
mlir/lib/IR/Dialect.cpp
mlir/lib/IR/Location.cpp
mlir/lib/IR/LocationDetail.h
mlir/lib/IR/MLIRContext.cpp
mlir/lib/Pass/Pass.cpp
mlir/lib/Pass/PassRegistry.cpp
mlir/lib/Pass/PassTiming.cpp
mlir/tools/mlir-tblgen/PassGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/WritingAPass.md b/mlir/docs/WritingAPass.md
index 4fd9ceda32fd..0da9fd4911d8 100644
--- a/mlir/docs/WritingAPass.md
+++ b/mlir/docs/WritingAPass.md
@@ -753,8 +753,8 @@ struct DominanceCounterInstrumentation : public PassInstrumentation {
unsigned &count;
DominanceCounterInstrumentation(unsigned &count) : count(count) {}
- void runAfterAnalysis(llvm::StringRef, AnalysisID *id, Operation *) override {
- if (id == AnalysisID::getID<DominanceInfo>())
+ void runAfterAnalysis(llvm::StringRef, TypeID id, Operation *) override {
+ if (id == TypeID::get<DominanceInfo>())
++count;
}
};
diff --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h
index 98b0c6370801..0289e9ce3175 100644
--- a/mlir/include/mlir/IR/AttributeSupport.h
+++ b/mlir/include/mlir/IR/AttributeSupport.h
@@ -92,13 +92,13 @@ class AttributeUniquer {
template <typename T, typename... Args>
static T get(MLIRContext *ctx, unsigned kind, Args &&... args) {
return ctx->getAttributeUniquer().get<typename T::ImplType>(
- getInitFn(ctx, T::getClassID()), kind, std::forward<Args>(args)...);
+ getInitFn(ctx, T::getTypeID()), kind, std::forward<Args>(args)...);
}
private:
/// Returns a functor used to initialize new attribute storage instances.
- static std::function<void(AttributeStorage *)>
- getInitFn(MLIRContext *ctx, const ClassID *const attrID);
+ static std::function<void(AttributeStorage *)> getInitFn(MLIRContext *ctx,
+ TypeID attrID);
};
} // namespace detail
diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h
index f65060238171..292c14db3f1f 100644
--- a/mlir/include/mlir/IR/Dialect.h
+++ b/mlir/include/mlir/IR/Dialect.h
@@ -160,7 +160,7 @@ class Dialect {
/// Lookup an interface for the given ID if one is registered, otherwise
/// nullptr.
- const DialectInterface *getRegisteredInterface(ClassID *interfaceID) {
+ const DialectInterface *getRegisteredInterface(TypeID interfaceID) {
auto it = registeredInterfaces.find(interfaceID);
return it != registeredInterfaces.end() ? it->getSecond().get() : nullptr;
}
@@ -190,12 +190,12 @@ class Dialect {
/// This method is used by derived classes to add their types to the set.
template <typename... Args> void addTypes() {
- (void)std::initializer_list<int>{0, (addSymbol(Args::getClassID()), 0)...};
+ (void)std::initializer_list<int>{0, (addSymbol(Args::getTypeID()), 0)...};
}
/// This method is used by derived classes to add their attributes to the set.
template <typename... Args> void addAttributes() {
- (void)std::initializer_list<int>{0, (addSymbol(Args::getClassID()), 0)...};
+ (void)std::initializer_list<int>{0, (addSymbol(Args::getTypeID()), 0)...};
}
/// Enable support for unregistered operations.
@@ -215,7 +215,7 @@ class Dialect {
private:
// Register a symbol(e.g. type) with its given unique class identifier.
- void addSymbol(const ClassID *const classID);
+ void addSymbol(TypeID typeID);
Dialect(const Dialect &) = delete;
void operator=(Dialect &) = delete;
@@ -241,14 +241,14 @@ class Dialect {
bool unknownTypesAllowed = false;
/// A collection of registered dialect interfaces.
- DenseMap<ClassID *, std::unique_ptr<DialectInterface>> registeredInterfaces;
+ DenseMap<TypeID, std::unique_ptr<DialectInterface>> registeredInterfaces;
/// Registers a specific dialect creation function with the global registry.
/// Used through the registerDialect template.
- /// Registrations are deduplicated by dialect ClassID and only the first
+ /// Registrations are deduplicated by dialect TypeID and only the first
/// registration will be used.
static void
- registerDialectAllocator(const ClassID *classId,
+ registerDialectAllocator(TypeID typeID,
const DialectAllocatorFunction &function);
template <typename ConcreteDialect>
friend void registerDialect();
@@ -260,7 +260,7 @@ void registerAllDialects(MLIRContext *context);
/// Utility to register a dialect. Client can register their dialect with the
/// global registry by calling registerDialect<MyDialect>();
template <typename ConcreteDialect> void registerDialect() {
- Dialect::registerDialectAllocator(ClassID::getID<ConcreteDialect>(),
+ Dialect::registerDialectAllocator(TypeID::get<ConcreteDialect>(),
[](MLIRContext *ctx) {
// Just allocate the dialect, the context
// takes ownership of it.
diff --git a/mlir/include/mlir/IR/DialectHooks.h b/mlir/include/mlir/IR/DialectHooks.h
index 4e59b4953e65..39862667fd75 100644
--- a/mlir/include/mlir/IR/DialectHooks.h
+++ b/mlir/include/mlir/IR/DialectHooks.h
@@ -38,15 +38,15 @@ class DialectHooks {
private:
/// Registers a function that will set hooks in the registered dialects.
- /// Registrations are deduplicated by dialect ClassID and only the first
+ /// Registrations are deduplicated by dialect TypeID and only the first
/// registration will be used.
- static void registerDialectHooksSetter(const ClassID *classId,
+ static void registerDialectHooksSetter(TypeID typeID,
const DialectHooksSetter &function);
template <typename ConcreteHooks>
friend void registerDialectHooks(StringRef dialectName);
};
-void registerDialectHooksSetter(const ClassID *classId,
+void registerDialectHooksSetter(TypeID typeID,
const DialectHooksSetter &function);
/// Utility to register dialect hooks. Client can register their dialect hooks
@@ -55,7 +55,7 @@ void registerDialectHooksSetter(const ClassID *classId,
template <typename ConcreteHooks>
void registerDialectHooks(StringRef dialectName) {
DialectHooks::registerDialectHooksSetter(
- ClassID::getID<ConcreteHooks>(), [dialectName](MLIRContext *ctx) {
+ TypeID::get<ConcreteHooks>(), [dialectName](MLIRContext *ctx) {
Dialect *dialect = ctx->getRegisteredDialect(dialectName);
if (!dialect) {
llvm::errs() << "error: cannot register hooks for unknown dialect '"
diff --git a/mlir/include/mlir/IR/DialectInterface.h b/mlir/include/mlir/IR/DialectInterface.h
index 5fc2a8849c1d..79826c2f6ffc 100644
--- a/mlir/include/mlir/IR/DialectInterface.h
+++ b/mlir/include/mlir/IR/DialectInterface.h
@@ -9,7 +9,7 @@
#ifndef MLIR_IR_DIALECTINTERFACE_H
#define MLIR_IR_DIALECTINTERFACE_H
-#include "mlir/Support/STLExtras.h"
+#include "mlir/Support/TypeID.h"
#include "llvm/ADT/DenseSet.h"
namespace mlir {
@@ -29,7 +29,7 @@ class DialectInterfaceBase : public BaseT {
using Base = DialectInterfaceBase<ConcreteType, BaseT>;
/// Get a unique id for the derived interface type.
- static ClassID *getInterfaceID() { return ClassID::getID<ConcreteType>(); }
+ static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
protected:
DialectInterfaceBase(Dialect *dialect) : BaseT(dialect, getInterfaceID()) {}
@@ -50,10 +50,10 @@ class DialectInterface {
Dialect *getDialect() const { return dialect; }
/// Return the derived interface id.
- ClassID *getID() const { return interfaceID; }
+ TypeID getID() const { return interfaceID; }
protected:
- DialectInterface(Dialect *dialect, ClassID *id)
+ DialectInterface(Dialect *dialect, TypeID id)
: dialect(dialect), interfaceID(id) {}
private:
@@ -61,7 +61,7 @@ class DialectInterface {
Dialect *dialect;
/// The unique identifier for the derived interface type.
- ClassID *interfaceID;
+ TypeID interfaceID;
};
//===----------------------------------------------------------------------===//
@@ -93,7 +93,7 @@ class DialectInterfaceCollectionBase {
using InterfaceVectorT = std::vector<const DialectInterface *>;
public:
- DialectInterfaceCollectionBase(MLIRContext *ctx, ClassID *interfaceKind);
+ DialectInterfaceCollectionBase(MLIRContext *ctx, TypeID interfaceKind);
virtual ~DialectInterfaceCollectionBase();
protected:
diff --git a/mlir/include/mlir/IR/Location.h b/mlir/include/mlir/IR/Location.h
index e0853611b67c..ea6c954f9aac 100644
--- a/mlir/include/mlir/IR/Location.h
+++ b/mlir/include/mlir/IR/Location.h
@@ -237,7 +237,7 @@ class OpaqueLoc : public Attribute::AttrBase<OpaqueLoc, LocationAttr,
template <typename T>
static Location get(T underlyingLocation, MLIRContext *context) {
return get(reinterpret_cast<uintptr_t>(underlyingLocation),
- ClassID::getID<T>(), UnknownLoc::get(context));
+ TypeID::get<T>(), UnknownLoc::get(context));
}
/// Returns an instance of opaque location which contains a given pointer to
@@ -245,7 +245,7 @@ class OpaqueLoc : public Attribute::AttrBase<OpaqueLoc, LocationAttr,
template <typename T>
static Location get(T underlyingLocation, Location fallbackLocation) {
return get(reinterpret_cast<uintptr_t>(underlyingLocation),
- ClassID::getID<T>(), fallbackLocation);
+ TypeID::get<T>(), fallbackLocation);
}
/// Returns a pointer to some data structure that opaque location stores.
@@ -270,14 +270,14 @@ class OpaqueLoc : public Attribute::AttrBase<OpaqueLoc, LocationAttr,
/// to an object of particular type.
template <typename T> static bool isa(Location location) {
auto opaque_loc = location.dyn_cast<OpaqueLoc>();
- return opaque_loc && opaque_loc.getClassId() == ClassID::getID<T>();
+ return opaque_loc && opaque_loc.getUnderlyingTypeID() == TypeID::get<T>();
}
/// Returns a pointer to the corresponding object.
uintptr_t getUnderlyingLocation() const;
- /// Returns a ClassID* that represents the underlying objects c++ type.
- ClassID *getClassId() const;
+ /// Returns a TypeID that represents the underlying objects c++ type.
+ TypeID getUnderlyingTypeID() const;
/// Returns a fallback location.
Location getFallbackLocation() const;
@@ -288,7 +288,7 @@ class OpaqueLoc : public Attribute::AttrBase<OpaqueLoc, LocationAttr,
}
private:
- static Location get(uintptr_t underlyingLocation, ClassID *classID,
+ static Location get(uintptr_t underlyingLocation, TypeID typeID,
Location fallbackLocation);
};
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 38403e86f952..e8944c7002b7 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1185,7 +1185,7 @@ class Op : public OpState,
/// Return true if this "op class" can match against the specified operation.
static bool classof(Operation *op) {
if (auto *abstractOp = op->getAbstractOperation())
- return ClassID::getID<ConcreteType>() == abstractOp->classID;
+ return TypeID::get<ConcreteType>() == abstractOp->typeID;
return op->getName().getStringRef() == ConcreteType::getOperationName();
}
@@ -1278,15 +1278,15 @@ class Op : public OpState,
}
};
- /// Returns true if this operation contains the trait for the given classID.
- static bool hasTrait(ClassID *traitID) {
- return llvm::is_contained(llvm::makeArrayRef({ClassID::getID<Traits>()...}),
+ /// Returns true if this operation contains the trait for the given typeID.
+ static bool hasTrait(TypeID traitID) {
+ return llvm::is_contained(llvm::makeArrayRef({TypeID::get<Traits>()...}),
traitID);
}
/// Returns an opaque pointer to a concept instance of the interface with the
/// given ID if one was registered to this operation.
- static void *getRawInterface(ClassID *id) {
+ static void *getRawInterface(TypeID id) {
return InterfaceLookup::template lookup<Traits<ConcreteType>...>(id);
}
@@ -1300,7 +1300,7 @@ class Op : public OpState,
template <typename T>
static typename std::enable_if<is_detected<has_get_interface_id, T>::value,
void *>::type
- lookup(ClassID *interfaceID) {
+ lookup(TypeID interfaceID) {
return (T::getInterfaceID() == interfaceID) ? &T::instance() : nullptr;
}
@@ -1308,12 +1308,12 @@ class Op : public OpState,
template <typename T>
static typename std::enable_if<!is_detected<has_get_interface_id, T>::value,
void *>::type
- lookup(ClassID *) {
+ lookup(TypeID) {
return nullptr;
}
template <typename T, typename T2, typename... Ts>
- static void *lookup(ClassID *interfaceID) {
+ static void *lookup(TypeID interfaceID) {
auto *concept = lookup<T>(interfaceID);
return concept ? concept : lookup<T2, Ts...>(interfaceID);
}
@@ -1359,14 +1359,14 @@ class OpInterface : public Op<ConcreteType> {
static bool classof(Operation *op) { return getInterfaceFor(op); }
/// Define an accessor for the ID of this interface.
- static ClassID *getInterfaceID() { return ClassID::getID<ConcreteType>(); }
+ static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
/// This is a special trait that registers a given interface with an
/// operation.
template <typename ConcreteOp>
struct Trait : public OpTrait::TraitBase<ConcreteOp, Trait> {
/// Define an accessor for the ID of this interface.
- static ClassID *getInterfaceID() { return ClassID::getID<ConcreteType>(); }
+ static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
/// Provide an accessor to a static instance of the interface model for the
/// concrete operation type.
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 8d6bda3dbba0..6efdff2fb5a0 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -94,7 +94,7 @@ class AbstractOperation {
Dialect &dialect;
/// The unique identifier of the derived Op class.
- ClassID *classID;
+ TypeID typeID;
/// Use the specified object to parse this ops custom assembly format.
ParseResult (&parseAssembly)(OpAsmParser &parser, OperationState &result);
@@ -149,7 +149,7 @@ class AbstractOperation {
/// Returns if the operation has a particular trait.
template <template <typename T> class Trait> bool hasTrait() const {
- return hasRawTrait(ClassID::getID<Trait>());
+ return hasRawTrait(TypeID::get<Trait>());
}
/// Look up the specified operation in the specified MLIRContext and return a
@@ -162,7 +162,7 @@ class AbstractOperation {
template <typename T> static AbstractOperation get(Dialect &dialect) {
return AbstractOperation(
T::getOperationName(), dialect, T::getOperationProperties(),
- ClassID::getID<T>(), T::parseAssembly, T::printAssembly,
+ TypeID::get<T>(), T::parseAssembly, T::printAssembly,
T::verifyInvariants, T::foldHook, T::getCanonicalizationPatterns,
T::getRawInterface, T::hasTrait);
}
@@ -170,7 +170,7 @@ class AbstractOperation {
private:
AbstractOperation(
StringRef name, Dialect &dialect, OperationProperties opProperties,
- ClassID *classID,
+ TypeID typeID,
ParseResult (&parseAssembly)(OpAsmParser &parser, OperationState &result),
void (&printAssembly)(Operation *op, OpAsmPrinter &p),
LogicalResult (&verifyInvariants)(Operation *op),
@@ -178,9 +178,9 @@ class AbstractOperation {
SmallVectorImpl<OpFoldResult> &results),
void (&getCanonicalizationPatterns)(OwningRewritePatternList &results,
MLIRContext *context),
- void *(&getRawInterface)(ClassID *interfaceID),
- bool (&hasTrait)(ClassID *traitID))
- : name(name), dialect(dialect), classID(classID),
+ void *(&getRawInterface)(TypeID interfaceID),
+ bool (&hasTrait)(TypeID traitID))
+ : name(name), dialect(dialect), typeID(typeID),
parseAssembly(parseAssembly), printAssembly(printAssembly),
verifyInvariants(verifyInvariants), foldHook(foldHook),
getCanonicalizationPatterns(getCanonicalizationPatterns),
@@ -193,11 +193,11 @@ class AbstractOperation {
/// Returns a raw instance of the concept for the given interface id if it is
/// registered to this operation, nullptr otherwise. This should not be used
/// directly.
- void *(&getRawInterface)(ClassID *interfaceID);
+ void *(&getRawInterface)(TypeID interfaceID);
/// This hook returns if the operation contains the trait corresponding
- /// to the given ClassID.
- bool (&hasRawTrait)(ClassID *traitID);
+ /// to the given TypeID.
+ bool (&hasRawTrait)(TypeID traitID);
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h
index 7520f8e053ed..1b01b7a9970c 100644
--- a/mlir/include/mlir/IR/StorageUniquerSupport.h
+++ b/mlir/include/mlir/IR/StorageUniquerSupport.h
@@ -14,8 +14,8 @@
#define MLIR_IR_STORAGEUNIQUERSUPPORT_H
#include "mlir/Support/LogicalResult.h"
-#include "mlir/Support/STLExtras.h"
#include "mlir/Support/StorageUniquer.h"
+#include "mlir/Support/TypeID.h"
namespace mlir {
class AttributeStorage;
@@ -41,7 +41,7 @@ class StorageUserBase : public BaseT {
using ImplType = StorageT;
/// Return a unique identifier for the concrete type.
- static ClassID *getClassID() { return ClassID::getID<ConcreteT>(); }
+ static TypeID getTypeID() { return TypeID::get<ConcreteT>(); }
/// Provide a default implementation of 'classof' that invokes a 'kindof'
/// method on the concrete type.
diff --git a/mlir/include/mlir/IR/TypeSupport.h b/mlir/include/mlir/IR/TypeSupport.h
index 7d01883c1f86..cb093c2188d6 100644
--- a/mlir/include/mlir/IR/TypeSupport.h
+++ b/mlir/include/mlir/IR/TypeSupport.h
@@ -17,7 +17,6 @@
#include "mlir/IR/StorageUniquerSupport.h"
namespace mlir {
-struct ClassID;
class Dialect;
class MLIRContext;
@@ -36,7 +35,7 @@ class TypeStorage : public StorageUniquer::BaseStorage {
protected:
/// This constructor is used by derived classes as part of the TypeUniquer.
- /// When using this constructor, the initializeTypeInfo function must be
+ /// When using this constructor, the initializeDialect function must be
/// invoked afterwards for the storage to be valid.
TypeStorage(unsigned subclassData = 0)
: dialect(nullptr), subclassData(subclassData) {}
@@ -98,12 +97,11 @@ class TypeUniquer {
private:
/// Get the dialect that the type 'T' was registered with.
template <typename T> static Dialect &lookupDialectForType(MLIRContext *ctx) {
- return lookupDialectForType(ctx, T::getClassID());
+ return lookupDialectForType(ctx, T::getTypeID());
}
/// Get the dialect that registered the type with the provided typeid.
- static Dialect &lookupDialectForType(MLIRContext *ctx,
- const ClassID *const typeID);
+ static Dialect &lookupDialectForType(MLIRContext *ctx, TypeID typeID);
};
} // namespace detail
diff --git a/mlir/include/mlir/Interfaces/SideEffects.h b/mlir/include/mlir/Interfaces/SideEffects.h
index e0fd17590fdb..f3f4c44238a2 100644
--- a/mlir/include/mlir/Interfaces/SideEffects.h
+++ b/mlir/include/mlir/Interfaces/SideEffects.h
@@ -32,7 +32,7 @@ class Effect {
using BaseT = Base<DerivedEffect>;
/// Return the unique identifier for the base effects class.
- static ClassID *getEffectID() { return ClassID::getID<DerivedEffect>(); }
+ static TypeID getEffectID() { return TypeID::get<DerivedEffect>(); }
/// 'classof' used to support llvm style cast functionality.
static bool classof(const ::mlir::SideEffects::Effect *effect) {
@@ -46,11 +46,11 @@ class Effect {
using BaseEffect::get;
protected:
- Base() : BaseEffect(BaseT::getEffectID()){};
+ Base() : BaseEffect(BaseT::getEffectID()) {}
};
/// Return the unique identifier for the base effects class.
- ClassID *getEffectID() const { return id; }
+ TypeID getEffectID() const { return id; }
/// Returns a unique instance for the given effect class.
template <typename DerivedEffect> static DerivedEffect *get() {
@@ -62,11 +62,11 @@ class Effect {
}
protected:
- Effect(ClassID *id) : id(id) {}
+ Effect(TypeID id) : id(id) {}
private:
/// The id of the derived effect class.
- ClassID *id;
+ TypeID id;
};
//===----------------------------------------------------------------------===//
@@ -92,9 +92,7 @@ class Resource {
}
/// Return the unique identifier for the base resource class.
- static ClassID *getResourceID() {
- return ClassID::getID<DerivedResource>();
- }
+ static TypeID getResourceID() { return TypeID::get<DerivedResource>(); }
/// 'classof' used to support llvm style cast functionality.
static bool classof(const Resource *resource) {
@@ -106,17 +104,17 @@ class Resource {
};
/// Return the unique identifier for the base resource class.
- ClassID *getResourceID() const { return id; }
+ TypeID getResourceID() const { return id; }
/// Return a string name of the resource.
virtual StringRef getName() = 0;
protected:
- Resource(ClassID *id) : id(id) {}
+ Resource(TypeID id) : id(id) {}
private:
/// The id of the derived resource class.
- ClassID *id;
+ TypeID id;
};
/// A conservative default resource kind.
diff --git a/mlir/include/mlir/Pass/AnalysisManager.h b/mlir/include/mlir/Pass/AnalysisManager.h
index a50bee000b25..2b948c2f73da 100644
--- a/mlir/include/mlir/Pass/AnalysisManager.h
+++ b/mlir/include/mlir/Pass/AnalysisManager.h
@@ -17,10 +17,6 @@
#include "llvm/Support/TypeName.h"
namespace mlir {
-/// A special type used by analyses to provide an address that identifies a
-/// particular analysis set or a concrete analysis type.
-using AnalysisID = ClassID;
-
//===----------------------------------------------------------------------===//
// Analysis Preservation and Concept Modeling
//===----------------------------------------------------------------------===//
@@ -28,43 +24,43 @@ using AnalysisID = ClassID;
namespace detail {
/// A utility class to represent the analyses that are known to be preserved.
class PreservedAnalyses {
+ /// A type used to represent all potential analyses.
+ struct AllAnalysesType;
+
public:
/// Mark all analyses as preserved.
- void preserveAll() { preservedIDs.insert(&allAnalysesID); }
+ void preserveAll() { preservedIDs.insert(TypeID::get<AllAnalysesType>()); }
/// Returns true if all analyses were marked preserved.
- bool isAll() const { return preservedIDs.count(&allAnalysesID); }
+ bool isAll() const {
+ return preservedIDs.count(TypeID::get<AllAnalysesType>());
+ }
/// Returns true if no analyses were marked preserved.
bool isNone() const { return preservedIDs.empty(); }
/// Preserve the given analyses.
template <typename AnalysisT> void preserve() {
- preserve(AnalysisID::getID<AnalysisT>());
+ preserve(TypeID::get<AnalysisT>());
}
template <typename AnalysisT, typename AnalysisT2, typename... OtherAnalysesT>
void preserve() {
preserve<AnalysisT>();
preserve<AnalysisT2, OtherAnalysesT...>();
}
- void preserve(const AnalysisID *id) { preservedIDs.insert(id); }
+ void preserve(TypeID id) { preservedIDs.insert(id); }
/// Returns if the given analysis has been marked as preserved. Note that this
/// simply checks for the presence of a given analysis ID and should not be
/// used as a general preservation checker.
template <typename AnalysisT> bool isPreserved() const {
- return isPreserved(AnalysisID::getID<AnalysisT>());
- }
- bool isPreserved(const AnalysisID *id) const {
- return preservedIDs.count(id);
+ return isPreserved(TypeID::get<AnalysisT>());
}
+ bool isPreserved(TypeID id) const { return preservedIDs.count(id); }
private:
- /// An identifier used to represent all potential analyses.
- constexpr static AnalysisID allAnalysesID = {};
-
/// The set of analyses that are known to be preserved.
- SmallPtrSet<const void *, 2> preservedIDs;
+ SmallPtrSet<TypeID, 2> preservedIDs;
};
namespace analysis_impl {
@@ -118,8 +114,7 @@ template <typename AnalysisT> struct AnalysisModel : public AnalysisConcept {
/// computation, caching, and invalidation of analyses takes place here.
class AnalysisMap {
/// A mapping between an analysis id and an existing analysis instance.
- using ConceptMap =
- DenseMap<const AnalysisID *, std::unique_ptr<AnalysisConcept>>;
+ using ConceptMap = DenseMap<TypeID, std::unique_ptr<AnalysisConcept>>;
/// Utility to return the name of the given analysis class.
template <typename AnalysisT> static StringRef getAnalysisName() {
@@ -134,7 +129,7 @@ class AnalysisMap {
/// Get an analysis for the current IR unit, computing it if necessary.
template <typename AnalysisT> AnalysisT &getAnalysis(PassInstrumentor *pi) {
- auto *id = AnalysisID::getID<AnalysisT>();
+ TypeID id = TypeID::get<AnalysisT>();
typename ConceptMap::iterator it;
bool wasInserted;
@@ -157,7 +152,7 @@ class AnalysisMap {
/// Get a cached analysis instance if one exists, otherwise return null.
template <typename AnalysisT>
Optional<std::reference_wrapper<AnalysisT>> getCachedAnalysis() const {
- auto res = analyses.find(AnalysisID::getID<AnalysisT>());
+ auto res = analyses.find(TypeID::get<AnalysisT>());
if (res == analyses.end())
return llvm::None;
return {static_cast<AnalysisModel<AnalysisT> &>(*res->second).analysis};
diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h
index 9b9dd7533a1e..5e0098458755 100644
--- a/mlir/include/mlir/Pass/Pass.h
+++ b/mlir/include/mlir/Pass/Pass.h
@@ -37,22 +37,22 @@ struct PassExecutionState {
} // namespace detail
/// The abstract base pass class. This class contains information describing the
-/// derived pass object, e.g its kind and abstract PassInfo.
+/// derived pass object, e.g its kind and abstract TypeID.
class Pass {
public:
virtual ~Pass() = default;
/// Returns the unique identifier that corresponds to this pass.
- const PassID *getPassID() const { return passID; }
+ TypeID getTypeID() const { return passID; }
/// Returns the pass info for the specified pass class or null if unknown.
- static const PassInfo *lookupPassInfo(const PassID *passID);
+ static const PassInfo *lookupPassInfo(TypeID passID);
template <typename PassT> static const PassInfo *lookupPassInfo() {
- return lookupPassInfo(PassID::getID<PassT>());
+ return lookupPassInfo(TypeID::get<PassT>());
}
/// Returns the pass info for this pass.
- const PassInfo *lookupPassInfo() const { return lookupPassInfo(getPassID()); }
+ const PassInfo *lookupPassInfo() const { return lookupPassInfo(getTypeID()); }
/// Returns the derived pass name.
virtual StringRef getName() const = 0;
@@ -130,7 +130,7 @@ class Pass {
MutableArrayRef<Statistic *> getStatistics() { return statistics; }
protected:
- explicit Pass(const PassID *passID, Optional<StringRef> opName = llvm::None)
+ explicit Pass(TypeID passID, Optional<StringRef> opName = llvm::None)
: passID(passID), opName(opName) {}
Pass(const Pass &other) : Pass(other.passID, other.opName) {}
@@ -183,7 +183,7 @@ class Pass {
template <typename... AnalysesT> void markAnalysesPreserved() {
getPassState().preservedAnalyses.preserve<AnalysesT...>();
}
- void markAnalysesPreserved(const AnalysisID *id) {
+ void markAnalysesPreserved(TypeID id) {
getPassState().preservedAnalyses.preserve(id);
}
@@ -234,7 +234,7 @@ class Pass {
virtual void anchor();
/// Represents a unique identifier for the pass.
- const PassID *passID;
+ TypeID passID;
/// The name of the operation that this pass operates on, or None if this is a
/// generic OperationPass.
@@ -274,7 +274,7 @@ class Pass {
/// - A 'std::unique_ptr<Pass> clonePass() const' method.
template <typename OpT = void> class OperationPass : public Pass {
protected:
- OperationPass(const PassID *passID) : Pass(passID, OpT::getOperationName()) {}
+ OperationPass(TypeID passID) : Pass(passID, OpT::getOperationName()) {}
/// Support isa/dyn_cast functionality.
static bool classof(const Pass *pass) {
@@ -299,7 +299,7 @@ template <typename OpT = void> class OperationPass : public Pass {
/// - A 'std::unique_ptr<Pass> clonePass() const' method.
template <> class OperationPass<void> : public Pass {
protected:
- OperationPass(const PassID *passID) : Pass(passID) {}
+ OperationPass(TypeID passID) : Pass(passID) {}
};
/// A model for providing function pass specific utilities.
@@ -333,11 +333,11 @@ template <typename PassT, typename BaseT> class PassWrapper : public BaseT {
public:
/// Support isa/dyn_cast functionality for the derived pass class.
static bool classof(const Pass *pass) {
- return pass->getPassID() == PassID::getID<PassT>();
+ return pass->getTypeID() == TypeID::get<PassT>();
}
protected:
- PassWrapper() : BaseT(PassID::getID<PassT>()) {}
+ PassWrapper() : BaseT(TypeID::get<PassT>()) {}
/// Returns the derived pass name.
StringRef getName() const override { return llvm::getTypeName<PassT>(); }
diff --git a/mlir/include/mlir/Pass/PassInstrumentation.h b/mlir/include/mlir/Pass/PassInstrumentation.h
index dc57a5ad7d72..dc648b2b0edf 100644
--- a/mlir/include/mlir/Pass/PassInstrumentation.h
+++ b/mlir/include/mlir/Pass/PassInstrumentation.h
@@ -10,12 +10,9 @@
#define MLIR_PASS_PASSINSTRUMENTATION_H_
#include "mlir/Support/LLVM.h"
-#include "mlir/Support/STLExtras.h"
-#include "llvm/ADT/DenseMapInfo.h"
-#include "llvm/ADT/StringRef.h"
+#include "mlir/Support/TypeID.h"
namespace mlir {
-using AnalysisID = ClassID;
class Operation;
class OperationName;
class Pass;
@@ -72,16 +69,14 @@ class PassInstrumentation {
virtual void runAfterPassFailed(Pass *pass, Operation *op) {}
/// A callback to run before an analysis is computed. This function takes the
- /// name of the analysis to be computed, its AnalysisID, as well as the
+ /// name of the analysis to be computed, its TypeID, as well as the
/// current operation being analyzed.
- virtual void runBeforeAnalysis(StringRef name, AnalysisID *id,
- Operation *op) {}
+ virtual void runBeforeAnalysis(StringRef name, TypeID id, Operation *op) {}
/// A callback to run before an analysis is computed. This function takes the
- /// name of the analysis that was computed, its AnalysisID, as well as the
+ /// name of the analysis that was computed, its TypeID, as well as the
/// current operation being analyzed.
- virtual void runAfterAnalysis(StringRef name, AnalysisID *id, Operation *op) {
- }
+ virtual void runAfterAnalysis(StringRef name, TypeID id, Operation *op) {}
};
/// This class holds a collection of PassInstrumentation objects, and invokes
@@ -113,10 +108,10 @@ class PassInstrumentor {
void runAfterPassFailed(Pass *pass, Operation *op);
/// See PassInstrumentation::runBeforeAnalysis for details.
- void runBeforeAnalysis(StringRef name, AnalysisID *id, Operation *op);
+ void runBeforeAnalysis(StringRef name, TypeID id, Operation *op);
/// See PassInstrumentation::runAfterAnalysis for details.
- void runAfterAnalysis(StringRef name, AnalysisID *id, Operation *op);
+ void runAfterAnalysis(StringRef name, TypeID id, Operation *op);
/// Add the given instrumentation to the collection.
void addInstrumentation(std::unique_ptr<PassInstrumentation> pi);
diff --git a/mlir/include/mlir/Pass/PassRegistry.h b/mlir/include/mlir/Pass/PassRegistry.h
index 31c09c125919..3187d8eb34c7 100644
--- a/mlir/include/mlir/Pass/PassRegistry.h
+++ b/mlir/include/mlir/Pass/PassRegistry.h
@@ -15,6 +15,7 @@
#define MLIR_PASS_PASSREGISTRY_H_
#include "mlir/Pass/PassOptions.h"
+#include "mlir/Support/TypeID.h"
#include <functional>
namespace mlir {
@@ -31,10 +32,6 @@ using PassRegistryFunction =
std::function<LogicalResult(OpPassManager &, StringRef options)>;
using PassAllocatorFunction = std::function<std::unique_ptr<Pass>()>;
-/// A special type used by transformation passes to provide an address that can
-/// act as a unique identifier during pass registration.
-using PassID = ClassID;
-
//===----------------------------------------------------------------------===//
// PassRegistry
//===----------------------------------------------------------------------===//
@@ -105,7 +102,7 @@ class PassInfo : public PassRegistryEntry {
public:
/// PassInfo constructor should not be invoked directly, instead use
/// PassRegistration or registerPass.
- PassInfo(StringRef arg, StringRef description, const PassID *passID,
+ PassInfo(StringRef arg, StringRef description, TypeID passID,
const PassAllocatorFunction &allocator);
};
diff --git a/mlir/include/mlir/Support/STLExtras.h b/mlir/include/mlir/Support/STLExtras.h
index 6cd1047b1ea5..9a2b2d35bb49 100644
--- a/mlir/include/mlir/Support/STLExtras.h
+++ b/mlir/include/mlir/Support/STLExtras.h
@@ -88,23 +88,6 @@ inline void interleaveComma(const Container &c, raw_ostream &os) {
interleaveComma(c, os, [&](const T &a) { os << a; });
}
-/// A special type used to provide an address for a given class that can act as
-/// a unique identifier during pass registration.
-/// Note: We specify an explicit alignment here to allow use with PointerIntPair
-/// and other utilities/data structures that require a known pointer alignment.
-struct alignas(8) ClassID {
- template <typename T>
- LLVM_EXTERNAL_VISIBILITY static ClassID *getID() {
- static ClassID id;
- return &id;
- }
- template <template <typename T> class Trait>
- LLVM_EXTERNAL_VISIBILITY static ClassID *getID() {
- static ClassID id;
- return &id;
- }
-};
-
/// Utilities for detecting if a given trait holds for some set of arguments
/// 'Args'. For example, the given trait could be used to detect if a given type
/// has a copy assignment operator:
diff --git a/mlir/include/mlir/Support/TypeID.h b/mlir/include/mlir/Support/TypeID.h
new file mode 100644
index 000000000000..518ff39e8669
--- /dev/null
+++ b/mlir/include/mlir/Support/TypeID.h
@@ -0,0 +1,133 @@
+//===- TypeID.h - TypeID RTTI class -----------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains a definition of the TypeID class. This provides a non
+// RTTI mechanism for producing unique type IDs in LLVM.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_SUPPORT_TYPEID_H
+#define MLIR_SUPPORT_TYPEID_H
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/DenseMapInfo.h"
+#include "llvm/Support/PointerLikeTypeTraits.h"
+
+namespace mlir {
+
+/// This class provides an efficient unique identifier for a specific C++ type.
+/// This allows for a C++ type to be compared, hashed, and stored in an opaque
+/// context. This class is similar in some ways to std::type_index, but can be
+/// used for any type. For example, this class could be used to implement LLVM
+/// style isa/dyn_cast functionality for a type hierarchy:
+///
+/// struct Base {
+/// Base(TypeID typeID) : typeID(typeID) {}
+/// TypeID typeID;
+/// };
+///
+/// struct DerivedA : public Base {
+/// DerivedA() : Base(TypeID::get<DerivedA>()) {}
+///
+/// static bool classof(const Base *base) {
+/// return base->typeID == TypeID::get<DerivedA>();
+/// }
+/// };
+///
+/// void foo(Base *base) {
+/// if (DerivedA *a = llvm::dyn_cast<DerivedA>(base))
+/// ...
+/// }
+///
+class TypeID {
+ /// This class represents the storage of a type info object.
+ /// Note: We specify an explicit alignment here to allow use with
+ /// PointerIntPair and other utilities/data structures that require a known
+ /// pointer alignment.
+ struct alignas(8) Storage {};
+
+public:
+ TypeID() : TypeID(get<void>()) {}
+ TypeID(const TypeID &) = default;
+
+ /// Comparison operations.
+ bool operator==(const TypeID &other) const {
+ return storage == other.storage;
+ }
+ bool operator!=(const TypeID &other) const { return !(*this == other); }
+
+ /// Construct a type info object for the given type T.
+ /// TODO: This currently won't work when using DLLs as it requires properly
+ /// attaching dllimport and dllexport. Fix this when that information is
+ /// available within LLVM.
+ template <typename T>
+ LLVM_EXTERNAL_VISIBILITY static TypeID get() {
+ static Storage instance;
+ return TypeID(&instance);
+ }
+ template <template <typename> class Trait>
+ LLVM_EXTERNAL_VISIBILITY static TypeID get() {
+ static Storage instance;
+ return TypeID(&instance);
+ }
+
+ /// Methods for supporting PointerLikeTypeTraits.
+ const void *getAsOpaquePointer() const {
+ return static_cast<const void *>(storage);
+ }
+ static TypeID getFromOpaquePointer(const void *pointer) {
+ return TypeID(reinterpret_cast<const Storage *>(pointer));
+ }
+
+ /// Enable hashing TypeID.
+ friend ::llvm::hash_code hash_value(TypeID id);
+
+private:
+ TypeID(const Storage *storage) : storage(storage) {}
+
+ /// The storage of this type info object.
+ const Storage *storage;
+};
+
+/// Enable hashing TypeID.
+inline ::llvm::hash_code hash_value(TypeID id) {
+ return llvm::hash_value(id.storage);
+}
+
+} // end namespace mlir
+
+namespace llvm {
+template <> struct DenseMapInfo<mlir::TypeID> {
+ static mlir::TypeID getEmptyKey() {
+ void *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
+ return mlir::TypeID::getFromOpaquePointer(pointer);
+ }
+ static mlir::TypeID getTombstoneKey() {
+ void *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
+ return mlir::TypeID::getFromOpaquePointer(pointer);
+ }
+ static unsigned getHashValue(mlir::TypeID val) {
+ return mlir::hash_value(val);
+ }
+ static bool isEqual(mlir::TypeID lhs, mlir::TypeID rhs) { return lhs == rhs; }
+};
+
+/// We align TypeID::Storage by 8, so allow LLVM to steal the low bits.
+template <> struct PointerLikeTypeTraits<mlir::TypeID> {
+ static inline void *getAsVoidPointer(mlir::TypeID info) {
+ return const_cast<void *>(info.getAsOpaquePointer());
+ }
+ static inline mlir::TypeID getFromVoidPointer(void *ptr) {
+ return mlir::TypeID::getFromOpaquePointer(ptr);
+ }
+ static constexpr int NumLowBitsAvailable = 3;
+};
+
+} // end namespace llvm
+
+#endif // MLIR_SUPPORT_TYPEID_H
diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index e48e7f64010d..501cddad2a48 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -28,30 +28,29 @@ DialectAsmParser::~DialectAsmParser() {}
//===----------------------------------------------------------------------===//
/// Registry for all dialect allocation functions.
-static llvm::ManagedStatic<
- llvm::MapVector<const ClassID *, DialectAllocatorFunction>>
+static llvm::ManagedStatic<llvm::MapVector<TypeID, DialectAllocatorFunction>>
dialectRegistry;
/// Registry for functions that set dialect hooks.
-static llvm::ManagedStatic<llvm::MapVector<const ClassID *, DialectHooksSetter>>
+static llvm::ManagedStatic<llvm::MapVector<TypeID, DialectHooksSetter>>
dialectHooksRegistry;
void Dialect::registerDialectAllocator(
- const ClassID *classId, const DialectAllocatorFunction &function) {
+ TypeID typeID, const DialectAllocatorFunction &function) {
assert(function &&
"Attempting to register an empty dialect initialize function");
- dialectRegistry->insert({classId, function});
+ dialectRegistry->insert({typeID, function});
}
/// Registers a function to set specific hooks for a specific dialect, typically
/// used through the DialectHooksRegistration template.
void DialectHooks::registerDialectHooksSetter(
- const ClassID *classId, const DialectHooksSetter &function) {
+ TypeID typeID, const DialectHooksSetter &function) {
assert(
function &&
"Attempting to register an empty dialect hooks initialization function");
- dialectHooksRegistry->insert({classId, function});
+ dialectHooksRegistry->insert({typeID, function});
}
/// Registers all dialects and hooks from the global registries with the
@@ -59,9 +58,8 @@ void DialectHooks::registerDialectHooksSetter(
void mlir::registerAllDialects(MLIRContext *context) {
for (const auto &it : *dialectRegistry)
it.second(context);
- for (const auto &it : *dialectHooksRegistry) {
+ for (const auto &it : *dialectHooksRegistry)
it.second(context);
- }
}
//===----------------------------------------------------------------------===//
@@ -139,7 +137,7 @@ void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
DialectInterface::~DialectInterface() {}
DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
- MLIRContext *ctx, ClassID *interfaceKind) {
+ MLIRContext *ctx, TypeID interfaceKind) {
for (auto *dialect : ctx->getRegisteredDialects()) {
if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
interfaces.insert(interface);
diff --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp
index c9727c500858..f22fd5cb7852 100644
--- a/mlir/lib/IR/Location.cpp
+++ b/mlir/lib/IR/Location.cpp
@@ -119,18 +119,18 @@ Location NameLoc::getChildLoc() const { return getImpl()->child; }
// OpaqueLoc
//===----------------------------------------------------------------------===//
-Location OpaqueLoc::get(uintptr_t underlyingLocation, ClassID *classID,
+Location OpaqueLoc::get(uintptr_t underlyingLocation, TypeID typeID,
Location fallbackLocation) {
return Base::get(fallbackLocation->getContext(),
StandardAttributes::OpaqueLocation, underlyingLocation,
- classID, fallbackLocation);
+ typeID, fallbackLocation);
}
uintptr_t OpaqueLoc::getUnderlyingLocation() const {
return Base::getImpl()->underlyingLocation;
}
-ClassID *OpaqueLoc::getClassId() const { return getImpl()->classId; }
+TypeID OpaqueLoc::getUnderlyingTypeID() const { return getImpl()->typeID; }
Location OpaqueLoc::getFallbackLocation() const {
return Base::getImpl()->fallbackLocation;
diff --git a/mlir/lib/IR/LocationDetail.h b/mlir/lib/IR/LocationDetail.h
index cf9a10903d88..c84f685d134a 100644
--- a/mlir/lib/IR/LocationDetail.h
+++ b/mlir/lib/IR/LocationDetail.h
@@ -126,15 +126,15 @@ struct NameLocationStorage : public AttributeStorage {
};
struct OpaqueLocationStorage : public AttributeStorage {
- OpaqueLocationStorage(uintptr_t underlyingLocation, ClassID *classId,
+ OpaqueLocationStorage(uintptr_t underlyingLocation, TypeID typeID,
Location fallbackLocation)
- : underlyingLocation(underlyingLocation), classId(classId),
+ : underlyingLocation(underlyingLocation), typeID(typeID),
fallbackLocation(fallbackLocation) {}
/// The hash key used for uniquing.
- using KeyTy = std::tuple<uintptr_t, ClassID *, Location>;
+ using KeyTy = std::tuple<uintptr_t, TypeID, Location>;
bool operator==(const KeyTy &key) const {
- return key == KeyTy(underlyingLocation, classId, fallbackLocation);
+ return key == KeyTy(underlyingLocation, typeID, fallbackLocation);
}
/// Construct a new storage instance.
@@ -149,7 +149,7 @@ struct OpaqueLocationStorage : public AttributeStorage {
uintptr_t underlyingLocation;
/// A unique pointer for each type of underlyingLocation.
- ClassID *classId;
+ TypeID typeID;
/// An additional location that can be used if the external one is not
/// suitable.
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 2407c8f30c4f..1623122df39c 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -205,9 +205,9 @@ class MLIRContextImpl {
/// operations.
llvm::StringMap<AbstractOperation> registeredOperations;
- /// This is a mapping from class identifier to Dialect for registered
- /// attributes and types.
- DenseMap<const ClassID *, Dialect *> registeredDialectSymbols;
+ /// This is a mapping from type id to Dialect for registered attributes and
+ /// types.
+ DenseMap<TypeID, Dialect *> registeredDialectSymbols;
/// These are identifiers uniqued into this MLIRContext.
llvm::StringMap<char, llvm::BumpPtrAllocator &> identifiers;
@@ -462,12 +462,12 @@ void Dialect::addOperation(AbstractOperation opInfo) {
}
/// Register a dialect-specific symbol(e.g. type) with the current context.
-void Dialect::addSymbol(const ClassID *const classID) {
+void Dialect::addSymbol(TypeID typeID) {
auto &impl = context->getImpl();
// Lock access to the context registry.
llvm::sys::SmartScopedWriter<true> registryLock(impl.contextMutex);
- if (!impl.registeredDialectSymbols.insert({classID, this}).second) {
+ if (!impl.registeredDialectSymbols.insert({typeID, this}).second) {
llvm::errs() << "error: dialect symbol already registered.\n";
abort();
}
@@ -516,10 +516,9 @@ Identifier Identifier::get(StringRef str, MLIRContext *context) {
// Type uniquing
//===----------------------------------------------------------------------===//
-static Dialect &lookupDialectForSymbol(MLIRContext *ctx,
- const ClassID *const classID) {
+static Dialect &lookupDialectForSymbol(MLIRContext *ctx, TypeID typeID) {
auto &impl = ctx->getImpl();
- auto it = impl.registeredDialectSymbols.find(classID);
+ auto it = impl.registeredDialectSymbols.find(typeID);
assert(it != impl.registeredDialectSymbols.end() &&
"symbol is not registered.");
return *it->second;
@@ -530,8 +529,7 @@ static Dialect &lookupDialectForSymbol(MLIRContext *ctx,
StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
/// Get the dialect that registered the type with the provided typeid.
-Dialect &TypeUniquer::lookupDialectForType(MLIRContext *ctx,
- const ClassID *const typeID) {
+Dialect &TypeUniquer::lookupDialectForType(MLIRContext *ctx, TypeID typeID) {
return lookupDialectForSymbol(ctx, typeID);
}
@@ -625,7 +623,7 @@ StorageUniquer &MLIRContext::getAttributeUniquer() {
/// Returns a functor used to initialize new attribute storage instances.
std::function<void(AttributeStorage *)>
-AttributeUniquer::getInitFn(MLIRContext *ctx, const ClassID *const attrID) {
+AttributeUniquer::getInitFn(MLIRContext *ctx, TypeID attrID) {
return [ctx, attrID](AttributeStorage *storage) {
storage->initializeDialect(lookupDialectForSymbol(ctx, attrID));
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 0c11973b9433..8bcf3b7a7a0f 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -738,7 +738,7 @@ void PassInstrumentor::runAfterPassFailed(Pass *pass, Operation *op) {
}
/// See PassInstrumentation::runBeforeAnalysis for details.
-void PassInstrumentor::runBeforeAnalysis(StringRef name, AnalysisID *id,
+void PassInstrumentor::runBeforeAnalysis(StringRef name, TypeID id,
Operation *op) {
llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
for (auto &instr : impl->instrumentations)
@@ -746,7 +746,7 @@ void PassInstrumentor::runBeforeAnalysis(StringRef name, AnalysisID *id,
}
/// See PassInstrumentation::runAfterAnalysis for details.
-void PassInstrumentor::runAfterAnalysis(StringRef name, AnalysisID *id,
+void PassInstrumentor::runAfterAnalysis(StringRef name, TypeID id,
Operation *op) {
llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
for (auto &instr : llvm::reverse(impl->instrumentations))
@@ -759,5 +759,3 @@ void PassInstrumentor::addInstrumentation(
llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
impl->instrumentations.emplace_back(std::move(pi));
}
-
-constexpr AnalysisID mlir::detail::PreservedAnalyses::allAnalysesID;
diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp
index 69962ecd20c8..1d88ebd82882 100644
--- a/mlir/lib/Pass/PassRegistry.cpp
+++ b/mlir/lib/Pass/PassRegistry.cpp
@@ -18,7 +18,7 @@ using namespace mlir;
using namespace detail;
/// Static mapping of all of the registered passes.
-static llvm::ManagedStatic<DenseMap<const PassID *, PassInfo>> passRegistry;
+static llvm::ManagedStatic<DenseMap<TypeID, PassInfo>> passRegistry;
/// Static mapping of all of the registered pass pipelines.
static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
@@ -86,7 +86,7 @@ void mlir::registerPassPipeline(
// PassInfo
//===----------------------------------------------------------------------===//
-PassInfo::PassInfo(StringRef arg, StringRef description, const PassID *passID,
+PassInfo::PassInfo(StringRef arg, StringRef description, TypeID passID,
const PassAllocatorFunction &allocator)
: PassRegistryEntry(
arg, description, buildDefaultRegistryFn(allocator),
@@ -98,13 +98,13 @@ PassInfo::PassInfo(StringRef arg, StringRef description, const PassID *passID,
void mlir::registerPass(StringRef arg, StringRef description,
const PassAllocatorFunction &function) {
// TODO: We should use the 'arg' as the lookup key instead of the pass id.
- const PassID *passID = function()->getPassID();
+ TypeID passID = function()->getTypeID();
PassInfo passInfo(arg, description, passID, function);
passRegistry->try_emplace(passID, passInfo);
}
/// Returns the pass info for the specified pass class or null if unknown.
-const PassInfo *mlir::Pass::lookupPassInfo(const PassID *passID) {
+const PassInfo *mlir::Pass::lookupPassInfo(TypeID passID) {
auto it = passRegistry->find(passID);
if (it == passRegistry->end())
return nullptr;
diff --git a/mlir/lib/Pass/PassTiming.cpp b/mlir/lib/Pass/PassTiming.cpp
index 768c68620ab9..663cbdad7c39 100644
--- a/mlir/lib/Pass/PassTiming.cpp
+++ b/mlir/lib/Pass/PassTiming.cpp
@@ -173,10 +173,10 @@ struct PassTiming : public PassInstrumentation {
void runAfterPassFailed(Pass *pass, Operation *op) override {
runAfterPass(pass, op);
}
- void runBeforeAnalysis(StringRef name, AnalysisID *id, Operation *) override {
+ void runBeforeAnalysis(StringRef name, TypeID id, Operation *) override {
startAnalysisTimer(name, id);
}
- void runAfterAnalysis(StringRef, AnalysisID *, Operation *) override;
+ void runAfterAnalysis(StringRef, TypeID, Operation *) override;
/// Print and clear the timing results.
void print();
@@ -185,7 +185,7 @@ struct PassTiming : public PassInstrumentation {
void startPassTimer(Pass *pass);
/// Start a new timer for the given analysis.
- void startAnalysisTimer(StringRef name, AnalysisID *id);
+ void startAnalysisTimer(StringRef name, TypeID id);
/// Pop the last active timer for the current thread.
Timer *popLastActiveTimer() {
@@ -291,8 +291,8 @@ void PassTiming::startPassTimer(Pass *pass) {
}
/// Start a new timer for the given analysis.
-void PassTiming::startAnalysisTimer(StringRef name, AnalysisID *id) {
- Timer *timer = getTimer(id, TimerKind::PassOrAnalysis,
+void PassTiming::startAnalysisTimer(StringRef name, TypeID id) {
+ Timer *timer = getTimer(id.getAsOpaquePointer(), TimerKind::PassOrAnalysis,
[name] { return "(A) " + name.str(); });
timer->start();
}
@@ -320,7 +320,7 @@ void PassTiming::runAfterPass(Pass *pass, Operation *) {
}
/// Stop a timer.
-void PassTiming::runAfterAnalysis(StringRef, AnalysisID *, Operation *) {
+void PassTiming::runAfterAnalysis(StringRef, TypeID, Operation *) {
popLastActiveTimer()->stop();
}
diff --git a/mlir/tools/mlir-tblgen/PassGen.cpp b/mlir/tools/mlir-tblgen/PassGen.cpp
index 7cd9590110ad..2ca8671a2f78 100644
--- a/mlir/tools/mlir-tblgen/PassGen.cpp
+++ b/mlir/tools/mlir-tblgen/PassGen.cpp
@@ -38,8 +38,8 @@ const char *const passDeclBegin = R"(
template <typename DerivedT>
class {0}Base : public {1} {
public:
- {0}Base() : {1}(PassID::getID<DerivedT>()) {{}
- {0}Base(const {0}Base &) : {1}(PassID::getID<DerivedT>()) {{}
+ {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{}
+ {0}Base(const {0}Base &) : {1}(::mlir::TypeID::get<DerivedT>()) {{}
/// Returns the command-line argument attached to this pass.
llvm::StringRef getArgument() const override { return "{2}"; }
@@ -49,7 +49,7 @@ class {0}Base : public {1} {
/// Support isa/dyn_cast functionality for the derived pass class.
static bool classof(const ::mlir::Pass *pass) {{
- return pass->getPassID() == ::mlir::PassID::getID<DerivedT>();
+ return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
}
/// A clone method to create a copy of this pass.
More information about the Mlir-commits
mailing list