[llvm] [ADT] Add TrieRawHashMap (PR #69528)

Steven Wu via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 23 10:51:27 PDT 2024


https://github.com/cachemeifyoucan updated https://github.com/llvm/llvm-project/pull/69528

>From 458e3af5f3b8d1231d00e46690726d48f13c1d77 Mon Sep 17 00:00:00 2001
From: Steven Wu <stevenwu at apple.com>
Date: Thu, 5 Oct 2023 13:02:40 -0700
Subject: [PATCH 1/5] [ADT] Add TrieRawHashMap

Implement TrieRawHashMap which stores objects into a Trie based on the
hash of the object.

User needs to supply the hashing function and guarantees the uniqueness of
the hash for the objects to be inserted. Hash collision is not
supported
---
 llvm/include/llvm/ADT/TrieRawHashMap.h    | 398 ++++++++++++++++++
 llvm/lib/Support/CMakeLists.txt           |   1 +
 llvm/lib/Support/TrieHashIndexGenerator.h |  89 ++++
 llvm/lib/Support/TrieRawHashMap.cpp       | 478 ++++++++++++++++++++++
 llvm/unittests/ADT/CMakeLists.txt         |   1 +
 llvm/unittests/ADT/TrieRawHashMapTest.cpp | 343 ++++++++++++++++
 6 files changed, 1310 insertions(+)
 create mode 100644 llvm/include/llvm/ADT/TrieRawHashMap.h
 create mode 100644 llvm/lib/Support/TrieHashIndexGenerator.h
 create mode 100644 llvm/lib/Support/TrieRawHashMap.cpp
 create mode 100644 llvm/unittests/ADT/TrieRawHashMapTest.cpp

diff --git a/llvm/include/llvm/ADT/TrieRawHashMap.h b/llvm/include/llvm/ADT/TrieRawHashMap.h
new file mode 100644
index 00000000000000..baa08e214ce6fd
--- /dev/null
+++ b/llvm/include/llvm/ADT/TrieRawHashMap.h
@@ -0,0 +1,398 @@
+//===- TrieRawHashMap.h -----------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_ADT_TRIERAWHASHMAP_H
+#define LLVM_ADT_TRIERAWHASHMAP_H
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
+#include <atomic>
+#include <optional>
+
+namespace llvm {
+
+class raw_ostream;
+
+/// TrieRawHashMap - is a lock-free thread-safe trie that is can be used to
+/// store/index data based on a hash value. It can be customized to work with
+/// any hash algorithm or store any data.
+///
+/// Data structure:
+/// Data node stored in the Trie contains both hash and data:
+/// struct {
+///    HashT Hash;
+///    DataT Data;
+/// };
+///
+/// Data is stored/indexed via a prefix tree, where each node in the tree can be
+/// either the root, a sub-trie or a data node. Assuming a 4-bit hash and two
+/// data objects {0001, A} and {0100, B}, it can be stored in a trie
+/// (assuming Root has 2 bits, SubTrie has 1 bit):
+///  +--------+
+///  |Root[00]| -> {0001, A}
+///  |    [01]| -> {0100, B}
+///  |    [10]| (empty)
+///  |    [11]| (empty)
+///  +--------+
+///
+/// Inserting a new object {0010, C} will result in:
+///  +--------+    +----------+
+///  |Root[00]| -> |SubTrie[0]| -> {0001, A}
+///  |        |    |       [1]| -> {0010, C}
+///  |        |    +----------+
+///  |    [01]| -> {0100, B}
+///  |    [10]| (empty)
+///  |    [11]| (empty)
+///  +--------+
+/// Note object A is sunk down to a sub-trie during the insertion. All the
+/// nodes are inserted through compare-exchange to ensure thread-safe and
+/// lock-free.
+///
+/// To find an object in the trie, walk the tree with prefix of the hash until
+/// the data node is found. Then the hash is compared with the hash stored in
+/// the data node to see if the is the same object.
+///
+/// Hash collision is not allowed so it is recommended to use trie with a
+/// "strong" hashing algorithm. A well-distributed hash can also result in
+/// better performance and memory usage.
+///
+/// It currently does not support iteration and deletion.
+
+/// Base class for a lock-free thread-safe hash-mapped trie.
+class ThreadSafeTrieRawHashMapBase {
+public:
+  static constexpr size_t TrieContentBaseSize = 4;
+  static constexpr size_t DefaultNumRootBits = 6;
+  static constexpr size_t DefaultNumSubtrieBits = 4;
+
+private:
+  template <class T> struct AllocValueType {
+    char Base[TrieContentBaseSize];
+    std::aligned_union_t<sizeof(T), T> Content;
+  };
+
+protected:
+  template <class T>
+  static constexpr size_t DefaultContentAllocSize = sizeof(AllocValueType<T>);
+
+  template <class T>
+  static constexpr size_t DefaultContentAllocAlign = alignof(AllocValueType<T>);
+
+  template <class T>
+  static constexpr size_t DefaultContentOffset =
+      offsetof(AllocValueType<T>, Content);
+
+public:
+  void operator delete(void *Ptr) { ::free(Ptr); }
+
+  LLVM_DUMP_METHOD void dump() const;
+  void print(raw_ostream &OS) const;
+
+protected:
+  /// Result of a lookup. Suitable for an insertion hint. Maybe could be
+  /// expanded into an iterator of sorts, but likely not useful (visiting
+  /// everything in the trie should probably be done some way other than
+  /// through an iterator pattern).
+  class PointerBase {
+  protected:
+    void *get() const { return I == -2u ? P : nullptr; }
+
+  public:
+    PointerBase() noexcept = default;
+    PointerBase(PointerBase &&) = default;
+    PointerBase(const PointerBase &) = default;
+    PointerBase &operator=(PointerBase &&) = default;
+    PointerBase &operator=(const PointerBase &) = default;
+
+  private:
+    friend class ThreadSafeTrieRawHashMapBase;
+    explicit PointerBase(void *Content) : P(Content), I(-2u) {}
+    PointerBase(void *P, unsigned I, unsigned B) : P(P), I(I), B(B) {}
+
+    bool isHint() const { return I != -1u && I != -2u; }
+
+    void *P = nullptr;
+    unsigned I = -1u;
+    unsigned B = 0;
+  };
+
+  /// Find the stored content with hash.
+  PointerBase find(ArrayRef<uint8_t> Hash) const;
+
+  /// Insert and return the stored content.
+  PointerBase
+  insert(PointerBase Hint, ArrayRef<uint8_t> Hash,
+         function_ref<const uint8_t *(void *Mem, ArrayRef<uint8_t> Hash)>
+             Constructor);
+
+  ThreadSafeTrieRawHashMapBase() = delete;
+
+  ThreadSafeTrieRawHashMapBase(
+      size_t ContentAllocSize, size_t ContentAllocAlign, size_t ContentOffset,
+      std::optional<size_t> NumRootBits = std::nullopt,
+      std::optional<size_t> NumSubtrieBits = std::nullopt);
+
+  /// Destructor, which asserts if there's anything to do. Subclasses should
+  /// call \a destroyImpl().
+  ///
+  /// \pre \a destroyImpl() was already called.
+  ~ThreadSafeTrieRawHashMapBase();
+  void destroyImpl(function_ref<void(void *ValueMem)> Destructor);
+
+  ThreadSafeTrieRawHashMapBase(ThreadSafeTrieRawHashMapBase &&RHS);
+
+  // Move assignment can be implemented in a thread-safe way if NumRootBits and
+  // NumSubtrieBits are stored inside the Root.
+  ThreadSafeTrieRawHashMapBase &
+  operator=(ThreadSafeTrieRawHashMapBase &&RHS) = delete;
+
+  // No copy.
+  ThreadSafeTrieRawHashMapBase(const ThreadSafeTrieRawHashMapBase &) = delete;
+  ThreadSafeTrieRawHashMapBase &
+  operator=(const ThreadSafeTrieRawHashMapBase &) = delete;
+
+  // Debug functions. Implementation details and not guaranteed to be
+  // thread-safe.
+  PointerBase getRoot() const;
+  unsigned getStartBit(PointerBase P) const;
+  unsigned getNumBits(PointerBase P) const;
+  unsigned getNumSlotUsed(PointerBase P) const;
+  std::string getTriePrefixAsString(PointerBase P) const;
+  unsigned getNumTries() const;
+  // Visit next trie in the allocation chain.
+  PointerBase getNextTrie(PointerBase P) const;
+
+private:
+  friend class TrieRawHashMapTestHelper;
+  const unsigned short ContentAllocSize;
+  const unsigned short ContentAllocAlign;
+  const unsigned short ContentOffset;
+  unsigned short NumRootBits;
+  unsigned short NumSubtrieBits;
+  struct ImplType;
+  // ImplPtr is owned by ThreadSafeTrieRawHashMapBase and needs to be freed in
+  // destoryImpl.
+  std::atomic<ImplType *> ImplPtr;
+  ImplType &getOrCreateImpl();
+  ImplType *getImpl() const;
+};
+
+/// Lock-free thread-safe hash-mapped trie.
+template <class T, size_t NumHashBytes>
+class ThreadSafeTrieRawHashMap : public ThreadSafeTrieRawHashMapBase {
+public:
+  using HashT = std::array<uint8_t, NumHashBytes>;
+
+  class LazyValueConstructor;
+  struct value_type {
+    const HashT Hash;
+    T Data;
+
+    value_type(value_type &&) = default;
+    value_type(const value_type &) = default;
+
+    value_type(ArrayRef<uint8_t> Hash, const T &Data)
+        : Hash(makeHash(Hash)), Data(Data) {}
+    value_type(ArrayRef<uint8_t> Hash, T &&Data)
+        : Hash(makeHash(Hash)), Data(std::move(Data)) {}
+
+  private:
+    friend class LazyValueConstructor;
+
+    struct EmplaceTag {};
+    template <class... ArgsT>
+    value_type(ArrayRef<uint8_t> Hash, EmplaceTag, ArgsT &&...Args)
+        : Hash(makeHash(Hash)), Data(std::forward<ArgsT>(Args)...) {}
+
+    static HashT makeHash(ArrayRef<uint8_t> HashRef) {
+      HashT Hash;
+      std::copy(HashRef.begin(), HashRef.end(), Hash.data());
+      return Hash;
+    }
+  };
+
+  using ThreadSafeTrieRawHashMapBase::operator delete;
+  using HashType = HashT;
+
+  using ThreadSafeTrieRawHashMapBase::dump;
+  using ThreadSafeTrieRawHashMapBase::print;
+
+private:
+  template <class ValueT> class PointerImpl : PointerBase {
+    friend class ThreadSafeTrieRawHashMap;
+
+    ValueT *get() const {
+      if (void *B = PointerBase::get())
+        return reinterpret_cast<ValueT *>(B);
+      return nullptr;
+    }
+
+  public:
+    ValueT &operator*() const {
+      assert(get());
+      return *get();
+    }
+    ValueT *operator->() const {
+      assert(get());
+      return get();
+    }
+    explicit operator bool() const { return get(); }
+
+    PointerImpl() = default;
+    PointerImpl(PointerImpl &&) = default;
+    PointerImpl(const PointerImpl &) = default;
+    PointerImpl &operator=(PointerImpl &&) = default;
+    PointerImpl &operator=(const PointerImpl &) = default;
+
+  protected:
+    PointerImpl(PointerBase Result) : PointerBase(Result) {}
+  };
+
+public:
+  class pointer;
+  class const_pointer;
+  class pointer : public PointerImpl<value_type> {
+    friend class ThreadSafeTrieRawHashMap;
+    friend class const_pointer;
+
+  public:
+    pointer() = default;
+    pointer(pointer &&) = default;
+    pointer(const pointer &) = default;
+    pointer &operator=(pointer &&) = default;
+    pointer &operator=(const pointer &) = default;
+
+  private:
+    pointer(PointerBase Result) : pointer::PointerImpl(Result) {}
+  };
+
+  class const_pointer : public PointerImpl<const value_type> {
+    friend class ThreadSafeTrieRawHashMap;
+
+  public:
+    const_pointer() = default;
+    const_pointer(const_pointer &&) = default;
+    const_pointer(const const_pointer &) = default;
+    const_pointer &operator=(const_pointer &&) = default;
+    const_pointer &operator=(const const_pointer &) = default;
+
+    const_pointer(const pointer &P) : const_pointer::PointerImpl(P) {}
+
+  private:
+    const_pointer(PointerBase Result) : const_pointer::PointerImpl(Result) {}
+  };
+
+  class LazyValueConstructor {
+  public:
+    value_type &operator()(T &&RHS) {
+      assert(Mem && "Constructor already called, or moved away");
+      return assign(::new (Mem) value_type(Hash, std::move(RHS)));
+    }
+    value_type &operator()(const T &RHS) {
+      assert(Mem && "Constructor already called, or moved away");
+      return assign(::new (Mem) value_type(Hash, RHS));
+    }
+    template <class... ArgsT> value_type &emplace(ArgsT &&...Args) {
+      assert(Mem && "Constructor already called, or moved away");
+      return assign(::new (Mem)
+                        value_type(Hash, typename value_type::EmplaceTag{},
+                                   std::forward<ArgsT>(Args)...));
+    }
+
+    LazyValueConstructor(LazyValueConstructor &&RHS)
+        : Mem(RHS.Mem), Result(RHS.Result), Hash(RHS.Hash) {
+      RHS.Mem = nullptr; // Moved away, cannot call.
+    }
+    ~LazyValueConstructor() { assert(!Mem && "Constructor never called!"); }
+
+  private:
+    value_type &assign(value_type *V) {
+      Mem = nullptr;
+      Result = V;
+      return *V;
+    }
+    friend class ThreadSafeTrieRawHashMap;
+    LazyValueConstructor() = delete;
+    LazyValueConstructor(void *Mem, value_type *&Result, ArrayRef<uint8_t> Hash)
+        : Mem(Mem), Result(Result), Hash(Hash) {
+      assert(Hash.size() == sizeof(HashT) && "Invalid hash");
+      assert(Mem && "Invalid memory for construction");
+    }
+    void *Mem;
+    value_type *&Result;
+    ArrayRef<uint8_t> Hash;
+  };
+
+  /// Insert with a hint. Default-constructed hint will work, but it's
+  /// recommended to start with a lookup to avoid overhead in object creation
+  /// if it already exists.
+  pointer insertLazy(const_pointer Hint, ArrayRef<uint8_t> Hash,
+                     function_ref<void(LazyValueConstructor)> OnConstruct) {
+    return pointer(ThreadSafeTrieRawHashMapBase::insert(
+        Hint, Hash, [&](void *Mem, ArrayRef<uint8_t> Hash) {
+          value_type *Result = nullptr;
+          OnConstruct(LazyValueConstructor(Mem, Result, Hash));
+          return Result->Hash.data();
+        }));
+  }
+
+  pointer insertLazy(ArrayRef<uint8_t> Hash,
+                     function_ref<void(LazyValueConstructor)> OnConstruct) {
+    return insertLazy(const_pointer(), Hash, OnConstruct);
+  }
+
+  pointer insert(const_pointer Hint, value_type &&HashedData) {
+    return insertLazy(Hint, HashedData.Hash, [&](LazyValueConstructor C) {
+      C(std::move(HashedData.Data));
+    });
+  }
+
+  pointer insert(const_pointer Hint, const value_type &HashedData) {
+    return insertLazy(Hint, HashedData.Hash,
+                      [&](LazyValueConstructor C) { C(HashedData.Data); });
+  }
+
+  pointer find(ArrayRef<uint8_t> Hash) {
+    assert(Hash.size() == std::tuple_size<HashT>::value);
+    return ThreadSafeTrieRawHashMapBase::find(Hash);
+  }
+
+  const_pointer find(ArrayRef<uint8_t> Hash) const {
+    assert(Hash.size() == std::tuple_size<HashT>::value);
+    return ThreadSafeTrieRawHashMapBase::find(Hash);
+  }
+
+  ThreadSafeTrieRawHashMap(std::optional<size_t> NumRootBits = std::nullopt,
+                           std::optional<size_t> NumSubtrieBits = std::nullopt)
+      : ThreadSafeTrieRawHashMapBase(DefaultContentAllocSize<value_type>,
+                                     DefaultContentAllocAlign<value_type>,
+                                     DefaultContentOffset<value_type>,
+                                     NumRootBits, NumSubtrieBits) {}
+
+  ~ThreadSafeTrieRawHashMap() {
+    if constexpr (std::is_trivially_destructible<value_type>::value)
+      this->destroyImpl(nullptr);
+    else
+      this->destroyImpl(
+          [](void *P) { static_cast<value_type *>(P)->~value_type(); });
+  }
+
+  // Move constructor okay.
+  ThreadSafeTrieRawHashMap(ThreadSafeTrieRawHashMap &&) = default;
+
+  // No move assignment or any copy.
+  ThreadSafeTrieRawHashMap &operator=(ThreadSafeTrieRawHashMap &&) = delete;
+  ThreadSafeTrieRawHashMap(const ThreadSafeTrieRawHashMap &) = delete;
+  ThreadSafeTrieRawHashMap &
+  operator=(const ThreadSafeTrieRawHashMap &) = delete;
+};
+
+} // namespace llvm
+
+#endif // LLVM_ADT_TRIERAWHASHMAP_H
diff --git a/llvm/lib/Support/CMakeLists.txt b/llvm/lib/Support/CMakeLists.txt
index 531bdeaca12614..2ecaea4b02bf61 100644
--- a/llvm/lib/Support/CMakeLists.txt
+++ b/llvm/lib/Support/CMakeLists.txt
@@ -256,6 +256,7 @@ add_llvm_component_library(LLVMSupport
   TimeProfiler.cpp
   Timer.cpp
   ToolOutputFile.cpp
+  TrieRawHashMap.cpp
   Twine.cpp
   TypeSize.cpp
   Unicode.cpp
diff --git a/llvm/lib/Support/TrieHashIndexGenerator.h b/llvm/lib/Support/TrieHashIndexGenerator.h
new file mode 100644
index 00000000000000..c9e9b70e10d3c7
--- /dev/null
+++ b/llvm/lib/Support/TrieHashIndexGenerator.h
@@ -0,0 +1,89 @@
+//===- TrieHashIndexGenerator.h ---------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIB_SUPPORT_TRIEHASHINDEXGENERATOR_H
+#define LLVM_LIB_SUPPORT_TRIEHASHINDEXGENERATOR_H
+
+#include "llvm/ADT/ArrayRef.h"
+#include <optional>
+
+namespace llvm {
+
+struct IndexGenerator {
+  size_t NumRootBits;
+  size_t NumSubtrieBits;
+  ArrayRef<uint8_t> Bytes;
+  std::optional<size_t> StartBit = std::nullopt;
+
+  size_t getNumBits() const {
+    assert(StartBit);
+    size_t TotalNumBits = Bytes.size() * 8;
+    assert(*StartBit <= TotalNumBits);
+    return std::min(*StartBit ? NumSubtrieBits : NumRootBits,
+                    TotalNumBits - *StartBit);
+  }
+  size_t next() {
+    size_t Index;
+    if (!StartBit) {
+      StartBit = 0;
+      Index = getIndex(Bytes, *StartBit, NumRootBits);
+    } else {
+      *StartBit += *StartBit ? NumSubtrieBits : NumRootBits;
+      assert((*StartBit - NumRootBits) % NumSubtrieBits == 0);
+      Index = getIndex(Bytes, *StartBit, NumSubtrieBits);
+    }
+    return Index;
+  }
+
+  size_t hint(unsigned Index, unsigned Bit) {
+    assert(Index >= 0);
+    assert(Bit < Bytes.size() * 8);
+    assert(Bit == 0 || (Bit - NumRootBits) % NumSubtrieBits == 0);
+    StartBit = Bit;
+    return Index;
+  }
+
+  size_t getCollidingBits(ArrayRef<uint8_t> CollidingBits) const {
+    assert(StartBit);
+    return getIndex(CollidingBits, *StartBit, NumSubtrieBits);
+  }
+
+  static size_t getIndex(ArrayRef<uint8_t> Bytes, size_t StartBit,
+                         size_t NumBits) {
+    assert(StartBit < Bytes.size() * 8);
+
+    Bytes = Bytes.drop_front(StartBit / 8u);
+    StartBit %= 8u;
+    size_t Index = 0;
+    for (uint8_t Byte : Bytes) {
+      size_t ByteStart = 0, ByteEnd = 8;
+      if (StartBit) {
+        ByteStart = StartBit;
+        Byte &= (1u << (8 - StartBit)) - 1u;
+        StartBit = 0;
+      }
+      size_t CurrentNumBits = ByteEnd - ByteStart;
+      if (CurrentNumBits > NumBits) {
+        Byte >>= CurrentNumBits - NumBits;
+        CurrentNumBits = NumBits;
+      }
+      Index <<= CurrentNumBits;
+      Index |= Byte & ((1u << CurrentNumBits) - 1u);
+
+      assert(NumBits >= CurrentNumBits);
+      NumBits -= CurrentNumBits;
+      if (!NumBits)
+        break;
+    }
+    return Index;
+  }
+};
+
+} // namespace llvm
+
+#endif // LLVM_LIB_SUPPORT_TRIEHASHINDEXGENERATOR_H
diff --git a/llvm/lib/Support/TrieRawHashMap.cpp b/llvm/lib/Support/TrieRawHashMap.cpp
new file mode 100644
index 00000000000000..5f9b8b9ffea038
--- /dev/null
+++ b/llvm/lib/Support/TrieRawHashMap.cpp
@@ -0,0 +1,478 @@
+//===- TrieRawHashMap.cpp -------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/TrieRawHashMap.h"
+#include "TrieHashIndexGenerator.h"
+#include "llvm/ADT/LazyAtomicPointer.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/Allocator.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ThreadSafeAllocator.h"
+#include "llvm/Support/raw_ostream.h"
+#include <memory>
+
+using namespace llvm;
+
+namespace {
+struct TrieNode {
+  const bool IsSubtrie = false;
+
+  TrieNode(bool IsSubtrie) : IsSubtrie(IsSubtrie) {}
+
+  static void *operator new(size_t Size) { return ::malloc(Size); }
+  void operator delete(void *Ptr) { ::free(Ptr); }
+};
+
+struct TrieContent final : public TrieNode {
+  const uint8_t ContentOffset;
+  const uint8_t HashSize;
+  const uint8_t HashOffset;
+
+  void *getValuePointer() const {
+    auto *Content = reinterpret_cast<const uint8_t *>(this) + ContentOffset;
+    return const_cast<uint8_t *>(Content);
+  }
+
+  ArrayRef<uint8_t> getHash() const {
+    auto *Begin = reinterpret_cast<const uint8_t *>(this) + HashOffset;
+    return ArrayRef(Begin, Begin + HashSize);
+  }
+
+  TrieContent(size_t ContentOffset, size_t HashSize, size_t HashOffset)
+      : TrieNode(/*IsSubtrie=*/false), ContentOffset(ContentOffset),
+        HashSize(HashSize), HashOffset(HashOffset) {}
+
+  static bool classof(const TrieNode *TN) { return !TN->IsSubtrie; }
+};
+static_assert(sizeof(TrieContent) ==
+                  ThreadSafeTrieRawHashMapBase::TrieContentBaseSize,
+              "Check header assumption!");
+
+class TrieSubtrie final : public TrieNode {
+public:
+  TrieNode *get(size_t I) const { return Slots[I].load(); }
+
+  TrieSubtrie *
+  sink(size_t I, TrieContent &Content, size_t NumSubtrieBits, size_t NewI,
+       function_ref<TrieSubtrie *(std::unique_ptr<TrieSubtrie>)> Saver);
+
+  static std::unique_ptr<TrieSubtrie> create(size_t StartBit, size_t NumBits);
+
+  explicit TrieSubtrie(size_t StartBit, size_t NumBits);
+
+  static bool classof(const TrieNode *TN) { return TN->IsSubtrie; }
+
+private:
+  // FIXME: Use a bitset to speed up access:
+  //
+  //     std::array<std::atomic<uint64_t>, NumSlots/64> IsSet;
+  //
+  // This will avoid needing to visit sparsely filled slots in
+  // \a ThreadSafeTrieRawHashMapBase::destroyImpl() when there's a non-trivial
+  // destructor.
+  //
+  // It would also greatly speed up iteration, if we add that some day, and
+  // allow get() to return one level sooner.
+  //
+  // This would be the algorithm for updating IsSet (after updating Slots):
+  //
+  //     std::atomic<uint64_t> &Bits = IsSet[I.High];
+  //     const uint64_t NewBit = 1ULL << I.Low;
+  //     uint64_t Old = 0;
+  //     while (!Bits.compare_exchange_weak(Old, Old | NewBit))
+  //       ;
+
+  // For debugging.
+  unsigned StartBit = 0;
+  unsigned NumBits = 0;
+  friend class llvm::ThreadSafeTrieRawHashMapBase;
+
+public:
+  /// Linked list for ownership of tries. The pointer is owned by TrieSubtrie.
+  std::atomic<TrieSubtrie *> Next;
+
+  /// The (co-allocated) slots of the subtrie.
+  MutableArrayRef<LazyAtomicPointer<TrieNode>> Slots;
+};
+} // end namespace
+
+static size_t getTrieTailSize(size_t StartBit, size_t NumBits) {
+  assert(NumBits < 20 && "Tries should have fewer than ~1M slots");
+  return sizeof(TrieNode *) * (1u << NumBits);
+}
+
+std::unique_ptr<TrieSubtrie> TrieSubtrie::create(size_t StartBit,
+                                                 size_t NumBits) {
+  size_t Size = sizeof(TrieSubtrie) + getTrieTailSize(StartBit, NumBits);
+  void *Memory = ::malloc(Size);
+  TrieSubtrie *S = ::new (Memory) TrieSubtrie(StartBit, NumBits);
+  return std::unique_ptr<TrieSubtrie>(S);
+}
+
+TrieSubtrie::TrieSubtrie(size_t StartBit, size_t NumBits)
+    : TrieNode(true), StartBit(StartBit), NumBits(NumBits), Next(nullptr),
+      Slots(reinterpret_cast<LazyAtomicPointer<TrieNode> *>(
+                reinterpret_cast<char *>(this) + sizeof(TrieSubtrie)),
+            (1u << NumBits)) {
+  for (auto *I = Slots.begin(), *E = Slots.end(); I != E; ++I)
+    new (I) LazyAtomicPointer<TrieNode>(nullptr);
+
+  static_assert(
+      std::is_trivially_destructible<LazyAtomicPointer<TrieNode>>::value,
+      "Expected no work in destructor for TrieNode");
+}
+
+TrieSubtrie *TrieSubtrie::sink(
+    size_t I, TrieContent &Content, size_t NumSubtrieBits, size_t NewI,
+    function_ref<TrieSubtrie *(std::unique_ptr<TrieSubtrie>)> Saver) {
+  assert(NumSubtrieBits > 0);
+  std::unique_ptr<TrieSubtrie> S = create(StartBit + NumBits, NumSubtrieBits);
+
+  assert(NewI < S->Slots.size());
+  S->Slots[NewI].store(&Content);
+
+  TrieNode *ExistingNode = &Content;
+  assert(I < Slots.size());
+  if (Slots[I].compare_exchange_strong(ExistingNode, S.get()))
+    return Saver(std::move(S));
+
+  // Another thread created a subtrie already. Return it and let "S" be
+  // destructed.
+  return cast<TrieSubtrie>(ExistingNode);
+}
+
+struct ThreadSafeTrieRawHashMapBase::ImplType {
+  static std::unique_ptr<ImplType> create(size_t StartBit, size_t NumBits) {
+    size_t Size = sizeof(ImplType) + getTrieTailSize(StartBit, NumBits);
+    void *Memory = ::malloc(Size);
+    ImplType *Impl = ::new (Memory) ImplType(StartBit, NumBits);
+    return std::unique_ptr<ImplType>(Impl);
+  }
+
+  TrieSubtrie *save(std::unique_ptr<TrieSubtrie> S) {
+    assert(!S->Next && "Expected S to a freshly-constructed leaf");
+
+    TrieSubtrie *CurrentHead = nullptr;
+    // Add ownership of "S" to front of the list, so that Root -> S ->
+    // Root.Next. This works by repeatedly setting S->Next to a candidate value
+    // of Root.Next (initially nullptr), then setting Root.Next to S once the
+    // candidate matches reality.
+    while (!Root.Next.compare_exchange_weak(CurrentHead, S.get()))
+      S->Next.exchange(CurrentHead);
+
+    // Ownership transferred to subtrie.
+    return S.release();
+  }
+
+  static void *operator new(size_t Size) { return ::malloc(Size); }
+  void operator delete(void *Ptr) { ::free(Ptr); }
+
+  /// FIXME: This should take a function that allocates and constructs the
+  /// content lazily (taking the hash as a separate parameter), in case of
+  /// collision.
+  ThreadSafeAllocator<BumpPtrAllocator> ContentAlloc;
+  TrieSubtrie Root; // Must be last! Tail-allocated.
+
+private:
+  ImplType(size_t StartBit, size_t NumBits) : Root(StartBit, NumBits) {}
+};
+
+ThreadSafeTrieRawHashMapBase::ImplType &
+ThreadSafeTrieRawHashMapBase::getOrCreateImpl() {
+  if (ImplType *Impl = ImplPtr.load())
+    return *Impl;
+
+  // Create a new ImplType and store it if another thread doesn't do so first.
+  // If another thread wins this one is destroyed locally.
+  std::unique_ptr<ImplType> Impl = ImplType::create(0, NumRootBits);
+  ImplType *ExistingImpl = nullptr;
+  if (ImplPtr.compare_exchange_strong(ExistingImpl, Impl.get()))
+    return *Impl.release();
+
+  return *ExistingImpl;
+}
+
+ThreadSafeTrieRawHashMapBase::PointerBase
+ThreadSafeTrieRawHashMapBase::find(ArrayRef<uint8_t> Hash) const {
+  assert(!Hash.empty() && "Uninitialized hash");
+
+  ImplType *Impl = ImplPtr.load();
+  if (!Impl)
+    return PointerBase();
+
+  TrieSubtrie *S = &Impl->Root;
+  IndexGenerator IndexGen{NumRootBits, NumSubtrieBits, Hash};
+  size_t Index = IndexGen.next();
+  while (true) {
+    // Try to set the content.
+    TrieNode *Existing = S->get(Index);
+    if (!Existing)
+      return PointerBase(S, Index, *IndexGen.StartBit);
+
+    // Check for an exact match.
+    if (auto *ExistingContent = dyn_cast<TrieContent>(Existing))
+      return ExistingContent->getHash() == Hash
+                 ? PointerBase(ExistingContent->getValuePointer())
+                 : PointerBase(S, Index, *IndexGen.StartBit);
+
+    Index = IndexGen.next();
+    S = cast<TrieSubtrie>(Existing);
+  }
+}
+
+ThreadSafeTrieRawHashMapBase::PointerBase ThreadSafeTrieRawHashMapBase::insert(
+    PointerBase Hint, ArrayRef<uint8_t> Hash,
+    function_ref<const uint8_t *(void *Mem, ArrayRef<uint8_t> Hash)>
+        Constructor) {
+  assert(!Hash.empty() && "Uninitialized hash");
+
+  ImplType &Impl = getOrCreateImpl();
+  TrieSubtrie *S = &Impl.Root;
+  IndexGenerator IndexGen{NumRootBits, NumSubtrieBits, Hash};
+  size_t Index;
+  if (Hint.isHint()) {
+    S = static_cast<TrieSubtrie *>(Hint.P);
+    Index = IndexGen.hint(Hint.I, Hint.B);
+  } else {
+    Index = IndexGen.next();
+  }
+
+  while (true) {
+    // Load the node from the slot, allocating and calling the constructor if
+    // the slot is empty.
+    bool Generated = false;
+    TrieNode &Existing = S->Slots[Index].loadOrGenerate([&]() {
+      Generated = true;
+
+      // Construct the value itself at the tail.
+      uint8_t *Memory = reinterpret_cast<uint8_t *>(
+          Impl.ContentAlloc.Allocate(ContentAllocSize, ContentAllocAlign));
+      const uint8_t *HashStorage = Constructor(Memory + ContentOffset, Hash);
+
+      // Construct the TrieContent header, passing in the offset to the hash.
+      TrieContent *Content = ::new (Memory)
+          TrieContent(ContentOffset, Hash.size(), HashStorage - Memory);
+      assert(Hash == Content->getHash() && "Hash not properly initialized");
+      return Content;
+    });
+    // If we just generated it, return it!
+    if (Generated)
+      return PointerBase(cast<TrieContent>(Existing).getValuePointer());
+
+    if (auto *ST = dyn_cast<TrieSubtrie>(&Existing)) {
+      S = ST;
+      Index = IndexGen.next();
+      continue;
+    }
+
+    // Return the existing content if it's an exact match!
+    auto &ExistingContent = cast<TrieContent>(Existing);
+    if (ExistingContent.getHash() == Hash)
+      return PointerBase(ExistingContent.getValuePointer());
+
+    // Sink the existing content as long as the indexes match.
+    while (true) {
+      size_t NextIndex = IndexGen.next();
+      size_t NewIndexForExistingContent =
+          IndexGen.getCollidingBits(ExistingContent.getHash());
+      S = S->sink(Index, ExistingContent, IndexGen.getNumBits(),
+                  NewIndexForExistingContent,
+                  [&Impl](std::unique_ptr<TrieSubtrie> S) {
+                    return Impl.save(std::move(S));
+                  });
+      Index = NextIndex;
+
+      // Found the difference.
+      if (NextIndex != NewIndexForExistingContent)
+        break;
+    }
+  }
+}
+
+ThreadSafeTrieRawHashMapBase::ThreadSafeTrieRawHashMapBase(
+    size_t ContentAllocSize, size_t ContentAllocAlign, size_t ContentOffset,
+    std::optional<size_t> NumRootBits, std::optional<size_t> NumSubtrieBits)
+    : ContentAllocSize(ContentAllocSize), ContentAllocAlign(ContentAllocAlign),
+      ContentOffset(ContentOffset),
+      NumRootBits(NumRootBits ? *NumRootBits : DefaultNumRootBits),
+      NumSubtrieBits(NumSubtrieBits ? *NumSubtrieBits : DefaultNumSubtrieBits),
+      ImplPtr(nullptr) {
+  assert((!NumRootBits || *NumRootBits < 20) &&
+         "Root should have fewer than ~1M slots");
+  assert((!NumSubtrieBits || *NumSubtrieBits < 10) &&
+         "Subtries should have fewer than ~1K slots");
+}
+
+ThreadSafeTrieRawHashMapBase::ThreadSafeTrieRawHashMapBase(
+    ThreadSafeTrieRawHashMapBase &&RHS)
+    : ContentAllocSize(RHS.ContentAllocSize),
+      ContentAllocAlign(RHS.ContentAllocAlign),
+      ContentOffset(RHS.ContentOffset), NumRootBits(RHS.NumRootBits),
+      NumSubtrieBits(RHS.NumSubtrieBits) {
+  // Steal the root from RHS.
+  ImplPtr = RHS.ImplPtr.exchange(nullptr);
+}
+
+ThreadSafeTrieRawHashMapBase::~ThreadSafeTrieRawHashMapBase() {
+  assert(!ImplPtr.load() && "Expected subclass to call destroyImpl()");
+}
+
+void ThreadSafeTrieRawHashMapBase::destroyImpl(
+    function_ref<void(void *)> Destructor) {
+  std::unique_ptr<ImplType> Impl(ImplPtr.exchange(nullptr));
+  if (!Impl)
+    return;
+
+  // Destroy content nodes throughout trie. Avoid destroying any subtries since
+  // we need TrieNode::classof() to find the content nodes.
+  //
+  // FIXME: Once we have bitsets (see FIXME in TrieSubtrie class), use them
+  // facilitate sparse iteration here.
+  if (Destructor)
+    for (TrieSubtrie *Trie = &Impl->Root; Trie; Trie = Trie->Next.load())
+      for (auto &Slot : Trie->Slots)
+        if (auto *Content = dyn_cast_or_null<TrieContent>(Slot.load()))
+          Destructor(Content->getValuePointer());
+
+  // Destroy the subtries. Incidentally, this destroys them in the reverse order
+  // of saving.
+  TrieSubtrie *Trie = Impl->Root.Next;
+  while (Trie) {
+    TrieSubtrie *Next = Trie->Next.exchange(nullptr);
+    delete Trie;
+    Trie = Next;
+  }
+}
+
+ThreadSafeTrieRawHashMapBase::PointerBase
+ThreadSafeTrieRawHashMapBase::getRoot() const {
+  ImplType *Impl = ImplPtr.load();
+  if (!Impl)
+    return PointerBase();
+  return PointerBase(&Impl->Root);
+}
+
+unsigned ThreadSafeTrieRawHashMapBase::getStartBit(
+    ThreadSafeTrieRawHashMapBase::PointerBase P) const {
+  assert(!P.isHint() && "Not a valid trie");
+  if (!P.P)
+    return 0;
+  if (auto *S = dyn_cast<TrieSubtrie>((TrieNode *)P.P))
+    return S->StartBit;
+  return 0;
+}
+
+unsigned ThreadSafeTrieRawHashMapBase::getNumBits(
+    ThreadSafeTrieRawHashMapBase::PointerBase P) const {
+  assert(!P.isHint() && "Not a valid trie");
+  if (!P.P)
+    return 0;
+  if (auto *S = dyn_cast<TrieSubtrie>((TrieNode *)P.P))
+    return S->NumBits;
+  return 0;
+}
+
+unsigned ThreadSafeTrieRawHashMapBase::getNumSlotUsed(
+    ThreadSafeTrieRawHashMapBase::PointerBase P) const {
+  assert(!P.isHint() && "Not a valid trie");
+  if (!P.P)
+    return 0;
+  auto *S = dyn_cast<TrieSubtrie>((TrieNode *)P.P);
+  if (!S)
+    return 0;
+  unsigned Num = 0;
+  for (unsigned I = 0, E = S->Slots.size(); I < E; ++I)
+    if (auto *E = S->Slots[I].load())
+      ++Num;
+  return Num;
+}
+
+std::string ThreadSafeTrieRawHashMapBase::getTriePrefixAsString(
+    ThreadSafeTrieRawHashMapBase::PointerBase P) const {
+  assert(!P.isHint() && "Not a valid trie");
+  if (!P.P)
+    return "";
+
+  auto *S = dyn_cast<TrieSubtrie>((TrieNode *)P.P);
+  if (!S || !S->IsSubtrie)
+    return "";
+
+  // Find a TrieContent node which has hash stored. Depth search following the
+  // first used slot until a TrieContent node is found.
+  TrieSubtrie *Current = S;
+  TrieContent *Node = nullptr;
+  while (Current) {
+    TrieSubtrie *Next = nullptr;
+    // find first used slot in the trie.
+    for (unsigned I = 0, E = Current->Slots.size(); I < E; ++I) {
+      auto *S = Current->get(I);
+      if (!S)
+        continue;
+
+      if (auto *Content = dyn_cast<TrieContent>(S))
+        Node = Content;
+      else if (auto *Sub = dyn_cast<TrieSubtrie>(S))
+        Next = Sub;
+      break;
+    }
+
+    // Found the node.
+    if (Node)
+      break;
+
+    // Continue to the next level if the node is not found.
+    Current = Next;
+  }
+
+  assert(Node && "malformed trie, cannot find TrieContent on leaf node");
+  // The prefix for the current trie is the first `StartBit` of the content
+  // stored underneath this subtrie.
+  std::string Str;
+  raw_string_ostream SS(Str);
+
+  unsigned StartFullBytes = (S->StartBit + 1) / 8 - 1;
+  SS << toHex(toStringRef(Node->getHash()).take_front(StartFullBytes),
+              /*LowerCase=*/true);
+
+  // For the part of the prefix that doesn't fill a byte, print raw bit values.
+  std::string Bits;
+  for (unsigned I = StartFullBytes * 8, E = S->StartBit; I < E; ++I) {
+    unsigned Index = I / 8;
+    unsigned Offset = 7 - I % 8;
+    Bits.push_back('0' + ((Node->getHash()[Index] >> Offset) & 1));
+  }
+
+  if (!Bits.empty())
+    SS << "[" << Bits << "]";
+
+  return SS.str();
+}
+
+unsigned ThreadSafeTrieRawHashMapBase::getNumTries() const {
+  ImplType *Impl = ImplPtr.load();
+  if (!Impl)
+    return 0;
+  unsigned Num = 0;
+  for (TrieSubtrie *Trie = &Impl->Root; Trie; Trie = Trie->Next.load())
+    ++Num;
+  return Num;
+}
+
+ThreadSafeTrieRawHashMapBase::PointerBase
+ThreadSafeTrieRawHashMapBase::getNextTrie(
+    ThreadSafeTrieRawHashMapBase::PointerBase P) const {
+  assert(!P.isHint() && "Not a valid trie");
+  if (!P.P)
+    return PointerBase();
+  auto *S = dyn_cast<TrieSubtrie>((TrieNode *)P.P);
+  if (!S)
+    return PointerBase();
+  if (auto *E = S->Next.load())
+    return PointerBase(E);
+  return PointerBase();
+}
diff --git a/llvm/unittests/ADT/CMakeLists.txt b/llvm/unittests/ADT/CMakeLists.txt
index 745e4d9fb74a4a..b0077d5b54a3ee 100644
--- a/llvm/unittests/ADT/CMakeLists.txt
+++ b/llvm/unittests/ADT/CMakeLists.txt
@@ -86,6 +86,7 @@ add_llvm_unittest(ADTTests
   StringSetTest.cpp
   StringSwitchTest.cpp
   TinyPtrVectorTest.cpp
+  TrieRawHashMapTest.cpp
   TwineTest.cpp
   TypeSwitchTest.cpp
   TypeTraitsTest.cpp
diff --git a/llvm/unittests/ADT/TrieRawHashMapTest.cpp b/llvm/unittests/ADT/TrieRawHashMapTest.cpp
new file mode 100644
index 00000000000000..95c26fa70e3ddf
--- /dev/null
+++ b/llvm/unittests/ADT/TrieRawHashMapTest.cpp
@@ -0,0 +1,343 @@
+//===- TrieRawHashMapTest.cpp ---------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/TrieRawHashMap.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/Endian.h"
+#include "llvm/Support/SHA1.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+namespace llvm {
+class TrieRawHashMapTestHelper {
+public:
+  TrieRawHashMapTestHelper() = default;
+
+  void setTrie(ThreadSafeTrieRawHashMapBase *T) { Trie = T; }
+
+  ThreadSafeTrieRawHashMapBase::PointerBase getRoot() const {
+    return Trie->getRoot();
+  }
+  unsigned getStartBit(ThreadSafeTrieRawHashMapBase::PointerBase P) const {
+    return Trie->getStartBit(P);
+  }
+  unsigned getNumBits(ThreadSafeTrieRawHashMapBase::PointerBase P) const {
+    return Trie->getNumBits(P);
+  }
+  unsigned getNumSlotUsed(ThreadSafeTrieRawHashMapBase::PointerBase P) const {
+    return Trie->getNumSlotUsed(P);
+  }
+  unsigned getNumTries() const { return Trie->getNumTries(); }
+  std::string
+  getTriePrefixAsString(ThreadSafeTrieRawHashMapBase::PointerBase P) const {
+    return Trie->getTriePrefixAsString(P);
+  }
+  ThreadSafeTrieRawHashMapBase::PointerBase
+  getNextTrie(ThreadSafeTrieRawHashMapBase::PointerBase P) const {
+    return Trie->getNextTrie(P);
+  }
+
+private:
+  ThreadSafeTrieRawHashMapBase *Trie = nullptr;
+};
+} // namespace llvm
+
+namespace {
+template <typename DataType, size_t HashSize>
+class SimpleTrieHashMapTest : public TrieRawHashMapTestHelper,
+                              public ::testing::Test {
+public:
+  using NumType = DataType;
+  using HashType = std::array<uint8_t, HashSize>;
+  using TrieType = ThreadSafeTrieRawHashMap<DataType, sizeof(HashType)>;
+
+  TrieType &createTrie(size_t RootBits, size_t SubtrieBits) {
+    auto &Ret = Trie.emplace(RootBits, SubtrieBits);
+    TrieRawHashMapTestHelper::setTrie(&Ret);
+    return Ret;
+  }
+
+  void destroyTrie() { Trie.reset(); }
+
+  ~SimpleTrieHashMapTest() {
+    if (Trie)
+      Trie.reset();
+  }
+
+  // Use the number itself as hash to test the pathological case.
+  static HashType hash(uint64_t Num) {
+    uint64_t HashN =
+        llvm::support::endian::byte_swap(Num, llvm::endianness::big);
+    HashType Hash;
+    memcpy(&Hash[0], &HashN, sizeof(HashType));
+    return Hash;
+  };
+
+private:
+  std::optional<TrieType> Trie;
+};
+
+using SmallNodeTrieTest = SimpleTrieHashMapTest<uint64_t, sizeof(uint64_t)>;
+
+TEST_F(SmallNodeTrieTest, TrieAllocation) {
+  NumType Numbers[] = {
+      0x0, std::numeric_limits<NumType>::max(),      0x1, 0x2,
+      0x3, std::numeric_limits<NumType>::max() - 1u,
+  };
+
+  unsigned ExpectedTries[] = {
+      1,       // Allocate Root.
+      1,       // Both on the root.
+      64,      // 0 and 1 sinks all the way down.
+      64,      // no new allocation needed.
+      65,      // need a new node between 2 and 3.
+      65 + 63, // 63 new allocation to sink two big numbers all the way.
+  };
+
+  const char *ExpectedPrefix[] = {
+      "", // Root.
+      "", // Root.
+      "00000000000000[0000000]",
+      "00000000000000[0000000]",
+      "00000000000000[0000001]",
+      "ffffffffffffff[1111111]",
+  };
+
+  // Use root and subtrie sizes of 1 so this gets sunk quite deep.
+  auto &Trie = createTrie(/*RootBits=*/1, /*SubtrieBits=*/1);
+
+  for (unsigned I = 0; I < 6; ++I) {
+    // Lookup first to exercise hint code for deep tries.
+    TrieType::pointer Lookup = Trie.find(hash(Numbers[I]));
+    EXPECT_FALSE(Lookup);
+
+    Trie.insert(Lookup, TrieType::value_type(hash(Numbers[I]), Numbers[I]));
+    EXPECT_EQ(getNumTries(), ExpectedTries[I]);
+    EXPECT_EQ(getTriePrefixAsString(getNextTrie(getRoot())), ExpectedPrefix[I]);
+  }
+}
+
+TEST_F(SmallNodeTrieTest, TrieStructure) {
+  NumType Numbers[] = {
+      // Three numbers that will nest deeply to test (1) sinking subtries and
+      // (2) deep, non-trivial hints.
+      std::numeric_limits<NumType>::max(),
+      std::numeric_limits<NumType>::max() - 2u,
+      std::numeric_limits<NumType>::max() - 3u,
+      // One number to stay at the top-level.
+      0x37,
+  };
+
+  // Use root and subtrie sizes of 1 so this gets sunk quite deep.
+  auto &Trie = createTrie(/*RootBits=*/1, /*SubtrieBits=*/1);
+
+  for (NumType N : Numbers) {
+    // Lookup first to exercise hint code for deep tries.
+    TrieType::pointer Lookup = Trie.find(hash(N));
+    EXPECT_FALSE(Lookup);
+
+    Trie.insert(Lookup, TrieType::value_type(hash(N), N));
+  }
+  for (NumType N : Numbers) {
+    TrieType::pointer Lookup = Trie.find(hash(N));
+    EXPECT_TRUE(Lookup);
+    if (!Lookup)
+      continue;
+    EXPECT_EQ(hash(N), Lookup->Hash);
+    EXPECT_EQ(N, Lookup->Data);
+
+    // Confirm a subsequent insertion fails to overwrite by trying to insert a
+    // bad value.
+    auto Result = Trie.insert(Lookup, TrieType::value_type(hash(N), N - 1));
+    EXPECT_EQ(N, Result->Data);
+  }
+
+  // Check the trie so we can confirm the structure is correct. Each subtrie
+  // should have 2 slots. The root's index=0 should have the content for
+  // 0x37 directly, and index=1 should be a linked-list of subtries, finally
+  // ending with content for (max-2) and (max-3).
+  //
+  // Note: This structure is not exhaustive (too expensive to update tests),
+  // but it does test that the dump format is somewhat readable and that the
+  // basic structure is correct.
+  //
+  // Note: This test requires that the trie reads bytes starting from index 0
+  // of the array of uint8_t, and then reads each byte's bits from high to low.
+
+  // Check the Trie.
+  // We should allocated a total of 64 SubTries for 64 bit hash.
+  ASSERT_EQ(getNumTries(), 64u);
+  // Check the root trie. Two slots and both are used.
+  ASSERT_EQ(getNumSlotUsed(getRoot()), 2u);
+  // Check last subtrie.
+  // Last allocated trie is the next node in the allocation chain.
+  auto LastAlloctedSubTrie = getNextTrie(getRoot());
+  ASSERT_EQ(getTriePrefixAsString(LastAlloctedSubTrie),
+            "ffffffffffffff[1111110]");
+  ASSERT_EQ(getStartBit(LastAlloctedSubTrie), 63u);
+  ASSERT_EQ(getNumBits(LastAlloctedSubTrie), 1u);
+  ASSERT_EQ(getNumSlotUsed(LastAlloctedSubTrie), 2u);
+}
+
+TEST_F(SmallNodeTrieTest, TrieStructureSmallFinalSubtrie) {
+  NumType Numbers[] = {
+      // Three numbers that will nest deeply to test (1) sinking subtries and
+      // (2) deep, non-trivial hints.
+      std::numeric_limits<NumType>::max(),
+      std::numeric_limits<NumType>::max() - 2u,
+      std::numeric_limits<NumType>::max() - 3u,
+      // One number to stay at the top-level.
+      0x37,
+  };
+
+  // Use subtrie size of 5 to avoid hitting 64 evenly, making the final subtrie
+  // small.
+  auto &Trie = createTrie(/*RootBits=*/8, /*SubtrieBits=*/5);
+
+  for (NumType N : Numbers) {
+    // Lookup first to exercise hint code for deep tries.
+    TrieType::pointer Lookup = Trie.find(hash(N));
+    EXPECT_FALSE(Lookup);
+
+    Trie.insert(Lookup, TrieType::value_type(hash(N), N));
+  }
+  for (NumType N : Numbers) {
+    TrieType::pointer Lookup = Trie.find(hash(N));
+    EXPECT_TRUE(Lookup);
+    if (!Lookup)
+      continue;
+    EXPECT_EQ(hash(N), Lookup->Hash);
+    EXPECT_EQ(N, Lookup->Data);
+
+    // Confirm a subsequent insertion fails to overwrite by trying to insert a
+    // bad value.
+    auto Result = Trie.insert(Lookup, TrieType::value_type(hash(N), N - 1));
+    EXPECT_EQ(N, Result->Data);
+  }
+
+  // Check the trie so we can confirm the structure is correct. The root
+  // should have 2^8=256 slots, most subtries should have 2^5=32 slots, and the
+  // deepest subtrie should have 2^1=2 slots (since (64-8)mod(5)=1).
+  // should have 2 slots. The root's index=0 should have the content for
+  // 0x37 directly, and index=1 should be a linked-list of subtries, finally
+  // ending with content for (max-2) and (max-3).
+  //
+  // Note: This structure is not exhaustive (too expensive to update tests),
+  // but it does test that the dump format is somewhat readable and that the
+  // basic structure is correct.
+  //
+  // Note: This test requires that the trie reads bytes starting from index 0
+  // of the array of uint8_t, and then reads each byte's bits from high to low.
+
+  // Check the Trie.
+  // 64 bit hash = 8 + 5 * 11 + 1, so 1 root, 11 8bit subtrie and 1 last level
+  // subtrie, 13 total.
+  ASSERT_EQ(getNumTries(), 13u);
+  // Check the root trie. Two slots and both are used.
+  ASSERT_EQ(getNumSlotUsed(getRoot()), 2u);
+  // Check last subtrie.
+  // Last allocated trie is the next node in the allocation chain.
+  auto LastAlloctedSubTrie = getNextTrie(getRoot());
+  ASSERT_EQ(getTriePrefixAsString(LastAlloctedSubTrie),
+            "ffffffffffffff[1111110]");
+  ASSERT_EQ(getStartBit(LastAlloctedSubTrie), 63u);
+  ASSERT_EQ(getNumBits(LastAlloctedSubTrie), 1u);
+  ASSERT_EQ(getNumSlotUsed(LastAlloctedSubTrie), 2u);
+}
+
+TEST_F(SmallNodeTrieTest, TrieDestructionLoop) {
+  // Test destroying large Trie. Make sure there is no recursion that can
+  // overflow the stack.
+
+  // Limit the tries to 2 slots (1 bit) to generate subtries at a higher rate.
+  auto &Trie = createTrie(/*NumRootBits=*/1, /*NumSubtrieBits=*/1);
+
+  // Fill them up. Pick a MaxN high enough to cause a stack overflow in debug
+  // builds.
+  static constexpr uint64_t MaxN = 100000;
+  for (uint64_t N = 0; N != MaxN; ++N) {
+    HashType Hash = hash(N);
+    Trie.insert(TrieType::pointer(), TrieType::value_type(Hash, NumType{N}));
+  }
+
+  // Destroy tries. If destruction is recursive and MaxN is high enough, these
+  // will both fail.
+  destroyTrie();
+}
+
+struct NumWithDestructorT {
+  uint64_t Num;
+  ~NumWithDestructorT() {}
+};
+
+using NodeWithDestructorTrieTest =
+    SimpleTrieHashMapTest<NumWithDestructorT, sizeof(uint64_t)>;
+
+TEST_F(NodeWithDestructorTrieTest, TrieDestructionLoop) {
+  // Test destroying large Trie. Make sure there is no recursion that can
+  // overflow the stack.
+
+  // Limit the tries to 2 slots (1 bit) to generate subtries at a higher rate.
+  auto &Trie = createTrie(/*NumRootBits=*/1, /*NumSubtrieBits=*/1);
+
+  // Fill them up. Pick a MaxN high enough to cause a stack overflow in debug
+  // builds.
+  static constexpr uint64_t MaxN = 100000;
+  for (uint64_t N = 0; N != MaxN; ++N) {
+    HashType Hash = hash(N);
+    Trie.insert(TrieType::pointer(), TrieType::value_type(Hash, NumType{N}));
+  }
+
+  // Destroy tries. If destruction is recursive and MaxN is high enough, these
+  // will both fail.
+  destroyTrie();
+}
+
+using NumStrNodeTrieTest = SimpleTrieHashMapTest<std::string, sizeof(uint64_t)>;
+
+TEST_F(NumStrNodeTrieTest, TrieInsertLazy) {
+  for (unsigned RootBits : {2, 3, 6, 10}) {
+    for (unsigned SubtrieBits : {2, 3, 4}) {
+      auto &Trie = createTrie(RootBits, SubtrieBits);
+      for (int I = 0, E = 1000; I != E; ++I) {
+        TrieType::pointer Lookup;
+        HashType H = hash(I);
+        if (I & 1)
+          Lookup = Trie.find(H);
+
+        auto insertNum = [&](uint64_t Num) {
+          std::string S = Twine(I).str();
+          auto Hash = hash(Num);
+          return Trie.insertLazy(
+              Hash, [&](TrieType::LazyValueConstructor C) { C(std::move(S)); });
+        };
+        auto S1 = insertNum(I);
+        // The address of the Data should be the same.
+        EXPECT_EQ(&S1->Data, &insertNum(I)->Data);
+
+        auto insertStr = [&](std::string S) {
+          int Num = std::stoi(S);
+          return insertNum(Num);
+        };
+        std::string S2 = S1->Data;
+        // The address of the Data should be the same.
+        EXPECT_EQ(&S1->Data, &insertStr(S2)->Data);
+      }
+      for (int I = 0, E = 1000; I != E; ++I) {
+        std::string S = Twine(I).str();
+        TrieType::pointer Lookup = Trie.find(hash(I));
+        EXPECT_TRUE(Lookup);
+        if (!Lookup)
+          continue;
+        EXPECT_EQ(S, Lookup->Data);
+      }
+    }
+  }
+}
+} // end anonymous namespace

>From 63ccd9cc08f5c45fd212f3c8b85d6db6f7360dd3 Mon Sep 17 00:00:00 2001
From: Steven Wu <stevenwu at apple.com>
Date: Fri, 20 Oct 2023 13:04:31 -0700
Subject: [PATCH 2/5] fixup! [ADT] Add TrieRawHashMap

---
 llvm/include/llvm/ADT/TrieRawHashMap.h    | 21 +--------------
 llvm/lib/Support/TrieHashIndexGenerator.h |  1 -
 llvm/lib/Support/TrieRawHashMap.cpp       |  7 ++++-
 llvm/unittests/ADT/TrieRawHashMapTest.cpp | 33 ++++++++++++-----------
 4 files changed, 25 insertions(+), 37 deletions(-)

diff --git a/llvm/include/llvm/ADT/TrieRawHashMap.h b/llvm/include/llvm/ADT/TrieRawHashMap.h
index baa08e214ce6fd..b7c4b92d1df307 100644
--- a/llvm/include/llvm/ADT/TrieRawHashMap.h
+++ b/llvm/include/llvm/ADT/TrieRawHashMap.h
@@ -105,10 +105,6 @@ class ThreadSafeTrieRawHashMapBase {
 
   public:
     PointerBase() noexcept = default;
-    PointerBase(PointerBase &&) = default;
-    PointerBase(const PointerBase &) = default;
-    PointerBase &operator=(PointerBase &&) = default;
-    PointerBase &operator=(const PointerBase &) = default;
 
   private:
     friend class ThreadSafeTrieRawHashMapBase;
@@ -228,9 +224,7 @@ class ThreadSafeTrieRawHashMap : public ThreadSafeTrieRawHashMapBase {
     friend class ThreadSafeTrieRawHashMap;
 
     ValueT *get() const {
-      if (void *B = PointerBase::get())
-        return reinterpret_cast<ValueT *>(B);
-      return nullptr;
+      return reinterpret_cast<ValueT *>(PointerBase::get());
     }
 
   public:
@@ -245,10 +239,6 @@ class ThreadSafeTrieRawHashMap : public ThreadSafeTrieRawHashMapBase {
     explicit operator bool() const { return get(); }
 
     PointerImpl() = default;
-    PointerImpl(PointerImpl &&) = default;
-    PointerImpl(const PointerImpl &) = default;
-    PointerImpl &operator=(PointerImpl &&) = default;
-    PointerImpl &operator=(const PointerImpl &) = default;
 
   protected:
     PointerImpl(PointerBase Result) : PointerBase(Result) {}
@@ -263,10 +253,6 @@ class ThreadSafeTrieRawHashMap : public ThreadSafeTrieRawHashMapBase {
 
   public:
     pointer() = default;
-    pointer(pointer &&) = default;
-    pointer(const pointer &) = default;
-    pointer &operator=(pointer &&) = default;
-    pointer &operator=(const pointer &) = default;
 
   private:
     pointer(PointerBase Result) : pointer::PointerImpl(Result) {}
@@ -277,11 +263,6 @@ class ThreadSafeTrieRawHashMap : public ThreadSafeTrieRawHashMapBase {
 
   public:
     const_pointer() = default;
-    const_pointer(const_pointer &&) = default;
-    const_pointer(const const_pointer &) = default;
-    const_pointer &operator=(const_pointer &&) = default;
-    const_pointer &operator=(const const_pointer &) = default;
-
     const_pointer(const pointer &P) : const_pointer::PointerImpl(P) {}
 
   private:
diff --git a/llvm/lib/Support/TrieHashIndexGenerator.h b/llvm/lib/Support/TrieHashIndexGenerator.h
index c9e9b70e10d3c7..fc1ebe92377a61 100644
--- a/llvm/lib/Support/TrieHashIndexGenerator.h
+++ b/llvm/lib/Support/TrieHashIndexGenerator.h
@@ -41,7 +41,6 @@ struct IndexGenerator {
   }
 
   size_t hint(unsigned Index, unsigned Bit) {
-    assert(Index >= 0);
     assert(Bit < Bytes.size() * 8);
     assert(Bit == 0 || (Bit - NumRootBits) % NumSubtrieBits == 0);
     StartBit = Bit;
diff --git a/llvm/lib/Support/TrieRawHashMap.cpp b/llvm/lib/Support/TrieRawHashMap.cpp
index 5f9b8b9ffea038..3a4656d4362120 100644
--- a/llvm/lib/Support/TrieRawHashMap.cpp
+++ b/llvm/lib/Support/TrieRawHashMap.cpp
@@ -49,6 +49,7 @@ struct TrieContent final : public TrieNode {
 
   static bool classof(const TrieNode *TN) { return !TN->IsSubtrie; }
 };
+
 static_assert(sizeof(TrieContent) ==
                   ThreadSafeTrieRawHashMapBase::TrieContentBaseSize,
               "Check header assumption!");
@@ -165,7 +166,7 @@ struct ThreadSafeTrieRawHashMapBase::ImplType {
     while (!Root.Next.compare_exchange_weak(CurrentHead, S.get()))
       S->Next.exchange(CurrentHead);
 
-    // Ownership transferred to subtrie.
+    // Ownership transferred to subtrie successfully. Release the unique_ptr.
     return S.release();
   }
 
@@ -191,9 +192,13 @@ ThreadSafeTrieRawHashMapBase::getOrCreateImpl() {
   // If another thread wins this one is destroyed locally.
   std::unique_ptr<ImplType> Impl = ImplType::create(0, NumRootBits);
   ImplType *ExistingImpl = nullptr;
+
+  // If the ownership transferred succesfully, release unique_ptr and return
+  // the pointer to the new ImplType.
   if (ImplPtr.compare_exchange_strong(ExistingImpl, Impl.get()))
     return *Impl.release();
 
+  // Already created, return the existing ImplType.
   return *ExistingImpl;
 }
 
diff --git a/llvm/unittests/ADT/TrieRawHashMapTest.cpp b/llvm/unittests/ADT/TrieRawHashMapTest.cpp
index 95c26fa70e3ddf..c9081f547812e9 100644
--- a/llvm/unittests/ADT/TrieRawHashMapTest.cpp
+++ b/llvm/unittests/ADT/TrieRawHashMapTest.cpp
@@ -49,7 +49,7 @@ class TrieRawHashMapTestHelper {
 } // namespace llvm
 
 namespace {
-template <typename DataType, size_t HashSize>
+template <typename DataType, size_t HashSize = sizeof(uint64_t)>
 class SimpleTrieHashMapTest : public TrieRawHashMapTestHelper,
                               public ::testing::Test {
 public:
@@ -64,11 +64,7 @@ class SimpleTrieHashMapTest : public TrieRawHashMapTestHelper,
   }
 
   void destroyTrie() { Trie.reset(); }
-
-  ~SimpleTrieHashMapTest() {
-    if (Trie)
-      Trie.reset();
-  }
+  ~SimpleTrieHashMapTest() { destroyTrie(); }
 
   // Use the number itself as hash to test the pathological case.
   static HashType hash(uint64_t Num) {
@@ -83,7 +79,7 @@ class SimpleTrieHashMapTest : public TrieRawHashMapTestHelper,
   std::optional<TrieType> Trie;
 };
 
-using SmallNodeTrieTest = SimpleTrieHashMapTest<uint64_t, sizeof(uint64_t)>;
+using SmallNodeTrieTest = SimpleTrieHashMapTest<uint64_t>;
 
 TEST_F(SmallNodeTrieTest, TrieAllocation) {
   NumType Numbers[] = {
@@ -209,9 +205,7 @@ TEST_F(SmallNodeTrieTest, TrieStructureSmallFinalSubtrie) {
   }
   for (NumType N : Numbers) {
     TrieType::pointer Lookup = Trie.find(hash(N));
-    EXPECT_TRUE(Lookup);
-    if (!Lookup)
-      continue;
+    ASSERT_TRUE(Lookup);
     EXPECT_EQ(hash(N), Lookup->Hash);
     EXPECT_EQ(N, Lookup->Data);
 
@@ -273,11 +267,11 @@ TEST_F(SmallNodeTrieTest, TrieDestructionLoop) {
 
 struct NumWithDestructorT {
   uint64_t Num;
-  ~NumWithDestructorT() {}
+  llvm::function_ref<void()> DestructorCallback;
+  ~NumWithDestructorT() { DestructorCallback(); }
 };
 
-using NodeWithDestructorTrieTest =
-    SimpleTrieHashMapTest<NumWithDestructorT, sizeof(uint64_t)>;
+using NodeWithDestructorTrieTest = SimpleTrieHashMapTest<NumWithDestructorT>;
 
 TEST_F(NodeWithDestructorTrieTest, TrieDestructionLoop) {
   // Test destroying large Trie. Make sure there is no recursion that can
@@ -289,17 +283,26 @@ TEST_F(NodeWithDestructorTrieTest, TrieDestructionLoop) {
   // Fill them up. Pick a MaxN high enough to cause a stack overflow in debug
   // builds.
   static constexpr uint64_t MaxN = 100000;
+
+  uint64_t DestructorCalled = 0;
+  auto DtorCallback = [&DestructorCalled]() { ++DestructorCalled; };
   for (uint64_t N = 0; N != MaxN; ++N) {
     HashType Hash = hash(N);
-    Trie.insert(TrieType::pointer(), TrieType::value_type(Hash, NumType{N}));
+    Trie.insert(TrieType::pointer(),
+                TrieType::value_type(Hash, NumType{N, DtorCallback}));
   }
+  // Reset the count after all the temporaries get destroyed.
+  DestructorCalled = 0;
 
   // Destroy tries. If destruction is recursive and MaxN is high enough, these
   // will both fail.
   destroyTrie();
+
+  // Count the number of destructor calls during `destroyTrie()`.
+  ASSERT_EQ(DestructorCalled, MaxN);
 }
 
-using NumStrNodeTrieTest = SimpleTrieHashMapTest<std::string, sizeof(uint64_t)>;
+using NumStrNodeTrieTest = SimpleTrieHashMapTest<std::string>;
 
 TEST_F(NumStrNodeTrieTest, TrieInsertLazy) {
   for (unsigned RootBits : {2, 3, 6, 10}) {

>From 6f55c27203dc2c28a94b2a0e1c3bcdff68a03034 Mon Sep 17 00:00:00 2001
From: Steven Wu <stevenwu at apple.com>
Date: Mon, 30 Oct 2023 10:13:40 -0700
Subject: [PATCH 3/5] fixup! Address review feedback for more comments

---
 llvm/lib/Support/TrieHashIndexGenerator.h | 33 ++++++++++++++++++++++-
 llvm/lib/Support/TrieRawHashMap.cpp       | 28 +++++++++++--------
 2 files changed, 49 insertions(+), 12 deletions(-)

diff --git a/llvm/lib/Support/TrieHashIndexGenerator.h b/llvm/lib/Support/TrieHashIndexGenerator.h
index fc1ebe92377a61..d67d22ec523bf4 100644
--- a/llvm/lib/Support/TrieHashIndexGenerator.h
+++ b/llvm/lib/Support/TrieHashIndexGenerator.h
@@ -14,12 +14,27 @@
 
 namespace llvm {
 
+/// The utility class that helps computing the index of the object inside trie
+/// from its hash. The generator can be configured with the number of bits
+/// used for each level of trie structure with \c NumRootsBits and \c
+/// NumSubtrieBits.
+/// For example, try computing indexes for a 16-bit hash 0x1234 with 8-bit root
+/// and 4-bit sub-trie:
+///
+///   IndexGenerator IndexGen{8, 4, Hash};
+///   size_t index1 = IndexGen.next(); // index 18 in root node.
+///   size_t index2 = IndexGen.next(); // index 3 in sub-trie level 1.
+///   size_t index3 = IndexGen.next(); // index 4 in sub-tire level 2.
+///
+/// This is used by different trie implementation to figure out where to
+/// insert/find the object in the data structure.
 struct IndexGenerator {
   size_t NumRootBits;
   size_t NumSubtrieBits;
   ArrayRef<uint8_t> Bytes;
   std::optional<size_t> StartBit = std::nullopt;
 
+  // Get the number of bits used to generate current index.
   size_t getNumBits() const {
     assert(StartBit);
     size_t TotalNumBits = Bytes.size() * 8;
@@ -27,12 +42,16 @@ struct IndexGenerator {
     return std::min(*StartBit ? NumSubtrieBits : NumRootBits,
                     TotalNumBits - *StartBit);
   }
+
+  // Get the index of the object in the next level of trie.
   size_t next() {
     size_t Index;
     if (!StartBit) {
+      // Compute index for root when StartBit is not set.
       StartBit = 0;
       Index = getIndex(Bytes, *StartBit, NumRootBits);
     } else {
+      // Compute index for sub-trie.
       *StartBit += *StartBit ? NumSubtrieBits : NumRootBits;
       assert((*StartBit - NumRootBits) % NumSubtrieBits == 0);
       Index = getIndex(Bytes, *StartBit, NumSubtrieBits);
@@ -40,6 +59,11 @@ struct IndexGenerator {
     return Index;
   }
 
+  // Provide a hint to speed up the index generation by providing the
+  // information of the hash in current level. For example, if the object is
+  // known to have \c Index on a level that already consumes first n \c Bits of
+  // the hash, it can start index generation from this level by calling \c hint
+  // function.
   size_t hint(unsigned Index, unsigned Bit) {
     assert(Bit < Bytes.size() * 8);
     assert(Bit == 0 || (Bit - NumRootBits) % NumSubtrieBits == 0);
@@ -47,18 +71,25 @@ struct IndexGenerator {
     return Index;
   }
 
+  // Utility funciton for looking up the index in the trie for an object that
+  // has colliding hash bits in the front as the hash of the object that is
+  // currently being computed.
   size_t getCollidingBits(ArrayRef<uint8_t> CollidingBits) const {
     assert(StartBit);
     return getIndex(CollidingBits, *StartBit, NumSubtrieBits);
   }
 
+  // Compute the index for the object from its hash, current start bits, and
+  // the number of bits used for current level.
   static size_t getIndex(ArrayRef<uint8_t> Bytes, size_t StartBit,
                          size_t NumBits) {
     assert(StartBit < Bytes.size() * 8);
-
+    // Drop all the bits before StartBit.
     Bytes = Bytes.drop_front(StartBit / 8u);
     StartBit %= 8u;
     size_t Index = 0;
+    // Compute the index using the bits in range [StartBit, StartBit + NumBits),
+    // note the range can spread across few `uint8_t` in the array.
     for (uint8_t Byte : Bytes) {
       size_t ByteStart = 0, ByteEnd = 8;
       if (StartBit) {
diff --git a/llvm/lib/Support/TrieRawHashMap.cpp b/llvm/lib/Support/TrieRawHashMap.cpp
index 3a4656d4362120..cf27a5caf409c2 100644
--- a/llvm/lib/Support/TrieRawHashMap.cpp
+++ b/llvm/lib/Support/TrieRawHashMap.cpp
@@ -102,14 +102,15 @@ class TrieSubtrie final : public TrieNode {
 };
 } // end namespace
 
-static size_t getTrieTailSize(size_t StartBit, size_t NumBits) {
-  assert(NumBits < 20 && "Tries should have fewer than ~1M slots");
+// Compute the trailing object size in the trie node. This is the size of \c
+// Slots in TrieNodes that pointing to the children.
+static size_t getTrieTailSize(size_t NumBits) {
   return sizeof(TrieNode *) * (1u << NumBits);
 }
 
 std::unique_ptr<TrieSubtrie> TrieSubtrie::create(size_t StartBit,
                                                  size_t NumBits) {
-  size_t Size = sizeof(TrieSubtrie) + getTrieTailSize(StartBit, NumBits);
+  size_t Size = sizeof(TrieSubtrie) + getTrieTailSize(NumBits);
   void *Memory = ::malloc(Size);
   TrieSubtrie *S = ::new (Memory) TrieSubtrie(StartBit, NumBits);
   return std::unique_ptr<TrieSubtrie>(S);
@@ -128,15 +129,22 @@ TrieSubtrie::TrieSubtrie(size_t StartBit, size_t NumBits)
       "Expected no work in destructor for TrieNode");
 }
 
+// Sink the nodes down sub-trie when the object being inserted collides with
+// the index of existing object in the trie. In this case, a new sub-trie needs
+// to be allocated to hold existing object.
 TrieSubtrie *TrieSubtrie::sink(
     size_t I, TrieContent &Content, size_t NumSubtrieBits, size_t NewI,
     function_ref<TrieSubtrie *(std::unique_ptr<TrieSubtrie>)> Saver) {
+  // Create a new sub-trie that points to the existing object with the new
+  // index for the next level.
   assert(NumSubtrieBits > 0);
   std::unique_ptr<TrieSubtrie> S = create(StartBit + NumBits, NumSubtrieBits);
 
   assert(NewI < S->Slots.size());
   S->Slots[NewI].store(&Content);
 
+  // Using compare_exchange to atomically add back the new sub-trie to the trie
+  // in the place of the exsiting object.
   TrieNode *ExistingNode = &Content;
   assert(I < Slots.size());
   if (Slots[I].compare_exchange_strong(ExistingNode, S.get()))
@@ -149,12 +157,15 @@ TrieSubtrie *TrieSubtrie::sink(
 
 struct ThreadSafeTrieRawHashMapBase::ImplType {
   static std::unique_ptr<ImplType> create(size_t StartBit, size_t NumBits) {
-    size_t Size = sizeof(ImplType) + getTrieTailSize(StartBit, NumBits);
+    size_t Size = sizeof(ImplType) + getTrieTailSize(NumBits);
     void *Memory = ::malloc(Size);
     ImplType *Impl = ::new (Memory) ImplType(StartBit, NumBits);
     return std::unique_ptr<ImplType>(Impl);
   }
 
+  // Save the Subtrie into the ownship list of the trie structure in a
+  // thread-safe way. The ownership transfer is done by compare_exchange the
+  // pointer value inside the unique_ptr.
   TrieSubtrie *save(std::unique_ptr<TrieSubtrie> S) {
     assert(!S->Next && "Expected S to a freshly-constructed leaf");
 
@@ -306,12 +317,7 @@ ThreadSafeTrieRawHashMapBase::ThreadSafeTrieRawHashMapBase(
       ContentOffset(ContentOffset),
       NumRootBits(NumRootBits ? *NumRootBits : DefaultNumRootBits),
       NumSubtrieBits(NumSubtrieBits ? *NumSubtrieBits : DefaultNumSubtrieBits),
-      ImplPtr(nullptr) {
-  assert((!NumRootBits || *NumRootBits < 20) &&
-         "Root should have fewer than ~1M slots");
-  assert((!NumSubtrieBits || *NumSubtrieBits < 10) &&
-         "Subtries should have fewer than ~1K slots");
-}
+      ImplPtr(nullptr) {}
 
 ThreadSafeTrieRawHashMapBase::ThreadSafeTrieRawHashMapBase(
     ThreadSafeTrieRawHashMapBase &&RHS)
@@ -413,7 +419,7 @@ std::string ThreadSafeTrieRawHashMapBase::getTriePrefixAsString(
   TrieContent *Node = nullptr;
   while (Current) {
     TrieSubtrie *Next = nullptr;
-    // find first used slot in the trie.
+    // Find first used slot in the trie.
     for (unsigned I = 0, E = Current->Slots.size(); I < E; ++I) {
       auto *S = Current->get(I);
       if (!S)

>From 702636e5accb441c2c806b3f1b9e0e51430cf95a Mon Sep 17 00:00:00 2001
From: Steven Wu <stevenwu at apple.com>
Date: Mon, 30 Oct 2023 11:26:21 -0700
Subject: [PATCH 4/5] fixup! Avoid `while (true)` during index generation

---
 llvm/lib/Support/TrieHashIndexGenerator.h | 13 ++++++++-----
 llvm/lib/Support/TrieRawHashMap.cpp       | 12 ++++++++----
 2 files changed, 16 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/Support/TrieHashIndexGenerator.h b/llvm/lib/Support/TrieHashIndexGenerator.h
index d67d22ec523bf4..404db2c28549be 100644
--- a/llvm/lib/Support/TrieHashIndexGenerator.h
+++ b/llvm/lib/Support/TrieHashIndexGenerator.h
@@ -45,18 +45,19 @@ struct IndexGenerator {
 
   // Get the index of the object in the next level of trie.
   size_t next() {
-    size_t Index;
     if (!StartBit) {
       // Compute index for root when StartBit is not set.
       StartBit = 0;
-      Index = getIndex(Bytes, *StartBit, NumRootBits);
-    } else {
+      return getIndex(Bytes, *StartBit, NumRootBits);
+    }
+    if (*StartBit < Bytes.size() * 8) {
       // Compute index for sub-trie.
       *StartBit += *StartBit ? NumSubtrieBits : NumRootBits;
       assert((*StartBit - NumRootBits) % NumSubtrieBits == 0);
-      Index = getIndex(Bytes, *StartBit, NumSubtrieBits);
+      return getIndex(Bytes, *StartBit, NumSubtrieBits);
     }
-    return Index;
+    // All the bits are consumed.
+    return end();
   }
 
   // Provide a hint to speed up the index generation by providing the
@@ -79,6 +80,8 @@ struct IndexGenerator {
     return getIndex(CollidingBits, *StartBit, NumSubtrieBits);
   }
 
+  size_t end() const { return SIZE_MAX; }
+
   // Compute the index for the object from its hash, current start bits, and
   // the number of bits used for current level.
   static size_t getIndex(ArrayRef<uint8_t> Bytes, size_t StartBit,
diff --git a/llvm/lib/Support/TrieRawHashMap.cpp b/llvm/lib/Support/TrieRawHashMap.cpp
index cf27a5caf409c2..41a44eda4cd662 100644
--- a/llvm/lib/Support/TrieRawHashMap.cpp
+++ b/llvm/lib/Support/TrieRawHashMap.cpp
@@ -224,7 +224,7 @@ ThreadSafeTrieRawHashMapBase::find(ArrayRef<uint8_t> Hash) const {
   TrieSubtrie *S = &Impl->Root;
   IndexGenerator IndexGen{NumRootBits, NumSubtrieBits, Hash};
   size_t Index = IndexGen.next();
-  while (true) {
+  while (Index != IndexGen.end()) {
     // Try to set the content.
     TrieNode *Existing = S->get(Index);
     if (!Existing)
@@ -239,6 +239,7 @@ ThreadSafeTrieRawHashMapBase::find(ArrayRef<uint8_t> Hash) const {
     Index = IndexGen.next();
     S = cast<TrieSubtrie>(Existing);
   }
+  llvm_unreachable("failed to locate the node after consuming all hash bytes");
 }
 
 ThreadSafeTrieRawHashMapBase::PointerBase ThreadSafeTrieRawHashMapBase::insert(
@@ -258,7 +259,7 @@ ThreadSafeTrieRawHashMapBase::PointerBase ThreadSafeTrieRawHashMapBase::insert(
     Index = IndexGen.next();
   }
 
-  while (true) {
+  while (Index != IndexGen.end()) {
     // Load the node from the slot, allocating and calling the constructor if
     // the slot is empty.
     bool Generated = false;
@@ -292,8 +293,8 @@ ThreadSafeTrieRawHashMapBase::PointerBase ThreadSafeTrieRawHashMapBase::insert(
       return PointerBase(ExistingContent.getValuePointer());
 
     // Sink the existing content as long as the indexes match.
-    while (true) {
-      size_t NextIndex = IndexGen.next();
+    size_t NextIndex = IndexGen.next();
+    while (NextIndex != IndexGen.end()) {
       size_t NewIndexForExistingContent =
           IndexGen.getCollidingBits(ExistingContent.getHash());
       S = S->sink(Index, ExistingContent, IndexGen.getNumBits(),
@@ -306,8 +307,11 @@ ThreadSafeTrieRawHashMapBase::PointerBase ThreadSafeTrieRawHashMapBase::insert(
       // Found the difference.
       if (NextIndex != NewIndexForExistingContent)
         break;
+
+      NextIndex = IndexGen.next();
     }
   }
+  llvm_unreachable("failed to insert the node after consuming all hash bytes");
 }
 
 ThreadSafeTrieRawHashMapBase::ThreadSafeTrieRawHashMapBase(

>From f853057d095fdb8131595ae62899f1a24ccf2e1c Mon Sep 17 00:00:00 2001
From: Steven Wu <stevenwu at apple.com>
Date: Tue, 22 Oct 2024 17:05:25 -0700
Subject: [PATCH 5/5] fixup! Use TrailingObjects implementation

---
 llvm/include/llvm/ADT/TrieRawHashMap.h |   2 +-
 llvm/lib/Support/TrieRawHashMap.cpp    | 105 +++++++++++++++----------
 2 files changed, 64 insertions(+), 43 deletions(-)

diff --git a/llvm/include/llvm/ADT/TrieRawHashMap.h b/llvm/include/llvm/ADT/TrieRawHashMap.h
index b7c4b92d1df307..5bf378b3c79715 100644
--- a/llvm/include/llvm/ADT/TrieRawHashMap.h
+++ b/llvm/include/llvm/ADT/TrieRawHashMap.h
@@ -171,7 +171,7 @@ class ThreadSafeTrieRawHashMapBase {
   const unsigned short ContentOffset;
   unsigned short NumRootBits;
   unsigned short NumSubtrieBits;
-  struct ImplType;
+  class ImplType;
   // ImplPtr is owned by ThreadSafeTrieRawHashMapBase and needs to be freed in
   // destoryImpl.
   std::atomic<ImplType *> ImplPtr;
diff --git a/llvm/lib/Support/TrieRawHashMap.cpp b/llvm/lib/Support/TrieRawHashMap.cpp
index 41a44eda4cd662..f7106ca233add7 100644
--- a/llvm/lib/Support/TrieRawHashMap.cpp
+++ b/llvm/lib/Support/TrieRawHashMap.cpp
@@ -13,6 +13,7 @@
 #include "llvm/Support/Allocator.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/ThreadSafeAllocator.h"
+#include "llvm/Support/TrailingObjects.h"
 #include "llvm/Support/raw_ostream.h"
 #include <memory>
 
@@ -54,9 +55,16 @@ static_assert(sizeof(TrieContent) ==
                   ThreadSafeTrieRawHashMapBase::TrieContentBaseSize,
               "Check header assumption!");
 
-class TrieSubtrie final : public TrieNode {
+class TrieSubtrie final
+    : public TrieNode,
+      private TrailingObjects<TrieSubtrie, LazyAtomicPointer<TrieNode>> {
 public:
-  TrieNode *get(size_t I) const { return Slots[I].load(); }
+  using Slot = LazyAtomicPointer<TrieNode>;
+
+  Slot &get(size_t I) { return getTrailingObjects<Slot>()[I]; }
+  TrieNode *load(size_t I) { return get(I).load(); }
+
+  unsigned size() const { return Size; }
 
   TrieSubtrie *
   sink(size_t I, TrieContent &Content, size_t NumSubtrieBits, size_t NewI,
@@ -68,6 +76,12 @@ class TrieSubtrie final : public TrieNode {
 
   static bool classof(const TrieNode *TN) { return TN->IsSubtrie; }
 
+  static constexpr size_t sizeToAlloc(unsigned NumBits) {
+    assert(NumBits < 20 && "Tries should have fewer than ~1M slots");
+    size_t Count = 1u << NumBits;
+    return totalSizeToAlloc<LazyAtomicPointer<TrieNode>>(Count);
+  }
+
 private:
   // FIXME: Use a bitset to speed up access:
   //
@@ -91,38 +105,28 @@ class TrieSubtrie final : public TrieNode {
   // For debugging.
   unsigned StartBit = 0;
   unsigned NumBits = 0;
+  unsigned Size = 0;
   friend class llvm::ThreadSafeTrieRawHashMapBase;
+  friend class TrailingObjects;
 
 public:
   /// Linked list for ownership of tries. The pointer is owned by TrieSubtrie.
   std::atomic<TrieSubtrie *> Next;
-
-  /// The (co-allocated) slots of the subtrie.
-  MutableArrayRef<LazyAtomicPointer<TrieNode>> Slots;
 };
 } // end namespace
 
-// Compute the trailing object size in the trie node. This is the size of \c
-// Slots in TrieNodes that pointing to the children.
-static size_t getTrieTailSize(size_t NumBits) {
-  return sizeof(TrieNode *) * (1u << NumBits);
-}
-
 std::unique_ptr<TrieSubtrie> TrieSubtrie::create(size_t StartBit,
                                                  size_t NumBits) {
-  size_t Size = sizeof(TrieSubtrie) + getTrieTailSize(NumBits);
-  void *Memory = ::malloc(Size);
+  void *Memory = ::malloc(sizeToAlloc(NumBits));
   TrieSubtrie *S = ::new (Memory) TrieSubtrie(StartBit, NumBits);
   return std::unique_ptr<TrieSubtrie>(S);
 }
 
 TrieSubtrie::TrieSubtrie(size_t StartBit, size_t NumBits)
-    : TrieNode(true), StartBit(StartBit), NumBits(NumBits), Next(nullptr),
-      Slots(reinterpret_cast<LazyAtomicPointer<TrieNode> *>(
-                reinterpret_cast<char *>(this) + sizeof(TrieSubtrie)),
-            (1u << NumBits)) {
-  for (auto *I = Slots.begin(), *E = Slots.end(); I != E; ++I)
-    new (I) LazyAtomicPointer<TrieNode>(nullptr);
+    : TrieNode(true), StartBit(StartBit), NumBits(NumBits), Size(1u << NumBits),
+      Next(nullptr) {
+  for (unsigned I = 0; I < Size; ++I)
+    new (&get(I)) Slot(nullptr);
 
   static_assert(
       std::is_trivially_destructible<LazyAtomicPointer<TrieNode>>::value,
@@ -140,14 +144,14 @@ TrieSubtrie *TrieSubtrie::sink(
   assert(NumSubtrieBits > 0);
   std::unique_ptr<TrieSubtrie> S = create(StartBit + NumBits, NumSubtrieBits);
 
-  assert(NewI < S->Slots.size());
-  S->Slots[NewI].store(&Content);
+  assert(NewI < Size);
+  S->get(NewI).store(&Content);
 
   // Using compare_exchange to atomically add back the new sub-trie to the trie
   // in the place of the exsiting object.
   TrieNode *ExistingNode = &Content;
-  assert(I < Slots.size());
-  if (Slots[I].compare_exchange_strong(ExistingNode, S.get()))
+  assert(I < Size);
+  if (get(I).compare_exchange_strong(ExistingNode, S.get()))
     return Saver(std::move(S));
 
   // Another thread created a subtrie already. Return it and let "S" be
@@ -155,9 +159,12 @@ TrieSubtrie *TrieSubtrie::sink(
   return cast<TrieSubtrie>(ExistingNode);
 }
 
-struct ThreadSafeTrieRawHashMapBase::ImplType {
+class ThreadSafeTrieRawHashMapBase::ImplType final
+    : private TrailingObjects<ThreadSafeTrieRawHashMapBase::ImplType,
+                              TrieSubtrie> {
+public:
   static std::unique_ptr<ImplType> create(size_t StartBit, size_t NumBits) {
-    size_t Size = sizeof(ImplType) + getTrieTailSize(NumBits);
+    size_t Size = sizeof(ImplType) + TrieSubtrie::sizeToAlloc(NumBits);
     void *Memory = ::malloc(Size);
     ImplType *Impl = ::new (Memory) ImplType(StartBit, NumBits);
     return std::unique_ptr<ImplType>(Impl);
@@ -174,13 +181,16 @@ struct ThreadSafeTrieRawHashMapBase::ImplType {
     // Root.Next. This works by repeatedly setting S->Next to a candidate value
     // of Root.Next (initially nullptr), then setting Root.Next to S once the
     // candidate matches reality.
-    while (!Root.Next.compare_exchange_weak(CurrentHead, S.get()))
+    while (!getRoot()->Next.compare_exchange_weak(CurrentHead, S.get()))
       S->Next.exchange(CurrentHead);
 
     // Ownership transferred to subtrie successfully. Release the unique_ptr.
     return S.release();
   }
 
+  // Get the root which is the trailing object.
+  TrieSubtrie *getRoot() { return getTrailingObjects<TrieSubtrie>(); }
+
   static void *operator new(size_t Size) { return ::malloc(Size); }
   void operator delete(void *Ptr) { ::free(Ptr); }
 
@@ -188,10 +198,13 @@ struct ThreadSafeTrieRawHashMapBase::ImplType {
   /// content lazily (taking the hash as a separate parameter), in case of
   /// collision.
   ThreadSafeAllocator<BumpPtrAllocator> ContentAlloc;
-  TrieSubtrie Root; // Must be last! Tail-allocated.
 
 private:
-  ImplType(size_t StartBit, size_t NumBits) : Root(StartBit, NumBits) {}
+  friend class TrailingObjects;
+
+  ImplType(size_t StartBit, size_t NumBits) {
+    ::new (getRoot()) TrieSubtrie(StartBit, NumBits);
+  }
 };
 
 ThreadSafeTrieRawHashMapBase::ImplType &
@@ -221,7 +234,7 @@ ThreadSafeTrieRawHashMapBase::find(ArrayRef<uint8_t> Hash) const {
   if (!Impl)
     return PointerBase();
 
-  TrieSubtrie *S = &Impl->Root;
+  TrieSubtrie *S = Impl->getRoot();
   IndexGenerator IndexGen{NumRootBits, NumSubtrieBits, Hash};
   size_t Index = IndexGen.next();
   while (Index != IndexGen.end()) {
@@ -249,7 +262,7 @@ ThreadSafeTrieRawHashMapBase::PointerBase ThreadSafeTrieRawHashMapBase::insert(
   assert(!Hash.empty() && "Uninitialized hash");
 
   ImplType &Impl = getOrCreateImpl();
-  TrieSubtrie *S = &Impl.Root;
+  TrieSubtrie *S = Impl.getRoot();
   IndexGenerator IndexGen{NumRootBits, NumSubtrieBits, Hash};
   size_t Index;
   if (Hint.isHint()) {
@@ -263,7 +276,7 @@ ThreadSafeTrieRawHashMapBase::PointerBase ThreadSafeTrieRawHashMapBase::insert(
     // Load the node from the slot, allocating and calling the constructor if
     // the slot is empty.
     bool Generated = false;
-    TrieNode &Existing = S->Slots[Index].loadOrGenerate([&]() {
+    TrieNode &Existing = S->get(Index).loadOrGenerate([&]() {
       Generated = true;
 
       // Construct the value itself at the tail.
@@ -321,7 +334,15 @@ ThreadSafeTrieRawHashMapBase::ThreadSafeTrieRawHashMapBase(
       ContentOffset(ContentOffset),
       NumRootBits(NumRootBits ? *NumRootBits : DefaultNumRootBits),
       NumSubtrieBits(NumSubtrieBits ? *NumSubtrieBits : DefaultNumSubtrieBits),
-      ImplPtr(nullptr) {}
+      ImplPtr(nullptr) {
+  // Assertion checks for reasonable configuration. The settings below are not
+  // hard limits on most platforms, but a reasonable configuration should fall
+  // within those limits.
+  assert((!NumRootBits || *NumRootBits < 20) &&
+         "Root should have fewer than ~1M slots");
+  assert((!NumSubtrieBits || *NumSubtrieBits < 10) &&
+         "Subtries should have fewer than ~1K slots");
+}
 
 ThreadSafeTrieRawHashMapBase::ThreadSafeTrieRawHashMapBase(
     ThreadSafeTrieRawHashMapBase &&RHS)
@@ -349,14 +370,14 @@ void ThreadSafeTrieRawHashMapBase::destroyImpl(
   // FIXME: Once we have bitsets (see FIXME in TrieSubtrie class), use them
   // facilitate sparse iteration here.
   if (Destructor)
-    for (TrieSubtrie *Trie = &Impl->Root; Trie; Trie = Trie->Next.load())
-      for (auto &Slot : Trie->Slots)
-        if (auto *Content = dyn_cast_or_null<TrieContent>(Slot.load()))
+    for (TrieSubtrie *Trie = Impl->getRoot(); Trie; Trie = Trie->Next.load())
+      for (unsigned I = 0; I < Trie->size(); ++I)
+        if (auto *Content = dyn_cast_or_null<TrieContent>(Trie->load(I)))
           Destructor(Content->getValuePointer());
 
   // Destroy the subtries. Incidentally, this destroys them in the reverse order
   // of saving.
-  TrieSubtrie *Trie = Impl->Root.Next;
+  TrieSubtrie *Trie = Impl->getRoot()->Next;
   while (Trie) {
     TrieSubtrie *Next = Trie->Next.exchange(nullptr);
     delete Trie;
@@ -369,7 +390,7 @@ ThreadSafeTrieRawHashMapBase::getRoot() const {
   ImplType *Impl = ImplPtr.load();
   if (!Impl)
     return PointerBase();
-  return PointerBase(&Impl->Root);
+  return PointerBase(Impl->getRoot());
 }
 
 unsigned ThreadSafeTrieRawHashMapBase::getStartBit(
@@ -401,8 +422,8 @@ unsigned ThreadSafeTrieRawHashMapBase::getNumSlotUsed(
   if (!S)
     return 0;
   unsigned Num = 0;
-  for (unsigned I = 0, E = S->Slots.size(); I < E; ++I)
-    if (auto *E = S->Slots[I].load())
+  for (unsigned I = 0, E = S->size(); I < E; ++I)
+    if (auto *E = S->load(I))
       ++Num;
   return Num;
 }
@@ -424,8 +445,8 @@ std::string ThreadSafeTrieRawHashMapBase::getTriePrefixAsString(
   while (Current) {
     TrieSubtrie *Next = nullptr;
     // Find first used slot in the trie.
-    for (unsigned I = 0, E = Current->Slots.size(); I < E; ++I) {
-      auto *S = Current->get(I);
+    for (unsigned I = 0, E = Current->size(); I < E; ++I) {
+      auto *S = Current->load(I);
       if (!S)
         continue;
 
@@ -473,7 +494,7 @@ unsigned ThreadSafeTrieRawHashMapBase::getNumTries() const {
   if (!Impl)
     return 0;
   unsigned Num = 0;
-  for (TrieSubtrie *Trie = &Impl->Root; Trie; Trie = Trie->Next.load())
+  for (TrieSubtrie *Trie = Impl->getRoot(); Trie; Trie = Trie->Next.load())
     ++Num;
   return Num;
 }



More information about the llvm-commits mailing list