[Mlir-commits] [mlir] [mlir] SYCL runtime wrapper: add handler for set default device. (PR #90878)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 2 10:15:56 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-execution-engine

Author: Sang Ik Lee (silee2)

<details>
<summary>Changes</summary>

Add missing support for set default device
Thread local storage is used for keeping per thread default device.
Default context is redefined as a context with all GPU devices.

---
Full diff: https://github.com/llvm/llvm-project/pull/90878.diff


1 Files Affected:

- (modified) mlir/lib/ExecutionEngine/SyclRuntimeWrappers.cpp (+80-33) 


``````````diff
diff --git a/mlir/lib/ExecutionEngine/SyclRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/SyclRuntimeWrappers.cpp
index c250340c38fc77..806d49f342d292 100644
--- a/mlir/lib/ExecutionEngine/SyclRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/SyclRuntimeWrappers.cpp
@@ -11,8 +11,11 @@
 //===----------------------------------------------------------------------===//
 
 #include <CL/sycl.hpp>
+#include <cstdio>
+#include <cstdlib>
 #include <level_zero/ze_api.h>
 #include <sycl/ext/oneapi/backend/level_zero.hpp>
+#include <vector>
 
 #ifdef _WIN32
 #define SYCL_RUNTIME_EXPORT __declspec(dllexport)
@@ -49,39 +52,53 @@ auto catchAll(F &&func) {
 
 } // namespace
 
-static sycl::device getDefaultDevice() {
-  static sycl::device syclDevice;
-  static bool isDeviceInitialised = false;
-  if (!isDeviceInitialised) {
-    auto platformList = sycl::platform::get_platforms();
-    for (const auto &platform : platformList) {
-      auto platformName = platform.get_info<sycl::info::platform::name>();
-      bool isLevelZero = platformName.find("Level-Zero") != std::string::npos;
-      if (!isLevelZero)
-        continue;
-
-      syclDevice = platform.get_devices()[0];
-      isDeviceInitialised = true;
-      return syclDevice;
+thread_local static int32_t defaultDevice = 0;
+thread_local static bool isGpuPoolInitialized = false;
+thread_local static bool isDefaultContextInitialized = false;
+thread_local static std::vector<sycl::device> *pGpuPool = nullptr;
+thread_local static sycl::context *pDefaultContext = nullptr;
+
+static void initGpuPool() {
+  if (isGpuPoolInitialized)
+    return;
+  auto platformList = sycl::platform::get_platforms();
+  for (const auto &platform : platformList) {
+    if (platform.get_backend() == sycl::backend::ext_oneapi_level_zero) {
+      auto gpuDevices = platform.get_devices(sycl::info::device_type::gpu);
+      if (gpuDevices.empty()) {
+        throw std::runtime_error("SyclRuntime: No GPU devices found!");
+      }
+      pGpuPool = new std::vector<sycl::device>{gpuDevices};
+      isGpuPoolInitialized = true;
+      return;
     }
-    throw std::runtime_error("getDefaultDevice failed");
-  } else
-    return syclDevice;
+  }
+  throw std::runtime_error("SyclRuntime: No GPU devices found!");
 }
 
-static sycl::context getDefaultContext() {
-  static sycl::context syclContext{getDefaultDevice()};
-  return syclContext;
+static sycl::device *getDefaultDevicePtr() {
+  initGpuPool();
+  return &((*pGpuPool)[defaultDevice]);
+}
+
+static sycl::context *getDefaultContextPtr() {
+  if (isDefaultContextInitialized) {
+    return pDefaultContext;
+  }
+  initGpuPool();
+  pDefaultContext = new sycl::context(*pGpuPool);
+  isDefaultContextInitialized = true;
+  return pDefaultContext;
 }
 
 static void *allocDeviceMemory(sycl::queue *queue, size_t size, bool isShared) {
   void *memPtr = nullptr;
   if (isShared) {
-    memPtr = sycl::aligned_alloc_shared(64, size, getDefaultDevice(),
-                                        getDefaultContext());
+    memPtr = sycl::aligned_alloc_shared(64, size, *getDefaultDevicePtr(),
+                                        *getDefaultContextPtr());
   } else {
-    memPtr = sycl::aligned_alloc_device(64, size, getDefaultDevice(),
-                                        getDefaultContext());
+    memPtr = sycl::aligned_alloc_device(64, size, *getDefaultDevicePtr(),
+                                        *getDefaultContextPtr());
   }
   if (memPtr == nullptr) {
     throw std::runtime_error("mem allocation failed!");
@@ -90,7 +107,13 @@ static void *allocDeviceMemory(sycl::queue *queue, size_t size, bool isShared) {
 }
 
 static void deallocDeviceMemory(sycl::queue *queue, void *ptr) {
-  sycl::free(ptr, *queue);
+  if (queue == nullptr) {
+    queue = new sycl::queue(*getDefaultContextPtr(), *getDefaultDevicePtr());
+    sycl::free(ptr, *queue);
+    delete queue;
+  } else {
+    sycl::free(ptr, *queue);
+  }
 }
 
 static ze_module_handle_t loadModule(const void *data, size_t dataSize) {
@@ -104,9 +127,9 @@ static ze_module_handle_t loadModule(const void *data, size_t dataSize) {
                            nullptr,
                            nullptr};
   auto zeDevice = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(
-      getDefaultDevice());
+      *getDefaultDevicePtr());
   auto zeContext = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(
-      getDefaultContext());
+      *getDefaultContextPtr());
   L0_SAFE_CALL(zeModuleCreate(zeContext, zeDevice, &desc, &zeModule, nullptr));
   return zeModule;
 }
@@ -115,17 +138,33 @@ static sycl::kernel *getKernel(ze_module_handle_t zeModule, const char *name) {
   assert(zeModule);
   assert(name);
   ze_kernel_handle_t zeKernel;
-  ze_kernel_desc_t desc = {};
-  desc.pKernelName = name;
+  ze_kernel_desc_t desc = {ZE_STRUCTURE_TYPE_KERNEL_DESC, nullptr,
+                           0, // flags
+                           name};
+
+  ze_result_t result = zeKernelCreate(zeModule, &desc, &zeKernel);
+
+  // Check if there are unresolved imports
+  if (result == ZE_RESULT_ERROR_INVALID_MODULE_UNLINKED) {
+    fprintf(stdout, "Unresolved imports!!!\n");
+    fflush(stdout);
+    abort();
+  }
+
+  // Check to see if the kernel name was found in the supplied module
+  if (result == ZE_RESULT_ERROR_INVALID_KERNEL_NAME) {
+    fprintf(stdout, "Invalid kernel name: %s !!!\n", name);
+    fflush(stdout);
+    abort();
+  }
 
-  L0_SAFE_CALL(zeKernelCreate(zeModule, &desc, &zeKernel));
   sycl::kernel_bundle<sycl::bundle_state::executable> kernelBundle =
       sycl::make_kernel_bundle<sycl::backend::ext_oneapi_level_zero,
                                sycl::bundle_state::executable>(
-          {zeModule}, getDefaultContext());
+          {zeModule}, *getDefaultContextPtr());
 
   auto kernel = sycl::make_kernel<sycl::backend::ext_oneapi_level_zero>(
-      {kernelBundle, zeKernel}, getDefaultContext());
+      {kernelBundle, zeKernel}, *getDefaultContextPtr());
   return new sycl::kernel(kernel);
 }
 
@@ -152,7 +191,7 @@ extern "C" SYCL_RUNTIME_EXPORT sycl::queue *mgpuStreamCreate() {
 
   return catchAll([&]() {
     sycl::queue *queue =
-        new sycl::queue(getDefaultContext(), getDefaultDevice());
+        new sycl::queue(*getDefaultContextPtr(), *getDefaultDevicePtr());
     return queue;
   });
 }
@@ -207,3 +246,11 @@ mgpuModuleUnload(ze_module_handle_t module) {
 
   catchAll([&]() { L0_SAFE_CALL(zeModuleDestroy(module)); });
 }
+
+extern "C" SYCL_RUNTIME_EXPORT void mgpuSetDefaultDevice(int32_t device) {
+  initGpuPool();
+  if (device >= pGpuPool->size()) {
+    throw std::runtime_error("SyclRuntime: Invalid device index!");
+  }
+  defaultDevice = device;
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/90878


More information about the Mlir-commits mailing list