[clang] [coroutines][coro_lifetimebound] Detect lifetime issues with lambda captures (PR #77066)

Utkarsh Saxena via cfe-commits cfe-commits at lists.llvm.org
Thu Jan 18 02:54:57 PST 2024


https://github.com/usx95 updated https://github.com/llvm/llvm-project/pull/77066

>From 3e0d0ab6c4fc6cba68285816a95e423bc18e8e55 Mon Sep 17 00:00:00 2001
From: Utkarsh Saxena <usx at google.com>
Date: Fri, 5 Jan 2024 10:11:20 +0100
Subject: [PATCH 01/16] [coroutines] Detect lifetime issues with coroutine
 lambda captures

---
 clang/lib/Sema/SemaInit.cpp               | 20 +++++--
 clang/test/SemaCXX/coro-lifetimebound.cpp | 64 +++++++++++++++++++++--
 2 files changed, 76 insertions(+), 8 deletions(-)

diff --git a/clang/lib/Sema/SemaInit.cpp b/clang/lib/Sema/SemaInit.cpp
index 60c0e3e74204ec..c100bf11454786 100644
--- a/clang/lib/Sema/SemaInit.cpp
+++ b/clang/lib/Sema/SemaInit.cpp
@@ -12,6 +12,7 @@
 
 #include "clang/AST/ASTContext.h"
 #include "clang/AST/DeclObjC.h"
+#include "clang/AST/Expr.h"
 #include "clang/AST/ExprCXX.h"
 #include "clang/AST/ExprObjC.h"
 #include "clang/AST/ExprOpenMP.h"
@@ -33,6 +34,7 @@
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/Casting.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/raw_ostream.h"
 
@@ -7575,15 +7577,27 @@ static void visitLifetimeBoundArguments(IndirectLocalPath &Path, Expr *Call,
     Path.pop_back();
   };
 
-  if (ObjectArg && implicitObjectParamIsLifetimeBound(Callee))
-    VisitLifetimeBoundArg(Callee, ObjectArg);
-
   bool CheckCoroCall = false;
   if (const auto *RD = Callee->getReturnType()->getAsRecordDecl()) {
     CheckCoroCall = RD->hasAttr<CoroLifetimeBoundAttr>() &&
                     RD->hasAttr<CoroReturnTypeAttr>() &&
                     !Callee->hasAttr<CoroDisableLifetimeBoundAttr>();
   }
+
+  if (ObjectArg) {
+    bool CheckCoroObjArg = CheckCoroCall;
+    // Ignore `__promise.get_return_object()` as it not lifetimebound.
+    if (Callee->getDeclName().isIdentifier() &&
+        Callee->getName() == "get_return_object")
+      CheckCoroObjArg = false;
+    // Coroutine lambda objects with empty capture list are not lifetimebound.
+    if (auto *LE = dyn_cast<LambdaExpr>(ObjectArg->IgnoreImplicit());
+        LE && LE->captures().empty())
+      CheckCoroObjArg = false;
+    if (implicitObjectParamIsLifetimeBound(Callee) || CheckCoroObjArg)
+      VisitLifetimeBoundArg(Callee, ObjectArg);
+  }
+
   for (unsigned I = 0,
                 N = std::min<unsigned>(Callee->getNumParams(), Args.size());
        I != N; ++I) {
diff --git a/clang/test/SemaCXX/coro-lifetimebound.cpp b/clang/test/SemaCXX/coro-lifetimebound.cpp
index 3fc7ca70a14a12..319134450e4b6f 100644
--- a/clang/test/SemaCXX/coro-lifetimebound.cpp
+++ b/clang/test/SemaCXX/coro-lifetimebound.cpp
@@ -64,6 +64,10 @@ Co<int> bar_coro(const int &b, int c) {
       : bar_coro(0, 1); // expected-warning {{returning address of local temporary object}}
 }
 
+// =============================================================================
+// Lambdas
+// =============================================================================
+namespace lambdas {
 void lambdas() {
   auto unsafe_lambda = [] [[clang::coro_wrapper]] (int b) {
     return foo_coro(b); // expected-warning {{address of stack memory associated with parameter}}
@@ -84,15 +88,47 @@ void lambdas() {
     co_return x + co_await foo_coro(b);
   };
 }
+
+Co<int> lambda_captures() {
+  int a = 1;
+  // Temporary lambda object dies.
+  auto lamb = [a](int x, const int& y) -> Co<int> { // expected-warning {{temporary whose address is used as value of local variable 'lamb'}}
+    co_return x + y + a;
+  }(1, a);
+  // Object dies but it has no capture.
+  auto no_capture = []() -> Co<int> { co_return 1; }();
+  auto bad_no_capture = [](const int& a) -> Co<int> { co_return a; }(1); // expected-warning {{temporary}}
+  // Temporary lambda object with lifetime extension under co_await.
+  int res = co_await [a](int x, const int& y) -> Co<int> {
+    co_return x + y + a;
+  }(1, a);
+  co_return 1;
+}
+} // namespace lambdas
+
 // =============================================================================
-// Safe usage when parameters are value
+// Member coroutines
 // =============================================================================
-namespace by_value {
-Co<int> value_coro(int b) { co_return co_await foo_coro(b); }
-[[clang::coro_wrapper]] Co<int> wrapper1(int b) { return value_coro(b); }
-[[clang::coro_wrapper]] Co<int> wrapper2(const int& b) { return value_coro(b); }
+namespace member_coroutines{
+struct S {
+  Co<int> member(const int& a) { co_return a; }  
+};
+
+Co<int> use() {
+  S s;
+  int a = 1;
+  auto test1 = s.member(1);  // expected-warning {{temporary whose address is used as value of local variable}}
+  auto test2 = s.member(a);
+  auto test3 = S{}.member(a);  // expected-warning {{temporary whose address is used as value of local variable}}
+  co_return 1;
 }
 
+[[clang::coro_wrapper]] Co<int> wrapper(const int& a) {
+  S s;
+  return s.member(a); // expected-warning {{address of stack memory}}
+}
+} // member_coroutines
+
 // =============================================================================
 // Lifetime bound but not a Coroutine Return Type: No analysis.
 // =============================================================================
@@ -129,4 +165,22 @@ Co<int> foo_wrapper(const int& x) { return foo(x); }
   // The call to foo_wrapper is wrapper is safe.
   return foo_wrapper(1);
 }
+
+struct S{
+[[clang::coro_wrapper, clang::coro_disable_lifetimebound]] 
+Co<int> member(const int& x) { return foo(x); }
+};
+
+Co<int> use() {
+  S s;
+  int a = 1;
+  auto test1 = s.member(1); // param is not flagged.
+  auto test2 = S{}.member(a); // 'this' is not flagged.
+  co_return 1;
+}
+
+[[clang::coro_wrapper]] Co<int> return_stack_addr(const int& a) {
+  S s;
+  return s.member(a); // return of stack addr is not flagged.
+}
 } // namespace disable_lifetimebound

>From 0d8d78e30dff896880e674240c8bf47b13c527f5 Mon Sep 17 00:00:00 2001
From: Utkarsh Saxena <usx at google.com>
Date: Fri, 5 Jan 2024 11:39:41 +0100
Subject: [PATCH 02/16] add more test

---
 clang/lib/Sema/SemaInit.cpp               |  4 ++--
 clang/test/SemaCXX/coro-lifetimebound.cpp | 17 +++++++++++++++--
 2 files changed, 17 insertions(+), 4 deletions(-)

diff --git a/clang/lib/Sema/SemaInit.cpp b/clang/lib/Sema/SemaInit.cpp
index c100bf11454786..1aa6be826699bb 100644
--- a/clang/lib/Sema/SemaInit.cpp
+++ b/clang/lib/Sema/SemaInit.cpp
@@ -7586,8 +7586,8 @@ static void visitLifetimeBoundArguments(IndirectLocalPath &Path, Expr *Call,
 
   if (ObjectArg) {
     bool CheckCoroObjArg = CheckCoroCall;
-    // Ignore `__promise.get_return_object()` as it not lifetimebound.
-    if (Callee->getDeclName().isIdentifier() &&
+    // Ignore `__promise.get_return_object()` as it is not lifetimebound.
+    if (CheckCoroObjArg && Callee->getDeclName().isIdentifier() &&
         Callee->getName() == "get_return_object")
       CheckCoroObjArg = false;
     // Coroutine lambda objects with empty capture list are not lifetimebound.
diff --git a/clang/test/SemaCXX/coro-lifetimebound.cpp b/clang/test/SemaCXX/coro-lifetimebound.cpp
index 319134450e4b6f..b61013ec057381 100644
--- a/clang/test/SemaCXX/coro-lifetimebound.cpp
+++ b/clang/test/SemaCXX/coro-lifetimebound.cpp
@@ -102,6 +102,10 @@ Co<int> lambda_captures() {
   int res = co_await [a](int x, const int& y) -> Co<int> {
     co_return x + y + a;
   }(1, a);
+  // Lambda object on stack should be fine.
+  auto lamb2 = [a]() -> Co<int> { co_return a; };
+  auto on_stack = lamb2();
+  auto res2 = co_await on_stack;
   co_return 1;
 }
 } // namespace lambdas
@@ -111,7 +115,7 @@ Co<int> lambda_captures() {
 // =============================================================================
 namespace member_coroutines{
 struct S {
-  Co<int> member(const int& a) { co_return a; }  
+  Co<int> member(const int& a) { co_return a; }
 };
 
 Co<int> use() {
@@ -129,6 +133,15 @@ Co<int> use() {
 }
 } // member_coroutines
 
+// =============================================================================
+// Safe usage when parameters are value
+// =============================================================================
+namespace by_value {
+Co<int> value_coro(int b) { co_return co_await foo_coro(b); }
+[[clang::coro_wrapper]] Co<int> wrapper1(int b) { return value_coro(b); }
+[[clang::coro_wrapper]] Co<int> wrapper2(const int& b) { return value_coro(b); }
+} // namespace by_value
+
 // =============================================================================
 // Lifetime bound but not a Coroutine Return Type: No analysis.
 // =============================================================================
@@ -158,7 +171,7 @@ CoNoCRT<int> bar(int a) {
 namespace disable_lifetimebound {
 Co<int> foo(int x) {  co_return x; }
 
-[[clang::coro_wrapper, clang::coro_disable_lifetimebound]] 
+[[clang::coro_wrapper, clang::coro_disable_lifetimebound]]
 Co<int> foo_wrapper(const int& x) { return foo(x); }
 
 [[clang::coro_wrapper]] Co<int> caller() {

>From e4b801784962946eedf8b5e7c229871588b1621b Mon Sep 17 00:00:00 2001
From: Utkarsh Saxena <usx at google.com>
Date: Wed, 10 Jan 2024 13:06:16 +0000
Subject: [PATCH 03/16] Refactor isGetReturnObject()

---
 clang/include/clang/Sema/Sema.h | 6 ++++++
 clang/lib/Sema/SemaDecl.cpp     | 3 +--
 clang/lib/Sema/SemaInit.cpp     | 3 +--
 3 files changed, 8 insertions(+), 4 deletions(-)

diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 8f44adef38159e..a17a9444d201b6 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -11219,6 +11219,12 @@ class Sema final {
   bool buildCoroutineParameterMoves(SourceLocation Loc);
   VarDecl *buildCoroutinePromise(SourceLocation Loc);
   void CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body);
+  // Heuristaclly tells if the function is promise::get_return_object().
+  static bool isGetReturnObject(const FunctionDecl *FD) {
+    return isa_and_nonnull<CXXMethodDecl>(FD) &&
+           FD->getDeclName().isIdentifier() &&
+           FD->getName().equals("get_return_object") && FD->param_empty();
+  }
 
   // As a clang extension, enforces that a non-coroutine function must be marked
   // with [[clang::coro_wrapper]] if it returns a type marked with
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 8e46c4984d93dc..2f93e803d11c69 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -15846,8 +15846,7 @@ void Sema::CheckCoroutineWrapper(FunctionDecl *FD) {
   if (!RD || !RD->getUnderlyingDecl()->hasAttr<CoroReturnTypeAttr>())
     return;
   // Allow `get_return_object()`.
-  if (FD->getDeclName().isIdentifier() &&
-      FD->getName().equals("get_return_object") && FD->param_empty())
+  if (isGetReturnObject(FD))
     return;
   if (!FD->hasAttr<CoroWrapperAttr>())
     Diag(FD->getLocation(), diag::err_coroutine_return_type) << RD;
diff --git a/clang/lib/Sema/SemaInit.cpp b/clang/lib/Sema/SemaInit.cpp
index 1aa6be826699bb..848184ff4b724e 100644
--- a/clang/lib/Sema/SemaInit.cpp
+++ b/clang/lib/Sema/SemaInit.cpp
@@ -7587,8 +7587,7 @@ static void visitLifetimeBoundArguments(IndirectLocalPath &Path, Expr *Call,
   if (ObjectArg) {
     bool CheckCoroObjArg = CheckCoroCall;
     // Ignore `__promise.get_return_object()` as it is not lifetimebound.
-    if (CheckCoroObjArg && Callee->getDeclName().isIdentifier() &&
-        Callee->getName() == "get_return_object")
+    if (CheckCoroObjArg && Sema::isGetReturnObject(Callee))
       CheckCoroObjArg = false;
     // Coroutine lambda objects with empty capture list are not lifetimebound.
     if (auto *LE = dyn_cast<LambdaExpr>(ObjectArg->IgnoreImplicit());

>From 4c05d54e07d4f650bf5694a3ddc365078b854627 Mon Sep 17 00:00:00 2001
From: Utkarsh Saxena <usx at google.com>
Date: Thu, 11 Jan 2024 15:03:41 +0000
Subject: [PATCH 04/16] implicitly add coro_wrapper and
 coro_disable_lifetimebound to get_return_object

---
 clang/include/clang/Sema/Sema.h           |  6 ------
 clang/lib/Sema/SemaCoroutine.cpp          | 22 ++++++++++++++++++++++
 clang/lib/Sema/SemaDecl.cpp               |  7 +++++--
 clang/lib/Sema/SemaInit.cpp               |  3 ---
 clang/test/SemaCXX/coro-lifetimebound.cpp |  2 +-
 5 files changed, 28 insertions(+), 12 deletions(-)

diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index a17a9444d201b6..8f44adef38159e 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -11219,12 +11219,6 @@ class Sema final {
   bool buildCoroutineParameterMoves(SourceLocation Loc);
   VarDecl *buildCoroutinePromise(SourceLocation Loc);
   void CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body);
-  // Heuristaclly tells if the function is promise::get_return_object().
-  static bool isGetReturnObject(const FunctionDecl *FD) {
-    return isa_and_nonnull<CXXMethodDecl>(FD) &&
-           FD->getDeclName().isIdentifier() &&
-           FD->getName().equals("get_return_object") && FD->param_empty();
-  }
 
   // As a clang extension, enforces that a non-coroutine function must be marked
   // with [[clang::coro_wrapper]] if it returns a type marked with
diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp
index bee80db8d166a6..0b03ca4ea1fa9c 100644
--- a/clang/lib/Sema/SemaCoroutine.cpp
+++ b/clang/lib/Sema/SemaCoroutine.cpp
@@ -297,6 +297,26 @@ struct ReadySuspendResumeResult {
   bool IsInvalid;
 };
 
+// Adds [[clang::coro_wrapper]] and [[clang::coro_disable_lifetimebound]]
+// attributes to `get_return_object`.
+static void handleGetReturnObject(Sema &S, MemberExpr *ME) {
+  if (!ME || !ME->getMemberDecl() || !ME->getMemberDecl()->getAsFunction())
+    return;
+  auto *MD = ME->getMemberDecl()->getAsFunction();
+  auto* RetType = MD->getReturnType()->getAsRecordDecl();
+  if (!RetType || !RetType->hasAttr<CoroReturnTypeAttr>())
+    return;
+  // `get_return_object` should be allowed to return coro_return_type.
+  if (!MD->hasAttr<CoroWrapperAttr>())
+    MD->addAttr(
+        CoroWrapperAttr::CreateImplicit(S.getASTContext(), MD->getLocation()));
+  // Object arg of `__promise.get_return_object()` is not lifetimebound.
+  if (RetType->hasAttr<CoroLifetimeBoundAttr>() &&
+      !MD->hasAttr<CoroDisableLifetimeBoundAttr>())
+    MD->addAttr(CoroDisableLifetimeBoundAttr::CreateImplicit(
+        S.getASTContext(), MD->getLocation()));
+}
+
 static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc,
                                   StringRef Name, MultiExprArg Args) {
   DeclarationNameInfo NameInfo(&S.PP.getIdentifierTable().get(Name), Loc);
@@ -319,6 +339,8 @@ static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc,
     return ExprError();
   }
 
+  if (Name.equals("get_return_object"))
+    handleGetReturnObject(S, dyn_cast<MemberExpr>(Result.get()));
   auto EndLoc = Args.empty() ? Loc : Args.back()->getEndLoc();
   return S.BuildCallExpr(nullptr, Result.get(), Loc, Args, EndLoc, nullptr);
 }
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 2f93e803d11c69..efd3232261123b 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -15845,8 +15845,11 @@ void Sema::CheckCoroutineWrapper(FunctionDecl *FD) {
   RecordDecl *RD = FD->getReturnType()->getAsRecordDecl();
   if (!RD || !RD->getUnderlyingDecl()->hasAttr<CoroReturnTypeAttr>())
     return;
-  // Allow `get_return_object()`.
-  if (isGetReturnObject(FD))
+  // Allow some_promise_type::get_return_object().
+  // Since we are still in the promise definition, we can only do this
+  // heuristically as the promise may not be yet associated to a coroutine.
+  if (isa<CXXMethodDecl>(FD) && FD->getDeclName().isIdentifier() &&
+      FD->getName().equals("get_return_object") && FD->param_empty())
     return;
   if (!FD->hasAttr<CoroWrapperAttr>())
     Diag(FD->getLocation(), diag::err_coroutine_return_type) << RD;
diff --git a/clang/lib/Sema/SemaInit.cpp b/clang/lib/Sema/SemaInit.cpp
index 848184ff4b724e..05e71a0b4960ab 100644
--- a/clang/lib/Sema/SemaInit.cpp
+++ b/clang/lib/Sema/SemaInit.cpp
@@ -7586,9 +7586,6 @@ static void visitLifetimeBoundArguments(IndirectLocalPath &Path, Expr *Call,
 
   if (ObjectArg) {
     bool CheckCoroObjArg = CheckCoroCall;
-    // Ignore `__promise.get_return_object()` as it is not lifetimebound.
-    if (CheckCoroObjArg && Sema::isGetReturnObject(Callee))
-      CheckCoroObjArg = false;
     // Coroutine lambda objects with empty capture list are not lifetimebound.
     if (auto *LE = dyn_cast<LambdaExpr>(ObjectArg->IgnoreImplicit());
         LE && LE->captures().empty())
diff --git a/clang/test/SemaCXX/coro-lifetimebound.cpp b/clang/test/SemaCXX/coro-lifetimebound.cpp
index b61013ec057381..9e96a296562a05 100644
--- a/clang/test/SemaCXX/coro-lifetimebound.cpp
+++ b/clang/test/SemaCXX/coro-lifetimebound.cpp
@@ -180,7 +180,7 @@ Co<int> foo_wrapper(const int& x) { return foo(x); }
 }
 
 struct S{
-[[clang::coro_wrapper, clang::coro_disable_lifetimebound]] 
+[[clang::coro_wrapper, clang::coro_disable_lifetimebound]]
 Co<int> member(const int& x) { return foo(x); }
 };
 

>From 9858f46932f8cf8775c11496b2fe069410c513f4 Mon Sep 17 00:00:00 2001
From: Utkarsh Saxena <usx at google.com>
Date: Thu, 11 Jan 2024 15:16:46 +0000
Subject: [PATCH 05/16] format

---
 clang/lib/Sema/SemaCoroutine.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp
index 0b03ca4ea1fa9c..ef08e1ee750d17 100644
--- a/clang/lib/Sema/SemaCoroutine.cpp
+++ b/clang/lib/Sema/SemaCoroutine.cpp
@@ -303,7 +303,7 @@ static void handleGetReturnObject(Sema &S, MemberExpr *ME) {
   if (!ME || !ME->getMemberDecl() || !ME->getMemberDecl()->getAsFunction())
     return;
   auto *MD = ME->getMemberDecl()->getAsFunction();
-  auto* RetType = MD->getReturnType()->getAsRecordDecl();
+  auto *RetType = MD->getReturnType()->getAsRecordDecl();
   if (!RetType || !RetType->hasAttr<CoroReturnTypeAttr>())
     return;
   // `get_return_object` should be allowed to return coro_return_type.

>From 070fffdd072d30a9d64e7b9bf0f46fcf16cb977a Mon Sep 17 00:00:00 2001
From: Utkarsh Saxena <usx at google.com>
Date: Fri, 12 Jan 2024 15:02:09 +0000
Subject: [PATCH 06/16] moved the check to makeReturnObject

---
 clang/lib/Sema/SemaCoroutine.cpp | 50 ++++++++++++++++++--------------
 1 file changed, 28 insertions(+), 22 deletions(-)

diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp
index ef08e1ee750d17..da4aa658d322ac 100644
--- a/clang/lib/Sema/SemaCoroutine.cpp
+++ b/clang/lib/Sema/SemaCoroutine.cpp
@@ -16,6 +16,7 @@
 #include "CoroutineStmtBuilder.h"
 #include "clang/AST/ASTLambda.h"
 #include "clang/AST/Decl.h"
+#include "clang/AST/Expr.h"
 #include "clang/AST/ExprCXX.h"
 #include "clang/AST/StmtCXX.h"
 #include "clang/Basic/Builtins.h"
@@ -297,26 +298,6 @@ struct ReadySuspendResumeResult {
   bool IsInvalid;
 };
 
-// Adds [[clang::coro_wrapper]] and [[clang::coro_disable_lifetimebound]]
-// attributes to `get_return_object`.
-static void handleGetReturnObject(Sema &S, MemberExpr *ME) {
-  if (!ME || !ME->getMemberDecl() || !ME->getMemberDecl()->getAsFunction())
-    return;
-  auto *MD = ME->getMemberDecl()->getAsFunction();
-  auto *RetType = MD->getReturnType()->getAsRecordDecl();
-  if (!RetType || !RetType->hasAttr<CoroReturnTypeAttr>())
-    return;
-  // `get_return_object` should be allowed to return coro_return_type.
-  if (!MD->hasAttr<CoroWrapperAttr>())
-    MD->addAttr(
-        CoroWrapperAttr::CreateImplicit(S.getASTContext(), MD->getLocation()));
-  // Object arg of `__promise.get_return_object()` is not lifetimebound.
-  if (RetType->hasAttr<CoroLifetimeBoundAttr>() &&
-      !MD->hasAttr<CoroDisableLifetimeBoundAttr>())
-    MD->addAttr(CoroDisableLifetimeBoundAttr::CreateImplicit(
-        S.getASTContext(), MD->getLocation()));
-}
-
 static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc,
                                   StringRef Name, MultiExprArg Args) {
   DeclarationNameInfo NameInfo(&S.PP.getIdentifierTable().get(Name), Loc);
@@ -339,8 +320,6 @@ static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc,
     return ExprError();
   }
 
-  if (Name.equals("get_return_object"))
-    handleGetReturnObject(S, dyn_cast<MemberExpr>(Result.get()));
   auto EndLoc = Args.empty() ? Loc : Args.back()->getEndLoc();
   return S.BuildCallExpr(nullptr, Result.get(), Loc, Args, EndLoc, nullptr);
 }
@@ -1818,6 +1797,32 @@ bool CoroutineStmtBuilder::makeOnException() {
   return true;
 }
 
+// Adds [[clang::coro_wrapper]] and [[clang::coro_disable_lifetimebound]]
+// attributes to the function `get_return_object`.
+static void handleGetReturnObject(Sema &S, Expr *E) {
+  if(auto* TE = dyn_cast<CXXBindTemporaryExpr>(E))
+    E = TE->getSubExpr();
+  auto* CE = dyn_cast<CallExpr>(E);
+  assert(CE);
+  auto *MD = CE->getDirectCallee();
+  if (!MD)
+    return;
+  // This analysis is done only for types marked with
+  // [[clang::coro_return_type]].
+  auto *RetType = MD->getReturnType()->getAsRecordDecl();
+  if (!RetType || !RetType->hasAttr<CoroReturnTypeAttr>())
+    return;
+  // `get_return_object` should be allowed to return coro_return_type.
+  if (!MD->hasAttr<CoroWrapperAttr>())
+    MD->addAttr(
+        CoroWrapperAttr::CreateImplicit(S.getASTContext(), MD->getLocation()));
+  // Object arg of `__promise.get_return_object()` is not lifetimebound.
+  if (RetType->hasAttr<CoroLifetimeBoundAttr>() &&
+      !MD->hasAttr<CoroDisableLifetimeBoundAttr>())
+    MD->addAttr(CoroDisableLifetimeBoundAttr::CreateImplicit(
+        S.getASTContext(), MD->getLocation()));
+}
+
 bool CoroutineStmtBuilder::makeReturnObject() {
   // [dcl.fct.def.coroutine]p7
   // The expression promise.get_return_object() is used to initialize the
@@ -1827,6 +1832,7 @@ bool CoroutineStmtBuilder::makeReturnObject() {
   if (ReturnObject.isInvalid())
     return false;
 
+  handleGetReturnObject(S, ReturnObject.get());
   this->ReturnValue = ReturnObject.get();
   return true;
 }

>From 4052cfa00494847074fce992aad3c727e91a84d2 Mon Sep 17 00:00:00 2001
From: Utkarsh Saxena <usx at google.com>
Date: Fri, 12 Jan 2024 15:55:02 +0000
Subject: [PATCH 07/16] format

---
 clang/lib/Sema/SemaCoroutine.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp
index da4aa658d322ac..fba8f38b42613d 100644
--- a/clang/lib/Sema/SemaCoroutine.cpp
+++ b/clang/lib/Sema/SemaCoroutine.cpp
@@ -1800,9 +1800,9 @@ bool CoroutineStmtBuilder::makeOnException() {
 // Adds [[clang::coro_wrapper]] and [[clang::coro_disable_lifetimebound]]
 // attributes to the function `get_return_object`.
 static void handleGetReturnObject(Sema &S, Expr *E) {
-  if(auto* TE = dyn_cast<CXXBindTemporaryExpr>(E))
+  if (auto *TE = dyn_cast<CXXBindTemporaryExpr>(E))
     E = TE->getSubExpr();
-  auto* CE = dyn_cast<CallExpr>(E);
+  auto *CE = dyn_cast<CallExpr>(E);
   assert(CE);
   auto *MD = CE->getDirectCallee();
   if (!MD)

>From ac43a7535b7f369f7a529f350ab61a4d3ae825dd Mon Sep 17 00:00:00 2001
From: Utkarsh Saxena <usx at google.com>
Date: Mon, 15 Jan 2024 15:00:58 +0100
Subject: [PATCH 08/16] Use cast instead of dyn_cast and assert

Co-authored-by: Chuanqi Xu <yedeng.yd at linux.alibaba.com>
---
 clang/lib/Sema/SemaCoroutine.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp
index fba8f38b42613d..466c17e11f3977 100644
--- a/clang/lib/Sema/SemaCoroutine.cpp
+++ b/clang/lib/Sema/SemaCoroutine.cpp
@@ -1802,8 +1802,7 @@ bool CoroutineStmtBuilder::makeOnException() {
 static void handleGetReturnObject(Sema &S, Expr *E) {
   if (auto *TE = dyn_cast<CXXBindTemporaryExpr>(E))
     E = TE->getSubExpr();
-  auto *CE = dyn_cast<CallExpr>(E);
-  assert(CE);
+  auto *CE = cast<CallExpr>(E);
   auto *MD = CE->getDirectCallee();
   if (!MD)
     return;

>From d6ed9900f9ade2162bd0247cf3f765b18d79eb83 Mon Sep 17 00:00:00 2001
From: Utkarsh Saxena <usx at google.com>
Date: Mon, 15 Jan 2024 15:06:28 +0100
Subject: [PATCH 09/16] Update clang/lib/Sema/SemaCoroutine.cpp

Co-authored-by: Chuanqi Xu <yedeng.yd at linux.alibaba.com>
---
 clang/lib/Sema/SemaCoroutine.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp
index 466c17e11f3977..98c7b54c38350c 100644
--- a/clang/lib/Sema/SemaCoroutine.cpp
+++ b/clang/lib/Sema/SemaCoroutine.cpp
@@ -1798,7 +1798,7 @@ bool CoroutineStmtBuilder::makeOnException() {
 }
 
 // Adds [[clang::coro_wrapper]] and [[clang::coro_disable_lifetimebound]]
-// attributes to the function `get_return_object`.
+// attributes to the function `get_return_object` if its return type is marked with `[[clang::coro_return_type]]` to avoid false-positive diagnostic for `get_return_object`.
 static void handleGetReturnObject(Sema &S, Expr *E) {
   if (auto *TE = dyn_cast<CXXBindTemporaryExpr>(E))
     E = TE->getSubExpr();

>From 8b74a3f8359892745e81abcaef361b33f5dc17e6 Mon Sep 17 00:00:00 2001
From: Utkarsh Saxena <usx at google.com>
Date: Mon, 15 Jan 2024 14:07:04 +0000
Subject: [PATCH 10/16] remove unused include

---
 clang/lib/Sema/SemaInit.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/clang/lib/Sema/SemaInit.cpp b/clang/lib/Sema/SemaInit.cpp
index 05e71a0b4960ab..d98108c1b0abc4 100644
--- a/clang/lib/Sema/SemaInit.cpp
+++ b/clang/lib/Sema/SemaInit.cpp
@@ -34,7 +34,6 @@
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringExtras.h"
-#include "llvm/Support/Casting.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/raw_ostream.h"
 

>From 832db83f7e3854b478bd00404e95ce4cf79a9dd8 Mon Sep 17 00:00:00 2001
From: Utkarsh Saxena <usx at google.com>
Date: Mon, 15 Jan 2024 14:40:09 +0000
Subject: [PATCH 11/16] format

---
 clang/lib/Sema/SemaCoroutine.cpp | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp
index 98c7b54c38350c..e64fb2cd8ef576 100644
--- a/clang/lib/Sema/SemaCoroutine.cpp
+++ b/clang/lib/Sema/SemaCoroutine.cpp
@@ -1798,7 +1798,9 @@ bool CoroutineStmtBuilder::makeOnException() {
 }
 
 // Adds [[clang::coro_wrapper]] and [[clang::coro_disable_lifetimebound]]
-// attributes to the function `get_return_object` if its return type is marked with `[[clang::coro_return_type]]` to avoid false-positive diagnostic for `get_return_object`.
+// attributes to the function `get_return_object` if its return type is marked
+// with `[[clang::coro_return_type]]` to avoid false-positive diagnostic for
+// `get_return_object`.
 static void handleGetReturnObject(Sema &S, Expr *E) {
   if (auto *TE = dyn_cast<CXXBindTemporaryExpr>(E))
     E = TE->getSubExpr();

>From 5080f85ed9adf8141e06d04690b9c24d779df18e Mon Sep 17 00:00:00 2001
From: Utkarsh Saxena <usx at google.com>
Date: Tue, 16 Jan 2024 15:32:20 +0000
Subject: [PATCH 12/16] handle get_return_object_on_allocation_failure

---
 clang/lib/Sema/SemaCoroutine.cpp              | 55 ++++++++++---------
 clang/lib/Sema/SemaDecl.cpp                   | 15 ++++-
 .../SemaCXX/coro-return-type-and-wrapper.cpp  | 12 ++++
 3 files changed, 53 insertions(+), 29 deletions(-)

diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp
index e64fb2cd8ef576..8923ce2bc7138c 100644
--- a/clang/lib/Sema/SemaCoroutine.cpp
+++ b/clang/lib/Sema/SemaCoroutine.cpp
@@ -1321,6 +1321,33 @@ static bool diagReturnOnAllocFailure(Sema &S, Expr *E,
   return false;
 }
 
+// Adds [[clang::coro_wrapper]] and [[clang::coro_disable_lifetimebound]]
+// attributes to the function `get_return_object` if its return type is marked
+// with `[[clang::coro_return_type]]` to avoid false-positive diagnostic for
+// `get_return_object`.
+static void handleGetReturnObject(Sema &S, Expr *E) {
+  if (auto *TE = dyn_cast<CXXBindTemporaryExpr>(E))
+    E = TE->getSubExpr();
+  auto *CE = cast<CallExpr>(E);
+  auto *MD = CE->getDirectCallee();
+  if (!MD)
+    return;
+  // This analysis is done only for types marked with
+  // [[clang::coro_return_type]].
+  auto *RetType = MD->getReturnType()->getAsRecordDecl();
+  if (!RetType || !RetType->hasAttr<CoroReturnTypeAttr>())
+    return;
+  // `get_return_object` should be allowed to return coro_return_type.
+  if (!MD->hasAttr<CoroWrapperAttr>())
+    MD->addAttr(
+        CoroWrapperAttr::CreateImplicit(S.getASTContext(), MD->getLocation()));
+  // Object arg of `__promise.get_return_object()` is not lifetimebound.
+  if (RetType->hasAttr<CoroLifetimeBoundAttr>() &&
+      !MD->hasAttr<CoroDisableLifetimeBoundAttr>())
+    MD->addAttr(CoroDisableLifetimeBoundAttr::CreateImplicit(
+        S.getASTContext(), MD->getLocation()));
+}
+
 bool CoroutineStmtBuilder::makeReturnOnAllocFailure() {
   assert(!IsPromiseDependentType &&
          "cannot make statement while the promise type is dependent");
@@ -1354,6 +1381,7 @@ bool CoroutineStmtBuilder::makeReturnOnAllocFailure() {
   if (ReturnObjectOnAllocationFailure.isInvalid())
     return false;
 
+  handleGetReturnObject(S, ReturnObjectOnAllocationFailure.get());
   StmtResult ReturnStmt =
       S.BuildReturnStmt(Loc, ReturnObjectOnAllocationFailure.get());
   if (ReturnStmt.isInvalid()) {
@@ -1797,33 +1825,6 @@ bool CoroutineStmtBuilder::makeOnException() {
   return true;
 }
 
-// Adds [[clang::coro_wrapper]] and [[clang::coro_disable_lifetimebound]]
-// attributes to the function `get_return_object` if its return type is marked
-// with `[[clang::coro_return_type]]` to avoid false-positive diagnostic for
-// `get_return_object`.
-static void handleGetReturnObject(Sema &S, Expr *E) {
-  if (auto *TE = dyn_cast<CXXBindTemporaryExpr>(E))
-    E = TE->getSubExpr();
-  auto *CE = cast<CallExpr>(E);
-  auto *MD = CE->getDirectCallee();
-  if (!MD)
-    return;
-  // This analysis is done only for types marked with
-  // [[clang::coro_return_type]].
-  auto *RetType = MD->getReturnType()->getAsRecordDecl();
-  if (!RetType || !RetType->hasAttr<CoroReturnTypeAttr>())
-    return;
-  // `get_return_object` should be allowed to return coro_return_type.
-  if (!MD->hasAttr<CoroWrapperAttr>())
-    MD->addAttr(
-        CoroWrapperAttr::CreateImplicit(S.getASTContext(), MD->getLocation()));
-  // Object arg of `__promise.get_return_object()` is not lifetimebound.
-  if (RetType->hasAttr<CoroLifetimeBoundAttr>() &&
-      !MD->hasAttr<CoroDisableLifetimeBoundAttr>())
-    MD->addAttr(CoroDisableLifetimeBoundAttr::CreateImplicit(
-        S.getASTContext(), MD->getLocation()));
-}
-
 bool CoroutineStmtBuilder::makeReturnObject() {
   // [dcl.fct.def.coroutine]p7
   // The expression promise.get_return_object() is used to initialize the
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index efd3232261123b..db3e258e57b4e1 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -15841,6 +15841,18 @@ static void diagnoseImplicitlyRetainedSelf(Sema &S) {
           << FixItHint::CreateInsertion(P.first, "self->");
 }
 
+static bool IsGetReturnType(FunctionDecl* FD) {
+  return isa<CXXMethodDecl>(FD) && FD->param_empty() &&
+         FD->getDeclName().isIdentifier() &&
+         FD->getName().equals("get_return_object");
+}
+
+static bool IsGetReturnTypeOnAllocFailure(FunctionDecl* FD) {
+  return FD->isStatic() && FD->param_empty() &&
+         FD->getDeclName().isIdentifier() &&
+         FD->getName().equals("get_return_object_on_allocation_failure");
+}
+
 void Sema::CheckCoroutineWrapper(FunctionDecl *FD) {
   RecordDecl *RD = FD->getReturnType()->getAsRecordDecl();
   if (!RD || !RD->getUnderlyingDecl()->hasAttr<CoroReturnTypeAttr>())
@@ -15848,8 +15860,7 @@ void Sema::CheckCoroutineWrapper(FunctionDecl *FD) {
   // Allow some_promise_type::get_return_object().
   // Since we are still in the promise definition, we can only do this
   // heuristically as the promise may not be yet associated to a coroutine.
-  if (isa<CXXMethodDecl>(FD) && FD->getDeclName().isIdentifier() &&
-      FD->getName().equals("get_return_object") && FD->param_empty())
+  if (IsGetReturnType(FD) || IsGetReturnTypeOnAllocFailure(FD))
     return;
   if (!FD->hasAttr<CoroWrapperAttr>())
     Diag(FD->getLocation(), diag::err_coroutine_return_type) << RD;
diff --git a/clang/test/SemaCXX/coro-return-type-and-wrapper.cpp b/clang/test/SemaCXX/coro-return-type-and-wrapper.cpp
index ac49e03ba9d90a..b08e1c9c065a0e 100644
--- a/clang/test/SemaCXX/coro-return-type-and-wrapper.cpp
+++ b/clang/test/SemaCXX/coro-return-type-and-wrapper.cpp
@@ -5,11 +5,23 @@ using std::suspend_always;
 using std::suspend_never;
 
 
+namespace std {
+  struct nothrow_t {};
+  constexpr nothrow_t nothrow = {};
+}
+
+using SizeT = decltype(sizeof(int));
+
+void* operator new(SizeT __sz, const std::nothrow_t&) noexcept;
+
 template <typename T> struct [[clang::coro_return_type]] Gen {
   struct promise_type {
     Gen<T> get_return_object() {
       return {};
     }
+    static Gen<T> get_return_object_on_allocation_failure() {
+      return {};
+    }
     suspend_always initial_suspend();
     suspend_always final_suspend() noexcept;
     void unhandled_exception();

>From 0889e48e05573b17bf192275e11ddf1e915c7fe2 Mon Sep 17 00:00:00 2001
From: Utkarsh Saxena <usx at google.com>
Date: Tue, 16 Jan 2024 15:36:04 +0000
Subject: [PATCH 13/16] no need to add attrs to
 get_return_object_on_allocation_failure as it is static

---
 clang/lib/Sema/SemaCoroutine.cpp | 55 ++++++++++++++++----------------
 1 file changed, 27 insertions(+), 28 deletions(-)

diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp
index 8923ce2bc7138c..e64fb2cd8ef576 100644
--- a/clang/lib/Sema/SemaCoroutine.cpp
+++ b/clang/lib/Sema/SemaCoroutine.cpp
@@ -1321,33 +1321,6 @@ static bool diagReturnOnAllocFailure(Sema &S, Expr *E,
   return false;
 }
 
-// Adds [[clang::coro_wrapper]] and [[clang::coro_disable_lifetimebound]]
-// attributes to the function `get_return_object` if its return type is marked
-// with `[[clang::coro_return_type]]` to avoid false-positive diagnostic for
-// `get_return_object`.
-static void handleGetReturnObject(Sema &S, Expr *E) {
-  if (auto *TE = dyn_cast<CXXBindTemporaryExpr>(E))
-    E = TE->getSubExpr();
-  auto *CE = cast<CallExpr>(E);
-  auto *MD = CE->getDirectCallee();
-  if (!MD)
-    return;
-  // This analysis is done only for types marked with
-  // [[clang::coro_return_type]].
-  auto *RetType = MD->getReturnType()->getAsRecordDecl();
-  if (!RetType || !RetType->hasAttr<CoroReturnTypeAttr>())
-    return;
-  // `get_return_object` should be allowed to return coro_return_type.
-  if (!MD->hasAttr<CoroWrapperAttr>())
-    MD->addAttr(
-        CoroWrapperAttr::CreateImplicit(S.getASTContext(), MD->getLocation()));
-  // Object arg of `__promise.get_return_object()` is not lifetimebound.
-  if (RetType->hasAttr<CoroLifetimeBoundAttr>() &&
-      !MD->hasAttr<CoroDisableLifetimeBoundAttr>())
-    MD->addAttr(CoroDisableLifetimeBoundAttr::CreateImplicit(
-        S.getASTContext(), MD->getLocation()));
-}
-
 bool CoroutineStmtBuilder::makeReturnOnAllocFailure() {
   assert(!IsPromiseDependentType &&
          "cannot make statement while the promise type is dependent");
@@ -1381,7 +1354,6 @@ bool CoroutineStmtBuilder::makeReturnOnAllocFailure() {
   if (ReturnObjectOnAllocationFailure.isInvalid())
     return false;
 
-  handleGetReturnObject(S, ReturnObjectOnAllocationFailure.get());
   StmtResult ReturnStmt =
       S.BuildReturnStmt(Loc, ReturnObjectOnAllocationFailure.get());
   if (ReturnStmt.isInvalid()) {
@@ -1825,6 +1797,33 @@ bool CoroutineStmtBuilder::makeOnException() {
   return true;
 }
 
+// Adds [[clang::coro_wrapper]] and [[clang::coro_disable_lifetimebound]]
+// attributes to the function `get_return_object` if its return type is marked
+// with `[[clang::coro_return_type]]` to avoid false-positive diagnostic for
+// `get_return_object`.
+static void handleGetReturnObject(Sema &S, Expr *E) {
+  if (auto *TE = dyn_cast<CXXBindTemporaryExpr>(E))
+    E = TE->getSubExpr();
+  auto *CE = cast<CallExpr>(E);
+  auto *MD = CE->getDirectCallee();
+  if (!MD)
+    return;
+  // This analysis is done only for types marked with
+  // [[clang::coro_return_type]].
+  auto *RetType = MD->getReturnType()->getAsRecordDecl();
+  if (!RetType || !RetType->hasAttr<CoroReturnTypeAttr>())
+    return;
+  // `get_return_object` should be allowed to return coro_return_type.
+  if (!MD->hasAttr<CoroWrapperAttr>())
+    MD->addAttr(
+        CoroWrapperAttr::CreateImplicit(S.getASTContext(), MD->getLocation()));
+  // Object arg of `__promise.get_return_object()` is not lifetimebound.
+  if (RetType->hasAttr<CoroLifetimeBoundAttr>() &&
+      !MD->hasAttr<CoroDisableLifetimeBoundAttr>())
+    MD->addAttr(CoroDisableLifetimeBoundAttr::CreateImplicit(
+        S.getASTContext(), MD->getLocation()));
+}
+
 bool CoroutineStmtBuilder::makeReturnObject() {
   // [dcl.fct.def.coroutine]p7
   // The expression promise.get_return_object() is used to initialize the

>From 45ec564aae1c02fdf966406de2f9b295df951f5b Mon Sep 17 00:00:00 2001
From: Utkarsh Saxena <usx at google.com>
Date: Tue, 16 Jan 2024 15:37:21 +0000
Subject: [PATCH 14/16] format

---
 clang/lib/Sema/SemaDecl.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index db3e258e57b4e1..49a6ad75e59ef9 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -15841,13 +15841,13 @@ static void diagnoseImplicitlyRetainedSelf(Sema &S) {
           << FixItHint::CreateInsertion(P.first, "self->");
 }
 
-static bool IsGetReturnType(FunctionDecl* FD) {
+static bool IsGetReturnType(FunctionDecl *FD) {
   return isa<CXXMethodDecl>(FD) && FD->param_empty() &&
          FD->getDeclName().isIdentifier() &&
          FD->getName().equals("get_return_object");
 }
 
-static bool IsGetReturnTypeOnAllocFailure(FunctionDecl* FD) {
+static bool IsGetReturnTypeOnAllocFailure(FunctionDecl *FD) {
   return FD->isStatic() && FD->param_empty() &&
          FD->getDeclName().isIdentifier() &&
          FD->getName().equals("get_return_object_on_allocation_failure");

>From c9031aca4fea376fc8e45e98168782f9ed4683ad Mon Sep 17 00:00:00 2001
From: Utkarsh Saxena <usx at google.com>
Date: Wed, 17 Jan 2024 11:59:35 +0000
Subject: [PATCH 15/16] revert to name matching heurstics

---
 clang/include/clang/Sema/Sema.h  |  5 +++++
 clang/lib/Sema/SemaCoroutine.cpp | 28 ----------------------------
 clang/lib/Sema/SemaDecl.cpp      |  8 +++-----
 clang/lib/Sema/SemaInit.cpp      |  4 ++++
 4 files changed, 12 insertions(+), 33 deletions(-)

diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 8f44adef38159e..dae4c3eca6e4aa 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -11220,6 +11220,11 @@ class Sema final {
   VarDecl *buildCoroutinePromise(SourceLocation Loc);
   void CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body);
 
+  // Heuristically tells if the function is get_return_object by matching
+  // function name.
+  static bool IsGetReturnObject(const FunctionDecl *FD);
+  static bool IsGetReturnTypeOnAllocFailure(const FunctionDecl *FD);
+
   // As a clang extension, enforces that a non-coroutine function must be marked
   // with [[clang::coro_wrapper]] if it returns a type marked with
   // [[clang::coro_return_type]].
diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp
index e64fb2cd8ef576..0e0f8f67dcd73e 100644
--- a/clang/lib/Sema/SemaCoroutine.cpp
+++ b/clang/lib/Sema/SemaCoroutine.cpp
@@ -1797,33 +1797,6 @@ bool CoroutineStmtBuilder::makeOnException() {
   return true;
 }
 
-// Adds [[clang::coro_wrapper]] and [[clang::coro_disable_lifetimebound]]
-// attributes to the function `get_return_object` if its return type is marked
-// with `[[clang::coro_return_type]]` to avoid false-positive diagnostic for
-// `get_return_object`.
-static void handleGetReturnObject(Sema &S, Expr *E) {
-  if (auto *TE = dyn_cast<CXXBindTemporaryExpr>(E))
-    E = TE->getSubExpr();
-  auto *CE = cast<CallExpr>(E);
-  auto *MD = CE->getDirectCallee();
-  if (!MD)
-    return;
-  // This analysis is done only for types marked with
-  // [[clang::coro_return_type]].
-  auto *RetType = MD->getReturnType()->getAsRecordDecl();
-  if (!RetType || !RetType->hasAttr<CoroReturnTypeAttr>())
-    return;
-  // `get_return_object` should be allowed to return coro_return_type.
-  if (!MD->hasAttr<CoroWrapperAttr>())
-    MD->addAttr(
-        CoroWrapperAttr::CreateImplicit(S.getASTContext(), MD->getLocation()));
-  // Object arg of `__promise.get_return_object()` is not lifetimebound.
-  if (RetType->hasAttr<CoroLifetimeBoundAttr>() &&
-      !MD->hasAttr<CoroDisableLifetimeBoundAttr>())
-    MD->addAttr(CoroDisableLifetimeBoundAttr::CreateImplicit(
-        S.getASTContext(), MD->getLocation()));
-}
-
 bool CoroutineStmtBuilder::makeReturnObject() {
   // [dcl.fct.def.coroutine]p7
   // The expression promise.get_return_object() is used to initialize the
@@ -1833,7 +1806,6 @@ bool CoroutineStmtBuilder::makeReturnObject() {
   if (ReturnObject.isInvalid())
     return false;
 
-  handleGetReturnObject(S, ReturnObject.get());
   this->ReturnValue = ReturnObject.get();
   return true;
 }
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 49a6ad75e59ef9..57254933d51537 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -15841,13 +15841,13 @@ static void diagnoseImplicitlyRetainedSelf(Sema &S) {
           << FixItHint::CreateInsertion(P.first, "self->");
 }
 
-static bool IsGetReturnType(FunctionDecl *FD) {
+bool Sema::IsGetReturnObject(const FunctionDecl *FD) {
   return isa<CXXMethodDecl>(FD) && FD->param_empty() &&
          FD->getDeclName().isIdentifier() &&
          FD->getName().equals("get_return_object");
 }
 
-static bool IsGetReturnTypeOnAllocFailure(FunctionDecl *FD) {
+bool Sema::IsGetReturnTypeOnAllocFailure(const FunctionDecl *FD) {
   return FD->isStatic() && FD->param_empty() &&
          FD->getDeclName().isIdentifier() &&
          FD->getName().equals("get_return_object_on_allocation_failure");
@@ -15858,9 +15858,7 @@ void Sema::CheckCoroutineWrapper(FunctionDecl *FD) {
   if (!RD || !RD->getUnderlyingDecl()->hasAttr<CoroReturnTypeAttr>())
     return;
   // Allow some_promise_type::get_return_object().
-  // Since we are still in the promise definition, we can only do this
-  // heuristically as the promise may not be yet associated to a coroutine.
-  if (IsGetReturnType(FD) || IsGetReturnTypeOnAllocFailure(FD))
+  if (IsGetReturnObject(FD) || IsGetReturnTypeOnAllocFailure(FD))
     return;
   if (!FD->hasAttr<CoroWrapperAttr>())
     Diag(FD->getLocation(), diag::err_coroutine_return_type) << RD;
diff --git a/clang/lib/Sema/SemaInit.cpp b/clang/lib/Sema/SemaInit.cpp
index d98108c1b0abc4..6e05899856c3e1 100644
--- a/clang/lib/Sema/SemaInit.cpp
+++ b/clang/lib/Sema/SemaInit.cpp
@@ -7589,6 +7589,10 @@ static void visitLifetimeBoundArguments(IndirectLocalPath &Path, Expr *Call,
     if (auto *LE = dyn_cast<LambdaExpr>(ObjectArg->IgnoreImplicit());
         LE && LE->captures().empty())
       CheckCoroObjArg = false;
+    // Allow `get_return_object()` as the object param (__promise) is not
+    // lifetimebound.
+    if (Sema::IsGetReturnObject(Callee))
+      CheckCoroObjArg = false;
     if (implicitObjectParamIsLifetimeBound(Callee) || CheckCoroObjArg)
       VisitLifetimeBoundArg(Callee, ObjectArg);
   }

>From e9ce46e3c1e2ba19c934fab991b517b1a5d579b0 Mon Sep 17 00:00:00 2001
From: Utkarsh Saxena <usx at google.com>
Date: Thu, 18 Jan 2024 10:54:35 +0000
Subject: [PATCH 16/16] addressed comments

---
 clang/include/clang/Sema/Sema.h |  8 ++++----
 clang/lib/Sema/SemaDecl.cpp     | 17 ++++++++++-------
 clang/lib/Sema/SemaInit.cpp     |  2 +-
 3 files changed, 15 insertions(+), 12 deletions(-)

diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index dae4c3eca6e4aa..744ec165616228 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -11220,10 +11220,10 @@ class Sema final {
   VarDecl *buildCoroutinePromise(SourceLocation Loc);
   void CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body);
 
-  // Heuristically tells if the function is get_return_object by matching
-  // function name.
-  static bool IsGetReturnObject(const FunctionDecl *FD);
-  static bool IsGetReturnTypeOnAllocFailure(const FunctionDecl *FD);
+  // Heuristically tells if the function is `get_return_object` member of a
+  // coroutine promise_type by matching the function name.
+  static bool CanBeGetReturnObject(const FunctionDecl *FD);
+  static bool CanBeGetReturnTypeOnAllocFailure(const FunctionDecl *FD);
 
   // As a clang extension, enforces that a non-coroutine function must be marked
   // with [[clang::coro_wrapper]] if it returns a type marked with
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 57254933d51537..c8b3e62c078d36 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -15841,16 +15841,19 @@ static void diagnoseImplicitlyRetainedSelf(Sema &S) {
           << FixItHint::CreateInsertion(P.first, "self->");
 }
 
-bool Sema::IsGetReturnObject(const FunctionDecl *FD) {
+static bool methodHasName(const FunctionDecl* FD, StringRef Name) {
   return isa<CXXMethodDecl>(FD) && FD->param_empty() &&
          FD->getDeclName().isIdentifier() &&
-         FD->getName().equals("get_return_object");
+         FD->getName().equals(Name);
 }
 
-bool Sema::IsGetReturnTypeOnAllocFailure(const FunctionDecl *FD) {
-  return FD->isStatic() && FD->param_empty() &&
-         FD->getDeclName().isIdentifier() &&
-         FD->getName().equals("get_return_object_on_allocation_failure");
+bool Sema::CanBeGetReturnObject(const FunctionDecl *FD) {
+  return methodHasName(FD, "get_return_object");
+}
+
+bool Sema::CanBeGetReturnTypeOnAllocFailure(const FunctionDecl *FD) {
+  return FD->isStatic() &&
+         methodHasName(FD, "get_return_object_on_allocation_failure");
 }
 
 void Sema::CheckCoroutineWrapper(FunctionDecl *FD) {
@@ -15858,7 +15861,7 @@ void Sema::CheckCoroutineWrapper(FunctionDecl *FD) {
   if (!RD || !RD->getUnderlyingDecl()->hasAttr<CoroReturnTypeAttr>())
     return;
   // Allow some_promise_type::get_return_object().
-  if (IsGetReturnObject(FD) || IsGetReturnTypeOnAllocFailure(FD))
+  if (CanBeGetReturnObject(FD) || CanBeGetReturnTypeOnAllocFailure(FD))
     return;
   if (!FD->hasAttr<CoroWrapperAttr>())
     Diag(FD->getLocation(), diag::err_coroutine_return_type) << RD;
diff --git a/clang/lib/Sema/SemaInit.cpp b/clang/lib/Sema/SemaInit.cpp
index 6e05899856c3e1..422487ba6bcee9 100644
--- a/clang/lib/Sema/SemaInit.cpp
+++ b/clang/lib/Sema/SemaInit.cpp
@@ -7591,7 +7591,7 @@ static void visitLifetimeBoundArguments(IndirectLocalPath &Path, Expr *Call,
       CheckCoroObjArg = false;
     // Allow `get_return_object()` as the object param (__promise) is not
     // lifetimebound.
-    if (Sema::IsGetReturnObject(Callee))
+    if (Sema::CanBeGetReturnObject(Callee))
       CheckCoroObjArg = false;
     if (implicitObjectParamIsLifetimeBound(Callee) || CheckCoroObjArg)
       VisitLifetimeBoundArg(Callee, ObjectArg);



More information about the cfe-commits mailing list