[Openmp-commits] [openmp] [openmp] Added omp_get_max_teams() on device (PR #67867)

Khoi Nguyen via Openmp-commits openmp-commits at lists.llvm.org
Fri Sep 29 16:50:34 PDT 2023


https://github.com/khoing0810 created https://github.com/llvm/llvm-project/pull/67867

Added omp_get_max_teams() on device side and wrote a unit test to ensure functionality.

>From e9dc1c01ea132343032ee131a9f04058ffbf4119 Mon Sep 17 00:00:00 2001
From: "Khoi D. Nguyen" <nguyen155 at llnl.gov>
Date: Fri, 29 Sep 2023 16:16:09 -0700
Subject: [PATCH] implemented omp_get_max_teams and its test

---
 openmp/libomptarget/DeviceRTL/include/State.h |  9 ++++
 openmp/libomptarget/DeviceRTL/src/State.cpp   |  7 ++-
 .../libomptarget/test/api/omp_get_max_teams.c | 43 +++++++++++++++++++
 3 files changed, 58 insertions(+), 1 deletion(-)
 create mode 100644 openmp/libomptarget/test/api/omp_get_max_teams.c

diff --git a/openmp/libomptarget/DeviceRTL/include/State.h b/openmp/libomptarget/DeviceRTL/include/State.h
index 60dc439f9551c21..ea65101d3cc4e52 100644
--- a/openmp/libomptarget/DeviceRTL/include/State.h
+++ b/openmp/libomptarget/DeviceRTL/include/State.h
@@ -53,6 +53,7 @@ inline constexpr uint32_t SharedScratchpadSize = SHARED_SCRATCHPAD_SIZE;
 
 struct ICVStateTy {
   uint32_t NThreadsVar;
+  uint32_t NTeamsVar;
   uint32_t LevelVar;
   uint32_t ActiveLevelVar;
   uint32_t Padding0Val;
@@ -125,6 +126,7 @@ KernelEnvironmentTy &getKernelEnvironment();
 /// TODO
 enum ValueKind {
   VK_NThreads,
+  VK_NTeams,
   VK_Level,
   VK_ActiveLevel,
   VK_MaxActiveLevels,
@@ -184,6 +186,11 @@ lookup32(ValueKind Kind, bool IsReadonly, IdentTy *Ident, bool ForceTeamState) {
       return lookupImpl(&ICVStateTy::NThreadsVar, ForceTeamState);
     return lookupForModify32Impl(&ICVStateTy::NThreadsVar, Ident,
                                  ForceTeamState);
+  case state::VK_NTeams:
+    if (IsReadonly)
+      return lookupImpl(&ICVStateTy::NTeamsVar, ForceTeamState);
+    return lookupForModify32Impl(&ICVStateTy::NTeamsVar, Ident,
+                                 ForceTeamState);
   case state::VK_Level:
     if (IsReadonly)
       return lookupImpl(&ICVStateTy::LevelVar, ForceTeamState);
@@ -356,6 +363,8 @@ namespace icv {
 /// TODO
 inline state::Value<uint32_t, state::VK_NThreads> NThreads;
 
+inline state::Value<uint32_t, state::VK_NTeams> NTeams;
+
 /// TODO
 inline state::Value<uint32_t, state::VK_Level> Level;
 
diff --git a/openmp/libomptarget/DeviceRTL/src/State.cpp b/openmp/libomptarget/DeviceRTL/src/State.cpp
index 721137cb95d658b..995866d51223b63 100644
--- a/openmp/libomptarget/DeviceRTL/src/State.cpp
+++ b/openmp/libomptarget/DeviceRTL/src/State.cpp
@@ -184,7 +184,8 @@ void memory::freeGlobal(void *Ptr, const char *Reason) { free(Ptr); }
 ///}
 
 bool state::ICVStateTy::operator==(const ICVStateTy &Other) const {
-  return (NThreadsVar == Other.NThreadsVar) & (LevelVar == Other.LevelVar) &
+  return (NThreadsVar == Other.NThreadsVar) & (NTeamsVar == Other.NTeamsVar) &
+         (LevelVar == Other.LevelVar) &
          (ActiveLevelVar == Other.ActiveLevelVar) &
          (MaxActiveLevelsVar == Other.MaxActiveLevelsVar) &
          (RunSchedVar == Other.RunSchedVar) &
@@ -193,6 +194,7 @@ bool state::ICVStateTy::operator==(const ICVStateTy &Other) const {
 
 void state::ICVStateTy::assertEqual(const ICVStateTy &Other) const {
   ASSERT(NThreadsVar == Other.NThreadsVar, nullptr);
+  ASSERT(NTeamsVar == Other.NTeamsVar, nullptr);
   ASSERT(LevelVar == Other.LevelVar, nullptr);
   ASSERT(ActiveLevelVar == Other.ActiveLevelVar, nullptr);
   ASSERT(MaxActiveLevelsVar == Other.MaxActiveLevelsVar, nullptr);
@@ -202,6 +204,7 @@ void state::ICVStateTy::assertEqual(const ICVStateTy &Other) const {
 
 void state::TeamStateTy::init(bool IsSPMD) {
   ICVState.NThreadsVar = 0;
+  ICVState.NTeamsVar = 0;
   ICVState.LevelVar = 0;
   ICVState.ActiveLevelVar = 0;
   ICVState.Padding0Val = 0;
@@ -417,6 +420,8 @@ int omp_get_device_num(void) { return config::getDeviceNum(); }
 
 int omp_get_num_teams(void) { return mapping::getNumberOfBlocksInKernel(); }
 
+int omp_get_max_teams(void) { return icv::NTeams; }
+
 int omp_get_team_num() { return mapping::getBlockIdInKernel(); }
 
 int omp_get_initial_device(void) { return -1; }
diff --git a/openmp/libomptarget/test/api/omp_get_max_teams.c b/openmp/libomptarget/test/api/omp_get_max_teams.c
new file mode 100644
index 000000000000000..ee9ed546a0b0563
--- /dev/null
+++ b/openmp/libomptarget/test/api/omp_get_max_teams.c
@@ -0,0 +1,43 @@
+// RUN: %libomptarget-compile-run-and-check-generic
+
+#include <stdio.h>
+#include <omp.h>
+
+int test_get_max_teams(int offload) {
+  int errors = 0;
+  int max_teams;
+  int num_teams = 0;
+
+  #pragma omp target map(tofrom: max_teams) if(offload)
+  {
+    max_teams = omp_get_max_teams();
+  }
+  #pragma omp target map(tofrom : num_teams) if(offload)
+  {
+    #pragma omp teams 
+    {
+      if (omp_get_team_num() == 0)
+        num_teams = omp_get_num_teams();
+    }
+  }
+
+  if (max_teams > 0 && num_teams > max_teams) {
+    printf("Number of teams(=%d) reported exceeded max number of teams(=%d) (max no. > 0)", num_teams, max_teams);
+    errors++;
+  }
+
+  return errors;
+}
+
+int main() {
+  int errors = 0;
+  errors = test_get_max_teams(0);
+  errors = errors + test_get_max_teams(1);
+  if (errors)
+    printf("FAIL\n");
+  else
+    printf("PASS\n");
+  return errors;
+}
+
+// CHECK: PASS
\ No newline at end of file



More information about the Openmp-commits mailing list