[llvm] [Offload][OMPX] Add the runtime support for multi-dim grid and block (PR #118042)

Shilei Tian via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 5 12:19:52 PST 2024


https://github.com/shiltian updated https://github.com/llvm/llvm-project/pull/118042

>From cd4c7f54502a09cbf507050b5c1b15f43e4af41c Mon Sep 17 00:00:00 2001
From: Shilei Tian <i at tianshilei.me>
Date: Sun, 1 Dec 2024 21:02:12 -0500
Subject: [PATCH] [Offload][OMPX] Add the runtime support for multi-dim grid
 and block

---
 offload/plugins-nextgen/amdgpu/src/rtl.cpp    | 69 ++++++++++---------
 .../common/include/PluginInterface.h          | 15 ++--
 .../common/src/PluginInterface.cpp            | 48 +++++++------
 offload/plugins-nextgen/cuda/src/rtl.cpp      | 19 +++--
 offload/plugins-nextgen/host/src/rtl.cpp      |  4 +-
 offload/src/interface.cpp                     | 25 ++++---
 offload/src/omptarget.cpp                     |  2 -
 offload/test/api/omp_env_vars.c               |  2 +-
 offload/test/offloading/info.c                |  2 +-
 offload/test/offloading/ompx_bare.c           |  2 +-
 .../test/offloading/ompx_bare_multi_dim.cpp   | 56 +++++++++++++++
 offload/test/offloading/small_trip_count.c    | 16 ++---
 .../small_trip_count_thread_limit.cpp         |  2 +-
 13 files changed, 167 insertions(+), 95 deletions(-)
 create mode 100644 offload/test/offloading/ompx_bare_multi_dim.cpp

diff --git a/offload/plugins-nextgen/amdgpu/src/rtl.cpp b/offload/plugins-nextgen/amdgpu/src/rtl.cpp
index d74e65d4165679..e10c58f1a32259 100644
--- a/offload/plugins-nextgen/amdgpu/src/rtl.cpp
+++ b/offload/plugins-nextgen/amdgpu/src/rtl.cpp
@@ -559,15 +559,15 @@ struct AMDGPUKernelTy : public GenericKernelTy {
   }
 
   /// Launch the AMDGPU kernel function.
-  Error launchImpl(GenericDeviceTy &GenericDevice, uint32_t NumThreads,
-                   uint64_t NumBlocks, KernelArgsTy &KernelArgs,
+  Error launchImpl(GenericDeviceTy &GenericDevice, uint32_t NumThreads[3],
+                   uint32_t NumBlocks[3], KernelArgsTy &KernelArgs,
                    KernelLaunchParamsTy LaunchParams,
                    AsyncInfoWrapperTy &AsyncInfoWrapper) const override;
 
   /// Print more elaborate kernel launch info for AMDGPU
   Error printLaunchInfoDetails(GenericDeviceTy &GenericDevice,
-                               KernelArgsTy &KernelArgs, uint32_t NumThreads,
-                               uint64_t NumBlocks) const override;
+                               KernelArgsTy &KernelArgs, uint32_t NumThreads[3],
+                               uint32_t NumBlocks[3]) const override;
 
   /// Get group and private segment kernel size.
   uint32_t getGroupSize() const { return GroupSize; }
@@ -719,7 +719,7 @@ struct AMDGPUQueueTy {
   /// Push a kernel launch to the queue. The kernel launch requires an output
   /// signal and can define an optional input signal (nullptr if none).
   Error pushKernelLaunch(const AMDGPUKernelTy &Kernel, void *KernelArgs,
-                         uint32_t NumThreads, uint64_t NumBlocks,
+                         uint32_t NumThreads[3], uint32_t NumBlocks[3],
                          uint32_t GroupSize, uint64_t StackSize,
                          AMDGPUSignalTy *OutputSignal,
                          AMDGPUSignalTy *InputSignal) {
@@ -746,14 +746,18 @@ struct AMDGPUQueueTy {
     assert(Packet && "Invalid packet");
 
     // The first 32 bits of the packet are written after the other fields
-    uint16_t Setup = UINT16_C(1) << HSA_KERNEL_DISPATCH_PACKET_SETUP_DIMENSIONS;
-    Packet->workgroup_size_x = NumThreads;
-    Packet->workgroup_size_y = 1;
-    Packet->workgroup_size_z = 1;
+    uint16_t Dims = NumBlocks[2] * NumThreads[2] > 1
+                        ? 3
+                        : 1 + (NumBlocks[1] * NumThreads[1] != 1);
+    uint16_t Setup = UINT16_C(Dims)
+                     << HSA_KERNEL_DISPATCH_PACKET_SETUP_DIMENSIONS;
+    Packet->workgroup_size_x = NumThreads[0];
+    Packet->workgroup_size_y = NumThreads[1];
+    Packet->workgroup_size_z = NumThreads[2];
     Packet->reserved0 = 0;
-    Packet->grid_size_x = NumBlocks * NumThreads;
-    Packet->grid_size_y = 1;
-    Packet->grid_size_z = 1;
+    Packet->grid_size_x = NumBlocks[0] * NumThreads[0];
+    Packet->grid_size_y = NumBlocks[1] * NumThreads[1];
+    Packet->grid_size_z = NumBlocks[2] * NumThreads[2];
     Packet->private_segment_size =
         Kernel.usesDynamicStack() ? StackSize : Kernel.getPrivateSize();
     Packet->group_segment_size = GroupSize;
@@ -1240,7 +1244,7 @@ struct AMDGPUStreamTy {
   /// the kernel finalizes. Once the kernel is finished, the stream will release
   /// the kernel args buffer to the specified memory manager.
   Error pushKernelLaunch(const AMDGPUKernelTy &Kernel, void *KernelArgs,
-                         uint32_t NumThreads, uint64_t NumBlocks,
+                         uint32_t NumThreads[3], uint32_t NumBlocks[3],
                          uint32_t GroupSize, uint64_t StackSize,
                          AMDGPUMemoryManagerTy &MemoryManager) {
     if (Queue == nullptr)
@@ -2827,10 +2831,10 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
     AsyncInfoWrapperTy AsyncInfoWrapper(*this, nullptr);
 
     KernelArgsTy KernelArgs = {};
-    if (auto Err =
-            AMDGPUKernel.launchImpl(*this, /*NumThread=*/1u,
-                                    /*NumBlocks=*/1ul, KernelArgs,
-                                    KernelLaunchParamsTy{}, AsyncInfoWrapper))
+    uint32_t NumBlocksAndThreads[3] = {1u, 1u, 1u};
+    if (auto Err = AMDGPUKernel.launchImpl(
+            *this, NumBlocksAndThreads, NumBlocksAndThreads, KernelArgs,
+            KernelLaunchParamsTy{}, AsyncInfoWrapper))
       return Err;
 
     Error Err = Plugin::success();
@@ -3328,7 +3332,7 @@ struct AMDGPUPluginTy final : public GenericPluginTy {
 };
 
 Error AMDGPUKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
-                                 uint32_t NumThreads, uint64_t NumBlocks,
+                                 uint32_t NumThreads[3], uint32_t NumBlocks[3],
                                  KernelArgsTy &KernelArgs,
                                  KernelLaunchParamsTy LaunchParams,
                                  AsyncInfoWrapperTy &AsyncInfoWrapper) const {
@@ -3385,13 +3389,15 @@ Error AMDGPUKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
   // Only COV5 implicitargs needs to be set. COV4 implicitargs are not used.
   if (ImplArgs &&
       getImplicitArgsSize() == sizeof(hsa_utils::AMDGPUImplicitArgsTy)) {
-    ImplArgs->BlockCountX = NumBlocks;
-    ImplArgs->BlockCountY = 1;
-    ImplArgs->BlockCountZ = 1;
-    ImplArgs->GroupSizeX = NumThreads;
-    ImplArgs->GroupSizeY = 1;
-    ImplArgs->GroupSizeZ = 1;
-    ImplArgs->GridDims = 1;
+    ImplArgs->BlockCountX = NumBlocks[0];
+    ImplArgs->BlockCountY = NumBlocks[1];
+    ImplArgs->BlockCountZ = NumBlocks[2];
+    ImplArgs->GroupSizeX = NumThreads[0];
+    ImplArgs->GroupSizeY = NumThreads[1];
+    ImplArgs->GroupSizeZ = NumThreads[2];
+    ImplArgs->GridDims = NumBlocks[2] * NumThreads[2] > 1
+                             ? 3
+                             : 1 + (NumBlocks[1] * NumThreads[1] != 1);
     ImplArgs->DynamicLdsSize = KernelArgs.DynCGroupMem;
   }
 
@@ -3402,8 +3408,8 @@ Error AMDGPUKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
 
 Error AMDGPUKernelTy::printLaunchInfoDetails(GenericDeviceTy &GenericDevice,
                                              KernelArgsTy &KernelArgs,
-                                             uint32_t NumThreads,
-                                             uint64_t NumBlocks) const {
+                                             uint32_t NumThreads[3],
+                                             uint32_t NumBlocks[3]) const {
   // Only do all this when the output is requested
   if (!(getInfoLevel() & OMP_INFOTYPE_PLUGIN_KERNEL))
     return Plugin::success();
@@ -3440,12 +3446,13 @@ Error AMDGPUKernelTy::printLaunchInfoDetails(GenericDeviceTy &GenericDevice,
   // S/VGPR Spill Count: how many S/VGPRs are spilled by the kernel
   // Tripcount: loop tripcount for the kernel
   INFO(OMP_INFOTYPE_PLUGIN_KERNEL, GenericDevice.getDeviceId(),
-       "#Args: %d Teams x Thrds: %4lux%4u (MaxFlatWorkGroupSize: %u) LDS "
+       "#Args: %d Teams x Thrds: %4ux%4u (MaxFlatWorkGroupSize: %u) LDS "
        "Usage: %uB #SGPRs/VGPRs: %u/%u #SGPR/VGPR Spills: %u/%u Tripcount: "
        "%lu\n",
-       ArgNum, NumGroups, ThreadsPerGroup, MaxFlatWorkgroupSize,
-       GroupSegmentSize, SGPRCount, VGPRCount, SGPRSpillCount, VGPRSpillCount,
-       LoopTripCount);
+       ArgNum, NumGroups[0] * NumGroups[1] * NumGroups[2],
+       ThreadsPerGroup[0] * ThreadsPerGroup[1] * ThreadsPerGroup[2],
+       MaxFlatWorkgroupSize, GroupSegmentSize, SGPRCount, VGPRCount,
+       SGPRSpillCount, VGPRSpillCount, LoopTripCount);
 
   return Plugin::success();
 }
diff --git a/offload/plugins-nextgen/common/include/PluginInterface.h b/offload/plugins-nextgen/common/include/PluginInterface.h
index 63e2f80302c306..eb266e8d4d451a 100644
--- a/offload/plugins-nextgen/common/include/PluginInterface.h
+++ b/offload/plugins-nextgen/common/include/PluginInterface.h
@@ -269,8 +269,9 @@ struct GenericKernelTy {
   Error launch(GenericDeviceTy &GenericDevice, void **ArgPtrs,
                ptrdiff_t *ArgOffsets, KernelArgsTy &KernelArgs,
                AsyncInfoWrapperTy &AsyncInfoWrapper) const;
-  virtual Error launchImpl(GenericDeviceTy &GenericDevice, uint32_t NumThreads,
-                           uint64_t NumBlocks, KernelArgsTy &KernelArgs,
+  virtual Error launchImpl(GenericDeviceTy &GenericDevice,
+                           uint32_t NumThreads[3], uint32_t NumBlocks[3],
+                           KernelArgsTy &KernelArgs,
                            KernelLaunchParamsTy LaunchParams,
                            AsyncInfoWrapperTy &AsyncInfoWrapper) const = 0;
 
@@ -320,15 +321,15 @@ struct GenericKernelTy {
 
   /// Prints generic kernel launch information.
   Error printLaunchInfo(GenericDeviceTy &GenericDevice,
-                        KernelArgsTy &KernelArgs, uint32_t NumThreads,
-                        uint64_t NumBlocks) const;
+                        KernelArgsTy &KernelArgs, uint32_t NumThreads[3],
+                        uint32_t NumBlocks[3]) const;
 
   /// Prints plugin-specific kernel launch information after generic kernel
   /// launch information
   virtual Error printLaunchInfoDetails(GenericDeviceTy &GenericDevice,
                                        KernelArgsTy &KernelArgs,
-                                       uint32_t NumThreads,
-                                       uint64_t NumBlocks) const;
+                                       uint32_t NumThreads[3],
+                                       uint32_t NumBlocks[3]) const;
 
 private:
   /// Prepare the arguments before launching the kernel.
@@ -347,7 +348,7 @@ struct GenericKernelTy {
   /// The number of threads \p NumThreads can be adjusted by this method.
   /// \p IsNumThreadsFromUser is true is \p NumThreads is defined by user via
   /// thread_limit clause.
-  uint64_t getNumBlocks(GenericDeviceTy &GenericDevice,
+  uint32_t getNumBlocks(GenericDeviceTy &GenericDevice,
                         uint32_t BlockLimitClause[3], uint64_t LoopTripCount,
                         uint32_t &NumThreads, bool IsNumThreadsFromUser) const;
 
diff --git a/offload/plugins-nextgen/common/src/PluginInterface.cpp b/offload/plugins-nextgen/common/src/PluginInterface.cpp
index 5cdf12176a0d66..bd58d1d6e0d96d 100644
--- a/offload/plugins-nextgen/common/src/PluginInterface.cpp
+++ b/offload/plugins-nextgen/common/src/PluginInterface.cpp
@@ -526,20 +526,21 @@ GenericKernelTy::getKernelLaunchEnvironment(
 
 Error GenericKernelTy::printLaunchInfo(GenericDeviceTy &GenericDevice,
                                        KernelArgsTy &KernelArgs,
-                                       uint32_t NumThreads,
-                                       uint64_t NumBlocks) const {
+                                       uint32_t NumThreads[3],
+                                       uint32_t NumBlocks[3]) const {
   INFO(OMP_INFOTYPE_PLUGIN_KERNEL, GenericDevice.getDeviceId(),
-       "Launching kernel %s with %" PRIu64
-       " blocks and %d threads in %s mode\n",
-       getName(), NumBlocks, NumThreads, getExecutionModeName());
+       "Launching kernel %s with [%u,%u,%u] blocks and [%u,%u,%u] threads in "
+       "%s mode\n",
+       getName(), NumBlocks[0], NumBlocks[1], NumBlocks[2], NumThreads[0],
+       NumThreads[1], NumThreads[2], getExecutionModeName());
   return printLaunchInfoDetails(GenericDevice, KernelArgs, NumThreads,
                                 NumBlocks);
 }
 
 Error GenericKernelTy::printLaunchInfoDetails(GenericDeviceTy &GenericDevice,
                                               KernelArgsTy &KernelArgs,
-                                              uint32_t NumThreads,
-                                              uint64_t NumBlocks) const {
+                                              uint32_t NumThreads[3],
+                                              uint32_t NumBlocks[3]) const {
   return Plugin::success();
 }
 
@@ -566,10 +567,16 @@ Error GenericKernelTy::launch(GenericDeviceTy &GenericDevice, void **ArgPtrs,
                     Args, Ptrs, *KernelLaunchEnvOrErr);
   }
 
-  uint32_t NumThreads = getNumThreads(GenericDevice, KernelArgs.ThreadLimit);
-  uint64_t NumBlocks =
-      getNumBlocks(GenericDevice, KernelArgs.NumTeams, KernelArgs.Tripcount,
-                   NumThreads, KernelArgs.ThreadLimit[0] > 0);
+  uint32_t NumThreads[3] = {KernelArgs.ThreadLimit[0],
+                            KernelArgs.ThreadLimit[1],
+                            KernelArgs.ThreadLimit[2]};
+  uint32_t NumBlocks[3] = {KernelArgs.NumTeams[0], KernelArgs.NumTeams[1],
+                           KernelArgs.NumTeams[2]};
+  if (!IsBareKernel) {
+    NumThreads[0] = getNumThreads(GenericDevice, NumThreads);
+    NumBlocks[0] = getNumBlocks(GenericDevice, NumBlocks, KernelArgs.Tripcount,
+                                NumThreads[0], KernelArgs.ThreadLimit[0] > 0);
+  }
 
   // Record the kernel description after we modified the argument count and num
   // blocks/threads.
@@ -578,7 +585,8 @@ Error GenericKernelTy::launch(GenericDeviceTy &GenericDevice, void **ArgPtrs,
     RecordReplay.saveImage(getName(), getImage());
     RecordReplay.saveKernelInput(getName(), getImage());
     RecordReplay.saveKernelDescr(getName(), LaunchParams, KernelArgs.NumArgs,
-                                 NumBlocks, NumThreads, KernelArgs.Tripcount);
+                                 NumBlocks[0], NumThreads[0],
+                                 KernelArgs.Tripcount);
   }
 
   if (auto Err =
@@ -618,11 +626,10 @@ KernelLaunchParamsTy GenericKernelTy::prepareArgs(
 
 uint32_t GenericKernelTy::getNumThreads(GenericDeviceTy &GenericDevice,
                                         uint32_t ThreadLimitClause[3]) const {
-  assert(ThreadLimitClause[1] == 0 && ThreadLimitClause[2] == 0 &&
-         "Multi dimensional launch not supported yet.");
+  assert(!IsBareKernel && "bare kernel should not call this function");
 
-  if (IsBareKernel && ThreadLimitClause[0] > 0)
-    return ThreadLimitClause[0];
+  assert(ThreadLimitClause[1] == 1 && ThreadLimitClause[2] == 1 &&
+         "Multi dimensional launch not supported yet.");
 
   if (ThreadLimitClause[0] > 0 && isGenericMode())
     ThreadLimitClause[0] += GenericDevice.getWarpSize();
@@ -632,16 +639,15 @@ uint32_t GenericKernelTy::getNumThreads(GenericDeviceTy &GenericDevice,
                                      : PreferredNumThreads);
 }
 
-uint64_t GenericKernelTy::getNumBlocks(GenericDeviceTy &GenericDevice,
+uint32_t GenericKernelTy::getNumBlocks(GenericDeviceTy &GenericDevice,
                                        uint32_t NumTeamsClause[3],
                                        uint64_t LoopTripCount,
                                        uint32_t &NumThreads,
                                        bool IsNumThreadsFromUser) const {
-  assert(NumTeamsClause[1] == 0 && NumTeamsClause[2] == 0 &&
-         "Multi dimensional launch not supported yet.");
+  assert(!IsBareKernel && "bare kernel should not call this function");
 
-  if (IsBareKernel && NumTeamsClause[0] > 0)
-    return NumTeamsClause[0];
+  assert(NumTeamsClause[1] == 1 && NumTeamsClause[2] == 1 &&
+         "Multi dimensional launch not supported yet.");
 
   if (NumTeamsClause[0] > 0) {
     // TODO: We need to honor any value and consequently allow more than the
diff --git a/offload/plugins-nextgen/cuda/src/rtl.cpp b/offload/plugins-nextgen/cuda/src/rtl.cpp
index 9af71b06ce97d3..894d1c2214b972 100644
--- a/offload/plugins-nextgen/cuda/src/rtl.cpp
+++ b/offload/plugins-nextgen/cuda/src/rtl.cpp
@@ -149,8 +149,8 @@ struct CUDAKernelTy : public GenericKernelTy {
   }
 
   /// Launch the CUDA kernel function.
-  Error launchImpl(GenericDeviceTy &GenericDevice, uint32_t NumThreads,
-                   uint64_t NumBlocks, KernelArgsTy &KernelArgs,
+  Error launchImpl(GenericDeviceTy &GenericDevice, uint32_t NumThreads[3],
+                   uint32_t NumBlocks[3], KernelArgsTy &KernelArgs,
                    KernelLaunchParamsTy LaunchParams,
                    AsyncInfoWrapperTy &AsyncInfoWrapper) const override;
 
@@ -1228,10 +1228,10 @@ struct CUDADeviceTy : public GenericDeviceTy {
     AsyncInfoWrapperTy AsyncInfoWrapper(*this, nullptr);
 
     KernelArgsTy KernelArgs = {};
-    if (auto Err =
-            CUDAKernel.launchImpl(*this, /*NumThread=*/1u,
-                                  /*NumBlocks=*/1ul, KernelArgs,
-                                  KernelLaunchParamsTy{}, AsyncInfoWrapper))
+    uint32_t NumBlocksAndThreads[3] = {1u, 1u, 1u};
+    if (auto Err = CUDAKernel.launchImpl(
+            *this, NumBlocksAndThreads, NumBlocksAndThreads, KernelArgs,
+            KernelLaunchParamsTy{}, AsyncInfoWrapper))
       return Err;
 
     Error Err = Plugin::success();
@@ -1274,7 +1274,7 @@ struct CUDADeviceTy : public GenericDeviceTy {
 };
 
 Error CUDAKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
-                               uint32_t NumThreads, uint64_t NumBlocks,
+                               uint32_t NumThreads[3], uint32_t NumBlocks[3],
                                KernelArgsTy &KernelArgs,
                                KernelLaunchParamsTy LaunchParams,
                                AsyncInfoWrapperTy &AsyncInfoWrapper) const {
@@ -1292,9 +1292,8 @@ Error CUDAKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
                     reinterpret_cast<void *>(&LaunchParams.Size),
                     CU_LAUNCH_PARAM_END};
 
-  CUresult Res = cuLaunchKernel(Func, NumBlocks, /*gridDimY=*/1,
-                                /*gridDimZ=*/1, NumThreads,
-                                /*blockDimY=*/1, /*blockDimZ=*/1,
+  CUresult Res = cuLaunchKernel(Func, NumBlocks[0], NumBlocks[1], NumBlocks[2],
+                                NumThreads[0], NumThreads[1], NumThreads[2],
                                 MaxDynCGroupMem, Stream, nullptr, Config);
   return Plugin::check(Res, "Error in cuLaunchKernel for '%s': %s", getName());
 }
diff --git a/offload/plugins-nextgen/host/src/rtl.cpp b/offload/plugins-nextgen/host/src/rtl.cpp
index 6f2e3d8604ec82..915c41e88c5828 100644
--- a/offload/plugins-nextgen/host/src/rtl.cpp
+++ b/offload/plugins-nextgen/host/src/rtl.cpp
@@ -89,8 +89,8 @@ struct GenELF64KernelTy : public GenericKernelTy {
   }
 
   /// Launch the kernel using the libffi.
-  Error launchImpl(GenericDeviceTy &GenericDevice, uint32_t NumThreads,
-                   uint64_t NumBlocks, KernelArgsTy &KernelArgs,
+  Error launchImpl(GenericDeviceTy &GenericDevice, uint32_t NumThreads[3],
+                   uint32_t NumBlocks[3], KernelArgsTy &KernelArgs,
                    KernelLaunchParamsTy LaunchParams,
                    AsyncInfoWrapperTy &AsyncInfoWrapper) const override {
     // Create a vector of ffi_types, one per argument.
diff --git a/offload/src/interface.cpp b/offload/src/interface.cpp
index 21f9114ac2b088..ad84a43cef8af4 100644
--- a/offload/src/interface.cpp
+++ b/offload/src/interface.cpp
@@ -284,14 +284,25 @@ static KernelArgsTy *upgradeKernelArgs(KernelArgsTy *KernelArgs,
     LocalKernelArgs.Flags = KernelArgs->Flags;
     LocalKernelArgs.DynCGroupMem = 0;
     LocalKernelArgs.NumTeams[0] = NumTeams;
-    LocalKernelArgs.NumTeams[1] = 0;
-    LocalKernelArgs.NumTeams[2] = 0;
+    LocalKernelArgs.NumTeams[1] = 1;
+    LocalKernelArgs.NumTeams[2] = 1;
     LocalKernelArgs.ThreadLimit[0] = ThreadLimit;
-    LocalKernelArgs.ThreadLimit[1] = 0;
-    LocalKernelArgs.ThreadLimit[2] = 0;
+    LocalKernelArgs.ThreadLimit[1] = 1;
+    LocalKernelArgs.ThreadLimit[2] = 1;
     return &LocalKernelArgs;
   }
 
+  // FIXME: This is a WA to "calibrate" the bad work done in the front end.
+  // Delete this ugly code after the front end emits proper values.
+  auto CorrectMultiDim = [](uint32_t(&Val)[3]) {
+    if (Val[1] == 0)
+      Val[1] = 1;
+    if (Val[2] == 0)
+      Val[2] = 1;
+  };
+  CorrectMultiDim(KernelArgs->ThreadLimit);
+  CorrectMultiDim(KernelArgs->NumTeams);
+
   return KernelArgs;
 }
 
@@ -320,12 +331,6 @@ static inline int targetKernel(ident_t *Loc, int64_t DeviceId, int32_t NumTeams,
   KernelArgs =
       upgradeKernelArgs(KernelArgs, LocalKernelArgs, NumTeams, ThreadLimit);
 
-  assert(KernelArgs->NumTeams[0] == static_cast<uint32_t>(NumTeams) &&
-         !KernelArgs->NumTeams[1] && !KernelArgs->NumTeams[2] &&
-         "OpenMP interface should not use multiple dimensions");
-  assert(KernelArgs->ThreadLimit[0] == static_cast<uint32_t>(ThreadLimit) &&
-         !KernelArgs->ThreadLimit[1] && !KernelArgs->ThreadLimit[2] &&
-         "OpenMP interface should not use multiple dimensions");
   TIMESCOPE_WITH_DETAILS_AND_IDENT(
       "Runtime: target exe",
       "NumTeams=" + std::to_string(NumTeams) +
diff --git a/offload/src/omptarget.cpp b/offload/src/omptarget.cpp
index 66137b53b0cb4e..1a7af5649b9e22 100644
--- a/offload/src/omptarget.cpp
+++ b/offload/src/omptarget.cpp
@@ -1451,8 +1451,6 @@ int target(ident_t *Loc, DeviceTy &Device, void *HostPtr,
         Loc);
 
 #ifdef OMPT_SUPPORT
-    assert(KernelArgs.NumTeams[1] == 0 && KernelArgs.NumTeams[2] == 0 &&
-           "Multi dimensional launch not supported yet.");
     /// RAII to establish tool anchors before and after kernel launch
     int32_t NumTeams = KernelArgs.NumTeams[0];
     // No need to guard this with OMPT_IF_BUILT
diff --git a/offload/test/api/omp_env_vars.c b/offload/test/api/omp_env_vars.c
index 2e78bb115beed4..91d0487427607e 100644
--- a/offload/test/api/omp_env_vars.c
+++ b/offload/test/api/omp_env_vars.c
@@ -5,7 +5,7 @@
 #define N 256
 
 int main() {
-  // CHECK: Launching kernel [[KERNEL:.+_main_.+]] with 1 blocks and 1 threads
+  // CHECK: Launching kernel [[KERNEL:.+_main_.+]] with [1,1,1] blocks and [1,1,1] threads
 #pragma omp target teams
 #pragma omp parallel
   {}
diff --git a/offload/test/offloading/info.c b/offload/test/offloading/info.c
index da8e4c44c5accb..d86644b871e258 100644
--- a/offload/test/offloading/info.c
+++ b/offload/test/offloading/info.c
@@ -42,7 +42,7 @@ int main() {
 // INFO: info: {{.*}}             {{.*}}             256      1           0            A[0:64] at info.c:{{[0-9]+}}:{{[0-9]+}}
 // INFO: info: Entering OpenMP kernel at info.c:{{[0-9]+}}:{{[0-9]+}} with 1 arguments:
 // INFO: info: firstprivate(val)[4]
-// INFO: info: Launching kernel __omp_offloading_{{.*}}main{{.*}} with {{[0-9]+}} blocks and {{[0-9]+}} threads in Generic mode
+// INFO: info: Launching kernel __omp_offloading_{{.*}}main{{.*}} with [{{[0-9]+}},1,1] blocks and [{{[0-9]+}},1,1] threads in Generic mode
 // AMDGPU: AMDGPU device {{[0-9]}} info: #Args: {{[0-9]}} Teams x Thrds: {{[0-9]+}}x {{[0-9]+}} (MaxFlatWorkGroupSize: {{[0-9]+}}) LDS Usage: {{[0-9]+}}B #SGPRs/VGPRs: {{[0-9]+}}/{{[0-9]+}} #SGPR/VGPR Spills: {{[0-9]+}}/{{[0-9]+}} Tripcount: {{[0-9]+}}
 // INFO: info: OpenMP Host-Device pointer mappings after block at info.c:{{[0-9]+}}:{{[0-9]+}}:
 // INFO: info: Host Ptr           Target Ptr         Size (B) DynRefCount HoldRefCount Declaration
diff --git a/offload/test/offloading/ompx_bare.c b/offload/test/offloading/ompx_bare.c
index b9a8759db1de14..6a6ada9617cf5b 100644
--- a/offload/test/offloading/ompx_bare.c
+++ b/offload/test/offloading/ompx_bare.c
@@ -15,7 +15,7 @@ int main(int argc, char *argv[]) {
   const int N = num_blocks * block_size;
   int *data = (int *)malloc(N * sizeof(int));
 
-  // CHECK: "PluginInterface" device 0 info: Launching kernel __omp_offloading_{{.*}} with 64 blocks and 64 threads in SPMD mode
+  // CHECK: "PluginInterface" device 0 info: Launching kernel __omp_offloading_{{.*}} with [64,1,1] blocks and [64,1,1] threads in SPMD mode
 
 #pragma omp target teams ompx_bare num_teams(num_blocks) thread_limit(block_size) map(from: data[0:N])
   {
diff --git a/offload/test/offloading/ompx_bare_multi_dim.cpp b/offload/test/offloading/ompx_bare_multi_dim.cpp
new file mode 100644
index 00000000000000..d37278525fdb0e
--- /dev/null
+++ b/offload/test/offloading/ompx_bare_multi_dim.cpp
@@ -0,0 +1,56 @@
+// RUN: %libomptarget-compilexx-generic
+// RUN: env LIBOMPTARGET_INFO=63 %libomptarget-run-generic 2>&1 | %fcheck-generic
+// REQUIRES: gpu
+
+#include <ompx.h>
+
+#include <cassert>
+#include <vector>
+
+// CHECK: "PluginInterface" device 0 info: Launching kernel __omp_offloading_{{.*}} with [2,4,6] blocks and [32,4,2] threads in SPMD mode
+
+int main(int argc, char *argv[]) {
+  int bs[3] = {32u, 4u, 2u};
+  int gs[3] = {2u, 4u, 6u};
+  int n = bs[0] * bs[1] * bs[2] * gs[0] * gs[1] * gs[2];
+  std::vector<int> x_buf(n);
+  std::vector<int> y_buf(n);
+  std::vector<int> z_buf(n);
+
+  auto x = x_buf.data();
+  auto y = y_buf.data();
+  auto z = z_buf.data();
+  for (int i = 0; i < n; ++i) {
+    x[i] = i;
+    y[i] = i + 1;
+  }
+
+#pragma omp target teams ompx_bare num_teams(gs[0], gs[1], gs[2])              \
+    thread_limit(bs[0], bs[1], bs[2]) map(to : x[ : n], y[ : n])               \
+    map(from : z[ : n])
+  {
+    int tid_x = ompx_thread_id_x();
+    int tid_y = ompx_thread_id_y();
+    int tid_z = ompx_thread_id_z();
+    int gid_x = ompx_block_id_x();
+    int gid_y = ompx_block_id_y();
+    int gid_z = ompx_block_id_z();
+    int bs_x = ompx_block_dim_x();
+    int bs_y = ompx_block_dim_y();
+    int bs_z = ompx_block_dim_z();
+    int bs = bs_x * bs_y * bs_z;
+    int gs_x = ompx_grid_dim_x();
+    int gs_y = ompx_grid_dim_y();
+    int gid = gid_z * gs_y * gs_x + gid_y * gs_x + gid_x;
+    int tid = tid_z * bs_x * bs_y + tid_y * bs_x + tid_x;
+    int i = gid * bs + tid;
+    z[i] = x[i] + y[i];
+  }
+
+  for (int i = 0; i < n; ++i) {
+    if (z[i] != (2 * i + 1))
+      return 1;
+  }
+
+  return 0;
+}
diff --git a/offload/test/offloading/small_trip_count.c b/offload/test/offloading/small_trip_count.c
index 78750411ff8f49..e9ec8b7103d66b 100644
--- a/offload/test/offloading/small_trip_count.c
+++ b/offload/test/offloading/small_trip_count.c
@@ -12,26 +12,26 @@
 __attribute__((optnone)) void optnone() {}
 
 int main() {
-  // DEFAULT: Launching kernel {{.+_main_.+}} with 4 blocks and 32 threads in SPMD mode
-  // EIGHT: Launching kernel {{.+_main_.+}} with 16 blocks and 8 threads in SPMD mode
+  // DEFAULT: Launching kernel {{.+_main_.+}} with [4,1,1] blocks and [32,1,1] threads in SPMD mode
+  // EIGHT: Launching kernel {{.+_main_.+}} with [16,1,1] blocks and [8,1,1] threads in SPMD mode
 #pragma omp target teams distribute parallel for simd
   for (int i = 0; i < N; ++i) {
     optnone();
   }
-  // DEFAULT: Launching kernel {{.+_main_.+}} with 4 blocks and 32 threads in SPMD mode
-  // EIGHT: Launching kernel {{.+_main_.+}} with 16 blocks and 8 threads in SPMD mode
+  // DEFAULT: Launching kernel {{.+_main_.+}} with [4,1,1] blocks and [32,1,1] threads in SPMD mode
+  // EIGHT: Launching kernel {{.+_main_.+}} with [16,1,1] blocks and [8,1,1] threads in SPMD mode
 #pragma omp target teams distribute parallel for simd
   for (int i = 0; i < N - 1; ++i) {
     optnone();
   }
-  // DEFAULT: Launching kernel {{.+_main_.+}} with 5 blocks and 32 threads in SPMD mode
-  // EIGHT: Launching kernel {{.+_main_.+}} with 17 blocks and 8 threads in SPMD mode
+  // DEFAULT: Launching kernel {{.+_main_.+}} with [5,1,1] blocks and [32,1,1] threads in SPMD mode
+  // EIGHT: Launching kernel {{.+_main_.+}} with [17,1,1] blocks and [8,1,1] threads in SPMD mode
 #pragma omp target teams distribute parallel for simd
   for (int i = 0; i < N + 1; ++i) {
     optnone();
   }
-  // DEFAULT: Launching kernel {{.+_main_.+}} with 32 blocks and 4 threads in SPMD mode
-  // EIGHT: Launching kernel {{.+_main_.+}} with 32 blocks and 4 threads in SPMD mode
+  // DEFAULT: Launching kernel {{.+_main_.+}} with [32,1,1] blocks and [4,1,1] threads in SPMD mode
+  // EIGHT: Launching kernel {{.+_main_.+}} with [32,1,1] blocks and [4,1,1] threads in SPMD mode
 #pragma omp target teams distribute parallel for simd thread_limit(4)
   for (int i = 0; i < N; ++i) {
     optnone();
diff --git a/offload/test/offloading/small_trip_count_thread_limit.cpp b/offload/test/offloading/small_trip_count_thread_limit.cpp
index cfb9fe712d270f..fbd7fe9175d705 100644
--- a/offload/test/offloading/small_trip_count_thread_limit.cpp
+++ b/offload/test/offloading/small_trip_count_thread_limit.cpp
@@ -25,4 +25,4 @@ int main(int argc, char *argv[]) {
   return 0;
 }
 
-// CHECK: Launching kernel {{.*}} with 4 blocks and 256 threads in SPMD mode
+// CHECK: Launching kernel {{.*}} with [4,1,1] blocks and [256,1,1] threads in SPMD mode



More information about the llvm-commits mailing list