[llvm] e17a39b - [Clang] C++20 Coroutines: Introduce Frontend Attribute [[clang::coro_await_elidable]] (#99282)

via llvm-commits llvm-commits at lists.llvm.org
Sun Sep 8 23:09:03 PDT 2024


Author: Yuxuan Chen
Date: 2024-09-08T23:08:58-07:00
New Revision: e17a39bc314f97231e440c9e68d9f46a9c07af6d

URL: https://github.com/llvm/llvm-project/commit/e17a39bc314f97231e440c9e68d9f46a9c07af6d
DIFF: https://github.com/llvm/llvm-project/commit/e17a39bc314f97231e440c9e68d9f46a9c07af6d.diff

LOG: [Clang] C++20 Coroutines: Introduce Frontend Attribute [[clang::coro_await_elidable]] (#99282)

This patch is the frontend implementation of the coroutine elide
improvement project detailed in this discourse post:
https://discourse.llvm.org/t/language-extension-for-better-more-deterministic-halo-for-c-coroutines/80044

This patch proposes a C++ struct/class attribute
`[[clang::coro_await_elidable]]`. This notion of await elidable task
gives developers and library authors a certainty that coroutine heap
elision happens in a predictable way.

Originally, after we lower a coroutine to LLVM IR, CoroElide is
responsible for analysis of whether an elision can happen. Take this as
an example:
```
Task foo();
Task bar() {
  co_await foo();
}
```
For CoroElide to happen, the ramp function of `foo` must be inlined into
`bar`. This inlining happens after `foo` has been split but `bar` is
usually still a presplit coroutine. If `foo` is indeed a coroutine, the
inlined `coro.id` intrinsics of `foo` is visible within `bar`. CoroElide
then runs an analysis to figure out whether the SSA value of
`coro.begin()` of `foo` gets destroyed before `bar` terminates.

`Task` types are rarely simple enough for the destroy logic of the task
to reference the SSA value from `coro.begin()` directly. Hence, the pass
is very ineffective for even the most trivial C++ Task types. Improving
CoroElide by implementing more powerful analyses is possible, however it
doesn't give us the predictability when we expect elision to happen.

The approach we want to take with this language extension generally
originates from the philosophy that library implementations of `Task`
types has the control over the structured concurrency guarantees we
demand for elision to happen. That is, the lifetime for the callee's
frame is shorter to that of the caller.

The ``[[clang::coro_await_elidable]]`` is a class attribute which can be
applied to a coroutine return type.

When a coroutine function that returns such a type calls another
coroutine function, the compiler performs heap allocation elision when
the following conditions are all met:
- callee coroutine function returns a type that is annotated with
``[[clang::coro_await_elidable]]``.
- In caller coroutine, the return value of the callee is a prvalue that
is immediately `co_await`ed.

>From the C++ perspective, it makes sense because we can ensure the
lifetime of elided callee cannot exceed that of the caller if we can
guarantee that the caller coroutine is never destroyed earlier than the
callee coroutine. This is not generally true for any C++ programs.
However, the library that implements `Task` types and executors may
provide this guarantee to the compiler, providing the user with
certainty that HALO will work on their programs.

After this patch, when compiling coroutines that return a type with such
attribute, the frontend checks that the type of the operand of
`co_await` expressions (not `operator co_await`). If it's also
attributed with `[[clang::coro_await_elidable]]`, the FE emits metadata
on the call or invoke instruction as a hint for a later middle end pass
to elide the elision.

The original patch version is
https://github.com/llvm/llvm-project/pull/94693 and as suggested, the
patch is split into frontend and middle end solutions into stacked PRs.

The middle end CoroSplit patch can be found at
https://github.com/llvm/llvm-project/pull/99283
The middle end transformation that performs the elide can be found at
https://github.com/llvm/llvm-project/pull/99285

Added: 
    clang/test/CodeGenCoroutines/Inputs/utility.h
    clang/test/CodeGenCoroutines/coro-await-elidable.cpp

Modified: 
    clang/docs/ReleaseNotes.rst
    clang/include/clang/AST/Expr.h
    clang/include/clang/AST/Stmt.h
    clang/include/clang/Basic/Attr.td
    clang/include/clang/Basic/AttrDocs.td
    clang/lib/AST/Expr.cpp
    clang/lib/CodeGen/CGBlocks.cpp
    clang/lib/CodeGen/CGCUDARuntime.cpp
    clang/lib/CodeGen/CGCUDARuntime.h
    clang/lib/CodeGen/CGCXXABI.h
    clang/lib/CodeGen/CGClass.cpp
    clang/lib/CodeGen/CGExpr.cpp
    clang/lib/CodeGen/CGExprCXX.cpp
    clang/lib/CodeGen/CodeGenFunction.h
    clang/lib/CodeGen/ItaniumCXXABI.cpp
    clang/lib/CodeGen/MicrosoftCXXABI.cpp
    clang/lib/Sema/SemaCoroutine.cpp
    clang/test/Misc/pragma-attribute-supported-attributes-list.test
    llvm/include/llvm/Bitcode/LLVMBitCodes.h
    llvm/include/llvm/IR/Attributes.td
    llvm/lib/Bitcode/Reader/BitcodeReader.cpp
    llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
    llvm/lib/Transforms/Utils/CodeExtractor.cpp

Removed: 
    


################################################################################
diff  --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst
index f7c3194c91fa31..a300829cf0e32c 100644
--- a/clang/docs/ReleaseNotes.rst
+++ b/clang/docs/ReleaseNotes.rst
@@ -246,6 +246,9 @@ Attribute Changes in Clang
   instantiation by accidentally allowing it in C++ in some circumstances.
   (#GH106864)
 
+- Introduced a new attribute ``[[clang::coro_await_elidable]]`` on coroutine return types
+  to express elideability at call sites where the coroutine is co_awaited as a prvalue.
+
 Improvements to Clang's diagnostics
 -----------------------------------
 

diff  --git a/clang/include/clang/AST/Expr.h b/clang/include/clang/AST/Expr.h
index 65104acda93825..66c746cc25040f 100644
--- a/clang/include/clang/AST/Expr.h
+++ b/clang/include/clang/AST/Expr.h
@@ -2991,6 +2991,9 @@ class CallExpr : public Expr {
 
   bool hasStoredFPFeatures() const { return CallExprBits.HasFPFeatures; }
 
+  bool isCoroElideSafe() const { return CallExprBits.IsCoroElideSafe; }
+  void setCoroElideSafe(bool V = true) { CallExprBits.IsCoroElideSafe = V; }
+
   Decl *getCalleeDecl() { return getCallee()->getReferencedDeclOfCallee(); }
   const Decl *getCalleeDecl() const {
     return getCallee()->getReferencedDeclOfCallee();

diff  --git a/clang/include/clang/AST/Stmt.h b/clang/include/clang/AST/Stmt.h
index f1a2aac0a8b2f8..7aed83e9c68bb7 100644
--- a/clang/include/clang/AST/Stmt.h
+++ b/clang/include/clang/AST/Stmt.h
@@ -561,8 +561,11 @@ class alignas(void *) Stmt {
     LLVM_PREFERRED_TYPE(bool)
     unsigned HasFPFeatures : 1;
 
+    /// True if the call expression is a must-elide call to a coroutine.
+    unsigned IsCoroElideSafe : 1;
+
     /// Padding used to align OffsetToTrailingObjects to a byte multiple.
-    unsigned : 24 - 3 - NumExprBits;
+    unsigned : 24 - 4 - NumExprBits;
 
     /// The offset in bytes from the this pointer to the start of the
     /// trailing objects belonging to CallExpr. Intentionally byte sized

diff  --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 0c98f8e25a6fbb..9a7b163b2c6da8 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -1250,6 +1250,14 @@ def CoroDisableLifetimeBound : InheritableAttr {
   let SimpleHandler = 1;
 }
 
+def CoroAwaitElidable : InheritableAttr {
+  let Spellings = [Clang<"coro_await_elidable">];
+  let Subjects = SubjectList<[CXXRecord]>;
+  let LangOpts = [CPlusPlus];
+  let Documentation = [CoroAwaitElidableDoc];
+  let SimpleHandler = 1;
+}
+
 // OSObject-based attributes.
 def OSConsumed : InheritableParamAttr {
   let Spellings = [Clang<"os_consumed">];

diff  --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td
index ef077db298831f..546e5100b79dd9 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -8255,6 +8255,38 @@ but do not pass them to the underlying coroutine or pass them by value.
 }];
 }
 
+def CoroAwaitElidableDoc : Documentation {
+  let Category = DocCatDecl;
+  let Content = [{
+The ``[[clang::coro_await_elidable]]`` is a class attribute which can be applied
+to a coroutine return type.
+
+When a coroutine function that returns such a type calls another coroutine function,
+the compiler performs heap allocation elision when the call to the coroutine function
+is immediately co_awaited as a prvalue. In this case, the coroutine frame for the
+callee will be a local variable within the enclosing braces in the caller's stack
+frame. And the local variable, like other variables in coroutines, may be collected
+into the coroutine frame, which may be allocated on the heap.
+
+Example:
+
+.. code-block:: c++
+
+  class [[clang::coro_await_elidable]] Task { ... };
+
+  Task foo();
+  Task bar() {
+    co_await foo(); // foo()'s coroutine frame on this line is elidable
+    auto t = foo(); // foo()'s coroutine frame on this line is NOT elidable
+    co_await t;
+  }
+
+The behavior is undefined if the caller coroutine is destroyed earlier than the
+callee coroutine.
+
+}];
+}
+
 def CountedByDocs : Documentation {
   let Category = DocCatField;
   let Content = [{
@@ -8414,4 +8446,3 @@ Declares that a function potentially allocates heap memory, and prevents any pot
 of ``nonallocating`` by the compiler.
   }];
 }
-

diff  --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp
index 27930db019a172..6545912ed160d9 100644
--- a/clang/lib/AST/Expr.cpp
+++ b/clang/lib/AST/Expr.cpp
@@ -1475,6 +1475,7 @@ CallExpr::CallExpr(StmtClass SC, Expr *Fn, ArrayRef<Expr *> PreArgs,
   this->computeDependence();
 
   CallExprBits.HasFPFeatures = FPFeatures.requiresTrailingStorage();
+  CallExprBits.IsCoroElideSafe = false;
   if (hasStoredFPFeatures())
     setStoredFPFeatures(FPFeatures);
 }
@@ -1490,6 +1491,7 @@ CallExpr::CallExpr(StmtClass SC, unsigned NumPreArgs, unsigned NumArgs,
   assert((CallExprBits.OffsetToTrailingObjects == OffsetToTrailingObjects) &&
          "OffsetToTrailingObjects overflow!");
   CallExprBits.HasFPFeatures = HasFPFeatures;
+  CallExprBits.IsCoroElideSafe = false;
 }
 
 CallExpr *CallExpr::Create(const ASTContext &Ctx, Expr *Fn,

diff  --git a/clang/lib/CodeGen/CGBlocks.cpp b/clang/lib/CodeGen/CGBlocks.cpp
index 066139b1c78c7f..684fda74407313 100644
--- a/clang/lib/CodeGen/CGBlocks.cpp
+++ b/clang/lib/CodeGen/CGBlocks.cpp
@@ -1163,7 +1163,8 @@ llvm::Type *CodeGenModule::getGenericBlockLiteralType() {
 }
 
 RValue CodeGenFunction::EmitBlockCallExpr(const CallExpr *E,
-                                          ReturnValueSlot ReturnValue) {
+                                          ReturnValueSlot ReturnValue,
+                                          llvm::CallBase **CallOrInvoke) {
   const auto *BPT = E->getCallee()->getType()->castAs<BlockPointerType>();
   llvm::Value *BlockPtr = EmitScalarExpr(E->getCallee());
   llvm::Type *GenBlockTy = CGM.getGenericBlockLiteralType();
@@ -1220,7 +1221,7 @@ RValue CodeGenFunction::EmitBlockCallExpr(const CallExpr *E,
   CGCallee Callee(CGCalleeInfo(), Func);
 
   // And call the block.
-  return EmitCall(FnInfo, Callee, ReturnValue, Args);
+  return EmitCall(FnInfo, Callee, ReturnValue, Args, CallOrInvoke);
 }
 
 Address CodeGenFunction::GetAddrOfBlockDecl(const VarDecl *variable) {

diff  --git a/clang/lib/CodeGen/CGCUDARuntime.cpp b/clang/lib/CodeGen/CGCUDARuntime.cpp
index c14a9d3f2bbbcf..1e1da1e2411a76 100644
--- a/clang/lib/CodeGen/CGCUDARuntime.cpp
+++ b/clang/lib/CodeGen/CGCUDARuntime.cpp
@@ -25,7 +25,8 @@ CGCUDARuntime::~CGCUDARuntime() {}
 
 RValue CGCUDARuntime::EmitCUDAKernelCallExpr(CodeGenFunction &CGF,
                                              const CUDAKernelCallExpr *E,
-                                             ReturnValueSlot ReturnValue) {
+                                             ReturnValueSlot ReturnValue,
+                                             llvm::CallBase **CallOrInvoke) {
   llvm::BasicBlock *ConfigOKBlock = CGF.createBasicBlock("kcall.configok");
   llvm::BasicBlock *ContBlock = CGF.createBasicBlock("kcall.end");
 
@@ -35,7 +36,7 @@ RValue CGCUDARuntime::EmitCUDAKernelCallExpr(CodeGenFunction &CGF,
 
   eval.begin(CGF);
   CGF.EmitBlock(ConfigOKBlock);
-  CGF.EmitSimpleCallExpr(E, ReturnValue);
+  CGF.EmitSimpleCallExpr(E, ReturnValue, CallOrInvoke);
   CGF.EmitBranch(ContBlock);
 
   CGF.EmitBlock(ContBlock);

diff  --git a/clang/lib/CodeGen/CGCUDARuntime.h b/clang/lib/CodeGen/CGCUDARuntime.h
index 8030d632cc3d28..86f776004ee7c1 100644
--- a/clang/lib/CodeGen/CGCUDARuntime.h
+++ b/clang/lib/CodeGen/CGCUDARuntime.h
@@ -21,6 +21,7 @@
 #include "llvm/IR/GlobalValue.h"
 
 namespace llvm {
+class CallBase;
 class Function;
 class GlobalVariable;
 }
@@ -82,9 +83,10 @@ class CGCUDARuntime {
   CGCUDARuntime(CodeGenModule &CGM) : CGM(CGM) {}
   virtual ~CGCUDARuntime();
 
-  virtual RValue EmitCUDAKernelCallExpr(CodeGenFunction &CGF,
-                                        const CUDAKernelCallExpr *E,
-                                        ReturnValueSlot ReturnValue);
+  virtual RValue
+  EmitCUDAKernelCallExpr(CodeGenFunction &CGF, const CUDAKernelCallExpr *E,
+                         ReturnValueSlot ReturnValue,
+                         llvm::CallBase **CallOrInvoke = nullptr);
 
   /// Emits a kernel launch stub.
   virtual void emitDeviceStub(CodeGenFunction &CGF, FunctionArgList &Args) = 0;

diff  --git a/clang/lib/CodeGen/CGCXXABI.h b/clang/lib/CodeGen/CGCXXABI.h
index 7dcc539111996b..687ff7fb844445 100644
--- a/clang/lib/CodeGen/CGCXXABI.h
+++ b/clang/lib/CodeGen/CGCXXABI.h
@@ -485,11 +485,11 @@ class CGCXXABI {
       llvm::PointerUnion<const CXXDeleteExpr *, const CXXMemberCallExpr *>;
 
   /// Emit the ABI-specific virtual destructor call.
-  virtual llvm::Value *EmitVirtualDestructorCall(CodeGenFunction &CGF,
-                                                 const CXXDestructorDecl *Dtor,
-                                                 CXXDtorType DtorType,
-                                                 Address This,
-                                                 DeleteOrMemberCallExpr E) = 0;
+  virtual llvm::Value *
+  EmitVirtualDestructorCall(CodeGenFunction &CGF, const CXXDestructorDecl *Dtor,
+                            CXXDtorType DtorType, Address This,
+                            DeleteOrMemberCallExpr E,
+                            llvm::CallBase **CallOrInvoke) = 0;
 
   virtual void adjustCallArgsForDestructorThunk(CodeGenFunction &CGF,
                                                 GlobalDecl GD,

diff  --git a/clang/lib/CodeGen/CGClass.cpp b/clang/lib/CodeGen/CGClass.cpp
index e5ba50de3462da..352955749a6332 100644
--- a/clang/lib/CodeGen/CGClass.cpp
+++ b/clang/lib/CodeGen/CGClass.cpp
@@ -2192,15 +2192,11 @@ static bool canEmitDelegateCallArgs(CodeGenFunction &CGF,
   return true;
 }
 
-void CodeGenFunction::EmitCXXConstructorCall(const CXXConstructorDecl *D,
-                                             CXXCtorType Type,
-                                             bool ForVirtualBase,
-                                             bool Delegating,
-                                             Address This,
-                                             CallArgList &Args,
-                                             AggValueSlot::Overlap_t Overlap,
-                                             SourceLocation Loc,
-                                             bool NewPointerIsChecked) {
+void CodeGenFunction::EmitCXXConstructorCall(
+    const CXXConstructorDecl *D, CXXCtorType Type, bool ForVirtualBase,
+    bool Delegating, Address This, CallArgList &Args,
+    AggValueSlot::Overlap_t Overlap, SourceLocation Loc,
+    bool NewPointerIsChecked, llvm::CallBase **CallOrInvoke) {
   const CXXRecordDecl *ClassDecl = D->getParent();
 
   if (!NewPointerIsChecked)
@@ -2248,7 +2244,7 @@ void CodeGenFunction::EmitCXXConstructorCall(const CXXConstructorDecl *D,
   const CGFunctionInfo &Info = CGM.getTypes().arrangeCXXConstructorCall(
       Args, D, Type, ExtraArgs.Prefix, ExtraArgs.Suffix, PassPrototypeArgs);
   CGCallee Callee = CGCallee::forDirect(CalleePtr, GlobalDecl(D, Type));
-  EmitCall(Info, Callee, ReturnValueSlot(), Args, nullptr, false, Loc);
+  EmitCall(Info, Callee, ReturnValueSlot(), Args, CallOrInvoke, false, Loc);
 
   // Generate vtable assumptions if we're constructing a complete object
   // with a vtable.  We don't do this for base subobjects for two reasons:

diff  --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 99cd61b9e78953..35b5daaf6d4b55 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -33,6 +33,7 @@
 #include "clang/Basic/SourceManager.h"
 #include "llvm/ADT/Hashing.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/Intrinsics.h"
@@ -5544,16 +5545,30 @@ RValue CodeGenFunction::EmitRValueForField(LValue LV,
 //===--------------------------------------------------------------------===//
 
 RValue CodeGenFunction::EmitCallExpr(const CallExpr *E,
-                                     ReturnValueSlot ReturnValue) {
+                                     ReturnValueSlot ReturnValue,
+                                     llvm::CallBase **CallOrInvoke) {
+  llvm::CallBase *CallOrInvokeStorage;
+  if (!CallOrInvoke) {
+    CallOrInvoke = &CallOrInvokeStorage;
+  }
+
+  auto AddCoroElideSafeOnExit = llvm::make_scope_exit([&] {
+    if (E->isCoroElideSafe()) {
+      auto *I = *CallOrInvoke;
+      if (I)
+        I->addFnAttr(llvm::Attribute::CoroElideSafe);
+    }
+  });
+
   // Builtins never have block type.
   if (E->getCallee()->getType()->isBlockPointerType())
-    return EmitBlockCallExpr(E, ReturnValue);
+    return EmitBlockCallExpr(E, ReturnValue, CallOrInvoke);
 
   if (const auto *CE = dyn_cast<CXXMemberCallExpr>(E))
-    return EmitCXXMemberCallExpr(CE, ReturnValue);
+    return EmitCXXMemberCallExpr(CE, ReturnValue, CallOrInvoke);
 
   if (const auto *CE = dyn_cast<CUDAKernelCallExpr>(E))
-    return EmitCUDAKernelCallExpr(CE, ReturnValue);
+    return EmitCUDAKernelCallExpr(CE, ReturnValue, CallOrInvoke);
 
   // A CXXOperatorCallExpr is created even for explicit object methods, but
   // these should be treated like static function call.
@@ -5561,7 +5576,7 @@ RValue CodeGenFunction::EmitCallExpr(const CallExpr *E,
     if (const auto *MD =
             dyn_cast_if_present<CXXMethodDecl>(CE->getCalleeDecl());
         MD && MD->isImplicitObjectMemberFunction())
-      return EmitCXXOperatorMemberCallExpr(CE, MD, ReturnValue);
+      return EmitCXXOperatorMemberCallExpr(CE, MD, ReturnValue, CallOrInvoke);
 
   CGCallee callee = EmitCallee(E->getCallee());
 
@@ -5574,14 +5589,17 @@ RValue CodeGenFunction::EmitCallExpr(const CallExpr *E,
     return EmitCXXPseudoDestructorExpr(callee.getPseudoDestructorExpr());
   }
 
-  return EmitCall(E->getCallee()->getType(), callee, E, ReturnValue);
+  return EmitCall(E->getCallee()->getType(), callee, E, ReturnValue,
+                  /*Chain=*/nullptr, CallOrInvoke);
 }
 
 /// Emit a CallExpr without considering whether it might be a subclass.
 RValue CodeGenFunction::EmitSimpleCallExpr(const CallExpr *E,
-                                           ReturnValueSlot ReturnValue) {
+                                           ReturnValueSlot ReturnValue,
+                                           llvm::CallBase **CallOrInvoke) {
   CGCallee Callee = EmitCallee(E->getCallee());
-  return EmitCall(E->getCallee()->getType(), Callee, E, ReturnValue);
+  return EmitCall(E->getCallee()->getType(), Callee, E, ReturnValue,
+                  /*Chain=*/nullptr, CallOrInvoke);
 }
 
 // Detect the unusual situation where an inline version is shadowed by a
@@ -5785,8 +5803,9 @@ LValue CodeGenFunction::EmitBinaryOperatorLValue(const BinaryOperator *E) {
   llvm_unreachable("bad evaluation kind");
 }
 
-LValue CodeGenFunction::EmitCallExprLValue(const CallExpr *E) {
-  RValue RV = EmitCallExpr(E);
+LValue CodeGenFunction::EmitCallExprLValue(const CallExpr *E,
+                                           llvm::CallBase **CallOrInvoke) {
+  RValue RV = EmitCallExpr(E, ReturnValueSlot(), CallOrInvoke);
 
   if (!RV.isScalar())
     return MakeAddrLValue(RV.getAggregateAddress(), E->getType(),
@@ -5909,9 +5928,11 @@ LValue CodeGenFunction::EmitStmtExprLValue(const StmtExpr *E) {
                         AlignmentSource::Decl);
 }
 
-RValue CodeGenFunction::EmitCall(QualType CalleeType, const CGCallee &OrigCallee,
-                                 const CallExpr *E, ReturnValueSlot ReturnValue,
-                                 llvm::Value *Chain) {
+RValue CodeGenFunction::EmitCall(QualType CalleeType,
+                                 const CGCallee &OrigCallee, const CallExpr *E,
+                                 ReturnValueSlot ReturnValue,
+                                 llvm::Value *Chain,
+                                 llvm::CallBase **CallOrInvoke) {
   // Get the actual function type. The callee type will always be a pointer to
   // function type or a block pointer type.
   assert(CalleeType->isFunctionPointerType() &&
@@ -6131,8 +6152,8 @@ RValue CodeGenFunction::EmitCall(QualType CalleeType, const CGCallee &OrigCallee
         Address(Handle, Handle->getType(), CGM.getPointerAlign()));
     Callee.setFunctionPointer(Stub);
   }
-  llvm::CallBase *CallOrInvoke = nullptr;
-  RValue Call = EmitCall(FnInfo, Callee, ReturnValue, Args, &CallOrInvoke,
+  llvm::CallBase *LocalCallOrInvoke = nullptr;
+  RValue Call = EmitCall(FnInfo, Callee, ReturnValue, Args, &LocalCallOrInvoke,
                          E == MustTailCall, E->getExprLoc());
 
   // Generate function declaration DISuprogram in order to be used
@@ -6141,11 +6162,13 @@ RValue CodeGenFunction::EmitCall(QualType CalleeType, const CGCallee &OrigCallee
     if (auto *CalleeDecl = dyn_cast_or_null<FunctionDecl>(TargetDecl)) {
       FunctionArgList Args;
       QualType ResTy = BuildFunctionArgList(CalleeDecl, Args);
-      DI->EmitFuncDeclForCallSite(CallOrInvoke,
+      DI->EmitFuncDeclForCallSite(LocalCallOrInvoke,
                                   DI->getFunctionType(CalleeDecl, ResTy, Args),
                                   CalleeDecl);
     }
   }
+  if (CallOrInvoke)
+    *CallOrInvoke = LocalCallOrInvoke;
 
   return Call;
 }

diff  --git a/clang/lib/CodeGen/CGExprCXX.cpp b/clang/lib/CodeGen/CGExprCXX.cpp
index 8eb6ab7381acbc..1214bb054fb8df 100644
--- a/clang/lib/CodeGen/CGExprCXX.cpp
+++ b/clang/lib/CodeGen/CGExprCXX.cpp
@@ -84,23 +84,24 @@ commonEmitCXXMemberOrOperatorCall(CodeGenFunction &CGF, GlobalDecl GD,
 
 RValue CodeGenFunction::EmitCXXMemberOrOperatorCall(
     const CXXMethodDecl *MD, const CGCallee &Callee,
-    ReturnValueSlot ReturnValue,
-    llvm::Value *This, llvm::Value *ImplicitParam, QualType ImplicitParamTy,
-    const CallExpr *CE, CallArgList *RtlArgs) {
+    ReturnValueSlot ReturnValue, llvm::Value *This, llvm::Value *ImplicitParam,
+    QualType ImplicitParamTy, const CallExpr *CE, CallArgList *RtlArgs,
+    llvm::CallBase **CallOrInvoke) {
   const FunctionProtoType *FPT = MD->getType()->castAs<FunctionProtoType>();
   CallArgList Args;
   MemberCallInfo CallInfo = commonEmitCXXMemberOrOperatorCall(
       *this, MD, This, ImplicitParam, ImplicitParamTy, CE, Args, RtlArgs);
   auto &FnInfo = CGM.getTypes().arrangeCXXMethodCall(
       Args, FPT, CallInfo.ReqArgs, CallInfo.PrefixSize);
-  return EmitCall(FnInfo, Callee, ReturnValue, Args, nullptr,
+  return EmitCall(FnInfo, Callee, ReturnValue, Args, CallOrInvoke,
                   CE && CE == MustTailCall,
                   CE ? CE->getExprLoc() : SourceLocation());
 }
 
 RValue CodeGenFunction::EmitCXXDestructorCall(
     GlobalDecl Dtor, const CGCallee &Callee, llvm::Value *This, QualType ThisTy,
-    llvm::Value *ImplicitParam, QualType ImplicitParamTy, const CallExpr *CE) {
+    llvm::Value *ImplicitParam, QualType ImplicitParamTy, const CallExpr *CE,
+    llvm::CallBase **CallOrInvoke) {
   const CXXMethodDecl *DtorDecl = cast<CXXMethodDecl>(Dtor.getDecl());
 
   assert(!ThisTy.isNull());
@@ -120,7 +121,8 @@ RValue CodeGenFunction::EmitCXXDestructorCall(
   commonEmitCXXMemberOrOperatorCall(*this, Dtor, This, ImplicitParam,
                                     ImplicitParamTy, CE, Args, nullptr);
   return EmitCall(CGM.getTypes().arrangeCXXStructorDeclaration(Dtor), Callee,
-                  ReturnValueSlot(), Args, nullptr, CE && CE == MustTailCall,
+                  ReturnValueSlot(), Args, CallOrInvoke,
+                  CE && CE == MustTailCall,
                   CE ? CE->getExprLoc() : SourceLocation{});
 }
 
@@ -186,11 +188,12 @@ static CXXRecordDecl *getCXXRecord(const Expr *E) {
 // Note: This function also emit constructor calls to support a MSVC
 // extensions allowing explicit constructor function call.
 RValue CodeGenFunction::EmitCXXMemberCallExpr(const CXXMemberCallExpr *CE,
-                                              ReturnValueSlot ReturnValue) {
+                                              ReturnValueSlot ReturnValue,
+                                              llvm::CallBase **CallOrInvoke) {
   const Expr *callee = CE->getCallee()->IgnoreParens();
 
   if (isa<BinaryOperator>(callee))
-    return EmitCXXMemberPointerCallExpr(CE, ReturnValue);
+    return EmitCXXMemberPointerCallExpr(CE, ReturnValue, CallOrInvoke);
 
   const MemberExpr *ME = cast<MemberExpr>(callee);
   const CXXMethodDecl *MD = cast<CXXMethodDecl>(ME->getMemberDecl());
@@ -200,7 +203,7 @@ RValue CodeGenFunction::EmitCXXMemberCallExpr(const CXXMemberCallExpr *CE,
     CGCallee callee =
         CGCallee::forDirect(CGM.GetAddrOfFunction(MD), GlobalDecl(MD));
     return EmitCall(getContext().getPointerType(MD->getType()), callee, CE,
-                    ReturnValue);
+                    ReturnValue, /*Chain=*/nullptr, CallOrInvoke);
   }
 
   bool HasQualifier = ME->hasQualifier();
@@ -208,14 +211,15 @@ RValue CodeGenFunction::EmitCXXMemberCallExpr(const CXXMemberCallExpr *CE,
   bool IsArrow = ME->isArrow();
   const Expr *Base = ME->getBase();
 
-  return EmitCXXMemberOrOperatorMemberCallExpr(
-      CE, MD, ReturnValue, HasQualifier, Qualifier, IsArrow, Base);
+  return EmitCXXMemberOrOperatorMemberCallExpr(CE, MD, ReturnValue,
+                                               HasQualifier, Qualifier, IsArrow,
+                                               Base, CallOrInvoke);
 }
 
 RValue CodeGenFunction::EmitCXXMemberOrOperatorMemberCallExpr(
     const CallExpr *CE, const CXXMethodDecl *MD, ReturnValueSlot ReturnValue,
     bool HasQualifier, NestedNameSpecifier *Qualifier, bool IsArrow,
-    const Expr *Base) {
+    const Expr *Base, llvm::CallBase **CallOrInvoke) {
   assert(isa<CXXMemberCallExpr>(CE) || isa<CXXOperatorCallExpr>(CE));
 
   // Compute the object pointer.
@@ -300,7 +304,7 @@ RValue CodeGenFunction::EmitCXXMemberOrOperatorMemberCallExpr(
     EmitCXXConstructorCall(Ctor, Ctor_Complete, /*ForVirtualBase=*/false,
                            /*Delegating=*/false, This.getAddress(), Args,
                            AggValueSlot::DoesNotOverlap, CE->getExprLoc(),
-                           /*NewPointerIsChecked=*/false);
+                           /*NewPointerIsChecked=*/false, CallOrInvoke);
     return RValue::get(nullptr);
   }
 
@@ -374,9 +378,9 @@ RValue CodeGenFunction::EmitCXXMemberOrOperatorMemberCallExpr(
            "Destructor shouldn't have explicit parameters");
     assert(ReturnValue.isNull() && "Destructor shouldn't have return value");
     if (UseVirtualCall) {
-      CGM.getCXXABI().EmitVirtualDestructorCall(*this, Dtor, Dtor_Complete,
-                                                This.getAddress(),
-                                                cast<CXXMemberCallExpr>(CE));
+      CGM.getCXXABI().EmitVirtualDestructorCall(
+          *this, Dtor, Dtor_Complete, This.getAddress(),
+          cast<CXXMemberCallExpr>(CE), CallOrInvoke);
     } else {
       GlobalDecl GD(Dtor, Dtor_Complete);
       CGCallee Callee;
@@ -393,7 +397,7 @@ RValue CodeGenFunction::EmitCXXMemberOrOperatorMemberCallExpr(
           IsArrow ? Base->getType()->getPointeeType() : Base->getType();
       EmitCXXDestructorCall(GD, Callee, This.getPointer(*this), ThisTy,
                             /*ImplicitParam=*/nullptr,
-                            /*ImplicitParamTy=*/QualType(), CE);
+                            /*ImplicitParamTy=*/QualType(), CE, CallOrInvoke);
     }
     return RValue::get(nullptr);
   }
@@ -435,12 +439,13 @@ RValue CodeGenFunction::EmitCXXMemberOrOperatorMemberCallExpr(
 
   return EmitCXXMemberOrOperatorCall(
       CalleeDecl, Callee, ReturnValue, This.getPointer(*this),
-      /*ImplicitParam=*/nullptr, QualType(), CE, RtlArgs);
+      /*ImplicitParam=*/nullptr, QualType(), CE, RtlArgs, CallOrInvoke);
 }
 
 RValue
 CodeGenFunction::EmitCXXMemberPointerCallExpr(const CXXMemberCallExpr *E,
-                                              ReturnValueSlot ReturnValue) {
+                                              ReturnValueSlot ReturnValue,
+                                              llvm::CallBase **CallOrInvoke) {
   const BinaryOperator *BO =
       cast<BinaryOperator>(E->getCallee()->IgnoreParens());
   const Expr *BaseExpr = BO->getLHS();
@@ -484,24 +489,25 @@ CodeGenFunction::EmitCXXMemberPointerCallExpr(const CXXMemberCallExpr *E,
   EmitCallArgs(Args, FPT, E->arguments());
   return EmitCall(CGM.getTypes().arrangeCXXMethodCall(Args, FPT, required,
                                                       /*PrefixSize=*/0),
-                  Callee, ReturnValue, Args, nullptr, E == MustTailCall,
+                  Callee, ReturnValue, Args, CallOrInvoke, E == MustTailCall,
                   E->getExprLoc());
 }
 
-RValue
-CodeGenFunction::EmitCXXOperatorMemberCallExpr(const CXXOperatorCallExpr *E,
-                                               const CXXMethodDecl *MD,
-                                               ReturnValueSlot ReturnValue) {
+RValue CodeGenFunction::EmitCXXOperatorMemberCallExpr(
+    const CXXOperatorCallExpr *E, const CXXMethodDecl *MD,
+    ReturnValueSlot ReturnValue, llvm::CallBase **CallOrInvoke) {
   assert(MD->isImplicitObjectMemberFunction() &&
          "Trying to emit a member call expr on a static method!");
   return EmitCXXMemberOrOperatorMemberCallExpr(
       E, MD, ReturnValue, /*HasQualifier=*/false, /*Qualifier=*/nullptr,
-      /*IsArrow=*/false, E->getArg(0));
+      /*IsArrow=*/false, E->getArg(0), CallOrInvoke);
 }
 
 RValue CodeGenFunction::EmitCUDAKernelCallExpr(const CUDAKernelCallExpr *E,
-                                               ReturnValueSlot ReturnValue) {
-  return CGM.getCUDARuntime().EmitCUDAKernelCallExpr(*this, E, ReturnValue);
+                                               ReturnValueSlot ReturnValue,
+                                               llvm::CallBase **CallOrInvoke) {
+  return CGM.getCUDARuntime().EmitCUDAKernelCallExpr(*this, E, ReturnValue,
+                                                     CallOrInvoke);
 }
 
 static void EmitNullBaseClassInitialization(CodeGenFunction &CGF,

diff  --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 9b93e9673ec5f8..5892d6ac6f88a5 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -3149,7 +3149,8 @@ class CodeGenFunction : public CodeGenTypeCache {
                               bool ForVirtualBase, bool Delegating,
                               Address This, CallArgList &Args,
                               AggValueSlot::Overlap_t Overlap,
-                              SourceLocation Loc, bool NewPointerIsChecked);
+                              SourceLocation Loc, bool NewPointerIsChecked,
+                              llvm::CallBase **CallOrInvoke = nullptr);
 
   /// Emit assumption load for all bases. Requires to be called only on
   /// most-derived class and not under construction of the object.
@@ -4269,7 +4270,8 @@ class CodeGenFunction : public CodeGenTypeCache {
   LValue EmitBinaryOperatorLValue(const BinaryOperator *E);
   LValue EmitCompoundAssignmentLValue(const CompoundAssignOperator *E);
   // Note: only available for agg return types
-  LValue EmitCallExprLValue(const CallExpr *E);
+  LValue EmitCallExprLValue(const CallExpr *E,
+                            llvm::CallBase **CallOrInvoke = nullptr);
   // Note: only available for agg return types
   LValue EmitVAArgExprLValue(const VAArgExpr *E);
   LValue EmitDeclRefLValue(const DeclRefExpr *E);
@@ -4382,21 +4384,27 @@ class CodeGenFunction : public CodeGenTypeCache {
   /// LLVM arguments and the types they were derived from.
   RValue EmitCall(const CGFunctionInfo &CallInfo, const CGCallee &Callee,
                   ReturnValueSlot ReturnValue, const CallArgList &Args,
-                  llvm::CallBase **callOrInvoke, bool IsMustTail,
+                  llvm::CallBase **CallOrInvoke, bool IsMustTail,
                   SourceLocation Loc,
                   bool IsVirtualFunctionPointerThunk = false);
   RValue EmitCall(const CGFunctionInfo &CallInfo, const CGCallee &Callee,
                   ReturnValueSlot ReturnValue, const CallArgList &Args,
-                  llvm::CallBase **callOrInvoke = nullptr,
+                  llvm::CallBase **CallOrInvoke = nullptr,
                   bool IsMustTail = false) {
-    return EmitCall(CallInfo, Callee, ReturnValue, Args, callOrInvoke,
+    return EmitCall(CallInfo, Callee, ReturnValue, Args, CallOrInvoke,
                     IsMustTail, SourceLocation());
   }
   RValue EmitCall(QualType FnType, const CGCallee &Callee, const CallExpr *E,
-                  ReturnValueSlot ReturnValue, llvm::Value *Chain = nullptr);
+                  ReturnValueSlot ReturnValue, llvm::Value *Chain = nullptr,
+                  llvm::CallBase **CallOrInvoke = nullptr);
+
+  // If a Call or Invoke instruction was emitted for this CallExpr, this method
+  // writes the pointer to `CallOrInvoke` if it's not null.
   RValue EmitCallExpr(const CallExpr *E,
-                      ReturnValueSlot ReturnValue = ReturnValueSlot());
-  RValue EmitSimpleCallExpr(const CallExpr *E, ReturnValueSlot ReturnValue);
+                      ReturnValueSlot ReturnValue = ReturnValueSlot(),
+                      llvm::CallBase **CallOrInvoke = nullptr);
+  RValue EmitSimpleCallExpr(const CallExpr *E, ReturnValueSlot ReturnValue,
+                            llvm::CallBase **CallOrInvoke = nullptr);
   CGCallee EmitCallee(const Expr *E);
 
   void checkTargetFeatures(const CallExpr *E, const FunctionDecl *TargetDecl);
@@ -4500,25 +4508,23 @@ class CodeGenFunction : public CodeGenTypeCache {
   void callCStructCopyAssignmentOperator(LValue Dst, LValue Src);
   void callCStructMoveAssignmentOperator(LValue Dst, LValue Src);
 
-  RValue
-  EmitCXXMemberOrOperatorCall(const CXXMethodDecl *Method,
-                              const CGCallee &Callee,
-                              ReturnValueSlot ReturnValue, llvm::Value *This,
-                              llvm::Value *ImplicitParam,
-                              QualType ImplicitParamTy, const CallExpr *E,
-                              CallArgList *RtlArgs);
+  RValue EmitCXXMemberOrOperatorCall(
+      const CXXMethodDecl *Method, const CGCallee &Callee,
+      ReturnValueSlot ReturnValue, llvm::Value *This,
+      llvm::Value *ImplicitParam, QualType ImplicitParamTy, const CallExpr *E,
+      CallArgList *RtlArgs, llvm::CallBase **CallOrInvoke);
   RValue EmitCXXDestructorCall(GlobalDecl Dtor, const CGCallee &Callee,
                                llvm::Value *This, QualType ThisTy,
                                llvm::Value *ImplicitParam,
-                               QualType ImplicitParamTy, const CallExpr *E);
+                               QualType ImplicitParamTy, const CallExpr *E,
+                               llvm::CallBase **CallOrInvoke = nullptr);
   RValue EmitCXXMemberCallExpr(const CXXMemberCallExpr *E,
-                               ReturnValueSlot ReturnValue);
-  RValue EmitCXXMemberOrOperatorMemberCallExpr(const CallExpr *CE,
-                                               const CXXMethodDecl *MD,
-                                               ReturnValueSlot ReturnValue,
-                                               bool HasQualifier,
-                                               NestedNameSpecifier *Qualifier,
-                                               bool IsArrow, const Expr *Base);
+                               ReturnValueSlot ReturnValue,
+                               llvm::CallBase **CallOrInvoke = nullptr);
+  RValue EmitCXXMemberOrOperatorMemberCallExpr(
+      const CallExpr *CE, const CXXMethodDecl *MD, ReturnValueSlot ReturnValue,
+      bool HasQualifier, NestedNameSpecifier *Qualifier, bool IsArrow,
+      const Expr *Base, llvm::CallBase **CallOrInvoke);
   // Compute the object pointer.
   Address EmitCXXMemberDataPointerAddress(const Expr *E, Address base,
                                           llvm::Value *memberPtr,
@@ -4526,15 +4532,18 @@ class CodeGenFunction : public CodeGenTypeCache {
                                           LValueBaseInfo *BaseInfo = nullptr,
                                           TBAAAccessInfo *TBAAInfo = nullptr);
   RValue EmitCXXMemberPointerCallExpr(const CXXMemberCallExpr *E,
-                                      ReturnValueSlot ReturnValue);
+                                      ReturnValueSlot ReturnValue,
+                                      llvm::CallBase **CallOrInvoke);
 
   RValue EmitCXXOperatorMemberCallExpr(const CXXOperatorCallExpr *E,
                                        const CXXMethodDecl *MD,
-                                       ReturnValueSlot ReturnValue);
+                                       ReturnValueSlot ReturnValue,
+                                       llvm::CallBase **CallOrInvoke);
   RValue EmitCXXPseudoDestructorExpr(const CXXPseudoDestructorExpr *E);
 
   RValue EmitCUDAKernelCallExpr(const CUDAKernelCallExpr *E,
-                                ReturnValueSlot ReturnValue);
+                                ReturnValueSlot ReturnValue,
+                                llvm::CallBase **CallOrInvoke);
 
   RValue EmitNVPTXDevicePrintfCallExpr(const CallExpr *E);
   RValue EmitAMDGPUDevicePrintfCallExpr(const CallExpr *E);
@@ -4556,7 +4565,8 @@ class CodeGenFunction : public CodeGenTypeCache {
       const analyze_os_log::OSLogBufferLayout &Layout,
       CharUnits BufferAlignment);
 
-  RValue EmitBlockCallExpr(const CallExpr *E, ReturnValueSlot ReturnValue);
+  RValue EmitBlockCallExpr(const CallExpr *E, ReturnValueSlot ReturnValue,
+                           llvm::CallBase **CallOrInvoke);
 
   /// EmitTargetBuiltinExpr - Emit the given builtin call. Returns 0 if the call
   /// is unhandled by the current target.

diff  --git a/clang/lib/CodeGen/ItaniumCXXABI.cpp b/clang/lib/CodeGen/ItaniumCXXABI.cpp
index fb1eb72d9f3404..dcc35d5689831e 100644
--- a/clang/lib/CodeGen/ItaniumCXXABI.cpp
+++ b/clang/lib/CodeGen/ItaniumCXXABI.cpp
@@ -315,10 +315,11 @@ class ItaniumCXXABI : public CodeGen::CGCXXABI {
                                      Address This, llvm::Type *Ty,
                                      SourceLocation Loc) override;
 
-  llvm::Value *EmitVirtualDestructorCall(CodeGenFunction &CGF,
-                                         const CXXDestructorDecl *Dtor,
-                                         CXXDtorType DtorType, Address This,
-                                         DeleteOrMemberCallExpr E) override;
+  llvm::Value *
+  EmitVirtualDestructorCall(CodeGenFunction &CGF, const CXXDestructorDecl *Dtor,
+                            CXXDtorType DtorType, Address This,
+                            DeleteOrMemberCallExpr E,
+                            llvm::CallBase **CallOrInvoke) override;
 
   void emitVirtualInheritanceTables(const CXXRecordDecl *RD) override;
 
@@ -1399,7 +1400,8 @@ void ItaniumCXXABI::emitVirtualObjectDelete(CodeGenFunction &CGF,
   // FIXME: Provide a source location here even though there's no
   // CXXMemberCallExpr for dtor call.
   CXXDtorType DtorType = UseGlobalDelete ? Dtor_Complete : Dtor_Deleting;
-  EmitVirtualDestructorCall(CGF, Dtor, DtorType, Ptr, DE);
+  EmitVirtualDestructorCall(CGF, Dtor, DtorType, Ptr, DE,
+                            /*CallOrInvoke=*/nullptr);
 
   if (UseGlobalDelete)
     CGF.PopCleanupBlock();
@@ -2236,7 +2238,7 @@ CGCallee ItaniumCXXABI::getVirtualFunctionPointer(CodeGenFunction &CGF,
 
 llvm::Value *ItaniumCXXABI::EmitVirtualDestructorCall(
     CodeGenFunction &CGF, const CXXDestructorDecl *Dtor, CXXDtorType DtorType,
-    Address This, DeleteOrMemberCallExpr E) {
+    Address This, DeleteOrMemberCallExpr E, llvm::CallBase **CallOrInvoke) {
   auto *CE = E.dyn_cast<const CXXMemberCallExpr *>();
   auto *D = E.dyn_cast<const CXXDeleteExpr *>();
   assert((CE != nullptr) ^ (D != nullptr));
@@ -2257,7 +2259,7 @@ llvm::Value *ItaniumCXXABI::EmitVirtualDestructorCall(
   }
 
   CGF.EmitCXXDestructorCall(GD, Callee, This.emitRawPointer(CGF), ThisTy,
-                            nullptr, QualType(), nullptr);
+                            nullptr, QualType(), nullptr, CallOrInvoke);
   return nullptr;
 }
 

diff  --git a/clang/lib/CodeGen/MicrosoftCXXABI.cpp b/clang/lib/CodeGen/MicrosoftCXXABI.cpp
index 76d0191a7e63ad..79dcdc04b0996f 100644
--- a/clang/lib/CodeGen/MicrosoftCXXABI.cpp
+++ b/clang/lib/CodeGen/MicrosoftCXXABI.cpp
@@ -334,10 +334,11 @@ class MicrosoftCXXABI : public CGCXXABI {
                                      Address This, llvm::Type *Ty,
                                      SourceLocation Loc) override;
 
-  llvm::Value *EmitVirtualDestructorCall(CodeGenFunction &CGF,
-                                         const CXXDestructorDecl *Dtor,
-                                         CXXDtorType DtorType, Address This,
-                                         DeleteOrMemberCallExpr E) override;
+  llvm::Value *
+  EmitVirtualDestructorCall(CodeGenFunction &CGF, const CXXDestructorDecl *Dtor,
+                            CXXDtorType DtorType, Address This,
+                            DeleteOrMemberCallExpr E,
+                            llvm::CallBase **CallOrInvoke) override;
 
   void adjustCallArgsForDestructorThunk(CodeGenFunction &CGF, GlobalDecl GD,
                                         CallArgList &CallArgs) override {
@@ -901,7 +902,8 @@ void MicrosoftCXXABI::emitVirtualObjectDelete(CodeGenFunction &CGF,
   // CXXMemberCallExpr for dtor call.
   bool UseGlobalDelete = DE->isGlobalDelete();
   CXXDtorType DtorType = UseGlobalDelete ? Dtor_Complete : Dtor_Deleting;
-  llvm::Value *MDThis = EmitVirtualDestructorCall(CGF, Dtor, DtorType, Ptr, DE);
+  llvm::Value *MDThis = EmitVirtualDestructorCall(CGF, Dtor, DtorType, Ptr, DE,
+                                                  /*CallOrInvoke=*/nullptr);
   if (UseGlobalDelete)
     CGF.EmitDeleteCall(DE->getOperatorDelete(), MDThis, ElementType);
 }
@@ -1685,7 +1687,7 @@ void MicrosoftCXXABI::EmitDestructorCall(CodeGenFunction &CGF,
   CGF.EmitCXXDestructorCall(GD, Callee, CGF.getAsNaturalPointerTo(This, ThisTy),
                             ThisTy,
                             /*ImplicitParam=*/Implicit,
-                            /*ImplicitParamTy=*/QualType(), nullptr);
+                            /*ImplicitParamTy=*/QualType(), /*E=*/nullptr);
   if (BaseDtorEndBB) {
     // Complete object handler should continue to be the remaining
     CGF.Builder.CreateBr(BaseDtorEndBB);
@@ -2001,7 +2003,7 @@ CGCallee MicrosoftCXXABI::getVirtualFunctionPointer(CodeGenFunction &CGF,
 
 llvm::Value *MicrosoftCXXABI::EmitVirtualDestructorCall(
     CodeGenFunction &CGF, const CXXDestructorDecl *Dtor, CXXDtorType DtorType,
-    Address This, DeleteOrMemberCallExpr E) {
+    Address This, DeleteOrMemberCallExpr E, llvm::CallBase **CallOrInvoke) {
   auto *CE = E.dyn_cast<const CXXMemberCallExpr *>();
   auto *D = E.dyn_cast<const CXXDeleteExpr *>();
   assert((CE != nullptr) ^ (D != nullptr));
@@ -2031,7 +2033,7 @@ llvm::Value *MicrosoftCXXABI::EmitVirtualDestructorCall(
   This = adjustThisArgumentForVirtualFunctionCall(CGF, GD, This, true);
   RValue RV =
       CGF.EmitCXXDestructorCall(GD, Callee, This.emitRawPointer(CGF), ThisTy,
-                                ImplicitParam, Context.IntTy, CE);
+                                ImplicitParam, Context.IntTy, CE, CallOrInvoke);
   return RV.getScalarVal();
 }
 

diff  --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp
index 1bb8955f6f8792..a574d56646f3a2 100644
--- a/clang/lib/Sema/SemaCoroutine.cpp
+++ b/clang/lib/Sema/SemaCoroutine.cpp
@@ -844,6 +844,19 @@ ExprResult Sema::BuildOperatorCoawaitLookupExpr(Scope *S, SourceLocation Loc) {
   return CoawaitOp;
 }
 
+static bool isAttributedCoroAwaitElidable(const QualType &QT) {
+  auto *Record = QT->getAsCXXRecordDecl();
+  return Record && Record->hasAttr<CoroAwaitElidableAttr>();
+}
+
+static bool isCoroAwaitElidableCall(Expr *Operand) {
+  if (!Operand->isPRValue()) {
+    return false;
+  }
+
+  return isAttributedCoroAwaitElidable(Operand->getType());
+}
+
 // Attempts to resolve and build a CoawaitExpr from "raw" inputs, bailing out to
 // DependentCoawaitExpr if needed.
 ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *Operand,
@@ -867,7 +880,16 @@ ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *Operand,
   }
 
   auto *RD = Promise->getType()->getAsCXXRecordDecl();
-  auto *Transformed = Operand;
+  bool AwaitElidable =
+      isCoroAwaitElidableCall(Operand) &&
+      isAttributedCoroAwaitElidable(
+          getCurFunctionDecl(/*AllowLambda=*/true)->getReturnType());
+
+  if (AwaitElidable)
+    if (auto *Call = dyn_cast<CallExpr>(Operand->IgnoreImplicit()))
+      Call->setCoroElideSafe();
+
+  Expr *Transformed = Operand;
   if (lookupMember(*this, "await_transform", RD, Loc)) {
     ExprResult R =
         buildPromiseCall(*this, Promise, Loc, "await_transform", Operand);

diff  --git a/clang/test/CodeGenCoroutines/Inputs/utility.h b/clang/test/CodeGenCoroutines/Inputs/utility.h
new file mode 100644
index 00000000000000..43c6d27823bd47
--- /dev/null
+++ b/clang/test/CodeGenCoroutines/Inputs/utility.h
@@ -0,0 +1,13 @@
+// This is a mock file for <utility>
+
+namespace std {
+
+template <typename T> struct remove_reference { using type = T; };
+template <typename T> struct remove_reference<T &> { using type = T; };
+template <typename T> struct remove_reference<T &&> { using type = T; };
+
+template <typename T>
+constexpr typename std::remove_reference<T>::type&& move(T &&t) noexcept {
+  return static_cast<typename std::remove_reference<T>::type &&>(t);
+}
+}

diff  --git a/clang/test/CodeGenCoroutines/coro-await-elidable.cpp b/clang/test/CodeGenCoroutines/coro-await-elidable.cpp
new file mode 100644
index 00000000000000..8512995dfad45a
--- /dev/null
+++ b/clang/test/CodeGenCoroutines/coro-await-elidable.cpp
@@ -0,0 +1,87 @@
+// This file tests the coro_await_elidable attribute semantics.
+// RUN: %clang_cc1 -triple=x86_64-unknown-linux-gnu -std=c++20 -disable-llvm-passes -emit-llvm %s -o - | FileCheck %s
+
+#include "Inputs/coroutine.h"
+#include "Inputs/utility.h"
+
+template <typename T>
+struct [[clang::coro_await_elidable]] Task {
+  struct promise_type {
+    struct FinalAwaiter {
+      bool await_ready() const noexcept { return false; }
+
+      template <typename P>
+      std::coroutine_handle<> await_suspend(std::coroutine_handle<P> coro) noexcept {
+        if (!coro)
+          return std::noop_coroutine();
+        return coro.promise().continuation;
+      }
+      void await_resume() noexcept {}
+    };
+
+    Task get_return_object() noexcept {
+      return std::coroutine_handle<promise_type>::from_promise(*this);
+    }
+
+    std::suspend_always initial_suspend() noexcept { return {}; }
+    FinalAwaiter final_suspend() noexcept { return {}; }
+    void unhandled_exception() noexcept {}
+    void return_value(T x) noexcept {
+      value = x;
+    }
+
+    std::coroutine_handle<> continuation;
+    T value;
+  };
+
+  Task(std::coroutine_handle<promise_type> handle) : handle(handle) {}
+  ~Task() {
+    if (handle)
+      handle.destroy();
+  }
+
+  struct Awaiter {
+    Awaiter(Task *t) : task(t) {}
+    bool await_ready() const noexcept { return false; }
+    void await_suspend(std::coroutine_handle<void> continuation) noexcept {}
+    T await_resume() noexcept {
+      return task->handle.promise().value;
+    }
+
+    Task *task;
+  };
+
+  auto operator co_await() {
+    return Awaiter{this};
+  }
+
+private:
+  std::coroutine_handle<promise_type> handle;
+};
+
+// CHECK-LABEL: define{{.*}} @_Z6calleev{{.*}} {
+Task<int> callee() {
+  co_return 1;
+}
+
+// CHECK-LABEL: define{{.*}} @_Z8elidablev{{.*}} {
+Task<int> elidable() {
+  // CHECK: %[[TASK_OBJ:.+]] = alloca %struct.Task
+  // CHECK: call void @_Z6calleev(ptr dead_on_unwind writable sret(%struct.Task) align 8 %[[TASK_OBJ]]) #[[ELIDE_SAFE:.+]]
+  co_return co_await callee();
+}
+
+// CHECK-LABEL: define{{.*}} @_Z11nonelidablev{{.*}} {
+Task<int> nonelidable() {
+  // CHECK: %[[TASK_OBJ:.+]] = alloca %struct.Task
+  auto t = callee();
+  // Because we aren't co_awaiting a prvalue, we cannot elide here.
+  // CHECK: call void @_Z6calleev(ptr dead_on_unwind writable sret(%struct.Task) align 8 %[[TASK_OBJ]])
+  // CHECK-NOT: #[[ELIDE_SAFE]]
+  co_await t;
+  co_await std::move(t);
+
+  co_return 1;
+}
+
+// CHECK: attributes #[[ELIDE_SAFE]] = { coro_elide_safe }

diff  --git a/clang/test/Misc/pragma-attribute-supported-attributes-list.test b/clang/test/Misc/pragma-attribute-supported-attributes-list.test
index eca86331149028..baa1816358b156 100644
--- a/clang/test/Misc/pragma-attribute-supported-attributes-list.test
+++ b/clang/test/Misc/pragma-attribute-supported-attributes-list.test
@@ -59,6 +59,7 @@
 // CHECK-NEXT: ConsumableAutoCast (SubjectMatchRule_record)
 // CHECK-NEXT: ConsumableSetOnRead (SubjectMatchRule_record)
 // CHECK-NEXT: Convergent (SubjectMatchRule_function)
+// CHECK-NEXT: CoroAwaitElidable (SubjectMatchRule_record)
 // CHECK-NEXT: CoroDisableLifetimeBound (SubjectMatchRule_function)
 // CHECK-NEXT: CoroLifetimeBound (SubjectMatchRule_record)
 // CHECK-NEXT: CoroOnlyDestroyWhenComplete (SubjectMatchRule_record)

diff  --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
index 49a48f1c1510c3..05ed148148d7cb 100644
--- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -761,6 +761,8 @@ enum AttributeKindCodes {
   ATTR_KIND_INITIALIZES = 94,
   ATTR_KIND_HYBRID_PATCHABLE = 95,
   ATTR_KIND_SANITIZE_REALTIME = 96,
+  ATTR_KIND_NO_SANITIZE_REALTIME = 97,
+  ATTR_KIND_CORO_ELIDE_SAFE = 98,
 };
 
 enum ComdatSelectionKindCodes {

diff  --git a/llvm/include/llvm/IR/Attributes.td b/llvm/include/llvm/IR/Attributes.td
index 891e34fec0c798..f3ef1e707675eb 100644
--- a/llvm/include/llvm/IR/Attributes.td
+++ b/llvm/include/llvm/IR/Attributes.td
@@ -345,6 +345,10 @@ def PresplitCoroutine : EnumAttr<"presplitcoroutine", [FnAttr]>;
 /// The coroutine would only be destroyed when it is complete.
 def CoroDestroyOnlyWhenComplete : EnumAttr<"coro_only_destroy_when_complete", [FnAttr]>;
 
+/// The coroutine call meets the elide requirement. Hint the optimization
+/// pipeline to perform elide on the call or invoke instruction.
+def CoroElideSafe : EnumAttr<"coro_elide_safe", [FnAttr]>;
+
 /// Target-independent string attributes.
 def LessPreciseFPMAD : StrBoolAttr<"less-precise-fpmad">;
 def NoInfsFPMath : StrBoolAttr<"no-infs-fp-math">;

diff  --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index 4faee1fb9ada22..6b8edbca19a015 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -2190,6 +2190,8 @@ static Attribute::AttrKind getAttrFromCode(uint64_t Code) {
     return Attribute::Range;
   case bitc::ATTR_KIND_INITIALIZES:
     return Attribute::Initializes;
+  case bitc::ATTR_KIND_CORO_ELIDE_SAFE:
+    return Attribute::CoroElideSafe;
   }
 }
 

diff  --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index 1aeaf0955fbbf3..a5942153dc2d64 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -885,6 +885,8 @@ static uint64_t getAttrKindEncoding(Attribute::AttrKind Kind) {
     return bitc::ATTR_KIND_WRITABLE;
   case Attribute::CoroDestroyOnlyWhenComplete:
     return bitc::ATTR_KIND_CORO_ONLY_DESTROY_WHEN_COMPLETE;
+  case Attribute::CoroElideSafe:
+    return bitc::ATTR_KIND_CORO_ELIDE_SAFE;
   case Attribute::DeadOnUnwind:
     return bitc::ATTR_KIND_DEAD_ON_UNWIND;
   case Attribute::Range:

diff  --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index d378c6c3a4b01c..895b588a9e5ac3 100644
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -916,6 +916,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
       case Attribute::Memory:
       case Attribute::NoFPClass:
       case Attribute::CoroDestroyOnlyWhenComplete:
+      case Attribute::CoroElideSafe:
         continue;
       // Those attributes should be safe to propagate to the extracted function.
       case Attribute::AlwaysInline:


        


More information about the llvm-commits mailing list