[Mlir-commits] [mlir] 3d95d1b - [mlir] AsyncRuntime: fix concurrency bugs + fix exports in methods definitions
Eugene Zhulenev
llvmlistbot at llvm.org
Tue Nov 24 03:53:21 PST 2020
Author: Eugene Zhulenev
Date: 2020-11-24T03:53:13-08:00
New Revision: 3d95d1b477dee6c1a01f6802527b60ba74271ed5
URL: https://github.com/llvm/llvm-project/commit/3d95d1b477dee6c1a01f6802527b60ba74271ed5
DIFF: https://github.com/llvm/llvm-project/commit/3d95d1b477dee6c1a01f6802527b60ba74271ed5.diff
LOG: [mlir] AsyncRuntime: fix concurrency bugs + fix exports in methods definitions
1. Move ThreadPool ownership to the runtime, and wait for the async tasks completion in the destructor.
2. Remove MLIR_ASYNCRUNTIME_EXPORT from method definitions because they are unnecessary in .cpp files, as only function declarations need to be exported, not their definitions.
3. Fix concurrency bugs in group emplace and potential use-after-free in token emplace.
Tested internally 10k runs in `async.mlir` and `async-group.mlir`.
Fixed: https://bugs.llvm.org/show_bug.cgi?id=48267
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D91988
Added:
Modified:
mlir/lib/ExecutionEngine/AsyncRuntime.cpp
Removed:
################################################################################
diff --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
index 0a98e72382e5..6bf59f86208d 100644
--- a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
@@ -45,6 +45,7 @@ class AsyncRuntime {
AsyncRuntime() : numRefCountedObjects(0) {}
~AsyncRuntime() {
+ threadPool.wait(); // wait for the completion of all async tasks
assert(getNumRefCountedObjects() == 0 &&
"all ref counted objects must be destroyed");
}
@@ -53,6 +54,8 @@ class AsyncRuntime {
return numRefCountedObjects.load(std::memory_order_relaxed);
}
+ llvm::ThreadPool &getThreadPool() { return threadPool; }
+
private:
friend class RefCounted;
@@ -66,6 +69,8 @@ class AsyncRuntime {
}
std::atomic<int32_t> numRefCountedObjects;
+
+ llvm::ThreadPool threadPool;
};
// Returns the default per-process instance of an async runtime.
@@ -143,15 +148,13 @@ struct AsyncGroup : public RefCounted {
};
// Adds references to reference counted runtime object.
-extern "C" MLIR_ASYNCRUNTIME_EXPORT void
-mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) {
+extern "C" void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) {
RefCounted *refCounted = static_cast<RefCounted *>(ptr);
refCounted->addRef(count);
}
// Drops references from reference counted runtime object.
-extern "C" MLIR_ASYNCRUNTIME_EXPORT void
-mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) {
+extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) {
RefCounted *refCounted = static_cast<RefCounted *>(ptr);
refCounted->dropRef(count);
}
@@ -163,13 +166,13 @@ extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
}
// Create a new `async.group` in empty state.
-extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncGroup *mlirAsyncRuntimeCreateGroup() {
+extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() {
AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntimeInstance());
return group;
}
-extern "C" MLIR_ASYNCRUNTIME_EXPORT int64_t
-mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, AsyncGroup *group) {
+extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token,
+ AsyncGroup *group) {
std::unique_lock<std::mutex> lockToken(token->mu);
std::unique_lock<std::mutex> lockGroup(group->mu);
@@ -177,27 +180,33 @@ mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, AsyncGroup *group) {
int rank = group->rank.fetch_add(1);
group->pendingTokens.fetch_add(1);
- auto onTokenReady = [group, token](bool dropRef) {
+ auto onTokenReady = [group]() {
// Run all group awaiters if it was the last token in the group.
if (group->pendingTokens.fetch_sub(1) == 1) {
group->cv.notify_all();
for (auto &awaiter : group->awaiters)
awaiter();
}
-
- // We no longer need the token or the group, drop references on them.
- if (dropRef) {
- group->dropRef();
- token->dropRef();
- }
};
if (token->ready) {
- onTokenReady(false);
+ // Update group pending tokens immediately and maybe run awaiters.
+ onTokenReady();
+
} else {
+ // Update group pending tokens when token will become ready. Because this
+ // will happen asynchronously we must ensure that `group` is alive until
+ // then, and re-ackquire the lock.
group->addRef();
- token->addRef();
- token->awaiters.push_back([onTokenReady]() { onTokenReady(true); });
+
+ token->awaiters.push_back([group, onTokenReady]() {
+ // Make sure that `dropRef` does not destroy the mutex owned by the lock.
+ {
+ std::unique_lock<std::mutex> lockGroup(group->mu);
+ onTokenReady();
+ }
+ group->dropRef();
+ });
}
return rank;
@@ -205,11 +214,14 @@ mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, AsyncGroup *group) {
// Switches `async.token` to ready state and runs all awaiters.
extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
- std::unique_lock<std::mutex> lock(token->mu);
- token->ready = true;
- token->cv.notify_all();
- for (auto &awaiter : token->awaiters)
- awaiter();
+ // Make sure that `dropRef` does not destroy the mutex owned by the lock.
+ {
+ std::unique_lock<std::mutex> lock(token->mu);
+ token->ready = true;
+ token->cv.notify_all();
+ for (auto &awaiter : token->awaiters)
+ awaiter();
+ }
// Async tokens created with a ref count `2` to keep token alive until the
// async task completes. Drop this reference explicitly when token emplaced.
@@ -222,58 +234,37 @@ extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
token->cv.wait(lock, [token] { return token->ready; });
}
-extern "C" MLIR_ASYNCRUNTIME_EXPORT void
-mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
+extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
std::unique_lock<std::mutex> lock(group->mu);
if (group->pendingTokens != 0)
group->cv.wait(lock, [group] { return group->pendingTokens == 0; });
}
extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
-#if LLVM_ENABLE_THREADS
- static llvm::ThreadPool *threadPool = new llvm::ThreadPool();
- threadPool->async([handle, resume]() { (*resume)(handle); });
-#else
- (*resume)(handle);
-#endif
+ auto *runtime = getDefaultAsyncRuntimeInstance();
+ runtime->getThreadPool().async([handle, resume]() { (*resume)(handle); });
}
extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,
CoroHandle handle,
CoroResume resume) {
std::unique_lock<std::mutex> lock(token->mu);
-
- auto execute = [handle, resume, token](bool dropRef) {
- if (dropRef)
- token->dropRef();
- mlirAsyncRuntimeExecute(handle, resume);
- };
-
- if (token->ready) {
- execute(false);
- } else {
- token->addRef();
- token->awaiters.push_back([execute]() { execute(true); });
- }
+ auto execute = [handle, resume]() { (*resume)(handle); };
+ if (token->ready)
+ execute();
+ else
+ token->awaiters.push_back([execute]() { execute(); });
}
-extern "C" MLIR_ASYNCRUNTIME_EXPORT void
-mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, CoroHandle handle,
- CoroResume resume) {
+extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group,
+ CoroHandle handle,
+ CoroResume resume) {
std::unique_lock<std::mutex> lock(group->mu);
-
- auto execute = [handle, resume, group](bool dropRef) {
- if (dropRef)
- group->dropRef();
- mlirAsyncRuntimeExecute(handle, resume);
- };
-
- if (group->pendingTokens == 0) {
- execute(false);
- } else {
- group->addRef();
- group->awaiters.push_back([execute]() { execute(true); });
- }
+ auto execute = [handle, resume]() { (*resume)(handle); };
+ if (group->pendingTokens == 0)
+ execute();
+ else
+ group->awaiters.push_back([execute]() { execute(); });
}
//===----------------------------------------------------------------------===//
@@ -282,7 +273,7 @@ mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, CoroHandle handle,
extern "C" void mlirAsyncRuntimePrintCurrentThreadId() {
static thread_local std::thread::id thisId = std::this_thread::get_id();
- std::cout << "Current thread id: " << thisId << "\n";
+ std::cout << "Current thread id: " << thisId << std::endl;
}
#endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
More information about the Mlir-commits
mailing list