[llvm] [llvm][Support] Add function to read from raw_socket_stream file descriptor with timeout (PR #92308)

Connor Sughrue via llvm-commits llvm-commits at lists.llvm.org
Wed May 15 12:49:01 PDT 2024


https://github.com/cpsughrue created https://github.com/llvm/llvm-project/pull/92308

None

>From b9e468a13df3af03ad2da94bc1686ea7770bee7e Mon Sep 17 00:00:00 2001
From: cpsughrue <cpsughrue at gmail.com>
Date: Wed, 15 May 2024 12:33:55 -0400
Subject: [PATCH 1/2] Include changes that will be upstreamed with future patch

---
 llvm/lib/Support/raw_socket_stream.cpp        | 23 +++++++++++--------
 .../Support/raw_socket_stream_test.cpp        | 19 ++++-----------
 2 files changed, 17 insertions(+), 25 deletions(-)

diff --git a/llvm/lib/Support/raw_socket_stream.cpp b/llvm/lib/Support/raw_socket_stream.cpp
index 14e2308df4d7e..549d537709bf2 100644
--- a/llvm/lib/Support/raw_socket_stream.cpp
+++ b/llvm/lib/Support/raw_socket_stream.cpp
@@ -204,17 +204,26 @@ ListeningSocket::accept(std::chrono::milliseconds Timeout) {
     auto Start = std::chrono::steady_clock::now();
 #ifdef _WIN32
     PollStatus = WSAPoll(FDs, 2, RemainingTime);
-    if (PollStatus == SOCKET_ERROR) {
 #else
     PollStatus = ::poll(FDs, 2, RemainingTime);
+#endif
+    // If FD equals -1 then ListeningSocket::shutdown has been called and it is
+    // appropriate to return operation_canceled
+    if (FD.load() == -1)
+      return llvm::make_error<StringError>(
+          std::make_error_code(std::errc::operation_canceled),
+          "Accept canceled");
+
+#if _WIN32
+    if (PollStatus == SOCKET_ERROR) {
+#else
     if (PollStatus == -1) {
 #endif
-      // Ignore error if caused by interupting signal
       std::error_code PollErrCode = getLastSocketErrorCode();
+      // Ignore EINTR (signal occured before any request event) and retry
       if (PollErrCode != std::errc::interrupted)
         return llvm::make_error<StringError>(PollErrCode, "FD poll failed");
     }
-
     if (PollStatus == 0)
       return llvm::make_error<StringError>(
           std::make_error_code(std::errc::timed_out),
@@ -222,13 +231,7 @@ ListeningSocket::accept(std::chrono::milliseconds Timeout) {
 
     if (FDs[0].revents & POLLNVAL)
       return llvm::make_error<StringError>(
-          std::make_error_code(std::errc::bad_file_descriptor),
-          "File descriptor closed by another thread");
-
-    if (FDs[1].revents & POLLIN)
-      return llvm::make_error<StringError>(
-          std::make_error_code(std::errc::operation_canceled),
-          "Accept canceled");
+          std::make_error_code(std::errc::bad_file_descriptor));
 
     auto Stop = std::chrono::steady_clock::now();
     ElapsedTime +=
diff --git a/llvm/unittests/Support/raw_socket_stream_test.cpp b/llvm/unittests/Support/raw_socket_stream_test.cpp
index a8536228666db..c4e8cfbbe7e6a 100644
--- a/llvm/unittests/Support/raw_socket_stream_test.cpp
+++ b/llvm/unittests/Support/raw_socket_stream_test.cpp
@@ -7,7 +7,6 @@
 #include "llvm/Testing/Support/Error.h"
 #include "gtest/gtest.h"
 #include <future>
-#include <iostream>
 #include <stdlib.h>
 #include <thread>
 
@@ -86,13 +85,8 @@ TEST(raw_socket_streamTest, TIMEOUT_PROVIDED) {
   std::chrono::milliseconds Timeout = std::chrono::milliseconds(100);
   Expected<std::unique_ptr<raw_socket_stream>> MaybeServer =
       ServerListener.accept(Timeout);
-
-  ASSERT_THAT_EXPECTED(MaybeServer, Failed());
-  llvm::Error Err = MaybeServer.takeError();
-  llvm::handleAllErrors(std::move(Err), [&](const llvm::StringError &SE) {
-    std::error_code EC = SE.convertToErrorCode();
-    ASSERT_EQ(EC, std::errc::timed_out);
-  });
+  ASSERT_EQ(llvm::errorToErrorCode(MaybeServer.takeError()),
+            std::errc::timed_out);
 }
 
 TEST(raw_socket_streamTest, FILE_DESCRIPTOR_CLOSED) {
@@ -122,12 +116,7 @@ TEST(raw_socket_streamTest, FILE_DESCRIPTOR_CLOSED) {
 
   // Wait for the CloseThread to finish
   CloseThread.join();
-
-  ASSERT_THAT_EXPECTED(MaybeServer, Failed());
-  llvm::Error Err = MaybeServer.takeError();
-  llvm::handleAllErrors(std::move(Err), [&](const llvm::StringError &SE) {
-    std::error_code EC = SE.convertToErrorCode();
-    ASSERT_EQ(EC, std::errc::operation_canceled);
-  });
+  ASSERT_EQ(llvm::errorToErrorCode(MaybeServer.takeError()),
+            std::errc::operation_canceled);
 }
 } // namespace

>From f16329949ee99431e061386a093d4f7006f050ff Mon Sep 17 00:00:00 2001
From: cpsughrue <cpsughrue at gmail.com>
Date: Wed, 15 May 2024 15:47:27 -0400
Subject: [PATCH 2/2] Rough draft of readWithTimeout to figure out structure

---
 llvm/include/llvm/Support/raw_socket_stream.h |  2 ++
 llvm/lib/Support/raw_socket_stream.cpp        | 32 ++++++++++++++-----
 2 files changed, 26 insertions(+), 8 deletions(-)

diff --git a/llvm/include/llvm/Support/raw_socket_stream.h b/llvm/include/llvm/Support/raw_socket_stream.h
index bddd47eb75e1a..f23bda8e429d0 100644
--- a/llvm/include/llvm/Support/raw_socket_stream.h
+++ b/llvm/include/llvm/Support/raw_socket_stream.h
@@ -128,6 +128,8 @@ class raw_socket_stream : public raw_fd_stream {
   /// SocketPath.
   static Expected<std::unique_ptr<raw_socket_stream>>
   createConnectedUnix(StringRef SocketPath);
+  llvm::Error readWithTimeout(char *Ptr, size_t Size,
+                              std::chrono::milliseconds Timeout);
   ~raw_socket_stream();
 };
 
diff --git a/llvm/lib/Support/raw_socket_stream.cpp b/llvm/lib/Support/raw_socket_stream.cpp
index 549d537709bf2..b256017d3d50d 100644
--- a/llvm/lib/Support/raw_socket_stream.cpp
+++ b/llvm/lib/Support/raw_socket_stream.cpp
@@ -177,19 +177,20 @@ Expected<ListeningSocket> ListeningSocket::createUnix(StringRef SocketPath,
 #endif // _WIN32
 }
 
-Expected<std::unique_ptr<raw_socket_stream>>
-ListeningSocket::accept(std::chrono::milliseconds Timeout) {
-
+static llvm::Error manageTimeout(const std::chrono::milliseconds Timeout,
+                                 const std::atomic<int> &ActiveFD,
+                                 const int PipeFD) {
+  // Populate array of file descriptors that ::poll will monitor
   struct pollfd FDs[2];
   FDs[0].events = POLLIN;
 #ifdef _WIN32
   SOCKET WinServerSock = _get_osfhandle(FD);
   FDs[0].fd = WinServerSock;
 #else
-  FDs[0].fd = FD;
+  FDs[0].fd = ActiveFD;
 #endif
   FDs[1].events = POLLIN;
-  FDs[1].fd = PipeFD[0];
+  FDs[1].fd = PipeFD;
 
   // Keep track of how much time has passed in case poll is interupted by a
   // signal and needs to be recalled
@@ -209,7 +210,7 @@ ListeningSocket::accept(std::chrono::milliseconds Timeout) {
 #endif
     // If FD equals -1 then ListeningSocket::shutdown has been called and it is
     // appropriate to return operation_canceled
-    if (FD.load() == -1)
+    if (ActiveFD.load() == -1)
       return llvm::make_error<StringError>(
           std::make_error_code(std::errc::operation_canceled),
           "Accept canceled");
@@ -237,6 +238,13 @@ ListeningSocket::accept(std::chrono::milliseconds Timeout) {
     ElapsedTime +=
         std::chrono::duration_cast<std::chrono::milliseconds>(Stop - Start);
   }
+  return llvm::Error::success();
+}
+
+Expected<std::unique_ptr<raw_socket_stream>>
+ListeningSocket::accept(std::chrono::milliseconds Timeout) {
+  if (llvm::Error TimeoutErr = manageTimeout(Timeout, FD, PipeFD[0]))
+    return std::move(TimeoutErr);
 
   int AcceptFD;
 #ifdef _WIN32
@@ -267,8 +275,7 @@ void ListeningSocket::shutdown() {
   ::unlink(SocketPath.c_str());
 
   // Ensure ::poll returns if shutdown is called by a seperate thread
-  char Byte = 'A';
-  ssize_t written = ::write(PipeFD[1], &Byte, 1);
+  ssize_t written = ::write(PipeFD[1], ".", 1);
 
   // Ignore any write() error
   (void)written;
@@ -306,4 +313,13 @@ raw_socket_stream::createConnectedUnix(StringRef SocketPath) {
   return std::make_unique<raw_socket_stream>(*FD);
 }
 
+llvm::Error
+raw_socket_stream::readWithTimeout(char *Ptr, size_t Size,
+                                   std::chrono::milliseconds Timeout) {
+  // FIXME: add pipe and remove test value of 10
+  if (llvm::Error TimeoutErr = manageTimeout(Timeout, get_fd(), 10))
+    return std::move(TimeoutErr);
+  ssize_t Ret = this->read(Ptr, Size);
+}
+
 raw_socket_stream::~raw_socket_stream() {}



More information about the llvm-commits mailing list