[llvm] [Offload] Ensure to load images when the device is used (PR #103002)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Aug 12 21:51:32 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-offload
Author: Johannes Doerfert (jdoerfert)
<details>
<summary>Changes</summary>
When we use the device, e.g., with an API that interacts with it, we need to ensure the image is loaded and the constructors are executed. Two tests are included to verify we 1) load images and run constructors when needed, and 2) we do so lazily only if the device is actually used.
---
Patch is 24.66 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/103002.diff
7 Files Affected:
- (modified) offload/include/device.h (+9)
- (modified) offload/src/PluginManager.cpp (+193-8)
- (modified) offload/src/interface.cpp (+43-3)
- (modified) offload/src/omptarget.cpp (-217)
- (modified) offload/src/private.h (-1)
- (added) offload/test/offloading/ctor_dtor_api.cpp (+24)
- (added) offload/test/offloading/ctor_dtor_lazy.cpp (+33)
``````````diff
diff --git a/offload/include/device.h b/offload/include/device.h
index fd6e5fba5fc530..103825c81059a5 100644
--- a/offload/include/device.h
+++ b/offload/include/device.h
@@ -152,6 +152,12 @@ struct DeviceTy {
/// Ask the device whether the runtime should use auto zero-copy.
bool useAutoZeroCopy();
+ /// Check if there are pending images for this device.
+ bool hasPendingImages() const { return HasPendingImages; }
+
+ /// Indicate that there are pending images for this device.
+ void setHasPendingImages() { HasPendingImages = true; }
+
private:
/// Deinitialize the device (and plugin).
void deinit();
@@ -163,6 +169,9 @@ struct DeviceTy {
/// Handler to collect and organize host-2-device mapping information.
MappingInfoTy MappingInfo;
+
+ /// Flag to indicate pending images (true after construction).
+ bool HasPendingImages = true;
};
#endif
diff --git a/offload/src/PluginManager.cpp b/offload/src/PluginManager.cpp
index c6117782fbab66..5fac6fa1fdfe6d 100644
--- a/offload/src/PluginManager.cpp
+++ b/offload/src/PluginManager.cpp
@@ -78,8 +78,13 @@ bool PluginManager::initializePlugin(GenericPluginTy &Plugin) {
bool PluginManager::initializeDevice(GenericPluginTy &Plugin,
int32_t DeviceId) {
- if (Plugin.is_device_initialized(DeviceId))
+ if (Plugin.is_device_initialized(DeviceId)) {
+ auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor();
+ (*ExclusiveDevicesAccessor)[PM->DeviceIds[std::make_pair(&Plugin,
+ DeviceId)]]
+ ->setHasPendingImages();
return true;
+ }
// Initialize the device information for the RTL we are about to use.
auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor();
@@ -286,13 +291,193 @@ void PluginManager::unregisterLib(__tgt_bin_desc *Desc) {
DP("Done unregistering library!\n");
}
+/// Map global data and execute pending ctors
+static int loadImagesOntoDevice(DeviceTy &Device) {
+ /*
+ * Map global data
+ */
+ int32_t DeviceId = Device.DeviceID;
+ int Rc = OFFLOAD_SUCCESS;
+ {
+ std::lock_guard<decltype(PM->TrlTblMtx)> LG(PM->TrlTblMtx);
+ for (auto *HostEntriesBegin : PM->HostEntriesBeginRegistrationOrder) {
+ TranslationTable *TransTable =
+ &PM->HostEntriesBeginToTransTable[HostEntriesBegin];
+ DP("Trans table %p : %p\n", TransTable->HostTable.EntriesBegin,
+ TransTable->HostTable.EntriesEnd);
+ if (TransTable->HostTable.EntriesBegin ==
+ TransTable->HostTable.EntriesEnd) {
+ // No host entry so no need to proceed
+ continue;
+ }
+
+ if (TransTable->TargetsTable[DeviceId] != 0) {
+ // Library entries have already been processed
+ continue;
+ }
+
+ // 1) get image.
+ assert(TransTable->TargetsImages.size() > (size_t)DeviceId &&
+ "Not expecting a device ID outside the table's bounds!");
+ __tgt_device_image *Img = TransTable->TargetsImages[DeviceId];
+ if (!Img) {
+ REPORT("No image loaded for device id %d.\n", DeviceId);
+ Rc = OFFLOAD_FAIL;
+ break;
+ }
+
+ // 2) Load the image onto the given device.
+ auto BinaryOrErr = Device.loadBinary(Img);
+ if (llvm::Error Err = BinaryOrErr.takeError()) {
+ REPORT("Failed to load image %s\n",
+ llvm::toString(std::move(Err)).c_str());
+ Rc = OFFLOAD_FAIL;
+ break;
+ }
+
+ // 3) Create the translation table.
+ llvm::SmallVector<__tgt_offload_entry> &DeviceEntries =
+ TransTable->TargetsEntries[DeviceId];
+ for (__tgt_offload_entry &Entry :
+ llvm::make_range(Img->EntriesBegin, Img->EntriesEnd)) {
+ __tgt_device_binary &Binary = *BinaryOrErr;
+
+ __tgt_offload_entry DeviceEntry = Entry;
+ if (Entry.size) {
+ if (Device.RTL->get_global(Binary, Entry.size, Entry.name,
+ &DeviceEntry.addr) != OFFLOAD_SUCCESS)
+ REPORT("Failed to load symbol %s\n", Entry.name);
+
+ // If unified memory is active, the corresponding global is a device
+ // reference to the host global. We need to initialize the pointer on
+ // the deive to point to the memory on the host.
+ if ((PM->getRequirements() & OMP_REQ_UNIFIED_SHARED_MEMORY) ||
+ (PM->getRequirements() & OMPX_REQ_AUTO_ZERO_COPY)) {
+ if (Device.RTL->data_submit(DeviceId, DeviceEntry.addr, Entry.addr,
+ Entry.size) != OFFLOAD_SUCCESS)
+ REPORT("Failed to write symbol for USM %s\n", Entry.name);
+ }
+ } else if (Entry.addr) {
+ if (Device.RTL->get_function(Binary, Entry.name, &DeviceEntry.addr) !=
+ OFFLOAD_SUCCESS)
+ REPORT("Failed to load kernel %s\n", Entry.name);
+ }
+ DP("Entry point " DPxMOD " maps to%s %s (" DPxMOD ")\n",
+ DPxPTR(Entry.addr), (Entry.size) ? " global" : "", Entry.name,
+ DPxPTR(DeviceEntry.addr));
+
+ DeviceEntries.emplace_back(DeviceEntry);
+ }
+
+ // Set the storage for the table and get a pointer to it.
+ __tgt_target_table DeviceTable{&DeviceEntries[0],
+ &DeviceEntries[0] + DeviceEntries.size()};
+ TransTable->DeviceTables[DeviceId] = DeviceTable;
+ __tgt_target_table *TargetTable = TransTable->TargetsTable[DeviceId] =
+ &TransTable->DeviceTables[DeviceId];
+
+ // 4) Verify whether the two table sizes match.
+ size_t Hsize =
+ TransTable->HostTable.EntriesEnd - TransTable->HostTable.EntriesBegin;
+ size_t Tsize = TargetTable->EntriesEnd - TargetTable->EntriesBegin;
+
+ // Invalid image for these host entries!
+ if (Hsize != Tsize) {
+ REPORT(
+ "Host and Target tables mismatch for device id %d [%zx != %zx].\n",
+ DeviceId, Hsize, Tsize);
+ TransTable->TargetsImages[DeviceId] = 0;
+ TransTable->TargetsTable[DeviceId] = 0;
+ Rc = OFFLOAD_FAIL;
+ break;
+ }
+
+ MappingInfoTy::HDTTMapAccessorTy HDTTMap =
+ Device.getMappingInfo().HostDataToTargetMap.getExclusiveAccessor();
+
+ __tgt_target_table *HostTable = &TransTable->HostTable;
+ for (__tgt_offload_entry *CurrDeviceEntry = TargetTable->EntriesBegin,
+ *CurrHostEntry = HostTable->EntriesBegin,
+ *EntryDeviceEnd = TargetTable->EntriesEnd;
+ CurrDeviceEntry != EntryDeviceEnd;
+ CurrDeviceEntry++, CurrHostEntry++) {
+ if (CurrDeviceEntry->size == 0)
+ continue;
+
+ assert(CurrDeviceEntry->size == CurrHostEntry->size &&
+ "data size mismatch");
+
+ // Fortran may use multiple weak declarations for the same symbol,
+ // therefore we must allow for multiple weak symbols to be loaded from
+ // the fat binary. Treat these mappings as any other "regular"
+ // mapping. Add entry to map.
+ if (Device.getMappingInfo().getTgtPtrBegin(HDTTMap, CurrHostEntry->addr,
+ CurrHostEntry->size))
+ continue;
+
+ void *CurrDeviceEntryAddr = CurrDeviceEntry->addr;
+
+ // For indirect mapping, follow the indirection and map the actual
+ // target.
+ if (CurrDeviceEntry->flags & OMP_DECLARE_TARGET_INDIRECT) {
+ AsyncInfoTy AsyncInfo(Device);
+ void *DevPtr;
+ Device.retrieveData(&DevPtr, CurrDeviceEntryAddr, sizeof(void *),
+ AsyncInfo, /*Entry=*/nullptr, &HDTTMap);
+ if (AsyncInfo.synchronize() != OFFLOAD_SUCCESS)
+ return OFFLOAD_FAIL;
+ CurrDeviceEntryAddr = DevPtr;
+ }
+
+ DP("Add mapping from host " DPxMOD " to device " DPxMOD " with size %zu"
+ ", name \"%s\"\n",
+ DPxPTR(CurrHostEntry->addr), DPxPTR(CurrDeviceEntry->addr),
+ CurrDeviceEntry->size, CurrDeviceEntry->name);
+ HDTTMap->emplace(new HostDataToTargetTy(
+ (uintptr_t)CurrHostEntry->addr /*HstPtrBase*/,
+ (uintptr_t)CurrHostEntry->addr /*HstPtrBegin*/,
+ (uintptr_t)CurrHostEntry->addr + CurrHostEntry->size /*HstPtrEnd*/,
+ (uintptr_t)CurrDeviceEntryAddr /*TgtAllocBegin*/,
+ (uintptr_t)CurrDeviceEntryAddr /*TgtPtrBegin*/,
+ false /*UseHoldRefCount*/, CurrHostEntry->name,
+ true /*IsRefCountINF*/));
+
+ // Notify about the new mapping.
+ if (Device.notifyDataMapped(CurrHostEntry->addr, CurrHostEntry->size))
+ return OFFLOAD_FAIL;
+ }
+ }
+ }
+
+ if (Rc != OFFLOAD_SUCCESS)
+ return Rc;
+
+ static Int32Envar DumpOffloadEntries =
+ Int32Envar("OMPTARGET_DUMP_OFFLOAD_ENTRIES", -1);
+ if (DumpOffloadEntries.get() == DeviceId)
+ Device.dumpOffloadEntries();
+
+ return OFFLOAD_SUCCESS;
+}
+
Expected<DeviceTy &> PluginManager::getDevice(uint32_t DeviceNo) {
- auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor();
- if (DeviceNo >= ExclusiveDevicesAccessor->size())
- return createStringError(
- inconvertibleErrorCode(),
- "Device number '%i' out of range, only %i devices available", DeviceNo,
- ExclusiveDevicesAccessor->size());
+ DeviceTy *DevicePtr;
+ {
+ auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor();
+ if (DeviceNo >= ExclusiveDevicesAccessor->size())
+ return createStringError(
+ inconvertibleErrorCode(),
+ "Device number '%i' out of range, only %i devices available",
+ DeviceNo, ExclusiveDevicesAccessor->size());
+
+ DevicePtr = &*(*ExclusiveDevicesAccessor)[DeviceNo];
+ }
- return *(*ExclusiveDevicesAccessor)[DeviceNo];
+ // Check whether global data has been mapped for this device
+ if (DevicePtr->hasPendingImages())
+ if (loadImagesOntoDevice(*DevicePtr) != OFFLOAD_SUCCESS)
+ return createStringError(inconvertibleErrorCode(),
+ "Failed to load images on device '%i'",
+ DeviceNo);
+ return *DevicePtr;
}
diff --git a/offload/src/interface.cpp b/offload/src/interface.cpp
index b73e92f17ddb02..21f9114ac2b088 100644
--- a/offload/src/interface.cpp
+++ b/offload/src/interface.cpp
@@ -12,8 +12,11 @@
//===----------------------------------------------------------------------===//
#include "OpenMP/OMPT/Interface.h"
+#include "OffloadPolicy.h"
#include "OpenMP/OMPT/Callback.h"
+#include "OpenMP/omp.h"
#include "PluginManager.h"
+#include "omptarget.h"
#include "private.h"
#include "Shared/EnvironmentVar.h"
@@ -32,6 +35,43 @@
using namespace llvm::omp::target::ompt;
#endif
+// If offload is enabled, ensure that device DeviceID has been initialized.
+//
+// The return bool indicates if the offload is to the host device
+// There are three possible results:
+// - Return false if the taregt device is ready for offload
+// - Return true without reporting a runtime error if offload is
+// disabled, perhaps because the initial device was specified.
+// - Report a runtime error and return true.
+//
+// If DeviceID == OFFLOAD_DEVICE_DEFAULT, set DeviceID to the default device.
+// This step might be skipped if offload is disabled.
+bool checkDevice(int64_t &DeviceID, ident_t *Loc) {
+ if (OffloadPolicy::get(*PM).Kind == OffloadPolicy::DISABLED) {
+ DP("Offload is disabled\n");
+ return true;
+ }
+
+ if (DeviceID == OFFLOAD_DEVICE_DEFAULT) {
+ DeviceID = omp_get_default_device();
+ DP("Use default device id %" PRId64 "\n", DeviceID);
+ }
+
+ // Proposed behavior for OpenMP 5.2 in OpenMP spec github issue 2669.
+ if (omp_get_num_devices() == 0) {
+ DP("omp_get_num_devices() == 0 but offload is manadatory\n");
+ handleTargetOutcome(false, Loc);
+ return true;
+ }
+
+ if (DeviceID == omp_get_initial_device()) {
+ DP("Device is host (%" PRId64 "), returning as if offload is disabled\n",
+ DeviceID);
+ return true;
+ }
+ return false;
+}
+
////////////////////////////////////////////////////////////////////////////////
/// adds requires flags
EXTERN void __tgt_register_requires(int64_t Flags) {
@@ -85,7 +125,7 @@ targetData(ident_t *Loc, int64_t DeviceId, int32_t ArgNum, void **ArgsBase,
DP("Entering data %s region for device %" PRId64 " with %d mappings\n",
RegionName, DeviceId, ArgNum);
- if (checkDeviceAndCtors(DeviceId, Loc)) {
+ if (checkDevice(DeviceId, Loc)) {
DP("Not offloading to device %" PRId64 "\n", DeviceId);
return;
}
@@ -266,7 +306,7 @@ static inline int targetKernel(ident_t *Loc, int64_t DeviceId, int32_t NumTeams,
"\n",
DeviceId, DPxPTR(HostPtr));
- if (checkDeviceAndCtors(DeviceId, Loc)) {
+ if (checkDevice(DeviceId, Loc)) {
DP("Not offloading to device %" PRId64 "\n", DeviceId);
return OMP_TGT_FAIL;
}
@@ -404,7 +444,7 @@ EXTERN int __tgt_target_kernel_replay(ident_t *Loc, int64_t DeviceId,
uint64_t LoopTripCount) {
assert(PM && "Runtime not initialized");
OMPT_IF_BUILT(ReturnAddressSetterRAII RA(__builtin_return_address(0)));
- if (checkDeviceAndCtors(DeviceId, Loc)) {
+ if (checkDevice(DeviceId, Loc)) {
DP("Not offloading to device %" PRId64 "\n", DeviceId);
return OMP_TGT_FAIL;
}
diff --git a/offload/src/omptarget.cpp b/offload/src/omptarget.cpp
index 3b627d257a0694..7a2ee1303d68c4 100644
--- a/offload/src/omptarget.cpp
+++ b/offload/src/omptarget.cpp
@@ -131,173 +131,6 @@ static uint64_t getPartialStructRequiredAlignment(void *HstPtrBase) {
return MaxAlignment < BaseAlignment ? MaxAlignment : BaseAlignment;
}
-/// Map global data and execute pending ctors
-static int initLibrary(DeviceTy &Device) {
- /*
- * Map global data
- */
- int32_t DeviceId = Device.DeviceID;
- int Rc = OFFLOAD_SUCCESS;
- {
- std::lock_guard<decltype(PM->TrlTblMtx)> LG(PM->TrlTblMtx);
- for (auto *HostEntriesBegin : PM->HostEntriesBeginRegistrationOrder) {
- TranslationTable *TransTable =
- &PM->HostEntriesBeginToTransTable[HostEntriesBegin];
- if (TransTable->HostTable.EntriesBegin ==
- TransTable->HostTable.EntriesEnd) {
- // No host entry so no need to proceed
- continue;
- }
-
- if (TransTable->TargetsTable[DeviceId] != 0) {
- // Library entries have already been processed
- continue;
- }
-
- // 1) get image.
- assert(TransTable->TargetsImages.size() > (size_t)DeviceId &&
- "Not expecting a device ID outside the table's bounds!");
- __tgt_device_image *Img = TransTable->TargetsImages[DeviceId];
- if (!Img) {
- REPORT("No image loaded for device id %d.\n", DeviceId);
- Rc = OFFLOAD_FAIL;
- break;
- }
-
- // 2) Load the image onto the given device.
- auto BinaryOrErr = Device.loadBinary(Img);
- if (llvm::Error Err = BinaryOrErr.takeError()) {
- REPORT("Failed to load image %s\n",
- llvm::toString(std::move(Err)).c_str());
- Rc = OFFLOAD_FAIL;
- break;
- }
-
- // 3) Create the translation table.
- llvm::SmallVector<__tgt_offload_entry> &DeviceEntries =
- TransTable->TargetsEntries[DeviceId];
- for (__tgt_offload_entry &Entry :
- llvm::make_range(Img->EntriesBegin, Img->EntriesEnd)) {
- __tgt_device_binary &Binary = *BinaryOrErr;
-
- __tgt_offload_entry DeviceEntry = Entry;
- if (Entry.size) {
- if (Device.RTL->get_global(Binary, Entry.size, Entry.name,
- &DeviceEntry.addr) != OFFLOAD_SUCCESS)
- REPORT("Failed to load symbol %s\n", Entry.name);
-
- // If unified memory is active, the corresponding global is a device
- // reference to the host global. We need to initialize the pointer on
- // the deive to point to the memory on the host.
- if ((PM->getRequirements() & OMP_REQ_UNIFIED_SHARED_MEMORY) ||
- (PM->getRequirements() & OMPX_REQ_AUTO_ZERO_COPY)) {
- if (Device.RTL->data_submit(DeviceId, DeviceEntry.addr, Entry.addr,
- Entry.size) != OFFLOAD_SUCCESS)
- REPORT("Failed to write symbol for USM %s\n", Entry.name);
- }
- } else if (Entry.addr) {
- if (Device.RTL->get_function(Binary, Entry.name, &DeviceEntry.addr) !=
- OFFLOAD_SUCCESS)
- REPORT("Failed to load kernel %s\n", Entry.name);
- }
- DP("Entry point " DPxMOD " maps to%s %s (" DPxMOD ")\n",
- DPxPTR(Entry.addr), (Entry.size) ? " global" : "", Entry.name,
- DPxPTR(DeviceEntry.addr));
-
- DeviceEntries.emplace_back(DeviceEntry);
- }
-
- // Set the storage for the table and get a pointer to it.
- __tgt_target_table DeviceTable{&DeviceEntries[0],
- &DeviceEntries[0] + DeviceEntries.size()};
- TransTable->DeviceTables[DeviceId] = DeviceTable;
- __tgt_target_table *TargetTable = TransTable->TargetsTable[DeviceId] =
- &TransTable->DeviceTables[DeviceId];
-
- // 4) Verify whether the two table sizes match.
- size_t Hsize =
- TransTable->HostTable.EntriesEnd - TransTable->HostTable.EntriesBegin;
- size_t Tsize = TargetTable->EntriesEnd - TargetTable->EntriesBegin;
-
- // Invalid image for these host entries!
- if (Hsize != Tsize) {
- REPORT(
- "Host and Target tables mismatch for device id %d [%zx != %zx].\n",
- DeviceId, Hsize, Tsize);
- TransTable->TargetsImages[DeviceId] = 0;
- TransTable->TargetsTable[DeviceId] = 0;
- Rc = OFFLOAD_FAIL;
- break;
- }
-
- MappingInfoTy::HDTTMapAccessorTy HDTTMap =
- Device.getMappingInfo().HostDataToTargetMap.getExclusiveAccessor();
-
- __tgt_target_table *HostTable = &TransTable->HostTable;
- for (__tgt_offload_entry *CurrDeviceEntry = TargetTable->EntriesBegin,
- *CurrHostEntry = HostTable->EntriesBegin,
- *EntryDeviceEnd = TargetTable->EntriesEnd;
- CurrDeviceEntry != EntryDeviceEnd;
- CurrDeviceEntry++, CurrHostEntry++) {
- if (CurrDeviceEntry->size == 0)
- continue;
-
- assert(CurrDeviceEntry->size == CurrHostEntry->size &&
- "data size mismatch");
-
- // Fortran may use multiple weak declarations for the same symbol,
- // therefore we must allow for multiple weak symbols to be loaded from
- // the fat binary. Treat these mappings as any other "regular"
- // mapping. Add entry to map.
- if (Device.getMappingInfo().getTgtPtrBegin(HDTTMap, CurrHostEntry->addr,
- CurrHostEntry->size))
- continue;
-
- void *CurrDeviceEntryAddr = CurrDeviceEntry->addr;
-
- // For indirect mapping, follow the indirection and map the actual
- // target.
- if (CurrDeviceEntry->flags & OMP_DECLARE_TARGET_INDIRECT) {
- AsyncInfoTy AsyncInfo(Device);
- void *DevPtr;
- Device.retrieveData(&DevPtr, CurrDeviceEntryAddr, sizeof(void *),
- AsyncInfo, /*Entry=*/nullptr, &HDTTMap);
- if (AsyncInfo.synchronize() != OFFLOAD_SUCCESS)
- return OFFLOAD_FAIL;
- CurrDeviceEntryAddr = DevPtr;
- }
-
- DP("Add mapping from host " DPxMOD " to device " DPxMOD " with size %zu"
- ", name \"%s\"\n",
- DPxPTR(CurrHostEntry->addr), DPxPTR(CurrDeviceEntry->addr),
- CurrDeviceEntry->size, CurrDeviceEntry->name);
- HDTTMap->emplace(new HostDataToTargetTy(
- (uintptr_t)CurrHostEntry->addr /*HstPtrBase*/,
- (uintptr_t)CurrHostEntry->addr /*HstPtrBegin*/,
- (uintptr_t)CurrHostEntry->addr + CurrHostEntry->size /*HstPtrEnd*/,
- (uintptr_t)CurrDeviceEntryAddr /*TgtAllocBegin*/,
- (uintptr_t)CurrDeviceEntryAddr /*TgtPtrBegin*/,
- false /*UseHoldRefCount*/, CurrHostEntry->name,
- true /*IsRefCountINF*/));
-
- // Notify about the new mapping.
- if (Device.notifyDataMapped(CurrHostEntry->addr, CurrHostEntry->size))
- ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/103002
More information about the llvm-commits
mailing list