[llvm] [AMDGPU] Graph-based Module Splitting Rewrite (PR #104763)
Matt Arsenault via llvm-commits
llvm-commits at lists.llvm.org
Mon Aug 19 07:05:16 PDT 2024
================
@@ -278,351 +253,1042 @@ static bool canBeIndirectlyCalled(const Function &F) {
/*IgnoreCastedDirectCall=*/true);
}
-/// When a function or any of its callees performs an indirect call, this
-/// takes over \ref addAllDependencies and adds all potentially callable
-/// functions to \p Fns so they can be counted as dependencies of the function.
+//===----------------------------------------------------------------------===//
+// Graph-based Module Representation
+//===----------------------------------------------------------------------===//
+
+/// AMDGPUSplitModule's view of the source Module, as a graph of all components
+/// that can be split into different modules.
///
-/// This is needed due to how AMDGPUResourceUsageAnalysis operates: in the
-/// presence of an indirect call, the function's resource usage is the same as
-/// the most expensive function in the module.
-/// \param M The module.
-/// \param Fns[out] Resulting list of functions.
-static void addAllIndirectCallDependencies(const Module &M,
- DenseSet<const Function *> &Fns) {
- for (const auto &Fn : M) {
- if (canBeIndirectlyCalled(Fn))
- Fns.insert(&Fn);
+/// The most trivial instance of this graph is just the CallGraph of the module,
+/// but it is not guaranteed that the graph is strictly equal to the CFG. It
+/// currently always is but it's designed in a way that would eventually allow
+/// us to create abstract nodes, or nodes for different entities such as global
+/// variables or any other meaningful constraint we must consider.
+///
+/// The graph is only mutable by this class, and is generally not modified
+/// after \ref SplitGraph::buildGraph runs. No consumers of the graph can
+/// mutate it.
+class SplitGraph {
+public:
+ class Node;
+
+ enum class EdgeKind : uint8_t {
+ /// The nodes are related through a direct call. This is a "strong" edge as
+ /// it means the Src will directly reference the Dst.
+ DirectCall,
+ /// The nodes are related through an indirect call.
+ /// This is a "weaker" edge and is only considered when traversing the graph
+ /// starting from a kernel. We need this edge for resource usage analysis.
+ ///
+ /// The reason why we have this edge in the first place is due to how
+ /// AMDGPUResourceUsageAnalysis works. In the presence of an indirect call,
+ /// the resource usage of the kernel containing the indirect call is the
+ /// max resource usage of all functions that can be indirectly called.
+ IndirectCall,
+ };
+
+ /// An edge between two nodes. Edges are directional, and tagged with a
+ /// "kind".
+ struct Edge {
+ Edge(Node *Src, Node *Dst, EdgeKind Kind)
+ : Src(Src), Dst(Dst), Kind(Kind) {}
+
+ Node *Src; ///< Source
+ Node *Dst; ///< Destination
+ EdgeKind Kind;
+ };
+
+ using EdgesVec = SmallVector<const Edge *, 0>;
+ using edges_iterator = EdgesVec::const_iterator;
+ using nodes_iterator = const Node *const *;
+
+ SplitGraph(const Module &M, const FunctionsCostMap &CostMap,
+ CostType ModuleCost)
+ : M(M), CostMap(CostMap), ModuleCost(ModuleCost) {}
+
+ void buildGraph(CallGraph &CG);
+
+#ifndef NDEBUG
+ void verifyGraph() const;
+#endif
+
+ bool empty() const { return Nodes.empty(); }
+ const iterator_range<nodes_iterator> nodes() const {
+ return {Nodes.begin(), Nodes.end()};
}
-}
+ const Node &getNode(unsigned ID) const { return *Nodes[ID]; }
-/// Adds the functions that \p Fn may call to \p Fns, then recurses into each
-/// callee until all reachable functions have been gathered.
+ unsigned getNumNodes() const { return Nodes.size(); }
+ BitVector createNodesBitVector() const { return BitVector(Nodes.size()); }
+
+ const Module &getModule() const { return M; }
+
+ CostType getModuleCost() const { return ModuleCost; }
+ CostType getCost(const Function &F) const { return CostMap.at(&F); }
+
+ /// \returns the aggregated cost of all nodes in \p BV (bits set to 1 = node
+ /// IDs).
+ CostType calculateCost(const BitVector &BV) const;
+
+private:
+ /// Retrieves the node for \p GV in \p Cache, or creates a new node for it and
+ /// updates \p Cache.
+ Node &getNode(DenseMap<const GlobalValue *, Node *> &Cache,
+ const GlobalValue &GV);
+
+ // Create a new edge between two nodes and add it to both nodes.
+ const Edge &createEdge(Node &Src, Node &Dst, EdgeKind EK);
+
+ const Module &M;
+ const FunctionsCostMap &CostMap;
+ CostType ModuleCost;
+
+ // Final list of nodes with stable ordering.
+ SmallVector<Node *> Nodes;
+
+ SpecificBumpPtrAllocator<Node> NodesPool;
+
+ // Edges are trivially destructible objects, so as a small optimization we
+ // use a BumpPtrAllocator which avoids destructor calls but also makes
+ // allocation faster.
+ static_assert(
+ std::is_trivially_destructible_v<Edge>,
+ "Edge must be trivially destructible to use the BumpPtrAllocator");
+ BumpPtrAllocator EdgesPool;
+};
+
+/// Nodes in the SplitGraph contain both incoming, and outgoing edges.
+/// Incoming edges have this node as their Dst, and Outgoing ones have this node
+/// as their Src.
///
-/// \param SML Log Helper
-/// \param CG Call graph for \p Fn's module.
-/// \param Fn Current function to look at.
-/// \param Fns[out] Resulting list of functions.
-/// \param OnlyDirect Whether to only consider direct callees.
-/// \param HadIndirectCall[out] Set to true if an indirect call was seen at some
-/// point, either in \p Fn or in one of the function it calls. When that
-/// happens, we fall back to adding all callable functions inside \p Fn's module
-/// to \p Fns.
-static void addAllDependencies(SplitModuleLogger &SML, const CallGraph &CG,
- const Function &Fn,
- DenseSet<const Function *> &Fns, bool OnlyDirect,
- bool &HadIndirectCall) {
- assert(!Fn.isDeclaration());
-
- const Module &M = *Fn.getParent();
- SmallVector<const Function *> WorkList({&Fn});
+/// Edge objects are shared by both nodes in Src/Dst. They provide immediate
+/// feedback on how two nodes are related, and in which direction they are
+/// related, which is valuable information to make splitting decisions.
+///
+/// Nodes are fundamentally abstract, and any consumers of the graph should
+/// treat them as such. While a node will be a function most of the time, we
+/// could also create nodes for any other reason. In the future, we could have
+/// single nodes for multiple functions, or nodes for GVs, etc.
+class SplitGraph::Node {
+ friend class SplitGraph;
+
+public:
+ Node(unsigned ID, const GlobalValue &GV, CostType IndividualCost,
+ bool IsNonCopyable)
+ : ID(ID), GV(GV), IndividualCost(IndividualCost),
+ IsNonCopyable(IsNonCopyable), IsEntry(false) {
+ if (auto *Fn = dyn_cast<Function>(&GV))
+ IsKernel = ::llvm::isKernel(Fn);
+ else
+ IsKernel = false;
+ }
+
+ /// An 0-indexed ID for the node. The maximum ID (exclusive) is the number of
+ /// nodes in the graph. This ID can be used as an index in a BitVector.
+ unsigned getID() const { return ID; }
+
+ const Function &getFunction() const { return cast<Function>(GV); }
+
+ /// \returns the cost to import this component into a given module, not
+ /// accounting for any dependencies that may need to be imported as well.
+ CostType getIndividualCost() const { return IndividualCost; }
+
+ bool isNonCopyable() const { return IsNonCopyable; }
+ bool isKernel() const { return IsKernel; }
+
+ /// \returns whether this is an entry point in the graph. Entry points are
+ /// defined as follows: if you take all entry points in the graph, and iterate
+ /// their dependencies, you are guaranteed to visit all nodes in the graph at
+ /// least once.
+ bool isGraphEntryPoint() const { return IsEntry; }
+
+ std::string getName() const { return GV.getName().str(); }
+
+ bool hasAnyIncomingEdges() const { return IncomingEdges.size(); }
+ bool hasAnyIncomingEdgesOfKind(EdgeKind EK) const {
+ return any_of(IncomingEdges, [&](const auto *E) { return E->Kind == EK; });
+ }
+
+ bool hasAnyOutgoingEdges() const { return OutgoingEdges.size(); }
+ bool hasAnyOutgoingEdgesOfKind(EdgeKind EK) const {
+ return any_of(OutgoingEdges, [&](const auto *E) { return E->Kind == EK; });
+ }
+
+ iterator_range<edges_iterator> incoming_edges() const {
+ return IncomingEdges;
+ }
+
+ iterator_range<edges_iterator> outgoing_edges() const {
+ return OutgoingEdges;
+ }
+
+ bool shouldFollowIndirectCalls() const { return isKernel(); }
+
+ /// Visit all children of this node in a recursive fashion. Also visits Self.
+ /// If \ref shouldFollowIndirectCalls returns false, then this only follows
+ /// DirectCall edges.
+ ///
+ /// \param Visitor Visitor Function.
+ void visitAllDependencies(std::function<void(const Node &)> Visitor) const;
+
+ /// Adds the depedencies of this node in \p BV by setting the bit
+ /// corresponding to each node.
+ ///
+ /// Implemented using \ref visitAllDependencies, hence it follows the same
+ /// rules regarding dependencies traversal.
+ ///
+ /// \param[out] BV The bitvector where the bits should be set.
+ void setDependenciesBits(BitVector &BV) const {
+ visitAllDependencies([&](const Node &N) { BV.set(N.getID()); });
+ }
+
+ /// Uses \ref visitAllDependencies to aggregate the individual cost of this
+ /// node and all of its dependencies.
+ ///
+ /// This is cached.
+ CostType getFullCost() const;
+
+private:
+ void markAsEntry() { IsEntry = true; }
+
+ unsigned ID;
+ const GlobalValue &GV;
+ CostType IndividualCost;
+ bool IsNonCopyable : 1;
+ bool IsKernel : 1;
+ bool IsEntry : 1;
+
+ // TODO: Cache dependencies as well?
+ mutable CostType FullCost = 0;
+
+ // TODO: Use a single sorted vector (with all incoming/outgoing edges grouped
+ // together)
+ EdgesVec IncomingEdges;
+ EdgesVec OutgoingEdges;
+};
+
+void SplitGraph::Node::visitAllDependencies(
+ std::function<void(const Node &)> Visitor) const {
+ const bool FollowIndirect = shouldFollowIndirectCalls();
+ // FIXME: If this can access SplitGraph in the future, use a BitVector
+ // instead.
+ DenseSet<const Node *> Seen;
+ SmallVector<const Node *, 8> WorkList({this});
while (!WorkList.empty()) {
- const auto &CurFn = *WorkList.pop_back_val();
- assert(!CurFn.isDeclaration());
+ const Node *CurN = WorkList.pop_back_val();
+ if (auto [It, Inserted] = Seen.insert(CurN); !Inserted)
+ continue;
+
+ Visitor(*CurN);
- // Scan for an indirect call. If such a call is found, we have to
- // conservatively assume this can call all non-entrypoint functions in the
- // module.
+ for (const Edge *E : CurN->outgoing_edges()) {
+ if (!FollowIndirect && E->Kind == EdgeKind::IndirectCall)
+ continue;
+ WorkList.push_back(E->Dst);
+ }
+ }
+}
+CostType SplitGraph::Node::getFullCost() const {
+ if (FullCost)
+ return FullCost;
+
+ assert(FullCost == 0);
+ visitAllDependencies(
+ [&](const Node &N) { FullCost += N.getIndividualCost(); });
+ return FullCost;
+}
+
+void SplitGraph::buildGraph(CallGraph &CG) {
+ SplitModuleTimer SMT("buildGraph", "graph construction");
+ LLVM_DEBUG(
+ dbgs()
+ << "[build graph] constructing graph representation of the input\n");
+
+ // We build the graph by just iterating all functions in the module and
+ // working on their direct callees. At the end, all nodes should be linked
+ // together as expected.
+ DenseMap<const GlobalValue *, Node *> Cache;
+ SmallVector<const Function *> FnsWithIndirectCalls, IndirectlyCallableFns;
+ for (const Function &Fn : M) {
+ if (Fn.isDeclaration())
+ continue;
- for (auto &CGEntry : *CG[&CurFn]) {
+ // Look at direct callees and create the necessary edges in the graph.
+ bool HasIndirectCall = false;
+ Node &N = getNode(Cache, Fn);
+ for (auto &CGEntry : *CG[&Fn]) {
auto *CGNode = CGEntry.second;
auto *Callee = CGNode->getFunction();
if (!Callee) {
- if (OnlyDirect)
- continue;
-
- // Functions have an edge towards CallsExternalNode if they're external
- // declarations, or if they do an indirect call. As we only process
- // definitions here, we know this means the function has an indirect
- // call. We then have to conservatively assume this can call all
- // non-entrypoint functions in the module.
- if (CGNode != CG.getCallsExternalNode())
- continue; // this is another function-less node we don't care about.
-
- SML << "Indirect call detected in " << getName(CurFn)
- << " - treating all non-entrypoint functions as "
- "potential dependencies\n";
-
- // TODO: Print an ORE as well ?
- addAllIndirectCallDependencies(M, Fns);
- HadIndirectCall = true;
+ // TODO: Don't consider inline assembly as indirect calls.
+ if (CGNode == CG.getCallsExternalNode())
+ HasIndirectCall = true;
continue;
}
- if (Callee->isDeclaration())
- continue;
+ if (!Callee->isDeclaration())
+ createEdge(N, getNode(Cache, *Callee), EdgeKind::DirectCall);
+ }
+
+ // Keep track of this function if it contains an indirect call and/or if it
+ // can be indirectly called.
+ if (HasIndirectCall) {
+ LLVM_DEBUG(dbgs() << "indirect call found in " << Fn.getName() << "\n");
+ FnsWithIndirectCalls.push_back(&Fn);
+ }
+
+ if (canBeIndirectlyCalled(Fn))
+ IndirectlyCallableFns.push_back(&Fn);
+ }
- auto [It, Inserted] = Fns.insert(Callee);
- if (Inserted)
- WorkList.push_back(Callee);
+ // Post-process functions with indirect calls.
+ for (const Function *Fn : FnsWithIndirectCalls) {
+ for (const Function *Candidate : IndirectlyCallableFns) {
+ Node &Src = getNode(Cache, *Fn);
+ Node &Dst = getNode(Cache, *Candidate);
+ createEdge(Src, Dst, EdgeKind::IndirectCall);
}
}
+
+ // Now, find all entry points.
+ SmallVector<Node *, 16> CandidateEntryPoints;
+ BitVector NodesReachableByKernels = createNodesBitVector();
+ for (Node *N : Nodes) {
+ // Kernels are always entry points.
+ if (N->isKernel()) {
+ N->markAsEntry();
+ N->setDependenciesBits(NodesReachableByKernels);
+ } else if (!N->hasAnyIncomingEdgesOfKind(EdgeKind::DirectCall))
+ CandidateEntryPoints.push_back(N);
+ }
+
+ for (Node *N : CandidateEntryPoints) {
+ // This can be another entry point if it's not reachable by a kernel
+ // TODO: We could sort all of the possible new entries in a stable order
+ // (e.g. by cost), then consume them one by one until
+ // NodesReachableByKernels is all 1s. It'd allow us to avoid
+ // considering some nodes as non-entries in some specific cases.
+ if (!NodesReachableByKernels.test(N->getID()))
+ N->markAsEntry();
+ }
+
+#ifndef NDEBUG
+ verifyGraph();
+#endif
}
-/// Contains information about a function and its dependencies.
-/// This is a splitting root. The splitting algorithm works by
-/// assigning these to partitions.
-struct FunctionWithDependencies {
- FunctionWithDependencies(SplitModuleLogger &SML, CallGraph &CG,
- const DenseMap<const Function *, CostType> &FnCosts,
- const Function *Fn)
- : Fn(Fn) {
- // When Fn is not a kernel, we don't need to collect indirect callees.
- // Resource usage analysis is only performed on kernels, and we collect
- // indirect callees for resource usage analysis.
- addAllDependencies(SML, CG, *Fn, Dependencies,
- /*OnlyDirect*/ !isEntryPoint(Fn), HasIndirectCall);
- TotalCost = FnCosts.at(Fn);
- for (const auto *Dep : Dependencies) {
- TotalCost += FnCosts.at(Dep);
-
- // We cannot duplicate functions with external linkage, or functions that
- // may be overriden at runtime.
- HasNonDuplicatableDependecy |=
- (Dep->hasExternalLinkage() || !Dep->isDefinitionExact());
+#ifndef NDEBUG
+void SplitGraph::verifyGraph() const {
+ unsigned ExpectedID = 0;
+ // Exceptionally using a set here in case IDs are messed up.
+ DenseSet<const Node *> SeenNodes;
+ DenseSet<const Function *> SeenFunctionNodes;
+ for (const Node *N : Nodes) {
+ assert(N->getID() == (ExpectedID++) && "Node IDs are incorrect!");
+ assert(SeenNodes.insert(N).second && "Node seen more than once!");
+ assert(&getNode(N->getID()) == N);
+
+ for (const Edge *E : N->IncomingEdges) {
+ assert(E->Src && E->Dst);
+ assert(E->Dst == N);
+ assert(find(E->Src->OutgoingEdges, E) != E->Src->OutgoingEdges.end());
}
+
+ for (const Edge *E : N->OutgoingEdges) {
+ assert(E->Src && E->Dst);
+ assert(E->Src == N);
+ assert(find(E->Dst->IncomingEdges, E) != E->Dst->IncomingEdges.end());
+ }
+
+ const Function &Fn = N->getFunction();
+ if (isKernel(&Fn)) {
+ assert(!N->hasAnyIncomingEdges() && "Kernels cannot have incoming edges");
+ }
+ assert(!Fn.isDeclaration() && "declarations shouldn't have nodes!");
+
+ auto [It, Inserted] = SeenFunctionNodes.insert(&Fn);
+ assert(Inserted && "one function has multiple nodes!");
}
+ assert(ExpectedID == Nodes.size() && "Node IDs out of sync!");
- const Function *Fn = nullptr;
- DenseSet<const Function *> Dependencies;
- /// Whether \p Fn or any of its \ref Dependencies contains an indirect call.
- bool HasIndirectCall = false;
- /// Whether any of \p Fn's dependencies cannot be duplicated.
- bool HasNonDuplicatableDependecy = false;
+ assert(createNodesBitVector().size() == getNumNodes());
- CostType TotalCost = 0;
+ // Check we respect the promise of Node::isKernel
+ BitVector BV = createNodesBitVector();
+ for (const Node *N : nodes()) {
+ if (N->isGraphEntryPoint())
+ N->setDependenciesBits(BV);
+ }
- /// \returns true if this function and its dependencies can be considered
- /// large according to \p Threshold.
- bool isLarge(CostType Threshold) const {
- return TotalCost > Threshold && !Dependencies.empty();
+ // Ensure each function in the module has an associated node.
+ for (const auto &Fn : M) {
+ if (!Fn.isDeclaration())
+ assert(SeenFunctionNodes.contains(&Fn) &&
+ "Fn has no associated node in the graph!");
}
+
+ assert(BV.all() &&
+ "not all nodes are reachable through the graph's entry points!");
+}
+#endif
+
+CostType SplitGraph::calculateCost(const BitVector &BV) const {
+ CostType Cost = 0;
+ for (unsigned NodeID : BV.set_bits())
+ Cost += getNode(NodeID).getIndividualCost();
+ return Cost;
+}
+
+SplitGraph::Node &
+SplitGraph::getNode(DenseMap<const GlobalValue *, Node *> &Cache,
+ const GlobalValue &GV) {
+ auto &N = Cache[&GV];
+ if (N)
+ return *N;
+
+ CostType Cost = 0;
+ bool NonCopyable = false;
+ if (const Function *Fn = dyn_cast<Function>(&GV)) {
+ NonCopyable = isNonCopyable(*Fn);
+ Cost = CostMap.at(Fn);
+ }
+ N = new (NodesPool.Allocate()) Node(Nodes.size(), GV, Cost, NonCopyable);
+ Nodes.push_back(N);
+ assert(&getNode(N->getID()) == N);
+ return *N;
+}
+
+const SplitGraph::Edge &SplitGraph::createEdge(Node &Src, Node &Dst,
+ EdgeKind EK) {
+ const Edge *E = new (EdgesPool.Allocate<Edge>(1)) Edge(&Src, &Dst, EK);
+ Src.OutgoingEdges.push_back(E);
+ Dst.IncomingEdges.push_back(E);
+ return *E;
+}
+
+//===----------------------------------------------------------------------===//
+// Split Proposals
+//===----------------------------------------------------------------------===//
+
+/// Represents a module splitting proposal.
+///
+/// Proposals are made of N BitVectors, one for each partition, where each bit
+/// set indicates that the node is present and should be copied inside that
+/// partition.
+///
+/// Proposals have several metrics attached so they can be compared/sorted,
+/// which the driver to try multiple strategies resultings in multiple proposals
+/// and choose the best one out of them.
+class SplitProposal {
+public:
+ SplitProposal(const SplitGraph &SG, unsigned MaxPartitions) : SG(&SG) {
+ Partitions.resize(MaxPartitions, {0, SG.createNodesBitVector()});
+ }
+
+ void setName(StringRef NewName) { Name = NewName; }
+ StringRef getName() const { return Name; }
+
+ const BitVector &operator[](unsigned PID) const {
+ return Partitions[PID].second;
+ }
+
+ void add(unsigned PID, const BitVector &BV) {
+ Partitions[PID].second |= BV;
+ updateScore(PID);
+ }
+
+ void print(raw_ostream &OS) const;
+ LLVM_DUMP_METHOD void dump() const { print(dbgs()); }
+
+ // Find the cheapest partition (lowest cost). In case of ties, always returns
+ // the highest partition number.
+ unsigned findCheapestPartition() const;
+
+ /// Calculate the CodeSize and Bottleneck scores.
+ void calculateScores();
+
+#ifndef NDEBUG
+ void verifyCompleteness() const;
+#endif
+
+ /// Only available after \ref calculateScores is called.
+ ///
+ /// A positive number indicating the % of code duplication that this proposal
+ /// creates. e.g. 0.2 means this proposal adds roughly 20% code size by
+ /// duplicating some functions across partitions.
+ ///
+ /// Value is always rounded up to 3 decimal places.
+ ///
+ /// A perfect score would be 0.0, and anything approaching 1.0 is very bad.
+ double getCodeSizeScore() const { return CodeSizeScore; }
+
+ /// Only available after \ref calculateScores is called.
+ ///
+ /// A number between [0, 1] which indicates how big of a bottleneck is
+ /// expected from the largest partition.
+ ///
+ /// A score of 1.0 means the biggest partition is as big as the source module,
+ /// so build time will be equal to or greater than the build time of the
+ /// initial input.
+ ///
+ /// Value is always rounded up to 3 decimal places.
+ ///
+ /// This is one of the metrics used to estimate this proposal's build time.
+ double getBottleneckScore() const { return BottleneckScore; }
+
+private:
+ void updateScore(unsigned PID) {
+ assert(SG);
+ for (auto &[PCost, Nodes] : Partitions) {
+ TotalCost -= PCost;
+ PCost = SG->calculateCost(Nodes);
+ TotalCost += PCost;
+ }
+ }
+
+ /// \see getCodeSizeScore
+ double CodeSizeScore = 0.0;
+ /// \see getBottleneckScore
+ double BottleneckScore = 0.0;
+ /// Aggregated cost of all partitions
+ CostType TotalCost = 0;
+
+ const SplitGraph *SG = nullptr;
+ std::string Name;
+
+ std::vector<std::pair<CostType, BitVector>> Partitions;
};
-/// Calculates how much overlap there is between \p A and \p B.
-/// \return A number between 0.0 and 1.0, where 1.0 means A == B and 0.0 means A
-/// and B have no shared elements. Kernels do not count in overlap calculation.
-static float calculateOverlap(const DenseSet<const Function *> &A,
- const DenseSet<const Function *> &B) {
- DenseSet<const Function *> Total;
- for (const auto *F : A) {
- if (!isEntryPoint(F))
- Total.insert(F);
+void SplitProposal::print(raw_ostream &OS) const {
+ assert(SG);
+
+ OS << "[proposal] " << Name << ", total cost:" << TotalCost
+ << ", code size score:" << format("%0.3f", CodeSizeScore)
+ << ", bottleneck score:" << format("%0.3f", BottleneckScore) << "\n";
+ for (const auto &[PID, Part] : enumerate(Partitions)) {
+ const auto &[Cost, NodeIDs] = Part;
+ OS << " - P" << PID << " nodes:" << NodeIDs.count() << " cost: " << Cost
+ << "|" << formatRatioOf(Cost, SG->getModuleCost()) << "%\n";
}
+}
- if (Total.empty())
- return 0.0f;
+unsigned SplitProposal::findCheapestPartition() const {
+ assert(!Partitions.empty());
+ CostType CurCost = std::numeric_limits<CostType>::max();
+ unsigned CurPID = InvalidPID;
+ for (const auto &[Idx, Part] : enumerate(Partitions)) {
+ if (Part.first <= CurCost) {
+ CurPID = Idx;
+ CurCost = Part.first;
+ }
+ }
+ assert(CurPID != InvalidPID);
+ return CurPID;
+}
- unsigned NumCommon = 0;
- for (const auto *F : B) {
- if (isEntryPoint(F))
- continue;
+void SplitProposal::calculateScores() {
+ if (Partitions.empty())
+ return;
- auto [It, Inserted] = Total.insert(F);
- if (!Inserted)
- ++NumCommon;
+ assert(SG);
+ CostType LargestPCost = 0;
+ for (auto &[PCost, Nodes] : Partitions) {
+ if (PCost > LargestPCost)
+ LargestPCost = PCost;
}
- return static_cast<float>(NumCommon) / Total.size();
+ CostType ModuleCost = SG->getModuleCost();
+ CodeSizeScore = double(TotalCost) / ModuleCost;
+ assert(CodeSizeScore >= 0.0);
+
+ BottleneckScore = double(LargestPCost) / ModuleCost;
+
+ CodeSizeScore = std::ceil(CodeSizeScore * 100.0) / 100.0;
+ BottleneckScore = std::ceil(BottleneckScore * 100.0) / 100.0;
}
-/// Performs all of the partitioning work on \p M.
-/// \param SML Log Helper
-/// \param M Module to partition.
-/// \param NumParts Number of partitions to create.
-/// \param ModuleCost Total cost of all functions in \p M.
-/// \param FnCosts Map of Function -> Cost
-/// \param WorkList Functions and their dependencies to process in order.
-/// \returns The created partitions (a vector of size \p NumParts )
-static std::vector<DenseSet<const Function *>>
-doPartitioning(SplitModuleLogger &SML, Module &M, unsigned NumParts,
- CostType ModuleCost,
- const DenseMap<const Function *, CostType> &FnCosts,
- const SmallVector<FunctionWithDependencies> &WorkList) {
-
- SML << "\n--Partitioning Starts--\n";
-
- // Calculate a "large function threshold". When more than one function's total
- // import cost exceeds this value, we will try to assign it to an existing
- // partition to reduce the amount of duplication needed.
- //
- // e.g. let two functions X and Y have a import cost of ~10% of the module, we
- // assign X to a partition as usual, but when we get to Y, we check if it's
- // worth also putting it in Y's partition.
- const CostType LargeFnThreshold =
- LargeFnFactor ? CostType(((ModuleCost / NumParts) * LargeFnFactor))
- : std::numeric_limits<CostType>::max();
-
- std::vector<DenseSet<const Function *>> Partitions;
- Partitions.resize(NumParts);
-
- // Assign functions to partitions, and try to keep the partitions more or
- // less balanced. We do that through a priority queue sorted in reverse, so we
- // can always look at the partition with the least content.
- //
- // There are some cases where we will be deliberately unbalanced though.
- // - Large functions: we try to merge with existing partitions to reduce code
- // duplication.
- // - Functions with indirect or external calls always go in the first
- // partition (P0).
- auto ComparePartitions = [](const std::pair<PartitionID, CostType> &a,
- const std::pair<PartitionID, CostType> &b) {
- // When two partitions have the same cost, assign to the one with the
- // biggest ID first. This allows us to put things in P0 last, because P0 may
- // have other stuff added later.
- if (a.second == b.second)
- return a.first < b.first;
- return a.second > b.second;
+#ifndef NDEBUG
+void SplitProposal::verifyCompleteness() const {
+ if (Partitions.empty())
+ return;
+
+ BitVector Result = Partitions[0].second;
+ for (const auto &P : drop_begin(Partitions))
+ Result |= P.second;
+ assert(Result.all() && "some nodes are missing from this proposal!");
+}
+#endif
+
+//===-- RecursiveSearchStrategy -------------------------------------------===//
+
+/// Partitioning algorithm.
+///
+/// This is a recursive search algorithm that can explore multiple possiblities.
+///
+/// When a cluster of nodes can go into more than one partition, and we haven't
+/// reached maximum search depth, we recurse and explore both options and their
+/// consequences. Both branches will yield a proposal, and the driver will grade
+/// both and choose the best one.
+///
+/// If max depth is reached, we will use some heuristics to make a choice. Most
+/// of the time we will just use the least-pressured (cheapest) partition, but
+/// if a cluster is particularly big and there is a good amount of overlap with
+/// an existing partition, we will choose that partition instead.
+class RecursiveSearchSplitting {
+public:
+ using SubmitProposalFn = function_ref<void(SplitProposal)>;
+
+ RecursiveSearchSplitting(const SplitGraph &SG, unsigned NumParts,
+ SubmitProposalFn SubmitProposal);
+
+ void run();
+
+private:
+ struct WorkListEntry {
+ WorkListEntry(const BitVector &BV) : Cluster(BV) {}
+
+ unsigned NumNonEntryNodes = 0;
+ CostType TotalCost = 0;
+ CostType CostExcludingGraphEntryPoints = 0;
+ BitVector Cluster;
};
- // We can't use priority_queue here because we need to be able to access any
- // element. This makes this a bit inefficient as we need to sort it again
- // everytime we change it, but it's a very small array anyway (likely under 64
- // partitions) so it's a cheap operation.
- std::vector<std::pair<PartitionID, CostType>> BalancingQueue;
- for (unsigned I = 0; I < NumParts; ++I)
- BalancingQueue.emplace_back(I, 0);
-
- // Helper function to handle assigning a function to a partition. This takes
- // care of updating the balancing queue.
- const auto AssignToPartition = [&](PartitionID PID,
- const FunctionWithDependencies &FWD) {
- auto &FnsInPart = Partitions[PID];
- FnsInPart.insert(FWD.Fn);
- FnsInPart.insert(FWD.Dependencies.begin(), FWD.Dependencies.end());
-
- SML << "assign " << getName(*FWD.Fn) << " to P" << PID << "\n -> ";
- if (!FWD.Dependencies.empty()) {
- SML << FWD.Dependencies.size() << " dependencies added\n";
- };
-
- // Update the balancing queue. we scan backwards because in the common case
- // the partition is at the end.
- for (auto &[QueuePID, Cost] : reverse(BalancingQueue)) {
- if (QueuePID == PID) {
- CostType NewCost = 0;
- for (auto *Fn : Partitions[PID])
- NewCost += FnCosts.at(Fn);
-
- SML << "[Updating P" << PID << " Cost]:" << Cost << " -> " << NewCost;
- if (Cost) {
- SML << " (" << unsigned(((float(NewCost) / Cost) - 1) * 100)
- << "% increase)";
- }
- SML << '\n';
+ /// Collects all graph entry points's clusters and sort them so the most
+ /// expensive clusters are viewed first. This will merge clusters together if
+ /// they share a non-copyable dependency.
+ void setupWorkList();
+
+ /// Recursive function that assigns the worklist item at \p Idx into a
+ /// partition of \p SP.
+ ///
+ /// \p Depth is the current search depth. When this value is equal to
+ /// \ref MaxDepth, we can no longer recurse.
+ ///
+ /// This function only recurses if there is more than one possible assignment,
+ /// otherwise it is iterative to avoid creating a call stack that is as big as
+ /// \ref WorkList.
+ void pickPartition(unsigned Depth, unsigned Idx, SplitProposal SP);
+
+ /// \return A pair: first element is the PID of the partition that has the
+ /// most similarities with \p Entry, or \ref InvalidPID if no partition was
+ /// found with at least one element in common. The second element is the
+ /// aggregated cost of all dependencies in common between \p Entry and that
+ /// partition.
+ std::pair<unsigned, CostType>
+ findMostSimilarPartition(const WorkListEntry &Entry, const SplitProposal &SP);
+
+ const SplitGraph &SG;
+ unsigned NumParts;
+ SubmitProposalFn SubmitProposal;
+
+ // A Cluster is considered large when its cost, excluding entry points,
+ // exceeds this value.
+ CostType LargeClusterThreshold = 0;
+ unsigned NumProposalsSubmitted = 0;
+ SmallVector<WorkListEntry> WorkList;
+};
+
+RecursiveSearchSplitting::RecursiveSearchSplitting(
+ const SplitGraph &SG, unsigned NumParts, SubmitProposalFn SubmitProposal)
+ : SG(SG), NumParts(NumParts), SubmitProposal(SubmitProposal) {
+ // arbitrary max value as a safeguard. Anything above 10 will already be
+ // slow, this is just a max value to prevent extreme resource exhaustion or
+ // unbounded run time.
+ if (MaxDepth > 16)
+ report_fatal_error("[amdgpu-split-module] search depth of " +
+ Twine(MaxDepth) + " is too high!");
+ LargeClusterThreshold =
+ (LargeFnFactor != 0.0)
+ ? CostType(((SG.getModuleCost() / NumParts) * LargeFnFactor))
+ : std::numeric_limits<CostType>::max();
+ LLVM_DEBUG(dbgs() << "[recursive search] large cluster threshold set at "
+ << LargeClusterThreshold << "\n");
+}
+
+void RecursiveSearchSplitting::run() {
+ {
+ SplitModuleTimer SMT("recursive_search_prepare", "preparing worklist");
+ setupWorkList();
+ }
+
+ {
+ SplitModuleTimer SMT("recursive_search_pick", "partitioning");
+ SplitProposal SP(SG, NumParts);
+ pickPartition(/*BranchDepth=*/0, /*Idx=*/0, SP);
+ }
+}
- Cost = NewCost;
+void RecursiveSearchSplitting::setupWorkList() {
+ // e.g. if A and B are two worklist item, and they both call a non copyable
+ // dependency C, this does:
+ // A=C
+ // B=C
+ // => NodeEC will create a single group (A, B, C) and we create a new
+ // WorkList entry for that group.
+
+ EquivalenceClasses<unsigned> NodeEC;
+ for (const SplitGraph::Node *N : SG.nodes()) {
+ if (!N->isGraphEntryPoint())
+ continue;
+
+ NodeEC.insert(N->getID());
+ N->visitAllDependencies([&](const SplitGraph::Node &Dep) {
+ if (&Dep != N && Dep.isNonCopyable())
+ NodeEC.unionSets(N->getID(), Dep.getID());
+ });
+ }
+
+ for (auto I = NodeEC.begin(), E = NodeEC.end(); I != E; ++I) {
+ if (!I->isLeader())
+ continue;
+
+ BitVector Cluster = SG.createNodesBitVector();
+ for (auto MI = NodeEC.member_begin(I); MI != NodeEC.member_end(); ++MI) {
+ const SplitGraph::Node &N = SG.getNode(*MI);
+ if (N.isGraphEntryPoint())
+ N.setDependenciesBits(Cluster);
+ }
+ WorkList.emplace_back(std::move(Cluster));
+ }
+
+ // Calculate costs and other useful information.
+ for (WorkListEntry &Entry : WorkList) {
+ for (unsigned NodeID : Entry.Cluster.set_bits()) {
+ const SplitGraph::Node &N = SG.getNode(NodeID);
+ const CostType Cost = N.getIndividualCost();
+
+ Entry.TotalCost += Cost;
+ if (!N.isGraphEntryPoint()) {
+ Entry.CostExcludingGraphEntryPoints += Cost;
+ ++Entry.NumNonEntryNodes;
}
}
+ }
- sort(BalancingQueue, ComparePartitions);
- };
+ sort(WorkList, [](const WorkListEntry &LHS, const WorkListEntry &RHS) {
+ return LHS.TotalCost > RHS.TotalCost;
+ });
- for (auto &CurFn : WorkList) {
- // When a function has indirect calls, it must stay in the first partition
- // alongside every reachable non-entry function. This is a nightmare case
- // for splitting as it severely limits what we can do.
- if (CurFn.HasIndirectCall) {
- SML << "Function with indirect call(s): " << getName(*CurFn.Fn)
- << " defaulting to P0\n";
- AssignToPartition(0, CurFn);
- continue;
+ LLVM_DEBUG({
+ dbgs() << "[recursive search] worklist:\n";
+ for (const auto &[Idx, Entry] : enumerate(WorkList)) {
+ dbgs() << " - [" << Idx << "]: ";
+ for (unsigned NodeID : Entry.Cluster.set_bits())
+ dbgs() << NodeID << " ";
+ dbgs() << "(total_cost:" << Entry.TotalCost
+ << ", cost_excl_entries:" << Entry.CostExcludingGraphEntryPoints
+ << ")\n";
}
+ });
+}
- // When a function has non duplicatable dependencies, we have to keep it in
- // the first partition as well. This is a conservative approach, a
- // finer-grained approach could keep track of which dependencies are
- // non-duplicatable exactly and just make sure they're grouped together.
- if (CurFn.HasNonDuplicatableDependecy) {
- SML << "Function with externally visible dependency "
- << getName(*CurFn.Fn) << " defaulting to P0\n";
- AssignToPartition(0, CurFn);
+void RecursiveSearchSplitting::pickPartition(unsigned Depth, unsigned Idx,
+ SplitProposal SP) {
+ while (Idx < WorkList.size()) {
+ // Step 1: Determine candidate PIDs.
+ //
+ const WorkListEntry &Entry = WorkList[Idx];
+ const BitVector &Cluster = Entry.Cluster;
+
+ // Default option is to do load-balancing, AKA assign to least pressured
+ // partition.
+ const unsigned CheapestPID = SP.findCheapestPartition();
+ assert(CheapestPID != InvalidPID);
+
+ // Explore assigning to the kernel that contains the most dependencies in
+ // common.
+ const auto [MostSimilarPID, SimilarDepsCost] =
+ findMostSimilarPartition(Entry, SP);
+
+ // We can chose to explore only one path if we only have one valid path, or
+ // if we reached maximum search depth and can no longer branch out.
+ unsigned SinglePIDToTry = InvalidPID;
+ if (MostSimilarPID == InvalidPID) // no similar PID found
+ SinglePIDToTry = CheapestPID;
+ else if (MostSimilarPID == CheapestPID) // both landed on the same PID
+ SinglePIDToTry = CheapestPID;
+ else if (Depth >= MaxDepth) {
+ // We have to choose one path. Use a heuristic to guess which one will be
+ // more appropriate.
+ if (Entry.CostExcludingGraphEntryPoints > LargeClusterThreshold) {
+ // Check if the amount of code in common makes it worth it.
+ assert(SimilarDepsCost && Entry.CostExcludingGraphEntryPoints);
+ const double Ratio =
+ SimilarDepsCost / Entry.CostExcludingGraphEntryPoints;
+ assert(Ratio >= 0.0 && Ratio <= 1.0);
+ if (LargeFnOverlapForMerge > Ratio) {
+ // For debug, just print "L", so we'll see "L3=P3" for instance, which
+ // will mean we reached max depth and chose P3 based on this
+ // heuristic.
+ LLVM_DEBUG(dbgs() << "L");
----------------
arsenm wrote:
```suggestion
LLVM_DEBUG(dbgs() << 'L');
```
https://github.com/llvm/llvm-project/pull/104763
More information about the llvm-commits
mailing list