[libc-commits] [libc] [libc] Add RPC helpers for dispatching functions to the host (PR #179085)

Joseph Huber via libc-commits libc-commits at lists.llvm.org
Sat Jan 31 19:58:33 PST 2026


https://github.com/jhuber6 created https://github.com/llvm/llvm-project/pull/179085

Summary:
The RPC interface is useful for forwarding functions. This PR adds
helper functions for doing a completely bare forwarding of a function
from the client to the server. This is intended to facilitate
heterogenous libraries that implement host functions on the GPU (like
MPI or Fortran).


>From 5038e2a22aee9e9f101e510e49854d3d39f96b66 Mon Sep 17 00:00:00 2001
From: Joseph Huber <huberjn at outlook.com>
Date: Sat, 31 Jan 2026 21:56:01 -0600
Subject: [PATCH] [libc] Add RPC helpers for dispatching functions to the host

Summary:
The RPC interface is useful for forwarding functions. This PR adds
helper functions for doing a completely bare forwarding of a function
from the client to the server. This is intended to facilitate
heterogenous libraries that implement host functions on the GPU (like
MPI or Fortran).
---
 libc/docs/gpu/rpc.rst               |  46 ++++++++-
 libc/shared/rpc.h                   |   1 -
 libc/shared/rpc_dispatch.h          | 142 ++++++++++++++++++++++++++++
 libc/shared/rpc_util.h              | 103 +++++++++++++++++++-
 libc/src/__support/RPC/rpc_client.h |   2 +
 5 files changed, 286 insertions(+), 8 deletions(-)
 create mode 100644 libc/shared/rpc_dispatch.h

diff --git a/libc/docs/gpu/rpc.rst b/libc/docs/gpu/rpc.rst
index 4ac3786cfa085..f42b06090632c 100644
--- a/libc/docs/gpu/rpc.rst
+++ b/libc/docs/gpu/rpc.rst
@@ -113,10 +113,10 @@ done. It can be omitted if asynchronous execution is desired.
   void rpc_host_call(void *fn, void *data, size_t size) {
     rpc::Client::Port port = rpc::client.open<RPC_HOST_CALL>();
     port.send_n(data, size);
-    port.send([=](rpc::Buffer *buffer) {
+    port.send([=](rpc::Buffer *buffer, uint32_t) {
       buffer->data[0] = reinterpret_cast<uintptr_t>(fn);
     });
-    port.recv([](rpc::Buffer *) {});
+    port.recv([](rpc::Buffer *, uint32_t) {});
     port.close();
   }
 
@@ -131,7 +131,7 @@ call a function pointer provided by the client.
 In this example, the server simply runs forever in a separate thread for
 brevity's sake. Because the client is a GPU potentially handling several threads
 at once, the server needs to loop over all the active threads on the GPU. We
-abstract this into the ``lane_size`` variable, which is simply the device's warp
+abstract this into the ``num_lanes`` variable, which is simply the device's warp
 or wavefront size. The identifier is simply the threads index into the current
 warp or wavefront. We allocate memory to copy the struct data into, and then
 call the given function pointer with that copied data. The final send simply
@@ -147,8 +147,8 @@ data.
 
     switch(port->get_opcode()) {
     case RPC_HOST_CALL: {
-      uint64_t sizes[LANE_SIZE];
-      void *args[LANE_SIZE];
+      uint64_t sizes[NUM_LANES];
+      void *args[NUM_LANES];
       port->recv_n(args, sizes, [&](uint64_t size) { return new char[size]; });
       port->recv([&](rpc::Buffer *buffer, uint32_t id) {
         reinterpret_cast<void (*)(void *)>(buffer->data[0])(args[id]);
@@ -162,8 +162,44 @@ data.
       port->recv([](rpc::Buffer *) {});
       break;
     }
+    port->close();
   }
 
+Function Dispatch
+-----------------
+
+There are cases where the client wants to simply execute a function as-is on the
+server. A helper function is provided to make this case almost automatic. This
+assumes that all pointer arguments are opaque handles accessible by the server.
+Functions returning void will wait for the server to complete execution rather
+than submitting asynchronously.
+
+.. code-block:: c++
+
+  double fn(int x, long y, char c, double d);
+
+  // Client-side dispatch.
+  double fn(int x, long y, char c, double d) {
+    return rpc::dispatch<OPCODE>(client, fn, x, y, c, d);
+  }
+
+  // Server-side handling.
+  for(;;) {
+    auto port = server.try_open(index);
+    if (!port)
+      return continue;
+
+    switch(port->get_opcode()) {
+    case OPCODE:
+      rpc::invoke<NUM_LANES>(fn, *port);
+    default:
+      port->recv([](rpc::Buffer *) {});
+      break;
+    }
+    port->close();
+  }
+
+
 CUDA Server Example
 -------------------
 
diff --git a/libc/shared/rpc.h b/libc/shared/rpc.h
index dac2a7949a906..567b22054b35b 100644
--- a/libc/shared/rpc.h
+++ b/libc/shared/rpc.h
@@ -590,7 +590,6 @@ RPC_ATTRS Server::Port Server::open(uint32_t lane_size) {
   }
 }
 
-#undef RPC_ATTRS
 #if !__has_builtin(__scoped_atomic_load_n)
 #undef __scoped_atomic_load_n
 #undef __scoped_atomic_store_n
diff --git a/libc/shared/rpc_dispatch.h b/libc/shared/rpc_dispatch.h
new file mode 100644
index 0000000000000..ce1ee938a5fea
--- /dev/null
+++ b/libc/shared/rpc_dispatch.h
@@ -0,0 +1,142 @@
+//===-- Helper functions for client / server dispatch -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "rpc.h"
+#include "rpc_util.h"
+
+namespace rpc {
+namespace {
+template <typename... Ts> struct tuple_bytes {
+  static constexpr uint64_t SIZE = (sizeof(Ts) + ...);
+  using array_type = rpc::array<uint8_t, SIZE>;
+
+  template <uint64_t... Is>
+  RPC_ATTRS static constexpr array_type pack_impl(rpc::tuple<Ts...> t,
+                                                  rpc::index_sequence<Is...>) {
+    array_type out{};
+    uint8_t *p = out.data();
+    ((rpc::rpc_memcpy(p, &rpc::get<Is>(t), sizeof(Ts)), p += sizeof(Ts)), ...);
+    return out;
+  }
+
+  RPC_ATTRS static constexpr array_type pack(rpc::tuple<Ts...> t) {
+    return pack_impl(t, rpc::index_sequence_for<Ts...>{});
+  }
+
+  template <uint64_t... Is>
+  RPC_ATTRS static constexpr rpc::tuple<Ts...>
+  unpack_impl(const uint8_t *data, rpc::index_sequence<Is...>) {
+    rpc::tuple<Ts...> t{};
+    const uint8_t *p = data;
+    ((rpc::rpc_memcpy(&rpc::get<Is>(t), p, sizeof(Ts)), p += sizeof(Ts)), ...);
+    return t;
+  }
+
+  RPC_ATTRS static constexpr rpc::tuple<Ts...> unpack(const array_type &a) {
+    return unpack_impl(a.data(), rpc::index_sequence_for<Ts...>{});
+  }
+};
+
+template <typename Tuple> struct tuple_bytes_from_tuple;
+
+template <typename... Ts>
+struct tuple_bytes_from_tuple<rpc::tuple<Ts...>> : tuple_bytes<Ts...> {};
+
+template <typename> struct function_traits;
+
+template <typename R, typename... Args> struct function_traits<R (*)(Args...)> {
+  using return_type = R;
+  using arg_types = rpc::tuple<Args...>;
+  static constexpr uint64_t ARITY = sizeof...(Args);
+};
+} // namespace
+
+template <uint32_t Opcode, typename FnTy, typename... CallArgs>
+RPC_ATTRS constexpr typename function_traits<FnTy>::return_type
+dispatch(rpc::Client &client, FnTy Fn, CallArgs &&...args) {
+  using Traits = function_traits<FnTy>;
+  using Ret = typename Traits::return_type;
+  using ArgTuple = typename Traits::arg_types;
+
+  static_assert(sizeof...(CallArgs) == Traits::ARITY,
+                "Argument count mismatch");
+
+  ArgTuple arg_tuple{rpc::forward<CallArgs>(args)...};
+  auto bytes =
+      tuple_bytes<rpc::remove_reference_t<CallArgs>...>::pack(arg_tuple);
+
+  auto port = client.open<Opcode>();
+  for (uint64_t i = 0; i < bytes.size(); i += sizeof(rpc::Buffer)) {
+    const uint64_t n = rpc::min(bytes.size() - i, sizeof(rpc::Buffer));
+    port.send([&, i, n](rpc::Buffer *b, uint32_t) {
+      rpc::rpc_memcpy(b->data, bytes.data() + i, n);
+    });
+  }
+
+  if constexpr (!rpc::is_void<Ret>::value) {
+    rpc::array<uint8_t, sizeof(Ret)> ret_bytes{};
+    for (uint64_t i = 0; i < ret_bytes.size(); i += sizeof(rpc::Buffer)) {
+      uint64_t n = rpc::min(ret_bytes.size() - i, sizeof(rpc::Buffer));
+      port.recv([&](rpc::Buffer *b, uint32_t) {
+        rpc::rpc_memcpy(ret_bytes.data() + i, b->data, n);
+      });
+    }
+    port.close();
+    Ret ret{};
+    rpc::rpc_memcpy(&ret, ret_bytes.data(), sizeof(Ret));
+    return ret;
+  } else {
+    port.recv([&](rpc::Buffer *, uint32_t) {});
+    port.close();
+  }
+}
+
+template <uint32_t LANES, typename FnTy>
+RPC_ATTRS constexpr void invoke(FnTy fn, rpc::Server::Port &port) {
+  using Traits = function_traits<FnTy>;
+  using Ret = typename Traits::return_type;
+  using ArgTuple = typename Traits::arg_types;
+  using Bytes = tuple_bytes_from_tuple<ArgTuple>;
+
+  constexpr uint64_t ARRAY_SIZE = sizeof(typename Bytes::array_type);
+  constexpr uint64_t BUFFER_SIZE = sizeof(rpc::Buffer);
+
+  typename Bytes::array_type arg_buf[LANES]{};
+
+  for (uint64_t i = 0; i < ARRAY_SIZE; i += BUFFER_SIZE) {
+    const uint64_t n = rpc::min(ARRAY_SIZE - i, BUFFER_SIZE);
+
+    port.recv([&](rpc::Buffer *b, uint32_t id) {
+      rpc::rpc_memcpy(arg_buf[id].data() + i, b->data, n);
+    });
+  }
+
+  ArgTuple args[LANES];
+  for (uint32_t id = 0; id < LANES; ++id)
+    args[id] = Bytes::unpack(arg_buf[id]);
+
+  if constexpr (!rpc::is_void<Ret>::value) {
+    Ret rets[LANES]{};
+
+    for (uint32_t id = 0; id < LANES; ++id)
+      rets[id] = rpc::apply(fn, args[id]);
+
+    for (uint64_t i = 0; i < sizeof(Ret); i += sizeof(rpc::Buffer)) {
+      const uint64_t n = rpc::min(sizeof(Ret) - i, sizeof(rpc::Buffer));
+
+      port.send([&](rpc::Buffer *b, uint32_t id) {
+        rpc::rpc_memcpy(b->data, reinterpret_cast<uint8_t *>(&rets[id]) + i, n);
+      });
+    }
+  } else {
+    for (uint32_t id = 0; id < LANES; ++id)
+      rpc::apply(fn, args[id]);
+    port.send([&](rpc::Buffer *, uint32_t) {});
+  }
+}
+} // namespace rpc
diff --git a/libc/shared/rpc_util.h b/libc/shared/rpc_util.h
index 687814b7ff2ae..de49f8c410769 100644
--- a/libc/shared/rpc_util.h
+++ b/libc/shared/rpc_util.h
@@ -45,16 +45,44 @@ template <class T, T v> struct type_constant {
 template <class T> struct remove_reference : type_identity<T> {};
 template <class T> struct remove_reference<T &> : type_identity<T> {};
 template <class T> struct remove_reference<T &&> : type_identity<T> {};
+template <class T>
+using remove_reference_t = typename remove_reference<T>::type;
 
 template <class T> struct is_const : type_constant<bool, false> {};
 template <class T> struct is_const<const T> : type_constant<bool, true> {};
 
+template <class T> struct is_void : type_constant<bool, false> {};
+template <> struct is_void<void> : type_constant<bool, true> {};
+
 /// Freestanding implementation of std::move.
 template <class T>
 RPC_ATTRS constexpr typename remove_reference<T>::type &&move(T &&t) {
   return static_cast<typename remove_reference<T>::type &&>(t);
 }
 
+/// Freestanding integer sequence.
+template <typename T, T... Ints> struct integer_sequence {
+  template <T Next> using append = integer_sequence<T, Ints..., Next>;
+};
+
+namespace detail {
+template <typename T, int N> struct make_integer_sequence {
+  using type =
+      typename make_integer_sequence<T, N - 1>::type::template append<N>;
+};
+template <typename T> struct make_integer_sequence<T, -1> {
+  using type = integer_sequence<T>;
+};
+} // namespace detail
+
+template <uint64_t... Ints>
+using index_sequence = integer_sequence<uint64_t, Ints...>;
+template <int N>
+using make_index_sequence =
+    typename detail::make_integer_sequence<uint64_t, N - 1>::type;
+template <typename... Ts>
+using index_sequence_for = make_index_sequence<sizeof...(Ts)>;
+
 /// Freestanding implementation of std::forward.
 template <typename T>
 RPC_ATTRS constexpr T &&forward(typename remove_reference<T>::type &value) {
@@ -150,6 +178,73 @@ template <typename T> class optional {
   RPC_ATTRS constexpr T &&operator*() && { return move(storage.stored_value); }
 };
 
+/// Minimal array type.
+template <typename T, uint64_t N> struct array {
+  T elems[N];
+
+  RPC_ATTRS constexpr T *data() { return elems; }
+  RPC_ATTRS constexpr const T *data() const { return elems; }
+  RPC_ATTRS static constexpr uint64_t size() { return N; }
+
+  RPC_ATTRS constexpr T &operator[](uint64_t i) { return elems[i]; }
+  RPC_ATTRS constexpr const T &operator[](uint64_t i) const { return elems[i]; }
+};
+
+/// Minimal tuple type.
+template <typename... Ts> struct tuple;
+template <> struct tuple<> {};
+
+template <typename Head, typename... Tail>
+struct tuple<Head, Tail...> : tuple<Tail...> {
+  Head head;
+
+  RPC_ATTRS constexpr tuple() = default;
+
+  template <typename OHead, typename... OTail>
+  RPC_ATTRS constexpr tuple &operator=(const tuple<OHead, OTail...> &other) {
+    head = other.get_head();
+    this->get_tail() = other.get_tail();
+    return *this;
+  }
+
+  RPC_ATTRS constexpr tuple(const Head &h, const Tail &...t)
+      : tuple<Tail...>(t...), head(h) {}
+
+  RPC_ATTRS constexpr Head &get_head() { return head; }
+  RPC_ATTRS constexpr const Head &get_head() const { return head; }
+
+  RPC_ATTRS constexpr tuple<Tail...> &get_tail() { return *this; }
+  RPC_ATTRS constexpr const tuple<Tail...> &get_tail() const { return *this; }
+};
+
+template <uint64_t Idx, typename Head, typename... Tail>
+RPC_ATTRS constexpr auto &get(tuple<Head, Tail...> &t) {
+  if constexpr (Idx == 0)
+    return t.get_head();
+  else
+    return get<Idx - 1>(t.get_tail());
+}
+template <uint64_t Idx, typename Head, typename... Tail>
+RPC_ATTRS constexpr const auto &get(const tuple<Head, Tail...> &t) {
+  if constexpr (Idx == 0)
+    return t.get_head();
+  else
+    return get<Idx - 1>(t.get_tail());
+}
+
+namespace detail {
+template <typename F, typename Tuple, uint64_t... Is>
+RPC_ATTRS auto apply(F &&f, Tuple &&t, index_sequence<Is...>) {
+  return f(get<Is>(static_cast<Tuple &&>(t))...);
+}
+} // namespace detail
+
+template <typename F, typename... Ts>
+RPC_ATTRS auto apply(F &&f, tuple<Ts...> &t) {
+  return detail::apply(static_cast<F &&>(f), t,
+                       make_index_sequence<sizeof...(Ts)>{});
+}
+
 /// Suspend the thread briefly to assist the thread scheduler during busy loops.
 RPC_ATTRS void sleep_briefly() {
 #if __has_builtin(__nvvm_reflect)
@@ -263,14 +358,18 @@ template <typename T, typename U> RPC_ATTRS T *advance(T *ptr, U bytes) {
 }
 
 /// Wrapper around the optimal memory copy implementation for the target.
-RPC_ATTRS void rpc_memcpy(void *dst, const void *src, size_t count) {
+RPC_ATTRS void rpc_memcpy(void *dst, const void *src, uint64_t count) {
   __builtin_memcpy(dst, src, count);
 }
 
-template <class T> RPC_ATTRS constexpr const T &max(const T &a, const T &b) {
+template <class T, class U> RPC_ATTRS constexpr T max(const T &a, const U &b) {
   return (a < b) ? b : a;
 }
 
+template <class T, class U> RPC_ATTRS constexpr T min(const T &a, const U &b) {
+  return (a < b) ? a : b;
+}
+
 } // namespace rpc
 
 #endif // LLVM_LIBC_SHARED_RPC_UTIL_H
diff --git a/libc/src/__support/RPC/rpc_client.h b/libc/src/__support/RPC/rpc_client.h
index 199803badf1a9..29e46e733c031 100644
--- a/libc/src/__support/RPC/rpc_client.h
+++ b/libc/src/__support/RPC/rpc_client.h
@@ -10,6 +10,7 @@
 #define LLVM_LIBC_SRC___SUPPORT_RPC_RPC_CLIENT_H
 
 #include "shared/rpc.h"
+#include "shared/rpc_dispatch.h"
 #include "shared/rpc_opcodes.h"
 
 #include "src/__support/CPP/type_traits.h"
@@ -20,6 +21,7 @@ namespace rpc {
 
 using ::rpc::Buffer;
 using ::rpc::Client;
+using ::rpc::dispatch;
 using ::rpc::Port;
 using ::rpc::Process;
 using ::rpc::Server;



More information about the libc-commits mailing list