[llvm] 76321b9 - [llvm][Support] Implement raw_socket_stream::read with optional timeout (#92308)

via llvm-commits llvm-commits at lists.llvm.org
Sun Jul 21 20:50:31 PDT 2024


Author: Connor Sughrue
Date: 2024-07-21T23:50:28-04:00
New Revision: 76321b9f08ef31a2b8ca26f7522aee511a05f7a8

URL: https://github.com/llvm/llvm-project/commit/76321b9f08ef31a2b8ca26f7522aee511a05f7a8
DIFF: https://github.com/llvm/llvm-project/commit/76321b9f08ef31a2b8ca26f7522aee511a05f7a8.diff

LOG: [llvm][Support] Implement raw_socket_stream::read with optional timeout (#92308)

This PR implements `raw_socket_stream::read`, which overloads the base
class `raw_fd_stream::read`. `raw_socket_stream::read` provides a way to
timeout the underlying `::read`. The timeout functionality was not added
to `raw_fd_stream::read` to avoid needlessly increasing compile times
and allow for convenient code reuse with `raw_socket_stream::accept`,
which also requires timeout functionality. This PR supports the module
build daemon and will help guarantee it never becomes a zombie process.

Added: 
    

Modified: 
    llvm/include/llvm/Support/raw_socket_stream.h
    llvm/lib/Support/raw_socket_stream.cpp
    llvm/unittests/Support/raw_socket_stream_test.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Support/raw_socket_stream.h b/llvm/include/llvm/Support/raw_socket_stream.h
index eed865fb5af49..6c65a66dec9a4 100644
--- a/llvm/include/llvm/Support/raw_socket_stream.h
+++ b/llvm/include/llvm/Support/raw_socket_stream.h
@@ -92,13 +92,14 @@ class ListeningSocket {
   /// Accepts an incoming connection on the listening socket. This method can
   /// optionally either block until a connection is available or timeout after a
   /// specified amount of time has passed. By default the method will block
-  /// until the socket has recieved a connection.
+  /// until the socket has recieved a connection. If the accept timesout this
+  /// method will return std::errc:timed_out
   ///
   /// \param Timeout An optional timeout duration in milliseconds. Setting
-  /// Timeout to -1 causes accept to block indefinitely
+  /// Timeout to a negative number causes ::accept to block indefinitely
   ///
-  Expected<std::unique_ptr<raw_socket_stream>>
-  accept(std::chrono::milliseconds Timeout = std::chrono::milliseconds(-1));
+  Expected<std::unique_ptr<raw_socket_stream>> accept(
+      const std::chrono::milliseconds &Timeout = std::chrono::milliseconds(-1));
 
   /// Creates a listening socket bound to the specified file system path.
   /// Handles the socket creation, binding, and immediately starts listening for
@@ -124,11 +125,28 @@ class raw_socket_stream : public raw_fd_stream {
 
 public:
   raw_socket_stream(int SocketFD);
+  ~raw_socket_stream();
+
   /// Create a \p raw_socket_stream connected to the UNIX domain socket at \p
   /// SocketPath.
   static Expected<std::unique_ptr<raw_socket_stream>>
   createConnectedUnix(StringRef SocketPath);
-  ~raw_socket_stream();
+
+  /// Attempt to read from the raw_socket_stream's file descriptor.
+  ///
+  /// This method can optionally either block until data is read or an error has
+  /// occurred or timeout after a specified amount of time has passed. By
+  /// default the method will block until the socket has read data or
+  /// encountered an error. If the read times out this method will return
+  /// std::errc:timed_out
+  ///
+  /// \param Ptr The start of the buffer that will hold any read data
+  /// \param Size The number of bytes to be read
+  /// \param Timeout An optional timeout duration in milliseconds
+  ///
+  ssize_t read(
+      char *Ptr, size_t Size,
+      const std::chrono::milliseconds &Timeout = std::chrono::milliseconds(-1));
 };
 
 } // end namespace llvm

diff  --git a/llvm/lib/Support/raw_socket_stream.cpp b/llvm/lib/Support/raw_socket_stream.cpp
index 4cd3d58b80198..04b3233084a41 100644
--- a/llvm/lib/Support/raw_socket_stream.cpp
+++ b/llvm/lib/Support/raw_socket_stream.cpp
@@ -18,6 +18,7 @@
 
 #include <atomic>
 #include <fcntl.h>
+#include <functional>
 #include <thread>
 
 #ifndef _WIN32
@@ -177,70 +178,89 @@ Expected<ListeningSocket> ListeningSocket::createUnix(StringRef SocketPath,
 #endif // _WIN32
 }
 
-Expected<std::unique_ptr<raw_socket_stream>>
-ListeningSocket::accept(std::chrono::milliseconds Timeout) {
-
-  struct pollfd FDs[2];
-  FDs[0].events = POLLIN;
+// If a file descriptor being monitored by ::poll is closed by another thread,
+// the result is unspecified. In the case ::poll does not unblock and return,
+// when ActiveFD is closed, you can provide another file descriptor via CancelFD
+// that when written to will cause poll to return. Typically CancelFD is the
+// read end of a unidirectional pipe.
+//
+// Timeout should be -1 to block indefinitly
+//
+// getActiveFD is a callback to handle ActiveFD's of std::atomic<int> and int
+static std::error_code
+manageTimeout(const std::chrono::milliseconds &Timeout,
+              const std::function<int()> &getActiveFD,
+              const std::optional<int> &CancelFD = std::nullopt) {
+  struct pollfd FD[2];
+  FD[0].events = POLLIN;
 #ifdef _WIN32
-  SOCKET WinServerSock = _get_osfhandle(FD);
-  FDs[0].fd = WinServerSock;
+  SOCKET WinServerSock = _get_osfhandle(getActiveFD());
+  FD[0].fd = WinServerSock;
 #else
-  FDs[0].fd = FD;
+  FD[0].fd = getActiveFD();
 #endif
-  FDs[1].events = POLLIN;
-  FDs[1].fd = PipeFD[0];
-
-  // Keep track of how much time has passed in case poll is interupted by a
-  // signal and needs to be recalled
-  int RemainingTime = Timeout.count();
-  std::chrono::milliseconds ElapsedTime = std::chrono::milliseconds(0);
-  int PollStatus = -1;
-
-  while (PollStatus == -1 && (Timeout.count() == -1 || ElapsedTime < Timeout)) {
-    if (Timeout.count() != -1)
-      RemainingTime -= ElapsedTime.count();
+  uint8_t FDCount = 1;
+  if (CancelFD.has_value()) {
+    FD[1].events = POLLIN;
+    FD[1].fd = CancelFD.value();
+    FDCount++;
+  }
 
-    auto Start = std::chrono::steady_clock::now();
+  // Keep track of how much time has passed in case ::poll or WSAPoll are
+  // interupted by a signal and need to be recalled
+  auto Start = std::chrono::steady_clock::now();
+  auto RemainingTimeout = Timeout;
+  int PollStatus = 0;
+  do {
+    // If Timeout is -1 then poll should block and RemainingTimeout does not
+    // need to be recalculated
+    if (PollStatus != 0 && Timeout != std::chrono::milliseconds(-1)) {
+      auto TotalElapsedTime =
+          std::chrono::duration_cast<std::chrono::milliseconds>(
+              std::chrono::steady_clock::now() - Start);
+
+      if (TotalElapsedTime >= Timeout)
+        return std::make_error_code(std::errc::operation_would_block);
+
+      RemainingTimeout = Timeout - TotalElapsedTime;
+    }
 #ifdef _WIN32
-    PollStatus = WSAPoll(FDs, 2, RemainingTime);
+    PollStatus = WSAPoll(FD, FDCount, RemainingTimeout.count());
+  } while (PollStatus == SOCKET_ERROR &&
+           getLastSocketErrorCode() == std::errc::interrupted);
 #else
-    PollStatus = ::poll(FDs, 2, RemainingTime);
+    PollStatus = ::poll(FD, FDCount, RemainingTimeout.count());
+  } while (PollStatus == -1 &&
+           getLastSocketErrorCode() == std::errc::interrupted);
 #endif
-    // If FD equals -1 then ListeningSocket::shutdown has been called and it is
-    // appropriate to return operation_canceled
-    if (FD.load() == -1)
-      return llvm::make_error<StringError>(
-          std::make_error_code(std::errc::operation_canceled),
-          "Accept canceled");
 
+  // If ActiveFD equals -1 or CancelFD has data to be read then the operation
+  // has been canceled by another thread
+  if (getActiveFD() == -1 || (CancelFD.has_value() && FD[1].revents & POLLIN))
+    return std::make_error_code(std::errc::operation_canceled);
 #if _WIN32
-    if (PollStatus == SOCKET_ERROR) {
+  if (PollStatus == SOCKET_ERROR)
 #else
-    if (PollStatus == -1) {
+  if (PollStatus == -1)
 #endif
-      std::error_code PollErrCode = getLastSocketErrorCode();
-      // Ignore EINTR (signal occured before any request event) and retry
-      if (PollErrCode != std::errc::interrupted)
-        return llvm::make_error<StringError>(PollErrCode, "FD poll failed");
-    }
-    if (PollStatus == 0)
-      return llvm::make_error<StringError>(
-          std::make_error_code(std::errc::timed_out),
-          "No client requests within timeout window");
-
-    if (FDs[0].revents & POLLNVAL)
-      return llvm::make_error<StringError>(
-          std::make_error_code(std::errc::bad_file_descriptor));
+    return getLastSocketErrorCode();
+  if (PollStatus == 0)
+    return std::make_error_code(std::errc::timed_out);
+  if (FD[0].revents & POLLNVAL)
+    return std::make_error_code(std::errc::bad_file_descriptor);
+  return std::error_code();
+}
 
-    auto Stop = std::chrono::steady_clock::now();
-    ElapsedTime +=
-        std::chrono::duration_cast<std::chrono::milliseconds>(Stop - Start);
-  }
+Expected<std::unique_ptr<raw_socket_stream>>
+ListeningSocket::accept(const std::chrono::milliseconds &Timeout) {
+  auto getActiveFD = [this]() -> int { return FD; };
+  std::error_code TimeoutErr = manageTimeout(Timeout, getActiveFD, PipeFD[0]);
+  if (TimeoutErr)
+    return llvm::make_error<StringError>(TimeoutErr, "Timeout error");
 
   int AcceptFD;
 #ifdef _WIN32
-  SOCKET WinAcceptSock = ::accept(WinServerSock, NULL, NULL);
+  SOCKET WinAcceptSock = ::accept(_get_osfhandle(FD), NULL, NULL);
   AcceptFD = _open_osfhandle(WinAcceptSock, 0);
 #else
   AcceptFD = ::accept(FD, NULL, NULL);
@@ -295,6 +315,8 @@ ListeningSocket::~ListeningSocket() {
 raw_socket_stream::raw_socket_stream(int SocketFD)
     : raw_fd_stream(SocketFD, true) {}
 
+raw_socket_stream::~raw_socket_stream() {}
+
 Expected<std::unique_ptr<raw_socket_stream>>
 raw_socket_stream::createConnectedUnix(StringRef SocketPath) {
 #ifdef _WIN32
@@ -306,4 +328,14 @@ raw_socket_stream::createConnectedUnix(StringRef SocketPath) {
   return std::make_unique<raw_socket_stream>(*FD);
 }
 
-raw_socket_stream::~raw_socket_stream() {}
+ssize_t raw_socket_stream::read(char *Ptr, size_t Size,
+                                const std::chrono::milliseconds &Timeout) {
+  auto getActiveFD = [this]() -> int { return this->get_fd(); };
+  std::error_code Err = manageTimeout(Timeout, getActiveFD);
+  // Mimic raw_fd_stream::read error handling behavior
+  if (Err) {
+    raw_fd_stream::error_detected(Err);
+    return -1;
+  }
+  return raw_fd_stream::read(Ptr, Size);
+}

diff  --git a/llvm/unittests/Support/raw_socket_stream_test.cpp b/llvm/unittests/Support/raw_socket_stream_test.cpp
index c4e8cfbbe7e6a..348fb4bb3e089 100644
--- a/llvm/unittests/Support/raw_socket_stream_test.cpp
+++ b/llvm/unittests/Support/raw_socket_stream_test.cpp
@@ -62,17 +62,50 @@ TEST(raw_socket_streamTest, CLIENT_TO_SERVER_AND_SERVER_TO_CLIENT) {
   ssize_t BytesRead = Server.read(Bytes, 8);
 
   std::string string(Bytes, 8);
+  ASSERT_EQ(Server.has_error(), false);
 
   ASSERT_EQ(8, BytesRead);
   ASSERT_EQ("01234567", string);
 }
 
-TEST(raw_socket_streamTest, TIMEOUT_PROVIDED) {
+TEST(raw_socket_streamTest, READ_WITH_TIMEOUT) {
   if (!hasUnixSocketSupport())
     GTEST_SKIP();
 
   SmallString<100> SocketPath;
-  llvm::sys::fs::createUniquePath("timout_provided.sock", SocketPath, true);
+  llvm::sys::fs::createUniquePath("read_with_timeout.sock", SocketPath, true);
+
+  // Make sure socket file does not exist. May still be there from the last test
+  std::remove(SocketPath.c_str());
+
+  Expected<ListeningSocket> MaybeServerListener =
+      ListeningSocket::createUnix(SocketPath);
+  ASSERT_THAT_EXPECTED(MaybeServerListener, llvm::Succeeded());
+  ListeningSocket ServerListener = std::move(*MaybeServerListener);
+
+  Expected<std::unique_ptr<raw_socket_stream>> MaybeClient =
+      raw_socket_stream::createConnectedUnix(SocketPath);
+  ASSERT_THAT_EXPECTED(MaybeClient, llvm::Succeeded());
+
+  Expected<std::unique_ptr<raw_socket_stream>> MaybeServer =
+      ServerListener.accept();
+  ASSERT_THAT_EXPECTED(MaybeServer, llvm::Succeeded());
+  raw_socket_stream &Server = **MaybeServer;
+
+  char Bytes[8];
+  ssize_t BytesRead = Server.read(Bytes, 8, std::chrono::milliseconds(100));
+  ASSERT_EQ(BytesRead, -1);
+  ASSERT_EQ(Server.has_error(), true);
+  ASSERT_EQ(Server.error(), std::errc::timed_out);
+  Server.clear_error();
+}
+
+TEST(raw_socket_streamTest, ACCEPT_WITH_TIMEOUT) {
+  if (!hasUnixSocketSupport())
+    GTEST_SKIP();
+
+  SmallString<100> SocketPath;
+  llvm::sys::fs::createUniquePath("accept_with_timeout.sock", SocketPath, true);
 
   // Make sure socket file does not exist. May still be there from the last test
   std::remove(SocketPath.c_str());
@@ -82,19 +115,19 @@ TEST(raw_socket_streamTest, TIMEOUT_PROVIDED) {
   ASSERT_THAT_EXPECTED(MaybeServerListener, llvm::Succeeded());
   ListeningSocket ServerListener = std::move(*MaybeServerListener);
 
-  std::chrono::milliseconds Timeout = std::chrono::milliseconds(100);
   Expected<std::unique_ptr<raw_socket_stream>> MaybeServer =
-      ServerListener.accept(Timeout);
+      ServerListener.accept(std::chrono::milliseconds(100));
   ASSERT_EQ(llvm::errorToErrorCode(MaybeServer.takeError()),
             std::errc::timed_out);
 }
 
-TEST(raw_socket_streamTest, FILE_DESCRIPTOR_CLOSED) {
+TEST(raw_socket_streamTest, ACCEPT_WITH_SHUTDOWN) {
   if (!hasUnixSocketSupport())
     GTEST_SKIP();
 
   SmallString<100> SocketPath;
-  llvm::sys::fs::createUniquePath("fd_closed.sock", SocketPath, true);
+  llvm::sys::fs::createUniquePath("accept_with_shutdown.sock", SocketPath,
+                                  true);
 
   // Make sure socket file does not exist. May still be there from the last test
   std::remove(SocketPath.c_str());


        


More information about the llvm-commits mailing list