[libc-commits] [libc] [libc] implement secure random buffer filling with vDSO (PR #109870)

Schrodinger ZHU Yifan via libc-commits libc-commits at lists.llvm.org
Wed Sep 25 06:39:08 PDT 2024


https://github.com/SchrodingerZhu updated https://github.com/llvm/llvm-project/pull/109870

>From 553925e8ed93fdf035cb1094c8018eb7ea36026f Mon Sep 17 00:00:00 2001
From: Schrodinger ZHU Yifan <yifanzhu at rochester.edu>
Date: Tue, 24 Sep 2024 18:01:46 -0400
Subject: [PATCH 1/5] [libc] implement secure random buffer filling with vDSO

---
 .../src/__support/OSUtil/linux/CMakeLists.txt |  22 ++
 .../src/__support/OSUtil/linux/aarch64/vdso.h |   2 +
 libc/src/__support/OSUtil/linux/random.cpp    | 323 ++++++++++++++++++
 libc/src/__support/OSUtil/linux/random.h      |  20 ++
 libc/src/__support/OSUtil/linux/vdso_sym.h    |   9 +-
 libc/src/__support/OSUtil/linux/x86_64/vdso.h |   2 +
 libc/src/__support/threads/thread.cpp         |   2 +-
 .../integration/src/__support/CMakeLists.txt  |   1 +
 .../src/__support/OSUtil/CMakeLists.txt       |   5 +
 .../src/__support/OSUtil/linux/CMakeLists.txt |  10 +
 .../OSUtil/linux/random_fill_test.cpp         |  22 ++
 11 files changed, 414 insertions(+), 4 deletions(-)
 create mode 100644 libc/src/__support/OSUtil/linux/random.cpp
 create mode 100644 libc/src/__support/OSUtil/linux/random.h
 create mode 100644 libc/test/integration/src/__support/OSUtil/CMakeLists.txt
 create mode 100644 libc/test/integration/src/__support/OSUtil/linux/CMakeLists.txt
 create mode 100644 libc/test/integration/src/__support/OSUtil/linux/random_fill_test.cpp

diff --git a/libc/src/__support/OSUtil/linux/CMakeLists.txt b/libc/src/__support/OSUtil/linux/CMakeLists.txt
index 6c7014940407d8..4752e326cc153e 100644
--- a/libc/src/__support/OSUtil/linux/CMakeLists.txt
+++ b/libc/src/__support/OSUtil/linux/CMakeLists.txt
@@ -53,3 +53,25 @@ add_object_library(
     libc.src.errno.errno
     libc.src.sys.auxv.getauxval
 )
+
+add_object_library(
+    random
+  HDRS
+    random.h
+  SRCS
+    random.cpp
+  DEPENDS
+    libc.src.sys.random.getrandom
+    libc.src.sys.mman.mmap
+    libc.src.sys.mman.munmap
+    libc.src.unistd.sysconf
+    libc.src.errno.errno
+    libc.src.__support.common
+    libc.src.__support.OSUtil.linux.vdso
+    libc.src.__support.threads.callonce
+    libc.src.__support.threads.linux.raw_mutex
+    libc.src.__support.threads.thread
+    libc.src.sched.sched_getaffinity
+    libc.src.sched.__sched_getcpucount
+)
+
diff --git a/libc/src/__support/OSUtil/linux/aarch64/vdso.h b/libc/src/__support/OSUtil/linux/aarch64/vdso.h
index 3c4c6205071da2..ee5777ad67f6dd 100644
--- a/libc/src/__support/OSUtil/linux/aarch64/vdso.h
+++ b/libc/src/__support/OSUtil/linux/aarch64/vdso.h
@@ -23,6 +23,8 @@ LIBC_INLINE constexpr cpp::string_view symbol_name(VDSOSym sym) {
     return "__kernel_clock_gettime";
   case VDSOSym::ClockGetRes:
     return "__kernel_clock_getres";
+  case VDSOSym::GetRandom:
+    return "__kernel_getrandom";
   default:
     return "";
   }
diff --git a/libc/src/__support/OSUtil/linux/random.cpp b/libc/src/__support/OSUtil/linux/random.cpp
new file mode 100644
index 00000000000000..612a184bda8455
--- /dev/null
+++ b/libc/src/__support/OSUtil/linux/random.cpp
@@ -0,0 +1,323 @@
+#include "src/__support/OSUtil/linux/random.h"
+#include "src/__support/CPP/mutex.h"
+#include "src/__support/CPP/new.h"
+#include "src/__support/OSUtil/linux/syscall.h"
+#include "src/__support/OSUtil/linux/vdso.h"
+#include "src/__support/OSUtil/linux/x86_64/vdso.h"
+#include "src/__support/libc_assert.h"
+#include "src/__support/memory_size.h"
+#include "src/__support/threads/callonce.h"
+#include "src/__support/threads/linux/callonce.h"
+#include "src/__support/threads/linux/raw_mutex.h"
+#include "src/errno/libc_errno.h"
+#include "src/sched/sched_getaffinity.h"
+#include "src/sched/sched_getcpucount.h"
+#include "src/stdlib/atexit.h"
+#include "src/sys/mman/mmap.h"
+#include "src/sys/mman/munmap.h"
+#include "src/sys/random/getrandom.h"
+#include "src/unistd/sysconf.h"
+#include <asm/param.h>
+
+namespace LIBC_NAMESPACE_DECL {
+namespace {
+// errno protection
+struct ErrnoProtect {
+  int backup;
+  ErrnoProtect() : backup(libc_errno) { libc_errno = 0; }
+  ~ErrnoProtect() { libc_errno = backup; }
+};
+
+// parameters for allocating per-thread random state
+struct Params {
+  unsigned size_of_opaque_state;
+  unsigned mmap_prot;
+  unsigned mmap_flags;
+  unsigned reserved[13];
+};
+
+// for registering thread-specific atexit callbacks
+using Destructor = void(void *);
+extern "C" int __cxa_thread_atexit_impl(Destructor *, void *, void *);
+extern "C" [[gnu::weak, gnu::visibility("hidden")]] void *__dso_handle =
+    nullptr;
+
+class MMapContainer {
+  void **ptr = nullptr;
+  void **usage = nullptr;
+  void **boundary = nullptr;
+
+  internal::SafeMemSize capacity() const {
+    return internal::SafeMemSize{
+        static_cast<size_t>(reinterpret_cast<ptrdiff_t>(boundary) -
+                            reinterpret_cast<ptrdiff_t>(ptr))};
+  }
+
+  internal::SafeMemSize bytes() const {
+    return capacity() * internal::SafeMemSize{sizeof(void *)};
+  }
+
+  bool initialize() {
+    internal::SafeMemSize page_size{static_cast<size_t>(sysconf(_SC_PAGESIZE))};
+    if (!page_size.valid())
+      return false;
+    ptr = reinterpret_cast<void **>(mmap(nullptr, page_size,
+                                         PROT_READ | PROT_WRITE,
+                                         MAP_PRIVATE | MAP_ANONYMOUS, -1, 0));
+    if (ptr == MAP_FAILED)
+      return false;
+    usage = ptr;
+    boundary = ptr + page_size / sizeof(void *);
+    return true;
+  }
+
+  bool grow(size_t additional) {
+    if (ptr == nullptr)
+      return initialize();
+
+    size_t old_capacity = capacity();
+
+    internal::SafeMemSize target_bytes{additional};
+    internal::SafeMemSize new_bytes = bytes();
+    target_bytes = target_bytes + size();
+    target_bytes = target_bytes * internal::SafeMemSize{sizeof(void *)};
+
+    if (!target_bytes.valid())
+      return false;
+    while (new_bytes < target_bytes) {
+      new_bytes = new_bytes * internal::SafeMemSize{static_cast<size_t>(2)};
+      if (!new_bytes.valid())
+        return false;
+    }
+
+    // TODO: migrate to syscall wrapper once it's available
+    auto result = syscall_impl<intptr_t>(
+        SYS_mremap, bytes(), static_cast<size_t>(new_bytes), MREMAP_MAYMOVE);
+
+    if (result < 0 && result > -EXEC_PAGESIZE)
+      return false;
+    ptr = reinterpret_cast<void **>(result);
+    usage = ptr + old_capacity;
+    boundary = ptr + new_bytes / sizeof(void *);
+    return true;
+  }
+
+public:
+  MMapContainer() = default;
+  ~MMapContainer() {
+    if (!ptr)
+      return;
+    munmap(ptr, bytes());
+  }
+
+  bool ensure_space(size_t additional) {
+    if (usage + additional >= boundary && !grow(additional))
+      return false;
+    return true;
+  }
+
+  void push_unchecked(void *value) {
+    LIBC_ASSERT(usage != boundary && "pushing into full container");
+    *usage++ = value;
+  }
+
+  using iterator = void **;
+  using value_type = void *;
+  iterator begin() const { return ptr; }
+  iterator end() const { return usage; }
+
+  bool empty() const { return begin() == end(); }
+  void *pop() {
+    LIBC_ASSERT(!empty() && "popping from empty container");
+    return *--usage;
+  }
+  internal::SafeMemSize size() const {
+    return internal::SafeMemSize{static_cast<size_t>(
+        reinterpret_cast<ptrdiff_t>(usage) - reinterpret_cast<ptrdiff_t>(ptr))};
+  }
+};
+
+class StateFactory {
+  RawMutex mutex{};
+  MMapContainer allocations{};
+  MMapContainer freelist{};
+  Params params{};
+  size_t states_per_page = 0;
+  size_t pages_per_allocation = 0;
+  size_t page_size = 0;
+
+  bool prepare() {
+    vdso::TypedSymbol<vdso::VDSOSym::GetRandom> vgetrandom;
+
+    if (!vgetrandom)
+      return false;
+
+    // get the allocation configuration suggested by the kernel
+    if (vgetrandom(nullptr, 0, 0, &params, ~0UL))
+      return false;
+
+    cpu_set_t cs{};
+
+    if (LIBC_NAMESPACE::sched_getaffinity(0, sizeof(cs), &cs))
+      return false;
+
+    internal::SafeMemSize count{static_cast<size_t>(
+        LIBC_NAMESPACE::__sched_getcpucount(sizeof(cs), &cs))};
+
+    internal::SafeMemSize allocation_size =
+        internal::SafeMemSize{
+            static_cast<size_t>(params.size_of_opaque_state)} *
+        count;
+
+    page_size = static_cast<size_t>(sysconf(_SC_PAGESIZE));
+    allocation_size = allocation_size.align_up(page_size);
+    if (!allocation_size.valid())
+      return false;
+
+    states_per_page = page_size / params.size_of_opaque_state;
+    pages_per_allocation = allocation_size / page_size;
+
+    return true;
+  }
+
+  bool allocate_new_states() {
+    if (!allocations.ensure_space(1))
+      return false;
+
+    // we always ensure the freelist can contain all the allocated states
+    internal::SafeMemSize total_size =
+        internal::SafeMemSize{page_size} *
+        internal::SafeMemSize{pages_per_allocation} *
+        (internal::SafeMemSize{static_cast<size_t>(1)} + allocations.size());
+
+    if (!total_size.valid() ||
+        !freelist.ensure_space(total_size - freelist.size()))
+      return false;
+
+    auto *new_allocation =
+        static_cast<char *>(mmap(nullptr, page_size * pages_per_allocation,
+                                 params.mmap_prot, params.mmap_flags, -1, 0));
+    if (new_allocation == MAP_FAILED)
+      return false;
+
+    for (size_t i = 0; i < pages_per_allocation; ++i) {
+      auto *page = new_allocation + i * page_size;
+      for (size_t j = 0; j < states_per_page; ++j)
+        freelist.push_unchecked(page + j * params.size_of_opaque_state);
+    }
+    return true;
+  }
+
+  static StateFactory *instance() {
+    alignas(StateFactory) static char storage[sizeof(StateFactory)]{};
+    static CallOnceFlag flag = callonce_impl::NOT_CALLED;
+    static bool valid = false;
+    callonce(&flag, []() {
+      auto *factory = new (storage) StateFactory();
+      valid = factory->prepare();
+      if (valid)
+        atexit([]() {
+          auto factory = reinterpret_cast<StateFactory *>(storage);
+          factory->~StateFactory();
+          valid = false;
+        });
+    });
+    return valid ? reinterpret_cast<StateFactory *>(storage) : nullptr;
+  }
+
+  void *acquire() {
+    cpp::lock_guard guard{mutex};
+    if (freelist.empty() && !allocate_new_states())
+      return nullptr;
+    return freelist.pop();
+  }
+  void release(void *state) {
+    cpp::lock_guard guard{mutex};
+    // there should be no need to check this pushing
+    freelist.push_unchecked(state);
+  }
+  ~StateFactory() {
+    for (auto *allocation : allocations)
+      munmap(allocation, page_size * pages_per_allocation);
+  }
+
+public:
+  static void *acquire_global() {
+    auto *factory = instance();
+    if (!factory)
+      return nullptr;
+    return factory->acquire();
+  }
+  static void release_global(void *state) {
+    auto *factory = instance();
+    if (!factory)
+      return;
+    factory->release(state);
+  }
+  static size_t size_of_opaque_state() {
+    return instance()->params.size_of_opaque_state;
+  }
+};
+
+void *acquire_tls() {
+  static thread_local void *state = nullptr;
+  // previous acquire failed, do not try again
+  if (state == MAP_FAILED)
+    return nullptr;
+  // first acquirement
+  if (state == nullptr) {
+    state = StateFactory::acquire_global();
+    // if still fails, remember the failure
+    if (state == nullptr) {
+      state = MAP_FAILED;
+      return nullptr;
+    } else {
+      // register the release callback.
+      if (__cxa_thread_atexit_impl(
+              [](void *s) { StateFactory::release_global(s); }, state,
+              __dso_handle)) {
+        StateFactory::release_global(state);
+        state = MAP_FAILED;
+        return nullptr;
+      }
+    }
+  }
+  return state;
+}
+
+template <class F> void random_fill_impl(F gen, void *buf, size_t size) {
+  auto *buffer = reinterpret_cast<uint8_t *>(buf);
+  while (size > 0) {
+    ssize_t len = gen(buffer, size);
+    if (len == -1) {
+      if (libc_errno == EINTR)
+        continue;
+      break;
+    }
+    size -= len;
+    buffer += len;
+  }
+}
+} // namespace
+
+void random_fill(void *buf, size_t size) {
+  ErrnoProtect protect;
+  void *state = acquire_tls();
+  if (state) {
+    random_fill_impl(
+        [state](void *buf, size_t size) {
+          vdso::TypedSymbol<vdso::VDSOSym::GetRandom> vgetrandom;
+          return vgetrandom(buf, size, 0, state,
+                            StateFactory::size_of_opaque_state());
+        },
+        buf, size);
+  } else {
+    random_fill_impl(
+        [](void *buf, size_t size) {
+          return LIBC_NAMESPACE::getrandom(buf, size, 0);
+        },
+        buf, size);
+  }
+}
+
+} // namespace LIBC_NAMESPACE_DECL
diff --git a/libc/src/__support/OSUtil/linux/random.h b/libc/src/__support/OSUtil/linux/random.h
new file mode 100644
index 00000000000000..0e2d51391ec31c
--- /dev/null
+++ b/libc/src/__support/OSUtil/linux/random.h
@@ -0,0 +1,20 @@
+//===-- Utilities for getting secure randomness -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_SRC___SUPPORT_RANDOMNESS_H
+#define LLVM_LIBC_SRC___SUPPORT_RANDOMNESS_H
+
+#include "src/__support/common.h"
+
+#define __need_size_t
+#include <stddef.h>
+
+namespace LIBC_NAMESPACE_DECL {
+void random_fill(void *buf, unsigned long size);
+} // namespace LIBC_NAMESPACE_DECL
+#endif // LLVM_LIBC_SRC___SUPPORT_RANDOMNESS_H
diff --git a/libc/src/__support/OSUtil/linux/vdso_sym.h b/libc/src/__support/OSUtil/linux/vdso_sym.h
index 968e1536c4d270..2b1cf398369de2 100644
--- a/libc/src/__support/OSUtil/linux/vdso_sym.h
+++ b/libc/src/__support/OSUtil/linux/vdso_sym.h
@@ -19,7 +19,6 @@ struct __kernel_timespec;
 struct timezone;
 struct riscv_hwprobe;
 struct getcpu_cache;
-struct cpu_set_t;
 // NOLINTEND(llvmlibc-implementation-in-namespace)
 
 namespace LIBC_NAMESPACE_DECL {
@@ -35,7 +34,8 @@ enum class VDSOSym {
   RTSigReturn,
   FlushICache,
   RiscvHwProbe,
-  VDSOSymCount
+  GetRandom,
+  VDSOSymCount,
 };
 
 template <VDSOSym sym> LIBC_INLINE constexpr auto dispatcher() {
@@ -58,8 +58,11 @@ template <VDSOSym sym> LIBC_INLINE constexpr auto dispatcher() {
   else if constexpr (sym == VDSOSym::FlushICache)
     return static_cast<void (*)(void *, void *, unsigned int)>(nullptr);
   else if constexpr (sym == VDSOSym::RiscvHwProbe)
-    return static_cast<int (*)(riscv_hwprobe *, size_t, size_t, cpu_set_t *,
+    return static_cast<int (*)(riscv_hwprobe *, size_t, size_t, void *,
                                unsigned)>(nullptr);
+  else if constexpr (sym == VDSOSym::GetRandom)
+    return static_cast<int (*)(void *, size_t, unsigned int, void *, size_t)>(
+        nullptr);
   else
     return static_cast<void *>(nullptr);
 }
diff --git a/libc/src/__support/OSUtil/linux/x86_64/vdso.h b/libc/src/__support/OSUtil/linux/x86_64/vdso.h
index abe7c33e07cfab..f46fcb038f2e60 100644
--- a/libc/src/__support/OSUtil/linux/x86_64/vdso.h
+++ b/libc/src/__support/OSUtil/linux/x86_64/vdso.h
@@ -29,6 +29,8 @@ LIBC_INLINE constexpr cpp::string_view symbol_name(VDSOSym sym) {
     return "__vdso_time";
   case VDSOSym::ClockGetRes:
     return "__vdso_clock_getres";
+  case VDSOSym::GetRandom:
+    return "__vdso_getrandom";
   default:
     return "";
   }
diff --git a/libc/src/__support/threads/thread.cpp b/libc/src/__support/threads/thread.cpp
index dad4f75f092ede..04668dbfcbb63a 100644
--- a/libc/src/__support/threads/thread.cpp
+++ b/libc/src/__support/threads/thread.cpp
@@ -117,7 +117,7 @@ class ThreadAtExitCallbackMgr {
 
   int add_callback(AtExitCallback *callback, void *obj) {
     cpp::lock_guard lock(mtx);
-    return callback_list.push_back({callback, obj});
+    return callback_list.push_back({callback, obj}) ? 0 : -1;
   }
 
   void call() {
diff --git a/libc/test/integration/src/__support/CMakeLists.txt b/libc/test/integration/src/__support/CMakeLists.txt
index b5b6557e8d6899..d2dae7e02a9c57 100644
--- a/libc/test/integration/src/__support/CMakeLists.txt
+++ b/libc/test/integration/src/__support/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_subdirectory(threads)
+add_subdirectory(OSUtil)
 if(LIBC_TARGET_OS_IS_GPU)
   add_subdirectory(GPU)
 endif()
diff --git a/libc/test/integration/src/__support/OSUtil/CMakeLists.txt b/libc/test/integration/src/__support/OSUtil/CMakeLists.txt
new file mode 100644
index 00000000000000..5ff1a11aff5c9d
--- /dev/null
+++ b/libc/test/integration/src/__support/OSUtil/CMakeLists.txt
@@ -0,0 +1,5 @@
+add_custom_target(libc-osutil-integration-tests)
+
+if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${LIBC_TARGET_OS})
+  add_subdirectory(${LIBC_TARGET_OS})
+endif()
diff --git a/libc/test/integration/src/__support/OSUtil/linux/CMakeLists.txt b/libc/test/integration/src/__support/OSUtil/linux/CMakeLists.txt
new file mode 100644
index 00000000000000..a4f15ad8370352
--- /dev/null
+++ b/libc/test/integration/src/__support/OSUtil/linux/CMakeLists.txt
@@ -0,0 +1,10 @@
+add_integration_test(
+  random_fill_test
+  SUITE
+    libc-osutil-integration-tests
+  SRCS
+    random_fill_test.cpp
+  DEPENDS
+    libc.include.pthread
+    libc.src.__support.OSUtil.linux.random
+)
diff --git a/libc/test/integration/src/__support/OSUtil/linux/random_fill_test.cpp b/libc/test/integration/src/__support/OSUtil/linux/random_fill_test.cpp
new file mode 100644
index 00000000000000..bde24a1e0d5521
--- /dev/null
+++ b/libc/test/integration/src/__support/OSUtil/linux/random_fill_test.cpp
@@ -0,0 +1,22 @@
+//===-- Tests for pthread_equal -------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "src/__support/OSUtil/linux/random.h"
+
+#include "test/IntegrationTest/test.h"
+
+void smoke_test() {
+  using namespace LIBC_NAMESPACE;
+  uint32_t buffer;
+  random_fill(&buffer, sizeof(buffer));
+}
+
+TEST_MAIN() {
+  smoke_test();
+  return 0;
+}

>From a09f3f3c79da87107166ac3247d7454242d3cbbc Mon Sep 17 00:00:00 2001
From: Schrodinger ZHU Yifan <i at zhuyi.fan>
Date: Tue, 24 Sep 2024 22:38:49 -0400
Subject: [PATCH 2/5] some fix

---
 libc/src/__support/OSUtil/linux/random.cpp      | 17 ++++++++++++++---
 .../__support/OSUtil/linux/random_fill_test.cpp |  2 +-
 2 files changed, 15 insertions(+), 4 deletions(-)

diff --git a/libc/src/__support/OSUtil/linux/random.cpp b/libc/src/__support/OSUtil/linux/random.cpp
index 612a184bda8455..553b481200c299 100644
--- a/libc/src/__support/OSUtil/linux/random.cpp
+++ b/libc/src/__support/OSUtil/linux/random.cpp
@@ -1,9 +1,15 @@
+//===- Linux implementation of secure random buffer generation --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
 #include "src/__support/OSUtil/linux/random.h"
 #include "src/__support/CPP/mutex.h"
 #include "src/__support/CPP/new.h"
 #include "src/__support/OSUtil/linux/syscall.h"
 #include "src/__support/OSUtil/linux/vdso.h"
-#include "src/__support/OSUtil/linux/x86_64/vdso.h"
 #include "src/__support/libc_assert.h"
 #include "src/__support/memory_size.h"
 #include "src/__support/threads/callonce.h"
@@ -307,8 +313,13 @@ void random_fill(void *buf, size_t size) {
     random_fill_impl(
         [state](void *buf, size_t size) {
           vdso::TypedSymbol<vdso::VDSOSym::GetRandom> vgetrandom;
-          return vgetrandom(buf, size, 0, state,
-                            StateFactory::size_of_opaque_state());
+          int res = vgetrandom(buf, size, 0, state,
+                               StateFactory::size_of_opaque_state());
+          if (res < 0) {
+            libc_errno = -res;
+            return -1;
+          }
+          return res;
         },
         buf, size);
   } else {
diff --git a/libc/test/integration/src/__support/OSUtil/linux/random_fill_test.cpp b/libc/test/integration/src/__support/OSUtil/linux/random_fill_test.cpp
index bde24a1e0d5521..4e029e484742e9 100644
--- a/libc/test/integration/src/__support/OSUtil/linux/random_fill_test.cpp
+++ b/libc/test/integration/src/__support/OSUtil/linux/random_fill_test.cpp
@@ -1,4 +1,4 @@
-//===-- Tests for pthread_equal -------------------------------------------===//
+//===-- Tests for random_fill ---------------------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.

>From d8b722a6427991e73ecc23c635d8ec515b439026 Mon Sep 17 00:00:00 2001
From: Schrodinger ZHU Yifan <i at zhuyi.fan>
Date: Tue, 24 Sep 2024 23:11:46 -0400
Subject: [PATCH 3/5] add fork hooks

---
 libc/src/__support/OSUtil/linux/random.cpp | 75 ++++++++++++++--------
 libc/src/__support/OSUtil/linux/random.h   |  3 +
 2 files changed, 52 insertions(+), 26 deletions(-)

diff --git a/libc/src/__support/OSUtil/linux/random.cpp b/libc/src/__support/OSUtil/linux/random.cpp
index 553b481200c299..13a6c40ce965a7 100644
--- a/libc/src/__support/OSUtil/linux/random.cpp
+++ b/libc/src/__support/OSUtil/linux/random.cpp
@@ -214,22 +214,7 @@ class StateFactory {
     return true;
   }
 
-  static StateFactory *instance() {
-    alignas(StateFactory) static char storage[sizeof(StateFactory)]{};
-    static CallOnceFlag flag = callonce_impl::NOT_CALLED;
-    static bool valid = false;
-    callonce(&flag, []() {
-      auto *factory = new (storage) StateFactory();
-      valid = factory->prepare();
-      if (valid)
-        atexit([]() {
-          auto factory = reinterpret_cast<StateFactory *>(storage);
-          factory->~StateFactory();
-          valid = false;
-        });
-    });
-    return valid ? reinterpret_cast<StateFactory *>(storage) : nullptr;
-  }
+  static StateFactory *instance();
 
   void *acquire() {
     cpp::lock_guard guard{mutex};
@@ -263,32 +248,62 @@ class StateFactory {
   static size_t size_of_opaque_state() {
     return instance()->params.size_of_opaque_state;
   }
+  static void postfork_cleanup();
 };
 
+thread_local bool fork_inflight = false;
+thread_local void *tls_state = nullptr;
+alignas(StateFactory) static char factory_storage[sizeof(StateFactory)]{};
+static CallOnceFlag factory_onceflag = callonce_impl::NOT_CALLED;
+static bool factory_valid = false;
+
+StateFactory *StateFactory::instance() {
+  callonce(&factory_onceflag, []() {
+    auto *factory = new (factory_storage) StateFactory();
+    factory_valid = factory->prepare();
+    if (factory_valid)
+      atexit([]() {
+        auto factory = reinterpret_cast<StateFactory *>(factory_storage);
+        factory->~StateFactory();
+        factory_valid = false;
+      });
+  });
+  return factory_valid ? reinterpret_cast<StateFactory *>(factory_storage)
+                       : nullptr;
+}
+
+void StateFactory::postfork_cleanup() {
+  if (factory_valid)
+    reinterpret_cast<StateFactory *>(factory_storage)->~StateFactory();
+  factory_onceflag = callonce_impl::NOT_CALLED;
+  factory_valid = false;
+}
+
 void *acquire_tls() {
-  static thread_local void *state = nullptr;
+  if (fork_inflight)
+    return nullptr;
   // previous acquire failed, do not try again
-  if (state == MAP_FAILED)
+  if (tls_state == MAP_FAILED)
     return nullptr;
   // first acquirement
-  if (state == nullptr) {
-    state = StateFactory::acquire_global();
+  if (tls_state == nullptr) {
+    tls_state = StateFactory::acquire_global();
     // if still fails, remember the failure
-    if (state == nullptr) {
-      state = MAP_FAILED;
+    if (tls_state == nullptr) {
+      tls_state = MAP_FAILED;
       return nullptr;
     } else {
       // register the release callback.
       if (__cxa_thread_atexit_impl(
-              [](void *s) { StateFactory::release_global(s); }, state,
+              [](void *s) { StateFactory::release_global(s); }, tls_state,
               __dso_handle)) {
-        StateFactory::release_global(state);
-        state = MAP_FAILED;
+        StateFactory::release_global(tls_state);
+        tls_state = MAP_FAILED;
         return nullptr;
       }
     }
   }
-  return state;
+  return tls_state;
 }
 
 template <class F> void random_fill_impl(F gen, void *buf, size_t size) {
@@ -331,4 +346,12 @@ void random_fill(void *buf, size_t size) {
   }
 }
 
+void random_prefork() { fork_inflight = true; }
+void random_postfork_parent() { fork_inflight = false; }
+void random_postfork_child() {
+  tls_state = nullptr;
+  StateFactory::postfork_cleanup();
+  fork_inflight = false;
+}
+
 } // namespace LIBC_NAMESPACE_DECL
diff --git a/libc/src/__support/OSUtil/linux/random.h b/libc/src/__support/OSUtil/linux/random.h
index 0e2d51391ec31c..567d5f3f412f07 100644
--- a/libc/src/__support/OSUtil/linux/random.h
+++ b/libc/src/__support/OSUtil/linux/random.h
@@ -16,5 +16,8 @@
 
 namespace LIBC_NAMESPACE_DECL {
 void random_fill(void *buf, unsigned long size);
+void random_prefork();
+void random_postfork_parent();
+void random_postfork_child();
 } // namespace LIBC_NAMESPACE_DECL
 #endif // LLVM_LIBC_SRC___SUPPORT_RANDOMNESS_H

>From 6e95fe9f6d634b774c0671f23defeebb0b24e6f6 Mon Sep 17 00:00:00 2001
From: Schrodinger ZHU Yifan <i at zhuyi.fan>
Date: Tue, 24 Sep 2024 23:18:48 -0400
Subject: [PATCH 4/5] make names more meaningful

---
 libc/src/__support/OSUtil/linux/random.cpp | 38 ++++++++++++----------
 1 file changed, 20 insertions(+), 18 deletions(-)

diff --git a/libc/src/__support/OSUtil/linux/random.cpp b/libc/src/__support/OSUtil/linux/random.cpp
index 13a6c40ce965a7..ab63c4e27e3525 100644
--- a/libc/src/__support/OSUtil/linux/random.cpp
+++ b/libc/src/__support/OSUtil/linux/random.cpp
@@ -35,7 +35,7 @@ struct ErrnoProtect {
 };
 
 // parameters for allocating per-thread random state
-struct Params {
+struct RandomStateMMapParams {
   unsigned size_of_opaque_state;
   unsigned mmap_prot;
   unsigned mmap_flags;
@@ -143,11 +143,11 @@ class MMapContainer {
   }
 };
 
-class StateFactory {
+class RandomStateFactory {
   RawMutex mutex{};
   MMapContainer allocations{};
   MMapContainer freelist{};
-  Params params{};
+  RandomStateMMapParams params{};
   size_t states_per_page = 0;
   size_t pages_per_allocation = 0;
   size_t page_size = 0;
@@ -214,7 +214,7 @@ class StateFactory {
     return true;
   }
 
-  static StateFactory *instance();
+  static RandomStateFactory *instance();
 
   void *acquire() {
     cpp::lock_guard guard{mutex};
@@ -227,7 +227,7 @@ class StateFactory {
     // there should be no need to check this pushing
     freelist.push_unchecked(state);
   }
-  ~StateFactory() {
+  ~RandomStateFactory() {
     for (auto *allocation : allocations)
       munmap(allocation, page_size * pages_per_allocation);
   }
@@ -253,28 +253,30 @@ class StateFactory {
 
 thread_local bool fork_inflight = false;
 thread_local void *tls_state = nullptr;
-alignas(StateFactory) static char factory_storage[sizeof(StateFactory)]{};
+alignas(RandomStateFactory) static char factory_storage[sizeof(
+    RandomStateFactory)]{};
 static CallOnceFlag factory_onceflag = callonce_impl::NOT_CALLED;
 static bool factory_valid = false;
 
-StateFactory *StateFactory::instance() {
+RandomStateFactory *RandomStateFactory::instance() {
   callonce(&factory_onceflag, []() {
-    auto *factory = new (factory_storage) StateFactory();
+    auto *factory = new (factory_storage) RandomStateFactory();
     factory_valid = factory->prepare();
     if (factory_valid)
       atexit([]() {
-        auto factory = reinterpret_cast<StateFactory *>(factory_storage);
-        factory->~StateFactory();
+        auto factory = reinterpret_cast<RandomStateFactory *>(factory_storage);
+        factory->~RandomStateFactory();
         factory_valid = false;
       });
   });
-  return factory_valid ? reinterpret_cast<StateFactory *>(factory_storage)
+  return factory_valid ? reinterpret_cast<RandomStateFactory *>(factory_storage)
                        : nullptr;
 }
 
-void StateFactory::postfork_cleanup() {
+void RandomStateFactory::postfork_cleanup() {
   if (factory_valid)
-    reinterpret_cast<StateFactory *>(factory_storage)->~StateFactory();
+    reinterpret_cast<RandomStateFactory *>(factory_storage)
+        ->~RandomStateFactory();
   factory_onceflag = callonce_impl::NOT_CALLED;
   factory_valid = false;
 }
@@ -287,7 +289,7 @@ void *acquire_tls() {
     return nullptr;
   // first acquirement
   if (tls_state == nullptr) {
-    tls_state = StateFactory::acquire_global();
+    tls_state = RandomStateFactory::acquire_global();
     // if still fails, remember the failure
     if (tls_state == nullptr) {
       tls_state = MAP_FAILED;
@@ -295,9 +297,9 @@ void *acquire_tls() {
     } else {
       // register the release callback.
       if (__cxa_thread_atexit_impl(
-              [](void *s) { StateFactory::release_global(s); }, tls_state,
+              [](void *s) { RandomStateFactory::release_global(s); }, tls_state,
               __dso_handle)) {
-        StateFactory::release_global(tls_state);
+        RandomStateFactory::release_global(tls_state);
         tls_state = MAP_FAILED;
         return nullptr;
       }
@@ -329,7 +331,7 @@ void random_fill(void *buf, size_t size) {
         [state](void *buf, size_t size) {
           vdso::TypedSymbol<vdso::VDSOSym::GetRandom> vgetrandom;
           int res = vgetrandom(buf, size, 0, state,
-                               StateFactory::size_of_opaque_state());
+                               RandomStateFactory::size_of_opaque_state());
           if (res < 0) {
             libc_errno = -res;
             return -1;
@@ -350,7 +352,7 @@ void random_prefork() { fork_inflight = true; }
 void random_postfork_parent() { fork_inflight = false; }
 void random_postfork_child() {
   tls_state = nullptr;
-  StateFactory::postfork_cleanup();
+  RandomStateFactory::postfork_cleanup();
   fork_inflight = false;
 }
 

>From 06c79b37cd1cf36e5e19d5e9fc5c870e4ab3cee4 Mon Sep 17 00:00:00 2001
From: Schrodinger ZHU Yifan <yifanzhu at rochester.edu>
Date: Wed, 25 Sep 2024 09:38:50 -0400
Subject: [PATCH 5/5] [libc] more tests

---
 .../src/__support/OSUtil/linux/CMakeLists.txt |  2 ++
 .../OSUtil/linux/random_fill_test.cpp         | 29 ++++++++++++++++++-
 2 files changed, 30 insertions(+), 1 deletion(-)

diff --git a/libc/test/integration/src/__support/OSUtil/linux/CMakeLists.txt b/libc/test/integration/src/__support/OSUtil/linux/CMakeLists.txt
index a4f15ad8370352..a61a36ec24d451 100644
--- a/libc/test/integration/src/__support/OSUtil/linux/CMakeLists.txt
+++ b/libc/test/integration/src/__support/OSUtil/linux/CMakeLists.txt
@@ -7,4 +7,6 @@ add_integration_test(
   DEPENDS
     libc.include.pthread
     libc.src.__support.OSUtil.linux.random
+    libc.src.pthread.pthread_create
+    libc.src.pthread.pthread_join
 )
diff --git a/libc/test/integration/src/__support/OSUtil/linux/random_fill_test.cpp b/libc/test/integration/src/__support/OSUtil/linux/random_fill_test.cpp
index 4e029e484742e9..3f1f104c5883c7 100644
--- a/libc/test/integration/src/__support/OSUtil/linux/random_fill_test.cpp
+++ b/libc/test/integration/src/__support/OSUtil/linux/random_fill_test.cpp
@@ -7,7 +7,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "src/__support/OSUtil/linux/random.h"
-
+#include "src/pthread/pthread_create.h"
+#include "src/pthread/pthread_join.h"
 #include "test/IntegrationTest/test.h"
 
 void smoke_test() {
@@ -16,7 +17,33 @@ void smoke_test() {
   random_fill(&buffer, sizeof(buffer));
 }
 
+void larger_smoke_test() {
+  using namespace LIBC_NAMESPACE;
+  char buffer[1024];
+  random_fill(buffer, sizeof(buffer));
+}
+
+void threaded_test() {
+  // 32 threads will overflow the single state buffer page
+  // on a 4k page size system.
+  pthread_t threads[32];
+  for (auto &thread : threads) {
+    LIBC_NAMESPACE::pthread_create(
+        &thread, nullptr,
+        [](void *) -> void * {
+          smoke_test();
+          return nullptr;
+        },
+        nullptr);
+  }
+  for (auto &thread : threads) {
+    LIBC_NAMESPACE::pthread_join(thread, nullptr);
+  }
+}
+
 TEST_MAIN() {
   smoke_test();
+  larger_smoke_test();
+  threaded_test();
   return 0;
 }



More information about the libc-commits mailing list