[clang] [libc] RPCLaneSize (PR #84557)

Joseph Huber via cfe-commits cfe-commits at lists.llvm.org
Fri Mar 8 12:18:37 PST 2024


https://github.com/jhuber6 created https://github.com/llvm/llvm-project/pull/84557

- [HIP] Make the new driver bundle outputs for device-only
- [libc][NFCI] Remove lane size template argument on RPC server


>From 99a769ec7ffaa7728847fdf2f67a1be11ce98f2b Mon Sep 17 00:00:00 2001
From: Joseph Huber <huberjn at outlook.com>
Date: Fri, 8 Mar 2024 12:49:38 -0600
Subject: [PATCH 1/2] [HIP] Make the new driver bundle outputs for device-only

Summary:
The current behavior of HIP is that when --offload-device-only is set it
still bundles the outputs into a fat binary. Even though this is
different from how all the other targets handle this, it seems to be
dependned on by some tooling so just make it backwards compatible for
the `-fno-gpu-rdc` case.
---
 clang/lib/Driver/Driver.cpp       | 10 +++++++++-
 clang/test/Driver/hip-binding.hip | 10 +++++++---
 2 files changed, 16 insertions(+), 4 deletions(-)

diff --git a/clang/lib/Driver/Driver.cpp b/clang/lib/Driver/Driver.cpp
index fce43430a91374..eba43d97431364 100644
--- a/clang/lib/Driver/Driver.cpp
+++ b/clang/lib/Driver/Driver.cpp
@@ -4638,7 +4638,10 @@ Action *Driver::BuildOffloadingActions(Compilation &C,
     }
   }
 
-  if (offloadDeviceOnly())
+  // All kinds exit now in device-only mode except for non-RDC mode HIP.
+  if (offloadDeviceOnly() &&
+      (!C.isOffloadingHostKind(Action::OFK_HIP) ||
+       Args.hasFlag(options::OPT_fgpu_rdc, options::OPT_fno_gpu_rdc, false)))
     return C.MakeAction<OffloadAction>(DDeps, types::TY_Nothing);
 
   if (OffloadActions.empty())
@@ -4671,6 +4674,11 @@ Action *Driver::BuildOffloadingActions(Compilation &C,
              nullptr, C.getActiveOffloadKinds());
   }
 
+  // HIP wants '--offload-device-only' to create a fatbinary by default.
+  if (offloadDeviceOnly() && C.isOffloadingHostKind(Action::OFK_HIP) &&
+      !Args.hasFlag(options::OPT_fgpu_rdc, options::OPT_fno_gpu_rdc, false))
+    return C.MakeAction<OffloadAction>(DDep, types::TY_Nothing);
+
   // If we are unable to embed a single device output into the host, we need to
   // add each device output as a host dependency to ensure they are still built.
   bool SingleDeviceOutput = !llvm::any_of(OffloadActions, [](Action *A) {
diff --git a/clang/test/Driver/hip-binding.hip b/clang/test/Driver/hip-binding.hip
index 79ec2039edb74c..cb17112c28d90a 100644
--- a/clang/test/Driver/hip-binding.hip
+++ b/clang/test/Driver/hip-binding.hip
@@ -64,10 +64,14 @@
 // MULTI-D-ONLY-NEXT: # "amdgcn-amd-amdhsa" - "clang", inputs: ["[[INPUT]]"], output: "[[GFX90a:.+]]"
 // MULTI-D-ONLY-NEXT: # "amdgcn-amd-amdhsa" - "AMDGCN::Linker", inputs: ["[[GFX90a]]"], output: "[[GFX90a_OUT:.+]]"
 //
-// RUN: not %clang -### --target=x86_64-linux-gnu --offload-new-driver -ccc-print-bindings -nogpulib -nogpuinc \
-// RUN:        --offload-arch=gfx90a --offload-arch=gfx908 --offload-device-only -c -o %t %s 2>&1 \
+// RUN: %clang -### --target=x86_64-linux-gnu --offload-new-driver -ccc-print-bindings -nogpulib -nogpuinc \
+// RUN:        --offload-arch=gfx90a --offload-arch=gfx908 --offload-device-only -c -o a.out %s 2>&1 \
 // RUN: | FileCheck -check-prefix=MULTI-D-ONLY-O %s
-// MULTI-D-ONLY-O: error: cannot specify -o when generating multiple output files
+//      MULTI-D-ONLY-O: "amdgcn-amd-amdhsa" - "clang", inputs: ["[[INPUT:.+]]"], output: "[[GFX908_OBJ:.+]]"
+// MULTI-D-ONLY-O-NEXT: "amdgcn-amd-amdhsa" - "AMDGCN::Linker", inputs: ["[[GFX908_OBJ]]"], output: "[[GFX908:.+]]"
+// MULTI-D-ONLY-O-NEXT: "amdgcn-amd-amdhsa" - "clang", inputs: ["[[INPUT]]"], output: "[[GFX90A_OBJ:.+]]"
+// MULTI-D-ONLY-O-NEXT: "amdgcn-amd-amdhsa" - "AMDGCN::Linker", inputs: ["[[GFX90A_OBJ]]"], output: "[[GFX90A:.+]]"
+// MULTI-D-ONLY-O-NEXT: "amdgcn-amd-amdhsa" - "AMDGCN::Linker", inputs: ["[[GFX908]]", "[[GFX90A]]"], output: "a.out"
 
 //
 // Check to ensure that we can use '-fsyntax-only' for HIP output with the new

>From dd6f3db812231e4cad9b133f918671de77c5d488 Mon Sep 17 00:00:00 2001
From: Joseph Huber <huberjn at outlook.com>
Date: Fri, 8 Mar 2024 14:15:46 -0600
Subject: [PATCH 2/2] [libc][NFCI] Remove lane size template argument on RPC
 server

Summary:
We previously changed the data layout for the RPC buffer to make it lane
size agnostic. I put off changing the size for the server case to make
the patch smaller. This patch simply reorganizes code by making the lane
size an argument to the port rather than a templtae size. Heavily
simplifies a lot of code, no more `std::variant`.
---
 libc/src/__support/RPC/rpc.h         |  23 +-
 libc/utils/gpu/server/rpc_server.cpp | 485 +++++++++++++--------------
 2 files changed, 235 insertions(+), 273 deletions(-)

diff --git a/libc/src/__support/RPC/rpc.h b/libc/src/__support/RPC/rpc.h
index 5ed39ae0d7f7a9..5dcae518bb6f8f 100644
--- a/libc/src/__support/RPC/rpc.h
+++ b/libc/src/__support/RPC/rpc.h
@@ -310,7 +310,7 @@ template <bool T> struct Port {
   LIBC_INLINE Port &operator=(Port &&) = default;
 
   friend struct Client;
-  template <uint32_t U> friend struct Server;
+  friend struct Server;
   friend class cpp::optional<Port<T>>;
 
 public:
@@ -369,7 +369,7 @@ static_assert(cpp::is_trivially_copyable<Client>::value &&
               "The client is not trivially copyable from the server");
 
 /// The RPC server used to respond to the client.
-template <uint32_t lane_size> struct Server {
+struct Server {
   LIBC_INLINE Server() = default;
   LIBC_INLINE Server(const Server &) = delete;
   LIBC_INLINE Server &operator=(const Server &) = delete;
@@ -379,10 +379,12 @@ template <uint32_t lane_size> struct Server {
       : process(port_count, buffer) {}
 
   using Port = rpc::Port<true>;
-  LIBC_INLINE cpp::optional<Port> try_open(uint32_t start = 0);
-  LIBC_INLINE Port open();
+  LIBC_INLINE cpp::optional<Port> try_open(uint32_t lane_size,
+                                           uint32_t start = 0);
+  LIBC_INLINE Port open(uint32_t lane_size);
 
-  LIBC_INLINE static uint64_t allocation_size(uint32_t port_count) {
+  LIBC_INLINE static uint64_t allocation_size(uint32_t lane_size,
+                                              uint32_t port_count) {
     return Process<true>::allocation_size(port_count, lane_size);
   }
 
@@ -556,10 +558,8 @@ template <uint16_t opcode>
 
 /// Attempts to open a port to use as the server. The server can only open a
 /// port if it has a pending receive operation
-template <uint32_t lane_size>
-[[clang::convergent]] LIBC_INLINE
-    cpp::optional<typename Server<lane_size>::Port>
-    Server<lane_size>::try_open(uint32_t start) {
+[[clang::convergent]] LIBC_INLINE cpp::optional<typename Server::Port>
+Server::try_open(uint32_t lane_size, uint32_t start) {
   // Perform a naive linear scan for a port that has a pending request.
   for (uint32_t index = start; index < process.port_count; ++index) {
     uint64_t lane_mask = gpu::get_lane_mask();
@@ -588,10 +588,9 @@ template <uint32_t lane_size>
   return cpp::nullopt;
 }
 
-template <uint32_t lane_size>
-LIBC_INLINE typename Server<lane_size>::Port Server<lane_size>::open() {
+LIBC_INLINE Server::Port Server::open(uint32_t lane_size) {
   for (;;) {
-    if (cpp::optional<Server::Port> p = try_open())
+    if (cpp::optional<Server::Port> p = try_open(lane_size))
       return cpp::move(p.value());
     sleep_briefly();
   }
diff --git a/libc/utils/gpu/server/rpc_server.cpp b/libc/utils/gpu/server/rpc_server.cpp
index 707807a5cbaf7d..e21a9c05eaa68f 100644
--- a/libc/utils/gpu/server/rpc_server.cpp
+++ b/libc/utils/gpu/server/rpc_server.cpp
@@ -27,228 +27,218 @@ static_assert(sizeof(rpc_buffer_t) == sizeof(rpc::Buffer),
 static_assert(RPC_MAXIMUM_PORT_COUNT == rpc::MAX_PORT_COUNT,
               "Incorrect maximum port count");
 
-// The client needs to support different lane sizes for the SIMT model. Because
-// of this we need to select between the possible sizes that the client can use.
-struct Server {
-  template <uint32_t lane_size>
-  Server(std::unique_ptr<rpc::Server<lane_size>> &&server)
-      : server(std::move(server)) {}
-
-  rpc_status_t handle_server(
-      const std::unordered_map<rpc_opcode_t, rpc_opcode_callback_ty> &callbacks,
-      const std::unordered_map<rpc_opcode_t, void *> &callback_data,
-      uint32_t &index) {
-    rpc_status_t ret = RPC_STATUS_SUCCESS;
-    std::visit(
-        [&](auto &server) {
-          ret = handle_server(*server, callbacks, callback_data, index);
-        },
-        server);
-    return ret;
-  }
-
-private:
-  template <uint32_t lane_size>
-  rpc_status_t handle_server(
-      rpc::Server<lane_size> &server,
-      const std::unordered_map<rpc_opcode_t, rpc_opcode_callback_ty> &callbacks,
-      const std::unordered_map<rpc_opcode_t, void *> &callback_data,
-      uint32_t &index) {
-    auto port = server.try_open(index);
-    if (!port)
-      return RPC_STATUS_SUCCESS;
-
-    switch (port->get_opcode()) {
-    case RPC_WRITE_TO_STREAM:
-    case RPC_WRITE_TO_STDERR:
-    case RPC_WRITE_TO_STDOUT:
-    case RPC_WRITE_TO_STDOUT_NEWLINE: {
-      uint64_t sizes[lane_size] = {0};
-      void *strs[lane_size] = {nullptr};
-      FILE *files[lane_size] = {nullptr};
-      if (port->get_opcode() == RPC_WRITE_TO_STREAM) {
-        port->recv([&](rpc::Buffer *buffer, uint32_t id) {
-          files[id] = reinterpret_cast<FILE *>(buffer->data[0]);
-        });
-      } else if (port->get_opcode() == RPC_WRITE_TO_STDERR) {
-        std::fill(files, files + lane_size, stderr);
-      } else {
-        std::fill(files, files + lane_size, stdout);
-      }
-
-      port->recv_n(strs, sizes, [&](uint64_t size) { return new char[size]; });
-      port->send([&](rpc::Buffer *buffer, uint32_t id) {
-        flockfile(files[id]);
-        buffer->data[0] = fwrite_unlocked(strs[id], 1, sizes[id], files[id]);
-        if (port->get_opcode() == RPC_WRITE_TO_STDOUT_NEWLINE &&
-            buffer->data[0] == sizes[id])
-          buffer->data[0] += fwrite_unlocked("\n", 1, 1, files[id]);
-        funlockfile(files[id]);
-        delete[] reinterpret_cast<uint8_t *>(strs[id]);
-      });
-      break;
-    }
-    case RPC_READ_FROM_STREAM: {
-      uint64_t sizes[lane_size] = {0};
-      void *data[lane_size] = {nullptr};
-      port->recv([&](rpc::Buffer *buffer, uint32_t id) {
-        data[id] = new char[buffer->data[0]];
-        sizes[id] = fread(data[id], 1, buffer->data[0],
-                          file::to_stream(buffer->data[1]));
-      });
-      port->send_n(data, sizes);
-      port->send([&](rpc::Buffer *buffer, uint32_t id) {
-        delete[] reinterpret_cast<uint8_t *>(data[id]);
-        std::memcpy(buffer->data, &sizes[id], sizeof(uint64_t));
-      });
-      break;
-    }
-    case RPC_READ_FGETS: {
-      uint64_t sizes[lane_size] = {0};
-      void *data[lane_size] = {nullptr};
-      port->recv([&](rpc::Buffer *buffer, uint32_t id) {
-        data[id] = new char[buffer->data[0]];
-        const char *str =
-            fgets(reinterpret_cast<char *>(data[id]), buffer->data[0],
-                  file::to_stream(buffer->data[1]));
-        sizes[id] = !str ? 0 : std::strlen(str) + 1;
-      });
-      port->send_n(data, sizes);
-      for (uint32_t id = 0; id < lane_size; ++id)
-        if (data[id])
-          delete[] reinterpret_cast<uint8_t *>(data[id]);
-      break;
-    }
-    case RPC_OPEN_FILE: {
-      uint64_t sizes[lane_size] = {0};
-      void *paths[lane_size] = {nullptr};
-      port->recv_n(paths, sizes, [&](uint64_t size) { return new char[size]; });
-      port->recv_and_send([&](rpc::Buffer *buffer, uint32_t id) {
-        FILE *file = fopen(reinterpret_cast<char *>(paths[id]),
-                           reinterpret_cast<char *>(buffer->data));
-        buffer->data[0] = reinterpret_cast<uintptr_t>(file);
-      });
-      break;
-    }
-    case RPC_CLOSE_FILE: {
-      port->recv_and_send([&](rpc::Buffer *buffer, uint32_t id) {
-        FILE *file = reinterpret_cast<FILE *>(buffer->data[0]);
-        buffer->data[0] = fclose(file);
-      });
-      break;
-    }
-    case RPC_EXIT: {
-      // Send a response to the client to signal that we are ready to exit.
-      port->recv_and_send([](rpc::Buffer *) {});
-      port->recv([](rpc::Buffer *buffer) {
-        int status = 0;
-        std::memcpy(&status, buffer->data, sizeof(int));
-        exit(status);
-      });
-      break;
-    }
-    case RPC_ABORT: {
-      // Send a response to the client to signal that we are ready to abort.
-      port->recv_and_send([](rpc::Buffer *) {});
-      port->recv([](rpc::Buffer *) {});
-      abort();
-      break;
-    }
-    case RPC_HOST_CALL: {
-      uint64_t sizes[lane_size] = {0};
-      void *args[lane_size] = {nullptr};
-      port->recv_n(args, sizes, [&](uint64_t size) { return new char[size]; });
+template <uint32_t lane_size>
+rpc_status_t handle_server_impl(
+    rpc::Server &server,
+    const std::unordered_map<rpc_opcode_t, rpc_opcode_callback_ty> &callbacks,
+    const std::unordered_map<rpc_opcode_t, void *> &callback_data,
+    uint32_t &index) {
+  auto port = server.try_open(lane_size, index);
+  if (!port)
+    return RPC_STATUS_SUCCESS;
+
+  switch (port->get_opcode()) {
+  case RPC_WRITE_TO_STREAM:
+  case RPC_WRITE_TO_STDERR:
+  case RPC_WRITE_TO_STDOUT:
+  case RPC_WRITE_TO_STDOUT_NEWLINE: {
+    uint64_t sizes[lane_size] = {0};
+    void *strs[lane_size] = {nullptr};
+    FILE *files[lane_size] = {nullptr};
+    if (port->get_opcode() == RPC_WRITE_TO_STREAM) {
       port->recv([&](rpc::Buffer *buffer, uint32_t id) {
-        reinterpret_cast<void (*)(void *)>(buffer->data[0])(args[id]);
-      });
-      port->send([&](rpc::Buffer *, uint32_t id) {
-        delete[] reinterpret_cast<uint8_t *>(args[id]);
-      });
-      break;
-    }
-    case RPC_FEOF: {
-      port->recv_and_send([](rpc::Buffer *buffer) {
-        buffer->data[0] = feof(file::to_stream(buffer->data[0]));
-      });
-      break;
-    }
-    case RPC_FERROR: {
-      port->recv_and_send([](rpc::Buffer *buffer) {
-        buffer->data[0] = ferror(file::to_stream(buffer->data[0]));
-      });
-      break;
-    }
-    case RPC_CLEARERR: {
-      port->recv_and_send([](rpc::Buffer *buffer) {
-        clearerr(file::to_stream(buffer->data[0]));
-      });
-      break;
-    }
-    case RPC_FSEEK: {
-      port->recv_and_send([](rpc::Buffer *buffer) {
-        buffer->data[0] = fseek(file::to_stream(buffer->data[0]),
-                                static_cast<long>(buffer->data[1]),
-                                static_cast<int>(buffer->data[2]));
-      });
-      break;
-    }
-    case RPC_FTELL: {
-      port->recv_and_send([](rpc::Buffer *buffer) {
-        buffer->data[0] = ftell(file::to_stream(buffer->data[0]));
-      });
-      break;
-    }
-    case RPC_FFLUSH: {
-      port->recv_and_send([](rpc::Buffer *buffer) {
-        buffer->data[0] = fflush(file::to_stream(buffer->data[0]));
+        files[id] = reinterpret_cast<FILE *>(buffer->data[0]);
       });
-      break;
-    }
-    case RPC_UNGETC: {
-      port->recv_and_send([](rpc::Buffer *buffer) {
-        buffer->data[0] = ungetc(static_cast<int>(buffer->data[0]),
-                                 file::to_stream(buffer->data[1]));
-      });
-      break;
-    }
-    case RPC_NOOP: {
-      port->recv([](rpc::Buffer *) {});
-      break;
-    }
-    default: {
-      auto handler =
-          callbacks.find(static_cast<rpc_opcode_t>(port->get_opcode()));
-
-      // We error out on an unhandled opcode.
-      if (handler == callbacks.end())
-        return RPC_STATUS_UNHANDLED_OPCODE;
-
-      // Invoke the registered callback with a reference to the port.
-      void *data =
-          callback_data.at(static_cast<rpc_opcode_t>(port->get_opcode()));
-      rpc_port_t port_ref{reinterpret_cast<uint64_t>(&*port), lane_size};
-      (handler->second)(port_ref, data);
-    }
+    } else if (port->get_opcode() == RPC_WRITE_TO_STDERR) {
+      std::fill(files, files + lane_size, stderr);
+    } else {
+      std::fill(files, files + lane_size, stdout);
     }
 
-    // Increment the index so we start the scan after this port.
-    index = port->get_index() + 1;
-    port->close();
-    return RPC_STATUS_CONTINUE;
+    port->recv_n(strs, sizes, [&](uint64_t size) { return new char[size]; });
+    port->send([&](rpc::Buffer *buffer, uint32_t id) {
+      flockfile(files[id]);
+      buffer->data[0] = fwrite_unlocked(strs[id], 1, sizes[id], files[id]);
+      if (port->get_opcode() == RPC_WRITE_TO_STDOUT_NEWLINE &&
+          buffer->data[0] == sizes[id])
+        buffer->data[0] += fwrite_unlocked("\n", 1, 1, files[id]);
+      funlockfile(files[id]);
+      delete[] reinterpret_cast<uint8_t *>(strs[id]);
+    });
+    break;
+  }
+  case RPC_READ_FROM_STREAM: {
+    uint64_t sizes[lane_size] = {0};
+    void *data[lane_size] = {nullptr};
+    port->recv([&](rpc::Buffer *buffer, uint32_t id) {
+      data[id] = new char[buffer->data[0]];
+      sizes[id] =
+          fread(data[id], 1, buffer->data[0], file::to_stream(buffer->data[1]));
+    });
+    port->send_n(data, sizes);
+    port->send([&](rpc::Buffer *buffer, uint32_t id) {
+      delete[] reinterpret_cast<uint8_t *>(data[id]);
+      std::memcpy(buffer->data, &sizes[id], sizeof(uint64_t));
+    });
+    break;
+  }
+  case RPC_READ_FGETS: {
+    uint64_t sizes[lane_size] = {0};
+    void *data[lane_size] = {nullptr};
+    port->recv([&](rpc::Buffer *buffer, uint32_t id) {
+      data[id] = new char[buffer->data[0]];
+      const char *str =
+          fgets(reinterpret_cast<char *>(data[id]), buffer->data[0],
+                file::to_stream(buffer->data[1]));
+      sizes[id] = !str ? 0 : std::strlen(str) + 1;
+    });
+    port->send_n(data, sizes);
+    for (uint32_t id = 0; id < lane_size; ++id)
+      if (data[id])
+        delete[] reinterpret_cast<uint8_t *>(data[id]);
+    break;
+  }
+  case RPC_OPEN_FILE: {
+    uint64_t sizes[lane_size] = {0};
+    void *paths[lane_size] = {nullptr};
+    port->recv_n(paths, sizes, [&](uint64_t size) { return new char[size]; });
+    port->recv_and_send([&](rpc::Buffer *buffer, uint32_t id) {
+      FILE *file = fopen(reinterpret_cast<char *>(paths[id]),
+                         reinterpret_cast<char *>(buffer->data));
+      buffer->data[0] = reinterpret_cast<uintptr_t>(file);
+    });
+    break;
+  }
+  case RPC_CLOSE_FILE: {
+    port->recv_and_send([&](rpc::Buffer *buffer, uint32_t id) {
+      FILE *file = reinterpret_cast<FILE *>(buffer->data[0]);
+      buffer->data[0] = fclose(file);
+    });
+    break;
+  }
+  case RPC_EXIT: {
+    // Send a response to the client to signal that we are ready to exit.
+    port->recv_and_send([](rpc::Buffer *) {});
+    port->recv([](rpc::Buffer *buffer) {
+      int status = 0;
+      std::memcpy(&status, buffer->data, sizeof(int));
+      exit(status);
+    });
+    break;
+  }
+  case RPC_ABORT: {
+    // Send a response to the client to signal that we are ready to abort.
+    port->recv_and_send([](rpc::Buffer *) {});
+    port->recv([](rpc::Buffer *) {});
+    abort();
+    break;
+  }
+  case RPC_HOST_CALL: {
+    uint64_t sizes[lane_size] = {0};
+    void *args[lane_size] = {nullptr};
+    port->recv_n(args, sizes, [&](uint64_t size) { return new char[size]; });
+    port->recv([&](rpc::Buffer *buffer, uint32_t id) {
+      reinterpret_cast<void (*)(void *)>(buffer->data[0])(args[id]);
+    });
+    port->send([&](rpc::Buffer *, uint32_t id) {
+      delete[] reinterpret_cast<uint8_t *>(args[id]);
+    });
+    break;
+  }
+  case RPC_FEOF: {
+    port->recv_and_send([](rpc::Buffer *buffer) {
+      buffer->data[0] = feof(file::to_stream(buffer->data[0]));
+    });
+    break;
+  }
+  case RPC_FERROR: {
+    port->recv_and_send([](rpc::Buffer *buffer) {
+      buffer->data[0] = ferror(file::to_stream(buffer->data[0]));
+    });
+    break;
+  }
+  case RPC_CLEARERR: {
+    port->recv_and_send([](rpc::Buffer *buffer) {
+      clearerr(file::to_stream(buffer->data[0]));
+    });
+    break;
+  }
+  case RPC_FSEEK: {
+    port->recv_and_send([](rpc::Buffer *buffer) {
+      buffer->data[0] = fseek(file::to_stream(buffer->data[0]),
+                              static_cast<long>(buffer->data[1]),
+                              static_cast<int>(buffer->data[2]));
+    });
+    break;
+  }
+  case RPC_FTELL: {
+    port->recv_and_send([](rpc::Buffer *buffer) {
+      buffer->data[0] = ftell(file::to_stream(buffer->data[0]));
+    });
+    break;
+  }
+  case RPC_FFLUSH: {
+    port->recv_and_send([](rpc::Buffer *buffer) {
+      buffer->data[0] = fflush(file::to_stream(buffer->data[0]));
+    });
+    break;
+  }
+  case RPC_UNGETC: {
+    port->recv_and_send([](rpc::Buffer *buffer) {
+      buffer->data[0] = ungetc(static_cast<int>(buffer->data[0]),
+                               file::to_stream(buffer->data[1]));
+    });
+    break;
+  }
+  case RPC_NOOP: {
+    port->recv([](rpc::Buffer *) {});
+    break;
+  }
+  default: {
+    auto handler =
+        callbacks.find(static_cast<rpc_opcode_t>(port->get_opcode()));
+
+    // We error out on an unhandled opcode.
+    if (handler == callbacks.end())
+      return RPC_STATUS_UNHANDLED_OPCODE;
+
+    // Invoke the registered callback with a reference to the port.
+    void *data =
+        callback_data.at(static_cast<rpc_opcode_t>(port->get_opcode()));
+    rpc_port_t port_ref{reinterpret_cast<uint64_t>(&*port), lane_size};
+    (handler->second)(port_ref, data);
+  }
   }
 
-  std::variant<std::unique_ptr<rpc::Server<1>>,
-               std::unique_ptr<rpc::Server<32>>,
-               std::unique_ptr<rpc::Server<64>>>
-      server;
-};
+  // Increment the index so we start the scan after this port.
+  index = port->get_index() + 1;
+  port->close();
+
+  return RPC_STATUS_CONTINUE;
+}
 
 struct Device {
-  template <typename T>
-  Device(uint32_t num_ports, void *buffer, std::unique_ptr<T> &&server)
-      : buffer(buffer), server(std::move(server)), client(num_ports, buffer) {}
+  Device(uint32_t lane_size, uint32_t num_ports, void *buffer)
+      : lane_size(lane_size), buffer(buffer), server(num_ports, buffer),
+        client(num_ports, buffer) {}
+
+  rpc_status_t handle_server(uint32_t &index) {
+    switch (lane_size) {
+    case 1:
+      return handle_server_impl<1>(server, callbacks, callback_data, index);
+    case 32:
+      return handle_server_impl<32>(server, callbacks, callback_data, index);
+
+    case 64:
+      return handle_server_impl<64>(server, callbacks, callback_data, index);
+    default:
+      return RPC_STATUS_INVALID_LANE_SIZE;
+    }
+  }
+
+  uint32_t lane_size;
   void *buffer;
-  Server server;
+  rpc::Server server;
   rpc::Client client;
   std::unordered_map<rpc_opcode_t, rpc_opcode_callback_ty> callbacks;
   std::unordered_map<rpc_opcode_t, void *> callback_data;
@@ -287,24 +277,6 @@ rpc_status_t rpc_shutdown(void) {
   return RPC_STATUS_SUCCESS;
 }
 
-template <uint32_t lane_size>
-rpc_status_t server_init_impl(uint32_t device_id, uint64_t num_ports,
-                              rpc_alloc_ty alloc, void *data) {
-  uint64_t size = rpc::Server<lane_size>::allocation_size(num_ports);
-  void *buffer = alloc(size, data);
-
-  if (!buffer)
-    return RPC_STATUS_ERROR;
-
-  state->devices[device_id] = std::make_unique<Device>(
-      num_ports, buffer,
-      std::make_unique<rpc::Server<lane_size>>(num_ports, buffer));
-  if (!state->devices[device_id])
-    return RPC_STATUS_ERROR;
-
-  return RPC_STATUS_SUCCESS;
-}
-
 rpc_status_t rpc_server_init(uint32_t device_id, uint64_t num_ports,
                              uint32_t lane_size, rpc_alloc_ty alloc,
                              void *data) {
@@ -312,28 +284,20 @@ rpc_status_t rpc_server_init(uint32_t device_id, uint64_t num_ports,
     return RPC_STATUS_NOT_INITIALIZED;
   if (device_id >= state->num_devices)
     return RPC_STATUS_OUT_OF_RANGE;
+  if (lane_size != 1 && lane_size != 32 && lane_size != 64)
+    return RPC_STATUS_INVALID_LANE_SIZE;
 
   if (!state->devices[device_id]) {
-    switch (lane_size) {
-    case 1:
-      if (rpc_status_t err =
-              server_init_impl<1>(device_id, num_ports, alloc, data))
-        return err;
-      break;
-    case 32: {
-      if (rpc_status_t err =
-              server_init_impl<32>(device_id, num_ports, alloc, data))
-        return err;
-      break;
-    }
-    case 64:
-      if (rpc_status_t err =
-              server_init_impl<64>(device_id, num_ports, alloc, data))
-        return err;
-      break;
-    default:
-      return RPC_STATUS_INVALID_LANE_SIZE;
-    }
+    uint64_t size = rpc::Server::allocation_size(lane_size, num_ports);
+    void *buffer = alloc(size, data);
+
+    if (!buffer)
+      return RPC_STATUS_ERROR;
+
+    state->devices[device_id] =
+        std::make_unique<Device>(lane_size, num_ports, buffer);
+    if (!state->devices[device_id])
+      return RPC_STATUS_ERROR;
   }
 
   return RPC_STATUS_SUCCESS;
@@ -365,9 +329,8 @@ rpc_status_t rpc_handle_server(uint32_t device_id) {
 
   uint32_t index = 0;
   for (;;) {
-    auto &device = *state->devices[device_id];
-    rpc_status_t status = device.server.handle_server(
-        device.callbacks, device.callback_data, index);
+    Device &device = *state->devices[device_id];
+    rpc_status_t status = device.handle_server(index);
     if (status != RPC_STATUS_CONTINUE)
       return status;
   }
@@ -396,26 +359,26 @@ const void *rpc_get_client_buffer(uint32_t device_id) {
 
 uint64_t rpc_get_client_size() { return sizeof(rpc::Client); }
 
-using ServerPort = std::variant<rpc::Server<0>::Port *>;
+using ServerPort = std::variant<rpc::Server::Port *>;
 
 ServerPort get_port(rpc_port_t ref) {
-  return reinterpret_cast<rpc::Server<0>::Port *>(ref.handle);
+  return reinterpret_cast<rpc::Server::Port *>(ref.handle);
 }
 
 void rpc_send(rpc_port_t ref, rpc_port_callback_ty callback, void *data) {
-  auto port = reinterpret_cast<rpc::Server<0>::Port *>(ref.handle);
+  auto port = reinterpret_cast<rpc::Server::Port *>(ref.handle);
   port->send([=](rpc::Buffer *buffer) {
     callback(reinterpret_cast<rpc_buffer_t *>(buffer), data);
   });
 }
 
 void rpc_send_n(rpc_port_t ref, const void *const *src, uint64_t *size) {
-  auto port = reinterpret_cast<rpc::Server<0>::Port *>(ref.handle);
+  auto port = reinterpret_cast<rpc::Server::Port *>(ref.handle);
   port->send_n(src, size);
 }
 
 void rpc_recv(rpc_port_t ref, rpc_port_callback_ty callback, void *data) {
-  auto port = reinterpret_cast<rpc::Server<0>::Port *>(ref.handle);
+  auto port = reinterpret_cast<rpc::Server::Port *>(ref.handle);
   port->recv([=](rpc::Buffer *buffer) {
     callback(reinterpret_cast<rpc_buffer_t *>(buffer), data);
   });
@@ -423,14 +386,14 @@ void rpc_recv(rpc_port_t ref, rpc_port_callback_ty callback, void *data) {
 
 void rpc_recv_n(rpc_port_t ref, void **dst, uint64_t *size, rpc_alloc_ty alloc,
                 void *data) {
-  auto port = reinterpret_cast<rpc::Server<0>::Port *>(ref.handle);
+  auto port = reinterpret_cast<rpc::Server::Port *>(ref.handle);
   auto alloc_fn = [=](uint64_t size) { return alloc(size, data); };
   port->recv_n(dst, size, alloc_fn);
 }
 
 void rpc_recv_and_send(rpc_port_t ref, rpc_port_callback_ty callback,
                        void *data) {
-  auto port = reinterpret_cast<rpc::Server<0>::Port *>(ref.handle);
+  auto port = reinterpret_cast<rpc::Server::Port *>(ref.handle);
   port->recv_and_send([=](rpc::Buffer *buffer) {
     callback(reinterpret_cast<rpc_buffer_t *>(buffer), data);
   });



More information about the cfe-commits mailing list