[llvm] [offload] Add properties parameter to olLaunchKernel (PR #184343)

via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 3 05:56:51 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-offload

@llvm/pr-subscribers-backend-amdgpu

Author: Ɓukasz Plewa (lplewa)

<details>
<summary>Changes</summary>

Introduce a properties argument to olLaunchKernel to enable future extensions.

This change adds initial extensions for:
- cooperative kernel launch
- kernel argument size (required by the L0 plugin)

---

Patch is 54.38 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/184343.diff


24 Files Affected:

- (modified) offload/include/Shared/APITypes.h (+5-4) 
- (modified) offload/liboffload/API/Device.td (+8-1) 
- (modified) offload/liboffload/API/Kernel.td (+69-8) 
- (modified) offload/liboffload/src/OffloadImpl.cpp (+51-3) 
- (modified) offload/plugins-nextgen/amdgpu/src/rtl.cpp (+14) 
- (modified) offload/plugins-nextgen/common/include/PluginInterface.h (+6) 
- (modified) offload/plugins-nextgen/cuda/dynamic_cuda/cuda.cpp (+2) 
- (modified) offload/plugins-nextgen/cuda/dynamic_cuda/cuda.h (+28) 
- (modified) offload/plugins-nextgen/cuda/src/rtl.cpp (+124-5) 
- (modified) offload/plugins-nextgen/host/src/rtl.cpp (+15) 
- (modified) offload/plugins-nextgen/level_zero/include/L0Context.h (+8) 
- (modified) offload/plugins-nextgen/level_zero/include/L0Kernel.h (+10-2) 
- (modified) offload/plugins-nextgen/level_zero/src/L0Context.cpp (+8) 
- (modified) offload/plugins-nextgen/level_zero/src/L0Device.cpp (+3) 
- (modified) offload/plugins-nextgen/level_zero/src/L0Kernel.cpp (+109-13) 
- (modified) offload/tools/offload-tblgen/PrintGen.cpp (+4-3) 
- (modified) offload/unittests/Conformance/lib/DeviceContext.cpp (+1-1) 
- (modified) offload/unittests/OffloadAPI/CMakeLists.txt (+2-1) 
- (modified) offload/unittests/OffloadAPI/common/Fixtures.hpp (+37) 
- (modified) offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp (+21-58) 
- (added) offload/unittests/OffloadAPI/kernel/olLaunchKernelCooperative.cpp (+133) 
- (modified) offload/unittests/OffloadAPI/memory/olMemcpy.cpp (+3-3) 
- (modified) offload/unittests/OffloadAPI/queue/olLaunchHostFunction.cpp (+2-2) 
- (modified) offload/unittests/OffloadAPI/queue/olWaitEvents.cpp (+3-3) 


``````````diff
diff --git a/offload/include/Shared/APITypes.h b/offload/include/Shared/APITypes.h
index 8c150b6bfc2d4..9213f2924f1f7 100644
--- a/offload/include/Shared/APITypes.h
+++ b/offload/include/Shared/APITypes.h
@@ -100,10 +100,11 @@ struct KernelArgsTy {
   uint64_t Tripcount =
       0; // Tripcount for the teams / distribute loop, 0 otherwise.
   struct {
-    uint64_t NoWait : 1; // Was this kernel spawned with a `nowait` clause.
-    uint64_t IsCUDA : 1; // Was this kernel spawned via CUDA.
-    uint64_t Unused : 62;
-  } Flags = {0, 0, 0};
+    uint64_t NoWait : 1;      // Was this kernel spawned with a `nowait` clause.
+    uint64_t IsCUDA : 1;      // Was this kernel spawned via CUDA.
+    uint64_t Cooperative : 1; // Was this kernel spawned as cooperative.
+    uint64_t Unused : 61;
+  } Flags = {0, 0, 0, 0};
   // The number of teams (for x,y,z dimension).
   uint32_t NumTeams[3] = {0, 0, 0};
   // The number of threads (for x,y,z dimension).
diff --git a/offload/liboffload/API/Device.td b/offload/liboffload/API/Device.td
index 6ada191089674..7790386ae02e1 100644
--- a/offload/liboffload/API/Device.td
+++ b/offload/liboffload/API/Device.td
@@ -47,7 +47,14 @@ def ol_device_info_t : Enum {
   ];
   list<TaggedEtor> fp_configs = !foreach(type, ["Single", "Double", "Half"], TaggedEtor<type # "_FP_CONFIG", "ol_device_fp_capability_flags_t", type # " precision floating point capability">);
   list<TaggedEtor> native_vec_widths = !foreach(type, ["char","short","int","long","float","double","half"], TaggedEtor<"NATIVE_VECTOR_WIDTH_" # type, "uint32_t", "Native vector width for " # type>);
-  let etors = !listconcat(basic_etors, fp_configs, native_vec_widths);
+  // This list is maintained separately to allow adding new basic etors without
+  // changing the values of previous ones.
+  list<TaggedEtor> basic_etors2 =
+      [TaggedEtor<"COOPERATIVE_LAUNCH_SUPPORT", "bool",
+                  "Is cooperative kernel launch supported">,
+  ];
+  let etors =
+      !listconcat(basic_etors, fp_configs, native_vec_widths, basic_etors2);
 }
 
 def ol_device_fp_capability_flag_t : Enum {
diff --git a/offload/liboffload/API/Kernel.td b/offload/liboffload/API/Kernel.td
index 2f5692a19d712..4a0627604da7d 100644
--- a/offload/liboffload/API/Kernel.td
+++ b/offload/liboffload/API/Kernel.td
@@ -20,20 +20,52 @@ def ol_kernel_launch_size_args_t : Struct {
     ];
 }
 
+def OL_KERNEL_LAUNCH_PROP_END : Macro {
+  let desc = "last element of the ol_kernel_launch_prop_t array";
+  let value = "{OL_KERNEL_LAUNCH_PROP_TYPE_NONE, NULL}";
+}
+
+def ol_kernel_launch_prop_type_t : Enum {
+  let desc = "Defines structure type";
+  let is_typed = 1;
+  let etors =
+      [TaggedEtor<"none", "void *", "Used for null terminating property array">,
+       TaggedEtor<"size", "size_t *", "Array of the arguments sizes.">,
+       TaggedEtor<"is_cooperative ", "bool *", "Cooperative kernel launch">];
+}
+
+def ol_kernel_launch_prop_t : Struct {
+  let desc = "Optional properties for kernel launch.";
+  let members = [StructMember<"ol_kernel_launch_prop_type_t", "type",
+                              "Type of the data field">,
+                 StructMember<"void *", "data",
+                              "Pointer to property-specific data.">];
+}
+
 def olLaunchKernel : Function {
     let desc = "Enqueue a kernel launch with the specified size and parameters.";
     let details = [
         "If a queue is not specified, kernel execution happens synchronously",
         "ArgumentsData may be set to NULL (to indicate no parameters)"
     ];
-    let params = [
-        Param<"ol_queue_handle_t", "Queue", "handle of the queue", PARAM_IN_OPTIONAL>,
-        Param<"ol_device_handle_t", "Device", "handle of the device to execute on", PARAM_IN>,
-        Param<"ol_symbol_handle_t", "Kernel", "handle of the kernel", PARAM_IN>,
-        Param<"const void*", "ArgumentsData", "pointer to the kernel argument struct", PARAM_IN_OPTIONAL>,
-        Param<"size_t", "ArgumentsSize", "size of the kernel argument struct", PARAM_IN>,
-        Param<"const ol_kernel_launch_size_args_t*", "LaunchSizeArgs", "pointer to the struct containing launch size parameters", PARAM_IN>,
-    ];
+    let params =
+        [Param<"ol_queue_handle_t", "Queue", "handle of the queue",
+               PARAM_IN_OPTIONAL>,
+         Param<"ol_device_handle_t", "Device",
+               "handle of the device to execute on", PARAM_IN>,
+         Param<"ol_symbol_handle_t", "Kernel", "handle of the kernel",
+               PARAM_IN>,
+         Param<"const void*", "ArgumentsData",
+               "pointer to the kernel argument struct", PARAM_IN_OPTIONAL>,
+         Param<"size_t", "ArgumentsSize", "size of the kernel argument struct",
+               PARAM_IN>,
+         Param<"const ol_kernel_launch_size_args_t*", "LaunchSizeArgs",
+               "pointer to the struct containing launch size parameters",
+               PARAM_IN>,
+         Param<"const ol_kernel_launch_prop_t *", "Properties",
+               "Array of optional properties, last element must be "
+               "OL_KERNEL_LAUNCH_PROP_END",
+               PARAM_IN_OPTIONAL>];
     let returns = [
         Return<"OL_ERRC_INVALID_ARGUMENT", ["`ArgumentsSize > 0 && ArgumentsData == NULL`"]>,
         Return<"OL_ERRC_INVALID_DEVICE", ["If Queue is non-null but does not belong to Device"]>,
@@ -57,3 +89,32 @@ def olCalculateOptimalOccupancy : Function {
         Return<"OL_ERRC_UNSUPPORTED", ["The backend cannot provide this information"]>,
     ];
 }
+
+def olGetKernelMaxCooperativeGroupCount : Function {
+  let desc = "Query the maximum number of work groups that can be launched "
+             "cooperatively for a kernel.";
+  let details =
+      ["This function returns the maximum number of work groups that can "
+       "participate in a cooperative launch for the given kernel.",
+       "The maximum count depends on the work group size and dynamic shared "
+       "memory usage.",
+  ];
+  let params = [Param<"ol_device_handle_t", "Device",
+                      "device intended to run the kernel", PARAM_IN>,
+                Param<"ol_symbol_handle_t", "Kernel", "handle of the kernel",
+                      PARAM_IN>,
+                Param<"uint32_t", "WorkDim",
+                      "number of work dimensions (1, 2, or 3)", PARAM_IN>,
+                Param<"const size_t*", "LocalWorkSize",
+                      "local work group size for each dimension", PARAM_IN>,
+                Param<"size_t", "DynamicSharedMemorySize",
+                      "dynamic shared memory size in bytes", PARAM_IN>,
+                Param<"uint32_t*", "MaxGroupCount",
+                      "maximum number of cooperative groups", PARAM_OUT>];
+  let returns =
+      [Return<"OL_ERRC_SYMBOL_KIND", ["The provided symbol is not a kernel"]>,
+       Return<
+           "OL_ERRC_UNSUPPORTED", ["Cooperative launch is not supported or "
+                                   "backend cannot provide this information"]>,
+  ];
+}
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp
index dd3ec0f61b4da..2b542b9c49d91 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -496,7 +496,13 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
                        "plugin returned incorrect type");
     return Info.writeString(std::get<std::string>(Entry->Value).c_str());
   }
-
+  case OL_DEVICE_INFO_COOPERATIVE_LAUNCH_SUPPORT: {
+    // Bool value
+    if (!std::holds_alternative<bool>(Entry->Value))
+      return makeError(ErrorCode::BACKEND_FAILURE,
+                       "plugin returned incorrect type");
+    return Info.write(static_cast<uint8_t>(std::get<bool>(Entry->Value)));
+  }
   case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE:
   case OL_DEVICE_INFO_MAX_WORK_SIZE:
   case OL_DEVICE_INFO_VENDOR_ID:
@@ -1032,10 +1038,34 @@ Error olCalculateOptimalOccupancy_impl(ol_device_handle_t Device,
   return Error::success();
 }
 
+Error olGetKernelMaxCooperativeGroupCount_impl(ol_device_handle_t Device,
+                                               ol_symbol_handle_t Kernel,
+                                               uint32_t WorkDim,
+                                               const size_t *LocalWorkSize,
+                                               size_t DynamicSharedMemorySize,
+                                               uint32_t *MaxGroupCount) {
+  if (Kernel->Kind != OL_SYMBOL_KIND_KERNEL)
+    return createOffloadError(ErrorCode::SYMBOL_KIND,
+                              "provided symbol is not a kernel");
+
+  auto *DeviceImpl = Device->Device;
+  auto *KernelImpl = std::get<GenericKernelTy *>(Kernel->PluginImpl);
+
+  auto Res = KernelImpl->getMaxCooperativeGroupCount(
+      *DeviceImpl, WorkDim, LocalWorkSize, DynamicSharedMemorySize);
+  if (auto Err = Res.takeError())
+    return Err;
+
+  *MaxGroupCount = *Res;
+
+  return Error::success();
+}
+
 Error olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device,
                           ol_symbol_handle_t Kernel, const void *ArgumentsData,
                           size_t ArgumentsSize,
-                          const ol_kernel_launch_size_args_t *LaunchSizeArgs) {
+                          const ol_kernel_launch_size_args_t *LaunchSizeArgs,
+                          const ol_kernel_launch_prop_t *Properties) {
   auto *DeviceImpl = Device->Device;
   if (Queue && Device != Queue->Device) {
     return createOffloadError(
@@ -1048,7 +1078,6 @@ Error olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device,
                               "provided symbol is not a kernel");
 
   auto *QueueImpl = Queue ? Queue->AsyncInfo : nullptr;
-  AsyncInfoWrapperTy AsyncInfoWrapper(*DeviceImpl, QueueImpl);
   KernelArgsTy LaunchArgs{};
   LaunchArgs.NumTeams[0] = LaunchSizeArgs->NumGroups.x;
   LaunchArgs.NumTeams[1] = LaunchSizeArgs->NumGroups.y;
@@ -1058,6 +1087,25 @@ Error olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device,
   LaunchArgs.ThreadLimit[2] = LaunchSizeArgs->GroupSize.z;
   LaunchArgs.DynCGroupMem = LaunchSizeArgs->DynSharedMemory;
 
+  while (Properties && Properties->type != OL_KERNEL_LAUNCH_PROP_TYPE_NONE) {
+    switch (Properties->type) {
+    case OL_KERNEL_LAUNCH_PROP_TYPE_SIZE:
+      LaunchArgs.ArgSizes = const_cast<int64_t *>(
+          reinterpret_cast<const int64_t *>(Properties->data));
+      break;
+    case OL_KERNEL_LAUNCH_PROP_TYPE_IS_COOPERATIVE:
+      LaunchArgs.Flags.Cooperative =
+          *reinterpret_cast<const bool *>(Properties->data);
+      break;
+    default:
+      return createOffloadError(ErrorCode::INVALID_ENUMERATION,
+                                "olLaunchKernel property enum '%i' is invalid",
+                                Properties->type);
+    }
+    Properties++;
+  }
+
+  AsyncInfoWrapperTy AsyncInfoWrapper(*DeviceImpl, QueueImpl);
   KernelLaunchParamsTy Params;
   Params.Data = const_cast<void *>(ArgumentsData);
   Params.Size = ArgumentsSize;
diff --git a/offload/plugins-nextgen/amdgpu/src/rtl.cpp b/offload/plugins-nextgen/amdgpu/src/rtl.cpp
index 379c8ec11225d..443b703820c93 100644
--- a/offload/plugins-nextgen/amdgpu/src/rtl.cpp
+++ b/offload/plugins-nextgen/amdgpu/src/rtl.cpp
@@ -586,6 +586,15 @@ struct AMDGPUKernelTy : public GenericKernelTy {
         "occupancy calculations for AMDGPU are not yet implemented");
   }
 
+  /// Get maximum cooperative group count
+  Expected<uint32_t>
+  getMaxCooperativeGroupCount(GenericDeviceTy &GenericDevice, uint32_t WorkDim,
+                              const size_t *LocalWorkSize,
+                              size_t DynamicSharedMemorySize) const override {
+    return Plugin::error(ErrorCode::UNSUPPORTED,
+                         "cooperative launch not supported for AMDGPU");
+  }
+
   /// Print more elaborate kernel launch info for AMDGPU
   Error printLaunchInfoDetails(GenericDeviceTy &GenericDevice,
                                KernelArgsTy &KernelArgs, uint32_t NumThreads[3],
@@ -3737,6 +3746,11 @@ Error AMDGPUKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
                                  KernelArgsTy &KernelArgs,
                                  KernelLaunchParamsTy LaunchParams,
                                  AsyncInfoWrapperTy &AsyncInfoWrapper) const {
+  // Cooperative kernel launch is not yet supported for AMDGPU
+  if (KernelArgs.Flags.Cooperative)
+    return Plugin::error(ErrorCode::UNSUPPORTED,
+                         "cooperative kernel launch not supported for AMDGPU");
+
   AMDGPUPluginTy &AMDGPUPlugin =
       static_cast<AMDGPUPluginTy &>(GenericDevice.Plugin);
   AMDHostDeviceTy &HostDevice = AMDGPUPlugin.getHostDevice();
diff --git a/offload/plugins-nextgen/common/include/PluginInterface.h b/offload/plugins-nextgen/common/include/PluginInterface.h
index 1c59ed1eda841..a28ef6d9287e2 100644
--- a/offload/plugins-nextgen/common/include/PluginInterface.h
+++ b/offload/plugins-nextgen/common/include/PluginInterface.h
@@ -370,6 +370,12 @@ struct GenericKernelTy {
   virtual Expected<uint64_t> maxGroupSize(GenericDeviceTy &GenericDevice,
                                           uint64_t DynamicMemSize) const = 0;
 
+  /// Get the maximum number of work groups that can be launched cooperatively.
+  virtual Expected<uint32_t>
+  getMaxCooperativeGroupCount(GenericDeviceTy &GenericDevice, uint32_t WorkDim,
+                              const size_t *LocalWorkSize,
+                              size_t DynamicSharedMemorySize) const = 0;
+
   /// Get the kernel name.
   const char *getName() const { return Name.c_str(); }
 
diff --git a/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.cpp b/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.cpp
index 80e3e418ae3fa..16205001a035c 100644
--- a/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.cpp
+++ b/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.cpp
@@ -43,6 +43,7 @@ DLWRAP(cuDriverGetVersion, 1)
 
 DLWRAP(cuGetErrorString, 2)
 DLWRAP(cuLaunchKernel, 11)
+DLWRAP(cuLaunchKernelEx, 4)
 DLWRAP(cuLaunchHostFunc, 3)
 
 DLWRAP(cuMemAlloc, 2)
@@ -83,6 +84,7 @@ DLWRAP(cuDevicePrimaryCtxSetFlags, 2)
 DLWRAP(cuDevicePrimaryCtxRetain, 2)
 DLWRAP(cuModuleLoadDataEx, 5)
 DLWRAP(cuOccupancyMaxPotentialBlockSize, 6)
+DLWRAP(cuOccupancyMaxActiveBlocksPerMultiprocessor, 4)
 DLWRAP(cuFuncGetParamInfo, 4)
 
 DLWRAP(cuDeviceCanAccessPeer, 3)
diff --git a/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.h b/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.h
index 7e42c66dddabb..f4087c87048c9 100644
--- a/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.h
+++ b/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.h
@@ -295,6 +295,31 @@ static inline void *CU_LAUNCH_PARAM_BUFFER_SIZE = (void *)0x02;
 typedef void (*CUstreamCallback)(CUstream, CUresult, void *);
 typedef size_t (*CUoccupancyB2DSize)(int);
 
+// Launch configuration structures for cuLaunchKernelEx
+typedef enum CUlaunchAttributeID_enum {
+  CU_LAUNCH_ATTRIBUTE_COOPERATIVE = 2,
+} CUlaunchAttributeID;
+
+typedef struct CUlaunchAttribute_st {
+  CUlaunchAttributeID id;
+  union {
+    int cooperative;
+  } value;
+} CUlaunchAttribute;
+
+typedef struct CUlaunchConfig_st {
+  unsigned int gridDimX;
+  unsigned int gridDimY;
+  unsigned int gridDimZ;
+  unsigned int blockDimX;
+  unsigned int blockDimY;
+  unsigned int blockDimZ;
+  unsigned int sharedMemBytes;
+  CUstream hStream;
+  CUlaunchAttribute *attrs;
+  unsigned int numAttrs;
+} CUlaunchConfig;
+
 CUresult cuCtxGetDevice(CUdevice *);
 CUresult cuDeviceGet(CUdevice *, int);
 CUresult cuDeviceGetAttribute(int *, CUdevice_attribute, CUdevice);
@@ -313,6 +338,7 @@ CUresult cuInit(unsigned);
 CUresult cuLaunchKernel(CUfunction, unsigned, unsigned, unsigned, unsigned,
                         unsigned, unsigned, unsigned, CUstream, void **,
                         void **);
+CUresult cuLaunchKernelEx(const CUlaunchConfig *, CUfunction, void **, void **);
 CUresult cuLaunchHostFunc(CUstream, CUhostFn, void *);
 
 CUresult cuMemAlloc(CUdeviceptr *, size_t);
@@ -390,6 +416,8 @@ CUresult cuMemGetAllocationGranularity(size_t *granularity,
                                        CUmemAllocationGranularity_flags option);
 CUresult cuOccupancyMaxPotentialBlockSize(int *, int *, CUfunction,
                                           CUoccupancyB2DSize, size_t, int);
+CUresult cuOccupancyMaxActiveBlocksPerMultiprocessor(int *, CUfunction, int,
+                                                     size_t);
 CUresult cuFuncGetParamInfo(CUfunction, size_t, size_t *, size_t *);
 
 #endif
diff --git a/offload/plugins-nextgen/cuda/src/rtl.cpp b/offload/plugins-nextgen/cuda/src/rtl.cpp
index d5ab0b3309c86..153a4eb964226 100644
--- a/offload/plugins-nextgen/cuda/src/rtl.cpp
+++ b/offload/plugins-nextgen/cuda/src/rtl.cpp
@@ -174,6 +174,12 @@ struct CUDAKernelTy : public GenericKernelTy {
     return MaxBlockSize;
   }
 
+  /// Get maximum cooperative group count
+  Expected<uint32_t>
+  getMaxCooperativeGroupCount(GenericDeviceTy &GenericDevice, uint32_t WorkDim,
+                              const size_t *LocalWorkSize,
+                              size_t DynamicSharedMemorySize) const override;
+
 private:
   /// Initialize the size of the arguments.
   Error initArgsSize() {
@@ -1257,8 +1263,21 @@ struct CUDADeviceTy : public GenericDeviceTy {
       Info.add("Preemption Supported", (bool)TmpInt);
 
     Res = getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH, TmpInt);
-    if (Res == CUDA_SUCCESS)
-      Info.add("Cooperative Launch", (bool)TmpInt);
+    if (Res == CUDA_SUCCESS) {
+      // Cooperative launch requires CUDA 11.0+ for cuLaunchKernelEx API.
+      // The older cuLaunchCooperative API does not support passing kernel
+      // arguments as a buffer (kernelParams config). cuLaunchKernelEx was
+      // introduced in CUDA 11.0 and supports both cooperative launch
+      // attributes and proper kernel argument buffer passing.
+      int DriverVersion = 0;
+      if (auto Node = Info.get(DeviceInfo::DRIVER_VERSION)) {
+        if (auto *StrVal = std::get_if<std::string>(&(*Node)->Value))
+          DriverVersion = std::stoi(*StrVal);
+      }
+      bool SupportsCooperative = (bool)TmpInt && DriverVersion >= 11000;
+      Info.add("Cooperative Launch", SupportsCooperative, "",
+               DeviceInfo::COOPERATIVE_LAUNCH_SUPPORT);
+    }
 
     Res = getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_MULTI_GPU_BOARD, TmpInt);
     if (Res == CUDA_SUCCESS)
@@ -1495,9 +1514,44 @@ Error CUDAKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
     MaxDynCGroupMemLimit = MaxDynCGroupMem;
   }
 
-  CUresult Res = cuLaunchKernel(Func, NumBlocks[0], NumBlocks[1], NumBlocks[2],
-                                NumThreads[0], NumThreads[1], NumThreads[2],
-                                MaxDynCGroupMem, Stream, nullptr, Config);
+  CUresult Res;
+  if (KernelArgs.Flags.Cooperative) {
+    CUDADeviceTy &CUDADevice = static_cast<CUDADeviceTy &>(GenericDevice);
+
+    uint32_t SupportsCooperative = 0;
+    if (auto Err = CUDADevice.getDeviceAttr(
+            CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH, SupportsCooperative))
+      return Err;
+
+    if (!SupportsCooperative) {
+      return Plugin::error(ErrorCode::UNSUPPORTED,
+                           "Device does not support cooperative launch");
+    }
+
+    CUlaunchAttribute CoopAttr;
+    CoopAttr.id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE;
+    CoopAttr.value.cooperative = 1;
+
+    CUlaunchConfig LaunchConfig;
+    LaunchConfig.gridDimX = NumBlocks[0];
+    LaunchConfig.gridDimY = NumBlocks[1];
+    LaunchConfig.gridDimZ = NumBlocks[2];
+    LaunchConfig.blockDimX = NumThreads[0];
+    LaunchConfig.blockDimY = NumThreads[1];
+    LaunchConfig.blockDimZ = NumThreads[2];
+    LaunchConfig.sharedMemBytes = MaxDynCGroupMem;
+    LaunchConfig.hStream = Stream;
+    LaunchConfig.attrs = &CoopAttr;
+    LaunchConfig.numAttrs = 1;
+
+    // Launch kernel with config-based arguments
+    Res = cuLaunchKernelEx(&LaunchConfig, Func, nullptr, Config);
+  } else {
+    // Use regular cuLaunchKernel for non-cooperative launches
+    Res = cuLaunchKernel(Func, NumBlocks[0], NumBlocks[1], NumBlocks[2],
+                         NumThreads[0], NumThreads[1], NumThreads[2],
+                         MaxDynCGroupMem, Stream, nullptr, Config);
+  }
 
   // Register a callback to indicate when the kernel is complete.
   if (GenericDev...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/184343


More information about the llvm-commits mailing list