[Mlir-commits] [mlir] [MLIR] Add SyclRuntimeWrapper (PR #69648)

Ivan Butygin llvmlistbot at llvm.org
Thu Oct 26 05:46:20 PDT 2023


================
@@ -0,0 +1,209 @@
+//===- SyclRuntimeWrappers.cpp - MLIR SYCL wrapper library ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implements wrappers around the sycl runtime library with C linkage
+//
+//===----------------------------------------------------------------------===//
+
+#include <CL/sycl.hpp>
+#include <level_zero/ze_api.h>
+#include <sycl/ext/oneapi/backend/level_zero.hpp>
+
+#ifdef _WIN32
+#define SYCL_RUNTIME_EXPORT __declspec(dllexport)
+#else
+#define SYCL_RUNTIME_EXPORT
+#endif // _WIN32
+
+namespace {
+
+template <typename F>
+auto catchAll(F &&func) {
+  try {
+    return func();
+  } catch (const std::exception &e) {
+    fprintf(stdout, "An exception was thrown: %s\n", e.what());
+    fflush(stdout);
+    abort();
+  } catch (...) {
+    fprintf(stdout, "An unknown exception was thrown\n");
+    fflush(stdout);
+    abort();
+  }
+}
+
+#define L0_SAFE_CALL(call)                                                     \
+  {                                                                            \
+    ze_result_t status = (call);                                               \
+    if (status != ZE_RESULT_SUCCESS) {                                         \
+      fprintf(stdout, "L0 error %d\n", status);                                \
+      fflush(stdout);                                                          \
+      abort();                                                                 \
+    }                                                                          \
+  }
+
+} // namespace
+
+static sycl::device getDefaultDevice() {
+  static sycl::device syclDevice;
+  static bool isDeviceInitialised = false;
+  if (!isDeviceInitialised) {
+    auto platformList = sycl::platform::get_platforms();
+    for (const auto &platform : platformList) {
+      auto platformName = platform.get_info<sycl::info::platform::name>();
+      bool isLevelZero = platformName.find("Level-Zero") != std::string::npos;
+      if (!isLevelZero)
+        continue;
+
+      syclDevice = platform.get_devices()[0];
+      isDeviceInitialised = true;
+      return syclDevice;
+    }
+    throw std::runtime_error("getDefaultDevice failed");
+  } else
+    return syclDevice;
+}
+
+static sycl::context getDefaultContext() {
+  static sycl::context syclContext{getDefaultDevice()};
+  return syclContext;
+}
+
+static void *allocDeviceMemory(sycl::queue *queue, size_t size, bool isShared) {
+  void *memPtr = nullptr;
+  if (isShared) {
+    memPtr = sycl::aligned_alloc_shared(64, size, getDefaultDevice(),
+                                        getDefaultContext());
+  } else {
+    memPtr = sycl::aligned_alloc_device(64, size, getDefaultDevice(),
+                                        getDefaultContext());
+  }
+  if (memPtr == nullptr) {
+    throw std::runtime_error("mem allocation failed!");
+  }
+  return memPtr;
+}
+
+static void deallocDeviceMemory(sycl::queue *queue, void *ptr) {
+  sycl::free(ptr, *queue);
+}
+
+static ze_module_handle_t loadModule(const void *data, size_t dataSize) {
+  assert(data);
+  ze_module_handle_t zeModule;
+  ze_module_desc_t desc = {ZE_STRUCTURE_TYPE_MODULE_DESC,
+                           nullptr,
+                           ZE_MODULE_FORMAT_IL_SPIRV,
+                           dataSize,
+                           (const uint8_t *)data,
+                           nullptr,
+                           nullptr};
+  auto zeDevice = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(
+      getDefaultDevice());
+  auto zeContext = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(
+      getDefaultContext());
+  L0_SAFE_CALL(zeModuleCreate(zeContext, zeDevice, &desc, &zeModule, nullptr));
+  return zeModule;
+}
+
+static sycl::kernel *getKernel(ze_module_handle_t zeModule, const char *name) {
+  assert(zeModule);
+  assert(name);
+  ze_kernel_handle_t zeKernel;
+  ze_kernel_desc_t desc = {};
+  desc.pKernelName = name;
+
+  L0_SAFE_CALL(zeKernelCreate(zeModule, &desc, &zeKernel));
+  sycl::kernel_bundle<sycl::bundle_state::executable> kernelBundle =
+      sycl::make_kernel_bundle<sycl::backend::ext_oneapi_level_zero,
+                               sycl::bundle_state::executable>(
+          {zeModule}, getDefaultContext());
+
+  auto kernel = sycl::make_kernel<sycl::backend::ext_oneapi_level_zero>(
+      {kernelBundle, zeKernel}, getDefaultContext());
+  return new sycl::kernel(kernel);
+}
+
+static void launchKernel(sycl::queue *queue, sycl::kernel *kernel, size_t gridX,
+                         size_t gridY, size_t gridZ, size_t blockX,
+                         size_t blockY, size_t blockZ, size_t sharedMemBytes,
+                         void **params, size_t paramsCount) {
+  auto syclGlobalRange =
+      sycl::range<3>(blockZ * gridZ, blockY * gridY, blockX * gridX);
+  auto syclLocalRange = sycl::range<3>(blockZ, blockY, blockX);
+  sycl::nd_range<3> syclNdRange(syclGlobalRange, syclLocalRange);
+
+  queue->submit([&](sycl::handler &cgh) {
+    for (size_t i = 0; i < paramsCount; i++) {
+      cgh.set_arg(static_cast<uint32_t>(i), *(static_cast<void **>(params[i])));
+    }
+    cgh.parallel_for(syclNdRange, *kernel);
+  });
+}
+
+// Wrappers
+
+extern "C" SYCL_RUNTIME_EXPORT sycl::queue *mgpuStreamCreate() {
+
+  return catchAll([&]() {
+    sycl::queue *queue =
+        new sycl::queue(getDefaultContext(), getDefaultDevice());
+    return queue;
+  });
+}
+
+extern "C" SYCL_RUNTIME_EXPORT void mgpuStreamDestroy(sycl::queue *queue) {
+  catchAll([&]() { delete queue; });
+}
+
+extern "C" SYCL_RUNTIME_EXPORT void *
+mgpuMemAlloc(uint64_t size, sycl::queue *queue, bool isShared) {
+  return catchAll([&]() {
+    return allocDeviceMemory(queue, static_cast<size_t>(size), true);
+  });
+}
+
+extern "C" SYCL_RUNTIME_EXPORT void mgpuMemFree(void *ptr, sycl::queue *queue) {
+  catchAll([&]() {
+    if (ptr) {
+      deallocDeviceMemory(queue, ptr);
+    }
+  });
+}
+
+extern "C" SYCL_RUNTIME_EXPORT ze_module_handle_t
+mgpuModuleLoad(const void *data, size_t gpuBlobSize) {
+  return catchAll([&]() { return loadModule(data, gpuBlobSize); });
+}
+
+extern "C" SYCL_RUNTIME_EXPORT sycl::kernel *
+mgpuModuleGetFunction(ze_module_handle_t module, const char *name) {
+  return catchAll([&]() { return getKernel(module, name); });
+}
+
+extern "C" SYCL_RUNTIME_EXPORT void
+mgpuLaunchKernel(sycl::kernel *kernel, size_t gridX, size_t gridY, size_t gridZ,
+                 size_t blockX, size_t blockY, size_t blockZ,
+                 size_t sharedMemBytes, sycl::queue *queue, void **params,
+                 void **extra, size_t paramsCount) {
----------------
Hardcode84 wrote:

nit: `extra` is not being used and will result in warning, either do `/*extra*/` or `(void)extra`.

https://github.com/llvm/llvm-project/pull/69648


More information about the Mlir-commits mailing list