[Mlir-commits] [mlir] [mlir][ExecutionEngine] Add LevelZeroRuntimeWrapper. (PR #151038)

Md Abdullah Shahneous Bari llvmlistbot at llvm.org
Mon Aug 4 14:39:15 PDT 2025


================
@@ -0,0 +1,567 @@
+//===- LevelZeroRuntimeWrappers.cpp - MLIR Level Zero (L0) wrapper library-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implements wrappers around the Level Zero (L0) runtime library with C linkage
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/Twine.h"
+
+#include "level_zero/ze_api.h"
+#include <cassert>
+#include <deque>
+#include <exception>
+#include <functional>
+#include <iostream>
+#include <limits>
+#include <unordered_set>
+#include <vector>
+
+namespace {
+template <typename F>
+auto catchAll(F &&func) {
+  try {
+    return func();
+  } catch (const std::exception &e) {
+    std::cerr << "An exception was thrown: " << e.what() << std::endl;
+    std::abort();
+  } catch (...) {
+    std::cerr << "An unknown exception was thrown." << std::endl;
+    std::abort();
+  }
+}
+
+#define L0_SAFE_CALL(call)                                                     \
+  {                                                                            \
+    ze_result_t status = (call);                                               \
+    if (status != ZE_RESULT_SUCCESS) {                                         \
+      const char *errorString;                                                 \
+      zeDriverGetLastErrorDescription(NULL, &errorString);                     \
+      std::cerr << "L0 error " << status << ": " << errorString << std::endl;  \
+      std::abort();                                                            \
+    }                                                                          \
+  }
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// L0 RT context & device setters
+//===----------------------------------------------------------------------===//
+
+// Returns the L0 driver handle for the given index. Default index is 0
+// (i.e., returns the first driver handle of the available drivers).
+
+static ze_driver_handle_t getDriver(uint32_t idx = 0) {
+  ze_init_driver_type_desc_t driver_type = {};
+  driver_type.stype = ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC;
+  driver_type.flags = ZE_INIT_DRIVER_TYPE_FLAG_GPU;
+  driver_type.pNext = nullptr;
+  uint32_t driverCount{0};
+  thread_local static std::vector<ze_driver_handle_t> drivers;
+  thread_local static bool isDriverInitialised{false};
+  if (isDriverInitialised && idx < drivers.size())
+    return drivers[idx];
+  L0_SAFE_CALL(zeInitDrivers(&driverCount, nullptr, &driver_type));
+  if (!driverCount)
+    throw std::runtime_error("No L0 drivers found.");
+  drivers.resize(driverCount);
+  L0_SAFE_CALL(zeInitDrivers(&driverCount, drivers.data(), &driver_type));
+  if (idx >= driverCount)
+    throw std::runtime_error((llvm::Twine("Requested driver idx out-of-bound, "
+                                          "number of availabe drivers: ") +
+                              std::to_string(driverCount))
+                                 .str());
+  isDriverInitialised = true;
+  return drivers[idx];
+}
+
+static ze_device_handle_t getDevice(const uint32_t driverIdx = 0,
+                                    const int32_t devIdx = 0) {
+  thread_local static ze_device_handle_t l0Device;
+  thread_local int32_t currDevIdx{-1};
+  thread_local uint32_t currDriverIdx{0};
+  if (currDriverIdx == driverIdx && currDevIdx == devIdx)
+    return l0Device;
+  auto driver = getDriver(driverIdx);
+  uint32_t deviceCount{0};
+  L0_SAFE_CALL(zeDeviceGet(driver, &deviceCount, nullptr));
+  if (!deviceCount)
+    throw std::runtime_error("getDevice failed: did not find L0 device.");
+  if (static_cast<int>(deviceCount) < devIdx + 1)
+    throw std::runtime_error("getDevice failed: devIdx out-of-bounds.");
+  std::vector<ze_device_handle_t> devices(deviceCount);
+  L0_SAFE_CALL(zeDeviceGet(driver, &deviceCount, devices.data()));
+  l0Device = devices[devIdx];
+  currDriverIdx = driverIdx;
+  currDevIdx = devIdx;
+  return l0Device;
+}
+
+// Returns the default L0 context of the defult driver.
+static ze_context_handle_t getDefaultContext() {
+  thread_local static ze_context_handle_t context;
+  thread_local static bool isContextInitialised{false};
+  if (isContextInitialised)
+    return context;
+  ze_context_desc_t ctxtDesc = {ZE_STRUCTURE_TYPE_CONTEXT_DESC, nullptr, 0};
+  auto driver = getDriver();
+  L0_SAFE_CALL(zeContextCreate(driver, &ctxtDesc, &context));
+  isContextInitialised = true;
+  return context;
+}
+
+//===----------------------------------------------------------------------===//
+// L0 RT helper structs
+//===----------------------------------------------------------------------===//
+
+struct ZeContextDeleter {
+  void operator()(ze_context_handle_t ctx) const {
+    if (ctx)
+      L0_SAFE_CALL(zeContextDestroy(ctx));
+  }
+};
+
+struct ZeCommandListDeleter {
+  void operator()(ze_command_list_handle_t cmdList) const {
+    if (cmdList)
+      L0_SAFE_CALL(zeCommandListDestroy(cmdList));
+  }
+};
+using UniqueZeContext =
+    std::unique_ptr<std::remove_pointer<ze_context_handle_t>::type,
+                    ZeContextDeleter>;
+using UniqueZeCommandList =
+    std::unique_ptr<std::remove_pointer<ze_command_list_handle_t>::type,
+                    ZeCommandListDeleter>;
+struct L0RtContext {
----------------
mshahneo wrote:

Thanks, Changed the name to `L0RTContextWrapper`.

https://github.com/llvm/llvm-project/pull/151038


More information about the Mlir-commits mailing list