[llvm] [NFC] CoroElide: Refactor `Lowerer` into `CoroIdElider` (PR #91539)

Yuxuan Chen via llvm-commits llvm-commits at lists.llvm.org
Wed May 8 22:43:04 PDT 2024


https://github.com/yuxuanchen1997 updated https://github.com/llvm/llvm-project/pull/91539

>From 65d83846a07a1ccb73293812a989bc6ef65caf39 Mon Sep 17 00:00:00 2001
From: Yuxuan Chen <ych at meta.com>
Date: Wed, 8 May 2024 13:27:10 -0700
Subject: [PATCH 1/4] [NFC] CoroElide: clarify elideability requirements and
 refactor remarks

---
 llvm/lib/Transforms/Coroutines/CoroElide.cpp | 116 ++++++++++---------
 1 file changed, 61 insertions(+), 55 deletions(-)

diff --git a/llvm/lib/Transforms/Coroutines/CoroElide.cpp b/llvm/lib/Transforms/Coroutines/CoroElide.cpp
index d356a6d2e5759..77eb956246518 100644
--- a/llvm/lib/Transforms/Coroutines/CoroElide.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroElide.cpp
@@ -45,7 +45,7 @@ struct Lowerer : coro::LowererBase {
 
   void elideHeapAllocations(Function *F, uint64_t FrameSize, Align FrameAlign,
                             AAResults &AA);
-  bool shouldElide(Function *F, DominatorTree &DT) const;
+  bool lifetimeEligibleForElide(CoroIdInst *CoroId, DominatorTree &DT) const;
   void collectPostSplitCoroIds(Function *F);
   bool processCoroId(CoroIdInst *, AAResults &AA, DominatorTree &DT,
                      OptimizationRemarkEmitter &ORE);
@@ -261,15 +261,19 @@ bool Lowerer::hasEscapePath(const CoroBeginInst *CB,
   return false;
 }
 
-bool Lowerer::shouldElide(Function *F, DominatorTree &DT) const {
+bool Lowerer::lifetimeEligibleForElide(CoroIdInst *CoroId,
+                                       DominatorTree &DT) const {
   // If no CoroAllocs, we cannot suppress allocation, so elision is not
   // possible.
   if (CoroAllocs.empty())
     return false;
 
+  auto *ContainingFunction = CoroId->getFunction();
+
   // Check that for every coro.begin there is at least one coro.destroy directly
   // referencing the SSA value of that coro.begin along each
   // non-exceptional path.
+  //
   // If the value escaped, then coro.destroy would have been referencing a
   // memory location storing that value and not the virtual register.
 
@@ -277,7 +281,7 @@ bool Lowerer::shouldElide(Function *F, DominatorTree &DT) const {
   // First gather all of the terminators for the function.
   // Consider the final coro.suspend as the real terminator when the current
   // function is a coroutine.
-  for (BasicBlock &B : *F) {
+  for (BasicBlock &B : *ContainingFunction) {
     auto *TI = B.getTerminator();
 
     if (TI->getNumSuccessors() != 0 || isa<UnreachableInst>(TI))
@@ -287,31 +291,40 @@ bool Lowerer::shouldElide(Function *F, DominatorTree &DT) const {
   }
 
   // Filter out the coro.destroy that lie along exceptional paths.
-  SmallPtrSet<CoroBeginInst *, 8> ReferencedCoroBegins;
-  for (const auto &It : DestroyAddr) {
+  for (const auto *CB : CoroBegins) {
+    auto It = DestroyAddr.find(CB);
+
+    // FIXME: If we have not found any destroys for this coro.begin, we
+    // disqualify this elide.
+    if (It == DestroyAddr.end())
+      return false;
+
+    const auto &CorrespondingDestroyAddrs = It->second;
+
     // If every terminators is dominated by coro.destroy, we could know the
     // corresponding coro.begin wouldn't escape.
-    //
+    auto DominatesTerminator = [&](auto *TI) {
+      return llvm::any_of(CorrespondingDestroyAddrs, [&](auto *Destroy) {
+        return DT.dominates(Destroy, TI->getTerminator());
+      });
+    };
+
+    if (llvm::all_of(Terminators, DominatesTerminator))
+      continue;
+
     // Otherwise hasEscapePath would decide whether there is any paths from
     // coro.begin to Terminators which not pass through any of the
-    // coro.destroys.
+    // coro.destroys. This is a slower analysis.
     //
     // hasEscapePath is relatively slow, so we avoid to run it as much as
     // possible.
-    if (llvm::all_of(Terminators,
-                     [&](auto *TI) {
-                       return llvm::any_of(It.second, [&](auto *DA) {
-                         return DT.dominates(DA, TI->getTerminator());
-                       });
-                     }) ||
-        !hasEscapePath(It.first, Terminators))
-      ReferencedCoroBegins.insert(It.first);
+    if (hasEscapePath(CB, Terminators))
+      return false;
   }
 
-  // If size of the set is the same as total number of coro.begin, that means we
-  // found a coro.free or coro.destroy referencing each coro.begin, so we can
-  // perform heap elision.
-  return ReferencedCoroBegins.size() == CoroBegins.size();
+  // We have checked all CoroBegins and their paths to the terminators without
+  // finding disqualifying code patterns, so we can perform heap allocations.
+  return true;
 }
 
 void Lowerer::collectPostSplitCoroIds(Function *F) {
@@ -382,45 +395,30 @@ bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA,
 
   replaceWithConstant(ResumeAddrConstant, ResumeAddr);
 
-  bool ShouldElide = shouldElide(CoroId->getFunction(), DT);
-  if (!ShouldElide)
-    ORE.emit([&]() {
-      if (auto FrameSizeAndAlign =
-              getFrameLayout(cast<Function>(ResumeAddrConstant)))
-        return OptimizationRemarkMissed(DEBUG_TYPE, "CoroElide", CoroId)
-               << "'" << ore::NV("callee", CoroId->getCoroutine()->getName())
-               << "' not elided in '"
-               << ore::NV("caller", CoroId->getFunction()->getName())
-               << "' (frame_size="
-               << ore::NV("frame_size", FrameSizeAndAlign->first) << ", align="
-               << ore::NV("align", FrameSizeAndAlign->second.value()) << ")";
-      else
-        return OptimizationRemarkMissed(DEBUG_TYPE, "CoroElide", CoroId)
-               << "'" << ore::NV("callee", CoroId->getCoroutine()->getName())
-               << "' not elided in '"
-               << ore::NV("caller", CoroId->getFunction()->getName())
-               << "' (frame_size=unknown, align=unknown)";
-    });
+  bool EligibleForElide = lifetimeEligibleForElide(CoroId, DT);
 
   auto *DestroyAddrConstant = Resumers->getAggregateElement(
-      ShouldElide ? CoroSubFnInst::CleanupIndex : CoroSubFnInst::DestroyIndex);
+      EligibleForElide ? CoroSubFnInst::CleanupIndex
+                       : CoroSubFnInst::DestroyIndex);
 
   for (auto &It : DestroyAddr)
     replaceWithConstant(DestroyAddrConstant, It.second);
 
-  if (ShouldElide) {
-    if (auto FrameSizeAndAlign =
-            getFrameLayout(cast<Function>(ResumeAddrConstant))) {
-      elideHeapAllocations(CoroId->getFunction(), FrameSizeAndAlign->first,
-                           FrameSizeAndAlign->second, AA);
-      coro::replaceCoroFree(CoroId, /*Elide=*/true);
-      NumOfCoroElided++;
+  auto FrameSizeAndAlign = getFrameLayout(cast<Function>(ResumeAddrConstant));
+
+  if (EligibleForElide && FrameSizeAndAlign) {
+    elideHeapAllocations(CoroId->getFunction(), FrameSizeAndAlign->first,
+                         FrameSizeAndAlign->second, AA);
+    coro::replaceCoroFree(CoroId, /*Elide=*/true);
+    NumOfCoroElided++;
+
 #ifndef NDEBUG
       if (!CoroElideInfoOutputFilename.empty())
         *getOrCreateLogFile()
             << "Elide " << CoroId->getCoroutine()->getName() << " in "
             << CoroId->getFunction()->getName() << "\n";
 #endif
+
       ORE.emit([&]() {
         return OptimizationRemark(DEBUG_TYPE, "CoroElide", CoroId)
                << "'" << ore::NV("callee", CoroId->getCoroutine()->getName())
@@ -430,15 +428,23 @@ bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA,
                << ore::NV("frame_size", FrameSizeAndAlign->first) << ", align="
                << ore::NV("align", FrameSizeAndAlign->second.value()) << ")";
       });
-    } else {
-      ORE.emit([&]() {
-        return OptimizationRemarkMissed(DEBUG_TYPE, "CoroElide", CoroId)
-               << "'" << ore::NV("callee", CoroId->getCoroutine()->getName())
-               << "' not elided in '"
-               << ore::NV("caller", CoroId->getFunction()->getName())
-               << "' (frame_size=unknown, align=unknown)";
-      });
-    }
+  } else {
+    ORE.emit([&]() {
+      auto Remark = OptimizationRemarkMissed(DEBUG_TYPE, "CoroElide", CoroId)
+                    << "'"
+                    << ore::NV("callee", CoroId->getCoroutine()->getName())
+                    << "' not elided in '"
+                    << ore::NV("caller", CoroId->getFunction()->getName());
+
+      if (FrameSizeAndAlign)
+        return Remark << "' (frame_size="
+                      << ore::NV("frame_size", FrameSizeAndAlign->first)
+                      << ", align="
+                      << ore::NV("align", FrameSizeAndAlign->second.value())
+                      << ")";
+      else
+        return Remark << "' (frame_size=unknown, align=unknown)";
+    });
   }
 
   return true;

>From 5b143329c81194389f58ad7592b127135be54b01 Mon Sep 17 00:00:00 2001
From: Yuxuan Chen <ych at meta.com>
Date: Wed, 8 May 2024 14:28:07 -0700
Subject: [PATCH 2/4] [NFC] refactor `Lowerer`

---
 llvm/lib/Transforms/Coroutines/CoroElide.cpp | 205 ++++++++++---------
 1 file changed, 110 insertions(+), 95 deletions(-)

diff --git a/llvm/lib/Transforms/Coroutines/CoroElide.cpp b/llvm/lib/Transforms/Coroutines/CoroElide.cpp
index 77eb956246518..ef40479d646e3 100644
--- a/llvm/lib/Transforms/Coroutines/CoroElide.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroElide.cpp
@@ -33,24 +33,46 @@ static cl::opt<std::string> CoroElideInfoOutputFilename(
 
 namespace {
 // Created on demand if the coro-elide pass has work to do.
-struct Lowerer : coro::LowererBase {
+class FunctionElideManager {
+public:
+  FunctionElideManager(Function *F) : ContainingFunction(F) {
+    this->collectPostSplitCoroIds();
+  }
+
+  bool isElideNecessary() const { return !CoroIds.empty(); }
+
+  const SmallVectorImpl<CoroIdInst *> &getCoroIds() const { return CoroIds; }
+
+private:
+  Function *ContainingFunction;
   SmallVector<CoroIdInst *, 4> CoroIds;
+  SmallPtrSet<const SwitchInst *, 4> CoroSuspendSwitches;
+
+  void collectPostSplitCoroIds();
+  friend class CoroIdElider;
+};
+
+class CoroIdElider {
+public:
+  CoroIdElider(CoroIdInst *CoroId, FunctionElideManager &FEM, AAResults &AA,
+               DominatorTree &DT, OptimizationRemarkEmitter &ORE);
+  void elideHeapAllocations(uint64_t FrameSize, Align FrameAlign);
+  bool lifetimeEligibleForElide() const;
+  bool attemptElide();
+  bool CoroBeginCanEscape(const CoroBeginInst *,
+                          const SmallPtrSetImpl<BasicBlock *> &) const;
+
+private:
+  CoroIdInst *CoroId;
+  FunctionElideManager &FEM;
+  AAResults &AA;
+  DominatorTree &DT;
+  OptimizationRemarkEmitter &ORE;
+
   SmallVector<CoroBeginInst *, 1> CoroBegins;
   SmallVector<CoroAllocInst *, 1> CoroAllocs;
   SmallVector<CoroSubFnInst *, 4> ResumeAddr;
   DenseMap<CoroBeginInst *, SmallVector<CoroSubFnInst *, 4>> DestroyAddr;
-  SmallPtrSet<const SwitchInst *, 4> CoroSuspendSwitches;
-
-  Lowerer(Module &M) : LowererBase(M) {}
-
-  void elideHeapAllocations(Function *F, uint64_t FrameSize, Align FrameAlign,
-                            AAResults &AA);
-  bool lifetimeEligibleForElide(CoroIdInst *CoroId, DominatorTree &DT) const;
-  void collectPostSplitCoroIds(Function *F);
-  bool processCoroId(CoroIdInst *, AAResults &AA, DominatorTree &DT,
-                     OptimizationRemarkEmitter &ORE);
-  bool hasEscapePath(const CoroBeginInst *,
-                     const SmallPtrSetImpl<BasicBlock *> &) const;
 };
 } // end anonymous namespace
 
@@ -136,14 +158,66 @@ static std::unique_ptr<raw_fd_ostream> getOrCreateLogFile() {
 }
 #endif
 
+void FunctionElideManager::collectPostSplitCoroIds() {
+  for (auto &I : instructions(this->ContainingFunction)) {
+    if (auto *CII = dyn_cast<CoroIdInst>(&I))
+      if (CII->getInfo().isPostSplit())
+        // If it is the coroutine itself, don't touch it.
+        if (CII->getCoroutine() != CII->getFunction())
+          CoroIds.push_back(CII);
+
+    // Consider case like:
+    // %0 = call i8 @llvm.coro.suspend(...)
+    // switch i8 %0, label %suspend [i8 0, label %resume
+    //                              i8 1, label %cleanup]
+    // and collect the SwitchInsts which are used by escape analysis later.
+    if (auto *CSI = dyn_cast<CoroSuspendInst>(&I))
+      if (CSI->hasOneUse() && isa<SwitchInst>(CSI->use_begin()->getUser())) {
+        SwitchInst *SWI = cast<SwitchInst>(CSI->use_begin()->getUser());
+        if (SWI->getNumCases() == 2)
+          CoroSuspendSwitches.insert(SWI);
+      }
+  }
+}
+
+CoroIdElider::CoroIdElider(CoroIdInst *CoroId, FunctionElideManager &FEM,
+                           AAResults &AA, DominatorTree &DT,
+                           OptimizationRemarkEmitter &ORE)
+    : CoroId(CoroId), FEM(FEM), AA(AA), DT(DT), ORE(ORE) {
+  // Collect all coro.begin and coro.allocs associated with this coro.id.
+  for (User *U : CoroId->users()) {
+    if (auto *CB = dyn_cast<CoroBeginInst>(U))
+      CoroBegins.push_back(CB);
+    else if (auto *CA = dyn_cast<CoroAllocInst>(U))
+      CoroAllocs.push_back(CA);
+  }
+
+  // Collect all coro.subfn.addrs associated with coro.begin.
+  // Note, we only devirtualize the calls if their coro.subfn.addr refers to
+  // coro.begin directly. If we run into cases where this check is too
+  // conservative, we can consider relaxing the check.
+  for (CoroBeginInst *CB : CoroBegins) {
+    for (User *U : CB->users())
+      if (auto *II = dyn_cast<CoroSubFnInst>(U))
+        switch (II->getIndex()) {
+        case CoroSubFnInst::ResumeIndex:
+          ResumeAddr.push_back(II);
+          break;
+        case CoroSubFnInst::DestroyIndex:
+          DestroyAddr[CB].push_back(II);
+          break;
+        default:
+          llvm_unreachable("unexpected coro.subfn.addr constant");
+        }
+  }
+}
+
 // To elide heap allocations we need to suppress code blocks guarded by
 // llvm.coro.alloc and llvm.coro.free instructions.
-void Lowerer::elideHeapAllocations(Function *F, uint64_t FrameSize,
-                                   Align FrameAlign, AAResults &AA) {
-  LLVMContext &C = F->getContext();
+void CoroIdElider::elideHeapAllocations(uint64_t FrameSize, Align FrameAlign) {
+  LLVMContext &C = FEM.ContainingFunction->getContext();
   BasicBlock::iterator InsertPt =
-      getFirstNonAllocaInTheEntryBlock(CoroIds.front()->getFunction())
-          ->getIterator();
+      getFirstNonAllocaInTheEntryBlock(FEM.ContainingFunction)->getIterator();
 
   // Replacing llvm.coro.alloc with false will suppress dynamic
   // allocation as it is expected for the frontend to generate the code that
@@ -161,7 +235,7 @@ void Lowerer::elideHeapAllocations(Function *F, uint64_t FrameSize,
   // is spilled into the coroutine frame and recreate the alignment information
   // here. Possibly we will need to do a mini SROA here and break the coroutine
   // frame into individual AllocaInst recreating the original alignment.
-  const DataLayout &DL = F->getParent()->getDataLayout();
+  const DataLayout &DL = FEM.ContainingFunction->getParent()->getDataLayout();
   auto FrameTy = ArrayType::get(Type::getInt8Ty(C), FrameSize);
   auto *Frame = new AllocaInst(FrameTy, DL.getAllocaAddrSpace(), "", InsertPt);
   Frame->setAlignment(FrameAlign);
@@ -178,8 +252,8 @@ void Lowerer::elideHeapAllocations(Function *F, uint64_t FrameSize,
   removeTailCallAttribute(Frame, AA);
 }
 
-bool Lowerer::hasEscapePath(const CoroBeginInst *CB,
-                            const SmallPtrSetImpl<BasicBlock *> &TIs) const {
+bool CoroIdElider::CoroBeginCanEscape(
+    const CoroBeginInst *CB, const SmallPtrSetImpl<BasicBlock *> &TIs) const {
   const auto &It = DestroyAddr.find(CB);
   assert(It != DestroyAddr.end());
 
@@ -248,7 +322,7 @@ bool Lowerer::hasEscapePath(const CoroBeginInst *CB,
     // which means a escape path to normal terminator, it is reasonable to skip
     // it since coroutine frame doesn't change outside the coroutine body.
     if (isa<SwitchInst>(TI) &&
-        CoroSuspendSwitches.count(cast<SwitchInst>(TI))) {
+        FEM.CoroSuspendSwitches.count(cast<SwitchInst>(TI))) {
       Worklist.push_back(cast<SwitchInst>(TI)->getSuccessor(1));
       Worklist.push_back(cast<SwitchInst>(TI)->getSuccessor(2));
     } else
@@ -261,8 +335,7 @@ bool Lowerer::hasEscapePath(const CoroBeginInst *CB,
   return false;
 }
 
-bool Lowerer::lifetimeEligibleForElide(CoroIdInst *CoroId,
-                                       DominatorTree &DT) const {
+bool CoroIdElider::lifetimeEligibleForElide() const {
   // If no CoroAllocs, we cannot suppress allocation, so elision is not
   // possible.
   if (CoroAllocs.empty())
@@ -312,13 +385,13 @@ bool Lowerer::lifetimeEligibleForElide(CoroIdInst *CoroId,
     if (llvm::all_of(Terminators, DominatesTerminator))
       continue;
 
-    // Otherwise hasEscapePath would decide whether there is any paths from
+    // Otherwise CoroBeginCanEscape would decide whether there is any paths from
     // coro.begin to Terminators which not pass through any of the
     // coro.destroys. This is a slower analysis.
     //
-    // hasEscapePath is relatively slow, so we avoid to run it as much as
+    // CoroBeginCanEscape is relatively slow, so we avoid to run it as much as
     // possible.
-    if (hasEscapePath(CB, Terminators))
+    if (CoroBeginCanEscape(CB, Terminators))
       return false;
   }
 
@@ -327,64 +400,7 @@ bool Lowerer::lifetimeEligibleForElide(CoroIdInst *CoroId,
   return true;
 }
 
-void Lowerer::collectPostSplitCoroIds(Function *F) {
-  CoroIds.clear();
-  CoroSuspendSwitches.clear();
-  for (auto &I : instructions(F)) {
-    if (auto *CII = dyn_cast<CoroIdInst>(&I))
-      if (CII->getInfo().isPostSplit())
-        // If it is the coroutine itself, don't touch it.
-        if (CII->getCoroutine() != CII->getFunction())
-          CoroIds.push_back(CII);
-
-    // Consider case like:
-    // %0 = call i8 @llvm.coro.suspend(...)
-    // switch i8 %0, label %suspend [i8 0, label %resume
-    //                              i8 1, label %cleanup]
-    // and collect the SwitchInsts which are used by escape analysis later.
-    if (auto *CSI = dyn_cast<CoroSuspendInst>(&I))
-      if (CSI->hasOneUse() && isa<SwitchInst>(CSI->use_begin()->getUser())) {
-        SwitchInst *SWI = cast<SwitchInst>(CSI->use_begin()->getUser());
-        if (SWI->getNumCases() == 2)
-          CoroSuspendSwitches.insert(SWI);
-      }
-  }
-}
-
-bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA,
-                            DominatorTree &DT, OptimizationRemarkEmitter &ORE) {
-  CoroBegins.clear();
-  CoroAllocs.clear();
-  ResumeAddr.clear();
-  DestroyAddr.clear();
-
-  // Collect all coro.begin and coro.allocs associated with this coro.id.
-  for (User *U : CoroId->users()) {
-    if (auto *CB = dyn_cast<CoroBeginInst>(U))
-      CoroBegins.push_back(CB);
-    else if (auto *CA = dyn_cast<CoroAllocInst>(U))
-      CoroAllocs.push_back(CA);
-  }
-
-  // Collect all coro.subfn.addrs associated with coro.begin.
-  // Note, we only devirtualize the calls if their coro.subfn.addr refers to
-  // coro.begin directly. If we run into cases where this check is too
-  // conservative, we can consider relaxing the check.
-  for (CoroBeginInst *CB : CoroBegins) {
-    for (User *U : CB->users())
-      if (auto *II = dyn_cast<CoroSubFnInst>(U))
-        switch (II->getIndex()) {
-        case CoroSubFnInst::ResumeIndex:
-          ResumeAddr.push_back(II);
-          break;
-        case CoroSubFnInst::DestroyIndex:
-          DestroyAddr[CB].push_back(II);
-          break;
-        default:
-          llvm_unreachable("unexpected coro.subfn.addr constant");
-        }
-  }
-
+bool CoroIdElider::attemptElide() {
   // PostSplit coro.id refers to an array of subfunctions in its Info
   // argument.
   ConstantArray *Resumers = CoroId->getInfo().Resumers;
@@ -395,7 +411,7 @@ bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA,
 
   replaceWithConstant(ResumeAddrConstant, ResumeAddr);
 
-  bool EligibleForElide = lifetimeEligibleForElide(CoroId, DT);
+  bool EligibleForElide = lifetimeEligibleForElide();
 
   auto *DestroyAddrConstant = Resumers->getAggregateElement(
       EligibleForElide ? CoroSubFnInst::CleanupIndex
@@ -407,8 +423,7 @@ bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA,
   auto FrameSizeAndAlign = getFrameLayout(cast<Function>(ResumeAddrConstant));
 
   if (EligibleForElide && FrameSizeAndAlign) {
-    elideHeapAllocations(CoroId->getFunction(), FrameSizeAndAlign->first,
-                         FrameSizeAndAlign->second, AA);
+    elideHeapAllocations(FrameSizeAndAlign->first, FrameSizeAndAlign->second);
     coro::replaceCoroFree(CoroId, /*Elide=*/true);
     NumOfCoroElided++;
 
@@ -459,11 +474,9 @@ PreservedAnalyses CoroElidePass::run(Function &F, FunctionAnalysisManager &AM) {
   if (!declaresCoroElideIntrinsics(M))
     return PreservedAnalyses::all();
 
-  Lowerer L(M);
-  L.CoroIds.clear();
-  L.collectPostSplitCoroIds(&F);
-  // If we did not find any coro.id, there is nothing to do.
-  if (L.CoroIds.empty())
+  FunctionElideManager FEM{&F};
+  // Elide is not necessary if there's no coro.id within the function.
+  if (!FEM.isElideNecessary())
     return PreservedAnalyses::all();
 
   AAResults &AA = AM.getResult<AAManager>(F);
@@ -471,8 +484,10 @@ PreservedAnalyses CoroElidePass::run(Function &F, FunctionAnalysisManager &AM) {
   auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
 
   bool Changed = false;
-  for (auto *CII : L.CoroIds)
-    Changed |= L.processCoroId(CII, AA, DT, ORE);
+  for (auto *CII : FEM.getCoroIds()) {
+    CoroIdElider CIE(CII, FEM, AA, DT, ORE);
+    Changed |= CIE.attemptElide();
+  }
 
   return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
 }

>From 1fe2f76d422216a1c0d09969959d07a0f8543ff8 Mon Sep 17 00:00:00 2001
From: Yuxuan Chen <ych at meta.com>
Date: Wed, 8 May 2024 14:44:12 -0700
Subject: [PATCH 3/4] Avoid using the confusing CoroId->getFunction()

---
 llvm/lib/Transforms/Coroutines/CoroElide.cpp | 22 +++++++++-----------
 1 file changed, 10 insertions(+), 12 deletions(-)

diff --git a/llvm/lib/Transforms/Coroutines/CoroElide.cpp b/llvm/lib/Transforms/Coroutines/CoroElide.cpp
index ef40479d646e3..9944469c17ad4 100644
--- a/llvm/lib/Transforms/Coroutines/CoroElide.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroElide.cpp
@@ -341,8 +341,6 @@ bool CoroIdElider::lifetimeEligibleForElide() const {
   if (CoroAllocs.empty())
     return false;
 
-  auto *ContainingFunction = CoroId->getFunction();
-
   // Check that for every coro.begin there is at least one coro.destroy directly
   // referencing the SSA value of that coro.begin along each
   // non-exceptional path.
@@ -354,7 +352,7 @@ bool CoroIdElider::lifetimeEligibleForElide() const {
   // First gather all of the terminators for the function.
   // Consider the final coro.suspend as the real terminator when the current
   // function is a coroutine.
-  for (BasicBlock &B : *ContainingFunction) {
+  for (BasicBlock &B : *FEM.ContainingFunction) {
     auto *TI = B.getTerminator();
 
     if (TI->getNumSuccessors() != 0 || isa<UnreachableInst>(TI))
@@ -422,6 +420,9 @@ bool CoroIdElider::attemptElide() {
 
   auto FrameSizeAndAlign = getFrameLayout(cast<Function>(ResumeAddrConstant));
 
+  auto CallerFunctionName = FEM.ContainingFunction->getName();
+  auto CalleeCoroutineName = CoroId->getCoroutine()->getName();
+
   if (EligibleForElide && FrameSizeAndAlign) {
     elideHeapAllocations(FrameSizeAndAlign->first, FrameSizeAndAlign->second);
     coro::replaceCoroFree(CoroId, /*Elide=*/true);
@@ -429,16 +430,14 @@ bool CoroIdElider::attemptElide() {
 
 #ifndef NDEBUG
       if (!CoroElideInfoOutputFilename.empty())
-        *getOrCreateLogFile()
-            << "Elide " << CoroId->getCoroutine()->getName() << " in "
-            << CoroId->getFunction()->getName() << "\n";
+        *getOrCreateLogFile() << "Elide " << CalleeCoroutineName << " in "
+                              << FEM.ContainingFunction->getName() << "\n";
 #endif
 
       ORE.emit([&]() {
         return OptimizationRemark(DEBUG_TYPE, "CoroElide", CoroId)
-               << "'" << ore::NV("callee", CoroId->getCoroutine()->getName())
-               << "' elided in '"
-               << ore::NV("caller", CoroId->getFunction()->getName())
+               << "'" << ore::NV("callee", CalleeCoroutineName)
+               << "' elided in '" << ore::NV("caller", CallerFunctionName)
                << "' (frame_size="
                << ore::NV("frame_size", FrameSizeAndAlign->first) << ", align="
                << ore::NV("align", FrameSizeAndAlign->second.value()) << ")";
@@ -446,10 +445,9 @@ bool CoroIdElider::attemptElide() {
   } else {
     ORE.emit([&]() {
       auto Remark = OptimizationRemarkMissed(DEBUG_TYPE, "CoroElide", CoroId)
-                    << "'"
-                    << ore::NV("callee", CoroId->getCoroutine()->getName())
+                    << "'" << ore::NV("callee", CalleeCoroutineName)
                     << "' not elided in '"
-                    << ore::NV("caller", CoroId->getFunction()->getName());
+                    << ore::NV("caller", CallerFunctionName);
 
       if (FrameSizeAndAlign)
         return Remark << "' (frame_size="

>From 407aefa9e0adcd6ae4277c93328ad85e6722c9b0 Mon Sep 17 00:00:00 2001
From: Yuxuan Chen <ych at meta.com>
Date: Wed, 8 May 2024 22:42:46 -0700
Subject: [PATCH 4/4] address comments

---
 llvm/lib/Transforms/Coroutines/CoroElide.cpp | 49 ++++++++++----------
 1 file changed, 25 insertions(+), 24 deletions(-)

diff --git a/llvm/lib/Transforms/Coroutines/CoroElide.cpp b/llvm/lib/Transforms/Coroutines/CoroElide.cpp
index 9944469c17ad4..bb244489e4c2c 100644
--- a/llvm/lib/Transforms/Coroutines/CoroElide.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroElide.cpp
@@ -33,19 +33,20 @@ static cl::opt<std::string> CoroElideInfoOutputFilename(
 
 namespace {
 // Created on demand if the coro-elide pass has work to do.
-class FunctionElideManager {
+class FunctionElideInfo {
 public:
-  FunctionElideManager(Function *F) : ContainingFunction(F) {
+  FunctionElideInfo(Function *F) : ContainingFunction(F) {
     this->collectPostSplitCoroIds();
   }
 
-  bool isElideNecessary() const { return !CoroIds.empty(); }
+  bool hasCoroIds() const { return !CoroIds.empty(); }
 
   const SmallVectorImpl<CoroIdInst *> &getCoroIds() const { return CoroIds; }
 
 private:
   Function *ContainingFunction;
   SmallVector<CoroIdInst *, 4> CoroIds;
+  // Used in canCoroBeginEscape to distinguish coro.suspend switchs.
   SmallPtrSet<const SwitchInst *, 4> CoroSuspendSwitches;
 
   void collectPostSplitCoroIds();
@@ -54,17 +55,17 @@ class FunctionElideManager {
 
 class CoroIdElider {
 public:
-  CoroIdElider(CoroIdInst *CoroId, FunctionElideManager &FEM, AAResults &AA,
+  CoroIdElider(CoroIdInst *CoroId, FunctionElideInfo &FEI, AAResults &AA,
                DominatorTree &DT, OptimizationRemarkEmitter &ORE);
   void elideHeapAllocations(uint64_t FrameSize, Align FrameAlign);
   bool lifetimeEligibleForElide() const;
   bool attemptElide();
-  bool CoroBeginCanEscape(const CoroBeginInst *,
+  bool canCoroBeginEscape(const CoroBeginInst *,
                           const SmallPtrSetImpl<BasicBlock *> &) const;
 
 private:
   CoroIdInst *CoroId;
-  FunctionElideManager &FEM;
+  FunctionElideInfo &FEI;
   AAResults &AA;
   DominatorTree &DT;
   OptimizationRemarkEmitter &ORE;
@@ -158,7 +159,7 @@ static std::unique_ptr<raw_fd_ostream> getOrCreateLogFile() {
 }
 #endif
 
-void FunctionElideManager::collectPostSplitCoroIds() {
+void FunctionElideInfo::collectPostSplitCoroIds() {
   for (auto &I : instructions(this->ContainingFunction)) {
     if (auto *CII = dyn_cast<CoroIdInst>(&I))
       if (CII->getInfo().isPostSplit())
@@ -180,10 +181,10 @@ void FunctionElideManager::collectPostSplitCoroIds() {
   }
 }
 
-CoroIdElider::CoroIdElider(CoroIdInst *CoroId, FunctionElideManager &FEM,
+CoroIdElider::CoroIdElider(CoroIdInst *CoroId, FunctionElideInfo &FEI,
                            AAResults &AA, DominatorTree &DT,
                            OptimizationRemarkEmitter &ORE)
-    : CoroId(CoroId), FEM(FEM), AA(AA), DT(DT), ORE(ORE) {
+    : CoroId(CoroId), FEI(FEI), AA(AA), DT(DT), ORE(ORE) {
   // Collect all coro.begin and coro.allocs associated with this coro.id.
   for (User *U : CoroId->users()) {
     if (auto *CB = dyn_cast<CoroBeginInst>(U))
@@ -215,9 +216,9 @@ CoroIdElider::CoroIdElider(CoroIdInst *CoroId, FunctionElideManager &FEM,
 // To elide heap allocations we need to suppress code blocks guarded by
 // llvm.coro.alloc and llvm.coro.free instructions.
 void CoroIdElider::elideHeapAllocations(uint64_t FrameSize, Align FrameAlign) {
-  LLVMContext &C = FEM.ContainingFunction->getContext();
+  LLVMContext &C = FEI.ContainingFunction->getContext();
   BasicBlock::iterator InsertPt =
-      getFirstNonAllocaInTheEntryBlock(FEM.ContainingFunction)->getIterator();
+      getFirstNonAllocaInTheEntryBlock(FEI.ContainingFunction)->getIterator();
 
   // Replacing llvm.coro.alloc with false will suppress dynamic
   // allocation as it is expected for the frontend to generate the code that
@@ -235,7 +236,7 @@ void CoroIdElider::elideHeapAllocations(uint64_t FrameSize, Align FrameAlign) {
   // is spilled into the coroutine frame and recreate the alignment information
   // here. Possibly we will need to do a mini SROA here and break the coroutine
   // frame into individual AllocaInst recreating the original alignment.
-  const DataLayout &DL = FEM.ContainingFunction->getParent()->getDataLayout();
+  const DataLayout &DL = FEI.ContainingFunction->getParent()->getDataLayout();
   auto FrameTy = ArrayType::get(Type::getInt8Ty(C), FrameSize);
   auto *Frame = new AllocaInst(FrameTy, DL.getAllocaAddrSpace(), "", InsertPt);
   Frame->setAlignment(FrameAlign);
@@ -252,7 +253,7 @@ void CoroIdElider::elideHeapAllocations(uint64_t FrameSize, Align FrameAlign) {
   removeTailCallAttribute(Frame, AA);
 }
 
-bool CoroIdElider::CoroBeginCanEscape(
+bool CoroIdElider::canCoroBeginEscape(
     const CoroBeginInst *CB, const SmallPtrSetImpl<BasicBlock *> &TIs) const {
   const auto &It = DestroyAddr.find(CB);
   assert(It != DestroyAddr.end());
@@ -322,7 +323,7 @@ bool CoroIdElider::CoroBeginCanEscape(
     // which means a escape path to normal terminator, it is reasonable to skip
     // it since coroutine frame doesn't change outside the coroutine body.
     if (isa<SwitchInst>(TI) &&
-        FEM.CoroSuspendSwitches.count(cast<SwitchInst>(TI))) {
+        FEI.CoroSuspendSwitches.count(cast<SwitchInst>(TI))) {
       Worklist.push_back(cast<SwitchInst>(TI)->getSuccessor(1));
       Worklist.push_back(cast<SwitchInst>(TI)->getSuccessor(2));
     } else
@@ -352,7 +353,7 @@ bool CoroIdElider::lifetimeEligibleForElide() const {
   // First gather all of the terminators for the function.
   // Consider the final coro.suspend as the real terminator when the current
   // function is a coroutine.
-  for (BasicBlock &B : *FEM.ContainingFunction) {
+  for (BasicBlock &B : *FEI.ContainingFunction) {
     auto *TI = B.getTerminator();
 
     if (TI->getNumSuccessors() != 0 || isa<UnreachableInst>(TI))
@@ -383,13 +384,13 @@ bool CoroIdElider::lifetimeEligibleForElide() const {
     if (llvm::all_of(Terminators, DominatesTerminator))
       continue;
 
-    // Otherwise CoroBeginCanEscape would decide whether there is any paths from
+    // Otherwise canCoroBeginEscape would decide whether there is any paths from
     // coro.begin to Terminators which not pass through any of the
     // coro.destroys. This is a slower analysis.
     //
-    // CoroBeginCanEscape is relatively slow, so we avoid to run it as much as
+    // canCoroBeginEscape is relatively slow, so we avoid to run it as much as
     // possible.
-    if (CoroBeginCanEscape(CB, Terminators))
+    if (canCoroBeginEscape(CB, Terminators))
       return false;
   }
 
@@ -420,7 +421,7 @@ bool CoroIdElider::attemptElide() {
 
   auto FrameSizeAndAlign = getFrameLayout(cast<Function>(ResumeAddrConstant));
 
-  auto CallerFunctionName = FEM.ContainingFunction->getName();
+  auto CallerFunctionName = FEI.ContainingFunction->getName();
   auto CalleeCoroutineName = CoroId->getCoroutine()->getName();
 
   if (EligibleForElide && FrameSizeAndAlign) {
@@ -431,7 +432,7 @@ bool CoroIdElider::attemptElide() {
 #ifndef NDEBUG
       if (!CoroElideInfoOutputFilename.empty())
         *getOrCreateLogFile() << "Elide " << CalleeCoroutineName << " in "
-                              << FEM.ContainingFunction->getName() << "\n";
+                              << FEI.ContainingFunction->getName() << "\n";
 #endif
 
       ORE.emit([&]() {
@@ -472,9 +473,9 @@ PreservedAnalyses CoroElidePass::run(Function &F, FunctionAnalysisManager &AM) {
   if (!declaresCoroElideIntrinsics(M))
     return PreservedAnalyses::all();
 
-  FunctionElideManager FEM{&F};
+  FunctionElideInfo FEI{&F};
   // Elide is not necessary if there's no coro.id within the function.
-  if (!FEM.isElideNecessary())
+  if (!FEI.hasCoroIds())
     return PreservedAnalyses::all();
 
   AAResults &AA = AM.getResult<AAManager>(F);
@@ -482,8 +483,8 @@ PreservedAnalyses CoroElidePass::run(Function &F, FunctionAnalysisManager &AM) {
   auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
 
   bool Changed = false;
-  for (auto *CII : FEM.getCoroIds()) {
-    CoroIdElider CIE(CII, FEM, AA, DT, ORE);
+  for (auto *CII : FEI.getCoroIds()) {
+    CoroIdElider CIE(CII, FEI, AA, DT, ORE);
     Changed |= CIE.attemptElide();
   }
 



More information about the llvm-commits mailing list