[llvm] [Libomptarget] Rework device initialization and image registration (PR #93844)

Joseph Huber via llvm-commits llvm-commits at lists.llvm.org
Fri May 31 15:41:04 PDT 2024


https://github.com/jhuber6 updated https://github.com/llvm/llvm-project/pull/93844

>From a719e4d9658a62bffcbce8cb97aec9ad71743586 Mon Sep 17 00:00:00 2001
From: Joseph Huber <huberjn at outlook.com>
Date: Wed, 29 May 2024 16:45:26 -0500
Subject: [PATCH] [Libomptarget] Rework device initialization and image
 registration

Summary:
Currently, we register images into a linear table according to the
logical OpenMP device identifier. We then initialize all of these images
as one block. This logic requires that images are compatible with *all*
devices instead of just the one that it can run on. This prevents us
from running on systems with heterogeneous devices (i.e. image 1 runs on
device 0 image 0 runs on device 1).

This patch reworks the logic by instead making the compatibility check a
per-device query. We then scan every device to see if it's compatible
and do it as they come.
---
 offload/include/PluginManager.h               |  27 +-
 offload/plugins-nextgen/amdgpu/src/rtl.cpp    |  21 +-
 .../common/include/PluginInterface.h          |  32 ++-
 .../common/src/PluginInterface.cpp            |  72 ++++--
 offload/plugins-nextgen/cuda/src/rtl.cpp      |  55 ++--
 offload/plugins-nextgen/host/src/rtl.cpp      |   4 +-
 offload/src/PluginManager.cpp                 | 234 ++++++++----------
 7 files changed, 229 insertions(+), 216 deletions(-)

diff --git a/offload/include/PluginManager.h b/offload/include/PluginManager.h
index 1d6804da75d92..0e8229ea63d9e 100644
--- a/offload/include/PluginManager.h
+++ b/offload/include/PluginManager.h
@@ -64,10 +64,6 @@ struct PluginManager {
         std::make_unique<DeviceImageTy>(TgtBinDesc, TgtDeviceImage));
   }
 
-  /// Initialize as many devices as possible for this plugin. Devices that fail
-  /// to initialize are ignored.
-  void initDevices(GenericPluginTy &RTL);
-
   /// Return the device presented to the user as device \p DeviceNo if it is
   /// initialized and ready. Otherwise return an error explaining the problem.
   llvm::Expected<DeviceTy &> getDevice(uint32_t DeviceNo);
@@ -117,20 +113,31 @@ struct PluginManager {
     return Devices.getExclusiveAccessor();
   }
 
-  int getNumUsedPlugins() const { return DeviceOffsets.size(); }
-
   // Initialize all plugins.
   void initAllPlugins();
 
   /// Iterator range for all plugins (in use or not, but always valid).
   auto plugins() { return llvm::make_pointee_range(Plugins); }
 
+  /// Iterator range for all plugins (in use or not, but always valid).
+  auto plugins() const { return llvm::make_pointee_range(Plugins); }
+
   /// Return the user provided requirements.
   int64_t getRequirements() const { return Requirements.getRequirements(); }
 
   /// Add \p Flags to the user provided requirements.
   void addRequirements(int64_t Flags) { Requirements.addRequirements(Flags); }
 
+  /// Returns the number of plugins that are active.
+  int getNumUsedPlugins() const {
+    int count = 0;
+    for (auto &R : plugins())
+      if (R.is_initialized())
+        ++count;
+
+    return count;
+  }
+
 private:
   bool RTLsLoaded = false;
   llvm::SmallVector<__tgt_bin_desc *> DelayedBinDesc;
@@ -138,11 +145,9 @@ struct PluginManager {
   // List of all plugins, in use or not.
   llvm::SmallVector<std::unique_ptr<GenericPluginTy>> Plugins;
 
-  // Mapping of plugins to offsets in the device table.
-  llvm::DenseMap<const GenericPluginTy *, int32_t> DeviceOffsets;
-
-  // Mapping of plugins to the number of used devices.
-  llvm::DenseMap<const GenericPluginTy *, int32_t> DeviceUsed;
+  // Mapping of plugins to the OpenMP device identifier.
+  llvm::DenseMap<std::pair<const GenericPluginTy *, int32_t>, int32_t>
+      DeviceIds;
 
   // Set of all device images currently in use.
   llvm::DenseSet<const __tgt_device_image *> UsedImages;
diff --git a/offload/plugins-nextgen/amdgpu/src/rtl.cpp b/offload/plugins-nextgen/amdgpu/src/rtl.cpp
index c6dd954746e4a..f088d5d1685df 100644
--- a/offload/plugins-nextgen/amdgpu/src/rtl.cpp
+++ b/offload/plugins-nextgen/amdgpu/src/rtl.cpp
@@ -3163,25 +3163,24 @@ struct AMDGPUPluginTy final : public GenericPluginTy {
   uint16_t getMagicElfBits() const override { return ELF::EM_AMDGPU; }
 
   /// Check whether the image is compatible with an AMDGPU device.
-  Expected<bool> isELFCompatible(StringRef Image) const override {
+  Expected<bool> isELFCompatible(uint32_t DeviceId,
+                                 StringRef Image) const override {
     // Get the associated architecture and flags from the ELF.
     auto ElfOrErr = ELF64LEObjectFile::create(
         MemoryBufferRef(Image, /*Identifier=*/""), /*InitContent=*/false);
     if (!ElfOrErr)
       return ElfOrErr.takeError();
     std::optional<StringRef> Processor = ElfOrErr->tryGetCPUName();
+    if (!Processor)
+      return false;
 
-    for (hsa_agent_t Agent : KernelAgents) {
-      auto TargeTripleAndFeaturesOrError =
-          utils::getTargetTripleAndFeatures(Agent);
-      if (!TargeTripleAndFeaturesOrError)
-        return TargeTripleAndFeaturesOrError.takeError();
-      if (!utils::isImageCompatibleWithEnv(Processor ? *Processor : "",
+    auto TargeTripleAndFeaturesOrError =
+        utils::getTargetTripleAndFeatures(getKernelAgent(DeviceId));
+    if (!TargeTripleAndFeaturesOrError)
+      return TargeTripleAndFeaturesOrError.takeError();
+    return utils::isImageCompatibleWithEnv(Processor ? *Processor : "",
                                            ElfOrErr->getPlatformFlags(),
-                                           *TargeTripleAndFeaturesOrError))
-        return false;
-    }
-    return true;
+                                           *TargeTripleAndFeaturesOrError);
   }
 
   bool isDataExchangable(int32_t SrcDeviceId, int32_t DstDeviceId) override {
diff --git a/offload/plugins-nextgen/common/include/PluginInterface.h b/offload/plugins-nextgen/common/include/PluginInterface.h
index eda6a4fd541e9..4523771e84b6a 100644
--- a/offload/plugins-nextgen/common/include/PluginInterface.h
+++ b/offload/plugins-nextgen/common/include/PluginInterface.h
@@ -994,10 +994,10 @@ struct GenericPluginTy {
   int32_t getNumDevices() const { return NumDevices; }
 
   /// Get the plugin-specific device identifier offset.
-  int32_t getDeviceIdStartIndex() const { return DeviceIdStartIndex; }
-
-  /// Set the plugin-specific device identifier offset.
-  void setDeviceIdStartIndex(int32_t Offset) { DeviceIdStartIndex = Offset; }
+  int32_t getUserId(int32_t DeviceId) const {
+    assert(UserDeviceIds.contains(DeviceId) && "No user-id registered");
+    return UserDeviceIds.at(DeviceId);
+  }
 
   /// Get the ELF code to recognize the binary image of this plugin.
   virtual uint16_t getMagicElfBits() const = 0;
@@ -1059,7 +1059,8 @@ struct GenericPluginTy {
   /// Indicate if an image is compatible with the plugin devices. Notice that
   /// this function may be called before actually initializing the devices. So
   /// we could not move this function into GenericDeviceTy.
-  virtual Expected<bool> isELFCompatible(StringRef Image) const = 0;
+  virtual Expected<bool> isELFCompatible(uint32_t DeviceID,
+                                         StringRef Image) const = 0;
 
 protected:
   /// Indicate whether a device id is valid.
@@ -1070,11 +1071,18 @@ struct GenericPluginTy {
 public:
   // TODO: This plugin interface needs to be cleaned up.
 
-  /// Returns true if the plugin has been initialized.
+  /// Returns non-zero if the plugin runtime has been initialized.
   int32_t is_initialized() const;
 
-  /// Returns non-zero if the provided \p Image can be executed by the runtime.
-  int32_t is_valid_binary(__tgt_device_image *Image, bool Initialized = true);
+  /// Returns non-zero if the \p Image is compatible with the plugin. This
+  /// function does not require the plugin to be initialized before use.
+  int32_t is_plugin_compatible(__tgt_device_image *Image);
+
+  /// Returns non-zero if the \p Image is compatible with the device.
+  int32_t is_device_compatible(int32_t DeviceId, __tgt_device_image *Image);
+
+  /// Returns non-zero if the plugin device has been initialized.
+  int32_t is_device_initialized(int32_t DeviceId) const;
 
   /// Initialize the device inside of the plugin.
   int32_t init_device(int32_t DeviceId);
@@ -1180,7 +1188,7 @@ struct GenericPluginTy {
                            const char **ErrStr);
 
   /// Sets the offset into the devices for use by OMPT.
-  int32_t set_device_offset(int32_t DeviceIdOffset);
+  int32_t set_device_identifier(int32_t UserId, int32_t DeviceId);
 
   /// Returns if the plugin can support auotmatic copy.
   int32_t use_auto_zero_copy(int32_t DeviceId);
@@ -1200,10 +1208,8 @@ struct GenericPluginTy {
   /// Number of devices available for the plugin.
   int32_t NumDevices = 0;
 
-  /// Index offset, which when added to a DeviceId, will yield a unique
-  /// user-observable device identifier. This is especially important when
-  /// DeviceIds of multiple plugins / RTLs need to be distinguishable.
-  int32_t DeviceIdStartIndex = 0;
+  /// Map of plugin device identifiers to the user device identifier.
+  llvm::DenseMap<int32_t, int32_t> UserDeviceIds;
 
   /// Array of pointers to the devices. Initially, they are all set to nullptr.
   /// Once a device is initialized, the pointer is stored in the position given
diff --git a/offload/plugins-nextgen/common/src/PluginInterface.cpp b/offload/plugins-nextgen/common/src/PluginInterface.cpp
index 913721a15d713..5a53c479e33d0 100644
--- a/offload/plugins-nextgen/common/src/PluginInterface.cpp
+++ b/offload/plugins-nextgen/common/src/PluginInterface.cpp
@@ -748,8 +748,7 @@ Error GenericDeviceTy::init(GenericPluginTy &Plugin) {
   if (ompt::Initialized) {
     bool ExpectedStatus = false;
     if (OmptInitialized.compare_exchange_strong(ExpectedStatus, true))
-      performOmptCallback(device_initialize, /*device_num=*/DeviceId +
-                                                 Plugin.getDeviceIdStartIndex(),
+      performOmptCallback(device_initialize, Plugin.getUserId(DeviceId),
                           /*type=*/getComputeUnitKind().c_str(),
                           /*device=*/reinterpret_cast<ompt_device_t *>(this),
                           /*lookup=*/ompt::lookupCallbackByName,
@@ -847,9 +846,7 @@ Error GenericDeviceTy::deinit(GenericPluginTy &Plugin) {
   if (ompt::Initialized) {
     bool ExpectedStatus = true;
     if (OmptInitialized.compare_exchange_strong(ExpectedStatus, false))
-      performOmptCallback(device_finalize,
-                          /*device_num=*/DeviceId +
-                              Plugin.getDeviceIdStartIndex());
+      performOmptCallback(device_finalize, Plugin.getUserId(DeviceId));
   }
 #endif
 
@@ -908,7 +905,7 @@ GenericDeviceTy::loadBinary(GenericPluginTy &Plugin,
     size_t Bytes =
         getPtrDiff(InputTgtImage->ImageEnd, InputTgtImage->ImageStart);
     performOmptCallback(
-        device_load, /*device_num=*/DeviceId + Plugin.getDeviceIdStartIndex(),
+        device_load, Plugin.getUserId(DeviceId),
         /*FileName=*/nullptr, /*FileOffset=*/0, /*VmaInFile=*/nullptr,
         /*ImgSize=*/Bytes, /*HostAddr=*/InputTgtImage->ImageStart,
         /*DeviceAddr=*/nullptr, /* FIXME: ModuleId */ 0);
@@ -1492,11 +1489,14 @@ Error GenericDeviceTy::syncEvent(void *EventPtr) {
 bool GenericDeviceTy::useAutoZeroCopy() { return useAutoZeroCopyImpl(); }
 
 Error GenericPluginTy::init() {
+  if (Initialized)
+    return Plugin::success();
+
   auto NumDevicesOrErr = initImpl();
   if (!NumDevicesOrErr)
     return NumDevicesOrErr.takeError();
-
   Initialized = true;
+
   NumDevices = *NumDevicesOrErr;
   if (NumDevices == 0)
     return Plugin::success();
@@ -1517,6 +1517,8 @@ Error GenericPluginTy::init() {
 }
 
 Error GenericPluginTy::deinit() {
+  assert(Initialized && "Plugin was not initialized!");
+
   // Deinitialize all active devices.
   for (int32_t DeviceId = 0; DeviceId < NumDevices; ++DeviceId) {
     if (Devices[DeviceId]) {
@@ -1537,7 +1539,11 @@ Error GenericPluginTy::deinit() {
     delete RecordReplay;
 
   // Perform last deinitializations on the plugin.
-  return deinitImpl();
+  if (Error Err = deinitImpl())
+    return Err;
+  Initialized = false;
+
+  return Plugin::success();
 }
 
 Error GenericPluginTy::initDevice(int32_t DeviceId) {
@@ -1599,8 +1605,7 @@ Expected<bool> GenericPluginTy::checkBitcodeImage(StringRef Image) const {
 
 int32_t GenericPluginTy::is_initialized() const { return Initialized; }
 
-int32_t GenericPluginTy::is_valid_binary(__tgt_device_image *Image,
-                                         bool Initialized) {
+int32_t GenericPluginTy::is_plugin_compatible(__tgt_device_image *Image) {
   StringRef Buffer(reinterpret_cast<const char *>(Image->ImageStart),
                    target::getPtrDiff(Image->ImageEnd, Image->ImageStart));
 
@@ -1618,11 +1623,43 @@ int32_t GenericPluginTy::is_valid_binary(__tgt_device_image *Image,
     auto MatchOrErr = checkELFImage(Buffer);
     if (Error Err = MatchOrErr.takeError())
       return HandleError(std::move(Err));
-    if (!Initialized || !*MatchOrErr)
-      return *MatchOrErr;
+    return *MatchOrErr;
+  }
+  case file_magic::bitcode: {
+    auto MatchOrErr = checkBitcodeImage(Buffer);
+    if (Error Err = MatchOrErr.takeError())
+      return HandleError(std::move(Err));
+    return *MatchOrErr;
+  }
+  default:
+    return false;
+  }
+}
+
+int32_t GenericPluginTy::is_device_compatible(int32_t DeviceId,
+                                              __tgt_device_image *Image) {
+  StringRef Buffer(reinterpret_cast<const char *>(Image->ImageStart),
+                   target::getPtrDiff(Image->ImageEnd, Image->ImageStart));
+
+  auto HandleError = [&](Error Err) -> bool {
+    [[maybe_unused]] std::string ErrStr = toString(std::move(Err));
+    DP("Failure to check validity of image %p: %s", Image, ErrStr.c_str());
+    return false;
+  };
+  switch (identify_magic(Buffer)) {
+  case file_magic::elf:
+  case file_magic::elf_relocatable:
+  case file_magic::elf_executable:
+  case file_magic::elf_shared_object:
+  case file_magic::elf_core: {
+    auto MatchOrErr = checkELFImage(Buffer);
+    if (Error Err = MatchOrErr.takeError())
+      return HandleError(std::move(Err));
+    if (!*MatchOrErr)
+      return false;
 
     // Perform plugin-dependent checks for the specific architecture if needed.
-    auto CompatibleOrErr = isELFCompatible(Buffer);
+    auto CompatibleOrErr = isELFCompatible(DeviceId, Buffer);
     if (Error Err = CompatibleOrErr.takeError())
       return HandleError(std::move(Err));
     return *CompatibleOrErr;
@@ -1638,6 +1675,10 @@ int32_t GenericPluginTy::is_valid_binary(__tgt_device_image *Image,
   }
 }
 
+int32_t GenericPluginTy::is_device_initialized(int32_t DeviceId) const {
+  return isValidDeviceId(DeviceId) && Devices[DeviceId] != nullptr;
+}
+
 int32_t GenericPluginTy::init_device(int32_t DeviceId) {
   auto Err = initDevice(DeviceId);
   if (Err) {
@@ -1985,8 +2026,9 @@ int32_t GenericPluginTy::init_device_info(int32_t DeviceId,
   return OFFLOAD_SUCCESS;
 }
 
-int32_t GenericPluginTy::set_device_offset(int32_t DeviceIdOffset) {
-  setDeviceIdStartIndex(DeviceIdOffset);
+int32_t GenericPluginTy::set_device_identifier(int32_t UserId,
+                                               int32_t DeviceId) {
+  UserDeviceIds[DeviceId] = UserId;
 
   return OFFLOAD_SUCCESS;
 }
diff --git a/offload/plugins-nextgen/cuda/src/rtl.cpp b/offload/plugins-nextgen/cuda/src/rtl.cpp
index b260334baa18b..62460c07415be 100644
--- a/offload/plugins-nextgen/cuda/src/rtl.cpp
+++ b/offload/plugins-nextgen/cuda/src/rtl.cpp
@@ -1388,8 +1388,9 @@ struct CUDAPluginTy final : public GenericPluginTy {
 
   const char *getName() const override { return GETNAME(TARGET_NAME); }
 
-  /// Check whether the image is compatible with the available CUDA devices.
-  Expected<bool> isELFCompatible(StringRef Image) const override {
+  /// Check whether the image is compatible with a CUDA device.
+  Expected<bool> isELFCompatible(uint32_t DeviceId,
+                                 StringRef Image) const override {
     auto ElfOrErr =
         ELF64LEObjectFile::create(MemoryBufferRef(Image, /*Identifier=*/""),
                                   /*InitContent=*/false);
@@ -1399,33 +1400,29 @@ struct CUDAPluginTy final : public GenericPluginTy {
     // Get the numeric value for the image's `sm_` value.
     auto SM = ElfOrErr->getPlatformFlags() & ELF::EF_CUDA_SM;
 
-    for (int32_t DevId = 0; DevId < getNumDevices(); ++DevId) {
-      CUdevice Device;
-      CUresult Res = cuDeviceGet(&Device, DevId);
-      if (auto Err = Plugin::check(Res, "Error in cuDeviceGet: %s"))
-        return std::move(Err);
-
-      int32_t Major, Minor;
-      Res = cuDeviceGetAttribute(
-          &Major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, Device);
-      if (auto Err = Plugin::check(Res, "Error in cuDeviceGetAttribute: %s"))
-        return std::move(Err);
-
-      Res = cuDeviceGetAttribute(
-          &Minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, Device);
-      if (auto Err = Plugin::check(Res, "Error in cuDeviceGetAttribute: %s"))
-        return std::move(Err);
-
-      int32_t ImageMajor = SM / 10;
-      int32_t ImageMinor = SM % 10;
-
-      // A cubin generated for a certain compute capability is supported to
-      // run on any GPU with the same major revision and same or higher minor
-      // revision.
-      if (Major != ImageMajor || Minor < ImageMinor)
-        return false;
-    }
-    return true;
+    CUdevice Device;
+    CUresult Res = cuDeviceGet(&Device, DeviceId);
+    if (auto Err = Plugin::check(Res, "Error in cuDeviceGet: %s"))
+      return std::move(Err);
+
+    int32_t Major, Minor;
+    Res = cuDeviceGetAttribute(
+        &Major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, Device);
+    if (auto Err = Plugin::check(Res, "Error in cuDeviceGetAttribute: %s"))
+      return std::move(Err);
+
+    Res = cuDeviceGetAttribute(
+        &Minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, Device);
+    if (auto Err = Plugin::check(Res, "Error in cuDeviceGetAttribute: %s"))
+      return std::move(Err);
+
+    int32_t ImageMajor = SM / 10;
+    int32_t ImageMinor = SM % 10;
+
+    // A cubin generated for a certain compute capability is supported to
+    // run on any GPU with the same major revision and same or higher minor
+    // revision.
+    return Major == ImageMajor && Minor >= ImageMinor;
   }
 };
 
diff --git a/offload/plugins-nextgen/host/src/rtl.cpp b/offload/plugins-nextgen/host/src/rtl.cpp
index ef84cbaf54588..aa59ea618e399 100644
--- a/offload/plugins-nextgen/host/src/rtl.cpp
+++ b/offload/plugins-nextgen/host/src/rtl.cpp
@@ -418,7 +418,9 @@ struct GenELF64PluginTy final : public GenericPluginTy {
   }
 
   /// All images (ELF-compatible) should be compatible with this plugin.
-  Expected<bool> isELFCompatible(StringRef) const override { return true; }
+  Expected<bool> isELFCompatible(uint32_t, StringRef) const override {
+    return true;
+  }
 
   Triple::ArchType getTripleArch() const override {
 #if defined(__x86_64__)
diff --git a/offload/src/PluginManager.cpp b/offload/src/PluginManager.cpp
index 13f08b142b876..9c1a7d1356c2f 100644
--- a/offload/src/PluginManager.cpp
+++ b/offload/src/PluginManager.cpp
@@ -47,6 +47,9 @@ void PluginManager::deinit() {
   DP("Unloading RTLs...\n");
 
   for (auto &Plugin : Plugins) {
+    if (!Plugin->is_initialized())
+      continue;
+
     if (auto Err = Plugin->deinit()) {
       [[maybe_unused]] std::string InfoMsg = toString(std::move(Err));
       DP("Failed to deinit plugin: %s\n", InfoMsg.c_str());
@@ -57,90 +60,15 @@ void PluginManager::deinit() {
   DP("RTLs unloaded!\n");
 }
 
-void PluginManager::initDevices(GenericPluginTy &RTL) {
-  // If this RTL has already been initialized.
-  if (PM->DeviceOffsets.contains(&RTL))
-    return;
-  TIMESCOPE();
-
-  // If this RTL is not already in use, initialize it.
-  assert(RTL.number_of_devices() > 0 && "Tried to initialize useless plugin!");
-
-  // Initialize the device information for the RTL we are about to use.
-  auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor();
-
-  // Initialize the index of this RTL and save it in the used RTLs.
-  int32_t DeviceOffset = ExclusiveDevicesAccessor->size();
-
-  // Set the device identifier offset in the plugin.
-  RTL.set_device_offset(DeviceOffset);
-
-  int32_t NumberOfUserDevices = 0;
-  int32_t NumPD = RTL.number_of_devices();
-  ExclusiveDevicesAccessor->reserve(DeviceOffset + NumPD);
-  // Auto zero-copy is a per-device property. We need to ensure
-  // that all devices are suggesting to use it.
-  bool UseAutoZeroCopy = !(NumPD == 0);
-  for (int32_t PDevI = 0, UserDevId = DeviceOffset; PDevI < NumPD; PDevI++) {
-    auto Device = std::make_unique<DeviceTy>(&RTL, UserDevId, PDevI);
-    if (auto Err = Device->init()) {
-      DP("Skip plugin known device %d: %s\n", PDevI,
-         toString(std::move(Err)).c_str());
-      continue;
-    }
-    UseAutoZeroCopy = UseAutoZeroCopy && Device->useAutoZeroCopy();
-
-    ExclusiveDevicesAccessor->push_back(std::move(Device));
-    ++NumberOfUserDevices;
-    ++UserDevId;
-  }
-
-  // Auto Zero-Copy can only be currently triggered when the system is an
-  // homogeneous APU architecture without attached discrete GPUs.
-  // If all devices suggest to use it, change requirment flags to trigger
-  // zero-copy behavior when mapping memory.
-  if (UseAutoZeroCopy)
-    addRequirements(OMPX_REQ_AUTO_ZERO_COPY);
-
-  DeviceOffsets[&RTL] = DeviceOffset;
-  DeviceUsed[&RTL] = NumberOfUserDevices;
-  DP("Plugin has index %d, exposes %d out of %d devices!\n", DeviceOffset,
-     NumberOfUserDevices, RTL.number_of_devices());
-}
-
 void PluginManager::initAllPlugins() {
-  for (auto &R : Plugins)
-    initDevices(*R);
-}
-
-static void registerImageIntoTranslationTable(TranslationTable &TT,
-                                              int32_t DeviceOffset,
-                                              int32_t NumberOfUserDevices,
-                                              __tgt_device_image *Image) {
-
-  // same size, as when we increase one, we also increase the other.
-  assert(TT.TargetsTable.size() == TT.TargetsImages.size() &&
-         "We should have as many images as we have tables!");
-
-  // Resize the Targets Table and Images to accommodate the new targets if
-  // required
-  unsigned TargetsTableMinimumSize = DeviceOffset + NumberOfUserDevices;
-
-  if (TT.TargetsTable.size() < TargetsTableMinimumSize) {
-    TT.DeviceTables.resize(TargetsTableMinimumSize, {});
-    TT.TargetsImages.resize(TargetsTableMinimumSize, 0);
-    TT.TargetsEntries.resize(TargetsTableMinimumSize, {});
-    TT.TargetsTable.resize(TargetsTableMinimumSize, 0);
-  }
-
-  // Register the image in all devices for this target type.
-  for (int32_t I = 0; I < NumberOfUserDevices; ++I) {
-    // If we are changing the image we are also invalidating the target table.
-    if (TT.TargetsImages[DeviceOffset + I] != Image) {
-      TT.TargetsImages[DeviceOffset + I] = Image;
-      TT.TargetsTable[DeviceOffset + I] =
-          0; // lazy initialization of target table.
+  for (auto &R : plugins()) {
+    if (auto Err = R.init()) {
+      [[maybe_unused]] std::string InfoMsg = toString(std::move(Err));
+      DP("Failed to init plugin: %s\n", InfoMsg.c_str());
+      continue;
     }
+    DP("Registered plugin %s with %d visible device(s)\n", R.getName(),
+       R.number_of_devices());
   }
 }
 
@@ -153,27 +81,6 @@ void PluginManager::registerLib(__tgt_bin_desc *Desc) {
     if (Entry.flags == OMP_REGISTER_REQUIRES)
       PM->addRequirements(Entry.data);
 
-  // Initialize all the plugins that have associated images.
-  for (auto &Plugin : Plugins) {
-    // Extract the exectuable image and extra information if availible.
-    for (int32_t i = 0; i < Desc->NumDeviceImages; ++i) {
-      if (Plugin->is_initialized())
-        continue;
-
-      if (!Plugin->is_valid_binary(&Desc->DeviceImages[i],
-                                   /*Initialized=*/false))
-        continue;
-
-      if (auto Err = Plugin->init()) {
-        [[maybe_unused]] std::string InfoMsg = toString(std::move(Err));
-        DP("Failed to init plugin: %s\n", InfoMsg.c_str());
-      } else {
-        DP("Registered plugin %s with %d visible device(s)\n",
-           Plugin->getName(), Plugin->number_of_devices());
-      }
-    }
-  }
-
   // Extract the exectuable image and extra information if availible.
   for (int32_t i = 0; i < Desc->NumDeviceImages; ++i)
     PM->addDeviceImage(*Desc, Desc->DeviceImages[i]);
@@ -188,54 +95,109 @@ void PluginManager::registerLib(__tgt_bin_desc *Desc) {
     // Scan the RTLs that have associated images until we find one that supports
     // the current image.
     for (auto &R : PM->plugins()) {
-      if (!R.number_of_devices())
+      if (!R.is_plugin_compatible(Img))
         continue;
 
-      if (!R.is_valid_binary(Img, /*Initialized=*/true)) {
-        DP("Image " DPxMOD " is NOT compatible with RTL %s!\n",
-           DPxPTR(Img->ImageStart), R.getName());
-        continue;
+      if (!R.is_initialized()) {
+        if (auto Err = R.init()) {
+          [[maybe_unused]] std::string InfoMsg = toString(std::move(Err));
+          DP("Failed to init plugin: %s\n", InfoMsg.c_str());
+          continue;
+        }
+        DP("Registered plugin %s with %d visible device(s)\n", R.getName(),
+           R.number_of_devices());
       }
 
-      DP("Image " DPxMOD " is compatible with RTL %s!\n",
-         DPxPTR(Img->ImageStart), R.getName());
-
-      PM->initDevices(R);
-
-      // Initialize (if necessary) translation table for this library.
-      PM->TrlTblMtx.lock();
-      if (!PM->HostEntriesBeginToTransTable.count(Desc->HostEntriesBegin)) {
-        PM->HostEntriesBeginRegistrationOrder.push_back(Desc->HostEntriesBegin);
-        TranslationTable &TransTable =
-            (PM->HostEntriesBeginToTransTable)[Desc->HostEntriesBegin];
-        TransTable.HostTable.EntriesBegin = Desc->HostEntriesBegin;
-        TransTable.HostTable.EntriesEnd = Desc->HostEntriesEnd;
+      if (!R.number_of_devices()) {
+        DP("Skipping plugin %s with no visible devices\n", R.getName());
+        continue;
       }
 
-      // Retrieve translation table for this library.
-      TranslationTable &TransTable =
-          (PM->HostEntriesBeginToTransTable)[Desc->HostEntriesBegin];
+      for (int32_t DeviceId = 0; DeviceId < R.number_of_devices(); ++DeviceId) {
+        if (!R.is_device_compatible(DeviceId, Img))
+          continue;
+
+        DP("Image " DPxMOD " is compatible with RTL %s device %d!\n",
+           DPxPTR(Img->ImageStart), R.getName(), DeviceId);
+
+        if (!R.is_device_initialized(DeviceId)) {
+          // Initialize the device information for the RTL we are about to use.
+          auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor();
+
+          int32_t UserId = ExclusiveDevicesAccessor->size();
+
+          // Set the device identifier offset in the plugin.
+#ifdef OMPT_SUPPORT
+          R.set_device_identifier(UserId, DeviceId);
+#endif
+
+          auto Device = std::make_unique<DeviceTy>(&R, UserId, DeviceId);
+          if (auto Err = Device->init()) {
+            [[maybe_unused]] std::string InfoMsg = toString(std::move(Err));
+            DP("Failed to init device %d: %s\n", DeviceId, InfoMsg.c_str());
+            continue;
+          }
+
+          ExclusiveDevicesAccessor->push_back(std::move(Device));
+
+          // We need to map between the plugin's device identifier and the one
+          // that OpenMP will use.
+          PM->DeviceIds[std::make_pair(&R, DeviceId)] = UserId;
+        }
+
+        // Initialize (if necessary) translation table for this library.
+        PM->TrlTblMtx.lock();
+        if (!PM->HostEntriesBeginToTransTable.count(Desc->HostEntriesBegin)) {
+          PM->HostEntriesBeginRegistrationOrder.push_back(
+              Desc->HostEntriesBegin);
+          TranslationTable &TT =
+              (PM->HostEntriesBeginToTransTable)[Desc->HostEntriesBegin];
+          TT.HostTable.EntriesBegin = Desc->HostEntriesBegin;
+          TT.HostTable.EntriesEnd = Desc->HostEntriesEnd;
+        }
+
+        // Retrieve translation table for this library.
+        TranslationTable &TT =
+            (PM->HostEntriesBeginToTransTable)[Desc->HostEntriesBegin];
 
-      DP("Registering image " DPxMOD " with RTL %s!\n", DPxPTR(Img->ImageStart),
-         R.getName());
+        DP("Registering image " DPxMOD " with RTL %s!\n",
+           DPxPTR(Img->ImageStart), R.getName());
 
-      registerImageIntoTranslationTable(TransTable, PM->DeviceOffsets[&R],
-                                        PM->DeviceUsed[&R], Img);
+        auto UserId = PM->DeviceIds[std::make_pair(&R, DeviceId)];
+        if (TT.TargetsTable.size() < static_cast<size_t>(UserId + 1)) {
+          TT.DeviceTables.resize(UserId + 1, {});
+          TT.TargetsImages.resize(UserId + 1, nullptr);
+          TT.TargetsEntries.resize(UserId + 1, {});
+          TT.TargetsTable.resize(UserId + 1, nullptr);
+        }
+
+        // Register the image for this target type and invalidate the table.
+        TT.TargetsImages[UserId] = Img;
+        TT.TargetsTable[UserId] = nullptr;
+        PM->TrlTblMtx.unlock();
+      }
       PM->UsedImages.insert(Img);
 
-      PM->TrlTblMtx.unlock();
       FoundRTL = &R;
-
-      // if an RTL was found we are done - proceed to register the next image
-      break;
     }
-
-    if (!FoundRTL) {
+    if (!FoundRTL)
       DP("No RTL found for image " DPxMOD "!\n", DPxPTR(Img->ImageStart));
-    }
   }
   PM->RTLsMtx.unlock();
 
+  bool UseAutoZeroCopy = Plugins.size() > 0;
+
+  auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor();
+  for (const auto &Device : *ExclusiveDevicesAccessor)
+    UseAutoZeroCopy &= Device->useAutoZeroCopy();
+
+  // Auto Zero-Copy can only be currently triggered when the system is an
+  // homogeneous APU architecture without attached discrete GPUs.
+  // If all devices suggest to use it, change requirment flags to trigger
+  // zero-copy behavior when mapping memory.
+  if (UseAutoZeroCopy)
+    addRequirements(OMPX_REQ_AUTO_ZERO_COPY);
+
   DP("Done registering entries!\n");
 }
 
@@ -257,7 +219,7 @@ void PluginManager::unregisterLib(__tgt_bin_desc *Desc) {
     // Scan the RTLs that have associated images until we find one that supports
     // the current image. We only need to scan RTLs that are already being used.
     for (auto &R : PM->plugins()) {
-      if (!DeviceOffsets.contains(&R))
+      if (R.is_initialized())
         continue;
 
       // Ensure that we do not use any unused images associated with this RTL.



More information about the llvm-commits mailing list