[Mlir-commits] [mlir] [mlir] SYCL runtime wrapper: add handler for set default device. (PR #90878)
Sang Ik Lee
llvmlistbot at llvm.org
Thu May 2 10:15:28 PDT 2024
https://github.com/silee2 created https://github.com/llvm/llvm-project/pull/90878
Add missing support for set default device
Thread local storage is used for keeping per thread default device.
Default context is redefined as a context with all GPU devices.
>From fb762abd6a0ff98129ed502ca97da3af6bdb9dd6 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Thu, 2 May 2024 17:11:35 +0000
Subject: [PATCH] [mlir] SYCL runtime wrapper: add handler for set default
device.
---
.../ExecutionEngine/SyclRuntimeWrappers.cpp | 113 +++++++++++++-----
1 file changed, 80 insertions(+), 33 deletions(-)
diff --git a/mlir/lib/ExecutionEngine/SyclRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/SyclRuntimeWrappers.cpp
index c250340c38fc77..806d49f342d292 100644
--- a/mlir/lib/ExecutionEngine/SyclRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/SyclRuntimeWrappers.cpp
@@ -11,8 +11,11 @@
//===----------------------------------------------------------------------===//
#include <CL/sycl.hpp>
+#include <cstdio>
+#include <cstdlib>
#include <level_zero/ze_api.h>
#include <sycl/ext/oneapi/backend/level_zero.hpp>
+#include <vector>
#ifdef _WIN32
#define SYCL_RUNTIME_EXPORT __declspec(dllexport)
@@ -49,39 +52,53 @@ auto catchAll(F &&func) {
} // namespace
-static sycl::device getDefaultDevice() {
- static sycl::device syclDevice;
- static bool isDeviceInitialised = false;
- if (!isDeviceInitialised) {
- auto platformList = sycl::platform::get_platforms();
- for (const auto &platform : platformList) {
- auto platformName = platform.get_info<sycl::info::platform::name>();
- bool isLevelZero = platformName.find("Level-Zero") != std::string::npos;
- if (!isLevelZero)
- continue;
-
- syclDevice = platform.get_devices()[0];
- isDeviceInitialised = true;
- return syclDevice;
+thread_local static int32_t defaultDevice = 0;
+thread_local static bool isGpuPoolInitialized = false;
+thread_local static bool isDefaultContextInitialized = false;
+thread_local static std::vector<sycl::device> *pGpuPool = nullptr;
+thread_local static sycl::context *pDefaultContext = nullptr;
+
+static void initGpuPool() {
+ if (isGpuPoolInitialized)
+ return;
+ auto platformList = sycl::platform::get_platforms();
+ for (const auto &platform : platformList) {
+ if (platform.get_backend() == sycl::backend::ext_oneapi_level_zero) {
+ auto gpuDevices = platform.get_devices(sycl::info::device_type::gpu);
+ if (gpuDevices.empty()) {
+ throw std::runtime_error("SyclRuntime: No GPU devices found!");
+ }
+ pGpuPool = new std::vector<sycl::device>{gpuDevices};
+ isGpuPoolInitialized = true;
+ return;
}
- throw std::runtime_error("getDefaultDevice failed");
- } else
- return syclDevice;
+ }
+ throw std::runtime_error("SyclRuntime: No GPU devices found!");
}
-static sycl::context getDefaultContext() {
- static sycl::context syclContext{getDefaultDevice()};
- return syclContext;
+static sycl::device *getDefaultDevicePtr() {
+ initGpuPool();
+ return &((*pGpuPool)[defaultDevice]);
+}
+
+static sycl::context *getDefaultContextPtr() {
+ if (isDefaultContextInitialized) {
+ return pDefaultContext;
+ }
+ initGpuPool();
+ pDefaultContext = new sycl::context(*pGpuPool);
+ isDefaultContextInitialized = true;
+ return pDefaultContext;
}
static void *allocDeviceMemory(sycl::queue *queue, size_t size, bool isShared) {
void *memPtr = nullptr;
if (isShared) {
- memPtr = sycl::aligned_alloc_shared(64, size, getDefaultDevice(),
- getDefaultContext());
+ memPtr = sycl::aligned_alloc_shared(64, size, *getDefaultDevicePtr(),
+ *getDefaultContextPtr());
} else {
- memPtr = sycl::aligned_alloc_device(64, size, getDefaultDevice(),
- getDefaultContext());
+ memPtr = sycl::aligned_alloc_device(64, size, *getDefaultDevicePtr(),
+ *getDefaultContextPtr());
}
if (memPtr == nullptr) {
throw std::runtime_error("mem allocation failed!");
@@ -90,7 +107,13 @@ static void *allocDeviceMemory(sycl::queue *queue, size_t size, bool isShared) {
}
static void deallocDeviceMemory(sycl::queue *queue, void *ptr) {
- sycl::free(ptr, *queue);
+ if (queue == nullptr) {
+ queue = new sycl::queue(*getDefaultContextPtr(), *getDefaultDevicePtr());
+ sycl::free(ptr, *queue);
+ delete queue;
+ } else {
+ sycl::free(ptr, *queue);
+ }
}
static ze_module_handle_t loadModule(const void *data, size_t dataSize) {
@@ -104,9 +127,9 @@ static ze_module_handle_t loadModule(const void *data, size_t dataSize) {
nullptr,
nullptr};
auto zeDevice = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(
- getDefaultDevice());
+ *getDefaultDevicePtr());
auto zeContext = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(
- getDefaultContext());
+ *getDefaultContextPtr());
L0_SAFE_CALL(zeModuleCreate(zeContext, zeDevice, &desc, &zeModule, nullptr));
return zeModule;
}
@@ -115,17 +138,33 @@ static sycl::kernel *getKernel(ze_module_handle_t zeModule, const char *name) {
assert(zeModule);
assert(name);
ze_kernel_handle_t zeKernel;
- ze_kernel_desc_t desc = {};
- desc.pKernelName = name;
+ ze_kernel_desc_t desc = {ZE_STRUCTURE_TYPE_KERNEL_DESC, nullptr,
+ 0, // flags
+ name};
+
+ ze_result_t result = zeKernelCreate(zeModule, &desc, &zeKernel);
+
+ // Check if there are unresolved imports
+ if (result == ZE_RESULT_ERROR_INVALID_MODULE_UNLINKED) {
+ fprintf(stdout, "Unresolved imports!!!\n");
+ fflush(stdout);
+ abort();
+ }
+
+ // Check to see if the kernel name was found in the supplied module
+ if (result == ZE_RESULT_ERROR_INVALID_KERNEL_NAME) {
+ fprintf(stdout, "Invalid kernel name: %s !!!\n", name);
+ fflush(stdout);
+ abort();
+ }
- L0_SAFE_CALL(zeKernelCreate(zeModule, &desc, &zeKernel));
sycl::kernel_bundle<sycl::bundle_state::executable> kernelBundle =
sycl::make_kernel_bundle<sycl::backend::ext_oneapi_level_zero,
sycl::bundle_state::executable>(
- {zeModule}, getDefaultContext());
+ {zeModule}, *getDefaultContextPtr());
auto kernel = sycl::make_kernel<sycl::backend::ext_oneapi_level_zero>(
- {kernelBundle, zeKernel}, getDefaultContext());
+ {kernelBundle, zeKernel}, *getDefaultContextPtr());
return new sycl::kernel(kernel);
}
@@ -152,7 +191,7 @@ extern "C" SYCL_RUNTIME_EXPORT sycl::queue *mgpuStreamCreate() {
return catchAll([&]() {
sycl::queue *queue =
- new sycl::queue(getDefaultContext(), getDefaultDevice());
+ new sycl::queue(*getDefaultContextPtr(), *getDefaultDevicePtr());
return queue;
});
}
@@ -207,3 +246,11 @@ mgpuModuleUnload(ze_module_handle_t module) {
catchAll([&]() { L0_SAFE_CALL(zeModuleDestroy(module)); });
}
+
+extern "C" SYCL_RUNTIME_EXPORT void mgpuSetDefaultDevice(int32_t device) {
+ initGpuPool();
+ if (device >= pGpuPool->size()) {
+ throw std::runtime_error("SyclRuntime: Invalid device index!");
+ }
+ defaultDevice = device;
+}
More information about the Mlir-commits
mailing list