[libc-commits] [libc] bbe79a8 - [libc] Use RAII alloc in gpu rpc printf impl (#110352)

via libc-commits libc-commits at lists.llvm.org
Sat Sep 28 05:44:04 PDT 2024


Author: Ivan Butygin
Date: 2024-09-28T15:44:01+03:00
New Revision: bbe79a803c84f4193c39566c9b0189ecadf5d8b4

URL: https://github.com/llvm/llvm-project/commit/bbe79a803c84f4193c39566c9b0189ecadf5d8b4
DIFF: https://github.com/llvm/llvm-project/commit/bbe79a803c84f4193c39566c9b0189ecadf5d8b4.diff

LOG: [libc] Use RAII alloc in gpu rpc printf impl (#110352)

Added: 
    

Modified: 
    libc/utils/gpu/server/rpc_server.cpp

Removed: 
    


################################################################################
diff  --git a/libc/utils/gpu/server/rpc_server.cpp b/libc/utils/gpu/server/rpc_server.cpp
index 888fca6cb0bb30..6951c5ae147df7 100644
--- a/libc/utils/gpu/server/rpc_server.cpp
+++ b/libc/utils/gpu/server/rpc_server.cpp
@@ -42,8 +42,19 @@ static_assert(sizeof(rpc_buffer_t) == sizeof(rpc::Buffer),
 static_assert(RPC_MAXIMUM_PORT_COUNT == rpc::MAX_PORT_COUNT,
               "Incorrect maximum port count");
 
+namespace {
+struct TempStorage {
+  char *alloc(size_t size) {
+    storage.emplace_back(std::make_unique<char[]>(size));
+    return storage.back().get();
+  }
+
+  std::vector<std::unique_ptr<char[]>> storage;
+};
+} // namespace
+
 template <bool packed, uint32_t lane_size>
-void handle_printf(rpc::Server::Port &port) {
+static void handle_printf(rpc::Server::Port &port, TempStorage &temp_storage) {
   FILE *files[lane_size] = {nullptr};
   // Get the appropriate output stream to use.
   if (port.get_opcode() == RPC_PRINTF_TO_STREAM ||
@@ -65,7 +76,7 @@ void handle_printf(rpc::Server::Port &port) {
 
   // Recieve the format string and arguments from the client.
   port.recv_n(format, format_sizes,
-              [&](uint64_t size) { return new char[size]; });
+              [&](uint64_t size) { return temp_storage.alloc(size); });
 
   // Parse the format string to get the expected size of the buffer.
   for (uint32_t lane = 0; lane < lane_size; ++lane) {
@@ -88,7 +99,8 @@ void handle_printf(rpc::Server::Port &port) {
   port.send([&](rpc::Buffer *buffer, uint32_t id) {
     buffer->data[0] = args_sizes[id];
   });
-  port.recv_n(args, args_sizes, [&](uint64_t size) { return new char[size]; });
+  port.recv_n(args, args_sizes,
+              [&](uint64_t size) { return temp_storage.alloc(size); });
 
   // Identify any arguments that are actually pointers to strings on the client.
   // Additionally we want to determine how much buffer space we need to print.
@@ -137,7 +149,8 @@ void handle_printf(rpc::Server::Port &port) {
     });
     uint64_t str_sizes[lane_size] = {0};
     void *strs[lane_size] = {nullptr};
-    port.recv_n(strs, str_sizes, [](uint64_t size) { return new char[size]; });
+    port.recv_n(strs, str_sizes,
+                [&](uint64_t size) { return temp_storage.alloc(size); });
     for (uint32_t lane = 0; lane < lane_size; ++lane) {
       if (!strs[lane])
         continue;
@@ -149,13 +162,12 @@ void handle_printf(rpc::Server::Port &port) {
 
   // Perform the final formatting and printing using the LLVM C library printf.
   int results[lane_size] = {0};
-  std::vector<void *> to_be_deleted;
   for (uint32_t lane = 0; lane < lane_size; ++lane) {
     if (!format[lane])
       continue;
 
-    std::unique_ptr<char[]> buffer(new char[buffer_size[lane]]);
-    WriteBuffer wb(buffer.get(), buffer_size[lane]);
+    char *buffer = temp_storage.alloc(buffer_size[lane]);
+    WriteBuffer wb(buffer, buffer_size[lane]);
     Writer writer(&wb);
 
     internal::StructArgList<packed> printf_args(args[lane], args_sizes[lane]);
@@ -173,7 +185,6 @@ void handle_printf(rpc::Server::Port &port) {
       if (cur_section.has_conv && cur_section.conv_name == 's') {
         if (!copied_strs[lane].empty()) {
           cur_section.conv_val_ptr = copied_strs[lane].back();
-          to_be_deleted.push_back(copied_strs[lane].back());
           copied_strs[lane].pop_back();
         } else {
           cur_section.conv_val_ptr = nullptr;
@@ -188,8 +199,7 @@ void handle_printf(rpc::Server::Port &port) {
       }
     }
 
-    results[lane] =
-        fwrite(buffer.get(), 1, writer.get_chars_written(), files[lane]);
+    results[lane] = fwrite(buffer, 1, writer.get_chars_written(), files[lane]);
     if (results[lane] != writer.get_chars_written() || ret == -1)
       results[lane] = -1;
   }
@@ -199,24 +209,9 @@ void handle_printf(rpc::Server::Port &port) {
   port.send([&](rpc::Buffer *buffer, uint32_t id) {
     buffer->data[0] = static_cast<uint64_t>(results[id]);
     buffer->data[1] = reinterpret_cast<uintptr_t>(nullptr);
-    delete[] reinterpret_cast<char *>(format[id]);
-    delete[] reinterpret_cast<char *>(args[id]);
   });
-  for (void *ptr : to_be_deleted)
-    delete[] reinterpret_cast<char *>(ptr);
 }
 
-namespace {
-struct TempStorage {
-  char *alloc(size_t size) {
-    storage.emplace_back(std::make_unique<char[]>(size));
-    return storage.back().get();
-  }
-
-  std::vector<std::unique_ptr<char[]>> storage;
-};
-} // namespace
-
 template <uint32_t lane_size>
 rpc_status_t handle_server_impl(
     rpc::Server &server,
@@ -381,13 +376,13 @@ rpc_status_t handle_server_impl(
   case RPC_PRINTF_TO_STREAM_PACKED:
   case RPC_PRINTF_TO_STDOUT_PACKED:
   case RPC_PRINTF_TO_STDERR_PACKED: {
-    handle_printf<true, lane_size>(*port);
+    handle_printf<true, lane_size>(*port, temp_storage);
     break;
   }
   case RPC_PRINTF_TO_STREAM:
   case RPC_PRINTF_TO_STDOUT:
   case RPC_PRINTF_TO_STDERR: {
-    handle_printf<false, lane_size>(*port);
+    handle_printf<false, lane_size>(*port, temp_storage);
     break;
   }
   case RPC_REMOVE: {


        


More information about the libc-commits mailing list