[Openmp-commits] [openmp] [OpenMP] Ensure `Devices` is accessed exlusively (PR #74374)

Johannes Doerfert via Openmp-commits openmp-commits at lists.llvm.org
Mon Dec 4 13:51:07 PST 2023


https://github.com/jdoerfert created https://github.com/llvm/llvm-project/pull/74374

We accessed the `Devices` container most of the time while holding the RTLsMtx, but not always. Sometimes we used the mutex for the size query, but then accessed Devices again unguarded. From now we properly encapsulate the container in a ProtectedObj which ensures exclusive accesses. We also hide the "isReady" part in the `getDevice` accessor and use an `llvm::Expected` to allow to return errors.

>From 7d1f652303f7c83690c4a05be85fccef4ef22773 Mon Sep 17 00:00:00 2001
From: Johannes Doerfert <johannes at jdoerfert.de>
Date: Mon, 4 Dec 2023 13:10:36 -0800
Subject: [PATCH] [OpenMP] Ensure `Devices` is accessed exlusively

We accessed the `Devices` container most of the time while holding the
RTLsMtx, but not always. Sometimes we used the mutex for the size query,
but then accessed Devices again unguarded. From now we properly
encapsulate the container in a ProtectedObj which ensures exclusive
accesses. We also hide the "isReady" part in the `getDevice` accessor
and use an `llvm::Expected` to allow to return errors.
---
 openmp/libomptarget/include/PluginManager.h   |  37 ++++-
 openmp/libomptarget/include/Shared/Debug.h    |   7 +-
 openmp/libomptarget/include/device.h          |   7 +-
 openmp/libomptarget/src/OpenMP/InteropAPI.cpp |  29 +++-
 openmp/libomptarget/src/PluginManager.cpp     |  52 ++++++-
 openmp/libomptarget/src/api.cpp               | 108 +++++++-------
 openmp/libomptarget/src/device.cpp            |  37 +----
 openmp/libomptarget/src/interface.cpp         |  52 +++----
 openmp/libomptarget/src/omptarget.cpp         | 134 +++++++-----------
 9 files changed, 237 insertions(+), 226 deletions(-)

diff --git a/openmp/libomptarget/include/PluginManager.h b/openmp/libomptarget/include/PluginManager.h
index 94ecce01ca74c..bc71e5d70474b 100644
--- a/openmp/libomptarget/include/PluginManager.h
+++ b/openmp/libomptarget/include/PluginManager.h
@@ -14,6 +14,7 @@
 #define OMPTARGET_PLUGIN_MANAGER_H
 
 #include "DeviceImage.h"
+#include "ExclusiveAccess.h"
 #include "Shared/APITypes.h"
 #include "Shared/PluginAPI.h"
 #include "Shared/Requirements.h"
@@ -25,6 +26,7 @@
 #include "llvm/ADT/iterator.h"
 #include "llvm/ADT/iterator_range.h"
 #include "llvm/Support/DynamicLibrary.h"
+#include "llvm/Support/Error.h"
 
 #include <cstdint>
 #include <list>
@@ -75,6 +77,13 @@ struct PluginAdaptorTy {
 
 /// Struct for the data required to handle plugins
 struct PluginManager {
+  /// Type of the devices container. We hand out DeviceTy& to queries which are
+  /// stable addresses regardless if the container changes.
+  using DeviceContainerTy = llvm::SmallVector<std::unique_ptr<DeviceTy>>;
+
+  /// Exclusive accessor type for the device container.
+  using ExclusiveDevicesAccessorTy = Accessor<DeviceContainerTy>;
+
   PluginManager() {}
 
   void init();
@@ -89,13 +98,19 @@ struct PluginManager {
     DeviceImages.emplace_back(std::make_unique<DeviceImageTy>(TgtBinDesc, TgtDeviceImage));
   }
 
+  /// Return the device presented to the user as device \p DeviceNo if it is
+  /// initialized and ready. Otherwise return an error explaining the problem.
+  llvm::Expected<DeviceTy &> getDevice(uint32_t DeviceNo);
+
+  /// Iterate over all initialized and ready devices registered with this
+  /// plugin.
+  auto devices(ExclusiveDevicesAccessorTy &DevicesAccessor) {
+    return llvm::make_pointee_range(*DevicesAccessor);
+  }
+
   /// Iterate over all device images registered with this plugin.
   auto deviceImages() { return llvm::make_pointee_range(DeviceImages); }
 
-  /// Devices associated with RTLs
-  llvm::SmallVector<std::unique_ptr<DeviceTy>> Devices;
-  std::mutex RTLsMtx; ///< For RTLs and Devices
-
   /// Translation table retreived from the binary
   HostEntriesBeginToTransTableTy HostEntriesBeginToTransTable;
   std::mutex TrlTblMtx; ///< For Translation Table
@@ -124,9 +139,12 @@ struct PluginManager {
     DelayedBinDesc.clear();
   }
 
-  int getNumDevices() {
-    std::lock_guard<decltype(RTLsMtx)> Lock(RTLsMtx);
-    return Devices.size();
+  /// Return the number of usable devices.
+  int getNumDevices() { return getExclusiveDevicesAccessor()->size(); }
+
+  /// Return an exclusive handle to access the devices container.
+  ExclusiveDevicesAccessorTy getExclusiveDevicesAccessor() {
+    return Devices.getExclusiveAccessor();
   }
 
   int getNumUsedPlugins() const {
@@ -166,6 +184,11 @@ struct PluginManager {
 
   /// The user provided requirements.
   RequirementCollection Requirements;
+
+  std::mutex RTLsMtx; ///< For RTLs
+
+  /// Devices associated with plugins, accesses to the container are exclusive.
+  ProtectedObj<DeviceContainerTy> Devices;
 };
 
 extern PluginManager *PM;
diff --git a/openmp/libomptarget/include/Shared/Debug.h b/openmp/libomptarget/include/Shared/Debug.h
index 9f8818429c779..a39626d15386b 100644
--- a/openmp/libomptarget/include/Shared/Debug.h
+++ b/openmp/libomptarget/include/Shared/Debug.h
@@ -115,15 +115,16 @@ inline uint32_t getDebugLevel() {
 /// Print fatal error message with an error string and error identifier
 #define FATAL_MESSAGE0(_num, _str)                                             \
   do {                                                                         \
-    fprintf(stderr, GETNAME(TARGET_NAME) " fatal error %d: %s\n", _num, _str); \
+    fprintf(stderr, GETNAME(TARGET_NAME) " fatal error %d: %s\n", (int)_num,   \
+            _str);                                                             \
     abort();                                                                   \
   } while (0)
 
 /// Print fatal error message with a printf string and error identifier
 #define FATAL_MESSAGE(_num, _str, ...)                                         \
   do {                                                                         \
-    fprintf(stderr, GETNAME(TARGET_NAME) " fatal error %d: " _str "\n", _num,  \
-            __VA_ARGS__);                                                      \
+    fprintf(stderr, GETNAME(TARGET_NAME) " fatal error %d: " _str "\n",        \
+            (int)_num, __VA_ARGS__);                                           \
     abort();                                                                   \
   } while (0)
 
diff --git a/openmp/libomptarget/include/device.h b/openmp/libomptarget/include/device.h
index 05ed6546557a4..5146fc1444b44 100644
--- a/openmp/libomptarget/include/device.h
+++ b/openmp/libomptarget/include/device.h
@@ -202,9 +202,8 @@ struct DeviceTy {
   /// completed and AsyncInfo.isDone() returns true.
   int32_t queryAsync(AsyncInfoTy &AsyncInfo);
 
-  /// Calls the corresponding print in the \p RTLDEVID
-  /// device RTL to obtain the information of the specific device.
-  bool printDeviceInfo(int32_t RTLDevID);
+  /// Calls the corresponding print device info function in the plugin.
+  bool printDeviceInfo();
 
   /// Event related interfaces.
   /// {
@@ -245,6 +244,4 @@ struct DeviceTy {
   llvm::DenseMap<llvm::StringRef, OffloadEntryTy *> DeviceOffloadEntries;
 };
 
-extern bool deviceIsReady(int DeviceNum);
-
 #endif
diff --git a/openmp/libomptarget/src/OpenMP/InteropAPI.cpp b/openmp/libomptarget/src/OpenMP/InteropAPI.cpp
index 6a40dbca87afd..c96ce2ce60b75 100644
--- a/openmp/libomptarget/src/OpenMP/InteropAPI.cpp
+++ b/openmp/libomptarget/src/OpenMP/InteropAPI.cpp
@@ -13,6 +13,9 @@
 #include "PluginManager.h"
 #include "device.h"
 #include "omptarget.h"
+#include "llvm/Support/Error.h"
+#include <cstdlib>
+#include <cstring>
 
 extern "C" {
 
@@ -190,6 +193,14 @@ __OMP_GET_INTEROP_TY3(const char *, type_desc)
 __OMP_GET_INTEROP_TY3(const char *, rc_desc)
 #undef __OMP_GET_INTEROP_TY3
 
+static const char *copyErrorString(llvm::Error &&Err) {
+  // TODO: Use the error string while avoiding leaks.
+  std::string ErrMsg = llvm::toString(std::move(Err));
+  char *UsrMsg = reinterpret_cast<char *>(malloc(ErrMsg.size() + 1));
+  strcpy(UsrMsg, ErrMsg.c_str());
+  return UsrMsg;
+};
+
 extern "C" {
 
 void __tgt_interop_init(ident_t *LocRef, int32_t Gtid,
@@ -211,12 +222,14 @@ void __tgt_interop_init(ident_t *LocRef, int32_t Gtid,
   }
 
   InteropPtr = new omp_interop_val_t(DeviceId, InteropType);
-  if (!deviceIsReady(DeviceId)) {
-    InteropPtr->err_str = "Device not ready!";
+
+  auto DeviceOrErr = PM->getDevice(DeviceId);
+  if (!DeviceOrErr) {
+    InteropPtr->err_str = copyErrorString(DeviceOrErr.takeError());
     return;
   }
 
-  DeviceTy &Device = *PM->Devices[DeviceId];
+  DeviceTy &Device = *DeviceOrErr;
   if (!Device.RTL || !Device.RTL->init_device_info ||
       Device.RTL->init_device_info(DeviceId, &(InteropPtr)->device_info,
                                    &(InteropPtr)->err_str)) {
@@ -248,8 +261,9 @@ void __tgt_interop_use(ident_t *LocRef, int32_t Gtid,
   assert((DeviceId == -1 || InteropVal->device_id == DeviceId) &&
          "Inconsistent device-id usage!");
 
-  if (!deviceIsReady(DeviceId)) {
-    InteropPtr->err_str = "Device not ready!";
+  auto DeviceOrErr = PM->getDevice(DeviceId);
+  if (!DeviceOrErr) {
+    InteropPtr->err_str = copyErrorString(DeviceOrErr.takeError());
     return;
   }
 
@@ -277,8 +291,9 @@ void __tgt_interop_destroy(ident_t *LocRef, int32_t Gtid,
 
   assert((DeviceId == -1 || InteropVal->device_id == DeviceId) &&
          "Inconsistent device-id usage!");
-  if (!deviceIsReady(DeviceId)) {
-    InteropPtr->err_str = "Device not ready!";
+  auto DeviceOrErr = PM->getDevice(DeviceId);
+  if (!DeviceOrErr) {
+    InteropPtr->err_str = copyErrorString(DeviceOrErr.takeError());
     return;
   }
 
diff --git a/openmp/libomptarget/src/PluginManager.cpp b/openmp/libomptarget/src/PluginManager.cpp
index e6dedeb699b14..931143ad2347d 100644
--- a/openmp/libomptarget/src/PluginManager.cpp
+++ b/openmp/libomptarget/src/PluginManager.cpp
@@ -11,6 +11,10 @@
 //===----------------------------------------------------------------------===//
 
 #include "PluginManager.h"
+#include "Shared/Debug.h"
+
+#include "llvm/Support/Error.h"
+#include "llvm/Support/ErrorHandling.h"
 
 using namespace llvm;
 using namespace llvm::sys;
@@ -71,7 +75,12 @@ PluginAdaptorTy::PluginAdaptorTy(const std::string &Name) : Name(Name) {
 
 void PluginAdaptorTy::addOffloadEntries(DeviceImageTy &DI) {
   for (int32_t I = 0; I < NumberOfDevices; ++I) {
-    DeviceTy &Device = *PM->Devices[DeviceOffset + I];
+    auto DeviceOrErr = PM->getDevice(DeviceOffset + I);
+    if (!DeviceOrErr)
+      FATAL_MESSAGE(DeviceOffset + I, "%s",
+                    toString(DeviceOrErr.takeError()).c_str());
+
+    DeviceTy &Device = *DeviceOrErr;
     for (OffloadEntryTy &Entry : DI.entries())
       Device.addOffloadEntry(Entry);
   }
@@ -97,14 +106,15 @@ void PluginManager::initPlugin(PluginAdaptorTy &Plugin) {
     return;
 
   // Initialize the device information for the RTL we are about to use.
-  const size_t Start = Devices.size();
-  Devices.reserve(Start + Plugin.NumberOfDevices);
+  auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor();
+  const size_t Start = ExclusiveDevicesAccessor->size();
+  ExclusiveDevicesAccessor->reserve(Start + Plugin.NumberOfDevices);
   for (int32_t DeviceId = 0; DeviceId < Plugin.NumberOfDevices; DeviceId++) {
-    Devices.push_back(std::make_unique<DeviceTy>(&Plugin));
+    ExclusiveDevicesAccessor->push_back(std::make_unique<DeviceTy>(&Plugin));
     // global device ID
-    Devices[Start + DeviceId]->DeviceID = Start + DeviceId;
+    (*ExclusiveDevicesAccessor)[Start + DeviceId]->DeviceID = Start + DeviceId;
     // RTL local device ID
-    Devices[Start + DeviceId]->RTLDeviceID = DeviceId;
+    (*ExclusiveDevicesAccessor)[Start + DeviceId]->RTLDeviceID = DeviceId;
   }
 
   // Initialize the index of this RTL and save it in the used RTLs.
@@ -254,7 +264,12 @@ void PluginManager::unregisterLib(__tgt_bin_desc *Desc) {
       // Execute dtors for static objects if the device has been used, i.e.
       // if its PendingCtors list has been emptied.
       for (int32_t I = 0; I < FoundRTL->NumberOfDevices; ++I) {
-        DeviceTy &Device = *PM->Devices[FoundRTL->DeviceOffset + I];
+        auto DeviceOrErr = PM->getDevice(FoundRTL->DeviceOffset + I);
+        if (!DeviceOrErr)
+          FATAL_MESSAGE(FoundRTL->DeviceOffset + I, "%s",
+                        toString(DeviceOrErr.takeError()).c_str());
+
+        DeviceTy &Device = *DeviceOrErr;
         Device.PendingGlobalsMtx.lock();
         if (Device.PendingCtorsDtors[Desc].PendingCtors.empty()) {
           AsyncInfoTy AsyncInfo(Device);
@@ -313,3 +328,26 @@ void PluginManager::unregisterLib(__tgt_bin_desc *Desc) {
 
   DP("Done unregistering library!\n");
 }
+
+Expected<DeviceTy &> PluginManager::getDevice(uint32_t DeviceNo) {
+  auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor();
+  if (DeviceNo >= ExclusiveDevicesAccessor->size())
+    return createStringError(
+        inconvertibleErrorCode(),
+        "Device number '%i' out of range, only %i devices available", DeviceNo,
+        ExclusiveDevicesAccessor->size());
+
+  DeviceTy &Device = *(*ExclusiveDevicesAccessor)[DeviceNo];
+
+  DP("Is the device %d (local ID %d) initialized? %d\n", DeviceNo,
+     Device.RTLDeviceID, Device.IsInit);
+
+  // Init the device if not done before
+  if (!Device.IsInit && Device.initOnce() != OFFLOAD_SUCCESS) {
+    return createStringError(inconvertibleErrorCode(),
+                             "Failed to init device %d\n", DeviceNo);
+  }
+
+  DP("Device %d is ready to use.\n", DeviceNo);
+  return Device;
+}
diff --git a/openmp/libomptarget/src/api.cpp b/openmp/libomptarget/src/api.cpp
index cc4cca286df51..0341e0c754649 100644
--- a/openmp/libomptarget/src/api.cpp
+++ b/openmp/libomptarget/src/api.cpp
@@ -110,21 +110,18 @@ EXTERN int omp_target_is_present(const void *Ptr, int DeviceNum) {
     return true;
   }
 
-  size_t NumDevices = PM->getNumDevices();
-  if (NumDevices <= (size_t)DeviceNum) {
-    DP("Call to omp_target_is_present with invalid device ID, returning "
-       "false\n");
-    return false;
-  }
+  auto DeviceOrErr = PM->getDevice(DeviceNum);
+  if (!DeviceOrErr)
+    FATAL_MESSAGE(DeviceNum, "%s", toString(DeviceOrErr.takeError()).c_str());
 
-  DeviceTy &Device = *PM->Devices[DeviceNum];
   // omp_target_is_present tests whether a host pointer refers to storage that
   // is mapped to a given device. However, due to the lack of the storage size,
   // only check 1 byte. Cannot set size 0 which checks whether the pointer (zero
   // lengh array) is mapped instead of the referred storage.
-  TargetPointerResultTy TPR = Device.getTgtPtrBegin(const_cast<void *>(Ptr), 1,
-                                                    /*UpdateRefCount=*/false,
-                                                    /*UseHoldRefCount=*/false);
+  TargetPointerResultTy TPR =
+      DeviceOrErr->getTgtPtrBegin(const_cast<void *>(Ptr), 1,
+                                  /*UpdateRefCount=*/false,
+                                  /*UseHoldRefCount=*/false);
   int Rc = TPR.isPresent();
   DP("Call to omp_target_is_present returns %d\n", Rc);
   return Rc;
@@ -150,16 +147,6 @@ EXTERN int omp_target_memcpy(void *Dst, const void *Src, size_t Length,
     return OFFLOAD_FAIL;
   }
 
-  if (SrcDevice != omp_get_initial_device() && !deviceIsReady(SrcDevice)) {
-    REPORT("omp_target_memcpy returns OFFLOAD_FAIL\n");
-    return OFFLOAD_FAIL;
-  }
-
-  if (DstDevice != omp_get_initial_device() && !deviceIsReady(DstDevice)) {
-    REPORT("omp_target_memcpy returns OFFLOAD_FAIL\n");
-    return OFFLOAD_FAIL;
-  }
-
   int Rc = OFFLOAD_SUCCESS;
   void *SrcAddr = (char *)const_cast<void *>(Src) + SrcOffset;
   void *DstAddr = (char *)Dst + DstOffset;
@@ -172,35 +159,49 @@ EXTERN int omp_target_memcpy(void *Dst, const void *Src, size_t Length,
       Rc = OFFLOAD_FAIL;
   } else if (SrcDevice == omp_get_initial_device()) {
     DP("copy from host to device\n");
-    DeviceTy &DstDev = *PM->Devices[DstDevice];
-    AsyncInfoTy AsyncInfo(DstDev);
-    Rc = DstDev.submitData(DstAddr, SrcAddr, Length, AsyncInfo);
+    auto DstDeviceOrErr = PM->getDevice(DstDevice);
+    if (!DstDeviceOrErr)
+      FATAL_MESSAGE(DstDevice, "%s",
+                    toString(DstDeviceOrErr.takeError()).c_str());
+    AsyncInfoTy AsyncInfo(*DstDeviceOrErr);
+    Rc = DstDeviceOrErr->submitData(DstAddr, SrcAddr, Length, AsyncInfo);
   } else if (DstDevice == omp_get_initial_device()) {
     DP("copy from device to host\n");
-    DeviceTy &SrcDev = *PM->Devices[SrcDevice];
-    AsyncInfoTy AsyncInfo(SrcDev);
-    Rc = SrcDev.retrieveData(DstAddr, SrcAddr, Length, AsyncInfo);
+    auto SrcDeviceOrErr = PM->getDevice(SrcDevice);
+    if (!SrcDeviceOrErr)
+      FATAL_MESSAGE(SrcDevice, "%s",
+                    toString(SrcDeviceOrErr.takeError()).c_str());
+    AsyncInfoTy AsyncInfo(*SrcDeviceOrErr);
+    Rc = SrcDeviceOrErr->retrieveData(DstAddr, SrcAddr, Length, AsyncInfo);
   } else {
     DP("copy from device to device\n");
-    DeviceTy &SrcDev = *PM->Devices[SrcDevice];
-    DeviceTy &DstDev = *PM->Devices[DstDevice];
+    auto SrcDeviceOrErr = PM->getDevice(SrcDevice);
+    if (!SrcDeviceOrErr)
+      FATAL_MESSAGE(SrcDevice, "%s",
+                    toString(SrcDeviceOrErr.takeError()).c_str());
+    AsyncInfoTy AsyncInfo(*SrcDeviceOrErr);
+    auto DstDeviceOrErr = PM->getDevice(DstDevice);
+    if (!DstDeviceOrErr)
+      FATAL_MESSAGE(DstDevice, "%s",
+                    toString(DstDeviceOrErr.takeError()).c_str());
     // First try to use D2D memcpy which is more efficient. If fails, fall back
     // to unefficient way.
-    if (SrcDev.isDataExchangable(DstDev)) {
-      AsyncInfoTy AsyncInfo(SrcDev);
-      Rc = SrcDev.dataExchange(SrcAddr, DstDev, DstAddr, Length, AsyncInfo);
+    if (SrcDeviceOrErr->isDataExchangable(*DstDeviceOrErr)) {
+      AsyncInfoTy AsyncInfo(*SrcDeviceOrErr);
+      Rc = SrcDeviceOrErr->dataExchange(SrcAddr, *DstDeviceOrErr, DstAddr,
+                                        Length, AsyncInfo);
       if (Rc == OFFLOAD_SUCCESS)
         return OFFLOAD_SUCCESS;
     }
 
     void *Buffer = malloc(Length);
     {
-      AsyncInfoTy AsyncInfo(SrcDev);
-      Rc = SrcDev.retrieveData(Buffer, SrcAddr, Length, AsyncInfo);
+      AsyncInfoTy AsyncInfo(*SrcDeviceOrErr);
+      Rc = SrcDeviceOrErr->retrieveData(Buffer, SrcAddr, Length, AsyncInfo);
     }
     if (Rc == OFFLOAD_SUCCESS) {
-      AsyncInfoTy AsyncInfo(DstDev);
-      Rc = DstDev.submitData(DstAddr, Buffer, Length, AsyncInfo);
+      AsyncInfoTy AsyncInfo(*DstDeviceOrErr);
+      Rc = DstDeviceOrErr->submitData(DstAddr, Buffer, Length, AsyncInfo);
     }
     free(Buffer);
   }
@@ -507,15 +508,13 @@ EXTERN int omp_target_associate_ptr(const void *HostPtr, const void *DevicePtr,
     return OFFLOAD_FAIL;
   }
 
-  if (!deviceIsReady(DeviceNum)) {
-    REPORT("omp_target_associate_ptr returns OFFLOAD_FAIL\n");
-    return OFFLOAD_FAIL;
-  }
+  auto DeviceOrErr = PM->getDevice(DeviceNum);
+  if (!DeviceOrErr)
+    FATAL_MESSAGE(DeviceNum, "%s", toString(DeviceOrErr.takeError()).c_str());
 
-  DeviceTy &Device = *PM->Devices[DeviceNum];
   void *DeviceAddr = (void *)((uint64_t)DevicePtr + (uint64_t)DeviceOffset);
-  int Rc = Device.associatePtr(const_cast<void *>(HostPtr),
-                               const_cast<void *>(DeviceAddr), Size);
+  int Rc = DeviceOrErr->associatePtr(const_cast<void *>(HostPtr),
+                                     const_cast<void *>(DeviceAddr), Size);
   DP("omp_target_associate_ptr returns %d\n", Rc);
   return Rc;
 }
@@ -537,13 +536,11 @@ EXTERN int omp_target_disassociate_ptr(const void *HostPtr, int DeviceNum) {
     return OFFLOAD_FAIL;
   }
 
-  if (!deviceIsReady(DeviceNum)) {
-    REPORT("omp_target_disassociate_ptr returns OFFLOAD_FAIL\n");
-    return OFFLOAD_FAIL;
-  }
+  auto DeviceOrErr = PM->getDevice(DeviceNum);
+  if (!DeviceOrErr)
+    FATAL_MESSAGE(DeviceNum, "%s", toString(DeviceOrErr.takeError()).c_str());
 
-  DeviceTy &Device = *PM->Devices[DeviceNum];
-  int Rc = Device.disassociatePtr(const_cast<void *>(HostPtr));
+  int Rc = DeviceOrErr->disassociatePtr(const_cast<void *>(HostPtr));
   DP("omp_target_disassociate_ptr returns %d\n", Rc);
   return Rc;
 }
@@ -570,15 +567,14 @@ EXTERN void *omp_get_mapped_ptr(const void *Ptr, int DeviceNum) {
     return nullptr;
   }
 
-  if (!deviceIsReady(DeviceNum)) {
-    REPORT("Device %d is not ready, returning nullptr.\n", DeviceNum);
-    return nullptr;
-  }
+  auto DeviceOrErr = PM->getDevice(DeviceNum);
+  if (!DeviceOrErr)
+    FATAL_MESSAGE(DeviceNum, "%s", toString(DeviceOrErr.takeError()).c_str());
 
-  auto &Device = *PM->Devices[DeviceNum];
-  TargetPointerResultTy TPR = Device.getTgtPtrBegin(const_cast<void *>(Ptr), 1,
-                                                    /*UpdateRefCount=*/false,
-                                                    /*UseHoldRefCount=*/false);
+  TargetPointerResultTy TPR =
+      DeviceOrErr->getTgtPtrBegin(const_cast<void *>(Ptr), 1,
+                                  /*UpdateRefCount=*/false,
+                                  /*UseHoldRefCount=*/false);
   if (!TPR.isPresent()) {
     DP("Ptr " DPxMOD "is not present on device %d, returning nullptr.\n",
        DPxPTR(Ptr), DeviceNum);
diff --git a/openmp/libomptarget/src/device.cpp b/openmp/libomptarget/src/device.cpp
index d3481d42af967..ad9563e04def4 100644
--- a/openmp/libomptarget/src/device.cpp
+++ b/openmp/libomptarget/src/device.cpp
@@ -711,10 +711,10 @@ int32_t DeviceTy::launchKernel(void *TgtEntryPtr, void **TgtVarsPtr,
 }
 
 // Run region on device
-bool DeviceTy::printDeviceInfo(int32_t RTLDevId) {
+bool DeviceTy::printDeviceInfo() {
   if (!RTL->print_device_info)
     return false;
-  RTL->print_device_info(RTLDevId);
+  RTL->print_device_info(RTLDeviceID);
   return true;
 }
 
@@ -778,39 +778,6 @@ int32_t DeviceTy::destroyEvent(void *Event) {
   return OFFLOAD_SUCCESS;
 }
 
-/// Check whether a device has an associated RTL and initialize it if it's not
-/// already initialized.
-bool deviceIsReady(int DeviceNum) {
-  DP("Checking whether device %d is ready.\n", DeviceNum);
-  // Devices.size() can only change while registering a new
-  // library, so try to acquire the lock of RTLs' mutex.
-  size_t DevicesSize;
-  {
-    std::lock_guard<decltype(PM->RTLsMtx)> LG(PM->RTLsMtx);
-    DevicesSize = PM->Devices.size();
-  }
-  if (DevicesSize <= (size_t)DeviceNum) {
-    DP("Device ID  %d does not have a matching RTL\n", DeviceNum);
-    return false;
-  }
-
-  // Get device info
-  DeviceTy &Device = *PM->Devices[DeviceNum];
-
-  DP("Is the device %d (local ID %d) initialized? %d\n", DeviceNum,
-     Device.RTLDeviceID, Device.IsInit);
-
-  // Init the device if not done before
-  if (!Device.IsInit && Device.initOnce() != OFFLOAD_SUCCESS) {
-    DP("Failed to init device %d\n", DeviceNum);
-    return false;
-  }
-
-  DP("Device %d is ready to use.\n", DeviceNum);
-
-  return true;
-}
-
 void DeviceTy::addOffloadEntry(OffloadEntryTy &Entry) {
   std::lock_guard<decltype(PendingGlobalsMtx)> Lock(PendingGlobalsMtx);
   DeviceOffloadEntries[Entry.getName()] = &Entry;
diff --git a/openmp/libomptarget/src/interface.cpp b/openmp/libomptarget/src/interface.cpp
index 62cf2262deb62..d92f40ce1d14e 100644
--- a/openmp/libomptarget/src/interface.cpp
+++ b/openmp/libomptarget/src/interface.cpp
@@ -95,8 +95,11 @@ targetData(ident_t *Loc, int64_t DeviceId, int32_t ArgNum, void **ArgsBase,
   }
 #endif
 
-  DeviceTy &Device = *PM->Devices[DeviceId];
-  TargetAsyncInfoTy TargetAsyncInfo(Device);
+  auto DeviceOrErr = PM->getDevice(DeviceId);
+  if (!DeviceOrErr)
+    FATAL_MESSAGE(DeviceId, "%s", toString(DeviceOrErr.takeError()).c_str());
+
+  TargetAsyncInfoTy TargetAsyncInfo(*DeviceOrErr);
   AsyncInfoTy &AsyncInfo = TargetAsyncInfo;
 
   /// RAII to establish tool anchors before and after data begin / end / update
@@ -115,7 +118,7 @@ targetData(ident_t *Loc, int64_t DeviceId, int32_t ArgNum, void **ArgsBase,
                                              OMPT_GET_RETURN_ADDRESS(0));)
 
   int Rc = OFFLOAD_SUCCESS;
-  Rc = TargetDataFunction(Loc, Device, ArgNum, ArgsBase, Args, ArgSizes,
+  Rc = TargetDataFunction(Loc, *DeviceOrErr, ArgNum, ArgsBase, Args, ArgSizes,
                           ArgTypes, ArgNames, ArgMappers, AsyncInfo,
                           false /* FromMapper */);
 
@@ -286,8 +289,11 @@ static inline int targetKernel(ident_t *Loc, int64_t DeviceId, int32_t NumTeams,
   }
 #endif
 
-  DeviceTy &Device = *PM->Devices[DeviceId];
-  TargetAsyncInfoTy TargetAsyncInfo(Device);
+  auto DeviceOrErr = PM->getDevice(DeviceId);
+  if (!DeviceOrErr)
+    FATAL_MESSAGE(DeviceId, "%s", toString(DeviceOrErr.takeError()).c_str());
+
+  TargetAsyncInfoTy TargetAsyncInfo(*DeviceOrErr);
   AsyncInfoTy &AsyncInfo = TargetAsyncInfo;
   /// RAII to establish tool anchors before and after target region
   OMPT_IF_BUILT(InterfaceRAII TargetRAII(
@@ -295,7 +301,7 @@ static inline int targetKernel(ident_t *Loc, int64_t DeviceId, int32_t NumTeams,
                     /* CodePtr */ OMPT_GET_RETURN_ADDRESS(0));)
 
   int Rc = OFFLOAD_SUCCESS;
-  Rc = target(Loc, Device, HostPtr, *KernelArgs, AsyncInfo);
+  Rc = target(Loc, *DeviceOrErr, HostPtr, *KernelArgs, AsyncInfo);
 
   if (Rc == OFFLOAD_SUCCESS)
     Rc = AsyncInfo.synchronize();
@@ -339,14 +345,12 @@ EXTERN int __tgt_activate_record_replay(int64_t DeviceId, uint64_t MemorySize,
                                         void *VAddr, bool IsRecord,
                                         bool SaveOutput,
                                         uint64_t &ReqPtrArgOffset) {
-  if (!deviceIsReady(DeviceId)) {
-    DP("Device %" PRId64 " is not ready\n", DeviceId);
-    return OMP_TGT_FAIL;
-  }
+  auto DeviceOrErr = PM->getDevice(DeviceId);
+  if (!DeviceOrErr)
+    FATAL_MESSAGE(DeviceId, "%s", toString(DeviceOrErr.takeError()).c_str());
 
-  DeviceTy &Device = *PM->Devices[DeviceId];
   [[maybe_unused]] int Rc = target_activate_rr(
-      Device, MemorySize, VAddr, IsRecord, SaveOutput, ReqPtrArgOffset);
+      *DeviceOrErr, MemorySize, VAddr, IsRecord, SaveOutput, ReqPtrArgOffset);
   assert(Rc == OFFLOAD_SUCCESS &&
          "__tgt_activate_record_replay unexpected failure!");
   return OMP_TGT_SUCCESS;
@@ -380,16 +384,19 @@ EXTERN int __tgt_target_kernel_replay(ident_t *Loc, int64_t DeviceId,
     DP("Not offloading to device %" PRId64 "\n", DeviceId);
     return OMP_TGT_FAIL;
   }
-  DeviceTy &Device = *PM->Devices[DeviceId];
+  auto DeviceOrErr = PM->getDevice(DeviceId);
+  if (!DeviceOrErr)
+    FATAL_MESSAGE(DeviceId, "%s", toString(DeviceOrErr.takeError()).c_str());
+
   /// RAII to establish tool anchors before and after target region
   OMPT_IF_BUILT(InterfaceRAII TargetRAII(
                     RegionInterface.getCallbacks<ompt_target>(), DeviceId,
                     /* CodePtr */ OMPT_GET_RETURN_ADDRESS(0));)
 
-  AsyncInfoTy AsyncInfo(Device);
-  int Rc = target_replay(Loc, Device, HostPtr, DeviceMemory, DeviceMemorySize,
-                         TgtArgs, TgtOffsets, NumArgs, NumTeams, ThreadLimit,
-                         LoopTripCount, AsyncInfo);
+  AsyncInfoTy AsyncInfo(*DeviceOrErr);
+  int Rc = target_replay(Loc, *DeviceOrErr, HostPtr, DeviceMemory,
+                         DeviceMemorySize, TgtArgs, TgtOffsets, NumArgs,
+                         NumTeams, ThreadLimit, LoopTripCount, AsyncInfo);
   if (Rc == OFFLOAD_SUCCESS)
     Rc = AsyncInfo.synchronize();
   handleTargetOutcome(Rc == OFFLOAD_SUCCESS, Loc);
@@ -433,14 +440,11 @@ EXTERN void __tgt_set_info_flag(uint32_t NewInfoLevel) {
 }
 
 EXTERN int __tgt_print_device_info(int64_t DeviceId) {
-  // Make sure the device is ready.
-  if (!deviceIsReady(DeviceId)) {
-    DP("Device %" PRId64 " is not ready\n", DeviceId);
-    return OMP_TGT_FAIL;
-  }
+  auto DeviceOrErr = PM->getDevice(DeviceId);
+  if (!DeviceOrErr)
+    FATAL_MESSAGE(DeviceId, "%s", toString(DeviceOrErr.takeError()).c_str());
 
-  return PM->Devices[DeviceId]->printDeviceInfo(
-      PM->Devices[DeviceId]->RTLDeviceID);
+  return DeviceOrErr->printDeviceInfo();
 }
 
 EXTERN void __tgt_target_nowait_query(void **AsyncHandle) {
diff --git a/openmp/libomptarget/src/omptarget.cpp b/openmp/libomptarget/src/omptarget.cpp
index 1fcadc018f72e..a9e22236dca27 100644
--- a/openmp/libomptarget/src/omptarget.cpp
+++ b/openmp/libomptarget/src/omptarget.cpp
@@ -16,6 +16,7 @@
 #include "OpenMP/OMPT/Callback.h"
 #include "OpenMP/OMPT/Interface.h"
 #include "PluginManager.h"
+#include "Shared/Debug.h"
 #include "Shared/EnvironmentVar.h"
 #include "device.h"
 #include "private.h"
@@ -299,10 +300,11 @@ void handleTargetOutcome(bool Success, ident_t *Loc) {
     break;
   case OffloadPolicy::MANDATORY:
     if (!Success) {
-      if (getInfoLevel() & OMP_INFOTYPE_DUMP_TABLE)
-        for (auto &Device : PM->Devices)
-          dumpTargetPointerMappings(Loc, *Device);
-      else
+      if (getInfoLevel() & OMP_INFOTYPE_DUMP_TABLE) {
+        auto ExclusiveDevicesAccessor = PM->getExclusiveDevicesAccessor();
+        for (auto &Device : PM->devices(ExclusiveDevicesAccessor))
+          dumpTargetPointerMappings(Loc, Device);
+      } else
         FAILURE_MESSAGE("Consult https://openmp.llvm.org/design/Runtimes.html "
                         "for debugging options.\n");
 
@@ -325,9 +327,11 @@ void handleTargetOutcome(bool Success, ident_t *Loc) {
       FATAL_MESSAGE0(
           1, "failure of target construct while offloading is mandatory");
     } else {
-      if (getInfoLevel() & OMP_INFOTYPE_DUMP_TABLE)
-        for (auto &Device : PM->Devices)
-          dumpTargetPointerMappings(Loc, *Device);
+      if (getInfoLevel() & OMP_INFOTYPE_DUMP_TABLE) {
+        auto ExclusiveDevicesAccessor = PM->getExclusiveDevicesAccessor();
+        for (auto &Device : PM->devices(ExclusiveDevicesAccessor))
+          dumpTargetPointerMappings(Loc, Device);
+      }
     }
     break;
   }
@@ -369,21 +373,15 @@ bool checkDeviceAndCtors(int64_t &DeviceID, ident_t *Loc) {
     return true;
   }
 
-  // Is device ready?
-  if (!deviceIsReady(DeviceID)) {
-    REPORT("Device %" PRId64 " is not ready.\n", DeviceID);
-    handleTargetOutcome(false, Loc);
-    return true;
-  }
-
-  // Get device info.
-  DeviceTy &Device = *PM->Devices[DeviceID];
+  auto DeviceOrErr = PM->getDevice(DeviceID);
+  if (!DeviceOrErr)
+    FATAL_MESSAGE(DeviceID, "%s", toString(DeviceOrErr.takeError()).data());
 
   // Check whether global data has been mapped for this device
   {
-    std::lock_guard<decltype(Device.PendingGlobalsMtx)> LG(
-        Device.PendingGlobalsMtx);
-    if (initLibrary(Device) != OFFLOAD_SUCCESS) {
+    std::lock_guard<decltype(DeviceOrErr->PendingGlobalsMtx)> LG(
+        DeviceOrErr->PendingGlobalsMtx);
+    if (initLibrary(*DeviceOrErr) != OFFLOAD_SUCCESS) {
       REPORT("Failed to init globals on device %" PRId64 "\n", DeviceID);
       handleTargetOutcome(false, Loc);
       return true;
@@ -415,13 +413,11 @@ void *targetAllocExplicit(size_t Size, int DeviceNum, int Kind,
     return Rc;
   }
 
-  if (!deviceIsReady(DeviceNum)) {
-    DP("%s returns NULL ptr\n", Name);
-    return NULL;
-  }
+  auto DeviceOrErr = PM->getDevice(DeviceNum);
+  if (!DeviceOrErr)
+    FATAL_MESSAGE(DeviceNum, "%s", toString(DeviceOrErr.takeError()).c_str());
 
-  DeviceTy &Device = *PM->Devices[DeviceNum];
-  Rc = Device.allocData(Size, nullptr, Kind);
+  Rc = DeviceOrErr->allocData(Size, nullptr, Kind);
   DP("%s returns device ptr " DPxMOD "\n", Name, DPxPTR(Rc));
   return Rc;
 }
@@ -443,12 +439,11 @@ void targetFreeExplicit(void *DevicePtr, int DeviceNum, int Kind,
     return;
   }
 
-  if (!deviceIsReady(DeviceNum)) {
-    DP("%s returns, nothing to do\n", Name);
-    return;
-  }
+  auto DeviceOrErr = PM->getDevice(DeviceNum);
+  if (!DeviceOrErr)
+    FATAL_MESSAGE(DeviceNum, "%s", toString(DeviceOrErr.takeError()).c_str());
 
-  PM->Devices[DeviceNum]->deleteData(DevicePtr, Kind);
+  DeviceOrErr->deleteData(DevicePtr, Kind);
   DP("omp_target_free deallocated device ptr\n");
 }
 
@@ -464,26 +459,13 @@ void *targetLockExplicit(void *HostPtr, size_t Size, int DeviceNum,
 
   void *RC = NULL;
 
-  if (!deviceIsReady(DeviceNum)) {
-    DP("%s returns NULL ptr\n", Name);
-    return NULL;
-  }
-
-  DeviceTy *DevicePtr = nullptr;
-  {
-    std::lock_guard<decltype(PM->RTLsMtx)> LG(PM->RTLsMtx);
-
-    if (!PM->Devices[DeviceNum]) {
-      DP("%s returns, device %d not available\n", Name, DeviceNum);
-      return nullptr;
-    }
-
-    DevicePtr = PM->Devices[DeviceNum].get();
-  }
+  auto DeviceOrErr = PM->getDevice(DeviceNum);
+  if (!DeviceOrErr)
+    FATAL_MESSAGE(DeviceNum, "%s", toString(DeviceOrErr.takeError()).c_str());
 
   int32_t Err = 0;
-  if (DevicePtr->RTL->data_lock) {
-    Err = DevicePtr->RTL->data_lock(DeviceNum, HostPtr, Size, &RC);
+  if (!DeviceOrErr->RTL->data_lock) {
+    Err = DeviceOrErr->RTL->data_lock(DeviceNum, HostPtr, Size, &RC);
     if (Err) {
       DP("Could not lock ptr %p\n", HostPtr);
       return nullptr;
@@ -497,31 +479,12 @@ void targetUnlockExplicit(void *HostPtr, int DeviceNum, const char *Name) {
   TIMESCOPE();
   DP("Call to %s for device %d unlocking\n", Name, DeviceNum);
 
-  DeviceTy *DevicePtr = nullptr;
-  {
-    std::lock_guard<decltype(PM->RTLsMtx)> LG(PM->RTLsMtx);
-
-    // Don't check deviceIsReady as it can initialize the device if needed.
-    // Just check if DeviceNum exists as targetUnlockExplicit can be called
-    // during process exit/free (and it may have been already destroyed) and
-    // targetAllocExplicit will have already checked deviceIsReady anyway.
-    size_t DevicesSize = PM->Devices.size();
+  auto DeviceOrErr = PM->getDevice(DeviceNum);
+  if (!DeviceOrErr)
+    FATAL_MESSAGE(DeviceNum, "%s", toString(DeviceOrErr.takeError()).c_str());
 
-    if (DevicesSize <= (size_t)DeviceNum) {
-      DP("Device ID  %d does not have a matching RTL\n", DeviceNum);
-      return;
-    }
-
-    if (!PM->Devices[DeviceNum]) {
-      DP("%s returns, device %d not available\n", Name, DeviceNum);
-      return;
-    }
-
-    DevicePtr = PM->Devices[DeviceNum].get();
-  } // unlock RTLsMtx
-
-  if (DevicePtr->RTL->data_unlock)
-    DevicePtr->RTL->data_unlock(DeviceNum, HostPtr);
+  if (!DeviceOrErr->RTL->data_unlock)
+    DeviceOrErr->RTL->data_unlock(DeviceNum, HostPtr);
 
   DP("%s returns\n", Name);
 }
@@ -1446,8 +1409,12 @@ static int processDataBefore(ident_t *Loc, int64_t DeviceId, void *HostPtr,
                              PrivateArgumentManagerTy &PrivateArgumentManager,
                              AsyncInfoTy &AsyncInfo) {
   TIMESCOPE_WITH_NAME_AND_IDENT("mappingBeforeTargetRegion", Loc);
-  DeviceTy &Device = *PM->Devices[DeviceId];
-  int Ret = targetDataBegin(Loc, Device, ArgNum, ArgBases, Args, ArgSizes,
+
+  auto DeviceOrErr = PM->getDevice(DeviceId);
+  if (!DeviceOrErr)
+    FATAL_MESSAGE(DeviceId, "%s", toString(DeviceOrErr.takeError()).c_str());
+
+  int Ret = targetDataBegin(Loc, *DeviceOrErr, ArgNum, ArgBases, Args, ArgSizes,
                             ArgTypes, ArgNames, ArgMappers, AsyncInfo);
   if (Ret != OFFLOAD_SUCCESS) {
     REPORT("Call to targetDataBegin failed, abort target.\n");
@@ -1478,7 +1445,7 @@ static int processDataBefore(ident_t *Loc, int64_t DeviceId, void *HostPtr,
         uint64_t Delta = (uint64_t)HstPtrBegin - (uint64_t)HstPtrBase;
         void *TgtPtrBegin = (void *)((uintptr_t)TgtPtrBase + Delta);
         void *&PointerTgtPtrBegin = AsyncInfo.getVoidPtrLocation();
-        TargetPointerResultTy TPR = Device.getTgtPtrBegin(
+        TargetPointerResultTy TPR = DeviceOrErr->getTgtPtrBegin(
             HstPtrVal, ArgSizes[I], /*UpdateRefCount=*/false,
             /*UseHoldRefCount=*/false);
         PointerTgtPtrBegin = TPR.TargetPointer;
@@ -1495,8 +1462,9 @@ static int processDataBefore(ident_t *Loc, int64_t DeviceId, void *HostPtr,
         }
         DP("Update lambda reference (" DPxMOD ") -> [" DPxMOD "]\n",
            DPxPTR(PointerTgtPtrBegin), DPxPTR(TgtPtrBegin));
-        Ret = Device.submitData(TgtPtrBegin, &PointerTgtPtrBegin,
-                                sizeof(void *), AsyncInfo, TPR.getEntry());
+        Ret =
+            DeviceOrErr->submitData(TgtPtrBegin, &PointerTgtPtrBegin,
+                                    sizeof(void *), AsyncInfo, TPR.getEntry());
         if (Ret != OFFLOAD_SUCCESS) {
           REPORT("Copying data to device failed.\n");
           return OFFLOAD_FAIL;
@@ -1535,9 +1503,9 @@ static int processDataBefore(ident_t *Loc, int64_t DeviceId, void *HostPtr,
     } else {
       if (ArgTypes[I] & OMP_TGT_MAPTYPE_PTR_AND_OBJ)
         HstPtrBase = *reinterpret_cast<void **>(HstPtrBase);
-      TPR = Device.getTgtPtrBegin(HstPtrBegin, ArgSizes[I],
-                                  /*UpdateRefCount=*/false,
-                                  /*UseHoldRefCount=*/false);
+      TPR = DeviceOrErr->getTgtPtrBegin(HstPtrBegin, ArgSizes[I],
+                                        /*UpdateRefCount=*/false,
+                                        /*UseHoldRefCount=*/false);
       TgtPtrBegin = TPR.TargetPointer;
       TgtBaseOffset = (intptr_t)HstPtrBase - (intptr_t)HstPtrBegin;
 #ifdef OMPTARGET_DEBUG
@@ -1573,10 +1541,12 @@ static int processDataAfter(ident_t *Loc, int64_t DeviceId, void *HostPtr,
                             PrivateArgumentManagerTy &PrivateArgumentManager,
                             AsyncInfoTy &AsyncInfo) {
   TIMESCOPE_WITH_NAME_AND_IDENT("mappingAfterTargetRegion", Loc);
-  DeviceTy &Device = *PM->Devices[DeviceId];
+  auto DeviceOrErr = PM->getDevice(DeviceId);
+  if (!DeviceOrErr)
+    FATAL_MESSAGE(DeviceId, "%s", toString(DeviceOrErr.takeError()).c_str());
 
   // Move data from device.
-  int Ret = targetDataEnd(Loc, Device, ArgNum, ArgBases, Args, ArgSizes,
+  int Ret = targetDataEnd(Loc, *DeviceOrErr, ArgNum, ArgBases, Args, ArgSizes,
                           ArgTypes, ArgNames, ArgMappers, AsyncInfo);
   if (Ret != OFFLOAD_SUCCESS) {
     REPORT("Call to targetDataEnd failed, abort target.\n");



More information about the Openmp-commits mailing list