[llvm] 6ab43f9 - [Support] Add PerThreadBumpPtrAllocator class.

Alexey Lapshin via llvm-commits llvm-commits at lists.llvm.org
Sat May 6 05:36:19 PDT 2023


Author: Alexey Lapshin
Date: 2023-05-06T14:35:26+02:00
New Revision: 6ab43f9b87ce982fed7073d91da6c5f027321b53

URL: https://github.com/llvm/llvm-project/commit/6ab43f9b87ce982fed7073d91da6c5f027321b53
DIFF: https://github.com/llvm/llvm-project/commit/6ab43f9b87ce982fed7073d91da6c5f027321b53.diff

LOG: [Support] Add PerThreadBumpPtrAllocator class.

PerThreadBumpPtrAllocator allows separating allocations by thread id.
That makes allocations race free. It is possible because
ThreadPoolExecutor class creates threads, keeps them until
the destructor of ThreadPoolExecutor is called, and assigns ids
to the threads. Thus PerThreadBumpPtrAllocator should be used with only
threads created by ThreadPoolExecutor. This allocator is useful when
thread safe BumpPtrAllocator is needed.

Reviewed By: MaskRay, dexonsmith, andrewng

Differential Revision: https://reviews.llvm.org/D142318

Added: 
    llvm/include/llvm/Support/PerThreadBumpPtrAllocator.h
    llvm/unittests/Support/PerThreadBumpPtrAllocatorTest.cpp

Modified: 
    llvm/include/llvm/DWARFLinkerParallel/StringPool.h
    llvm/include/llvm/Support/Parallel.h
    llvm/lib/Support/Parallel.cpp
    llvm/unittests/ADT/ConcurrentHashtableTest.cpp
    llvm/unittests/DWARFLinkerParallel/StringPoolTest.cpp
    llvm/unittests/DWARFLinkerParallel/StringTableTest.cpp
    llvm/unittests/Support/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/DWARFLinkerParallel/StringPool.h b/llvm/include/llvm/DWARFLinkerParallel/StringPool.h
index 5406e03c8af19..4828ff4e4b05d 100644
--- a/llvm/include/llvm/DWARFLinkerParallel/StringPool.h
+++ b/llvm/include/llvm/DWARFLinkerParallel/StringPool.h
@@ -12,6 +12,7 @@
 #include "llvm/ADT/ConcurrentHashtable.h"
 #include "llvm/CodeGen/DwarfStringPoolEntry.h"
 #include "llvm/Support/Allocator.h"
+#include "llvm/Support/PerThreadBumpPtrAllocator.h"
 #include <string>
 #include <string_view>
 
@@ -22,28 +23,6 @@ namespace dwarflinker_parallel {
 /// and a string body which is placed right after StringEntry.
 using StringEntry = StringMapEntry<DwarfStringPoolEntry *>;
 
-class StringAllocator : public AllocatorBase<StringAllocator> {
-public:
-  inline LLVM_ATTRIBUTE_RETURNS_NONNULL void *Allocate(size_t Size,
-                                                       size_t Alignment) {
-#if LLVM_ENABLE_THREADS
-    std::lock_guard<std::mutex> Guard(AllocatorMutex);
-#endif
-
-    return Allocator.Allocate(Size, Align(Alignment));
-  }
-
-  // Pull in base class overloads.
-  using AllocatorBase<StringAllocator>::Allocate;
-
-private:
-#if LLVM_ENABLE_THREADS
-  std::mutex AllocatorMutex;
-#endif
-
-  BumpPtrAllocator Allocator;
-};
-
 class StringPoolEntryInfo {
 public:
   /// \returns Hash value for the specified \p Key.
@@ -62,28 +41,31 @@ class StringPoolEntryInfo {
   }
 
   /// \returns newly created object of KeyDataTy type.
-  static inline StringEntry *create(const StringRef &Key,
-                                    StringAllocator &Allocator) {
+  static inline StringEntry *
+  create(const StringRef &Key, parallel::PerThreadBumpPtrAllocator &Allocator) {
     return StringEntry::create(Key, Allocator);
   }
 };
 
 class StringPool
-    : public ConcurrentHashTableByPtr<StringRef, StringEntry, StringAllocator,
+    : public ConcurrentHashTableByPtr<StringRef, StringEntry,
+                                      parallel::PerThreadBumpPtrAllocator,
                                       StringPoolEntryInfo> {
 public:
   StringPool()
-      : ConcurrentHashTableByPtr<StringRef, StringEntry, StringAllocator,
+      : ConcurrentHashTableByPtr<StringRef, StringEntry,
+                                 parallel::PerThreadBumpPtrAllocator,
                                  StringPoolEntryInfo>(Allocator) {}
 
   StringPool(size_t InitialSize)
-      : ConcurrentHashTableByPtr<StringRef, StringEntry, StringAllocator,
+      : ConcurrentHashTableByPtr<StringRef, StringEntry,
+                                 parallel::PerThreadBumpPtrAllocator,
                                  StringPoolEntryInfo>(Allocator, InitialSize) {}
 
-  StringAllocator &getAllocatorRef() { return Allocator; }
+  parallel::PerThreadBumpPtrAllocator &getAllocatorRef() { return Allocator; }
 
 private:
-  StringAllocator Allocator;
+  parallel::PerThreadBumpPtrAllocator Allocator;
 };
 
 } // end of namespace dwarflinker_parallel

diff  --git a/llvm/include/llvm/Support/Parallel.h b/llvm/include/llvm/Support/Parallel.h
index 75e7e8d597c44..8170da98f15a8 100644
--- a/llvm/include/llvm/Support/Parallel.h
+++ b/llvm/include/llvm/Support/Parallel.h
@@ -48,8 +48,11 @@ extern thread_local unsigned threadIndex;
 
 inline unsigned getThreadIndex() { GET_THREAD_INDEX_IMPL; }
 #endif
+
+size_t getThreadCount();
 #else
 inline unsigned getThreadIndex() { return 0; }
+inline size_t getThreadCount() { return 1; }
 #endif
 
 namespace detail {

diff  --git a/llvm/include/llvm/Support/PerThreadBumpPtrAllocator.h b/llvm/include/llvm/Support/PerThreadBumpPtrAllocator.h
new file mode 100644
index 0000000000000..f94d18f62e9ab
--- /dev/null
+++ b/llvm/include/llvm/Support/PerThreadBumpPtrAllocator.h
@@ -0,0 +1,120 @@
+//===- PerThreadBumpPtrAllocator.h ------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_SUPPORT_PERTHREADBUMPPTRALLOCATOR_H
+#define LLVM_SUPPORT_PERTHREADBUMPPTRALLOCATOR_H
+
+#include "llvm/Support/Allocator.h"
+#include "llvm/Support/Parallel.h"
+
+namespace llvm {
+namespace parallel {
+
+/// PerThreadAllocator is used in conjunction with ThreadPoolExecutor to allow
+/// per-thread allocations. It wraps a possibly thread-unsafe allocator,
+/// e.g. BumpPtrAllocator. PerThreadAllocator must be used with only main thread
+/// or threads created by ThreadPoolExecutor, as it utilizes getThreadIndex,
+/// which is set by ThreadPoolExecutor. To work properly, ThreadPoolExecutor
+/// should be initialized before PerThreadAllocator is created.
+/// TODO: The same approach might be implemented for ThreadPool.
+
+template <typename AllocatorTy>
+class PerThreadAllocator
+    : public AllocatorBase<PerThreadAllocator<AllocatorTy>> {
+public:
+  PerThreadAllocator()
+      : NumOfAllocators(parallel::getThreadCount()),
+        Allocators(std::make_unique<AllocatorTy[]>(NumOfAllocators)) {}
+
+  /// \defgroup Methods which could be called asynchronously:
+  ///
+  /// @{
+
+  using AllocatorBase<PerThreadAllocator<AllocatorTy>>::Allocate;
+
+  using AllocatorBase<PerThreadAllocator<AllocatorTy>>::Deallocate;
+
+  /// Allocate \a Size bytes of \a Alignment aligned memory.
+  void *Allocate(size_t Size, size_t Alignment) {
+    assert(getThreadIndex() < NumOfAllocators);
+    return Allocators[getThreadIndex()].Allocate(Size, Alignment);
+  }
+
+  /// Deallocate \a Ptr to \a Size bytes of memory allocated by this
+  /// allocator.
+  void Deallocate(const void *Ptr, size_t Size, size_t Alignment) {
+    assert(getThreadIndex() < NumOfAllocators);
+    return Allocators[getThreadIndex()].Deallocate(Ptr, Size, Alignment);
+  }
+
+  /// Return allocator corresponding to the current thread.
+  AllocatorTy &getThreadLocalAllocator() {
+    assert(getThreadIndex() < NumOfAllocators);
+    return Allocators[getThreadIndex()];
+  }
+
+  // Return number of used allocators.
+  size_t getNumberOfAllocators() const { return NumOfAllocators; }
+  /// @}
+
+  /// \defgroup Methods which could not be called asynchronously:
+  ///
+  /// @{
+
+  /// Reset state of allocators.
+  void Reset() {
+    for (size_t Idx = 0; Idx < getNumberOfAllocators(); Idx++)
+      Allocators[Idx].Reset();
+  }
+
+  /// Return total memory size used by all allocators.
+  size_t getTotalMemory() const {
+    size_t TotalMemory = 0;
+
+    for (size_t Idx = 0; Idx < getNumberOfAllocators(); Idx++)
+      TotalMemory += Allocators[Idx].getTotalMemory();
+
+    return TotalMemory;
+  }
+
+  /// Return allocated size by all allocators.
+  size_t getBytesAllocated() const {
+    size_t BytesAllocated = 0;
+
+    for (size_t Idx = 0; Idx < getNumberOfAllocators(); Idx++)
+      BytesAllocated += Allocators[Idx].getBytesAllocated();
+
+    return BytesAllocated;
+  }
+
+  /// Set red zone for all allocators.
+  void setRedZoneSize(size_t NewSize) {
+    for (size_t Idx = 0; Idx < getNumberOfAllocators(); Idx++)
+      Allocators[Idx].setRedZoneSize(NewSize);
+  }
+
+  /// Print statistic for each allocator.
+  void PrintStats() const {
+    for (size_t Idx = 0; Idx < getNumberOfAllocators(); Idx++) {
+      errs() << "\n Allocator " << Idx << "\n";
+      Allocators[Idx].PrintStats();
+    }
+  }
+  /// @}
+
+protected:
+  size_t NumOfAllocators;
+  std::unique_ptr<AllocatorTy[]> Allocators;
+};
+
+using PerThreadBumpPtrAllocator = PerThreadAllocator<BumpPtrAllocator>;
+
+} // end namespace parallel
+} // end namespace llvm
+
+#endif // LLVM_SUPPORT_PERTHREADBUMPPTRALLOCATOR_H

diff  --git a/llvm/lib/Support/Parallel.cpp b/llvm/lib/Support/Parallel.cpp
index f54479069cf67..9b14b05b52116 100644
--- a/llvm/lib/Support/Parallel.cpp
+++ b/llvm/lib/Support/Parallel.cpp
@@ -40,6 +40,7 @@ class Executor {
 public:
   virtual ~Executor() = default;
   virtual void add(std::function<void()> func, bool Sequential = false) = 0;
+  virtual size_t getThreadCount() const = 0;
 
   static Executor *getDefaultExecutor();
 };
@@ -49,7 +50,7 @@ class Executor {
 class ThreadPoolExecutor : public Executor {
 public:
   explicit ThreadPoolExecutor(ThreadPoolStrategy S = hardware_concurrency()) {
-    unsigned ThreadCount = S.compute_thread_count();
+    ThreadCount = S.compute_thread_count();
     // Spawn all but one of the threads in another thread as spawning threads
     // can take a while.
     Threads.reserve(ThreadCount);
@@ -58,7 +59,7 @@ class ThreadPoolExecutor : public Executor {
     // Use operator[] before creating the thread to avoid data race in .size()
     // in “safe libc++” mode.
     auto &Thread0 = Threads[0];
-    Thread0 = std::thread([this, ThreadCount, S] {
+    Thread0 = std::thread([this, S] {
       for (unsigned I = 1; I < ThreadCount; ++I) {
         Threads.emplace_back([=] { work(S, I); });
         if (Stop)
@@ -108,6 +109,8 @@ class ThreadPoolExecutor : public Executor {
     Cond.notify_one();
   }
 
+  size_t getThreadCount() const override { return ThreadCount; }
+
 private:
   bool hasSequentialTasks() const {
     return !WorkQueueSequential.empty() && !SequentialQueueIsLocked;
@@ -149,6 +152,7 @@ class ThreadPoolExecutor : public Executor {
   std::condition_variable Cond;
   std::promise<void> ThreadsCreated;
   std::vector<std::thread> Threads;
+  unsigned ThreadCount;
 };
 
 Executor *Executor::getDefaultExecutor() {
@@ -178,6 +182,10 @@ Executor *Executor::getDefaultExecutor() {
 }
 } // namespace
 } // namespace detail
+
+size_t getThreadCount() {
+  return detail::Executor::getDefaultExecutor()->getThreadCount();
+}
 #endif
 
 // Latch::sync() called by the dtor may cause one thread to block. If is a dead

diff  --git a/llvm/unittests/ADT/ConcurrentHashtableTest.cpp b/llvm/unittests/ADT/ConcurrentHashtableTest.cpp
index 895bde85ea9e7..ee1ee41f453a3 100644
--- a/llvm/unittests/ADT/ConcurrentHashtableTest.cpp
+++ b/llvm/unittests/ADT/ConcurrentHashtableTest.cpp
@@ -10,11 +10,13 @@
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/Parallel.h"
+#include "llvm/Support/PerThreadBumpPtrAllocator.h"
 #include "gtest/gtest.h"
 #include <limits>
 #include <random>
 #include <vector>
 using namespace llvm;
+using namespace parallel;
 
 namespace {
 class String {
@@ -36,186 +38,195 @@ class String {
   std::array<char, 0x20> ExtraData;
 };
 
-class SimpleAllocator : public AllocatorBase<SimpleAllocator> {
-public:
-  inline LLVM_ATTRIBUTE_RETURNS_NONNULL void *Allocate(size_t Size,
-                                                       size_t Alignment) {
-#if LLVM_ENABLE_THREADS
-    std::lock_guard<std::mutex> Guard(AllocatorMutex);
-#endif
-
-    return Allocator.Allocate(Size, Align(Alignment));
-  }
-  inline size_t getBytesAllocated() {
-#if LLVM_ENABLE_THREADS
-    std::lock_guard<std::mutex> Guard(AllocatorMutex);
-#endif
-
-    return Allocator.getBytesAllocated();
-  }
-
-  // Pull in base class overloads.
-  using AllocatorBase<SimpleAllocator>::Allocate;
-
-protected:
-#if LLVM_ENABLE_THREADS
-  std::mutex AllocatorMutex;
-#endif
-  BumpPtrAllocator Allocator;
-} Allocator;
-
 TEST(ConcurrentHashTableTest, AddStringEntries) {
-  ConcurrentHashTableByPtr<
-      std::string, String, SimpleAllocator,
-      ConcurrentHashTableInfoByPtr<std::string, String, SimpleAllocator>>
+  PerThreadBumpPtrAllocator Allocator;
+  ConcurrentHashTableByPtr<std::string, String, PerThreadBumpPtrAllocator,
+                           ConcurrentHashTableInfoByPtr<
+                               std::string, String, PerThreadBumpPtrAllocator>>
       HashTable(Allocator, 10);
 
-  size_t AllocatedBytesAtStart = Allocator.getBytesAllocated();
-  std::pair<String *, bool> res1 = HashTable.insert("1");
-  // Check entry is inserted.
-  EXPECT_TRUE(res1.first->getKey() == "1");
-  EXPECT_TRUE(res1.second);
-
-  std::pair<String *, bool> res2 = HashTable.insert("2");
-  // Check old entry is still valid.
-  EXPECT_TRUE(res1.first->getKey() == "1");
-  // Check new entry is inserted.
-  EXPECT_TRUE(res2.first->getKey() == "2");
-  EXPECT_TRUE(res2.second);
-  // Check new and old entries use 
diff erent memory.
-  EXPECT_TRUE(res1.first != res2.first);
-
-  std::pair<String *, bool> res3 = HashTable.insert("3");
-  // Check one more entry is inserted.
-  EXPECT_TRUE(res3.first->getKey() == "3");
-  EXPECT_TRUE(res3.second);
-
-  std::pair<String *, bool> res4 = HashTable.insert("1");
-  // Check duplicated entry is inserted.
-  EXPECT_TRUE(res4.first->getKey() == "1");
-  EXPECT_FALSE(res4.second);
-  // Check duplicated entry uses the same memory.
-  EXPECT_TRUE(res1.first == res4.first);
-
-  // Check first entry is still valid.
-  EXPECT_TRUE(res1.first->getKey() == "1");
-
-  // Check data was allocated by allocator.
-  EXPECT_TRUE(Allocator.getBytesAllocated() > AllocatedBytesAtStart);
+  // PerThreadBumpPtrAllocator should be accessed from threads created by
+  // ThreadPoolExecutor. Use TaskGroup to run on ThreadPoolExecutor threads.
+  parallel::TaskGroup tg;
 
-  // Check statistic.
-  std::string StatisticString;
-  raw_string_ostream StatisticStream(StatisticString);
-  HashTable.printStatistic(StatisticStream);
+  tg.spawn([&]() {
+    size_t AllocatedBytesAtStart = Allocator.getBytesAllocated();
+    std::pair<String *, bool> res1 = HashTable.insert("1");
+    // Check entry is inserted.
+    EXPECT_TRUE(res1.first->getKey() == "1");
+    EXPECT_TRUE(res1.second);
+
+    std::pair<String *, bool> res2 = HashTable.insert("2");
+    // Check old entry is still valid.
+    EXPECT_TRUE(res1.first->getKey() == "1");
+    // Check new entry is inserted.
+    EXPECT_TRUE(res2.first->getKey() == "2");
+    EXPECT_TRUE(res2.second);
+    // Check new and old entries use 
diff erent memory.
+    EXPECT_TRUE(res1.first != res2.first);
+
+    std::pair<String *, bool> res3 = HashTable.insert("3");
+    // Check one more entry is inserted.
+    EXPECT_TRUE(res3.first->getKey() == "3");
+    EXPECT_TRUE(res3.second);
+
+    std::pair<String *, bool> res4 = HashTable.insert("1");
+    // Check duplicated entry is inserted.
+    EXPECT_TRUE(res4.first->getKey() == "1");
+    EXPECT_FALSE(res4.second);
+    // Check duplicated entry uses the same memory.
+    EXPECT_TRUE(res1.first == res4.first);
+
+    // Check first entry is still valid.
+    EXPECT_TRUE(res1.first->getKey() == "1");
+
+    // Check data was allocated by allocator.
+    EXPECT_TRUE(Allocator.getBytesAllocated() > AllocatedBytesAtStart);
 
-  EXPECT_TRUE(StatisticString.find("Overall number of entries = 3\n") !=
-              std::string::npos);
+    // Check statistic.
+    std::string StatisticString;
+    raw_string_ostream StatisticStream(StatisticString);
+    HashTable.printStatistic(StatisticStream);
+
+    EXPECT_TRUE(StatisticString.find("Overall number of entries = 3\n") !=
+                std::string::npos);
+  });
 }
 
 TEST(ConcurrentHashTableTest, AddStringMultiplueEntries) {
+  PerThreadBumpPtrAllocator Allocator;
   const size_t NumElements = 10000;
-  ConcurrentHashTableByPtr<
-      std::string, String, SimpleAllocator,
-      ConcurrentHashTableInfoByPtr<std::string, String, SimpleAllocator>>
+  ConcurrentHashTableByPtr<std::string, String, PerThreadBumpPtrAllocator,
+                           ConcurrentHashTableInfoByPtr<
+                               std::string, String, PerThreadBumpPtrAllocator>>
       HashTable(Allocator);
 
-  // Check insertion.
-  for (size_t I = 0; I < NumElements; I++) {
-    size_t AllocatedBytesAtStart = Allocator.getBytesAllocated();
-    std::string StringForElement = formatv("{0}", I);
-    std::pair<String *, bool> Entry = HashTable.insert(StringForElement);
-    EXPECT_TRUE(Entry.second);
-    EXPECT_TRUE(Entry.first->getKey() == StringForElement);
-    EXPECT_TRUE(Allocator.getBytesAllocated() > AllocatedBytesAtStart);
-  }
-
-  std::string StatisticString;
-  raw_string_ostream StatisticStream(StatisticString);
-  HashTable.printStatistic(StatisticStream);
-
-  // Verifying that the table contains exactly the number of elements we
-  // inserted.
-  EXPECT_TRUE(StatisticString.find("Overall number of entries = 10000\n") !=
-              std::string::npos);
-
-  // Check insertion of duplicates.
-  for (size_t I = 0; I < NumElements; I++) {
-    size_t AllocatedBytesAtStart = Allocator.getBytesAllocated();
-    std::string StringForElement = formatv("{0}", I);
-    std::pair<String *, bool> Entry = HashTable.insert(StringForElement);
-    EXPECT_FALSE(Entry.second);
-    EXPECT_TRUE(Entry.first->getKey() == StringForElement);
-    // Check no additional bytes were allocated for duplicate.
-    EXPECT_TRUE(Allocator.getBytesAllocated() == AllocatedBytesAtStart);
-  }
-
-  // Check statistic.
-  // Verifying that the table contains exactly the number of elements we
-  // inserted.
-  EXPECT_TRUE(StatisticString.find("Overall number of entries = 10000\n") !=
-              std::string::npos);
+  // PerThreadBumpPtrAllocator should be accessed from threads created by
+  // ThreadPoolExecutor. Use TaskGroup to run on ThreadPoolExecutor threads.
+  parallel::TaskGroup tg;
+
+  tg.spawn([&]() {
+    // Check insertion.
+    for (size_t I = 0; I < NumElements; I++) {
+      BumpPtrAllocator &ThreadLocalAllocator =
+          Allocator.getThreadLocalAllocator();
+      size_t AllocatedBytesAtStart = ThreadLocalAllocator.getBytesAllocated();
+      std::string StringForElement = formatv("{0}", I);
+      std::pair<String *, bool> Entry = HashTable.insert(StringForElement);
+      EXPECT_TRUE(Entry.second);
+      EXPECT_TRUE(Entry.first->getKey() == StringForElement);
+      EXPECT_TRUE(ThreadLocalAllocator.getBytesAllocated() >
+                  AllocatedBytesAtStart);
+    }
+
+    std::string StatisticString;
+    raw_string_ostream StatisticStream(StatisticString);
+    HashTable.printStatistic(StatisticStream);
+
+    // Verifying that the table contains exactly the number of elements we
+    // inserted.
+    EXPECT_TRUE(StatisticString.find("Overall number of entries = 10000\n") !=
+                std::string::npos);
+
+    // Check insertion of duplicates.
+    for (size_t I = 0; I < NumElements; I++) {
+      BumpPtrAllocator &ThreadLocalAllocator =
+          Allocator.getThreadLocalAllocator();
+      size_t AllocatedBytesAtStart = ThreadLocalAllocator.getBytesAllocated();
+      std::string StringForElement = formatv("{0}", I);
+      std::pair<String *, bool> Entry = HashTable.insert(StringForElement);
+      EXPECT_FALSE(Entry.second);
+      EXPECT_TRUE(Entry.first->getKey() == StringForElement);
+      // Check no additional bytes were allocated for duplicate.
+      EXPECT_TRUE(ThreadLocalAllocator.getBytesAllocated() ==
+                  AllocatedBytesAtStart);
+    }
+
+    // Check statistic.
+    // Verifying that the table contains exactly the number of elements we
+    // inserted.
+    EXPECT_TRUE(StatisticString.find("Overall number of entries = 10000\n") !=
+                std::string::npos);
+  });
 }
 
 TEST(ConcurrentHashTableTest, AddStringMultiplueEntriesWithResize) {
+  PerThreadBumpPtrAllocator Allocator;
   // Number of elements exceeds original size, thus hashtable should be resized.
   const size_t NumElements = 20000;
-  ConcurrentHashTableByPtr<
-      std::string, String, SimpleAllocator,
-      ConcurrentHashTableInfoByPtr<std::string, String, SimpleAllocator>>
+  ConcurrentHashTableByPtr<std::string, String, PerThreadBumpPtrAllocator,
+                           ConcurrentHashTableInfoByPtr<
+                               std::string, String, PerThreadBumpPtrAllocator>>
       HashTable(Allocator, 100);
 
-  // Check insertion.
-  for (size_t I = 0; I < NumElements; I++) {
-    size_t AllocatedBytesAtStart = Allocator.getBytesAllocated();
-    std::string StringForElement = formatv("{0} {1}", I, I + 100);
-    std::pair<String *, bool> Entry = HashTable.insert(StringForElement);
-    EXPECT_TRUE(Entry.second);
-    EXPECT_TRUE(Entry.first->getKey() == StringForElement);
-    EXPECT_TRUE(Allocator.getBytesAllocated() > AllocatedBytesAtStart);
-  }
-
-  std::string StatisticString;
-  raw_string_ostream StatisticStream(StatisticString);
-  HashTable.printStatistic(StatisticStream);
-
-  // Verifying that the table contains exactly the number of elements we
-  // inserted.
-  EXPECT_TRUE(StatisticString.find("Overall number of entries = 20000\n") !=
-              std::string::npos);
-
-  // Check insertion of duplicates.
-  for (size_t I = 0; I < NumElements; I++) {
-    size_t AllocatedBytesAtStart = Allocator.getBytesAllocated();
-    std::string StringForElement = formatv("{0} {1}", I, I + 100);
-    std::pair<String *, bool> Entry = HashTable.insert(StringForElement);
-    EXPECT_FALSE(Entry.second);
-    EXPECT_TRUE(Entry.first->getKey() == StringForElement);
-    // Check no additional bytes were allocated for duplicate.
-    EXPECT_TRUE(Allocator.getBytesAllocated() == AllocatedBytesAtStart);
-  }
-
-  // Check statistic.
-  // Verifying that the table contains exactly the number of elements we
-  // inserted.
-  EXPECT_TRUE(StatisticString.find("Overall number of entries = 20000\n") !=
-              std::string::npos);
+  // PerThreadBumpPtrAllocator should be accessed from threads created by
+  // ThreadPoolExecutor. Use TaskGroup to run on ThreadPoolExecutor threads.
+  parallel::TaskGroup tg;
+
+  tg.spawn([&]() {
+    // Check insertion.
+    for (size_t I = 0; I < NumElements; I++) {
+      BumpPtrAllocator &ThreadLocalAllocator =
+          Allocator.getThreadLocalAllocator();
+      size_t AllocatedBytesAtStart = ThreadLocalAllocator.getBytesAllocated();
+      std::string StringForElement = formatv("{0} {1}", I, I + 100);
+      std::pair<String *, bool> Entry = HashTable.insert(StringForElement);
+      EXPECT_TRUE(Entry.second);
+      EXPECT_TRUE(Entry.first->getKey() == StringForElement);
+      EXPECT_TRUE(ThreadLocalAllocator.getBytesAllocated() >
+                  AllocatedBytesAtStart);
+    }
+
+    std::string StatisticString;
+    raw_string_ostream StatisticStream(StatisticString);
+    HashTable.printStatistic(StatisticStream);
+
+    // Verifying that the table contains exactly the number of elements we
+    // inserted.
+    EXPECT_TRUE(StatisticString.find("Overall number of entries = 20000\n") !=
+                std::string::npos);
+
+    // Check insertion of duplicates.
+    for (size_t I = 0; I < NumElements; I++) {
+      BumpPtrAllocator &ThreadLocalAllocator =
+          Allocator.getThreadLocalAllocator();
+      size_t AllocatedBytesAtStart = ThreadLocalAllocator.getBytesAllocated();
+      std::string StringForElement = formatv("{0} {1}", I, I + 100);
+      std::pair<String *, bool> Entry = HashTable.insert(StringForElement);
+      EXPECT_FALSE(Entry.second);
+      EXPECT_TRUE(Entry.first->getKey() == StringForElement);
+      // Check no additional bytes were allocated for duplicate.
+      EXPECT_TRUE(ThreadLocalAllocator.getBytesAllocated() ==
+                  AllocatedBytesAtStart);
+    }
+
+    // Check statistic.
+    // Verifying that the table contains exactly the number of elements we
+    // inserted.
+    EXPECT_TRUE(StatisticString.find("Overall number of entries = 20000\n") !=
+                std::string::npos);
+  });
 }
 
 TEST(ConcurrentHashTableTest, AddStringEntriesParallel) {
+  PerThreadBumpPtrAllocator Allocator;
   const size_t NumElements = 10000;
-  ConcurrentHashTableByPtr<
-      std::string, String, SimpleAllocator,
-      ConcurrentHashTableInfoByPtr<std::string, String, SimpleAllocator>>
+  ConcurrentHashTableByPtr<std::string, String, PerThreadBumpPtrAllocator,
+                           ConcurrentHashTableInfoByPtr<
+                               std::string, String, PerThreadBumpPtrAllocator>>
       HashTable(Allocator);
 
   // Check parallel insertion.
   parallelFor(0, NumElements, [&](size_t I) {
-    size_t AllocatedBytesAtStart = Allocator.getBytesAllocated();
+    BumpPtrAllocator &ThreadLocalAllocator =
+        Allocator.getThreadLocalAllocator();
+    size_t AllocatedBytesAtStart = ThreadLocalAllocator.getBytesAllocated();
     std::string StringForElement = formatv("{0}", I);
     std::pair<String *, bool> Entry = HashTable.insert(StringForElement);
     EXPECT_TRUE(Entry.second);
     EXPECT_TRUE(Entry.first->getKey() == StringForElement);
-    EXPECT_TRUE(Allocator.getBytesAllocated() > AllocatedBytesAtStart);
+    EXPECT_TRUE(ThreadLocalAllocator.getBytesAllocated() >
+                AllocatedBytesAtStart);
   });
 
   std::string StatisticString;
@@ -229,13 +240,16 @@ TEST(ConcurrentHashTableTest, AddStringEntriesParallel) {
 
   // Check parallel insertion of duplicates.
   parallelFor(0, NumElements, [&](size_t I) {
-    size_t AllocatedBytesAtStart = Allocator.getBytesAllocated();
+    BumpPtrAllocator &ThreadLocalAllocator =
+        Allocator.getThreadLocalAllocator();
+    size_t AllocatedBytesAtStart = ThreadLocalAllocator.getBytesAllocated();
     std::string StringForElement = formatv("{0}", I);
     std::pair<String *, bool> Entry = HashTable.insert(StringForElement);
     EXPECT_FALSE(Entry.second);
     EXPECT_TRUE(Entry.first->getKey() == StringForElement);
     // Check no additional bytes were allocated for duplicate.
-    EXPECT_TRUE(Allocator.getBytesAllocated() == AllocatedBytesAtStart);
+    EXPECT_TRUE(ThreadLocalAllocator.getBytesAllocated() ==
+                AllocatedBytesAtStart);
   });
 
   // Check statistic.
@@ -246,20 +260,24 @@ TEST(ConcurrentHashTableTest, AddStringEntriesParallel) {
 }
 
 TEST(ConcurrentHashTableTest, AddStringEntriesParallelWithResize) {
+  PerThreadBumpPtrAllocator Allocator;
   const size_t NumElements = 20000;
-  ConcurrentHashTableByPtr<
-      std::string, String, SimpleAllocator,
-      ConcurrentHashTableInfoByPtr<std::string, String, SimpleAllocator>>
+  ConcurrentHashTableByPtr<std::string, String, PerThreadBumpPtrAllocator,
+                           ConcurrentHashTableInfoByPtr<
+                               std::string, String, PerThreadBumpPtrAllocator>>
       HashTable(Allocator, 100);
 
   // Check parallel insertion.
   parallelFor(0, NumElements, [&](size_t I) {
-    size_t AllocatedBytesAtStart = Allocator.getBytesAllocated();
+    BumpPtrAllocator &ThreadLocalAllocator =
+        Allocator.getThreadLocalAllocator();
+    size_t AllocatedBytesAtStart = ThreadLocalAllocator.getBytesAllocated();
     std::string StringForElement = formatv("{0}", I);
     std::pair<String *, bool> Entry = HashTable.insert(StringForElement);
     EXPECT_TRUE(Entry.second);
     EXPECT_TRUE(Entry.first->getKey() == StringForElement);
-    EXPECT_TRUE(Allocator.getBytesAllocated() > AllocatedBytesAtStart);
+    EXPECT_TRUE(ThreadLocalAllocator.getBytesAllocated() >
+                AllocatedBytesAtStart);
   });
 
   std::string StatisticString;
@@ -273,13 +291,16 @@ TEST(ConcurrentHashTableTest, AddStringEntriesParallelWithResize) {
 
   // Check parallel insertion of duplicates.
   parallelFor(0, NumElements, [&](size_t I) {
-    size_t AllocatedBytesAtStart = Allocator.getBytesAllocated();
+    BumpPtrAllocator &ThreadLocalAllocator =
+        Allocator.getThreadLocalAllocator();
+    size_t AllocatedBytesAtStart = ThreadLocalAllocator.getBytesAllocated();
     std::string StringForElement = formatv("{0}", I);
     std::pair<String *, bool> Entry = HashTable.insert(StringForElement);
     EXPECT_FALSE(Entry.second);
     EXPECT_TRUE(Entry.first->getKey() == StringForElement);
     // Check no additional bytes were allocated for duplicate.
-    EXPECT_TRUE(Allocator.getBytesAllocated() == AllocatedBytesAtStart);
+    EXPECT_TRUE(ThreadLocalAllocator.getBytesAllocated() ==
+                AllocatedBytesAtStart);
   });
 
   // Check statistic.

diff  --git a/llvm/unittests/DWARFLinkerParallel/StringPoolTest.cpp b/llvm/unittests/DWARFLinkerParallel/StringPoolTest.cpp
index 2e36cdecd8b30..11ae18c9e8621 100644
--- a/llvm/unittests/DWARFLinkerParallel/StringPoolTest.cpp
+++ b/llvm/unittests/DWARFLinkerParallel/StringPoolTest.cpp
@@ -19,24 +19,31 @@ namespace {
 TEST(StringPoolTest, TestStringPool) {
   StringPool Strings;
 
-  std::pair<StringEntry *, bool> Entry = Strings.insert("test");
-  EXPECT_TRUE(Entry.second);
-  EXPECT_TRUE(Entry.first->getKey() == "test");
-  EXPECT_TRUE(Entry.first->second == nullptr);
-
-  StringEntry *EntryPtr = Entry.first;
-
-  Entry = Strings.insert("test");
-  EXPECT_FALSE(Entry.second);
-  EXPECT_TRUE(Entry.first->getKey() == "test");
-  EXPECT_TRUE(Entry.first->second == nullptr);
-  EXPECT_TRUE(EntryPtr == Entry.first);
-
-  Entry = Strings.insert("test2");
-  EXPECT_TRUE(Entry.second);
-  EXPECT_TRUE(Entry.first->getKey() == "test2");
-  EXPECT_TRUE(Entry.first->second == nullptr);
-  EXPECT_TRUE(EntryPtr != Entry.first);
+  // StringPool uses PerThreadBumpPtrAllocator which should be accessed from
+  // threads created by ThreadPoolExecutor. Use TaskGroup to run on
+  // ThreadPoolExecutor threads.
+  parallel::TaskGroup tg;
+
+  tg.spawn([&]() {
+    std::pair<StringEntry *, bool> Entry = Strings.insert("test");
+    EXPECT_TRUE(Entry.second);
+    EXPECT_TRUE(Entry.first->getKey() == "test");
+    EXPECT_TRUE(Entry.first->second == nullptr);
+
+    StringEntry *EntryPtr = Entry.first;
+
+    Entry = Strings.insert("test");
+    EXPECT_FALSE(Entry.second);
+    EXPECT_TRUE(Entry.first->getKey() == "test");
+    EXPECT_TRUE(Entry.first->second == nullptr);
+    EXPECT_TRUE(EntryPtr == Entry.first);
+
+    Entry = Strings.insert("test2");
+    EXPECT_TRUE(Entry.second);
+    EXPECT_TRUE(Entry.first->getKey() == "test2");
+    EXPECT_TRUE(Entry.first->second == nullptr);
+    EXPECT_TRUE(EntryPtr != Entry.first);
+  });
 }
 
 TEST(StringPoolTest, TestStringPoolParallel) {

diff  --git a/llvm/unittests/DWARFLinkerParallel/StringTableTest.cpp b/llvm/unittests/DWARFLinkerParallel/StringTableTest.cpp
index 40795f5d38ae6..02bb95001acd4 100644
--- a/llvm/unittests/DWARFLinkerParallel/StringTableTest.cpp
+++ b/llvm/unittests/DWARFLinkerParallel/StringTableTest.cpp
@@ -29,41 +29,48 @@ TEST(StringPoolTest, TestStringTable) {
   StringPool Strings;
   StringTable OutStrings(Strings, nullptr);
 
-  // Check string insertion.
-  StringEntry *FirstPtr = Strings.insert(InputStrings[0].Str).first;
-  StringEntry *SecondPtr = Strings.insert(InputStrings[1].Str).first;
-  StringEntry *ThirdPtr = Strings.insert(InputStrings[2].Str).first;
-
-  FirstPtr = OutStrings.add(FirstPtr);
-  SecondPtr = OutStrings.add(SecondPtr);
-  ThirdPtr = OutStrings.add(ThirdPtr);
-
-  // Check fields of inserted strings.
-  EXPECT_TRUE(FirstPtr->getKey() == InputStrings[0].Str);
-  EXPECT_TRUE(FirstPtr->getValue()->Offset == InputStrings[0].Offset);
-  EXPECT_TRUE(FirstPtr->getValue()->Index == InputStrings[0].Idx);
-
-  EXPECT_TRUE(SecondPtr->getKey() == InputStrings[1].Str);
-  EXPECT_TRUE(SecondPtr->getValue()->Offset == InputStrings[1].Offset);
-  EXPECT_TRUE(SecondPtr->getValue()->Index == InputStrings[1].Idx);
-
-  EXPECT_TRUE(ThirdPtr->getKey() == InputStrings[2].Str);
-  EXPECT_TRUE(ThirdPtr->getValue()->Offset == InputStrings[2].Offset);
-  EXPECT_TRUE(ThirdPtr->getValue()->Index == InputStrings[2].Idx);
-
-  // Check order enumerated strings.
-  uint64_t CurIdx = 0;
-  std::function<void(DwarfStringPoolEntryRef)> checkStr =
-      [&](DwarfStringPoolEntryRef Entry) {
-        EXPECT_TRUE(Entry.getEntry().isIndexed());
-        EXPECT_TRUE(Entry.getIndex() == CurIdx);
-        EXPECT_TRUE(Entry.getOffset() == InputStrings[CurIdx].Offset);
-        EXPECT_TRUE(Entry.getString() == InputStrings[CurIdx].Str);
-
-        CurIdx++;
-      };
-
-  OutStrings.forEach(checkStr);
+  // StringPool uses PerThreadBumpPtrAllocator which should be accessed from
+  // threads created by ThreadPoolExecutor. Use TaskGroup to run on
+  // ThreadPoolExecutor threads.
+  parallel::TaskGroup tg;
+
+  tg.spawn([&]() {
+    // Check string insertion.
+    StringEntry *FirstPtr = Strings.insert(InputStrings[0].Str).first;
+    StringEntry *SecondPtr = Strings.insert(InputStrings[1].Str).first;
+    StringEntry *ThirdPtr = Strings.insert(InputStrings[2].Str).first;
+
+    FirstPtr = OutStrings.add(FirstPtr);
+    SecondPtr = OutStrings.add(SecondPtr);
+    ThirdPtr = OutStrings.add(ThirdPtr);
+
+    // Check fields of inserted strings.
+    EXPECT_TRUE(FirstPtr->getKey() == InputStrings[0].Str);
+    EXPECT_TRUE(FirstPtr->getValue()->Offset == InputStrings[0].Offset);
+    EXPECT_TRUE(FirstPtr->getValue()->Index == InputStrings[0].Idx);
+
+    EXPECT_TRUE(SecondPtr->getKey() == InputStrings[1].Str);
+    EXPECT_TRUE(SecondPtr->getValue()->Offset == InputStrings[1].Offset);
+    EXPECT_TRUE(SecondPtr->getValue()->Index == InputStrings[1].Idx);
+
+    EXPECT_TRUE(ThirdPtr->getKey() == InputStrings[2].Str);
+    EXPECT_TRUE(ThirdPtr->getValue()->Offset == InputStrings[2].Offset);
+    EXPECT_TRUE(ThirdPtr->getValue()->Index == InputStrings[2].Idx);
+
+    // Check order enumerated strings.
+    uint64_t CurIdx = 0;
+    std::function<void(DwarfStringPoolEntryRef)> checkStr =
+        [&](DwarfStringPoolEntryRef Entry) {
+          EXPECT_TRUE(Entry.getEntry().isIndexed());
+          EXPECT_TRUE(Entry.getIndex() == CurIdx);
+          EXPECT_TRUE(Entry.getOffset() == InputStrings[CurIdx].Offset);
+          EXPECT_TRUE(Entry.getString() == InputStrings[CurIdx].Str);
+
+          CurIdx++;
+        };
+
+    OutStrings.forEach(checkStr);
+  });
 }
 
 TEST(StringPoolTest, TestStringTableWithTranslator) {
@@ -80,25 +87,32 @@ TEST(StringPoolTest, TestStringTableWithTranslator) {
   StringPool Strings;
   StringTable OutStrings(Strings, TranslatorFunc);
 
-  StringEntry *FirstPtr = Strings.insert("first").first;
-  StringEntry *SecondPtr = Strings.insert("second").first;
-  StringEntry *ThirdPtr = Strings.insert("third").first;
+  // StringPool uses PerThreadBumpPtrAllocator which should be accessed from
+  // threads created by ThreadPoolExecutor. Use TaskGroup to run on
+  // ThreadPoolExecutor threads.
+  parallel::TaskGroup tg;
 
-  FirstPtr = OutStrings.add(FirstPtr);
-  SecondPtr = OutStrings.add(SecondPtr);
-  ThirdPtr = OutStrings.add(ThirdPtr);
+  tg.spawn([&]() {
+    StringEntry *FirstPtr = Strings.insert("first").first;
+    StringEntry *SecondPtr = Strings.insert("second").first;
+    StringEntry *ThirdPtr = Strings.insert("third").first;
 
-  EXPECT_TRUE(FirstPtr->getKey() == "tsrif0");
-  EXPECT_TRUE(FirstPtr->getValue()->Offset == 0);
-  EXPECT_TRUE(FirstPtr->getValue()->Index == 0);
+    FirstPtr = OutStrings.add(FirstPtr);
+    SecondPtr = OutStrings.add(SecondPtr);
+    ThirdPtr = OutStrings.add(ThirdPtr);
 
-  EXPECT_TRUE(SecondPtr->getKey() == "dnoces0");
-  EXPECT_TRUE(SecondPtr->getValue()->Offset == 7);
-  EXPECT_TRUE(SecondPtr->getValue()->Index == 1);
+    EXPECT_TRUE(FirstPtr->getKey() == "tsrif0");
+    EXPECT_TRUE(FirstPtr->getValue()->Offset == 0);
+    EXPECT_TRUE(FirstPtr->getValue()->Index == 0);
 
-  EXPECT_TRUE(ThirdPtr->getKey() == "driht0");
-  EXPECT_TRUE(ThirdPtr->getValue()->Offset == 15);
-  EXPECT_TRUE(ThirdPtr->getValue()->Index == 2);
+    EXPECT_TRUE(SecondPtr->getKey() == "dnoces0");
+    EXPECT_TRUE(SecondPtr->getValue()->Offset == 7);
+    EXPECT_TRUE(SecondPtr->getValue()->Index == 1);
+
+    EXPECT_TRUE(ThirdPtr->getKey() == "driht0");
+    EXPECT_TRUE(ThirdPtr->getValue()->Offset == 15);
+    EXPECT_TRUE(ThirdPtr->getValue()->Index == 2);
+  });
 }
 
 } // anonymous namespace

diff  --git a/llvm/unittests/Support/CMakeLists.txt b/llvm/unittests/Support/CMakeLists.txt
index d6e4fa5869379..d1185827db202 100644
--- a/llvm/unittests/Support/CMakeLists.txt
+++ b/llvm/unittests/Support/CMakeLists.txt
@@ -63,6 +63,7 @@ add_llvm_unittest(SupportTests
   OptimizedStructLayoutTest.cpp
   ParallelTest.cpp
   Path.cpp
+  PerThreadBumpPtrAllocatorTest.cpp
   ProcessTest.cpp
   ProgramTest.cpp
   RegexTest.cpp

diff  --git a/llvm/unittests/Support/PerThreadBumpPtrAllocatorTest.cpp b/llvm/unittests/Support/PerThreadBumpPtrAllocatorTest.cpp
new file mode 100644
index 0000000000000..d30de997f0fd1
--- /dev/null
+++ b/llvm/unittests/Support/PerThreadBumpPtrAllocatorTest.cpp
@@ -0,0 +1,56 @@
+//===- PerThreadBumpPtrAllocatorTest.cpp ----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Support/PerThreadBumpPtrAllocator.h"
+#include "llvm/Support/Parallel.h"
+#include "gtest/gtest.h"
+#include <cstdlib>
+
+using namespace llvm;
+using namespace parallel;
+
+namespace {
+
+TEST(PerThreadBumpPtrAllocatorTest, Simple) {
+  PerThreadBumpPtrAllocator Allocator;
+
+  parallel::TaskGroup tg;
+
+  tg.spawn([&]() {
+    uint64_t *Var =
+        (uint64_t *)Allocator.Allocate(sizeof(uint64_t), alignof(uint64_t));
+    *Var = 0xFE;
+    EXPECT_EQ(0xFEul, *Var);
+    EXPECT_EQ(sizeof(uint64_t), Allocator.getBytesAllocated());
+    EXPECT_TRUE(Allocator.getBytesAllocated() <= Allocator.getTotalMemory());
+
+    PerThreadBumpPtrAllocator Allocator2(std::move(Allocator));
+
+    EXPECT_EQ(sizeof(uint64_t), Allocator2.getBytesAllocated());
+    EXPECT_TRUE(Allocator2.getBytesAllocated() <= Allocator2.getTotalMemory());
+
+    EXPECT_EQ(0xFEul, *Var);
+  });
+}
+
+TEST(PerThreadBumpPtrAllocatorTest, ParallelAllocation) {
+  PerThreadBumpPtrAllocator Allocator;
+
+  static size_t constexpr NumAllocations = 1000;
+
+  parallelFor(0, NumAllocations, [&](size_t Idx) {
+    uint64_t *ptr =
+        (uint64_t *)Allocator.Allocate(sizeof(uint64_t), alignof(uint64_t));
+    *ptr = Idx;
+  });
+
+  EXPECT_EQ(sizeof(uint64_t) * NumAllocations, Allocator.getBytesAllocated());
+  EXPECT_EQ(Allocator.getNumberOfAllocators(), parallel::getThreadCount());
+}
+
+} // anonymous namespace


        


More information about the llvm-commits mailing list