[llvm] [BOLT] Support pre-aggregated returns (PR #143296)

Amir Ayupov via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 20 03:04:29 PDT 2025


https://github.com/aaupov updated https://github.com/llvm/llvm-project/pull/143296

>From 3a58be3e4f3fa04c7ca63f69a31fb730ce5b2f56 Mon Sep 17 00:00:00 2001
From: Amir Ayupov <aaupov at fb.com>
Date: Sat, 7 Jun 2025 21:24:15 -0700
Subject: [PATCH] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20change?=
 =?UTF-8?q?s=20to=20main=20this=20commit=20is=20based=20on?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Created using spr 1.3.4

[skip ci]
---
 bolt/include/bolt/Core/BinaryFunction.h       |  12 +
 .../bolt/Profile/BoltAddressTranslation.h     |   5 +-
 bolt/include/bolt/Profile/DataAggregator.h    |  66 +++--
 bolt/include/bolt/Profile/DataReader.h        |  15 +-
 .../include/bolt/Profile/ProfileYAMLMapping.h |   2 +
 bolt/lib/Core/BinaryFunction.cpp              |   2 +
 bolt/lib/Passes/ProfileQualityStats.cpp       |   3 +
 bolt/lib/Profile/BoltAddressTranslation.cpp   |  14 +-
 bolt/lib/Profile/DataAggregator.cpp           | 248 ++++++------------
 bolt/lib/Profile/DataReader.cpp               |   6 +
 bolt/lib/Profile/YAMLProfileReader.cpp        |   1 +
 bolt/lib/Profile/YAMLProfileWriter.cpp        |   1 +
 bolt/test/X86/pre-aggregated-perf.test        |   6 +-
 bolt/test/X86/shrinkwrapping.test             |   2 +
 14 files changed, 176 insertions(+), 207 deletions(-)

diff --git a/bolt/include/bolt/Core/BinaryFunction.h b/bolt/include/bolt/Core/BinaryFunction.h
index 14957cba50174..ca8b786f4ab69 100644
--- a/bolt/include/bolt/Core/BinaryFunction.h
+++ b/bolt/include/bolt/Core/BinaryFunction.h
@@ -388,6 +388,10 @@ class BinaryFunction {
   /// The profile data for the number of times the function was executed.
   uint64_t ExecutionCount{COUNT_NO_PROFILE};
 
+  /// Profile data for the number of times this function was entered from
+  /// external code (DSO, JIT, etc).
+  uint64_t ExternEntryCount{0};
+
   /// Profile match ratio.
   float ProfileMatchRatio{0.0f};
 
@@ -1877,6 +1881,10 @@ class BinaryFunction {
     return *this;
   }
 
+  /// Set the profile data for the number of times the function was entered from
+  /// external code (DSO/JIT).
+  void setExternEntryCount(uint64_t Count) { ExternEntryCount = Count; }
+
   /// Adjust execution count for the function by a given \p Count. The value
   /// \p Count will be subtracted from the current function count.
   ///
@@ -1904,6 +1912,10 @@ class BinaryFunction {
   /// Return COUNT_NO_PROFILE if there's no profile info.
   uint64_t getExecutionCount() const { return ExecutionCount; }
 
+  /// Return the profile information about the number of times the function was
+  /// entered from external code (DSO/JIT).
+  uint64_t getExternEntryCount() const { return ExternEntryCount; }
+
   /// Return the raw profile information about the number of branch
   /// executions corresponding to this function.
   uint64_t getRawSampleCount() const { return RawSampleCount; }
diff --git a/bolt/include/bolt/Profile/BoltAddressTranslation.h b/bolt/include/bolt/Profile/BoltAddressTranslation.h
index fcc578f35e322..bb1e28f02ea19 100644
--- a/bolt/include/bolt/Profile/BoltAddressTranslation.h
+++ b/bolt/include/bolt/Profile/BoltAddressTranslation.h
@@ -98,12 +98,13 @@ class BoltAddressTranslation {
                      bool IsBranchSrc) const;
 
   /// Use the map keys containing basic block addresses to infer fall-throughs
-  /// taken in the path started at FirstLBR.To and ending at SecondLBR.From.
+  /// taken in the path starting at \p From and ending at \p To.
   /// Return std::nullopt if trace is invalid or the list of fall-throughs
   /// otherwise.
   std::optional<FallthroughListTy> getFallthroughsInTrace(uint64_t FuncAddress,
                                                           uint64_t From,
-                                                          uint64_t To) const;
+                                                          uint64_t To,
+                                                          bool IsReturn) const;
 
   /// If available, fetch the address of the hot part linked to the cold part
   /// at \p Address. Return 0 otherwise.
diff --git a/bolt/include/bolt/Profile/DataAggregator.h b/bolt/include/bolt/Profile/DataAggregator.h
index cb8e81b829a09..098a934fd5c88 100644
--- a/bolt/include/bolt/Profile/DataAggregator.h
+++ b/bolt/include/bolt/Profile/DataAggregator.h
@@ -78,6 +78,13 @@ class DataAggregator : public DataReader {
   static bool checkPerfDataMagic(StringRef FileName);
 
 private:
+  struct LBREntry {
+    uint64_t From;
+    uint64_t To;
+    bool Mispred;
+  };
+  friend raw_ostream &operator<<(raw_ostream &OS, const LBREntry &);
+
   struct PerfBranchSample {
     SmallVector<LBREntry, 32> LBR;
   };
@@ -92,26 +99,31 @@ class DataAggregator : public DataReader {
     uint64_t Addr;
   };
 
+  /// Container for the unit of branch data.
+  /// Backwards compatible with legacy use for branches and fall-throughs:
+  /// - if \p Branch is FT_ONLY or FT_EXTERNAL_ORIGIN, the trace only contains
+  ///   fall-through data,
+  /// - if \p To is EXTERNAL, the trace only contains branch data.
   struct Trace {
+    static constexpr const uint64_t EXTERNAL = 0ULL;
+    static constexpr const uint64_t FT_ONLY = -1ULL;
+    static constexpr const uint64_t FT_EXTERNAL_ORIGIN = -2ULL;
+
+    uint64_t Branch;
     uint64_t From;
     uint64_t To;
-    Trace(uint64_t From, uint64_t To) : From(From), To(To) {}
     bool operator==(const Trace &Other) const {
-      return From == Other.From && To == Other.To;
+      return Branch == Other.Branch && From == Other.From && To == Other.To;
     }
   };
+  friend raw_ostream &operator<<(raw_ostream &OS, const Trace &);
 
   struct TraceHash {
     size_t operator()(const Trace &L) const {
-      return std::hash<uint64_t>()(L.From << 32 | L.To);
+      return llvm::hash_combine(L.Branch, L.From, L.To);
     }
   };
 
-  struct FTInfo {
-    uint64_t InternCount{0};
-    uint64_t ExternCount{0};
-  };
-
   struct TakenBranchInfo {
     uint64_t TakenCount{0};
     uint64_t MispredCount{0};
@@ -119,8 +131,11 @@ class DataAggregator : public DataReader {
 
   /// Intermediate storage for profile data. We save the results of parsing
   /// and use them later for processing and assigning profile.
-  std::unordered_map<Trace, TakenBranchInfo, TraceHash> BranchLBRs;
-  std::unordered_map<Trace, FTInfo, TraceHash> FallthroughLBRs;
+  std::unordered_map<Trace, TakenBranchInfo, TraceHash> TraceMap;
+  std::vector<std::pair<Trace, TakenBranchInfo>> Traces;
+  /// Pre-populated addresses of returns, coming from pre-aggregated data or
+  /// disassembly. Used to disambiguate call-continuation fall-throughs.
+  std::unordered_set<uint64_t> Returns;
   std::unordered_map<uint64_t, uint64_t> BasicSamples;
   std::vector<PerfMemSample> MemSamples;
 
@@ -193,8 +208,8 @@ class DataAggregator : public DataReader {
   /// Return a vector of offsets corresponding to a trace in a function
   /// if the trace is valid, std::nullopt otherwise.
   std::optional<SmallVector<std::pair<uint64_t, uint64_t>, 16>>
-  getFallthroughsInTrace(BinaryFunction &BF, const LBREntry &First,
-                         const LBREntry &Second, uint64_t Count = 1) const;
+  getFallthroughsInTrace(BinaryFunction &BF, const Trace &Trace, uint64_t Count,
+                         bool IsReturn) const;
 
   /// Record external entry into the function \p BF.
   ///
@@ -255,11 +270,10 @@ class DataAggregator : public DataReader {
                      uint64_t Mispreds);
 
   /// Register a \p Branch.
-  bool doBranch(uint64_t From, uint64_t To, uint64_t Count, uint64_t Mispreds);
+  bool doBranch(const Trace &Trace, uint64_t Count, uint64_t Mispreds);
 
   /// Register a trace between two LBR entries supplied in execution order.
-  bool doTrace(const LBREntry &First, const LBREntry &Second,
-               uint64_t Count = 1);
+  bool doTrace(const Trace &Trace, uint64_t Count);
 
   /// Parser helpers
   /// Return false if we exhausted our parser buffer and finished parsing
@@ -476,7 +490,6 @@ class DataAggregator : public DataReader {
 
   /// Debugging dump methods
   void dump() const;
-  void dump(const LBREntry &LBR) const;
   void dump(const PerfBranchSample &Sample) const;
   void dump(const PerfMemSample &Sample) const;
 
@@ -504,6 +517,27 @@ class DataAggregator : public DataReader {
 
   friend class YAMLProfileWriter;
 };
+
+inline raw_ostream &operator<<(raw_ostream &OS,
+                               const DataAggregator::LBREntry &L) {
+  OS << formatv("{0:x} -> {1:x}/{2}", L.From, L.To, L.Mispred ? 'M' : 'P');
+  return OS;
+}
+
+inline raw_ostream &operator<<(raw_ostream &OS,
+                               const DataAggregator::Trace &T) {
+  switch (T.Branch) {
+  case DataAggregator::Trace::FT_ONLY:
+  case DataAggregator::Trace::FT_EXTERNAL_ORIGIN:
+    break;
+  default:
+    OS << Twine::utohexstr(T.Branch) << " -> ";
+  }
+  OS << Twine::utohexstr(T.From);
+  if (T.To)
+    OS << " ... " << Twine::utohexstr(T.To);
+  return OS;
+}
 } // namespace bolt
 } // namespace llvm
 
diff --git a/bolt/include/bolt/Profile/DataReader.h b/bolt/include/bolt/Profile/DataReader.h
index 5df1b5a8f4a00..6f527ba3931d4 100644
--- a/bolt/include/bolt/Profile/DataReader.h
+++ b/bolt/include/bolt/Profile/DataReader.h
@@ -32,18 +32,6 @@ namespace bolt {
 
 class BinaryFunction;
 
-struct LBREntry {
-  uint64_t From;
-  uint64_t To;
-  bool Mispred;
-};
-
-inline raw_ostream &operator<<(raw_ostream &OS, const LBREntry &LBR) {
-  OS << "0x" << Twine::utohexstr(LBR.From) << " -> 0x"
-     << Twine::utohexstr(LBR.To);
-  return OS;
-}
-
 struct Location {
   bool IsSymbol;
   StringRef Name;
@@ -109,6 +97,9 @@ struct FuncBranchData {
   /// Total execution count for the function.
   int64_t ExecutionCount{0};
 
+  /// Total entry count from external code for the function.
+  uint64_t ExternEntryCount{0};
+
   /// Indicate if the data was used.
   bool Used{false};
 
diff --git a/bolt/include/bolt/Profile/ProfileYAMLMapping.h b/bolt/include/bolt/Profile/ProfileYAMLMapping.h
index a8d9a15311d94..41e2bd1651efd 100644
--- a/bolt/include/bolt/Profile/ProfileYAMLMapping.h
+++ b/bolt/include/bolt/Profile/ProfileYAMLMapping.h
@@ -206,6 +206,7 @@ struct BinaryFunctionProfile {
   uint32_t Id{0};
   llvm::yaml::Hex64 Hash{0};
   uint64_t ExecCount{0};
+  uint64_t ExternEntryCount{0};
   std::vector<BinaryBasicBlockProfile> Blocks;
   std::vector<InlineTreeNode> InlineTree;
   bool Used{false};
@@ -218,6 +219,7 @@ template <> struct MappingTraits<bolt::BinaryFunctionProfile> {
     YamlIO.mapRequired("fid", BFP.Id);
     YamlIO.mapRequired("hash", BFP.Hash);
     YamlIO.mapRequired("exec", BFP.ExecCount);
+    YamlIO.mapOptional("extern", BFP.ExternEntryCount, 0);
     YamlIO.mapRequired("nblocks", BFP.NumBasicBlocks);
     YamlIO.mapOptional("blocks", BFP.Blocks,
                        std::vector<bolt::BinaryBasicBlockProfile>());
diff --git a/bolt/lib/Core/BinaryFunction.cpp b/bolt/lib/Core/BinaryFunction.cpp
index 6d1969f5c6c30..b998d7160aae7 100644
--- a/bolt/lib/Core/BinaryFunction.cpp
+++ b/bolt/lib/Core/BinaryFunction.cpp
@@ -471,6 +471,8 @@ void BinaryFunction::print(raw_ostream &OS, std::string Annotation) {
     OS << "\n  Sample Count: " << RawSampleCount;
     OS << "\n  Profile Acc : " << format("%.1f%%", ProfileMatchRatio * 100.0f);
   }
+  if (ExternEntryCount)
+    OS << "\n  Extern Entry Count: " << ExternEntryCount;
 
   if (opts::PrintDynoStats && !getLayout().block_empty()) {
     OS << '\n';
diff --git a/bolt/lib/Passes/ProfileQualityStats.cpp b/bolt/lib/Passes/ProfileQualityStats.cpp
index dfd74d3dd5719..64cc662c3ab29 100644
--- a/bolt/lib/Passes/ProfileQualityStats.cpp
+++ b/bolt/lib/Passes/ProfileQualityStats.cpp
@@ -532,6 +532,9 @@ void computeFlowMappings(const BinaryContext &BC, FlowInfo &TotalFlowMap) {
     std::vector<uint64_t> &MaxCountMap = TotalMaxCountMaps[FunctionNum];
     std::vector<uint64_t> &MinCountMap = TotalMinCountMaps[FunctionNum];
 
+    // Record external entry count into CallGraphIncomingFlows
+    CallGraphIncomingFlows[FunctionNum] += Function->getExternEntryCount();
+
     // Update MaxCountMap, MinCountMap, and CallGraphIncomingFlows
     auto recordCall = [&](const BinaryBasicBlock *SourceBB,
                           const MCSymbol *DestSymbol, uint64_t Count,
diff --git a/bolt/lib/Profile/BoltAddressTranslation.cpp b/bolt/lib/Profile/BoltAddressTranslation.cpp
index a253522e4fb15..e1b046da237e4 100644
--- a/bolt/lib/Profile/BoltAddressTranslation.cpp
+++ b/bolt/lib/Profile/BoltAddressTranslation.cpp
@@ -511,8 +511,8 @@ uint64_t BoltAddressTranslation::translate(uint64_t FuncAddress,
 
 std::optional<BoltAddressTranslation::FallthroughListTy>
 BoltAddressTranslation::getFallthroughsInTrace(uint64_t FuncAddress,
-                                               uint64_t From,
-                                               uint64_t To) const {
+                                               uint64_t From, uint64_t To,
+                                               bool IsReturn) const {
   SmallVector<std::pair<uint64_t, uint64_t>, 16> Res;
 
   // Filter out trivial case
@@ -530,6 +530,12 @@ BoltAddressTranslation::getFallthroughsInTrace(uint64_t FuncAddress,
   auto FromIter = Map.upper_bound(From);
   if (FromIter == Map.begin())
     return Res;
+
+  // For fall-throughs originating at returns, go back one entry to cover call
+  // site.
+  if (IsReturn)
+    --FromIter;
+
   // Skip instruction entries, to create fallthroughs we are only interested in
   // BB boundaries
   do {
@@ -546,7 +552,7 @@ BoltAddressTranslation::getFallthroughsInTrace(uint64_t FuncAddress,
     return Res;
 
   for (auto Iter = FromIter; Iter != ToIter;) {
-    const uint32_t Src = Iter->first;
+    const uint32_t Src = Iter->second >> 1;
     if (Iter->second & BRANCHENTRY) {
       ++Iter;
       continue;
@@ -557,7 +563,7 @@ BoltAddressTranslation::getFallthroughsInTrace(uint64_t FuncAddress,
       ++Iter;
     if (Iter->second & BRANCHENTRY)
       break;
-    Res.emplace_back(Src, Iter->first);
+    Res.emplace_back(Src, Iter->second >> 1);
   }
 
   return Res;
diff --git a/bolt/lib/Profile/DataAggregator.cpp b/bolt/lib/Profile/DataAggregator.cpp
index 2527b5bfe38d2..b64953f21f1ce 100644
--- a/bolt/lib/Profile/DataAggregator.cpp
+++ b/bolt/lib/Profile/DataAggregator.cpp
@@ -587,8 +587,7 @@ void DataAggregator::processProfile(BinaryContext &BC) {
     llvm::stable_sort(MemEvents.second.Data);
 
   // Release intermediate storage.
-  clear(BranchLBRs);
-  clear(FallthroughLBRs);
+  clear(Traces);
   clear(BasicSamples);
   clear(MemSamples);
 }
@@ -716,14 +715,20 @@ bool DataAggregator::doInterBranch(BinaryFunction *FromFunc,
   return true;
 }
 
-bool DataAggregator::doBranch(uint64_t From, uint64_t To, uint64_t Count,
+bool DataAggregator::doBranch(const Trace &Trace, uint64_t Count,
                               uint64_t Mispreds) {
+  uint64_t From = Trace.Branch, To = Trace.From;
   // Returns whether \p Offset in \p Func contains a return instruction.
   auto checkReturn = [&](const BinaryFunction &Func, const uint64_t Offset) {
     auto isReturn = [&](auto MI) { return MI && BC->MIB->isReturn(*MI); };
-    return Func.hasInstructions()
-               ? isReturn(Func.getInstructionAtOffset(Offset))
-               : isReturn(Func.disassembleInstructionAtOffset(Offset));
+    const uint64_t Addr = Func.getAddress() + Offset;
+    if (llvm::is_contained(Returns, Addr))
+      return true;
+    if (Func.hasInstructions()
+            ? isReturn(Func.getInstructionAtOffset(Offset))
+            : isReturn(Func.disassembleInstructionAtOffset(Offset)))
+      return Returns.emplace(Addr).second;
+    return false;
   };
 
   // Mutates \p Addr to an offset into the containing function, performing BAT
@@ -733,8 +738,10 @@ bool DataAggregator::doBranch(uint64_t From, uint64_t To, uint64_t Count,
   // corresponds to a return (if \p IsFrom) or a call continuation (otherwise).
   auto handleAddress = [&](uint64_t &Addr, bool IsFrom) {
     BinaryFunction *Func = getBinaryFunctionContainingAddress(Addr);
-    if (!Func)
+    if (!Func) {
+      Addr = 0;
       return std::pair{Func, false};
+    }
 
     Addr -= Func->getAddress();
 
@@ -767,89 +774,59 @@ bool DataAggregator::doBranch(uint64_t From, uint64_t To, uint64_t Count,
   return doInterBranch(FromFunc, ToFunc, From, To, Count, Mispreds);
 }
 
-bool DataAggregator::doTrace(const LBREntry &First, const LBREntry &Second,
-                             uint64_t Count) {
-  BinaryFunction *FromFunc = getBinaryFunctionContainingAddress(First.To);
-  BinaryFunction *ToFunc = getBinaryFunctionContainingAddress(Second.From);
+bool DataAggregator::doTrace(const Trace &Trace, uint64_t Count) {
+  const uint64_t Branch = Trace.Branch, From = Trace.From, To = Trace.To;
+  BinaryFunction *FromFunc = getBinaryFunctionContainingAddress(From);
+  BinaryFunction *ToFunc = getBinaryFunctionContainingAddress(To);
   if (!FromFunc || !ToFunc) {
-    LLVM_DEBUG({
-      dbgs() << "Out of range trace starting in ";
-      if (FromFunc)
-        dbgs() << formatv("{0} @ {1:x}", *FromFunc,
-                          First.To - FromFunc->getAddress());
-      else
-        dbgs() << Twine::utohexstr(First.To);
-      dbgs() << " and ending in ";
-      if (ToFunc)
-        dbgs() << formatv("{0} @ {1:x}", *ToFunc,
-                          Second.From - ToFunc->getAddress());
-      else
-        dbgs() << Twine::utohexstr(Second.From);
-      dbgs() << '\n';
-    });
+    LLVM_DEBUG(dbgs() << "Out of range trace " << Trace << '\n');
     NumLongRangeTraces += Count;
     return false;
   }
   if (FromFunc != ToFunc) {
     NumInvalidTraces += Count;
-    LLVM_DEBUG({
-      dbgs() << "Invalid trace starting in " << FromFunc->getPrintName()
-             << formatv(" @ {0:x}", First.To - FromFunc->getAddress())
-             << " and ending in " << ToFunc->getPrintName()
-             << formatv(" @ {0:x}\n", Second.From - ToFunc->getAddress());
-    });
+    LLVM_DEBUG(dbgs() << "Invalid trace " << Trace << '\n');
     return false;
   }
 
+  // All branches are checked in doBranch except external addresses.
+  bool IsReturn = llvm::is_contained(Returns, Branch);
+
   // Set ParentFunc to BAT parent function or FromFunc itself.
   BinaryFunction *ParentFunc = getBATParentFunction(*FromFunc);
   if (!ParentFunc)
     ParentFunc = FromFunc;
-  ParentFunc->SampleCountInBytes += Count * (Second.From - First.To);
+  ParentFunc->SampleCountInBytes += Count * (To - From);
 
   const uint64_t FuncAddress = FromFunc->getAddress();
   std::optional<BoltAddressTranslation::FallthroughListTy> FTs =
       BAT && BAT->isBATFunction(FuncAddress)
-          ? BAT->getFallthroughsInTrace(FuncAddress, First.To, Second.From)
-          : getFallthroughsInTrace(*FromFunc, First, Second, Count);
+          ? BAT->getFallthroughsInTrace(FuncAddress, From, To, IsReturn)
+          : getFallthroughsInTrace(*FromFunc, Trace, Count, IsReturn);
   if (!FTs) {
-    LLVM_DEBUG(
-        dbgs() << "Invalid trace starting in " << FromFunc->getPrintName()
-               << " @ " << Twine::utohexstr(First.To - FromFunc->getAddress())
-               << " and ending in " << ToFunc->getPrintName() << " @ "
-               << ToFunc->getPrintName() << " @ "
-               << Twine::utohexstr(Second.From - ToFunc->getAddress()) << '\n');
+    LLVM_DEBUG(dbgs() << "Invalid trace " << Trace << '\n');
     NumInvalidTraces += Count;
     return false;
   }
 
   LLVM_DEBUG(dbgs() << "Processing " << FTs->size() << " fallthroughs for "
-                    << FromFunc->getPrintName() << ":"
-                    << Twine::utohexstr(First.To) << " to "
-                    << Twine::utohexstr(Second.From) << ".\n");
-  for (auto [From, To] : *FTs) {
-    if (BAT) {
-      From = BAT->translate(FromFunc->getAddress(), From, /*IsBranchSrc=*/true);
-      To = BAT->translate(FromFunc->getAddress(), To, /*IsBranchSrc=*/false);
-    }
+                    << FromFunc->getPrintName() << ":" << Trace << '\n');
+  for (auto [From, To] : *FTs)
     doIntraBranch(*ParentFunc, From, To, Count, false);
-  }
 
   return true;
 }
 
 std::optional<SmallVector<std::pair<uint64_t, uint64_t>, 16>>
-DataAggregator::getFallthroughsInTrace(BinaryFunction &BF,
-                                       const LBREntry &FirstLBR,
-                                       const LBREntry &SecondLBR,
-                                       uint64_t Count) const {
+DataAggregator::getFallthroughsInTrace(BinaryFunction &BF, const Trace &Trace,
+                                       uint64_t Count, bool IsReturn) const {
   SmallVector<std::pair<uint64_t, uint64_t>, 16> Branches;
 
   BinaryContext &BC = BF.getBinaryContext();
 
   // Offsets of the trace within this function.
-  const uint64_t From = FirstLBR.To - BF.getAddress();
-  const uint64_t To = SecondLBR.From - BF.getAddress();
+  const uint64_t From = Trace.From - BF.getAddress();
+  const uint64_t To = Trace.To - BF.getAddress();
 
   if (From > To)
     return std::nullopt;
@@ -876,8 +853,13 @@ DataAggregator::getFallthroughsInTrace(BinaryFunction &BF,
 
   // Adjust FromBB if the first LBR is a return from the last instruction in
   // the previous block (that instruction should be a call).
-  if (From == FromBB->getOffset() && !BF.containsAddress(FirstLBR.From) &&
-      !FromBB->isEntryPoint() && !FromBB->isLandingPad()) {
+  if (IsReturn) {
+    if (From)
+      FromBB = BF.getBasicBlockContainingOffset(From - 1);
+    else
+      LLVM_DEBUG(dbgs() << "return to the function start: " << Trace << '\n');
+  } else if (Trace.Branch == Trace::EXTERNAL && From == FromBB->getOffset() &&
+             !FromBB->isEntryPoint() && !FromBB->isLandingPad()) {
     const BinaryBasicBlock *PrevBB =
         BF.getLayout().getBlock(FromBB->getIndex() - 1);
     if (PrevBB->getSuccessor(FromBB->getLabel())) {
@@ -885,10 +867,9 @@ DataAggregator::getFallthroughsInTrace(BinaryFunction &BF,
       if (Instr && BC.MIB->isCall(*Instr))
         FromBB = PrevBB;
       else
-        LLVM_DEBUG(dbgs() << "invalid incoming LBR (no call): " << FirstLBR
-                          << '\n');
+        LLVM_DEBUG(dbgs() << "invalid trace (no call): " << Trace << '\n');
     } else {
-      LLVM_DEBUG(dbgs() << "invalid incoming LBR: " << FirstLBR << '\n');
+      LLVM_DEBUG(dbgs() << "invalid trace: " << Trace << '\n');
     }
   }
 
@@ -907,9 +888,7 @@ DataAggregator::getFallthroughsInTrace(BinaryFunction &BF,
 
     // Check for bad LBRs.
     if (!BB->getSuccessor(NextBB->getLabel())) {
-      LLVM_DEBUG(dbgs() << "no fall-through for the trace:\n"
-                        << "  " << FirstLBR << '\n'
-                        << "  " << SecondLBR << '\n');
+      LLVM_DEBUG(dbgs() << "no fall-through for the trace: " << Trace << '\n');
       return std::nullopt;
     }
 
@@ -972,7 +951,7 @@ bool DataAggregator::recordExit(BinaryFunction &BF, uint64_t From, bool Mispred,
   return true;
 }
 
-ErrorOr<LBREntry> DataAggregator::parseLBREntry() {
+ErrorOr<DataAggregator::LBREntry> DataAggregator::parseLBREntry() {
   LBREntry Res;
   ErrorOr<StringRef> FromStrRes = parseString('/');
   if (std::error_code EC = FromStrRes.getError())
@@ -1303,29 +1282,24 @@ std::error_code DataAggregator::parseAggregatedLBREntry() {
   if (ToFunc)
     ToFunc->setHasProfileAvailable();
 
-  Trace Trace(FromOffset, ToOffset);
-  // Taken trace
-  if (Type == TRACE || Type == BRANCH) {
-    TakenBranchInfo &Info = BranchLBRs[Trace];
-    Info.TakenCount += Count;
-    Info.MispredCount += Mispreds;
-
-    NumTotalSamples += Count;
+  if (Type == FT || Type == FT_EXTERNAL_ORIGIN) {
+    Addr[2] = Location(Addr[1]->Offset);
+    Addr[1] = Location(Addr[0]->Offset);
+    Addr[0] = Location(Type == FT ? Trace::FT_ONLY : Trace::FT_EXTERNAL_ORIGIN);
   }
-  // Construct fallthrough part of the trace
-  if (Type == TRACE) {
-    const uint64_t TraceFtEndOffset = Addr[2]->Offset;
-    Trace.From = ToOffset;
-    Trace.To = TraceFtEndOffset;
-    Type = FromFunc == ToFunc ? FT : FT_EXTERNAL_ORIGIN;
+
+  if (Type == BRANCH) {
+    Addr[2] = Location(Trace::EXTERNAL);
   }
-  // Add fallthrough trace
-  if (Type != BRANCH) {
-    FTInfo &Info = FallthroughLBRs[Trace];
-    (Type == FT ? Info.InternCount : Info.ExternCount) += Count;
 
+  Trace T{Addr[0]->Offset, Addr[1]->Offset, Addr[2]->Offset};
+  TakenBranchInfo TI{(uint64_t)Count, (uint64_t)Mispreds};
+
+  Traces.emplace_back(T, TI);
+
+  if (Addr[2]->Offset)
     NumTraces += Count;
-  }
+  NumTotalSamples += Count;
 
   return std::error_code();
 }
@@ -1373,12 +1347,9 @@ std::error_code DataAggregator::printLBRHeatMap() {
   // Register basic samples and perf LBR addresses not covered by fallthroughs.
   for (const auto &[PC, Hits] : BasicSamples)
     HM.registerAddress(PC, Hits);
-  for (const auto &LBR : FallthroughLBRs) {
-    const Trace &Trace = LBR.first;
-    const FTInfo &Info = LBR.second;
-    HM.registerAddressRange(Trace.From, Trace.To,
-                            Info.InternCount + Info.ExternCount);
-  }
+  for (const auto &[Trace, Info] : Traces)
+    if (Trace.To)
+      HM.registerAddressRange(Trace.From, Trace.To, Info.TakenCount);
 
   if (HM.getNumInvalidRanges())
     outs() << "HEATMAP: invalid traces: " << HM.getNumInvalidRanges() << '\n';
@@ -1424,62 +1395,11 @@ void DataAggregator::parseLBRSample(const PerfBranchSample &Sample,
     // chronological order)
     if (NeedsSkylakeFix && NumEntry <= 2)
       continue;
-    if (NextLBR) {
-      // Record fall-through trace.
-      const uint64_t TraceFrom = LBR.To;
-      const uint64_t TraceTo = NextLBR->From;
-      const BinaryFunction *TraceBF =
-          getBinaryFunctionContainingAddress(TraceFrom);
-      if (opts::HeatmapMode == opts::HeatmapModeKind::HM_Exclusive) {
-        FTInfo &Info = FallthroughLBRs[Trace(TraceFrom, TraceTo)];
-        ++Info.InternCount;
-      } else if (TraceBF && TraceBF->containsAddress(TraceTo)) {
-        FTInfo &Info = FallthroughLBRs[Trace(TraceFrom, TraceTo)];
-        if (TraceBF->containsAddress(LBR.From))
-          ++Info.InternCount;
-        else
-          ++Info.ExternCount;
-      } else {
-        const BinaryFunction *ToFunc =
-            getBinaryFunctionContainingAddress(TraceTo);
-        if (TraceBF && ToFunc) {
-          LLVM_DEBUG({
-            dbgs() << "Invalid trace starting in " << TraceBF->getPrintName()
-                   << formatv(" @ {0:x}", TraceFrom - TraceBF->getAddress())
-                   << formatv(" and ending @ {0:x}\n", TraceTo);
-          });
-          ++NumInvalidTraces;
-        } else {
-          LLVM_DEBUG({
-            dbgs() << "Out of range trace starting in "
-                   << (TraceBF ? TraceBF->getPrintName() : "None")
-                   << formatv(" @ {0:x}",
-                              TraceFrom - (TraceBF ? TraceBF->getAddress() : 0))
-                   << " and ending in "
-                   << (ToFunc ? ToFunc->getPrintName() : "None")
-                   << formatv(" @ {0:x}\n",
-                              TraceTo - (ToFunc ? ToFunc->getAddress() : 0));
-          });
-          ++NumLongRangeTraces;
-        }
-      }
-      ++NumTraces;
-    }
-    NextLBR = &LBR;
-
-    // Record branches outside binary functions for heatmap.
-    if (opts::HeatmapMode == opts::HeatmapModeKind::HM_Exclusive) {
-      TakenBranchInfo &Info = BranchLBRs[Trace(LBR.From, LBR.To)];
-      ++Info.TakenCount;
-      continue;
-    }
-    uint64_t From = getBinaryFunctionContainingAddress(LBR.From) ? LBR.From : 0;
-    uint64_t To = getBinaryFunctionContainingAddress(LBR.To) ? LBR.To : 0;
-    if (!From && !To)
-      continue;
-    TakenBranchInfo &Info = BranchLBRs[Trace(From, To)];
+    TakenBranchInfo &Info = TraceMap[Trace{
+        LBR.From, LBR.To, NextLBR ? NextLBR->From : Trace::EXTERNAL}];
     ++Info.TakenCount;
     Info.MispredCount += LBR.Mispred;
+    NextLBR = &LBR;
   }
   // Record LBR addresses not covered by fallthroughs (bottom-of-stack source
   // and top-of-stack target) as basic samples for heatmap.
@@ -1588,10 +1508,14 @@ std::error_code DataAggregator::parseBranchEvents() {
     parseLBRSample(Sample, NeedsSkylakeFix);
   }
 
-  for (const Trace &Trace : llvm::make_first_range(BranchLBRs))
-    for (const uint64_t Addr : {Trace.From, Trace.To})
+  Traces.reserve(TraceMap.size());
+  for (const auto &[Trace, Info] : TraceMap) {
+    Traces.emplace_back(Trace, Info);
+    for (const uint64_t Addr : {Trace.Branch, Trace.From})
       if (BinaryFunction *BF = getBinaryFunctionContainingAddress(Addr))
         BF->setHasProfileAvailable();
+  }
+  clear(TraceMap);
 
   outs() << "PERF2BOLT: read " << NumSamples << " samples and " << NumEntries
          << " LBR entries\n";
@@ -1616,23 +1540,12 @@ void DataAggregator::processBranchEvents() {
   NamedRegionTimer T("processBranch", "Processing branch events",
                      TimerGroupName, TimerGroupDesc, opts::TimeAggregator);
 
-  for (const auto &AggrLBR : FallthroughLBRs) {
-    const Trace &Loc = AggrLBR.first;
-    const FTInfo &Info = AggrLBR.second;
-    LBREntry First{Loc.From, Loc.From, false};
-    LBREntry Second{Loc.To, Loc.To, false};
-    if (Info.InternCount)
-      doTrace(First, Second, Info.InternCount);
-    if (Info.ExternCount) {
-      First.From = 0;
-      doTrace(First, Second, Info.ExternCount);
-    }
-  }
-
-  for (const auto &AggrLBR : BranchLBRs) {
-    const Trace &Loc = AggrLBR.first;
-    const TakenBranchInfo &Info = AggrLBR.second;
-    doBranch(Loc.From, Loc.To, Info.TakenCount, Info.MispredCount);
+  for (const auto &[Trace, Info] : Traces) {
+    if (Trace.Branch != Trace::FT_ONLY &&
+        Trace.Branch != Trace::FT_EXTERNAL_ORIGIN)
+      doBranch(Trace, Info.TakenCount, Info.MispredCount);
+    if (Trace.To)
+      doTrace(Trace, Info.TakenCount);
   }
   printBranchSamplesDiagnostics();
 }
@@ -2289,6 +2202,7 @@ std::error_code DataAggregator::writeBATYAML(BinaryContext &BC,
       YamlBF.Id = BF->getFunctionNumber();
       YamlBF.Hash = BAT->getBFHash(FuncAddress);
       YamlBF.ExecCount = BF->getKnownExecutionCount();
+      YamlBF.ExternEntryCount = BF->getExternEntryCount();
       YamlBF.NumBasicBlocks = BAT->getNumBasicBlocks(FuncAddress);
       const BoltAddressTranslation::BBHashMapTy &BlockMap =
           BAT->getBBHashMap(FuncAddress);
@@ -2398,16 +2312,10 @@ std::error_code DataAggregator::writeBATYAML(BinaryContext &BC,
 
 void DataAggregator::dump() const { DataReader::dump(); }
 
-void DataAggregator::dump(const LBREntry &LBR) const {
-  Diag << "From: " << Twine::utohexstr(LBR.From)
-       << " To: " << Twine::utohexstr(LBR.To) << " Mispred? " << LBR.Mispred
-       << "\n";
-}
-
 void DataAggregator::dump(const PerfBranchSample &Sample) const {
   Diag << "Sample LBR entries: " << Sample.LBR.size() << "\n";
   for (const LBREntry &LBR : Sample.LBR)
-    dump(LBR);
+    Diag << LBR << '\n';
 }
 
 void DataAggregator::dump(const PerfMemSample &Sample) const {
diff --git a/bolt/lib/Profile/DataReader.cpp b/bolt/lib/Profile/DataReader.cpp
index c512394f26a3b..afe24216d7f5d 100644
--- a/bolt/lib/Profile/DataReader.cpp
+++ b/bolt/lib/Profile/DataReader.cpp
@@ -85,6 +85,7 @@ void FuncBranchData::appendFrom(const FuncBranchData &FBD, uint64_t Offset) {
   }
   llvm::stable_sort(Data);
   ExecutionCount += FBD.ExecutionCount;
+  ExternEntryCount += FBD.ExternEntryCount;
   for (auto I = FBD.EntryData.begin(), E = FBD.EntryData.end(); I != E; ++I) {
     assert(I->To.Name == FBD.Name);
     auto NewElmt = EntryData.insert(EntryData.end(), *I);
@@ -269,6 +270,7 @@ Error DataReader::preprocessProfile(BinaryContext &BC) {
     if (FuncBranchData *FuncData = getBranchDataForNames(Function.getNames())) {
       setBranchData(Function, FuncData);
       Function.ExecutionCount = FuncData->ExecutionCount;
+      Function.ExternEntryCount = FuncData->ExternEntryCount;
       FuncData->Used = true;
     }
   }
@@ -419,6 +421,7 @@ void DataReader::matchProfileData(BinaryFunction &BF) {
       if (fetchProfileForOtherEntryPoints(BF)) {
         BF.ProfileMatchRatio = evaluateProfileData(BF, *FBD);
         BF.ExecutionCount = FBD->ExecutionCount;
+        BF.ExternEntryCount = FBD->ExternEntryCount;
         BF.RawSampleCount = FBD->getNumExecutedBranches();
       }
       return;
@@ -449,6 +452,7 @@ void DataReader::matchProfileData(BinaryFunction &BF) {
     setBranchData(BF, NewBranchData);
     NewBranchData->Used = true;
     BF.ExecutionCount = NewBranchData->ExecutionCount;
+    BF.ExternEntryCount = NewBranchData->ExternEntryCount;
     BF.ProfileMatchRatio = 1.0f;
     break;
   }
@@ -1190,6 +1194,8 @@ std::error_code DataReader::parse() {
     if (BI.To.IsSymbol && BI.To.Offset == 0) {
       I = GetOrCreateFuncEntry(BI.To.Name);
       I->second.ExecutionCount += BI.Branches;
+      if (!BI.From.IsSymbol)
+        I->second.ExternEntryCount += BI.Branches;
     }
   }
 
diff --git a/bolt/lib/Profile/YAMLProfileReader.cpp b/bolt/lib/Profile/YAMLProfileReader.cpp
index 33ce40ac2eeec..086e47b661e10 100644
--- a/bolt/lib/Profile/YAMLProfileReader.cpp
+++ b/bolt/lib/Profile/YAMLProfileReader.cpp
@@ -176,6 +176,7 @@ bool YAMLProfileReader::parseFunctionProfile(
   uint64_t FunctionExecutionCount = 0;
 
   BF.setExecutionCount(YamlBF.ExecCount);
+  BF.setExternEntryCount(YamlBF.ExternEntryCount);
 
   uint64_t FuncRawBranchCount = 0;
   for (const yaml::bolt::BinaryBasicBlockProfile &YamlBB : YamlBF.Blocks)
diff --git a/bolt/lib/Profile/YAMLProfileWriter.cpp b/bolt/lib/Profile/YAMLProfileWriter.cpp
index f1fe45f21a0f6..f4308d6fc1992 100644
--- a/bolt/lib/Profile/YAMLProfileWriter.cpp
+++ b/bolt/lib/Profile/YAMLProfileWriter.cpp
@@ -226,6 +226,7 @@ YAMLProfileWriter::convert(const BinaryFunction &BF, bool UseDFS,
   YamlBF.Hash = BF.getHash();
   YamlBF.NumBasicBlocks = BF.size();
   YamlBF.ExecCount = BF.getKnownExecutionCount();
+  YamlBF.ExternEntryCount = BF.getExternEntryCount();
   DenseMap<const MCDecodedPseudoProbeInlineTree *, uint32_t> InlineTreeNodeId;
   if (PseudoProbeDecoder && BF.getGUID()) {
     std::tie(YamlBF.InlineTree, InlineTreeNodeId) =
diff --git a/bolt/test/X86/pre-aggregated-perf.test b/bolt/test/X86/pre-aggregated-perf.test
index 92e093c238e00..cc79cbd339505 100644
--- a/bolt/test/X86/pre-aggregated-perf.test
+++ b/bolt/test/X86/pre-aggregated-perf.test
@@ -67,10 +67,10 @@ BASIC-ERROR: BOLT-INFO: 0 out of 7 functions in the binary (0.0%) have non-empty
 BASIC-SUCCESS: BOLT-INFO: 4 out of 7 functions in the binary (57.1%) have non-empty execution profile
 CHECK-BASIC-NL: no_lbr cycles
 
-PERF2BOLT: 0 [unknown] 7f36d18d60c0 1 main 53c 0 2
+PERF2BOLT: 0 [unknown] 0 1 main 53c 0 2
 PERF2BOLT: 1 main 451 1 SolveCubic 0 0 2
-PERF2BOLT: 1 main 490 0 [unknown] 4005f0 0 1
-PERF2BOLT: 1 main 537 0 [unknown] 400610 0 1
+PERF2BOLT: 1 main 490 0 [unknown] 0 0 1
+PERF2BOLT: 1 main 537 0 [unknown] 0 0 1
 PERF2BOLT: 1 usqrt 30 1 usqrt 32 0 22
 PERF2BOLT: 1 usqrt 30 1 usqrt 39 4 33
 PERF2BOLT: 1 usqrt 35 1 usqrt 39 0 22
diff --git a/bolt/test/X86/shrinkwrapping.test b/bolt/test/X86/shrinkwrapping.test
index 8581d7e0c0f7b..521b4561b3ba6 100644
--- a/bolt/test/X86/shrinkwrapping.test
+++ b/bolt/test/X86/shrinkwrapping.test
@@ -8,6 +8,7 @@ REQUIRES: shell
 
 RUN: %clangxx %cxxflags -no-pie %S/Inputs/exc4sw.S -o %t.exe -Wl,-q
 RUN: llvm-bolt %t.exe -o %t --relocs --frame-opt=all \
+RUN:   --print-only=main --print-cfg \
 RUN:   --data=%p/Inputs/exc4sw.fdata --reorder-blocks=cache 2>&1 | \
 RUN:   FileCheck %s --check-prefix=CHECK-BOLT
 
@@ -19,6 +20,7 @@ RUN: llvm-objdump --dwarf=frames %t | grep -A20 -e \
 RUN:   `llvm-nm --numeric-sort %t | grep main | tail -n 1 | cut -f1 -d' ' | \
 RUN:    tail -c9` 2>&1 | FileCheck %s --check-prefix=CHECK-OUTPUT
 
+CHECK-BOLT: Extern Entry Count: 100
 CHECK-BOLT: Shrink wrapping moved 2 spills inserting load/stores and 0 spills inserting push/pops
 
 CHECK-INPUT:  DW_CFA_advance_loc: 2



More information about the llvm-commits mailing list