[llvm] [CAS] LLVMCAS implementation (PR #68448)
David Blaikie via llvm-commits
llvm-commits at lists.llvm.org
Mon Oct 9 12:22:19 PDT 2023
================
@@ -0,0 +1,493 @@
+//===- TrieRawHashMap.cpp -------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/TrieRawHashMap.h"
+#include "TrieHashIndexGenerator.h"
+#include "llvm/ADT/LazyAtomicPointer.h"
+#include "llvm/Support/Allocator.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ThreadSafeAllocator.h"
+#include "llvm/Support/raw_ostream.h"
+#include <memory>
+
+using namespace llvm;
+
+namespace {
+struct TrieNode {
+ const bool IsSubtrie = false;
+
+ TrieNode(bool IsSubtrie) : IsSubtrie(IsSubtrie) {}
+
+ static void *operator new(size_t Size) { return ::malloc(Size); }
+ void operator delete(void *Ptr) { ::free(Ptr); }
+};
+
+struct TrieContent final : public TrieNode {
+ const uint8_t ContentOffset;
+ const uint8_t HashSize;
+ const uint8_t HashOffset;
+
+ void *getValuePointer() const {
+ auto Content = reinterpret_cast<const uint8_t *>(this) + ContentOffset;
+ return const_cast<uint8_t *>(Content);
+ }
+
+ ArrayRef<uint8_t> getHash() const {
+ auto *Begin = reinterpret_cast<const uint8_t *>(this) + HashOffset;
+ return ArrayRef(Begin, Begin + HashSize);
+ }
+
+ TrieContent(size_t ContentOffset, size_t HashSize, size_t HashOffset)
+ : TrieNode(/*IsSubtrie=*/false), ContentOffset(ContentOffset),
+ HashSize(HashSize), HashOffset(HashOffset) {}
+};
+static_assert(sizeof(TrieContent) ==
+ ThreadSafeTrieRawHashMapBase::TrieContentBaseSize,
+ "Check header assumption!");
+
+class TrieSubtrie final : public TrieNode {
+public:
+ TrieNode *get(size_t I) const { return Slots[I].load(); }
+
+ TrieSubtrie *
+ sink(size_t I, TrieContent &Content, size_t NumSubtrieBits, size_t NewI,
+ function_ref<TrieSubtrie *(std::unique_ptr<TrieSubtrie>)> Saver);
+
+ static std::unique_ptr<TrieSubtrie> create(size_t StartBit, size_t NumBits);
+
+ explicit TrieSubtrie(size_t StartBit, size_t NumBits);
+
+private:
+ // FIXME: Use a bitset to speed up access:
+ //
+ // std::array<std::atomic<uint64_t>, NumSlots/64> IsSet;
+ //
+ // This will avoid needing to visit sparsely filled slots in
+ // \a ThreadSafeTrieRawHashMapBase::destroyImpl() when there's a non-trivial
+ // destructor.
+ //
+ // It would also greatly speed up iteration, if we add that some day, and
+ // allow get() to return one level sooner.
+ //
+ // This would be the algorithm for updating IsSet (after updating Slots):
+ //
+ // std::atomic<uint64_t> &Bits = IsSet[I.High];
+ // const uint64_t NewBit = 1ULL << I.Low;
+ // uint64_t Old = 0;
+ // while (!Bits.compare_exchange_weak(Old, Old | NewBit))
+ // ;
+
+ // For debugging.
+ unsigned StartBit = 0;
+ unsigned NumBits = 0;
+ friend class llvm::ThreadSafeTrieRawHashMapBase;
+
+public:
+ /// Linked list for ownership of tries. The pointer is owned by TrieSubtrie.
+ std::atomic<TrieSubtrie *> Next;
+
+ /// The (co-allocated) slots of the subtrie.
+ MutableArrayRef<LazyAtomicPointer<TrieNode>> Slots;
+};
+} // end namespace
+
+namespace llvm {
+template <> struct isa_impl<TrieContent, TrieNode> {
+ static inline bool doit(const TrieNode &TN) { return !TN.IsSubtrie; }
+};
+template <> struct isa_impl<TrieSubtrie, TrieNode> {
+ static inline bool doit(const TrieNode &TN) { return TN.IsSubtrie; }
+};
+} // end namespace llvm
+
+static size_t getTrieTailSize(size_t StartBit, size_t NumBits) {
+ assert(NumBits < 20 && "Tries should have fewer than ~1M slots");
+ return sizeof(TrieNode *) * (1u << NumBits);
+}
+
+std::unique_ptr<TrieSubtrie> TrieSubtrie::create(size_t StartBit,
+ size_t NumBits) {
+ size_t Size = sizeof(TrieSubtrie) + getTrieTailSize(StartBit, NumBits);
+ void *Memory = ::malloc(Size);
+ TrieSubtrie *S = ::new (Memory) TrieSubtrie(StartBit, NumBits);
+ return std::unique_ptr<TrieSubtrie>(S);
+}
+
+TrieSubtrie::TrieSubtrie(size_t StartBit, size_t NumBits)
+ : TrieNode(true), StartBit(StartBit), NumBits(NumBits), Next(nullptr),
+ Slots(reinterpret_cast<LazyAtomicPointer<TrieNode> *>(
+ reinterpret_cast<char *>(this) + sizeof(TrieSubtrie)),
+ (1u << NumBits)) {
+ for (auto *I = Slots.begin(), *E = Slots.end(); I != E; ++I)
+ new (I) LazyAtomicPointer<TrieNode>(nullptr);
+
+ static_assert(
+ std::is_trivially_destructible<LazyAtomicPointer<TrieNode>>::value,
+ "Expected no work in destructor for TrieNode");
+}
+
+TrieSubtrie *TrieSubtrie::sink(
+ size_t I, TrieContent &Content, size_t NumSubtrieBits, size_t NewI,
+ function_ref<TrieSubtrie *(std::unique_ptr<TrieSubtrie>)> Saver) {
+ assert(NumSubtrieBits > 0);
+ std::unique_ptr<TrieSubtrie> S = create(StartBit + NumBits, NumSubtrieBits);
+
+ assert(NewI < S->Slots.size());
+ S->Slots[NewI].store(&Content);
+
+ TrieNode *ExistingNode = &Content;
+ assert(I < Slots.size());
+ if (Slots[I].compare_exchange_strong(ExistingNode, S.get()))
+ return Saver(std::move(S));
+
+ // Another thread created a subtrie already. Return it and let "S" be
+ // destructed.
+ return cast<TrieSubtrie>(ExistingNode);
+}
+
+struct ThreadSafeTrieRawHashMapBase::ImplType {
+ static ImplType *create(size_t StartBit, size_t NumBits) {
+ size_t Size = sizeof(ImplType) + getTrieTailSize(StartBit, NumBits);
+ void *Memory = ::malloc(Size);
+ return ::new (Memory) ImplType(StartBit, NumBits);
+ }
+
+ TrieSubtrie *save(std::unique_ptr<TrieSubtrie> S) {
+ assert(!S->Next && "Expected S to a freshly-constructed leaf");
+
+ TrieSubtrie *CurrentHead = nullptr;
+ // Add ownership of "S" to front of the list, so that Root -> S ->
+ // Root.Next. This works by repeatedly setting S->Next to a candidate value
+ // of Root.Next (initially nullptr), then setting Root.Next to S once the
+ // candidate matches reality.
+ while (!Root.Next.compare_exchange_weak(CurrentHead, S.get()))
+ S->Next.exchange(CurrentHead);
+
+ // Ownership transferred to subtrie.
+ return S.release();
+ }
+
+ static void *operator new(size_t Size) { return ::malloc(Size); }
+ void operator delete(void *Ptr) { ::free(Ptr); }
+
+ /// FIXME: This should take a function that allocates and constructs the
+ /// content lazily (taking the hash as a separate parameter), in case of
+ /// collision.
+ ThreadSafeAllocator<BumpPtrAllocator> ContentAlloc;
+ TrieSubtrie Root; // Must be last! Tail-allocated.
+
+private:
+ ImplType(size_t StartBit, size_t NumBits) : Root(StartBit, NumBits) {}
+};
+
+ThreadSafeTrieRawHashMapBase::ImplType &
+ThreadSafeTrieRawHashMapBase::getOrCreateImpl() {
+ if (ImplType *Impl = ImplPtr.load())
+ return *Impl;
+
+ // Create a new ImplType and store it if another thread doesn't do so first.
+ // If another thread wins this one is destroyed locally.
+ std::unique_ptr<ImplType> Impl(ImplType::create(0, NumRootBits));
+ ImplType *ExistingImpl = nullptr;
+ if (ImplPtr.compare_exchange_strong(ExistingImpl, Impl.get()))
+ return *Impl.release();
+
+ return *ExistingImpl;
+}
+
+ThreadSafeTrieRawHashMapBase::PointerBase
+ThreadSafeTrieRawHashMapBase::find(ArrayRef<uint8_t> Hash) const {
+ assert(!Hash.empty() && "Uninitialized hash");
+
+ ImplType *Impl = ImplPtr.load();
+ if (!Impl)
+ return PointerBase();
+
+ TrieSubtrie *S = &Impl->Root;
+ IndexGenerator IndexGen{NumRootBits, NumSubtrieBits, Hash};
+ size_t Index = IndexGen.next();
+ for (;;) {
+ // Try to set the content.
+ TrieNode *Existing = S->get(Index);
+ if (!Existing)
+ return PointerBase(S, Index, *IndexGen.StartBit);
+
+ // Check for an exact match.
+ if (auto *ExistingContent = dyn_cast<TrieContent>(Existing))
+ return ExistingContent->getHash() == Hash
+ ? PointerBase(ExistingContent->getValuePointer())
+ : PointerBase(S, Index, *IndexGen.StartBit);
+
+ Index = IndexGen.next();
+ S = cast<TrieSubtrie>(Existing);
+ }
+}
+
+ThreadSafeTrieRawHashMapBase::PointerBase ThreadSafeTrieRawHashMapBase::insert(
+ PointerBase Hint, ArrayRef<uint8_t> Hash,
+ function_ref<const uint8_t *(void *Mem, ArrayRef<uint8_t> Hash)>
+ Constructor) {
+ assert(!Hash.empty() && "Uninitialized hash");
+
+ ImplType &Impl = getOrCreateImpl();
+ TrieSubtrie *S = &Impl.Root;
+ IndexGenerator IndexGen{NumRootBits, NumSubtrieBits, Hash};
+ size_t Index;
+ if (Hint.isHint()) {
+ S = static_cast<TrieSubtrie *>(Hint.P);
+ Index = IndexGen.hint(Hint.I, Hint.B);
+ } else {
+ Index = IndexGen.next();
+ }
+
+ for (;;) {
+ // Load the node from the slot, allocating and calling the constructor if
+ // the slot is empty.
+ bool Generated = false;
+ TrieNode &Existing = S->Slots[Index].loadOrGenerate([&]() {
+ Generated = true;
+
+ // Construct the value itself at the tail.
+ uint8_t *Memory = reinterpret_cast<uint8_t *>(
+ Impl.ContentAlloc.Allocate(ContentAllocSize, ContentAllocAlign));
+ const uint8_t *HashStorage = Constructor(Memory + ContentOffset, Hash);
+
+ // Construct the TrieContent header, passing in the offset to the hash.
+ TrieContent *Content = ::new (Memory)
+ TrieContent(ContentOffset, Hash.size(), HashStorage - Memory);
+ assert(Hash == Content->getHash() && "Hash not properly initialized");
+ return Content;
+ });
+ // If we just generated it, return it!
+ if (Generated)
+ return PointerBase(cast<TrieContent>(Existing).getValuePointer());
+
+ if (isa<TrieSubtrie>(Existing)) {
+ S = &cast<TrieSubtrie>(Existing);
+ Index = IndexGen.next();
+ continue;
+ }
+
+ // Return the existing content if it's an exact match!
+ auto &ExistingContent = cast<TrieContent>(Existing);
+ if (ExistingContent.getHash() == Hash)
+ return PointerBase(ExistingContent.getValuePointer());
+
+ // Sink the existing content as long as the indexes match.
+ for (;;) {
+ size_t NextIndex = IndexGen.next();
+ size_t NewIndexForExistingContent =
+ IndexGen.getCollidingBits(ExistingContent.getHash());
+ S = S->sink(Index, ExistingContent, IndexGen.getNumBits(),
+ NewIndexForExistingContent,
+ [&Impl](std::unique_ptr<TrieSubtrie> S) {
+ return Impl.save(std::move(S));
+ });
+ Index = NextIndex;
+
+ // Found the difference.
+ if (NextIndex != NewIndexForExistingContent)
+ break;
+ }
+ }
+}
+
+static void printHexDigit(raw_ostream &OS, uint8_t Digit) {
+ if (Digit < 10)
+ OS << char(Digit + '0');
+ else
+ OS << char(Digit - 10 + 'a');
+}
----------------
dwblaikie wrote:
I think we already have abstractions for printing hex numbers - perhaps we could use those? (& they print whole hex numbers, which might remove the need for some of the `printPrefix` code too)
https://github.com/llvm/llvm-project/pull/68448
More information about the llvm-commits
mailing list