[Mlir-commits] [mlir] 072e0aa - Enable the use of ThreadPoolTaskGroup in MLIR threading helper to enable nested parallelism
Mehdi Amini
llvmlistbot at llvm.org
Fri May 6 12:40:40 PDT 2022
Author: Mehdi Amini
Date: 2022-05-06T19:40:22Z
New Revision: 072e0aabbc457b8802dcf7b483e3acebfbde1c33
URL: https://github.com/llvm/llvm-project/commit/072e0aabbc457b8802dcf7b483e3acebfbde1c33
DIFF: https://github.com/llvm/llvm-project/commit/072e0aabbc457b8802dcf7b483e3acebfbde1c33.diff
LOG: Enable the use of ThreadPoolTaskGroup in MLIR threading helper to enable nested parallelism
The LLVM ThreadPool recently got the addition of the concept of
ThreadPoolTaskGroup: this is a way to "partition" the threadpool
into a group of tasks and enable nested parallelism through this
grouping at every level of nesting.
We make use of this feature in MLIR threading abstraction to fix a long
lasting TODO and enable nested parallelism.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D124902
Added:
Modified:
mlir/include/mlir/IR/Threading.h
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Threading.h b/mlir/include/mlir/IR/Threading.h
index 0d60e95b54c11..dd99039e298ac 100644
--- a/mlir/include/mlir/IR/Threading.h
+++ b/mlir/include/mlir/IR/Threading.h
@@ -41,10 +41,7 @@ LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin,
// If multithreading is disabled or there is a small number of elements,
// process the elements directly on this thread.
- // FIXME: ThreadPool should allow work stealing to avoid deadlocks when
- // scheduling work within a worker thread.
- if (!context->isMultithreadingEnabled() || numElements <= 1 ||
- context->getThreadPool().isWorkerThread()) {
+ if (!context->isMultithreadingEnabled() || numElements <= 1) {
for (; begin != end; ++begin)
if (failed(func(*begin)))
return failure();
@@ -70,16 +67,14 @@ LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin,
// Otherwise, process the elements in parallel.
llvm::ThreadPool &threadPool = context->getThreadPool();
+ llvm::ThreadPoolTaskGroup tasksGroup(threadPool);
size_t numActions = std::min(numElements, threadPool.getThreadCount());
- SmallVector<std::shared_future<void>> threadFutures;
- threadFutures.reserve(numActions - 1);
- for (unsigned i = 1; i < numActions; ++i)
- threadFutures.emplace_back(threadPool.async(processFn));
- processFn();
-
- // Wait for all of the threads to finish.
- for (std::shared_future<void> &future : threadFutures)
- future.wait();
+ for (unsigned i = 0; i < numActions; ++i)
+ tasksGroup.async(processFn);
+ // If the current thread is a worker thread from the pool, then waiting for
+ // the task group allows the current thread to also participate in processing
+ // tasks from the group, which avoid any deadlock/starvation.
+ tasksGroup.wait();
return failure(processingFailed);
}
More information about the Mlir-commits
mailing list