[Mlir-commits] [mlir] bdde959 - Remove unnecessary async group creates and awaits.

Eugene Zhulenev llvmlistbot at llvm.org
Tue Sep 28 14:52:17 PDT 2021


Author: bakhtiyar
Date: 2021-09-28T14:52:08-07:00
New Revision: bdde959533f05d7d191cabce4d62216754802014

URL: https://github.com/llvm/llvm-project/commit/bdde959533f05d7d191cabce4d62216754802014
DIFF: https://github.com/llvm/llvm-project/commit/bdde959533f05d7d191cabce4d62216754802014.diff

LOG: Remove unnecessary async group creates and awaits.

Reviewed By: ezhulenev

Differential Revision: https://reviews.llvm.org/D110605

Added: 
    

Modified: 
    mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
    mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index 8ccf97e269500..2928c26adfe00 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -508,12 +508,6 @@ static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
   Value c0 = b.create<ConstantIndexOp>(0);
   Value c1 = b.create<ConstantIndexOp>(1);
 
-  // Create an async.group to wait on all async tokens from the concurrent
-  // execution of multiple parallel compute function. First block will be
-  // executed synchronously in the caller thread.
-  Value groupSize = b.create<SubIOp>(blockCount, c1);
-  Value group = b.create<CreateGroupOp>(GroupType::get(ctx), groupSize);
-
   // Appends operands shared by async dispatch and parallel compute functions to
   // the given operands vector.
   auto appendBlockComputeOperands = [&](SmallVector<Value> &operands) {
@@ -543,6 +537,12 @@ static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
   };
 
   auto asyncDispatch = [&](OpBuilder &nestedBuilder, Location loc) {
+    // Create an async.group to wait on all async tokens from the concurrent
+    // execution of multiple parallel compute function. First block will be
+    // executed synchronously in the caller thread.
+    Value groupSize = b.create<SubIOp>(blockCount, c1);
+    Value group = b.create<CreateGroupOp>(GroupType::get(ctx), groupSize);
+
     ImplicitLocOpBuilder nb(loc, nestedBuilder);
 
     // Launch async dispatch function for [0, blockCount) range.
@@ -551,14 +551,15 @@ static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
 
     nb.create<CallOp>(asyncDispatchFunction.sym_name(),
                       asyncDispatchFunction.getCallableResults(), operands);
+
+    // Wait for the completion of all parallel compute operations.
+    b.create<AwaitAllOp>(group);
+
     nb.create<scf::YieldOp>();
   };
 
   // Dispatch either single block compute function, or launch async dispatch.
   b.create<scf::IfOp>(TypeRange(), isSingleBlock, syncDispatch, asyncDispatch);
-
-  // Wait for the completion of all parallel compute operations.
-  b.create<AwaitAllOp>(group);
 }
 
 // Dispatch parallel compute functions by submitting all async compute tasks

diff  --git a/mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir b/mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir
index 7bc7edc27b486..c8d1e6d3c0cc0 100644
--- a/mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir
+++ b/mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir
@@ -12,13 +12,13 @@ func @loop_1d(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<?xf32>) {
 
   // CHECK:      scf.if %[[IS_NOOP]] {
   // CHECK-NEXT: } else {
-  // CHECK:        %[[GROUP:.*]] = async.create_group
-  // CHECK:        scf.if {{.*}} {
+    // CHECK:        scf.if {{.*}} {
   // CHECK:          call @parallel_compute_fn(%[[C0]]
   // CHECK:        } else {
+  // CHECK:          %[[GROUP:.*]] = async.create_group
   // CHECK:          call @async_dispatch_fn
+  // CHECK:          async.await_all %[[GROUP]]
   // CHECK:        }
-  // CHECK:        async.await_all %[[GROUP]]
   // CHECK:      }
   scf.parallel (%i) = (%arg0) to (%arg1) step (%arg2) {
     %one = constant 1.0 : f32


        


More information about the Mlir-commits mailing list