[Lldb-commits] [lldb] [lldb] Add a MainLoop version of DomainSocket::Accept (PR #108188)

via lldb-commits lldb-commits at lists.llvm.org
Wed Sep 11 03:58:26 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-lldb

Author: Pavel Labath (labath)

<details>
<summary>Changes</summary>

To go along with the existing TCPSocket implementation.

---
Full diff: https://github.com/llvm/llvm-project/pull/108188.diff


9 Files Affected:

- (modified) lldb/include/lldb/Host/Socket.h (+12-1) 
- (modified) lldb/include/lldb/Host/common/TCPSocket.h (+2-8) 
- (modified) lldb/include/lldb/Host/common/UDPSocket.h (+7-1) 
- (modified) lldb/include/lldb/Host/posix/DomainSocket.h (+7-1) 
- (modified) lldb/source/Host/common/Socket.cpp (+14) 
- (modified) lldb/source/Host/common/TCPSocket.cpp (+3-16) 
- (modified) lldb/source/Host/common/UDPSocket.cpp (-4) 
- (modified) lldb/source/Host/posix/DomainSocket.cpp (+34-8) 
- (modified) lldb/unittests/Host/SocketTest.cpp (+38-2) 


``````````diff
diff --git a/lldb/include/lldb/Host/Socket.h b/lldb/include/lldb/Host/Socket.h
index 764a048976eb41..14468c98ac5a3a 100644
--- a/lldb/include/lldb/Host/Socket.h
+++ b/lldb/include/lldb/Host/Socket.h
@@ -12,6 +12,7 @@
 #include <memory>
 #include <string>
 
+#include "lldb/Host/MainLoopBase.h"
 #include "lldb/lldb-private.h"
 
 #include "lldb/Host/SocketAddress.h"
@@ -97,7 +98,17 @@ class Socket : public IOObject {
 
   virtual Status Connect(llvm::StringRef name) = 0;
   virtual Status Listen(llvm::StringRef name, int backlog) = 0;
-  virtual Status Accept(Socket *&socket) = 0;
+
+  // Use the provided main loop instance to accept new connections. The callback
+  // will be called (from MainLoop::Run) for each new connection. This function
+  // does not block.
+  virtual llvm::Expected<std::vector<MainLoopBase::ReadHandleUP>>
+  Accept(MainLoopBase &loop,
+         std::function<void(std::unique_ptr<Socket> socket)> sock_cb) = 0;
+
+  // Accept a single connection and "return" it in the pointer argument. This
+  // function blocks until the connection arrives.
+  virtual Status Accept(Socket *&socket);
 
   // Initialize a Tcp Socket object in listening mode.  listen and accept are
   // implemented separately because the caller may wish to manipulate or query
diff --git a/lldb/include/lldb/Host/common/TCPSocket.h b/lldb/include/lldb/Host/common/TCPSocket.h
index 78e80568e39967..eefe0240fe4a95 100644
--- a/lldb/include/lldb/Host/common/TCPSocket.h
+++ b/lldb/include/lldb/Host/common/TCPSocket.h
@@ -42,16 +42,10 @@ class TCPSocket : public Socket {
   Status Connect(llvm::StringRef name) override;
   Status Listen(llvm::StringRef name, int backlog) override;
 
-  // Use the provided main loop instance to accept new connections. The callback
-  // will be called (from MainLoop::Run) for each new connection. This function
-  // does not block.
+  using Socket::Accept;
   llvm::Expected<std::vector<MainLoopBase::ReadHandleUP>>
   Accept(MainLoopBase &loop,
-         std::function<void(std::unique_ptr<TCPSocket> socket)> sock_cb);
-
-  // Accept a single connection and "return" it in the pointer argument. This
-  // function blocks until the connection arrives.
-  Status Accept(Socket *&conn_socket) override;
+         std::function<void(std::unique_ptr<Socket> socket)> sock_cb) override;
 
   Status CreateSocket(int domain);
 
diff --git a/lldb/include/lldb/Host/common/UDPSocket.h b/lldb/include/lldb/Host/common/UDPSocket.h
index bae707e345d87c..7348010d02ada2 100644
--- a/lldb/include/lldb/Host/common/UDPSocket.h
+++ b/lldb/include/lldb/Host/common/UDPSocket.h
@@ -27,7 +27,13 @@ class UDPSocket : public Socket {
   size_t Send(const void *buf, const size_t num_bytes) override;
   Status Connect(llvm::StringRef name) override;
   Status Listen(llvm::StringRef name, int backlog) override;
-  Status Accept(Socket *&socket) override;
+
+  llvm::Expected<std::vector<MainLoopBase::ReadHandleUP>>
+  Accept(MainLoopBase &loop,
+         std::function<void(std::unique_ptr<Socket> socket)> sock_cb) override {
+    return llvm::errorCodeToError(
+        std::make_error_code(std::errc::operation_not_supported));
+  }
 
   SocketAddress m_sockaddr;
 };
diff --git a/lldb/include/lldb/Host/posix/DomainSocket.h b/lldb/include/lldb/Host/posix/DomainSocket.h
index 35c33811f60de6..983f43bd633719 100644
--- a/lldb/include/lldb/Host/posix/DomainSocket.h
+++ b/lldb/include/lldb/Host/posix/DomainSocket.h
@@ -14,11 +14,17 @@
 namespace lldb_private {
 class DomainSocket : public Socket {
 public:
+  DomainSocket(NativeSocket socket, bool should_close,
+               bool child_processes_inherit);
   DomainSocket(bool should_close, bool child_processes_inherit);
 
   Status Connect(llvm::StringRef name) override;
   Status Listen(llvm::StringRef name, int backlog) override;
-  Status Accept(Socket *&socket) override;
+
+  using Socket::Accept;
+  llvm::Expected<std::vector<MainLoopBase::ReadHandleUP>>
+  Accept(MainLoopBase &loop,
+         std::function<void(std::unique_ptr<Socket> socket)> sock_cb) override;
 
   std::string GetRemoteConnectionURI() const override;
 
diff --git a/lldb/source/Host/common/Socket.cpp b/lldb/source/Host/common/Socket.cpp
index 1a63571b94c6f1..d69eb608204033 100644
--- a/lldb/source/Host/common/Socket.cpp
+++ b/lldb/source/Host/common/Socket.cpp
@@ -10,6 +10,7 @@
 
 #include "lldb/Host/Config.h"
 #include "lldb/Host/Host.h"
+#include "lldb/Host/MainLoop.h"
 #include "lldb/Host/SocketAddress.h"
 #include "lldb/Host/common/TCPSocket.h"
 #include "lldb/Host/common/UDPSocket.h"
@@ -459,6 +460,19 @@ NativeSocket Socket::CreateSocket(const int domain, const int type,
   return sock;
 }
 
+Status Socket::Accept(Socket *&socket) {
+  MainLoop accept_loop;
+  llvm::Expected<std::vector<MainLoopBase::ReadHandleUP>> expected_handles =
+      Accept(accept_loop,
+             [&accept_loop, &socket](std::unique_ptr<Socket> sock) {
+               socket = sock.release();
+               accept_loop.RequestTermination();
+             });
+  if (!expected_handles)
+    return Status::FromError(expected_handles.takeError());
+  return accept_loop.Run();
+}
+
 NativeSocket Socket::AcceptSocket(NativeSocket sockfd, struct sockaddr *addr,
                                   socklen_t *addrlen,
                                   bool child_processes_inherit, Status &error) {
diff --git a/lldb/source/Host/common/TCPSocket.cpp b/lldb/source/Host/common/TCPSocket.cpp
index b28ba148ee1afa..2d16b605af9497 100644
--- a/lldb/source/Host/common/TCPSocket.cpp
+++ b/lldb/source/Host/common/TCPSocket.cpp
@@ -232,9 +232,9 @@ void TCPSocket::CloseListenSockets() {
   m_listen_sockets.clear();
 }
 
-llvm::Expected<std::vector<MainLoopBase::ReadHandleUP>> TCPSocket::Accept(
-    MainLoopBase &loop,
-    std::function<void(std::unique_ptr<TCPSocket> socket)> sock_cb) {
+llvm::Expected<std::vector<MainLoopBase::ReadHandleUP>>
+TCPSocket::Accept(MainLoopBase &loop,
+                  std::function<void(std::unique_ptr<Socket> socket)> sock_cb) {
   if (m_listen_sockets.size() == 0)
     return llvm::createStringError("No open listening sockets!");
 
@@ -278,19 +278,6 @@ llvm::Expected<std::vector<MainLoopBase::ReadHandleUP>> TCPSocket::Accept(
   return handles;
 }
 
-Status TCPSocket::Accept(Socket *&conn_socket) {
-  MainLoop accept_loop;
-  llvm::Expected<std::vector<MainLoopBase::ReadHandleUP>> expected_handles =
-      Accept(accept_loop,
-             [&accept_loop, &conn_socket](std::unique_ptr<TCPSocket> sock) {
-               conn_socket = sock.release();
-               accept_loop.RequestTermination();
-             });
-  if (!expected_handles)
-    return Status::FromError(expected_handles.takeError());
-  return accept_loop.Run();
-}
-
 int TCPSocket::SetOptionNoDelay() {
   return SetOption(IPPROTO_TCP, TCP_NODELAY, 1);
 }
diff --git a/lldb/source/Host/common/UDPSocket.cpp b/lldb/source/Host/common/UDPSocket.cpp
index 2a7a6cff414b14..05d7b2e6506027 100644
--- a/lldb/source/Host/common/UDPSocket.cpp
+++ b/lldb/source/Host/common/UDPSocket.cpp
@@ -47,10 +47,6 @@ Status UDPSocket::Listen(llvm::StringRef name, int backlog) {
   return Status::FromErrorStringWithFormat("%s", g_not_supported_error);
 }
 
-Status UDPSocket::Accept(Socket *&socket) {
-  return Status::FromErrorStringWithFormat("%s", g_not_supported_error);
-}
-
 llvm::Expected<std::unique_ptr<UDPSocket>>
 UDPSocket::Connect(llvm::StringRef name, bool child_processes_inherit) {
   std::unique_ptr<UDPSocket> socket;
diff --git a/lldb/source/Host/posix/DomainSocket.cpp b/lldb/source/Host/posix/DomainSocket.cpp
index 2d18995c3bb469..369123f2239302 100644
--- a/lldb/source/Host/posix/DomainSocket.cpp
+++ b/lldb/source/Host/posix/DomainSocket.cpp
@@ -7,11 +7,13 @@
 //===----------------------------------------------------------------------===//
 
 #include "lldb/Host/posix/DomainSocket.h"
+#include "lldb/Utility/LLDBLog.h"
 
 #include "llvm/Support/Errno.h"
 #include "llvm/Support/FileSystem.h"
 
 #include <cstddef>
+#include <memory>
 #include <sys/socket.h>
 #include <sys/un.h>
 
@@ -57,7 +59,14 @@ static bool SetSockAddr(llvm::StringRef name, const size_t name_offset,
 }
 
 DomainSocket::DomainSocket(bool should_close, bool child_processes_inherit)
-    : Socket(ProtocolUnixDomain, should_close, child_processes_inherit) {}
+    : DomainSocket(kInvalidSocketValue, should_close, child_processes_inherit) {
+}
+
+DomainSocket::DomainSocket(NativeSocket socket, bool should_close,
+                           bool child_processes_inherit)
+    : Socket(ProtocolUnixDomain, should_close, child_processes_inherit) {
+  m_socket = socket;
+}
 
 DomainSocket::DomainSocket(SocketProtocol protocol,
                            bool child_processes_inherit)
@@ -108,14 +117,31 @@ Status DomainSocket::Listen(llvm::StringRef name, int backlog) {
   return error;
 }
 
-Status DomainSocket::Accept(Socket *&socket) {
-  Status error;
-  auto conn_fd = AcceptSocket(GetNativeSocket(), nullptr, nullptr,
-                              m_child_processes_inherit, error);
-  if (error.Success())
-    socket = new DomainSocket(conn_fd, *this);
+llvm::Expected<std::vector<MainLoopBase::ReadHandleUP>> DomainSocket::Accept(
+    MainLoopBase &loop,
+    std::function<void(std::unique_ptr<Socket> socket)> sock_cb) {
+  // TODO: Refactor MainLoop to avoid the shared_ptr requirement.
+  auto io_sp = std::make_shared<DomainSocket>(GetNativeSocket(), false,
+                                              m_child_processes_inherit);
+  auto cb = [this, sock_cb](MainLoopBase &loop) {
+    Log *log = GetLog(LLDBLog::Host);
+    Status error;
+    auto conn_fd = AcceptSocket(GetNativeSocket(), nullptr, nullptr,
+                                m_child_processes_inherit, error);
+    if (error.Fail()) {
+      LLDB_LOG(log, "AcceptSocket({0}): {1}", GetNativeSocket(), error);
+      return;
+    }
+    std::unique_ptr<DomainSocket> sock_up(new DomainSocket(conn_fd, *this));
+    sock_cb(std::move(sock_up));
+  };
 
-  return error;
+  Status error;
+  std::vector<MainLoopBase::ReadHandleUP> handles;
+  handles.emplace_back(loop.RegisterReadObject(io_sp, cb, error));
+  if (error.Fail())
+    return error.ToError();
+  return handles;
 }
 
 size_t DomainSocket::GetNameOffset() const { return 0; }
diff --git a/lldb/unittests/Host/SocketTest.cpp b/lldb/unittests/Host/SocketTest.cpp
index 3a356d11ba1a51..a93b928e274d03 100644
--- a/lldb/unittests/Host/SocketTest.cpp
+++ b/lldb/unittests/Host/SocketTest.cpp
@@ -85,6 +85,42 @@ TEST_P(SocketTest, DomainListenConnectAccept) {
   std::unique_ptr<DomainSocket> socket_b_up;
   CreateDomainConnectedSockets(Path, &socket_a_up, &socket_b_up);
 }
+
+TEST_P(SocketTest, DomainMainLoopAccept) {
+  llvm::SmallString<64> Path;
+  std::error_code EC = llvm::sys::fs::createUniqueDirectory("DomainListenConnectAccept", Path);
+  ASSERT_FALSE(EC);
+  llvm::sys::path::append(Path, "test");
+
+  // Skip the test if the $TMPDIR is too long to hold a domain socket.
+  if (Path.size() > 107u)
+    return;
+
+  auto listen_socket_up = std::make_unique<DomainSocket>(
+      /*should_close=*/true, /*child_process_inherit=*/false);
+  Status error = listen_socket_up->Listen(Path, 5);
+  ASSERT_THAT_ERROR(error.ToError(), llvm::Succeeded());
+  ASSERT_TRUE(listen_socket_up->IsValid());
+
+  MainLoop loop;
+  std::unique_ptr<Socket> accepted_socket_up;
+  auto expected_handles = listen_socket_up->Accept(
+      loop, [&accepted_socket_up, &loop](std::unique_ptr<Socket> sock_up) {
+        accepted_socket_up = std::move(sock_up);
+        loop.RequestTermination();
+      });
+  ASSERT_THAT_EXPECTED(expected_handles, llvm::Succeeded());
+
+  auto connect_socket_up = std::make_unique<DomainSocket>(
+      /*should_close=*/true, /*child_process_inherit=*/false);
+  ASSERT_THAT_ERROR(connect_socket_up->Connect(Path).ToError(),
+                    llvm::Succeeded());
+  ASSERT_TRUE(connect_socket_up->IsValid());
+
+  loop.Run();
+  ASSERT_TRUE(accepted_socket_up);
+  ASSERT_TRUE(accepted_socket_up->IsValid());
+}
 #endif
 
 TEST_P(SocketTest, TCPListen0ConnectAccept) {
@@ -109,9 +145,9 @@ TEST_P(SocketTest, TCPMainLoopAccept) {
   ASSERT_TRUE(listen_socket_up->IsValid());
 
   MainLoop loop;
-  std::unique_ptr<TCPSocket> accepted_socket_up;
+  std::unique_ptr<Socket> accepted_socket_up;
   auto expected_handles = listen_socket_up->Accept(
-      loop, [&accepted_socket_up, &loop](std::unique_ptr<TCPSocket> sock_up) {
+      loop, [&accepted_socket_up, &loop](std::unique_ptr<Socket> sock_up) {
         accepted_socket_up = std::move(sock_up);
         loop.RequestTermination();
       });

``````````

</details>


https://github.com/llvm/llvm-project/pull/108188


More information about the lldb-commits mailing list