[Mlir-commits] [mlir] 7fa19e6 - [MLIR] Add SyclRuntimeWrapper (#69648)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 26 10:41:13 PDT 2023
Author: Nishant Patel
Date: 2023-10-26T19:41:09+02:00
New Revision: 7fa19e6f4b87623b0ca1a23bf6b6293c1b5e5799
URL: https://github.com/llvm/llvm-project/commit/7fa19e6f4b87623b0ca1a23bf6b6293c1b5e5799
DIFF: https://github.com/llvm/llvm-project/commit/7fa19e6f4b87623b0ca1a23bf6b6293c1b5e5799.diff
LOG: [MLIR] Add SyclRuntimeWrapper (#69648)
Added:
mlir/cmake/modules/FindLevelZero.cmake
mlir/cmake/modules/FindSyclRuntime.cmake
mlir/lib/ExecutionEngine/SyclRuntimeWrappers.cpp
Modified:
mlir/CMakeLists.txt
mlir/lib/ExecutionEngine/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt
index ac120aad0d1eda7..16ff950089734b7 100644
--- a/mlir/CMakeLists.txt
+++ b/mlir/CMakeLists.txt
@@ -126,6 +126,7 @@ add_definitions(-DMLIR_ROCM_CONVERSIONS_ENABLED=${MLIR_ENABLE_ROCM_CONVERSIONS})
set(MLIR_ENABLE_DEPRECATED_GPU_SERIALIZATION 0 CACHE BOOL "Enable deprecated GPU serialization passes")
set(MLIR_ENABLE_CUDA_RUNNER 0 CACHE BOOL "Enable building the mlir CUDA runner")
set(MLIR_ENABLE_ROCM_RUNNER 0 CACHE BOOL "Enable building the mlir ROCm runner")
+set(MLIR_ENABLE_SYCL_RUNNER 0 CACHE BOOL "Enable building the mlir Sycl runner")
set(MLIR_ENABLE_SPIRV_CPU_RUNNER 0 CACHE BOOL "Enable building the mlir SPIR-V cpu runner")
set(MLIR_ENABLE_VULKAN_RUNNER 0 CACHE BOOL "Enable building the mlir Vulkan runner")
set(MLIR_ENABLE_NVPTXCOMPILER 0 CACHE BOOL
diff --git a/mlir/cmake/modules/FindLevelZero.cmake b/mlir/cmake/modules/FindLevelZero.cmake
new file mode 100644
index 000000000000000..012187f0afc0b07
--- /dev/null
+++ b/mlir/cmake/modules/FindLevelZero.cmake
@@ -0,0 +1,221 @@
+# CMake find_package() module for level-zero
+#
+# Example usage:
+#
+# find_package(LevelZero)
+#
+# If successful, the following variables will be defined:
+# LevelZero_FOUND
+# LevelZero_INCLUDE_DIRS
+# LevelZero_LIBRARY
+# LevelZero_LIBRARIES_DIR
+#
+# By default, the module searches the standard paths to locate the "ze_api.h"
+# and the ze_loader shared library. When using a custom level-zero installation,
+# the environment variable "LEVEL_ZERO_DIR" should be specified telling the
+# module to get the level-zero library and headers from that location.
+
+include(FindPackageHandleStandardArgs)
+
+# Search path priority
+# 1. CMake Variable LEVEL_ZERO_DIR
+# 2. Environment Variable LEVEL_ZERO_DIR
+
+if(NOT LEVEL_ZERO_DIR)
+ if(DEFINED ENV{LEVEL_ZERO_DIR})
+ set(LEVEL_ZERO_DIR "$ENV{LEVEL_ZERO_DIR}")
+ endif()
+endif()
+
+if(LEVEL_ZERO_DIR)
+ find_path(LevelZero_INCLUDE_DIR
+ NAMES level_zero/ze_api.h
+ PATHS ${LEVEL_ZERO_DIR}/include
+ NO_DEFAULT_PATH
+ )
+
+ if(LINUX)
+ find_library(LevelZero_LIBRARY
+ NAMES ze_loader
+ PATHS ${LEVEL_ZERO_DIR}/lib
+ ${LEVEL_ZERO_DIR}/lib/x86_64-linux-gnu
+ NO_DEFAULT_PATH
+ )
+ else()
+ find_library(LevelZero_LIBRARY
+ NAMES ze_loader
+ PATHS ${LEVEL_ZERO_DIR}/lib
+ NO_DEFAULT_PATH
+ )
+ endif()
+else()
+ find_path(LevelZero_INCLUDE_DIR
+ NAMES level_zero/ze_api.h
+ )
+
+ find_library(LevelZero_LIBRARY
+ NAMES ze_loader
+ )
+endif()
+
+# Compares the two version string that are supposed to be in x.y.z format
+# and reports if the argument VERSION_STR1 is greater than or equal than
+# version_str2. The strings are compared lexicographically after conversion to
+# lists of equal lengths, with the shorter string getting zero-padded.
+function(compare_versions VERSION_STR1 VERSION_STR2 OUTPUT)
+ # Convert the strings to list
+ string(REPLACE "." ";" VL1 ${VERSION_STR1})
+ string(REPLACE "." ";" VL2 ${VERSION_STR2})
+ # get lengths of both lists
+ list(LENGTH VL1 VL1_LEN)
+ list(LENGTH VL2 VL2_LEN)
+ set(LEN ${VL1_LEN})
+ # If they
diff er in size pad the shorter list with 0s
+ if(VL1_LEN GREATER VL2_LEN)
+ math(EXPR DIFF "${VL1_LEN} - ${VL2_LEN}" OUTPUT_FORMAT DECIMAL)
+ foreach(IDX RANGE 1 ${DIFF} 1)
+ list(APPEND VL2 "0")
+ endforeach()
+ elseif(VL2_LEN GREATER VL2_LEN)
+ math(EXPR DIFF "${VL1_LEN} - ${VL2_LEN}" OUTPUT_FORMAT DECIMAL)
+ foreach(IDX RANGE 1 ${DIFF} 1)
+ list(APPEND VL2 "0")
+ endforeach()
+ set(LEN ${VL2_LEN})
+ endif()
+ math(EXPR LEN_SUB_ONE "${LEN}-1")
+ foreach(IDX RANGE 0 ${LEN_SUB_ONE} 1)
+ list(GET VL1 ${IDX} VAL1)
+ list(GET VL2 ${IDX} VAL2)
+
+ if(${VAL1} GREATER ${VAL2})
+ set(${OUTPUT} TRUE PARENT_SCOPE)
+ break()
+ elseif(${VAL1} LESS ${VAL2})
+ set(${OUTPUT} FALSE PARENT_SCOPE)
+ break()
+ else()
+ set(${OUTPUT} TRUE PARENT_SCOPE)
+ endif()
+ endforeach()
+
+ endfunction(compare_versions)
+
+# Creates a small function to run and extract the LevelZero loader version.
+function(get_l0_loader_version)
+
+ set(L0_VERSIONEER_SRC
+ [====[
+ #include <iostream>
+ #include <level_zero/loader/ze_loader.h>
+ #include <string>
+ int main() {
+ ze_result_t result;
+ std::string loader("loader");
+ zel_component_version_t *versions;
+ size_t size = 0;
+ result = zeInit(0);
+ if (result != ZE_RESULT_SUCCESS) {
+ std::cerr << "Failed to init ze driver" << std::endl;
+ return -1;
+ }
+ zelLoaderGetVersions(&size, nullptr);
+ versions = new zel_component_version_t[size];
+ zelLoaderGetVersions(&size, versions);
+ for (size_t i = 0; i < size; i++) {
+ if (loader.compare(versions[i].component_name) == 0) {
+ std::cout << versions[i].component_lib_version.major << "."
+ << versions[i].component_lib_version.minor << "."
+ << versions[i].component_lib_version.patch;
+ break;
+ }
+ }
+ delete[] versions;
+ return 0;
+ }
+ ]====]
+ )
+
+ set(L0_VERSIONEER_FILE ${CMAKE_BINARY_DIR}/temp/l0_versioneer.cpp)
+
+ file(WRITE ${L0_VERSIONEER_FILE} "${L0_VERSIONEER_SRC}")
+
+ # We need both the directories in the include path as ze_loader.h
+ # includes "ze_api.h" and not "level_zero/ze_api.h".
+ list(APPEND INCLUDE_DIRS ${LevelZero_INCLUDE_DIR})
+ list(APPEND INCLUDE_DIRS ${LevelZero_INCLUDE_DIR}/level_zero)
+ list(JOIN INCLUDE_DIRS ";" INCLUDE_DIRS_STR)
+ try_run(L0_VERSIONEER_RUN L0_VERSIONEER_COMPILE
+ "${CMAKE_BINARY_DIR}"
+ "${L0_VERSIONEER_FILE}"
+ LINK_LIBRARIES ${LevelZero_LIBRARY}
+ CMAKE_FLAGS
+ "-DINCLUDE_DIRECTORIES=${INCLUDE_DIRS_STR}"
+ RUN_OUTPUT_VARIABLE L0_VERSION
+ )
+ if(${L0_VERSIONEER_COMPILE} AND (DEFINED L0_VERSIONEER_RUN))
+ set(LevelZero_VERSION ${L0_VERSION} PARENT_SCOPE)
+ message(STATUS "Found Level Zero of version: ${L0_VERSION}")
+ else()
+ message(FATAL_ERROR
+ "Could not compile a level-zero program to extract loader version"
+ )
+ endif()
+endfunction(get_l0_loader_version)
+
+if(LevelZero_INCLUDE_DIR AND LevelZero_LIBRARY)
+ list(APPEND LevelZero_LIBRARIES "${LevelZero_LIBRARY}")
+ list(APPEND LevelZero_INCLUDE_DIRS ${LevelZero_INCLUDE_DIR})
+ if(OpenCL_FOUND)
+ list(APPEND LevelZero_INCLUDE_DIRS ${OpenCL_INCLUDE_DIRS})
+ endif()
+
+ cmake_path(GET LevelZero_LIBRARY PARENT_PATH LevelZero_LIBRARIES_PATH)
+ set(LevelZero_LIBRARIES_DIR ${LevelZero_LIBRARIES_PATH})
+
+ if(NOT TARGET LevelZero::LevelZero)
+ add_library(LevelZero::LevelZero INTERFACE IMPORTED)
+ set_target_properties(LevelZero::LevelZero
+ PROPERTIES INTERFACE_LINK_LIBRARIES "${LevelZero_LIBRARIES}"
+ )
+ set_target_properties(LevelZero::LevelZero
+ PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${LevelZero_INCLUDE_DIRS}"
+ )
+ endif()
+endif()
+
+# Check if a specific version of Level Zero is required
+if(LevelZero_FIND_VERSION)
+ get_l0_loader_version()
+ set(VERSION_GT_FIND_VERSION FALSE)
+ compare_versions(
+ ${LevelZero_VERSION}
+ ${LevelZero_FIND_VERSION}
+ VERSION_GT_FIND_VERSION
+ )
+ if(${VERSION_GT_FIND_VERSION})
+ set(LevelZero_FOUND TRUE)
+ else()
+ set(LevelZero_FOUND FALSE)
+ endif()
+else()
+ set(LevelZero_FOUND TRUE)
+endif()
+
+find_package_handle_standard_args(LevelZero
+ REQUIRED_VARS
+ LevelZero_FOUND
+ LevelZero_INCLUDE_DIRS
+ LevelZero_LIBRARY
+ LevelZero_LIBRARIES_DIR
+ HANDLE_COMPONENTS
+)
+mark_as_advanced(LevelZero_LIBRARY LevelZero_INCLUDE_DIRS)
+
+if(LevelZero_FOUND)
+ find_package_message(LevelZero "Found LevelZero: ${LevelZero_LIBRARY}"
+ "(found version ${LevelZero_VERSION})"
+ )
+else()
+ find_package_message(LevelZero "Could not find LevelZero" "")
+endif()
diff --git a/mlir/cmake/modules/FindSyclRuntime.cmake b/mlir/cmake/modules/FindSyclRuntime.cmake
new file mode 100644
index 000000000000000..38b065a3f284c2c
--- /dev/null
+++ b/mlir/cmake/modules/FindSyclRuntime.cmake
@@ -0,0 +1,68 @@
+# CMake find_package() module for SYCL Runtime
+#
+# Example usage:
+#
+# find_package(SyclRuntime)
+#
+# If successful, the following variables will be defined:
+# SyclRuntime_FOUND
+# SyclRuntime_INCLUDE_DIRS
+# SyclRuntime_LIBRARY
+# SyclRuntime_LIBRARIES_DIR
+#
+
+include(FindPackageHandleStandardArgs)
+
+if(NOT DEFINED ENV{CMPLR_ROOT})
+ message(WARNING "Please make sure to install Intel DPC++ Compiler and run setvars.(sh/bat)")
+ message(WARNING "You can download standalone Intel DPC++ Compiler from https://www.intel.com/content/www/us/en/developer/articles/tool/oneapi-standalone-components.html#compilers")
+else()
+ if(LINUX OR (${CMAKE_SYSTEM_NAME} MATCHES "Linux"))
+ set(SyclRuntime_ROOT "$ENV{CMPLR_ROOT}/linux")
+ elseif(WIN32)
+ set(SyclRuntime_ROOT "$ENV{CMPLR_ROOT}/windows")
+ endif()
+ list(APPEND SyclRuntime_INCLUDE_DIRS "${SyclRuntime_ROOT}/include")
+ list(APPEND SyclRuntime_INCLUDE_DIRS "${SyclRuntime_ROOT}/include/sycl")
+
+ set(SyclRuntime_LIBRARY_DIR "${SyclRuntime_ROOT}/lib")
+
+ message(STATUS "SyclRuntime_LIBRARY_DIR: ${SyclRuntime_LIBRARY_DIR}")
+ find_library(SyclRuntime_LIBRARY
+ NAMES sycl
+ PATHS ${SyclRuntime_LIBRARY_DIR}
+ NO_DEFAULT_PATH
+ )
+endif()
+
+if(SyclRuntime_LIBRARY)
+ set(SyclRuntime_FOUND TRUE)
+ if(NOT TARGET SyclRuntime::SyclRuntime)
+ add_library(SyclRuntime::SyclRuntime INTERFACE IMPORTED)
+ set_target_properties(SyclRuntime::SyclRuntime
+ PROPERTIES INTERFACE_LINK_LIBRARIES "${SyclRuntime_LIBRARY}"
+ )
+ set_target_properties(SyclRuntime::SyclRuntime
+ PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${SyclRuntime_INCLUDE_DIRS}"
+ )
+ endif()
+else()
+ set(SyclRuntime_FOUND FALSE)
+endif()
+
+find_package_handle_standard_args(SyclRuntime
+ REQUIRED_VARS
+ SyclRuntime_FOUND
+ SyclRuntime_INCLUDE_DIRS
+ SyclRuntime_LIBRARY
+ SyclRuntime_LIBRARY_DIR
+ HANDLE_COMPONENTS
+)
+
+mark_as_advanced(SyclRuntime_LIBRARY SyclRuntime_INCLUDE_DIRS)
+
+if(SyclRuntime_FOUND)
+ find_package_message(SyclRuntime "Found SyclRuntime: ${SyclRuntime_LIBRARY}" "")
+else()
+ find_package_message(SyclRuntime "Could not find SyclRuntime" "")
+endif()
diff --git a/mlir/lib/ExecutionEngine/CMakeLists.txt b/mlir/lib/ExecutionEngine/CMakeLists.txt
index ea33c2c6ed261e1..fdc797763ae3a41 100644
--- a/mlir/lib/ExecutionEngine/CMakeLists.txt
+++ b/mlir/lib/ExecutionEngine/CMakeLists.txt
@@ -12,6 +12,7 @@ set(LLVM_OPTIONAL_SOURCES
RunnerUtils.cpp
OptUtils.cpp
JitRunner.cpp
+ SyclRuntimeWrappers.cpp
)
# Use a separate library for OptUtils, to avoid pulling in the entire JIT and
@@ -328,4 +329,39 @@ if(LLVM_ENABLE_PIC)
hip::host hip::amdhip64
)
endif()
+
+ if(MLIR_ENABLE_SYCL_RUNNER)
+ find_package(SyclRuntime)
+
+ if(NOT SyclRuntime_FOUND)
+ message(FATAL_ERROR "syclRuntime not found. Please set check oneapi installation and run setvars.sh.")
+ endif()
+
+ find_package(LevelZero)
+
+ if(NOT LevelZero_FOUND)
+ message(FATAL_ERROR "LevelZero not found. Please set LEVEL_ZERO_DIR.")
+ endif()
+
+ add_mlir_library(mlir_sycl_runtime
+ SHARED
+ SyclRuntimeWrappers.cpp
+
+ EXCLUDE_FROM_LIBMLIR
+ )
+
+ check_cxx_compiler_flag("-frtti" CXX_HAS_FRTTI_FLAG)
+ if(NOT CXX_HAS_FRTTI_FLAG)
+ message(FATAL_ERROR "CXX compiler does not accept flag -frtti")
+ endif()
+ target_compile_options (mlir_sycl_runtime PUBLIC -fexceptions -frtti)
+
+ target_include_directories(mlir_sycl_runtime PRIVATE
+ ${MLIR_INCLUDE_DIRS}
+ )
+
+ target_link_libraries(mlir_sycl_runtime PRIVATE LevelZero::LevelZero SyclRuntime::SyclRuntime)
+
+ set_property(TARGET mlir_sycl_runtime APPEND PROPERTY BUILD_RPATH "${LevelZero_LIBRARIES_DIR}" "${SyclRuntime_LIBRARIES_DIR}")
+ endif()
endif()
diff --git a/mlir/lib/ExecutionEngine/SyclRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/SyclRuntimeWrappers.cpp
new file mode 100644
index 000000000000000..c250340c38fc77d
--- /dev/null
+++ b/mlir/lib/ExecutionEngine/SyclRuntimeWrappers.cpp
@@ -0,0 +1,209 @@
+//===- SyclRuntimeWrappers.cpp - MLIR SYCL wrapper library ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implements wrappers around the sycl runtime library with C linkage
+//
+//===----------------------------------------------------------------------===//
+
+#include <CL/sycl.hpp>
+#include <level_zero/ze_api.h>
+#include <sycl/ext/oneapi/backend/level_zero.hpp>
+
+#ifdef _WIN32
+#define SYCL_RUNTIME_EXPORT __declspec(dllexport)
+#else
+#define SYCL_RUNTIME_EXPORT
+#endif // _WIN32
+
+namespace {
+
+template <typename F>
+auto catchAll(F &&func) {
+ try {
+ return func();
+ } catch (const std::exception &e) {
+ fprintf(stdout, "An exception was thrown: %s\n", e.what());
+ fflush(stdout);
+ abort();
+ } catch (...) {
+ fprintf(stdout, "An unknown exception was thrown\n");
+ fflush(stdout);
+ abort();
+ }
+}
+
+#define L0_SAFE_CALL(call) \
+ { \
+ ze_result_t status = (call); \
+ if (status != ZE_RESULT_SUCCESS) { \
+ fprintf(stdout, "L0 error %d\n", status); \
+ fflush(stdout); \
+ abort(); \
+ } \
+ }
+
+} // namespace
+
+static sycl::device getDefaultDevice() {
+ static sycl::device syclDevice;
+ static bool isDeviceInitialised = false;
+ if (!isDeviceInitialised) {
+ auto platformList = sycl::platform::get_platforms();
+ for (const auto &platform : platformList) {
+ auto platformName = platform.get_info<sycl::info::platform::name>();
+ bool isLevelZero = platformName.find("Level-Zero") != std::string::npos;
+ if (!isLevelZero)
+ continue;
+
+ syclDevice = platform.get_devices()[0];
+ isDeviceInitialised = true;
+ return syclDevice;
+ }
+ throw std::runtime_error("getDefaultDevice failed");
+ } else
+ return syclDevice;
+}
+
+static sycl::context getDefaultContext() {
+ static sycl::context syclContext{getDefaultDevice()};
+ return syclContext;
+}
+
+static void *allocDeviceMemory(sycl::queue *queue, size_t size, bool isShared) {
+ void *memPtr = nullptr;
+ if (isShared) {
+ memPtr = sycl::aligned_alloc_shared(64, size, getDefaultDevice(),
+ getDefaultContext());
+ } else {
+ memPtr = sycl::aligned_alloc_device(64, size, getDefaultDevice(),
+ getDefaultContext());
+ }
+ if (memPtr == nullptr) {
+ throw std::runtime_error("mem allocation failed!");
+ }
+ return memPtr;
+}
+
+static void deallocDeviceMemory(sycl::queue *queue, void *ptr) {
+ sycl::free(ptr, *queue);
+}
+
+static ze_module_handle_t loadModule(const void *data, size_t dataSize) {
+ assert(data);
+ ze_module_handle_t zeModule;
+ ze_module_desc_t desc = {ZE_STRUCTURE_TYPE_MODULE_DESC,
+ nullptr,
+ ZE_MODULE_FORMAT_IL_SPIRV,
+ dataSize,
+ (const uint8_t *)data,
+ nullptr,
+ nullptr};
+ auto zeDevice = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(
+ getDefaultDevice());
+ auto zeContext = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(
+ getDefaultContext());
+ L0_SAFE_CALL(zeModuleCreate(zeContext, zeDevice, &desc, &zeModule, nullptr));
+ return zeModule;
+}
+
+static sycl::kernel *getKernel(ze_module_handle_t zeModule, const char *name) {
+ assert(zeModule);
+ assert(name);
+ ze_kernel_handle_t zeKernel;
+ ze_kernel_desc_t desc = {};
+ desc.pKernelName = name;
+
+ L0_SAFE_CALL(zeKernelCreate(zeModule, &desc, &zeKernel));
+ sycl::kernel_bundle<sycl::bundle_state::executable> kernelBundle =
+ sycl::make_kernel_bundle<sycl::backend::ext_oneapi_level_zero,
+ sycl::bundle_state::executable>(
+ {zeModule}, getDefaultContext());
+
+ auto kernel = sycl::make_kernel<sycl::backend::ext_oneapi_level_zero>(
+ {kernelBundle, zeKernel}, getDefaultContext());
+ return new sycl::kernel(kernel);
+}
+
+static void launchKernel(sycl::queue *queue, sycl::kernel *kernel, size_t gridX,
+ size_t gridY, size_t gridZ, size_t blockX,
+ size_t blockY, size_t blockZ, size_t sharedMemBytes,
+ void **params, size_t paramsCount) {
+ auto syclGlobalRange =
+ sycl::range<3>(blockZ * gridZ, blockY * gridY, blockX * gridX);
+ auto syclLocalRange = sycl::range<3>(blockZ, blockY, blockX);
+ sycl::nd_range<3> syclNdRange(syclGlobalRange, syclLocalRange);
+
+ queue->submit([&](sycl::handler &cgh) {
+ for (size_t i = 0; i < paramsCount; i++) {
+ cgh.set_arg(static_cast<uint32_t>(i), *(static_cast<void **>(params[i])));
+ }
+ cgh.parallel_for(syclNdRange, *kernel);
+ });
+}
+
+// Wrappers
+
+extern "C" SYCL_RUNTIME_EXPORT sycl::queue *mgpuStreamCreate() {
+
+ return catchAll([&]() {
+ sycl::queue *queue =
+ new sycl::queue(getDefaultContext(), getDefaultDevice());
+ return queue;
+ });
+}
+
+extern "C" SYCL_RUNTIME_EXPORT void mgpuStreamDestroy(sycl::queue *queue) {
+ catchAll([&]() { delete queue; });
+}
+
+extern "C" SYCL_RUNTIME_EXPORT void *
+mgpuMemAlloc(uint64_t size, sycl::queue *queue, bool isShared) {
+ return catchAll([&]() {
+ return allocDeviceMemory(queue, static_cast<size_t>(size), true);
+ });
+}
+
+extern "C" SYCL_RUNTIME_EXPORT void mgpuMemFree(void *ptr, sycl::queue *queue) {
+ catchAll([&]() {
+ if (ptr) {
+ deallocDeviceMemory(queue, ptr);
+ }
+ });
+}
+
+extern "C" SYCL_RUNTIME_EXPORT ze_module_handle_t
+mgpuModuleLoad(const void *data, size_t gpuBlobSize) {
+ return catchAll([&]() { return loadModule(data, gpuBlobSize); });
+}
+
+extern "C" SYCL_RUNTIME_EXPORT sycl::kernel *
+mgpuModuleGetFunction(ze_module_handle_t module, const char *name) {
+ return catchAll([&]() { return getKernel(module, name); });
+}
+
+extern "C" SYCL_RUNTIME_EXPORT void
+mgpuLaunchKernel(sycl::kernel *kernel, size_t gridX, size_t gridY, size_t gridZ,
+ size_t blockX, size_t blockY, size_t blockZ,
+ size_t sharedMemBytes, sycl::queue *queue, void **params,
+ void ** /*extra*/, size_t paramsCount) {
+ return catchAll([&]() {
+ launchKernel(queue, kernel, gridX, gridY, gridZ, blockX, blockY, blockZ,
+ sharedMemBytes, params, paramsCount);
+ });
+}
+
+extern "C" SYCL_RUNTIME_EXPORT void mgpuStreamSynchronize(sycl::queue *queue) {
+
+ catchAll([&]() { queue->wait(); });
+}
+
+extern "C" SYCL_RUNTIME_EXPORT void
+mgpuModuleUnload(ze_module_handle_t module) {
+
+ catchAll([&]() { L0_SAFE_CALL(zeModuleDestroy(module)); });
+}
More information about the Mlir-commits
mailing list