[llvm-branch-commits] [mlir] 94e645f - [mlir] Async: Add numWorkerThreads argument to createAsyncParallelForPass

Eugene Zhulenev via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Dec 8 10:34:25 PST 2020


Author: Eugene Zhulenev
Date: 2020-12-08T10:30:14-08:00
New Revision: 94e645f9cce8fba26b4aec069103794f1779065f

URL: https://github.com/llvm/llvm-project/commit/94e645f9cce8fba26b4aec069103794f1779065f
DIFF: https://github.com/llvm/llvm-project/commit/94e645f9cce8fba26b4aec069103794f1779065f.diff

LOG: [mlir] Async: Add numWorkerThreads argument to createAsyncParallelForPass

Add an option to pass the number of worker threads to select the number of async regions for parallel for transformation.
```
std::unique_ptr<OperationPass<FuncOp>> createAsyncParallelForPass(int numWorkerThreads);
```

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D92835

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Async/Passes.h
    mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Async/Passes.h b/mlir/include/mlir/Dialect/Async/Passes.h
index 9716bde76593..ab5abdc28611 100644
--- a/mlir/include/mlir/Dialect/Async/Passes.h
+++ b/mlir/include/mlir/Dialect/Async/Passes.h
@@ -19,6 +19,9 @@ namespace mlir {
 
 std::unique_ptr<OperationPass<FuncOp>> createAsyncParallelForPass();
 
+std::unique_ptr<OperationPass<FuncOp>>
+createAsyncParallelForPass(int numWorkerThreads);
+
 std::unique_ptr<OperationPass<FuncOp>> createAsyncRefCountingPass();
 
 std::unique_ptr<OperationPass<FuncOp>> createAsyncRefCountingOptimizationPass();

diff  --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index c6508610c796..d6553974bc38 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -96,6 +96,10 @@ struct AsyncParallelForRewrite : public OpRewritePattern<scf::ParallelOp> {
 struct AsyncParallelForPass
     : public AsyncParallelForBase<AsyncParallelForPass> {
   AsyncParallelForPass() = default;
+  AsyncParallelForPass(int numWorkerThreads) {
+    assert(numWorkerThreads >= 1);
+    numConcurrentAsyncExecute = numWorkerThreads;
+  }
   void runOnFunction() override;
 };
 
@@ -276,3 +280,8 @@ void AsyncParallelForPass::runOnFunction() {
 std::unique_ptr<OperationPass<FuncOp>> mlir::createAsyncParallelForPass() {
   return std::make_unique<AsyncParallelForPass>();
 }
+
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createAsyncParallelForPass(int numWorkerThreads) {
+  return std::make_unique<AsyncParallelForPass>(numWorkerThreads);
+}


        


More information about the llvm-branch-commits mailing list