[Mlir-commits] [mlir] 86646be - [mlir] Refactor StorageUniquer to require registration of possible storage types

River Riddle llvmlistbot at llvm.org
Fri Aug 7 13:43:42 PDT 2020


Author: River Riddle
Date: 2020-08-07T13:43:24-07:00
New Revision: 86646be3158933330bf3342e9d7e4250945bb70c

URL: https://github.com/llvm/llvm-project/commit/86646be3158933330bf3342e9d7e4250945bb70c
DIFF: https://github.com/llvm/llvm-project/commit/86646be3158933330bf3342e9d7e4250945bb70c.diff

LOG: [mlir] Refactor StorageUniquer to require registration of possible storage types

This allows for bucketing the different possible storage types, with each bucket having its own allocator/mutex/instance map. This greatly reduces the amount of lock contention when multi-threading is enabled. On some non-trivial .mlir modules (>300K operations), this led to a compile time decrease of a single conversion pass by around half a second(>25%).

Differential Revision: https://reviews.llvm.org/D82596

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SDBM/SDBMDialect.h
    mlir/include/mlir/IR/AttributeSupport.h
    mlir/include/mlir/IR/StorageUniquerSupport.h
    mlir/include/mlir/IR/TypeSupport.h
    mlir/include/mlir/Support/StorageUniquer.h
    mlir/lib/Dialect/SDBM/SDBMDialect.cpp
    mlir/lib/Dialect/SDBM/SDBMExpr.cpp
    mlir/lib/IR/AffineExpr.cpp
    mlir/lib/IR/MLIRContext.cpp
    mlir/lib/Support/StorageUniquer.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h b/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h
index 901ada94482f..85cfe91d2c9b 100644
--- a/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h
+++ b/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h
@@ -17,8 +17,7 @@ class MLIRContext;
 
 class SDBMDialect : public Dialect {
 public:
-  SDBMDialect(MLIRContext *context)
-      : Dialect(getDialectNamespace(), context, TypeID::get<SDBMDialect>()) {}
+  SDBMDialect(MLIRContext *context);
 
   /// Since there are no other virtual methods in this derived class, override
   /// the destructor so that key methods get defined in the corresponding

diff  --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h
index 72a89be43867..79ce1dd2db95 100644
--- a/mlir/include/mlir/IR/AttributeSupport.h
+++ b/mlir/include/mlir/IR/AttributeSupport.h
@@ -133,17 +133,19 @@ class AttributeUniquer {
   template <typename T, typename... Args>
   static T get(MLIRContext *ctx, unsigned kind, Args &&... args) {
     return ctx->getAttributeUniquer().get<typename T::ImplType>(
+        T::getTypeID(),
         [ctx](AttributeStorage *storage) {
           initializeAttributeStorage(storage, ctx, T::getTypeID());
         },
         kind, std::forward<Args>(args)...);
   }
 
-  template <typename ImplType, typename... Args>
-  static LogicalResult mutate(MLIRContext *ctx, ImplType *impl,
+  template <typename T, typename... Args>
+  static LogicalResult mutate(MLIRContext *ctx, typename T::ImplType *impl,
                               Args &&...args) {
     assert(impl && "cannot mutate null attribute");
-    return ctx->getAttributeUniquer().mutate(impl, std::forward<Args>(args)...);
+    return ctx->getAttributeUniquer().mutate(T::getTypeID(), impl,
+                                             std::forward<Args>(args)...);
   }
 
 private:

diff  --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h
index 4c7693c28d2f..707d4d7f9ad6 100644
--- a/mlir/include/mlir/IR/StorageUniquerSupport.h
+++ b/mlir/include/mlir/IR/StorageUniquerSupport.h
@@ -109,8 +109,8 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
   /// The arguments are forwarded to 'ConcreteT::mutate'.
   template <typename... Args>
   LogicalResult mutate(Args &&...args) {
-    return UniquerT::mutate(this->getContext(), getImpl(),
-                            std::forward<Args>(args)...);
+    return UniquerT::template mutate<ConcreteT>(this->getContext(), getImpl(),
+                                                std::forward<Args>(args)...);
   }
 
   /// Default implementation that just returns success.

diff  --git a/mlir/include/mlir/IR/TypeSupport.h b/mlir/include/mlir/IR/TypeSupport.h
index ddb91e09dc89..0e1a6c72c11d 100644
--- a/mlir/include/mlir/IR/TypeSupport.h
+++ b/mlir/include/mlir/IR/TypeSupport.h
@@ -127,6 +127,7 @@ struct TypeUniquer {
   template <typename T, typename... Args>
   static T get(MLIRContext *ctx, unsigned kind, Args &&... args) {
     return ctx->getTypeUniquer().get<typename T::ImplType>(
+        T::getTypeID(),
         [&](TypeStorage *storage) {
           storage->initialize(AbstractType::lookup(T::getTypeID(), ctx));
         },
@@ -135,11 +136,12 @@ struct TypeUniquer {
 
   /// Change the mutable component of the given type instance in the provided
   /// context.
-  template <typename ImplType, typename... Args>
-  static LogicalResult mutate(MLIRContext *ctx, ImplType *impl,
+  template <typename T, typename... Args>
+  static LogicalResult mutate(MLIRContext *ctx, typename T::ImplType *impl,
                               Args &&...args) {
     assert(impl && "cannot mutate null type");
-    return ctx->getTypeUniquer().mutate(impl, std::forward<Args>(args)...);
+    return ctx->getTypeUniquer().mutate(T::getTypeID(), impl,
+                                        std::forward<Args>(args)...);
   }
 };
 } // namespace detail

diff  --git a/mlir/include/mlir/Support/StorageUniquer.h b/mlir/include/mlir/Support/StorageUniquer.h
index 3100b4454197..6c7c7b0496da 100644
--- a/mlir/include/mlir/Support/StorageUniquer.h
+++ b/mlir/include/mlir/Support/StorageUniquer.h
@@ -15,6 +15,8 @@
 #include "llvm/Support/Allocator.h"
 
 namespace mlir {
+class TypeID;
+
 namespace detail {
 struct StorageUniquerImpl;
 
@@ -75,6 +77,10 @@ using has_impltype_hash_t = decltype(ImplTy::hashKey(std::declval<T>()));
 ///      value of the function is used to indicate whether the mutation was
 ///      successful, e.g., to limit the number of mutations or enable deferred
 ///      one-time assignment of the mutable component.
+///
+/// All storage classes must be registered with the uniquer via
+/// `registerStorageType` using an appropriate unique `TypeID` for the storage
+/// class.
 class StorageUniquer {
 public:
   StorageUniquer();
@@ -83,6 +89,10 @@ class StorageUniquer {
   /// Set the flag specifying if multi-threading is disabled within the uniquer.
   void disableMultithreading(bool disable = true);
 
+  /// Register a new storage object with this uniquer using the given unique
+  /// type id.
+  void registerStorageType(TypeID id);
+
   /// This class acts as the base storage that all storage classes must derived
   /// from.
   class BaseStorage {
@@ -140,8 +150,8 @@ class StorageUniquer {
   /// function is used for derived types that have complex storage or uniquing
   /// constraints.
   template <typename Storage, typename Arg, typename... Args>
-  Storage *get(function_ref<void(Storage *)> initFn, unsigned kind, Arg &&arg,
-               Args &&... args) {
+  Storage *get(const TypeID &id, function_ref<void(Storage *)> initFn,
+               unsigned kind, Arg &&arg, Args &&...args) {
     // Construct a value of the derived key type.
     auto derivedKey =
         getKey<Storage>(std::forward<Arg>(arg), std::forward<Args>(args)...);
@@ -163,7 +173,8 @@ class StorageUniquer {
     };
 
     // Get an instance for the derived storage.
-    return static_cast<Storage *>(getImpl(kind, hashValue, isEqual, ctorFn));
+    return static_cast<Storage *>(
+        getImpl(id, kind, hashValue, isEqual, ctorFn));
   }
 
   /// Gets a uniqued instance of 'Storage'. 'initFn' is an optional parameter
@@ -171,31 +182,32 @@ class StorageUniquer {
   /// function is used for derived types that use no additional storage or
   /// uniquing outside of the kind.
   template <typename Storage>
-  Storage *get(function_ref<void(Storage *)> initFn, unsigned kind) {
+  Storage *get(const TypeID &id, function_ref<void(Storage *)> initFn,
+               unsigned kind) {
     auto ctorFn = [&](StorageAllocator &allocator) {
       auto *storage = new (allocator.allocate<Storage>()) Storage();
       if (initFn)
         initFn(storage);
       return storage;
     };
-    return static_cast<Storage *>(getImpl(kind, ctorFn));
+    return static_cast<Storage *>(getImpl(id, kind, ctorFn));
   }
 
   /// Changes the mutable component of 'storage' by forwarding the trailing
   /// arguments to the 'mutate' function of the derived class.
   template <typename Storage, typename... Args>
-  LogicalResult mutate(Storage *storage, Args &&...args) {
+  LogicalResult mutate(const TypeID &id, Storage *storage, Args &&...args) {
     auto mutationFn = [&](StorageAllocator &allocator) -> LogicalResult {
       return static_cast<Storage &>(*storage).mutate(
           allocator, std::forward<Args>(args)...);
     };
-    return mutateImpl(mutationFn);
+    return mutateImpl(id, mutationFn);
   }
 
   /// Erases a uniqued instance of 'Storage'. This function is used for derived
   /// types that have complex storage or uniquing constraints.
   template <typename Storage, typename Arg, typename... Args>
-  void erase(unsigned kind, Arg &&arg, Args &&... args) {
+  void erase(const TypeID &id, unsigned kind, Arg &&arg, Args &&...args) {
     // Construct a value of the derived key type.
     auto derivedKey =
         getKey<Storage>(std::forward<Arg>(arg), std::forward<Args>(args)...);
@@ -209,7 +221,7 @@ class StorageUniquer {
     };
 
     // Attempt to erase the storage instance.
-    eraseImpl(kind, hashValue, isEqual, [](BaseStorage *storage) {
+    eraseImpl(id, kind, hashValue, isEqual, [](BaseStorage *storage) {
       static_cast<Storage *>(storage)->cleanup();
     });
   }
@@ -217,24 +229,25 @@ class StorageUniquer {
 private:
   /// Implementation for getting/creating an instance of a derived type with
   /// complex storage.
-  BaseStorage *getImpl(unsigned kind, unsigned hashValue,
+  BaseStorage *getImpl(const TypeID &id, unsigned kind, unsigned hashValue,
                        function_ref<bool(const BaseStorage *)> isEqual,
                        function_ref<BaseStorage *(StorageAllocator &)> ctorFn);
 
   /// Implementation for getting/creating an instance of a derived type with
   /// default storage.
-  BaseStorage *getImpl(unsigned kind,
+  BaseStorage *getImpl(const TypeID &id, unsigned kind,
                        function_ref<BaseStorage *(StorageAllocator &)> ctorFn);
 
   /// Implementation for erasing an instance of a derived type with complex
   /// storage.
-  void eraseImpl(unsigned kind, unsigned hashValue,
+  void eraseImpl(const TypeID &id, unsigned kind, unsigned hashValue,
                  function_ref<bool(const BaseStorage *)> isEqual,
                  function_ref<void(BaseStorage *)> cleanupFn);
 
   /// Implementation for mutating an instance of a derived storage.
   LogicalResult
-  mutateImpl(function_ref<LogicalResult(StorageAllocator &)> mutationFn);
+  mutateImpl(const TypeID &id,
+             function_ref<LogicalResult(StorageAllocator &)> mutationFn);
 
   /// The internal implementation class.
   std::unique_ptr<detail::StorageUniquerImpl> impl;
@@ -249,7 +262,7 @@ class StorageUniquer {
   static typename std::enable_if<
       llvm::is_detected<detail::has_impltype_getkey_t, ImplTy, Args...>::value,
       typename ImplTy::KeyTy>::type
-  getKey(Args &&... args) {
+  getKey(Args &&...args) {
     return ImplTy::getKey(args...);
   }
   /// If there is no 'ImplTy::getKey' method, then we try to directly construct
@@ -258,7 +271,7 @@ class StorageUniquer {
   static typename std::enable_if<
       !llvm::is_detected<detail::has_impltype_getkey_t, ImplTy, Args...>::value,
       typename ImplTy::KeyTy>::type
-  getKey(Args &&... args) {
+  getKey(Args &&...args) {
     return typename ImplTy::KeyTy(args...);
   }
 

diff  --git a/mlir/lib/Dialect/SDBM/SDBMDialect.cpp b/mlir/lib/Dialect/SDBM/SDBMDialect.cpp
index 6306063181b3..09c9d1dfd3d8 100644
--- a/mlir/lib/Dialect/SDBM/SDBMDialect.cpp
+++ b/mlir/lib/Dialect/SDBM/SDBMDialect.cpp
@@ -7,7 +7,17 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/SDBM/SDBMDialect.h"
+#include "SDBMExprDetail.h"
 
 using namespace mlir;
 
+SDBMDialect::SDBMDialect(MLIRContext *context)
+    : Dialect(getDialectNamespace(), context, TypeID::get<SDBMDialect>()) {
+  uniquer.registerStorageType(TypeID::get<detail::SDBMBinaryExprStorage>());
+  uniquer.registerStorageType(TypeID::get<detail::SDBMConstantExprStorage>());
+  uniquer.registerStorageType(TypeID::get<detail::SDBMDiffExprStorage>());
+  uniquer.registerStorageType(TypeID::get<detail::SDBMNegExprStorage>());
+  uniquer.registerStorageType(TypeID::get<detail::SDBMTermExprStorage>());
+}
+
 SDBMDialect::~SDBMDialect() = default;

diff  --git a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp
index 5d60158c34e4..435c7fe25f0c 100644
--- a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp
+++ b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp
@@ -246,6 +246,7 @@ SDBMSumExpr SDBMSumExpr::get(SDBMTermExpr lhs, SDBMConstantExpr rhs) {
 
   StorageUniquer &uniquer = lhs.getDialect()->getUniquer();
   return uniquer.get<detail::SDBMBinaryExprStorage>(
+      TypeID::get<detail::SDBMBinaryExprStorage>(),
       /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Add), lhs, rhs);
 }
 
@@ -533,6 +534,7 @@ SDBMDiffExpr SDBMDiffExpr::get(SDBMDirectExpr lhs, SDBMTermExpr rhs) {
 
   StorageUniquer &uniquer = lhs.getDialect()->getUniquer();
   return uniquer.get<detail::SDBMDiffExprStorage>(
+      TypeID::get<detail::SDBMDiffExprStorage>(),
       /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Diff), lhs, rhs);
 }
 
@@ -573,6 +575,7 @@ SDBMStripeExpr SDBMStripeExpr::get(SDBMDirectExpr var,
 
   StorageUniquer &uniquer = var.getDialect()->getUniquer();
   return uniquer.get<detail::SDBMBinaryExprStorage>(
+      TypeID::get<detail::SDBMBinaryExprStorage>(),
       /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Stripe), var,
       stripeFactor);
 }
@@ -608,7 +611,8 @@ SDBMDimExpr SDBMDimExpr::get(SDBMDialect *dialect, unsigned position) {
 
   StorageUniquer &uniquer = dialect->getUniquer();
   return uniquer.get<detail::SDBMTermExprStorage>(
-      assignDialect, static_cast<unsigned>(SDBMExprKind::DimId), position);
+      TypeID::get<detail::SDBMTermExprStorage>(), assignDialect,
+      static_cast<unsigned>(SDBMExprKind::DimId), position);
 }
 
 //===----------------------------------------------------------------------===//
@@ -624,7 +628,8 @@ SDBMSymbolExpr SDBMSymbolExpr::get(SDBMDialect *dialect, unsigned position) {
 
   StorageUniquer &uniquer = dialect->getUniquer();
   return uniquer.get<detail::SDBMTermExprStorage>(
-      assignDialect, static_cast<unsigned>(SDBMExprKind::SymbolId), position);
+      TypeID::get<detail::SDBMTermExprStorage>(), assignDialect,
+      static_cast<unsigned>(SDBMExprKind::SymbolId), position);
 }
 
 //===----------------------------------------------------------------------===//
@@ -640,7 +645,8 @@ SDBMConstantExpr SDBMConstantExpr::get(SDBMDialect *dialect, int64_t value) {
 
   StorageUniquer &uniquer = dialect->getUniquer();
   return uniquer.get<detail::SDBMConstantExprStorage>(
-      assignCtx, static_cast<unsigned>(SDBMExprKind::Constant), value);
+      TypeID::get<detail::SDBMConstantExprStorage>(), assignCtx,
+      static_cast<unsigned>(SDBMExprKind::Constant), value);
 }
 
 int64_t SDBMConstantExpr::getValue() const {
@@ -656,6 +662,7 @@ SDBMNegExpr SDBMNegExpr::get(SDBMDirectExpr var) {
 
   StorageUniquer &uniquer = var.getDialect()->getUniquer();
   return uniquer.get<detail::SDBMNegExprStorage>(
+      TypeID::get<detail::SDBMNegExprStorage>(),
       /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Neg), var);
 }
 

diff  --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index c78e7e1eac57..83d080f17d7d 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -12,6 +12,7 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/Support/MathExtras.h"
+#include "mlir/Support/TypeID.h"
 #include "llvm/ADT/STLExtras.h"
 
 using namespace mlir;
@@ -448,7 +449,8 @@ static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position,
 
   StorageUniquer &uniquer = context->getAffineUniquer();
   return uniquer.get<AffineDimExprStorage>(
-      assignCtx, static_cast<unsigned>(kind), position);
+      TypeID::get<AffineDimExprStorage>(), assignCtx,
+      static_cast<unsigned>(kind), position);
 }
 
 AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) {
@@ -483,7 +485,8 @@ AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) {
 
   StorageUniquer &uniquer = context->getAffineUniquer();
   return uniquer.get<AffineConstantExprStorage>(
-      assignCtx, static_cast<unsigned>(AffineExprKind::Constant), constant);
+      TypeID::get<AffineConstantExprStorage>(), assignCtx,
+      static_cast<unsigned>(AffineExprKind::Constant), constant);
 }
 
 /// Simplify add expression. Return nullptr if it can't be simplified.
@@ -591,6 +594,7 @@ AffineExpr AffineExpr::operator+(AffineExpr other) const {
 
   StorageUniquer &uniquer = getContext()->getAffineUniquer();
   return uniquer.get<AffineBinaryOpExprStorage>(
+      TypeID::get<AffineBinaryOpExprStorage>(),
       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), *this, other);
 }
 
@@ -651,6 +655,7 @@ AffineExpr AffineExpr::operator*(AffineExpr other) const {
 
   StorageUniquer &uniquer = getContext()->getAffineUniquer();
   return uniquer.get<AffineBinaryOpExprStorage>(
+      TypeID::get<AffineBinaryOpExprStorage>(),
       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), *this, other);
 }
 
@@ -717,6 +722,7 @@ AffineExpr AffineExpr::floorDiv(AffineExpr other) const {
 
   StorageUniquer &uniquer = getContext()->getAffineUniquer();
   return uniquer.get<AffineBinaryOpExprStorage>(
+      TypeID::get<AffineBinaryOpExprStorage>(),
       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::FloorDiv), *this,
       other);
 }
@@ -760,6 +766,7 @@ AffineExpr AffineExpr::ceilDiv(AffineExpr other) const {
 
   StorageUniquer &uniquer = getContext()->getAffineUniquer();
   return uniquer.get<AffineBinaryOpExprStorage>(
+      TypeID::get<AffineBinaryOpExprStorage>(),
       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::CeilDiv), *this,
       other);
 }
@@ -807,6 +814,7 @@ AffineExpr AffineExpr::operator%(AffineExpr other) const {
 
   StorageUniquer &uniquer = getContext()->getAffineUniquer();
   return uniquer.get<AffineBinaryOpExprStorage>(
+      TypeID::get<AffineBinaryOpExprStorage>(),
       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mod), *this, other);
 }
 

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 32ee60f5512a..df58b957bc32 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -402,6 +402,13 @@ MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) {
   /// The empty dictionary attribute.
   impl->emptyDictionaryAttr = AttributeUniquer::get<DictionaryAttr>(
       this, StandardAttributes::Dictionary, ArrayRef<NamedAttribute>());
+
+  // Register the affine storage objects with the uniquer.
+  impl->affineUniquer.registerStorageType(
+      TypeID::get<AffineBinaryOpExprStorage>());
+  impl->affineUniquer.registerStorageType(
+      TypeID::get<AffineConstantExprStorage>());
+  impl->affineUniquer.registerStorageType(TypeID::get<AffineDimExprStorage>());
 }
 
 MLIRContext::~MLIRContext() {}
@@ -571,6 +578,7 @@ void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
           AbstractType(std::move(typeInfo));
   if (!impl.registeredTypes.insert({typeID, newInfo}).second)
     llvm::report_fatal_error("Dialect Type already registered.");
+  impl.typeUniquer.registerStorageType(typeID);
 }
 
 void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) {
@@ -580,6 +588,7 @@ void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) {
           AbstractAttribute(std::move(attrInfo));
   if (!impl.registeredAttributes.insert({typeID, newInfo}).second)
     llvm::report_fatal_error("Dialect Attribute already registered.");
+  impl.attributeUniquer.registerStorageType(typeID);
 }
 
 /// Get the dialect that registered the attribute with the provided typeid.

diff  --git a/mlir/lib/Support/StorageUniquer.cpp b/mlir/lib/Support/StorageUniquer.cpp
index f7c953e98140..49e7272091fb 100644
--- a/mlir/lib/Support/StorageUniquer.cpp
+++ b/mlir/lib/Support/StorageUniquer.cpp
@@ -9,15 +9,18 @@
 #include "mlir/Support/StorageUniquer.h"
 
 #include "mlir/Support/LLVM.h"
+#include "mlir/Support/TypeID.h"
 #include "llvm/Support/RWMutex.h"
 
 using namespace mlir;
 using namespace mlir::detail;
 
-namespace mlir {
-namespace detail {
-/// This is the implementation of the StorageUniquer class.
-struct StorageUniquerImpl {
+namespace {
+/// This class represents a uniquer for storage instances of a specific type. It
+/// contains all of the necessary data to unique storage instances in a thread
+/// safe way. This allows for the main uniquer to bucket each of the individual
+/// sub-types removing the need to lock the main uniquer itself.
+struct InstSpecificUniquer {
   using BaseStorage = StorageUniquer::BaseStorage;
   using StorageAllocator = StorageUniquer::StorageAllocator;
 
@@ -40,98 +43,160 @@ struct StorageUniquerImpl {
     BaseStorage *storage;
   };
 
+  /// Storage info for derived TypeStorage objects.
+  struct StorageKeyInfo : DenseMapInfo<HashedStorage> {
+    static HashedStorage getEmptyKey() {
+      return HashedStorage{0, DenseMapInfo<BaseStorage *>::getEmptyKey()};
+    }
+    static HashedStorage getTombstoneKey() {
+      return HashedStorage{0, DenseMapInfo<BaseStorage *>::getTombstoneKey()};
+    }
+
+    static unsigned getHashValue(const HashedStorage &key) {
+      return key.hashValue;
+    }
+    static unsigned getHashValue(LookupKey key) { return key.hashValue; }
+
+    static bool isEqual(const HashedStorage &lhs, const HashedStorage &rhs) {
+      return lhs.storage == rhs.storage;
+    }
+    static bool isEqual(const LookupKey &lhs, const HashedStorage &rhs) {
+      if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
+        return false;
+      // If the lookup kind matches the kind of the storage, then invoke the
+      // equality function on the lookup key.
+      return lhs.kind == rhs.storage->getKind() && lhs.isEqual(rhs.storage);
+    }
+  };
+
+  /// Unique types with specific hashing or storage constraints.
+  using StorageTypeSet = DenseSet<HashedStorage, StorageKeyInfo>;
+  StorageTypeSet complexInstances;
+
+  /// Instances of this storage object.
+  llvm::SmallDenseMap<unsigned, BaseStorage *, 1> simpleInstances;
+
+  /// Allocator to use when constructing derived instances.
+  StorageAllocator allocator;
+
+  /// A mutex to keep type uniquing thread-safe.
+  llvm::sys::SmartRWMutex<true> mutex;
+};
+} // end anonymous namespace
+
+namespace mlir {
+namespace detail {
+/// This is the implementation of the StorageUniquer class.
+struct StorageUniquerImpl {
+  using BaseStorage = StorageUniquer::BaseStorage;
+  using StorageAllocator = StorageUniquer::StorageAllocator;
+
   /// Get or create an instance of a complex derived type.
   BaseStorage *
-  getOrCreate(unsigned kind, unsigned hashValue,
+  getOrCreate(TypeID id, unsigned kind, unsigned hashValue,
               function_ref<bool(const BaseStorage *)> isEqual,
               function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
-    LookupKey lookupKey{kind, hashValue, isEqual};
+    assert(instUniquers.count(id) && "creating unregistered storage instance");
+    InstSpecificUniquer::LookupKey lookupKey{kind, hashValue, isEqual};
+    InstSpecificUniquer &storageUniquer = *instUniquers[id];
     if (!threadingIsEnabled)
-      return getOrCreateUnsafe(kind, hashValue, lookupKey, ctorFn);
+      return getOrCreateUnsafe(storageUniquer, kind, lookupKey, ctorFn);
 
     // Check for an existing instance in read-only mode.
     {
-      llvm::sys::SmartScopedReader<true> typeLock(mutex);
-      auto it = storageTypes.find_as(lookupKey);
-      if (it != storageTypes.end())
+      llvm::sys::SmartScopedReader<true> typeLock(storageUniquer.mutex);
+      auto it = storageUniquer.complexInstances.find_as(lookupKey);
+      if (it != storageUniquer.complexInstances.end())
         return it->storage;
     }
 
     // Acquire a writer-lock so that we can safely create the new type instance.
-    llvm::sys::SmartScopedWriter<true> typeLock(mutex);
-    return getOrCreateUnsafe(kind, hashValue, lookupKey, ctorFn);
+    llvm::sys::SmartScopedWriter<true> typeLock(storageUniquer.mutex);
+    return getOrCreateUnsafe(storageUniquer, kind, lookupKey, ctorFn);
   }
-  /// Get or create an instance of a complex derived type in an unsafe fashion.
+  /// Get or create an instance of a complex derived type in an thread-unsafe
+  /// fashion.
   BaseStorage *
-  getOrCreateUnsafe(unsigned kind, unsigned hashValue, LookupKey &lookupKey,
+  getOrCreateUnsafe(InstSpecificUniquer &storageUniquer, unsigned kind,
+                    InstSpecificUniquer::LookupKey &lookupKey,
                     function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
-    auto existing = storageTypes.insert_as({}, lookupKey);
+    auto existing = storageUniquer.complexInstances.insert_as({}, lookupKey);
     if (!existing.second)
       return existing.first->storage;
 
     // Otherwise, construct and initialize the derived storage for this type
     // instance.
-    BaseStorage *storage = initializeStorage(kind, ctorFn);
-    *existing.first = HashedStorage{hashValue, storage};
+    BaseStorage *storage =
+        initializeStorage(kind, storageUniquer.allocator, ctorFn);
+    *existing.first =
+        InstSpecificUniquer::HashedStorage{lookupKey.hashValue, storage};
     return storage;
   }
 
   /// Get or create an instance of a simple derived type.
   BaseStorage *
-  getOrCreate(unsigned kind,
+  getOrCreate(TypeID id, unsigned kind,
               function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
+    assert(instUniquers.count(id) && "creating unregistered storage instance");
+    InstSpecificUniquer &storageUniquer = *instUniquers[id];
     if (!threadingIsEnabled)
-      return getOrCreateUnsafe(kind, ctorFn);
+      return getOrCreateUnsafe(storageUniquer, kind, ctorFn);
 
     // Check for an existing instance in read-only mode.
     {
-      llvm::sys::SmartScopedReader<true> typeLock(mutex);
-      auto it = simpleTypes.find(kind);
-      if (it != simpleTypes.end())
+      llvm::sys::SmartScopedReader<true> typeLock(storageUniquer.mutex);
+      auto it = storageUniquer.simpleInstances.find(kind);
+      if (it != storageUniquer.simpleInstances.end())
         return it->second;
     }
 
     // Acquire a writer-lock so that we can safely create the new type instance.
-    llvm::sys::SmartScopedWriter<true> typeLock(mutex);
-    return getOrCreateUnsafe(kind, ctorFn);
+    llvm::sys::SmartScopedWriter<true> typeLock(storageUniquer.mutex);
+    return getOrCreateUnsafe(storageUniquer, kind, ctorFn);
   }
-  /// Get or create an instance of a simple derived type in an unsafe fashion.
+  /// Get or create an instance of a simple derived type in an thread-unsafe
+  /// fashion.
   BaseStorage *
-  getOrCreateUnsafe(unsigned kind,
+  getOrCreateUnsafe(InstSpecificUniquer &storageUniquer, unsigned kind,
                     function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
-    auto &result = simpleTypes[kind];
+    auto &result = storageUniquer.simpleInstances[kind];
     if (result)
       return result;
 
     // Otherwise, create and return a new storage instance.
-    return result = initializeStorage(kind, ctorFn);
+    return result = initializeStorage(kind, storageUniquer.allocator, ctorFn);
   }
 
   /// Erase an instance of a complex derived type.
-  void erase(unsigned kind, unsigned hashValue,
+  void erase(TypeID id, unsigned kind, unsigned hashValue,
              function_ref<bool(const BaseStorage *)> isEqual,
              function_ref<void(BaseStorage *)> cleanupFn) {
-    LookupKey lookupKey{kind, hashValue, isEqual};
+    assert(instUniquers.count(id) && "erasing unregistered storage instance");
+    InstSpecificUniquer &storageUniquer = *instUniquers[id];
+    InstSpecificUniquer::LookupKey lookupKey{kind, hashValue, isEqual};
 
     // Acquire a writer-lock so that we can safely erase the type instance.
-    llvm::sys::SmartScopedWriter<true> typeLock(mutex);
-    auto existing = storageTypes.find_as(lookupKey);
-    if (existing == storageTypes.end())
+    llvm::sys::SmartScopedWriter<true> lock(storageUniquer.mutex);
+    auto existing = storageUniquer.complexInstances.find_as(lookupKey);
+    if (existing == storageUniquer.complexInstances.end())
       return;
 
     // Cleanup the storage and remove it from the map.
     cleanupFn(existing->storage);
-    storageTypes.erase(existing);
+    storageUniquer.complexInstances.erase(existing);
   }
 
   /// Mutates an instance of a derived storage in a thread-safe way.
   LogicalResult
-  mutate(function_ref<LogicalResult(StorageAllocator &)> mutationFn) {
+  mutate(TypeID id,
+         function_ref<LogicalResult(StorageAllocator &)> mutationFn) {
+    assert(instUniquers.count(id) && "mutating unregistered storage instance");
+    InstSpecificUniquer &storageUniquer = *instUniquers[id];
     if (!threadingIsEnabled)
-      return mutationFn(allocator);
+      return mutationFn(storageUniquer.allocator);
 
-    llvm::sys::SmartScopedWriter<true> lock(mutex);
-    return mutationFn(allocator);
+    llvm::sys::SmartScopedWriter<true> lock(storageUniquer.mutex);
+    return mutationFn(storageUniquer.allocator);
   }
 
   //===--------------------------------------------------------------------===//
@@ -140,51 +205,15 @@ struct StorageUniquerImpl {
 
   /// Utility to create and initialize a storage instance.
   BaseStorage *
-  initializeStorage(unsigned kind,
+  initializeStorage(unsigned kind, StorageAllocator &allocator,
                     function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
     BaseStorage *storage = ctorFn(allocator);
     storage->kind = kind;
     return storage;
   }
 
-  /// Storage info for derived TypeStorage objects.
-  struct StorageKeyInfo : DenseMapInfo<HashedStorage> {
-    static HashedStorage getEmptyKey() {
-      return HashedStorage{0, DenseMapInfo<BaseStorage *>::getEmptyKey()};
-    }
-    static HashedStorage getTombstoneKey() {
-      return HashedStorage{0, DenseMapInfo<BaseStorage *>::getTombstoneKey()};
-    }
-
-    static unsigned getHashValue(const HashedStorage &key) {
-      return key.hashValue;
-    }
-    static unsigned getHashValue(LookupKey key) { return key.hashValue; }
-
-    static bool isEqual(const HashedStorage &lhs, const HashedStorage &rhs) {
-      return lhs.storage == rhs.storage;
-    }
-    static bool isEqual(const LookupKey &lhs, const HashedStorage &rhs) {
-      if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
-        return false;
-      // If the lookup kind matches the kind of the storage, then invoke the
-      // equality function on the lookup key.
-      return lhs.kind == rhs.storage->getKind() && lhs.isEqual(rhs.storage);
-    }
-  };
-
-  /// Unique types with specific hashing or storage constraints.
-  using StorageTypeSet = DenseSet<HashedStorage, StorageKeyInfo>;
-  StorageTypeSet storageTypes;
-
-  /// Unique types with just the kind.
-  DenseMap<unsigned, BaseStorage *> simpleTypes;
-
-  /// Allocator to use when constructing derived type instances.
-  StorageUniquer::StorageAllocator allocator;
-
-  /// A mutex to keep type uniquing thread-safe.
-  llvm::sys::SmartRWMutex<true> mutex;
+  /// Map of type ids to the storage uniquer to use for registered objects.
+  DenseMap<TypeID, std::unique_ptr<InstSpecificUniquer>> instUniquers;
 
   /// Flag specifying if multi-threading is enabled within the uniquer.
   bool threadingIsEnabled = true;
@@ -200,33 +229,41 @@ void StorageUniquer::disableMultithreading(bool disable) {
   impl->threadingIsEnabled = !disable;
 }
 
+/// Register a new storage object with this uniquer using the given unique type
+/// id.
+void StorageUniquer::registerStorageType(TypeID id) {
+  impl->instUniquers.try_emplace(id, std::make_unique<InstSpecificUniquer>());
+}
+
 /// Implementation for getting/creating an instance of a derived type with
 /// complex storage.
 auto StorageUniquer::getImpl(
-    unsigned kind, unsigned hashValue,
+    const TypeID &id, unsigned kind, unsigned hashValue,
     function_ref<bool(const BaseStorage *)> isEqual,
     function_ref<BaseStorage *(StorageAllocator &)> ctorFn) -> BaseStorage * {
-  return impl->getOrCreate(kind, hashValue, isEqual, ctorFn);
+  return impl->getOrCreate(id, kind, hashValue, isEqual, ctorFn);
 }
 
 /// Implementation for getting/creating an instance of a derived type with
 /// default storage.
 auto StorageUniquer::getImpl(
-    unsigned kind, function_ref<BaseStorage *(StorageAllocator &)> ctorFn)
-    -> BaseStorage * {
-  return impl->getOrCreate(kind, ctorFn);
+    const TypeID &id, unsigned kind,
+    function_ref<BaseStorage *(StorageAllocator &)> ctorFn) -> BaseStorage * {
+  return impl->getOrCreate(id, kind, ctorFn);
 }
 
 /// Implementation for erasing an instance of a derived type with complex
 /// storage.
-void StorageUniquer::eraseImpl(unsigned kind, unsigned hashValue,
+void StorageUniquer::eraseImpl(const TypeID &id, unsigned kind,
+                               unsigned hashValue,
                                function_ref<bool(const BaseStorage *)> isEqual,
                                function_ref<void(BaseStorage *)> cleanupFn) {
-  impl->erase(kind, hashValue, isEqual, cleanupFn);
+  impl->erase(id, kind, hashValue, isEqual, cleanupFn);
 }
 
 /// Implementation for mutating an instance of a derived storage.
 LogicalResult StorageUniquer::mutateImpl(
+    const TypeID &id,
     function_ref<LogicalResult(StorageAllocator &)> mutationFn) {
-  return impl->mutate(mutationFn);
+  return impl->mutate(id, mutationFn);
 }


        


More information about the Mlir-commits mailing list