[clang] df2a4ea - [clang] Expose CoawaitExpr's operand in the AST

Nathan Ridge via cfe-commits cfe-commits at lists.llvm.org
Tue May 17 05:14:19 PDT 2022


Author: Nathan Ridge
Date: 2022-05-17T08:13:37-04:00
New Revision: df2a4eae6b190806c0b96ef3312975e1c97dbda0

URL: https://github.com/llvm/llvm-project/commit/df2a4eae6b190806c0b96ef3312975e1c97dbda0
DIFF: https://github.com/llvm/llvm-project/commit/df2a4eae6b190806c0b96ef3312975e1c97dbda0.diff

LOG: [clang] Expose CoawaitExpr's operand in the AST

Previously the Expr returned by getOperand() was actually the
subexpression common to the "ready", "suspend", and "resume"
expressions, which often isn't just the operand but e.g.
await_transform() called on the operand.

It's important for the AST to expose the operand as written
in the source for traversals and tools like clangd to work
correctly.

Fixes https://github.com/clangd/clangd/issues/939

Differential Revision: https://reviews.llvm.org/D115187

Added: 
    clang/test/SemaCXX/co_await-ast.cpp

Modified: 
    clang-tools-extra/clangd/unittests/FindTargetTests.cpp
    clang/include/clang/AST/ExprCXX.h
    clang/include/clang/Sema/Sema.h
    clang/lib/Sema/SemaChecking.cpp
    clang/lib/Sema/SemaCoroutine.cpp
    clang/lib/Sema/TreeTransform.h
    clang/test/AST/coroutine-locals-cleanup-exp-namespace.cpp
    clang/test/AST/coroutine-locals-cleanup.cpp
    clang/test/AST/coroutine-source-location-crash-exp-namespace.cpp
    clang/test/AST/coroutine-source-location-crash.cpp

Removed: 
    


################################################################################
diff  --git a/clang-tools-extra/clangd/unittests/FindTargetTests.cpp b/clang-tools-extra/clangd/unittests/FindTargetTests.cpp
index c21114fa45b0..f7d547bd960c 100644
--- a/clang-tools-extra/clangd/unittests/FindTargetTests.cpp
+++ b/clang-tools-extra/clangd/unittests/FindTargetTests.cpp
@@ -548,6 +548,50 @@ TEST_F(TargetDeclTest, Concept) {
                {"template <typename T, typename U> concept Fooable = true"});
 }
 
+TEST_F(TargetDeclTest, Coroutine) {
+  Flags.push_back("-std=c++20");
+
+  Code = R"cpp(
+    namespace std::experimental {
+    template <typename, typename...> struct coroutine_traits;
+    template <typename> struct coroutine_handle {
+      template <typename U>
+      coroutine_handle(coroutine_handle<U>&&) noexcept;
+      static coroutine_handle from_address(void* __addr) noexcept;
+    };
+    } // namespace std::experimental
+
+    struct executor {};
+    struct awaitable {};
+    struct awaitable_frame {
+      awaitable get_return_object();
+      void return_void();
+      void unhandled_exception();
+      struct result_t {
+        ~result_t();
+        bool await_ready() const noexcept;
+        void await_suspend(std::experimental::coroutine_handle<void>) noexcept;
+        void await_resume() const noexcept;
+      };
+      result_t initial_suspend() noexcept;
+      result_t final_suspend() noexcept;
+      result_t await_transform(executor) noexcept;
+    };
+
+    namespace std::experimental {
+    template <>
+    struct coroutine_traits<awaitable> {
+      typedef awaitable_frame promise_type;
+    };
+    } // namespace std::experimental
+
+    awaitable foo() {
+      co_await [[executor]]();
+    }
+  )cpp";
+  EXPECT_DECLS("RecordTypeLoc", "struct executor");
+}
+
 TEST_F(TargetDeclTest, FunctionTemplate) {
   Code = R"cpp(
     // Implicit specialization.

diff  --git a/clang/include/clang/AST/ExprCXX.h b/clang/include/clang/AST/ExprCXX.h
index 3da9290c7dfb..967a74db916d 100644
--- a/clang/include/clang/AST/ExprCXX.h
+++ b/clang/include/clang/AST/ExprCXX.h
@@ -4698,18 +4698,19 @@ class CoroutineSuspendExpr : public Expr {
 
   SourceLocation KeywordLoc;
 
-  enum SubExpr { Common, Ready, Suspend, Resume, Count };
+  enum SubExpr { Operand, Common, Ready, Suspend, Resume, Count };
 
   Stmt *SubExprs[SubExpr::Count];
   OpaqueValueExpr *OpaqueValue = nullptr;
 
 public:
-  CoroutineSuspendExpr(StmtClass SC, SourceLocation KeywordLoc, Expr *Common,
-                       Expr *Ready, Expr *Suspend, Expr *Resume,
+  CoroutineSuspendExpr(StmtClass SC, SourceLocation KeywordLoc, Expr *Operand,
+                       Expr *Common, Expr *Ready, Expr *Suspend, Expr *Resume,
                        OpaqueValueExpr *OpaqueValue)
       : Expr(SC, Resume->getType(), Resume->getValueKind(),
              Resume->getObjectKind()),
         KeywordLoc(KeywordLoc), OpaqueValue(OpaqueValue) {
+    SubExprs[SubExpr::Operand] = Operand;
     SubExprs[SubExpr::Common] = Common;
     SubExprs[SubExpr::Ready] = Ready;
     SubExprs[SubExpr::Suspend] = Suspend;
@@ -4718,10 +4719,11 @@ class CoroutineSuspendExpr : public Expr {
   }
 
   CoroutineSuspendExpr(StmtClass SC, SourceLocation KeywordLoc, QualType Ty,
-                       Expr *Common)
+                       Expr *Operand, Expr *Common)
       : Expr(SC, Ty, VK_PRValue, OK_Ordinary), KeywordLoc(KeywordLoc) {
     assert(Common->isTypeDependent() && Ty->isDependentType() &&
            "wrong constructor for non-dependent co_await/co_yield expression");
+    SubExprs[SubExpr::Operand] = Operand;
     SubExprs[SubExpr::Common] = Common;
     SubExprs[SubExpr::Ready] = nullptr;
     SubExprs[SubExpr::Suspend] = nullptr;
@@ -4730,14 +4732,13 @@ class CoroutineSuspendExpr : public Expr {
   }
 
   CoroutineSuspendExpr(StmtClass SC, EmptyShell Empty) : Expr(SC, Empty) {
+    SubExprs[SubExpr::Operand] = nullptr;
     SubExprs[SubExpr::Common] = nullptr;
     SubExprs[SubExpr::Ready] = nullptr;
     SubExprs[SubExpr::Suspend] = nullptr;
     SubExprs[SubExpr::Resume] = nullptr;
   }
 
-  SourceLocation getKeywordLoc() const { return KeywordLoc; }
-
   Expr *getCommonExpr() const {
     return static_cast<Expr*>(SubExprs[SubExpr::Common]);
   }
@@ -4757,10 +4758,17 @@ class CoroutineSuspendExpr : public Expr {
     return static_cast<Expr*>(SubExprs[SubExpr::Resume]);
   }
 
+  // The syntactic operand written in the code
+  Expr *getOperand() const {
+    return static_cast<Expr *>(SubExprs[SubExpr::Operand]);
+  }
+
+  SourceLocation getKeywordLoc() const { return KeywordLoc; }
+
   SourceLocation getBeginLoc() const LLVM_READONLY { return KeywordLoc; }
 
   SourceLocation getEndLoc() const LLVM_READONLY {
-    return getCommonExpr()->getEndLoc();
+    return getOperand()->getEndLoc();
   }
 
   child_range children() {
@@ -4782,28 +4790,24 @@ class CoawaitExpr : public CoroutineSuspendExpr {
   friend class ASTStmtReader;
 
 public:
-  CoawaitExpr(SourceLocation CoawaitLoc, Expr *Operand, Expr *Ready,
-              Expr *Suspend, Expr *Resume, OpaqueValueExpr *OpaqueValue,
-              bool IsImplicit = false)
-      : CoroutineSuspendExpr(CoawaitExprClass, CoawaitLoc, Operand, Ready,
-                             Suspend, Resume, OpaqueValue) {
+  CoawaitExpr(SourceLocation CoawaitLoc, Expr *Operand, Expr *Common,
+              Expr *Ready, Expr *Suspend, Expr *Resume,
+              OpaqueValueExpr *OpaqueValue, bool IsImplicit = false)
+      : CoroutineSuspendExpr(CoawaitExprClass, CoawaitLoc, Operand, Common,
+                             Ready, Suspend, Resume, OpaqueValue) {
     CoawaitBits.IsImplicit = IsImplicit;
   }
 
   CoawaitExpr(SourceLocation CoawaitLoc, QualType Ty, Expr *Operand,
-              bool IsImplicit = false)
-      : CoroutineSuspendExpr(CoawaitExprClass, CoawaitLoc, Ty, Operand) {
+              Expr *Common, bool IsImplicit = false)
+      : CoroutineSuspendExpr(CoawaitExprClass, CoawaitLoc, Ty, Operand,
+                             Common) {
     CoawaitBits.IsImplicit = IsImplicit;
   }
 
   CoawaitExpr(EmptyShell Empty)
       : CoroutineSuspendExpr(CoawaitExprClass, Empty) {}
 
-  Expr *getOperand() const {
-    // FIXME: Dig out the actual operand or store it.
-    return getCommonExpr();
-  }
-
   bool isImplicit() const { return CoawaitBits.IsImplicit; }
   void setIsImplicit(bool value = true) { CoawaitBits.IsImplicit = value; }
 
@@ -4867,20 +4871,18 @@ class CoyieldExpr : public CoroutineSuspendExpr {
   friend class ASTStmtReader;
 
 public:
-  CoyieldExpr(SourceLocation CoyieldLoc, Expr *Operand, Expr *Ready,
-              Expr *Suspend, Expr *Resume, OpaqueValueExpr *OpaqueValue)
-      : CoroutineSuspendExpr(CoyieldExprClass, CoyieldLoc, Operand, Ready,
-                             Suspend, Resume, OpaqueValue) {}
-  CoyieldExpr(SourceLocation CoyieldLoc, QualType Ty, Expr *Operand)
-      : CoroutineSuspendExpr(CoyieldExprClass, CoyieldLoc, Ty, Operand) {}
+  CoyieldExpr(SourceLocation CoyieldLoc, Expr *Operand, Expr *Common,
+              Expr *Ready, Expr *Suspend, Expr *Resume,
+              OpaqueValueExpr *OpaqueValue)
+      : CoroutineSuspendExpr(CoyieldExprClass, CoyieldLoc, Operand, Common,
+                             Ready, Suspend, Resume, OpaqueValue) {}
+  CoyieldExpr(SourceLocation CoyieldLoc, QualType Ty, Expr *Operand,
+              Expr *Common)
+      : CoroutineSuspendExpr(CoyieldExprClass, CoyieldLoc, Ty, Operand,
+                             Common) {}
   CoyieldExpr(EmptyShell Empty)
       : CoroutineSuspendExpr(CoyieldExprClass, Empty) {}
 
-  Expr *getOperand() const {
-    // FIXME: Dig out the actual operand or store it.
-    return getCommonExpr();
-  }
-
   static bool classof(const Stmt *T) {
     return T->getStmtClass() == CoyieldExprClass;
   }

diff  --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index e9bd7569984c..976ee59f8684 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -10426,10 +10426,13 @@ class Sema final {
   ExprResult ActOnCoyieldExpr(Scope *S, SourceLocation KwLoc, Expr *E);
   StmtResult ActOnCoreturnStmt(Scope *S, SourceLocation KwLoc, Expr *E);
 
-  ExprResult BuildResolvedCoawaitExpr(SourceLocation KwLoc, Expr *E,
-                                      bool IsImplicit = false);
-  ExprResult BuildUnresolvedCoawaitExpr(SourceLocation KwLoc, Expr *E,
-                                        UnresolvedLookupExpr* Lookup);
+  ExprResult BuildOperatorCoawaitLookupExpr(Scope *S, SourceLocation Loc);
+  ExprResult BuildOperatorCoawaitCall(SourceLocation Loc, Expr *E,
+                                      UnresolvedLookupExpr *Lookup);
+  ExprResult BuildResolvedCoawaitExpr(SourceLocation KwLoc, Expr *Operand,
+                                      Expr *Awaiter, bool IsImplicit = false);
+  ExprResult BuildUnresolvedCoawaitExpr(SourceLocation KwLoc, Expr *Operand,
+                                        UnresolvedLookupExpr *Lookup);
   ExprResult BuildCoyieldExpr(SourceLocation KwLoc, Expr *E);
   StmtResult BuildCoreturnStmt(SourceLocation KwLoc, Expr *E,
                                bool IsImplicit = false);

diff  --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index 4c4041a2d68a..453364c3ac3d 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -14095,6 +14095,13 @@ static void AnalyzeImplicitConversions(
     if (!ChildExpr)
       continue;
 
+    if (auto *CSE = dyn_cast<CoroutineSuspendExpr>(E))
+      if (ChildExpr == CSE->getOperand())
+        // Do not recurse over a CoroutineSuspendExpr's operand.
+        // The operand is also a subexpression of getCommonExpr(), and
+        // recursing into it directly would produce duplicate diagnostics.
+        continue;
+
     if (IsLogicalAndOperator &&
         isa<StringLiteral>(ChildExpr->IgnoreParenImpCasts()))
       // Ignore checking string literals that are in logical and operators.

diff  --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp
index 8709e7bed207..b43b0b3eb957 100644
--- a/clang/lib/Sema/SemaCoroutine.cpp
+++ b/clang/lib/Sema/SemaCoroutine.cpp
@@ -246,44 +246,22 @@ static bool isValidCoroutineContext(Sema &S, SourceLocation Loc,
   return !Diagnosed;
 }
 
-static ExprResult buildOperatorCoawaitLookupExpr(Sema &SemaRef, Scope *S,
-                                                 SourceLocation Loc) {
-  DeclarationName OpName =
-      SemaRef.Context.DeclarationNames.getCXXOperatorName(OO_Coawait);
-  LookupResult Operators(SemaRef, OpName, SourceLocation(),
-                         Sema::LookupOperatorName);
-  SemaRef.LookupName(Operators, S);
-
-  assert(!Operators.isAmbiguous() && "Operator lookup cannot be ambiguous");
-  const auto &Functions = Operators.asUnresolvedSet();
-  bool IsOverloaded =
-      Functions.size() > 1 ||
-      (Functions.size() == 1 && isa<FunctionTemplateDecl>(*Functions.begin()));
-  Expr *CoawaitOp = UnresolvedLookupExpr::Create(
-      SemaRef.Context, /*NamingClass*/ nullptr, NestedNameSpecifierLoc(),
-      DeclarationNameInfo(OpName, Loc), /*RequiresADL*/ true, IsOverloaded,
-      Functions.begin(), Functions.end());
-  assert(CoawaitOp);
-  return CoawaitOp;
-}
-
 /// Build a call to 'operator co_await' if there is a suitable operator for
 /// the given expression.
-static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, SourceLocation Loc,
-                                           Expr *E,
-                                           UnresolvedLookupExpr *Lookup) {
+ExprResult Sema::BuildOperatorCoawaitCall(SourceLocation Loc, Expr *E,
+                                          UnresolvedLookupExpr *Lookup) {
   UnresolvedSet<16> Functions;
   Functions.append(Lookup->decls_begin(), Lookup->decls_end());
-  return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E);
+  return CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E);
 }
 
 static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S,
                                            SourceLocation Loc, Expr *E) {
-  ExprResult R = buildOperatorCoawaitLookupExpr(SemaRef, S, Loc);
+  ExprResult R = SemaRef.BuildOperatorCoawaitLookupExpr(S, Loc);
   if (R.isInvalid())
     return ExprError();
-  return buildOperatorCoawaitCall(SemaRef, Loc, E,
-                                  cast<UnresolvedLookupExpr>(R.get()));
+  return SemaRef.BuildOperatorCoawaitCall(Loc, E,
+                                          cast<UnresolvedLookupExpr>(R.get()));
 }
 
 static ExprResult buildCoroutineHandle(Sema &S, QualType PromiseType,
@@ -727,14 +705,15 @@ bool Sema::ActOnCoroutineBodyStart(Scope *SC, SourceLocation KWLoc,
   SourceLocation Loc = Fn->getLocation();
   // Build the initial suspend point
   auto buildSuspends = [&](StringRef Name) mutable -> StmtResult {
-    ExprResult Suspend =
+    ExprResult Operand =
         buildPromiseCall(*this, ScopeInfo->CoroutinePromise, Loc, Name, None);
-    if (Suspend.isInvalid())
+    if (Operand.isInvalid())
       return StmtError();
-    Suspend = buildOperatorCoawaitCall(*this, SC, Loc, Suspend.get());
+    ExprResult Suspend =
+        buildOperatorCoawaitCall(*this, SC, Loc, Operand.get());
     if (Suspend.isInvalid())
       return StmtError();
-    Suspend = BuildResolvedCoawaitExpr(Loc, Suspend.get(),
+    Suspend = BuildResolvedCoawaitExpr(Loc, Operand.get(), Suspend.get(),
                                        /*IsImplicit*/ true);
     Suspend = ActOnFinishFullExpr(Suspend.get(), /*DiscardedValue*/ false);
     if (Suspend.isInvalid()) {
@@ -815,88 +794,112 @@ ExprResult Sema::ActOnCoawaitExpr(Scope *S, SourceLocation Loc, Expr *E) {
     if (R.isInvalid()) return ExprError();
     E = R.get();
   }
-  ExprResult Lookup = buildOperatorCoawaitLookupExpr(*this, S, Loc);
+  ExprResult Lookup = BuildOperatorCoawaitLookupExpr(S, Loc);
   if (Lookup.isInvalid())
     return ExprError();
   return BuildUnresolvedCoawaitExpr(Loc, E,
                                    cast<UnresolvedLookupExpr>(Lookup.get()));
 }
 
-ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *E,
+ExprResult Sema::BuildOperatorCoawaitLookupExpr(Scope *S, SourceLocation Loc) {
+  DeclarationName OpName =
+      Context.DeclarationNames.getCXXOperatorName(OO_Coawait);
+  LookupResult Operators(*this, OpName, SourceLocation(),
+                         Sema::LookupOperatorName);
+  LookupName(Operators, S);
+
+  assert(!Operators.isAmbiguous() && "Operator lookup cannot be ambiguous");
+  const auto &Functions = Operators.asUnresolvedSet();
+  bool IsOverloaded =
+      Functions.size() > 1 ||
+      (Functions.size() == 1 && isa<FunctionTemplateDecl>(*Functions.begin()));
+  Expr *CoawaitOp = UnresolvedLookupExpr::Create(
+      Context, /*NamingClass*/ nullptr, NestedNameSpecifierLoc(),
+      DeclarationNameInfo(OpName, Loc), /*RequiresADL*/ true, IsOverloaded,
+      Functions.begin(), Functions.end());
+  assert(CoawaitOp);
+  return CoawaitOp;
+}
+
+// Attempts to resolve and build a CoawaitExpr from "raw" inputs, bailing out to
+// DependentCoawaitExpr if needed.
+ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *Operand,
                                             UnresolvedLookupExpr *Lookup) {
   auto *FSI = checkCoroutineContext(*this, Loc, "co_await");
   if (!FSI)
     return ExprError();
 
-  if (E->hasPlaceholderType()) {
-    ExprResult R = CheckPlaceholderExpr(E);
+  if (Operand->hasPlaceholderType()) {
+    ExprResult R = CheckPlaceholderExpr(Operand);
     if (R.isInvalid())
       return ExprError();
-    E = R.get();
+    Operand = R.get();
   }
 
   auto *Promise = FSI->CoroutinePromise;
   if (Promise->getType()->isDependentType()) {
-    Expr *Res =
-        new (Context) DependentCoawaitExpr(Loc, Context.DependentTy, E, Lookup);
+    Expr *Res = new (Context)
+        DependentCoawaitExpr(Loc, Context.DependentTy, Operand, Lookup);
     return Res;
   }
 
   auto *RD = Promise->getType()->getAsCXXRecordDecl();
+  auto *Transformed = Operand;
   if (lookupMember(*this, "await_transform", RD, Loc)) {
-    ExprResult R = buildPromiseCall(*this, Promise, Loc, "await_transform", E);
+    ExprResult R =
+        buildPromiseCall(*this, Promise, Loc, "await_transform", Operand);
     if (R.isInvalid()) {
       Diag(Loc,
            diag::note_coroutine_promise_implicit_await_transform_required_here)
-          << E->getSourceRange();
+          << Operand->getSourceRange();
       return ExprError();
     }
-    E = R.get();
+    Transformed = R.get();
   }
-  ExprResult Awaitable = buildOperatorCoawaitCall(*this, Loc, E, Lookup);
-  if (Awaitable.isInvalid())
+  ExprResult Awaiter = BuildOperatorCoawaitCall(Loc, Transformed, Lookup);
+  if (Awaiter.isInvalid())
     return ExprError();
 
-  return BuildResolvedCoawaitExpr(Loc, Awaitable.get());
+  return BuildResolvedCoawaitExpr(Loc, Operand, Awaiter.get());
 }
 
-ExprResult Sema::BuildResolvedCoawaitExpr(SourceLocation Loc, Expr *E,
-                                  bool IsImplicit) {
+ExprResult Sema::BuildResolvedCoawaitExpr(SourceLocation Loc, Expr *Operand,
+                                          Expr *Awaiter, bool IsImplicit) {
   auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await", IsImplicit);
   if (!Coroutine)
     return ExprError();
 
-  if (E->hasPlaceholderType()) {
-    ExprResult R = CheckPlaceholderExpr(E);
+  if (Awaiter->hasPlaceholderType()) {
+    ExprResult R = CheckPlaceholderExpr(Awaiter);
     if (R.isInvalid()) return ExprError();
-    E = R.get();
+    Awaiter = R.get();
   }
 
-  if (E->getType()->isDependentType()) {
+  if (Awaiter->getType()->isDependentType()) {
     Expr *Res = new (Context)
-        CoawaitExpr(Loc, Context.DependentTy, E, IsImplicit);
+        CoawaitExpr(Loc, Context.DependentTy, Operand, Awaiter, IsImplicit);
     return Res;
   }
 
   // If the expression is a temporary, materialize it as an lvalue so that we
   // can use it multiple times.
-  if (E->isPRValue())
-    E = CreateMaterializeTemporaryExpr(E->getType(), E, true);
+  if (Awaiter->isPRValue())
+    Awaiter = CreateMaterializeTemporaryExpr(Awaiter->getType(), Awaiter, true);
 
   // The location of the `co_await` token cannot be used when constructing
   // the member call expressions since it's before the location of `Expr`, which
   // is used as the start of the member call expression.
-  SourceLocation CallLoc = E->getExprLoc();
+  SourceLocation CallLoc = Awaiter->getExprLoc();
 
   // Build the await_ready, await_suspend, await_resume calls.
-  ReadySuspendResumeResult RSS = buildCoawaitCalls(
-      *this, Coroutine->CoroutinePromise, CallLoc, E);
+  ReadySuspendResumeResult RSS =
+      buildCoawaitCalls(*this, Coroutine->CoroutinePromise, CallLoc, Awaiter);
   if (RSS.IsInvalid)
     return ExprError();
 
-  Expr *Res =
-      new (Context) CoawaitExpr(Loc, E, RSS.Results[0], RSS.Results[1],
-                                RSS.Results[2], RSS.OpaqueValue, IsImplicit);
+  Expr *Res = new (Context)
+      CoawaitExpr(Loc, Operand, Awaiter, RSS.Results[0], RSS.Results[1],
+                  RSS.Results[2], RSS.OpaqueValue, IsImplicit);
 
   return Res;
 }
@@ -933,8 +936,10 @@ ExprResult Sema::BuildCoyieldExpr(SourceLocation Loc, Expr *E) {
     E = R.get();
   }
 
+  Expr *Operand = E;
+
   if (E->getType()->isDependentType()) {
-    Expr *Res = new (Context) CoyieldExpr(Loc, Context.DependentTy, E);
+    Expr *Res = new (Context) CoyieldExpr(Loc, Context.DependentTy, Operand, E);
     return Res;
   }
 
@@ -950,7 +955,7 @@ ExprResult Sema::BuildCoyieldExpr(SourceLocation Loc, Expr *E) {
     return ExprError();
 
   Expr *Res =
-      new (Context) CoyieldExpr(Loc, E, RSS.Results[0], RSS.Results[1],
+      new (Context) CoyieldExpr(Loc, Operand, E, RSS.Results[0], RSS.Results[1],
                                 RSS.Results[2], RSS.OpaqueValue);
 
   return Res;

diff  --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index aef757efc9e7..3d7279c47f7c 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -1470,9 +1470,28 @@ class TreeTransform {
   ///
   /// By default, performs semantic analysis to build the new expression.
   /// Subclasses may override this routine to provide 
diff erent behavior.
-  ExprResult RebuildCoawaitExpr(SourceLocation CoawaitLoc, Expr *Result,
+  ExprResult RebuildCoawaitExpr(SourceLocation CoawaitLoc, Expr *Operand,
+                                UnresolvedLookupExpr *OpCoawaitLookup,
                                 bool IsImplicit) {
-    return getSema().BuildResolvedCoawaitExpr(CoawaitLoc, Result, IsImplicit);
+    // This function rebuilds a coawait-expr given its operator.
+    // For an explicit coawait-expr, the rebuild involves the full set
+    // of transformations performed by BuildUnresolvedCoawaitExpr(),
+    // including calling await_transform().
+    // For an implicit coawait-expr, we need to rebuild the "operator
+    // coawait" but not await_transform(), so use BuildResolvedCoawaitExpr().
+    // This mirrors how the implicit CoawaitExpr is originally created
+    // in Sema::ActOnCoroutineBodyStart().
+    if (IsImplicit) {
+      ExprResult Suspend = getSema().BuildOperatorCoawaitCall(
+          CoawaitLoc, Operand, OpCoawaitLookup);
+      if (Suspend.isInvalid())
+        return ExprError();
+      return getSema().BuildResolvedCoawaitExpr(CoawaitLoc, Operand,
+                                                Suspend.get(), true);
+    }
+
+    return getSema().BuildUnresolvedCoawaitExpr(CoawaitLoc, Operand,
+                                                OpCoawaitLookup);
   }
 
   /// Build a new co_await expression.
@@ -7945,18 +7964,27 @@ TreeTransform<Derived>::TransformCoreturnStmt(CoreturnStmt *S) {
                                           S->isImplicit());
 }
 
-template<typename Derived>
-ExprResult
-TreeTransform<Derived>::TransformCoawaitExpr(CoawaitExpr *E) {
-  ExprResult Result = getDerived().TransformInitializer(E->getOperand(),
-                                                        /*NotCopyInit*/false);
-  if (Result.isInvalid())
+template <typename Derived>
+ExprResult TreeTransform<Derived>::TransformCoawaitExpr(CoawaitExpr *E) {
+  ExprResult Operand = getDerived().TransformInitializer(E->getOperand(),
+                                                         /*NotCopyInit*/ false);
+  if (Operand.isInvalid())
     return ExprError();
 
+  // Rebuild the common-expr from the operand rather than transforming it
+  // separately.
+
+  // FIXME: getCurScope() should not be used during template instantiation.
+  // We should pick up the set of unqualified lookup results for operator
+  // co_await during the initial parse.
+  ExprResult Lookup = getSema().BuildOperatorCoawaitLookupExpr(
+      getSema().getCurScope(), E->getKeywordLoc());
+
   // Always rebuild; we don't know if this needs to be injected into a new
   // context or if the promise type has changed.
-  return getDerived().RebuildCoawaitExpr(E->getKeywordLoc(), Result.get(),
-                                         E->isImplicit());
+  return getDerived().RebuildCoawaitExpr(
+      E->getKeywordLoc(), Operand.get(),
+      cast<UnresolvedLookupExpr>(Lookup.get()), E->isImplicit());
 }
 
 template <typename Derived>

diff  --git a/clang/test/AST/coroutine-locals-cleanup-exp-namespace.cpp b/clang/test/AST/coroutine-locals-cleanup-exp-namespace.cpp
index 048c6778bd05..3122df9b6732 100644
--- a/clang/test/AST/coroutine-locals-cleanup-exp-namespace.cpp
+++ b/clang/test/AST/coroutine-locals-cleanup-exp-namespace.cpp
@@ -85,7 +85,8 @@ Task bar() {
 // CHECK:           CaseStmt
 // CHECK:             ExprWithCleanups {{.*}} 'void'
 // CHECK-NEXT:          CoawaitExpr
-// CHECK-NEXT:            MaterializeTemporaryExpr {{.*}} 'Task::Awaiter':'Task::Awaiter'
+// CHECK-NEXT:            CXXBindTemporaryExpr {{.*}} 'Task' (CXXTemporary {{.*}})
+// CHECK:                 MaterializeTemporaryExpr {{.*}} 'Task::Awaiter':'Task::Awaiter'
 // CHECK:                 ExprWithCleanups {{.*}} 'bool'
 // CHECK-NEXT:              CXXMemberCallExpr {{.*}} 'bool'
 // CHECK-NEXT:                MemberExpr {{.*}} .await_ready
@@ -97,7 +98,8 @@ Task bar() {
 // CHECK:           CaseStmt
 // CHECK:             ExprWithCleanups {{.*}} 'void'
 // CHECK-NEXT:          CoawaitExpr
-// CHECK-NEXT:            MaterializeTemporaryExpr {{.*}} 'Task::Awaiter':'Task::Awaiter'
+// CHECK-NEXT:            CXXBindTemporaryExpr {{.*}} 'Task' (CXXTemporary {{.*}})
+// CHECK:                 MaterializeTemporaryExpr {{.*}} 'Task::Awaiter':'Task::Awaiter'
 // CHECK:                 ExprWithCleanups {{.*}} 'bool'
 // CHECK-NEXT:              CXXMemberCallExpr {{.*}} 'bool'
 // CHECK-NEXT:                MemberExpr {{.*}} .await_ready

diff  --git a/clang/test/AST/coroutine-locals-cleanup.cpp b/clang/test/AST/coroutine-locals-cleanup.cpp
index 4e2fe6275de7..aa04a350e9a9 100644
--- a/clang/test/AST/coroutine-locals-cleanup.cpp
+++ b/clang/test/AST/coroutine-locals-cleanup.cpp
@@ -85,7 +85,8 @@ Task bar() {
 // CHECK:           CaseStmt
 // CHECK:             ExprWithCleanups {{.*}} 'void'
 // CHECK-NEXT:          CoawaitExpr
-// CHECK-NEXT:            MaterializeTemporaryExpr {{.*}} 'Task::Awaiter':'Task::Awaiter'
+// CHECK-NEXT:            CXXBindTemporaryExpr {{.*}} 'Task' (CXXTemporary {{.*}})
+// CHECK:                 MaterializeTemporaryExpr {{.*}} 'Task::Awaiter':'Task::Awaiter'
 // CHECK:                 ExprWithCleanups {{.*}} 'bool'
 // CHECK-NEXT:              CXXMemberCallExpr {{.*}} 'bool'
 // CHECK-NEXT:                MemberExpr {{.*}} .await_ready
@@ -97,7 +98,8 @@ Task bar() {
 // CHECK:           CaseStmt
 // CHECK:             ExprWithCleanups {{.*}} 'void'
 // CHECK-NEXT:          CoawaitExpr
-// CHECK-NEXT:            MaterializeTemporaryExpr {{.*}} 'Task::Awaiter':'Task::Awaiter'
+// CHECK-NEXT:            CXXBindTemporaryExpr {{.*}} 'Task' (CXXTemporary {{.*}})
+// CHECK:                 MaterializeTemporaryExpr {{.*}} 'Task::Awaiter':'Task::Awaiter'
 // CHECK:                 ExprWithCleanups {{.*}} 'bool'
 // CHECK-NEXT:              CXXMemberCallExpr {{.*}} 'bool'
 // CHECK-NEXT:                MemberExpr {{.*}} .await_ready

diff  --git a/clang/test/AST/coroutine-source-location-crash-exp-namespace.cpp b/clang/test/AST/coroutine-source-location-crash-exp-namespace.cpp
index 9995dee542e9..fb9aaa537c40 100644
--- a/clang/test/AST/coroutine-source-location-crash-exp-namespace.cpp
+++ b/clang/test/AST/coroutine-source-location-crash-exp-namespace.cpp
@@ -36,6 +36,7 @@ coro_t f(int n) {
   A a{};
   // CHECK: CoawaitExpr {{0x[0-9a-fA-F]+}} <col:3, col:12>
   // CHECK-NEXT: DeclRefExpr {{0x[0-9a-fA-F]+}} <col:12>
+  // CHECK-NEXT: DeclRefExpr {{0x[0-9a-fA-F]+}} <col:12>
   // CHECK-NEXT: CXXMemberCallExpr {{0x[0-9a-fA-F]+}} <col:12>
   // CHECK-NEXT: MemberExpr {{0x[0-9a-fA-F]+}} <col:12>
   co_await a;

diff  --git a/clang/test/AST/coroutine-source-location-crash.cpp b/clang/test/AST/coroutine-source-location-crash.cpp
index 9b18dc817fb5..fcf23d21d298 100644
--- a/clang/test/AST/coroutine-source-location-crash.cpp
+++ b/clang/test/AST/coroutine-source-location-crash.cpp
@@ -36,6 +36,7 @@ coro_t f(int n) {
   A a{};
   // CHECK: CoawaitExpr {{0x[0-9a-fA-F]+}} <col:3, col:12>
   // CHECK-NEXT: DeclRefExpr {{0x[0-9a-fA-F]+}} <col:12>
+  // CHECK-NEXT: DeclRefExpr {{0x[0-9a-fA-F]+}} <col:12>
   // CHECK-NEXT: CXXMemberCallExpr {{0x[0-9a-fA-F]+}} <col:12>
   // CHECK-NEXT: MemberExpr {{0x[0-9a-fA-F]+}} <col:12>
   co_await a;

diff  --git a/clang/test/SemaCXX/co_await-ast.cpp b/clang/test/SemaCXX/co_await-ast.cpp
new file mode 100644
index 000000000000..1221e8623c26
--- /dev/null
+++ b/clang/test/SemaCXX/co_await-ast.cpp
@@ -0,0 +1,97 @@
+// RUN: %clang_cc1 -std=c++20 -fsyntax-only -ast-dump -ast-dump-filter=foo %s | FileCheck %s --strict-whitespace
+
+namespace std {
+template <typename, typename...> struct coroutine_traits;
+template <typename> struct coroutine_handle {
+  template <typename U>
+  coroutine_handle(coroutine_handle<U> &&) noexcept;
+  static coroutine_handle from_address(void *__addr) noexcept;
+};
+} // namespace std
+
+struct executor {};
+struct awaitable {};
+struct awaitable_frame {
+  awaitable get_return_object();
+  void return_void();
+  void unhandled_exception();
+  struct result_t {
+    ~result_t();
+    bool await_ready() const noexcept;
+    void await_suspend(std::coroutine_handle<void>) noexcept;
+    void await_resume() const noexcept;
+  };
+  result_t initial_suspend() noexcept;
+  result_t final_suspend() noexcept;
+  result_t await_transform(executor) noexcept;
+};
+
+namespace std {
+template <>
+struct coroutine_traits<awaitable> {
+  typedef awaitable_frame promise_type;
+};
+} // namespace std
+
+awaitable foo() {
+  co_await executor();
+}
+
+// Check that CoawaitExpr contains the correct subexpressions, including
+// the operand expression as written in the source.
+
+// CHECK-LABEL: Dumping foo:
+// CHECK: FunctionDecl {{.*}} foo 'awaitable ()'
+// CHECK: `-CoroutineBodyStmt {{.*}}
+// CHECK:   |-CompoundStmt {{.*}}
+// CHECK:   | `-ExprWithCleanups {{.*}} 'void'
+// CHECK:   |   `-CoawaitExpr {{.*}} 'void'
+// CHECK:   |     |-CXXTemporaryObjectExpr {{.*}} 'executor' 'void () noexcept' zeroing
+// CHECK:   |     |-MaterializeTemporaryExpr {{.*}} 'awaitable_frame::result_t' lvalue
+// CHECK:   |     | `-CXXBindTemporaryExpr {{.*}} 'awaitable_frame::result_t' (CXXTemporary {{.*}})
+// CHECK:   |     |   `-CXXMemberCallExpr {{.*}} 'awaitable_frame::result_t'
+// CHECK:   |     |     |-MemberExpr {{.*}} '<bound member function type>' .await_transform {{.*}}
+// CHECK:   |     |     | `-DeclRefExpr {{.*}} 'std::coroutine_traits<awaitable>::promise_type':'awaitable_frame' lvalue Var {{.*}} '__promise' 'std::coroutine_traits<awaitable>::promise_type':'awaitable_frame'
+// CHECK:   |     |     `-CXXTemporaryObjectExpr {{.*}} 'executor' 'void () noexcept' zeroing
+// CHECK:   |     |-ExprWithCleanups {{.*}} 'bool'
+// CHECK:   |     | `-CXXMemberCallExpr {{.*}} 'bool'
+// CHECK:   |     |   `-MemberExpr {{.*}} '<bound member function type>' .await_ready {{.*}}
+// CHECK:   |     |     `-ImplicitCastExpr {{.*}} 'const awaitable_frame::result_t' lvalue <NoOp>
+// CHECK:   |     |       `-OpaqueValueExpr {{.*}} 'awaitable_frame::result_t' lvalue
+// CHECK:   |     |         `-MaterializeTemporaryExpr {{.*}} 'awaitable_frame::result_t' lvalue
+// CHECK:   |     |           `-CXXBindTemporaryExpr {{.*}} 'awaitable_frame::result_t' (CXXTemporary {{.*}})
+// CHECK:   |     |             `-CXXMemberCallExpr {{.*}} 'awaitable_frame::result_t'
+// CHECK:   |     |               |-MemberExpr {{.*}} '<bound member function type>' .await_transform {{.*}}
+// CHECK:   |     |               | `-DeclRefExpr {{.*}} 'std::coroutine_traits<awaitable>::promise_type':'awaitable_frame' lvalue Var {{.*}} '__promise' 'std::coroutine_traits<awaitable>::promise_type':'awaitable_frame'
+// CHECK:   |     |               `-CXXTemporaryObjectExpr {{.*}} 'executor' 'void () noexcept' zeroing
+// CHECK:   |     |-ExprWithCleanups {{.*}} 'void'
+// CHECK:   |     | `-CXXMemberCallExpr {{.*}} 'void'
+// CHECK:   |     |   |-MemberExpr {{.*}} '<bound member function type>' .await_suspend {{.*}}
+// CHECK:   |     |   | `-OpaqueValueExpr {{.*}} 'awaitable_frame::result_t' lvalue
+// CHECK:   |     |   |   `-MaterializeTemporaryExpr {{.*}} 'awaitable_frame::result_t' lvalue
+// CHECK:   |     |   |     `-CXXBindTemporaryExpr {{.*}} 'awaitable_frame::result_t' (CXXTemporary {{.*}})
+// CHECK:   |     |   |       `-CXXMemberCallExpr {{.*}} 'awaitable_frame::result_t'
+// CHECK:   |     |   |         |-MemberExpr {{.*}} '<bound member function type>' .await_transform {{.*}}
+// CHECK:   |     |   |         | `-DeclRefExpr {{.*}} 'std::coroutine_traits<awaitable>::promise_type':'awaitable_frame' lvalue Var {{.*}} '__promise' 'std::coroutine_traits<awaitable>::promise_type':'awaitable_frame'
+// CHECK:   |     |   |         `-CXXTemporaryObjectExpr {{.*}} 'executor' 'void () noexcept' zeroing
+// CHECK:   |     |   `-ImplicitCastExpr {{.*}} 'std::coroutine_handle<void>':'std::coroutine_handle<void>' <ConstructorConversion>
+// CHECK:   |     |     `-CXXConstructExpr {{.*}} 'std::coroutine_handle<void>':'std::coroutine_handle<void>' 'void (coroutine_handle<awaitable_frame> &&) noexcept'
+// CHECK:   |     |       `-MaterializeTemporaryExpr {{.*}} 'std::coroutine_handle<awaitable_frame>' xvalue
+// CHECK:   |     |         `-CallExpr {{.*}} 'std::coroutine_handle<awaitable_frame>'
+// CHECK:   |     |           |-ImplicitCastExpr {{.*}} 'std::coroutine_handle<awaitable_frame> (*)(void *) noexcept' <FunctionToPointerDecay>
+// CHECK:   |     |           | `-DeclRefExpr {{.*}} 'std::coroutine_handle<awaitable_frame> (void *) noexcept' lvalue CXXMethod {{.*}} 'from_address' 'std::coroutine_handle<awaitable_frame> (void *) noexcept'
+// CHECK:   |     |           `-CallExpr {{.*}} 'void *'
+// CHECK:   |     |             `-ImplicitCastExpr {{.*}} 'void *(*)() noexcept' <FunctionToPointerDecay>
+// CHECK:   |     |               `-DeclRefExpr {{.*}} 'void *() noexcept' lvalue Function {{.*}} '__builtin_coro_frame' 'void *() noexcept'
+// CHECK:   |     `-CXXMemberCallExpr {{.*}} 'void'
+// CHECK:   |       `-MemberExpr {{.*}} '<bound member function type>' .await_resume {{.*}}
+// CHECK:   |         `-ImplicitCastExpr {{.*}} 'const awaitable_frame::result_t' lvalue <NoOp>
+// CHECK:   |           `-OpaqueValueExpr {{.*}} 'awaitable_frame::result_t' lvalue
+// CHECK:   |             `-MaterializeTemporaryExpr {{.*}} 'awaitable_frame::result_t' lvalue
+// CHECK:   |               `-CXXBindTemporaryExpr {{.*}} 'awaitable_frame::result_t' (CXXTemporary {{.*}})
+// CHECK:   |                 `-CXXMemberCallExpr {{.*}} 'awaitable_frame::result_t'
+// CHECK:   |                   |-MemberExpr {{.*}} '<bound member function type>' .await_transform {{.*}}
+// CHECK:   |                   | `-DeclRefExpr {{.*}} 'std::coroutine_traits<awaitable>::promise_type':'awaitable_frame' lvalue Var {{.*}} '__promise' 'std::coroutine_traits<awaitable>::promise_type':'awaitable_frame'
+// CHECK:   |                   `-CXXTemporaryObjectExpr {{.*}} <col:12, col:21> 'executor' 'void () noexcept' zeroing
+
+// Rest of the generated coroutine statements omitted.


        


More information about the cfe-commits mailing list