[llvm] [Offload] Ensure to load images when the device is used (PR #103002)
Johannes Doerfert via llvm-commits
llvm-commits at lists.llvm.org
Tue Aug 13 14:25:49 PDT 2024
================
@@ -286,13 +291,194 @@ void PluginManager::unregisterLib(__tgt_bin_desc *Desc) {
DP("Done unregistering library!\n");
}
+/// Map global data and execute pending ctors
+static int loadImagesOntoDevice(DeviceTy &Device) {
+ /*
+ * Map global data
+ */
+ int32_t DeviceId = Device.DeviceID;
+ int Rc = OFFLOAD_SUCCESS;
+ {
+ std::lock_guard<decltype(PM->TrlTblMtx)> LG(PM->TrlTblMtx);
+ for (auto *HostEntriesBegin : PM->HostEntriesBeginRegistrationOrder) {
+ TranslationTable *TransTable =
+ &PM->HostEntriesBeginToTransTable[HostEntriesBegin];
+ DP("Trans table %p : %p\n", TransTable->HostTable.EntriesBegin,
+ TransTable->HostTable.EntriesEnd);
+ if (TransTable->HostTable.EntriesBegin ==
+ TransTable->HostTable.EntriesEnd) {
+ // No host entry so no need to proceed
+ continue;
+ }
+
+ if (TransTable->TargetsTable[DeviceId] != 0) {
+ // Library entries have already been processed
+ continue;
+ }
+
+ // 1) get image.
+ assert(TransTable->TargetsImages.size() > (size_t)DeviceId &&
+ "Not expecting a device ID outside the table's bounds!");
+ __tgt_device_image *Img = TransTable->TargetsImages[DeviceId];
+ if (!Img) {
+ REPORT("No image loaded for device id %d.\n", DeviceId);
+ Rc = OFFLOAD_FAIL;
+ break;
+ }
+
+ // 2) Load the image onto the given device.
+ auto BinaryOrErr = Device.loadBinary(Img);
+ if (llvm::Error Err = BinaryOrErr.takeError()) {
+ REPORT("Failed to load image %s\n",
+ llvm::toString(std::move(Err)).c_str());
+ Rc = OFFLOAD_FAIL;
+ break;
+ }
+
+ // 3) Create the translation table.
+ llvm::SmallVector<__tgt_offload_entry> &DeviceEntries =
+ TransTable->TargetsEntries[DeviceId];
+ for (__tgt_offload_entry &Entry :
+ llvm::make_range(Img->EntriesBegin, Img->EntriesEnd)) {
+ __tgt_device_binary &Binary = *BinaryOrErr;
+
+ __tgt_offload_entry DeviceEntry = Entry;
+ if (Entry.size) {
+ if (Device.RTL->get_global(Binary, Entry.size, Entry.name,
+ &DeviceEntry.addr) != OFFLOAD_SUCCESS)
+ REPORT("Failed to load symbol %s\n", Entry.name);
+
+ // If unified memory is active, the corresponding global is a device
+ // reference to the host global. We need to initialize the pointer on
+ // the deive to point to the memory on the host.
+ if ((PM->getRequirements() & OMP_REQ_UNIFIED_SHARED_MEMORY) ||
+ (PM->getRequirements() & OMPX_REQ_AUTO_ZERO_COPY)) {
+ if (Device.RTL->data_submit(DeviceId, DeviceEntry.addr, Entry.addr,
+ Entry.size) != OFFLOAD_SUCCESS)
+ REPORT("Failed to write symbol for USM %s\n", Entry.name);
+ }
+ } else if (Entry.addr) {
+ if (Device.RTL->get_function(Binary, Entry.name, &DeviceEntry.addr) !=
+ OFFLOAD_SUCCESS)
+ REPORT("Failed to load kernel %s\n", Entry.name);
+ }
+ DP("Entry point " DPxMOD " maps to%s %s (" DPxMOD ")\n",
+ DPxPTR(Entry.addr), (Entry.size) ? " global" : "", Entry.name,
+ DPxPTR(DeviceEntry.addr));
+
+ DeviceEntries.emplace_back(DeviceEntry);
+ }
+
+ // Set the storage for the table and get a pointer to it.
+ __tgt_target_table DeviceTable{&DeviceEntries[0],
+ &DeviceEntries[0] + DeviceEntries.size()};
+ TransTable->DeviceTables[DeviceId] = DeviceTable;
+ __tgt_target_table *TargetTable = TransTable->TargetsTable[DeviceId] =
+ &TransTable->DeviceTables[DeviceId];
+
+ // 4) Verify whether the two table sizes match.
+ size_t Hsize =
+ TransTable->HostTable.EntriesEnd - TransTable->HostTable.EntriesBegin;
+ size_t Tsize = TargetTable->EntriesEnd - TargetTable->EntriesBegin;
+
+ // Invalid image for these host entries!
+ if (Hsize != Tsize) {
+ REPORT(
+ "Host and Target tables mismatch for device id %d [%zx != %zx].\n",
+ DeviceId, Hsize, Tsize);
+ TransTable->TargetsImages[DeviceId] = 0;
+ TransTable->TargetsTable[DeviceId] = 0;
+ Rc = OFFLOAD_FAIL;
+ break;
+ }
+
+ MappingInfoTy::HDTTMapAccessorTy HDTTMap =
+ Device.getMappingInfo().HostDataToTargetMap.getExclusiveAccessor();
+
+ __tgt_target_table *HostTable = &TransTable->HostTable;
+ for (__tgt_offload_entry *CurrDeviceEntry = TargetTable->EntriesBegin,
+ *CurrHostEntry = HostTable->EntriesBegin,
+ *EntryDeviceEnd = TargetTable->EntriesEnd;
+ CurrDeviceEntry != EntryDeviceEnd;
+ CurrDeviceEntry++, CurrHostEntry++) {
+ if (CurrDeviceEntry->size == 0)
+ continue;
+
+ assert(CurrDeviceEntry->size == CurrHostEntry->size &&
+ "data size mismatch");
+
+ // Fortran may use multiple weak declarations for the same symbol,
+ // therefore we must allow for multiple weak symbols to be loaded from
+ // the fat binary. Treat these mappings as any other "regular"
+ // mapping. Add entry to map.
+ if (Device.getMappingInfo().getTgtPtrBegin(HDTTMap, CurrHostEntry->addr,
+ CurrHostEntry->size))
+ continue;
+
+ void *CurrDeviceEntryAddr = CurrDeviceEntry->addr;
+
+ // For indirect mapping, follow the indirection and map the actual
+ // target.
+ if (CurrDeviceEntry->flags & OMP_DECLARE_TARGET_INDIRECT) {
+ AsyncInfoTy AsyncInfo(Device);
+ void *DevPtr;
+ Device.retrieveData(&DevPtr, CurrDeviceEntryAddr, sizeof(void *),
+ AsyncInfo, /*Entry=*/nullptr, &HDTTMap);
+ if (AsyncInfo.synchronize() != OFFLOAD_SUCCESS)
+ return OFFLOAD_FAIL;
+ CurrDeviceEntryAddr = DevPtr;
+ }
+
+ DP("Add mapping from host " DPxMOD " to device " DPxMOD " with size %zu"
+ ", name \"%s\"\n",
+ DPxPTR(CurrHostEntry->addr), DPxPTR(CurrDeviceEntry->addr),
+ CurrDeviceEntry->size, CurrDeviceEntry->name);
+ HDTTMap->emplace(new HostDataToTargetTy(
+ (uintptr_t)CurrHostEntry->addr /*HstPtrBase*/,
+ (uintptr_t)CurrHostEntry->addr /*HstPtrBegin*/,
+ (uintptr_t)CurrHostEntry->addr + CurrHostEntry->size /*HstPtrEnd*/,
+ (uintptr_t)CurrDeviceEntryAddr /*TgtAllocBegin*/,
+ (uintptr_t)CurrDeviceEntryAddr /*TgtPtrBegin*/,
+ false /*UseHoldRefCount*/, CurrHostEntry->name,
+ true /*IsRefCountINF*/));
----------------
jdoerfert wrote:
I just moved the code. Cleanup is separate.
https://github.com/llvm/llvm-project/pull/103002
More information about the llvm-commits
mailing list