[Mlir-commits] [mlir] f3df3b5 - [mlir] Add a utility class, ThreadLocalCache, for storing non static thread local objects.

River Riddle llvmlistbot at llvm.org
Fri Oct 16 12:09:21 PDT 2020


Author: River Riddle
Date: 2020-10-16T12:08:48-07:00
New Revision: f3df3b58e7dd7c400f9c18d16d92631823705ebd

URL: https://github.com/llvm/llvm-project/commit/f3df3b58e7dd7c400f9c18d16d92631823705ebd
DIFF: https://github.com/llvm/llvm-project/commit/f3df3b58e7dd7c400f9c18d16d92631823705ebd.diff

LOG: [mlir] Add a utility class, ThreadLocalCache, for storing non static thread local objects.

(Note: This is a reland of D82597)

This class allows for defining thread local objects that have a set non-static lifetime. This internals of the cache use a static thread_local map between the various different non-static objects and the desired value type. When a non-static object destructs, it simply nulls out the entry in the static map. This will leave an entry in the map, but erase any of the data for the associated value. The current use cases for this are in the MLIRContext, meaning that the number of items in the static map is ~1-2 which aren't particularly costly enough to warrant the complexity of pruning. If a use case arises that requires pruning of the map, the functionality can be added.

This is especially useful in the context of MLIR for implementing thread-local caching of context level objects that would otherwise have very high lock contention. This revision adds a thread local cache in the MLIRContext for attributes, identifiers, and types to reduce some of the locking burden. This led to a speedup of several seconds when compiling a somewhat large mlir module.

Differential Revision: https://reviews.llvm.org/D89504

Added: 
    mlir/include/mlir/Support/ThreadLocalCache.h

Modified: 
    mlir/include/mlir/Support/StorageUniquer.h
    mlir/lib/IR/MLIRContext.cpp
    mlir/lib/Support/StorageUniquer.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Support/StorageUniquer.h b/mlir/include/mlir/Support/StorageUniquer.h
index d0a6170805bf..a3429ac14e56 100644
--- a/mlir/include/mlir/Support/StorageUniquer.h
+++ b/mlir/include/mlir/Support/StorageUniquer.h
@@ -231,28 +231,6 @@ class StorageUniquer {
     return mutateImpl(id, mutationFn);
   }
 
-  /// Erases a uniqued instance of 'Storage'. This function is used for derived
-  /// types that have complex storage or uniquing constraints.
-  template <typename Storage, typename Arg, typename... Args>
-  void erase(TypeID id, Arg &&arg, Args &&...args) {
-    // Construct a value of the derived key type.
-    auto derivedKey =
-        getKey<Storage>(std::forward<Arg>(arg), std::forward<Args>(args)...);
-
-    // Create a hash of the derived key.
-    unsigned hashValue = getHash<Storage>(derivedKey);
-
-    // Generate an equality function for the derived storage.
-    auto isEqual = [&derivedKey](const BaseStorage *existing) {
-      return static_cast<const Storage &>(*existing) == derivedKey;
-    };
-
-    // Attempt to erase the storage instance.
-    eraseImpl(id, hashValue, isEqual, [](BaseStorage *storage) {
-      static_cast<Storage *>(storage)->cleanup();
-    });
-  }
-
 private:
   /// Implementation for getting/creating an instance of a derived type with
   /// parametric storage.
@@ -275,12 +253,6 @@ class StorageUniquer {
   registerSingletonImpl(TypeID id,
                         function_ref<BaseStorage *(StorageAllocator &)> ctorFn);
 
-  /// Implementation for erasing an instance of a derived type with complex
-  /// storage.
-  void eraseImpl(TypeID id, unsigned hashValue,
-                 function_ref<bool(const BaseStorage *)> isEqual,
-                 function_ref<void(BaseStorage *)> cleanupFn);
-
   /// Implementation for mutating an instance of a derived storage.
   LogicalResult
   mutateImpl(TypeID id,

diff  --git a/mlir/include/mlir/Support/ThreadLocalCache.h b/mlir/include/mlir/Support/ThreadLocalCache.h
new file mode 100644
index 000000000000..3b5d6f0f424f
--- /dev/null
+++ b/mlir/include/mlir/Support/ThreadLocalCache.h
@@ -0,0 +1,117 @@
+//===- ThreadLocalCache.h - ThreadLocalCache class --------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains a definition of the ThreadLocalCache class. This class
+// provides support for defining thread local objects with non-static duration.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_SUPPORT_THREADLOCALCACHE_H
+#define MLIR_SUPPORT_THREADLOCALCACHE_H
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/ManagedStatic.h"
+#include "llvm/Support/Mutex.h"
+#include "llvm/Support/ThreadLocal.h"
+
+namespace mlir {
+/// This class provides support for defining a thread local object with non
+/// static storage duration. This is very useful for situations in which a data
+/// cache has very large lock contention.
+template <typename ValueT>
+class ThreadLocalCache {
+  /// The type used for the static thread_local cache. This is a map between an
+  /// instance of the non-static cache and a weak reference to an instance of
+  /// ValueT. We use a weak reference here so that the object can be destroyed
+  /// without needing to lock access to the cache itself.
+  struct CacheType : public llvm::SmallDenseMap<ThreadLocalCache<ValueT> *,
+                                                std::weak_ptr<ValueT>> {
+    ~CacheType() {
+      // Remove the values of this cache that haven't already expired.
+      for (auto &it : *this)
+        if (std::shared_ptr<ValueT> value = it.second.lock())
+          it.first->remove(value.get());
+    }
+
+    /// Clear out any unused entries within the map. This method is not
+    /// thread-safe, and should only be called by the same thread as the cache.
+    void clearExpiredEntries() {
+      for (auto it = this->begin(), e = this->end(); it != e;) {
+        auto curIt = it++;
+        if (curIt->second.expired())
+          this->erase(curIt);
+      }
+    }
+  };
+
+public:
+  ThreadLocalCache() = default;
+  ~ThreadLocalCache() {
+    // No cleanup is necessary here as the shared_pointer memory will go out of
+    // scope and invalidate the weak pointers held by the thread_local caches.
+  }
+
+  /// Return an instance of the value type for the current thread.
+  ValueT &get() {
+    // Check for an already existing instance for this thread.
+    CacheType &staticCache = getStaticCache();
+    std::weak_ptr<ValueT> &threadInstance = staticCache[this];
+    if (std::shared_ptr<ValueT> value = threadInstance.lock())
+      return *value;
+
+    // Otherwise, create a new instance for this thread.
+    llvm::sys::SmartScopedLock<true> threadInstanceLock(instanceMutex);
+    instances.push_back(std::make_shared<ValueT>());
+    std::shared_ptr<ValueT> &instance = instances.back();
+    threadInstance = instance;
+
+    // Before returning the new instance, take the chance to clear out any used
+    // entries in the static map. The cache is only cleared within the same
+    // thread to remove the need to lock the cache itself.
+    staticCache.clearExpiredEntries();
+    return *instance;
+  }
+  ValueT &operator*() { return get(); }
+  ValueT *operator->() { return &get(); }
+
+private:
+  ThreadLocalCache(ThreadLocalCache &&) = delete;
+  ThreadLocalCache(const ThreadLocalCache &) = delete;
+  ThreadLocalCache &operator=(const ThreadLocalCache &) = delete;
+
+  /// Return the static thread local instance of the cache type.
+  static CacheType &getStaticCache() {
+    static LLVM_THREAD_LOCAL CacheType cache;
+    return cache;
+  }
+
+  /// Remove the given value entry. This is generally called when a thread local
+  /// cache is destructing.
+  void remove(ValueT *value) {
+    // Erase the found value directly, because it is guaranteed to be in the
+    // list.
+    llvm::sys::SmartScopedLock<true> threadInstanceLock(instanceMutex);
+    auto it = llvm::find_if(instances, [&](std::shared_ptr<ValueT> &instance) {
+      return instance.get() == value;
+    });
+    assert(it != instances.end() && "expected value to exist in cache");
+    instances.erase(it);
+  }
+
+  /// Owning pointers to all of the values that have been constructed for this
+  /// object in the static cache.
+  SmallVector<std::shared_ptr<ValueT>, 1> instances;
+
+  /// A mutex used when a new thread instance has been added to the cache for
+  /// this object.
+  llvm::sys::SmartMutex<true> instanceMutex;
+};
+} // end namespace mlir
+
+#endif // MLIR_SUPPORT_THREADLOCALCACHE_H

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 7fffb51a1d1a..7551bb929970 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -24,6 +24,7 @@
 #include "mlir/IR/Location.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/Types.h"
+#include "mlir/Support/ThreadLocalCache.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/SetVector.h"
@@ -291,8 +292,12 @@ class MLIRContextImpl {
   /// operations.
   llvm::StringMap<AbstractOperation> registeredOperations;
 
-  /// These are identifiers uniqued into this MLIRContext.
+  /// Identifers are uniqued by string value and use the internal string set for
+  /// storage.
   llvm::StringSet<llvm::BumpPtrAllocator &> identifiers;
+  /// A thread local cache of identifiers to reduce lock contention.
+  ThreadLocalCache<llvm::StringMap<llvm::StringMapEntry<llvm::NoneType> *>>
+      localIdentifierCache;
 
   /// An allocator used for AbstractAttribute and AbstractType objects.
   llvm::BumpPtrAllocator abstractDialectSymbolAllocator;
@@ -703,16 +708,6 @@ const AbstractType &AbstractType::lookup(TypeID typeID, MLIRContext *context) {
 
 /// Return an identifier for the specified string.
 Identifier Identifier::get(StringRef str, MLIRContext *context) {
-  auto &impl = context->getImpl();
-
-  // Check for an existing identifier in read-only mode.
-  if (context->isMultithreadingEnabled()) {
-    llvm::sys::SmartScopedReader<true> contextLock(impl.identifierMutex);
-    auto it = impl.identifiers.find(str);
-    if (it != impl.identifiers.end())
-      return Identifier(&*it);
-  }
-
   // Check invariants after seeing if we already have something in the
   // identifier table - if we already had it in the table, then it already
   // passed invariant checks.
@@ -720,10 +715,30 @@ Identifier Identifier::get(StringRef str, MLIRContext *context) {
   assert(str.find('\0') == StringRef::npos &&
          "Cannot create an identifier with a nul character");
 
+  auto &impl = context->getImpl();
+  if (!context->isMultithreadingEnabled())
+    return Identifier(&*impl.identifiers.insert(str).first);
+
+  // Check for an existing instance in the local cache.
+  auto *&localEntry = (*impl.localIdentifierCache)[str];
+  if (localEntry)
+    return Identifier(localEntry);
+
+  // Check for an existing identifier in read-only mode.
+  {
+    llvm::sys::SmartScopedReader<true> contextLock(impl.identifierMutex);
+    auto it = impl.identifiers.find(str);
+    if (it != impl.identifiers.end()) {
+      localEntry = &*it;
+      return Identifier(localEntry);
+    }
+  }
+
   // Acquire a writer-lock so that we can safely create the new instance.
-  ScopedWriterLock contextLock(impl.identifierMutex, impl.threadingIsEnabled);
+  llvm::sys::SmartScopedWriter<true> contextLock(impl.identifierMutex);
   auto it = impl.identifiers.insert(str).first;
-  return Identifier(&*it);
+  localEntry = &*it;
+  return Identifier(localEntry);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Support/StorageUniquer.cpp b/mlir/lib/Support/StorageUniquer.cpp
index a3e296e99e73..8e0ef6b8f276 100644
--- a/mlir/lib/Support/StorageUniquer.cpp
+++ b/mlir/lib/Support/StorageUniquer.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Support/StorageUniquer.h"
 
 #include "mlir/Support/LLVM.h"
+#include "mlir/Support/ThreadLocalCache.h"
 #include "mlir/Support/TypeID.h"
 #include "llvm/Support/RWMutex.h"
 
@@ -37,6 +38,8 @@ struct ParametricStorageUniquer {
   /// A utility wrapper object representing a hashed storage object. This class
   /// contains a storage object and an existing computed hash value.
   struct HashedStorage {
+    HashedStorage(unsigned hashValue = 0, BaseStorage *storage = nullptr)
+        : hashValue(hashValue), storage(storage) {}
     unsigned hashValue;
     BaseStorage *storage;
   };
@@ -44,10 +47,10 @@ struct ParametricStorageUniquer {
   /// Storage info for derived TypeStorage objects.
   struct StorageKeyInfo : DenseMapInfo<HashedStorage> {
     static HashedStorage getEmptyKey() {
-      return HashedStorage{0, DenseMapInfo<BaseStorage *>::getEmptyKey()};
+      return HashedStorage(0, DenseMapInfo<BaseStorage *>::getEmptyKey());
     }
     static HashedStorage getTombstoneKey() {
-      return HashedStorage{0, DenseMapInfo<BaseStorage *>::getTombstoneKey()};
+      return HashedStorage(0, DenseMapInfo<BaseStorage *>::getTombstoneKey());
     }
 
     static unsigned getHashValue(const HashedStorage &key) {
@@ -70,6 +73,10 @@ struct ParametricStorageUniquer {
   using StorageTypeSet = DenseSet<HashedStorage, StorageKeyInfo>;
   StorageTypeSet instances;
 
+  /// A thread local cache for storage objects. This helps to reduce the lock
+  /// contention when an object already existing in the cache.
+  ThreadLocalCache<StorageTypeSet> localCache;
+
   /// Allocator to use when constructing derived instances.
   StorageAllocator allocator;
 
@@ -104,25 +111,31 @@ struct StorageUniquerImpl {
     if (!threadingIsEnabled)
       return getOrCreateUnsafe(storageUniquer, lookupKey, ctorFn);
 
+    // Check for a instance of this object in the local cache.
+    auto localIt = storageUniquer.localCache->insert_as({hashValue}, lookupKey);
+    BaseStorage *&localInst = localIt.first->storage;
+    if (localInst)
+      return localInst;
+
     // Check for an existing instance in read-only mode.
     {
       llvm::sys::SmartScopedReader<true> typeLock(storageUniquer.mutex);
       auto it = storageUniquer.instances.find_as(lookupKey);
       if (it != storageUniquer.instances.end())
-        return it->storage;
+        return localInst = it->storage;
     }
 
     // Acquire a writer-lock so that we can safely create the new type instance.
     llvm::sys::SmartScopedWriter<true> typeLock(storageUniquer.mutex);
-    return getOrCreateUnsafe(storageUniquer, lookupKey, ctorFn);
+    return localInst = getOrCreateUnsafe(storageUniquer, lookupKey, ctorFn);
   }
-  /// Get or create an instance of a complex derived type in an thread-unsafe
+  /// Get or create an instance of a param derived type in an thread-unsafe
   /// fashion.
   BaseStorage *
   getOrCreateUnsafe(ParametricStorageUniquer &storageUniquer,
-                    ParametricStorageUniquer::LookupKey &lookupKey,
+                    ParametricStorageUniquer::LookupKey &key,
                     function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
-    auto existing = storageUniquer.instances.insert_as({}, lookupKey);
+    auto existing = storageUniquer.instances.insert_as({key.hashValue}, key);
     if (!existing.second)
       return existing.first->storage;
 
@@ -130,30 +143,10 @@ struct StorageUniquerImpl {
     // instance.
     BaseStorage *storage = ctorFn(storageUniquer.allocator);
     *existing.first =
-        ParametricStorageUniquer::HashedStorage{lookupKey.hashValue, storage};
+        ParametricStorageUniquer::HashedStorage{key.hashValue, storage};
     return storage;
   }
 
-  /// Erase an instance of a parametric derived type.
-  void erase(TypeID id, unsigned hashValue,
-             function_ref<bool(const BaseStorage *)> isEqual,
-             function_ref<void(BaseStorage *)> cleanupFn) {
-    assert(parametricUniquers.count(id) &&
-           "erasing unregistered storage instance");
-    ParametricStorageUniquer &storageUniquer = *parametricUniquers[id];
-    ParametricStorageUniquer::LookupKey lookupKey{hashValue, isEqual};
-
-    // Acquire a writer-lock so that we can safely erase the type instance.
-    llvm::sys::SmartScopedWriter<true> lock(storageUniquer.mutex);
-    auto existing = storageUniquer.instances.find_as(lookupKey);
-    if (existing == storageUniquer.instances.end())
-      return;
-
-    // Cleanup the storage and remove it from the map.
-    cleanupFn(existing->storage);
-    storageUniquer.instances.erase(existing);
-  }
-
   /// Mutates an instance of a derived storage in a thread-safe way.
   LogicalResult
   mutate(TypeID id,
@@ -252,14 +245,6 @@ void StorageUniquer::registerSingletonImpl(
   impl->singletonInstances.try_emplace(id, ctorFn(impl->singletonAllocator));
 }
 
-/// Implementation for erasing an instance of a derived type with parametric
-/// storage.
-void StorageUniquer::eraseImpl(TypeID id, unsigned hashValue,
-                               function_ref<bool(const BaseStorage *)> isEqual,
-                               function_ref<void(BaseStorage *)> cleanupFn) {
-  impl->erase(id, hashValue, isEqual, cleanupFn);
-}
-
 /// Implementation for mutating an instance of a derived storage.
 LogicalResult StorageUniquer::mutateImpl(
     TypeID id, function_ref<LogicalResult(StorageAllocator &)> mutationFn) {


        


More information about the Mlir-commits mailing list