[libc-commits] [libc] 964a535 - [libc] Remove flexible array and replace with a template

Joseph Huber via libc-commits libc-commits at lists.llvm.org
Tue Jun 20 13:22:46 PDT 2023


Author: Joseph Huber
Date: 2023-06-20T15:22:37-05:00
New Revision: 964a535bfa77b491f5ad11222b419710795f0ffb

URL: https://github.com/llvm/llvm-project/commit/964a535bfa77b491f5ad11222b419710795f0ffb
DIFF: https://github.com/llvm/llvm-project/commit/964a535bfa77b491f5ad11222b419710795f0ffb.diff

LOG: [libc] Remove flexible array and replace with a template

Currently the implementation of the RPC interface requires a flexible
struct. This caused problems when compilling the RPC server with GCC as
would be required if trying to export the RPC server interface. This
required that we either move to the `x[1]` workaround or make it a
template parameter. While just using `x[1]` would be much less noisy,
this is technically undefined behavior. For this reason I elected to use
templates.

The downside to using templates is that the server code must now be able
to handle multiple different types at runtime. I was unable to find a
good solution that didn't rely on type erasure so I simply branch off of
the given value.

Reviewed By: JonChesterfield

Differential Revision: https://reviews.llvm.org/D153304

Added: 
    

Modified: 
    libc/src/__support/RPC/rpc.h
    libc/src/__support/RPC/rpc_client.cpp
    libc/src/gpu/rpc_reset.cpp
    libc/startup/gpu/amdgpu/start.cpp
    libc/startup/gpu/nvptx/start.cpp
    libc/utils/gpu/server/Server.cpp
    libc/utils/gpu/server/Server.h

Removed: 
    


################################################################################
diff  --git a/libc/src/__support/RPC/rpc.h b/libc/src/__support/RPC/rpc.h
index 0d7911403cf48..c8427a7835e22 100644
--- a/libc/src/__support/RPC/rpc.h
+++ b/libc/src/__support/RPC/rpc.h
@@ -45,20 +45,15 @@ struct Header {
 
 /// The data payload for the associated packet. We provide enough space for each
 /// thread in the cooperating lane to have a buffer.
-struct Payload {
-#if defined(LIBC_TARGET_ARCH_IS_GPU)
-  Buffer slot[gpu::LANE_SIZE];
-#else
-  // Flexible array size allocated at runtime to the appropriate size.
-  Buffer slot[];
-#endif
+template <uint32_t lane_size = gpu::LANE_SIZE> struct Payload {
+  Buffer slot[lane_size];
 };
 
 /// A packet used to share data between the client and server across an entire
 /// lane. We use a lane as the minimum granularity for execution.
-struct alignas(64) Packet {
+template <uint32_t lane_size = gpu::LANE_SIZE> struct alignas(64) Packet {
   Header header;
-  Payload payload;
+  Payload<lane_size> payload;
 };
 
 // TODO: This should be configured by the server and passed in. The general rule
@@ -79,7 +74,7 @@ constexpr uint64_t DEFAULT_PORT_COUNT = 64;
 ///   - The client will always start with a 'send' operation.
 ///   - The server will always start with a 'recv' operation.
 ///   - Every 'send' or 'recv' call is mirrored by the other process.
-template <bool Invert> struct Process {
+template <bool Invert, uint32_t lane_size> struct Process {
   LIBC_INLINE Process() = default;
   LIBC_INLINE Process(const Process &) = delete;
   LIBC_INLINE Process &operator=(const Process &) = delete;
@@ -87,29 +82,26 @@ template <bool Invert> struct Process {
   LIBC_INLINE Process &operator=(Process &&) = default;
   LIBC_INLINE ~Process() = default;
 
-  template <bool T> friend struct Port;
+  template <bool T, uint32_t S> friend struct Port;
 
 protected:
   uint64_t port_count;
-  uint32_t lane_size;
   cpp::Atomic<uint32_t> *inbox;
   cpp::Atomic<uint32_t> *outbox;
-  Packet *packet;
+  Packet<lane_size> *packet;
 
   cpp::Atomic<uint32_t> lock[DEFAULT_PORT_COUNT] = {0};
 
 public:
   /// Initialize the communication channels.
-  LIBC_INLINE void reset(uint64_t port_count, uint32_t lane_size,
-                         void *buffer) {
+  LIBC_INLINE void reset(uint64_t port_count, void *buffer) {
     this->port_count = port_count;
-    this->lane_size = lane_size;
     this->inbox = reinterpret_cast<cpp::Atomic<uint32_t> *>(
         advance(buffer, inbox_offset(port_count)));
     this->outbox = reinterpret_cast<cpp::Atomic<uint32_t> *>(
         advance(buffer, outbox_offset(port_count)));
-    this->packet =
-        reinterpret_cast<Packet *>(advance(buffer, buffer_offset(port_count)));
+    this->packet = reinterpret_cast<Packet<lane_size> *>(
+        advance(buffer, buffer_offset(port_count)));
   }
 
   /// Returns the beginning of the unified buffer. Intended for initializing the
@@ -124,20 +116,11 @@ template <bool Invert> struct Process {
   ///   Atomic<uint32_t> secondary[port_count];
   ///   Packet buffer[port_count];
   /// };
-  LIBC_INLINE static uint64_t allocation_size(uint64_t port_count,
-                                              uint32_t lane_size) {
-    return buffer_offset(port_count) + buffer_bytes(port_count, lane_size);
+  LIBC_INLINE static uint64_t allocation_size(uint64_t port_count) {
+    return buffer_offset(port_count) + buffer_bytes(port_count);
   }
 
 protected:
-  /// The length of the packet is flexible because the server needs to look up
-  /// the lane size at runtime. This helper indexes at the proper offset.
-  LIBC_INLINE Packet &get_packet(uint64_t index) {
-    return *reinterpret_cast<Packet *>(advance(
-        packet, index * align_up(sizeof(Header) + lane_size * sizeof(Buffer),
-                                 alignof(Packet))));
-  }
-
   /// Retrieve the inbox state from memory shared between processes.
   LIBC_INLINE uint32_t load_inbox(uint64_t index) {
     return inbox[index].load(cpp::MemoryOrder::RELAXED);
@@ -222,7 +205,7 @@ template <bool Invert> struct Process {
 
   /// Invokes a function accross every active buffer across the total lane size.
   LIBC_INLINE void invoke_rpc(cpp::function<void(Buffer *)> fn,
-                              Packet &packet) {
+                              Packet<lane_size> &packet) {
     if constexpr (is_process_gpu()) {
       fn(&packet.payload.slot[gpu::get_lane_id()]);
     } else {
@@ -234,7 +217,7 @@ template <bool Invert> struct Process {
 
   /// Alternate version that also provides the index of the current lane.
   LIBC_INLINE void invoke_rpc(cpp::function<void(Buffer *, uint32_t)> fn,
-                              Packet &packet) {
+                              Packet<lane_size> &packet) {
     if constexpr (is_process_gpu()) {
       fn(&packet.payload.slot[gpu::get_lane_id()], gpu::get_lane_id());
     } else {
@@ -250,13 +233,8 @@ template <bool Invert> struct Process {
   }
 
   /// Number of bytes to allocate for the buffer containing the packets.
-  LIBC_INLINE static uint64_t buffer_bytes(uint64_t port_count,
-                                           uint32_t lane_size) {
-    return is_process_gpu()
-               ? port_count * sizeof(Packet)
-               : port_count *
-                     align_up(sizeof(Header) + (lane_size * sizeof(Buffer)),
-                              alignof(Packet));
+  LIBC_INLINE static uint64_t buffer_bytes(uint64_t port_count) {
+    return port_count * sizeof(Packet<lane_size>);
   }
 
   /// Offset of the inbox in memory. This is the same as the outbox if inverted.
@@ -271,15 +249,15 @@ template <bool Invert> struct Process {
 
   /// Offset of the buffer containing the packets after the inbox and outbox.
   LIBC_INLINE static uint64_t buffer_offset(uint64_t port_count) {
-    return align_up(2 * mailbox_bytes(port_count), alignof(Packet));
+    return align_up(2 * mailbox_bytes(port_count), alignof(Packet<lane_size>));
   }
 };
 
 /// The port provides the interface to communicate between the multiple
 /// processes. A port is conceptually an index into the memory provided by the
 /// underlying process that is guarded by a lock bit.
-template <bool T> struct Port {
-  LIBC_INLINE Port(Process<T> &process, uint64_t lane_mask, uint64_t index,
+template <bool T, uint32_t S> struct Port {
+  LIBC_INLINE Port(Process<T, S> &process, uint64_t lane_mask, uint64_t index,
                    uint32_t out)
       : process(process), lane_mask(lane_mask), index(index), out(out),
         receive(false), owns_buffer(true) {}
@@ -292,8 +270,8 @@ template <bool T> struct Port {
   LIBC_INLINE Port &operator=(Port &&) = default;
 
   friend struct Client;
-  friend struct Server;
-  friend class cpp::optional<Port<T>>;
+  template <uint32_t U> friend struct Server;
+  friend class cpp::optional<Port<T, S>>;
 
 public:
   template <typename U> LIBC_INLINE void recv(U use);
@@ -307,7 +285,7 @@ template <bool T> struct Port {
   LIBC_INLINE void recv_n(void **dst, uint64_t *size, A &&alloc);
 
   LIBC_INLINE uint16_t get_opcode() const {
-    return process.get_packet(index).header.opcode;
+    return process.packet[index].header.opcode;
   }
 
   LIBC_INLINE void close() {
@@ -319,7 +297,7 @@ template <bool T> struct Port {
   }
 
 private:
-  Process<T> &process;
+  Process<T, S> &process;
   uint64_t lane_mask;
   uint64_t index;
   uint32_t out;
@@ -328,41 +306,43 @@ template <bool T> struct Port {
 };
 
 /// The RPC client used to make requests to the server.
-struct Client : public Process<false> {
+struct Client : public Process<false, gpu::LANE_SIZE> {
   LIBC_INLINE Client() = default;
   LIBC_INLINE Client(const Client &) = delete;
   LIBC_INLINE Client &operator=(const Client &) = delete;
   LIBC_INLINE ~Client() = default;
 
-  using Port = rpc::Port<false>;
+  using Port = rpc::Port<false, gpu::LANE_SIZE>;
   template <uint16_t opcode> LIBC_INLINE cpp::optional<Port> try_open();
   template <uint16_t opcode> LIBC_INLINE Port open();
 };
 
 /// The RPC server used to respond to the client.
-struct Server : public Process<true> {
+template <uint32_t lane_size> struct Server : public Process<true, lane_size> {
   LIBC_INLINE Server() = default;
   LIBC_INLINE Server(const Server &) = delete;
   LIBC_INLINE Server &operator=(const Server &) = delete;
   LIBC_INLINE ~Server() = default;
 
-  using Port = rpc::Port<true>;
+  using Port = rpc::Port<true, lane_size>;
   LIBC_INLINE cpp::optional<Port> try_open();
   LIBC_INLINE Port open();
 };
 
 /// Applies \p fill to the shared buffer and initiates a send operation.
-template <bool T> template <typename F> LIBC_INLINE void Port<T>::send(F fill) {
+template <bool T, uint32_t S>
+template <typename F>
+LIBC_INLINE void Port<T, S>::send(F fill) {
   uint32_t in = owns_buffer ? out ^ T : process.load_inbox(index);
 
   // We need to wait until we own the buffer before sending.
-  while (Process<T>::buffer_unavailable(in, out)) {
+  while (Process<T, S>::buffer_unavailable(in, out)) {
     sleep_briefly();
     in = process.load_inbox(index);
   }
 
   // Apply the \p fill function to initialize the buffer and release the memory.
-  process.invoke_rpc(fill, process.get_packet(index));
+  process.invoke_rpc(fill, process.packet[index]);
   atomic_thread_fence(cpp::MemoryOrder::RELEASE);
   out = process.invert_outbox(index, out);
   owns_buffer = false;
@@ -370,7 +350,9 @@ template <bool T> template <typename F> LIBC_INLINE void Port<T>::send(F fill) {
 }
 
 /// Applies \p use to the shared buffer and acknowledges the send.
-template <bool T> template <typename U> LIBC_INLINE void Port<T>::recv(U use) {
+template <bool T, uint32_t S>
+template <typename U>
+LIBC_INLINE void Port<T, S>::recv(U use) {
   // We only exchange ownership of the buffer during a receive if we are waiting
   // for a previous receive to finish.
   if (receive) {
@@ -381,22 +363,22 @@ template <bool T> template <typename U> LIBC_INLINE void Port<T>::recv(U use) {
   uint32_t in = owns_buffer ? out ^ T : process.load_inbox(index);
 
   // We need to wait until we own the buffer before receiving.
-  while (Process<T>::buffer_unavailable(in, out)) {
+  while (Process<T, S>::buffer_unavailable(in, out)) {
     sleep_briefly();
     in = process.load_inbox(index);
   }
   atomic_thread_fence(cpp::MemoryOrder::ACQUIRE);
 
   // Apply the \p use function to read the memory out of the buffer.
-  process.invoke_rpc(use, process.get_packet(index));
+  process.invoke_rpc(use, process.packet[index]);
   receive = true;
   owns_buffer = true;
 }
 
 /// Combines a send and receive into a single function.
-template <bool T>
+template <bool T, uint32_t S>
 template <typename F, typename U>
-LIBC_INLINE void Port<T>::send_and_recv(F fill, U use) {
+LIBC_INLINE void Port<T, S>::send_and_recv(F fill, U use) {
   send(fill);
   recv(use);
 }
@@ -404,17 +386,17 @@ LIBC_INLINE void Port<T>::send_and_recv(F fill, U use) {
 /// Combines a receive and send operation into a single function. The \p work
 /// function modifies the buffer in-place and the send is only used to initiate
 /// the copy back.
-template <bool T>
+template <bool T, uint32_t S>
 template <typename W>
-LIBC_INLINE void Port<T>::recv_and_send(W work) {
+LIBC_INLINE void Port<T, S>::recv_and_send(W work) {
   recv(work);
   send([](Buffer *) { /* no-op */ });
 }
 
 /// Helper routine to simplify the interface when sending from the GPU using
 /// thread private pointers to the underlying value.
-template <bool T>
-LIBC_INLINE void Port<T>::send_n(const void *src, uint64_t size) {
+template <bool T, uint32_t S>
+LIBC_INLINE void Port<T, S>::send_n(const void *src, uint64_t size) {
   static_assert(is_process_gpu(), "Only valid when running on the GPU");
   const void **src_ptr = &src;
   uint64_t *size_ptr = &size;
@@ -423,8 +405,8 @@ LIBC_INLINE void Port<T>::send_n(const void *src, uint64_t size) {
 
 /// Sends an arbitrarily sized data buffer \p src across the shared channel in
 /// multiples of the packet length.
-template <bool T>
-LIBC_INLINE void Port<T>::send_n(const void *const *src, uint64_t *size) {
+template <bool T, uint32_t S>
+LIBC_INLINE void Port<T, S>::send_n(const void *const *src, uint64_t *size) {
   uint64_t num_sends = 0;
   send([&](Buffer *buffer, uint32_t id) {
     reinterpret_cast<uint64_t *>(buffer->data)[0] = lane_value(size, id);
@@ -437,7 +419,7 @@ LIBC_INLINE void Port<T>::send_n(const void *const *src, uint64_t *size) {
     inline_memcpy(&buffer->data[1], lane_value(src, id), len);
   });
   uint64_t idx = sizeof(Buffer::data) - sizeof(uint64_t);
-  uint64_t mask = process.get_packet(index).header.mask;
+  uint64_t mask = process.packet[index].header.mask;
   while (gpu::ballot(mask, idx < num_sends)) {
     send([=](Buffer *buffer, uint32_t id) {
       uint64_t len = lane_value(size, id) - idx > sizeof(Buffer::data)
@@ -453,9 +435,9 @@ LIBC_INLINE void Port<T>::send_n(const void *const *src, uint64_t *size) {
 /// Receives an arbitrarily sized data buffer across the shared channel in
 /// multiples of the packet length. The \p alloc function is called with the
 /// size of the data so that we can initialize the size of the \p dst buffer.
-template <bool T>
+template <bool T, uint32_t S>
 template <typename A>
-LIBC_INLINE void Port<T>::recv_n(void **dst, uint64_t *size, A &&alloc) {
+LIBC_INLINE void Port<T, S>::recv_n(void **dst, uint64_t *size, A &&alloc) {
   uint64_t num_recvs = 0;
   recv([&](Buffer *buffer, uint32_t id) {
     lane_value(size, id) = reinterpret_cast<uint64_t *>(buffer->data)[0];
@@ -470,7 +452,7 @@ LIBC_INLINE void Port<T>::recv_n(void **dst, uint64_t *size, A &&alloc) {
     inline_memcpy(lane_value(dst, id), &buffer->data[1], len);
   });
   uint64_t idx = sizeof(Buffer::data) - sizeof(uint64_t);
-  uint64_t mask = process.get_packet(index).header.mask;
+  uint64_t mask = process.packet[index].header.mask;
   while (gpu::ballot(mask, idx < num_recvs)) {
     recv([=](Buffer *buffer, uint32_t id) {
       uint64_t len = lane_value(size, id) - idx > sizeof(Buffer::data)
@@ -491,28 +473,28 @@ template <uint16_t opcode>
 [[clang::convergent]] LIBC_INLINE cpp::optional<Client::Port>
 Client::try_open() {
   // Perform a naive linear scan for a port that can be opened to send data.
-  for (uint64_t index = 0; index < port_count; ++index) {
+  for (uint64_t index = 0; index < this->port_count; ++index) {
     // Attempt to acquire the lock on this index.
     uint64_t lane_mask = gpu::get_lane_mask();
-    if (!try_lock(lane_mask, index))
+    if (!this->try_lock(lane_mask, index))
       continue;
 
     // The mailbox state must be read with the lock held.
     atomic_thread_fence(cpp::MemoryOrder::ACQUIRE);
 
-    uint32_t in = load_inbox(index);
-    uint32_t out = load_outbox(index);
+    uint32_t in = this->load_inbox(index);
+    uint32_t out = this->load_outbox(index);
 
     // Once we acquire the index we need to check if we are in a valid sending
     // state.
-    if (buffer_unavailable(in, out)) {
-      unlock(lane_mask, index);
+    if (this->buffer_unavailable(in, out)) {
+      this->unlock(lane_mask, index);
       continue;
     }
 
     if (is_first_lane(lane_mask)) {
-      get_packet(index).header.opcode = opcode;
-      get_packet(index).header.mask = lane_mask;
+      this->packet[index].header.opcode = opcode;
+      this->packet[index].header.mask = lane_mask;
     }
     gpu::sync_lane(lane_mask);
     return Port(*this, lane_mask, index, out);
@@ -530,32 +512,34 @@ template <uint16_t opcode> LIBC_INLINE Client::Port Client::open() {
 
 /// Attempts to open a port to use as the server. The server can only open a
 /// port if it has a pending receive operation
-[[clang::convergent]] LIBC_INLINE cpp::optional<Server::Port>
-Server::try_open() {
+template <uint32_t lane_size>
+[[clang::convergent]] LIBC_INLINE
+    cpp::optional<typename Server<lane_size>::Port>
+    Server<lane_size>::try_open() {
   // Perform a naive linear scan for a port that has a pending request.
-  for (uint64_t index = 0; index < port_count; ++index) {
-    uint32_t in = load_inbox(index);
-    uint32_t out = load_outbox(index);
+  for (uint64_t index = 0; index < this->port_count; ++index) {
+    uint32_t in = this->load_inbox(index);
+    uint32_t out = this->load_outbox(index);
 
     // The server is passive, if there is no work pending don't bother
     // opening a port.
-    if (buffer_unavailable(in, out))
+    if (this->buffer_unavailable(in, out))
       continue;
 
     // Attempt to acquire the lock on this index.
     uint64_t lane_mask = gpu::get_lane_mask();
     // Attempt to acquire the lock on this index.
-    if (!try_lock(lane_mask, index))
+    if (!this->try_lock(lane_mask, index))
       continue;
 
     // The mailbox state must be read with the lock held.
     atomic_thread_fence(cpp::MemoryOrder::ACQUIRE);
 
-    in = load_inbox(index);
-    out = load_outbox(index);
+    in = this->load_inbox(index);
+    out = this->load_outbox(index);
 
-    if (buffer_unavailable(in, out)) {
-      unlock(lane_mask, index);
+    if (this->buffer_unavailable(in, out)) {
+      this->unlock(lane_mask, index);
       continue;
     }
 
@@ -564,7 +548,8 @@ Server::try_open() {
   return cpp::nullopt;
 }
 
-LIBC_INLINE Server::Port Server::open() {
+template <uint32_t lane_size>
+LIBC_INLINE typename Server<lane_size>::Port Server<lane_size>::open() {
   for (;;) {
     if (cpp::optional<Server::Port> p = try_open())
       return cpp::move(p.value());

diff  --git a/libc/src/__support/RPC/rpc_client.cpp b/libc/src/__support/RPC/rpc_client.cpp
index 8f71b654627f6..c03b75be1c01e 100644
--- a/libc/src/__support/RPC/rpc_client.cpp
+++ b/libc/src/__support/RPC/rpc_client.cpp
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "rpc_client.h"
 #include "rpc.h"
 
 namespace __llvm_libc {

diff  --git a/libc/src/gpu/rpc_reset.cpp b/libc/src/gpu/rpc_reset.cpp
index 949df7ccd4e32..ba5a097d1a1bc 100644
--- a/libc/src/gpu/rpc_reset.cpp
+++ b/libc/src/gpu/rpc_reset.cpp
@@ -18,8 +18,7 @@ namespace __llvm_libc {
 // shared buffer.
 LLVM_LIBC_FUNCTION(void, rpc_reset,
                    (unsigned int num_ports, void *rpc_shared_buffer)) {
-  __llvm_libc::rpc::client.reset(num_ports, __llvm_libc::gpu::get_lane_size(),
-                                 rpc_shared_buffer);
+  __llvm_libc::rpc::client.reset(num_ports, rpc_shared_buffer);
 }
 
 } // namespace __llvm_libc

diff  --git a/libc/startup/gpu/amdgpu/start.cpp b/libc/startup/gpu/amdgpu/start.cpp
index 8d85fe1a4b03d..9a9995512f30c 100644
--- a/libc/startup/gpu/amdgpu/start.cpp
+++ b/libc/startup/gpu/amdgpu/start.cpp
@@ -42,7 +42,6 @@ _begin(int argc, char **argv, char **env, void *rpc_shared_buffer) {
   // We need to set up the RPC client first in case any of the constructors
   // require it.
   __llvm_libc::rpc::client.reset(__llvm_libc::rpc::DEFAULT_PORT_COUNT,
-                                 __llvm_libc::gpu::get_lane_size(),
                                  rpc_shared_buffer);
 
   // We want the fini array callbacks to be run after other atexit

diff  --git a/libc/startup/gpu/nvptx/start.cpp b/libc/startup/gpu/nvptx/start.cpp
index fdd3c439530e1..90f00496a6b2d 100644
--- a/libc/startup/gpu/nvptx/start.cpp
+++ b/libc/startup/gpu/nvptx/start.cpp
@@ -46,7 +46,6 @@ _begin(int argc, char **argv, char **env, void *rpc_shared_buffer) {
   // We need to set up the RPC client first in case any of the constructors
   // require it.
   __llvm_libc::rpc::client.reset(__llvm_libc::rpc::DEFAULT_PORT_COUNT,
-                                 __llvm_libc::gpu::get_lane_size(),
                                  rpc_shared_buffer);
 
   // We want the fini array callbacks to be run after other atexit

diff  --git a/libc/utils/gpu/server/Server.cpp b/libc/utils/gpu/server/Server.cpp
index 11781b90114cc..f37fc13851477 100644
--- a/libc/utils/gpu/server/Server.cpp
+++ b/libc/utils/gpu/server/Server.cpp
@@ -14,6 +14,8 @@
 #include <memory>
 #include <mutex>
 #include <unordered_map>
+#include <variant>
+#include <vector>
 
 using namespace __llvm_libc;
 
@@ -22,81 +24,51 @@ static_assert(sizeof(rpc_buffer_t) == sizeof(rpc::Buffer),
 
 static_assert(RPC_MAXIMUM_PORT_COUNT == rpc::DEFAULT_PORT_COUNT,
               "Incorrect maximum port count");
-struct Device {
-  rpc::Server server;
-  std::unordered_map<rpc_opcode_t, rpc_opcode_callback_ty> callbacks;
-  std::unordered_map<rpc_opcode_t, void *> callback_data;
-};
-
-// A struct containing all the runtime state required to run the RPC server.
-struct State {
-  State(uint32_t num_devices)
-      : num_devices(num_devices),
-        devices(std::unique_ptr<Device[]>(new Device[num_devices])),
-        reference_count(0u) {}
-  uint32_t num_devices;
-  std::unique_ptr<Device[]> devices;
-  std::atomic_uint32_t reference_count;
-};
-
-static std::mutex startup_mutex;
-
-static State *state;
-
-rpc_status_t rpc_init(uint32_t num_devices) {
-  std::scoped_lock<decltype(startup_mutex)> lock(startup_mutex);
-  if (!state)
-    state = new State(num_devices);
-
-  if (state->reference_count == std::numeric_limits<uint32_t>::max())
-    return RPC_STATUS_ERROR;
-
-  state->reference_count++;
-
-  return RPC_STATUS_SUCCESS;
-}
-
-rpc_status_t rpc_shutdown(void) {
-  if (state->reference_count-- == 1)
-    delete state;
-
-  return RPC_STATUS_SUCCESS;
-}
-
-rpc_status_t rpc_server_init(uint32_t device_id, uint64_t num_ports,
-                             uint32_t lane_size, rpc_alloc_ty alloc,
-                             void *data) {
-  if (device_id >= state->num_devices)
-    return RPC_STATUS_OUT_OF_RANGE;
-
-  uint64_t buffer_size =
-      __llvm_libc::rpc::Server::allocation_size(num_ports, lane_size);
-  void *buffer = alloc(buffer_size, data);
-
-  if (!buffer)
-    return RPC_STATUS_ERROR;
 
-  state->devices[device_id].server.reset(num_ports, lane_size, buffer);
+// The client needs to support 
diff erent lane sizes for the SIMT model. Because
+// of this we need to select between the possible sizes that the client can use.
+struct Server {
+  template <uint32_t lane_size>
+  Server(std::unique_ptr<rpc::Server<lane_size>> &&server)
+      : server(std::move(server)) {}
 
-  return RPC_STATUS_SUCCESS;
-}
-
-rpc_status_t rpc_server_shutdown(uint32_t device_id, rpc_free_ty dealloc,
-                                 void *data) {
-  if (device_id >= state->num_devices)
-    return RPC_STATUS_OUT_OF_RANGE;
+  void reset(uint64_t port_count, void *buffer) {
+    std::visit([&](auto &server) { server->reset(port_count, buffer); },
+               server);
+  }
 
-  dealloc(rpc_get_buffer(device_id), data);
+  uint64_t allocation_size(uint64_t port_count) {
+    uint64_t ret = 0;
+    std::visit([&](auto &server) { ret = server->allocation_size(port_count); },
+               server);
+    return ret;
+  }
 
-  return RPC_STATUS_SUCCESS;
-}
+  void *get_buffer_start() const {
+    void *ret = nullptr;
+    std::visit([&](auto &server) { ret = server->get_buffer_start(); }, server);
+    return ret;
+  }
 
-rpc_status_t rpc_handle_server(uint32_t device_id) {
-  if (device_id >= state->num_devices)
-    return RPC_STATUS_OUT_OF_RANGE;
+  rpc_status_t handle_server(
+      std::unordered_map<rpc_opcode_t, rpc_opcode_callback_ty> &callbacks,
+      std::unordered_map<rpc_opcode_t, void *> &callback_data) {
+    rpc_status_t ret = RPC_STATUS_SUCCESS;
+    std::visit(
+        [&](auto &server) {
+          ret = handle_server(*server, callbacks, callback_data);
+        },
+        server);
+    return ret;
+  }
 
-  for (;;) {
-    auto port = state->devices[device_id].server.try_open();
+private:
+  template <uint32_t lane_size>
+  rpc_status_t handle_server(
+      rpc::Server<lane_size> &server,
+      std::unordered_map<rpc_opcode_t, rpc_opcode_callback_ty> &callbacks,
+      std::unordered_map<rpc_opcode_t, void *> &callback_data) {
+    auto port = server.try_open();
     if (!port)
       return RPC_STATUS_SUCCESS;
 
@@ -175,21 +147,133 @@ rpc_status_t rpc_handle_server(uint32_t device_id) {
       break;
     }
     default: {
-      auto handler = state->devices[device_id].callbacks.find(
-          static_cast<rpc_opcode_t>(port->get_opcode()));
+      auto handler =
+          callbacks.find(static_cast<rpc_opcode_t>(port->get_opcode()));
 
       // We error out on an unhandled opcode.
-      if (handler == state->devices[device_id].callbacks.end())
+      if (handler == callbacks.end())
         return RPC_STATUS_UNHANDLED_OPCODE;
 
       // Invoke the registered callback with a reference to the port.
-      void *data = state->devices[device_id].callback_data.at(
-          static_cast<rpc_opcode_t>(port->get_opcode()));
-      rpc_port_t port_ref{reinterpret_cast<uint64_t>(&*port)};
+      void *data =
+          callback_data.at(static_cast<rpc_opcode_t>(port->get_opcode()));
+      rpc_port_t port_ref{reinterpret_cast<uint64_t>(&*port), lane_size};
       (handler->second)(port_ref, data);
     }
     }
     port->close();
+    return RPC_STATUS_CONTINUE;
+  }
+
+  std::variant<std::unique_ptr<rpc::Server<1>>,
+               std::unique_ptr<rpc::Server<32>>,
+               std::unique_ptr<rpc::Server<64>>>
+      server;
+};
+
+struct Device {
+  template <typename T>
+  Device(std::unique_ptr<T> &&server) : server(std::move(server)) {}
+  Server server;
+  std::unordered_map<rpc_opcode_t, rpc_opcode_callback_ty> callbacks;
+  std::unordered_map<rpc_opcode_t, void *> callback_data;
+};
+
+// A struct containing all the runtime state required to run the RPC server.
+struct State {
+  State(uint32_t num_devices)
+      : num_devices(num_devices), devices(num_devices), reference_count(0u) {}
+  uint32_t num_devices;
+  std::vector<std::unique_ptr<Device>> devices;
+  std::atomic_uint32_t reference_count;
+};
+
+static std::mutex startup_mutex;
+
+static State *state;
+
+rpc_status_t rpc_init(uint32_t num_devices) {
+  std::scoped_lock<decltype(startup_mutex)> lock(startup_mutex);
+  if (!state)
+    state = new State(num_devices);
+
+  if (state->reference_count == std::numeric_limits<uint32_t>::max())
+    return RPC_STATUS_ERROR;
+
+  state->reference_count++;
+
+  return RPC_STATUS_SUCCESS;
+}
+
+rpc_status_t rpc_shutdown(void) {
+  if (state->reference_count-- == 1)
+    delete state;
+
+  return RPC_STATUS_SUCCESS;
+}
+
+rpc_status_t rpc_server_init(uint32_t device_id, uint64_t num_ports,
+                             uint32_t lane_size, rpc_alloc_ty alloc,
+                             void *data) {
+  if (device_id >= state->num_devices)
+    return RPC_STATUS_OUT_OF_RANGE;
+
+  if (!state->devices[device_id]) {
+    switch (lane_size) {
+    case 1:
+      state->devices[device_id] =
+          std::make_unique<Device>(std::make_unique<rpc::Server<1>>());
+      break;
+    case 32:
+      state->devices[device_id] =
+          std::make_unique<Device>(std::make_unique<rpc::Server<32>>());
+      break;
+    case 64:
+      state->devices[device_id] =
+          std::make_unique<Device>(std::make_unique<rpc::Server<64>>());
+      break;
+    default:
+      return RPC_STATUS_INVALID_LANE_SIZE;
+    }
+  }
+
+  uint64_t size = state->devices[device_id]->server.allocation_size(num_ports);
+  void *buffer = alloc(size, data);
+
+  if (!buffer)
+    return RPC_STATUS_ERROR;
+
+  state->devices[device_id]->server.reset(num_ports, buffer);
+
+  return RPC_STATUS_SUCCESS;
+}
+
+rpc_status_t rpc_server_shutdown(uint32_t device_id, rpc_free_ty dealloc,
+                                 void *data) {
+  if (device_id >= state->num_devices)
+    return RPC_STATUS_OUT_OF_RANGE;
+  if (!state->devices[device_id])
+    return RPC_STATUS_ERROR;
+
+  dealloc(rpc_get_buffer(device_id), data);
+  if (state->devices[device_id])
+    state->devices[device_id].release();
+
+  return RPC_STATUS_SUCCESS;
+}
+
+rpc_status_t rpc_handle_server(uint32_t device_id) {
+  if (device_id >= state->num_devices)
+    return RPC_STATUS_OUT_OF_RANGE;
+  if (!state->devices[device_id])
+    return RPC_STATUS_ERROR;
+
+  for (;;) {
+    auto &device = *state->devices[device_id];
+    rpc_status_t status =
+        device.server.handle_server(device.callbacks, device.callback_data);
+    if (status != RPC_STATUS_CONTINUE)
+      return status;
   }
 }
 
@@ -198,22 +282,41 @@ rpc_status_t rpc_register_callback(uint32_t device_id, rpc_opcode_t opcode,
                                    void *data) {
   if (device_id >= state->num_devices)
     return RPC_STATUS_OUT_OF_RANGE;
+  if (!state->devices[device_id])
+    return RPC_STATUS_ERROR;
 
-  state->devices[device_id].callbacks[opcode] = callback;
-  state->devices[device_id].callback_data[opcode] = data;
+  state->devices[device_id]->callbacks[opcode] = callback;
+  state->devices[device_id]->callback_data[opcode] = data;
   return RPC_STATUS_SUCCESS;
 }
 
 void *rpc_get_buffer(uint32_t device_id) {
   if (device_id >= state->num_devices)
     return nullptr;
-  return state->devices[device_id].server.get_buffer_start();
+  if (!state->devices[device_id])
+    return nullptr;
+  return state->devices[device_id]->server.get_buffer_start();
 }
 
 void rpc_recv_and_send(rpc_port_t ref, rpc_port_callback_ty callback,
                        void *data) {
-  rpc::Server::Port *port = reinterpret_cast<rpc::Server::Port *>(ref.handle);
-  port->recv_and_send([=](rpc::Buffer *buffer) {
-    callback(reinterpret_cast<rpc_buffer_t *>(buffer), data);
-  });
+  if (ref.lane_size == 1) {
+    rpc::Server<1>::Port *port =
+        reinterpret_cast<rpc::Server<1>::Port *>(ref.handle);
+    port->recv_and_send([=](rpc::Buffer *buffer) {
+      callback(reinterpret_cast<rpc_buffer_t *>(buffer), data);
+    });
+  } else if (ref.lane_size == 32) {
+    rpc::Server<32>::Port *port =
+        reinterpret_cast<rpc::Server<32>::Port *>(ref.handle);
+    port->recv_and_send([=](rpc::Buffer *buffer) {
+      callback(reinterpret_cast<rpc_buffer_t *>(buffer), data);
+    });
+  } else if (ref.lane_size == 64) {
+    rpc::Server<64>::Port *port =
+        reinterpret_cast<rpc::Server<64>::Port *>(ref.handle);
+    port->recv_and_send([=](rpc::Buffer *buffer) {
+      callback(reinterpret_cast<rpc_buffer_t *>(buffer), data);
+    });
+  }
 }

diff  --git a/libc/utils/gpu/server/Server.h b/libc/utils/gpu/server/Server.h
index 81162ccc164d8..adc38e0961006 100644
--- a/libc/utils/gpu/server/Server.h
+++ b/libc/utils/gpu/server/Server.h
@@ -23,15 +23,18 @@ const uint64_t RPC_MAXIMUM_PORT_COUNT = 64;
 /// status codes.
 typedef enum {
   RPC_STATUS_SUCCESS = 0x0,
+  RPC_STATUS_CONTINUE = 0x1,
   RPC_STATUS_ERROR = 0x1000,
   RPC_STATUS_OUT_OF_RANGE = 0x1001,
   RPC_STATUS_UNHANDLED_OPCODE = 0x1002,
+  RPC_STATUS_INVALID_LANE_SIZE = 0x1003,
 } rpc_status_t;
 
 /// A struct containing an opaque handle to an RPC port. This is what allows the
 /// server to communicate with the client.
 typedef struct rpc_port_s {
   uint64_t handle;
+  uint32_t lane_size;
 } rpc_port_t;
 
 /// A fixed-size buffer containing the payload sent from the client.


        


More information about the libc-commits mailing list