[Mlir-commits] [mlir] a5ddd92 - [mlir][async] Allow to call async.execute inside async.func
Eugene Zhulenev
llvmlistbot at llvm.org
Fri Jan 13 16:04:28 PST 2023
Author: yijiagu
Date: 2023-01-13T16:04:24-08:00
New Revision: a5ddd92035042204d42e108631142694692eabf1
URL: https://github.com/llvm/llvm-project/commit/a5ddd92035042204d42e108631142694692eabf1
DIFF: https://github.com/llvm/llvm-project/commit/a5ddd92035042204d42e108631142694692eabf1.diff
LOG: [mlir][async] Allow to call async.execute inside async.func
This change added support of calling async execute inside async.func.
Ex.
```
async.func @async_func_call_func() -> !async.token {
%token = async.execute {
%c0 = arith.constant 0 : index
memref.store %arg0, %arg1[%c0] : memref<1xf32>
async.yield
}
async.await %token : !async.token
return
}
```
Reviewed By: ezhulenev
Differential Revision: https://reviews.llvm.org/D141730
Added:
Modified:
mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
mlir/test/Dialect/Async/async-to-async-runtime.mlir
mlir/test/mlir-cpu-runner/async-func.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index e2120ed644c34..92bc5ee7a8665 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -840,8 +840,9 @@ void mlir::populateAsyncFuncToAsyncRuntimeConversionPatterns(
target.addDynamicallyLegalOp<AwaitOp, AwaitAllOp, YieldOp, cf::AssertOp>(
[coros](Operation *op) {
+ auto exec = op->getParentOfType<ExecuteOp>();
auto func = op->getParentOfType<func::FuncOp>();
- return coros->find(func) == coros->end();
+ return exec || coros->find(func) == coros->end();
});
}
diff --git a/mlir/test/Dialect/Async/async-to-async-runtime.mlir b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
index 38a88cc9de5b8..635a86ecdb4be 100644
--- a/mlir/test/Dialect/Async/async-to-async-runtime.mlir
+++ b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
@@ -455,3 +455,31 @@ async.func @async_func_await(%arg0: f32, %arg1: !async.value<f32>)
// CHECK-SAME: !async.value<f32>
// CHECK: async.coro.suspend %[[SAVED]]
// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME:.*]], ^[[CLEANUP:.*]]
+
+// -----
+// Async execute inside async func
+
+// CHECK-LABEL: @execute_in_async_func
+async.func @execute_in_async_func(%arg0: f32, %arg1: memref<1xf32>)
+ -> !async.token {
+ %token = async.execute {
+ %c0 = arith.constant 0 : index
+ memref.store %arg0, %arg1[%c0] : memref<1xf32>
+ async.yield
+ }
+ async.await %token : !async.token
+ return
+}
+// Call outlind async execute Function
+// CHECK: %[[RES:.*]] = call @async_execute_fn(
+// CHECK-SAME: %[[VALUE:arg[0-9]+]],
+// CHECK-SAME: %[[MEMREF:arg[0-9]+]]
+// CHECK-SAME: ) : (f32, memref<1xf32>) -> !async.token
+
+// Function outlined from the async.execute operation.
+// CHECK-LABEL: func private @async_execute_fn(
+// CHECK-SAME: %[[VALUE:arg[0-9]+]]: f32,
+// CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref<1xf32>
+// CHECK-SAME: ) -> !async.token
+// CHECK: %[[CST:.*]] = arith.constant 0 : index
+// CHECK: memref.store %[[VALUE]], %[[MEMREF]][%[[CST]]]
diff --git a/mlir/test/mlir-cpu-runner/async-func.mlir b/mlir/test/mlir-cpu-runner/async-func.mlir
index 6f21ba906b222..89b5358bfdba0 100644
--- a/mlir/test/mlir-cpu-runner/async-func.mlir
+++ b/mlir/test/mlir-cpu-runner/async-func.mlir
@@ -64,6 +64,19 @@ async.func @async_func_passed_memref(%arg0 : !async.value<memref<f32>>) -> !asyn
return
}
+async.func @async_execute_in_async_func(%arg0 : !async.value<memref<f32>>) -> !async.token {
+ %token0 = async.execute {
+ %unwrapped = async.await %arg0 : !async.value<memref<f32>>
+ %0 = memref.load %unwrapped[] : memref<f32>
+ %1 = arith.addf %0, %0 : f32
+ memref.store %1, %unwrapped[] : memref<f32>
+ async.yield
+ }
+
+ async.await %token0 : !async.token
+ return
+}
+
func.func @main() {
%false = arith.constant 0 : i1
@@ -140,6 +153,17 @@ func.func @main() {
// CHECK-NEXT: [0.5]
call @printMemrefF32(%6) : (memref<*xf32>) -> ()
+ // ------------------------------------------------------------------------ //
+ // async.execute inside async.func
+ // ------------------------------------------------------------------------ //
+ %token4 = async.call @async_execute_in_async_func(%result1) : (!async.value<memref<f32>>) -> !async.token
+ async.await %token4 : !async.token
+
+ // CHECK: Unranked Memref
+ // CHECK-SAME: rank = 0 offset = 0 sizes = [] strides = []
+ // CHECK-NEXT: [1]
+ call @printMemrefF32(%6) : (memref<*xf32>) -> ()
+
memref.dealloc %5 : memref<f32>
return
More information about the Mlir-commits
mailing list