[Openmp-commits] [openmp] 53e4c7c - [OpenMP][libomptarget] Improving plugin resource managers

Kevin Sala via Openmp-commits openmp-commits at lists.llvm.org
Thu Jul 27 15:39:40 PDT 2023


Author: Kevin Sala
Date: 2023-07-28T00:37:08+02:00
New Revision: 53e4c7c30936564b71fe7f581691fe8356eafeda

URL: https://github.com/llvm/llvm-project/commit/53e4c7c30936564b71fe7f581691fe8356eafeda
DIFF: https://github.com/llvm/llvm-project/commit/53e4c7c30936564b71fe7f581691fe8356eafeda.diff

LOG: [OpenMP][libomptarget] Improving plugin resource managers

This patch improves the resource managers in the plugins by properly handling
the errors. Until now, errors when creating and destroying resources were not
propagated and were directly handled inside the resource managers. Now, all
errors are propagated as in the rest of the plugin infrastructure.

The code is now ready to implement the request/return of multiple resources in
a single getResource/returnResource call.

Differential Revision: https://reviews.llvm.org/D155621

Added: 
    

Modified: 
    openmp/libomptarget/plugins-nextgen/amdgpu/src/rtl.cpp
    openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.h
    openmp/libomptarget/plugins-nextgen/cuda/src/rtl.cpp

Removed: 
    


################################################################################
diff  --git a/openmp/libomptarget/plugins-nextgen/amdgpu/src/rtl.cpp b/openmp/libomptarget/plugins-nextgen/amdgpu/src/rtl.cpp
index e443f8be557845..905bc51bcd2142 100644
--- a/openmp/libomptarget/plugins-nextgen/amdgpu/src/rtl.cpp
+++ b/openmp/libomptarget/plugins-nextgen/amdgpu/src/rtl.cpp
@@ -135,11 +135,14 @@ Error iterateAgentMemoryPools(hsa_agent_t Agent, CallbackTy Cb) {
 /// Utility class representing generic resource references to AMDGPU resources.
 template <typename ResourceTy>
 struct AMDGPUResourceRef : public GenericDeviceResourceRef {
+  /// The underlying handle type for resources.
+  using HandleTy = ResourceTy *;
+
   /// Create an empty reference to an invalid resource.
   AMDGPUResourceRef() : Resource(nullptr) {}
 
   /// Create a reference to an existing resource.
-  AMDGPUResourceRef(ResourceTy *Resource) : Resource(Resource) {}
+  AMDGPUResourceRef(HandleTy Resource) : Resource(Resource) {}
 
   virtual ~AMDGPUResourceRef() {}
 
@@ -148,7 +151,7 @@ struct AMDGPUResourceRef : public GenericDeviceResourceRef {
   Error create(GenericDeviceTy &Device) override;
 
   /// Destroy the referenced resource and invalidate the reference. The
-  /// reference must be to a valid event before calling to this function.
+  /// reference must be to a valid resource before calling to this function.
   Error destroy(GenericDeviceTy &Device) override {
     if (!Resource)
       return Plugin::error("Destroying an invalid resource");
@@ -162,12 +165,12 @@ struct AMDGPUResourceRef : public GenericDeviceResourceRef {
     return Plugin::success();
   }
 
-  /// Get the underlying AMDGPUSignalTy reference.
-  operator ResourceTy *() const { return Resource; }
+  /// Get the underlying resource handle.
+  operator HandleTy() const { return Resource; }
 
 private:
-  /// The reference to the actual resource.
-  ResourceTy *Resource;
+  /// The handle to the actual resource.
+  HandleTy Resource;
 };
 
 /// Class holding an HSA memory pool.
@@ -955,7 +958,8 @@ struct AMDGPUStreamTy {
 
       // Release the slot's signal if possible. Otherwise, another user will.
       if (Slots[Slot].Signal->decreaseUseCount())
-        SignalManager.returnResource(Slots[Slot].Signal);
+        if (auto Err = SignalManager.returnResource(Slots[Slot].Signal))
+          return Err;
 
       Slots[Slot].Signal = nullptr;
     }
@@ -981,7 +985,9 @@ struct AMDGPUStreamTy {
     OtherSignal->increaseUseCount();
 
     // Retrieve an available signal for the operation's output.
-    AMDGPUSignalTy *OutputSignal = SignalManager.getResource();
+    AMDGPUSignalTy *OutputSignal = nullptr;
+    if (auto Err = SignalManager.getResource(OutputSignal))
+      return Err;
     OutputSignal->reset();
     OutputSignal->increaseUseCount();
 
@@ -1052,7 +1058,8 @@ struct AMDGPUStreamTy {
 
     // Release the signal if needed.
     if (Args->Signal->decreaseUseCount())
-      Args->SignalManager->returnResource(Args->Signal);
+      if (auto Err = Args->SignalManager->returnResource(Args->Signal))
+        return Err;
 
     return Plugin::success();
   }
@@ -1079,7 +1086,9 @@ struct AMDGPUStreamTy {
                          uint32_t GroupSize,
                          AMDGPUMemoryManagerTy &MemoryManager) {
     // Retrieve an available signal for the operation's output.
-    AMDGPUSignalTy *OutputSignal = SignalManager.getResource();
+    AMDGPUSignalTy *OutputSignal = nullptr;
+    if (auto Err = SignalManager.getResource(OutputSignal))
+      return Err;
     OutputSignal->reset();
     OutputSignal->increaseUseCount();
 
@@ -1101,7 +1110,9 @@ struct AMDGPUStreamTy {
   Error pushPinnedMemoryCopyAsync(void *Dst, const void *Src,
                                   uint64_t CopySize) {
     // Retrieve an available signal for the operation's output.
-    AMDGPUSignalTy *OutputSignal = SignalManager.getResource();
+    AMDGPUSignalTy *OutputSignal = nullptr;
+    if (auto Err = SignalManager.getResource(OutputSignal))
+      return Err;
     OutputSignal->reset();
     OutputSignal->increaseUseCount();
 
@@ -1138,17 +1149,18 @@ struct AMDGPUStreamTy {
     // TODO: Managers should define a function to retrieve multiple resources
     // in a single call.
     // Retrieve available signals for the operation's outputs.
-    AMDGPUSignalTy *OutputSignal1 = SignalManager.getResource();
-    AMDGPUSignalTy *OutputSignal2 = SignalManager.getResource();
-    OutputSignal1->reset();
-    OutputSignal2->reset();
-    OutputSignal1->increaseUseCount();
-    OutputSignal2->increaseUseCount();
+    AMDGPUSignalTy *OutputSignals[2] = {};
+    for (auto &Signal : OutputSignals) {
+      if (auto Err = SignalManager.getResource(Signal))
+        return Err;
+      Signal->reset();
+      Signal->increaseUseCount();
+    }
 
     std::lock_guard<std::mutex> Lock(Mutex);
 
     // Consume stream slot and compute dependencies.
-    auto [Curr, InputSignal] = consume(OutputSignal1);
+    auto [Curr, InputSignal] = consume(OutputSignals[0]);
 
     // Avoid defining the input dependency if already satisfied.
     if (InputSignal && !InputSignal->load())
@@ -1163,11 +1175,12 @@ struct AMDGPUStreamTy {
     hsa_status_t Status;
     if (InputSignal) {
       hsa_signal_t InputSignalRaw = InputSignal->get();
-      Status = hsa_amd_memory_async_copy(Inter, Agent, Src, Agent, CopySize, 1,
-                                         &InputSignalRaw, OutputSignal1->get());
+      Status =
+          hsa_amd_memory_async_copy(Inter, Agent, Src, Agent, CopySize, 1,
+                                    &InputSignalRaw, OutputSignals[0]->get());
     } else {
       Status = hsa_amd_memory_async_copy(Inter, Agent, Src, Agent, CopySize, 0,
-                                         nullptr, OutputSignal1->get());
+                                         nullptr, OutputSignals[0]->get());
     }
 
     if (auto Err =
@@ -1175,7 +1188,7 @@ struct AMDGPUStreamTy {
       return Err;
 
     // Consume another stream slot and compute dependencies.
-    std::tie(Curr, InputSignal) = consume(OutputSignal2);
+    std::tie(Curr, InputSignal) = consume(OutputSignals[1]);
     assert(InputSignal && "Invalid input signal");
 
     // The std::memcpy is done asynchronously using an async handler. We store
@@ -1204,14 +1217,15 @@ struct AMDGPUStreamTy {
                                uint64_t CopySize,
                                AMDGPUMemoryManagerTy &MemoryManager) {
     // Retrieve available signals for the operation's outputs.
-    AMDGPUSignalTy *OutputSignal1 = SignalManager.getResource();
-    AMDGPUSignalTy *OutputSignal2 = SignalManager.getResource();
-    OutputSignal1->reset();
-    OutputSignal2->reset();
-    OutputSignal1->increaseUseCount();
-    OutputSignal2->increaseUseCount();
+    AMDGPUSignalTy *OutputSignals[2] = {};
+    for (auto &Signal : OutputSignals) {
+      if (auto Err = SignalManager.getResource(Signal))
+        return Err;
+      Signal->reset();
+      Signal->increaseUseCount();
+    }
 
-    AMDGPUSignalTy *OutputSignal = OutputSignal1;
+    AMDGPUSignalTy *OutputSignal = OutputSignals[0];
 
     std::lock_guard<std::mutex> Lock(Mutex);
 
@@ -1242,7 +1256,7 @@ struct AMDGPUStreamTy {
         return Err;
 
       // Let's use now the second output signal.
-      OutputSignal = OutputSignal2;
+      OutputSignal = OutputSignals[1];
 
       // Consume another stream slot and compute dependencies.
       std::tie(Curr, InputSignal) = consume(OutputSignal);
@@ -1251,8 +1265,9 @@ struct AMDGPUStreamTy {
       std::memcpy(Inter, Src, CopySize);
 
       // Return the second signal because it will not be used.
-      OutputSignal2->decreaseUseCount();
-      SignalManager.returnResource(OutputSignal2);
+      OutputSignals[1]->decreaseUseCount();
+      if (auto Err = SignalManager.returnResource(OutputSignals[1]))
+        return Err;
     }
 
     // Setup the post action to release the intermediate pinned buffer.
@@ -1814,11 +1829,19 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
   }
 
   /// Get the stream of the asynchronous info sructure or get a new one.
-  AMDGPUStreamTy &getStream(AsyncInfoWrapperTy &AsyncInfoWrapper) {
-    AMDGPUStreamTy *&Stream = AsyncInfoWrapper.getQueueAs<AMDGPUStreamTy *>();
-    if (!Stream)
-      Stream = AMDGPUStreamManager.getResource();
-    return *Stream;
+  Error getStream(AsyncInfoWrapperTy &AsyncInfoWrapper,
+                  AMDGPUStreamTy *&Stream) {
+    // Get the stream (if any) from the async info.
+    Stream = AsyncInfoWrapper.getQueueAs<AMDGPUStreamTy *>();
+    if (!Stream) {
+      // There was no stream; get an idle one.
+      if (auto Err = AMDGPUStreamManager.getResource(Stream))
+        return Err;
+
+      // Modify the async info's stream.
+      AsyncInfoWrapper.setQueueAs<AMDGPUStreamTy *>(Stream);
+    }
+    return Plugin::success();
   }
 
   /// Load the binary image into the device and allocate an image object.
@@ -1883,10 +1906,8 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
     // Once the stream is synchronized, return it to stream pool and reset
     // AsyncInfo. This is to make sure the synchronization only works for its
     // own tasks.
-    AMDGPUStreamManager.returnResource(Stream);
     AsyncInfo.Queue = nullptr;
-
-    return Plugin::success();
+    return AMDGPUStreamManager.returnResource(Stream);
   }
 
   /// Query for the completion of the pending operations on the async info.
@@ -1906,10 +1927,8 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
     // Once the stream is completed, return it to stream pool and reset
     // AsyncInfo. This is to make sure the synchronization only works for its
     // own tasks.
-    AMDGPUStreamManager.returnResource(Stream);
     AsyncInfo.Queue = nullptr;
-
-    return Plugin::success();
+    return AMDGPUStreamManager.returnResource(Stream);
   }
 
   /// Pin the host buffer and return the device pointer that should be used for
@@ -1966,15 +1985,17 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
   /// Submit data to the device (host to device transfer).
   Error dataSubmitImpl(void *TgtPtr, const void *HstPtr, int64_t Size,
                        AsyncInfoWrapperTy &AsyncInfoWrapper) override {
+    AMDGPUStreamTy *Stream = nullptr;
+    void *PinnedPtr = nullptr;
+
     // Use one-step asynchronous operation when host memory is already pinned.
     if (void *PinnedPtr =
             PinnedAllocs.getDeviceAccessiblePtrFromPinnedBuffer(HstPtr)) {
-      AMDGPUStreamTy &Stream = getStream(AsyncInfoWrapper);
-      return Stream.pushPinnedMemoryCopyAsync(TgtPtr, PinnedPtr, Size);
+      if (auto Err = getStream(AsyncInfoWrapper, Stream))
+        return Err;
+      return Stream->pushPinnedMemoryCopyAsync(TgtPtr, PinnedPtr, Size);
     }
 
-    void *PinnedHstPtr = nullptr;
-
     // For large transfers use synchronous behavior.
     if (Size >= OMPX_MaxAsyncCopyBytes) {
       if (AsyncInfoWrapper.hasQueue())
@@ -1983,7 +2004,7 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
 
       hsa_status_t Status;
       Status = hsa_amd_memory_lock(const_cast<void *>(HstPtr), Size, nullptr, 0,
-                                   &PinnedHstPtr);
+                                   &PinnedPtr);
       if (auto Err =
               Plugin::check(Status, "Error in hsa_amd_memory_lock: %s\n"))
         return Err;
@@ -1992,8 +2013,8 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
       if (auto Err = Signal.init())
         return Err;
 
-      Status = hsa_amd_memory_async_copy(TgtPtr, Agent, PinnedHstPtr, Agent,
-                                         Size, 0, nullptr, Signal.get());
+      Status = hsa_amd_memory_async_copy(TgtPtr, Agent, PinnedPtr, Agent, Size,
+                                         0, nullptr, Signal.get());
       if (auto Err =
               Plugin::check(Status, "Error in hsa_amd_memory_async_copy: %s"))
         return Err;
@@ -2011,26 +2032,30 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
     // Otherwise, use two-step copy with an intermediate pinned host buffer.
     AMDGPUMemoryManagerTy &PinnedMemoryManager =
         HostDevice.getPinnedMemoryManager();
-    if (auto Err = PinnedMemoryManager.allocate(Size, &PinnedHstPtr))
+    if (auto Err = PinnedMemoryManager.allocate(Size, &PinnedPtr))
       return Err;
 
-    AMDGPUStreamTy &Stream = getStream(AsyncInfoWrapper);
-    return Stream.pushMemoryCopyH2DAsync(TgtPtr, HstPtr, PinnedHstPtr, Size,
-                                         PinnedMemoryManager);
+    if (auto Err = getStream(AsyncInfoWrapper, Stream))
+      return Err;
+
+    return Stream->pushMemoryCopyH2DAsync(TgtPtr, HstPtr, PinnedPtr, Size,
+                                          PinnedMemoryManager);
   }
 
   /// Retrieve data from the device (device to host transfer).
   Error dataRetrieveImpl(void *HstPtr, const void *TgtPtr, int64_t Size,
                          AsyncInfoWrapperTy &AsyncInfoWrapper) override {
+    AMDGPUStreamTy *Stream = nullptr;
+    void *PinnedPtr = nullptr;
 
     // Use one-step asynchronous operation when host memory is already pinned.
     if (void *PinnedPtr =
             PinnedAllocs.getDeviceAccessiblePtrFromPinnedBuffer(HstPtr)) {
-      AMDGPUStreamTy &Stream = getStream(AsyncInfoWrapper);
-      return Stream.pushPinnedMemoryCopyAsync(PinnedPtr, TgtPtr, Size);
-    }
+      if (auto Err = getStream(AsyncInfoWrapper, Stream))
+        return Err;
 
-    void *PinnedHstPtr = nullptr;
+      return Stream->pushPinnedMemoryCopyAsync(PinnedPtr, TgtPtr, Size);
+    }
 
     // For large transfers use synchronous behavior.
     if (Size >= OMPX_MaxAsyncCopyBytes) {
@@ -2040,7 +2065,7 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
 
       hsa_status_t Status;
       Status = hsa_amd_memory_lock(const_cast<void *>(HstPtr), Size, nullptr, 0,
-                                   &PinnedHstPtr);
+                                   &PinnedPtr);
       if (auto Err =
               Plugin::check(Status, "Error in hsa_amd_memory_lock: %s\n"))
         return Err;
@@ -2049,8 +2074,8 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
       if (auto Err = Signal.init())
         return Err;
 
-      Status = hsa_amd_memory_async_copy(PinnedHstPtr, Agent, TgtPtr, Agent,
-                                         Size, 0, nullptr, Signal.get());
+      Status = hsa_amd_memory_async_copy(PinnedPtr, Agent, TgtPtr, Agent, Size,
+                                         0, nullptr, Signal.get());
       if (auto Err =
               Plugin::check(Status, "Error in hsa_amd_memory_async_copy: %s"))
         return Err;
@@ -2068,12 +2093,14 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
     // Otherwise, use two-step copy with an intermediate pinned host buffer.
     AMDGPUMemoryManagerTy &PinnedMemoryManager =
         HostDevice.getPinnedMemoryManager();
-    if (auto Err = PinnedMemoryManager.allocate(Size, &PinnedHstPtr))
+    if (auto Err = PinnedMemoryManager.allocate(Size, &PinnedPtr))
+      return Err;
+
+    if (auto Err = getStream(AsyncInfoWrapper, Stream))
       return Err;
 
-    AMDGPUStreamTy &Stream = getStream(AsyncInfoWrapper);
-    return Stream.pushMemoryCopyD2HAsync(HstPtr, TgtPtr, PinnedHstPtr, Size,
-                                         PinnedMemoryManager);
+    return Stream->pushMemoryCopyD2HAsync(HstPtr, TgtPtr, PinnedPtr, Size,
+                                          PinnedMemoryManager);
   }
 
   /// Exchange data between two devices within the plugin. This function is not
@@ -2105,15 +2132,13 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
   /// Create an event.
   Error createEventImpl(void **EventPtrStorage) override {
     AMDGPUEventTy **Event = reinterpret_cast<AMDGPUEventTy **>(EventPtrStorage);
-    *Event = AMDGPUEventManager.getResource();
-    return Plugin::success();
+    return AMDGPUEventManager.getResource(*Event);
   }
 
   /// Destroy a previously created event.
   Error destroyEventImpl(void *EventPtr) override {
     AMDGPUEventTy *Event = reinterpret_cast<AMDGPUEventTy *>(EventPtr);
-    AMDGPUEventManager.returnResource(Event);
-    return Plugin::success();
+    return AMDGPUEventManager.returnResource(Event);
   }
 
   /// Record the event.
@@ -2122,9 +2147,11 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
     AMDGPUEventTy *Event = reinterpret_cast<AMDGPUEventTy *>(EventPtr);
     assert(Event && "Invalid event");
 
-    AMDGPUStreamTy &Stream = getStream(AsyncInfoWrapper);
+    AMDGPUStreamTy *Stream = nullptr;
+    if (auto Err = getStream(AsyncInfoWrapper, Stream))
+      return Err;
 
-    return Event->record(Stream);
+    return Event->record(*Stream);
   }
 
   /// Make the stream wait on the event.
@@ -2132,9 +2159,11 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
                       AsyncInfoWrapperTy &AsyncInfoWrapper) override {
     AMDGPUEventTy *Event = reinterpret_cast<AMDGPUEventTy *>(EventPtr);
 
-    AMDGPUStreamTy &Stream = getStream(AsyncInfoWrapper);
+    AMDGPUStreamTy *Stream = nullptr;
+    if (auto Err = getStream(AsyncInfoWrapper, Stream))
+      return Err;
 
-    return Event->wait(Stream);
+    return Event->wait(*Stream);
   }
 
   /// Synchronize the current thread with the event.
@@ -2850,15 +2879,18 @@ Error AMDGPUKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
                 sizeof(void *) * KernelArgs.NumArgs);
 
   AMDGPUDeviceTy &AMDGPUDevice = static_cast<AMDGPUDeviceTy &>(GenericDevice);
-  AMDGPUStreamTy &Stream = AMDGPUDevice.getStream(AsyncInfoWrapper);
+
+  AMDGPUStreamTy *Stream = nullptr;
+  if (auto Err = AMDGPUDevice.getStream(AsyncInfoWrapper, Stream))
+    return Err;
 
   // If this kernel requires an RPC server we attach its pointer to the stream.
   if (GenericDevice.getRPCServer())
-    Stream.setRPCServer(GenericDevice.getRPCServer());
+    Stream->setRPCServer(GenericDevice.getRPCServer());
 
   // Push the kernel launch into the stream.
-  return Stream.pushKernelLaunch(*this, AllArgs, NumThreads, NumBlocks,
-                                 GroupSize, ArgsMemoryManager);
+  return Stream->pushKernelLaunch(*this, AllArgs, NumThreads, NumBlocks,
+                                  GroupSize, ArgsMemoryManager);
 }
 
 Error AMDGPUKernelTy::printLaunchInfoDetails(GenericDeviceTy &GenericDevice,

diff  --git a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.h b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.h
index 15e840b92b3f14..46f1e7863981ab 100644
--- a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.h
+++ b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.h
@@ -67,15 +67,23 @@ struct AsyncInfoWrapperTy {
   /// Get the raw __tgt_async_info pointer.
   operator __tgt_async_info *() const { return AsyncInfoPtr; }
 
-  /// Get a reference to the underlying plugin-specific queue type.
-  template <typename Ty> Ty &getQueueAs() const {
+  /// Indicate whether there is queue.
+  bool hasQueue() const { return (AsyncInfoPtr->Queue != nullptr); }
+
+  /// Get the queue.
+  template <typename Ty> Ty getQueueAs() {
     static_assert(sizeof(Ty) == sizeof(AsyncInfoPtr->Queue),
                   "Queue is not of the same size as target type");
-    return reinterpret_cast<Ty &>(AsyncInfoPtr->Queue);
+    return static_cast<Ty>(AsyncInfoPtr->Queue);
   }
 
-  /// Indicate whether there is queue.
-  bool hasQueue() const { return (AsyncInfoPtr->Queue != nullptr); }
+  /// Set the queue.
+  template <typename Ty> void setQueueAs(Ty Queue) {
+    static_assert(sizeof(Ty) == sizeof(AsyncInfoPtr->Queue),
+                  "Queue is not of the same size as target type");
+    assert(!AsyncInfoPtr->Queue && "Overwriting queue");
+    AsyncInfoPtr->Queue = Queue;
+  }
 
   /// Synchronize with the __tgt_async_info's pending operations if it's the
   /// internal async info. The error associated to the aysnchronous operations
@@ -1118,6 +1126,10 @@ class Plugin {
 /// some basic functions to be implemented. The derived class should define an
 /// empty constructor that creates an empty and invalid resource reference. Do
 /// not create a new resource on the ctor, but on the create() function instead.
+///
+/// The derived class should also define the type HandleTy as the underlying
+/// resource handle type. For instance, in a CUDA stream it would be:
+///   using HandleTy = CUstream;
 struct GenericDeviceResourceRef {
   /// Create a new resource and stores a reference.
   virtual Error create(GenericDeviceTy &Device) = 0;
@@ -1135,6 +1147,7 @@ struct GenericDeviceResourceRef {
 /// and destroy virtual functions.
 template <typename ResourceRef> class GenericDeviceResourceManagerTy {
   using ResourcePoolTy = GenericDeviceResourceManagerTy<ResourceRef>;
+  using ResourceHandleTy = typename ResourceRef::HandleTy;
 
 public:
   /// Create an empty resource pool for a specific device.
@@ -1169,31 +1182,33 @@ template <typename ResourceRef> class GenericDeviceResourceManagerTy {
     return Plugin::success();
   }
 
-  /// Get resource from the pool or create new resources.
-  ResourceRef getResource() {
+  /// Get resource from the pool or create new resources. If the function
+  /// succeeeds, the handle to the resource is saved in \p Handle.
+  Error getResource(ResourceHandleTy &Handle) {
     const std::lock_guard<std::mutex> Lock(Mutex);
 
     assert(NextAvailable <= ResourcePool.size() &&
            "Resource pool is corrupted");
 
-    if (NextAvailable == ResourcePool.size()) {
+    if (NextAvailable == ResourcePool.size())
       // By default we double the resource pool every time.
-      if (auto Err = ResourcePoolTy::resizeResourcePool(NextAvailable * 2)) {
-        REPORT("Failure to resize the resource pool: %s",
-               toString(std::move(Err)).data());
-        // Return an empty reference.
-        return ResourceRef();
-      }
-    }
-    return ResourcePool[NextAvailable++];
+      if (auto Err = ResourcePoolTy::resizeResourcePool(NextAvailable * 2))
+        return Err;
+
+    // Save the handle in the output parameter.
+    Handle = ResourcePool[NextAvailable++];
+
+    return Plugin::success();
   }
 
   /// Return resource to the pool.
-  void returnResource(ResourceRef Resource) {
+  Error returnResource(ResourceHandleTy Handle) {
     const std::lock_guard<std::mutex> Lock(Mutex);
 
     assert(NextAvailable > 0 && "Resource pool is corrupted");
-    ResourcePool[--NextAvailable] = Resource;
+    ResourcePool[--NextAvailable] = Handle;
+
+    return Plugin::success();
   }
 
 private:

diff  --git a/openmp/libomptarget/plugins-nextgen/cuda/src/rtl.cpp b/openmp/libomptarget/plugins-nextgen/cuda/src/rtl.cpp
index eb1400b7f38610..d3c82280f08881 100644
--- a/openmp/libomptarget/plugins-nextgen/cuda/src/rtl.cpp
+++ b/openmp/libomptarget/plugins-nextgen/cuda/src/rtl.cpp
@@ -83,16 +83,15 @@ struct CUDAKernelTy : public GenericKernelTy {
 
 /// Class wrapping a CUDA stream reference. These are the objects handled by the
 /// Stream Manager for the CUDA plugin.
-class CUDAStreamRef final : public GenericDeviceResourceRef {
-  /// The reference to the CUDA stream.
-  CUstream Stream;
+struct CUDAStreamRef final : public GenericDeviceResourceRef {
+  /// The underlying handle type for streams.
+  using HandleTy = CUstream;
 
-public:
   /// Create an empty reference to an invalid stream.
   CUDAStreamRef() : Stream(nullptr) {}
 
   /// Create a reference to an existing stream.
-  CUDAStreamRef(CUstream Stream) : Stream(Stream) {}
+  CUDAStreamRef(HandleTy Stream) : Stream(Stream) {}
 
   /// Create a new stream and save the reference. The reference must be empty
   /// before calling to this function.
@@ -121,21 +120,25 @@ class CUDAStreamRef final : public GenericDeviceResourceRef {
     return Plugin::success();
   }
 
-  /// Get the underlying CUstream.
-  operator CUstream() const { return Stream; }
+  /// Get the underlying CUDA stream.
+  operator HandleTy() const { return Stream; }
+
+private:
+  /// The reference to the CUDA stream.
+  HandleTy Stream;
 };
 
 /// Class wrapping a CUDA event reference. These are the objects handled by the
 /// Event Manager for the CUDA plugin.
-class CUDAEventRef final : public GenericDeviceResourceRef {
-  CUevent Event;
+struct CUDAEventRef final : public GenericDeviceResourceRef {
+  /// The underlying handle type for events.
+  using HandleTy = CUevent;
 
-public:
   /// Create an empty reference to an invalid event.
   CUDAEventRef() : Event(nullptr) {}
 
   /// Create a reference to an existing event.
-  CUDAEventRef(CUevent Event) : Event(Event) {}
+  CUDAEventRef(HandleTy Event) : Event(Event) {}
 
   /// Create a new event and save the reference. The reference must be empty
   /// before calling to this function.
@@ -165,7 +168,11 @@ class CUDAEventRef final : public GenericDeviceResourceRef {
   }
 
   /// Get the underlying CUevent.
-  operator CUevent() const { return Event; }
+  operator HandleTy() const { return Event; }
+
+private:
+  /// The reference to the CUDA event.
+  HandleTy Event;
 };
 
 /// Class implementing the CUDA device images properties.
@@ -374,11 +381,18 @@ struct CUDADeviceTy : public GenericDeviceTy {
   }
 
   /// Get the stream of the asynchronous info sructure or get a new one.
-  CUstream getStream(AsyncInfoWrapperTy &AsyncInfoWrapper) {
-    CUstream &Stream = AsyncInfoWrapper.getQueueAs<CUstream>();
-    if (!Stream)
-      Stream = CUDAStreamManager.getResource();
-    return Stream;
+  Error getStream(AsyncInfoWrapperTy &AsyncInfoWrapper, CUstream &Stream) {
+    // Get the stream (if any) from the async info.
+    Stream = AsyncInfoWrapper.getQueueAs<CUstream>();
+    if (!Stream) {
+      // There was no stream; get an idle one.
+      if (auto Err = CUDAStreamManager.getResource(Stream))
+        return Err;
+
+      // Modify the async info's stream.
+      AsyncInfoWrapper.setQueueAs<CUstream>(Stream);
+    }
+    return Plugin::success();
   }
 
   /// Getters of CUDA references.
@@ -487,8 +501,9 @@ struct CUDADeviceTy : public GenericDeviceTy {
     // Once the stream is synchronized, return it to stream pool and reset
     // AsyncInfo. This is to make sure the synchronization only works for its
     // own tasks.
-    CUDAStreamManager.returnResource(Stream);
     AsyncInfo.Queue = nullptr;
+    if (auto Err = CUDAStreamManager.returnResource(Stream))
+      return Err;
 
     return Plugin::check(Res, "Error in cuStreamSynchronize: %s");
   }
@@ -505,8 +520,9 @@ struct CUDADeviceTy : public GenericDeviceTy {
     // Once the stream is synchronized and the operations completed (or an error
     // occurs), return it to stream pool and reset AsyncInfo. This is to make
     // sure the synchronization only works for its own tasks.
-    CUDAStreamManager.returnResource(Stream);
     AsyncInfo.Queue = nullptr;
+    if (auto Err = CUDAStreamManager.returnResource(Stream))
+      return Err;
 
     return Plugin::check(Res, "Error in cuStreamQuery: %s");
   }
@@ -531,9 +547,9 @@ struct CUDADeviceTy : public GenericDeviceTy {
     if (auto Err = setContext())
       return Err;
 
-    CUstream Stream = getStream(AsyncInfoWrapper);
-    if (!Stream)
-      return Plugin::error("Failure to get stream");
+    CUstream Stream;
+    if (auto Err = getStream(AsyncInfoWrapper, Stream))
+      return Err;
 
     CUresult Res = cuMemcpyHtoDAsync((CUdeviceptr)TgtPtr, HstPtr, Size, Stream);
     return Plugin::check(Res, "Error in cuMemcpyHtoDAsync: %s");
@@ -545,9 +561,9 @@ struct CUDADeviceTy : public GenericDeviceTy {
     if (auto Err = setContext())
       return Err;
 
-    CUstream Stream = getStream(AsyncInfoWrapper);
-    if (!Stream)
-      return Plugin::error("Failure to get stream");
+    CUstream Stream;
+    if (auto Err = getStream(AsyncInfoWrapper, Stream))
+      return Err;
 
     CUresult Res = cuMemcpyDtoHAsync(HstPtr, (CUdeviceptr)TgtPtr, Size, Stream);
     return Plugin::check(Res, "Error in cuMemcpyDtoHAsync: %s");
@@ -564,8 +580,9 @@ struct CUDADeviceTy : public GenericDeviceTy {
     if (auto Err = setContext())
       return Err;
 
-    if (!getStream(AsyncInfoWrapper))
-      return Plugin::error("Failure to get stream");
+    CUstream Stream;
+    if (auto Err = getStream(AsyncInfoWrapper, Stream))
+      return Err;
 
     return Plugin::success();
   }
@@ -590,15 +607,13 @@ struct CUDADeviceTy : public GenericDeviceTy {
   /// Create an event.
   Error createEventImpl(void **EventPtrStorage) override {
     CUevent *Event = reinterpret_cast<CUevent *>(EventPtrStorage);
-    *Event = CUDAEventManager.getResource();
-    return Plugin::success();
+    return CUDAEventManager.getResource(*Event);
   }
 
   /// Destroy a previously created event.
   Error destroyEventImpl(void *EventPtr) override {
     CUevent Event = reinterpret_cast<CUevent>(EventPtr);
-    CUDAEventManager.returnResource(Event);
-    return Plugin::success();
+    return CUDAEventManager.returnResource(Event);
   }
 
   /// Record the event.
@@ -606,9 +621,9 @@ struct CUDADeviceTy : public GenericDeviceTy {
                         AsyncInfoWrapperTy &AsyncInfoWrapper) override {
     CUevent Event = reinterpret_cast<CUevent>(EventPtr);
 
-    CUstream Stream = getStream(AsyncInfoWrapper);
-    if (!Stream)
-      return Plugin::error("Failure to get stream");
+    CUstream Stream;
+    if (auto Err = getStream(AsyncInfoWrapper, Stream))
+      return Err;
 
     CUresult Res = cuEventRecord(Event, Stream);
     return Plugin::check(Res, "Error in cuEventRecord: %s");
@@ -619,9 +634,9 @@ struct CUDADeviceTy : public GenericDeviceTy {
                       AsyncInfoWrapperTy &AsyncInfoWrapper) override {
     CUevent Event = reinterpret_cast<CUevent>(EventPtr);
 
-    CUstream Stream = getStream(AsyncInfoWrapper);
-    if (!Stream)
-      return Plugin::error("Failure to get stream");
+    CUstream Stream;
+    if (auto Err = getStream(AsyncInfoWrapper, Stream))
+      return Err;
 
     // Do not use CU_EVENT_WAIT_DEFAULT here as it is only available from
     // specific CUDA version, and defined as 0x0. In previous version, per CUDA
@@ -883,9 +898,9 @@ Error CUDAKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
                                AsyncInfoWrapperTy &AsyncInfoWrapper) const {
   CUDADeviceTy &CUDADevice = static_cast<CUDADeviceTy &>(GenericDevice);
 
-  CUstream Stream = CUDADevice.getStream(AsyncInfoWrapper);
-  if (!Stream)
-    return Plugin::error("Failure to get stream");
+  CUstream Stream;
+  if (auto Err = CUDADevice.getStream(AsyncInfoWrapper, Stream))
+    return Err;
 
   uint32_t MaxDynCGroupMem =
       std::max(KernelArgs.DynCGroupMem, GenericDevice.getDynamicMemorySize());
@@ -1069,9 +1084,9 @@ Error CUDADeviceTy::dataExchangeImpl(const void *SrcPtr,
     }
   }
 
-  CUstream Stream = getStream(AsyncInfoWrapper);
-  if (!Stream)
-    return Plugin::error("Failure to get stream");
+  CUstream Stream;
+  if (auto Err = getStream(AsyncInfoWrapper, Stream))
+    return Err;
 
   if (CanAccessPeer) {
     // TODO: Should we fallback to D2D if peer access fails?


        


More information about the Openmp-commits mailing list