[llvm] Reland "[AMDGPU] Graph-based Module Splitting Rewrite (#104763)" (PR #107076)

Pierre van Houtryve via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 3 03:12:22 PDT 2024


================
@@ -278,351 +250,1113 @@ static bool canBeIndirectlyCalled(const Function &F) {
                            /*IgnoreCastedDirectCall=*/true);
 }
 
-/// When a function or any of its callees performs an indirect call, this
-/// takes over \ref addAllDependencies and adds all potentially callable
-/// functions to \p Fns so they can be counted as dependencies of the function.
+//===----------------------------------------------------------------------===//
+// Graph-based Module Representation
+//===----------------------------------------------------------------------===//
+
+/// AMDGPUSplitModule's view of the source Module, as a graph of all components
+/// that can be split into different modules.
 ///
-/// This is needed due to how AMDGPUResourceUsageAnalysis operates: in the
-/// presence of an indirect call, the function's resource usage is the same as
-/// the most expensive function in the module.
-/// \param M    The module.
-/// \param Fns[out] Resulting list of functions.
-static void addAllIndirectCallDependencies(const Module &M,
-                                           DenseSet<const Function *> &Fns) {
-  for (const auto &Fn : M) {
-    if (canBeIndirectlyCalled(Fn))
-      Fns.insert(&Fn);
+/// The most trivial instance of this graph is just the CallGraph of the module,
+/// but it is not guaranteed that the graph is strictly equal to the CG. It
+/// currently always is but it's designed in a way that would eventually allow
+/// us to create abstract nodes, or nodes for different entities such as global
+/// variables or any other meaningful constraint we must consider.
+///
+/// The graph is only mutable by this class, and is generally not modified
+/// after \ref SplitGraph::buildGraph runs. No consumers of the graph can
+/// mutate it.
+class SplitGraph {
+public:
+  class Node;
+
+  enum class EdgeKind : uint8_t {
+    /// The nodes are related through a direct call. This is a "strong" edge as
+    /// it means the Src will directly reference the Dst.
+    DirectCall,
+    /// The nodes are related through an indirect call.
+    /// This is a "weaker" edge and is only considered when traversing the graph
+    /// starting from a kernel. We need this edge for resource usage analysis.
+    ///
+    /// The reason why we have this edge in the first place is due to how
+    /// AMDGPUResourceUsageAnalysis works. In the presence of an indirect call,
+    /// the resource usage of the kernel containing the indirect call is the
+    /// max resource usage of all functions that can be indirectly called.
+    IndirectCall,
+  };
+
+  /// An edge between two nodes. Edges are directional, and tagged with a
+  /// "kind".
+  struct Edge {
+    Edge(Node *Src, Node *Dst, EdgeKind Kind)
+        : Src(Src), Dst(Dst), Kind(Kind) {}
+
+    Node *Src; ///< Source
+    Node *Dst; ///< Destination
+    EdgeKind Kind;
+  };
+
+  using EdgesVec = SmallVector<const Edge *, 0>;
+  using edges_iterator = EdgesVec::const_iterator;
+  using nodes_iterator = const Node *const *;
+
+  SplitGraph(const Module &M, const FunctionsCostMap &CostMap,
+             CostType ModuleCost)
+      : M(M), CostMap(CostMap), ModuleCost(ModuleCost) {}
+
+  void buildGraph(CallGraph &CG);
+
+#ifndef NDEBUG
+  bool verifyGraph() const;
+#endif
+
+  bool empty() const { return Nodes.empty(); }
+  const iterator_range<nodes_iterator> nodes() const {
+    return {Nodes.begin(), Nodes.end()};
   }
-}
+  const Node &getNode(unsigned ID) const { return *Nodes[ID]; }
+
+  unsigned getNumNodes() const { return Nodes.size(); }
+  BitVector createNodesBitVector() const { return BitVector(Nodes.size()); }
+
+  const Module &getModule() const { return M; }
+
+  CostType getModuleCost() const { return ModuleCost; }
+  CostType getCost(const Function &F) const { return CostMap.at(&F); }
+
+  /// \returns the aggregated cost of all nodes in \p BV (bits set to 1 = node
+  /// IDs).
+  CostType calculateCost(const BitVector &BV) const;
 
-/// Adds the functions that \p Fn may call to \p Fns, then recurses into each
-/// callee until all reachable functions have been gathered.
+private:
+  /// Retrieves the node for \p GV in \p Cache, or creates a new node for it and
+  /// updates \p Cache.
+  Node &getNode(DenseMap<const GlobalValue *, Node *> &Cache,
+                const GlobalValue &GV);
+
+  // Create a new edge between two nodes and add it to both nodes.
+  const Edge &createEdge(Node &Src, Node &Dst, EdgeKind EK);
+
+  const Module &M;
+  const FunctionsCostMap &CostMap;
+  CostType ModuleCost;
+
+  // Final list of nodes with stable ordering.
+  SmallVector<Node *> Nodes;
+
+  SpecificBumpPtrAllocator<Node> NodesPool;
+
+  // Edges are trivially destructible objects, so as a small optimization we
+  // use a BumpPtrAllocator which avoids destructor calls but also makes
+  // allocation faster.
+  static_assert(
+      std::is_trivially_destructible_v<Edge>,
+      "Edge must be trivially destructible to use the BumpPtrAllocator");
+  BumpPtrAllocator EdgesPool;
+};
+
+/// Nodes in the SplitGraph contain both incoming, and outgoing edges.
+/// Incoming edges have this node as their Dst, and Outgoing ones have this node
+/// as their Src.
 ///
-/// \param SML Log Helper
-/// \param CG Call graph for \p Fn's module.
-/// \param Fn Current function to look at.
-/// \param Fns[out] Resulting list of functions.
-/// \param OnlyDirect Whether to only consider direct callees.
-/// \param HadIndirectCall[out] Set to true if an indirect call was seen at some
-/// point, either in \p Fn or in one of the function it calls. When that
-/// happens, we fall back to adding all callable functions inside \p Fn's module
-/// to \p Fns.
-static void addAllDependencies(SplitModuleLogger &SML, const CallGraph &CG,
-                               const Function &Fn,
-                               DenseSet<const Function *> &Fns, bool OnlyDirect,
-                               bool &HadIndirectCall) {
-  assert(!Fn.isDeclaration());
-
-  const Module &M = *Fn.getParent();
-  SmallVector<const Function *> WorkList({&Fn});
+/// Edge objects are shared by both nodes in Src/Dst. They provide immediate
+/// feedback on how two nodes are related, and in which direction they are
+/// related, which is valuable information to make splitting decisions.
+///
+/// Nodes are fundamentally abstract, and any consumers of the graph should
+/// treat them as such. While a node will be a function most of the time, we
+/// could also create nodes for any other reason. In the future, we could have
+/// single nodes for multiple functions, or nodes for GVs, etc.
+class SplitGraph::Node {
+  friend class SplitGraph;
+
+public:
+  Node(unsigned ID, const GlobalValue &GV, CostType IndividualCost,
+       bool IsNonCopyable)
+      : ID(ID), GV(GV), IndividualCost(IndividualCost),
+        IsNonCopyable(IsNonCopyable), IsEntryFnCC(false), IsGraphEntry(false) {
+    if (auto *Fn = dyn_cast<Function>(&GV))
+      IsEntryFnCC = AMDGPU::isEntryFunctionCC(Fn->getCallingConv());
+  }
+
+  /// An 0-indexed ID for the node. The maximum ID (exclusive) is the number of
+  /// nodes in the graph. This ID can be used as an index in a BitVector.
+  unsigned getID() const { return ID; }
+
+  const Function &getFunction() const { return cast<Function>(GV); }
+
+  /// \returns the cost to import this component into a given module, not
+  /// accounting for any dependencies that may need to be imported as well.
+  CostType getIndividualCost() const { return IndividualCost; }
+
+  bool isNonCopyable() const { return IsNonCopyable; }
+  bool isEntryFunctionCC() const { return IsEntryFnCC; }
+
+  /// \returns whether this is an entry point in the graph. Entry points are
+  /// defined as follows: if you take all entry points in the graph, and iterate
+  /// their dependencies, you are guaranteed to visit all nodes in the graph at
+  /// least once.
+  bool isGraphEntryPoint() const { return IsGraphEntry; }
+
+  StringRef getName() const { return GV.getName(); }
+
+  bool hasAnyIncomingEdges() const { return IncomingEdges.size(); }
+  bool hasAnyIncomingEdgesOfKind(EdgeKind EK) const {
+    return any_of(IncomingEdges, [&](const auto *E) { return E->Kind == EK; });
+  }
+
+  bool hasAnyOutgoingEdges() const { return OutgoingEdges.size(); }
+  bool hasAnyOutgoingEdgesOfKind(EdgeKind EK) const {
+    return any_of(OutgoingEdges, [&](const auto *E) { return E->Kind == EK; });
+  }
+
+  iterator_range<edges_iterator> incoming_edges() const {
+    return IncomingEdges;
+  }
+
+  iterator_range<edges_iterator> outgoing_edges() const {
+    return OutgoingEdges;
+  }
+
+  bool shouldFollowIndirectCalls() const { return isEntryFunctionCC(); }
+
+  /// Visit all children of this node in a recursive fashion. Also visits Self.
+  /// If \ref shouldFollowIndirectCalls returns false, then this only follows
+  /// DirectCall edges.
+  ///
+  /// \param Visitor Visitor Function.
+  void visitAllDependencies(std::function<void(const Node &)> Visitor) const;
+
+  /// Adds the depedencies of this node in \p BV by setting the bit
+  /// corresponding to each node.
+  ///
+  /// Implemented using \ref visitAllDependencies, hence it follows the same
+  /// rules regarding dependencies traversal.
+  ///
+  /// \param[out] BV The bitvector where the bits should be set.
+  void getDependencies(BitVector &BV) const {
+    visitAllDependencies([&](const Node &N) { BV.set(N.getID()); });
+  }
+
+  /// Uses \ref visitAllDependencies to aggregate the individual cost of this
+  /// node and all of its dependencies.
+  ///
+  /// This is cached.
+  CostType getFullCost() const;
+
+private:
+  void markAsGraphEntry() { IsGraphEntry = true; }
+
+  unsigned ID;
+  const GlobalValue &GV;
+  CostType IndividualCost;
+  bool IsNonCopyable : 1;
+  bool IsEntryFnCC : 1;
+  bool IsGraphEntry : 1;
+
+  // TODO: Cache dependencies as well?
+  mutable CostType FullCost = 0;
+
+  // TODO: Use a single sorted vector (with all incoming/outgoing edges grouped
+  // together)
+  EdgesVec IncomingEdges;
+  EdgesVec OutgoingEdges;
+};
+
+void SplitGraph::Node::visitAllDependencies(
+    std::function<void(const Node &)> Visitor) const {
+  const bool FollowIndirect = shouldFollowIndirectCalls();
+  // FIXME: If this can access SplitGraph in the future, use a BitVector
+  // instead.
+  DenseSet<const Node *> Seen;
+  SmallVector<const Node *, 8> WorkList({this});
   while (!WorkList.empty()) {
-    const auto &CurFn = *WorkList.pop_back_val();
-    assert(!CurFn.isDeclaration());
+    const Node *CurN = WorkList.pop_back_val();
+    if (auto [It, Inserted] = Seen.insert(CurN); !Inserted)
+      continue;
 
-    // Scan for an indirect call. If such a call is found, we have to
-    // conservatively assume this can call all non-entrypoint functions in the
-    // module.
+    Visitor(*CurN);
 
-    for (auto &CGEntry : *CG[&CurFn]) {
+    for (const Edge *E : CurN->outgoing_edges()) {
+      if (!FollowIndirect && E->Kind == EdgeKind::IndirectCall)
+        continue;
+      WorkList.push_back(E->Dst);
+    }
+  }
+}
+
+CostType SplitGraph::Node::getFullCost() const {
+  if (FullCost)
+    return FullCost;
+
+  assert(FullCost == 0);
+  visitAllDependencies(
+      [&](const Node &N) { FullCost += N.getIndividualCost(); });
+  return FullCost;
+}
+
+void SplitGraph::buildGraph(CallGraph &CG) {
+  SplitModuleTimer SMT("buildGraph", "graph construction");
+  LLVM_DEBUG(
+      dbgs()
+      << "[build graph] constructing graph representation of the input\n");
+
+  // We build the graph by just iterating all functions in the module and
+  // working on their direct callees. At the end, all nodes should be linked
+  // together as expected.
+  DenseMap<const GlobalValue *, Node *> Cache;
+  SmallVector<const Function *> FnsWithIndirectCalls, IndirectlyCallableFns;
+  for (const Function &Fn : M) {
+    if (Fn.isDeclaration())
+      continue;
+
+    // Look at direct callees and create the necessary edges in the graph.
+    bool HasIndirectCall = false;
+    Node &N = getNode(Cache, Fn);
+    for (auto &CGEntry : *CG[&Fn]) {
       auto *CGNode = CGEntry.second;
       auto *Callee = CGNode->getFunction();
       if (!Callee) {
-        if (OnlyDirect)
-          continue;
-
-        // Functions have an edge towards CallsExternalNode if they're external
-        // declarations, or if they do an indirect call. As we only process
-        // definitions here, we know this means the function has an indirect
-        // call. We then have to conservatively assume this can call all
-        // non-entrypoint functions in the module.
-        if (CGNode != CG.getCallsExternalNode())
-          continue; // this is another function-less node we don't care about.
-
-        SML << "Indirect call detected in " << getName(CurFn)
-            << " - treating all non-entrypoint functions as "
-               "potential dependencies\n";
-
-        // TODO: Print an ORE as well ?
-        addAllIndirectCallDependencies(M, Fns);
-        HadIndirectCall = true;
+        // TODO: Don't consider inline assembly as indirect calls.
+        if (CGNode == CG.getCallsExternalNode())
+          HasIndirectCall = true;
         continue;
       }
 
-      if (Callee->isDeclaration())
-        continue;
+      if (!Callee->isDeclaration())
+        createEdge(N, getNode(Cache, *Callee), EdgeKind::DirectCall);
+    }
+
+    // Keep track of this function if it contains an indirect call and/or if it
+    // can be indirectly called.
+    if (HasIndirectCall) {
+      LLVM_DEBUG(dbgs() << "indirect call found in " << Fn.getName() << "\n");
+      FnsWithIndirectCalls.push_back(&Fn);
+    }
+
+    if (canBeIndirectlyCalled(Fn))
+      IndirectlyCallableFns.push_back(&Fn);
+  }
 
-      auto [It, Inserted] = Fns.insert(Callee);
-      if (Inserted)
-        WorkList.push_back(Callee);
+  // Post-process functions with indirect calls.
+  for (const Function *Fn : FnsWithIndirectCalls) {
+    for (const Function *Candidate : IndirectlyCallableFns) {
+      Node &Src = getNode(Cache, *Fn);
+      Node &Dst = getNode(Cache, *Candidate);
+      createEdge(Src, Dst, EdgeKind::IndirectCall);
     }
   }
+
+  // Now, find all entry points.
+  SmallVector<Node *, 16> CandidateEntryPoints;
+  BitVector NodesReachableByKernels = createNodesBitVector();
+  for (Node *N : Nodes) {
+    // Functions with an Entry CC are always graph entry points too.
+    if (N->isEntryFunctionCC()) {
+      N->markAsGraphEntry();
+      N->getDependencies(NodesReachableByKernels);
+    } else if (!N->hasAnyIncomingEdgesOfKind(EdgeKind::DirectCall))
+      CandidateEntryPoints.push_back(N);
+  }
+
+  for (Node *N : CandidateEntryPoints) {
+    // This can be another entry point if it's not reachable by a kernel
+    // TODO: We could sort all of the possible new entries in a stable order
+    // (e.g. by cost), then consume them one by one until
+    // NodesReachableByKernels is all 1s. It'd allow us to avoid
+    // considering some nodes as non-entries in some specific cases.
+    if (!NodesReachableByKernels.test(N->getID()))
+      N->markAsGraphEntry();
+  }
+
+#ifndef NDEBUG
+  assert(verifyGraph());
+#endif
 }
 
-/// Contains information about a function and its dependencies.
-/// This is a splitting root. The splitting algorithm works by
-/// assigning these to partitions.
-struct FunctionWithDependencies {
-  FunctionWithDependencies(SplitModuleLogger &SML, CallGraph &CG,
-                           const DenseMap<const Function *, CostType> &FnCosts,
-                           const Function *Fn)
-      : Fn(Fn) {
-    // When Fn is not a kernel, we don't need to collect indirect callees.
-    // Resource usage analysis is only performed on kernels, and we collect
-    // indirect callees for resource usage analysis.
-    addAllDependencies(SML, CG, *Fn, Dependencies,
-                       /*OnlyDirect*/ !isEntryPoint(Fn), HasIndirectCall);
-    TotalCost = FnCosts.at(Fn);
-    for (const auto *Dep : Dependencies) {
-      TotalCost += FnCosts.at(Dep);
-
-      // We cannot duplicate functions with external linkage, or functions that
-      // may be overriden at runtime.
-      HasNonDuplicatableDependecy |=
-          (Dep->hasExternalLinkage() || !Dep->isDefinitionExact());
+#ifndef NDEBUG
+bool SplitGraph::verifyGraph() const {
+  unsigned ExpectedID = 0;
+  // Exceptionally using a set here in case IDs are messed up.
+  DenseSet<const Node *> SeenNodes;
+  DenseSet<const Function *> SeenFunctionNodes;
+  for (const Node *N : Nodes) {
+    if (N->getID() != (ExpectedID++)) {
+      errs() << "Node IDs are incorrect!\n";
+      return false;
+    }
+
+    if (!SeenNodes.insert(N).second) {
+      errs() << "Node seen more than once!\n";
+      return false;
+    }
+
+    if (&getNode(N->getID()) != N) {
+      errs() << "getNode doesn't return the right node\n";
+      return false;
+    }
+
+    for (const Edge *E : N->IncomingEdges) {
+      if (!E->Src || !E->Dst || (E->Dst != N) ||
+          (find(E->Src->OutgoingEdges, E) == E->Src->OutgoingEdges.end())) {
+        errs() << "ill-formed incoming edges\n";
+        return false;
+      }
+    }
+
+    for (const Edge *E : N->OutgoingEdges) {
+      if (!E->Src || !E->Dst || (E->Src != N) ||
+          (find(E->Dst->IncomingEdges, E) == E->Dst->IncomingEdges.end())) {
+        errs() << "ill-formed outgoing edges\n";
+        return false;
+      }
+    }
+
+    const Function &Fn = N->getFunction();
+    if (AMDGPU::isEntryFunctionCC(Fn.getCallingConv())) {
+      if (N->hasAnyIncomingEdges()) {
+        errs() << "Kernels cannot have incoming edges\n";
+        return false;
+      }
+    }
+
+    if (Fn.isDeclaration()) {
+      errs() << "declarations shouldn't have nodes!\n";
+      return false;
+    }
+
+    auto [It, Inserted] = SeenFunctionNodes.insert(&Fn);
+    if (!Inserted) {
+      errs() << "one function has multiple nodes!\n";
+      return false;
     }
   }
 
-  const Function *Fn = nullptr;
-  DenseSet<const Function *> Dependencies;
-  /// Whether \p Fn or any of its \ref Dependencies contains an indirect call.
-  bool HasIndirectCall = false;
-  /// Whether any of \p Fn's dependencies cannot be duplicated.
-  bool HasNonDuplicatableDependecy = false;
+  if (ExpectedID != Nodes.size()) {
+    errs() << "Node IDs out of sync!\n";
+    return false;
+  }
 
-  CostType TotalCost = 0;
+  if (createNodesBitVector().size() != getNumNodes()) {
+    errs() << "nodes bit vector doesn't have the right size!\n";
+    return false;
+  }
+
+  // Check we respect the promise of Node::isKernel
+  BitVector BV = createNodesBitVector();
+  for (const Node *N : nodes()) {
+    if (N->isGraphEntryPoint())
+      N->getDependencies(BV);
+  }
+
+  // Ensure each function in the module has an associated node.
+  for (const auto &Fn : M) {
+    if (!Fn.isDeclaration()) {
+      if (!SeenFunctionNodes.contains(&Fn)) {
+        errs() << "Fn has no associated node in the graph!\n";
+        return false;
+      }
+    }
+  }
+
+  if (!BV.all()) {
+    errs() << "not all nodes are reachable through the graph's entry points!\n";
+    return false;
+  }
 
-  /// \returns true if this function and its dependencies can be considered
-  /// large according to \p Threshold.
-  bool isLarge(CostType Threshold) const {
-    return TotalCost > Threshold && !Dependencies.empty();
+  return true;
+}
+#endif
+
+CostType SplitGraph::calculateCost(const BitVector &BV) const {
+  CostType Cost = 0;
+  for (unsigned NodeID : BV.set_bits())
+    Cost += getNode(NodeID).getIndividualCost();
+  return Cost;
+}
+
+SplitGraph::Node &
+SplitGraph::getNode(DenseMap<const GlobalValue *, Node *> &Cache,
+                    const GlobalValue &GV) {
+  auto &N = Cache[&GV];
+  if (N)
+    return *N;
+
+  CostType Cost = 0;
+  bool NonCopyable = false;
+  if (const Function *Fn = dyn_cast<Function>(&GV)) {
+    NonCopyable = isNonCopyable(*Fn);
+    Cost = CostMap.at(Fn);
   }
+  N = new (NodesPool.Allocate()) Node(Nodes.size(), GV, Cost, NonCopyable);
+  Nodes.push_back(N);
+  assert(&getNode(N->getID()) == N);
+  return *N;
+}
+
+const SplitGraph::Edge &SplitGraph::createEdge(Node &Src, Node &Dst,
+                                               EdgeKind EK) {
+  const Edge *E = new (EdgesPool.Allocate<Edge>(1)) Edge(&Src, &Dst, EK);
+  Src.OutgoingEdges.push_back(E);
+  Dst.IncomingEdges.push_back(E);
+  return *E;
+}
+
+//===----------------------------------------------------------------------===//
+// Split Proposals
+//===----------------------------------------------------------------------===//
+
+/// Represents a module splitting proposal.
+///
+/// Proposals are made of N BitVectors, one for each partition, where each bit
+/// set indicates that the node is present and should be copied inside that
+/// partition.
+///
+/// Proposals have several metrics attached so they can be compared/sorted,
+/// which the driver to try multiple strategies resultings in multiple proposals
+/// and choose the best one out of them.
+class SplitProposal {
+public:
+  SplitProposal(const SplitGraph &SG, unsigned MaxPartitions) : SG(&SG) {
+    Partitions.resize(MaxPartitions, {0, SG.createNodesBitVector()});
+  }
+
+  void setName(StringRef NewName) { Name = NewName; }
+  StringRef getName() const { return Name; }
+
+  const BitVector &operator[](unsigned PID) const {
+    return Partitions[PID].second;
+  }
+
+  void add(unsigned PID, const BitVector &BV) {
+    Partitions[PID].second |= BV;
+    updateScore(PID);
+  }
+
+  void print(raw_ostream &OS) const;
+  LLVM_DUMP_METHOD void dump() const { print(dbgs()); }
+
+  // Find the cheapest partition (lowest cost). In case of ties, always returns
+  // the highest partition number.
+  unsigned findCheapestPartition() const;
+
+  /// Calculate the CodeSize and Bottleneck scores.
+  void calculateScores();
+
+#ifndef NDEBUG
+  void verifyCompleteness() const;
+#endif
+
+  /// Only available after \ref calculateScores is called.
+  ///
+  /// A positive number indicating the % of code duplication that this proposal
+  /// creates. e.g. 0.2 means this proposal adds roughly 20% code size by
+  /// duplicating some functions across partitions.
+  ///
+  /// Value is always rounded up to 3 decimal places.
+  ///
+  /// A perfect score would be 0.0, and anything approaching 1.0 is very bad.
+  double getCodeSizeScore() const { return CodeSizeScore; }
+
+  /// Only available after \ref calculateScores is called.
+  ///
+  /// A number between [0, 1] which indicates how big of a bottleneck is
+  /// expected from the largest partition.
+  ///
+  /// A score of 1.0 means the biggest partition is as big as the source module,
+  /// so build time will be equal to or greater than the build time of the
+  /// initial input.
+  ///
+  /// Value is always rounded up to 3 decimal places.
+  ///
+  /// This is one of the metrics used to estimate this proposal's build time.
+  double getBottleneckScore() const { return BottleneckScore; }
+
+private:
+  void updateScore(unsigned PID) {
+    assert(SG);
+    for (auto &[PCost, Nodes] : Partitions) {
+      TotalCost -= PCost;
+      PCost = SG->calculateCost(Nodes);
+      TotalCost += PCost;
+    }
+  }
+
+  /// \see getCodeSizeScore
+  double CodeSizeScore = 0.0;
+  /// \see getBottleneckScore
+  double BottleneckScore = 0.0;
+  /// Aggregated cost of all partitions
+  CostType TotalCost = 0;
+
+  const SplitGraph *SG = nullptr;
+  std::string Name;
+
+  std::vector<std::pair<CostType, BitVector>> Partitions;
 };
 
-/// Calculates how much overlap there is between \p A and \p B.
-/// \return A number between 0.0 and 1.0, where 1.0 means A == B and 0.0 means A
-/// and B have no shared elements. Kernels do not count in overlap calculation.
-static float calculateOverlap(const DenseSet<const Function *> &A,
-                              const DenseSet<const Function *> &B) {
-  DenseSet<const Function *> Total;
-  for (const auto *F : A) {
-    if (!isEntryPoint(F))
-      Total.insert(F);
+void SplitProposal::print(raw_ostream &OS) const {
+  assert(SG);
+
+  OS << "[proposal] " << Name << ", total cost:" << TotalCost
+     << ", code size score:" << format("%0.3f", CodeSizeScore)
+     << ", bottleneck score:" << format("%0.3f", BottleneckScore) << '\n';
+  for (const auto &[PID, Part] : enumerate(Partitions)) {
+    const auto &[Cost, NodeIDs] = Part;
+    OS << "  - P" << PID << " nodes:" << NodeIDs.count() << " cost: " << Cost
+       << '|' << formatRatioOf(Cost, SG->getModuleCost()) << "%\n";
   }
+}
 
-  if (Total.empty())
-    return 0.0f;
+unsigned SplitProposal::findCheapestPartition() const {
+  assert(!Partitions.empty());
+  CostType CurCost = std::numeric_limits<CostType>::max();
+  unsigned CurPID = InvalidPID;
+  for (const auto &[Idx, Part] : enumerate(Partitions)) {
+    if (Part.first <= CurCost) {
+      CurPID = Idx;
+      CurCost = Part.first;
+    }
+  }
+  assert(CurPID != InvalidPID);
+  return CurPID;
+}
 
-  unsigned NumCommon = 0;
-  for (const auto *F : B) {
-    if (isEntryPoint(F))
-      continue;
+void SplitProposal::calculateScores() {
+  if (Partitions.empty())
+    return;
 
-    auto [It, Inserted] = Total.insert(F);
-    if (!Inserted)
-      ++NumCommon;
+  assert(SG);
+  CostType LargestPCost = 0;
+  for (auto &[PCost, Nodes] : Partitions) {
+    if (PCost > LargestPCost)
+      LargestPCost = PCost;
   }
 
-  return static_cast<float>(NumCommon) / Total.size();
+  CostType ModuleCost = SG->getModuleCost();
+  CodeSizeScore = double(TotalCost) / ModuleCost;
+  assert(CodeSizeScore >= 0.0);
+
+  BottleneckScore = double(LargestPCost) / ModuleCost;
+
+  CodeSizeScore = std::ceil(CodeSizeScore * 100.0) / 100.0;
+  BottleneckScore = std::ceil(BottleneckScore * 100.0) / 100.0;
 }
 
-/// Performs all of the partitioning work on \p M.
-/// \param SML Log Helper
-/// \param M Module to partition.
-/// \param NumParts Number of partitions to create.
-/// \param ModuleCost Total cost of all functions in \p M.
-/// \param FnCosts Map of Function -> Cost
-/// \param WorkList Functions and their dependencies to process in order.
-/// \returns The created partitions (a vector of size \p NumParts )
-static std::vector<DenseSet<const Function *>>
-doPartitioning(SplitModuleLogger &SML, Module &M, unsigned NumParts,
-               CostType ModuleCost,
-               const DenseMap<const Function *, CostType> &FnCosts,
-               const SmallVector<FunctionWithDependencies> &WorkList) {
-
-  SML << "\n--Partitioning Starts--\n";
-
-  // Calculate a "large function threshold". When more than one function's total
-  // import cost exceeds this value, we will try to assign it to an existing
-  // partition to reduce the amount of duplication needed.
-  //
-  // e.g. let two functions X and Y have a import cost of ~10% of the module, we
-  // assign X to a partition as usual, but when we get to Y, we check if it's
-  // worth also putting it in Y's partition.
-  const CostType LargeFnThreshold =
-      LargeFnFactor ? CostType(((ModuleCost / NumParts) * LargeFnFactor))
-                    : std::numeric_limits<CostType>::max();
-
-  std::vector<DenseSet<const Function *>> Partitions;
-  Partitions.resize(NumParts);
-
-  // Assign functions to partitions, and try to keep the partitions more or
-  // less balanced. We do that through a priority queue sorted in reverse, so we
-  // can always look at the partition with the least content.
-  //
-  // There are some cases where we will be deliberately unbalanced though.
-  //  - Large functions: we try to merge with existing partitions to reduce code
-  //  duplication.
-  //  - Functions with indirect or external calls always go in the first
-  //  partition (P0).
-  auto ComparePartitions = [](const std::pair<PartitionID, CostType> &a,
-                              const std::pair<PartitionID, CostType> &b) {
-    // When two partitions have the same cost, assign to the one with the
-    // biggest ID first. This allows us to put things in P0 last, because P0 may
-    // have other stuff added later.
-    if (a.second == b.second)
-      return a.first < b.first;
-    return a.second > b.second;
+#ifndef NDEBUG
+void SplitProposal::verifyCompleteness() const {
+  if (Partitions.empty())
+    return;
+
+  BitVector Result = Partitions[0].second;
+  for (const auto &P : drop_begin(Partitions))
+    Result |= P.second;
+  assert(Result.all() && "some nodes are missing from this proposal!");
+}
+#endif
+
+//===-- RecursiveSearchStrategy -------------------------------------------===//
+
+/// Partitioning algorithm.
+///
+/// This is a recursive search algorithm that can explore multiple possiblities.
+///
+/// When a cluster of nodes can go into more than one partition, and we haven't
+/// reached maximum search depth, we recurse and explore both options and their
+/// consequences. Both branches will yield a proposal, and the driver will grade
+/// both and choose the best one.
+///
+/// If max depth is reached, we will use some heuristics to make a choice. Most
+/// of the time we will just use the least-pressured (cheapest) partition, but
+/// if a cluster is particularly big and there is a good amount of overlap with
+/// an existing partition, we will choose that partition instead.
+class RecursiveSearchSplitting {
+public:
+  using SubmitProposalFn = function_ref<void(SplitProposal)>;
+
+  RecursiveSearchSplitting(const SplitGraph &SG, unsigned NumParts,
+                           SubmitProposalFn SubmitProposal);
+
+  void run();
+
+private:
+  struct WorkListEntry {
+    WorkListEntry(const BitVector &BV) : Cluster(BV) {}
+
+    unsigned NumNonEntryNodes = 0;
+    CostType TotalCost = 0;
+    CostType CostExcludingGraphEntryPoints = 0;
+    BitVector Cluster;
   };
 
-  // We can't use priority_queue here because we need to be able to access any
-  // element. This makes this a bit inefficient as we need to sort it again
-  // everytime we change it, but it's a very small array anyway (likely under 64
-  // partitions) so it's a cheap operation.
-  std::vector<std::pair<PartitionID, CostType>> BalancingQueue;
-  for (unsigned I = 0; I < NumParts; ++I)
-    BalancingQueue.emplace_back(I, 0);
-
-  // Helper function to handle assigning a function to a partition. This takes
-  // care of updating the balancing queue.
-  const auto AssignToPartition = [&](PartitionID PID,
-                                     const FunctionWithDependencies &FWD) {
-    auto &FnsInPart = Partitions[PID];
-    FnsInPart.insert(FWD.Fn);
-    FnsInPart.insert(FWD.Dependencies.begin(), FWD.Dependencies.end());
-
-    SML << "assign " << getName(*FWD.Fn) << " to P" << PID << "\n  ->  ";
-    if (!FWD.Dependencies.empty()) {
-      SML << FWD.Dependencies.size() << " dependencies added\n";
-    };
-
-    // Update the balancing queue. we scan backwards because in the common case
-    // the partition is at the end.
-    for (auto &[QueuePID, Cost] : reverse(BalancingQueue)) {
-      if (QueuePID == PID) {
-        CostType NewCost = 0;
-        for (auto *Fn : Partitions[PID])
-          NewCost += FnCosts.at(Fn);
-
-        SML << "[Updating P" << PID << " Cost]:" << Cost << " -> " << NewCost;
-        if (Cost) {
-          SML << " (" << unsigned(((float(NewCost) / Cost) - 1) * 100)
-              << "% increase)";
-        }
-        SML << '\n';
+  /// Checks if the TotalCost of \p A > \p B, handling the case where the costs
+  /// are identical in a deterministic manner.
+  bool stableGreaterThan(const WorkListEntry &A, const WorkListEntry &B) const;
+
+  /// Collects all graph entry points's clusters and sort them so the most
+  /// expensive clusters are viewed first. This will merge clusters together if
+  /// they share a non-copyable dependency.
+  void setupWorkList();
+
+  /// Recursive function that assigns the worklist item at \p Idx into a
+  /// partition of \p SP.
+  ///
+  /// \p Depth is the current search depth. When this value is equal to
+  /// \ref MaxDepth, we can no longer recurse.
+  ///
+  /// This function only recurses if there is more than one possible assignment,
+  /// otherwise it is iterative to avoid creating a call stack that is as big as
+  /// \ref WorkList.
+  void pickPartition(unsigned Depth, unsigned Idx, SplitProposal SP);
+
+  /// \return A pair: first element is the PID of the partition that has the
+  /// most similarities with \p Entry, or \ref InvalidPID if no partition was
+  /// found with at least one element in common. The second element is the
+  /// aggregated cost of all dependencies in common between \p Entry and that
+  /// partition.
+  std::pair<unsigned, CostType>
+  findMostSimilarPartition(const WorkListEntry &Entry, const SplitProposal &SP);
+
+  const SplitGraph &SG;
+  unsigned NumParts;
+  SubmitProposalFn SubmitProposal;
+
+  // A Cluster is considered large when its cost, excluding entry points,
+  // exceeds this value.
+  CostType LargeClusterThreshold = 0;
+  unsigned NumProposalsSubmitted = 0;
+  SmallVector<WorkListEntry> WorkList;
+};
 
-        Cost = NewCost;
+RecursiveSearchSplitting::RecursiveSearchSplitting(
+    const SplitGraph &SG, unsigned NumParts, SubmitProposalFn SubmitProposal)
+    : SG(SG), NumParts(NumParts), SubmitProposal(SubmitProposal) {
+  // arbitrary max value as a safeguard. Anything above 10 will already be
+  // slow, this is just a max value to prevent extreme resource exhaustion or
+  // unbounded run time.
+  if (MaxDepth > 16)
+    report_fatal_error("[amdgpu-split-module] search depth of " +
+                       Twine(MaxDepth) + " is too high!");
+  LargeClusterThreshold =
+      (LargeFnFactor != 0.0)
+          ? CostType(((SG.getModuleCost() / NumParts) * LargeFnFactor))
+          : std::numeric_limits<CostType>::max();
+  LLVM_DEBUG(dbgs() << "[recursive search] large cluster threshold set at "
+                    << LargeClusterThreshold << "\n");
+}
+
+void RecursiveSearchSplitting::run() {
+  {
+    SplitModuleTimer SMT("recursive_search_prepare", "preparing worklist");
+    setupWorkList();
+  }
+
+  {
+    SplitModuleTimer SMT("recursive_search_pick", "partitioning");
+    SplitProposal SP(SG, NumParts);
+    pickPartition(/*BranchDepth=*/0, /*Idx=*/0, SP);
+  }
+}
+
+bool RecursiveSearchSplitting::stableGreaterThan(const WorkListEntry &A,
+                                                 const WorkListEntry &B) const {
+  if (A.TotalCost != B.TotalCost)
+    return A.TotalCost > B.TotalCost;
+
+  if (A.CostExcludingGraphEntryPoints != B.CostExcludingGraphEntryPoints)
+    return A.CostExcludingGraphEntryPoints > B.CostExcludingGraphEntryPoints;
+
+  if (A.NumNonEntryNodes != B.NumNonEntryNodes)
+    return A.NumNonEntryNodes > B.NumNonEntryNodes;
+
+  const auto &CA = A.Cluster;
+  const auto &CB = B.Cluster;
+  if (CA.count() != CB.count())
+    return CA.count() > CB.count();
+
+  // If these are identical clusters, which can happen in some glibc checks that
+  // verify the result of A > A for instance, return false to pass the glibc
+  // assert that !(a < a).
+  if (CA == CB)
+    return false;
+
+  // As a last resort, take the first diverging bit between the sets.
+  BitVector Result = CA;
+  Result ^= CB;
+
+  assert(Result.any());
+  return CA.test(*Result.set_bits_begin()) ? true : false;
+}
+
+void RecursiveSearchSplitting::setupWorkList() {
+  // e.g. if A and B are two worklist item, and they both call a non copyable
+  // dependency C, this does:
+  //    A=C
+  //    B=C
+  // => NodeEC will create a single group (A, B, C) and we create a new
+  // WorkList entry for that group.
+
+  EquivalenceClasses<unsigned> NodeEC;
+  for (const SplitGraph::Node *N : SG.nodes()) {
+    if (!N->isGraphEntryPoint())
+      continue;
+
+    NodeEC.insert(N->getID());
+    N->visitAllDependencies([&](const SplitGraph::Node &Dep) {
+      if (&Dep != N && Dep.isNonCopyable())
+        NodeEC.unionSets(N->getID(), Dep.getID());
+    });
+  }
+
+  for (auto I = NodeEC.begin(), E = NodeEC.end(); I != E; ++I) {
+    if (!I->isLeader())
+      continue;
+
+    BitVector Cluster = SG.createNodesBitVector();
+    for (auto MI = NodeEC.member_begin(I); MI != NodeEC.member_end(); ++MI) {
+      const SplitGraph::Node &N = SG.getNode(*MI);
+      if (N.isGraphEntryPoint())
+        N.getDependencies(Cluster);
+    }
+    WorkList.emplace_back(std::move(Cluster));
+  }
+
+  // Calculate costs and other useful information.
+  for (WorkListEntry &Entry : WorkList) {
+    for (unsigned NodeID : Entry.Cluster.set_bits()) {
+      const SplitGraph::Node &N = SG.getNode(NodeID);
+      const CostType Cost = N.getIndividualCost();
+
+      Entry.TotalCost += Cost;
+      if (!N.isGraphEntryPoint()) {
+        Entry.CostExcludingGraphEntryPoints += Cost;
+        ++Entry.NumNonEntryNodes;
       }
     }
+  }
 
-    sort(BalancingQueue, ComparePartitions);
-  };
+  sort(WorkList, [this](const WorkListEntry &LHS, const WorkListEntry &RHS) {
+    return stableGreaterThan(LHS, RHS);
+  });
 
-  for (auto &CurFn : WorkList) {
-    // When a function has indirect calls, it must stay in the first partition
-    // alongside every reachable non-entry function. This is a nightmare case
-    // for splitting as it severely limits what we can do.
-    if (CurFn.HasIndirectCall) {
-      SML << "Function with indirect call(s): " << getName(*CurFn.Fn)
-          << " defaulting to P0\n";
-      AssignToPartition(0, CurFn);
-      continue;
+  LLVM_DEBUG({
+    dbgs() << "[recursive search] worklist:\n";
+    for (const auto &[Idx, Entry] : enumerate(WorkList)) {
+      dbgs() << "  - [" << Idx << "]: ";
+      for (unsigned NodeID : Entry.Cluster.set_bits())
+        dbgs() << NodeID << " ";
+      dbgs() << "(total_cost:" << Entry.TotalCost
+             << ", cost_excl_entries:" << Entry.CostExcludingGraphEntryPoints
+             << ")\n";
+    }
+  });
+}
+
+void RecursiveSearchSplitting::pickPartition(unsigned Depth, unsigned Idx,
+                                             SplitProposal SP) {
+  while (Idx < WorkList.size()) {
+    // Step 1: Determine candidate PIDs.
+    //
+    const WorkListEntry &Entry = WorkList[Idx];
+    const BitVector &Cluster = Entry.Cluster;
+
+    // Default option is to do load-balancing, AKA assign to least pressured
+    // partition.
+    const unsigned CheapestPID = SP.findCheapestPartition();
+    assert(CheapestPID != InvalidPID);
+
+    // Explore assigning to the kernel that contains the most dependencies in
+    // common.
+    const auto [MostSimilarPID, SimilarDepsCost] =
+        findMostSimilarPartition(Entry, SP);
+
+    // We can chose to explore only one path if we only have one valid path, or
+    // if we reached maximum search depth and can no longer branch out.
+    unsigned SinglePIDToTry = InvalidPID;
+    if (MostSimilarPID == InvalidPID) // no similar PID found
+      SinglePIDToTry = CheapestPID;
+    else if (MostSimilarPID == CheapestPID) // both landed on the same PID
+      SinglePIDToTry = CheapestPID;
+    else if (Depth >= MaxDepth) {
+      // We have to choose one path. Use a heuristic to guess which one will be
+      // more appropriate.
+      if (Entry.CostExcludingGraphEntryPoints > LargeClusterThreshold) {
+        // Check if the amount of code in common makes it worth it.
+        assert(SimilarDepsCost && Entry.CostExcludingGraphEntryPoints);
+        const double Ratio =
+            SimilarDepsCost / Entry.CostExcludingGraphEntryPoints;
+        assert(Ratio >= 0.0 && Ratio <= 1.0);
+        if (LargeFnOverlapForMerge > Ratio) {
+          // For debug, just print "L", so we'll see "L3=P3" for instance, which
+          // will mean we reached max depth and chose P3 based on this
+          // heuristic.
+          LLVM_DEBUG(dbgs() << 'L');
+          SinglePIDToTry = MostSimilarPID;
+        }
+      } else
+        SinglePIDToTry = CheapestPID;
     }
 
-    // When a function has non duplicatable dependencies, we have to keep it in
-    // the first partition as well. This is a conservative approach, a
-    // finer-grained approach could keep track of which dependencies are
-    // non-duplicatable exactly and just make sure they're grouped together.
-    if (CurFn.HasNonDuplicatableDependecy) {
-      SML << "Function with externally visible dependency "
-          << getName(*CurFn.Fn) << " defaulting to P0\n";
-      AssignToPartition(0, CurFn);
+    // Step 2: Explore candidates.
+
+    // When we only explore one possible path, and thus branch depth doesn't
+    // increase, do not recurse, iterate instead.
+    if (SinglePIDToTry != InvalidPID) {
+      LLVM_DEBUG(dbgs() << Idx << "=P" << SinglePIDToTry << ' ');
+      // Only one path to explore, don't clone SP, don't increase depth.
+      SP.add(SinglePIDToTry, Cluster);
+      ++Idx;
       continue;
     }
 
-    // Be smart with large functions to avoid duplicating their dependencies.
-    if (CurFn.isLarge(LargeFnThreshold)) {
-      assert(LargeFnOverlapForMerge >= 0.0f && LargeFnOverlapForMerge <= 1.0f);
-      SML << "Large Function: " << getName(*CurFn.Fn)
-          << " - looking for partition with at least "
-          << format("%0.2f", LargeFnOverlapForMerge * 100) << "% overlap\n";
-
-      bool Assigned = false;
-      for (const auto &[PID, Fns] : enumerate(Partitions)) {
-        float Overlap = calculateOverlap(CurFn.Dependencies, Fns);
-        SML << "  => " << format("%0.2f", Overlap * 100) << "% overlap with P"
-            << PID << '\n';
-        if (Overlap > LargeFnOverlapForMerge) {
-          SML << "  selecting P" << PID << '\n';
-          AssignToPartition(PID, CurFn);
-          Assigned = true;
-        }
-      }
+    assert(MostSimilarPID != InvalidPID);
 
-      if (Assigned)
-        continue;
+    // We explore multiple paths: recurse at increased depth, then stop this
+    // function.
+
+    LLVM_DEBUG(dbgs() << '\n');
+
+    // lb = load balancing = put in cheapest partition
+    {
+      SplitProposal BranchSP = SP;
+      LLVM_DEBUG(dbgs().indent(Depth)
+                 << " [lb] " << Idx << "=P" << CheapestPID << "? ");
+      BranchSP.add(CheapestPID, Cluster);
+      pickPartition(Depth + 1, Idx + 1, BranchSP);
+    }
+
+    // ms = most similar = put in partition with the most in common
+    {
+      SplitProposal BranchSP = SP;
+      LLVM_DEBUG(dbgs().indent(Depth)
+                 << " [ms] " << Idx << "=P" << MostSimilarPID << "? ");
+      BranchSP.add(MostSimilarPID, Cluster);
+      pickPartition(Depth + 1, Idx + 1, BranchSP);
     }
 
-    // Normal "load-balancing", assign to partition with least pressure.
-    auto [PID, CurCost] = BalancingQueue.back();
-    AssignToPartition(PID, CurFn);
+    return;
   }
 
-  if (SML) {
-    CostType ModuleCostOr1 = ModuleCost ? ModuleCost : 1;
-    for (const auto &[Idx, Part] : enumerate(Partitions)) {
-      CostType Cost = 0;
-      for (auto *Fn : Part)
-        Cost += FnCosts.at(Fn);
-      SML << "P" << Idx << " has a total cost of " << Cost << " ("
-          << format("%0.2f", (float(Cost) / ModuleCostOr1) * 100)
-          << "% of source module)\n";
+  // Step 3: If we assigned all WorkList items, submit the proposal.
+
+  assert(Idx == WorkList.size());
+  assert(NumProposalsSubmitted <= (2u << MaxDepth) &&
+         "Search got out of bounds?");
+  SP.setName("recursive_search (depth=" + std::to_string(Depth) + ") #" +
+             std::to_string(NumProposalsSubmitted++));
+  LLVM_DEBUG(dbgs() << '\n');
+  SubmitProposal(SP);
+}
+
+std::pair<unsigned, CostType>
+RecursiveSearchSplitting::findMostSimilarPartition(const WorkListEntry &Entry,
+                                                   const SplitProposal &SP) {
+  if (!Entry.NumNonEntryNodes)
+    return {InvalidPID, 0};
+
+  // We take the partition that is the most similar using Cost as a metric.
+  // So we take the set of nodes in common, compute their aggregated cost, and
+  // pick the partition with the highest cost in common.
+  unsigned ChosenPID = InvalidPID;
+  CostType ChosenCost = 0;
+  for (unsigned PID = 0; PID < NumParts; ++PID) {
+    BitVector BV = SP[PID];
+    BV &= Entry.Cluster; // FIXME: & doesn't work between BVs?!
+
+    if (BV.none())
+      continue;
+
+    const CostType Cost = SG.calculateCost(BV);
+
+    if (ChosenPID == InvalidPID || ChosenCost < Cost ||
+        (ChosenCost == Cost && PID > ChosenPID)) {
+      ChosenPID = PID;
+      ChosenCost = Cost;
     }
+  }
+
+  return {ChosenPID, ChosenCost};
+}
+
+//===----------------------------------------------------------------------===//
+// DOTGraph Printing Support
+//===----------------------------------------------------------------------===//
+
+const SplitGraph::Node *mapEdgeToDst(const SplitGraph::Edge *E) {
+  return E->Dst;
+}
+
+using SplitGraphEdgeDstIterator =
+    mapped_iterator<SplitGraph::edges_iterator, decltype(&mapEdgeToDst)>;
+
+} // namespace
+
+template <> struct GraphTraits<SplitGraph> {
+  using NodeRef = const SplitGraph::Node *;
+  using nodes_iterator = SplitGraph::nodes_iterator;
+  using ChildIteratorType = SplitGraphEdgeDstIterator;
+
+  using EdgeRef = const SplitGraph::Edge *;
+  using ChildEdgeIteratorType = SplitGraph::edges_iterator;
+
+  static NodeRef getEntryNode(NodeRef N) { return N; }
 
-    SML << "--Partitioning Done--\n\n";
+  static ChildIteratorType child_begin(NodeRef Ref) {
+    return {Ref->outgoing_edges().begin(), mapEdgeToDst};
+  }
+  static ChildIteratorType child_end(NodeRef Ref) {
+    return {Ref->outgoing_edges().end(), mapEdgeToDst};
   }
 
-  // Check no functions were missed.
-#ifndef NDEBUG
-  DenseSet<const Function *> AllFunctions;
-  for (const auto &Part : Partitions)
-    AllFunctions.insert(Part.begin(), Part.end());
+  static nodes_iterator nodes_begin(const SplitGraph &G) {
+    return G.nodes().begin();
+  }
+  static nodes_iterator nodes_end(const SplitGraph &G) {
+    return G.nodes().end();
+  }
+};
 
-  for (auto &Fn : M) {
-    if (!Fn.isDeclaration() && !AllFunctions.contains(&Fn)) {
-      assert(AllFunctions.contains(&Fn) && "Missed a function?!");
+template <> struct DOTGraphTraits<SplitGraph> : public DefaultDOTGraphTraits {
+  DOTGraphTraits(bool IsSimple = false) : DefaultDOTGraphTraits(IsSimple) {}
+
+  static std::string getGraphName(const SplitGraph &SG) {
+    return SG.getModule().getName().str();
+  }
+
+  std::string getNodeLabel(const SplitGraph::Node *N, const SplitGraph &SG) {
+    return N->getName().str();
+  }
+
+  static std::string getNodeDescription(const SplitGraph::Node *N,
+                                        const SplitGraph &SG) {
+    std::string Result;
+    if (N->isEntryFunctionCC())
+      Result += "entry-fn-cc ";
+    if (N->isNonCopyable())
+      Result += "non-copyable ";
+    Result += "cost:" + std::to_string(N->getIndividualCost());
+    return Result;
+  }
+
+  static std::string getNodeAttributes(const SplitGraph::Node *N,
+                                       const SplitGraph &SG) {
+    return N->hasAnyIncomingEdges() ? "" : "color=\"red\"";
+  }
+
+  static std::string getEdgeAttributes(const SplitGraph::Node *N,
+                                       SplitGraphEdgeDstIterator EI,
+                                       const SplitGraph &SG) {
+
+    switch ((*EI.getCurrent())->Kind) {
+    case SplitGraph::EdgeKind::DirectCall:
+      return "";
+    case SplitGraph::EdgeKind::IndirectCall:
+      return "style=\"dashed\"";
     }
+    llvm_unreachable("Unknown SplitGraph::EdgeKind enum");
----------------
Pierre-vh wrote:

Added fix for MSVC warning

https://github.com/llvm/llvm-project/pull/107076


More information about the llvm-commits mailing list