[llvm] [llvm][Support] Add function to read from raw_socket_stream file descriptor with timeout (PR #92308)

via llvm-commits llvm-commits at lists.llvm.org
Sun Jun 2 21:37:01 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-support

Author: Connor Sughrue (cpsughrue)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/92308.diff


3 Files Affected:

- (modified) llvm/include/llvm/Support/raw_socket_stream.h (+18-3) 
- (modified) llvm/lib/Support/raw_socket_stream.cpp (+70-23) 
- (modified) llvm/unittests/Support/raw_socket_stream_test.cpp (+96-11) 


``````````diff
diff --git a/llvm/include/llvm/Support/raw_socket_stream.h b/llvm/include/llvm/Support/raw_socket_stream.h
index bddd47eb75e1a..225980cb28a42 100644
--- a/llvm/include/llvm/Support/raw_socket_stream.h
+++ b/llvm/include/llvm/Support/raw_socket_stream.h
@@ -92,10 +92,11 @@ class ListeningSocket {
   /// Accepts an incoming connection on the listening socket. This method can
   /// optionally either block until a connection is available or timeout after a
   /// specified amount of time has passed. By default the method will block
-  /// until the socket has recieved a connection.
+  /// until the socket has recieved a connection. If the accept timesout this
+  /// method will return std::errc:timed_out
   ///
   /// \param Timeout An optional timeout duration in milliseconds. Setting
-  /// Timeout to -1 causes accept to block indefinitely
+  /// Timeout to a negative number causes ::accept to block indefinitely
   ///
   Expected<std::unique_ptr<raw_socket_stream>>
   accept(std::chrono::milliseconds Timeout = std::chrono::milliseconds(-1));
@@ -124,11 +125,25 @@ class raw_socket_stream : public raw_fd_stream {
 
 public:
   raw_socket_stream(int SocketFD);
+  ~raw_socket_stream();
+
   /// Create a \p raw_socket_stream connected to the UNIX domain socket at \p
   /// SocketPath.
   static Expected<std::unique_ptr<raw_socket_stream>>
   createConnectedUnix(StringRef SocketPath);
-  ~raw_socket_stream();
+
+  /// Attempt to read from the raw_socket_stream's file descriptor. This method
+  /// can optionally either block until data is read or an error has occurred or
+  /// timeout after a specified amount of time has passed. By default the method
+  /// will block until the socket has read data or encountered an error. If the
+  /// read timesout this method will return std::errc:timed_out
+  ///
+  /// \param Timeout An optional timeout duration in milliseconds
+  /// \param Ptr The start of the buffer that will hold any read data
+  /// \param Size The number of bytes to be read
+  ///
+  Expected<std::string> readFromSocket(
+      std::chrono::milliseconds Timeout = std::chrono::milliseconds(-1));
 };
 
 } // end namespace llvm
diff --git a/llvm/lib/Support/raw_socket_stream.cpp b/llvm/lib/Support/raw_socket_stream.cpp
index 549d537709bf2..063f6fc366da9 100644
--- a/llvm/lib/Support/raw_socket_stream.cpp
+++ b/llvm/lib/Support/raw_socket_stream.cpp
@@ -18,6 +18,7 @@
 
 #include <atomic>
 #include <fcntl.h>
+#include <functional>
 #include <thread>
 
 #ifndef _WIN32
@@ -177,22 +178,31 @@ Expected<ListeningSocket> ListeningSocket::createUnix(StringRef SocketPath,
 #endif // _WIN32
 }
 
-Expected<std::unique_ptr<raw_socket_stream>>
-ListeningSocket::accept(std::chrono::milliseconds Timeout) {
-
-  struct pollfd FDs[2];
-  FDs[0].events = POLLIN;
+// If a file descriptor being monitored by poll is closed by another thread, the
+// result is unspecified. In the case poll does not unblock and return when
+// ActiveFD is closed you can provide another file descriptor via CancelFD that
+// when written to will cause poll to return. Typically CancelFD is the read end
+// of a unidirectional pipe.
+static llvm::Error manageTimeout(std::chrono::milliseconds Timeout,
+                                 std::function<int()> getActiveFD,
+                                 std::optional<int> CancelFD = std::nullopt) {
+  struct pollfd FD[2];
+  FD[0].events = POLLIN;
 #ifdef _WIN32
-  SOCKET WinServerSock = _get_osfhandle(FD);
-  FDs[0].fd = WinServerSock;
+  SOCKET WinServerSock = _get_osfhandle(getActiveFD());
+  FD[0].fd = WinServerSock;
 #else
-  FDs[0].fd = FD;
+  FD[0].fd = getActiveFD();
 #endif
-  FDs[1].events = POLLIN;
-  FDs[1].fd = PipeFD[0];
+  uint8_t FDCount = 1;
+  if (CancelFD.has_value()) {
+    FD[1].events = POLLIN;
+    FD[1].fd = CancelFD.value();
+    FDCount++;
+  }
 
-  // Keep track of how much time has passed in case poll is interupted by a
-  // signal and needs to be recalled
+  // Keep track of how much time has passed in case ::poll or WSAPoll are
+  // interupted by a signal and need to be recalled
   int RemainingTime = Timeout.count();
   std::chrono::milliseconds ElapsedTime = std::chrono::milliseconds(0);
   int PollStatus = -1;
@@ -200,20 +210,20 @@ ListeningSocket::accept(std::chrono::milliseconds Timeout) {
   while (PollStatus == -1 && (Timeout.count() == -1 || ElapsedTime < Timeout)) {
     if (Timeout.count() != -1)
       RemainingTime -= ElapsedTime.count();
-
     auto Start = std::chrono::steady_clock::now();
+
 #ifdef _WIN32
-    PollStatus = WSAPoll(FDs, 2, RemainingTime);
+    PollStatus = WSAPoll(FD, FDCount, RemainingTime);
 #else
-    PollStatus = ::poll(FDs, 2, RemainingTime);
+    PollStatus = ::poll(FD, FDCount, RemainingTime);
 #endif
-    // If FD equals -1 then ListeningSocket::shutdown has been called and it is
-    // appropriate to return operation_canceled
-    if (FD.load() == -1)
+
+    // If ActiveFD equals -1 or CancelFD has data to be read then the operation
+    // has been canceled by another thread
+    if (getActiveFD() == -1 || FD[1].revents & POLLIN)
       return llvm::make_error<StringError>(
           std::make_error_code(std::errc::operation_canceled),
           "Accept canceled");
-
 #if _WIN32
     if (PollStatus == SOCKET_ERROR) {
 #else
@@ -222,14 +232,14 @@ ListeningSocket::accept(std::chrono::milliseconds Timeout) {
       std::error_code PollErrCode = getLastSocketErrorCode();
       // Ignore EINTR (signal occured before any request event) and retry
       if (PollErrCode != std::errc::interrupted)
-        return llvm::make_error<StringError>(PollErrCode, "FD poll failed");
+        return llvm::make_error<StringError>(PollErrCode, "poll failed");
     }
     if (PollStatus == 0)
       return llvm::make_error<StringError>(
           std::make_error_code(std::errc::timed_out),
           "No client requests within timeout window");
 
-    if (FDs[0].revents & POLLNVAL)
+    if (FD[0].revents & POLLNVAL)
       return llvm::make_error<StringError>(
           std::make_error_code(std::errc::bad_file_descriptor));
 
@@ -237,10 +247,19 @@ ListeningSocket::accept(std::chrono::milliseconds Timeout) {
     ElapsedTime +=
         std::chrono::duration_cast<std::chrono::milliseconds>(Stop - Start);
   }
+  return llvm::Error::success();
+}
+
+Expected<std::unique_ptr<raw_socket_stream>>
+ListeningSocket::accept(std::chrono::milliseconds Timeout) {
+  auto getActiveFD = [this]() -> int { return FD; };
+  llvm::Error TimeoutErr = manageTimeout(Timeout, getActiveFD, PipeFD[0]);
+  if (TimeoutErr)
+    return TimeoutErr;
 
   int AcceptFD;
 #ifdef _WIN32
-  SOCKET WinAcceptSock = ::accept(WinServerSock, NULL, NULL);
+  SOCKET WinAcceptSock = ::accept(_get_osfhandle(FD), NULL, NULL);
   AcceptFD = _open_osfhandle(WinAcceptSock, 0);
 #else
   AcceptFD = ::accept(FD, NULL, NULL);
@@ -295,6 +314,8 @@ ListeningSocket::~ListeningSocket() {
 raw_socket_stream::raw_socket_stream(int SocketFD)
     : raw_fd_stream(SocketFD, true) {}
 
+raw_socket_stream::~raw_socket_stream() {}
+
 Expected<std::unique_ptr<raw_socket_stream>>
 raw_socket_stream::createConnectedUnix(StringRef SocketPath) {
 #ifdef _WIN32
@@ -306,4 +327,30 @@ raw_socket_stream::createConnectedUnix(StringRef SocketPath) {
   return std::make_unique<raw_socket_stream>(*FD);
 }
 
-raw_socket_stream::~raw_socket_stream() {}
+Expected<std::string>
+raw_socket_stream::readFromSocket(std::chrono::milliseconds Timeout) {
+  auto getActiveFD = [this]() -> int { return this->get_fd(); };
+  llvm::Error TimeoutErr = manageTimeout(Timeout, getActiveFD);
+  if (TimeoutErr)
+    return TimeoutErr;
+
+  std::vector<char> Buffer;
+  constexpr ssize_t TmpBufferSize = 1024;
+  char TmpBuffer[TmpBufferSize];
+
+  while (true) {
+    std::memset(TmpBuffer, 0, TmpBufferSize);
+    ssize_t BytesRead = this->read(TmpBuffer, TmpBufferSize);
+    if (BytesRead == -1)
+      return llvm::make_error<StringError>(this->error(), "read failed");
+    else if (BytesRead == 0)
+      break;
+    else
+      Buffer.insert(Buffer.end(), TmpBuffer, TmpBuffer + BytesRead);
+    // All available bytes have been read. Another call to read will block
+    if (BytesRead < TmpBufferSize)
+      break;
+  }
+
+  return std::string(Buffer.begin(), Buffer.end());
+}
diff --git a/llvm/unittests/Support/raw_socket_stream_test.cpp b/llvm/unittests/Support/raw_socket_stream_test.cpp
index c4e8cfbbe7e6a..1b8f85f88f1af 100644
--- a/llvm/unittests/Support/raw_socket_stream_test.cpp
+++ b/llvm/unittests/Support/raw_socket_stream_test.cpp
@@ -58,21 +58,106 @@ TEST(raw_socket_streamTest, CLIENT_TO_SERVER_AND_SERVER_TO_CLIENT) {
   Client << "01234567";
   Client.flush();
 
-  char Bytes[8];
-  ssize_t BytesRead = Server.read(Bytes, 8);
+  llvm::Expected<std::string> MaybeText = Server.readFromSocket();
+  ASSERT_THAT_EXPECTED(MaybeText, llvm::Succeeded());
+  ASSERT_EQ("01234567", *MaybeText);
+}
+
+TEST(raw_socket_streamTest, LARGE_READ) {
+  if (!hasUnixSocketSupport())
+    GTEST_SKIP();
+
+  SmallString<100> SocketPath;
+  llvm::sys::fs::createUniquePath("large_read.sock", SocketPath, true);
+
+  // Make sure socket file does not exist. May still be there from the last test
+  std::remove(SocketPath.c_str());
+
+  Expected<ListeningSocket> MaybeServerListener =
+      ListeningSocket::createUnix(SocketPath);
+  ASSERT_THAT_EXPECTED(MaybeServerListener, llvm::Succeeded());
+  ListeningSocket ServerListener = std::move(*MaybeServerListener);
+
+  Expected<std::unique_ptr<raw_socket_stream>> MaybeClient =
+      raw_socket_stream::createConnectedUnix(SocketPath);
+  ASSERT_THAT_EXPECTED(MaybeClient, llvm::Succeeded());
+  raw_socket_stream &Client = **MaybeClient;
+
+  Expected<std::unique_ptr<raw_socket_stream>> MaybeServer =
+      ServerListener.accept();
+  ASSERT_THAT_EXPECTED(MaybeServer, llvm::Succeeded());
+  raw_socket_stream &Server = **MaybeServer;
+
+  // raw_socket_stream::readFromSocket pre-allocates a buffer 1024 bytes large.
+  // Test to make sure readFromSocket can handle messages larger then size of
+  // pre-allocated block
+  constexpr int TextLength = 1342;
+  constexpr char Text[TextLength] =
+      "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do "
+      "eiusmod tempor incididunt ut labore et dolore magna aliqua. Vel orci "
+      "porta non pulvinar neque laoreet suspendisse interdum consectetur. "
+      "Nulla facilisi etiam dignissim diam quis. Porttitor massa id neque "
+      "aliquam vestibulum morbi blandit cursus. Purus viverra accumsan in "
+      "nisl. Nunc non blandit massa enim nec dui nunc mattis enim. Rhoncus "
+      "dolor purus non enim praesent elementum facilisis leo. Parturient "
+      "montes nascetur ridiculus mus mauris. Urna condimentum mattis "
+      "pellentesque id nibh tortor id aliquet lectus. Orci eu lobortis "
+      "elementum nibh. Sagittis eu volutpat odio facilisis. Molestie a "
+      "iaculis at erat pellentesque adipiscing. Tincidunt augue interdum "
+      "velit euismod in pellentesque massa placerat. Cras ornare arcu dui "
+      "vivamus arcu felis bibendum ut tristique. Tellus elementum sagittis "
+      "vitae et leo duis. Scelerisque fermentum dui faucibus in ornare "
+      "quam. Ipsum a arcu cursus vitae congue. Sit amet nisl suscipit "
+      "adipiscing. Sociis natoque penatibus et magnis. Cras semper auctor "
+      "neque vitae tempus quam pellentesque. Neque gravida in fermentum et "
+      "sollicitudin ac orci phasellus egestas. Vitae suscipit tellus mauris "
+      "a diam maecenas sed. Lectus arcu bibendum at varius vel pharetra. "
+      "Dignissim sodales ut eu sem integer vitae justo. Id cursus metus "
+      "aliquam eleifend mi.";
+  Client << Text;
+  Client.flush();
+
+  llvm::Expected<std::string> MaybeText = Server.readFromSocket();
+  ASSERT_THAT_EXPECTED(MaybeText, llvm::Succeeded());
+  ASSERT_EQ(Text, *MaybeText);
+}
 
-  std::string string(Bytes, 8);
+TEST(raw_socket_streamTest, READ_WITH_TIMEOUT) {
+  if (!hasUnixSocketSupport())
+    GTEST_SKIP();
+
+  SmallString<100> SocketPath;
+  llvm::sys::fs::createUniquePath("read_with_timeout.sock", SocketPath, true);
 
-  ASSERT_EQ(8, BytesRead);
-  ASSERT_EQ("01234567", string);
+  // Make sure socket file does not exist. May still be there from the last test
+  std::remove(SocketPath.c_str());
+
+  Expected<ListeningSocket> MaybeServerListener =
+      ListeningSocket::createUnix(SocketPath);
+  ASSERT_THAT_EXPECTED(MaybeServerListener, llvm::Succeeded());
+  ListeningSocket ServerListener = std::move(*MaybeServerListener);
+
+  Expected<std::unique_ptr<raw_socket_stream>> MaybeClient =
+      raw_socket_stream::createConnectedUnix(SocketPath);
+  ASSERT_THAT_EXPECTED(MaybeClient, llvm::Succeeded());
+
+  Expected<std::unique_ptr<raw_socket_stream>> MaybeServer =
+      ServerListener.accept();
+  ASSERT_THAT_EXPECTED(MaybeServer, llvm::Succeeded());
+  raw_socket_stream &Server = **MaybeServer;
+
+  llvm::Expected<std::string> MaybeBytesRead =
+      Server.readFromSocket(std::chrono::milliseconds(100));
+  ASSERT_EQ(llvm::errorToErrorCode(MaybeBytesRead.takeError()),
+            std::errc::timed_out);
 }
 
-TEST(raw_socket_streamTest, TIMEOUT_PROVIDED) {
+TEST(raw_socket_streamTest, ACCEPT_WITH_TIMEOUT) {
   if (!hasUnixSocketSupport())
     GTEST_SKIP();
 
   SmallString<100> SocketPath;
-  llvm::sys::fs::createUniquePath("timout_provided.sock", SocketPath, true);
+  llvm::sys::fs::createUniquePath("accept_with_timeout.sock", SocketPath, true);
 
   // Make sure socket file does not exist. May still be there from the last test
   std::remove(SocketPath.c_str());
@@ -82,19 +167,19 @@ TEST(raw_socket_streamTest, TIMEOUT_PROVIDED) {
   ASSERT_THAT_EXPECTED(MaybeServerListener, llvm::Succeeded());
   ListeningSocket ServerListener = std::move(*MaybeServerListener);
 
-  std::chrono::milliseconds Timeout = std::chrono::milliseconds(100);
   Expected<std::unique_ptr<raw_socket_stream>> MaybeServer =
-      ServerListener.accept(Timeout);
+      ServerListener.accept(std::chrono::milliseconds(100));
   ASSERT_EQ(llvm::errorToErrorCode(MaybeServer.takeError()),
             std::errc::timed_out);
 }
 
-TEST(raw_socket_streamTest, FILE_DESCRIPTOR_CLOSED) {
+TEST(raw_socket_streamTest, ACCEPT_WITH_SHUTDOWN) {
   if (!hasUnixSocketSupport())
     GTEST_SKIP();
 
   SmallString<100> SocketPath;
-  llvm::sys::fs::createUniquePath("fd_closed.sock", SocketPath, true);
+  llvm::sys::fs::createUniquePath("accept_with_shutdown.sock", SocketPath,
+                                  true);
 
   // Make sure socket file does not exist. May still be there from the last test
   std::remove(SocketPath.c_str());

``````````

</details>


https://github.com/llvm/llvm-project/pull/92308


More information about the llvm-commits mailing list