[Openmp-commits] [openmp] 39d3283 - [OpenMP][CUDA] Avoid calling `cuCtxSetCurrent` redundantly

Shilei Tian via Openmp-commits openmp-commits at lists.llvm.org
Wed Mar 9 13:32:52 PST 2022


Author: Shilei Tian
Date: 2022-03-09T16:32:47-05:00
New Revision: 39d3283a08ba0d687eba3f74ecb85d60b7c71355

URL: https://github.com/llvm/llvm-project/commit/39d3283a08ba0d687eba3f74ecb85d60b7c71355
DIFF: https://github.com/llvm/llvm-project/commit/39d3283a08ba0d687eba3f74ecb85d60b7c71355.diff

LOG: [OpenMP][CUDA] Avoid calling `cuCtxSetCurrent` redundantly

Currently we set ccontext everywhere accordingly, but that causes many
unnecessary function calls. For example, in the resource pool, if we need to
resize the pool, we need to get from allocator. Each call to allocate sets the
current context once, which is unnecessary. In this patch, we set the context
only in the entry interface functions, if needed. Actually in the best way this
should be implemented via RAII, but since `cuCtxSetCurrent` could return error,
and we don't use exception, we can't stop the execution if RAII fails.

Reviewed By: jdoerfert

Differential Revision: https://reviews.llvm.org/D121322

Added: 
    

Modified: 
    openmp/libomptarget/plugins/cuda/src/rtl.cpp

Removed: 
    


################################################################################
diff  --git a/openmp/libomptarget/plugins/cuda/src/rtl.cpp b/openmp/libomptarget/plugins/cuda/src/rtl.cpp
index 9c3488790bed2..4f2c99feb84b3 100644
--- a/openmp/libomptarget/plugins/cuda/src/rtl.cpp
+++ b/openmp/libomptarget/plugins/cuda/src/rtl.cpp
@@ -167,30 +167,17 @@ struct DeviceDataTy {
 /// Functions \p create and \p destroy return OFFLOAD_SUCCESS and OFFLOAD_FAIL
 /// accordingly. The implementation should not raise any exception.
 template <typename T> struct AllocatorTy {
-  AllocatorTy(CUcontext C) noexcept : Context(C) {}
   using ElementTy = T;
-
-  virtual ~AllocatorTy() {}
-
   /// Create a resource and assign to R.
   virtual int create(T &R) noexcept = 0;
   /// Destroy the resource.
   virtual int destroy(T) noexcept = 0;
-
-protected:
-  CUcontext Context;
 };
 
 /// Allocator for CUstream.
 struct StreamAllocatorTy final : public AllocatorTy<CUstream> {
-  StreamAllocatorTy(CUcontext C) noexcept : AllocatorTy<CUstream>(C) {}
-
   /// See AllocatorTy<T>::create.
   int create(CUstream &Stream) noexcept override {
-    if (!checkResult(cuCtxSetCurrent(Context),
-                     "Error returned from cuCtxSetCurrent\n"))
-      return OFFLOAD_FAIL;
-
     if (!checkResult(cuStreamCreate(&Stream, CU_STREAM_NON_BLOCKING),
                      "Error returned from cuStreamCreate\n"))
       return OFFLOAD_FAIL;
@@ -200,9 +187,6 @@ struct StreamAllocatorTy final : public AllocatorTy<CUstream> {
 
   /// See AllocatorTy<T>::destroy.
   int destroy(CUstream Stream) noexcept override {
-    if (!checkResult(cuCtxSetCurrent(Context),
-                     "Error returned from cuCtxSetCurrent\n"))
-      return OFFLOAD_FAIL;
     if (!checkResult(cuStreamDestroy(Stream),
                      "Error returned from cuStreamDestroy\n"))
       return OFFLOAD_FAIL;
@@ -213,8 +197,6 @@ struct StreamAllocatorTy final : public AllocatorTy<CUstream> {
 
 /// Allocator for CUevent.
 struct EventAllocatorTy final : public AllocatorTy<CUevent> {
-  EventAllocatorTy(CUcontext C) noexcept : AllocatorTy<CUevent>(C) {}
-
   /// See AllocatorTy<T>::create.
   int create(CUevent &Event) noexcept override {
     if (!checkResult(cuEventCreate(&Event, CU_EVENT_DEFAULT),
@@ -363,23 +345,15 @@ class DeviceRTLTy {
   /// A class responsible for interacting with device native runtime library to
   /// allocate and free memory.
   class CUDADeviceAllocatorTy : public DeviceAllocatorTy {
-    const int DeviceId;
-    const std::vector<DeviceDataTy> &DeviceData;
     std::unordered_map<void *, TargetAllocTy> HostPinnedAllocs;
 
   public:
-    CUDADeviceAllocatorTy(int DeviceId, std::vector<DeviceDataTy> &DeviceData)
-        : DeviceId(DeviceId), DeviceData(DeviceData) {}
-
     void *allocate(size_t Size, void *, TargetAllocTy Kind) override {
       if (Size == 0)
         return nullptr;
 
-      CUresult Err = cuCtxSetCurrent(DeviceData[DeviceId].Context);
-      if (!checkResult(Err, "Error returned from cuCtxSetCurrent\n"))
-        return nullptr;
-
       void *MemAlloc = nullptr;
+      CUresult Err;
       switch (Kind) {
       case TARGET_ALLOC_DEFAULT:
       case TARGET_ALLOC_DEVICE:
@@ -410,10 +384,7 @@ class DeviceRTLTy {
     }
 
     int free(void *TgtPtr) override {
-      CUresult Err = cuCtxSetCurrent(DeviceData[DeviceId].Context);
-      if (!checkResult(Err, "Error returned from cuCtxSetCurrent\n"))
-        return OFFLOAD_FAIL;
-
+      CUresult Err;
       // Host pinned memory must be freed 
diff erently.
       TargetAllocTy Kind =
           (HostPinnedAllocs.find(TgtPtr) == HostPinnedAllocs.end())
@@ -566,7 +537,7 @@ class DeviceRTLTy {
     }
 
     for (int I = 0; I < NumberOfDevices; ++I)
-      DeviceAllocators.emplace_back(I, DeviceData);
+      DeviceAllocators.emplace_back();
 
     // Get the size threshold from environment variable
     std::pair<size_t, bool> Res = MemoryManagerTy::getSizeThresholdFromEnv();
@@ -641,13 +612,13 @@ class DeviceRTLTy {
 
     // Initialize the stream pool.
     if (!StreamPool[DeviceId])
-      StreamPool[DeviceId] = std::make_unique<StreamPoolTy>(
-          StreamAllocatorTy(DeviceData[DeviceId].Context), NumInitialStreams);
+      StreamPool[DeviceId] = std::make_unique<StreamPoolTy>(StreamAllocatorTy(),
+                                                            NumInitialStreams);
 
     // Initialize the event pool.
     if (!EventPool[DeviceId])
-      EventPool[DeviceId] = std::make_unique<EventPoolTy>(
-          EventAllocatorTy(DeviceData[DeviceId].Context), NumInitialEvents);
+      EventPool[DeviceId] =
+          std::make_unique<EventPoolTy>(EventAllocatorTy(), NumInitialEvents);
 
     // Query attributes to determine number of threads/block and blocks/grid.
     int MaxGridDimX;
@@ -806,18 +777,14 @@ class DeviceRTLTy {
 
   __tgt_target_table *loadBinary(const int DeviceId,
                                  const __tgt_device_image *Image) {
-    // Set the context we are using
-    CUresult Err = cuCtxSetCurrent(DeviceData[DeviceId].Context);
-    if (!checkResult(Err, "Error returned from cuCtxSetCurrent\n"))
-      return nullptr;
-
     // Clear the offload table as we are going to create a new one.
     clearOffloadEntriesTable(DeviceId);
 
     // Create the module and extract the function pointers.
     CUmodule Module;
     DP("Load data from image " DPxMOD "\n", DPxPTR(Image->ImageStart));
-    Err = cuModuleLoadDataEx(&Module, Image->ImageStart, 0, nullptr, nullptr);
+    CUresult Err =
+        cuModuleLoadDataEx(&Module, Image->ImageStart, 0, nullptr, nullptr);
     if (!checkResult(Err, "Error returned from cuModuleLoadDataEx\n"))
       return nullptr;
 
@@ -1004,13 +971,8 @@ class DeviceRTLTy {
                  const int64_t Size, __tgt_async_info *AsyncInfo) const {
     assert(AsyncInfo && "AsyncInfo is nullptr");
 
-    CUresult Err = cuCtxSetCurrent(DeviceData[DeviceId].Context);
-    if (!checkResult(Err, "Error returned from cuCtxSetCurrent\n"))
-      return OFFLOAD_FAIL;
-
     CUstream Stream = getStream(DeviceId, AsyncInfo);
-
-    Err = cuMemcpyHtoDAsync((CUdeviceptr)TgtPtr, HstPtr, Size, Stream);
+    CUresult Err = cuMemcpyHtoDAsync((CUdeviceptr)TgtPtr, HstPtr, Size, Stream);
     if (Err != CUDA_SUCCESS) {
       DP("Error when copying data from host to device. Pointers: host "
          "= " DPxMOD ", device = " DPxMOD ", size = %" PRId64 "\n",
@@ -1026,13 +988,8 @@ class DeviceRTLTy {
                    const int64_t Size, __tgt_async_info *AsyncInfo) const {
     assert(AsyncInfo && "AsyncInfo is nullptr");
 
-    CUresult Err = cuCtxSetCurrent(DeviceData[DeviceId].Context);
-    if (!checkResult(Err, "Error returned from cuCtxSetCurrent\n"))
-      return OFFLOAD_FAIL;
-
     CUstream Stream = getStream(DeviceId, AsyncInfo);
-
-    Err = cuMemcpyDtoHAsync(HstPtr, (CUdeviceptr)TgtPtr, Size, Stream);
+    CUresult Err = cuMemcpyDtoHAsync(HstPtr, (CUdeviceptr)TgtPtr, Size, Stream);
     if (Err != CUDA_SUCCESS) {
       DP("Error when copying data from device to host. Pointers: host "
          "= " DPxMOD ", device = " DPxMOD ", size = %" PRId64 "\n",
@@ -1048,10 +1005,7 @@ class DeviceRTLTy {
                    int64_t Size, __tgt_async_info *AsyncInfo) const {
     assert(AsyncInfo && "AsyncInfo is nullptr");
 
-    CUresult Err = cuCtxSetCurrent(DeviceData[SrcDevId].Context);
-    if (!checkResult(Err, "Error returned from cuCtxSetCurrent\n"))
-      return OFFLOAD_FAIL;
-
+    CUresult Err;
     CUstream Stream = getStream(SrcDevId, AsyncInfo);
 
     // If they are two devices, we try peer to peer copy first
@@ -1107,10 +1061,6 @@ class DeviceRTLTy {
                           const int TeamNum, const int ThreadLimit,
                           const unsigned int LoopTripCount,
                           __tgt_async_info *AsyncInfo) const {
-    CUresult Err = cuCtxSetCurrent(DeviceData[DeviceId].Context);
-    if (!checkResult(Err, "Error returned from cuCtxSetCurrent\n"))
-      return OFFLOAD_FAIL;
-
     // All args are references.
     std::vector<void *> Args(ArgNum);
     std::vector<void *> Ptrs(ArgNum);
@@ -1150,6 +1100,7 @@ class DeviceRTLTy {
       CudaThreadsPerBlock = DeviceData[DeviceId].ThreadsPerBlock;
     }
 
+    CUresult Err;
     if (!KernelInfo->MaxThreadsPerBlock) {
       Err = cuFuncGetAttribute(&KernelInfo->MaxThreadsPerBlock,
                                CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK,
@@ -1476,10 +1427,6 @@ class DeviceRTLTy {
   }
 
   int initAsyncInfo(int DeviceId, __tgt_async_info **AsyncInfo) const {
-    CUresult Err = cuCtxSetCurrent(DeviceData[DeviceId].Context);
-    if (!checkResult(Err, "error returned from cuCtxSetCurrent"))
-      return OFFLOAD_FAIL;
-
     *AsyncInfo = new __tgt_async_info;
     getStream(DeviceId, *AsyncInfo);
     return OFFLOAD_SUCCESS;
@@ -1503,6 +1450,16 @@ class DeviceRTLTy {
     }
     return OFFLOAD_SUCCESS;
   }
+
+  int setContext(int DeviceId) {
+    assert(InitializedFlags[DeviceId] && "Device is not initialized");
+
+    CUresult Err = cuCtxSetCurrent(DeviceData[DeviceId].Context);
+    if (!checkResult(Err, "error returned from cuCtxSetCurrent"))
+      return OFFLOAD_FAIL;
+
+    return OFFLOAD_SUCCESS;
+  }
 };
 
 DeviceRTLTy DeviceRTL;
@@ -1535,12 +1492,14 @@ int32_t __tgt_rtl_is_data_exchangable(int32_t src_dev_id, int dst_dev_id) {
 
 int32_t __tgt_rtl_init_device(int32_t device_id) {
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
+  // Context is set when init the device.
 
   return DeviceRTL.initDevice(device_id);
 }
 
 int32_t __tgt_rtl_deinit_device(int32_t device_id) {
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
+  // Context is set when deinit the device.
 
   return DeviceRTL.deinitDevice(device_id);
 }
@@ -1549,6 +1508,9 @@ __tgt_target_table *__tgt_rtl_load_binary(int32_t device_id,
                                           __tgt_device_image *image) {
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
 
+  if (!DeviceRTL.setContext(device_id))
+    return nullptr;
+
   return DeviceRTL.loadBinary(device_id, image);
 }
 
@@ -1556,12 +1518,16 @@ void *__tgt_rtl_data_alloc(int32_t device_id, int64_t size, void *,
                            int32_t kind) {
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
 
+  if (!DeviceRTL.setContext(device_id))
+    return nullptr;
+
   return DeviceRTL.dataAlloc(device_id, size, (TargetAllocTy)kind);
 }
 
 int32_t __tgt_rtl_data_submit(int32_t device_id, void *tgt_ptr, void *hst_ptr,
                               int64_t size) {
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
+  // Context is set in __tgt_rtl_data_submit_async.
 
   __tgt_async_info AsyncInfo;
   const int32_t rc = __tgt_rtl_data_submit_async(device_id, tgt_ptr, hst_ptr,
@@ -1578,6 +1544,9 @@ int32_t __tgt_rtl_data_submit_async(int32_t device_id, void *tgt_ptr,
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
   assert(async_info_ptr && "async_info_ptr is nullptr");
 
+  if (!DeviceRTL.setContext(device_id))
+    return OFFLOAD_FAIL;
+
   return DeviceRTL.dataSubmit(device_id, tgt_ptr, hst_ptr, size,
                               async_info_ptr);
 }
@@ -1585,6 +1554,7 @@ int32_t __tgt_rtl_data_submit_async(int32_t device_id, void *tgt_ptr,
 int32_t __tgt_rtl_data_retrieve(int32_t device_id, void *hst_ptr, void *tgt_ptr,
                                 int64_t size) {
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
+  // Context is set in __tgt_rtl_data_retrieve_async.
 
   __tgt_async_info AsyncInfo;
   const int32_t rc = __tgt_rtl_data_retrieve_async(device_id, hst_ptr, tgt_ptr,
@@ -1601,6 +1571,9 @@ int32_t __tgt_rtl_data_retrieve_async(int32_t device_id, void *hst_ptr,
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
   assert(async_info_ptr && "async_info_ptr is nullptr");
 
+  if (!DeviceRTL.setContext(device_id))
+    return OFFLOAD_FAIL;
+
   return DeviceRTL.dataRetrieve(device_id, hst_ptr, tgt_ptr, size,
                                 async_info_ptr);
 }
@@ -1612,7 +1585,8 @@ int32_t __tgt_rtl_data_exchange_async(int32_t src_dev_id, void *src_ptr,
   assert(DeviceRTL.isValidDeviceId(src_dev_id) && "src_dev_id is invalid");
   assert(DeviceRTL.isValidDeviceId(dst_dev_id) && "dst_dev_id is invalid");
   assert(AsyncInfo && "AsyncInfo is nullptr");
-
+  // NOTE: We don't need to set context for data exchange as the device contexts
+  // are passed to CUDA function directly.
   return DeviceRTL.dataExchange(src_dev_id, src_ptr, dst_dev_id, dst_ptr, size,
                                 AsyncInfo);
 }
@@ -1622,6 +1596,7 @@ int32_t __tgt_rtl_data_exchange(int32_t src_dev_id, void *src_ptr,
                                 int64_t size) {
   assert(DeviceRTL.isValidDeviceId(src_dev_id) && "src_dev_id is invalid");
   assert(DeviceRTL.isValidDeviceId(dst_dev_id) && "dst_dev_id is invalid");
+  // Context is set in __tgt_rtl_data_exchange_async.
 
   __tgt_async_info AsyncInfo;
   const int32_t rc = __tgt_rtl_data_exchange_async(
@@ -1635,6 +1610,9 @@ int32_t __tgt_rtl_data_exchange(int32_t src_dev_id, void *src_ptr,
 int32_t __tgt_rtl_data_delete(int32_t device_id, void *tgt_ptr) {
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
 
+  if (!DeviceRTL.setContext(device_id))
+    return OFFLOAD_FAIL;
+
   return DeviceRTL.dataDelete(device_id, tgt_ptr);
 }
 
@@ -1645,6 +1623,7 @@ int32_t __tgt_rtl_run_target_team_region(int32_t device_id, void *tgt_entry_ptr,
                                          int32_t thread_limit,
                                          uint64_t loop_tripcount) {
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
+  // Context is set in __tgt_rtl_run_target_team_region_async.
 
   __tgt_async_info AsyncInfo;
   const int32_t rc = __tgt_rtl_run_target_team_region_async(
@@ -1663,6 +1642,9 @@ int32_t __tgt_rtl_run_target_team_region_async(
     __tgt_async_info *async_info_ptr) {
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
 
+  if (!DeviceRTL.setContext(device_id))
+    return OFFLOAD_FAIL;
+
   return DeviceRTL.runTargetTeamRegion(
       device_id, tgt_entry_ptr, tgt_args, tgt_offsets, arg_num, team_num,
       thread_limit, loop_tripcount, async_info_ptr);
@@ -1672,6 +1654,7 @@ int32_t __tgt_rtl_run_target_region(int32_t device_id, void *tgt_entry_ptr,
                                     void **tgt_args, ptr
diff _t *tgt_offsets,
                                     int32_t arg_num) {
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
+  // Context is set in __tgt_rtl_run_target_region_async.
 
   __tgt_async_info AsyncInfo;
   const int32_t rc = __tgt_rtl_run_target_region_async(
@@ -1688,7 +1671,7 @@ int32_t __tgt_rtl_run_target_region_async(int32_t device_id,
                                           int32_t arg_num,
                                           __tgt_async_info *async_info_ptr) {
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
-
+  // Context is set in __tgt_rtl_run_target_team_region_async.
   return __tgt_rtl_run_target_team_region_async(
       device_id, tgt_entry_ptr, tgt_args, tgt_offsets, arg_num,
       /* team num*/ 1, /* thread_limit */ 1, /* loop_tripcount */ 0,
@@ -1700,7 +1683,7 @@ int32_t __tgt_rtl_synchronize(int32_t device_id,
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
   assert(async_info_ptr && "async_info_ptr is nullptr");
   assert(async_info_ptr->Queue && "async_info_ptr->Queue is nullptr");
-
+  // NOTE: We don't need to set context for stream sync.
   return DeviceRTL.synchronize(device_id, async_info_ptr);
 }
 
@@ -1711,11 +1694,16 @@ void __tgt_rtl_set_info_flag(uint32_t NewInfoLevel) {
 
 void __tgt_rtl_print_device_info(int32_t device_id) {
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
+  // NOTE: We don't need to set context for print device info.
   DeviceRTL.printDeviceInfo(device_id);
 }
 
 int32_t __tgt_rtl_create_event(int32_t device_id, void **event) {
   assert(event && "event is nullptr");
+
+  if (!DeviceRTL.setContext(device_id))
+    return OFFLOAD_FAIL;
+
   return DeviceRTL.createEvent(device_id, event);
 }
 
@@ -1724,7 +1712,7 @@ int32_t __tgt_rtl_record_event(int32_t device_id, void *event_ptr,
   assert(async_info_ptr && "async_info_ptr is nullptr");
   assert(async_info_ptr->Queue && "async_info_ptr->Queue is nullptr");
   assert(event_ptr && "event_ptr is nullptr");
-
+  // NOTE: We might not need to set context for event record.
   return recordEvent(event_ptr, async_info_ptr);
 }
 
@@ -1733,19 +1721,22 @@ int32_t __tgt_rtl_wait_event(int32_t device_id, void *event_ptr,
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
   assert(async_info_ptr && "async_info_ptr is nullptr");
   assert(event_ptr && "event is nullptr");
-
+  // NOTE: We might not need to set context for event sync.
   return DeviceRTL.waitEvent(device_id, async_info_ptr, event_ptr);
 }
 
 int32_t __tgt_rtl_sync_event(int32_t device_id, void *event_ptr) {
   assert(event_ptr && "event is nullptr");
-
+  // NOTE: We might not need to set context for event sync.
   return syncEvent(event_ptr);
 }
 
 int32_t __tgt_rtl_destroy_event(int32_t device_id, void *event_ptr) {
   assert(event_ptr && "event is nullptr");
 
+  if (!DeviceRTL.setContext(device_id))
+    return OFFLOAD_FAIL;
+
   return DeviceRTL.destroyEvent(device_id, event_ptr);
 }
 
@@ -1754,6 +1745,9 @@ int32_t __tgt_rtl_release_async_info(int32_t device_id,
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
   assert(async_info && "async_info is nullptr");
 
+  if (!DeviceRTL.setContext(device_id))
+    return OFFLOAD_FAIL;
+
   return DeviceRTL.releaseAsyncInfo(device_id, async_info);
 }
 
@@ -1762,6 +1756,9 @@ int32_t __tgt_rtl_init_async_info(int32_t device_id,
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
   assert(async_info && "async_info is nullptr");
 
+  if (!DeviceRTL.setContext(device_id))
+    return OFFLOAD_FAIL;
+
   return DeviceRTL.initAsyncInfo(device_id, async_info);
 }
 
@@ -1771,6 +1768,9 @@ int32_t __tgt_rtl_init_device_info(int32_t device_id,
   assert(DeviceRTL.isValidDeviceId(device_id) && "device_id is invalid");
   assert(device_info_ptr && "device_info_ptr is nullptr");
 
+  if (!DeviceRTL.setContext(device_id))
+    return OFFLOAD_FAIL;
+
   return DeviceRTL.initDeviceInfo(device_id, device_info_ptr, err_str);
 }
 


        


More information about the Openmp-commits mailing list