[Mlir-commits] [mlir] 61dce0f - [mlir] Add async.await operation to async dialect
Eugene Zhulenev
llvmlistbot at llvm.org
Mon Oct 12 21:05:47 PDT 2020
Author: Eugene Zhulenev
Date: 2020-10-12T21:05:36-07:00
New Revision: 61dce0f308e35df1edbd3061af339a3aff8d1f35
URL: https://github.com/llvm/llvm-project/commit/61dce0f308e35df1edbd3061af339a3aff8d1f35
DIFF: https://github.com/llvm/llvm-project/commit/61dce0f308e35df1edbd3061af339a3aff8d1f35.diff
LOG: [mlir] Add async.await operation to async dialect
Add async.await operation to "unwrap" async.values
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D89137
Added:
mlir/test/Dialect/Async/verify.mlir
Modified:
mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
mlir/lib/Dialect/Async/IR/Async.cpp
mlir/test/Dialect/Async/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
index fbdbdb92302f..2382253eff17 100644
--- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
+++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
@@ -75,7 +75,6 @@ def Async_ExecuteOp : Async_Op<"execute", [AttrSizedOperandSegments]> {
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
-
let verifier = [{ return ::verify(*this); }];
}
@@ -94,4 +93,47 @@ def Async_YieldOp :
let verifier = [{ return ::verify(*this); }];
}
+def Async_AwaitOp : Async_Op<"await", [NoSideEffect]> {
+ let summary = "waits for the argument to become ready";
+ let description = [{
+ The `async.await` operation waits until the argument becomes ready, and for
+ the `async.value` arguments it unwraps the underlying value
+
+ Example:
+
+ ```mlir
+ %0 = ... : !async.token
+ async.await %0 : !async.token
+
+ %1 = ... : !async.value<f32>
+ %2 = async.await %1 : !async.value<f32>
+ ```
+ }];
+
+ let arguments = (ins Async_AnyValueOrTokenType:$operand);
+ let results = (outs Optional<AnyType>:$result);
+
+ let skipDefaultBuilders = 1;
+
+ let builders = [
+ OpBuilder<"mlir::OpBuilder &builder, OperationState &result,"
+ "Value operand, ArrayRef<NamedAttribute> attrs = {}">,
+ ];
+
+ let extraClassDeclaration = [{
+ Optional<Type> getResultType() {
+ if (getResultTypes().empty()) return None;
+ return getResultTypes()[0];
+ }
+ }];
+
+ let assemblyFormat = [{
+ attr-dict $operand `:` custom<AwaitResultType>(
+ type($operand), type($result)
+ )
+ }];
+
+ let verifier = [{ return ::verify(*this); }];
+}
+
#endif // ASYNC_OPS
diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index eb5e65b5dc93..754665d8bafc 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -250,5 +250,54 @@ static LogicalResult verify(ExecuteOp op) {
return success();
}
+//===----------------------------------------------------------------------===//
+/// AwaitOp
+//===----------------------------------------------------------------------===//
+
+void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand,
+ ArrayRef<NamedAttribute> attrs) {
+ result.addOperands({operand});
+ result.attributes.append(attrs.begin(), attrs.end());
+
+ // Add unwrapped async.value type to the returned values types.
+ if (auto valueType = operand.getType().dyn_cast<ValueType>())
+ result.addTypes(valueType.getValueType());
+}
+
+static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType,
+ Type &resultType) {
+ if (parser.parseType(operandType))
+ return failure();
+
+ // Add unwrapped async.value type to the returned values types.
+ if (auto valueType = operandType.dyn_cast<ValueType>())
+ resultType = valueType.getValueType();
+
+ return success();
+}
+
+static void printAwaitResultType(OpAsmPrinter &p, Type operandType,
+ Type resultType) {
+ p << operandType;
+}
+
+static LogicalResult verify(AwaitOp op) {
+ Type argType = op.operand().getType();
+
+ // Awaiting on a token does not have any results.
+ if (argType.isa<TokenType>() && !op.getResultTypes().empty())
+ return op.emitOpError("awaiting on a token must have empty result");
+
+ // Awaiting on a value unwraps the async value type.
+ if (auto value = argType.dyn_cast<ValueType>()) {
+ if (*op.getResultType() != value.getValueType())
+ return op.emitOpError()
+ << "result type " << *op.getResultType()
+ << " does not match async value type " << value.getValueType();
+ }
+
+ return success();
+}
+
#define GET_OP_CLASSES
#include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
diff --git a/mlir/test/Dialect/Async/ops.mlir b/mlir/test/Dialect/Async/ops.mlir
index 371cea7e4c06..8784b6f05a08 100644
--- a/mlir/test/Dialect/Async/ops.mlir
+++ b/mlir/test/Dialect/Async/ops.mlir
@@ -106,3 +106,17 @@ func @empty_tokens_or_values_operands() {
%token4 = async.execute [] { async.yield }
return
}
+
+// CHECK-LABEL: @await_token
+func @await_token(%arg0: !async.token) {
+ // CHECK: async.await %arg0
+ async.await %arg0 : !async.token
+ return
+}
+
+// CHECK-LABEL: @await_value
+func @await_value(%arg0: !async.value<f32>) -> f32 {
+ // CHECK: async.await %arg0
+ %0 = async.await %arg0 : !async.value<f32>
+ return %0 : f32
+}
diff --git a/mlir/test/Dialect/Async/verify.mlir b/mlir/test/Dialect/Async/verify.mlir
new file mode 100644
index 000000000000..9d6c43cbcf49
--- /dev/null
+++ b/mlir/test/Dialect/Async/verify.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s
+
+// FileCheck test must have at least one CHECK statement.
+// CHECK-LABEL: @no_op
+func @no_op(%arg0: !async.token) {
+ return
+}
+
+// -----
+
+func @wrong_async_await_arg_type(%arg0: f32) {
+ // expected-error @+1 {{'async.await' op operand #0 must be async value type or token type, but got 'f32'}}
+ async.await %arg0 : f32
+}
+
+// -----
+
+func @wrong_async_await_result_type(%arg0: !async.value<f32>) {
+ // expected-error @+1 {{'async.await' op result type 'f64' does not match async value type 'f32'}}
+ %0 = "async.await"(%arg0): (!async.value<f32>) -> f64
+}
More information about the Mlir-commits
mailing list