[llvm-branch-commits] [mlir] 621ad46 - [mlir] Async: lowering async.value to LLVM
Eugene Zhulenev via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Dec 25 02:28:39 PST 2020
Author: Eugene Zhulenev
Date: 2020-12-25T02:23:48-08:00
New Revision: 621ad468d99d4013a4298465f02707a5e9e89cae
URL: https://github.com/llvm/llvm-project/commit/621ad468d99d4013a4298465f02707a5e9e89cae
DIFF: https://github.com/llvm/llvm-project/commit/621ad468d99d4013a4298465f02707a5e9e89cae.diff
LOG: [mlir] Async: lowering async.value to LLVM
1. Add new methods to Async runtime API to support yielding async values
2. Add lowering from `async.yield` with value payload to the new runtime API calls
`async.value` lowering requires that payload type is convertible to LLVM and supported by `llvm.mlir.cast` (DialectCast) operation.
Reviewed By: csigg
Differential Revision: https://reviews.llvm.org/D93592
Added:
mlir/test/mlir-cpu-runner/async-value.mlir
Modified:
mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
mlir/lib/Conversion/AsyncToLLVM/CMakeLists.txt
mlir/lib/ExecutionEngine/AsyncRuntime.cpp
mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
index e3d90198f36c..0fe44cd1c127 100644
--- a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
+++ b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
@@ -45,6 +45,12 @@ typedef struct AsyncToken AsyncToken;
// Runtime implementation of `async.group` data type.
typedef struct AsyncGroup AsyncGroup;
+// Runtime implementation of `async.value` data type.
+typedef struct AsyncValue AsyncValue;
+
+// Async value payload stored in a memory owned by the async.value.
+using ValueStorage = void *;
+
// Async runtime uses LLVM coroutines to represent asynchronous tasks. Task
// function is a coroutine handle and a resume function that continue coroutine
// execution from a suspension point.
@@ -66,6 +72,13 @@ extern "C" MLIR_ASYNCRUNTIME_EXPORT void
// Create a new `async.token` in not-ready state.
extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncToken *mlirAsyncRuntimeCreateToken();
+// Create a new `async.value` in not-ready state. Size parameter specifies the
+// number of bytes that will be allocated for the async value storage. Storage
+// is owned by the `async.value` and deallocated when the async value is
+// destructed (reference count drops to zero).
+extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncValue *
+ mlirAsyncRuntimeCreateValue(int32_t);
+
// Create a new `async.group` in empty state.
extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncGroup *mlirAsyncRuntimeCreateGroup();
@@ -76,14 +89,26 @@ mlirAsyncRuntimeAddTokenToGroup(AsyncToken *, AsyncGroup *);
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
mlirAsyncRuntimeEmplaceToken(AsyncToken *);
+// Switches `async.value` to ready state and runs all awaiters.
+extern "C" MLIR_ASYNCRUNTIME_EXPORT void
+mlirAsyncRuntimeEmplaceValue(AsyncValue *);
+
// Blocks the caller thread until the token becomes ready.
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
mlirAsyncRuntimeAwaitToken(AsyncToken *);
+// Blocks the caller thread until the value becomes ready.
+extern "C" MLIR_ASYNCRUNTIME_EXPORT void
+mlirAsyncRuntimeAwaitValue(AsyncValue *);
+
// Blocks the caller thread until the elements in the group become ready.
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *);
+// Returns a pointer to the storage owned by the async value.
+extern "C" MLIR_ASYNCRUNTIME_EXPORT ValueStorage
+mlirAsyncRuntimeGetValueStorage(AsyncValue *);
+
// Executes the task (coro handle + resume function) in one of the threads
// managed by the runtime.
extern "C" MLIR_ASYNCRUNTIME_EXPORT void mlirAsyncRuntimeExecute(CoroHandle,
@@ -94,6 +119,11 @@ extern "C" MLIR_ASYNCRUNTIME_EXPORT void mlirAsyncRuntimeExecute(CoroHandle,
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *, CoroHandle, CoroResume);
+// Executes the task (coro handle + resume function) in one of the threads
+// managed by the runtime after the value becomes ready.
+extern "C" MLIR_ASYNCRUNTIME_EXPORT void
+mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *, CoroHandle, CoroResume);
+
// Executes the task (coro handle + resume function) in one of the threads
// managed by the runtime after the all members of the group become ready.
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 3daa70b0a952..f1d6264606be 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -9,9 +9,11 @@
#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
#include "../PassDetail.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/TypeUtilities.h"
@@ -36,23 +38,39 @@ static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
+static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
+static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue";
static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
+static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
+static constexpr const char *kGetValueStorage =
+ "mlirAsyncRuntimeGetValueStorage";
static constexpr const char *kAddTokenToGroup =
"mlirAsyncRuntimeAddTokenToGroup";
-static constexpr const char *kAwaitAndExecute =
+static constexpr const char *kAwaitTokenAndExecute =
"mlirAsyncRuntimeAwaitTokenAndExecute";
+static constexpr const char *kAwaitValueAndExecute =
+ "mlirAsyncRuntimeAwaitValueAndExecute";
static constexpr const char *kAwaitAllAndExecute =
"mlirAsyncRuntimeAwaitAllInGroupAndExecute";
namespace {
-// Async Runtime API function types.
+/// Async Runtime API function types.
+///
+/// Because we can't create API function signature for type parametrized
+/// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After
+/// lowering all async data types become opaque pointers at runtime.
struct AsyncAPI {
+ // All async types are lowered to opaque i8* LLVM pointers at runtime.
+ static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) {
+ return LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
+ }
+
static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
- auto ref = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
+ auto ref = opaquePointerType(ctx);
auto count = IntegerType::get(ctx, 32);
return FunctionType::get(ctx, {ref, count}, {});
}
@@ -61,24 +79,46 @@ struct AsyncAPI {
return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
}
+ static FunctionType createValueFunctionType(MLIRContext *ctx) {
+ auto i32 = IntegerType::get(ctx, 32);
+ auto value = opaquePointerType(ctx);
+ return FunctionType::get(ctx, {i32}, {value});
+ }
+
static FunctionType createGroupFunctionType(MLIRContext *ctx) {
return FunctionType::get(ctx, {}, {GroupType::get(ctx)});
}
+ static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
+ auto value = opaquePointerType(ctx);
+ auto storage = opaquePointerType(ctx);
+ return FunctionType::get(ctx, {value}, {storage});
+ }
+
static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
}
+ static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
+ auto value = opaquePointerType(ctx);
+ return FunctionType::get(ctx, {value}, {});
+ }
+
static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
}
+ static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
+ auto value = opaquePointerType(ctx);
+ return FunctionType::get(ctx, {value}, {});
+ }
+
static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
}
static FunctionType executeFunctionType(MLIRContext *ctx) {
- auto hdl = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
+ auto hdl = opaquePointerType(ctx);
auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
return FunctionType::get(ctx, {hdl, resume}, {});
}
@@ -89,14 +129,21 @@ struct AsyncAPI {
{i64});
}
- static FunctionType awaitAndExecuteFunctionType(MLIRContext *ctx) {
- auto hdl = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
+ static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
+ auto hdl = opaquePointerType(ctx);
auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {});
}
+ static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
+ auto value = opaquePointerType(ctx);
+ auto hdl = opaquePointerType(ctx);
+ auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
+ return FunctionType::get(ctx, {value, hdl, resume}, {});
+ }
+
static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
- auto hdl = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
+ auto hdl = opaquePointerType(ctx);
auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
}
@@ -104,13 +151,13 @@ struct AsyncAPI {
// Auxiliary coroutine resume intrinsic wrapper.
static LLVM::LLVMType resumeFunctionType(MLIRContext *ctx) {
auto voidTy = LLVM::LLVMVoidType::get(ctx);
- auto i8Ptr = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
+ auto i8Ptr = opaquePointerType(ctx);
return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false);
}
};
} // namespace
-// Adds Async Runtime C API declarations to the module.
+/// Adds Async Runtime C API declarations to the module.
static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
auto builder = ImplicitLocOpBuilder::atBlockTerminator(module.getLoc(),
module.getBody());
@@ -125,13 +172,20 @@ static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
+ addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx));
addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
+ addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
+ addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
+ addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx));
addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
- addFuncDecl(kAwaitAndExecute, AsyncAPI::awaitAndExecuteFunctionType(ctx));
+ addFuncDecl(kAwaitTokenAndExecute,
+ AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
+ addFuncDecl(kAwaitValueAndExecute,
+ AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
addFuncDecl(kAwaitAllAndExecute,
AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
}
@@ -215,9 +269,9 @@ static void addCRuntimeDeclarations(ModuleOp module) {
static constexpr const char *kResume = "__resume";
-// A function that takes a coroutine handle and calls a `llvm.coro.resume`
-// intrinsics. We need this function to be able to pass it to the async
-// runtime execute API.
+/// A function that takes a coroutine handle and calls a `llvm.coro.resume`
+/// intrinsics. We need this function to be able to pass it to the async
+/// runtime execute API.
static void addResumeFunction(ModuleOp module) {
MLIRContext *ctx = module.getContext();
@@ -248,49 +302,61 @@ static void addResumeFunction(ModuleOp module) {
// async.execute op outlining to the coroutine functions.
//===----------------------------------------------------------------------===//
-// Function targeted for coroutine transformation has two additional blocks at
-// the end: coroutine cleanup and coroutine suspension.
-//
-// async.await op lowering additionaly creates a resume block for each
-// operation to enable non-blocking waiting via coroutine suspension.
+/// Function targeted for coroutine transformation has two additional blocks at
+/// the end: coroutine cleanup and coroutine suspension.
+///
+/// async.await op lowering additionaly creates a resume block for each
+/// operation to enable non-blocking waiting via coroutine suspension.
namespace {
struct CoroMachinery {
- Value asyncToken;
+ // Async execute region returns a completion token, and an async value for
+ // each yielded value.
+ //
+ // %token, %result = async.execute -> !async.value<T> {
+ // %0 = constant ... : T
+ // async.yield %0 : T
+ // }
+ Value asyncToken; // token representing completion of the async region
+ llvm::SmallVector<Value, 4> returnValues; // returned async values
+
Value coroHandle;
Block *cleanup;
Block *suspend;
};
} // namespace
-// Builds an coroutine template compatible with LLVM coroutines lowering.
-//
-// - `entry` block sets up the coroutine.
-// - `cleanup` block cleans up the coroutine state.
-// - `suspend block after the @llvm.coro.end() defines what value will be
-// returned to the initial caller of a coroutine. Everything before the
-// @llvm.coro.end() will be executed at every suspension point.
-//
-// Coroutine structure (only the important bits):
-//
-// func @async_execute_fn(<function-arguments>) -> !async.token {
-// ^entryBlock(<function-arguments>):
-// %token = <async token> : !async.token // create async runtime token
-// %hdl = llvm.call @llvm.coro.id(...) // create a coroutine handle
-// br ^cleanup
-//
-// ^cleanup:
-// llvm.call @llvm.coro.free(...) // delete coroutine state
-// br ^suspend
-//
-// ^suspend:
-// llvm.call @llvm.coro.end(...) // marks the end of a coroutine
-// return %token : !async.token
-// }
-//
-// The actual code for the async.execute operation body region will be inserted
-// before the entry block terminator.
-//
-//
+/// Builds an coroutine template compatible with LLVM coroutines lowering.
+///
+/// - `entry` block sets up the coroutine.
+/// - `cleanup` block cleans up the coroutine state.
+/// - `suspend block after the @llvm.coro.end() defines what value will be
+/// returned to the initial caller of a coroutine. Everything before the
+/// @llvm.coro.end() will be executed at every suspension point.
+///
+/// Coroutine structure (only the important bits):
+///
+/// func @async_execute_fn(<function-arguments>)
+/// -> (!async.token, !async.value<T>)
+/// {
+/// ^entryBlock(<function-arguments>):
+/// %token = <async token> : !async.token // create async runtime token
+/// %value = <async value> : !async.value<T> // create async value
+/// %hdl = llvm.call @llvm.coro.id(...) // create a coroutine handle
+/// br ^cleanup
+///
+/// ^cleanup:
+/// llvm.call @llvm.coro.free(...) // delete coroutine state
+/// br ^suspend
+///
+/// ^suspend:
+/// llvm.call @llvm.coro.end(...) // marks the end of a coroutine
+/// return %token, %value : !async.token, !async.value<T>
+/// }
+///
+/// The actual code for the async.execute operation body region will be inserted
+/// before the entry block terminator.
+///
+///
static CoroMachinery setupCoroMachinery(FuncOp func) {
assert(func.getBody().empty() && "Function must have empty body");
@@ -312,6 +378,44 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
// ------------------------------------------------------------------------ //
auto createToken = builder.create<CallOp>(kCreateToken, TokenType::get(ctx));
+ // Async value operands and results must be convertible to LLVM types. This is
+ // verified before the function outlining.
+ LLVMTypeConverter converter(ctx);
+
+ // Returns the size requirements for the async value storage.
+ // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt
+ auto sizeOf = [&](ValueType valueType) -> Value {
+ auto storedType = converter.convertType(valueType.getValueType());
+ auto storagePtrType =
+ LLVM::LLVMPointerType::get(storedType.cast<LLVM::LLVMType>());
+
+ // %Size = getelementptr %T* null, int 1
+ // %SizeI = ptrtoint %T* %Size to i32
+ auto nullPtr = builder.create<LLVM::NullOp>(loc, storagePtrType);
+ auto one = builder.create<LLVM::ConstantOp>(loc, i32,
+ builder.getI32IntegerAttr(1));
+ auto gep = builder.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr,
+ one.getResult());
+ auto size = builder.create<LLVM::PtrToIntOp>(loc, i32, gep);
+
+ // Cast to std type because runtime API defined using std types.
+ return builder.create<LLVM::DialectCastOp>(loc, builder.getI32Type(),
+ size.getResult());
+ };
+
+ // We use the `async.value` type as a return type although it does not match
+ // the `kCreateValue` function signature, because it will be later lowered to
+ // the runtime type (opaque i8* pointer).
+ llvm::SmallVector<CallOp, 4> createValues;
+ for (auto resultType : func.getCallableResults().drop_front(1))
+ createValues.emplace_back(builder.create<CallOp>(
+ loc, kCreateValue, resultType, sizeOf(resultType.cast<ValueType>())));
+
+ auto createdValues = llvm::map_range(
+ createValues, [](CallOp call) { return call.getResult(0); });
+ llvm::SmallVector<Value, 4> returnValues(createdValues.begin(),
+ createdValues.end());
+
// ------------------------------------------------------------------------ //
// Initialize coroutine: allocate frame, get coroutine handle.
// ------------------------------------------------------------------------ //
@@ -371,9 +475,11 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
builder.create<LLVM::CallOp>(i1, builder.getSymbolRefAttr(kCoroEnd),
ValueRange({coroHdl.getResult(0), constFalse}));
- // Return created `async.token` from the suspend block. This will be the
- // return value of a coroutine ramp function.
- builder.create<ReturnOp>(createToken.getResult(0));
+ // Return created `async.token` and `async.values` from the suspend block.
+ // This will be the return value of a coroutine ramp function.
+ SmallVector<Value, 4> ret{createToken.getResult(0)};
+ ret.insert(ret.end(), returnValues.begin(), returnValues.end());
+ builder.create<ReturnOp>(loc, ret);
// Branch from the entry block to the cleanup block to create a valid CFG.
builder.setInsertionPointToEnd(entryBlock);
@@ -383,39 +489,44 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
// `async.await` op lowering will create resume blocks for async
// continuations, and will conditionally branch to cleanup or suspend blocks.
- return {createToken.getResult(0), coroHdl.getResult(0), cleanupBlock,
- suspendBlock};
+ CoroMachinery machinery;
+ machinery.asyncToken = createToken.getResult(0);
+ machinery.returnValues = returnValues;
+ machinery.coroHandle = coroHdl.getResult(0);
+ machinery.cleanup = cleanupBlock;
+ machinery.suspend = suspendBlock;
+ return machinery;
}
-// Add a LLVM coroutine suspension point to the end of suspended block, to
-// resume execution in resume block. The caller is responsible for creating the
-// two suspended/resume blocks with the desired ops contained in each block.
-// This function merely provides the required control flow logic.
-//
-// `coroState` must be a value returned from the call to @llvm.coro.save(...)
-// intrinsic (saved coroutine state).
-//
-// Before:
-//
-// ^bb0:
-// "opBefore"(...)
-// "op"(...)
-// ^cleanup: ...
-// ^suspend: ...
-// ^resume:
-// "op"(...)
-//
-// After:
-//
-// ^bb0:
-// "opBefore"(...)
-// %suspend = llmv.call @llvm.coro.suspend(...)
-// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
-// ^resume:
-// "op"(...)
-// ^cleanup: ...
-// ^suspend: ...
-//
+/// Add a LLVM coroutine suspension point to the end of suspended block, to
+/// resume execution in resume block. The caller is responsible for creating the
+/// two suspended/resume blocks with the desired ops contained in each block.
+/// This function merely provides the required control flow logic.
+///
+/// `coroState` must be a value returned from the call to @llvm.coro.save(...)
+/// intrinsic (saved coroutine state).
+///
+/// Before:
+///
+/// ^bb0:
+/// "opBefore"(...)
+/// "op"(...)
+/// ^cleanup: ...
+/// ^suspend: ...
+/// ^resume:
+/// "op"(...)
+///
+/// After:
+///
+/// ^bb0:
+/// "opBefore"(...)
+/// %suspend = llmv.call @llvm.coro.suspend(...)
+/// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
+/// ^resume:
+/// "op"(...)
+/// ^cleanup: ...
+/// ^suspend: ...
+///
static void addSuspensionPoint(CoroMachinery coro, Value coroState,
Operation *op, Block *suspended, Block *resume,
OpBuilder &builder) {
@@ -461,10 +572,10 @@ static void addSuspensionPoint(CoroMachinery coro, Value coroState,
/*falseDest=*/coro.cleanup);
}
-// Outline the body region attached to the `async.execute` op into a standalone
-// function.
-//
-// Note that this is not reversible transformation.
+/// Outline the body region attached to the `async.execute` op into a standalone
+/// function.
+///
+/// Note that this is not reversible transformation.
static std::pair<FuncOp, CoroMachinery>
outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
ModuleOp module = execute->getParentOfType<ModuleOp>();
@@ -475,6 +586,7 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
// Collect all outlined function inputs.
llvm::SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
execute.dependencies().end());
+ assert(execute.operands().empty() && "operands are not supported");
getUsedValuesDefinedAbove(execute.body(), functionInputs);
// Collect types for the outlined function inputs and outputs.
@@ -535,15 +647,9 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
valueMapping.map(functionInputs, func.getArguments());
// Clone all operations from the execute operation body into the outlined
- // function body, and replace all `async.yield` operations with a call
- // to async runtime to emplace the result token.
- for (Operation &op : execute.body().getOps()) {
- if (isa<async::YieldOp>(op)) {
- builder.create<CallOp>(kEmplaceToken, TypeRange(), coro.asyncToken);
- continue;
- }
+ // function body.
+ for (Operation &op : execute.body().getOps())
builder.clone(op, valueMapping);
- }
// Replace the original `async.execute` with a call to outlined function.
ImplicitLocOpBuilder callBuilder(loc, execute);
@@ -560,42 +666,38 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
//===----------------------------------------------------------------------===//
namespace {
+
+/// AsyncRuntimeTypeConverter only converts types from the Async dialect to
+/// their runtime type (opaque pointers) and does not convert any other types.
class AsyncRuntimeTypeConverter : public TypeConverter {
public:
- AsyncRuntimeTypeConverter() { addConversion(convertType); }
-
- static Type convertType(Type type) {
- MLIRContext *ctx = type.getContext();
- // Convert async tokens and groups to opaque pointers.
- if (type.isa<TokenType, GroupType>())
- return LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8));
- return type;
+ AsyncRuntimeTypeConverter() {
+ addConversion([](Type type) { return type; });
+ addConversion(convertAsyncTypes);
+ }
+
+ static Optional<Type> convertAsyncTypes(Type type) {
+ if (type.isa<TokenType, GroupType, ValueType>())
+ return AsyncAPI::opaquePointerType(type.getContext());
+ return llvm::None;
}
};
} // namespace
//===----------------------------------------------------------------------===//
-// Convert types for all call operations to lowered async types.
+// Convert return operations that return async values from async regions.
//===----------------------------------------------------------------------===//
namespace {
-class CallOpOpConversion : public ConversionPattern {
+class ReturnOpOpConversion : public ConversionPattern {
public:
- explicit CallOpOpConversion(MLIRContext *ctx)
- : ConversionPattern(CallOp::getOperationName(), 1, ctx) {}
+ explicit ReturnOpOpConversion(TypeConverter &converter, MLIRContext *ctx)
+ : ConversionPattern(ReturnOp::getOperationName(), 1, converter, ctx) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- AsyncRuntimeTypeConverter converter;
-
- SmallVector<Type, 5> resultTypes;
- converter.convertTypes(op->getResultTypes(), resultTypes);
-
- CallOp call = cast<CallOp>(op);
- rewriter.replaceOpWithNewOp<CallOp>(op, resultTypes, call.callee(),
- operands);
-
+ rewriter.replaceOpWithNewOp<ReturnOp>(op, operands);
return success();
}
};
@@ -611,8 +713,9 @@ namespace {
template <typename RefCountingOp>
class RefCountingOpLowering : public ConversionPattern {
public:
- explicit RefCountingOpLowering(MLIRContext *ctx, StringRef apiFunctionName)
- : ConversionPattern(RefCountingOp::getOperationName(), 1, ctx),
+ explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx,
+ StringRef apiFunctionName)
+ : ConversionPattern(RefCountingOp::getOperationName(), 1, converter, ctx),
apiFunctionName(apiFunctionName) {}
LogicalResult
@@ -634,18 +737,18 @@ class RefCountingOpLowering : public ConversionPattern {
StringRef apiFunctionName;
};
-// async.drop_ref op lowering to mlirAsyncRuntimeDropRef function call.
+/// async.drop_ref op lowering to mlirAsyncRuntimeDropRef function call.
class AddRefOpLowering : public RefCountingOpLowering<AddRefOp> {
public:
- explicit AddRefOpLowering(MLIRContext *ctx)
- : RefCountingOpLowering(ctx, kAddRef) {}
+ explicit AddRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
+ : RefCountingOpLowering(converter, ctx, kAddRef) {}
};
-// async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call.
+/// async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call.
class DropRefOpLowering : public RefCountingOpLowering<DropRefOp> {
public:
- explicit DropRefOpLowering(MLIRContext *ctx)
- : RefCountingOpLowering(ctx, kDropRef) {}
+ explicit DropRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
+ : RefCountingOpLowering(converter, ctx, kDropRef) {}
};
} // namespace
@@ -657,8 +760,9 @@ class DropRefOpLowering : public RefCountingOpLowering<DropRefOp> {
namespace {
class CreateGroupOpLowering : public ConversionPattern {
public:
- explicit CreateGroupOpLowering(MLIRContext *ctx)
- : ConversionPattern(CreateGroupOp::getOperationName(), 1, ctx) {}
+ explicit CreateGroupOpLowering(TypeConverter &converter, MLIRContext *ctx)
+ : ConversionPattern(CreateGroupOp::getOperationName(), 1, converter,
+ ctx) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -677,8 +781,9 @@ class CreateGroupOpLowering : public ConversionPattern {
namespace {
class AddToGroupOpLowering : public ConversionPattern {
public:
- explicit AddToGroupOpLowering(MLIRContext *ctx)
- : ConversionPattern(AddToGroupOp::getOperationName(), 1, ctx) {}
+ explicit AddToGroupOpLowering(TypeConverter &converter, MLIRContext *ctx)
+ : ConversionPattern(AddToGroupOp::getOperationName(), 1, converter, ctx) {
+ }
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -706,10 +811,10 @@ template <typename AwaitType, typename AwaitableType>
class AwaitOpLoweringBase : public ConversionPattern {
protected:
explicit AwaitOpLoweringBase(
- MLIRContext *ctx,
+ TypeConverter &converter, MLIRContext *ctx,
const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions,
StringRef blockingAwaitFuncName, StringRef coroAwaitFuncName)
- : ConversionPattern(AwaitType::getOperationName(), 1, ctx),
+ : ConversionPattern(AwaitType::getOperationName(), 1, converter, ctx),
outlinedFunctions(outlinedFunctions),
blockingAwaitFuncName(blockingAwaitFuncName),
coroAwaitFuncName(coroAwaitFuncName) {}
@@ -719,7 +824,7 @@ class AwaitOpLoweringBase : public ConversionPattern {
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// We can only await on one the `AwaitableType` (for `await` it can be
- // only a `token`, for `await_all` it is a `group`).
+ // a `token` or a `value`, for `await_all` it must be a `group`).
auto await = cast<AwaitType>(op);
if (!await.operand().getType().template isa<AwaitableType>())
return failure();
@@ -768,44 +873,163 @@ class AwaitOpLoweringBase : public ConversionPattern {
Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
addSuspensionPoint(coro, coroSave.getResult(0), op, suspended, resume,
builder);
+
+ // Make sure that replacement value will be constructed in resume block.
+ rewriter.setInsertionPointToStart(resume);
}
- // Original operation was replaced by function call or suspension point.
- rewriter.eraseOp(op);
+ // Replace or erase the await operation with the new value.
+ if (Value replaceWith = getReplacementValue(op, operands[0], rewriter))
+ rewriter.replaceOp(op, replaceWith);
+ else
+ rewriter.eraseOp(op);
return success();
}
+ virtual Value getReplacementValue(Operation *op, Value operand,
+ ConversionPatternRewriter &rewriter) const {
+ return Value();
+ }
+
private:
const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
StringRef blockingAwaitFuncName;
StringRef coroAwaitFuncName;
};
-// Lowering for `async.await` operation (only token operands are supported).
-class AwaitOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
+/// Lowering for `async.await` with a token operand.
+class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
public:
- explicit AwaitOpLowering(
- MLIRContext *ctx,
+ explicit AwaitTokenOpLowering(
+ TypeConverter &converter, MLIRContext *ctx,
const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
- : Base(ctx, outlinedFunctions, kAwaitToken, kAwaitAndExecute) {}
+ : Base(converter, ctx, outlinedFunctions, kAwaitToken,
+ kAwaitTokenAndExecute) {}
};
-// Lowering for `async.await_all` operation.
+/// Lowering for `async.await` with a value operand.
+class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
+ using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
+
+public:
+ explicit AwaitValueOpLowering(
+ TypeConverter &converter, MLIRContext *ctx,
+ const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
+ : Base(converter, ctx, outlinedFunctions, kAwaitValue,
+ kAwaitValueAndExecute) {}
+
+ Value
+ getReplacementValue(Operation *op, Value operand,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op->getLoc();
+ auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
+
+ // Get the underlying value type from the `async.value`.
+ auto await = cast<AwaitOp>(op);
+ auto valueType = await.operand().getType().cast<ValueType>().getValueType();
+
+ // Get a pointer to an async value storage from the runtime.
+ auto storage = rewriter.create<CallOp>(loc, kGetValueStorage,
+ TypeRange(i8Ptr), operand);
+
+ // Cast from i8* to the pointer pointer to LLVM type.
+ auto llvmValueType = getTypeConverter()->convertType(valueType);
+ auto castedStorage = rewriter.create<LLVM::BitcastOp>(
+ loc, LLVM::LLVMPointerType::get(llvmValueType.cast<LLVM::LLVMType>()),
+ storage.getResult(0));
+
+ // Load from the async value storage.
+ auto loaded = rewriter.create<LLVM::LoadOp>(loc, castedStorage.getResult());
+
+ // Cast from LLVM type to the expected value type. This cast will become
+ // no-op after lowering to LLVM.
+ return rewriter.create<LLVM::DialectCastOp>(loc, valueType, loaded);
+ }
+};
+
+/// Lowering for `async.await_all` operation.
class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
public:
explicit AwaitAllOpLowering(
- MLIRContext *ctx,
+ TypeConverter &converter, MLIRContext *ctx,
const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
- : Base(ctx, outlinedFunctions, kAwaitGroup, kAwaitAllAndExecute) {}
+ : Base(converter, ctx, outlinedFunctions, kAwaitGroup,
+ kAwaitAllAndExecute) {}
};
} // namespace
+//===----------------------------------------------------------------------===//
+// async.yield op lowerings to the corresponding async runtime function calls.
+//===----------------------------------------------------------------------===//
+
+class YieldOpLowering : public ConversionPattern {
+public:
+ explicit YieldOpLowering(
+ TypeConverter &converter, MLIRContext *ctx,
+ const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions)
+ : ConversionPattern(async::YieldOp::getOperationName(), 1, converter,
+ ctx),
+ outlinedFunctions(outlinedFunctions) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ // Check if yield operation is inside the outlined coroutine function.
+ auto func = op->template getParentOfType<FuncOp>();
+ auto outlined = outlinedFunctions.find(func);
+ if (outlined == outlinedFunctions.end())
+ return op->emitOpError(
+ "async.yield is not inside the outlined coroutine function");
+
+ Location loc = op->getLoc();
+ const CoroMachinery &coro = outlined->getSecond();
+
+ // Store yielded values into the async values storage and emplace them.
+ auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
+
+ for (auto tuple : llvm::zip(operands, coro.returnValues)) {
+ // Store `yieldValue` into the `asyncValue` storage.
+ Value yieldValue = std::get<0>(tuple);
+ Value asyncValue = std::get<1>(tuple);
+
+ // Get an opaque i8* pointer to an async value storage from the runtime.
+ auto storage = rewriter.create<CallOp>(loc, kGetValueStorage,
+ TypeRange(i8Ptr), asyncValue);
+
+ // Cast storage pointer to the yielded value type.
+ auto castedStorage = rewriter.create<LLVM::BitcastOp>(
+ loc,
+ LLVM::LLVMPointerType::get(
+ yieldValue.getType().cast<LLVM::LLVMType>()),
+ storage.getResult(0));
+
+ // Store the yielded value into the async value storage.
+ rewriter.create<LLVM::StoreOp>(loc, yieldValue,
+ castedStorage.getResult());
+
+ // Emplace the `async.value` to mark it ready.
+ rewriter.create<CallOp>(loc, kEmplaceValue, TypeRange(), asyncValue);
+ }
+
+ // Emplace the completion token to mark it ready.
+ rewriter.create<CallOp>(loc, kEmplaceToken, TypeRange(), coro.asyncToken);
+
+ // Original operation was replaced by the function call(s).
+ rewriter.eraseOp(op);
+
+ return success();
+ }
+
+private:
+ const llvm::DenseMap<FuncOp, CoroMachinery> &outlinedFunctions;
+};
+
//===----------------------------------------------------------------------===//
namespace {
@@ -818,15 +1042,38 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
ModuleOp module = getOperation();
SymbolTable symbolTable(module);
+ MLIRContext *ctx = &getContext();
+
// Outline all `async.execute` body regions into async functions (coroutines).
llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions;
+ // We use conversion to LLVM type to ensure that all `async.value` operands
+ // and results can be lowered to LLVM load and store operations.
+ LLVMTypeConverter llvmConverter(ctx);
+ llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes);
+
+ // Returns true if the `async.value` payload is convertible to LLVM.
+ auto isConvertibleToLlvm = [&](Type type) -> bool {
+ auto valueType = type.cast<ValueType>().getValueType();
+ return static_cast<bool>(llvmConverter.convertType(valueType));
+ };
+
WalkResult outlineResult = module.walk([&](ExecuteOp execute) {
+ // All operands and results must be convertible to LLVM.
+ if (!llvm::all_of(execute.operands().getTypes(), isConvertibleToLlvm)) {
+ execute.emitOpError("operands payload must be convertible to LLVM type");
+ return WalkResult::interrupt();
+ }
+ if (!llvm::all_of(execute.results().getTypes(), isConvertibleToLlvm)) {
+ execute.emitOpError("results payload must be convertible to LLVM type");
+ return WalkResult::interrupt();
+ }
+
// We currently do not support execute operations that have async value
// operands or produce async results.
- if (!execute.operands().empty() || !execute.results().empty()) {
- execute.emitOpError("can't outline async.execute op with async value "
- "operands or returned async results");
+ if (!execute.operands().empty()) {
+ execute.emitOpError(
+ "can't outline async.execute op with async value operands");
return WalkResult::interrupt();
}
@@ -852,26 +1099,44 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
addCoroutineIntrinsicsDeclarations(module);
addCRuntimeDeclarations(module);
- MLIRContext *ctx = &getContext();
-
// Convert async dialect types and operations to LLVM dialect.
AsyncRuntimeTypeConverter converter;
OwningRewritePatternList patterns;
+ // Convert async types in function signatures and function calls.
populateFuncOpTypeConversionPattern(patterns, ctx, converter);
- patterns.insert<CallOpOpConversion>(ctx);
- patterns.insert<AddRefOpLowering, DropRefOpLowering>(ctx);
- patterns.insert<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
- patterns.insert<AwaitOpLowering, AwaitAllOpLowering>(ctx, outlinedFunctions);
+ populateCallOpTypeConversionPattern(patterns, ctx, converter);
+
+ // Convert return operations inside async.execute regions.
+ patterns.insert<ReturnOpOpConversion>(converter, ctx);
+
+ // Lower async operations to async runtime API calls.
+ patterns.insert<AddRefOpLowering, DropRefOpLowering>(converter, ctx);
+ patterns.insert<CreateGroupOpLowering, AddToGroupOpLowering>(converter, ctx);
+
+ // Use LLVM type converter to automatically convert between the async value
+ // payload type and LLVM type when loading/storing from/to the async
+ // value storage which is an opaque i8* pointer using LLVM load/store ops.
+ patterns
+ .insert<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
+ llvmConverter, ctx, outlinedFunctions);
+ patterns.insert<YieldOpLowering>(llvmConverter, ctx, outlinedFunctions);
ConversionTarget target(*ctx);
target.addLegalOp<ConstantOp>();
target.addLegalDialect<LLVM::LLVMDialect>();
+
+ // All operations from Async dialect must be lowered to the runtime API calls.
target.addIllegalDialect<AsyncDialect>();
+
+ // Add dynamic legality constraints to apply conversions defined above.
target.addDynamicallyLegalOp<FuncOp>(
[&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
- target.addDynamicallyLegalOp<CallOp>(
- [&](CallOp op) { return converter.isLegal(op.getResultTypes()); });
+ target.addDynamicallyLegalOp<ReturnOp>(
+ [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); });
+ target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
+ return converter.isSignatureLegal(op.getCalleeType());
+ });
if (failed(applyPartialConversion(module, target, std::move(patterns))))
signalPassFailure();
diff --git a/mlir/lib/Conversion/AsyncToLLVM/CMakeLists.txt b/mlir/lib/Conversion/AsyncToLLVM/CMakeLists.txt
index 005712c6cb48..48c2e0155ae1 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/AsyncToLLVM/CMakeLists.txt
@@ -13,5 +13,7 @@ add_mlir_conversion_library(MLIRAsyncToLLVM
LINK_LIBS PUBLIC
MLIRAsync
MLIRLLVMIR
+ MLIRStandardOpsTransforms
+ MLIRStandardToLLVM
MLIRTransforms
)
diff --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
index 3bfed86aa996..45bdcb3733b8 100644
--- a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
@@ -114,6 +114,7 @@ static AsyncRuntime *getDefaultAsyncRuntimeInstance() {
return runtime.get();
}
+// Async token provides a mechanism to signal asynchronous operation completion.
struct AsyncToken : public RefCounted {
// AsyncToken created with a reference count of 2 because it will be returned
// to the `async.execute` caller and also will be later on emplaced by the
@@ -130,6 +131,28 @@ struct AsyncToken : public RefCounted {
std::vector<std::function<void()>> awaiters;
};
+// Async value provides a mechanism to access the result of asynchronous
+// operations. It owns the storage that is used to store/load the value of the
+// underlying type, and a flag to signal if the value is ready or not.
+struct AsyncValue : public RefCounted {
+ // AsyncValue similar to an AsyncToken created with a reference count of 2.
+ AsyncValue(AsyncRuntime *runtime, int32_t size)
+ : RefCounted(runtime, /*count=*/2), storage(size) {}
+
+ // Internal state below guarded by a mutex.
+ std::mutex mu;
+ std::condition_variable cv;
+
+ bool ready = false;
+ std::vector<std::function<void()>> awaiters;
+
+ // Use vector of bytes to store async value payload.
+ std::vector<int8_t> storage;
+};
+
+// Async group provides a mechanism to group together multiple async tokens or
+// values to await on all of them together (wait for the completion of all
+// tokens or values added to the group).
struct AsyncGroup : public RefCounted {
AsyncGroup(AsyncRuntime *runtime)
: RefCounted(runtime), pendingTokens(0), rank(0) {}
@@ -159,12 +182,18 @@ extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) {
refCounted->dropRef(count);
}
-// Create a new `async.token` in not-ready state.
+// Creates a new `async.token` in not-ready state.
extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
AsyncToken *token = new AsyncToken(getDefaultAsyncRuntimeInstance());
return token;
}
+// Creates a new `async.value` in not-ready state.
+extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int32_t size) {
+ AsyncValue *value = new AsyncValue(getDefaultAsyncRuntimeInstance(), size);
+ return value;
+}
+
// Create a new `async.group` in empty state.
extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() {
AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntimeInstance());
@@ -228,18 +257,45 @@ extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
token->dropRef();
}
+// Switches `async.value` to ready state and runs all awaiters.
+extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) {
+ // Make sure that `dropRef` does not destroy the mutex owned by the lock.
+ {
+ std::unique_lock<std::mutex> lock(value->mu);
+ value->ready = true;
+ value->cv.notify_all();
+ for (auto &awaiter : value->awaiters)
+ awaiter();
+ }
+
+ // Async values created with a ref count `2` to keep value alive until the
+ // async task completes. Drop this reference explicitly when value emplaced.
+ value->dropRef();
+}
+
extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
std::unique_lock<std::mutex> lock(token->mu);
if (!token->ready)
token->cv.wait(lock, [token] { return token->ready; });
}
+extern "C" void mlirAsyncRuntimeAwaitValue(AsyncValue *value) {
+ std::unique_lock<std::mutex> lock(value->mu);
+ if (!value->ready)
+ value->cv.wait(lock, [value] { return value->ready; });
+}
+
extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
std::unique_lock<std::mutex> lock(group->mu);
if (group->pendingTokens != 0)
group->cv.wait(lock, [group] { return group->pendingTokens == 0; });
}
+// Returns a pointer to the storage owned by the async value.
+extern "C" ValueStorage mlirAsyncRuntimeGetValueStorage(AsyncValue *value) {
+ return value->storage.data();
+}
+
extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
(*resume)(handle);
}
@@ -255,6 +311,17 @@ extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,
token->awaiters.push_back([execute]() { execute(); });
}
+extern "C" void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *value,
+ CoroHandle handle,
+ CoroResume resume) {
+ std::unique_lock<std::mutex> lock(value->mu);
+ auto execute = [handle, resume]() { (*resume)(handle); };
+ if (value->ready)
+ execute();
+ else
+ value->awaiters.push_back([execute]() { execute(); });
+}
+
extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group,
CoroHandle handle,
CoroResume resume) {
diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
index dadb28dbc082..dce0cf89628f 100644
--- a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
+++ b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
@@ -211,3 +211,44 @@ func @async_group_await_all(%arg0: f32, %arg1: memref<1xf32>) {
// Emplace result token.
// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_1]])
+
+// -----
+
+// CHECK-LABEL: execute_and_return_f32
+func @execute_and_return_f32() -> f32 {
+ // CHECK: %[[RET:.*]]:2 = call @async_execute_fn
+ %token, %result = async.execute -> !async.value<f32> {
+ %c0 = constant 123.0 : f32
+ async.yield %c0 : f32
+ }
+
+ // CHECK: %[[STORAGE:.*]] = call @mlirAsyncRuntimeGetValueStorage(%[[RET]]#1)
+ // CHECK: %[[ST_F32:.*]] = llvm.bitcast %[[STORAGE]]
+ // CHECK: %[[LOADED:.*]] = llvm.load %[[ST_F32]] : !llvm.ptr<float>
+ // CHECK: %[[CASTED:.*]] = llvm.mlir.cast %[[LOADED]] : !llvm.float to f32
+ %0 = async.await %result : !async.value<f32>
+
+ return %0 : f32
+}
+
+// Function outlined from the async.execute operation.
+// CHECK-LABEL: func private @async_execute_fn()
+// CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateToken()
+// CHECK: %[[VALUE:.*]] = call @mlirAsyncRuntimeCreateValue
+// CHECK: %[[HDL:.*]] = llvm.call @llvm.coro.begin
+
+// Suspend coroutine in the beginning.
+// CHECK: call @mlirAsyncRuntimeExecute(%[[HDL]],
+// CHECK: llvm.call @llvm.coro.suspend
+
+// Emplace result value.
+// CHECK: %[[CST:.*]] = constant 1.230000e+02 : f32
+// CHECK: %[[LLVM_CST:.*]] = llvm.mlir.cast %[[CST]] : f32 to !llvm.float
+// CHECK: %[[STORAGE:.*]] = call @mlirAsyncRuntimeGetValueStorage(%[[VALUE]])
+// CHECK: %[[ST_F32:.*]] = llvm.bitcast %[[STORAGE]]
+// CHECK: llvm.store %[[LLVM_CST]], %[[ST_F32]] : !llvm.ptr<float>
+// CHECK: call @mlirAsyncRuntimeEmplaceValue(%[[VALUE]])
+
+// Emplace result token.
+// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[TOKEN]])
+
diff --git a/mlir/test/mlir-cpu-runner/async-value.mlir b/mlir/test/mlir-cpu-runner/async-value.mlir
new file mode 100644
index 000000000000..44b3b29e8491
--- /dev/null
+++ b/mlir/test/mlir-cpu-runner/async-value.mlir
@@ -0,0 +1,64 @@
+// RUN: mlir-opt %s -async-ref-counting \
+// RUN: -convert-async-to-llvm \
+// RUN: -convert-vector-to-llvm \
+// RUN: -convert-std-to-llvm \
+// RUN: | mlir-cpu-runner \
+// RUN: -e main -entry-point-result=void -O0 \
+// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext \
+// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext \
+// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_async_runtime%shlibext \
+// RUN: | FileCheck %s --dump-input=always
+
+func @main() {
+
+ // ------------------------------------------------------------------------ //
+ // Blocking async.await outside of the async.execute.
+ // ------------------------------------------------------------------------ //
+ %token, %result = async.execute -> !async.value<f32> {
+ %0 = constant 123.456 : f32
+ async.yield %0 : f32
+ }
+ %1 = async.await %result : !async.value<f32>
+
+ // CHECK: 123.456
+ vector.print %1 : f32
+
+ // ------------------------------------------------------------------------ //
+ // Non-blocking async.await inside the async.execute
+ // ------------------------------------------------------------------------ //
+ %token0, %result0 = async.execute -> !async.value<f32> {
+ %token1, %result2 = async.execute -> !async.value<f32> {
+ %2 = constant 456.789 : f32
+ async.yield %2 : f32
+ }
+ %3 = async.await %result2 : !async.value<f32>
+ async.yield %3 : f32
+ }
+ %4 = async.await %result0 : !async.value<f32>
+
+ // CHECK: 456.789
+ vector.print %4 : f32
+
+ // ------------------------------------------------------------------------ //
+ // Memref allocated inside async.execute region.
+ // ------------------------------------------------------------------------ //
+ %token2, %result2 = async.execute[%token0] -> !async.value<memref<f32>> {
+ %5 = alloc() : memref<f32>
+ %c0 = constant 987.654 : f32
+ store %c0, %5[]: memref<f32>
+ async.yield %5 : memref<f32>
+ }
+ %6 = async.await %result2 : !async.value<memref<f32>>
+ %7 = memref_cast %6 : memref<f32> to memref<*xf32>
+
+ // CHECK: Unranked Memref
+ // CHECK-SAME: rank = 0 offset = 0 sizes = [] strides = []
+ // CHECK-NEXT: [987.654]
+ call @print_memref_f32(%7): (memref<*xf32>) -> ()
+ dealloc %6 : memref<f32>
+
+ return
+}
+
+func private @print_memref_f32(memref<*xf32>)
+ attributes { llvm.emit_c_interface }
More information about the llvm-branch-commits
mailing list