[libc-commits] [libc] 719d77e - [libc] Begin implementing a library	for the RPC server
    Joseph Huber via libc-commits 
    libc-commits at lists.llvm.org
       
    Thu Jun 15 09:02:31 PDT 2023
    
    
  
Author: Joseph Huber
Date: 2023-06-15T11:02:23-05:00
New Revision: 719d77ed28b69101aaa03cf82d64ede3b29cafcd
URL: https://github.com/llvm/llvm-project/commit/719d77ed28b69101aaa03cf82d64ede3b29cafcd
DIFF: https://github.com/llvm/llvm-project/commit/719d77ed28b69101aaa03cf82d64ede3b29cafcd.diff
LOG: [libc] Begin implementing a library for the RPC server
This patch begins providing a generic static library that wraps around
the raw `rpc.h` interface. As discussed in the corresponding RFC,
https://discourse.llvm.org/t/rfc-libc-exporting-the-rpc-interface-for-the-gpu-libc/71030,
we want to begin exporting RPC services to external users. In order to
do this we decided to not expose the `rpc.h` header by wrapping around
its functionality. This is done with a C-interface as we make heavy use
of callbacks and allows us to provide a predictable interface.
Reviewed By: JonChesterfield, sivachandra
Differential Revision: https://reviews.llvm.org/D147054
Added: 
    libc/utils/gpu/server/CMakeLists.txt
    libc/utils/gpu/server/Server.cpp
    libc/utils/gpu/server/Server.h
Modified: 
    libc/src/__support/RPC/rpc.h
    libc/utils/gpu/CMakeLists.txt
    libc/utils/gpu/loader/Loader.h
    libc/utils/gpu/loader/amdgpu/CMakeLists.txt
    libc/utils/gpu/loader/amdgpu/Loader.cpp
    libc/utils/gpu/loader/nvptx/CMakeLists.txt
    libc/utils/gpu/loader/nvptx/Loader.cpp
Removed: 
    libc/utils/gpu/loader/Server.h
################################################################################
diff  --git a/libc/src/__support/RPC/rpc.h b/libc/src/__support/RPC/rpc.h
index 8be4e5dfb8ac5..3c4613b36a973 100644
--- a/libc/src/__support/RPC/rpc.h
+++ b/libc/src/__support/RPC/rpc.h
@@ -126,6 +126,10 @@ template <bool Invert> struct Process {
         reinterpret_cast<Packet *>(advance(buffer, buffer_offset(port_count)));
   }
 
+  /// Returns the beginning of the unified buffer. Intended for initializing the
+  /// client after the server has been started.
+  LIBC_INLINE void *get_buffer_start() const { return Invert ? outbox : inbox; }
+
   /// Allocate a memory buffer sufficient to store the following equivalent
   /// representation in memory.
   ///
diff  --git a/libc/utils/gpu/CMakeLists.txt b/libc/utils/gpu/CMakeLists.txt
index e529646a1206e..7c15f36052cf3 100644
--- a/libc/utils/gpu/CMakeLists.txt
+++ b/libc/utils/gpu/CMakeLists.txt
@@ -1 +1,2 @@
+add_subdirectory(server)
 add_subdirectory(loader)
diff  --git a/libc/utils/gpu/loader/Loader.h b/libc/utils/gpu/loader/Loader.h
index fcff0ec1516e2..dd043b46b3c90 100644
--- a/libc/utils/gpu/loader/Loader.h
+++ b/libc/utils/gpu/loader/Loader.h
@@ -9,9 +9,12 @@
 #ifndef LLVM_LIBC_UTILS_GPU_LOADER_LOADER_H
 #define LLVM_LIBC_UTILS_GPU_LOADER_LOADER_H
 
+#include "utils/gpu/server/Server.h"
+#include <cstddef>
 #include <cstdint>
+#include <cstdio>
+#include <cstdlib>
 #include <cstring>
-#include <stddef.h>
 
 /// Generic launch parameters for configuration the number of blocks / threads.
 struct LaunchParameters {
@@ -92,4 +95,13 @@ void *copy_environment(char **envp, Allocator alloc) {
   return copy_argument_vector(envc, envp, alloc);
 };
 
+inline void handle_error(const char *msg) {
+  fprintf(stderr, "%s\n", msg);
+  exit(EXIT_FAILURE);
+}
+
+inline void handle_error(rpc_status_t) {
+  handle_error("Failure in the RPC server\n");
+}
+
 #endif
diff  --git a/libc/utils/gpu/loader/Server.h b/libc/utils/gpu/loader/Server.h
deleted file mode 100644
index 9286e55604c39..0000000000000
--- a/libc/utils/gpu/loader/Server.h
+++ /dev/null
@@ -1,123 +0,0 @@
-//===-- Generic RPC server interface --------------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef LLVM_LIBC_UTILS_GPU_LOADER_RPC_H
-#define LLVM_LIBC_UTILS_GPU_LOADER_RPC_H
-
-#include <cstdint>
-#include <cstdio>
-#include <cstdlib>
-#include <cstring>
-#include <stddef.h>
-
-#include "src/__support/RPC/rpc.h"
-
-static __llvm_libc::rpc::Server server;
-
-/// Queries the RPC client at least once and performs server-side work if there
-/// are any active requests.
-template <typename Alloc, typename Dealloc>
-void handle_server(Alloc allocator, Dealloc deallocator) {
-  using namespace __llvm_libc;
-
-  // Continue servicing the client until there is no work left and we return.
-  for (;;) {
-    auto port = server.try_open();
-    if (!port)
-      return;
-
-    switch (port->get_opcode()) {
-    case rpc::Opcode::WRITE_TO_STREAM:
-    case rpc::Opcode::WRITE_TO_STDERR:
-    case rpc::Opcode::WRITE_TO_STDOUT: {
-      uint64_t sizes[rpc::MAX_LANE_SIZE] = {0};
-      void *strs[rpc::MAX_LANE_SIZE] = {nullptr};
-      FILE *files[rpc::MAX_LANE_SIZE] = {nullptr};
-      if (port->get_opcode() == rpc::Opcode::WRITE_TO_STREAM)
-        port->recv([&](rpc::Buffer *buffer, uint32_t id) {
-          files[id] = reinterpret_cast<FILE *>(buffer->data[0]);
-        });
-      port->recv_n(strs, sizes, [&](uint64_t size) { return new char[size]; });
-      port->send([&](rpc::Buffer *buffer, uint32_t id) {
-        FILE *file = port->get_opcode() == rpc::Opcode::WRITE_TO_STDOUT
-                         ? stdout
-                         : (port->get_opcode() == rpc::Opcode::WRITE_TO_STDERR
-                                ? stderr
-                                : files[id]);
-        int ret = fwrite(strs[id], sizes[id], 1, file);
-        reinterpret_cast<int *>(buffer->data)[0] = ret >= 0 ? sizes[id] : ret;
-      });
-      for (uint64_t i = 0; i < rpc::MAX_LANE_SIZE; ++i) {
-        if (strs[i])
-          delete[] reinterpret_cast<uint8_t *>(strs[i]);
-      }
-      break;
-    }
-    case rpc::Opcode::EXIT: {
-      port->recv([](rpc::Buffer *buffer) {
-        exit(reinterpret_cast<uint32_t *>(buffer->data)[0]);
-      });
-      break;
-    }
-    case rpc::Opcode::MALLOC: {
-      port->recv_and_send([&](rpc::Buffer *buffer) {
-        buffer->data[0] =
-            reinterpret_cast<uintptr_t>(allocator(buffer->data[0]));
-      });
-      break;
-    }
-    case rpc::Opcode::FREE: {
-      port->recv([&](rpc::Buffer *buffer) {
-        deallocator(reinterpret_cast<void *>(buffer->data[0]));
-      });
-      break;
-    }
-    case rpc::Opcode::TEST_INCREMENT: {
-      port->recv_and_send([](rpc::Buffer *buffer) {
-        reinterpret_cast<uint64_t *>(buffer->data)[0] += 1;
-      });
-      break;
-    }
-    case rpc::Opcode::TEST_INTERFACE: {
-      uint64_t cnt = 0;
-      bool end_with_recv;
-      port->recv([&](rpc::Buffer *buffer) { end_with_recv = buffer->data[0]; });
-      port->recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; });
-      port->send([&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; });
-      port->recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; });
-      port->send([&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; });
-      port->recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; });
-      port->recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; });
-      port->send([&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; });
-      port->send([&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; });
-      if (end_with_recv)
-        port->recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; });
-      else
-        port->send(
-            [&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; });
-      break;
-    }
-    case rpc::Opcode::TEST_STREAM: {
-      uint64_t sizes[rpc::MAX_LANE_SIZE] = {0};
-      void *dst[rpc::MAX_LANE_SIZE] = {nullptr};
-      port->recv_n(dst, sizes, [](uint64_t size) { return new char[size]; });
-      port->send_n(dst, sizes);
-      for (uint64_t i = 0; i < rpc::MAX_LANE_SIZE; ++i) {
-        if (dst[i])
-          delete[] reinterpret_cast<uint8_t *>(dst[i]);
-      }
-      break;
-    }
-    default:
-      port->recv([](rpc::Buffer *buffer) {});
-    }
-    port->close();
-  }
-}
-
-#endif
diff  --git a/libc/utils/gpu/loader/amdgpu/CMakeLists.txt b/libc/utils/gpu/loader/amdgpu/CMakeLists.txt
index 83c61002c8ae7..9cd5fe4ce2c1b 100644
--- a/libc/utils/gpu/loader/amdgpu/CMakeLists.txt
+++ b/libc/utils/gpu/loader/amdgpu/CMakeLists.txt
@@ -5,4 +5,5 @@ target_link_libraries(amdhsa_loader
   PRIVATE
   hsa-runtime64::hsa-runtime64
   gpu_loader
+  rpc_server
 )
diff  --git a/libc/utils/gpu/loader/amdgpu/Loader.cpp b/libc/utils/gpu/loader/amdgpu/Loader.cpp
index 3d1f95a05b3c1..ca6f6ab315a2d 100644
--- a/libc/utils/gpu/loader/amdgpu/Loader.cpp
+++ b/libc/utils/gpu/loader/amdgpu/Loader.cpp
@@ -14,7 +14,6 @@
 //===----------------------------------------------------------------------===//
 
 #include "Loader.h"
-#include "Server.h"
 
 #include <hsa/hsa.h>
 #include <hsa/hsa_ext_amd.h>
@@ -22,6 +21,7 @@
 #include <cstdio>
 #include <cstdlib>
 #include <cstring>
+#include <tuple>
 #include <utility>
 
 /// Print the error code and exit if \p code indicates an error.
@@ -36,11 +36,6 @@ static void handle_error(hsa_status_t code) {
   exit(EXIT_FAILURE);
 }
 
-static void handle_error(const char *msg) {
-  fprintf(stderr, "%s\n", msg);
-  exit(EXIT_FAILURE);
-}
-
 /// Generic interface for iterating using the HSA callbacks.
 template <typename elem_ty, typename func_ty, typename callback_ty>
 hsa_status_t iterate(func_ty func, callback_ty cb) {
@@ -143,20 +138,37 @@ hsa_status_t launch_kernel(hsa_agent_t dev_agent, hsa_executable_t executable,
           executable, kernel_name, &dev_agent, &symbol))
     return err;
 
-  auto allocator = [&](uint64_t size) -> void * {
-    void *dev_ptr = nullptr;
-    if (hsa_status_t err =
-            hsa_amd_memory_pool_allocate(coarsegrained_pool, size,
-                                         /*flags=*/0, &dev_ptr))
-      handle_error(err);
-    hsa_amd_agents_allow_access(1, &dev_agent, nullptr, dev_ptr);
-    return dev_ptr;
-  };
-
-  auto deallocator = [](void *ptr) -> void {
-    if (hsa_status_t err = hsa_amd_memory_pool_free(ptr))
-      handle_error(err);
-  };
+  // Register RPC callbacks for the malloc and free functions on HSA.
+  uint32_t device_id = 0;
+  auto tuple = std::make_tuple(dev_agent, coarsegrained_pool);
+  rpc_register_callback(
+      device_id, RPC_MALLOC,
+      [](rpc_port_t port, void *data) {
+        auto malloc_handler = [](rpc_buffer_t *buffer, void *data) -> void {
+          auto &[dev_agent, pool] = *static_cast<decltype(tuple) *>(data);
+          uint64_t size = buffer->data[0];
+          void *dev_ptr = nullptr;
+          if (hsa_status_t err =
+                  hsa_amd_memory_pool_allocate(pool, size,
+                                               /*flags=*/0, &dev_ptr))
+            handle_error(err);
+          hsa_amd_agents_allow_access(1, &dev_agent, nullptr, dev_ptr);
+          buffer->data[0] = reinterpret_cast<uintptr_t>(dev_ptr);
+        };
+        rpc_recv_and_send(port, malloc_handler, data);
+      },
+      &tuple);
+  rpc_register_callback(
+      device_id, RPC_FREE,
+      [](rpc_port_t port, void *data) {
+        auto free_handler = [](rpc_buffer_t *buffer, void *) {
+          if (hsa_status_t err = hsa_amd_memory_pool_free(
+                  reinterpret_cast<void *>(buffer->data[0])))
+            handle_error(err);
+        };
+        rpc_recv_and_send(port, free_handler, data);
+      },
+      nullptr);
 
   // Retrieve 
diff erent properties of the kernel symbol used for launch.
   uint64_t kernel;
@@ -235,11 +247,13 @@ hsa_status_t launch_kernel(hsa_agent_t dev_agent, hsa_executable_t executable,
   while (hsa_signal_wait_scacquire(
              packet->completion_signal, HSA_SIGNAL_CONDITION_EQ, 0,
              /*timeout_hint=*/1024, HSA_WAIT_STATE_ACTIVE) != 0)
-    handle_server(allocator, deallocator);
+    if (rpc_status_t err = rpc_handle_server(device_id))
+      handle_error(err);
 
   // Handle the server one more time in case the kernel exited with a pending
   // send still in flight.
-  handle_server(allocator, deallocator);
+  if (rpc_status_t err = rpc_handle_server(device_id))
+    handle_error(err);
 
   // Destroy the resources acquired to launch the kernel and return.
   if (hsa_status_t err = hsa_amd_memory_pool_free(args))
@@ -266,7 +280,9 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
           nullptr))
     handle_error(err);
 
-  // Obtain an agent for the device and host to use the HSA memory model.
+  // Obtain a single agent for the device and host to use the HSA memory model.
+  uint32_t num_devices = 1;
+  uint32_t device_id = 0;
   hsa_agent_t dev_agent;
   hsa_agent_t host_agent;
   if (hsa_status_t err = get_agent<HSA_DEVICE_TYPE_GPU>(&dev_agent))
@@ -350,23 +366,27 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
   hsa_amd_memory_fill(dev_ret, 0, sizeof(int));
 
   // Allocate finegrained memory for the RPC server and client to share.
-  uint64_t port_size = __llvm_libc::rpc::DEFAULT_PORT_COUNT;
   uint32_t wavefront_size = 0;
   if (hsa_status_t err = hsa_agent_get_info(
           dev_agent, HSA_AGENT_INFO_WAVEFRONT_SIZE, &wavefront_size))
     handle_error(err);
 
-  uint64_t rpc_shared_buffer_size =
-      __llvm_libc::rpc::Server::allocation_size(port_size, wavefront_size);
-  void *rpc_shared_buffer;
-  if (hsa_status_t err =
-          hsa_amd_memory_pool_allocate(finegrained_pool, rpc_shared_buffer_size,
-                                       /*flags=*/0, &rpc_shared_buffer))
+  // Set up the RPC server.
+  if (rpc_status_t err = rpc_init(num_devices))
+    handle_error(err);
+  auto tuple = std::make_tuple(dev_agent, finegrained_pool);
+  auto rpc_alloc = [](uint64_t size, void *data) {
+    auto &[dev_agent, finegrained_pool] = *static_cast<decltype(tuple) *>(data);
+    void *dev_ptr = nullptr;
+    if (hsa_status_t err = hsa_amd_memory_pool_allocate(finegrained_pool, size,
+                                                        /*flags=*/0, &dev_ptr))
+      handle_error(err);
+    hsa_amd_agents_allow_access(1, &dev_agent, nullptr, dev_ptr);
+    return dev_ptr;
+  };
+  if (rpc_status_t err = rpc_server_init(device_id, RPC_MAXIMUM_PORT_COUNT,
+                                         wavefront_size, rpc_alloc, &tuple))
     handle_error(err);
-  hsa_amd_agents_allow_access(1, &dev_agent, nullptr, rpc_shared_buffer);
-
-  // Initialize the RPC server's buffer for host-device communication.
-  server.reset(port_size, wavefront_size, rpc_shared_buffer);
 
   // Obtain a queue with the minimum (power of two) size, used to send commands
   // to the HSA runtime and launch execution on the device.
@@ -381,7 +401,8 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
     handle_error(err);
 
   LaunchParameters single_threaded_params = {1, 1, 1, 1, 1, 1};
-  begin_args_t init_args = {argc, dev_argv, dev_envp, rpc_shared_buffer};
+  begin_args_t init_args = {argc, dev_argv, dev_envp,
+                            rpc_get_buffer(device_id)};
   if (hsa_status_t err = launch_kernel(
           dev_agent, executable, kernargs_pool, coarsegrained_pool, queue,
           single_threaded_params, "_begin.kd", init_args))
@@ -424,13 +445,16 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
           single_threaded_params, "_end.kd", fini_args))
     handle_error(err);
 
+  if (rpc_status_t err = rpc_server_shutdown(
+          device_id, [](void *ptr, void *) { hsa_amd_memory_pool_free(ptr); },
+          nullptr))
+    handle_error(err);
+
   // Free the memory allocated for the device.
   if (hsa_status_t err = hsa_amd_memory_pool_free(dev_argv))
     handle_error(err);
   if (hsa_status_t err = hsa_amd_memory_pool_free(dev_ret))
     handle_error(err);
-  if (hsa_status_t err = hsa_amd_memory_pool_free(rpc_shared_buffer))
-    handle_error(err);
   if (hsa_status_t err = hsa_amd_memory_pool_free(host_ret))
     handle_error(err);
 
@@ -445,6 +469,8 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
   if (hsa_status_t err = hsa_code_object_destroy(object))
     handle_error(err);
 
+  if (rpc_status_t err = rpc_shutdown())
+    handle_error(err);
   if (hsa_status_t err = hsa_shut_down())
     handle_error(err);
 
diff  --git a/libc/utils/gpu/loader/nvptx/CMakeLists.txt b/libc/utils/gpu/loader/nvptx/CMakeLists.txt
index 9e85357920678..c14d60d3fd1fe 100644
--- a/libc/utils/gpu/loader/nvptx/CMakeLists.txt
+++ b/libc/utils/gpu/loader/nvptx/CMakeLists.txt
@@ -8,6 +8,7 @@ target_include_directories(nvptx_loader PRIVATE ${LLVM_INCLUDE_DIRS})
 target_link_libraries(nvptx_loader
   PRIVATE
   gpu_loader
+  rpc_server
   CUDA::cuda_driver
   LLVMObject
   LLVMSupport
diff  --git a/libc/utils/gpu/loader/nvptx/Loader.cpp b/libc/utils/gpu/loader/nvptx/Loader.cpp
index af1b5e69a88b3..7526381d0622f 100644
--- a/libc/utils/gpu/loader/nvptx/Loader.cpp
+++ b/libc/utils/gpu/loader/nvptx/Loader.cpp
@@ -14,7 +14,6 @@
 //===----------------------------------------------------------------------===//
 
 #include "Loader.h"
-#include "Server.h"
 
 #include "cuda.h"
 
@@ -43,11 +42,6 @@ static void handle_error(CUresult err) {
   exit(1);
 }
 
-static void handle_error(const char *msg) {
-  fprintf(stderr, "%s\n", msg);
-  exit(EXIT_FAILURE);
-}
-
 // Gets the names of all the globals that contain functions to initialize or
 // deinitialize. We need to do this manually because the NVPTX toolchain does
 // not contain the necessary binary manipulation tools.
@@ -181,21 +175,37 @@ CUresult launch_kernel(CUmodule binary, CUstream stream,
   if (CUresult err = cuStreamCreate(&memory_stream, CU_STREAM_NON_BLOCKING))
     handle_error(err);
 
-  auto allocator = [&](uint64_t size) -> void * {
-    CUdeviceptr dev_ptr;
-    if (CUresult err = cuMemAllocAsync(&dev_ptr, size, memory_stream))
-      handle_error(err);
-
-    // Wait until the memory allocation is complete.
-    while (cuStreamQuery(memory_stream) == CUDA_ERROR_NOT_READY)
-      ;
-    return reinterpret_cast<void *>(dev_ptr);
-  };
-  auto deallocator = [&](void *ptr) -> void {
-    if (CUresult err =
-            cuMemFreeAsync(reinterpret_cast<CUdeviceptr>(ptr), memory_stream))
-      handle_error(err);
-  };
+  // Register RPC callbacks for the malloc and free functions on HSA.
+  uint32_t device_id = 0;
+  rpc_register_callback(
+      device_id, RPC_MALLOC,
+      [](rpc_port_t port, void *data) {
+        auto malloc_handler = [](rpc_buffer_t *buffer, void *data) -> void {
+          CUstream memory_stream = *static_cast<CUstream *>(data);
+          uint64_t size = buffer->data[0];
+          CUdeviceptr dev_ptr;
+          if (CUresult err = cuMemAllocAsync(&dev_ptr, size, memory_stream))
+            handle_error(err);
+
+          // Wait until the memory allocation is complete.
+          while (cuStreamQuery(memory_stream) == CUDA_ERROR_NOT_READY)
+            ;
+        };
+        rpc_recv_and_send(port, malloc_handler, data);
+      },
+      &memory_stream);
+  rpc_register_callback(
+      device_id, RPC_FREE,
+      [](rpc_port_t port, void *data) {
+        auto free_handler = [](rpc_buffer_t *buffer, void *data) {
+          CUstream memory_stream = *static_cast<CUstream *>(data);
+          if (CUresult err = cuMemFreeAsync(
+                  static_cast<CUdeviceptr>(buffer->data[0]), memory_stream))
+            handle_error(err);
+        };
+        rpc_recv_and_send(port, free_handler, data);
+      },
+      &memory_stream);
 
   // Call the kernel with the given arguments.
   if (CUresult err = cuLaunchKernel(
@@ -207,23 +217,26 @@ CUresult launch_kernel(CUmodule binary, CUstream stream,
   // Wait until the kernel has completed execution on the device. Periodically
   // check the RPC client for work to be performed on the server.
   while (cuStreamQuery(stream) == CUDA_ERROR_NOT_READY)
-    handle_server(allocator, deallocator);
+    if (rpc_status_t err = rpc_handle_server(device_id))
+      handle_error(err);
 
   // Handle the server one more time in case the kernel exited with a pending
   // send still in flight.
-  handle_server(allocator, deallocator);
+  if (rpc_status_t err = rpc_handle_server(device_id))
+    handle_error(err);
 
   return CUDA_SUCCESS;
 }
 
 int load(int argc, char **argv, char **envp, void *image, size_t size,
          const LaunchParameters ¶ms) {
-
   if (CUresult err = cuInit(0))
     handle_error(err);
   // Obtain the first device found on the system.
+  uint32_t num_devices = 1;
+  uint32_t device_id = 0;
   CUdevice device;
-  if (CUresult err = cuDeviceGet(&device, 0))
+  if (CUresult err = cuDeviceGet(&device, device_id))
     handle_error(err);
 
   // Initialize the CUDA context and claim it for this execution.
@@ -279,22 +292,24 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
   if (CUresult err = cuMemsetD32(dev_ret, 0, 1))
     handle_error(err);
 
-  uint64_t port_size = __llvm_libc::rpc::DEFAULT_PORT_COUNT;
-  uint32_t warp_size = 32;
-
-  uint64_t rpc_shared_buffer_size =
-      __llvm_libc::rpc::Server::allocation_size(port_size, warp_size);
-  void *rpc_shared_buffer = allocator(rpc_shared_buffer_size);
-
-  if (!rpc_shared_buffer)
-    handle_error("Failed to allocate memory the RPC client / server.");
+  if (rpc_status_t err = rpc_init(num_devices))
+    handle_error(err);
 
-  // Initialize the RPC server's buffer for host-device communication.
-  server.reset(port_size, warp_size, rpc_shared_buffer);
+  uint32_t warp_size = 32;
+  auto rpc_alloc = [](uint64_t size, void *) -> void * {
+    void *dev_ptr;
+    if (CUresult err = cuMemAllocHost(&dev_ptr, size))
+      handle_error(err);
+    return dev_ptr;
+  };
+  if (rpc_status_t err = rpc_server_init(device_id, RPC_MAXIMUM_PORT_COUNT,
+                                         warp_size, rpc_alloc, nullptr))
+    handle_error(err);
 
   LaunchParameters single_threaded_params = {1, 1, 1, 1, 1, 1};
   // Call the kernel to
-  begin_args_t init_args = {argc, dev_argv, dev_envp, rpc_shared_buffer};
+  begin_args_t init_args = {argc, dev_argv, dev_envp,
+                            rpc_get_buffer(device_id)};
   if (CUresult err = launch_kernel(binary, stream, single_threaded_params,
                                    "_begin", init_args))
     handle_error(err);
@@ -324,7 +339,8 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
     handle_error(err);
   if (CUresult err = cuMemFreeHost(dev_argv))
     handle_error(err);
-  if (CUresult err = cuMemFreeHost(rpc_shared_buffer))
+  if (rpc_status_t err = rpc_server_shutdown(
+          device_id, [](void *ptr, void *) { cuMemFreeHost(ptr); }, nullptr))
     handle_error(err);
 
   // Destroy the context and the loaded binary.
@@ -332,5 +348,7 @@ int load(int argc, char **argv, char **envp, void *image, size_t size,
     handle_error(err);
   if (CUresult err = cuDevicePrimaryCtxRelease(device))
     handle_error(err);
+  if (rpc_status_t err = rpc_shutdown())
+    handle_error(err);
   return host_ret;
 }
diff  --git a/libc/utils/gpu/server/CMakeLists.txt b/libc/utils/gpu/server/CMakeLists.txt
new file mode 100644
index 0000000000000..fa006b36ea16c
--- /dev/null
+++ b/libc/utils/gpu/server/CMakeLists.txt
@@ -0,0 +1,6 @@
+add_library(rpc_server STATIC Server.cpp)
+
+# Include the RPC implemenation from libc.
+add_dependencies(rpc_server libc.src.__support.RPC.rpc)
+target_include_directories(rpc_server PRIVATE ${LIBC_SOURCE_DIR})
+target_include_directories(rpc_server PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
diff  --git a/libc/utils/gpu/server/Server.cpp b/libc/utils/gpu/server/Server.cpp
new file mode 100644
index 0000000000000..f3ce2fd9b2851
--- /dev/null
+++ b/libc/utils/gpu/server/Server.cpp
@@ -0,0 +1,219 @@
+//===-- Shared memory RPC server instantiation ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "Server.h"
+
+#include "src/__support/RPC/rpc.h"
+#include <atomic>
+#include <cstdio>
+#include <memory>
+#include <mutex>
+#include <unordered_map>
+
+using namespace __llvm_libc;
+
+static_assert(sizeof(rpc_buffer_t) == sizeof(rpc::Buffer),
+              "Buffer size mismatch");
+
+static_assert(RPC_MAXIMUM_PORT_COUNT == rpc::DEFAULT_PORT_COUNT,
+              "Incorrect maximum port count");
+struct Device {
+  rpc::Server server;
+  std::unordered_map<rpc_opcode_t, rpc_opcode_callback_ty> callbacks;
+  std::unordered_map<rpc_opcode_t, void *> callback_data;
+};
+
+// A struct containing all the runtime state required to run the RPC server.
+struct State {
+  State(uint32_t num_devices)
+      : num_devices(num_devices),
+        devices(std::unique_ptr<Device[]>(new Device[num_devices])),
+        reference_count(0u) {}
+  uint32_t num_devices;
+  std::unique_ptr<Device[]> devices;
+  std::atomic_uint32_t reference_count;
+};
+
+static std::mutex startup_mutex;
+
+static State *state;
+
+rpc_status_t rpc_init(uint32_t num_devices) {
+  std::scoped_lock<decltype(startup_mutex)> lock(startup_mutex);
+  if (!state)
+    state = new State(num_devices);
+
+  if (state->reference_count == std::numeric_limits<uint32_t>::max())
+    return RPC_STATUS_ERROR;
+
+  state->reference_count++;
+
+  return RPC_STATUS_SUCCESS;
+}
+
+rpc_status_t rpc_shutdown(void) {
+  if (state->reference_count-- == 1)
+    delete state;
+
+  return RPC_STATUS_SUCCESS;
+}
+
+rpc_status_t rpc_server_init(uint32_t device_id, uint64_t num_ports,
+                             uint32_t lane_size, rpc_alloc_ty alloc,
+                             void *data) {
+  if (device_id >= state->num_devices)
+    return RPC_STATUS_OUT_OF_RANGE;
+
+  uint64_t buffer_size =
+      __llvm_libc::rpc::Server::allocation_size(num_ports, lane_size);
+  void *buffer = alloc(buffer_size, data);
+
+  if (!buffer)
+    return RPC_STATUS_ERROR;
+
+  state->devices[device_id].server.reset(num_ports, lane_size, buffer);
+
+  return RPC_STATUS_SUCCESS;
+}
+
+rpc_status_t rpc_server_shutdown(uint32_t device_id, rpc_free_ty dealloc,
+                                 void *data) {
+  if (device_id >= state->num_devices)
+    return RPC_STATUS_OUT_OF_RANGE;
+
+  dealloc(rpc_get_buffer(device_id), data);
+
+  return RPC_STATUS_SUCCESS;
+}
+
+rpc_status_t rpc_handle_server(uint32_t device_id) {
+  if (device_id >= state->num_devices)
+    return RPC_STATUS_OUT_OF_RANGE;
+
+  for (;;) {
+    auto port = state->devices[device_id].server.try_open();
+    if (!port)
+      return RPC_STATUS_SUCCESS;
+
+    switch (port->get_opcode()) {
+    case rpc::Opcode::WRITE_TO_STREAM:
+    case rpc::Opcode::WRITE_TO_STDERR:
+    case rpc::Opcode::WRITE_TO_STDOUT: {
+      uint64_t sizes[rpc::MAX_LANE_SIZE] = {0};
+      void *strs[rpc::MAX_LANE_SIZE] = {nullptr};
+      FILE *files[rpc::MAX_LANE_SIZE] = {nullptr};
+      if (port->get_opcode() == rpc::Opcode::WRITE_TO_STREAM)
+        port->recv([&](rpc::Buffer *buffer, uint32_t id) {
+          files[id] = reinterpret_cast<FILE *>(buffer->data[0]);
+        });
+      port->recv_n(strs, sizes, [&](uint64_t size) { return new char[size]; });
+      port->send([&](rpc::Buffer *buffer, uint32_t id) {
+        FILE *file = port->get_opcode() == rpc::Opcode::WRITE_TO_STDOUT
+                         ? stdout
+                         : (port->get_opcode() == rpc::Opcode::WRITE_TO_STDERR
+                                ? stderr
+                                : files[id]);
+        int ret = fwrite(strs[id], sizes[id], 1, file);
+        reinterpret_cast<int *>(buffer->data)[0] = ret >= 0 ? sizes[id] : ret;
+      });
+      for (uint64_t i = 0; i < rpc::MAX_LANE_SIZE; ++i) {
+        if (strs[i])
+          delete[] reinterpret_cast<uint8_t *>(strs[i]);
+      }
+      break;
+    }
+    case rpc::Opcode::EXIT: {
+      port->recv([](rpc::Buffer *buffer) {
+        exit(reinterpret_cast<uint32_t *>(buffer->data)[0]);
+      });
+      break;
+    }
+    // TODO: Move handling of these  test cases to the loader implementation.
+    case rpc::Opcode::TEST_INCREMENT: {
+      port->recv_and_send([](rpc::Buffer *buffer) {
+        reinterpret_cast<uint64_t *>(buffer->data)[0] += 1;
+      });
+      break;
+    }
+    case rpc::Opcode::TEST_INTERFACE: {
+      uint64_t cnt = 0;
+      bool end_with_recv;
+      port->recv([&](rpc::Buffer *buffer) { end_with_recv = buffer->data[0]; });
+      port->recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; });
+      port->send([&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; });
+      port->recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; });
+      port->send([&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; });
+      port->recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; });
+      port->recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; });
+      port->send([&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; });
+      port->send([&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; });
+      if (end_with_recv)
+        port->recv([&](rpc::Buffer *buffer) { cnt = buffer->data[0]; });
+      else
+        port->send(
+            [&](rpc::Buffer *buffer) { buffer->data[0] = cnt = cnt + 1; });
+      break;
+    }
+    case rpc::Opcode::TEST_STREAM: {
+      uint64_t sizes[rpc::MAX_LANE_SIZE] = {0};
+      void *dst[rpc::MAX_LANE_SIZE] = {nullptr};
+      port->recv_n(dst, sizes, [](uint64_t size) { return new char[size]; });
+      port->send_n(dst, sizes);
+      for (uint64_t i = 0; i < rpc::MAX_LANE_SIZE; ++i) {
+        if (dst[i])
+          delete[] reinterpret_cast<uint8_t *>(dst[i]);
+      }
+      break;
+    }
+    case rpc::Opcode::NOOP: {
+      port->recv([](rpc::Buffer *buffer) {});
+      break;
+    }
+    default: {
+      auto handler = state->devices[device_id].callbacks.find(
+          static_cast<rpc_opcode_t>(port->get_opcode()));
+
+      // We error out on an unhandled opcode.
+      if (handler == state->devices[device_id].callbacks.end())
+        return RPC_STATUS_UNHANDLED_OPCODE;
+
+      // Invoke the registered callback with a reference to the port.
+      void *data = state->devices[device_id].callback_data.at(
+          static_cast<rpc_opcode_t>(port->get_opcode()));
+      rpc_port_t port_ref{reinterpret_cast<uint64_t>(&*port)};
+      (handler->second)(port_ref, data);
+    }
+    }
+    port->close();
+  }
+}
+
+rpc_status_t rpc_register_callback(uint32_t device_id, rpc_opcode_t opcode,
+                                   rpc_opcode_callback_ty callback,
+                                   void *data) {
+  if (device_id >= state->num_devices)
+    return RPC_STATUS_OUT_OF_RANGE;
+
+  state->devices[device_id].callbacks[opcode] = callback;
+  state->devices[device_id].callback_data[opcode] = data;
+  return RPC_STATUS_SUCCESS;
+}
+
+void *rpc_get_buffer(uint32_t device_id) {
+  if (device_id >= state->num_devices)
+    return nullptr;
+  return state->devices[device_id].server.get_buffer_start();
+}
+
+void rpc_recv_and_send(rpc_port_t ref, rpc_port_callback_ty callback,
+                       void *data) {
+  rpc::Server::Port *port = reinterpret_cast<rpc::Server::Port *>(ref.handle);
+  port->recv_and_send([=](rpc::Buffer *buffer) {
+    callback(reinterpret_cast<rpc_buffer_t *>(buffer), data);
+  });
+}
diff  --git a/libc/utils/gpu/server/Server.h b/libc/utils/gpu/server/Server.h
new file mode 100644
index 0000000000000..5a15371ff77ce
--- /dev/null
+++ b/libc/utils/gpu/server/Server.h
@@ -0,0 +1,102 @@
+//===-- Shared memory RPC server instantiation ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_UTILS_GPU_SERVER_RPC_SERVER_H
+#define LLVM_LIBC_UTILS_GPU_SERVER_RPC_SERVER_H
+
+#include <stdint.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/// The maxium number of ports that can be opened for any server.
+const uint64_t RPC_MAXIMUM_PORT_COUNT = 64;
+
+// TODO: Move these to a header exported by the C library.
+typedef enum : uint16_t {
+  RPC_NOOP = 0,
+  RPC_EXIT = 1,
+  RPC_WRITE_TO_STDOUT = 2,
+  RPC_WRITE_TO_STDERR = 3,
+  RPC_WRITE_TO_STREAM = 4,
+  RPC_MALLOC = 5,
+  RPC_FREE = 6,
+} rpc_opcode_t;
+
+/// status codes.
+typedef enum {
+  RPC_STATUS_SUCCESS = 0x0,
+  RPC_STATUS_ERROR = 0x1000,
+  RPC_STATUS_OUT_OF_RANGE = 0x1001,
+  RPC_STATUS_UNHANDLED_OPCODE = 0x1002,
+} rpc_status_t;
+
+/// A struct containing an opaque handle to an RPC port. This is what allows the
+/// server to communicate with the client.
+typedef struct rpc_port_s {
+  uint64_t handle;
+} rpc_port_t;
+
+/// A fixed-size buffer containing the payload sent from the client.
+typedef struct rpc_buffer_s {
+  uint64_t data[8];
+} rpc_buffer_t;
+
+/// A function used to allocate \p bytes for use by the RPC server and client.
+/// The memory should support asynchronous and atomic access from both the
+/// client and server.
+typedef void *(*rpc_alloc_ty)(uint64_t size, void *data);
+
+/// A function used to free the \p ptr previously allocated.
+typedef void (*rpc_free_ty)(void *ptr, void *data);
+
+/// A callback function provided with a \p port to communicate with the RPC
+/// client. This will be called by the server to handle an opcode.
+typedef void (*rpc_opcode_callback_ty)(rpc_port_t port, void *data);
+
+/// A callback function to use the port to receive or send a \p buffer.
+typedef void (*rpc_port_callback_ty)(rpc_buffer_t *buffer, void *data);
+
+/// Initialize the rpc library for general use on \p num_devices.
+rpc_status_t rpc_init(uint32_t num_devices);
+
+/// Shut down the rpc interface.
+rpc_status_t rpc_shutdown(void);
+
+/// Initialize the server for a given device.
+rpc_status_t rpc_server_init(uint32_t device_id, uint64_t num_ports,
+                             uint32_t lane_size, rpc_alloc_ty alloc,
+                             void *data);
+
+/// Shut down the server for a given device.
+rpc_status_t rpc_server_shutdown(uint32_t device_id, rpc_free_ty dealloc,
+                                 void *data);
+
+/// Queries the RPC clients at least once and performs server-side work if there
+/// are any active requests. Runs until all work on the server is completed.
+rpc_status_t rpc_handle_server(uint32_t device_id);
+
+/// Register a callback to handle an opcode from the RPC client. The associated
+/// data must remain accessible as long as the user intends to handle the server
+/// with this callback.
+rpc_status_t rpc_register_callback(uint32_t device_id, rpc_opcode_t opcode,
+                                   rpc_opcode_callback_ty callback, void *data);
+
+/// Obtain a pointer to the memory buffer used to run the RPC client and server.
+void *rpc_get_buffer(uint32_t device_id);
+
+/// Use the \p port to receive and send a buffer using the \p callback.
+void rpc_recv_and_send(rpc_port_t port, rpc_port_callback_ty callback,
+                       void *data);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
        
    
    
More information about the libc-commits
mailing list