[libc-commits] [libc] 5f5d27d - [libc] Support array tags in the RPC dispatch helpers (#181395)
via libc-commits
libc-commits at lists.llvm.org
Fri Feb 20 07:35:52 PST 2026
Author: Joseph Huber
Date: 2026-02-20T09:35:47-06:00
New Revision: 5f5d27d7d32e9d85fbdabc0bc2632ce0ab485681
URL: https://github.com/llvm/llvm-project/commit/5f5d27d7d32e9d85fbdabc0bc2632ce0ab485681
DIFF: https://github.com/llvm/llvm-project/commit/5f5d27d7d32e9d85fbdabc0bc2632ce0ab485681.diff
LOG: [libc] Support array tags in the RPC dispatch helpers (#181395)
Summary:
This PR adds support for tagging a pointer as an array when marshaling
between the CPU and GPU.
Added:
Modified:
libc/shared/rpc_dispatch.h
libc/shared/rpc_util.h
offload/test/libc/rpc_callback.cpp
Removed:
################################################################################
diff --git a/libc/shared/rpc_dispatch.h b/libc/shared/rpc_dispatch.h
index d9ea5fe661b49..c9c2502e557bc 100644
--- a/libc/shared/rpc_dispatch.h
+++ b/libc/shared/rpc_dispatch.h
@@ -54,47 +54,73 @@ template <typename... Ts> struct tuple_bytes {
template <typename... Ts>
struct tuple_bytes<rpc::tuple<Ts...>> : tuple_bytes<Ts...> {};
+// Whether a pointer value will be marshalled between the client and server.
+template <typename Ty, typename PtrTy = rpc::remove_reference_t<Ty>,
+ typename ElemTy = rpc::remove_pointer_t<PtrTy>>
+RPC_ATTRS constexpr bool is_marshalled_ptr_v =
+ rpc::is_pointer_v<PtrTy> && rpc::is_complete_v<ElemTy> &&
+ !rpc::is_void_v<ElemTy>;
+
+// Get an index value for the marshalled types in the tuple type.
+template <class Tuple, uint64_t... Is>
+constexpr uint64_t marshalled_index(rpc::index_sequence<Is...>) {
+ return (0u + ... +
+ (rpc::is_marshalled_ptr_v<rpc::tuple_element_t<Is, Tuple>>));
+}
+template <class Tuple, uint64_t I>
+constexpr uint64_t marshalled_index_v =
+ marshalled_index<Tuple>(rpc::make_index_sequence<I>{});
+
+// Storage for the marshalled arguments from the client.
+template <uint32_t NUM_LANES, typename... Ts> struct MarshalledState {
+ static constexpr uint32_t NUM_PTRS =
+ rpc::marshalled_index_v<rpc::tuple<Ts...>, sizeof...(Ts)>;
+
+ uint64_t sizes[NUM_PTRS][NUM_LANES]{};
+ void *ptrs[NUM_PTRS][NUM_LANES]{};
+};
+template <uint32_t NUM_LANES, typename... Ts>
+struct MarshalledState<NUM_LANES, rpc::tuple<Ts...>>
+ : MarshalledState<NUM_LANES, Ts...> {};
+
// Client-side dispatch of pointer values. We copy the memory associated with
// the pointer to the server and receive back a server-side pointer to replace
// the client-side pointer in the argument list.
-template <uint64_t Idx, typename Tuple>
-RPC_ATTRS constexpr void prepare_arg(rpc::Client::Port &port, Tuple &t) {
+template <uint64_t Idx, typename Tuple, typename CallTuple>
+RPC_ATTRS constexpr void prepare_arg(rpc::Client::Port &port, Tuple &t,
+ CallTuple &ct) {
using ArgTy = rpc::tuple_element_t<Idx, Tuple>;
- using ElemTy = rpc::remove_pointer_t<ArgTy>;
- if constexpr (rpc::is_pointer_v<ArgTy> && rpc::is_complete_v<ElemTy> &&
- !rpc::is_void_v<ElemTy>) {
+ using CallArgTy = rpc::tuple_element_t<Idx, CallTuple>;
+ if constexpr (rpc::is_marshalled_ptr_v<ArgTy>) {
// We assume all constant character arrays are C-strings.
uint64_t size{};
if constexpr (rpc::is_same_v<ArgTy, const char *>)
size = rpc::string_length(rpc::get<Idx>(t));
+ else if constexpr (rpc::is_span_v<CallArgTy>)
+ size = rpc::get<Idx>(ct).size * sizeof(rpc::remove_pointer_t<ArgTy>);
else
size = sizeof(rpc::remove_pointer_t<ArgTy>);
port.send_n(rpc::get<Idx>(t), size);
port.recv([&](rpc::Buffer *buffer, uint32_t) {
- rpc::get<Idx>(t) = *reinterpret_cast<ArgTy *>(buffer->data);
+ ArgTy val;
+ rpc::rpc_memcpy(&val, buffer->data, sizeof(ArgTy));
+ rpc::get<Idx>(t) = val;
});
}
}
// Server-side handling of pointer arguments. We receive the memory into a
// temporary buffer and pass a pointer to this new memory back to the client.
-template <uint32_t NUM_LANES, typename Tuple, uint64_t Idx>
-RPC_ATTRS constexpr void prepare_arg(rpc::Server::Port &port) {
+template <uint32_t NUM_LANES, typename Tuple, uint64_t Idx, typename State>
+RPC_ATTRS constexpr void prepare_arg(rpc::Server::Port &port, State &&state) {
using ArgTy = rpc::tuple_element_t<Idx, Tuple>;
- using ElemTy = rpc::remove_pointer_t<ArgTy>;
- if constexpr (rpc::is_pointer_v<ArgTy> && rpc::is_complete_v<ElemTy> &&
- !rpc::is_void_v<ElemTy>) {
- void *args[NUM_LANES]{};
- uint64_t sizes[NUM_LANES]{};
- port.recv_n(args, sizes, [](uint64_t size) {
- if constexpr (rpc::is_same_v<ArgTy, const char *>)
- return malloc(size);
- else
- return malloc(
- sizeof(rpc::remove_const_t<rpc::remove_pointer_t<ArgTy>>));
- });
+ if constexpr (rpc::is_marshalled_ptr_v<ArgTy>) {
+ auto &ptrs = state.ptrs[rpc::marshalled_index_v<Tuple, Idx>];
+ auto &sizes = state.sizes[rpc::marshalled_index_v<Tuple, Idx>];
+ port.recv_n(ptrs, sizes, [](uint64_t size) { return malloc(size); });
port.send([&](rpc::Buffer *buffer, uint32_t id) {
- *reinterpret_cast<ArgTy *>(buffer->data) = static_cast<ArgTy>(args[id]);
+ ArgTy val = static_cast<ArgTy>(ptrs[id]);
+ rpc::rpc_memcpy(buffer->data, &val, sizeof(ArgTy));
});
}
}
@@ -105,14 +131,11 @@ RPC_ATTRS constexpr void prepare_arg(rpc::Server::Port &port) {
template <uint64_t Idx, typename Tuple>
RPC_ATTRS constexpr void finish_arg(rpc::Client::Port &port, Tuple &t) {
using ArgTy = rpc::tuple_element_t<Idx, Tuple>;
- using ElemTy = rpc::remove_pointer_t<ArgTy>;
- using MemoryTy = rpc::remove_const_t<rpc::remove_pointer_t<ArgTy>> *;
- if constexpr (rpc::is_pointer_v<ArgTy> && !rpc::is_const_v<ArgTy> &&
- rpc::is_complete_v<ElemTy> && !rpc::is_void_v<ElemTy>) {
+ if constexpr (rpc::is_marshalled_ptr_v<ArgTy> && !rpc::is_const_v<ArgTy>) {
uint64_t size{};
void *buf{};
port.recv_n(&buf, &size, [&](uint64_t) {
- return const_cast<MemoryTy>(rpc::get<Idx>(t));
+ return const_cast<void *>(static_cast<const void *>(rpc::get<Idx>(t)));
});
}
}
@@ -120,30 +143,20 @@ RPC_ATTRS constexpr void finish_arg(rpc::Client::Port &port, Tuple &t) {
// Server-side finalization of pointer arguments. We copy any potential
// modifications to the value back to the client unless it was a constant. We
// can also free the associated memory.
-template <uint32_t NUM_LANES, uint64_t Idx, typename Tuple>
-RPC_ATTRS constexpr void finish_arg(rpc::Server::Port &port,
- Tuple (&t)[NUM_LANES]) {
+template <uint32_t NUM_LANES, typename Tuple, uint64_t Idx, typename State>
+RPC_ATTRS constexpr void finish_arg(rpc::Server::Port &port, State &&state) {
using ArgTy = rpc::tuple_element_t<Idx, Tuple>;
- using ElemTy = rpc::remove_pointer_t<ArgTy>;
- if constexpr (rpc::is_pointer_v<ArgTy> && !rpc::is_const_v<ArgTy> &&
- rpc::is_complete_v<ElemTy> && !rpc::is_void_v<ElemTy>) {
- const void *buffer[NUM_LANES]{};
- size_t sizes[NUM_LANES]{};
- for (uint32_t id = 0; id < NUM_LANES; ++id) {
- if (port.get_lane_mask() & (uint64_t(1) << id)) {
- buffer[id] = rpc::get<Idx>(t[id]);
- sizes[id] = sizeof(rpc::remove_pointer_t<ArgTy>);
- }
- }
- port.send_n(buffer, sizes);
+ if constexpr (rpc::is_marshalled_ptr_v<ArgTy> && !rpc::is_const_v<ArgTy>) {
+ auto &ptrs = state.ptrs[rpc::marshalled_index_v<Tuple, Idx>];
+ auto &sizes = state.sizes[rpc::marshalled_index_v<Tuple, Idx>];
+ port.send_n(ptrs, sizes);
}
- if constexpr (rpc::is_pointer_v<ArgTy> && rpc::is_complete_v<ElemTy> &&
- !rpc::is_void_v<ElemTy>) {
+ if constexpr (rpc::is_marshalled_ptr_v<ArgTy>) {
+ auto &ptrs = state.ptrs[rpc::marshalled_index_v<Tuple, Idx>];
for (uint32_t id = 0; id < NUM_LANES; ++id) {
if (port.get_lane_mask() & (uint64_t(1) << id))
- free(const_cast<void *>(
- static_cast<const void *>(rpc::get<Idx>(t[id]))));
+ free(const_cast<void *>(static_cast<const void *>(ptrs[id])));
}
}
}
@@ -151,15 +164,16 @@ RPC_ATTRS constexpr void finish_arg(rpc::Server::Port &port,
// Iterate over the tuple list of arguments to see if we need to forward any.
// The current forwarding is somewhat inefficient as each pointer is an
// individual RPC call.
-template <typename Tuple, uint64_t... Is>
+template <typename Tuple, typename CallTuple, uint64_t... Is>
RPC_ATTRS constexpr void prepare_args(rpc::Client::Port &port, Tuple &t,
+ CallTuple &ct,
rpc::index_sequence<Is...>) {
- (prepare_arg<Is>(port, t), ...);
+ (prepare_arg<Is>(port, t, ct), ...);
}
-template <uint32_t NUM_LANES, typename Tuple, uint64_t... Is>
-RPC_ATTRS constexpr void prepare_args(rpc::Server::Port &port,
+template <uint32_t NUM_LANES, typename Tuple, typename State, uint64_t... Is>
+RPC_ATTRS constexpr void prepare_args(rpc::Server::Port &port, State &&state,
rpc::index_sequence<Is...>) {
- (prepare_arg<NUM_LANES, Tuple, Is>(port), ...);
+ (prepare_arg<NUM_LANES, Tuple, Is>(port, state), ...);
}
// Performs the preparation in reverse, copying back any modified values.
@@ -168,11 +182,10 @@ RPC_ATTRS constexpr void finish_args(rpc::Client::Port &port, Tuple &&t,
rpc::index_sequence<Is...>) {
(finish_arg<Is>(port, t), ...);
}
-template <uint32_t NUM_LANES, typename Tuple, uint64_t... Is>
-RPC_ATTRS constexpr void finish_args(rpc::Server::Port &port,
- Tuple (&t)[NUM_LANES],
+template <uint32_t NUM_LANES, typename Tuple, typename State, uint64_t... Is>
+RPC_ATTRS constexpr void finish_args(rpc::Server::Port &port, State &&state,
rpc::index_sequence<Is...>) {
- (finish_arg<NUM_LANES, Is>(port, t), ...);
+ (finish_arg<NUM_LANES, Tuple, Is>(port, state), ...);
}
} // namespace
@@ -184,7 +197,7 @@ dispatch(rpc::Client &client, FnTy, CallArgs... args) {
using Traits = function_traits<FnTy>;
using RetTy = typename Traits::return_type;
using TupleTy = typename Traits::arg_types;
- using Bytes = tuple_bytes<CallArgs...>;
+ using Bytes = tuple_bytes<rpc::remove_span_t<CallArgs>...>;
static_assert(sizeof...(CallArgs) == Traits::ARITY,
"Argument count mismatch");
@@ -196,8 +209,10 @@ dispatch(rpc::Client &client, FnTy, CallArgs... args) {
auto port = client.open<OPCODE>();
// Copy over any pointer arguments by walking the argument list.
+ rpc::tuple<CallArgs...> call_args{args...};
TupleTy arg_tuple{rpc::forward<CallArgs>(args)...};
- rpc::prepare_args(port, arg_tuple, rpc::make_index_sequence<Traits::ARITY>{});
+ rpc::prepare_args(port, arg_tuple, call_args,
+ rpc::make_index_sequence<Traits::ARITY>{});
// Compress the argument list to a binary stream and send it to the server.
auto bytes = Bytes::pack(arg_tuple);
@@ -226,8 +241,9 @@ RPC_ATTRS constexpr void invoke(rpc::Server::Port &port, FnTy fn) {
using Bytes = tuple_bytes<TupleTy>;
// Receive pointer data from the host and pack it in server-side memory.
+ MarshalledState<NUM_LANES, TupleTy> state{};
rpc::prepare_args<NUM_LANES, TupleTy>(
- port, rpc::make_index_sequence<Traits::ARITY>{});
+ port, state, rpc::make_index_sequence<Traits::ARITY>{});
// Get the argument list from the client.
typename Bytes::array_type arg_buf[NUM_LANES]{};
@@ -254,8 +270,8 @@ RPC_ATTRS constexpr void invoke(rpc::Server::Port &port, FnTy fn) {
}
// Send any potentially modified pointer arguments back to the client.
- rpc::finish_args<NUM_LANES>(port, args,
- rpc::make_index_sequence<Traits::ARITY>{});
+ rpc::finish_args<NUM_LANES, TupleTy>(
+ port, state, rpc::make_index_sequence<Traits::ARITY>{});
// Copy back the return value of the function if one exists. If the function
// is void we send an empty packet to force synchronous behavior.
diff --git a/libc/shared/rpc_util.h b/libc/shared/rpc_util.h
index 44eb475919947..03ec0fa5030e7 100644
--- a/libc/shared/rpc_util.h
+++ b/libc/shared/rpc_util.h
@@ -63,24 +63,23 @@ using remove_reference_t = typename remove_reference<T>::type;
template <typename T> struct is_const : type_constant<bool, false> {};
template <typename T> struct is_const<const T> : type_constant<bool, true> {};
-template <typename T> RPC_ATTRS constexpr bool is_const_v = is_const<T>::value;
+template <typename T> inline constexpr bool is_const_v = is_const<T>::value;
template <typename T> struct is_pointer : type_constant<bool, false> {};
template <typename T> struct is_pointer<T *> : type_constant<bool, true> {};
template <typename T>
struct is_pointer<T *const> : type_constant<bool, true> {};
-template <typename T>
-RPC_ATTRS constexpr bool is_pointer_v = is_pointer<T>::value;
+template <typename T> inline constexpr bool is_pointer_v = is_pointer<T>::value;
template <typename T, typename U>
struct is_same : type_constant<bool, false> {};
template <typename T> struct is_same<T, T> : type_constant<bool, true> {};
template <typename T, typename U>
-RPC_ATTRS constexpr bool is_same_v = is_same<T, U>::value;
+inline constexpr bool is_same_v = is_same<T, U>::value;
template <typename T> struct is_void : type_constant<bool, false> {};
template <> struct is_void<void> : type_constant<bool, true> {};
-template <typename T> RPC_ATTRS constexpr bool is_void_v = is_void<T>::value;
+template <typename T> inline constexpr bool is_void_v = is_void<T>::value;
// Scary trait that can change within a TU, use with caution.
template <typename...> using void_t = void;
@@ -90,22 +89,36 @@ template <typename T>
struct is_complete<T, void_t<decltype(sizeof(T))>> : type_constant<bool, true> {
};
template <typename T>
-RPC_ATTRS constexpr bool is_complete_v = is_complete<T>::value;
+inline constexpr bool is_complete_v = is_complete<T>::value;
template <typename T>
struct is_trivially_copyable
: public type_constant<bool, __is_trivially_copyable(T)> {};
template <typename T>
-RPC_ATTRS constexpr bool is_trivially_copyable_v =
- is_trivially_copyable<T>::value;
+inline constexpr bool is_trivially_copyable_v = is_trivially_copyable<T>::value;
template <typename T, typename... Args>
struct is_trivially_constructible
: type_constant<bool, __is_trivially_constructible(T, Args...)> {};
template <typename T, typename... Args>
-RPC_ATTRS constexpr bool is_trivially_constructible_v =
+inline constexpr bool is_trivially_constructible_v =
is_trivially_constructible<T>::value;
+/// Tag type to indicate an array of elements being passed through RPC.
+template <typename T> struct span {
+ T *data;
+ uint64_t size;
+ RPC_ATTRS operator T *() const { return data; }
+};
+
+template <typename T> struct is_span : type_constant<bool, false> {};
+template <typename T> struct is_span<span<T>> : type_constant<bool, true> {};
+template <typename T> inline constexpr bool is_span_v = is_span<T>::value;
+
+template <typename T> struct remove_span : type_identity<T> {};
+template <typename T> struct remove_span<span<T>> : type_identity<T *> {};
+template <typename T> using remove_span_t = typename remove_span<T>::type;
+
template <bool B, typename T, typename F>
struct conditional : type_identity<T> {};
template <typename T, typename F>
@@ -449,6 +462,12 @@ template <typename R, typename... Args> struct function_traits<R (*)(Args...)> {
using arg_types = rpc::tuple<Args...>;
static constexpr uint64_t ARITY = sizeof...(Args);
};
+template <typename R, typename... Args>
+struct function_traits<R (*)(Args...) noexcept> {
+ using return_type = R;
+ using arg_types = rpc::tuple<Args...>;
+ static constexpr uint64_t ARITY = sizeof...(Args);
+};
template <typename T> T &&declval();
template <typename T>
struct function_traits
diff --git a/offload/test/libc/rpc_callback.cpp b/offload/test/libc/rpc_callback.cpp
index f313541e58d1d..75a942385082b 100644
--- a/offload/test/libc/rpc_callback.cpp
+++ b/offload/test/libc/rpc_callback.cpp
@@ -1,4 +1,5 @@
-// RUN: %libomptarget-compilexx-run-and-check-generic
+// RUN: %libomptarget-compilexx-generic -fopenmp-cuda-mode
+// RUN: %libomptarget-run-generic
// REQUIRES: libc
// REQUIRES: gpu
@@ -32,6 +33,7 @@ constexpr uint32_t CONST_PTR_OPCODE = 4;
constexpr uint32_t STRING_OPCODE = 5;
constexpr uint32_t EMPTY_OPCODE = 6;
constexpr uint32_t DIVERGENT_OPCODE = 7;
+constexpr uint32_t ARRAY_SUM_OPCODE = 8;
//===------------------------------------------------------------------------===
// Server-side implementations.
@@ -82,6 +84,14 @@ void divergent(int *p) {
*p = *p;
}
+// 8. Array argument via span.
+int sum_array(const int *arr, int n) {
+ int s = 0;
+ for (int i = 0; i < n; ++i)
+ s += arr[i];
+ return s;
+}
+
//===------------------------------------------------------------------------===
// RPC client dispatch.
//===------------------------------------------------------------------------===
@@ -110,6 +120,11 @@ int empty() { return rpc::dispatch<EMPTY_OPCODE>(client, empty); }
void divergent(int *p) {
rpc::dispatch<DIVERGENT_OPCODE>(client, divergent, p);
}
+
+int sum_array(const int *arr, int n) {
+ return rpc::dispatch<ARRAY_SUM_OPCODE>(
+ client, sum_array, rpc::span<const int>{arr, uint64_t(n)}, n);
+}
#pragma omp end declare variant
//===------------------------------------------------------------------------===
@@ -143,6 +158,9 @@ rpc::Status handleOpcodesImpl(rpc::Server::Port &Port) {
*p = *p;
});
break;
+ case ARRAY_SUM_OPCODE:
+ rpc::invoke<NUM_LANES>(Port, sum_array);
+ break;
default:
return rpc::RPC_UNHANDLED_OPCODE;
}
@@ -203,6 +221,11 @@ int main() {
if (id % 2)
divergent(&id);
assert(id == omp_get_thread_num());
+
+ // 8. Array sum via span.
+ int arr[4] = {1, 2, 3, 4};
+ int total = sum_array(arr, 4);
+ assert(total == 10);
}
printf("PASS\n");
More information about the libc-commits
mailing list