[llvm] [ADT][ConcurrentHashTable] Refactor ConcurrentHashTable. (PR #71932)
via llvm-commits
llvm-commits at lists.llvm.org
Sun Nov 12 14:28:17 PST 2023
https://github.com/avl-llvm updated https://github.com/llvm/llvm-project/pull/71932
>From 8b2dfd1e56e8507ffc8dd7c3c2136857bdfbd098 Mon Sep 17 00:00:00 2001
From: Alexey Lapshin <a.v.lapshin at mail.ru>
Date: Thu, 9 Nov 2023 14:48:13 +0100
Subject: [PATCH] [ADT][ConcurrentHashTable] Refactor ConcurrentHashTable.
This patch adds the following improvements to ConcurrentHashTable:
1. Hashtable is optimized for the case when read operations exceed
write operations. It uses std::shared_mutex for synchronization
now.
2. The type of the mutex may be changed. It allows to use of more
efficient synchronization primitives.
3. Separate implementation for non-aggregate types is added.
---
llvm/include/llvm/ADT/ConcurrentHashtable.h | 831 ++++++++++++------
.../unittests/ADT/ConcurrentHashtableTest.cpp | 271 ++++++
2 files changed, 853 insertions(+), 249 deletions(-)
diff --git a/llvm/include/llvm/ADT/ConcurrentHashtable.h b/llvm/include/llvm/ADT/ConcurrentHashtable.h
index ffbeece1a89345f..215c64a66a516f8 100644
--- a/llvm/include/llvm/ADT/ConcurrentHashtable.h
+++ b/llvm/include/llvm/ADT/ConcurrentHashtable.h
@@ -16,24 +16,21 @@
#include "llvm/Support/Allocator.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Parallel.h"
+#include "llvm/Support/RWMutex.h"
#include "llvm/Support/WithColor.h"
#include "llvm/Support/xxhash.h"
#include <atomic>
#include <cstddef>
#include <iomanip>
-#include <mutex>
#include <sstream>
#include <type_traits>
namespace llvm {
-/// ConcurrentHashTable - is a resizeable concurrent hashtable.
-/// The number of resizings limited up to x2^31. This hashtable is
-/// useful to have efficient access to aggregate data(like strings,
-/// type descriptors...) and to keep only single copy of such
-/// an aggregate. The hashtable allows only concurrent insertions:
+/// This file contains an implementation of resizeable concurrent hashtable.
+/// The hashtable allows only concurrent insertions:
///
-/// KeyDataTy* = insert ( const KeyTy& );
+/// std::pair<DataTy, bool> = insert ( const KeyTy& );
///
/// Data structure:
///
@@ -53,343 +50,679 @@ namespace llvm {
/// Different buckets may have different sizes. If the single bucket is full
/// then the bucket is resized.
///
-/// BucketsArray keeps all buckets. Each bucket keeps an array of Entries
-/// (pointers to KeyDataTy) and another array of entries hashes:
+/// ConcurrentHashTableBase is a base implementation which encapsulates
+/// common operations. The implementation assumes that stored data are
+/// POD. It uses special markers for uninitialised values. Uninintialised
+/// value is either zero, either 0xff(depending on the ZeroIsUndefValue
+/// parameter).
///
-/// BucketsArray[BucketIdx].Hashes[EntryIdx]:
-/// BucketsArray[BucketIdx].Entries[EntryIdx]:
-///
-/// [Bucket 0].Hashes -> [uint32_t][uint32_t]
-/// [Bucket 0].Entries -> [KeyDataTy*][KeyDataTy*]
-///
-/// [Bucket 1].Hashes -> [uint32_t][uint32_t][uint32_t][uint32_t]
-/// [Bucket 1].Entries -> [KeyDataTy*][KeyDataTy*][KeyDataTy*][KeyDataTy*]
-/// .........................
-/// [Bucket N].Hashes -> [uint32_t][uint32_t][uint32_t]
-/// [Bucket N].Entries -> [KeyDataTy*][KeyDataTy*][KeyDataTy*]
-///
-/// ConcurrentHashTableByPtr uses an external thread-safe allocator to allocate
-/// KeyDataTy items.
-
-template <typename KeyTy, typename KeyDataTy, typename AllocatorTy>
-class ConcurrentHashTableInfoByPtr {
-public:
+/// ConcurrentHashTableBase has a MutexTy parameter which should satisfy
+/// std::shared_mutex interface to lock buckets:
+/// - lock_shared/unlock_shared if rehashing is not neccessary.
+/// - lock/unlock when the bucket should be exclusively locked and resized.
+/// To get single-threaded version of ConcurrentHashTableBase set MutexTy
+/// to void.
+
+template <typename KeyTy> struct ConcurrentHashTableBaseInfo {
/// \returns Hash value for the specified \p Key.
static inline uint64_t getHashValue(const KeyTy &Key) {
- return xxh3_64bits(Key);
- }
-
- /// \returns true if both \p LHS and \p RHS are equal.
- static inline bool isEqual(const KeyTy &LHS, const KeyTy &RHS) {
- return LHS == RHS;
- }
-
- /// \returns key for the specified \p KeyData.
- static inline const KeyTy &getKey(const KeyDataTy &KeyData) {
- return KeyData.getKey();
- }
-
- /// \returns newly created object of KeyDataTy type.
- static inline KeyDataTy *create(const KeyTy &Key, AllocatorTy &Allocator) {
- return KeyDataTy::create(Key, Allocator);
+ return std::hash<KeyTy>{}(Key);
}
};
-template <typename KeyTy, typename KeyDataTy, typename AllocatorTy,
- typename Info =
- ConcurrentHashTableInfoByPtr<KeyTy, KeyDataTy, AllocatorTy>>
-class ConcurrentHashTableByPtr {
+template <typename KeyTy, typename DataTy, typename DerivedImplTy,
+ typename Info = ConcurrentHashTableBaseInfo<KeyTy>,
+ typename MutexTy = sys::RWMutex, bool ZeroIsUndefValue = true>
+class ConcurrentHashTableBase {
public:
- ConcurrentHashTableByPtr(
- AllocatorTy &Allocator, uint64_t EstimatedSize = 100000,
+ /// ReservedSize - Specify the number of items for which space
+ /// will be allocated in advance.
+ /// ThreadsNum - Specify the number of threads that will work
+ /// with the table.
+ /// InitialNumberOfBuckets - Specify number of buckets. Small
+ /// number of buckets may lead to high thread
+ /// competitions and slow execution due to cache
+ /// synchronization.
+ ConcurrentHashTableBase(
+ uint64_t ReservedSize = 0,
size_t ThreadsNum = parallel::strategy.compute_thread_count(),
- size_t InitialNumberOfBuckets = 128)
- : MultiThreadAllocator(Allocator) {
+ uint64_t InitialNumberOfBuckets = 0) {
assert((ThreadsNum > 0) && "ThreadsNum must be greater than 0");
- assert((InitialNumberOfBuckets > 0) &&
- "InitialNumberOfBuckets must be greater than 0");
+
+ this->ThreadsNum = ThreadsNum;
// Calculate number of buckets.
- uint64_t EstimatedNumberOfBuckets = ThreadsNum;
- if (ThreadsNum > 1) {
- EstimatedNumberOfBuckets *= InitialNumberOfBuckets;
- EstimatedNumberOfBuckets *= std::max(
- 1,
- countr_zero(PowerOf2Ceil(EstimatedSize / InitialNumberOfBuckets)) >>
- 2);
+ if constexpr (std::is_void<MutexTy>::value)
+ NumberOfBuckets = 1;
+ else {
+ if (InitialNumberOfBuckets)
+ NumberOfBuckets = InitialNumberOfBuckets;
+ else
+ NumberOfBuckets = (ThreadsNum == 1) ? 1 : ThreadsNum * 256;
+ NumberOfBuckets = PowerOf2Ceil(NumberOfBuckets);
}
- EstimatedNumberOfBuckets = PowerOf2Ceil(EstimatedNumberOfBuckets);
- NumberOfBuckets =
- std::min(EstimatedNumberOfBuckets, (uint64_t)(1Ull << 31));
// Allocate buckets.
BucketsArray = std::make_unique<Bucket[]>(NumberOfBuckets);
- InitialBucketSize = EstimatedSize / NumberOfBuckets;
- InitialBucketSize = std::max((uint32_t)1, InitialBucketSize);
- InitialBucketSize = PowerOf2Ceil(InitialBucketSize);
+ uint64_t InitialBucketSize =
+ calculateBucketSizeFromOverallSize(ReservedSize);
// Initialize each bucket.
- for (uint32_t Idx = 0; Idx < NumberOfBuckets; Idx++) {
- HashesPtr Hashes = new ExtHashBitsTy[InitialBucketSize];
- memset(Hashes, 0, sizeof(ExtHashBitsTy) * InitialBucketSize);
-
- DataPtr Entries = new EntryDataTy[InitialBucketSize];
- memset(Entries, 0, sizeof(EntryDataTy) * InitialBucketSize);
-
- BucketsArray[Idx].Size = InitialBucketSize;
- BucketsArray[Idx].Hashes = Hashes;
- BucketsArray[Idx].Entries = Entries;
+ for (uint64_t CurIdx = 0; CurIdx < NumberOfBuckets; ++CurIdx) {
+ BucketsArray[CurIdx].Size = InitialBucketSize;
+ BucketsArray[CurIdx].Data =
+ static_cast<DerivedImplTy *>(this)->allocateData(InitialBucketSize);
}
- // Calculate masks.
- HashMask = NumberOfBuckets - 1;
-
- size_t LeadingZerosNumber = countl_zero(HashMask);
- HashBitsNum = 64 - LeadingZerosNumber;
+ // Calculate mask.
+ BucketsHashMask = NumberOfBuckets - 1;
- // We keep only high 32-bits of hash value. So bucket size cannot
- // exceed 2^31. Bucket size is always power of two.
- MaxBucketSize = 1Ull << (std::min((size_t)31, LeadingZerosNumber));
-
- // Calculate mask for extended hash bits.
- ExtHashMask = (uint64_t)NumberOfBuckets * MaxBucketSize - 1;
+ size_t LeadingZerosNumber = countl_zero(BucketsHashMask);
+ BucketsHashBitsNum = 64 - LeadingZerosNumber;
}
- virtual ~ConcurrentHashTableByPtr() {
- // Deallocate buckets.
- for (uint32_t Idx = 0; Idx < NumberOfBuckets; Idx++) {
- delete[] BucketsArray[Idx].Hashes;
- delete[] BucketsArray[Idx].Entries;
- }
- }
-
- /// Insert new value \p NewValue or return already existing entry.
- ///
- /// \returns entry and "true" if an entry is just inserted or
- /// "false" if an entry already exists.
- std::pair<KeyDataTy *, bool> insert(const KeyTy &NewValue) {
- // Calculate bucket index.
- uint64_t Hash = Info::getHashValue(NewValue);
- Bucket &CurBucket = BucketsArray[getBucketIdx(Hash)];
- uint32_t ExtHashBits = getExtHashBits(Hash);
-
-#if LLVM_ENABLE_THREADS
- // Lock bucket.
- CurBucket.Guard.lock();
-#endif
-
- HashesPtr BucketHashes = CurBucket.Hashes;
- DataPtr BucketEntries = CurBucket.Entries;
- uint32_t CurEntryIdx = getStartIdx(ExtHashBits, CurBucket.Size);
-
- while (true) {
- uint32_t CurEntryHashBits = BucketHashes[CurEntryIdx];
-
- if (CurEntryHashBits == 0 && BucketEntries[CurEntryIdx] == nullptr) {
- // Found empty slot. Insert data.
- KeyDataTy *NewData = Info::create(NewValue, MultiThreadAllocator);
- BucketEntries[CurEntryIdx] = NewData;
- BucketHashes[CurEntryIdx] = ExtHashBits;
-
- CurBucket.NumberOfEntries++;
- RehashBucket(CurBucket);
-
-#if LLVM_ENABLE_THREADS
- CurBucket.Guard.unlock();
-#endif
-
- return {NewData, true};
+ /// Erase content of table.
+ void clear(uint64_t ReservedSize = 0) {
+ if (ReservedSize) {
+ for (uint64_t CurIdx = 0; CurIdx < NumberOfBuckets; ++CurIdx) {
+ Bucket &CurBucket = BucketsArray[CurIdx];
+ delete[] CurBucket.Data;
+ uint64_t BucketSize = calculateBucketSizeFromOverallSize(ReservedSize);
+ CurBucket.Size = BucketSize;
+ CurBucket.Data =
+ static_cast<DerivedImplTy *>(this)->allocateData(BucketSize);
}
-
- if (CurEntryHashBits == ExtHashBits) {
- // Hash matched. Check value for equality.
- KeyDataTy *EntryData = BucketEntries[CurEntryIdx];
- if (Info::isEqual(Info::getKey(*EntryData), NewValue)) {
- // Already existed entry matched with inserted data is found.
-#if LLVM_ENABLE_THREADS
- CurBucket.Guard.unlock();
-#endif
-
- return {EntryData, false};
- }
+ } else {
+ for (uint64_t CurIdx = 0; CurIdx < NumberOfBuckets; ++CurIdx) {
+ Bucket &CurBucket = BucketsArray[CurIdx];
+ uint64_t BufferSize =
+ static_cast<DerivedImplTy *>(this)->getBufferSize(CurBucket.Size);
+ fillBufferWithUndefValue(CurBucket.Data, BufferSize);
}
-
- CurEntryIdx++;
- CurEntryIdx &= (CurBucket.Size - 1);
}
+ }
- llvm_unreachable("Insertion error.");
- return {};
+ ~ConcurrentHashTableBase() {
+ // Deallocate buckets.
+ for (uint64_t Idx = 0; Idx < NumberOfBuckets; Idx++)
+ delete[] BucketsArray[Idx].Data;
}
/// Print information about current state of hash table structures.
void printStatistic(raw_ostream &OS) {
OS << "\n--- HashTable statistic:\n";
OS << "\nNumber of buckets = " << NumberOfBuckets;
- OS << "\nInitial bucket size = " << InitialBucketSize;
- uint64_t NumberOfNonEmptyBuckets = 0;
- uint64_t NumberOfEntriesPlusEmpty = 0;
+ uint64_t OverallNumberOfAllocatedEntries = 0;
uint64_t OverallNumberOfEntries = 0;
uint64_t OverallSize = sizeof(*this) + NumberOfBuckets * sizeof(Bucket);
- DenseMap<uint32_t, uint32_t> BucketSizesMap;
+ DenseMap<uint64_t, uint64_t> BucketSizesMap;
// For each bucket...
- for (uint32_t Idx = 0; Idx < NumberOfBuckets; Idx++) {
+ for (uint64_t Idx = 0; Idx < NumberOfBuckets; Idx++) {
Bucket &CurBucket = BucketsArray[Idx];
BucketSizesMap[CurBucket.Size]++;
- if (CurBucket.NumberOfEntries != 0)
- NumberOfNonEmptyBuckets++;
- NumberOfEntriesPlusEmpty += CurBucket.Size;
- OverallNumberOfEntries += CurBucket.NumberOfEntries;
+ OverallNumberOfAllocatedEntries += CurBucket.Size;
+ OverallNumberOfEntries +=
+ calculateNumberOfEntries(CurBucket.Data, CurBucket.Size);
OverallSize +=
- (sizeof(ExtHashBitsTy) + sizeof(EntryDataTy)) * CurBucket.Size;
+ static_cast<DerivedImplTy *>(this)->getBufferSize(CurBucket.Size);
}
OS << "\nOverall number of entries = " << OverallNumberOfEntries;
- OS << "\nOverall number of non empty buckets = " << NumberOfNonEmptyBuckets;
- for (auto &BucketSize : BucketSizesMap)
- OS << "\n Number of buckets with size " << BucketSize.first << ": "
- << BucketSize.second;
+ OS << "\nOverall allocated size = " << OverallSize;
std::stringstream stream;
stream << std::fixed << std::setprecision(2)
- << ((float)OverallNumberOfEntries / (float)NumberOfEntriesPlusEmpty);
+ << ((float)OverallNumberOfEntries /
+ (float)OverallNumberOfAllocatedEntries);
std::string str = stream.str();
OS << "\nLoad factor = " << str;
- OS << "\nOverall allocated size = " << OverallSize;
+ for (auto &BucketSize : BucketSizesMap)
+ OS << "\n Number of buckets with size " << BucketSize.first << ": "
+ << BucketSize.second;
}
protected:
- using ExtHashBitsTy = uint32_t;
- using EntryDataTy = KeyDataTy *;
-
- using HashesPtr = ExtHashBitsTy *;
- using DataPtr = EntryDataTy *;
+ struct VoidMutex {
+ inline void lock_shared() {}
+ inline void unlock_shared() {}
+ inline void lock() {}
+ inline void unlock() {}
+ };
- // Bucket structure. Keeps bucket data.
- struct Bucket {
+ /// Bucket structure. Keeps bucket data.
+ struct Bucket
+ : public
+#if LLVM_ENABLE_THREADS
+ std::conditional_t<std::is_void<MutexTy>::value, VoidMutex, MutexTy>
+#else
+ VoidMutex
+#endif
+ {
Bucket() = default;
- // Size of bucket.
- uint32_t Size = 0;
+ /// Size of bucket.
+ uint64_t Size;
- // Number of non-null entries.
- uint32_t NumberOfEntries = 0;
+ /// Buffer keeping bucket data.
+ uint8_t *Data;
+ };
- // Hashes for [Size] entries.
- HashesPtr Hashes = nullptr;
+ void fillBufferWithUndefValue(uint8_t *Data, uint64_t BufferSize) const {
+ if constexpr (ZeroIsUndefValue)
+ memset(Data, 0, BufferSize);
+ else
+ memset(Data, 0xff, BufferSize);
+ }
- // [Size] entries.
- DataPtr Entries = nullptr;
+ uint64_t calculateNumberOfEntries(uint8_t *Data, uint64_t Size) {
+ uint64_t Result = 0;
+ for (uint64_t CurIdx = 0; CurIdx < Size; CurIdx++) {
+ auto &AtomicData =
+ static_cast<DerivedImplTy *>(this)->getDataEntry(Data, CurIdx, Size);
+ if (!isNull(AtomicData.load()))
+ Result++;
+ }
-#if LLVM_ENABLE_THREADS
- // Mutex for this bucket.
- std::mutex Guard;
-#endif
- };
+ return Result;
+ }
- // Reallocate and rehash bucket if this is full enough.
- void RehashBucket(Bucket &CurBucket) {
- assert((CurBucket.Size > 0) && "Uninitialised bucket");
- if (CurBucket.NumberOfEntries < CurBucket.Size * 0.9)
+ uint64_t calculateBucketSizeFromOverallSize(uint64_t OverallSize) const {
+ uint64_t BucketSize = OverallSize / NumberOfBuckets;
+ BucketSize = std::max((uint64_t)1, BucketSize);
+ BucketSize = PowerOf2Ceil(BucketSize);
+ return BucketSize;
+ }
+
+ template <typename T> static inline bool isNull(T Data) {
+ if constexpr (ZeroIsUndefValue)
+ return Data == 0;
+ else if constexpr (sizeof(Data) == 1)
+ return reinterpret_cast<uint8_t>(Data) == 0xff;
+ else if constexpr (sizeof(Data) == 2)
+ return reinterpret_cast<uint16_t>(Data) == 0xffff;
+ else if constexpr (sizeof(Data) == 4)
+ return reinterpret_cast<uint32_t>(Data) == 0xffffffff;
+ else if constexpr (sizeof(Data) == 8)
+ return reinterpret_cast<uint64_t>(Data) == 0xffffffffffffffff;
+
+ llvm_unreachable("Unsupported data size");
+ }
+
+ /// Common implementation of insert method. This implementation selects
+ /// bucket, locks bucket and calls to child implementation which does final
+ /// insertion.
+ template <typename... Args>
+ std::pair<DataTy, bool> insert(const KeyTy &NewKey, Args... args) {
+ // Calculate hash.
+ uint64_t Hash = Info::getHashValue(NewKey);
+ // Get bucket.
+ Bucket &CurBucket = BucketsArray[getBucketIdx(Hash)];
+
+ // Calculate extendend hash bits.
+ uint64_t ExtHashBits = Hash >> BucketsHashBitsNum;
+ std::pair<DataTy, bool> Result;
+
+ while (true) {
+ uint64_t BucketSizeForRehashing = 0;
+ CurBucket.lock_shared();
+ // Call child implementation.
+ if (static_cast<DerivedImplTy *>(this)->insertImpl(
+ CurBucket, ExtHashBits, NewKey, Result, args...)) {
+ CurBucket.unlock_shared();
+ return Result;
+ }
+
+ BucketSizeForRehashing = CurBucket.Size;
+ CurBucket.unlock_shared();
+
+ // Rehash bucket.
+ rehashBucket(CurBucket, BucketSizeForRehashing);
+ }
+
+ llvm_unreachable("Unhandled path of insert() method");
+ return {};
+ }
+
+ /// Rehash bucket data.
+ void rehashBucket(Bucket &CurBucket, uint64_t BucketSizeForRehashing) {
+ CurBucket.lock();
+ uint64_t OldSize = CurBucket.Size;
+ if (BucketSizeForRehashing != OldSize) {
+ CurBucket.unlock();
return;
+ }
- if (CurBucket.Size >= MaxBucketSize)
- report_fatal_error("ConcurrentHashTable is full");
+ uint8_t *OldData = CurBucket.Data;
+ uint64_t NewSize = OldSize << 1;
+ uint8_t *NewData =
+ static_cast<DerivedImplTy *>(this)->allocateData(NewSize);
+
+ // Iterate through old data.
+ for (uint64_t CurIdx = 0; CurIdx < OldSize; ++CurIdx) {
+ auto &AtomicData = static_cast<DerivedImplTy *>(this)->getDataEntry(
+ OldData, CurIdx, OldSize);
+ auto CurData = AtomicData.load(std::memory_order_acquire);
+
+ // Check data entry for null value.
+ if (!isNull(CurData)) {
+ auto &AtomicKey = static_cast<DerivedImplTy *>(this)->getKeyEntry(
+ OldData, CurIdx, OldSize);
+ auto CurKey = AtomicKey.load(std::memory_order_acquire);
+
+ // Get index for position in the new bucket.
+ uint64_t ExtHashBits =
+ static_cast<DerivedImplTy *>(this)->getExtHashBits(CurKey);
+ uint64_t NewIdx = getStartIdx(ExtHashBits, NewSize);
+ while (true) {
+ auto &NewAtomicData =
+ static_cast<DerivedImplTy *>(this)->getDataEntry(NewData, NewIdx,
+ NewSize);
+ auto NewCurData = NewAtomicData.load(std::memory_order_acquire);
+
+ if (isNull(NewCurData)) {
+ // Store data entry and key into the new bucket data.
+ NewAtomicData.store(CurData, std::memory_order_release);
+ auto &NewAtomicKey =
+ static_cast<DerivedImplTy *>(this)->getKeyEntry(NewData, NewIdx,
+ NewSize);
+ NewAtomicKey.store(CurKey, std::memory_order_release);
+ break;
+ }
+
+ ++NewIdx;
+ NewIdx &= (NewSize - 1);
+ }
+ }
+ }
- uint32_t NewBucketSize = CurBucket.Size << 1;
- assert((NewBucketSize <= MaxBucketSize) && "New bucket size is too big");
- assert((CurBucket.Size < NewBucketSize) &&
- "New bucket size less than size of current bucket");
+ CurBucket.Size = NewSize;
+ CurBucket.Data = NewData;
+ CurBucket.unlock();
- // Store old entries & hashes arrays.
- HashesPtr SrcHashes = CurBucket.Hashes;
- DataPtr SrcEntries = CurBucket.Entries;
+ delete[] OldData;
+ }
- // Allocate new entries&hashes arrays.
- HashesPtr DestHashes = new ExtHashBitsTy[NewBucketSize];
- memset(DestHashes, 0, sizeof(ExtHashBitsTy) * NewBucketSize);
+ uint64_t getBucketIdx(hash_code Hash) { return Hash & BucketsHashMask; }
- DataPtr DestEntries = new EntryDataTy[NewBucketSize];
- memset(DestEntries, 0, sizeof(EntryDataTy) * NewBucketSize);
+ uint64_t getStartIdx(uint64_t ExtHashBits, uint64_t BucketSize) {
+ assert((BucketSize > 0) && "Empty bucket");
- // For each entry in source arrays...
- for (uint32_t CurSrcEntryIdx = 0; CurSrcEntryIdx < CurBucket.Size;
- CurSrcEntryIdx++) {
- uint32_t CurSrcEntryHashBits = SrcHashes[CurSrcEntryIdx];
+ return ExtHashBits & (BucketSize - 1);
+ }
- // Check for null entry.
- if (CurSrcEntryHashBits == 0 && SrcEntries[CurSrcEntryIdx] == nullptr)
- continue;
+ /// Number of bits in hash mask.
+ uint8_t BucketsHashBitsNum = 0;
- uint32_t StartDestIdx = getStartIdx(CurSrcEntryHashBits, NewBucketSize);
+ /// Hash mask.
+ uint64_t BucketsHashMask = 0;
- // Insert non-null entry into the new arrays.
- while (true) {
- uint32_t CurDestEntryHashBits = DestHashes[StartDestIdx];
+ /// Array of buckets.
+ std::unique_ptr<Bucket[]> BucketsArray;
- if (CurDestEntryHashBits == 0 && DestEntries[StartDestIdx] == nullptr) {
- // Found empty slot. Insert data.
- DestHashes[StartDestIdx] = CurSrcEntryHashBits;
- DestEntries[StartDestIdx] = SrcEntries[CurSrcEntryIdx];
- break;
+ /// The number of buckets.
+ uint64_t NumberOfBuckets = 0;
+
+ /// Number of available threads.
+ size_t ThreadsNum = 0;
+};
+
+/// ConcurrentHashTable: This class is optimized for small data like
+/// uint32_t or uint64_t. It keeps keys and data in the internal table.
+/// Keys and data should have equal alignment and size. They also should
+/// satisfy requirements for atomic operations.
+///
+/// Bucket.Data contains an array of pairs [ DataTy, KeyTy ]:
+///
+/// [Bucket].Data -> [DataTy0][KeyTy0]...[DataTyN][KeyTyN]
+
+template <typename KeyTy> class ConcurrentHashTableInfo {
+public:
+ /// \returns Hash value for the specified \p Key.
+ static inline uint64_t getHashValue(KeyTy Key) {
+ return std::hash<KeyTy>{}(Key);
+ }
+
+ /// \returns true if both \p LHS and \p RHS are equal.
+ static inline bool isEqual(KeyTy LHS, KeyTy RHS) { return LHS == RHS; }
+};
+
+template <typename KeyTy, typename DataTy,
+ typename Info = ConcurrentHashTableInfo<KeyTy>,
+ typename MutexTy = sys::RWMutex, bool ZeroIsUndefValue = false,
+ uint64_t MaxProbeCount = 512>
+class ConcurrentHashTable
+ : public ConcurrentHashTableBase<
+ KeyTy, DataTy,
+ ConcurrentHashTable<KeyTy, DataTy, Info, MutexTy, ZeroIsUndefValue,
+ MaxProbeCount>,
+ Info, MutexTy, ZeroIsUndefValue> {
+ using SuperClass = ConcurrentHashTableBase<
+ KeyTy, DataTy,
+ ConcurrentHashTable<KeyTy, DataTy, Info, MutexTy, ZeroIsUndefValue>, Info,
+ MutexTy, ZeroIsUndefValue>;
+ friend SuperClass;
+
+ using Bucket = typename SuperClass::Bucket;
+ using AtomicEntryTy = std::atomic<DataTy>;
+ using AtomicKeyTy = std::atomic<KeyTy>;
+
+ static_assert(sizeof(KeyTy) == sizeof(DataTy));
+ static_assert(alignof(KeyTy) == alignof(DataTy));
+
+public:
+ ConcurrentHashTable(
+ uint64_t ReservedSize = 0,
+ size_t ThreadsNum = parallel::strategy.compute_thread_count(),
+ uint64_t InitialNumberOfBuckets = 0)
+ : ConcurrentHashTableBase<
+ KeyTy, DataTy,
+ ConcurrentHashTable<KeyTy, DataTy, Info, MutexTy, ZeroIsUndefValue,
+ MaxProbeCount>,
+ Info, MutexTy, ZeroIsUndefValue>(ReservedSize, ThreadsNum,
+ InitialNumberOfBuckets) {}
+
+ std::pair<DataTy, bool> insert(const KeyTy &Key,
+ function_ref<DataTy(KeyTy)> onInsert) {
+ return SuperClass::insert(Key, onInsert);
+ }
+
+protected:
+ /// Returns size of the buffer required to keep bucket data of \p Size.
+ uint64_t getBufferSize(uint64_t Size) const {
+ return (sizeof(DataTy) + sizeof(KeyTy)) * Size;
+ }
+
+ /// Allocates bucket data.
+ uint8_t *allocateData(uint64_t Size) const {
+ uint64_t BufferSize = getBufferSize(Size);
+ uint8_t *Data = static_cast<uint8_t *>(
+ llvm::allocate_buffer(BufferSize, alignof(DataTy)));
+ SuperClass::fillBufferWithUndefValue(Data, BufferSize);
+ return Data;
+ }
+
+ /// Returns reference to data entry with index /p CurIdx.
+ LLVM_ATTRIBUTE_ALWAYS_INLINE AtomicEntryTy &
+ getDataEntry(uint8_t *Data, uint64_t CurIdx, uint64_t Size) {
+ return *(reinterpret_cast<AtomicEntryTy *>(
+ Data + (sizeof(DataTy) + sizeof(KeyTy)) * CurIdx));
+ }
+
+ /// Returns reference to key entry with index /p CurIdx.
+ LLVM_ATTRIBUTE_ALWAYS_INLINE AtomicKeyTy &
+ getKeyEntry(uint8_t *Data, uint64_t CurIdx, uint64_t Size) {
+ return *(reinterpret_cast<AtomicKeyTy *>(
+ Data + (sizeof(KeyTy) + sizeof(DataTy)) * CurIdx + sizeof(DataTy)));
+ }
+
+ /// Returns extended hash bits value for specified key.
+ LLVM_ATTRIBUTE_ALWAYS_INLINE uint64_t getExtHashBits(KeyTy Key) const {
+ return (Info::getHashValue(Key)) >> SuperClass::BucketsHashBitsNum;
+ }
+
+ /// Inserts data returned by \p onInsert into the hashtable.
+ /// a) If data was inserted returns true and set \p Result.second = true
+ /// and \p Result.first = Data.
+ /// b) If data was found returns true and set \p Result.second = false
+ /// and \p Result.first = Data.
+ /// c) If the table is full returns false.
+ LLVM_ATTRIBUTE_ALWAYS_INLINE bool
+ insertImpl(Bucket &CurBucket, uint64_t ExtHashBits, const KeyTy &NewKey,
+ std::pair<DataTy, bool> &Result,
+ function_ref<DataTy(KeyTy)> onInsert) {
+ assert(!SuperClass::isNull(NewKey) && "Null key value");
+
+ uint64_t BucketSize = CurBucket.Size;
+ uint8_t *Data = CurBucket.Data;
+ uint64_t BucketMaxProbeCount = std::min(BucketSize, MaxProbeCount);
+ uint64_t CurProbeCount = 0;
+ uint64_t CurEntryIdx = SuperClass::getStartIdx(ExtHashBits, BucketSize);
+
+ while (CurProbeCount < BucketMaxProbeCount) {
+ AtomicKeyTy &AtomicKey = getKeyEntry(Data, CurEntryIdx, BucketSize);
+ KeyTy CurKey = AtomicKey.load(std::memory_order_acquire);
+
+ AtomicEntryTy &AtomicEntry = getDataEntry(Data, CurEntryIdx, BucketSize);
+
+ if (SuperClass::isNull(CurKey)) {
+ // Found empty slot. Insert data.
+ if (AtomicKey.compare_exchange_strong(CurKey, NewKey)) {
+ DataTy NewData = onInsert(NewKey);
+ assert(!SuperClass::isNull(NewData) && "Null data value");
+
+ AtomicEntry.store(NewData, std::memory_order_release);
+ Result.first = NewData;
+ Result.second = true;
+ return true;
}
- StartDestIdx++;
- StartDestIdx = StartDestIdx & (NewBucketSize - 1);
+ // The slot is overwritten from another thread. Retry slot probing.
+ continue;
+ } else if (Info::isEqual(CurKey, NewKey)) {
+ // Already existed entry matched with inserted data is found.
+
+ DataTy CurData = AtomicEntry.load(std::memory_order_acquire);
+ while (SuperClass::isNull(CurData))
+ CurData = AtomicEntry.load(std::memory_order_acquire);
+
+ Result.first = CurData;
+ Result.second = false;
+ return true;
}
+
+ CurProbeCount++;
+ CurEntryIdx++;
+ CurEntryIdx &= (BucketSize - 1);
}
- // Update bucket fields.
- CurBucket.Hashes = DestHashes;
- CurBucket.Entries = DestEntries;
- CurBucket.Size = NewBucketSize;
+ return false;
+ }
+};
+
+/// ConcurrentHashTableByPtr: This class is optimized for the case when key
+/// and/or data is an aggregate type. It keeps hash instead of the key and
+/// pointer to the data allocated in external thread-safe allocator.
+/// This hashtable is useful to have efficient access to aggregate data(like
+/// strings, type descriptors...) and to keep only single copy of such an
+/// aggregate.
+///
+/// To save space it keeps only 32-bits of the hash value. Which limits number
+/// of resizings for single bucket up to x2^31.
+///
+/// Bucket.Data contains an array of EntryDataTy first and then array of
+/// ExtHashBitsTy:
+///
+/// [Bucket].Data ->
+/// [EntryDataTy0]...[EntryDataTyN][ExtHashBitsTy0]...[ExtHashBitsTyN]
- // Delete old bucket entries.
- if (SrcHashes != nullptr)
- delete[] SrcHashes;
- if (SrcEntries != nullptr)
- delete[] SrcEntries;
+template <typename KeyTy, typename KeyDataTy, typename AllocatorTy>
+class ConcurrentHashTableInfoByPtr {
+public:
+ /// \returns Hash value for the specified \p Key.
+ static inline uint64_t getHashValue(const KeyTy &Key) {
+ return xxh3_64bits(Key);
}
- uint32_t getBucketIdx(hash_code Hash) { return Hash & HashMask; }
+ /// \returns true if both \p LHS and \p RHS are equal.
+ static inline bool isEqual(const KeyTy &LHS, const KeyTy &RHS) {
+ return LHS == RHS;
+ }
- uint32_t getExtHashBits(uint64_t Hash) {
- return (Hash & ExtHashMask) >> HashBitsNum;
+ /// \returns key for the specified \p KeyData.
+ static inline const KeyTy &getKey(const KeyDataTy &KeyData) {
+ return KeyData.getKey();
}
- uint32_t getStartIdx(uint32_t ExtHashBits, uint32_t BucketSize) {
- assert((BucketSize > 0) && "Empty bucket");
+ /// \returns newly created object of KeyDataTy type.
+ static inline KeyDataTy *create(const KeyTy &Key, AllocatorTy &Allocator) {
+ return KeyDataTy::create(Key, Allocator);
+ }
+};
- return ExtHashBits & (BucketSize - 1);
+template <typename KeyTy, typename KeyDataTy, typename AllocatorTy,
+ typename Info =
+ ConcurrentHashTableInfoByPtr<KeyTy, KeyDataTy, AllocatorTy>,
+ typename MutexTy = sys::RWMutex, bool ZeroIsUndefValue = true,
+ uint64_t MaxProbeCount = 512>
+class ConcurrentHashTableByPtr
+ : public ConcurrentHashTableBase<
+ KeyTy, KeyDataTy *,
+ ConcurrentHashTableByPtr<KeyTy, KeyDataTy, AllocatorTy, Info, MutexTy,
+ ZeroIsUndefValue, MaxProbeCount>,
+ Info, MutexTy, ZeroIsUndefValue> {
+ using SuperClass = ConcurrentHashTableBase<
+ KeyTy, KeyDataTy *,
+ ConcurrentHashTableByPtr<KeyTy, KeyDataTy, AllocatorTy, Info, MutexTy,
+ ZeroIsUndefValue, MaxProbeCount>,
+ Info, MutexTy, ZeroIsUndefValue>;
+ friend SuperClass;
+
+ using Bucket = typename SuperClass::Bucket;
+
+ using EntryDataTy = KeyDataTy *;
+ using ExtHashBitsTy = uint32_t;
+
+ using AtomicEntryDataTy = std::atomic<EntryDataTy>;
+ using AtomicExtHashBitsTy = std::atomic<ExtHashBitsTy>;
+
+ static constexpr uint64_t MaxBucketSize = 1Ull << 31;
+ static constexpr uint64_t MaxNumberOfBuckets = 0xFFFFFFFFUll;
+
+ static_assert(alignof(EntryDataTy) >= alignof(ExtHashBitsTy),
+ "EntryDataTy alignment must be greater or equal to "
+ "ExtHashBitsTy alignment");
+ static_assert(
+ (alignof(EntryDataTy) % alignof(ExtHashBitsTy)) == 0,
+ "EntryDataTy alignment must be a multiple of ExtHashBitsTy alignment");
+
+public:
+ ConcurrentHashTableByPtr(
+ AllocatorTy &Allocator, uint64_t ReservedSize = 0,
+ size_t ThreadsNum = parallel::strategy.compute_thread_count(),
+ uint64_t InitialNumberOfBuckets = 0)
+ : ConcurrentHashTableBase<
+ KeyTy, KeyDataTy *,
+ ConcurrentHashTableByPtr<KeyTy, KeyDataTy, AllocatorTy, Info,
+ MutexTy, ZeroIsUndefValue, MaxProbeCount>,
+ Info, MutexTy, ZeroIsUndefValue>(ReservedSize, ThreadsNum,
+ InitialNumberOfBuckets),
+ MultiThreadAllocator(Allocator) {
+ assert((SuperClass::NumberOfBuckets <= MaxNumberOfBuckets) &&
+ "NumberOfBuckets is too big");
}
- // Number of bits in hash mask.
- uint64_t HashBitsNum = 0;
+ std::pair<KeyDataTy *, bool> insert(const KeyTy &Key) {
+ KeyDataTy *NewData = nullptr;
+ return SuperClass::insert(Key, NewData);
+ }
- // Hash mask.
- uint64_t HashMask = 0;
+protected:
+ /// Returns size of the buffer required to keep bucket data of \p Size.
+ uint64_t getBufferSize(uint64_t Size) const {
+ return (sizeof(EntryDataTy) + sizeof(ExtHashBitsTy)) * Size;
+ }
- // Hash mask for the extended hash bits.
- uint64_t ExtHashMask = 0;
+ /// Allocates bucket data.
+ uint8_t *allocateData(uint64_t Size) const {
+ uint64_t BufferSize = getBufferSize(Size);
+ uint8_t *Data = static_cast<uint8_t *>(
+ llvm::allocate_buffer(BufferSize, alignof(EntryDataTy)));
+ SuperClass::fillBufferWithUndefValue(Data, BufferSize);
+ return Data;
+ }
- // The maximal bucket size.
- uint32_t MaxBucketSize = 0;
+ /// Returns reference to data entry with index /p CurIdx.
+ LLVM_ATTRIBUTE_ALWAYS_INLINE AtomicEntryDataTy &
+ getDataEntry(uint8_t *Data, uint64_t CurIdx, uint64_t) {
+ return *(reinterpret_cast<AtomicEntryDataTy *>(
+ Data + sizeof(AtomicEntryDataTy) * CurIdx));
+ }
+
+ /// Returns reference to key entry with index /p CurIdx.
+ LLVM_ATTRIBUTE_ALWAYS_INLINE AtomicExtHashBitsTy &
+ getKeyEntry(uint8_t *Data, uint64_t CurIdx, uint64_t Size) {
+ return *(reinterpret_cast<AtomicExtHashBitsTy *>(
+ Data + sizeof(AtomicEntryDataTy) * Size +
+ sizeof(AtomicExtHashBitsTy) * CurIdx));
+ }
- // Initial size of bucket.
- uint32_t InitialBucketSize = 0;
+ /// Returns extended hash bits value for specified key. We keep hash value
+ /// instead of the key. So we can use kept value instead of calculating hash
+ /// again.
+ LLVM_ATTRIBUTE_ALWAYS_INLINE ExtHashBitsTy
+ getExtHashBits(ExtHashBitsTy Key) const {
+ return Key;
+ }
- // The number of buckets.
- uint32_t NumberOfBuckets = 0;
+ /// Inserts data created from \p NewKey into the hashtable.
+ /// a) If data was inserted then returns true and set \p Result.second =
+ /// true and \p Result.first = KeyDataTy*.
+ /// b) If data was found returns true and set \p Result.second = false
+ /// and \p Result.first = KeyDataTy*.
+ /// c) If the table is full returns false.
+ LLVM_ATTRIBUTE_ALWAYS_INLINE bool
+ insertImpl(Bucket &CurBucket, uint64_t ExtHashBits, const KeyTy &NewKey,
+ std::pair<KeyDataTy *, bool> &Result, KeyDataTy *&NewData) {
+ uint64_t BucketSize = CurBucket.Size;
+ uint8_t *Data = CurBucket.Data;
+ uint64_t BucketMaxProbeCount = std::min(BucketSize, MaxProbeCount);
+ uint64_t CurProbeCount = 0;
+ uint64_t CurEntryIdx = SuperClass::getStartIdx(ExtHashBits, BucketSize);
+
+ while (CurProbeCount < BucketMaxProbeCount) {
+ AtomicExtHashBitsTy &AtomicKey =
+ getKeyEntry(Data, CurEntryIdx, BucketSize);
+ ExtHashBitsTy CurHashBits = AtomicKey.load(std::memory_order_acquire);
+
+ if (CurHashBits == static_cast<ExtHashBitsTy>(ExtHashBits) ||
+ SuperClass::isNull(CurHashBits)) {
+ AtomicEntryDataTy &AtomicData =
+ getDataEntry(Data, CurEntryIdx, BucketSize);
+ EntryDataTy EntryData = AtomicData.load(std::memory_order_acquire);
+ if (SuperClass::isNull(EntryData)) {
+ // Found empty slot. Insert data.
+ if (!NewData)
+ NewData = Info::create(NewKey, MultiThreadAllocator);
- // Array of buckets.
- std::unique_ptr<Bucket[]> BucketsArray;
+ if (AtomicData.compare_exchange_strong(EntryData, NewData)) {
+
+ AtomicKey.store(ExtHashBits, std::memory_order_release);
+ Result.first = NewData;
+ Result.second = true;
+ return true;
+ }
+
+ // The slot is overwritten from another thread. Retry slot probing.
+ continue;
+ } else if (Info::isEqual(Info::getKey(*EntryData), NewKey)) {
+ // Hash matched. Check value for equality.
+ if (NewData)
+ MultiThreadAllocator.Deallocate(NewData);
+
+ // Already existed entry matched with inserted data is found.
+ Result.first = EntryData;
+ Result.second = false;
+ return true;
+ }
+ }
+
+ CurProbeCount++;
+ CurEntryIdx++;
+ CurEntryIdx &= (BucketSize - 1);
+ }
+
+ if (BucketSize == MaxBucketSize)
+ report_fatal_error("ConcurrentHashTableByPtr is full");
+
+ return false;
+ }
// Used for allocating KeyDataTy values.
AllocatorTy &MultiThreadAllocator;
diff --git a/llvm/unittests/ADT/ConcurrentHashtableTest.cpp b/llvm/unittests/ADT/ConcurrentHashtableTest.cpp
index ee1ee41f453a306..82821bdb9424d96 100644
--- a/llvm/unittests/ADT/ConcurrentHashtableTest.cpp
+++ b/llvm/unittests/ADT/ConcurrentHashtableTest.cpp
@@ -19,6 +19,277 @@ using namespace llvm;
using namespace parallel;
namespace {
+
+TEST(ConcurrentHashTableTest, AddIntEntries) {
+ ConcurrentHashTable<uint32_t, uint32_t> HashTable(10);
+
+ std::function<uint32_t(uint32_t)> InsertionFunc =
+ [&](uint32_t Key) -> uint32_t { return Key; };
+
+ std::pair<uint32_t, bool> res1 = HashTable.insert(1, InsertionFunc);
+ // Check entry is inserted.
+ EXPECT_TRUE(res1.first == 1);
+ EXPECT_TRUE(res1.second);
+
+ std::pair<uint32_t, bool> res2 = HashTable.insert(2, InsertionFunc);
+ // Check old entry is still valid.
+ EXPECT_TRUE(res1.first == 1);
+ // Check new entry is inserted.
+ EXPECT_TRUE(res2.first == 2);
+ EXPECT_TRUE(res2.second);
+ // Check new and old entries use different memory.
+ EXPECT_TRUE(res1.first != res2.first);
+
+ std::pair<uint32_t, bool> res3 = HashTable.insert(3, InsertionFunc);
+ // Check one more entry is inserted.
+ EXPECT_TRUE(res3.first == 3);
+ EXPECT_TRUE(res3.second);
+
+ std::pair<uint32_t, bool> res4 = HashTable.insert(1, InsertionFunc);
+ // Check duplicated entry is inserted.
+ EXPECT_TRUE(res4.first == 1);
+ EXPECT_FALSE(res4.second);
+ // Check duplicated entry uses the same memory.
+ EXPECT_TRUE(res1.first == res4.first);
+
+ // Check first entry is still valid.
+ EXPECT_TRUE(res1.first == 1);
+
+ // Check statistic.
+ std::string StatisticString;
+ raw_string_ostream StatisticStream(StatisticString);
+ HashTable.printStatistic(StatisticStream);
+
+ EXPECT_TRUE(StatisticString.find("Overall number of entries = 3\n") !=
+ std::string::npos);
+}
+
+TEST(ConcurrentHashTableTest, AddIntMultiplueEntries) {
+ const size_t NumElements = 10000;
+ ConcurrentHashTable<uint32_t, uint32_t> HashTable;
+
+ std::function<uint32_t(uint32_t)> InsertionFunc =
+ [&](uint32_t Key) -> uint32_t { return Key; };
+
+ // Check insertion.
+ for (uint32_t I = 0; I < NumElements; I++) {
+ std::pair<uint32_t, bool> Entry = HashTable.insert(I, InsertionFunc);
+ EXPECT_TRUE(Entry.second);
+ EXPECT_TRUE(Entry.first == I);
+ }
+
+ std::string StatisticString;
+ raw_string_ostream StatisticStream(StatisticString);
+ HashTable.printStatistic(StatisticStream);
+
+ // Verifying that the table contains exactly the number of elements we
+ // inserted.
+ EXPECT_TRUE(StatisticString.find("Overall number of entries = 10000\n") !=
+ std::string::npos);
+
+ // Check insertion of duplicates.
+ for (uint32_t I = 0; I < NumElements; I++) {
+ std::pair<uint32_t, bool> Entry = HashTable.insert(I, InsertionFunc);
+ EXPECT_FALSE(Entry.second);
+ EXPECT_TRUE(Entry.first == I);
+ }
+
+ // Check statistic.
+ // Verifying that the table contains exactly the number of elements we
+ // inserted.
+ EXPECT_TRUE(StatisticString.find("Overall number of entries = 10000\n") !=
+ std::string::npos);
+}
+
+TEST(ConcurrentHashTableTest, AddIntMultiplueEntriesWithClearance) {
+ const size_t NumElements = 100;
+ ConcurrentHashTable<uint32_t, uint32_t> HashTable;
+
+ std::function<uint32_t(uint32_t)> InsertionFunc =
+ [&](uint32_t Key) -> uint32_t { return Key; };
+
+ // Check insertion.
+ for (uint32_t I = 0; I < NumElements; I++) {
+ std::pair<uint32_t, bool> Entry = HashTable.insert(I, InsertionFunc);
+ EXPECT_TRUE(Entry.second);
+ EXPECT_TRUE(Entry.first == I);
+ }
+
+ std::string StatisticString;
+ raw_string_ostream StatisticStream(StatisticString);
+ HashTable.printStatistic(StatisticStream);
+
+ // Verifying that the table contains exactly the number of elements we
+ // inserted.
+ EXPECT_TRUE(StatisticString.find("Overall number of entries = 100\n") !=
+ std::string::npos);
+
+ HashTable.clear();
+
+ // Check insertion of duplicates.
+ for (uint32_t I = 0; I < NumElements; I++) {
+ std::pair<uint32_t, bool> Entry = HashTable.insert(I, InsertionFunc);
+ EXPECT_TRUE(Entry.second);
+ EXPECT_TRUE(Entry.first == I);
+ }
+
+ // Check statistic.
+ // Verifying that the table contains exactly the number of elements we
+ // inserted.
+ EXPECT_TRUE(StatisticString.find("Overall number of entries = 100\n") !=
+ std::string::npos);
+}
+
+TEST(ConcurrentHashTableTest, AddIntMultiplueEntriesWithResize) {
+ const size_t NumElements = 20000;
+ ConcurrentHashTable<uint32_t, uint32_t> HashTable(10);
+
+ std::function<uint32_t(uint32_t)> InsertionFunc =
+ [&](uint32_t Key) -> uint32_t { return Key; };
+
+ // Check insertion.
+ for (uint32_t I = 0; I < NumElements; I++) {
+ std::pair<uint32_t, bool> Entry = HashTable.insert(I, InsertionFunc);
+ EXPECT_TRUE(Entry.second);
+ EXPECT_TRUE(Entry.first == I);
+ }
+
+ std::string StatisticString;
+ raw_string_ostream StatisticStream(StatisticString);
+ HashTable.printStatistic(StatisticStream);
+
+ // Verifying that the table contains exactly the number of elements we
+ // inserted.
+ EXPECT_TRUE(StatisticString.find("Overall number of entries = 20000\n") !=
+ std::string::npos);
+
+ // Check insertion of duplicates.
+ for (uint32_t I = 0; I < NumElements; I++) {
+ std::pair<uint32_t, bool> Entry = HashTable.insert(I, InsertionFunc);
+ EXPECT_FALSE(Entry.second);
+ EXPECT_TRUE(Entry.first == I);
+ }
+
+ // Check statistic.
+ // Verifying that the table contains exactly the number of elements we
+ // inserted.
+ EXPECT_TRUE(StatisticString.find("Overall number of entries = 20000\n") !=
+ std::string::npos);
+}
+
+TEST(ConcurrentHashTableTest, AddIntEntriesParallel) {
+ const size_t NumElements = 10000;
+ ConcurrentHashTable<uint32_t, uint32_t> HashTable;
+
+ std::function<uint32_t(uint32_t)> InsertionFunc =
+ [&](uint32_t Key) -> uint32_t { return Key * 100; };
+
+ // Check parallel insertion.
+ parallelFor(0, NumElements, [&](size_t I) {
+ std::pair<uint32_t, bool> Entry = HashTable.insert(I, InsertionFunc);
+ EXPECT_TRUE(Entry.second);
+ EXPECT_TRUE(Entry.first == I * 100);
+ });
+
+ std::string StatisticString;
+ raw_string_ostream StatisticStream(StatisticString);
+ HashTable.printStatistic(StatisticStream);
+
+ // Verifying that the table contains exactly the number of elements we
+ // inserted.
+ EXPECT_TRUE(StatisticString.find("Overall number of entries = 10000\n") !=
+ std::string::npos);
+
+ // Check parallel insertion of duplicates.
+ parallelFor(0, NumElements, [&](size_t I) {
+ std::pair<uint32_t, bool> Entry = HashTable.insert(I, InsertionFunc);
+ EXPECT_FALSE(Entry.second);
+ EXPECT_TRUE(Entry.first == I * 100);
+ });
+
+ // Check statistic.
+ // Verifying that the table contains exactly the number of elements we
+ // inserted.
+ EXPECT_TRUE(StatisticString.find("Overall number of entries = 10000\n") !=
+ std::string::npos);
+}
+
+TEST(ConcurrentHashTableTest, AddIntEntriesParallelWithResize) {
+ const size_t NumElements = 20000;
+ ConcurrentHashTable<uint64_t, uint64_t> HashTable(100);
+
+ std::function<uint64_t(uint64_t)> InsertionFunc =
+ [&](uint64_t Key) -> uint64_t { return Key * 100; };
+
+ // Check parallel insertion.
+ parallelFor(0, NumElements, [&](size_t I) {
+ std::pair<uint64_t, bool> Entry = HashTable.insert(I, InsertionFunc);
+ EXPECT_TRUE(Entry.second);
+ EXPECT_TRUE(Entry.first == I * 100);
+ });
+
+ std::string StatisticString;
+ raw_string_ostream StatisticStream(StatisticString);
+ HashTable.printStatistic(StatisticStream);
+
+ // Verifying that the table contains exactly the number of elements we
+ // inserted.
+ EXPECT_TRUE(StatisticString.find("Overall number of entries = 20000\n") !=
+ std::string::npos);
+
+ // Check parallel insertion of duplicates.
+ parallelFor(0, NumElements, [&](size_t I) {
+ std::pair<uint64_t, bool> Entry = HashTable.insert(I, InsertionFunc);
+ EXPECT_FALSE(Entry.second);
+ EXPECT_TRUE(Entry.first == I * 100);
+ });
+
+ // Check statistic.
+ // Verifying that the table contains exactly the number of elements we
+ // inserted.
+ EXPECT_TRUE(StatisticString.find("Overall number of entries = 20000\n") !=
+ std::string::npos);
+}
+
+TEST(ConcurrentHashTableTest, AddIntEntriesNoThreads) {
+ const size_t NumElements = 500;
+ ConcurrentHashTable<uint32_t, uint32_t, ConcurrentHashTableInfo<uint32_t>,
+ void>
+ HashTable;
+
+ std::function<uint32_t(uint32_t)> InsertionFunc =
+ [&](uint32_t Key) -> uint32_t { return Key; };
+
+ // Check insertion.
+ for (uint32_t I = 0; I < NumElements; I++) {
+ std::pair<uint32_t, bool> Entry = HashTable.insert(I, InsertionFunc);
+ EXPECT_TRUE(Entry.second);
+ EXPECT_TRUE(Entry.first == I);
+ }
+
+ std::string StatisticString;
+ raw_string_ostream StatisticStream(StatisticString);
+ HashTable.printStatistic(StatisticStream);
+
+ // Verifying that the table contains exactly the number of elements we
+ // inserted.
+ EXPECT_TRUE(StatisticString.find("Overall number of entries = 500\n") !=
+ std::string::npos);
+
+ // Check insertion of duplicates.
+ for (uint32_t I = 0; I < NumElements; I++) {
+ std::pair<uint32_t, bool> Entry = HashTable.insert(I, InsertionFunc);
+ EXPECT_FALSE(Entry.second);
+ EXPECT_TRUE(Entry.first == I);
+ }
+
+ // Check statistic.
+ // Verifying that the table contains exactly the number of elements we
+ // inserted.
+ EXPECT_TRUE(StatisticString.find("Overall number of entries = 500\n") !=
+ std::string::npos);
+}
+
class String {
public:
String() {}
More information about the llvm-commits
mailing list