[clang] 9d1cb18 - [Coroutines] Ignore instructions more aggressively in addMustTailToCoroResumes() (#85271)

via cfe-commits cfe-commits at lists.llvm.org
Wed Mar 20 06:51:50 PDT 2024


Author: Hans
Date: 2024-03-20T14:51:45+01:00
New Revision: 9d1cb18d19862fc0627e4a56e1e491a498e84c71

URL: https://github.com/llvm/llvm-project/commit/9d1cb18d19862fc0627e4a56e1e491a498e84c71
DIFF: https://github.com/llvm/llvm-project/commit/9d1cb18d19862fc0627e4a56e1e491a498e84c71.diff

LOG: [Coroutines] Ignore instructions more aggressively in addMustTailToCoroResumes() (#85271)

The old code used isInstructionTriviallyDead() and removed instructions
when walking the path from a resume call to function return to check if
the call is in tail position.

However, since the code was walking forwards it was not able to get past
instructions such as:

  %gep = getelementptr inbounds i64, ptr %alloc.var, i32 0
  %foo = ptrtoint ptr %gep to i64

This patch instead ignores such instructions as long as their values are
not needed. This enables the code to emit tail calls in more situations.

Added: 
    clang/test/CodeGenCoroutines/coro-symmetric-transfer-04.cpp

Modified: 
    llvm/lib/Transforms/Coroutines/CoroSplit.cpp
    llvm/test/Transforms/Coroutines/coro-split-musttail7.ll

Removed: 
    


################################################################################
diff  --git a/clang/test/CodeGenCoroutines/coro-symmetric-transfer-04.cpp b/clang/test/CodeGenCoroutines/coro-symmetric-transfer-04.cpp
new file mode 100644
index 00000000000000..cf9170d7e71105
--- /dev/null
+++ b/clang/test/CodeGenCoroutines/coro-symmetric-transfer-04.cpp
@@ -0,0 +1,66 @@
+// This tests that the symmetric transfer at the final suspend point could happen successfully.
+// Based on https://github.com/llvm/llvm-project/pull/85271#issuecomment-2007554532
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -O2 -emit-llvm %s -o - | FileCheck %s
+
+#include "Inputs/coroutine.h"
+
+struct Task {
+  struct promise_type {
+    struct FinalAwaiter {
+      bool await_ready() const noexcept { return false; }
+      template <typename PromiseType>
+      std::coroutine_handle<> await_suspend(std::coroutine_handle<PromiseType> h) noexcept {
+        return h.promise().continuation;
+      }
+      void await_resume() noexcept {}
+    };
+    Task get_return_object() noexcept {
+      return std::coroutine_handle<promise_type>::from_promise(*this);
+    }
+    std::suspend_always initial_suspend() noexcept { return {}; }
+    FinalAwaiter final_suspend() noexcept { return {}; }
+    void unhandled_exception() noexcept {}
+    void return_value(int x) noexcept {
+      _value = x;
+    }
+    std::coroutine_handle<> continuation;
+    int _value;
+  };
+
+  Task(std::coroutine_handle<promise_type> handle) : handle(handle), stuff(123) {}
+
+  struct Awaiter {
+    std::coroutine_handle<promise_type> handle;
+    Awaiter(std::coroutine_handle<promise_type> handle) : handle(handle) {}
+    bool await_ready() const noexcept { return false; }
+    std::coroutine_handle<void> await_suspend(std::coroutine_handle<void> continuation) noexcept {
+      handle.promise().continuation = continuation;
+      return handle;
+    }
+    int await_resume() noexcept {
+      int ret = handle.promise()._value;
+      handle.destroy();
+      return ret;
+    }
+  };
+
+  auto operator co_await() {
+    auto handle_ = handle;
+    handle = nullptr;
+    return Awaiter(handle_);
+  }
+
+private:
+  std::coroutine_handle<promise_type> handle;
+  int stuff;
+};
+
+Task task0() {
+  co_return 43;
+}
+
+// CHECK-LABEL: define{{.*}} void @_Z5task0v.resume
+// This checks we are still in the scope of the current function.
+// CHECK-NOT: {{^}}}
+// CHECK: musttail call fastcc void
+// CHECK-NEXT: ret void

diff  --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
index 3f3d81474fafe4..3a43b1edcaba37 100644
--- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
@@ -1198,22 +1198,6 @@ static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) {
   assert(InitialInst->getModule());
   const DataLayout &DL = InitialInst->getModule()->getDataLayout();
 
-  auto GetFirstValidInstruction = [](Instruction *I) {
-    while (I) {
-      // BitCastInst wouldn't generate actual code so that we could skip it.
-      if (isa<BitCastInst>(I) || I->isDebugOrPseudoInst() ||
-          I->isLifetimeStartOrEnd())
-        I = I->getNextNode();
-      else if (isInstructionTriviallyDead(I))
-        // Duing we are in the middle of the transformation, we need to erase
-        // the dead instruction manually.
-        I = &*I->eraseFromParent();
-      else
-        break;
-    }
-    return I;
-  };
-
   auto TryResolveConstant = [&ResolvedValues](Value *V) {
     auto It = ResolvedValues.find(V);
     if (It != ResolvedValues.end())
@@ -1222,8 +1206,9 @@ static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) {
   };
 
   Instruction *I = InitialInst;
-  while (I->isTerminator() || isa<CmpInst>(I)) {
+  while (true) {
     if (isa<ReturnInst>(I)) {
+      assert(!cast<ReturnInst>(I)->getReturnValue());
       ReplaceInstWithInst(InitialInst, I->clone());
       return true;
     }
@@ -1247,40 +1232,26 @@ static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) {
 
       BasicBlock *Succ = BR->getSuccessor(SuccIndex);
       scanPHIsAndUpdateValueMap(I, Succ, ResolvedValues);
-      I = GetFirstValidInstruction(Succ->getFirstNonPHIOrDbgOrLifetime());
-
+      I = Succ->getFirstNonPHIOrDbgOrLifetime();
       continue;
     }
 
-    if (auto *CondCmp = dyn_cast<CmpInst>(I)) {
+    if (auto *Cmp = dyn_cast<CmpInst>(I)) {
       // If the case number of suspended switch instruction is reduced to
       // 1, then it is simplified to CmpInst in llvm::ConstantFoldTerminator.
-      auto *BR = dyn_cast<BranchInst>(
-          GetFirstValidInstruction(CondCmp->getNextNode()));
-      if (!BR || !BR->isConditional() || CondCmp != BR->getCondition())
-        return false;
-
-      // And the comparsion looks like : %cond = icmp eq i8 %V, constant.
-      // So we try to resolve constant for the first operand only since the
-      // second operand should be literal constant by design.
-      ConstantInt *Cond0 = TryResolveConstant(CondCmp->getOperand(0));
-      auto *Cond1 = dyn_cast<ConstantInt>(CondCmp->getOperand(1));
-      if (!Cond0 || !Cond1)
-        return false;
-
-      // Both operands of the CmpInst are Constant. So that we could evaluate
-      // it immediately to get the destination.
-      auto *ConstResult =
-          dyn_cast_or_null<ConstantInt>(ConstantFoldCompareInstOperands(
-              CondCmp->getPredicate(), Cond0, Cond1, DL));
-      if (!ConstResult)
-        return false;
-
-      ResolvedValues[BR->getCondition()] = ConstResult;
-
-      // Handle this branch in next iteration.
-      I = BR;
-      continue;
+      // Try to constant fold it.
+      ConstantInt *Cond0 = TryResolveConstant(Cmp->getOperand(0));
+      ConstantInt *Cond1 = TryResolveConstant(Cmp->getOperand(1));
+      if (Cond0 && Cond1) {
+        ConstantInt *Result =
+            dyn_cast_or_null<ConstantInt>(ConstantFoldCompareInstOperands(
+                Cmp->getPredicate(), Cond0, Cond1, DL));
+        if (Result) {
+          ResolvedValues[Cmp] = Result;
+          I = I->getNextNode();
+          continue;
+        }
+      }
     }
 
     if (auto *SI = dyn_cast<SwitchInst>(I)) {
@@ -1288,13 +1259,21 @@ static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) {
       if (!Cond)
         return false;
 
-      BasicBlock *BB = SI->findCaseValue(Cond)->getCaseSuccessor();
-      scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
-      I = GetFirstValidInstruction(BB->getFirstNonPHIOrDbgOrLifetime());
+      BasicBlock *Succ = SI->findCaseValue(Cond)->getCaseSuccessor();
+      scanPHIsAndUpdateValueMap(I, Succ, ResolvedValues);
+      I = Succ->getFirstNonPHIOrDbgOrLifetime();
+      continue;
+    }
+
+    if (I->isDebugOrPseudoInst() || I->isLifetimeStartOrEnd() ||
+        wouldInstructionBeTriviallyDead(I)) {
+      // We can skip instructions without side effects. If their values are
+      // needed, we'll notice later, e.g. when hitting a conditional branch.
+      I = I->getNextNode();
       continue;
     }
 
-    return false;
+    break;
   }
 
   return false;

diff  --git a/llvm/test/Transforms/Coroutines/coro-split-musttail7.ll b/llvm/test/Transforms/Coroutines/coro-split-musttail7.ll
index 2257d5aee473ee..d0d5005587bda6 100644
--- a/llvm/test/Transforms/Coroutines/coro-split-musttail7.ll
+++ b/llvm/test/Transforms/Coroutines/coro-split-musttail7.ll
@@ -1,6 +1,6 @@
 ; Tests that sinked lifetime markers wouldn't provent optimization
 ; to convert a resuming call to a musttail call.
-; The 
diff erence between this and coro-split-musttail5.ll and coro-split-musttail5.ll
+; The 
diff erence between this and coro-split-musttail5.ll and coro-split-musttail6.ll
 ; is that this contains dead instruction generated during the transformation,
 ; which makes the optimization harder.
 ; RUN: opt < %s -passes='cgscc(coro-split),simplifycfg,early-cse' -S | FileCheck %s
@@ -8,7 +8,7 @@
 
 declare void @fakeresume1(ptr align 8)
 
-define void @g() #0 {
+define i64 @g() #0 {
 entry:
   %id = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null)
   %alloc = call ptr @malloc(i64 16) #3
@@ -27,6 +27,11 @@ await.suspend:
   %save2 = call token @llvm.coro.save(ptr null)
   call fastcc void @fakeresume1(ptr align 8 null)
   %suspend2 = call i8 @llvm.coro.suspend(token %save2, i1 false)
+
+  ; These (non-trivially) dead instructions are in the way.
+  %gep = getelementptr inbounds i64, ptr %alloc.var, i32 0
+  %foo = ptrtoint ptr %gep to i64
+
   switch i8 %suspend2, label %exit [
     i8 0, label %await.ready
     i8 1, label %exit
@@ -36,8 +41,9 @@ await.ready:
   call void @llvm.lifetime.end.p0(i64 1, ptr %alloc.var)
   br label %exit
 exit:
+  %result = phi i64 [0, %entry], [0, %entry], [%foo, %await.suspend], [%foo, %await.suspend], [%foo, %await.ready]
   call i1 @llvm.coro.end(ptr null, i1 false, token none)
-  ret void
+  ret i64 %result
 }
 
 ; Verify that in the resume part resume call is marked with musttail.


        


More information about the cfe-commits mailing list