[Mlir-commits] [mlir] 9f24640 - [mlir] Add a utility class, ThreadLocalCache, for storing non static thread local objects.

River Riddle llvmlistbot at llvm.org
Fri Aug 7 13:43:44 PDT 2020


Author: River Riddle
Date: 2020-08-07T13:43:25-07:00
New Revision: 9f24640b7e6e61b0f293c724155a90a5e446dd7a

URL: https://github.com/llvm/llvm-project/commit/9f24640b7e6e61b0f293c724155a90a5e446dd7a
DIFF: https://github.com/llvm/llvm-project/commit/9f24640b7e6e61b0f293c724155a90a5e446dd7a.diff

LOG: [mlir] Add a utility class, ThreadLocalCache, for storing non static thread local objects.

This class allows for defining thread local objects that have a set non-static lifetime. This internals of the cache use a static thread_local map between the various different non-static objects and the desired value type. When a non-static object destructs, it simply nulls out the entry in the static map. This will leave an entry in the map, but erase any of the data for the associated value. The current use cases for this are in the MLIRContext, meaning that the number of items in the static map is ~1-2 which aren't particularly costly enough to warrant the complexity of pruning. If a use case arises that requires pruning of the map, the functionality can be added.

This is especially useful in the context of MLIR for implementing thread-local caching of context level objects that would otherwise have very high lock contention. This revision adds a thread local cache in the MLIRContext for attributes, identifiers, and types to reduce some of the locking burden. This led to a speedup of several hundred miliseconds when compiling a conversion pass on a very large mlir module(>300K operations).

Differential Revision: https://reviews.llvm.org/D82597

Added: 
    mlir/include/mlir/Support/ThreadLocalCache.h

Modified: 
    mlir/lib/IR/MLIRContext.cpp
    mlir/lib/Support/StorageUniquer.cpp
    mlir/test/EDSC/builder-api-test.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Support/ThreadLocalCache.h b/mlir/include/mlir/Support/ThreadLocalCache.h
new file mode 100644
index 000000000000..3b5d6f0f424f
--- /dev/null
+++ b/mlir/include/mlir/Support/ThreadLocalCache.h
@@ -0,0 +1,117 @@
+//===- ThreadLocalCache.h - ThreadLocalCache class --------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains a definition of the ThreadLocalCache class. This class
+// provides support for defining thread local objects with non-static duration.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_SUPPORT_THREADLOCALCACHE_H
+#define MLIR_SUPPORT_THREADLOCALCACHE_H
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/ManagedStatic.h"
+#include "llvm/Support/Mutex.h"
+#include "llvm/Support/ThreadLocal.h"
+
+namespace mlir {
+/// This class provides support for defining a thread local object with non
+/// static storage duration. This is very useful for situations in which a data
+/// cache has very large lock contention.
+template <typename ValueT>
+class ThreadLocalCache {
+  /// The type used for the static thread_local cache. This is a map between an
+  /// instance of the non-static cache and a weak reference to an instance of
+  /// ValueT. We use a weak reference here so that the object can be destroyed
+  /// without needing to lock access to the cache itself.
+  struct CacheType : public llvm::SmallDenseMap<ThreadLocalCache<ValueT> *,
+                                                std::weak_ptr<ValueT>> {
+    ~CacheType() {
+      // Remove the values of this cache that haven't already expired.
+      for (auto &it : *this)
+        if (std::shared_ptr<ValueT> value = it.second.lock())
+          it.first->remove(value.get());
+    }
+
+    /// Clear out any unused entries within the map. This method is not
+    /// thread-safe, and should only be called by the same thread as the cache.
+    void clearExpiredEntries() {
+      for (auto it = this->begin(), e = this->end(); it != e;) {
+        auto curIt = it++;
+        if (curIt->second.expired())
+          this->erase(curIt);
+      }
+    }
+  };
+
+public:
+  ThreadLocalCache() = default;
+  ~ThreadLocalCache() {
+    // No cleanup is necessary here as the shared_pointer memory will go out of
+    // scope and invalidate the weak pointers held by the thread_local caches.
+  }
+
+  /// Return an instance of the value type for the current thread.
+  ValueT &get() {
+    // Check for an already existing instance for this thread.
+    CacheType &staticCache = getStaticCache();
+    std::weak_ptr<ValueT> &threadInstance = staticCache[this];
+    if (std::shared_ptr<ValueT> value = threadInstance.lock())
+      return *value;
+
+    // Otherwise, create a new instance for this thread.
+    llvm::sys::SmartScopedLock<true> threadInstanceLock(instanceMutex);
+    instances.push_back(std::make_shared<ValueT>());
+    std::shared_ptr<ValueT> &instance = instances.back();
+    threadInstance = instance;
+
+    // Before returning the new instance, take the chance to clear out any used
+    // entries in the static map. The cache is only cleared within the same
+    // thread to remove the need to lock the cache itself.
+    staticCache.clearExpiredEntries();
+    return *instance;
+  }
+  ValueT &operator*() { return get(); }
+  ValueT *operator->() { return &get(); }
+
+private:
+  ThreadLocalCache(ThreadLocalCache &&) = delete;
+  ThreadLocalCache(const ThreadLocalCache &) = delete;
+  ThreadLocalCache &operator=(const ThreadLocalCache &) = delete;
+
+  /// Return the static thread local instance of the cache type.
+  static CacheType &getStaticCache() {
+    static LLVM_THREAD_LOCAL CacheType cache;
+    return cache;
+  }
+
+  /// Remove the given value entry. This is generally called when a thread local
+  /// cache is destructing.
+  void remove(ValueT *value) {
+    // Erase the found value directly, because it is guaranteed to be in the
+    // list.
+    llvm::sys::SmartScopedLock<true> threadInstanceLock(instanceMutex);
+    auto it = llvm::find_if(instances, [&](std::shared_ptr<ValueT> &instance) {
+      return instance.get() == value;
+    });
+    assert(it != instances.end() && "expected value to exist in cache");
+    instances.erase(it);
+  }
+
+  /// Owning pointers to all of the values that have been constructed for this
+  /// object in the static cache.
+  SmallVector<std::shared_ptr<ValueT>, 1> instances;
+
+  /// A mutex used when a new thread instance has been added to the cache for
+  /// this object.
+  llvm::sys::SmartMutex<true> instanceMutex;
+};
+} // end namespace mlir
+
+#endif // MLIR_SUPPORT_THREADLOCALCACHE_H

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index df58b957bc32..42c4d4855e50 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -24,6 +24,7 @@
 #include "mlir/IR/Location.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/Types.h"
+#include "mlir/Support/ThreadLocalCache.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/SetVector.h"
@@ -280,8 +281,12 @@ class MLIRContextImpl {
   /// operations.
   llvm::StringMap<AbstractOperation> registeredOperations;
 
-  /// These are identifiers uniqued into this MLIRContext.
+  /// Identifers are uniqued by string value and use the internal string set for
+  /// storage.
   llvm::StringSet<llvm::BumpPtrAllocator &> identifiers;
+  /// A thread local cache of identifiers to reduce lock contention.
+  ThreadLocalCache<llvm::StringMap<llvm::StringMapEntry<llvm::NoneType> *>>
+      localIdentifierCache;
 
   /// An allocator used for AbstractAttribute and AbstractType objects.
   llvm::BumpPtrAllocator abstractDialectSymbolAllocator;
@@ -629,27 +634,37 @@ const AbstractType &AbstractType::lookup(TypeID typeID, MLIRContext *context) {
 
 /// Return an identifier for the specified string.
 Identifier Identifier::get(StringRef str, MLIRContext *context) {
+  // Check invariants after seeing if we already have something in the
+  // identifier table - if we already had it in the table, then it already
+  // passed invariant checks.
+  assert(!str.empty() && "Cannot create an empty identifier");
+  assert(str.find('\0') == StringRef::npos &&
+         "Cannot create an identifier with a nul character");
+
   auto &impl = context->getImpl();
+  if (!context->isMultithreadingEnabled())
+    return Identifier(&*impl.identifiers.insert(str).first);
+
+  // Check for an existing instance in the local cache.
+  auto *&localEntry = (*impl.localIdentifierCache)[str];
+  if (localEntry)
+    return Identifier(localEntry);
 
   // Check for an existing identifier in read-only mode.
   if (context->isMultithreadingEnabled()) {
     llvm::sys::SmartScopedReader<true> contextLock(impl.identifierMutex);
     auto it = impl.identifiers.find(str);
-    if (it != impl.identifiers.end())
-      return Identifier(&*it);
+    if (it != impl.identifiers.end()) {
+      localEntry = &*it;
+      return Identifier(localEntry);
+    }
   }
 
-  // Check invariants after seeing if we already have something in the
-  // identifier table - if we already had it in the table, then it already
-  // passed invariant checks.
-  assert(!str.empty() && "Cannot create an empty identifier");
-  assert(str.find('\0') == StringRef::npos &&
-         "Cannot create an identifier with a nul character");
-
   // Acquire a writer-lock so that we can safely create the new instance.
-  ScopedWriterLock contextLock(impl.identifierMutex, impl.threadingIsEnabled);
+  llvm::sys::SmartScopedWriter<true> contextLock(impl.identifierMutex);
   auto it = impl.identifiers.insert(str).first;
-  return Identifier(&*it);
+  localEntry = &*it;
+  return Identifier(localEntry);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Support/StorageUniquer.cpp b/mlir/lib/Support/StorageUniquer.cpp
index 49e7272091fb..3fb6d3733b23 100644
--- a/mlir/lib/Support/StorageUniquer.cpp
+++ b/mlir/lib/Support/StorageUniquer.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Support/StorageUniquer.h"
 
 #include "mlir/Support/LLVM.h"
+#include "mlir/Support/ThreadLocalCache.h"
 #include "mlir/Support/TypeID.h"
 #include "llvm/Support/RWMutex.h"
 
@@ -39,6 +40,8 @@ struct InstSpecificUniquer {
   /// A utility wrapper object representing a hashed storage object. This class
   /// contains a storage object and an existing computed hash value.
   struct HashedStorage {
+    HashedStorage(unsigned hashValue = 0, BaseStorage *storage = nullptr)
+        : hashValue(hashValue), storage(storage) {}
     unsigned hashValue;
     BaseStorage *storage;
   };
@@ -46,10 +49,10 @@ struct InstSpecificUniquer {
   /// Storage info for derived TypeStorage objects.
   struct StorageKeyInfo : DenseMapInfo<HashedStorage> {
     static HashedStorage getEmptyKey() {
-      return HashedStorage{0, DenseMapInfo<BaseStorage *>::getEmptyKey()};
+      return HashedStorage(0, DenseMapInfo<BaseStorage *>::getEmptyKey());
     }
     static HashedStorage getTombstoneKey() {
-      return HashedStorage{0, DenseMapInfo<BaseStorage *>::getTombstoneKey()};
+      return HashedStorage(0, DenseMapInfo<BaseStorage *>::getTombstoneKey());
     }
 
     static unsigned getHashValue(const HashedStorage &key) {
@@ -102,25 +105,34 @@ struct StorageUniquerImpl {
     if (!threadingIsEnabled)
       return getOrCreateUnsafe(storageUniquer, kind, lookupKey, ctorFn);
 
+    // Check for a instance of this object in the local cache.
+    auto localIt = complexStorageLocalCache->insert_as(
+        InstSpecificUniquer::HashedStorage(lookupKey.hashValue), lookupKey);
+    BaseStorage *&localInst = localIt.first->storage;
+    if (localInst)
+      return localInst;
+
     // Check for an existing instance in read-only mode.
     {
       llvm::sys::SmartScopedReader<true> typeLock(storageUniquer.mutex);
       auto it = storageUniquer.complexInstances.find_as(lookupKey);
       if (it != storageUniquer.complexInstances.end())
-        return it->storage;
+        return localInst = it->storage;
     }
 
     // Acquire a writer-lock so that we can safely create the new type instance.
     llvm::sys::SmartScopedWriter<true> typeLock(storageUniquer.mutex);
-    return getOrCreateUnsafe(storageUniquer, kind, lookupKey, ctorFn);
+    return localInst =
+               getOrCreateUnsafe(storageUniquer, kind, lookupKey, ctorFn);
   }
   /// Get or create an instance of a complex derived type in an thread-unsafe
   /// fashion.
   BaseStorage *
   getOrCreateUnsafe(InstSpecificUniquer &storageUniquer, unsigned kind,
-                    InstSpecificUniquer::LookupKey &lookupKey,
+                    InstSpecificUniquer::LookupKey &key,
                     function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
-    auto existing = storageUniquer.complexInstances.insert_as({}, lookupKey);
+    auto existing =
+        storageUniquer.complexInstances.insert_as({key.hashValue}, key);
     if (!existing.second)
       return existing.first->storage;
 
@@ -128,9 +140,7 @@ struct StorageUniquerImpl {
     // instance.
     BaseStorage *storage =
         initializeStorage(kind, storageUniquer.allocator, ctorFn);
-    *existing.first =
-        InstSpecificUniquer::HashedStorage{lookupKey.hashValue, storage};
-    return storage;
+    return existing.first->storage = storage;
   }
 
   /// Get or create an instance of a simple derived type.
@@ -142,6 +152,11 @@ struct StorageUniquerImpl {
     if (!threadingIsEnabled)
       return getOrCreateUnsafe(storageUniquer, kind, ctorFn);
 
+    // Check for a instance of this object in the local cache.
+    BaseStorage *&localInst = (*simpleStorageLocalCache)[kind];
+    if (localInst)
+      return localInst;
+
     // Check for an existing instance in read-only mode.
     {
       llvm::sys::SmartScopedReader<true> typeLock(storageUniquer.mutex);
@@ -152,7 +167,7 @@ struct StorageUniquerImpl {
 
     // Acquire a writer-lock so that we can safely create the new type instance.
     llvm::sys::SmartScopedWriter<true> typeLock(storageUniquer.mutex);
-    return getOrCreateUnsafe(storageUniquer, kind, ctorFn);
+    return localInst = getOrCreateUnsafe(storageUniquer, kind, ctorFn);
   }
   /// Get or create an instance of a simple derived type in an thread-unsafe
   /// fashion.
@@ -215,6 +230,12 @@ struct StorageUniquerImpl {
   /// Map of type ids to the storage uniquer to use for registered objects.
   DenseMap<TypeID, std::unique_ptr<InstSpecificUniquer>> instUniquers;
 
+  /// A thread local cache for simple and complex storage objects. This helps to
+  /// reduce the lock contention when an object already existing in the cache.
+  ThreadLocalCache<DenseMap<unsigned, BaseStorage *>> simpleStorageLocalCache;
+  ThreadLocalCache<InstSpecificUniquer::StorageTypeSet>
+      complexStorageLocalCache;
+
   /// Flag specifying if multi-threading is enabled within the uniquer.
   bool threadingIsEnabled = true;
 };

diff  --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp
index 3fcfcf24ef8f..b620062e2238 100644
--- a/mlir/test/EDSC/builder-api-test.cpp
+++ b/mlir/test/EDSC/builder-api-test.cpp
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-// RUN: mlir-edsc-builder-api-test | FileCheck %s
+// RUN: mlir-edsc-builder-api-test
 
 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
 #include "mlir/Dialect/Linalg/EDSC/Builders.h"


        


More information about the Mlir-commits mailing list