[llvm] [Coroutines][NFC] Refactor CoroSplit for Switch Resume ABI (PR #80758)
Yuxuan Chen via llvm-commits
llvm-commits at lists.llvm.org
Mon Feb 5 15:04:27 PST 2024
https://github.com/yuxuanchen1997 updated https://github.com/llvm/llvm-project/pull/80758
>From f9b13c03b21bb4d96d23c3262f990a698c84483e Mon Sep 17 00:00:00 2001
From: Yuxuan Chen <yuxuanchen1997 at outlook.com>
Date: Mon, 5 Feb 2024 14:28:11 -0800
Subject: [PATCH] put switch resume splitting in one struct
---
llvm/lib/Transforms/Coroutines/CoroSplit.cpp | 471 ++++++++++---------
1 file changed, 237 insertions(+), 234 deletions(-)
diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
index 7758b52abc204..a552b3ac20085 100644
--- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
@@ -407,104 +407,6 @@ static void replaceCoroEnd(AnyCoroEndInst *End, const coro::Shape &Shape,
End->eraseFromParent();
}
-// Create an entry block for a resume function with a switch that will jump to
-// suspend points.
-static void createResumeEntryBlock(Function &F, coro::Shape &Shape) {
- assert(Shape.ABI == coro::ABI::Switch);
- LLVMContext &C = F.getContext();
-
- // resume.entry:
- // %index.addr = getelementptr inbounds %f.Frame, %f.Frame* %FramePtr, i32 0,
- // i32 2
- // % index = load i32, i32* %index.addr
- // switch i32 %index, label %unreachable [
- // i32 0, label %resume.0
- // i32 1, label %resume.1
- // ...
- // ]
-
- auto *NewEntry = BasicBlock::Create(C, "resume.entry", &F);
- auto *UnreachBB = BasicBlock::Create(C, "unreachable", &F);
-
- IRBuilder<> Builder(NewEntry);
- auto *FramePtr = Shape.FramePtr;
- auto *FrameTy = Shape.FrameTy;
- auto *GepIndex = Builder.CreateStructGEP(
- FrameTy, FramePtr, Shape.getSwitchIndexField(), "index.addr");
- auto *Index = Builder.CreateLoad(Shape.getIndexType(), GepIndex, "index");
- auto *Switch =
- Builder.CreateSwitch(Index, UnreachBB, Shape.CoroSuspends.size());
- Shape.SwitchLowering.ResumeSwitch = Switch;
-
- size_t SuspendIndex = 0;
- for (auto *AnyS : Shape.CoroSuspends) {
- auto *S = cast<CoroSuspendInst>(AnyS);
- ConstantInt *IndexVal = Shape.getIndex(SuspendIndex);
-
- // Replace CoroSave with a store to Index:
- // %index.addr = getelementptr %f.frame... (index field number)
- // store i32 %IndexVal, i32* %index.addr1
- auto *Save = S->getCoroSave();
- Builder.SetInsertPoint(Save);
- if (S->isFinal()) {
- // The coroutine should be marked done if it reaches the final suspend
- // point.
- markCoroutineAsDone(Builder, Shape, FramePtr);
- } else {
- auto *GepIndex = Builder.CreateStructGEP(
- FrameTy, FramePtr, Shape.getSwitchIndexField(), "index.addr");
- Builder.CreateStore(IndexVal, GepIndex);
- }
-
- Save->replaceAllUsesWith(ConstantTokenNone::get(C));
- Save->eraseFromParent();
-
- // Split block before and after coro.suspend and add a jump from an entry
- // switch:
- //
- // whateverBB:
- // whatever
- // %0 = call i8 @llvm.coro.suspend(token none, i1 false)
- // switch i8 %0, label %suspend[i8 0, label %resume
- // i8 1, label %cleanup]
- // becomes:
- //
- // whateverBB:
- // whatever
- // br label %resume.0.landing
- //
- // resume.0: ; <--- jump from the switch in the resume.entry
- // %0 = tail call i8 @llvm.coro.suspend(token none, i1 false)
- // br label %resume.0.landing
- //
- // resume.0.landing:
- // %1 = phi i8[-1, %whateverBB], [%0, %resume.0]
- // switch i8 % 1, label %suspend [i8 0, label %resume
- // i8 1, label %cleanup]
-
- auto *SuspendBB = S->getParent();
- auto *ResumeBB =
- SuspendBB->splitBasicBlock(S, "resume." + Twine(SuspendIndex));
- auto *LandingBB = ResumeBB->splitBasicBlock(
- S->getNextNode(), ResumeBB->getName() + Twine(".landing"));
- Switch->addCase(IndexVal, ResumeBB);
-
- cast<BranchInst>(SuspendBB->getTerminator())->setSuccessor(0, LandingBB);
- auto *PN = PHINode::Create(Builder.getInt8Ty(), 2, "");
- PN->insertBefore(LandingBB->begin());
- S->replaceAllUsesWith(PN);
- PN->addIncoming(Builder.getInt8(-1), SuspendBB);
- PN->addIncoming(S, ResumeBB);
-
- ++SuspendIndex;
- }
-
- Builder.SetInsertPoint(UnreachBB);
- Builder.CreateUnreachable();
-
- Shape.SwitchLowering.ResumeEntryBlock = NewEntry;
-}
-
// In the resume function, we remove the last case (when coro::Shape is built,
// the final suspend point (if present) is always the last element of
// CoroSuspends array) since it is an undefined behavior to resume a coroutine
@@ -1161,16 +1063,6 @@ void CoroCloner::create() {
/*Elide=*/ FKind == CoroCloner::Kind::SwitchCleanup);
}
-// Create a resume clone by cloning the body of the original function, setting
-// new entry block and replacing coro.suspend an appropriate value to force
-// resume or cleanup pass for every suspend point.
-static Function *createClone(Function &F, const Twine &Suffix,
- coro::Shape &Shape, CoroCloner::Kind FKind) {
- CoroCloner Cloner(F, Suffix, Shape, FKind);
- Cloner.create();
- return Cloner.getFunction();
-}
-
static void updateAsyncFuncPointerContextSize(coro::Shape &Shape) {
assert(Shape.ABI == coro::ABI::Async);
@@ -1212,67 +1104,6 @@ static void replaceFrameSizeAndAlignment(coro::Shape &Shape) {
}
}
-// Create a global constant array containing pointers to functions provided and
-// set Info parameter of CoroBegin to point at this constant. Example:
-//
-// @f.resumers = internal constant [2 x void(%f.frame*)*]
-// [void(%f.frame*)* @f.resume, void(%f.frame*)* @f.destroy]
-// define void @f() {
-// ...
-// call i8* @llvm.coro.begin(i8* null, i32 0, i8* null,
-// i8* bitcast([2 x void(%f.frame*)*] * @f.resumers to i8*))
-//
-// Assumes that all the functions have the same signature.
-static void setCoroInfo(Function &F, coro::Shape &Shape,
- ArrayRef<Function *> Fns) {
- // This only works under the switch-lowering ABI because coro elision
- // only works on the switch-lowering ABI.
- assert(Shape.ABI == coro::ABI::Switch);
-
- SmallVector<Constant *, 4> Args(Fns.begin(), Fns.end());
- assert(!Args.empty());
- Function *Part = *Fns.begin();
- Module *M = Part->getParent();
- auto *ArrTy = ArrayType::get(Part->getType(), Args.size());
-
- auto *ConstVal = ConstantArray::get(ArrTy, Args);
- auto *GV = new GlobalVariable(*M, ConstVal->getType(), /*isConstant=*/true,
- GlobalVariable::PrivateLinkage, ConstVal,
- F.getName() + Twine(".resumers"));
-
- // Update coro.begin instruction to refer to this constant.
- LLVMContext &C = F.getContext();
- auto *BC = ConstantExpr::getPointerCast(GV, PointerType::getUnqual(C));
- Shape.getSwitchCoroId()->setInfo(BC);
-}
-
-// Store addresses of Resume/Destroy/Cleanup functions in the coroutine frame.
-static void updateCoroFrame(coro::Shape &Shape, Function *ResumeFn,
- Function *DestroyFn, Function *CleanupFn) {
- assert(Shape.ABI == coro::ABI::Switch);
-
- IRBuilder<> Builder(&*Shape.getInsertPtAfterFramePtr());
-
- auto *ResumeAddr = Builder.CreateStructGEP(
- Shape.FrameTy, Shape.FramePtr, coro::Shape::SwitchFieldIndex::Resume,
- "resume.addr");
- Builder.CreateStore(ResumeFn, ResumeAddr);
-
- Value *DestroyOrCleanupFn = DestroyFn;
-
- CoroIdInst *CoroId = Shape.getSwitchCoroId();
- if (CoroAllocInst *CA = CoroId->getCoroAlloc()) {
- // If there is a CoroAlloc and it returns false (meaning we elide the
- // allocation, use CleanupFn instead of DestroyFn).
- DestroyOrCleanupFn = Builder.CreateSelect(CA, DestroyFn, CleanupFn);
- }
-
- auto *DestroyAddr = Builder.CreateStructGEP(
- Shape.FrameTy, Shape.FramePtr, coro::Shape::SwitchFieldIndex::Destroy,
- "destroy.addr");
- Builder.CreateStore(DestroyOrCleanupFn, DestroyAddr);
-}
-
static void postSplitCleanup(Function &F) {
removeUnreachableBlocks(F);
@@ -1447,34 +1278,6 @@ static bool shouldBeMustTail(const CallInst &CI, const Function &F) {
return true;
}
-// Add musttail to any resume instructions that is immediately followed by a
-// suspend (i.e. ret). We do this even in -O0 to support guaranteed tail call
-// for symmetrical coroutine control transfer (C++ Coroutines TS extension).
-// This transformation is done only in the resume part of the coroutine that has
-// identical signature and calling convention as the coro.resume call.
-static void addMustTailToCoroResumes(Function &F, TargetTransformInfo &TTI) {
- bool changed = false;
-
- // Collect potential resume instructions.
- SmallVector<CallInst *, 4> Resumes;
- for (auto &I : instructions(F))
- if (auto *Call = dyn_cast<CallInst>(&I))
- if (shouldBeMustTail(*Call, F))
- Resumes.push_back(Call);
-
- // Set musttail on those that are followed by a ret instruction.
- for (CallInst *Call : Resumes)
- // Skip targets which don't support tail call on the specific case.
- if (TTI.supportsTailCallFor(Call) &&
- simplifyTerminatorLeadingToRet(Call->getNextNode())) {
- Call->setTailCallKind(CallInst::TCK_MustTail);
- changed = true;
- }
-
- if (changed)
- removeUnreachableBlocks(F);
-}
-
// Coroutine has no suspend points. Remove heap allocation for the coroutine
// frame if possible.
static void handleNoSuspendCoroutine(coro::Shape &Shape) {
@@ -1678,44 +1481,244 @@ static void simplifySuspendPoints(coro::Shape &Shape) {
}
}
-static void splitSwitchCoroutine(Function &F, coro::Shape &Shape,
- SmallVectorImpl<Function *> &Clones,
- TargetTransformInfo &TTI) {
- assert(Shape.ABI == coro::ABI::Switch);
-
- createResumeEntryBlock(F, Shape);
- auto ResumeClone = createClone(F, ".resume", Shape,
- CoroCloner::Kind::SwitchResume);
- auto DestroyClone = createClone(F, ".destroy", Shape,
- CoroCloner::Kind::SwitchUnwind);
- auto CleanupClone = createClone(F, ".cleanup", Shape,
- CoroCloner::Kind::SwitchCleanup);
-
- postSplitCleanup(*ResumeClone);
- postSplitCleanup(*DestroyClone);
- postSplitCleanup(*CleanupClone);
-
- // Adding musttail call to support symmetric transfer.
- // Skip targets which don't support tail call.
- //
- // FIXME: Could we support symmetric transfer effectively without musttail
- // call?
- if (TTI.supportsTailCalls())
- addMustTailToCoroResumes(*ResumeClone, TTI);
+namespace {
- // Store addresses resume/destroy/cleanup functions in the coroutine frame.
- updateCoroFrame(Shape, ResumeClone, DestroyClone, CleanupClone);
+struct SwitchCoroutineSplitter {
+ static void split(Function &F, coro::Shape &Shape,
+ SmallVectorImpl<Function *> &Clones,
+ TargetTransformInfo &TTI) {
+ assert(Shape.ABI == coro::ABI::Switch);
- assert(Clones.empty());
- Clones.push_back(ResumeClone);
- Clones.push_back(DestroyClone);
- Clones.push_back(CleanupClone);
-
- // Create a constant array referring to resume/destroy/clone functions pointed
- // by the last argument of @llvm.coro.info, so that CoroElide pass can
- // determined correct function to call.
- setCoroInfo(F, Shape, Clones);
-}
+ createResumeEntryBlock(F, Shape);
+ auto *ResumeClone =
+ createClone(F, ".resume", Shape, CoroCloner::Kind::SwitchResume);
+ auto *DestroyClone =
+ createClone(F, ".destroy", Shape, CoroCloner::Kind::SwitchUnwind);
+ auto *CleanupClone =
+ createClone(F, ".cleanup", Shape, CoroCloner::Kind::SwitchCleanup);
+
+ postSplitCleanup(*ResumeClone);
+ postSplitCleanup(*DestroyClone);
+ postSplitCleanup(*CleanupClone);
+
+ // Adding musttail call to support symmetric transfer.
+ // Skip targets which don't support tail call.
+ //
+ // FIXME: Could we support symmetric transfer effectively without musttail
+ // call?
+ if (TTI.supportsTailCalls())
+ addMustTailToCoroResumes(*ResumeClone, TTI);
+
+ // Store addresses resume/destroy/cleanup functions in the coroutine frame.
+ updateCoroFrame(Shape, ResumeClone, DestroyClone, CleanupClone);
+
+ assert(Clones.empty());
+ Clones.push_back(ResumeClone);
+ Clones.push_back(DestroyClone);
+ Clones.push_back(CleanupClone);
+
+ // Create a constant array referring to resume/destroy/clone functions
+ // pointed by the last argument of @llvm.coro.info, so that CoroElide pass
+ // can determined correct function to call.
+ setCoroInfo(F, Shape, Clones);
+ }
+
+private:
+ // Create a resume clone by cloning the body of the original function, setting
+ // new entry block and replacing coro.suspend an appropriate value to force
+ // resume or cleanup pass for every suspend point.
+ static Function *createClone(Function &F, const Twine &Suffix,
+ coro::Shape &Shape, CoroCloner::Kind FKind) {
+ CoroCloner Cloner(F, Suffix, Shape, FKind);
+ Cloner.create();
+ return Cloner.getFunction();
+ }
+
+ // Create an entry block for a resume function with a switch that will jump to
+ // suspend points.
+ static void createResumeEntryBlock(Function &F, coro::Shape &Shape) {
+ LLVMContext &C = F.getContext();
+
+ // resume.entry:
+ // %index.addr = getelementptr inbounds %f.Frame, %f.Frame* %FramePtr, i32
+ // 0, i32 2 % index = load i32, i32* %index.addr switch i32 %index, label
+ // %unreachable [
+ // i32 0, label %resume.0
+ // i32 1, label %resume.1
+ // ...
+ // ]
+
+ auto *NewEntry = BasicBlock::Create(C, "resume.entry", &F);
+ auto *UnreachBB = BasicBlock::Create(C, "unreachable", &F);
+
+ IRBuilder<> Builder(NewEntry);
+ auto *FramePtr = Shape.FramePtr;
+ auto *FrameTy = Shape.FrameTy;
+ auto *GepIndex = Builder.CreateStructGEP(
+ FrameTy, FramePtr, Shape.getSwitchIndexField(), "index.addr");
+ auto *Index = Builder.CreateLoad(Shape.getIndexType(), GepIndex, "index");
+ auto *Switch =
+ Builder.CreateSwitch(Index, UnreachBB, Shape.CoroSuspends.size());
+ Shape.SwitchLowering.ResumeSwitch = Switch;
+
+ size_t SuspendIndex = 0;
+ for (auto *AnyS : Shape.CoroSuspends) {
+ auto *S = cast<CoroSuspendInst>(AnyS);
+ ConstantInt *IndexVal = Shape.getIndex(SuspendIndex);
+
+ // Replace CoroSave with a store to Index:
+ // %index.addr = getelementptr %f.frame... (index field number)
+ // store i32 %IndexVal, i32* %index.addr1
+ auto *Save = S->getCoroSave();
+ Builder.SetInsertPoint(Save);
+ if (S->isFinal()) {
+ // The coroutine should be marked done if it reaches the final suspend
+ // point.
+ markCoroutineAsDone(Builder, Shape, FramePtr);
+ } else {
+ auto *GepIndex = Builder.CreateStructGEP(
+ FrameTy, FramePtr, Shape.getSwitchIndexField(), "index.addr");
+ Builder.CreateStore(IndexVal, GepIndex);
+ }
+
+ Save->replaceAllUsesWith(ConstantTokenNone::get(C));
+ Save->eraseFromParent();
+
+ // Split block before and after coro.suspend and add a jump from an entry
+ // switch:
+ //
+ // whateverBB:
+ // whatever
+ // %0 = call i8 @llvm.coro.suspend(token none, i1 false)
+ // switch i8 %0, label %suspend[i8 0, label %resume
+ // i8 1, label %cleanup]
+ // becomes:
+ //
+ // whateverBB:
+ // whatever
+ // br label %resume.0.landing
+ //
+ // resume.0: ; <--- jump from the switch in the resume.entry
+ // %0 = tail call i8 @llvm.coro.suspend(token none, i1 false)
+ // br label %resume.0.landing
+ //
+ // resume.0.landing:
+ // %1 = phi i8[-1, %whateverBB], [%0, %resume.0]
+ // switch i8 % 1, label %suspend [i8 0, label %resume
+ // i8 1, label %cleanup]
+
+ auto *SuspendBB = S->getParent();
+ auto *ResumeBB =
+ SuspendBB->splitBasicBlock(S, "resume." + Twine(SuspendIndex));
+ auto *LandingBB = ResumeBB->splitBasicBlock(
+ S->getNextNode(), ResumeBB->getName() + Twine(".landing"));
+ Switch->addCase(IndexVal, ResumeBB);
+
+ cast<BranchInst>(SuspendBB->getTerminator())->setSuccessor(0, LandingBB);
+ auto *PN = PHINode::Create(Builder.getInt8Ty(), 2, "");
+ PN->insertBefore(LandingBB->begin());
+ S->replaceAllUsesWith(PN);
+ PN->addIncoming(Builder.getInt8(-1), SuspendBB);
+ PN->addIncoming(S, ResumeBB);
+
+ ++SuspendIndex;
+ }
+
+ Builder.SetInsertPoint(UnreachBB);
+ Builder.CreateUnreachable();
+
+ Shape.SwitchLowering.ResumeEntryBlock = NewEntry;
+ }
+
+ // Add musttail to any resume instructions that is immediately followed by a
+ // suspend (i.e. ret). We do this even in -O0 to support guaranteed tail call
+ // for symmetrical coroutine control transfer (C++ Coroutines TS extension).
+ // This transformation is done only in the resume part of the coroutine that
+ // has identical signature and calling convention as the coro.resume call.
+ static void addMustTailToCoroResumes(Function &F, TargetTransformInfo &TTI) {
+ bool Changed = false;
+
+ // Collect potential resume instructions.
+ SmallVector<CallInst *, 4> Resumes;
+ for (auto &I : instructions(F))
+ if (auto *Call = dyn_cast<CallInst>(&I))
+ if (shouldBeMustTail(*Call, F))
+ Resumes.push_back(Call);
+
+ // Set musttail on those that are followed by a ret instruction.
+ for (CallInst *Call : Resumes)
+ // Skip targets which don't support tail call on the specific case.
+ if (TTI.supportsTailCallFor(Call) &&
+ simplifyTerminatorLeadingToRet(Call->getNextNode())) {
+ Call->setTailCallKind(CallInst::TCK_MustTail);
+ Changed = true;
+ }
+
+ if (Changed)
+ removeUnreachableBlocks(F);
+ }
+
+ // Store addresses of Resume/Destroy/Cleanup functions in the coroutine frame.
+ static void updateCoroFrame(coro::Shape &Shape, Function *ResumeFn,
+ Function *DestroyFn, Function *CleanupFn) {
+ IRBuilder<> Builder(&*Shape.getInsertPtAfterFramePtr());
+
+ auto *ResumeAddr = Builder.CreateStructGEP(
+ Shape.FrameTy, Shape.FramePtr, coro::Shape::SwitchFieldIndex::Resume,
+ "resume.addr");
+ Builder.CreateStore(ResumeFn, ResumeAddr);
+
+ Value *DestroyOrCleanupFn = DestroyFn;
+
+ CoroIdInst *CoroId = Shape.getSwitchCoroId();
+ if (CoroAllocInst *CA = CoroId->getCoroAlloc()) {
+ // If there is a CoroAlloc and it returns false (meaning we elide the
+ // allocation, use CleanupFn instead of DestroyFn).
+ DestroyOrCleanupFn = Builder.CreateSelect(CA, DestroyFn, CleanupFn);
+ }
+
+ auto *DestroyAddr = Builder.CreateStructGEP(
+ Shape.FrameTy, Shape.FramePtr, coro::Shape::SwitchFieldIndex::Destroy,
+ "destroy.addr");
+ Builder.CreateStore(DestroyOrCleanupFn, DestroyAddr);
+ }
+
+ // Create a global constant array containing pointers to functions provided
+ // and set Info parameter of CoroBegin to point at this constant. Example:
+ //
+ // @f.resumers = internal constant [2 x void(%f.frame*)*]
+ // [void(%f.frame*)* @f.resume, void(%f.frame*)*
+ // @f.destroy]
+ // define void @f() {
+ // ...
+ // call i8* @llvm.coro.begin(i8* null, i32 0, i8* null,
+ // i8* bitcast([2 x void(%f.frame*)*] * @f.resumers to
+ // i8*))
+ //
+ // Assumes that all the functions have the same signature.
+ static void setCoroInfo(Function &F, coro::Shape &Shape,
+ ArrayRef<Function *> Fns) {
+ // This only works under the switch-lowering ABI because coro elision
+ // only works on the switch-lowering ABI.
+ SmallVector<Constant *, 4> Args(Fns.begin(), Fns.end());
+ assert(!Args.empty());
+ Function *Part = *Fns.begin();
+ Module *M = Part->getParent();
+ auto *ArrTy = ArrayType::get(Part->getType(), Args.size());
+
+ auto *ConstVal = ConstantArray::get(ArrTy, Args);
+ auto *GV = new GlobalVariable(*M, ConstVal->getType(), /*isConstant=*/true,
+ GlobalVariable::PrivateLinkage, ConstVal,
+ F.getName() + Twine(".resumers"));
+
+ // Update coro.begin instruction to refer to this constant.
+ LLVMContext &C = F.getContext();
+ auto *BC = ConstantExpr::getPointerCast(GV, PointerType::getUnqual(C));
+ Shape.getSwitchCoroId()->setInfo(BC);
+ }
+};
+
+} // namespace
static void replaceAsyncResumeFunction(CoroSuspendAsyncInst *Suspend,
Value *Continuation) {
@@ -2027,7 +2030,7 @@ splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones,
} else {
switch (Shape.ABI) {
case coro::ABI::Switch:
- splitSwitchCoroutine(F, Shape, Clones, TTI);
+ SwitchCoroutineSplitter::split(F, Shape, Clones, TTI);
break;
case coro::ABI::Async:
splitAsyncCoroutine(F, Shape, Clones);
More information about the llvm-commits
mailing list