[clang] [Clang][NFC] Refactor suspend emit logic in coroutine codegen (PR #73564)
Yuxuan Chen via cfe-commits
cfe-commits at lists.llvm.org
Tue Nov 28 19:51:46 PST 2023
https://github.com/yuxuanchen1997 updated https://github.com/llvm/llvm-project/pull/73564
>From 5f5ebec41c90366bf3c7ec1ee53154ba7afcb849 Mon Sep 17 00:00:00 2001
From: Yuxuan Chen <ych at meta.com>
Date: Wed, 22 Nov 2023 13:38:54 -0800
Subject: [PATCH] Refactor how we generate RValue vs LValue coawait expressions
---
clang/lib/CodeGen/CGCoroutine.cpp | 332 +++++++++++++++++-------------
1 file changed, 185 insertions(+), 147 deletions(-)
diff --git a/clang/lib/CodeGen/CGCoroutine.cpp b/clang/lib/CodeGen/CGCoroutine.cpp
index 888d30bfb3e1d6a..472c260f60068bd 100644
--- a/clang/lib/CodeGen/CGCoroutine.cpp
+++ b/clang/lib/CodeGen/CGCoroutine.cpp
@@ -129,49 +129,7 @@ static SmallString<32> buildSuspendPrefixStr(CGCoroData &Coro, AwaitKind Kind) {
return Prefix;
}
-// Check if function can throw based on prototype noexcept, also works for
-// destructors which are implicitly noexcept but can be marked noexcept(false).
-static bool FunctionCanThrow(const FunctionDecl *D) {
- const auto *Proto = D->getType()->getAs<FunctionProtoType>();
- if (!Proto) {
- // Function proto is not found, we conservatively assume throwing.
- return true;
- }
- return !isNoexceptExceptionSpec(Proto->getExceptionSpecType()) ||
- Proto->canThrow() != CT_Cannot;
-}
-
-static bool ResumeStmtCanThrow(const Stmt *S) {
- if (const auto *CE = dyn_cast<CallExpr>(S)) {
- const auto *Callee = CE->getDirectCallee();
- if (!Callee)
- // We don't have direct callee. Conservatively assume throwing.
- return true;
-
- if (FunctionCanThrow(Callee))
- return true;
-
- // Fall through to visit the children.
- }
-
- if (const auto *TE = dyn_cast<CXXBindTemporaryExpr>(S)) {
- // Special handling of CXXBindTemporaryExpr here as calling of Dtor of the
- // temporary is not part of `children()` as covered in the fall through.
- // We need to mark entire statement as throwing if the destructor of the
- // temporary throws.
- const auto *Dtor = TE->getTemporary()->getDestructor();
- if (FunctionCanThrow(Dtor))
- return true;
-
- // Fall through to visit the children.
- }
-
- for (const auto *child : S->children())
- if (ResumeStmtCanThrow(child))
- return true;
-
- return false;
-}
+namespace {
// Emit suspend expression which roughly looks like:
//
@@ -200,117 +158,198 @@ static bool ResumeStmtCanThrow(const Stmt *S) {
//
// See llvm's docs/Coroutines.rst for more details.
//
-namespace {
- struct LValueOrRValue {
- LValue LV;
- RValue RV;
- };
-}
-static LValueOrRValue emitSuspendExpression(CodeGenFunction &CGF, CGCoroData &Coro,
- CoroutineSuspendExpr const &S,
- AwaitKind Kind, AggValueSlot aggSlot,
- bool ignoreResult, bool forLValue) {
- auto *E = S.getCommonExpr();
-
- auto Binder =
- CodeGenFunction::OpaqueValueMappingData::bind(CGF, S.getOpaqueValue(), E);
- auto UnbindOnExit = llvm::make_scope_exit([&] { Binder.unbind(CGF); });
-
- auto Prefix = buildSuspendPrefixStr(Coro, Kind);
- BasicBlock *ReadyBlock = CGF.createBasicBlock(Prefix + Twine(".ready"));
- BasicBlock *SuspendBlock = CGF.createBasicBlock(Prefix + Twine(".suspend"));
- BasicBlock *CleanupBlock = CGF.createBasicBlock(Prefix + Twine(".cleanup"));
-
- // If expression is ready, no need to suspend.
- CGF.EmitBranchOnBoolExpr(S.getReadyExpr(), ReadyBlock, SuspendBlock, 0);
-
- // Otherwise, emit suspend logic.
- CGF.EmitBlock(SuspendBlock);
-
- auto &Builder = CGF.Builder;
- llvm::Function *CoroSave = CGF.CGM.getIntrinsic(llvm::Intrinsic::coro_save);
- auto *NullPtr = llvm::ConstantPointerNull::get(CGF.CGM.Int8PtrTy);
- auto *SaveCall = Builder.CreateCall(CoroSave, {NullPtr});
-
- CGF.CurCoro.InSuspendBlock = true;
- auto *SuspendRet = CGF.EmitScalarExpr(S.getSuspendExpr());
- CGF.CurCoro.InSuspendBlock = false;
-
- if (SuspendRet != nullptr && SuspendRet->getType()->isIntegerTy(1)) {
- // Veto suspension if requested by bool returning await_suspend.
- BasicBlock *RealSuspendBlock =
- CGF.createBasicBlock(Prefix + Twine(".suspend.bool"));
- CGF.Builder.CreateCondBr(SuspendRet, RealSuspendBlock, ReadyBlock);
- CGF.EmitBlock(RealSuspendBlock);
+class SuspendExpressionEmitter final {
+public:
+ SuspendExpressionEmitter(CodeGenFunction &CGF, CGCoroData &Coro,
+ CoroutineSuspendExpr const &S, AwaitKind Kind)
+ : CGF(CGF), Coro(Coro), SuspendExpr(S), Kind(Kind),
+ SuspendPrefix(buildSuspendPrefixStr(Coro, Kind)) {
+ CommonExpr = SuspendExpr.getCommonExpr();
+ Binder = CodeGenFunction::OpaqueValueMappingData::bind(
+ CGF, SuspendExpr.getOpaqueValue(), CommonExpr);
}
- // Emit the suspend point.
- const bool IsFinalSuspend = (Kind == AwaitKind::Final);
- llvm::Function *CoroSuspend =
- CGF.CGM.getIntrinsic(llvm::Intrinsic::coro_suspend);
- auto *SuspendResult = Builder.CreateCall(
- CoroSuspend, {SaveCall, Builder.getInt1(IsFinalSuspend)});
-
- // Create a switch capturing three possible continuations.
- auto *Switch = Builder.CreateSwitch(SuspendResult, Coro.SuspendBB, 2);
- Switch->addCase(Builder.getInt8(0), ReadyBlock);
- Switch->addCase(Builder.getInt8(1), CleanupBlock);
-
- // Emit cleanup for this suspend point.
- CGF.EmitBlock(CleanupBlock);
- CGF.EmitBranchThroughCleanup(Coro.CleanupJD);
-
- // Emit await_resume expression.
- CGF.EmitBlock(ReadyBlock);
-
- // Exception handling requires additional IR. If the 'await_resume' function
- // is marked as 'noexcept', we avoid generating this additional IR.
- CXXTryStmt *TryStmt = nullptr;
- if (Coro.ExceptionHandler && Kind == AwaitKind::Init &&
- ResumeStmtCanThrow(S.getResumeExpr())) {
- Coro.ResumeEHVar =
- CGF.CreateTempAlloca(Builder.getInt1Ty(), Prefix + Twine("resume.eh"));
- Builder.CreateFlagStore(true, Coro.ResumeEHVar);
-
- auto Loc = S.getResumeExpr()->getExprLoc();
- auto *Catch = new (CGF.getContext())
- CXXCatchStmt(Loc, /*exDecl=*/nullptr, Coro.ExceptionHandler);
- auto *TryBody = CompoundStmt::Create(CGF.getContext(), S.getResumeExpr(),
- FPOptionsOverride(), Loc, Loc);
- TryStmt = CXXTryStmt::Create(CGF.getContext(), Loc, TryBody, Catch);
- CGF.EnterCXXTryStmt(*TryStmt);
- CGF.EmitStmt(TryBody);
- // We don't use EmitCXXTryStmt here. We need to store to ResumeEHVar that
- // doesn't exist in the body.
- Builder.CreateFlagStore(false, Coro.ResumeEHVar);
- CGF.ExitCXXTryStmt(*TryStmt);
- LValueOrRValue Res;
- // We are not supposed to obtain the value from init suspend await_resume().
- Res.RV = RValue::getIgnored();
- return Res;
+ SuspendExpressionEmitter(const SuspendExpressionEmitter &) = delete;
+ SuspendExpressionEmitter(SuspendExpressionEmitter &&) = delete;
+
+ ~SuspendExpressionEmitter() { Binder.unbind(CGF); }
+
+ LValue EmitAsLValue() {
+ assert(Kind != AwaitKind::Init);
+ emitCommonBlocks();
+ return CGF.EmitLValue(SuspendExpr.getResumeExpr());
}
- LValueOrRValue Res;
- if (forLValue)
- Res.LV = CGF.EmitLValue(S.getResumeExpr());
- else
- Res.RV = CGF.EmitAnyExpr(S.getResumeExpr(), aggSlot, ignoreResult);
+ RValue EmitAsRValue(AggValueSlot AggSlot, bool IgnoreResult) {
+ emitCommonBlocks();
+ auto &Builder = CGF.Builder;
+ auto *ResumeExpr = SuspendExpr.getResumeExpr();
+
+ // Exception handling requires additional IR. If the 'await_resume' function
+ // is marked as 'noexcept', we avoid generating this additional IR.
+ CXXTryStmt *TryStmt = nullptr;
+ if (Coro.ExceptionHandler && Kind == AwaitKind::Init &&
+ resumeExprCanThrow()) {
+ Coro.ResumeEHVar = CGF.CreateTempAlloca(
+ Builder.getInt1Ty(), SuspendPrefix + Twine("resume.eh"));
+ Builder.CreateFlagStore(true, Coro.ResumeEHVar);
+
+ auto Loc = ResumeExpr->getExprLoc();
+ auto *Catch = new (CGF.getContext())
+ CXXCatchStmt(Loc, /*exDecl=*/nullptr, Coro.ExceptionHandler);
+
+ auto *TryBody = CompoundStmt::Create(CGF.getContext(), ResumeExpr,
+ FPOptionsOverride(), Loc, Loc);
+ TryStmt = CXXTryStmt::Create(CGF.getContext(), Loc, TryBody, Catch);
+ CGF.EnterCXXTryStmt(*TryStmt);
+ CGF.EmitStmt(TryBody);
+ // We don't use EmitCXXTryStmt here. We need to store to ResumeEHVar that
+ // doesn't exist in the body.
+ Builder.CreateFlagStore(false, Coro.ResumeEHVar);
+ CGF.ExitCXXTryStmt(*TryStmt);
+ // We are not supposed to obtain the value from init suspend
+ // await_resume().
+ return RValue::getIgnored();
+ }
- return Res;
-}
+ auto Ret = CGF.EmitAnyExpr(ResumeExpr, AggSlot, IgnoreResult);
+ return Ret;
+ }
+
+private:
+ CodeGenFunction &CGF;
+ CGCoroData &Coro;
+ CoroutineSuspendExpr const &SuspendExpr;
+ AwaitKind Kind;
+ SmallString<32> SuspendPrefix;
+ Expr *CommonExpr;
+ CodeGenFunction::OpaqueValueMappingData Binder;
+
+ // Emit all the common blocks for this suspend expression until the ready
+ // block, from which point there are three possible outcomes:
+ // 1) Emit as LValue;
+ // 2) Emit as RValue;
+ // 3) This suspend is the initial suspend of the coroutine, run
+ // `try { promise.await_resume() } catch { ... }` and store a flag if it
+ // didn't throw. In such case we continue to the coroutine function body,
+ // otherwise, continue to the catch logic in the coroutine's exception
+ // handler.
+ void emitCommonBlocks() {
+ BasicBlock *ReadyBlock =
+ CGF.createBasicBlock(SuspendPrefix + Twine(".ready"));
+ BasicBlock *SuspendBlock =
+ CGF.createBasicBlock(SuspendPrefix + Twine(".suspend"));
+ BasicBlock *CleanupBlock =
+ CGF.createBasicBlock(SuspendPrefix + Twine(".cleanup"));
+
+ // If expression is ready, no need to suspend.
+ CGF.EmitBranchOnBoolExpr(SuspendExpr.getReadyExpr(), ReadyBlock,
+ SuspendBlock, 0);
+
+ // Otherwise, emit suspend logic.
+ CGF.EmitBlock(SuspendBlock);
+
+ auto &Builder = CGF.Builder;
+ llvm::Function *CoroSave = CGF.CGM.getIntrinsic(llvm::Intrinsic::coro_save);
+ auto *NullPtr = llvm::ConstantPointerNull::get(CGF.CGM.Int8PtrTy);
+ auto *SaveCall = Builder.CreateCall(CoroSave, {NullPtr});
+
+ CGF.CurCoro.InSuspendBlock = true;
+ auto *SuspendRet = CGF.EmitScalarExpr(SuspendExpr.getSuspendExpr());
+ CGF.CurCoro.InSuspendBlock = false;
+
+ if (SuspendRet != nullptr && SuspendRet->getType()->isIntegerTy(1)) {
+ // Veto suspension if requested by bool returning await_suspend.
+ BasicBlock *RealSuspendBlock =
+ CGF.createBasicBlock(SuspendPrefix + Twine(".suspend.bool"));
+ CGF.Builder.CreateCondBr(SuspendRet, RealSuspendBlock, ReadyBlock);
+ CGF.EmitBlock(RealSuspendBlock);
+ }
+
+ // Emit the suspend point.
+ const bool IsFinalSuspend = (Kind == AwaitKind::Final);
+ llvm::Function *CoroSuspend =
+ CGF.CGM.getIntrinsic(llvm::Intrinsic::coro_suspend);
+ auto *SuspendResult = Builder.CreateCall(
+ CoroSuspend, {SaveCall, Builder.getInt1(IsFinalSuspend)});
+
+ // Create a switch capturing three possible continuations.
+ auto *Switch = Builder.CreateSwitch(SuspendResult, Coro.SuspendBB, 2);
+ Switch->addCase(Builder.getInt8(0), ReadyBlock);
+ Switch->addCase(Builder.getInt8(1), CleanupBlock);
+
+ // Emit cleanup for this suspend point.
+ CGF.EmitBlock(CleanupBlock);
+ CGF.EmitBranchThroughCleanup(Coro.CleanupJD);
+
+ // Emit await_resume expression.
+ CGF.EmitBlock(ReadyBlock);
+ }
+
+ // Check if function can throw based on prototype noexcept, also works for
+ // destructors which are implicitly noexcept but can be marked
+ // noexcept(false).
+ static bool functionCanThrow(const FunctionDecl *D) {
+ const auto *Proto = D->getType()->getAs<FunctionProtoType>();
+ if (!Proto) {
+ // Function proto is not found, we conservatively assume throwing.
+ return true;
+ }
+ return !isNoexceptExceptionSpec(Proto->getExceptionSpecType()) ||
+ Proto->canThrow() != CT_Cannot;
+ }
+
+ static bool resumeStmtCanThrow(const Stmt *S) {
+ if (const auto *CE = dyn_cast<CallExpr>(S)) {
+ const auto *Callee = CE->getDirectCallee();
+ if (!Callee)
+ // We don't have direct callee. Conservatively assume throwing.
+ return true;
+
+ if (functionCanThrow(Callee))
+ return true;
+
+ // Fall through to visit the children.
+ }
+
+ if (const auto *TE = dyn_cast<CXXBindTemporaryExpr>(S)) {
+ // Special handling of CXXBindTemporaryExpr here as calling of Dtor of the
+ // temporary is not part of `children()` as covered in the fall through.
+ // We need to mark entire statement as throwing if the destructor of the
+ // temporary throws.
+ const auto *Dtor = TE->getTemporary()->getDestructor();
+ if (functionCanThrow(Dtor))
+ return true;
+
+ // Fall through to visit the children.
+ }
+
+ for (const auto *child : S->children())
+ if (resumeStmtCanThrow(child))
+ return true;
+
+ return false;
+ }
+
+ bool resumeExprCanThrow() {
+ const Expr *E = SuspendExpr.getResumeExpr();
+ return resumeStmtCanThrow(E);
+ }
+};
+} // namespace
RValue CodeGenFunction::EmitCoawaitExpr(const CoawaitExpr &E,
AggValueSlot aggSlot,
bool ignoreResult) {
- return emitSuspendExpression(*this, *CurCoro.Data, E,
- CurCoro.Data->CurrentAwaitKind, aggSlot,
- ignoreResult, /*forLValue*/false).RV;
+ return SuspendExpressionEmitter(*this, *CurCoro.Data, E,
+ CurCoro.Data->CurrentAwaitKind)
+ .EmitAsRValue(aggSlot, ignoreResult);
}
RValue CodeGenFunction::EmitCoyieldExpr(const CoyieldExpr &E,
AggValueSlot aggSlot,
bool ignoreResult) {
- return emitSuspendExpression(*this, *CurCoro.Data, E, AwaitKind::Yield,
- aggSlot, ignoreResult, /*forLValue*/false).RV;
+ return SuspendExpressionEmitter(*this, *CurCoro.Data, E, AwaitKind::Yield)
+ .EmitAsRValue(aggSlot, ignoreResult);
}
void CodeGenFunction::EmitCoreturnStmt(CoreturnStmt const &S) {
@@ -343,9 +382,9 @@ CodeGenFunction::EmitCoawaitLValue(const CoawaitExpr *E) {
assert(getCoroutineSuspendExprReturnType(getContext(), E)->isReferenceType() &&
"Can't have a scalar return unless the return type is a "
"reference type!");
- return emitSuspendExpression(*this, *CurCoro.Data, *E,
- CurCoro.Data->CurrentAwaitKind, AggValueSlot::ignored(),
- /*ignoreResult*/false, /*forLValue*/true).LV;
+ return SuspendExpressionEmitter(*this, *CurCoro.Data, *E,
+ CurCoro.Data->CurrentAwaitKind)
+ .EmitAsLValue();
}
LValue
@@ -353,9 +392,8 @@ CodeGenFunction::EmitCoyieldLValue(const CoyieldExpr *E) {
assert(getCoroutineSuspendExprReturnType(getContext(), E)->isReferenceType() &&
"Can't have a scalar return unless the return type is a "
"reference type!");
- return emitSuspendExpression(*this, *CurCoro.Data, *E,
- AwaitKind::Yield, AggValueSlot::ignored(),
- /*ignoreResult*/false, /*forLValue*/true).LV;
+ return SuspendExpressionEmitter(*this, *CurCoro.Data, *E, AwaitKind::Yield)
+ .EmitAsLValue();
}
// Hunts for the parameter reference in the parameter copy/move declaration.
More information about the cfe-commits
mailing list