[clang] [Clang] Coroutines: Properly Check if `await_suspend` return type convertible to `std::coroutine_handle<>` (PR #85684)
Yuxuan Chen via cfe-commits
cfe-commits at lists.llvm.org
Fri Mar 22 11:58:10 PDT 2024
https://github.com/yuxuanchen1997 updated https://github.com/llvm/llvm-project/pull/85684
>From 6887adae7500c4791a8620fa5b558e195e2c64cc Mon Sep 17 00:00:00 2001
From: Yuxuan Chen <ych at meta.com>
Date: Mon, 18 Mar 2024 10:45:20 -0700
Subject: [PATCH] Check if Coroutine await_suspend type returns the right type
---
.../clang/Basic/DiagnosticSemaKinds.td | 2 +-
clang/lib/Sema/SemaCoroutine.cpp | 119 +++++++++++++-----
clang/test/SemaCXX/coroutines.cpp | 28 ++++-
3 files changed, 108 insertions(+), 41 deletions(-)
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index fc727cef9cd835..796b3d9d5e1190 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -11707,7 +11707,7 @@ def err_coroutine_promise_new_requires_nothrow : Error<
def note_coroutine_promise_call_implicitly_required : Note<
"call to %0 implicitly required by coroutine function here">;
def err_await_suspend_invalid_return_type : Error<
- "return type of 'await_suspend' is required to be 'void' or 'bool' (have %0)"
+ "return type of 'await_suspend' is required to be 'void', 'bool', or 'std::coroutine_handle' (have %0)"
>;
def note_await_ready_no_bool_conversion : Note<
"return type of 'await_ready' is required to be contextually convertible to 'bool'"
diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp
index 736632857efc36..2e81a83b62df51 100644
--- a/clang/lib/Sema/SemaCoroutine.cpp
+++ b/clang/lib/Sema/SemaCoroutine.cpp
@@ -137,12 +137,8 @@ static QualType lookupPromiseType(Sema &S, const FunctionDecl *FD,
return PromiseType;
}
-/// Look up the std::coroutine_handle<PromiseType>.
-static QualType lookupCoroutineHandleType(Sema &S, QualType PromiseType,
- SourceLocation Loc) {
- if (PromiseType.isNull())
- return QualType();
-
+static ClassTemplateDecl *lookupCoroutineHandleTemplate(Sema &S,
+ SourceLocation Loc) {
NamespaceDecl *CoroNamespace = S.getStdNamespace();
assert(CoroNamespace && "Should already be diagnosed");
@@ -151,18 +147,32 @@ static QualType lookupCoroutineHandleType(Sema &S, QualType PromiseType,
if (!S.LookupQualifiedName(Result, CoroNamespace)) {
S.Diag(Loc, diag::err_implied_coroutine_type_not_found)
<< "std::coroutine_handle";
- return QualType();
+ return nullptr;
}
- ClassTemplateDecl *CoroHandle = Result.getAsSingle<ClassTemplateDecl>();
+ auto *CoroHandle = Result.getAsSingle<ClassTemplateDecl>();
+
if (!CoroHandle) {
Result.suppressDiagnostics();
// We found something weird. Complain about the first thing we found.
NamedDecl *Found = *Result.begin();
S.Diag(Found->getLocation(), diag::err_malformed_std_coroutine_handle);
- return QualType();
+ return nullptr;
}
+ return CoroHandle;
+}
+
+/// Look up the std::coroutine_handle<PromiseType>.
+static QualType lookupCoroutineHandleType(Sema &S, QualType PromiseType,
+ SourceLocation Loc) {
+ if (PromiseType.isNull())
+ return QualType();
+
+ ClassTemplateDecl *CoroHandle = lookupCoroutineHandleTemplate(S, Loc);
+ if (!CoroHandle)
+ return QualType();
+
// Form template argument list for coroutine_handle<Promise>.
TemplateArgumentListInfo Args(Loc, Loc);
Args.addArgument(TemplateArgumentLoc(
@@ -331,16 +341,12 @@ static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc,
// coroutine.
static Expr *maybeTailCall(Sema &S, QualType RetType, Expr *E,
SourceLocation Loc) {
- if (RetType->isReferenceType())
- return nullptr;
+ assert(!RetType->isReferenceType() &&
+ "Should have diagnosed reference types.");
Type const *T = RetType.getTypePtr();
if (!T->isClassType() && !T->isStructureType())
return nullptr;
- // FIXME: Add convertability check to coroutine_handle<>. Possibly via
- // EvaluateBinaryTypeTrait(BTT_IsConvertible, ...) which is at the moment
- // a private function in SemaExprCXX.cpp
-
ExprResult AddressExpr = buildMemberCall(S, E, Loc, "address", std::nullopt);
if (AddressExpr.isInvalid())
return nullptr;
@@ -358,6 +364,30 @@ static Expr *maybeTailCall(Sema &S, QualType RetType, Expr *E,
return S.MaybeCreateExprWithCleanups(JustAddress);
}
+static bool isSpecializationOfCoroutineHandle(Sema &S, QualType Ty,
+ SourceLocation Loc) {
+ auto *CoroutineHandleClassTemplateDecl =
+ lookupCoroutineHandleTemplate(S, Loc);
+
+ if (!CoroutineHandleClassTemplateDecl)
+ return false;
+
+ auto *RecordTy = Ty->getAs<RecordType>();
+ if (!RecordTy)
+ return false;
+
+ auto *D = RecordTy->getDecl();
+ if (!D)
+ return false;
+
+ auto *SpecializationDecl = dyn_cast<ClassTemplateSpecializationDecl>(D);
+ if (!SpecializationDecl)
+ return false;
+
+ return CoroutineHandleClassTemplateDecl->getCanonicalDecl() ==
+ SpecializationDecl->getSpecializedTemplate()->getCanonicalDecl();
+}
+
/// Build calls to await_ready, await_suspend, and await_resume for a co_await
/// expression.
/// The generated AST tries to clean up temporary objects as early as
@@ -418,39 +448,60 @@ static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, VarDecl *CoroPromise,
return Calls;
}
Expr *CoroHandle = CoroHandleRes.get();
- CallExpr *AwaitSuspend = cast_or_null<CallExpr>(
- BuildSubExpr(ACT::ACT_Suspend, "await_suspend", CoroHandle));
+ auto *AwaitSuspend = [&]() -> CallExpr * {
+ auto *SubExpr = BuildSubExpr(ACT::ACT_Suspend, "await_suspend", CoroHandle);
+ if (!SubExpr)
+ return nullptr;
+ if (auto *E = dyn_cast<CXXBindTemporaryExpr>(SubExpr)) {
+ // This happens when await_suspend return type is not trivially
+ // destructible. This doesn't happen for the permitted return types of
+ // such function. Diagnose it later.
+ return cast_or_null<CallExpr>(E->getSubExpr());
+ } else {
+ return cast_or_null<CallExpr>(SubExpr);
+ }
+ }();
+
if (!AwaitSuspend)
return Calls;
+
if (!AwaitSuspend->getType()->isDependentType()) {
+ auto InvalidAwaitSuspendReturnType = [&](QualType RetType) {
+ // non-class prvalues always have cv-unqualified types
+ S.Diag(AwaitSuspend->getCalleeDecl()->getLocation(),
+ diag::err_await_suspend_invalid_return_type)
+ << RetType;
+ S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required)
+ << AwaitSuspend->getDirectCallee();
+ Calls.IsInvalid = true;
+ };
+
// [expr.await]p3 [...]
// - await-suspend is the expression e.await_suspend(h), which shall be
// a prvalue of type void, bool, or std::coroutine_handle<Z> for some
// type Z.
QualType RetType = AwaitSuspend->getCallReturnType(S.Context);
- // Support for coroutine_handle returning await_suspend.
- if (Expr *TailCallSuspend =
- maybeTailCall(S, RetType, AwaitSuspend, Loc))
+ if (RetType->isReferenceType()) {
+ InvalidAwaitSuspendReturnType(RetType);
+ } else if (RetType->isBooleanType() || RetType->isVoidType()) {
+ Calls.Results[ACT::ACT_Suspend] =
+ S.MaybeCreateExprWithCleanups(AwaitSuspend);
+ } else if (isSpecializationOfCoroutineHandle(S, RetType, Loc)) {
+ // Support for coroutine_handle returning await_suspend.
+ //
// Note that we don't wrap the expression with ExprWithCleanups here
// because that might interfere with tailcall contract (e.g. inserting
// clean up instructions in-between tailcall and return). Instead
// ExprWithCleanups is wrapped within maybeTailCall() prior to the resume
// call.
- Calls.Results[ACT::ACT_Suspend] = TailCallSuspend;
- else {
- // non-class prvalues always have cv-unqualified types
- if (RetType->isReferenceType() ||
- (!RetType->isBooleanType() && !RetType->isVoidType())) {
- S.Diag(AwaitSuspend->getCalleeDecl()->getLocation(),
- diag::err_await_suspend_invalid_return_type)
- << RetType;
- S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required)
- << AwaitSuspend->getDirectCallee();
- Calls.IsInvalid = true;
- } else
- Calls.Results[ACT::ACT_Suspend] =
- S.MaybeCreateExprWithCleanups(AwaitSuspend);
+ Expr *TailCallSuspend = maybeTailCall(S, RetType, AwaitSuspend, Loc);
+ if (TailCallSuspend)
+ Calls.Results[ACT::ACT_Suspend] = TailCallSuspend;
+ else
+ InvalidAwaitSuspendReturnType(RetType);
+ } else {
+ InvalidAwaitSuspendReturnType(RetType);
}
}
diff --git a/clang/test/SemaCXX/coroutines.cpp b/clang/test/SemaCXX/coroutines.cpp
index 2292932583fff6..14c4a2a8d9b45e 100644
--- a/clang/test/SemaCXX/coroutines.cpp
+++ b/clang/test/SemaCXX/coroutines.cpp
@@ -1005,12 +1005,24 @@ coro<promise_no_return_func> no_return_value_or_return_void_3() {
co_return 43; // expected-error {{no member named 'return_value'}}
}
-struct bad_await_suspend_return {
+struct non_trivial_destruction_type {
+ ~non_trivial_destruction_type();
+};
+
+struct bad_await_suspend_return_1 {
bool await_ready();
- // expected-error at +1 {{return type of 'await_suspend' is required to be 'void' or 'bool' (have 'char')}}
+ // expected-error at +1 {{return type of 'await_suspend' is required to be 'void', 'bool', or 'std::coroutine_handle' (have 'char')}}
char await_suspend(std::coroutine_handle<>);
void await_resume();
};
+
+struct bad_await_suspend_return_2 {
+ bool await_ready();
+ // expected-error at +1 {{return type of 'await_suspend' is required to be 'void', 'bool', or 'std::coroutine_handle' (have 'non_trivial_destruction_type')}}
+ non_trivial_destruction_type await_suspend(std::coroutine_handle<>);
+ void await_resume();
+};
+
struct bad_await_ready_return {
// expected-note at +1 {{return type of 'await_ready' is required to be contextually convertible to 'bool'}}
void await_ready();
@@ -1028,8 +1040,8 @@ struct await_ready_explicit_bool {
template <class SuspendTy>
struct await_suspend_type_test {
bool await_ready();
- // expected-error at +2 {{return type of 'await_suspend' is required to be 'void' or 'bool' (have 'bool &')}}
- // expected-error at +1 {{return type of 'await_suspend' is required to be 'void' or 'bool' (have 'bool &&')}}
+ // expected-error at +2 {{return type of 'await_suspend' is required to be 'void', 'bool', or 'std::coroutine_handle' (have 'bool &')}}
+ // expected-error at +1 {{return type of 'await_suspend' is required to be 'void', 'bool', or 'std::coroutine_handle' (have 'bool &&')}}
SuspendTy await_suspend(std::coroutine_handle<>);
// cxx20_23-warning at -1 {{volatile-qualified return type 'const volatile bool' is deprecated}}
void await_resume();
@@ -1042,8 +1054,12 @@ void test_bad_suspend() {
co_await a; // expected-note {{call to 'await_ready' implicitly required by coroutine function here}}
}
{
- bad_await_suspend_return b;
- co_await b; // expected-note {{call to 'await_suspend' implicitly required by coroutine function here}}
+ bad_await_suspend_return_1 b1;
+ co_await b1; // expected-note {{call to 'await_suspend' implicitly required by coroutine function here}}
+ }
+ {
+ bad_await_suspend_return_2 b2;
+ co_await b2; // expected-note {{call to 'await_suspend' implicitly required by coroutine function here}}
}
{
await_ready_explicit_bool c;
More information about the cfe-commits
mailing list