[llvm] 414930b - [CSSPGO][llvm-profgen] Refactor to unify hashable interface for trace sample and context-sensitive counter

via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 13 11:07:23 PST 2021


Author: wlei
Date: 2021-01-13T11:02:57-08:00
New Revision: 414930b91bfd4196c457120932a1dbaf26db711d

URL: https://github.com/llvm/llvm-project/commit/414930b91bfd4196c457120932a1dbaf26db711d
DIFF: https://github.com/llvm/llvm-project/commit/414930b91bfd4196c457120932a1dbaf26db711d.diff

LOG: [CSSPGO][llvm-profgen] Refactor to unify hashable interface for trace sample and context-sensitive counter

As we plan to support both CSSPGO and AutoFDO for llvm-profgen, we will have different kinds of perf sample and different kinds of sample counter(cs/non-cs, with/without pseudo probe) which both need to do aggregation in hash map.  This change implements the hashable interface(`Hashable`) and the unified base class for them to have better extensibility and reusability.

Currently perf trace sample and sample counter with context implemented this `Hashable` and  the class hierarchy is like:

```
| Hashable
           | PerfSample
                          | HybridSample
                          | LBRSample
           | ContextKey
                          | StringBasedCtxKey
                          | ProbeBasedCtxKey
                          | CallsiteBasedCtxKey
           | ...
```

- Class specifying `Hashable` should implement `getHashCode` and `isEqual`. Here we make `getHashCode` a non-virtual function to avoid vtable overhead, so derived class should calculate and assign the base class's HashCode manually. This also provides the flexibility for calculating the hash code incrementally(like rolling hash) during frame stack unwinding
- `isEqual` is a virtual function, which will have perf overhead. In the future, if we redesign a better hash function, then we can just skip this or switch to non-virtual function.
- Added `PerfSample` and `ContextKey` as base class for perf sample and counter context key, leveraging llvm-style RTTI for this.
- Added `StringBasedCtxKey` class extending  `ContextKey` to use string as context id.
- Refactor `AggregationCounter` to take all kinds of `PerfSample` as key
- Refactor `ContextSampleCounter` to take all kinds of `ContextKey` as key
- Other refactoring work:
 - Create a wrapper class `SampleCounter` to wrap `RangeCounter` and `BranchCounter`
 - Hoist `ContextId` and `FunctionProfile` out of `populateFunctionBodySamples` and `populateFunctionBoundarySamples` to reuse them in ProfileGenerator

Differential Revision: https://reviews.llvm.org/D92584

Added: 
    

Modified: 
    llvm/tools/llvm-profgen/PerfReader.cpp
    llvm/tools/llvm-profgen/PerfReader.h
    llvm/tools/llvm-profgen/ProfileGenerator.cpp
    llvm/tools/llvm-profgen/ProfileGenerator.h

Removed: 
    


################################################################################
diff  --git a/llvm/tools/llvm-profgen/PerfReader.cpp b/llvm/tools/llvm-profgen/PerfReader.cpp
index 6a0d54e2de87..1ed5e2917cba 100644
--- a/llvm/tools/llvm-profgen/PerfReader.cpp
+++ b/llvm/tools/llvm-profgen/PerfReader.cpp
@@ -72,26 +72,39 @@ void VirtualUnwinder::unwindBranchWithinFrame(UnwindState &State) {
   State.InstPtr.update(Source);
 }
 
+SampleCounter &
+VirtualUnwinder::getOrCreateSampleCounter(const ProfiledBinary *Binary,
+                                          std::list<uint64_t> &CallStack) {
+  std::shared_ptr<StringBasedCtxKey> KeyStr =
+      std::make_shared<StringBasedCtxKey>();
+  KeyStr->Context = Binary->getExpandedContextStr(CallStack);
+  KeyStr->genHashCode();
+  auto Ret =
+      CtxCounterMap->emplace(Hashable<ContextKey>(KeyStr), SampleCounter());
+  return Ret.first->second;
+}
+
 void VirtualUnwinder::recordRangeCount(uint64_t Start, uint64_t End,
                                        UnwindState &State, uint64_t Repeat) {
-  std::string &&ContextId = State.getExpandedContextStr();
   uint64_t StartOffset = State.getBinary()->virtualAddrToOffset(Start);
   uint64_t EndOffset = State.getBinary()->virtualAddrToOffset(End);
-  SampleCounters->recordRangeCount(ContextId, StartOffset, EndOffset, Repeat);
+  SampleCounter &SCounter =
+      getOrCreateSampleCounter(State.getBinary(), State.CallStack);
+  SCounter.recordRangeCount(StartOffset, EndOffset, Repeat);
 }
 
 void VirtualUnwinder::recordBranchCount(const LBREntry &Branch,
                                         UnwindState &State, uint64_t Repeat) {
   if (Branch.IsArtificial)
     return;
-  std::string &&ContextId = State.getExpandedContextStr();
   uint64_t SourceOffset = State.getBinary()->virtualAddrToOffset(Branch.Source);
   uint64_t TargetOffset = State.getBinary()->virtualAddrToOffset(Branch.Target);
-  SampleCounters->recordBranchCount(ContextId, SourceOffset, TargetOffset,
-                                    Repeat);
+  SampleCounter &SCounter =
+      getOrCreateSampleCounter(State.getBinary(), State.CallStack);
+  SCounter.recordBranchCount(SourceOffset, TargetOffset, Repeat);
 }
 
-bool VirtualUnwinder::unwind(const HybridSample &Sample, uint64_t Repeat) {
+bool VirtualUnwinder::unwind(const HybridSample *Sample, uint64_t Repeat) {
   // Capture initial state as starting point for unwinding.
   UnwindState State(Sample);
 
@@ -198,10 +211,10 @@ ProfiledBinary *PerfReader::getBinary(uint64_t Address) {
   return Iter->second;
 }
 
-static void printSampleCounter(ContextRangeCounter &Counter) {
-  // Use ordered map to make the output deterministic
-  std::map<std::string, RangeSample> OrderedCounter(Counter.begin(),
-                                                    Counter.end());
+// Use ordered map to make the output deterministic
+using OrderedCounterForPrint = std::map<StringRef, RangeSample>;
+
+static void printSampleCounter(OrderedCounterForPrint &OrderedCounter) {
   for (auto Range : OrderedCounter) {
     outs() << Range.first << "\n";
     for (auto I : Range.second) {
@@ -211,20 +224,40 @@ static void printSampleCounter(ContextRangeCounter &Counter) {
   }
 }
 
+static void printRangeCounter(ContextSampleCounterMap &Counter) {
+  OrderedCounterForPrint OrderedCounter;
+  for (auto &CI : Counter) {
+    const StringBasedCtxKey *CtxKey =
+        dyn_cast<StringBasedCtxKey>(CI.first.getPtr());
+    OrderedCounter[CtxKey->Context] = CI.second.RangeCounter;
+  }
+  printSampleCounter(OrderedCounter);
+}
+
+static void printBranchCounter(ContextSampleCounterMap &Counter) {
+  OrderedCounterForPrint OrderedCounter;
+  for (auto &CI : Counter) {
+    const StringBasedCtxKey *CtxKey =
+        dyn_cast<StringBasedCtxKey>(CI.first.getPtr());
+    OrderedCounter[CtxKey->Context] = CI.second.BranchCounter;
+  }
+  printSampleCounter(OrderedCounter);
+}
+
 void PerfReader::printUnwinderOutput() {
   for (auto I : BinarySampleCounters) {
     const ProfiledBinary *Binary = I.first;
     outs() << "Binary(" << Binary->getName().str() << ")'s Range Counter:\n";
-    printSampleCounter(I.second.RangeCounter);
+    printRangeCounter(I.second);
     outs() << "\nBinary(" << Binary->getName().str() << ")'s Branch Counter:\n";
-    printSampleCounter(I.second.BranchCounter);
+    printBranchCounter(I.second);
   }
 }
 
 void PerfReader::unwindSamples() {
   for (const auto &Item : AggregatedSamples) {
-    const HybridSample &Sample = Item.first;
-    VirtualUnwinder Unwinder(&BinarySampleCounters[Sample.Binary]);
+    const HybridSample *Sample = dyn_cast<HybridSample>(Item.first.getPtr());
+    VirtualUnwinder Unwinder(&BinarySampleCounters[Sample->Binary]);
     Unwinder.unwind(Sample, Item.second);
   }
 
@@ -366,26 +399,27 @@ void PerfReader::parseHybridSample(TraceStream &TraceIt) {
   // 0x4005c8/0x4005dc/P/-/-/0   0x40062f/0x4005b0/P/-/-/0 ...
   //          ... 0x4005c8/0x4005dc/P/-/-/0    # LBR Entries
   //
-  HybridSample Sample;
+  std::shared_ptr<HybridSample> Sample = std::make_shared<HybridSample>();
 
   // Parsing call stack and populate into HybridSample.CallStack
-  if (!extractCallstack(TraceIt, Sample.CallStack)) {
+  if (!extractCallstack(TraceIt, Sample->CallStack)) {
     // Skip the next LBR line matched current call stack
     if (!TraceIt.isAtEoF() && TraceIt.getCurrentLine().startswith(" 0x"))
       TraceIt.advance();
     return;
   }
   // Set the binary current sample belongs to
-  Sample.Binary = getBinary(Sample.CallStack.front());
+  Sample->Binary = getBinary(Sample->CallStack.front());
 
   if (!TraceIt.isAtEoF() && TraceIt.getCurrentLine().startswith(" 0x")) {
     // Parsing LBR stack and populate into HybridSample.LBRStack
-    if (extractLBRStack(TraceIt, Sample.LBRStack, Sample.Binary)) {
+    if (extractLBRStack(TraceIt, Sample->LBRStack, Sample->Binary)) {
       // Canonicalize stack leaf to avoid 'random' IP from leaf frame skew LBR
       // ranges
-      Sample.CallStack.front() = Sample.LBRStack[0].Target;
+      Sample->CallStack.front() = Sample->LBRStack[0].Target;
       // Record samples by aggregation
-      AggregatedSamples[Sample]++;
+      Sample->genHashCode();
+      AggregatedSamples[Hashable<PerfSample>(Sample)]++;
     }
   } else {
     // LBR sample is encoded in single line after stack sample

diff  --git a/llvm/tools/llvm-profgen/PerfReader.h b/llvm/tools/llvm-profgen/PerfReader.h
index 9883ba4b37a1..5c2159daae39 100644
--- a/llvm/tools/llvm-profgen/PerfReader.h
+++ b/llvm/tools/llvm-profgen/PerfReader.h
@@ -10,6 +10,7 @@
 #define LLVM_TOOLS_LLVM_PROFGEN_PERFREADER_H
 #include "ErrorHandling.h"
 #include "ProfiledBinary.h"
+#include "llvm/Support/Casting.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Regex.h"
 #include <fstream>
@@ -75,8 +76,60 @@ struct LBREntry {
       : Source(S), Target(T), IsArtificial(I) {}
 };
 
+// Hash interface for generic data of type T
+// Data should implement a \fn getHashCode and a \fn isEqual
+// Currently getHashCode is non-virtual to avoid the overhead of calling vtable,
+// i.e we explicitly calculate hash of derived class, assign to base class's
+// HashCode. This also provides the flexibility for calculating the hash code
+// incrementally(like rolling hash) during frame stack unwinding since unwinding
+// only changes the leaf of frame stack. \fn isEqual is a virtual function,
+// which will have perf overhead. In the future, if we redesign a better hash
+// function, then we can just skip this or switch to non-virtual function(like
+// just ignore comparision if hash conflicts probabilities is low)
+template <class T> class Hashable {
+public:
+  std::shared_ptr<T> Data;
+  Hashable(const std::shared_ptr<T> &D) : Data(D) {}
+
+  // Hash code generation
+  struct Hash {
+    uint64_t operator()(const Hashable<T> &Key) const {
+      // Don't make it virtual for getHashCode
+      assert(Key.Data->getHashCode() && "Should generate HashCode for it!");
+      return Key.Data->getHashCode();
+    }
+  };
+
+  // Hash equal
+  struct Equal {
+    bool operator()(const Hashable<T> &LHS, const Hashable<T> &RHS) const {
+      // Precisely compare the data, vtable will have overhead.
+      return LHS.Data->isEqual(RHS.Data.get());
+    }
+  };
+
+  T *getPtr() const { return Data.get(); }
+};
+
+// Base class to extend for all types of perf sample
+struct PerfSample {
+  uint64_t HashCode = 0;
+
+  virtual ~PerfSample() = default;
+  uint64_t getHashCode() const { return HashCode; }
+  virtual bool isEqual(const PerfSample *K) const {
+    return HashCode == K->HashCode;
+  };
+
+  // Utilities for LLVM-style RTTI
+  enum PerfKind { PK_HybridSample };
+  const PerfKind Kind;
+  PerfKind getKind() const { return Kind; }
+  PerfSample(PerfKind K) : Kind(K){};
+};
+
 // The parsed hybrid sample including call stack and LBR stack.
-struct HybridSample {
+struct HybridSample : public PerfSample {
   // Profiled binary that current frame address belongs to
   ProfiledBinary *Binary;
   // Call stack recorded in FILO(leaf to root) order
@@ -84,12 +137,18 @@ struct HybridSample {
   // LBR stack recorded in FIFO order
   SmallVector<LBREntry, 16> LBRStack;
 
+  HybridSample() : PerfSample(PK_HybridSample){};
+  static bool classof(const PerfSample *K) {
+    return K->getKind() == PK_HybridSample;
+  }
+
   // Used for sample aggregation
-  bool operator==(const HybridSample &Other) const {
-    if (Other.Binary != Binary)
+  bool isEqual(const PerfSample *K) const override {
+    const HybridSample *Other = dyn_cast<HybridSample>(K);
+    if (Other->Binary != Binary)
       return false;
-    const std::list<uint64_t> &OtherCallStack = Other.CallStack;
-    const SmallVector<LBREntry, 16> &OtherLBRStack = Other.LBRStack;
+    const std::list<uint64_t> &OtherCallStack = Other->CallStack;
+    const SmallVector<LBREntry, 16> &OtherLBRStack = Other->LBRStack;
 
     if (CallStack.size() != OtherCallStack.size() ||
         LBRStack.size() != OtherLBRStack.size())
@@ -108,8 +167,32 @@ struct HybridSample {
     }
     return true;
   }
+
+  void genHashCode() {
+    // Use simple DJB2 hash
+    auto HashCombine = [](uint64_t H, uint64_t V) {
+      return ((H << 5) + H) + V;
+    };
+    uint64_t Hash = 5381;
+    Hash = HashCombine(Hash, reinterpret_cast<uint64_t>(Binary));
+    for (const auto &Value : CallStack) {
+      Hash = HashCombine(Hash, Value);
+    }
+    for (const auto &Entry : LBRStack) {
+      Hash = HashCombine(Hash, Entry.Source);
+      Hash = HashCombine(Hash, Entry.Target);
+    }
+    HashCode = Hash;
+  }
 };
 
+// After parsing the sample, we record the samples by aggregating them
+// into this counter. The key stores the sample data and the value is
+// the sample repeat times.
+using AggregatedCounter =
+    std::unordered_map<Hashable<PerfSample>, uint64_t,
+                       Hashable<PerfSample>::Hash, Hashable<PerfSample>::Equal>;
+
 // The state for the unwinder, it doesn't hold the data but only keep the
 // pointer/index of the data, While unwinding, the CallStack is changed
 // dynamicially and will be recorded as the context of the sample
@@ -124,10 +207,10 @@ struct UnwindState {
   const SmallVector<LBREntry, 16> &LBRStack;
   // Used to iterate the address range
   InstructionPointer InstPtr;
-  UnwindState(const HybridSample &Sample)
-      : Binary(Sample.Binary), CallStack(Sample.CallStack),
-        LBRStack(Sample.LBRStack),
-        InstPtr(Sample.Binary, Sample.CallStack.front()) {}
+  UnwindState(const HybridSample *Sample)
+      : Binary(Sample->Binary), CallStack(Sample->CallStack),
+        LBRStack(Sample->LBRStack),
+        InstPtr(Sample->Binary, Sample->CallStack.front()) {}
 
   bool validateInitialState() {
     uint64_t LBRLeaf = LBRStack[LBRIndex].Target;
@@ -160,56 +243,61 @@ struct UnwindState {
   void advanceLBR() { LBRIndex++; }
 };
 
+// Base class for sample counter key with context
+struct ContextKey {
+  uint64_t HashCode = 0;
+  virtual ~ContextKey() = default;
+  uint64_t getHashCode() const { return HashCode; }
+  virtual bool isEqual(const ContextKey *K) const {
+    return HashCode == K->HashCode;
+  };
+
+  // Utilities for LLVM-style RTTI
+  enum ContextKind { CK_StringBased };
+  const ContextKind Kind;
+  ContextKind getKind() const { return Kind; }
+  ContextKey(ContextKind K) : Kind(K){};
+};
+
+// String based context id
+struct StringBasedCtxKey : public ContextKey {
+  std::string Context;
+  StringBasedCtxKey() : ContextKey(CK_StringBased){};
+  static bool classof(const ContextKey *K) {
+    return K->getKind() == CK_StringBased;
+  }
+
+  bool isEqual(const ContextKey *K) const override {
+    const StringBasedCtxKey *Other = dyn_cast<StringBasedCtxKey>(K);
+    return Context == Other->Context;
+  }
+
+  void genHashCode() { HashCode = hash_value(Context); }
+};
+
 // The counter of branch samples for one function indexed by the branch,
 // which is represented as the source and target offset pair.
 using BranchSample = std::map<std::pair<uint64_t, uint64_t>, uint64_t>;
 // The counter of range samples for one function indexed by the range,
 // which is represented as the start and end offset pair.
 using RangeSample = std::map<std::pair<uint64_t, uint64_t>, uint64_t>;
-// Range sample counters indexed by the context string
-using ContextRangeCounter = std::unordered_map<std::string, RangeSample>;
-// Branch sample counters indexed by the context string
-using ContextBranchCounter = std::unordered_map<std::string, BranchSample>;
-
-// For Hybrid sample counters
-struct ContextSampleCounters {
-  ContextRangeCounter RangeCounter;
-  ContextBranchCounter BranchCounter;
-
-  void recordRangeCount(std::string &ContextId, uint64_t Start, uint64_t End,
-                        uint64_t Repeat) {
-    RangeCounter[ContextId][{Start, End}] += Repeat;
-  }
-  void recordBranchCount(std::string &ContextId, uint64_t Source,
-                         uint64_t Target, uint64_t Repeat) {
-    BranchCounter[ContextId][{Source, Target}] += Repeat;
-  }
-};
+// Wrapper for sample counters including range counter and branch counter
+struct SampleCounter {
+  RangeSample RangeCounter;
+  BranchSample BranchCounter;
 
-struct HybridSampleHash {
-  uint64_t hashCombine(uint64_t Hash, uint64_t Value) const {
-    // Simple DJB2 hash
-    return ((Hash << 5) + Hash) + Value;
+  void recordRangeCount(uint64_t Start, uint64_t End, uint64_t Repeat) {
+    RangeCounter[{Start, End}] += Repeat;
   }
-
-  uint64_t operator()(const HybridSample &Sample) const {
-    uint64_t Hash = 5381;
-    Hash = hashCombine(Hash, reinterpret_cast<uint64_t>(Sample.Binary));
-    for (const auto &Value : Sample.CallStack) {
-      Hash = hashCombine(Hash, Value);
-    }
-    for (const auto &Entry : Sample.LBRStack) {
-      Hash = hashCombine(Hash, Entry.Source);
-      Hash = hashCombine(Hash, Entry.Target);
-    }
-    return Hash;
+  void recordBranchCount(uint64_t Source, uint64_t Target, uint64_t Repeat) {
+    BranchCounter[{Source, Target}] += Repeat;
   }
 };
 
-// After parsing the sample, we record the samples by aggregating them
-// into this structure and the value is the sample counter.
-using AggregationCounter =
-    std::unordered_map<HybridSample, uint64_t, HybridSampleHash>;
+// Sample counter with context to support context-sensitive profile
+using ContextSampleCounterMap =
+    std::unordered_map<Hashable<ContextKey>, SampleCounter,
+                       Hashable<ContextKey>::Hash, Hashable<ContextKey>::Equal>;
 
 /*
 As in hybrid sample we have a group of LBRs and the most recent sampling call
@@ -232,7 +320,7 @@ range as sample counter for further CS profile generation.
 */
 class VirtualUnwinder {
 public:
-  VirtualUnwinder(ContextSampleCounters *Counters) : SampleCounters(Counters) {}
+  VirtualUnwinder(ContextSampleCounterMap *Counter) : CtxCounterMap(Counter) {}
 
   bool isCallState(UnwindState &State) const {
     // The tail call frame is always missing here in stack sample, we will
@@ -250,14 +338,16 @@ class VirtualUnwinder {
   void unwindLinear(UnwindState &State, uint64_t Repeat);
   void unwindReturn(UnwindState &State);
   void unwindBranchWithinFrame(UnwindState &State);
-  bool unwind(const HybridSample &Sample, uint64_t Repeat);
+  bool unwind(const HybridSample *Sample, uint64_t Repeat);
   void recordRangeCount(uint64_t Start, uint64_t End, UnwindState &State,
                         uint64_t Repeat);
   void recordBranchCount(const LBREntry &Branch, UnwindState &State,
                          uint64_t Repeat);
+  SampleCounter &getOrCreateSampleCounter(const ProfiledBinary *Binary,
+                                          std::list<uint64_t> &CallStack);
 
 private:
-  ContextSampleCounters *SampleCounters;
+  ContextSampleCounterMap *CtxCounterMap;
 };
 
 // Filename to binary map
@@ -268,7 +358,7 @@ using AddressBinaryMap = std::map<uint64_t, ProfiledBinary *>;
 // same binary loaded at 
diff erent addresses, they should share the same sample
 // counter
 using BinarySampleCounterMap =
-    std::unordered_map<ProfiledBinary *, ContextSampleCounters>;
+    std::unordered_map<ProfiledBinary *, ContextSampleCounterMap>;
 
 // Load binaries and read perf trace to parse the events and samples
 class PerfReader {
@@ -344,7 +434,7 @@ class PerfReader {
 private:
   BinarySampleCounterMap BinarySampleCounters;
   // Samples with the repeating time generated by the perf reader
-  AggregationCounter AggregatedSamples;
+  AggregatedCounter AggregatedSamples;
   PerfScriptType PerfType;
 };
 

diff  --git a/llvm/tools/llvm-profgen/ProfileGenerator.cpp b/llvm/tools/llvm-profgen/ProfileGenerator.cpp
index a92236ca6909..0a4978b67105 100644
--- a/llvm/tools/llvm-profgen/ProfileGenerator.cpp
+++ b/llvm/tools/llvm-profgen/ProfileGenerator.cpp
@@ -178,95 +178,76 @@ void CSProfileGenerator::updateBodySamplesforFunctionProfile(
   }
 }
 
-void CSProfileGenerator::populateFunctionBodySamples() {
-  for (const auto &BI : BinarySampleCounters) {
-    ProfiledBinary *Binary = BI.first;
-    for (const auto &CI : BI.second.RangeCounter) {
-      StringRef ContextId(CI.first);
-      // Get or create function profile for the range
-      FunctionSamples &FunctionProfile =
-          getFunctionProfileForContext(ContextId);
-      // Compute disjoint ranges first, so we can use MAX
-      // for calculating count for each location.
-      RangeSample Ranges;
-      findDisjointRanges(Ranges, CI.second);
-
-      for (auto Range : Ranges) {
-        uint64_t RangeBegin = Binary->offsetToVirtualAddr(Range.first.first);
-        uint64_t RangeEnd = Binary->offsetToVirtualAddr(Range.first.second);
-        uint64_t Count = Range.second;
-        // Disjoint ranges have introduce zero-filled gap that
-        // doesn't belong to current context, filter them out.
-        if (Count == 0)
-          continue;
-
-        InstructionPointer IP(Binary, RangeBegin, true);
-
-        // Disjoint ranges may have range in the middle of two instr,
-        // e.g. If Instr1 at Addr1, and Instr2 at Addr2, disjoint range
-        // can be Addr1+1 to Addr2-1. We should ignore such range.
-        if (IP.Address > RangeEnd)
-          continue;
-
-        while (IP.Address <= RangeEnd) {
-          uint64_t Offset = Binary->virtualAddrToOffset(IP.Address);
-          const FrameLocation &LeafLoc = Binary->getInlineLeafFrameLoc(Offset);
-          // Recording body sample for this specific context
-          updateBodySamplesforFunctionProfile(FunctionProfile, LeafLoc, Count);
-          // Move to next IP within the range
-          IP.advance();
-        }
-      }
+void CSProfileGenerator::populateFunctionBodySamples(
+    FunctionSamples &FunctionProfile, const RangeSample &RangeCounter,
+    ProfiledBinary *Binary) {
+  // Compute disjoint ranges first, so we can use MAX
+  // for calculating count for each location.
+  RangeSample Ranges;
+  findDisjointRanges(Ranges, RangeCounter);
+  for (auto Range : Ranges) {
+    uint64_t RangeBegin = Binary->offsetToVirtualAddr(Range.first.first);
+    uint64_t RangeEnd = Binary->offsetToVirtualAddr(Range.first.second);
+    uint64_t Count = Range.second;
+    // Disjoint ranges have introduce zero-filled gap that
+    // doesn't belong to current context, filter them out.
+    if (Count == 0)
+      continue;
+
+    InstructionPointer IP(Binary, RangeBegin, true);
+
+    // Disjoint ranges may have range in the middle of two instr,
+    // e.g. If Instr1 at Addr1, and Instr2 at Addr2, disjoint range
+    // can be Addr1+1 to Addr2-1. We should ignore such range.
+    if (IP.Address > RangeEnd)
+      continue;
+
+    while (IP.Address <= RangeEnd) {
+      uint64_t Offset = Binary->virtualAddrToOffset(IP.Address);
+      const FrameLocation &LeafLoc = Binary->getInlineLeafFrameLoc(Offset);
+      // Recording body sample for this specific context
+      updateBodySamplesforFunctionProfile(FunctionProfile, LeafLoc, Count);
+      // Move to next IP within the range
+      IP.advance();
     }
   }
 }
 
-void CSProfileGenerator::populateFunctionBoundarySamples() {
-  for (const auto &BI : BinarySampleCounters) {
-    ProfiledBinary *Binary = BI.first;
-    for (const auto &CI : BI.second.BranchCounter) {
-      StringRef ContextId(CI.first);
-      // Get or create function profile for branch Source
-      FunctionSamples &FunctionProfile =
-          getFunctionProfileForContext(ContextId);
-
-      for (auto Entry : CI.second) {
-        uint64_t SourceOffset = Entry.first.first;
-        uint64_t TargetOffset = Entry.first.second;
-        uint64_t Count = Entry.second;
-        // Get the callee name by branch target if it's a call branch
-        StringRef CalleeName = FunctionSamples::getCanonicalFnName(
-            Binary->getFuncFromStartOffset(TargetOffset));
-        if (CalleeName.size() == 0)
-          continue;
-
-        // Record called target sample and its count
-        const FrameLocation &LeafLoc =
-            Binary->getInlineLeafFrameLoc(SourceOffset);
-
-        FunctionProfile.addCalledTargetSamples(LeafLoc.second.LineOffset,
-                                               LeafLoc.second.Discriminator,
-                                               CalleeName, Count);
-        FunctionProfile.addTotalSamples(Count);
-
-        // Record head sample for called target(callee)
-        // TODO: Cleanup ' @ '
-        std::string CalleeContextId =
-            getCallSite(LeafLoc) + " @ " + CalleeName.str();
-        if (ContextId.find(" @ ") != StringRef::npos) {
-          CalleeContextId =
-              ContextId.rsplit(" @ ").first.str() + " @ " + CalleeContextId;
-        }
-
-        if (ProfileMap.find(CalleeContextId) != ProfileMap.end()) {
-          FunctionSamples &CalleeProfile = ProfileMap[CalleeContextId];
-          assert(Count != 0 && "Unexpected zero weight branch");
-          if (CalleeProfile.getName().size()) {
-            CalleeProfile.addHeadSamples(Count);
-          }
-        }
-      }
+void CSProfileGenerator::populateFunctionBoundarySamples(
+    StringRef ContextId, FunctionSamples &FunctionProfile,
+    const BranchSample &BranchCounters, ProfiledBinary *Binary) {
+
+  for (auto Entry : BranchCounters) {
+    uint64_t SourceOffset = Entry.first.first;
+    uint64_t TargetOffset = Entry.first.second;
+    uint64_t Count = Entry.second;
+    // Get the callee name by branch target if it's a call branch
+    StringRef CalleeName = FunctionSamples::getCanonicalFnName(
+        Binary->getFuncFromStartOffset(TargetOffset));
+    if (CalleeName.size() == 0)
+      continue;
+
+    // Record called target sample and its count
+    const FrameLocation &LeafLoc = Binary->getInlineLeafFrameLoc(SourceOffset);
+
+    FunctionProfile.addCalledTargetSamples(LeafLoc.second.LineOffset,
+                                           LeafLoc.second.Discriminator,
+                                           CalleeName, Count);
+    FunctionProfile.addTotalSamples(Count);
+
+    // Record head sample for called target(callee)
+    // TODO: Cleanup ' @ '
+    std::string CalleeContextId =
+        getCallSite(LeafLoc) + " @ " + CalleeName.str();
+    if (ContextId.find(" @ ") != StringRef::npos) {
+      CalleeContextId =
+          ContextId.rsplit(" @ ").first.str() + " @ " + CalleeContextId;
     }
+
+    FunctionSamples &CalleeProfile =
+        getFunctionProfileForContext(CalleeContextId);
+    assert(Count != 0 && "Unexpected zero weight branch");
+    CalleeProfile.addHeadSamples(Count);
   }
 }
 

diff  --git a/llvm/tools/llvm-profgen/ProfileGenerator.h b/llvm/tools/llvm-profgen/ProfileGenerator.h
index f447118e78e0..0ce2465aaa8f 100644
--- a/llvm/tools/llvm-profgen/ProfileGenerator.h
+++ b/llvm/tools/llvm-profgen/ProfileGenerator.h
@@ -64,12 +64,24 @@ class CSProfileGenerator : public ProfileGenerator {
 
 public:
   void generateProfile() override {
-    // Fill in function body samples
-    populateFunctionBodySamples();
-
-    // Fill in boundary sample counts as well as call site samples for calls
-    populateFunctionBoundarySamples();
-
+    for (const auto &BI : BinarySampleCounters) {
+      ProfiledBinary *Binary = BI.first;
+      for (const auto &CI : BI.second) {
+        const StringBasedCtxKey *CtxKey =
+            dyn_cast<StringBasedCtxKey>(CI.first.getPtr());
+        StringRef ContextId(CtxKey->Context);
+        // Get or create function profile for the range
+        FunctionSamples &FunctionProfile =
+            getFunctionProfileForContext(ContextId);
+
+        // Fill in function body samples
+        populateFunctionBodySamples(FunctionProfile, CI.second.RangeCounter,
+                                    Binary);
+        // Fill in boundary sample counts as well as call site samples for calls
+        populateFunctionBoundarySamples(ContextId, FunctionProfile,
+                                        CI.second.BranchCounter, Binary);
+      }
+    }
     // Fill in call site value sample for inlined calls and also use context to
     // infer missing samples. Since we don't have call count for inlined
     // functions, we estimate it from inlinee's profile using the entry of the
@@ -85,8 +97,13 @@ class CSProfileGenerator : public ProfileGenerator {
                                            uint64_t Count);
   // Lookup or create FunctionSamples for the context
   FunctionSamples &getFunctionProfileForContext(StringRef ContextId);
-  void populateFunctionBodySamples();
-  void populateFunctionBoundarySamples();
+  void populateFunctionBodySamples(FunctionSamples &FunctionProfile,
+                                   const RangeSample &RangeCounters,
+                                   ProfiledBinary *Binary);
+  void populateFunctionBoundarySamples(StringRef ContextId,
+                                       FunctionSamples &FunctionProfile,
+                                       const BranchSample &BranchCounters,
+                                       ProfiledBinary *Binary);
   void populateInferredFunctionSamples();
 };
 


        


More information about the llvm-commits mailing list