[lldb] [llvm] [mlir] Split the llvm::ThreadPool into an abstract base class and an implementation (PR #82094)

Mehdi Amini via llvm-commits llvm-commits at lists.llvm.org
Sat Mar 2 15:30:46 PST 2024


https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/82094

>From 6864303679435f51ce899e348e49bfd11eb4146f Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Fri, 16 Feb 2024 21:55:57 -0800
Subject: [PATCH] Split the llvm::ThreadPool into an abstract base class and an
 implementation

This decouples the public API used to enqueue tasks and wait for completion
from the actual implementation, and opens up the possibility for clients to
set their own thread pool implementation for the pool.

https://discourse.llvm.org/t/construct-threadpool-from-vector-of-existing-threads/76883
---
 bolt/include/bolt/Core/ParallelUtilities.h    |   3 +-
 lldb/include/lldb/Core/Debugger.h             |   6 +-
 llvm/include/llvm/Debuginfod/Debuginfod.h     |   6 +-
 .../llvm/Support/BalancedPartitioning.h       |   7 +-
 llvm/include/llvm/Support/ThreadPool.h        | 184 +++++++++++-------
 llvm/lib/Debuginfod/Debuginfod.cpp            |   3 +-
 llvm/lib/Support/ThreadPool.cpp               |  41 ++--
 llvm/tools/llvm-cov/CoverageReport.h          |   4 +-
 llvm/tools/llvm-cov/SourceCoverageViewHTML.h  |   2 -
 llvm/unittests/Support/ThreadPool.cpp         | 112 +++++++----
 mlir/include/mlir/CAPI/Support.h              |   4 +-
 mlir/include/mlir/IR/MLIRContext.h            |   6 +-
 mlir/include/mlir/IR/Threading.h              |   2 +-
 mlir/lib/CAPI/IR/IR.cpp                       |   1 +
 mlir/lib/IR/MLIRContext.cpp                   |   8 +-
 mlir/lib/Tools/mlir-opt/MlirOptMain.cpp       |   4 +-
 16 files changed, 227 insertions(+), 166 deletions(-)

diff --git a/bolt/include/bolt/Core/ParallelUtilities.h b/bolt/include/bolt/Core/ParallelUtilities.h
index 7d3af47757bce6..e510525bc51d00 100644
--- a/bolt/include/bolt/Core/ParallelUtilities.h
+++ b/bolt/include/bolt/Core/ParallelUtilities.h
@@ -18,6 +18,7 @@
 
 #include "bolt/Core/MCPlusBuilder.h"
 #include "llvm/Support/CommandLine.h"
+#include "llvm/Support/ThreadPool.h"
 
 using namespace llvm;
 
@@ -28,8 +29,6 @@ extern cl::opt<unsigned> TaskCount;
 } // namespace opts
 
 namespace llvm {
-class ThreadPool;
-
 namespace bolt {
 class BinaryContext;
 class BinaryFunction;
diff --git a/lldb/include/lldb/Core/Debugger.h b/lldb/include/lldb/Core/Debugger.h
index b65ec1029ab24b..418c2403d020f4 100644
--- a/lldb/include/lldb/Core/Debugger.h
+++ b/lldb/include/lldb/Core/Debugger.h
@@ -52,7 +52,7 @@
 
 namespace llvm {
 class raw_ostream;
-class ThreadPool;
+class ThreadPoolInterface;
 } // namespace llvm
 
 namespace lldb_private {
@@ -499,8 +499,8 @@ class Debugger : public std::enable_shared_from_this<Debugger>,
     return m_broadcaster_manager_sp;
   }
 
-  /// Shared thread poll. Use only with ThreadPoolTaskGroup.
-  static llvm::ThreadPool &GetThreadPool();
+  /// Shared thread pool. Use only with ThreadPoolTaskGroup.
+  static llvm::ThreadPoolInterface &GetThreadPool();
 
   /// Report warning events.
   ///
diff --git a/llvm/include/llvm/Debuginfod/Debuginfod.h b/llvm/include/llvm/Debuginfod/Debuginfod.h
index ef03948a706c06..99fe15ad859794 100644
--- a/llvm/include/llvm/Debuginfod/Debuginfod.h
+++ b/llvm/include/llvm/Debuginfod/Debuginfod.h
@@ -97,7 +97,7 @@ Expected<std::string> getCachedOrDownloadArtifact(
     StringRef UniqueKey, StringRef UrlPath, StringRef CacheDirectoryPath,
     ArrayRef<StringRef> DebuginfodUrls, std::chrono::milliseconds Timeout);
 
-class ThreadPool;
+class ThreadPoolInterface;
 
 struct DebuginfodLogEntry {
   std::string Message;
@@ -135,7 +135,7 @@ class DebuginfodCollection {
   // error.
   Expected<bool> updateIfStale();
   DebuginfodLog &Log;
-  ThreadPool &Pool;
+  ThreadPoolInterface &Pool;
   Timer UpdateTimer;
   sys::Mutex UpdateMutex;
 
@@ -145,7 +145,7 @@ class DebuginfodCollection {
 
 public:
   DebuginfodCollection(ArrayRef<StringRef> Paths, DebuginfodLog &Log,
-                       ThreadPool &Pool, double MinInterval);
+                       ThreadPoolInterface &Pool, double MinInterval);
   Error update();
   Error updateForever(std::chrono::milliseconds Interval);
   Expected<std::string> findDebugBinaryPath(object::BuildIDRef);
diff --git a/llvm/include/llvm/Support/BalancedPartitioning.h b/llvm/include/llvm/Support/BalancedPartitioning.h
index a8464ac0fe60e5..9738e742f7f1e9 100644
--- a/llvm/include/llvm/Support/BalancedPartitioning.h
+++ b/llvm/include/llvm/Support/BalancedPartitioning.h
@@ -50,7 +50,7 @@
 
 namespace llvm {
 
-class ThreadPool;
+class ThreadPoolInterface;
 /// A function with a set of utility nodes where it is beneficial to order two
 /// functions close together if they have similar utility nodes
 class BPFunctionNode {
@@ -115,7 +115,7 @@ class BalancedPartitioning {
   /// threads, so we need to track how many active threads that could spawn more
   /// threads.
   struct BPThreadPool {
-    ThreadPool &TheThreadPool;
+    ThreadPoolInterface &TheThreadPool;
     std::mutex mtx;
     std::condition_variable cv;
     /// The number of threads that could spawn more threads
@@ -128,7 +128,8 @@ class BalancedPartitioning {
     /// acceptable for other threads to add more tasks while blocking on this
     /// call.
     void wait();
-    BPThreadPool(ThreadPool &TheThreadPool) : TheThreadPool(TheThreadPool) {}
+    BPThreadPool(ThreadPoolInterface &TheThreadPool)
+        : TheThreadPool(TheThreadPool) {}
   };
 
   /// Run a recursive bisection of a given list of FunctionNodes
diff --git a/llvm/include/llvm/Support/ThreadPool.h b/llvm/include/llvm/Support/ThreadPool.h
index 03ebd35aa46dc4..93f02729f047aa 100644
--- a/llvm/include/llvm/Support/ThreadPool.h
+++ b/llvm/include/llvm/Support/ThreadPool.h
@@ -32,11 +32,8 @@ namespace llvm {
 
 class ThreadPoolTaskGroup;
 
-/// A ThreadPool for asynchronous parallel execution on a defined number of
-/// threads.
-///
-/// The pool keeps a vector of threads alive, waiting on a condition variable
-/// for some work to become available.
+/// This defines the abstract base interface for a ThreadPool allowing
+/// asynchronous parallel execution on a defined number of threads.
 ///
 /// It is possible to reuse one thread pool for different groups of tasks
 /// by grouping tasks using ThreadPoolTaskGroup. All tasks are processed using
@@ -49,16 +46,31 @@ class ThreadPoolTaskGroup;
 /// available threads are used up by tasks waiting for a task that has no thread
 /// left to run on (this includes waiting on the returned future). It should be
 /// generally safe to wait() for a group as long as groups do not form a cycle.
-class ThreadPool {
+class ThreadPoolInterface {
+  /// The actual method to enqueue a task to be defined by the concrete
+  /// implementation.
+  virtual void asyncEnqueue(std::function<void()> Task,
+                            ThreadPoolTaskGroup *Group) = 0;
+
 public:
-  /// Construct a pool using the hardware strategy \p S for mapping hardware
-  /// execution resources (threads, cores, CPUs)
-  /// Defaults to using the maximum execution resources in the system, but
-  /// accounting for the affinity mask.
-  ThreadPool(ThreadPoolStrategy S = hardware_concurrency());
+  /// Destroying the pool will drain the pending tasks and wait. The current
+  /// thread may participate in the execution of the pending tasks.
+  virtual ~ThreadPoolInterface();
 
-  /// Blocking destructor: the pool will wait for all the threads to complete.
-  ~ThreadPool();
+  /// Blocking wait for all the threads to complete and the queue to be empty.
+  /// It is an error to try to add new tasks while blocking on this call.
+  /// Calling wait() from a task would deadlock waiting for itself.
+  virtual void wait() = 0;
+
+  /// Blocking wait for only all the threads in the given group to complete.
+  /// It is possible to wait even inside a task, but waiting (directly or
+  /// indirectly) on itself will deadlock. If called from a task running on a
+  /// worker thread, the call may process pending tasks while waiting in order
+  /// not to waste the thread.
+  virtual void wait(ThreadPoolTaskGroup &Group) = 0;
+
+  /// Returns the maximum number of worker this pool can eventually grow to.
+  virtual unsigned getMaxConcurrency() const = 0;
 
   /// Asynchronous submission of a task to the pool. The returned future can be
   /// used to wait for the task to finish and is *non-blocking* on destruction.
@@ -92,23 +104,51 @@ class ThreadPool {
                      &Group);
   }
 
+private:
+  /// Asynchronous submission of a task to the pool. The returned future can be
+  /// used to wait for the task to finish and is *non-blocking* on destruction.
+  template <typename ResTy>
+  std::shared_future<ResTy> asyncImpl(std::function<ResTy()> Task,
+                                      ThreadPoolTaskGroup *Group) {
+    auto Future = std::async(std::launch::deferred, std::move(Task)).share();
+    asyncEnqueue([Future]() { Future.wait(); }, Group);
+    return Future;
+  }
+};
+
+#if LLVM_ENABLE_THREADS
+/// A ThreadPool implementation using std::threads.
+///
+/// The pool keeps a vector of threads alive, waiting on a condition variable
+/// for some work to become available.
+class StdThreadPool : public ThreadPoolInterface {
+public:
+  /// Construct a pool using the hardware strategy \p S for mapping hardware
+  /// execution resources (threads, cores, CPUs)
+  /// Defaults to using the maximum execution resources in the system, but
+  /// accounting for the affinity mask.
+  StdThreadPool(ThreadPoolStrategy S = hardware_concurrency());
+
+  /// Blocking destructor: the pool will wait for all the threads to complete.
+  ~StdThreadPool() override;
+
   /// Blocking wait for all the threads to complete and the queue to be empty.
   /// It is an error to try to add new tasks while blocking on this call.
   /// Calling wait() from a task would deadlock waiting for itself.
-  void wait();
+  void wait() override;
 
   /// Blocking wait for only all the threads in the given group to complete.
   /// It is possible to wait even inside a task, but waiting (directly or
   /// indirectly) on itself will deadlock. If called from a task running on a
   /// worker thread, the call may process pending tasks while waiting in order
   /// not to waste the thread.
-  void wait(ThreadPoolTaskGroup &Group);
+  void wait(ThreadPoolTaskGroup &Group) override;
 
-  // Returns the maximum number of worker threads in the pool, not the current
-  // number of threads!
-  unsigned getMaxConcurrency() const { return MaxThreadCount; }
+  /// Returns the maximum number of worker threads in the pool, not the current
+  /// number of threads!
+  unsigned getMaxConcurrency() const override { return MaxThreadCount; }
 
-  // TODO: misleading legacy name warning!
+  // TODO: Remove, misleading legacy name warning!
   LLVM_DEPRECATED("Use getMaxConcurrency instead", "getMaxConcurrency")
   unsigned getThreadCount() const { return MaxThreadCount; }
 
@@ -116,46 +156,14 @@ class ThreadPool {
   bool isWorkerThread() const;
 
 private:
-  /// Helpers to create a promise and a callable wrapper of \p Task that sets
-  /// the result of the promise. Returns the callable and a future to access the
-  /// result.
-  template <typename ResTy>
-  static std::pair<std::function<void()>, std::future<ResTy>>
-  createTaskAndFuture(std::function<ResTy()> Task) {
-    std::shared_ptr<std::promise<ResTy>> Promise =
-        std::make_shared<std::promise<ResTy>>();
-    auto F = Promise->get_future();
-    return {
-        [Promise = std::move(Promise), Task]() { Promise->set_value(Task()); },
-        std::move(F)};
-  }
-  static std::pair<std::function<void()>, std::future<void>>
-  createTaskAndFuture(std::function<void()> Task) {
-    std::shared_ptr<std::promise<void>> Promise =
-        std::make_shared<std::promise<void>>();
-    auto F = Promise->get_future();
-    return {[Promise = std::move(Promise), Task]() {
-              Task();
-              Promise->set_value();
-            },
-            std::move(F)};
-  }
-
   /// Returns true if all tasks in the given group have finished (nullptr means
   /// all tasks regardless of their group). QueueLock must be locked.
   bool workCompletedUnlocked(ThreadPoolTaskGroup *Group) const;
 
   /// Asynchronous submission of a task to the pool. The returned future can be
   /// used to wait for the task to finish and is *non-blocking* on destruction.
-  template <typename ResTy>
-  std::shared_future<ResTy> asyncImpl(std::function<ResTy()> Task,
-                                      ThreadPoolTaskGroup *Group) {
-
-#if LLVM_ENABLE_THREADS
-    /// Wrap the Task in a std::function<void()> that sets the result of the
-    /// corresponding future.
-    auto R = createTaskAndFuture(Task);
-
+  void asyncEnqueue(std::function<void()> Task,
+                    ThreadPoolTaskGroup *Group) override {
     int requestedThreads;
     {
       // Lock the queue and push the new task
@@ -163,31 +171,18 @@ class ThreadPool {
 
       // Don't allow enqueueing after disabling the pool
       assert(EnableFlag && "Queuing a thread during ThreadPool destruction");
-      Tasks.emplace_back(std::make_pair(std::move(R.first), Group));
+      Tasks.emplace_back(std::make_pair(std::move(Task), Group));
       requestedThreads = ActiveThreads + Tasks.size();
     }
     QueueCondition.notify_one();
     grow(requestedThreads);
-    return R.second.share();
-
-#else // LLVM_ENABLE_THREADS Disabled
-
-    // Get a Future with launch::deferred execution using std::async
-    auto Future = std::async(std::launch::deferred, std::move(Task)).share();
-    // Wrap the future so that both ThreadPool::wait() can operate and the
-    // returned future can be sync'ed on.
-    Tasks.emplace_back(std::make_pair([Future]() { Future.get(); }, Group));
-    return Future;
-#endif
   }
 
-#if LLVM_ENABLE_THREADS
-  // Grow to ensure that we have at least `requested` Threads, but do not go
-  // over MaxThreadCount.
+  /// Grow to ensure that we have at least `requested` Threads, but do not go
+  /// over MaxThreadCount.
   void grow(int requested);
 
   void processTasks(ThreadPoolTaskGroup *WaitingForGroup);
-#endif
 
   /// Threads in flight
   std::vector<llvm::thread> Threads;
@@ -209,10 +204,8 @@ class ThreadPool {
   /// Number of threads active for tasks in the given group (only non-zero).
   DenseMap<ThreadPoolTaskGroup *, unsigned> ActiveGroups;
 
-#if LLVM_ENABLE_THREADS // avoids warning for unused variable
   /// Signal for the destruction of the pool, asking thread to exit.
   bool EnableFlag = true;
-#endif
 
   const ThreadPoolStrategy Strategy;
 
@@ -220,6 +213,51 @@ class ThreadPool {
   const unsigned MaxThreadCount;
 };
 
+#endif // LLVM_ENABLE_THREADS Disabled
+
+/// A non-threaded implementation.
+class SingleThreadExecutor : public ThreadPoolInterface {
+public:
+  /// Construct a non-threaded pool, ignoring using the hardware strategy.
+  SingleThreadExecutor(ThreadPoolStrategy ignored = {});
+
+  /// Blocking destructor: the pool will first execute the pending tasks.
+  ~SingleThreadExecutor() override;
+
+  /// Blocking wait for all the tasks to execute first
+  void wait() override;
+
+  /// Blocking wait for only all the tasks in the given group to complete.
+  void wait(ThreadPoolTaskGroup &Group) override;
+
+  /// Returns always 1: there is no concurrency.
+  unsigned getMaxConcurrency() const override { return 1; }
+
+  // TODO: Remove, misleading legacy name warning!
+  LLVM_DEPRECATED("Use getMaxConcurrency instead", "getMaxConcurrency")
+  unsigned getThreadCount() const { return 1; }
+
+  /// Returns true if the current thread is a worker thread of this thread pool.
+  bool isWorkerThread() const;
+
+private:
+  /// Asynchronous submission of a task to the pool. The returned future can be
+  /// used to wait for the task to finish and is *non-blocking* on destruction.
+  void asyncEnqueue(std::function<void()> Task,
+                    ThreadPoolTaskGroup *Group) override {
+    Tasks.emplace_back(std::make_pair(std::move(Task), Group));
+  }
+
+  /// Tasks waiting for execution in the pool.
+  std::deque<std::pair<std::function<void()>, ThreadPoolTaskGroup *>> Tasks;
+};
+
+#if LLVM_ENABLE_THREADS
+using ThreadPool = StdThreadPool;
+#else
+using ThreadPool = SingleThreadExecutor;
+#endif
+
 /// A group of tasks to be run on a thread pool. Thread pool tasks in different
 /// groups can run on the same threadpool but can be waited for separately.
 /// It is even possible for tasks of one group to submit and wait for tasks
@@ -227,7 +265,7 @@ class ThreadPool {
 class ThreadPoolTaskGroup {
 public:
   /// The ThreadPool argument is the thread pool to forward calls to.
-  ThreadPoolTaskGroup(ThreadPool &Pool) : Pool(Pool) {}
+  ThreadPoolTaskGroup(ThreadPoolInterface &Pool) : Pool(Pool) {}
 
   /// Blocking destructor: will wait for all the tasks in the group to complete
   /// by calling ThreadPool::wait().
@@ -244,7 +282,7 @@ class ThreadPoolTaskGroup {
   void wait() { Pool.wait(*this); }
 
 private:
-  ThreadPool &Pool;
+  ThreadPoolInterface &Pool;
 };
 
 } // namespace llvm
diff --git a/llvm/lib/Debuginfod/Debuginfod.cpp b/llvm/lib/Debuginfod/Debuginfod.cpp
index 1cf550721000eb..4c785117ae8ef7 100644
--- a/llvm/lib/Debuginfod/Debuginfod.cpp
+++ b/llvm/lib/Debuginfod/Debuginfod.cpp
@@ -348,7 +348,8 @@ DebuginfodLogEntry DebuginfodLog::pop() {
 }
 
 DebuginfodCollection::DebuginfodCollection(ArrayRef<StringRef> PathsRef,
-                                           DebuginfodLog &Log, ThreadPool &Pool,
+                                           DebuginfodLog &Log,
+                                           ThreadPoolInterface &Pool,
                                            double MinInterval)
     : Log(Log), Pool(Pool), MinInterval(MinInterval) {
   for (StringRef Path : PathsRef)
diff --git a/llvm/lib/Support/ThreadPool.cpp b/llvm/lib/Support/ThreadPool.cpp
index 4eef339000e198..27e0f220ac4ed6 100644
--- a/llvm/lib/Support/ThreadPool.cpp
+++ b/llvm/lib/Support/ThreadPool.cpp
@@ -14,16 +14,13 @@
 
 #include "llvm/Config/llvm-config.h"
 
-#if LLVM_ENABLE_THREADS
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/Threading.h"
-#else
 #include "llvm/Support/raw_ostream.h"
-#endif
 
 using namespace llvm;
 
-#if LLVM_ENABLE_THREADS
+ThreadPoolInterface::~ThreadPoolInterface() = default;
 
 // A note on thread groups: Tasks are by default in no group (represented
 // by nullptr ThreadPoolTaskGroup pointer in the Tasks queue) and functionality
@@ -33,10 +30,12 @@ using namespace llvm;
 // queue, and functions called to work only on tasks from one group take that
 // pointer.
 
-ThreadPool::ThreadPool(ThreadPoolStrategy S)
+#if LLVM_ENABLE_THREADS
+
+StdThreadPool::StdThreadPool(ThreadPoolStrategy S)
     : Strategy(S), MaxThreadCount(S.compute_thread_count()) {}
 
-void ThreadPool::grow(int requested) {
+void StdThreadPool::grow(int requested) {
   llvm::sys::ScopedWriter LockGuard(ThreadsLock);
   if (Threads.size() >= MaxThreadCount)
     return; // Already hit the max thread pool size.
@@ -58,7 +57,7 @@ static LLVM_THREAD_LOCAL std::vector<ThreadPoolTaskGroup *>
 #endif
 
 // WaitingForGroup == nullptr means all tasks regardless of their group.
-void ThreadPool::processTasks(ThreadPoolTaskGroup *WaitingForGroup) {
+void StdThreadPool::processTasks(ThreadPoolTaskGroup *WaitingForGroup) {
   while (true) {
     std::function<void()> Task;
     ThreadPoolTaskGroup *GroupOfTask;
@@ -111,7 +110,7 @@ void ThreadPool::processTasks(ThreadPoolTaskGroup *WaitingForGroup) {
     bool Notify;
     bool NotifyGroup;
     {
-      // Adjust `ActiveThreads`, in case someone waits on ThreadPool::wait()
+      // Adjust `ActiveThreads`, in case someone waits on StdThreadPool::wait()
       std::lock_guard<std::mutex> LockGuard(QueueLock);
       --ActiveThreads;
       if (GroupOfTask != nullptr) {
@@ -123,7 +122,7 @@ void ThreadPool::processTasks(ThreadPoolTaskGroup *WaitingForGroup) {
       NotifyGroup = GroupOfTask != nullptr && Notify;
     }
     // Notify task completion if this is the last active thread, in case
-    // someone waits on ThreadPool::wait().
+    // someone waits on StdThreadPool::wait().
     if (Notify)
       CompletionCondition.notify_all();
     // If this was a task in a group, notify also threads waiting for tasks
@@ -134,7 +133,7 @@ void ThreadPool::processTasks(ThreadPoolTaskGroup *WaitingForGroup) {
   }
 }
 
-bool ThreadPool::workCompletedUnlocked(ThreadPoolTaskGroup *Group) const {
+bool StdThreadPool::workCompletedUnlocked(ThreadPoolTaskGroup *Group) const {
   if (Group == nullptr)
     return !ActiveThreads && Tasks.empty();
   return ActiveGroups.count(Group) == 0 &&
@@ -142,7 +141,7 @@ bool ThreadPool::workCompletedUnlocked(ThreadPoolTaskGroup *Group) const {
                        [Group](const auto &T) { return T.second == Group; });
 }
 
-void ThreadPool::wait() {
+void StdThreadPool::wait() {
   assert(!isWorkerThread()); // Would deadlock waiting for itself.
   // Wait for all threads to complete and the queue to be empty
   std::unique_lock<std::mutex> LockGuard(QueueLock);
@@ -150,7 +149,7 @@ void ThreadPool::wait() {
                            [&] { return workCompletedUnlocked(nullptr); });
 }
 
-void ThreadPool::wait(ThreadPoolTaskGroup &Group) {
+void StdThreadPool::wait(ThreadPoolTaskGroup &Group) {
   // Wait for all threads in the group to complete.
   if (!isWorkerThread()) {
     std::unique_lock<std::mutex> LockGuard(QueueLock);
@@ -167,7 +166,7 @@ void ThreadPool::wait(ThreadPoolTaskGroup &Group) {
   processTasks(&Group);
 }
 
-bool ThreadPool::isWorkerThread() const {
+bool StdThreadPool::isWorkerThread() const {
   llvm::sys::ScopedReader LockGuard(ThreadsLock);
   llvm::thread::id CurrentThreadId = llvm::this_thread::get_id();
   for (const llvm::thread &Thread : Threads)
@@ -177,7 +176,7 @@ bool ThreadPool::isWorkerThread() const {
 }
 
 // The destructor joins all threads, waiting for completion.
-ThreadPool::~ThreadPool() {
+StdThreadPool::~StdThreadPool() {
   {
     std::unique_lock<std::mutex> LockGuard(QueueLock);
     EnableFlag = false;
@@ -188,10 +187,10 @@ ThreadPool::~ThreadPool() {
     Worker.join();
 }
 
-#else // LLVM_ENABLE_THREADS Disabled
+#endif // LLVM_ENABLE_THREADS Disabled
 
 // No threads are launched, issue a warning if ThreadCount is not 0
-ThreadPool::ThreadPool(ThreadPoolStrategy S) : MaxThreadCount(1) {
+SingleThreadExecutor::SingleThreadExecutor(ThreadPoolStrategy S) {
   int ThreadCount = S.compute_thread_count();
   if (ThreadCount != 1) {
     errs() << "Warning: request a ThreadPool with " << ThreadCount
@@ -199,7 +198,7 @@ ThreadPool::ThreadPool(ThreadPoolStrategy S) : MaxThreadCount(1) {
   }
 }
 
-void ThreadPool::wait() {
+void SingleThreadExecutor::wait() {
   // Sequential implementation running the tasks
   while (!Tasks.empty()) {
     auto Task = std::move(Tasks.front().first);
@@ -208,16 +207,14 @@ void ThreadPool::wait() {
   }
 }
 
-void ThreadPool::wait(ThreadPoolTaskGroup &) {
+void SingleThreadExecutor::wait(ThreadPoolTaskGroup &) {
   // Simply wait for all, this works even if recursive (the running task
   // is already removed from the queue).
   wait();
 }
 
-bool ThreadPool::isWorkerThread() const {
+bool SingleThreadExecutor::isWorkerThread() const {
   report_fatal_error("LLVM compiled without multithreading");
 }
 
-ThreadPool::~ThreadPool() { wait(); }
-
-#endif
+SingleThreadExecutor::~SingleThreadExecutor() { wait(); }
diff --git a/llvm/tools/llvm-cov/CoverageReport.h b/llvm/tools/llvm-cov/CoverageReport.h
index 60f751ca967528..b25ed5e35f9a77 100644
--- a/llvm/tools/llvm-cov/CoverageReport.h
+++ b/llvm/tools/llvm-cov/CoverageReport.h
@@ -20,7 +20,7 @@
 
 namespace llvm {
 
-class ThreadPool;
+class ThreadPoolInterface;
 
 /// Displays the code coverage report.
 class CoverageReport {
@@ -104,7 +104,7 @@ class DirectoryCoverageReport {
   /// For calling CoverageReport::prepareSingleFileReport asynchronously
   /// in prepareSubDirectoryReports(). It's not intended to be modified by
   /// generateSubDirectoryReport().
-  ThreadPool *TPool;
+  ThreadPoolInterface *TPool;
 
   /// One report level may correspond to multiple directory levels as we omit
   /// directories which have only one subentry. So we use this Stack to track
diff --git a/llvm/tools/llvm-cov/SourceCoverageViewHTML.h b/llvm/tools/llvm-cov/SourceCoverageViewHTML.h
index 7b97f05b946bd2..32313a3963c430 100644
--- a/llvm/tools/llvm-cov/SourceCoverageViewHTML.h
+++ b/llvm/tools/llvm-cov/SourceCoverageViewHTML.h
@@ -19,8 +19,6 @@ namespace llvm {
 
 using namespace coverage;
 
-class ThreadPool;
-
 struct FileCoverageSummary;
 
 /// A coverage printer for html output.
diff --git a/llvm/unittests/Support/ThreadPool.cpp b/llvm/unittests/Support/ThreadPool.cpp
index cce20b6dd1dfb5..8904f573100d28 100644
--- a/llvm/unittests/Support/ThreadPool.cpp
+++ b/llvm/unittests/Support/ThreadPool.cpp
@@ -27,11 +27,26 @@
 
 #include "gtest/gtest.h"
 
+namespace testing {
+namespace internal {
+// Specialize gtest construct to provide friendlier name in the output.
+#if LLVM_ENABLE_THREADS
+template <> std::string GetTypeName<llvm::StdThreadPool>() {
+  return "llvm::StdThreadPool";
+}
+#endif
+template <> std::string GetTypeName<llvm::SingleThreadExecutor>() {
+  return "llvm::SingleThreadExecutor";
+}
+} // namespace internal
+} // namespace testing
+
 using namespace llvm;
+namespace {
 
 // Fixture for the unittests, allowing to *temporarily* disable the unittests
 // on a particular platform
-class ThreadPoolTest : public testing::Test {
+template <typename ThreadPoolImpl> class ThreadPoolTest : public testing::Test {
   Triple Host;
   SmallVector<Triple::ArchType, 4> UnsupportedArchs;
   SmallVector<Triple::OSType, 4> UnsupportedOSs;
@@ -106,34 +121,42 @@ class ThreadPoolTest : public testing::Test {
   int CurrentPhase; // -1 = error, 0 = setup, 1 = ready, 2+ = custom
 };
 
+using ThreadPoolImpls = ::testing::Types<
+#if LLVM_ENABLE_THREADS
+    StdThreadPool,
+#endif
+    SingleThreadExecutor>;
+
+TYPED_TEST_SUITE(ThreadPoolTest, ThreadPoolImpls);
+
 #define CHECK_UNSUPPORTED()                                                    \
   do {                                                                         \
-    if (isUnsupportedOSOrEnvironment())                                        \
+    if (this->isUnsupportedOSOrEnvironment())                                  \
       GTEST_SKIP();                                                            \
   } while (0);
 
-TEST_F(ThreadPoolTest, AsyncBarrier) {
+TYPED_TEST(ThreadPoolTest, AsyncBarrier) {
   CHECK_UNSUPPORTED();
   // test that async & barrier work together properly.
 
   std::atomic_int checked_in{0};
 
-  ThreadPool Pool;
+  TypeParam Pool;
   for (size_t i = 0; i < 5; ++i) {
     Pool.async([this, &checked_in] {
-      waitForMainThread();
+      this->waitForMainThread();
       ++checked_in;
     });
   }
   ASSERT_EQ(0, checked_in);
-  setMainThreadReady();
+  this->setMainThreadReady();
   Pool.wait();
   ASSERT_EQ(5, checked_in);
 }
 
 static void TestFunc(std::atomic_int &checked_in, int i) { checked_in += i; }
 
-TEST_F(ThreadPoolTest, AsyncBarrierArgs) {
+TYPED_TEST(ThreadPoolTest, AsyncBarrierArgs) {
   CHECK_UNSUPPORTED();
   // Test that async works with a function requiring multiple parameters.
   std::atomic_int checked_in{0};
@@ -146,63 +169,63 @@ TEST_F(ThreadPoolTest, AsyncBarrierArgs) {
   ASSERT_EQ(10, checked_in);
 }
 
-TEST_F(ThreadPoolTest, Async) {
+TYPED_TEST(ThreadPoolTest, Async) {
   CHECK_UNSUPPORTED();
   ThreadPool Pool;
   std::atomic_int i{0};
   Pool.async([this, &i] {
-    waitForMainThread();
+    this->waitForMainThread();
     ++i;
   });
   Pool.async([&i] { ++i; });
   ASSERT_NE(2, i.load());
-  setMainThreadReady();
+  this->setMainThreadReady();
   Pool.wait();
   ASSERT_EQ(2, i.load());
 }
 
-TEST_F(ThreadPoolTest, GetFuture) {
+TYPED_TEST(ThreadPoolTest, GetFuture) {
   CHECK_UNSUPPORTED();
   ThreadPool Pool(hardware_concurrency(2));
   std::atomic_int i{0};
   Pool.async([this, &i] {
-    waitForMainThread();
+    this->waitForMainThread();
     ++i;
   });
   // Force the future using get()
   Pool.async([&i] { ++i; }).get();
   ASSERT_NE(2, i.load());
-  setMainThreadReady();
+  this->setMainThreadReady();
   Pool.wait();
   ASSERT_EQ(2, i.load());
 }
 
-TEST_F(ThreadPoolTest, GetFutureWithResult) {
+TYPED_TEST(ThreadPoolTest, GetFutureWithResult) {
   CHECK_UNSUPPORTED();
   ThreadPool Pool(hardware_concurrency(2));
   auto F1 = Pool.async([] { return 1; });
   auto F2 = Pool.async([] { return 2; });
 
-  setMainThreadReady();
+  this->setMainThreadReady();
   Pool.wait();
   ASSERT_EQ(1, F1.get());
   ASSERT_EQ(2, F2.get());
 }
 
-TEST_F(ThreadPoolTest, GetFutureWithResultAndArgs) {
+TYPED_TEST(ThreadPoolTest, GetFutureWithResultAndArgs) {
   CHECK_UNSUPPORTED();
   ThreadPool Pool(hardware_concurrency(2));
   auto Fn = [](int x) { return x; };
   auto F1 = Pool.async(Fn, 1);
   auto F2 = Pool.async(Fn, 2);
 
-  setMainThreadReady();
+  this->setMainThreadReady();
   Pool.wait();
   ASSERT_EQ(1, F1.get());
   ASSERT_EQ(2, F2.get());
 }
 
-TEST_F(ThreadPoolTest, PoolDestruction) {
+TYPED_TEST(ThreadPoolTest, PoolDestruction) {
   CHECK_UNSUPPORTED();
   // Test that we are waiting on destruction
   std::atomic_int checked_in{0};
@@ -210,18 +233,18 @@ TEST_F(ThreadPoolTest, PoolDestruction) {
     ThreadPool Pool;
     for (size_t i = 0; i < 5; ++i) {
       Pool.async([this, &checked_in] {
-        waitForMainThread();
+        this->waitForMainThread();
         ++checked_in;
       });
     }
     ASSERT_EQ(0, checked_in);
-    setMainThreadReady();
+    this->setMainThreadReady();
   }
   ASSERT_EQ(5, checked_in);
 }
 
 // Check running tasks in different groups.
-TEST_F(ThreadPoolTest, Groups) {
+TYPED_TEST(ThreadPoolTest, Groups) {
   CHECK_UNSUPPORTED();
   // Need at least two threads, as the task in group2
   // might block a thread until all tasks in group1 finish.
@@ -229,7 +252,7 @@ TEST_F(ThreadPoolTest, Groups) {
   if (S.compute_thread_count() < 2)
     GTEST_SKIP();
   ThreadPool Pool(S);
-  PhaseResetHelper Helper(this);
+  typename TestFixture::PhaseResetHelper Helper(this);
   ThreadPoolTaskGroup Group1(Pool);
   ThreadPoolTaskGroup Group2(Pool);
 
@@ -241,30 +264,30 @@ TEST_F(ThreadPoolTest, Groups) {
 
   for (size_t i = 0; i < 5; ++i) {
     Group1.async([this, &checked_in1] {
-      waitForMainThread();
+      this->waitForMainThread();
       ++checked_in1;
     });
   }
   Group2.async([this, &checked_in2] {
-    waitForPhase(2);
+    this->waitForPhase(2);
     ++checked_in2;
   });
   ASSERT_EQ(0, checked_in1);
   ASSERT_EQ(0, checked_in2);
   // Start first group and wait for it.
-  setMainThreadReady();
+  this->setMainThreadReady();
   Group1.wait();
   ASSERT_EQ(5, checked_in1);
   // Second group has not yet finished, start it and wait for it.
   ASSERT_EQ(0, checked_in2);
-  setPhase(2);
+  this->setPhase(2);
   Group2.wait();
   ASSERT_EQ(5, checked_in1);
   ASSERT_EQ(1, checked_in2);
 }
 
 // Check recursive tasks.
-TEST_F(ThreadPoolTest, RecursiveGroups) {
+TYPED_TEST(ThreadPoolTest, RecursiveGroups) {
   CHECK_UNSUPPORTED();
   ThreadPool Pool;
   ThreadPoolTaskGroup Group(Pool);
@@ -273,7 +296,7 @@ TEST_F(ThreadPoolTest, RecursiveGroups) {
 
   for (size_t i = 0; i < 5; ++i) {
     Group.async([this, &Pool, &checked_in1] {
-      waitForMainThread();
+      this->waitForMainThread();
 
       ThreadPoolTaskGroup LocalGroup(Pool);
 
@@ -291,18 +314,18 @@ TEST_F(ThreadPoolTest, RecursiveGroups) {
     });
   }
   ASSERT_EQ(0, checked_in1);
-  setMainThreadReady();
+  this->setMainThreadReady();
   Group.wait();
   ASSERT_EQ(5, checked_in1);
 }
 
-TEST_F(ThreadPoolTest, RecursiveWaitDeadlock) {
+TYPED_TEST(ThreadPoolTest, RecursiveWaitDeadlock) {
   CHECK_UNSUPPORTED();
   ThreadPoolStrategy S = hardware_concurrency(2);
   if (S.compute_thread_count() < 2)
     GTEST_SKIP();
   ThreadPool Pool(S);
-  PhaseResetHelper Helper(this);
+  typename TestFixture::PhaseResetHelper Helper(this);
   ThreadPoolTaskGroup Group(Pool);
 
   // Test that a thread calling wait() for a group and is waiting for more tasks
@@ -312,17 +335,17 @@ TEST_F(ThreadPoolTest, RecursiveWaitDeadlock) {
   // Task A runs in the first thread. It finishes and leaves
   // the background thread waiting for more tasks.
   Group.async([this] {
-    waitForMainThread();
-    setPhase(2);
+    this->waitForMainThread();
+    this->setPhase(2);
   });
   // Task B is run in a second thread, it launches yet another
   // task C in a different group, which will be handled by the waiting
   // thread started above.
   Group.async([this, &Pool] {
-    waitForPhase(2);
+    this->waitForPhase(2);
     ThreadPoolTaskGroup LocalGroup(Pool);
     LocalGroup.async([this] {
-      waitForPhase(3);
+      this->waitForPhase(3);
       // Give the other thread enough time to check that there's no task
       // to process and suspend waiting for a notification. This is indeed racy,
       // but probably the best that can be done.
@@ -332,10 +355,10 @@ TEST_F(ThreadPoolTest, RecursiveWaitDeadlock) {
     // to finish. This test checks that it does not deadlock. If the
     // `NotifyGroup` handling in ThreadPool::processTasks() didn't take place,
     // this task B would be stuck waiting for tasks to arrive.
-    setPhase(3);
+    this->setPhase(3);
     LocalGroup.wait();
   });
-  setMainThreadReady();
+  this->setMainThreadReady();
   Group.wait();
 }
 
@@ -346,8 +369,9 @@ TEST_F(ThreadPoolTest, RecursiveWaitDeadlock) {
 // isn't implemented for Unix (need AffinityMask in Support/Unix/Program.inc).
 #ifdef _WIN32
 
+template <typename ThreadPoolImpl>
 SmallVector<llvm::BitVector, 0>
-ThreadPoolTest::RunOnAllSockets(ThreadPoolStrategy S) {
+ThreadPoolTest<ThreadPoolImpl>::RunOnAllSockets(ThreadPoolStrategy S) {
   llvm::SetVector<llvm::BitVector> ThreadsUsed;
   std::mutex Lock;
   {
@@ -363,7 +387,7 @@ ThreadPoolTest::RunOnAllSockets(ThreadPoolStrategy S) {
           ++Active;
           AllThreads.notify_one();
         }
-        waitForMainThread();
+        this->waitForMainThread();
         std::lock_guard<std::mutex> Guard(Lock);
         auto Mask = llvm::get_thread_affinity_mask();
         ThreadsUsed.insert(Mask);
@@ -375,12 +399,12 @@ ThreadPoolTest::RunOnAllSockets(ThreadPoolStrategy S) {
       AllThreads.wait(Guard,
                       [&]() { return Active == S.compute_thread_count(); });
     }
-    setMainThreadReady();
+    this->setMainThreadReady();
   }
   return ThreadsUsed.takeVector();
 }
 
-TEST_F(ThreadPoolTest, AllThreads_UseAllRessources) {
+TYPED_TEST(ThreadPoolTest, AllThreads_UseAllRessources) {
   CHECK_UNSUPPORTED();
   // After Windows 11, the OS is free to deploy the threads on any CPU socket.
   // We cannot relibly ensure that all thread affinity mask are covered,
@@ -391,7 +415,7 @@ TEST_F(ThreadPoolTest, AllThreads_UseAllRessources) {
   ASSERT_EQ(llvm::get_cpus(), ThreadsUsed.size());
 }
 
-TEST_F(ThreadPoolTest, AllThreads_OneThreadPerCore) {
+TYPED_TEST(ThreadPoolTest, AllThreads_OneThreadPerCore) {
   CHECK_UNSUPPORTED();
   // After Windows 11, the OS is free to deploy the threads on any CPU socket.
   // We cannot relibly ensure that all thread affinity mask are covered,
@@ -412,7 +436,7 @@ static cl::opt<std::string> ThreadPoolTestStringArg1("thread-pool-string-arg1");
 #define setenv(name, var, ignore) _putenv_s(name, var)
 #endif
 
-TEST_F(ThreadPoolTest, AffinityMask) {
+TYPED_TEST(ThreadPoolTest, AffinityMask) {
   CHECK_UNSUPPORTED();
 
   // Skip this test if less than 4 threads are available.
@@ -451,3 +475,5 @@ TEST_F(ThreadPoolTest, AffinityMask) {
 
 #endif // #ifdef _WIN32
 #endif // #if LLVM_ENABLE_THREADS == 1
+
+} // namespace
diff --git a/mlir/include/mlir/CAPI/Support.h b/mlir/include/mlir/CAPI/Support.h
index 82aa05185858e3..622745256111e1 100644
--- a/mlir/include/mlir/CAPI/Support.h
+++ b/mlir/include/mlir/CAPI/Support.h
@@ -22,7 +22,7 @@
 #include "llvm/ADT/StringRef.h"
 
 namespace llvm {
-class ThreadPool;
+class ThreadPoolInterface;
 } // namespace llvm
 
 /// Converts a StringRef into its MLIR C API equivalent.
@@ -45,7 +45,7 @@ inline mlir::LogicalResult unwrap(MlirLogicalResult res) {
   return mlir::success(mlirLogicalResultIsSuccess(res));
 }
 
-DEFINE_C_API_PTR_METHODS(MlirLlvmThreadPool, llvm::ThreadPool)
+DEFINE_C_API_PTR_METHODS(MlirLlvmThreadPool, llvm::ThreadPoolInterface)
 DEFINE_C_API_METHODS(MlirTypeID, mlir::TypeID)
 DEFINE_C_API_PTR_METHODS(MlirTypeIDAllocator, mlir::TypeIDAllocator)
 
diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index d9e140bd75f726..2ad35d8f78ee35 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -17,7 +17,7 @@
 #include <vector>
 
 namespace llvm {
-class ThreadPool;
+class ThreadPoolInterface;
 } // namespace llvm
 
 namespace mlir {
@@ -162,7 +162,7 @@ class MLIRContext {
   /// The command line debugging flag `--mlir-disable-threading` will still
   /// prevent threading from being enabled and threading won't be enabled after
   /// this call in this case.
-  void setThreadPool(llvm::ThreadPool &pool);
+  void setThreadPool(llvm::ThreadPoolInterface &pool);
 
   /// Return the number of threads used by the thread pool in this context. The
   /// number of computed hardware threads can change over the lifetime of a
@@ -175,7 +175,7 @@ class MLIRContext {
   /// multithreading be enabled within the context, and should generally not be
   /// used directly. Users should instead prefer the threading utilities within
   /// Threading.h.
-  llvm::ThreadPool &getThreadPool();
+  llvm::ThreadPoolInterface &getThreadPool();
 
   /// Return true if we should attach the operation to diagnostics emitted via
   /// Operation::emit.
diff --git a/mlir/include/mlir/IR/Threading.h b/mlir/include/mlir/IR/Threading.h
index 0f71dc27cf391f..3ceab6b3e883a5 100644
--- a/mlir/include/mlir/IR/Threading.h
+++ b/mlir/include/mlir/IR/Threading.h
@@ -66,7 +66,7 @@ LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin,
   };
 
   // Otherwise, process the elements in parallel.
-  llvm::ThreadPool &threadPool = context->getThreadPool();
+  llvm::ThreadPoolInterface &threadPool = context->getThreadPool();
   llvm::ThreadPoolTaskGroup tasksGroup(threadPool);
   size_t numActions = std::min(numElements, threadPool.getMaxConcurrency());
   for (unsigned i = 0; i < numActions; ++i)
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index a97cfe5b0ea364..cdb64f4ec4a40f 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -28,6 +28,7 @@
 #include "mlir/IR/Visitors.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Parser/Parser.h"
+#include "llvm/Support/ThreadPool.h"
 
 #include <cstddef>
 #include <memory>
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 58ebe4fdec202d..92568bd311e394 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -170,11 +170,11 @@ class MLIRContextImpl {
   /// It can't be nullptr when multi-threading is enabled. Otherwise if
   /// multi-threading is disabled, and the threadpool wasn't externally provided
   /// using `setThreadPool`, this will be nullptr.
-  llvm::ThreadPool *threadPool = nullptr;
+  llvm::ThreadPoolInterface *threadPool = nullptr;
 
   /// In case where the thread pool is owned by the context, this ensures
   /// destruction with the context.
-  std::unique_ptr<llvm::ThreadPool> ownedThreadPool;
+  std::unique_ptr<llvm::ThreadPoolInterface> ownedThreadPool;
 
   /// An allocator used for AbstractAttribute and AbstractType objects.
   llvm::BumpPtrAllocator abstractDialectSymbolAllocator;
@@ -626,7 +626,7 @@ void MLIRContext::disableMultithreading(bool disable) {
   }
 }
 
-void MLIRContext::setThreadPool(llvm::ThreadPool &pool) {
+void MLIRContext::setThreadPool(llvm::ThreadPoolInterface &pool) {
   assert(!isMultithreadingEnabled() &&
          "expected multi-threading to be disabled when setting a ThreadPool");
   impl->threadPool = &pool;
@@ -644,7 +644,7 @@ unsigned MLIRContext::getNumThreads() {
   return 1;
 }
 
-llvm::ThreadPool &MLIRContext::getThreadPool() {
+llvm::ThreadPoolInterface &MLIRContext::getThreadPool() {
   assert(isMultithreadingEnabled() &&
          "expected multi-threading to be enabled within the context");
   assert(impl->threadPool &&
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 2755a949fb947c..b62557153b4167 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -431,7 +431,7 @@ static LogicalResult processBuffer(raw_ostream &os,
                                    std::unique_ptr<MemoryBuffer> ownedBuffer,
                                    const MlirOptMainConfig &config,
                                    DialectRegistry &registry,
-                                   llvm::ThreadPool *threadPool) {
+                                   llvm::ThreadPoolInterface *threadPool) {
   // Tell sourceMgr about this buffer, which is what the parser will pick up.
   auto sourceMgr = std::make_shared<SourceMgr>();
   sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
@@ -517,7 +517,7 @@ LogicalResult mlir::MlirOptMain(llvm::raw_ostream &outputStream,
   // up into small pieces and checks each independently.
   // We use an explicit threadpool to avoid creating and joining/destroying
   // threads for each of the split.
-  ThreadPool *threadPool = nullptr;
+  ThreadPoolInterface *threadPool = nullptr;
 
   // Create a temporary context for the sake of checking if
   // --mlir-disable-threading was passed on the command line.



More information about the llvm-commits mailing list