[Mlir-commits] [mlir] a5ddd92 - [mlir][async] Allow to call async.execute inside async.func

Eugene Zhulenev llvmlistbot at llvm.org
Fri Jan 13 16:04:28 PST 2023


Author: yijiagu
Date: 2023-01-13T16:04:24-08:00
New Revision: a5ddd92035042204d42e108631142694692eabf1

URL: https://github.com/llvm/llvm-project/commit/a5ddd92035042204d42e108631142694692eabf1
DIFF: https://github.com/llvm/llvm-project/commit/a5ddd92035042204d42e108631142694692eabf1.diff

LOG: [mlir][async] Allow to call async.execute inside async.func

This change added support of calling async execute inside async.func.
Ex.

```
async.func @async_func_call_func() -> !async.token {
  %token = async.execute {
    %c0 = arith.constant 0 : index
    memref.store %arg0, %arg1[%c0] : memref<1xf32>
    async.yield
  }
  async.await %token : !async.token
  return
}
```

Reviewed By: ezhulenev

Differential Revision: https://reviews.llvm.org/D141730

Added: 
    

Modified: 
    mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
    mlir/test/Dialect/Async/async-to-async-runtime.mlir
    mlir/test/mlir-cpu-runner/async-func.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index e2120ed644c34..92bc5ee7a8665 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -840,8 +840,9 @@ void mlir::populateAsyncFuncToAsyncRuntimeConversionPatterns(
 
   target.addDynamicallyLegalOp<AwaitOp, AwaitAllOp, YieldOp, cf::AssertOp>(
       [coros](Operation *op) {
+        auto exec = op->getParentOfType<ExecuteOp>();
         auto func = op->getParentOfType<func::FuncOp>();
-        return coros->find(func) == coros->end();
+        return exec || coros->find(func) == coros->end();
       });
 }
 

diff  --git a/mlir/test/Dialect/Async/async-to-async-runtime.mlir b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
index 38a88cc9de5b8..635a86ecdb4be 100644
--- a/mlir/test/Dialect/Async/async-to-async-runtime.mlir
+++ b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
@@ -455,3 +455,31 @@ async.func @async_func_await(%arg0: f32, %arg1: !async.value<f32>)
 // CHECK-SAME: !async.value<f32>
 // CHECK: async.coro.suspend %[[SAVED]]
 // CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME:.*]], ^[[CLEANUP:.*]]
+
+// -----
+// Async execute inside async func
+
+// CHECK-LABEL: @execute_in_async_func
+async.func @execute_in_async_func(%arg0: f32, %arg1: memref<1xf32>)
+             -> !async.token {
+  %token = async.execute {
+    %c0 = arith.constant 0 : index
+    memref.store %arg0, %arg1[%c0] : memref<1xf32>
+    async.yield
+  }
+  async.await %token : !async.token
+  return
+}
+// Call outlind async execute Function
+// CHECK: %[[RES:.*]] = call @async_execute_fn(
+// CHECK-SAME:    %[[VALUE:arg[0-9]+]],
+// CHECK-SAME:    %[[MEMREF:arg[0-9]+]]
+// CHECK-SAME:  ) : (f32, memref<1xf32>) -> !async.token
+
+// Function outlined from the async.execute operation.
+// CHECK-LABEL: func private @async_execute_fn(
+// CHECK-SAME:    %[[VALUE:arg[0-9]+]]: f32,
+// CHECK-SAME:    %[[MEMREF:arg[0-9]+]]: memref<1xf32>
+// CHECK-SAME:  ) -> !async.token
+// CHECK:         %[[CST:.*]] = arith.constant 0 : index
+// CHECK:         memref.store %[[VALUE]], %[[MEMREF]][%[[CST]]]

diff  --git a/mlir/test/mlir-cpu-runner/async-func.mlir b/mlir/test/mlir-cpu-runner/async-func.mlir
index 6f21ba906b222..89b5358bfdba0 100644
--- a/mlir/test/mlir-cpu-runner/async-func.mlir
+++ b/mlir/test/mlir-cpu-runner/async-func.mlir
@@ -64,6 +64,19 @@ async.func @async_func_passed_memref(%arg0 : !async.value<memref<f32>>) -> !asyn
   return
 }
 
+async.func @async_execute_in_async_func(%arg0 : !async.value<memref<f32>>) -> !async.token {
+  %token0 = async.execute {
+    %unwrapped = async.await %arg0 : !async.value<memref<f32>>
+    %0 = memref.load %unwrapped[] : memref<f32>
+    %1 = arith.addf %0, %0 : f32
+    memref.store %1, %unwrapped[] : memref<f32>
+    async.yield
+  }
+
+  async.await %token0 : !async.token
+  return
+}
+
 
 func.func @main() {
   %false = arith.constant 0 : i1
@@ -140,6 +153,17 @@ func.func @main() {
   // CHECK-NEXT: [0.5]
   call @printMemrefF32(%6) : (memref<*xf32>) -> ()
 
+  // ------------------------------------------------------------------------ //
+  // async.execute inside async.func
+  // ------------------------------------------------------------------------ //
+  %token4 = async.call @async_execute_in_async_func(%result1) : (!async.value<memref<f32>>) -> !async.token
+  async.await %token4 : !async.token
+
+  // CHECK: Unranked Memref
+  // CHECK-SAME: rank = 0 offset = 0 sizes = [] strides = []
+  // CHECK-NEXT: [1]
+  call @printMemrefF32(%6) : (memref<*xf32>) -> ()
+
   memref.dealloc %5 : memref<f32>
 
   return


        


More information about the Mlir-commits mailing list