r323381 - [coroutines] Pass coro func args to promise ctor

Brian Gesiak via cfe-commits cfe-commits at lists.llvm.org
Wed Jan 24 14:15:42 PST 2018


Author: modocache
Date: Wed Jan 24 14:15:42 2018
New Revision: 323381

URL: http://llvm.org/viewvc/llvm-project?rev=323381&view=rev
Log:
[coroutines] Pass coro func args to promise ctor

Summary:
Use corutine function arguments to initialize a promise type, but only
if the promise type defines a constructor that takes those arguments.
Otherwise, fall back to the default constructor.

Test Plan: check-clang

Reviewers: rsmith, GorNishanov, eric_niebler

Reviewed By: GorNishanov

Subscribers: toby-allsopp, lewissbaker, EricWF, cfe-commits

Differential Revision: https://reviews.llvm.org/D41820

Modified:
    cfe/trunk/include/clang/Sema/ScopeInfo.h
    cfe/trunk/include/clang/Sema/Sema.h
    cfe/trunk/lib/Sema/CoroutineStmtBuilder.h
    cfe/trunk/lib/Sema/ScopeInfo.cpp
    cfe/trunk/lib/Sema/SemaCoroutine.cpp
    cfe/trunk/lib/Sema/TreeTransform.h
    cfe/trunk/test/CodeGenCoroutines/coro-params.cpp
    cfe/trunk/test/SemaCXX/coroutines.cpp

Modified: cfe/trunk/include/clang/Sema/ScopeInfo.h
URL: http://llvm.org/viewvc/llvm-project/cfe/trunk/include/clang/Sema/ScopeInfo.h?rev=323381&r1=323380&r2=323381&view=diff
==============================================================================
--- cfe/trunk/include/clang/Sema/ScopeInfo.h (original)
+++ cfe/trunk/include/clang/Sema/ScopeInfo.h Wed Jan 24 14:15:42 2018
@@ -22,6 +22,7 @@
 #include "clang/Sema/CleanupInfo.h"
 #include "clang/Sema/Ownership.h"
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringSwitch.h"
@@ -172,6 +173,10 @@ public:
   /// \brief The promise object for this coroutine, if any.
   VarDecl *CoroutinePromise = nullptr;
 
+  /// \brief A mapping between the coroutine function parameters that were moved
+  /// to the coroutine frame, and their move statements.
+  llvm::SmallMapVector<ParmVarDecl *, Stmt *, 4> CoroutineParameterMoves;
+
   /// \brief The initial and final coroutine suspend points.
   std::pair<Stmt *, Stmt *> CoroutineSuspends;
 

Modified: cfe/trunk/include/clang/Sema/Sema.h
URL: http://llvm.org/viewvc/llvm-project/cfe/trunk/include/clang/Sema/Sema.h?rev=323381&r1=323380&r2=323381&view=diff
==============================================================================
--- cfe/trunk/include/clang/Sema/Sema.h (original)
+++ cfe/trunk/include/clang/Sema/Sema.h Wed Jan 24 14:15:42 2018
@@ -8478,6 +8478,7 @@ public:
   StmtResult BuildCoreturnStmt(SourceLocation KwLoc, Expr *E,
                                bool IsImplicit = false);
   StmtResult BuildCoroutineBodyStmt(CoroutineBodyStmt::CtorArgs);
+  bool buildCoroutineParameterMoves(SourceLocation Loc);
   VarDecl *buildCoroutinePromise(SourceLocation Loc);
   void CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body);
 

Modified: cfe/trunk/lib/Sema/CoroutineStmtBuilder.h
URL: http://llvm.org/viewvc/llvm-project/cfe/trunk/lib/Sema/CoroutineStmtBuilder.h?rev=323381&r1=323380&r2=323381&view=diff
==============================================================================
--- cfe/trunk/lib/Sema/CoroutineStmtBuilder.h (original)
+++ cfe/trunk/lib/Sema/CoroutineStmtBuilder.h Wed Jan 24 14:15:42 2018
@@ -51,9 +51,6 @@ public:
   /// name lookup.
   bool buildDependentStatements();
 
-  /// \brief Build just parameter moves. To use for rebuilding in TreeTransform.
-  bool buildParameterMoves();
-
   bool isInvalid() const { return !this->IsValid; }
 
 private:
@@ -65,7 +62,6 @@ private:
   bool makeReturnObject();
   bool makeGroDeclAndReturnStmt();
   bool makeReturnOnAllocFailure();
-  bool makeParamMoves();
 };
 
 } // end namespace clang

Modified: cfe/trunk/lib/Sema/ScopeInfo.cpp
URL: http://llvm.org/viewvc/llvm-project/cfe/trunk/lib/Sema/ScopeInfo.cpp?rev=323381&r1=323380&r2=323381&view=diff
==============================================================================
--- cfe/trunk/lib/Sema/ScopeInfo.cpp (original)
+++ cfe/trunk/lib/Sema/ScopeInfo.cpp Wed Jan 24 14:15:42 2018
@@ -43,6 +43,7 @@ void FunctionScopeInfo::Clear() {
   // Coroutine state
   FirstCoroutineStmtLoc = SourceLocation();
   CoroutinePromise = nullptr;
+  CoroutineParameterMoves.clear();
   NeedsCoroutineSuspends = true;
   CoroutineSuspends.first = nullptr;
   CoroutineSuspends.second = nullptr;

Modified: cfe/trunk/lib/Sema/SemaCoroutine.cpp
URL: http://llvm.org/viewvc/llvm-project/cfe/trunk/lib/Sema/SemaCoroutine.cpp?rev=323381&r1=323380&r2=323381&view=diff
==============================================================================
--- cfe/trunk/lib/Sema/SemaCoroutine.cpp (original)
+++ cfe/trunk/lib/Sema/SemaCoroutine.cpp Wed Jan 24 14:15:42 2018
@@ -494,9 +494,67 @@ VarDecl *Sema::buildCoroutinePromise(Sou
   CheckVariableDeclarationType(VD);
   if (VD->isInvalidDecl())
     return nullptr;
-  ActOnUninitializedDecl(VD);
+
+  auto *ScopeInfo = getCurFunction();
+  // Build a list of arguments, based on the coroutine functions arguments,
+  // that will be passed to the promise type's constructor.
+  llvm::SmallVector<Expr *, 4> CtorArgExprs;
+  auto &Moves = ScopeInfo->CoroutineParameterMoves;
+  for (auto *PD : FD->parameters()) {
+    if (PD->getType()->isDependentType())
+      continue;
+
+    auto RefExpr = ExprEmpty();
+    auto Move = Moves.find(PD);
+    if (Move != Moves.end()) {
+      // If a reference to the function parameter exists in the coroutine
+      // frame, use that reference.
+      auto *MoveDecl =
+          cast<VarDecl>(cast<DeclStmt>(Move->second)->getSingleDecl());
+      RefExpr = BuildDeclRefExpr(MoveDecl, MoveDecl->getType(),
+                                 ExprValueKind::VK_LValue, FD->getLocation());
+    } else {
+      // If the function parameter doesn't exist in the coroutine frame, it
+      // must be a scalar value. Use it directly.
+      assert(!PD->getType()->getAsCXXRecordDecl() &&
+             "Non-scalar types should have been moved and inserted into the "
+             "parameter moves map");
+      RefExpr =
+          BuildDeclRefExpr(PD, PD->getOriginalType().getNonReferenceType(),
+                           ExprValueKind::VK_LValue, FD->getLocation());
+    }
+
+    if (RefExpr.isInvalid())
+      return nullptr;
+    CtorArgExprs.push_back(RefExpr.get());
+  }
+
+  // Create an initialization sequence for the promise type using the
+  // constructor arguments, wrapped in a parenthesized list expression.
+  Expr *PLE = new (Context) ParenListExpr(Context, FD->getLocation(),
+                                          CtorArgExprs, FD->getLocation());
+  InitializedEntity Entity = InitializedEntity::InitializeVariable(VD);
+  InitializationKind Kind = InitializationKind::CreateForInit(
+      VD->getLocation(), /*DirectInit=*/true, PLE);
+  InitializationSequence InitSeq(*this, Entity, Kind, CtorArgExprs,
+                                 /*TopLevelOfInitList=*/false,
+                                 /*TreatUnavailableAsInvalid=*/false);
+
+  // Attempt to initialize the promise type with the arguments.
+  // If that fails, fall back to the promise type's default constructor.
+  if (InitSeq) {
+    ExprResult Result = InitSeq.Perform(*this, Entity, Kind, CtorArgExprs);
+    if (Result.isInvalid()) {
+      VD->setInvalidDecl();
+    } else if (Result.get()) {
+      VD->setInit(MaybeCreateExprWithCleanups(Result.get()));
+      VD->setInitStyle(VarDecl::CallInit);
+      CheckCompleteVariableDeclaration(VD);
+    }
+  } else
+    ActOnUninitializedDecl(VD);
+
   FD->addDecl(VD);
-  assert(!VD->isInvalidDecl());
   return VD;
 }
 
@@ -518,6 +576,9 @@ static FunctionScopeInfo *checkCoroutine
   if (ScopeInfo->CoroutinePromise)
     return ScopeInfo;
 
+  if (!S.buildCoroutineParameterMoves(Loc))
+    return nullptr;
+
   ScopeInfo->CoroutinePromise = S.buildCoroutinePromise(Loc);
   if (!ScopeInfo->CoroutinePromise)
     return nullptr;
@@ -861,6 +922,11 @@ CoroutineStmtBuilder::CoroutineStmtBuild
           !Fn.CoroutinePromise ||
           Fn.CoroutinePromise->getType()->isDependentType()) {
   this->Body = Body;
+
+  for (auto KV : Fn.CoroutineParameterMoves)
+    this->ParamMovesVector.push_back(KV.second);
+  this->ParamMoves = this->ParamMovesVector;
+
   if (!IsPromiseDependentType) {
     PromiseRecordDecl = Fn.CoroutinePromise->getType()->getAsCXXRecordDecl();
     assert(PromiseRecordDecl && "Type should have already been checked");
@@ -870,7 +936,7 @@ CoroutineStmtBuilder::CoroutineStmtBuild
 
 bool CoroutineStmtBuilder::buildStatements() {
   assert(this->IsValid && "coroutine already invalid");
-  this->IsValid = makeReturnObject() && makeParamMoves();
+  this->IsValid = makeReturnObject();
   if (this->IsValid && !IsPromiseDependentType)
     buildDependentStatements();
   return this->IsValid;
@@ -886,12 +952,6 @@ bool CoroutineStmtBuilder::buildDependen
   return this->IsValid;
 }
 
-bool CoroutineStmtBuilder::buildParameterMoves() {
-  assert(this->IsValid && "coroutine already invalid");
-  assert(this->ParamMoves.empty() && "param moves already built");
-  return this->IsValid = makeParamMoves();
-}
-
 bool CoroutineStmtBuilder::makePromiseStmt() {
   // Form a declaration statement for the promise declaration, so that AST
   // visitors can more easily find it.
@@ -1304,47 +1364,50 @@ static Expr *castForMoving(Sema &S, Expr
       .get();
 }
 
-
 /// \brief Build a variable declaration for move parameter.
 static VarDecl *buildVarDecl(Sema &S, SourceLocation Loc, QualType Type,
                              IdentifierInfo *II) {
   TypeSourceInfo *TInfo = S.Context.getTrivialTypeSourceInfo(Type, Loc);
-  VarDecl *Decl =
-      VarDecl::Create(S.Context, S.CurContext, Loc, Loc, II, Type, TInfo, SC_None);
+  VarDecl *Decl = VarDecl::Create(S.Context, S.CurContext, Loc, Loc, II, Type,
+                                  TInfo, SC_None);
   Decl->setImplicit();
   return Decl;
 }
 
-bool CoroutineStmtBuilder::makeParamMoves() {
-  for (auto *paramDecl : FD.parameters()) {
-    auto Ty = paramDecl->getType();
-    if (Ty->isDependentType())
+// Build statements that move coroutine function parameters to the coroutine
+// frame, and store them on the function scope info.
+bool Sema::buildCoroutineParameterMoves(SourceLocation Loc) {
+  assert(isa<FunctionDecl>(CurContext) && "not in a function scope");
+  auto *FD = cast<FunctionDecl>(CurContext);
+
+  auto *ScopeInfo = getCurFunction();
+  assert(ScopeInfo->CoroutineParameterMoves.empty() &&
+         "Should not build parameter moves twice");
+
+  for (auto *PD : FD->parameters()) {
+    if (PD->getType()->isDependentType())
       continue;
 
-    // No need to copy scalars, llvm will take care of them.
-    if (Ty->getAsCXXRecordDecl()) {
-      ExprResult ParamRef =
-          S.BuildDeclRefExpr(paramDecl, paramDecl->getType(),
-                             ExprValueKind::VK_LValue, Loc); // FIXME: scope?
-      if (ParamRef.isInvalid())
+    // No need to copy scalars, LLVM will take care of them.
+    if (PD->getType()->getAsCXXRecordDecl()) {
+      ExprResult PDRefExpr = BuildDeclRefExpr(
+          PD, PD->getType(), ExprValueKind::VK_LValue, Loc); // FIXME: scope?
+      if (PDRefExpr.isInvalid())
         return false;
 
-      Expr *RCast = castForMoving(S, ParamRef.get());
+      Expr *CExpr = castForMoving(*this, PDRefExpr.get());
 
-      auto D = buildVarDecl(S, Loc, Ty, paramDecl->getIdentifier());
-      S.AddInitializerToDecl(D, RCast, /*DirectInit=*/true);
+      auto D = buildVarDecl(*this, Loc, PD->getType(), PD->getIdentifier());
+      AddInitializerToDecl(D, CExpr, /*DirectInit=*/true);
 
       // Convert decl to a statement.
-      StmtResult Stmt = S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(D), Loc, Loc);
+      StmtResult Stmt = ActOnDeclStmt(ConvertDeclToDeclGroup(D), Loc, Loc);
       if (Stmt.isInvalid())
         return false;
 
-      ParamMovesVector.push_back(Stmt.get());
+      ScopeInfo->CoroutineParameterMoves.insert(std::make_pair(PD, Stmt.get()));
     }
   }
-
-  // Convert to ArrayRef in CtorArgs structure that builder inherits from.
-  ParamMoves = ParamMovesVector;
   return true;
 }
 

Modified: cfe/trunk/lib/Sema/TreeTransform.h
URL: http://llvm.org/viewvc/llvm-project/cfe/trunk/lib/Sema/TreeTransform.h?rev=323381&r1=323380&r2=323381&view=diff
==============================================================================
--- cfe/trunk/lib/Sema/TreeTransform.h (original)
+++ cfe/trunk/lib/Sema/TreeTransform.h Wed Jan 24 14:15:42 2018
@@ -6944,6 +6944,8 @@ TreeTransform<Derived>::TransformCorouti
 
   // The new CoroutinePromise object needs to be built and put into the current
   // FunctionScopeInfo before any transformations or rebuilding occurs.
+  if (!SemaRef.buildCoroutineParameterMoves(FD->getLocation()))
+    return StmtError();
   auto *Promise = SemaRef.buildCoroutinePromise(FD->getLocation());
   if (!Promise)
     return StmtError();
@@ -7034,8 +7036,6 @@ TreeTransform<Derived>::TransformCorouti
       Builder.ReturnStmt = Res.get();
     }
   }
-  if (!Builder.buildParameterMoves())
-    return StmtError();
 
   return getDerived().RebuildCoroutineBodyStmt(Builder);
 }

Modified: cfe/trunk/test/CodeGenCoroutines/coro-params.cpp
URL: http://llvm.org/viewvc/llvm-project/cfe/trunk/test/CodeGenCoroutines/coro-params.cpp?rev=323381&r1=323380&r2=323381&view=diff
==============================================================================
--- cfe/trunk/test/CodeGenCoroutines/coro-params.cpp (original)
+++ cfe/trunk/test/CodeGenCoroutines/coro-params.cpp Wed Jan 24 14:15:42 2018
@@ -1,6 +1,7 @@
 // Verifies that parameters are copied with move constructors
 // Verifies that parameter copies are destroyed
 // Vefifies that parameter copies are used in the body of the coroutine
+// Verifies that parameter copies are used to construct the promise type, if that type has a matching constructor
 // RUN: %clang_cc1 -std=c++1z -fcoroutines-ts -triple=x86_64-unknown-linux-gnu -emit-llvm -o - %s -disable-llvm-passes -fexceptions | FileCheck %s
 
 namespace std::experimental {
@@ -127,3 +128,31 @@ struct B {
 void call_dependent_params() {
   dependent_params(A{}, B{}, B{});
 }
+
+// Test that, when the promise type has a constructor whose signature matches
+// that of the coroutine function, that constructor is used. This is an
+// experimental feature that will be proposed for the Coroutines TS.
+
+struct promise_matching_constructor {};
+
+template<>
+struct std::experimental::coroutine_traits<void, promise_matching_constructor, int, float, double> {
+  struct promise_type {
+    promise_type(promise_matching_constructor, int, float, double) {}
+    promise_type() = delete;
+    void get_return_object() {}
+    suspend_always initial_suspend() { return {}; }
+    suspend_always final_suspend() { return {}; }
+    void return_void() {}
+    void unhandled_exception() {}
+  };
+};
+
+// CHECK-LABEL: void @_Z38coroutine_matching_promise_constructor28promise_matching_constructorifd(i32, float, double)
+void coroutine_matching_promise_constructor(promise_matching_constructor, int, float, double) {
+  // CHECK: %[[INT:.+]] = load i32, i32* %.addr, align 4
+  // CHECK: %[[FLOAT:.+]] = load float, float* %.addr1, align 4
+  // CHECK: %[[DOUBLE:.+]] = load double, double* %.addr2, align 8
+  // CHECK: invoke void @_ZNSt12experimental16coroutine_traitsIJv28promise_matching_constructorifdEE12promise_typeC1ES1_ifd(%"struct.std::experimental::coroutine_traits<void, promise_matching_constructor, int, float, double>::promise_type"* %__promise, i32 %[[INT]], float %[[FLOAT]], double %[[DOUBLE]])
+  co_return;
+}

Modified: cfe/trunk/test/SemaCXX/coroutines.cpp
URL: http://llvm.org/viewvc/llvm-project/cfe/trunk/test/SemaCXX/coroutines.cpp?rev=323381&r1=323380&r2=323381&view=diff
==============================================================================
--- cfe/trunk/test/SemaCXX/coroutines.cpp (original)
+++ cfe/trunk/test/SemaCXX/coroutines.cpp Wed Jan 24 14:15:42 2018
@@ -1171,4 +1171,73 @@ template CoroMemberTag DepTestType<int>:
 
 template CoroMemberTag DepTestType<int>::test_static_template<void>(const char *volatile &, unsigned);
 
+struct bad_promise_deleted_constructor {
+  // expected-note at +1 {{'bad_promise_deleted_constructor' has been explicitly marked deleted here}}
+  bad_promise_deleted_constructor() = delete;
+  coro<bad_promise_deleted_constructor> get_return_object();
+  suspend_always initial_suspend();
+  suspend_always final_suspend();
+  void return_void();
+  void unhandled_exception();
+};
+
+coro<bad_promise_deleted_constructor>
+bad_coroutine_calls_deleted_promise_constructor() {
+  // expected-error at -1 {{call to deleted constructor of 'std::experimental::coroutine_traits<coro<CoroHandleMemberFunctionTest::bad_promise_deleted_constructor>>::promise_type' (aka 'CoroHandleMemberFunctionTest::bad_promise_deleted_constructor')}}
+  co_return;
+}
+
+// Test that, when the promise type has a constructor whose signature matches
+// that of the coroutine function, that constructor is used. If no matching
+// constructor exists, the default constructor is used as a fallback. If no
+// matching constructors exist at all, an error is emitted. This is an
+// experimental feature that will be proposed for the Coroutines TS.
+
+struct good_promise_default_constructor {
+  good_promise_default_constructor(double, float, int);
+  good_promise_default_constructor() = default;
+  coro<good_promise_default_constructor> get_return_object();
+  suspend_always initial_suspend();
+  suspend_always final_suspend();
+  void return_void();
+  void unhandled_exception();
+};
+
+coro<good_promise_default_constructor>
+good_coroutine_calls_default_constructor() {
+  co_return;
+}
+
+struct good_promise_custom_constructor {
+  good_promise_custom_constructor(double, float, int);
+  good_promise_custom_constructor() = delete;
+  coro<good_promise_custom_constructor> get_return_object();
+  suspend_always initial_suspend();
+  suspend_always final_suspend();
+  void return_void();
+  void unhandled_exception();
+};
+
+coro<good_promise_custom_constructor>
+good_coroutine_calls_custom_constructor(double, float, int) {
+  co_return;
+}
+
+struct bad_promise_no_matching_constructor {
+  bad_promise_no_matching_constructor(int, int, int);
+  // expected-note at +1 {{'bad_promise_no_matching_constructor' has been explicitly marked deleted here}}
+  bad_promise_no_matching_constructor() = delete;
+  coro<bad_promise_no_matching_constructor> get_return_object();
+  suspend_always initial_suspend();
+  suspend_always final_suspend();
+  void return_void();
+  void unhandled_exception();
+};
+
+coro<bad_promise_no_matching_constructor>
+bad_coroutine_calls_with_no_matching_constructor(int, int) {
+  // expected-error at -1 {{call to deleted constructor of 'std::experimental::coroutine_traits<coro<CoroHandleMemberFunctionTest::bad_promise_no_matching_constructor>, int, int>::promise_type' (aka 'CoroHandleMemberFunctionTest::bad_promise_no_matching_constructor')}}
+  co_return;
+}
+
 } // namespace CoroHandleMemberFunctionTest




More information about the cfe-commits mailing list