[Mlir-commits] [mlir] d43b236 - [mlir:Async] Add the size parameter to the async.group
Eugene Zhulenev
llvmlistbot at llvm.org
Fri Jun 25 10:26:57 PDT 2021
Author: Eugene Zhulenev
Date: 2021-06-25T10:26:50-07:00
New Revision: d43b23608ad664f02f56e965ca78916bde220950
URL: https://github.com/llvm/llvm-project/commit/d43b23608ad664f02f56e965ca78916bde220950
DIFF: https://github.com/llvm/llvm-project/commit/d43b23608ad664f02f56e965ca78916bde220950.diff
LOG: [mlir:Async] Add the size parameter to the async.group
Specify the `!async.group` size (the number of tokens that will be added to it) at construction time. `async.await_all` operation can potentially race with `async.execute` operations that keep updating the group, for this reason it is required to know upfront how many tokens will be added to the group.
Reviewed By: ftynse, herhut
Differential Revision: https://reviews.llvm.org/D104780
Added:
Modified:
mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
mlir/lib/ExecutionEngine/AsyncRuntime.cpp
mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir
mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
mlir/test/Dialect/Async/async-to-async-runtime.mlir
mlir/test/Dialect/Async/ops.mlir
mlir/test/Dialect/Async/runtime.mlir
mlir/test/mlir-cpu-runner/async-error.mlir
mlir/test/mlir-cpu-runner/async-group.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
index 9ef218ebe560f..f9ddd67a7961d 100644
--- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
+++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
@@ -160,20 +160,24 @@ def Async_CreateGroupOp : Async_Op<"create_group", [NoSideEffect]> {
let summary = "creates an empty async group";
let description = [{
The `async.create_group` allocates an empty async group. Async tokens or
- values can be added to this group later.
+ values can be added to this group later. The size of the group must be
+ specified at construction time, and `await_all` operation will first
+ wait until the number of added tokens or values reaches the group size.
Example:
```mlir
- %0 = async.create_group
+ %size = ... : index
+ %group = async.create_group %size : !async.group
...
- async.await_all %0
+ async.await_all %group
```
}];
+ let arguments = (ins Index:$size);
let results = (outs Async_GroupType:$result);
- let assemblyFormat = "attr-dict";
+ let assemblyFormat = "$size `:` type($result) attr-dict";
}
def Async_AddToGroupOp : Async_Op<"add_to_group", []> {
@@ -186,7 +190,7 @@ def Async_AddToGroupOp : Async_Op<"add_to_group", []> {
Example:
```mlir
- %0 = async.create_group
+ %0 = async.create_group %size : !async.group
%1 = ... : !async.token
%2 = async.add_to_group %1, %0 : !async.token
```
@@ -209,7 +213,7 @@ def Async_AwaitAllOp : Async_Op<"await_all", []> {
Example:
```mlir
- %0 = async.create_group
+ %0 = async.create_group %size : !async.group
%1 = ... : !async.token
%2 = async.add_to_group %1, %0 : !async.token
@@ -331,17 +335,28 @@ def Async_CoroSuspendOp : Async_Op<"coro.suspend", [Terminator]> {
// Runtime API defined in the `ExecutionEngine/AsyncRuntime.h`.
def Async_RuntimeCreateOp : Async_Op<"runtime.create"> {
- let summary = "creates an async runtime value (token, value or group)";
+ let summary = "creates an async runtime token or value";
let description = [{
- The `async.runtime.create` operation creates an async dialect value
- (token, value or group). Tokens and values are created in non-ready state.
- Groups are created in empty state.
+ The `async.runtime.create` operation creates an async dialect token or
+ value. Tokens and values are created in the non-ready state.
}];
- let results = (outs Async_AnyAsyncType:$result);
+ let results = (outs Async_AnyValueOrTokenType:$result);
let assemblyFormat = "attr-dict `:` type($result)";
}
+def Async_RuntimeCreateGroupOp : Async_Op<"runtime.create_group"> {
+ let summary = "creates an async runtime group";
+ let description = [{
+ The `async.runtime.create_group` operation creates an async dialect group
+ of the given size. Group created in the empty state.
+ }];
+
+ let arguments = (ins Index:$size);
+ let results = (outs Async_GroupType:$result);
+ let assemblyFormat = "$size `:` type($result) attr-dict ";
+}
+
def Async_RuntimeSetAvailableOp : Async_Op<"runtime.set_available"> {
let summary = "switches token or value to available state";
let description = [{
diff --git a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
index 3b26bf61e622f..a101b28bb4282 100644
--- a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
+++ b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
@@ -66,7 +66,7 @@ extern "C" AsyncToken *mlirAsyncRuntimeCreateToken();
extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int32_t);
// Create a new `async.group` in empty state.
-extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup();
+extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup(int64_t size);
extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *, AsyncGroup *);
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index ff2460bb129c4..0156ede1b9b65 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -89,7 +89,8 @@ struct AsyncAPI {
}
static FunctionType createGroupFunctionType(MLIRContext *ctx) {
- return FunctionType::get(ctx, {}, {GroupType::get(ctx)});
+ auto i64 = IntegerType::get(ctx, 64);
+ return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)});
}
static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
@@ -543,11 +544,10 @@ class RuntimeCreateOpLowering : public OpConversionPattern<RuntimeCreateOp> {
TypeConverter *converter = getTypeConverter();
Type resultType = op->getResultTypes()[0];
- // Tokens and Groups lowered to function calls without arguments.
- if (resultType.isa<TokenType>() || resultType.isa<GroupType>()) {
- rewriter.replaceOpWithNewOp<CallOp>(
- op, resultType.isa<TokenType>() ? kCreateToken : kCreateGroup,
- converter->convertType(resultType));
+ // Tokens creation maps to a simple function call.
+ if (resultType.isa<TokenType>()) {
+ rewriter.replaceOpWithNewOp<CallOp>(op, kCreateToken,
+ converter->convertType(resultType));
return success();
}
@@ -582,6 +582,29 @@ class RuntimeCreateOpLowering : public OpConversionPattern<RuntimeCreateOp> {
};
} // namespace
+//===----------------------------------------------------------------------===//
+// Convert async.runtime.create_group to the corresponding runtime API call.
+//===----------------------------------------------------------------------===//
+
+namespace {
+class RuntimeCreateGroupOpLowering
+ : public OpConversionPattern<RuntimeCreateGroupOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(RuntimeCreateGroupOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ TypeConverter *converter = getTypeConverter();
+ Type resultType = op->getResultTypes()[0];
+
+ rewriter.replaceOpWithNewOp<CallOp>(
+ op, kCreateGroup, converter->convertType(resultType), operands);
+ return success();
+ }
+};
+} // namespace
+
//===----------------------------------------------------------------------===//
// Convert async.runtime.set_available to the corresponding runtime API call.
//===----------------------------------------------------------------------===//
@@ -967,8 +990,9 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
// Lower async.runtime operations that rely on LLVM type converter to convert
// from async value payload type to the LLVM type.
- patterns.add<RuntimeCreateOpLowering, RuntimeStoreOpLowering,
- RuntimeLoadOpLowering>(llvmConverter, ctx);
+ patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering,
+ RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter,
+ ctx);
// Lower async coroutine operations to LLVM coroutine intrinsics.
patterns
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index ce2bc7081faf1..ba09123199849 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -165,8 +165,14 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
numBlocks[i] = divup(tripCounts[i], blockSize[i]);
}
+ // Total number of async compute blocks.
+ Value totalBlocks = numBlocks[0];
+ for (size_t i = 1; i < op.getNumLoops(); ++i)
+ totalBlocks = rewriter.create<MulIOp>(loc, totalBlocks, numBlocks[i]);
+
// Create an async.group to wait on all async tokens from async execute ops.
- auto group = rewriter.create<CreateGroupOp>(loc, GroupType::get(ctx));
+ auto group =
+ rewriter.create<CreateGroupOp>(loc, GroupType::get(ctx), totalBlocks);
// Build a scf.for loop nest from the parallel operation.
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index 0789a0ee68875..ea8e353925994 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -302,7 +302,7 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
}
//===----------------------------------------------------------------------===//
-// Convert async.create_group operation to async.runtime.create
+// Convert async.create_group operation to async.runtime.create_group
//===----------------------------------------------------------------------===//
namespace {
@@ -313,8 +313,8 @@ class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> {
LogicalResult
matchAndRewrite(CreateGroupOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<RuntimeCreateOp>(
- op, GroupType::get(op->getContext()));
+ rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>(
+ op, GroupType::get(op->getContext()), operands);
return success();
}
};
diff --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
index 6bbb8e4052b10..a8aeaec60d3a7 100644
--- a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
@@ -211,8 +211,8 @@ struct AsyncValue : public RefCounted {
// values to await on all of them together (wait for the completion of all
// tokens or values added to the group).
struct AsyncGroup : public RefCounted {
- AsyncGroup(AsyncRuntime *runtime)
- : RefCounted(runtime), pendingTokens(0), numErrors(0), rank(0) {}
+ AsyncGroup(AsyncRuntime *runtime, int64_t size)
+ : RefCounted(runtime), pendingTokens(size), numErrors(0), rank(0) {}
std::atomic<int> pendingTokens;
std::atomic<int> numErrors;
@@ -249,8 +249,8 @@ extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int32_t size) {
}
// Create a new `async.group` in empty state.
-extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() {
- AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime());
+extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup(int64_t size) {
+ AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime(), size);
return group;
}
@@ -261,13 +261,16 @@ extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token,
// Get the rank of the token inside the group before we drop the reference.
int rank = group->rank.fetch_add(1);
- group->pendingTokens.fetch_add(1);
auto onTokenReady = [group, token]() {
// Increment the number of errors in the group.
if (State(token->state).isError())
group->numErrors.fetch_add(1);
+ // If pending tokens go below zero it means that more tokens than the group
+ // size were added to this group.
+ assert(group->pendingTokens > 0 && "wrong group size");
+
// Run all group awaiters if it was the last token in the group.
if (group->pendingTokens.fetch_sub(1) == 1) {
group->cv.notify_all();
diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir
index 74c091e7575a1..9d57ef31dea84 100644
--- a/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir
+++ b/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-async-to-llvm | FileCheck %s
+// RUN: mlir-opt %s -convert-async-to-llvm | FileCheck %s --dump-input=always
// CHECK-LABEL: @create_token
func @create_token() {
@@ -20,8 +20,11 @@ func @create_value() {
// CHECK-LABEL: @create_group
func @create_group() {
- // CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup
- %0 = async.runtime.create : !async.group
+ // CHECK: %[[C:.*]] = constant 1 : index
+ // CHECK: %[[S:.*]] = llvm.mlir.cast %[[C]] : index to i64
+ %c = constant 1 : index
+ // CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup(%[[S]])
+ %0 = async.runtime.create_group %c: !async.group
return
}
@@ -81,8 +84,9 @@ func @await_value() {
// CHECK-LABEL: @await_group
func @await_group() {
+ %c = constant 1 : index
// CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup
- %0 = async.runtime.create : !async.group
+ %0 = async.runtime.create_group %c: !async.group
// CHECK: call @mlirAsyncRuntimeAwaitAllInGroup(%[[GROUP]])
async.runtime.await %0 : !async.group
return
@@ -118,11 +122,12 @@ func @await_and_resume_value() {
// CHECK-LABEL: @await_and_resume_group
func @await_and_resume_group() {
+ %c = constant 1 : index
%0 = async.coro.id
// CHECK: %[[HDL:.*]] = llvm.intr.coro.begin
%1 = async.coro.begin %0
// CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateGroup
- %2 = async.runtime.create : !async.group
+ %2 = async.runtime.create_group %c : !async.group
// CHECK: %[[RESUME:.*]] = llvm.mlir.addressof @__resume
// CHECK: call @mlirAsyncRuntimeAwaitAllInGroupAndExecute
// CHECK-SAME: (%[[TOKEN]], %[[HDL]], %[[RESUME]])
@@ -168,10 +173,11 @@ func @load() -> f32 {
// CHECK-LABEL: @add_token_to_group
func @add_token_to_group() {
+ %c = constant 1 : index
// CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateToken
%0 = async.runtime.create : !async.token
// CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup
- %1 = async.runtime.create : !async.group
+ %1 = async.runtime.create_group %c : !async.group
// CHECK: call @mlirAsyncRuntimeAddTokenToGroup(%[[TOKEN]], %[[GROUP]])
async.runtime.add_to_group %0, %1 : !async.token
return
diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
index ba2a3914145c2..da96f306f0bf0 100644
--- a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
+++ b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
@@ -170,12 +170,13 @@ func @async_execute_token_dependency(%arg0: f32, %arg1: memref<1xf32>) {
// CHECK-LABEL: async_group_await_all
func @async_group_await_all(%arg0: f32, %arg1: memref<1xf32>) {
- // CHECK: %0 = call @mlirAsyncRuntimeCreateGroup()
- %0 = async.create_group
+ %c = constant 1 : index
+ // CHECK: %[[GROUP:.*]] = call @mlirAsyncRuntimeCreateGroup
+ %0 = async.create_group %c : !async.group
// CHECK: %[[TOKEN:.*]] = call @async_execute_fn
%token = async.execute { async.yield }
- // CHECK: call @mlirAsyncRuntimeAddTokenToGroup(%[[TOKEN]], %0)
+ // CHECK: call @mlirAsyncRuntimeAddTokenToGroup(%[[TOKEN]], %[[GROUP]])
async.add_to_group %token, %0 : !async.token
// CHECK: call @async_execute_fn_0
@@ -184,7 +185,7 @@ func @async_group_await_all(%arg0: f32, %arg1: memref<1xf32>) {
async.yield
}
- // CHECK: call @mlirAsyncRuntimeAwaitAllInGroup(%0)
+ // CHECK: call @mlirAsyncRuntimeAwaitAllInGroup(%[[GROUP]])
async.await_all %0
return
diff --git a/mlir/test/Dialect/Async/async-to-async-runtime.mlir b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
index b77b0d6e89a9a..7564f13352a72 100644
--- a/mlir/test/Dialect/Async/async-to-async-runtime.mlir
+++ b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
@@ -179,8 +179,10 @@ func @async_execute_token_dependency(%arg0: f32, %arg1: memref<1xf32>) {
// CHECK-LABEL: @async_group_await_all
func @async_group_await_all(%arg0: f32, %arg1: memref<1xf32>) {
- // CHECK: %[[GROUP:.*]] = async.runtime.create : !async.group
- %0 = async.create_group
+ // CHECK: %[[C:.*]] = constant 1 : index
+ %c = constant 1 : index
+ // CHECK: %[[GROUP:.*]] = async.runtime.create_group %[[C]] : !async.group
+ %0 = async.create_group %c : !async.group
// CHECK: %[[TOKEN:.*]] = call @async_execute_fn
%token = async.execute { async.yield }
diff --git a/mlir/test/Dialect/Async/ops.mlir b/mlir/test/Dialect/Async/ops.mlir
index a95be650eff78..1ec2b6d6faa15 100644
--- a/mlir/test/Dialect/Async/ops.mlir
+++ b/mlir/test/Dialect/Async/ops.mlir
@@ -122,8 +122,10 @@ func @await_value(%arg0: !async.value<f32>) -> f32 {
}
// CHECK-LABEL: @create_group_and_await_all
-func @create_group_and_await_all(%arg0: !async.token, %arg1: !async.value<f32>) -> index {
- %0 = async.create_group
+func @create_group_and_await_all(%arg0: !async.token,
+ %arg1: !async.value<f32>) -> index {
+ %c = constant 2 : index
+ %0 = async.create_group %c : !async.group
// CHECK: async.add_to_group %arg0
// CHECK: async.add_to_group %arg1
diff --git a/mlir/test/Dialect/Async/runtime.mlir b/mlir/test/Dialect/Async/runtime.mlir
index ede523f3c084a..1b39e6420b870 100644
--- a/mlir/test/Dialect/Async/runtime.mlir
+++ b/mlir/test/Dialect/Async/runtime.mlir
@@ -18,9 +18,11 @@ func @create_value() -> !async.value<f32> {
// CHECK-LABEL: @create_group
func @create_group() -> !async.group {
- // CHECK: %0 = async.runtime.create : !async.group
- %0 = async.runtime.create : !async.group
- // CHECK: return %0 : !async.group
+ // CHECK: %[[C:.*]] = constant 10 : index
+ %c = constant 10 : index
+ // CHECK: %[[V:.*]] = async.runtime.create_group %[[C]] : !async.group
+ %0 = async.runtime.create_group %c : !async.group
+ // CHECK: return %[[V]] : !async.group
return %0 : !async.group
}
diff --git a/mlir/test/mlir-cpu-runner/async-error.mlir b/mlir/test/mlir-cpu-runner/async-error.mlir
index 63b9a0077f5d9..77616deba76e6 100644
--- a/mlir/test/mlir-cpu-runner/async-error.mlir
+++ b/mlir/test/mlir-cpu-runner/async-error.mlir
@@ -85,7 +85,8 @@ func @main() {
// Check error propagation from a token to the group.
// ------------------------------------------------------------------------ //
- %group0 = async.create_group
+ %c2 = constant 2 : index
+ %group0 = async.create_group %c2 : !async.group
%token4 = async.execute {
async.yield
diff --git a/mlir/test/mlir-cpu-runner/async-group.mlir b/mlir/test/mlir-cpu-runner/async-group.mlir
index 8216d1558c639..7df88776262c4 100644
--- a/mlir/test/mlir-cpu-runner/async-group.mlir
+++ b/mlir/test/mlir-cpu-runner/async-group.mlir
@@ -11,7 +11,10 @@
// RUN: | FileCheck %s
func @main() {
- %group = async.create_group
+ %c1 = constant 1 : index
+ %c5 = constant 5 : index
+
+ %group = async.create_group %c5 : !async.group
%token0 = async.execute { async.yield }
%token1 = async.execute { async.yield }
@@ -30,7 +33,7 @@ func @main() {
async.yield
}
- %group0 = async.create_group
+ %group0 = async.create_group %c1 : !async.group
%5 = async.add_to_group %token5, %group0 : !async.token
async.await_all %group0
More information about the Mlir-commits
mailing list