[Openmp-commits] [openmp] [Libomptarget] Factor functions out of 'Plugin' interface (PR #86528)

Joseph Huber via Openmp-commits openmp-commits at lists.llvm.org
Mon Mar 25 09:36:06 PDT 2024


https://github.com/jhuber6 created https://github.com/llvm/llvm-project/pull/86528

Summary:
This patch factors common functions out of the `Plugin` interface prior
to its removal in a future patch. This simply temporarily renames it to
`PluginTy` so that we could re-use `Plugin::check` internally as this
needs to be defined statically per plugin now. We can refactor this
later.

The future patch will delete `PluginTy` and `PluginTy::get` entirely.
This simply tries to minimize a few changes to make it easier to land.


>From 44c087ef4b0cb8056717e7f2e0ba9be59301c099 Mon Sep 17 00:00:00 2001
From: Joseph Huber <huberjn at outlook.com>
Date: Mon, 25 Mar 2024 11:33:29 -0500
Subject: [PATCH] [Libomptarget] Factor functions out of 'Plugin' interface

Summary:
This patch factors common functions out of the `Plugin` interface prior
to its removal in a future patch. This simply temporarily renames it to
`PluginTy` so that we could re-use `Plugin::check` internally as this
needs to be defined statically per plugin now. We can refactor this
later.

The future patch will delete `PluginTy` and `PluginTy::get` entirely.
This simply tries to minimize a few changes to make it easier to land.
---
 openmp/libomptarget/CMakeLists.txt            |  3 +-
 .../plugins-nextgen/amdgpu/src/rtl.cpp        | 39 +++++-----
 .../common/include/PluginInterface.h          | 70 +++++++++--------
 .../common/src/PluginInterface.cpp            | 78 +++++++++----------
 .../plugins-nextgen/cuda/src/rtl.cpp          | 26 ++++---
 .../plugins-nextgen/host/src/rtl.cpp          | 28 +++----
 6 files changed, 126 insertions(+), 118 deletions(-)

diff --git a/openmp/libomptarget/CMakeLists.txt b/openmp/libomptarget/CMakeLists.txt
index b382137b70ee2a..531198fae01699 100644
--- a/openmp/libomptarget/CMakeLists.txt
+++ b/openmp/libomptarget/CMakeLists.txt
@@ -145,8 +145,7 @@ add_subdirectory(DeviceRTL)
 add_subdirectory(tools)
 
 # Build target agnostic offloading library.
-set(LIBOMPTARGET_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/src)
-add_subdirectory(${LIBOMPTARGET_SRC_DIR})
+add_subdirectory(src)
 
 # Add tests.
 add_subdirectory(test)
diff --git a/openmp/libomptarget/plugins-nextgen/amdgpu/src/rtl.cpp b/openmp/libomptarget/plugins-nextgen/amdgpu/src/rtl.cpp
index fce7454bf2800d..14461f14dd6706 100644
--- a/openmp/libomptarget/plugins-nextgen/amdgpu/src/rtl.cpp
+++ b/openmp/libomptarget/plugins-nextgen/amdgpu/src/rtl.cpp
@@ -2088,7 +2088,7 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
   /// Allocate and construct an AMDGPU kernel.
   Expected<GenericKernelTy &> constructKernel(const char *Name) override {
     // Allocate and construct the AMDGPU kernel.
-    AMDGPUKernelTy *AMDGPUKernel = Plugin::get().allocate<AMDGPUKernelTy>();
+    AMDGPUKernelTy *AMDGPUKernel = PluginTy::get().allocate<AMDGPUKernelTy>();
     if (!AMDGPUKernel)
       return Plugin::error("Failed to allocate memory for AMDGPU kernel");
 
@@ -2139,7 +2139,7 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
                                            int32_t ImageId) override {
     // Allocate and initialize the image object.
     AMDGPUDeviceImageTy *AMDImage =
-        Plugin::get().allocate<AMDGPUDeviceImageTy>();
+        PluginTy::get().allocate<AMDGPUDeviceImageTy>();
     new (AMDImage) AMDGPUDeviceImageTy(ImageId, *this, TgtImage);
 
     // Load the HSA executable.
@@ -2697,7 +2697,7 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
   }
   Error setDeviceHeapSize(uint64_t Value) override {
     for (DeviceImageTy *Image : LoadedImages)
-      if (auto Err = setupDeviceMemoryPool(Plugin::get(), *Image, Value))
+      if (auto Err = setupDeviceMemoryPool(PluginTy::get(), *Image, Value))
         return Err;
     DeviceMemoryPoolSize = Value;
     return Plugin::success();
@@ -2737,7 +2737,7 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
     return utils::iterateAgentMemoryPools(
         Agent, [&](hsa_amd_memory_pool_t HSAMemoryPool) {
           AMDGPUMemoryPoolTy *MemoryPool =
-              Plugin::get().allocate<AMDGPUMemoryPoolTy>();
+              PluginTy::get().allocate<AMDGPUMemoryPoolTy>();
           new (MemoryPool) AMDGPUMemoryPoolTy(HSAMemoryPool);
           AllMemoryPools.push_back(MemoryPool);
           return HSA_STATUS_SUCCESS;
@@ -3115,6 +3115,17 @@ struct AMDGPUPluginTy final : public GenericPluginTy {
     return Plugin::check(Status, "Error in hsa_shut_down: %s");
   }
 
+  /// Creates an AMDGPU device.
+  GenericDeviceTy *createDevice(int32_t DeviceId, int32_t NumDevices) override {
+    return new AMDGPUDeviceTy(DeviceId, NumDevices, getHostDevice(),
+                              getKernelAgent(DeviceId));
+  }
+
+  /// Creates an AMDGPU global handler.
+  GenericGlobalHandlerTy *createGlobalHandler() override {
+    return new AMDGPUGlobalHandlerTy();
+  }
+
   Triple::ArchType getTripleArch() const override { return Triple::amdgcn; }
 
   /// Get the ELF code for recognizing the compatible image binary.
@@ -3237,7 +3248,7 @@ Error AMDGPUKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
   // 56 bytes per allocation.
   uint32_t AllArgsSize = KernelArgsSize + ImplicitArgsSize;
 
-  AMDHostDeviceTy &HostDevice = Plugin::get<AMDGPUPluginTy>().getHostDevice();
+  AMDHostDeviceTy &HostDevice = PluginTy::get<AMDGPUPluginTy>().getHostDevice();
   AMDGPUMemoryManagerTy &ArgsMemoryManager = HostDevice.getArgsMemoryManager();
 
   void *AllArgs = nullptr;
@@ -3347,20 +3358,10 @@ Error AMDGPUKernelTy::printLaunchInfoDetails(GenericDeviceTy &GenericDevice,
   return Plugin::success();
 }
 
-GenericPluginTy *Plugin::createPlugin() { return new AMDGPUPluginTy(); }
-
-GenericDeviceTy *Plugin::createDevice(int32_t DeviceId, int32_t NumDevices) {
-  AMDGPUPluginTy &Plugin = get<AMDGPUPluginTy &>();
-  return new AMDGPUDeviceTy(DeviceId, NumDevices, Plugin.getHostDevice(),
-                            Plugin.getKernelAgent(DeviceId));
-}
-
-GenericGlobalHandlerTy *Plugin::createGlobalHandler() {
-  return new AMDGPUGlobalHandlerTy();
-}
+GenericPluginTy *PluginTy::createPlugin() { return new AMDGPUPluginTy(); }
 
 template <typename... ArgsTy>
-Error Plugin::check(int32_t Code, const char *ErrFmt, ArgsTy... Args) {
+static Error Plugin::check(int32_t Code, const char *ErrFmt, ArgsTy... Args) {
   hsa_status_t ResultCode = static_cast<hsa_status_t>(Code);
   if (ResultCode == HSA_STATUS_SUCCESS || ResultCode == HSA_STATUS_INFO_BREAK)
     return Error::success();
@@ -3384,7 +3385,7 @@ void *AMDGPUMemoryManagerTy::allocate(size_t Size, void *HstPtr,
   }
   assert(Ptr && "Invalid pointer");
 
-  auto &KernelAgents = Plugin::get<AMDGPUPluginTy>().getKernelAgents();
+  auto &KernelAgents = PluginTy::get<AMDGPUPluginTy>().getKernelAgents();
 
   // Allow all kernel agents to access the allocation.
   if (auto Err = MemoryPool->enableAccess(Ptr, Size, KernelAgents)) {
@@ -3427,7 +3428,7 @@ void *AMDGPUDeviceTy::allocate(size_t Size, void *, TargetAllocTy Kind) {
   }
 
   if (Alloc) {
-    auto &KernelAgents = Plugin::get<AMDGPUPluginTy>().getKernelAgents();
+    auto &KernelAgents = PluginTy::get<AMDGPUPluginTy>().getKernelAgents();
     // Inherently necessary for host or shared allocations
     // Also enabled for device memory to allow device to device memcpy
 
diff --git a/openmp/libomptarget/plugins-nextgen/common/include/PluginInterface.h b/openmp/libomptarget/plugins-nextgen/common/include/PluginInterface.h
index b7be7b645ba33e..1440f8f678d820 100644
--- a/openmp/libomptarget/plugins-nextgen/common/include/PluginInterface.h
+++ b/openmp/libomptarget/plugins-nextgen/common/include/PluginInterface.h
@@ -976,6 +976,13 @@ struct GenericPluginTy {
   Error deinit();
   virtual Error deinitImpl() = 0;
 
+  /// Create a new device for the underlying plugin.
+  virtual GenericDeviceTy *createDevice(int32_t DeviceID,
+                                        int32_t NumDevices) = 0;
+
+  /// Create a new global handler for the underlying plugin.
+  virtual GenericGlobalHandlerTy *createGlobalHandler() = 0;
+
   /// Get the reference to the device with a certain device id.
   GenericDeviceTy &getDevice(int32_t DeviceId) {
     assert(isValidDeviceId(DeviceId) && "Invalid device id");
@@ -1085,29 +1092,53 @@ struct GenericPluginTy {
   RPCServerTy *RPCServer;
 };
 
+namespace Plugin {
+/// Create a success error. This is the same as calling Error::success(), but
+/// it is recommended to use this one for consistency with Plugin::error() and
+/// Plugin::check().
+static Error success() { return Error::success(); }
+
+/// Create a string error.
+template <typename... ArgsTy>
+static Error error(const char *ErrFmt, ArgsTy... Args) {
+  return createStringError(inconvertibleErrorCode(), ErrFmt, Args...);
+}
+
+/// Check the plugin-specific error code and return an error or success
+/// accordingly. In case of an error, create a string error with the error
+/// description. The ErrFmt should follow the format:
+///     "Error in <function name>[<optional info>]: %s"
+/// The last format specifier "%s" is mandatory and will be used to place the
+/// error code's description. Notice this function should be only called from
+/// the plugin-specific code.
+/// TODO: Refactor this, must be defined individually by each plugin.
+template <typename... ArgsTy>
+static Error check(int32_t ErrorCode, const char *ErrFmt, ArgsTy... Args);
+} // namespace Plugin
+
 /// Class for simplifying the getter operation of the plugin. Anywhere on the
 /// code, the current plugin can be retrieved by Plugin::get(). The class also
 /// declares functions to create plugin-specific object instances. The check(),
 /// createPlugin(), createDevice() and createGlobalHandler() functions should be
 /// defined by each plugin implementation.
-class Plugin {
+class PluginTy {
   // Reference to the plugin instance.
   static GenericPluginTy *SpecificPlugin;
 
-  Plugin() {
+  PluginTy() {
     if (auto Err = init())
       REPORT("Failed to initialize plugin: %s\n",
              toString(std::move(Err)).data());
   }
 
-  ~Plugin() {
+  ~PluginTy() {
     if (auto Err = deinit())
       REPORT("Failed to deinitialize plugin: %s\n",
              toString(std::move(Err)).data());
   }
 
-  Plugin(const Plugin &) = delete;
-  void operator=(const Plugin &) = delete;
+  PluginTy(const PluginTy &) = delete;
+  void operator=(const PluginTy &) = delete;
 
   /// Create and intialize the plugin instance.
   static Error init() {
@@ -1158,7 +1189,7 @@ class Plugin {
     // This static variable will initialize the underlying plugin instance in
     // case there was no previous explicit initialization. The initialization is
     // thread safe.
-    static Plugin Plugin;
+    static PluginTy Plugin;
 
     assert(SpecificPlugin && "Plugin is not active");
     return *SpecificPlugin;
@@ -1170,35 +1201,8 @@ class Plugin {
   /// Indicate whether the plugin is active.
   static bool isActive() { return SpecificPlugin != nullptr; }
 
-  /// Create a success error. This is the same as calling Error::success(), but
-  /// it is recommended to use this one for consistency with Plugin::error() and
-  /// Plugin::check().
-  static Error success() { return Error::success(); }
-
-  /// Create a string error.
-  template <typename... ArgsTy>
-  static Error error(const char *ErrFmt, ArgsTy... Args) {
-    return createStringError(inconvertibleErrorCode(), ErrFmt, Args...);
-  }
-
-  /// Check the plugin-specific error code and return an error or success
-  /// accordingly. In case of an error, create a string error with the error
-  /// description. The ErrFmt should follow the format:
-  ///     "Error in <function name>[<optional info>]: %s"
-  /// The last format specifier "%s" is mandatory and will be used to place the
-  /// error code's description. Notice this function should be only called from
-  /// the plugin-specific code.
-  template <typename... ArgsTy>
-  static Error check(int32_t ErrorCode, const char *ErrFmt, ArgsTy... Args);
-
   /// Create a plugin instance.
   static GenericPluginTy *createPlugin();
-
-  /// Create a plugin-specific device.
-  static GenericDeviceTy *createDevice(int32_t DeviceId, int32_t NumDevices);
-
-  /// Create a plugin-specific global handler.
-  static GenericGlobalHandlerTy *createGlobalHandler();
 };
 
 /// Auxiliary interface class for GenericDeviceResourceManagerTy. This class
diff --git a/openmp/libomptarget/plugins-nextgen/common/src/PluginInterface.cpp b/openmp/libomptarget/plugins-nextgen/common/src/PluginInterface.cpp
index f39f913d85eec2..fe937e984523fd 100644
--- a/openmp/libomptarget/plugins-nextgen/common/src/PluginInterface.cpp
+++ b/openmp/libomptarget/plugins-nextgen/common/src/PluginInterface.cpp
@@ -39,7 +39,7 @@ using namespace omp;
 using namespace target;
 using namespace plugin;
 
-GenericPluginTy *Plugin::SpecificPlugin = nullptr;
+GenericPluginTy *PluginTy::SpecificPlugin = nullptr;
 
 // TODO: Fix any thread safety issues for multi-threaded kernel recording.
 struct RecordReplayTy {
@@ -438,7 +438,7 @@ Error GenericKernelTy::init(GenericDeviceTy &GenericDevice,
   // Retrieve kernel environment object for the kernel.
   GlobalTy KernelEnv(std::string(Name) + "_kernel_environment",
                      sizeof(KernelEnvironment), &KernelEnvironment);
-  GenericGlobalHandlerTy &GHandler = Plugin::get().getGlobalHandler();
+  GenericGlobalHandlerTy &GHandler = PluginTy::get().getGlobalHandler();
   if (auto Err =
           GHandler.readGlobalFromImage(GenericDevice, *ImagePtr, KernelEnv)) {
     [[maybe_unused]] std::string ErrStr = toString(std::move(Err));
@@ -1488,7 +1488,7 @@ Error GenericPluginTy::init() {
   assert(Devices.size() == 0 && "Plugin already initialized");
   Devices.resize(NumDevices, nullptr);
 
-  GlobalHandler = Plugin::createGlobalHandler();
+  GlobalHandler = createGlobalHandler();
   assert(GlobalHandler && "Invalid global handler");
 
   RPCServer = new RPCServerTy(NumDevices);
@@ -1522,7 +1522,7 @@ Error GenericPluginTy::initDevice(int32_t DeviceId) {
   assert(!Devices[DeviceId] && "Device already initialized");
 
   // Create the device and save the reference.
-  GenericDeviceTy *Device = Plugin::createDevice(DeviceId, NumDevices);
+  GenericDeviceTy *Device = createDevice(DeviceId, NumDevices);
   assert(Device && "Invalid device");
 
   // Save the device reference into the list.
@@ -1581,7 +1581,7 @@ extern "C" {
 #endif
 
 int32_t __tgt_rtl_init_plugin() {
-  auto Err = Plugin::initIfNeeded();
+  auto Err = PluginTy::initIfNeeded();
   if (Err) {
     [[maybe_unused]] std::string ErrStr = toString(std::move(Err));
     DP("Failed to init plugin: %s", ErrStr.c_str());
@@ -1592,7 +1592,7 @@ int32_t __tgt_rtl_init_plugin() {
 }
 
 int32_t __tgt_rtl_is_valid_binary(__tgt_device_image *Image) {
-  if (!Plugin::isActive())
+  if (!PluginTy::isActive())
     return false;
 
   StringRef Buffer(reinterpret_cast<const char *>(Image->ImageStart),
@@ -1609,13 +1609,13 @@ int32_t __tgt_rtl_is_valid_binary(__tgt_device_image *Image) {
   case file_magic::elf_executable:
   case file_magic::elf_shared_object:
   case file_magic::elf_core: {
-    auto MatchOrErr = Plugin::get().checkELFImage(Buffer);
+    auto MatchOrErr = PluginTy::get().checkELFImage(Buffer);
     if (Error Err = MatchOrErr.takeError())
       return HandleError(std::move(Err));
     return *MatchOrErr;
   }
   case file_magic::bitcode: {
-    auto MatchOrErr = Plugin::get().getJIT().checkBitcodeImage(Buffer);
+    auto MatchOrErr = PluginTy::get().getJIT().checkBitcodeImage(Buffer);
     if (Error Err = MatchOrErr.takeError())
       return HandleError(std::move(Err));
     return *MatchOrErr;
@@ -1626,7 +1626,7 @@ int32_t __tgt_rtl_is_valid_binary(__tgt_device_image *Image) {
 }
 
 int32_t __tgt_rtl_init_device(int32_t DeviceId) {
-  auto Err = Plugin::get().initDevice(DeviceId);
+  auto Err = PluginTy::get().initDevice(DeviceId);
   if (Err) {
     REPORT("Failure to initialize device %d: %s\n", DeviceId,
            toString(std::move(Err)).data());
@@ -1636,23 +1636,23 @@ int32_t __tgt_rtl_init_device(int32_t DeviceId) {
   return OFFLOAD_SUCCESS;
 }
 
-int32_t __tgt_rtl_number_of_devices() { return Plugin::get().getNumDevices(); }
+int32_t __tgt_rtl_number_of_devices() { return PluginTy::get().getNumDevices(); }
 
 int64_t __tgt_rtl_init_requires(int64_t RequiresFlags) {
-  Plugin::get().setRequiresFlag(RequiresFlags);
+  PluginTy::get().setRequiresFlag(RequiresFlags);
   return OFFLOAD_SUCCESS;
 }
 
 int32_t __tgt_rtl_is_data_exchangable(int32_t SrcDeviceId,
                                       int32_t DstDeviceId) {
-  return Plugin::get().isDataExchangable(SrcDeviceId, DstDeviceId);
+  return PluginTy::get().isDataExchangable(SrcDeviceId, DstDeviceId);
 }
 
 int32_t __tgt_rtl_initialize_record_replay(int32_t DeviceId, int64_t MemorySize,
                                            void *VAddr, bool isRecord,
                                            bool SaveOutput,
                                            uint64_t &ReqPtrArgOffset) {
-  GenericPluginTy &Plugin = Plugin::get();
+  GenericPluginTy &Plugin = PluginTy::get();
   GenericDeviceTy &Device = Plugin.getDevice(DeviceId);
   RecordReplayTy::RRStatusTy Status =
       isRecord ? RecordReplayTy::RRStatusTy::RRRecording
@@ -1674,7 +1674,7 @@ int32_t __tgt_rtl_initialize_record_replay(int32_t DeviceId, int64_t MemorySize,
 
 int32_t __tgt_rtl_load_binary(int32_t DeviceId, __tgt_device_image *TgtImage,
                               __tgt_device_binary *Binary) {
-  GenericPluginTy &Plugin = Plugin::get();
+  GenericPluginTy &Plugin = PluginTy::get();
   GenericDeviceTy &Device = Plugin.getDevice(DeviceId);
 
   auto ImageOrErr = Device.loadBinary(Plugin, TgtImage);
@@ -1695,7 +1695,7 @@ int32_t __tgt_rtl_load_binary(int32_t DeviceId, __tgt_device_image *TgtImage,
 
 void *__tgt_rtl_data_alloc(int32_t DeviceId, int64_t Size, void *HostPtr,
                            int32_t Kind) {
-  auto AllocOrErr = Plugin::get().getDevice(DeviceId).dataAlloc(
+  auto AllocOrErr = PluginTy::get().getDevice(DeviceId).dataAlloc(
       Size, HostPtr, (TargetAllocTy)Kind);
   if (!AllocOrErr) {
     auto Err = AllocOrErr.takeError();
@@ -1710,7 +1710,7 @@ void *__tgt_rtl_data_alloc(int32_t DeviceId, int64_t Size, void *HostPtr,
 
 int32_t __tgt_rtl_data_delete(int32_t DeviceId, void *TgtPtr, int32_t Kind) {
   auto Err =
-      Plugin::get().getDevice(DeviceId).dataDelete(TgtPtr, (TargetAllocTy)Kind);
+      PluginTy::get().getDevice(DeviceId).dataDelete(TgtPtr, (TargetAllocTy)Kind);
   if (Err) {
     REPORT("Failure to deallocate device pointer %p: %s\n", TgtPtr,
            toString(std::move(Err)).data());
@@ -1722,7 +1722,7 @@ int32_t __tgt_rtl_data_delete(int32_t DeviceId, void *TgtPtr, int32_t Kind) {
 
 int32_t __tgt_rtl_data_lock(int32_t DeviceId, void *Ptr, int64_t Size,
                             void **LockedPtr) {
-  auto LockedPtrOrErr = Plugin::get().getDevice(DeviceId).dataLock(Ptr, Size);
+  auto LockedPtrOrErr = PluginTy::get().getDevice(DeviceId).dataLock(Ptr, Size);
   if (!LockedPtrOrErr) {
     auto Err = LockedPtrOrErr.takeError();
     REPORT("Failure to lock memory %p: %s\n", Ptr,
@@ -1740,7 +1740,7 @@ int32_t __tgt_rtl_data_lock(int32_t DeviceId, void *Ptr, int64_t Size,
 }
 
 int32_t __tgt_rtl_data_unlock(int32_t DeviceId, void *Ptr) {
-  auto Err = Plugin::get().getDevice(DeviceId).dataUnlock(Ptr);
+  auto Err = PluginTy::get().getDevice(DeviceId).dataUnlock(Ptr);
   if (Err) {
     REPORT("Failure to unlock memory %p: %s\n", Ptr,
            toString(std::move(Err)).data());
@@ -1752,7 +1752,7 @@ int32_t __tgt_rtl_data_unlock(int32_t DeviceId, void *Ptr) {
 
 int32_t __tgt_rtl_data_notify_mapped(int32_t DeviceId, void *HstPtr,
                                      int64_t Size) {
-  auto Err = Plugin::get().getDevice(DeviceId).notifyDataMapped(HstPtr, Size);
+  auto Err = PluginTy::get().getDevice(DeviceId).notifyDataMapped(HstPtr, Size);
   if (Err) {
     REPORT("Failure to notify data mapped %p: %s\n", HstPtr,
            toString(std::move(Err)).data());
@@ -1763,7 +1763,7 @@ int32_t __tgt_rtl_data_notify_mapped(int32_t DeviceId, void *HstPtr,
 }
 
 int32_t __tgt_rtl_data_notify_unmapped(int32_t DeviceId, void *HstPtr) {
-  auto Err = Plugin::get().getDevice(DeviceId).notifyDataUnmapped(HstPtr);
+  auto Err = PluginTy::get().getDevice(DeviceId).notifyDataUnmapped(HstPtr);
   if (Err) {
     REPORT("Failure to notify data unmapped %p: %s\n", HstPtr,
            toString(std::move(Err)).data());
@@ -1782,7 +1782,7 @@ int32_t __tgt_rtl_data_submit(int32_t DeviceId, void *TgtPtr, void *HstPtr,
 int32_t __tgt_rtl_data_submit_async(int32_t DeviceId, void *TgtPtr,
                                     void *HstPtr, int64_t Size,
                                     __tgt_async_info *AsyncInfoPtr) {
-  auto Err = Plugin::get().getDevice(DeviceId).dataSubmit(TgtPtr, HstPtr, Size,
+  auto Err = PluginTy::get().getDevice(DeviceId).dataSubmit(TgtPtr, HstPtr, Size,
                                                           AsyncInfoPtr);
   if (Err) {
     REPORT("Failure to copy data from host to device. Pointers: host "
@@ -1804,7 +1804,7 @@ int32_t __tgt_rtl_data_retrieve(int32_t DeviceId, void *HstPtr, void *TgtPtr,
 int32_t __tgt_rtl_data_retrieve_async(int32_t DeviceId, void *HstPtr,
                                       void *TgtPtr, int64_t Size,
                                       __tgt_async_info *AsyncInfoPtr) {
-  auto Err = Plugin::get().getDevice(DeviceId).dataRetrieve(HstPtr, TgtPtr,
+  auto Err = PluginTy::get().getDevice(DeviceId).dataRetrieve(HstPtr, TgtPtr,
                                                             Size, AsyncInfoPtr);
   if (Err) {
     REPORT("Faliure to copy data from device to host. Pointers: host "
@@ -1829,8 +1829,8 @@ int32_t __tgt_rtl_data_exchange_async(int32_t SrcDeviceId, void *SrcPtr,
                                       int DstDeviceId, void *DstPtr,
                                       int64_t Size,
                                       __tgt_async_info *AsyncInfo) {
-  GenericDeviceTy &SrcDevice = Plugin::get().getDevice(SrcDeviceId);
-  GenericDeviceTy &DstDevice = Plugin::get().getDevice(DstDeviceId);
+  GenericDeviceTy &SrcDevice = PluginTy::get().getDevice(SrcDeviceId);
+  GenericDeviceTy &DstDevice = PluginTy::get().getDevice(DstDeviceId);
   auto Err = SrcDevice.dataExchange(SrcPtr, DstDevice, DstPtr, Size, AsyncInfo);
   if (Err) {
     REPORT("Failure to copy data from device (%d) to device (%d). Pointers: "
@@ -1847,7 +1847,7 @@ int32_t __tgt_rtl_launch_kernel(int32_t DeviceId, void *TgtEntryPtr,
                                 void **TgtArgs, ptrdiff_t *TgtOffsets,
                                 KernelArgsTy *KernelArgs,
                                 __tgt_async_info *AsyncInfoPtr) {
-  auto Err = Plugin::get().getDevice(DeviceId).launchKernel(
+  auto Err = PluginTy::get().getDevice(DeviceId).launchKernel(
       TgtEntryPtr, TgtArgs, TgtOffsets, *KernelArgs, AsyncInfoPtr);
   if (Err) {
     REPORT("Failure to run target region " DPxMOD " in device %d: %s\n",
@@ -1860,7 +1860,7 @@ int32_t __tgt_rtl_launch_kernel(int32_t DeviceId, void *TgtEntryPtr,
 
 int32_t __tgt_rtl_synchronize(int32_t DeviceId,
                               __tgt_async_info *AsyncInfoPtr) {
-  auto Err = Plugin::get().getDevice(DeviceId).synchronize(AsyncInfoPtr);
+  auto Err = PluginTy::get().getDevice(DeviceId).synchronize(AsyncInfoPtr);
   if (Err) {
     REPORT("Failure to synchronize stream %p: %s\n", AsyncInfoPtr->Queue,
            toString(std::move(Err)).data());
@@ -1872,7 +1872,7 @@ int32_t __tgt_rtl_synchronize(int32_t DeviceId,
 
 int32_t __tgt_rtl_query_async(int32_t DeviceId,
                               __tgt_async_info *AsyncInfoPtr) {
-  auto Err = Plugin::get().getDevice(DeviceId).queryAsync(AsyncInfoPtr);
+  auto Err = PluginTy::get().getDevice(DeviceId).queryAsync(AsyncInfoPtr);
   if (Err) {
     REPORT("Failure to query stream %p: %s\n", AsyncInfoPtr->Queue,
            toString(std::move(Err)).data());
@@ -1883,13 +1883,13 @@ int32_t __tgt_rtl_query_async(int32_t DeviceId,
 }
 
 void __tgt_rtl_print_device_info(int32_t DeviceId) {
-  if (auto Err = Plugin::get().getDevice(DeviceId).printInfo())
+  if (auto Err = PluginTy::get().getDevice(DeviceId).printInfo())
     REPORT("Failure to print device %d info: %s\n", DeviceId,
            toString(std::move(Err)).data());
 }
 
 int32_t __tgt_rtl_create_event(int32_t DeviceId, void **EventPtr) {
-  auto Err = Plugin::get().getDevice(DeviceId).createEvent(EventPtr);
+  auto Err = PluginTy::get().getDevice(DeviceId).createEvent(EventPtr);
   if (Err) {
     REPORT("Failure to create event: %s\n", toString(std::move(Err)).data());
     return OFFLOAD_FAIL;
@@ -1901,7 +1901,7 @@ int32_t __tgt_rtl_create_event(int32_t DeviceId, void **EventPtr) {
 int32_t __tgt_rtl_record_event(int32_t DeviceId, void *EventPtr,
                                __tgt_async_info *AsyncInfoPtr) {
   auto Err =
-      Plugin::get().getDevice(DeviceId).recordEvent(EventPtr, AsyncInfoPtr);
+      PluginTy::get().getDevice(DeviceId).recordEvent(EventPtr, AsyncInfoPtr);
   if (Err) {
     REPORT("Failure to record event %p: %s\n", EventPtr,
            toString(std::move(Err)).data());
@@ -1914,7 +1914,7 @@ int32_t __tgt_rtl_record_event(int32_t DeviceId, void *EventPtr,
 int32_t __tgt_rtl_wait_event(int32_t DeviceId, void *EventPtr,
                              __tgt_async_info *AsyncInfoPtr) {
   auto Err =
-      Plugin::get().getDevice(DeviceId).waitEvent(EventPtr, AsyncInfoPtr);
+      PluginTy::get().getDevice(DeviceId).waitEvent(EventPtr, AsyncInfoPtr);
   if (Err) {
     REPORT("Failure to wait event %p: %s\n", EventPtr,
            toString(std::move(Err)).data());
@@ -1925,7 +1925,7 @@ int32_t __tgt_rtl_wait_event(int32_t DeviceId, void *EventPtr,
 }
 
 int32_t __tgt_rtl_sync_event(int32_t DeviceId, void *EventPtr) {
-  auto Err = Plugin::get().getDevice(DeviceId).syncEvent(EventPtr);
+  auto Err = PluginTy::get().getDevice(DeviceId).syncEvent(EventPtr);
   if (Err) {
     REPORT("Failure to synchronize event %p: %s\n", EventPtr,
            toString(std::move(Err)).data());
@@ -1936,7 +1936,7 @@ int32_t __tgt_rtl_sync_event(int32_t DeviceId, void *EventPtr) {
 }
 
 int32_t __tgt_rtl_destroy_event(int32_t DeviceId, void *EventPtr) {
-  auto Err = Plugin::get().getDevice(DeviceId).destroyEvent(EventPtr);
+  auto Err = PluginTy::get().getDevice(DeviceId).destroyEvent(EventPtr);
   if (Err) {
     REPORT("Failure to destroy event %p: %s\n", EventPtr,
            toString(std::move(Err)).data());
@@ -1955,7 +1955,7 @@ int32_t __tgt_rtl_init_async_info(int32_t DeviceId,
                                   __tgt_async_info **AsyncInfoPtr) {
   assert(AsyncInfoPtr && "Invalid async info");
 
-  auto Err = Plugin::get().getDevice(DeviceId).initAsyncInfo(AsyncInfoPtr);
+  auto Err = PluginTy::get().getDevice(DeviceId).initAsyncInfo(AsyncInfoPtr);
   if (Err) {
     REPORT("Failure to initialize async info at " DPxMOD " on device %d: %s\n",
            DPxPTR(*AsyncInfoPtr), DeviceId, toString(std::move(Err)).data());
@@ -1970,7 +1970,7 @@ int32_t __tgt_rtl_init_device_info(int32_t DeviceId,
                                    const char **ErrStr) {
   *ErrStr = "";
 
-  auto Err = Plugin::get().getDevice(DeviceId).initDeviceInfo(DeviceInfo);
+  auto Err = PluginTy::get().getDevice(DeviceId).initDeviceInfo(DeviceInfo);
   if (Err) {
     REPORT("Failure to initialize device info at " DPxMOD " on device %d: %s\n",
            DPxPTR(DeviceInfo), DeviceId, toString(std::move(Err)).data());
@@ -1981,7 +1981,7 @@ int32_t __tgt_rtl_init_device_info(int32_t DeviceId,
 }
 
 int32_t __tgt_rtl_set_device_offset(int32_t DeviceIdOffset) {
-  Plugin::get().setDeviceIdStartIndex(DeviceIdOffset);
+  PluginTy::get().setDeviceIdStartIndex(DeviceIdOffset);
 
   return OFFLOAD_SUCCESS;
 }
@@ -1990,9 +1990,9 @@ int32_t __tgt_rtl_use_auto_zero_copy(int32_t DeviceId) {
   // Automatic zero-copy only applies to programs that did
   // not request unified_shared_memory and are deployed on an
   // APU with XNACK enabled.
-  if (Plugin::get().getRequiresFlags() & OMP_REQ_UNIFIED_SHARED_MEMORY)
+  if (PluginTy::get().getRequiresFlags() & OMP_REQ_UNIFIED_SHARED_MEMORY)
     return false;
-  return Plugin::get().getDevice(DeviceId).useAutoZeroCopy();
+  return PluginTy::get().getDevice(DeviceId).useAutoZeroCopy();
 }
 
 int32_t __tgt_rtl_get_global(__tgt_device_binary Binary, uint64_t Size,
@@ -2000,7 +2000,7 @@ int32_t __tgt_rtl_get_global(__tgt_device_binary Binary, uint64_t Size,
   assert(Binary.handle && "Invalid device binary handle");
   DeviceImageTy &Image = *reinterpret_cast<DeviceImageTy *>(Binary.handle);
 
-  GenericPluginTy &Plugin = Plugin::get();
+  GenericPluginTy &Plugin = PluginTy::get();
   GenericDeviceTy &Device = Image.getDevice();
 
   GlobalTy DeviceGlobal(Name, Size);
diff --git a/openmp/libomptarget/plugins-nextgen/cuda/src/rtl.cpp b/openmp/libomptarget/plugins-nextgen/cuda/src/rtl.cpp
index b862bc74909257..298783fd35ef98 100644
--- a/openmp/libomptarget/plugins-nextgen/cuda/src/rtl.cpp
+++ b/openmp/libomptarget/plugins-nextgen/cuda/src/rtl.cpp
@@ -471,7 +471,7 @@ struct CUDADeviceTy : public GenericDeviceTy {
   /// Allocate and construct a CUDA kernel.
   Expected<GenericKernelTy &> constructKernel(const char *Name) override {
     // Allocate and construct the CUDA kernel.
-    CUDAKernelTy *CUDAKernel = Plugin::get().allocate<CUDAKernelTy>();
+    CUDAKernelTy *CUDAKernel = PluginTy::get().allocate<CUDAKernelTy>();
     if (!CUDAKernel)
       return Plugin::error("Failed to allocate memory for CUDA kernel");
 
@@ -529,7 +529,7 @@ struct CUDADeviceTy : public GenericDeviceTy {
       return std::move(Err);
 
     // Allocate and initialize the image object.
-    CUDADeviceImageTy *CUDAImage = Plugin::get().allocate<CUDADeviceImageTy>();
+    CUDADeviceImageTy *CUDAImage = PluginTy::get().allocate<CUDADeviceImageTy>();
     new (CUDAImage) CUDADeviceImageTy(ImageId, *this, TgtImage);
 
     // Load the CUDA module.
@@ -1371,6 +1371,16 @@ struct CUDAPluginTy final : public GenericPluginTy {
   /// Deinitialize the plugin.
   Error deinitImpl() override { return Plugin::success(); }
 
+  /// Creates a CUDA device to use for offloading.
+  GenericDeviceTy *createDevice(int32_t DeviceId, int32_t NumDevices) override {
+    return new CUDADeviceTy(DeviceId, NumDevices);
+  }
+
+  /// Creates a CUDA global handler.
+  GenericGlobalHandlerTy *createGlobalHandler() override {
+    return new CUDAGlobalHandlerTy();
+  }
+
   /// Get the ELF code for recognizing the compatible image binary.
   uint16_t getMagicElfBits() const override { return ELF::EM_CUDA; }
 
@@ -1484,18 +1494,10 @@ Error CUDADeviceTy::dataExchangeImpl(const void *SrcPtr,
   return Plugin::check(Res, "Error in cuMemcpyDtoDAsync: %s");
 }
 
-GenericPluginTy *Plugin::createPlugin() { return new CUDAPluginTy(); }
-
-GenericDeviceTy *Plugin::createDevice(int32_t DeviceId, int32_t NumDevices) {
-  return new CUDADeviceTy(DeviceId, NumDevices);
-}
-
-GenericGlobalHandlerTy *Plugin::createGlobalHandler() {
-  return new CUDAGlobalHandlerTy();
-}
+GenericPluginTy *PluginTy::createPlugin() { return new CUDAPluginTy(); }
 
 template <typename... ArgsTy>
-Error Plugin::check(int32_t Code, const char *ErrFmt, ArgsTy... Args) {
+static Error Plugin::check(int32_t Code, const char *ErrFmt, ArgsTy... Args) {
   CUresult ResultCode = static_cast<CUresult>(Code);
   if (ResultCode == CUDA_SUCCESS)
     return Error::success();
diff --git a/openmp/libomptarget/plugins-nextgen/host/src/rtl.cpp b/openmp/libomptarget/plugins-nextgen/host/src/rtl.cpp
index 1ef18814a26ac8..ea8ed8d6a8569e 100644
--- a/openmp/libomptarget/plugins-nextgen/host/src/rtl.cpp
+++ b/openmp/libomptarget/plugins-nextgen/host/src/rtl.cpp
@@ -66,7 +66,7 @@ struct GenELF64KernelTy : public GenericKernelTy {
     GlobalTy Global(getName(), 0);
 
     // Get the metadata (address) of the kernel function.
-    GenericGlobalHandlerTy &GHandler = Plugin::get().getGlobalHandler();
+    GenericGlobalHandlerTy &GHandler = PluginTy::get().getGlobalHandler();
     if (auto Err = GHandler.getGlobalMetadataFromDevice(Device, Image, Global))
       return Err;
 
@@ -150,7 +150,7 @@ struct GenELF64DeviceTy : public GenericDeviceTy {
   Expected<GenericKernelTy &> constructKernel(const char *Name) override {
     // Allocate and construct the kernel.
     GenELF64KernelTy *GenELF64Kernel =
-        Plugin::get().allocate<GenELF64KernelTy>();
+        PluginTy::get().allocate<GenELF64KernelTy>();
     if (!GenELF64Kernel)
       return Plugin::error("Failed to allocate memory for GenELF64 kernel");
 
@@ -167,7 +167,7 @@ struct GenELF64DeviceTy : public GenericDeviceTy {
                                            int32_t ImageId) override {
     // Allocate and initialize the image object.
     GenELF64DeviceImageTy *Image =
-        Plugin::get().allocate<GenELF64DeviceImageTy>();
+        PluginTy::get().allocate<GenELF64DeviceImageTy>();
     new (Image) GenELF64DeviceImageTy(ImageId, *this, TgtImage);
 
     // Create a temporary file.
@@ -399,6 +399,16 @@ struct GenELF64PluginTy final : public GenericPluginTy {
   /// Deinitialize the plugin.
   Error deinitImpl() override { return Plugin::success(); }
 
+  /// Creates a generic ELF device.
+  GenericDeviceTy *createDevice(int32_t DeviceId, int32_t NumDevices) override {
+    return new GenELF64DeviceTy(DeviceId, NumDevices);
+  }
+
+  /// Creates a generic global handler.
+  GenericGlobalHandlerTy *createGlobalHandler() override {
+    return new GenELF64GlobalHandlerTy();
+  }
+
   /// Get the ELF code to recognize the compatible binary images.
   uint16_t getMagicElfBits() const override { return ELF::TARGET_ELF_ID; }
 
@@ -415,18 +425,10 @@ struct GenELF64PluginTy final : public GenericPluginTy {
   }
 };
 
-GenericPluginTy *Plugin::createPlugin() { return new GenELF64PluginTy(); }
-
-GenericDeviceTy *Plugin::createDevice(int32_t DeviceId, int32_t NumDevices) {
-  return new GenELF64DeviceTy(DeviceId, NumDevices);
-}
-
-GenericGlobalHandlerTy *Plugin::createGlobalHandler() {
-  return new GenELF64GlobalHandlerTy();
-}
+GenericPluginTy *PluginTy::createPlugin() { return new GenELF64PluginTy(); }
 
 template <typename... ArgsTy>
-Error Plugin::check(int32_t Code, const char *ErrMsg, ArgsTy... Args) {
+static Error Plugin::check(int32_t Code, const char *ErrMsg, ArgsTy... Args) {
   if (Code == 0)
     return Error::success();
 



More information about the Openmp-commits mailing list