[llvm] c6543cc - llvm.coro.id.async lowering: Parameterize how-to restore the current's continutation context and restart the pipeline after splitting

Arnold Schwaighofer via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 6 06:23:41 PST 2020


Author: Arnold Schwaighofer
Date: 2020-11-06T06:22:46-08:00
New Revision: c6543cc6b8f107b58e7205d8fc64865a508bacba

URL: https://github.com/llvm/llvm-project/commit/c6543cc6b8f107b58e7205d8fc64865a508bacba
DIFF: https://github.com/llvm/llvm-project/commit/c6543cc6b8f107b58e7205d8fc64865a508bacba.diff

LOG: llvm.coro.id.async lowering: Parameterize how-to restore the current's continutation context and restart the pipeline after splitting

The `llvm.coro.suspend.async` intrinsic takes a function pointer as its
argument that describes how-to restore the current continuation's
context from the context argument of the continuation function. Before
we assumed that the current context can be restored by loading from the
context arguments first pointer field (`first_arg->caller_context`).

This allows for defining suspension points that reuse the current
context for example.

Also:

llvm.coro.id.async lowering: Add llvm.coro.preprare.async intrinsic

Blocks inlining until after the async coroutine was split.

Also, change the async function pointer's context size position

   struct async_function_pointer {
     uint32_t relative_function_pointer_to_async_impl;
     uint32_t context_size;
   }

And make the position of the `async context` argument configurable. The
position is specified by the `llvm.coro.id.async` intrinsic.

rdar://70097093

Differential Revision: https://reviews.llvm.org/D90783

Added: 
    

Modified: 
    llvm/docs/Coroutines.rst
    llvm/include/llvm/IR/Intrinsics.td
    llvm/lib/Transforms/Coroutines/CoroElide.cpp
    llvm/lib/Transforms/Coroutines/CoroInstr.h
    llvm/lib/Transforms/Coroutines/CoroInternal.h
    llvm/lib/Transforms/Coroutines/CoroSplit.cpp
    llvm/lib/Transforms/Coroutines/Coroutines.cpp
    llvm/test/Transforms/Coroutines/coro-async.ll

Removed: 
    


################################################################################
diff  --git a/llvm/docs/Coroutines.rst b/llvm/docs/Coroutines.rst
index 69e9802f386e..77fb77d9a967 100644
--- a/llvm/docs/Coroutines.rst
+++ b/llvm/docs/Coroutines.rst
@@ -181,30 +181,47 @@ In async-continuation lowering, signaled by the use of `llvm.coro.id.async`,
 handling of control-flow must be handled explicitly by the frontend.
 
 In this lowering, a coroutine is assumed to take the current `async context` as
-its first argument. It is used to marshal arguments and return values of the
-coroutine. Therefore a async coroutine returns `void`.
+one of its arguments (the argument position is determined by
+`llvm.coro.id.async`). It is used to marshal arguments and return values of the
+coroutine. Therefore an async coroutine returns `void`.
 
 .. code-block:: llvm
 
   define swiftcc void @async_coroutine(i8* %async.ctxt, i8*, i8*) {
   }
 
+Values live accross a suspend point need to be stored in the coroutine frame to
+be available in the continuation function. This frame is stored as a tail to the
+`async context`.
 
-Every suspend point takes an `async context` argument which provides the context
-and the coroutine frame of the callee function. Every
-suspend point has an associated `resume function` denoted by the
-`llvm.coro.async.resume` intrinsic. The coroutine is resumed by
-calling this `resume function` passing the `async context` as the first
-argument. It is assumed that the `resume function` can restore its (the
-caller's) `async context` by loading the first field in the `async context`.
+Every suspend point takes an `context projection function` argument which
+describes how-to obtain the continuations `async context` and every suspend
+point has an associated `resume function` denoted by the
+`llvm.coro.async.resume` intrinsic. The coroutine is resumed by calling this
+`resume function` passing the `async context` as the one of its arguments
+argument. The `resume function` can restore its (the caller's) `async context`
+by applying a `context projection function` that is provided by the frontend as
+a parameter to the `llvm.coro.suspend.async` intrinsic.
 
 .. code-block:: c
 
+  // For example:
   struct async_context {
     struct async_context *caller_context;
     ...
   }
 
+  char *context_projection_function(struct async_context *callee_ctxt) {
+     return callee_ctxt->caller_context;
+  }
+
+.. code-block:: llvm
+
+  %resume_func_ptr = call i8* @llvm.coro.async.resume()
+  call {i8*, i8*, i8*} (i8*, i8*, ...) @llvm.coro.suspend.async(
+                                              i8* %resume_func_ptr,
+                                              i8* %context_projection_function
+
 The frontend should provide a `async function pointer` struct associated with
 each async coroutine by `llvm.coro.id.async`'s argument. The initial size and
 alignment of the `async context` must be provided as arguments to the
@@ -216,8 +233,8 @@ to obtain the required size.
 .. code-block:: c
 
   struct async_function_pointer {
-    uint32_t context_size;
     uint32_t relative_function_pointer_to_async_impl;
+    uint32_t context_size;
   }
 
 Lowering will split an async coroutine into a ramp function and one resume
@@ -231,6 +248,14 @@ to model the transfer to the callee function. It will be tail called by
 lowering and therefore must have the same signature and calling convention as
 the async coroutine.
 
+.. code-block:: llvm
+
+  call {i8*, i8*, i8*} (i8*, i8*, ...) @llvm.coro.suspend.async(
+                   i8* %resume_func_ptr,
+                   i8* %context_projection_function,
+                   i8* (bitcast void (i8*, i8*, i8*)* to i8*) %suspend_function,
+                   i8* %arg1, i8* %arg2, i8 %arg3)
+
 Coroutines by Example
 =====================
 
@@ -1482,10 +1507,11 @@ to the coroutine:
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 ::
 
-  declare {i8*, i8*, i8*} @llvm.coro.suspend.async(i8* <resume function>,
-                                                   i8* <callee context>,
-                                                   ... <function to call>
-                                                   ... <arguments to function>)
+  declare {i8*, i8*, i8*} @llvm.coro.suspend.async(
+                             i8* <resume function>,
+                             i8* <context projection function>,
+                             ... <function to call>
+                             ... <arguments to function>)
 
 Overview:
 """""""""
@@ -1500,8 +1526,9 @@ The first argument should be the result of the `llvm.coro.async.resume` intrinsi
 Lowering will replace this intrinsic with the resume function for this suspend
 point.
 
-The second argument is the `async context` allocation for the callee. It should
-provide storage the `async context` header and the coroutine frame.
+The second argument is the `context projection function`. It should describe
+how-to restore the `async context` in the continuation function from the first
+argument of the continuation function. Its type is `i8* (i8*)`.
 
 The third argument is the function that models tranfer to the callee at the
 suspend point. It should take 3 arguments. Lowering will `musttail` call this
@@ -1516,6 +1543,26 @@ The result of the intrinsic are mapped to the arguments of the resume function.
 Execution is suspended at this intrinsic and resumed when the resume function is
 called.
 
+.. _coro.prepare.async:
+
+'llvm.coro.prepare.async' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+::
+
+  declare i8* @llvm.coro.prepare.async(i8* <coroutine function>)
+
+Overview:
+"""""""""
+
+The '``llvm.coro.prepare.async``' intrinsic is used to block inlining of the
+async coroutine until after coroutine splitting.
+
+Arguments:
+""""""""""
+
+The first argument should be an async coroutine of type `void (i8*, i8*, i8*)`.
+Lowering will replace this intrinsic with its coroutine function argument.
+
 .. _coro.suspend.retcon:
 
 'llvm.coro.suspend.retcon' Intrinsic

diff  --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 85afbfe9ab3b..025160b96151 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -1202,6 +1202,8 @@ def int_coro_async_resume : Intrinsic<[llvm_ptr_ty],
 def int_coro_suspend_async : Intrinsic<[llvm_ptr_ty, llvm_ptr_ty, llvm_ptr_ty],
     [llvm_ptr_ty, llvm_ptr_ty, llvm_vararg_ty],
     []>;
+def int_coro_prepare_async : Intrinsic<[llvm_ptr_ty], [llvm_ptr_ty],
+                                       [IntrNoMem]>;
 def int_coro_begin : Intrinsic<[llvm_ptr_ty], [llvm_token_ty, llvm_ptr_ty],
                                [WriteOnly<ArgIndex<1>>]>;
 

diff  --git a/llvm/lib/Transforms/Coroutines/CoroElide.cpp b/llvm/lib/Transforms/Coroutines/CoroElide.cpp
index 9d364b3097c1..a0b8168f8717 100644
--- a/llvm/lib/Transforms/Coroutines/CoroElide.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroElide.cpp
@@ -366,7 +366,7 @@ static bool replaceDevirtTrigger(Function &F) {
 }
 
 static bool declaresCoroElideIntrinsics(Module &M) {
-  return coro::declaresIntrinsics(M, {"llvm.coro.id"});
+  return coro::declaresIntrinsics(M, {"llvm.coro.id", "llvm.coro.id.async"});
 }
 
 PreservedAnalyses CoroElidePass::run(Function &F, FunctionAnalysisManager &AM) {

diff  --git a/llvm/lib/Transforms/Coroutines/CoroInstr.h b/llvm/lib/Transforms/Coroutines/CoroInstr.h
index 5dbb94d10985..5f6ff68b9254 100644
--- a/llvm/lib/Transforms/Coroutines/CoroInstr.h
+++ b/llvm/lib/Transforms/Coroutines/CoroInstr.h
@@ -295,6 +295,11 @@ class LLVM_LIBRARY_VISIBILITY CoroIdAsyncInst : public AnyCoroIdInst {
   /// The async context parameter.
   Value *getStorage() const { return getArgOperand(StorageArg); }
 
+  unsigned getStorageArgumentIndex() const {
+    auto *Arg = cast<Argument>(getArgOperand(StorageArg)->stripPointerCasts());
+    return Arg->getArgNo();
+  }
+
   /// Return the async function pointer address. This should be the address of
   /// a async function pointer struct for the current async function.
   /// struct async_function_pointer {
@@ -504,11 +509,14 @@ inline CoroSaveInst *AnyCoroSuspendInst::getCoroSave() const {
 
 /// This represents the llvm.coro.suspend.async instruction.
 class LLVM_LIBRARY_VISIBILITY CoroSuspendAsyncInst : public AnyCoroSuspendInst {
-  enum { ResumeFunctionArg, AsyncContextArg, MustTailCallFuncArg };
+  enum { ResumeFunctionArg, AsyncContextProjectionArg, MustTailCallFuncArg };
 
 public:
-  Value *getAsyncContext() const {
-    return getArgOperand(AsyncContextArg)->stripPointerCasts();
+  void checkWellFormed() const;
+
+  Function *getAsyncContextProjectionFunction() const {
+    return cast<Function>(
+        getArgOperand(AsyncContextProjectionArg)->stripPointerCasts());
   }
 
   CoroAsyncResumeInst *getResumeFunction() const {

diff  --git a/llvm/lib/Transforms/Coroutines/CoroInternal.h b/llvm/lib/Transforms/Coroutines/CoroInternal.h
index 855c5c4b582d..e2f129c38d92 100644
--- a/llvm/lib/Transforms/Coroutines/CoroInternal.h
+++ b/llvm/lib/Transforms/Coroutines/CoroInternal.h
@@ -34,10 +34,12 @@ void initializeCoroCleanupLegacyPass(PassRegistry &);
 // CoroElide pass that triggers a restart of the pipeline by CGPassManager.
 // When CoroSplit pass sees the same coroutine the second time, it splits it up,
 // adds coroutine subfunctions to the SCC to be processed by IPO pipeline.
-
+// Async lowering similarily triggers a restart of the pipeline after it has
+// split the coroutine.
 #define CORO_PRESPLIT_ATTR "coroutine.presplit"
 #define UNPREPARED_FOR_SPLIT "0"
 #define PREPARED_FOR_SPLIT "1"
+#define ASYNC_RESTART_AFTER_SPLIT "2"
 
 #define CORO_DEVIRT_TRIGGER_FN "coro.devirt.trigger"
 
@@ -141,6 +143,7 @@ struct LLVM_LIBRARY_VISIBILITY Shape {
   struct AsyncLoweringStorage {
     FunctionType *AsyncFuncTy;
     Value *Context;
+    unsigned ContextArgNo;
     uint64_t ContextHeaderSize;
     uint64_t ContextAlignment;
     uint64_t FrameOffset; // Start of the frame.

diff  --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
index 1f4c2b0a8cd9..e46916b2d24b 100644
--- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
@@ -654,20 +654,23 @@ Value *CoroCloner::deriveNewFramePointer() {
   // In switch-lowering, the argument is the frame pointer.
   case coro::ABI::Switch:
     return &*NewF->arg_begin();
+  // In async-lowering, one of the arguments is an async context as determined
+  // by the `llvm.coro.id.async` intrinsic. We can retrieve the async context of
+  // the resume function from the async context projection function associated
+  // with the active suspend. The frame is located as a tail to the async
+  // context header.
   case coro::ABI::Async: {
-    auto *CalleeContext = &*NewF->arg_begin();
+    auto *CalleeContext = NewF->getArg(Shape.AsyncLowering.ContextArgNo);
     auto *FramePtrTy = Shape.FrameTy->getPointerTo();
-    // The caller context is assumed to be stored at the begining of the callee
-    // context.
-    // struct async_context {
-    //    struct async_context *caller;
-    //    ...
-    auto &Context = Builder.getContext();
-    auto *Int8PtrPtrTy = Type::getInt8PtrTy(Context)->getPointerTo();
-    auto *CallerContextAddr =
-        Builder.CreateBitOrPointerCast(CalleeContext, Int8PtrPtrTy);
-    auto *CallerContext = Builder.CreateLoad(CallerContextAddr);
+    auto *ProjectionFunc = cast<CoroSuspendAsyncInst>(ActiveSuspend)
+                               ->getAsyncContextProjectionFunction();
+    // Calling i8* (i8*)
+    auto *CallerContext = Builder.CreateCall(
+        cast<FunctionType>(ProjectionFunc->getType()->getPointerElementType()),
+        ProjectionFunc, CalleeContext);
+    CallerContext->setCallingConv(ProjectionFunc->getCallingConv());
     // The frame is located after the async_context header.
+    auto &Context = Builder.getContext();
     auto *FramePtrAddr = Builder.CreateConstInBoundsGEP1_32(
         Type::getInt8Ty(Context), CallerContext,
         Shape.AsyncLowering.FrameOffset, "async.ctx.frameptr");
@@ -871,12 +874,12 @@ static void updateAsyncFuncPointerContextSize(coro::Shape &Shape) {
 
   auto *FuncPtrStruct = cast<ConstantStruct>(
       Shape.AsyncLowering.AsyncFuncPointer->getInitializer());
-  auto *OrigContextSize = FuncPtrStruct->getOperand(0);
-  auto *OrigRelativeFunOffset = FuncPtrStruct->getOperand(1);
+  auto *OrigRelativeFunOffset = FuncPtrStruct->getOperand(0);
+  auto *OrigContextSize = FuncPtrStruct->getOperand(1);
   auto *NewContextSize = ConstantInt::get(OrigContextSize->getType(),
                                           Shape.AsyncLowering.ContextSize);
   auto *NewFuncPtrStruct = ConstantStruct::get(
-      FuncPtrStruct->getType(), NewContextSize, OrigRelativeFunOffset);
+      FuncPtrStruct->getType(), OrigRelativeFunOffset, NewContextSize);
 
   Shape.AsyncLowering.AsyncFuncPointer->setInitializer(NewFuncPtrStruct);
 }
@@ -1671,7 +1674,10 @@ static void updateCallGraphAfterCoroutineSplit(
 // When we see the coroutine the first time, we insert an indirect call to a
 // devirt trigger function and mark the coroutine that it is now ready for
 // split.
-static void prepareForSplit(Function &F, CallGraph &CG) {
+// Async lowering uses this after it has split the function to restart the
+// pipeline.
+static void prepareForSplit(Function &F, CallGraph &CG,
+                            bool MarkForAsyncRestart = false) {
   Module &M = *F.getParent();
   LLVMContext &Context = F.getContext();
 #ifndef NDEBUG
@@ -1679,7 +1685,9 @@ static void prepareForSplit(Function &F, CallGraph &CG) {
   assert(DevirtFn && "coro.devirt.trigger function not found");
 #endif
 
-  F.addFnAttr(CORO_PRESPLIT_ATTR, PREPARED_FOR_SPLIT);
+  F.addFnAttr(CORO_PRESPLIT_ATTR, MarkForAsyncRestart
+                                      ? ASYNC_RESTART_AFTER_SPLIT
+                                      : PREPARED_FOR_SPLIT);
 
   // Insert an indirect call sequence that will be devirtualized by CoroElide
   // pass:
@@ -1687,7 +1695,9 @@ static void prepareForSplit(Function &F, CallGraph &CG) {
   //    %1 = bitcast i8* %0 to void(i8*)*
   //    call void %1(i8* null)
   coro::LowererBase Lowerer(M);
-  Instruction *InsertPt = F.getEntryBlock().getTerminator();
+  Instruction *InsertPt =
+      MarkForAsyncRestart ? F.getEntryBlock().getFirstNonPHIOrDbgOrLifetime()
+                          : F.getEntryBlock().getTerminator();
   auto *Null = ConstantPointerNull::get(Type::getInt8PtrTy(Context));
   auto *DevirtFnAddr =
       Lowerer.makeSubFnCall(Null, CoroSubFnInst::RestartTrigger, InsertPt);
@@ -1849,8 +1859,17 @@ static bool replaceAllPrepares(Function *PrepareFn, CallGraph &CG) {
 }
 
 static bool declaresCoroSplitIntrinsics(const Module &M) {
-  return coro::declaresIntrinsics(
-      M, {"llvm.coro.begin", "llvm.coro.prepare.retcon"});
+  return coro::declaresIntrinsics(M, {"llvm.coro.begin",
+                                      "llvm.coro.prepare.retcon",
+                                      "llvm.coro.prepare.async"});
+}
+
+static void addPrepareFunction(const Module &M,
+                               SmallVectorImpl<Function *> &Fns,
+                               StringRef Name) {
+  auto *PrepareFn = M.getFunction(Name);
+  if (PrepareFn && !PrepareFn->use_empty())
+    Fns.push_back(PrepareFn);
 }
 
 PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C,
@@ -1866,10 +1885,10 @@ PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C,
   if (!declaresCoroSplitIntrinsics(M))
     return PreservedAnalyses::all();
 
-  // Check for uses of llvm.coro.prepare.retcon.
-  auto *PrepareFn = M.getFunction("llvm.coro.prepare.retcon");
-  if (PrepareFn && PrepareFn->use_empty())
-    PrepareFn = nullptr;
+  // Check for uses of llvm.coro.prepare.retcon/async.
+  SmallVector<Function *, 2> PrepareFns;
+  addPrepareFunction(M, PrepareFns, "llvm.coro.prepare.retcon");
+  addPrepareFunction(M, PrepareFns, "llvm.coro.prepare.async");
 
   // Find coroutines for processing.
   SmallVector<LazyCallGraph::Node *, 4> Coroutines;
@@ -1877,11 +1896,14 @@ PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C,
     if (N.getFunction().hasFnAttribute(CORO_PRESPLIT_ATTR))
       Coroutines.push_back(&N);
 
-  if (Coroutines.empty() && !PrepareFn)
+  if (Coroutines.empty() && PrepareFns.empty())
     return PreservedAnalyses::all();
 
-  if (Coroutines.empty())
-    replaceAllPrepares(PrepareFn, CG, C);
+  if (Coroutines.empty()) {
+    for (auto *PrepareFn : PrepareFns) {
+      replaceAllPrepares(PrepareFn, CG, C);
+    }
+  }
 
   // Split all the coroutines.
   for (LazyCallGraph::Node *N : Coroutines) {
@@ -1911,10 +1933,18 @@ PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C,
     SmallVector<Function *, 4> Clones;
     const coro::Shape Shape = splitCoroutine(F, Clones, ReuseFrameSlot);
     updateCallGraphAfterCoroutineSplit(*N, Shape, Clones, C, CG, AM, UR, FAM);
+
+    if (Shape.ABI == coro::ABI::Async && !Shape.CoroSuspends.empty()) {
+      // We want the inliner to be run on the newly inserted functions.
+      UR.CWorklist.insert(&C);
+    }
   }
 
-  if (PrepareFn)
-    replaceAllPrepares(PrepareFn, CG, C);
+  if (!PrepareFns.empty()) {
+    for (auto *PrepareFn : PrepareFns) {
+      replaceAllPrepares(PrepareFn, CG, C);
+    }
+  }
 
   return PreservedAnalyses::none();
 }
@@ -1952,10 +1982,10 @@ struct CoroSplitLegacy : public CallGraphSCCPass {
       return false;
 
     // Check for uses of llvm.coro.prepare.retcon.
-    auto PrepareFn =
-      SCC.getCallGraph().getModule().getFunction("llvm.coro.prepare.retcon");
-    if (PrepareFn && PrepareFn->use_empty())
-      PrepareFn = nullptr;
+    SmallVector<Function *, 2> PrepareFns;
+    auto &M = SCC.getCallGraph().getModule();
+    addPrepareFunction(M, PrepareFns, "llvm.coro.prepare.retcon");
+    addPrepareFunction(M, PrepareFns, "llvm.coro.prepare.async");
 
     // Find coroutines for processing.
     SmallVector<Function *, 4> Coroutines;
@@ -1964,13 +1994,17 @@ struct CoroSplitLegacy : public CallGraphSCCPass {
         if (F->hasFnAttribute(CORO_PRESPLIT_ATTR))
           Coroutines.push_back(F);
 
-    if (Coroutines.empty() && !PrepareFn)
+    if (Coroutines.empty() && PrepareFns.empty())
       return false;
 
     CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
 
-    if (Coroutines.empty())
-      return replaceAllPrepares(PrepareFn, CG);
+    if (Coroutines.empty()) {
+      bool Changed = false;
+      for (auto *PrepareFn : PrepareFns)
+        Changed |= replaceAllPrepares(PrepareFn, CG);
+      return Changed;
+    }
 
     createDevirtTriggerFunc(CG, SCC);
 
@@ -1980,6 +2014,12 @@ struct CoroSplitLegacy : public CallGraphSCCPass {
       StringRef Value = Attr.getValueAsString();
       LLVM_DEBUG(dbgs() << "CoroSplit: Processing coroutine '" << F->getName()
                         << "' state: " << Value << "\n");
+      // Async lowering marks coroutines to trigger a restart of the pipeline
+      // after it has split them.
+      if (Value == ASYNC_RESTART_AFTER_SPLIT) {
+        F->removeFnAttr(CORO_PRESPLIT_ATTR);
+        continue;
+      }
       if (Value == UNPREPARED_FOR_SPLIT) {
         prepareForSplit(*F, CG);
         continue;
@@ -1989,9 +2029,15 @@ struct CoroSplitLegacy : public CallGraphSCCPass {
       SmallVector<Function *, 4> Clones;
       const coro::Shape Shape = splitCoroutine(*F, Clones, ReuseFrameSlot);
       updateCallGraphAfterCoroutineSplit(*F, Shape, Clones, CG, SCC);
+      if (Shape.ABI == coro::ABI::Async) {
+        // Restart SCC passes.
+        // Mark function for CoroElide pass. It will devirtualize causing a
+        // restart of the SCC pipeline.
+        prepareForSplit(*F, CG, true /*MarkForAsyncRestart*/);
+      }
     }
 
-    if (PrepareFn)
+    for (auto *PrepareFn : PrepareFns)
       replaceAllPrepares(PrepareFn, CG);
 
     return true;

diff  --git a/llvm/lib/Transforms/Coroutines/Coroutines.cpp b/llvm/lib/Transforms/Coroutines/Coroutines.cpp
index 932bd4015993..fc0de47eeacc 100644
--- a/llvm/lib/Transforms/Coroutines/Coroutines.cpp
+++ b/llvm/lib/Transforms/Coroutines/Coroutines.cpp
@@ -139,6 +139,7 @@ static bool isCoroutineIntrinsicName(StringRef Name) {
       "llvm.coro.id.retcon.once",
       "llvm.coro.noop",
       "llvm.coro.param",
+      "llvm.coro.prepare.async",
       "llvm.coro.prepare.retcon",
       "llvm.coro.promise",
       "llvm.coro.resume",
@@ -276,6 +277,7 @@ void coro::Shape::buildFrom(Function &F) {
         break;
       case Intrinsic::coro_suspend_async: {
         auto *Suspend = cast<CoroSuspendAsyncInst>(II);
+        Suspend->checkWellFormed();
         CoroSuspends.push_back(Suspend);
         break;
       }
@@ -386,6 +388,7 @@ void coro::Shape::buildFrom(Function &F) {
     AsyncId->checkWellFormed();
     this->ABI = coro::ABI::Async;
     this->AsyncLowering.Context = AsyncId->getStorage();
+    this->AsyncLowering.ContextArgNo = AsyncId->getStorageArgumentIndex();
     this->AsyncLowering.ContextHeaderSize = AsyncId->getStorageSize();
     this->AsyncLowering.ContextAlignment =
         AsyncId->getStorageAlignment().value();
@@ -688,6 +691,27 @@ void CoroIdAsyncInst::checkWellFormed() const {
   checkAsyncFuncPointer(this, getArgOperand(AsyncFuncPtrArg));
 }
 
+static void checkAsyncContextProjectFunction(const Instruction *I,
+                                             Function *F) {
+  auto *FunTy = cast<FunctionType>(F->getType()->getPointerElementType());
+  if (!FunTy->getReturnType()->isPointerTy() ||
+      !FunTy->getReturnType()->getPointerElementType()->isIntegerTy(8))
+    fail(I,
+         "llvm.coro.suspend.async resume function projection function must "
+         "return an i8* type",
+         F);
+  if (FunTy->getNumParams() != 1 || !FunTy->getParamType(0)->isPointerTy() ||
+      !FunTy->getParamType(0)->getPointerElementType()->isIntegerTy(8))
+    fail(I,
+         "llvm.coro.suspend.async resume function projection function must "
+         "take one i8* type as parameter",
+         F);
+}
+
+void CoroSuspendAsyncInst::checkWellFormed() const {
+  checkAsyncContextProjectFunction(this, getAsyncContextProjectionFunction());
+}
+
 void LLVMAddCoroEarlyPass(LLVMPassManagerRef PM) {
   unwrap(PM)->add(createCoroEarlyLegacyPass());
 }

diff  --git a/llvm/test/Transforms/Coroutines/coro-async.ll b/llvm/test/Transforms/Coroutines/coro-async.ll
index 79e5b66dd9cf..4c75dba11c67 100644
--- a/llvm/test/Transforms/Coroutines/coro-async.ll
+++ b/llvm/test/Transforms/Coroutines/coro-async.ll
@@ -13,18 +13,19 @@ target datalayout = "p:64:64:64"
 declare void @my_other_async_function(i8* %async.ctxt)
 
 ; The current async function (the caller).
-; This struct describes an async function. The first field is the size needed
-; for the async context of the current async function, the second field is the
-; relative offset to the async function implementation.
+; This struct describes an async function. The first field is the
+; relative offset to the async function implementation, the second field is the
+; size needed for the async context of the current async function.
+
 @my_async_function_fp = constant <{ i32, i32 }>
-  <{ i32 128,    ; Initial async context size without space for frame
-     i32 trunc ( ; Relative pointer to async function
+  <{ i32 trunc ( ; Relative pointer to async function
        i64 sub (
          i64 ptrtoint (void (i8*, %async.task*, %async.actor*)* @my_async_function to i64),
          i64 ptrtoint (i32* getelementptr inbounds (<{ i32, i32 }>, <{ i32, i32 }>* @my_async_function_fp, i32 0, i32 1) to i64)
        )
-     to i32)
-  }>
+     to i32),
+     i32 128    ; Initial async context size without space for frame
+}>
 
 ; Function that implements the dispatch to the callee function.
 define swiftcc void @my_async_function.my_other_async_function_fp.apply(i8* %async.ctxt, %async.task* %task, %async.actor* %actor) {
@@ -35,13 +36,23 @@ define swiftcc void @my_async_function.my_other_async_function_fp.apply(i8* %asy
 declare void @some_user(i64)
 declare void @some_may_write(i64*)
 
+define i8* @resume_context_projection(i8* %ctxt) {
+entry:
+  %resume_ctxt_addr = bitcast i8* %ctxt to i8**
+  %resume_ctxt = load i8*, i8** %resume_ctxt_addr, align 8
+  ret i8* %resume_ctxt
+}
+
+
 define swiftcc void @my_async_function(i8* %async.ctxt, %async.task* %task, %async.actor* %actor)  {
 entry:
   %tmp = alloca { i64, i64 }, align 8
   %proj.1 = getelementptr inbounds { i64, i64 }, { i64, i64 }* %tmp, i64 0, i32 0
   %proj.2 = getelementptr inbounds { i64, i64 }, { i64, i64 }* %tmp, i64 0, i32 1
 
-  %id = call token @llvm.coro.id.async(i32 128, i32 16, i8* %async.ctxt, i8* bitcast (<{i32, i32}>* @my_async_function_fp to i8*))
+  %id = call token @llvm.coro.id.async(i32 128, i32 16,
+          i8* %async.ctxt,
+          i8* bitcast (<{i32, i32}>* @my_async_function_fp to i8*))
   %hdl = call i8* @llvm.coro.begin(token %id, i8* null)
   store i64 0, i64* %proj.1, align 8
   store i64 1, i64* %proj.2, align 8
@@ -66,10 +77,10 @@ entry:
   ; store caller context into callee context
   %callee_context.caller_context.addr = getelementptr inbounds %async.ctxt, %async.ctxt* %callee_context.0, i32 0, i32 0
   store i8* %async.ctxt, i8** %callee_context.caller_context.addr
-
+  %resume_proj_fun = bitcast i8*(i8*)* @resume_context_projection to i8*
   %res = call {i8*, i8*, i8*} (i8*, i8*, ...) @llvm.coro.suspend.async(
                                                   i8* %resume.func_ptr,
-                                                  i8* %callee_context,
+                                                  i8* %resume_proj_fun,
                                                   void (i8*, %async.task*, %async.actor*)* @my_async_function.my_other_async_function_fp.apply,
                                                   i8* %callee_context, %async.task* %task, %async.actor *%actor)
 
@@ -87,8 +98,8 @@ entry:
 }
 
 ; Make sure we update the async function pointer
-; CHECK: @my_async_function_fp = constant <{ i32, i32 }> <{ i32 168,
-; CHECK: @my_async_function2_fp = constant <{ i32, i32 }> <{ i32 168,
+; CHECK: @my_async_function_fp = constant <{ i32, i32 }> <{ {{.*}}, i32 168 }
+; CHECK: @my_async_function2_fp = constant <{ i32, i32 }> <{ {{.*}}, i32 168 }
 
 ; CHECK-LABEL: define swiftcc void @my_async_function(i8* %async.ctxt, %async.task* %task, %async.actor* %actor) {
 ; CHECK: entry:
@@ -115,11 +126,11 @@ entry:
 ; CHECK:   store i8* bitcast (void (i8*, i8*, i8*)* @my_async_function.resume.0 to i8*), i8** [[RETURN_TO_CALLER_ADDR]]
 ; CHECK:   [[CALLER_CONTEXT_ADDR:%.*]] = bitcast i8* [[CALLEE_CTXT]] to i8**
 ; CHECK:   store i8* %async.ctxt, i8** [[CALLER_CONTEXT_ADDR]]
-; CHECK:   musttail call swiftcc void @my_async_function.my_other_async_function_fp.apply(i8* [[CALLEE_CTXT]], %async.task* %task, %async.actor* %actor)
+; CHECK:   musttail call swiftcc void @asyncSuspend(i8* [[CALLEE_CTXT]], %async.task* %task, %async.actor* %actor)
 ; CHECK:   ret void
 ; CHECK: }
 
-; CHECK-LABEL: define internal swiftcc void @my_async_function.resume.0(i8* %0, i8* %1, i8* %2) {
+; CHECK-LABEL: define internal swiftcc void @my_async_function.resume.0(i8* nocapture readonly %0, i8* %1, i8* nocapture readnone %2) {
 ; CHECK: entryresume.0:
 ; CHECK:   [[CALLER_CONTEXT_ADDR:%.*]] = bitcast i8* %0 to i8**
 ; CHECK:   [[CALLER_CONTEXT:%.*]] = load i8*, i8** [[CALLER_CONTEXT_ADDR]]
@@ -130,7 +141,7 @@ entry:
 ; CHECK:   [[ACTOR_RELOAD_ADDR:%.*]] = getelementptr inbounds i8, i8* [[CALLER_CONTEXT]], i64 152
 ; CHECK:   [[CAST2:%.*]] = bitcast i8* [[ACTOR_RELOAD_ADDR]] to %async.actor**
 ; CHECK:   [[ACTOR_RELOAD:%.*]] = load %async.actor*, %async.actor** [[CAST2]]
-; CHECK:   [[ADDR1:%.*]] = getelementptr inbounds i8, i8* %4, i64 144
+; CHECK:   [[ADDR1:%.*]] = getelementptr inbounds i8, i8* [[CALLER_CONTEXT]], i64 144
 ; CHECK:   [[ASYNC_CTXT_RELOAD_ADDR:%.*]] = bitcast i8* [[ADDR1]] to i8**
 ; CHECK:   [[ASYNC_CTXT_RELOAD:%.*]] = load i8*, i8** [[ASYNC_CTXT_RELOAD_ADDR]]
 ; CHECK:   [[ALLOCA_PRJ2:%.*]] = getelementptr inbounds i8, i8* [[CALLER_CONTEXT]], i64 136
@@ -147,16 +158,16 @@ entry:
 ; CHECK: }
 
 @my_async_function2_fp = constant <{ i32, i32 }>
-  <{ i32 128,    ; Initial async context size without space for frame
-     i32 trunc ( ; Relative pointer to async function
+  <{ i32 trunc ( ; Relative pointer to async function
        i64 sub (
-         i64 ptrtoint (void (i8*, %async.task*, %async.actor*)* @my_async_function2 to i64),
+         i64 ptrtoint (void (%async.task*, %async.actor*, i8*)* @my_async_function2 to i64),
          i64 ptrtoint (i32* getelementptr inbounds (<{ i32, i32 }>, <{ i32, i32 }>* @my_async_function2_fp, i32 0, i32 1) to i64)
        )
-     to i32)
+     to i32),
+     i32 128    ; Initial async context size without space for frame
   }>
 
-define swiftcc void @my_async_function2(i8* %async.ctxt, %async.task* %task, %async.actor* %actor)  {
+define swiftcc void @my_async_function2(%async.task* %task, %async.actor* %actor, i8* %async.ctxt)  {
 entry:
 
   %id = call token @llvm.coro.id.async(i32 128, i32 16, i8* %async.ctxt, i8* bitcast (<{i32, i32}>* @my_async_function2_fp to i8*))
@@ -173,13 +184,14 @@ entry:
   store i8* %resume.func_ptr, i8** %return_to_caller.addr
   %callee_context.caller_context.addr = getelementptr inbounds %async.ctxt, %async.ctxt* %callee_context.0, i32 0, i32 0
   store i8* %async.ctxt, i8** %callee_context.caller_context.addr
+  %resume_proj_fun = bitcast i8*(i8*)* @resume_context_projection to i8*
   %res = call {i8*, i8*, i8*} (i8*, i8*, ...) @llvm.coro.suspend.async(
                                                   i8* %resume.func_ptr,
-                                                  i8* %callee_context,
+                                                  i8* %resume_proj_fun,
                                                   void (i8*, %async.task*, %async.actor*)* @my_async_function.my_other_async_function_fp.apply,
                                                   i8* %callee_context, %async.task* %task, %async.actor *%actor)
 
-  %continuation_task_arg = extractvalue {i8*, i8*, i8*} %res, 1
+  %continuation_task_arg = extractvalue {i8*, i8*, i8*} %res, 0
   %task.2 =  bitcast i8* %continuation_task_arg to %async.task*
 
 	%callee_context.0.1 = bitcast i8* %callee_context to %async.ctxt*
@@ -189,14 +201,15 @@ entry:
   store i8* %resume.func_ptr.1, i8** %return_to_caller.addr.1
   %callee_context.caller_context.addr.1 = getelementptr inbounds %async.ctxt, %async.ctxt* %callee_context.0.1, i32 0, i32 0
   store i8* %async.ctxt, i8** %callee_context.caller_context.addr.1
+  %resume_proj_fun.2 = bitcast i8*(i8*)* @resume_context_projection to i8*
   %res.2 = call {i8*, i8*, i8*} (i8*, i8*, ...) @llvm.coro.suspend.async(
                                                   i8* %resume.func_ptr.1,
-                                                  i8* %callee_context,
+                                                  i8* %resume_proj_fun.2,
                                                   void (i8*, %async.task*, %async.actor*)* @my_async_function.my_other_async_function_fp.apply,
                                                   i8* %callee_context, %async.task* %task, %async.actor *%actor)
 
   call void @llvm.coro.async.context.dealloc(i8* %callee_context)
-  %continuation_actor_arg = extractvalue {i8*, i8*, i8*} %res.2, 2
+  %continuation_actor_arg = extractvalue {i8*, i8*, i8*} %res.2, 1
   %actor.2 =  bitcast i8* %continuation_actor_arg to %async.actor*
 
   tail call swiftcc void @asyncReturn(i8* %async.ctxt, %async.task* %task.2, %async.actor* %actor.2)
@@ -204,32 +217,47 @@ entry:
   unreachable
 }
 
-; CHECK-LABEL: define swiftcc void @my_async_function2(i8* %async.ctxt, %async.task* %task, %async.actor* %actor) {
+; CHECK-LABEL: define swiftcc void @my_async_function2(%async.task* %task, %async.actor* %actor, i8* %async.ctxt) {
+; CHECK: store i8* %async.ctxt,
 ; CHECK: store %async.actor* %actor,
 ; CHECK: store %async.task* %task,
-; CHECK: store i8* %async.ctxt,
 ; CHECK: [[CALLEE_CTXT:%.*]] =  tail call i8* @llvm.coro.async.context.alloc(
 ; CHECK: store i8* [[CALLEE_CTXT]],
 ; CHECK: store i8* bitcast (void (i8*, i8*, i8*)* @my_async_function2.resume.0 to i8*),
 ; CHECK: store i8* %async.ctxt,
-; CHECK: musttail call swiftcc void @my_async_function.my_other_async_function_fp.apply(i8* [[CALLEE_CTXT]], %async.task* %task, %async.actor* %actor)
+; CHECK: musttail call swiftcc void @asyncSuspend(i8* [[CALLEE_CTXT]], %async.task* %task, %async.actor* %actor)
 ; CHECK: ret void
 
-; CHECK-LABEL: define internal swiftcc void @my_async_function2.resume.0(i8* %0, i8* %1, i8* %2) {
-; CHECK: [[CALLEE_CTXT_ADDR:%.*]] = bitcast i8* %0 to i8**
+; CHECK-LABEL: define internal swiftcc void @my_async_function2.resume.0(i8* %0, i8* nocapture readnone %1, i8* nocapture readonly %2) {
+; CHECK: [[CALLEE_CTXT_ADDR:%.*]] = bitcast i8* %2 to i8**
 ; CHECK: [[CALLEE_CTXT:%.*]] = load i8*, i8** [[CALLEE_CTXT_ADDR]]
 ; CHECK: [[CALLEE_CTXT_SPILL_ADDR:%.*]] = getelementptr inbounds i8, i8* [[CALLEE_CTXT]], i64 152
 ; CHECK: [[CALLEE_CTXT_SPILL_ADDR2:%.*]] = bitcast i8* [[CALLEE_CTXT_SPILL_ADDR]] to i8**
 ; CHECK: store i8* bitcast (void (i8*, i8*, i8*)* @my_async_function2.resume.1 to i8*),
 ; CHECK: [[CALLLE_CTXT_RELOAD:%.*]] = load i8*, i8** [[CALLEE_CTXT_SPILL_ADDR2]]
-; CHECK: musttail call swiftcc void @my_async_function.my_other_async_function_fp.apply(i8* [[CALLEE_CTXT_RELOAD]]
+; CHECK: musttail call swiftcc void @asyncSuspend(i8* [[CALLEE_CTXT_RELOAD]]
 ; CHECK: ret void
 
-; CHECK-LABEL: define internal swiftcc void @my_async_function2.resume.1(i8* %0, i8* %1, i8* %2) {
-; CHECK: [[ACTOR_ARG:%.*]] = bitcast i8* %2
+; CHECK-LABEL: define internal swiftcc void @my_async_function2.resume.1(i8* nocapture readnone %0, i8* %1, i8* nocapture readonly %2) {
+; CHECK: bitcast i8* %2 to i8**
+; CHECK: [[ACTOR_ARG:%.*]] = bitcast i8* %1
 ; CHECK: tail call swiftcc void @asyncReturn({{.*}}[[ACTOR_ARG]])
 ; CHECK: ret void
 
+define swiftcc void @top_level_caller(i8* %ctxt, i8* %task, i8* %actor) {
+  %prepare = call i8* @llvm.coro.prepare.async(i8* bitcast (void (i8*, %async.task*,  %async.actor*)* @my_async_function to i8*))
+  %f = bitcast i8* %prepare to void (i8*, i8*, i8*)*
+  call swiftcc void %f(i8* %ctxt, i8* %task, i8* %actor)
+  ret void
+}
+
+; CHECK-LABEL: define swiftcc void @top_level_caller(i8* %ctxt, i8* %task, i8* %actor)
+; CHECK: store i8* bitcast (void (i8*, i8*, i8*)* @my_async_function.resume.0
+; CHECK: store i8* %ctxt
+; CHECK: tail call swiftcc void @asyncSuspend
+; CHECK: ret void
+
+declare i8* @llvm.coro.prepare.async(i8*)
 declare token @llvm.coro.id.async(i32, i32, i8*, i8*)
 declare i8* @llvm.coro.begin(token, i8*)
 declare i1 @llvm.coro.end(i8*, i1)


        


More information about the llvm-commits mailing list