[llvm] [Coroutines] Make CoroElide actually work for (trivial) C++ coroutines (PR #92969)
Yuxuan Chen via llvm-commits
llvm-commits at lists.llvm.org
Tue May 21 16:05:13 PDT 2024
https://github.com/yuxuanchen1997 updated https://github.com/llvm/llvm-project/pull/92969
>From 3760552ebef1d77ee760a3a2fae11e1b34030a19 Mon Sep 17 00:00:00 2001
From: Yuxuan Chen <ych at meta.com>
Date: Tue, 21 May 2024 12:39:26 -0700
Subject: [PATCH] wip
---
llvm/lib/Transforms/Coroutines/CoroElide.cpp | 101 ++++++++++++++++---
1 file changed, 88 insertions(+), 13 deletions(-)
diff --git a/llvm/lib/Transforms/Coroutines/CoroElide.cpp b/llvm/lib/Transforms/Coroutines/CoroElide.cpp
index bb244489e4c2c..96b9c1e4af9ab 100644
--- a/llvm/lib/Transforms/Coroutines/CoroElide.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroElide.cpp
@@ -13,8 +13,10 @@
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
+#include "llvm/IR/Constants.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instructions.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FileSystem.h"
#include <optional>
@@ -193,23 +195,51 @@ CoroIdElider::CoroIdElider(CoroIdInst *CoroId, FunctionElideInfo &FEI,
CoroAllocs.push_back(CA);
}
+ DenseMap<AllocaInst *, CoroBeginInst *> Writes;
+ for (CoroBeginInst *CB : CoroBegins) {
+ for (User *U : CB->users()) {
+ if (auto *Store = dyn_cast<StoreInst>(U)) {
+ if (Store->getOperand(0) == CB) {
+ auto *Dest = Store->getOperand(1);
+ if (auto *AI = dyn_cast<AllocaInst>(Dest)) {
+ Writes[AI] = CB;
+ }
+ }
+ }
+ }
+ }
+
+ auto CheckCoroBeginUser = [&](CoroBeginInst *CB, User *U) {
+ if (auto *II = dyn_cast<CoroSubFnInst>(U))
+ switch (II->getIndex()) {
+ case CoroSubFnInst::ResumeIndex:
+ ResumeAddr.push_back(II);
+ break;
+ case CoroSubFnInst::DestroyIndex:
+ DestroyAddr[CB].push_back(II);
+ break;
+ default:
+ llvm_unreachable("unexpected coro.subfn.addr constant");
+ }
+ };
+
+ for (const auto &[Alloca, CB] : Writes) {
+ for (User *AU : Alloca->users()) {
+ auto *Load = dyn_cast<LoadInst>(AU);
+ if (Load && Load->getOperand(0) == Alloca) {
+ for (User *LU : Load->users())
+ CheckCoroBeginUser(CB, LU);
+ }
+ }
+ }
+
// Collect all coro.subfn.addrs associated with coro.begin.
// Note, we only devirtualize the calls if their coro.subfn.addr refers to
// coro.begin directly. If we run into cases where this check is too
// conservative, we can consider relaxing the check.
for (CoroBeginInst *CB : CoroBegins) {
for (User *U : CB->users())
- if (auto *II = dyn_cast<CoroSubFnInst>(U))
- switch (II->getIndex()) {
- case CoroSubFnInst::ResumeIndex:
- ResumeAddr.push_back(II);
- break;
- case CoroSubFnInst::DestroyIndex:
- DestroyAddr[CB].push_back(II);
- break;
- default:
- llvm_unreachable("unexpected coro.subfn.addr constant");
- }
+ CheckCoroBeginUser(CB, U);
}
}
@@ -253,6 +283,44 @@ void CoroIdElider::elideHeapAllocations(uint64_t FrameSize, Align FrameAlign) {
removeTailCallAttribute(Frame, AA);
}
+static BasicBlock *getNullCheckerPredecessor(Value *Handle,
+ BasicBlock *DestroyBB) {
+ auto *Predecessor = DestroyBB->getSinglePredecessor();
+ if (!Predecessor)
+ return nullptr;
+
+ auto *BrInst = dyn_cast_or_null<BranchInst>(Predecessor->getTerminator());
+ if (!BrInst)
+ return nullptr;
+
+ auto CheckNe = BrInst->getSuccessor(0) == DestroyBB;
+ auto *Cond = BrInst->getCondition();
+ assert(Cond);
+ if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
+ bool HasRequiredPredicate =
+ Cmp->getSignedPredicate() ==
+ (CheckNe ? CmpInst::Predicate::ICMP_NE : CmpInst::Predicate::ICMP_EQ);
+ bool HasNullPtrOnOneSide = false;
+ bool HasHandleOnOther = false;
+
+ for (auto I : {0, 1}) {
+ auto Operand = Cmp->getOperand(I);
+ if (auto *Const = dyn_cast<Constant>(Operand)) {
+ if (Const->isNullValue())
+ HasNullPtrOnOneSide = true;
+ } else if (Operand == Handle) {
+ HasHandleOnOther = true;
+ }
+ }
+
+ if (HasRequiredPredicate && HasNullPtrOnOneSide && HasHandleOnOther) {
+ return Predecessor;
+ }
+ }
+
+ return nullptr;
+}
+
bool CoroIdElider::canCoroBeginEscape(
const CoroBeginInst *CB, const SmallPtrSetImpl<BasicBlock *> &TIs) const {
const auto &It = DestroyAddr.find(CB);
@@ -267,8 +335,15 @@ bool CoroIdElider::canCoroBeginEscape(
SmallPtrSet<const BasicBlock *, 32> Visited;
// Consider basicblock of coro.destroy as visited one, so that we
// skip the path pass through coro.destroy.
- for (auto *DA : It->second)
- Visited.insert(DA->getParent());
+ for (auto *DA : It->second) {
+ auto *DestroyBB = DA->getParent();
+ auto *Handle = DA->getOperand(0);
+ Visited.insert(DestroyBB);
+
+ if (auto *Pred = getNullCheckerPredecessor(Handle, DestroyBB)) {
+ Visited.insert(Pred);
+ }
+ }
SmallPtrSet<const BasicBlock *, 32> EscapingBBs;
for (auto *U : CB->users()) {
More information about the llvm-commits
mailing list