[Openmp-commits] [openmp] [OpenMP][DeviceRTL] implemented nteams-var ICV, omp_get_max_teams(), and omp_set_num_teams() (PR #71259)

Khoi Nguyen via Openmp-commits openmp-commits at lists.llvm.org
Thu Dec 7 15:38:34 PST 2023


https://github.com/khoing0810 updated https://github.com/llvm/llvm-project/pull/71259

>From 1c69398a44ec0a3c2212bfd756b3dcf7017b16a7 Mon Sep 17 00:00:00 2001
From: "Khoi D. Nguyen" <nguyen155 at llnl.gov>
Date: Fri, 3 Nov 2023 17:24:11 -0700
Subject: [PATCH 1/6] implemented omp_get_max_teams() and omp_set_num_teams and
 nteams-var ICV

---
 .../DeviceRTL/include/Configuration.h         |  3 +
 .../DeviceRTL/include/Interface.h             |  4 ++
 openmp/libomptarget/DeviceRTL/include/State.h | 10 ++++
 .../DeviceRTL/src/Configuration.cpp           |  2 +
 openmp/libomptarget/DeviceRTL/src/State.cpp   |  8 +++
 openmp/libomptarget/include/Environment.h     |  1 +
 .../PluginInterface/PluginInterface.cpp       |  3 +-
 .../test/env/omp_get_max_teams_env_var.c      | 54 ++++++++++++++++++
 .../libomptarget/test/env/omp_set_num_teams.c | 57 +++++++++++++++++++
 9 files changed, 141 insertions(+), 1 deletion(-)
 create mode 100644 openmp/libomptarget/test/env/omp_get_max_teams_env_var.c
 create mode 100644 openmp/libomptarget/test/env/omp_set_num_teams.c

diff --git a/openmp/libomptarget/DeviceRTL/include/Configuration.h b/openmp/libomptarget/DeviceRTL/include/Configuration.h
index 45e5cead231f72..9529bc1ddf6198 100644
--- a/openmp/libomptarget/DeviceRTL/include/Configuration.h
+++ b/openmp/libomptarget/DeviceRTL/include/Configuration.h
@@ -55,6 +55,9 @@ bool mayUseThreadStates();
 /// parallelism, or if it was explicitly disabled by the user.
 bool mayUseNestedParallelism();
 
+/// Return max number of teams in the device it's called on.
+uint32_t getMaxTeams(); 
+
 } // namespace config
 } // namespace ompx
 
diff --git a/openmp/libomptarget/DeviceRTL/include/Interface.h b/openmp/libomptarget/DeviceRTL/include/Interface.h
index 24de620759c419..a403561e2bf44b 100644
--- a/openmp/libomptarget/DeviceRTL/include/Interface.h
+++ b/openmp/libomptarget/DeviceRTL/include/Interface.h
@@ -133,6 +133,10 @@ int omp_get_num_teams(void);
 
 int omp_get_team_num();
 
+int omp_get_max_teams(void);
+
+void omp_set_num_teams(int V);
+
 int omp_get_initial_device(void);
 
 void *llvm_omp_target_dynamic_shared_alloc();
diff --git a/openmp/libomptarget/DeviceRTL/include/State.h b/openmp/libomptarget/DeviceRTL/include/State.h
index 1d73bdc4f5409c..ae71fa6159830d 100644
--- a/openmp/libomptarget/DeviceRTL/include/State.h
+++ b/openmp/libomptarget/DeviceRTL/include/State.h
@@ -54,6 +54,7 @@ inline constexpr uint32_t SharedScratchpadSize = SHARED_SCRATCHPAD_SIZE;
 
 struct ICVStateTy {
   uint32_t NThreadsVar;
+  uint32_t NTeamsVar;
   uint32_t LevelVar;
   uint32_t ActiveLevelVar;
   uint32_t Padding0Val;
@@ -131,6 +132,7 @@ KernelLaunchEnvironmentTy &getKernelLaunchEnvironment();
 /// TODO
 enum ValueKind {
   VK_NThreads,
+  VK_NTeams,
   VK_Level,
   VK_ActiveLevel,
   VK_MaxActiveLevels,
@@ -190,6 +192,11 @@ lookup32(ValueKind Kind, bool IsReadonly, IdentTy *Ident, bool ForceTeamState) {
       return lookupImpl(&ICVStateTy::NThreadsVar, ForceTeamState);
     return lookupForModify32Impl(&ICVStateTy::NThreadsVar, Ident,
                                  ForceTeamState);
+  case state::VK_NTeams:
+    if (IsReadonly)
+      return lookupImpl(&ICVStateTy::NTeamsVar, ForceTeamState);
+    return lookupForModify32Impl(&ICVStateTy::NTeamsVar, Ident,
+                                 ForceTeamState);
   case state::VK_Level:
     if (IsReadonly)
       return lookupImpl(&ICVStateTy::LevelVar, ForceTeamState);
@@ -360,6 +367,9 @@ namespace icv {
 /// TODO
 inline state::Value<uint32_t, state::VK_NThreads> NThreads;
 
+/// TODO
+inline state::Value<uint32_t, state::VK_NTeams> NTeams;
+
 /// TODO
 inline state::Value<uint32_t, state::VK_Level> Level;
 
diff --git a/openmp/libomptarget/DeviceRTL/src/Configuration.cpp b/openmp/libomptarget/DeviceRTL/src/Configuration.cpp
index ab1608b1cfb0ae..1f1a77c9938077 100644
--- a/openmp/libomptarget/DeviceRTL/src/Configuration.cpp
+++ b/openmp/libomptarget/DeviceRTL/src/Configuration.cpp
@@ -75,4 +75,6 @@ bool config::mayUseNestedParallelism() {
   return state::getKernelEnvironment().Configuration.MayUseNestedParallelism;
 }
 
+uint32_t config::getMaxTeams() { return __omp_rtl_device_environment.NumTeams; }
+
 #pragma omp end declare target
diff --git a/openmp/libomptarget/DeviceRTL/src/State.cpp b/openmp/libomptarget/DeviceRTL/src/State.cpp
index f8a6d333df0d9e..43f6a402b47d96 100644
--- a/openmp/libomptarget/DeviceRTL/src/State.cpp
+++ b/openmp/libomptarget/DeviceRTL/src/State.cpp
@@ -199,6 +199,7 @@ void state::ICVStateTy::assertEqual(const ICVStateTy &Other) const {
 
 void state::TeamStateTy::init(bool IsSPMD) {
   ICVState.NThreadsVar = 0;
+  ICVState.NTeamsVar = config::getMaxTeams();
   ICVState.LevelVar = 0;
   ICVState.ActiveLevelVar = 0;
   ICVState.Padding0Val = 0;
@@ -424,6 +425,13 @@ int omp_get_num_teams(void) { return mapping::getNumberOfBlocksInKernel(); }
 
 int omp_get_team_num() { return mapping::getBlockIdInKernel(); }
 
+int omp_get_max_teams(void) { return icv::NTeams; }
+
+void omp_set_num_teams(int V) {
+  icv::NTeams = (V < 0) ? 0 :
+                (V >= config::getMaxTeams()) ? config::getMaxTeams() : V;
+}
+
 int omp_get_initial_device(void) { return -1; }
 }
 
diff --git a/openmp/libomptarget/include/Environment.h b/openmp/libomptarget/include/Environment.h
index bd493e8a0be78f..5a24696b4aab4b 100644
--- a/openmp/libomptarget/include/Environment.h
+++ b/openmp/libomptarget/include/Environment.h
@@ -35,6 +35,7 @@ enum class DeviceDebugKind : uint32_t {
 struct DeviceEnvironmentTy {
   uint32_t DeviceDebugKind;
   uint32_t NumDevices;
+  uint32_t NumTeams;
   uint32_t DeviceNum;
   uint32_t DynamicMemSize;
   uint64_t ClockFrequency;
diff --git a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp
index 106e7a68cd3ae3..d7d9d0e7e64200 100644
--- a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp
+++ b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp
@@ -628,7 +628,7 @@ uint64_t GenericKernelTy::getNumBlocks(GenericDeviceTy &GenericDevice,
 GenericDeviceTy::GenericDeviceTy(int32_t DeviceId, int32_t NumDevices,
                                  const llvm::omp::GV &OMPGridValues)
     : MemoryManager(nullptr), OMP_TeamLimit("OMP_TEAM_LIMIT"),
-      OMP_NumTeams("OMP_NUM_TEAMS"),
+      OMP_NumTeams("OMP_NUM_TEAMS_DEV_" + std::to_string(DeviceId)),
       OMP_TeamsThreadLimit("OMP_TEAMS_THREAD_LIMIT"),
       OMPX_DebugKind("LIBOMPTARGET_DEVICE_RTL_DEBUG"),
       OMPX_SharedMemorySize("LIBOMPTARGET_SHARED_MEMORY_SIZE"),
@@ -854,6 +854,7 @@ Error GenericDeviceTy::setupDeviceEnvironment(GenericPluginTy &Plugin,
   DeviceEnvironmentTy DeviceEnvironment;
   DeviceEnvironment.DeviceDebugKind = OMPX_DebugKind;
   DeviceEnvironment.NumDevices = Plugin.getNumDevices();
+  DeviceEnvironment.NumTeams = (OMP_NumTeams >= 0) ? uint32_t(OMP_NumTeams) : 0;
   // TODO: The device ID used here is not the real device ID used by OpenMP.
   DeviceEnvironment.DeviceNum = DeviceId;
   DeviceEnvironment.DynamicMemSize = OMPX_SharedMemorySize;
diff --git a/openmp/libomptarget/test/env/omp_get_max_teams_env_var.c b/openmp/libomptarget/test/env/omp_get_max_teams_env_var.c
new file mode 100644
index 00000000000000..4910dbe4a75b29
--- /dev/null
+++ b/openmp/libomptarget/test/env/omp_get_max_teams_env_var.c
@@ -0,0 +1,54 @@
+// Test functionality of omp_get_max_teams() with setting
+// environment variable to 2 GPU devices. If there's only
+// one GPU device, remove the device 1 if statement.
+
+// RUN: %libomptarget-compile-generic -fopenmp-offload-mandatory
+// RUN: env OMP_NUM_TEAMS_DEV_0=5 OMP_NUM_TEAMS_DEV_1=-1 \
+// RUN: %libomptarget-run-generic 
+
+// UNSUPPORTED: x86_64-pc-linux-gnu
+// UNSUPPORTED: x86_64-pc-linux-gnu-LTO
+// UNSUPPORTED: aarch64-unknown-linux-gnu
+// UNSUPPORTED: aarch64-unknown-linux-gnu-LTO
+
+#include <omp.h>
+#include <stdio.h>
+
+const int EXPECTED_NTEAMS_DEV_0 = 5;
+const int EXPECTED_NTEAMS_DEV_1 = 0;
+
+int omp_get_max_teams(void);
+
+int test_nteams_var_env(void) {
+  int errors = 0;
+  int device_id;
+  int n_devs;
+  int curr_nteams = -1;
+#pragma omp target map(tofrom : n_devs)
+  { n_devs = omp_get_num_devices(); }
+
+  for (int i = 0; i < n_devs; i++) {
+#pragma omp target device(i) map(tofrom : curr_nteams, device_id, errors)
+    {
+      device_id = omp_get_device_num();
+      errors = errors + (device_id != i);
+      curr_nteams = omp_get_max_teams();
+      if (device_id == 0) { errors = errors + (curr_nteams != EXPECTED_NTEAMS_DEV_0); } // device 0
+      if (device_id == 1) { errors = errors + (curr_nteams != EXPECTED_NTEAMS_DEV_1); } // device 1
+    }
+    printf("device: %d nteams: %d\n", device_id, curr_nteams);
+  }
+  return errors;
+}
+
+int main() {
+  int errors = 0;
+  errors = test_nteams_var_env();
+  if (errors)
+    printf("FAIL\n");
+  else
+    printf("PASS\n");
+  return errors;
+}
+
+// CHECK: PASS
\ No newline at end of file
diff --git a/openmp/libomptarget/test/env/omp_set_num_teams.c b/openmp/libomptarget/test/env/omp_set_num_teams.c
new file mode 100644
index 00000000000000..9b798518bfc5e9
--- /dev/null
+++ b/openmp/libomptarget/test/env/omp_set_num_teams.c
@@ -0,0 +1,57 @@
+// Test functionality of omp_set_num_teams() with setting
+// environment variable as an upper bound. Test for negative
+// value and value that is larger than the upper bound.
+
+// RUN: %libomptarget-compile-generic -fopenmp-offload-mandatory
+// RUN: env OMP_NUM_TEAMS_DEV_0=3 LIBOMPTARGET_INFO=16\
+// RUN: %libomptarget-run-generic
+
+// UNSUPPORTED: x86_64-pc-linux-gnu
+// UNSUPPORTED: x86_64-pc-linux-gnu-LTO
+// UNSUPPORTED: aarch64-unknown-linux-gnu
+// UNSUPPORTED: aarch64-unknown-linux-gnu-LTO
+
+#include <omp.h>
+#include <stdio.h>
+
+const int EXPECTED_NTEAMS = 3;
+
+void omp_set_num_teams(int V);
+int omp_get_max_teams(void);
+
+int test_set_over_max(void) {
+  int errors = 0;
+  int n_devs;
+  int curr_nteams = -1;
+
+#pragma omp target map(tofrom : n_devs)
+  { n_devs = omp_get_num_devices(); }
+
+#pragma omp target device(0) map(tofrom : curr_nteams, errors)
+  {
+    omp_set_num_teams(3 + 1);
+    curr_nteams = omp_get_max_teams();
+    errors = errors + (curr_nteams != 3);
+
+    omp_set_num_teams(-1);
+    curr_nteams = omp_get_max_teams();
+    errors = errors + (curr_nteams != 0);
+
+    omp_set_num_teams(3);
+    curr_nteams = omp_get_max_teams();
+    errors = errors + (curr_nteams != 3);
+  }
+  return errors;
+}
+
+int main() {
+  int errors = 0;
+  errors = test_set_over_max();
+  if (errors)
+    printf("FAIL\n");
+  else
+    printf("PASS\n");
+  return errors;
+}
+
+// CHECK: PASS
\ No newline at end of file

>From 932d89fb7bc19141c9d4461586cd1dec516c0cdc Mon Sep 17 00:00:00 2001
From: "Khoi D. Nguyen" <nguyen155 at llnl.gov>
Date: Fri, 3 Nov 2023 17:41:33 -0700
Subject: [PATCH 2/6] fixed formatting

---
 openmp/libomptarget/DeviceRTL/include/Configuration.h  |  2 +-
 openmp/libomptarget/DeviceRTL/include/State.h          |  3 +--
 openmp/libomptarget/DeviceRTL/src/State.cpp            |  5 +++--
 .../libomptarget/test/env/omp_get_max_teams_env_var.c  | 10 +++++++---
 4 files changed, 12 insertions(+), 8 deletions(-)

diff --git a/openmp/libomptarget/DeviceRTL/include/Configuration.h b/openmp/libomptarget/DeviceRTL/include/Configuration.h
index 9529bc1ddf6198..dcc4c7a7cf15be 100644
--- a/openmp/libomptarget/DeviceRTL/include/Configuration.h
+++ b/openmp/libomptarget/DeviceRTL/include/Configuration.h
@@ -56,7 +56,7 @@ bool mayUseThreadStates();
 bool mayUseNestedParallelism();
 
 /// Return max number of teams in the device it's called on.
-uint32_t getMaxTeams(); 
+uint32_t getMaxTeams();
 
 } // namespace config
 } // namespace ompx
diff --git a/openmp/libomptarget/DeviceRTL/include/State.h b/openmp/libomptarget/DeviceRTL/include/State.h
index ae71fa6159830d..2039e993b55a34 100644
--- a/openmp/libomptarget/DeviceRTL/include/State.h
+++ b/openmp/libomptarget/DeviceRTL/include/State.h
@@ -190,8 +190,7 @@ lookup32(ValueKind Kind, bool IsReadonly, IdentTy *Ident, bool ForceTeamState) {
   case state::VK_NThreads:
     if (IsReadonly)
       return lookupImpl(&ICVStateTy::NThreadsVar, ForceTeamState);
-    return lookupForModify32Impl(&ICVStateTy::NThreadsVar, Ident,
-                                 ForceTeamState);
+    return lookupForModify32Impl(&ICVStateTy::NTeamsVar, Ident, ForceTeamState);
   case state::VK_NTeams:
     if (IsReadonly)
       return lookupImpl(&ICVStateTy::NTeamsVar, ForceTeamState);
diff --git a/openmp/libomptarget/DeviceRTL/src/State.cpp b/openmp/libomptarget/DeviceRTL/src/State.cpp
index 43f6a402b47d96..9d5c6d30652759 100644
--- a/openmp/libomptarget/DeviceRTL/src/State.cpp
+++ b/openmp/libomptarget/DeviceRTL/src/State.cpp
@@ -428,8 +428,9 @@ int omp_get_team_num() { return mapping::getBlockIdInKernel(); }
 int omp_get_max_teams(void) { return icv::NTeams; }
 
 void omp_set_num_teams(int V) {
-  icv::NTeams = (V < 0) ? 0 :
-                (V >= config::getMaxTeams()) ? config::getMaxTeams() : V;
+  icv::NTeams = (V < 0)                        ? 0
+                : (V >= config::getMaxTeams()) ? config::getMaxTeams()
+                                               : V;
 }
 
 int omp_get_initial_device(void) { return -1; }
diff --git a/openmp/libomptarget/test/env/omp_get_max_teams_env_var.c b/openmp/libomptarget/test/env/omp_get_max_teams_env_var.c
index 4910dbe4a75b29..55014434bc9b5c 100644
--- a/openmp/libomptarget/test/env/omp_get_max_teams_env_var.c
+++ b/openmp/libomptarget/test/env/omp_get_max_teams_env_var.c
@@ -4,7 +4,7 @@
 
 // RUN: %libomptarget-compile-generic -fopenmp-offload-mandatory
 // RUN: env OMP_NUM_TEAMS_DEV_0=5 OMP_NUM_TEAMS_DEV_1=-1 \
-// RUN: %libomptarget-run-generic 
+// RUN: %libomptarget-run-generic
 
 // UNSUPPORTED: x86_64-pc-linux-gnu
 // UNSUPPORTED: x86_64-pc-linux-gnu-LTO
@@ -33,8 +33,12 @@ int test_nteams_var_env(void) {
       device_id = omp_get_device_num();
       errors = errors + (device_id != i);
       curr_nteams = omp_get_max_teams();
-      if (device_id == 0) { errors = errors + (curr_nteams != EXPECTED_NTEAMS_DEV_0); } // device 0
-      if (device_id == 1) { errors = errors + (curr_nteams != EXPECTED_NTEAMS_DEV_1); } // device 1
+      if (device_id == 0) {
+        errors = errors + (curr_nteams != EXPECTED_NTEAMS_DEV_0);
+      } // device 0
+      if (device_id == 1) {
+        errors = errors + (curr_nteams != EXPECTED_NTEAMS_DEV_1);
+      } // device 1
     }
     printf("device: %d nteams: %d\n", device_id, curr_nteams);
   }

>From 06dd9cd180610aa3fbbab91b1cefdfc53e157364 Mon Sep 17 00:00:00 2001
From: "Khoi D. Nguyen" <nguyen155 at llnl.gov>
Date: Fri, 3 Nov 2023 17:47:46 -0700
Subject: [PATCH 3/6] fixed clang-format in State.h and reverted NThreads
 formatting one back to the original

---
 openmp/libomptarget/DeviceRTL/include/State.h | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/openmp/libomptarget/DeviceRTL/include/State.h b/openmp/libomptarget/DeviceRTL/include/State.h
index 2039e993b55a34..17c7b559f56416 100644
--- a/openmp/libomptarget/DeviceRTL/include/State.h
+++ b/openmp/libomptarget/DeviceRTL/include/State.h
@@ -190,12 +190,12 @@ lookup32(ValueKind Kind, bool IsReadonly, IdentTy *Ident, bool ForceTeamState) {
   case state::VK_NThreads:
     if (IsReadonly)
       return lookupImpl(&ICVStateTy::NThreadsVar, ForceTeamState);
-    return lookupForModify32Impl(&ICVStateTy::NTeamsVar, Ident, ForceTeamState);
+    return lookupForModify32Impl(&ICVStateTy::NThreadsVar, Ident,
+                                 ForceTeamState);
   case state::VK_NTeams:
     if (IsReadonly)
       return lookupImpl(&ICVStateTy::NTeamsVar, ForceTeamState);
-    return lookupForModify32Impl(&ICVStateTy::NTeamsVar, Ident,
-                                 ForceTeamState);
+    return lookupForModify32Impl(&ICVStateTy::NTeamsVar, Ident, ForceTeamState);
   case state::VK_Level:
     if (IsReadonly)
       return lookupImpl(&ICVStateTy::LevelVar, ForceTeamState);

>From 2403ea712cc6b292054b72e878bfee9dc078b3d5 Mon Sep 17 00:00:00 2001
From: "Khoi D. Nguyen" <nguyen155 at llnl.gov>
Date: Sat, 4 Nov 2023 01:01:04 -0700
Subject: [PATCH 4/6] cleaned up leftover debugging statements and unnecessary
 parts

---
 .../libomptarget/test/env/omp_get_max_teams_env_var.c  |  1 -
 openmp/libomptarget/test/env/omp_set_num_teams.c       | 10 +++-------
 2 files changed, 3 insertions(+), 8 deletions(-)

diff --git a/openmp/libomptarget/test/env/omp_get_max_teams_env_var.c b/openmp/libomptarget/test/env/omp_get_max_teams_env_var.c
index 55014434bc9b5c..162b6e5cb01bfa 100644
--- a/openmp/libomptarget/test/env/omp_get_max_teams_env_var.c
+++ b/openmp/libomptarget/test/env/omp_get_max_teams_env_var.c
@@ -40,7 +40,6 @@ int test_nteams_var_env(void) {
         errors = errors + (curr_nteams != EXPECTED_NTEAMS_DEV_1);
       } // device 1
     }
-    printf("device: %d nteams: %d\n", device_id, curr_nteams);
   }
   return errors;
 }
diff --git a/openmp/libomptarget/test/env/omp_set_num_teams.c b/openmp/libomptarget/test/env/omp_set_num_teams.c
index 9b798518bfc5e9..e73c2990e913c5 100644
--- a/openmp/libomptarget/test/env/omp_set_num_teams.c
+++ b/openmp/libomptarget/test/env/omp_set_num_teams.c
@@ -21,17 +21,13 @@ int omp_get_max_teams(void);
 
 int test_set_over_max(void) {
   int errors = 0;
-  int n_devs;
   int curr_nteams = -1;
 
-#pragma omp target map(tofrom : n_devs)
-  { n_devs = omp_get_num_devices(); }
-
 #pragma omp target device(0) map(tofrom : curr_nteams, errors)
   {
-    omp_set_num_teams(3 + 1);
+    omp_set_num_teams(EXPECTED_NTEAMS + 1);
     curr_nteams = omp_get_max_teams();
-    errors = errors + (curr_nteams != 3);
+    errors = errors + (curr_nteams != EXPECTED_NTEAMS);
 
     omp_set_num_teams(-1);
     curr_nteams = omp_get_max_teams();
@@ -39,7 +35,7 @@ int test_set_over_max(void) {
 
     omp_set_num_teams(3);
     curr_nteams = omp_get_max_teams();
-    errors = errors + (curr_nteams != 3);
+    errors = errors + (curr_nteams != EXPECTED_NTEAMS);
   }
   return errors;
 }

>From 069ed4d89eece9abc3f6761eb27c63b6f6870fb1 Mon Sep 17 00:00:00 2001
From: "Khoi D. Nguyen" <nguyen155 at llnl.gov>
Date: Thu, 7 Dec 2023 15:32:42 -0800
Subject: [PATCH 5/6] added device num offset from plugin. ignored negative
 value behaviors.

---
 openmp/libomptarget/DeviceRTL/include/Interface.h      |  2 +-
 openmp/libomptarget/DeviceRTL/include/Utils.h          | 10 ++++++++++
 openmp/libomptarget/DeviceRTL/src/State.cpp            |  6 ++----
 .../common/PluginInterface/PluginInterface.cpp         |  4 ++--
 .../libomptarget/test/env/omp_get_max_teams_env_var.c  |  4 ++--
 openmp/libomptarget/test/env/omp_set_num_teams.c       |  9 +++++----
 6 files changed, 22 insertions(+), 13 deletions(-)

diff --git a/openmp/libomptarget/DeviceRTL/include/Interface.h b/openmp/libomptarget/DeviceRTL/include/Interface.h
index a403561e2bf44b..6f9946dde1041e 100644
--- a/openmp/libomptarget/DeviceRTL/include/Interface.h
+++ b/openmp/libomptarget/DeviceRTL/include/Interface.h
@@ -135,7 +135,7 @@ int omp_get_team_num();
 
 int omp_get_max_teams(void);
 
-void omp_set_num_teams(int V);
+void omp_set_num_teams(uint32_t V);
 
 int omp_get_initial_device(void);
 
diff --git a/openmp/libomptarget/DeviceRTL/include/Utils.h b/openmp/libomptarget/DeviceRTL/include/Utils.h
index 4ab0aea46eea12..ed2ecfee438f0b 100644
--- a/openmp/libomptarget/DeviceRTL/include/Utils.h
+++ b/openmp/libomptarget/DeviceRTL/include/Utils.h
@@ -82,6 +82,16 @@ template <typename DstTy, typename SrcTy> inline DstTy convertViaPun(SrcTy V) {
   return *((DstTy *)(&V));
 }
 
+/// Return minimum value out of 2 value arguments provided
+template <typename Ty> const Ty& min(const Ty& a, const Ty& b) {
+  return (b < a) ? b : a;
+}
+
+/// Return maxmimum value out of 2 value arguments provided
+template <typename Ty> const Ty& max(const Ty& a, const Ty& b) {
+  return (b > a) ? b : a;
+}
+
 /// A  pointer variable that has by design an `undef` value. Use with care.
 [[clang::loader_uninitialized]] static void *const UndefPtr;
 
diff --git a/openmp/libomptarget/DeviceRTL/src/State.cpp b/openmp/libomptarget/DeviceRTL/src/State.cpp
index 9d5c6d30652759..c7dc292728da10 100644
--- a/openmp/libomptarget/DeviceRTL/src/State.cpp
+++ b/openmp/libomptarget/DeviceRTL/src/State.cpp
@@ -427,10 +427,8 @@ int omp_get_team_num() { return mapping::getBlockIdInKernel(); }
 
 int omp_get_max_teams(void) { return icv::NTeams; }
 
-void omp_set_num_teams(int V) {
-  icv::NTeams = (V < 0)                        ? 0
-                : (V >= config::getMaxTeams()) ? config::getMaxTeams()
-                                               : V;
+void omp_set_num_teams(uint32_t V) {
+  icv::NTeams = utils::min(V, config::getMaxTeams());
 }
 
 int omp_get_initial_device(void) { return -1; }
diff --git a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp
index d7d9d0e7e64200..0efc0cdb78c6d0 100644
--- a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp
+++ b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp
@@ -628,7 +628,7 @@ uint64_t GenericKernelTy::getNumBlocks(GenericDeviceTy &GenericDevice,
 GenericDeviceTy::GenericDeviceTy(int32_t DeviceId, int32_t NumDevices,
                                  const llvm::omp::GV &OMPGridValues)
     : MemoryManager(nullptr), OMP_TeamLimit("OMP_TEAM_LIMIT"),
-      OMP_NumTeams("OMP_NUM_TEAMS_DEV_" + std::to_string(DeviceId)),
+      OMP_NumTeams("OMP_NUM_TEAMS_DEV_" + std::to_string(DeviceId + Plugin::get().getDeviceIdStartIndex())),
       OMP_TeamsThreadLimit("OMP_TEAMS_THREAD_LIMIT"),
       OMPX_DebugKind("LIBOMPTARGET_DEVICE_RTL_DEBUG"),
       OMPX_SharedMemorySize("LIBOMPTARGET_SHARED_MEMORY_SIZE"),
@@ -854,7 +854,7 @@ Error GenericDeviceTy::setupDeviceEnvironment(GenericPluginTy &Plugin,
   DeviceEnvironmentTy DeviceEnvironment;
   DeviceEnvironment.DeviceDebugKind = OMPX_DebugKind;
   DeviceEnvironment.NumDevices = Plugin.getNumDevices();
-  DeviceEnvironment.NumTeams = (OMP_NumTeams >= 0) ? uint32_t(OMP_NumTeams) : 0;
+  DeviceEnvironment.NumTeams = uint32_t(OMP_NumTeams);
   // TODO: The device ID used here is not the real device ID used by OpenMP.
   DeviceEnvironment.DeviceNum = DeviceId;
   DeviceEnvironment.DynamicMemSize = OMPX_SharedMemorySize;
diff --git a/openmp/libomptarget/test/env/omp_get_max_teams_env_var.c b/openmp/libomptarget/test/env/omp_get_max_teams_env_var.c
index 162b6e5cb01bfa..43da37f86f36f7 100644
--- a/openmp/libomptarget/test/env/omp_get_max_teams_env_var.c
+++ b/openmp/libomptarget/test/env/omp_get_max_teams_env_var.c
@@ -3,7 +3,7 @@
 // one GPU device, remove the device 1 if statement.
 
 // RUN: %libomptarget-compile-generic -fopenmp-offload-mandatory
-// RUN: env OMP_NUM_TEAMS_DEV_0=5 OMP_NUM_TEAMS_DEV_1=-1 \
+// RUN: env OMP_NUM_TEAMS_DEV_0=5 OMP_NUM_TEAMS_DEV_1=3 LIBOMPTARGET_INFO=16\
 // RUN: %libomptarget-run-generic
 
 // UNSUPPORTED: x86_64-pc-linux-gnu
@@ -15,7 +15,7 @@
 #include <stdio.h>
 
 const int EXPECTED_NTEAMS_DEV_0 = 5;
-const int EXPECTED_NTEAMS_DEV_1 = 0;
+const int EXPECTED_NTEAMS_DEV_1 = 3;
 
 int omp_get_max_teams(void);
 
diff --git a/openmp/libomptarget/test/env/omp_set_num_teams.c b/openmp/libomptarget/test/env/omp_set_num_teams.c
index e73c2990e913c5..e7b5bd10782213 100644
--- a/openmp/libomptarget/test/env/omp_set_num_teams.c
+++ b/openmp/libomptarget/test/env/omp_set_num_teams.c
@@ -25,17 +25,18 @@ int test_set_over_max(void) {
 
 #pragma omp target device(0) map(tofrom : curr_nteams, errors)
   {
+    // Setting over specified OMP_NUM_TEAMS_DEV_0 value is not allowed
     omp_set_num_teams(EXPECTED_NTEAMS + 1);
     curr_nteams = omp_get_max_teams();
     errors = errors + (curr_nteams != EXPECTED_NTEAMS);
 
-    omp_set_num_teams(-1);
-    curr_nteams = omp_get_max_teams();
-    errors = errors + (curr_nteams != 0);
-
     omp_set_num_teams(3);
     curr_nteams = omp_get_max_teams();
     errors = errors + (curr_nteams != EXPECTED_NTEAMS);
+
+    omp_set_num_teams(2);
+    curr_nteams = omp_get_max_teams();
+    errors = errors + (curr_nteams != 2);
   }
   return errors;
 }

>From 10f84242ab712ab0f453a351fc8d9c7dd79d95f9 Mon Sep 17 00:00:00 2001
From: "Khoi D. Nguyen" <nguyen155 at llnl.gov>
Date: Thu, 7 Dec 2023 15:35:15 -0800
Subject: [PATCH 6/6] fixed formatting

---
 openmp/libomptarget/DeviceRTL/include/Utils.h                 | 4 ++--
 .../common/PluginInterface/PluginInterface.cpp                | 4 +++-
 2 files changed, 5 insertions(+), 3 deletions(-)

diff --git a/openmp/libomptarget/DeviceRTL/include/Utils.h b/openmp/libomptarget/DeviceRTL/include/Utils.h
index ed2ecfee438f0b..a5ea213e7db046 100644
--- a/openmp/libomptarget/DeviceRTL/include/Utils.h
+++ b/openmp/libomptarget/DeviceRTL/include/Utils.h
@@ -83,12 +83,12 @@ template <typename DstTy, typename SrcTy> inline DstTy convertViaPun(SrcTy V) {
 }
 
 /// Return minimum value out of 2 value arguments provided
-template <typename Ty> const Ty& min(const Ty& a, const Ty& b) {
+template <typename Ty> const Ty &min(const Ty &a, const Ty &b) {
   return (b < a) ? b : a;
 }
 
 /// Return maxmimum value out of 2 value arguments provided
-template <typename Ty> const Ty& max(const Ty& a, const Ty& b) {
+template <typename Ty> const Ty &max(const Ty &a, const Ty &b) {
   return (b > a) ? b : a;
 }
 
diff --git a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp
index 0efc0cdb78c6d0..73d90a5f59a648 100644
--- a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp
+++ b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp
@@ -628,7 +628,9 @@ uint64_t GenericKernelTy::getNumBlocks(GenericDeviceTy &GenericDevice,
 GenericDeviceTy::GenericDeviceTy(int32_t DeviceId, int32_t NumDevices,
                                  const llvm::omp::GV &OMPGridValues)
     : MemoryManager(nullptr), OMP_TeamLimit("OMP_TEAM_LIMIT"),
-      OMP_NumTeams("OMP_NUM_TEAMS_DEV_" + std::to_string(DeviceId + Plugin::get().getDeviceIdStartIndex())),
+      OMP_NumTeams(
+          "OMP_NUM_TEAMS_DEV_" +
+          std::to_string(DeviceId + Plugin::get().getDeviceIdStartIndex())),
       OMP_TeamsThreadLimit("OMP_TEAMS_THREAD_LIMIT"),
       OMPX_DebugKind("LIBOMPTARGET_DEVICE_RTL_DEBUG"),
       OMPX_SharedMemorySize("LIBOMPTARGET_SHARED_MEMORY_SIZE"),



More information about the Openmp-commits mailing list