[llvm] [Offload][OMPX] Add the runtime support for multi-dim grid and block (PR #118042)

Shilei Tian via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 29 17:32:06 PST 2024


https://github.com/shiltian updated https://github.com/llvm/llvm-project/pull/118042

>From cca25b5b51bf074ffdde9cc7803e0566079d19f7 Mon Sep 17 00:00:00 2001
From: Shilei Tian <i at tianshilei.me>
Date: Fri, 29 Nov 2024 18:54:23 -0500
Subject: [PATCH] [Offload][OMPX] Add the runtime support for multi-dim grid
 and block

---
 offload/plugins-nextgen/amdgpu/src/rtl.cpp    | 69 ++++++++++--------
 .../common/include/PluginInterface.h          | 23 +++---
 .../common/src/PluginInterface.cpp            | 72 ++++++++++---------
 offload/plugins-nextgen/cuda/src/rtl.cpp      | 19 +++--
 offload/plugins-nextgen/host/src/rtl.cpp      |  4 +-
 offload/src/interface.cpp                     | 14 ++--
 offload/src/omptarget.cpp                     |  2 -
 .../test/offloading/ompx_bare_multi_dim.cpp   | 54 ++++++++++++++
 8 files changed, 159 insertions(+), 98 deletions(-)
 create mode 100644 offload/test/offloading/ompx_bare_multi_dim.cpp

diff --git a/offload/plugins-nextgen/amdgpu/src/rtl.cpp b/offload/plugins-nextgen/amdgpu/src/rtl.cpp
index 6356fa0554a9c1..2920435cd8b5c8 100644
--- a/offload/plugins-nextgen/amdgpu/src/rtl.cpp
+++ b/offload/plugins-nextgen/amdgpu/src/rtl.cpp
@@ -559,15 +559,15 @@ struct AMDGPUKernelTy : public GenericKernelTy {
   }
 
   /// Launch the AMDGPU kernel function.
-  Error launchImpl(GenericDeviceTy &GenericDevice, uint32_t NumThreads,
-                   uint64_t NumBlocks, KernelArgsTy &KernelArgs,
+  Error launchImpl(GenericDeviceTy &GenericDevice, uint32_t NumThreads[3],
+                   uint32_t NumBlocks[3], KernelArgsTy &KernelArgs,
                    KernelLaunchParamsTy LaunchParams,
                    AsyncInfoWrapperTy &AsyncInfoWrapper) const override;
 
   /// Print more elaborate kernel launch info for AMDGPU
   Error printLaunchInfoDetails(GenericDeviceTy &GenericDevice,
-                               KernelArgsTy &KernelArgs, uint32_t NumThreads,
-                               uint64_t NumBlocks) const override;
+                               KernelArgsTy &KernelArgs, uint32_t NumThreads[3],
+                               uint32_t NumBlocks[3]) const override;
 
   /// Get group and private segment kernel size.
   uint32_t getGroupSize() const { return GroupSize; }
@@ -719,7 +719,7 @@ struct AMDGPUQueueTy {
   /// Push a kernel launch to the queue. The kernel launch requires an output
   /// signal and can define an optional input signal (nullptr if none).
   Error pushKernelLaunch(const AMDGPUKernelTy &Kernel, void *KernelArgs,
-                         uint32_t NumThreads, uint64_t NumBlocks,
+                         uint32_t NumThreads[3], uint32_t NumBlocks[3],
                          uint32_t GroupSize, uint64_t StackSize,
                          AMDGPUSignalTy *OutputSignal,
                          AMDGPUSignalTy *InputSignal) {
@@ -746,14 +746,18 @@ struct AMDGPUQueueTy {
     assert(Packet && "Invalid packet");
 
     // The first 32 bits of the packet are written after the other fields
-    uint16_t Setup = UINT16_C(1) << HSA_KERNEL_DISPATCH_PACKET_SETUP_DIMENSIONS;
-    Packet->workgroup_size_x = NumThreads;
-    Packet->workgroup_size_y = 1;
-    Packet->workgroup_size_z = 1;
+    uint16_t Dims = NumBlocks[2] * NumThreads[2] > 1
+                        ? 3
+                        : 1 + (NumBlocks[1] * NumThreads[1] != 1);
+    uint16_t Setup = UINT16_C(Dims)
+                     << HSA_KERNEL_DISPATCH_PACKET_SETUP_DIMENSIONS;
+    Packet->workgroup_size_x = NumThreads[0];
+    Packet->workgroup_size_y = NumThreads[1];
+    Packet->workgroup_size_z = NumThreads[2];
     Packet->reserved0 = 0;
-    Packet->grid_size_x = NumBlocks * NumThreads;
-    Packet->grid_size_y = 1;
-    Packet->grid_size_z = 1;
+    Packet->grid_size_x = NumBlocks[0] * NumThreads[0];
+    Packet->grid_size_y = NumBlocks[1] * NumThreads[1];
+    Packet->grid_size_z = NumBlocks[2] * NumThreads[2];
     Packet->private_segment_size =
         Kernel.usesDynamicStack() ? StackSize : Kernel.getPrivateSize();
     Packet->group_segment_size = GroupSize;
@@ -1240,7 +1244,7 @@ struct AMDGPUStreamTy {
   /// the kernel finalizes. Once the kernel is finished, the stream will release
   /// the kernel args buffer to the specified memory manager.
   Error pushKernelLaunch(const AMDGPUKernelTy &Kernel, void *KernelArgs,
-                         uint32_t NumThreads, uint64_t NumBlocks,
+                         uint32_t NumThreads[3], uint32_t NumBlocks[3],
                          uint32_t GroupSize, uint64_t StackSize,
                          AMDGPUMemoryManagerTy &MemoryManager) {
     if (Queue == nullptr)
@@ -2829,10 +2833,10 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
     AsyncInfoWrapperTy AsyncInfoWrapper(*this, nullptr);
 
     KernelArgsTy KernelArgs = {};
-    if (auto Err =
-            AMDGPUKernel.launchImpl(*this, /*NumThread=*/1u,
-                                    /*NumBlocks=*/1ul, KernelArgs,
-                                    KernelLaunchParamsTy{}, AsyncInfoWrapper))
+    uint32_t NumBlocksAndThreads[3] = {1u, 1u, 1u};
+    if (auto Err = AMDGPUKernel.launchImpl(
+            *this, NumBlocksAndThreads, NumBlocksAndThreads, KernelArgs,
+            KernelLaunchParamsTy{}, AsyncInfoWrapper))
       return Err;
 
     Error Err = Plugin::success();
@@ -3330,7 +3334,7 @@ struct AMDGPUPluginTy final : public GenericPluginTy {
 };
 
 Error AMDGPUKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
-                                 uint32_t NumThreads, uint64_t NumBlocks,
+                                 uint32_t NumThreads[3], uint32_t NumBlocks[3],
                                  KernelArgsTy &KernelArgs,
                                  KernelLaunchParamsTy LaunchParams,
                                  AsyncInfoWrapperTy &AsyncInfoWrapper) const {
@@ -3387,13 +3391,15 @@ Error AMDGPUKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
   // Only COV5 implicitargs needs to be set. COV4 implicitargs are not used.
   if (ImplArgs &&
       getImplicitArgsSize() == sizeof(hsa_utils::AMDGPUImplicitArgsTy)) {
-    ImplArgs->BlockCountX = NumBlocks;
-    ImplArgs->BlockCountY = 1;
-    ImplArgs->BlockCountZ = 1;
-    ImplArgs->GroupSizeX = NumThreads;
-    ImplArgs->GroupSizeY = 1;
-    ImplArgs->GroupSizeZ = 1;
-    ImplArgs->GridDims = 1;
+    ImplArgs->BlockCountX = NumBlocks[0];
+    ImplArgs->BlockCountY = NumBlocks[1];
+    ImplArgs->BlockCountZ = NumBlocks[2];
+    ImplArgs->GroupSizeX = NumThreads[0];
+    ImplArgs->GroupSizeY = NumThreads[1];
+    ImplArgs->GroupSizeZ = NumThreads[2];
+    ImplArgs->GridDims = NumBlocks[2] * NumThreads[2] > 1
+                             ? 3
+                             : 1 + (NumBlocks[1] * NumThreads[1] != 1);
     ImplArgs->DynamicLdsSize = KernelArgs.DynCGroupMem;
   }
 
@@ -3404,8 +3410,8 @@ Error AMDGPUKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
 
 Error AMDGPUKernelTy::printLaunchInfoDetails(GenericDeviceTy &GenericDevice,
                                              KernelArgsTy &KernelArgs,
-                                             uint32_t NumThreads,
-                                             uint64_t NumBlocks) const {
+                                             uint32_t NumThreads[3],
+                                             uint32_t NumBlocks[3]) const {
   // Only do all this when the output is requested
   if (!(getInfoLevel() & OMP_INFOTYPE_PLUGIN_KERNEL))
     return Plugin::success();
@@ -3442,12 +3448,13 @@ Error AMDGPUKernelTy::printLaunchInfoDetails(GenericDeviceTy &GenericDevice,
   // S/VGPR Spill Count: how many S/VGPRs are spilled by the kernel
   // Tripcount: loop tripcount for the kernel
   INFO(OMP_INFOTYPE_PLUGIN_KERNEL, GenericDevice.getDeviceId(),
-       "#Args: %d Teams x Thrds: %4lux%4u (MaxFlatWorkGroupSize: %u) LDS "
+       "#Args: %d Teams x Thrds: %4ux%4u (MaxFlatWorkGroupSize: %u) LDS "
        "Usage: %uB #SGPRs/VGPRs: %u/%u #SGPR/VGPR Spills: %u/%u Tripcount: "
        "%lu\n",
-       ArgNum, NumGroups, ThreadsPerGroup, MaxFlatWorkgroupSize,
-       GroupSegmentSize, SGPRCount, VGPRCount, SGPRSpillCount, VGPRSpillCount,
-       LoopTripCount);
+       ArgNum, NumGroups[0] * NumGroups[1] * NumGroups[2],
+       ThreadsPerGroup[0] * ThreadsPerGroup[1] * ThreadsPerGroup[2],
+       MaxFlatWorkgroupSize, GroupSegmentSize, SGPRCount, VGPRCount,
+       SGPRSpillCount, VGPRSpillCount, LoopTripCount);
 
   return Plugin::success();
 }
diff --git a/offload/plugins-nextgen/common/include/PluginInterface.h b/offload/plugins-nextgen/common/include/PluginInterface.h
index 41cc0f286a581f..be3467c3f7098f 100644
--- a/offload/plugins-nextgen/common/include/PluginInterface.h
+++ b/offload/plugins-nextgen/common/include/PluginInterface.h
@@ -265,8 +265,9 @@ struct GenericKernelTy {
   Error launch(GenericDeviceTy &GenericDevice, void **ArgPtrs,
                ptrdiff_t *ArgOffsets, KernelArgsTy &KernelArgs,
                AsyncInfoWrapperTy &AsyncInfoWrapper) const;
-  virtual Error launchImpl(GenericDeviceTy &GenericDevice, uint32_t NumThreads,
-                           uint64_t NumBlocks, KernelArgsTy &KernelArgs,
+  virtual Error launchImpl(GenericDeviceTy &GenericDevice,
+                           uint32_t NumThreads[3], uint32_t NumBlocks[3],
+                           KernelArgsTy &KernelArgs,
                            KernelLaunchParamsTy LaunchParams,
                            AsyncInfoWrapperTy &AsyncInfoWrapper) const = 0;
 
@@ -316,15 +317,15 @@ struct GenericKernelTy {
 
   /// Prints generic kernel launch information.
   Error printLaunchInfo(GenericDeviceTy &GenericDevice,
-                        KernelArgsTy &KernelArgs, uint32_t NumThreads,
-                        uint64_t NumBlocks) const;
+                        KernelArgsTy &KernelArgs, uint32_t NumThreads[3],
+                        uint32_t NumBlocks[3]) const;
 
   /// Prints plugin-specific kernel launch information after generic kernel
   /// launch information
   virtual Error printLaunchInfoDetails(GenericDeviceTy &GenericDevice,
                                        KernelArgsTy &KernelArgs,
-                                       uint32_t NumThreads,
-                                       uint64_t NumBlocks) const;
+                                       uint32_t NumThreads[3],
+                                       uint32_t NumBlocks[3]) const;
 
 private:
   /// Prepare the arguments before launching the kernel.
@@ -337,15 +338,15 @@ struct GenericKernelTy {
 
   /// Get the number of threads and blocks for the kernel based on the
   /// user-defined threads and block clauses.
-  uint32_t getNumThreads(GenericDeviceTy &GenericDevice,
-                         uint32_t ThreadLimitClause[3]) const;
+  void getNumThreads(GenericDeviceTy &GenericDevice,
+                     uint32_t ThreadLimitClause[3]) const;
 
   /// The number of threads \p NumThreads can be adjusted by this method.
   /// \p IsNumThreadsFromUser is true is \p NumThreads is defined by user via
   /// thread_limit clause.
-  uint64_t getNumBlocks(GenericDeviceTy &GenericDevice,
-                        uint32_t BlockLimitClause[3], uint64_t LoopTripCount,
-                        uint32_t &NumThreads, bool IsNumThreadsFromUser) const;
+  void getNumBlocks(GenericDeviceTy &GenericDevice,
+                    uint32_t BlockLimitClause[3], uint64_t LoopTripCount,
+                    uint32_t &NumThreads, bool IsNumThreadsFromUser) const;
 
   /// Indicate if the kernel works in Generic SPMD, Generic or SPMD mode.
   bool isGenericSPMDMode() const {
diff --git a/offload/plugins-nextgen/common/src/PluginInterface.cpp b/offload/plugins-nextgen/common/src/PluginInterface.cpp
index 25b815b7f96694..cea4a177211fd1 100644
--- a/offload/plugins-nextgen/common/src/PluginInterface.cpp
+++ b/offload/plugins-nextgen/common/src/PluginInterface.cpp
@@ -526,20 +526,21 @@ GenericKernelTy::getKernelLaunchEnvironment(
 
 Error GenericKernelTy::printLaunchInfo(GenericDeviceTy &GenericDevice,
                                        KernelArgsTy &KernelArgs,
-                                       uint32_t NumThreads,
-                                       uint64_t NumBlocks) const {
+                                       uint32_t NumThreads[3],
+                                       uint32_t NumBlocks[3]) const {
   INFO(OMP_INFOTYPE_PLUGIN_KERNEL, GenericDevice.getDeviceId(),
-       "Launching kernel %s with %" PRIu64
-       " blocks and %d threads in %s mode\n",
-       getName(), NumBlocks, NumThreads, getExecutionModeName());
+       "Launching kernel %s with [%u,%u,%u] blocks and [%u,%u,%u] threads in "
+       "%s mode\n",
+       getName(), NumBlocks[0], NumBlocks[1], NumBlocks[2], NumThreads[0],
+       NumThreads[1], NumThreads[2], getExecutionModeName());
   return printLaunchInfoDetails(GenericDevice, KernelArgs, NumThreads,
                                 NumBlocks);
 }
 
 Error GenericKernelTy::printLaunchInfoDetails(GenericDeviceTy &GenericDevice,
                                               KernelArgsTy &KernelArgs,
-                                              uint32_t NumThreads,
-                                              uint64_t NumBlocks) const {
+                                              uint32_t NumThreads[3],
+                                              uint32_t NumBlocks[3]) const {
   return Plugin::success();
 }
 
@@ -566,10 +567,14 @@ Error GenericKernelTy::launch(GenericDeviceTy &GenericDevice, void **ArgPtrs,
                     Args, Ptrs, *KernelLaunchEnvOrErr);
   }
 
-  uint32_t NumThreads = getNumThreads(GenericDevice, KernelArgs.ThreadLimit);
-  uint64_t NumBlocks =
-      getNumBlocks(GenericDevice, KernelArgs.NumTeams, KernelArgs.Tripcount,
-                   NumThreads, KernelArgs.ThreadLimit[0] > 0);
+  uint32_t NumThreads[3] = {KernelArgs.ThreadLimit[0],
+                            KernelArgs.ThreadLimit[1],
+                            KernelArgs.ThreadLimit[2]};
+  uint32_t NumBlocks[3] = {KernelArgs.NumTeams[0], KernelArgs.NumTeams[1],
+                           KernelArgs.NumTeams[2]};
+  getNumThreads(GenericDevice, NumThreads);
+  getNumBlocks(GenericDevice, NumBlocks, KernelArgs.Tripcount, NumThreads[0],
+               NumThreads[0] > 0);
 
   // Record the kernel description after we modified the argument count and num
   // blocks/threads.
@@ -578,7 +583,8 @@ Error GenericKernelTy::launch(GenericDeviceTy &GenericDevice, void **ArgPtrs,
     RecordReplay.saveImage(getName(), getImage());
     RecordReplay.saveKernelInput(getName(), getImage());
     RecordReplay.saveKernelDescr(getName(), LaunchParams, KernelArgs.NumArgs,
-                                 NumBlocks, NumThreads, KernelArgs.Tripcount);
+                                 NumBlocks[0], NumThreads[0],
+                                 KernelArgs.Tripcount);
   }
 
   if (auto Err =
@@ -616,38 +622,39 @@ KernelLaunchParamsTy GenericKernelTy::prepareArgs(
   return KernelLaunchParamsTy{sizeof(void *) * NumArgs, &Args[0], &Ptrs[0]};
 }
 
-uint32_t GenericKernelTy::getNumThreads(GenericDeviceTy &GenericDevice,
-                                        uint32_t ThreadLimitClause[3]) const {
-  assert(ThreadLimitClause[1] == 0 && ThreadLimitClause[2] == 0 &&
-         "Multi dimensional launch not supported yet.");
+void GenericKernelTy::getNumThreads(GenericDeviceTy &GenericDevice,
+                                    uint32_t ThreadLimitClause[3]) const {
+  if (IsBareKernel)
+    return;
 
-  if (IsBareKernel && ThreadLimitClause[0] > 0)
-    return ThreadLimitClause[0];
+  assert(ThreadLimitClause[1] == 1 && ThreadLimitClause[2] == 1 &&
+         "Multi dimensional launch not supported yet.");
 
   if (ThreadLimitClause[0] > 0 && isGenericMode())
     ThreadLimitClause[0] += GenericDevice.getWarpSize();
 
-  return std::min(MaxNumThreads, (ThreadLimitClause[0] > 0)
-                                     ? ThreadLimitClause[0]
-                                     : PreferredNumThreads);
+  ThreadLimitClause[0] =
+      std::min(MaxNumThreads, (ThreadLimitClause[0] > 0) ? ThreadLimitClause[0]
+                                                         : PreferredNumThreads);
 }
 
-uint64_t GenericKernelTy::getNumBlocks(GenericDeviceTy &GenericDevice,
-                                       uint32_t NumTeamsClause[3],
-                                       uint64_t LoopTripCount,
-                                       uint32_t &NumThreads,
-                                       bool IsNumThreadsFromUser) const {
-  assert(NumTeamsClause[1] == 0 && NumTeamsClause[2] == 0 &&
-         "Multi dimensional launch not supported yet.");
+void GenericKernelTy::getNumBlocks(GenericDeviceTy &GenericDevice,
+                                   uint32_t NumTeamsClause[3],
+                                   uint64_t LoopTripCount, uint32_t &NumThreads,
+                                   bool IsNumThreadsFromUser) const {
+  if (IsBareKernel)
+    return;
 
-  if (IsBareKernel && NumTeamsClause[0] > 0)
-    return NumTeamsClause[0];
+  assert(NumTeamsClause[1] == 1 && NumTeamsClause[2] == 1 &&
+         "Multi dimensional launch not supported yet.");
 
   if (NumTeamsClause[0] > 0) {
     // TODO: We need to honor any value and consequently allow more than the
     // block limit. For this we might need to start multiple kernels or let the
     // blocks start again until the requested number has been started.
-    return std::min(NumTeamsClause[0], GenericDevice.getBlockLimit());
+    NumTeamsClause[0] =
+        std::min(NumTeamsClause[0], GenericDevice.getBlockLimit());
+    return;
   }
 
   uint64_t DefaultNumBlocks = GenericDevice.getDefaultNumBlocks();
@@ -719,7 +726,8 @@ uint64_t GenericKernelTy::getNumBlocks(GenericDeviceTy &GenericDevice,
   // If the loops are long running we rather reuse blocks than spawn too many.
   if (GenericDevice.getReuseBlocksForHighTripCount())
     PreferredNumBlocks = std::min(TripCountNumBlocks, DefaultNumBlocks);
-  return std::min(PreferredNumBlocks, GenericDevice.getBlockLimit());
+  NumTeamsClause[0] =
+      std::min(PreferredNumBlocks, GenericDevice.getBlockLimit());
 }
 
 GenericDeviceTy::GenericDeviceTy(GenericPluginTy &Plugin, int32_t DeviceId,
diff --git a/offload/plugins-nextgen/cuda/src/rtl.cpp b/offload/plugins-nextgen/cuda/src/rtl.cpp
index 015c7775ba3513..1e3d79c9554c67 100644
--- a/offload/plugins-nextgen/cuda/src/rtl.cpp
+++ b/offload/plugins-nextgen/cuda/src/rtl.cpp
@@ -149,8 +149,8 @@ struct CUDAKernelTy : public GenericKernelTy {
   }
 
   /// Launch the CUDA kernel function.
-  Error launchImpl(GenericDeviceTy &GenericDevice, uint32_t NumThreads,
-                   uint64_t NumBlocks, KernelArgsTy &KernelArgs,
+  Error launchImpl(GenericDeviceTy &GenericDevice, uint32_t NumThreads[3],
+                   uint32_t NumBlocks[3], KernelArgsTy &KernelArgs,
                    KernelLaunchParamsTy LaunchParams,
                    AsyncInfoWrapperTy &AsyncInfoWrapper) const override;
 
@@ -1230,10 +1230,10 @@ struct CUDADeviceTy : public GenericDeviceTy {
     AsyncInfoWrapperTy AsyncInfoWrapper(*this, nullptr);
 
     KernelArgsTy KernelArgs = {};
-    if (auto Err =
-            CUDAKernel.launchImpl(*this, /*NumThread=*/1u,
-                                  /*NumBlocks=*/1ul, KernelArgs,
-                                  KernelLaunchParamsTy{}, AsyncInfoWrapper))
+    uint32_t NumBlocksAndThreads[3] = {1u, 1u, 1u};
+    if (auto Err = CUDAKernel.launchImpl(
+            *this, NumBlocksAndThreads, NumBlocksAndThreads, KernelArgs,
+            KernelLaunchParamsTy{}, AsyncInfoWrapper))
       return Err;
 
     Error Err = Plugin::success();
@@ -1276,7 +1276,7 @@ struct CUDADeviceTy : public GenericDeviceTy {
 };
 
 Error CUDAKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
-                               uint32_t NumThreads, uint64_t NumBlocks,
+                               uint32_t NumThreads[3], uint32_t NumBlocks[3],
                                KernelArgsTy &KernelArgs,
                                KernelLaunchParamsTy LaunchParams,
                                AsyncInfoWrapperTy &AsyncInfoWrapper) const {
@@ -1294,9 +1294,8 @@ Error CUDAKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
                     reinterpret_cast<void *>(&LaunchParams.Size),
                     CU_LAUNCH_PARAM_END};
 
-  CUresult Res = cuLaunchKernel(Func, NumBlocks, /*gridDimY=*/1,
-                                /*gridDimZ=*/1, NumThreads,
-                                /*blockDimY=*/1, /*blockDimZ=*/1,
+  CUresult Res = cuLaunchKernel(Func, NumBlocks[0], NumBlocks[1], NumBlocks[2],
+                                NumThreads[0], NumThreads[1], NumThreads[2],
                                 MaxDynCGroupMem, Stream, nullptr, Config);
   return Plugin::check(Res, "Error in cuLaunchKernel for '%s': %s", getName());
 }
diff --git a/offload/plugins-nextgen/host/src/rtl.cpp b/offload/plugins-nextgen/host/src/rtl.cpp
index 6f2e3d8604ec82..915c41e88c5828 100644
--- a/offload/plugins-nextgen/host/src/rtl.cpp
+++ b/offload/plugins-nextgen/host/src/rtl.cpp
@@ -89,8 +89,8 @@ struct GenELF64KernelTy : public GenericKernelTy {
   }
 
   /// Launch the kernel using the libffi.
-  Error launchImpl(GenericDeviceTy &GenericDevice, uint32_t NumThreads,
-                   uint64_t NumBlocks, KernelArgsTy &KernelArgs,
+  Error launchImpl(GenericDeviceTy &GenericDevice, uint32_t NumThreads[3],
+                   uint32_t NumBlocks[3], KernelArgsTy &KernelArgs,
                    KernelLaunchParamsTy LaunchParams,
                    AsyncInfoWrapperTy &AsyncInfoWrapper) const override {
     // Create a vector of ffi_types, one per argument.
diff --git a/offload/src/interface.cpp b/offload/src/interface.cpp
index 21f9114ac2b088..6c9d161468df4b 100644
--- a/offload/src/interface.cpp
+++ b/offload/src/interface.cpp
@@ -284,11 +284,11 @@ static KernelArgsTy *upgradeKernelArgs(KernelArgsTy *KernelArgs,
     LocalKernelArgs.Flags = KernelArgs->Flags;
     LocalKernelArgs.DynCGroupMem = 0;
     LocalKernelArgs.NumTeams[0] = NumTeams;
-    LocalKernelArgs.NumTeams[1] = 0;
-    LocalKernelArgs.NumTeams[2] = 0;
+    LocalKernelArgs.NumTeams[1] = 1;
+    LocalKernelArgs.NumTeams[2] = 1;
     LocalKernelArgs.ThreadLimit[0] = ThreadLimit;
-    LocalKernelArgs.ThreadLimit[1] = 0;
-    LocalKernelArgs.ThreadLimit[2] = 0;
+    LocalKernelArgs.ThreadLimit[1] = 1;
+    LocalKernelArgs.ThreadLimit[2] = 1;
     return &LocalKernelArgs;
   }
 
@@ -320,12 +320,6 @@ static inline int targetKernel(ident_t *Loc, int64_t DeviceId, int32_t NumTeams,
   KernelArgs =
       upgradeKernelArgs(KernelArgs, LocalKernelArgs, NumTeams, ThreadLimit);
 
-  assert(KernelArgs->NumTeams[0] == static_cast<uint32_t>(NumTeams) &&
-         !KernelArgs->NumTeams[1] && !KernelArgs->NumTeams[2] &&
-         "OpenMP interface should not use multiple dimensions");
-  assert(KernelArgs->ThreadLimit[0] == static_cast<uint32_t>(ThreadLimit) &&
-         !KernelArgs->ThreadLimit[1] && !KernelArgs->ThreadLimit[2] &&
-         "OpenMP interface should not use multiple dimensions");
   TIMESCOPE_WITH_DETAILS_AND_IDENT(
       "Runtime: target exe",
       "NumTeams=" + std::to_string(NumTeams) +
diff --git a/offload/src/omptarget.cpp b/offload/src/omptarget.cpp
index 66137b53b0cb4e..1a7af5649b9e22 100644
--- a/offload/src/omptarget.cpp
+++ b/offload/src/omptarget.cpp
@@ -1451,8 +1451,6 @@ int target(ident_t *Loc, DeviceTy &Device, void *HostPtr,
         Loc);
 
 #ifdef OMPT_SUPPORT
-    assert(KernelArgs.NumTeams[1] == 0 && KernelArgs.NumTeams[2] == 0 &&
-           "Multi dimensional launch not supported yet.");
     /// RAII to establish tool anchors before and after kernel launch
     int32_t NumTeams = KernelArgs.NumTeams[0];
     // No need to guard this with OMPT_IF_BUILT
diff --git a/offload/test/offloading/ompx_bare_multi_dim.cpp b/offload/test/offloading/ompx_bare_multi_dim.cpp
new file mode 100644
index 00000000000000..150e805da2a175
--- /dev/null
+++ b/offload/test/offloading/ompx_bare_multi_dim.cpp
@@ -0,0 +1,54 @@
+// RUN: %libomptarget-compilexx-run-and-check-generic
+//
+// REQUIRES: gpu
+
+#include <ompx.h>
+
+#include <cassert>
+#include <vector>
+
+int main(int argc, char *argv[]) {
+  int bs[3] = {32u, 4u, 2u};
+  int gs[3] = {2u, 4u, 6u};
+  int n = bs[0] * bs[1] * bs[2] * gs[0] * gs[1] * gs[2];
+  std::vector<int> x_buf(n);
+  std::vector<int> y_buf(n);
+  std::vector<int> z_buf(n);
+
+  auto x = x_buf.data();
+  auto y = y_buf.data();
+  auto z = z_buf.data();
+  for (int i = 0; i < n; ++i) {
+    x[i] = i;
+    y[i] = i + 1;
+  }
+
+#pragma omp target teams ompx_bare num_teams(gs[0], gs[1], gs[2])              \
+    thread_limit(bs[0], bs[1], bs[2]) map(to : x[ : n], y[ : n])               \
+    map(from : z[ : n])
+  {
+    int tid_x = ompx_thread_id_x();
+    int tid_y = ompx_thread_id_y();
+    int tid_z = ompx_thread_id_z();
+    int gid_x = ompx_block_id_x();
+    int gid_y = ompx_block_id_y();
+    int gid_z = ompx_block_id_z();
+    int bs_x = ompx_block_dim_x();
+    int bs_y = ompx_block_dim_y();
+    int bs_z = ompx_block_dim_z();
+    int bs = bs_x * bs_y * bs_z;
+    int gs_x = ompx_grid_dim_x();
+    int gs_y = ompx_grid_dim_y();
+    int gid = gid_z * gs_y * gs_x + gid_y * gs_x + gid_x;
+    int tid = tid_z * bs_x * bs_y + tid_y * bs_x + tid_x;
+    int i = gid * bs + tid;
+    z[i] = x[i] + y[i];
+  }
+
+  for (int i = 0; i < n; ++i) {
+    if (z[i] != (2 * i + 1))
+      return 1;
+  }
+
+  return 0;
+}



More information about the llvm-commits mailing list