[libc-commits] [libc] 5f5d27d - [libc] Support array tags in the RPC dispatch helpers (#181395)

via libc-commits libc-commits at lists.llvm.org
Fri Feb 20 07:35:52 PST 2026


Author: Joseph Huber
Date: 2026-02-20T09:35:47-06:00
New Revision: 5f5d27d7d32e9d85fbdabc0bc2632ce0ab485681

URL: https://github.com/llvm/llvm-project/commit/5f5d27d7d32e9d85fbdabc0bc2632ce0ab485681
DIFF: https://github.com/llvm/llvm-project/commit/5f5d27d7d32e9d85fbdabc0bc2632ce0ab485681.diff

LOG: [libc] Support array tags in the RPC dispatch helpers (#181395)

Summary:
This PR adds support for tagging a pointer as an array when marshaling
between the CPU and GPU.

Added: 
    

Modified: 
    libc/shared/rpc_dispatch.h
    libc/shared/rpc_util.h
    offload/test/libc/rpc_callback.cpp

Removed: 
    


################################################################################
diff  --git a/libc/shared/rpc_dispatch.h b/libc/shared/rpc_dispatch.h
index d9ea5fe661b49..c9c2502e557bc 100644
--- a/libc/shared/rpc_dispatch.h
+++ b/libc/shared/rpc_dispatch.h
@@ -54,47 +54,73 @@ template <typename... Ts> struct tuple_bytes {
 template <typename... Ts>
 struct tuple_bytes<rpc::tuple<Ts...>> : tuple_bytes<Ts...> {};
 
+// Whether a pointer value will be marshalled between the client and server.
+template <typename Ty, typename PtrTy = rpc::remove_reference_t<Ty>,
+          typename ElemTy = rpc::remove_pointer_t<PtrTy>>
+RPC_ATTRS constexpr bool is_marshalled_ptr_v =
+    rpc::is_pointer_v<PtrTy> && rpc::is_complete_v<ElemTy> &&
+    !rpc::is_void_v<ElemTy>;
+
+// Get an index value for the marshalled types in the tuple type.
+template <class Tuple, uint64_t... Is>
+constexpr uint64_t marshalled_index(rpc::index_sequence<Is...>) {
+  return (0u + ... +
+          (rpc::is_marshalled_ptr_v<rpc::tuple_element_t<Is, Tuple>>));
+}
+template <class Tuple, uint64_t I>
+constexpr uint64_t marshalled_index_v =
+    marshalled_index<Tuple>(rpc::make_index_sequence<I>{});
+
+// Storage for the marshalled arguments from the client.
+template <uint32_t NUM_LANES, typename... Ts> struct MarshalledState {
+  static constexpr uint32_t NUM_PTRS =
+      rpc::marshalled_index_v<rpc::tuple<Ts...>, sizeof...(Ts)>;
+
+  uint64_t sizes[NUM_PTRS][NUM_LANES]{};
+  void *ptrs[NUM_PTRS][NUM_LANES]{};
+};
+template <uint32_t NUM_LANES, typename... Ts>
+struct MarshalledState<NUM_LANES, rpc::tuple<Ts...>>
+    : MarshalledState<NUM_LANES, Ts...> {};
+
 // Client-side dispatch of pointer values. We copy the memory associated with
 // the pointer to the server and receive back a server-side pointer to replace
 // the client-side pointer in the argument list.
-template <uint64_t Idx, typename Tuple>
-RPC_ATTRS constexpr void prepare_arg(rpc::Client::Port &port, Tuple &t) {
+template <uint64_t Idx, typename Tuple, typename CallTuple>
+RPC_ATTRS constexpr void prepare_arg(rpc::Client::Port &port, Tuple &t,
+                                     CallTuple &ct) {
   using ArgTy = rpc::tuple_element_t<Idx, Tuple>;
-  using ElemTy = rpc::remove_pointer_t<ArgTy>;
-  if constexpr (rpc::is_pointer_v<ArgTy> && rpc::is_complete_v<ElemTy> &&
-                !rpc::is_void_v<ElemTy>) {
+  using CallArgTy = rpc::tuple_element_t<Idx, CallTuple>;
+  if constexpr (rpc::is_marshalled_ptr_v<ArgTy>) {
     // We assume all constant character arrays are C-strings.
     uint64_t size{};
     if constexpr (rpc::is_same_v<ArgTy, const char *>)
       size = rpc::string_length(rpc::get<Idx>(t));
+    else if constexpr (rpc::is_span_v<CallArgTy>)
+      size = rpc::get<Idx>(ct).size * sizeof(rpc::remove_pointer_t<ArgTy>);
     else
       size = sizeof(rpc::remove_pointer_t<ArgTy>);
     port.send_n(rpc::get<Idx>(t), size);
     port.recv([&](rpc::Buffer *buffer, uint32_t) {
-      rpc::get<Idx>(t) = *reinterpret_cast<ArgTy *>(buffer->data);
+      ArgTy val;
+      rpc::rpc_memcpy(&val, buffer->data, sizeof(ArgTy));
+      rpc::get<Idx>(t) = val;
     });
   }
 }
 
 // Server-side handling of pointer arguments. We receive the memory into a
 // temporary buffer and pass a pointer to this new memory back to the client.
-template <uint32_t NUM_LANES, typename Tuple, uint64_t Idx>
-RPC_ATTRS constexpr void prepare_arg(rpc::Server::Port &port) {
+template <uint32_t NUM_LANES, typename Tuple, uint64_t Idx, typename State>
+RPC_ATTRS constexpr void prepare_arg(rpc::Server::Port &port, State &&state) {
   using ArgTy = rpc::tuple_element_t<Idx, Tuple>;
-  using ElemTy = rpc::remove_pointer_t<ArgTy>;
-  if constexpr (rpc::is_pointer_v<ArgTy> && rpc::is_complete_v<ElemTy> &&
-                !rpc::is_void_v<ElemTy>) {
-    void *args[NUM_LANES]{};
-    uint64_t sizes[NUM_LANES]{};
-    port.recv_n(args, sizes, [](uint64_t size) {
-      if constexpr (rpc::is_same_v<ArgTy, const char *>)
-        return malloc(size);
-      else
-        return malloc(
-            sizeof(rpc::remove_const_t<rpc::remove_pointer_t<ArgTy>>));
-    });
+  if constexpr (rpc::is_marshalled_ptr_v<ArgTy>) {
+    auto &ptrs = state.ptrs[rpc::marshalled_index_v<Tuple, Idx>];
+    auto &sizes = state.sizes[rpc::marshalled_index_v<Tuple, Idx>];
+    port.recv_n(ptrs, sizes, [](uint64_t size) { return malloc(size); });
     port.send([&](rpc::Buffer *buffer, uint32_t id) {
-      *reinterpret_cast<ArgTy *>(buffer->data) = static_cast<ArgTy>(args[id]);
+      ArgTy val = static_cast<ArgTy>(ptrs[id]);
+      rpc::rpc_memcpy(buffer->data, &val, sizeof(ArgTy));
     });
   }
 }
@@ -105,14 +131,11 @@ RPC_ATTRS constexpr void prepare_arg(rpc::Server::Port &port) {
 template <uint64_t Idx, typename Tuple>
 RPC_ATTRS constexpr void finish_arg(rpc::Client::Port &port, Tuple &t) {
   using ArgTy = rpc::tuple_element_t<Idx, Tuple>;
-  using ElemTy = rpc::remove_pointer_t<ArgTy>;
-  using MemoryTy = rpc::remove_const_t<rpc::remove_pointer_t<ArgTy>> *;
-  if constexpr (rpc::is_pointer_v<ArgTy> && !rpc::is_const_v<ArgTy> &&
-                rpc::is_complete_v<ElemTy> && !rpc::is_void_v<ElemTy>) {
+  if constexpr (rpc::is_marshalled_ptr_v<ArgTy> && !rpc::is_const_v<ArgTy>) {
     uint64_t size{};
     void *buf{};
     port.recv_n(&buf, &size, [&](uint64_t) {
-      return const_cast<MemoryTy>(rpc::get<Idx>(t));
+      return const_cast<void *>(static_cast<const void *>(rpc::get<Idx>(t)));
     });
   }
 }
@@ -120,30 +143,20 @@ RPC_ATTRS constexpr void finish_arg(rpc::Client::Port &port, Tuple &t) {
 // Server-side finalization of pointer arguments. We copy any potential
 // modifications to the value back to the client unless it was a constant. We
 // can also free the associated memory.
-template <uint32_t NUM_LANES, uint64_t Idx, typename Tuple>
-RPC_ATTRS constexpr void finish_arg(rpc::Server::Port &port,
-                                    Tuple (&t)[NUM_LANES]) {
+template <uint32_t NUM_LANES, typename Tuple, uint64_t Idx, typename State>
+RPC_ATTRS constexpr void finish_arg(rpc::Server::Port &port, State &&state) {
   using ArgTy = rpc::tuple_element_t<Idx, Tuple>;
-  using ElemTy = rpc::remove_pointer_t<ArgTy>;
-  if constexpr (rpc::is_pointer_v<ArgTy> && !rpc::is_const_v<ArgTy> &&
-                rpc::is_complete_v<ElemTy> && !rpc::is_void_v<ElemTy>) {
-    const void *buffer[NUM_LANES]{};
-    size_t sizes[NUM_LANES]{};
-    for (uint32_t id = 0; id < NUM_LANES; ++id) {
-      if (port.get_lane_mask() & (uint64_t(1) << id)) {
-        buffer[id] = rpc::get<Idx>(t[id]);
-        sizes[id] = sizeof(rpc::remove_pointer_t<ArgTy>);
-      }
-    }
-    port.send_n(buffer, sizes);
+  if constexpr (rpc::is_marshalled_ptr_v<ArgTy> && !rpc::is_const_v<ArgTy>) {
+    auto &ptrs = state.ptrs[rpc::marshalled_index_v<Tuple, Idx>];
+    auto &sizes = state.sizes[rpc::marshalled_index_v<Tuple, Idx>];
+    port.send_n(ptrs, sizes);
   }
 
-  if constexpr (rpc::is_pointer_v<ArgTy> && rpc::is_complete_v<ElemTy> &&
-                !rpc::is_void_v<ElemTy>) {
+  if constexpr (rpc::is_marshalled_ptr_v<ArgTy>) {
+    auto &ptrs = state.ptrs[rpc::marshalled_index_v<Tuple, Idx>];
     for (uint32_t id = 0; id < NUM_LANES; ++id) {
       if (port.get_lane_mask() & (uint64_t(1) << id))
-        free(const_cast<void *>(
-            static_cast<const void *>(rpc::get<Idx>(t[id]))));
+        free(const_cast<void *>(static_cast<const void *>(ptrs[id])));
     }
   }
 }
@@ -151,15 +164,16 @@ RPC_ATTRS constexpr void finish_arg(rpc::Server::Port &port,
 // Iterate over the tuple list of arguments to see if we need to forward any.
 // The current forwarding is somewhat inefficient as each pointer is an
 // individual RPC call.
-template <typename Tuple, uint64_t... Is>
+template <typename Tuple, typename CallTuple, uint64_t... Is>
 RPC_ATTRS constexpr void prepare_args(rpc::Client::Port &port, Tuple &t,
+                                      CallTuple &ct,
                                       rpc::index_sequence<Is...>) {
-  (prepare_arg<Is>(port, t), ...);
+  (prepare_arg<Is>(port, t, ct), ...);
 }
-template <uint32_t NUM_LANES, typename Tuple, uint64_t... Is>
-RPC_ATTRS constexpr void prepare_args(rpc::Server::Port &port,
+template <uint32_t NUM_LANES, typename Tuple, typename State, uint64_t... Is>
+RPC_ATTRS constexpr void prepare_args(rpc::Server::Port &port, State &&state,
                                       rpc::index_sequence<Is...>) {
-  (prepare_arg<NUM_LANES, Tuple, Is>(port), ...);
+  (prepare_arg<NUM_LANES, Tuple, Is>(port, state), ...);
 }
 
 // Performs the preparation in reverse, copying back any modified values.
@@ -168,11 +182,10 @@ RPC_ATTRS constexpr void finish_args(rpc::Client::Port &port, Tuple &&t,
                                      rpc::index_sequence<Is...>) {
   (finish_arg<Is>(port, t), ...);
 }
-template <uint32_t NUM_LANES, typename Tuple, uint64_t... Is>
-RPC_ATTRS constexpr void finish_args(rpc::Server::Port &port,
-                                     Tuple (&t)[NUM_LANES],
+template <uint32_t NUM_LANES, typename Tuple, typename State, uint64_t... Is>
+RPC_ATTRS constexpr void finish_args(rpc::Server::Port &port, State &&state,
                                      rpc::index_sequence<Is...>) {
-  (finish_arg<NUM_LANES, Is>(port, t), ...);
+  (finish_arg<NUM_LANES, Tuple, Is>(port, state), ...);
 }
 } // namespace
 
@@ -184,7 +197,7 @@ dispatch(rpc::Client &client, FnTy, CallArgs... args) {
   using Traits = function_traits<FnTy>;
   using RetTy = typename Traits::return_type;
   using TupleTy = typename Traits::arg_types;
-  using Bytes = tuple_bytes<CallArgs...>;
+  using Bytes = tuple_bytes<rpc::remove_span_t<CallArgs>...>;
 
   static_assert(sizeof...(CallArgs) == Traits::ARITY,
                 "Argument count mismatch");
@@ -196,8 +209,10 @@ dispatch(rpc::Client &client, FnTy, CallArgs... args) {
   auto port = client.open<OPCODE>();
 
   // Copy over any pointer arguments by walking the argument list.
+  rpc::tuple<CallArgs...> call_args{args...};
   TupleTy arg_tuple{rpc::forward<CallArgs>(args)...};
-  rpc::prepare_args(port, arg_tuple, rpc::make_index_sequence<Traits::ARITY>{});
+  rpc::prepare_args(port, arg_tuple, call_args,
+                    rpc::make_index_sequence<Traits::ARITY>{});
 
   // Compress the argument list to a binary stream and send it to the server.
   auto bytes = Bytes::pack(arg_tuple);
@@ -226,8 +241,9 @@ RPC_ATTRS constexpr void invoke(rpc::Server::Port &port, FnTy fn) {
   using Bytes = tuple_bytes<TupleTy>;
 
   // Receive pointer data from the host and pack it in server-side memory.
+  MarshalledState<NUM_LANES, TupleTy> state{};
   rpc::prepare_args<NUM_LANES, TupleTy>(
-      port, rpc::make_index_sequence<Traits::ARITY>{});
+      port, state, rpc::make_index_sequence<Traits::ARITY>{});
 
   // Get the argument list from the client.
   typename Bytes::array_type arg_buf[NUM_LANES]{};
@@ -254,8 +270,8 @@ RPC_ATTRS constexpr void invoke(rpc::Server::Port &port, FnTy fn) {
   }
 
   // Send any potentially modified pointer arguments back to the client.
-  rpc::finish_args<NUM_LANES>(port, args,
-                              rpc::make_index_sequence<Traits::ARITY>{});
+  rpc::finish_args<NUM_LANES, TupleTy>(
+      port, state, rpc::make_index_sequence<Traits::ARITY>{});
 
   // Copy back the return value of the function if one exists. If the function
   // is void we send an empty packet to force synchronous behavior.

diff  --git a/libc/shared/rpc_util.h b/libc/shared/rpc_util.h
index 44eb475919947..03ec0fa5030e7 100644
--- a/libc/shared/rpc_util.h
+++ b/libc/shared/rpc_util.h
@@ -63,24 +63,23 @@ using remove_reference_t = typename remove_reference<T>::type;
 
 template <typename T> struct is_const : type_constant<bool, false> {};
 template <typename T> struct is_const<const T> : type_constant<bool, true> {};
-template <typename T> RPC_ATTRS constexpr bool is_const_v = is_const<T>::value;
+template <typename T> inline constexpr bool is_const_v = is_const<T>::value;
 
 template <typename T> struct is_pointer : type_constant<bool, false> {};
 template <typename T> struct is_pointer<T *> : type_constant<bool, true> {};
 template <typename T>
 struct is_pointer<T *const> : type_constant<bool, true> {};
-template <typename T>
-RPC_ATTRS constexpr bool is_pointer_v = is_pointer<T>::value;
+template <typename T> inline constexpr bool is_pointer_v = is_pointer<T>::value;
 
 template <typename T, typename U>
 struct is_same : type_constant<bool, false> {};
 template <typename T> struct is_same<T, T> : type_constant<bool, true> {};
 template <typename T, typename U>
-RPC_ATTRS constexpr bool is_same_v = is_same<T, U>::value;
+inline constexpr bool is_same_v = is_same<T, U>::value;
 
 template <typename T> struct is_void : type_constant<bool, false> {};
 template <> struct is_void<void> : type_constant<bool, true> {};
-template <typename T> RPC_ATTRS constexpr bool is_void_v = is_void<T>::value;
+template <typename T> inline constexpr bool is_void_v = is_void<T>::value;
 
 // Scary trait that can change within a TU, use with caution.
 template <typename...> using void_t = void;
@@ -90,22 +89,36 @@ template <typename T>
 struct is_complete<T, void_t<decltype(sizeof(T))>> : type_constant<bool, true> {
 };
 template <typename T>
-RPC_ATTRS constexpr bool is_complete_v = is_complete<T>::value;
+inline constexpr bool is_complete_v = is_complete<T>::value;
 
 template <typename T>
 struct is_trivially_copyable
     : public type_constant<bool, __is_trivially_copyable(T)> {};
 template <typename T>
-RPC_ATTRS constexpr bool is_trivially_copyable_v =
-    is_trivially_copyable<T>::value;
+inline constexpr bool is_trivially_copyable_v = is_trivially_copyable<T>::value;
 
 template <typename T, typename... Args>
 struct is_trivially_constructible
     : type_constant<bool, __is_trivially_constructible(T, Args...)> {};
 template <typename T, typename... Args>
-RPC_ATTRS constexpr bool is_trivially_constructible_v =
+inline constexpr bool is_trivially_constructible_v =
     is_trivially_constructible<T>::value;
 
+/// Tag type to indicate an array of elements being passed through RPC.
+template <typename T> struct span {
+  T *data;
+  uint64_t size;
+  RPC_ATTRS operator T *() const { return data; }
+};
+
+template <typename T> struct is_span : type_constant<bool, false> {};
+template <typename T> struct is_span<span<T>> : type_constant<bool, true> {};
+template <typename T> inline constexpr bool is_span_v = is_span<T>::value;
+
+template <typename T> struct remove_span : type_identity<T> {};
+template <typename T> struct remove_span<span<T>> : type_identity<T *> {};
+template <typename T> using remove_span_t = typename remove_span<T>::type;
+
 template <bool B, typename T, typename F>
 struct conditional : type_identity<T> {};
 template <typename T, typename F>
@@ -449,6 +462,12 @@ template <typename R, typename... Args> struct function_traits<R (*)(Args...)> {
   using arg_types = rpc::tuple<Args...>;
   static constexpr uint64_t ARITY = sizeof...(Args);
 };
+template <typename R, typename... Args>
+struct function_traits<R (*)(Args...) noexcept> {
+  using return_type = R;
+  using arg_types = rpc::tuple<Args...>;
+  static constexpr uint64_t ARITY = sizeof...(Args);
+};
 template <typename T> T &&declval();
 template <typename T>
 struct function_traits

diff  --git a/offload/test/libc/rpc_callback.cpp b/offload/test/libc/rpc_callback.cpp
index f313541e58d1d..75a942385082b 100644
--- a/offload/test/libc/rpc_callback.cpp
+++ b/offload/test/libc/rpc_callback.cpp
@@ -1,4 +1,5 @@
-// RUN: %libomptarget-compilexx-run-and-check-generic
+// RUN: %libomptarget-compilexx-generic -fopenmp-cuda-mode
+// RUN: %libomptarget-run-generic
 // REQUIRES: libc
 // REQUIRES: gpu
 
@@ -32,6 +33,7 @@ constexpr uint32_t CONST_PTR_OPCODE = 4;
 constexpr uint32_t STRING_OPCODE = 5;
 constexpr uint32_t EMPTY_OPCODE = 6;
 constexpr uint32_t DIVERGENT_OPCODE = 7;
+constexpr uint32_t ARRAY_SUM_OPCODE = 8;
 
 //===------------------------------------------------------------------------===
 // Server-side implementations.
@@ -82,6 +84,14 @@ void divergent(int *p) {
   *p = *p;
 }
 
+// 8. Array argument via span.
+int sum_array(const int *arr, int n) {
+  int s = 0;
+  for (int i = 0; i < n; ++i)
+    s += arr[i];
+  return s;
+}
+
 //===------------------------------------------------------------------------===
 // RPC client dispatch.
 //===------------------------------------------------------------------------===
@@ -110,6 +120,11 @@ int empty() { return rpc::dispatch<EMPTY_OPCODE>(client, empty); }
 void divergent(int *p) {
   rpc::dispatch<DIVERGENT_OPCODE>(client, divergent, p);
 }
+
+int sum_array(const int *arr, int n) {
+  return rpc::dispatch<ARRAY_SUM_OPCODE>(
+      client, sum_array, rpc::span<const int>{arr, uint64_t(n)}, n);
+}
 #pragma omp end declare variant
 
 //===------------------------------------------------------------------------===
@@ -143,6 +158,9 @@ rpc::Status handleOpcodesImpl(rpc::Server::Port &Port) {
       *p = *p;
     });
     break;
+  case ARRAY_SUM_OPCODE:
+    rpc::invoke<NUM_LANES>(Port, sum_array);
+    break;
   default:
     return rpc::RPC_UNHANDLED_OPCODE;
   }
@@ -203,6 +221,11 @@ int main() {
     if (id % 2)
       divergent(&id);
     assert(id == omp_get_thread_num());
+
+    // 8. Array sum via span.
+    int arr[4] = {1, 2, 3, 4};
+    int total = sum_array(arr, 4);
+    assert(total == 10);
   }
 
   printf("PASS\n");


        


More information about the libc-commits mailing list