[Mlir-commits] [mlir] Revert "[mlir] Optimize ThreadLocalCache by removing atomic bottleneck" (PR #93306)

Jacques Pienaar llvmlistbot at llvm.org
Fri May 24 07:46:34 PDT 2024


https://github.com/jpienaar created https://github.com/llvm/llvm-project/pull/93306

Reverts llvm/llvm-project#93270

This was found to have a race and the forward fix was reverted, reverting this until can forward fix.

>From 50cb2160464ef916146c199ab8c9fe3ea9f054d6 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jacques+gh at japienaar.info>
Date: Fri, 24 May 2024 07:45:37 -0700
Subject: [PATCH] Revert "[mlir] Optimize ThreadLocalCache by removing atomic
 bottleneck (#93270)"

This reverts commit 1b803fe53dda11ca63f008f5ab8e206f3954ce48.
---
 mlir/include/mlir/Support/ThreadLocalCache.h | 28 ++++++++------------
 1 file changed, 11 insertions(+), 17 deletions(-)

diff --git a/mlir/include/mlir/Support/ThreadLocalCache.h b/mlir/include/mlir/Support/ThreadLocalCache.h
index d19257bf6e25e..1be94ca14bcfa 100644
--- a/mlir/include/mlir/Support/ThreadLocalCache.h
+++ b/mlir/include/mlir/Support/ThreadLocalCache.h
@@ -58,12 +58,11 @@ class ThreadLocalCache {
   /// ValueT. We use a weak reference here so that the object can be destroyed
   /// without needing to lock access to the cache itself.
   struct CacheType
-      : public llvm::SmallDenseMap<PerInstanceState *,
-                                   std::pair<std::weak_ptr<ValueT>, ValueT *>> {
+      : public llvm::SmallDenseMap<PerInstanceState *, std::weak_ptr<ValueT>> {
     ~CacheType() {
       // Remove the values of this cache that haven't already expired.
       for (auto &it : *this)
-        if (std::shared_ptr<ValueT> value = it.second.first.lock())
+        if (std::shared_ptr<ValueT> value = it.second.lock())
           it.first->remove(value.get());
     }
 
@@ -72,7 +71,7 @@ class ThreadLocalCache {
     void clearExpiredEntries() {
       for (auto it = this->begin(), e = this->end(); it != e;) {
         auto curIt = it++;
-        if (curIt->second.first.expired())
+        if (curIt->second.expired())
           this->erase(curIt);
       }
     }
@@ -89,27 +88,22 @@ class ThreadLocalCache {
   ValueT &get() {
     // Check for an already existing instance for this thread.
     CacheType &staticCache = getStaticCache();
-    std::pair<std::weak_ptr<ValueT>, ValueT *> &threadInstance =
-        staticCache[perInstanceState.get()];
-    if (ValueT *value = threadInstance.second)
+    std::weak_ptr<ValueT> &threadInstance = staticCache[perInstanceState.get()];
+    if (std::shared_ptr<ValueT> value = threadInstance.lock())
       return *value;
 
     // Otherwise, create a new instance for this thread.
-    {
-      llvm::sys::SmartScopedLock<true> threadInstanceLock(
-          perInstanceState->instanceMutex);
-      threadInstance.second =
-          perInstanceState->instances.emplace_back(std::make_unique<ValueT>())
-              .get();
-    }
-    threadInstance.first =
-        std::shared_ptr<ValueT>(perInstanceState, threadInstance.second);
+    llvm::sys::SmartScopedLock<true> threadInstanceLock(
+        perInstanceState->instanceMutex);
+    perInstanceState->instances.push_back(std::make_unique<ValueT>());
+    ValueT *instance = perInstanceState->instances.back().get();
+    threadInstance = std::shared_ptr<ValueT>(perInstanceState, instance);
 
     // Before returning the new instance, take the chance to clear out any used
     // entries in the static map. The cache is only cleared within the same
     // thread to remove the need to lock the cache itself.
     staticCache.clearExpiredEntries();
-    return *threadInstance.second;
+    return *instance;
   }
   ValueT &operator*() { return get(); }
   ValueT *operator->() { return &get(); }



More information about the Mlir-commits mailing list