[clang] [Clang][NFC] Refactor suspend emit logic in coroutine codegen (PR #73564)

Yuxuan Chen via cfe-commits cfe-commits at lists.llvm.org
Mon Nov 27 12:41:28 PST 2023


https://github.com/yuxuanchen1997 created https://github.com/llvm/llvm-project/pull/73564

Depends on https://github.com/llvm/llvm-project/pull/73160

I am not a big fan of the `struct LValueOrRValue`. It appears to me that the intent is to use it like a union, without explicitly saying `union`. While you don't have the risk of UB when reading the wrong member, it's still an annoying thing that it doesn't carry any tag on which arm is chosen. 

This PR is not trying to make functional changes. Instead, I am proposing this new style of using a helper class and its methods to emit different values.

Before:
```
emitSuspendExpression(*this, *CurCoro.Data, E, CurCoro.Data->CurrentAwaitKind, aggSlot, ignoreResult, /*forLValue*/false).RV
```
After:
```
SuspendExpressionEmitter(*this, *CurCoro.Data, E, CurCoro.Data->CurrentAwaitKind).EmitAsRValue(aggSlot, ignoreResult)
```

>From 08e2293255a504043fe404cceaeb3ff1fc0dc344 Mon Sep 17 00:00:00 2001
From: Yuxuan Chen <ych at meta.com>
Date: Tue, 21 Nov 2023 21:38:12 -0800
Subject: [PATCH 1/2] add checks for nested noexcept in cxxtempexpr

---
 clang/lib/CodeGen/CGCoroutine.cpp             | 11 ++++-
 .../coro-init-await-nontrivial-return.cpp     | 44 ++++++++++++++++++-
 2 files changed, 51 insertions(+), 4 deletions(-)

diff --git a/clang/lib/CodeGen/CGCoroutine.cpp b/clang/lib/CodeGen/CGCoroutine.cpp
index aaf122c0f83bc47..8aebc5563757cba 100644
--- a/clang/lib/CodeGen/CGCoroutine.cpp
+++ b/clang/lib/CodeGen/CGCoroutine.cpp
@@ -129,7 +129,14 @@ static SmallString<32> buildSuspendPrefixStr(CGCoroData &Coro, AwaitKind Kind) {
   return Prefix;
 }
 
-static bool memberCallExpressionCanThrow(const Expr *E) {
+static bool ResumeExprCanThrow(const CoroutineSuspendExpr &S) {
+  const Expr *E = S.getResumeExpr();
+
+  // If the return type of await_resume is not void, get the CXXMemberCallExpr
+  // from its subexpr.
+  if (const auto *BindTempExpr = dyn_cast<CXXBindTemporaryExpr>(E)) {
+    E = BindTempExpr->getSubExpr();
+  }
   if (const auto *CE = dyn_cast<CXXMemberCallExpr>(E))
     if (const auto *Proto =
             CE->getMethodDecl()->getType()->getAs<FunctionProtoType>())
@@ -233,7 +240,7 @@ static LValueOrRValue emitSuspendExpression(CodeGenFunction &CGF, CGCoroData &Co
   // is marked as 'noexcept', we avoid generating this additional IR.
   CXXTryStmt *TryStmt = nullptr;
   if (Coro.ExceptionHandler && Kind == AwaitKind::Init &&
-      memberCallExpressionCanThrow(S.getResumeExpr())) {
+      ResumeExprCanThrow(S)) {
     Coro.ResumeEHVar =
         CGF.CreateTempAlloca(Builder.getInt1Ty(), Prefix + Twine("resume.eh"));
     Builder.CreateFlagStore(true, Coro.ResumeEHVar);
diff --git a/clang/test/CodeGenCoroutines/coro-init-await-nontrivial-return.cpp b/clang/test/CodeGenCoroutines/coro-init-await-nontrivial-return.cpp
index c4b8da327f5c140..5d24841091f339c 100644
--- a/clang/test/CodeGenCoroutines/coro-init-await-nontrivial-return.cpp
+++ b/clang/test/CodeGenCoroutines/coro-init-await-nontrivial-return.cpp
@@ -7,6 +7,7 @@ struct NontrivialType {
   ~NontrivialType() {}
 };
 
+namespace can_throw {
 struct Task {
     struct promise_type;
     using handle_type = std::coroutine_handle<promise_type>;
@@ -38,9 +39,48 @@ Task coro_create() {
     co_return;
 }
 
-// CHECK-LABEL: define{{.*}} ptr @_Z11coro_createv(
+// CHECK-LABEL: define{{.*}} ptr @_ZN9can_throw11coro_createEv(
 // CHECK: init.ready:
 // CHECK-NEXT: store i1 true, ptr {{.*}}
-// CHECK-NEXT: call void @_ZN4Task23initial_suspend_awaiter12await_resumeEv(
+// CHECK-NEXT: call void @_ZN9can_throw4Task23initial_suspend_awaiter12await_resumeEv(
 // CHECK-NEXT: call void @_ZN14NontrivialTypeD1Ev(
 // CHECK-NEXT: store i1 false, ptr {{.*}}
+}
+
+namespace no_throw {
+struct InitNoThrowTask {
+    struct promise_type;
+    using handle_type = std::coroutine_handle<promise_type>;
+
+    struct initial_suspend_awaiter {
+        bool await_ready() {
+            return false;
+        }
+
+        void await_suspend(handle_type h) {}
+
+        NontrivialType await_resume() noexcept { return {}; }
+    };
+
+    struct promise_type {
+        void return_void() {}
+        void unhandled_exception() {}
+        initial_suspend_awaiter initial_suspend() { return {}; }
+        std::suspend_never final_suspend() noexcept { return {}; }
+        InitNoThrowTask get_return_object() {
+            return InitNoThrowTask{handle_type::from_promise(*this)};
+        }
+    };
+
+    handle_type handler;
+};
+
+InitNoThrowTask coro_create() {
+    co_return;
+}
+
+// CHECK-LABEL: define{{.*}} ptr @_ZN8no_throw11coro_createEv(
+// CHECK: init.ready:
+// CHECK-NEXT: call void @_ZN8no_throw15InitNoThrowTask23initial_suspend_awaiter12await_resumeEv(
+// CHECK-NEXT: call void @_ZN14NontrivialTypeD1Ev(
+}

>From 619d4517550d6b3226fbb48146419ae68bb18694 Mon Sep 17 00:00:00 2001
From: Yuxuan Chen <ych at meta.com>
Date: Wed, 22 Nov 2023 13:38:54 -0800
Subject: [PATCH 2/2] Refactor how we generate RValue vs LValue coawait
 expressions

---
 clang/lib/CodeGen/CGCoroutine.cpp | 272 +++++++++++++++++-------------
 1 file changed, 152 insertions(+), 120 deletions(-)

diff --git a/clang/lib/CodeGen/CGCoroutine.cpp b/clang/lib/CodeGen/CGCoroutine.cpp
index 8aebc5563757cba..b3da83158bbd9f5 100644
--- a/clang/lib/CodeGen/CGCoroutine.cpp
+++ b/clang/lib/CodeGen/CGCoroutine.cpp
@@ -129,22 +129,7 @@ static SmallString<32> buildSuspendPrefixStr(CGCoroData &Coro, AwaitKind Kind) {
   return Prefix;
 }
 
-static bool ResumeExprCanThrow(const CoroutineSuspendExpr &S) {
-  const Expr *E = S.getResumeExpr();
-
-  // If the return type of await_resume is not void, get the CXXMemberCallExpr
-  // from its subexpr.
-  if (const auto *BindTempExpr = dyn_cast<CXXBindTemporaryExpr>(E)) {
-    E = BindTempExpr->getSubExpr();
-  }
-  if (const auto *CE = dyn_cast<CXXMemberCallExpr>(E))
-    if (const auto *Proto =
-            CE->getMethodDecl()->getType()->getAs<FunctionProtoType>())
-      if (isNoexceptExceptionSpec(Proto->getExceptionSpecType()) &&
-          Proto->canThrow() == CT_Cannot)
-        return false;
-  return true;
-}
+namespace {
 
 // Emit suspend expression which roughly looks like:
 //
@@ -173,117 +158,165 @@ static bool ResumeExprCanThrow(const CoroutineSuspendExpr &S) {
 //
 //  See llvm's docs/Coroutines.rst for more details.
 //
-namespace {
-  struct LValueOrRValue {
-    LValue LV;
-    RValue RV;
-  };
-}
-static LValueOrRValue emitSuspendExpression(CodeGenFunction &CGF, CGCoroData &Coro,
-                                    CoroutineSuspendExpr const &S,
-                                    AwaitKind Kind, AggValueSlot aggSlot,
-                                    bool ignoreResult, bool forLValue) {
-  auto *E = S.getCommonExpr();
-
-  auto Binder =
-      CodeGenFunction::OpaqueValueMappingData::bind(CGF, S.getOpaqueValue(), E);
-  auto UnbindOnExit = llvm::make_scope_exit([&] { Binder.unbind(CGF); });
-
-  auto Prefix = buildSuspendPrefixStr(Coro, Kind);
-  BasicBlock *ReadyBlock = CGF.createBasicBlock(Prefix + Twine(".ready"));
-  BasicBlock *SuspendBlock = CGF.createBasicBlock(Prefix + Twine(".suspend"));
-  BasicBlock *CleanupBlock = CGF.createBasicBlock(Prefix + Twine(".cleanup"));
-
-  // If expression is ready, no need to suspend.
-  CGF.EmitBranchOnBoolExpr(S.getReadyExpr(), ReadyBlock, SuspendBlock, 0);
-
-  // Otherwise, emit suspend logic.
-  CGF.EmitBlock(SuspendBlock);
-
-  auto &Builder = CGF.Builder;
-  llvm::Function *CoroSave = CGF.CGM.getIntrinsic(llvm::Intrinsic::coro_save);
-  auto *NullPtr = llvm::ConstantPointerNull::get(CGF.CGM.Int8PtrTy);
-  auto *SaveCall = Builder.CreateCall(CoroSave, {NullPtr});
-
-  CGF.CurCoro.InSuspendBlock = true;
-  auto *SuspendRet = CGF.EmitScalarExpr(S.getSuspendExpr());
-  CGF.CurCoro.InSuspendBlock = false;
-
-  if (SuspendRet != nullptr && SuspendRet->getType()->isIntegerTy(1)) {
-    // Veto suspension if requested by bool returning await_suspend.
-    BasicBlock *RealSuspendBlock =
-        CGF.createBasicBlock(Prefix + Twine(".suspend.bool"));
-    CGF.Builder.CreateCondBr(SuspendRet, RealSuspendBlock, ReadyBlock);
-    CGF.EmitBlock(RealSuspendBlock);
+class SuspendExpressionEmitter final {
+public:
+  SuspendExpressionEmitter(CodeGenFunction &CGF, CGCoroData &Coro,
+                           CoroutineSuspendExpr const &S, AwaitKind Kind)
+      : CGF(CGF), Coro(Coro), SuspendExpr(S), Kind(Kind),
+        SuspendPrefix(buildSuspendPrefixStr(Coro, Kind)) {
+    CommonExpr = SuspendExpr.getCommonExpr();
+    Binder = CodeGenFunction::OpaqueValueMappingData::bind(
+        CGF, SuspendExpr.getOpaqueValue(), CommonExpr);
   }
 
-  // Emit the suspend point.
-  const bool IsFinalSuspend = (Kind == AwaitKind::Final);
-  llvm::Function *CoroSuspend =
-      CGF.CGM.getIntrinsic(llvm::Intrinsic::coro_suspend);
-  auto *SuspendResult = Builder.CreateCall(
-      CoroSuspend, {SaveCall, Builder.getInt1(IsFinalSuspend)});
-
-  // Create a switch capturing three possible continuations.
-  auto *Switch = Builder.CreateSwitch(SuspendResult, Coro.SuspendBB, 2);
-  Switch->addCase(Builder.getInt8(0), ReadyBlock);
-  Switch->addCase(Builder.getInt8(1), CleanupBlock);
-
-  // Emit cleanup for this suspend point.
-  CGF.EmitBlock(CleanupBlock);
-  CGF.EmitBranchThroughCleanup(Coro.CleanupJD);
-
-  // Emit await_resume expression.
-  CGF.EmitBlock(ReadyBlock);
-
-  // Exception handling requires additional IR. If the 'await_resume' function
-  // is marked as 'noexcept', we avoid generating this additional IR.
-  CXXTryStmt *TryStmt = nullptr;
-  if (Coro.ExceptionHandler && Kind == AwaitKind::Init &&
-      ResumeExprCanThrow(S)) {
-    Coro.ResumeEHVar =
-        CGF.CreateTempAlloca(Builder.getInt1Ty(), Prefix + Twine("resume.eh"));
-    Builder.CreateFlagStore(true, Coro.ResumeEHVar);
-
-    auto Loc = S.getResumeExpr()->getExprLoc();
-    auto *Catch = new (CGF.getContext())
-        CXXCatchStmt(Loc, /*exDecl=*/nullptr, Coro.ExceptionHandler);
-    auto *TryBody = CompoundStmt::Create(CGF.getContext(), S.getResumeExpr(),
-                                         FPOptionsOverride(), Loc, Loc);
-    TryStmt = CXXTryStmt::Create(CGF.getContext(), Loc, TryBody, Catch);
-    CGF.EnterCXXTryStmt(*TryStmt);
-    CGF.EmitStmt(TryBody);
-    // We don't use EmitCXXTryStmt here. We need to store to ResumeEHVar that
-    // doesn't exist in the body.
-    Builder.CreateFlagStore(false, Coro.ResumeEHVar);
-    CGF.ExitCXXTryStmt(*TryStmt);
-    LValueOrRValue Res;
-    // We are not supposed to obtain the value from init suspend await_resume().
-    Res.RV = RValue::getIgnored();
-    return Res;
+  SuspendExpressionEmitter(const SuspendExpressionEmitter &) = delete;
+  SuspendExpressionEmitter(SuspendExpressionEmitter &&) = delete;
+
+  ~SuspendExpressionEmitter() { Binder.unbind(CGF); }
+
+  LValue EmitAsLValue() {
+    assert(Kind != AwaitKind::Init);
+    emitCommonBlocks();
+    return CGF.EmitLValue(SuspendExpr.getResumeExpr());
   }
 
-  LValueOrRValue Res;
-  if (forLValue)
-    Res.LV = CGF.EmitLValue(S.getResumeExpr());
-  else
-    Res.RV = CGF.EmitAnyExpr(S.getResumeExpr(), aggSlot, ignoreResult);
+  RValue EmitAsRValue(AggValueSlot AggSlot, bool IgnoreResult) {
+    emitCommonBlocks();
+    auto &Builder = CGF.Builder;
+    auto *ResumeExpr = SuspendExpr.getResumeExpr();
+
+    // Exception handling requires additional IR. If the 'await_resume' function
+    // is marked as 'noexcept', we avoid generating this additional IR.
+    CXXTryStmt *TryStmt = nullptr;
+    if (Coro.ExceptionHandler && Kind == AwaitKind::Init &&
+        resumeExprCanThrow()) {
+      Coro.ResumeEHVar = CGF.CreateTempAlloca(
+          Builder.getInt1Ty(), SuspendPrefix + Twine("resume.eh"));
+      Builder.CreateFlagStore(true, Coro.ResumeEHVar);
+
+      auto Loc = ResumeExpr->getExprLoc();
+      auto *Catch = new (CGF.getContext())
+          CXXCatchStmt(Loc, /*exDecl=*/nullptr, Coro.ExceptionHandler);
+
+      auto *TryBody = CompoundStmt::Create(CGF.getContext(), ResumeExpr,
+                                           FPOptionsOverride(), Loc, Loc);
+      TryStmt = CXXTryStmt::Create(CGF.getContext(), Loc, TryBody, Catch);
+      CGF.EnterCXXTryStmt(*TryStmt);
+      CGF.EmitStmt(TryBody);
+      // We don't use EmitCXXTryStmt here. We need to store to ResumeEHVar that
+      // doesn't exist in the body.
+      Builder.CreateFlagStore(false, Coro.ResumeEHVar);
+      CGF.ExitCXXTryStmt(*TryStmt);
+      // We are not supposed to obtain the value from init suspend
+      // await_resume().
+      return RValue::getIgnored();
+    }
 
-  return Res;
-}
+    auto Ret = CGF.EmitAnyExpr(ResumeExpr, AggSlot, IgnoreResult);
+    return Ret;
+  }
+
+private:
+  CodeGenFunction &CGF;
+  CGCoroData &Coro;
+  CoroutineSuspendExpr const &SuspendExpr;
+  AwaitKind Kind;
+  SmallString<32> SuspendPrefix;
+  Expr *CommonExpr;
+  CodeGenFunction::OpaqueValueMappingData Binder;
+
+  // Emit all the common blocks for this suspend expression until the ready
+  // block, from which point there are three possible outcomes:
+  //   1) Emit as LValue;
+  //   2) Emit as RValue;
+  //   3) This suspend is the initial suspend of the coroutine, run
+  //      `try { promise.await_resume() } catch { ... }` and store a flag if it
+  //      didn't throw. In such case we continue to the coroutine function body,
+  //      otherwise, continue to the catch logic in the coroutine's exception
+  //      handler.
+  void emitCommonBlocks() {
+    BasicBlock *ReadyBlock =
+        CGF.createBasicBlock(SuspendPrefix + Twine(".ready"));
+    BasicBlock *SuspendBlock =
+        CGF.createBasicBlock(SuspendPrefix + Twine(".suspend"));
+    BasicBlock *CleanupBlock =
+        CGF.createBasicBlock(SuspendPrefix + Twine(".cleanup"));
+
+    // If expression is ready, no need to suspend.
+    CGF.EmitBranchOnBoolExpr(SuspendExpr.getReadyExpr(), ReadyBlock,
+                             SuspendBlock, 0);
+
+    // Otherwise, emit suspend logic.
+    CGF.EmitBlock(SuspendBlock);
+
+    auto &Builder = CGF.Builder;
+    llvm::Function *CoroSave = CGF.CGM.getIntrinsic(llvm::Intrinsic::coro_save);
+    auto *NullPtr = llvm::ConstantPointerNull::get(CGF.CGM.Int8PtrTy);
+    auto *SaveCall = Builder.CreateCall(CoroSave, {NullPtr});
+
+    CGF.CurCoro.InSuspendBlock = true;
+    auto *SuspendRet = CGF.EmitScalarExpr(SuspendExpr.getSuspendExpr());
+    CGF.CurCoro.InSuspendBlock = false;
+
+    if (SuspendRet != nullptr && SuspendRet->getType()->isIntegerTy(1)) {
+      // Veto suspension if requested by bool returning await_suspend.
+      BasicBlock *RealSuspendBlock =
+          CGF.createBasicBlock(SuspendPrefix + Twine(".suspend.bool"));
+      CGF.Builder.CreateCondBr(SuspendRet, RealSuspendBlock, ReadyBlock);
+      CGF.EmitBlock(RealSuspendBlock);
+    }
+
+    // Emit the suspend point.
+    const bool IsFinalSuspend = (Kind == AwaitKind::Final);
+    llvm::Function *CoroSuspend =
+        CGF.CGM.getIntrinsic(llvm::Intrinsic::coro_suspend);
+    auto *SuspendResult = Builder.CreateCall(
+        CoroSuspend, {SaveCall, Builder.getInt1(IsFinalSuspend)});
+
+    // Create a switch capturing three possible continuations.
+    auto *Switch = Builder.CreateSwitch(SuspendResult, Coro.SuspendBB, 2);
+    Switch->addCase(Builder.getInt8(0), ReadyBlock);
+    Switch->addCase(Builder.getInt8(1), CleanupBlock);
+
+    // Emit cleanup for this suspend point.
+    CGF.EmitBlock(CleanupBlock);
+    CGF.EmitBranchThroughCleanup(Coro.CleanupJD);
+
+    // Emit await_resume expression.
+    CGF.EmitBlock(ReadyBlock);
+  }
+
+  bool resumeExprCanThrow() {
+    const Expr *E = SuspendExpr.getResumeExpr();
+
+    // If the return type of await_resume is not void, get the CXXMemberCallExpr
+    // from its subexpr.
+    if (const auto *BindTempExpr = dyn_cast<CXXBindTemporaryExpr>(E)) {
+      E = BindTempExpr->getSubExpr();
+    }
+    if (const auto *CE = dyn_cast<CXXMemberCallExpr>(E))
+      if (const auto *Proto =
+              CE->getMethodDecl()->getType()->getAs<FunctionProtoType>())
+        if (isNoexceptExceptionSpec(Proto->getExceptionSpecType()) &&
+            Proto->canThrow() == CT_Cannot)
+          return false;
+    return true;
+  }
+};
+} // namespace
 
 RValue CodeGenFunction::EmitCoawaitExpr(const CoawaitExpr &E,
                                         AggValueSlot aggSlot,
                                         bool ignoreResult) {
-  return emitSuspendExpression(*this, *CurCoro.Data, E,
-                               CurCoro.Data->CurrentAwaitKind, aggSlot,
-                               ignoreResult, /*forLValue*/false).RV;
+  return SuspendExpressionEmitter(*this, *CurCoro.Data, E,
+                                  CurCoro.Data->CurrentAwaitKind)
+      .EmitAsRValue(aggSlot, ignoreResult);
 }
 RValue CodeGenFunction::EmitCoyieldExpr(const CoyieldExpr &E,
                                         AggValueSlot aggSlot,
                                         bool ignoreResult) {
-  return emitSuspendExpression(*this, *CurCoro.Data, E, AwaitKind::Yield,
-                               aggSlot, ignoreResult, /*forLValue*/false).RV;
+  return SuspendExpressionEmitter(*this, *CurCoro.Data, E, AwaitKind::Yield)
+      .EmitAsRValue(aggSlot, ignoreResult);
 }
 
 void CodeGenFunction::EmitCoreturnStmt(CoreturnStmt const &S) {
@@ -316,9 +349,9 @@ CodeGenFunction::EmitCoawaitLValue(const CoawaitExpr *E) {
   assert(getCoroutineSuspendExprReturnType(getContext(), E)->isReferenceType() &&
          "Can't have a scalar return unless the return type is a "
          "reference type!");
-  return emitSuspendExpression(*this, *CurCoro.Data, *E,
-                               CurCoro.Data->CurrentAwaitKind, AggValueSlot::ignored(),
-                               /*ignoreResult*/false, /*forLValue*/true).LV;
+  return SuspendExpressionEmitter(*this, *CurCoro.Data, *E,
+                                  CurCoro.Data->CurrentAwaitKind)
+      .EmitAsLValue();
 }
 
 LValue
@@ -326,9 +359,8 @@ CodeGenFunction::EmitCoyieldLValue(const CoyieldExpr *E) {
   assert(getCoroutineSuspendExprReturnType(getContext(), E)->isReferenceType() &&
          "Can't have a scalar return unless the return type is a "
          "reference type!");
-  return emitSuspendExpression(*this, *CurCoro.Data, *E,
-                               AwaitKind::Yield, AggValueSlot::ignored(),
-                               /*ignoreResult*/false, /*forLValue*/true).LV;
+  return SuspendExpressionEmitter(*this, *CurCoro.Data, *E, AwaitKind::Yield)
+      .EmitAsLValue();
 }
 
 // Hunts for the parameter reference in the parameter copy/move declaration.



More information about the cfe-commits mailing list