[clang] [llvm] [Clang][Coroutines] Introducing the `[[clang::coro_inplace_task]]` attribute (PR #94693)

Yuxuan Chen via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 17 22:57:28 PDT 2024


https://github.com/yuxuanchen1997 updated https://github.com/llvm/llvm-project/pull/94693

>From 093cd09a5b479deaabd3013be1fd6849f6c174d6 Mon Sep 17 00:00:00 2001
From: Yuxuan Chen <yuxuanchen1997 at outlook.com>
Date: Tue, 4 Jun 2024 23:22:00 -0700
Subject: [PATCH 1/2] [Clang] Introduce [[clang::structured_concurrency]]

---
 clang/include/clang/AST/ExprCXX.h             |  5 ++
 clang/include/clang/AST/Stmt.h                |  1 +
 clang/include/clang/Basic/Attr.td             |  8 ++
 clang/include/clang/Basic/AttrDocs.td         | 20 +++++
 clang/lib/CodeGen/CGExpr.cpp                  |  6 +-
 clang/lib/CodeGen/CodeGenFunction.h           |  3 +
 clang/lib/Sema/SemaCoroutine.cpp              | 22 ++++-
 clang/test/CodeGenCoroutines/Inputs/utility.h | 13 +++
 .../coro-structured-concurrency.cpp           | 84 +++++++++++++++++++
 ...a-attribute-supported-attributes-list.test |  1 +
 llvm/include/llvm/IR/Intrinsics.td            |  3 +
 .../lib/Transforms/Coroutines/CoroCleanup.cpp | 11 ++-
 llvm/lib/Transforms/Coroutines/CoroElide.cpp  | 56 ++++++++++++-
 llvm/lib/Transforms/Coroutines/Coroutines.cpp |  1 +
 .../coro-elide-structured-concurrency.ll      | 62 ++++++++++++++
 15 files changed, 288 insertions(+), 8 deletions(-)
 create mode 100644 clang/test/CodeGenCoroutines/Inputs/utility.h
 create mode 100644 clang/test/CodeGenCoroutines/coro-structured-concurrency.cpp
 create mode 100644 llvm/test/Transforms/Coroutines/coro-elide-structured-concurrency.ll

diff --git a/clang/include/clang/AST/ExprCXX.h b/clang/include/clang/AST/ExprCXX.h
index c2feac525c1ea..5e35099db9c84 100644
--- a/clang/include/clang/AST/ExprCXX.h
+++ b/clang/include/clang/AST/ExprCXX.h
@@ -5213,6 +5213,11 @@ class CoawaitExpr : public CoroutineSuspendExpr {
   bool isImplicit() const { return CoawaitBits.IsImplicit; }
   void setIsImplicit(bool value = true) { CoawaitBits.IsImplicit = value; }
 
+  bool isInplaceCall() const { return CoawaitBits.IsInplaceCall; }
+  void setIsInplaceCall(bool value = true) {
+    CoawaitBits.IsInplaceCall = value;
+  }
+
   static bool classof(const Stmt *T) {
     return T->getStmtClass() == CoawaitExprClass;
   }
diff --git a/clang/include/clang/AST/Stmt.h b/clang/include/clang/AST/Stmt.h
index 9cd7a364cd3f1..81f67fd266b7e 100644
--- a/clang/include/clang/AST/Stmt.h
+++ b/clang/include/clang/AST/Stmt.h
@@ -1163,6 +1163,7 @@ class alignas(void *) Stmt {
 
     LLVM_PREFERRED_TYPE(bool)
     unsigned IsImplicit : 1;
+    unsigned IsInplaceCall : 1;
   };
 
   //===--- Obj-C Expression bitfields classes ---===//
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index b70b0c8b836a5..7c291978a27ed 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -1212,6 +1212,14 @@ def CoroDisableLifetimeBound : InheritableAttr {
   let SimpleHandler = 1;
 }
 
+def CoroInplaceTask : InheritableAttr {
+  let Spellings = [Clang<"coro_inplace_task">];
+  let Subjects = SubjectList<[CXXRecord]>;
+  let LangOpts = [CPlusPlus];
+  let Documentation = [CoroInplaceTaskDoc];
+  let SimpleHandler = 1;
+}
+
 // OSObject-based attributes.
 def OSConsumed : InheritableParamAttr {
   let Spellings = [Clang<"os_consumed">];
diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td
index 70d5dfa8aaf86..61253cfb4af92 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -8015,6 +8015,26 @@ but do not pass them to the underlying coroutine or pass them by value.
 }];
 }
 
+def CoroInplaceTaskDoc : Documentation {
+  let Category = DocCatDecl;
+  let Content = [{
+The ``[[clang::coro_inplace_task]]`` is a class attribute which can be applied
+to a coroutine return type.
+
+When a coroutine function that returns such a type calls another coroutine function,
+the compiler performs heap allocation elision when the following conditions are all met:
+- callee coroutine function returns a type that is annotated with
+  ``[[clang::coro_inplace_task]]``.
+- The callee coroutine function is inlined.
+- In caller coroutine, the return value of the callee is a prvalue or an xvalue, and
+- The temporary expression containing the callee coroutine object is immediately co_awaited.
+
+The behavior is undefined if any of the following condition was met:
+- the caller coroutine is destroyed earlier than the callee coroutine.
+
+  }];
+}
+
 def CountedByDocs : Documentation {
   let Category = DocCatField;
   let Content = [{
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index b6718a46e8c50..43b4936e5180a 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -618,7 +618,11 @@ EmitMaterializeTemporaryExpr(const MaterializeTemporaryExpr *M) {
     }
   }
 
-  return MakeAddrLValue(Object, M->getType(), AlignmentSource::Decl);
+  auto Ret = MakeAddrLValue(Object, M->getType(), AlignmentSource::Decl);
+  if (TemporaryValues.contains(M)) {
+    TemporaryValues[M] = Ret.getPointer(*this);
+  }
+  return Ret;
 }
 
 RValue
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 8525f66082a4e..6dd3da8aacaeb 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -369,6 +369,9 @@ class CodeGenFunction : public CodeGenTypeCache {
   };
   CGCoroInfo CurCoro;
 
+  llvm::SmallDenseMap<const MaterializeTemporaryExpr *, llvm::Value *>
+      TemporaryValues;
+
   bool isCoroutine() const {
     return CurCoro.Data != nullptr;
   }
diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp
index 81334c817b2af..2dbce5e54c76b 100644
--- a/clang/lib/Sema/SemaCoroutine.cpp
+++ b/clang/lib/Sema/SemaCoroutine.cpp
@@ -825,6 +825,19 @@ ExprResult Sema::BuildOperatorCoawaitLookupExpr(Scope *S, SourceLocation Loc) {
   return CoawaitOp;
 }
 
+static bool isAttributedCoroInplaceTask(const QualType &QT) {
+  auto *Record = QT->getAsCXXRecordDecl();
+  return Record && Record->hasAttr<CoroInplaceTaskAttr>();
+}
+
+static bool isCoroInplaceCall(Expr *Operand) {
+  if (!Operand->isPRValue()) {
+    return false;
+  }
+
+  return isAttributedCoroInplaceTask(Operand->getType());
+}
+
 // Attempts to resolve and build a CoawaitExpr from "raw" inputs, bailing out to
 // DependentCoawaitExpr if needed.
 ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *Operand,
@@ -864,7 +877,14 @@ ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *Operand,
   if (Awaiter.isInvalid())
     return ExprError();
 
-  return BuildResolvedCoawaitExpr(Loc, Operand, Awaiter.get());
+  auto Res = BuildResolvedCoawaitExpr(Loc, Operand, Awaiter.get());
+  if (!Res.isInvalid() && isCoroInplaceCall(Operand) &&
+      isAttributedCoroInplaceTask(
+          getCurFunctionDecl(/*AllowLambda=*/true)->getReturnType())) {
+    // BuildResolvedCoawaitExpr must return a CoawaitExpr, if valid.
+    Res.getAs<CoawaitExpr>()->setIsInplaceCall();
+  }
+  return Res;
 }
 
 ExprResult Sema::BuildResolvedCoawaitExpr(SourceLocation Loc, Expr *Operand,
diff --git a/clang/test/CodeGenCoroutines/Inputs/utility.h b/clang/test/CodeGenCoroutines/Inputs/utility.h
new file mode 100644
index 0000000000000..43c6d27823bd4
--- /dev/null
+++ b/clang/test/CodeGenCoroutines/Inputs/utility.h
@@ -0,0 +1,13 @@
+// This is a mock file for <utility>
+
+namespace std {
+
+template <typename T> struct remove_reference { using type = T; };
+template <typename T> struct remove_reference<T &> { using type = T; };
+template <typename T> struct remove_reference<T &&> { using type = T; };
+
+template <typename T>
+constexpr typename std::remove_reference<T>::type&& move(T &&t) noexcept {
+  return static_cast<typename std::remove_reference<T>::type &&>(t);
+}
+}
diff --git a/clang/test/CodeGenCoroutines/coro-structured-concurrency.cpp b/clang/test/CodeGenCoroutines/coro-structured-concurrency.cpp
new file mode 100644
index 0000000000000..2569643221da0
--- /dev/null
+++ b/clang/test/CodeGenCoroutines/coro-structured-concurrency.cpp
@@ -0,0 +1,84 @@
+// This file tests the coro_structured_concurrency attribute semantics. 
+// RUN: %clang_cc1 -std=c++20 -disable-llvm-passes -emit-llvm %s -o - | FileCheck %s
+
+#include "Inputs/coroutine.h"
+#include "Inputs/utility.h"
+
+template <typename T>
+struct [[clang::coro_inplace_task]] Task {
+  struct promise_type {
+    struct FinalAwaiter {
+      bool await_ready() const noexcept { return false; }
+
+      template <typename P>
+      std::coroutine_handle<> await_suspend(std::coroutine_handle<P> coro) noexcept {
+        if (!coro)
+          return std::noop_coroutine();
+        return coro.promise().continuation;
+      }
+      void await_resume() noexcept {}
+    };
+
+    Task get_return_object() noexcept {
+      return std::coroutine_handle<promise_type>::from_promise(*this);
+    }
+
+    std::suspend_always initial_suspend() noexcept { return {}; }
+    FinalAwaiter final_suspend() noexcept { return {}; }
+    void unhandled_exception() noexcept {}
+    void return_value(T x) noexcept {
+      value = x;
+    }
+
+    std::coroutine_handle<> continuation;
+    T value;
+  };
+
+  Task(std::coroutine_handle<promise_type> handle) : handle(handle) {}
+  ~Task() {
+    if (handle)
+      handle.destroy();
+  }
+
+  struct Awaiter {
+    Awaiter(Task *t) : task(t) {}
+    bool await_ready() const noexcept { return false; }
+    void await_suspend(std::coroutine_handle<void> continuation) noexcept {}
+    T await_resume() noexcept {
+      return task->handle.promise().value;
+    }
+
+    Task *task;
+  };
+
+  auto operator co_await() {
+    return Awaiter{this};
+  }
+
+private:
+  std::coroutine_handle<promise_type> handle;
+};
+
+// CHECK-LABEL: define{{.*}} @_Z6calleev 
+Task<int> callee() {
+  co_return 1;
+}
+
+// CHECK-LABEL: define{{.*}} @_Z8elidablev 
+Task<int> elidable() {
+  // CHECK: %[[TARK_OBJ:.+]] = alloca %struct.Task
+  // CHECK: call void @llvm.coro.safe.elide(ptr %[[TARK_OBJ:.+]])
+  co_return co_await callee();
+}
+
+// CHECK-LABEL: define{{.*}} @_Z11nonelidablev 
+Task<int> nonelidable() {
+  // CHECK: %[[TARK_OBJ:.+]] = alloca %struct.Task
+  auto t = callee();
+  // Because we aren't co_awaiting a prvalue, we cannot elide here.
+  // CHECK-NOT: call void @llvm.coro.safe.elide(ptr %[[TARK_OBJ:.+]])
+  co_await t;
+  co_await std::move(t);
+  
+  co_return 1;
+}
diff --git a/clang/test/Misc/pragma-attribute-supported-attributes-list.test b/clang/test/Misc/pragma-attribute-supported-attributes-list.test
index 99732694f72a5..c37e0ac9fec46 100644
--- a/clang/test/Misc/pragma-attribute-supported-attributes-list.test
+++ b/clang/test/Misc/pragma-attribute-supported-attributes-list.test
@@ -62,6 +62,7 @@
 // CHECK-NEXT: CoroLifetimeBound (SubjectMatchRule_record)
 // CHECK-NEXT: CoroOnlyDestroyWhenComplete (SubjectMatchRule_record)
 // CHECK-NEXT: CoroReturnType (SubjectMatchRule_record)
+// CHECK-NEXT: CoroInplaceTask (SubjectMatchRule_record)
 // CHECK-NEXT: CoroWrapper (SubjectMatchRule_function)
 // CHECK-NEXT: DLLExport (SubjectMatchRule_function, SubjectMatchRule_variable, SubjectMatchRule_record, SubjectMatchRule_objc_interface)
 // CHECK-NEXT: DLLImport (SubjectMatchRule_function, SubjectMatchRule_variable, SubjectMatchRule_record, SubjectMatchRule_objc_interface)
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index ef500329d1fb9..7b17f3061269c 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -1728,6 +1728,9 @@ def int_coro_subfn_addr : DefaultAttrsIntrinsic<
     [IntrReadMem, IntrArgMemOnly, ReadOnly<ArgIndex<0>>,
      NoCapture<ArgIndex<0>>]>;
 
+def int_coro_safe_elide : DefaultAttrsIntrinsic<
+    [], [llvm_ptr_ty], []>;
+
 ///===-------------------------- Other Intrinsics --------------------------===//
 //
 // TODO: We should introduce a new memory kind fo traps (and other side effects
diff --git a/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp b/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp
index 3e3825fcd50e2..71229eae5cb47 100644
--- a/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp
@@ -8,10 +8,11 @@
 
 #include "llvm/Transforms/Coroutines/CoroCleanup.h"
 #include "CoroInternal.h"
+#include "llvm/IR/Function.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/PassManager.h"
-#include "llvm/IR/Function.h"
 #include "llvm/Transforms/Scalar/SimplifyCFG.h"
 
 using namespace llvm;
@@ -80,7 +81,7 @@ bool Lowerer::lower(Function &F) {
         } else
           continue;
         break;
-      case Intrinsic::coro_async_size_replace:
+      case Intrinsic::coro_async_size_replace: {
         auto *Target = cast<ConstantStruct>(
             cast<GlobalVariable>(II->getArgOperand(0)->stripPointerCasts())
                 ->getInitializer());
@@ -98,6 +99,9 @@ bool Lowerer::lower(Function &F) {
         Target->replaceAllUsesWith(NewFuncPtrStruct);
         break;
       }
+      case Intrinsic::coro_safe_elide:
+        break;
+      }
       II->eraseFromParent();
       Changed = true;
     }
@@ -111,7 +115,8 @@ static bool declaresCoroCleanupIntrinsics(const Module &M) {
       M, {"llvm.coro.alloc", "llvm.coro.begin", "llvm.coro.subfn.addr",
           "llvm.coro.free", "llvm.coro.id", "llvm.coro.id.retcon",
           "llvm.coro.id.async", "llvm.coro.id.retcon.once",
-          "llvm.coro.async.size.replace", "llvm.coro.async.resume"});
+          "llvm.coro.async.size.replace", "llvm.coro.async.resume",
+          "llvm.coro.safe.elide"});
 }
 
 PreservedAnalyses CoroCleanupPass::run(Module &M,
diff --git a/llvm/lib/Transforms/Coroutines/CoroElide.cpp b/llvm/lib/Transforms/Coroutines/CoroElide.cpp
index 74b5ccb7b9b71..dd2f72410c931 100644
--- a/llvm/lib/Transforms/Coroutines/CoroElide.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroElide.cpp
@@ -7,12 +7,14 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Transforms/Coroutines/CoroElide.h"
+#include "CoroInstr.h"
 #include "CoroInternal.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/Analysis/AliasAnalysis.h"
 #include "llvm/Analysis/InstructionSimplify.h"
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
+#include "llvm/Analysis/PostDominators.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/InstIterator.h"
 #include "llvm/Support/ErrorHandling.h"
@@ -56,7 +58,8 @@ class FunctionElideInfo {
 class CoroIdElider {
 public:
   CoroIdElider(CoroIdInst *CoroId, FunctionElideInfo &FEI, AAResults &AA,
-               DominatorTree &DT, OptimizationRemarkEmitter &ORE);
+               DominatorTree &DT, PostDominatorTree &PDT,
+               OptimizationRemarkEmitter &ORE);
   void elideHeapAllocations(uint64_t FrameSize, Align FrameAlign);
   bool lifetimeEligibleForElide() const;
   bool attemptElide();
@@ -68,6 +71,7 @@ class CoroIdElider {
   FunctionElideInfo &FEI;
   AAResults &AA;
   DominatorTree &DT;
+  PostDominatorTree &PDT;
   OptimizationRemarkEmitter &ORE;
 
   SmallVector<CoroBeginInst *, 1> CoroBegins;
@@ -183,8 +187,9 @@ void FunctionElideInfo::collectPostSplitCoroIds() {
 
 CoroIdElider::CoroIdElider(CoroIdInst *CoroId, FunctionElideInfo &FEI,
                            AAResults &AA, DominatorTree &DT,
+                           PostDominatorTree &PDT,
                            OptimizationRemarkEmitter &ORE)
-    : CoroId(CoroId), FEI(FEI), AA(AA), DT(DT), ORE(ORE) {
+    : CoroId(CoroId), FEI(FEI), AA(AA), DT(DT), PDT(PDT), ORE(ORE) {
   // Collect all coro.begin and coro.allocs associated with this coro.id.
   for (User *U : CoroId->users()) {
     if (auto *CB = dyn_cast<CoroBeginInst>(U))
@@ -336,6 +341,41 @@ bool CoroIdElider::canCoroBeginEscape(
   return false;
 }
 
+// FIXME: This is not accounting for the stores to tasks whose handle is not
+// zero offset.
+static const StoreInst *getPostDominatingStoreToTask(const CoroBeginInst *CB,
+                                                     PostDominatorTree &PDT) {
+  const StoreInst *OnlyStore = nullptr;
+
+  for (auto *U : CB->users()) {
+    auto *Store = dyn_cast<StoreInst>(U);
+    if (Store && Store->getValueOperand() == CB) {
+      if (OnlyStore) {
+        // Store must be unique. one coro begin getting stored to multiple
+        // stores is not accepted.
+        return nullptr;
+      }
+      OnlyStore = Store;
+    }
+  }
+
+  if (!OnlyStore || !PDT.dominates(OnlyStore, CB)) {
+    return nullptr;
+  }
+
+  return OnlyStore;
+}
+
+static bool isMarkedSafeElide(const llvm::Value *V) {
+  for (auto *U : V->users()) {
+    auto *II = dyn_cast<IntrinsicInst>(U);
+    if (II && (II->getIntrinsicID() == Intrinsic::coro_safe_elide)) {
+      return true;
+    }
+  }
+  return false;
+}
+
 bool CoroIdElider::lifetimeEligibleForElide() const {
   // If no CoroAllocs, we cannot suppress allocation, so elision is not
   // possible.
@@ -364,6 +404,15 @@ bool CoroIdElider::lifetimeEligibleForElide() const {
 
   // Filter out the coro.destroy that lie along exceptional paths.
   for (const auto *CB : CoroBegins) {
+    // This might be too strong of a condition but should be very safe.
+    // If the CB is unconditionally stored into a "Task Like Object",
+    // and such object is "safe elide".
+    if (auto *MaybeStoreToTask = getPostDominatingStoreToTask(CB, PDT)) {
+      auto Dest = MaybeStoreToTask->getPointerOperand();
+      if (isMarkedSafeElide(Dest))
+        continue;
+    }
+
     auto It = DestroyAddr.find(CB);
 
     // FIXME: If we have not found any destroys for this coro.begin, we
@@ -476,11 +525,12 @@ PreservedAnalyses CoroElidePass::run(Function &F, FunctionAnalysisManager &AM) {
 
   AAResults &AA = AM.getResult<AAManager>(F);
   DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);
+  PostDominatorTree &PDT = AM.getResult<PostDominatorTreeAnalysis>(F);
   auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
 
   bool Changed = false;
   for (auto *CII : FEI.getCoroIds()) {
-    CoroIdElider CIE(CII, FEI, AA, DT, ORE);
+    CoroIdElider CIE(CII, FEI, AA, DT, PDT, ORE);
     Changed |= CIE.attemptElide();
   }
 
diff --git a/llvm/lib/Transforms/Coroutines/Coroutines.cpp b/llvm/lib/Transforms/Coroutines/Coroutines.cpp
index 1a92bc1636257..48c02e5406b75 100644
--- a/llvm/lib/Transforms/Coroutines/Coroutines.cpp
+++ b/llvm/lib/Transforms/Coroutines/Coroutines.cpp
@@ -86,6 +86,7 @@ static const char *const CoroIntrinsics[] = {
     "llvm.coro.prepare.retcon",
     "llvm.coro.promise",
     "llvm.coro.resume",
+    "llvm.coro.safe.elide",
     "llvm.coro.save",
     "llvm.coro.size",
     "llvm.coro.subfn.addr",
diff --git a/llvm/test/Transforms/Coroutines/coro-elide-structured-concurrency.ll b/llvm/test/Transforms/Coroutines/coro-elide-structured-concurrency.ll
new file mode 100644
index 0000000000000..97c615d00d238
--- /dev/null
+++ b/llvm/test/Transforms/Coroutines/coro-elide-structured-concurrency.ll
@@ -0,0 +1,62 @@
+; Tests that the coro.destroy and coro.resume are devirtualized where possible,
+; SCC pipeline restarts and inlines the direct calls.
+; RUN: opt < %s -S -passes='inline,coro-elide' | FileCheck %s
+
+%struct.Task = type { ptr }
+
+declare void @print(i32) nounwind
+
+; resume part of the coroutine
+define fastcc void @callee.resume(ptr dereferenceable(1)) {
+  tail call void @print(i32 0)
+  ret void
+}
+
+; destroy part of the coroutine
+define fastcc void @callee.destroy(ptr) {
+  tail call void @print(i32 1)
+  ret void
+}
+
+; cleanup part of the coroutine
+define fastcc void @callee.cleanup(ptr) {
+  tail call void @print(i32 2)
+  ret void
+}
+
+ at callee.resumers = internal constant [3 x ptr] [
+  ptr @callee.resume, ptr @callee.destroy, ptr @callee.cleanup]
+
+declare void @alloc(i1) nounwind
+
+; CHECK: define ptr @callee()
+define ptr @callee() {
+entry:
+  %task = alloca %struct.Task, align 8
+  %id = call token @llvm.coro.id(i32 0, ptr null,
+                          ptr @callee,
+                          ptr @callee.resumers)
+  %alloc = call i1 @llvm.coro.alloc(token %id)
+  %hdl = call ptr @llvm.coro.begin(token %id, ptr null)
+  store ptr %hdl, ptr %task
+  ret ptr %task
+}
+
+; CHECK: define ptr @caller()
+define ptr @caller() {
+entry:
+  %task = call ptr @callee()
+
+  ; CHECK: %[[id:.+]] = call token @llvm.coro.id(i32 0, ptr null, ptr @callee, ptr @callee.resumers)
+  ; CHECK-NOT: call i1 @llvm.coro.alloc(token %[[id]])
+  call void @llvm.coro.safe.elide(ptr %task)
+
+  ret ptr %task
+}
+
+declare token @llvm.coro.id(i32, ptr, ptr, ptr)
+declare ptr @llvm.coro.begin(token, ptr)
+declare ptr @llvm.coro.frame()
+declare ptr @llvm.coro.subfn.addr(ptr, i8)
+declare i1 @llvm.coro.alloc(token)
+declare void @llvm.coro.safe.elide(ptr)

>From 7488b1861f4d382dc7dffc67a87ac52ca72e4ad4 Mon Sep 17 00:00:00 2001
From: Yuxuan Chen <ych at meta.com>
Date: Mon, 17 Jun 2024 17:00:42 -0700
Subject: [PATCH 2/2] [Clang] address upstream comments, use new name and Sema
 approach

---
 clang/include/clang/AST/Expr.h                |  7 ++++
 clang/include/clang/AST/ExprCXX.h             | 22 ++++++----
 clang/include/clang/AST/Stmt.h                |  7 +++-
 clang/lib/CodeGen/CGCoroutine.cpp             | 31 +++++++++++---
 clang/lib/CodeGen/CGExpr.cpp                  |  6 +--
 clang/lib/CodeGen/CodeGenFunction.h           |  3 --
 clang/lib/Sema/SemaCoroutine.cpp              | 42 +++++++++++++++++--
 clang/lib/Serialization/ASTReaderStmt.cpp     |  8 +++-
 clang/lib/Serialization/ASTWriterStmt.cpp     |  3 +-
 ...a-attribute-supported-attributes-list.test |  2 +-
 10 files changed, 101 insertions(+), 30 deletions(-)

diff --git a/clang/include/clang/AST/Expr.h b/clang/include/clang/AST/Expr.h
index f2bf667636dc9..809e17e07f104 100644
--- a/clang/include/clang/AST/Expr.h
+++ b/clang/include/clang/AST/Expr.h
@@ -3127,6 +3127,13 @@ class CallExpr : public Expr {
     return getUnusedResultAttr(Ctx) != nullptr;
   }
 
+  bool isCoroutineInplaceTaskCall() const {
+    return CallExprBits.IsCoroutineInplaceTaskCall;
+  }
+  void setIsCoroutineInplaceTaskCall(bool value = true) {
+    CallExprBits.IsCoroutineInplaceTaskCall = value;
+  }
+
   SourceLocation getRParenLoc() const { return RParenLoc; }
   void setRParenLoc(SourceLocation L) { RParenLoc = L; }
 
diff --git a/clang/include/clang/AST/ExprCXX.h b/clang/include/clang/AST/ExprCXX.h
index 5e35099db9c84..e93f0856beb52 100644
--- a/clang/include/clang/AST/ExprCXX.h
+++ b/clang/include/clang/AST/ExprCXX.h
@@ -5082,7 +5082,8 @@ class CoroutineSuspendExpr : public Expr {
   enum SubExpr { Operand, Common, Ready, Suspend, Resume, Count };
 
   Stmt *SubExprs[SubExpr::Count];
-  OpaqueValueExpr *OpaqueValue = nullptr;
+  OpaqueValueExpr *CommonExprOpaqueValue = nullptr;
+  OpaqueValueExpr *OperandOpaqueValue = nullptr;
 
 public:
   // These types correspond to the three C++ 'await_suspend' return variants
@@ -5090,10 +5091,10 @@ class CoroutineSuspendExpr : public Expr {
 
   CoroutineSuspendExpr(StmtClass SC, SourceLocation KeywordLoc, Expr *Operand,
                        Expr *Common, Expr *Ready, Expr *Suspend, Expr *Resume,
-                       OpaqueValueExpr *OpaqueValue)
+                       OpaqueValueExpr *CommonExprOpaqueValue)
       : Expr(SC, Resume->getType(), Resume->getValueKind(),
              Resume->getObjectKind()),
-        KeywordLoc(KeywordLoc), OpaqueValue(OpaqueValue) {
+        KeywordLoc(KeywordLoc), CommonExprOpaqueValue(CommonExprOpaqueValue) {
     SubExprs[SubExpr::Operand] = Operand;
     SubExprs[SubExpr::Common] = Common;
     SubExprs[SubExpr::Ready] = Ready;
@@ -5128,7 +5129,12 @@ class CoroutineSuspendExpr : public Expr {
   }
 
   /// getOpaqueValue - Return the opaque value placeholder.
-  OpaqueValueExpr *getOpaqueValue() const { return OpaqueValue; }
+  OpaqueValueExpr *getCommonExprOpaqueValue() const {
+    return CommonExprOpaqueValue;
+  }
+
+  OpaqueValueExpr *getOperandOpaqueValue() const { return OperandOpaqueValue; }
+  void setOperandOpaqueValue(OpaqueValueExpr *E) { OperandOpaqueValue = E; }
 
   Expr *getReadyExpr() const {
     return static_cast<Expr*>(SubExprs[SubExpr::Ready]);
@@ -5194,9 +5200,9 @@ class CoawaitExpr : public CoroutineSuspendExpr {
 public:
   CoawaitExpr(SourceLocation CoawaitLoc, Expr *Operand, Expr *Common,
               Expr *Ready, Expr *Suspend, Expr *Resume,
-              OpaqueValueExpr *OpaqueValue, bool IsImplicit = false)
+              OpaqueValueExpr *CommonExprOpaqueValue, bool IsImplicit = false)
       : CoroutineSuspendExpr(CoawaitExprClass, CoawaitLoc, Operand, Common,
-                             Ready, Suspend, Resume, OpaqueValue) {
+                             Ready, Suspend, Resume, CommonExprOpaqueValue) {
     CoawaitBits.IsImplicit = IsImplicit;
   }
 
@@ -5280,9 +5286,9 @@ class CoyieldExpr : public CoroutineSuspendExpr {
 public:
   CoyieldExpr(SourceLocation CoyieldLoc, Expr *Operand, Expr *Common,
               Expr *Ready, Expr *Suspend, Expr *Resume,
-              OpaqueValueExpr *OpaqueValue)
+              OpaqueValueExpr *CommonExprOpaqueValue)
       : CoroutineSuspendExpr(CoyieldExprClass, CoyieldLoc, Operand, Common,
-                             Ready, Suspend, Resume, OpaqueValue) {}
+                             Ready, Suspend, Resume, CommonExprOpaqueValue) {}
   CoyieldExpr(SourceLocation CoyieldLoc, QualType Ty, Expr *Operand,
               Expr *Common)
       : CoroutineSuspendExpr(CoyieldExprClass, CoyieldLoc, Ty, Operand,
diff --git a/clang/include/clang/AST/Stmt.h b/clang/include/clang/AST/Stmt.h
index 81f67fd266b7e..7f2f119e26a9c 100644
--- a/clang/include/clang/AST/Stmt.h
+++ b/clang/include/clang/AST/Stmt.h
@@ -561,8 +561,11 @@ class alignas(void *) Stmt {
     LLVM_PREFERRED_TYPE(bool)
     unsigned HasFPFeatures : 1;
 
+    LLVM_PREFERRED_TYPE(bool)
+    unsigned IsCoroutineInplaceTaskCall : 1;
+
     /// Padding used to align OffsetToTrailingObjects to a byte multiple.
-    unsigned : 24 - 3 - NumExprBits;
+    unsigned : 24 - 4 - NumExprBits;
 
     /// The offset in bytes from the this pointer to the start of the
     /// trailing objects belonging to CallExpr. Intentionally byte sized
@@ -1163,6 +1166,8 @@ class alignas(void *) Stmt {
 
     LLVM_PREFERRED_TYPE(bool)
     unsigned IsImplicit : 1;
+
+    LLVM_PREFERRED_TYPE(bool)
     unsigned IsInplaceCall : 1;
   };
 
diff --git a/clang/lib/CodeGen/CGCoroutine.cpp b/clang/lib/CodeGen/CGCoroutine.cpp
index b4c724422c14a..3985cf2ed8776 100644
--- a/clang/lib/CodeGen/CGCoroutine.cpp
+++ b/clang/lib/CodeGen/CGCoroutine.cpp
@@ -15,6 +15,7 @@
 #include "llvm/ADT/ScopeExit.h"
 #include "clang/AST/StmtCXX.h"
 #include "clang/AST/StmtVisitor.h"
+#include "llvm/IR/Intrinsics.h"
 
 using namespace clang;
 using namespace CodeGen;
@@ -224,9 +225,26 @@ static LValueOrRValue emitSuspendExpression(CodeGenFunction &CGF, CGCoroData &Co
                                     AwaitKind Kind, AggValueSlot aggSlot,
                                     bool ignoreResult, bool forLValue) {
   auto *E = S.getCommonExpr();
+  auto &Builder = CGF.Builder;
 
-  auto CommonBinder =
-      CodeGenFunction::OpaqueValueMappingData::bind(CGF, S.getOpaqueValue(), E);
+  // S.getOperandOpaqueValue() may be null, in this case it maps to nothing.
+  std::optional<CodeGenFunction::OpaqueValueMapping> OperandMapping = std::nullopt;
+  auto CallOV = S.getOperandOpaqueValue();
+  if (CallOV) {
+    OperandMapping.emplace(CGF, CallOV);
+    LValue LV = CGF.getOrCreateOpaqueLValueMapping(CallOV);
+    llvm::Value *Value = LV.getPointer(CGF);
+    // for (auto *U : Value->users()) {
+    //   if (auto *Call = cast<llvm::CallBase>(U)) {
+    //     Call->dump();
+    //   }
+    // }
+    auto SafeElide = CGF.CGM.getIntrinsic(llvm::Intrinsic::coro_safe_elide);
+    if (cast<CallExpr>(CallOV->getSourceExpr())->isCoroutineInplaceTaskCall())
+      Builder.CreateCall(SafeElide, Value);
+  }
+  auto CommonBinder = CodeGenFunction::OpaqueValueMappingData::bind(
+      CGF, S.getCommonExprOpaqueValue(), E);
   auto UnbindCommonOnExit =
       llvm::make_scope_exit([&] { CommonBinder.unbind(CGF); });
 
@@ -241,7 +259,6 @@ static LValueOrRValue emitSuspendExpression(CodeGenFunction &CGF, CGCoroData &Co
   // Otherwise, emit suspend logic.
   CGF.EmitBlock(SuspendBlock);
 
-  auto &Builder = CGF.Builder;
   llvm::Function *CoroSave = CGF.CGM.getIntrinsic(llvm::Intrinsic::coro_save);
   auto *NullPtr = llvm::ConstantPointerNull::get(CGF.CGM.Int8PtrTy);
   auto *SaveCall = Builder.CreateCall(CoroSave, {NullPtr});
@@ -256,7 +273,8 @@ static LValueOrRValue emitSuspendExpression(CodeGenFunction &CGF, CGCoroData &Co
 
   SmallVector<llvm::Value *, 3> SuspendIntrinsicCallArgs;
   SuspendIntrinsicCallArgs.push_back(
-      CGF.getOrCreateOpaqueLValueMapping(S.getOpaqueValue()).getPointer(CGF));
+      CGF.getOrCreateOpaqueLValueMapping(S.getCommonExprOpaqueValue())
+          .getPointer(CGF));
 
   SuspendIntrinsicCallArgs.push_back(CGF.CurCoro.Data->CoroBegin);
   SuspendIntrinsicCallArgs.push_back(SuspendWrapper);
@@ -455,7 +473,7 @@ CodeGenFunction::generateAwaitSuspendWrapper(Twine const &CoroName,
       Builder.CreateLoad(GetAddrOfLocalVar(&FrameDecl));
 
   auto AwaiterBinder = CodeGenFunction::OpaqueValueMappingData::bind(
-      *this, S.getOpaqueValue(), AwaiterLValue);
+      *this, S.getCommonExprOpaqueValue(), AwaiterLValue);
 
   auto *SuspendRet = EmitScalarExpr(S.getSuspendExpr());
 
@@ -473,6 +491,9 @@ CodeGenFunction::generateAwaitSuspendWrapper(Twine const &CoroName,
 
 LValue
 CodeGenFunction::EmitCoawaitLValue(const CoawaitExpr *E) {
+  if (E->isInplaceCall()) {
+    llvm::dbgs() << "Inplace call!\n";
+  }
   assert(getCoroutineSuspendExprReturnType(getContext(), E)->isReferenceType() &&
          "Can't have a scalar return unless the return type is a "
          "reference type!");
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 43b4936e5180a..b6718a46e8c50 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -618,11 +618,7 @@ EmitMaterializeTemporaryExpr(const MaterializeTemporaryExpr *M) {
     }
   }
 
-  auto Ret = MakeAddrLValue(Object, M->getType(), AlignmentSource::Decl);
-  if (TemporaryValues.contains(M)) {
-    TemporaryValues[M] = Ret.getPointer(*this);
-  }
-  return Ret;
+  return MakeAddrLValue(Object, M->getType(), AlignmentSource::Decl);
 }
 
 RValue
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 6dd3da8aacaeb..8525f66082a4e 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -369,9 +369,6 @@ class CodeGenFunction : public CodeGenTypeCache {
   };
   CGCoroInfo CurCoro;
 
-  llvm::SmallDenseMap<const MaterializeTemporaryExpr *, llvm::Value *>
-      TemporaryValues;
-
   bool isCoroutine() const {
     return CurCoro.Data != nullptr;
   }
diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp
index 2dbce5e54c76b..4aeda4e2a93c2 100644
--- a/clang/lib/Sema/SemaCoroutine.cpp
+++ b/clang/lib/Sema/SemaCoroutine.cpp
@@ -15,6 +15,7 @@
 
 #include "CoroutineStmtBuilder.h"
 #include "clang/AST/ASTLambda.h"
+#include "clang/AST/ComputeDependence.h"
 #include "clang/AST/Decl.h"
 #include "clang/AST/Expr.h"
 #include "clang/AST/ExprCXX.h"
@@ -838,6 +839,19 @@ static bool isCoroInplaceCall(Expr *Operand) {
   return isAttributedCoroInplaceTask(Operand->getType());
 }
 
+template <typename DesiredExpr>
+DesiredExpr *getExprWrappedByTemporary(Expr *E) {
+  if (auto *BTE = dyn_cast<CXXBindTemporaryExpr>(E)) {
+    E = BTE->getSubExpr();
+  }
+
+  if (auto *S = dyn_cast<DesiredExpr>(E)) {
+    return S;
+  }
+
+  return nullptr;
+}
+
 // Attempts to resolve and build a CoawaitExpr from "raw" inputs, bailing out to
 // DependentCoawaitExpr if needed.
 ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *Operand,
@@ -861,6 +875,26 @@ ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *Operand,
   }
 
   auto *RD = Promise->getType()->getAsCXXRecordDecl();
+  bool InplaceCall =
+      isCoroInplaceCall(Operand) &&
+      isAttributedCoroInplaceTask(
+          getCurFunctionDecl(/*AllowLambda=*/true)->getReturnType());
+
+  OpaqueValueExpr *OpaqueCallExpr = nullptr;
+
+  if (InplaceCall) {
+    if (auto *Temporary = dyn_cast<CXXBindTemporaryExpr>(Operand)) {
+      auto *SubExpr = Temporary->getSubExpr();
+      if (CallExpr *Call = dyn_cast<CallExpr>(SubExpr)) {
+        Call->setIsCoroutineInplaceTaskCall();
+        OpaqueCallExpr = new (Context)
+            OpaqueValueExpr(Call->getRParenLoc(), Call->getType(),
+                            Call->getValueKind(), Call->getObjectKind(), Call);
+        Temporary->setSubExpr(OpaqueCallExpr);
+      }
+    }
+  }
+
   auto *Transformed = Operand;
   if (lookupMember(*this, "await_transform", RD, Loc)) {
     ExprResult R =
@@ -878,11 +912,11 @@ ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *Operand,
     return ExprError();
 
   auto Res = BuildResolvedCoawaitExpr(Loc, Operand, Awaiter.get());
-  if (!Res.isInvalid() && isCoroInplaceCall(Operand) &&
-      isAttributedCoroInplaceTask(
-          getCurFunctionDecl(/*AllowLambda=*/true)->getReturnType())) {
+  if (!Res.isInvalid() && InplaceCall) {
     // BuildResolvedCoawaitExpr must return a CoawaitExpr, if valid.
-    Res.getAs<CoawaitExpr>()->setIsInplaceCall();
+    CoawaitExpr *CE = Res.getAs<CoawaitExpr>();
+    CE->setIsInplaceCall();
+    CE->setOperandOpaqueValue(OpaqueCallExpr);
   }
   return Res;
 }
diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp
index 67ef170251914..22d58f16178d7 100644
--- a/clang/lib/Serialization/ASTReaderStmt.cpp
+++ b/clang/lib/Serialization/ASTReaderStmt.cpp
@@ -483,7 +483,9 @@ void ASTStmtReader::VisitCoawaitExpr(CoawaitExpr *E) {
   E->KeywordLoc = readSourceLocation();
   for (auto &SubExpr: E->SubExprs)
     SubExpr = Record.readSubStmt();
-  E->OpaqueValue = cast_or_null<OpaqueValueExpr>(Record.readSubStmt());
+  E->CommonExprOpaqueValue =
+      cast_or_null<OpaqueValueExpr>(Record.readSubStmt());
+  E->OperandOpaqueValue = cast_or_null<OpaqueValueExpr>(Record.readSubStmt());
   E->setIsImplicit(Record.readInt() != 0);
 }
 
@@ -492,7 +494,9 @@ void ASTStmtReader::VisitCoyieldExpr(CoyieldExpr *E) {
   E->KeywordLoc = readSourceLocation();
   for (auto &SubExpr: E->SubExprs)
     SubExpr = Record.readSubStmt();
-  E->OpaqueValue = cast_or_null<OpaqueValueExpr>(Record.readSubStmt());
+  E->CommonExprOpaqueValue =
+      cast_or_null<OpaqueValueExpr>(Record.readSubStmt());
+  E->OperandOpaqueValue = cast_or_null<OpaqueValueExpr>(Record.readSubStmt());
 }
 
 void ASTStmtReader::VisitDependentCoawaitExpr(DependentCoawaitExpr *E) {
diff --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp
index 1ba6d5501fd10..86ddd086321cf 100644
--- a/clang/lib/Serialization/ASTWriterStmt.cpp
+++ b/clang/lib/Serialization/ASTWriterStmt.cpp
@@ -443,7 +443,8 @@ void ASTStmtWriter::VisitCoroutineSuspendExpr(CoroutineSuspendExpr *E) {
   Record.AddSourceLocation(E->getKeywordLoc());
   for (Stmt *S : E->children())
     Record.AddStmt(S);
-  Record.AddStmt(E->getOpaqueValue());
+  Record.AddStmt(E->getCommonExprOpaqueValue());
+  Record.AddStmt(E->getOperandOpaqueValue());
 }
 
 void ASTStmtWriter::VisitCoawaitExpr(CoawaitExpr *E) {
diff --git a/clang/test/Misc/pragma-attribute-supported-attributes-list.test b/clang/test/Misc/pragma-attribute-supported-attributes-list.test
index c37e0ac9fec46..068192c173fcd 100644
--- a/clang/test/Misc/pragma-attribute-supported-attributes-list.test
+++ b/clang/test/Misc/pragma-attribute-supported-attributes-list.test
@@ -59,10 +59,10 @@
 // CHECK-NEXT: ConsumableSetOnRead (SubjectMatchRule_record)
 // CHECK-NEXT: Convergent (SubjectMatchRule_function)
 // CHECK-NEXT: CoroDisableLifetimeBound (SubjectMatchRule_function)
+// CHECK-NEXT: CoroInplaceTask (SubjectMatchRule_record)
 // CHECK-NEXT: CoroLifetimeBound (SubjectMatchRule_record)
 // CHECK-NEXT: CoroOnlyDestroyWhenComplete (SubjectMatchRule_record)
 // CHECK-NEXT: CoroReturnType (SubjectMatchRule_record)
-// CHECK-NEXT: CoroInplaceTask (SubjectMatchRule_record)
 // CHECK-NEXT: CoroWrapper (SubjectMatchRule_function)
 // CHECK-NEXT: DLLExport (SubjectMatchRule_function, SubjectMatchRule_variable, SubjectMatchRule_record, SubjectMatchRule_objc_interface)
 // CHECK-NEXT: DLLImport (SubjectMatchRule_function, SubjectMatchRule_variable, SubjectMatchRule_record, SubjectMatchRule_objc_interface)



More information about the llvm-commits mailing list