[libc-commits] [libc] [llvm] [libc] Add RPC helpers for dispatching functions to the host (PR #179085)
Joseph Huber via libc-commits
libc-commits at lists.llvm.org
Sun Feb 8 15:06:34 PST 2026
https://github.com/jhuber6 updated https://github.com/llvm/llvm-project/pull/179085
>From dae0e88aa02f438186bf0459c53378b6693a8397 Mon Sep 17 00:00:00 2001
From: Joseph Huber <huberjn at outlook.com>
Date: Sat, 31 Jan 2026 21:56:01 -0600
Subject: [PATCH] [libc] Add RPC helpers for dispatching functions to the host
Summary:
The RPC interface is useful for forwarding functions. This PR adds
helper functions for doing a completely bare forwarding of a function
from the client to the server. This is intended to facilitate
heterogenous libraries that implement host functions on the GPU (like
MPI or Fortran).
Add support for regular handling
update doc
---
libc/docs/gpu/rpc.rst | 47 +++++-
libc/shared/rpc.h | 38 ++++-
libc/shared/rpc_dispatch.h | 255 +++++++++++++++++++++++++++++
libc/shared/rpc_util.h | 170 ++++++++++++++++++-
offload/test/libc/rpc_callback.c | 66 --------
offload/test/libc/rpc_callback.cpp | 171 +++++++++++++++++++
6 files changed, 671 insertions(+), 76 deletions(-)
create mode 100644 libc/shared/rpc_dispatch.h
delete mode 100644 offload/test/libc/rpc_callback.c
create mode 100644 offload/test/libc/rpc_callback.cpp
diff --git a/libc/docs/gpu/rpc.rst b/libc/docs/gpu/rpc.rst
index 4ac3786cfa085..a9d63b734572c 100644
--- a/libc/docs/gpu/rpc.rst
+++ b/libc/docs/gpu/rpc.rst
@@ -113,10 +113,10 @@ done. It can be omitted if asynchronous execution is desired.
void rpc_host_call(void *fn, void *data, size_t size) {
rpc::Client::Port port = rpc::client.open<RPC_HOST_CALL>();
port.send_n(data, size);
- port.send([=](rpc::Buffer *buffer) {
+ port.send([=](rpc::Buffer *buffer, uint32_t) {
buffer->data[0] = reinterpret_cast<uintptr_t>(fn);
});
- port.recv([](rpc::Buffer *) {});
+ port.recv([](rpc::Buffer *, uint32_t) {});
port.close();
}
@@ -131,7 +131,7 @@ call a function pointer provided by the client.
In this example, the server simply runs forever in a separate thread for
brevity's sake. Because the client is a GPU potentially handling several threads
at once, the server needs to loop over all the active threads on the GPU. We
-abstract this into the ``lane_size`` variable, which is simply the device's warp
+abstract this into the ``num_lanes`` variable, which is simply the device's warp
or wavefront size. The identifier is simply the threads index into the current
warp or wavefront. We allocate memory to copy the struct data into, and then
call the given function pointer with that copied data. The final send simply
@@ -147,8 +147,8 @@ data.
switch(port->get_opcode()) {
case RPC_HOST_CALL: {
- uint64_t sizes[LANE_SIZE];
- void *args[LANE_SIZE];
+ uint64_t sizes[NUM_LANES];
+ void *args[NUM_LANES];
port->recv_n(args, sizes, [&](uint64_t size) { return new char[size]; });
port->recv([&](rpc::Buffer *buffer, uint32_t id) {
reinterpret_cast<void (*)(void *)>(buffer->data[0])(args[id]);
@@ -162,8 +162,45 @@ data.
port->recv([](rpc::Buffer *) {});
break;
}
+ port->close();
}
+Function Dispatch
+-----------------
+
+There are cases where the client wants to simply execute a function as-is on the
+server. A helper function is provided to make this case almost automatic. By
+default, all memory is assumed to live privately on the client. Pointer
+arguments will be copied between the client and server for correctness.
+Functions returning void will wait for the server to complete execution rather
+than submitting asynchronously.
+
+.. code-block:: c++
+
+ double fn(int x, long y, char c, double d);
+
+ // Client-side dispatch.
+ double fn(int x, long y, char c, double d) {
+ return rpc::dispatch<OPCODE>(client, fn, x, y, c, d);
+ }
+
+ // Server-side handling.
+ for(;;) {
+ auto port = server.try_open(index);
+ if (!port)
+ return continue;
+
+ switch(port->get_opcode()) {
+ case OPCODE:
+ rpc::invoke<NUM_LANES>(fn, *port);
+ default:
+ port->recv([](rpc::Buffer *) {});
+ break;
+ }
+ port->close();
+ }
+
+
CUDA Server Example
-------------------
diff --git a/libc/shared/rpc.h b/libc/shared/rpc.h
index dac2a7949a906..febee0147a8d2 100644
--- a/libc/shared/rpc.h
+++ b/libc/shared/rpc.h
@@ -318,10 +318,19 @@ template <bool T> struct Port {
template <typename A>
RPC_ATTRS void recv_n(void **dst, uint64_t *size, A &&alloc);
+ template <typename Ty> RPC_ATTRS void send_n(const Ty *src);
+ template <typename Ty> RPC_ATTRS void recv_n(Ty *dst);
+
RPC_ATTRS uint32_t get_opcode() const { return process.header[index].opcode; }
RPC_ATTRS uint32_t get_index() const { return index; }
+ RPC_ATTRS uint64_t get_lane_mask() const {
+ if constexpr (T)
+ return process.header[index].mask;
+ return lane_mask;
+ }
+
RPC_ATTRS void close() {
// Wait for all lanes to finish using the port.
rpc::sync_lane(lane_mask);
@@ -392,7 +401,7 @@ template <bool T> template <typename F> RPC_ATTRS void Port<T>::send(F fill) {
process.wait_for_ownership(lane_mask, index, out, in);
// Apply the \p fill function to initialize the buffer and release the memory.
- invoke_rpc(fill, lane_size, process.header[index].mask,
+ invoke_rpc(fill, lane_size, get_lane_mask(),
process.get_packet(index, lane_size));
out = process.invert_outbox(index, out);
owns_buffer = false;
@@ -414,7 +423,7 @@ template <bool T> template <typename U> RPC_ATTRS void Port<T>::recv(U use) {
process.wait_for_ownership(lane_mask, index, out, in);
// Apply the \p use function to read the memory out of the buffer.
- invoke_rpc(use, lane_size, process.header[index].mask,
+ invoke_rpc(use, lane_size, get_lane_mask(),
process.get_packet(index, lane_size));
receive = true;
owns_buffer = true;
@@ -509,6 +518,30 @@ RPC_ATTRS void Port<T>::recv_n(void **dst, uint64_t *size, A &&alloc) {
}
}
+/// Simplified version of `send_n` where the size is a known constant.
+template <bool T>
+template <typename Ty>
+RPC_ATTRS void Port<T>::send_n(const Ty *src) {
+ for (uint64_t idx = 0; idx < sizeof(Ty); idx += sizeof(Buffer::data)) {
+ const uint64_t bytes = rpc::min(sizeof(Ty) - idx, sizeof(Buffer::data));
+ send([&](Buffer *buffer, uint32_t id) {
+ rpc_memcpy(buffer->data, &lane_value(src, id) + idx, bytes);
+ });
+ }
+}
+
+/// Simplified version of `recv_n` where the size is a known constant.
+template <bool T>
+template <typename Ty>
+RPC_ATTRS void Port<T>::recv_n(Ty *dst) {
+ for (uint64_t idx = 0; idx < sizeof(Ty); idx += sizeof(Buffer::data)) {
+ const uint64_t bytes = rpc::min(sizeof(Ty) - idx, sizeof(Buffer::data));
+ recv([&](Buffer *buffer, uint32_t id) {
+ rpc_memcpy(&lane_value(dst, id) + idx, buffer->data, bytes);
+ });
+ }
+}
+
/// Continually attempts to open a port to use as the client. The client can
/// only open a port if we find an index that is in a valid sending state. That
/// is, there are send operations pending that haven't been serviced on this
@@ -590,7 +623,6 @@ RPC_ATTRS Server::Port Server::open(uint32_t lane_size) {
}
}
-#undef RPC_ATTRS
#if !__has_builtin(__scoped_atomic_load_n)
#undef __scoped_atomic_load_n
#undef __scoped_atomic_store_n
diff --git a/libc/shared/rpc_dispatch.h b/libc/shared/rpc_dispatch.h
new file mode 100644
index 0000000000000..eea603e6c7d3a
--- /dev/null
+++ b/libc/shared/rpc_dispatch.h
@@ -0,0 +1,255 @@
+//===-- Helper functions for client / server dispatch -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "rpc.h"
+#include "rpc_util.h"
+
+namespace rpc {
+namespace {
+// Traits to convert between a tuple and binary representation of an argument
+// list.
+template <typename... Ts> struct tuple_bytes {
+ static constexpr uint64_t SIZE = (sizeof(Ts) + ...);
+ using array_type = rpc::array<uint8_t, SIZE>;
+
+ template <uint64_t... Is>
+ RPC_ATTRS static constexpr array_type pack_impl(rpc::tuple<Ts...> t,
+ rpc::index_sequence<Is...>) {
+ array_type out{};
+ uint8_t *p = out.data();
+ ((rpc::rpc_memcpy(p, &rpc::get<Is>(t), sizeof(Ts)), p += sizeof(Ts)), ...);
+ return out;
+ }
+
+ RPC_ATTRS static constexpr array_type pack(rpc::tuple<Ts...> t) {
+ return pack_impl(t, rpc::index_sequence_for<Ts...>{});
+ }
+
+ template <uint64_t... Is>
+ RPC_ATTRS static constexpr rpc::tuple<Ts...>
+ unpack_impl(const uint8_t *data, rpc::index_sequence<Is...>) {
+ rpc::tuple<Ts...> t{};
+ const uint8_t *p = data;
+ ((rpc::rpc_memcpy(&rpc::get<Is>(t), p, sizeof(Ts)), p += sizeof(Ts)), ...);
+ return t;
+ }
+
+ RPC_ATTRS static constexpr rpc::tuple<Ts...> unpack(const array_type &a) {
+ return unpack_impl(a.data(), rpc::index_sequence_for<Ts...>{});
+ }
+};
+template <typename... Ts>
+struct tuple_bytes<rpc::tuple<Ts...>> : tuple_bytes<Ts...> {};
+
+template <typename> struct function_traits;
+template <typename R, typename... Args> struct function_traits<R (*)(Args...)> {
+ using return_type = R;
+ using arg_types = rpc::tuple<Args...>;
+ static constexpr uint64_t ARITY = sizeof...(Args);
+};
+
+// Client-side dispatch of pointer values. We copy the memory associated with
+// the pointer to the server and recieve back a server-side pointer to replace
+// the client-side pointer in the argument list.
+template <uint64_t Idx, typename Tuple>
+RPC_ATTRS constexpr void prepare_arg(rpc::Client::Port &port, Tuple &t) {
+ using ArgTy = rpc::tuple_element_t<Idx, Tuple>;
+ if constexpr (rpc::is_pointer_v<ArgTy>) {
+ // We assume all constant character arrays are C-strings.
+ uint64_t size{};
+ if constexpr (rpc::is_same_v<ArgTy, const char *>)
+ size = rpc::string_length(rpc::get<Idx>(t));
+ else
+ size = sizeof(rpc::remove_pointer_t<ArgTy>);
+ port.send_n(rpc::get<Idx>(t), size);
+ port.recv([&](rpc::Buffer *buffer, uint32_t) {
+ rpc::get<Idx>(t) = *reinterpret_cast<ArgTy *>(buffer->data);
+ });
+ }
+}
+
+// Server-side handling of pointer arguments. We recieve the memory into a
+// temporary buffer and pass a pointer to this new memory back to the client.
+template <uint32_t NUM_LANES, typename Tuple, uint64_t Idx>
+RPC_ATTRS constexpr void prepare_arg(rpc::Server::Port &port) {
+ using ArgTy = rpc::tuple_element_t<Idx, Tuple>;
+ if constexpr (rpc::is_pointer_v<ArgTy>) {
+ void *args[NUM_LANES]{};
+ uint64_t sizes[NUM_LANES]{};
+ port.recv_n(args, sizes, [](uint64_t size) {
+ if constexpr (rpc::is_same_v<ArgTy, const char *>)
+ return new rpc::remove_const_t<rpc::remove_pointer_t<ArgTy>>[size];
+ else
+ return new rpc::remove_const_t<rpc::remove_pointer_t<ArgTy>>;
+ });
+ port.send([&](rpc::Buffer *buffer, uint32_t id) {
+ *reinterpret_cast<ArgTy *>(buffer->data) = static_cast<ArgTy>(args[id]);
+ });
+ }
+}
+
+// Client-side finalization of pointer arguments. If the type is not constant we
+// must copy back any potential modifications the invoked function made to that
+// pointer.
+template <uint64_t Idx, typename Tuple>
+RPC_ATTRS constexpr void finish_arg(rpc::Client::Port &port, Tuple &t) {
+ using ArgTy = rpc::tuple_element_t<Idx, Tuple>;
+ using MemoryTy = rpc::remove_const_t<rpc::remove_pointer_t<ArgTy>> *;
+ if constexpr (rpc::is_pointer_v<ArgTy> && !rpc::is_const_v<ArgTy>) {
+ uint64_t size{};
+ void *buf{};
+ port.recv_n(&buf, &size, [&](uint64_t) {
+ return const_cast<MemoryTy>(rpc::get<Idx>(t));
+ });
+ }
+}
+
+// Server-side finalization of pointer arguments. We copy any potential
+// modifications to the value back to the client unless it was a constant. We
+// can also free the associated memory.
+template <uint32_t NUM_LANES, uint64_t Idx, typename Tuple>
+RPC_ATTRS constexpr void finish_arg(rpc::Server::Port &port,
+ Tuple (&t)[NUM_LANES]) {
+ using ArgTy = rpc::tuple_element_t<Idx, Tuple>;
+ if constexpr (rpc::is_pointer_v<ArgTy> && !rpc::is_const_v<ArgTy>) {
+ const void *buffer[NUM_LANES]{};
+ size_t sizes[NUM_LANES]{};
+ for (uint32_t id = 0; id < NUM_LANES; ++id) {
+ if (port.get_lane_mask() & (uint64_t(1) << id)) {
+ buffer[id] = rpc::get<Idx>(t[id]);
+ sizes[id] = sizeof(rpc::remove_pointer_t<ArgTy>);
+ }
+ }
+ port.send_n(buffer, sizes);
+ }
+
+ if constexpr (rpc::is_pointer_v<ArgTy>) {
+ for (uint32_t id = 0; id < NUM_LANES; ++id) {
+ if (port.get_lane_mask() & (uint64_t(1) << id)) {
+ if constexpr (rpc::is_same_v<ArgTy, const char *>)
+ delete[] rpc::get<Idx>(t[id]);
+ else
+ delete rpc::get<Idx>(t[id]);
+ }
+ }
+ }
+}
+
+// Iterate over the tuple list of arguments to see if we need to forward any.
+// The current forwarding is somewhat inefficient as each pointer is an
+// individual RPC call.
+template <typename Tuple, uint64_t... Is>
+RPC_ATTRS constexpr void prepare_args(rpc::Client::Port &port, Tuple &t,
+ rpc::index_sequence<Is...>) {
+ (prepare_arg<Is>(port, t), ...);
+}
+template <uint32_t NUM_LANES, typename Tuple, uint64_t... Is>
+RPC_ATTRS constexpr void prepare_args(rpc::Server::Port &port,
+ rpc::index_sequence<Is...>) {
+ (prepare_arg<NUM_LANES, Tuple, Is>(port), ...);
+}
+
+// Performs the preparation in reverse, copying back any modified values.
+template <typename Tuple, uint64_t... Is>
+RPC_ATTRS constexpr void finish_args(rpc::Client::Port &port, Tuple &&t,
+ rpc::index_sequence<Is...>) {
+ (finish_arg<Is>(port, t), ...);
+}
+template <uint32_t NUM_LANES, typename Tuple, uint64_t... Is>
+RPC_ATTRS constexpr void finish_args(rpc::Server::Port &port,
+ Tuple (&t)[NUM_LANES],
+ rpc::index_sequence<Is...>) {
+ (finish_arg<NUM_LANES, Is>(port, t), ...);
+}
+} // namespace
+
+// Dispatch a function call to the server through the RPC mechanism. Copies the
+// argument list through the RPC interface.
+template <uint32_t OPCODE, typename FnTy, typename... CallArgs>
+RPC_ATTRS constexpr typename function_traits<FnTy>::return_type
+dispatch(rpc::Client &client, FnTy, CallArgs... args) {
+ using Traits = function_traits<FnTy>;
+ using RetTy = typename Traits::return_type;
+ using TupleTy = typename Traits::arg_types;
+ using Bytes = tuple_bytes<CallArgs...>;
+
+ static_assert(sizeof...(CallArgs) == Traits::ARITY,
+ "Argument count mismatch");
+ static_assert(((rpc::is_trivially_constructible_v<CallArgs> &&
+ rpc::is_trivially_copyable_v<CallArgs>) &&
+ ...),
+ "Must be a trivial type");
+
+ auto port = client.open<OPCODE>();
+
+ // Copy over any pointer arguments by walking the argument list.
+ TupleTy arg_tuple{rpc::forward<CallArgs>(args)...};
+ rpc::prepare_args(port, arg_tuple, rpc::make_index_sequence<Traits::ARITY>{});
+
+ // Compress the argument list to a binary stream and send it to the server.
+ auto bytes = Bytes::pack(arg_tuple);
+ port.send_n(&bytes);
+
+ // Copy back any potentially modified pointer arguments and the return value.
+ rpc::finish_args(port, TupleTy{rpc::forward<CallArgs>(args)...},
+ rpc::make_index_sequence<Traits::ARITY>{});
+
+ // Copy back the final function return value.
+ using BufferTy = rpc::conditional_t<rpc::is_void_v<RetTy>, uint8_t, RetTy>;
+ BufferTy ret{};
+ port.recv_n(&ret);
+ port.close();
+
+ if constexpr (!rpc::is_void_v<RetTy>)
+ return ret;
+}
+
+// Invoke a function on the server on behalf of the client. Recieves the
+// arguments through the interface and forwards them to the function.
+template <uint32_t NUM_LANES, typename FnTy>
+RPC_ATTRS constexpr void invoke(FnTy fn, rpc::Server::Port &port) {
+ using Traits = function_traits<FnTy>;
+ using RetTy = typename Traits::return_type;
+ using TupleTy = typename Traits::arg_types;
+ using Bytes = tuple_bytes<TupleTy>;
+
+ // Recieve pointer data and arguments from the host and pack it in server-side
+ // memory.
+ rpc::prepare_args<NUM_LANES, TupleTy>(
+ port, rpc::make_index_sequence<Traits::ARITY>{});
+
+ typename Bytes::array_type arg_buf[NUM_LANES]{};
+ port.recv_n(arg_buf);
+
+ // Convert the recieved arguments into an invocable argument list.
+ TupleTy args[NUM_LANES];
+ for (uint32_t id = 0; id < NUM_LANES; ++id)
+ args[id] = Bytes::unpack(arg_buf[id]);
+
+ // Execute the function with the provided arguments and send back any copies
+ // made for pointer data.
+ using BufferTy = rpc::conditional_t<rpc::is_void_v<RetTy>, uint8_t, RetTy>;
+ BufferTy rets[NUM_LANES]{};
+ for (uint32_t id = 0; id < NUM_LANES; ++id) {
+ if (port.get_lane_mask() & (uint64_t(1) << id)) {
+ if constexpr (rpc::is_void_v<RetTy>)
+ rpc::apply(fn, args[id]);
+ else
+ rets[id] = rpc::apply(fn, args[id]);
+ }
+ }
+
+ // Send any potentially modified pointer arguments back to the client.
+ rpc::finish_args<NUM_LANES>(port, args,
+ rpc::make_index_sequence<Traits::ARITY>{});
+
+ // Copy back the return value of the function if one exists. If the function
+ // is void we send an empty packet to force synchronous behavior.
+ port.send_n(rets);
+}
+} // namespace rpc
diff --git a/libc/shared/rpc_util.h b/libc/shared/rpc_util.h
index 687814b7ff2ae..4f08258b3a341 100644
--- a/libc/shared/rpc_util.h
+++ b/libc/shared/rpc_util.h
@@ -42,12 +42,65 @@ template <class T, T v> struct type_constant {
static inline constexpr T value = v;
};
+/// Freestanding type trait helpers.
+template <class T> struct remove_cv : type_identity<T> {};
+template <class T> struct remove_cv<const T> : type_identity<T> {};
+template <class T> using remove_cv_t = typename remove_cv<T>::type;
+
+template <class T> struct remove_pointer : type_identity<T> {};
+template <class T> struct remove_pointer<T *> : type_identity<T> {};
+template <class T> using remove_pointer_t = typename remove_pointer<T>::type;
+
+template <class T> struct remove_const : type_identity<T> {};
+template <class T> struct remove_const<const T> : type_identity<T> {};
+template <class T> using remove_const_t = typename remove_const<T>::type;
+
template <class T> struct remove_reference : type_identity<T> {};
template <class T> struct remove_reference<T &> : type_identity<T> {};
template <class T> struct remove_reference<T &&> : type_identity<T> {};
+template <class T>
+using remove_reference_t = typename remove_reference<T>::type;
template <class T> struct is_const : type_constant<bool, false> {};
template <class T> struct is_const<const T> : type_constant<bool, true> {};
+template <class T> RPC_ATTRS constexpr bool is_const_v = is_const<T>::value;
+
+template <typename T> struct is_pointer : type_constant<bool, false> {};
+template <typename T> struct is_pointer<T *> : type_constant<bool, true> {};
+template <typename T>
+struct is_pointer<T *const> : type_constant<bool, true> {};
+template <typename T>
+RPC_ATTRS constexpr bool is_pointer_v = is_pointer<T>::value;
+
+template <typename T, typename U>
+struct is_same : type_constant<bool, false> {};
+template <typename T> struct is_same<T, T> : type_constant<bool, true> {};
+template <typename T, typename U>
+RPC_ATTRS constexpr bool is_same_v = is_same<T, U>::value;
+
+template <class T> struct is_void : type_constant<bool, false> {};
+template <> struct is_void<void> : type_constant<bool, true> {};
+template <typename T> RPC_ATTRS constexpr bool is_void_v = is_void<T>::value;
+
+template <class T>
+struct is_trivially_copyable
+ : public type_constant<bool, __is_trivially_copyable(T)> {};
+template <class T>
+RPC_ATTRS constexpr bool is_trivially_copyable_v =
+ is_trivially_copyable<T>::value;
+
+template <class T, class... Args>
+struct is_trivially_constructible
+ : type_constant<bool, __is_trivially_constructible(T, Args...)> {};
+template <class T, class... Args>
+RPC_ATTRS constexpr bool is_trivially_constructible_v =
+ is_trivially_constructible<T>::value;
+
+template <bool B, class T, class F> struct conditional : type_identity<T> {};
+template <class T, class F>
+struct conditional<false, T, F> : type_identity<F> {};
+template <bool B, class T, class F>
+using conditional_t = typename conditional<B, T, F>::type;
/// Freestanding implementation of std::move.
template <class T>
@@ -55,6 +108,29 @@ RPC_ATTRS constexpr typename remove_reference<T>::type &&move(T &&t) {
return static_cast<typename remove_reference<T>::type &&>(t);
}
+/// Freestanding integer sequence.
+template <typename T, T... Ints> struct integer_sequence {
+ template <T Next> using append = integer_sequence<T, Ints..., Next>;
+};
+
+namespace detail {
+template <typename T, int N> struct make_integer_sequence {
+ using type =
+ typename make_integer_sequence<T, N - 1>::type::template append<N>;
+};
+template <typename T> struct make_integer_sequence<T, -1> {
+ using type = integer_sequence<T>;
+};
+} // namespace detail
+
+template <uint64_t... Ints>
+using index_sequence = integer_sequence<uint64_t, Ints...>;
+template <int N>
+using make_index_sequence =
+ typename detail::make_integer_sequence<uint64_t, N - 1>::type;
+template <typename... Ts>
+using index_sequence_for = make_index_sequence<sizeof...(Ts)>;
+
/// Freestanding implementation of std::forward.
template <typename T>
RPC_ATTRS constexpr T &&forward(typename remove_reference<T>::type &value) {
@@ -150,6 +226,84 @@ template <typename T> class optional {
RPC_ATTRS constexpr T &&operator*() && { return move(storage.stored_value); }
};
+/// Minimal array type.
+template <typename T, uint64_t N> struct array {
+ T elems[N];
+
+ RPC_ATTRS constexpr T *data() { return elems; }
+ RPC_ATTRS constexpr const T *data() const { return elems; }
+ RPC_ATTRS static constexpr uint64_t size() { return N; }
+
+ RPC_ATTRS constexpr T &operator[](uint64_t i) { return elems[i]; }
+ RPC_ATTRS constexpr const T &operator[](uint64_t i) const { return elems[i]; }
+};
+
+/// Minimal tuple type.
+template <typename... Ts> struct tuple;
+template <> struct tuple<> {};
+
+template <typename Head, typename... Tail>
+struct tuple<Head, Tail...> : tuple<Tail...> {
+ Head head;
+
+ RPC_ATTRS constexpr tuple() = default;
+
+ template <typename OHead, typename... OTail>
+ RPC_ATTRS constexpr tuple &operator=(const tuple<OHead, OTail...> &other) {
+ head = other.get_head();
+ this->get_tail() = other.get_tail();
+ return *this;
+ }
+
+ RPC_ATTRS constexpr tuple(const Head &h, const Tail &...t)
+ : tuple<Tail...>(t...), head(h) {}
+
+ RPC_ATTRS constexpr Head &get_head() { return head; }
+ RPC_ATTRS constexpr const Head &get_head() const { return head; }
+
+ RPC_ATTRS constexpr tuple<Tail...> &get_tail() { return *this; }
+ RPC_ATTRS constexpr const tuple<Tail...> &get_tail() const { return *this; }
+};
+
+template <size_t Idx, typename T> struct tuple_element;
+template <size_t Idx, typename Head, typename... Tail>
+struct tuple_element<Idx, tuple<Head, Tail...>>
+ : tuple_element<Idx - 1, tuple<Tail...>> {};
+template <typename Head, typename... Tail>
+struct tuple_element<0, tuple<Head, Tail...>> {
+ using type = remove_cv_t<remove_reference_t<Head>>;
+};
+template <size_t Idx, typename T>
+using tuple_element_t = typename tuple_element<Idx, T>::type;
+
+template <uint64_t Idx, typename Head, typename... Tail>
+RPC_ATTRS constexpr auto &get(tuple<Head, Tail...> &t) {
+ if constexpr (Idx == 0)
+ return t.get_head();
+ else
+ return get<Idx - 1>(t.get_tail());
+}
+template <uint64_t Idx, typename Head, typename... Tail>
+RPC_ATTRS constexpr const auto &get(const tuple<Head, Tail...> &t) {
+ if constexpr (Idx == 0)
+ return t.get_head();
+ else
+ return get<Idx - 1>(t.get_tail());
+}
+
+namespace detail {
+template <typename F, typename Tuple, uint64_t... Is>
+RPC_ATTRS auto apply(F &&f, Tuple &&t, index_sequence<Is...>) {
+ return f(get<Is>(static_cast<Tuple &&>(t))...);
+}
+} // namespace detail
+
+template <typename F, typename... Ts>
+RPC_ATTRS auto apply(F &&f, tuple<Ts...> &t) {
+ return detail::apply(static_cast<F &&>(f), t,
+ make_index_sequence<sizeof...(Ts)>{});
+}
+
/// Suspend the thread briefly to assist the thread scheduler during busy loops.
RPC_ATTRS void sleep_briefly() {
#if __has_builtin(__nvvm_reflect)
@@ -263,14 +417,26 @@ template <typename T, typename U> RPC_ATTRS T *advance(T *ptr, U bytes) {
}
/// Wrapper around the optimal memory copy implementation for the target.
-RPC_ATTRS void rpc_memcpy(void *dst, const void *src, size_t count) {
+RPC_ATTRS void rpc_memcpy(void *dst, const void *src, uint64_t count) {
__builtin_memcpy(dst, src, count);
}
-template <class T> RPC_ATTRS constexpr const T &max(const T &a, const T &b) {
+/// Minimal string length function.
+RPC_ATTRS constexpr uint64_t string_length(const char *s) {
+ const char *end = s;
+ for (; *end != '\0'; ++end)
+ ;
+ return static_cast<uint64_t>(end - s + 1);
+}
+
+template <class T, class U> RPC_ATTRS constexpr T max(const T &a, const U &b) {
return (a < b) ? b : a;
}
+template <class T, class U> RPC_ATTRS constexpr T min(const T &a, const U &b) {
+ return (a < b) ? a : b;
+}
+
} // namespace rpc
#endif // LLVM_LIBC_SHARED_RPC_UTIL_H
diff --git a/offload/test/libc/rpc_callback.c b/offload/test/libc/rpc_callback.c
deleted file mode 100644
index 223b54eddd81a..0000000000000
--- a/offload/test/libc/rpc_callback.c
+++ /dev/null
@@ -1,66 +0,0 @@
-// RUN: %libomptarget-compilexx-run-and-check-generic
-
-// REQUIRES: libc
-// REQUIRES: gpu
-
-#include <assert.h>
-#include <stdint.h>
-#include <stdio.h>
-
-// CHECK: PASS
-
-// This should be present in-tree relative to the test directory. If someone is
-// using a partial tree just pass the test.
-#if !__has_include(<../../libc/shared/rpc.h>)
-int main() { printf("PASS\n"); }
-#else
-#include <../../libc/shared/rpc.h>
-
-extern "C" void __tgt_register_rpc_callback(unsigned (*Callback)(void *,
- unsigned));
-constexpr uint32_t RPC_TEST_OPCODE = 1;
-
-template<uint32_t NumLanes> rpc::Status handleOpcodes(rpc::Server::Port &Port) {
- switch (Port.get_opcode()) {
- case RPC_TEST_OPCODE: {
- Port.recv(
- [&](rpc::Buffer *Buffer, uint32_t) { assert(Buffer->data[0] == 42); });
- Port.send([&](rpc::Buffer *, uint32_t) {});
- break;
- }
- default:
- return rpc::RPC_UNHANDLED_OPCODE;
- break;
- }
- return rpc::RPC_SUCCESS;
-}
-
-static uint32_t handleOffloadOpcodes(void *Raw, uint32_t NumLanes) {
- rpc::Server::Port &Port = *reinterpret_cast<rpc::Server::Port *>(Raw);
- if (NumLanes == 1)
- return handleOpcodes<1>(Port);
- else if (NumLanes == 32)
- return handleOpcodes<32>(Port);
- else if (NumLanes == 64)
- return handleOpcodes<64>(Port);
- else
- return rpc::RPC_ERROR;
-}
-
-[[gnu::weak]] rpc::Client client asm("__llvm_rpc_client");
-#pragma omp declare target to(client) device_type(nohost)
-
-void __tgt_register_rpc_callback(unsigned (*Callback)(void *, unsigned));
-
-int main() {
- __tgt_register_rpc_callback(&handleOffloadOpcodes);
-#pragma omp target
- {
- rpc::Client::Port Port = client.open<RPC_TEST_OPCODE>();
- Port.send([=](rpc::Buffer *buffer, uint32_t) { buffer->data[0] = 42; });
- Port.recv([](rpc::Buffer *, uint32_t) {});
- Port.close();
- }
- printf("PASS\n");
-}
-#endif
diff --git a/offload/test/libc/rpc_callback.cpp b/offload/test/libc/rpc_callback.cpp
new file mode 100644
index 0000000000000..882aed47e4262
--- /dev/null
+++ b/offload/test/libc/rpc_callback.cpp
@@ -0,0 +1,171 @@
+// RUN: %libomptarget-compilexx-run-and-check-generic
+// REQUIRES: libc
+// REQUIRES: gpu
+
+#include <assert.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <string.h>
+
+// CHECK: PASS
+
+// If the RPC headers are not present, just pass the test.
+#if !__has_include(<../../libc/shared/rpc.h>)
+int main() { printf("PASS\n"); }
+#else
+
+#include <../../libc/shared/rpc.h>
+#include <../../libc/shared/rpc_dispatch.h>
+
+[[gnu::weak]] rpc::Client client asm("__llvm_rpc_client");
+#pragma omp declare target to(client) device_type(nohost)
+
+//===------------------------------------------------------------------------===
+// Opcodes.
+//===------------------------------------------------------------------------===
+
+constexpr uint32_t FOO_OPCODE = 1;
+constexpr uint32_t VOID_OPCODE = 2;
+constexpr uint32_t WRITEBACK_OPCODE = 3;
+constexpr uint32_t CONST_PTR_OPCODE = 4;
+constexpr uint32_t STRING_OPCODE = 5;
+
+//===------------------------------------------------------------------------===
+// Server-side implementations.
+//===------------------------------------------------------------------------===
+
+struct S {
+ int arr[4];
+};
+
+// 1. Non-pointer arguments, non-void return.
+int foo(int x, double d, char c) {
+ assert(x == 42);
+ assert(d == 0.0);
+ assert(c == 'c');
+ return -1;
+}
+
+// 2. Void return type.
+void void_fn(int x) { assert(x == 7); }
+
+// 3. Write-back pointer.
+void writeback_fn(int *out) {
+ assert(out != nullptr && *out == 42);
+ *out = 99;
+}
+
+// 4. Const pointer.
+int sum_const(const S *p) {
+ int s = 0;
+ for (int i = 0; i < 4; ++i)
+ s += p->arr[i];
+ return s;
+}
+
+// 5. const char * string.
+int c_string(const char *s) {
+ assert(s != nullptr);
+ assert(strcmp(s, "hello") == 0);
+ return strlen(s);
+}
+
+//===------------------------------------------------------------------------===
+// RPC client dispatch.
+//===------------------------------------------------------------------------===
+
+#pragma omp begin declare variant match(device = {kind(gpu)})
+int foo(int x, double d, char c) {
+ return rpc::dispatch<FOO_OPCODE>(client, foo, x, d, c);
+}
+
+void void_fn(int x) { rpc::dispatch<VOID_OPCODE>(client, void_fn, x); }
+
+void writeback_fn(int *out) {
+ rpc::dispatch<WRITEBACK_OPCODE>(client, writeback_fn, out);
+}
+
+int sum_const(const S *p) {
+ return rpc::dispatch<CONST_PTR_OPCODE>(client, sum_const, p);
+}
+
+int c_string(const char *s) {
+ return rpc::dispatch<STRING_OPCODE>(client, c_string, s);
+}
+#pragma omp end declare variant
+
+//===------------------------------------------------------------------------===
+// RPC server dispatch.
+//===------------------------------------------------------------------------===
+
+template <uint32_t NUM_LANES>
+rpc::Status handleOpcodesImpl(rpc::Server::Port &Port) {
+ switch (Port.get_opcode()) {
+ case FOO_OPCODE:
+ rpc::invoke<NUM_LANES>(foo, Port);
+ break;
+ case VOID_OPCODE:
+ rpc::invoke<NUM_LANES>(void_fn, Port);
+ break;
+ case WRITEBACK_OPCODE:
+ rpc::invoke<NUM_LANES>(writeback_fn, Port);
+ break;
+ case CONST_PTR_OPCODE:
+ rpc::invoke<NUM_LANES>(sum_const, Port);
+ break;
+ case STRING_OPCODE:
+ rpc::invoke<NUM_LANES>(c_string, Port);
+ break;
+ default:
+ return rpc::RPC_UNHANDLED_OPCODE;
+ }
+ return rpc::RPC_SUCCESS;
+}
+
+static uint32_t handleOpcodes(void *raw, uint32_t numLanes) {
+ rpc::Server::Port &Port = *reinterpret_cast<rpc::Server::Port *>(raw);
+ if (numLanes == 1)
+ return handleOpcodesImpl<1>(Port);
+ else if (numLanes == 32)
+ return handleOpcodesImpl<32>(Port);
+ else if (numLanes == 64)
+ return handleOpcodesImpl<64>(Port);
+ else
+ return rpc::RPC_ERROR;
+}
+
+extern "C" void __tgt_register_rpc_callback(unsigned (*callback)(void *,
+ unsigned));
+
+int main() {
+ __tgt_register_rpc_callback(&handleOpcodes);
+
+#pragma omp target
+#pragma omp parallel num_threads(32)
+ {
+ // 1. Non-pointer return.
+ assert(foo(42, 0.0, 'c') == -1);
+
+ // 2. Void return.
+ void_fn(7);
+
+ // 3. Write-back pointer.
+ int value = 42;
+ writeback_fn(&value);
+ assert(value == 99);
+
+ // 4. Const pointer.
+ S s{1, 2, 3, 4};
+ int sum = sum_const(&s);
+ assert(sum == 10);
+
+ // 5. const char * string.
+ const char *msg = "hello";
+ int len = c_string(msg);
+ assert(len == 5);
+ }
+
+ printf("PASS\n");
+}
+
+#endif
More information about the libc-commits
mailing list