[Mlir-commits] [mlir] Revert "[mlir] Fix race condition introduced in ThreadLocalCache (#93… (PR #93290)
Kiran Chandramohan
llvmlistbot at llvm.org
Fri May 24 03:20:57 PDT 2024
https://github.com/kiranchandramohan created https://github.com/llvm/llvm-project/pull/93290
…280)"
This reverts commit 6977bfb57c3efb9488aef463cd7ea521fd25a067.
>From 2790b4d9e63c13d1e692cc301bbd373b10f28070 Mon Sep 17 00:00:00 2001
From: Kiran Chandramohan <kiranchandramohan at gmail.com>
Date: Fri, 24 May 2024 11:18:55 +0100
Subject: [PATCH] Revert "[mlir] Fix race condition introduced in
ThreadLocalCache (#93280)"
This reverts commit 6977bfb57c3efb9488aef463cd7ea521fd25a067.
---
mlir/include/mlir/Support/ThreadLocalCache.h | 97 +++++---------------
1 file changed, 25 insertions(+), 72 deletions(-)
diff --git a/mlir/include/mlir/Support/ThreadLocalCache.h b/mlir/include/mlir/Support/ThreadLocalCache.h
index fe6c6fa3cf6bd..d19257bf6e25e 100644
--- a/mlir/include/mlir/Support/ThreadLocalCache.h
+++ b/mlir/include/mlir/Support/ThreadLocalCache.h
@@ -16,6 +16,7 @@
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/Mutex.h"
namespace mlir {
@@ -24,80 +25,28 @@ namespace mlir {
/// cache has very large lock contention.
template <typename ValueT>
class ThreadLocalCache {
- struct PerInstanceState;
-
- /// The "observer" is owned by a thread-local cache instance. It is
- /// constructed the first time a `ThreadLocalCache` instance is accessed by a
- /// thread, unless `perInstanceState` happens to get re-allocated to the same
- /// address as a previous one. This class is destructed the thread in which
- /// the `thread_local` cache lives is destroyed.
- ///
- /// This class is called the "observer" because while values cached in
- /// thread-local caches are owned by `PerInstanceState`, a reference is stored
- /// via this class in the TLC. With a double pointer, it knows when the
- /// referenced value has been destroyed.
- struct Observer {
- /// This is the double pointer, explicitly allocated because we need to keep
- /// the address stable if the TLC map re-allocates. It is owned by the
- /// observer and shared with the value owner.
- std::shared_ptr<ValueT *> ptr = std::make_shared<ValueT *>(nullptr);
- /// Because `Owner` living inside `PerInstanceState` contains a reference to
- /// the double pointer, and livkewise this class contains a reference to the
- /// value, we need to synchronize destruction of the TLC and the
- /// `PerInstanceState` to avoid racing. This weak pointer is acquired during
- /// TLC destruction if the `PerInstanceState` hasn't entered its destructor
- /// yet, and prevents it from happening.
- std::weak_ptr<PerInstanceState> keepalive;
- };
-
- /// This struct owns the cache entries. It contains a reference back to the
- /// reference inside the cache so that it can be written to null to indicate
- /// that the cache entry is invalidated. It needs to do this because
- /// `perInstanceState` could get re-allocated to the same pointer and we don't
- /// remove entries from the TLC when it is deallocated. Thus, we have to reset
- /// the TLC entries to a starting state in case the `ThreadLocalCache` lives
- /// shorter than the threads.
- struct Owner {
- /// Save a pointer to the reference and write it to the newly created entry.
- Owner(Observer &observer)
- : value(std::make_unique<ValueT>()), ptrRef(observer.ptr) {
- *observer.ptr = value.get();
- }
- ~Owner() {
- if (std::shared_ptr<ValueT *> ptr = ptrRef.lock())
- *ptr = nullptr;
- }
-
- Owner(Owner &&) = default;
- Owner &operator=(Owner &&) = default;
-
- std::unique_ptr<ValueT> value;
- std::weak_ptr<ValueT *> ptrRef;
- };
-
// Keep a separate shared_ptr protected state that can be acquired atomically
// instead of using shared_ptr's for each value. This avoids a problem
// where the instance shared_ptr is locked() successfully, and then the
// ThreadLocalCache gets destroyed before remove() can be called successfully.
struct PerInstanceState {
- /// Remove the given value entry. This is called when a thread local cache
- /// is destructing but still contains references to values owned by the
- /// `PerInstanceState`. Removal is required because it prevents writeback to
- /// a pointer that was deallocated.
+ /// Remove the given value entry. This is generally called when a thread
+ /// local cache is destructing.
void remove(ValueT *value) {
// Erase the found value directly, because it is guaranteed to be in the
// list.
llvm::sys::SmartScopedLock<true> threadInstanceLock(instanceMutex);
- auto it = llvm::find_if(instances, [&](Owner &instance) {
- return instance.value.get() == value;
- });
+ auto it =
+ llvm::find_if(instances, [&](std::unique_ptr<ValueT> &instance) {
+ return instance.get() == value;
+ });
assert(it != instances.end() && "expected value to exist in cache");
instances.erase(it);
}
/// Owning pointers to all of the values that have been constructed for this
/// object in the static cache.
- SmallVector<Owner, 1> instances;
+ SmallVector<std::unique_ptr<ValueT>, 1> instances;
/// A mutex used when a new thread instance has been added to the cache for
/// this object.
@@ -108,14 +57,14 @@ class ThreadLocalCache {
/// instance of the non-static cache and a weak reference to an instance of
/// ValueT. We use a weak reference here so that the object can be destroyed
/// without needing to lock access to the cache itself.
- struct CacheType : public llvm::SmallDenseMap<PerInstanceState *, Observer> {
+ struct CacheType
+ : public llvm::SmallDenseMap<PerInstanceState *,
+ std::pair<std::weak_ptr<ValueT>, ValueT *>> {
~CacheType() {
- // Remove the values of this cache that haven't already expired. This is
- // required because if we don't remove them, they will contain a reference
- // back to the data here that is being destroyed.
- for (auto &[instance, observer] : *this)
- if (std::shared_ptr<PerInstanceState> state = observer.keepalive.lock())
- state->remove(*observer.ptr);
+ // Remove the values of this cache that haven't already expired.
+ for (auto &it : *this)
+ if (std::shared_ptr<ValueT> value = it.second.first.lock())
+ it.first->remove(value.get());
}
/// Clear out any unused entries within the map. This method is not
@@ -123,7 +72,7 @@ class ThreadLocalCache {
void clearExpiredEntries() {
for (auto it = this->begin(), e = this->end(); it != e;) {
auto curIt = it++;
- if (!*curIt->second.ptr)
+ if (curIt->second.first.expired())
this->erase(curIt);
}
}
@@ -140,23 +89,27 @@ class ThreadLocalCache {
ValueT &get() {
// Check for an already existing instance for this thread.
CacheType &staticCache = getStaticCache();
- Observer &threadInstance = staticCache[perInstanceState.get()];
- if (ValueT *value = *threadInstance.ptr)
+ std::pair<std::weak_ptr<ValueT>, ValueT *> &threadInstance =
+ staticCache[perInstanceState.get()];
+ if (ValueT *value = threadInstance.second)
return *value;
// Otherwise, create a new instance for this thread.
{
llvm::sys::SmartScopedLock<true> threadInstanceLock(
perInstanceState->instanceMutex);
- perInstanceState->instances.emplace_back(threadInstance);
+ threadInstance.second =
+ perInstanceState->instances.emplace_back(std::make_unique<ValueT>())
+ .get();
}
- threadInstance.keepalive = perInstanceState;
+ threadInstance.first =
+ std::shared_ptr<ValueT>(perInstanceState, threadInstance.second);
// Before returning the new instance, take the chance to clear out any used
// entries in the static map. The cache is only cleared within the same
// thread to remove the need to lock the cache itself.
staticCache.clearExpiredEntries();
- return **threadInstance.ptr;
+ return *threadInstance.second;
}
ValueT &operator*() { return get(); }
ValueT *operator->() { return &get(); }
More information about the Mlir-commits
mailing list