[clang] [Clang][NFC] Refactor suspend emit logic in coroutine codegen (PR #73564)

Yuxuan Chen via cfe-commits cfe-commits at lists.llvm.org
Tue Nov 28 19:51:46 PST 2023


https://github.com/yuxuanchen1997 updated https://github.com/llvm/llvm-project/pull/73564

>From 5f5ebec41c90366bf3c7ec1ee53154ba7afcb849 Mon Sep 17 00:00:00 2001
From: Yuxuan Chen <ych at meta.com>
Date: Wed, 22 Nov 2023 13:38:54 -0800
Subject: [PATCH] Refactor how we generate RValue vs LValue coawait expressions

---
 clang/lib/CodeGen/CGCoroutine.cpp | 332 +++++++++++++++++-------------
 1 file changed, 185 insertions(+), 147 deletions(-)

diff --git a/clang/lib/CodeGen/CGCoroutine.cpp b/clang/lib/CodeGen/CGCoroutine.cpp
index 888d30bfb3e1d6a..472c260f60068bd 100644
--- a/clang/lib/CodeGen/CGCoroutine.cpp
+++ b/clang/lib/CodeGen/CGCoroutine.cpp
@@ -129,49 +129,7 @@ static SmallString<32> buildSuspendPrefixStr(CGCoroData &Coro, AwaitKind Kind) {
   return Prefix;
 }
 
-// Check if function can throw based on prototype noexcept, also works for
-// destructors which are implicitly noexcept but can be marked noexcept(false).
-static bool FunctionCanThrow(const FunctionDecl *D) {
-  const auto *Proto = D->getType()->getAs<FunctionProtoType>();
-  if (!Proto) {
-    // Function proto is not found, we conservatively assume throwing.
-    return true;
-  }
-  return !isNoexceptExceptionSpec(Proto->getExceptionSpecType()) ||
-         Proto->canThrow() != CT_Cannot;
-}
-
-static bool ResumeStmtCanThrow(const Stmt *S) {
-  if (const auto *CE = dyn_cast<CallExpr>(S)) {
-    const auto *Callee = CE->getDirectCallee();
-    if (!Callee)
-      // We don't have direct callee. Conservatively assume throwing.
-      return true;
-
-    if (FunctionCanThrow(Callee))
-      return true;
-
-    // Fall through to visit the children.
-  }
-
-  if (const auto *TE = dyn_cast<CXXBindTemporaryExpr>(S)) {
-    // Special handling of CXXBindTemporaryExpr here as calling of Dtor of the
-    // temporary is not part of `children()` as covered in the fall through.
-    // We need to mark entire statement as throwing if the destructor of the
-    // temporary throws.
-    const auto *Dtor = TE->getTemporary()->getDestructor();
-    if (FunctionCanThrow(Dtor))
-      return true;
-
-    // Fall through to visit the children.
-  }
-
-  for (const auto *child : S->children())
-    if (ResumeStmtCanThrow(child))
-      return true;
-
-  return false;
-}
+namespace {
 
 // Emit suspend expression which roughly looks like:
 //
@@ -200,117 +158,198 @@ static bool ResumeStmtCanThrow(const Stmt *S) {
 //
 //  See llvm's docs/Coroutines.rst for more details.
 //
-namespace {
-  struct LValueOrRValue {
-    LValue LV;
-    RValue RV;
-  };
-}
-static LValueOrRValue emitSuspendExpression(CodeGenFunction &CGF, CGCoroData &Coro,
-                                    CoroutineSuspendExpr const &S,
-                                    AwaitKind Kind, AggValueSlot aggSlot,
-                                    bool ignoreResult, bool forLValue) {
-  auto *E = S.getCommonExpr();
-
-  auto Binder =
-      CodeGenFunction::OpaqueValueMappingData::bind(CGF, S.getOpaqueValue(), E);
-  auto UnbindOnExit = llvm::make_scope_exit([&] { Binder.unbind(CGF); });
-
-  auto Prefix = buildSuspendPrefixStr(Coro, Kind);
-  BasicBlock *ReadyBlock = CGF.createBasicBlock(Prefix + Twine(".ready"));
-  BasicBlock *SuspendBlock = CGF.createBasicBlock(Prefix + Twine(".suspend"));
-  BasicBlock *CleanupBlock = CGF.createBasicBlock(Prefix + Twine(".cleanup"));
-
-  // If expression is ready, no need to suspend.
-  CGF.EmitBranchOnBoolExpr(S.getReadyExpr(), ReadyBlock, SuspendBlock, 0);
-
-  // Otherwise, emit suspend logic.
-  CGF.EmitBlock(SuspendBlock);
-
-  auto &Builder = CGF.Builder;
-  llvm::Function *CoroSave = CGF.CGM.getIntrinsic(llvm::Intrinsic::coro_save);
-  auto *NullPtr = llvm::ConstantPointerNull::get(CGF.CGM.Int8PtrTy);
-  auto *SaveCall = Builder.CreateCall(CoroSave, {NullPtr});
-
-  CGF.CurCoro.InSuspendBlock = true;
-  auto *SuspendRet = CGF.EmitScalarExpr(S.getSuspendExpr());
-  CGF.CurCoro.InSuspendBlock = false;
-
-  if (SuspendRet != nullptr && SuspendRet->getType()->isIntegerTy(1)) {
-    // Veto suspension if requested by bool returning await_suspend.
-    BasicBlock *RealSuspendBlock =
-        CGF.createBasicBlock(Prefix + Twine(".suspend.bool"));
-    CGF.Builder.CreateCondBr(SuspendRet, RealSuspendBlock, ReadyBlock);
-    CGF.EmitBlock(RealSuspendBlock);
+class SuspendExpressionEmitter final {
+public:
+  SuspendExpressionEmitter(CodeGenFunction &CGF, CGCoroData &Coro,
+                           CoroutineSuspendExpr const &S, AwaitKind Kind)
+      : CGF(CGF), Coro(Coro), SuspendExpr(S), Kind(Kind),
+        SuspendPrefix(buildSuspendPrefixStr(Coro, Kind)) {
+    CommonExpr = SuspendExpr.getCommonExpr();
+    Binder = CodeGenFunction::OpaqueValueMappingData::bind(
+        CGF, SuspendExpr.getOpaqueValue(), CommonExpr);
   }
 
-  // Emit the suspend point.
-  const bool IsFinalSuspend = (Kind == AwaitKind::Final);
-  llvm::Function *CoroSuspend =
-      CGF.CGM.getIntrinsic(llvm::Intrinsic::coro_suspend);
-  auto *SuspendResult = Builder.CreateCall(
-      CoroSuspend, {SaveCall, Builder.getInt1(IsFinalSuspend)});
-
-  // Create a switch capturing three possible continuations.
-  auto *Switch = Builder.CreateSwitch(SuspendResult, Coro.SuspendBB, 2);
-  Switch->addCase(Builder.getInt8(0), ReadyBlock);
-  Switch->addCase(Builder.getInt8(1), CleanupBlock);
-
-  // Emit cleanup for this suspend point.
-  CGF.EmitBlock(CleanupBlock);
-  CGF.EmitBranchThroughCleanup(Coro.CleanupJD);
-
-  // Emit await_resume expression.
-  CGF.EmitBlock(ReadyBlock);
-
-  // Exception handling requires additional IR. If the 'await_resume' function
-  // is marked as 'noexcept', we avoid generating this additional IR.
-  CXXTryStmt *TryStmt = nullptr;
-  if (Coro.ExceptionHandler && Kind == AwaitKind::Init &&
-      ResumeStmtCanThrow(S.getResumeExpr())) {
-    Coro.ResumeEHVar =
-        CGF.CreateTempAlloca(Builder.getInt1Ty(), Prefix + Twine("resume.eh"));
-    Builder.CreateFlagStore(true, Coro.ResumeEHVar);
-
-    auto Loc = S.getResumeExpr()->getExprLoc();
-    auto *Catch = new (CGF.getContext())
-        CXXCatchStmt(Loc, /*exDecl=*/nullptr, Coro.ExceptionHandler);
-    auto *TryBody = CompoundStmt::Create(CGF.getContext(), S.getResumeExpr(),
-                                         FPOptionsOverride(), Loc, Loc);
-    TryStmt = CXXTryStmt::Create(CGF.getContext(), Loc, TryBody, Catch);
-    CGF.EnterCXXTryStmt(*TryStmt);
-    CGF.EmitStmt(TryBody);
-    // We don't use EmitCXXTryStmt here. We need to store to ResumeEHVar that
-    // doesn't exist in the body.
-    Builder.CreateFlagStore(false, Coro.ResumeEHVar);
-    CGF.ExitCXXTryStmt(*TryStmt);
-    LValueOrRValue Res;
-    // We are not supposed to obtain the value from init suspend await_resume().
-    Res.RV = RValue::getIgnored();
-    return Res;
+  SuspendExpressionEmitter(const SuspendExpressionEmitter &) = delete;
+  SuspendExpressionEmitter(SuspendExpressionEmitter &&) = delete;
+
+  ~SuspendExpressionEmitter() { Binder.unbind(CGF); }
+
+  LValue EmitAsLValue() {
+    assert(Kind != AwaitKind::Init);
+    emitCommonBlocks();
+    return CGF.EmitLValue(SuspendExpr.getResumeExpr());
   }
 
-  LValueOrRValue Res;
-  if (forLValue)
-    Res.LV = CGF.EmitLValue(S.getResumeExpr());
-  else
-    Res.RV = CGF.EmitAnyExpr(S.getResumeExpr(), aggSlot, ignoreResult);
+  RValue EmitAsRValue(AggValueSlot AggSlot, bool IgnoreResult) {
+    emitCommonBlocks();
+    auto &Builder = CGF.Builder;
+    auto *ResumeExpr = SuspendExpr.getResumeExpr();
+
+    // Exception handling requires additional IR. If the 'await_resume' function
+    // is marked as 'noexcept', we avoid generating this additional IR.
+    CXXTryStmt *TryStmt = nullptr;
+    if (Coro.ExceptionHandler && Kind == AwaitKind::Init &&
+        resumeExprCanThrow()) {
+      Coro.ResumeEHVar = CGF.CreateTempAlloca(
+          Builder.getInt1Ty(), SuspendPrefix + Twine("resume.eh"));
+      Builder.CreateFlagStore(true, Coro.ResumeEHVar);
+
+      auto Loc = ResumeExpr->getExprLoc();
+      auto *Catch = new (CGF.getContext())
+          CXXCatchStmt(Loc, /*exDecl=*/nullptr, Coro.ExceptionHandler);
+
+      auto *TryBody = CompoundStmt::Create(CGF.getContext(), ResumeExpr,
+                                           FPOptionsOverride(), Loc, Loc);
+      TryStmt = CXXTryStmt::Create(CGF.getContext(), Loc, TryBody, Catch);
+      CGF.EnterCXXTryStmt(*TryStmt);
+      CGF.EmitStmt(TryBody);
+      // We don't use EmitCXXTryStmt here. We need to store to ResumeEHVar that
+      // doesn't exist in the body.
+      Builder.CreateFlagStore(false, Coro.ResumeEHVar);
+      CGF.ExitCXXTryStmt(*TryStmt);
+      // We are not supposed to obtain the value from init suspend
+      // await_resume().
+      return RValue::getIgnored();
+    }
 
-  return Res;
-}
+    auto Ret = CGF.EmitAnyExpr(ResumeExpr, AggSlot, IgnoreResult);
+    return Ret;
+  }
+
+private:
+  CodeGenFunction &CGF;
+  CGCoroData &Coro;
+  CoroutineSuspendExpr const &SuspendExpr;
+  AwaitKind Kind;
+  SmallString<32> SuspendPrefix;
+  Expr *CommonExpr;
+  CodeGenFunction::OpaqueValueMappingData Binder;
+
+  // Emit all the common blocks for this suspend expression until the ready
+  // block, from which point there are three possible outcomes:
+  //   1) Emit as LValue;
+  //   2) Emit as RValue;
+  //   3) This suspend is the initial suspend of the coroutine, run
+  //      `try { promise.await_resume() } catch { ... }` and store a flag if it
+  //      didn't throw. In such case we continue to the coroutine function body,
+  //      otherwise, continue to the catch logic in the coroutine's exception
+  //      handler.
+  void emitCommonBlocks() {
+    BasicBlock *ReadyBlock =
+        CGF.createBasicBlock(SuspendPrefix + Twine(".ready"));
+    BasicBlock *SuspendBlock =
+        CGF.createBasicBlock(SuspendPrefix + Twine(".suspend"));
+    BasicBlock *CleanupBlock =
+        CGF.createBasicBlock(SuspendPrefix + Twine(".cleanup"));
+
+    // If expression is ready, no need to suspend.
+    CGF.EmitBranchOnBoolExpr(SuspendExpr.getReadyExpr(), ReadyBlock,
+                             SuspendBlock, 0);
+
+    // Otherwise, emit suspend logic.
+    CGF.EmitBlock(SuspendBlock);
+
+    auto &Builder = CGF.Builder;
+    llvm::Function *CoroSave = CGF.CGM.getIntrinsic(llvm::Intrinsic::coro_save);
+    auto *NullPtr = llvm::ConstantPointerNull::get(CGF.CGM.Int8PtrTy);
+    auto *SaveCall = Builder.CreateCall(CoroSave, {NullPtr});
+
+    CGF.CurCoro.InSuspendBlock = true;
+    auto *SuspendRet = CGF.EmitScalarExpr(SuspendExpr.getSuspendExpr());
+    CGF.CurCoro.InSuspendBlock = false;
+
+    if (SuspendRet != nullptr && SuspendRet->getType()->isIntegerTy(1)) {
+      // Veto suspension if requested by bool returning await_suspend.
+      BasicBlock *RealSuspendBlock =
+          CGF.createBasicBlock(SuspendPrefix + Twine(".suspend.bool"));
+      CGF.Builder.CreateCondBr(SuspendRet, RealSuspendBlock, ReadyBlock);
+      CGF.EmitBlock(RealSuspendBlock);
+    }
+
+    // Emit the suspend point.
+    const bool IsFinalSuspend = (Kind == AwaitKind::Final);
+    llvm::Function *CoroSuspend =
+        CGF.CGM.getIntrinsic(llvm::Intrinsic::coro_suspend);
+    auto *SuspendResult = Builder.CreateCall(
+        CoroSuspend, {SaveCall, Builder.getInt1(IsFinalSuspend)});
+
+    // Create a switch capturing three possible continuations.
+    auto *Switch = Builder.CreateSwitch(SuspendResult, Coro.SuspendBB, 2);
+    Switch->addCase(Builder.getInt8(0), ReadyBlock);
+    Switch->addCase(Builder.getInt8(1), CleanupBlock);
+
+    // Emit cleanup for this suspend point.
+    CGF.EmitBlock(CleanupBlock);
+    CGF.EmitBranchThroughCleanup(Coro.CleanupJD);
+
+    // Emit await_resume expression.
+    CGF.EmitBlock(ReadyBlock);
+  }
+
+  // Check if function can throw based on prototype noexcept, also works for
+  // destructors which are implicitly noexcept but can be marked
+  // noexcept(false).
+  static bool functionCanThrow(const FunctionDecl *D) {
+    const auto *Proto = D->getType()->getAs<FunctionProtoType>();
+    if (!Proto) {
+      // Function proto is not found, we conservatively assume throwing.
+      return true;
+    }
+    return !isNoexceptExceptionSpec(Proto->getExceptionSpecType()) ||
+           Proto->canThrow() != CT_Cannot;
+  }
+
+  static bool resumeStmtCanThrow(const Stmt *S) {
+    if (const auto *CE = dyn_cast<CallExpr>(S)) {
+      const auto *Callee = CE->getDirectCallee();
+      if (!Callee)
+        // We don't have direct callee. Conservatively assume throwing.
+        return true;
+
+      if (functionCanThrow(Callee))
+        return true;
+
+      // Fall through to visit the children.
+    }
+
+    if (const auto *TE = dyn_cast<CXXBindTemporaryExpr>(S)) {
+      // Special handling of CXXBindTemporaryExpr here as calling of Dtor of the
+      // temporary is not part of `children()` as covered in the fall through.
+      // We need to mark entire statement as throwing if the destructor of the
+      // temporary throws.
+      const auto *Dtor = TE->getTemporary()->getDestructor();
+      if (functionCanThrow(Dtor))
+        return true;
+
+      // Fall through to visit the children.
+    }
+
+    for (const auto *child : S->children())
+      if (resumeStmtCanThrow(child))
+        return true;
+
+    return false;
+  }
+
+  bool resumeExprCanThrow() {
+    const Expr *E = SuspendExpr.getResumeExpr();
+    return resumeStmtCanThrow(E);
+  }
+};
+} // namespace
 
 RValue CodeGenFunction::EmitCoawaitExpr(const CoawaitExpr &E,
                                         AggValueSlot aggSlot,
                                         bool ignoreResult) {
-  return emitSuspendExpression(*this, *CurCoro.Data, E,
-                               CurCoro.Data->CurrentAwaitKind, aggSlot,
-                               ignoreResult, /*forLValue*/false).RV;
+  return SuspendExpressionEmitter(*this, *CurCoro.Data, E,
+                                  CurCoro.Data->CurrentAwaitKind)
+      .EmitAsRValue(aggSlot, ignoreResult);
 }
 RValue CodeGenFunction::EmitCoyieldExpr(const CoyieldExpr &E,
                                         AggValueSlot aggSlot,
                                         bool ignoreResult) {
-  return emitSuspendExpression(*this, *CurCoro.Data, E, AwaitKind::Yield,
-                               aggSlot, ignoreResult, /*forLValue*/false).RV;
+  return SuspendExpressionEmitter(*this, *CurCoro.Data, E, AwaitKind::Yield)
+      .EmitAsRValue(aggSlot, ignoreResult);
 }
 
 void CodeGenFunction::EmitCoreturnStmt(CoreturnStmt const &S) {
@@ -343,9 +382,9 @@ CodeGenFunction::EmitCoawaitLValue(const CoawaitExpr *E) {
   assert(getCoroutineSuspendExprReturnType(getContext(), E)->isReferenceType() &&
          "Can't have a scalar return unless the return type is a "
          "reference type!");
-  return emitSuspendExpression(*this, *CurCoro.Data, *E,
-                               CurCoro.Data->CurrentAwaitKind, AggValueSlot::ignored(),
-                               /*ignoreResult*/false, /*forLValue*/true).LV;
+  return SuspendExpressionEmitter(*this, *CurCoro.Data, *E,
+                                  CurCoro.Data->CurrentAwaitKind)
+      .EmitAsLValue();
 }
 
 LValue
@@ -353,9 +392,8 @@ CodeGenFunction::EmitCoyieldLValue(const CoyieldExpr *E) {
   assert(getCoroutineSuspendExprReturnType(getContext(), E)->isReferenceType() &&
          "Can't have a scalar return unless the return type is a "
          "reference type!");
-  return emitSuspendExpression(*this, *CurCoro.Data, *E,
-                               AwaitKind::Yield, AggValueSlot::ignored(),
-                               /*ignoreResult*/false, /*forLValue*/true).LV;
+  return SuspendExpressionEmitter(*this, *CurCoro.Data, *E, AwaitKind::Yield)
+      .EmitAsLValue();
 }
 
 // Hunts for the parameter reference in the parameter copy/move declaration.



More information about the cfe-commits mailing list