[Mlir-commits] [mlir] 39957aa - [mlir] Add error state and error propagation to async runtime values

Eugene Zhulenev llvmlistbot at llvm.org
Thu May 27 09:28:54 PDT 2021


Author: Eugene Zhulenev
Date: 2021-05-27T09:28:47-07:00
New Revision: 39957aa4243cb9aec3a7114c0ecf710ecce96b72

URL: https://github.com/llvm/llvm-project/commit/39957aa4243cb9aec3a7114c0ecf710ecce96b72
DIFF: https://github.com/llvm/llvm-project/commit/39957aa4243cb9aec3a7114c0ecf710ecce96b72.diff

LOG: [mlir] Add error state and error propagation to async runtime values

Depends On D103102

Not yet implemented:
1. Error handling after synchronous await
2. Error handling for async groups

Will be addressed in the followup PRs

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D103109

Added: 
    mlir/test/mlir-cpu-runner/async-error.mlir

Modified: 
    mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
    mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
    mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
    mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
    mlir/lib/ExecutionEngine/AsyncRuntime.cpp
    mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir
    mlir/test/Dialect/Async/async-to-async-runtime.mlir
    mlir/test/Dialect/Async/runtime.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
index fbfc529c0b824..5f6eece6cf69d 100644
--- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
+++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
@@ -343,7 +343,7 @@ def Async_RuntimeCreateOp : Async_Op<"runtime.create"> {
 }
 
 def Async_RuntimeSetAvailableOp : Async_Op<"runtime.set_available"> {
-  let summary = "switches token or value available state";
+  let summary = "switches token or value to available state";
   let description = [{
     The `async.runtime.set_available` operation switches async token or value
     state to available.
@@ -353,11 +353,37 @@ def Async_RuntimeSetAvailableOp : Async_Op<"runtime.set_available"> {
   let assemblyFormat = "$operand attr-dict `:` type($operand)";
 }
 
+def Async_RuntimeSetErrorOp : Async_Op<"runtime.set_error"> {
+  let summary = "switches token or value to error state";
+  let description = [{
+    The `async.runtime.set_error` operation switches async token or value
+    state to error.
+  }];
+
+  let arguments = (ins Async_AnyValueOrTokenType:$operand);
+  let assemblyFormat = "$operand attr-dict `:` type($operand)";
+}
+
+def Async_RuntimeIsErrorOp : Async_Op<"runtime.is_error"> {
+  let summary = "returns true if token, value or group is in error state";
+  let description = [{
+    The `async.runtime.is_error` operation returns true if the token, value or
+    group (any of the async runtime values) is in the error state. It is the
+    caller responsibility to check error state after the call to `await` or
+    resuming after `await_and_resume`.
+  }];
+
+  let arguments = (ins Async_AnyValueOrTokenType:$operand);
+  let results = (outs I1:$is_error);
+
+  let assemblyFormat = "$operand attr-dict `:` type($operand)";
+}
+
 def Async_RuntimeAwaitOp : Async_Op<"runtime.await"> {
   let summary = "blocks the caller thread until the operand becomes available";
   let description = [{
     The `async.runtime.await` operation blocks the caller thread until the
-    operand becomes available.
+    operand becomes available or error.
   }];
 
   let arguments = (ins Async_AnyAsyncType:$operand);
@@ -379,8 +405,8 @@ def Async_RuntimeAwaitAndResumeOp : Async_Op<"runtime.await_and_resume"> {
   let summary = "awaits the async operand and resumes the coroutine";
   let description = [{
     The `async.runtime.await_and_resume` operation awaits for the operand to
-    become available and resumes the coroutine on a thread managed by the
-    runtime.
+    become available or error and resumes the coroutine on a thread managed by
+    the runtime.
   }];
 
   let arguments = (ins Async_AnyAsyncType:$operand,

diff  --git a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
index d5cede323ed92..df9adebbd6206 100644
--- a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
+++ b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
@@ -76,6 +76,18 @@ extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *);
 // Switches `async.value` to ready state and runs all awaiters.
 extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *);
 
+// Switches `async.token` to error state and runs all awaiters.
+extern "C" void mlirAsyncRuntimeSetTokenError(AsyncToken *);
+
+// Switches `async.value` to error state and runs all awaiters.
+extern "C" void mlirAsyncRuntimeSetValueError(AsyncValue *);
+
+// Returns true if token is in the error state.
+extern "C" bool mlirAsyncRuntimeIsTokenError(AsyncToken *);
+
+// Returns true if value is in the error state.
+extern "C" bool mlirAsyncRuntimeIsValueError(AsyncValue *);
+
 // Blocks the caller thread until the token becomes ready.
 extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *);
 

diff  --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 7a24b75640ec1..55d1714171774 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -35,6 +35,10 @@ static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue";
+static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError";
+static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError";
+static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError";
+static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError";
 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
@@ -101,6 +105,26 @@ struct AsyncAPI {
     return FunctionType::get(ctx, {value}, {});
   }
 
+  static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) {
+    return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
+  }
+
+  static FunctionType setValueErrorFunctionType(MLIRContext *ctx) {
+    auto value = opaquePointerType(ctx);
+    return FunctionType::get(ctx, {value}, {});
+  }
+
+  static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) {
+    auto i1 = IntegerType::get(ctx, 1);
+    return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1});
+  }
+
+  static FunctionType isValueErrorFunctionType(MLIRContext *ctx) {
+    auto value = opaquePointerType(ctx);
+    auto i1 = IntegerType::get(ctx, 1);
+    return FunctionType::get(ctx, {value}, {i1});
+  }
+
   static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
   }
@@ -173,6 +197,10 @@ static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
   addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
   addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
   addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
+  addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx));
+  addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx));
+  addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx));
+  addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx));
   addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
   addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
   addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
@@ -560,17 +588,53 @@ class RuntimeSetAvailableOpLowering
   matchAndRewrite(RuntimeSetAvailableOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     Type operandType = op.operand().getType();
+    rewriter.replaceOpWithNewOp<CallOp>(
+        op, operandType.isa<TokenType>() ? kEmplaceToken : kEmplaceValue,
+        TypeRange(), operands);
+    return success();
+  }
+};
+} // namespace
 
-    if (operandType.isa<TokenType>() || operandType.isa<ValueType>()) {
-      rewriter.create<CallOp>(op->getLoc(),
-                              operandType.isa<TokenType>() ? kEmplaceToken
-                                                           : kEmplaceValue,
-                              TypeRange(), operands);
-      rewriter.eraseOp(op);
-      return success();
-    }
+//===----------------------------------------------------------------------===//
+// Convert async.runtime.set_error to the corresponding runtime API call.
+//===----------------------------------------------------------------------===//
 
-    return rewriter.notifyMatchFailure(op, "unsupported async type");
+namespace {
+class RuntimeSetErrorOpLowering
+    : public OpConversionPattern<RuntimeSetErrorOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(RuntimeSetErrorOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type operandType = op.operand().getType();
+    rewriter.replaceOpWithNewOp<CallOp>(
+        op, operandType.isa<TokenType>() ? kSetTokenError : kSetValueError,
+        TypeRange(), operands);
+    return success();
+  }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Convert async.runtime.is_error to the corresponding runtime API call.
+//===----------------------------------------------------------------------===//
+
+namespace {
+class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(RuntimeIsErrorOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type operandType = op.operand().getType();
+    rewriter.replaceOpWithNewOp<CallOp>(
+        op, operandType.isa<TokenType>() ? kIsTokenError : kIsValueError,
+        rewriter.getI1Type(), operands);
+    return success();
   }
 };
 } // namespace
@@ -889,7 +953,8 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
   patterns.add<ReturnOpOpConversion>(converter, ctx);
 
   // Lower async.runtime operations to the async runtime API calls.
-  patterns.add<RuntimeSetAvailableOpLowering, RuntimeAwaitOpLowering,
+  patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering,
+               RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering,
                RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
                RuntimeAddToGroupOpLowering, RuntimeAddRefOpLowering,
                RuntimeDropRefOpLowering>(converter, ctx);

diff  --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index a9bcfe46fba54..baa7d5d8c46ae 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -52,6 +52,8 @@ class AsyncToAsyncRuntimePass
 /// operation to enable non-blocking waiting via coroutine suspension.
 namespace {
 struct CoroMachinery {
+  FuncOp func;
+
   // Async execute region returns a completion token, and an async value for
   // each yielded value.
   //
@@ -63,6 +65,7 @@ struct CoroMachinery {
   llvm::SmallVector<Value, 4> returnValues; // returned async values
 
   Value coroHandle; // coroutine handle (!async.coro.handle value)
+  Block *setError;  // switch completion token and all values to error state
   Block *cleanup;   // coroutine cleanup block
   Block *suspend;   // coroutine suspension block
 };
@@ -74,6 +77,7 @@ struct CoroMachinery {
 /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html
 ///
 ///  - `entry` block sets up the coroutine.
+///  - `set_error` block sets completion token and async values state to error.
 ///  - `cleanup` block cleans up the coroutine state.
 ///  - `suspend block after the @llvm.coro.end() defines what value will be
 ///    returned to the initial caller of a coroutine. Everything before the
@@ -91,6 +95,11 @@ struct CoroMachinery {
 ///       %hdl = async.coro.begin %id              // create a coroutine handle
 ///       br ^cleanup
 ///
+///     ^set_error: // this block created lazily only if needed (see code below)
+///       async.runtime.set_error %token : !async.token
+///       async.runtime.set_error %value : !async.value<T>
+///       br ^cleanup
+///
 ///     ^cleanup:
 ///       async.coro.free %hdl // delete the coroutine state
 ///       br ^suspend
@@ -163,14 +172,39 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
   // continuations, and will conditionally branch to cleanup or suspend blocks.
 
   CoroMachinery machinery;
+  machinery.func = func;
   machinery.asyncToken = retToken;
   machinery.returnValues = retValues;
   machinery.coroHandle = coroHdlOp.handle();
+  machinery.setError = nullptr; // created lazily only if needed
   machinery.cleanup = cleanupBlock;
   machinery.suspend = suspendBlock;
   return machinery;
 }
 
+// Lazily creates `set_error` block only if it is required for lowering to the
+// runtime operations (see for example lowering of assert operation).
+static Block *setupSetErrorBlock(CoroMachinery &coro) {
+  if (coro.setError)
+    return coro.setError;
+
+  coro.setError = coro.func.addBlock();
+  coro.setError->moveBefore(coro.cleanup);
+
+  auto builder =
+      ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), coro.setError);
+
+  // Coroutine set_error block: set error on token and all returned values.
+  builder.create<RuntimeSetErrorOp>(coro.asyncToken);
+  for (Value retValue : coro.returnValues)
+    builder.create<RuntimeSetErrorOp>(retValue);
+
+  // Branch into the cleanup block.
+  builder.create<BranchOp>(coro.cleanup);
+
+  return coro.setError;
+}
+
 /// Outline the body region attached to the `async.execute` op into a standalone
 /// function.
 ///
@@ -316,9 +350,8 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
   using AwaitAdaptor = typename AwaitType::Adaptor;
 
 public:
-  AwaitOpLoweringBase(
-      MLIRContext *ctx,
-      const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
+  AwaitOpLoweringBase(MLIRContext *ctx,
+                      llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
       : OpConversionPattern<AwaitType>(ctx),
         outlinedFunctions(outlinedFunctions) {}
 
@@ -346,7 +379,7 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
     // Inside the coroutine we convert await operation into coroutine suspension
     // point, and resume execution asynchronously.
     if (isInCoroutine) {
-      const CoroMachinery &coro = outlined->getSecond();
+      CoroMachinery &coro = outlined->getSecond();
       Block *suspended = op->getBlock();
 
       ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
@@ -366,8 +399,25 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
       builder.create<CoroSuspendOp>(coroSaveOp.state(), coro.suspend, resume,
                                     coro.cleanup);
 
-      // Make sure that replacement value will be constructed in resume block.
-      rewriter.setInsertionPointToStart(resume);
+      // TODO: Async groups do not yet support runtime errors.
+      if (!std::is_same<AwaitAllOp, AwaitType>::value) {
+        // Split the resume block into error checking and continuation.
+        Block *continuation = rewriter.splitBlock(resume, Block::iterator(op));
+
+        // Check if the awaited value is in the error state.
+        builder.setInsertionPointToStart(resume);
+        auto isError = builder.create<RuntimeIsErrorOp>(
+            loc, rewriter.getI1Type(), operand);
+        builder.create<CondBranchOp>(isError,
+                                     /*trueDest=*/setupSetErrorBlock(coro),
+                                     /*trueArgs=*/ArrayRef<Value>(),
+                                     /*falseDest=*/continuation,
+                                     /*falseArgs=*/ArrayRef<Value>());
+
+        // Make sure that replacement value will be constructed in the
+        // continuation block.
+        rewriter.setInsertionPointToStart(continuation);
+      }
     }
 
     // Erase or replace the await operation with the new value.
@@ -385,7 +435,7 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
   }
 
 private:
-  const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
+  llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
 };
 
 /// Lowering for `async.await` with a token operand.
@@ -437,12 +487,12 @@ class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
   LogicalResult
   matchAndRewrite(async::YieldOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    // Check if yield operation is inside the outlined coroutine function.
+    // Check if yield operation is inside the async coroutine function.
     auto func = op->template getParentOfType<FuncOp>();
     auto outlined = outlinedFunctions.find(func);
     if (outlined == outlinedFunctions.end())
       return rewriter.notifyMatchFailure(
-          op, "operation is not inside the outlined async.execute function");
+          op, "operation is not inside the async coroutine function");
 
     Location loc = op->getLoc();
     const CoroMachinery &coro = outlined->getSecond();
@@ -466,6 +516,46 @@ class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
   const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
 };
 
+//===----------------------------------------------------------------------===//
+// Convert std.assert operation to cond_br into `set_error` block.
+//===----------------------------------------------------------------------===//
+
+class AssertOpLowering : public OpConversionPattern<AssertOp> {
+public:
+  AssertOpLowering(MLIRContext *ctx,
+                   llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
+      : OpConversionPattern<AssertOp>(ctx),
+        outlinedFunctions(outlinedFunctions) {}
+
+  LogicalResult
+  matchAndRewrite(AssertOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Check if assert operation is inside the async coroutine function.
+    auto func = op->template getParentOfType<FuncOp>();
+    auto outlined = outlinedFunctions.find(func);
+    if (outlined == outlinedFunctions.end())
+      return rewriter.notifyMatchFailure(
+          op, "operation is not inside the async coroutine function");
+
+    Location loc = op->getLoc();
+    CoroMachinery &coro = outlined->getSecond();
+
+    Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op));
+    rewriter.setInsertionPointToEnd(cont->getPrevNode());
+    rewriter.create<CondBranchOp>(loc, AssertOpAdaptor(operands).arg(),
+                                  /*trueDest=*/cont,
+                                  /*trueArgs=*/ArrayRef<Value>(),
+                                  /*falseDest=*/setupSetErrorBlock(coro),
+                                  /*falseArgs=*/ArrayRef<Value>());
+    rewriter.eraseOp(op);
+
+    return success();
+  }
+
+private:
+  llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
+};
+
 //===----------------------------------------------------------------------===//
 
 void AsyncToAsyncRuntimePass::runOnOperation() {
@@ -495,12 +585,19 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
                     AwaitAllOpLowering, YieldOpLowering>(ctx,
                                                          outlinedFunctions);
 
+  // Lower assertions to conditional branches into error blocks.
+  asyncPatterns.add<AssertOpLowering>(ctx, outlinedFunctions);
+
   // All high level async operations must be lowered to the runtime operations.
   ConversionTarget runtimeTarget(*ctx);
   runtimeTarget.addLegalDialect<AsyncDialect>();
   runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
   runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
 
+  // Assertions must be converted to runtime errors.
+  runtimeTarget.addIllegalOp<AssertOp>();
+  runtimeTarget.addLegalOp<CondBranchOp>();
+
   if (failed(applyPartialConversion(module, runtimeTarget,
                                     std::move(asyncPatterns)))) {
     signalPassFailure();

diff  --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
index 35e114285b941..856d1c7b74f9c 100644
--- a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
@@ -77,6 +77,46 @@ class AsyncRuntime {
   llvm::ThreadPool threadPool;
 };
 
+// -------------------------------------------------------------------------- //
+// A state of the async runtime value (token, value or group).
+// -------------------------------------------------------------------------- //
+
+class State {
+public:
+  enum StateEnum : int8_t {
+    // The underlying value is not yet available for consumption.
+    kUnavailable = 0,
+    // The underlying value is available for consumption. This state can not
+    // transition to any other state.
+    kAvailable = 1,
+    // This underlying value is available and contains an error. This state can
+    // not transition to any other state.
+    kError = 2,
+  };
+
+  /* implicit */ State(StateEnum s) : state(s) {}
+  /* implicit */ operator StateEnum() { return state; }
+
+  bool isUnavailable() const { return state == kUnavailable; }
+  bool isAvailable() const { return state == kAvailable; }
+  bool isError() const { return state == kError; }
+  bool isAvailableOrError() const { return isAvailable() || isError(); }
+
+  const char *debug() const {
+    switch (state) {
+    case kUnavailable:
+      return "unavailable";
+    case kAvailable:
+      return "available";
+    case kError:
+      return "error";
+    }
+  }
+
+private:
+  StateEnum state;
+};
+
 // -------------------------------------------------------------------------- //
 // A base class for all reference counted objects created by the async runtime.
 // -------------------------------------------------------------------------- //
@@ -137,9 +177,9 @@ struct AsyncToken : public RefCounted {
   // reference we must ensure that the token will be alive until the
   // asynchronous operation is completed.
   AsyncToken(AsyncRuntime *runtime)
-      : RefCounted(runtime, /*count=*/2), ready(false) {}
+      : RefCounted(runtime, /*refCount=*/2), state(State::kUnavailable) {}
 
-  std::atomic<bool> ready;
+  std::atomic<State::StateEnum> state;
 
   // Pending awaiters are guarded by a mutex.
   std::mutex mu;
@@ -153,9 +193,10 @@ struct AsyncToken : public RefCounted {
 struct AsyncValue : public RefCounted {
   // AsyncValue similar to an AsyncToken created with a reference count of 2.
   AsyncValue(AsyncRuntime *runtime, int32_t size)
-      : RefCounted(runtime, /*count=*/2), ready(false), storage(size) {}
+      : RefCounted(runtime, /*refCount=*/2), state(State::kUnavailable),
+        storage(size) {}
 
-  std::atomic<bool> ready;
+  std::atomic<State::StateEnum> state;
 
   // Use vector of bytes to store async value payload.
   std::vector<int8_t> storage;
@@ -182,7 +223,6 @@ struct AsyncGroup : public RefCounted {
   std::vector<std::function<void()>> awaiters;
 };
 
-
 // Adds references to reference counted runtime object.
 extern "C" void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) {
   RefCounted *refCounted = static_cast<RefCounted *>(ptr);
@@ -231,7 +271,7 @@ extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token,
     }
   };
 
-  if (token->ready) {
+  if (State(token->state).isAvailableOrError()) {
     // Update group pending tokens immediately and maybe run awaiters.
     onTokenReady();
 
@@ -254,12 +294,16 @@ extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token,
   return rank;
 }
 
-// Switches `async.token` to ready state and runs all awaiters.
-extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
+// Switches `async.token` to available or error state (terminatl state) and runs
+// all awaiters.
+static void setTokenState(AsyncToken *token, State state) {
+  assert(state.isAvailableOrError() && "must be terminal state");
+  assert(State(token->state).isUnavailable() && "token must be unavailable");
+
   // Make sure that `dropRef` does not destroy the mutex owned by the lock.
   {
     std::unique_lock<std::mutex> lock(token->mu);
-    token->ready = true;
+    token->state = state;
     token->cv.notify_all();
     for (auto &awaiter : token->awaiters)
       awaiter();
@@ -270,12 +314,14 @@ extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
   token->dropRef();
 }
 
-// Switches `async.value` to ready state and runs all awaiters.
-extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) {
+static void setValueState(AsyncValue *value, State state) {
+  assert(state.isAvailableOrError() && "must be terminal state");
+  assert(State(value->state).isUnavailable() && "value must be unavailable");
+
   // Make sure that `dropRef` does not destroy the mutex owned by the lock.
   {
     std::unique_lock<std::mutex> lock(value->mu);
-    value->ready = true;
+    value->state = state;
     value->cv.notify_all();
     for (auto &awaiter : value->awaiters)
       awaiter();
@@ -286,16 +332,42 @@ extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) {
   value->dropRef();
 }
 
+extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
+  setTokenState(token, State::kAvailable);
+}
+
+extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) {
+  setValueState(value, State::kAvailable);
+}
+
+extern "C" void mlirAsyncRuntimeSetTokenError(AsyncToken *token) {
+  setTokenState(token, State::kError);
+}
+
+extern "C" void mlirAsyncRuntimeSetValueError(AsyncValue *value) {
+  setValueState(value, State::kError);
+}
+
+extern "C" bool mlirAsyncRuntimeIsTokenError(AsyncToken *token) {
+  return State(token->state).isError();
+}
+
+extern "C" bool mlirAsyncRuntimeIsValueError(AsyncValue *value) {
+  return State(value->state).isError();
+}
+
 extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
   std::unique_lock<std::mutex> lock(token->mu);
-  if (!token->ready)
-    token->cv.wait(lock, [token] { return token->ready.load(); });
+  if (!State(token->state).isAvailableOrError())
+    token->cv.wait(
+        lock, [token] { return State(token->state).isAvailableOrError(); });
 }
 
 extern "C" void mlirAsyncRuntimeAwaitValue(AsyncValue *value) {
   std::unique_lock<std::mutex> lock(value->mu);
-  if (!value->ready)
-    value->cv.wait(lock, [value] { return value->ready.load(); });
+  if (!State(value->state).isAvailableOrError())
+    value->cv.wait(
+        lock, [value] { return State(value->state).isAvailableOrError(); });
 }
 
 extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
@@ -306,6 +378,7 @@ extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
 
 // Returns a pointer to the storage owned by the async value.
 extern "C" ValueStorage mlirAsyncRuntimeGetValueStorage(AsyncValue *value) {
+  assert(!State(value->state).isError() && "unexpected error state");
   return value->storage.data();
 }
 
@@ -319,7 +392,7 @@ extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,
                                                      CoroResume resume) {
   auto execute = [handle, resume]() { (*resume)(handle); };
   std::unique_lock<std::mutex> lock(token->mu);
-  if (token->ready) {
+  if (State(token->state).isAvailableOrError()) {
     lock.unlock();
     execute();
   } else {
@@ -332,7 +405,7 @@ extern "C" void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *value,
                                                      CoroResume resume) {
   auto execute = [handle, resume]() { (*resume)(handle); };
   std::unique_lock<std::mutex> lock(value->mu);
-  if (value->ready) {
+  if (State(value->state).isAvailableOrError()) {
     lock.unlock();
     execute();
   } else {
@@ -402,6 +475,14 @@ void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols) {
                &mlir::runtime::mlirAsyncRuntimeEmplaceToken);
   exportSymbol("mlirAsyncRuntimeEmplaceValue",
                &mlir::runtime::mlirAsyncRuntimeEmplaceValue);
+  exportSymbol("mlirAsyncRuntimeSetTokenError",
+               &mlir::runtime::mlirAsyncRuntimeSetTokenError);
+  exportSymbol("mlirAsyncRuntimeSetValueError",
+               &mlir::runtime::mlirAsyncRuntimeSetValueError);
+  exportSymbol("mlirAsyncRuntimeIsTokenError",
+               &mlir::runtime::mlirAsyncRuntimeIsTokenError);
+  exportSymbol("mlirAsyncRuntimeIsValueError",
+               &mlir::runtime::mlirAsyncRuntimeIsValueError);
   exportSymbol("mlirAsyncRuntimeAwaitToken",
                &mlir::runtime::mlirAsyncRuntimeAwaitToken);
   exportSymbol("mlirAsyncRuntimeAwaitValue",

diff  --git a/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir
index 32c02128c103d..74c091e7575a1 100644
--- a/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir
+++ b/mlir/test/Conversion/AsyncToLLVM/convert-runtime-to-llvm.mlir
@@ -43,6 +43,24 @@ func @set_value_available() {
   return
 }
 
+// CHECK-LABEL: @is_token_error
+func @is_token_error() -> i1 {
+  // CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateToken
+  %0 = async.runtime.create : !async.token
+  // CHECK: %[[ERR:.*]] = call @mlirAsyncRuntimeIsTokenError(%[[TOKEN]])
+  %1 = async.runtime.is_error %0 : !async.token
+  return %1 : i1
+}
+
+// CHECK-LABEL: @is_value_error
+func @is_value_error() -> i1 {
+  // CHECK: %[[VALUE:.*]] = call @mlirAsyncRuntimeCreateValue
+  %0 = async.runtime.create : !async.value<f32>
+  // CHECK: %[[ERR:.*]] = call @mlirAsyncRuntimeIsValueError(%[[VALUE]])
+  %1 = async.runtime.is_error %0 : !async.value<f32>
+  return %1 : i1
+}
+
 // CHECK-LABEL: @await_token
 func @await_token() {
   // CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateToken

diff  --git a/mlir/test/Dialect/Async/async-to-async-runtime.mlir b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
index a57c283b3de29..08607f89e2e57 100644
--- a/mlir/test/Dialect/Async/async-to-async-runtime.mlir
+++ b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -split-input-file -async-to-async-runtime -print-ir-after-all | FileCheck %s --dump-input=always
+// RUN: mlir-opt %s -split-input-file -async-to-async-runtime                  \
+// RUN:   | FileCheck %s --dump-input=always
 
 // CHECK-LABEL: @execute_no_async_args
 func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) {
@@ -101,11 +102,17 @@ func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
 // CHECK:   async.coro.suspend %[[SAVED]]
 // CHECK-SAME: ^[[SUSPEND]], ^[[RESUME_1:.*]], ^[[CLEANUP]]
 
-// Set token available after second resumption.
+// Check the error of the awaited token after resumption.
 // CHECK: ^[[RESUME_1]]:
+// CHECK:   %[[ERR:.*]] = async.runtime.is_error %[[INNER_TOKEN]]
+// CHECK:   cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]]
+
+// Set token available if the token is not in the error state.
+// CHECK: ^[[CONTINUATION:.*]]:
 // CHECK:   memref.store
 // CHECK:   async.runtime.set_available %[[TOKEN]]
 
+// CHECK: ^[[SET_ERROR]]:
 // CHECK: ^[[CLEANUP]]:
 // CHECK: ^[[SUSPEND]]:
 
@@ -155,8 +162,13 @@ func @async_execute_token_dependency(%arg0: f32, %arg1: memref<1xf32>) {
 // CHECK:   async.coro.suspend %[[SAVED]]
 // CHECK-SAME: ^[[SUSPEND]], ^[[RESUME_1:.*]], ^[[CLEANUP]]
 
-// Emplace result token after second resumption.
+// Check the error of the awaited token after resumption.
 // CHECK: ^[[RESUME_1]]:
+// CHECK:   %[[ERR:.*]] = async.runtime.is_error %[[ARG0]]
+// CHECK:   cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]]
+
+// Emplace result token after second resumption and error checking.
+// CHECK: ^[[CONTINUATION:.*]]:
 // CHECK:   memref.store
 // CHECK:   async.runtime.set_available %[[TOKEN]]
 
@@ -293,11 +305,65 @@ func @async_value_operands() {
 // CHECK:   async.coro.suspend
 // CHECK-SAME: ^[[SUSPEND]], ^[[RESUME_1:.*]], ^[[CLEANUP]]
 
-// Load from the async.value argument.
+// Check the error of the awaited token after resumption.
 // CHECK: ^[[RESUME_1]]:
+// CHECK:   %[[ERR:.*]] = async.runtime.is_error %[[ARG]]
+// CHECK:   cond_br %[[ERR]], ^[[SET_ERROR:.*]], ^[[CONTINUATION:.*]]
+
+// // Load from the async.value argument after error checking.
+// CHECK: ^[[CONTINUATION:.*]]:
 // CHECK:   %[[LOADED:.*]] = async.runtime.load %[[ARG]] : !async.value<f32
 // CHECK:   addf %[[LOADED]], %[[LOADED]] : f32
 // CHECK:   async.runtime.set_available %[[TOKEN]]
 
 // CHECK: ^[[CLEANUP]]:
 // CHECK: ^[[SUSPEND]]:
+
+// -----
+
+// CHECK-LABEL: @execute_asserttion
+func @execute_asserttion(%arg0: i1) {
+  %token = async.execute {
+    assert %arg0, "error"
+    async.yield
+  }
+  async.await %token : !async.token
+  return
+}
+
+// Function outlined from the async.execute operation.
+// CHECK-LABEL: func private @async_execute_fn(
+// CHECK-SAME:  %[[ARG0:.*]]: i1
+// CHECK-SAME:  -> !async.token
+
+// Create token for return op, and mark a function as a coroutine.
+// CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token
+// CHECK: %[[ID:.*]] = async.coro.id
+// CHECK: %[[HDL:.*]] = async.coro.begin
+
+// Initial coroutine suspension.
+// CHECK:      async.coro.suspend
+// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME:.*]], ^[[CLEANUP:.*]]
+
+// Resume coroutine after suspension.
+// CHECK: ^[[RESUME]]:
+// CHECK:   cond_br %[[ARG0]], ^[[SET_AVAILABLE:.*]], ^[[SET_ERROR:.*]]
+
+// Set coroutine completion token to available state.
+// CHECK: ^[[SET_AVAILABLE]]:
+// CHECK:   async.runtime.set_available %[[TOKEN]]
+// CHECK:   br ^[[CLEANUP]]
+
+// Set coroutine completion token to error state.
+// CHECK: ^[[SET_ERROR]]:
+// CHECK:   async.runtime.set_error %[[TOKEN]]
+// CHECK:   br ^[[CLEANUP]]
+
+// Delete coroutine.
+// CHECK: ^[[CLEANUP]]:
+// CHECK:   async.coro.free %[[ID]], %[[HDL]]
+
+// Suspend coroutine, and also a return statement for ramp function.
+// CHECK: ^[[SUSPEND]]:
+// CHECK:   async.coro.end %[[HDL]]
+// CHECK:   return %[[TOKEN]]

diff  --git a/mlir/test/Dialect/Async/runtime.mlir b/mlir/test/Dialect/Async/runtime.mlir
index 776b16a48e3bc..c8f6f65cb8469 100644
--- a/mlir/test/Dialect/Async/runtime.mlir
+++ b/mlir/test/Dialect/Async/runtime.mlir
@@ -38,6 +38,34 @@ func @set_value_available(%arg0: !async.value<f32>) {
   return
 }
 
+// CHECK-LABEL: @set_token_error
+func @set_token_error(%arg0: !async.token) {
+  // CHECK: async.runtime.set_error %arg0 : !async.token
+  async.runtime.set_error %arg0 : !async.token
+  return
+}
+
+// CHECK-LABEL: @set_value_error
+func @set_value_error(%arg0: !async.value<f32>) {
+  // CHECK: async.runtime.set_error %arg0 : !async.value<f32>
+  async.runtime.set_error %arg0 : !async.value<f32>
+  return
+}
+
+// CHECK-LABEL: @is_token_error
+func @is_token_error(%arg0: !async.token) -> i1 {
+  // CHECK: %[[ERR:.*]] = async.runtime.is_error %arg0 : !async.token
+  %0 = async.runtime.is_error %arg0 : !async.token
+  return %0 : i1
+}
+
+// CHECK-LABEL: @is_value_error
+func @is_value_error(%arg0: !async.value<f32>) -> i1 {
+  // CHECK: %[[ERR:.*]] = async.runtime.is_error %arg0 : !async.value<f32>
+  %0 = async.runtime.is_error %arg0 : !async.value<f32>
+  return %0 : i1
+}
+
 // CHECK-LABEL: @await_token
 func @await_token(%arg0: !async.token) {
   // CHECK: async.runtime.await %arg0 : !async.token

diff  --git a/mlir/test/mlir-cpu-runner/async-error.mlir b/mlir/test/mlir-cpu-runner/async-error.mlir
new file mode 100644
index 0000000000000..139296d6614f4
--- /dev/null
+++ b/mlir/test/mlir-cpu-runner/async-error.mlir
@@ -0,0 +1,85 @@
+// RUN:   mlir-opt %s -async-to-async-runtime                                  \
+// RUN:               -async-runtime-ref-counting                              \
+// RUN:               -async-runtime-ref-counting-opt                          \
+// RUN:               -convert-async-to-llvm                                   \
+// RUN:               -convert-linalg-to-loops                                 \
+// RUN:               -convert-scf-to-std                                      \
+// RUN:               -convert-linalg-to-llvm                                  \
+// RUN:               -convert-vector-to-llvm                                  \
+// RUN:               -convert-std-to-llvm                                     \
+// RUN: | mlir-cpu-runner                                                      \
+// RUN:     -e main -entry-point-result=void -O0                               \
+// RUN:     -shared-libs=%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext  \
+// RUN:     -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext    \
+// RUN:     -shared-libs=%linalg_test_lib_dir/libmlir_async_runtime%shlibext   \
+// RUN: | FileCheck %s --dump-input=always
+
+func @main() {
+  %false = constant 0 : i1
+
+  // ------------------------------------------------------------------------ //
+  // Check that simple async region completes without errors.
+  // ------------------------------------------------------------------------ //
+  %token0 = async.execute {
+    async.yield
+  }
+  async.runtime.await %token0 : !async.token
+
+  // CHECK: 0
+  %err0 = async.runtime.is_error %token0 : !async.token
+  vector.print %err0 : i1
+
+  // ------------------------------------------------------------------------ //
+  // Check that assertion in the async region converted to async error.
+  // ------------------------------------------------------------------------ //
+  %token1 = async.execute {
+    assert %false, "error"
+    async.yield
+  }
+  async.runtime.await %token1 : !async.token
+
+  // CHECK: 1
+  %err1 = async.runtime.is_error %token1 : !async.token
+  vector.print %err1 : i1
+
+  // ------------------------------------------------------------------------ //
+  // Check error propagation from the nested region.
+  // ------------------------------------------------------------------------ //
+  %token2 = async.execute {
+    %token = async.execute {
+      assert %false, "error"
+      async.yield
+    }
+    async.await %token : !async.token
+    async.yield
+  }
+  async.runtime.await %token2 : !async.token
+
+  // CHECK: 1
+  %err2 = async.runtime.is_error %token2 : !async.token
+  vector.print %err2 : i1
+
+  // ------------------------------------------------------------------------ //
+  // Check error propagation from the nested region with async values.
+  // ------------------------------------------------------------------------ //
+  %token3, %value3 = async.execute -> !async.value<f32> {
+    %token, %value = async.execute -> !async.value<f32> {
+      assert %false, "error"
+      %0 = constant 123.45 : f32
+      async.yield %0 : f32
+    }
+    %ret = async.await %value : !async.value<f32>
+    async.yield %ret : f32
+  }
+  async.runtime.await %token3 : !async.token
+  async.runtime.await %value3 : !async.value<f32>
+
+  // CHECK: 1
+  // CHECK: 1
+  %err3_0 = async.runtime.is_error %token3 : !async.token
+  %err3_1 = async.runtime.is_error %value3 : !async.value<f32>
+  vector.print %err3_0 : i1
+  vector.print %err3_1 : i1
+
+  return
+}


        


More information about the Mlir-commits mailing list