[Openmp-commits] [openmp] 10aa83f - [OpenMP] Allow to explicitly deinitialize device resources

Johannes Doerfert via Openmp-commits openmp-commits at lists.llvm.org
Mon Mar 7 21:43:30 PST 2022


Author: Johannes Doerfert
Date: 2022-03-07T23:43:04-06:00
New Revision: 10aa83ff74b48d441aa2141047cf8674e069d4f6

URL: https://github.com/llvm/llvm-project/commit/10aa83ff74b48d441aa2141047cf8674e069d4f6
DIFF: https://github.com/llvm/llvm-project/commit/10aa83ff74b48d441aa2141047cf8674e069d4f6.diff

LOG: [OpenMP] Allow to explicitly deinitialize device resources

There are two problems this patch tries to address:
1) We currently free resources in a random order wrt. plugin and
   libomptarget destruction. This patch should ensure the CUDA plugin
   is less fragile if something during the deinitialization goes wrong.
2) We need to support (hard) pause runtime calls eventually. This patch
   allows us to free all associated resources, though we cannot
   reinitialize the device yet.

Follow up patch will associate one event pool per device/context.

Differential Revision: https://reviews.llvm.org/D120089

Added: 
    

Modified: 
    openmp/libomptarget/include/device.h
    openmp/libomptarget/include/omptargetplugin.h
    openmp/libomptarget/include/rtl.h
    openmp/libomptarget/plugins/cuda/src/rtl.cpp
    openmp/libomptarget/plugins/exports
    openmp/libomptarget/src/device.cpp
    openmp/libomptarget/src/rtl.cpp

Removed: 
    


################################################################################
diff  --git a/openmp/libomptarget/include/device.h b/openmp/libomptarget/include/device.h
index 43d05bb50b223..17ab2ed4e2ce5 100644
--- a/openmp/libomptarget/include/device.h
+++ b/openmp/libomptarget/include/device.h
@@ -408,6 +408,9 @@ struct DeviceTy {
 private:
   // Call to RTL
   void init(); // To be called only via DeviceTy::initOnce()
+
+  /// Deinitialize the device (and plugin).
+  void deinit();
 };
 
 extern bool device_is_ready(int device_num);

diff  --git a/openmp/libomptarget/include/omptargetplugin.h b/openmp/libomptarget/include/omptargetplugin.h
index e404d55d064fb..ceec68b4c2346 100644
--- a/openmp/libomptarget/include/omptargetplugin.h
+++ b/openmp/libomptarget/include/omptargetplugin.h
@@ -48,6 +48,10 @@ int64_t __tgt_rtl_init_requires(int64_t RequiresFlags);
 // return an error code.
 int32_t __tgt_rtl_init_device(int32_t ID);
 
+// Deinitialize the specified device. In case of success return 0; otherwise
+// return an error code.
+int32_t __tgt_rtl_deinit_device(int32_t ID);
+
 // Pass an executable image section described by image to the specified
 // device and prepare an address table of target entities. In case of error,
 // return NULL. Otherwise, return a pointer to the built address table.

diff  --git a/openmp/libomptarget/include/rtl.h b/openmp/libomptarget/include/rtl.h
index e742ca0205e61..9019939b52fc2 100644
--- a/openmp/libomptarget/include/rtl.h
+++ b/openmp/libomptarget/include/rtl.h
@@ -29,6 +29,7 @@ struct RTLInfoTy {
   typedef int32_t(is_data_exchangable_ty)(int32_t, int32_t);
   typedef int32_t(number_of_devices_ty)();
   typedef int32_t(init_device_ty)(int32_t);
+  typedef int32_t(deinit_device_ty)(int32_t);
   typedef __tgt_target_table *(load_binary_ty)(int32_t, void *);
   typedef void *(data_alloc_ty)(int32_t, int64_t, void *, int32_t);
   typedef int32_t(data_submit_ty)(int32_t, void *, void *, int64_t);
@@ -84,6 +85,7 @@ struct RTLInfoTy {
   is_data_exchangable_ty *is_data_exchangable = nullptr;
   number_of_devices_ty *number_of_devices = nullptr;
   init_device_ty *init_device = nullptr;
+  deinit_device_ty *deinit_device = nullptr;
   load_binary_ty *load_binary = nullptr;
   data_alloc_ty *data_alloc = nullptr;
   data_submit_ty *data_submit = nullptr;

diff  --git a/openmp/libomptarget/plugins/cuda/src/rtl.cpp b/openmp/libomptarget/plugins/cuda/src/rtl.cpp
index b688fe11ef5a7..cb511240ac2c6 100644
--- a/openmp/libomptarget/plugins/cuda/src/rtl.cpp
+++ b/openmp/libomptarget/plugins/cuda/src/rtl.cpp
@@ -10,6 +10,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include <algorithm>
 #include <cassert>
 #include <cstddef>
 #include <cuda.h>
@@ -22,6 +23,7 @@
 
 #include "Debug.h"
 #include "DeviceEnvironment.h"
+#include "omptarget.h"
 #include "omptargetplugin.h"
 
 #define TARGET_NAME CUDA
@@ -339,6 +341,10 @@ class DeviceRTLTy {
   std::vector<DeviceDataTy> DeviceData;
   std::vector<CUmodule> Modules;
 
+  /// Vector of flags indicating the initalization status of all associated
+  /// devices.
+  std::vector<bool> InitializedFlags;
+
   /// A class responsible for interacting with device native runtime library to
   /// allocate and free memory.
   class CUDADeviceAllocatorTy : public DeviceAllocatorTy {
@@ -467,7 +473,6 @@ class DeviceRTLTy {
   }
 
 public:
-
   CUstream getStream(const int DeviceId, __tgt_async_info *AsyncInfo) const {
     assert(AsyncInfo && "AsyncInfo is nullptr");
 
@@ -555,36 +560,14 @@ class DeviceRTLTy {
       for (int I = 0; I < NumberOfDevices; ++I)
         MemoryManagers.emplace_back(std::make_unique<MemoryManagerTy>(
             DeviceAllocators[I], MemoryManagerThreshold));
+
+    // We lazily initialize all devices later.
+    InitializedFlags.assign(NumberOfDevices, false);
   }
 
   ~DeviceRTLTy() {
-    // We first destruct memory managers in case that its dependent data are
-    // destroyed before it.
-    for (auto &M : MemoryManagers)
-      M.release();
-
-    for (CUmodule &M : Modules)
-      // Close module
-      if (M)
-        checkResult(cuModuleUnload(M), "Error returned from cuModuleUnload\n");
-
-    for (auto &S : StreamPool)
-      S.reset();
-
-    EventPool.clear();
-
-    for (DeviceDataTy &D : DeviceData) {
-      // Destroy context
-      if (D.Context) {
-        checkResult(cuCtxSetCurrent(D.Context),
-                    "Error returned from cuCtxSetCurrent\n");
-        CUdevice Device;
-        checkResult(cuCtxGetDevice(&Device),
-                    "Error returned from cuCtxGetDevice\n");
-        checkResult(cuDevicePrimaryCtxRelease(Device),
-                    "Error returned from cuDevicePrimaryCtxRelease\n");
-      }
-    }
+    for (int DeviceId = 0; DeviceId < NumberOfDevices; ++DeviceId)
+      deinitDevice(DeviceId);
   }
 
   // Check whether a given DeviceId is valid
@@ -604,6 +587,9 @@ class DeviceRTLTy {
     if (!checkResult(Err, "Error returned from cuDeviceGet\n"))
       return OFFLOAD_FAIL;
 
+    assert(InitializedFlags[DeviceId] == false && "Reinitializing device!");
+    InitializedFlags[DeviceId] = true;
+
     // Query the current flags of the primary context and set its flags if
     // it is inactive
     unsigned int FormerPrimaryCtxFlags = 0;
@@ -761,6 +747,42 @@ class DeviceRTLTy {
     return OFFLOAD_SUCCESS;
   }
 
+  int deinitDevice(const int DeviceId) {
+    auto IsInitialized = InitializedFlags[DeviceId];
+    if (!IsInitialized)
+      return OFFLOAD_SUCCESS;
+    InitializedFlags[DeviceId] = false;
+
+    if (UseMemoryManager)
+      MemoryManagers[DeviceId].release();
+
+    // Close module
+    if (CUmodule &M = Modules[DeviceId])
+      checkResult(cuModuleUnload(M), "Error returned from cuModuleUnload\n");
+
+    StreamPool[DeviceId].reset();
+
+    // The event pool is shared, we initialize it once all devices have been
+    // deinitialized.
+    if (std::none_of(InitializedFlags.begin(), InitializedFlags.end(),
+                     [](bool IsInitialized) { return IsInitialized; }))
+      EventPool.clear();
+
+    // Destroy context
+    DeviceDataTy &D = DeviceData[DeviceId];
+    if (D.Context) {
+      if (checkResult(cuCtxSetCurrent(D.Context),
+                      "Error returned from cuCtxSetCurrent\n")) {
+        CUdevice Device;
+        if (checkResult(cuCtxGetDevice(&Device),
+                        "Error returned from cuCtxGetDevice\n"))
+          checkResult(cuDevicePrimaryCtxRelease(Device),
+                      "Error returned from cuDevicePrimaryCtxRelease\n");
+      }
+    }
+    return OFFLOAD_SUCCESS;
+  }
+
   __tgt_target_table *loadBinary(const int DeviceId,
                                  const __tgt_device_image *Image) {
     // Set the context we are using
@@ -1496,6 +1518,12 @@ int32_t __tgt_rtl_init_device(int32_t device_id) {
   return DeviceRTL.initDevice(device_id);
 }
 
+int32_t __tgt_rtl_deinit_device(int32_t device_id) {
+  assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
+
+  return DeviceRTL.deinitDevice(device_id);
+}
+
 __tgt_target_table *__tgt_rtl_load_binary(int32_t device_id,
                                           __tgt_device_image *image) {
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");

diff  --git a/openmp/libomptarget/plugins/exports b/openmp/libomptarget/plugins/exports
index 8664a2e493ee2..b4582f1f25c03 100644
--- a/openmp/libomptarget/plugins/exports
+++ b/openmp/libomptarget/plugins/exports
@@ -5,6 +5,7 @@ VERS1.0 {
     __tgt_rtl_number_of_devices;
     __tgt_rtl_init_requires;
     __tgt_rtl_init_device;
+    __tgt_rtl_deinit_device;
     __tgt_rtl_load_binary;
     __tgt_rtl_data_alloc;
     __tgt_rtl_data_submit;

diff  --git a/openmp/libomptarget/src/device.cpp b/openmp/libomptarget/src/device.cpp
index 09470444ed1c3..bcad793e73e9e 100644
--- a/openmp/libomptarget/src/device.cpp
+++ b/openmp/libomptarget/src/device.cpp
@@ -468,6 +468,11 @@ int32_t DeviceTy::initOnce() {
     return OFFLOAD_FAIL;
 }
 
+void DeviceTy::deinit() {
+  if (RTL->deinit_device)
+    RTL->deinit_device(RTLDeviceID);
+}
+
 // Load binary to device.
 __tgt_target_table *DeviceTy::load_binary(void *Img) {
   std::lock_guard<decltype(RTL->Mtx)> LG(RTL->Mtx);

diff  --git a/openmp/libomptarget/src/rtl.cpp b/openmp/libomptarget/src/rtl.cpp
index 8a256ab7bd9bd..f02196f770c03 100644
--- a/openmp/libomptarget/src/rtl.cpp
+++ b/openmp/libomptarget/src/rtl.cpp
@@ -165,6 +165,8 @@ void RTLsTy::LoadRTLs() {
        R.NumberOfDevices);
 
     // Optional functions
+    *((void **)&R.deinit_device) =
+        dlsym(dynlib_handle, "__tgt_rtl_deinit_device");
     *((void **)&R.init_requires) =
         dlsym(dynlib_handle, "__tgt_rtl_init_requires");
     *((void **)&R.data_submit_async) =


        


More information about the Openmp-commits mailing list