[clang] [Clang] Propagate elide safe context through [[clang::coro_await_elidable_argument]] (PR #108474)

Yuxuan Chen via cfe-commits cfe-commits at lists.llvm.org
Tue Sep 17 09:49:23 PDT 2024


https://github.com/yuxuanchen1997 updated https://github.com/llvm/llvm-project/pull/108474

>From c32b36e249cb1062cc05618181bf4cb4fdcd2133 Mon Sep 17 00:00:00 2001
From: Yuxuan Chen <ych at meta.com>
Date: Thu, 12 Sep 2024 17:13:57 -0700
Subject: [PATCH] [Clang] Propagate elide safe context through
 [[clang::coro_must_await]]

---
 clang/docs/ReleaseNotes.rst                   |  5 +-
 clang/include/clang/Basic/Attr.td             |  8 ++
 clang/include/clang/Basic/AttrDocs.td         | 83 ++++++++++++++++---
 clang/lib/Sema/SemaCoroutine.cpp              | 40 ++++++---
 .../CodeGenCoroutines/coro-await-elidable.cpp | 40 +++++++++
 ...a-attribute-supported-attributes-list.test |  1 +
 6 files changed, 153 insertions(+), 24 deletions(-)

diff --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst
index 8228055a1d861a..16a4ce54dbdcc8 100644
--- a/clang/docs/ReleaseNotes.rst
+++ b/clang/docs/ReleaseNotes.rst
@@ -249,7 +249,10 @@ Attribute Changes in Clang
   (#GH106864)
 
 - Introduced a new attribute ``[[clang::coro_await_elidable]]`` on coroutine return types
-  to express elideability at call sites where the coroutine is co_awaited as a prvalue.
+  to express elideability at call sites where the coroutine is invoked under a safe elide context.
+
+- Introduced a new attribute ``[[clang::coro_await_elidable_argument]]`` on function parameters
+  to propagate safe elide context to arguments if such function is also under a safe elide context.
 
 Improvements to Clang's diagnostics
 -----------------------------------
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 35b9716e13ff21..ce86116680d7a3 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -1258,6 +1258,14 @@ def CoroAwaitElidable : InheritableAttr {
   let SimpleHandler = 1;
 }
 
+def CoroAwaitElidableArgument : InheritableAttr {
+  let Spellings = [Clang<"coro_await_elidable_argument">];
+  let Subjects = SubjectList<[ParmVar]>;
+  let LangOpts = [CPlusPlus];
+  let Documentation = [CoroAwaitElidableArgumentDoc];
+  let SimpleHandler = 1;
+}
+
 // OSObject-based attributes.
 def OSConsumed : InheritableParamAttr {
   let Spellings = [Clang<"os_consumed">];
diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td
index cc9bc499c9cc24..8ef151b3f2fddb 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -8258,15 +8258,23 @@ but do not pass them to the underlying coroutine or pass them by value.
 def CoroAwaitElidableDoc : Documentation {
   let Category = DocCatDecl;
   let Content = [{
-The ``[[clang::coro_await_elidable]]`` is a class attribute which can be applied
-to a coroutine return type.
+The ``[[clang::coro_await_elidable]]`` is a class attribute which can be
+applied to a coroutine return type. It provides a hint to the compiler to apply
+Heap Allocation Elision more aggressively.
 
-When a coroutine function that returns such a type calls another coroutine function,
-the compiler performs heap allocation elision when the call to the coroutine function
-is immediately co_awaited as a prvalue. In this case, the coroutine frame for the
-callee will be a local variable within the enclosing braces in the caller's stack
-frame. And the local variable, like other variables in coroutines, may be collected
-into the coroutine frame, which may be allocated on the heap.
+When a coroutine function returns such a type, a direct call expression therein
+that returns a prvalue of a type attributed ``[[clang::coro_await_elidable]]``
+is said to be under a safe elide context if one of the following is true:
+- it is the immediate right-hand side operand to a co_await expression.
+- it is an argument to a ``[[clang::coro_await_elidable_argument]]`` parameter
+or parameter pack of another direct call expression under a safe elide context.
+
+Do note that the safe elide context applies only to the call expression itself,
+and the context does not transitively include any of its subexpressions unless
+exceptional rules of ``[[clang::coro_await_elidable_argument]]`` apply.
+
+The compiler performs heap allocation elision on call expressions under a safe
+elide context, if the callee is a coroutine.
 
 Example:
 
@@ -8281,8 +8289,63 @@ Example:
     co_await t;
   }
 
-The behavior is undefined if the caller coroutine is destroyed earlier than the
-callee coroutine.
+Such elision replaces the heap allocated activation frame of the callee coroutine
+with a local variable within the enclosing braces in the caller's stack frame.
+The local variable, like other variables in coroutines, may be collected into the
+coroutine frame, which may be allocated on the heap. The behavior is undefined
+if the caller coroutine is destroyed earlier than the callee coroutine.
+
+}];
+}
+
+def CoroAwaitElidableArgumentDoc : Documentation {
+  let Category = DocCatDecl;
+  let Content = [{
+
+The ``[[clang::coro_await_elidable_argument]]`` is a function parameter attribute.
+It works in conjunction with ``[[clang::coro_await_elidable]]`` to propagate a
+safe elide context to a parameter or parameter pack if the function is called
+under a safe elide context.
+
+This is sometimes necessary on utility functions used to compose or modify the
+behavior of a callee coroutine.
+
+Example:
+
+.. code-block:: c++
+
+  template <typename T>
+  class [[clang::coro_await_elidable]] Task { ... };
+
+  template <typename... T>
+  class [[clang::coro_await_elidable]] WhenAll { ... };
+
+  // `when_all` is a utility function that composes coroutines. It does not
+  // need to be a coroutine to propagate.
+  template <typename... T>
+  WhenAll<T...> when_all([[clang::coro_await_elidable_argument]] Task<T> tasks...);
+
+  Task<int> foo();
+  Task<int> bar();
+  Task<void> example1() {
+    // `when_all``, `foo``, and `bar` are all elide safe because `when_all` is
+    // under a safe elide context and, thanks to the [[clang::coro_await_elidable_argument]]
+    // attribute, such context is propagated to foo and bar.
+    co_await when_all(foo(), bar());
+  }
+
+  Task<void> example2() {
+    // `when_all` and `bar` are elide safe. `foo` is not elide safe.
+    auto f = foo();
+    co_await when_all(f, bar());
+  }
+
+
+  Task<void> example3() {
+    // None of the calls are elide safe.
+    auto t = when_all(foo(), bar());
+    co_await t;
+  }
 
 }];
 }
diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp
index a574d56646f3a2..89a0beadc61f3d 100644
--- a/clang/lib/Sema/SemaCoroutine.cpp
+++ b/clang/lib/Sema/SemaCoroutine.cpp
@@ -849,12 +849,28 @@ static bool isAttributedCoroAwaitElidable(const QualType &QT) {
   return Record && Record->hasAttr<CoroAwaitElidableAttr>();
 }
 
-static bool isCoroAwaitElidableCall(Expr *Operand) {
-  if (!Operand->isPRValue()) {
-    return false;
-  }
+static void applySafeElideContext(Expr *Operand) {
+  auto *Call = dyn_cast<CallExpr>(Operand->IgnoreImplicit());
+  if (!Call || !Call->isPRValue())
+    return;
+
+  if (!isAttributedCoroAwaitElidable(Call->getType()))
+    return;
+
+  Call->setCoroElideSafe();
 
-  return isAttributedCoroAwaitElidable(Operand->getType());
+  // Check parameter
+  auto *Fn = llvm::dyn_cast_if_present<FunctionDecl>(Call->getCalleeDecl());
+  if (!Fn)
+    return;
+
+  size_t ParmIdx = 0;
+  for (ParmVarDecl *PD : Fn->parameters()) {
+    if (PD->hasAttr<CoroAwaitElidableArgumentAttr>())
+      applySafeElideContext(Call->getArg(ParmIdx));
+
+    ParmIdx++;
+  }
 }
 
 // Attempts to resolve and build a CoawaitExpr from "raw" inputs, bailing out to
@@ -880,14 +896,12 @@ ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *Operand,
   }
 
   auto *RD = Promise->getType()->getAsCXXRecordDecl();
-  bool AwaitElidable =
-      isCoroAwaitElidableCall(Operand) &&
-      isAttributedCoroAwaitElidable(
-          getCurFunctionDecl(/*AllowLambda=*/true)->getReturnType());
-
-  if (AwaitElidable)
-    if (auto *Call = dyn_cast<CallExpr>(Operand->IgnoreImplicit()))
-      Call->setCoroElideSafe();
+
+  bool CurFnAwaitElidable = isAttributedCoroAwaitElidable(
+      getCurFunctionDecl(/*AllowLambda=*/true)->getReturnType());
+
+  if (CurFnAwaitElidable)
+    applySafeElideContext(Operand);
 
   Expr *Transformed = Operand;
   if (lookupMember(*this, "await_transform", RD, Loc)) {
diff --git a/clang/test/CodeGenCoroutines/coro-await-elidable.cpp b/clang/test/CodeGenCoroutines/coro-await-elidable.cpp
index 8512995dfad45a..deb19b4a500437 100644
--- a/clang/test/CodeGenCoroutines/coro-await-elidable.cpp
+++ b/clang/test/CodeGenCoroutines/coro-await-elidable.cpp
@@ -84,4 +84,44 @@ Task<int> nonelidable() {
   co_return 1;
 }
 
+// CHECK-LABEL: define{{.*}} @_Z8addTasksO4TaskIiES1_{{.*}} {
+Task<int> addTasks([[clang::coro_await_elidable_argument]] Task<int> &&t1, Task<int> &&t2) {
+  int i1 = co_await t1;
+  int i2 = co_await t2;
+  co_return i1 + i2;
+}
+
+// CHECK-LABEL: define{{.*}} @_Z10returnSamei{{.*}} {
+Task<int> returnSame(int i) {
+  co_return i;
+}
+
+// CHECK-LABEL: define{{.*}} @_Z21elidableWithMustAwaitv{{.*}} {
+Task<int> elidableWithMustAwait() {
+  // CHECK: call void @_Z10returnSamei(ptr {{.*}}, i32 noundef 2) #[[ELIDE_SAFE]]
+  // CHECK: call void @_Z10returnSamei(ptr {{.*}}, i32 noundef 3){{$}}
+  co_return co_await addTasks(returnSame(2), returnSame(3));
+}
+
+template <typename... Args>
+Task<int> sumAll([[clang::coro_await_elidable_argument]] Args && ... tasks);
+
+// CHECK-LABEL: define{{.*}} @_Z16elidableWithPackv{{.*}} {
+Task<int> elidableWithPack() {
+  // CHECK: call void @_Z10returnSamei(ptr {{.*}}, i32 noundef 1){{$}}
+  // CHECK: call void @_Z10returnSamei(ptr {{.*}}, i32 noundef 2) #[[ELIDE_SAFE]]
+  // CHECK: call void @_Z10returnSamei(ptr {{.*}}, i32 noundef 3) #[[ELIDE_SAFE]]
+  auto t = returnSame(1);
+  co_return co_await sumAll(t, returnSame(2), returnSame(3));
+}
+
+
+// CHECK-LABEL: define{{.*}} @_Z25elidableWithPackRecursivev{{.*}} {
+Task<int> elidableWithPackRecursive() {
+  // CHECK: call void @_Z10returnSamei(ptr {{.*}}, i32 noundef 1) #[[ELIDE_SAFE]]
+  // CHECK: call void @_Z10returnSamei(ptr {{.*}}, i32 noundef 2){{$}}
+  // CHECK: call void @_Z10returnSamei(ptr {{.*}}, i32 noundef 3) #[[ELIDE_SAFE]]
+  co_return co_await sumAll(addTasks(returnSame(1), returnSame(2)), returnSame(3));
+}
+
 // CHECK: attributes #[[ELIDE_SAFE]] = { coro_elide_safe }
diff --git a/clang/test/Misc/pragma-attribute-supported-attributes-list.test b/clang/test/Misc/pragma-attribute-supported-attributes-list.test
index baa1816358b156..914f94c08a9fd9 100644
--- a/clang/test/Misc/pragma-attribute-supported-attributes-list.test
+++ b/clang/test/Misc/pragma-attribute-supported-attributes-list.test
@@ -60,6 +60,7 @@
 // CHECK-NEXT: ConsumableSetOnRead (SubjectMatchRule_record)
 // CHECK-NEXT: Convergent (SubjectMatchRule_function)
 // CHECK-NEXT: CoroAwaitElidable (SubjectMatchRule_record)
+// CHECK-NEXT: CoroAwaitElidableArgument (SubjectMatchRule_variable_is_parameter)
 // CHECK-NEXT: CoroDisableLifetimeBound (SubjectMatchRule_function)
 // CHECK-NEXT: CoroLifetimeBound (SubjectMatchRule_record)
 // CHECK-NEXT: CoroOnlyDestroyWhenComplete (SubjectMatchRule_record)



More information about the cfe-commits mailing list