[llvm] [CGData] Outlined Hash Tree (PR #89792)

Kyungwoo Lee via llvm-commits llvm-commits at lists.llvm.org
Sat May 4 22:21:17 PDT 2024


https://github.com/kyulee-com updated https://github.com/llvm/llvm-project/pull/89792

>From fda8026a3e1c720d2c8fee90a51e8e4f7182bded Mon Sep 17 00:00:00 2001
From: Kyungwoo Lee <kyulee at meta.com>
Date: Mon, 22 Apr 2024 15:29:25 -0700
Subject: [PATCH 1/2] [CGData] Outlined Hash Tree

This defines the OutlinedHashTree class.
It contains sequences of stable hash values of instructions that have been outlined.
This OutlinedHashTree can be used to track the outlined instruction sequences across modules.
A trie structure is used in its implementation, allowing for a compact sharing of common prefixes.
---
 .../llvm/CodeGenData/OutlinedHashTree.h       | 107 +++++++++++
 .../llvm/CodeGenData/OutlinedHashTreeRecord.h |  67 +++++++
 llvm/lib/CMakeLists.txt                       |   1 +
 llvm/lib/CodeGenData/CMakeLists.txt           |  17 ++
 llvm/lib/CodeGenData/OutlinedHashTree.cpp     | 131 ++++++++++++++
 .../CodeGenData/OutlinedHashTreeRecord.cpp    | 168 ++++++++++++++++++
 llvm/unittests/CMakeLists.txt                 |   1 +
 llvm/unittests/CodeGenData/CMakeLists.txt     |  14 ++
 .../OutlinedHashTreeRecordTest.cpp            | 118 ++++++++++++
 .../CodeGenData/OutlinedHashTreeTest.cpp      |  81 +++++++++
 10 files changed, 705 insertions(+)
 create mode 100644 llvm/include/llvm/CodeGenData/OutlinedHashTree.h
 create mode 100644 llvm/include/llvm/CodeGenData/OutlinedHashTreeRecord.h
 create mode 100644 llvm/lib/CodeGenData/CMakeLists.txt
 create mode 100644 llvm/lib/CodeGenData/OutlinedHashTree.cpp
 create mode 100644 llvm/lib/CodeGenData/OutlinedHashTreeRecord.cpp
 create mode 100644 llvm/unittests/CodeGenData/CMakeLists.txt
 create mode 100644 llvm/unittests/CodeGenData/OutlinedHashTreeRecordTest.cpp
 create mode 100644 llvm/unittests/CodeGenData/OutlinedHashTreeTest.cpp

diff --git a/llvm/include/llvm/CodeGenData/OutlinedHashTree.h b/llvm/include/llvm/CodeGenData/OutlinedHashTree.h
new file mode 100644
index 00000000000000..875e1a78bb4010
--- /dev/null
+++ b/llvm/include/llvm/CodeGenData/OutlinedHashTree.h
@@ -0,0 +1,107 @@
+//===- OutlinedHashTree.h --------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+//
+// This defines the OutlinedHashTree class. It contains sequences of stable
+// hash values of instructions that have been outlined. This OutlinedHashTree
+// can be used to track the outlined instruction sequences across modules.
+//
+//===---------------------------------------------------------------------===//
+
+#ifndef LLVM_CODEGENDATA_OUTLINEDHASHTREE_H
+#define LLVM_CODEGENDATA_OUTLINEDHASHTREE_H
+
+#include "llvm/ADT/StableHashing.h"
+#include "llvm/ObjectYAML/YAML.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <unordered_map>
+#include <vector>
+
+namespace llvm {
+
+/// A HashNode is an entry in an OutlinedHashTree, holding a hash value
+/// and a collection of Successors (other HashNodes). If a HashNode has
+/// a positive terminal value (Terminals > 0), it signifies the end of
+/// a hash sequence with that occurrence count.
+struct HashNode {
+  /// The hash value of the node.
+  stable_hash Hash;
+  /// The number of terminals in the sequence ending at this node.
+  unsigned Terminals;
+  /// The successors of this node.
+  std::unordered_map<stable_hash, std::unique_ptr<HashNode>> Successors;
+};
+
+/// HashNodeStable is the serialized, stable, and compact representation
+/// of a HashNode.
+struct HashNodeStable {
+  llvm::yaml::Hex64 Hash;
+  unsigned Terminals;
+  std::vector<unsigned> SuccessorIds;
+};
+
+class OutlinedHashTree {
+
+  using EdgeCallbackFn =
+      std::function<void(const HashNode *, const HashNode *)>;
+  using NodeCallbackFn = std::function<void(const HashNode *)>;
+
+  using HashSequence = std::vector<stable_hash>;
+  using HashSequencePair = std::pair<std::vector<stable_hash>, unsigned>;
+
+public:
+  /// Walks every edge and node in the OutlinedHashTree and calls CallbackEdge
+  /// for the edges and CallbackNode for the nodes with the stable_hash for
+  /// the source and the stable_hash of the sink for an edge. These generic
+  /// callbacks can be used to traverse a OutlinedHashTree for the purpose of
+  /// print debugging or serializing it.
+  void walkGraph(NodeCallbackFn CallbackNode,
+                 EdgeCallbackFn CallbackEdge = nullptr,
+                 bool SortedWalk = false) const;
+
+  /// Release all hash nodes except the root hash node.
+  void clear() {
+    assert(getRoot()->Hash == 0 && getRoot()->Terminals == 0);
+    getRoot()->Successors.clear();
+  }
+
+  /// \returns true if the hash tree has only the root node.
+  bool empty() { return size() == 1; }
+
+  /// \returns the size of a OutlinedHashTree by traversing it. If
+  /// \p GetTerminalCountOnly is true, it only counts the terminal nodes
+  /// (meaning it returns the the number of hash sequences in the
+  /// OutlinedHashTree).
+  size_t size(bool GetTerminalCountOnly = false) const;
+
+  /// \returns the depth of a OutlinedHashTree by traversing it.
+  size_t depth() const;
+
+  /// \returns the root hash node of a OutlinedHashTree.
+  const HashNode *getRoot() const { return Root.get(); }
+  HashNode *getRoot() { return Root.get(); }
+
+  /// Inserts a \p Sequence into the this tree. The last node in the sequence
+  /// will increase Terminals.
+  void insert(const HashSequencePair &SequencePair);
+
+  /// Merge a \p OtherTree into this Tree.
+  void merge(const OutlinedHashTree *OtherTree);
+
+  /// \returns the matching count if \p Sequence exists in the OutlinedHashTree.
+  unsigned find(const HashSequence &Sequence) const;
+
+  OutlinedHashTree() { Root = std::make_unique<HashNode>(); }
+
+private:
+  std::unique_ptr<HashNode> Root;
+};
+
+} // namespace llvm
+
+#endif
diff --git a/llvm/include/llvm/CodeGenData/OutlinedHashTreeRecord.h b/llvm/include/llvm/CodeGenData/OutlinedHashTreeRecord.h
new file mode 100644
index 00000000000000..ccd2ad26dd0871
--- /dev/null
+++ b/llvm/include/llvm/CodeGenData/OutlinedHashTreeRecord.h
@@ -0,0 +1,67 @@
+//===- OutlinedHashTreeRecord.h --------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+//
+// This defines the OutlinedHashTreeRecord class. This class holds the outlined
+// hash tree for both serialization and deserialization processes. It utilizes
+// two data formats for serialization: raw binary data and YAML.
+// These two formats can be used interchangeably.
+//
+//===---------------------------------------------------------------------===//
+
+#ifndef LLVM_CODEGENDATA_OUTLINEDHASHTREERECORD_H
+#define LLVM_CODEGENDATA_OUTLINEDHASHTREERECORD_H
+
+#include "llvm/CodeGenData/OutlinedHashTree.h"
+
+namespace llvm {
+
+using IdHashNodeStableMapTy = std::map<unsigned, HashNodeStable>;
+using IdHashNodeMapTy = std::map<unsigned, HashNode *>;
+using HashNodeIdMapTy = std::unordered_map<const HashNode *, unsigned>;
+
+struct OutlinedHashTreeRecord {
+  std::unique_ptr<OutlinedHashTree> HashTree;
+
+  OutlinedHashTreeRecord() { HashTree = std::make_unique<OutlinedHashTree>(); }
+  OutlinedHashTreeRecord(std::unique_ptr<OutlinedHashTree> HashTree)
+      : HashTree(std::move(HashTree)){};
+
+  /// Serialize the outlined hash tree to a raw_ostream.
+  void serialize(raw_ostream &OS) const;
+  /// Deserialize the outlined hash tree from a raw_ostream.
+  void deserialize(const unsigned char *&Ptr);
+  /// Serialize the outlined hash tree to a YAML stream.
+  void serializeYAML(yaml::Output &YOS) const;
+  /// Deserialize the outlined hash tree from a YAML stream.
+  void deserializeYAML(yaml::Input &YIS);
+
+  /// Merge the other outlined hash tree into this one.
+  void merge(const OutlinedHashTreeRecord &Other) {
+    HashTree->merge(Other.HashTree.get());
+  }
+
+  /// \returns true if the outlined hash tree is empty.
+  bool empty() const { return HashTree->empty(); }
+
+  /// Print the outlined hash tree in a YAML format.
+  void print(raw_ostream &OS = llvm::errs()) const {
+    yaml::Output YOS(OS);
+    serializeYAML(YOS);
+  }
+
+private:
+  /// Convert the outlined hash tree to stable data.
+  void convertToStableData(IdHashNodeStableMapTy &IdNodeStableMap) const;
+
+  /// Convert the stable data back to the outlined hash tree.
+  void convertFromStableData(const IdHashNodeStableMapTy &IdNodeStableMap);
+};
+
+} // end namespace llvm
+
+#endif // LLVM_CODEGENDATA_OUTLINEDHASHTREERECORD_H
diff --git a/llvm/lib/CMakeLists.txt b/llvm/lib/CMakeLists.txt
index 74e2d03c07953d..2ac0b0dc026e16 100644
--- a/llvm/lib/CMakeLists.txt
+++ b/llvm/lib/CMakeLists.txt
@@ -10,6 +10,7 @@ add_subdirectory(InterfaceStub)
 add_subdirectory(IRPrinter)
 add_subdirectory(IRReader)
 add_subdirectory(CodeGen)
+add_subdirectory(CodeGenData)
 add_subdirectory(CodeGenTypes)
 add_subdirectory(BinaryFormat)
 add_subdirectory(Bitcode)
diff --git a/llvm/lib/CodeGenData/CMakeLists.txt b/llvm/lib/CodeGenData/CMakeLists.txt
new file mode 100644
index 00000000000000..1156d53afb2e0f
--- /dev/null
+++ b/llvm/lib/CodeGenData/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_llvm_component_library(LLVMCodeGenData
+  CodeGenData.cpp
+  CodeGenDataReader.cpp
+  CodeGenDataWriter.cpp
+  OutlinedHashTree.cpp
+  OutlinedHashTreeRecord.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${LLVM_MAIN_INCLUDE_DIR}/llvm/CodeGenData
+
+  DEPENDS
+  intrinsics_gen
+
+  LINK_COMPONENTS
+  Core
+  Support
+  )
diff --git a/llvm/lib/CodeGenData/OutlinedHashTree.cpp b/llvm/lib/CodeGenData/OutlinedHashTree.cpp
new file mode 100644
index 00000000000000..032993ded60ead
--- /dev/null
+++ b/llvm/lib/CodeGenData/OutlinedHashTree.cpp
@@ -0,0 +1,131 @@
+//===-- OutlinedHashTree.cpp ----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// An OutlinedHashTree is a Trie that contains sequences of stable hash values
+// of instructions that have been outlined. This OutlinedHashTree can be used
+// to understand the outlined instruction sequences collected across modules.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/CodeGenData/OutlinedHashTree.h"
+
+#include <stack>
+#include <tuple>
+
+#define DEBUG_TYPE "outlined-hash-tree"
+
+using namespace llvm;
+
+void OutlinedHashTree::walkGraph(NodeCallbackFn CallbackNode,
+                                 EdgeCallbackFn CallbackEdge,
+                                 bool SortedWalk) const {
+  std::stack<const HashNode *> Stack;
+  Stack.push(getRoot());
+
+  while (!Stack.empty()) {
+    const auto *Current = Stack.top();
+    Stack.pop();
+    if (CallbackNode)
+      CallbackNode(Current);
+
+    auto HandleNext = [&](const HashNode *Next) {
+      if (CallbackEdge)
+        CallbackEdge(Current, Next);
+      Stack.push(Next);
+    };
+    if (SortedWalk) {
+      std::map<stable_hash, const HashNode *> SortedSuccessors;
+      for (const auto &P : Current->Successors)
+        SortedSuccessors[P.first] = P.second.get();
+      for (const auto &P : SortedSuccessors)
+        HandleNext(P.second);
+    } else {
+      for (const auto &P : Current->Successors)
+        HandleNext(P.second.get());
+    }
+  }
+}
+
+size_t OutlinedHashTree::size(bool GetTerminalCountOnly) const {
+  size_t Size = 0;
+  walkGraph([&Size, GetTerminalCountOnly](const HashNode *N) {
+    Size += (N && (!GetTerminalCountOnly || N->Terminals));
+  });
+  return Size;
+}
+
+size_t OutlinedHashTree::depth() const {
+  size_t Size = 0;
+  std::unordered_map<const HashNode *, size_t> DepthMap;
+  walkGraph([&Size, &DepthMap](
+                const HashNode *N) { Size = std::max(Size, DepthMap[N]); },
+            [&DepthMap](const HashNode *Src, const HashNode *Dst) {
+              size_t Depth = DepthMap[Src];
+              DepthMap[Dst] = Depth + 1;
+            });
+  return Size;
+}
+
+void OutlinedHashTree::insert(const HashSequencePair &SequencePair) {
+  const auto &Sequence = SequencePair.first;
+  unsigned Count = SequencePair.second;
+  HashNode *Current = getRoot();
+
+  for (stable_hash StableHash : Sequence) {
+    auto I = Current->Successors.find(StableHash);
+    if (I == Current->Successors.end()) {
+      std::unique_ptr<HashNode> Next = std::make_unique<HashNode>();
+      HashNode *NextPtr = Next.get();
+      NextPtr->Hash = StableHash;
+      Current->Successors.emplace(StableHash, std::move(Next));
+      Current = NextPtr;
+    } else
+      Current = I->second.get();
+  }
+  Current->Terminals += Count;
+}
+
+void OutlinedHashTree::merge(const OutlinedHashTree *Tree) {
+  HashNode *Dst = getRoot();
+  const HashNode *Src = Tree->getRoot();
+  std::stack<std::pair<HashNode *, const HashNode *>> Stack;
+  Stack.push({Dst, Src});
+
+  while (!Stack.empty()) {
+    auto [DstNode, SrcNode] = Stack.top();
+    Stack.pop();
+    if (!SrcNode)
+      continue;
+    DstNode->Terminals += SrcNode->Terminals;
+
+    for (auto &[Hash, NextSrcNode] : SrcNode->Successors) {
+      HashNode *NextDstNode;
+      auto I = DstNode->Successors.find(Hash);
+      if (I == DstNode->Successors.end()) {
+        auto NextDst = std::make_unique<HashNode>();
+        NextDstNode = NextDst.get();
+        NextDstNode->Hash = Hash;
+        DstNode->Successors.emplace(Hash, std::move(NextDst));
+      } else
+        NextDstNode = I->second.get();
+
+      Stack.push({NextDstNode, NextSrcNode.get()});
+    }
+  }
+}
+
+unsigned OutlinedHashTree::find(const HashSequence &Sequence) const {
+  const HashNode *Current = getRoot();
+  for (stable_hash StableHash : Sequence) {
+    const auto I = Current->Successors.find(StableHash);
+    if (I == Current->Successors.end())
+      return 0;
+    Current = I->second.get();
+  }
+  return Current->Terminals;
+}
diff --git a/llvm/lib/CodeGenData/OutlinedHashTreeRecord.cpp b/llvm/lib/CodeGenData/OutlinedHashTreeRecord.cpp
new file mode 100644
index 00000000000000..0d5dd864c89c55
--- /dev/null
+++ b/llvm/lib/CodeGenData/OutlinedHashTreeRecord.cpp
@@ -0,0 +1,168 @@
+//===-- OutlinedHashTreeRecord.cpp ----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This defines the OutlinedHashTreeRecord class. This class holds the outlined
+// hash tree for both serialization and deserialization processes. It utilizes
+// two data formats for serialization: raw binary data and YAML.
+// These two formats can be used interchangeably.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/CodeGenData/OutlinedHashTreeRecord.h"
+#include "llvm/CodeGenData/OutlinedHashTree.h"
+#include "llvm/ObjectYAML/YAML.h"
+#include "llvm/Support/Endian.h"
+#include "llvm/Support/EndianStream.h"
+
+#define DEBUG_TYPE "outlined-hash-tree"
+
+using namespace llvm;
+using namespace llvm::support;
+
+namespace llvm {
+namespace yaml {
+
+template <> struct MappingTraits<HashNodeStable> {
+  static void mapping(IO &io, HashNodeStable &res) {
+    io.mapRequired("Hash", res.Hash);
+    io.mapRequired("Terminals", res.Terminals);
+    io.mapRequired("SuccessorIds", res.SuccessorIds);
+  }
+};
+
+template <> struct CustomMappingTraits<IdHashNodeStableMapTy> {
+  static void inputOne(IO &io, StringRef Key, IdHashNodeStableMapTy &V) {
+    HashNodeStable NodeStable;
+    io.mapRequired(Key.str().c_str(), NodeStable);
+    unsigned Id;
+    if (Key.getAsInteger(0, Id)) {
+      io.setError("Id not an integer");
+      return;
+    }
+    V.insert({Id, NodeStable});
+  }
+
+  static void output(IO &io, IdHashNodeStableMapTy &V) {
+    for (auto Iter = V.begin(); Iter != V.end(); ++Iter)
+      io.mapRequired(utostr(Iter->first).c_str(), Iter->second);
+  }
+};
+
+} // namespace yaml
+} // namespace llvm
+
+void OutlinedHashTreeRecord::serialize(raw_ostream &OS) const {
+  IdHashNodeStableMapTy IdNodeStableMap;
+  convertToStableData(IdNodeStableMap);
+  support::endian::Writer Writer(OS, endianness::little);
+  Writer.write<uint32_t>(IdNodeStableMap.size());
+
+  for (const auto &[Id, NodeStable] : IdNodeStableMap) {
+    Writer.write<uint32_t>(Id);
+    Writer.write<uint64_t>(NodeStable.Hash);
+    Writer.write<uint32_t>(NodeStable.Terminals);
+    Writer.write<uint32_t>(NodeStable.SuccessorIds.size());
+    for (auto SuccessorId : NodeStable.SuccessorIds)
+      Writer.write<uint32_t>(SuccessorId);
+  }
+}
+
+void OutlinedHashTreeRecord::deserialize(const unsigned char *&Ptr) {
+  IdHashNodeStableMapTy IdNodeStableMap;
+  auto NumIdNodeStableMap =
+      endian::readNext<uint32_t, endianness::little, unaligned>(Ptr);
+
+  for (unsigned I = 0; I < NumIdNodeStableMap; ++I) {
+    auto Id = endian::readNext<uint32_t, endianness::little, unaligned>(Ptr);
+    HashNodeStable NodeStable;
+    NodeStable.Hash =
+        endian::readNext<uint64_t, endianness::little, unaligned>(Ptr);
+    NodeStable.Terminals =
+        endian::readNext<uint32_t, endianness::little, unaligned>(Ptr);
+    auto NumSuccessorIds =
+        endian::readNext<uint32_t, endianness::little, unaligned>(Ptr);
+    for (unsigned J = 0; J < NumSuccessorIds; ++J)
+      NodeStable.SuccessorIds.push_back(
+          endian::readNext<uint32_t, endianness::little, unaligned>(Ptr));
+
+    IdNodeStableMap[Id] = std::move(NodeStable);
+  }
+
+  convertFromStableData(IdNodeStableMap);
+}
+
+void OutlinedHashTreeRecord::serializeYAML(yaml::Output &YOS) const {
+  IdHashNodeStableMapTy IdNodeStableMap;
+  convertToStableData(IdNodeStableMap);
+
+  YOS << IdNodeStableMap;
+}
+
+void OutlinedHashTreeRecord::deserializeYAML(yaml::Input &YIS) {
+  IdHashNodeStableMapTy IdNodeStableMap;
+
+  YIS >> IdNodeStableMap;
+  YIS.nextDocument();
+
+  convertFromStableData(IdNodeStableMap);
+}
+
+void OutlinedHashTreeRecord::convertToStableData(
+    IdHashNodeStableMapTy &IdNodeStableMap) const {
+  // Build NodeIdMap
+  HashNodeIdMapTy NodeIdMap;
+  HashTree->walkGraph(
+      [&NodeIdMap](const HashNode *Current) {
+        size_t Index = NodeIdMap.size();
+        NodeIdMap[Current] = Index;
+        assert(Index = NodeIdMap.size() + 1 &&
+                       "Expected size of NodeMap to increment by 1");
+      },
+      /*EdgeCallbackFn=*/nullptr, /*SortedWork=*/true);
+
+  // Convert NodeIdMap to NodeStableMap
+  for (auto &P : NodeIdMap) {
+    auto *Node = P.first;
+    auto Id = P.second;
+    HashNodeStable NodeStable;
+    NodeStable.Hash = Node->Hash;
+    NodeStable.Terminals = Node->Terminals;
+    for (auto &P : Node->Successors)
+      NodeStable.SuccessorIds.push_back(NodeIdMap[P.second.get()]);
+    IdNodeStableMap[Id] = NodeStable;
+  }
+
+  // Sort the Successors so that they come out in the same order as in the map.
+  for (auto &P : IdNodeStableMap)
+    std::sort(P.second.SuccessorIds.begin(), P.second.SuccessorIds.end());
+}
+
+void OutlinedHashTreeRecord::convertFromStableData(
+    const IdHashNodeStableMapTy &IdNodeStableMap) {
+  IdHashNodeMapTy IdNodeMap;
+  // Initialize the root node at 0.
+  IdNodeMap[0] = HashTree->getRoot();
+  assert(IdNodeMap[0]->Successors.empty());
+
+  for (auto &P : IdNodeStableMap) {
+    auto Id = P.first;
+    const HashNodeStable &NodeStable = P.second;
+    assert(IdNodeMap.count(Id));
+    HashNode *Curr = IdNodeMap[Id];
+    Curr->Hash = NodeStable.Hash;
+    Curr->Terminals = NodeStable.Terminals;
+    auto &Successors = Curr->Successors;
+    assert(Successors.empty());
+    for (auto SuccessorId : NodeStable.SuccessorIds) {
+      auto Sucessor = std::make_unique<HashNode>();
+      IdNodeMap[SuccessorId] = Sucessor.get();
+      auto Hash = IdNodeStableMap.at(SuccessorId).Hash;
+      Successors[Hash] = std::move(Sucessor);
+    }
+  }
+}
diff --git a/llvm/unittests/CMakeLists.txt b/llvm/unittests/CMakeLists.txt
index 46f30ff398e10d..cb4b8513e6d02e 100644
--- a/llvm/unittests/CMakeLists.txt
+++ b/llvm/unittests/CMakeLists.txt
@@ -21,6 +21,7 @@ add_subdirectory(BinaryFormat)
 add_subdirectory(Bitcode)
 add_subdirectory(Bitstream)
 add_subdirectory(CodeGen)
+add_subdirectory(CodeGenData)
 add_subdirectory(DebugInfo)
 add_subdirectory(Debuginfod)
 add_subdirectory(Demangle)
diff --git a/llvm/unittests/CodeGenData/CMakeLists.txt b/llvm/unittests/CodeGenData/CMakeLists.txt
new file mode 100644
index 00000000000000..3d821b87e29d8c
--- /dev/null
+++ b/llvm/unittests/CodeGenData/CMakeLists.txt
@@ -0,0 +1,14 @@
+set(LLVM_LINK_COMPONENTS
+  ${LLVM_TARGETS_TO_BUILD}
+  CodeGen
+  CodeGenData
+  Core
+  Support
+  )
+
+add_llvm_unittest(CodeGenDataTests
+  OutlinedHashTreeRecordTest.cpp
+  OutlinedHashTreeTest.cpp
+  )
+
+target_link_libraries(CodeGenDataTests PRIVATE LLVMTestingSupport)
diff --git a/llvm/unittests/CodeGenData/OutlinedHashTreeRecordTest.cpp b/llvm/unittests/CodeGenData/OutlinedHashTreeRecordTest.cpp
new file mode 100644
index 00000000000000..aa7ad4a33754ff
--- /dev/null
+++ b/llvm/unittests/CodeGenData/OutlinedHashTreeRecordTest.cpp
@@ -0,0 +1,118 @@
+//===- OutlinedHashTreeRecordTest.cpp -------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/CodeGenData/OutlinedHashTreeRecord.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+namespace {
+
+TEST(OutlinedHashTreeRecordTest, Empty) {
+  OutlinedHashTreeRecord HashTreeRecord;
+  ASSERT_TRUE(HashTreeRecord.empty());
+}
+
+TEST(OutlinedHashTreeRecordTest, Print) {
+  OutlinedHashTreeRecord HashTreeRecord;
+  HashTreeRecord.HashTree->insert({{1, 2}, 3});
+
+  const char *ExpectedTreeStr = R"(---
+0:
+  Hash:            0x0
+  Terminals:       0
+  SuccessorIds:    [ 1 ]
+1:
+  Hash:            0x1
+  Terminals:       0
+  SuccessorIds:    [ 2 ]
+2:
+  Hash:            0x2
+  Terminals:       3
+  SuccessorIds:    [  ]
+...
+)";
+  std::string TreeDump;
+  raw_string_ostream OS(TreeDump);
+  HashTreeRecord.print(OS);
+  EXPECT_EQ(ExpectedTreeStr, TreeDump);
+}
+
+TEST(OutlinedHashTreeRecordTest, Stable) {
+  OutlinedHashTreeRecord HashTreeRecord1;
+  HashTreeRecord1.HashTree->insert({{1, 2}, 4});
+  HashTreeRecord1.HashTree->insert({{1, 3}, 5});
+
+  OutlinedHashTreeRecord HashTreeRecord2;
+  HashTreeRecord2.HashTree->insert({{1, 3}, 5});
+  HashTreeRecord2.HashTree->insert({{1, 2}, 4});
+
+  // Output is stable regardless of insertion order.
+  std::string TreeDump1;
+  raw_string_ostream OS1(TreeDump1);
+  HashTreeRecord1.print(OS1);
+  std::string TreeDump2;
+  raw_string_ostream OS2(TreeDump2);
+  HashTreeRecord2.print(OS2);
+
+  EXPECT_EQ(TreeDump1, TreeDump2);
+}
+
+TEST(OutlinedHashTreeRecordTest, Serialize) {
+  OutlinedHashTreeRecord HashTreeRecord1;
+  HashTreeRecord1.HashTree->insert({{1, 2}, 4});
+  HashTreeRecord1.HashTree->insert({{1, 3}, 5});
+
+  // Serialize and deserialize the tree.
+  SmallVector<char> Out;
+  raw_svector_ostream OS(Out);
+  HashTreeRecord1.serialize(OS);
+
+  OutlinedHashTreeRecord HashTreeRecord2;
+  const uint8_t *Data = reinterpret_cast<const uint8_t *>(Out.data());
+  HashTreeRecord2.deserialize(Data);
+
+  // Two trees should be identical.
+  std::string TreeDump1;
+  raw_string_ostream OS1(TreeDump1);
+  HashTreeRecord1.print(OS1);
+  std::string TreeDump2;
+  raw_string_ostream OS2(TreeDump2);
+  HashTreeRecord2.print(OS2);
+
+  EXPECT_EQ(TreeDump1, TreeDump2);
+}
+
+TEST(OutlinedHashTreeRecordTest, SerializeYAML) {
+  OutlinedHashTreeRecord HashTreeRecord1;
+  HashTreeRecord1.HashTree->insert({{1, 2}, 4});
+  HashTreeRecord1.HashTree->insert({{1, 3}, 5});
+
+  // Serialize and deserialize the tree in a YAML format.
+  std::string Out;
+  raw_string_ostream OS(Out);
+  yaml::Output YOS(OS);
+  HashTreeRecord1.serializeYAML(YOS);
+
+  OutlinedHashTreeRecord HashTreeRecord2;
+  yaml::Input YIS(StringRef(Out.data(), Out.size()));
+  HashTreeRecord2.deserializeYAML(YIS);
+
+  // Two trees should be identical.
+  std::string TreeDump1;
+  raw_string_ostream OS1(TreeDump1);
+  HashTreeRecord1.print(OS1);
+  std::string TreeDump2;
+  raw_string_ostream OS2(TreeDump2);
+  HashTreeRecord2.print(OS2);
+
+  EXPECT_EQ(TreeDump1, TreeDump2);
+}
+
+} // end namespace
diff --git a/llvm/unittests/CodeGenData/OutlinedHashTreeTest.cpp b/llvm/unittests/CodeGenData/OutlinedHashTreeTest.cpp
new file mode 100644
index 00000000000000..d11618cf8e4fae
--- /dev/null
+++ b/llvm/unittests/CodeGenData/OutlinedHashTreeTest.cpp
@@ -0,0 +1,81 @@
+//===- OutlinedHashTreeTest.cpp -------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/CodeGenData/OutlinedHashTree.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+namespace {
+
+TEST(OutlinedHashTreeTest, Empty) {
+  OutlinedHashTree HashTree;
+  ASSERT_TRUE(HashTree.empty());
+  // The header node is always present.
+  ASSERT_TRUE(HashTree.size() == 1);
+  ASSERT_TRUE(HashTree.depth() == 0);
+}
+
+TEST(OutlinedHashTreeTest, Insert) {
+  OutlinedHashTree HashTree;
+  HashTree.insert({{1, 2, 3}, 1});
+  // The node count is 4 (including the root node).
+  ASSERT_TRUE(HashTree.size() == 4);
+  // The terminal count is 1.
+  ASSERT_TRUE(HashTree.size(/*GetTerminalCountOnly=*/true) == 1);
+  // The depth is 3.
+  ASSERT_TRUE(HashTree.depth() == 3);
+
+  HashTree.clear();
+  ASSERT_TRUE(HashTree.empty());
+
+  HashTree.insert({{1, 2, 3}, 1});
+  HashTree.insert({{1, 2, 4}, 2});
+  // The nodes of 1 and 2 are shared with the same prefix.
+  // The nodes are root, 1, 2, 3 and 4, whose counts are 5.
+  ASSERT_TRUE(HashTree.size() == 5);
+}
+
+TEST(OutlinedHashTreeTest, Find) {
+  OutlinedHashTree HashTree;
+  HashTree.insert({{1, 2, 3}, 1});
+  HashTree.insert({{1, 2, 3}, 2});
+
+  // The node count does not change as the same sequences are added.
+  ASSERT_TRUE(HashTree.size() == 4);
+  // The terminal counts are accumulated from two same sequences.
+  ASSERT_TRUE(HashTree.find({1, 2, 3}) == 3);
+  ASSERT_TRUE(HashTree.find({1, 2}) == 0);
+}
+
+TEST(OutlinedHashTreeTest, Merge) {
+  // Build HashTree1 inserting 2 sequences.
+  OutlinedHashTree HashTree1;
+
+  HashTree1.insert({{1, 2}, 20});
+  HashTree1.insert({{1, 4}, 30});
+
+  // Build HashTree2 and HashTree3 for each
+  OutlinedHashTree HashTree2;
+  HashTree2.insert({{1, 2}, 20});
+  OutlinedHashTree HashTree3;
+  HashTree3.insert({{1, 4}, 30});
+
+  // Merge HashTree3 into HashTree2.
+  HashTree2.merge(&HashTree3);
+
+  // Compare HashTree1 and HashTree2.
+  EXPECT_EQ(HashTree1.size(), HashTree2.size());
+  EXPECT_EQ(HashTree1.depth(), HashTree2.depth());
+  EXPECT_EQ(HashTree1.find({1, 2}), HashTree2.find({1, 2}));
+  EXPECT_EQ(HashTree1.find({1, 4}), HashTree2.find({1, 4}));
+  EXPECT_EQ(HashTree1.find({1, 3}), HashTree2.find({1, 3}));
+}
+
+} // end namespace

>From 5205a52af0165556ae2d4d7f1957fa1ff2ded4e2 Mon Sep 17 00:00:00 2001
From: Kyungwoo Lee <kyulee at meta.com>
Date: Sat, 4 May 2024 17:24:33 -0700
Subject: [PATCH 2/2] Address comments from Ellis

---
 .../llvm/CodeGenData/OutlinedHashTree.h       | 29 ++++++-----------
 .../llvm/CodeGenData/OutlinedHashTreeRecord.h | 13 ++++++--
 llvm/lib/CodeGenData/OutlinedHashTree.cpp     | 32 +++++++++----------
 .../CodeGenData/OutlinedHashTreeRecord.cpp    |  7 ++--
 .../CodeGenData/OutlinedHashTreeTest.cpp      |  5 +--
 5 files changed, 44 insertions(+), 42 deletions(-)

diff --git a/llvm/include/llvm/CodeGenData/OutlinedHashTree.h b/llvm/include/llvm/CodeGenData/OutlinedHashTree.h
index 875e1a78bb4010..c40038cd8c5174 100644
--- a/llvm/include/llvm/CodeGenData/OutlinedHashTree.h
+++ b/llvm/include/llvm/CodeGenData/OutlinedHashTree.h
@@ -30,29 +30,22 @@ namespace llvm {
 /// a hash sequence with that occurrence count.
 struct HashNode {
   /// The hash value of the node.
-  stable_hash Hash;
+  stable_hash Hash = 0;
   /// The number of terminals in the sequence ending at this node.
-  unsigned Terminals;
+  std::optional<unsigned> Terminals;
   /// The successors of this node.
+  /// We don't use DenseMap as a stable_hash value can be tombstone.
   std::unordered_map<stable_hash, std::unique_ptr<HashNode>> Successors;
 };
 
-/// HashNodeStable is the serialized, stable, and compact representation
-/// of a HashNode.
-struct HashNodeStable {
-  llvm::yaml::Hex64 Hash;
-  unsigned Terminals;
-  std::vector<unsigned> SuccessorIds;
-};
-
 class OutlinedHashTree {
 
   using EdgeCallbackFn =
       std::function<void(const HashNode *, const HashNode *)>;
   using NodeCallbackFn = std::function<void(const HashNode *)>;
 
-  using HashSequence = std::vector<stable_hash>;
-  using HashSequencePair = std::pair<std::vector<stable_hash>, unsigned>;
+  using HashSequence = SmallVector<stable_hash>;
+  using HashSequencePair = std::pair<HashSequence, unsigned>;
 
 public:
   /// Walks every edge and node in the OutlinedHashTree and calls CallbackEdge
@@ -66,7 +59,7 @@ class OutlinedHashTree {
 
   /// Release all hash nodes except the root hash node.
   void clear() {
-    assert(getRoot()->Hash == 0 && getRoot()->Terminals == 0);
+    assert(getRoot()->Hash == 0 && !getRoot()->Terminals);
     getRoot()->Successors.clear();
   }
 
@@ -83,8 +76,8 @@ class OutlinedHashTree {
   size_t depth() const;
 
   /// \returns the root hash node of a OutlinedHashTree.
-  const HashNode *getRoot() const { return Root.get(); }
-  HashNode *getRoot() { return Root.get(); }
+  const HashNode *getRoot() const { return &Root; }
+  HashNode *getRoot() { return &Root; }
 
   /// Inserts a \p Sequence into the this tree. The last node in the sequence
   /// will increase Terminals.
@@ -94,12 +87,10 @@ class OutlinedHashTree {
   void merge(const OutlinedHashTree *OtherTree);
 
   /// \returns the matching count if \p Sequence exists in the OutlinedHashTree.
-  unsigned find(const HashSequence &Sequence) const;
-
-  OutlinedHashTree() { Root = std::make_unique<HashNode>(); }
+  std::optional<unsigned> find(const HashSequence &Sequence) const;
 
 private:
-  std::unique_ptr<HashNode> Root;
+  HashNode Root;
 };
 
 } // namespace llvm
diff --git a/llvm/include/llvm/CodeGenData/OutlinedHashTreeRecord.h b/llvm/include/llvm/CodeGenData/OutlinedHashTreeRecord.h
index ccd2ad26dd0871..2960e319604489 100644
--- a/llvm/include/llvm/CodeGenData/OutlinedHashTreeRecord.h
+++ b/llvm/include/llvm/CodeGenData/OutlinedHashTreeRecord.h
@@ -16,13 +16,22 @@
 #ifndef LLVM_CODEGENDATA_OUTLINEDHASHTREERECORD_H
 #define LLVM_CODEGENDATA_OUTLINEDHASHTREERECORD_H
 
+#include "llvm/ADT/DenseMap.h"
 #include "llvm/CodeGenData/OutlinedHashTree.h"
 
 namespace llvm {
 
+/// HashNodeStable is the serialized, stable, and compact representation
+/// of a HashNode.
+struct HashNodeStable {
+  llvm::yaml::Hex64 Hash;
+  unsigned Terminals;
+  std::vector<unsigned> SuccessorIds;
+};
+
 using IdHashNodeStableMapTy = std::map<unsigned, HashNodeStable>;
-using IdHashNodeMapTy = std::map<unsigned, HashNode *>;
-using HashNodeIdMapTy = std::unordered_map<const HashNode *, unsigned>;
+using IdHashNodeMapTy = DenseMap<unsigned, HashNode *>;
+using HashNodeIdMapTy = DenseMap<const HashNode *, unsigned>;
 
 struct OutlinedHashTreeRecord {
   std::unique_ptr<OutlinedHashTree> HashTree;
diff --git a/llvm/lib/CodeGenData/OutlinedHashTree.cpp b/llvm/lib/CodeGenData/OutlinedHashTree.cpp
index 032993ded60ead..cb985aa87afcfb 100644
--- a/llvm/lib/CodeGenData/OutlinedHashTree.cpp
+++ b/llvm/lib/CodeGenData/OutlinedHashTree.cpp
@@ -24,19 +24,18 @@ using namespace llvm;
 void OutlinedHashTree::walkGraph(NodeCallbackFn CallbackNode,
                                  EdgeCallbackFn CallbackEdge,
                                  bool SortedWalk) const {
-  std::stack<const HashNode *> Stack;
-  Stack.push(getRoot());
+  SmallVector<const HashNode *> Stack;
+  Stack.emplace_back(getRoot());
 
   while (!Stack.empty()) {
-    const auto *Current = Stack.top();
-    Stack.pop();
+    const auto *Current = Stack.pop_back_val();
     if (CallbackNode)
       CallbackNode(Current);
 
     auto HandleNext = [&](const HashNode *Next) {
       if (CallbackEdge)
         CallbackEdge(Current, Next);
-      Stack.push(Next);
+      Stack.emplace_back(Next);
     };
     if (SortedWalk) {
       std::map<stable_hash, const HashNode *> SortedSuccessors;
@@ -72,8 +71,7 @@ size_t OutlinedHashTree::depth() const {
 }
 
 void OutlinedHashTree::insert(const HashSequencePair &SequencePair) {
-  const auto &Sequence = SequencePair.first;
-  unsigned Count = SequencePair.second;
+  auto &[Sequence, Count] = SequencePair;
   HashNode *Current = getRoot();
 
   for (stable_hash StableHash : Sequence) {
@@ -87,22 +85,23 @@ void OutlinedHashTree::insert(const HashSequencePair &SequencePair) {
     } else
       Current = I->second.get();
   }
-  Current->Terminals += Count;
+  if (Count)
+    Current->Terminals = (Current->Terminals ? *Current->Terminals : 0) + Count;
 }
 
 void OutlinedHashTree::merge(const OutlinedHashTree *Tree) {
   HashNode *Dst = getRoot();
   const HashNode *Src = Tree->getRoot();
-  std::stack<std::pair<HashNode *, const HashNode *>> Stack;
-  Stack.push({Dst, Src});
+  SmallVector<std::pair<HashNode *, const HashNode *>> Stack;
+  Stack.emplace_back(Dst, Src);
 
   while (!Stack.empty()) {
-    auto [DstNode, SrcNode] = Stack.top();
-    Stack.pop();
+    auto [DstNode, SrcNode] = Stack.pop_back_val();
     if (!SrcNode)
       continue;
-    DstNode->Terminals += SrcNode->Terminals;
-
+    if (SrcNode->Terminals)
+      DstNode->Terminals =
+          (DstNode->Terminals ? *DstNode->Terminals : 0) + *SrcNode->Terminals;
     for (auto &[Hash, NextSrcNode] : SrcNode->Successors) {
       HashNode *NextDstNode;
       auto I = DstNode->Successors.find(Hash);
@@ -114,12 +113,13 @@ void OutlinedHashTree::merge(const OutlinedHashTree *Tree) {
       } else
         NextDstNode = I->second.get();
 
-      Stack.push({NextDstNode, NextSrcNode.get()});
+      Stack.emplace_back(NextDstNode, NextSrcNode.get());
     }
   }
 }
 
-unsigned OutlinedHashTree::find(const HashSequence &Sequence) const {
+std::optional<unsigned>
+OutlinedHashTree::find(const HashSequence &Sequence) const {
   const HashNode *Current = getRoot();
   for (stable_hash StableHash : Sequence) {
     const auto I = Current->Successors.find(StableHash);
diff --git a/llvm/lib/CodeGenData/OutlinedHashTreeRecord.cpp b/llvm/lib/CodeGenData/OutlinedHashTreeRecord.cpp
index 0d5dd864c89c55..da4db7e9e69f11 100644
--- a/llvm/lib/CodeGenData/OutlinedHashTreeRecord.cpp
+++ b/llvm/lib/CodeGenData/OutlinedHashTreeRecord.cpp
@@ -131,7 +131,7 @@ void OutlinedHashTreeRecord::convertToStableData(
     auto Id = P.second;
     HashNodeStable NodeStable;
     NodeStable.Hash = Node->Hash;
-    NodeStable.Terminals = Node->Terminals;
+    NodeStable.Terminals = Node->Terminals ? *Node->Terminals : 0;
     for (auto &P : Node->Successors)
       NodeStable.SuccessorIds.push_back(NodeIdMap[P.second.get()]);
     IdNodeStableMap[Id] = NodeStable;
@@ -139,7 +139,7 @@ void OutlinedHashTreeRecord::convertToStableData(
 
   // Sort the Successors so that they come out in the same order as in the map.
   for (auto &P : IdNodeStableMap)
-    std::sort(P.second.SuccessorIds.begin(), P.second.SuccessorIds.end());
+    llvm::sort(P.second.SuccessorIds);
 }
 
 void OutlinedHashTreeRecord::convertFromStableData(
@@ -155,7 +155,8 @@ void OutlinedHashTreeRecord::convertFromStableData(
     assert(IdNodeMap.count(Id));
     HashNode *Curr = IdNodeMap[Id];
     Curr->Hash = NodeStable.Hash;
-    Curr->Terminals = NodeStable.Terminals;
+    if (NodeStable.Terminals)
+      Curr->Terminals = NodeStable.Terminals;
     auto &Successors = Curr->Successors;
     assert(Successors.empty());
     for (auto SuccessorId : NodeStable.SuccessorIds) {
diff --git a/llvm/unittests/CodeGenData/OutlinedHashTreeTest.cpp b/llvm/unittests/CodeGenData/OutlinedHashTreeTest.cpp
index d11618cf8e4fae..5fdfa60673b7fc 100644
--- a/llvm/unittests/CodeGenData/OutlinedHashTreeTest.cpp
+++ b/llvm/unittests/CodeGenData/OutlinedHashTreeTest.cpp
@@ -50,8 +50,9 @@ TEST(OutlinedHashTreeTest, Find) {
   // The node count does not change as the same sequences are added.
   ASSERT_TRUE(HashTree.size() == 4);
   // The terminal counts are accumulated from two same sequences.
-  ASSERT_TRUE(HashTree.find({1, 2, 3}) == 3);
-  ASSERT_TRUE(HashTree.find({1, 2}) == 0);
+  ASSERT_TRUE(HashTree.find({1, 2, 3}));
+  ASSERT_TRUE(HashTree.find({1, 2, 3}).value() == 3);
+  ASSERT_FALSE(HashTree.find({1, 2}));
 }
 
 TEST(OutlinedHashTreeTest, Merge) {



More information about the llvm-commits mailing list