[llvm] 284da04 - [coro][pgo] Don't promote pgo counters in the suspend basic block (#71263)

via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 30 11:58:33 PST 2023


Author: Mircea Trofin
Date: 2023-11-30T11:58:26-08:00
New Revision: 284da049f5feb62b40f5abc41dda7895e3d81d72

URL: https://github.com/llvm/llvm-project/commit/284da049f5feb62b40f5abc41dda7895e3d81d72
DIFF: https://github.com/llvm/llvm-project/commit/284da049f5feb62b40f5abc41dda7895e3d81d72.diff

LOG: [coro][pgo] Don't promote pgo counters in the suspend basic block (#71263)

If a suspend happens in the resume part (this can happen in the case of chained coroutines), and that's part of a loop, the pre-split CFG has the suspend block as an exit of that loop. PGO Counter Promotion will then try to commit the temporary counter to the global in that "exit" block (it also does that in the other loop exit BBs, which also includes
the "destroy" case). This interferes with symmetric transfer.

We don't need to commit the counter in the suspend case - it's not a loop exit from the perspective of the behavior of the program. The regular loop exit, together with the "destroy" case, completely cover any updates that may need to happen to the global counter.

Added: 
    llvm/test/Transforms/Coroutines/coro-split-musttail-chain-pgo-counter-promo.ll

Modified: 
    llvm/include/llvm/Transforms/Instrumentation/CFGMST.h
    llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h
    llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp
    llvm/lib/Transforms/Utils/BasicBlockUtils.cpp
    llvm/unittests/Transforms/Utils/BasicBlockUtilsTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/Instrumentation/CFGMST.h b/llvm/include/llvm/Transforms/Instrumentation/CFGMST.h
index cd2ae61334d0f05..682ae877f7ae393 100644
--- a/llvm/include/llvm/Transforms/Instrumentation/CFGMST.h
+++ b/llvm/include/llvm/Transforms/Instrumentation/CFGMST.h
@@ -100,20 +100,11 @@ template <class Edge, class BBInfo> class CFGMST {
     //   i8 0, label %await.ready
     //   i8 1, label %exit
     // ]
-    const BasicBlock *EdgeTarget = E->DestBB;
-    if (!EdgeTarget)
+    if (!E->DestBB)
       return;
     assert(E->SrcBB);
-    const Function *F = EdgeTarget->getParent();
-    if (!F->isPresplitCoroutine())
-      return;
-
-    const Instruction *TI = E->SrcBB->getTerminator();
-    if (auto *SWInst = dyn_cast<SwitchInst>(TI))
-      if (auto *Intrinsic = dyn_cast<IntrinsicInst>(SWInst->getCondition()))
-        if (Intrinsic->getIntrinsicID() == Intrinsic::coro_suspend &&
-            SWInst->getDefaultDest() == EdgeTarget)
-          E->Removed = true;
+    if (llvm::isPresplitCoroSuspendExitEdge(*E->SrcBB, *E->DestBB))
+      E->Removed = true;
   }
 
   // Traverse the CFG using a stack. Find all the edges and assign the weight.

diff  --git a/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h b/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h
index e6dde450b7df9c8..e650ac80efbcdc7 100644
--- a/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h
+++ b/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h
@@ -705,6 +705,25 @@ void InvertBranch(BranchInst *PBI, IRBuilderBase &Builder);
 // Check whether the function only has simple terminator:
 // br/brcond/unreachable/ret
 bool hasOnlySimpleTerminator(const Function &F);
+
+// Returns true if these basic blocks belong to a presplit coroutine and the
+// edge corresponds to the 'default' case in the switch statement in the
+// pattern:
+//
+// %0 = call i8 @llvm.coro.suspend(token none, i1 false)
+// switch i8 %0, label %suspend [i8 0, label %resume
+//                               i8 1, label %cleanup]
+//
+// i.e. the edge to the `%suspend` BB. This edge is special in that it will
+// be elided by coroutine lowering (coro-split), and the `%suspend` BB needs
+// to be kept as-is. It's not a real CFG edge - post-lowering, it will end
+// up being a `ret`, and it must be thus lowerable to support symmetric
+// transfer. For example:
+//  - this edge is not a loop exit edge if encountered in a loop (and should
+//    be ignored)
+//  - must not be split for PGO instrumentation, for example.
+bool isPresplitCoroSuspendExitEdge(const BasicBlock &Src,
+                                   const BasicBlock &Dest);
 } // end namespace llvm
 
 #endif // LLVM_TRANSFORMS_UTILS_BASICBLOCKUTILS_H

diff  --git a/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp b/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp
index 73a7116f74e1180..10258e254679133 100644
--- a/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp
+++ b/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp
@@ -14,6 +14,7 @@
 
 #include "llvm/Transforms/Instrumentation/InstrProfiling.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/Twine.h"
@@ -23,6 +24,7 @@
 #include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/IR/Attributes.h"
 #include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/CFG.h"
 #include "llvm/IR/Constant.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DIBuilder.h"
@@ -48,6 +50,7 @@
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/TargetParser/Triple.h"
 #include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/ModuleUtils.h"
 #include "llvm/Transforms/Utils/SSAUpdater.h"
 #include <algorithm>
@@ -243,7 +246,10 @@ class PGOCounterPromoter {
       return;
 
     for (BasicBlock *ExitBlock : LoopExitBlocks) {
-      if (BlockSet.insert(ExitBlock).second) {
+      if (BlockSet.insert(ExitBlock).second &&
+          llvm::none_of(predecessors(ExitBlock), [&](const BasicBlock *Pred) {
+            return llvm::isPresplitCoroSuspendExitEdge(*Pred, *ExitBlock);
+          })) {
         ExitBlocks.push_back(ExitBlock);
         InsertPts.push_back(&*ExitBlock->getFirstInsertionPt());
       }

diff  --git a/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp b/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp
index 168998fbee114ab..ccee97025e3c3f9 100644
--- a/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp
+++ b/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp
@@ -2149,3 +2149,15 @@ bool llvm::hasOnlySimpleTerminator(const Function &F) {
   }
   return true;
 }
+
+bool llvm::isPresplitCoroSuspendExitEdge(const BasicBlock &Src,
+                                         const BasicBlock &Dest) {
+  assert(Src.getParent() == Dest.getParent());
+  if (!Src.getParent()->isPresplitCoroutine())
+    return false;
+  if (auto *SW = dyn_cast<SwitchInst>(Src.getTerminator()))
+    if (auto *Intr = dyn_cast<IntrinsicInst>(SW->getCondition()))
+      return Intr->getIntrinsicID() == Intrinsic::coro_suspend &&
+             SW->getDefaultDest() == &Dest;
+  return false;
+}

diff  --git a/llvm/test/Transforms/Coroutines/coro-split-musttail-chain-pgo-counter-promo.ll b/llvm/test/Transforms/Coroutines/coro-split-musttail-chain-pgo-counter-promo.ll
new file mode 100644
index 000000000000000..ddd293eed2409e9
--- /dev/null
+++ b/llvm/test/Transforms/Coroutines/coro-split-musttail-chain-pgo-counter-promo.ll
@@ -0,0 +1,175 @@
+; REQUIRES: x86-registered-target
+; RUN: opt -passes='pgo-instr-gen,instrprof,coro-split' -do-counter-promotion=true -S < %s | FileCheck %s
+
+; CHECK-LABEL: define internal fastcc void @f.resume
+; CHECK: musttail call fastcc void 
+; CHECK-NEXT: ret void
+; CHECK: musttail call fastcc void 
+; CHECK-NEXT: ret void
+; CHECK-LABEL: define internal fastcc void @f.destroy
+target triple = "x86_64-grtev4-linux-gnu"
+
+%CoroutinePromise = type { ptr, i64, [8 x i8], ptr} 
+%Awaitable.1 = type { ptr }
+%Awaitable.2 = type { ptr, ptr }
+
+declare void @await_suspend(ptr noundef nonnull align 1 dereferenceable(1), ptr) local_unnamed_addr
+declare ptr @await_transform_await_suspend(ptr noundef nonnull align 8 dereferenceable(16), ptr) local_unnamed_addr
+declare void @destroy_frame_slowpath(ptr noundef nonnull align 16 dereferenceable(32)) local_unnamed_addr
+declare ptr @other_coro();
+declare void @heap_delete(ptr noundef, i64 noundef, i64 noundef) local_unnamed_addr
+declare noundef nonnull ptr @heap_allocate(i64 noundef, i64 noundef) local_unnamed_addr
+
+declare void @llvm.assume(i1 noundef)
+declare i64 @llvm.coro.align.i64()
+declare i1 @llvm.coro.alloc(token)
+declare ptr @llvm.coro.begin(token, ptr writeonly)
+declare i1 @llvm.coro.end(ptr, i1, token)
+declare ptr @llvm.coro.free(token, ptr nocapture readonly)
+declare token @llvm.coro.id(i32, ptr readnone, ptr nocapture readonly, ptr)
+declare token @llvm.coro.save(ptr)
+declare i64 @llvm.coro.size.i64()
+declare ptr @llvm.coro.subfn.addr(ptr nocapture readonly, i8)
+declare i8 @llvm.coro.suspend(token, i1)
+declare void @llvm.instrprof.increment(ptr, i64, i32, i32)
+declare void @llvm.instrprof.value.profile(ptr, i64, i64, i32, i32)
+declare void @llvm.lifetime.start.p0(i64 immarg, ptr nocapture)
+declare void @llvm.lifetime.end.p0(i64 immarg, ptr nocapture)
+
+; Function Attrs: noinline nounwind presplitcoroutine uwtable
+define ptr @f(i32 %0) presplitcoroutine align 32 {
+  %2 = alloca i32, align 8
+  %3 = alloca %CoroutinePromise, align 16
+  %4 = alloca %Awaitable.1, align 8
+  %5 = alloca %Awaitable.2, align 8
+  %6 = call token @llvm.coro.id(i32 8, ptr nonnull %3, ptr nonnull @f, ptr null)
+  %7 = call i1 @llvm.coro.alloc(token %6)
+  br i1 %7, label %8, label %12
+
+8:                                                ; preds = %1
+  %9 = call i64 @llvm.coro.size.i64()
+  %10 = call i64 @llvm.coro.align.i64()
+  %11 = call noalias noundef nonnull ptr @heap_allocate(i64 noundef %9, i64 noundef %10) #27
+  call void @llvm.assume(i1 true) [ "align"(ptr %11, i64 %10) ]
+  br label %12
+
+12:                                               ; preds = %8, %1
+  %13 = phi ptr [ null, %1 ], [ %11, %8 ]
+  %14 = call ptr @llvm.coro.begin(token %6, ptr %13) #28
+  call void @llvm.lifetime.start.p0(i64 32, ptr nonnull %3) #9
+  store ptr null, ptr %3, align 16
+  %15 = getelementptr inbounds {ptr, i64}, ptr %3, i64 0, i32 1
+  store i64 0, ptr %15, align 8
+  call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %4) #9
+  store ptr %3, ptr %4, align 8
+  %16 = call token @llvm.coro.save(ptr null)
+  call void @await_suspend(ptr noundef nonnull align 1 dereferenceable(1) %4, ptr %14) #9
+  %17 = call i8 @llvm.coro.suspend(token %16, i1 false)
+  switch i8 %17, label %61 [
+    i8 0, label %18
+    i8 1, label %21
+  ]
+
+18:                                               ; preds = %12
+  call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %4) #9
+  %19 = icmp slt i32 0, %0
+  br i1 %19, label %20, label %36
+
+20:                                               ; preds = %18
+  br label %22
+
+21:                                               ; preds = %12
+  call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %4) #9
+  br label %54
+
+22:                                               ; preds = %20, %31
+  %23 = phi i32 [ 0, %20 ], [ %32, %31 ]
+  call void @llvm.lifetime.start.p0(i64 16, ptr nonnull %5) #9
+  %24 = call ptr @other_coro()
+  store ptr %3, ptr %5, align 8
+  %25 = getelementptr inbounds { ptr, ptr }, ptr %5, i64 0, i32 1
+  store ptr %24, ptr %25, align 8
+  %26 = call token @llvm.coro.save(ptr null)
+  %27 = call ptr @await_transform_await_suspend(ptr noundef nonnull align 8 dereferenceable(16) %5, ptr %14)
+  %28 = call ptr @llvm.coro.subfn.addr(ptr %27, i8 0)
+  %29 = ptrtoint ptr %28 to i64
+  call fastcc void %28(ptr %27) #9
+  %30 = call i8 @llvm.coro.suspend(token %26, i1 false)
+  switch i8 %30, label %60 [
+    i8 0, label %31
+    i8 1, label %34
+  ]
+
+31:                                               ; preds = %22
+  call void @llvm.lifetime.end.p0(i64 16, ptr nonnull %5) #9
+  %32 = add nuw nsw i32 %23, 1
+  %33 = icmp slt i32 %32, %0
+  br i1 %33, label %22, label %35, !llvm.loop !0
+
+34:                                               ; preds = %22
+  call void @llvm.lifetime.end.p0(i64 16, ptr nonnull %5) #9
+  br label %54
+
+35:                                               ; preds = %31
+  br label %36
+
+36:                                               ; preds = %35, %18
+  %37 = call token @llvm.coro.save(ptr null)
+  %38 = getelementptr inbounds i8, ptr %14, i64 16
+  %39 = getelementptr inbounds i8, ptr %14, i64 32
+  %40 = load i64, ptr %39, align 8
+  %41 = load ptr, ptr %38, align 16
+  %42 = icmp eq ptr %41, null
+  br i1 %42, label %43, label %46
+
+43:                                               ; preds = %36
+  %44 = call ptr @llvm.coro.subfn.addr(ptr nonnull %14, i8 1)
+  %45 = ptrtoint ptr %44 to i64
+  call fastcc void %44(ptr nonnull %14) #9
+  br label %47
+
+46:                                               ; preds = %36
+  call void @destroy_frame_slowpath(ptr noundef nonnull align 16 dereferenceable(32) %38) #9
+  br label %47
+
+47:                                               ; preds = %43, %46
+  %48 = inttoptr i64 %40 to ptr
+  %49 = call ptr @llvm.coro.subfn.addr(ptr %48, i8 0)
+  %50 = ptrtoint ptr %49 to i64
+  call fastcc void %49(ptr %48) #9
+  %51 = call i8 @llvm.coro.suspend(token %37, i1 true) #28
+  switch i8 %51, label %61 [
+    i8 0, label %53
+    i8 1, label %52
+  ]
+
+52:                                               ; preds = %47
+  br label %54
+
+53:                                               ; preds = %47
+  call void @llvm.lifetime.start.p0(i64 16, ptr nonnull %2) #9
+  unreachable
+
+54:                                               ; preds = %52, %34, %21
+  call void @llvm.lifetime.end.p0(i64 32, ptr nonnull %3) #9
+  %55 = call ptr @llvm.coro.free(token %6, ptr %14)
+  %56 = icmp eq ptr %55, null
+  br i1 %56, label %61, label %57
+
+57:                                               ; preds = %54
+  %58 = call i64 @llvm.coro.size.i64()
+  %59 = call i64 @llvm.coro.align.i64()
+  call void @heap_delete(ptr noundef nonnull %55, i64 noundef %58, i64 noundef %59) #9
+  br label %61
+
+60:                                               ; preds = %22
+  br label %61
+
+61:                                               ; preds = %60, %57, %54, %47, %12
+  %62 = getelementptr inbounds i8, ptr %3, i64 -16
+  %63 = call i1 @llvm.coro.end(ptr null, i1 false, token none) #28
+  ret ptr %62
+}
+
+!0 = distinct !{!0, !1}
+!1 = !{!"llvm.loop.mustprogress"}

diff  --git a/llvm/unittests/Transforms/Utils/BasicBlockUtilsTest.cpp b/llvm/unittests/Transforms/Utils/BasicBlockUtilsTest.cpp
index 2da38c60044bebe..5152cbe19c215c3 100644
--- a/llvm/unittests/Transforms/Utils/BasicBlockUtilsTest.cpp
+++ b/llvm/unittests/Transforms/Utils/BasicBlockUtilsTest.cpp
@@ -611,3 +611,64 @@ switch i32 %0, label %LD [
   EXPECT_EQ(BranchProbability::getRaw(1),
             BPI.getEdgeProbability(EntryBB, UnreachableBB));
 }
+
+TEST(BasicBlockUtils, IsPresplitCoroSuspendExitTest) {
+  LLVMContext C;
+  std::unique_ptr<Module> M = parseIR(C, R"IR(
+define void @positive_case(i32 %0) #0 {
+entry:
+  %save = call token @llvm.coro.save(ptr null)
+  %suspend = call i8 @llvm.coro.suspend(token %save, i1 false)
+  switch i8 %suspend, label %exit [
+    i8 0, label %resume
+    i8 1, label %destroy
+  ]
+resume:
+  ret void
+destroy:
+  ret void
+exit:
+  call i1 @llvm.coro.end(ptr null, i1 false, token none)
+  ret void
+}
+
+define void @notpresplit(i32 %0) {
+entry:
+  %save = call token @llvm.coro.save(ptr null)
+  %suspend = call i8 @llvm.coro.suspend(token %save, i1 false)
+  switch i8 %suspend, label %exit [
+    i8 0, label %resume
+    i8 1, label %destroy
+  ]
+resume:
+  ret void
+destroy:
+  ret void
+exit:
+  call i1 @llvm.coro.end(ptr null, i1 false, token none)
+  ret void
+}
+
+declare token @llvm.coro.save(ptr)
+declare i8 @llvm.coro.suspend(token, i1)
+declare i1 @llvm.coro.end(ptr, i1, token)
+
+attributes #0 = { presplitcoroutine }
+)IR");
+
+  auto FindExit = [](const Function &F) -> const BasicBlock * {
+    for (const auto &BB : F)
+      if (BB.getName() == "exit")
+        return &BB;
+    return nullptr;
+  };
+  Function *P = M->getFunction("positive_case");
+  const auto &ExitP = *FindExit(*P);
+  EXPECT_TRUE(llvm::isPresplitCoroSuspendExitEdge(*ExitP.getSinglePredecessor(),
+                                                  ExitP));
+
+  Function *N = M->getFunction("notpresplit");
+  const auto &ExitN = *FindExit(*N);
+  EXPECT_FALSE(llvm::isPresplitCoroSuspendExitEdge(
+      *ExitN.getSinglePredecessor(), ExitN));
+}


        


More information about the llvm-commits mailing list