[clang] [llvm] [Clang] C++20 Coroutines: Introduce Frontend Attribute [[clang::coro_await_elidable]] (PR #99282)

Yuxuan Chen via cfe-commits cfe-commits at lists.llvm.org
Wed Aug 28 14:34:25 PDT 2024


https://github.com/yuxuanchen1997 updated https://github.com/llvm/llvm-project/pull/99282

>From ef2b868666579df534d63daeebd2b201600ce292 Mon Sep 17 00:00:00 2001
From: Yuxuan Chen <yuxuanchen1997 at outlook.com>
Date: Tue, 4 Jun 2024 23:22:00 -0700
Subject: [PATCH] [Clang] C++20 Coroutines: Introduce Frontend Attribute
 [[clang::coro_await_elidable]]

---
 clang/docs/ReleaseNotes.rst                   |  3 +
 clang/include/clang/AST/Expr.h                |  3 +
 clang/include/clang/AST/Stmt.h                |  5 +-
 clang/include/clang/Basic/Attr.td             |  8 ++
 clang/include/clang/Basic/AttrDocs.td         | 33 ++++++-
 clang/lib/AST/Expr.cpp                        |  2 +
 clang/lib/CodeGen/CGBlocks.cpp                |  5 +-
 clang/lib/CodeGen/CGCUDARuntime.cpp           |  5 +-
 clang/lib/CodeGen/CGCUDARuntime.h             |  8 +-
 clang/lib/CodeGen/CGCXXABI.h                  | 10 +--
 clang/lib/CodeGen/CGClass.cpp                 | 16 ++--
 clang/lib/CodeGen/CGExpr.cpp                  | 55 ++++++++----
 clang/lib/CodeGen/CGExprCXX.cpp               | 60 +++++++------
 clang/lib/CodeGen/CodeGenFunction.h           | 64 ++++++++------
 clang/lib/CodeGen/ItaniumCXXABI.cpp           | 16 ++--
 clang/lib/CodeGen/MicrosoftCXXABI.cpp         | 18 ++--
 clang/lib/Sema/SemaCoroutine.cpp              | 24 ++++-
 clang/test/CodeGenCoroutines/Inputs/utility.h | 13 +++
 .../CodeGenCoroutines/coro-await-elidable.cpp | 87 +++++++++++++++++++
 ...a-attribute-supported-attributes-list.test |  1 +
 llvm/include/llvm/Bitcode/LLVMBitCodes.h      |  1 +
 llvm/include/llvm/IR/Attributes.td            |  4 +
 llvm/lib/Bitcode/Reader/BitcodeReader.cpp     |  2 +
 llvm/lib/Bitcode/Writer/BitcodeWriter.cpp     |  2 +
 llvm/lib/Transforms/Utils/CodeExtractor.cpp   |  1 +
 25 files changed, 336 insertions(+), 110 deletions(-)
 create mode 100644 clang/test/CodeGenCoroutines/Inputs/utility.h
 create mode 100644 clang/test/CodeGenCoroutines/coro-await-elidable.cpp

diff --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst
index 4289b3374a810b..9038d18fd11fdc 100644
--- a/clang/docs/ReleaseNotes.rst
+++ b/clang/docs/ReleaseNotes.rst
@@ -225,6 +225,9 @@ Attribute Changes in Clang
   more cases where the returned reference outlives the object.
   (#GH100567)
 
+- Introduced a new attribute ``[[clang::coro_await_elidable]]`` on coroutine return types
+  to express elideability at call sites where the coroutine is co_awaited as a prvalue.
+
 Improvements to Clang's diagnostics
 -----------------------------------
 
diff --git a/clang/include/clang/AST/Expr.h b/clang/include/clang/AST/Expr.h
index 7bacf028192c65..77ce378df26dcf 100644
--- a/clang/include/clang/AST/Expr.h
+++ b/clang/include/clang/AST/Expr.h
@@ -2991,6 +2991,9 @@ class CallExpr : public Expr {
 
   bool hasStoredFPFeatures() const { return CallExprBits.HasFPFeatures; }
 
+  bool isCoroElideSafe() const { return CallExprBits.IsCoroElideSafe; }
+  void setCoroElideSafe(bool V = true) { CallExprBits.IsCoroElideSafe = V; }
+
   Decl *getCalleeDecl() { return getCallee()->getReferencedDeclOfCallee(); }
   const Decl *getCalleeDecl() const {
     return getCallee()->getReferencedDeclOfCallee();
diff --git a/clang/include/clang/AST/Stmt.h b/clang/include/clang/AST/Stmt.h
index f1a2aac0a8b2f8..7aed83e9c68bb7 100644
--- a/clang/include/clang/AST/Stmt.h
+++ b/clang/include/clang/AST/Stmt.h
@@ -561,8 +561,11 @@ class alignas(void *) Stmt {
     LLVM_PREFERRED_TYPE(bool)
     unsigned HasFPFeatures : 1;
 
+    /// True if the call expression is a must-elide call to a coroutine.
+    unsigned IsCoroElideSafe : 1;
+
     /// Padding used to align OffsetToTrailingObjects to a byte multiple.
-    unsigned : 24 - 3 - NumExprBits;
+    unsigned : 24 - 4 - NumExprBits;
 
     /// The offset in bytes from the this pointer to the start of the
     /// trailing objects belonging to CallExpr. Intentionally byte sized
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index a83e908899c83b..711e3adf0892ed 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -1220,6 +1220,14 @@ def CoroDisableLifetimeBound : InheritableAttr {
   let SimpleHandler = 1;
 }
 
+def CoroAwaitElidable : InheritableAttr {
+  let Spellings = [Clang<"coro_await_elidable">];
+  let Subjects = SubjectList<[CXXRecord]>;
+  let LangOpts = [CPlusPlus];
+  let Documentation = [CoroAwaitElidableDoc];
+  let SimpleHandler = 1;
+}
+
 // OSObject-based attributes.
 def OSConsumed : InheritableParamAttr {
   let Spellings = [Clang<"os_consumed">];
diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td
index c2b9d7cb93c309..a1ac58c9bc80a8 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -8218,6 +8218,38 @@ but do not pass them to the underlying coroutine or pass them by value.
 }];
 }
 
+def CoroAwaitElidableDoc : Documentation {
+  let Category = DocCatDecl;
+  let Content = [{
+The ``[[clang::coro_await_elidable]]`` is a class attribute which can be applied
+to a coroutine return type.
+
+When a coroutine function that returns such a type calls another coroutine function,
+the compiler performs heap allocation elision when the call to the coroutine function
+is immediately co_awaited as a prvalue. In this case, the coroutine frame for the
+callee will be a local variable within the enclosing braces in the caller's stack
+frame. And the local variable, like other variables in coroutines, may be collected
+into the coroutine frame, which may be allocated on the heap.
+
+Example:
+
+.. code-block:: c++
+
+  class [[clang::coro_await_elidable]] Task { ... };
+
+  Task foo();
+  Task bar() {
+    co_await foo(); // foo()'s coroutine frame on this line is elidable
+    auto t = foo(); // foo()'s coroutine frame on this line is NOT elidable
+    co_await t;
+  }
+
+The behavior is undefined if the caller coroutine is destroyed earlier than the
+callee coroutine.
+
+}];
+}
+
 def CountedByDocs : Documentation {
   let Category = DocCatField;
   let Content = [{
@@ -8377,4 +8409,3 @@ Declares that a function potentially allocates heap memory, and prevents any pot
 of ``nonallocating`` by the compiler.
   }];
 }
-
diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp
index 3309619850f34a..30904d8721f6d2 100644
--- a/clang/lib/AST/Expr.cpp
+++ b/clang/lib/AST/Expr.cpp
@@ -1474,6 +1474,7 @@ CallExpr::CallExpr(StmtClass SC, Expr *Fn, ArrayRef<Expr *> PreArgs,
   this->computeDependence();
 
   CallExprBits.HasFPFeatures = FPFeatures.requiresTrailingStorage();
+  CallExprBits.IsCoroElideSafe = false;
   if (hasStoredFPFeatures())
     setStoredFPFeatures(FPFeatures);
 }
@@ -1489,6 +1490,7 @@ CallExpr::CallExpr(StmtClass SC, unsigned NumPreArgs, unsigned NumArgs,
   assert((CallExprBits.OffsetToTrailingObjects == OffsetToTrailingObjects) &&
          "OffsetToTrailingObjects overflow!");
   CallExprBits.HasFPFeatures = HasFPFeatures;
+  CallExprBits.IsCoroElideSafe = false;
 }
 
 CallExpr *CallExpr::Create(const ASTContext &Ctx, Expr *Fn,
diff --git a/clang/lib/CodeGen/CGBlocks.cpp b/clang/lib/CodeGen/CGBlocks.cpp
index 066139b1c78c7f..684fda74407313 100644
--- a/clang/lib/CodeGen/CGBlocks.cpp
+++ b/clang/lib/CodeGen/CGBlocks.cpp
@@ -1163,7 +1163,8 @@ llvm::Type *CodeGenModule::getGenericBlockLiteralType() {
 }
 
 RValue CodeGenFunction::EmitBlockCallExpr(const CallExpr *E,
-                                          ReturnValueSlot ReturnValue) {
+                                          ReturnValueSlot ReturnValue,
+                                          llvm::CallBase **CallOrInvoke) {
   const auto *BPT = E->getCallee()->getType()->castAs<BlockPointerType>();
   llvm::Value *BlockPtr = EmitScalarExpr(E->getCallee());
   llvm::Type *GenBlockTy = CGM.getGenericBlockLiteralType();
@@ -1220,7 +1221,7 @@ RValue CodeGenFunction::EmitBlockCallExpr(const CallExpr *E,
   CGCallee Callee(CGCalleeInfo(), Func);
 
   // And call the block.
-  return EmitCall(FnInfo, Callee, ReturnValue, Args);
+  return EmitCall(FnInfo, Callee, ReturnValue, Args, CallOrInvoke);
 }
 
 Address CodeGenFunction::GetAddrOfBlockDecl(const VarDecl *variable) {
diff --git a/clang/lib/CodeGen/CGCUDARuntime.cpp b/clang/lib/CodeGen/CGCUDARuntime.cpp
index c14a9d3f2bbbcf..1e1da1e2411a76 100644
--- a/clang/lib/CodeGen/CGCUDARuntime.cpp
+++ b/clang/lib/CodeGen/CGCUDARuntime.cpp
@@ -25,7 +25,8 @@ CGCUDARuntime::~CGCUDARuntime() {}
 
 RValue CGCUDARuntime::EmitCUDAKernelCallExpr(CodeGenFunction &CGF,
                                              const CUDAKernelCallExpr *E,
-                                             ReturnValueSlot ReturnValue) {
+                                             ReturnValueSlot ReturnValue,
+                                             llvm::CallBase **CallOrInvoke) {
   llvm::BasicBlock *ConfigOKBlock = CGF.createBasicBlock("kcall.configok");
   llvm::BasicBlock *ContBlock = CGF.createBasicBlock("kcall.end");
 
@@ -35,7 +36,7 @@ RValue CGCUDARuntime::EmitCUDAKernelCallExpr(CodeGenFunction &CGF,
 
   eval.begin(CGF);
   CGF.EmitBlock(ConfigOKBlock);
-  CGF.EmitSimpleCallExpr(E, ReturnValue);
+  CGF.EmitSimpleCallExpr(E, ReturnValue, CallOrInvoke);
   CGF.EmitBranch(ContBlock);
 
   CGF.EmitBlock(ContBlock);
diff --git a/clang/lib/CodeGen/CGCUDARuntime.h b/clang/lib/CodeGen/CGCUDARuntime.h
index 8030d632cc3d28..86f776004ee7c1 100644
--- a/clang/lib/CodeGen/CGCUDARuntime.h
+++ b/clang/lib/CodeGen/CGCUDARuntime.h
@@ -21,6 +21,7 @@
 #include "llvm/IR/GlobalValue.h"
 
 namespace llvm {
+class CallBase;
 class Function;
 class GlobalVariable;
 }
@@ -82,9 +83,10 @@ class CGCUDARuntime {
   CGCUDARuntime(CodeGenModule &CGM) : CGM(CGM) {}
   virtual ~CGCUDARuntime();
 
-  virtual RValue EmitCUDAKernelCallExpr(CodeGenFunction &CGF,
-                                        const CUDAKernelCallExpr *E,
-                                        ReturnValueSlot ReturnValue);
+  virtual RValue
+  EmitCUDAKernelCallExpr(CodeGenFunction &CGF, const CUDAKernelCallExpr *E,
+                         ReturnValueSlot ReturnValue,
+                         llvm::CallBase **CallOrInvoke = nullptr);
 
   /// Emits a kernel launch stub.
   virtual void emitDeviceStub(CodeGenFunction &CGF, FunctionArgList &Args) = 0;
diff --git a/clang/lib/CodeGen/CGCXXABI.h b/clang/lib/CodeGen/CGCXXABI.h
index 7dcc539111996b..687ff7fb844445 100644
--- a/clang/lib/CodeGen/CGCXXABI.h
+++ b/clang/lib/CodeGen/CGCXXABI.h
@@ -485,11 +485,11 @@ class CGCXXABI {
       llvm::PointerUnion<const CXXDeleteExpr *, const CXXMemberCallExpr *>;
 
   /// Emit the ABI-specific virtual destructor call.
-  virtual llvm::Value *EmitVirtualDestructorCall(CodeGenFunction &CGF,
-                                                 const CXXDestructorDecl *Dtor,
-                                                 CXXDtorType DtorType,
-                                                 Address This,
-                                                 DeleteOrMemberCallExpr E) = 0;
+  virtual llvm::Value *
+  EmitVirtualDestructorCall(CodeGenFunction &CGF, const CXXDestructorDecl *Dtor,
+                            CXXDtorType DtorType, Address This,
+                            DeleteOrMemberCallExpr E,
+                            llvm::CallBase **CallOrInvoke) = 0;
 
   virtual void adjustCallArgsForDestructorThunk(CodeGenFunction &CGF,
                                                 GlobalDecl GD,
diff --git a/clang/lib/CodeGen/CGClass.cpp b/clang/lib/CodeGen/CGClass.cpp
index e5ba50de3462da..352955749a6332 100644
--- a/clang/lib/CodeGen/CGClass.cpp
+++ b/clang/lib/CodeGen/CGClass.cpp
@@ -2192,15 +2192,11 @@ static bool canEmitDelegateCallArgs(CodeGenFunction &CGF,
   return true;
 }
 
-void CodeGenFunction::EmitCXXConstructorCall(const CXXConstructorDecl *D,
-                                             CXXCtorType Type,
-                                             bool ForVirtualBase,
-                                             bool Delegating,
-                                             Address This,
-                                             CallArgList &Args,
-                                             AggValueSlot::Overlap_t Overlap,
-                                             SourceLocation Loc,
-                                             bool NewPointerIsChecked) {
+void CodeGenFunction::EmitCXXConstructorCall(
+    const CXXConstructorDecl *D, CXXCtorType Type, bool ForVirtualBase,
+    bool Delegating, Address This, CallArgList &Args,
+    AggValueSlot::Overlap_t Overlap, SourceLocation Loc,
+    bool NewPointerIsChecked, llvm::CallBase **CallOrInvoke) {
   const CXXRecordDecl *ClassDecl = D->getParent();
 
   if (!NewPointerIsChecked)
@@ -2248,7 +2244,7 @@ void CodeGenFunction::EmitCXXConstructorCall(const CXXConstructorDecl *D,
   const CGFunctionInfo &Info = CGM.getTypes().arrangeCXXConstructorCall(
       Args, D, Type, ExtraArgs.Prefix, ExtraArgs.Suffix, PassPrototypeArgs);
   CGCallee Callee = CGCallee::forDirect(CalleePtr, GlobalDecl(D, Type));
-  EmitCall(Info, Callee, ReturnValueSlot(), Args, nullptr, false, Loc);
+  EmitCall(Info, Callee, ReturnValueSlot(), Args, CallOrInvoke, false, Loc);
 
   // Generate vtable assumptions if we're constructing a complete object
   // with a vtable.  We don't do this for base subobjects for two reasons:
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 426fccb721db84..e946b5ce5c0f1a 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -33,6 +33,7 @@
 #include "clang/Basic/SourceManager.h"
 #include "llvm/ADT/Hashing.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/Intrinsics.h"
@@ -5489,16 +5490,30 @@ RValue CodeGenFunction::EmitRValueForField(LValue LV,
 //===--------------------------------------------------------------------===//
 
 RValue CodeGenFunction::EmitCallExpr(const CallExpr *E,
-                                     ReturnValueSlot ReturnValue) {
+                                     ReturnValueSlot ReturnValue,
+                                     llvm::CallBase **CallOrInvoke) {
+  llvm::CallBase *CallOrInvokeStorage;
+  if (!CallOrInvoke) {
+    CallOrInvoke = &CallOrInvokeStorage;
+  }
+
+  auto AddCoroElideSafeOnExit = llvm::make_scope_exit([&] {
+    if (E->isCoroElideSafe()) {
+      auto *I = *CallOrInvoke;
+      if (I)
+        I->addFnAttr(llvm::Attribute::CoroElideSafe);
+    }
+  });
+
   // Builtins never have block type.
   if (E->getCallee()->getType()->isBlockPointerType())
-    return EmitBlockCallExpr(E, ReturnValue);
+    return EmitBlockCallExpr(E, ReturnValue, CallOrInvoke);
 
   if (const auto *CE = dyn_cast<CXXMemberCallExpr>(E))
-    return EmitCXXMemberCallExpr(CE, ReturnValue);
+    return EmitCXXMemberCallExpr(CE, ReturnValue, CallOrInvoke);
 
   if (const auto *CE = dyn_cast<CUDAKernelCallExpr>(E))
-    return EmitCUDAKernelCallExpr(CE, ReturnValue);
+    return EmitCUDAKernelCallExpr(CE, ReturnValue, CallOrInvoke);
 
   // A CXXOperatorCallExpr is created even for explicit object methods, but
   // these should be treated like static function call.
@@ -5506,7 +5521,7 @@ RValue CodeGenFunction::EmitCallExpr(const CallExpr *E,
     if (const auto *MD =
             dyn_cast_if_present<CXXMethodDecl>(CE->getCalleeDecl());
         MD && MD->isImplicitObjectMemberFunction())
-      return EmitCXXOperatorMemberCallExpr(CE, MD, ReturnValue);
+      return EmitCXXOperatorMemberCallExpr(CE, MD, ReturnValue, CallOrInvoke);
 
   CGCallee callee = EmitCallee(E->getCallee());
 
@@ -5519,14 +5534,17 @@ RValue CodeGenFunction::EmitCallExpr(const CallExpr *E,
     return EmitCXXPseudoDestructorExpr(callee.getPseudoDestructorExpr());
   }
 
-  return EmitCall(E->getCallee()->getType(), callee, E, ReturnValue);
+  return EmitCall(E->getCallee()->getType(), callee, E, ReturnValue,
+                  /*Chain=*/nullptr, CallOrInvoke);
 }
 
 /// Emit a CallExpr without considering whether it might be a subclass.
 RValue CodeGenFunction::EmitSimpleCallExpr(const CallExpr *E,
-                                           ReturnValueSlot ReturnValue) {
+                                           ReturnValueSlot ReturnValue,
+                                           llvm::CallBase **CallOrInvoke) {
   CGCallee Callee = EmitCallee(E->getCallee());
-  return EmitCall(E->getCallee()->getType(), Callee, E, ReturnValue);
+  return EmitCall(E->getCallee()->getType(), Callee, E, ReturnValue,
+                  /*Chain=*/nullptr, CallOrInvoke);
 }
 
 // Detect the unusual situation where an inline version is shadowed by a
@@ -5730,8 +5748,9 @@ LValue CodeGenFunction::EmitBinaryOperatorLValue(const BinaryOperator *E) {
   llvm_unreachable("bad evaluation kind");
 }
 
-LValue CodeGenFunction::EmitCallExprLValue(const CallExpr *E) {
-  RValue RV = EmitCallExpr(E);
+LValue CodeGenFunction::EmitCallExprLValue(const CallExpr *E,
+                                           llvm::CallBase **CallOrInvoke) {
+  RValue RV = EmitCallExpr(E, ReturnValueSlot(), CallOrInvoke);
 
   if (!RV.isScalar())
     return MakeAddrLValue(RV.getAggregateAddress(), E->getType(),
@@ -5854,9 +5873,11 @@ LValue CodeGenFunction::EmitStmtExprLValue(const StmtExpr *E) {
                         AlignmentSource::Decl);
 }
 
-RValue CodeGenFunction::EmitCall(QualType CalleeType, const CGCallee &OrigCallee,
-                                 const CallExpr *E, ReturnValueSlot ReturnValue,
-                                 llvm::Value *Chain) {
+RValue CodeGenFunction::EmitCall(QualType CalleeType,
+                                 const CGCallee &OrigCallee, const CallExpr *E,
+                                 ReturnValueSlot ReturnValue,
+                                 llvm::Value *Chain,
+                                 llvm::CallBase **CallOrInvoke) {
   // Get the actual function type. The callee type will always be a pointer to
   // function type or a block pointer type.
   assert(CalleeType->isFunctionPointerType() &&
@@ -6076,8 +6097,8 @@ RValue CodeGenFunction::EmitCall(QualType CalleeType, const CGCallee &OrigCallee
         Address(Handle, Handle->getType(), CGM.getPointerAlign()));
     Callee.setFunctionPointer(Stub);
   }
-  llvm::CallBase *CallOrInvoke = nullptr;
-  RValue Call = EmitCall(FnInfo, Callee, ReturnValue, Args, &CallOrInvoke,
+  llvm::CallBase *LocalCallOrInvoke = nullptr;
+  RValue Call = EmitCall(FnInfo, Callee, ReturnValue, Args, &LocalCallOrInvoke,
                          E == MustTailCall, E->getExprLoc());
 
   // Generate function declaration DISuprogram in order to be used
@@ -6086,11 +6107,13 @@ RValue CodeGenFunction::EmitCall(QualType CalleeType, const CGCallee &OrigCallee
     if (auto *CalleeDecl = dyn_cast_or_null<FunctionDecl>(TargetDecl)) {
       FunctionArgList Args;
       QualType ResTy = BuildFunctionArgList(CalleeDecl, Args);
-      DI->EmitFuncDeclForCallSite(CallOrInvoke,
+      DI->EmitFuncDeclForCallSite(LocalCallOrInvoke,
                                   DI->getFunctionType(CalleeDecl, ResTy, Args),
                                   CalleeDecl);
     }
   }
+  if (CallOrInvoke)
+    *CallOrInvoke = LocalCallOrInvoke;
 
   return Call;
 }
diff --git a/clang/lib/CodeGen/CGExprCXX.cpp b/clang/lib/CodeGen/CGExprCXX.cpp
index 8eb6ab7381acbc..1214bb054fb8df 100644
--- a/clang/lib/CodeGen/CGExprCXX.cpp
+++ b/clang/lib/CodeGen/CGExprCXX.cpp
@@ -84,23 +84,24 @@ commonEmitCXXMemberOrOperatorCall(CodeGenFunction &CGF, GlobalDecl GD,
 
 RValue CodeGenFunction::EmitCXXMemberOrOperatorCall(
     const CXXMethodDecl *MD, const CGCallee &Callee,
-    ReturnValueSlot ReturnValue,
-    llvm::Value *This, llvm::Value *ImplicitParam, QualType ImplicitParamTy,
-    const CallExpr *CE, CallArgList *RtlArgs) {
+    ReturnValueSlot ReturnValue, llvm::Value *This, llvm::Value *ImplicitParam,
+    QualType ImplicitParamTy, const CallExpr *CE, CallArgList *RtlArgs,
+    llvm::CallBase **CallOrInvoke) {
   const FunctionProtoType *FPT = MD->getType()->castAs<FunctionProtoType>();
   CallArgList Args;
   MemberCallInfo CallInfo = commonEmitCXXMemberOrOperatorCall(
       *this, MD, This, ImplicitParam, ImplicitParamTy, CE, Args, RtlArgs);
   auto &FnInfo = CGM.getTypes().arrangeCXXMethodCall(
       Args, FPT, CallInfo.ReqArgs, CallInfo.PrefixSize);
-  return EmitCall(FnInfo, Callee, ReturnValue, Args, nullptr,
+  return EmitCall(FnInfo, Callee, ReturnValue, Args, CallOrInvoke,
                   CE && CE == MustTailCall,
                   CE ? CE->getExprLoc() : SourceLocation());
 }
 
 RValue CodeGenFunction::EmitCXXDestructorCall(
     GlobalDecl Dtor, const CGCallee &Callee, llvm::Value *This, QualType ThisTy,
-    llvm::Value *ImplicitParam, QualType ImplicitParamTy, const CallExpr *CE) {
+    llvm::Value *ImplicitParam, QualType ImplicitParamTy, const CallExpr *CE,
+    llvm::CallBase **CallOrInvoke) {
   const CXXMethodDecl *DtorDecl = cast<CXXMethodDecl>(Dtor.getDecl());
 
   assert(!ThisTy.isNull());
@@ -120,7 +121,8 @@ RValue CodeGenFunction::EmitCXXDestructorCall(
   commonEmitCXXMemberOrOperatorCall(*this, Dtor, This, ImplicitParam,
                                     ImplicitParamTy, CE, Args, nullptr);
   return EmitCall(CGM.getTypes().arrangeCXXStructorDeclaration(Dtor), Callee,
-                  ReturnValueSlot(), Args, nullptr, CE && CE == MustTailCall,
+                  ReturnValueSlot(), Args, CallOrInvoke,
+                  CE && CE == MustTailCall,
                   CE ? CE->getExprLoc() : SourceLocation{});
 }
 
@@ -186,11 +188,12 @@ static CXXRecordDecl *getCXXRecord(const Expr *E) {
 // Note: This function also emit constructor calls to support a MSVC
 // extensions allowing explicit constructor function call.
 RValue CodeGenFunction::EmitCXXMemberCallExpr(const CXXMemberCallExpr *CE,
-                                              ReturnValueSlot ReturnValue) {
+                                              ReturnValueSlot ReturnValue,
+                                              llvm::CallBase **CallOrInvoke) {
   const Expr *callee = CE->getCallee()->IgnoreParens();
 
   if (isa<BinaryOperator>(callee))
-    return EmitCXXMemberPointerCallExpr(CE, ReturnValue);
+    return EmitCXXMemberPointerCallExpr(CE, ReturnValue, CallOrInvoke);
 
   const MemberExpr *ME = cast<MemberExpr>(callee);
   const CXXMethodDecl *MD = cast<CXXMethodDecl>(ME->getMemberDecl());
@@ -200,7 +203,7 @@ RValue CodeGenFunction::EmitCXXMemberCallExpr(const CXXMemberCallExpr *CE,
     CGCallee callee =
         CGCallee::forDirect(CGM.GetAddrOfFunction(MD), GlobalDecl(MD));
     return EmitCall(getContext().getPointerType(MD->getType()), callee, CE,
-                    ReturnValue);
+                    ReturnValue, /*Chain=*/nullptr, CallOrInvoke);
   }
 
   bool HasQualifier = ME->hasQualifier();
@@ -208,14 +211,15 @@ RValue CodeGenFunction::EmitCXXMemberCallExpr(const CXXMemberCallExpr *CE,
   bool IsArrow = ME->isArrow();
   const Expr *Base = ME->getBase();
 
-  return EmitCXXMemberOrOperatorMemberCallExpr(
-      CE, MD, ReturnValue, HasQualifier, Qualifier, IsArrow, Base);
+  return EmitCXXMemberOrOperatorMemberCallExpr(CE, MD, ReturnValue,
+                                               HasQualifier, Qualifier, IsArrow,
+                                               Base, CallOrInvoke);
 }
 
 RValue CodeGenFunction::EmitCXXMemberOrOperatorMemberCallExpr(
     const CallExpr *CE, const CXXMethodDecl *MD, ReturnValueSlot ReturnValue,
     bool HasQualifier, NestedNameSpecifier *Qualifier, bool IsArrow,
-    const Expr *Base) {
+    const Expr *Base, llvm::CallBase **CallOrInvoke) {
   assert(isa<CXXMemberCallExpr>(CE) || isa<CXXOperatorCallExpr>(CE));
 
   // Compute the object pointer.
@@ -300,7 +304,7 @@ RValue CodeGenFunction::EmitCXXMemberOrOperatorMemberCallExpr(
     EmitCXXConstructorCall(Ctor, Ctor_Complete, /*ForVirtualBase=*/false,
                            /*Delegating=*/false, This.getAddress(), Args,
                            AggValueSlot::DoesNotOverlap, CE->getExprLoc(),
-                           /*NewPointerIsChecked=*/false);
+                           /*NewPointerIsChecked=*/false, CallOrInvoke);
     return RValue::get(nullptr);
   }
 
@@ -374,9 +378,9 @@ RValue CodeGenFunction::EmitCXXMemberOrOperatorMemberCallExpr(
            "Destructor shouldn't have explicit parameters");
     assert(ReturnValue.isNull() && "Destructor shouldn't have return value");
     if (UseVirtualCall) {
-      CGM.getCXXABI().EmitVirtualDestructorCall(*this, Dtor, Dtor_Complete,
-                                                This.getAddress(),
-                                                cast<CXXMemberCallExpr>(CE));
+      CGM.getCXXABI().EmitVirtualDestructorCall(
+          *this, Dtor, Dtor_Complete, This.getAddress(),
+          cast<CXXMemberCallExpr>(CE), CallOrInvoke);
     } else {
       GlobalDecl GD(Dtor, Dtor_Complete);
       CGCallee Callee;
@@ -393,7 +397,7 @@ RValue CodeGenFunction::EmitCXXMemberOrOperatorMemberCallExpr(
           IsArrow ? Base->getType()->getPointeeType() : Base->getType();
       EmitCXXDestructorCall(GD, Callee, This.getPointer(*this), ThisTy,
                             /*ImplicitParam=*/nullptr,
-                            /*ImplicitParamTy=*/QualType(), CE);
+                            /*ImplicitParamTy=*/QualType(), CE, CallOrInvoke);
     }
     return RValue::get(nullptr);
   }
@@ -435,12 +439,13 @@ RValue CodeGenFunction::EmitCXXMemberOrOperatorMemberCallExpr(
 
   return EmitCXXMemberOrOperatorCall(
       CalleeDecl, Callee, ReturnValue, This.getPointer(*this),
-      /*ImplicitParam=*/nullptr, QualType(), CE, RtlArgs);
+      /*ImplicitParam=*/nullptr, QualType(), CE, RtlArgs, CallOrInvoke);
 }
 
 RValue
 CodeGenFunction::EmitCXXMemberPointerCallExpr(const CXXMemberCallExpr *E,
-                                              ReturnValueSlot ReturnValue) {
+                                              ReturnValueSlot ReturnValue,
+                                              llvm::CallBase **CallOrInvoke) {
   const BinaryOperator *BO =
       cast<BinaryOperator>(E->getCallee()->IgnoreParens());
   const Expr *BaseExpr = BO->getLHS();
@@ -484,24 +489,25 @@ CodeGenFunction::EmitCXXMemberPointerCallExpr(const CXXMemberCallExpr *E,
   EmitCallArgs(Args, FPT, E->arguments());
   return EmitCall(CGM.getTypes().arrangeCXXMethodCall(Args, FPT, required,
                                                       /*PrefixSize=*/0),
-                  Callee, ReturnValue, Args, nullptr, E == MustTailCall,
+                  Callee, ReturnValue, Args, CallOrInvoke, E == MustTailCall,
                   E->getExprLoc());
 }
 
-RValue
-CodeGenFunction::EmitCXXOperatorMemberCallExpr(const CXXOperatorCallExpr *E,
-                                               const CXXMethodDecl *MD,
-                                               ReturnValueSlot ReturnValue) {
+RValue CodeGenFunction::EmitCXXOperatorMemberCallExpr(
+    const CXXOperatorCallExpr *E, const CXXMethodDecl *MD,
+    ReturnValueSlot ReturnValue, llvm::CallBase **CallOrInvoke) {
   assert(MD->isImplicitObjectMemberFunction() &&
          "Trying to emit a member call expr on a static method!");
   return EmitCXXMemberOrOperatorMemberCallExpr(
       E, MD, ReturnValue, /*HasQualifier=*/false, /*Qualifier=*/nullptr,
-      /*IsArrow=*/false, E->getArg(0));
+      /*IsArrow=*/false, E->getArg(0), CallOrInvoke);
 }
 
 RValue CodeGenFunction::EmitCUDAKernelCallExpr(const CUDAKernelCallExpr *E,
-                                               ReturnValueSlot ReturnValue) {
-  return CGM.getCUDARuntime().EmitCUDAKernelCallExpr(*this, E, ReturnValue);
+                                               ReturnValueSlot ReturnValue,
+                                               llvm::CallBase **CallOrInvoke) {
+  return CGM.getCUDARuntime().EmitCUDAKernelCallExpr(*this, E, ReturnValue,
+                                                     CallOrInvoke);
 }
 
 static void EmitNullBaseClassInitialization(CodeGenFunction &CGF,
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 05f85f8b95bfa2..d116dbc0eff50d 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -3150,7 +3150,8 @@ class CodeGenFunction : public CodeGenTypeCache {
                               bool ForVirtualBase, bool Delegating,
                               Address This, CallArgList &Args,
                               AggValueSlot::Overlap_t Overlap,
-                              SourceLocation Loc, bool NewPointerIsChecked);
+                              SourceLocation Loc, bool NewPointerIsChecked,
+                              llvm::CallBase **CallOrInvoke = nullptr);
 
   /// Emit assumption load for all bases. Requires to be called only on
   /// most-derived class and not under construction of the object.
@@ -4270,7 +4271,8 @@ class CodeGenFunction : public CodeGenTypeCache {
   LValue EmitBinaryOperatorLValue(const BinaryOperator *E);
   LValue EmitCompoundAssignmentLValue(const CompoundAssignOperator *E);
   // Note: only available for agg return types
-  LValue EmitCallExprLValue(const CallExpr *E);
+  LValue EmitCallExprLValue(const CallExpr *E,
+                            llvm::CallBase **CallOrInvoke = nullptr);
   // Note: only available for agg return types
   LValue EmitVAArgExprLValue(const VAArgExpr *E);
   LValue EmitDeclRefLValue(const DeclRefExpr *E);
@@ -4381,21 +4383,27 @@ class CodeGenFunction : public CodeGenTypeCache {
   /// LLVM arguments and the types they were derived from.
   RValue EmitCall(const CGFunctionInfo &CallInfo, const CGCallee &Callee,
                   ReturnValueSlot ReturnValue, const CallArgList &Args,
-                  llvm::CallBase **callOrInvoke, bool IsMustTail,
+                  llvm::CallBase **CallOrInvoke, bool IsMustTail,
                   SourceLocation Loc,
                   bool IsVirtualFunctionPointerThunk = false);
   RValue EmitCall(const CGFunctionInfo &CallInfo, const CGCallee &Callee,
                   ReturnValueSlot ReturnValue, const CallArgList &Args,
-                  llvm::CallBase **callOrInvoke = nullptr,
+                  llvm::CallBase **CallOrInvoke = nullptr,
                   bool IsMustTail = false) {
-    return EmitCall(CallInfo, Callee, ReturnValue, Args, callOrInvoke,
+    return EmitCall(CallInfo, Callee, ReturnValue, Args, CallOrInvoke,
                     IsMustTail, SourceLocation());
   }
   RValue EmitCall(QualType FnType, const CGCallee &Callee, const CallExpr *E,
-                  ReturnValueSlot ReturnValue, llvm::Value *Chain = nullptr);
+                  ReturnValueSlot ReturnValue, llvm::Value *Chain = nullptr,
+                  llvm::CallBase **CallOrInvoke = nullptr);
+
+  // If a Call or Invoke instruction was emitted for this CallExpr, this method
+  // writes the pointer to `CallOrInvoke` if it's not null.
   RValue EmitCallExpr(const CallExpr *E,
-                      ReturnValueSlot ReturnValue = ReturnValueSlot());
-  RValue EmitSimpleCallExpr(const CallExpr *E, ReturnValueSlot ReturnValue);
+                      ReturnValueSlot ReturnValue = ReturnValueSlot(),
+                      llvm::CallBase **CallOrInvoke = nullptr);
+  RValue EmitSimpleCallExpr(const CallExpr *E, ReturnValueSlot ReturnValue,
+                            llvm::CallBase **CallOrInvoke = nullptr);
   CGCallee EmitCallee(const Expr *E);
 
   void checkTargetFeatures(const CallExpr *E, const FunctionDecl *TargetDecl);
@@ -4499,25 +4507,23 @@ class CodeGenFunction : public CodeGenTypeCache {
   void callCStructCopyAssignmentOperator(LValue Dst, LValue Src);
   void callCStructMoveAssignmentOperator(LValue Dst, LValue Src);
 
-  RValue
-  EmitCXXMemberOrOperatorCall(const CXXMethodDecl *Method,
-                              const CGCallee &Callee,
-                              ReturnValueSlot ReturnValue, llvm::Value *This,
-                              llvm::Value *ImplicitParam,
-                              QualType ImplicitParamTy, const CallExpr *E,
-                              CallArgList *RtlArgs);
+  RValue EmitCXXMemberOrOperatorCall(
+      const CXXMethodDecl *Method, const CGCallee &Callee,
+      ReturnValueSlot ReturnValue, llvm::Value *This,
+      llvm::Value *ImplicitParam, QualType ImplicitParamTy, const CallExpr *E,
+      CallArgList *RtlArgs, llvm::CallBase **CallOrInvoke);
   RValue EmitCXXDestructorCall(GlobalDecl Dtor, const CGCallee &Callee,
                                llvm::Value *This, QualType ThisTy,
                                llvm::Value *ImplicitParam,
-                               QualType ImplicitParamTy, const CallExpr *E);
+                               QualType ImplicitParamTy, const CallExpr *E,
+                               llvm::CallBase **CallOrInvoke = nullptr);
   RValue EmitCXXMemberCallExpr(const CXXMemberCallExpr *E,
-                               ReturnValueSlot ReturnValue);
-  RValue EmitCXXMemberOrOperatorMemberCallExpr(const CallExpr *CE,
-                                               const CXXMethodDecl *MD,
-                                               ReturnValueSlot ReturnValue,
-                                               bool HasQualifier,
-                                               NestedNameSpecifier *Qualifier,
-                                               bool IsArrow, const Expr *Base);
+                               ReturnValueSlot ReturnValue,
+                               llvm::CallBase **CallOrInvoke = nullptr);
+  RValue EmitCXXMemberOrOperatorMemberCallExpr(
+      const CallExpr *CE, const CXXMethodDecl *MD, ReturnValueSlot ReturnValue,
+      bool HasQualifier, NestedNameSpecifier *Qualifier, bool IsArrow,
+      const Expr *Base, llvm::CallBase **CallOrInvoke);
   // Compute the object pointer.
   Address EmitCXXMemberDataPointerAddress(const Expr *E, Address base,
                                           llvm::Value *memberPtr,
@@ -4525,15 +4531,18 @@ class CodeGenFunction : public CodeGenTypeCache {
                                           LValueBaseInfo *BaseInfo = nullptr,
                                           TBAAAccessInfo *TBAAInfo = nullptr);
   RValue EmitCXXMemberPointerCallExpr(const CXXMemberCallExpr *E,
-                                      ReturnValueSlot ReturnValue);
+                                      ReturnValueSlot ReturnValue,
+                                      llvm::CallBase **CallOrInvoke);
 
   RValue EmitCXXOperatorMemberCallExpr(const CXXOperatorCallExpr *E,
                                        const CXXMethodDecl *MD,
-                                       ReturnValueSlot ReturnValue);
+                                       ReturnValueSlot ReturnValue,
+                                       llvm::CallBase **CallOrInvoke);
   RValue EmitCXXPseudoDestructorExpr(const CXXPseudoDestructorExpr *E);
 
   RValue EmitCUDAKernelCallExpr(const CUDAKernelCallExpr *E,
-                                ReturnValueSlot ReturnValue);
+                                ReturnValueSlot ReturnValue,
+                                llvm::CallBase **CallOrInvoke);
 
   RValue EmitNVPTXDevicePrintfCallExpr(const CallExpr *E);
   RValue EmitAMDGPUDevicePrintfCallExpr(const CallExpr *E);
@@ -4555,7 +4564,8 @@ class CodeGenFunction : public CodeGenTypeCache {
       const analyze_os_log::OSLogBufferLayout &Layout,
       CharUnits BufferAlignment);
 
-  RValue EmitBlockCallExpr(const CallExpr *E, ReturnValueSlot ReturnValue);
+  RValue EmitBlockCallExpr(const CallExpr *E, ReturnValueSlot ReturnValue,
+                           llvm::CallBase **CallOrInvoke);
 
   /// EmitTargetBuiltinExpr - Emit the given builtin call. Returns 0 if the call
   /// is unhandled by the current target.
diff --git a/clang/lib/CodeGen/ItaniumCXXABI.cpp b/clang/lib/CodeGen/ItaniumCXXABI.cpp
index 0cde8a192eda08..28b0880d013e5b 100644
--- a/clang/lib/CodeGen/ItaniumCXXABI.cpp
+++ b/clang/lib/CodeGen/ItaniumCXXABI.cpp
@@ -315,10 +315,11 @@ class ItaniumCXXABI : public CodeGen::CGCXXABI {
                                      Address This, llvm::Type *Ty,
                                      SourceLocation Loc) override;
 
-  llvm::Value *EmitVirtualDestructorCall(CodeGenFunction &CGF,
-                                         const CXXDestructorDecl *Dtor,
-                                         CXXDtorType DtorType, Address This,
-                                         DeleteOrMemberCallExpr E) override;
+  llvm::Value *
+  EmitVirtualDestructorCall(CodeGenFunction &CGF, const CXXDestructorDecl *Dtor,
+                            CXXDtorType DtorType, Address This,
+                            DeleteOrMemberCallExpr E,
+                            llvm::CallBase **CallOrInvoke) override;
 
   void emitVirtualInheritanceTables(const CXXRecordDecl *RD) override;
 
@@ -1396,7 +1397,8 @@ void ItaniumCXXABI::emitVirtualObjectDelete(CodeGenFunction &CGF,
   // FIXME: Provide a source location here even though there's no
   // CXXMemberCallExpr for dtor call.
   CXXDtorType DtorType = UseGlobalDelete ? Dtor_Complete : Dtor_Deleting;
-  EmitVirtualDestructorCall(CGF, Dtor, DtorType, Ptr, DE);
+  EmitVirtualDestructorCall(CGF, Dtor, DtorType, Ptr, DE,
+                            /*CallOrInvoke=*/nullptr);
 
   if (UseGlobalDelete)
     CGF.PopCleanupBlock();
@@ -2233,7 +2235,7 @@ CGCallee ItaniumCXXABI::getVirtualFunctionPointer(CodeGenFunction &CGF,
 
 llvm::Value *ItaniumCXXABI::EmitVirtualDestructorCall(
     CodeGenFunction &CGF, const CXXDestructorDecl *Dtor, CXXDtorType DtorType,
-    Address This, DeleteOrMemberCallExpr E) {
+    Address This, DeleteOrMemberCallExpr E, llvm::CallBase **CallOrInvoke) {
   auto *CE = E.dyn_cast<const CXXMemberCallExpr *>();
   auto *D = E.dyn_cast<const CXXDeleteExpr *>();
   assert((CE != nullptr) ^ (D != nullptr));
@@ -2254,7 +2256,7 @@ llvm::Value *ItaniumCXXABI::EmitVirtualDestructorCall(
   }
 
   CGF.EmitCXXDestructorCall(GD, Callee, This.emitRawPointer(CGF), ThisTy,
-                            nullptr, QualType(), nullptr);
+                            nullptr, QualType(), nullptr, CallOrInvoke);
   return nullptr;
 }
 
diff --git a/clang/lib/CodeGen/MicrosoftCXXABI.cpp b/clang/lib/CodeGen/MicrosoftCXXABI.cpp
index 76d0191a7e63ad..79dcdc04b0996f 100644
--- a/clang/lib/CodeGen/MicrosoftCXXABI.cpp
+++ b/clang/lib/CodeGen/MicrosoftCXXABI.cpp
@@ -334,10 +334,11 @@ class MicrosoftCXXABI : public CGCXXABI {
                                      Address This, llvm::Type *Ty,
                                      SourceLocation Loc) override;
 
-  llvm::Value *EmitVirtualDestructorCall(CodeGenFunction &CGF,
-                                         const CXXDestructorDecl *Dtor,
-                                         CXXDtorType DtorType, Address This,
-                                         DeleteOrMemberCallExpr E) override;
+  llvm::Value *
+  EmitVirtualDestructorCall(CodeGenFunction &CGF, const CXXDestructorDecl *Dtor,
+                            CXXDtorType DtorType, Address This,
+                            DeleteOrMemberCallExpr E,
+                            llvm::CallBase **CallOrInvoke) override;
 
   void adjustCallArgsForDestructorThunk(CodeGenFunction &CGF, GlobalDecl GD,
                                         CallArgList &CallArgs) override {
@@ -901,7 +902,8 @@ void MicrosoftCXXABI::emitVirtualObjectDelete(CodeGenFunction &CGF,
   // CXXMemberCallExpr for dtor call.
   bool UseGlobalDelete = DE->isGlobalDelete();
   CXXDtorType DtorType = UseGlobalDelete ? Dtor_Complete : Dtor_Deleting;
-  llvm::Value *MDThis = EmitVirtualDestructorCall(CGF, Dtor, DtorType, Ptr, DE);
+  llvm::Value *MDThis = EmitVirtualDestructorCall(CGF, Dtor, DtorType, Ptr, DE,
+                                                  /*CallOrInvoke=*/nullptr);
   if (UseGlobalDelete)
     CGF.EmitDeleteCall(DE->getOperatorDelete(), MDThis, ElementType);
 }
@@ -1685,7 +1687,7 @@ void MicrosoftCXXABI::EmitDestructorCall(CodeGenFunction &CGF,
   CGF.EmitCXXDestructorCall(GD, Callee, CGF.getAsNaturalPointerTo(This, ThisTy),
                             ThisTy,
                             /*ImplicitParam=*/Implicit,
-                            /*ImplicitParamTy=*/QualType(), nullptr);
+                            /*ImplicitParamTy=*/QualType(), /*E=*/nullptr);
   if (BaseDtorEndBB) {
     // Complete object handler should continue to be the remaining
     CGF.Builder.CreateBr(BaseDtorEndBB);
@@ -2001,7 +2003,7 @@ CGCallee MicrosoftCXXABI::getVirtualFunctionPointer(CodeGenFunction &CGF,
 
 llvm::Value *MicrosoftCXXABI::EmitVirtualDestructorCall(
     CodeGenFunction &CGF, const CXXDestructorDecl *Dtor, CXXDtorType DtorType,
-    Address This, DeleteOrMemberCallExpr E) {
+    Address This, DeleteOrMemberCallExpr E, llvm::CallBase **CallOrInvoke) {
   auto *CE = E.dyn_cast<const CXXMemberCallExpr *>();
   auto *D = E.dyn_cast<const CXXDeleteExpr *>();
   assert((CE != nullptr) ^ (D != nullptr));
@@ -2031,7 +2033,7 @@ llvm::Value *MicrosoftCXXABI::EmitVirtualDestructorCall(
   This = adjustThisArgumentForVirtualFunctionCall(CGF, GD, This, true);
   RValue RV =
       CGF.EmitCXXDestructorCall(GD, Callee, This.emitRawPointer(CGF), ThisTy,
-                                ImplicitParam, Context.IntTy, CE);
+                                ImplicitParam, Context.IntTy, CE, CallOrInvoke);
   return RV.getScalarVal();
 }
 
diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp
index 1bb8955f6f8792..a574d56646f3a2 100644
--- a/clang/lib/Sema/SemaCoroutine.cpp
+++ b/clang/lib/Sema/SemaCoroutine.cpp
@@ -844,6 +844,19 @@ ExprResult Sema::BuildOperatorCoawaitLookupExpr(Scope *S, SourceLocation Loc) {
   return CoawaitOp;
 }
 
+static bool isAttributedCoroAwaitElidable(const QualType &QT) {
+  auto *Record = QT->getAsCXXRecordDecl();
+  return Record && Record->hasAttr<CoroAwaitElidableAttr>();
+}
+
+static bool isCoroAwaitElidableCall(Expr *Operand) {
+  if (!Operand->isPRValue()) {
+    return false;
+  }
+
+  return isAttributedCoroAwaitElidable(Operand->getType());
+}
+
 // Attempts to resolve and build a CoawaitExpr from "raw" inputs, bailing out to
 // DependentCoawaitExpr if needed.
 ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *Operand,
@@ -867,7 +880,16 @@ ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *Operand,
   }
 
   auto *RD = Promise->getType()->getAsCXXRecordDecl();
-  auto *Transformed = Operand;
+  bool AwaitElidable =
+      isCoroAwaitElidableCall(Operand) &&
+      isAttributedCoroAwaitElidable(
+          getCurFunctionDecl(/*AllowLambda=*/true)->getReturnType());
+
+  if (AwaitElidable)
+    if (auto *Call = dyn_cast<CallExpr>(Operand->IgnoreImplicit()))
+      Call->setCoroElideSafe();
+
+  Expr *Transformed = Operand;
   if (lookupMember(*this, "await_transform", RD, Loc)) {
     ExprResult R =
         buildPromiseCall(*this, Promise, Loc, "await_transform", Operand);
diff --git a/clang/test/CodeGenCoroutines/Inputs/utility.h b/clang/test/CodeGenCoroutines/Inputs/utility.h
new file mode 100644
index 00000000000000..43c6d27823bd47
--- /dev/null
+++ b/clang/test/CodeGenCoroutines/Inputs/utility.h
@@ -0,0 +1,13 @@
+// This is a mock file for <utility>
+
+namespace std {
+
+template <typename T> struct remove_reference { using type = T; };
+template <typename T> struct remove_reference<T &> { using type = T; };
+template <typename T> struct remove_reference<T &&> { using type = T; };
+
+template <typename T>
+constexpr typename std::remove_reference<T>::type&& move(T &&t) noexcept {
+  return static_cast<typename std::remove_reference<T>::type &&>(t);
+}
+}
diff --git a/clang/test/CodeGenCoroutines/coro-await-elidable.cpp b/clang/test/CodeGenCoroutines/coro-await-elidable.cpp
new file mode 100644
index 00000000000000..8512995dfad45a
--- /dev/null
+++ b/clang/test/CodeGenCoroutines/coro-await-elidable.cpp
@@ -0,0 +1,87 @@
+// This file tests the coro_await_elidable attribute semantics.
+// RUN: %clang_cc1 -triple=x86_64-unknown-linux-gnu -std=c++20 -disable-llvm-passes -emit-llvm %s -o - | FileCheck %s
+
+#include "Inputs/coroutine.h"
+#include "Inputs/utility.h"
+
+template <typename T>
+struct [[clang::coro_await_elidable]] Task {
+  struct promise_type {
+    struct FinalAwaiter {
+      bool await_ready() const noexcept { return false; }
+
+      template <typename P>
+      std::coroutine_handle<> await_suspend(std::coroutine_handle<P> coro) noexcept {
+        if (!coro)
+          return std::noop_coroutine();
+        return coro.promise().continuation;
+      }
+      void await_resume() noexcept {}
+    };
+
+    Task get_return_object() noexcept {
+      return std::coroutine_handle<promise_type>::from_promise(*this);
+    }
+
+    std::suspend_always initial_suspend() noexcept { return {}; }
+    FinalAwaiter final_suspend() noexcept { return {}; }
+    void unhandled_exception() noexcept {}
+    void return_value(T x) noexcept {
+      value = x;
+    }
+
+    std::coroutine_handle<> continuation;
+    T value;
+  };
+
+  Task(std::coroutine_handle<promise_type> handle) : handle(handle) {}
+  ~Task() {
+    if (handle)
+      handle.destroy();
+  }
+
+  struct Awaiter {
+    Awaiter(Task *t) : task(t) {}
+    bool await_ready() const noexcept { return false; }
+    void await_suspend(std::coroutine_handle<void> continuation) noexcept {}
+    T await_resume() noexcept {
+      return task->handle.promise().value;
+    }
+
+    Task *task;
+  };
+
+  auto operator co_await() {
+    return Awaiter{this};
+  }
+
+private:
+  std::coroutine_handle<promise_type> handle;
+};
+
+// CHECK-LABEL: define{{.*}} @_Z6calleev{{.*}} {
+Task<int> callee() {
+  co_return 1;
+}
+
+// CHECK-LABEL: define{{.*}} @_Z8elidablev{{.*}} {
+Task<int> elidable() {
+  // CHECK: %[[TASK_OBJ:.+]] = alloca %struct.Task
+  // CHECK: call void @_Z6calleev(ptr dead_on_unwind writable sret(%struct.Task) align 8 %[[TASK_OBJ]]) #[[ELIDE_SAFE:.+]]
+  co_return co_await callee();
+}
+
+// CHECK-LABEL: define{{.*}} @_Z11nonelidablev{{.*}} {
+Task<int> nonelidable() {
+  // CHECK: %[[TASK_OBJ:.+]] = alloca %struct.Task
+  auto t = callee();
+  // Because we aren't co_awaiting a prvalue, we cannot elide here.
+  // CHECK: call void @_Z6calleev(ptr dead_on_unwind writable sret(%struct.Task) align 8 %[[TASK_OBJ]])
+  // CHECK-NOT: #[[ELIDE_SAFE]]
+  co_await t;
+  co_await std::move(t);
+
+  co_return 1;
+}
+
+// CHECK: attributes #[[ELIDE_SAFE]] = { coro_elide_safe }
diff --git a/clang/test/Misc/pragma-attribute-supported-attributes-list.test b/clang/test/Misc/pragma-attribute-supported-attributes-list.test
index 5ebbd29b316bfa..da15c48c9ff2c1 100644
--- a/clang/test/Misc/pragma-attribute-supported-attributes-list.test
+++ b/clang/test/Misc/pragma-attribute-supported-attributes-list.test
@@ -59,6 +59,7 @@
 // CHECK-NEXT: ConsumableAutoCast (SubjectMatchRule_record)
 // CHECK-NEXT: ConsumableSetOnRead (SubjectMatchRule_record)
 // CHECK-NEXT: Convergent (SubjectMatchRule_function)
+// CHECK-NEXT: CoroAwaitElidable (SubjectMatchRule_record)
 // CHECK-NEXT: CoroDisableLifetimeBound (SubjectMatchRule_function)
 // CHECK-NEXT: CoroLifetimeBound (SubjectMatchRule_record)
 // CHECK-NEXT: CoroOnlyDestroyWhenComplete (SubjectMatchRule_record)
diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
index 8a2e6583af87c5..79573e864557d3 100644
--- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -760,6 +760,7 @@ enum AttributeKindCodes {
   ATTR_KIND_HYBRID_PATCHABLE = 95,
   ATTR_KIND_SANITIZE_REALTIME = 96,
   ATTR_KIND_NO_SANITIZE_REALTIME = 97,
+  ATTR_KIND_CORO_ELIDE_SAFE = 98,
 };
 
 enum ComdatSelectionKindCodes {
diff --git a/llvm/include/llvm/IR/Attributes.td b/llvm/include/llvm/IR/Attributes.td
index 80936c0ee83355..8feaadd4231351 100644
--- a/llvm/include/llvm/IR/Attributes.td
+++ b/llvm/include/llvm/IR/Attributes.td
@@ -348,6 +348,10 @@ def PresplitCoroutine : EnumAttr<"presplitcoroutine", [FnAttr]>;
 /// The coroutine would only be destroyed when it is complete.
 def CoroDestroyOnlyWhenComplete : EnumAttr<"coro_only_destroy_when_complete", [FnAttr]>;
 
+/// The coroutine call meets the elide requirement. Hint the optimization
+/// pipeline to perform elide on the call or invoke instruction.
+def CoroElideSafe : EnumAttr<"coro_elide_safe", [FnAttr]>;
+
 /// Target-independent string attributes.
 def LessPreciseFPMAD : StrBoolAttr<"less-precise-fpmad">;
 def NoInfsFPMath : StrBoolAttr<"no-infs-fp-math">;
diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index 974a05023c72a5..4912497e20ef61 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -2189,6 +2189,8 @@ static Attribute::AttrKind getAttrFromCode(uint64_t Code) {
     return Attribute::Range;
   case bitc::ATTR_KIND_INITIALIZES:
     return Attribute::Initializes;
+  case bitc::ATTR_KIND_CORO_ELIDE_SAFE:
+    return Attribute::CoroElideSafe;
   }
 }
 
diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index 3c5097f4af7c56..1ad44abba7fc79 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -883,6 +883,8 @@ static uint64_t getAttrKindEncoding(Attribute::AttrKind Kind) {
     return bitc::ATTR_KIND_WRITABLE;
   case Attribute::CoroDestroyOnlyWhenComplete:
     return bitc::ATTR_KIND_CORO_ONLY_DESTROY_WHEN_COMPLETE;
+  case Attribute::CoroElideSafe:
+    return bitc::ATTR_KIND_CORO_ELIDE_SAFE;
   case Attribute::DeadOnUnwind:
     return bitc::ATTR_KIND_DEAD_ON_UNWIND;
   case Attribute::Range:
diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index cf00299812bb7f..3ecf7e3f4fa795 100644
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -916,6 +916,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
       case Attribute::Memory:
       case Attribute::NoFPClass:
       case Attribute::CoroDestroyOnlyWhenComplete:
+      case Attribute::CoroElideSafe:
         continue;
       // Those attributes should be safe to propagate to the extracted function.
       case Attribute::AlwaysInline:



More information about the cfe-commits mailing list