[llvm] [llvm][Support] Add function to read from raw_socket_stream file descriptor with timeout (PR #92308)

Connor Sughrue via llvm-commits llvm-commits at lists.llvm.org
Wed May 29 06:26:38 PDT 2024


https://github.com/cpsughrue updated https://github.com/llvm/llvm-project/pull/92308

>From a7f9b96dea4a090ad1ff9c1d06cb7584c8f1fed5 Mon Sep 17 00:00:00 2001
From: cpsughrue <cpsughrue at gmail.com>
Date: Wed, 29 May 2024 09:26:07 -0400
Subject: [PATCH] WIP

---
 llvm/include/llvm/Support/FileDescriptor.h    | 32 +++++++
 llvm/lib/Support/CMakeLists.txt               |  3 +-
 llvm/lib/Support/FileDescriptor.cpp           | 91 +++++++++++++++++++
 llvm/lib/Support/raw_socket_stream.cpp        | 62 +------------
 .../gn/secondary/llvm/lib/Support/BUILD.gn    |  1 +
 5 files changed, 130 insertions(+), 59 deletions(-)
 create mode 100644 llvm/include/llvm/Support/FileDescriptor.h
 create mode 100644 llvm/lib/Support/FileDescriptor.cpp

diff --git a/llvm/include/llvm/Support/FileDescriptor.h b/llvm/include/llvm/Support/FileDescriptor.h
new file mode 100644
index 0000000000000..ab83b34fca0fa
--- /dev/null
+++ b/llvm/include/llvm/Support/FileDescriptor.h
@@ -0,0 +1,32 @@
+//===-- FileDescriptor.h ----------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains a utility functions for working with file descriptors
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_SUPPORT_FILEDESCRIPTOR_H
+#define LLVM_SUPPORT_FILEDESCRIPTOR_H
+
+#include "llvm/Support/Error.h"
+#include <chrono>
+
+namespace llvm {
+// Helper function to get the value from either std::atomic<int> or int
+template <typename T> int getFD(T &FD) {
+  if constexpr (std::is_same_v<T, std::atomic<int>>) {
+    return FD.load();
+  } else {
+    return FD;
+  }
+}
+
+template <typename T>
+llvm::Error manageTimeout(std::chrono::milliseconds Timeout, T &FD, int PipeFD);
+} // namespace llvm
+#endif
diff --git a/llvm/lib/Support/CMakeLists.txt b/llvm/lib/Support/CMakeLists.txt
index be4badc09efa5..0a65e58da88a8 100644
--- a/llvm/lib/Support/CMakeLists.txt
+++ b/llvm/lib/Support/CMakeLists.txt
@@ -176,8 +176,9 @@ add_llvm_component_library(LLVMSupport
   ExponentialBackoff.cpp
   ExtensibleRTTI.cpp
   FileCollector.cpp
-  FileUtilities.cpp
+  FileDescriptor.cpp
   FileOutputBuffer.cpp
+  FileUtilities.cpp
   FloatingPointMode.cpp
   FoldingSet.cpp
   FormattedStream.cpp
diff --git a/llvm/lib/Support/FileDescriptor.cpp b/llvm/lib/Support/FileDescriptor.cpp
new file mode 100644
index 0000000000000..9ecf991fa5dc3
--- /dev/null
+++ b/llvm/lib/Support/FileDescriptor.cpp
@@ -0,0 +1,91 @@
+//===-- FileDescriptor.cpp --------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains a utility functions for working with file descriptors
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Support/Error.h"
+#include "llvm/Support/FileDescriptor.h"
+#include <atomic>
+#include <chrono>
+#include <poll.h>
+
+static std::error_code getLastSocketErrorCode() {
+#ifdef _WIN32
+  return std::error_code(::WSAGetLastError(), std::system_category());
+#else
+  return llvm::errnoAsErrorCode();
+#endif
+}
+
+template <typename T>
+llvm::Error llvm::manageTimeout(std::chrono::milliseconds Timeout, T &FD, int PipeFD) {
+  static_assert(std::is_same_v<T, int> || std::is_same_v<T, std::atomic<int>>,
+                "FD must be of type int& or std::atomic<int>&");
+
+  struct pollfd FDs[2];
+  FDs[0].events = POLLIN;
+#ifdef _WIN32
+  SOCKET WinServerSock = _get_osfhandle(FD);
+  FDs[0].fd = WinServerSock;
+#else
+  FDs[0].fd = llvm::getFD(FD);
+#endif
+  FDs[1].events = POLLIN;
+  FDs[1].fd = PipeFD;
+
+  // Keep track of how much time has passed in case poll is interupted by a
+  // signal and needs to be recalled
+  int RemainingTime = Timeout.count();
+  std::chrono::milliseconds ElapsedTime = std::chrono::milliseconds(0);
+  int PollStatus = -1;
+
+  while (PollStatus == -1 && (Timeout.count() == -1 || ElapsedTime < Timeout)) {
+    if (Timeout.count() != -1)
+      RemainingTime -= ElapsedTime.count();
+
+    auto Start = std::chrono::steady_clock::now();
+#ifdef _WIN32
+    PollStatus = WSAPoll(FDs, 2, RemainingTime);
+#else
+    PollStatus = ::poll(FDs, 2, RemainingTime);
+#endif
+    // If FD equals -1 then ListeningSocket::shutdown has been called and it is
+    // appropriate to return operation_canceled
+    if (FD == -1)
+      return llvm::make_error<llvm::StringError>(
+          std::make_error_code(std::errc::operation_canceled),
+          "Accept canceled");
+
+#if _WIN32
+    if (PollStatus == SOCKET_ERROR) {
+#else
+    if (PollStatus == -1) {
+#endif
+      std::error_code PollErrCode = getLastSocketErrorCode();
+      // Ignore EINTR (signal occured before any request event) and retry
+      if (PollErrCode != std::errc::interrupted)
+        return llvm::make_error<llvm::StringError>(PollErrCode,
+                                                   "FD poll failed");
+    }
+    if (PollStatus == 0)
+      return llvm::make_error<llvm::StringError>(
+          std::make_error_code(std::errc::timed_out),
+          "No client requests within timeout window");
+
+    if (FDs[0].revents & POLLNVAL)
+      return llvm::make_error<llvm::StringError>(
+          std::make_error_code(std::errc::bad_file_descriptor));
+
+    auto Stop = std::chrono::steady_clock::now();
+    ElapsedTime +=
+        std::chrono::duration_cast<std::chrono::milliseconds>(Stop - Start);
+  }
+  return llvm::Error::success();
+}
diff --git a/llvm/lib/Support/raw_socket_stream.cpp b/llvm/lib/Support/raw_socket_stream.cpp
index 549d537709bf2..4f1b24377bd5e 100644
--- a/llvm/lib/Support/raw_socket_stream.cpp
+++ b/llvm/lib/Support/raw_socket_stream.cpp
@@ -14,6 +14,7 @@
 #include "llvm/Support/raw_socket_stream.h"
 #include "llvm/Config/config.h"
 #include "llvm/Support/Error.h"
+#include "llvm/Support/FileDescriptor.h"
 #include "llvm/Support/FileSystem.h"
 
 #include <atomic>
@@ -179,64 +180,9 @@ Expected<ListeningSocket> ListeningSocket::createUnix(StringRef SocketPath,
 
 Expected<std::unique_ptr<raw_socket_stream>>
 ListeningSocket::accept(std::chrono::milliseconds Timeout) {
-
-  struct pollfd FDs[2];
-  FDs[0].events = POLLIN;
-#ifdef _WIN32
-  SOCKET WinServerSock = _get_osfhandle(FD);
-  FDs[0].fd = WinServerSock;
-#else
-  FDs[0].fd = FD;
-#endif
-  FDs[1].events = POLLIN;
-  FDs[1].fd = PipeFD[0];
-
-  // Keep track of how much time has passed in case poll is interupted by a
-  // signal and needs to be recalled
-  int RemainingTime = Timeout.count();
-  std::chrono::milliseconds ElapsedTime = std::chrono::milliseconds(0);
-  int PollStatus = -1;
-
-  while (PollStatus == -1 && (Timeout.count() == -1 || ElapsedTime < Timeout)) {
-    if (Timeout.count() != -1)
-      RemainingTime -= ElapsedTime.count();
-
-    auto Start = std::chrono::steady_clock::now();
-#ifdef _WIN32
-    PollStatus = WSAPoll(FDs, 2, RemainingTime);
-#else
-    PollStatus = ::poll(FDs, 2, RemainingTime);
-#endif
-    // If FD equals -1 then ListeningSocket::shutdown has been called and it is
-    // appropriate to return operation_canceled
-    if (FD.load() == -1)
-      return llvm::make_error<StringError>(
-          std::make_error_code(std::errc::operation_canceled),
-          "Accept canceled");
-
-#if _WIN32
-    if (PollStatus == SOCKET_ERROR) {
-#else
-    if (PollStatus == -1) {
-#endif
-      std::error_code PollErrCode = getLastSocketErrorCode();
-      // Ignore EINTR (signal occured before any request event) and retry
-      if (PollErrCode != std::errc::interrupted)
-        return llvm::make_error<StringError>(PollErrCode, "FD poll failed");
-    }
-    if (PollStatus == 0)
-      return llvm::make_error<StringError>(
-          std::make_error_code(std::errc::timed_out),
-          "No client requests within timeout window");
-
-    if (FDs[0].revents & POLLNVAL)
-      return llvm::make_error<StringError>(
-          std::make_error_code(std::errc::bad_file_descriptor));
-
-    auto Stop = std::chrono::steady_clock::now();
-    ElapsedTime +=
-        std::chrono::duration_cast<std::chrono::milliseconds>(Stop - Start);
-  }
+  llvm::Error TimeoutErr = manageTimeout(Timeout, FD, PipeFD[0]);
+  // if (TimeoutErr)
+  //   return TimeoutErr;
 
   int AcceptFD;
 #ifdef _WIN32
diff --git a/llvm/utils/gn/secondary/llvm/lib/Support/BUILD.gn b/llvm/utils/gn/secondary/llvm/lib/Support/BUILD.gn
index 7728455499bf3..79259abb80022 100644
--- a/llvm/utils/gn/secondary/llvm/lib/Support/BUILD.gn
+++ b/llvm/utils/gn/secondary/llvm/lib/Support/BUILD.gn
@@ -80,6 +80,7 @@ static_library("Support") {
     "ExponentialBackoff.cpp",
     "ExtensibleRTTI.cpp",
     "FileCollector.cpp",
+    "FileDescriptor.cpp",
     "FileOutputBuffer.cpp",
     "FileUtilities.cpp",
     "FloatingPointMode.cpp",



More information about the llvm-commits mailing list