[llvm] [Coroutines] Make CoroElide actually work for (trivial) C++ coroutines (PR #92969)
Yuxuan Chen via llvm-commits
llvm-commits at lists.llvm.org
Tue May 21 15:19:10 PDT 2024
https://github.com/yuxuanchen1997 created https://github.com/llvm/llvm-project/pull/92969
## Motivation
`Task` is widely used in C++ coroutines. We have various reports that `CoroElide` are not working in coroutines that actually `co_await` a `Task`. Further inspection revealed there are multiple issues.
Consider this most trivial task type.
```c++
struct Task {
struct promise_type {
struct FinalAwaiter {
bool await_ready() const noexcept { return false; }
template <typename P>
std::coroutine_handle<> await_suspend(std::coroutine_handle<P> coro) noexcept {
if (!coro)
return std::noop_coroutine();
return coro.promise().continuation;
}
void await_resume() noexcept {}
};
Task get_return_object() noexcept {
return std::coroutine_handle<promise_type>::from_promise(*this);
}
std::suspend_always initial_suspend() noexcept { return {}; }
FinalAwaiter final_suspend() noexcept { return {}; }
void unhandled_exception() noexcept {}
void return_value(int x) noexcept {
value = x;
}
std::coroutine_handle<> continuation;
int value;
};
Task(std::coroutine_handle<promise_type> handle) : handle(handle) {}
~Task() {
if (handle)
handle.destroy();
}
struct Awaiter {
Awaiter(Task *t) : task(t) {}
bool await_ready() const noexcept { return false; }
void await_suspend(std::coroutine_handle<void> continuation) noexcept {}
int await_resume() noexcept {
return 43;
}
Task *task;
};
auto operator co_await() {
return Awaiter{this};
}
private:
std::coroutine_handle<promise_type> handle;
};
```
This patch addresses two most obvious issues in `CoroElide` that prevents Elide from happening in simple inlined `co_await` cases.
1. CoroElide analysis requires that a coroutine frame pointer gets destroyed through `subfn.addr` (which is generally just `handle.destroy()`. However, in C++, the pointer is often first stored to `Task` and then retrieved from the `Task` struct before destruction. This maps into the code as in the constructor and destructor of the `Task` type. This patch RAUW all such uses.
2. Make `canCoroBeginEscape` stronger with more analysis for its predecessor block.
TODO:
1. Task type's storage for coroutine handle needs to be constant for every instance of `Task`. Currently this is not true for move contructors. Make the pass only RAUW if the task has not been tampered with (e.g. with move, awaiter::await_suspend)
3. Revisit `getNullCheckerPredecessor`. It works well enough for the following pattern
```
if (handle)
handle.destroy()
```
However, we need to think about more robust ways other than brute force pattern matching.
4. Add unit tests.
5.
>From cc83350224d56525463754dbffa57db1756e8eea Mon Sep 17 00:00:00 2001
From: Yuxuan Chen <ych at meta.com>
Date: Tue, 21 May 2024 12:39:26 -0700
Subject: [PATCH] wip
---
llvm/lib/Transforms/Coroutines/CoroElide.cpp | 76 +++++++++++++++++++-
1 file changed, 74 insertions(+), 2 deletions(-)
diff --git a/llvm/lib/Transforms/Coroutines/CoroElide.cpp b/llvm/lib/Transforms/Coroutines/CoroElide.cpp
index bb244489e4c2c..918051be0e072 100644
--- a/llvm/lib/Transforms/Coroutines/CoroElide.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroElide.cpp
@@ -13,8 +13,10 @@
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
+#include "llvm/IR/Constants.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instructions.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FileSystem.h"
#include <optional>
@@ -193,6 +195,31 @@ CoroIdElider::CoroIdElider(CoroIdInst *CoroId, FunctionElideInfo &FEI,
CoroAllocs.push_back(CA);
}
+ DenseMap<AllocaInst *, CoroBeginInst *> Writes;
+ for (CoroBeginInst *CB : CoroBegins) {
+ for (User *U : CB->users()) {
+ if (auto *Store = dyn_cast<StoreInst>(U)) {
+ if (Store->getOperand(0) == CB) {
+ auto *Dest = Store->getOperand(1);
+ if (auto *AI = dyn_cast<AllocaInst>(Dest)) {
+ Writes[AI] = CB;
+ }
+ }
+ }
+ }
+ }
+
+ for (Instruction &Inst : instructions(FEI.ContainingFunction)) {
+ if (auto *Load = dyn_cast<LoadInst>(&Inst)) {
+ if (auto *AI = dyn_cast<AllocaInst>(Load->getOperand(0))) {
+ auto It = Writes.find(AI);
+ if (It != Writes.end() && DT.dominates(It->second, Load)) {
+ Load->replaceAllUsesWith(It->second);
+ }
+ }
+ }
+ }
+
// Collect all coro.subfn.addrs associated with coro.begin.
// Note, we only devirtualize the calls if their coro.subfn.addr refers to
// coro.begin directly. If we run into cases where this check is too
@@ -253,6 +280,44 @@ void CoroIdElider::elideHeapAllocations(uint64_t FrameSize, Align FrameAlign) {
removeTailCallAttribute(Frame, AA);
}
+static BasicBlock *getNullCheckerPredecessor(Value *Handle,
+ BasicBlock *DestroyBB) {
+ auto *Predecessor = DestroyBB->getSinglePredecessor();
+ if (!Predecessor)
+ return nullptr;
+
+ auto *BrInst = dyn_cast_or_null<BranchInst>(Predecessor->getTerminator());
+ if (!BrInst)
+ return nullptr;
+
+ auto CheckNe = BrInst->getSuccessor(0) == DestroyBB;
+ auto *Cond = BrInst->getCondition();
+ assert(Cond);
+ if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
+ bool HasRequiredPredicate =
+ Cmp->getSignedPredicate() ==
+ (CheckNe ? CmpInst::Predicate::ICMP_NE : CmpInst::Predicate::ICMP_EQ);
+ bool HasNullPtrOnOneSide = false;
+ bool HasHandleOnOther = false;
+
+ for (auto I : {0, 1}) {
+ auto Operand = Cmp->getOperand(I);
+ if (auto *Const = dyn_cast<Constant>(Operand)) {
+ if (Const->isNullValue())
+ HasNullPtrOnOneSide = true;
+ } else if (Operand == Handle) {
+ HasHandleOnOther = true;
+ }
+ }
+
+ if (HasRequiredPredicate && HasNullPtrOnOneSide && HasHandleOnOther) {
+ return Predecessor;
+ }
+ }
+
+ return nullptr;
+}
+
bool CoroIdElider::canCoroBeginEscape(
const CoroBeginInst *CB, const SmallPtrSetImpl<BasicBlock *> &TIs) const {
const auto &It = DestroyAddr.find(CB);
@@ -267,8 +332,15 @@ bool CoroIdElider::canCoroBeginEscape(
SmallPtrSet<const BasicBlock *, 32> Visited;
// Consider basicblock of coro.destroy as visited one, so that we
// skip the path pass through coro.destroy.
- for (auto *DA : It->second)
- Visited.insert(DA->getParent());
+ for (auto *DA : It->second) {
+ auto *DestroyBB = DA->getParent();
+ auto *Handle = DA->getOperand(0);
+ Visited.insert(DestroyBB);
+
+ if (auto *Pred = getNullCheckerPredecessor(Handle, DestroyBB)) {
+ Visited.insert(Pred);
+ }
+ }
SmallPtrSet<const BasicBlock *, 32> EscapingBBs;
for (auto *U : CB->users()) {
More information about the llvm-commits
mailing list