[llvm] [offload] Add properties parameter to olLaunchKernel (PR #184343)
Ćukasz Plewa via llvm-commits
llvm-commits at lists.llvm.org
Wed Mar 4 09:50:26 PST 2026
================
@@ -437,27 +453,79 @@ Error L0KernelTy::launchImpl(GenericDeviceTy &GenericDevice,
auto *AsyncQueue =
IsAsync ? static_cast<AsyncQueueTy *>(AsyncInfo->Queue) : nullptr;
auto &KernelPR = getProperties();
-
- L0LaunchEnvTy KEnv(IsAsync, AsyncQueue, KernelPR);
+ bool IsCooperative = KernelArgs.Flags.Cooperative;
+ L0LaunchEnvTy KEnv(IsAsync, IsCooperative, AsyncQueue, KernelPR);
// Protect from kernel preparation to submission as kernels are shared.
KernelPR.Mtx.lock();
if (auto Err = setKernelGroups(l0Device, KEnv, NumThreads, NumBlocks))
return Err;
+ // Validate cooperative kernel launch constraints
+ if (IsCooperative) {
+ uint32_t MaxCooperativeGroupCount = 0;
+ CALL_ZE_RET_ERROR(zeKernelSuggestMaxCooperativeGroupCount, zeKernel,
+ &MaxCooperativeGroupCount);
+
+ uint32_t TotalGroupCount = KEnv.GroupCounts.groupCountX *
+ KEnv.GroupCounts.groupCountY *
+ KEnv.GroupCounts.groupCountZ;
+
+ if (TotalGroupCount > MaxCooperativeGroupCount) {
+ KernelPR.Mtx.unlock();
+ return Plugin::error(
+ ErrorCode::INVALID_ARGUMENT,
+ "Cooperative kernel launch failed: requested %u groups exceeds "
+ "maximum %u cooperative groups supported by device",
+ TotalGroupCount, MaxCooperativeGroupCount);
+ }
+
+ INFO(OMP_INFOTYPE_PLUGIN_KERNEL, DeviceId,
+ "Cooperative kernel validated: using %u groups (max: %u)\n",
+ TotalGroupCount, MaxCooperativeGroupCount);
+ }
+
// Set kernel arguments.
- for (int32_t I = 0; I < NumArgs; I++) {
+
+ uint32_t NumArgs = Properties.NumKernelArgs;
+
+ std::vector<uint32_t> ArgSizes;
+
+ if (NumArgs > 0) {
+ if (KernelArgs.ArgSizes) {
+ for (uint32_t I = 0; I < NumArgs; I++) {
+ ArgSizes.push_back(KernelArgs.ArgSizes[I]);
+ }
+ } else {
+ if (!Context.ZeKernelArgumentSizeExt.Supported) {
+ return Plugin::error(
+ ErrorCode::INVALID_ARGUMENT,
+ "Level zero plugin requires kernel argument sizes.");
+ }
+ for (uint32_t I = 0; I < NumArgs; I++) {
+ uint32_t ArgSize;
+ CALL_ZE_RET_ERROR(
+ Context.ZeKernelArgumentSizeExt.zexKernelGetArgumentSize, zeKernel,
+ I, &ArgSize);
+ ArgSizes.push_back(ArgSize);
----------------
lplewa wrote:
done
https://github.com/llvm/llvm-project/pull/184343
More information about the llvm-commits
mailing list