[llvm] [offload] Add properties parameter to olLaunchKernel (PR #184343)

Joseph Huber via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 3 09:56:22 PST 2026


================
@@ -1512,6 +1566,71 @@ Error CUDAKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
   return Plugin::check(Res, "error in cuLaunchKernel for '%s': %s", getName());
 }
 
+Expected<uint32_t> CUDAKernelTy::getMaxCooperativeGroupCount(
+    GenericDeviceTy &GenericDevice, uint32_t WorkDim,
+    const size_t *LocalWorkSize, size_t DynamicSharedMemorySize) const {
+  CUDADeviceTy &CUDADevice = static_cast<CUDADeviceTy &>(GenericDevice);
+
+  uint32_t SupportsCooperative = 0;
+  if (auto Err = CUDADevice.getDeviceAttr(
+          CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH, SupportsCooperative))
+    return Err;
+
+  if (!SupportsCooperative) {
+    return Plugin::error(ErrorCode::UNSUPPORTED,
+                         "Device does not support cooperative launch");
+  }
+
+  // Calculate total local work size
+  size_t LocalWorkSizeTotal = LocalWorkSize[0];
+  LocalWorkSizeTotal *= (WorkDim >= 2 ? LocalWorkSize[1] : 1);
+  LocalWorkSizeTotal *= (WorkDim == 3 ? LocalWorkSize[2] : 1);
+
+  // Query max active blocks per multiprocessor
+  int MaxNumActiveGroupsPerCU = 0;
+  CUresult Res = cuOccupancyMaxActiveBlocksPerMultiprocessor(
+      &MaxNumActiveGroupsPerCU, Func, LocalWorkSizeTotal,
+      DynamicSharedMemorySize);
+  if (auto Err = Plugin::check(
+          Res, "error in cuOccupancyMaxActiveBlocksPerMultiprocessor: %s"))
+    return Err;
+
+  assert(MaxNumActiveGroupsPerCU >= 0);
+
+  // Handle the case where we can't have all SMs active with at least 1 group
+  // per SM. In that case, the device is still able to run 1 work-group, hence
+  // we will manually check if it is possible with the available HW resources.
+  if (MaxNumActiveGroupsPerCU == 0) {
+    // Get max threads per block for this kernel
+    int MaxThreads;
----------------
jhuber6 wrote:

What does the API define these sizes as? Usually I'd expect a more explicit `uint32_t` and the runtime states that the pointer is always 4 bytes or something.

https://github.com/llvm/llvm-project/pull/184343


More information about the llvm-commits mailing list