[llvm] [NFC] CoroElide: Refactor `Lowerer` into `CoroIdElider` (PR #91539)

via llvm-commits llvm-commits at lists.llvm.org
Wed May 8 14:46:52 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Yuxuan Chen (yuxuanchen1997)

<details>
<summary>Changes</summary>

This patch contains no functional changes. 

The main goal of this patch is to get better clarity out of the code, to make intentions and assumptions clear. 

One major design problem I had in the past were `Lowerer`. It previously inherited from `coro::LowererBase` but it doesn't use any of the fields or methods from `LowererBase`. It might be an artifact leftover from previous designs of this code.

Furthermore, we should clarify that although one such instance is bound to the function, `Lowerer` was dedicated to one `CoroId` instruction at a time. We rely on a sequence of fragile constructs like `CoroBegins.clear(); DestroyAddr.clear()`. This doesn't help understand the code. 

What's worse is that we have confusing calls like `elideHeapAllocations(CoroId->getFunction(), ...` and it might get confused with `CoroId->getCoroutine()`. 

The new structure intends to make it clear that we always operate on one `CoroId` at a time, which may have multiple `CoroBegin`s. Such structure doesn't rely on frequent `.clear()` that's prone to miss. 

---
Full diff: https://github.com/llvm/llvm-project/pull/91539.diff


1 Files Affected:

- (modified) llvm/lib/Transforms/Coroutines/CoroElide.cpp (+168-149) 


``````````diff
diff --git a/llvm/lib/Transforms/Coroutines/CoroElide.cpp b/llvm/lib/Transforms/Coroutines/CoroElide.cpp
index d356a6d2e5759..9944469c17ad4 100644
--- a/llvm/lib/Transforms/Coroutines/CoroElide.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroElide.cpp
@@ -33,24 +33,46 @@ static cl::opt<std::string> CoroElideInfoOutputFilename(
 
 namespace {
 // Created on demand if the coro-elide pass has work to do.
-struct Lowerer : coro::LowererBase {
+class FunctionElideManager {
+public:
+  FunctionElideManager(Function *F) : ContainingFunction(F) {
+    this->collectPostSplitCoroIds();
+  }
+
+  bool isElideNecessary() const { return !CoroIds.empty(); }
+
+  const SmallVectorImpl<CoroIdInst *> &getCoroIds() const { return CoroIds; }
+
+private:
+  Function *ContainingFunction;
   SmallVector<CoroIdInst *, 4> CoroIds;
+  SmallPtrSet<const SwitchInst *, 4> CoroSuspendSwitches;
+
+  void collectPostSplitCoroIds();
+  friend class CoroIdElider;
+};
+
+class CoroIdElider {
+public:
+  CoroIdElider(CoroIdInst *CoroId, FunctionElideManager &FEM, AAResults &AA,
+               DominatorTree &DT, OptimizationRemarkEmitter &ORE);
+  void elideHeapAllocations(uint64_t FrameSize, Align FrameAlign);
+  bool lifetimeEligibleForElide() const;
+  bool attemptElide();
+  bool CoroBeginCanEscape(const CoroBeginInst *,
+                          const SmallPtrSetImpl<BasicBlock *> &) const;
+
+private:
+  CoroIdInst *CoroId;
+  FunctionElideManager &FEM;
+  AAResults &AA;
+  DominatorTree &DT;
+  OptimizationRemarkEmitter &ORE;
+
   SmallVector<CoroBeginInst *, 1> CoroBegins;
   SmallVector<CoroAllocInst *, 1> CoroAllocs;
   SmallVector<CoroSubFnInst *, 4> ResumeAddr;
   DenseMap<CoroBeginInst *, SmallVector<CoroSubFnInst *, 4>> DestroyAddr;
-  SmallPtrSet<const SwitchInst *, 4> CoroSuspendSwitches;
-
-  Lowerer(Module &M) : LowererBase(M) {}
-
-  void elideHeapAllocations(Function *F, uint64_t FrameSize, Align FrameAlign,
-                            AAResults &AA);
-  bool shouldElide(Function *F, DominatorTree &DT) const;
-  void collectPostSplitCoroIds(Function *F);
-  bool processCoroId(CoroIdInst *, AAResults &AA, DominatorTree &DT,
-                     OptimizationRemarkEmitter &ORE);
-  bool hasEscapePath(const CoroBeginInst *,
-                     const SmallPtrSetImpl<BasicBlock *> &) const;
 };
 } // end anonymous namespace
 
@@ -136,14 +158,66 @@ static std::unique_ptr<raw_fd_ostream> getOrCreateLogFile() {
 }
 #endif
 
+void FunctionElideManager::collectPostSplitCoroIds() {
+  for (auto &I : instructions(this->ContainingFunction)) {
+    if (auto *CII = dyn_cast<CoroIdInst>(&I))
+      if (CII->getInfo().isPostSplit())
+        // If it is the coroutine itself, don't touch it.
+        if (CII->getCoroutine() != CII->getFunction())
+          CoroIds.push_back(CII);
+
+    // Consider case like:
+    // %0 = call i8 @llvm.coro.suspend(...)
+    // switch i8 %0, label %suspend [i8 0, label %resume
+    //                              i8 1, label %cleanup]
+    // and collect the SwitchInsts which are used by escape analysis later.
+    if (auto *CSI = dyn_cast<CoroSuspendInst>(&I))
+      if (CSI->hasOneUse() && isa<SwitchInst>(CSI->use_begin()->getUser())) {
+        SwitchInst *SWI = cast<SwitchInst>(CSI->use_begin()->getUser());
+        if (SWI->getNumCases() == 2)
+          CoroSuspendSwitches.insert(SWI);
+      }
+  }
+}
+
+CoroIdElider::CoroIdElider(CoroIdInst *CoroId, FunctionElideManager &FEM,
+                           AAResults &AA, DominatorTree &DT,
+                           OptimizationRemarkEmitter &ORE)
+    : CoroId(CoroId), FEM(FEM), AA(AA), DT(DT), ORE(ORE) {
+  // Collect all coro.begin and coro.allocs associated with this coro.id.
+  for (User *U : CoroId->users()) {
+    if (auto *CB = dyn_cast<CoroBeginInst>(U))
+      CoroBegins.push_back(CB);
+    else if (auto *CA = dyn_cast<CoroAllocInst>(U))
+      CoroAllocs.push_back(CA);
+  }
+
+  // Collect all coro.subfn.addrs associated with coro.begin.
+  // Note, we only devirtualize the calls if their coro.subfn.addr refers to
+  // coro.begin directly. If we run into cases where this check is too
+  // conservative, we can consider relaxing the check.
+  for (CoroBeginInst *CB : CoroBegins) {
+    for (User *U : CB->users())
+      if (auto *II = dyn_cast<CoroSubFnInst>(U))
+        switch (II->getIndex()) {
+        case CoroSubFnInst::ResumeIndex:
+          ResumeAddr.push_back(II);
+          break;
+        case CoroSubFnInst::DestroyIndex:
+          DestroyAddr[CB].push_back(II);
+          break;
+        default:
+          llvm_unreachable("unexpected coro.subfn.addr constant");
+        }
+  }
+}
+
 // To elide heap allocations we need to suppress code blocks guarded by
 // llvm.coro.alloc and llvm.coro.free instructions.
-void Lowerer::elideHeapAllocations(Function *F, uint64_t FrameSize,
-                                   Align FrameAlign, AAResults &AA) {
-  LLVMContext &C = F->getContext();
+void CoroIdElider::elideHeapAllocations(uint64_t FrameSize, Align FrameAlign) {
+  LLVMContext &C = FEM.ContainingFunction->getContext();
   BasicBlock::iterator InsertPt =
-      getFirstNonAllocaInTheEntryBlock(CoroIds.front()->getFunction())
-          ->getIterator();
+      getFirstNonAllocaInTheEntryBlock(FEM.ContainingFunction)->getIterator();
 
   // Replacing llvm.coro.alloc with false will suppress dynamic
   // allocation as it is expected for the frontend to generate the code that
@@ -161,7 +235,7 @@ void Lowerer::elideHeapAllocations(Function *F, uint64_t FrameSize,
   // is spilled into the coroutine frame and recreate the alignment information
   // here. Possibly we will need to do a mini SROA here and break the coroutine
   // frame into individual AllocaInst recreating the original alignment.
-  const DataLayout &DL = F->getParent()->getDataLayout();
+  const DataLayout &DL = FEM.ContainingFunction->getParent()->getDataLayout();
   auto FrameTy = ArrayType::get(Type::getInt8Ty(C), FrameSize);
   auto *Frame = new AllocaInst(FrameTy, DL.getAllocaAddrSpace(), "", InsertPt);
   Frame->setAlignment(FrameAlign);
@@ -178,8 +252,8 @@ void Lowerer::elideHeapAllocations(Function *F, uint64_t FrameSize,
   removeTailCallAttribute(Frame, AA);
 }
 
-bool Lowerer::hasEscapePath(const CoroBeginInst *CB,
-                            const SmallPtrSetImpl<BasicBlock *> &TIs) const {
+bool CoroIdElider::CoroBeginCanEscape(
+    const CoroBeginInst *CB, const SmallPtrSetImpl<BasicBlock *> &TIs) const {
   const auto &It = DestroyAddr.find(CB);
   assert(It != DestroyAddr.end());
 
@@ -248,7 +322,7 @@ bool Lowerer::hasEscapePath(const CoroBeginInst *CB,
     // which means a escape path to normal terminator, it is reasonable to skip
     // it since coroutine frame doesn't change outside the coroutine body.
     if (isa<SwitchInst>(TI) &&
-        CoroSuspendSwitches.count(cast<SwitchInst>(TI))) {
+        FEM.CoroSuspendSwitches.count(cast<SwitchInst>(TI))) {
       Worklist.push_back(cast<SwitchInst>(TI)->getSuccessor(1));
       Worklist.push_back(cast<SwitchInst>(TI)->getSuccessor(2));
     } else
@@ -261,7 +335,7 @@ bool Lowerer::hasEscapePath(const CoroBeginInst *CB,
   return false;
 }
 
-bool Lowerer::shouldElide(Function *F, DominatorTree &DT) const {
+bool CoroIdElider::lifetimeEligibleForElide() const {
   // If no CoroAllocs, we cannot suppress allocation, so elision is not
   // possible.
   if (CoroAllocs.empty())
@@ -270,6 +344,7 @@ bool Lowerer::shouldElide(Function *F, DominatorTree &DT) const {
   // Check that for every coro.begin there is at least one coro.destroy directly
   // referencing the SSA value of that coro.begin along each
   // non-exceptional path.
+  //
   // If the value escaped, then coro.destroy would have been referencing a
   // memory location storing that value and not the virtual register.
 
@@ -277,7 +352,7 @@ bool Lowerer::shouldElide(Function *F, DominatorTree &DT) const {
   // First gather all of the terminators for the function.
   // Consider the final coro.suspend as the real terminator when the current
   // function is a coroutine.
-  for (BasicBlock &B : *F) {
+  for (BasicBlock &B : *FEM.ContainingFunction) {
     auto *TI = B.getTerminator();
 
     if (TI->getNumSuccessors() != 0 || isa<UnreachableInst>(TI))
@@ -287,91 +362,43 @@ bool Lowerer::shouldElide(Function *F, DominatorTree &DT) const {
   }
 
   // Filter out the coro.destroy that lie along exceptional paths.
-  SmallPtrSet<CoroBeginInst *, 8> ReferencedCoroBegins;
-  for (const auto &It : DestroyAddr) {
+  for (const auto *CB : CoroBegins) {
+    auto It = DestroyAddr.find(CB);
+
+    // FIXME: If we have not found any destroys for this coro.begin, we
+    // disqualify this elide.
+    if (It == DestroyAddr.end())
+      return false;
+
+    const auto &CorrespondingDestroyAddrs = It->second;
+
     // If every terminators is dominated by coro.destroy, we could know the
     // corresponding coro.begin wouldn't escape.
-    //
-    // Otherwise hasEscapePath would decide whether there is any paths from
+    auto DominatesTerminator = [&](auto *TI) {
+      return llvm::any_of(CorrespondingDestroyAddrs, [&](auto *Destroy) {
+        return DT.dominates(Destroy, TI->getTerminator());
+      });
+    };
+
+    if (llvm::all_of(Terminators, DominatesTerminator))
+      continue;
+
+    // Otherwise CoroBeginCanEscape would decide whether there is any paths from
     // coro.begin to Terminators which not pass through any of the
-    // coro.destroys.
+    // coro.destroys. This is a slower analysis.
     //
-    // hasEscapePath is relatively slow, so we avoid to run it as much as
+    // CoroBeginCanEscape is relatively slow, so we avoid to run it as much as
     // possible.
-    if (llvm::all_of(Terminators,
-                     [&](auto *TI) {
-                       return llvm::any_of(It.second, [&](auto *DA) {
-                         return DT.dominates(DA, TI->getTerminator());
-                       });
-                     }) ||
-        !hasEscapePath(It.first, Terminators))
-      ReferencedCoroBegins.insert(It.first);
+    if (CoroBeginCanEscape(CB, Terminators))
+      return false;
   }
 
-  // If size of the set is the same as total number of coro.begin, that means we
-  // found a coro.free or coro.destroy referencing each coro.begin, so we can
-  // perform heap elision.
-  return ReferencedCoroBegins.size() == CoroBegins.size();
-}
-
-void Lowerer::collectPostSplitCoroIds(Function *F) {
-  CoroIds.clear();
-  CoroSuspendSwitches.clear();
-  for (auto &I : instructions(F)) {
-    if (auto *CII = dyn_cast<CoroIdInst>(&I))
-      if (CII->getInfo().isPostSplit())
-        // If it is the coroutine itself, don't touch it.
-        if (CII->getCoroutine() != CII->getFunction())
-          CoroIds.push_back(CII);
-
-    // Consider case like:
-    // %0 = call i8 @llvm.coro.suspend(...)
-    // switch i8 %0, label %suspend [i8 0, label %resume
-    //                              i8 1, label %cleanup]
-    // and collect the SwitchInsts which are used by escape analysis later.
-    if (auto *CSI = dyn_cast<CoroSuspendInst>(&I))
-      if (CSI->hasOneUse() && isa<SwitchInst>(CSI->use_begin()->getUser())) {
-        SwitchInst *SWI = cast<SwitchInst>(CSI->use_begin()->getUser());
-        if (SWI->getNumCases() == 2)
-          CoroSuspendSwitches.insert(SWI);
-      }
-  }
+  // We have checked all CoroBegins and their paths to the terminators without
+  // finding disqualifying code patterns, so we can perform heap allocations.
+  return true;
 }
 
-bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA,
-                            DominatorTree &DT, OptimizationRemarkEmitter &ORE) {
-  CoroBegins.clear();
-  CoroAllocs.clear();
-  ResumeAddr.clear();
-  DestroyAddr.clear();
-
-  // Collect all coro.begin and coro.allocs associated with this coro.id.
-  for (User *U : CoroId->users()) {
-    if (auto *CB = dyn_cast<CoroBeginInst>(U))
-      CoroBegins.push_back(CB);
-    else if (auto *CA = dyn_cast<CoroAllocInst>(U))
-      CoroAllocs.push_back(CA);
-  }
-
-  // Collect all coro.subfn.addrs associated with coro.begin.
-  // Note, we only devirtualize the calls if their coro.subfn.addr refers to
-  // coro.begin directly. If we run into cases where this check is too
-  // conservative, we can consider relaxing the check.
-  for (CoroBeginInst *CB : CoroBegins) {
-    for (User *U : CB->users())
-      if (auto *II = dyn_cast<CoroSubFnInst>(U))
-        switch (II->getIndex()) {
-        case CoroSubFnInst::ResumeIndex:
-          ResumeAddr.push_back(II);
-          break;
-        case CoroSubFnInst::DestroyIndex:
-          DestroyAddr[CB].push_back(II);
-          break;
-        default:
-          llvm_unreachable("unexpected coro.subfn.addr constant");
-        }
-  }
-
+bool CoroIdElider::attemptElide() {
   // PostSplit coro.id refers to an array of subfunctions in its Info
   // argument.
   ConstantArray *Resumers = CoroId->getInfo().Resumers;
@@ -382,63 +409,55 @@ bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA,
 
   replaceWithConstant(ResumeAddrConstant, ResumeAddr);
 
-  bool ShouldElide = shouldElide(CoroId->getFunction(), DT);
-  if (!ShouldElide)
-    ORE.emit([&]() {
-      if (auto FrameSizeAndAlign =
-              getFrameLayout(cast<Function>(ResumeAddrConstant)))
-        return OptimizationRemarkMissed(DEBUG_TYPE, "CoroElide", CoroId)
-               << "'" << ore::NV("callee", CoroId->getCoroutine()->getName())
-               << "' not elided in '"
-               << ore::NV("caller", CoroId->getFunction()->getName())
-               << "' (frame_size="
-               << ore::NV("frame_size", FrameSizeAndAlign->first) << ", align="
-               << ore::NV("align", FrameSizeAndAlign->second.value()) << ")";
-      else
-        return OptimizationRemarkMissed(DEBUG_TYPE, "CoroElide", CoroId)
-               << "'" << ore::NV("callee", CoroId->getCoroutine()->getName())
-               << "' not elided in '"
-               << ore::NV("caller", CoroId->getFunction()->getName())
-               << "' (frame_size=unknown, align=unknown)";
-    });
+  bool EligibleForElide = lifetimeEligibleForElide();
 
   auto *DestroyAddrConstant = Resumers->getAggregateElement(
-      ShouldElide ? CoroSubFnInst::CleanupIndex : CoroSubFnInst::DestroyIndex);
+      EligibleForElide ? CoroSubFnInst::CleanupIndex
+                       : CoroSubFnInst::DestroyIndex);
 
   for (auto &It : DestroyAddr)
     replaceWithConstant(DestroyAddrConstant, It.second);
 
-  if (ShouldElide) {
-    if (auto FrameSizeAndAlign =
-            getFrameLayout(cast<Function>(ResumeAddrConstant))) {
-      elideHeapAllocations(CoroId->getFunction(), FrameSizeAndAlign->first,
-                           FrameSizeAndAlign->second, AA);
-      coro::replaceCoroFree(CoroId, /*Elide=*/true);
-      NumOfCoroElided++;
+  auto FrameSizeAndAlign = getFrameLayout(cast<Function>(ResumeAddrConstant));
+
+  auto CallerFunctionName = FEM.ContainingFunction->getName();
+  auto CalleeCoroutineName = CoroId->getCoroutine()->getName();
+
+  if (EligibleForElide && FrameSizeAndAlign) {
+    elideHeapAllocations(FrameSizeAndAlign->first, FrameSizeAndAlign->second);
+    coro::replaceCoroFree(CoroId, /*Elide=*/true);
+    NumOfCoroElided++;
+
 #ifndef NDEBUG
       if (!CoroElideInfoOutputFilename.empty())
-        *getOrCreateLogFile()
-            << "Elide " << CoroId->getCoroutine()->getName() << " in "
-            << CoroId->getFunction()->getName() << "\n";
+        *getOrCreateLogFile() << "Elide " << CalleeCoroutineName << " in "
+                              << FEM.ContainingFunction->getName() << "\n";
 #endif
+
       ORE.emit([&]() {
         return OptimizationRemark(DEBUG_TYPE, "CoroElide", CoroId)
-               << "'" << ore::NV("callee", CoroId->getCoroutine()->getName())
-               << "' elided in '"
-               << ore::NV("caller", CoroId->getFunction()->getName())
+               << "'" << ore::NV("callee", CalleeCoroutineName)
+               << "' elided in '" << ore::NV("caller", CallerFunctionName)
                << "' (frame_size="
                << ore::NV("frame_size", FrameSizeAndAlign->first) << ", align="
                << ore::NV("align", FrameSizeAndAlign->second.value()) << ")";
       });
-    } else {
-      ORE.emit([&]() {
-        return OptimizationRemarkMissed(DEBUG_TYPE, "CoroElide", CoroId)
-               << "'" << ore::NV("callee", CoroId->getCoroutine()->getName())
-               << "' not elided in '"
-               << ore::NV("caller", CoroId->getFunction()->getName())
-               << "' (frame_size=unknown, align=unknown)";
-      });
-    }
+  } else {
+    ORE.emit([&]() {
+      auto Remark = OptimizationRemarkMissed(DEBUG_TYPE, "CoroElide", CoroId)
+                    << "'" << ore::NV("callee", CalleeCoroutineName)
+                    << "' not elided in '"
+                    << ore::NV("caller", CallerFunctionName);
+
+      if (FrameSizeAndAlign)
+        return Remark << "' (frame_size="
+                      << ore::NV("frame_size", FrameSizeAndAlign->first)
+                      << ", align="
+                      << ore::NV("align", FrameSizeAndAlign->second.value())
+                      << ")";
+      else
+        return Remark << "' (frame_size=unknown, align=unknown)";
+    });
   }
 
   return true;
@@ -453,11 +472,9 @@ PreservedAnalyses CoroElidePass::run(Function &F, FunctionAnalysisManager &AM) {
   if (!declaresCoroElideIntrinsics(M))
     return PreservedAnalyses::all();
 
-  Lowerer L(M);
-  L.CoroIds.clear();
-  L.collectPostSplitCoroIds(&F);
-  // If we did not find any coro.id, there is nothing to do.
-  if (L.CoroIds.empty())
+  FunctionElideManager FEM{&F};
+  // Elide is not necessary if there's no coro.id within the function.
+  if (!FEM.isElideNecessary())
     return PreservedAnalyses::all();
 
   AAResults &AA = AM.getResult<AAManager>(F);
@@ -465,8 +482,10 @@ PreservedAnalyses CoroElidePass::run(Function &F, FunctionAnalysisManager &AM) {
   auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
 
   bool Changed = false;
-  for (auto *CII : L.CoroIds)
-    Changed |= L.processCoroId(CII, AA, DT, ORE);
+  for (auto *CII : FEM.getCoroIds()) {
+    CoroIdElider CIE(CII, FEM, AA, DT, ORE);
+    Changed |= CIE.attemptElide();
+  }
 
   return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/91539


More information about the llvm-commits mailing list