[Mlir-commits] [mlir] a2223b0 - [mlir:async] Fix data races in AsyncRuntime

Eugene Zhulenev llvmlistbot at llvm.org
Wed Jan 20 13:23:48 PST 2021


Author: Eugene Zhulenev
Date: 2021-01-20T13:23:39-08:00
New Revision: a2223b09b10a4cc87b5e9c4a36ab9401c46610f6

URL: https://github.com/llvm/llvm-project/commit/a2223b09b10a4cc87b5e9c4a36ab9401c46610f6
DIFF: https://github.com/llvm/llvm-project/commit/a2223b09b10a4cc87b5e9c4a36ab9401c46610f6.diff

LOG: [mlir:async] Fix data races in AsyncRuntime

Resumed coroutine potentially can deallocate the token/value/group and destroy the mutex before the std::unique_ptr destructor.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D95037

Added: 
    

Modified: 
    mlir/lib/ExecutionEngine/AsyncRuntime.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
index a20bd6d1e996..e38ebf92cd84 100644
--- a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
@@ -136,13 +136,14 @@ struct AsyncToken : public RefCounted {
   // asynchronously executed task. If the caller immediately will drop its
   // reference we must ensure that the token will be alive until the
   // asynchronous operation is completed.
-  AsyncToken(AsyncRuntime *runtime) : RefCounted(runtime, /*count=*/2) {}
+  AsyncToken(AsyncRuntime *runtime)
+      : RefCounted(runtime, /*count=*/2), ready(false) {}
 
-  // Internal state below guarded by a mutex.
+  std::atomic<bool> ready;
+
+  // Pending awaiters are guarded by a mutex.
   std::mutex mu;
   std::condition_variable cv;
-
-  bool ready = false;
   std::vector<std::function<void()>> awaiters;
 };
 
@@ -152,17 +153,17 @@ struct AsyncToken : public RefCounted {
 struct AsyncValue : public RefCounted {
   // AsyncValue similar to an AsyncToken created with a reference count of 2.
   AsyncValue(AsyncRuntime *runtime, int32_t size)
-      : RefCounted(runtime, /*count=*/2), storage(size) {}
-
-  // Internal state below guarded by a mutex.
-  std::mutex mu;
-  std::condition_variable cv;
+      : RefCounted(runtime, /*count=*/2), ready(false), storage(size) {}
 
-  bool ready = false;
-  std::vector<std::function<void()>> awaiters;
+  std::atomic<bool> ready;
 
   // Use vector of bytes to store async value payload.
   std::vector<int8_t> storage;
+
+  // Pending awaiters are guarded by a mutex.
+  std::mutex mu;
+  std::condition_variable cv;
+  std::vector<std::function<void()>> awaiters;
 };
 
 // Async group provides a mechanism to group together multiple async tokens or
@@ -175,10 +176,9 @@ struct AsyncGroup : public RefCounted {
   std::atomic<int> pendingTokens;
   std::atomic<int> rank;
 
-  // Internal state below guarded by a mutex.
+  // Pending awaiters are guarded by a mutex.
   std::mutex mu;
   std::condition_variable cv;
-
   std::vector<std::function<void()>> awaiters;
 };
 
@@ -291,13 +291,13 @@ extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) {
 extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
   std::unique_lock<std::mutex> lock(token->mu);
   if (!token->ready)
-    token->cv.wait(lock, [token] { return token->ready; });
+    token->cv.wait(lock, [token] { return token->ready.load(); });
 }
 
 extern "C" void mlirAsyncRuntimeAwaitValue(AsyncValue *value) {
   std::unique_lock<std::mutex> lock(value->mu);
   if (!value->ready)
-    value->cv.wait(lock, [value] { return value->ready; });
+    value->cv.wait(lock, [value] { return value->ready.load(); });
 }
 
 extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
@@ -319,34 +319,37 @@ extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
 extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,
                                                      CoroHandle handle,
                                                      CoroResume resume) {
-  std::unique_lock<std::mutex> lock(token->mu);
   auto execute = [handle, resume]() { (*resume)(handle); };
-  if (token->ready)
+  if (token->ready) {
     execute();
-  else
+  } else {
+    std::unique_lock<std::mutex> lock(token->mu);
     token->awaiters.push_back([execute]() { execute(); });
+  }
 }
 
 extern "C" void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *value,
                                                      CoroHandle handle,
                                                      CoroResume resume) {
-  std::unique_lock<std::mutex> lock(value->mu);
   auto execute = [handle, resume]() { (*resume)(handle); };
-  if (value->ready)
+  if (value->ready) {
     execute();
-  else
+  } else {
+    std::unique_lock<std::mutex> lock(value->mu);
     value->awaiters.push_back([execute]() { execute(); });
+  }
 }
 
 extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group,
                                                           CoroHandle handle,
                                                           CoroResume resume) {
-  std::unique_lock<std::mutex> lock(group->mu);
   auto execute = [handle, resume]() { (*resume)(handle); };
-  if (group->pendingTokens == 0)
+  if (group->pendingTokens == 0) {
     execute();
-  else
+  } else {
+    std::unique_lock<std::mutex> lock(group->mu);
     group->awaiters.push_back([execute]() { execute(); });
+  }
 }
 
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list