[llvm] [OFFLOAD] Add plugin with support for Intel oneAPI Level Zero (PR #158900)
Alex Duran via llvm-commits
llvm-commits at lists.llvm.org
Sat Sep 20 01:31:59 PDT 2025
================
@@ -0,0 +1,1066 @@
+//===--- Level Zero Target RTL Implementation -----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// GenericDevice instatiation for SPIR-V/Xe machine
+//
+//===----------------------------------------------------------------------===//
+
+#include "L0Device.h"
+#include "L0Defs.h"
+#include "L0Interop.h"
+#include "L0Plugin.h"
+#include "L0Program.h"
+#include "L0Trace.h"
+
+namespace llvm::omp::target::plugin {
+
+L0DeviceTLSTy &L0DeviceTy::getTLS() {
+ return getPlugin().getDeviceTLS(getDeviceId());
+}
+
+// clang-format off
+/// Mapping from device arch to GPU runtime's device identifiers
+static struct {
+ DeviceArchTy arch;
+ PCIIdTy ids[10];
+} DeviceArchMap[] = {{DeviceArchTy::DeviceArch_Gen,
+ {PCIIdTy::SKL,
+ PCIIdTy::KBL,
+ PCIIdTy::CFL, PCIIdTy::CFL_2,
+ PCIIdTy::ICX,
+ PCIIdTy::None}},
+ {DeviceArchTy::DeviceArch_Gen,
+ {PCIIdTy::TGL, PCIIdTy::TGL_2,
+ PCIIdTy::DG1,
+ PCIIdTy::RKL,
+ PCIIdTy::ADLS,
+ PCIIdTy::RTL,
+ PCIIdTy::None}},
+ {DeviceArchTy::DeviceArch_XeLPG,
+ {PCIIdTy::MTL,
+ PCIIdTy::None}},
+ {DeviceArchTy::DeviceArch_XeHPC,
+ {PCIIdTy::PVC,
+ PCIIdTy::None}},
+ {DeviceArchTy::DeviceArch_XeHPG,
+ {PCIIdTy::DG2_ATS_M,
+ PCIIdTy::DG2_ATS_M_2,
+ PCIIdTy::None}},
+ {DeviceArchTy::DeviceArch_Xe2LP,
+ {PCIIdTy::LNL,
+ PCIIdTy::None}},
+ {DeviceArchTy::DeviceArch_Xe2HP,
+ {PCIIdTy::BMG,
+ PCIIdTy::None}},
+};
+constexpr int DeviceArchMapSize = sizeof(DeviceArchMap) / sizeof(DeviceArchMap[0]);
+// clang-format on
+
+DeviceArchTy L0DeviceTy::computeArch() const {
+ const auto PCIDeviceId = getPCIId();
+ if (PCIDeviceId != 0) {
+ for (int ArchIndex = 0; ArchIndex < DeviceArchMapSize; ArchIndex++) {
+ for (int i = 0;; i++) {
+ const auto Id = DeviceArchMap[ArchIndex].ids[i];
+ if (Id == PCIIdTy::None)
+ break;
+
+ auto maskedId = static_cast<PCIIdTy>(PCIDeviceId & 0xFF00);
+ if (maskedId == Id)
+ return DeviceArchMap[ArchIndex].arch; // Exact match or prefix match
+ }
+ }
+ }
+
+ DP("Warning: Cannot decide device arch for %s.\n", getNameCStr());
+ return DeviceArchTy::DeviceArch_None;
+}
+
+bool L0DeviceTy::isDeviceIPorNewer(uint32_t Version) const {
+ ze_device_ip_version_ext_t IPVersion{};
+ IPVersion.stype = ZE_STRUCTURE_TYPE_DEVICE_IP_VERSION_EXT;
+ IPVersion.pNext = nullptr;
+ ze_device_properties_t DevicePR{};
+ DevicePR.stype = ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES;
+ DevicePR.pNext = &IPVersion;
+ CALL_ZE_RET(false, zeDeviceGetProperties, zeDevice, &DevicePR);
+ return IPVersion.ipVersion >= Version;
+}
+
+/// Get default compute group ordinal. Returns Ordinal-NumQueues pair
+std::pair<uint32_t, uint32_t> L0DeviceTy::findComputeOrdinal() {
+ std::pair<uint32_t, uint32_t> Ordinal{UINT32_MAX, 0};
+ uint32_t Count = 0;
+ const auto zeDevice = getZeDevice();
+ CALL_ZE_RET(Ordinal, zeDeviceGetCommandQueueGroupProperties, zeDevice, &Count,
+ nullptr);
+ ze_command_queue_group_properties_t Init{
+ ZE_STRUCTURE_TYPE_COMMAND_QUEUE_GROUP_PROPERTIES, nullptr, 0, 0, 0};
+ std::vector<ze_command_queue_group_properties_t> Properties(Count, Init);
+ CALL_ZE_RET(Ordinal, zeDeviceGetCommandQueueGroupProperties, zeDevice, &Count,
+ Properties.data());
+ for (uint32_t I = 0; I < Count; I++) {
+ // TODO: add a separate set of ordinals for compute queue groups which
+ // support cooperative kernels
+ if (Properties[I].flags & ZE_COMMAND_QUEUE_GROUP_PROPERTY_FLAG_COMPUTE) {
+ Ordinal.first = I;
+ Ordinal.second = Properties[I].numQueues;
+ break;
+ }
+ }
+ if (Ordinal.first == UINT32_MAX)
+ DP("Error: no command queues are found\n");
+
+ return Ordinal;
+}
+
+/// Get copy command queue group ordinal. Returns Ordinal-NumQueues pair
+std::pair<uint32_t, uint32_t> L0DeviceTy::findCopyOrdinal(bool LinkCopy) {
+ std::pair<uint32_t, uint32_t> Ordinal{UINT32_MAX, 0};
+ uint32_t Count = 0;
+ const auto zeDevice = getZeDevice();
+ CALL_ZE_RET(Ordinal, zeDeviceGetCommandQueueGroupProperties, zeDevice, &Count,
+ nullptr);
+ ze_command_queue_group_properties_t Init{
+ ZE_STRUCTURE_TYPE_COMMAND_QUEUE_GROUP_PROPERTIES, nullptr, 0, 0, 0};
+ std::vector<ze_command_queue_group_properties_t> Properties(Count, Init);
+ CALL_ZE_RET(Ordinal, zeDeviceGetCommandQueueGroupProperties, zeDevice, &Count,
+ Properties.data());
+
+ for (uint32_t I = 0; I < Count; I++) {
+ const auto &Flags = Properties[I].flags;
+ if ((Flags & ZE_COMMAND_QUEUE_GROUP_PROPERTY_FLAG_COPY) &&
+ (Flags & ZE_COMMAND_QUEUE_GROUP_PROPERTY_FLAG_COMPUTE) == 0) {
+ auto NumQueues = Properties[I].numQueues;
+ if (LinkCopy && NumQueues > 1) {
+ Ordinal = {I, NumQueues};
+ DP("Found link copy command queue for device " DPxMOD
+ ", ordinal = %" PRIu32 ", number of queues = %" PRIu32 "\n",
+ DPxPTR(zeDevice), Ordinal.first, Ordinal.second);
+ break;
+ } else if (!LinkCopy && NumQueues == 1) {
+ Ordinal = {I, NumQueues};
+ DP("Found copy command queue for device " DPxMOD ", ordinal = %" PRIu32
+ "\n",
+ DPxPTR(zeDevice), Ordinal.first);
+ break;
+ }
+ }
+ }
+ return Ordinal;
+}
+
+void L0DeviceTy::reportDeviceInfo() const {
+ DP("Device %" PRIu32 "\n", DeviceId);
+ DP("-- Name : %s\n", getNameCStr());
+ DP("-- PCI ID : 0x%" PRIx32 "\n", getPCIId());
+ DP("-- UUID : %s\n", getUuid().c_str());
+ DP("-- Number of total EUs : %" PRIu32 "\n", getNumEUs());
+ DP("-- Number of threads per EU : %" PRIu32 "\n", getNumThreadsPerEU());
+ DP("-- EU SIMD width : %" PRIu32 "\n", getSIMDWidth());
+ DP("-- Number of EUs per subslice : %" PRIu32 "\n", getNumEUsPerSubslice());
+ DP("-- Number of subslices per slice: %" PRIu32 "\n",
+ getNumSubslicesPerSlice());
+ DP("-- Number of slices : %" PRIu32 "\n", getNumSlices());
+ DP("-- Local memory size (bytes) : %" PRIu32 "\n",
+ getMaxSharedLocalMemory());
+ DP("-- Global memory size (bytes) : %" PRIu64 "\n", getGlobalMemorySize());
+ DP("-- Cache size (bytes) : %" PRIu64 "\n", getCacheSize());
+ DP("-- Max clock frequency (MHz) : %" PRIu32 "\n", getClockRate());
+}
+
+Error L0DeviceTy::internalInit() {
+ const auto &Options = getPlugin().getOptions();
+
+ uint32_t Count = 1;
+ const auto zeDevice = getZeDevice();
+ CALL_ZE_RET_ERROR(zeDeviceGetProperties, zeDevice, &DeviceProperties);
+ CALL_ZE_RET_ERROR(zeDeviceGetComputeProperties, zeDevice, &ComputeProperties);
+ CALL_ZE_RET_ERROR(zeDeviceGetMemoryProperties, zeDevice, &Count,
+ &MemoryProperties);
+ CALL_ZE_RET_ERROR(zeDeviceGetCacheProperties, zeDevice, &Count,
+ &CacheProperties);
+
+ DeviceName =
+ std::string(DeviceProperties.name, sizeof(DeviceProperties.name));
+
+ DP("Found a GPU device, Name = %s\n", DeviceProperties.name);
+
+ DeviceArch = computeArch();
+ // Default allocation kind for this device
+ AllocKind = isDiscreteDevice() ? TARGET_ALLOC_DEVICE : TARGET_ALLOC_SHARED;
+
+ ze_kernel_indirect_access_flags_t Flags =
+ (AllocKind == TARGET_ALLOC_DEVICE)
+ ? ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE
+ : ZE_KERNEL_INDIRECT_ACCESS_FLAG_SHARED;
+ IndirectAccessFlags = Flags;
+
+ // Get the UUID
+ std::string uid = "";
+ for (int n = 0; n < ZE_MAX_DEVICE_UUID_SIZE; n++)
+ uid += std::to_string(DeviceProperties.uuid.id[n]);
+ DeviceUuid = std::move(uid);
+
+ ComputeOrdinal = findComputeOrdinal();
+
+ CopyOrdinal = findCopyOrdinal();
+
+ LinkCopyOrdinal = findCopyOrdinal(true);
+ IsAsyncEnabled =
+ isDiscreteDevice() && Options.CommandMode != CommandModeTy::Sync;
+ MemAllocator.initDevicePools(*this, getPlugin().getOptions());
+ l0Context.getHostMemAllocator().updateMaxAllocSize(*this);
+ return Plugin::success();
+}
+
+Error L0DeviceTy::initImpl(GenericPluginTy &Plugin) {
+ return Plugin::success();
+}
+
+int32_t L0DeviceTy::synchronize(__tgt_async_info *AsyncInfo,
+ bool ReleaseQueue) {
+ bool IsAsync = AsyncInfo && asyncEnabled();
+ if (!IsAsync)
+ return OFFLOAD_SUCCESS;
+
+ auto &Plugin = getPlugin();
+
+ AsyncQueueTy *AsyncQueue = (AsyncQueueTy *)AsyncInfo->Queue;
+
+ if (!AsyncQueue->WaitEvents.empty()) {
+ const auto &WaitEvents = AsyncQueue->WaitEvents;
+ if (Plugin.getOptions().CommandMode == CommandModeTy::AsyncOrdered) {
+ // Only need to wait for the last event
+ CALL_ZE_RET_FAIL(zeEventHostSynchronize, WaitEvents.back(), UINT64_MAX);
+ // Synchronize on kernel event to support printf()
+ auto KE = AsyncQueue->KernelEvent;
+ if (KE && KE != WaitEvents.back()) {
+ CALL_ZE_RET_FAIL(zeEventHostSynchronize, KE, UINT64_MAX);
+ }
+ for (auto &Event : WaitEvents) {
+ releaseEvent(Event);
+ }
+ } else { // Async
+ // Wait for all events. We should wait and reset events in reverse order
+ // to avoid premature event reset. If we have a kernel event in the
+ // queue, it is the last event to wait for since all wait events of the
+ // kernel are signaled before the kernel is invoked. We always invoke
+ // synchronization on kernel event to support printf().
+ bool WaitDone = false;
+ for (auto Itr = WaitEvents.rbegin(); Itr != WaitEvents.rend(); Itr++) {
+ if (!WaitDone) {
+ CALL_ZE_RET_FAIL(zeEventHostSynchronize, *Itr, UINT64_MAX);
+ if (*Itr == AsyncQueue->KernelEvent)
+ WaitDone = true;
+ }
+ releaseEvent(*Itr);
+ }
+ }
+ }
+
+ // Commit delayed USM2M copies
+ for (auto &USM2M : AsyncQueue->USM2MList) {
+ std::copy_n(static_cast<const char *>(std::get<0>(USM2M)),
+ std::get<2>(USM2M), static_cast<char *>(std::get<1>(USM2M)));
+ }
+ // Commit delayed H2M copies
+ for (auto &H2M : AsyncQueue->H2MList) {
+ std::copy_n(static_cast<char *>(std::get<0>(H2M)), std::get<2>(H2M),
+ static_cast<char *>(std::get<1>(H2M)));
+ }
+ if (ReleaseQueue) {
+ Plugin.releaseAsyncQueue(AsyncQueue);
+ getStagingBuffer().reset();
+ AsyncInfo->Queue = nullptr;
+ }
+ return OFFLOAD_SUCCESS;
+}
+
+int32_t L0DeviceTy::submitData(void *TgtPtr, const void *HstPtr, int64_t Size,
+ __tgt_async_info *AsyncInfo) {
+ if (Size == 0)
+ return OFFLOAD_SUCCESS;
+
+ auto &Plugin = getPlugin();
+
+ const auto DeviceId = getDeviceId();
+ bool IsAsync = AsyncInfo && asyncEnabled();
+ if (IsAsync && !AsyncInfo->Queue) {
+ AsyncInfo->Queue = reinterpret_cast<void *>(Plugin.getAsyncQueue());
+ if (!AsyncInfo->Queue)
+ IsAsync = false; // Couldn't get a queue, revert to sync
+ }
+ const auto TgtPtrType = getMemAllocType(TgtPtr);
+ if (TgtPtrType == ZE_MEMORY_TYPE_SHARED ||
+ TgtPtrType == ZE_MEMORY_TYPE_HOST) {
+ std::copy_n(static_cast<const char *>(HstPtr), Size,
+ static_cast<char *>(TgtPtr));
+ } else {
+ const void *SrcPtr = HstPtr;
+ if (isDiscreteDevice() &&
+ static_cast<size_t>(Size) <= Plugin.getOptions().StagingBufferSize &&
+ getMemAllocType(HstPtr) != ZE_MEMORY_TYPE_HOST) {
+ SrcPtr = getStagingBuffer().get(IsAsync);
+ std::copy_n(static_cast<const char *>(HstPtr), Size,
+ static_cast<char *>(const_cast<void *>(SrcPtr)));
+ }
+ int32_t RC;
+ if (IsAsync)
+ RC = enqueueMemCopyAsync(TgtPtr, SrcPtr, Size, AsyncInfo);
+ else
+ RC = enqueueMemCopy(TgtPtr, SrcPtr, Size, AsyncInfo);
+ if (RC != OFFLOAD_SUCCESS)
+ return RC;
+ }
+ INFO(OMP_INFOTYPE_PLUGIN_KERNEL, DeviceId,
+ "%s %" PRId64 " bytes (hst:" DPxMOD ") -> (tgt:" DPxMOD ")\n",
+ IsAsync ? "Submitted copy" : "Copied", Size, DPxPTR(HstPtr),
+ DPxPTR(TgtPtr));
+
+ return OFFLOAD_SUCCESS;
+}
+
+int32_t L0DeviceTy::retrieveData(void *HstPtr, const void *TgtPtr, int64_t Size,
+ __tgt_async_info *AsyncInfo) {
+ if (Size == 0)
+ return OFFLOAD_SUCCESS;
+
+ auto &Plugin = getPlugin();
+ const auto DeviceId = getDeviceId();
+ bool IsAsync = AsyncInfo && asyncEnabled();
+ if (IsAsync && !AsyncInfo->Queue) {
+ AsyncInfo->Queue = Plugin.getAsyncQueue();
+ if (!AsyncInfo->Queue)
+ IsAsync = false; // Couldn't get a queue, revert to sync
+ }
+ auto AsyncQueue =
+ IsAsync ? static_cast<AsyncQueueTy *>(AsyncInfo->Queue) : nullptr;
+ auto TgtPtrType = getMemAllocType(TgtPtr);
+ if (TgtPtrType == ZE_MEMORY_TYPE_HOST ||
+ TgtPtrType == ZE_MEMORY_TYPE_SHARED) {
+ bool CopyNow = true;
+ if (IsAsync) {
+ if (AsyncQueue->KernelEvent) {
+ // Delay Host/Shared USM to host memory copy since it must wait for
+ // kernel completion.
+ AsyncQueue->USM2MList.emplace_back(TgtPtr, HstPtr, Size);
+ CopyNow = false;
+ }
+ }
+ if (CopyNow) {
+ std::copy_n(static_cast<const char *>(TgtPtr), Size,
+ static_cast<char *>(HstPtr));
+ }
+ } else {
+ void *DstPtr = HstPtr;
+ if (isDiscreteDevice() &&
+ static_cast<size_t>(Size) <=
+ getPlugin().getOptions().StagingBufferSize &&
+ getMemAllocType(HstPtr) != ZE_MEMORY_TYPE_HOST) {
+ DstPtr = getStagingBuffer().get(IsAsync);
+ }
+ int32_t RC;
+ if (IsAsync)
+ RC = enqueueMemCopyAsync(DstPtr, TgtPtr, Size, AsyncInfo,
+ /* CopyTo */ false);
+ else
+ RC = enqueueMemCopy(DstPtr, TgtPtr, Size, AsyncInfo);
+ if (RC != OFFLOAD_SUCCESS)
+ return RC;
+ if (DstPtr != HstPtr) {
+ if (IsAsync) {
+ // Store delayed H2M data copies
+ auto &H2MList = AsyncQueue->H2MList;
+ H2MList.emplace_back(DstPtr, HstPtr, static_cast<size_t>(Size));
+ } else {
+ std::copy_n(static_cast<char *>(DstPtr), Size,
+ static_cast<char *>(HstPtr));
+ }
+ }
+ }
+ INFO(OMP_INFOTYPE_PLUGIN_KERNEL, DeviceId,
+ "%s %" PRId64 " bytes (tgt:" DPxMOD ") -> (hst:" DPxMOD ")\n",
+ IsAsync ? "Submitted copy" : "Copied", Size, DPxPTR(TgtPtr),
+ DPxPTR(HstPtr));
+
+ return OFFLOAD_SUCCESS;
+}
+
+Expected<DeviceImageTy *>
+L0DeviceTy::loadBinaryImpl(const __tgt_device_image *TgtImage,
+ int32_t ImageId) {
+ auto *PGM = getProgramFromImage(TgtImage);
+ if (PGM) {
+ // Program already exists
+ return PGM;
+ }
+
+ INFO(OMP_INFOTYPE_PLUGIN_KERNEL, getDeviceId(),
+ "Device %" PRId32 ": Loading binary from " DPxMOD "\n", getDeviceId(),
+ DPxPTR(TgtImage->ImageStart));
+
+ const size_t NumEntries =
+ (size_t)(TgtImage->EntriesEnd - TgtImage->EntriesBegin);
+
+ INFO(OMP_INFOTYPE_PLUGIN_KERNEL, getDeviceId(),
+ "Expecting to have %zu entries defined\n", NumEntries);
+ (void)NumEntries; // silence warning
+
+ const auto &Options = getPlugin().getOptions();
+ std::string CompilationOptions(Options.CompilationOptions);
+ CompilationOptions += " " + Options.UserCompilationOptions;
+
+ INFO(OMP_INFOTYPE_PLUGIN_KERNEL, getDeviceId(),
+ "Base L0 module compilation options: %s\n", CompilationOptions.c_str());
+
+ CompilationOptions += " ";
+ CompilationOptions += Options.InternalCompilationOptions;
+ auto &Program = addProgram(ImageId, TgtImage);
+
+ int32_t RC = Program.buildModules(CompilationOptions);
+ if (RC != OFFLOAD_SUCCESS)
+ return Plugin::check(RC, "Error in buildModules %d", RC);
+
+ RC = Program.linkModules();
+ if (RC != OFFLOAD_SUCCESS)
+ return Plugin::check(RC, "Error in linkModules %d", RC);
+
+ RC = Program.loadModuleKernels();
+ if (RC != OFFLOAD_SUCCESS)
+ return Plugin::check(RC, "Error in buildKernels %d", RC);
+
+ return &Program;
+}
+
+Error L0DeviceTy::unloadBinaryImpl(DeviceImageTy *Image) {
+ // Ignoring for now
+ // TODO: call properly L0Program unload
+ return Plugin::success();
+}
+
+Error L0DeviceTy::synchronizeImpl(__tgt_async_info &AsyncInfo,
+ bool ReleaseQueue) {
+ if (!ReleaseQueue) {
+ return Plugin::error(ErrorCode::UNIMPLEMENTED,
+ "Support for ReleaseQueue=false in %s"
+ " not implemented yet\n",
+ __func__);
+ }
+ int32_t RC = synchronize(&AsyncInfo);
+ return Plugin::check(RC, "Error in synchronizeImpl %d", RC);
+}
+
+Expected<bool>
+L0DeviceTy::hasPendingWorkImpl(AsyncInfoWrapperTy &AsyncInfoWrapper) {
+ auto &AsyncInfo = *static_cast<__tgt_async_info *>(AsyncInfoWrapper);
+ const bool IsAsync = AsyncInfo.Queue && asyncEnabled();
+ if (!IsAsync)
+ return false;
+
+ auto *AsyncQueue = static_cast<AsyncQueueTy *>(AsyncInfo.Queue);
+
+ if (AsyncQueue->WaitEvents.empty())
+ return false;
+
+ return true;
+}
+
+Error L0DeviceTy::queryAsyncImpl(__tgt_async_info &AsyncInfo) {
+ const bool IsAsync = AsyncInfo.Queue && asyncEnabled();
+ if (!IsAsync)
+ return Plugin::success();
+
+ auto &Plugin = getPlugin();
+ auto *AsyncQueue = static_cast<AsyncQueueTy *>(AsyncInfo.Queue);
+
+ if (!AsyncQueue->WaitEvents.empty())
+ return Plugin::success();
+
+ // Commit delayed USM2M copies
+ for (auto &USM2M : AsyncQueue->USM2MList) {
+ std::copy_n(static_cast<const char *>(std::get<0>(USM2M)),
+ std::get<2>(USM2M), static_cast<char *>(std::get<1>(USM2M)));
+ }
+ // Commit delayed H2M copies
+ for (auto &H2M : AsyncQueue->H2MList) {
+ std::copy_n(static_cast<char *>(std::get<0>(H2M)), std::get<2>(H2M),
+ static_cast<char *>(std::get<1>(H2M)));
+ }
+ Plugin.releaseAsyncQueue(AsyncQueue);
+ getStagingBuffer().reset();
+ AsyncInfo.Queue = nullptr;
+
+ return Plugin::success();
+}
+
+void *L0DeviceTy::allocate(size_t Size, void *HstPtr, TargetAllocTy Kind) {
+ return dataAlloc(Size, /*Align=*/0, Kind,
+ /*Offset=*/0, /*UserAlloc=*/HstPtr == nullptr,
+ /*DevMalloc=*/false);
+}
+
+int L0DeviceTy::free(void *TgtPtr, TargetAllocTy Kind) {
+ return dataDelete(TgtPtr);
+}
+
+Error L0DeviceTy::dataSubmitImpl(void *TgtPtr, const void *HstPtr, int64_t Size,
+ AsyncInfoWrapperTy &AsyncInfoWrapper) {
+ int32_t RC = submitData(TgtPtr, HstPtr, Size, AsyncInfoWrapper);
+ return Plugin::check(RC, "Error in dataSubmitImpl %d", RC);
+}
+
+Error L0DeviceTy::dataRetrieveImpl(void *HstPtr, const void *TgtPtr,
+ int64_t Size,
+ AsyncInfoWrapperTy &AsyncInfoWrapper) {
+ int32_t RC = retrieveData(HstPtr, TgtPtr, Size, AsyncInfoWrapper);
+ return Plugin::check(RC, "Error in dataRetrieveImpl %d", RC);
+}
+
+Error L0DeviceTy::dataExchangeImpl(const void *SrcPtr, GenericDeviceTy &DstDev,
+ void *DstPtr, int64_t Size,
+ AsyncInfoWrapperTy &AsyncInfoWrapper) {
+
+ L0DeviceTy &L0DstDev = L0DeviceTy::makeL0Device(DstDev);
+ // Use copy engine only for across-tile/device copies.
+ const bool UseCopyEngine = getZeDevice() != L0DstDev.getZeDevice();
+
+ if (asyncEnabled() && AsyncInfoWrapper.hasQueue()) {
+ if (enqueueMemCopyAsync(DstPtr, SrcPtr, Size,
+ (__tgt_async_info *)AsyncInfoWrapper))
+ return Plugin::error(ErrorCode::UNKNOWN, "dataExchangeImpl failed");
+ } else {
+ if (enqueueMemCopy(DstPtr, SrcPtr, Size,
+ /* AsyncInfo */ nullptr,
+ /* Locked */ false, UseCopyEngine))
+ return Plugin::error(ErrorCode::UNKNOWN, "dataExchangeImpl failed");
+ }
+ return Plugin::success();
+}
+
+Error L0DeviceTy::initAsyncInfoImpl(AsyncInfoWrapperTy &AsyncInfoWrapper) {
+ AsyncQueueTy *Queue = AsyncInfoWrapper.getQueueAs<AsyncQueueTy *>();
+ if (!Queue) {
+ Queue = getPlugin().getAsyncQueue();
+ AsyncInfoWrapper.setQueueAs<AsyncQueueTy *>(Queue);
+ }
+ return Plugin::success();
+}
+
+Error L0DeviceTy::initDeviceInfoImpl(__tgt_device_info *Info) {
+ if (!Info->Context)
+ Info->Context = getZeContext();
+ if (!Info->Device)
+ Info->Device = reinterpret_cast<void *>(getZeDevice());
+ return Plugin::success();
+}
+
+Expected<InfoTreeNode> L0DeviceTy::obtainInfoImpl() {
+ InfoTreeNode Info;
+ Info.add("Device Number", getDeviceId());
+ Info.add("Device Name", getNameCStr());
----------------
adurang wrote:
I'm sorry but I don't quite follow what you want me to do. Can you point me somewhere (example, documentation, ...)?
https://github.com/llvm/llvm-project/pull/158900
More information about the llvm-commits
mailing list