[Mlir-commits] [mlir] d43b236 - [mlir:Async] Add the size parameter to the async.group

Eugene Zhulenev llvmlistbot at llvm.org
Fri Jun 25 10:26:57 PDT 2021


Author: Eugene Zhulenev
Date: 2021-06-25T10:26:50-07:00
New Revision: d43b23608ad664f02f56e965ca78916bde220950

URL: https://github.com/llvm/llvm-project/commit/d43b23608ad664f02f56e965ca78916bde220950
DIFF: https://github.com/llvm/llvm-project/commit/d43b23608ad664f02f56e965ca78916bde220950.diff

LOG: [mlir:Async] Add the size parameter to the async.group

Specify the `!async.group` size (the number of tokens that will be added to it) at construction time. `async.await_all` operation can potentially race with `async.execute` operations that keep updating the group, for this reason it is required to know upfront how many tokens will be added to the group.

Reviewed By: ftynse, herhut

Differential Revision: https://reviews.llvm.org/D104780

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
    mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
    mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
    mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
    mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
    mlir/lib/ExecutionEngine/AsyncRuntime.cpp
    mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir
    mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
    mlir/test/Dialect/Async/async-to-async-runtime.mlir
    mlir/test/Dialect/Async/ops.mlir
    mlir/test/Dialect/Async/runtime.mlir
    mlir/test/mlir-cpu-runner/async-error.mlir
    mlir/test/mlir-cpu-runner/async-group.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
index 9ef218ebe560f..f9ddd67a7961d 100644
--- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
+++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
@@ -160,20 +160,24 @@ def Async_CreateGroupOp : Async_Op<"create_group", [NoSideEffect]> {
   let summary = "creates an empty async group";
   let description = [{
     The `async.create_group` allocates an empty async group. Async tokens or
-    values can be added to this group later.
+    values can be added to this group later. The size of the group must be
+    specified at construction time, and `await_all` operation will first
+    wait until the number of added tokens or values reaches the group size.
 
     Example:
 
     ```mlir
-    %0 = async.create_group
+    %size = ... : index
+    %group = async.create_group %size : !async.group
     ...
-    async.await_all %0
+    async.await_all %group
     ```
   }];
 
+  let arguments = (ins Index:$size);
   let results = (outs Async_GroupType:$result);
 
-  let assemblyFormat = "attr-dict";
+  let assemblyFormat = "$size `:` type($result) attr-dict";
 }
 
 def Async_AddToGroupOp : Async_Op<"add_to_group", []> {
@@ -186,7 +190,7 @@ def Async_AddToGroupOp : Async_Op<"add_to_group", []> {
     Example:
 
     ```mlir
-    %0 = async.create_group
+    %0 = async.create_group %size : !async.group
     %1 = ... : !async.token
     %2 = async.add_to_group %1, %0 : !async.token
     ```
@@ -209,7 +213,7 @@ def Async_AwaitAllOp : Async_Op<"await_all", []> {
     Example:
 
     ```mlir
-    %0 = async.create_group
+    %0 = async.create_group %size : !async.group
 
     %1 = ... : !async.token
     %2 = async.add_to_group %1, %0 : !async.token
@@ -331,17 +335,28 @@ def Async_CoroSuspendOp : Async_Op<"coro.suspend", [Terminator]> {
 // Runtime API defined in the `ExecutionEngine/AsyncRuntime.h`.
 
 def Async_RuntimeCreateOp : Async_Op<"runtime.create"> {
-  let summary = "creates an async runtime value (token, value or group)";
+  let summary = "creates an async runtime token or value";
   let description = [{
-    The `async.runtime.create` operation creates an async dialect value
-    (token, value or group). Tokens and values are created in non-ready state.
-    Groups are created in empty state.
+    The `async.runtime.create` operation creates an async dialect token or
+    value. Tokens and values are created in the non-ready state.
   }];
 
-  let results = (outs Async_AnyAsyncType:$result);
+  let results = (outs Async_AnyValueOrTokenType:$result);
   let assemblyFormat = "attr-dict `:` type($result)";
 }
 
+def Async_RuntimeCreateGroupOp : Async_Op<"runtime.create_group"> {
+  let summary = "creates an async runtime group";
+  let description = [{
+    The `async.runtime.create_group` operation creates an async dialect group
+    of the given size. Group created in the empty state.
+  }];
+
+  let arguments = (ins Index:$size);
+  let results = (outs Async_GroupType:$result);
+  let assemblyFormat = "$size `:` type($result) attr-dict ";
+}
+
 def Async_RuntimeSetAvailableOp : Async_Op<"runtime.set_available"> {
   let summary = "switches token or value to available state";
   let description = [{

diff  --git a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
index 3b26bf61e622f..a101b28bb4282 100644
--- a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
+++ b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
@@ -66,7 +66,7 @@ extern "C" AsyncToken *mlirAsyncRuntimeCreateToken();
 extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int32_t);
 
 // Create a new `async.group` in empty state.
-extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup();
+extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup(int64_t size);
 
 extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *, AsyncGroup *);
 

diff  --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index ff2460bb129c4..0156ede1b9b65 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -89,7 +89,8 @@ struct AsyncAPI {
   }
 
   static FunctionType createGroupFunctionType(MLIRContext *ctx) {
-    return FunctionType::get(ctx, {}, {GroupType::get(ctx)});
+    auto i64 = IntegerType::get(ctx, 64);
+    return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)});
   }
 
   static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
@@ -543,11 +544,10 @@ class RuntimeCreateOpLowering : public OpConversionPattern<RuntimeCreateOp> {
     TypeConverter *converter = getTypeConverter();
     Type resultType = op->getResultTypes()[0];
 
-    // Tokens and Groups lowered to function calls without arguments.
-    if (resultType.isa<TokenType>() || resultType.isa<GroupType>()) {
-      rewriter.replaceOpWithNewOp<CallOp>(
-          op, resultType.isa<TokenType>() ? kCreateToken : kCreateGroup,
-          converter->convertType(resultType));
+    // Tokens creation maps to a simple function call.
+    if (resultType.isa<TokenType>()) {
+      rewriter.replaceOpWithNewOp<CallOp>(op, kCreateToken,
+                                          converter->convertType(resultType));
       return success();
     }
 
@@ -582,6 +582,29 @@ class RuntimeCreateOpLowering : public OpConversionPattern<RuntimeCreateOp> {
 };
 } // namespace
 
+//===----------------------------------------------------------------------===//
+// Convert async.runtime.create_group to the corresponding runtime API call.
+//===----------------------------------------------------------------------===//
+
+namespace {
+class RuntimeCreateGroupOpLowering
+    : public OpConversionPattern<RuntimeCreateGroupOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(RuntimeCreateGroupOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    TypeConverter *converter = getTypeConverter();
+    Type resultType = op->getResultTypes()[0];
+
+    rewriter.replaceOpWithNewOp<CallOp>(
+        op, kCreateGroup, converter->convertType(resultType), operands);
+    return success();
+  }
+};
+} // namespace
+
 //===----------------------------------------------------------------------===//
 // Convert async.runtime.set_available to the corresponding runtime API call.
 //===----------------------------------------------------------------------===//
@@ -967,8 +990,9 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
 
   // Lower async.runtime operations that rely on LLVM type converter to convert
   // from async value payload type to the LLVM type.
-  patterns.add<RuntimeCreateOpLowering, RuntimeStoreOpLowering,
-               RuntimeLoadOpLowering>(llvmConverter, ctx);
+  patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering,
+               RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter,
+                                                              ctx);
 
   // Lower async coroutine operations to LLVM coroutine intrinsics.
   patterns

diff  --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index ce2bc7081faf1..ba09123199849 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -165,8 +165,14 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
     numBlocks[i] = divup(tripCounts[i], blockSize[i]);
   }
 
+  // Total number of async compute blocks.
+  Value totalBlocks = numBlocks[0];
+  for (size_t i = 1; i < op.getNumLoops(); ++i)
+    totalBlocks = rewriter.create<MulIOp>(loc, totalBlocks, numBlocks[i]);
+
   // Create an async.group to wait on all async tokens from async execute ops.
-  auto group = rewriter.create<CreateGroupOp>(loc, GroupType::get(ctx));
+  auto group =
+      rewriter.create<CreateGroupOp>(loc, GroupType::get(ctx), totalBlocks);
 
   // Build a scf.for loop nest from the parallel operation.
 

diff  --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index 0789a0ee68875..ea8e353925994 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -302,7 +302,7 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
 }
 
 //===----------------------------------------------------------------------===//
-// Convert async.create_group operation to async.runtime.create
+// Convert async.create_group operation to async.runtime.create_group
 //===----------------------------------------------------------------------===//
 
 namespace {
@@ -313,8 +313,8 @@ class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> {
   LogicalResult
   matchAndRewrite(CreateGroupOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<RuntimeCreateOp>(
-        op, GroupType::get(op->getContext()));
+    rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>(
+        op, GroupType::get(op->getContext()), operands);
     return success();
   }
 };

diff  --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
index 6bbb8e4052b10..a8aeaec60d3a7 100644
--- a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
@@ -211,8 +211,8 @@ struct AsyncValue : public RefCounted {
 // values to await on all of them together (wait for the completion of all
 // tokens or values added to the group).
 struct AsyncGroup : public RefCounted {
-  AsyncGroup(AsyncRuntime *runtime)
-      : RefCounted(runtime), pendingTokens(0), numErrors(0), rank(0) {}
+  AsyncGroup(AsyncRuntime *runtime, int64_t size)
+      : RefCounted(runtime), pendingTokens(size), numErrors(0), rank(0) {}
 
   std::atomic<int> pendingTokens;
   std::atomic<int> numErrors;
@@ -249,8 +249,8 @@ extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int32_t size) {
 }
 
 // Create a new `async.group` in empty state.
-extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() {
-  AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime());
+extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup(int64_t size) {
+  AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime(), size);
   return group;
 }
 
@@ -261,13 +261,16 @@ extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token,
 
   // Get the rank of the token inside the group before we drop the reference.
   int rank = group->rank.fetch_add(1);
-  group->pendingTokens.fetch_add(1);
 
   auto onTokenReady = [group, token]() {
     // Increment the number of errors in the group.
     if (State(token->state).isError())
       group->numErrors.fetch_add(1);
 
+    // If pending tokens go below zero it means that more tokens than the group
+    // size were added to this group.
+    assert(group->pendingTokens > 0 && "wrong group size");
+
     // Run all group awaiters if it was the last token in the group.
     if (group->pendingTokens.fetch_sub(1) == 1) {
       group->cv.notify_all();

diff  --git a/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir
index 74c091e7575a1..9d57ef31dea84 100644
--- a/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir
+++ b/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-async-to-llvm | FileCheck %s
+// RUN: mlir-opt %s -convert-async-to-llvm | FileCheck %s --dump-input=always
 
 // CHECK-LABEL: @create_token
 func @create_token() {
@@ -20,8 +20,11 @@ func @create_value() {
 
 // CHECK-LABEL: @create_group
 func @create_group() {
-  // CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup
-  %0 = async.runtime.create : !async.group
+  // CHECK: %[[C:.*]] = constant 1 : index
+  // CHECK: %[[S:.*]] = llvm.mlir.cast %[[C]] : index to i64
+  %c = constant 1 : index
+  // CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup(%[[S]])
+  %0 = async.runtime.create_group  %c: !async.group
   return
 }
 
@@ -81,8 +84,9 @@ func @await_value() {
 
 // CHECK-LABEL: @await_group
 func @await_group() {
+  %c = constant 1 : index
   // CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup
-  %0 = async.runtime.create : !async.group
+  %0 = async.runtime.create_group %c: !async.group
   // CHECK: call @mlirAsyncRuntimeAwaitAllInGroup(%[[GROUP]])
   async.runtime.await %0 : !async.group
   return
@@ -118,11 +122,12 @@ func @await_and_resume_value() {
 
 // CHECK-LABEL: @await_and_resume_group
 func @await_and_resume_group() {
+  %c = constant 1 : index
   %0 = async.coro.id
   // CHECK: %[[HDL:.*]] = llvm.intr.coro.begin
   %1 = async.coro.begin %0
   // CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateGroup
-  %2 = async.runtime.create : !async.group
+  %2 = async.runtime.create_group %c : !async.group
   // CHECK: %[[RESUME:.*]] = llvm.mlir.addressof @__resume
   // CHECK: call @mlirAsyncRuntimeAwaitAllInGroupAndExecute
   // CHECK-SAME: (%[[TOKEN]], %[[HDL]], %[[RESUME]])
@@ -168,10 +173,11 @@ func @load() -> f32 {
 
 // CHECK-LABEL: @add_token_to_group
 func @add_token_to_group() {
+  %c = constant 1 : index
   // CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateToken
   %0 = async.runtime.create : !async.token
   // CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup
-  %1 = async.runtime.create : !async.group
+  %1 = async.runtime.create_group %c : !async.group
   // CHECK: call @mlirAsyncRuntimeAddTokenToGroup(%[[TOKEN]], %[[GROUP]])
   async.runtime.add_to_group %0, %1 : !async.token
   return

diff  --git a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
index ba2a3914145c2..da96f306f0bf0 100644
--- a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
+++ b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
@@ -170,12 +170,13 @@ func @async_execute_token_dependency(%arg0: f32, %arg1: memref<1xf32>) {
 
 // CHECK-LABEL: async_group_await_all
 func @async_group_await_all(%arg0: f32, %arg1: memref<1xf32>) {
-  // CHECK: %0 = call @mlirAsyncRuntimeCreateGroup()
-  %0 = async.create_group
+  %c = constant 1 : index
+  // CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup
+  %0 = async.create_group %c : !async.group
 
   // CHECK: %[[TOKEN:.*]] = call @async_execute_fn
   %token = async.execute { async.yield }
-  // CHECK: call @mlirAsyncRuntimeAddTokenToGroup(%[[TOKEN]], %0)
+  // CHECK: call @mlirAsyncRuntimeAddTokenToGroup(%[[TOKEN]], %[[GROUP]])
   async.add_to_group %token, %0 : !async.token
 
   // CHECK: call @async_execute_fn_0
@@ -184,7 +185,7 @@ func @async_group_await_all(%arg0: f32, %arg1: memref<1xf32>) {
     async.yield
   }
 
-  // CHECK: call @mlirAsyncRuntimeAwaitAllInGroup(%0)
+  // CHECK: call @mlirAsyncRuntimeAwaitAllInGroup(%[[GROUP]])
   async.await_all %0
 
   return

diff  --git a/mlir/test/Dialect/Async/async-to-async-runtime.mlir b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
index b77b0d6e89a9a..7564f13352a72 100644
--- a/mlir/test/Dialect/Async/async-to-async-runtime.mlir
+++ b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
@@ -179,8 +179,10 @@ func @async_execute_token_dependency(%arg0: f32, %arg1: memref<1xf32>) {
 
 // CHECK-LABEL: @async_group_await_all
 func @async_group_await_all(%arg0: f32, %arg1: memref<1xf32>) {
-  // CHECK: %[[GROUP:.*]] = async.runtime.create : !async.group
-  %0 = async.create_group
+  // CHECK: %[[C:.*]] = constant 1 : index
+  %c = constant 1 : index
+  // CHECK: %[[GROUP:.*]] = async.runtime.create_group %[[C]] : !async.group
+  %0 = async.create_group %c : !async.group
 
   // CHECK: %[[TOKEN:.*]] = call @async_execute_fn
   %token = async.execute { async.yield }

diff  --git a/mlir/test/Dialect/Async/ops.mlir b/mlir/test/Dialect/Async/ops.mlir
index a95be650eff78..1ec2b6d6faa15 100644
--- a/mlir/test/Dialect/Async/ops.mlir
+++ b/mlir/test/Dialect/Async/ops.mlir
@@ -122,8 +122,10 @@ func @await_value(%arg0: !async.value<f32>) -> f32 {
 }
 
 // CHECK-LABEL: @create_group_and_await_all
-func @create_group_and_await_all(%arg0: !async.token, %arg1: !async.value<f32>) -> index {
-  %0 = async.create_group
+func @create_group_and_await_all(%arg0: !async.token,
+                                 %arg1: !async.value<f32>) -> index {
+  %c = constant 2 : index
+  %0 = async.create_group %c : !async.group
 
   // CHECK: async.add_to_group %arg0
   // CHECK: async.add_to_group %arg1

diff  --git a/mlir/test/Dialect/Async/runtime.mlir b/mlir/test/Dialect/Async/runtime.mlir
index ede523f3c084a..1b39e6420b870 100644
--- a/mlir/test/Dialect/Async/runtime.mlir
+++ b/mlir/test/Dialect/Async/runtime.mlir
@@ -18,9 +18,11 @@ func @create_value() -> !async.value<f32> {
 
 // CHECK-LABEL: @create_group
 func @create_group() -> !async.group {
-  // CHECK: %0 = async.runtime.create : !async.group
-  %0 = async.runtime.create : !async.group
-  // CHECK: return %0 : !async.group
+  // CHECK: %[[C:.*]] = constant 10 : index
+  %c = constant 10 : index
+  // CHECK: %[[V:.*]] = async.runtime.create_group %[[C]] : !async.group
+  %0 = async.runtime.create_group %c : !async.group
+  // CHECK: return %[[V]] : !async.group
   return %0 : !async.group
 }
 

diff  --git a/mlir/test/mlir-cpu-runner/async-error.mlir b/mlir/test/mlir-cpu-runner/async-error.mlir
index 63b9a0077f5d9..77616deba76e6 100644
--- a/mlir/test/mlir-cpu-runner/async-error.mlir
+++ b/mlir/test/mlir-cpu-runner/async-error.mlir
@@ -85,7 +85,8 @@ func @main() {
   // Check error propagation from a token to the group.
   // ------------------------------------------------------------------------ //
 
-  %group0 = async.create_group
+  %c2 = constant 2 : index
+  %group0 = async.create_group %c2 : !async.group
 
   %token4 = async.execute {
     async.yield

diff  --git a/mlir/test/mlir-cpu-runner/async-group.mlir b/mlir/test/mlir-cpu-runner/async-group.mlir
index 8216d1558c639..7df88776262c4 100644
--- a/mlir/test/mlir-cpu-runner/async-group.mlir
+++ b/mlir/test/mlir-cpu-runner/async-group.mlir
@@ -11,7 +11,10 @@
 // RUN: | FileCheck %s
 
 func @main() {
-  %group = async.create_group
+  %c1 = constant 1 : index
+  %c5 = constant 5 : index
+
+  %group = async.create_group %c5 : !async.group
 
   %token0 = async.execute { async.yield }
   %token1 = async.execute { async.yield }
@@ -30,7 +33,7 @@ func @main() {
     async.yield
   }
 
-  %group0 = async.create_group
+  %group0 = async.create_group %c1 : !async.group
   %5 = async.add_to_group %token5, %group0 : !async.token
   async.await_all %group0
 


        


More information about the Mlir-commits mailing list