[clang] [Clang] Propagate elide safe context through [[clang::coro_must_await]] (PR #108474)

Yuxuan Chen via cfe-commits cfe-commits at lists.llvm.org
Thu Sep 12 17:35:56 PDT 2024


https://github.com/yuxuanchen1997 created https://github.com/llvm/llvm-project/pull/108474

None

>From a4736c1effa479692157dbe735fa873b233f98bd Mon Sep 17 00:00:00 2001
From: Yuxuan Chen <ych at meta.com>
Date: Thu, 12 Sep 2024 17:13:57 -0700
Subject: [PATCH] [Clang] Propagate elide safe context through
 [[clang::coro_must_await]]

---
 clang/include/clang/Basic/Attr.td             |  8 ++++
 clang/include/clang/Basic/AttrDocs.td         | 12 +++++
 clang/lib/Sema/SemaCoroutine.cpp              | 48 ++++++++++++++-----
 .../CodeGenCoroutines/coro-await-elidable.cpp | 19 ++++++++
 4 files changed, 76 insertions(+), 11 deletions(-)

diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 9a7b163b2c6da8..cd29f08a3320e7 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -1258,6 +1258,14 @@ def CoroAwaitElidable : InheritableAttr {
   let SimpleHandler = 1;
 }
 
+def CoroMustAwait : InheritableAttr {
+  let Spellings = [Clang<"coro_must_await">];
+  let Subjects = SubjectList<[ParmVar]>;
+  let LangOpts = [CPlusPlus];
+  let Documentation = [CoroMustAwaitDoc];
+  let SimpleHandler = 1;
+}
+
 // OSObject-based attributes.
 def OSConsumed : InheritableParamAttr {
   let Spellings = [Clang<"os_consumed">];
diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td
index 9f72456d2da678..2403cb4d2f994a 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -8301,6 +8301,18 @@ callee coroutine.
 }];
 }
 
+def CoroMustAwaitDoc : Documentation {
+  let Category = DocCatDecl;
+  let Content = [{
+A direct call expression which returned a prvalue of a type attributed 
+[[clang::coro_await_elidable]] is said to be under a SafeElide context
+if one of the following is true: 
+- it is the right-hand side operand to an co_await expression. 
+- it is an argument to a [[clang::coro_must_await]] parameter or 
+parameter pack of another direct call expression under a SafeElide context.
+}];
+}
+
 def CountedByDocs : Documentation {
   let Category = DocCatField;
   let Content = [{
diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp
index a574d56646f3a2..17ffd24a3eb3e4 100644
--- a/clang/lib/Sema/SemaCoroutine.cpp
+++ b/clang/lib/Sema/SemaCoroutine.cpp
@@ -849,12 +849,40 @@ static bool isAttributedCoroAwaitElidable(const QualType &QT) {
   return Record && Record->hasAttr<CoroAwaitElidableAttr>();
 }
 
-static bool isCoroAwaitElidableCall(Expr *Operand) {
+static void applyAwaitElidableContext(Expr *Operand) {
   if (!Operand->isPRValue()) {
-    return false;
+    return;
   }
 
-  return isAttributedCoroAwaitElidable(Operand->getType());
+  auto *Call = dyn_cast<CallExpr>(Operand->IgnoreImplicit());
+  if (!Call)
+    return;
+
+  if (!isAttributedCoroAwaitElidable(Call->getType()))
+    return;
+
+  Call->setCoroElideSafe();
+
+  // Check parameter
+  auto *Fn = llvm::dyn_cast_if_present<FunctionDecl>(Call->getCalleeDecl());
+  if (!Fn)
+    return;
+
+  size_t ParmIdx = 0;
+  for (ParmVarDecl *PD : Fn->parameters()) {
+    if (PD->hasAttr<CoroMustAwaitAttr>()) {
+      if (PD->isParameterPack()) {
+        size_t NumArgs = Call->getNumArgs();
+        for (size_t ArgIdx = ParmIdx; ArgIdx < NumArgs; ArgIdx++) {
+          applyAwaitElidableContext(Call->getArg(ArgIdx));
+        }
+        break;
+      } else {
+        applyAwaitElidableContext(Call->getArg(ParmIdx));
+      }
+    }
+    ParmIdx++;
+  }
 }
 
 // Attempts to resolve and build a CoawaitExpr from "raw" inputs, bailing out to
@@ -880,14 +908,12 @@ ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *Operand,
   }
 
   auto *RD = Promise->getType()->getAsCXXRecordDecl();
-  bool AwaitElidable =
-      isCoroAwaitElidableCall(Operand) &&
-      isAttributedCoroAwaitElidable(
-          getCurFunctionDecl(/*AllowLambda=*/true)->getReturnType());
-
-  if (AwaitElidable)
-    if (auto *Call = dyn_cast<CallExpr>(Operand->IgnoreImplicit()))
-      Call->setCoroElideSafe();
+
+  bool CurFnAwaitElidable = isAttributedCoroAwaitElidable(
+      getCurFunctionDecl(/*AllowLambda=*/true)->getReturnType());
+
+  if (CurFnAwaitElidable)
+    applyAwaitElidableContext(Operand);
 
   Expr *Transformed = Operand;
   if (lookupMember(*this, "await_transform", RD, Loc)) {
diff --git a/clang/test/CodeGenCoroutines/coro-await-elidable.cpp b/clang/test/CodeGenCoroutines/coro-await-elidable.cpp
index 8512995dfad45a..9e28f0351875cb 100644
--- a/clang/test/CodeGenCoroutines/coro-await-elidable.cpp
+++ b/clang/test/CodeGenCoroutines/coro-await-elidable.cpp
@@ -84,4 +84,23 @@ Task<int> nonelidable() {
   co_return 1;
 }
 
+// CHECK-LABEL: define{{.*}} @_Z8addTasks4TaskIiES0_{{.*}} {
+Task<int> addTasks([[clang::coro_must_await]] Task<int> t1, Task<int> t2) {
+  int i1 = co_await t1;
+  int i2 = co_await t2;
+  co_return i1 + i2;
+}
+
+// CHECK-LABEL: define{{.*}} @_Z10returnSamei{{.*}} {
+Task<int> returnSame(int i) {
+  co_return i;
+}
+
+// CHECK-LABEL: define{{.*}} @_Z21elidableWithMustAwaitv{{.*}} {
+Task<int> elidableWithMustAwait() {
+  // CHECK: call void @_Z10returnSamei(ptr {{.*}}, i32 noundef 2) #[[ELIDE_SAFE]]
+  // CHECK-NOT: call void @_Z10returnSamei(ptr {{.*}}, i32 noundef 3) #[[ELIDE_SAFE]]
+  co_return co_await addTasks(returnSame(2), returnSame(3));
+}
+
 // CHECK: attributes #[[ELIDE_SAFE]] = { coro_elide_safe }



More information about the cfe-commits mailing list