[Mlir-commits] [mlir] f507aa1 - [mlir] Implement lowering to LLVM of async.execute ops with token dependencies

Eugene Zhulenev llvmlistbot at llvm.org
Fri Oct 30 05:59:12 PDT 2020


Author: Eugene Zhulenev
Date: 2020-10-30T05:59:03-07:00
New Revision: f507aa17b791af3028d321ccfb83c5bfaf315f02

URL: https://github.com/llvm/llvm-project/commit/f507aa17b791af3028d321ccfb83c5bfaf315f02
DIFF: https://github.com/llvm/llvm-project/commit/f507aa17b791af3028d321ccfb83c5bfaf315f02.diff

LOG: [mlir] Implement lowering to LLVM of async.execute ops with token dependencies

Add support for lowering `async.execute` operations with token dependencies

Example:

```
%dep = ... : !async.token
%token = async.execute[%dep] {
...
}
```

Token dependencies lowered to `async.await` operations inside the outline coroutine body.

Reviewed By: herhut, mehdi_amini, ftynse

Differential Revision: https://reviews.llvm.org/D89958

Added: 
    

Modified: 
    mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
    mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
    mlir/test/mlir-cpu-runner/async.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 9a99bf00c08c..5233d1db179b 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -462,14 +462,15 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
 
   OpBuilder moduleBuilder(module.getBody()->getTerminator());
 
-  // Get values captured by the async region
-  llvm::SetVector<mlir::Value> usedAbove;
-  getUsedValuesDefinedAbove(execute.body(), usedAbove);
-
-  // Collect types of the captured values.
-  auto usedAboveTypes =
-      llvm::map_range(usedAbove, [](Value value) { return value.getType(); });
-  SmallVector<Type, 4> inputTypes(usedAboveTypes.begin(), usedAboveTypes.end());
+  // Collect all outlined function inputs.
+  llvm::SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
+                                              execute.dependencies().end());
+  getUsedValuesDefinedAbove(execute.body(), functionInputs);
+
+  // Collect types for the outlined function inputs and outputs.
+  auto typesRange = llvm::map_range(
+      functionInputs, [](Value value) { return value.getType(); });
+  SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
   auto outputTypes = execute.getResultTypes();
 
   auto funcType = moduleBuilder.getFunctionType(inputTypes, outputTypes);
@@ -510,14 +511,19 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
   Block *resume = addSuspensionPoint(coro, coroSave.getResult(0),
                                      entryBlock->getTerminator());
 
-  // Map from values defined above the execute op to the function arguments.
+  // Await on all dependencies before starting to execute the body region.
+  builder.setInsertionPointToStart(resume);
+  for (size_t i = 0; i < execute.dependencies().size(); ++i)
+    builder.create<AwaitOp>(loc, func.getArgument(i));
+
+  // Map from function inputs defined above the execute op to the function
+  // arguments.
   BlockAndValueMapping valueMapping;
-  valueMapping.map(usedAbove, func.getArguments());
+  valueMapping.map(functionInputs, func.getArguments());
 
   // Clone all operations from the execute operation body into the outlined
   // function body, and replace all `async.yield` operations with a call
   // to async runtime to emplace the result token.
-  builder.setInsertionPointToStart(resume);
   for (Operation &op : execute.body().getOps()) {
     if (isa<async::YieldOp>(op)) {
       builder.create<CallOp>(loc, kEmplaceToken, Type(), coro.asyncToken);
@@ -528,9 +534,9 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
 
   // Replace the original `async.execute` with a call to outlined function.
   OpBuilder callBuilder(execute);
-  SmallVector<Value, 4> usedAboveArgs(usedAbove.begin(), usedAbove.end());
-  auto callOutlinedFunc = callBuilder.create<CallOp>(
-      loc, func.getName(), execute.getResultTypes(), usedAboveArgs);
+  auto callOutlinedFunc =
+      callBuilder.create<CallOp>(loc, func.getName(), execute.getResultTypes(),
+                                 functionInputs.getArrayRef());
   execute.replaceAllUsesWith(callOutlinedFunc.getResults());
   execute.erase();
 
@@ -673,13 +679,11 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
   llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions;
 
   WalkResult outlineResult = module.walk([&](ExecuteOp execute) {
-    // We currently do not support execute operations that take async
-    // token dependencies, async value arguments or produce async results.
-    if (!execute.dependencies().empty() || !execute.operands().empty() ||
-        !execute.results().empty()) {
-      execute.emitOpError(
-          "Can't outline async.execute op with async dependencies, arguments "
-          "or returned async results");
+    // We currently do not support execute operations that have async value
+    // operands or produce async results.
+    if (!execute.operands().empty() || !execute.results().empty()) {
+      execute.emitOpError("can't outline async.execute op with async value "
+                          "operands or returned async results");
       return WalkResult::interrupt();
     }
 

diff  --git a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
index f838b8fb57dc..f8287dd0a360 100644
--- a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
+++ b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
@@ -15,7 +15,7 @@ func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) {
 }
 
 // Function outlined from the async.execute operation.
-// CHECK: func @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>)
+// CHECK-LABEL: func @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>)
 // CHECK-SAME: -> !llvm.ptr<i8> attributes {sym_visibility = "private"}
 
 // Create token for return op, and mark a function as a coroutine.
@@ -79,7 +79,7 @@ func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
 }
 
 // Function outlined from the inner async.execute operation.
-// CHECK: func @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>, %arg2: index)
+// CHECK-LABEL: func @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>, %arg2: index)
 // CHECK-SAME: -> !llvm.ptr<i8> attributes {sym_visibility = "private"}
 // CHECK: %[[RET_0:.*]] = call @mlirAsyncRuntimeCreateToken()
 // CHECK: %[[HDL_0:.*]] = llvm.call @llvm.coro.begin
@@ -89,7 +89,7 @@ func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
 // CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_0]])
 
 // Function outlined from the outer async.execute operation.
-// CHECK: func @async_execute_fn_0(%arg0: f32, %arg1: memref<1xf32>, %arg2: f32)
+// CHECK-LABEL: func @async_execute_fn_0(%arg0: f32, %arg1: memref<1xf32>, %arg2: f32)
 // CHECK-SAME: -> !llvm.ptr<i8> attributes {sym_visibility = "private"}
 // CHECK: %[[RET_1:.*]] = call @mlirAsyncRuntimeCreateToken()
 // CHECK: %[[HDL_1:.*]] = llvm.call @llvm.coro.begin
@@ -108,4 +108,52 @@ func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
 // CHECK: store %arg2, %arg1[%c0] : memref<1xf32>
 // CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_1]])
 
+// -----
+
+// CHECK-LABEL: async_execute_token_dependency
+func @async_execute_token_dependency(%arg0: f32, %arg1: memref<1xf32>) {
+  // CHECK: %0 = call @async_execute_fn(%arg0, %arg1)
+  %token = async.execute {
+    %c0 = constant 0 : index
+    store %arg0, %arg1[%c0] : memref<1xf32>
+    async.yield
+  }
+  // CHECK: %1 = call @async_execute_fn_0(%0, %arg0, %arg1)
+  %token_0 = async.execute [%token] {
+    %c0 = constant 0 : index
+    store %arg0, %arg1[%c0] : memref<1xf32>
+    async.yield
+  }
+  return
+}
+
+// Function outlined from the first async.execute operation.
+// CHECK-LABEL: func @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>)
+// CHECK-SAME: -> !llvm.ptr<i8> attributes {sym_visibility = "private"}
+// CHECK: %[[RET_0:.*]] = call @mlirAsyncRuntimeCreateToken()
+// CHECK: %[[HDL_0:.*]] = llvm.call @llvm.coro.begin
+// CHECK: call @mlirAsyncRuntimeExecute
+// CHECK: llvm.call @llvm.coro.suspend
+// CHECK: store %arg0, %arg1[%c0] : memref<1xf32>
+// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_0]])
+
+// Function outlined from the second async.execute operation with dependency.
+// CHECK-LABEL: func @async_execute_fn_0(%arg0: !llvm.ptr<i8>, %arg1: f32, %arg2: memref<1xf32>)
+// CHECK-SAME: -> !llvm.ptr<i8> attributes {sym_visibility = "private"}
+// CHECK: %[[RET_1:.*]] = call @mlirAsyncRuntimeCreateToken()
+// CHECK: %[[HDL_1:.*]] = llvm.call @llvm.coro.begin
+
+// Suspend coroutine in the beginning.
+// CHECK: call @mlirAsyncRuntimeExecute(%[[HDL_1]],
+// CHECK: llvm.call @llvm.coro.suspend
+
+// Suspend coroutine second time waiting for the completion of token dependency.
+// CHECK: llvm.call @llvm.coro.save
+// CHECK: call @mlirAsyncRuntimeAwaitTokenAndExecute(%arg0, %[[HDL_1]],
+// CHECK: llvm.call @llvm.coro.suspend
+
+// Emplace result token after second resumption.
+// CHECK: store %arg1, %arg2[%c0] : memref<1xf32>
+// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_1]])
+
 

diff  --git a/mlir/test/mlir-cpu-runner/async.mlir b/mlir/test/mlir-cpu-runner/async.mlir
index ce9cdb14d4c7..9ec4de98cc6b 100644
--- a/mlir/test/mlir-cpu-runner/async.mlir
+++ b/mlir/test/mlir-cpu-runner/async.mlir
@@ -41,8 +41,15 @@ func @main() {
     call @mlirAsyncRuntimePrintCurrentThreadId(): () -> ()
     call @print_memref_f32(%U): (memref<*xf32>) -> ()
 
-    %inner = async.execute {
+    // No op async region to create a token for testing async dependency.
+    %noop = async.execute {
       // CHECK: Current thread id: [[THREAD1:.*]]
+      call @mlirAsyncRuntimePrintCurrentThreadId(): () -> ()
+      async.yield
+    }
+
+    %inner = async.execute [%noop] {
+      // CHECK: Current thread id: [[THREAD2:.*]]
       // CHECK: [1, 2, 3, 0]
       store %c3, %A[%i2]: memref<4xf32>
       call @mlirAsyncRuntimePrintCurrentThreadId(): () -> ()
@@ -52,7 +59,7 @@ func @main() {
     }
     async.await %inner : !async.token
 
-    // CHECK: Current thread id: [[THREAD2:.*]]
+    // CHECK: Current thread id: [[THREAD3:.*]]
     // CHECK: [1, 2, 3, 4]
     store %c4, %A[%i3]: memref<4xf32>
     call @mlirAsyncRuntimePrintCurrentThreadId(): () -> ()


        


More information about the Mlir-commits mailing list