[Openmp-commits] [openmp] 2fdf191 - [OpenMP] Fix crash with task stealing and task dependencies (#126049)

Joachim Jenke via Openmp-commits openmp-commits at lists.llvm.org
Fri Feb 14 01:56:24 PST 2025


Author: Julian Brown
Date: 2025-02-14T10:55:59+01:00
New Revision: 2fdf191e244b62409fd73fa9bb717466d6e683b5

URL: https://github.com/llvm/llvm-project/commit/2fdf191e244b62409fd73fa9bb717466d6e683b5
DIFF: https://github.com/llvm/llvm-project/commit/2fdf191e244b62409fd73fa9bb717466d6e683b5.diff

LOG: [OpenMP] Fix crash with task stealing and task dependencies (#126049)

This patch series demonstrates and fixes a bug that causes crashes with
OpenMP 'taskwait' directives in heavily multi-threaded scenarios.

TLDR: The early return from __kmpc_omp_taskwait_deps_51 missed the
synchronization mechanism in place for the late return.

Additional debug assertions check for the implied invariants of the code.

@jpeyton52 found the timing hole as this sequence of events:
>
> 1. THREAD 1: A regular task with dependences is created, call it T1
> 2. THREAD 1: Call into `__kmpc_omp_taskwait_deps_51()` and create a stack
based depnode (`NULL` task), call it T2 (stack)
> 3. THREAD 2: Steals task T1 and executes it getting to
`__kmp_release_deps()` region.
> 4. THREAD 1: During processing of dependences for T2 (stack) (within
`__kmp_check_deps()` region),  a link is created T1 -> T2. This increases
T2's (stack) `nrefs` count.
> 5. THREAD 2: Iterates through the successors list: decrement the T2's
(stack) npredecessor count. BUT HASN'T YET `__kmp_node_deref()`-ed it.
> 6. THREAD 1: Now when finished with `__kmp_check_deps()`, it returns false
because npredecessor count is 0, but T2's (stack) `nrefs`  count is 2 because
THREAD 2 still references it!
> 7. THREAD 1: Because `__kmp_check_deps()` returns false, early exit.
>    _Now the stack based depnode is invalid, but THREAD 2 still references it._
>
> We've reached improper stack referencing behavior. Varied results/crashes/
asserts can occur if THREAD 1 comes back and recreates the exact same depnode
in the exact same stack address during the same time THREAD 2 calls
`__kmp_node_deref()`.

Added: 
    

Modified: 
    openmp/runtime/src/kmp_taskdeps.cpp
    openmp/runtime/src/kmp_taskdeps.h

Removed: 
    


################################################################################
diff  --git a/openmp/runtime/src/kmp_taskdeps.cpp b/openmp/runtime/src/kmp_taskdeps.cpp
index 39cf3496c5a18..392c8a2b19333 100644
--- a/openmp/runtime/src/kmp_taskdeps.cpp
+++ b/openmp/runtime/src/kmp_taskdeps.cpp
@@ -33,7 +33,7 @@
 static std::atomic<kmp_int32> kmp_node_id_seed = 0;
 #endif
 
-static void __kmp_init_node(kmp_depnode_t *node) {
+static void __kmp_init_node(kmp_depnode_t *node, bool on_stack) {
   node->dn.successors = NULL;
   node->dn.task = NULL; // will point to the right task
   // once dependences have been processed
@@ -41,7 +41,11 @@ static void __kmp_init_node(kmp_depnode_t *node) {
     node->dn.mtx_locks[i] = NULL;
   node->dn.mtx_num_locks = 0;
   __kmp_init_lock(&node->dn.lock);
-  KMP_ATOMIC_ST_RLX(&node->dn.nrefs, 1); // init creates the first reference
+  // Init creates the first reference.  Bit 0 indicates that this node
+  // resides on the stack.  The refcount is incremented and decremented in
+  // steps of two, maintaining use of even numbers for heap nodes and odd
+  // numbers for stack nodes.
+  KMP_ATOMIC_ST_RLX(&node->dn.nrefs, on_stack ? 3 : 2);
 #ifdef KMP_SUPPORT_GRAPH_OUTPUT
   node->dn.id = KMP_ATOMIC_INC(&kmp_node_id_seed);
 #endif
@@ -51,7 +55,7 @@ static void __kmp_init_node(kmp_depnode_t *node) {
 }
 
 static inline kmp_depnode_t *__kmp_node_ref(kmp_depnode_t *node) {
-  KMP_ATOMIC_INC(&node->dn.nrefs);
+  KMP_ATOMIC_ADD(&node->dn.nrefs, 2);
   return node;
 }
 
@@ -825,7 +829,7 @@ kmp_int32 __kmpc_omp_task_with_deps(ident_t *loc_ref, kmp_int32 gtid,
         (kmp_depnode_t *)__kmp_thread_malloc(thread, sizeof(kmp_depnode_t));
 #endif
 
-    __kmp_init_node(node);
+    __kmp_init_node(node, /*on_stack=*/false);
     new_taskdata->td_depnode = node;
 
     if (__kmp_check_deps(gtid, node, new_task, &current_task->td_dephash,
@@ -1007,7 +1011,7 @@ void __kmpc_omp_taskwait_deps_51(ident_t *loc_ref, kmp_int32 gtid,
   }
 
   kmp_depnode_t node = {0};
-  __kmp_init_node(&node);
+  __kmp_init_node(&node, /*on_stack=*/true);
 
   if (!__kmp_check_deps(gtid, &node, NULL, &current_task->td_dephash,
                         DEP_BARRIER, ndeps, dep_list, ndeps_noalias,
@@ -1018,6 +1022,16 @@ void __kmpc_omp_taskwait_deps_51(ident_t *loc_ref, kmp_int32 gtid,
 #if OMPT_SUPPORT
     __ompt_taskwait_dep_finish(current_task, taskwait_task_data);
 #endif /* OMPT_SUPPORT */
+
+    // There may still be references to this node here, due to task stealing.
+    // Wait for them to be released.
+    kmp_int32 nrefs;
+    while ((nrefs = node.dn.nrefs) > 3) {
+      KMP_DEBUG_ASSERT((nrefs & 1) == 1);
+      KMP_YIELD(TRUE);
+    }
+    KMP_DEBUG_ASSERT(nrefs == 3);
+
     return;
   }
 
@@ -1032,9 +1046,15 @@ void __kmpc_omp_taskwait_deps_51(ident_t *loc_ref, kmp_int32 gtid,
 
   // Wait until the last __kmp_release_deps is finished before we free the
   // current stack frame holding the "node" variable; once its nrefs count
-  // reaches 1, we're sure nobody else can try to reference it again.
-  while (node.dn.nrefs > 1)
+  // reaches 3 (meaning 1, since bit zero of the refcount indicates a stack
+  // rather than a heap address), we're sure nobody else can try to reference
+  // it again.
+  kmp_int32 nrefs;
+  while ((nrefs = node.dn.nrefs) > 3) {
+    KMP_DEBUG_ASSERT((nrefs & 1) == 1);
     KMP_YIELD(TRUE);
+  }
+  KMP_DEBUG_ASSERT(nrefs == 3);
 
 #if OMPT_SUPPORT
   __ompt_taskwait_dep_finish(current_task, taskwait_task_data);

diff  --git a/openmp/runtime/src/kmp_taskdeps.h b/openmp/runtime/src/kmp_taskdeps.h
index d2ab515158011..893688bec9c80 100644
--- a/openmp/runtime/src/kmp_taskdeps.h
+++ b/openmp/runtime/src/kmp_taskdeps.h
@@ -22,12 +22,15 @@ static inline void __kmp_node_deref(kmp_info_t *thread, kmp_depnode_t *node) {
   if (!node)
     return;
 
-  kmp_int32 n = KMP_ATOMIC_DEC(&node->dn.nrefs) - 1;
+  kmp_int32 n = KMP_ATOMIC_SUB(&node->dn.nrefs, 2) - 2;
   KMP_DEBUG_ASSERT(n >= 0);
-  if (n == 0) {
+  if ((n & ~1) == 0) {
 #if USE_ITT_BUILD && USE_ITT_NOTIFY
     __itt_sync_destroy(node);
 #endif
+    // These two assertions are somewhat redundant.  The first is intended to
+    // detect if we are trying to free a depnode on the stack.
+    KMP_DEBUG_ASSERT((node->dn.nrefs & 1) == 0);
     KMP_ASSERT(node->dn.nrefs == 0);
 #if USE_FAST_MEMORY
     __kmp_fast_free(thread, node);


        


More information about the Openmp-commits mailing list