[llvm] [Coroutines] Make CoroElide actually work for (trivial) C++ coroutines (PR #92969)

Yuxuan Chen via llvm-commits llvm-commits at lists.llvm.org
Tue May 21 16:05:13 PDT 2024


https://github.com/yuxuanchen1997 updated https://github.com/llvm/llvm-project/pull/92969

>From 3760552ebef1d77ee760a3a2fae11e1b34030a19 Mon Sep 17 00:00:00 2001
From: Yuxuan Chen <ych at meta.com>
Date: Tue, 21 May 2024 12:39:26 -0700
Subject: [PATCH] wip

---
 llvm/lib/Transforms/Coroutines/CoroElide.cpp | 101 ++++++++++++++++---
 1 file changed, 88 insertions(+), 13 deletions(-)

diff --git a/llvm/lib/Transforms/Coroutines/CoroElide.cpp b/llvm/lib/Transforms/Coroutines/CoroElide.cpp
index bb244489e4c2c..96b9c1e4af9ab 100644
--- a/llvm/lib/Transforms/Coroutines/CoroElide.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroElide.cpp
@@ -13,8 +13,10 @@
 #include "llvm/Analysis/AliasAnalysis.h"
 #include "llvm/Analysis/InstructionSimplify.h"
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
+#include "llvm/IR/Constants.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/FileSystem.h"
 #include <optional>
@@ -193,23 +195,51 @@ CoroIdElider::CoroIdElider(CoroIdInst *CoroId, FunctionElideInfo &FEI,
       CoroAllocs.push_back(CA);
   }
 
+  DenseMap<AllocaInst *, CoroBeginInst *> Writes;
+  for (CoroBeginInst *CB : CoroBegins) {
+    for (User *U : CB->users()) {
+      if (auto *Store = dyn_cast<StoreInst>(U)) {
+        if (Store->getOperand(0) == CB) {
+          auto *Dest = Store->getOperand(1);
+          if (auto *AI = dyn_cast<AllocaInst>(Dest)) {
+            Writes[AI] = CB;
+          }
+        }
+      }
+    }
+  }
+
+  auto CheckCoroBeginUser = [&](CoroBeginInst *CB, User *U) {
+    if (auto *II = dyn_cast<CoroSubFnInst>(U))
+      switch (II->getIndex()) {
+      case CoroSubFnInst::ResumeIndex:
+        ResumeAddr.push_back(II);
+        break;
+      case CoroSubFnInst::DestroyIndex:
+        DestroyAddr[CB].push_back(II);
+        break;
+      default:
+        llvm_unreachable("unexpected coro.subfn.addr constant");
+      }
+  };
+
+  for (const auto &[Alloca, CB] : Writes) {
+    for (User *AU : Alloca->users()) {
+      auto *Load = dyn_cast<LoadInst>(AU);
+      if (Load && Load->getOperand(0) == Alloca) {
+        for (User *LU : Load->users())
+          CheckCoroBeginUser(CB, LU);
+      }
+    }
+  }
+
   // Collect all coro.subfn.addrs associated with coro.begin.
   // Note, we only devirtualize the calls if their coro.subfn.addr refers to
   // coro.begin directly. If we run into cases where this check is too
   // conservative, we can consider relaxing the check.
   for (CoroBeginInst *CB : CoroBegins) {
     for (User *U : CB->users())
-      if (auto *II = dyn_cast<CoroSubFnInst>(U))
-        switch (II->getIndex()) {
-        case CoroSubFnInst::ResumeIndex:
-          ResumeAddr.push_back(II);
-          break;
-        case CoroSubFnInst::DestroyIndex:
-          DestroyAddr[CB].push_back(II);
-          break;
-        default:
-          llvm_unreachable("unexpected coro.subfn.addr constant");
-        }
+      CheckCoroBeginUser(CB, U);
   }
 }
 
@@ -253,6 +283,44 @@ void CoroIdElider::elideHeapAllocations(uint64_t FrameSize, Align FrameAlign) {
   removeTailCallAttribute(Frame, AA);
 }
 
+static BasicBlock *getNullCheckerPredecessor(Value *Handle,
+                                             BasicBlock *DestroyBB) {
+  auto *Predecessor = DestroyBB->getSinglePredecessor();
+  if (!Predecessor)
+    return nullptr;
+
+  auto *BrInst = dyn_cast_or_null<BranchInst>(Predecessor->getTerminator());
+  if (!BrInst)
+    return nullptr;
+
+  auto CheckNe = BrInst->getSuccessor(0) == DestroyBB;
+  auto *Cond = BrInst->getCondition();
+  assert(Cond);
+  if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
+    bool HasRequiredPredicate =
+        Cmp->getSignedPredicate() ==
+        (CheckNe ? CmpInst::Predicate::ICMP_NE : CmpInst::Predicate::ICMP_EQ);
+    bool HasNullPtrOnOneSide = false;
+    bool HasHandleOnOther = false;
+
+    for (auto I : {0, 1}) {
+      auto Operand = Cmp->getOperand(I);
+      if (auto *Const = dyn_cast<Constant>(Operand)) {
+        if (Const->isNullValue())
+          HasNullPtrOnOneSide = true;
+      } else if (Operand == Handle) {
+        HasHandleOnOther = true;
+      }
+    }
+
+    if (HasRequiredPredicate && HasNullPtrOnOneSide && HasHandleOnOther) {
+      return Predecessor;
+    }
+  }
+
+  return nullptr;
+}
+
 bool CoroIdElider::canCoroBeginEscape(
     const CoroBeginInst *CB, const SmallPtrSetImpl<BasicBlock *> &TIs) const {
   const auto &It = DestroyAddr.find(CB);
@@ -267,8 +335,15 @@ bool CoroIdElider::canCoroBeginEscape(
   SmallPtrSet<const BasicBlock *, 32> Visited;
   // Consider basicblock of coro.destroy as visited one, so that we
   // skip the path pass through coro.destroy.
-  for (auto *DA : It->second)
-    Visited.insert(DA->getParent());
+  for (auto *DA : It->second) {
+    auto *DestroyBB = DA->getParent();
+    auto *Handle = DA->getOperand(0);
+    Visited.insert(DestroyBB);
+
+    if (auto *Pred = getNullCheckerPredecessor(Handle, DestroyBB)) {
+      Visited.insert(Pred);
+    }
+  }
 
   SmallPtrSet<const BasicBlock *, 32> EscapingBBs;
   for (auto *U : CB->users()) {



More information about the llvm-commits mailing list