[llvm] [Coroutines][NFC] Refactor CoroSplit for Switch Resume ABI (PR #80758)

Yuxuan Chen via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 5 15:04:27 PST 2024


https://github.com/yuxuanchen1997 updated https://github.com/llvm/llvm-project/pull/80758

>From f9b13c03b21bb4d96d23c3262f990a698c84483e Mon Sep 17 00:00:00 2001
From: Yuxuan Chen <yuxuanchen1997 at outlook.com>
Date: Mon, 5 Feb 2024 14:28:11 -0800
Subject: [PATCH] put switch resume splitting in one struct

---
 llvm/lib/Transforms/Coroutines/CoroSplit.cpp | 471 ++++++++++---------
 1 file changed, 237 insertions(+), 234 deletions(-)

diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
index 7758b52abc204..a552b3ac20085 100644
--- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
@@ -407,104 +407,6 @@ static void replaceCoroEnd(AnyCoroEndInst *End, const coro::Shape &Shape,
   End->eraseFromParent();
 }
 
-// Create an entry block for a resume function with a switch that will jump to
-// suspend points.
-static void createResumeEntryBlock(Function &F, coro::Shape &Shape) {
-  assert(Shape.ABI == coro::ABI::Switch);
-  LLVMContext &C = F.getContext();
-
-  // resume.entry:
-  //  %index.addr = getelementptr inbounds %f.Frame, %f.Frame* %FramePtr, i32 0,
-  //  i32 2
-  //  % index = load i32, i32* %index.addr
-  //  switch i32 %index, label %unreachable [
-  //    i32 0, label %resume.0
-  //    i32 1, label %resume.1
-  //    ...
-  //  ]
-
-  auto *NewEntry = BasicBlock::Create(C, "resume.entry", &F);
-  auto *UnreachBB = BasicBlock::Create(C, "unreachable", &F);
-
-  IRBuilder<> Builder(NewEntry);
-  auto *FramePtr = Shape.FramePtr;
-  auto *FrameTy = Shape.FrameTy;
-  auto *GepIndex = Builder.CreateStructGEP(
-      FrameTy, FramePtr, Shape.getSwitchIndexField(), "index.addr");
-  auto *Index = Builder.CreateLoad(Shape.getIndexType(), GepIndex, "index");
-  auto *Switch =
-      Builder.CreateSwitch(Index, UnreachBB, Shape.CoroSuspends.size());
-  Shape.SwitchLowering.ResumeSwitch = Switch;
-
-  size_t SuspendIndex = 0;
-  for (auto *AnyS : Shape.CoroSuspends) {
-    auto *S = cast<CoroSuspendInst>(AnyS);
-    ConstantInt *IndexVal = Shape.getIndex(SuspendIndex);
-
-    // Replace CoroSave with a store to Index:
-    //    %index.addr = getelementptr %f.frame... (index field number)
-    //    store i32 %IndexVal, i32* %index.addr1
-    auto *Save = S->getCoroSave();
-    Builder.SetInsertPoint(Save);
-    if (S->isFinal()) {
-      // The coroutine should be marked done if it reaches the final suspend
-      // point.
-      markCoroutineAsDone(Builder, Shape, FramePtr);
-    } else {
-      auto *GepIndex = Builder.CreateStructGEP(
-          FrameTy, FramePtr, Shape.getSwitchIndexField(), "index.addr");
-      Builder.CreateStore(IndexVal, GepIndex);
-    }
-
-    Save->replaceAllUsesWith(ConstantTokenNone::get(C));
-    Save->eraseFromParent();
-
-    // Split block before and after coro.suspend and add a jump from an entry
-    // switch:
-    //
-    //  whateverBB:
-    //    whatever
-    //    %0 = call i8 @llvm.coro.suspend(token none, i1 false)
-    //    switch i8 %0, label %suspend[i8 0, label %resume
-    //                                 i8 1, label %cleanup]
-    // becomes:
-    //
-    //  whateverBB:
-    //     whatever
-    //     br label %resume.0.landing
-    //
-    //  resume.0: ; <--- jump from the switch in the resume.entry
-    //     %0 = tail call i8 @llvm.coro.suspend(token none, i1 false)
-    //     br label %resume.0.landing
-    //
-    //  resume.0.landing:
-    //     %1 = phi i8[-1, %whateverBB], [%0, %resume.0]
-    //     switch i8 % 1, label %suspend [i8 0, label %resume
-    //                                    i8 1, label %cleanup]
-
-    auto *SuspendBB = S->getParent();
-    auto *ResumeBB =
-        SuspendBB->splitBasicBlock(S, "resume." + Twine(SuspendIndex));
-    auto *LandingBB = ResumeBB->splitBasicBlock(
-        S->getNextNode(), ResumeBB->getName() + Twine(".landing"));
-    Switch->addCase(IndexVal, ResumeBB);
-
-    cast<BranchInst>(SuspendBB->getTerminator())->setSuccessor(0, LandingBB);
-    auto *PN = PHINode::Create(Builder.getInt8Ty(), 2, "");
-    PN->insertBefore(LandingBB->begin());
-    S->replaceAllUsesWith(PN);
-    PN->addIncoming(Builder.getInt8(-1), SuspendBB);
-    PN->addIncoming(S, ResumeBB);
-
-    ++SuspendIndex;
-  }
-
-  Builder.SetInsertPoint(UnreachBB);
-  Builder.CreateUnreachable();
-
-  Shape.SwitchLowering.ResumeEntryBlock = NewEntry;
-}
-
 // In the resume function, we remove the last case  (when coro::Shape is built,
 // the final suspend point (if present) is always the last element of
 // CoroSuspends array) since it is an undefined behavior to resume a coroutine
@@ -1161,16 +1063,6 @@ void CoroCloner::create() {
                           /*Elide=*/ FKind == CoroCloner::Kind::SwitchCleanup);
 }
 
-// Create a resume clone by cloning the body of the original function, setting
-// new entry block and replacing coro.suspend an appropriate value to force
-// resume or cleanup pass for every suspend point.
-static Function *createClone(Function &F, const Twine &Suffix,
-                             coro::Shape &Shape, CoroCloner::Kind FKind) {
-  CoroCloner Cloner(F, Suffix, Shape, FKind);
-  Cloner.create();
-  return Cloner.getFunction();
-}
-
 static void updateAsyncFuncPointerContextSize(coro::Shape &Shape) {
   assert(Shape.ABI == coro::ABI::Async);
 
@@ -1212,67 +1104,6 @@ static void replaceFrameSizeAndAlignment(coro::Shape &Shape) {
   }
 }
 
-// Create a global constant array containing pointers to functions provided and
-// set Info parameter of CoroBegin to point at this constant. Example:
-//
-//   @f.resumers = internal constant [2 x void(%f.frame*)*]
-//                    [void(%f.frame*)* @f.resume, void(%f.frame*)* @f.destroy]
-//   define void @f() {
-//     ...
-//     call i8* @llvm.coro.begin(i8* null, i32 0, i8* null,
-//                    i8* bitcast([2 x void(%f.frame*)*] * @f.resumers to i8*))
-//
-// Assumes that all the functions have the same signature.
-static void setCoroInfo(Function &F, coro::Shape &Shape,
-                        ArrayRef<Function *> Fns) {
-  // This only works under the switch-lowering ABI because coro elision
-  // only works on the switch-lowering ABI.
-  assert(Shape.ABI == coro::ABI::Switch);
-
-  SmallVector<Constant *, 4> Args(Fns.begin(), Fns.end());
-  assert(!Args.empty());
-  Function *Part = *Fns.begin();
-  Module *M = Part->getParent();
-  auto *ArrTy = ArrayType::get(Part->getType(), Args.size());
-
-  auto *ConstVal = ConstantArray::get(ArrTy, Args);
-  auto *GV = new GlobalVariable(*M, ConstVal->getType(), /*isConstant=*/true,
-                                GlobalVariable::PrivateLinkage, ConstVal,
-                                F.getName() + Twine(".resumers"));
-
-  // Update coro.begin instruction to refer to this constant.
-  LLVMContext &C = F.getContext();
-  auto *BC = ConstantExpr::getPointerCast(GV, PointerType::getUnqual(C));
-  Shape.getSwitchCoroId()->setInfo(BC);
-}
-
-// Store addresses of Resume/Destroy/Cleanup functions in the coroutine frame.
-static void updateCoroFrame(coro::Shape &Shape, Function *ResumeFn,
-                            Function *DestroyFn, Function *CleanupFn) {
-  assert(Shape.ABI == coro::ABI::Switch);
-
-  IRBuilder<> Builder(&*Shape.getInsertPtAfterFramePtr());
-
-  auto *ResumeAddr = Builder.CreateStructGEP(
-      Shape.FrameTy, Shape.FramePtr, coro::Shape::SwitchFieldIndex::Resume,
-      "resume.addr");
-  Builder.CreateStore(ResumeFn, ResumeAddr);
-
-  Value *DestroyOrCleanupFn = DestroyFn;
-
-  CoroIdInst *CoroId = Shape.getSwitchCoroId();
-  if (CoroAllocInst *CA = CoroId->getCoroAlloc()) {
-    // If there is a CoroAlloc and it returns false (meaning we elide the
-    // allocation, use CleanupFn instead of DestroyFn).
-    DestroyOrCleanupFn = Builder.CreateSelect(CA, DestroyFn, CleanupFn);
-  }
-
-  auto *DestroyAddr = Builder.CreateStructGEP(
-      Shape.FrameTy, Shape.FramePtr, coro::Shape::SwitchFieldIndex::Destroy,
-      "destroy.addr");
-  Builder.CreateStore(DestroyOrCleanupFn, DestroyAddr);
-}
-
 static void postSplitCleanup(Function &F) {
   removeUnreachableBlocks(F);
 
@@ -1447,34 +1278,6 @@ static bool shouldBeMustTail(const CallInst &CI, const Function &F) {
   return true;
 }
 
-// Add musttail to any resume instructions that is immediately followed by a
-// suspend (i.e. ret). We do this even in -O0 to support guaranteed tail call
-// for symmetrical coroutine control transfer (C++ Coroutines TS extension).
-// This transformation is done only in the resume part of the coroutine that has
-// identical signature and calling convention as the coro.resume call.
-static void addMustTailToCoroResumes(Function &F, TargetTransformInfo &TTI) {
-  bool changed = false;
-
-  // Collect potential resume instructions.
-  SmallVector<CallInst *, 4> Resumes;
-  for (auto &I : instructions(F))
-    if (auto *Call = dyn_cast<CallInst>(&I))
-      if (shouldBeMustTail(*Call, F))
-        Resumes.push_back(Call);
-
-  // Set musttail on those that are followed by a ret instruction.
-  for (CallInst *Call : Resumes)
-    // Skip targets which don't support tail call on the specific case.
-    if (TTI.supportsTailCallFor(Call) &&
-        simplifyTerminatorLeadingToRet(Call->getNextNode())) {
-      Call->setTailCallKind(CallInst::TCK_MustTail);
-      changed = true;
-    }
-
-  if (changed)
-    removeUnreachableBlocks(F);
-}
-
 // Coroutine has no suspend points. Remove heap allocation for the coroutine
 // frame if possible.
 static void handleNoSuspendCoroutine(coro::Shape &Shape) {
@@ -1678,44 +1481,244 @@ static void simplifySuspendPoints(coro::Shape &Shape) {
   }
 }
 
-static void splitSwitchCoroutine(Function &F, coro::Shape &Shape,
-                                 SmallVectorImpl<Function *> &Clones,
-                                 TargetTransformInfo &TTI) {
-  assert(Shape.ABI == coro::ABI::Switch);
-
-  createResumeEntryBlock(F, Shape);
-  auto ResumeClone = createClone(F, ".resume", Shape,
-                                 CoroCloner::Kind::SwitchResume);
-  auto DestroyClone = createClone(F, ".destroy", Shape,
-                                  CoroCloner::Kind::SwitchUnwind);
-  auto CleanupClone = createClone(F, ".cleanup", Shape,
-                                  CoroCloner::Kind::SwitchCleanup);
-
-  postSplitCleanup(*ResumeClone);
-  postSplitCleanup(*DestroyClone);
-  postSplitCleanup(*CleanupClone);
-
-  // Adding musttail call to support symmetric transfer.
-  // Skip targets which don't support tail call.
-  //
-  // FIXME: Could we support symmetric transfer effectively without musttail
-  // call?
-  if (TTI.supportsTailCalls())
-    addMustTailToCoroResumes(*ResumeClone, TTI);
+namespace {
 
-  // Store addresses resume/destroy/cleanup functions in the coroutine frame.
-  updateCoroFrame(Shape, ResumeClone, DestroyClone, CleanupClone);
+struct SwitchCoroutineSplitter {
+  static void split(Function &F, coro::Shape &Shape,
+                    SmallVectorImpl<Function *> &Clones,
+                    TargetTransformInfo &TTI) {
+    assert(Shape.ABI == coro::ABI::Switch);
 
-  assert(Clones.empty());
-  Clones.push_back(ResumeClone);
-  Clones.push_back(DestroyClone);
-  Clones.push_back(CleanupClone);
-
-  // Create a constant array referring to resume/destroy/clone functions pointed
-  // by the last argument of @llvm.coro.info, so that CoroElide pass can
-  // determined correct function to call.
-  setCoroInfo(F, Shape, Clones);
-}
+    createResumeEntryBlock(F, Shape);
+    auto *ResumeClone =
+        createClone(F, ".resume", Shape, CoroCloner::Kind::SwitchResume);
+    auto *DestroyClone =
+        createClone(F, ".destroy", Shape, CoroCloner::Kind::SwitchUnwind);
+    auto *CleanupClone =
+        createClone(F, ".cleanup", Shape, CoroCloner::Kind::SwitchCleanup);
+
+    postSplitCleanup(*ResumeClone);
+    postSplitCleanup(*DestroyClone);
+    postSplitCleanup(*CleanupClone);
+
+    // Adding musttail call to support symmetric transfer.
+    // Skip targets which don't support tail call.
+    //
+    // FIXME: Could we support symmetric transfer effectively without musttail
+    // call?
+    if (TTI.supportsTailCalls())
+      addMustTailToCoroResumes(*ResumeClone, TTI);
+
+    // Store addresses resume/destroy/cleanup functions in the coroutine frame.
+    updateCoroFrame(Shape, ResumeClone, DestroyClone, CleanupClone);
+
+    assert(Clones.empty());
+    Clones.push_back(ResumeClone);
+    Clones.push_back(DestroyClone);
+    Clones.push_back(CleanupClone);
+
+    // Create a constant array referring to resume/destroy/clone functions
+    // pointed by the last argument of @llvm.coro.info, so that CoroElide pass
+    // can determined correct function to call.
+    setCoroInfo(F, Shape, Clones);
+  }
+
+private:
+  // Create a resume clone by cloning the body of the original function, setting
+  // new entry block and replacing coro.suspend an appropriate value to force
+  // resume or cleanup pass for every suspend point.
+  static Function *createClone(Function &F, const Twine &Suffix,
+                               coro::Shape &Shape, CoroCloner::Kind FKind) {
+    CoroCloner Cloner(F, Suffix, Shape, FKind);
+    Cloner.create();
+    return Cloner.getFunction();
+  }
+
+  // Create an entry block for a resume function with a switch that will jump to
+  // suspend points.
+  static void createResumeEntryBlock(Function &F, coro::Shape &Shape) {
+    LLVMContext &C = F.getContext();
+
+    // resume.entry:
+    //  %index.addr = getelementptr inbounds %f.Frame, %f.Frame* %FramePtr, i32
+    //  0, i32 2 % index = load i32, i32* %index.addr switch i32 %index, label
+    //  %unreachable [
+    //    i32 0, label %resume.0
+    //    i32 1, label %resume.1
+    //    ...
+    //  ]
+
+    auto *NewEntry = BasicBlock::Create(C, "resume.entry", &F);
+    auto *UnreachBB = BasicBlock::Create(C, "unreachable", &F);
+
+    IRBuilder<> Builder(NewEntry);
+    auto *FramePtr = Shape.FramePtr;
+    auto *FrameTy = Shape.FrameTy;
+    auto *GepIndex = Builder.CreateStructGEP(
+        FrameTy, FramePtr, Shape.getSwitchIndexField(), "index.addr");
+    auto *Index = Builder.CreateLoad(Shape.getIndexType(), GepIndex, "index");
+    auto *Switch =
+        Builder.CreateSwitch(Index, UnreachBB, Shape.CoroSuspends.size());
+    Shape.SwitchLowering.ResumeSwitch = Switch;
+
+    size_t SuspendIndex = 0;
+    for (auto *AnyS : Shape.CoroSuspends) {
+      auto *S = cast<CoroSuspendInst>(AnyS);
+      ConstantInt *IndexVal = Shape.getIndex(SuspendIndex);
+
+      // Replace CoroSave with a store to Index:
+      //    %index.addr = getelementptr %f.frame... (index field number)
+      //    store i32 %IndexVal, i32* %index.addr1
+      auto *Save = S->getCoroSave();
+      Builder.SetInsertPoint(Save);
+      if (S->isFinal()) {
+        // The coroutine should be marked done if it reaches the final suspend
+        // point.
+        markCoroutineAsDone(Builder, Shape, FramePtr);
+      } else {
+        auto *GepIndex = Builder.CreateStructGEP(
+            FrameTy, FramePtr, Shape.getSwitchIndexField(), "index.addr");
+        Builder.CreateStore(IndexVal, GepIndex);
+      }
+
+      Save->replaceAllUsesWith(ConstantTokenNone::get(C));
+      Save->eraseFromParent();
+
+      // Split block before and after coro.suspend and add a jump from an entry
+      // switch:
+      //
+      //  whateverBB:
+      //    whatever
+      //    %0 = call i8 @llvm.coro.suspend(token none, i1 false)
+      //    switch i8 %0, label %suspend[i8 0, label %resume
+      //                                 i8 1, label %cleanup]
+      // becomes:
+      //
+      //  whateverBB:
+      //     whatever
+      //     br label %resume.0.landing
+      //
+      //  resume.0: ; <--- jump from the switch in the resume.entry
+      //     %0 = tail call i8 @llvm.coro.suspend(token none, i1 false)
+      //     br label %resume.0.landing
+      //
+      //  resume.0.landing:
+      //     %1 = phi i8[-1, %whateverBB], [%0, %resume.0]
+      //     switch i8 % 1, label %suspend [i8 0, label %resume
+      //                                    i8 1, label %cleanup]
+
+      auto *SuspendBB = S->getParent();
+      auto *ResumeBB =
+          SuspendBB->splitBasicBlock(S, "resume." + Twine(SuspendIndex));
+      auto *LandingBB = ResumeBB->splitBasicBlock(
+          S->getNextNode(), ResumeBB->getName() + Twine(".landing"));
+      Switch->addCase(IndexVal, ResumeBB);
+
+      cast<BranchInst>(SuspendBB->getTerminator())->setSuccessor(0, LandingBB);
+      auto *PN = PHINode::Create(Builder.getInt8Ty(), 2, "");
+      PN->insertBefore(LandingBB->begin());
+      S->replaceAllUsesWith(PN);
+      PN->addIncoming(Builder.getInt8(-1), SuspendBB);
+      PN->addIncoming(S, ResumeBB);
+
+      ++SuspendIndex;
+    }
+
+    Builder.SetInsertPoint(UnreachBB);
+    Builder.CreateUnreachable();
+
+    Shape.SwitchLowering.ResumeEntryBlock = NewEntry;
+  }
+
+  // Add musttail to any resume instructions that is immediately followed by a
+  // suspend (i.e. ret). We do this even in -O0 to support guaranteed tail call
+  // for symmetrical coroutine control transfer (C++ Coroutines TS extension).
+  // This transformation is done only in the resume part of the coroutine that
+  // has identical signature and calling convention as the coro.resume call.
+  static void addMustTailToCoroResumes(Function &F, TargetTransformInfo &TTI) {
+    bool Changed = false;
+
+    // Collect potential resume instructions.
+    SmallVector<CallInst *, 4> Resumes;
+    for (auto &I : instructions(F))
+      if (auto *Call = dyn_cast<CallInst>(&I))
+        if (shouldBeMustTail(*Call, F))
+          Resumes.push_back(Call);
+
+    // Set musttail on those that are followed by a ret instruction.
+    for (CallInst *Call : Resumes)
+      // Skip targets which don't support tail call on the specific case.
+      if (TTI.supportsTailCallFor(Call) &&
+          simplifyTerminatorLeadingToRet(Call->getNextNode())) {
+        Call->setTailCallKind(CallInst::TCK_MustTail);
+        Changed = true;
+      }
+
+    if (Changed)
+      removeUnreachableBlocks(F);
+  }
+
+  // Store addresses of Resume/Destroy/Cleanup functions in the coroutine frame.
+  static void updateCoroFrame(coro::Shape &Shape, Function *ResumeFn,
+                              Function *DestroyFn, Function *CleanupFn) {
+    IRBuilder<> Builder(&*Shape.getInsertPtAfterFramePtr());
+
+    auto *ResumeAddr = Builder.CreateStructGEP(
+        Shape.FrameTy, Shape.FramePtr, coro::Shape::SwitchFieldIndex::Resume,
+        "resume.addr");
+    Builder.CreateStore(ResumeFn, ResumeAddr);
+
+    Value *DestroyOrCleanupFn = DestroyFn;
+
+    CoroIdInst *CoroId = Shape.getSwitchCoroId();
+    if (CoroAllocInst *CA = CoroId->getCoroAlloc()) {
+      // If there is a CoroAlloc and it returns false (meaning we elide the
+      // allocation, use CleanupFn instead of DestroyFn).
+      DestroyOrCleanupFn = Builder.CreateSelect(CA, DestroyFn, CleanupFn);
+    }
+
+    auto *DestroyAddr = Builder.CreateStructGEP(
+        Shape.FrameTy, Shape.FramePtr, coro::Shape::SwitchFieldIndex::Destroy,
+        "destroy.addr");
+    Builder.CreateStore(DestroyOrCleanupFn, DestroyAddr);
+  }
+
+  // Create a global constant array containing pointers to functions provided
+  // and set Info parameter of CoroBegin to point at this constant. Example:
+  //
+  //   @f.resumers = internal constant [2 x void(%f.frame*)*]
+  //                    [void(%f.frame*)* @f.resume, void(%f.frame*)*
+  //                    @f.destroy]
+  //   define void @f() {
+  //     ...
+  //     call i8* @llvm.coro.begin(i8* null, i32 0, i8* null,
+  //                    i8* bitcast([2 x void(%f.frame*)*] * @f.resumers to
+  //                    i8*))
+  //
+  // Assumes that all the functions have the same signature.
+  static void setCoroInfo(Function &F, coro::Shape &Shape,
+                          ArrayRef<Function *> Fns) {
+    // This only works under the switch-lowering ABI because coro elision
+    // only works on the switch-lowering ABI.
+    SmallVector<Constant *, 4> Args(Fns.begin(), Fns.end());
+    assert(!Args.empty());
+    Function *Part = *Fns.begin();
+    Module *M = Part->getParent();
+    auto *ArrTy = ArrayType::get(Part->getType(), Args.size());
+
+    auto *ConstVal = ConstantArray::get(ArrTy, Args);
+    auto *GV = new GlobalVariable(*M, ConstVal->getType(), /*isConstant=*/true,
+                                  GlobalVariable::PrivateLinkage, ConstVal,
+                                  F.getName() + Twine(".resumers"));
+
+    // Update coro.begin instruction to refer to this constant.
+    LLVMContext &C = F.getContext();
+    auto *BC = ConstantExpr::getPointerCast(GV, PointerType::getUnqual(C));
+    Shape.getSwitchCoroId()->setInfo(BC);
+  }
+};
+
+} // namespace
 
 static void replaceAsyncResumeFunction(CoroSuspendAsyncInst *Suspend,
                                        Value *Continuation) {
@@ -2027,7 +2030,7 @@ splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones,
   } else {
     switch (Shape.ABI) {
     case coro::ABI::Switch:
-      splitSwitchCoroutine(F, Shape, Clones, TTI);
+      SwitchCoroutineSplitter::split(F, Shape, Clones, TTI);
       break;
     case coro::ABI::Async:
       splitAsyncCoroutine(F, Shape, Clones);



More information about the llvm-commits mailing list