[libc-commits] [libc] [llvm] [libc] Remove RPC server API and use the header directly (PR #117075)

Joseph Huber via libc-commits libc-commits at lists.llvm.org
Fri Nov 22 13:33:04 PST 2024


https://github.com/jhuber6 updated https://github.com/llvm/llvm-project/pull/117075

>From 19aad3b0c5bf09c4c3eb597dbb0a4e14bbec6b38 Mon Sep 17 00:00:00 2001
From: Joseph Huber <huberjn at outlook.com>
Date: Wed, 20 Nov 2024 16:45:37 -0600
Subject: [PATCH] [libc] Remove RPC server API and use the header directly

Summary:
This patch removes much of the `llvmlibc_rpc_server` interface. This
pretty much deletes all of this code and just replaces it with including
`rpc.h` directly. We still maintain the file to let `libc` handle the
opcodes, since those depend on the `printf` impelmentation.

This will need to be cleaned up more, but I don't want to put too much
into a single patch.
---
 libc/shared/rpc.h                             |   7 +
 libc/utils/gpu/loader/Loader.h                | 204 +++++-----
 .../utils/gpu/loader/amdgpu/amdhsa-loader.cpp | 123 +++----
 libc/utils/gpu/loader/nvptx/nvptx-loader.cpp  |  96 ++---
 libc/utils/gpu/server/llvmlibc_rpc_server.h   |  94 +----
 libc/utils/gpu/server/rpc_server.cpp          | 348 +++++-------------
 offload/plugins-nextgen/common/CMakeLists.txt |   1 +
 offload/plugins-nextgen/common/include/RPC.h  |   2 +-
 offload/plugins-nextgen/common/src/RPC.cpp    | 126 +++----
 9 files changed, 339 insertions(+), 662 deletions(-)

diff --git a/libc/shared/rpc.h b/libc/shared/rpc.h
index 489a8cebfb807c..c5e4277286c399 100644
--- a/libc/shared/rpc.h
+++ b/libc/shared/rpc.h
@@ -42,6 +42,13 @@ namespace rpc {
 #define __scoped_atomic_thread_fence(ord, scp) __atomic_thread_fence(ord)
 #endif
 
+/// Generic codes that can be used whem implementing the server.
+enum Status {
+  SUCCESS = 0x0,
+  ERROR = 0x1000,
+  UNHANDLED_OPCODE = 0x1001,
+};
+
 /// A fixed size channel used to communicate between the RPC client and server.
 struct Buffer {
   uint64_t data[8];
diff --git a/libc/utils/gpu/loader/Loader.h b/libc/utils/gpu/loader/Loader.h
index 8be8c0d5f85532..fd5105b34709e6 100644
--- a/libc/utils/gpu/loader/Loader.h
+++ b/libc/utils/gpu/loader/Loader.h
@@ -13,6 +13,7 @@
 
 #include "include/llvm-libc-types/rpc_opcodes_t.h"
 #include "include/llvm-libc-types/test_rpc_opcodes_t.h"
+#include "shared/rpc.h"
 
 #include <cstddef>
 #include <cstdint>
@@ -103,129 +104,90 @@ inline void handle_error_impl(const char *file, int32_t line, const char *msg) {
   fprintf(stderr, "%s:%d:0: Error: %s\n", file, line, msg);
   exit(EXIT_FAILURE);
 }
-
-inline void handle_error_impl(const char *file, int32_t line,
-                              rpc_status_t err) {
-  fprintf(stderr, "%s:%d:0: Error: %d\n", file, line, err);
-  exit(EXIT_FAILURE);
-}
 #define handle_error(X) handle_error_impl(__FILE__, __LINE__, X)
 
-template <uint32_t lane_size>
-inline void register_rpc_callbacks(rpc_device_t device) {
-  static_assert(lane_size == 32 || lane_size == 64, "Invalid Lane size");
-  // Register the ping test for the `libc` tests.
-  rpc_register_callback(
-      device, static_cast<rpc_opcode_t>(RPC_TEST_INCREMENT),
-      [](rpc_port_t port, void *data) {
-        rpc_recv_and_send(
-            port,
-            [](rpc_buffer_t *buffer, void *data) {
-              reinterpret_cast<uint64_t *>(buffer->data)[0] += 1;
-            },
-            data);
-      },
-      nullptr);
-
-  // Register the interface test callbacks.
-  rpc_register_callback(
-      device, static_cast<rpc_opcode_t>(RPC_TEST_INTERFACE),
-      [](rpc_port_t port, void *data) {
-        uint64_t cnt = 0;
-        bool end_with_recv;
-        rpc_recv(
-            port,
-            [](rpc_buffer_t *buffer, void *data) {
-              *reinterpret_cast<bool *>(data) = buffer->data[0];
-            },
-            &end_with_recv);
-        rpc_recv(
-            port,
-            [](rpc_buffer_t *buffer, void *data) {
-              *reinterpret_cast<uint64_t *>(data) = buffer->data[0];
-            },
-            &cnt);
-        rpc_send(
-            port,
-            [](rpc_buffer_t *buffer, void *data) {
-              uint64_t &cnt = *reinterpret_cast<uint64_t *>(data);
-              buffer->data[0] = cnt = cnt + 1;
-            },
-            &cnt);
-        rpc_recv(
-            port,
-            [](rpc_buffer_t *buffer, void *data) {
-              *reinterpret_cast<uint64_t *>(data) = buffer->data[0];
-            },
-            &cnt);
-        rpc_send(
-            port,
-            [](rpc_buffer_t *buffer, void *data) {
-              uint64_t &cnt = *reinterpret_cast<uint64_t *>(data);
-              buffer->data[0] = cnt = cnt + 1;
-            },
-            &cnt);
-        rpc_recv(
-            port,
-            [](rpc_buffer_t *buffer, void *data) {
-              *reinterpret_cast<uint64_t *>(data) = buffer->data[0];
-            },
-            &cnt);
-        rpc_recv(
-            port,
-            [](rpc_buffer_t *buffer, void *data) {
-              *reinterpret_cast<uint64_t *>(data) = buffer->data[0];
-            },
-            &cnt);
-        rpc_send(
-            port,
-            [](rpc_buffer_t *buffer, void *data) {
-              uint64_t &cnt = *reinterpret_cast<uint64_t *>(data);
-              buffer->data[0] = cnt = cnt + 1;
-            },
-            &cnt);
-        rpc_send(
-            port,
-            [](rpc_buffer_t *buffer, void *data) {
-              uint64_t &cnt = *reinterpret_cast<uint64_t *>(data);
-              buffer->data[0] = cnt = cnt + 1;
-            },
-            &cnt);
-        if (end_with_recv)
-          rpc_recv(
-              port,
-              [](rpc_buffer_t *buffer, void *data) {
-                *reinterpret_cast<uint64_t *>(data) = buffer->data[0];
-              },
-              &cnt);
-        else
-          rpc_send(
-              port,
-              [](rpc_buffer_t *buffer, void *data) {
-                uint64_t &cnt = *reinterpret_cast<uint64_t *>(data);
-                buffer->data[0] = cnt = cnt + 1;
-              },
-              &cnt);
-      },
-      nullptr);
-
-  // Register the stream test handler.
-  rpc_register_callback(
-      device, static_cast<rpc_opcode_t>(RPC_TEST_STREAM),
-      [](rpc_port_t port, void *data) {
-        uint64_t sizes[lane_size] = {0};
-        void *dst[lane_size] = {nullptr};
-        rpc_recv_n(
-            port, dst, sizes,
-            [](uint64_t size, void *) -> void * { return new char[size]; },
-            nullptr);
-        rpc_send_n(port, dst, sizes);
-        for (uint64_t i = 0; i < lane_size; ++i) {
-          if (dst[i])
-            delete[] reinterpret_cast<uint8_t *>(dst[i]);
-        }
-      },
-      nullptr);
+template <uint32_t num_lanes, typename Alloc, typename Free>
+inline uint32_t handle_server(rpc::Server &server, uint32_t index,
+                              Alloc &&alloc, Free &&free) {
+  auto port = server.try_open(num_lanes, index);
+  if (!port)
+    return 0;
+  index = port->get_index() + 1;
+
+  switch (port->get_opcode()) {
+  case RPC_TEST_INCREMENT: {
+    port->recv_and_send([](rpc::Buffer *buffer, uint32_t) {
+      reinterpret_cast<uint64_t *>(buffer->data)[0] += 1;
+    });
+    break;
+  }
+  case RPC_TEST_INTERFACE: {
+    bool end_with_recv;
+    uint64_t cnt;
+    port->recv([&](rpc::Buffer *buffer, uint32_t) {
+      end_with_recv = buffer->data[0];
+    });
+    port->recv([&](rpc::Buffer *buffer, uint32_t) { cnt = buffer->data[0]; });
+    port->send([&](rpc::Buffer *buffer, uint32_t) {
+      buffer->data[0] = cnt = cnt + 1;
+    });
+    port->recv([&](rpc::Buffer *buffer, uint32_t) { cnt = buffer->data[0]; });
+    port->send([&](rpc::Buffer *buffer, uint32_t) {
+      buffer->data[0] = cnt = cnt + 1;
+    });
+    port->recv([&](rpc::Buffer *buffer, uint32_t) { cnt = buffer->data[0]; });
+    port->recv([&](rpc::Buffer *buffer, uint32_t) { cnt = buffer->data[0]; });
+    port->send([&](rpc::Buffer *buffer, uint32_t) {
+      buffer->data[0] = cnt = cnt + 1;
+    });
+    port->send([&](rpc::Buffer *buffer, uint32_t) {
+      buffer->data[0] = cnt = cnt + 1;
+    });
+    if (end_with_recv)
+      port->recv([&](rpc::Buffer *buffer, uint32_t) { cnt = buffer->data[0]; });
+    else
+      port->send([&](rpc::Buffer *buffer, uint32_t) {
+        buffer->data[0] = cnt = cnt + 1;
+      });
+
+    break;
+  }
+  case RPC_TEST_STREAM: {
+    uint64_t sizes[num_lanes] = {0};
+    void *dst[num_lanes] = {nullptr};
+    port->recv_n(dst, sizes,
+                 [](uint64_t size) -> void * { return new char[size]; });
+    port->send_n(dst, sizes);
+    for (uint64_t i = 0; i < num_lanes; ++i) {
+      if (dst[i])
+        delete[] reinterpret_cast<uint8_t *>(dst[i]);
+    }
+    break;
+  }
+  case RPC_MALLOC: {
+    port->recv_and_send([&](rpc::Buffer *buffer, uint32_t) {
+      buffer->data[0] = reinterpret_cast<uintptr_t>(alloc(buffer->data[0]));
+    });
+    break;
+  }
+  case RPC_FREE: {
+    port->recv([&](rpc::Buffer *buffer, uint32_t) {
+      free(reinterpret_cast<void *>(buffer->data[0]));
+    });
+    break;
+  }
+  default:
+    break;
+  }
+
+  // Handle all of the `libc` specific opcodes.
+  int status = libc_handle_rpc_port(&*port, num_lanes);
+  if (status != rpc::SUCCESS)
+    handle_error("Error handling RPC server");
+
+  port->close();
+
+  return index;
 }
 
 #endif
diff --git a/libc/utils/gpu/loader/amdgpu/amdhsa-loader.cpp b/libc/utils/gpu/loader/amdgpu/amdhsa-loader.cpp
index d825a6299263ae..13a13668335471 100644
--- a/libc/utils/gpu/loader/amdgpu/amdhsa-loader.cpp
+++ b/libc/utils/gpu/loader/amdgpu/amdhsa-loader.cpp
@@ -160,7 +160,7 @@ template <typename args_t>
 hsa_status_t launch_kernel(hsa_agent_t dev_agent, hsa_executable_t executable,
                            hsa_amd_memory_pool_t kernargs_pool,
                            hsa_amd_memory_pool_t coarsegrained_pool,
-                           hsa_queue_t *queue, rpc_device_t device,
+                           hsa_queue_t *queue, rpc::Server &server,
                            const LaunchParameters &params,
                            const char *kernel_name, args_t kernel_args,
                            bool print_resource_usage) {
@@ -170,37 +170,10 @@ hsa_status_t launch_kernel(hsa_agent_t dev_agent, hsa_executable_t executable,
           executable, kernel_name, &dev_agent, &symbol))
     return err;
 
-  // Register RPC callbacks for the malloc and free functions on HSA.
-  auto tuple = std::make_tuple(dev_agent, coarsegrained_pool);
-  rpc_register_callback(
-      device, RPC_MALLOC,
-      [](rpc_port_t port, void *data) {
-        auto malloc_handler = [](rpc_buffer_t *buffer, void *data) -> void {
-          auto &[dev_agent, pool] = *static_cast<decltype(tuple) *>(data);
-          uint64_t size = buffer->data[0];
-          void *dev_ptr = nullptr;
-          if (hsa_status_t err =
-                  hsa_amd_memory_pool_allocate(pool, size,
-                                               /*flags=*/0, &dev_ptr))
-            dev_ptr = nullptr;
-          hsa_amd_agents_allow_access(1, &dev_agent, nullptr, dev_ptr);
-          buffer->data[0] = reinterpret_cast<uintptr_t>(dev_ptr);
-        };
-        rpc_recv_and_send(port, malloc_handler, data);
-      },
-      &tuple);
-  rpc_register_callback(
-      device, RPC_FREE,
-      [](rpc_port_t port, void *data) {
-        auto free_handler = [](rpc_buffer_t *buffer, void *) {
-          if (hsa_status_t err = hsa_amd_memory_pool_free(
-                  reinterpret_cast<void *>(buffer->data[0])))
-            handle_error(err);
-        };
-        rpc_recv_and_send(port, free_handler, data);
-      },
-      nullptr);
-
+  uint32_t wavefront_size = 0;
+  if (hsa_status_t err = hsa_agent_get_info(
+          dev_agent, HSA_AGENT_INFO_WAVEFRONT_SIZE, &wavefront_size))
+    handle_error(err);
   // Retrieve different properties of the kernel symbol used for launch.
   uint64_t kernel;
   uint32_t args_size;
@@ -292,14 +265,38 @@ hsa_status_t launch_kernel(hsa_agent_t dev_agent, hsa_executable_t executable,
   hsa_signal_store_relaxed(queue->doorbell_signal, packet_id);
 
   std::atomic<bool> finished = false;
-  std::thread server(
-      [](std::atomic<bool> *finished, rpc_device_t device) {
-        while (!*finished) {
-          if (rpc_status_t err = rpc_handle_server(device))
+  std::thread server_thread(
+      [](std::atomic<bool> *finished, rpc::Server *server,
+         uint32_t wavefront_size, hsa_agent_t dev_agent,
+         hsa_amd_memory_pool_t coarsegrained_pool) {
+        // Register RPC callbacks for the malloc and free functions on HSA.
+        auto malloc_handler = [&](size_t size) -> void * {
+          void *dev_ptr = nullptr;
+          if (hsa_status_t err =
+                  hsa_amd_memory_pool_allocate(coarsegrained_pool, size,
+                                               /*flags=*/0, &dev_ptr))
+            dev_ptr = nullptr;
+          hsa_amd_agents_allow_access(1, &dev_agent, nullptr, dev_ptr);
+          return dev_ptr;
+        };
+
+        auto free_handler = [](void *ptr) -> void {
+          if (hsa_status_t err =
+                  hsa_amd_memory_pool_free(reinterpret_cast<void *>(ptr)))
             handle_error(err);
+        };
+
+        uint32_t index = 0;
+        while (!*finished) {
+          if (wavefront_size == 32)
+            index =
+                handle_server<32>(*server, index, malloc_handler, free_handler);
+          else
+            index =
+                handle_server<64>(*server, index, malloc_handler, free_handler);
         }
       },
-      &finished, device);
+      &finished, &server, wavefront_size, dev_agent, coarsegrained_pool);
 
   // Wait until the kernel has completed execution on the device. Periodically
   // check the RPC client for work to be performed on the server.
@@ -309,8 +306,8 @@ hsa_status_t launch_kernel(hsa_agent_t dev_agent, hsa_executable_t executable,
     ;
 
   finished = true;
-  if (server.joinable())
-    server.join();
+  if (server_thread.joinable())
+    server_thread.join();
 
   // Destroy the resources acquired to launch the kernel and return.
   if (hsa_status_t err = hsa_amd_memory_pool_free(args))
@@ -457,34 +454,22 @@ int load(int argc, const char **argv, const char **envp, void *image,
     handle_error(err);
 
   // Set up the RPC server.
-  auto tuple = std::make_tuple(dev_agent, finegrained_pool);
-  auto rpc_alloc = [](uint64_t size, void *data) {
-    auto &[dev_agent, finegrained_pool] = *static_cast<decltype(tuple) *>(data);
-    void *dev_ptr = nullptr;
-    if (hsa_status_t err = hsa_amd_memory_pool_allocate(finegrained_pool, size,
-                                                        /*flags=*/0, &dev_ptr))
-      handle_error(err);
-    hsa_amd_agents_allow_access(1, &dev_agent, nullptr, dev_ptr);
-    return dev_ptr;
-  };
-  rpc_device_t device;
-  if (rpc_status_t err = rpc_server_init(&device, RPC_MAXIMUM_PORT_COUNT,
-                                         wavefront_size, rpc_alloc, &tuple))
+  void *rpc_buffer;
+  if (hsa_status_t err = hsa_amd_memory_pool_allocate(
+          finegrained_pool,
+          rpc::Server::allocation_size(wavefront_size, rpc::MAX_PORT_COUNT),
+          /*flags=*/0, &rpc_buffer))
     handle_error(err);
+  hsa_amd_agents_allow_access(1, &dev_agent, nullptr, rpc_buffer);
 
-  // Register callbacks for the RPC unit tests.
-  if (wavefront_size == 32)
-    register_rpc_callbacks<32>(device);
-  else if (wavefront_size == 64)
-    register_rpc_callbacks<64>(device);
-  else
-    handle_error("Invalid wavefront size");
+  rpc::Server server(rpc::MAX_PORT_COUNT, rpc_buffer);
+  rpc::Client client(rpc::MAX_PORT_COUNT, rpc_buffer);
 
   // Initialize the RPC client on the device by copying the local data to the
   // device's internal pointer.
   hsa_executable_symbol_t rpc_client_sym;
   if (hsa_status_t err = hsa_executable_get_symbol_by_name(
-          executable, rpc_client_symbol_name, &dev_agent, &rpc_client_sym))
+          executable, "__llvm_libc_rpc_client", &dev_agent, &rpc_client_sym))
     handle_error(err);
 
   void *rpc_client_host;
@@ -507,19 +492,17 @@ int load(int argc, const char **argv, const char **envp, void *image,
 
   void *rpc_client_buffer;
   if (hsa_status_t err =
-          hsa_amd_memory_lock(const_cast<void *>(rpc_get_client_buffer(device)),
-                              rpc_get_client_size(),
+          hsa_amd_memory_lock(&client, sizeof(rpc::Client),
                               /*agents=*/nullptr, 0, &rpc_client_buffer))
     handle_error(err);
 
   // Copy the RPC client buffer to the address pointed to by the symbol.
   if (hsa_status_t err =
           hsa_memcpy(*reinterpret_cast<void **>(rpc_client_host), dev_agent,
-                     rpc_client_buffer, host_agent, rpc_get_client_size()))
+                     rpc_client_buffer, host_agent, sizeof(rpc::Client)))
     handle_error(err);
 
-  if (hsa_status_t err = hsa_amd_memory_unlock(
-          const_cast<void *>(rpc_get_client_buffer(device))))
+  if (hsa_status_t err = hsa_amd_memory_unlock(&client))
     handle_error(err);
   if (hsa_status_t err = hsa_amd_memory_pool_free(rpc_client_host))
     handle_error(err);
@@ -571,7 +554,7 @@ int load(int argc, const char **argv, const char **envp, void *image,
   LaunchParameters single_threaded_params = {1, 1, 1, 1, 1, 1};
   begin_args_t init_args = {argc, dev_argv, dev_envp};
   if (hsa_status_t err = launch_kernel(dev_agent, executable, kernargs_pool,
-                                       coarsegrained_pool, queue, device,
+                                       coarsegrained_pool, queue, server,
                                        single_threaded_params, "_begin.kd",
                                        init_args, print_resource_usage))
     handle_error(err);
@@ -579,7 +562,7 @@ int load(int argc, const char **argv, const char **envp, void *image,
   start_args_t args = {argc, dev_argv, dev_envp, dev_ret};
   if (hsa_status_t err = launch_kernel(
           dev_agent, executable, kernargs_pool, coarsegrained_pool, queue,
-          device, params, "_start.kd", args, print_resource_usage))
+          server, params, "_start.kd", args, print_resource_usage))
     handle_error(err);
 
   void *host_ret;
@@ -598,14 +581,12 @@ int load(int argc, const char **argv, const char **envp, void *image,
 
   end_args_t fini_args = {ret};
   if (hsa_status_t err = launch_kernel(dev_agent, executable, kernargs_pool,
-                                       coarsegrained_pool, queue, device,
+                                       coarsegrained_pool, queue, server,
                                        single_threaded_params, "_end.kd",
                                        fini_args, print_resource_usage))
     handle_error(err);
 
-  if (rpc_status_t err = rpc_server_shutdown(
-          device, [](void *ptr, void *) { hsa_amd_memory_pool_free(ptr); },
-          nullptr))
+  if (hsa_status_t err = hsa_amd_memory_pool_free(rpc_buffer))
     handle_error(err);
 
   // Free the memory allocated for the device.
diff --git a/libc/utils/gpu/loader/nvptx/nvptx-loader.cpp b/libc/utils/gpu/loader/nvptx/nvptx-loader.cpp
index 58e5e5f04d0a70..0ba217451feaea 100644
--- a/libc/utils/gpu/loader/nvptx/nvptx-loader.cpp
+++ b/libc/utils/gpu/loader/nvptx/nvptx-loader.cpp
@@ -167,10 +167,9 @@ void print_kernel_resources(CUmodule binary, const char *kernel_name) {
 }
 
 template <typename args_t>
-CUresult launch_kernel(CUmodule binary, CUstream stream,
-                       rpc_device_t rpc_device, const LaunchParameters &params,
-                       const char *kernel_name, args_t kernel_args,
-                       bool print_resource_usage) {
+CUresult launch_kernel(CUmodule binary, CUstream stream, rpc::Server &server,
+                       const LaunchParameters &params, const char *kernel_name,
+                       args_t kernel_args, bool print_resource_usage) {
   // look up the '_start' kernel in the loaded module.
   CUfunction function;
   if (CUresult err = cuModuleGetFunction(&function, binary, kernel_name))
@@ -181,23 +180,21 @@ CUresult launch_kernel(CUmodule binary, CUstream stream,
   void *args_config[] = {CU_LAUNCH_PARAM_BUFFER_POINTER, &kernel_args,
                          CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
                          CU_LAUNCH_PARAM_END};
+  if (print_resource_usage)
+    print_kernel_resources(binary, kernel_name);
 
-  // Initialize a non-blocking CUDA stream to allocate memory if needed. This
-  // needs to be done on a separate stream or else it will deadlock with the
-  // executing kernel.
+  // Initialize a non-blocking CUDA stream to allocate memory if needed.
+  // This needs to be done on a separate stream or else it will deadlock
+  // with the executing kernel.
   CUstream memory_stream;
   if (CUresult err = cuStreamCreate(&memory_stream, CU_STREAM_NON_BLOCKING))
     handle_error(err);
 
-  // Register RPC callbacks for the malloc and free functions on HSA.
-  register_rpc_callbacks<32>(rpc_device);
-
-  rpc_register_callback(
-      rpc_device, RPC_MALLOC,
-      [](rpc_port_t port, void *data) {
-        auto malloc_handler = [](rpc_buffer_t *buffer, void *data) -> void {
-          CUstream memory_stream = *static_cast<CUstream *>(data);
-          uint64_t size = buffer->data[0];
+  std::atomic<bool> finished = false;
+  std::thread server_thread(
+      [](std::atomic<bool> *finished, rpc::Server *server,
+         CUstream memory_stream) {
+        auto malloc_handler = [&](size_t size) -> void * {
           CUdeviceptr dev_ptr;
           if (CUresult err = cuMemAllocAsync(&dev_ptr, size, memory_stream))
             dev_ptr = 0UL;
@@ -205,36 +202,22 @@ CUresult launch_kernel(CUmodule binary, CUstream stream,
           // Wait until the memory allocation is complete.
           while (cuStreamQuery(memory_stream) == CUDA_ERROR_NOT_READY)
             ;
-          buffer->data[0] = static_cast<uintptr_t>(dev_ptr);
+          return reinterpret_cast<void *>(dev_ptr);
         };
-        rpc_recv_and_send(port, malloc_handler, data);
-      },
-      &memory_stream);
-  rpc_register_callback(
-      rpc_device, RPC_FREE,
-      [](rpc_port_t port, void *data) {
-        auto free_handler = [](rpc_buffer_t *buffer, void *data) {
-          CUstream memory_stream = *static_cast<CUstream *>(data);
-          if (CUresult err = cuMemFreeAsync(
-                  static_cast<CUdeviceptr>(buffer->data[0]), memory_stream))
+
+        auto free_handler = [&](void *ptr) -> void {
+          if (CUresult err = cuMemFreeAsync(reinterpret_cast<CUdeviceptr>(ptr),
+                                            memory_stream))
             handle_error(err);
         };
-        rpc_recv_and_send(port, free_handler, data);
-      },
-      &memory_stream);
 
-  if (print_resource_usage)
-    print_kernel_resources(binary, kernel_name);
-
-  std::atomic<bool> finished = false;
-  std::thread server(
-      [](std::atomic<bool> *finished, rpc_device_t device) {
+        uint32_t index = 0;
         while (!*finished) {
-          if (rpc_status_t err = rpc_handle_server(device))
-            handle_error(err);
+          index =
+              handle_server<32>(*server, index, malloc_handler, free_handler);
         }
       },
-      &finished, rpc_device);
+      &finished, &server, memory_stream);
 
   // Call the kernel with the given arguments.
   if (CUresult err = cuLaunchKernel(
@@ -247,8 +230,8 @@ CUresult launch_kernel(CUmodule binary, CUstream stream,
     handle_error(err);
 
   finished = true;
-  if (server.joinable())
-    server.join();
+  if (server_thread.joinable())
+    server_thread.join();
 
   return CUDA_SUCCESS;
 }
@@ -318,23 +301,20 @@ int load(int argc, const char **argv, const char **envp, void *image,
     handle_error(err);
 
   uint32_t warp_size = 32;
-  auto rpc_alloc = [](uint64_t size, void *) -> void * {
-    void *dev_ptr;
-    if (CUresult err = cuMemAllocHost(&dev_ptr, size))
-      handle_error(err);
-    return dev_ptr;
-  };
-  rpc_device_t rpc_device;
-  if (rpc_status_t err = rpc_server_init(&rpc_device, RPC_MAXIMUM_PORT_COUNT,
-                                         warp_size, rpc_alloc, nullptr))
+  void *rpc_buffer = nullptr;
+  if (CUresult err = cuMemAllocHost(
+          &rpc_buffer,
+          rpc::Server::allocation_size(warp_size, rpc::MAX_PORT_COUNT)))
     handle_error(err);
+  rpc::Server server(rpc::MAX_PORT_COUNT, rpc_buffer);
+  rpc::Client client(rpc::MAX_PORT_COUNT, rpc_buffer);
 
   // Initialize the RPC client on the device by copying the local data to the
   // device's internal pointer.
   CUdeviceptr rpc_client_dev = 0;
   uint64_t client_ptr_size = sizeof(void *);
   if (CUresult err = cuModuleGetGlobal(&rpc_client_dev, &client_ptr_size,
-                                       binary, rpc_client_symbol_name))
+                                       binary, "__llvm_libc_rpc_client"))
     handle_error(err);
 
   CUdeviceptr rpc_client_host = 0;
@@ -342,20 +322,19 @@ int load(int argc, const char **argv, const char **envp, void *image,
           cuMemcpyDtoH(&rpc_client_host, rpc_client_dev, sizeof(void *)))
     handle_error(err);
   if (CUresult err =
-          cuMemcpyHtoD(rpc_client_host, rpc_get_client_buffer(rpc_device),
-                       rpc_get_client_size()))
+          cuMemcpyHtoD(rpc_client_host, &client, sizeof(rpc::Client)))
     handle_error(err);
 
   LaunchParameters single_threaded_params = {1, 1, 1, 1, 1, 1};
   begin_args_t init_args = {argc, dev_argv, dev_envp};
   if (CUresult err =
-          launch_kernel(binary, stream, rpc_device, single_threaded_params,
+          launch_kernel(binary, stream, server, single_threaded_params,
                         "_begin", init_args, print_resource_usage))
     handle_error(err);
 
   start_args_t args = {argc, dev_argv, dev_envp,
                        reinterpret_cast<void *>(dev_ret)};
-  if (CUresult err = launch_kernel(binary, stream, rpc_device, params, "_start",
+  if (CUresult err = launch_kernel(binary, stream, server, params, "_start",
                                    args, print_resource_usage))
     handle_error(err);
 
@@ -369,8 +348,8 @@ int load(int argc, const char **argv, const char **envp, void *image,
 
   end_args_t fini_args = {host_ret};
   if (CUresult err =
-          launch_kernel(binary, stream, rpc_device, single_threaded_params,
-                        "_end", fini_args, print_resource_usage))
+          launch_kernel(binary, stream, server, single_threaded_params, "_end",
+                        fini_args, print_resource_usage))
     handle_error(err);
 
   // Free the memory allocated for the device.
@@ -380,8 +359,7 @@ int load(int argc, const char **argv, const char **envp, void *image,
     handle_error(err);
   if (CUresult err = cuMemFreeHost(dev_argv))
     handle_error(err);
-  if (rpc_status_t err = rpc_server_shutdown(
-          rpc_device, [](void *ptr, void *) { cuMemFreeHost(ptr); }, nullptr))
+  if (CUresult err = cuMemFreeHost(rpc_buffer))
     handle_error(err);
 
   // Destroy the context and the loaded binary.
diff --git a/libc/utils/gpu/server/llvmlibc_rpc_server.h b/libc/utils/gpu/server/llvmlibc_rpc_server.h
index 98df882afa21cf..b7f173734345c0 100644
--- a/libc/utils/gpu/server/llvmlibc_rpc_server.h
+++ b/libc/utils/gpu/server/llvmlibc_rpc_server.h
@@ -15,99 +15,7 @@
 extern "C" {
 #endif
 
-/// The maximum number of ports that can be opened for any server.
-const uint64_t RPC_MAXIMUM_PORT_COUNT = 4096;
-
-/// The symbol name associated with the client for use with the LLVM C library
-/// implementation.
-const char *const rpc_client_symbol_name = "__llvm_libc_rpc_client";
-
-/// status codes.
-typedef enum {
-  RPC_STATUS_SUCCESS = 0x0,
-  RPC_STATUS_CONTINUE = 0x1,
-  RPC_STATUS_ERROR = 0x1000,
-  RPC_STATUS_UNHANDLED_OPCODE = 0x1001,
-  RPC_STATUS_INVALID_LANE_SIZE = 0x1002,
-} rpc_status_t;
-
-/// A struct containing an opaque handle to an RPC port. This is what allows the
-/// server to communicate with the client.
-typedef struct rpc_port_s {
-  uint64_t handle;
-  uint32_t lane_size;
-} rpc_port_t;
-
-/// A fixed-size buffer containing the payload sent from the client.
-typedef struct rpc_buffer_s {
-  uint64_t data[8];
-} rpc_buffer_t;
-
-/// An opaque handle to an RPC server that can be attached to a device.
-typedef struct rpc_device_s {
-  uintptr_t handle;
-} rpc_device_t;
-
-/// A function used to allocate \p bytes for use by the RPC server and client.
-/// The memory should support asynchronous and atomic access from both the
-/// client and server.
-typedef void *(*rpc_alloc_ty)(uint64_t size, void *data);
-
-/// A function used to free the \p ptr previously allocated.
-typedef void (*rpc_free_ty)(void *ptr, void *data);
-
-/// A callback function provided with a \p port to communicate with the RPC
-/// client. This will be called by the server to handle an opcode.
-typedef void (*rpc_opcode_callback_ty)(rpc_port_t port, void *data);
-
-/// A callback function to use the port to receive or send a \p buffer.
-typedef void (*rpc_port_callback_ty)(rpc_buffer_t *buffer, void *data);
-
-/// Initialize the server for a given device and return it in \p device.
-rpc_status_t rpc_server_init(rpc_device_t *rpc_device, uint64_t num_ports,
-                             uint32_t lane_size, rpc_alloc_ty alloc,
-                             void *data);
-
-/// Shut down the server for a given device.
-rpc_status_t rpc_server_shutdown(rpc_device_t rpc_device, rpc_free_ty dealloc,
-                                 void *data);
-
-/// Queries the RPC clients at least once and performs server-side work if there
-/// are any active requests. Runs until all work on the server is completed.
-rpc_status_t rpc_handle_server(rpc_device_t rpc_device);
-
-/// Register a callback to handle an opcode from the RPC client. The associated
-/// data must remain accessible as long as the user intends to handle the server
-/// with this callback.
-rpc_status_t rpc_register_callback(rpc_device_t rpc_device, uint32_t opcode,
-                                   rpc_opcode_callback_ty callback, void *data);
-
-/// Obtain a pointer to a local client buffer that can be copied directly to the
-/// other process using the address stored at the rpc client symbol name.
-const void *rpc_get_client_buffer(rpc_device_t device);
-
-/// Returns the size of the client in bytes to be used for a memory copy.
-uint64_t rpc_get_client_size();
-
-/// Use the \p port to send a buffer using the \p callback.
-void rpc_send(rpc_port_t port, rpc_port_callback_ty callback, void *data);
-
-/// Use the \p port to send \p bytes using the \p callback. The input is an
-/// array of at least the configured lane size.
-void rpc_send_n(rpc_port_t port, const void *const *src, uint64_t *size);
-
-/// Use the \p port to recieve a buffer using the \p callback.
-void rpc_recv(rpc_port_t port, rpc_port_callback_ty callback, void *data);
-
-/// Use the \p port to recieve \p bytes using the \p callback. The inputs is an
-/// array of at least the configured lane size. The \p alloc function allocates
-/// memory for the recieved bytes.
-void rpc_recv_n(rpc_port_t port, void **dst, uint64_t *size, rpc_alloc_ty alloc,
-                void *data);
-
-/// Use the \p port to receive and send a buffer using the \p callback.
-void rpc_recv_and_send(rpc_port_t port, rpc_port_callback_ty callback,
-                       void *data);
+int libc_handle_rpc_port(void *port, uint32_t num_lanes);
 
 #ifdef __cplusplus
 }
diff --git a/libc/utils/gpu/server/rpc_server.cpp b/libc/utils/gpu/server/rpc_server.cpp
index d877cbc25a13d0..1fdbb79df7e3e0 100644
--- a/libc/utils/gpu/server/rpc_server.cpp
+++ b/libc/utils/gpu/server/rpc_server.cpp
@@ -37,12 +37,6 @@
 using namespace LIBC_NAMESPACE;
 using namespace LIBC_NAMESPACE::printf_core;
 
-static_assert(sizeof(rpc_buffer_t) == sizeof(rpc::Buffer),
-              "Buffer size mismatch");
-
-static_assert(RPC_MAXIMUM_PORT_COUNT == rpc::MAX_PORT_COUNT,
-              "Incorrect maximum port count");
-
 namespace {
 struct TempStorage {
   char *alloc(size_t size) {
@@ -74,9 +68,9 @@ LIBC_INLINE ::FILE *to_stream(uintptr_t f) {
   return stream;
 }
 
-template <bool packed, uint32_t lane_size>
+template <bool packed, uint32_t num_lanes>
 static void handle_printf(rpc::Server::Port &port, TempStorage &temp_storage) {
-  FILE *files[lane_size] = {nullptr};
+  FILE *files[num_lanes] = {nullptr};
   // Get the appropriate output stream to use.
   if (port.get_opcode() == RPC_PRINTF_TO_STREAM ||
       port.get_opcode() == RPC_PRINTF_TO_STREAM_PACKED)
@@ -85,22 +79,22 @@ static void handle_printf(rpc::Server::Port &port, TempStorage &temp_storage) {
     });
   else if (port.get_opcode() == RPC_PRINTF_TO_STDOUT ||
            port.get_opcode() == RPC_PRINTF_TO_STDOUT_PACKED)
-    std::fill(files, files + lane_size, stdout);
+    std::fill(files, files + num_lanes, stdout);
   else
-    std::fill(files, files + lane_size, stderr);
+    std::fill(files, files + num_lanes, stderr);
 
-  uint64_t format_sizes[lane_size] = {0};
-  void *format[lane_size] = {nullptr};
+  uint64_t format_sizes[num_lanes] = {0};
+  void *format[num_lanes] = {nullptr};
 
-  uint64_t args_sizes[lane_size] = {0};
-  void *args[lane_size] = {nullptr};
+  uint64_t args_sizes[num_lanes] = {0};
+  void *args[num_lanes] = {nullptr};
 
   // Recieve the format string and arguments from the client.
   port.recv_n(format, format_sizes,
               [&](uint64_t size) { return temp_storage.alloc(size); });
 
   // Parse the format string to get the expected size of the buffer.
-  for (uint32_t lane = 0; lane < lane_size; ++lane) {
+  for (uint32_t lane = 0; lane < num_lanes; ++lane) {
     if (!format[lane])
       continue;
 
@@ -125,9 +119,9 @@ static void handle_printf(rpc::Server::Port &port, TempStorage &temp_storage) {
 
   // Identify any arguments that are actually pointers to strings on the client.
   // Additionally we want to determine how much buffer space we need to print.
-  std::vector<void *> strs_to_copy[lane_size];
-  int buffer_size[lane_size] = {0};
-  for (uint32_t lane = 0; lane < lane_size; ++lane) {
+  std::vector<void *> strs_to_copy[num_lanes];
+  int buffer_size[num_lanes] = {0};
+  for (uint32_t lane = 0; lane < num_lanes; ++lane) {
     if (!format[lane])
       continue;
 
@@ -159,7 +153,7 @@ static void handle_printf(rpc::Server::Port &port, TempStorage &temp_storage) {
   }
 
   // Recieve any strings from the client and push them into a buffer.
-  std::vector<void *> copied_strs[lane_size];
+  std::vector<void *> copied_strs[num_lanes];
   while (std::any_of(std::begin(strs_to_copy), std::end(strs_to_copy),
                      [](const auto &v) { return !v.empty() && v.back(); })) {
     port.send([&](rpc::Buffer *buffer, uint32_t id) {
@@ -168,11 +162,11 @@ static void handle_printf(rpc::Server::Port &port, TempStorage &temp_storage) {
       if (!strs_to_copy[id].empty())
         strs_to_copy[id].pop_back();
     });
-    uint64_t str_sizes[lane_size] = {0};
-    void *strs[lane_size] = {nullptr};
+    uint64_t str_sizes[num_lanes] = {0};
+    void *strs[num_lanes] = {nullptr};
     port.recv_n(strs, str_sizes,
                 [&](uint64_t size) { return temp_storage.alloc(size); });
-    for (uint32_t lane = 0; lane < lane_size; ++lane) {
+    for (uint32_t lane = 0; lane < num_lanes; ++lane) {
       if (!strs[lane])
         continue;
 
@@ -182,8 +176,8 @@ static void handle_printf(rpc::Server::Port &port, TempStorage &temp_storage) {
   }
 
   // Perform the final formatting and printing using the LLVM C library printf.
-  int results[lane_size] = {0};
-  for (uint32_t lane = 0; lane < lane_size; ++lane) {
+  int results[num_lanes] = {0};
+  for (uint32_t lane = 0; lane < num_lanes; ++lane) {
     if (!format[lane])
       continue;
 
@@ -233,42 +227,34 @@ static void handle_printf(rpc::Server::Port &port, TempStorage &temp_storage) {
   });
 }
 
-template <uint32_t lane_size>
-rpc_status_t handle_server_impl(
-    rpc::Server &server,
-    const std::unordered_map<uint32_t, rpc_opcode_callback_ty> &callbacks,
-    const std::unordered_map<uint32_t, void *> &callback_data,
-    uint32_t &index) {
-  auto port = server.try_open(lane_size, index);
-  if (!port)
-    return RPC_STATUS_SUCCESS;
-
+template <uint32_t num_lanes>
+rpc::Status handle_port_impl(rpc::Server::Port &port) {
   TempStorage temp_storage;
 
-  switch (port->get_opcode()) {
+  switch (port.get_opcode()) {
   case RPC_WRITE_TO_STREAM:
   case RPC_WRITE_TO_STDERR:
   case RPC_WRITE_TO_STDOUT:
   case RPC_WRITE_TO_STDOUT_NEWLINE: {
-    uint64_t sizes[lane_size] = {0};
-    void *strs[lane_size] = {nullptr};
-    FILE *files[lane_size] = {nullptr};
-    if (port->get_opcode() == RPC_WRITE_TO_STREAM) {
-      port->recv([&](rpc::Buffer *buffer, uint32_t id) {
+    uint64_t sizes[num_lanes] = {0};
+    void *strs[num_lanes] = {nullptr};
+    FILE *files[num_lanes] = {nullptr};
+    if (port.get_opcode() == RPC_WRITE_TO_STREAM) {
+      port.recv([&](rpc::Buffer *buffer, uint32_t id) {
         files[id] = reinterpret_cast<FILE *>(buffer->data[0]);
       });
-    } else if (port->get_opcode() == RPC_WRITE_TO_STDERR) {
-      std::fill(files, files + lane_size, stderr);
+    } else if (port.get_opcode() == RPC_WRITE_TO_STDERR) {
+      std::fill(files, files + num_lanes, stderr);
     } else {
-      std::fill(files, files + lane_size, stdout);
+      std::fill(files, files + num_lanes, stdout);
     }
 
-    port->recv_n(strs, sizes,
-                 [&](uint64_t size) { return temp_storage.alloc(size); });
-    port->send([&](rpc::Buffer *buffer, uint32_t id) {
+    port.recv_n(strs, sizes,
+                [&](uint64_t size) { return temp_storage.alloc(size); });
+    port.send([&](rpc::Buffer *buffer, uint32_t id) {
       flockfile(files[id]);
       buffer->data[0] = fwrite_unlocked(strs[id], 1, sizes[id], files[id]);
-      if (port->get_opcode() == RPC_WRITE_TO_STDOUT_NEWLINE &&
+      if (port.get_opcode() == RPC_WRITE_TO_STDOUT_NEWLINE &&
           buffer->data[0] == sizes[id])
         buffer->data[0] += fwrite_unlocked("\n", 1, 1, files[id]);
       funlockfile(files[id]);
@@ -276,37 +262,37 @@ rpc_status_t handle_server_impl(
     break;
   }
   case RPC_READ_FROM_STREAM: {
-    uint64_t sizes[lane_size] = {0};
-    void *data[lane_size] = {nullptr};
-    port->recv([&](rpc::Buffer *buffer, uint32_t id) {
+    uint64_t sizes[num_lanes] = {0};
+    void *data[num_lanes] = {nullptr};
+    port.recv([&](rpc::Buffer *buffer, uint32_t id) {
       data[id] = temp_storage.alloc(buffer->data[0]);
       sizes[id] =
           fread(data[id], 1, buffer->data[0], to_stream(buffer->data[1]));
     });
-    port->send_n(data, sizes);
-    port->send([&](rpc::Buffer *buffer, uint32_t id) {
+    port.send_n(data, sizes);
+    port.send([&](rpc::Buffer *buffer, uint32_t id) {
       std::memcpy(buffer->data, &sizes[id], sizeof(uint64_t));
     });
     break;
   }
   case RPC_READ_FGETS: {
-    uint64_t sizes[lane_size] = {0};
-    void *data[lane_size] = {nullptr};
-    port->recv([&](rpc::Buffer *buffer, uint32_t id) {
+    uint64_t sizes[num_lanes] = {0};
+    void *data[num_lanes] = {nullptr};
+    port.recv([&](rpc::Buffer *buffer, uint32_t id) {
       data[id] = temp_storage.alloc(buffer->data[0]);
       const char *str = fgets(reinterpret_cast<char *>(data[id]),
                               buffer->data[0], to_stream(buffer->data[1]));
       sizes[id] = !str ? 0 : std::strlen(str) + 1;
     });
-    port->send_n(data, sizes);
+    port.send_n(data, sizes);
     break;
   }
   case RPC_OPEN_FILE: {
-    uint64_t sizes[lane_size] = {0};
-    void *paths[lane_size] = {nullptr};
-    port->recv_n(paths, sizes,
-                 [&](uint64_t size) { return temp_storage.alloc(size); });
-    port->recv_and_send([&](rpc::Buffer *buffer, uint32_t id) {
+    uint64_t sizes[num_lanes] = {0};
+    void *paths[num_lanes] = {nullptr};
+    port.recv_n(paths, sizes,
+                [&](uint64_t size) { return temp_storage.alloc(size); });
+    port.recv_and_send([&](rpc::Buffer *buffer, uint32_t id) {
       FILE *file = fopen(reinterpret_cast<char *>(paths[id]),
                          reinterpret_cast<char *>(buffer->data));
       buffer->data[0] = reinterpret_cast<uintptr_t>(file);
@@ -314,7 +300,7 @@ rpc_status_t handle_server_impl(
     break;
   }
   case RPC_CLOSE_FILE: {
-    port->recv_and_send([&](rpc::Buffer *buffer, uint32_t id) {
+    port.recv_and_send([&](rpc::Buffer *buffer, uint32_t id) {
       FILE *file = reinterpret_cast<FILE *>(buffer->data[0]);
       buffer->data[0] = fclose(file);
     });
@@ -322,8 +308,8 @@ rpc_status_t handle_server_impl(
   }
   case RPC_EXIT: {
     // Send a response to the client to signal that we are ready to exit.
-    port->recv_and_send([](rpc::Buffer *, uint32_t) {});
-    port->recv([](rpc::Buffer *buffer, uint32_t) {
+    port.recv_and_send([](rpc::Buffer *, uint32_t) {});
+    port.recv([](rpc::Buffer *buffer, uint32_t) {
       int status = 0;
       std::memcpy(&status, buffer->data, sizeof(int));
       exit(status);
@@ -332,47 +318,47 @@ rpc_status_t handle_server_impl(
   }
   case RPC_ABORT: {
     // Send a response to the client to signal that we are ready to abort.
-    port->recv_and_send([](rpc::Buffer *, uint32_t) {});
-    port->recv([](rpc::Buffer *, uint32_t) {});
+    port.recv_and_send([](rpc::Buffer *, uint32_t) {});
+    port.recv([](rpc::Buffer *, uint32_t) {});
     abort();
     break;
   }
   case RPC_HOST_CALL: {
-    uint64_t sizes[lane_size] = {0};
-    unsigned long long results[lane_size] = {0};
-    void *args[lane_size] = {nullptr};
-    port->recv_n(args, sizes,
-                 [&](uint64_t size) { return temp_storage.alloc(size); });
-    port->recv([&](rpc::Buffer *buffer, uint32_t id) {
+    uint64_t sizes[num_lanes] = {0};
+    unsigned long long results[num_lanes] = {0};
+    void *args[num_lanes] = {nullptr};
+    port.recv_n(args, sizes,
+                [&](uint64_t size) { return temp_storage.alloc(size); });
+    port.recv([&](rpc::Buffer *buffer, uint32_t id) {
       using func_ptr_t = unsigned long long (*)(void *);
       auto func = reinterpret_cast<func_ptr_t>(buffer->data[0]);
       results[id] = func(args[id]);
     });
-    port->send([&](rpc::Buffer *buffer, uint32_t id) {
+    port.send([&](rpc::Buffer *buffer, uint32_t id) {
       buffer->data[0] = static_cast<uint64_t>(results[id]);
     });
     break;
   }
   case RPC_FEOF: {
-    port->recv_and_send([](rpc::Buffer *buffer, uint32_t) {
+    port.recv_and_send([](rpc::Buffer *buffer, uint32_t) {
       buffer->data[0] = feof(to_stream(buffer->data[0]));
     });
     break;
   }
   case RPC_FERROR: {
-    port->recv_and_send([](rpc::Buffer *buffer, uint32_t) {
+    port.recv_and_send([](rpc::Buffer *buffer, uint32_t) {
       buffer->data[0] = ferror(to_stream(buffer->data[0]));
     });
     break;
   }
   case RPC_CLEARERR: {
-    port->recv_and_send([](rpc::Buffer *buffer, uint32_t) {
+    port.recv_and_send([](rpc::Buffer *buffer, uint32_t) {
       clearerr(to_stream(buffer->data[0]));
     });
     break;
   }
   case RPC_FSEEK: {
-    port->recv_and_send([](rpc::Buffer *buffer, uint32_t) {
+    port.recv_and_send([](rpc::Buffer *buffer, uint32_t) {
       buffer->data[0] =
           fseek(to_stream(buffer->data[0]), static_cast<long>(buffer->data[1]),
                 static_cast<int>(buffer->data[2]));
@@ -380,19 +366,19 @@ rpc_status_t handle_server_impl(
     break;
   }
   case RPC_FTELL: {
-    port->recv_and_send([](rpc::Buffer *buffer, uint32_t) {
+    port.recv_and_send([](rpc::Buffer *buffer, uint32_t) {
       buffer->data[0] = ftell(to_stream(buffer->data[0]));
     });
     break;
   }
   case RPC_FFLUSH: {
-    port->recv_and_send([](rpc::Buffer *buffer, uint32_t) {
+    port.recv_and_send([](rpc::Buffer *buffer, uint32_t) {
       buffer->data[0] = fflush(to_stream(buffer->data[0]));
     });
     break;
   }
   case RPC_UNGETC: {
-    port->recv_and_send([](rpc::Buffer *buffer, uint32_t) {
+    port.recv_and_send([](rpc::Buffer *buffer, uint32_t) {
       buffer->data[0] =
           ungetc(static_cast<int>(buffer->data[0]), to_stream(buffer->data[1]));
     });
@@ -401,36 +387,36 @@ rpc_status_t handle_server_impl(
   case RPC_PRINTF_TO_STREAM_PACKED:
   case RPC_PRINTF_TO_STDOUT_PACKED:
   case RPC_PRINTF_TO_STDERR_PACKED: {
-    handle_printf<true, lane_size>(*port, temp_storage);
+    handle_printf<true, num_lanes>(port, temp_storage);
     break;
   }
   case RPC_PRINTF_TO_STREAM:
   case RPC_PRINTF_TO_STDOUT:
   case RPC_PRINTF_TO_STDERR: {
-    handle_printf<false, lane_size>(*port, temp_storage);
+    handle_printf<false, num_lanes>(port, temp_storage);
     break;
   }
   case RPC_REMOVE: {
-    uint64_t sizes[lane_size] = {0};
-    void *args[lane_size] = {nullptr};
-    port->recv_n(args, sizes,
-                 [&](uint64_t size) { return temp_storage.alloc(size); });
-    port->send([&](rpc::Buffer *buffer, uint32_t id) {
+    uint64_t sizes[num_lanes] = {0};
+    void *args[num_lanes] = {nullptr};
+    port.recv_n(args, sizes,
+                [&](uint64_t size) { return temp_storage.alloc(size); });
+    port.send([&](rpc::Buffer *buffer, uint32_t id) {
       buffer->data[0] = static_cast<uint64_t>(
           remove(reinterpret_cast<const char *>(args[id])));
     });
     break;
   }
   case RPC_RENAME: {
-    uint64_t oldsizes[lane_size] = {0};
-    uint64_t newsizes[lane_size] = {0};
-    void *oldpath[lane_size] = {nullptr};
-    void *newpath[lane_size] = {nullptr};
-    port->recv_n(oldpath, oldsizes,
-                 [&](uint64_t size) { return temp_storage.alloc(size); });
-    port->recv_n(newpath, newsizes,
-                 [&](uint64_t size) { return temp_storage.alloc(size); });
-    port->send([&](rpc::Buffer *buffer, uint32_t id) {
+    uint64_t oldsizes[num_lanes] = {0};
+    uint64_t newsizes[num_lanes] = {0};
+    void *oldpath[num_lanes] = {nullptr};
+    void *newpath[num_lanes] = {nullptr};
+    port.recv_n(oldpath, oldsizes,
+                [&](uint64_t size) { return temp_storage.alloc(size); });
+    port.recv_n(newpath, newsizes,
+                [&](uint64_t size) { return temp_storage.alloc(size); });
+    port.send([&](rpc::Buffer *buffer, uint32_t id) {
       buffer->data[0] = static_cast<uint64_t>(
           rename(reinterpret_cast<const char *>(oldpath[id]),
                  reinterpret_cast<const char *>(newpath[id])));
@@ -438,168 +424,36 @@ rpc_status_t handle_server_impl(
     break;
   }
   case RPC_SYSTEM: {
-    uint64_t sizes[lane_size] = {0};
-    void *args[lane_size] = {nullptr};
-    port->recv_n(args, sizes,
-                 [&](uint64_t size) { return temp_storage.alloc(size); });
-    port->send([&](rpc::Buffer *buffer, uint32_t id) {
+    uint64_t sizes[num_lanes] = {0};
+    void *args[num_lanes] = {nullptr};
+    port.recv_n(args, sizes,
+                [&](uint64_t size) { return temp_storage.alloc(size); });
+    port.send([&](rpc::Buffer *buffer, uint32_t id) {
       buffer->data[0] = static_cast<uint64_t>(
           system(reinterpret_cast<const char *>(args[id])));
     });
     break;
   }
   case RPC_NOOP: {
-    port->recv([](rpc::Buffer *, uint32_t) {});
+    port.recv([](rpc::Buffer *, uint32_t) {});
     break;
   }
-  default: {
-    auto handler =
-        callbacks.find(static_cast<rpc_opcode_t>(port->get_opcode()));
-
-    // We error out on an unhandled opcode.
-    if (handler == callbacks.end())
-      return RPC_STATUS_UNHANDLED_OPCODE;
-
-    // Invoke the registered callback with a reference to the port.
-    void *data =
-        callback_data.at(static_cast<rpc_opcode_t>(port->get_opcode()));
-    rpc_port_t port_ref{reinterpret_cast<uint64_t>(&*port), lane_size};
-    (handler->second)(port_ref, data);
+  default:
+    return rpc::UNHANDLED_OPCODE;
   }
-  }
-
-  // Increment the index so we start the scan after this port.
-  index = port->get_index() + 1;
-  port->close();
 
-  return RPC_STATUS_CONTINUE;
+  return rpc::SUCCESS;
 }
 
-struct Device {
-  Device(uint32_t lane_size, uint32_t num_ports, void *buffer)
-      : lane_size(lane_size), buffer(buffer), server(num_ports, buffer),
-        client(num_ports, buffer) {}
-
-  rpc_status_t handle_server(uint32_t &index) {
-    switch (lane_size) {
-    case 1:
-      return handle_server_impl<1>(server, callbacks, callback_data, index);
-    case 32:
-      return handle_server_impl<32>(server, callbacks, callback_data, index);
-    case 64:
-      return handle_server_impl<64>(server, callbacks, callback_data, index);
-    default:
-      return RPC_STATUS_INVALID_LANE_SIZE;
-    }
+int libc_handle_rpc_port(void *port, uint32_t num_lanes) {
+  switch (num_lanes) {
+  case 1:
+    return handle_port_impl<1>(*reinterpret_cast<rpc::Server::Port *>(port));
+  case 32:
+    return handle_port_impl<32>(*reinterpret_cast<rpc::Server::Port *>(port));
+  case 64:
+    return handle_port_impl<64>(*reinterpret_cast<rpc::Server::Port *>(port));
+  default:
+    return rpc::ERROR;
   }
-
-  uint32_t lane_size;
-  void *buffer;
-  rpc::Server server;
-  rpc::Client client;
-  std::unordered_map<uint32_t, rpc_opcode_callback_ty> callbacks;
-  std::unordered_map<uint32_t, void *> callback_data;
-};
-
-rpc_status_t rpc_server_init(rpc_device_t *rpc_device, uint64_t num_ports,
-                             uint32_t lane_size, rpc_alloc_ty alloc,
-                             void *data) {
-  if (!rpc_device)
-    return RPC_STATUS_ERROR;
-  if (lane_size != 1 && lane_size != 32 && lane_size != 64)
-    return RPC_STATUS_INVALID_LANE_SIZE;
-
-  uint64_t size = rpc::Server::allocation_size(lane_size, num_ports);
-  void *buffer = alloc(size, data);
-
-  if (!buffer)
-    return RPC_STATUS_ERROR;
-
-  Device *device = new Device(lane_size, num_ports, buffer);
-  if (!device)
-    return RPC_STATUS_ERROR;
-
-  rpc_device->handle = reinterpret_cast<uintptr_t>(device);
-  return RPC_STATUS_SUCCESS;
-}
-
-rpc_status_t rpc_server_shutdown(rpc_device_t rpc_device, rpc_free_ty dealloc,
-                                 void *data) {
-  if (!rpc_device.handle)
-    return RPC_STATUS_ERROR;
-
-  Device *device = reinterpret_cast<Device *>(rpc_device.handle);
-  dealloc(device->buffer, data);
-  delete device;
-
-  return RPC_STATUS_SUCCESS;
-}
-
-rpc_status_t rpc_handle_server(rpc_device_t rpc_device) {
-  if (!rpc_device.handle)
-    return RPC_STATUS_ERROR;
-
-  Device *device = reinterpret_cast<Device *>(rpc_device.handle);
-  uint32_t index = 0;
-  for (;;) {
-    rpc_status_t status = device->handle_server(index);
-    if (status != RPC_STATUS_CONTINUE)
-      return status;
-  }
-}
-
-rpc_status_t rpc_register_callback(rpc_device_t rpc_device, uint32_t opcode,
-                                   rpc_opcode_callback_ty callback,
-                                   void *data) {
-  if (!rpc_device.handle)
-    return RPC_STATUS_ERROR;
-
-  Device *device = reinterpret_cast<Device *>(rpc_device.handle);
-
-  device->callbacks[opcode] = callback;
-  device->callback_data[opcode] = data;
-  return RPC_STATUS_SUCCESS;
-}
-
-const void *rpc_get_client_buffer(rpc_device_t rpc_device) {
-  if (!rpc_device.handle)
-    return nullptr;
-  Device *device = reinterpret_cast<Device *>(rpc_device.handle);
-  return &device->client;
-}
-
-uint64_t rpc_get_client_size() { return sizeof(rpc::Client); }
-
-void rpc_send(rpc_port_t ref, rpc_port_callback_ty callback, void *data) {
-  auto port = reinterpret_cast<rpc::Server::Port *>(ref.handle);
-  port->send([=](rpc::Buffer *buffer, uint32_t) {
-    callback(reinterpret_cast<rpc_buffer_t *>(buffer), data);
-  });
-}
-
-void rpc_send_n(rpc_port_t ref, const void *const *src, uint64_t *size) {
-  auto port = reinterpret_cast<rpc::Server::Port *>(ref.handle);
-  port->send_n(src, size);
-}
-
-void rpc_recv(rpc_port_t ref, rpc_port_callback_ty callback, void *data) {
-  auto port = reinterpret_cast<rpc::Server::Port *>(ref.handle);
-  port->recv([=](rpc::Buffer *buffer, uint32_t) {
-    callback(reinterpret_cast<rpc_buffer_t *>(buffer), data);
-  });
-}
-
-void rpc_recv_n(rpc_port_t ref, void **dst, uint64_t *size, rpc_alloc_ty alloc,
-                void *data) {
-  auto port = reinterpret_cast<rpc::Server::Port *>(ref.handle);
-  auto alloc_fn = [=](uint64_t size) { return alloc(size, data); };
-  port->recv_n(dst, size, alloc_fn);
-}
-
-void rpc_recv_and_send(rpc_port_t ref, rpc_port_callback_ty callback,
-                       void *data) {
-  auto port = reinterpret_cast<rpc::Server::Port *>(ref.handle);
-  port->recv_and_send([=](rpc::Buffer *buffer, uint32_t) {
-    callback(reinterpret_cast<rpc_buffer_t *>(buffer), data);
-  });
 }
diff --git a/offload/plugins-nextgen/common/CMakeLists.txt b/offload/plugins-nextgen/common/CMakeLists.txt
index fde4b2f930349e..3ed5c02ed4a3bb 100644
--- a/offload/plugins-nextgen/common/CMakeLists.txt
+++ b/offload/plugins-nextgen/common/CMakeLists.txt
@@ -34,6 +34,7 @@ elseif(${LIBOMPTARGET_GPU_LIBC_SUPPORT})
     # We may need to get the headers directly from the 'libc' source directory.
     target_include_directories(PluginCommon PRIVATE
                                ${CMAKE_SOURCE_DIR}/../libc/utils/gpu/server
+                               ${CMAKE_SOURCE_DIR}/../libc/
                                ${CMAKE_SOURCE_DIR}/../libc/include)
   endif()
 endif()
diff --git a/offload/plugins-nextgen/common/include/RPC.h b/offload/plugins-nextgen/common/include/RPC.h
index 01bf539bcb3f32..5b9b7ffd086b57 100644
--- a/offload/plugins-nextgen/common/include/RPC.h
+++ b/offload/plugins-nextgen/common/include/RPC.h
@@ -61,7 +61,7 @@ struct RPCServerTy {
 
 private:
   /// Array from this device's identifier to its attached devices.
-  llvm::SmallVector<uintptr_t> Handles;
+  llvm::SmallVector<void *> Buffers;
 };
 
 } // namespace llvm::omp::target
diff --git a/offload/plugins-nextgen/common/src/RPC.cpp b/offload/plugins-nextgen/common/src/RPC.cpp
index faa2cbd4f02fe1..48ed612cf853b6 100644
--- a/offload/plugins-nextgen/common/src/RPC.cpp
+++ b/offload/plugins-nextgen/common/src/RPC.cpp
@@ -12,9 +12,11 @@
 
 #include "PluginInterface.h"
 
+// TODO: This should be included unconditionally and cleaned up.
 #if defined(LIBOMPTARGET_RPC_SUPPORT)
-#include "llvm-libc-types/rpc_opcodes_t.h"
+#include "include/llvm-libc-types/rpc_opcodes_t.h"
 #include "llvmlibc_rpc_server.h"
+#include "shared/rpc.h"
 #endif
 
 using namespace llvm;
@@ -22,14 +24,14 @@ using namespace omp;
 using namespace target;
 
 RPCServerTy::RPCServerTy(plugin::GenericPluginTy &Plugin)
-    : Handles(Plugin.getNumDevices()) {}
+    : Buffers(Plugin.getNumDevices()) {}
 
 llvm::Expected<bool>
 RPCServerTy::isDeviceUsingRPC(plugin::GenericDeviceTy &Device,
                               plugin::GenericGlobalHandlerTy &Handler,
                               plugin::DeviceImageTy &Image) {
 #ifdef LIBOMPTARGET_RPC_SUPPORT
-  return Handler.isSymbolInImage(Device, Image, rpc_client_symbol_name);
+  return Handler.isSymbolInImage(Device, Image, "__llvm_libc_rpc_client");
 #else
   return false;
 #endif
@@ -39,59 +41,18 @@ Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device,
                               plugin::GenericGlobalHandlerTy &Handler,
                               plugin::DeviceImageTy &Image) {
 #ifdef LIBOMPTARGET_RPC_SUPPORT
-  auto Alloc = [](uint64_t Size, void *Data) {
-    plugin::GenericDeviceTy &Device =
-        *reinterpret_cast<plugin::GenericDeviceTy *>(Data);
-    return Device.allocate(Size, nullptr, TARGET_ALLOC_HOST);
-  };
   uint64_t NumPorts =
-      std::min(Device.requestedRPCPortCount(), RPC_MAXIMUM_PORT_COUNT);
-  rpc_device_t RPCDevice;
-  if (rpc_status_t Err = rpc_server_init(&RPCDevice, NumPorts,
-                                         Device.getWarpSize(), Alloc, &Device))
+      std::min(Device.requestedRPCPortCount(), rpc::MAX_PORT_COUNT);
+  void *RPCBuffer = Device.allocate(
+      rpc::Server::allocation_size(Device.getWarpSize(), NumPorts), nullptr,
+      TARGET_ALLOC_HOST);
+  if (!RPCBuffer)
     return plugin::Plugin::error(
-        "Failed to initialize RPC server for device %d: %d",
-        Device.getDeviceId(), Err);
-
-  // Register a custom opcode handler to perform plugin specific allocation.
-  auto MallocHandler = [](rpc_port_t Port, void *Data) {
-    rpc_recv_and_send(
-        Port,
-        [](rpc_buffer_t *Buffer, void *Data) {
-          plugin::GenericDeviceTy &Device =
-              *reinterpret_cast<plugin::GenericDeviceTy *>(Data);
-          Buffer->data[0] = reinterpret_cast<uintptr_t>(Device.allocate(
-              Buffer->data[0], nullptr, TARGET_ALLOC_DEVICE_NON_BLOCKING));
-        },
-        Data);
-  };
-  if (rpc_status_t Err =
-          rpc_register_callback(RPCDevice, RPC_MALLOC, MallocHandler, &Device))
-    return plugin::Plugin::error(
-        "Failed to register RPC malloc handler for device %d: %d\n",
-        Device.getDeviceId(), Err);
-
-  // Register a custom opcode handler to perform plugin specific deallocation.
-  auto FreeHandler = [](rpc_port_t Port, void *Data) {
-    rpc_recv(
-        Port,
-        [](rpc_buffer_t *Buffer, void *Data) {
-          plugin::GenericDeviceTy &Device =
-              *reinterpret_cast<plugin::GenericDeviceTy *>(Data);
-          Device.free(reinterpret_cast<void *>(Buffer->data[0]),
-                      TARGET_ALLOC_DEVICE_NON_BLOCKING);
-        },
-        Data);
-  };
-  if (rpc_status_t Err =
-          rpc_register_callback(RPCDevice, RPC_FREE, FreeHandler, &Device))
-    return plugin::Plugin::error(
-        "Failed to register RPC free handler for device %d: %d\n",
-        Device.getDeviceId(), Err);
+        "Failed to initialize RPC server for device %d", Device.getDeviceId());
 
   // Get the address of the RPC client from the device.
   void *ClientPtr;
-  plugin::GlobalTy ClientGlobal(rpc_client_symbol_name, sizeof(void *));
+  plugin::GlobalTy ClientGlobal("__llvm_libc_rpc_client", sizeof(void *));
   if (auto Err =
           Handler.getGlobalMetadataFromDevice(Device, Image, ClientGlobal))
     return Err;
@@ -100,38 +61,63 @@ Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device,
                                      sizeof(void *), nullptr))
     return Err;
 
-  const void *ClientBuffer = rpc_get_client_buffer(RPCDevice);
-  if (auto Err = Device.dataSubmit(ClientPtr, ClientBuffer,
-                                   rpc_get_client_size(), nullptr))
+  rpc::Client client(NumPorts, RPCBuffer);
+  if (auto Err =
+          Device.dataSubmit(ClientPtr, &client, sizeof(rpc::Client), nullptr))
     return Err;
-  Handles[Device.getDeviceId()] = RPCDevice.handle;
+  Buffers[Device.getDeviceId()] = RPCBuffer;
+
+  return Error::success();
+
 #endif
   return Error::success();
 }
 
 Error RPCServerTy::runServer(plugin::GenericDeviceTy &Device) {
 #ifdef LIBOMPTARGET_RPC_SUPPORT
-  rpc_device_t RPCDevice{Handles[Device.getDeviceId()]};
-  if (rpc_status_t Err = rpc_handle_server(RPCDevice))
-    return plugin::Plugin::error(
-        "Error while running RPC server on device %d: %d", Device.getDeviceId(),
-        Err);
+  uint64_t NumPorts =
+      std::min(Device.requestedRPCPortCount(), rpc::MAX_PORT_COUNT);
+  rpc::Server Server(NumPorts, Buffers[Device.getDeviceId()]);
+
+  auto port = Server.try_open(Device.getWarpSize());
+  if (!port)
+    return Error::success();
+
+  switch (port->get_opcode()) {
+  case RPC_MALLOC: {
+    port->recv_and_send([&](rpc::Buffer *Buffer, uint32_t) {
+      Buffer->data[0] = reinterpret_cast<uintptr_t>(Device.allocate(
+          Buffer->data[0], nullptr, TARGET_ALLOC_DEVICE_NON_BLOCKING));
+    });
+    break;
+  }
+  case RPC_FREE: {
+    port->recv([&](rpc::Buffer *Buffer, uint32_t) {
+      Device.free(reinterpret_cast<void *>(Buffer->data[0]),
+                  TARGET_ALLOC_DEVICE_NON_BLOCKING);
+    });
+    break;
+  }
+  default:
+    break;
+  }
+
+  // Let the `libc` library handle and unhandled opcodes.
+  int Status = libc_handle_rpc_port(&*port, Device.getWarpSize());
+  if (Status != rpc::SUCCESS)
+    return createStringError("RPC server given invalid opcode!");
+
+  port->close();
+
+  return Error::success();
 #endif
   return Error::success();
 }
 
 Error RPCServerTy::deinitDevice(plugin::GenericDeviceTy &Device) {
 #ifdef LIBOMPTARGET_RPC_SUPPORT
-  rpc_device_t RPCDevice{Handles[Device.getDeviceId()]};
-  auto Dealloc = [](void *Ptr, void *Data) {
-    plugin::GenericDeviceTy &Device =
-        *reinterpret_cast<plugin::GenericDeviceTy *>(Data);
-    Device.free(Ptr, TARGET_ALLOC_HOST);
-  };
-  if (rpc_status_t Err = rpc_server_shutdown(RPCDevice, Dealloc, &Device))
-    return plugin::Plugin::error(
-        "Failed to shut down RPC server for device %d: %d",
-        Device.getDeviceId(), Err);
+  Device.free(Buffers[Device.getDeviceId()], TARGET_ALLOC_HOST);
+  return Error::success();
 #endif
   return Error::success();
 }



More information about the libc-commits mailing list