[llvm-branch-commits] [mlir] 621ad46 - [mlir] Async: lowering async.value to LLVM

Eugene Zhulenev via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Dec 25 02:28:39 PST 2020


Author: Eugene Zhulenev
Date: 2020-12-25T02:23:48-08:00
New Revision: 621ad468d99d4013a4298465f02707a5e9e89cae

URL: https://github.com/llvm/llvm-project/commit/621ad468d99d4013a4298465f02707a5e9e89cae
DIFF: https://github.com/llvm/llvm-project/commit/621ad468d99d4013a4298465f02707a5e9e89cae.diff

LOG: [mlir] Async: lowering async.value to LLVM

1. Add new methods to Async runtime API to support yielding async values
2. Add lowering from `async.yield` with value payload to the new runtime API calls

`async.value` lowering requires that payload type is convertible to LLVM and supported by `llvm.mlir.cast` (DialectCast) operation.

Reviewed By: csigg

Differential Revision: https://reviews.llvm.org/D93592

Added: 
    mlir/test/mlir-cpu-runner/async-value.mlir

Modified: 
    mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
    mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
    mlir/lib/Conversion/AsyncToLLVM/CMakeLists.txt
    mlir/lib/ExecutionEngine/AsyncRuntime.cpp
    mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
index e3d90198f36c..0fe44cd1c127 100644
--- a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
+++ b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
@@ -45,6 +45,12 @@ typedef struct AsyncToken AsyncToken;
 // Runtime implementation of `async.group` data type.
 typedef struct AsyncGroup AsyncGroup;
 
+// Runtime implementation of `async.value` data type.
+typedef struct AsyncValue AsyncValue;
+
+// Async value payload stored in a memory owned by the async.value.
+using ValueStorage = void *;
+
 // Async runtime uses LLVM coroutines to represent asynchronous tasks. Task
 // function is a coroutine handle and a resume function that continue coroutine
 // execution from a suspension point.
@@ -66,6 +72,13 @@ extern "C" MLIR_ASYNCRUNTIME_EXPORT void
 // Create a new `async.token` in not-ready state.
 extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncToken *mlirAsyncRuntimeCreateToken();
 
+// Create a new `async.value` in not-ready state. Size parameter specifies the
+// number of bytes that will be allocated for the async value storage. Storage
+// is owned by the `async.value` and deallocated when the async value is
+// destructed (reference count drops to zero).
+extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncValue *
+    mlirAsyncRuntimeCreateValue(int32_t);
+
 // Create a new `async.group` in empty state.
 extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncGroup *mlirAsyncRuntimeCreateGroup();
 
@@ -76,14 +89,26 @@ mlirAsyncRuntimeAddTokenToGroup(AsyncToken *, AsyncGroup *);
 extern "C" MLIR_ASYNCRUNTIME_EXPORT void
 mlirAsyncRuntimeEmplaceToken(AsyncToken *);
 
+// Switches `async.value` to ready state and runs all awaiters.
+extern "C" MLIR_ASYNCRUNTIME_EXPORT void
+mlirAsyncRuntimeEmplaceValue(AsyncValue *);
+
 // Blocks the caller thread until the token becomes ready.
 extern "C" MLIR_ASYNCRUNTIME_EXPORT void
 mlirAsyncRuntimeAwaitToken(AsyncToken *);
 
+// Blocks the caller thread until the value becomes ready.
+extern "C" MLIR_ASYNCRUNTIME_EXPORT void
+mlirAsyncRuntimeAwaitValue(AsyncValue *);
+
 // Blocks the caller thread until the elements in the group become ready.
 extern "C" MLIR_ASYNCRUNTIME_EXPORT void
 mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *);
 
+// Returns a pointer to the storage owned by the async value.
+extern "C" MLIR_ASYNCRUNTIME_EXPORT ValueStorage
+mlirAsyncRuntimeGetValueStorage(AsyncValue *);
+
 // Executes the task (coro handle + resume function) in one of the threads
 // managed by the runtime.
 extern "C" MLIR_ASYNCRUNTIME_EXPORT void mlirAsyncRuntimeExecute(CoroHandle,
@@ -94,6 +119,11 @@ extern "C" MLIR_ASYNCRUNTIME_EXPORT void mlirAsyncRuntimeExecute(CoroHandle,
 extern "C" MLIR_ASYNCRUNTIME_EXPORT void
 mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *, CoroHandle, CoroResume);
 
+// Executes the task (coro handle + resume function) in one of the threads
+// managed by the runtime after the value becomes ready.
+extern "C" MLIR_ASYNCRUNTIME_EXPORT void
+mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *, CoroHandle, CoroResume);
+
 // Executes the task (coro handle + resume function) in one of the threads
 // managed by the runtime after the all members of the group become ready.
 extern "C" MLIR_ASYNCRUNTIME_EXPORT void

diff  --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 3daa70b0a952..f1d6264606be 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -9,9 +9,11 @@
 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
 
 #include "../PassDetail.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
 #include "mlir/Dialect/Async/IR/Async.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/TypeUtilities.h"
@@ -36,23 +38,39 @@ static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
+static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
+static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue";
 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
+static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
+static constexpr const char *kGetValueStorage =
+    "mlirAsyncRuntimeGetValueStorage";
 static constexpr const char *kAddTokenToGroup =
     "mlirAsyncRuntimeAddTokenToGroup";
-static constexpr const char *kAwaitAndExecute =
+static constexpr const char *kAwaitTokenAndExecute =
     "mlirAsyncRuntimeAwaitTokenAndExecute";
+static constexpr const char *kAwaitValueAndExecute =
+    "mlirAsyncRuntimeAwaitValueAndExecute";
 static constexpr const char *kAwaitAllAndExecute =
     "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
 
 namespace {
-// Async Runtime API function types.
+/// Async Runtime API function types.
+///
+/// Because we can't create API function signature for type parametrized
+/// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After
+/// lowering all async data types become opaque pointers at runtime.
 struct AsyncAPI {
+  // All async types are lowered to opaque i8* LLVM pointers at runtime.
+  static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) {
+    return LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
+  }
+
   static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
-    auto ref = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
+    auto ref = opaquePointerType(ctx);
     auto count = IntegerType::get(ctx, 32);
     return FunctionType::get(ctx, {ref, count}, {});
   }
@@ -61,24 +79,46 @@ struct AsyncAPI {
     return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
   }
 
+  static FunctionType createValueFunctionType(MLIRContext *ctx) {
+    auto i32 = IntegerType::get(ctx, 32);
+    auto value = opaquePointerType(ctx);
+    return FunctionType::get(ctx, {i32}, {value});
+  }
+
   static FunctionType createGroupFunctionType(MLIRContext *ctx) {
     return FunctionType::get(ctx, {}, {GroupType::get(ctx)});
   }
 
+  static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
+    auto value = opaquePointerType(ctx);
+    auto storage = opaquePointerType(ctx);
+    return FunctionType::get(ctx, {value}, {storage});
+  }
+
   static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
   }
 
+  static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
+    auto value = opaquePointerType(ctx);
+    return FunctionType::get(ctx, {value}, {});
+  }
+
   static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
   }
 
+  static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
+    auto value = opaquePointerType(ctx);
+    return FunctionType::get(ctx, {value}, {});
+  }
+
   static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
     return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
   }
 
   static FunctionType executeFunctionType(MLIRContext *ctx) {
-    auto hdl = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
+    auto hdl = opaquePointerType(ctx);
     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
     return FunctionType::get(ctx, {hdl, resume}, {});
   }
@@ -89,14 +129,21 @@ struct AsyncAPI {
                              {i64});
   }
 
-  static FunctionType awaitAndExecuteFunctionType(MLIRContext *ctx) {
-    auto hdl = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
+  static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
+    auto hdl = opaquePointerType(ctx);
     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
     return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {});
   }
 
+  static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
+    auto value = opaquePointerType(ctx);
+    auto hdl = opaquePointerType(ctx);
+    auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
+    return FunctionType::get(ctx, {value, hdl, resume}, {});
+  }
+
   static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
-    auto hdl = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
+    auto hdl = opaquePointerType(ctx);
     auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
     return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
   }
@@ -104,13 +151,13 @@ struct AsyncAPI {
   // Auxiliary coroutine resume intrinsic wrapper.
   static LLVM::LLVMType resumeFunctionType(MLIRContext *ctx) {
     auto voidTy = LLVM::LLVMVoidType::get(ctx);
-    auto i8Ptr = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
+    auto i8Ptr = opaquePointerType(ctx);
     return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false);
   }
 };
 } // namespace
 
-// Adds Async Runtime C API declarations to the module.
+/// Adds Async Runtime C API declarations to the module.
 static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
   auto builder = ImplicitLocOpBuilder::atBlockTerminator(module.getLoc(),
                                                          module.getBody());
@@ -125,13 +172,20 @@ static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
   addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
   addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
   addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
+  addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx));
   addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
   addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
+  addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
   addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
+  addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
   addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
   addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
+  addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx));
   addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
-  addFuncDecl(kAwaitAndExecute, AsyncAPI::awaitAndExecuteFunctionType(ctx));
+  addFuncDecl(kAwaitTokenAndExecute,
+              AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
+  addFuncDecl(kAwaitValueAndExecute,
+              AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
   addFuncDecl(kAwaitAllAndExecute,
               AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
 }
@@ -215,9 +269,9 @@ static void addCRuntimeDeclarations(ModuleOp module) {
 
 static constexpr const char *kResume = "__resume";
 
-// A function that takes a coroutine handle and calls a `llvm.coro.resume`
-// intrinsics. We need this function to be able to pass it to the async
-// runtime execute API.
+/// A function that takes a coroutine handle and calls a `llvm.coro.resume`
+/// intrinsics. We need this function to be able to pass it to the async
+/// runtime execute API.
 static void addResumeFunction(ModuleOp module) {
   MLIRContext *ctx = module.getContext();
 
@@ -248,49 +302,61 @@ static void addResumeFunction(ModuleOp module) {
 // async.execute op outlining to the coroutine functions.
 //===----------------------------------------------------------------------===//
 
-// Function targeted for coroutine transformation has two additional blocks at
-// the end: coroutine cleanup and coroutine suspension.
-//
-// async.await op lowering additionaly creates a resume block for each
-// operation to enable non-blocking waiting via coroutine suspension.
+/// Function targeted for coroutine transformation has two additional blocks at
+/// the end: coroutine cleanup and coroutine suspension.
+///
+/// async.await op lowering additionaly creates a resume block for each
+/// operation to enable non-blocking waiting via coroutine suspension.
 namespace {
 struct CoroMachinery {
-  Value asyncToken;
+  // Async execute region returns a completion token, and an async value for
+  // each yielded value.
+  //
+  //   %token, %result = async.execute -> !async.value<T> {
+  //     %0 = constant ... : T
+  //     async.yield %0 : T
+  //   }
+  Value asyncToken; // token representing completion of the async region
+  llvm::SmallVector<Value, 4> returnValues; // returned async values
+
   Value coroHandle;
   Block *cleanup;
   Block *suspend;
 };
 } // namespace
 
-// Builds an coroutine template compatible with LLVM coroutines lowering.
-//
-//  - `entry` block sets up the coroutine.
-//  - `cleanup` block cleans up the coroutine state.
-//  - `suspend block after the @llvm.coro.end() defines what value will be
-//    returned to the initial caller of a coroutine. Everything before the
-//    @llvm.coro.end() will be executed at every suspension point.
-//
-// Coroutine structure (only the important bits):
-//
-//   func @async_execute_fn(<function-arguments>) -> !async.token {
-//     ^entryBlock(<function-arguments>):
-//       %token = <async token> : !async.token // create async runtime token
-//       %hdl = llvm.call @llvm.coro.id(...)   // create a coroutine handle
-//       br ^cleanup
-//
-//     ^cleanup:
-//       llvm.call @llvm.coro.free(...)        // delete coroutine state
-//       br ^suspend
-//
-//     ^suspend:
-//       llvm.call @llvm.coro.end(...)         // marks the end of a coroutine
-//       return %token : !async.token
-//   }
-//
-// The actual code for the async.execute operation body region will be inserted
-// before the entry block terminator.
-//
-//
+/// Builds an coroutine template compatible with LLVM coroutines lowering.
+///
+///  - `entry` block sets up the coroutine.
+///  - `cleanup` block cleans up the coroutine state.
+///  - `suspend block after the @llvm.coro.end() defines what value will be
+///    returned to the initial caller of a coroutine. Everything before the
+///    @llvm.coro.end() will be executed at every suspension point.
+///
+/// Coroutine structure (only the important bits):
+///
+///   func @async_execute_fn(<function-arguments>)
+///        -> (!async.token, !async.value<T>)
+///   {
+///     ^entryBlock(<function-arguments>):
+///       %token = <async token> : !async.token    // create async runtime token
+///       %value = <async value> : !async.value<T> // create async value
+///       %hdl = llvm.call @llvm.coro.id(...)      // create a coroutine handle
+///       br ^cleanup
+///
+///     ^cleanup:
+///       llvm.call @llvm.coro.free(...)  // delete coroutine state
+///       br ^suspend
+///
+///     ^suspend:
+///       llvm.call @llvm.coro.end(...)  // marks the end of a coroutine
+///       return %token, %value : !async.token, !async.value<T>
+///   }
+///
+/// The actual code for the async.execute operation body region will be inserted
+/// before the entry block terminator.
+///
+///
 static CoroMachinery setupCoroMachinery(FuncOp func) {
   assert(func.getBody().empty() && "Function must have empty body");
 
@@ -312,6 +378,44 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
   // ------------------------------------------------------------------------ //
   auto createToken = builder.create<CallOp>(kCreateToken, TokenType::get(ctx));
 
+  // Async value operands and results must be convertible to LLVM types. This is
+  // verified before the function outlining.
+  LLVMTypeConverter converter(ctx);
+
+  // Returns the size requirements for the async value storage.
+  // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt
+  auto sizeOf = [&](ValueType valueType) -> Value {
+    auto storedType = converter.convertType(valueType.getValueType());
+    auto storagePtrType =
+        LLVM::LLVMPointerType::get(storedType.cast<LLVM::LLVMType>());
+
+    // %Size = getelementptr %T* null, int 1
+    // %SizeI = ptrtoint %T* %Size to i32
+    auto nullPtr = builder.create<LLVM::NullOp>(loc, storagePtrType);
+    auto one = builder.create<LLVM::ConstantOp>(loc, i32,
+                                                builder.getI32IntegerAttr(1));
+    auto gep = builder.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr,
+                                           one.getResult());
+    auto size = builder.create<LLVM::PtrToIntOp>(loc, i32, gep);
+
+    // Cast to std type because runtime API defined using std types.
+    return builder.create<LLVM::DialectCastOp>(loc, builder.getI32Type(),
+                                               size.getResult());
+  };
+
+  // We use the `async.value` type as a return type although it does not match
+  // the `kCreateValue` function signature, because it will be later lowered to
+  // the runtime type (opaque i8* pointer).
+  llvm::SmallVector<CallOp, 4> createValues;
+  for (auto resultType : func.getCallableResults().drop_front(1))
+    createValues.emplace_back(builder.create<CallOp>(
+        loc, kCreateValue, resultType, sizeOf(resultType.cast<ValueType>())));
+
+  auto createdValues = llvm::map_range(
+      createValues, [](CallOp call) { return call.getResult(0); });
+  llvm::SmallVector<Value, 4> returnValues(createdValues.begin(),
+                                           createdValues.end());
+
   // ------------------------------------------------------------------------ //
   // Initialize coroutine: allocate frame, get coroutine handle.
   // ------------------------------------------------------------------------ //
@@ -371,9 +475,11 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
   builder.create<LLVM::CallOp>(i1, builder.getSymbolRefAttr(kCoroEnd),
                                ValueRange({coroHdl.getResult(0), constFalse}));
 
-  // Return created `async.token` from the suspend block. This will be the
-  // return value of a coroutine ramp function.
-  builder.create<ReturnOp>(createToken.getResult(0));
+  // Return created `async.token` and `async.values` from the suspend block.
+  // This will be the return value of a coroutine ramp function.
+  SmallVector<Value, 4> ret{createToken.getResult(0)};
+  ret.insert(ret.end(), returnValues.begin(), returnValues.end());
+  builder.create<ReturnOp>(loc, ret);
 
   // Branch from the entry block to the cleanup block to create a valid CFG.
   builder.setInsertionPointToEnd(entryBlock);
@@ -383,39 +489,44 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
   // `async.await` op lowering will create resume blocks for async
   // continuations, and will conditionally branch to cleanup or suspend blocks.
 
-  return {createToken.getResult(0), coroHdl.getResult(0), cleanupBlock,
-          suspendBlock};
+  CoroMachinery machinery;
+  machinery.asyncToken = createToken.getResult(0);
+  machinery.returnValues = returnValues;
+  machinery.coroHandle = coroHdl.getResult(0);
+  machinery.cleanup = cleanupBlock;
+  machinery.suspend = suspendBlock;
+  return machinery;
 }
 
-// Add a LLVM coroutine suspension point to the end of suspended block, to
-// resume execution in resume block. The caller is responsible for creating the
-// two suspended/resume blocks with the desired ops contained in each block.
-// This function merely provides the required control flow logic.
-//
-// `coroState` must be a value returned from the call to @llvm.coro.save(...)
-// intrinsic (saved coroutine state).
-//
-// Before:
-//
-//   ^bb0:
-//     "opBefore"(...)
-//     "op"(...)
-//   ^cleanup: ...
-//   ^suspend: ...
-//   ^resume:
-//     "op"(...)
-//
-// After:
-//
-//   ^bb0:
-//     "opBefore"(...)
-//     %suspend = llmv.call @llvm.coro.suspend(...)
-//     switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
-//   ^resume:
-//     "op"(...)
-//   ^cleanup: ...
-//   ^suspend: ...
-//
+/// Add a LLVM coroutine suspension point to the end of suspended block, to
+/// resume execution in resume block. The caller is responsible for creating the
+/// two suspended/resume blocks with the desired ops contained in each block.
+/// This function merely provides the required control flow logic.
+///
+/// `coroState` must be a value returned from the call to @llvm.coro.save(...)
+/// intrinsic (saved coroutine state).
+///
+/// Before:
+///
+///   ^bb0:
+///     "opBefore"(...)
+///     "op"(...)
+///   ^cleanup: ...
+///   ^suspend: ...
+///   ^resume:
+///     "op"(...)
+///
+/// After:
+///
+///   ^bb0:
+///     "opBefore"(...)
+///     %suspend = llmv.call @llvm.coro.suspend(...)
+///     switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
+///   ^resume:
+///     "op"(...)
+///   ^cleanup: ...
+///   ^suspend: ...
+///
 static void addSuspensionPoint(CoroMachinery coro, Value coroState,
                                Operation *op, Block *suspended, Block *resume,
                                OpBuilder &builder) {
@@ -461,10 +572,10 @@ static void addSuspensionPoint(CoroMachinery coro, Value coroState,
                                  /*falseDest=*/coro.cleanup);
 }
 
-// Outline the body region attached to the `async.execute` op into a standalone
-// function.
-//
-// Note that this is not reversible transformation.
+/// Outline the body region attached to the `async.execute` op into a standalone
+/// function.
+///
+/// Note that this is not reversible transformation.
 static std::pair<FuncOp, CoroMachinery>
 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
   ModuleOp module = execute->getParentOfType<ModuleOp>();
@@ -475,6 +586,7 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
   // Collect all outlined function inputs.
   llvm::SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
                                               execute.dependencies().end());
+  assert(execute.operands().empty() && "operands are not supported");
   getUsedValuesDefinedAbove(execute.body(), functionInputs);
 
   // Collect types for the outlined function inputs and outputs.
@@ -535,15 +647,9 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
   valueMapping.map(functionInputs, func.getArguments());
 
   // Clone all operations from the execute operation body into the outlined
-  // function body, and replace all `async.yield` operations with a call
-  // to async runtime to emplace the result token.
-  for (Operation &op : execute.body().getOps()) {
-    if (isa<async::YieldOp>(op)) {
-      builder.create<CallOp>(kEmplaceToken, TypeRange(), coro.asyncToken);
-      continue;
-    }
+  // function body.
+  for (Operation &op : execute.body().getOps())
     builder.clone(op, valueMapping);
-  }
 
   // Replace the original `async.execute` with a call to outlined function.
   ImplicitLocOpBuilder callBuilder(loc, execute);
@@ -560,42 +666,38 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
 //===----------------------------------------------------------------------===//
 
 namespace {
+
+/// AsyncRuntimeTypeConverter only converts types from the Async dialect to
+/// their runtime type (opaque pointers) and does not convert any other types.
 class AsyncRuntimeTypeConverter : public TypeConverter {
 public:
-  AsyncRuntimeTypeConverter() { addConversion(convertType); }
-
-  static Type convertType(Type type) {
-    MLIRContext *ctx = type.getContext();
-    // Convert async tokens and groups to opaque pointers.
-    if (type.isa<TokenType, GroupType>())
-      return LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
-    return type;
+  AsyncRuntimeTypeConverter() {
+    addConversion([](Type type) { return type; });
+    addConversion(convertAsyncTypes);
+  }
+
+  static Optional<Type> convertAsyncTypes(Type type) {
+    if (type.isa<TokenType, GroupType, ValueType>())
+      return AsyncAPI::opaquePointerType(type.getContext());
+    return llvm::None;
   }
 };
 } // namespace
 
 //===----------------------------------------------------------------------===//
-// Convert types for all call operations to lowered async types.
+// Convert return operations that return async values from async regions.
 //===----------------------------------------------------------------------===//
 
 namespace {
-class CallOpOpConversion : public ConversionPattern {
+class ReturnOpOpConversion : public ConversionPattern {
 public:
-  explicit CallOpOpConversion(MLIRContext *ctx)
-      : ConversionPattern(CallOp::getOperationName(), 1, ctx) {}
+  explicit ReturnOpOpConversion(TypeConverter &converter, MLIRContext *ctx)
+      : ConversionPattern(ReturnOp::getOperationName(), 1, converter, ctx) {}
 
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    AsyncRuntimeTypeConverter converter;
-
-    SmallVector<Type, 5> resultTypes;
-    converter.convertTypes(op->getResultTypes(), resultTypes);
-
-    CallOp call = cast<CallOp>(op);
-    rewriter.replaceOpWithNewOp<CallOp>(op, resultTypes, call.callee(),
-                                        operands);
-
+    rewriter.replaceOpWithNewOp<ReturnOp>(op, operands);
     return success();
   }
 };
@@ -611,8 +713,9 @@ namespace {
 template <typename RefCountingOp>
 class RefCountingOpLowering : public ConversionPattern {
 public:
-  explicit RefCountingOpLowering(MLIRContext *ctx, StringRef apiFunctionName)
-      : ConversionPattern(RefCountingOp::getOperationName(), 1, ctx),
+  explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx,
+                                 StringRef apiFunctionName)
+      : ConversionPattern(RefCountingOp::getOperationName(), 1, converter, ctx),
         apiFunctionName(apiFunctionName) {}
 
   LogicalResult
@@ -634,18 +737,18 @@ class RefCountingOpLowering : public ConversionPattern {
   StringRef apiFunctionName;
 };
 
-// async.drop_ref op lowering to mlirAsyncRuntimeDropRef function call.
+/// async.drop_ref op lowering to mlirAsyncRuntimeDropRef function call.
 class AddRefOpLowering : public RefCountingOpLowering<AddRefOp> {
 public:
-  explicit AddRefOpLowering(MLIRContext *ctx)
-      : RefCountingOpLowering(ctx, kAddRef) {}
+  explicit AddRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
+      : RefCountingOpLowering(converter, ctx, kAddRef) {}
 };
 
-// async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call.
+/// async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call.
 class DropRefOpLowering : public RefCountingOpLowering<DropRefOp> {
 public:
-  explicit DropRefOpLowering(MLIRContext *ctx)
-      : RefCountingOpLowering(ctx, kDropRef) {}
+  explicit DropRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
+      : RefCountingOpLowering(converter, ctx, kDropRef) {}
 };
 
 } // namespace
@@ -657,8 +760,9 @@ class DropRefOpLowering : public RefCountingOpLowering<DropRefOp> {
 namespace {
 class CreateGroupOpLowering : public ConversionPattern {
 public:
-  explicit CreateGroupOpLowering(MLIRContext *ctx)
-      : ConversionPattern(CreateGroupOp::getOperationName(), 1, ctx) {}
+  explicit CreateGroupOpLowering(TypeConverter &converter, MLIRContext *ctx)
+      : ConversionPattern(CreateGroupOp::getOperationName(), 1, converter,
+                          ctx) {}
 
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -677,8 +781,9 @@ class CreateGroupOpLowering : public ConversionPattern {
 namespace {
 class AddToGroupOpLowering : public ConversionPattern {
 public:
-  explicit AddToGroupOpLowering(MLIRContext *ctx)
-      : ConversionPattern(AddToGroupOp::getOperationName(), 1, ctx) {}
+  explicit AddToGroupOpLowering(TypeConverter &converter, MLIRContext *ctx)
+      : ConversionPattern(AddToGroupOp::getOperationName(), 1, converter, ctx) {
+  }
 
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -706,10 +811,10 @@ template <typename AwaitType, typename AwaitableType>
 class AwaitOpLoweringBase : public ConversionPattern {
 protected:
   explicit AwaitOpLoweringBase(
-      MLIRContext *ctx,
+      TypeConverter &converter, MLIRContext *ctx,
       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions,
       StringRef blockingAwaitFuncName, StringRef coroAwaitFuncName)
-      : ConversionPattern(AwaitType::getOperationName(), 1, ctx),
+      : ConversionPattern(AwaitType::getOperationName(), 1, converter, ctx),
         outlinedFunctions(outlinedFunctions),
         blockingAwaitFuncName(blockingAwaitFuncName),
         coroAwaitFuncName(coroAwaitFuncName) {}
@@ -719,7 +824,7 @@ class AwaitOpLoweringBase : public ConversionPattern {
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     // We can only await on one the `AwaitableType` (for `await` it can be
-    // only a `token`, for `await_all` it is a `group`).
+    // a `token` or a `value`, for `await_all` it must be a `group`).
     auto await = cast<AwaitType>(op);
     if (!await.operand().getType().template isa<AwaitableType>())
       return failure();
@@ -768,44 +873,163 @@ class AwaitOpLoweringBase : public ConversionPattern {
       Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
       addSuspensionPoint(coro, coroSave.getResult(0), op, suspended, resume,
                          builder);
+
+      // Make sure that replacement value will be constructed in resume block.
+      rewriter.setInsertionPointToStart(resume);
     }
 
-    // Original operation was replaced by function call or suspension point.
-    rewriter.eraseOp(op);
+    // Replace or erase the await operation with the new value.
+    if (Value replaceWith = getReplacementValue(op, operands[0], rewriter))
+      rewriter.replaceOp(op, replaceWith);
+    else
+      rewriter.eraseOp(op);
 
     return success();
   }
 
+  virtual Value getReplacementValue(Operation *op, Value operand,
+                                    ConversionPatternRewriter &rewriter) const {
+    return Value();
+  }
+
 private:
   const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
   StringRef blockingAwaitFuncName;
   StringRef coroAwaitFuncName;
 };
 
-// Lowering for `async.await` operation (only token operands are supported).
-class AwaitOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
+/// Lowering for `async.await` with a token operand.
+class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
   using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
 
 public:
-  explicit AwaitOpLowering(
-      MLIRContext *ctx,
+  explicit AwaitTokenOpLowering(
+      TypeConverter &converter, MLIRContext *ctx,
       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
-      : Base(ctx, outlinedFunctions, kAwaitToken, kAwaitAndExecute) {}
+      : Base(converter, ctx, outlinedFunctions, kAwaitToken,
+             kAwaitTokenAndExecute) {}
 };
 
-// Lowering for `async.await_all` operation.
+/// Lowering for `async.await` with a value operand.
+class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
+  using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
+
+public:
+  explicit AwaitValueOpLowering(
+      TypeConverter &converter, MLIRContext *ctx,
+      const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
+      : Base(converter, ctx, outlinedFunctions, kAwaitValue,
+             kAwaitValueAndExecute) {}
+
+  Value
+  getReplacementValue(Operation *op, Value operand,
+                      ConversionPatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
+    auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
+
+    // Get the underlying value type from the `async.value`.
+    auto await = cast<AwaitOp>(op);
+    auto valueType = await.operand().getType().cast<ValueType>().getValueType();
+
+    // Get a pointer to an async value storage from the runtime.
+    auto storage = rewriter.create<CallOp>(loc, kGetValueStorage,
+                                           TypeRange(i8Ptr), operand);
+
+    // Cast from i8* to the pointer pointer to LLVM type.
+    auto llvmValueType = getTypeConverter()->convertType(valueType);
+    auto castedStorage = rewriter.create<LLVM::BitcastOp>(
+        loc, LLVM::LLVMPointerType::get(llvmValueType.cast<LLVM::LLVMType>()),
+        storage.getResult(0));
+
+    // Load from the async value storage.
+    auto loaded = rewriter.create<LLVM::LoadOp>(loc, castedStorage.getResult());
+
+    // Cast from LLVM type to the expected value type. This cast will become
+    // no-op after lowering to LLVM.
+    return rewriter.create<LLVM::DialectCastOp>(loc, valueType, loaded);
+  }
+};
+
+/// Lowering for `async.await_all` operation.
 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
   using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
 
 public:
   explicit AwaitAllOpLowering(
-      MLIRContext *ctx,
+      TypeConverter &converter, MLIRContext *ctx,
       const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
-      : Base(ctx, outlinedFunctions, kAwaitGroup, kAwaitAllAndExecute) {}
+      : Base(converter, ctx, outlinedFunctions, kAwaitGroup,
+             kAwaitAllAndExecute) {}
 };
 
 } // namespace
 
+//===----------------------------------------------------------------------===//
+// async.yield op lowerings to the corresponding async runtime function calls.
+//===----------------------------------------------------------------------===//
+
+class YieldOpLowering : public ConversionPattern {
+public:
+  explicit YieldOpLowering(
+      TypeConverter &converter, MLIRContext *ctx,
+      const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
+      : ConversionPattern(async::YieldOp::getOperationName(), 1, converter,
+                          ctx),
+        outlinedFunctions(outlinedFunctions) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Check if yield operation is inside the outlined coroutine function.
+    auto func = op->template getParentOfType<FuncOp>();
+    auto outlined = outlinedFunctions.find(func);
+    if (outlined == outlinedFunctions.end())
+      return op->emitOpError(
+          "async.yield is not inside the outlined coroutine function");
+
+    Location loc = op->getLoc();
+    const CoroMachinery &coro = outlined->getSecond();
+
+    // Store yielded values into the async values storage and emplace them.
+    auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
+
+    for (auto tuple : llvm::zip(operands, coro.returnValues)) {
+      // Store `yieldValue` into the `asyncValue` storage.
+      Value yieldValue = std::get<0>(tuple);
+      Value asyncValue = std::get<1>(tuple);
+
+      // Get an opaque i8* pointer to an async value storage from the runtime.
+      auto storage = rewriter.create<CallOp>(loc, kGetValueStorage,
+                                             TypeRange(i8Ptr), asyncValue);
+
+      // Cast storage pointer to the yielded value type.
+      auto castedStorage = rewriter.create<LLVM::BitcastOp>(
+          loc,
+          LLVM::LLVMPointerType::get(
+              yieldValue.getType().cast<LLVM::LLVMType>()),
+          storage.getResult(0));
+
+      // Store the yielded value into the async value storage.
+      rewriter.create<LLVM::StoreOp>(loc, yieldValue,
+                                     castedStorage.getResult());
+
+      // Emplace the `async.value` to mark it ready.
+      rewriter.create<CallOp>(loc, kEmplaceValue, TypeRange(), asyncValue);
+    }
+
+    // Emplace the completion token to mark it ready.
+    rewriter.create<CallOp>(loc, kEmplaceToken, TypeRange(), coro.asyncToken);
+
+    // Original operation was replaced by the function call(s).
+    rewriter.eraseOp(op);
+
+    return success();
+  }
+
+private:
+  const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
+};
+
 //===----------------------------------------------------------------------===//
 
 namespace {
@@ -818,15 +1042,38 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
   ModuleOp module = getOperation();
   SymbolTable symbolTable(module);
 
+  MLIRContext *ctx = &getContext();
+
   // Outline all `async.execute` body regions into async functions (coroutines).
   llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions;
 
+  // We use conversion to LLVM type to ensure that all `async.value` operands
+  // and results can be lowered to LLVM load and store operations.
+  LLVMTypeConverter llvmConverter(ctx);
+  llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes);
+
+  // Returns true if the `async.value` payload is convertible to LLVM.
+  auto isConvertibleToLlvm = [&](Type type) -> bool {
+    auto valueType = type.cast<ValueType>().getValueType();
+    return static_cast<bool>(llvmConverter.convertType(valueType));
+  };
+
   WalkResult outlineResult = module.walk([&](ExecuteOp execute) {
+    // All operands and results must be convertible to LLVM.
+    if (!llvm::all_of(execute.operands().getTypes(), isConvertibleToLlvm)) {
+      execute.emitOpError("operands payload must be convertible to LLVM type");
+      return WalkResult::interrupt();
+    }
+    if (!llvm::all_of(execute.results().getTypes(), isConvertibleToLlvm)) {
+      execute.emitOpError("results payload must be convertible to LLVM type");
+      return WalkResult::interrupt();
+    }
+
     // We currently do not support execute operations that have async value
     // operands or produce async results.
-    if (!execute.operands().empty() || !execute.results().empty()) {
-      execute.emitOpError("can't outline async.execute op with async value "
-                          "operands or returned async results");
+    if (!execute.operands().empty()) {
+      execute.emitOpError(
+          "can't outline async.execute op with async value operands");
       return WalkResult::interrupt();
     }
 
@@ -852,26 +1099,44 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
   addCoroutineIntrinsicsDeclarations(module);
   addCRuntimeDeclarations(module);
 
-  MLIRContext *ctx = &getContext();
-
   // Convert async dialect types and operations to LLVM dialect.
   AsyncRuntimeTypeConverter converter;
   OwningRewritePatternList patterns;
 
+  // Convert async types in function signatures and function calls.
   populateFuncOpTypeConversionPattern(patterns, ctx, converter);
-  patterns.insert<CallOpOpConversion>(ctx);
-  patterns.insert<AddRefOpLowering, DropRefOpLowering>(ctx);
-  patterns.insert<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
-  patterns.insert<AwaitOpLowering, AwaitAllOpLowering>(ctx, outlinedFunctions);
+  populateCallOpTypeConversionPattern(patterns, ctx, converter);
+
+  // Convert return operations inside async.execute regions.
+  patterns.insert<ReturnOpOpConversion>(converter, ctx);
+
+  // Lower async operations to async runtime API calls.
+  patterns.insert<AddRefOpLowering, DropRefOpLowering>(converter, ctx);
+  patterns.insert<CreateGroupOpLowering, AddToGroupOpLowering>(converter, ctx);
+
+  // Use LLVM type converter to automatically convert between the async value
+  // payload type and LLVM type when loading/storing from/to the async
+  // value storage which is an opaque i8* pointer using LLVM load/store ops.
+  patterns
+      .insert<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
+          llvmConverter, ctx, outlinedFunctions);
+  patterns.insert<YieldOpLowering>(llvmConverter, ctx, outlinedFunctions);
 
   ConversionTarget target(*ctx);
   target.addLegalOp<ConstantOp>();
   target.addLegalDialect<LLVM::LLVMDialect>();
+
+  // All operations from Async dialect must be lowered to the runtime API calls.
   target.addIllegalDialect<AsyncDialect>();
+
+  // Add dynamic legality constraints to apply conversions defined above.
   target.addDynamicallyLegalOp<FuncOp>(
       [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
-  target.addDynamicallyLegalOp<CallOp>(
-      [&](CallOp op) { return converter.isLegal(op.getResultTypes()); });
+  target.addDynamicallyLegalOp<ReturnOp>(
+      [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); });
+  target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
+    return converter.isSignatureLegal(op.getCalleeType());
+  });
 
   if (failed(applyPartialConversion(module, target, std::move(patterns))))
     signalPassFailure();

diff  --git a/mlir/lib/Conversion/AsyncToLLVM/CMakeLists.txt b/mlir/lib/Conversion/AsyncToLLVM/CMakeLists.txt
index 005712c6cb48..48c2e0155ae1 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/AsyncToLLVM/CMakeLists.txt
@@ -13,5 +13,7 @@ add_mlir_conversion_library(MLIRAsyncToLLVM
   LINK_LIBS PUBLIC
   MLIRAsync
   MLIRLLVMIR
+  MLIRStandardOpsTransforms
+  MLIRStandardToLLVM
   MLIRTransforms
   )

diff  --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
index 3bfed86aa996..45bdcb3733b8 100644
--- a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
@@ -114,6 +114,7 @@ static AsyncRuntime *getDefaultAsyncRuntimeInstance() {
   return runtime.get();
 }
 
+// Async token provides a mechanism to signal asynchronous operation completion.
 struct AsyncToken : public RefCounted {
   // AsyncToken created with a reference count of 2 because it will be returned
   // to the `async.execute` caller and also will be later on emplaced by the
@@ -130,6 +131,28 @@ struct AsyncToken : public RefCounted {
   std::vector<std::function<void()>> awaiters;
 };
 
+// Async value provides a mechanism to access the result of asynchronous
+// operations. It owns the storage that is used to store/load the value of the
+// underlying type, and a flag to signal if the value is ready or not.
+struct AsyncValue : public RefCounted {
+  // AsyncValue similar to an AsyncToken created with a reference count of 2.
+  AsyncValue(AsyncRuntime *runtime, int32_t size)
+      : RefCounted(runtime, /*count=*/2), storage(size) {}
+
+  // Internal state below guarded by a mutex.
+  std::mutex mu;
+  std::condition_variable cv;
+
+  bool ready = false;
+  std::vector<std::function<void()>> awaiters;
+
+  // Use vector of bytes to store async value payload.
+  std::vector<int8_t> storage;
+};
+
+// Async group provides a mechanism to group together multiple async tokens or
+// values to await on all of them together (wait for the completion of all
+// tokens or values added to the group).
 struct AsyncGroup : public RefCounted {
   AsyncGroup(AsyncRuntime *runtime)
       : RefCounted(runtime), pendingTokens(0), rank(0) {}
@@ -159,12 +182,18 @@ extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) {
   refCounted->dropRef(count);
 }
 
-// Create a new `async.token` in not-ready state.
+// Creates a new `async.token` in not-ready state.
 extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
   AsyncToken *token = new AsyncToken(getDefaultAsyncRuntimeInstance());
   return token;
 }
 
+// Creates a new `async.value` in not-ready state.
+extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int32_t size) {
+  AsyncValue *value = new AsyncValue(getDefaultAsyncRuntimeInstance(), size);
+  return value;
+}
+
 // Create a new `async.group` in empty state.
 extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() {
   AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntimeInstance());
@@ -228,18 +257,45 @@ extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
   token->dropRef();
 }
 
+// Switches `async.value` to ready state and runs all awaiters.
+extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) {
+  // Make sure that `dropRef` does not destroy the mutex owned by the lock.
+  {
+    std::unique_lock<std::mutex> lock(value->mu);
+    value->ready = true;
+    value->cv.notify_all();
+    for (auto &awaiter : value->awaiters)
+      awaiter();
+  }
+
+  // Async values created with a ref count `2` to keep value alive until the
+  // async task completes. Drop this reference explicitly when value emplaced.
+  value->dropRef();
+}
+
 extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
   std::unique_lock<std::mutex> lock(token->mu);
   if (!token->ready)
     token->cv.wait(lock, [token] { return token->ready; });
 }
 
+extern "C" void mlirAsyncRuntimeAwaitValue(AsyncValue *value) {
+  std::unique_lock<std::mutex> lock(value->mu);
+  if (!value->ready)
+    value->cv.wait(lock, [value] { return value->ready; });
+}
+
 extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
   std::unique_lock<std::mutex> lock(group->mu);
   if (group->pendingTokens != 0)
     group->cv.wait(lock, [group] { return group->pendingTokens == 0; });
 }
 
+// Returns a pointer to the storage owned by the async value.
+extern "C" ValueStorage mlirAsyncRuntimeGetValueStorage(AsyncValue *value) {
+  return value->storage.data();
+}
+
 extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
   (*resume)(handle);
 }
@@ -255,6 +311,17 @@ extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,
     token->awaiters.push_back([execute]() { execute(); });
 }
 
+extern "C" void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *value,
+                                                     CoroHandle handle,
+                                                     CoroResume resume) {
+  std::unique_lock<std::mutex> lock(value->mu);
+  auto execute = [handle, resume]() { (*resume)(handle); };
+  if (value->ready)
+    execute();
+  else
+    value->awaiters.push_back([execute]() { execute(); });
+}
+
 extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group,
                                                           CoroHandle handle,
                                                           CoroResume resume) {

diff  --git a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
index dadb28dbc082..dce0cf89628f 100644
--- a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
+++ b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
@@ -211,3 +211,44 @@ func @async_group_await_all(%arg0: f32, %arg1: memref<1xf32>) {
 
 // Emplace result token.
 // CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_1]])
+
+// -----
+
+// CHECK-LABEL: execute_and_return_f32
+func @execute_and_return_f32() -> f32 {
+ // CHECK: %[[RET:.*]]:2 = call @async_execute_fn
+  %token, %result = async.execute -> !async.value<f32> {
+    %c0 = constant 123.0 : f32
+    async.yield %c0 : f32
+  }
+
+  // CHECK: %[[STORAGE:.*]] = call @mlirAsyncRuntimeGetValueStorage(%[[RET]]#1)
+  // CHECK: %[[ST_F32:.*]] = llvm.bitcast %[[STORAGE]]
+  // CHECK: %[[LOADED:.*]] = llvm.load %[[ST_F32]] :  !llvm.ptr<float>
+  // CHECK: %[[CASTED:.*]] = llvm.mlir.cast %[[LOADED]] : !llvm.float to f32
+  %0 = async.await %result : !async.value<f32>
+
+  return %0 : f32
+}
+
+// Function outlined from the async.execute operation.
+// CHECK-LABEL: func private @async_execute_fn()
+// CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateToken()
+// CHECK: %[[VALUE:.*]] = call @mlirAsyncRuntimeCreateValue
+// CHECK: %[[HDL:.*]] = llvm.call @llvm.coro.begin
+
+// Suspend coroutine in the beginning.
+// CHECK: call @mlirAsyncRuntimeExecute(%[[HDL]],
+// CHECK: llvm.call @llvm.coro.suspend
+
+// Emplace result value.
+// CHECK: %[[CST:.*]] = constant 1.230000e+02 : f32
+// CHECK: %[[LLVM_CST:.*]] = llvm.mlir.cast %[[CST]] : f32 to !llvm.float
+// CHECK: %[[STORAGE:.*]] = call @mlirAsyncRuntimeGetValueStorage(%[[VALUE]])
+// CHECK: %[[ST_F32:.*]] = llvm.bitcast %[[STORAGE]]
+// CHECK: llvm.store %[[LLVM_CST]], %[[ST_F32]] : !llvm.ptr<float>
+// CHECK: call @mlirAsyncRuntimeEmplaceValue(%[[VALUE]])
+
+// Emplace result token.
+// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[TOKEN]])
+

diff  --git a/mlir/test/mlir-cpu-runner/async-value.mlir b/mlir/test/mlir-cpu-runner/async-value.mlir
new file mode 100644
index 000000000000..44b3b29e8491
--- /dev/null
+++ b/mlir/test/mlir-cpu-runner/async-value.mlir
@@ -0,0 +1,64 @@
+// RUN:   mlir-opt %s -async-ref-counting                                      \
+// RUN:               -convert-async-to-llvm                                   \
+// RUN:               -convert-vector-to-llvm                                  \
+// RUN:               -convert-std-to-llvm                                     \
+// RUN: | mlir-cpu-runner                                                      \
+// RUN:     -e main -entry-point-result=void -O0                               \
+// RUN:     -shared-libs=%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext  \
+// RUN:     -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext    \
+// RUN:     -shared-libs=%linalg_test_lib_dir/libmlir_async_runtime%shlibext   \
+// RUN: | FileCheck %s --dump-input=always
+
+func @main() {
+
+  // ------------------------------------------------------------------------ //
+  // Blocking async.await outside of the async.execute.
+  // ------------------------------------------------------------------------ //
+  %token, %result = async.execute -> !async.value<f32> {
+    %0 = constant 123.456 : f32
+    async.yield %0 : f32
+  }
+  %1 = async.await %result : !async.value<f32>
+
+  // CHECK: 123.456
+  vector.print %1 : f32
+
+  // ------------------------------------------------------------------------ //
+  // Non-blocking async.await inside the async.execute
+  // ------------------------------------------------------------------------ //
+  %token0, %result0 = async.execute -> !async.value<f32> {
+    %token1, %result2 = async.execute -> !async.value<f32> {
+      %2 = constant 456.789 : f32
+      async.yield %2 : f32
+    }
+    %3 = async.await %result2 : !async.value<f32>
+    async.yield %3 : f32
+  }
+  %4 = async.await %result0 : !async.value<f32>
+
+  // CHECK: 456.789
+  vector.print %4 : f32
+
+  // ------------------------------------------------------------------------ //
+  // Memref allocated inside async.execute region.
+  // ------------------------------------------------------------------------ //
+  %token2, %result2 = async.execute[%token0] -> !async.value<memref<f32>> {
+    %5 = alloc() : memref<f32>
+    %c0 = constant 987.654 : f32
+    store %c0, %5[]: memref<f32>
+    async.yield %5 : memref<f32>
+  }
+  %6 = async.await %result2 : !async.value<memref<f32>>
+  %7 = memref_cast %6 :  memref<f32> to memref<*xf32>
+
+  // CHECK: Unranked Memref
+  // CHECK-SAME: rank = 0 offset = 0 sizes = [] strides = []
+  // CHECK-NEXT: [987.654]
+  call @print_memref_f32(%7): (memref<*xf32>) -> ()
+  dealloc %6 : memref<f32>
+
+  return
+}
+
+func private @print_memref_f32(memref<*xf32>)
+  attributes { llvm.emit_c_interface }


        


More information about the llvm-branch-commits mailing list