[libc-commits] [libc] [libs] Use RAII alloc in gpu rpc printf impl (PR #110352)

Ivan Butygin via libc-commits libc-commits at lists.llvm.org
Sat Sep 28 02:31:31 PDT 2024


https://github.com/Hardcode84 created https://github.com/llvm/llvm-project/pull/110352

None

>From 3c032339785fda22357d63f3ce17b57297737ba2 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 28 Sep 2024 11:29:07 +0200
Subject: [PATCH] [libs] Use RAII alloc in gpu rpc printf impl

---
 libc/utils/gpu/server/rpc_server.cpp | 49 +++++++++++++---------------
 1 file changed, 22 insertions(+), 27 deletions(-)

diff --git a/libc/utils/gpu/server/rpc_server.cpp b/libc/utils/gpu/server/rpc_server.cpp
index 888fca6cb0bb30..6951c5ae147df7 100644
--- a/libc/utils/gpu/server/rpc_server.cpp
+++ b/libc/utils/gpu/server/rpc_server.cpp
@@ -42,8 +42,19 @@ static_assert(sizeof(rpc_buffer_t) == sizeof(rpc::Buffer),
 static_assert(RPC_MAXIMUM_PORT_COUNT == rpc::MAX_PORT_COUNT,
               "Incorrect maximum port count");
 
+namespace {
+struct TempStorage {
+  char *alloc(size_t size) {
+    storage.emplace_back(std::make_unique<char[]>(size));
+    return storage.back().get();
+  }
+
+  std::vector<std::unique_ptr<char[]>> storage;
+};
+} // namespace
+
 template <bool packed, uint32_t lane_size>
-void handle_printf(rpc::Server::Port &port) {
+static void handle_printf(rpc::Server::Port &port, TempStorage &temp_storage) {
   FILE *files[lane_size] = {nullptr};
   // Get the appropriate output stream to use.
   if (port.get_opcode() == RPC_PRINTF_TO_STREAM ||
@@ -65,7 +76,7 @@ void handle_printf(rpc::Server::Port &port) {
 
   // Recieve the format string and arguments from the client.
   port.recv_n(format, format_sizes,
-              [&](uint64_t size) { return new char[size]; });
+              [&](uint64_t size) { return temp_storage.alloc(size); });
 
   // Parse the format string to get the expected size of the buffer.
   for (uint32_t lane = 0; lane < lane_size; ++lane) {
@@ -88,7 +99,8 @@ void handle_printf(rpc::Server::Port &port) {
   port.send([&](rpc::Buffer *buffer, uint32_t id) {
     buffer->data[0] = args_sizes[id];
   });
-  port.recv_n(args, args_sizes, [&](uint64_t size) { return new char[size]; });
+  port.recv_n(args, args_sizes,
+              [&](uint64_t size) { return temp_storage.alloc(size); });
 
   // Identify any arguments that are actually pointers to strings on the client.
   // Additionally we want to determine how much buffer space we need to print.
@@ -137,7 +149,8 @@ void handle_printf(rpc::Server::Port &port) {
     });
     uint64_t str_sizes[lane_size] = {0};
     void *strs[lane_size] = {nullptr};
-    port.recv_n(strs, str_sizes, [](uint64_t size) { return new char[size]; });
+    port.recv_n(strs, str_sizes,
+                [&](uint64_t size) { return temp_storage.alloc(size); });
     for (uint32_t lane = 0; lane < lane_size; ++lane) {
       if (!strs[lane])
         continue;
@@ -149,13 +162,12 @@ void handle_printf(rpc::Server::Port &port) {
 
   // Perform the final formatting and printing using the LLVM C library printf.
   int results[lane_size] = {0};
-  std::vector<void *> to_be_deleted;
   for (uint32_t lane = 0; lane < lane_size; ++lane) {
     if (!format[lane])
       continue;
 
-    std::unique_ptr<char[]> buffer(new char[buffer_size[lane]]);
-    WriteBuffer wb(buffer.get(), buffer_size[lane]);
+    char *buffer = temp_storage.alloc(buffer_size[lane]);
+    WriteBuffer wb(buffer, buffer_size[lane]);
     Writer writer(&wb);
 
     internal::StructArgList<packed> printf_args(args[lane], args_sizes[lane]);
@@ -173,7 +185,6 @@ void handle_printf(rpc::Server::Port &port) {
       if (cur_section.has_conv && cur_section.conv_name == 's') {
         if (!copied_strs[lane].empty()) {
           cur_section.conv_val_ptr = copied_strs[lane].back();
-          to_be_deleted.push_back(copied_strs[lane].back());
           copied_strs[lane].pop_back();
         } else {
           cur_section.conv_val_ptr = nullptr;
@@ -188,8 +199,7 @@ void handle_printf(rpc::Server::Port &port) {
       }
     }
 
-    results[lane] =
-        fwrite(buffer.get(), 1, writer.get_chars_written(), files[lane]);
+    results[lane] = fwrite(buffer, 1, writer.get_chars_written(), files[lane]);
     if (results[lane] != writer.get_chars_written() || ret == -1)
       results[lane] = -1;
   }
@@ -199,24 +209,9 @@ void handle_printf(rpc::Server::Port &port) {
   port.send([&](rpc::Buffer *buffer, uint32_t id) {
     buffer->data[0] = static_cast<uint64_t>(results[id]);
     buffer->data[1] = reinterpret_cast<uintptr_t>(nullptr);
-    delete[] reinterpret_cast<char *>(format[id]);
-    delete[] reinterpret_cast<char *>(args[id]);
   });
-  for (void *ptr : to_be_deleted)
-    delete[] reinterpret_cast<char *>(ptr);
 }
 
-namespace {
-struct TempStorage {
-  char *alloc(size_t size) {
-    storage.emplace_back(std::make_unique<char[]>(size));
-    return storage.back().get();
-  }
-
-  std::vector<std::unique_ptr<char[]>> storage;
-};
-} // namespace
-
 template <uint32_t lane_size>
 rpc_status_t handle_server_impl(
     rpc::Server &server,
@@ -381,13 +376,13 @@ rpc_status_t handle_server_impl(
   case RPC_PRINTF_TO_STREAM_PACKED:
   case RPC_PRINTF_TO_STDOUT_PACKED:
   case RPC_PRINTF_TO_STDERR_PACKED: {
-    handle_printf<true, lane_size>(*port);
+    handle_printf<true, lane_size>(*port, temp_storage);
     break;
   }
   case RPC_PRINTF_TO_STREAM:
   case RPC_PRINTF_TO_STDOUT:
   case RPC_PRINTF_TO_STDERR: {
-    handle_printf<false, lane_size>(*port);
+    handle_printf<false, lane_size>(*port, temp_storage);
     break;
   }
   case RPC_REMOVE: {



More information about the libc-commits mailing list