[llvm] 4313376 - [coro] Async coroutines: Allow more than 3 arguments in the dispatch function

Arnold Schwaighofer via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 11 15:27:46 PST 2020


Author: Arnold Schwaighofer
Date: 2020-11-11T15:25:28-08:00
New Revision: 431337662ee01bedb2f2a45fba960bfc7388adb6

URL: https://github.com/llvm/llvm-project/commit/431337662ee01bedb2f2a45fba960bfc7388adb6
DIFF: https://github.com/llvm/llvm-project/commit/431337662ee01bedb2f2a45fba960bfc7388adb6.diff

LOG: [coro] Async coroutines: Allow more than 3 arguments in the dispatch function

We need to be able to call function pointers. Inline the dispatch
function.

Also inline the context projection function.

Transfer debug locations from the suspend point to the inlined functions.

Use the function argument index instead of the function argument in
coro.id.async. This solves any spurious use issues.

Coerce the arguments of the tail call function at a suspend point. The LLVM
optimizer seems to drop casts leading to a vararg intrinsic.

rdar://70097093

Differential Revision: https://reviews.llvm.org/D91098

Added: 
    

Modified: 
    llvm/include/llvm/IR/Intrinsics.td
    llvm/lib/Transforms/Coroutines/CoroFrame.cpp
    llvm/lib/Transforms/Coroutines/CoroInstr.h
    llvm/lib/Transforms/Coroutines/CoroSplit.cpp
    llvm/lib/Transforms/Coroutines/Coroutines.cpp
    llvm/test/Transforms/Coroutines/coro-async.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 8ea27402decc..e0f3d67a62dd 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -1188,7 +1188,7 @@ def int_coro_id_retcon_once : Intrinsic<[llvm_token_ty],
     []>;
 def int_coro_alloc : Intrinsic<[llvm_i1_ty], [llvm_token_ty], []>;
 def int_coro_id_async : Intrinsic<[llvm_token_ty],
-  [llvm_i32_ty, llvm_i32_ty, llvm_ptr_ty, llvm_ptr_ty],
+  [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_ptr_ty],
   []>;
 def int_coro_async_context_alloc : Intrinsic<[llvm_ptr_ty],
     [llvm_ptr_ty, llvm_ptr_ty],

diff  --git a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
index 5265977ea5a0..27064d2da5da 100644
--- a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
@@ -1911,8 +1911,7 @@ static void sinkSpillUsesAfterCoroBegin(Function &F,
     for (User *U : Def->users()) {
       auto Inst = cast<Instruction>(U);
       if (Inst->getParent() != CoroBegin->getParent() ||
-          Dom.dominates(CoroBegin, Inst) ||
-          isa<CoroIdAsyncInst>(Inst) /*'fake' use of async context argument*/)
+          Dom.dominates(CoroBegin, Inst))
         continue;
       if (ToMove.insert(Inst))
         Worklist.push_back(Inst);

diff  --git a/llvm/lib/Transforms/Coroutines/CoroInstr.h b/llvm/lib/Transforms/Coroutines/CoroInstr.h
index 5f6ff68b9254..1b8333922473 100644
--- a/llvm/lib/Transforms/Coroutines/CoroInstr.h
+++ b/llvm/lib/Transforms/Coroutines/CoroInstr.h
@@ -293,11 +293,13 @@ class LLVM_LIBRARY_VISIBILITY CoroIdAsyncInst : public AnyCoroIdInst {
   }
 
   /// The async context parameter.
-  Value *getStorage() const { return getArgOperand(StorageArg); }
+  Value *getStorage() const {
+    return getParent()->getParent()->getArg(getStorageArgumentIndex());
+  }
 
   unsigned getStorageArgumentIndex() const {
-    auto *Arg = cast<Argument>(getArgOperand(StorageArg)->stripPointerCasts());
-    return Arg->getArgNo();
+    auto *Arg = cast<ConstantInt>(getArgOperand(StorageArg));
+    return Arg->getZExtValue();
   }
 
   /// Return the async function pointer address. This should be the address of

diff  --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
index c13ead98a0e2..fd1f074921fb 100644
--- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
@@ -668,16 +668,23 @@ Value *CoroCloner::deriveNewFramePointer() {
     auto *FramePtrTy = Shape.FrameTy->getPointerTo();
     auto *ProjectionFunc = cast<CoroSuspendAsyncInst>(ActiveSuspend)
                                ->getAsyncContextProjectionFunction();
+    auto DbgLoc =
+        cast<CoroSuspendAsyncInst>(VMap[ActiveSuspend])->getDebugLoc();
     // Calling i8* (i8*)
     auto *CallerContext = Builder.CreateCall(
         cast<FunctionType>(ProjectionFunc->getType()->getPointerElementType()),
         ProjectionFunc, CalleeContext);
     CallerContext->setCallingConv(ProjectionFunc->getCallingConv());
+    CallerContext->setDebugLoc(DbgLoc);
     // The frame is located after the async_context header.
     auto &Context = Builder.getContext();
     auto *FramePtrAddr = Builder.CreateConstInBoundsGEP1_32(
         Type::getInt8Ty(Context), CallerContext,
         Shape.AsyncLowering.FrameOffset, "async.ctx.frameptr");
+    // Inline the projection function.
+    InlineFunctionInfo InlineInfo;
+    auto InlineRes = InlineFunction(*CallerContext, InlineInfo);
+    assert(InlineRes.isSuccess());
     return Builder.CreateBitCast(FramePtrAddr, FramePtrTy);
   }
   // In continuation-lowering, the argument is the opaque storage.
@@ -1364,6 +1371,22 @@ static void replaceAsyncResumeFunction(CoroSuspendAsyncInst *Suspend,
   Suspend->setOperand(0, UndefValue::get(Int8PtrTy));
 }
 
+/// Coerce the arguments in \p FnArgs according to \p FnTy in \p CallArgs.
+static void coerceArguments(IRBuilder<> &Builder, FunctionType *FnTy,
+                            ArrayRef<Value *> FnArgs,
+                            SmallVectorImpl<Value *> &CallArgs) {
+  size_t ArgIdx = 0;
+  for (auto paramTy : FnTy->params()) {
+    assert(ArgIdx < FnArgs.size());
+    if (paramTy != FnArgs[ArgIdx]->getType())
+      CallArgs.push_back(
+          Builder.CreateBitOrPointerCast(FnArgs[ArgIdx], paramTy));
+    else
+      CallArgs.push_back(FnArgs[ArgIdx]);
+    ++ArgIdx;
+  }
+}
+
 static void splitAsyncCoroutine(Function &F, coro::Shape &Shape,
                                 SmallVectorImpl<Function *> &Clones) {
   assert(Shape.ABI == coro::ABI::Async);
@@ -1420,14 +1443,23 @@ static void splitAsyncCoroutine(Function &F, coro::Shape &Shape,
 
     IRBuilder<> Builder(ReturnBB);
 
-    // Insert the call to the tail call function.
-    auto *Fun = Suspend->getMustTailCallFunction();
+    // Insert the call to the tail call function and inline it.
+    auto *Fn = Suspend->getMustTailCallFunction();
+    auto DbgLoc = Suspend->getDebugLoc();
     SmallVector<Value *, 8> Args(Suspend->operand_values());
-    auto *TailCall = Builder.CreateCall(
-        cast<FunctionType>(Fun->getType()->getPointerElementType()), Fun,
-        ArrayRef<Value *>(Args).drop_front(3).drop_back(1));
-    TailCall->setTailCallKind(CallInst::TCK_MustTail);
-    TailCall->setCallingConv(Fun->getCallingConv());
+    auto FnArgs = ArrayRef<Value *>(Args).drop_front(3).drop_back(1);
+    auto FnTy = cast<FunctionType>(Fn->getType()->getPointerElementType());
+    // Coerce the arguments, llvm optimizations seem to ignore the types in
+    // vaarg functions and throws away casts in optimized mode.
+    SmallVector<Value *, 8> CallArgs;
+    coerceArguments(Builder, FnTy, FnArgs, CallArgs);
+    auto *TailCall = Builder.CreateCall(FnTy, Fn, CallArgs);
+    TailCall->setDebugLoc(DbgLoc);
+    TailCall->setTailCall();
+    TailCall->setCallingConv(Fn->getCallingConv());
+    InlineFunctionInfo FnInfo;
+    auto InlineRes = InlineFunction(*TailCall, FnInfo);
+    assert(InlineRes.isSuccess() && "Expected inlining to succeed");
     Builder.CreateRetVoid();
 
     // Replace the lvm.coro.async.resume intrisic call.

diff  --git a/llvm/lib/Transforms/Coroutines/Coroutines.cpp b/llvm/lib/Transforms/Coroutines/Coroutines.cpp
index fc0de47eeacc..726899f9c04c 100644
--- a/llvm/lib/Transforms/Coroutines/Coroutines.cpp
+++ b/llvm/lib/Transforms/Coroutines/Coroutines.cpp
@@ -683,11 +683,12 @@ static void checkAsyncFuncPointer(const Instruction *I, Value *V) {
 }
 
 void CoroIdAsyncInst::checkWellFormed() const {
-  // TODO: check that the StorageArg is a parameter of this function.
   checkConstantInt(this, getArgOperand(SizeArg),
                    "size argument to coro.id.async must be constant");
   checkConstantInt(this, getArgOperand(AlignArg),
                    "alignment argument to coro.id.async must be constant");
+  checkConstantInt(this, getArgOperand(StorageArg),
+                   "storage argument offset to coro.id.async must be constant");
   checkAsyncFuncPointer(this, getArgOperand(AsyncFuncPtrArg));
 }
 

diff  --git a/llvm/test/Transforms/Coroutines/coro-async.ll b/llvm/test/Transforms/Coroutines/coro-async.ll
index 4c75dba11c67..35b23bb33a6e 100644
--- a/llvm/test/Transforms/Coroutines/coro-async.ll
+++ b/llvm/test/Transforms/Coroutines/coro-async.ll
@@ -28,8 +28,9 @@ declare void @my_other_async_function(i8* %async.ctxt)
 }>
 
 ; Function that implements the dispatch to the callee function.
-define swiftcc void @my_async_function.my_other_async_function_fp.apply(i8* %async.ctxt, %async.task* %task, %async.actor* %actor) {
-  musttail call swiftcc void @asyncSuspend(i8* %async.ctxt, %async.task* %task, %async.actor* %actor)
+define swiftcc void @my_async_function.my_other_async_function_fp.apply(i8* %fnPtr, i8* %async.ctxt, %async.task* %task, %async.actor* %actor) {
+  %callee = bitcast i8* %fnPtr to void(i8*, %async.task*, %async.actor*)*
+  tail call swiftcc void %callee(i8* %async.ctxt, %async.task* %task, %async.actor* %actor)
   ret void
 }
 
@@ -50,8 +51,7 @@ entry:
   %proj.1 = getelementptr inbounds { i64, i64 }, { i64, i64 }* %tmp, i64 0, i32 0
   %proj.2 = getelementptr inbounds { i64, i64 }, { i64, i64 }* %tmp, i64 0, i32 1
 
-  %id = call token @llvm.coro.id.async(i32 128, i32 16,
-          i8* %async.ctxt,
+  %id = call token @llvm.coro.id.async(i32 128, i32 16, i32 0,
           i8* bitcast (<{i32, i32}>* @my_async_function_fp to i8*))
   %hdl = call i8* @llvm.coro.begin(token %id, i8* null)
   store i64 0, i64* %proj.1, align 8
@@ -78,11 +78,12 @@ entry:
   %callee_context.caller_context.addr = getelementptr inbounds %async.ctxt, %async.ctxt* %callee_context.0, i32 0, i32 0
   store i8* %async.ctxt, i8** %callee_context.caller_context.addr
   %resume_proj_fun = bitcast i8*(i8*)* @resume_context_projection to i8*
+  %callee = bitcast void(i8*, %async.task*, %async.actor*)* @asyncSuspend to i8*
   %res = call {i8*, i8*, i8*} (i8*, i8*, ...) @llvm.coro.suspend.async(
                                                   i8* %resume.func_ptr,
                                                   i8* %resume_proj_fun,
-                                                  void (i8*, %async.task*, %async.actor*)* @my_async_function.my_other_async_function_fp.apply,
-                                                  i8* %callee_context, %async.task* %task, %async.actor *%actor)
+                                                  void (i8*, i8*, %async.task*, %async.actor*)* @my_async_function.my_other_async_function_fp.apply,
+                                                  i8* %callee, i8* %callee_context, %async.task* %task, %async.actor *%actor)
 
   call void @llvm.coro.async.context.dealloc(i8* %callee_context)
   %continuation_task_arg = extractvalue {i8*, i8*, i8*} %res, 1
@@ -126,7 +127,7 @@ entry:
 ; CHECK:   store i8* bitcast (void (i8*, i8*, i8*)* @my_async_function.resume.0 to i8*), i8** [[RETURN_TO_CALLER_ADDR]]
 ; CHECK:   [[CALLER_CONTEXT_ADDR:%.*]] = bitcast i8* [[CALLEE_CTXT]] to i8**
 ; CHECK:   store i8* %async.ctxt, i8** [[CALLER_CONTEXT_ADDR]]
-; CHECK:   musttail call swiftcc void @asyncSuspend(i8* [[CALLEE_CTXT]], %async.task* %task, %async.actor* %actor)
+; CHECK:   tail call swiftcc void @asyncSuspend(i8* [[CALLEE_CTXT]], %async.task* %task, %async.actor* %actor)
 ; CHECK:   ret void
 ; CHECK: }
 
@@ -170,7 +171,7 @@ entry:
 define swiftcc void @my_async_function2(%async.task* %task, %async.actor* %actor, i8* %async.ctxt)  {
 entry:
 
-  %id = call token @llvm.coro.id.async(i32 128, i32 16, i8* %async.ctxt, i8* bitcast (<{i32, i32}>* @my_async_function2_fp to i8*))
+  %id = call token @llvm.coro.id.async(i32 128, i32 16, i32 2, i8* bitcast (<{i32, i32}>* @my_async_function2_fp to i8*))
   %hdl = call i8* @llvm.coro.begin(token %id, i8* null)
   ; setup callee context
   %arg0 = bitcast %async.task* %task to i8*
@@ -185,11 +186,12 @@ entry:
   %callee_context.caller_context.addr = getelementptr inbounds %async.ctxt, %async.ctxt* %callee_context.0, i32 0, i32 0
   store i8* %async.ctxt, i8** %callee_context.caller_context.addr
   %resume_proj_fun = bitcast i8*(i8*)* @resume_context_projection to i8*
+  %callee = bitcast void(i8*, %async.task*, %async.actor*)* @asyncSuspend to i8*
   %res = call {i8*, i8*, i8*} (i8*, i8*, ...) @llvm.coro.suspend.async(
                                                   i8* %resume.func_ptr,
                                                   i8* %resume_proj_fun,
-                                                  void (i8*, %async.task*, %async.actor*)* @my_async_function.my_other_async_function_fp.apply,
-                                                  i8* %callee_context, %async.task* %task, %async.actor *%actor)
+                                                  void (i8*, i8*, %async.task*, %async.actor*)* @my_async_function.my_other_async_function_fp.apply,
+                                                  i8* %callee, i8* %callee_context, %async.task* %task, %async.actor *%actor)
 
   %continuation_task_arg = extractvalue {i8*, i8*, i8*} %res, 0
   %task.2 =  bitcast i8* %continuation_task_arg to %async.task*
@@ -202,11 +204,12 @@ entry:
   %callee_context.caller_context.addr.1 = getelementptr inbounds %async.ctxt, %async.ctxt* %callee_context.0.1, i32 0, i32 0
   store i8* %async.ctxt, i8** %callee_context.caller_context.addr.1
   %resume_proj_fun.2 = bitcast i8*(i8*)* @resume_context_projection to i8*
+  %callee.2 = bitcast void(i8*, %async.task*, %async.actor*)* @asyncSuspend to i8*
   %res.2 = call {i8*, i8*, i8*} (i8*, i8*, ...) @llvm.coro.suspend.async(
                                                   i8* %resume.func_ptr.1,
                                                   i8* %resume_proj_fun.2,
-                                                  void (i8*, %async.task*, %async.actor*)* @my_async_function.my_other_async_function_fp.apply,
-                                                  i8* %callee_context, %async.task* %task, %async.actor *%actor)
+                                                  void (i8*, i8*, %async.task*, %async.actor*)* @my_async_function.my_other_async_function_fp.apply,
+                                                  i8* %callee.2, i8* %callee_context, %async.task* %task, %async.actor *%actor)
 
   call void @llvm.coro.async.context.dealloc(i8* %callee_context)
   %continuation_actor_arg = extractvalue {i8*, i8*, i8*} %res.2, 1
@@ -225,7 +228,7 @@ entry:
 ; CHECK: store i8* [[CALLEE_CTXT]],
 ; CHECK: store i8* bitcast (void (i8*, i8*, i8*)* @my_async_function2.resume.0 to i8*),
 ; CHECK: store i8* %async.ctxt,
-; CHECK: musttail call swiftcc void @asyncSuspend(i8* [[CALLEE_CTXT]], %async.task* %task, %async.actor* %actor)
+; CHECK: tail call swiftcc void @asyncSuspend(i8* [[CALLEE_CTXT]], %async.task* %task, %async.actor* %actor)
 ; CHECK: ret void
 
 ; CHECK-LABEL: define internal swiftcc void @my_async_function2.resume.0(i8* %0, i8* nocapture readnone %1, i8* nocapture readonly %2) {
@@ -235,7 +238,7 @@ entry:
 ; CHECK: [[CALLEE_CTXT_SPILL_ADDR2:%.*]] = bitcast i8* [[CALLEE_CTXT_SPILL_ADDR]] to i8**
 ; CHECK: store i8* bitcast (void (i8*, i8*, i8*)* @my_async_function2.resume.1 to i8*),
 ; CHECK: [[CALLLE_CTXT_RELOAD:%.*]] = load i8*, i8** [[CALLEE_CTXT_SPILL_ADDR2]]
-; CHECK: musttail call swiftcc void @asyncSuspend(i8* [[CALLEE_CTXT_RELOAD]]
+; CHECK: tail call swiftcc void @asyncSuspend(i8* [[CALLEE_CTXT_RELOAD]]
 ; CHECK: ret void
 
 ; CHECK-LABEL: define internal swiftcc void @my_async_function2.resume.1(i8* nocapture readnone %0, i8* %1, i8* nocapture readonly %2) {
@@ -258,7 +261,7 @@ define swiftcc void @top_level_caller(i8* %ctxt, i8* %task, i8* %actor) {
 ; CHECK: ret void
 
 declare i8* @llvm.coro.prepare.async(i8*)
-declare token @llvm.coro.id.async(i32, i32, i8*, i8*)
+declare token @llvm.coro.id.async(i32, i32, i32, i8*)
 declare i8* @llvm.coro.begin(token, i8*)
 declare i1 @llvm.coro.end(i8*, i1)
 declare {i8*, i8*, i8*} @llvm.coro.suspend.async(i8*, i8*, ...)


        


More information about the llvm-commits mailing list