[llvm] 76321b9 - [llvm][Support] Implement raw_socket_stream::read with optional timeout (#92308)
via llvm-commits
llvm-commits at lists.llvm.org
Sun Jul 21 20:50:31 PDT 2024
Author: Connor Sughrue
Date: 2024-07-21T23:50:28-04:00
New Revision: 76321b9f08ef31a2b8ca26f7522aee511a05f7a8
URL: https://github.com/llvm/llvm-project/commit/76321b9f08ef31a2b8ca26f7522aee511a05f7a8
DIFF: https://github.com/llvm/llvm-project/commit/76321b9f08ef31a2b8ca26f7522aee511a05f7a8.diff
LOG: [llvm][Support] Implement raw_socket_stream::read with optional timeout (#92308)
This PR implements `raw_socket_stream::read`, which overloads the base
class `raw_fd_stream::read`. `raw_socket_stream::read` provides a way to
timeout the underlying `::read`. The timeout functionality was not added
to `raw_fd_stream::read` to avoid needlessly increasing compile times
and allow for convenient code reuse with `raw_socket_stream::accept`,
which also requires timeout functionality. This PR supports the module
build daemon and will help guarantee it never becomes a zombie process.
Added:
Modified:
llvm/include/llvm/Support/raw_socket_stream.h
llvm/lib/Support/raw_socket_stream.cpp
llvm/unittests/Support/raw_socket_stream_test.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Support/raw_socket_stream.h b/llvm/include/llvm/Support/raw_socket_stream.h
index eed865fb5af49..6c65a66dec9a4 100644
--- a/llvm/include/llvm/Support/raw_socket_stream.h
+++ b/llvm/include/llvm/Support/raw_socket_stream.h
@@ -92,13 +92,14 @@ class ListeningSocket {
/// Accepts an incoming connection on the listening socket. This method can
/// optionally either block until a connection is available or timeout after a
/// specified amount of time has passed. By default the method will block
- /// until the socket has recieved a connection.
+ /// until the socket has recieved a connection. If the accept timesout this
+ /// method will return std::errc:timed_out
///
/// \param Timeout An optional timeout duration in milliseconds. Setting
- /// Timeout to -1 causes accept to block indefinitely
+ /// Timeout to a negative number causes ::accept to block indefinitely
///
- Expected<std::unique_ptr<raw_socket_stream>>
- accept(std::chrono::milliseconds Timeout = std::chrono::milliseconds(-1));
+ Expected<std::unique_ptr<raw_socket_stream>> accept(
+ const std::chrono::milliseconds &Timeout = std::chrono::milliseconds(-1));
/// Creates a listening socket bound to the specified file system path.
/// Handles the socket creation, binding, and immediately starts listening for
@@ -124,11 +125,28 @@ class raw_socket_stream : public raw_fd_stream {
public:
raw_socket_stream(int SocketFD);
+ ~raw_socket_stream();
+
/// Create a \p raw_socket_stream connected to the UNIX domain socket at \p
/// SocketPath.
static Expected<std::unique_ptr<raw_socket_stream>>
createConnectedUnix(StringRef SocketPath);
- ~raw_socket_stream();
+
+ /// Attempt to read from the raw_socket_stream's file descriptor.
+ ///
+ /// This method can optionally either block until data is read or an error has
+ /// occurred or timeout after a specified amount of time has passed. By
+ /// default the method will block until the socket has read data or
+ /// encountered an error. If the read times out this method will return
+ /// std::errc:timed_out
+ ///
+ /// \param Ptr The start of the buffer that will hold any read data
+ /// \param Size The number of bytes to be read
+ /// \param Timeout An optional timeout duration in milliseconds
+ ///
+ ssize_t read(
+ char *Ptr, size_t Size,
+ const std::chrono::milliseconds &Timeout = std::chrono::milliseconds(-1));
};
} // end namespace llvm
diff --git a/llvm/lib/Support/raw_socket_stream.cpp b/llvm/lib/Support/raw_socket_stream.cpp
index 4cd3d58b80198..04b3233084a41 100644
--- a/llvm/lib/Support/raw_socket_stream.cpp
+++ b/llvm/lib/Support/raw_socket_stream.cpp
@@ -18,6 +18,7 @@
#include <atomic>
#include <fcntl.h>
+#include <functional>
#include <thread>
#ifndef _WIN32
@@ -177,70 +178,89 @@ Expected<ListeningSocket> ListeningSocket::createUnix(StringRef SocketPath,
#endif // _WIN32
}
-Expected<std::unique_ptr<raw_socket_stream>>
-ListeningSocket::accept(std::chrono::milliseconds Timeout) {
-
- struct pollfd FDs[2];
- FDs[0].events = POLLIN;
+// If a file descriptor being monitored by ::poll is closed by another thread,
+// the result is unspecified. In the case ::poll does not unblock and return,
+// when ActiveFD is closed, you can provide another file descriptor via CancelFD
+// that when written to will cause poll to return. Typically CancelFD is the
+// read end of a unidirectional pipe.
+//
+// Timeout should be -1 to block indefinitly
+//
+// getActiveFD is a callback to handle ActiveFD's of std::atomic<int> and int
+static std::error_code
+manageTimeout(const std::chrono::milliseconds &Timeout,
+ const std::function<int()> &getActiveFD,
+ const std::optional<int> &CancelFD = std::nullopt) {
+ struct pollfd FD[2];
+ FD[0].events = POLLIN;
#ifdef _WIN32
- SOCKET WinServerSock = _get_osfhandle(FD);
- FDs[0].fd = WinServerSock;
+ SOCKET WinServerSock = _get_osfhandle(getActiveFD());
+ FD[0].fd = WinServerSock;
#else
- FDs[0].fd = FD;
+ FD[0].fd = getActiveFD();
#endif
- FDs[1].events = POLLIN;
- FDs[1].fd = PipeFD[0];
-
- // Keep track of how much time has passed in case poll is interupted by a
- // signal and needs to be recalled
- int RemainingTime = Timeout.count();
- std::chrono::milliseconds ElapsedTime = std::chrono::milliseconds(0);
- int PollStatus = -1;
-
- while (PollStatus == -1 && (Timeout.count() == -1 || ElapsedTime < Timeout)) {
- if (Timeout.count() != -1)
- RemainingTime -= ElapsedTime.count();
+ uint8_t FDCount = 1;
+ if (CancelFD.has_value()) {
+ FD[1].events = POLLIN;
+ FD[1].fd = CancelFD.value();
+ FDCount++;
+ }
- auto Start = std::chrono::steady_clock::now();
+ // Keep track of how much time has passed in case ::poll or WSAPoll are
+ // interupted by a signal and need to be recalled
+ auto Start = std::chrono::steady_clock::now();
+ auto RemainingTimeout = Timeout;
+ int PollStatus = 0;
+ do {
+ // If Timeout is -1 then poll should block and RemainingTimeout does not
+ // need to be recalculated
+ if (PollStatus != 0 && Timeout != std::chrono::milliseconds(-1)) {
+ auto TotalElapsedTime =
+ std::chrono::duration_cast<std::chrono::milliseconds>(
+ std::chrono::steady_clock::now() - Start);
+
+ if (TotalElapsedTime >= Timeout)
+ return std::make_error_code(std::errc::operation_would_block);
+
+ RemainingTimeout = Timeout - TotalElapsedTime;
+ }
#ifdef _WIN32
- PollStatus = WSAPoll(FDs, 2, RemainingTime);
+ PollStatus = WSAPoll(FD, FDCount, RemainingTimeout.count());
+ } while (PollStatus == SOCKET_ERROR &&
+ getLastSocketErrorCode() == std::errc::interrupted);
#else
- PollStatus = ::poll(FDs, 2, RemainingTime);
+ PollStatus = ::poll(FD, FDCount, RemainingTimeout.count());
+ } while (PollStatus == -1 &&
+ getLastSocketErrorCode() == std::errc::interrupted);
#endif
- // If FD equals -1 then ListeningSocket::shutdown has been called and it is
- // appropriate to return operation_canceled
- if (FD.load() == -1)
- return llvm::make_error<StringError>(
- std::make_error_code(std::errc::operation_canceled),
- "Accept canceled");
+ // If ActiveFD equals -1 or CancelFD has data to be read then the operation
+ // has been canceled by another thread
+ if (getActiveFD() == -1 || (CancelFD.has_value() && FD[1].revents & POLLIN))
+ return std::make_error_code(std::errc::operation_canceled);
#if _WIN32
- if (PollStatus == SOCKET_ERROR) {
+ if (PollStatus == SOCKET_ERROR)
#else
- if (PollStatus == -1) {
+ if (PollStatus == -1)
#endif
- std::error_code PollErrCode = getLastSocketErrorCode();
- // Ignore EINTR (signal occured before any request event) and retry
- if (PollErrCode != std::errc::interrupted)
- return llvm::make_error<StringError>(PollErrCode, "FD poll failed");
- }
- if (PollStatus == 0)
- return llvm::make_error<StringError>(
- std::make_error_code(std::errc::timed_out),
- "No client requests within timeout window");
-
- if (FDs[0].revents & POLLNVAL)
- return llvm::make_error<StringError>(
- std::make_error_code(std::errc::bad_file_descriptor));
+ return getLastSocketErrorCode();
+ if (PollStatus == 0)
+ return std::make_error_code(std::errc::timed_out);
+ if (FD[0].revents & POLLNVAL)
+ return std::make_error_code(std::errc::bad_file_descriptor);
+ return std::error_code();
+}
- auto Stop = std::chrono::steady_clock::now();
- ElapsedTime +=
- std::chrono::duration_cast<std::chrono::milliseconds>(Stop - Start);
- }
+Expected<std::unique_ptr<raw_socket_stream>>
+ListeningSocket::accept(const std::chrono::milliseconds &Timeout) {
+ auto getActiveFD = [this]() -> int { return FD; };
+ std::error_code TimeoutErr = manageTimeout(Timeout, getActiveFD, PipeFD[0]);
+ if (TimeoutErr)
+ return llvm::make_error<StringError>(TimeoutErr, "Timeout error");
int AcceptFD;
#ifdef _WIN32
- SOCKET WinAcceptSock = ::accept(WinServerSock, NULL, NULL);
+ SOCKET WinAcceptSock = ::accept(_get_osfhandle(FD), NULL, NULL);
AcceptFD = _open_osfhandle(WinAcceptSock, 0);
#else
AcceptFD = ::accept(FD, NULL, NULL);
@@ -295,6 +315,8 @@ ListeningSocket::~ListeningSocket() {
raw_socket_stream::raw_socket_stream(int SocketFD)
: raw_fd_stream(SocketFD, true) {}
+raw_socket_stream::~raw_socket_stream() {}
+
Expected<std::unique_ptr<raw_socket_stream>>
raw_socket_stream::createConnectedUnix(StringRef SocketPath) {
#ifdef _WIN32
@@ -306,4 +328,14 @@ raw_socket_stream::createConnectedUnix(StringRef SocketPath) {
return std::make_unique<raw_socket_stream>(*FD);
}
-raw_socket_stream::~raw_socket_stream() {}
+ssize_t raw_socket_stream::read(char *Ptr, size_t Size,
+ const std::chrono::milliseconds &Timeout) {
+ auto getActiveFD = [this]() -> int { return this->get_fd(); };
+ std::error_code Err = manageTimeout(Timeout, getActiveFD);
+ // Mimic raw_fd_stream::read error handling behavior
+ if (Err) {
+ raw_fd_stream::error_detected(Err);
+ return -1;
+ }
+ return raw_fd_stream::read(Ptr, Size);
+}
diff --git a/llvm/unittests/Support/raw_socket_stream_test.cpp b/llvm/unittests/Support/raw_socket_stream_test.cpp
index c4e8cfbbe7e6a..348fb4bb3e089 100644
--- a/llvm/unittests/Support/raw_socket_stream_test.cpp
+++ b/llvm/unittests/Support/raw_socket_stream_test.cpp
@@ -62,17 +62,50 @@ TEST(raw_socket_streamTest, CLIENT_TO_SERVER_AND_SERVER_TO_CLIENT) {
ssize_t BytesRead = Server.read(Bytes, 8);
std::string string(Bytes, 8);
+ ASSERT_EQ(Server.has_error(), false);
ASSERT_EQ(8, BytesRead);
ASSERT_EQ("01234567", string);
}
-TEST(raw_socket_streamTest, TIMEOUT_PROVIDED) {
+TEST(raw_socket_streamTest, READ_WITH_TIMEOUT) {
if (!hasUnixSocketSupport())
GTEST_SKIP();
SmallString<100> SocketPath;
- llvm::sys::fs::createUniquePath("timout_provided.sock", SocketPath, true);
+ llvm::sys::fs::createUniquePath("read_with_timeout.sock", SocketPath, true);
+
+ // Make sure socket file does not exist. May still be there from the last test
+ std::remove(SocketPath.c_str());
+
+ Expected<ListeningSocket> MaybeServerListener =
+ ListeningSocket::createUnix(SocketPath);
+ ASSERT_THAT_EXPECTED(MaybeServerListener, llvm::Succeeded());
+ ListeningSocket ServerListener = std::move(*MaybeServerListener);
+
+ Expected<std::unique_ptr<raw_socket_stream>> MaybeClient =
+ raw_socket_stream::createConnectedUnix(SocketPath);
+ ASSERT_THAT_EXPECTED(MaybeClient, llvm::Succeeded());
+
+ Expected<std::unique_ptr<raw_socket_stream>> MaybeServer =
+ ServerListener.accept();
+ ASSERT_THAT_EXPECTED(MaybeServer, llvm::Succeeded());
+ raw_socket_stream &Server = **MaybeServer;
+
+ char Bytes[8];
+ ssize_t BytesRead = Server.read(Bytes, 8, std::chrono::milliseconds(100));
+ ASSERT_EQ(BytesRead, -1);
+ ASSERT_EQ(Server.has_error(), true);
+ ASSERT_EQ(Server.error(), std::errc::timed_out);
+ Server.clear_error();
+}
+
+TEST(raw_socket_streamTest, ACCEPT_WITH_TIMEOUT) {
+ if (!hasUnixSocketSupport())
+ GTEST_SKIP();
+
+ SmallString<100> SocketPath;
+ llvm::sys::fs::createUniquePath("accept_with_timeout.sock", SocketPath, true);
// Make sure socket file does not exist. May still be there from the last test
std::remove(SocketPath.c_str());
@@ -82,19 +115,19 @@ TEST(raw_socket_streamTest, TIMEOUT_PROVIDED) {
ASSERT_THAT_EXPECTED(MaybeServerListener, llvm::Succeeded());
ListeningSocket ServerListener = std::move(*MaybeServerListener);
- std::chrono::milliseconds Timeout = std::chrono::milliseconds(100);
Expected<std::unique_ptr<raw_socket_stream>> MaybeServer =
- ServerListener.accept(Timeout);
+ ServerListener.accept(std::chrono::milliseconds(100));
ASSERT_EQ(llvm::errorToErrorCode(MaybeServer.takeError()),
std::errc::timed_out);
}
-TEST(raw_socket_streamTest, FILE_DESCRIPTOR_CLOSED) {
+TEST(raw_socket_streamTest, ACCEPT_WITH_SHUTDOWN) {
if (!hasUnixSocketSupport())
GTEST_SKIP();
SmallString<100> SocketPath;
- llvm::sys::fs::createUniquePath("fd_closed.sock", SocketPath, true);
+ llvm::sys::fs::createUniquePath("accept_with_shutdown.sock", SocketPath,
+ true);
// Make sure socket file does not exist. May still be there from the last test
std::remove(SocketPath.c_str());
More information about the llvm-commits
mailing list