[llvm-branch-commits] [llvm] [NFC][MemProf] Move Radix tree methods to their own header and cpp. (PR #140501)

Snehasish Kumar via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon May 19 00:09:09 PDT 2025


https://github.com/snehasish created https://github.com/llvm/llvm-project/pull/140501

None

>From c8e520c48fe9e64f9e2ac389498d0e27797bf362 Mon Sep 17 00:00:00 2001
From: Snehasish Kumar <snehasishk at google.com>
Date: Fri, 16 May 2025 18:54:05 -0700
Subject: [PATCH] [NFC][MemProf] Move Radix tree methods to their own header
 and cpp.

---
 llvm/include/llvm/ProfileData/MemProf.h       | 336 ----------------
 .../llvm/ProfileData/MemProfRadixTree.h       | 358 ++++++++++++++++++
 llvm/include/llvm/ProfileData/MemProfReader.h |   2 +-
 llvm/lib/Bitcode/Writer/BitcodeWriter.cpp     |   1 +
 llvm/lib/ProfileData/CMakeLists.txt           |   1 +
 llvm/lib/ProfileData/IndexedMemProfData.cpp   |   1 +
 llvm/lib/ProfileData/InstrProfReader.cpp      |   3 +-
 llvm/lib/ProfileData/MemProf.cpp              | 235 ------------
 llvm/lib/ProfileData/MemProfRadixTree.cpp     | 253 +++++++++++++
 llvm/unittests/ProfileData/InstrProfTest.cpp  |   1 +
 llvm/unittests/ProfileData/MemProfTest.cpp    |   3 +-
 11 files changed, 620 insertions(+), 574 deletions(-)
 create mode 100644 llvm/include/llvm/ProfileData/MemProfRadixTree.h
 create mode 100644 llvm/lib/ProfileData/MemProfRadixTree.cpp

diff --git a/llvm/include/llvm/ProfileData/MemProf.h b/llvm/include/llvm/ProfileData/MemProf.h
index e713c3807611b..0bc1432f7d198 100644
--- a/llvm/include/llvm/ProfileData/MemProf.h
+++ b/llvm/include/llvm/ProfileData/MemProf.h
@@ -818,133 +818,6 @@ class CallStackLookupTrait {
   }
 };
 
-namespace detail {
-// "Dereference" the iterator from DenseMap or OnDiskChainedHashTable.  We have
-// to do so in one of two different ways depending on the type of the hash
-// table.
-template <typename value_type, typename IterTy>
-value_type DerefIterator(IterTy Iter) {
-  using deref_type = llvm::remove_cvref_t<decltype(*Iter)>;
-  if constexpr (std::is_same_v<deref_type, value_type>)
-    return *Iter;
-  else
-    return Iter->second;
-}
-} // namespace detail
-
-// A function object that returns a frame for a given FrameId.
-template <typename MapTy> struct FrameIdConverter {
-  std::optional<FrameId> LastUnmappedId;
-  MapTy ⤅
-
-  FrameIdConverter() = delete;
-  FrameIdConverter(MapTy &Map) : Map(Map) {}
-
-  // Delete the copy constructor and copy assignment operator to avoid a
-  // situation where a copy of FrameIdConverter gets an error in LastUnmappedId
-  // while the original instance doesn't.
-  FrameIdConverter(const FrameIdConverter &) = delete;
-  FrameIdConverter &operator=(const FrameIdConverter &) = delete;
-
-  Frame operator()(FrameId Id) {
-    auto Iter = Map.find(Id);
-    if (Iter == Map.end()) {
-      LastUnmappedId = Id;
-      return Frame();
-    }
-    return detail::DerefIterator<Frame>(Iter);
-  }
-};
-
-// A function object that returns a call stack for a given CallStackId.
-template <typename MapTy> struct CallStackIdConverter {
-  std::optional<CallStackId> LastUnmappedId;
-  MapTy ⤅
-  llvm::function_ref<Frame(FrameId)> FrameIdToFrame;
-
-  CallStackIdConverter() = delete;
-  CallStackIdConverter(MapTy &Map,
-                       llvm::function_ref<Frame(FrameId)> FrameIdToFrame)
-      : Map(Map), FrameIdToFrame(FrameIdToFrame) {}
-
-  // Delete the copy constructor and copy assignment operator to avoid a
-  // situation where a copy of CallStackIdConverter gets an error in
-  // LastUnmappedId while the original instance doesn't.
-  CallStackIdConverter(const CallStackIdConverter &) = delete;
-  CallStackIdConverter &operator=(const CallStackIdConverter &) = delete;
-
-  std::vector<Frame> operator()(CallStackId CSId) {
-    std::vector<Frame> Frames;
-    auto CSIter = Map.find(CSId);
-    if (CSIter == Map.end()) {
-      LastUnmappedId = CSId;
-    } else {
-      llvm::SmallVector<FrameId> CS =
-          detail::DerefIterator<llvm::SmallVector<FrameId>>(CSIter);
-      Frames.reserve(CS.size());
-      for (FrameId Id : CS)
-        Frames.push_back(FrameIdToFrame(Id));
-    }
-    return Frames;
-  }
-};
-
-// A function object that returns a Frame stored at a given index into the Frame
-// array in the profile.
-struct LinearFrameIdConverter {
-  const unsigned char *FrameBase;
-
-  LinearFrameIdConverter() = delete;
-  LinearFrameIdConverter(const unsigned char *FrameBase)
-      : FrameBase(FrameBase) {}
-
-  Frame operator()(LinearFrameId LinearId) {
-    uint64_t Offset = static_cast<uint64_t>(LinearId) * Frame::serializedSize();
-    return Frame::deserialize(FrameBase + Offset);
-  }
-};
-
-// A function object that returns a call stack stored at a given index into the
-// call stack array in the profile.
-struct LinearCallStackIdConverter {
-  const unsigned char *CallStackBase;
-  llvm::function_ref<Frame(LinearFrameId)> FrameIdToFrame;
-
-  LinearCallStackIdConverter() = delete;
-  LinearCallStackIdConverter(
-      const unsigned char *CallStackBase,
-      llvm::function_ref<Frame(LinearFrameId)> FrameIdToFrame)
-      : CallStackBase(CallStackBase), FrameIdToFrame(FrameIdToFrame) {}
-
-  std::vector<Frame> operator()(LinearCallStackId LinearCSId) {
-    std::vector<Frame> Frames;
-
-    const unsigned char *Ptr =
-        CallStackBase +
-        static_cast<uint64_t>(LinearCSId) * sizeof(LinearFrameId);
-    uint32_t NumFrames =
-        support::endian::readNext<uint32_t, llvm::endianness::little>(Ptr);
-    Frames.reserve(NumFrames);
-    for (; NumFrames; --NumFrames) {
-      LinearFrameId Elem =
-          support::endian::read<LinearFrameId, llvm::endianness::little>(Ptr);
-      // Follow a pointer to the parent, if any.  See comments below on
-      // CallStackRadixTreeBuilder for the description of the radix tree format.
-      if (static_cast<std::make_signed_t<LinearFrameId>>(Elem) < 0) {
-        Ptr += (-Elem) * sizeof(LinearFrameId);
-        Elem =
-            support::endian::read<LinearFrameId, llvm::endianness::little>(Ptr);
-      }
-      // We shouldn't encounter another pointer.
-      assert(static_cast<std::make_signed_t<LinearFrameId>>(Elem) >= 0);
-      Frames.push_back(FrameIdToFrame(Elem));
-      Ptr += sizeof(LinearFrameId);
-    }
-
-    return Frames;
-  }
-};
-
 struct LineLocation {
   LineLocation(uint32_t L, uint32_t D) : LineOffset(L), Column(D) {}
 
@@ -970,73 +843,6 @@ struct LineLocation {
 // A pair of a call site location and its corresponding callee GUID.
 using CallEdgeTy = std::pair<LineLocation, uint64_t>;
 
-// Used to extract caller-callee pairs from the call stack array.  The leaf
-// frame is assumed to call a heap allocation function with GUID 0.  The
-// resulting pairs are accumulated in CallerCalleePairs.  Users can take it
-// with:
-//
-//   auto Pairs = std::move(Extractor.CallerCalleePairs);
-struct CallerCalleePairExtractor {
-  // The base address of the radix tree array.
-  const unsigned char *CallStackBase;
-  // A functor to convert a linear FrameId to a Frame.
-  llvm::function_ref<Frame(LinearFrameId)> FrameIdToFrame;
-  // A map from caller GUIDs to lists of call sites in respective callers.
-  DenseMap<uint64_t, SmallVector<CallEdgeTy, 0>> CallerCalleePairs;
-
-  // The set of linear call stack IDs that we've visited.
-  BitVector Visited;
-
-  CallerCalleePairExtractor() = delete;
-  CallerCalleePairExtractor(
-      const unsigned char *CallStackBase,
-      llvm::function_ref<Frame(LinearFrameId)> FrameIdToFrame,
-      unsigned RadixTreeSize)
-      : CallStackBase(CallStackBase), FrameIdToFrame(FrameIdToFrame),
-        Visited(RadixTreeSize) {}
-
-  void operator()(LinearCallStackId LinearCSId) {
-    const unsigned char *Ptr =
-        CallStackBase +
-        static_cast<uint64_t>(LinearCSId) * sizeof(LinearFrameId);
-    uint32_t NumFrames =
-        support::endian::readNext<uint32_t, llvm::endianness::little>(Ptr);
-    // The leaf frame calls a function with GUID 0.
-    uint64_t CalleeGUID = 0;
-    for (; NumFrames; --NumFrames) {
-      LinearFrameId Elem =
-          support::endian::read<LinearFrameId, llvm::endianness::little>(Ptr);
-      // Follow a pointer to the parent, if any.  See comments below on
-      // CallStackRadixTreeBuilder for the description of the radix tree format.
-      if (static_cast<std::make_signed_t<LinearFrameId>>(Elem) < 0) {
-        Ptr += (-Elem) * sizeof(LinearFrameId);
-        Elem =
-            support::endian::read<LinearFrameId, llvm::endianness::little>(Ptr);
-      }
-      // We shouldn't encounter another pointer.
-      assert(static_cast<std::make_signed_t<LinearFrameId>>(Elem) >= 0);
-
-      // Add a new caller-callee pair.
-      Frame F = FrameIdToFrame(Elem);
-      uint64_t CallerGUID = F.Function;
-      LineLocation Loc(F.LineOffset, F.Column);
-      CallerCalleePairs[CallerGUID].emplace_back(Loc, CalleeGUID);
-
-      // Keep track of the indices we've visited.  If we've already visited the
-      // current one, terminate the traversal.  We will not discover any new
-      // caller-callee pair by continuing the traversal.
-      unsigned Offset =
-          std::distance(CallStackBase, Ptr) / sizeof(LinearFrameId);
-      if (Visited.test(Offset))
-        break;
-      Visited.set(Offset);
-
-      Ptr += sizeof(LinearFrameId);
-      CalleeGUID = CallerGUID;
-    }
-  }
-};
-
 struct IndexedMemProfData {
   // A map to hold memprof data per function. The lower 64 bits obtained from
   // the md5 hash of the function name is used to index into the map.
@@ -1087,148 +893,6 @@ struct IndexedMemProfData {
   // Compute a CallStackId for a given call stack.
   CallStackId hashCallStack(ArrayRef<FrameId> CS) const;
 };
-
-// A convenience wrapper around FrameIdConverter and CallStackIdConverter for
-// tests.
-struct IndexedCallstackIdConverter {
-  IndexedCallstackIdConverter() = delete;
-  IndexedCallstackIdConverter(IndexedMemProfData &MemProfData)
-      : FrameIdConv(MemProfData.Frames),
-        CSIdConv(MemProfData.CallStacks, FrameIdConv) {}
-
-  // Delete the copy constructor and copy assignment operator to avoid a
-  // situation where a copy of IndexedCallstackIdConverter gets an error in
-  // LastUnmappedId while the original instance doesn't.
-  IndexedCallstackIdConverter(const IndexedCallstackIdConverter &) = delete;
-  IndexedCallstackIdConverter &
-  operator=(const IndexedCallstackIdConverter &) = delete;
-
-  std::vector<Frame> operator()(CallStackId CSId) { return CSIdConv(CSId); }
-
-  FrameIdConverter<decltype(IndexedMemProfData::Frames)> FrameIdConv;
-  CallStackIdConverter<decltype(IndexedMemProfData::CallStacks)> CSIdConv;
-};
-
-struct FrameStat {
-  // The number of occurrences of a given FrameId.
-  uint64_t Count = 0;
-  // The sum of indexes where a given FrameId shows up.
-  uint64_t PositionSum = 0;
-};
-
-// Compute a histogram of Frames in call stacks.
-template <typename FrameIdTy>
-llvm::DenseMap<FrameIdTy, FrameStat>
-computeFrameHistogram(llvm::MapVector<CallStackId, llvm::SmallVector<FrameIdTy>>
-                          &MemProfCallStackData);
-
-// Construct a radix tree of call stacks.
-//
-// A set of call stacks might look like:
-//
-// CallStackId 1:  f1 -> f2 -> f3
-// CallStackId 2:  f1 -> f2 -> f4 -> f5
-// CallStackId 3:  f1 -> f2 -> f4 -> f6
-// CallStackId 4:  f7 -> f8 -> f9
-//
-// where each fn refers to a stack frame.
-//
-// Since we expect a lot of common prefixes, we can compress the call stacks
-// into a radix tree like:
-//
-// CallStackId 1:  f1 -> f2 -> f3
-//                       |
-// CallStackId 2:        +---> f4 -> f5
-//                             |
-// CallStackId 3:              +---> f6
-//
-// CallStackId 4:  f7 -> f8 -> f9
-//
-// Now, we are interested in retrieving call stacks for a given CallStackId, so
-// we just need a pointer from a given call stack to its parent.  For example,
-// CallStackId 2 would point to CallStackId 1 as a parent.
-//
-// We serialize the radix tree above into a single array along with the length
-// of each call stack and pointers to the parent call stacks.
-//
-// Index:              0  1  2  3  4  5  6  7  8  9 10 11 12 13 14
-// Array:             L3 f9 f8 f7 L4 f6 J3 L4 f5 f4 J3 L3 f3 f2 f1
-//                     ^           ^        ^           ^
-//                     |           |        |           |
-// CallStackId 4:  0 --+           |        |           |
-// CallStackId 3:  4 --------------+        |           |
-// CallStackId 2:  7 -----------------------+           |
-// CallStackId 1: 11 -----------------------------------+
-//
-// - LN indicates the length of a call stack, encoded as ordinary integer N.
-//
-// - JN indicates a pointer to the parent, encoded as -N.
-//
-// The radix tree allows us to reconstruct call stacks in the leaf-to-root
-// order as we scan the array from left ro right while following pointers to
-// parents along the way.
-//
-// For example, if we are decoding CallStackId 2, we start a forward traversal
-// at Index 7, noting the call stack length of 4 and obtaining f5 and f4.  When
-// we see J3 at Index 10, we resume a forward traversal at Index 13 = 10 + 3,
-// picking up f2 and f1.  We are done after collecting 4 frames as indicated at
-// the beginning of the traversal.
-//
-// On-disk IndexedMemProfRecord will refer to call stacks by their indexes into
-// the radix tree array, so we do not explicitly encode mappings like:
-// "CallStackId 1 -> 11".
-template <typename FrameIdTy> class CallStackRadixTreeBuilder {
-  // The radix tree array.
-  std::vector<LinearFrameId> RadixArray;
-
-  // Mapping from CallStackIds to indexes into RadixArray.
-  llvm::DenseMap<CallStackId, LinearCallStackId> CallStackPos;
-
-  // In build, we partition a given call stack into two parts -- the prefix
-  // that's common with the previously encoded call stack and the frames beyond
-  // the common prefix -- the unique portion.  Then we want to find out where
-  // the common prefix is stored in RadixArray so that we can link the unique
-  // portion to the common prefix.  Indexes, declared below, helps with our
-  // needs.  Intuitively, Indexes tells us where each of the previously encoded
-  // call stack is stored in RadixArray.  More formally, Indexes satisfies:
-  //
-  //   RadixArray[Indexes[I]] == Prev[I]
-  //
-  // for every I, where Prev is the the call stack in the root-to-leaf order
-  // previously encoded by build.  (Note that Prev, as passed to
-  // encodeCallStack, is in the leaf-to-root order.)
-  //
-  // For example, if the call stack being encoded shares 5 frames at the root of
-  // the call stack with the previously encoded call stack,
-  // RadixArray[Indexes[0]] is the root frame of the common prefix.
-  // RadixArray[Indexes[5 - 1]] is the last frame of the common prefix.
-  std::vector<LinearCallStackId> Indexes;
-
-  using CSIdPair = std::pair<CallStackId, llvm::SmallVector<FrameIdTy>>;
-
-  // Encode a call stack into RadixArray.  Return the starting index within
-  // RadixArray.
-  LinearCallStackId encodeCallStack(
-      const llvm::SmallVector<FrameIdTy> *CallStack,
-      const llvm::SmallVector<FrameIdTy> *Prev,
-      const llvm::DenseMap<FrameIdTy, LinearFrameId> *MemProfFrameIndexes);
-
-public:
-  CallStackRadixTreeBuilder() = default;
-
-  // Build a radix tree array.
-  void
-  build(llvm::MapVector<CallStackId, llvm::SmallVector<FrameIdTy>>
-            &&MemProfCallStackData,
-        const llvm::DenseMap<FrameIdTy, LinearFrameId> *MemProfFrameIndexes,
-        llvm::DenseMap<FrameIdTy, FrameStat> &FrameHistogram);
-
-  ArrayRef<LinearFrameId> getRadixArray() const { return RadixArray; }
-
-  llvm::DenseMap<CallStackId, LinearCallStackId> takeCallStackPos() {
-    return std::move(CallStackPos);
-  }
-};
 } // namespace memprof
 } // namespace llvm
 
diff --git a/llvm/include/llvm/ProfileData/MemProfRadixTree.h b/llvm/include/llvm/ProfileData/MemProfRadixTree.h
new file mode 100644
index 0000000000000..9abf9f7a8774b
--- /dev/null
+++ b/llvm/include/llvm/ProfileData/MemProfRadixTree.h
@@ -0,0 +1,358 @@
+//===- MemProfRadixTree.h - MemProf format support ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// A custom Radix Tree builder for memprof data to optimize for space.  
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_PROFILEDATA_MEMPROFRADIXTREE_H
+#define LLVM_PROFILEDATA_MEMPROFRADIXTREE_H
+
+#include "llvm/ProfileData/MemProf.h"
+
+namespace llvm {
+namespace memprof {
+namespace detail {
+// "Dereference" the iterator from DenseMap or OnDiskChainedHashTable.  We have
+// to do so in one of two different ways depending on the type of the hash
+// table.
+template <typename value_type, typename IterTy>
+value_type DerefIterator(IterTy Iter) {
+  using deref_type = llvm::remove_cvref_t<decltype(*Iter)>;
+  if constexpr (std::is_same_v<deref_type, value_type>)
+    return *Iter;
+  else
+    return Iter->second;
+}
+} // namespace detail
+
+// A function object that returns a frame for a given FrameId.
+template <typename MapTy> struct FrameIdConverter {
+  std::optional<FrameId> LastUnmappedId;
+  MapTy ⤅
+
+  FrameIdConverter() = delete;
+  FrameIdConverter(MapTy &Map) : Map(Map) {}
+
+  // Delete the copy constructor and copy assignment operator to avoid a
+  // situation where a copy of FrameIdConverter gets an error in LastUnmappedId
+  // while the original instance doesn't.
+  FrameIdConverter(const FrameIdConverter &) = delete;
+  FrameIdConverter &operator=(const FrameIdConverter &) = delete;
+
+  Frame operator()(FrameId Id) {
+    auto Iter = Map.find(Id);
+    if (Iter == Map.end()) {
+      LastUnmappedId = Id;
+      return Frame();
+    }
+    return detail::DerefIterator<Frame>(Iter);
+  }
+};
+
+// A function object that returns a call stack for a given CallStackId.
+template <typename MapTy> struct CallStackIdConverter {
+  std::optional<CallStackId> LastUnmappedId;
+  MapTy ⤅
+  llvm::function_ref<Frame(FrameId)> FrameIdToFrame;
+
+  CallStackIdConverter() = delete;
+  CallStackIdConverter(MapTy &Map,
+                       llvm::function_ref<Frame(FrameId)> FrameIdToFrame)
+      : Map(Map), FrameIdToFrame(FrameIdToFrame) {}
+
+  // Delete the copy constructor and copy assignment operator to avoid a
+  // situation where a copy of CallStackIdConverter gets an error in
+  // LastUnmappedId while the original instance doesn't.
+  CallStackIdConverter(const CallStackIdConverter &) = delete;
+  CallStackIdConverter &operator=(const CallStackIdConverter &) = delete;
+
+  std::vector<Frame> operator()(CallStackId CSId) {
+    std::vector<Frame> Frames;
+    auto CSIter = Map.find(CSId);
+    if (CSIter == Map.end()) {
+      LastUnmappedId = CSId;
+    } else {
+      llvm::SmallVector<FrameId> CS =
+          detail::DerefIterator<llvm::SmallVector<FrameId>>(CSIter);
+      Frames.reserve(CS.size());
+      for (FrameId Id : CS)
+        Frames.push_back(FrameIdToFrame(Id));
+    }
+    return Frames;
+  }
+};
+
+// A function object that returns a Frame stored at a given index into the Frame
+// array in the profile.
+struct LinearFrameIdConverter {
+  const unsigned char *FrameBase;
+
+  LinearFrameIdConverter() = delete;
+  LinearFrameIdConverter(const unsigned char *FrameBase)
+      : FrameBase(FrameBase) {}
+
+  Frame operator()(LinearFrameId LinearId) {
+    uint64_t Offset = static_cast<uint64_t>(LinearId) * Frame::serializedSize();
+    return Frame::deserialize(FrameBase + Offset);
+  }
+};
+
+// A function object that returns a call stack stored at a given index into the
+// call stack array in the profile.
+struct LinearCallStackIdConverter {
+  const unsigned char *CallStackBase;
+  llvm::function_ref<Frame(LinearFrameId)> FrameIdToFrame;
+
+  LinearCallStackIdConverter() = delete;
+  LinearCallStackIdConverter(
+      const unsigned char *CallStackBase,
+      llvm::function_ref<Frame(LinearFrameId)> FrameIdToFrame)
+      : CallStackBase(CallStackBase), FrameIdToFrame(FrameIdToFrame) {}
+
+  std::vector<Frame> operator()(LinearCallStackId LinearCSId) {
+    std::vector<Frame> Frames;
+
+    const unsigned char *Ptr =
+        CallStackBase +
+        static_cast<uint64_t>(LinearCSId) * sizeof(LinearFrameId);
+    uint32_t NumFrames =
+        support::endian::readNext<uint32_t, llvm::endianness::little>(Ptr);
+    Frames.reserve(NumFrames);
+    for (; NumFrames; --NumFrames) {
+      LinearFrameId Elem =
+          support::endian::read<LinearFrameId, llvm::endianness::little>(Ptr);
+      // Follow a pointer to the parent, if any.  See comments below on
+      // CallStackRadixTreeBuilder for the description of the radix tree format.
+      if (static_cast<std::make_signed_t<LinearFrameId>>(Elem) < 0) {
+        Ptr += (-Elem) * sizeof(LinearFrameId);
+        Elem =
+            support::endian::read<LinearFrameId, llvm::endianness::little>(Ptr);
+      }
+      // We shouldn't encounter another pointer.
+      assert(static_cast<std::make_signed_t<LinearFrameId>>(Elem) >= 0);
+      Frames.push_back(FrameIdToFrame(Elem));
+      Ptr += sizeof(LinearFrameId);
+    }
+
+    return Frames;
+  }
+};
+
+// Used to extract caller-callee pairs from the call stack array.  The leaf
+// frame is assumed to call a heap allocation function with GUID 0.  The
+// resulting pairs are accumulated in CallerCalleePairs.  Users can take it
+// with:
+//
+//   auto Pairs = std::move(Extractor.CallerCalleePairs);
+struct CallerCalleePairExtractor {
+  // The base address of the radix tree array.
+  const unsigned char *CallStackBase;
+  // A functor to convert a linear FrameId to a Frame.
+  llvm::function_ref<Frame(LinearFrameId)> FrameIdToFrame;
+  // A map from caller GUIDs to lists of call sites in respective callers.
+  DenseMap<uint64_t, SmallVector<CallEdgeTy, 0>> CallerCalleePairs;
+
+  // The set of linear call stack IDs that we've visited.
+  BitVector Visited;
+
+  CallerCalleePairExtractor() = delete;
+  CallerCalleePairExtractor(
+      const unsigned char *CallStackBase,
+      llvm::function_ref<Frame(LinearFrameId)> FrameIdToFrame,
+      unsigned RadixTreeSize)
+      : CallStackBase(CallStackBase), FrameIdToFrame(FrameIdToFrame),
+        Visited(RadixTreeSize) {}
+
+  void operator()(LinearCallStackId LinearCSId) {
+    const unsigned char *Ptr =
+        CallStackBase +
+        static_cast<uint64_t>(LinearCSId) * sizeof(LinearFrameId);
+    uint32_t NumFrames =
+        support::endian::readNext<uint32_t, llvm::endianness::little>(Ptr);
+    // The leaf frame calls a function with GUID 0.
+    uint64_t CalleeGUID = 0;
+    for (; NumFrames; --NumFrames) {
+      LinearFrameId Elem =
+          support::endian::read<LinearFrameId, llvm::endianness::little>(Ptr);
+      // Follow a pointer to the parent, if any.  See comments below on
+      // CallStackRadixTreeBuilder for the description of the radix tree format.
+      if (static_cast<std::make_signed_t<LinearFrameId>>(Elem) < 0) {
+        Ptr += (-Elem) * sizeof(LinearFrameId);
+        Elem =
+            support::endian::read<LinearFrameId, llvm::endianness::little>(Ptr);
+      }
+      // We shouldn't encounter another pointer.
+      assert(static_cast<std::make_signed_t<LinearFrameId>>(Elem) >= 0);
+
+      // Add a new caller-callee pair.
+      Frame F = FrameIdToFrame(Elem);
+      uint64_t CallerGUID = F.Function;
+      LineLocation Loc(F.LineOffset, F.Column);
+      CallerCalleePairs[CallerGUID].emplace_back(Loc, CalleeGUID);
+
+      // Keep track of the indices we've visited.  If we've already visited the
+      // current one, terminate the traversal.  We will not discover any new
+      // caller-callee pair by continuing the traversal.
+      unsigned Offset =
+          std::distance(CallStackBase, Ptr) / sizeof(LinearFrameId);
+      if (Visited.test(Offset))
+        break;
+      Visited.set(Offset);
+
+      Ptr += sizeof(LinearFrameId);
+      CalleeGUID = CallerGUID;
+    }
+  }
+};
+
+
+// A convenience wrapper around FrameIdConverter and CallStackIdConverter for
+// tests.
+struct IndexedCallstackIdConverter {
+  IndexedCallstackIdConverter() = delete;
+  IndexedCallstackIdConverter(IndexedMemProfData &MemProfData)
+      : FrameIdConv(MemProfData.Frames),
+        CSIdConv(MemProfData.CallStacks, FrameIdConv) {}
+
+  // Delete the copy constructor and copy assignment operator to avoid a
+  // situation where a copy of IndexedCallstackIdConverter gets an error in
+  // LastUnmappedId while the original instance doesn't.
+  IndexedCallstackIdConverter(const IndexedCallstackIdConverter &) = delete;
+  IndexedCallstackIdConverter &
+  operator=(const IndexedCallstackIdConverter &) = delete;
+
+  std::vector<Frame> operator()(CallStackId CSId) { return CSIdConv(CSId); }
+
+  FrameIdConverter<decltype(IndexedMemProfData::Frames)> FrameIdConv;
+  CallStackIdConverter<decltype(IndexedMemProfData::CallStacks)> CSIdConv;
+};
+
+struct FrameStat {
+  // The number of occurrences of a given FrameId.
+  uint64_t Count = 0;
+  // The sum of indexes where a given FrameId shows up.
+  uint64_t PositionSum = 0;
+};
+
+// Compute a histogram of Frames in call stacks.
+template <typename FrameIdTy>
+llvm::DenseMap<FrameIdTy, FrameStat>
+computeFrameHistogram(llvm::MapVector<CallStackId, llvm::SmallVector<FrameIdTy>>
+                          &MemProfCallStackData);
+
+// Construct a radix tree of call stacks.
+//
+// A set of call stacks might look like:
+//
+// CallStackId 1:  f1 -> f2 -> f3
+// CallStackId 2:  f1 -> f2 -> f4 -> f5
+// CallStackId 3:  f1 -> f2 -> f4 -> f6
+// CallStackId 4:  f7 -> f8 -> f9
+//
+// where each fn refers to a stack frame.
+//
+// Since we expect a lot of common prefixes, we can compress the call stacks
+// into a radix tree like:
+//
+// CallStackId 1:  f1 -> f2 -> f3
+//                       |
+// CallStackId 2:        +---> f4 -> f5
+//                             |
+// CallStackId 3:              +---> f6
+//
+// CallStackId 4:  f7 -> f8 -> f9
+//
+// Now, we are interested in retrieving call stacks for a given CallStackId, so
+// we just need a pointer from a given call stack to its parent.  For example,
+// CallStackId 2 would point to CallStackId 1 as a parent.
+//
+// We serialize the radix tree above into a single array along with the length
+// of each call stack and pointers to the parent call stacks.
+//
+// Index:              0  1  2  3  4  5  6  7  8  9 10 11 12 13 14
+// Array:             L3 f9 f8 f7 L4 f6 J3 L4 f5 f4 J3 L3 f3 f2 f1
+//                     ^           ^        ^           ^
+//                     |           |        |           |
+// CallStackId 4:  0 --+           |        |           |
+// CallStackId 3:  4 --------------+        |           |
+// CallStackId 2:  7 -----------------------+           |
+// CallStackId 1: 11 -----------------------------------+
+//
+// - LN indicates the length of a call stack, encoded as ordinary integer N.
+//
+// - JN indicates a pointer to the parent, encoded as -N.
+//
+// The radix tree allows us to reconstruct call stacks in the leaf-to-root
+// order as we scan the array from left ro right while following pointers to
+// parents along the way.
+//
+// For example, if we are decoding CallStackId 2, we start a forward traversal
+// at Index 7, noting the call stack length of 4 and obtaining f5 and f4.  When
+// we see J3 at Index 10, we resume a forward traversal at Index 13 = 10 + 3,
+// picking up f2 and f1.  We are done after collecting 4 frames as indicated at
+// the beginning of the traversal.
+//
+// On-disk IndexedMemProfRecord will refer to call stacks by their indexes into
+// the radix tree array, so we do not explicitly encode mappings like:
+// "CallStackId 1 -> 11".
+template <typename FrameIdTy> class CallStackRadixTreeBuilder {
+  // The radix tree array.
+  std::vector<LinearFrameId> RadixArray;
+
+  // Mapping from CallStackIds to indexes into RadixArray.
+  llvm::DenseMap<CallStackId, LinearCallStackId> CallStackPos;
+
+  // In build, we partition a given call stack into two parts -- the prefix
+  // that's common with the previously encoded call stack and the frames beyond
+  // the common prefix -- the unique portion.  Then we want to find out where
+  // the common prefix is stored in RadixArray so that we can link the unique
+  // portion to the common prefix.  Indexes, declared below, helps with our
+  // needs.  Intuitively, Indexes tells us where each of the previously encoded
+  // call stack is stored in RadixArray.  More formally, Indexes satisfies:
+  //
+  //   RadixArray[Indexes[I]] == Prev[I]
+  //
+  // for every I, where Prev is the the call stack in the root-to-leaf order
+  // previously encoded by build.  (Note that Prev, as passed to
+  // encodeCallStack, is in the leaf-to-root order.)
+  //
+  // For example, if the call stack being encoded shares 5 frames at the root of
+  // the call stack with the previously encoded call stack,
+  // RadixArray[Indexes[0]] is the root frame of the common prefix.
+  // RadixArray[Indexes[5 - 1]] is the last frame of the common prefix.
+  std::vector<LinearCallStackId> Indexes;
+
+  using CSIdPair = std::pair<CallStackId, llvm::SmallVector<FrameIdTy>>;
+
+  // Encode a call stack into RadixArray.  Return the starting index within
+  // RadixArray.
+  LinearCallStackId encodeCallStack(
+      const llvm::SmallVector<FrameIdTy> *CallStack,
+      const llvm::SmallVector<FrameIdTy> *Prev,
+      const llvm::DenseMap<FrameIdTy, LinearFrameId> *MemProfFrameIndexes);
+
+public:
+  CallStackRadixTreeBuilder() = default;
+
+  // Build a radix tree array.
+  void
+  build(llvm::MapVector<CallStackId, llvm::SmallVector<FrameIdTy>>
+            &&MemProfCallStackData,
+        const llvm::DenseMap<FrameIdTy, LinearFrameId> *MemProfFrameIndexes,
+        llvm::DenseMap<FrameIdTy, FrameStat> &FrameHistogram);
+
+  ArrayRef<LinearFrameId> getRadixArray() const { return RadixArray; }
+
+  llvm::DenseMap<CallStackId, LinearCallStackId> takeCallStackPos() {
+    return std::move(CallStackPos);
+  }
+};
+} // namespace memprof
+} // namespace llvm
+#endif // LLVM_PROFILEDATA_MEMPROFRADIXTREE_H
diff --git a/llvm/include/llvm/ProfileData/MemProfReader.h b/llvm/include/llvm/ProfileData/MemProfReader.h
index 29d9e57cae3e3..9aa55554fdf72 100644
--- a/llvm/include/llvm/ProfileData/MemProfReader.h
+++ b/llvm/include/llvm/ProfileData/MemProfReader.h
@@ -22,8 +22,8 @@
 #include "llvm/Object/Binary.h"
 #include "llvm/Object/ObjectFile.h"
 #include "llvm/ProfileData/InstrProfReader.h"
-#include "llvm/ProfileData/MemProf.h"
 #include "llvm/ProfileData/MemProfData.inc"
+#include "llvm/ProfileData/MemProfRadixTree.h"
 #include "llvm/Support/Error.h"
 #include "llvm/Support/MemoryBuffer.h"
 
diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index 1a15c5120d3fd..f8748babb1625 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -62,6 +62,7 @@
 #include "llvm/MC/TargetRegistry.h"
 #include "llvm/Object/IRSymtab.h"
 #include "llvm/ProfileData/MemProf.h"
+#include "llvm/ProfileData/MemProfRadixTree.h"
 #include "llvm/Support/AtomicOrdering.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/CommandLine.h"
diff --git a/llvm/lib/ProfileData/CMakeLists.txt b/llvm/lib/ProfileData/CMakeLists.txt
index 67a69d7761b2c..ca9ea3205ee1d 100644
--- a/llvm/lib/ProfileData/CMakeLists.txt
+++ b/llvm/lib/ProfileData/CMakeLists.txt
@@ -9,6 +9,7 @@ add_llvm_component_library(LLVMProfileData
   ItaniumManglingCanonicalizer.cpp
   MemProf.cpp
   MemProfReader.cpp
+  MemProfRadixTree.cpp
   PGOCtxProfReader.cpp
   PGOCtxProfWriter.cpp
   ProfileSummaryBuilder.cpp
diff --git a/llvm/lib/ProfileData/IndexedMemProfData.cpp b/llvm/lib/ProfileData/IndexedMemProfData.cpp
index 3d20f7a7a5778..59e59720179af 100644
--- a/llvm/lib/ProfileData/IndexedMemProfData.cpp
+++ b/llvm/lib/ProfileData/IndexedMemProfData.cpp
@@ -13,6 +13,7 @@
 #include "llvm/ProfileData/InstrProf.h"
 #include "llvm/ProfileData/InstrProfReader.h"
 #include "llvm/ProfileData/MemProf.h"
+#include "llvm/ProfileData/MemProfRadixTree.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/OnDiskHashTable.h"
 
diff --git a/llvm/lib/ProfileData/InstrProfReader.cpp b/llvm/lib/ProfileData/InstrProfReader.cpp
index e6c83430cd8e9..65b8eb514784b 100644
--- a/llvm/lib/ProfileData/InstrProfReader.cpp
+++ b/llvm/lib/ProfileData/InstrProfReader.cpp
@@ -18,7 +18,8 @@
 #include "llvm/ADT/StringRef.h"
 #include "llvm/IR/ProfileSummary.h"
 #include "llvm/ProfileData/InstrProf.h"
-#include "llvm/ProfileData/MemProf.h"
+//#include "llvm/ProfileData/MemProf.h"
+#include "llvm/ProfileData/MemProfRadixTree.h"
 #include "llvm/ProfileData/ProfileCommon.h"
 #include "llvm/ProfileData/SymbolRemappingReader.h"
 #include "llvm/Support/Endian.h"
diff --git a/llvm/lib/ProfileData/MemProf.cpp b/llvm/lib/ProfileData/MemProf.cpp
index e497bbff67d2e..a9c5ee09a6daf 100644
--- a/llvm/lib/ProfileData/MemProf.cpp
+++ b/llvm/lib/ProfileData/MemProf.cpp
@@ -395,240 +395,5 @@ CallStackId IndexedMemProfData::hashCallStack(ArrayRef<FrameId> CS) const {
   std::memcpy(&CSId, Hash.data(), sizeof(Hash));
   return CSId;
 }
-
-// Encode a call stack into RadixArray.  Return the starting index within
-// RadixArray.  For each call stack we encode, we emit two or three components
-// into RadixArray.  If a given call stack doesn't have a common prefix relative
-// to the previous one, we emit:
-//
-// - the frames in the given call stack in the root-to-leaf order
-//
-// - the length of the given call stack
-//
-// If a given call stack has a non-empty common prefix relative to the previous
-// one, we emit:
-//
-// - the relative location of the common prefix, encoded as a negative number.
-//
-// - a portion of the given call stack that's beyond the common prefix
-//
-// - the length of the given call stack, including the length of the common
-//   prefix.
-//
-// The resulting RadixArray requires a somewhat unintuitive backward traversal
-// to reconstruct a call stack -- read the call stack length and scan backward
-// while collecting frames in the leaf to root order.  build, the caller of this
-// function, reverses RadixArray in place so that we can reconstruct a call
-// stack as if we were deserializing an array in a typical way -- the call stack
-// length followed by the frames in the leaf-to-root order except that we need
-// to handle pointers to parents along the way.
-//
-// To quickly determine the location of the common prefix within RadixArray,
-// Indexes caches the indexes of the previous call stack's frames within
-// RadixArray.
-template <typename FrameIdTy>
-LinearCallStackId CallStackRadixTreeBuilder<FrameIdTy>::encodeCallStack(
-    const llvm::SmallVector<FrameIdTy> *CallStack,
-    const llvm::SmallVector<FrameIdTy> *Prev,
-    const llvm::DenseMap<FrameIdTy, LinearFrameId> *MemProfFrameIndexes) {
-  // Compute the length of the common root prefix between Prev and CallStack.
-  uint32_t CommonLen = 0;
-  if (Prev) {
-    auto Pos = std::mismatch(Prev->rbegin(), Prev->rend(), CallStack->rbegin(),
-                             CallStack->rend());
-    CommonLen = std::distance(CallStack->rbegin(), Pos.second);
-  }
-
-  // Drop the portion beyond CommonLen.
-  assert(CommonLen <= Indexes.size());
-  Indexes.resize(CommonLen);
-
-  // Append a pointer to the parent.
-  if (CommonLen) {
-    uint32_t CurrentIndex = RadixArray.size();
-    uint32_t ParentIndex = Indexes.back();
-    // The offset to the parent must be negative because we are pointing to an
-    // element we've already added to RadixArray.
-    assert(ParentIndex < CurrentIndex);
-    RadixArray.push_back(ParentIndex - CurrentIndex);
-  }
-
-  // Copy the part of the call stack beyond the common prefix to RadixArray.
-  assert(CommonLen <= CallStack->size());
-  for (FrameIdTy F : llvm::drop_begin(llvm::reverse(*CallStack), CommonLen)) {
-    // Remember the index of F in RadixArray.
-    Indexes.push_back(RadixArray.size());
-    RadixArray.push_back(
-        MemProfFrameIndexes ? MemProfFrameIndexes->find(F)->second : F);
-  }
-  assert(CallStack->size() == Indexes.size());
-
-  // End with the call stack length.
-  RadixArray.push_back(CallStack->size());
-
-  // Return the index within RadixArray where we can start reconstructing a
-  // given call stack from.
-  return RadixArray.size() - 1;
-}
-
-template <typename FrameIdTy>
-void CallStackRadixTreeBuilder<FrameIdTy>::build(
-    llvm::MapVector<CallStackId, llvm::SmallVector<FrameIdTy>>
-        &&MemProfCallStackData,
-    const llvm::DenseMap<FrameIdTy, LinearFrameId> *MemProfFrameIndexes,
-    llvm::DenseMap<FrameIdTy, FrameStat> &FrameHistogram) {
-  // Take the vector portion of MemProfCallStackData.  The vector is exactly
-  // what we need to sort.  Also, we no longer need its lookup capability.
-  llvm::SmallVector<CSIdPair, 0> CallStacks = MemProfCallStackData.takeVector();
-
-  // Return early if we have no work to do.
-  if (CallStacks.empty()) {
-    RadixArray.clear();
-    CallStackPos.clear();
-    return;
-  }
-
-  // Sorting the list of call stacks in the dictionary order is sufficient to
-  // maximize the length of the common prefix between two adjacent call stacks
-  // and thus minimize the length of RadixArray.  However, we go one step
-  // further and try to reduce the number of times we follow pointers to parents
-  // during deserilization.  Consider a poorly encoded radix tree:
-  //
-  // CallStackId 1:  f1 -> f2 -> f3
-  //                  |
-  // CallStackId 2:   +--- f4 -> f5
-  //                        |
-  // CallStackId 3:         +--> f6
-  //
-  // Here, f2 and f4 appear once and twice, respectively, in the call stacks.
-  // Once we encode CallStackId 1 into RadixArray, every other call stack with
-  // common prefix f1 ends up pointing to CallStackId 1.  Since CallStackId 3
-  // share "f1 f4" with CallStackId 2, CallStackId 3 needs to follow pointers to
-  // parents twice.
-  //
-  // We try to alleviate the situation by sorting the list of call stacks by
-  // comparing the popularity of frames rather than the integer values of
-  // FrameIds.  In the example above, f4 is more popular than f2, so we sort the
-  // call stacks and encode them as:
-  //
-  // CallStackId 2:  f1 -- f4 -> f5
-  //                  |     |
-  // CallStackId 3:   |     +--> f6
-  //                  |
-  // CallStackId 1:   +--> f2 -> f3
-  //
-  // Notice that CallStackId 3 follows a pointer to a parent only once.
-  //
-  // All this is a quick-n-dirty trick to reduce the number of jumps.  The
-  // proper way would be to compute the weight of each radix tree node -- how
-  // many call stacks use a given radix tree node, and encode a radix tree from
-  // the heaviest node first.  We do not do so because that's a lot of work.
-  llvm::sort(CallStacks, [&](const CSIdPair &L, const CSIdPair &R) {
-    // Call stacks are stored from leaf to root.  Perform comparisons from the
-    // root.
-    return std::lexicographical_compare(
-        L.second.rbegin(), L.second.rend(), R.second.rbegin(), R.second.rend(),
-        [&](FrameIdTy F1, FrameIdTy F2) {
-          uint64_t H1 = FrameHistogram[F1].Count;
-          uint64_t H2 = FrameHistogram[F2].Count;
-          // Popular frames should come later because we encode call stacks from
-          // the last one in the list.
-          if (H1 != H2)
-            return H1 < H2;
-          // For sort stability.
-          return F1 < F2;
-        });
-  });
-
-  // Reserve some reasonable amount of storage.
-  RadixArray.clear();
-  RadixArray.reserve(CallStacks.size() * 8);
-
-  // Indexes will grow as long as the longest call stack.
-  Indexes.clear();
-  Indexes.reserve(512);
-
-  // CallStackPos will grow to exactly CallStacks.size() entries.
-  CallStackPos.clear();
-  CallStackPos.reserve(CallStacks.size());
-
-  // Compute the radix array.  We encode one call stack at a time, computing the
-  // longest prefix that's shared with the previous call stack we encode.  For
-  // each call stack we encode, we remember a mapping from CallStackId to its
-  // position within RadixArray.
-  //
-  // As an optimization, we encode from the last call stack in CallStacks to
-  // reduce the number of times we follow pointers to the parents.  Consider the
-  // list of call stacks that has been sorted in the dictionary order:
-  //
-  // Call Stack 1: F1
-  // Call Stack 2: F1 -> F2
-  // Call Stack 3: F1 -> F2 -> F3
-  //
-  // If we traversed CallStacks in the forward order, we would end up with a
-  // radix tree like:
-  //
-  // Call Stack 1:  F1
-  //                |
-  // Call Stack 2:  +---> F2
-  //                      |
-  // Call Stack 3:        +---> F3
-  //
-  // Notice that each call stack jumps to the previous one.  However, if we
-  // traverse CallStacks in the reverse order, then Call Stack 3 has the
-  // complete call stack encoded without any pointers.  Call Stack 1 and 2 point
-  // to appropriate prefixes of Call Stack 3.
-  const llvm::SmallVector<FrameIdTy> *Prev = nullptr;
-  for (const auto &[CSId, CallStack] : llvm::reverse(CallStacks)) {
-    LinearCallStackId Pos =
-        encodeCallStack(&CallStack, Prev, MemProfFrameIndexes);
-    CallStackPos.insert({CSId, Pos});
-    Prev = &CallStack;
-  }
-
-  // "RadixArray.size() - 1" below is problematic if RadixArray is empty.
-  assert(!RadixArray.empty());
-
-  // Reverse the radix array in place.  We do so mostly for intuitive
-  // deserialization where we would read the length field and then the call
-  // stack frames proper just like any other array deserialization, except
-  // that we have occasional jumps to take advantage of prefixes.
-  for (size_t I = 0, J = RadixArray.size() - 1; I < J; ++I, --J)
-    std::swap(RadixArray[I], RadixArray[J]);
-
-  // "Reverse" the indexes stored in CallStackPos.
-  for (auto &[K, V] : CallStackPos)
-    V = RadixArray.size() - 1 - V;
-}
-
-// Explicitly instantiate class with the utilized FrameIdTy.
-template class CallStackRadixTreeBuilder<FrameId>;
-template class CallStackRadixTreeBuilder<LinearFrameId>;
-
-template <typename FrameIdTy>
-llvm::DenseMap<FrameIdTy, FrameStat>
-computeFrameHistogram(llvm::MapVector<CallStackId, llvm::SmallVector<FrameIdTy>>
-                          &MemProfCallStackData) {
-  llvm::DenseMap<FrameIdTy, FrameStat> Histogram;
-
-  for (const auto &KV : MemProfCallStackData) {
-    const auto &CS = KV.second;
-    for (unsigned I = 0, E = CS.size(); I != E; ++I) {
-      auto &S = Histogram[CS[I]];
-      ++S.Count;
-      S.PositionSum += I;
-    }
-  }
-  return Histogram;
-}
-
-// Explicitly instantiate function with the utilized FrameIdTy.
-template llvm::DenseMap<FrameId, FrameStat> computeFrameHistogram<FrameId>(
-    llvm::MapVector<CallStackId, llvm::SmallVector<FrameId>>
-        &MemProfCallStackData);
-template llvm::DenseMap<LinearFrameId, FrameStat>
-computeFrameHistogram<LinearFrameId>(
-    llvm::MapVector<CallStackId, llvm::SmallVector<LinearFrameId>>
-        &MemProfCallStackData);
 } // namespace memprof
 } // namespace llvm
diff --git a/llvm/lib/ProfileData/MemProfRadixTree.cpp b/llvm/lib/ProfileData/MemProfRadixTree.cpp
new file mode 100644
index 0000000000000..5ef357efdeffd
--- /dev/null
+++ b/llvm/lib/ProfileData/MemProfRadixTree.cpp
@@ -0,0 +1,253 @@
+//===- MemProfRadixTree.cpp - Radix tree encoded callstacks ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// This file contains logic that implements a space efficient radix tree 
+// encoding for callstacks used by MemProf.
+//
+//===----------------------------------------------------------------------===//
+
+
+#include "llvm/ProfileData/MemProfRadixTree.h"
+
+namespace llvm {
+namespace memprof {
+// Encode a call stack into RadixArray.  Return the starting index within
+// RadixArray.  For each call stack we encode, we emit two or three components
+// into RadixArray.  If a given call stack doesn't have a common prefix relative
+// to the previous one, we emit:
+//
+// - the frames in the given call stack in the root-to-leaf order
+//
+// - the length of the given call stack
+//
+// If a given call stack has a non-empty common prefix relative to the previous
+// one, we emit:
+//
+// - the relative location of the common prefix, encoded as a negative number.
+//
+// - a portion of the given call stack that's beyond the common prefix
+//
+// - the length of the given call stack, including the length of the common
+//   prefix.
+//
+// The resulting RadixArray requires a somewhat unintuitive backward traversal
+// to reconstruct a call stack -- read the call stack length and scan backward
+// while collecting frames in the leaf to root order.  build, the caller of this
+// function, reverses RadixArray in place so that we can reconstruct a call
+// stack as if we were deserializing an array in a typical way -- the call stack
+// length followed by the frames in the leaf-to-root order except that we need
+// to handle pointers to parents along the way.
+//
+// To quickly determine the location of the common prefix within RadixArray,
+// Indexes caches the indexes of the previous call stack's frames within
+// RadixArray.
+template <typename FrameIdTy>
+LinearCallStackId CallStackRadixTreeBuilder<FrameIdTy>::encodeCallStack(
+    const llvm::SmallVector<FrameIdTy> *CallStack,
+    const llvm::SmallVector<FrameIdTy> *Prev,
+    const llvm::DenseMap<FrameIdTy, LinearFrameId> *MemProfFrameIndexes) {
+  // Compute the length of the common root prefix between Prev and CallStack.
+  uint32_t CommonLen = 0;
+  if (Prev) {
+    auto Pos = std::mismatch(Prev->rbegin(), Prev->rend(), CallStack->rbegin(),
+                             CallStack->rend());
+    CommonLen = std::distance(CallStack->rbegin(), Pos.second);
+  }
+
+  // Drop the portion beyond CommonLen.
+  assert(CommonLen <= Indexes.size());
+  Indexes.resize(CommonLen);
+
+  // Append a pointer to the parent.
+  if (CommonLen) {
+    uint32_t CurrentIndex = RadixArray.size();
+    uint32_t ParentIndex = Indexes.back();
+    // The offset to the parent must be negative because we are pointing to an
+    // element we've already added to RadixArray.
+    assert(ParentIndex < CurrentIndex);
+    RadixArray.push_back(ParentIndex - CurrentIndex);
+  }
+
+  // Copy the part of the call stack beyond the common prefix to RadixArray.
+  assert(CommonLen <= CallStack->size());
+  for (FrameIdTy F : llvm::drop_begin(llvm::reverse(*CallStack), CommonLen)) {
+    // Remember the index of F in RadixArray.
+    Indexes.push_back(RadixArray.size());
+    RadixArray.push_back(
+        MemProfFrameIndexes ? MemProfFrameIndexes->find(F)->second : F);
+  }
+  assert(CallStack->size() == Indexes.size());
+
+  // End with the call stack length.
+  RadixArray.push_back(CallStack->size());
+
+  // Return the index within RadixArray where we can start reconstructing a
+  // given call stack from.
+  return RadixArray.size() - 1;
+}
+
+template <typename FrameIdTy>
+void CallStackRadixTreeBuilder<FrameIdTy>::build(
+    llvm::MapVector<CallStackId, llvm::SmallVector<FrameIdTy>>
+        &&MemProfCallStackData,
+    const llvm::DenseMap<FrameIdTy, LinearFrameId> *MemProfFrameIndexes,
+    llvm::DenseMap<FrameIdTy, FrameStat> &FrameHistogram) {
+  // Take the vector portion of MemProfCallStackData.  The vector is exactly
+  // what we need to sort.  Also, we no longer need its lookup capability.
+  llvm::SmallVector<CSIdPair, 0> CallStacks = MemProfCallStackData.takeVector();
+
+  // Return early if we have no work to do.
+  if (CallStacks.empty()) {
+    RadixArray.clear();
+    CallStackPos.clear();
+    return;
+  }
+
+  // Sorting the list of call stacks in the dictionary order is sufficient to
+  // maximize the length of the common prefix between two adjacent call stacks
+  // and thus minimize the length of RadixArray.  However, we go one step
+  // further and try to reduce the number of times we follow pointers to parents
+  // during deserilization.  Consider a poorly encoded radix tree:
+  //
+  // CallStackId 1:  f1 -> f2 -> f3
+  //                  |
+  // CallStackId 2:   +--- f4 -> f5
+  //                        |
+  // CallStackId 3:         +--> f6
+  //
+  // Here, f2 and f4 appear once and twice, respectively, in the call stacks.
+  // Once we encode CallStackId 1 into RadixArray, every other call stack with
+  // common prefix f1 ends up pointing to CallStackId 1.  Since CallStackId 3
+  // share "f1 f4" with CallStackId 2, CallStackId 3 needs to follow pointers to
+  // parents twice.
+  //
+  // We try to alleviate the situation by sorting the list of call stacks by
+  // comparing the popularity of frames rather than the integer values of
+  // FrameIds.  In the example above, f4 is more popular than f2, so we sort the
+  // call stacks and encode them as:
+  //
+  // CallStackId 2:  f1 -- f4 -> f5
+  //                  |     |
+  // CallStackId 3:   |     +--> f6
+  //                  |
+  // CallStackId 1:   +--> f2 -> f3
+  //
+  // Notice that CallStackId 3 follows a pointer to a parent only once.
+  //
+  // All this is a quick-n-dirty trick to reduce the number of jumps.  The
+  // proper way would be to compute the weight of each radix tree node -- how
+  // many call stacks use a given radix tree node, and encode a radix tree from
+  // the heaviest node first.  We do not do so because that's a lot of work.
+  llvm::sort(CallStacks, [&](const CSIdPair &L, const CSIdPair &R) {
+    // Call stacks are stored from leaf to root.  Perform comparisons from the
+    // root.
+    return std::lexicographical_compare(
+        L.second.rbegin(), L.second.rend(), R.second.rbegin(), R.second.rend(),
+        [&](FrameIdTy F1, FrameIdTy F2) {
+          uint64_t H1 = FrameHistogram[F1].Count;
+          uint64_t H2 = FrameHistogram[F2].Count;
+          // Popular frames should come later because we encode call stacks from
+          // the last one in the list.
+          if (H1 != H2)
+            return H1 < H2;
+          // For sort stability.
+          return F1 < F2;
+        });
+  });
+
+  // Reserve some reasonable amount of storage.
+  RadixArray.clear();
+  RadixArray.reserve(CallStacks.size() * 8);
+
+  // Indexes will grow as long as the longest call stack.
+  Indexes.clear();
+  Indexes.reserve(512);
+
+  // CallStackPos will grow to exactly CallStacks.size() entries.
+  CallStackPos.clear();
+  CallStackPos.reserve(CallStacks.size());
+
+  // Compute the radix array.  We encode one call stack at a time, computing the
+  // longest prefix that's shared with the previous call stack we encode.  For
+  // each call stack we encode, we remember a mapping from CallStackId to its
+  // position within RadixArray.
+  //
+  // As an optimization, we encode from the last call stack in CallStacks to
+  // reduce the number of times we follow pointers to the parents.  Consider the
+  // list of call stacks that has been sorted in the dictionary order:
+  //
+  // Call Stack 1: F1
+  // Call Stack 2: F1 -> F2
+  // Call Stack 3: F1 -> F2 -> F3
+  //
+  // If we traversed CallStacks in the forward order, we would end up with a
+  // radix tree like:
+  //
+  // Call Stack 1:  F1
+  //                |
+  // Call Stack 2:  +---> F2
+  //                      |
+  // Call Stack 3:        +---> F3
+  //
+  // Notice that each call stack jumps to the previous one.  However, if we
+  // traverse CallStacks in the reverse order, then Call Stack 3 has the
+  // complete call stack encoded without any pointers.  Call Stack 1 and 2 point
+  // to appropriate prefixes of Call Stack 3.
+  const llvm::SmallVector<FrameIdTy> *Prev = nullptr;
+  for (const auto &[CSId, CallStack] : llvm::reverse(CallStacks)) {
+    LinearCallStackId Pos =
+        encodeCallStack(&CallStack, Prev, MemProfFrameIndexes);
+    CallStackPos.insert({CSId, Pos});
+    Prev = &CallStack;
+  }
+
+  // "RadixArray.size() - 1" below is problematic if RadixArray is empty.
+  assert(!RadixArray.empty());
+
+  // Reverse the radix array in place.  We do so mostly for intuitive
+  // deserialization where we would read the length field and then the call
+  // stack frames proper just like any other array deserialization, except
+  // that we have occasional jumps to take advantage of prefixes.
+  for (size_t I = 0, J = RadixArray.size() - 1; I < J; ++I, --J)
+    std::swap(RadixArray[I], RadixArray[J]);
+
+  // "Reverse" the indexes stored in CallStackPos.
+  for (auto &[K, V] : CallStackPos)
+    V = RadixArray.size() - 1 - V;
+}
+
+// Explicitly instantiate class with the utilized FrameIdTy.
+template class CallStackRadixTreeBuilder<FrameId>;
+template class CallStackRadixTreeBuilder<LinearFrameId>;
+
+template <typename FrameIdTy>
+llvm::DenseMap<FrameIdTy, FrameStat>
+computeFrameHistogram(llvm::MapVector<CallStackId, llvm::SmallVector<FrameIdTy>>
+                          &MemProfCallStackData) {
+  llvm::DenseMap<FrameIdTy, FrameStat> Histogram;
+
+  for (const auto &KV : MemProfCallStackData) {
+    const auto &CS = KV.second;
+    for (unsigned I = 0, E = CS.size(); I != E; ++I) {
+      auto &S = Histogram[CS[I]];
+      ++S.Count;
+      S.PositionSum += I;
+    }
+  }
+  return Histogram;
+}
+
+// Explicitly instantiate function with the utilized FrameIdTy.
+template llvm::DenseMap<FrameId, FrameStat> computeFrameHistogram<FrameId>(
+    llvm::MapVector<CallStackId, llvm::SmallVector<FrameId>>
+        &MemProfCallStackData);
+template llvm::DenseMap<LinearFrameId, FrameStat>
+computeFrameHistogram<LinearFrameId>(
+    llvm::MapVector<CallStackId, llvm::SmallVector<LinearFrameId>>
+        &MemProfCallStackData);
+} // namespace memprof
+} // namespace llvm
diff --git a/llvm/unittests/ProfileData/InstrProfTest.cpp b/llvm/unittests/ProfileData/InstrProfTest.cpp
index a0bd41bccf928..9f1caae296cca 100644
--- a/llvm/unittests/ProfileData/InstrProfTest.cpp
+++ b/llvm/unittests/ProfileData/InstrProfTest.cpp
@@ -15,6 +15,7 @@
 #include "llvm/ProfileData/InstrProfReader.h"
 #include "llvm/ProfileData/InstrProfWriter.h"
 #include "llvm/ProfileData/MemProf.h"
+#include "llvm/ProfileData/MemProfRadixTree.h"
 #include "llvm/ProfileData/MemProfData.inc"
 #include "llvm/Support/Compression.h"
 #include "llvm/Support/raw_ostream.h"
diff --git a/llvm/unittests/ProfileData/MemProfTest.cpp b/llvm/unittests/ProfileData/MemProfTest.cpp
index b1937992ea7f4..a072dee26d9a0 100644
--- a/llvm/unittests/ProfileData/MemProfTest.cpp
+++ b/llvm/unittests/ProfileData/MemProfTest.cpp
@@ -6,7 +6,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "llvm/ProfileData/MemProf.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/STLForwardCompat.h"
@@ -14,8 +13,10 @@
 #include "llvm/DebugInfo/Symbolize/SymbolizableModule.h"
 #include "llvm/IR/Value.h"
 #include "llvm/Object/ObjectFile.h"
+#include "llvm/ProfileData/MemProf.h"
 #include "llvm/ProfileData/MemProfData.inc"
 #include "llvm/ProfileData/MemProfReader.h"
+#include "llvm/ProfileData/MemProfRadixTree.h"
 #include "llvm/ProfileData/MemProfYAML.h"
 #include "llvm/Support/raw_ostream.h"
 #include "gmock/gmock.h"



More information about the llvm-branch-commits mailing list