[llvm] [coro][pgo] Don't promote pgo counters in the suspend basic block (PR #71263)

Michael Maitland via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 30 08:38:30 PST 2023


https://github.com/michaelmaitland updated https://github.com/llvm/llvm-project/pull/71263

>From 4f3b2d421bd7f0a0dc28cd2c8a87f5d29f900f8c Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Fri, 3 Nov 2023 18:33:31 -0700
Subject: [PATCH 1/5] [coro][pgo] Don't promote pgo counters in the suspend
 basic block

If a suspend happens in the resume part (this can happen in the case of
chained coroutines), and that's part of a loop, the pre-split CFG has
the suspend block as an exit of that loop. PGO Counter Promotion will
then try to commit the temporary counter to the global in that "exit"
block (it also does that in the other loop exit BBs, which also includes
the "destroy" case).

We don't need to commit the counter in the suspend case - it's not
a loop exit from the perspective of the behavior of the program. The
regular loop exit, together with the "destroy" case, completely cover
any updates that may need to happen to the global counter.
---
 .../Instrumentation/InstrProfiling.cpp        |  10 +-
 ...-split-musttail-chain-pgo-counter-promo.ll | 175 ++++++++++++++++++
 2 files changed, 184 insertions(+), 1 deletion(-)
 create mode 100644 llvm/test/Transforms/Coroutines/coro-split-musttail-chain-pgo-counter-promo.ll

diff --git a/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp b/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp
index 480817a23d2c208..cbf5110e889e668 100644
--- a/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp
+++ b/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp
@@ -242,8 +242,16 @@ class PGOCounterPromoter {
     if (!isPromotionPossible(&L, LoopExitBlocks))
       return;
 
+    auto IsSuspendBB = [&](BasicBlock *BB) {
+      if (auto *Pred = BB->getSinglePredecessor())
+        if (auto *SW = dyn_cast<SwitchInst>(Pred->getTerminator()))
+          if (auto *Intr = dyn_cast<IntrinsicInst>(SW->getCondition()))
+            return Intr->getIntrinsicID() == Intrinsic::coro_suspend &&
+                   SW->getDefaultDest() == BB;
+      return false;
+    };
     for (BasicBlock *ExitBlock : LoopExitBlocks) {
-      if (BlockSet.insert(ExitBlock).second) {
+      if (BlockSet.insert(ExitBlock).second && !IsSuspendBB(ExitBlock)) {
         ExitBlocks.push_back(ExitBlock);
         InsertPts.push_back(&*ExitBlock->getFirstInsertionPt());
       }
diff --git a/llvm/test/Transforms/Coroutines/coro-split-musttail-chain-pgo-counter-promo.ll b/llvm/test/Transforms/Coroutines/coro-split-musttail-chain-pgo-counter-promo.ll
new file mode 100644
index 000000000000000..ddd293eed2409e9
--- /dev/null
+++ b/llvm/test/Transforms/Coroutines/coro-split-musttail-chain-pgo-counter-promo.ll
@@ -0,0 +1,175 @@
+; REQUIRES: x86-registered-target
+; RUN: opt -passes='pgo-instr-gen,instrprof,coro-split' -do-counter-promotion=true -S < %s | FileCheck %s
+
+; CHECK-LABEL: define internal fastcc void @f.resume
+; CHECK: musttail call fastcc void 
+; CHECK-NEXT: ret void
+; CHECK: musttail call fastcc void 
+; CHECK-NEXT: ret void
+; CHECK-LABEL: define internal fastcc void @f.destroy
+target triple = "x86_64-grtev4-linux-gnu"
+
+%CoroutinePromise = type { ptr, i64, [8 x i8], ptr} 
+%Awaitable.1 = type { ptr }
+%Awaitable.2 = type { ptr, ptr }
+
+declare void @await_suspend(ptr noundef nonnull align 1 dereferenceable(1), ptr) local_unnamed_addr
+declare ptr @await_transform_await_suspend(ptr noundef nonnull align 8 dereferenceable(16), ptr) local_unnamed_addr
+declare void @destroy_frame_slowpath(ptr noundef nonnull align 16 dereferenceable(32)) local_unnamed_addr
+declare ptr @other_coro();
+declare void @heap_delete(ptr noundef, i64 noundef, i64 noundef) local_unnamed_addr
+declare noundef nonnull ptr @heap_allocate(i64 noundef, i64 noundef) local_unnamed_addr
+
+declare void @llvm.assume(i1 noundef)
+declare i64 @llvm.coro.align.i64()
+declare i1 @llvm.coro.alloc(token)
+declare ptr @llvm.coro.begin(token, ptr writeonly)
+declare i1 @llvm.coro.end(ptr, i1, token)
+declare ptr @llvm.coro.free(token, ptr nocapture readonly)
+declare token @llvm.coro.id(i32, ptr readnone, ptr nocapture readonly, ptr)
+declare token @llvm.coro.save(ptr)
+declare i64 @llvm.coro.size.i64()
+declare ptr @llvm.coro.subfn.addr(ptr nocapture readonly, i8)
+declare i8 @llvm.coro.suspend(token, i1)
+declare void @llvm.instrprof.increment(ptr, i64, i32, i32)
+declare void @llvm.instrprof.value.profile(ptr, i64, i64, i32, i32)
+declare void @llvm.lifetime.start.p0(i64 immarg, ptr nocapture)
+declare void @llvm.lifetime.end.p0(i64 immarg, ptr nocapture)
+
+; Function Attrs: noinline nounwind presplitcoroutine uwtable
+define ptr @f(i32 %0) presplitcoroutine align 32 {
+  %2 = alloca i32, align 8
+  %3 = alloca %CoroutinePromise, align 16
+  %4 = alloca %Awaitable.1, align 8
+  %5 = alloca %Awaitable.2, align 8
+  %6 = call token @llvm.coro.id(i32 8, ptr nonnull %3, ptr nonnull @f, ptr null)
+  %7 = call i1 @llvm.coro.alloc(token %6)
+  br i1 %7, label %8, label %12
+
+8:                                                ; preds = %1
+  %9 = call i64 @llvm.coro.size.i64()
+  %10 = call i64 @llvm.coro.align.i64()
+  %11 = call noalias noundef nonnull ptr @heap_allocate(i64 noundef %9, i64 noundef %10) #27
+  call void @llvm.assume(i1 true) [ "align"(ptr %11, i64 %10) ]
+  br label %12
+
+12:                                               ; preds = %8, %1
+  %13 = phi ptr [ null, %1 ], [ %11, %8 ]
+  %14 = call ptr @llvm.coro.begin(token %6, ptr %13) #28
+  call void @llvm.lifetime.start.p0(i64 32, ptr nonnull %3) #9
+  store ptr null, ptr %3, align 16
+  %15 = getelementptr inbounds {ptr, i64}, ptr %3, i64 0, i32 1
+  store i64 0, ptr %15, align 8
+  call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %4) #9
+  store ptr %3, ptr %4, align 8
+  %16 = call token @llvm.coro.save(ptr null)
+  call void @await_suspend(ptr noundef nonnull align 1 dereferenceable(1) %4, ptr %14) #9
+  %17 = call i8 @llvm.coro.suspend(token %16, i1 false)
+  switch i8 %17, label %61 [
+    i8 0, label %18
+    i8 1, label %21
+  ]
+
+18:                                               ; preds = %12
+  call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %4) #9
+  %19 = icmp slt i32 0, %0
+  br i1 %19, label %20, label %36
+
+20:                                               ; preds = %18
+  br label %22
+
+21:                                               ; preds = %12
+  call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %4) #9
+  br label %54
+
+22:                                               ; preds = %20, %31
+  %23 = phi i32 [ 0, %20 ], [ %32, %31 ]
+  call void @llvm.lifetime.start.p0(i64 16, ptr nonnull %5) #9
+  %24 = call ptr @other_coro()
+  store ptr %3, ptr %5, align 8
+  %25 = getelementptr inbounds { ptr, ptr }, ptr %5, i64 0, i32 1
+  store ptr %24, ptr %25, align 8
+  %26 = call token @llvm.coro.save(ptr null)
+  %27 = call ptr @await_transform_await_suspend(ptr noundef nonnull align 8 dereferenceable(16) %5, ptr %14)
+  %28 = call ptr @llvm.coro.subfn.addr(ptr %27, i8 0)
+  %29 = ptrtoint ptr %28 to i64
+  call fastcc void %28(ptr %27) #9
+  %30 = call i8 @llvm.coro.suspend(token %26, i1 false)
+  switch i8 %30, label %60 [
+    i8 0, label %31
+    i8 1, label %34
+  ]
+
+31:                                               ; preds = %22
+  call void @llvm.lifetime.end.p0(i64 16, ptr nonnull %5) #9
+  %32 = add nuw nsw i32 %23, 1
+  %33 = icmp slt i32 %32, %0
+  br i1 %33, label %22, label %35, !llvm.loop !0
+
+34:                                               ; preds = %22
+  call void @llvm.lifetime.end.p0(i64 16, ptr nonnull %5) #9
+  br label %54
+
+35:                                               ; preds = %31
+  br label %36
+
+36:                                               ; preds = %35, %18
+  %37 = call token @llvm.coro.save(ptr null)
+  %38 = getelementptr inbounds i8, ptr %14, i64 16
+  %39 = getelementptr inbounds i8, ptr %14, i64 32
+  %40 = load i64, ptr %39, align 8
+  %41 = load ptr, ptr %38, align 16
+  %42 = icmp eq ptr %41, null
+  br i1 %42, label %43, label %46
+
+43:                                               ; preds = %36
+  %44 = call ptr @llvm.coro.subfn.addr(ptr nonnull %14, i8 1)
+  %45 = ptrtoint ptr %44 to i64
+  call fastcc void %44(ptr nonnull %14) #9
+  br label %47
+
+46:                                               ; preds = %36
+  call void @destroy_frame_slowpath(ptr noundef nonnull align 16 dereferenceable(32) %38) #9
+  br label %47
+
+47:                                               ; preds = %43, %46
+  %48 = inttoptr i64 %40 to ptr
+  %49 = call ptr @llvm.coro.subfn.addr(ptr %48, i8 0)
+  %50 = ptrtoint ptr %49 to i64
+  call fastcc void %49(ptr %48) #9
+  %51 = call i8 @llvm.coro.suspend(token %37, i1 true) #28
+  switch i8 %51, label %61 [
+    i8 0, label %53
+    i8 1, label %52
+  ]
+
+52:                                               ; preds = %47
+  br label %54
+
+53:                                               ; preds = %47
+  call void @llvm.lifetime.start.p0(i64 16, ptr nonnull %2) #9
+  unreachable
+
+54:                                               ; preds = %52, %34, %21
+  call void @llvm.lifetime.end.p0(i64 32, ptr nonnull %3) #9
+  %55 = call ptr @llvm.coro.free(token %6, ptr %14)
+  %56 = icmp eq ptr %55, null
+  br i1 %56, label %61, label %57
+
+57:                                               ; preds = %54
+  %58 = call i64 @llvm.coro.size.i64()
+  %59 = call i64 @llvm.coro.align.i64()
+  call void @heap_delete(ptr noundef nonnull %55, i64 noundef %58, i64 noundef %59) #9
+  br label %61
+
+60:                                               ; preds = %22
+  br label %61
+
+61:                                               ; preds = %60, %57, %54, %47, %12
+  %62 = getelementptr inbounds i8, ptr %3, i64 -16
+  %63 = call i1 @llvm.coro.end(ptr null, i1 false, token none) #28
+  ret ptr %62
+}
+
+!0 = distinct !{!0, !1}
+!1 = !{!"llvm.loop.mustprogress"}

>From 1fb2beeb97a630f68ffc14fb8d6b788686dbd884 Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Thu, 16 Nov 2023 13:45:28 -0800
Subject: [PATCH 2/5] Added suspend exit BB detector to BasicBlockUtils.h

---
 .../llvm/Transforms/Utils/BasicBlockUtils.h   |  2 +
 .../Instrumentation/InstrProfiling.cpp        | 12 ++---
 llvm/lib/Transforms/Utils/BasicBlockUtils.cpp | 11 +++++
 .../Transforms/Utils/BasicBlockUtilsTest.cpp  | 44 +++++++++++++++++++
 4 files changed, 60 insertions(+), 9 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h b/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h
index e6dde450b7df9c8..21d56c6c5848bd6 100644
--- a/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h
+++ b/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h
@@ -705,6 +705,8 @@ void InvertBranch(BranchInst *PBI, IRBuilderBase &Builder);
 // Check whether the function only has simple terminator:
 // br/brcond/unreachable/ret
 bool hasOnlySimpleTerminator(const Function &F);
+
+bool isPresplitCoroSuspendExit(const BasicBlock &BB);
 } // end namespace llvm
 
 #endif // LLVM_TRANSFORMS_UTILS_BASICBLOCKUTILS_H
diff --git a/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp b/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp
index cbf5110e889e668..ddef932f3705bfd 100644
--- a/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp
+++ b/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp
@@ -48,6 +48,7 @@
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/TargetParser/Triple.h"
 #include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/ModuleUtils.h"
 #include "llvm/Transforms/Utils/SSAUpdater.h"
 #include <algorithm>
@@ -242,16 +243,9 @@ class PGOCounterPromoter {
     if (!isPromotionPossible(&L, LoopExitBlocks))
       return;
 
-    auto IsSuspendBB = [&](BasicBlock *BB) {
-      if (auto *Pred = BB->getSinglePredecessor())
-        if (auto *SW = dyn_cast<SwitchInst>(Pred->getTerminator()))
-          if (auto *Intr = dyn_cast<IntrinsicInst>(SW->getCondition()))
-            return Intr->getIntrinsicID() == Intrinsic::coro_suspend &&
-                   SW->getDefaultDest() == BB;
-      return false;
-    };
     for (BasicBlock *ExitBlock : LoopExitBlocks) {
-      if (BlockSet.insert(ExitBlock).second && !IsSuspendBB(ExitBlock)) {
+      if (BlockSet.insert(ExitBlock).second &&
+          !llvm::isPresplitCoroSuspendExit(*ExitBlock)) {
         ExitBlocks.push_back(ExitBlock);
         InsertPts.push_back(&*ExitBlock->getFirstInsertionPt());
       }
diff --git a/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp b/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp
index 05ff4efb7b94471..f64083b79373f1a 100644
--- a/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp
+++ b/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp
@@ -2086,3 +2086,14 @@ bool llvm::hasOnlySimpleTerminator(const Function &F) {
   }
   return true;
 }
+
+bool llvm::isPresplitCoroSuspendExit(const BasicBlock &BB) {
+  if (!BB.getParent()->isPresplitCoroutine())
+    return false;
+  if (auto *Pred = BB.getSinglePredecessor())
+    if (auto *SW = dyn_cast<SwitchInst>(Pred->getTerminator()))
+      if (auto *Intr = dyn_cast<IntrinsicInst>(SW->getCondition()))
+        return Intr->getIntrinsicID() == Intrinsic::coro_suspend &&
+               SW->getDefaultDest() == &BB;
+  return false;
+}
diff --git a/llvm/unittests/Transforms/Utils/BasicBlockUtilsTest.cpp b/llvm/unittests/Transforms/Utils/BasicBlockUtilsTest.cpp
index 2da38c60044bebe..fec01bc31de1034 100644
--- a/llvm/unittests/Transforms/Utils/BasicBlockUtilsTest.cpp
+++ b/llvm/unittests/Transforms/Utils/BasicBlockUtilsTest.cpp
@@ -611,3 +611,47 @@ switch i32 %0, label %LD [
   EXPECT_EQ(BranchProbability::getRaw(1),
             BPI.getEdgeProbability(EntryBB, UnreachableBB));
 }
+
+TEST(BasicBlockUtils, IsPresplitCoroSuspendExitTest) {
+  LLVMContext C;
+  std::unique_ptr<Module> M = parseIR(C, R"IR(
+define void @positive_case(i32 %0) #0 {
+entry:
+  %suspend2 = call i8 @llvm.coro.suspend(token %save2, i1 false)
+  switch i8 %suspend2, label %exit [
+    i8 0, label %resume
+    i8 1, label %destroy
+  ]
+%resume:
+  ret void
+%destroy:
+  ret void
+%exit:
+  call i1 @llvm.coro.end(ptr null, i1 false, token none)
+  ret void
+}
+
+define void @notpresplit(i32 %0) {
+entry:
+  %suspend2 = call i8 @llvm.coro.suspend(token %save2, i1 false)
+  switch i8 %suspend2, label %exit [
+    i8 0, label %resume
+    i8 1, label %destroy
+  ]
+%resume:
+  ret void
+%destroy:
+  ret void
+%exit:
+  call i1 @llvm.coro.end(ptr null, i1 false, token none)
+  ret void
+}
+attributes #0 = { presplitcoroutine }
+)IR");
+
+  Function *P = M->getFunction("positive_case");
+  EXPECT_TRUE(llvm::isPresplitCoroSuspendExit(*P->begin()));
+
+  Function *N = M->getFunction("notpresplit");
+  EXPECT_FALSE(llvm::isPresplitCoroSuspendExit(*N->begin()));
+}

>From cf1071736c4375cd200280def914669959bf92f1 Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Thu, 16 Nov 2023 13:59:21 -0800
Subject: [PATCH 3/5] Fixed unittest.

---
 .../Transforms/Utils/BasicBlockUtilsTest.cpp  | 37 +++++++++++++------
 1 file changed, 25 insertions(+), 12 deletions(-)

diff --git a/llvm/unittests/Transforms/Utils/BasicBlockUtilsTest.cpp b/llvm/unittests/Transforms/Utils/BasicBlockUtilsTest.cpp
index fec01bc31de1034..3b77f4330d79377 100644
--- a/llvm/unittests/Transforms/Utils/BasicBlockUtilsTest.cpp
+++ b/llvm/unittests/Transforms/Utils/BasicBlockUtilsTest.cpp
@@ -617,41 +617,54 @@ TEST(BasicBlockUtils, IsPresplitCoroSuspendExitTest) {
   std::unique_ptr<Module> M = parseIR(C, R"IR(
 define void @positive_case(i32 %0) #0 {
 entry:
-  %suspend2 = call i8 @llvm.coro.suspend(token %save2, i1 false)
-  switch i8 %suspend2, label %exit [
+  %save = call token @llvm.coro.save(ptr null)
+  %suspend = call i8 @llvm.coro.suspend(token %save, i1 false)
+  switch i8 %suspend, label %exit [
     i8 0, label %resume
     i8 1, label %destroy
   ]
-%resume:
+resume:
   ret void
-%destroy:
+destroy:
   ret void
-%exit:
+exit:
   call i1 @llvm.coro.end(ptr null, i1 false, token none)
   ret void
 }
 
 define void @notpresplit(i32 %0) {
 entry:
-  %suspend2 = call i8 @llvm.coro.suspend(token %save2, i1 false)
-  switch i8 %suspend2, label %exit [
+  %save = call token @llvm.coro.save(ptr null)
+  %suspend = call i8 @llvm.coro.suspend(token %save, i1 false)
+  switch i8 %suspend, label %exit [
     i8 0, label %resume
     i8 1, label %destroy
   ]
-%resume:
+resume:
   ret void
-%destroy:
+destroy:
   ret void
-%exit:
+exit:
   call i1 @llvm.coro.end(ptr null, i1 false, token none)
   ret void
 }
+
+declare token @llvm.coro.save(ptr)
+declare i8 @llvm.coro.suspend(token, i1)
+declare i1 @llvm.coro.end(ptr, i1, token)
+
 attributes #0 = { presplitcoroutine }
 )IR");
 
+  auto FindExit = [](const Function &F) -> const BasicBlock * {
+    for (const auto &BB : F)
+      if (BB.getName() == "exit")
+        return &BB;
+    return nullptr;
+  };
   Function *P = M->getFunction("positive_case");
-  EXPECT_TRUE(llvm::isPresplitCoroSuspendExit(*P->begin()));
+  EXPECT_TRUE(llvm::isPresplitCoroSuspendExit(*FindExit(*P)));
 
   Function *N = M->getFunction("notpresplit");
-  EXPECT_FALSE(llvm::isPresplitCoroSuspendExit(*N->begin()));
+  EXPECT_FALSE(llvm::isPresplitCoroSuspendExit(*FindExit(*N)));
 }

>From 3ec71c6b7675f4526561f9ebb1c93a3115a182ed Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Thu, 16 Nov 2023 21:15:54 -0800
Subject: [PATCH 4/5] Reuse `isPresplitCoroSuspendExitEdge`

---
 .../llvm/Transforms/Instrumentation/CFGMST.h      | 15 +++------------
 .../llvm/Transforms/Utils/BasicBlockUtils.h       |  3 ++-
 .../Transforms/Instrumentation/InstrProfiling.cpp |  6 +++++-
 llvm/lib/Transforms/Utils/BasicBlockUtils.cpp     | 15 ++++++++-------
 .../Transforms/Utils/BasicBlockUtilsTest.cpp      |  8 ++++++--
 5 files changed, 24 insertions(+), 23 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Instrumentation/CFGMST.h b/llvm/include/llvm/Transforms/Instrumentation/CFGMST.h
index cd2ae61334d0f05..682ae877f7ae393 100644
--- a/llvm/include/llvm/Transforms/Instrumentation/CFGMST.h
+++ b/llvm/include/llvm/Transforms/Instrumentation/CFGMST.h
@@ -100,20 +100,11 @@ template <class Edge, class BBInfo> class CFGMST {
     //   i8 0, label %await.ready
     //   i8 1, label %exit
     // ]
-    const BasicBlock *EdgeTarget = E->DestBB;
-    if (!EdgeTarget)
+    if (!E->DestBB)
       return;
     assert(E->SrcBB);
-    const Function *F = EdgeTarget->getParent();
-    if (!F->isPresplitCoroutine())
-      return;
-
-    const Instruction *TI = E->SrcBB->getTerminator();
-    if (auto *SWInst = dyn_cast<SwitchInst>(TI))
-      if (auto *Intrinsic = dyn_cast<IntrinsicInst>(SWInst->getCondition()))
-        if (Intrinsic->getIntrinsicID() == Intrinsic::coro_suspend &&
-            SWInst->getDefaultDest() == EdgeTarget)
-          E->Removed = true;
+    if (llvm::isPresplitCoroSuspendExitEdge(*E->SrcBB, *E->DestBB))
+      E->Removed = true;
   }
 
   // Traverse the CFG using a stack. Find all the edges and assign the weight.
diff --git a/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h b/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h
index 21d56c6c5848bd6..71f1884eefb70a8 100644
--- a/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h
+++ b/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h
@@ -706,7 +706,8 @@ void InvertBranch(BranchInst *PBI, IRBuilderBase &Builder);
 // br/brcond/unreachable/ret
 bool hasOnlySimpleTerminator(const Function &F);
 
-bool isPresplitCoroSuspendExit(const BasicBlock &BB);
+bool isPresplitCoroSuspendExitEdge(const BasicBlock &Src,
+                                   const BasicBlock &Dest);
 } // end namespace llvm
 
 #endif // LLVM_TRANSFORMS_UTILS_BASICBLOCKUTILS_H
diff --git a/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp b/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp
index ddef932f3705bfd..5033aa8f046f5e5 100644
--- a/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp
+++ b/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp
@@ -14,6 +14,7 @@
 
 #include "llvm/Transforms/Instrumentation/InstrProfiling.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/Twine.h"
@@ -23,6 +24,7 @@
 #include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/IR/Attributes.h"
 #include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/CFG.h"
 #include "llvm/IR/Constant.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DIBuilder.h"
@@ -245,7 +247,9 @@ class PGOCounterPromoter {
 
     for (BasicBlock *ExitBlock : LoopExitBlocks) {
       if (BlockSet.insert(ExitBlock).second &&
-          !llvm::isPresplitCoroSuspendExit(*ExitBlock)) {
+          llvm::none_of(predecessors(ExitBlock), [&](const BasicBlock *Pred) {
+            return llvm::isPresplitCoroSuspendExitEdge(*Pred, *ExitBlock);
+          })) {
         ExitBlocks.push_back(ExitBlock);
         InsertPts.push_back(&*ExitBlock->getFirstInsertionPt());
       }
diff --git a/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp b/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp
index f64083b79373f1a..c4c5e415166e583 100644
--- a/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp
+++ b/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp
@@ -2087,13 +2087,14 @@ bool llvm::hasOnlySimpleTerminator(const Function &F) {
   return true;
 }
 
-bool llvm::isPresplitCoroSuspendExit(const BasicBlock &BB) {
-  if (!BB.getParent()->isPresplitCoroutine())
+bool llvm::isPresplitCoroSuspendExitEdge(const BasicBlock &Src,
+                                         const BasicBlock &Dest) {
+  assert(Src.getParent() == Dest.getParent());
+  if (!Src.getParent()->isPresplitCoroutine())
     return false;
-  if (auto *Pred = BB.getSinglePredecessor())
-    if (auto *SW = dyn_cast<SwitchInst>(Pred->getTerminator()))
-      if (auto *Intr = dyn_cast<IntrinsicInst>(SW->getCondition()))
-        return Intr->getIntrinsicID() == Intrinsic::coro_suspend &&
-               SW->getDefaultDest() == &BB;
+  if (auto *SW = dyn_cast<SwitchInst>(Src.getTerminator()))
+    if (auto *Intr = dyn_cast<IntrinsicInst>(SW->getCondition()))
+      return Intr->getIntrinsicID() == Intrinsic::coro_suspend &&
+             SW->getDefaultDest() == &Dest;
   return false;
 }
diff --git a/llvm/unittests/Transforms/Utils/BasicBlockUtilsTest.cpp b/llvm/unittests/Transforms/Utils/BasicBlockUtilsTest.cpp
index 3b77f4330d79377..5152cbe19c215c3 100644
--- a/llvm/unittests/Transforms/Utils/BasicBlockUtilsTest.cpp
+++ b/llvm/unittests/Transforms/Utils/BasicBlockUtilsTest.cpp
@@ -663,8 +663,12 @@ attributes #0 = { presplitcoroutine }
     return nullptr;
   };
   Function *P = M->getFunction("positive_case");
-  EXPECT_TRUE(llvm::isPresplitCoroSuspendExit(*FindExit(*P)));
+  const auto &ExitP = *FindExit(*P);
+  EXPECT_TRUE(llvm::isPresplitCoroSuspendExitEdge(*ExitP.getSinglePredecessor(),
+                                                  ExitP));
 
   Function *N = M->getFunction("notpresplit");
-  EXPECT_FALSE(llvm::isPresplitCoroSuspendExit(*FindExit(*N)));
+  const auto &ExitN = *FindExit(*N);
+  EXPECT_FALSE(llvm::isPresplitCoroSuspendExitEdge(
+      *ExitN.getSinglePredecessor(), ExitN));
 }

>From 2a3affc7b704c124dc9a33207deed1d7b9d44df1 Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Thu, 30 Nov 2023 08:38:08 -0800
Subject: [PATCH 5/5] Comment describing `isPresplitCorouSuspendExitEdge`

---
 .../llvm/Transforms/Utils/BasicBlockUtils.h       | 15 +++++++++++++++
 1 file changed, 15 insertions(+)

diff --git a/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h b/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h
index 71f1884eefb70a8..db9e07a861e11f0 100644
--- a/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h
+++ b/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h
@@ -706,6 +706,21 @@ void InvertBranch(BranchInst *PBI, IRBuilderBase &Builder);
 // br/brcond/unreachable/ret
 bool hasOnlySimpleTerminator(const Function &F);
 
+// Returns true if these basic blocks belong to a presplit coroutine and the
+// edge corresponds to the 'default' case in the switch statement in the pattern:
+//
+// %0 = call i8 @llvm.coro.suspend(token none, i1 false)
+// switch i8 %0, label %suspend [i8 0, label %resume
+                                i8 1, label %cleanup]
+//
+// i.e. the edge to the `%suspend` BB. This edge is special in that it will
+// be elided by coroutine lowering (coro-split), and the `%suspend` BB needs
+// to be kept as-is. It's not a real CFG edge - post-lowering, it will end
+// up being a `ret`, and it must be thus lowerable to support symmetric
+// transfer. For example:
+//  - this edge is not a loop exit edge if encountered in a loop (and should
+//    be ignored)
+//  - must not be split for PGO instrumentation, for example.
 bool isPresplitCoroSuspendExitEdge(const BasicBlock &Src,
                                    const BasicBlock &Dest);
 } // end namespace llvm



More information about the llvm-commits mailing list