[Openmp-commits] [libc] [openmp] [libc] Change RPC interface to not use device ids (PR #87087)

Joseph Huber via Openmp-commits openmp-commits at lists.llvm.org
Fri Mar 29 10:23:36 PDT 2024


https://github.com/jhuber6 updated https://github.com/llvm/llvm-project/pull/87087

>From 174629cb45999156cf9531a0959bade7b7ab621e Mon Sep 17 00:00:00 2001
From: Joseph Huber <huberjn at outlook.com>
Date: Fri, 29 Mar 2024 11:51:02 -0500
Subject: [PATCH] [libc] Change RPC interface to not use device ids

Summary:
The current implementation of RPC tied everything to device IDs and
forced us to do init / shutdown to manage some global state. This turned
out to be a bad idea in situations where we want to track multiple
hetergeneous devices that may report the same device ID in the same
process.

This patch changes the interface to instead create an opaque handle to
the internal device and simply allocates it via `new`. The user will
then take this device and store it to interface with the attached
device. This interface puts the burden of tracking the device identifier
to mapped d evices onto the user, but in return heavily simplifies the
implementation.
---
 libc/utils/gpu/loader/Loader.h                |   8 +-
 libc/utils/gpu/loader/amdgpu/Loader.cpp       |  47 ++++----
 libc/utils/gpu/loader/nvptx/Loader.cpp        |  39 +++---
 libc/utils/gpu/server/llvmlibc_rpc_server.h   |  29 ++---
 libc/utils/gpu/server/rpc_server.cpp          | 112 +++++-------------
 .../plugins-nextgen/common/include/RPC.h      |   7 +-
 .../common/src/PluginInterface.cpp            |   2 +-
 .../plugins-nextgen/common/src/RPC.cpp        |  44 +++----
 8 files changed, 107 insertions(+), 181 deletions(-)

diff --git a/libc/utils/gpu/loader/Loader.h b/libc/utils/gpu/loader/Loader.h
index 93380383701978..9c7d328930c239 100644
--- a/libc/utils/gpu/loader/Loader.h
+++ b/libc/utils/gpu/loader/Loader.h
@@ -108,11 +108,11 @@ inline void handle_error(rpc_status_t) {
 }
 
 template <uint32_t lane_size>
-inline void register_rpc_callbacks(uint32_t device_id) {
+inline void register_rpc_callbacks(rpc_device_t device) {
   static_assert(lane_size == 32 || lane_size == 64, "Invalid Lane size");
   // Register the ping test for the `libc` tests.
   rpc_register_callback(
-      device_id, static_cast<rpc_opcode_t>(RPC_TEST_INCREMENT),
+      device, static_cast<rpc_opcode_t>(RPC_TEST_INCREMENT),
       [](rpc_port_t port, void *data) {
         rpc_recv_and_send(
             port,
@@ -125,7 +125,7 @@ inline void register_rpc_callbacks(uint32_t device_id) {
 
   // Register the interface test callbacks.
   rpc_register_callback(
-      device_id, static_cast<rpc_opcode_t>(RPC_TEST_INTERFACE),
+      device, static_cast<rpc_opcode_t>(RPC_TEST_INTERFACE),
       [](rpc_port_t port, void *data) {
         uint64_t cnt = 0;
         bool end_with_recv;
@@ -207,7 +207,7 @@ inline void register_rpc_callbacks(uint32_t device_id) {
 
   // Register the stream test handler.
   rpc_register_callback(
-      device_id, static_cast<rpc_opcode_t>(RPC_TEST_STREAM),
+      device, static_cast<rpc_opcode_t>(RPC_TEST_STREAM),
       [](rpc_port_t port, void *data) {
         uint64_t sizes[lane_size] = {0};
         void *dst[lane_size] = {nullptr};
diff --git a/libc/utils/gpu/loader/amdgpu/Loader.cpp b/libc/utils/gpu/loader/amdgpu/Loader.cpp
index e3911eda2bd82a..35840c6910bd8e 100644
--- a/libc/utils/gpu/loader/amdgpu/Loader.cpp
+++ b/libc/utils/gpu/loader/amdgpu/Loader.cpp
@@ -153,7 +153,8 @@ template <typename args_t>
 hsa_status_t launch_kernel(hsa_agent_t dev_agent, hsa_executable_t executable,
                            hsa_amd_memory_pool_t kernargs_pool,
                            hsa_amd_memory_pool_t coarsegrained_pool,
-                           hsa_queue_t *queue, const LaunchParameters &params,
+                           hsa_queue_t *queue, rpc_device_t device,
+                           const LaunchParameters &params,
                            const char *kernel_name, args_t kernel_args) {
   // Look up the '_start' kernel in the loaded executable.
   hsa_executable_symbol_t symbol;
@@ -162,10 +163,9 @@ hsa_status_t launch_kernel(hsa_agent_t dev_agent, hsa_executable_t executable,
     return err;
 
   // Register RPC callbacks for the malloc and free functions on HSA.
-  uint32_t device_id = 0;
   auto tuple = std::make_tuple(dev_agent, coarsegrained_pool);
   rpc_register_callback(
-      device_id, RPC_MALLOC,
+      device, RPC_MALLOC,
       [](rpc_port_t port, void *data) {
         auto malloc_handler = [](rpc_buffer_t *buffer, void *data) -> void {
           auto &[dev_agent, pool] = *static_cast<decltype(tuple) *>(data);
@@ -182,7 +182,7 @@ hsa_status_t launch_kernel(hsa_agent_t dev_agent, hsa_executable_t executable,
       },
       &tuple);
   rpc_register_callback(
-      device_id, RPC_FREE,
+      device, RPC_FREE,
       [](rpc_port_t port, void *data) {
         auto free_handler = [](rpc_buffer_t *buffer, void *) {
           if (hsa_status_t err = hsa_amd_memory_pool_free(
@@ -284,12 +284,12 @@ hsa_status_t launch_kernel(hsa_agent_t dev_agent, hsa_executable_t executable,
   while (hsa_signal_wait_scacquire(
              packet->completion_signal, HSA_SIGNAL_CONDITION_EQ, 0,
              /*timeout_hint=*/1024, HSA_WAIT_STATE_ACTIVE) != 0)
-    if (rpc_status_t err = rpc_handle_server(device_id))
+    if (rpc_status_t err = rpc_handle_server(device))
       handle_error(err);
 
   // Handle the server one more time in case the kernel exited with a pending
   // send still in flight.
-  if (rpc_status_t err = rpc_handle_server(device_id))
+  if (rpc_status_t err = rpc_handle_server(device))
     handle_error(err);
 
   // Destroy the resources acquired to launch the kernel and return.
@@ -342,8 +342,6 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
     handle_error(err);
 
   // Obtain a single agent for the device and host to use the HSA memory model.
-  uint32_t num_devices = 1;
-  uint32_t device_id = 0;
   hsa_agent_t dev_agent;
   hsa_agent_t host_agent;
   if (hsa_status_t err = get_agent<HSA_DEVICE_TYPE_GPU>(&dev_agent))
@@ -433,8 +431,6 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
     handle_error(err);
 
   // Set up the RPC server.
-  if (rpc_status_t err = rpc_init(num_devices))
-    handle_error(err);
   auto tuple = std::make_tuple(dev_agent, finegrained_pool);
   auto rpc_alloc = [](uint64_t size, void *data) {
     auto &[dev_agent, finegrained_pool] = *static_cast<decltype(tuple) *>(data);
@@ -445,15 +441,16 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
     hsa_amd_agents_allow_access(1, &dev_agent, nullptr, dev_ptr);
     return dev_ptr;
   };
-  if (rpc_status_t err = rpc_server_init(device_id, RPC_MAXIMUM_PORT_COUNT,
+  rpc_device_t device;
+  if (rpc_status_t err = rpc_server_init(&device, RPC_MAXIMUM_PORT_COUNT,
                                          wavefront_size, rpc_alloc, &tuple))
     handle_error(err);
 
   // Register callbacks for the RPC unit tests.
   if (wavefront_size == 32)
-    register_rpc_callbacks<32>(device_id);
+    register_rpc_callbacks<32>(device);
   else if (wavefront_size == 64)
-    register_rpc_callbacks<64>(device_id);
+    register_rpc_callbacks<64>(device);
   else
     handle_error("Invalid wavefront size");
 
@@ -483,10 +480,10 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
     handle_error(err);
 
   void *rpc_client_buffer;
-  if (hsa_status_t err = hsa_amd_memory_lock(
-          const_cast<void *>(rpc_get_client_buffer(device_id)),
-          rpc_get_client_size(),
-          /*agents=*/nullptr, 0, &rpc_client_buffer))
+  if (hsa_status_t err =
+          hsa_amd_memory_lock(const_cast<void *>(rpc_get_client_buffer(device)),
+                              rpc_get_client_size(),
+                              /*agents=*/nullptr, 0, &rpc_client_buffer))
     handle_error(err);
 
   // Copy the RPC client buffer to the address pointed to by the symbol.
@@ -496,7 +493,7 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
     handle_error(err);
 
   if (hsa_status_t err = hsa_amd_memory_unlock(
-          const_cast<void *>(rpc_get_client_buffer(device_id))))
+          const_cast<void *>(rpc_get_client_buffer(device))))
     handle_error(err);
   if (hsa_status_t err = hsa_amd_memory_pool_free(rpc_client_host))
     handle_error(err);
@@ -549,13 +546,13 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
   begin_args_t init_args = {argc, dev_argv, dev_envp};
   if (hsa_status_t err = launch_kernel(
           dev_agent, executable, kernargs_pool, coarsegrained_pool, queue,
-          single_threaded_params, "_begin.kd", init_args))
+          device, single_threaded_params, "_begin.kd", init_args))
     handle_error(err);
 
   start_args_t args = {argc, dev_argv, dev_envp, dev_ret};
-  if (hsa_status_t err =
-          launch_kernel(dev_agent, executable, kernargs_pool,
-                        coarsegrained_pool, queue, params, "_start.kd", args))
+  if (hsa_status_t err = launch_kernel(dev_agent, executable, kernargs_pool,
+                                       coarsegrained_pool, queue, device,
+                                       params, "_start.kd", args))
     handle_error(err);
 
   void *host_ret;
@@ -575,11 +572,11 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
   end_args_t fini_args = {ret};
   if (hsa_status_t err = launch_kernel(
           dev_agent, executable, kernargs_pool, coarsegrained_pool, queue,
-          single_threaded_params, "_end.kd", fini_args))
+          device, single_threaded_params, "_end.kd", fini_args))
     handle_error(err);
 
   if (rpc_status_t err = rpc_server_shutdown(
-          device_id, [](void *ptr, void *) { hsa_amd_memory_pool_free(ptr); },
+          device, [](void *ptr, void *) { hsa_amd_memory_pool_free(ptr); },
           nullptr))
     handle_error(err);
 
@@ -600,8 +597,6 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
   if (hsa_status_t err = hsa_code_object_destroy(object))
     handle_error(err);
 
-  if (rpc_status_t err = rpc_shutdown())
-    handle_error(err);
   if (hsa_status_t err = hsa_shut_down())
     handle_error(err);
 
diff --git a/libc/utils/gpu/loader/nvptx/Loader.cpp b/libc/utils/gpu/loader/nvptx/Loader.cpp
index 5388f287063b7f..1818932f0a966b 100644
--- a/libc/utils/gpu/loader/nvptx/Loader.cpp
+++ b/libc/utils/gpu/loader/nvptx/Loader.cpp
@@ -154,8 +154,8 @@ Expected<void *> get_ctor_dtor_array(const void *image, const size_t size,
 
 template <typename args_t>
 CUresult launch_kernel(CUmodule binary, CUstream stream,
-                       const LaunchParameters &params, const char *kernel_name,
-                       args_t kernel_args) {
+                       rpc_device_t rpc_device, const LaunchParameters &params,
+                       const char *kernel_name, args_t kernel_args) {
   // look up the '_start' kernel in the loaded module.
   CUfunction function;
   if (CUresult err = cuModuleGetFunction(&function, binary, kernel_name))
@@ -175,11 +175,10 @@ CUresult launch_kernel(CUmodule binary, CUstream stream,
     handle_error(err);
 
   // Register RPC callbacks for the malloc and free functions on HSA.
-  uint32_t device_id = 0;
-  register_rpc_callbacks<32>(device_id);
+  register_rpc_callbacks<32>(rpc_device);
 
   rpc_register_callback(
-      device_id, RPC_MALLOC,
+      rpc_device, RPC_MALLOC,
       [](rpc_port_t port, void *data) {
         auto malloc_handler = [](rpc_buffer_t *buffer, void *data) -> void {
           CUstream memory_stream = *static_cast<CUstream *>(data);
@@ -197,7 +196,7 @@ CUresult launch_kernel(CUmodule binary, CUstream stream,
       },
       &memory_stream);
   rpc_register_callback(
-      device_id, RPC_FREE,
+      rpc_device, RPC_FREE,
       [](rpc_port_t port, void *data) {
         auto free_handler = [](rpc_buffer_t *buffer, void *data) {
           CUstream memory_stream = *static_cast<CUstream *>(data);
@@ -219,12 +218,12 @@ CUresult launch_kernel(CUmodule binary, CUstream stream,
   // Wait until the kernel has completed execution on the device. Periodically
   // check the RPC client for work to be performed on the server.
   while (cuStreamQuery(stream) == CUDA_ERROR_NOT_READY)
-    if (rpc_status_t err = rpc_handle_server(device_id))
+    if (rpc_status_t err = rpc_handle_server(rpc_device))
       handle_error(err);
 
   // Handle the server one more time in case the kernel exited with a pending
   // send still in flight.
-  if (rpc_status_t err = rpc_handle_server(device_id))
+  if (rpc_status_t err = rpc_handle_server(rpc_device))
     handle_error(err);
 
   return CUDA_SUCCESS;
@@ -235,7 +234,6 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
   if (CUresult err = cuInit(0))
     handle_error(err);
   // Obtain the first device found on the system.
-  uint32_t num_devices = 1;
   uint32_t device_id = 0;
   CUdevice device;
   if (CUresult err = cuDeviceGet(&device, device_id))
@@ -294,9 +292,6 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
   if (CUresult err = cuMemsetD32(dev_ret, 0, 1))
     handle_error(err);
 
-  if (rpc_status_t err = rpc_init(num_devices))
-    handle_error(err);
-
   uint32_t warp_size = 32;
   auto rpc_alloc = [](uint64_t size, void *) -> void * {
     void *dev_ptr;
@@ -304,7 +299,8 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
       handle_error(err);
     return dev_ptr;
   };
-  if (rpc_status_t err = rpc_server_init(device_id, RPC_MAXIMUM_PORT_COUNT,
+  rpc_device_t rpc_device;
+  if (rpc_status_t err = rpc_server_init(&rpc_device, RPC_MAXIMUM_PORT_COUNT,
                                          warp_size, rpc_alloc, nullptr))
     handle_error(err);
 
@@ -321,19 +317,20 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
           cuMemcpyDtoH(&rpc_client_host, rpc_client_dev, sizeof(void *)))
     handle_error(err);
   if (CUresult err =
-          cuMemcpyHtoD(rpc_client_host, rpc_get_client_buffer(device_id),
+          cuMemcpyHtoD(rpc_client_host, rpc_get_client_buffer(rpc_device),
                        rpc_get_client_size()))
     handle_error(err);
 
   LaunchParameters single_threaded_params = {1, 1, 1, 1, 1, 1};
   begin_args_t init_args = {argc, dev_argv, dev_envp};
-  if (CUresult err = launch_kernel(binary, stream, single_threaded_params,
-                                   "_begin", init_args))
+  if (CUresult err = launch_kernel(binary, stream, rpc_device,
+                                   single_threaded_params, "_begin", init_args))
     handle_error(err);
 
   start_args_t args = {argc, dev_argv, dev_envp,
                        reinterpret_cast<void *>(dev_ret)};
-  if (CUresult err = launch_kernel(binary, stream, params, "_start", args))
+  if (CUresult err =
+          launch_kernel(binary, stream, rpc_device, params, "_start", args))
     handle_error(err);
 
   // Copy the return value back from the kernel and wait.
@@ -345,8 +342,8 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
     handle_error(err);
 
   end_args_t fini_args = {host_ret};
-  if (CUresult err = launch_kernel(binary, stream, single_threaded_params,
-                                   "_end", fini_args))
+  if (CUresult err = launch_kernel(binary, stream, rpc_device,
+                                   single_threaded_params, "_end", fini_args))
     handle_error(err);
 
   // Free the memory allocated for the device.
@@ -357,7 +354,7 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
   if (CUresult err = cuMemFreeHost(dev_argv))
     handle_error(err);
   if (rpc_status_t err = rpc_server_shutdown(
-          device_id, [](void *ptr, void *) { cuMemFreeHost(ptr); }, nullptr))
+          rpc_device, [](void *ptr, void *) { cuMemFreeHost(ptr); }, nullptr))
     handle_error(err);
 
   // Destroy the context and the loaded binary.
@@ -365,7 +362,5 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
     handle_error(err);
   if (CUresult err = cuDevicePrimaryCtxRelease(device))
     handle_error(err);
-  if (rpc_status_t err = rpc_shutdown())
-    handle_error(err);
   return host_ret;
 }
diff --git a/libc/utils/gpu/server/llvmlibc_rpc_server.h b/libc/utils/gpu/server/llvmlibc_rpc_server.h
index b7f2a463b1f5c8..b0cf2f916b3856 100644
--- a/libc/utils/gpu/server/llvmlibc_rpc_server.h
+++ b/libc/utils/gpu/server/llvmlibc_rpc_server.h
@@ -27,10 +27,8 @@ typedef enum {
   RPC_STATUS_SUCCESS = 0x0,
   RPC_STATUS_CONTINUE = 0x1,
   RPC_STATUS_ERROR = 0x1000,
-  RPC_STATUS_OUT_OF_RANGE = 0x1001,
-  RPC_STATUS_UNHANDLED_OPCODE = 0x1002,
-  RPC_STATUS_INVALID_LANE_SIZE = 0x1003,
-  RPC_STATUS_NOT_INITIALIZED = 0x1004,
+  RPC_STATUS_UNHANDLED_OPCODE = 0x1001,
+  RPC_STATUS_INVALID_LANE_SIZE = 0x1002,
 } rpc_status_t;
 
 /// A struct containing an opaque handle to an RPC port. This is what allows the
@@ -45,6 +43,11 @@ typedef struct rpc_buffer_s {
   uint64_t data[8];
 } rpc_buffer_t;
 
+/// An opaque handle to an RPC server that can be attached to a device.
+typedef struct rpc_device_s {
+  uintptr_t handle;
+} rpc_device_t;
+
 /// A function used to allocate \p bytes for use by the RPC server and client.
 /// The memory should support asynchronous and atomic access from both the
 /// client and server.
@@ -60,34 +63,28 @@ typedef void (*rpc_opcode_callback_ty)(rpc_port_t port, void *data);
 /// A callback function to use the port to receive or send a \p buffer.
 typedef void (*rpc_port_callback_ty)(rpc_buffer_t *buffer, void *data);
 
-/// Initialize the rpc library for general use on \p num_devices.
-rpc_status_t rpc_init(uint32_t num_devices);
-
-/// Shut down the rpc interface.
-rpc_status_t rpc_shutdown(void);
-
-/// Initialize the server for a given device.
-rpc_status_t rpc_server_init(uint32_t device_id, uint64_t num_ports,
+/// Initialize the server for a given device and return it in \p device.
+rpc_status_t rpc_server_init(rpc_device_t *rpc_device, uint64_t num_ports,
                              uint32_t lane_size, rpc_alloc_ty alloc,
                              void *data);
 
 /// Shut down the server for a given device.
-rpc_status_t rpc_server_shutdown(uint32_t device_id, rpc_free_ty dealloc,
+rpc_status_t rpc_server_shutdown(rpc_device_t rpc_device, rpc_free_ty dealloc,
                                  void *data);
 
 /// Queries the RPC clients at least once and performs server-side work if there
 /// are any active requests. Runs until all work on the server is completed.
-rpc_status_t rpc_handle_server(uint32_t device_id);
+rpc_status_t rpc_handle_server(rpc_device_t rpc_device);
 
 /// Register a callback to handle an opcode from the RPC client. The associated
 /// data must remain accessible as long as the user intends to handle the server
 /// with this callback.
-rpc_status_t rpc_register_callback(uint32_t device_id, uint16_t opcode,
+rpc_status_t rpc_register_callback(rpc_device_t rpc_device, uint16_t opcode,
                                    rpc_opcode_callback_ty callback, void *data);
 
 /// Obtain a pointer to a local client buffer that can be copied directly to the
 /// other process using the address stored at the rpc client symbol name.
-const void *rpc_get_client_buffer(uint32_t device_id);
+const void *rpc_get_client_buffer(rpc_device_t device);
 
 /// Returns the size of the client in bytes to be used for a memory copy.
 uint64_t rpc_get_client_size();
diff --git a/libc/utils/gpu/server/rpc_server.cpp b/libc/utils/gpu/server/rpc_server.cpp
index 46ad98fa02cc51..fd306642fdcc4c 100644
--- a/libc/utils/gpu/server/rpc_server.cpp
+++ b/libc/utils/gpu/server/rpc_server.cpp
@@ -248,127 +248,75 @@ struct Device {
   std::unordered_map<uint16_t, void *> callback_data;
 };
 
-// A struct containing all the runtime state required to run the RPC server.
-struct State {
-  State(uint32_t num_devices)
-      : num_devices(num_devices), devices(num_devices), reference_count(0u) {}
-  uint32_t num_devices;
-  std::vector<std::unique_ptr<Device>> devices;
-  std::atomic_uint32_t reference_count;
-};
-
-static std::mutex startup_mutex;
-
-static State *state;
-
-rpc_status_t rpc_init(uint32_t num_devices) {
-  std::scoped_lock<decltype(startup_mutex)> lock(startup_mutex);
-  if (!state)
-    state = new State(num_devices);
-
-  if (state->reference_count == std::numeric_limits<uint32_t>::max())
-    return RPC_STATUS_ERROR;
-
-  state->reference_count++;
-
-  return RPC_STATUS_SUCCESS;
-}
-
-rpc_status_t rpc_shutdown(void) {
-  if (state && state->reference_count-- == 1)
-    delete state;
-
-  return RPC_STATUS_SUCCESS;
-}
-
-rpc_status_t rpc_server_init(uint32_t device_id, uint64_t num_ports,
+rpc_status_t rpc_server_init(rpc_device_t *rpc_device, uint64_t num_ports,
                              uint32_t lane_size, rpc_alloc_ty alloc,
                              void *data) {
-  if (!state)
-    return RPC_STATUS_NOT_INITIALIZED;
-  if (device_id >= state->num_devices)
-    return RPC_STATUS_OUT_OF_RANGE;
+  if (!rpc_device)
+    return RPC_STATUS_ERROR;
   if (lane_size != 1 && lane_size != 32 && lane_size != 64)
     return RPC_STATUS_INVALID_LANE_SIZE;
 
-  if (!state->devices[device_id]) {
-    uint64_t size = rpc::Server::allocation_size(lane_size, num_ports);
-    void *buffer = alloc(size, data);
+  uint64_t size = rpc::Server::allocation_size(lane_size, num_ports);
+  void *buffer = alloc(size, data);
 
-    if (!buffer)
-      return RPC_STATUS_ERROR;
+  if (!buffer)
+    return RPC_STATUS_ERROR;
 
-    state->devices[device_id] =
-        std::make_unique<Device>(lane_size, num_ports, buffer);
-    if (!state->devices[device_id])
-      return RPC_STATUS_ERROR;
-  }
+  Device *device = new Device(lane_size, num_ports, buffer);
+  if (!device)
+    return RPC_STATUS_ERROR;
 
+  rpc_device->handle = reinterpret_cast<uintptr_t>(device);
   return RPC_STATUS_SUCCESS;
 }
 
-rpc_status_t rpc_server_shutdown(uint32_t device_id, rpc_free_ty dealloc,
+rpc_status_t rpc_server_shutdown(rpc_device_t rpc_device, rpc_free_ty dealloc,
                                  void *data) {
-  if (!state)
-    return RPC_STATUS_NOT_INITIALIZED;
-  if (device_id >= state->num_devices)
-    return RPC_STATUS_OUT_OF_RANGE;
-  if (!state->devices[device_id])
+  if (!rpc_device.handle)
     return RPC_STATUS_ERROR;
 
-  dealloc(state->devices[device_id]->buffer, data);
-  if (state->devices[device_id])
-    state->devices[device_id].release();
+  Device *device = reinterpret_cast<Device *>(rpc_device.handle);
+  dealloc(device->buffer, data);
+  delete device;
 
   return RPC_STATUS_SUCCESS;
 }
 
-rpc_status_t rpc_handle_server(uint32_t device_id) {
-  if (!state)
-    return RPC_STATUS_NOT_INITIALIZED;
-  if (device_id >= state->num_devices)
-    return RPC_STATUS_OUT_OF_RANGE;
-  if (!state->devices[device_id])
+rpc_status_t rpc_handle_server(rpc_device_t rpc_device) {
+  if (!rpc_device.handle)
     return RPC_STATUS_ERROR;
 
+  Device *device = reinterpret_cast<Device *>(rpc_device.handle);
   uint32_t index = 0;
   for (;;) {
-    Device &device = *state->devices[device_id];
-    rpc_status_t status = device.handle_server(index);
+    rpc_status_t status = device->handle_server(index);
     if (status != RPC_STATUS_CONTINUE)
       return status;
   }
 }
 
-rpc_status_t rpc_register_callback(uint32_t device_id, uint16_t opcode,
+rpc_status_t rpc_register_callback(rpc_device_t rpc_device, uint16_t opcode,
                                    rpc_opcode_callback_ty callback,
                                    void *data) {
-  if (!state)
-    return RPC_STATUS_NOT_INITIALIZED;
-  if (device_id >= state->num_devices)
-    return RPC_STATUS_OUT_OF_RANGE;
-  if (!state->devices[device_id])
+  if (!rpc_device.handle)
     return RPC_STATUS_ERROR;
 
-  state->devices[device_id]->callbacks[opcode] = callback;
-  state->devices[device_id]->callback_data[opcode] = data;
+  Device *device = reinterpret_cast<Device *>(rpc_device.handle);
+
+  device->callbacks[opcode] = callback;
+  device->callback_data[opcode] = data;
   return RPC_STATUS_SUCCESS;
 }
 
-const void *rpc_get_client_buffer(uint32_t device_id) {
-  if (!state || device_id >= state->num_devices || !state->devices[device_id])
+const void *rpc_get_client_buffer(rpc_device_t rpc_device) {
+  if (!rpc_device.handle)
     return nullptr;
-  return &state->devices[device_id]->client;
+  Device *device = reinterpret_cast<Device *>(rpc_device.handle);
+  return &device->client;
 }
 
 uint64_t rpc_get_client_size() { return sizeof(rpc::Client); }
 
-using ServerPort = std::variant<rpc::Server::Port *>;
-
-ServerPort get_port(rpc_port_t ref) {
-  return reinterpret_cast<rpc::Server::Port *>(ref.handle);
-}
-
 void rpc_send(rpc_port_t ref, rpc_port_callback_ty callback, void *data) {
   auto port = reinterpret_cast<rpc::Server::Port *>(ref.handle);
   port->send([=](rpc::Buffer *buffer) {
diff --git a/openmp/libomptarget/plugins-nextgen/common/include/RPC.h b/openmp/libomptarget/plugins-nextgen/common/include/RPC.h
index 2e39b3f299c888..b621cc0da4587d 100644
--- a/openmp/libomptarget/plugins-nextgen/common/include/RPC.h
+++ b/openmp/libomptarget/plugins-nextgen/common/include/RPC.h
@@ -16,6 +16,7 @@
 #ifndef OPENMP_LIBOMPTARGET_PLUGINS_NEXTGEN_COMMON_RPC_H
 #define OPENMP_LIBOMPTARGET_PLUGINS_NEXTGEN_COMMON_RPC_H
 
+#include "llvm/ADT/DenseMap.h"
 #include "llvm/Support/Error.h"
 
 #include <cstdint>
@@ -32,8 +33,6 @@ class DeviceImageTy;
 /// these routines will perform no action.
 struct RPCServerTy {
 public:
-  RPCServerTy(uint32_t NumDevices);
-
   /// Check if this device image is using an RPC server. This checks for the
   /// precense of an externally visible symbol in the device image that will
   /// be present whenever RPC code is called.
@@ -56,7 +55,9 @@ struct RPCServerTy {
   /// memory associated with the k
   llvm::Error deinitDevice(plugin::GenericDeviceTy &Device);
 
-  ~RPCServerTy();
+private:
+  /// Array from this device's identifier to its attached devices.
+  llvm::SmallVector<uintptr_t> Handles;
 };
 
 } // namespace llvm::omp::target
diff --git a/openmp/libomptarget/plugins-nextgen/common/src/PluginInterface.cpp b/openmp/libomptarget/plugins-nextgen/common/src/PluginInterface.cpp
index a4e6c93192159a..55e2865d6aae42 100644
--- a/openmp/libomptarget/plugins-nextgen/common/src/PluginInterface.cpp
+++ b/openmp/libomptarget/plugins-nextgen/common/src/PluginInterface.cpp
@@ -1492,7 +1492,7 @@ Error GenericPluginTy::init() {
   GlobalHandler = createGlobalHandler();
   assert(GlobalHandler && "Invalid global handler");
 
-  RPCServer = new RPCServerTy(NumDevices);
+  RPCServer = new RPCServerTy();
   assert(RPCServer && "Invalid RPC server");
 
   return Plugin::success();
diff --git a/openmp/libomptarget/plugins-nextgen/common/src/RPC.cpp b/openmp/libomptarget/plugins-nextgen/common/src/RPC.cpp
index f46b27701b5b91..fab0f6838f4a87 100644
--- a/openmp/libomptarget/plugins-nextgen/common/src/RPC.cpp
+++ b/openmp/libomptarget/plugins-nextgen/common/src/RPC.cpp
@@ -21,14 +21,6 @@ using namespace llvm;
 using namespace omp;
 using namespace target;
 
-RPCServerTy::RPCServerTy(uint32_t NumDevices) {
-#ifdef LIBOMPTARGET_RPC_SUPPORT
-  // If this fails then something is catastrophically wrong, just exit.
-  if (rpc_status_t Err = rpc_init(NumDevices))
-    FATAL_MESSAGE(1, "Error initializing the RPC server: %d\n", Err);
-#endif
-}
-
 llvm::Expected<bool>
 RPCServerTy::isDeviceUsingRPC(plugin::GenericDeviceTy &Device,
                               plugin::GenericGlobalHandlerTy &Handler,
@@ -44,7 +36,6 @@ Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device,
                               plugin::GenericGlobalHandlerTy &Handler,
                               plugin::DeviceImageTy &Image) {
 #ifdef LIBOMPTARGET_RPC_SUPPORT
-  uint32_t DeviceId = Device.getDeviceId();
   auto Alloc = [](uint64_t Size, void *Data) {
     plugin::GenericDeviceTy &Device =
         *reinterpret_cast<plugin::GenericDeviceTy *>(Data);
@@ -52,10 +43,12 @@ Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device,
   };
   uint64_t NumPorts =
       std::min(Device.requestedRPCPortCount(), RPC_MAXIMUM_PORT_COUNT);
-  if (rpc_status_t Err = rpc_server_init(DeviceId, NumPorts,
+  rpc_device_t RPCDevice;
+  if (rpc_status_t Err = rpc_server_init(&RPCDevice, NumPorts,
                                          Device.getWarpSize(), Alloc, &Device))
     return plugin::Plugin::error(
-        "Failed to initialize RPC server for device %d: %d", DeviceId, Err);
+        "Failed to initialize RPC server for device %d: %d",
+        Device.getDeviceId(), Err);
 
   // Register a custom opcode handler to perform plugin specific allocation.
   auto MallocHandler = [](rpc_port_t Port, void *Data) {
@@ -70,10 +63,10 @@ Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device,
         Data);
   };
   if (rpc_status_t Err =
-          rpc_register_callback(DeviceId, RPC_MALLOC, MallocHandler, &Device))
+          rpc_register_callback(RPCDevice, RPC_MALLOC, MallocHandler, &Device))
     return plugin::Plugin::error(
-        "Failed to register RPC malloc handler for device %d: %d\n", DeviceId,
-        Err);
+        "Failed to register RPC malloc handler for device %d: %d\n",
+        Device.getDeviceId(), Err);
 
   // Register a custom opcode handler to perform plugin specific deallocation.
   auto FreeHandler = [](rpc_port_t Port, void *Data) {
@@ -88,10 +81,10 @@ Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device,
         Data);
   };
   if (rpc_status_t Err =
-          rpc_register_callback(DeviceId, RPC_FREE, FreeHandler, &Device))
+          rpc_register_callback(RPCDevice, RPC_FREE, FreeHandler, &Device))
     return plugin::Plugin::error(
-        "Failed to register RPC free handler for device %d: %d\n", DeviceId,
-        Err);
+        "Failed to register RPC free handler for device %d: %d\n",
+        Device.getDeviceId(), Err);
 
   // Get the address of the RPC client from the device.
   void *ClientPtr;
@@ -104,17 +97,20 @@ Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device,
                                      sizeof(void *), nullptr))
     return Err;
 
-  const void *ClientBuffer = rpc_get_client_buffer(DeviceId);
+  const void *ClientBuffer = rpc_get_client_buffer(RPCDevice);
   if (auto Err = Device.dataSubmit(ClientPtr, ClientBuffer,
                                    rpc_get_client_size(), nullptr))
     return Err;
+  Handles.resize(Device.getDeviceId() + 1);
+  Handles[Device.getDeviceId()] = RPCDevice.handle;
 #endif
   return Error::success();
 }
 
 Error RPCServerTy::runServer(plugin::GenericDeviceTy &Device) {
 #ifdef LIBOMPTARGET_RPC_SUPPORT
-  if (rpc_status_t Err = rpc_handle_server(Device.getDeviceId()))
+  rpc_device_t RPCDevice{Handles[Device.getDeviceId()]};
+  if (rpc_status_t Err = rpc_handle_server(RPCDevice))
     return plugin::Plugin::error(
         "Error while running RPC server on device %d: %d", Device.getDeviceId(),
         Err);
@@ -124,22 +120,16 @@ Error RPCServerTy::runServer(plugin::GenericDeviceTy &Device) {
 
 Error RPCServerTy::deinitDevice(plugin::GenericDeviceTy &Device) {
 #ifdef LIBOMPTARGET_RPC_SUPPORT
+  rpc_device_t RPCDevice{Handles[Device.getDeviceId()]};
   auto Dealloc = [](void *Ptr, void *Data) {
     plugin::GenericDeviceTy &Device =
         *reinterpret_cast<plugin::GenericDeviceTy *>(Data);
     Device.free(Ptr, TARGET_ALLOC_HOST);
   };
-  if (rpc_status_t Err =
-          rpc_server_shutdown(Device.getDeviceId(), Dealloc, &Device))
+  if (rpc_status_t Err = rpc_server_shutdown(RPCDevice, Dealloc, &Device))
     return plugin::Plugin::error(
         "Failed to shut down RPC server for device %d: %d",
         Device.getDeviceId(), Err);
 #endif
   return Error::success();
 }
-
-RPCServerTy::~RPCServerTy() {
-#ifdef LIBOMPTARGET_RPC_SUPPORT
-  rpc_shutdown();
-#endif
-}



More information about the Openmp-commits mailing list