[Mlir-commits] [mlir] ce2b488 - [mlir] ThreadLocalCache: make TSAN happy about destructors (#106170)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Aug 26 19:59:53 PDT 2024
Author: Jeff Niu
Date: 2024-08-26T19:59:49-07:00
New Revision: ce2b488e90d6d5a5c8fe495ede8238938827da39
URL: https://github.com/llvm/llvm-project/commit/ce2b488e90d6d5a5c8fe495ede8238938827da39
DIFF: https://github.com/llvm/llvm-project/commit/ce2b488e90d6d5a5c8fe495ede8238938827da39.diff
LOG: [mlir] ThreadLocalCache: make TSAN happy about destructors (#106170)
TSAN warns that `ptr` is read and write without protection in
`clearExpiredEntries` and in the destructor of `Owner`. Add an atomic
bool to synchronize these without incurring a cost when calling `get`.
Added:
Modified:
mlir/include/mlir/Support/ThreadLocalCache.h
Removed:
################################################################################
diff --git a/mlir/include/mlir/Support/ThreadLocalCache.h b/mlir/include/mlir/Support/ThreadLocalCache.h
index 87cc52cc56ac4f..53b6d31a09555b 100644
--- a/mlir/include/mlir/Support/ThreadLocalCache.h
+++ b/mlir/include/mlir/Support/ThreadLocalCache.h
@@ -27,6 +27,8 @@ template <typename ValueT>
class ThreadLocalCache {
struct PerInstanceState;
+ using PointerAndFlag = std::pair<ValueT *, std::atomic<bool>>;
+
/// The "observer" is owned by a thread-local cache instance. It is
/// constructed the first time a `ThreadLocalCache` instance is accessed by a
/// thread, unless `perInstanceState` happens to get re-allocated to the same
@@ -41,7 +43,8 @@ class ThreadLocalCache {
/// This is the double pointer, explicitly allocated because we need to keep
/// the address stable if the TLC map re-allocates. It is owned by the
/// observer and shared with the value owner.
- std::shared_ptr<ValueT *> ptr = std::make_shared<ValueT *>(nullptr);
+ std::shared_ptr<PointerAndFlag> ptr =
+ std::make_shared<PointerAndFlag>(std::make_pair(nullptr, false));
/// Because the `Owner` instance that lives inside `PerInstanceState`
/// contains a reference to the double pointer, and likewise this class
/// contains a reference to the value, we need to synchronize destruction of
@@ -62,18 +65,21 @@ class ThreadLocalCache {
/// Save a pointer to the reference and write it to the newly created entry.
Owner(Observer &observer)
: value(std::make_unique<ValueT>()), ptrRef(observer.ptr) {
- *observer.ptr = value.get();
+ observer.ptr->second = true;
+ observer.ptr->first = value.get();
}
~Owner() {
- if (std::shared_ptr<ValueT *> ptr = ptrRef.lock())
- *ptr = nullptr;
+ if (std::shared_ptr<PointerAndFlag> ptr = ptrRef.lock()) {
+ ptr->first = nullptr;
+ ptr->second = false;
+ }
}
Owner(Owner &&) = default;
Owner &operator=(Owner &&) = default;
std::unique_ptr<ValueT> value;
- std::weak_ptr<ValueT *> ptrRef;
+ std::weak_ptr<PointerAndFlag> ptrRef;
};
// Keep a separate shared_ptr protected state that can be acquired atomically
@@ -116,7 +122,7 @@ class ThreadLocalCache {
// back to the data here that is being destroyed.
for (auto &[instance, observer] : *this)
if (std::shared_ptr<PerInstanceState> state = observer.keepalive.lock())
- state->remove(*observer.ptr);
+ state->remove(observer.ptr->first);
}
/// Clear out any unused entries within the map. This method is not
@@ -124,7 +130,7 @@ class ThreadLocalCache {
void clearExpiredEntries() {
for (auto it = this->begin(), e = this->end(); it != e;) {
auto curIt = it++;
- if (!*curIt->second.ptr)
+ if (!curIt->second.ptr->second)
this->erase(curIt);
}
}
@@ -142,7 +148,7 @@ class ThreadLocalCache {
// Check for an already existing instance for this thread.
CacheType &staticCache = getStaticCache();
Observer &threadInstance = staticCache[perInstanceState.get()];
- if (ValueT *value = *threadInstance.ptr)
+ if (ValueT *value = threadInstance.ptr->first)
return *value;
// Otherwise, create a new instance for this thread.
@@ -157,7 +163,7 @@ class ThreadLocalCache {
// entries in the static map. The cache is only cleared within the same
// thread to remove the need to lock the cache itself.
staticCache.clearExpiredEntries();
- return **threadInstance.ptr;
+ return *threadInstance.ptr->first;
}
ValueT &operator*() { return get(); }
ValueT *operator->() { return &get(); }
More information about the Mlir-commits
mailing list