[Mlir-commits] [mlir] 61dce0f - [mlir] Add async.await operation to async dialect

Eugene Zhulenev llvmlistbot at llvm.org
Mon Oct 12 21:05:47 PDT 2020


Author: Eugene Zhulenev
Date: 2020-10-12T21:05:36-07:00
New Revision: 61dce0f308e35df1edbd3061af339a3aff8d1f35

URL: https://github.com/llvm/llvm-project/commit/61dce0f308e35df1edbd3061af339a3aff8d1f35
DIFF: https://github.com/llvm/llvm-project/commit/61dce0f308e35df1edbd3061af339a3aff8d1f35.diff

LOG: [mlir] Add async.await operation to async dialect

Add async.await operation to "unwrap" async.values

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D89137

Added: 
    mlir/test/Dialect/Async/verify.mlir

Modified: 
    mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
    mlir/lib/Dialect/Async/IR/Async.cpp
    mlir/test/Dialect/Async/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
index fbdbdb92302f..2382253eff17 100644
--- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
+++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
@@ -75,7 +75,6 @@ def Async_ExecuteOp : Async_Op<"execute", [AttrSizedOperandSegments]> {
 
   let printer = [{ return ::print(p, *this); }];
   let parser = [{ return ::parse$cppClass(parser, result); }];
-
   let verifier = [{ return ::verify(*this); }];
 }
 
@@ -94,4 +93,47 @@ def Async_YieldOp :
   let verifier = [{ return ::verify(*this); }];
 }
 
+def Async_AwaitOp : Async_Op<"await", [NoSideEffect]> {
+  let summary = "waits for the argument to become ready";
+  let description = [{
+    The `async.await` operation waits until the argument becomes ready, and for
+    the `async.value` arguments it unwraps the underlying value
+
+    Example:
+
+    ```mlir
+    %0 = ... : !async.token
+    async.await %0 : !async.token
+
+    %1 = ... : !async.value<f32>
+    %2 = async.await %1 : !async.value<f32>
+    ```
+  }];
+
+  let arguments = (ins Async_AnyValueOrTokenType:$operand);
+  let results = (outs Optional<AnyType>:$result);
+
+  let skipDefaultBuilders = 1;
+
+  let builders = [
+    OpBuilder<"mlir::OpBuilder &builder, OperationState &result,"
+              "Value operand, ArrayRef<NamedAttribute> attrs = {}">,
+  ];
+
+  let extraClassDeclaration = [{
+    Optional<Type> getResultType() {
+      if (getResultTypes().empty()) return None;
+      return getResultTypes()[0];
+    }
+  }];
+
+  let assemblyFormat = [{
+    attr-dict $operand `:` custom<AwaitResultType>(
+      type($operand), type($result)
+    )
+  }];
+
+  let verifier = [{ return ::verify(*this); }];
+}
+
 #endif // ASYNC_OPS

diff  --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index eb5e65b5dc93..754665d8bafc 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -250,5 +250,54 @@ static LogicalResult verify(ExecuteOp op) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+/// AwaitOp
+//===----------------------------------------------------------------------===//
+
+void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand,
+                    ArrayRef<NamedAttribute> attrs) {
+  result.addOperands({operand});
+  result.attributes.append(attrs.begin(), attrs.end());
+
+  // Add unwrapped async.value type to the returned values types.
+  if (auto valueType = operand.getType().dyn_cast<ValueType>())
+    result.addTypes(valueType.getValueType());
+}
+
+static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType,
+                                        Type &resultType) {
+  if (parser.parseType(operandType))
+    return failure();
+
+  // Add unwrapped async.value type to the returned values types.
+  if (auto valueType = operandType.dyn_cast<ValueType>())
+    resultType = valueType.getValueType();
+
+  return success();
+}
+
+static void printAwaitResultType(OpAsmPrinter &p, Type operandType,
+                                 Type resultType) {
+  p << operandType;
+}
+
+static LogicalResult verify(AwaitOp op) {
+  Type argType = op.operand().getType();
+
+  // Awaiting on a token does not have any results.
+  if (argType.isa<TokenType>() && !op.getResultTypes().empty())
+    return op.emitOpError("awaiting on a token must have empty result");
+
+  // Awaiting on a value unwraps the async value type.
+  if (auto value = argType.dyn_cast<ValueType>()) {
+    if (*op.getResultType() != value.getValueType())
+      return op.emitOpError()
+             << "result type " << *op.getResultType()
+             << " does not match async value type " << value.getValueType();
+  }
+
+  return success();
+}
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"

diff  --git a/mlir/test/Dialect/Async/ops.mlir b/mlir/test/Dialect/Async/ops.mlir
index 371cea7e4c06..8784b6f05a08 100644
--- a/mlir/test/Dialect/Async/ops.mlir
+++ b/mlir/test/Dialect/Async/ops.mlir
@@ -106,3 +106,17 @@ func @empty_tokens_or_values_operands() {
   %token4 = async.execute [] { async.yield }
   return
 }
+
+// CHECK-LABEL: @await_token
+func @await_token(%arg0: !async.token) {
+  // CHECK: async.await %arg0
+  async.await %arg0 : !async.token
+  return
+}
+
+// CHECK-LABEL: @await_value
+func @await_value(%arg0: !async.value<f32>) -> f32 {
+  // CHECK: async.await %arg0
+  %0 = async.await %arg0 : !async.value<f32>
+  return %0 : f32
+}

diff  --git a/mlir/test/Dialect/Async/verify.mlir b/mlir/test/Dialect/Async/verify.mlir
new file mode 100644
index 000000000000..9d6c43cbcf49
--- /dev/null
+++ b/mlir/test/Dialect/Async/verify.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s
+
+// FileCheck test must have at least one CHECK statement.
+// CHECK-LABEL: @no_op
+func @no_op(%arg0: !async.token) {
+  return
+}
+
+// -----
+
+func @wrong_async_await_arg_type(%arg0: f32) {
+  // expected-error @+1 {{'async.await' op operand #0 must be async value type or token type, but got 'f32'}}
+  async.await %arg0 : f32
+}
+
+// -----
+
+func @wrong_async_await_result_type(%arg0: !async.value<f32>) {
+  // expected-error @+1 {{'async.await' op result type 'f64' does not match async value type 'f32'}}
+  %0 = "async.await"(%arg0): (!async.value<f32>) -> f64
+}


        


More information about the Mlir-commits mailing list