[llvm] [AMDGPU] Graph-based Module Splitting Rewrite (PR #104763)

Pierre van Houtryve via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 21 00:09:17 PDT 2024


================
@@ -278,351 +248,1043 @@ static bool canBeIndirectlyCalled(const Function &F) {
                            /*IgnoreCastedDirectCall=*/true);
 }
 
-/// When a function or any of its callees performs an indirect call, this
-/// takes over \ref addAllDependencies and adds all potentially callable
-/// functions to \p Fns so they can be counted as dependencies of the function.
+//===----------------------------------------------------------------------===//
+// Graph-based Module Representation
+//===----------------------------------------------------------------------===//
+
+/// AMDGPUSplitModule's view of the source Module, as a graph of all components
+/// that can be split into different modules.
 ///
-/// This is needed due to how AMDGPUResourceUsageAnalysis operates: in the
-/// presence of an indirect call, the function's resource usage is the same as
-/// the most expensive function in the module.
-/// \param M    The module.
-/// \param Fns[out] Resulting list of functions.
-static void addAllIndirectCallDependencies(const Module &M,
-                                           DenseSet<const Function *> &Fns) {
-  for (const auto &Fn : M) {
-    if (canBeIndirectlyCalled(Fn))
-      Fns.insert(&Fn);
+/// The most trivial instance of this graph is just the CallGraph of the module,
+/// but it is not guaranteed that the graph is strictly equal to the CG. It
+/// currently always is but it's designed in a way that would eventually allow
+/// us to create abstract nodes, or nodes for different entities such as global
+/// variables or any other meaningful constraint we must consider.
+///
+/// The graph is only mutable by this class, and is generally not modified
+/// after \ref SplitGraph::buildGraph runs. No consumers of the graph can
+/// mutate it.
+class SplitGraph {
+public:
+  class Node;
+
+  enum class EdgeKind : uint8_t {
+    /// The nodes are related through a direct call. This is a "strong" edge as
+    /// it means the Src will directly reference the Dst.
+    DirectCall,
+    /// The nodes are related through an indirect call.
+    /// This is a "weaker" edge and is only considered when traversing the graph
+    /// starting from a kernel. We need this edge for resource usage analysis.
+    ///
+    /// The reason why we have this edge in the first place is due to how
+    /// AMDGPUResourceUsageAnalysis works. In the presence of an indirect call,
+    /// the resource usage of the kernel containing the indirect call is the
+    /// max resource usage of all functions that can be indirectly called.
+    IndirectCall,
+  };
+
+  /// An edge between two nodes. Edges are directional, and tagged with a
+  /// "kind".
+  struct Edge {
+    Edge(Node *Src, Node *Dst, EdgeKind Kind)
+        : Src(Src), Dst(Dst), Kind(Kind) {}
+
+    Node *Src; ///< Source
+    Node *Dst; ///< Destination
+    EdgeKind Kind;
+  };
+
+  using EdgesVec = SmallVector<const Edge *, 0>;
+  using edges_iterator = EdgesVec::const_iterator;
+  using nodes_iterator = const Node *const *;
+
+  SplitGraph(const Module &M, const FunctionsCostMap &CostMap,
+             CostType ModuleCost)
+      : M(M), CostMap(CostMap), ModuleCost(ModuleCost) {}
+
+  void buildGraph(CallGraph &CG);
+
+#ifndef NDEBUG
+  void verifyGraph() const;
+#endif
+
+  bool empty() const { return Nodes.empty(); }
+  const iterator_range<nodes_iterator> nodes() const {
+    return {Nodes.begin(), Nodes.end()};
   }
-}
+  const Node &getNode(unsigned ID) const { return *Nodes[ID]; }
+
+  unsigned getNumNodes() const { return Nodes.size(); }
+  BitVector createNodesBitVector() const { return BitVector(Nodes.size()); }
+
+  const Module &getModule() const { return M; }
+
+  CostType getModuleCost() const { return ModuleCost; }
+  CostType getCost(const Function &F) const { return CostMap.at(&F); }
+
+  /// \returns the aggregated cost of all nodes in \p BV (bits set to 1 = node
+  /// IDs).
+  CostType calculateCost(const BitVector &BV) const;
+
+private:
+  /// Retrieves the node for \p GV in \p Cache, or creates a new node for it and
+  /// updates \p Cache.
+  Node &getNode(DenseMap<const GlobalValue *, Node *> &Cache,
+                const GlobalValue &GV);
 
-/// Adds the functions that \p Fn may call to \p Fns, then recurses into each
-/// callee until all reachable functions have been gathered.
+  // Create a new edge between two nodes and add it to both nodes.
+  const Edge &createEdge(Node &Src, Node &Dst, EdgeKind EK);
+
+  const Module &M;
+  const FunctionsCostMap &CostMap;
+  CostType ModuleCost;
+
+  // Final list of nodes with stable ordering.
+  SmallVector<Node *> Nodes;
+
+  SpecificBumpPtrAllocator<Node> NodesPool;
+
+  // Edges are trivially destructible objects, so as a small optimization we
+  // use a BumpPtrAllocator which avoids destructor calls but also makes
+  // allocation faster.
+  static_assert(
+      std::is_trivially_destructible_v<Edge>,
+      "Edge must be trivially destructible to use the BumpPtrAllocator");
+  BumpPtrAllocator EdgesPool;
+};
+
+/// Nodes in the SplitGraph contain both incoming, and outgoing edges.
+/// Incoming edges have this node as their Dst, and Outgoing ones have this node
+/// as their Src.
+///
+/// Edge objects are shared by both nodes in Src/Dst. They provide immediate
+/// feedback on how two nodes are related, and in which direction they are
+/// related, which is valuable information to make splitting decisions.
 ///
-/// \param SML Log Helper
-/// \param CG Call graph for \p Fn's module.
-/// \param Fn Current function to look at.
-/// \param Fns[out] Resulting list of functions.
-/// \param OnlyDirect Whether to only consider direct callees.
-/// \param HadIndirectCall[out] Set to true if an indirect call was seen at some
-/// point, either in \p Fn or in one of the function it calls. When that
-/// happens, we fall back to adding all callable functions inside \p Fn's module
-/// to \p Fns.
-static void addAllDependencies(SplitModuleLogger &SML, const CallGraph &CG,
-                               const Function &Fn,
-                               DenseSet<const Function *> &Fns, bool OnlyDirect,
-                               bool &HadIndirectCall) {
-  assert(!Fn.isDeclaration());
-
-  const Module &M = *Fn.getParent();
-  SmallVector<const Function *> WorkList({&Fn});
+/// Nodes are fundamentally abstract, and any consumers of the graph should
+/// treat them as such. While a node will be a function most of the time, we
+/// could also create nodes for any other reason. In the future, we could have
+/// single nodes for multiple functions, or nodes for GVs, etc.
+class SplitGraph::Node {
+  friend class SplitGraph;
+
+public:
+  Node(unsigned ID, const GlobalValue &GV, CostType IndividualCost,
+       bool IsNonCopyable)
+      : ID(ID), GV(GV), IndividualCost(IndividualCost),
+        IsNonCopyable(IsNonCopyable), IsGraphEntry(false) {
+    if (auto *Fn = dyn_cast<Function>(&GV))
+      IsEntryFnCC = AMDGPU::isEntryFunctionCC(Fn->getCallingConv());
+    else
+      IsEntryFnCC = false;
+  }
+
+  /// An 0-indexed ID for the node. The maximum ID (exclusive) is the number of
+  /// nodes in the graph. This ID can be used as an index in a BitVector.
+  unsigned getID() const { return ID; }
+
+  const Function &getFunction() const { return cast<Function>(GV); }
+
+  /// \returns the cost to import this component into a given module, not
+  /// accounting for any dependencies that may need to be imported as well.
+  CostType getIndividualCost() const { return IndividualCost; }
+
+  bool isNonCopyable() const { return IsNonCopyable; }
+  bool isEntryFunctionCC() const { return IsEntryFnCC; }
+
+  /// \returns whether this is an entry point in the graph. Entry points are
+  /// defined as follows: if you take all entry points in the graph, and iterate
+  /// their dependencies, you are guaranteed to visit all nodes in the graph at
+  /// least once.
+  bool isGraphEntryPoint() const { return IsGraphEntry; }
+
+  StringRef getName() const { return GV.getName(); }
+
+  bool hasAnyIncomingEdges() const { return IncomingEdges.size(); }
+  bool hasAnyIncomingEdgesOfKind(EdgeKind EK) const {
+    return any_of(IncomingEdges, [&](const auto *E) { return E->Kind == EK; });
+  }
+
+  bool hasAnyOutgoingEdges() const { return OutgoingEdges.size(); }
+  bool hasAnyOutgoingEdgesOfKind(EdgeKind EK) const {
+    return any_of(OutgoingEdges, [&](const auto *E) { return E->Kind == EK; });
+  }
+
+  iterator_range<edges_iterator> incoming_edges() const {
+    return IncomingEdges;
+  }
+
+  iterator_range<edges_iterator> outgoing_edges() const {
+    return OutgoingEdges;
+  }
+
+  bool shouldFollowIndirectCalls() const { return isEntryFunctionCC(); }
+
+  /// Visit all children of this node in a recursive fashion. Also visits Self.
+  /// If \ref shouldFollowIndirectCalls returns false, then this only follows
+  /// DirectCall edges.
+  ///
+  /// \param Visitor Visitor Function.
+  void visitAllDependencies(std::function<void(const Node &)> Visitor) const;
+
+  /// Adds the depedencies of this node in \p BV by setting the bit
+  /// corresponding to each node.
+  ///
+  /// Implemented using \ref visitAllDependencies, hence it follows the same
+  /// rules regarding dependencies traversal.
+  ///
+  /// \param[out] BV The bitvector where the bits should be set.
+  void setDependenciesBits(BitVector &BV) const {
+    visitAllDependencies([&](const Node &N) { BV.set(N.getID()); });
+  }
+
+  /// Uses \ref visitAllDependencies to aggregate the individual cost of this
+  /// node and all of its dependencies.
+  ///
+  /// This is cached.
+  CostType getFullCost() const;
+
+private:
+  void markAsGraphEntry() { IsGraphEntry = true; }
+
+  unsigned ID;
+  const GlobalValue &GV;
+  CostType IndividualCost;
+  bool IsNonCopyable : 1;
+  bool IsEntryFnCC : 1;
+  bool IsGraphEntry : 1;
+
+  // TODO: Cache dependencies as well?
+  mutable CostType FullCost = 0;
+
+  // TODO: Use a single sorted vector (with all incoming/outgoing edges grouped
+  // together)
+  EdgesVec IncomingEdges;
+  EdgesVec OutgoingEdges;
+};
+
+void SplitGraph::Node::visitAllDependencies(
+    std::function<void(const Node &)> Visitor) const {
+  const bool FollowIndirect = shouldFollowIndirectCalls();
+  // FIXME: If this can access SplitGraph in the future, use a BitVector
+  // instead.
+  DenseSet<const Node *> Seen;
+  SmallVector<const Node *, 8> WorkList({this});
   while (!WorkList.empty()) {
-    const auto &CurFn = *WorkList.pop_back_val();
-    assert(!CurFn.isDeclaration());
+    const Node *CurN = WorkList.pop_back_val();
+    if (auto [It, Inserted] = Seen.insert(CurN); !Inserted)
+      continue;
+
+    Visitor(*CurN);
 
-    // Scan for an indirect call. If such a call is found, we have to
-    // conservatively assume this can call all non-entrypoint functions in the
-    // module.
+    for (const Edge *E : CurN->outgoing_edges()) {
+      if (!FollowIndirect && E->Kind == EdgeKind::IndirectCall)
+        continue;
+      WorkList.push_back(E->Dst);
+    }
+  }
+}
 
-    for (auto &CGEntry : *CG[&CurFn]) {
+CostType SplitGraph::Node::getFullCost() const {
+  if (FullCost)
+    return FullCost;
+
+  assert(FullCost == 0);
+  visitAllDependencies(
+      [&](const Node &N) { FullCost += N.getIndividualCost(); });
+  return FullCost;
+}
+
+void SplitGraph::buildGraph(CallGraph &CG) {
+  SplitModuleTimer SMT("buildGraph", "graph construction");
+  LLVM_DEBUG(
+      dbgs()
+      << "[build graph] constructing graph representation of the input\n");
+
+  // We build the graph by just iterating all functions in the module and
+  // working on their direct callees. At the end, all nodes should be linked
+  // together as expected.
+  DenseMap<const GlobalValue *, Node *> Cache;
+  SmallVector<const Function *> FnsWithIndirectCalls, IndirectlyCallableFns;
+  for (const Function &Fn : M) {
+    if (Fn.isDeclaration())
+      continue;
+
+    // Look at direct callees and create the necessary edges in the graph.
+    bool HasIndirectCall = false;
+    Node &N = getNode(Cache, Fn);
+    for (auto &CGEntry : *CG[&Fn]) {
       auto *CGNode = CGEntry.second;
       auto *Callee = CGNode->getFunction();
       if (!Callee) {
-        if (OnlyDirect)
-          continue;
-
-        // Functions have an edge towards CallsExternalNode if they're external
-        // declarations, or if they do an indirect call. As we only process
-        // definitions here, we know this means the function has an indirect
-        // call. We then have to conservatively assume this can call all
-        // non-entrypoint functions in the module.
-        if (CGNode != CG.getCallsExternalNode())
-          continue; // this is another function-less node we don't care about.
-
-        SML << "Indirect call detected in " << getName(CurFn)
-            << " - treating all non-entrypoint functions as "
-               "potential dependencies\n";
-
-        // TODO: Print an ORE as well ?
-        addAllIndirectCallDependencies(M, Fns);
-        HadIndirectCall = true;
+        // TODO: Don't consider inline assembly as indirect calls.
+        if (CGNode == CG.getCallsExternalNode())
+          HasIndirectCall = true;
         continue;
       }
 
-      if (Callee->isDeclaration())
-        continue;
+      if (!Callee->isDeclaration())
+        createEdge(N, getNode(Cache, *Callee), EdgeKind::DirectCall);
+    }
+
+    // Keep track of this function if it contains an indirect call and/or if it
+    // can be indirectly called.
+    if (HasIndirectCall) {
+      LLVM_DEBUG(dbgs() << "indirect call found in " << Fn.getName() << "\n");
+      FnsWithIndirectCalls.push_back(&Fn);
+    }
 
-      auto [It, Inserted] = Fns.insert(Callee);
-      if (Inserted)
-        WorkList.push_back(Callee);
+    if (canBeIndirectlyCalled(Fn))
+      IndirectlyCallableFns.push_back(&Fn);
+  }
+
+  // Post-process functions with indirect calls.
+  for (const Function *Fn : FnsWithIndirectCalls) {
+    for (const Function *Candidate : IndirectlyCallableFns) {
+      Node &Src = getNode(Cache, *Fn);
+      Node &Dst = getNode(Cache, *Candidate);
+      createEdge(Src, Dst, EdgeKind::IndirectCall);
     }
   }
+
+  // Now, find all entry points.
+  SmallVector<Node *, 16> CandidateEntryPoints;
+  BitVector NodesReachableByKernels = createNodesBitVector();
+  for (Node *N : Nodes) {
+    // Functions with an Entry CC are always graph entry points too.
+    if (N->isEntryFunctionCC()) {
+      N->markAsGraphEntry();
+      N->setDependenciesBits(NodesReachableByKernels);
+    } else if (!N->hasAnyIncomingEdgesOfKind(EdgeKind::DirectCall))
+      CandidateEntryPoints.push_back(N);
+  }
+
+  for (Node *N : CandidateEntryPoints) {
+    // This can be another entry point if it's not reachable by a kernel
+    // TODO: We could sort all of the possible new entries in a stable order
+    // (e.g. by cost), then consume them one by one until
+    // NodesReachableByKernels is all 1s. It'd allow us to avoid
+    // considering some nodes as non-entries in some specific cases.
+    if (!NodesReachableByKernels.test(N->getID()))
+      N->markAsGraphEntry();
+  }
+
+#ifndef NDEBUG
+  verifyGraph();
+#endif
 }
 
-/// Contains information about a function and its dependencies.
-/// This is a splitting root. The splitting algorithm works by
-/// assigning these to partitions.
-struct FunctionWithDependencies {
-  FunctionWithDependencies(SplitModuleLogger &SML, CallGraph &CG,
-                           const DenseMap<const Function *, CostType> &FnCosts,
-                           const Function *Fn)
-      : Fn(Fn) {
-    // When Fn is not a kernel, we don't need to collect indirect callees.
-    // Resource usage analysis is only performed on kernels, and we collect
-    // indirect callees for resource usage analysis.
-    addAllDependencies(SML, CG, *Fn, Dependencies,
-                       /*OnlyDirect*/ !isEntryPoint(Fn), HasIndirectCall);
-    TotalCost = FnCosts.at(Fn);
-    for (const auto *Dep : Dependencies) {
-      TotalCost += FnCosts.at(Dep);
-
-      // We cannot duplicate functions with external linkage, or functions that
-      // may be overriden at runtime.
-      HasNonDuplicatableDependecy |=
-          (Dep->hasExternalLinkage() || !Dep->isDefinitionExact());
+#ifndef NDEBUG
+void SplitGraph::verifyGraph() const {
----------------
Pierre-vh wrote:

should it also be enabled in release builds or is it fine if it's guarded by NDEBUG?

https://github.com/llvm/llvm-project/pull/104763


More information about the llvm-commits mailing list