[llvm] [mlir] Split the llvm::ThreadPool into an abstract base class and an impleme… (PR #82094)

Mehdi Amini via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 19 18:11:36 PST 2024


https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/82094

>From 07e8499ccaa428adacf28e73dd947aa88f9f1f98 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Fri, 16 Feb 2024 21:55:57 -0800
Subject: [PATCH] Split the llvm::ThreadPool into an abstract base class and an
 implementation

This decouples the public API used to enqueue tasks and wait for completion
from the actual implementation, and opens up the possibility for clients to
set their own thread pool implementation for the pool.
---
 llvm/include/llvm/Support/ThreadPool.h  | 144 +++++++++++++++---------
 llvm/lib/Support/ThreadPool.cpp         |   2 +
 mlir/include/mlir/IR/MLIRContext.h      |   6 +-
 mlir/include/mlir/IR/Threading.h        |   2 +-
 mlir/lib/CAPI/IR/IR.cpp                 |   1 +
 mlir/lib/IR/MLIRContext.cpp             |   8 +-
 mlir/lib/Tools/mlir-opt/MlirOptMain.cpp |   4 +-
 7 files changed, 104 insertions(+), 63 deletions(-)

diff --git a/llvm/include/llvm/Support/ThreadPool.h b/llvm/include/llvm/Support/ThreadPool.h
index 03ebd35aa46dc4..af83e7870aa16b 100644
--- a/llvm/include/llvm/Support/ThreadPool.h
+++ b/llvm/include/llvm/Support/ThreadPool.h
@@ -32,11 +32,8 @@ namespace llvm {
 
 class ThreadPoolTaskGroup;
 
-/// A ThreadPool for asynchronous parallel execution on a defined number of
-/// threads.
-///
-/// The pool keeps a vector of threads alive, waiting on a condition variable
-/// for some work to become available.
+/// This defines the abstract base interface for a ThreadPool allowing
+/// asynchronous parallel execution on a defined number of threads.
 ///
 /// It is possible to reuse one thread pool for different groups of tasks
 /// by grouping tasks using ThreadPoolTaskGroup. All tasks are processed using
@@ -49,16 +46,31 @@ class ThreadPoolTaskGroup;
 /// available threads are used up by tasks waiting for a task that has no thread
 /// left to run on (this includes waiting on the returned future). It should be
 /// generally safe to wait() for a group as long as groups do not form a cycle.
-class ThreadPool {
+class ThreadPoolInterface {
+  // The actual method to enqueue a task to be defined by the concrete
+  // implementation.
+  virtual void asyncEnqueue(std::function<void()> Task,
+                            ThreadPoolTaskGroup *Group) = 0;
+
 public:
-  /// Construct a pool using the hardware strategy \p S for mapping hardware
-  /// execution resources (threads, cores, CPUs)
-  /// Defaults to using the maximum execution resources in the system, but
-  /// accounting for the affinity mask.
-  ThreadPool(ThreadPoolStrategy S = hardware_concurrency());
+  // Destroying the pool will drain the pending tasks and wait. The current
+  // thread may participate in the execution of the pending tasks.
+  virtual ~ThreadPoolInterface();
 
-  /// Blocking destructor: the pool will wait for all the threads to complete.
-  ~ThreadPool();
+  /// Blocking wait for all the threads to complete and the queue to be empty.
+  /// It is an error to try to add new tasks while blocking on this call.
+  /// Calling wait() from a task would deadlock waiting for itself.
+  virtual void wait() = 0;
+
+  /// Blocking wait for only all the threads in the given group to complete.
+  /// It is possible to wait even inside a task, but waiting (directly or
+  /// indirectly) on itself will deadlock. If called from a task running on a
+  /// worker thread, the call may process pending tasks while waiting in order
+  /// not to waste the thread.
+  virtual void wait(ThreadPoolTaskGroup &Group) = 0;
+
+  // Returns the maximum number of worker this pool can eventually grow to.
+  virtual unsigned getMaxConcurrency() const = 0;
 
   /// Asynchronous submission of a task to the pool. The returned future can be
   /// used to wait for the task to finish and is *non-blocking* on destruction.
@@ -92,30 +104,32 @@ class ThreadPool {
                      &Group);
   }
 
-  /// Blocking wait for all the threads to complete and the queue to be empty.
-  /// It is an error to try to add new tasks while blocking on this call.
-  /// Calling wait() from a task would deadlock waiting for itself.
-  void wait();
+private:
+  /// Asynchronous submission of a task to the pool. The returned future can be
+  /// used to wait for the task to finish and is *non-blocking* on destruction.
+  template <typename ResTy>
+  std::shared_future<ResTy> asyncImpl(std::function<ResTy()> Task,
+                                      ThreadPoolTaskGroup *Group) {
 
-  /// Blocking wait for only all the threads in the given group to complete.
-  /// It is possible to wait even inside a task, but waiting (directly or
-  /// indirectly) on itself will deadlock. If called from a task running on a
-  /// worker thread, the call may process pending tasks while waiting in order
-  /// not to waste the thread.
-  void wait(ThreadPoolTaskGroup &Group);
+#if LLVM_ENABLE_THREADS
+    /// Wrap the Task in a std::function<void()> that sets the result of the
+    /// corresponding future.
+    auto R = createTaskAndFuture(Task);
 
-  // Returns the maximum number of worker threads in the pool, not the current
-  // number of threads!
-  unsigned getMaxConcurrency() const { return MaxThreadCount; }
+    asyncEnqueue(std::move(R.first), Group);
+    return R.second.share();
 
-  // TODO: misleading legacy name warning!
-  LLVM_DEPRECATED("Use getMaxConcurrency instead", "getMaxConcurrency")
-  unsigned getThreadCount() const { return MaxThreadCount; }
+#else // LLVM_ENABLE_THREADS Disabled
 
-  /// Returns true if the current thread is a worker thread of this thread pool.
-  bool isWorkerThread() const;
+    // Get a Future with launch::deferred execution using std::async
+    auto Future = std::async(std::launch::deferred, std::move(Task)).share();
+    // Wrap the future so that both ThreadPool::wait() can operate and the
+    // returned future can be sync'ed on.
+    asyncEnqueue([Future]() { Future.get(); }, Group);
+    return Future;
+#endif
+  }
 
-private:
   /// Helpers to create a promise and a callable wrapper of \p Task that sets
   /// the result of the promise. Returns the callable and a future to access the
   /// result.
@@ -140,22 +154,56 @@ class ThreadPool {
             },
             std::move(F)};
   }
+};
+
+/// A ThreadPool implementation using std::threads.
+///
+/// The pool keeps a vector of threads alive, waiting on a condition variable
+/// for some work to become available.
+class ThreadPool : public ThreadPoolInterface {
+public:
+  /// Construct a pool using the hardware strategy \p S for mapping hardware
+  /// execution resources (threads, cores, CPUs)
+  /// Defaults to using the maximum execution resources in the system, but
+  /// accounting for the affinity mask.
+  ThreadPool(ThreadPoolStrategy S = hardware_concurrency());
 
+  /// Blocking destructor: the pool will wait for all the threads to complete.
+  ~ThreadPool() override;
+
+  /// Blocking wait for all the threads to complete and the queue to be empty.
+  /// It is an error to try to add new tasks while blocking on this call.
+  /// Calling wait() from a task would deadlock waiting for itself.
+  void wait() override;
+
+  /// Blocking wait for only all the threads in the given group to complete.
+  /// It is possible to wait even inside a task, but waiting (directly or
+  /// indirectly) on itself will deadlock. If called from a task running on a
+  /// worker thread, the call may process pending tasks while waiting in order
+  /// not to waste the thread.
+  void wait(ThreadPoolTaskGroup &Group) override;
+
+  // Returns the maximum number of worker threads in the pool, not the current
+  // number of threads!
+  unsigned getMaxConcurrency() const override { return MaxThreadCount; }
+
+  // TODO: misleading legacy name warning!
+  LLVM_DEPRECATED("Use getMaxConcurrency instead", "getMaxConcurrency")
+  unsigned getThreadCount() const { return MaxThreadCount; }
+
+  /// Returns true if the current thread is a worker thread of this thread pool.
+  bool isWorkerThread() const;
+
+private:
   /// Returns true if all tasks in the given group have finished (nullptr means
   /// all tasks regardless of their group). QueueLock must be locked.
   bool workCompletedUnlocked(ThreadPoolTaskGroup *Group) const;
 
   /// Asynchronous submission of a task to the pool. The returned future can be
   /// used to wait for the task to finish and is *non-blocking* on destruction.
-  template <typename ResTy>
-  std::shared_future<ResTy> asyncImpl(std::function<ResTy()> Task,
-                                      ThreadPoolTaskGroup *Group) {
-
+  void asyncEnqueue(std::function<void()> Task,
+                    ThreadPoolTaskGroup *Group) override {
 #if LLVM_ENABLE_THREADS
-    /// Wrap the Task in a std::function<void()> that sets the result of the
-    /// corresponding future.
-    auto R = createTaskAndFuture(Task);
-
     int requestedThreads;
     {
       // Lock the queue and push the new task
@@ -163,21 +211,11 @@ class ThreadPool {
 
       // Don't allow enqueueing after disabling the pool
       assert(EnableFlag && "Queuing a thread during ThreadPool destruction");
-      Tasks.emplace_back(std::make_pair(std::move(R.first), Group));
+      Tasks.emplace_back(std::make_pair(std::move(Task), Group));
       requestedThreads = ActiveThreads + Tasks.size();
     }
     QueueCondition.notify_one();
     grow(requestedThreads);
-    return R.second.share();
-
-#else // LLVM_ENABLE_THREADS Disabled
-
-    // Get a Future with launch::deferred execution using std::async
-    auto Future = std::async(std::launch::deferred, std::move(Task)).share();
-    // Wrap the future so that both ThreadPool::wait() can operate and the
-    // returned future can be sync'ed on.
-    Tasks.emplace_back(std::make_pair([Future]() { Future.get(); }, Group));
-    return Future;
 #endif
   }
 
@@ -227,7 +265,7 @@ class ThreadPool {
 class ThreadPoolTaskGroup {
 public:
   /// The ThreadPool argument is the thread pool to forward calls to.
-  ThreadPoolTaskGroup(ThreadPool &Pool) : Pool(Pool) {}
+  ThreadPoolTaskGroup(ThreadPoolInterface &Pool) : Pool(Pool) {}
 
   /// Blocking destructor: will wait for all the tasks in the group to complete
   /// by calling ThreadPool::wait().
@@ -244,7 +282,7 @@ class ThreadPoolTaskGroup {
   void wait() { Pool.wait(*this); }
 
 private:
-  ThreadPool &Pool;
+  ThreadPoolInterface &Pool;
 };
 
 } // namespace llvm
diff --git a/llvm/lib/Support/ThreadPool.cpp b/llvm/lib/Support/ThreadPool.cpp
index 4eef339000e198..16567d90bd1ad6 100644
--- a/llvm/lib/Support/ThreadPool.cpp
+++ b/llvm/lib/Support/ThreadPool.cpp
@@ -23,6 +23,8 @@
 
 using namespace llvm;
 
+ThreadPoolInterface::~ThreadPoolInterface() = default;
+
 #if LLVM_ENABLE_THREADS
 
 // A note on thread groups: Tasks are by default in no group (represented
diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index d9e140bd75f726..2ad35d8f78ee35 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -17,7 +17,7 @@
 #include <vector>
 
 namespace llvm {
-class ThreadPool;
+class ThreadPoolInterface;
 } // namespace llvm
 
 namespace mlir {
@@ -162,7 +162,7 @@ class MLIRContext {
   /// The command line debugging flag `--mlir-disable-threading` will still
   /// prevent threading from being enabled and threading won't be enabled after
   /// this call in this case.
-  void setThreadPool(llvm::ThreadPool &pool);
+  void setThreadPool(llvm::ThreadPoolInterface &pool);
 
   /// Return the number of threads used by the thread pool in this context. The
   /// number of computed hardware threads can change over the lifetime of a
@@ -175,7 +175,7 @@ class MLIRContext {
   /// multithreading be enabled within the context, and should generally not be
   /// used directly. Users should instead prefer the threading utilities within
   /// Threading.h.
-  llvm::ThreadPool &getThreadPool();
+  llvm::ThreadPoolInterface &getThreadPool();
 
   /// Return true if we should attach the operation to diagnostics emitted via
   /// Operation::emit.
diff --git a/mlir/include/mlir/IR/Threading.h b/mlir/include/mlir/IR/Threading.h
index 0f71dc27cf391f..3ceab6b3e883a5 100644
--- a/mlir/include/mlir/IR/Threading.h
+++ b/mlir/include/mlir/IR/Threading.h
@@ -66,7 +66,7 @@ LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin,
   };
 
   // Otherwise, process the elements in parallel.
-  llvm::ThreadPool &threadPool = context->getThreadPool();
+  llvm::ThreadPoolInterface &threadPool = context->getThreadPool();
   llvm::ThreadPoolTaskGroup tasksGroup(threadPool);
   size_t numActions = std::min(numElements, threadPool.getMaxConcurrency());
   for (unsigned i = 0; i < numActions; ++i)
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index a97cfe5b0ea364..cdb64f4ec4a40f 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -28,6 +28,7 @@
 #include "mlir/IR/Visitors.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Parser/Parser.h"
+#include "llvm/Support/ThreadPool.h"
 
 #include <cstddef>
 #include <memory>
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 58ebe4fdec202d..92568bd311e394 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -170,11 +170,11 @@ class MLIRContextImpl {
   /// It can't be nullptr when multi-threading is enabled. Otherwise if
   /// multi-threading is disabled, and the threadpool wasn't externally provided
   /// using `setThreadPool`, this will be nullptr.
-  llvm::ThreadPool *threadPool = nullptr;
+  llvm::ThreadPoolInterface *threadPool = nullptr;
 
   /// In case where the thread pool is owned by the context, this ensures
   /// destruction with the context.
-  std::unique_ptr<llvm::ThreadPool> ownedThreadPool;
+  std::unique_ptr<llvm::ThreadPoolInterface> ownedThreadPool;
 
   /// An allocator used for AbstractAttribute and AbstractType objects.
   llvm::BumpPtrAllocator abstractDialectSymbolAllocator;
@@ -626,7 +626,7 @@ void MLIRContext::disableMultithreading(bool disable) {
   }
 }
 
-void MLIRContext::setThreadPool(llvm::ThreadPool &pool) {
+void MLIRContext::setThreadPool(llvm::ThreadPoolInterface &pool) {
   assert(!isMultithreadingEnabled() &&
          "expected multi-threading to be disabled when setting a ThreadPool");
   impl->threadPool = &pool;
@@ -644,7 +644,7 @@ unsigned MLIRContext::getNumThreads() {
   return 1;
 }
 
-llvm::ThreadPool &MLIRContext::getThreadPool() {
+llvm::ThreadPoolInterface &MLIRContext::getThreadPool() {
   assert(isMultithreadingEnabled() &&
          "expected multi-threading to be enabled within the context");
   assert(impl->threadPool &&
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index f01c7631decb77..668221eadd5fc1 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -429,7 +429,7 @@ static LogicalResult processBuffer(raw_ostream &os,
                                    std::unique_ptr<MemoryBuffer> ownedBuffer,
                                    const MlirOptMainConfig &config,
                                    DialectRegistry &registry,
-                                   llvm::ThreadPool *threadPool) {
+                                   llvm::ThreadPoolInterface *threadPool) {
   // Tell sourceMgr about this buffer, which is what the parser will pick up.
   auto sourceMgr = std::make_shared<SourceMgr>();
   sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
@@ -515,7 +515,7 @@ LogicalResult mlir::MlirOptMain(llvm::raw_ostream &outputStream,
   // up into small pieces and checks each independently.
   // We use an explicit threadpool to avoid creating and joining/destroying
   // threads for each of the split.
-  ThreadPool *threadPool = nullptr;
+  ThreadPoolInterface *threadPool = nullptr;
 
   // Create a temporary context for the sake of checking if
   // --mlir-disable-threading was passed on the command line.



More information about the llvm-commits mailing list