[Mlir-commits] [mlir] bcc1081 - Fix tsan problem where the per-thread shared_ptr() can be locked right before the cache is destroyed causing a race where it tries to remove an entry from a destroyed cache.
Benjamin Kramer
llvmlistbot at llvm.org
Fri Jan 27 09:56:54 PST 2023
Author: Qiao Zhang
Date: 2023-01-27T18:49:32+01:00
New Revision: bcc10817d5569172ee065015747e226280e9b698
URL: https://github.com/llvm/llvm-project/commit/bcc10817d5569172ee065015747e226280e9b698
DIFF: https://github.com/llvm/llvm-project/commit/bcc10817d5569172ee065015747e226280e9b698.diff
LOG: Fix tsan problem where the per-thread shared_ptr() can be locked right before the cache is destroyed causing a race where it tries to remove an entry from a destroyed cache.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D142394
Added:
Modified:
mlir/include/mlir/Support/ThreadLocalCache.h
Removed:
################################################################################
diff --git a/mlir/include/mlir/Support/ThreadLocalCache.h b/mlir/include/mlir/Support/ThreadLocalCache.h
index e98fae6b117ae..1be94ca14bcfa 100644
--- a/mlir/include/mlir/Support/ThreadLocalCache.h
+++ b/mlir/include/mlir/Support/ThreadLocalCache.h
@@ -25,12 +25,40 @@ namespace mlir {
/// cache has very large lock contention.
template <typename ValueT>
class ThreadLocalCache {
+ // Keep a separate shared_ptr protected state that can be acquired atomically
+ // instead of using shared_ptr's for each value. This avoids a problem
+ // where the instance shared_ptr is locked() successfully, and then the
+ // ThreadLocalCache gets destroyed before remove() can be called successfully.
+ struct PerInstanceState {
+ /// Remove the given value entry. This is generally called when a thread
+ /// local cache is destructing.
+ void remove(ValueT *value) {
+ // Erase the found value directly, because it is guaranteed to be in the
+ // list.
+ llvm::sys::SmartScopedLock<true> threadInstanceLock(instanceMutex);
+ auto it =
+ llvm::find_if(instances, [&](std::unique_ptr<ValueT> &instance) {
+ return instance.get() == value;
+ });
+ assert(it != instances.end() && "expected value to exist in cache");
+ instances.erase(it);
+ }
+
+ /// Owning pointers to all of the values that have been constructed for this
+ /// object in the static cache.
+ SmallVector<std::unique_ptr<ValueT>, 1> instances;
+
+ /// A mutex used when a new thread instance has been added to the cache for
+ /// this object.
+ llvm::sys::SmartMutex<true> instanceMutex;
+ };
+
/// The type used for the static thread_local cache. This is a map between an
/// instance of the non-static cache and a weak reference to an instance of
/// ValueT. We use a weak reference here so that the object can be destroyed
/// without needing to lock access to the cache itself.
- struct CacheType : public llvm::SmallDenseMap<ThreadLocalCache<ValueT> *,
- std::weak_ptr<ValueT>> {
+ struct CacheType
+ : public llvm::SmallDenseMap<PerInstanceState *, std::weak_ptr<ValueT>> {
~CacheType() {
// Remove the values of this cache that haven't already expired.
for (auto &it : *this)
@@ -60,15 +88,16 @@ class ThreadLocalCache {
ValueT &get() {
// Check for an already existing instance for this thread.
CacheType &staticCache = getStaticCache();
- std::weak_ptr<ValueT> &threadInstance = staticCache[this];
+ std::weak_ptr<ValueT> &threadInstance = staticCache[perInstanceState.get()];
if (std::shared_ptr<ValueT> value = threadInstance.lock())
return *value;
// Otherwise, create a new instance for this thread.
- llvm::sys::SmartScopedLock<true> threadInstanceLock(instanceMutex);
- instances.push_back(std::make_shared<ValueT>());
- std::shared_ptr<ValueT> &instance = instances.back();
- threadInstance = instance;
+ llvm::sys::SmartScopedLock<true> threadInstanceLock(
+ perInstanceState->instanceMutex);
+ perInstanceState->instances.push_back(std::make_unique<ValueT>());
+ ValueT *instance = perInstanceState->instances.back().get();
+ threadInstance = std::shared_ptr<ValueT>(perInstanceState, instance);
// Before returning the new instance, take the chance to clear out any used
// entries in the static map. The cache is only cleared within the same
@@ -90,26 +119,8 @@ class ThreadLocalCache {
return cache;
}
- /// Remove the given value entry. This is generally called when a thread local
- /// cache is destructing.
- void remove(ValueT *value) {
- // Erase the found value directly, because it is guaranteed to be in the
- // list.
- llvm::sys::SmartScopedLock<true> threadInstanceLock(instanceMutex);
- auto it = llvm::find_if(instances, [&](std::shared_ptr<ValueT> &instance) {
- return instance.get() == value;
- });
- assert(it != instances.end() && "expected value to exist in cache");
- instances.erase(it);
- }
-
- /// Owning pointers to all of the values that have been constructed for this
- /// object in the static cache.
- SmallVector<std::shared_ptr<ValueT>, 1> instances;
-
- /// A mutex used when a new thread instance has been added to the cache for
- /// this object.
- llvm::sys::SmartMutex<true> instanceMutex;
+ std::shared_ptr<PerInstanceState> perInstanceState =
+ std::make_shared<PerInstanceState>();
};
} // namespace mlir
More information about the Mlir-commits
mailing list