[Openmp-commits] [openmp] [OpenMP] Simplify accessing num-teams and team-num (PR #131618)

Hansang Bae via Openmp-commits openmp-commits at lists.llvm.org
Mon Mar 17 07:09:11 PDT 2025


https://github.com/hansangbae updated https://github.com/llvm/llvm-project/pull/131618

>From 793af09c91cfca10f39e56ebc6a0d4ce0684ffc9 Mon Sep 17 00:00:00 2001
From: Hansang Bae <hansang.bae at intel.com>
Date: Fri, 14 Mar 2025 17:20:44 -0500
Subject: [PATCH 1/2] [OpenMP] Simplify accessing num-teams and team-num

We found an issue with accessing correct number of teams
and team number when the enclosing region is serialized due to
use of if clause. It appears that the existing method is not
able to handle such cases, so this change proposes a simpler way
of accessing the team struct bound to the implicit task invoked
by each OpenMP team in the league.
---
 openmp/runtime/src/kmp_runtime.cpp            | 62 +++-----------
 openmp/runtime/src/kmp_sched.cpp              | 13 +--
 openmp/runtime/test/teams/teams_parallel_if.c | 81 +++++++++++++++++++
 3 files changed, 95 insertions(+), 61 deletions(-)
 create mode 100644 openmp/runtime/test/teams/teams_parallel_if.c

diff --git a/openmp/runtime/src/kmp_runtime.cpp b/openmp/runtime/src/kmp_runtime.cpp
index 3e5d671cb7a48..3bbdba98d1137 100644
--- a/openmp/runtime/src/kmp_runtime.cpp
+++ b/openmp/runtime/src/kmp_runtime.cpp
@@ -8516,60 +8516,22 @@ void __kmp_aux_set_library(enum library_type arg) {
   }
 }
 
-/* Getting team information common for all team API */
-// Returns NULL if not in teams construct
-static kmp_team_t *__kmp_aux_get_team_info(int &teams_serialized) {
-  kmp_info_t *thr = __kmp_entry_thread();
-  teams_serialized = 0;
-  if (thr->th.th_teams_microtask) {
-    kmp_team_t *team = thr->th.th_team;
-    int tlevel = thr->th.th_teams_level; // the level of the teams construct
-    int ii = team->t.t_level;
-    teams_serialized = team->t.t_serialized;
-    int level = tlevel + 1;
-    KMP_DEBUG_ASSERT(ii >= tlevel);
-    while (ii > level) {
-      for (teams_serialized = team->t.t_serialized;
-           (teams_serialized > 0) && (ii > level); teams_serialized--, ii--) {
-      }
-      if (team->t.t_serialized && (!teams_serialized)) {
-        team = team->t.t_parent;
-        continue;
-      }
-      if (ii > level) {
-        team = team->t.t_parent;
-        ii--;
-      }
-    }
-    return team;
-  }
-  return NULL;
-}
-
 int __kmp_aux_get_team_num() {
-  int serialized;
-  kmp_team_t *team = __kmp_aux_get_team_info(serialized);
-  if (team) {
-    if (serialized > 1) {
-      return 0; // teams region is serialized ( 1 team of 1 thread ).
-    } else {
-      return team->t.t_master_tid;
-    }
-  }
-  return 0;
+  auto *team = __kmp_entry_thread()->th.th_team;
+  while (team && team->t.t_parent &&
+         team->t.t_parent->t.t_pkfn != (microtask_t)__kmp_teams_master)
+    team = team->t.t_parent;
+  return team ? team->t.t_master_tid : 0;
 }
 
 int __kmp_aux_get_num_teams() {
-  int serialized;
-  kmp_team_t *team = __kmp_aux_get_team_info(serialized);
-  if (team) {
-    if (serialized > 1) {
-      return 1;
-    } else {
-      return team->t.t_parent->t.t_nproc;
-    }
-  }
-  return 1;
+  auto *team = __kmp_entry_thread()->th.th_team;
+  while (team && team->t.t_parent &&
+         team->t.t_parent->t.t_pkfn != (microtask_t)__kmp_teams_master)
+    team = team->t.t_parent;
+  if (!team || !team->t.t_parent)
+    return 1;
+  return team->t.t_parent->t.t_nproc;
 }
 
 /* ------------------------------------------------------------------------ */
diff --git a/openmp/runtime/src/kmp_sched.cpp b/openmp/runtime/src/kmp_sched.cpp
index 2b1bb6f595f9a..3ae08cc899478 100644
--- a/openmp/runtime/src/kmp_sched.cpp
+++ b/openmp/runtime/src/kmp_sched.cpp
@@ -497,7 +497,6 @@ static void __kmp_dist_for_static_init(ident_t *loc, kmp_int32 gtid,
   kmp_uint32 team_id;
   kmp_uint32 nteams;
   UT trip_count;
-  kmp_team_t *team;
   kmp_info_t *th;
 
   KMP_DEBUG_ASSERT(plastiter && plower && pupper && pupperDist && pstride);
@@ -540,17 +539,9 @@ static void __kmp_dist_for_static_init(ident_t *loc, kmp_int32 gtid,
   tid = __kmp_tid_from_gtid(gtid);
   th = __kmp_threads[gtid];
   nth = th->th.th_team_nproc;
-  team = th->th.th_team;
   KMP_DEBUG_ASSERT(th->th.th_teams_microtask); // we are in the teams construct
-  // skip optional serialized teams to prevent this from using the wrong teams
-  // information when called after __kmp_serialized_parallel
-  // TODO: make __kmp_serialized_parallel eventually call __kmp_fork_in_teams
-  // to address this edge case
-  while (team->t.t_parent && team->t.t_serialized)
-    team = team->t.t_parent;
-  nteams = th->th.th_teams_size.nteams;
-  team_id = team->t.t_master_tid;
-  KMP_DEBUG_ASSERT(nteams == (kmp_uint32)team->t.t_parent->t.t_nproc);
+  nteams = __kmp_aux_get_num_teams();
+  team_id = __kmp_aux_get_team_num();
 
   // compute global trip count
   if (incr == 1) {
diff --git a/openmp/runtime/test/teams/teams_parallel_if.c b/openmp/runtime/test/teams/teams_parallel_if.c
new file mode 100644
index 0000000000000..2a7d50c235328
--- /dev/null
+++ b/openmp/runtime/test/teams/teams_parallel_if.c
@@ -0,0 +1,81 @@
+// RUN: %libomp-compile -fopenmp-version=52 && %libomp-run
+
+#include <stdio.h>
+#include <stdlib.h>
+#include <omp.h>
+
+typedef struct {
+  int team_num;
+  int thread_num;
+} omp_id_t;
+
+/// Test if each worker threads can retrieve correct icv values.
+void test_api(int nteams, int nthreads, int par_if) {
+  int expected_nteams = nteams;
+  int expected_nthreads = par_if ? nthreads : 1;
+  int expected_size = expected_nteams * expected_nthreads;
+  omp_id_t *expected = (omp_id_t *)malloc(expected_size * sizeof(omp_id_t));
+  omp_id_t *observed = (omp_id_t *)malloc(expected_size * sizeof(omp_id_t));
+
+  for (int i = 0; i < expected_size; i++) {
+    expected[i].team_num = i / expected_nthreads;
+    expected[i].thread_num = i % expected_nthreads;
+    observed[i].team_num = -1;
+    observed[i].thread_num = -1;
+  }
+
+#pragma omp teams num_teams(nteams)
+#pragma omp parallel num_threads(nthreads) if(par_if)
+  {
+    omp_id_t id = {omp_get_team_num(), omp_get_thread_num()};
+    if (omp_get_num_teams() == expected_nteams &&
+        omp_get_num_threads() == expected_nthreads &&
+        id.team_num >= 0 && id.team_num < expected_nteams &&
+        id.thread_num >= 0 && id.thread_num < expected_nthreads) {
+      int flat_id = id.thread_num + id.team_num * expected_nthreads;
+      observed[flat_id] = id;
+    }
+  }
+
+  for (int i = 0; i < expected_size; i++) {
+    if (expected[i].team_num != observed[i].team_num ||
+        expected[i].thread_num != observed[i].thread_num) {
+      printf("failed at nteams=%d, nthreads=%d, par_if=%d\n",
+             nteams, nthreads, par_if);
+      exit(EXIT_FAILURE);
+    }
+  }
+}
+
+/// Test if __kmpc_dist_for_static_init works correctly.
+void test_dist(int nteams, int nthreads, int par_if) {
+  int ub = 1000;
+  int index_sum_expected = ub * (ub + 1) / 2;
+  int index_sum = 0;
+#pragma omp teams distribute parallel for num_teams(nteams)                    \
+                                          num_threads(nthreads) if(par_if)
+  for (int i = 1; i <= ub; i++)
+#pragma omp atomic update
+    index_sum += i;
+
+  if (index_sum != index_sum_expected) {
+    printf("failed at nteams=%d, nthreads=%d, par_if=%d\n", nteams, nthreads,
+           par_if);
+    exit(EXIT_FAILURE);
+  }
+}
+
+int main() {
+  for (int par_if = 0; par_if < 2; par_if++) {
+    for (int nteams = 1; nteams <= 16; nteams++) {
+      for (int nthreads = 1; nthreads <= 16; nthreads++) {
+        if (omp_get_max_threads() < nteams * nthreads)
+          continue; // make sure requested resources are granted
+        test_api(nteams, nthreads, par_if);
+        test_dist(nteams, nthreads, par_if);
+      }
+    }
+  }
+  printf("passed\n");
+  return EXIT_SUCCESS;
+}

>From a1a4bcb5eecb5f7350b166b7a5a816cace585124 Mon Sep 17 00:00:00 2001
From: Hansang Bae <hansang.bae at intel.com>
Date: Mon, 17 Mar 2025 09:08:40 -0500
Subject: [PATCH 2/2] Clang-format

---
 openmp/runtime/test/teams/teams_parallel_if.c | 14 +++++++-------
 1 file changed, 7 insertions(+), 7 deletions(-)

diff --git a/openmp/runtime/test/teams/teams_parallel_if.c b/openmp/runtime/test/teams/teams_parallel_if.c
index 2a7d50c235328..7b5d80279b9c9 100644
--- a/openmp/runtime/test/teams/teams_parallel_if.c
+++ b/openmp/runtime/test/teams/teams_parallel_if.c
@@ -25,13 +25,13 @@ void test_api(int nteams, int nthreads, int par_if) {
   }
 
 #pragma omp teams num_teams(nteams)
-#pragma omp parallel num_threads(nthreads) if(par_if)
+#pragma omp parallel num_threads(nthreads) if (par_if)
   {
     omp_id_t id = {omp_get_team_num(), omp_get_thread_num()};
     if (omp_get_num_teams() == expected_nteams &&
-        omp_get_num_threads() == expected_nthreads &&
-        id.team_num >= 0 && id.team_num < expected_nteams &&
-        id.thread_num >= 0 && id.thread_num < expected_nthreads) {
+        omp_get_num_threads() == expected_nthreads && id.team_num >= 0 &&
+        id.team_num < expected_nteams && id.thread_num >= 0 &&
+        id.thread_num < expected_nthreads) {
       int flat_id = id.thread_num + id.team_num * expected_nthreads;
       observed[flat_id] = id;
     }
@@ -40,8 +40,8 @@ void test_api(int nteams, int nthreads, int par_if) {
   for (int i = 0; i < expected_size; i++) {
     if (expected[i].team_num != observed[i].team_num ||
         expected[i].thread_num != observed[i].thread_num) {
-      printf("failed at nteams=%d, nthreads=%d, par_if=%d\n",
-             nteams, nthreads, par_if);
+      printf("failed at nteams=%d, nthreads=%d, par_if=%d\n", nteams, nthreads,
+             par_if);
       exit(EXIT_FAILURE);
     }
   }
@@ -53,7 +53,7 @@ void test_dist(int nteams, int nthreads, int par_if) {
   int index_sum_expected = ub * (ub + 1) / 2;
   int index_sum = 0;
 #pragma omp teams distribute parallel for num_teams(nteams)                    \
-                                          num_threads(nthreads) if(par_if)
+    num_threads(nthreads) if(par_if)
   for (int i = 1; i <= ub; i++)
 #pragma omp atomic update
     index_sum += i;



More information about the Openmp-commits mailing list