[Mlir-commits] [mlir] [mlir] Optimize ThreadLocalCache by removing atomic bottleneck (PR #93270)

Jeff Niu llvmlistbot at llvm.org
Thu May 23 20:37:04 PDT 2024


https://github.com/Mogball created https://github.com/llvm/llvm-project/pull/93270

The ThreadLocalCache implementation is used by the MLIRContext (among other things) to try to manage thread contention in the StorageUniquers. There is a bunch of fancy shared pointer/weak pointer setups that basically keeps everything alive across threads at the right time, but a huge bottleneck is the `weak_ptr::lock` call inside the `::get` method.

This is because the `lock` method has to hit the atomic refcount several times, and this is bottlenecking performance across many threads. However, all this is doing is checking whether the storage is initialized. We know that it cannot be an expired weak pointer because the thread local cache object we're calling into owns the memory and is still alive for the method call to be valid. Thus, we can store and extra `Value *` inside the thread local cache for speedy retrieval if the cache is already initialized for the thread, which is the common case.

Before:

<img width="560" alt="image" src="https://github.com/llvm/llvm-project/assets/15016832/f4ea3f32-6649-4c10-88c4-b7522031e8c9">

After:

<img width="344" alt="image" src="https://github.com/llvm/llvm-project/assets/15016832/1216db25-3dc1-4b0f-be89-caeff622dd35">


>From 69905172f7c37f3ce70098b136e4437841f75a44 Mon Sep 17 00:00:00 2001
From: Mogball <jeff at modular.com>
Date: Thu, 23 May 2024 20:32:13 -0700
Subject: [PATCH] [mlir] Optimize ThreadLocalCache by removing atomic
 bottleneck

The ThreadLocalCache implementation is used by the MLIRContext (among
other things) to try to manage thread contention in the StorageUniquers.
There is a bunch of fancy shared pointer/weak pointer setups that
basically keeps everything alive across threads at the right time, but a
huge bottleneck is the `weak_ptr::lock` call inside the `::get` method.

This is because the `lock` method has to hit the atomic refcount several
times, and this is bottlenecking performance across many threads.
However, all this is doing is checking whether the storage is
initialized. We know that it cannot be an expired weak pointer because
the thread local cache object we're calling into owns the memory and is
still alive for the method call to be valid. Thus, we can store and
extra `Value *` inside the thread local cache for speedy retrieval if
the cache is already initialized for the thread, which is the common
case.
---
 mlir/include/mlir/Support/ThreadLocalCache.h | 28 ++++++++++++--------
 1 file changed, 17 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Support/ThreadLocalCache.h b/mlir/include/mlir/Support/ThreadLocalCache.h
index 1be94ca14bcfa..d19257bf6e25e 100644
--- a/mlir/include/mlir/Support/ThreadLocalCache.h
+++ b/mlir/include/mlir/Support/ThreadLocalCache.h
@@ -58,11 +58,12 @@ class ThreadLocalCache {
   /// ValueT. We use a weak reference here so that the object can be destroyed
   /// without needing to lock access to the cache itself.
   struct CacheType
-      : public llvm::SmallDenseMap<PerInstanceState *, std::weak_ptr<ValueT>> {
+      : public llvm::SmallDenseMap<PerInstanceState *,
+                                   std::pair<std::weak_ptr<ValueT>, ValueT *>> {
     ~CacheType() {
       // Remove the values of this cache that haven't already expired.
       for (auto &it : *this)
-        if (std::shared_ptr<ValueT> value = it.second.lock())
+        if (std::shared_ptr<ValueT> value = it.second.first.lock())
           it.first->remove(value.get());
     }
 
@@ -71,7 +72,7 @@ class ThreadLocalCache {
     void clearExpiredEntries() {
       for (auto it = this->begin(), e = this->end(); it != e;) {
         auto curIt = it++;
-        if (curIt->second.expired())
+        if (curIt->second.first.expired())
           this->erase(curIt);
       }
     }
@@ -88,22 +89,27 @@ class ThreadLocalCache {
   ValueT &get() {
     // Check for an already existing instance for this thread.
     CacheType &staticCache = getStaticCache();
-    std::weak_ptr<ValueT> &threadInstance = staticCache[perInstanceState.get()];
-    if (std::shared_ptr<ValueT> value = threadInstance.lock())
+    std::pair<std::weak_ptr<ValueT>, ValueT *> &threadInstance =
+        staticCache[perInstanceState.get()];
+    if (ValueT *value = threadInstance.second)
       return *value;
 
     // Otherwise, create a new instance for this thread.
-    llvm::sys::SmartScopedLock<true> threadInstanceLock(
-        perInstanceState->instanceMutex);
-    perInstanceState->instances.push_back(std::make_unique<ValueT>());
-    ValueT *instance = perInstanceState->instances.back().get();
-    threadInstance = std::shared_ptr<ValueT>(perInstanceState, instance);
+    {
+      llvm::sys::SmartScopedLock<true> threadInstanceLock(
+          perInstanceState->instanceMutex);
+      threadInstance.second =
+          perInstanceState->instances.emplace_back(std::make_unique<ValueT>())
+              .get();
+    }
+    threadInstance.first =
+        std::shared_ptr<ValueT>(perInstanceState, threadInstance.second);
 
     // Before returning the new instance, take the chance to clear out any used
     // entries in the static map. The cache is only cleared within the same
     // thread to remove the need to lock the cache itself.
     staticCache.clearExpiredEntries();
-    return *instance;
+    return *threadInstance.second;
   }
   ValueT &operator*() { return get(); }
   ValueT *operator->() { return &get(); }



More information about the Mlir-commits mailing list