[Mlir-commits] [mlir] f81f880 - [mlir] Lower async.func with async.coro and async.runtime operations

Eugene Zhulenev llvmlistbot at llvm.org
Mon Nov 7 09:54:03 PST 2022


Author: yijiagu
Date: 2022-11-07T09:53:58-08:00
New Revision: f81f880871e04ef0284af14a141a58905e81cdd9

URL: https://github.com/llvm/llvm-project/commit/f81f880871e04ef0284af14a141a58905e81cdd9
DIFF: https://github.com/llvm/llvm-project/commit/f81f880871e04ef0284af14a141a58905e81cdd9.diff

LOG: [mlir] Lower async.func with async.coro and async.runtime operations

Lower async.func with async.coro and async.runtime operations

- This patch modifies AsyncToAsyncRuntime pass to add lowering async.func ops with coroutine cfg.
Example:

```
async.func @foo() -> !async.value<f32> {
  %cst = arith.constant 42.0 : f32
  return %cst: f32
}
```

After lowering:

```
func.func @foo() -> !async.value<f32> attributes {passthrough = ["presplitcoroutine"]} {
    %0 = async.runtime.create : !async.value<f32>
    %1 = async.coro.id
    %2 = async.coro.begin %1
    cf.br ^bb1
  ^bb1:  // pred: ^bb0
    %cst = arith.constant 4.200000e+01 : f32
    async.runtime.store %cst, %0 : <f32>
    async.runtime.set_available %0 : !async.value<f32>
    cf.br ^bb2
  ^bb2:  // pred: ^bb1
    async.coro.free %1, %2
    cf.br ^bb3
  ^bb3:  // pred: ^bb2
    async.coro.end %2
    return %0 : !async.value<f32>
}
```

Reviewed By: ezhulenev

Differential Revision: https://reviews.llvm.org/D137462

Added: 
    mlir/test/mlir-cpu-runner/async-func.mlir

Modified: 
    mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
    mlir/test/Dialect/Async/async-to-async-runtime.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index 38f3717c70f9b..66c5b731b6e76 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -51,10 +51,6 @@ class AsyncToAsyncRuntimePass
 
 } // namespace
 
-//===----------------------------------------------------------------------===//
-// async.execute op outlining to the coroutine functions.
-//===----------------------------------------------------------------------===//
-
 /// Function targeted for coroutine transformation has two additional blocks at
 /// the end: coroutine cleanup and coroutine suspension.
 ///
@@ -64,6 +60,12 @@ namespace {
 struct CoroMachinery {
   func::FuncOp func;
 
+  // Async function returns an optional token, followed by some async values
+  //
+  //  async.func @foo() -> !async.value<T> {
+  //    %cst = arith.constant 42.0 : T
+  //    return %cst: T
+  //  }
   // Async execute region returns a completion token, and an async value for
   // each yielded value.
   //
@@ -71,12 +73,12 @@ struct CoroMachinery {
   //     %0 = arith.constant ... : T
   //     async.yield %0 : T
   //   }
-  Value asyncToken; // token representing completion of the async region
+  Optional<Value> asyncToken;               // returned completion token
   llvm::SmallVector<Value, 4> returnValues; // returned async values
 
   Value coroHandle; // coroutine handle (!async.coro.getHandle value)
   Block *entry;     // coroutine entry block
-  Block *setError;  // switch completion token and all values to error state
+  Optional<Block *> setError; // set returned values to error state
   Block *cleanup;   // coroutine cleanup block
   Block *suspend;   // coroutine suspension block
 };
@@ -87,13 +89,9 @@ struct CoroMachinery {
 /// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block
 /// that branches into preexisting entry block. Also inserts trailing blocks.
 ///
-/// The result types of the passed `func` must start with an `async.token`
+/// The result types of the passed `func` start with an optional `async.token`
 /// and be continued with some number of `async.value`s.
 ///
-/// The func given to this function needs to have been preprocessed to have
-/// either branch or yield ops as terminators. Branches to the cleanup block are
-/// inserted after each yield.
-///
 /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html
 ///
 ///  - `entry` block sets up the coroutine.
@@ -110,7 +108,7 @@ struct CoroMachinery {
 ///     ^entry(<function-arguments>):
 ///       %token = <async token> : !async.token    // create async runtime token
 ///       %value = <async value> : !async.value<T> // create async value
-///       %id = async.coro.getId                      // create a coroutine id
+///       %id = async.coro.getId                   // create a coroutine id
 ///       %hdl = async.coro.begin %id              // create a coroutine handle
 ///       cf.br ^preexisting_entry_block
 ///
@@ -142,11 +140,20 @@ static CoroMachinery setupCoroMachinery(func::FuncOp func) {
   // ------------------------------------------------------------------------ //
   // Allocate async token/values that we will return from a ramp function.
   // ------------------------------------------------------------------------ //
-  auto retToken =
-      builder.create<RuntimeCreateOp>(TokenType::get(ctx)).getResult();
+
+  // We treat TokenType as state update marker to represent side-effects of
+  // async computations
+  bool isStateful = func.getCallableResults().front().isa<TokenType>();
+
+  Optional<Value> retToken;
+  if (isStateful)
+    retToken.emplace(builder.create<RuntimeCreateOp>(TokenType::get(ctx)));
 
   llvm::SmallVector<Value, 4> retValues;
-  for (auto resType : func.getCallableResults().drop_front())
+  ArrayRef<Type> resValueTypes = isStateful
+                                     ? func.getCallableResults().drop_front()
+                                     : func.getCallableResults();
+  for (auto resType : resValueTypes)
     retValues.emplace_back(
         builder.create<RuntimeCreateOp>(resType).getResult());
 
@@ -179,26 +186,17 @@ static CoroMachinery setupCoroMachinery(func::FuncOp func) {
   // Mark the end of a coroutine: async.coro.end
   builder.create<CoroEndOp>(coroHdlOp.getHandle());
 
-  // Return created `async.token` and `async.values` from the suspend block.
-  // This will be the return value of a coroutine ramp function.
-  SmallVector<Value, 4> ret{retToken};
+  // Return created optional `async.token` and `async.values` from the suspend
+  // block. This will be the return value of a coroutine ramp function.
+  SmallVector<Value, 4> ret;
+  if (retToken)
+    ret.push_back(*retToken);
   ret.insert(ret.end(), retValues.begin(), retValues.end());
   builder.create<func::ReturnOp>(ret);
 
   // `async.await` op lowering will create resume blocks for async
   // continuations, and will conditionally branch to cleanup or suspend blocks.
 
-  for (Block &block : func.getBody().getBlocks()) {
-    if (&block == entryBlock || &block == cleanupBlock ||
-        &block == suspendBlock)
-      continue;
-    Operation *terminator = block.getTerminator();
-    if (auto yield = dyn_cast<YieldOp>(terminator)) {
-      builder.setInsertionPointToEnd(&block);
-      builder.create<cf::BranchOp>(cleanupBlock);
-    }
-  }
-
   // The switch-resumed API based coroutine should be marked with
   // coroutine.presplit attribute to mark the function as a coroutine.
   func->setAttr("passthrough", builder.getArrayAttr(
@@ -210,7 +208,7 @@ static CoroMachinery setupCoroMachinery(func::FuncOp func) {
   machinery.returnValues = retValues;
   machinery.coroHandle = coroHdlOp.getHandle();
   machinery.entry = entryBlock;
-  machinery.setError = nullptr; // created lazily only if needed
+  machinery.setError = None; // created lazily only if needed
   machinery.cleanup = cleanupBlock;
   machinery.suspend = suspendBlock;
   return machinery;
@@ -220,25 +218,31 @@ static CoroMachinery setupCoroMachinery(func::FuncOp func) {
 // runtime operations (see for example lowering of assert operation).
 static Block *setupSetErrorBlock(CoroMachinery &coro) {
   if (coro.setError)
-    return coro.setError;
+    return *coro.setError;
 
   coro.setError = coro.func.addBlock();
-  coro.setError->moveBefore(coro.cleanup);
+  (*coro.setError)->moveBefore(coro.cleanup);
 
   auto builder =
-      ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), coro.setError);
+      ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), *coro.setError);
 
   // Coroutine set_error block: set error on token and all returned values.
-  builder.create<RuntimeSetErrorOp>(coro.asyncToken);
+  if (coro.asyncToken)
+    builder.create<RuntimeSetErrorOp>(*coro.asyncToken);
+
   for (Value retValue : coro.returnValues)
     builder.create<RuntimeSetErrorOp>(retValue);
 
   // Branch into the cleanup block.
   builder.create<cf::BranchOp>(coro.cleanup);
 
-  return coro.setError;
+  return *coro.setError;
 }
 
+//===----------------------------------------------------------------------===//
+// async.execute op outlining to the coroutine functions.
+//===----------------------------------------------------------------------===//
+
 /// Outline the body region attached to the `async.execute` op into a standalone
 /// function.
 ///
@@ -382,6 +386,118 @@ class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> {
 };
 } // namespace
 
+//===----------------------------------------------------------------------===//
+// Convert async.func, async.return and async.call operations to non-blocking
+// operations based on llvm coroutine
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Convert async.func operation to func.func
+//===----------------------------------------------------------------------===//
+
+class AsyncFuncOpLowering : public OpConversionPattern<async::FuncOp> {
+public:
+  AsyncFuncOpLowering(MLIRContext *ctx,
+                      llvm::DenseMap<func::FuncOp, CoroMachinery> &coros)
+      : OpConversionPattern<async::FuncOp>(ctx), coros(coros) {}
+
+  LogicalResult
+  matchAndRewrite(async::FuncOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
+
+    auto newFuncOp =
+        rewriter.create<func::FuncOp>(loc, op.getName(), op.getFunctionType());
+
+    SymbolTable::setSymbolVisibility(newFuncOp,
+                                     SymbolTable::getSymbolVisibility(op));
+    // Copy over all attributes other than the name.
+    for (const auto &namedAttr : op->getAttrs()) {
+      if (namedAttr.getName() != SymbolTable::getSymbolAttrName())
+        newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
+    }
+
+    rewriter.inlineRegionBefore(op.getBody(), newFuncOp.getBody(),
+                                newFuncOp.end());
+
+    CoroMachinery coro = setupCoroMachinery(newFuncOp);
+    coros[newFuncOp] = coro;
+    // no initial suspend, we should hot-start
+
+    rewriter.eraseOp(op);
+    return success();
+  }
+
+private:
+  llvm::DenseMap<func::FuncOp, CoroMachinery> &coros;
+};
+
+//===----------------------------------------------------------------------===//
+// Convert async.call operation to func.call
+//===----------------------------------------------------------------------===//
+
+class AsyncCallOpLowering : public OpConversionPattern<async::CallOp> {
+public:
+  AsyncCallOpLowering(MLIRContext *ctx)
+      : OpConversionPattern<async::CallOp>(ctx) {}
+
+  LogicalResult
+  matchAndRewrite(async::CallOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<func::CallOp>(
+        op, op.getCallee(), op.getResultTypes(), op.getOperands());
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// Convert async.return operation to async.runtime operations.
+//===----------------------------------------------------------------------===//
+
+class AsyncReturnOpLowering : public OpConversionPattern<async::ReturnOp> {
+public:
+  AsyncReturnOpLowering(MLIRContext *ctx,
+                        llvm::DenseMap<func::FuncOp, CoroMachinery> &coros)
+      : OpConversionPattern<async::ReturnOp>(ctx), coros(coros) {}
+
+  LogicalResult
+  matchAndRewrite(async::ReturnOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto func = op->template getParentOfType<func::FuncOp>();
+    auto funcCoro = coros.find(func);
+    if (funcCoro == coros.end())
+      return rewriter.notifyMatchFailure(
+          op, "operation is not inside the async coroutine function");
+
+    Location loc = op->getLoc();
+    const CoroMachinery &coro = funcCoro->getSecond();
+    rewriter.setInsertionPointAfter(op);
+
+    // Store return values into the async values storage and switch async
+    // values state to available.
+    for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
+      Value returnValue = std::get<0>(tuple);
+      Value asyncValue = std::get<1>(tuple);
+      rewriter.create<RuntimeStoreOp>(loc, returnValue, asyncValue);
+      rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue);
+    }
+
+    if (coro.asyncToken)
+      // Switch the coroutine completion token to available state.
+      rewriter.create<RuntimeSetAvailableOp>(loc, *coro.asyncToken);
+
+    rewriter.eraseOp(op);
+    rewriter.create<cf::BranchOp>(loc, coro.cleanup);
+    return success();
+  }
+
+private:
+  llvm::DenseMap<func::FuncOp, CoroMachinery> &coros;
+};
+} // namespace
+
 //===----------------------------------------------------------------------===//
 // Convert async.await and async.await_all operations to the async.runtime.await
 // or async.runtime.await_and_resume operations.
@@ -393,11 +509,9 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
   using AwaitAdaptor = typename AwaitType::Adaptor;
 
 public:
-  AwaitOpLoweringBase(
-      MLIRContext *ctx,
-      llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions)
-      : OpConversionPattern<AwaitType>(ctx),
-        outlinedFunctions(outlinedFunctions) {}
+  AwaitOpLoweringBase(MLIRContext *ctx,
+                      llvm::DenseMap<func::FuncOp, CoroMachinery> &coros)
+      : OpConversionPattern<AwaitType>(ctx), coros(coros) {}
 
   LogicalResult
   matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor,
@@ -409,8 +523,8 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
 
     // Check if await operation is inside the outlined coroutine function.
     auto func = op->template getParentOfType<func::FuncOp>();
-    auto outlined = outlinedFunctions.find(func);
-    const bool isInCoroutine = outlined != outlinedFunctions.end();
+    auto funcCoro = coros.find(func);
+    const bool isInCoroutine = funcCoro != coros.end();
 
     Location loc = op->getLoc();
     Value operand = adaptor.getOperand();
@@ -436,7 +550,7 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
     // Inside the coroutine we convert await operation into coroutine suspension
     // point, and resume execution asynchronously.
     if (isInCoroutine) {
-      CoroMachinery &coro = outlined->getSecond();
+      CoroMachinery &coro = funcCoro->getSecond();
       Block *suspended = op->getBlock();
 
       ImplicitLocOpBuilder builder(loc, op, rewriter.getListener());
@@ -488,7 +602,7 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
   }
 
 private:
-  llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions;
+  llvm::DenseMap<func::FuncOp, CoroMachinery> &coros;
 };
 
 /// Lowering for `async.await` with a token operand.
@@ -531,24 +645,22 @@ class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
 
 class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
 public:
-  YieldOpLowering(
-      MLIRContext *ctx,
-      const llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions)
-      : OpConversionPattern<async::YieldOp>(ctx),
-        outlinedFunctions(outlinedFunctions) {}
+  YieldOpLowering(MLIRContext *ctx,
+                  const llvm::DenseMap<func::FuncOp, CoroMachinery> &coros)
+      : OpConversionPattern<async::YieldOp>(ctx), coros(coros) {}
 
   LogicalResult
   matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Check if yield operation is inside the async coroutine function.
     auto func = op->template getParentOfType<func::FuncOp>();
-    auto outlined = outlinedFunctions.find(func);
-    if (outlined == outlinedFunctions.end())
+    auto funcCoro = coros.find(func);
+    if (funcCoro == coros.end())
       return rewriter.notifyMatchFailure(
           op, "operation is not inside the async coroutine function");
 
     Location loc = op->getLoc();
-    const CoroMachinery &coro = outlined->getSecond();
+    const CoroMachinery &coro = funcCoro->getSecond();
 
     // Store yielded values into the async values storage and switch async
     // values state to available.
@@ -559,14 +671,18 @@ class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
       rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue);
     }
 
-    // Switch the coroutine completion token to available state.
-    rewriter.replaceOpWithNewOp<RuntimeSetAvailableOp>(op, coro.asyncToken);
+    if (coro.asyncToken)
+      // Switch the coroutine completion token to available state.
+      rewriter.create<RuntimeSetAvailableOp>(loc, *coro.asyncToken);
+
+    rewriter.eraseOp(op);
+    rewriter.create<cf::BranchOp>(loc, coro.cleanup);
 
     return success();
   }
 
 private:
-  const llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions;
+  const llvm::DenseMap<func::FuncOp, CoroMachinery> &coros;
 };
 
 //===----------------------------------------------------------------------===//
@@ -575,24 +691,22 @@ class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
 
 class AssertOpLowering : public OpConversionPattern<cf::AssertOp> {
 public:
-  AssertOpLowering(
-      MLIRContext *ctx,
-      llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions)
-      : OpConversionPattern<cf::AssertOp>(ctx),
-        outlinedFunctions(outlinedFunctions) {}
+  AssertOpLowering(MLIRContext *ctx,
+                   llvm::DenseMap<func::FuncOp, CoroMachinery> &coros)
+      : OpConversionPattern<cf::AssertOp>(ctx), coros(coros) {}
 
   LogicalResult
   matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Check if assert operation is inside the async coroutine function.
     auto func = op->template getParentOfType<func::FuncOp>();
-    auto outlined = outlinedFunctions.find(func);
-    if (outlined == outlinedFunctions.end())
+    auto funcCoro = coros.find(func);
+    if (funcCoro == coros.end())
       return rewriter.notifyMatchFailure(
           op, "operation is not inside the async coroutine function");
 
     Location loc = op->getLoc();
-    CoroMachinery &coro = outlined->getSecond();
+    CoroMachinery &coro = funcCoro->getSecond();
 
     Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op));
     rewriter.setInsertionPointToEnd(cont->getPrevNode());
@@ -607,7 +721,7 @@ class AssertOpLowering : public OpConversionPattern<cf::AssertOp> {
   }
 
 private:
-  llvm::DenseMap<func::FuncOp, CoroMachinery> &outlinedFunctions;
+  llvm::DenseMap<func::FuncOp, CoroMachinery> &coros;
 };
 
 //===----------------------------------------------------------------------===//
@@ -615,22 +729,23 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
   ModuleOp module = getOperation();
   SymbolTable symbolTable(module);
 
-  // Outline all `async.execute` body regions into async functions (coroutines).
-  llvm::DenseMap<func::FuncOp, CoroMachinery> outlinedFunctions;
+  // Functions with coroutine CFG setups, which are results of outlining
+  // `async.execute` body regions and converting async.func.
+  llvm::DenseMap<func::FuncOp, CoroMachinery> coros;
 
   module.walk([&](ExecuteOp execute) {
-    outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute));
+    coros.insert(outlineExecuteOp(symbolTable, execute));
   });
 
   LLVM_DEBUG({
-    llvm::dbgs() << "Outlined " << outlinedFunctions.size()
+    llvm::dbgs() << "Outlined " << coros.size()
                  << " functions built from async.execute operations\n";
   });
 
   // Returns true if operation is inside the coroutine.
   auto isInCoroutine = [&](Operation *op) -> bool {
     auto parentFunc = op->getParentOfType<func::FuncOp>();
-    return outlinedFunctions.find(parentFunc) != outlinedFunctions.end();
+    return coros.find(parentFunc) != coros.end();
   };
 
   // Lower async operations to async.runtime operations.
@@ -646,18 +761,23 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
   // Async lowering does not use type converter because it must preserve all
   // types for async.runtime operations.
   asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
+
+  // Lower async.func to func.func with coroutine cfg.
+  asyncPatterns.add<AsyncCallOpLowering>(ctx);
+  asyncPatterns.add<AsyncFuncOpLowering, AsyncReturnOpLowering>(ctx, coros);
+
   asyncPatterns.add<AwaitTokenOpLowering, AwaitValueOpLowering,
-                    AwaitAllOpLowering, YieldOpLowering>(ctx,
-                                                         outlinedFunctions);
+                    AwaitAllOpLowering, YieldOpLowering>(ctx, coros);
 
   // Lower assertions to conditional branches into error blocks.
-  asyncPatterns.add<AssertOpLowering>(ctx, outlinedFunctions);
+  asyncPatterns.add<AssertOpLowering>(ctx, coros);
 
   // All high level async operations must be lowered to the runtime operations.
   ConversionTarget runtimeTarget(*ctx);
-  runtimeTarget.addLegalDialect<AsyncDialect>();
+  runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
   runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
-  runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
+  runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp,
+                             async::FuncOp, async::CallOp, async::ReturnOp>();
 
   // Decide if structured control flow has to be lowered to branch-based CFG.
   runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) {
@@ -675,7 +795,7 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
   runtimeTarget.addDynamicallyLegalOp<cf::AssertOp>(
       [&](cf::AssertOp op) -> bool {
         auto func = op->getParentOfType<func::FuncOp>();
-        return outlinedFunctions.find(func) == outlinedFunctions.end();
+        return coros.find(func) == coros.end();
       });
 
   if (failed(applyPartialConversion(module, runtimeTarget,

diff  --git a/mlir/test/Dialect/Async/async-to-async-runtime.mlir b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
index d7ebfb9e77926..1551e55c90c08 100644
--- a/mlir/test/Dialect/Async/async-to-async-runtime.mlir
+++ b/mlir/test/Dialect/Async/async-to-async-runtime.mlir
@@ -433,3 +433,25 @@ func.func @clone_constants(%arg0: f32, %arg1: memref<1xf32>) {
 // CHECK-SAME:  ) -> !async.token
 // CHECK:         %[[CST:.*]] = arith.constant 0 : index
 // CHECK:         memref.store %[[VALUE]], %[[MEMREF]][%[[CST]]]
+
+// -----
+// Async Functions should be none blocking
+
+// CHECK-LABEL: @async_func_await
+async.func @async_func_await(%arg0: f32, %arg1: !async.value<f32>)
+              -> !async.token {
+  %0 = async.await %arg1 : !async.value<f32>
+  return
+}
+// Create token for return op, and mark a function as a coroutine.
+// CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token
+// CHECK: %[[ID:.*]] = async.coro.id
+// CHECK: %[[HDL:.*]] = async.coro.begin
+// CHECK:   cf.br ^[[ORIGIN_ENTRY:.*]]
+
+// CHECK: ^[[ORIGIN_ENTRY]]:
+// CHECK: %[[SAVED:.*]] = async.coro.save %[[HDL]]
+// CHECK: async.runtime.await_and_resume %[[arg1:.*]], %[[HDL]] :
+// CHECK-SAME: !async.value<f32>
+// CHECK: async.coro.suspend %[[SAVED]]
+// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME:.*]], ^[[CLEANUP:.*]]

diff  --git a/mlir/test/mlir-cpu-runner/async-func.mlir b/mlir/test/mlir-cpu-runner/async-func.mlir
new file mode 100644
index 0000000000000..8b3d728d4667f
--- /dev/null
+++ b/mlir/test/mlir-cpu-runner/async-func.mlir
@@ -0,0 +1,149 @@
+// RUN:   mlir-opt %s -pass-pipeline="builtin.module(async-to-async-runtime,func.func(async-runtime-ref-counting,async-runtime-ref-counting-opt),convert-async-to-llvm,func.func(convert-linalg-to-loops,convert-scf-to-cf),convert-linalg-to-llvm,convert-vector-to-llvm,func.func(convert-arith-to-llvm),convert-func-to-llvm,reconcile-unrealized-casts)" \
+// RUN: | mlir-cpu-runner                                                      \
+// RUN:     -e main -entry-point-result=void -O0                               \
+// RUN:     -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext  \
+// RUN:     -shared-libs=%mlir_lib_dir/libmlir_runner_utils%shlibext    \
+// RUN:     -shared-libs=%mlir_lib_dir/libmlir_async_runtime%shlibext   \
+// RUN: | FileCheck %s --dump-input=always
+
+// FIXME: https://github.com/llvm/llvm-project/issues/57231
+// UNSUPPORTED: hwasan
+
+async.func @async_func_empty() -> !async.token {
+  return
+}
+
+async.func @async_func_assert() -> !async.token {
+  %false = arith.constant 0 : i1
+  cf.assert %false, "error"
+  return
+}
+
+async.func @async_func_nested_assert() -> !async.token {
+  %token0 = async.call @async_func_assert() : () -> !async.token
+  async.await %token0 : !async.token
+  return
+}
+
+async.func @async_func_value_assert() -> !async.value<f32> {
+  %false = arith.constant 0 : i1
+  cf.assert %false, "error"
+  %0 = arith.constant 123.45 : f32
+  return %0 : f32
+}
+
+async.func @async_func_value_nested_assert() -> !async.value<f32> {
+  %value0 = async.call @async_func_value_assert() : () -> !async.value<f32>
+  %ret = async.await %value0 : !async.value<f32>
+  return %ret : f32
+}
+
+async.func @async_func_return_value() -> !async.value<f32> {
+  %0 = arith.constant 456.789 : f32
+  return %0 : f32
+}
+
+async.func @async_func_non_blocking_await() -> !async.value<f32> {
+  %value0 = async.call @async_func_return_value() : () -> !async.value<f32>
+  %1 = async.await %value0 : !async.value<f32>
+  return  %1 : f32
+}
+
+async.func @async_func_inside_memref() -> !async.value<memref<f32>> {
+  %0 = memref.alloc() : memref<f32>
+  %c0 = arith.constant 0.25 : f32
+  memref.store %c0, %0[] : memref<f32>
+  return %0 : memref<f32>
+}
+
+async.func @async_func_passed_memref(%arg0 : !async.value<memref<f32>>) -> !async.token {
+  %unwrapped = async.await %arg0 : !async.value<memref<f32>>
+  %0 = memref.load %unwrapped[] : memref<f32>
+  %1 = arith.addf %0, %0 : f32
+  memref.store %1, %unwrapped[] : memref<f32>
+  return
+}
+
+
+func.func @main() {
+  %false = arith.constant 0 : i1
+
+  // ------------------------------------------------------------------------ //
+  // Check that simple async.func completes without errors.
+  // ------------------------------------------------------------------------ //
+  %token0 = async.call @async_func_empty() : () -> !async.token
+  async.runtime.await %token0 : !async.token
+
+  // CHECK: 0
+  %err0 = async.runtime.is_error %token0 : !async.token
+  vector.print %err0 : i1
+
+  // ------------------------------------------------------------------------ //
+  // Check that assertion in the async.func converted to async error.
+  // ------------------------------------------------------------------------ //
+  %token1 = async.call @async_func_assert() : () -> !async.token
+  async.runtime.await %token1 : !async.token
+
+  // CHECK: 1
+  %err1 = async.runtime.is_error %token1 : !async.token
+  vector.print %err1 : i1
+
+  // ------------------------------------------------------------------------ //
+  // Check error propagation from the nested async.func.
+  // ------------------------------------------------------------------------ //
+  %token2 = async.call @async_func_nested_assert() : () -> !async.token
+  async.runtime.await %token2 : !async.token
+
+  // CHECK: 1
+  %err2 = async.runtime.is_error %token2 : !async.token
+  vector.print %err2 : i1
+
+  // ------------------------------------------------------------------------ //
+  // Check error propagation from the nested async.func with async values.
+  // ------------------------------------------------------------------------ //
+  %value3 = async.call @async_func_value_nested_assert() : () -> !async.value<f32>
+  async.runtime.await %value3 : !async.value<f32>
+
+  // CHECK: 1
+  %err3_0 = async.runtime.is_error %value3 : !async.value<f32>
+  vector.print %err3_0 : i1
+
+  // ------------------------------------------------------------------------ //
+  // Non-blocking async.await inside the async.func
+  // ------------------------------------------------------------------------ //
+  %result0 = async.call @async_func_non_blocking_await() : () -> !async.value<f32>
+  %4 = async.await %result0 : !async.value<f32>
+
+  // CHECK: 456.789
+  vector.print %4 : f32
+
+  // ------------------------------------------------------------------------ //
+  // Memref allocated inside async.func.
+  // ------------------------------------------------------------------------ //
+  %result1 = async.call @async_func_inside_memref() : () -> !async.value<memref<f32>>
+  %5 = async.await %result1 : !async.value<memref<f32>>
+  %6 = memref.cast %5 :  memref<f32> to memref<*xf32>
+
+  // CHECK: Unranked Memref
+  // CHECK-SAME: rank = 0 offset = 0 sizes = [] strides = []
+  // CHECK-NEXT: [0.25]
+  call @printMemrefF32(%6) : (memref<*xf32>) -> ()
+
+  // ------------------------------------------------------------------------ //
+  // Memref passed as async.func parameter
+  // ------------------------------------------------------------------------ //
+  %token3 = async.call @async_func_passed_memref(%result1) : (!async.value<memref<f32>>) -> !async.token
+  async.await %token3 : !async.token
+
+  // CHECK: Unranked Memref
+  // CHECK-SAME: rank = 0 offset = 0 sizes = [] strides = []
+  // CHECK-NEXT: [0.5]
+  call @printMemrefF32(%6) : (memref<*xf32>) -> ()
+
+  memref.dealloc %5 : memref<f32>
+
+  return
+}
+
+func.func private @printMemrefF32(memref<*xf32>)
+  attributes { llvm.emit_c_interface }


        


More information about the Mlir-commits mailing list