[Mlir-commits] [mlir] 655af65 - [MLIR] Add async.value type to Async dialect

Eugene Zhulenev llvmlistbot at llvm.org
Wed Sep 30 11:30:14 PDT 2020


Author: Eugene Zhulenev
Date: 2020-09-30T11:30:06-07:00
New Revision: 655af658c93bf7f133341e7eb5a2dfa176282781

URL: https://github.com/llvm/llvm-project/commit/655af658c93bf7f133341e7eb5a2dfa176282781
DIFF: https://github.com/llvm/llvm-project/commit/655af658c93bf7f133341e7eb5a2dfa176282781.diff

LOG: [MLIR] Add async.value type to Async dialect

Return values from async regions as !async.value<...>.

Reviewed By: mehdi_amini, csigg

Differential Revision: https://reviews.llvm.org/D88510

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Async/IR/Async.h
    mlir/include/mlir/Dialect/Async/IR/AsyncBase.td
    mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
    mlir/lib/Dialect/Async/IR/Async.cpp
    mlir/test/Dialect/Async/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Async/IR/Async.h b/mlir/include/mlir/Dialect/Async/IR/Async.h
index f61d07b7d0df..b1cf25ecea57 100644
--- a/mlir/include/mlir/Dialect/Async/IR/Async.h
+++ b/mlir/include/mlir/Dialect/Async/IR/Async.h
@@ -22,12 +22,28 @@
 namespace mlir {
 namespace async {
 
+namespace detail {
+struct ValueTypeStorage;
+} // namespace detail
+
 /// The token type to represent asynchronous operation completion.
 class TokenType : public Type::TypeBase<TokenType, Type, TypeStorage> {
 public:
   using Base::Base;
 };
 
+/// The value type to represent values returned from asynchronous operations.
+class ValueType
+    : public Type::TypeBase<ValueType, Type, detail::ValueTypeStorage> {
+public:
+  using Base::Base;
+
+  /// Get or create an async ValueType with the provided value type.
+  static ValueType get(Type valueType);
+
+  Type getValueType();
+};
+
 } // namespace async
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td b/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td
index ac67e9f1609d..2097f05747dd 100644
--- a/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td
+++ b/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td
@@ -39,4 +39,24 @@ def Async_TokenType : DialectType<AsyncDialect,
   }];
 }
 
+class Async_ValueType<Type type>
+    : DialectType<AsyncDialect,
+        And<[
+          CPred<"$_self.isa<::mlir::async::ValueType>()">,
+          SubstLeaves<"$_self",
+                      "$_self.cast<::mlir::async::ValueType>().getValueType()",
+                      type.predicate>
+       ]>, "async value type with " # type.description # " underlying type"> {
+  let typeDescription = [{
+    `async.value` represents a value returned by asynchronous operations,
+    which may or may not be available currently, but will be available at some
+    point in the future.
+  }];
+
+  Type valueType = type;
+}
+
+def Async_AnyValueType : Type<CPred<"$_self.isa<::mlir::async::ValueType>()">,
+                                    "async value type">;
+
 #endif // ASYNC_BASE_TD

diff  --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
index b84f7c402801..2dcc9a8f86fd 100644
--- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
+++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
@@ -40,24 +40,24 @@ def Async_ExecuteOp : Async_Op<"execute"> {
     state). All dependencies must be made explicit with async execute arguments
     (`async.token` or `async.value`).
 
-    Example:
-
     ```mlir
-    %0 = async.execute {
-      "compute0"(...)
-      async.yield
-    } : !async.token
+    %done, %values = async.execute {
+      %0 = "compute0"(...) : !some.type
+      async.yield %1 : f32
+    } : !async.token, !async.value<!some.type>
 
-    %1 = "compute1"(...)
+    %1 = "compute1"(...) : !some.type
     ```
   }];
 
   // TODO: Take async.tokens/async.values as arguments.
   let arguments = (ins );
-  let results = (outs Async_TokenType:$done);
+  let results = (outs Async_TokenType:$done,
+                      Variadic<Async_AnyValueType>:$values);
   let regions = (region SizedRegion<1>:$body);
 
-  let assemblyFormat = "$body attr-dict `:` type($done)";
+  let printer = [{ return ::mlir::async::print(p, *this); }];
+  let parser = [{ return ::mlir::async::parse$cppClass(parser, result); }];
 }
 
 def Async_YieldOp :
@@ -71,6 +71,8 @@ def Async_YieldOp :
   let arguments = (ins Variadic<AnyType>:$operands);
 
   let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
+
+  let verifier = [{ return ::mlir::async::verify(*this); }];
 }
 
 #endif // ASYNC_OPS

diff  --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index 61057870d301..4d9ede13f19c 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -19,8 +19,8 @@
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/raw_ostream.h"
 
-using namespace mlir;
-using namespace mlir::async;
+namespace mlir {
+namespace async {
 
 void AsyncDialect::initialize() {
   addOperations<
@@ -28,6 +28,7 @@ void AsyncDialect::initialize() {
 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
       >();
   addTypes<TokenType>();
+  addTypes<ValueType>();
 }
 
 /// Parse a type registered to this dialect.
@@ -39,6 +40,15 @@ Type AsyncDialect::parseType(DialectAsmParser &parser) const {
   if (keyword == "token")
     return TokenType::get(getContext());
 
+  if (keyword == "value") {
+    Type ty;
+    if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) {
+      parser.emitError(parser.getNameLoc(), "failed to parse async value type");
+      return Type();
+    }
+    return ValueType::get(ty);
+  }
+
   parser.emitError(parser.getNameLoc(), "unknown async type: ") << keyword;
   return Type();
 }
@@ -46,9 +56,113 @@ Type AsyncDialect::parseType(DialectAsmParser &parser) const {
 /// Print a type registered to this dialect.
 void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const {
   TypeSwitch<Type>(type)
-      .Case<TokenType>([&](Type) { os << "token"; })
+      .Case<TokenType>([&](TokenType) { os << "token"; })
+      .Case<ValueType>([&](ValueType valueTy) {
+        os << "value<";
+        os.printType(valueTy.getValueType());
+        os << '>';
+      })
       .Default([](Type) { llvm_unreachable("unexpected 'async' type kind"); });
 }
 
+//===----------------------------------------------------------------------===//
+/// ValueType
+//===----------------------------------------------------------------------===//
+
+namespace detail {
+
+// Storage for `async.value<T>` type, the only member is the wrapped type.
+struct ValueTypeStorage : public TypeStorage {
+  ValueTypeStorage(Type valueType) : valueType(valueType) {}
+
+  /// The hash key used for uniquing.
+  using KeyTy = Type;
+  bool operator==(const KeyTy &key) const { return key == valueType; }
+
+  /// Construction.
+  static ValueTypeStorage *construct(TypeStorageAllocator &allocator,
+                                     Type valueType) {
+    return new (allocator.allocate<ValueTypeStorage>())
+        ValueTypeStorage(valueType);
+  }
+
+  Type valueType;
+};
+
+} // namespace detail
+
+ValueType ValueType::get(Type valueType) {
+  return Base::get(valueType.getContext(), valueType);
+}
+
+Type ValueType::getValueType() { return getImpl()->valueType; }
+
+//===----------------------------------------------------------------------===//
+// YieldOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(YieldOp op) {
+  // Get the underlying value types from async values returned from the
+  // parent `async.execute` operation.
+  auto executeOp = op.getParentOfType<ExecuteOp>();
+  auto types = llvm::map_range(executeOp.values(), [](const OpResult &result) {
+    return result.getType().cast<ValueType>().getValueType();
+  });
+
+  if (!std::equal(types.begin(), types.end(), op.getOperandTypes().begin()))
+    return op.emitOpError("Operand types do not match the types returned from "
+                          "the parent ExecuteOp");
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+/// ExecuteOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter &p, ExecuteOp op) {
+  p << "async.execute ";
+  p.printRegion(op.body());
+  p.printOptionalAttrDict(op.getAttrs());
+  p << " : ";
+  p.printType(op.done().getType());
+  if (!op.values().empty())
+    p << ", ";
+  llvm::interleaveComma(op.values(), p, [&](const OpResult &result) {
+    p.printType(result.getType());
+  });
+}
+
+static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) {
+  MLIRContext *ctx = result.getContext();
+
+  // Parse asynchronous region.
+  Region *body = result.addRegion();
+  if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{},
+                         /*enableNameShadowing=*/false))
+    return failure();
+
+  // Parse operation attributes.
+  NamedAttrList attrs;
+  if (parser.parseOptionalAttrDict(attrs))
+    return failure();
+  result.addAttributes(attrs);
+
+  // Parse result types.
+  SmallVector<Type, 4> resultTypes;
+  if (parser.parseColonTypeList(resultTypes))
+    return failure();
+
+  // First result type must be an async token type.
+  if (resultTypes.empty() || resultTypes.front() != TokenType::get(ctx))
+    return failure();
+  parser.addTypesToList(resultTypes, result.types);
+
+  return success();
+}
+
+} // namespace async
+} // namespace mlir
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"

diff  --git a/mlir/test/Dialect/Async/ops.mlir b/mlir/test/Dialect/Async/ops.mlir
index 2f5d0123e215..d23bc003dd3a 100644
--- a/mlir/test/Dialect/Async/ops.mlir
+++ b/mlir/test/Dialect/Async/ops.mlir
@@ -1,16 +1,46 @@
 // RUN: mlir-opt  %s | FileCheck %s
 
-// CHECK-LABEL: @identity
-func @identity(%arg0 : !async.token) -> !async.token {
+// CHECK-LABEL: @identity_token
+func @identity_token(%arg0 : !async.token) -> !async.token {
   // CHECK: return %arg0 : !async.token
   return %arg0 : !async.token
 }
 
+// CHECK-LABEL: @identity_value
+func @identity_value(%arg0 : !async.value<f32>) -> !async.value<f32> {
+  // CHECK: return %arg0 : !async.value<f32>
+  return %arg0 : !async.value<f32>
+}
+
 // CHECK-LABEL: @empty_async_execute
 func @empty_async_execute() -> !async.token {
-  %0 = async.execute {
+  %done = async.execute {
     async.yield
   } : !async.token
 
-  return %0 : !async.token
+  // CHECK: return %done : !async.token
+  return %done : !async.token
+}
+
+// CHECK-LABEL: @return_async_value
+func @return_async_value() -> !async.value<f32> {
+  %done, %values = async.execute {
+    %cst = constant 1.000000e+00 : f32
+    async.yield %cst : f32
+  } : !async.token, !async.value<f32>
+
+  // CHECK: return %values : !async.value<f32>
+  return %values : !async.value<f32>
+}
+
+// CHECK-LABEL: @return_async_values
+func @return_async_values() -> (!async.value<f32>, !async.value<f32>) {
+  %done, %values:2 = async.execute {
+    %cst1 = constant 1.000000e+00 : f32
+    %cst2 = constant 2.000000e+00 : f32
+    async.yield %cst1, %cst2 : f32, f32
+  } : !async.token, !async.value<f32>, !async.value<f32>
+
+  // CHECK: return %values#0, %values#1 : !async.value<f32>, !async.value<f32>
+  return %values#0, %values#1 : !async.value<f32>, !async.value<f32>
 }


        


More information about the Mlir-commits mailing list