[llvm] [ctxprof][nfc] Move profile annotator to Analysis (PR #135871)

Mircea Trofin via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 16 08:09:21 PDT 2025


https://github.com/mtrofin updated https://github.com/llvm/llvm-project/pull/135871

>From babbffbacdb54079726016f2785fa554512d9be7 Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Tue, 15 Apr 2025 13:20:16 -0700
Subject: [PATCH] [ctxprof][nfc] Move profile annotator to Analysis

---
 llvm/include/llvm/Analysis/CtxProfAnalysis.h  |  28 ++
 llvm/lib/Analysis/CtxProfAnalysis.cpp         | 372 +++++++++++++++++
 .../Instrumentation/PGOCtxProfFlattening.cpp  | 375 ++----------------
 3 files changed, 424 insertions(+), 351 deletions(-)

diff --git a/llvm/include/llvm/Analysis/CtxProfAnalysis.h b/llvm/include/llvm/Analysis/CtxProfAnalysis.h
index 6f1c3696ca78c..aa582cfef1ad1 100644
--- a/llvm/include/llvm/Analysis/CtxProfAnalysis.h
+++ b/llvm/include/llvm/Analysis/CtxProfAnalysis.h
@@ -157,6 +157,34 @@ class CtxProfAnalysisPrinterPass
   const PrintMode Mode;
 };
 
+/// Utility that propagates counter values to each basic block and to each edge
+/// when a basic block has more than one outgoing edge, using an adaptation of
+/// PGOUseFunc::populateCounters.
+// FIXME(mtrofin): look into factoring the code to share one implementation.
+class ProfileAnnotatorImpl;
+class ProfileAnnotator {
+  std::unique_ptr<ProfileAnnotatorImpl> PImpl;
+
+public:
+  ProfileAnnotator(const Function &F, ArrayRef<uint64_t> RawCounters);
+  uint64_t getBBCount(const BasicBlock &BB) const;
+
+  // Finds the true and false counts for the given select instruction. Returns
+  // false if the select doesn't have instrumentation or if the count of the
+  // parent BB is 0.
+  bool getSelectInstrProfile(SelectInst &SI, uint64_t &TrueCount,
+                             uint64_t &FalseCount) const;
+  // Clears Profile and populates it with the edge weights, in the same order as
+  // they need to appear in the MD_prof metadata. Also computes the max of those
+  // weights an returns it in MaxCount. Returs false if:
+  //   - the BB has less than 2 successors
+  //   - the counts are 0
+  bool getOutgoingBranchWeights(BasicBlock &BB,
+                                SmallVectorImpl<uint64_t> &Profile,
+                                uint64_t &MaxCount) const;
+  ~ProfileAnnotator();
+};
+
 /// Assign a GUID to functions as metadata. GUID calculation takes linkage into
 /// account, which may change especially through and after thinlto. By
 /// pre-computing and assigning as metadata, this mechanism is resilient to such
diff --git a/llvm/lib/Analysis/CtxProfAnalysis.cpp b/llvm/lib/Analysis/CtxProfAnalysis.cpp
index d203e277546ea..391631e15aa89 100644
--- a/llvm/lib/Analysis/CtxProfAnalysis.cpp
+++ b/llvm/lib/Analysis/CtxProfAnalysis.cpp
@@ -14,7 +14,9 @@
 #include "llvm/Analysis/CtxProfAnalysis.h"
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/Analysis/CFG.h"
 #include "llvm/IR/Analysis.h"
+#include "llvm/IR/Dominators.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PassManager.h"
@@ -22,6 +24,8 @@
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/MemoryBuffer.h"
 #include "llvm/Support/Path.h"
+#include <deque>
+#include <memory>
 
 #define DEBUG_TYPE "ctx_prof"
 
@@ -46,6 +50,374 @@ static cl::opt<bool> ForceIsInSpecializedModule(
 
 const char *AssignGUIDPass::GUIDMetadataName = "guid";
 
+namespace llvm {
+class ProfileAnnotatorImpl final {
+  friend class ProfileAnnotator;
+  class BBInfo;
+  struct EdgeInfo {
+    BBInfo *const Src;
+    BBInfo *const Dest;
+    std::optional<uint64_t> Count;
+
+    explicit EdgeInfo(BBInfo &Src, BBInfo &Dest) : Src(&Src), Dest(&Dest) {}
+  };
+
+  class BBInfo {
+    std::optional<uint64_t> Count;
+    // OutEdges is dimensioned to match the number of terminator operands.
+    // Entries in the vector match the index in the terminator operand list. In
+    // some cases - see `shouldExcludeEdge` and its implementation - an entry
+    // will be nullptr.
+    // InEdges doesn't have the above constraint.
+    SmallVector<EdgeInfo *> OutEdges;
+    SmallVector<EdgeInfo *> InEdges;
+    size_t UnknownCountOutEdges = 0;
+    size_t UnknownCountInEdges = 0;
+
+    // Pass AssumeAllKnown when we try to propagate counts from edges to BBs -
+    // because all the edge counters must be known.
+    // Return std::nullopt if there were no edges to sum. The user can decide
+    // how to interpret that.
+    std::optional<uint64_t> getEdgeSum(const SmallVector<EdgeInfo *> &Edges,
+                                       bool AssumeAllKnown) const {
+      std::optional<uint64_t> Sum;
+      for (const auto *E : Edges) {
+        // `Edges` may be `OutEdges`, case in which `E` could be nullptr.
+        if (E) {
+          if (!Sum.has_value())
+            Sum = 0;
+          *Sum += (AssumeAllKnown ? *E->Count : E->Count.value_or(0U));
+        }
+      }
+      return Sum;
+    }
+
+    bool computeCountFrom(const SmallVector<EdgeInfo *> &Edges) {
+      assert(!Count.has_value());
+      Count = getEdgeSum(Edges, true);
+      return Count.has_value();
+    }
+
+    void setSingleUnknownEdgeCount(SmallVector<EdgeInfo *> &Edges) {
+      uint64_t KnownSum = getEdgeSum(Edges, false).value_or(0U);
+      uint64_t EdgeVal = *Count > KnownSum ? *Count - KnownSum : 0U;
+      EdgeInfo *E = nullptr;
+      for (auto *I : Edges)
+        if (I && !I->Count.has_value()) {
+          E = I;
+#ifdef NDEBUG
+          break;
+#else
+          assert((!E || E == I) &&
+                 "Expected exactly one edge to have an unknown count, "
+                 "found a second one");
+          continue;
+#endif
+        }
+      assert(E && "Expected exactly one edge to have an unknown count");
+      assert(!E->Count.has_value());
+      E->Count = EdgeVal;
+      assert(E->Src->UnknownCountOutEdges > 0);
+      assert(E->Dest->UnknownCountInEdges > 0);
+      --E->Src->UnknownCountOutEdges;
+      --E->Dest->UnknownCountInEdges;
+    }
+
+  public:
+    BBInfo(size_t NumInEdges, size_t NumOutEdges, std::optional<uint64_t> Count)
+        : Count(Count) {
+      // For in edges, we just want to pre-allocate enough space, since we know
+      // it at this stage. For out edges, we will insert edges at the indices
+      // corresponding to positions in this BB's terminator instruction, so we
+      // construct a default (nullptr values)-initialized vector. A nullptr edge
+      // corresponds to those that are excluded (see shouldExcludeEdge).
+      InEdges.reserve(NumInEdges);
+      OutEdges.resize(NumOutEdges);
+    }
+
+    bool tryTakeCountFromKnownOutEdges(const BasicBlock &BB) {
+      if (!UnknownCountOutEdges) {
+        return computeCountFrom(OutEdges);
+      }
+      return false;
+    }
+
+    bool tryTakeCountFromKnownInEdges(const BasicBlock &BB) {
+      if (!UnknownCountInEdges) {
+        return computeCountFrom(InEdges);
+      }
+      return false;
+    }
+
+    void addInEdge(EdgeInfo &Info) {
+      InEdges.push_back(&Info);
+      ++UnknownCountInEdges;
+    }
+
+    // For the out edges, we care about the position we place them in, which is
+    // the position in terminator instruction's list (at construction). Later,
+    // we build branch_weights metadata with edge frequency values matching
+    // these positions.
+    void addOutEdge(size_t Index, EdgeInfo &Info) {
+      OutEdges[Index] = &Info;
+      ++UnknownCountOutEdges;
+    }
+
+    bool hasCount() const { return Count.has_value(); }
+
+    uint64_t getCount() const { return *Count; }
+
+    bool trySetSingleUnknownInEdgeCount() {
+      if (UnknownCountInEdges == 1) {
+        setSingleUnknownEdgeCount(InEdges);
+        return true;
+      }
+      return false;
+    }
+
+    bool trySetSingleUnknownOutEdgeCount() {
+      if (UnknownCountOutEdges == 1) {
+        setSingleUnknownEdgeCount(OutEdges);
+        return true;
+      }
+      return false;
+    }
+    size_t getNumOutEdges() const { return OutEdges.size(); }
+
+    uint64_t getEdgeCount(size_t Index) const {
+      if (auto *E = OutEdges[Index])
+        return *E->Count;
+      return 0U;
+    }
+  };
+
+  const Function &F;
+  ArrayRef<uint64_t> Counters;
+  // To be accessed through getBBInfo() after construction.
+  std::map<const BasicBlock *, BBInfo> BBInfos;
+  std::vector<EdgeInfo> EdgeInfos;
+
+  // The only criteria for exclusion is faux suspend -> exit edges in presplit
+  // coroutines. The API serves for readability, currently.
+  bool shouldExcludeEdge(const BasicBlock &Src, const BasicBlock &Dest) const {
+    return llvm::isPresplitCoroSuspendExitEdge(Src, Dest);
+  }
+
+  BBInfo &getBBInfo(const BasicBlock &BB) { return BBInfos.find(&BB)->second; }
+
+  const BBInfo &getBBInfo(const BasicBlock &BB) const {
+    return BBInfos.find(&BB)->second;
+  }
+
+  // validation function after we propagate the counters: all BBs and edges'
+  // counters must have a value.
+  bool allCountersAreAssigned() const {
+    for (const auto &BBInfo : BBInfos)
+      if (!BBInfo.second.hasCount())
+        return false;
+    for (const auto &EdgeInfo : EdgeInfos)
+      if (!EdgeInfo.Count.has_value())
+        return false;
+    return true;
+  }
+
+  /// Check that all paths from the entry basic block that use edges with
+  /// non-zero counts arrive at a basic block with no successors (i.e. "exit")
+  bool allTakenPathsExit() const {
+    std::deque<const BasicBlock *> Worklist;
+    DenseSet<const BasicBlock *> Visited;
+    Worklist.push_back(&F.getEntryBlock());
+    bool HitExit = false;
+    while (!Worklist.empty()) {
+      const auto *BB = Worklist.front();
+      Worklist.pop_front();
+      if (!Visited.insert(BB).second)
+        continue;
+      if (succ_size(BB) == 0) {
+        if (isa<UnreachableInst>(BB->getTerminator()))
+          return false;
+        HitExit = true;
+        continue;
+      }
+      if (succ_size(BB) == 1) {
+        Worklist.push_back(BB->getUniqueSuccessor());
+        continue;
+      }
+      const auto &BBInfo = getBBInfo(*BB);
+      bool HasAWayOut = false;
+      for (auto I = 0U; I < BB->getTerminator()->getNumSuccessors(); ++I) {
+        const auto *Succ = BB->getTerminator()->getSuccessor(I);
+        if (!shouldExcludeEdge(*BB, *Succ)) {
+          if (BBInfo.getEdgeCount(I) > 0) {
+            HasAWayOut = true;
+            Worklist.push_back(Succ);
+          }
+        }
+      }
+      if (!HasAWayOut)
+        return false;
+    }
+    return HitExit;
+  }
+
+  bool allNonColdSelectsHaveProfile() const {
+    for (const auto &BB : F) {
+      if (getBBInfo(BB).getCount() > 0) {
+        for (const auto &I : BB) {
+          if (const auto *SI = dyn_cast<SelectInst>(&I)) {
+            if (const auto *Inst = CtxProfAnalysis::getSelectInstrumentation(
+                    *const_cast<SelectInst *>(SI))) {
+              auto Index = Inst->getIndex()->getZExtValue();
+              assert(Index < Counters.size());
+              if (Counters[Index] == 0)
+                return false;
+            }
+          }
+        }
+      }
+    }
+    return true;
+  }
+
+  // This is an adaptation of PGOUseFunc::populateCounters.
+  // FIXME(mtrofin): look into factoring the code to share one implementation.
+  void propagateCounterValues() {
+    bool KeepGoing = true;
+    while (KeepGoing) {
+      KeepGoing = false;
+      for (const auto &BB : F) {
+        auto &Info = getBBInfo(BB);
+        if (!Info.hasCount())
+          KeepGoing |= Info.tryTakeCountFromKnownOutEdges(BB) ||
+                       Info.tryTakeCountFromKnownInEdges(BB);
+        if (Info.hasCount()) {
+          KeepGoing |= Info.trySetSingleUnknownOutEdgeCount();
+          KeepGoing |= Info.trySetSingleUnknownInEdgeCount();
+        }
+      }
+    }
+    assert(allCountersAreAssigned() &&
+           "[ctx-prof] Expected all counters have been assigned.");
+    assert(allTakenPathsExit() &&
+           "[ctx-prof] Encountered a BB with more than one successor, where "
+           "all outgoing edges have a 0 count. This occurs in non-exiting "
+           "functions (message pumps, usually) which are not supported in the "
+           "contextual profiling case");
+    assert(allNonColdSelectsHaveProfile() &&
+           "[ctx-prof] All non-cold select instructions were expected to have "
+           "a profile.");
+  }
+
+public:
+  ProfileAnnotatorImpl(const Function &F, ArrayRef<uint64_t> Counters)
+      : F(F), Counters(Counters) {
+    assert(!F.isDeclaration());
+    assert(!Counters.empty());
+    size_t NrEdges = 0;
+    for (const auto &BB : F) {
+      std::optional<uint64_t> Count;
+      if (auto *Ins = CtxProfAnalysis::getBBInstrumentation(
+              const_cast<BasicBlock &>(BB))) {
+        auto Index = Ins->getIndex()->getZExtValue();
+        assert(Index < Counters.size() &&
+               "The index must be inside the counters vector by construction - "
+               "tripping this assertion indicates a bug in how the contextual "
+               "profile is managed by IPO transforms");
+        (void)Index;
+        Count = Counters[Ins->getIndex()->getZExtValue()];
+      } else if (isa<UnreachableInst>(BB.getTerminator())) {
+        // The program presumably didn't crash.
+        Count = 0;
+      }
+      auto [It, Ins] =
+          BBInfos.insert({&BB, {pred_size(&BB), succ_size(&BB), Count}});
+      (void)Ins;
+      assert(Ins && "We iterate through the function's BBs, no reason to "
+                    "insert one more than once");
+      NrEdges += llvm::count_if(successors(&BB), [&](const auto *Succ) {
+        return !shouldExcludeEdge(BB, *Succ);
+      });
+    }
+    // Pre-allocate the vector, we want references to its contents to be stable.
+    EdgeInfos.reserve(NrEdges);
+    for (const auto &BB : F) {
+      auto &Info = getBBInfo(BB);
+      for (auto I = 0U; I < BB.getTerminator()->getNumSuccessors(); ++I) {
+        const auto *Succ = BB.getTerminator()->getSuccessor(I);
+        if (!shouldExcludeEdge(BB, *Succ)) {
+          auto &EI = EdgeInfos.emplace_back(getBBInfo(BB), getBBInfo(*Succ));
+          Info.addOutEdge(I, EI);
+          getBBInfo(*Succ).addInEdge(EI);
+        }
+      }
+    }
+    assert(EdgeInfos.capacity() == NrEdges &&
+           "The capacity of EdgeInfos should have stayed unchanged it was "
+           "populated, because we need pointers to its contents to be stable");
+    propagateCounterValues();
+  }
+
+  uint64_t getBBCount(const BasicBlock &BB) { return getBBInfo(BB).getCount(); }
+};
+
+} // namespace llvm
+
+ProfileAnnotator::ProfileAnnotator(const Function &F,
+                                   ArrayRef<uint64_t> RawCounters)
+    : PImpl(std::make_unique<ProfileAnnotatorImpl>(F, RawCounters)) {}
+
+ProfileAnnotator::~ProfileAnnotator() = default;
+
+uint64_t ProfileAnnotator::getBBCount(const BasicBlock &BB) const {
+  return PImpl->getBBCount(BB);
+}
+
+bool ProfileAnnotator::getSelectInstrProfile(SelectInst &SI,
+                                             uint64_t &TrueCount,
+                                             uint64_t &FalseCount) const {
+  const auto &BBInfo = PImpl->getBBInfo(*SI.getParent());
+  TrueCount = FalseCount = 0;
+  if (BBInfo.getCount() == 0)
+    return false;
+
+  auto *Step = CtxProfAnalysis::getSelectInstrumentation(SI);
+  if (!Step)
+    return false;
+  auto Index = Step->getIndex()->getZExtValue();
+  assert(Index < PImpl->Counters.size() &&
+         "The index of the step instruction must be inside the "
+         "counters vector by "
+         "construction - tripping this assertion indicates a bug in "
+         "how the contextual profile is managed by IPO transforms");
+  auto TotalCount = BBInfo.getCount();
+  TrueCount = PImpl->Counters[Index];
+  FalseCount = (TotalCount > TrueCount ? TotalCount - TrueCount : 0U);
+  return true;
+}
+
+bool ProfileAnnotator::getOutgoingBranchWeights(
+    BasicBlock &BB, SmallVectorImpl<uint64_t> &Profile,
+    uint64_t &MaxCount) const {
+  Profile.clear();
+
+  if (succ_size(&BB) < 2)
+    return false;
+
+  auto *Term = BB.getTerminator();
+  Profile.resize(Term->getNumSuccessors());
+
+  const auto &BBInfo = PImpl->getBBInfo(BB);
+  MaxCount = 0;
+  for (unsigned SuccIdx = 0, Size = BBInfo.getNumOutEdges(); SuccIdx < Size;
+       ++SuccIdx) {
+    uint64_t EdgeCount = BBInfo.getEdgeCount(SuccIdx);
+    if (EdgeCount > MaxCount)
+      MaxCount = EdgeCount;
+    Profile[SuccIdx] = EdgeCount;
+  }
+  return MaxCount > 0;
+}
+
 PreservedAnalyses AssignGUIDPass::run(Module &M, ModuleAnalysisManager &MAM) {
   for (auto &F : M.functions()) {
     if (F.isDeclaration())
diff --git a/llvm/lib/Transforms/Instrumentation/PGOCtxProfFlattening.cpp b/llvm/lib/Transforms/Instrumentation/PGOCtxProfFlattening.cpp
index 508a41684ed20..e47c9ab75ffe1 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOCtxProfFlattening.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOCtxProfFlattening.cpp
@@ -45,358 +45,33 @@ using namespace llvm;
 
 namespace {
 
-class ProfileAnnotator final {
-  class BBInfo;
-  struct EdgeInfo {
-    BBInfo *const Src;
-    BBInfo *const Dest;
-    std::optional<uint64_t> Count;
+/// Assign branch weights and function entry count. Also update the PSI
+/// builder.
+void assignProfileData(Function &F, ArrayRef<uint64_t> RawCounters) {
+  assert(!RawCounters.empty());
+  ProfileAnnotator PA(F, RawCounters);
 
-    explicit EdgeInfo(BBInfo &Src, BBInfo &Dest) : Src(&Src), Dest(&Dest) {}
-  };
+  F.setEntryCount(RawCounters[0]);
+  SmallVector<uint64_t, 2> ProfileHolder;
 
-  class BBInfo {
-    std::optional<uint64_t> Count;
-    // OutEdges is dimensioned to match the number of terminator operands.
-    // Entries in the vector match the index in the terminator operand list. In
-    // some cases - see `shouldExcludeEdge` and its implementation - an entry
-    // will be nullptr.
-    // InEdges doesn't have the above constraint.
-    SmallVector<EdgeInfo *> OutEdges;
-    SmallVector<EdgeInfo *> InEdges;
-    size_t UnknownCountOutEdges = 0;
-    size_t UnknownCountInEdges = 0;
-
-    // Pass AssumeAllKnown when we try to propagate counts from edges to BBs -
-    // because all the edge counters must be known.
-    // Return std::nullopt if there were no edges to sum. The user can decide
-    // how to interpret that.
-    std::optional<uint64_t> getEdgeSum(const SmallVector<EdgeInfo *> &Edges,
-                                       bool AssumeAllKnown) const {
-      std::optional<uint64_t> Sum;
-      for (const auto *E : Edges) {
-        // `Edges` may be `OutEdges`, case in which `E` could be nullptr.
-        if (E) {
-          if (!Sum.has_value())
-            Sum = 0;
-          *Sum += (AssumeAllKnown ? *E->Count : E->Count.value_or(0U));
-        }
-      }
-      return Sum;
-    }
-
-    bool computeCountFrom(const SmallVector<EdgeInfo *> &Edges) {
-      assert(!Count.has_value());
-      Count = getEdgeSum(Edges, true);
-      return Count.has_value();
-    }
-
-    void setSingleUnknownEdgeCount(SmallVector<EdgeInfo *> &Edges) {
-      uint64_t KnownSum = getEdgeSum(Edges, false).value_or(0U);
-      uint64_t EdgeVal = *Count > KnownSum ? *Count - KnownSum : 0U;
-      EdgeInfo *E = nullptr;
-      for (auto *I : Edges)
-        if (I && !I->Count.has_value()) {
-          E = I;
-#ifdef NDEBUG
-          break;
-#else
-          assert((!E || E == I) &&
-                 "Expected exactly one edge to have an unknown count, "
-                 "found a second one");
-          continue;
-#endif
-        }
-      assert(E && "Expected exactly one edge to have an unknown count");
-      assert(!E->Count.has_value());
-      E->Count = EdgeVal;
-      assert(E->Src->UnknownCountOutEdges > 0);
-      assert(E->Dest->UnknownCountInEdges > 0);
-      --E->Src->UnknownCountOutEdges;
-      --E->Dest->UnknownCountInEdges;
-    }
-
-  public:
-    BBInfo(size_t NumInEdges, size_t NumOutEdges, std::optional<uint64_t> Count)
-        : Count(Count) {
-      // For in edges, we just want to pre-allocate enough space, since we know
-      // it at this stage. For out edges, we will insert edges at the indices
-      // corresponding to positions in this BB's terminator instruction, so we
-      // construct a default (nullptr values)-initialized vector. A nullptr edge
-      // corresponds to those that are excluded (see shouldExcludeEdge).
-      InEdges.reserve(NumInEdges);
-      OutEdges.resize(NumOutEdges);
-    }
-
-    bool tryTakeCountFromKnownOutEdges(const BasicBlock &BB) {
-      if (!UnknownCountOutEdges) {
-        return computeCountFrom(OutEdges);
-      }
-      return false;
-    }
-
-    bool tryTakeCountFromKnownInEdges(const BasicBlock &BB) {
-      if (!UnknownCountInEdges) {
-        return computeCountFrom(InEdges);
-      }
-      return false;
-    }
-
-    void addInEdge(EdgeInfo &Info) {
-      InEdges.push_back(&Info);
-      ++UnknownCountInEdges;
-    }
-
-    // For the out edges, we care about the position we place them in, which is
-    // the position in terminator instruction's list (at construction). Later,
-    // we build branch_weights metadata with edge frequency values matching
-    // these positions.
-    void addOutEdge(size_t Index, EdgeInfo &Info) {
-      OutEdges[Index] = &Info;
-      ++UnknownCountOutEdges;
-    }
-
-    bool hasCount() const { return Count.has_value(); }
-
-    uint64_t getCount() const { return *Count; }
-
-    bool trySetSingleUnknownInEdgeCount() {
-      if (UnknownCountInEdges == 1) {
-        setSingleUnknownEdgeCount(InEdges);
-        return true;
-      }
-      return false;
-    }
-
-    bool trySetSingleUnknownOutEdgeCount() {
-      if (UnknownCountOutEdges == 1) {
-        setSingleUnknownEdgeCount(OutEdges);
-        return true;
-      }
-      return false;
-    }
-    size_t getNumOutEdges() const { return OutEdges.size(); }
-
-    uint64_t getEdgeCount(size_t Index) const {
-      if (auto *E = OutEdges[Index])
-        return *E->Count;
-      return 0U;
-    }
-  };
-
-  Function &F;
-  const SmallVectorImpl<uint64_t> &Counters;
-  // To be accessed through getBBInfo() after construction.
-  std::map<const BasicBlock *, BBInfo> BBInfos;
-  std::vector<EdgeInfo> EdgeInfos;
-
-  // This is an adaptation of PGOUseFunc::populateCounters.
-  // FIXME(mtrofin): look into factoring the code to share one implementation.
-  void propagateCounterValues(const SmallVectorImpl<uint64_t> &Counters) {
-    bool KeepGoing = true;
-    while (KeepGoing) {
-      KeepGoing = false;
-      for (const auto &BB : F) {
-        auto &Info = getBBInfo(BB);
-        if (!Info.hasCount())
-          KeepGoing |= Info.tryTakeCountFromKnownOutEdges(BB) ||
-                       Info.tryTakeCountFromKnownInEdges(BB);
-        if (Info.hasCount()) {
-          KeepGoing |= Info.trySetSingleUnknownOutEdgeCount();
-          KeepGoing |= Info.trySetSingleUnknownInEdgeCount();
-        }
-      }
-    }
-  }
-  // The only criteria for exclusion is faux suspend -> exit edges in presplit
-  // coroutines. The API serves for readability, currently.
-  bool shouldExcludeEdge(const BasicBlock &Src, const BasicBlock &Dest) const {
-    return llvm::isPresplitCoroSuspendExitEdge(Src, Dest);
-  }
-
-  BBInfo &getBBInfo(const BasicBlock &BB) { return BBInfos.find(&BB)->second; }
-
-  const BBInfo &getBBInfo(const BasicBlock &BB) const {
-    return BBInfos.find(&BB)->second;
-  }
-
-  // validation function after we propagate the counters: all BBs and edges'
-  // counters must have a value.
-  bool allCountersAreAssigned() const {
-    for (const auto &BBInfo : BBInfos)
-      if (!BBInfo.second.hasCount())
-        return false;
-    for (const auto &EdgeInfo : EdgeInfos)
-      if (!EdgeInfo.Count.has_value())
-        return false;
-    return true;
-  }
-
-  /// Check that all paths from the entry basic block that use edges with
-  /// non-zero counts arrive at a basic block with no successors (i.e. "exit")
-  bool allTakenPathsExit() const {
-    std::deque<const BasicBlock *> Worklist;
-    DenseSet<const BasicBlock *> Visited;
-    Worklist.push_back(&F.getEntryBlock());
-    bool HitExit = false;
-    while (!Worklist.empty()) {
-      const auto *BB = Worklist.front();
-      Worklist.pop_front();
-      if (!Visited.insert(BB).second)
-        continue;
-      if (succ_size(BB) == 0) {
-        if (isa<UnreachableInst>(BB->getTerminator()))
-          return false;
-        HitExit = true;
-        continue;
-      }
-      if (succ_size(BB) == 1) {
-        Worklist.push_back(BB->getUniqueSuccessor());
-        continue;
-      }
-      const auto &BBInfo = getBBInfo(*BB);
-      bool HasAWayOut = false;
-      for (auto I = 0U; I < BB->getTerminator()->getNumSuccessors(); ++I) {
-        const auto *Succ = BB->getTerminator()->getSuccessor(I);
-        if (!shouldExcludeEdge(*BB, *Succ)) {
-          if (BBInfo.getEdgeCount(I) > 0) {
-            HasAWayOut = true;
-            Worklist.push_back(Succ);
-          }
-        }
-      }
-      if (!HasAWayOut)
-        return false;
-    }
-    return HitExit;
-  }
-
-  bool allNonColdSelectsHaveProfile() const {
-    for (const auto &BB : F) {
-      if (getBBInfo(BB).getCount() > 0) {
-        for (const auto &I : BB) {
-          if (const auto *SI = dyn_cast<SelectInst>(&I)) {
-            if (!SI->getMetadata(LLVMContext::MD_prof)) {
-              return false;
-            }
-          }
-        }
-      }
-    }
-    return true;
-  }
-
-public:
-  ProfileAnnotator(Function &F, const SmallVectorImpl<uint64_t> &Counters)
-      : F(F), Counters(Counters) {
-    assert(!F.isDeclaration());
-    assert(!Counters.empty());
-    size_t NrEdges = 0;
-    for (const auto &BB : F) {
-      std::optional<uint64_t> Count;
-      if (auto *Ins = CtxProfAnalysis::getBBInstrumentation(
-              const_cast<BasicBlock &>(BB))) {
-        auto Index = Ins->getIndex()->getZExtValue();
-        assert(Index < Counters.size() &&
-               "The index must be inside the counters vector by construction - "
-               "tripping this assertion indicates a bug in how the contextual "
-               "profile is managed by IPO transforms");
-        (void)Index;
-        Count = Counters[Ins->getIndex()->getZExtValue()];
-      } else if (isa<UnreachableInst>(BB.getTerminator())) {
-        // The program presumably didn't crash.
-        Count = 0;
-      }
-      auto [It, Ins] =
-          BBInfos.insert({&BB, {pred_size(&BB), succ_size(&BB), Count}});
-      (void)Ins;
-      assert(Ins && "We iterate through the function's BBs, no reason to "
-                    "insert one more than once");
-      NrEdges += llvm::count_if(successors(&BB), [&](const auto *Succ) {
-        return !shouldExcludeEdge(BB, *Succ);
-      });
-    }
-    // Pre-allocate the vector, we want references to its contents to be stable.
-    EdgeInfos.reserve(NrEdges);
-    for (const auto &BB : F) {
-      auto &Info = getBBInfo(BB);
-      for (auto I = 0U; I < BB.getTerminator()->getNumSuccessors(); ++I) {
-        const auto *Succ = BB.getTerminator()->getSuccessor(I);
-        if (!shouldExcludeEdge(BB, *Succ)) {
-          auto &EI = EdgeInfos.emplace_back(getBBInfo(BB), getBBInfo(*Succ));
-          Info.addOutEdge(I, EI);
-          getBBInfo(*Succ).addInEdge(EI);
-        }
-      }
-    }
-    assert(EdgeInfos.capacity() == NrEdges &&
-           "The capacity of EdgeInfos should have stayed unchanged it was "
-           "populated, because we need pointers to its contents to be stable");
-  }
-
-  void setProfileForSelectInstructions(BasicBlock &BB, const BBInfo &BBInfo) {
-    if (BBInfo.getCount() == 0)
-      return;
-
-    for (auto &I : BB) {
+  for (auto &BB : F) {
+    for (auto &I : BB)
       if (auto *SI = dyn_cast<SelectInst>(&I)) {
-        if (auto *Step = CtxProfAnalysis::getSelectInstrumentation(*SI)) {
-          auto Index = Step->getIndex()->getZExtValue();
-          assert(Index < Counters.size() &&
-                 "The index of the step instruction must be inside the "
-                 "counters vector by "
-                 "construction - tripping this assertion indicates a bug in "
-                 "how the contextual profile is managed by IPO transforms");
-          auto TotalCount = BBInfo.getCount();
-          auto TrueCount = Counters[Index];
-          auto FalseCount =
-              (TotalCount > TrueCount ? TotalCount - TrueCount : 0U);
-          setProfMetadata(F.getParent(), SI, {TrueCount, FalseCount},
-                          std::max(TrueCount, FalseCount));
-        }
-      }
-    }
-  }
-
-  /// Assign branch weights and function entry count. Also update the PSI
-  /// builder.
-  void assignProfileData() {
-    assert(!Counters.empty());
-    propagateCounterValues(Counters);
-    F.setEntryCount(Counters[0]);
-
-    for (auto &BB : F) {
-      const auto &BBInfo = getBBInfo(BB);
-      setProfileForSelectInstructions(BB, BBInfo);
-      if (succ_size(&BB) < 2)
-        continue;
-      auto *Term = BB.getTerminator();
-      SmallVector<uint64_t, 2> EdgeCounts(Term->getNumSuccessors(), 0);
-      uint64_t MaxCount = 0;
-
-      for (unsigned SuccIdx = 0, Size = BBInfo.getNumOutEdges(); SuccIdx < Size;
-           ++SuccIdx) {
-        uint64_t EdgeCount = BBInfo.getEdgeCount(SuccIdx);
-        if (EdgeCount > MaxCount)
-          MaxCount = EdgeCount;
-        EdgeCounts[SuccIdx] = EdgeCount;
+        uint64_t TrueCount, FalseCount = 0;
+        if (!PA.getSelectInstrProfile(*SI, TrueCount, FalseCount))
+          continue;
+        setProfMetadata(F.getParent(), SI, {TrueCount, FalseCount},
+                        std::max(TrueCount, FalseCount));
       }
-
-      if (MaxCount != 0)
-        setProfMetadata(F.getParent(), Term, EdgeCounts, MaxCount);
-    }
-    assert(allCountersAreAssigned() &&
-           "[ctx-prof] Expected all counters have been assigned.");
-    assert(allTakenPathsExit() &&
-           "[ctx-prof] Encountered a BB with more than one successor, where "
-           "all outgoing edges have a 0 count. This occurs in non-exiting "
-           "functions (message pumps, usually) which are not supported in the "
-           "contextual profiling case");
-    assert(allNonColdSelectsHaveProfile() &&
-           "[ctx-prof] All non-cold select instructions were expected to have "
-           "a profile.");
+    if (succ_size(&BB) < 2)
+      continue;
+    uint64_t MaxCount = 0;
+    if (!PA.getOutgoingBranchWeights(BB, ProfileHolder, MaxCount))
+      continue;
+    assert(MaxCount > 0);
+    setProfMetadata(F.getParent(), BB.getTerminator(), ProfileHolder, MaxCount);
   }
-};
+}
 
 [[maybe_unused]] bool areAllBBsReachable(const Function &F,
                                          FunctionAnalysisManager &FAM) {
@@ -510,10 +185,8 @@ PreservedAnalyses PGOCtxProfFlatteningPass::run(Module &M,
     // If this function didn't appear in the contextual profile, it's cold.
     if (It == FlattenedProfile.end())
       clearColdFunctionProfile(F);
-    else {
-      ProfileAnnotator S(F, It->second);
-      S.assignProfileData();
-    }
+    else
+      assignProfileData(F, It->second);
   }
   InstrProfSummaryBuilder PB(ProfileSummaryBuilder::DefaultCutoffs);
   // use here the flat profiles just so the importer doesn't complain about



More information about the llvm-commits mailing list