[Mlir-commits] [mlir] [mlir] ThreadLocalCache: make TSAN happy about destructors (PR #106170)

Jeff Niu llvmlistbot at llvm.org
Mon Aug 26 19:07:31 PDT 2024


https://github.com/Mogball created https://github.com/llvm/llvm-project/pull/106170

TSAN warns that `ptr` is read and write without protection in `clearExpiredEntries` and in the destructor of `Owner`. Add an atomic bool to synchronize these without incurring a cost when calling `get`.

>From d0482487d2d74bbc0d11386ec5f3d00a63ea7ff4 Mon Sep 17 00:00:00 2001
From: Mogball <jeff at modular.com>
Date: Mon, 26 Aug 2024 22:05:18 -0400
Subject: [PATCH] [mlir] ThreadLocalCache: make TSAN happy about destructors

TSAN warns that `ptr` is read and write without protection in
`clearExpiredEntries` and in the destructor of `Owner`. Add an atomic
bool to synchronize these without incurring a cost when calling `get`.
---
 mlir/include/mlir/Support/ThreadLocalCache.h | 24 ++++++++++++--------
 1 file changed, 15 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Support/ThreadLocalCache.h b/mlir/include/mlir/Support/ThreadLocalCache.h
index 87cc52cc56ac4f..53b6d31a09555b 100644
--- a/mlir/include/mlir/Support/ThreadLocalCache.h
+++ b/mlir/include/mlir/Support/ThreadLocalCache.h
@@ -27,6 +27,8 @@ template <typename ValueT>
 class ThreadLocalCache {
   struct PerInstanceState;
 
+  using PointerAndFlag = std::pair<ValueT *, std::atomic<bool>>;
+
   /// The "observer" is owned by a thread-local cache instance. It is
   /// constructed the first time a `ThreadLocalCache` instance is accessed by a
   /// thread, unless `perInstanceState` happens to get re-allocated to the same
@@ -41,7 +43,8 @@ class ThreadLocalCache {
     /// This is the double pointer, explicitly allocated because we need to keep
     /// the address stable if the TLC map re-allocates. It is owned by the
     /// observer and shared with the value owner.
-    std::shared_ptr<ValueT *> ptr = std::make_shared<ValueT *>(nullptr);
+    std::shared_ptr<PointerAndFlag> ptr =
+        std::make_shared<PointerAndFlag>(std::make_pair(nullptr, false));
     /// Because the `Owner` instance that lives inside `PerInstanceState`
     /// contains a reference to the double pointer, and likewise this class
     /// contains a reference to the value, we need to synchronize destruction of
@@ -62,18 +65,21 @@ class ThreadLocalCache {
     /// Save a pointer to the reference and write it to the newly created entry.
     Owner(Observer &observer)
         : value(std::make_unique<ValueT>()), ptrRef(observer.ptr) {
-      *observer.ptr = value.get();
+      observer.ptr->second = true;
+      observer.ptr->first = value.get();
     }
     ~Owner() {
-      if (std::shared_ptr<ValueT *> ptr = ptrRef.lock())
-        *ptr = nullptr;
+      if (std::shared_ptr<PointerAndFlag> ptr = ptrRef.lock()) {
+        ptr->first = nullptr;
+        ptr->second = false;
+      }
     }
 
     Owner(Owner &&) = default;
     Owner &operator=(Owner &&) = default;
 
     std::unique_ptr<ValueT> value;
-    std::weak_ptr<ValueT *> ptrRef;
+    std::weak_ptr<PointerAndFlag> ptrRef;
   };
 
   // Keep a separate shared_ptr protected state that can be acquired atomically
@@ -116,7 +122,7 @@ class ThreadLocalCache {
       // back to the data here that is being destroyed.
       for (auto &[instance, observer] : *this)
         if (std::shared_ptr<PerInstanceState> state = observer.keepalive.lock())
-          state->remove(*observer.ptr);
+          state->remove(observer.ptr->first);
     }
 
     /// Clear out any unused entries within the map. This method is not
@@ -124,7 +130,7 @@ class ThreadLocalCache {
     void clearExpiredEntries() {
       for (auto it = this->begin(), e = this->end(); it != e;) {
         auto curIt = it++;
-        if (!*curIt->second.ptr)
+        if (!curIt->second.ptr->second)
           this->erase(curIt);
       }
     }
@@ -142,7 +148,7 @@ class ThreadLocalCache {
     // Check for an already existing instance for this thread.
     CacheType &staticCache = getStaticCache();
     Observer &threadInstance = staticCache[perInstanceState.get()];
-    if (ValueT *value = *threadInstance.ptr)
+    if (ValueT *value = threadInstance.ptr->first)
       return *value;
 
     // Otherwise, create a new instance for this thread.
@@ -157,7 +163,7 @@ class ThreadLocalCache {
     // entries in the static map. The cache is only cleared within the same
     // thread to remove the need to lock the cache itself.
     staticCache.clearExpiredEntries();
-    return **threadInstance.ptr;
+    return *threadInstance.ptr->first;
   }
   ValueT &operator*() { return get(); }
   ValueT *operator->() { return &get(); }



More information about the Mlir-commits mailing list