[clang] [Clang] Coroutines: Properly Check if `await_suspend` return type convertible to `std::coroutine_handle<>` (PR #85684)

Yuxuan Chen via cfe-commits cfe-commits at lists.llvm.org
Fri Mar 22 11:58:10 PDT 2024


https://github.com/yuxuanchen1997 updated https://github.com/llvm/llvm-project/pull/85684

>From 6887adae7500c4791a8620fa5b558e195e2c64cc Mon Sep 17 00:00:00 2001
From: Yuxuan Chen <ych at meta.com>
Date: Mon, 18 Mar 2024 10:45:20 -0700
Subject: [PATCH] Check if Coroutine await_suspend type returns the right type

---
 .../clang/Basic/DiagnosticSemaKinds.td        |   2 +-
 clang/lib/Sema/SemaCoroutine.cpp              | 119 +++++++++++++-----
 clang/test/SemaCXX/coroutines.cpp             |  28 ++++-
 3 files changed, 108 insertions(+), 41 deletions(-)

diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index fc727cef9cd835..796b3d9d5e1190 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -11707,7 +11707,7 @@ def err_coroutine_promise_new_requires_nothrow : Error<
 def note_coroutine_promise_call_implicitly_required : Note<
   "call to %0 implicitly required by coroutine function here">;
 def err_await_suspend_invalid_return_type : Error<
-  "return type of 'await_suspend' is required to be 'void' or 'bool' (have %0)"
+  "return type of 'await_suspend' is required to be 'void', 'bool', or 'std::coroutine_handle' (have %0)"
 >;
 def note_await_ready_no_bool_conversion : Note<
   "return type of 'await_ready' is required to be contextually convertible to 'bool'"
diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp
index 736632857efc36..2e81a83b62df51 100644
--- a/clang/lib/Sema/SemaCoroutine.cpp
+++ b/clang/lib/Sema/SemaCoroutine.cpp
@@ -137,12 +137,8 @@ static QualType lookupPromiseType(Sema &S, const FunctionDecl *FD,
   return PromiseType;
 }
 
-/// Look up the std::coroutine_handle<PromiseType>.
-static QualType lookupCoroutineHandleType(Sema &S, QualType PromiseType,
-                                          SourceLocation Loc) {
-  if (PromiseType.isNull())
-    return QualType();
-
+static ClassTemplateDecl *lookupCoroutineHandleTemplate(Sema &S,
+                                                        SourceLocation Loc) {
   NamespaceDecl *CoroNamespace = S.getStdNamespace();
   assert(CoroNamespace && "Should already be diagnosed");
 
@@ -151,18 +147,32 @@ static QualType lookupCoroutineHandleType(Sema &S, QualType PromiseType,
   if (!S.LookupQualifiedName(Result, CoroNamespace)) {
     S.Diag(Loc, diag::err_implied_coroutine_type_not_found)
         << "std::coroutine_handle";
-    return QualType();
+    return nullptr;
   }
 
-  ClassTemplateDecl *CoroHandle = Result.getAsSingle<ClassTemplateDecl>();
+  auto *CoroHandle = Result.getAsSingle<ClassTemplateDecl>();
+
   if (!CoroHandle) {
     Result.suppressDiagnostics();
     // We found something weird. Complain about the first thing we found.
     NamedDecl *Found = *Result.begin();
     S.Diag(Found->getLocation(), diag::err_malformed_std_coroutine_handle);
-    return QualType();
+    return nullptr;
   }
 
+  return CoroHandle;
+}
+
+/// Look up the std::coroutine_handle<PromiseType>.
+static QualType lookupCoroutineHandleType(Sema &S, QualType PromiseType,
+                                          SourceLocation Loc) {
+  if (PromiseType.isNull())
+    return QualType();
+
+  ClassTemplateDecl *CoroHandle = lookupCoroutineHandleTemplate(S, Loc);
+  if (!CoroHandle)
+    return QualType();
+
   // Form template argument list for coroutine_handle<Promise>.
   TemplateArgumentListInfo Args(Loc, Loc);
   Args.addArgument(TemplateArgumentLoc(
@@ -331,16 +341,12 @@ static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc,
 // coroutine.
 static Expr *maybeTailCall(Sema &S, QualType RetType, Expr *E,
                            SourceLocation Loc) {
-  if (RetType->isReferenceType())
-    return nullptr;
+  assert(!RetType->isReferenceType() &&
+         "Should have diagnosed reference types.");
   Type const *T = RetType.getTypePtr();
   if (!T->isClassType() && !T->isStructureType())
     return nullptr;
 
-  // FIXME: Add convertability check to coroutine_handle<>. Possibly via
-  // EvaluateBinaryTypeTrait(BTT_IsConvertible, ...) which is at the moment
-  // a private function in SemaExprCXX.cpp
-
   ExprResult AddressExpr = buildMemberCall(S, E, Loc, "address", std::nullopt);
   if (AddressExpr.isInvalid())
     return nullptr;
@@ -358,6 +364,30 @@ static Expr *maybeTailCall(Sema &S, QualType RetType, Expr *E,
   return S.MaybeCreateExprWithCleanups(JustAddress);
 }
 
+static bool isSpecializationOfCoroutineHandle(Sema &S, QualType Ty,
+                                              SourceLocation Loc) {
+  auto *CoroutineHandleClassTemplateDecl =
+      lookupCoroutineHandleTemplate(S, Loc);
+
+  if (!CoroutineHandleClassTemplateDecl)
+    return false;
+
+  auto *RecordTy = Ty->getAs<RecordType>();
+  if (!RecordTy)
+    return false;
+
+  auto *D = RecordTy->getDecl();
+  if (!D)
+    return false;
+
+  auto *SpecializationDecl = dyn_cast<ClassTemplateSpecializationDecl>(D);
+  if (!SpecializationDecl)
+    return false;
+
+  return CoroutineHandleClassTemplateDecl->getCanonicalDecl() ==
+         SpecializationDecl->getSpecializedTemplate()->getCanonicalDecl();
+}
+
 /// Build calls to await_ready, await_suspend, and await_resume for a co_await
 /// expression.
 /// The generated AST tries to clean up temporary objects as early as
@@ -418,39 +448,60 @@ static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, VarDecl *CoroPromise,
     return Calls;
   }
   Expr *CoroHandle = CoroHandleRes.get();
-  CallExpr *AwaitSuspend = cast_or_null<CallExpr>(
-      BuildSubExpr(ACT::ACT_Suspend, "await_suspend", CoroHandle));
+  auto *AwaitSuspend = [&]() -> CallExpr * {
+    auto *SubExpr = BuildSubExpr(ACT::ACT_Suspend, "await_suspend", CoroHandle);
+    if (!SubExpr)
+      return nullptr;
+    if (auto *E = dyn_cast<CXXBindTemporaryExpr>(SubExpr)) {
+      // This happens when await_suspend return type is not trivially
+      // destructible. This doesn't happen for the permitted return types of
+      // such function. Diagnose it later.
+      return cast_or_null<CallExpr>(E->getSubExpr());
+    } else {
+      return cast_or_null<CallExpr>(SubExpr);
+    }
+  }();
+
   if (!AwaitSuspend)
     return Calls;
+
   if (!AwaitSuspend->getType()->isDependentType()) {
+    auto InvalidAwaitSuspendReturnType = [&](QualType RetType) {
+      // non-class prvalues always have cv-unqualified types
+      S.Diag(AwaitSuspend->getCalleeDecl()->getLocation(),
+             diag::err_await_suspend_invalid_return_type)
+          << RetType;
+      S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required)
+          << AwaitSuspend->getDirectCallee();
+      Calls.IsInvalid = true;
+    };
+
     // [expr.await]p3 [...]
     //   - await-suspend is the expression e.await_suspend(h), which shall be
     //     a prvalue of type void, bool, or std::coroutine_handle<Z> for some
     //     type Z.
     QualType RetType = AwaitSuspend->getCallReturnType(S.Context);
 
-    // Support for coroutine_handle returning await_suspend.
-    if (Expr *TailCallSuspend =
-            maybeTailCall(S, RetType, AwaitSuspend, Loc))
+    if (RetType->isReferenceType()) {
+      InvalidAwaitSuspendReturnType(RetType);
+    } else if (RetType->isBooleanType() || RetType->isVoidType()) {
+      Calls.Results[ACT::ACT_Suspend] =
+          S.MaybeCreateExprWithCleanups(AwaitSuspend);
+    } else if (isSpecializationOfCoroutineHandle(S, RetType, Loc)) {
+      // Support for coroutine_handle returning await_suspend.
+      //
       // Note that we don't wrap the expression with ExprWithCleanups here
       // because that might interfere with tailcall contract (e.g. inserting
       // clean up instructions in-between tailcall and return). Instead
       // ExprWithCleanups is wrapped within maybeTailCall() prior to the resume
       // call.
-      Calls.Results[ACT::ACT_Suspend] = TailCallSuspend;
-    else {
-      // non-class prvalues always have cv-unqualified types
-      if (RetType->isReferenceType() ||
-          (!RetType->isBooleanType() && !RetType->isVoidType())) {
-        S.Diag(AwaitSuspend->getCalleeDecl()->getLocation(),
-               diag::err_await_suspend_invalid_return_type)
-            << RetType;
-        S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required)
-            << AwaitSuspend->getDirectCallee();
-        Calls.IsInvalid = true;
-      } else
-        Calls.Results[ACT::ACT_Suspend] =
-            S.MaybeCreateExprWithCleanups(AwaitSuspend);
+      Expr *TailCallSuspend = maybeTailCall(S, RetType, AwaitSuspend, Loc);
+      if (TailCallSuspend)
+        Calls.Results[ACT::ACT_Suspend] = TailCallSuspend;
+      else
+        InvalidAwaitSuspendReturnType(RetType);
+    } else {
+      InvalidAwaitSuspendReturnType(RetType);
     }
   }
 
diff --git a/clang/test/SemaCXX/coroutines.cpp b/clang/test/SemaCXX/coroutines.cpp
index 2292932583fff6..14c4a2a8d9b45e 100644
--- a/clang/test/SemaCXX/coroutines.cpp
+++ b/clang/test/SemaCXX/coroutines.cpp
@@ -1005,12 +1005,24 @@ coro<promise_no_return_func> no_return_value_or_return_void_3() {
   co_return 43; // expected-error {{no member named 'return_value'}}
 }
 
-struct bad_await_suspend_return {
+struct non_trivial_destruction_type {
+  ~non_trivial_destruction_type();
+};
+
+struct bad_await_suspend_return_1 {
   bool await_ready();
-  // expected-error at +1 {{return type of 'await_suspend' is required to be 'void' or 'bool' (have 'char')}}
+  // expected-error at +1 {{return type of 'await_suspend' is required to be 'void', 'bool', or 'std::coroutine_handle' (have 'char')}}
   char await_suspend(std::coroutine_handle<>);
   void await_resume();
 };
+
+struct bad_await_suspend_return_2 {
+  bool await_ready();
+  // expected-error at +1 {{return type of 'await_suspend' is required to be 'void', 'bool', or 'std::coroutine_handle' (have 'non_trivial_destruction_type')}}
+  non_trivial_destruction_type await_suspend(std::coroutine_handle<>);
+  void await_resume();
+};
+
 struct bad_await_ready_return {
   // expected-note at +1 {{return type of 'await_ready' is required to be contextually convertible to 'bool'}}
   void await_ready();
@@ -1028,8 +1040,8 @@ struct await_ready_explicit_bool {
 template <class SuspendTy>
 struct await_suspend_type_test {
   bool await_ready();
-  // expected-error at +2 {{return type of 'await_suspend' is required to be 'void' or 'bool' (have 'bool &')}}
-  // expected-error at +1 {{return type of 'await_suspend' is required to be 'void' or 'bool' (have 'bool &&')}}
+  // expected-error at +2 {{return type of 'await_suspend' is required to be 'void', 'bool', or 'std::coroutine_handle' (have 'bool &')}}
+  // expected-error at +1 {{return type of 'await_suspend' is required to be 'void', 'bool', or 'std::coroutine_handle' (have 'bool &&')}}
   SuspendTy await_suspend(std::coroutine_handle<>);
   // cxx20_23-warning at -1 {{volatile-qualified return type 'const volatile bool' is deprecated}}
   void await_resume();
@@ -1042,8 +1054,12 @@ void test_bad_suspend() {
     co_await a; // expected-note {{call to 'await_ready' implicitly required by coroutine function here}}
   }
   {
-    bad_await_suspend_return b;
-    co_await b; // expected-note {{call to 'await_suspend' implicitly required by coroutine function here}}
+    bad_await_suspend_return_1 b1;
+    co_await b1; // expected-note {{call to 'await_suspend' implicitly required by coroutine function here}}
+  }
+  {
+    bad_await_suspend_return_2 b2;
+    co_await b2; // expected-note {{call to 'await_suspend' implicitly required by coroutine function here}}
   }
   {
     await_ready_explicit_bool c;



More information about the cfe-commits mailing list