[Mlir-commits] [mlir] 39957aa - [mlir] Add error state and error propagation to async runtime values
Eugene Zhulenev
llvmlistbot at llvm.org
Thu May 27 09:28:54 PDT 2021
Author: Eugene Zhulenev
Date: 2021-05-27T09:28:47-07:00
New Revision: 39957aa4243cb9aec3a7114c0ecf710ecce96b72
URL: https://github.com/llvm/llvm-project/commit/39957aa4243cb9aec3a7114c0ecf710ecce96b72
DIFF: https://github.com/llvm/llvm-project/commit/39957aa4243cb9aec3a7114c0ecf710ecce96b72.diff
LOG: [mlir] Add error state and error propagation to async runtime values
Depends On D103102
Not yet implemented:
1. Error handling after synchronous await
2. Error handling for async groups
Will be addressed in the followup PRs
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D103109
Added:
mlir/test/mlir-cpu-runner/async-error.mlir
Modified:
mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
mlir/lib/ExecutionEngine/AsyncRuntime.cpp
mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir
mlir/test/Dialect/Async/async-to-async-runtime.mlir
mlir/test/Dialect/Async/runtime.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
index fbfc529c0b824..5f6eece6cf69d 100644
--- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
+++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
@@ -343,7 +343,7 @@ def Async_RuntimeCreateOp : Async_Op<"runtime.create"> {
}
def Async_RuntimeSetAvailableOp : Async_Op<"runtime.set_available"> {
- let summary = "switches token or value available state";
+ let summary = "switches token or value to available state";
let description = [{
The `async.runtime.set_available` operation switches async token or value
state to available.
@@ -353,11 +353,37 @@ def Async_RuntimeSetAvailableOp : Async_Op<"runtime.set_available"> {
let assemblyFormat = "$operand attr-dict `:` type($operand)";
}
+def Async_RuntimeSetErrorOp : Async_Op<"runtime.set_error"> {
+ let summary = "switches token or value to error state";
+ let description = [{
+ The `async.runtime.set_error` operation switches async token or value
+ state to error.
+ }];
+
+ let arguments = (ins Async_AnyValueOrTokenType:$operand);
+ let assemblyFormat = "$operand attr-dict `:` type($operand)";
+}
+
+def Async_RuntimeIsErrorOp : Async_Op<"runtime.is_error"> {
+ let summary = "returns true if token, value or group is in error state";
+ let description = [{
+ The `async.runtime.is_error` operation returns true if the token, value or
+ group (any of the async runtime values) is in the error state. It is the
+ caller responsibility to check error state after the call to `await` or
+ resuming after `await_and_resume`.
+ }];
+
+ let arguments = (ins Async_AnyValueOrTokenType:$operand);
+ let results = (outs I1:$is_error);
+
+ let assemblyFormat = "$operand attr-dict `:` type($operand)";
+}
+
def Async_RuntimeAwaitOp : Async_Op<"runtime.await"> {
let summary = "blocks the caller thread until the operand becomes available";
let description = [{
The `async.runtime.await` operation blocks the caller thread until the
- operand becomes available.
+ operand becomes available or error.
}];
let arguments = (ins Async_AnyAsyncType:$operand);
@@ -379,8 +405,8 @@ def Async_RuntimeAwaitAndResumeOp : Async_Op<"runtime.await_and_resume"> {
let summary = "awaits the async operand and resumes the coroutine";
let description = [{
The `async.runtime.await_and_resume` operation awaits for the operand to
- become available and resumes the coroutine on a thread managed by the
- runtime.
+ become available or error and resumes the coroutine on a thread managed by
+ the runtime.
}];
let arguments = (ins Async_AnyAsyncType:$operand,
diff --git a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
index d5cede323ed92..df9adebbd6206 100644
--- a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
+++ b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
@@ -76,6 +76,18 @@ extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *);
// Switches `async.value` to ready state and runs all awaiters.
extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *);
+// Switches `async.token` to error state and runs all awaiters.
+extern "C" void mlirAsyncRuntimeSetTokenError(AsyncToken *);
+
+// Switches `async.value` to error state and runs all awaiters.
+extern "C" void mlirAsyncRuntimeSetValueError(AsyncValue *);
+
+// Returns true if token is in the error state.
+extern "C" bool mlirAsyncRuntimeIsTokenError(AsyncToken *);
+
+// Returns true if value is in the error state.
+extern "C" bool mlirAsyncRuntimeIsValueError(AsyncValue *);
+
// Blocks the caller thread until the token becomes ready.
extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *);
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 7a24b75640ec1..55d1714171774 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -35,6 +35,10 @@ static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue";
+static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError";
+static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError";
+static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError";
+static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError";
static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
@@ -101,6 +105,26 @@ struct AsyncAPI {
return FunctionType::get(ctx, {value}, {});
}
+ static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) {
+ return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
+ }
+
+ static FunctionType setValueErrorFunctionType(MLIRContext *ctx) {
+ auto value = opaquePointerType(ctx);
+ return FunctionType::get(ctx, {value}, {});
+ }
+
+ static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) {
+ auto i1 = IntegerType::get(ctx, 1);
+ return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1});
+ }
+
+ static FunctionType isValueErrorFunctionType(MLIRContext *ctx) {
+ auto value = opaquePointerType(ctx);
+ auto i1 = IntegerType::get(ctx, 1);
+ return FunctionType::get(ctx, {value}, {i1});
+ }
+
static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
}
@@ -173,6 +197,10 @@ static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
+ addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx));
+ addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx));
+ addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx));
+ addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx));
addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
@@ -560,17 +588,53 @@ class RuntimeSetAvailableOpLowering
matchAndRewrite(RuntimeSetAvailableOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Type operandType = op.operand().getType();
+ rewriter.replaceOpWithNewOp<CallOp>(
+ op, operandType.isa<TokenType>() ? kEmplaceToken : kEmplaceValue,
+ TypeRange(), operands);
+ return success();
+ }
+};
+} // namespace
- if (operandType.isa<TokenType>() || operandType.isa<ValueType>()) {
- rewriter.create<CallOp>(op->getLoc(),
- operandType.isa<TokenType>() ? kEmplaceToken
- : kEmplaceValue,
- TypeRange(), operands);
- rewriter.eraseOp(op);
- return success();
- }
+//===----------------------------------------------------------------------===//
+// Convert async.runtime.set_error to the corresponding runtime API call.
+//===----------------------------------------------------------------------===//
- return rewriter.notifyMatchFailure(op, "unsupported async type");
+namespace {
+class RuntimeSetErrorOpLowering
+ : public OpConversionPattern<RuntimeSetErrorOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(RuntimeSetErrorOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ Type operandType = op.operand().getType();
+ rewriter.replaceOpWithNewOp<CallOp>(
+ op, operandType.isa<TokenType>() ? kSetTokenError : kSetValueError,
+ TypeRange(), operands);
+ return success();
+ }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Convert async.runtime.is_error to the corresponding runtime API call.
+//===----------------------------------------------------------------------===//
+
+namespace {
+class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(RuntimeIsErrorOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ Type operandType = op.operand().getType();
+ rewriter.replaceOpWithNewOp<CallOp>(
+ op, operandType.isa<TokenType>() ? kIsTokenError : kIsValueError,
+ rewriter.getI1Type(), operands);
+ return success();
}
};
} // namespace
@@ -889,7 +953,8 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
patterns.add<ReturnOpOpConversion>(converter, ctx);
// Lower async.runtime operations to the async runtime API calls.
- patterns.add<RuntimeSetAvailableOpLowering, RuntimeAwaitOpLowering,
+ patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering,
+ RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering,
RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
RuntimeAddToGroupOpLowering, RuntimeAddRefOpLowering,
RuntimeDropRefOpLowering>(converter, ctx);
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index a9bcfe46fba54..baa7d5d8c46ae 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -52,6 +52,8 @@ class AsyncToAsyncRuntimePass
/// operation to enable non-blocking waiting via coroutine suspension.
namespace {
struct CoroMachinery {
+ FuncOp func;
+
// Async execute region returns a completion token, and an async value for
// each yielded value.
//
@@ -63,6 +65,7 @@ struct CoroMachinery {
llvm::SmallVector<Value, 4> returnValues; // returned async values
Value coroHandle; // coroutine handle (!async.coro.handle value)
+ Block *setError; // switch completion token and all values to error state
Block *cleanup; // coroutine cleanup block
Block *suspend; // coroutine suspension block
};
@@ -74,6 +77,7 @@ struct CoroMachinery {
/// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html
///
/// - `entry` block sets up the coroutine.
+/// - `set_error` block sets completion token and async values state to error.
/// - `cleanup` block cleans up the coroutine state.
/// - `suspend block after the @llvm.coro.end() defines what value will be
/// returned to the initial caller of a coroutine. Everything before the
@@ -91,6 +95,11 @@ struct CoroMachinery {
/// %hdl = async.coro.begin %id // create a coroutine handle
/// br ^cleanup
///
+/// ^set_error: // this block created lazily only if needed (see code below)
+/// async.runtime.set_error %token : !async.token
+/// async.runtime.set_error %value : !async.value<T>
+/// br ^cleanup
+///
/// ^cleanup:
/// async.coro.free %hdl // delete the coroutine state
/// br ^suspend
@@ -163,14 +172,39 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
// continuations, and will conditionally branch to cleanup or suspend blocks.
CoroMachinery machinery;
+ machinery.func = func;
machinery.asyncToken = retToken;
machinery.returnValues = retValues;
machinery.coroHandle = coroHdlOp.handle();
+ machinery.setError = nullptr; // created lazily only if needed
machinery.cleanup = cleanupBlock;
machinery.suspend = suspendBlock;
return machinery;
}
+// Lazily creates `set_error` block only if it is required for lowering to the
+// runtime operations (see for example lowering of assert operation).
+static Block *setupSetErrorBlock(CoroMachinery &coro) {
+ if (coro.setError)
+ return coro.setError;
+
+ coro.setError = coro.func.addBlock();
+ coro.setError->moveBefore(coro.cleanup);
+
+ auto builder =
+ ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), coro.setError);
+
+ // Coroutine set_error block: set error on token and all returned values.
+ builder.create<RuntimeSetErrorOp>(coro.asyncToken);
+ for (Value retValue : coro.returnValues)
+ builder.create<RuntimeSetErrorOp>(retValue);
+
+ // Branch into the cleanup block.
+ builder.create<BranchOp>(coro.cleanup);
+
+ return coro.setError;
+}
+
/// Outline the body region attached to the `async.execute` op into a standalone
/// function.
///
@@ -316,9 +350,8 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
using AwaitAdaptor = typename AwaitType::Adaptor;
public:
- AwaitOpLoweringBase(
- MLIRContext *ctx,
- const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
+ AwaitOpLoweringBase(MLIRContext *ctx,
+ llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
: OpConversionPattern<AwaitType>(ctx),
outlinedFunctions(outlinedFunctions) {}
@@ -346,7 +379,7 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
// Inside the coroutine we convert await operation into coroutine suspension
// point, and resume execution asynchronously.
if (isInCoroutine) {
- const CoroMachinery &coro = outlined->getSecond();
+ CoroMachinery &coro = outlined->getSecond();
Block *suspended = op->getBlock();
ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
@@ -366,8 +399,25 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume,
coro.cleanup);
- // Make sure that replacement value will be constructed in resume block.
- rewriter.setInsertionPointToStart(resume);
+ // TODO: Async groups do not yet support runtime errors.
+ if (!std::is_same<AwaitAllOp, AwaitType>::value) {
+ // Split the resume block into error checking and continuation.
+ Block *continuation = rewriter.splitBlock(resume, Block::iterator(op));
+
+ // Check if the awaited value is in the error state.
+ builder.setInsertionPointToStart(resume);
+ auto isError = builder.create<RuntimeIsErrorOp>(
+ loc, rewriter.getI1Type(), operand);
+ builder.create<CondBranchOp>(isError,
+ /*trueDest=*/setupSetErrorBlock(coro),
+ /*trueArgs=*/ArrayRef<Value>(),
+ /*falseDest=*/continuation,
+ /*falseArgs=*/ArrayRef<Value>());
+
+ // Make sure that replacement value will be constructed in the
+ // continuation block.
+ rewriter.setInsertionPointToStart(continuation);
+ }
}
// Erase or replace the await operation with the new value.
@@ -385,7 +435,7 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
}
private:
- const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
+ llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
};
/// Lowering for `async.await` with a token operand.
@@ -437,12 +487,12 @@ class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
LogicalResult
matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- // Check if yield operation is inside the outlined coroutine function.
+ // Check if yield operation is inside the async coroutine function.
auto func = op->template getParentOfType<FuncOp>();
auto outlined = outlinedFunctions.find(func);
if (outlined == outlinedFunctions.end())
return rewriter.notifyMatchFailure(
- op, "operation is not inside the outlined async.execute function");
+ op, "operation is not inside the async coroutine function");
Location loc = op->getLoc();
const CoroMachinery &coro = outlined->getSecond();
@@ -466,6 +516,46 @@ class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
};
+//===----------------------------------------------------------------------===//
+// Convert std.assert operation to cond_br into `set_error` block.
+//===----------------------------------------------------------------------===//
+
+class AssertOpLowering : public OpConversionPattern<AssertOp> {
+public:
+ AssertOpLowering(MLIRContext *ctx,
+ llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
+ : OpConversionPattern<AssertOp>(ctx),
+ outlinedFunctions(outlinedFunctions) {}
+
+ LogicalResult
+ matchAndRewrite(AssertOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ // Check if assert operation is inside the async coroutine function.
+ auto func = op->template getParentOfType<FuncOp>();
+ auto outlined = outlinedFunctions.find(func);
+ if (outlined == outlinedFunctions.end())
+ return rewriter.notifyMatchFailure(
+ op, "operation is not inside the async coroutine function");
+
+ Location loc = op->getLoc();
+ CoroMachinery &coro = outlined->getSecond();
+
+ Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op));
+ rewriter.setInsertionPointToEnd(cont->getPrevNode());
+ rewriter.create<CondBranchOp>(loc, AssertOpAdaptor(operands).arg(),
+ /*trueDest=*/cont,
+ /*trueArgs=*/ArrayRef<Value>(),
+ /*falseDest=*/setupSetErrorBlock(coro),
+ /*falseArgs=*/ArrayRef<Value>());
+ rewriter.eraseOp(op);
+
+ return success();
+ }
+
+private:
+ llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
+};
+
//===----------------------------------------------------------------------===//
void AsyncToAsyncRuntimePass::runOnOperation() {
@@ -495,12 +585,19 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
AwaitAllOpLowering, YieldOpLowering>(ctx,
outlinedFunctions);
+ // Lower assertions to conditional branches into error blocks.
+ asyncPatterns.add<AssertOpLowering>(ctx, outlinedFunctions);
+
// All high level async operations must be lowered to the runtime operations.
ConversionTarget runtimeTarget(*ctx);
runtimeTarget.addLegalDialect<AsyncDialect>();
runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
+ // Assertions must be converted to runtime errors.
+ runtimeTarget.addIllegalOp<AssertOp>();
+ runtimeTarget.addLegalOp<CondBranchOp>();
+
if (failed(applyPartialConversion(module, runtimeTarget,
std::move(asyncPatterns)))) {
signalPassFailure();
diff --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
index 35e114285b941..856d1c7b74f9c 100644
--- a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
@@ -77,6 +77,46 @@ class AsyncRuntime {
llvm::ThreadPool threadPool;
};
+// -------------------------------------------------------------------------- //
+// A state of the async runtime value (token, value or group).
+// -------------------------------------------------------------------------- //
+
+class State {
+public:
+ enum StateEnum : int8_t {
+ // The underlying value is not yet available for consumption.
+ kUnavailable = 0,
+ // The underlying value is available for consumption. This state can not
+ // transition to any other state.
+ kAvailable = 1,
+ // This underlying value is available and contains an error. This state can
+ // not transition to any other state.
+ kError = 2,
+ };
+
+ /* implicit */ State(StateEnum s) : state(s) {}
+ /* implicit */ operator StateEnum() { return state; }
+
+ bool isUnavailable() const { return state == kUnavailable; }
+ bool isAvailable() const { return state == kAvailable; }
+ bool isError() const { return state == kError; }
+ bool isAvailableOrError() const { return isAvailable() || isError(); }
+
+ const char *debug() const {
+ switch (state) {
+ case kUnavailable:
+ return "unavailable";
+ case kAvailable:
+ return "available";
+ case kError:
+ return "error";
+ }
+ }
+
+private:
+ StateEnum state;
+};
+
// -------------------------------------------------------------------------- //
// A base class for all reference counted objects created by the async runtime.
// -------------------------------------------------------------------------- //
@@ -137,9 +177,9 @@ struct AsyncToken : public RefCounted {
// reference we must ensure that the token will be alive until the
// asynchronous operation is completed.
AsyncToken(AsyncRuntime *runtime)
- : RefCounted(runtime, /*count=*/2), ready(false) {}
+ : RefCounted(runtime, /*refCount=*/2), state(State::kUnavailable) {}
- std::atomic<bool> ready;
+ std::atomic<State::StateEnum> state;
// Pending awaiters are guarded by a mutex.
std::mutex mu;
@@ -153,9 +193,10 @@ struct AsyncToken : public RefCounted {
struct AsyncValue : public RefCounted {
// AsyncValue similar to an AsyncToken created with a reference count of 2.
AsyncValue(AsyncRuntime *runtime, int32_t size)
- : RefCounted(runtime, /*count=*/2), ready(false), storage(size) {}
+ : RefCounted(runtime, /*refCount=*/2), state(State::kUnavailable),
+ storage(size) {}
- std::atomic<bool> ready;
+ std::atomic<State::StateEnum> state;
// Use vector of bytes to store async value payload.
std::vector<int8_t> storage;
@@ -182,7 +223,6 @@ struct AsyncGroup : public RefCounted {
std::vector<std::function<void()>> awaiters;
};
-
// Adds references to reference counted runtime object.
extern "C" void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) {
RefCounted *refCounted = static_cast<RefCounted *>(ptr);
@@ -231,7 +271,7 @@ extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token,
}
};
- if (token->ready) {
+ if (State(token->state).isAvailableOrError()) {
// Update group pending tokens immediately and maybe run awaiters.
onTokenReady();
@@ -254,12 +294,16 @@ extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token,
return rank;
}
-// Switches `async.token` to ready state and runs all awaiters.
-extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
+// Switches `async.token` to available or error state (terminatl state) and runs
+// all awaiters.
+static void setTokenState(AsyncToken *token, State state) {
+ assert(state.isAvailableOrError() && "must be terminal state");
+ assert(State(token->state).isUnavailable() && "token must be unavailable");
+
// Make sure that `dropRef` does not destroy the mutex owned by the lock.
{
std::unique_lock<std::mutex> lock(token->mu);
- token->ready = true;
+ token->state = state;
token->cv.notify_all();
for (auto &awaiter : token->awaiters)
awaiter();
@@ -270,12 +314,14 @@ extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
token->dropRef();
}
-// Switches `async.value` to ready state and runs all awaiters.
-extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) {
+static void setValueState(AsyncValue *value, State state) {
+ assert(state.isAvailableOrError() && "must be terminal state");
+ assert(State(value->state).isUnavailable() && "value must be unavailable");
+
// Make sure that `dropRef` does not destroy the mutex owned by the lock.
{
std::unique_lock<std::mutex> lock(value->mu);
- value->ready = true;
+ value->state = state;
value->cv.notify_all();
for (auto &awaiter : value->awaiters)
awaiter();
@@ -286,16 +332,42 @@ extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) {
value->dropRef();
}
+extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
+ setTokenState(token, State::kAvailable);
+}
+
+extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) {
+ setValueState(value, State::kAvailable);
+}
+
+extern "C" void mlirAsyncRuntimeSetTokenError(AsyncToken *token) {
+ setTokenState(token, State::kError);
+}
+
+extern "C" void mlirAsyncRuntimeSetValueError(AsyncValue *value) {
+ setValueState(value, State::kError);
+}
+
+extern "C" bool mlirAsyncRuntimeIsTokenError(AsyncToken *token) {
+ return State(token->state).isError();
+}
+
+extern "C" bool mlirAsyncRuntimeIsValueError(AsyncValue *value) {
+ return State(value->state).isError();
+}
+
extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
std::unique_lock<std::mutex> lock(token->mu);
- if (!token->ready)
- token->cv.wait(lock, [token] { return token->ready.load(); });
+ if (!State(token->state).isAvailableOrError())
+ token->cv.wait(
+ lock, [token] { return State(token->state).isAvailableOrError(); });
}
extern "C" void mlirAsyncRuntimeAwaitValue(AsyncValue *value) {
std::unique_lock<std::mutex> lock(value->mu);
- if (!value->ready)
- value->cv.wait(lock, [value] { return value->ready.load(); });
+ if (!State(value->state).isAvailableOrError())
+ value->cv.wait(
+ lock, [value] { return State(value->state).isAvailableOrError(); });
}
extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
@@ -306,6 +378,7 @@ extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
// Returns a pointer to the storage owned by the async value.
extern "C" ValueStorage mlirAsyncRuntimeGetValueStorage(AsyncValue *value) {
+ assert(!State(value->state).isError() && "unexpected error state");
return value->storage.data();
}
@@ -319,7 +392,7 @@ extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,
CoroResume resume) {
auto execute = [handle, resume]() { (*resume)(handle); };
std::unique_lock<std::mutex> lock(token->mu);
- if (token->ready) {
+ if (State(token->state).isAvailableOrError()) {
lock.unlock();
execute();
} else {
@@ -332,7 +405,7 @@ extern "C" void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *value,
CoroResume resume) {
auto execute = [handle, resume]() { (*resume)(handle); };
std::unique_lock<std::mutex> lock(value->mu);
- if (value->ready) {
+ if (State(value->state).isAvailableOrError()) {
lock.unlock();
execute();
} else {
@@ -402,6 +475,14 @@ void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols) {
&mlir::runtime::mlirAsyncRuntimeEmplaceToken);
exportSymbol("mlirAsyncRuntimeEmplaceValue",
&mlir::runtime::mlirAsyncRuntimeEmplaceValue);
+ exportSymbol("mlirAsyncRuntimeSetTokenError",
+ &mlir::runtime::mlirAsyncRuntimeSetTokenError);
+ exportSymbol("mlirAsyncRuntimeSetValueError",
+ &mlir::runtime::mlirAsyncRuntimeSetValueError);
+ exportSymbol("mlirAsyncRuntimeIsTokenError",
+ &mlir::runtime::mlirAsyncRuntimeIsTokenError);
+ exportSymbol("mlirAsyncRuntimeIsValueError",
+ &mlir::runtime::mlirAsyncRuntimeIsValueError);
exportSymbol("mlirAsyncRuntimeAwaitToken",
&mlir::runtime::mlirAsyncRuntimeAwaitToken);
exportSymbol("mlirAsyncRuntimeAwaitValue",
diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir
index 32c02128c103d..74c091e7575a1 100644
--- a/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir
+++ b/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir
@@ -43,6 +43,24 @@ func @set_value_available() {
return
}
+// CHECK-LABEL: @is_token_error
+func @is_token_error() -> i1 {
+ // CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateToken
+ %0 = async.runtime.create : !async.token
+ // CHECK: %[[ERR:.*]] = call @mlirAsyncRuntimeIsTokenError(%[[TOKEN]])
+ %1 = async.runtime.is_error %0 : !async.token
+ return %1 : i1
+}
+
+// CHECK-LABEL: @is_value_error
+func @is_value_error() -> i1 {
+ // CHECK: %[[VALUE:.*]] = call @mlirAsyncRuntimeCreateValue
+ %0 = async.runtime.create : !async.value<f32>
+ // CHECK: %[[ERR:.*]] = call @mlirAsyncRuntimeIsValueError(%[[VALUE]])
+ %1 = async.runtime.is_error %0 : !async.value<f32>
+ return %1 : i1
+}
+
// CHECK-LABEL: @await_token
func @await_token() {
// CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateToken
diff --git a/mlir/test/Dialect/Async/async-to-async-runtime.mlir b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
index a57c283b3de29..08607f89e2e57 100644
--- a/mlir/test/Dialect/Async/async-to-async-runtime.mlir
+++ b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -split-input-file -async-to-async-runtime -print-ir-after-all | FileCheck %s --dump-input=always
+// RUN: mlir-opt %s -split-input-file -async-to-async-runtime \
+// RUN: | FileCheck %s --dump-input=always
// CHECK-LABEL: @execute_no_async_args
func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) {
@@ -101,11 +102,17 @@ func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
// CHECK: async.coro.suspend %[[SAVED]]
// CHECK-SAME: ^[[SUSPEND]], ^[[RESUME_1:.*]], ^[[CLEANUP]]
-// Set token available after second resumption.
+// Check the error of the awaited token after resumption.
// CHECK: ^[[RESUME_1]]:
+// CHECK: %[[ERR:.*]] = async.runtime.is_error %[[INNER_TOKEN]]
+// CHECK: cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]]
+
+// Set token available if the token is not in the error state.
+// CHECK: ^[[CONTINUATION:.*]]:
// CHECK: memref.store
// CHECK: async.runtime.set_available %[[TOKEN]]
+// CHECK: ^[[SET_ERROR]]:
// CHECK: ^[[CLEANUP]]:
// CHECK: ^[[SUSPEND]]:
@@ -155,8 +162,13 @@ func @async_execute_token_dependency(%arg0: f32, %arg1: memref<1xf32>) {
// CHECK: async.coro.suspend %[[SAVED]]
// CHECK-SAME: ^[[SUSPEND]], ^[[RESUME_1:.*]], ^[[CLEANUP]]
-// Emplace result token after second resumption.
+// Check the error of the awaited token after resumption.
// CHECK: ^[[RESUME_1]]:
+// CHECK: %[[ERR:.*]] = async.runtime.is_error %[[ARG0]]
+// CHECK: cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]]
+
+// Emplace result token after second resumption and error checking.
+// CHECK: ^[[CONTINUATION:.*]]:
// CHECK: memref.store
// CHECK: async.runtime.set_available %[[TOKEN]]
@@ -293,11 +305,65 @@ func @async_value_operands() {
// CHECK: async.coro.suspend
// CHECK-SAME: ^[[SUSPEND]], ^[[RESUME_1:.*]], ^[[CLEANUP]]
-// Load from the async.value argument.
+// Check the error of the awaited token after resumption.
// CHECK: ^[[RESUME_1]]:
+// CHECK: %[[ERR:.*]] = async.runtime.is_error %[[ARG]]
+// CHECK: cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]]
+
+// // Load from the async.value argument after error checking.
+// CHECK: ^[[CONTINUATION:.*]]:
// CHECK: %[[LOADED:.*]] = async.runtime.load %[[ARG]] : !async.value<f32
// CHECK: addf %[[LOADED]], %[[LOADED]] : f32
// CHECK: async.runtime.set_available %[[TOKEN]]
// CHECK: ^[[CLEANUP]]:
// CHECK: ^[[SUSPEND]]:
+
+// -----
+
+// CHECK-LABEL: @execute_asserttion
+func @execute_asserttion(%arg0: i1) {
+ %token = async.execute {
+ assert %arg0, "error"
+ async.yield
+ }
+ async.await %token : !async.token
+ return
+}
+
+// Function outlined from the async.execute operation.
+// CHECK-LABEL: func private @async_execute_fn(
+// CHECK-SAME: %[[ARG0:.*]]: i1
+// CHECK-SAME: -> !async.token
+
+// Create token for return op, and mark a function as a coroutine.
+// CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token
+// CHECK: %[[ID:.*]] = async.coro.id
+// CHECK: %[[HDL:.*]] = async.coro.begin
+
+// Initial coroutine suspension.
+// CHECK: async.coro.suspend
+// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME:.*]], ^[[CLEANUP:.*]]
+
+// Resume coroutine after suspension.
+// CHECK: ^[[RESUME]]:
+// CHECK: cond_br %[[ARG0]], ^[[SET_AVAILABLE:.*]], ^[[SET_ERROR:.*]]
+
+// Set coroutine completion token to available state.
+// CHECK: ^[[SET_AVAILABLE]]:
+// CHECK: async.runtime.set_available %[[TOKEN]]
+// CHECK: br ^[[CLEANUP]]
+
+// Set coroutine completion token to error state.
+// CHECK: ^[[SET_ERROR]]:
+// CHECK: async.runtime.set_error %[[TOKEN]]
+// CHECK: br ^[[CLEANUP]]
+
+// Delete coroutine.
+// CHECK: ^[[CLEANUP]]:
+// CHECK: async.coro.free %[[ID]], %[[HDL]]
+
+// Suspend coroutine, and also a return statement for ramp function.
+// CHECK: ^[[SUSPEND]]:
+// CHECK: async.coro.end %[[HDL]]
+// CHECK: return %[[TOKEN]]
diff --git a/mlir/test/Dialect/Async/runtime.mlir b/mlir/test/Dialect/Async/runtime.mlir
index 776b16a48e3bc..c8f6f65cb8469 100644
--- a/mlir/test/Dialect/Async/runtime.mlir
+++ b/mlir/test/Dialect/Async/runtime.mlir
@@ -38,6 +38,34 @@ func @set_value_available(%arg0: !async.value<f32>) {
return
}
+// CHECK-LABEL: @set_token_error
+func @set_token_error(%arg0: !async.token) {
+ // CHECK: async.runtime.set_error %arg0 : !async.token
+ async.runtime.set_error %arg0 : !async.token
+ return
+}
+
+// CHECK-LABEL: @set_value_error
+func @set_value_error(%arg0: !async.value<f32>) {
+ // CHECK: async.runtime.set_error %arg0 : !async.value<f32>
+ async.runtime.set_error %arg0 : !async.value<f32>
+ return
+}
+
+// CHECK-LABEL: @is_token_error
+func @is_token_error(%arg0: !async.token) -> i1 {
+ // CHECK: %[[ERR:.*]] = async.runtime.is_error %arg0 : !async.token
+ %0 = async.runtime.is_error %arg0 : !async.token
+ return %0 : i1
+}
+
+// CHECK-LABEL: @is_value_error
+func @is_value_error(%arg0: !async.value<f32>) -> i1 {
+ // CHECK: %[[ERR:.*]] = async.runtime.is_error %arg0 : !async.value<f32>
+ %0 = async.runtime.is_error %arg0 : !async.value<f32>
+ return %0 : i1
+}
+
// CHECK-LABEL: @await_token
func @await_token(%arg0: !async.token) {
// CHECK: async.runtime.await %arg0 : !async.token
diff --git a/mlir/test/mlir-cpu-runner/async-error.mlir b/mlir/test/mlir-cpu-runner/async-error.mlir
new file mode 100644
index 0000000000000..139296d6614f4
--- /dev/null
+++ b/mlir/test/mlir-cpu-runner/async-error.mlir
@@ -0,0 +1,85 @@
+// RUN: mlir-opt %s -async-to-async-runtime \
+// RUN: -async-runtime-ref-counting \
+// RUN: -async-runtime-ref-counting-opt \
+// RUN: -convert-async-to-llvm \
+// RUN: -convert-linalg-to-loops \
+// RUN: -convert-scf-to-std \
+// RUN: -convert-linalg-to-llvm \
+// RUN: -convert-vector-to-llvm \
+// RUN: -convert-std-to-llvm \
+// RUN: | mlir-cpu-runner \
+// RUN: -e main -entry-point-result=void -O0 \
+// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext \
+// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext \
+// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_async_runtime%shlibext \
+// RUN: | FileCheck %s --dump-input=always
+
+func @main() {
+ %false = constant 0 : i1
+
+ // ------------------------------------------------------------------------ //
+ // Check that simple async region completes without errors.
+ // ------------------------------------------------------------------------ //
+ %token0 = async.execute {
+ async.yield
+ }
+ async.runtime.await %token0 : !async.token
+
+ // CHECK: 0
+ %err0 = async.runtime.is_error %token0 : !async.token
+ vector.print %err0 : i1
+
+ // ------------------------------------------------------------------------ //
+ // Check that assertion in the async region converted to async error.
+ // ------------------------------------------------------------------------ //
+ %token1 = async.execute {
+ assert %false, "error"
+ async.yield
+ }
+ async.runtime.await %token1 : !async.token
+
+ // CHECK: 1
+ %err1 = async.runtime.is_error %token1 : !async.token
+ vector.print %err1 : i1
+
+ // ------------------------------------------------------------------------ //
+ // Check error propagation from the nested region.
+ // ------------------------------------------------------------------------ //
+ %token2 = async.execute {
+ %token = async.execute {
+ assert %false, "error"
+ async.yield
+ }
+ async.await %token : !async.token
+ async.yield
+ }
+ async.runtime.await %token2 : !async.token
+
+ // CHECK: 1
+ %err2 = async.runtime.is_error %token2 : !async.token
+ vector.print %err2 : i1
+
+ // ------------------------------------------------------------------------ //
+ // Check error propagation from the nested region with async values.
+ // ------------------------------------------------------------------------ //
+ %token3, %value3 = async.execute -> !async.value<f32> {
+ %token, %value = async.execute -> !async.value<f32> {
+ assert %false, "error"
+ %0 = constant 123.45 : f32
+ async.yield %0 : f32
+ }
+ %ret = async.await %value : !async.value<f32>
+ async.yield %ret : f32
+ }
+ async.runtime.await %token3 : !async.token
+ async.runtime.await %value3 : !async.value<f32>
+
+ // CHECK: 1
+ // CHECK: 1
+ %err3_0 = async.runtime.is_error %token3 : !async.token
+ %err3_1 = async.runtime.is_error %value3 : !async.value<f32>
+ vector.print %err3_0 : i1
+ vector.print %err3_1 : i1
+
+ return
+}
More information about the Mlir-commits
mailing list