[Mlir-commits] [mlir] 86ad0af - [mlir:Async] Implement recursive async work splitting for scf.parallel operation (async-parallel-for pass)

Eugene Zhulenev llvmlistbot at llvm.org
Fri Jun 25 10:34:47 PDT 2021


Author: Eugene Zhulenev
Date: 2021-06-25T10:34:39-07:00
New Revision: 86ad0af87054c3cccd68d32e103a6f1f6c6194c7

URL: https://github.com/llvm/llvm-project/commit/86ad0af87054c3cccd68d32e103a6f1f6c6194c7
DIFF: https://github.com/llvm/llvm-project/commit/86ad0af87054c3cccd68d32e103a6f1f6c6194c7.diff

LOG: [mlir:Async] Implement recursive async work splitting for scf.parallel operation (async-parallel-for pass)

Depends On D104780

Recursive work splitting instead of sequential async tasks submission gives ~20%-30% speedup in microbenchmarks.

Algorithm outline:
1. Collapse scf.parallel dimensions into a single dimension
2. Compute the block size for the parallel operations from the 1d problem size
3. Launch parallel tasks
4. Each parallel task reconstructs its own bounds in the original multi-dimensional iteration space
5. Each parallel task computes the original parallel operation body using scf.for loop nest

Reviewed By: herhut

Differential Revision: https://reviews.llvm.org/D104850

Added: 
    mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir
    mlir/test/Dialect/Async/async-parallel-for-seq-dispatch.mlir
    mlir/test/Integration/Dialect/Async/CPU/microbench-scf-async-parallel-for.mlir

Modified: 
    mlir/include/mlir/Dialect/Async/Passes.h
    mlir/include/mlir/Dialect/Async/Passes.td
    mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
    mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir
    mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
    mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir

Removed: 
    mlir/test/Dialect/Async/async-parallel-for.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Async/Passes.h b/mlir/include/mlir/Dialect/Async/Passes.h
index d790835c76125..5c5dbe914aed5 100644
--- a/mlir/include/mlir/Dialect/Async/Passes.h
+++ b/mlir/include/mlir/Dialect/Async/Passes.h
@@ -19,8 +19,6 @@ namespace mlir {
 
 std::unique_ptr<Pass> createAsyncParallelForPass();
 
-std::unique_ptr<Pass> createAsyncParallelForPass(int numWorkerThreads);
-
 std::unique_ptr<OperationPass<ModuleOp>> createAsyncToAsyncRuntimePass();
 
 std::unique_ptr<Pass> createAsyncRuntimeRefCountingPass();

diff  --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td
index d5640f3ae65a6..b770ac751ab13 100644
--- a/mlir/include/mlir/Dialect/Async/Passes.td
+++ b/mlir/include/mlir/Dialect/Async/Passes.td
@@ -12,15 +12,26 @@
 include "mlir/Pass/PassBase.td"
 
 def AsyncParallelFor : Pass<"async-parallel-for"> {
-  let summary = "Convert scf.parallel operations to multiple async regions "
+  let summary = "Convert scf.parallel operations to multiple async compute ops "
                 "executed concurrently for non-overlapping iteration ranges";
   let constructor = "mlir::createAsyncParallelForPass()";
+
   let options = [
-    Option<"numConcurrentAsyncExecute", "num-concurrent-async-execute",
-      "int32_t", /*default=*/"4",
-      "The number of async.execute operations that will be used for concurrent "
-      "loop execution.">
+    Option<"asyncDispatch", "async-dispatch",
+      "bool", /*default=*/"true",
+      "Dispatch async compute tasks using recursive work splitting. If `false` "
+      "async compute tasks will be launched using simple for loop in the "
+      "caller thread.">,
+
+    Option<"numWorkerThreads", "num-workers",
+      "int32_t", /*default=*/"8",
+      "The number of available workers to execute async operations.">,
+
+    Option<"targetBlockSize", "target-block-size",
+      "int32_t", /*default=*/"1000",
+      "The target block size for sharding parallel operation.">
   ];
+
   let dependentDialects = ["async::AsyncDialect", "scf::SCFDialect"];
 }
 

diff  --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index ba09123199849..02541479fb24b 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file implements scf.parallel to src.for + async.execute conversion pass.
+// This file implements scf.parallel to scf.for + async.execute conversion pass.
 //
 //===----------------------------------------------------------------------===//
 
@@ -16,8 +16,10 @@
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/RegionUtils.h"
 
 using namespace mlir;
 using namespace mlir::async;
@@ -31,243 +33,627 @@ namespace {
 //
 // Example:
 //
-//   scf.for (%i, %j) = (%lbi, %lbj) to (%ubi, %ubj) step (%si, %sj) {
+//   scf.parallel (%i, %j) = (%lbi, %lbj) to (%ubi, %ubj) step (%si, %sj) {
 //     "do_some_compute"(%i, %j): () -> ()
 //   }
 //
 // Converted to:
 //
-//   %c0 = constant 0 : index
-//   %c1 = constant 1 : index
+//   // Parallel compute function that executes the parallel body region for
+//   // a subset of the parallel iteration space defined by the one-dimensional
+//   // compute block index.
+//   func parallel_compute_function(%block_index : index, %block_size : index,
+//                                  <parallel operation properties>, ...) {
+//     // Compute multi-dimensional loop bounds for %block_index.
+//     %block_lbi, %block_lbj = ...
+//     %block_ubi, %block_ubj = ...
 //
-//   // Compute blocks sizes for each induction variable.
-//   %num_blocks_i = ... : index
-//   %num_blocks_j = ... : index
-//   %block_size_i = ... : index
-//   %block_size_j = ... : index
+//     // Clone parallel operation body into the scf.for loop nest.
+//     scf.for %i = %blockLbi to %blockUbi {
+//       scf.for %j = block_lbj to %block_ubj {
+//         "do_some_compute"(%i, %j): () -> ()
+//       }
+//     }
+//   }
 //
-//   // Create an async group to track async execute ops.
-//   %group = async.create_group
+// And a dispatch function depending on the `asyncDispatch` option.
 //
-//   scf.for %bi = %c0 to %num_blocks_i step %c1 {
-//     %block_start_i = ... : index
-//     %block_end_i   = ... : index
+// When async dispatch is on: (pseudocode)
 //
-//     scf.for %bj = %c0 to %num_blocks_j step %c1 {
-//       %block_start_j = ... : index
-//       %block_end_j   = ... : index
+//   %block_size = ... compute parallel compute block size
+//   %block_count = ... compute the number of compute blocks
 //
-//       // Execute the body of original parallel operation for the current
-//       // block.
-//       %token = async.execute {
-//         scf.for %i = %block_start_i to %block_end_i step %si {
-//           scf.for %j = %block_start_j to %block_end_j step %sj {
-//             "do_some_compute"(%i, %j): () -> ()
-//           }
-//         }
-//       }
-//
-//       // Add produced async token to the group.
-//       async.add_to_group %token, %group
+//   func @async_dispatch(%block_start : index, %block_end : index, ...) {
+//     // Keep splitting block range until we reached a range of size 1.
+//     while (%block_end - %block_start > 1) {
+//       %mid_index = block_start + (block_end - block_start) / 2;
+//       async.execute { call @async_dispatch(%mid_index, %block_end); }
+//       %block_end = %mid_index
 //     }
+//
+//     // Call parallel compute function for a single block.
+//     call @parallel_compute_fn(%block_start, %block_size, ...);
 //   }
 //
-//   // Await completion of all async.execute operations.
-//   async.await_all %group
+//   // Launch async dispatch for [0, block_count) range.
+//   call @async_dispatch(%c0, %block_count);
 //
-// In this example outer loop launches inner block level loops as separate async
-// execute operations which will be executed concurrently.
+// When async dispatch is off:
 //
-// At the end it waits for the completiom of all async execute operations.
+//   %block_size = ... compute parallel compute block size
+//   %block_count = ... compute the number of compute blocks
 //
+//   scf.for %block_index = %c0 to %block_count {
+//      call @parallel_compute_fn(%block_index, %block_size, ...)
+//   }
+//
+struct AsyncParallelForPass
+    : public AsyncParallelForBase<AsyncParallelForPass> {
+  AsyncParallelForPass() = default;
+  void runOnOperation() override;
+};
+
 struct AsyncParallelForRewrite : public OpRewritePattern<scf::ParallelOp> {
 public:
-  AsyncParallelForRewrite(MLIRContext *ctx, int numConcurrentAsyncExecute)
-      : OpRewritePattern(ctx),
-        numConcurrentAsyncExecute(numConcurrentAsyncExecute) {}
+  AsyncParallelForRewrite(MLIRContext *ctx, bool asyncDispatch,
+                          int32_t numWorkerThreads, int32_t targetBlockSize)
+      : OpRewritePattern(ctx), asyncDispatch(asyncDispatch),
+        numWorkerThreads(numWorkerThreads), targetBlockSize(targetBlockSize) {}
 
   LogicalResult matchAndRewrite(scf::ParallelOp op,
                                 PatternRewriter &rewriter) const override;
 
 private:
-  int numConcurrentAsyncExecute;
+  // The maximum number of tasks per worker thread when sharding parallel op.
+  static constexpr int32_t kMaxOversharding = 4;
+
+  bool asyncDispatch;
+  int32_t numWorkerThreads;
+  int32_t targetBlockSize;
 };
 
-struct AsyncParallelForPass
-    : public AsyncParallelForBase<AsyncParallelForPass> {
-  AsyncParallelForPass() = default;
-  AsyncParallelForPass(int numWorkerThreads) {
-    assert(numWorkerThreads >= 1);
-    numConcurrentAsyncExecute = numWorkerThreads;
-  }
-  void runOnOperation() override;
+struct ParallelComputeFunctionType {
+  FunctionType type;
+  llvm::SmallVector<Value> captures;
 };
 
-} // namespace
+struct ParallelComputeFunction {
+  FuncOp func;
+  llvm::SmallVector<Value> captures;
+};
 
-LogicalResult
-AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
-                                         PatternRewriter &rewriter) const {
-  // We do not currently support rewrite for parallel op with reductions.
-  if (op.getNumReductions() != 0)
-    return failure();
+} // namespace
 
-  MLIRContext *ctx = op.getContext();
-  Location loc = op.getLoc();
+// Converts one-dimensional iteration index in the [0, tripCount) interval
+// into multidimensional iteration coordinate.
+static SmallVector<Value> delinearize(ImplicitLocOpBuilder &b, Value index,
+                                      const SmallVector<Value> &tripCounts) {
+  SmallVector<Value> coords(tripCounts.size());
+  assert(!tripCounts.empty() && "tripCounts must be not empty");
 
-  // Index constants used below.
-  auto indexTy = IndexType::get(ctx);
-  auto zero = IntegerAttr::get(indexTy, 0);
-  auto one = IntegerAttr::get(indexTy, 1);
-  auto c0 = rewriter.create<ConstantOp>(loc, indexTy, zero);
-  auto c1 = rewriter.create<ConstantOp>(loc, indexTy, one);
+  for (ssize_t i = tripCounts.size() - 1; i >= 0; --i) {
+    coords[i] = b.create<SignedRemIOp>(index, tripCounts[i]);
+    index = b.create<SignedDivIOp>(index, tripCounts[i]);
+  }
 
-  // Shorthand for signed integer ceil division operation.
-  auto divup = [&](Value x, Value y) -> Value {
-    return rewriter.create<SignedCeilDivIOp>(loc, x, y);
-  };
+  return coords;
+}
 
-  // Compute trip count for each loop induction variable:
-  //   tripCount = divUp(upperBound - lowerBound, step);
-  SmallVector<Value, 4> tripCounts(op.getNumLoops());
-  for (size_t i = 0; i < op.getNumLoops(); ++i) {
-    auto lb = op.lowerBound()[i];
-    auto ub = op.upperBound()[i];
-    auto step = op.step()[i];
-    auto range = rewriter.create<SubIOp>(loc, ub, lb);
-    tripCounts[i] = divup(range, step);
+// Returns a function type and implicit captures for a parallel compute
+// function. We'll need a list of implicit captures to setup block and value
+// mapping when we'll clone the body of the parallel operation.
+static ParallelComputeFunctionType
+getParallelComputeFunctionType(scf::ParallelOp op, PatternRewriter &rewriter) {
+  // Values implicitly captured by the parallel operation.
+  llvm::SetVector<Value> captures;
+  getUsedValuesDefinedAbove(op.region(), op.region(), captures);
+
+  llvm::SmallVector<Type> inputs;
+  inputs.reserve(2 + 4 * op.getNumLoops() + captures.size());
+
+  Type indexTy = rewriter.getIndexType();
+
+  // One-dimensional iteration space defined by the block index and size.
+  inputs.push_back(indexTy); // blockIndex
+  inputs.push_back(indexTy); // blockSize
+
+  // Multi-dimensional parallel iteration space defined by the loop trip counts.
+  for (unsigned i = 0; i < op.getNumLoops(); ++i)
+    inputs.push_back(indexTy); // loop tripCount
+
+  // Parallel operation lower bound, upper bound and step.
+  for (unsigned i = 0; i < op.getNumLoops(); ++i) {
+    inputs.push_back(indexTy); // lower bound
+    inputs.push_back(indexTy); // upper bound
+    inputs.push_back(indexTy); // step
   }
 
-  // The target number of concurrent async.execute ops.
-  auto numExecuteOps = rewriter.create<ConstantOp>(
-      loc, indexTy, IntegerAttr::get(indexTy, numConcurrentAsyncExecute));
-
-  // Blocks sizes configuration for each induction variable.
-
-  // We try to use maximum available concurrency in outer dimensions first
-  // (assuming that parallel induction variables are corresponding to some
-  // multidimensional access, e.g. in (%d0, %d1, ..., %dn) = (<from>) to (<to>)
-  // we will try to parallelize iteration along the %d0. If %d0 is too small,
-  // we'll parallelize iteration over %d1, and so on.
-  SmallVector<Value, 4> targetNumBlocks(op.getNumLoops());
-  SmallVector<Value, 4> blockSize(op.getNumLoops());
-  SmallVector<Value, 4> numBlocks(op.getNumLoops());
-
-  // Compute block size and number of blocks along the first induction variable.
-  targetNumBlocks[0] = numExecuteOps;
-  blockSize[0] = divup(tripCounts[0], targetNumBlocks[0]);
-  numBlocks[0] = divup(tripCounts[0], blockSize[0]);
-
-  // Assign remaining available concurrency to other induction variables.
-  for (size_t i = 1; i < op.getNumLoops(); ++i) {
-    targetNumBlocks[i] = divup(targetNumBlocks[i - 1], numBlocks[i - 1]);
-    blockSize[i] = divup(tripCounts[i], targetNumBlocks[i]);
-    numBlocks[i] = divup(tripCounts[i], blockSize[i]);
-  }
+  // Types of the implicit captures.
+  for (Value capture : captures)
+    inputs.push_back(capture.getType());
 
-  // Total number of async compute blocks.
-  Value totalBlocks = numBlocks[0];
-  for (size_t i = 1; i < op.getNumLoops(); ++i)
-    totalBlocks = rewriter.create<MulIOp>(loc, totalBlocks, numBlocks[i]);
+  // Convert captures to vector for later convenience.
+  SmallVector<Value> capturesVector(captures.begin(), captures.end());
+  return {rewriter.getFunctionType(inputs, TypeRange()), capturesVector};
+}
 
-  // Create an async.group to wait on all async tokens from async execute ops.
-  auto group =
-      rewriter.create<CreateGroupOp>(loc, GroupType::get(ctx), totalBlocks);
+// Create a parallel compute fuction from the parallel operation.
+static ParallelComputeFunction
+createParallelComputeFunction(scf::ParallelOp op, PatternRewriter &rewriter) {
+  OpBuilder::InsertionGuard guard(rewriter);
+  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
 
-  // Build a scf.for loop nest from the parallel operation.
+  ModuleOp module = op->getParentOfType<ModuleOp>();
+  b.setInsertionPointToStart(&module->getRegion(0).front());
 
-  // Lower/upper bounds for nest block level computations.
-  SmallVector<Value, 4> blockLowerBounds(op.getNumLoops());
-  SmallVector<Value, 4> blockUpperBounds(op.getNumLoops());
-  SmallVector<Value, 4> blockInductionVars(op.getNumLoops());
+  ParallelComputeFunctionType computeFuncType =
+      getParallelComputeFunctionType(op, rewriter);
 
+  FunctionType type = computeFuncType.type;
+  FuncOp func = FuncOp::create(op.getLoc(), "parallel_compute_fn", type);
+  func.setPrivate();
+
+  // Insert function into the module symbol table and assign it unique name.
+  SymbolTable symbolTable(module);
+  symbolTable.insert(func);
+  rewriter.getListener()->notifyOperationInserted(func);
+
+  // Create function entry block.
+  Block *block = b.createBlock(&func.getBody(), func.begin(), type.getInputs());
+  b.setInsertionPointToEnd(block);
+
+  unsigned offset = 0; // argument offset for arguments decoding
+
+  // Load multiple arguments into values vector.
+  auto getArguments = [&](unsigned num_arguments) -> SmallVector<Value> {
+    SmallVector<Value> values(num_arguments);
+    for (unsigned i = 0; i < num_arguments; ++i)
+      values[i] = block->getArgument(offset++);
+    return values;
+  };
+
+  // Block iteration position defined by the block index and size.
+  Value blockIndex = block->getArgument(offset++);
+  Value blockSize = block->getArgument(offset++);
+
+  // Constants used below.
+  Value c0 = b.create<ConstantOp>(b.getIndexAttr(0));
+  Value c1 = b.create<ConstantOp>(b.getIndexAttr(1));
+
+  // Multi-dimensional parallel iteration space defined by the loop trip counts.
+  SmallVector<Value> tripCounts = getArguments(op.getNumLoops());
+
+  // Compute a product of trip counts to get the size of the flattened
+  // one-dimensional iteration space.
+  Value tripCount = tripCounts[0];
+  for (unsigned i = 1; i < tripCounts.size(); ++i)
+    tripCount = b.create<MulIOp>(tripCount, tripCounts[i]);
+
+  // Parallel operation lower bound, upper bound and step.
+  SmallVector<Value> lowerBound = getArguments(op.getNumLoops());
+  SmallVector<Value> upperBound = getArguments(op.getNumLoops());
+  SmallVector<Value> step = getArguments(op.getNumLoops());
+
+  // Remaining arguments are implicit captures of the parallel operation.
+  SmallVector<Value> captures = getArguments(block->getNumArguments() - offset);
+
+  // Find one-dimensional iteration bounds: [blockFirstIndex, blockLastIndex]:
+  //   blockFirstIndex = blockIndex * blockSize
+  Value blockFirstIndex = b.create<MulIOp>(blockIndex, blockSize);
+
+  // The last one-dimensional index in the block defined by the `blockIndex`:
+  //   blockLastIndex = max((blockIndex + 1) * blockSize, tripCount) - 1
+  Value blockEnd0 = b.create<AddIOp>(blockIndex, c1);
+  Value blockEnd1 = b.create<MulIOp>(blockEnd0, blockSize);
+  Value blockEnd2 = b.create<CmpIOp>(CmpIPredicate::sge, blockEnd1, tripCount);
+  Value blockEnd3 = b.create<SelectOp>(blockEnd2, tripCount, blockEnd1);
+  Value blockLastIndex = b.create<SubIOp>(blockEnd3, c1);
+
+  // Convert one-dimensional indices to multi-dimensional coordinates.
+  auto blockFirstCoord = delinearize(b, blockFirstIndex, tripCounts);
+  auto blockLastCoord = delinearize(b, blockLastIndex, tripCounts);
+
+  // Compute compute loops upper bounds from the block last coordinates:
+  //   blockEndCoord[i] = blockLastCoord[i] + 1
+  //
+  // Block first and last coordinates can be the same along the outer compute
+  // dimension when inner compute dimension containts multple blocks.
+  SmallVector<Value> blockEndCoord(op.getNumLoops());
+  for (size_t i = 0; i < blockLastCoord.size(); ++i)
+    blockEndCoord[i] = b.create<AddIOp>(blockLastCoord[i], c1);
+
+  // Construct a loop nest out of scf.for operations that will iterate over
+  // all coordinates in [blockFirstCoord, blockLastCoord] range.
   using LoopBodyBuilder =
       std::function<void(OpBuilder &, Location, Value, ValueRange)>;
-  using LoopBuilder = std::function<LoopBodyBuilder(size_t loopIdx)>;
+  using LoopNestBuilder = std::function<LoopBodyBuilder(size_t loopIdx)>;
+
+  // Parallel region induction variables computed from the multi-dimensional
+  // iteration coordinate using parallel operation bounds and step:
+  //
+  //   computeBlockInductionVars[loopIdx] =
+  //       lowerBound[loopIdx] + blockCoord[loopIdx] * step[loopDdx]
+  SmallVector<Value> computeBlockInductionVars(op.getNumLoops());
+
+  // We need to know if we are in the first or last iteration of the
+  // multi-dimensional loop for each loop in the nest, so we can decide what
+  // loop bounds should we use for the nested loops: bounds defined by compute
+  // block interval, or bounds defined by the parallel operation.
+  //
+  // Example: 2d parallel operation
+  //                   i   j
+  //   loop sizes:   [50, 50]
+  //   first coord:  [25, 25]
+  //   last coord:   [30, 30]
+  //
+  // If `i` is equal to 25 then iteration over `j` should start at 25, when `i`
+  // is between 25 and 30 it should start at 0. The upper bound for `j` should
+  // be 50, except when `i` is equal to 30, then it should also be 30.
+  //
+  // Value at ith position specifies if all loops in [0, i) range of the loop
+  // nest are in the first/last iteration.
+  SmallVector<Value> isBlockFirstCoord(op.getNumLoops());
+  SmallVector<Value> isBlockLastCoord(op.getNumLoops());
 
   // Builds inner loop nest inside async.execute operation that does all the
   // work concurrently.
-  LoopBuilder workLoopBuilder = [&](size_t loopIdx) -> LoopBodyBuilder {
-    return [&, loopIdx](OpBuilder &b, Location loc, Value iv, ValueRange args) {
-      blockInductionVars[loopIdx] = iv;
+  LoopNestBuilder workLoopBuilder = [&](size_t loopIdx) -> LoopBodyBuilder {
+    return [&, loopIdx](OpBuilder &nestedBuilder, Location loc, Value iv,
+                        ValueRange args) {
+      ImplicitLocOpBuilder nb(loc, nestedBuilder);
+
+      // Compute induction variable for `loopIdx`.
+      computeBlockInductionVars[loopIdx] = nb.create<AddIOp>(
+          lowerBound[loopIdx], nb.create<MulIOp>(iv, step[loopIdx]));
+
+      // Check if we are inside first or last iteration of the loop.
+      isBlockFirstCoord[loopIdx] =
+          nb.create<CmpIOp>(CmpIPredicate::eq, iv, blockFirstCoord[loopIdx]);
+      isBlockLastCoord[loopIdx] =
+          nb.create<CmpIOp>(CmpIPredicate::eq, iv, blockLastCoord[loopIdx]);
+
+      // Check if the previous loop is in its first of last iteration.
+      if (loopIdx > 0) {
+        isBlockFirstCoord[loopIdx] = nb.create<AndOp>(
+            isBlockFirstCoord[loopIdx], isBlockFirstCoord[loopIdx - 1]);
+        isBlockLastCoord[loopIdx] = nb.create<AndOp>(
+            isBlockLastCoord[loopIdx], isBlockLastCoord[loopIdx - 1]);
+      }
 
-      // Continue building async loop nest.
+      // Keep building loop nest.
       if (loopIdx < op.getNumLoops() - 1) {
-        b.create<scf::ForOp>(
-            loc, blockLowerBounds[loopIdx + 1], blockUpperBounds[loopIdx + 1],
-            op.step()[loopIdx + 1], ValueRange(), workLoopBuilder(loopIdx + 1));
-        b.create<scf::YieldOp>(loc);
+        // Select nested loop lower/upper bounds depending on out position in
+        // the multi-dimensional iteration space.
+        auto lb = nb.create<SelectOp>(isBlockFirstCoord[loopIdx],
+                                      blockFirstCoord[loopIdx + 1], c0);
+
+        auto ub = nb.create<SelectOp>(isBlockLastCoord[loopIdx],
+                                      blockEndCoord[loopIdx + 1],
+                                      tripCounts[loopIdx + 1]);
+
+        nb.create<scf::ForOp>(lb, ub, c1, ValueRange(),
+                              workLoopBuilder(loopIdx + 1));
+        nb.create<scf::YieldOp>(loc);
         return;
       }
 
-      // Copy the body of the parallel op with new loop bounds.
+      // Copy the body of the parallel op into the inner-most loop.
       BlockAndValueMapping mapping;
-      mapping.map(op.getInductionVars(), blockInductionVars);
+      mapping.map(op.getInductionVars(), computeBlockInductionVars);
+      mapping.map(computeFuncType.captures, captures);
 
       for (auto &bodyOp : op.getLoopBody().getOps())
-        b.clone(bodyOp, mapping);
+        nb.clone(bodyOp, mapping);
     };
   };
 
-  // Builds a loop nest that does async execute op dispatching.
-  LoopBuilder asyncLoopBuilder = [&](size_t loopIdx) -> LoopBodyBuilder {
-    return [&, loopIdx](OpBuilder &b, Location loc, Value iv, ValueRange args) {
-      auto lb = op.lowerBound()[loopIdx];
-      auto ub = op.upperBound()[loopIdx];
-      auto step = op.step()[loopIdx];
-
-      // Compute lower bound for the current block:
-      //   blockLowerBound = iv * blockSize * step + lowerBound
-      auto s0 = b.create<MulIOp>(loc, iv, blockSize[loopIdx]);
-      auto s1 = b.create<MulIOp>(loc, s0, step);
-      auto s2 = b.create<AddIOp>(loc, s1, lb);
-      blockLowerBounds[loopIdx] = s2;
-
-      // Compute upper bound for the current block:
-      //   blockUpperBound = min(upperBound,
-      //                         blockLowerBound + blockSize * step)
-      auto e0 = b.create<MulIOp>(loc, blockSize[loopIdx], step);
-      auto e1 = b.create<AddIOp>(loc, e0, s2);
-      auto e2 = b.create<CmpIOp>(loc, CmpIPredicate::slt, e1, ub);
-      auto e3 = b.create<SelectOp>(loc, e2, e1, ub);
-      blockUpperBounds[loopIdx] = e3;
-
-      // Continue building async dispatch loop nest.
-      if (loopIdx < op.getNumLoops() - 1) {
-        b.create<scf::ForOp>(loc, c0, numBlocks[loopIdx + 1], c1, ValueRange(),
-                             asyncLoopBuilder(loopIdx + 1));
-        b.create<scf::YieldOp>(loc);
-        return;
-      }
+  b.create<scf::ForOp>(blockFirstCoord[0], blockEndCoord[0], c1, ValueRange(),
+                       workLoopBuilder(0));
+  b.create<ReturnOp>(ValueRange());
+
+  return {func, std::move(computeFuncType.captures)};
+}
 
-      // Build the inner loop nest that will do the actual work inside the
-      // `async.execute` body region.
-      auto executeBodyBuilder = [&](OpBuilder &executeBuilder,
-                                    Location executeLoc,
-                                    ValueRange executeArgs) {
-        executeBuilder.create<scf::ForOp>(executeLoc, blockLowerBounds[0],
-                                          blockUpperBounds[0], op.step()[0],
-                                          ValueRange(), workLoopBuilder(0));
-        executeBuilder.create<async::YieldOp>(executeLoc, ValueRange());
-      };
-
-      auto execute = b.create<ExecuteOp>(
-          loc, /*resultTypes=*/TypeRange(), /*dependencies=*/ValueRange(),
-          /*operands=*/ValueRange(), executeBodyBuilder);
-      auto rankType = IndexType::get(ctx);
-      b.create<AddToGroupOp>(loc, rankType, execute.token(), group.result());
-      b.create<scf::YieldOp>(loc);
+// Creates recursive async dispatch function for the given parallel compute
+// function. Dispatch function keeps splitting block range into halves until it
+// reaches a single block, and then excecutes it inline.
+//
+// Function pseudocode (mix of C++ and MLIR):
+//
+//   func @async_dispatch(%block_start : index, %block_end : index, ...) {
+//
+//     // Keep splitting block range until we reached a range of size 1.
+//     while (%block_end - %block_start > 1) {
+//       %mid_index = block_start + (block_end - block_start) / 2;
+//       async.execute { call @async_dispatch(%mid_index, %block_end); }
+//       %block_end = %mid_index
+//     }
+//
+//     // Call parallel compute function for a single block.
+//     call @parallel_compute_fn(%block_start, %block_size, ...);
+//   }
+//
+static FuncOp createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
+                                          PatternRewriter &rewriter) {
+  OpBuilder::InsertionGuard guard(rewriter);
+  Location loc = computeFunc.func.getLoc();
+  ImplicitLocOpBuilder b(loc, rewriter);
+
+  ModuleOp module = computeFunc.func->getParentOfType<ModuleOp>();
+  b.setInsertionPointToStart(&module->getRegion(0).front());
+
+  ArrayRef<Type> computeFuncInputTypes =
+      computeFunc.func.type().cast<FunctionType>().getInputs();
+
+  // Compared to the parallel compute function async dispatch function takes
+  // additional !async.group argument. Also instead of a single `blockIndex` it
+  // takes `blockStart` and `blockEnd` arguments to define the range of
+  // dispatched blocks.
+  SmallVector<Type> inputTypes;
+  inputTypes.push_back(async::GroupType::get(rewriter.getContext()));
+  inputTypes.push_back(rewriter.getIndexType()); // add blockStart argument
+  inputTypes.append(computeFuncInputTypes.begin(), computeFuncInputTypes.end());
+
+  FunctionType type = rewriter.getFunctionType(inputTypes, TypeRange());
+  FuncOp func = FuncOp::create(loc, "async_dispatch_fn", type);
+  func.setPrivate();
+
+  // Insert function into the module symbol table and assign it unique name.
+  SymbolTable symbolTable(module);
+  symbolTable.insert(func);
+  rewriter.getListener()->notifyOperationInserted(func);
+
+  // Create function entry block.
+  Block *block = b.createBlock(&func.getBody(), func.begin(), type.getInputs());
+  b.setInsertionPointToEnd(block);
+
+  Type indexTy = b.getIndexType();
+  Value c1 = b.create<ConstantOp>(b.getIndexAttr(1));
+  Value c2 = b.create<ConstantOp>(b.getIndexAttr(2));
+
+  // Get the async group that will track async dispatch completion.
+  Value group = block->getArgument(0);
+
+  // Get the block iteration range: [blockStart, blockEnd)
+  Value blockStart = block->getArgument(1);
+  Value blockEnd = block->getArgument(2);
+
+  // Create a work splitting while loop for the [blockStart, blockEnd) range.
+  SmallVector<Type> types = {indexTy, indexTy};
+  SmallVector<Value> operands = {blockStart, blockEnd};
+
+  // Create a recursive dispatch loop.
+  scf::WhileOp whileOp = b.create<scf::WhileOp>(types, operands);
+  Block *before = b.createBlock(&whileOp.before(), {}, types);
+  Block *after = b.createBlock(&whileOp.after(), {}, types);
+
+  // Setup dispatch loop condition block: decide if we need to go into the
+  // `after` block and launch one more async dispatch.
+  {
+    b.setInsertionPointToEnd(before);
+    Value start = before->getArgument(0);
+    Value end = before->getArgument(1);
+    Value distance = b.create<SubIOp>(end, start);
+    Value dispatch = b.create<CmpIOp>(CmpIPredicate::sgt, distance, c1);
+    b.create<scf::ConditionOp>(dispatch, before->getArguments());
+  }
+
+  // Setup the async dispatch loop body: recursively call dispatch function
+  // for second the half of the original range and go to the next iteration.
+  {
+    b.setInsertionPointToEnd(after);
+    Value start = after->getArgument(0);
+    Value end = after->getArgument(1);
+    Value distance = b.create<SubIOp>(end, start);
+    Value halfDistance = b.create<SignedDivIOp>(distance, c2);
+    Value midIndex = b.create<AddIOp>(after->getArgument(0), halfDistance);
+
+    // Call parallel compute function inside the async.execute region.
+    auto executeBodyBuilder = [&](OpBuilder &executeBuilder,
+                                  Location executeLoc, ValueRange executeArgs) {
+      // Update the original `blockStart` and `blockEnd` with new range.
+      SmallVector<Value> operands{block->getArguments().begin(),
+                                  block->getArguments().end()};
+      operands[1] = midIndex;
+      operands[2] = end;
+
+      executeBuilder.create<CallOp>(executeLoc, func.sym_name(),
+                                    func.getCallableResults(), operands);
+      executeBuilder.create<async::YieldOp>(executeLoc, ValueRange());
     };
+
+    // Create async.execute operation to dispatch half of the block range.
+    auto execute = b.create<ExecuteOp>(TypeRange(), ValueRange(), ValueRange(),
+                                       executeBodyBuilder);
+    b.create<AddToGroupOp>(indexTy, execute.token(), group);
+    b.create<scf::YieldOp>(ValueRange({after->getArgument(0), midIndex}));
+  }
+
+  // After dispatching async operations to process the tail of the block range
+  // call the parallel compute function for the first block of the range.
+  b.setInsertionPointAfter(whileOp);
+
+  // Drop async dispatch specific arguments: async group, block start and end.
+  auto forwardedInputs = block->getArguments().drop_front(3);
+  SmallVector<Value> computeFuncOperands = {blockStart};
+  computeFuncOperands.append(forwardedInputs.begin(), forwardedInputs.end());
+
+  b.create<CallOp>(computeFunc.func.sym_name(),
+                   computeFunc.func.getCallableResults(), computeFuncOperands);
+  b.create<ReturnOp>(ValueRange());
+
+  return func;
+}
+
+// Launch async dispatch of the parallel compute function.
+static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
+                            ParallelComputeFunction &parallelComputeFunction,
+                            scf::ParallelOp op, Value blockSize,
+                            Value blockCount,
+                            const SmallVector<Value> &tripCounts) {
+  MLIRContext *ctx = op->getContext();
+
+  // Add one more level of indirection to dispatch parallel compute functions
+  // using async operations and recursive work splitting.
+  FuncOp asyncDispatchFunction =
+      createAsyncDispatchFunction(parallelComputeFunction, rewriter);
+
+  Value c0 = b.create<ConstantOp>(b.getIndexAttr(0));
+  Value c1 = b.create<ConstantOp>(b.getIndexAttr(1));
+
+  // Create an async.group to wait on all async tokens from the concurrent
+  // execution of multiple parallel compute function. First block will be
+  // executed synchronously in the caller thread.
+  Value groupSize = b.create<SubIOp>(blockCount, c1);
+  Value group = b.create<CreateGroupOp>(GroupType::get(ctx), groupSize);
+
+  // Pack the async dispath function operands to launch the work splitting.
+  SmallVector<Value> asyncDispatchOperands = {group, c0, blockCount, blockSize};
+  asyncDispatchOperands.append(tripCounts);
+  asyncDispatchOperands.append(op.lowerBound().begin(), op.lowerBound().end());
+  asyncDispatchOperands.append(op.upperBound().begin(), op.upperBound().end());
+  asyncDispatchOperands.append(op.step().begin(), op.step().end());
+  asyncDispatchOperands.append(parallelComputeFunction.captures);
+
+  // Launch async dispatch function for [0, blockCount) range.
+  b.create<CallOp>(asyncDispatchFunction.sym_name(),
+                   asyncDispatchFunction.getCallableResults(),
+                   asyncDispatchOperands);
+
+  // Wait for the completion of all parallel compute operations.
+  b.create<AwaitAllOp>(group);
+}
+
+// Dispatch parallel compute functions by submitting all async compute tasks
+// from a simple for loop in the caller thread.
+static void
+doSequantialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
+                     ParallelComputeFunction &parallelComputeFunction,
+                     scf::ParallelOp op, Value blockSize, Value blockCount,
+                     const SmallVector<Value> &tripCounts) {
+  MLIRContext *ctx = op->getContext();
+
+  FuncOp compute = parallelComputeFunction.func;
+
+  Value c0 = b.create<ConstantOp>(b.getIndexAttr(0));
+  Value c1 = b.create<ConstantOp>(b.getIndexAttr(1));
+
+  // Create an async.group to wait on all async tokens from the concurrent
+  // execution of multiple parallel compute function. First block will be
+  // executed synchronously in the caller thread.
+  Value groupSize = b.create<SubIOp>(blockCount, c1);
+  Value group = b.create<CreateGroupOp>(GroupType::get(ctx), groupSize);
+
+  // Call parallel compute function for all blocks.
+  using LoopBodyBuilder =
+      std::function<void(OpBuilder &, Location, Value, ValueRange)>;
+
+  // Returns parallel compute function operands to process the given block.
+  auto computeFuncOperands = [&](Value blockIndex) -> SmallVector<Value> {
+    SmallVector<Value> computeFuncOperands = {blockIndex, blockSize};
+    computeFuncOperands.append(tripCounts);
+    computeFuncOperands.append(op.lowerBound().begin(), op.lowerBound().end());
+    computeFuncOperands.append(op.upperBound().begin(), op.upperBound().end());
+    computeFuncOperands.append(op.step().begin(), op.step().end());
+    computeFuncOperands.append(parallelComputeFunction.captures);
+    return computeFuncOperands;
   };
 
-  // Start building a loop nest from the first induction variable.
-  rewriter.create<scf::ForOp>(loc, c0, numBlocks[0], c1, ValueRange(),
-                              asyncLoopBuilder(0));
+  // Induction variable is the index of the block: [0, blockCount).
+  LoopBodyBuilder loopBuilder = [&](OpBuilder &loopBuilder, Location loc,
+                                    Value iv, ValueRange args) {
+    ImplicitLocOpBuilder nb(loc, loopBuilder);
+
+    // Call parallel compute function inside the async.execute region.
+    auto executeBodyBuilder = [&](OpBuilder &executeBuilder,
+                                  Location executeLoc, ValueRange executeArgs) {
+      executeBuilder.create<CallOp>(executeLoc, compute.sym_name(),
+                                    compute.getCallableResults(),
+                                    computeFuncOperands(iv));
+      executeBuilder.create<async::YieldOp>(executeLoc, ValueRange());
+    };
+
+    // Create async.execute operation to launch parallel computate function.
+    auto execute = nb.create<ExecuteOp>(TypeRange(), ValueRange(), ValueRange(),
+                                        executeBodyBuilder);
+    nb.create<AddToGroupOp>(rewriter.getIndexType(), execute.token(), group);
+    nb.create<scf::YieldOp>();
+  };
+
+  // Iterate over all compute blocks and launch parallel compute operations.
+  b.create<scf::ForOp>(c1, blockCount, c1, ValueRange(), loopBuilder);
+
+  // Call parallel compute function for the first block in the caller thread.
+  b.create<CallOp>(compute.sym_name(), compute.getCallableResults(),
+                   computeFuncOperands(c0));
+
+  // Wait for the completion of all async compute operations.
+  b.create<AwaitAllOp>(group);
+}
 
-  // Wait for the completion of all subtasks.
-  rewriter.create<AwaitAllOp>(loc, group.result());
+LogicalResult
+AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
+                                         PatternRewriter &rewriter) const {
+  // We do not currently support rewrite for parallel op with reductions.
+  if (op.getNumReductions() != 0)
+    return failure();
 
-  // Erase the original parallel operation.
+  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+
+  // Compute trip count for each loop induction variable:
+  //   tripCount = ceil_div(upperBound - lowerBound, step);
+  SmallVector<Value> tripCounts(op.getNumLoops());
+  for (size_t i = 0; i < op.getNumLoops(); ++i) {
+    auto lb = op.lowerBound()[i];
+    auto ub = op.upperBound()[i];
+    auto step = op.step()[i];
+    auto range = b.create<SubIOp>(ub, lb);
+    tripCounts[i] = b.create<SignedCeilDivIOp>(range, step);
+  }
+
+  // Compute a product of trip counts to get the 1-dimensional iteration space
+  // for the scf.parallel operation.
+  Value tripCount = tripCounts[0];
+  for (size_t i = 1; i < tripCounts.size(); ++i)
+    tripCount = b.create<MulIOp>(tripCount, tripCounts[i]);
+
+  auto indexTy = b.getIndexType();
+
+  // Do not overload worker threads with too many compute blocks.
+  Value maxComputeBlocks = b.create<ConstantOp>(
+      indexTy, b.getIndexAttr(numWorkerThreads * kMaxOversharding));
+
+  // Target block size from the pass parameters.
+  Value targetComputeBlockSize =
+      b.create<ConstantOp>(indexTy, b.getIndexAttr(targetBlockSize));
+
+  // Compute parallel block size from the parallel problem size:
+  //   blockSize = min(tripCount,
+  //                   max(divup(tripCount, maxComputeBlocks),
+  //                       targetComputeBlockSize))
+  Value bs0 = b.create<SignedCeilDivIOp>(tripCount, maxComputeBlocks);
+  Value bs1 = b.create<CmpIOp>(CmpIPredicate::sge, bs0, targetComputeBlockSize);
+  Value bs2 = b.create<SelectOp>(bs1, bs0, targetComputeBlockSize);
+  Value bs3 = b.create<CmpIOp>(CmpIPredicate::sle, tripCount, bs2);
+  Value blockSize = b.create<SelectOp>(bs3, tripCount, bs2);
+  Value blockCount = b.create<SignedCeilDivIOp>(tripCount, blockSize);
+
+  // Create a parallel compute function that takes a block id and computes the
+  // parallel operation body for a subset of iteration space.
+  ParallelComputeFunction parallelComputeFunction =
+      createParallelComputeFunction(op, rewriter);
+
+  // Dispatch parallel compute function using async recursive work splitting, or
+  // by submitting compute task sequentially from a caller thread.
+  if (asyncDispatch) {
+    doAsyncDispatch(b, rewriter, parallelComputeFunction, op, blockSize,
+                    blockCount, tripCounts);
+  } else {
+    doSequantialDispatch(b, rewriter, parallelComputeFunction, op, blockSize,
+                         blockCount, tripCounts);
+  }
+
+  // Parallel operation was replaces with a block iteration loop.
   rewriter.eraseOp(op);
 
   return success();
@@ -277,7 +663,8 @@ void AsyncParallelForPass::runOnOperation() {
   MLIRContext *ctx = &getContext();
 
   RewritePatternSet patterns(ctx);
-  patterns.add<AsyncParallelForRewrite>(ctx, numConcurrentAsyncExecute);
+  patterns.add<AsyncParallelForRewrite>(ctx, asyncDispatch, numWorkerThreads,
+                                        targetBlockSize);
 
   if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
     signalPassFailure();
@@ -286,7 +673,3 @@ void AsyncParallelForPass::runOnOperation() {
 std::unique_ptr<Pass> mlir::createAsyncParallelForPass() {
   return std::make_unique<AsyncParallelForPass>();
 }
-
-std::unique_ptr<Pass> mlir::createAsyncParallelForPass(int numWorkerThreads) {
-  return std::make_unique<AsyncParallelForPass>(numWorkerThreads);
-}

diff  --git a/mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir b/mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir
new file mode 100644
index 0000000000000..72b2e01045482
--- /dev/null
+++ b/mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt %s -split-input-file -async-parallel-for=async-dispatch=true  \
+// RUN: | FileCheck %s
+
+// CHECK-LABEL: @loop_1d
+func @loop_1d(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<?xf32>) {
+  // CHECK: %[[GROUP:.*]] = async.create_group
+  // CHECK: call @async_dispatch_fn
+  // CHECK: async.await_all %[[GROUP]]
+  scf.parallel (%i) = (%arg0) to (%arg1) step (%arg2) {
+    %one = constant 1.0 : f32
+    memref.store %one, %arg3[%i] : memref<?xf32>
+  }
+  return
+}
+
+// CHECK-LABEL: func private @parallel_compute_fn
+// CHECK:       scf.for
+// CHECK:         memref.store
+
+// CHECK-LABEL: func private @async_dispatch_fn
+// CHECK-SAME:    %[[GROUP:arg0]]: !async.group,
+// CHECK-SAME:    %[[BLOCK_START:arg1]]: index
+// CHECK-SAME:    %[[BLOCK_END:arg2]]: index
+
+// CHECK:         scf.while (%[[S:.*]] = %[[BLOCK_START]],
+// CHECK-SAME:               %[[E:.*]] = %[[BLOCK_END]])
+// CHECK:         } do {
+// CHECK:           %[[TOKEN:.*]] = async.execute
+// CHECK:             call @async_dispatch_fn
+// CHECK:             async.add_to_group
+// CHECK:         }
+
+// CHECK:         call @parallel_compute_fn(%[[BLOCK_START]]
+
+// -----
+
+// CHECK-LABEL: @loop_2d
+func @loop_2d(%arg0: index, %arg1: index, %arg2: index, // lb, ub, step
+              %arg3: index, %arg4: index, %arg5: index, // lb, ub, step
+              %arg6: memref<?x?xf32>) {
+  // CHECK: %[[GROUP:.*]] = async.create_group
+  // CHECK: call @async_dispatch_fn
+  // CHECK: async.await_all %[[GROUP]]
+  scf.parallel (%i0, %i1) = (%arg0, %arg3) to (%arg1, %arg4)
+                            step (%arg2, %arg5) {
+    %one = constant 1.0 : f32
+    memref.store %one, %arg6[%i0, %i1] : memref<?x?xf32>
+  }
+  return
+}
+
+// CHECK-LABEL: func private @parallel_compute_fn
+// CHECK:       scf.for
+// CHECK:         scf.for
+// CHECK:           memref.store
+
+// CHECK-LABEL: func private @async_dispatch_fn

diff  --git a/mlir/test/Dialect/Async/async-parallel-for.mlir b/mlir/test/Dialect/Async/async-parallel-for-seq-dispatch.mlir
similarity index 58%
rename from mlir/test/Dialect/Async/async-parallel-for.mlir
rename to mlir/test/Dialect/Async/async-parallel-for-seq-dispatch.mlir
index d498fe80906be..81ed85323e033 100644
--- a/mlir/test/Dialect/Async/async-parallel-for.mlir
+++ b/mlir/test/Dialect/Async/async-parallel-for-seq-dispatch.mlir
@@ -1,44 +1,50 @@
-// RUN: mlir-opt %s -async-parallel-for | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -async-parallel-for=async-dispatch=false  \
+// RUN: | FileCheck %s --dump-input=always
 
 // CHECK-LABEL: @loop_1d
 func @loop_1d(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<?xf32>) {
   // CHECK: %[[GROUP:.*]] = async.create_group
   // CHECK: scf.for
-  // CHECK:   %[[TOKEN:.*]] = async.execute {
-  // CHECK:     scf.for
-  // CHECK:       memref.store
+  // CHECK:   %[[TOKEN:.*]] = async.execute
+  // CHECK:     call @parallel_compute_fn
   // CHECK:     async.yield
-  // CHECK:   }
   // CHECK:   async.add_to_group %[[TOKEN]], %[[GROUP]]
+  // CHECK: call @parallel_compute_fn
   // CHECK: async.await_all %[[GROUP]]
   scf.parallel (%i) = (%arg0) to (%arg1) step (%arg2) {
     %one = constant 1.0 : f32
     memref.store %one, %arg3[%i] : memref<?xf32>
   }
-
   return
 }
 
+// CHECK-LABEL: func private @parallel_compute_fn
+// CHECK:       scf.for
+// CHECK:         memref.store
+
+// -----
+
 // CHECK-LABEL: @loop_2d
 func @loop_2d(%arg0: index, %arg1: index, %arg2: index, // lb, ub, step
               %arg3: index, %arg4: index, %arg5: index, // lb, ub, step
               %arg6: memref<?x?xf32>) {
   // CHECK: %[[GROUP:.*]] = async.create_group
   // CHECK: scf.for
-  // CHECK:   scf.for
-  // CHECK:     %[[TOKEN:.*]] = async.execute {
-  // CHECK:       scf.for
-  // CHECK:         scf.for
-  // CHECK:           memref.store
-  // CHECK:       async.yield
-  // CHECK:     }
-  // CHECK:     async.add_to_group %[[TOKEN]], %[[GROUP]]
+  // CHECK:   %[[TOKEN:.*]] = async.execute
+  // CHECK:     call @parallel_compute_fn
+  // CHECK:     async.yield
+  // CHECK:   async.add_to_group %[[TOKEN]], %[[GROUP]]
+  // CHECK: call @parallel_compute_fn
   // CHECK: async.await_all %[[GROUP]]
   scf.parallel (%i0, %i1) = (%arg0, %arg3) to (%arg1, %arg4)
                             step (%arg2, %arg5) {
     %one = constant 1.0 : f32
     memref.store %one, %arg6[%i0, %i1] : memref<?x?xf32>
   }
-
   return
 }
+
+// CHECK-LABEL: func private @parallel_compute_fn
+// CHECK:       scf.for
+// CHECK:         scf.for
+// CHECK:           memref.store

diff  --git a/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir b/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir
index d26e301760e9e..1ab6ff0630eda 100644
--- a/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir
+++ b/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir
@@ -1,12 +1,10 @@
 // RUN:   mlir-opt %s                                                          \
-// RUN:               -linalg-tile-to-parallel-loops="linalg-tile-sizes=256"   \
-// RUN:               -async-parallel-for="num-concurrent-async-execute=4"     \
+// RUN:               -convert-linalg-to-parallel-loops                        \
+// RUN:               -async-parallel-for                                      \
 // RUN:               -async-to-async-runtime                                  \
 // RUN:               -async-runtime-ref-counting                              \
-// RUN:               -async-runtime-ref-counting-opt                          \
+// FIXME:             -async-runtime-ref-counting-opt                          \
 // RUN:               -convert-async-to-llvm                                   \
-// RUN:               -lower-affine                                            \
-// RUN:               -convert-linalg-to-loops                                 \
 // RUN:               -convert-scf-to-std                                      \
 // RUN:               -std-expand                                              \
 // RUN:               -convert-vector-to-llvm                                  \

diff  --git a/mlir/test/Integration/Dialect/Async/CPU/microbench-scf-async-parallel-for.mlir b/mlir/test/Integration/Dialect/Async/CPU/microbench-scf-async-parallel-for.mlir
new file mode 100644
index 0000000000000..e2e69c65ba08c
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Async/CPU/microbench-scf-async-parallel-for.mlir
@@ -0,0 +1,148 @@
+// RUN:   mlir-opt %s                                                          \
+// RUN:               -async-parallel-for                                      \
+// RUN:               -async-to-async-runtime                                  \
+// RUN:               -async-runtime-ref-counting                              \
+// FIXME:             -async-runtime-ref-counting-opt                          \
+// RUN:               -convert-async-to-llvm                                   \
+// RUN:               -convert-linalg-to-loops                                 \
+// RUN:               -convert-scf-to-std                                      \
+// RUN:               -std-expand                                              \
+// RUN:               -convert-vector-to-llvm                                  \
+// RUN:               -convert-std-to-llvm                                     \
+// RUN: | mlir-cpu-runner                                                      \
+// RUN: -e entry -entry-point-result=void -O3                                  \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext  \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext\
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_async_runtime%shlibext \
+// RUN: | FileCheck %s --dump-input=always
+
+// RUN:   mlir-opt %s                                                          \
+// RUN:               -async-parallel-for=async-dispatch=false                 \
+// RUN:               -async-to-async-runtime                                  \
+// RUN:               -async-runtime-ref-counting                              \
+// FIXME:             -async-runtime-ref-counting-opt                          \
+// RUN:               -convert-async-to-llvm                                   \
+// RUN:               -convert-linalg-to-loops                                 \
+// RUN:               -convert-scf-to-std                                      \
+// RUN:               -std-expand                                              \
+// RUN:               -convert-vector-to-llvm                                  \
+// RUN:               -convert-std-to-llvm                                     \
+// RUN: | mlir-cpu-runner                                                      \
+// RUN: -e entry -entry-point-result=void -O3                                  \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext  \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext\
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_async_runtime%shlibext \
+// RUN: | FileCheck %s --dump-input=always
+
+// RUN:   mlir-opt %s                                                          \
+// RUN:               -convert-linalg-to-loops                                 \
+// RUN:               -convert-scf-to-std                                      \
+// RUN:               -convert-vector-to-llvm                                  \
+// RUN:               -convert-std-to-llvm                                     \
+// RUN: | mlir-cpu-runner                                                      \
+// RUN: -e entry -entry-point-result=void -O3                                  \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext  \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext\
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_async_runtime%shlibext \
+// RUN: | FileCheck %s --dump-input=always
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+
+func @scf_parallel(%lhs: memref<?x?xf32>,
+                   %rhs: memref<?x?xf32>,
+                   %sum: memref<?x?xf32>) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+
+  %d0 = memref.dim %lhs, %c0 : memref<?x?xf32>
+  %d1 = memref.dim %lhs, %c1 : memref<?x?xf32>
+
+  scf.parallel (%i, %j) = (%c0, %c0) to (%d0, %d1) step (%c1, %c1) {
+    %lv = memref.load %lhs[%i, %j] : memref<?x?xf32>
+    %rv = memref.load %lhs[%i, %j] : memref<?x?xf32>
+    %r = addf %lv, %rv : f32
+    memref.store %r, %sum[%i, %j] : memref<?x?xf32>
+  }
+
+  return
+}
+
+func @entry() {
+  %f1 = constant 1.0 : f32
+  %f4 = constant 4.0 : f32
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %cM = constant 1000 : index
+
+  //
+  // Sanity check for the function under test.
+  //
+
+  %LHS10 = memref.alloc() {alignment = 64} : memref<1x10xf32>
+  %RHS10 = memref.alloc() {alignment = 64} : memref<1x10xf32>
+  %DST10 = memref.alloc() {alignment = 64} : memref<1x10xf32>
+
+  linalg.fill(%f1, %LHS10) : f32, memref<1x10xf32>
+  linalg.fill(%f1, %RHS10) : f32, memref<1x10xf32>
+
+  %LHS = memref.cast %LHS10 : memref<1x10xf32> to memref<?x?xf32>
+  %RHS = memref.cast %RHS10 : memref<1x10xf32> to memref<?x?xf32>
+  %DST = memref.cast %DST10 : memref<1x10xf32> to memref<?x?xf32>
+
+  call @scf_parallel(%LHS, %RHS, %DST)
+    : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
+
+  // CHECK: [2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
+  %U = memref.cast %DST10 :  memref<1x10xf32> to memref<*xf32>
+  call @print_memref_f32(%U): (memref<*xf32>) -> ()
+
+  memref.dealloc %LHS10: memref<1x10xf32>
+  memref.dealloc %RHS10: memref<1x10xf32>
+  memref.dealloc %DST10: memref<1x10xf32>
+
+  //
+  // Allocate data for microbenchmarks.
+  //
+
+  %LHS1024 = memref.alloc() {alignment = 64} : memref<1024x1024xf32>
+  %RHS1024 = memref.alloc() {alignment = 64} : memref<1024x1024xf32>
+  %DST1024 = memref.alloc() {alignment = 64} : memref<1024x1024xf32>
+
+  %LHS0 = memref.cast %LHS1024 : memref<1024x1024xf32> to memref<?x?xf32>
+  %RHS0 = memref.cast %RHS1024 : memref<1024x1024xf32> to memref<?x?xf32>
+  %DST0 = memref.cast %DST1024 : memref<1024x1024xf32> to memref<?x?xf32>
+
+  //
+  // Warm up.
+  //
+
+  call @scf_parallel(%LHS0, %RHS0, %DST0)
+    : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
+
+  //
+  // Measure execution time.
+  //
+
+  %t0 = call @rtclock() : () -> f64
+  scf.for %i = %c0 to %cM step %c1 {
+    call @scf_parallel(%LHS0, %RHS0, %DST0)
+      : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
+  }
+  %t1 = call @rtclock() : () -> f64
+  %t1024 = subf %t1, %t0 : f64
+
+  // Print timings.
+  vector.print %t1024 : f64
+
+  // Free.
+  memref.dealloc %LHS1024: memref<1024x1024xf32>
+  memref.dealloc %RHS1024: memref<1024x1024xf32>
+  memref.dealloc %DST1024: memref<1024x1024xf32>
+
+  return
+}
+
+func private @rtclock() -> f64
+
+func private @print_memref_f32(memref<*xf32>)
+  attributes { llvm.emit_c_interface }

diff  --git a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
index 9f05ec8065dc5..76a6b2f270531 100644
--- a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
+++ b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
@@ -1,7 +1,22 @@
 // RUN:   mlir-opt %s -async-parallel-for                                      \
 // RUN:               -async-to-async-runtime                                  \
 // RUN:               -async-runtime-ref-counting                              \
-// RUN:               -async-runtime-ref-counting-opt                          \
+// FIXME:             -async-runtime-ref-counting-opt                          \
+// RUN:               -convert-async-to-llvm                                   \
+// RUN:               -convert-scf-to-std                                      \
+// RUN:               -convert-std-to-llvm                                     \
+// RUN: | mlir-cpu-runner                                                      \
+// RUN:  -e entry -entry-point-result=void -O0                                 \
+// RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_async_runtime%shlibext\
+// RUN: | FileCheck %s --dump-input=always
+
+// RUN:   mlir-opt %s -async-parallel-for="async-dispatch=false                \
+// RUN:                                    num-workers=20                      \
+// RUN:                                    target-block-size=1"                \
+// RUN:               -async-to-async-runtime                                  \
+// RUN:               -async-runtime-ref-counting                              \
+// FIXME:             -async-runtime-ref-counting-opt                          \
 // RUN:               -convert-async-to-llvm                                   \
 // RUN:               -convert-scf-to-std                                      \
 // RUN:               -convert-std-to-llvm                                     \

diff  --git a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir
index 883a0bc4fab7b..0443e46116920 100644
--- a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir
+++ b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir
@@ -1,7 +1,22 @@
 // RUN:   mlir-opt %s -async-parallel-for                                      \
 // RUN:               -async-to-async-runtime                                  \
 // RUN:               -async-runtime-ref-counting                              \
-// RUN:               -async-runtime-ref-counting-opt                          \
+// FIXME:             -async-runtime-ref-counting-opt                          \
+// RUN:               -convert-async-to-llvm                                   \
+// RUN:               -convert-scf-to-std                                      \
+// RUN:               -convert-std-to-llvm                                     \
+// RUN: | mlir-cpu-runner                                                      \
+// RUN:  -e entry -entry-point-result=void -O0                                 \
+// RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_async_runtime%shlibext\
+// RUN: | FileCheck %s --dump-input=always
+
+// RUN:   mlir-opt %s -async-parallel-for="async-dispatch=false                \
+// RUN:                                    num-workers=20                      \
+// RUN:                                    target-block-size=1"                \
+// RUN:               -async-to-async-runtime                                  \
+// RUN:               -async-runtime-ref-counting                              \
+// FIXME:             -async-runtime-ref-counting-opt                          \
 // RUN:               -convert-async-to-llvm                                   \
 // RUN:               -convert-scf-to-std                                      \
 // RUN:               -convert-std-to-llvm                                     \


        


More information about the Mlir-commits mailing list