[Mlir-commits] [mlir] 072e0aa - Enable the use of ThreadPoolTaskGroup in MLIR threading helper to enable nested parallelism

Mehdi Amini llvmlistbot at llvm.org
Fri May 6 12:40:40 PDT 2022


Author: Mehdi Amini
Date: 2022-05-06T19:40:22Z
New Revision: 072e0aabbc457b8802dcf7b483e3acebfbde1c33

URL: https://github.com/llvm/llvm-project/commit/072e0aabbc457b8802dcf7b483e3acebfbde1c33
DIFF: https://github.com/llvm/llvm-project/commit/072e0aabbc457b8802dcf7b483e3acebfbde1c33.diff

LOG: Enable the use of ThreadPoolTaskGroup in MLIR threading helper to enable nested parallelism

The LLVM ThreadPool recently got the addition of the concept of
ThreadPoolTaskGroup: this is a way to "partition" the threadpool
into a group of tasks and enable nested parallelism through this
grouping at every level of nesting.
We make use of this feature in MLIR threading abstraction to fix a long
lasting TODO and enable nested parallelism.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D124902

Added: 
    

Modified: 
    mlir/include/mlir/IR/Threading.h

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Threading.h b/mlir/include/mlir/IR/Threading.h
index 0d60e95b54c11..dd99039e298ac 100644
--- a/mlir/include/mlir/IR/Threading.h
+++ b/mlir/include/mlir/IR/Threading.h
@@ -41,10 +41,7 @@ LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin,
 
   // If multithreading is disabled or there is a small number of elements,
   // process the elements directly on this thread.
-  // FIXME: ThreadPool should allow work stealing to avoid deadlocks when
-  // scheduling work within a worker thread.
-  if (!context->isMultithreadingEnabled() || numElements <= 1 ||
-      context->getThreadPool().isWorkerThread()) {
+  if (!context->isMultithreadingEnabled() || numElements <= 1) {
     for (; begin != end; ++begin)
       if (failed(func(*begin)))
         return failure();
@@ -70,16 +67,14 @@ LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin,
 
   // Otherwise, process the elements in parallel.
   llvm::ThreadPool &threadPool = context->getThreadPool();
+  llvm::ThreadPoolTaskGroup tasksGroup(threadPool);
   size_t numActions = std::min(numElements, threadPool.getThreadCount());
-  SmallVector<std::shared_future<void>> threadFutures;
-  threadFutures.reserve(numActions - 1);
-  for (unsigned i = 1; i < numActions; ++i)
-    threadFutures.emplace_back(threadPool.async(processFn));
-  processFn();
-
-  // Wait for all of the threads to finish.
-  for (std::shared_future<void> &future : threadFutures)
-    future.wait();
+  for (unsigned i = 0; i < numActions; ++i)
+    tasksGroup.async(processFn);
+  // If the current thread is a worker thread from the pool, then waiting for
+  // the task group allows the current thread to also participate in processing
+  // tasks from the group, which avoid any deadlock/starvation.
+  tasksGroup.wait();
   return failure(processingFailed);
 }
 


        


More information about the Mlir-commits mailing list