[llvm] [Offload] Ensure to load images when the device is used (PR #103002)

Johannes Doerfert via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 13 14:25:49 PDT 2024


================
@@ -286,13 +291,194 @@ void PluginManager::unregisterLib(__tgt_bin_desc *Desc) {
   DP("Done unregistering library!\n");
 }
 
+/// Map global data and execute pending ctors
+static int loadImagesOntoDevice(DeviceTy &Device) {
+  /*
+   * Map global data
+   */
+  int32_t DeviceId = Device.DeviceID;
+  int Rc = OFFLOAD_SUCCESS;
+  {
+    std::lock_guard<decltype(PM->TrlTblMtx)> LG(PM->TrlTblMtx);
+    for (auto *HostEntriesBegin : PM->HostEntriesBeginRegistrationOrder) {
+      TranslationTable *TransTable =
+          &PM->HostEntriesBeginToTransTable[HostEntriesBegin];
+      DP("Trans table %p : %p\n", TransTable->HostTable.EntriesBegin,
+         TransTable->HostTable.EntriesEnd);
+      if (TransTable->HostTable.EntriesBegin ==
+          TransTable->HostTable.EntriesEnd) {
+        // No host entry so no need to proceed
+        continue;
+      }
+
+      if (TransTable->TargetsTable[DeviceId] != 0) {
+        // Library entries have already been processed
+        continue;
+      }
+
+      // 1) get image.
+      assert(TransTable->TargetsImages.size() > (size_t)DeviceId &&
+             "Not expecting a device ID outside the table's bounds!");
+      __tgt_device_image *Img = TransTable->TargetsImages[DeviceId];
+      if (!Img) {
+        REPORT("No image loaded for device id %d.\n", DeviceId);
+        Rc = OFFLOAD_FAIL;
+        break;
+      }
+
+      // 2) Load the image onto the given device.
+      auto BinaryOrErr = Device.loadBinary(Img);
+      if (llvm::Error Err = BinaryOrErr.takeError()) {
+        REPORT("Failed to load image %s\n",
+               llvm::toString(std::move(Err)).c_str());
+        Rc = OFFLOAD_FAIL;
+        break;
+      }
+
+      // 3) Create the translation table.
+      llvm::SmallVector<__tgt_offload_entry> &DeviceEntries =
+          TransTable->TargetsEntries[DeviceId];
+      for (__tgt_offload_entry &Entry :
+           llvm::make_range(Img->EntriesBegin, Img->EntriesEnd)) {
+        __tgt_device_binary &Binary = *BinaryOrErr;
+
+        __tgt_offload_entry DeviceEntry = Entry;
+        if (Entry.size) {
+          if (Device.RTL->get_global(Binary, Entry.size, Entry.name,
+                                     &DeviceEntry.addr) != OFFLOAD_SUCCESS)
+            REPORT("Failed to load symbol %s\n", Entry.name);
+
+          // If unified memory is active, the corresponding global is a device
+          // reference to the host global. We need to initialize the pointer on
+          // the deive to point to the memory on the host.
+          if ((PM->getRequirements() & OMP_REQ_UNIFIED_SHARED_MEMORY) ||
+              (PM->getRequirements() & OMPX_REQ_AUTO_ZERO_COPY)) {
+            if (Device.RTL->data_submit(DeviceId, DeviceEntry.addr, Entry.addr,
+                                        Entry.size) != OFFLOAD_SUCCESS)
+              REPORT("Failed to write symbol for USM %s\n", Entry.name);
+          }
+        } else if (Entry.addr) {
+          if (Device.RTL->get_function(Binary, Entry.name, &DeviceEntry.addr) !=
+              OFFLOAD_SUCCESS)
+            REPORT("Failed to load kernel %s\n", Entry.name);
+        }
+        DP("Entry point " DPxMOD " maps to%s %s (" DPxMOD ")\n",
+           DPxPTR(Entry.addr), (Entry.size) ? " global" : "", Entry.name,
+           DPxPTR(DeviceEntry.addr));
+
+        DeviceEntries.emplace_back(DeviceEntry);
+      }
+
+      // Set the storage for the table and get a pointer to it.
+      __tgt_target_table DeviceTable{&DeviceEntries[0],
+                                     &DeviceEntries[0] + DeviceEntries.size()};
+      TransTable->DeviceTables[DeviceId] = DeviceTable;
+      __tgt_target_table *TargetTable = TransTable->TargetsTable[DeviceId] =
+          &TransTable->DeviceTables[DeviceId];
+
+      // 4) Verify whether the two table sizes match.
+      size_t Hsize =
+          TransTable->HostTable.EntriesEnd - TransTable->HostTable.EntriesBegin;
+      size_t Tsize = TargetTable->EntriesEnd - TargetTable->EntriesBegin;
+
+      // Invalid image for these host entries!
+      if (Hsize != Tsize) {
+        REPORT(
+            "Host and Target tables mismatch for device id %d [%zx != %zx].\n",
+            DeviceId, Hsize, Tsize);
+        TransTable->TargetsImages[DeviceId] = 0;
+        TransTable->TargetsTable[DeviceId] = 0;
+        Rc = OFFLOAD_FAIL;
+        break;
+      }
+
+      MappingInfoTy::HDTTMapAccessorTy HDTTMap =
+          Device.getMappingInfo().HostDataToTargetMap.getExclusiveAccessor();
+
+      __tgt_target_table *HostTable = &TransTable->HostTable;
+      for (__tgt_offload_entry *CurrDeviceEntry = TargetTable->EntriesBegin,
+                               *CurrHostEntry = HostTable->EntriesBegin,
+                               *EntryDeviceEnd = TargetTable->EntriesEnd;
+           CurrDeviceEntry != EntryDeviceEnd;
+           CurrDeviceEntry++, CurrHostEntry++) {
+        if (CurrDeviceEntry->size == 0)
+          continue;
+
+        assert(CurrDeviceEntry->size == CurrHostEntry->size &&
+               "data size mismatch");
+
+        // Fortran may use multiple weak declarations for the same symbol,
+        // therefore we must allow for multiple weak symbols to be loaded from
+        // the fat binary. Treat these mappings as any other "regular"
+        // mapping. Add entry to map.
+        if (Device.getMappingInfo().getTgtPtrBegin(HDTTMap, CurrHostEntry->addr,
+                                                   CurrHostEntry->size))
+          continue;
+
+        void *CurrDeviceEntryAddr = CurrDeviceEntry->addr;
+
+        // For indirect mapping, follow the indirection and map the actual
+        // target.
+        if (CurrDeviceEntry->flags & OMP_DECLARE_TARGET_INDIRECT) {
+          AsyncInfoTy AsyncInfo(Device);
+          void *DevPtr;
+          Device.retrieveData(&DevPtr, CurrDeviceEntryAddr, sizeof(void *),
+                              AsyncInfo, /*Entry=*/nullptr, &HDTTMap);
+          if (AsyncInfo.synchronize() != OFFLOAD_SUCCESS)
+            return OFFLOAD_FAIL;
+          CurrDeviceEntryAddr = DevPtr;
+        }
+
+        DP("Add mapping from host " DPxMOD " to device " DPxMOD " with size %zu"
+           ", name \"%s\"\n",
+           DPxPTR(CurrHostEntry->addr), DPxPTR(CurrDeviceEntry->addr),
+           CurrDeviceEntry->size, CurrDeviceEntry->name);
+        HDTTMap->emplace(new HostDataToTargetTy(
+            (uintptr_t)CurrHostEntry->addr /*HstPtrBase*/,
+            (uintptr_t)CurrHostEntry->addr /*HstPtrBegin*/,
+            (uintptr_t)CurrHostEntry->addr + CurrHostEntry->size /*HstPtrEnd*/,
+            (uintptr_t)CurrDeviceEntryAddr /*TgtAllocBegin*/,
+            (uintptr_t)CurrDeviceEntryAddr /*TgtPtrBegin*/,
+            false /*UseHoldRefCount*/, CurrHostEntry->name,
+            true /*IsRefCountINF*/));
----------------
jdoerfert wrote:

I just moved the code. Cleanup is separate. 

https://github.com/llvm/llvm-project/pull/103002


More information about the llvm-commits mailing list