[Mlir-commits] [mlir] Revert "[mlir] Optimize ThreadLocalCache by removing atomic bottleneck" (PR #93306)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri May 24 07:47:04 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Jacques Pienaar (jpienaar)

<details>
<summary>Changes</summary>

Reverts llvm/llvm-project#<!-- -->93270

This was found to have a race and the forward fix was reverted, reverting this until can forward fix.

---
Full diff: https://github.com/llvm/llvm-project/pull/93306.diff


1 Files Affected:

- (modified) mlir/include/mlir/Support/ThreadLocalCache.h (+11-17) 


``````````diff
diff --git a/mlir/include/mlir/Support/ThreadLocalCache.h b/mlir/include/mlir/Support/ThreadLocalCache.h
index d19257bf6e25e..1be94ca14bcfa 100644
--- a/mlir/include/mlir/Support/ThreadLocalCache.h
+++ b/mlir/include/mlir/Support/ThreadLocalCache.h
@@ -58,12 +58,11 @@ class ThreadLocalCache {
   /// ValueT. We use a weak reference here so that the object can be destroyed
   /// without needing to lock access to the cache itself.
   struct CacheType
-      : public llvm::SmallDenseMap<PerInstanceState *,
-                                   std::pair<std::weak_ptr<ValueT>, ValueT *>> {
+      : public llvm::SmallDenseMap<PerInstanceState *, std::weak_ptr<ValueT>> {
     ~CacheType() {
       // Remove the values of this cache that haven't already expired.
       for (auto &it : *this)
-        if (std::shared_ptr<ValueT> value = it.second.first.lock())
+        if (std::shared_ptr<ValueT> value = it.second.lock())
           it.first->remove(value.get());
     }
 
@@ -72,7 +71,7 @@ class ThreadLocalCache {
     void clearExpiredEntries() {
       for (auto it = this->begin(), e = this->end(); it != e;) {
         auto curIt = it++;
-        if (curIt->second.first.expired())
+        if (curIt->second.expired())
           this->erase(curIt);
       }
     }
@@ -89,27 +88,22 @@ class ThreadLocalCache {
   ValueT &get() {
     // Check for an already existing instance for this thread.
     CacheType &staticCache = getStaticCache();
-    std::pair<std::weak_ptr<ValueT>, ValueT *> &threadInstance =
-        staticCache[perInstanceState.get()];
-    if (ValueT *value = threadInstance.second)
+    std::weak_ptr<ValueT> &threadInstance = staticCache[perInstanceState.get()];
+    if (std::shared_ptr<ValueT> value = threadInstance.lock())
       return *value;
 
     // Otherwise, create a new instance for this thread.
-    {
-      llvm::sys::SmartScopedLock<true> threadInstanceLock(
-          perInstanceState->instanceMutex);
-      threadInstance.second =
-          perInstanceState->instances.emplace_back(std::make_unique<ValueT>())
-              .get();
-    }
-    threadInstance.first =
-        std::shared_ptr<ValueT>(perInstanceState, threadInstance.second);
+    llvm::sys::SmartScopedLock<true> threadInstanceLock(
+        perInstanceState->instanceMutex);
+    perInstanceState->instances.push_back(std::make_unique<ValueT>());
+    ValueT *instance = perInstanceState->instances.back().get();
+    threadInstance = std::shared_ptr<ValueT>(perInstanceState, instance);
 
     // Before returning the new instance, take the chance to clear out any used
     // entries in the static map. The cache is only cleared within the same
     // thread to remove the need to lock the cache itself.
     staticCache.clearExpiredEntries();
-    return *threadInstance.second;
+    return *instance;
   }
   ValueT &operator*() { return get(); }
   ValueT *operator->() { return &get(); }

``````````

</details>


https://github.com/llvm/llvm-project/pull/93306


More information about the Mlir-commits mailing list