[Mlir-commits] [mlir] d8c84d2 - [mlir] Async: Add error propagation support to async groups
Eugene Zhulenev
llvmlistbot at llvm.org
Thu May 27 09:35:17 PDT 2021
Author: Eugene Zhulenev
Date: 2021-05-27T09:35:11-07:00
New Revision: d8c84d2a4efc87b756d9d3df42b80d6f8762f62a
URL: https://github.com/llvm/llvm-project/commit/d8c84d2a4efc87b756d9d3df42b80d6f8762f62a
DIFF: https://github.com/llvm/llvm-project/commit/d8c84d2a4efc87b756d9d3df42b80d6f8762f62a.diff
LOG: [mlir] Async: Add error propagation support to async groups
Depends On D103109
If any of the tokens/values added to the `!async.group` switches to the error state, than the group itself switches to the error state.
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D103203
Added:
Modified:
mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
mlir/lib/ExecutionEngine/AsyncRuntime.cpp
mlir/test/Dialect/Async/async-to-async-runtime.mlir
mlir/test/Dialect/Async/runtime.mlir
mlir/test/mlir-cpu-runner/async-error.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
index 5f6eece6cf69d..9ef218ebe560f 100644
--- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
+++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
@@ -373,7 +373,7 @@ def Async_RuntimeIsErrorOp : Async_Op<"runtime.is_error"> {
resuming after `await_and_resume`.
}];
- let arguments = (ins Async_AnyValueOrTokenType:$operand);
+ let arguments = (ins Async_AnyAsyncType:$operand);
let results = (outs I1:$is_error);
let assemblyFormat = "$operand attr-dict `:` type($operand)";
diff --git a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
index df9adebbd6206..3b26bf61e622f 100644
--- a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
+++ b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
@@ -88,6 +88,10 @@ extern "C" bool mlirAsyncRuntimeIsTokenError(AsyncToken *);
// Returns true if value is in the error state.
extern "C" bool mlirAsyncRuntimeIsValueError(AsyncValue *);
+// Returns true if group is in the error state (any of the tokens or values
+// added to the group are in the error state).
+extern "C" bool mlirAsyncRuntimeIsGroupError(AsyncGroup *);
+
// Blocks the caller thread until the token becomes ready.
extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *);
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 55d1714171774..a66f246cf55ed 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -18,6 +18,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/TypeSwitch.h"
#define DEBUG_TYPE "convert-async-to-llvm"
@@ -39,6 +40,7 @@ static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError";
static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError";
static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError";
static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError";
+static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError";
static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
@@ -125,6 +127,11 @@ struct AsyncAPI {
return FunctionType::get(ctx, {value}, {i1});
}
+ static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) {
+ auto i1 = IntegerType::get(ctx, 1);
+ return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1});
+ }
+
static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
}
@@ -201,6 +208,7 @@ static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx));
addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx));
addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx));
+ addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx));
addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
@@ -587,10 +595,13 @@ class RuntimeSetAvailableOpLowering
LogicalResult
matchAndRewrite(RuntimeSetAvailableOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- Type operandType = op.operand().getType();
- rewriter.replaceOpWithNewOp<CallOp>(
- op, operandType.isa<TokenType>() ? kEmplaceToken : kEmplaceValue,
- TypeRange(), operands);
+ StringRef apiFuncName =
+ TypeSwitch<Type, StringRef>(op.operand().getType())
+ .Case<TokenType>([](Type) { return kEmplaceToken; })
+ .Case<ValueType>([](Type) { return kEmplaceValue; });
+
+ rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(), operands);
+
return success();
}
};
@@ -609,10 +620,13 @@ class RuntimeSetErrorOpLowering
LogicalResult
matchAndRewrite(RuntimeSetErrorOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- Type operandType = op.operand().getType();
- rewriter.replaceOpWithNewOp<CallOp>(
- op, operandType.isa<TokenType>() ? kSetTokenError : kSetValueError,
- TypeRange(), operands);
+ StringRef apiFuncName =
+ TypeSwitch<Type, StringRef>(op.operand().getType())
+ .Case<TokenType>([](Type) { return kSetTokenError; })
+ .Case<ValueType>([](Type) { return kSetValueError; });
+
+ rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, TypeRange(), operands);
+
return success();
}
};
@@ -630,10 +644,14 @@ class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> {
LogicalResult
matchAndRewrite(RuntimeIsErrorOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- Type operandType = op.operand().getType();
- rewriter.replaceOpWithNewOp<CallOp>(
- op, operandType.isa<TokenType>() ? kIsTokenError : kIsValueError,
- rewriter.getI1Type(), operands);
+ StringRef apiFuncName =
+ TypeSwitch<Type, StringRef>(op.operand().getType())
+ .Case<TokenType>([](Type) { return kIsTokenError; })
+ .Case<GroupType>([](Type) { return kIsGroupError; })
+ .Case<ValueType>([](Type) { return kIsValueError; });
+
+ rewriter.replaceOpWithNewOp<CallOp>(op, apiFuncName, rewriter.getI1Type(),
+ operands);
return success();
}
};
@@ -651,17 +669,11 @@ class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> {
LogicalResult
matchAndRewrite(RuntimeAwaitOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- Type operandType = op.operand().getType();
-
- StringRef apiFuncName;
- if (operandType.isa<TokenType>())
- apiFuncName = kAwaitToken;
- else if (operandType.isa<ValueType>())
- apiFuncName = kAwaitValue;
- else if (operandType.isa<GroupType>())
- apiFuncName = kAwaitGroup;
- else
- return rewriter.notifyMatchFailure(op, "unsupported async type");
+ StringRef apiFuncName =
+ TypeSwitch<Type, StringRef>(op.operand().getType())
+ .Case<TokenType>([](Type) { return kAwaitToken; })
+ .Case<ValueType>([](Type) { return kAwaitValue; })
+ .Case<GroupType>([](Type) { return kAwaitGroup; });
rewriter.create<CallOp>(op->getLoc(), apiFuncName, TypeRange(), operands);
rewriter.eraseOp(op);
@@ -684,17 +696,11 @@ class RuntimeAwaitAndResumeOpLowering
LogicalResult
matchAndRewrite(RuntimeAwaitAndResumeOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- Type operandType = op.operand().getType();
-
- StringRef apiFuncName;
- if (operandType.isa<TokenType>())
- apiFuncName = kAwaitTokenAndExecute;
- else if (operandType.isa<ValueType>())
- apiFuncName = kAwaitValueAndExecute;
- else if (operandType.isa<GroupType>())
- apiFuncName = kAwaitAllAndExecute;
- else
- return rewriter.notifyMatchFailure(op, "unsupported async type");
+ StringRef apiFuncName =
+ TypeSwitch<Type, StringRef>(op.operand().getType())
+ .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; })
+ .Case<ValueType>([](Type) { return kAwaitValueAndExecute; })
+ .Case<GroupType>([](Type) { return kAwaitAllAndExecute; });
Value operand = RuntimeAwaitAndResumeOpAdaptor(operands).operand();
Value handle = RuntimeAwaitAndResumeOpAdaptor(operands).handle();
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index baa7d5d8c46ae..6ebf48ad64756 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -399,25 +399,22 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume,
coro.cleanup);
- // TODO: Async groups do not yet support runtime errors.
- if (!std::is_same<AwaitAllOp, AwaitType>::value) {
- // Split the resume block into error checking and continuation.
- Block *continuation = rewriter.splitBlock(resume, Block::iterator(op));
-
- // Check if the awaited value is in the error state.
- builder.setInsertionPointToStart(resume);
- auto isError = builder.create<RuntimeIsErrorOp>(
- loc, rewriter.getI1Type(), operand);
- builder.create<CondBranchOp>(isError,
- /*trueDest=*/setupSetErrorBlock(coro),
- /*trueArgs=*/ArrayRef<Value>(),
- /*falseDest=*/continuation,
- /*falseArgs=*/ArrayRef<Value>());
-
- // Make sure that replacement value will be constructed in the
- // continuation block.
- rewriter.setInsertionPointToStart(continuation);
- }
+ // Split the resume block into error checking and continuation.
+ Block *continuation = rewriter.splitBlock(resume, Block::iterator(op));
+
+ // Check if the awaited value is in the error state.
+ builder.setInsertionPointToStart(resume);
+ auto isError =
+ builder.create<RuntimeIsErrorOp>(loc, rewriter.getI1Type(), operand);
+ builder.create<CondBranchOp>(isError,
+ /*trueDest=*/setupSetErrorBlock(coro),
+ /*trueArgs=*/ArrayRef<Value>(),
+ /*falseDest=*/continuation,
+ /*falseArgs=*/ArrayRef<Value>());
+
+ // Make sure that replacement value will be constructed in the
+ // continuation block.
+ rewriter.setInsertionPointToStart(continuation);
}
// Erase or replace the await operation with the new value.
diff --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
index 856d1c7b74f9c..6bbb8e4052b10 100644
--- a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
@@ -212,9 +212,10 @@ struct AsyncValue : public RefCounted {
// tokens or values added to the group).
struct AsyncGroup : public RefCounted {
AsyncGroup(AsyncRuntime *runtime)
- : RefCounted(runtime), pendingTokens(0), rank(0) {}
+ : RefCounted(runtime), pendingTokens(0), numErrors(0), rank(0) {}
std::atomic<int> pendingTokens;
+ std::atomic<int> numErrors;
std::atomic<int> rank;
// Pending awaiters are guarded by a mutex.
@@ -262,7 +263,11 @@ extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token,
int rank = group->rank.fetch_add(1);
group->pendingTokens.fetch_add(1);
- auto onTokenReady = [group]() {
+ auto onTokenReady = [group, token]() {
+ // Increment the number of errors in the group.
+ if (State(token->state).isError())
+ group->numErrors.fetch_add(1);
+
// Run all group awaiters if it was the last token in the group.
if (group->pendingTokens.fetch_sub(1) == 1) {
group->cv.notify_all();
@@ -356,6 +361,10 @@ extern "C" bool mlirAsyncRuntimeIsValueError(AsyncValue *value) {
return State(value->state).isError();
}
+extern "C" bool mlirAsyncRuntimeIsGroupError(AsyncGroup *group) {
+ return group->numErrors.load() > 0;
+}
+
extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
std::unique_lock<std::mutex> lock(token->mu);
if (!State(token->state).isAvailableOrError())
@@ -483,6 +492,8 @@ void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols) {
&mlir::runtime::mlirAsyncRuntimeIsTokenError);
exportSymbol("mlirAsyncRuntimeIsValueError",
&mlir::runtime::mlirAsyncRuntimeIsValueError);
+ exportSymbol("mlirAsyncRuntimeIsGroupError",
+ &mlir::runtime::mlirAsyncRuntimeIsGroupError);
exportSymbol("mlirAsyncRuntimeAwaitToken",
&mlir::runtime::mlirAsyncRuntimeAwaitToken);
exportSymbol("mlirAsyncRuntimeAwaitValue",
diff --git a/mlir/test/Dialect/Async/async-to-async-runtime.mlir b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
index 08607f89e2e57..b77b0d6e89a9a 100644
--- a/mlir/test/Dialect/Async/async-to-async-runtime.mlir
+++ b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
@@ -216,8 +216,13 @@ func @async_group_await_all(%arg0: f32, %arg1: memref<1xf32>) {
// CHECK: async.coro.suspend
// CHECK-SAME: ^[[SUSPEND]], ^[[RESUME_1:.*]], ^[[CLEANUP]]
-// Emplace result token.
+// Check the error of the awaited token after resumption.
// CHECK: ^[[RESUME_1]]:
+// CHECK: %[[ERR:.*]] = async.runtime.is_error %[[ARG]]
+// CHECK: cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]]
+
+// Emplace result token after error checking.
+// CHECK: ^[[CONTINUATION:.*]]:
// CHECK: async.runtime.set_available %[[TOKEN]]
// CHECK: ^[[CLEANUP]]:
diff --git a/mlir/test/Dialect/Async/runtime.mlir b/mlir/test/Dialect/Async/runtime.mlir
index c8f6f65cb8469..ede523f3c084a 100644
--- a/mlir/test/Dialect/Async/runtime.mlir
+++ b/mlir/test/Dialect/Async/runtime.mlir
@@ -66,6 +66,13 @@ func @is_value_error(%arg0: !async.value<f32>) -> i1 {
return %0 : i1
}
+// CHECK-LABEL: @is_group_error
+func @is_group_error(%arg0: !async.group) -> i1 {
+ // CHECK: %[[ERR:.*]] = async.runtime.is_error %arg0 : !async.group
+ %0 = async.runtime.is_error %arg0 : !async.group
+ return %0 : i1
+}
+
// CHECK-LABEL: @await_token
func @await_token(%arg0: !async.token) {
// CHECK: async.runtime.await %arg0 : !async.token
diff --git a/mlir/test/mlir-cpu-runner/async-error.mlir b/mlir/test/mlir-cpu-runner/async-error.mlir
index 139296d6614f4..63b9a0077f5d9 100644
--- a/mlir/test/mlir-cpu-runner/async-error.mlir
+++ b/mlir/test/mlir-cpu-runner/async-error.mlir
@@ -81,5 +81,29 @@ func @main() {
vector.print %err3_0 : i1
vector.print %err3_1 : i1
+ // ------------------------------------------------------------------------ //
+ // Check error propagation from a token to the group.
+ // ------------------------------------------------------------------------ //
+
+ %group0 = async.create_group
+
+ %token4 = async.execute {
+ async.yield
+ }
+
+ %token5 = async.execute {
+ assert %false, "error"
+ async.yield
+ }
+
+ %idx0 = async.add_to_group %token4, %group0 : !async.token
+ %idx1 = async.add_to_group %token5, %group0 : !async.token
+
+ async.runtime.await %group0 : !async.group
+
+ // CHECK: 1
+ %err4 = async.runtime.is_error %group0 : !async.group
+ vector.print %err4 : i1
+
return
}
More information about the Mlir-commits
mailing list