[llvm] [AMDGPU] Graph-based Module Splitting Rewrite (PR #104763)
Matt Arsenault via llvm-commits
llvm-commits at lists.llvm.org
Mon Aug 19 07:05:14 PDT 2024
================
@@ -278,351 +253,1042 @@ static bool canBeIndirectlyCalled(const Function &F) {
/*IgnoreCastedDirectCall=*/true);
}
-/// When a function or any of its callees performs an indirect call, this
-/// takes over \ref addAllDependencies and adds all potentially callable
-/// functions to \p Fns so they can be counted as dependencies of the function.
+//===----------------------------------------------------------------------===//
+// Graph-based Module Representation
+//===----------------------------------------------------------------------===//
+
+/// AMDGPUSplitModule's view of the source Module, as a graph of all components
+/// that can be split into different modules.
///
-/// This is needed due to how AMDGPUResourceUsageAnalysis operates: in the
-/// presence of an indirect call, the function's resource usage is the same as
-/// the most expensive function in the module.
-/// \param M The module.
-/// \param Fns[out] Resulting list of functions.
-static void addAllIndirectCallDependencies(const Module &M,
- DenseSet<const Function *> &Fns) {
- for (const auto &Fn : M) {
- if (canBeIndirectlyCalled(Fn))
- Fns.insert(&Fn);
+/// The most trivial instance of this graph is just the CallGraph of the module,
+/// but it is not guaranteed that the graph is strictly equal to the CFG. It
+/// currently always is but it's designed in a way that would eventually allow
+/// us to create abstract nodes, or nodes for different entities such as global
+/// variables or any other meaningful constraint we must consider.
+///
+/// The graph is only mutable by this class, and is generally not modified
+/// after \ref SplitGraph::buildGraph runs. No consumers of the graph can
+/// mutate it.
+class SplitGraph {
+public:
+ class Node;
+
+ enum class EdgeKind : uint8_t {
+ /// The nodes are related through a direct call. This is a "strong" edge as
+ /// it means the Src will directly reference the Dst.
+ DirectCall,
+ /// The nodes are related through an indirect call.
+ /// This is a "weaker" edge and is only considered when traversing the graph
+ /// starting from a kernel. We need this edge for resource usage analysis.
+ ///
+ /// The reason why we have this edge in the first place is due to how
+ /// AMDGPUResourceUsageAnalysis works. In the presence of an indirect call,
+ /// the resource usage of the kernel containing the indirect call is the
+ /// max resource usage of all functions that can be indirectly called.
+ IndirectCall,
+ };
+
+ /// An edge between two nodes. Edges are directional, and tagged with a
+ /// "kind".
+ struct Edge {
+ Edge(Node *Src, Node *Dst, EdgeKind Kind)
+ : Src(Src), Dst(Dst), Kind(Kind) {}
+
+ Node *Src; ///< Source
+ Node *Dst; ///< Destination
+ EdgeKind Kind;
+ };
+
+ using EdgesVec = SmallVector<const Edge *, 0>;
+ using edges_iterator = EdgesVec::const_iterator;
+ using nodes_iterator = const Node *const *;
+
+ SplitGraph(const Module &M, const FunctionsCostMap &CostMap,
+ CostType ModuleCost)
+ : M(M), CostMap(CostMap), ModuleCost(ModuleCost) {}
+
+ void buildGraph(CallGraph &CG);
+
+#ifndef NDEBUG
+ void verifyGraph() const;
+#endif
+
+ bool empty() const { return Nodes.empty(); }
+ const iterator_range<nodes_iterator> nodes() const {
+ return {Nodes.begin(), Nodes.end()};
}
-}
+ const Node &getNode(unsigned ID) const { return *Nodes[ID]; }
-/// Adds the functions that \p Fn may call to \p Fns, then recurses into each
-/// callee until all reachable functions have been gathered.
+ unsigned getNumNodes() const { return Nodes.size(); }
+ BitVector createNodesBitVector() const { return BitVector(Nodes.size()); }
+
+ const Module &getModule() const { return M; }
+
+ CostType getModuleCost() const { return ModuleCost; }
+ CostType getCost(const Function &F) const { return CostMap.at(&F); }
+
+ /// \returns the aggregated cost of all nodes in \p BV (bits set to 1 = node
+ /// IDs).
+ CostType calculateCost(const BitVector &BV) const;
+
+private:
+ /// Retrieves the node for \p GV in \p Cache, or creates a new node for it and
+ /// updates \p Cache.
+ Node &getNode(DenseMap<const GlobalValue *, Node *> &Cache,
+ const GlobalValue &GV);
+
+ // Create a new edge between two nodes and add it to both nodes.
+ const Edge &createEdge(Node &Src, Node &Dst, EdgeKind EK);
+
+ const Module &M;
+ const FunctionsCostMap &CostMap;
+ CostType ModuleCost;
+
+ // Final list of nodes with stable ordering.
+ SmallVector<Node *> Nodes;
+
+ SpecificBumpPtrAllocator<Node> NodesPool;
+
+ // Edges are trivially destructible objects, so as a small optimization we
+ // use a BumpPtrAllocator which avoids destructor calls but also makes
+ // allocation faster.
+ static_assert(
+ std::is_trivially_destructible_v<Edge>,
+ "Edge must be trivially destructible to use the BumpPtrAllocator");
+ BumpPtrAllocator EdgesPool;
+};
+
+/// Nodes in the SplitGraph contain both incoming, and outgoing edges.
+/// Incoming edges have this node as their Dst, and Outgoing ones have this node
+/// as their Src.
///
-/// \param SML Log Helper
-/// \param CG Call graph for \p Fn's module.
-/// \param Fn Current function to look at.
-/// \param Fns[out] Resulting list of functions.
-/// \param OnlyDirect Whether to only consider direct callees.
-/// \param HadIndirectCall[out] Set to true if an indirect call was seen at some
-/// point, either in \p Fn or in one of the function it calls. When that
-/// happens, we fall back to adding all callable functions inside \p Fn's module
-/// to \p Fns.
-static void addAllDependencies(SplitModuleLogger &SML, const CallGraph &CG,
- const Function &Fn,
- DenseSet<const Function *> &Fns, bool OnlyDirect,
- bool &HadIndirectCall) {
- assert(!Fn.isDeclaration());
-
- const Module &M = *Fn.getParent();
- SmallVector<const Function *> WorkList({&Fn});
+/// Edge objects are shared by both nodes in Src/Dst. They provide immediate
+/// feedback on how two nodes are related, and in which direction they are
+/// related, which is valuable information to make splitting decisions.
+///
+/// Nodes are fundamentally abstract, and any consumers of the graph should
+/// treat them as such. While a node will be a function most of the time, we
+/// could also create nodes for any other reason. In the future, we could have
+/// single nodes for multiple functions, or nodes for GVs, etc.
+class SplitGraph::Node {
+ friend class SplitGraph;
+
+public:
+ Node(unsigned ID, const GlobalValue &GV, CostType IndividualCost,
+ bool IsNonCopyable)
+ : ID(ID), GV(GV), IndividualCost(IndividualCost),
+ IsNonCopyable(IsNonCopyable), IsEntry(false) {
+ if (auto *Fn = dyn_cast<Function>(&GV))
+ IsKernel = ::llvm::isKernel(Fn);
+ else
+ IsKernel = false;
+ }
+
+ /// An 0-indexed ID for the node. The maximum ID (exclusive) is the number of
+ /// nodes in the graph. This ID can be used as an index in a BitVector.
+ unsigned getID() const { return ID; }
+
+ const Function &getFunction() const { return cast<Function>(GV); }
+
+ /// \returns the cost to import this component into a given module, not
+ /// accounting for any dependencies that may need to be imported as well.
+ CostType getIndividualCost() const { return IndividualCost; }
+
+ bool isNonCopyable() const { return IsNonCopyable; }
+ bool isKernel() const { return IsKernel; }
+
+ /// \returns whether this is an entry point in the graph. Entry points are
+ /// defined as follows: if you take all entry points in the graph, and iterate
+ /// their dependencies, you are guaranteed to visit all nodes in the graph at
+ /// least once.
+ bool isGraphEntryPoint() const { return IsEntry; }
+
+ std::string getName() const { return GV.getName().str(); }
+
+ bool hasAnyIncomingEdges() const { return IncomingEdges.size(); }
+ bool hasAnyIncomingEdgesOfKind(EdgeKind EK) const {
+ return any_of(IncomingEdges, [&](const auto *E) { return E->Kind == EK; });
+ }
+
+ bool hasAnyOutgoingEdges() const { return OutgoingEdges.size(); }
+ bool hasAnyOutgoingEdgesOfKind(EdgeKind EK) const {
+ return any_of(OutgoingEdges, [&](const auto *E) { return E->Kind == EK; });
+ }
+
+ iterator_range<edges_iterator> incoming_edges() const {
+ return IncomingEdges;
+ }
+
+ iterator_range<edges_iterator> outgoing_edges() const {
+ return OutgoingEdges;
+ }
+
+ bool shouldFollowIndirectCalls() const { return isKernel(); }
+
+ /// Visit all children of this node in a recursive fashion. Also visits Self.
+ /// If \ref shouldFollowIndirectCalls returns false, then this only follows
+ /// DirectCall edges.
+ ///
+ /// \param Visitor Visitor Function.
+ void visitAllDependencies(std::function<void(const Node &)> Visitor) const;
+
+ /// Adds the depedencies of this node in \p BV by setting the bit
+ /// corresponding to each node.
+ ///
+ /// Implemented using \ref visitAllDependencies, hence it follows the same
+ /// rules regarding dependencies traversal.
+ ///
+ /// \param[out] BV The bitvector where the bits should be set.
+ void setDependenciesBits(BitVector &BV) const {
+ visitAllDependencies([&](const Node &N) { BV.set(N.getID()); });
+ }
+
+ /// Uses \ref visitAllDependencies to aggregate the individual cost of this
+ /// node and all of its dependencies.
+ ///
+ /// This is cached.
+ CostType getFullCost() const;
+
+private:
+ void markAsEntry() { IsEntry = true; }
+
+ unsigned ID;
+ const GlobalValue &GV;
+ CostType IndividualCost;
+ bool IsNonCopyable : 1;
+ bool IsKernel : 1;
+ bool IsEntry : 1;
+
+ // TODO: Cache dependencies as well?
+ mutable CostType FullCost = 0;
+
+ // TODO: Use a single sorted vector (with all incoming/outgoing edges grouped
+ // together)
+ EdgesVec IncomingEdges;
+ EdgesVec OutgoingEdges;
+};
+
+void SplitGraph::Node::visitAllDependencies(
+ std::function<void(const Node &)> Visitor) const {
+ const bool FollowIndirect = shouldFollowIndirectCalls();
+ // FIXME: If this can access SplitGraph in the future, use a BitVector
+ // instead.
+ DenseSet<const Node *> Seen;
+ SmallVector<const Node *, 8> WorkList({this});
while (!WorkList.empty()) {
- const auto &CurFn = *WorkList.pop_back_val();
- assert(!CurFn.isDeclaration());
+ const Node *CurN = WorkList.pop_back_val();
+ if (auto [It, Inserted] = Seen.insert(CurN); !Inserted)
+ continue;
+
+ Visitor(*CurN);
- // Scan for an indirect call. If such a call is found, we have to
- // conservatively assume this can call all non-entrypoint functions in the
- // module.
+ for (const Edge *E : CurN->outgoing_edges()) {
+ if (!FollowIndirect && E->Kind == EdgeKind::IndirectCall)
+ continue;
+ WorkList.push_back(E->Dst);
+ }
+ }
+}
+CostType SplitGraph::Node::getFullCost() const {
----------------
arsenm wrote:
Missing newline
https://github.com/llvm/llvm-project/pull/104763
More information about the llvm-commits
mailing list