[Openmp-commits] [openmp] [OpenMP] Fix work-stealing stack clobber with taskwait (PR #126049)

Julian Brown via Openmp-commits openmp-commits at lists.llvm.org
Mon Feb 10 08:01:56 PST 2025


https://github.com/jtb20 updated https://github.com/llvm/llvm-project/pull/126049

>From df7ffe87f4ee2fe98b8b49f30e6c5cf89ec84b72 Mon Sep 17 00:00:00 2001
From: Julian Brown <julian.brown at amd.com>
Date: Mon, 10 Feb 2025 04:20:01 -0600
Subject: [PATCH] [OpenMP] Fix crash with task stealing and task dependencies

This patch fixes a bug that causes crashes with OpenMP 'taskwait'
directives in heavily multi-threaded scenarios.

Task stealing can lead to a situation where references to an on-stack
'taskwait' dependency node remain even for the early-exit path in
__kmpc_omp_taskwait_deps_51.  This patch adds a wait loop to ensure the
function does not return before such references are decremented to 1,
along similar lines to the fix for PR85963.

Several new assertions are also added for safety, borrowing bit 0 of the
depnode refcount as a low-cost way of distinguishing heap-allocated from
stack-allocated depnodes.
---
 openmp/runtime/src/kmp_taskdeps.cpp | 34 +++++++++++++++++++++++------
 openmp/runtime/src/kmp_taskdeps.h   |  7 ++++--
 2 files changed, 32 insertions(+), 9 deletions(-)

diff --git a/openmp/runtime/src/kmp_taskdeps.cpp b/openmp/runtime/src/kmp_taskdeps.cpp
index 39cf3496c5a1855..392c8a2b19333d9 100644
--- a/openmp/runtime/src/kmp_taskdeps.cpp
+++ b/openmp/runtime/src/kmp_taskdeps.cpp
@@ -33,7 +33,7 @@
 static std::atomic<kmp_int32> kmp_node_id_seed = 0;
 #endif
 
-static void __kmp_init_node(kmp_depnode_t *node) {
+static void __kmp_init_node(kmp_depnode_t *node, bool on_stack) {
   node->dn.successors = NULL;
   node->dn.task = NULL; // will point to the right task
   // once dependences have been processed
@@ -41,7 +41,11 @@ static void __kmp_init_node(kmp_depnode_t *node) {
     node->dn.mtx_locks[i] = NULL;
   node->dn.mtx_num_locks = 0;
   __kmp_init_lock(&node->dn.lock);
-  KMP_ATOMIC_ST_RLX(&node->dn.nrefs, 1); // init creates the first reference
+  // Init creates the first reference.  Bit 0 indicates that this node
+  // resides on the stack.  The refcount is incremented and decremented in
+  // steps of two, maintaining use of even numbers for heap nodes and odd
+  // numbers for stack nodes.
+  KMP_ATOMIC_ST_RLX(&node->dn.nrefs, on_stack ? 3 : 2);
 #ifdef KMP_SUPPORT_GRAPH_OUTPUT
   node->dn.id = KMP_ATOMIC_INC(&kmp_node_id_seed);
 #endif
@@ -51,7 +55,7 @@ static void __kmp_init_node(kmp_depnode_t *node) {
 }
 
 static inline kmp_depnode_t *__kmp_node_ref(kmp_depnode_t *node) {
-  KMP_ATOMIC_INC(&node->dn.nrefs);
+  KMP_ATOMIC_ADD(&node->dn.nrefs, 2);
   return node;
 }
 
@@ -825,7 +829,7 @@ kmp_int32 __kmpc_omp_task_with_deps(ident_t *loc_ref, kmp_int32 gtid,
         (kmp_depnode_t *)__kmp_thread_malloc(thread, sizeof(kmp_depnode_t));
 #endif
 
-    __kmp_init_node(node);
+    __kmp_init_node(node, /*on_stack=*/false);
     new_taskdata->td_depnode = node;
 
     if (__kmp_check_deps(gtid, node, new_task, &current_task->td_dephash,
@@ -1007,7 +1011,7 @@ void __kmpc_omp_taskwait_deps_51(ident_t *loc_ref, kmp_int32 gtid,
   }
 
   kmp_depnode_t node = {0};
-  __kmp_init_node(&node);
+  __kmp_init_node(&node, /*on_stack=*/true);
 
   if (!__kmp_check_deps(gtid, &node, NULL, &current_task->td_dephash,
                         DEP_BARRIER, ndeps, dep_list, ndeps_noalias,
@@ -1018,6 +1022,16 @@ void __kmpc_omp_taskwait_deps_51(ident_t *loc_ref, kmp_int32 gtid,
 #if OMPT_SUPPORT
     __ompt_taskwait_dep_finish(current_task, taskwait_task_data);
 #endif /* OMPT_SUPPORT */
+
+    // There may still be references to this node here, due to task stealing.
+    // Wait for them to be released.
+    kmp_int32 nrefs;
+    while ((nrefs = node.dn.nrefs) > 3) {
+      KMP_DEBUG_ASSERT((nrefs & 1) == 1);
+      KMP_YIELD(TRUE);
+    }
+    KMP_DEBUG_ASSERT(nrefs == 3);
+
     return;
   }
 
@@ -1032,9 +1046,15 @@ void __kmpc_omp_taskwait_deps_51(ident_t *loc_ref, kmp_int32 gtid,
 
   // Wait until the last __kmp_release_deps is finished before we free the
   // current stack frame holding the "node" variable; once its nrefs count
-  // reaches 1, we're sure nobody else can try to reference it again.
-  while (node.dn.nrefs > 1)
+  // reaches 3 (meaning 1, since bit zero of the refcount indicates a stack
+  // rather than a heap address), we're sure nobody else can try to reference
+  // it again.
+  kmp_int32 nrefs;
+  while ((nrefs = node.dn.nrefs) > 3) {
+    KMP_DEBUG_ASSERT((nrefs & 1) == 1);
     KMP_YIELD(TRUE);
+  }
+  KMP_DEBUG_ASSERT(nrefs == 3);
 
 #if OMPT_SUPPORT
   __ompt_taskwait_dep_finish(current_task, taskwait_task_data);
diff --git a/openmp/runtime/src/kmp_taskdeps.h b/openmp/runtime/src/kmp_taskdeps.h
index d2ab515158011a1..893688bec9c80c9 100644
--- a/openmp/runtime/src/kmp_taskdeps.h
+++ b/openmp/runtime/src/kmp_taskdeps.h
@@ -22,12 +22,15 @@ static inline void __kmp_node_deref(kmp_info_t *thread, kmp_depnode_t *node) {
   if (!node)
     return;
 
-  kmp_int32 n = KMP_ATOMIC_DEC(&node->dn.nrefs) - 1;
+  kmp_int32 n = KMP_ATOMIC_SUB(&node->dn.nrefs, 2) - 2;
   KMP_DEBUG_ASSERT(n >= 0);
-  if (n == 0) {
+  if ((n & ~1) == 0) {
 #if USE_ITT_BUILD && USE_ITT_NOTIFY
     __itt_sync_destroy(node);
 #endif
+    // These two assertions are somewhat redundant.  The first is intended to
+    // detect if we are trying to free a depnode on the stack.
+    KMP_DEBUG_ASSERT((node->dn.nrefs & 1) == 0);
     KMP_ASSERT(node->dn.nrefs == 0);
 #if USE_FAST_MEMORY
     __kmp_fast_free(thread, node);



More information about the Openmp-commits mailing list