[clang] [llvm] [Clang] CGCoroutine: Skip moving parameters if the allocation decision is false (PR #81195)

Yuxuan Chen via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 6 16:03:41 PST 2024


https://github.com/yuxuanchen1997 updated https://github.com/llvm/llvm-project/pull/81195

>From 1e03c2ec24c7bc6a303266a7023a56d6449a46e5 Mon Sep 17 00:00:00 2001
From: Yuxuan Chen <ych at meta.com>
Date: Wed, 7 Feb 2024 16:05:42 -0800
Subject: [PATCH] Skip moving parameters if the allocation decision is false

---
 clang/lib/CodeGen/CGCoroutine.cpp            | 120 ++++++++---
 clang/test/CodeGenCoroutines/coro-gro.cpp    |   6 +-
 clang/test/CodeGenCoroutines/coro-params.cpp |  48 +++--
 llvm/lib/Transforms/Coroutines/CoroElide.cpp | 211 ++++++++++++++++---
 4 files changed, 311 insertions(+), 74 deletions(-)

diff --git a/clang/lib/CodeGen/CGCoroutine.cpp b/clang/lib/CodeGen/CGCoroutine.cpp
index 888d30bfb3e1d6..b2933650367c12 100644
--- a/clang/lib/CodeGen/CGCoroutine.cpp
+++ b/clang/lib/CodeGen/CGCoroutine.cpp
@@ -389,25 +389,12 @@ namespace {
     ParamReferenceReplacerRAII(CodeGenFunction::DeclMapTy &LocalDeclMap)
         : LocalDeclMap(LocalDeclMap) {}
 
-    void addCopy(DeclStmt const *PM) {
-      // Figure out what param it refers to.
-
-      assert(PM->isSingleDecl());
-      VarDecl const*VD = static_cast<VarDecl const*>(PM->getSingleDecl());
-      Expr const *InitExpr = VD->getInit();
-      GetParamRef Visitor;
-      Visitor.Visit(const_cast<Expr*>(InitExpr));
-      assert(Visitor.Expr);
-      DeclRefExpr *DREOrig = Visitor.Expr;
-      auto *PD = DREOrig->getDecl();
-
-      auto it = LocalDeclMap.find(PD);
-      assert(it != LocalDeclMap.end() && "parameter is not found");
-      SavedLocals.insert({ PD, it->second });
-
-      auto copyIt = LocalDeclMap.find(VD);
-      assert(copyIt != LocalDeclMap.end() && "parameter copy is not found");
-      it->second = copyIt->getSecond();
+    void substAddress(ValueDecl *D, Address Addr) {
+      auto it = LocalDeclMap.find(D);
+      assert(it != LocalDeclMap.end() && "original decl is not found");
+      SavedLocals.insert({D, it->second});
+
+      it->second = Addr;
     }
 
     ~ParamReferenceReplacerRAII() {
@@ -629,6 +616,63 @@ struct GetReturnObjectManager {
     Builder.CreateStore(Builder.getTrue(), GroActiveFlag);
   }
 };
+
+static ValueDecl *getOriginalParamDeclForParamMove(VarDecl const *VD) {
+  Expr const *InitExpr = VD->getInit();
+  GetParamRef Visitor;
+  Visitor.Visit(const_cast<Expr *>(InitExpr));
+  assert(Visitor.Expr);
+  return Visitor.Expr->getDecl();
+}
+
+struct ParamMoveManager {
+  ParamMoveManager(CodeGenFunction &CGF,
+                   llvm::ArrayRef<const Stmt *> ParamMoves)
+      : CGF(CGF), ParamMovesVarDecls() {
+    ParamMovesVarDecls.reserve(ParamMoves.size());
+    for (auto *S : ParamMoves) {
+      auto *PMStmt = cast<DeclStmt>(S);
+      assert(PMStmt->isSingleDecl());
+      auto *ParamMoveVD = static_cast<VarDecl const *>(PMStmt->getSingleDecl());
+      ParamMovesVarDecls.push_back(ParamMoveVD);
+    }
+  }
+
+  // Because we wrap param moves in the coro.alloc block. It's not always
+  // necessary to run the corresponding cleanups in the branches.
+  // We would need to know when to (conditionally) clean them up.
+  void EmitMovesWithCleanup(Address PMCleanupActiveFlag) {
+    // Create parameter copies. We do it before creating a promise, since an
+    // evolution of coroutine TS may allow promise constructor to observe
+    // parameter copies.
+    for (auto *VD : ParamMovesVarDecls) {
+      auto Emission = CGF.EmitAutoVarAlloca(*VD);
+      CGF.EmitAutoVarInit(Emission);
+      auto OldTop = CGF.EHStack.stable_begin();
+      CGF.EmitAutoVarCleanups(Emission);
+      auto Top = CGF.EHStack.stable_begin();
+
+      for (auto I = CGF.EHStack.find(Top), E = CGF.EHStack.find(OldTop); I != E;
+           I++) {
+        if (auto *Cleanup = dyn_cast<EHCleanupScope>(&*I)) {
+          assert(!Cleanup->hasActiveFlag() &&
+                 "cleanup already has active flag?");
+          Cleanup->setActiveFlag(PMCleanupActiveFlag);
+          Cleanup->setTestFlagInEHCleanup();
+          Cleanup->setTestFlagInNormalCleanup();
+        }
+      }
+    }
+  }
+
+  llvm::ArrayRef<const VarDecl *> GetParamMovesVarDecls() {
+    return ParamMovesVarDecls;
+  }
+
+private:
+  CodeGenFunction &CGF;
+  SmallVector<const VarDecl *> ParamMovesVarDecls;
+};
 } // namespace
 
 static void emitBodyAndFallthrough(CodeGenFunction &CGF,
@@ -648,6 +692,8 @@ void CodeGenFunction::EmitCoroutineBody(const CoroutineBodyStmt &S) {
   auto *EntryBB = Builder.GetInsertBlock();
   auto *AllocBB = createBasicBlock("coro.alloc");
   auto *InitBB = createBasicBlock("coro.init");
+  auto *ParamMoveBB = createBasicBlock("coro.param.move");
+  auto *AfterParamMoveBB = createBasicBlock("coro.after.param.move");
   auto *FinalBB = createBasicBlock("coro.final");
   auto *RetBB = createBasicBlock("coro.ret");
 
@@ -664,6 +710,9 @@ void CodeGenFunction::EmitCoroutineBody(const CoroutineBodyStmt &S) {
   auto *CoroAlloc = Builder.CreateCall(
       CGM.getIntrinsic(llvm::Intrinsic::coro_alloc), {CoroId});
 
+  auto PMCleanupActiveFlag = CreateTempAlloca(
+      Builder.getInt1Ty(), CharUnits::One(), "param.move.cleanup.active");
+  Builder.CreateStore(CoroAlloc, PMCleanupActiveFlag);
   Builder.CreateCondBr(CoroAlloc, AllocBB, InitBB);
 
   EmitBlock(AllocBB);
@@ -695,6 +744,7 @@ void CodeGenFunction::EmitCoroutineBody(const CoroutineBodyStmt &S) {
   auto *Phi = Builder.CreatePHI(VoidPtrTy, 2);
   Phi->addIncoming(NullPtr, EntryBB);
   Phi->addIncoming(AllocateCall, AllocOrInvokeContBB);
+
   auto *CoroBegin = Builder.CreateCall(
       CGM.getIntrinsic(llvm::Intrinsic::coro_begin), {CoroId, Phi});
   CurCoro.Data->CoroBegin = CoroBegin;
@@ -719,15 +769,29 @@ void CodeGenFunction::EmitCoroutineBody(const CoroutineBodyStmt &S) {
         DI->getCoroutineParameterMappings().insert(
             {std::get<0>(Pair), std::get<1>(Pair)});
 
-    // Create parameter copies. We do it before creating a promise, since an
-    // evolution of coroutine TS may allow promise constructor to observe
-    // parameter copies.
-    for (auto *PM : S.getParamMoves()) {
-      EmitStmt(PM);
-      ParamReplacer.addCopy(cast<DeclStmt>(PM));
-      // TODO: if(CoroParam(...)) need to surround ctor and dtor
-      // for the copy, so that llvm can elide it if the copy is
-      // not needed.
+    Builder.CreateCondBr(CoroAlloc, ParamMoveBB, AfterParamMoveBB);
+
+    EmitBlock(ParamMoveBB);
+
+    ParamMoveManager Mover{*this, ParamMoves};
+    Mover.EmitMovesWithCleanup(PMCleanupActiveFlag);
+
+    Builder.CreateBr(AfterParamMoveBB);
+
+    EmitBlock(AfterParamMoveBB);
+
+    for (auto *VD : Mover.GetParamMovesVarDecls()) {
+      auto NewAddr = LocalDeclMap.find(VD)->getSecond();
+      auto *OrigVD = getOriginalParamDeclForParamMove(VD);
+      auto OldAddr = LocalDeclMap.find(OrigVD)->getSecond();
+
+      auto *ParamPhi = Builder.CreatePHI(VoidPtrTy, 2);
+      ParamPhi->addIncoming(NewAddr.getPointer(), ParamMoveBB);
+      ParamPhi->addIncoming(OldAddr.getPointer(), InitBB);
+
+      ParamReplacer.substAddress(
+          OrigVD,
+          Address(ParamPhi, OldAddr.getElementType(), OldAddr.getAlignment()));
     }
 
     EmitStmt(S.getPromiseDeclStmt());
diff --git a/clang/test/CodeGenCoroutines/coro-gro.cpp b/clang/test/CodeGenCoroutines/coro-gro.cpp
index d4c3ff589e340a..fb42c7e089b363 100644
--- a/clang/test/CodeGenCoroutines/coro-gro.cpp
+++ b/clang/test/CodeGenCoroutines/coro-gro.cpp
@@ -29,8 +29,11 @@ void doSomething() noexcept;
 // CHECK: define{{.*}} i32 @_Z1fv(
 int f() {
   // CHECK: %[[RetVal:.+]] = alloca i32
+  // CHECK: %[[ParamMoveCleanupActive:.+]] = alloca i1
   // CHECK: %[[GroActive:.+]] = alloca i1
   // CHECK: %[[CoroGro:.+]] = alloca %struct.GroType, {{.*}} !coro.outside.frame ![[OutFrameMetadata:.+]]
+  // CHECK: %[[CoroAlloc:.+]] = call i1 @llvm.coro.alloc
+  // CHECK-NEXT: store i1 %[[CoroAlloc:.+]], ptr %[[ParamMoveCleanupActive:.+]]
 
   // CHECK: %[[Size:.+]] = call i64 @llvm.coro.size.i64()
   // CHECK: call noalias noundef nonnull ptr @_Znwm(i64 noundef %[[Size]])
@@ -94,6 +97,7 @@ class invoker {
 // CHECK: define{{.*}} void @_Z1gv({{.*}} %[[AggRes:.+]])
 invoker g() {
   // CHECK: %[[ResultPtr:.+]] = alloca ptr
+  // CHECK-NEXT: %[[ParamMoveCleanupActive:.+]] = alloca i1
   // CHECK-NEXT: %[[Promise:.+]] = alloca %"class.invoker::invoker_promise"
 
   // CHECK: store ptr %[[AggRes]], ptr %[[ResultPtr]]
@@ -105,4 +109,4 @@ invoker g() {
   // CHECK: call void @_ZN7invoker15invoker_promise17get_return_objectEv({{.*}} %[[AggRes]]
   co_return;
 }
-// CHECK: ![[OutFrameMetadata]] = !{}
\ No newline at end of file
+// CHECK: ![[OutFrameMetadata]] = !{}
diff --git a/clang/test/CodeGenCoroutines/coro-params.cpp b/clang/test/CodeGenCoroutines/coro-params.cpp
index c5a61a53cb46ed..873936ac28aba4 100644
--- a/clang/test/CodeGenCoroutines/coro-params.cpp
+++ b/clang/test/CodeGenCoroutines/coro-params.cpp
@@ -64,22 +64,36 @@ void consume(int,int,int) noexcept;
 // TODO: Add support for CopyOnly params
 // CHECK: define{{.*}} void @_Z1fi8MoveOnly11MoveAndCopy(i32 noundef %val, ptr noundef %[[MoParam:.+]], ptr noundef %[[McParam:.+]]) #0 personality ptr @__gxx_personality_v0
 void f(int val, MoveOnly moParam, MoveAndCopy mcParam) {
+  // CHECK: %[[ValCopy:.+]] = alloca i32,
   // CHECK: %[[MoCopy:.+]] = alloca %struct.MoveOnly,
   // CHECK: %[[McCopy:.+]] = alloca %struct.MoveAndCopy,
   // CHECK: store i32 %val, ptr %[[ValAddr:.+]]
+  // CHECK: %[[AllocFlag:.+]] = call i1 @llvm.coro.alloc
+  // CHECK-NEXT: store i1 %[[AllocFlag:.+]] %[[CleanupActiveMem:.+]]
 
   // CHECK: call ptr @llvm.coro.begin(
-  // CHECK: call void @_ZN8MoveOnlyC1EOS_(ptr {{[^,]*}} %[[MoCopy]], ptr noundef nonnull align 4 dereferenceable(4) %[[MoParam]])
-  // CHECK-NEXT: call void @llvm.lifetime.start.p0(
-  // CHECK-NEXT: call void @_ZN11MoveAndCopyC1EOS_(ptr {{[^,]*}} %[[McCopy]], ptr noundef nonnull align 4 dereferenceable(4) %[[McParam]]) #
-  // CHECK-NEXT: call void @llvm.lifetime.start.p0(
-  // CHECK-NEXT: invoke void @_ZNSt16coroutine_traitsIJvi8MoveOnly11MoveAndCopyEE12promise_typeC1Ev(
-
+  // CHECK: br i1 %[[CoroAlloc:.+]], label %coro.param.move, label %coro.after.param.move
+  // CHECK: coro.param.move:
+  // CHECK: call void @llvm.lifetime.start.p0(i64 4, ptr %[[ValCopy:.+]])
+  // CHECK-NEXT: %[[Temp:.+]] = load i32, ptr %val.addr, align 4
+  // CHECK-NEXT: store i32 %[[Temp:.+]], ptr %val1, align 4
+  // CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 4, ptr %[[MoCopy:.+]]) #2
+  // CHECK-NEXT: call void @_ZN8MoveOnlyC1EOS_(ptr noundef nonnull align 4 dereferenceable(4) %[[MoCopy:.+]], ptr noundef nonnull align 4 dereferenceable(4) %[[MoParam:.+]]) #2
+  // CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 4, ptr %[[McCopy:.+]]) #2
+  // CHECK-NEXT: call void @_ZN11MoveAndCopyC1EOS_(ptr noundef nonnull align 4 dereferenceable(4) %[[McCopy:.+]], ptr noundef nonnull align 4 dereferenceable(4) %[[McParam:.+]]) #2
+  // CHECK-NEXT: br label %coro.after.param.move
+  //
+  // CHECK: %[[ValPhi:.+]] = phi ptr [ %[[ValCopy:.+]], %coro.param.move ], [ %[[ValAddr:.+]], %coro.init ]
+  // CHECK-NEXT: %[[MoPhi:.+]] = phi ptr [ %[[MoCopy:.+]], %coro.param.move ], [ %[[MoParam:.+]], %coro.init ]
+  // CHECK-NEXT: %[[McPhi:.+]] = phi ptr [ %[[McCopy:.+]], %coro.param.move ], [ %[[McParam:.+]], %coro.init ]
+  // CHECK: invoke void @_ZNSt16coroutine_traitsIJvi8MoveOnly11MoveAndCopyEE12promise_typeC1Ev(
+
+  // CHECK: init.ready
   // CHECK: call void @_ZN14suspend_always12await_resumeEv(
   // CHECK: %[[IntParam:.+]] = load i32, ptr %{{.*}}
-  // CHECK: %[[MoGep:.+]] = getelementptr inbounds %struct.MoveOnly, ptr %[[MoCopy]], i32 0, i32 0
+  // CHECK: %[[MoGep:.+]] = getelementptr inbounds %struct.MoveOnly, ptr %[[MoPhi]], i32 0, i32 0
   // CHECK: %[[MoVal:.+]] = load i32, ptr %[[MoGep]]
-  // CHECK: %[[McGep:.+]] =  getelementptr inbounds %struct.MoveAndCopy, ptr %[[McCopy]], i32 0, i32 0
+  // CHECK: %[[McGep:.+]] =  getelementptr inbounds %struct.MoveAndCopy, ptr %[[McPhi]], i32 0, i32 0
   // CHECK: %[[McVal:.+]] = load i32, ptr %[[McGep]]
   // CHECK: call void @_Z7consumeiii(i32 noundef %[[IntParam]], i32 noundef %[[MoVal]], i32 noundef %[[McVal]])
 
@@ -90,13 +104,17 @@ void f(int val, MoveOnly moParam, MoveAndCopy mcParam) {
   // CHECK: call void @_ZNSt16coroutine_traitsIJvi8MoveOnly11MoveAndCopyEE12promise_type13final_suspendEv(
   // CHECK: call void @_ZN14suspend_always12await_resumeEv(
 
-  // Destroy promise, then parameter copies:
+  // Destroy promise, then test the cleanup flag for parameter copies (if exist):
   // CHECK: call void @_ZNSt16coroutine_traitsIJvi8MoveOnly11MoveAndCopyEE12promise_typeD1Ev(ptr {{[^,]*}} %__promise)
   // CHECK-NEXT: call void @llvm.lifetime.end.p0(
-  // CHECK-NEXT: call void @_ZN11MoveAndCopyD1Ev(ptr {{[^,]*}} %[[McCopy]])
-  // CHECK-NEXT: call void @llvm.lifetime.end.p0(
-  // CHECK-NEXT: call void @_ZN8MoveOnlyD1Ev(ptr {{[^,]*}} %[[MoCopy]]
-  // CHECK-NEXT: call void @llvm.lifetime.end.p0(
+  // CHECK-NEXT: %[[TempCleanupActive:.+]] = load i1, ptr %[[CleanupActiveMem:.+]]
+  // CHECK-NEXT: br i1 %[[TempCleanupActive]]
+
+  // The next two may be in different cleanup blocks:
+  // CHECK: call void @_ZN11MoveAndCopyD1Ev(ptr {{[^,]*}} %[[McCopy]])
+  // CHECK: call void @_ZN8MoveOnlyD1Ev(ptr {{[^,]*}} %[[MoCopy]]
+
+  // CHECK: call void @llvm.lifetime.end.p0(
   // CHECK-NEXT: call void @llvm.lifetime.end.p0(
   // CHECK-NEXT: call ptr @llvm.coro.free(
 }
@@ -109,13 +127,13 @@ void dependent_params(T x, U, U y) {
   // CHECK-NEXT: %[[y_copy:.+]] = alloca %struct.B
 
   // CHECK: call ptr @llvm.coro.begin
-  // CHECK-NEXT: call void @llvm.lifetime.start.p0(
+  // CHECK: call void @llvm.lifetime.start.p0(
   // CHECK-NEXT: call void @_ZN1AC1EOS_(ptr {{[^,]*}} %[[x_copy]], ptr noundef nonnull align 4 dereferenceable(512) %x)
   // CHECK-NEXT: call void @llvm.lifetime.start.p0(
   // CHECK-NEXT: call void @_ZN1BC1EOS_(ptr {{[^,]*}} %[[unnamed_copy]], ptr noundef nonnull align 4 dereferenceable(512) %0)
   // CHECK-NEXT: call void @llvm.lifetime.start.p0(
   // CHECK-NEXT: call void @_ZN1BC1EOS_(ptr {{[^,]*}} %[[y_copy]], ptr noundef nonnull align 4 dereferenceable(512) %y)
-  // CHECK-NEXT: call void @llvm.lifetime.start.p0(
+  // CHECK: call void @llvm.lifetime.start.p0(
   // CHECK-NEXT: invoke void @_ZNSt16coroutine_traitsIJv1A1BS1_EE12promise_typeC1Ev(
 
   co_return;
diff --git a/llvm/lib/Transforms/Coroutines/CoroElide.cpp b/llvm/lib/Transforms/Coroutines/CoroElide.cpp
index d356a6d2e57594..51c155767c6c09 100644
--- a/llvm/lib/Transforms/Coroutines/CoroElide.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroElide.cpp
@@ -7,8 +7,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Transforms/Coroutines/CoroElide.h"
+#include "CoroInstr.h"
 #include "CoroInternal.h"
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/EquivalenceClasses.h"
+#include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/Analysis/AliasAnalysis.h"
 #include "llvm/Analysis/InstructionSimplify.h"
@@ -32,13 +35,15 @@ static cl::opt<std::string> CoroElideInfoOutputFilename(
 #endif
 
 namespace {
+
 // Created on demand if the coro-elide pass has work to do.
 struct Lowerer : coro::LowererBase {
   SmallVector<CoroIdInst *, 4> CoroIds;
   SmallVector<CoroBeginInst *, 1> CoroBegins;
   SmallVector<CoroAllocInst *, 1> CoroAllocs;
-  SmallVector<CoroSubFnInst *, 4> ResumeAddr;
-  DenseMap<CoroBeginInst *, SmallVector<CoroSubFnInst *, 4>> DestroyAddr;
+  SmallSet<CoroSubFnInst *, 4> ResumeAddr;
+  EquivalenceClasses<CoroBeginInst *> CoroBeginClasses;
+  DenseMap<CoroBeginInst *, SmallSet<CoroSubFnInst *, 4>> DestroyAddr;
   SmallPtrSet<const SwitchInst *, 4> CoroSuspendSwitches;
 
   Lowerer(Module &M) : LowererBase(M) {}
@@ -47,24 +52,71 @@ struct Lowerer : coro::LowererBase {
                             AAResults &AA);
   bool shouldElide(Function *F, DominatorTree &DT) const;
   void collectPostSplitCoroIds(Function *F);
-  bool processCoroId(CoroIdInst *, AAResults &AA, DominatorTree &DT,
+  bool processCoroId(Function &F, CoroIdInst *, AAResults &AA, DominatorTree &DT,
                      OptimizationRemarkEmitter &ORE);
-  bool hasEscapePath(const CoroBeginInst *,
+  bool hasEscapePath(Function &F,
+                     const CoroBeginInst *,
                      const SmallPtrSetImpl<BasicBlock *> &) const;
 };
+
+struct StackAliases {
+  StackAliases(Function &F) {
+    for (Instruction &I : instructions(F)) {
+      if (auto *Load = dyn_cast<LoadInst>(&I)) {
+        auto Ptr = Load->getPointerOperand();
+        Loads[Ptr].insert(Load);
+      }
+
+      if (auto *Store = dyn_cast<StoreInst>(&I)) {
+        auto Ptr = Store->getPointerOperand();
+        Stores[Ptr].insert(Store);
+      }
+    }
+  }
+
+  bool valueHasSingleStore(llvm::Value *V) const {
+    auto It = Stores.find(V);
+    return It != Stores.end() && It->second.size() == 1;
+  }
+
+  const SmallPtrSetImpl<LoadInst *>& getLoadsByStore(const StoreInst *Store) const {
+    const static SmallPtrSet<LoadInst *, 4> EmptyStores;
+    auto Ptr = Store->getPointerOperand();
+    if (Loads.contains(Ptr)) {
+      return Loads.at(Ptr);
+    }
+
+    return EmptyStores;
+  }
+
+  const SmallPtrSetImpl<StoreInst *>& getStoresByLoad(const LoadInst *Load) const {
+    const static SmallPtrSet<StoreInst *, 4> EmptyLoads;
+    auto Ptr = Load->getPointerOperand();
+    if (Stores.contains(Ptr)) {
+      return Stores.at(Ptr);
+    }
+
+    return EmptyLoads;
+  }
+
+private:
+  DenseMap<llvm::Value *, SmallPtrSet<StoreInst *, 4>> Stores;
+  DenseMap<llvm::Value *, SmallPtrSet<LoadInst *, 4>> Loads;
+};
 } // end anonymous namespace
 
 // Go through the list of coro.subfn.addr intrinsics and replace them with the
 // provided constant.
+template <typename Range>
 static void replaceWithConstant(Constant *Value,
-                                SmallVectorImpl<CoroSubFnInst *> &Users) {
+                                Range &Users) {
   if (Users.empty())
     return;
 
   // See if we need to bitcast the constant to match the type of the intrinsic
   // being replaced. Note: All coro.subfn.addr intrinsics return the same type,
   // so we only need to examine the type of the first one in the list.
-  Type *IntrTy = Users.front()->getType();
+  Type *IntrTy = (*Users.begin())->getType();
   Type *ValueTy = Value->getType();
   if (ValueTy != IntrTy) {
     // May need to tweak the function type to match the type expected at the
@@ -74,8 +126,11 @@ static void replaceWithConstant(Constant *Value,
   }
 
   // Now the value type matches the type of the intrinsic. Replace them all!
-  for (CoroSubFnInst *I : Users)
+  for (CoroSubFnInst *I : Users) {
+    llvm::dbgs() << "CSFI: ";
+    I->dump();
     replaceAndRecursivelySimplify(I, Value);
+  }
 }
 
 // See if any operand of the call instruction references the coroutine frame.
@@ -178,11 +233,15 @@ void Lowerer::elideHeapAllocations(Function *F, uint64_t FrameSize,
   removeTailCallAttribute(Frame, AA);
 }
 
-bool Lowerer::hasEscapePath(const CoroBeginInst *CB,
-                            const SmallPtrSetImpl<BasicBlock *> &TIs) const {
+bool Lowerer::hasEscapePath(
+    Function &F,
+    const CoroBeginInst *CB,
+    const SmallPtrSetImpl<BasicBlock *> &TIs) const {
+
   const auto &It = DestroyAddr.find(CB);
   assert(It != DestroyAddr.end());
 
+  StackAliases SA{F};
   // Limit the number of blocks we visit.
   unsigned Limit = 32 * (1 + It->second.size());
 
@@ -196,11 +255,20 @@ bool Lowerer::hasEscapePath(const CoroBeginInst *CB,
     Visited.insert(DA->getParent());
 
   SmallPtrSet<const BasicBlock *, 32> EscapingBBs;
+  SmallPtrSet<const LoadInst *, 1> AliasUsers;
   for (auto *U : CB->users()) {
     // The use from coroutine intrinsics are not a problem.
     if (isa<CoroFreeInst, CoroSubFnInst, CoroSaveInst>(U))
       continue;
 
+    if (auto *SI = dyn_cast<StoreInst>(U);
+        SI && SI->getPointerOperand() == U) {
+      for (const auto *Load : SA.getLoadsByStore(SI)) {
+        AliasUsers.insert(Load);
+      }
+      continue;
+    }
+
     // Think all other usages may be an escaping candidate conservatively.
     //
     // Note that the major user of switch ABI coroutine (the C++) will store
@@ -215,6 +283,15 @@ bool Lowerer::hasEscapePath(const CoroBeginInst *CB,
     EscapingBBs.insert(cast<Instruction>(U)->getParent());
   }
 
+  for (auto *Load : AliasUsers) {
+    for (auto *U : Load->users()) {
+      // The use from coroutine intrinsics are not a problem.
+      if (isa<CoroFreeInst, CoroSubFnInst, CoroSaveInst>(U))
+        continue;
+      EscapingBBs.insert(cast<Instruction>(U)->getParent());
+    }
+  }
+
   bool PotentiallyEscaped = false;
 
   do {
@@ -285,10 +362,17 @@ bool Lowerer::shouldElide(Function *F, DominatorTree &DT) const {
 
     Terminators.insert(&B);
   }
+  SmallPtrSet<CoroBeginInst *, 8> CoroBeginsToTest;
+
+  for (const auto& Class : CoroBeginClasses) {
+    if (Class.isLeader()) {
+      CoroBeginsToTest.insert(Class.getData());
+    }
+  }
 
   // Filter out the coro.destroy that lie along exceptional paths.
   SmallPtrSet<CoroBeginInst *, 8> ReferencedCoroBegins;
-  for (const auto &It : DestroyAddr) {
+  for (auto *CB : CoroBeginsToTest) {
     // If every terminators is dominated by coro.destroy, we could know the
     // corresponding coro.begin wouldn't escape.
     //
@@ -298,20 +382,35 @@ bool Lowerer::shouldElide(Function *F, DominatorTree &DT) const {
     //
     // hasEscapePath is relatively slow, so we avoid to run it as much as
     // possible.
+    auto It = DestroyAddr.find(CB);
+    if (It == DestroyAddr.end()) {
+      llvm::dbgs() << "has no destroyaddr!\n";
+      continue;
+    }
+
     if (llvm::all_of(Terminators,
                      [&](auto *TI) {
-                       return llvm::any_of(It.second, [&](auto *DA) {
+                       return llvm::any_of(It->second, [&](auto *DA) {
                          return DT.dominates(DA, TI->getTerminator());
                        });
                      }) ||
-        !hasEscapePath(It.first, Terminators))
-      ReferencedCoroBegins.insert(It.first);
+        !hasEscapePath(*F, CB, Terminators))
+    {
+      ReferencedCoroBegins.insert(CB);
+    } else {
+      llvm::dbgs() << "not referenced\n";
+    }
   }
 
   // If size of the set is the same as total number of coro.begin, that means we
   // found a coro.free or coro.destroy referencing each coro.begin, so we can
   // perform heap elision.
-  return ReferencedCoroBegins.size() == CoroBegins.size();
+  llvm::dbgs() << "DestroyAddr: " << DestroyAddr.size()
+               << " ReferencedCoroBegins: " << ReferencedCoroBegins.size()
+               << " CoroBeginsToTest: " <<  CoroBeginsToTest.size()
+               << "\n";
+
+  return ReferencedCoroBegins.size() == CoroBeginsToTest.size();
 }
 
 void Lowerer::collectPostSplitCoroIds(Function *F) {
@@ -338,38 +437,89 @@ void Lowerer::collectPostSplitCoroIds(Function *F) {
   }
 }
 
-bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA,
+static void findReachableCoroSubFnForValue(
+    llvm::Value *V,
+    const StackAliases &SA,
+    SmallSet<CoroSubFnInst *, 4> &Res) {
+
+  if (auto *CSFI = dyn_cast<CoroSubFnInst>(V)) {
+    Res.insert(CSFI);
+    return;
+  }
+
+  if (auto *CB = dyn_cast<CoroBeginInst>(V)) {
+    for (auto *U : CB->users()) {
+      findReachableCoroSubFnForValue(U, SA, Res);
+    }
+  }
+
+  if (auto *Phi = dyn_cast<PHINode>(V)) {
+    for (auto *U : Phi->users()) {
+      findReachableCoroSubFnForValue(U, SA, Res);
+    }
+  }
+
+  if (auto *SI = dyn_cast<StoreInst>(V)) {
+    for (auto *Load : SA.getLoadsByStore(SI)) {
+      findReachableCoroSubFnForValue(Load, SA, Res);
+    }
+  }
+}
+
+bool Lowerer::processCoroId(Function &F, CoroIdInst *CoroId, AAResults &AA,
                             DominatorTree &DT, OptimizationRemarkEmitter &ORE) {
   CoroBegins.clear();
   CoroAllocs.clear();
   ResumeAddr.clear();
   DestroyAddr.clear();
+  CoroBeginClasses = EquivalenceClasses<CoroBeginInst *>{}; // No clear function.
 
   // Collect all coro.begin and coro.allocs associated with this coro.id.
   for (User *U : CoroId->users()) {
-    if (auto *CB = dyn_cast<CoroBeginInst>(U))
+    if (auto *CB = dyn_cast<CoroBeginInst>(U)) {
       CoroBegins.push_back(CB);
-    else if (auto *CA = dyn_cast<CoroAllocInst>(U))
+      CoroBeginClasses.insert(CB);
+    } else if (auto *CA = dyn_cast<CoroAllocInst>(U))
       CoroAllocs.push_back(CA);
   }
 
-  // Collect all coro.subfn.addrs associated with coro.begin.
-  // Note, we only devirtualize the calls if their coro.subfn.addr refers to
-  // coro.begin directly. If we run into cases where this check is too
-  // conservative, we can consider relaxing the check.
+  // Create Equivalent Classes for CoroBegins, so that multiple begins going to the same Phi Node can be count as one.
   for (CoroBeginInst *CB : CoroBegins) {
-    for (User *U : CB->users())
-      if (auto *II = dyn_cast<CoroSubFnInst>(U))
-        switch (II->getIndex()) {
-        case CoroSubFnInst::ResumeIndex:
-          ResumeAddr.push_back(II);
+    CoroBeginClasses.insert(CB);
+    for (User *U : CB->users()) {
+      if (auto *Phi = dyn_cast<PHINode>(U)) {
+        auto Values = Phi->incoming_values();
+        auto First = cast<CoroBeginInst>(Values.begin());
+        for (auto I = Values.begin(), E = Values.end(); I != E; I++) {
+          // unionSets inserts First if not exist as well.
+          CoroBeginClasses.unionSets(First, cast<CoroBeginInst>(I));
+        }
+      }
+    }
+  }
+
+  StackAliases SA{F};
+
+  // Collect all coro.subfn.addrs associated with coro.begin.
+  for (const auto& Class : CoroBeginClasses) {
+    if (!Class.isLeader()) {
+      continue;
+    }
+    auto Leader = Class.getData();
+    SmallSet<CoroSubFnInst *, 4> Res;
+    findReachableCoroSubFnForValue(Leader, SA, Res);
+    for (auto *CSFI : Res) {
+      switch (CSFI->getIndex()) {
+        case llvm::CoroSubFnInst::ResumeIndex:
+          ResumeAddr.insert(CSFI);
           break;
-        case CoroSubFnInst::DestroyIndex:
-          DestroyAddr[CB].push_back(II);
+        case llvm::CoroSubFnInst::DestroyIndex:
+          DestroyAddr[Leader].insert(CSFI);
           break;
         default:
           llvm_unreachable("unexpected coro.subfn.addr constant");
-        }
+      }
+    }
   }
 
   // PostSplit coro.id refers to an array of subfunctions in its Info
@@ -383,6 +533,7 @@ bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA,
   replaceWithConstant(ResumeAddrConstant, ResumeAddr);
 
   bool ShouldElide = shouldElide(CoroId->getFunction(), DT);
+  llvm::dbgs() << "ShouldElide " << CoroId->getCoroutine()->getName() << ": " << ShouldElide << "\n";
   if (!ShouldElide)
     ORE.emit([&]() {
       if (auto FrameSizeAndAlign =
@@ -466,7 +617,7 @@ PreservedAnalyses CoroElidePass::run(Function &F, FunctionAnalysisManager &AM) {
 
   bool Changed = false;
   for (auto *CII : L.CoroIds)
-    Changed |= L.processCoroId(CII, AA, DT, ORE);
+    Changed |= L.processCoroId(F, CII, AA, DT, ORE);
 
   return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
 }



More information about the llvm-commits mailing list