[Mlir-commits] [mlir] bcc1081 - Fix tsan problem where the per-thread shared_ptr() can be locked right before the cache is destroyed causing a race where it tries to remove an entry from a destroyed cache.

Benjamin Kramer llvmlistbot at llvm.org
Fri Jan 27 09:56:54 PST 2023


Author: Qiao Zhang
Date: 2023-01-27T18:49:32+01:00
New Revision: bcc10817d5569172ee065015747e226280e9b698

URL: https://github.com/llvm/llvm-project/commit/bcc10817d5569172ee065015747e226280e9b698
DIFF: https://github.com/llvm/llvm-project/commit/bcc10817d5569172ee065015747e226280e9b698.diff

LOG: Fix tsan problem where the per-thread shared_ptr() can be locked right before the cache is destroyed causing a race where it tries to remove an entry from a destroyed cache.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D142394

Added: 
    

Modified: 
    mlir/include/mlir/Support/ThreadLocalCache.h

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Support/ThreadLocalCache.h b/mlir/include/mlir/Support/ThreadLocalCache.h
index e98fae6b117ae..1be94ca14bcfa 100644
--- a/mlir/include/mlir/Support/ThreadLocalCache.h
+++ b/mlir/include/mlir/Support/ThreadLocalCache.h
@@ -25,12 +25,40 @@ namespace mlir {
 /// cache has very large lock contention.
 template <typename ValueT>
 class ThreadLocalCache {
+  // Keep a separate shared_ptr protected state that can be acquired atomically
+  // instead of using shared_ptr's for each value. This avoids a problem
+  // where the instance shared_ptr is locked() successfully, and then the
+  // ThreadLocalCache gets destroyed before remove() can be called successfully.
+  struct PerInstanceState {
+    /// Remove the given value entry. This is generally called when a thread
+    /// local cache is destructing.
+    void remove(ValueT *value) {
+      // Erase the found value directly, because it is guaranteed to be in the
+      // list.
+      llvm::sys::SmartScopedLock<true> threadInstanceLock(instanceMutex);
+      auto it =
+          llvm::find_if(instances, [&](std::unique_ptr<ValueT> &instance) {
+            return instance.get() == value;
+          });
+      assert(it != instances.end() && "expected value to exist in cache");
+      instances.erase(it);
+    }
+
+    /// Owning pointers to all of the values that have been constructed for this
+    /// object in the static cache.
+    SmallVector<std::unique_ptr<ValueT>, 1> instances;
+
+    /// A mutex used when a new thread instance has been added to the cache for
+    /// this object.
+    llvm::sys::SmartMutex<true> instanceMutex;
+  };
+
   /// The type used for the static thread_local cache. This is a map between an
   /// instance of the non-static cache and a weak reference to an instance of
   /// ValueT. We use a weak reference here so that the object can be destroyed
   /// without needing to lock access to the cache itself.
-  struct CacheType : public llvm::SmallDenseMap<ThreadLocalCache<ValueT> *,
-                                                std::weak_ptr<ValueT>> {
+  struct CacheType
+      : public llvm::SmallDenseMap<PerInstanceState *, std::weak_ptr<ValueT>> {
     ~CacheType() {
       // Remove the values of this cache that haven't already expired.
       for (auto &it : *this)
@@ -60,15 +88,16 @@ class ThreadLocalCache {
   ValueT &get() {
     // Check for an already existing instance for this thread.
     CacheType &staticCache = getStaticCache();
-    std::weak_ptr<ValueT> &threadInstance = staticCache[this];
+    std::weak_ptr<ValueT> &threadInstance = staticCache[perInstanceState.get()];
     if (std::shared_ptr<ValueT> value = threadInstance.lock())
       return *value;
 
     // Otherwise, create a new instance for this thread.
-    llvm::sys::SmartScopedLock<true> threadInstanceLock(instanceMutex);
-    instances.push_back(std::make_shared<ValueT>());
-    std::shared_ptr<ValueT> &instance = instances.back();
-    threadInstance = instance;
+    llvm::sys::SmartScopedLock<true> threadInstanceLock(
+        perInstanceState->instanceMutex);
+    perInstanceState->instances.push_back(std::make_unique<ValueT>());
+    ValueT *instance = perInstanceState->instances.back().get();
+    threadInstance = std::shared_ptr<ValueT>(perInstanceState, instance);
 
     // Before returning the new instance, take the chance to clear out any used
     // entries in the static map. The cache is only cleared within the same
@@ -90,26 +119,8 @@ class ThreadLocalCache {
     return cache;
   }
 
-  /// Remove the given value entry. This is generally called when a thread local
-  /// cache is destructing.
-  void remove(ValueT *value) {
-    // Erase the found value directly, because it is guaranteed to be in the
-    // list.
-    llvm::sys::SmartScopedLock<true> threadInstanceLock(instanceMutex);
-    auto it = llvm::find_if(instances, [&](std::shared_ptr<ValueT> &instance) {
-      return instance.get() == value;
-    });
-    assert(it != instances.end() && "expected value to exist in cache");
-    instances.erase(it);
-  }
-
-  /// Owning pointers to all of the values that have been constructed for this
-  /// object in the static cache.
-  SmallVector<std::shared_ptr<ValueT>, 1> instances;
-
-  /// A mutex used when a new thread instance has been added to the cache for
-  /// this object.
-  llvm::sys::SmartMutex<true> instanceMutex;
+  std::shared_ptr<PerInstanceState> perInstanceState =
+      std::make_shared<PerInstanceState>();
 };
 } // namespace mlir
 


        


More information about the Mlir-commits mailing list