[clang] [Clang] Coroutines: Properly Check if `await_suspend` return type convertible to `std::coroutine_handle<>` (PR #85684)
Yuxuan Chen via cfe-commits
cfe-commits at lists.llvm.org
Tue Mar 19 14:27:28 PDT 2024
https://github.com/yuxuanchen1997 updated https://github.com/llvm/llvm-project/pull/85684
>From 08de54f02038795924a6e5fdbcf51a496fcedf56 Mon Sep 17 00:00:00 2001
From: Yuxuan Chen <ych at meta.com>
Date: Mon, 18 Mar 2024 10:45:20 -0700
Subject: [PATCH] Check if Coroutine await_suspend type returns the right type
---
.../clang/Basic/DiagnosticSemaKinds.td | 2 +-
clang/include/clang/Sema/Sema.h | 2 +
clang/lib/Sema/SemaCoroutine.cpp | 75 +++++++++++------
clang/lib/Sema/SemaExprCXX.cpp | 84 +++++++++----------
clang/test/SemaCXX/coroutines.cpp | 28 +++++--
5 files changed, 116 insertions(+), 75 deletions(-)
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 8e97902564af08..f99170445c76b6 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -11701,7 +11701,7 @@ def err_coroutine_promise_new_requires_nothrow : Error<
def note_coroutine_promise_call_implicitly_required : Note<
"call to %0 implicitly required by coroutine function here">;
def err_await_suspend_invalid_return_type : Error<
- "return type of 'await_suspend' is required to be 'void' or 'bool' (have %0)"
+ "return type of 'await_suspend' is required to be 'void' or 'bool' or convertible to 'std::coroutine_handle<>' (have %0)"
>;
def note_await_ready_no_bool_conversion : Note<
"return type of 'await_ready' is required to be contextually convertible to 'bool'"
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 95ea5ebc7f1ac1..4976ff96b03d5b 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -7011,6 +7011,8 @@ class Sema final {
ExprResult BuildTypeTrait(TypeTrait Kind, SourceLocation KWLoc,
ArrayRef<TypeSourceInfo *> Args,
SourceLocation RParenLoc);
+ bool EvaluateBinaryTypeTrait(TypeTrait BTT, QualType LhsT, QualType RhsT,
+ SourceLocation KeyLoc);
/// ActOnArrayTypeTrait - Parsed one of the binary type trait support
/// pseudo-functions.
diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp
index 736632857efc36..fbe230737404fa 100644
--- a/clang/lib/Sema/SemaCoroutine.cpp
+++ b/clang/lib/Sema/SemaCoroutine.cpp
@@ -331,16 +331,12 @@ static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc,
// coroutine.
static Expr *maybeTailCall(Sema &S, QualType RetType, Expr *E,
SourceLocation Loc) {
- if (RetType->isReferenceType())
- return nullptr;
+ assert(!RetType->isReferenceType() &&
+ "Should have diagnosed reference types.");
Type const *T = RetType.getTypePtr();
if (!T->isClassType() && !T->isStructureType())
return nullptr;
- // FIXME: Add convertability check to coroutine_handle<>. Possibly via
- // EvaluateBinaryTypeTrait(BTT_IsConvertible, ...) which is at the moment
- // a private function in SemaExprCXX.cpp
-
ExprResult AddressExpr = buildMemberCall(S, E, Loc, "address", std::nullopt);
if (AddressExpr.isInvalid())
return nullptr;
@@ -358,6 +354,14 @@ static Expr *maybeTailCall(Sema &S, QualType RetType, Expr *E,
return S.MaybeCreateExprWithCleanups(JustAddress);
}
+static bool isConvertibleToCoroutineHandle(Sema &S, QualType Ty,
+ SourceLocation Loc) {
+ QualType ErasedHandleType =
+ lookupCoroutineHandleType(S, S.Context.VoidTy, Loc);
+ return S.EvaluateBinaryTypeTrait(BTT_IsConvertible, Ty, ErasedHandleType,
+ Loc);
+}
+
/// Build calls to await_ready, await_suspend, and await_resume for a co_await
/// expression.
/// The generated AST tries to clean up temporary objects as early as
@@ -418,39 +422,60 @@ static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, VarDecl *CoroPromise,
return Calls;
}
Expr *CoroHandle = CoroHandleRes.get();
- CallExpr *AwaitSuspend = cast_or_null<CallExpr>(
- BuildSubExpr(ACT::ACT_Suspend, "await_suspend", CoroHandle));
+ auto *AwaitSuspend = [&]() -> CallExpr * {
+ auto *SubExpr = BuildSubExpr(ACT::ACT_Suspend, "await_suspend", CoroHandle);
+ if (!SubExpr)
+ return nullptr;
+ if (auto *E = dyn_cast<CXXBindTemporaryExpr>(SubExpr)) {
+ // This happens when await_suspend return type is not trivially
+ // destructible. This doesn't happen for the permitted return types of
+ // such function. Diagnose it later.
+ return cast_or_null<CallExpr>(E->getSubExpr());
+ } else {
+ return cast_or_null<CallExpr>(SubExpr);
+ }
+ }();
+
if (!AwaitSuspend)
return Calls;
+
if (!AwaitSuspend->getType()->isDependentType()) {
+ auto InvalidAwaitSuspendReturnType = [&](QualType RetType) {
+ // non-class prvalues always have cv-unqualified types
+ S.Diag(AwaitSuspend->getCalleeDecl()->getLocation(),
+ diag::err_await_suspend_invalid_return_type)
+ << RetType;
+ S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required)
+ << AwaitSuspend->getDirectCallee();
+ Calls.IsInvalid = true;
+ };
+
// [expr.await]p3 [...]
// - await-suspend is the expression e.await_suspend(h), which shall be
// a prvalue of type void, bool, or std::coroutine_handle<Z> for some
// type Z.
QualType RetType = AwaitSuspend->getCallReturnType(S.Context);
- // Support for coroutine_handle returning await_suspend.
- if (Expr *TailCallSuspend =
- maybeTailCall(S, RetType, AwaitSuspend, Loc))
+ if (RetType->isReferenceType()) {
+ InvalidAwaitSuspendReturnType(RetType);
+ } else if (RetType->isBooleanType() || RetType->isVoidType()) {
+ Calls.Results[ACT::ACT_Suspend] =
+ S.MaybeCreateExprWithCleanups(AwaitSuspend);
+ } else if (isConvertibleToCoroutineHandle(S, RetType, Loc)) {
+ // Support for coroutine_handle returning await_suspend.
+ //
// Note that we don't wrap the expression with ExprWithCleanups here
// because that might interfere with tailcall contract (e.g. inserting
// clean up instructions in-between tailcall and return). Instead
// ExprWithCleanups is wrapped within maybeTailCall() prior to the resume
// call.
- Calls.Results[ACT::ACT_Suspend] = TailCallSuspend;
- else {
- // non-class prvalues always have cv-unqualified types
- if (RetType->isReferenceType() ||
- (!RetType->isBooleanType() && !RetType->isVoidType())) {
- S.Diag(AwaitSuspend->getCalleeDecl()->getLocation(),
- diag::err_await_suspend_invalid_return_type)
- << RetType;
- S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required)
- << AwaitSuspend->getDirectCallee();
- Calls.IsInvalid = true;
- } else
- Calls.Results[ACT::ACT_Suspend] =
- S.MaybeCreateExprWithCleanups(AwaitSuspend);
+ Expr *TailCallSuspend = maybeTailCall(S, RetType, AwaitSuspend, Loc);
+ if (TailCallSuspend)
+ Calls.Results[ACT::ACT_Suspend] = TailCallSuspend;
+ else
+ InvalidAwaitSuspendReturnType(RetType);
+ } else {
+ InvalidAwaitSuspendReturnType(RetType);
}
}
diff --git a/clang/lib/Sema/SemaExprCXX.cpp b/clang/lib/Sema/SemaExprCXX.cpp
index c34a40fa7c81ac..db04e59a91332d 100644
--- a/clang/lib/Sema/SemaExprCXX.cpp
+++ b/clang/lib/Sema/SemaExprCXX.cpp
@@ -5559,9 +5559,6 @@ static bool EvaluateUnaryTypeTrait(Sema &Self, TypeTrait UTT,
}
}
-static bool EvaluateBinaryTypeTrait(Sema &Self, TypeTrait BTT, QualType LhsT,
- QualType RhsT, SourceLocation KeyLoc);
-
static bool EvaluateBooleanTypeTrait(Sema &S, TypeTrait Kind,
SourceLocation KWLoc,
ArrayRef<TypeSourceInfo *> Args,
@@ -5576,8 +5573,8 @@ static bool EvaluateBooleanTypeTrait(Sema &S, TypeTrait Kind,
// Evaluate ReferenceBindsToTemporary and ReferenceConstructsFromTemporary
// alongside the IsConstructible traits to avoid duplication.
if (Kind <= BTT_Last && Kind != BTT_ReferenceBindsToTemporary && Kind != BTT_ReferenceConstructsFromTemporary)
- return EvaluateBinaryTypeTrait(S, Kind, Args[0]->getType(),
- Args[1]->getType(), RParenLoc);
+ return S.EvaluateBinaryTypeTrait(Kind, Args[0]->getType(),
+ Args[1]->getType(), RParenLoc);
switch (Kind) {
case clang::BTT_ReferenceBindsToTemporary:
@@ -5674,7 +5671,8 @@ static bool EvaluateBooleanTypeTrait(Sema &S, TypeTrait Kind,
QualType TPtr = S.Context.getPointerType(S.BuiltinRemoveReference(T, UnaryTransformType::RemoveCVRef, {}));
QualType UPtr = S.Context.getPointerType(S.BuiltinRemoveReference(U, UnaryTransformType::RemoveCVRef, {}));
- return EvaluateBinaryTypeTrait(S, TypeTrait::BTT_IsConvertibleTo, UPtr, TPtr, RParenLoc);
+ return S.EvaluateBinaryTypeTrait(TypeTrait::BTT_IsConvertibleTo, UPtr,
+ TPtr, RParenLoc);
}
if (Kind == clang::TT_IsNothrowConstructible)
@@ -5807,8 +5805,8 @@ ExprResult Sema::ActOnTypeTrait(TypeTrait Kind, SourceLocation KWLoc,
return BuildTypeTrait(Kind, KWLoc, ConvertedArgs, RParenLoc);
}
-static bool EvaluateBinaryTypeTrait(Sema &Self, TypeTrait BTT, QualType LhsT,
- QualType RhsT, SourceLocation KeyLoc) {
+bool Sema::EvaluateBinaryTypeTrait(TypeTrait BTT, QualType LhsT, QualType RhsT,
+ SourceLocation KeyLoc) {
assert(!LhsT->isDependentType() && !RhsT->isDependentType() &&
"Cannot evaluate traits of dependent types");
@@ -5832,15 +5830,15 @@ static bool EvaluateBinaryTypeTrait(Sema &Self, TypeTrait BTT, QualType LhsT,
if (!BaseInterface || !DerivedInterface)
return false;
- if (Self.RequireCompleteType(
+ if (RequireCompleteType(
KeyLoc, RhsT, diag::err_incomplete_type_used_in_type_trait_expr))
return false;
return BaseInterface->isSuperClassOf(DerivedInterface);
}
- assert(Self.Context.hasSameUnqualifiedType(LhsT, RhsT)
- == (lhsRecord == rhsRecord));
+ assert(Context.hasSameUnqualifiedType(LhsT, RhsT) ==
+ (lhsRecord == rhsRecord));
// Unions are never base classes, and never have base classes.
// It doesn't matter if they are complete or not. See PR#41843
@@ -5856,21 +5854,21 @@ static bool EvaluateBinaryTypeTrait(Sema &Self, TypeTrait BTT, QualType LhsT,
// If Base and Derived are class types and are different types
// (ignoring possible cv-qualifiers) then Derived shall be a
// complete type.
- if (Self.RequireCompleteType(KeyLoc, RhsT,
- diag::err_incomplete_type_used_in_type_trait_expr))
+ if (RequireCompleteType(KeyLoc, RhsT,
+ diag::err_incomplete_type_used_in_type_trait_expr))
return false;
return cast<CXXRecordDecl>(rhsRecord->getDecl())
->isDerivedFrom(cast<CXXRecordDecl>(lhsRecord->getDecl()));
}
case BTT_IsSame:
- return Self.Context.hasSameType(LhsT, RhsT);
+ return Context.hasSameType(LhsT, RhsT);
case BTT_TypeCompatible: {
// GCC ignores cv-qualifiers on arrays for this builtin.
Qualifiers LhsQuals, RhsQuals;
- QualType Lhs = Self.getASTContext().getUnqualifiedArrayType(LhsT, LhsQuals);
- QualType Rhs = Self.getASTContext().getUnqualifiedArrayType(RhsT, RhsQuals);
- return Self.Context.typesAreCompatible(Lhs, Rhs);
+ QualType Lhs = getASTContext().getUnqualifiedArrayType(LhsT, LhsQuals);
+ QualType Rhs = getASTContext().getUnqualifiedArrayType(RhsT, RhsQuals);
+ return Context.typesAreCompatible(Lhs, Rhs);
}
case BTT_IsConvertible:
case BTT_IsConvertibleTo:
@@ -5909,16 +5907,16 @@ static bool EvaluateBinaryTypeTrait(Sema &Self, TypeTrait BTT, QualType LhsT,
return LhsT->isVoidType();
// A function definition requires a complete, non-abstract return type.
- if (!Self.isCompleteType(KeyLoc, RhsT) || Self.isAbstractType(KeyLoc, RhsT))
+ if (!isCompleteType(KeyLoc, RhsT) || isAbstractType(KeyLoc, RhsT))
return false;
// Compute the result of add_rvalue_reference.
if (LhsT->isObjectType() || LhsT->isFunctionType())
- LhsT = Self.Context.getRValueReferenceType(LhsT);
+ LhsT = Context.getRValueReferenceType(LhsT);
// Build a fake source and destination for initialization.
InitializedEntity To(InitializedEntity::InitializeTemporary(RhsT));
- OpaqueValueExpr From(KeyLoc, LhsT.getNonLValueExprType(Self.Context),
+ OpaqueValueExpr From(KeyLoc, LhsT.getNonLValueExprType(Context),
Expr::getValueKindForType(LhsT));
Expr *FromPtr = &From;
InitializationKind Kind(InitializationKind::CreateCopy(KeyLoc,
@@ -5927,21 +5925,21 @@ static bool EvaluateBinaryTypeTrait(Sema &Self, TypeTrait BTT, QualType LhsT,
// Perform the initialization in an unevaluated context within a SFINAE
// trap at translation unit scope.
EnterExpressionEvaluationContext Unevaluated(
- Self, Sema::ExpressionEvaluationContext::Unevaluated);
- Sema::SFINAETrap SFINAE(Self, /*AccessCheckingSFINAE=*/true);
- Sema::ContextRAII TUContext(Self, Self.Context.getTranslationUnitDecl());
- InitializationSequence Init(Self, To, Kind, FromPtr);
+ *this, Sema::ExpressionEvaluationContext::Unevaluated);
+ Sema::SFINAETrap SFINAE(*this, /*AccessCheckingSFINAE=*/true);
+ Sema::ContextRAII TUContext(*this, Context.getTranslationUnitDecl());
+ InitializationSequence Init(*this, To, Kind, FromPtr);
if (Init.Failed())
return false;
- ExprResult Result = Init.Perform(Self, To, Kind, FromPtr);
+ ExprResult Result = Init.Perform(*this, To, Kind, FromPtr);
if (Result.isInvalid() || SFINAE.hasErrorOccurred())
return false;
if (BTT != BTT_IsNothrowConvertible)
return true;
- return Self.canThrow(Result.get()) == CT_Cannot;
+ return canThrow(Result.get()) == CT_Cannot;
}
case BTT_IsAssignable:
@@ -5959,12 +5957,12 @@ static bool EvaluateBinaryTypeTrait(Sema &Self, TypeTrait BTT, QualType LhsT,
// For both, T and U shall be complete types, (possibly cv-qualified)
// void, or arrays of unknown bound.
if (!LhsT->isVoidType() && !LhsT->isIncompleteArrayType() &&
- Self.RequireCompleteType(KeyLoc, LhsT,
- diag::err_incomplete_type_used_in_type_trait_expr))
+ RequireCompleteType(KeyLoc, LhsT,
+ diag::err_incomplete_type_used_in_type_trait_expr))
return false;
if (!RhsT->isVoidType() && !RhsT->isIncompleteArrayType() &&
- Self.RequireCompleteType(KeyLoc, RhsT,
- diag::err_incomplete_type_used_in_type_trait_expr))
+ RequireCompleteType(KeyLoc, RhsT,
+ diag::err_incomplete_type_used_in_type_trait_expr))
return false;
// cv void is never assignable.
@@ -5974,27 +5972,27 @@ static bool EvaluateBinaryTypeTrait(Sema &Self, TypeTrait BTT, QualType LhsT,
// Build expressions that emulate the effect of declval<T>() and
// declval<U>().
if (LhsT->isObjectType() || LhsT->isFunctionType())
- LhsT = Self.Context.getRValueReferenceType(LhsT);
+ LhsT = Context.getRValueReferenceType(LhsT);
if (RhsT->isObjectType() || RhsT->isFunctionType())
- RhsT = Self.Context.getRValueReferenceType(RhsT);
- OpaqueValueExpr Lhs(KeyLoc, LhsT.getNonLValueExprType(Self.Context),
+ RhsT = Context.getRValueReferenceType(RhsT);
+ OpaqueValueExpr Lhs(KeyLoc, LhsT.getNonLValueExprType(Context),
Expr::getValueKindForType(LhsT));
- OpaqueValueExpr Rhs(KeyLoc, RhsT.getNonLValueExprType(Self.Context),
+ OpaqueValueExpr Rhs(KeyLoc, RhsT.getNonLValueExprType(Context),
Expr::getValueKindForType(RhsT));
// Attempt the assignment in an unevaluated context within a SFINAE
// trap at translation unit scope.
EnterExpressionEvaluationContext Unevaluated(
- Self, Sema::ExpressionEvaluationContext::Unevaluated);
- Sema::SFINAETrap SFINAE(Self, /*AccessCheckingSFINAE=*/true);
- Sema::ContextRAII TUContext(Self, Self.Context.getTranslationUnitDecl());
- ExprResult Result = Self.BuildBinOp(/*S=*/nullptr, KeyLoc, BO_Assign, &Lhs,
- &Rhs);
+ *this, Sema::ExpressionEvaluationContext::Unevaluated);
+ Sema::SFINAETrap SFINAE(*this, /*AccessCheckingSFINAE=*/true);
+ Sema::ContextRAII TUContext(*this, Context.getTranslationUnitDecl());
+ ExprResult Result =
+ BuildBinOp(/*S=*/nullptr, KeyLoc, BO_Assign, &Lhs, &Rhs);
if (Result.isInvalid())
return false;
// Treat the assignment as unused for the purpose of -Wdeprecated-volatile.
- Self.CheckUnusedVolatileAssignment(Result.get());
+ CheckUnusedVolatileAssignment(Result.get());
if (SFINAE.hasErrorOccurred())
return false;
@@ -6003,7 +6001,7 @@ static bool EvaluateBinaryTypeTrait(Sema &Self, TypeTrait BTT, QualType LhsT,
return true;
if (BTT == BTT_IsNothrowAssignable)
- return Self.canThrow(Result.get()) == CT_Cannot;
+ return canThrow(Result.get()) == CT_Cannot;
if (BTT == BTT_IsTriviallyAssignable) {
// Under Objective-C ARC and Weak, if the destination has non-trivial
@@ -6011,14 +6009,14 @@ static bool EvaluateBinaryTypeTrait(Sema &Self, TypeTrait BTT, QualType LhsT,
if (LhsT.getNonReferenceType().hasNonTrivialObjCLifetime())
return false;
- return !Result.get()->hasNonTrivialCall(Self.Context);
+ return !Result.get()->hasNonTrivialCall(Context);
}
llvm_unreachable("unhandled type trait");
return false;
}
case BTT_IsLayoutCompatible: {
- return Self.IsLayoutCompatible(LhsT, RhsT);
+ return IsLayoutCompatible(LhsT, RhsT);
}
default: llvm_unreachable("not a BTT");
}
diff --git a/clang/test/SemaCXX/coroutines.cpp b/clang/test/SemaCXX/coroutines.cpp
index 2292932583fff6..cdd9be4c201d3f 100644
--- a/clang/test/SemaCXX/coroutines.cpp
+++ b/clang/test/SemaCXX/coroutines.cpp
@@ -1005,12 +1005,24 @@ coro<promise_no_return_func> no_return_value_or_return_void_3() {
co_return 43; // expected-error {{no member named 'return_value'}}
}
-struct bad_await_suspend_return {
+struct non_trivial_destruction_type {
+ ~non_trivial_destruction_type();
+};
+
+struct bad_await_suspend_return_1 {
bool await_ready();
- // expected-error at +1 {{return type of 'await_suspend' is required to be 'void' or 'bool' (have 'char')}}
+ // expected-error at +1 {{return type of 'await_suspend' is required to be 'void' or 'bool' or convertible to 'std::coroutine_handle<>' (have 'char')}}
char await_suspend(std::coroutine_handle<>);
void await_resume();
};
+
+struct bad_await_suspend_return_2 {
+ bool await_ready();
+ // expected-error at +1 {{return type of 'await_suspend' is required to be 'void' or 'bool' or convertible to 'std::coroutine_handle<>' (have 'non_trivial_destruction_type')}}
+ non_trivial_destruction_type await_suspend(std::coroutine_handle<>);
+ void await_resume();
+};
+
struct bad_await_ready_return {
// expected-note at +1 {{return type of 'await_ready' is required to be contextually convertible to 'bool'}}
void await_ready();
@@ -1028,8 +1040,8 @@ struct await_ready_explicit_bool {
template <class SuspendTy>
struct await_suspend_type_test {
bool await_ready();
- // expected-error at +2 {{return type of 'await_suspend' is required to be 'void' or 'bool' (have 'bool &')}}
- // expected-error at +1 {{return type of 'await_suspend' is required to be 'void' or 'bool' (have 'bool &&')}}
+ // expected-error at +2 {{return type of 'await_suspend' is required to be 'void' or 'bool' or convertible to 'std::coroutine_handle<>' (have 'bool &')}}
+ // expected-error at +1 {{return type of 'await_suspend' is required to be 'void' or 'bool' or convertible to 'std::coroutine_handle<>' (have 'bool &&')}}
SuspendTy await_suspend(std::coroutine_handle<>);
// cxx20_23-warning at -1 {{volatile-qualified return type 'const volatile bool' is deprecated}}
void await_resume();
@@ -1042,8 +1054,12 @@ void test_bad_suspend() {
co_await a; // expected-note {{call to 'await_ready' implicitly required by coroutine function here}}
}
{
- bad_await_suspend_return b;
- co_await b; // expected-note {{call to 'await_suspend' implicitly required by coroutine function here}}
+ bad_await_suspend_return_1 b1;
+ co_await b1; // expected-note {{call to 'await_suspend' implicitly required by coroutine function here}}
+ }
+ {
+ bad_await_suspend_return_2 b2;
+ co_await b2; // expected-note {{call to 'await_suspend' implicitly required by coroutine function here}}
}
{
await_ready_explicit_bool c;
More information about the cfe-commits
mailing list