[libc-commits] [libc] [libc] add hashtable fuzzing (PR #87949)

Schrodinger ZHU Yifan via libc-commits libc-commits at lists.llvm.org
Tue Apr 30 15:13:07 PDT 2024


https://github.com/SchrodingerZhu updated https://github.com/llvm/llvm-project/pull/87949

>From 6624d841cd809e117f59dbed54660c0a0cdece1d Mon Sep 17 00:00:00 2001
From: Schrodinger ZHU Yifan <yifanzhu at rochester.edu>
Date: Sun, 7 Apr 2024 18:26:16 -0400
Subject: [PATCH 1/9] [libc] add hashtable fuzzing

---
 libc/fuzzing/__support/CMakeLists.txt     |   9 ++
 libc/fuzzing/__support/hashtable_fuzz.cpp | 157 ++++++++++++++++++++++
 2 files changed, 166 insertions(+)
 create mode 100644 libc/fuzzing/__support/hashtable_fuzz.cpp

diff --git a/libc/fuzzing/__support/CMakeLists.txt b/libc/fuzzing/__support/CMakeLists.txt
index d4f6db71fdd849..b5d2b488447fc5 100644
--- a/libc/fuzzing/__support/CMakeLists.txt
+++ b/libc/fuzzing/__support/CMakeLists.txt
@@ -5,3 +5,12 @@ add_libc_fuzzer(
   DEPENDS
     libc.src.__support.big_int
 )
+
+add_libc_fuzzer(
+  hashtable_fuzz
+  SRCS
+    hashtable_fuzz.cpp
+  DEPENDS
+    libc.src.__support.HashTable.table
+    libc.src.string.memcpy
+)
diff --git a/libc/fuzzing/__support/hashtable_fuzz.cpp b/libc/fuzzing/__support/hashtable_fuzz.cpp
new file mode 100644
index 00000000000000..4b862b03b9d309
--- /dev/null
+++ b/libc/fuzzing/__support/hashtable_fuzz.cpp
@@ -0,0 +1,157 @@
+#include "src/__support/CPP/new.h"
+#include "src/__support/CPP/optional.h"
+#include "src/__support/HashTable/table.h"
+#include "src/string/memcpy.h"
+#include <search.h>
+#include <stdint.h>
+namespace LIBC_NAMESPACE {
+
+enum class Action { Find, Insert, CrossCheck };
+static uint8_t *global_buffer = nullptr;
+static size_t remaining = 0;
+
+static cpp::optional<uint8_t> next_u8() {
+  if (remaining == 0)
+    return cpp::nullopt;
+  uint8_t result = *global_buffer;
+  global_buffer++;
+  remaining--;
+  return result;
+}
+
+static cpp::optional<uint64_t> next_uint64() {
+  uint64_t result;
+  if (remaining < sizeof(result))
+    return cpp::nullopt;
+  memcpy(&result, global_buffer, sizeof(result));
+  global_buffer += sizeof(result);
+  remaining -= sizeof(result);
+  return result;
+}
+
+static cpp::optional<Action> next_action() {
+  if (cpp::optional<uint8_t> action = next_u8()) {
+    switch (*action % 3) {
+    case 0:
+      return Action::Find;
+    case 1:
+      return Action::Insert;
+    case 2:
+      return Action::CrossCheck;
+    }
+  }
+  return cpp::nullopt;
+}
+
+static cpp::optional<char *> next_cstr() {
+  char *result = reinterpret_cast<char *>(global_buffer);
+  if (cpp::optional<uint64_t> len = next_uint64()) {
+    uint64_t length;
+    for (length = 0; length < *len % 128; length++) {
+      if (length >= remaining)
+        return cpp::nullopt;
+      if (*global_buffer == '\0')
+        break;
+    }
+    if (length >= remaining)
+      return cpp::nullopt;
+    global_buffer[length] = '\0';
+    global_buffer += length + 1;
+    remaining -= length + 1;
+    return result;
+  }
+  return cpp::nullopt;
+}
+
+#define GET_VAL(op)                                                            \
+  __extension__({                                                              \
+    auto val = op();                                                           \
+    if (!val)                                                                  \
+      return 0;                                                                \
+    *val;                                                                      \
+  })
+
+template <typename Fn> struct CleanUpHook {
+  cpp::optional<Fn> fn;
+  ~CleanUpHook() {
+    if (fn)
+      (*fn)();
+  }
+  CleanUpHook(Fn fn) : fn(cpp::move(fn)) {}
+  CleanUpHook(const CleanUpHook &) = delete;
+  CleanUpHook(CleanUpHook &&other) : fn(cpp::move(other.fn)) {
+    other.fn = cpp::nullopt;
+  }
+};
+
+#define register_cleanup(ID, ...)                                              \
+  auto cleanup_hook##ID = __extension__({                                      \
+    auto a = __VA_ARGS__;                                                      \
+    CleanUpHook<decltype(a)>{a};                                               \
+  });
+
+static void trap_with_message(const char *msg) { __builtin_trap(); }
+
+extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
+  AllocChecker ac;
+  global_buffer = static_cast<uint8_t *>(::operator new(size, ac));
+  register_cleanup(0, [global_buffer = global_buffer, size] {
+    ::operator delete(global_buffer, size);
+  });
+  if (!ac)
+    return 0;
+  memcpy(global_buffer, data, size);
+
+  remaining = size;
+  uint64_t size_a = GET_VAL(next_uint64) % 256;
+  uint64_t size_b = GET_VAL(next_uint64) % 256;
+  uint64_t rand_a = GET_VAL(next_uint64);
+  uint64_t rand_b = GET_VAL(next_uint64);
+  internal::HashTable *table_a = internal::HashTable::allocate(size_a, rand_a);
+  register_cleanup(1, [&table_a] { internal::HashTable::deallocate(table_a); });
+  internal::HashTable *table_b = internal::HashTable::allocate(size_b, rand_b);
+  register_cleanup(2, [&table_b] { internal::HashTable::deallocate(table_b); });
+  if (!table_a || !table_b)
+    return 0;
+  for (;;) {
+    Action action = GET_VAL(next_action);
+    switch (action) {
+    case Action::Find: {
+      const char *key = GET_VAL(next_cstr);
+      if (!key)
+        return 0;
+      if (static_cast<bool>(table_a->find(key)) !=
+          static_cast<bool>(table_b->find(key)))
+        trap_with_message(key);
+      break;
+    }
+    case Action::Insert: {
+      char *key = GET_VAL(next_cstr);
+      if (!key)
+        return 0;
+      ENTRY *a = internal::HashTable::insert(table_a, ENTRY{key, key});
+      ENTRY *b = internal::HashTable::insert(table_b, ENTRY{key, key});
+      if (a->data != b->data)
+        __builtin_trap();
+      break;
+    }
+    case Action::CrossCheck: {
+      for (ENTRY a : *table_a) {
+        if (const ENTRY *b = table_b->find(a.key)) {
+          if (a.data != b->data)
+            __builtin_trap();
+        }
+      }
+      for (ENTRY b : *table_b) {
+        if (const ENTRY *a = table_a->find(b.key)) {
+          if (a->data != b.data)
+            __builtin_trap();
+        }
+      }
+      break;
+    }
+    }
+  }
+}
+
+} // namespace LIBC_NAMESPACE

>From 395b76d705e3eb5cc33d85f5b63fc2d2d06ccb81 Mon Sep 17 00:00:00 2001
From: Schrodinger ZHU Yifan <yifanzhu at rochester.edu>
Date: Sun, 7 Apr 2024 18:35:17 -0400
Subject: [PATCH 2/9] remove extra code

---
 libc/fuzzing/__support/hashtable_fuzz.cpp | 14 ++++++++------
 1 file changed, 8 insertions(+), 6 deletions(-)

diff --git a/libc/fuzzing/__support/hashtable_fuzz.cpp b/libc/fuzzing/__support/hashtable_fuzz.cpp
index 4b862b03b9d309..758c8d1aae01bf 100644
--- a/libc/fuzzing/__support/hashtable_fuzz.cpp
+++ b/libc/fuzzing/__support/hashtable_fuzz.cpp
@@ -108,9 +108,15 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
   uint64_t rand_a = GET_VAL(next_uint64);
   uint64_t rand_b = GET_VAL(next_uint64);
   internal::HashTable *table_a = internal::HashTable::allocate(size_a, rand_a);
-  register_cleanup(1, [&table_a] { internal::HashTable::deallocate(table_a); });
+  register_cleanup(1, [&table_a] {
+    if (table_a)
+      internal::HashTable::deallocate(table_a);
+  });
   internal::HashTable *table_b = internal::HashTable::allocate(size_b, rand_b);
-  register_cleanup(2, [&table_b] { internal::HashTable::deallocate(table_b); });
+  register_cleanup(2, [&table_b] {
+    if (table_b)
+      internal::HashTable::deallocate(table_b);
+  });
   if (!table_a || !table_b)
     return 0;
   for (;;) {
@@ -118,8 +124,6 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
     switch (action) {
     case Action::Find: {
       const char *key = GET_VAL(next_cstr);
-      if (!key)
-        return 0;
       if (static_cast<bool>(table_a->find(key)) !=
           static_cast<bool>(table_b->find(key)))
         trap_with_message(key);
@@ -127,8 +131,6 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
     }
     case Action::Insert: {
       char *key = GET_VAL(next_cstr);
-      if (!key)
-        return 0;
       ENTRY *a = internal::HashTable::insert(table_a, ENTRY{key, key});
       ENTRY *b = internal::HashTable::insert(table_b, ENTRY{key, key});
       if (a->data != b->data)

>From 2fef4145834707c01bcedd4fcf03e7904de47f23 Mon Sep 17 00:00:00 2001
From: Schrodinger ZHU Yifan <yifanzhu at rochester.edu>
Date: Mon, 8 Apr 2024 09:18:48 -0400
Subject: [PATCH 3/9] make style consistent

---
 libc/fuzzing/__support/hashtable_fuzz.cpp | 16 ++++++++--------
 1 file changed, 8 insertions(+), 8 deletions(-)

diff --git a/libc/fuzzing/__support/hashtable_fuzz.cpp b/libc/fuzzing/__support/hashtable_fuzz.cpp
index 758c8d1aae01bf..d5c64970b53234 100644
--- a/libc/fuzzing/__support/hashtable_fuzz.cpp
+++ b/libc/fuzzing/__support/hashtable_fuzz.cpp
@@ -63,7 +63,7 @@ static cpp::optional<char *> next_cstr() {
   return cpp::nullopt;
 }
 
-#define GET_VAL(op)                                                            \
+#define get_value(op)                                                          \
   __extension__({                                                              \
     auto val = op();                                                           \
     if (!val)                                                                  \
@@ -103,10 +103,10 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
   memcpy(global_buffer, data, size);
 
   remaining = size;
-  uint64_t size_a = GET_VAL(next_uint64) % 256;
-  uint64_t size_b = GET_VAL(next_uint64) % 256;
-  uint64_t rand_a = GET_VAL(next_uint64);
-  uint64_t rand_b = GET_VAL(next_uint64);
+  uint64_t size_a = get_value(next_uint64) % 256;
+  uint64_t size_b = get_value(next_uint64) % 256;
+  uint64_t rand_a = get_value(next_uint64);
+  uint64_t rand_b = get_value(next_uint64);
   internal::HashTable *table_a = internal::HashTable::allocate(size_a, rand_a);
   register_cleanup(1, [&table_a] {
     if (table_a)
@@ -120,17 +120,17 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
   if (!table_a || !table_b)
     return 0;
   for (;;) {
-    Action action = GET_VAL(next_action);
+    Action action = get_value(next_action);
     switch (action) {
     case Action::Find: {
-      const char *key = GET_VAL(next_cstr);
+      const char *key = get_value(next_cstr);
       if (static_cast<bool>(table_a->find(key)) !=
           static_cast<bool>(table_b->find(key)))
         trap_with_message(key);
       break;
     }
     case Action::Insert: {
-      char *key = GET_VAL(next_cstr);
+      char *key = get_value(next_cstr);
       ENTRY *a = internal::HashTable::insert(table_a, ENTRY{key, key});
       ENTRY *b = internal::HashTable::insert(table_b, ENTRY{key, key});
       if (a->data != b->data)

>From 8c14a30a64b0e01f97dd83b88e90912ad9813fb3 Mon Sep 17 00:00:00 2001
From: Schrodinger ZHU Yifan <yifanzhu at rochester.edu>
Date: Mon, 8 Apr 2024 12:50:59 -0400
Subject: [PATCH 4/9] make information more concentrated

---
 libc/fuzzing/__support/hashtable_fuzz.cpp | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/libc/fuzzing/__support/hashtable_fuzz.cpp b/libc/fuzzing/__support/hashtable_fuzz.cpp
index d5c64970b53234..f346d2b79c6db4 100644
--- a/libc/fuzzing/__support/hashtable_fuzz.cpp
+++ b/libc/fuzzing/__support/hashtable_fuzz.cpp
@@ -45,9 +45,9 @@ static cpp::optional<Action> next_action() {
 
 static cpp::optional<char *> next_cstr() {
   char *result = reinterpret_cast<char *>(global_buffer);
-  if (cpp::optional<uint64_t> len = next_uint64()) {
+  if (cpp::optional<uint8_t> len = next_u8()) {
     uint64_t length;
-    for (length = 0; length < *len % 128; length++) {
+    for (length = 0; length < *len; length++) {
       if (length >= remaining)
         return cpp::nullopt;
       if (*global_buffer == '\0')
@@ -87,7 +87,7 @@ template <typename Fn> struct CleanUpHook {
 #define register_cleanup(ID, ...)                                              \
   auto cleanup_hook##ID = __extension__({                                      \
     auto a = __VA_ARGS__;                                                      \
-    CleanUpHook<decltype(a)>{a};                                               \
+    CleanUpHook<decltype(a)>(cpp::move(a));                                    \
   });
 
 static void trap_with_message(const char *msg) { __builtin_trap(); }
@@ -103,8 +103,8 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
   memcpy(global_buffer, data, size);
 
   remaining = size;
-  uint64_t size_a = get_value(next_uint64) % 256;
-  uint64_t size_b = get_value(next_uint64) % 256;
+  uint64_t size_a = get_value(next_u8);
+  uint64_t size_b = get_value(next_u8);
   uint64_t rand_a = get_value(next_uint64);
   uint64_t rand_b = get_value(next_uint64);
   internal::HashTable *table_a = internal::HashTable::allocate(size_a, rand_a);

>From ea094120be6387893bc53f4ce173b1697e9bf22d Mon Sep 17 00:00:00 2001
From: Schrodinger ZHU Yifan <yifanzhu at rochester.edu>
Date: Fri, 19 Apr 2024 17:56:50 -0400
Subject: [PATCH 5/9] add missing license headers

---
 libc/fuzzing/__support/hashtable_fuzz.cpp | 11 +++++++++++
 libc/fuzzing/__support/uint_fuzz.cpp      | 11 +++++++++++
 2 files changed, 22 insertions(+)

diff --git a/libc/fuzzing/__support/hashtable_fuzz.cpp b/libc/fuzzing/__support/hashtable_fuzz.cpp
index f346d2b79c6db4..4a726b229b7818 100644
--- a/libc/fuzzing/__support/hashtable_fuzz.cpp
+++ b/libc/fuzzing/__support/hashtable_fuzz.cpp
@@ -1,3 +1,14 @@
+//===-- hashtable_fuzz.cpp ------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// Fuzzing test for llvm-libc hashtable implementations.
+///
+//===----------------------------------------------------------------------===//
 #include "src/__support/CPP/new.h"
 #include "src/__support/CPP/optional.h"
 #include "src/__support/HashTable/table.h"
diff --git a/libc/fuzzing/__support/uint_fuzz.cpp b/libc/fuzzing/__support/uint_fuzz.cpp
index 07149f511b8386..109375f84da780 100644
--- a/libc/fuzzing/__support/uint_fuzz.cpp
+++ b/libc/fuzzing/__support/uint_fuzz.cpp
@@ -1,3 +1,14 @@
+//===-- uint_fuzz.cpp -----------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// Fuzzing test for llvm-libc unsigned integer utilities.
+///
+//===----------------------------------------------------------------------===//
 #include "src/__support/CPP/bit.h"
 #include "src/__support/big_int.h"
 #include "src/string/memory_utils/inline_memcpy.h"

>From 184fbd4d73e86da2ce9fea98f534d3d259c18744 Mon Sep 17 00:00:00 2001
From: Schrodinger ZHU Yifan <yifanzhu at rochester.edu>
Date: Tue, 30 Apr 2024 17:55:33 -0400
Subject: [PATCH 6/9] fix

---
 libc/fuzzing/__support/CMakeLists.txt         |  11 +
 libc/fuzzing/__support/hashtable_fuzz.cpp     | 261 ++++++++++--------
 .../HashTable/generic/bitmask_impl.inc        |  15 +-
 3 files changed, 160 insertions(+), 127 deletions(-)

diff --git a/libc/fuzzing/__support/CMakeLists.txt b/libc/fuzzing/__support/CMakeLists.txt
index b5d2b488447fc5..0c35948cb1e1d6 100644
--- a/libc/fuzzing/__support/CMakeLists.txt
+++ b/libc/fuzzing/__support/CMakeLists.txt
@@ -14,3 +14,14 @@ add_libc_fuzzer(
     libc.src.__support.HashTable.table
     libc.src.string.memcpy
 )
+
+add_libc_fuzzer(
+  hashtable_opt_fuzz
+  SRCS
+    hashtable_fuzz.cpp
+  DEPENDS
+    libc.src.__support.HashTable.table
+    libc.src.string.memcpy
+  COMPILE_OPTIONS
+    -D__LIBC_EXPLICIT_SIMD_OPT
+) 
diff --git a/libc/fuzzing/__support/hashtable_fuzz.cpp b/libc/fuzzing/__support/hashtable_fuzz.cpp
index 4a726b229b7818..21124eef9e154c 100644
--- a/libc/fuzzing/__support/hashtable_fuzz.cpp
+++ b/libc/fuzzing/__support/hashtable_fuzz.cpp
@@ -11,160 +11,179 @@
 //===----------------------------------------------------------------------===//
 #include "src/__support/CPP/new.h"
 #include "src/__support/CPP/optional.h"
+#include "src/__support/CPP/string.h"
+#include "src/__support/CPP/utility/forward.h"
 #include "src/__support/HashTable/table.h"
-#include "src/string/memcpy.h"
-#include <search.h>
 #include <stdint.h>
 namespace LIBC_NAMESPACE {
 
-enum class Action { Find, Insert, CrossCheck };
-static uint8_t *global_buffer = nullptr;
-static size_t remaining = 0;
+template <typename T> class UniquePtr {
+  T *ptr;
 
-static cpp::optional<uint8_t> next_u8() {
-  if (remaining == 0)
-    return cpp::nullopt;
-  uint8_t result = *global_buffer;
-  global_buffer++;
-  remaining--;
-  return result;
-}
+public:
+  UniquePtr(T *ptr) : ptr(ptr) {}
+  ~UniquePtr() { delete ptr; }
+  UniquePtr(UniquePtr &&other) : ptr(other.ptr) { other.ptr = nullptr; }
+  UniquePtr &operator=(UniquePtr &&other) {
+    delete ptr;
+    ptr = other.ptr;
+    other.ptr = nullptr;
+    return *this;
+  }
+  T *operator->() { return ptr; }
+  template <typename... U> static UniquePtr create(U &&...x) {
+    AllocChecker ac;
+    T *ptr = new (ac) T(cpp::forward<U>(x)...);
+    if (!ac)
+      return {nullptr};
+    return UniquePtr(ptr);
+  }
+  operator bool() { return ptr != nullptr; }
+  T *get() { return ptr; }
+};
 
-static cpp::optional<uint64_t> next_uint64() {
-  uint64_t result;
-  if (remaining < sizeof(result))
-    return cpp::nullopt;
-  memcpy(&result, global_buffer, sizeof(result));
-  global_buffer += sizeof(result);
-  remaining -= sizeof(result);
-  return result;
-}
+// a tagged union
+struct Action {
+  enum class Tag { Find, Insert, CrossCheck } tag;
+  cpp::string key;
+  UniquePtr<Action> next;
+  Action(Tag tag, cpp::string key, UniquePtr<Action> next)
+      : tag(tag), key(cpp::move(key)), next(cpp::move(next)) {}
+};
 
-static cpp::optional<Action> next_action() {
-  if (cpp::optional<uint8_t> action = next_u8()) {
-    switch (*action % 3) {
-    case 0:
-      return Action::Find;
-    case 1:
-      return Action::Insert;
-    case 2:
-      return Action::CrossCheck;
-    }
+static struct {
+  UniquePtr<Action> actions = nullptr;
+  size_t remaining;
+  const char *buffer;
+
+  template <typename T> cpp::optional<T> next() {
+    static_assert(cpp::is_integral<T>::value, "T must be an integral type");
+    union {
+      T result;
+      char data[sizeof(T)];
+    };
+    if (remaining < sizeof(result))
+      return cpp::nullopt;
+    for (size_t i = 0; i < sizeof(result); i++)
+      data[i] = buffer[i];
+    buffer += sizeof(result);
+    remaining -= sizeof(result);
+    return result;
   }
-  return cpp::nullopt;
-}
 
-static cpp::optional<char *> next_cstr() {
-  char *result = reinterpret_cast<char *>(global_buffer);
-  if (cpp::optional<uint8_t> len = next_u8()) {
-    uint64_t length;
-    for (length = 0; length < *len; length++) {
-      if (length >= remaining)
-        return cpp::nullopt;
-      if (*global_buffer == '\0')
+  cpp::optional<cpp::string> next_string() {
+    if (cpp::optional<uint16_t> len = next<uint16_t>()) {
+      uint64_t length;
+      for (length = 0; length < *len && length < remaining; length++)
+        if (buffer[length] == '\0')
+          break;
+      cpp::string result(buffer, length);
+      result += '\0';
+      buffer += length;
+      remaining -= length;
+      return result;
+    }
+    return cpp::nullopt;
+  }
+  Action *next_action() {
+    if (cpp::optional<uint8_t> action = next<uint8_t>()) {
+      switch (*action % 3) {
+      case 0: {
+        if (cpp::optional<cpp::string> key = next_string())
+          actions = UniquePtr<Action>::create(
+              Action::Tag::Find, cpp::move(*key), cpp::move(actions));
+        else
+          return nullptr;
         break;
+      }
+      case 1: {
+        if (cpp::optional<cpp::string> key = next_string())
+          actions = UniquePtr<Action>::create(
+              Action::Tag::Insert, cpp::move(*key), cpp::move(actions));
+        else
+          return nullptr;
+        break;
+      }
+      case 2: {
+        actions = UniquePtr<Action>::create(Action::Tag::CrossCheck, "",
+                                            cpp::move(actions));
+        break;
+      }
+      }
+      return actions.get();
     }
-    if (length >= remaining)
-      return cpp::nullopt;
-    global_buffer[length] = '\0';
-    global_buffer += length + 1;
-    remaining -= length + 1;
-    return result;
+    return nullptr;
   }
-  return cpp::nullopt;
-}
+} global_status;
 
-#define get_value(op)                                                          \
-  __extension__({                                                              \
-    auto val = op();                                                           \
-    if (!val)                                                                  \
-      return 0;                                                                \
-    *val;                                                                      \
-  })
+class HashTable {
+  internal::HashTable *table;
 
-template <typename Fn> struct CleanUpHook {
-  cpp::optional<Fn> fn;
-  ~CleanUpHook() {
-    if (fn)
-      (*fn)();
-  }
-  CleanUpHook(Fn fn) : fn(cpp::move(fn)) {}
-  CleanUpHook(const CleanUpHook &) = delete;
-  CleanUpHook(CleanUpHook &&other) : fn(cpp::move(other.fn)) {
-    other.fn = cpp::nullopt;
+public:
+  HashTable(uint64_t size, uint64_t seed)
+      : table(internal::HashTable::allocate(size, seed)) {}
+  HashTable(internal::HashTable *table) : table(table) {}
+  ~HashTable() { internal::HashTable::deallocate(table); }
+  HashTable(HashTable &&other) : table(other.table) { other.table = nullptr; }
+  bool is_valid() const { return table != nullptr; }
+  ENTRY *find(const char *key) { return table->find(key); }
+  ENTRY *insert(const ENTRY &entry) {
+    return internal::HashTable::insert(this->table, entry);
   }
+  using iterator = internal::HashTable::iterator;
+  iterator begin() const { return table->begin(); }
+  iterator end() const { return table->end(); }
 };
 
-#define register_cleanup(ID, ...)                                              \
-  auto cleanup_hook##ID = __extension__({                                      \
-    auto a = __VA_ARGS__;                                                      \
-    CleanUpHook<decltype(a)>(cpp::move(a));                                    \
-  });
+HashTable next_hashtable() {
+  if (cpp::optional<uint16_t> size = global_status.next<uint16_t>())
+    if (cpp::optional<uint64_t> seed = global_status.next<uint64_t>())
+      return HashTable(*size, *seed);
 
-static void trap_with_message(const char *msg) { __builtin_trap(); }
+  return HashTable(0, 0);
+}
 
 extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
-  AllocChecker ac;
-  global_buffer = static_cast<uint8_t *>(::operator new(size, ac));
-  register_cleanup(0, [global_buffer = global_buffer, size] {
-    ::operator delete(global_buffer, size);
-  });
-  if (!ac)
+  char key[] = "key";
+  global_status.buffer = reinterpret_cast<const char *>(data);
+  global_status.remaining = size;
+  HashTable table_a = next_hashtable();
+  HashTable table_b = next_hashtable();
+  if (!table_a.is_valid() || !table_b.is_valid())
     return 0;
-  memcpy(global_buffer, data, size);
 
-  remaining = size;
-  uint64_t size_a = get_value(next_u8);
-  uint64_t size_b = get_value(next_u8);
-  uint64_t rand_a = get_value(next_uint64);
-  uint64_t rand_b = get_value(next_uint64);
-  internal::HashTable *table_a = internal::HashTable::allocate(size_a, rand_a);
-  register_cleanup(1, [&table_a] {
-    if (table_a)
-      internal::HashTable::deallocate(table_a);
-  });
-  internal::HashTable *table_b = internal::HashTable::allocate(size_b, rand_b);
-  register_cleanup(2, [&table_b] {
-    if (table_b)
-      internal::HashTable::deallocate(table_b);
-  });
-  if (!table_a || !table_b)
-    return 0;
   for (;;) {
-    Action action = get_value(next_action);
-    switch (action) {
-    case Action::Find: {
-      const char *key = get_value(next_cstr);
-      if (static_cast<bool>(table_a->find(key)) !=
-          static_cast<bool>(table_b->find(key)))
-        trap_with_message(key);
+    Action *action = global_status.next_action();
+    if (!action)
+      return 0;
+    switch (action->tag) {
+    case Action::Tag::Find: {
+      if (table_a.find(action->key.c_str()) !=
+          table_b.find(action->key.c_str()))
+        __builtin_trap();
       break;
     }
-    case Action::Insert: {
-      char *key = get_value(next_cstr);
-      ENTRY *a = internal::HashTable::insert(table_a, ENTRY{key, key});
-      ENTRY *b = internal::HashTable::insert(table_b, ENTRY{key, key});
+    case Action::Tag::Insert: {
+      ENTRY *a = table_a.insert(ENTRY{key, key});
+      ENTRY *b = table_b.insert(ENTRY{key, key});
       if (a->data != b->data)
         __builtin_trap();
       break;
     }
-    case Action::CrossCheck: {
-      for (ENTRY a : *table_a) {
-        if (const ENTRY *b = table_b->find(a.key)) {
-          if (a.data != b->data)
-            __builtin_trap();
-        }
-      }
-      for (ENTRY b : *table_b) {
-        if (const ENTRY *a = table_a->find(b.key)) {
-          if (a->data != b.data)
-            __builtin_trap();
-        }
-      }
+    case Action::Tag::CrossCheck: {
+      for (ENTRY a : table_a)
+        if (const ENTRY *b = table_b.find(a.key); a.data != b->data)
+          __builtin_trap();
+
+      for (ENTRY b : table_b)
+        if (const ENTRY *a = table_a.find(b.key); a->data != b.data)
+          __builtin_trap();
+
       break;
     }
     }
   }
+  return 0;
 }
 
 } // namespace LIBC_NAMESPACE
diff --git a/libc/src/__support/HashTable/generic/bitmask_impl.inc b/libc/src/__support/HashTable/generic/bitmask_impl.inc
index 56b540d568d005..b1ebc20721e309 100644
--- a/libc/src/__support/HashTable/generic/bitmask_impl.inc
+++ b/libc/src/__support/HashTable/generic/bitmask_impl.inc
@@ -34,10 +34,11 @@ LIBC_INLINE constexpr bitmask_t repeat_byte(bitmask_t byte) {
   return byte;
 }
 
-using BitMask = BitMaskAdaptor<bitmask_t, 0x8ull>;
+using BitMask = BitMaskAdaptor<bitmask_t, 0x8ul>;
 using IteratableBitMask = IteratableBitMaskAdaptor<BitMask>;
 
 struct Group {
+  LIBC_INLINE_VAR static constexpr bitmask_t MASK = repeat_byte(0x80ul);
   bitmask_t data;
 
   // Load a group of control words from an arbitary address.
@@ -100,21 +101,23 @@ struct Group {
     //  - The check for key equality will catch these.
     //  - This only happens if there is at least 1 true match.
     //  - The chance of this happening is very low (< 1% chance per byte).
+    static constexpr bitmask_t ONES = repeat_byte(0x01ul);
     auto cmp = data ^ repeat_byte(byte);
-    auto result = LIBC_NAMESPACE::Endian::to_little_endian(
-        (cmp - repeat_byte(0x01)) & ~cmp & repeat_byte(0x80));
+    auto result =
+        LIBC_NAMESPACE::Endian::to_little_endian((cmp - ONES) & ~cmp & MASK);
     return {BitMask{result}};
   }
 
   // Find out the lanes equal to EMPTY or DELETE (highest bit set) and
   // return the bitmask with corresponding bits set.
   LIBC_INLINE BitMask mask_available() const {
-    return {LIBC_NAMESPACE::Endian::to_little_endian(data) & repeat_byte(0x80)};
+    bitmask_t le_data = LIBC_NAMESPACE::Endian::to_little_endian(data);
+    return {le_data & MASK};
   }
 
   LIBC_INLINE IteratableBitMask occupied() const {
-    return {
-        {static_cast<bitmask_t>(mask_available().word ^ repeat_byte(0x80))}};
+    bitmask_t available = mask_available().word;
+    return {BitMask{available ^ MASK}};
   }
 };
 } // namespace internal

>From 7bef0dca0d8207c69e63bc7b349bd8983015c535 Mon Sep 17 00:00:00 2001
From: Schrodinger ZHU Yifan <yifanzhu at rochester.edu>
Date: Tue, 30 Apr 2024 18:01:23 -0400
Subject: [PATCH 7/9] fix

---
 libc/fuzzing/__support/hashtable_fuzz.cpp | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/libc/fuzzing/__support/hashtable_fuzz.cpp b/libc/fuzzing/__support/hashtable_fuzz.cpp
index 21124eef9e154c..b5fdcfa6482600 100644
--- a/libc/fuzzing/__support/hashtable_fuzz.cpp
+++ b/libc/fuzzing/__support/hashtable_fuzz.cpp
@@ -144,7 +144,6 @@ HashTable next_hashtable() {
 }
 
 extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
-  char key[] = "key";
   global_status.buffer = reinterpret_cast<const char *>(data);
   global_status.remaining = size;
   HashTable table_a = next_hashtable();
@@ -164,8 +163,8 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
       break;
     }
     case Action::Tag::Insert: {
-      ENTRY *a = table_a.insert(ENTRY{key, key});
-      ENTRY *b = table_b.insert(ENTRY{key, key});
+      ENTRY *a = table_a.insert(ENTRY{action->key.data(), action->key.data()});
+      ENTRY *b = table_b.insert(ENTRY{action->key.data(), action->key.data()});
       if (a->data != b->data)
         __builtin_trap();
       break;

>From d5b7e2617528d48ad90b3cac0d8a6055236f17bc Mon Sep 17 00:00:00 2001
From: Schrodinger ZHU Yifan <yifanzhu at rochester.edu>
Date: Tue, 30 Apr 2024 18:10:23 -0400
Subject: [PATCH 8/9] fix

---
 libc/fuzzing/__support/hashtable_fuzz.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/libc/fuzzing/__support/hashtable_fuzz.cpp b/libc/fuzzing/__support/hashtable_fuzz.cpp
index b5fdcfa6482600..cc1b63716ec589 100644
--- a/libc/fuzzing/__support/hashtable_fuzz.cpp
+++ b/libc/fuzzing/__support/hashtable_fuzz.cpp
@@ -157,8 +157,8 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
       return 0;
     switch (action->tag) {
     case Action::Tag::Find: {
-      if (table_a.find(action->key.c_str()) !=
-          table_b.find(action->key.c_str()))
+      if (static_cast<bool>(table_a.find(action->key.c_str())) !=
+          static_cast<bool>(table_b.find(action->key.c_str())))
         __builtin_trap();
       break;
     }

>From 5a2873388d1b9ba855f5cb9a19bae7d547cd9d9e Mon Sep 17 00:00:00 2001
From: Schrodinger ZHU Yifan <yifanzhu at rochester.edu>
Date: Tue, 30 Apr 2024 18:12:44 -0400
Subject: [PATCH 9/9] fix

---
 libc/src/__support/HashTable/generic/bitmask_impl.inc | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/libc/src/__support/HashTable/generic/bitmask_impl.inc b/libc/src/__support/HashTable/generic/bitmask_impl.inc
index b1ebc20721e309..d6c5ae075558af 100644
--- a/libc/src/__support/HashTable/generic/bitmask_impl.inc
+++ b/libc/src/__support/HashTable/generic/bitmask_impl.inc
@@ -102,7 +102,7 @@ struct Group {
     //  - This only happens if there is at least 1 true match.
     //  - The chance of this happening is very low (< 1% chance per byte).
     static constexpr bitmask_t ONES = repeat_byte(0x01ul);
-    auto cmp = data ^ repeat_byte(byte);
+    auto cmp = data ^ repeat_byte(static_cast<bitmask_t>(byte) & 0xFFul);
     auto result =
         LIBC_NAMESPACE::Endian::to_little_endian((cmp - ONES) & ~cmp & MASK);
     return {BitMask{result}};



More information about the libc-commits mailing list