[llvm] [offload] Add properties parameter to olLaunchKernel (PR #184343)

Ɓukasz Plewa via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 4 09:50:26 PST 2026


================
@@ -437,27 +453,79 @@ Error L0KernelTy::launchImpl(GenericDeviceTy &GenericDevice,
   auto *AsyncQueue =
       IsAsync ? static_cast<AsyncQueueTy *>(AsyncInfo->Queue) : nullptr;
   auto &KernelPR = getProperties();
-
-  L0LaunchEnvTy KEnv(IsAsync, AsyncQueue, KernelPR);
+  bool IsCooperative = KernelArgs.Flags.Cooperative;
+  L0LaunchEnvTy KEnv(IsAsync, IsCooperative, AsyncQueue, KernelPR);
 
   // Protect from kernel preparation to submission as kernels are shared.
   KernelPR.Mtx.lock();
 
   if (auto Err = setKernelGroups(l0Device, KEnv, NumThreads, NumBlocks))
     return Err;
 
+  // Validate cooperative kernel launch constraints
+  if (IsCooperative) {
+    uint32_t MaxCooperativeGroupCount = 0;
+    CALL_ZE_RET_ERROR(zeKernelSuggestMaxCooperativeGroupCount, zeKernel,
+                      &MaxCooperativeGroupCount);
+
+    uint32_t TotalGroupCount = KEnv.GroupCounts.groupCountX *
+                               KEnv.GroupCounts.groupCountY *
+                               KEnv.GroupCounts.groupCountZ;
+
+    if (TotalGroupCount > MaxCooperativeGroupCount) {
+      KernelPR.Mtx.unlock();
+      return Plugin::error(
+          ErrorCode::INVALID_ARGUMENT,
+          "Cooperative kernel launch failed: requested %u groups exceeds "
+          "maximum %u cooperative groups supported by device",
+          TotalGroupCount, MaxCooperativeGroupCount);
+    }
+
+    INFO(OMP_INFOTYPE_PLUGIN_KERNEL, DeviceId,
+         "Cooperative kernel validated: using %u groups (max: %u)\n",
+         TotalGroupCount, MaxCooperativeGroupCount);
+  }
+
   // Set kernel arguments.
-  for (int32_t I = 0; I < NumArgs; I++) {
+
+  uint32_t NumArgs = Properties.NumKernelArgs;
+
+  std::vector<uint32_t> ArgSizes;
+
+  if (NumArgs > 0) {
+    if (KernelArgs.ArgSizes) {
+      for (uint32_t I = 0; I < NumArgs; I++) {
+        ArgSizes.push_back(KernelArgs.ArgSizes[I]);
+      }
+    } else {
+      if (!Context.ZeKernelArgumentSizeExt.Supported) {
+        return Plugin::error(
+            ErrorCode::INVALID_ARGUMENT,
+            "Level zero plugin requires kernel argument sizes.");
+      }
+      for (uint32_t I = 0; I < NumArgs; I++) {
+        uint32_t ArgSize;
+        CALL_ZE_RET_ERROR(
+            Context.ZeKernelArgumentSizeExt.zexKernelGetArgumentSize, zeKernel,
+            I, &ArgSize);
+        ArgSizes.push_back(ArgSize);
----------------
lplewa wrote:

done

https://github.com/llvm/llvm-project/pull/184343


More information about the llvm-commits mailing list