[llvm] [ADT] Add TrieRawHashMap (PR #69528)

Duncan P. N. Exon Smith via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 8 11:24:41 PST 2023


================
@@ -0,0 +1,483 @@
+//===- TrieRawHashMap.cpp -------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/TrieRawHashMap.h"
+#include "TrieHashIndexGenerator.h"
+#include "llvm/ADT/LazyAtomicPointer.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/Allocator.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ThreadSafeAllocator.h"
+#include "llvm/Support/raw_ostream.h"
+#include <memory>
+
+using namespace llvm;
+
+namespace {
+struct TrieNode {
+  const bool IsSubtrie = false;
+
+  TrieNode(bool IsSubtrie) : IsSubtrie(IsSubtrie) {}
+
+  static void *operator new(size_t Size) { return ::malloc(Size); }
+  void operator delete(void *Ptr) { ::free(Ptr); }
+};
+
+struct TrieContent final : public TrieNode {
+  const uint8_t ContentOffset;
+  const uint8_t HashSize;
+  const uint8_t HashOffset;
+
+  void *getValuePointer() const {
+    auto *Content = reinterpret_cast<const uint8_t *>(this) + ContentOffset;
+    return const_cast<uint8_t *>(Content);
+  }
+
+  ArrayRef<uint8_t> getHash() const {
+    auto *Begin = reinterpret_cast<const uint8_t *>(this) + HashOffset;
+    return ArrayRef(Begin, Begin + HashSize);
+  }
+
+  TrieContent(size_t ContentOffset, size_t HashSize, size_t HashOffset)
+      : TrieNode(/*IsSubtrie=*/false), ContentOffset(ContentOffset),
+        HashSize(HashSize), HashOffset(HashOffset) {}
+
+  static bool classof(const TrieNode *TN) { return !TN->IsSubtrie; }
+};
+
+static_assert(sizeof(TrieContent) ==
+                  ThreadSafeTrieRawHashMapBase::TrieContentBaseSize,
+              "Check header assumption!");
+
+class TrieSubtrie final : public TrieNode {
+public:
+  TrieNode *get(size_t I) const { return Slots[I].load(); }
+
+  TrieSubtrie *
+  sink(size_t I, TrieContent &Content, size_t NumSubtrieBits, size_t NewI,
+       function_ref<TrieSubtrie *(std::unique_ptr<TrieSubtrie>)> Saver);
+
+  static std::unique_ptr<TrieSubtrie> create(size_t StartBit, size_t NumBits);
+
+  explicit TrieSubtrie(size_t StartBit, size_t NumBits);
+
+  static bool classof(const TrieNode *TN) { return TN->IsSubtrie; }
+
+private:
+  // FIXME: Use a bitset to speed up access:
+  //
+  //     std::array<std::atomic<uint64_t>, NumSlots/64> IsSet;
+  //
+  // This will avoid needing to visit sparsely filled slots in
+  // \a ThreadSafeTrieRawHashMapBase::destroyImpl() when there's a non-trivial
+  // destructor.
+  //
+  // It would also greatly speed up iteration, if we add that some day, and
+  // allow get() to return one level sooner.
+  //
+  // This would be the algorithm for updating IsSet (after updating Slots):
+  //
+  //     std::atomic<uint64_t> &Bits = IsSet[I.High];
+  //     const uint64_t NewBit = 1ULL << I.Low;
+  //     uint64_t Old = 0;
+  //     while (!Bits.compare_exchange_weak(Old, Old | NewBit))
+  //       ;
+
+  // For debugging.
+  unsigned StartBit = 0;
+  unsigned NumBits = 0;
+  friend class llvm::ThreadSafeTrieRawHashMapBase;
+
+public:
+  /// Linked list for ownership of tries. The pointer is owned by TrieSubtrie.
+  std::atomic<TrieSubtrie *> Next;
+
+  /// The (co-allocated) slots of the subtrie.
+  MutableArrayRef<LazyAtomicPointer<TrieNode>> Slots;
+};
+} // end namespace
+
+static size_t getTrieTailSize(size_t StartBit, size_t NumBits) {
+  assert(NumBits < 20 && "Tries should have fewer than ~1M slots");
+  return sizeof(TrieNode *) * (1u << NumBits);
+}
+
+std::unique_ptr<TrieSubtrie> TrieSubtrie::create(size_t StartBit,
+                                                 size_t NumBits) {
+  size_t Size = sizeof(TrieSubtrie) + getTrieTailSize(StartBit, NumBits);
+  void *Memory = ::malloc(Size);
+  TrieSubtrie *S = ::new (Memory) TrieSubtrie(StartBit, NumBits);
+  return std::unique_ptr<TrieSubtrie>(S);
+}
+
+TrieSubtrie::TrieSubtrie(size_t StartBit, size_t NumBits)
+    : TrieNode(true), StartBit(StartBit), NumBits(NumBits), Next(nullptr),
+      Slots(reinterpret_cast<LazyAtomicPointer<TrieNode> *>(
+                reinterpret_cast<char *>(this) + sizeof(TrieSubtrie)),
+            (1u << NumBits)) {
+  for (auto *I = Slots.begin(), *E = Slots.end(); I != E; ++I)
+    new (I) LazyAtomicPointer<TrieNode>(nullptr);
+
+  static_assert(
+      std::is_trivially_destructible<LazyAtomicPointer<TrieNode>>::value,
+      "Expected no work in destructor for TrieNode");
+}
+
+TrieSubtrie *TrieSubtrie::sink(
+    size_t I, TrieContent &Content, size_t NumSubtrieBits, size_t NewI,
+    function_ref<TrieSubtrie *(std::unique_ptr<TrieSubtrie>)> Saver) {
+  assert(NumSubtrieBits > 0);
+  std::unique_ptr<TrieSubtrie> S = create(StartBit + NumBits, NumSubtrieBits);
+
+  assert(NewI < S->Slots.size());
+  S->Slots[NewI].store(&Content);
+
+  TrieNode *ExistingNode = &Content;
+  assert(I < Slots.size());
+  if (Slots[I].compare_exchange_strong(ExistingNode, S.get()))
+    return Saver(std::move(S));
+
+  // Another thread created a subtrie already. Return it and let "S" be
+  // destructed.
+  return cast<TrieSubtrie>(ExistingNode);
+}
+
+struct ThreadSafeTrieRawHashMapBase::ImplType {
+  static std::unique_ptr<ImplType> create(size_t StartBit, size_t NumBits) {
+    size_t Size = sizeof(ImplType) + getTrieTailSize(StartBit, NumBits);
+    void *Memory = ::malloc(Size);
+    ImplType *Impl = ::new (Memory) ImplType(StartBit, NumBits);
+    return std::unique_ptr<ImplType>(Impl);
+  }
+
+  TrieSubtrie *save(std::unique_ptr<TrieSubtrie> S) {
+    assert(!S->Next && "Expected S to a freshly-constructed leaf");
+
+    TrieSubtrie *CurrentHead = nullptr;
+    // Add ownership of "S" to front of the list, so that Root -> S ->
+    // Root.Next. This works by repeatedly setting S->Next to a candidate value
+    // of Root.Next (initially nullptr), then setting Root.Next to S once the
+    // candidate matches reality.
+    while (!Root.Next.compare_exchange_weak(CurrentHead, S.get()))
+      S->Next.exchange(CurrentHead);
+
+    // Ownership transferred to subtrie successfully. Release the unique_ptr.
+    return S.release();
+  }
+
+  static void *operator new(size_t Size) { return ::malloc(Size); }
+  void operator delete(void *Ptr) { ::free(Ptr); }
+
+  /// FIXME: This should take a function that allocates and constructs the
+  /// content lazily (taking the hash as a separate parameter), in case of
+  /// collision.
+  ThreadSafeAllocator<BumpPtrAllocator> ContentAlloc;
+  TrieSubtrie Root; // Must be last! Tail-allocated.
+
+private:
+  ImplType(size_t StartBit, size_t NumBits) : Root(StartBit, NumBits) {}
+};
+
+ThreadSafeTrieRawHashMapBase::ImplType &
+ThreadSafeTrieRawHashMapBase::getOrCreateImpl() {
+  if (ImplType *Impl = ImplPtr.load())
+    return *Impl;
+
+  // Create a new ImplType and store it if another thread doesn't do so first.
+  // If another thread wins this one is destroyed locally.
+  std::unique_ptr<ImplType> Impl = ImplType::create(0, NumRootBits);
+  ImplType *ExistingImpl = nullptr;
+
+  // If the ownership transferred succesfully, release unique_ptr and return
+  // the pointer to the new ImplType.
+  if (ImplPtr.compare_exchange_strong(ExistingImpl, Impl.get()))
+    return *Impl.release();
+
+  // Already created, return the existing ImplType.
+  return *ExistingImpl;
+}
+
+ThreadSafeTrieRawHashMapBase::PointerBase
+ThreadSafeTrieRawHashMapBase::find(ArrayRef<uint8_t> Hash) const {
+  assert(!Hash.empty() && "Uninitialized hash");
+
+  ImplType *Impl = ImplPtr.load();
+  if (!Impl)
+    return PointerBase();
+
+  TrieSubtrie *S = &Impl->Root;
+  IndexGenerator IndexGen{NumRootBits, NumSubtrieBits, Hash};
+  size_t Index = IndexGen.next();
+  while (true) {
+    // Try to set the content.
+    TrieNode *Existing = S->get(Index);
+    if (!Existing)
+      return PointerBase(S, Index, *IndexGen.StartBit);
+
+    // Check for an exact match.
+    if (auto *ExistingContent = dyn_cast<TrieContent>(Existing))
+      return ExistingContent->getHash() == Hash
+                 ? PointerBase(ExistingContent->getValuePointer())
+                 : PointerBase(S, Index, *IndexGen.StartBit);
+
+    Index = IndexGen.next();
+    S = cast<TrieSubtrie>(Existing);
+  }
+}
+
+ThreadSafeTrieRawHashMapBase::PointerBase ThreadSafeTrieRawHashMapBase::insert(
+    PointerBase Hint, ArrayRef<uint8_t> Hash,
+    function_ref<const uint8_t *(void *Mem, ArrayRef<uint8_t> Hash)>
+        Constructor) {
+  assert(!Hash.empty() && "Uninitialized hash");
+
+  ImplType &Impl = getOrCreateImpl();
+  TrieSubtrie *S = &Impl.Root;
+  IndexGenerator IndexGen{NumRootBits, NumSubtrieBits, Hash};
+  size_t Index;
+  if (Hint.isHint()) {
+    S = static_cast<TrieSubtrie *>(Hint.P);
+    Index = IndexGen.hint(Hint.I, Hint.B);
+  } else {
+    Index = IndexGen.next();
+  }
+
+  while (true) {
+    // Load the node from the slot, allocating and calling the constructor if
+    // the slot is empty.
+    bool Generated = false;
+    TrieNode &Existing = S->Slots[Index].loadOrGenerate([&]() {
+      Generated = true;
+
+      // Construct the value itself at the tail.
+      uint8_t *Memory = reinterpret_cast<uint8_t *>(
+          Impl.ContentAlloc.Allocate(ContentAllocSize, ContentAllocAlign));
+      const uint8_t *HashStorage = Constructor(Memory + ContentOffset, Hash);
+
+      // Construct the TrieContent header, passing in the offset to the hash.
+      TrieContent *Content = ::new (Memory)
+          TrieContent(ContentOffset, Hash.size(), HashStorage - Memory);
+      assert(Hash == Content->getHash() && "Hash not properly initialized");
+      return Content;
+    });
+    // If we just generated it, return it!
+    if (Generated)
+      return PointerBase(cast<TrieContent>(Existing).getValuePointer());
+
+    if (auto *ST = dyn_cast<TrieSubtrie>(&Existing)) {
+      S = ST;
+      Index = IndexGen.next();
+      continue;
+    }
+
+    // Return the existing content if it's an exact match!
+    auto &ExistingContent = cast<TrieContent>(Existing);
+    if (ExistingContent.getHash() == Hash)
+      return PointerBase(ExistingContent.getValuePointer());
+
+    // Sink the existing content as long as the indexes match.
+    while (true) {
+      size_t NextIndex = IndexGen.next();
+      size_t NewIndexForExistingContent =
+          IndexGen.getCollidingBits(ExistingContent.getHash());
+      S = S->sink(Index, ExistingContent, IndexGen.getNumBits(),
+                  NewIndexForExistingContent,
+                  [&Impl](std::unique_ptr<TrieSubtrie> S) {
+                    return Impl.save(std::move(S));
+                  });
+      Index = NextIndex;
+
+      // Found the difference.
+      if (NextIndex != NewIndexForExistingContent)
+        break;
+    }
+  }
+}
+
+ThreadSafeTrieRawHashMapBase::ThreadSafeTrieRawHashMapBase(
+    size_t ContentAllocSize, size_t ContentAllocAlign, size_t ContentOffset,
+    std::optional<size_t> NumRootBits, std::optional<size_t> NumSubtrieBits)
+    : ContentAllocSize(ContentAllocSize), ContentAllocAlign(ContentAllocAlign),
+      ContentOffset(ContentOffset),
+      NumRootBits(NumRootBits ? *NumRootBits : DefaultNumRootBits),
+      NumSubtrieBits(NumSubtrieBits ? *NumSubtrieBits : DefaultNumSubtrieBits),
+      ImplPtr(nullptr) {
+  assert((!NumRootBits || *NumRootBits < 20) &&
----------------
dexonsmith wrote:

The point is catch misuse, not logic errors, so it's okay for the assertion to be relaxed if there's another reasonable limit.

The limit of 20 bits ensures the root's allocation is under 8MB (1M entries). On some platforms, larger contiguous allocations might fail. An assertion could catch platform-dependent memory failures.

Certainly, I imagine a limit of 29 would be palatable? (500M entries, 4GB allocation for the root node)

In any case, it feels like someone is holding the trie "wrong" for root nodes this big. Going bigger is certainly valid in some sense, but I question the benefit of supporting it. If you really want to optimize for super large data sets, you probably want a file-backed trie instead.

I prefer the original limit of 20, unless someone has a specific argument for something else. (I don't feel strongly about what the limit is, just that there should be one.)

https://github.com/llvm/llvm-project/pull/69528


More information about the llvm-commits mailing list