[libc-commits] [libc] [libc] Implement (v|f)printf on the GPU (PR #96369)

Joseph Huber via libc-commits libc-commits at lists.llvm.org
Fri Jul 12 15:35:18 PDT 2024


https://github.com/jhuber6 updated https://github.com/llvm/llvm-project/pull/96369

>From 8e6a5500ff8ed365ef4ac2a0cfc90cc867190818 Mon Sep 17 00:00:00 2001
From: Joseph Huber <huberjn at outlook.com>
Date: Fri, 21 Jun 2024 19:10:40 -0500
Subject: [PATCH] [libc] Implement (v|f)printf on the GPU

Summary:
This patch implements the `printf` family of functions on the GPU using
the new variadic support. This patch adapts the old handling in the
`rpc_fprintf` placeholder, but adds an extra RPC call to get the size of
the buffer to copy. This prevents the GPU from needing to parse the
string. While it's theoretically possible for the pass to know the size
of the struct, it's prohibitively difficult to do while maintaining ABI
compatibility with NVIDIA's varargs.

Depends on https://github.com/llvm/llvm-project/pull/96015.
---
 libc/config/gpu/entrypoints.txt               |  4 +
 libc/include/llvm-libc-types/rpc_opcodes_t.h  |  3 +
 libc/src/__support/arg_list.h                 | 51 +++++++++--
 libc/src/gpu/rpc_fprintf.cpp                  |  5 +-
 libc/src/stdio/CMakeLists.txt                 | 26 +-----
 libc/src/stdio/generic/CMakeLists.txt         | 39 ++++++---
 libc/src/stdio/{ => generic}/fprintf.cpp      |  0
 libc/src/stdio/{ => generic}/vfprintf.cpp     |  0
 libc/src/stdio/gpu/CMakeLists.txt             | 48 +++++++++++
 libc/src/stdio/gpu/fprintf.cpp                | 31 +++++++
 libc/src/stdio/gpu/printf.cpp                 | 29 +++++++
 libc/src/stdio/gpu/vfprintf.cpp               | 28 ++++++
 libc/src/stdio/gpu/vfprintf_utils.h           | 86 +++++++++++++++++++
 libc/src/stdio/gpu/vprintf.cpp                | 27 ++++++
 .../integration/src/stdio/gpu/CMakeLists.txt  |  4 +-
 .../stdio/gpu/{printf.cpp => printf_test.cpp} | 46 +++-------
 libc/test/src/stdio/CMakeLists.txt            |  2 +
 libc/utils/gpu/server/rpc_server.cpp          | 47 ++++++++--
 18 files changed, 387 insertions(+), 89 deletions(-)
 rename libc/src/stdio/{ => generic}/fprintf.cpp (100%)
 rename libc/src/stdio/{ => generic}/vfprintf.cpp (100%)
 create mode 100644 libc/src/stdio/gpu/fprintf.cpp
 create mode 100644 libc/src/stdio/gpu/printf.cpp
 create mode 100644 libc/src/stdio/gpu/vfprintf.cpp
 create mode 100644 libc/src/stdio/gpu/vfprintf_utils.h
 create mode 100644 libc/src/stdio/gpu/vprintf.cpp
 rename libc/test/integration/src/stdio/gpu/{printf.cpp => printf_test.cpp} (54%)

diff --git a/libc/config/gpu/entrypoints.txt b/libc/config/gpu/entrypoints.txt
index be4af4d21689..624ac2715579 100644
--- a/libc/config/gpu/entrypoints.txt
+++ b/libc/config/gpu/entrypoints.txt
@@ -177,6 +177,10 @@ set(TARGET_LIBC_ENTRYPOINTS
     # stdio.h entrypoints
     libc.src.stdio.clearerr
     libc.src.stdio.fclose
+    libc.src.stdio.printf
+    libc.src.stdio.vprintf
+    libc.src.stdio.fprintf
+    libc.src.stdio.vfprintf
     libc.src.stdio.sprintf
     libc.src.stdio.snprintf
     libc.src.stdio.vsprintf
diff --git a/libc/include/llvm-libc-types/rpc_opcodes_t.h b/libc/include/llvm-libc-types/rpc_opcodes_t.h
index fb3bc7b68f71..45050e8521f7 100644
--- a/libc/include/llvm-libc-types/rpc_opcodes_t.h
+++ b/libc/include/llvm-libc-types/rpc_opcodes_t.h
@@ -34,6 +34,9 @@ typedef enum {
   RPC_PRINTF_TO_STDOUT,
   RPC_PRINTF_TO_STDERR,
   RPC_PRINTF_TO_STREAM,
+  RPC_PRINTF_TO_STDOUT_PACKED,
+  RPC_PRINTF_TO_STDERR_PACKED,
+  RPC_PRINTF_TO_STREAM_PACKED,
   RPC_REMOVE,
   RPC_LAST = 0xFFFF,
 } rpc_opcode_t;
diff --git a/libc/src/__support/arg_list.h b/libc/src/__support/arg_list.h
index a57973273c9f..66afa67e320b 100644
--- a/libc/src/__support/arg_list.h
+++ b/libc/src/__support/arg_list.h
@@ -19,6 +19,11 @@
 namespace LIBC_NAMESPACE_DECL {
 namespace internal {
 
+template <typename V, typename A>
+LIBC_INLINE constexpr V align_up(V val, A align) {
+  return ((val + V(align) - 1) / V(align)) * V(align);
+}
+
 class ArgList {
   va_list vlist;
 
@@ -55,7 +60,34 @@ class MockArgList {
   }
 
   template <class T> LIBC_INLINE T next_var() {
-    ++arg_counter;
+    arg_counter++;
+    return T(arg_counter);
+  }
+
+  size_t read_count() const { return arg_counter; }
+};
+
+// Used by the GPU implementation to parse how many bytes need to be read from
+// the variadic argument buffer.
+template <bool packed> class DummyArgList {
+  size_t arg_counter = 0;
+
+public:
+  LIBC_INLINE DummyArgList() = default;
+  LIBC_INLINE DummyArgList(va_list) { ; }
+  LIBC_INLINE DummyArgList(DummyArgList &other) {
+    arg_counter = other.arg_counter;
+  }
+  LIBC_INLINE ~DummyArgList() = default;
+
+  LIBC_INLINE DummyArgList &operator=(DummyArgList &rhs) {
+    arg_counter = rhs.arg_counter;
+    return *this;
+  }
+
+  template <class T> LIBC_INLINE T next_var() {
+    arg_counter = packed ? arg_counter + sizeof(T)
+                         : align_up(arg_counter, alignof(T)) + sizeof(T);
     return T(arg_counter);
   }
 
@@ -64,7 +96,7 @@ class MockArgList {
 
 // Used for the GPU implementation of `printf`. This models a variadic list as a
 // simple array of pointers that are built manually by the implementation.
-class StructArgList {
+template <bool packed> class StructArgList {
   void *ptr;
   void *end;
 
@@ -86,15 +118,18 @@ class StructArgList {
   LIBC_INLINE void *get_ptr() const { return ptr; }
 
   template <class T> LIBC_INLINE T next_var() {
-    ptr = reinterpret_cast<void *>(
-        ((reinterpret_cast<uintptr_t>(ptr) + alignof(T) - 1) / alignof(T)) *
-        alignof(T));
-
+    if (!packed)
+      ptr = reinterpret_cast<void *>(
+          align_up(reinterpret_cast<uintptr_t>(ptr), alignof(T)));
     if (ptr >= end)
       return T(-1);
 
-    T val = *reinterpret_cast<T *>(ptr);
-    ptr = reinterpret_cast<unsigned char *>(ptr) + sizeof(T);
+    // Memcpy because pointer alignment may be illegal given a packed struct.
+    T val;
+    __builtin_memcpy(&val, ptr, sizeof(T));
+
+    ptr =
+        reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(ptr) + sizeof(T));
     return val;
   }
 };
diff --git a/libc/src/gpu/rpc_fprintf.cpp b/libc/src/gpu/rpc_fprintf.cpp
index 321137ef2495..70056daa25e2 100644
--- a/libc/src/gpu/rpc_fprintf.cpp
+++ b/libc/src/gpu/rpc_fprintf.cpp
@@ -30,6 +30,9 @@ int fprintf_impl(::FILE *__restrict file, const char *__restrict format,
   }
 
   port.send_n(format, format_size);
+  port.recv([&](rpc::Buffer *buffer) {
+    args_size = static_cast<size_t>(buffer->data[0]);
+  });
   port.send_n(args, args_size);
 
   uint32_t ret = 0;
@@ -51,7 +54,7 @@ int fprintf_impl(::FILE *__restrict file, const char *__restrict format,
   return ret;
 }
 
-// TODO: This is a stand-in function that uses a struct pointer and size in
+// TODO: Delete this and port OpenMP to use `printf`.
 // place of varargs. Once varargs support is added we will use that to
 // implement the real version.
 LLVM_LIBC_FUNCTION(int, rpc_fprintf,
diff --git a/libc/src/stdio/CMakeLists.txt b/libc/src/stdio/CMakeLists.txt
index 7c61238e6957..2d528a903cc2 100644
--- a/libc/src/stdio/CMakeLists.txt
+++ b/libc/src/stdio/CMakeLists.txt
@@ -163,18 +163,6 @@ add_entrypoint_object(
     libc.src.stdio.printf_core.writer
 )
 
-add_entrypoint_object(
-  fprintf
-  SRCS
-    fprintf.cpp
-  HDRS
-    fprintf.h
-  DEPENDS
-    libc.hdr.types.FILE
-    libc.src.__support.arg_list
-    libc.src.stdio.printf_core.vfprintf_internal
-)
-
 add_entrypoint_object(
   vsprintf
   SRCS
@@ -197,18 +185,6 @@ add_entrypoint_object(
     libc.src.stdio.printf_core.writer
 )
 
-add_entrypoint_object(
-  vfprintf
-  SRCS
-    vfprintf.cpp
-  HDRS
-    vfprintf.h
-  DEPENDS
-    libc.hdr.types.FILE
-    libc.src.__support.arg_list
-    libc.src.stdio.printf_core.vfprintf_internal
-)
-
 add_subdirectory(printf_core)
 add_subdirectory(scanf_core)
 
@@ -258,6 +234,7 @@ add_stdio_entrypoint_object(fputc)
 add_stdio_entrypoint_object(putc)
 add_stdio_entrypoint_object(putchar)
 add_stdio_entrypoint_object(printf)
+add_stdio_entrypoint_object(fprintf)
 add_stdio_entrypoint_object(fgetc)
 add_stdio_entrypoint_object(fgetc_unlocked)
 add_stdio_entrypoint_object(getc)
@@ -270,3 +247,4 @@ add_stdio_entrypoint_object(stdin)
 add_stdio_entrypoint_object(stdout)
 add_stdio_entrypoint_object(stderr)
 add_stdio_entrypoint_object(vprintf)
+add_stdio_entrypoint_object(vfprintf)
diff --git a/libc/src/stdio/generic/CMakeLists.txt b/libc/src/stdio/generic/CMakeLists.txt
index b0ce8ccc08bd..a86b86a32dce 100644
--- a/libc/src/stdio/generic/CMakeLists.txt
+++ b/libc/src/stdio/generic/CMakeLists.txt
@@ -363,19 +363,6 @@ add_entrypoint_object(
     libc.src.__support.File.platform_file
 )
 
-list(APPEND printf_deps
-      libc.src.__support.arg_list
-      libc.src.stdio.printf_core.vfprintf_internal
-)
-
-if(LLVM_LIBC_FULL_BUILD)
-  list(APPEND printf_deps
-      libc.src.__support.File.file
-      libc.src.__support.File.platform_file
-      libc.src.__support.File.platform_stdout
-  )
-endif()
-
 add_entrypoint_object(
   printf
   SRCS
@@ -396,6 +383,32 @@ add_entrypoint_object(
     ${printf_deps}
 )
 
+add_entrypoint_object(
+  fprintf
+  SRCS
+    fprintf.cpp
+  HDRS
+    ../fprintf.h
+  DEPENDS
+    libc.hdr.types.FILE
+    libc.src.__support.arg_list
+    libc.src.stdio.printf_core.vfprintf_internal
+    ${printf_deps}
+)
+
+add_entrypoint_object(
+  vfprintf
+  SRCS
+    vfprintf.cpp
+  HDRS
+    ../vfprintf.h
+  DEPENDS
+    libc.hdr.types.FILE
+    libc.src.__support.arg_list
+    libc.src.stdio.printf_core.vfprintf_internal
+    ${printf_deps}
+)
+
 add_entrypoint_object(
   fgets
   SRCS
diff --git a/libc/src/stdio/fprintf.cpp b/libc/src/stdio/generic/fprintf.cpp
similarity index 100%
rename from libc/src/stdio/fprintf.cpp
rename to libc/src/stdio/generic/fprintf.cpp
diff --git a/libc/src/stdio/vfprintf.cpp b/libc/src/stdio/generic/vfprintf.cpp
similarity index 100%
rename from libc/src/stdio/vfprintf.cpp
rename to libc/src/stdio/generic/vfprintf.cpp
diff --git a/libc/src/stdio/gpu/CMakeLists.txt b/libc/src/stdio/gpu/CMakeLists.txt
index 64bae7dc02d5..86470b8425e9 100644
--- a/libc/src/stdio/gpu/CMakeLists.txt
+++ b/libc/src/stdio/gpu/CMakeLists.txt
@@ -11,6 +11,14 @@ add_header_library(
     .stderr
 )
 
+add_header_library(
+  vfprintf_utils
+  HDRS
+    vfprintf_utils.h
+  DEPENDS
+    .gpu_file
+)
+
 add_entrypoint_object(
   feof
   SRCS
@@ -246,6 +254,46 @@ add_entrypoint_object(
     .gpu_file
 )
 
+add_entrypoint_object(
+  printf
+  SRCS
+    printf.cpp
+  HDRS
+    ../printf.h
+  DEPENDS
+    .vfprintf_utils
+)
+
+add_entrypoint_object(
+  vprintf
+  SRCS
+    vprintf.cpp
+  HDRS
+    ../vprintf.h
+  DEPENDS
+    .vfprintf_utils
+)
+
+add_entrypoint_object(
+  fprintf
+  SRCS
+    fprintf.cpp
+  HDRS
+    ../fprintf.h
+  DEPENDS
+    .vfprintf_utils
+)
+
+add_entrypoint_object(
+  vfprintf
+  SRCS
+    vfprintf.cpp
+  HDRS
+    ../vfprintf.h
+  DEPENDS
+    .vfprintf_utils
+)
+
 add_entrypoint_object(
   stdin
   SRCS
diff --git a/libc/src/stdio/gpu/fprintf.cpp b/libc/src/stdio/gpu/fprintf.cpp
new file mode 100644
index 000000000000..42d6ad008777
--- /dev/null
+++ b/libc/src/stdio/gpu/fprintf.cpp
@@ -0,0 +1,31 @@
+//===-- GPU Implementation of fprintf -------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "src/stdio/fprintf.h"
+
+#include "src/__support/CPP/string_view.h"
+#include "src/__support/arg_list.h"
+#include "src/errno/libc_errno.h"
+#include "src/stdio/gpu/vfprintf_utils.h"
+
+#include <stdio.h>
+
+namespace LIBC_NAMESPACE {
+
+LLVM_LIBC_FUNCTION(int, fprintf,
+                   (::FILE *__restrict stream, const char *__restrict format,
+                    ...)) {
+  va_list vlist;
+  va_start(vlist, format);
+  cpp::string_view str_view(format);
+  int ret_val = vfprintf_internal(stream, format, str_view.size() + 1, vlist);
+  va_end(vlist);
+  return ret_val;
+}
+
+} // namespace LIBC_NAMESPACE
diff --git a/libc/src/stdio/gpu/printf.cpp b/libc/src/stdio/gpu/printf.cpp
new file mode 100644
index 000000000000..63af6fffeea7
--- /dev/null
+++ b/libc/src/stdio/gpu/printf.cpp
@@ -0,0 +1,29 @@
+//===-- GPU Implementation of printf --------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "src/stdio/printf.h"
+
+#include "src/__support/CPP/string_view.h"
+#include "src/__support/arg_list.h"
+#include "src/errno/libc_errno.h"
+#include "src/stdio/gpu/vfprintf_utils.h"
+
+#include <stdio.h>
+
+namespace LIBC_NAMESPACE {
+
+LLVM_LIBC_FUNCTION(int, printf, (const char *__restrict format, ...)) {
+  va_list vlist;
+  va_start(vlist, format);
+  cpp::string_view str_view(format);
+  int ret_val = vfprintf_internal(stdout, format, str_view.size() + 1, vlist);
+  va_end(vlist);
+  return ret_val;
+}
+
+} // namespace LIBC_NAMESPACE
diff --git a/libc/src/stdio/gpu/vfprintf.cpp b/libc/src/stdio/gpu/vfprintf.cpp
new file mode 100644
index 000000000000..f314f6872ad0
--- /dev/null
+++ b/libc/src/stdio/gpu/vfprintf.cpp
@@ -0,0 +1,28 @@
+//===-- GPU Implementation of vfprintf ------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "src/stdio/vfprintf.h"
+
+#include "src/__support/CPP/string_view.h"
+#include "src/__support/arg_list.h"
+#include "src/errno/libc_errno.h"
+#include "src/stdio/gpu/vfprintf_utils.h"
+
+#include <stdio.h>
+
+namespace LIBC_NAMESPACE {
+
+LLVM_LIBC_FUNCTION(int, vfprintf,
+                   (::FILE *__restrict stream, const char *__restrict format,
+                    va_list vlist)) {
+  cpp::string_view str_view(format);
+  int ret_val = vfprintf_internal(stream, format, str_view.size() + 1, vlist);
+  return ret_val;
+}
+
+} // namespace LIBC_NAMESPACE
diff --git a/libc/src/stdio/gpu/vfprintf_utils.h b/libc/src/stdio/gpu/vfprintf_utils.h
new file mode 100644
index 000000000000..f364646fcea5
--- /dev/null
+++ b/libc/src/stdio/gpu/vfprintf_utils.h
@@ -0,0 +1,86 @@
+//===--- GPU helper functions for printf using RPC ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "src/__support/RPC/rpc_client.h"
+#include "src/__support/arg_list.h"
+#include "src/stdio/gpu/file.h"
+#include "src/string/string_utils.h"
+
+#include <stdio.h>
+
+namespace LIBC_NAMESPACE {
+
+template <uint16_t opcode>
+LIBC_INLINE int vfprintf_impl(::FILE *__restrict file,
+                              const char *__restrict format, size_t format_size,
+                              va_list vlist) {
+  uint64_t mask = gpu::get_lane_mask();
+  rpc::Client::Port port = rpc::client.open<opcode>();
+
+  if constexpr (opcode == RPC_PRINTF_TO_STREAM ||
+                opcode == RPC_PRINTF_TO_STREAM_PACKED) {
+    port.send([&](rpc::Buffer *buffer) {
+      buffer->data[0] = reinterpret_cast<uintptr_t>(file);
+    });
+  }
+
+  size_t args_size = 0;
+  port.send_n(format, format_size);
+  port.recv([&](rpc::Buffer *buffer) {
+    args_size = static_cast<size_t>(buffer->data[0]);
+  });
+  port.send_n(vlist, args_size);
+
+  uint32_t ret = 0;
+  for (;;) {
+    const char *str = nullptr;
+    port.recv([&](rpc::Buffer *buffer) {
+      ret = static_cast<uint32_t>(buffer->data[0]);
+      str = reinterpret_cast<const char *>(buffer->data[1]);
+    });
+    // If any lanes have a string argument it needs to be copied back.
+    if (!gpu::ballot(mask, str))
+      break;
+
+    uint64_t size = str ? internal::string_length(str) + 1 : 0;
+    port.send_n(str, size);
+  }
+
+  port.close();
+  return ret;
+}
+
+LIBC_INLINE int vfprintf_internal(::FILE *__restrict stream,
+                                  const char *__restrict format,
+                                  size_t format_size, va_list vlist) {
+  // The AMDPGU backend uses a packed struct for its varargs. We pass it as a
+  // separate opcode so the server knows how much to advance the pointers.
+#if defined(LIBC_TARGET_ARCH_IS_AMDGPU)
+  if (stream == stdout)
+    return vfprintf_impl<RPC_PRINTF_TO_STDOUT_PACKED>(stream, format,
+                                                      format_size, vlist);
+  else if (stream == stderr)
+    return vfprintf_impl<RPC_PRINTF_TO_STDERR_PACKED>(stream, format,
+                                                      format_size, vlist);
+  else
+    return vfprintf_impl<RPC_PRINTF_TO_STREAM_PACKED>(stream, format,
+                                                      format_size, vlist);
+#else
+  if (stream == stdout)
+    return vfprintf_impl<RPC_PRINTF_TO_STDOUT>(stream, format, format_size,
+                                               vlist);
+  else if (stream == stderr)
+    return vfprintf_impl<RPC_PRINTF_TO_STDERR>(stream, format, format_size,
+                                               vlist);
+  else
+    return vfprintf_impl<RPC_PRINTF_TO_STREAM>(stream, format, format_size,
+                                               vlist);
+#endif
+}
+
+} // namespace LIBC_NAMESPACE
diff --git a/libc/src/stdio/gpu/vprintf.cpp b/libc/src/stdio/gpu/vprintf.cpp
new file mode 100644
index 000000000000..1356aceeb51c
--- /dev/null
+++ b/libc/src/stdio/gpu/vprintf.cpp
@@ -0,0 +1,27 @@
+//===-- GPU Implementation of vprintf -------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "src/stdio/vprintf.h"
+
+#include "src/__support/CPP/string_view.h"
+#include "src/__support/arg_list.h"
+#include "src/errno/libc_errno.h"
+#include "src/stdio/gpu/vfprintf_utils.h"
+
+#include <stdio.h>
+
+namespace LIBC_NAMESPACE {
+
+LLVM_LIBC_FUNCTION(int, vprintf,
+                   (const char *__restrict format, va_list vlist)) {
+  cpp::string_view str_view(format);
+  int ret_val = vfprintf_internal(stdout, format, str_view.size() + 1, vlist);
+  return ret_val;
+}
+
+} // namespace LIBC_NAMESPACE
diff --git a/libc/test/integration/src/stdio/gpu/CMakeLists.txt b/libc/test/integration/src/stdio/gpu/CMakeLists.txt
index 6327c45e1ea5..1b896b5ddbe8 100644
--- a/libc/test/integration/src/stdio/gpu/CMakeLists.txt
+++ b/libc/test/integration/src/stdio/gpu/CMakeLists.txt
@@ -11,9 +11,9 @@ add_integration_test(
   SUITE
     stdio-gpu-integration-tests
   SRCS
-    printf.cpp
+    printf_test.cpp
   DEPENDS
-    libc.src.gpu.rpc_fprintf
+    libc.src.stdio.fprintf
     libc.src.stdio.fopen
   LOADER_ARGS
     --threads 32
diff --git a/libc/test/integration/src/stdio/gpu/printf.cpp b/libc/test/integration/src/stdio/gpu/printf_test.cpp
similarity index 54%
rename from libc/test/integration/src/stdio/gpu/printf.cpp
rename to libc/test/integration/src/stdio/gpu/printf_test.cpp
index 97ad4ace1dca..5f78737a2aca 100644
--- a/libc/test/integration/src/stdio/gpu/printf.cpp
+++ b/libc/test/integration/src/stdio/gpu/printf_test.cpp
@@ -9,8 +9,8 @@
 #include "test/IntegrationTest/test.h"
 
 #include "src/__support/GPU/utils.h"
-#include "src/gpu/rpc_fprintf.h"
 #include "src/stdio/fopen.h"
+#include "src/stdio/fprintf.h"
 
 using namespace LIBC_NAMESPACE;
 
@@ -20,68 +20,48 @@ TEST_MAIN(int argc, char **argv, char **envp) {
   ASSERT_TRUE(file && "failed to open file");
   // Check basic printing.
   int written = 0;
-  written = LIBC_NAMESPACE::rpc_fprintf(file, "A simple string\n", nullptr, 0);
+  written = LIBC_NAMESPACE::fprintf(file, "A simple string\n");
   ASSERT_EQ(written, 16);
 
   const char *str = "A simple string\n";
-  written = LIBC_NAMESPACE::rpc_fprintf(file, "%s", &str, sizeof(void *));
+  written = LIBC_NAMESPACE::fprintf(file, "%s", str);
   ASSERT_EQ(written, 16);
 
   // Check printing a different value with each thread.
   uint64_t thread_id = gpu::get_thread_id();
-  written = LIBC_NAMESPACE::rpc_fprintf(file, "%8ld\n", &thread_id,
-                                        sizeof(thread_id));
+  written = LIBC_NAMESPACE::fprintf(file, "%8ld\n", thread_id);
   ASSERT_EQ(written, 9);
 
-  struct {
-    uint32_t x = 1;
-    char c = 'c';
-    double f = 1.0;
-  } args1;
-  written =
-      LIBC_NAMESPACE::rpc_fprintf(file, "%d%c%.1f\n", &args1, sizeof(args1));
+  written = LIBC_NAMESPACE::fprintf(file, "%d%c%.1f\n", 1, 'c', 1.0);
   ASSERT_EQ(written, 6);
 
-  struct {
-    uint32_t x = 1;
-    const char *str = "A simple string\n";
-  } args2;
-  written =
-      LIBC_NAMESPACE::rpc_fprintf(file, "%032b%s\n", &args2, sizeof(args2));
+  written = LIBC_NAMESPACE::fprintf(file, "%032b%s\n", 1, "A simple string\n");
   ASSERT_EQ(written, 49);
 
   // Check that the server correctly handles divergent numbers of arguments.
   const char *format = gpu::get_thread_id() % 2 ? "%s" : "%20ld\n";
-  written = LIBC_NAMESPACE::rpc_fprintf(file, format, &str, sizeof(void *));
+  written = LIBC_NAMESPACE::fprintf(file, format, str);
   ASSERT_EQ(written, gpu::get_thread_id() % 2 ? 16 : 21);
 
   format = gpu::get_thread_id() % 2 ? "%s" : str;
-  written = LIBC_NAMESPACE::rpc_fprintf(file, format, &str, sizeof(void *));
+  written = LIBC_NAMESPACE::fprintf(file, format, str);
   ASSERT_EQ(written, 16);
 
   // Check that we handle null arguments correctly.
-  struct {
-    void *null = nullptr;
-  } args3;
-  written = LIBC_NAMESPACE::rpc_fprintf(file, "%p", &args3, sizeof(args3));
+  written = LIBC_NAMESPACE::fprintf(file, "%p", nullptr);
   ASSERT_EQ(written, 9);
 
 #ifndef LIBC_COPT_PRINTF_NO_NULLPTR_CHECKS
-  written = LIBC_NAMESPACE::rpc_fprintf(file, "%s", &args3, sizeof(args3));
+  written = LIBC_NAMESPACE::fprintf(file, "%s", nullptr);
   ASSERT_EQ(written, 6);
 #endif // LIBC_COPT_PRINTF_NO_NULLPTR_CHECKS
 
   // Check for extremely abused variable width arguments
-  struct {
-    uint32_t x = 1;
-    uint32_t y = 2;
-    double f = 1.0;
-  } args4;
-  written = LIBC_NAMESPACE::rpc_fprintf(file, "%**d", &args4, sizeof(args4));
+  written = LIBC_NAMESPACE::fprintf(file, "%**d", 1, 2, 1.0);
   ASSERT_EQ(written, 4);
-  written = LIBC_NAMESPACE::rpc_fprintf(file, "%**d%6d", &args4, sizeof(args4));
+  written = LIBC_NAMESPACE::fprintf(file, "%**d%6d", 1, 2, 1.0);
   ASSERT_EQ(written, 10);
-  written = LIBC_NAMESPACE::rpc_fprintf(file, "%**.**f", &args4, sizeof(args4));
+  written = LIBC_NAMESPACE::fprintf(file, "%**.**f", 1, 2, 1.0);
   ASSERT_EQ(written, 7);
 
   return 0;
diff --git a/libc/test/src/stdio/CMakeLists.txt b/libc/test/src/stdio/CMakeLists.txt
index 5eb8c9577893..a01deae5d93a 100644
--- a/libc/test/src/stdio/CMakeLists.txt
+++ b/libc/test/src/stdio/CMakeLists.txt
@@ -189,6 +189,7 @@ add_libc_test(
     printf_test.cpp
   DEPENDS
     libc.src.stdio.printf
+    libc.src.stdio.stdout
 )
 
 add_fp_unittest(
@@ -234,6 +235,7 @@ add_libc_test(
     vprintf_test.cpp
   DEPENDS
     libc.src.stdio.vprintf
+    libc.src.stdio.stdout
 )
 
 
diff --git a/libc/utils/gpu/server/rpc_server.cpp b/libc/utils/gpu/server/rpc_server.cpp
index a05fc8457a9c..119539e3cad4 100644
--- a/libc/utils/gpu/server/rpc_server.cpp
+++ b/libc/utils/gpu/server/rpc_server.cpp
@@ -39,14 +39,17 @@ static_assert(sizeof(rpc_buffer_t) == sizeof(rpc::Buffer),
 static_assert(RPC_MAXIMUM_PORT_COUNT == rpc::MAX_PORT_COUNT,
               "Incorrect maximum port count");
 
-template <uint32_t lane_size> void handle_printf(rpc::Server::Port &port) {
+template <bool packed, uint32_t lane_size>
+void handle_printf(rpc::Server::Port &port) {
   FILE *files[lane_size] = {nullptr};
   // Get the appropriate output stream to use.
-  if (port.get_opcode() == RPC_PRINTF_TO_STREAM)
+  if (port.get_opcode() == RPC_PRINTF_TO_STREAM ||
+      port.get_opcode() == RPC_PRINTF_TO_STREAM_PACKED)
     port.recv([&](rpc::Buffer *buffer, uint32_t id) {
       files[id] = reinterpret_cast<FILE *>(buffer->data[0]);
     });
-  else if (port.get_opcode() == RPC_PRINTF_TO_STDOUT)
+  else if (port.get_opcode() == RPC_PRINTF_TO_STDOUT ||
+           port.get_opcode() == RPC_PRINTF_TO_STDOUT_PACKED)
     std::fill(files, files + lane_size, stdout);
   else
     std::fill(files, files + lane_size, stderr);
@@ -60,6 +63,28 @@ template <uint32_t lane_size> void handle_printf(rpc::Server::Port &port) {
   // Recieve the format string and arguments from the client.
   port.recv_n(format, format_sizes,
               [&](uint64_t size) { return new char[size]; });
+
+  // Parse the format string to get the expected size of the buffer.
+  for (uint32_t lane = 0; lane < lane_size; ++lane) {
+    if (!format[lane])
+      continue;
+
+    WriteBuffer wb(nullptr, 0);
+    Writer writer(&wb);
+
+    internal::DummyArgList<packed> printf_args;
+    Parser<internal::DummyArgList<packed> &> parser(
+        reinterpret_cast<const char *>(format[lane]), printf_args);
+
+    for (FormatSection cur_section = parser.get_next_section();
+         !cur_section.raw_string.empty();
+         cur_section = parser.get_next_section())
+      ;
+    args_sizes[lane] = printf_args.read_count();
+  }
+  port.send([&](rpc::Buffer *buffer, uint32_t id) {
+    buffer->data[0] = args_sizes[id];
+  });
   port.recv_n(args, args_sizes, [&](uint64_t size) { return new char[size]; });
 
   // Identify any arguments that are actually pointers to strings on the client.
@@ -73,8 +98,8 @@ template <uint32_t lane_size> void handle_printf(rpc::Server::Port &port) {
     WriteBuffer wb(nullptr, 0);
     Writer writer(&wb);
 
-    internal::StructArgList printf_args(args[lane], args_sizes[lane]);
-    Parser<internal::StructArgList> parser(
+    internal::StructArgList<packed> printf_args(args[lane], args_sizes[lane]);
+    Parser<internal::StructArgList<packed>> parser(
         reinterpret_cast<const char *>(format[lane]), printf_args);
 
     for (FormatSection cur_section = parser.get_next_section();
@@ -126,8 +151,8 @@ template <uint32_t lane_size> void handle_printf(rpc::Server::Port &port) {
     WriteBuffer wb(buffer.get(), buffer_size[lane]);
     Writer writer(&wb);
 
-    internal::StructArgList printf_args(args[lane], args_sizes[lane]);
-    Parser<internal::StructArgList> parser(
+    internal::StructArgList<packed> printf_args(args[lane], args_sizes[lane]);
+    Parser<internal::StructArgList<packed>> parser(
         reinterpret_cast<const char *>(format[lane]), printf_args);
 
     // Parse and print the format string using the arguments we copied from
@@ -337,10 +362,16 @@ rpc_status_t handle_server_impl(
     });
     break;
   }
+  case RPC_PRINTF_TO_STREAM_PACKED:
+  case RPC_PRINTF_TO_STDOUT_PACKED:
+  case RPC_PRINTF_TO_STDERR_PACKED: {
+    handle_printf<true, lane_size>(*port);
+    break;
+  }
   case RPC_PRINTF_TO_STREAM:
   case RPC_PRINTF_TO_STDOUT:
   case RPC_PRINTF_TO_STDERR: {
-    handle_printf<lane_size>(*port);
+    handle_printf<false, lane_size>(*port);
     break;
   }
   case RPC_REMOVE: {



More information about the libc-commits mailing list