[llvm] [Coroutines] Make CoroElide actually work for (trivial) C++ coroutines (PR #92969)

Yuxuan Chen via llvm-commits llvm-commits at lists.llvm.org
Tue May 21 15:19:10 PDT 2024


https://github.com/yuxuanchen1997 created https://github.com/llvm/llvm-project/pull/92969

## Motivation
`Task` is widely used in C++ coroutines. We have various reports that `CoroElide` are not working in coroutines that actually `co_await` a `Task`. Further inspection revealed there are multiple issues. 

Consider this most trivial task type. 
```c++
struct Task {
  struct promise_type {
    struct FinalAwaiter {
      bool await_ready() const noexcept { return false; }

      template <typename P>
      std::coroutine_handle<> await_suspend(std::coroutine_handle<P> coro) noexcept {
        if (!coro)
          return std::noop_coroutine();
        return coro.promise().continuation;
      }
      void await_resume() noexcept {}
    };

    Task get_return_object() noexcept {
      return std::coroutine_handle<promise_type>::from_promise(*this);
    }

    std::suspend_always initial_suspend() noexcept { return {}; }
    FinalAwaiter final_suspend() noexcept { return {}; }
    void unhandled_exception() noexcept {}
    void return_value(int x) noexcept {
      value = x;
    }

    std::coroutine_handle<> continuation;
    int value;
  };

  Task(std::coroutine_handle<promise_type> handle) : handle(handle) {}
  ~Task() {
    if (handle)
      handle.destroy();
  }

  struct Awaiter {
    Awaiter(Task *t) : task(t) {}
    bool await_ready() const noexcept { return false; }
    void await_suspend(std::coroutine_handle<void> continuation) noexcept {}
    int await_resume() noexcept {
      return 43;
    }

    Task *task;
  };

  auto operator co_await() {
    return Awaiter{this};
  }

private:
  std::coroutine_handle<promise_type> handle;
};
```

This patch addresses two most obvious issues in `CoroElide` that prevents Elide from happening in simple inlined `co_await` cases. 

1. CoroElide analysis requires that a coroutine frame pointer gets destroyed through `subfn.addr` (which is generally just `handle.destroy()`. However, in C++, the pointer is often first stored to `Task` and then retrieved from the `Task` struct before destruction. This maps into the code as in the constructor and destructor of the `Task` type. This patch RAUW all such uses.
2. Make `canCoroBeginEscape` stronger with more analysis for its predecessor block.

TODO:
1. Task type's storage for coroutine handle needs to be constant for every instance of `Task`. Currently this is not true for move contructors. Make the pass only RAUW if the task has not been tampered with (e.g. with move, awaiter::await_suspend)
3. Revisit `getNullCheckerPredecessor`. It works well enough for the following pattern 
```
if (handle)
  handle.destroy()
```
However, we need to think about more robust ways other than brute force pattern matching. 
4. Add unit tests. 
5. 


>From cc83350224d56525463754dbffa57db1756e8eea Mon Sep 17 00:00:00 2001
From: Yuxuan Chen <ych at meta.com>
Date: Tue, 21 May 2024 12:39:26 -0700
Subject: [PATCH] wip

---
 llvm/lib/Transforms/Coroutines/CoroElide.cpp | 76 +++++++++++++++++++-
 1 file changed, 74 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Transforms/Coroutines/CoroElide.cpp b/llvm/lib/Transforms/Coroutines/CoroElide.cpp
index bb244489e4c2c..918051be0e072 100644
--- a/llvm/lib/Transforms/Coroutines/CoroElide.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroElide.cpp
@@ -13,8 +13,10 @@
 #include "llvm/Analysis/AliasAnalysis.h"
 #include "llvm/Analysis/InstructionSimplify.h"
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
+#include "llvm/IR/Constants.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/FileSystem.h"
 #include <optional>
@@ -193,6 +195,31 @@ CoroIdElider::CoroIdElider(CoroIdInst *CoroId, FunctionElideInfo &FEI,
       CoroAllocs.push_back(CA);
   }
 
+  DenseMap<AllocaInst *, CoroBeginInst *> Writes;
+  for (CoroBeginInst *CB : CoroBegins) {
+    for (User *U : CB->users()) {
+      if (auto *Store = dyn_cast<StoreInst>(U)) {
+        if (Store->getOperand(0) == CB) {
+          auto *Dest = Store->getOperand(1);
+          if (auto *AI = dyn_cast<AllocaInst>(Dest)) {
+            Writes[AI] = CB;
+          }
+        }
+      }
+    }
+  }
+
+  for (Instruction &Inst : instructions(FEI.ContainingFunction)) {
+    if (auto *Load = dyn_cast<LoadInst>(&Inst)) {
+      if (auto *AI = dyn_cast<AllocaInst>(Load->getOperand(0))) {
+        auto It = Writes.find(AI);
+        if (It != Writes.end() && DT.dominates(It->second, Load)) {
+          Load->replaceAllUsesWith(It->second);
+        }
+      }
+    }
+  }
+
   // Collect all coro.subfn.addrs associated with coro.begin.
   // Note, we only devirtualize the calls if their coro.subfn.addr refers to
   // coro.begin directly. If we run into cases where this check is too
@@ -253,6 +280,44 @@ void CoroIdElider::elideHeapAllocations(uint64_t FrameSize, Align FrameAlign) {
   removeTailCallAttribute(Frame, AA);
 }
 
+static BasicBlock *getNullCheckerPredecessor(Value *Handle,
+                                             BasicBlock *DestroyBB) {
+  auto *Predecessor = DestroyBB->getSinglePredecessor();
+  if (!Predecessor)
+    return nullptr;
+
+  auto *BrInst = dyn_cast_or_null<BranchInst>(Predecessor->getTerminator());
+  if (!BrInst)
+    return nullptr;
+
+  auto CheckNe = BrInst->getSuccessor(0) == DestroyBB;
+  auto *Cond = BrInst->getCondition();
+  assert(Cond);
+  if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
+    bool HasRequiredPredicate =
+        Cmp->getSignedPredicate() ==
+        (CheckNe ? CmpInst::Predicate::ICMP_NE : CmpInst::Predicate::ICMP_EQ);
+    bool HasNullPtrOnOneSide = false;
+    bool HasHandleOnOther = false;
+
+    for (auto I : {0, 1}) {
+      auto Operand = Cmp->getOperand(I);
+      if (auto *Const = dyn_cast<Constant>(Operand)) {
+        if (Const->isNullValue())
+          HasNullPtrOnOneSide = true;
+      } else if (Operand == Handle) {
+        HasHandleOnOther = true;
+      }
+    }
+
+    if (HasRequiredPredicate && HasNullPtrOnOneSide && HasHandleOnOther) {
+      return Predecessor;
+    }
+  }
+
+  return nullptr;
+}
+
 bool CoroIdElider::canCoroBeginEscape(
     const CoroBeginInst *CB, const SmallPtrSetImpl<BasicBlock *> &TIs) const {
   const auto &It = DestroyAddr.find(CB);
@@ -267,8 +332,15 @@ bool CoroIdElider::canCoroBeginEscape(
   SmallPtrSet<const BasicBlock *, 32> Visited;
   // Consider basicblock of coro.destroy as visited one, so that we
   // skip the path pass through coro.destroy.
-  for (auto *DA : It->second)
-    Visited.insert(DA->getParent());
+  for (auto *DA : It->second) {
+    auto *DestroyBB = DA->getParent();
+    auto *Handle = DA->getOperand(0);
+    Visited.insert(DestroyBB);
+
+    if (auto *Pred = getNullCheckerPredecessor(Handle, DestroyBB)) {
+      Visited.insert(Pred);
+    }
+  }
 
   SmallPtrSet<const BasicBlock *, 32> EscapingBBs;
   for (auto *U : CB->users()) {



More information about the llvm-commits mailing list