[llvm] 203232f - [llvm][Support] ListeningSocket::accept returns operation_canceled if FD is set to -1 (#89479)

via llvm-commits llvm-commits at lists.llvm.org
Tue May 21 17:32:15 PDT 2024


Author: Connor Sughrue
Date: 2024-05-21T20:32:11-04:00
New Revision: 203232ffbd80e9f4631213a3876f14dde155a92d

URL: https://github.com/llvm/llvm-project/commit/203232ffbd80e9f4631213a3876f14dde155a92d
DIFF: https://github.com/llvm/llvm-project/commit/203232ffbd80e9f4631213a3876f14dde155a92d.diff

LOG: [llvm][Support] ListeningSocket::accept returns operation_canceled if FD is set to -1 (#89479)

If `::poll` returns and `FD` equals -1, then `ListeningSocket::shutdown`
has been called. So, regardless of any other information that could be
gleaned from `FDs.revents` or `PollStatus`, it is appropriate to return
`std::errc::operation_canceled`. `ListeningSocket::shutdown` copies
`FD`'s value to `ObservedFD` then sets `FD` to -1 before canceling
`::poll` by calling `::close(ObservedFD)` and writing to the pipe.

Added: 
    

Modified: 
    llvm/lib/Support/raw_socket_stream.cpp
    llvm/unittests/Support/raw_socket_stream_test.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Support/raw_socket_stream.cpp b/llvm/lib/Support/raw_socket_stream.cpp
index 14e2308df4d7e..549d537709bf2 100644
--- a/llvm/lib/Support/raw_socket_stream.cpp
+++ b/llvm/lib/Support/raw_socket_stream.cpp
@@ -204,17 +204,26 @@ ListeningSocket::accept(std::chrono::milliseconds Timeout) {
     auto Start = std::chrono::steady_clock::now();
 #ifdef _WIN32
     PollStatus = WSAPoll(FDs, 2, RemainingTime);
-    if (PollStatus == SOCKET_ERROR) {
 #else
     PollStatus = ::poll(FDs, 2, RemainingTime);
+#endif
+    // If FD equals -1 then ListeningSocket::shutdown has been called and it is
+    // appropriate to return operation_canceled
+    if (FD.load() == -1)
+      return llvm::make_error<StringError>(
+          std::make_error_code(std::errc::operation_canceled),
+          "Accept canceled");
+
+#if _WIN32
+    if (PollStatus == SOCKET_ERROR) {
+#else
     if (PollStatus == -1) {
 #endif
-      // Ignore error if caused by interupting signal
       std::error_code PollErrCode = getLastSocketErrorCode();
+      // Ignore EINTR (signal occured before any request event) and retry
       if (PollErrCode != std::errc::interrupted)
         return llvm::make_error<StringError>(PollErrCode, "FD poll failed");
     }
-
     if (PollStatus == 0)
       return llvm::make_error<StringError>(
           std::make_error_code(std::errc::timed_out),
@@ -222,13 +231,7 @@ ListeningSocket::accept(std::chrono::milliseconds Timeout) {
 
     if (FDs[0].revents & POLLNVAL)
       return llvm::make_error<StringError>(
-          std::make_error_code(std::errc::bad_file_descriptor),
-          "File descriptor closed by another thread");
-
-    if (FDs[1].revents & POLLIN)
-      return llvm::make_error<StringError>(
-          std::make_error_code(std::errc::operation_canceled),
-          "Accept canceled");
+          std::make_error_code(std::errc::bad_file_descriptor));
 
     auto Stop = std::chrono::steady_clock::now();
     ElapsedTime +=

diff  --git a/llvm/unittests/Support/raw_socket_stream_test.cpp b/llvm/unittests/Support/raw_socket_stream_test.cpp
index a8536228666db..c4e8cfbbe7e6a 100644
--- a/llvm/unittests/Support/raw_socket_stream_test.cpp
+++ b/llvm/unittests/Support/raw_socket_stream_test.cpp
@@ -7,7 +7,6 @@
 #include "llvm/Testing/Support/Error.h"
 #include "gtest/gtest.h"
 #include <future>
-#include <iostream>
 #include <stdlib.h>
 #include <thread>
 
@@ -86,13 +85,8 @@ TEST(raw_socket_streamTest, TIMEOUT_PROVIDED) {
   std::chrono::milliseconds Timeout = std::chrono::milliseconds(100);
   Expected<std::unique_ptr<raw_socket_stream>> MaybeServer =
       ServerListener.accept(Timeout);
-
-  ASSERT_THAT_EXPECTED(MaybeServer, Failed());
-  llvm::Error Err = MaybeServer.takeError();
-  llvm::handleAllErrors(std::move(Err), [&](const llvm::StringError &SE) {
-    std::error_code EC = SE.convertToErrorCode();
-    ASSERT_EQ(EC, std::errc::timed_out);
-  });
+  ASSERT_EQ(llvm::errorToErrorCode(MaybeServer.takeError()),
+            std::errc::timed_out);
 }
 
 TEST(raw_socket_streamTest, FILE_DESCRIPTOR_CLOSED) {
@@ -122,12 +116,7 @@ TEST(raw_socket_streamTest, FILE_DESCRIPTOR_CLOSED) {
 
   // Wait for the CloseThread to finish
   CloseThread.join();
-
-  ASSERT_THAT_EXPECTED(MaybeServer, Failed());
-  llvm::Error Err = MaybeServer.takeError();
-  llvm::handleAllErrors(std::move(Err), [&](const llvm::StringError &SE) {
-    std::error_code EC = SE.convertToErrorCode();
-    ASSERT_EQ(EC, std::errc::operation_canceled);
-  });
+  ASSERT_EQ(llvm::errorToErrorCode(MaybeServer.takeError()),
+            std::errc::operation_canceled);
 }
 } // namespace


        


More information about the llvm-commits mailing list