[Mlir-commits] [mlir] [mlir][CallGraph] Add edges for callable symbol references in CallGraph (PR #116177)
Haocong Lu
llvmlistbot at llvm.org
Wed Dec 4 00:05:58 PST 2024
https://github.com/Luhaocong updated https://github.com/llvm/llvm-project/pull/116177
>From 53901c1425f88c7fd6481249f5f6c814ea41d69b Mon Sep 17 00:00:00 2001
From: Lu Haocong <haocong.lu at evas.ai>
Date: Thu, 14 Nov 2024 15:11:40 +0800
Subject: [PATCH] [mlir][CallGraph] Add edges for callable symbol references in
CallGraph
This patch introduces a reference edge to represent callable symbol
references in `CallGraph Analysis`. References edges exist whenever
any operation references a callable node. Any callable symbol might
be referenced by external node is also connected with a reference
edge from entry node. The callgraph will represent both direct call
and any potential call might be formed through static optimization.
This implementation refers to many design concepts of `LazyCallGraph`
in LLVM : bf71a34eb9e25c6080d5058a553dbcb75676ff95. The difference is
that mlir retaines the entry and exit nodes, and add references counter
for each callgraph node, which is easy to find a dead node.
This patch also improves readable callgraph, which is supported to dump
all nodes in callgraph and dump references number of each nodes.
---
mlir/include/mlir/Analysis/CallGraph.h | 66 ++++--
mlir/lib/Analysis/CallGraph.cpp | 143 ++++++++----
mlir/lib/Transforms/Utils/Inliner.cpp | 310 ++++---------------------
mlir/test/Analysis/test-callgraph.mlir | 70 +++++-
mlir/test/Transforms/inlining-dce.mlir | 31 ++-
5 files changed, 280 insertions(+), 340 deletions(-)
diff --git a/mlir/include/mlir/Analysis/CallGraph.h b/mlir/include/mlir/Analysis/CallGraph.h
index 631cdd1ad22909..d5f6de2107f0ca 100644
--- a/mlir/include/mlir/Analysis/CallGraph.h
+++ b/mlir/include/mlir/Analysis/CallGraph.h
@@ -16,6 +16,7 @@
#ifndef MLIR_ANALYSIS_CALLGRAPH_H
#define MLIR_ANALYSIS_CALLGRAPH_H
+#include "mlir/IR/SymbolTable.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/GraphTraits.h"
#include "llvm/ADT/MapVector.h"
@@ -42,11 +43,12 @@ class CallGraphNode {
/// This class represents a directed edge between two nodes in the callgraph.
class Edge {
enum class Kind {
- // An 'Abstract' edge represents an opaque, non-operation, reference
- // between this node and the target. Edges of this type are only valid
- // from the external node, as there is no valid connection to an operation
- // in the module.
- Abstract,
+ // A 'Ref' edge represents a reference between this node and the target.
+ // Edges of this type exist whenever any operation references a callable
+ // node. A callable node which may be referenced externally is connected
+ // with a reference edge from the entry node. All 'Call' edges and 'Child'
+ // edges are inherently reference edges in the callgraph.
+ Ref,
// A 'Call' edge represents a direct reference to the target node via a
// call-like operation within the callable region of this node.
@@ -60,8 +62,8 @@ class CallGraphNode {
};
public:
- /// Returns true if this edge represents an `Abstract` edge.
- bool isAbstract() const { return targetAndKind.getInt() == Kind::Abstract; }
+ /// Returns true if this edge represents an `Ref` edge.
+ bool isRef() const { return targetAndKind.getInt() == Kind::Ref; }
/// Returns true if this edge represents a `Call` edge.
bool isCall() const { return targetAndKind.getInt() == Kind::Call; }
@@ -95,10 +97,9 @@ class CallGraphNode {
/// on non-external nodes.
Region *getCallableRegion() const;
- /// Adds an abstract reference edge to the given node. An abstract edge does
- /// not come from any observable operations, so this is only valid on the
- /// external node.
- void addAbstractEdge(CallGraphNode *node);
+ /// Adds a reference edge to the given node. A reference comes from
+ /// external node or any observable operations.
+ void addRefEdge(CallGraphNode *node);
/// Add an outgoing call edge from this node.
void addCallEdge(CallGraphNode *node);
@@ -106,6 +107,12 @@ class CallGraphNode {
/// Adds a reference edge to the given child node.
void addChildEdge(CallGraphNode *child);
+ /// Remove edges to the target node.
+ void removeEdgesTo(CallGraphNode *target);
+
+ /// Remove all edges for this callgraph node.
+ void removeAllEdges();
+
/// Iterator over the outgoing edges of this node.
using iterator = SmallVectorImpl<Edge>::const_iterator;
iterator begin() const { return edges.begin(); }
@@ -114,6 +121,13 @@ class CallGraphNode {
/// Returns true if this node has any child edges.
bool hasChildren() const;
+ /// Return true if this node is dead.
+ bool isDead() const { return !isExternal() && getNumReferences() == 0; }
+
+ /// Returns the number of edges in this callgraph that connected
+ /// to this node.
+ unsigned getNumReferences() const { return NumReferences; }
+
private:
/// DenseMap info for callgraph edges.
struct EdgeKeyInfo {
@@ -143,6 +157,17 @@ class CallGraphNode {
llvm::SmallDenseSet<Edge, 4, EdgeKeyInfo>>
edges;
+ /// The number of edges in this callgraph that connected to this
+ /// callgraph node.
+ unsigned NumReferences = 0;
+
+ /// Decrease reference counter when an edge connected to this callgraph node
+ /// is removed.
+ void dropRef() { --NumReferences; }
+
+ /// Add reference counter when a new edge is connected to this callgraph node.
+ void addRef() { ++NumReferences; }
+
// Provide access to private methods.
friend class CallGraph;
};
@@ -184,13 +209,13 @@ class CallGraph {
/// registered.
CallGraphNode *lookupNode(Region *region) const;
- /// Return the callgraph node representing an external caller.
- CallGraphNode *getExternalCallerNode() const {
+ /// Return the callgraph node representing the entry node of this callgraph.
+ CallGraphNode *getEntryNode() const {
return const_cast<CallGraphNode *>(&externalCallerNode);
}
- /// Return the callgraph node representing an indirect callee.
- CallGraphNode *getUnknownCalleeNode() const {
+ /// Return the callgraph node representing the exit node of this callgraph.
+ CallGraphNode *getExitNode() const {
return const_cast<CallGraphNode *>(&unknownCalleeNode);
}
@@ -201,6 +226,11 @@ class CallGraph {
CallGraphNode *resolveCallable(CallOpInterface call,
SymbolTableCollection &symbolTable) const;
+ /// Resolve the callable for given symbol to a node in the callgraph, or the
+ /// external node if a valid node was not resolved.
+ CallGraphNode *resolveCallable(Operation *op, SymbolRefAttr symbolRef,
+ SymbolTableCollection &symbolTable) const;
+
/// Erase the given node from the callgraph.
void eraseNode(CallGraphNode *node);
@@ -217,10 +247,10 @@ class CallGraph {
/// The set of nodes within the callgraph.
NodeMapT nodes;
- /// A special node used to indicate an external caller.
+ /// A special node used to indicate an entry callgraph node.
CallGraphNode externalCallerNode;
- /// A special node used to indicate an unknown callee.
+ /// A special node used to indicate an exit callgraph node.
CallGraphNode unknownCalleeNode;
};
@@ -254,7 +284,7 @@ struct GraphTraits<const mlir::CallGraph *>
: public GraphTraits<const mlir::CallGraphNode *> {
/// The entry node into the graph is the external node.
static NodeRef getEntryNode(const mlir::CallGraph *cg) {
- return cg->getExternalCallerNode();
+ return cg->getEntryNode();
}
// nodes_iterator/begin/end - Allow iteration over all nodes in the graph
diff --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp
index 780c7caee767c1..559c3c733b3df4 100644
--- a/mlir/lib/Analysis/CallGraph.cpp
+++ b/mlir/lib/Analysis/CallGraph.cpp
@@ -38,11 +38,9 @@ Region *CallGraphNode::getCallableRegion() const {
return callableRegion;
}
-/// Adds an reference edge to the given node. This is only valid on the
-/// external node.
-void CallGraphNode::addAbstractEdge(CallGraphNode *node) {
- assert(isExternal() && "abstract edges are only valid on external nodes");
- addEdge(node, Edge::Kind::Abstract);
+/// Adds an reference edge to the given node.
+void CallGraphNode::addRefEdge(CallGraphNode *node) {
+ addEdge(node, Edge::Kind::Ref);
}
/// Add an outgoing call edge from this node.
@@ -60,9 +58,28 @@ bool CallGraphNode::hasChildren() const {
return llvm::any_of(edges, [](const Edge &edge) { return edge.isChild(); });
}
+/// Remove edges to the target callgraph node.
+void CallGraphNode::removeEdgesTo(CallGraphNode *target) {
+ edges.remove_if([target](const CallGraphNode::Edge &edge) {
+ if (edge.getTarget() != target)
+ return false;
+ target->dropRef();
+ return true;
+ });
+}
+
+/// Remove all edges for this callgraph node.
+void CallGraphNode::removeAllEdges() {
+ edges.remove_if([](const CallGraphNode::Edge &edge) {
+ edge.getTarget()->dropRef();
+ return true;
+ });
+}
+
/// Add an edge to 'node' with the given kind.
void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) {
- edges.insert({node, kind});
+ if (edges.insert({node, kind}))
+ node->addRef();
}
//===----------------------------------------------------------------------===//
@@ -73,14 +90,31 @@ void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) {
/// edges are placed into the given callgraph object.
static void computeCallGraph(Operation *op, CallGraph &cg,
SymbolTableCollection &symbolTable,
- CallGraphNode *parentNode, bool resolveCalls) {
- if (CallOpInterface call = dyn_cast<CallOpInterface>(op)) {
- // If there is no parent node, we ignore this operation. Even if this
- // operation was a call, there would be no callgraph node to attribute it
- // to.
- if (resolveCalls && parentNode)
- parentNode->addCallEdge(cg.resolveCallable(call, symbolTable));
- return;
+ CallGraphNode *parentNode, bool resolveSymbolRef) {
+ if (resolveSymbolRef) {
+ bool skipCalleeSymbolRef = false;
+ if (auto call = dyn_cast<CallOpInterface>(op)) {
+ // If there is no parent node, there would be no callgraph node to call it
+ // directly, we use a reference egde to represent this call.
+ if (!parentNode->isExternal()) {
+ parentNode->addCallEdge(cg.resolveCallable(call, symbolTable));
+ if (isa<SymbolRefAttr>(call.getCallableForCallee()))
+ skipCalleeSymbolRef = true;
+ }
+ }
+ // If an operation reference a callable symbol, create a reference edge.
+ op->getAttrDictionary().walk<WalkOrder::PreOrder>(
+ [&](SymbolRefAttr symbolRef) {
+ // Skip the first symbol reference if it is a resolved callee.
+ if (skipCalleeSymbolRef) {
+ skipCalleeSymbolRef = false;
+ return WalkResult::skip();
+ }
+ auto *node = cg.resolveCallable(op, symbolRef, symbolTable);
+ if (!node->isExternal())
+ parentNode->addRefEdge(node);
+ return WalkResult::skip();
+ });
}
// Compute the callgraph nodes and edges for each of the nested operations.
@@ -93,42 +127,41 @@ static void computeCallGraph(Operation *op, CallGraph &cg,
for (Region ®ion : op->getRegions())
for (Operation &nested : region.getOps())
- computeCallGraph(&nested, cg, symbolTable, parentNode, resolveCalls);
+ computeCallGraph(&nested, cg, symbolTable, parentNode, resolveSymbolRef);
}
CallGraph::CallGraph(Operation *op)
: externalCallerNode(/*callableRegion=*/nullptr),
unknownCalleeNode(/*callableRegion=*/nullptr) {
// Make two passes over the graph, one to compute the callables and one to
- // resolve the calls. We split these up as we may have nested callable objects
- // that need to be reserved before the calls.
+ // resolve the calls and symbol references. We split these up as we may have
+ // nested callable objects that need to be reserved before the calls.
SymbolTableCollection symbolTable;
- computeCallGraph(op, *this, symbolTable, /*parentNode=*/nullptr,
- /*resolveCalls=*/false);
- computeCallGraph(op, *this, symbolTable, /*parentNode=*/nullptr,
- /*resolveCalls=*/true);
+ computeCallGraph(op, *this, symbolTable, getEntryNode(),
+ /*resolveSymbolRef=*/false);
+ computeCallGraph(op, *this, symbolTable, getEntryNode(),
+ /*resolveSymbolRef=*/true);
}
/// Get or add a call graph node for the given region.
CallGraphNode *CallGraph::getOrAddNode(Region *region,
CallGraphNode *parentNode) {
- assert(region && isa<CallableOpInterface>(region->getParentOp()) &&
+ Operation *parentOp = region->getParentOp();
+ assert(region && isa<CallableOpInterface>(parentOp) &&
"expected parent operation to be callable");
std::unique_ptr<CallGraphNode> &node = nodes[region];
if (!node) {
node.reset(new CallGraphNode(region));
// Add this node to the given parent node if necessary.
- if (parentNode) {
+ assert(parentNode && "expected non-empty parent node");
+ if (!parentNode->isExternal()) {
parentNode->addChildEdge(node.get());
- } else {
- // Otherwise, connect all callable nodes to the external node, this allows
- // for conservatively including all callable nodes within the graph.
- // FIXME This isn't correct, this is only necessary for callable nodes
- // that *could* be called from external sources. This requires extending
- // the interface for callables to check if they may be referenced
- // externally.
- externalCallerNode.addAbstractEdge(node.get());
+ } else if (auto symbolOp = dyn_cast<SymbolOpInterface>(parentOp)) {
+ // Otherwise, connect all symbol nodes with public visibility
+ // to the entry node, which may be referenced externally.
+ if (!symbolOp.canDiscardOnUseEmpty())
+ parentNode->addRefEdge(node.get());
}
}
return node.get();
@@ -151,7 +184,21 @@ CallGraph::resolveCallable(CallOpInterface call,
if (auto *node = lookupNode(callableOp.getCallableRegion()))
return node;
- return getUnknownCalleeNode();
+ return getExitNode();
+}
+
+/// Resolve the callable for given symbol to a node in the callgraph, or the
+/// external node if a valid node was not resolved.
+CallGraphNode *
+CallGraph::resolveCallable(Operation *op, SymbolRefAttr symbolRef,
+ SymbolTableCollection &symbolTable) const {
+ auto *symbolOp = symbolTable.lookupNearestSymbolFrom(op, symbolRef);
+ if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp)) {
+ if (auto *node = lookupNode(callableOp.getCallableRegion()))
+ return node;
+ }
+
+ return getExitNode();
}
/// Erase the given node from the callgraph.
@@ -163,11 +210,12 @@ void CallGraph::eraseNode(CallGraphNode *node) {
eraseNode(edge.getTarget());
}
// Erase any edges to this node from any other nodes.
- for (auto &it : nodes) {
- it.second->edges.remove_if([node](const CallGraphNode::Edge &edge) {
- return edge.getTarget() == node;
- });
- }
+ for (auto &it : nodes)
+ it.second->removeEdgesTo(node);
+ // Erase all edges from this node to any other nodes.
+ node->removeAllEdges();
+
+ assert(node->isDead() && "expected no references");
nodes.erase(node->getCallableRegion());
}
@@ -181,11 +229,11 @@ void CallGraph::print(raw_ostream &os) const {
// Functor used to output the name for the given node.
auto emitNodeName = [&](const CallGraphNode *node) {
- if (node == getExternalCallerNode()) {
+ if (node == getEntryNode()) {
os << "<External-Caller-Node>";
return;
}
- if (node == getUnknownCalleeNode()) {
+ if (node == getExitNode()) {
os << "<Unknown-Callee-Node>";
return;
}
@@ -199,13 +247,12 @@ void CallGraph::print(raw_ostream &os) const {
os << " : " << attrs;
};
- for (auto &nodeIt : nodes) {
- const CallGraphNode *node = nodeIt.second.get();
-
+ // Functor used to emit the given node and edges.
+ auto emitNodeAndEdge = [&](const CallGraphNode *node) {
// Dump the header for this node.
os << "// - Node : ";
emitNodeName(node);
- os << "\n";
+ os << " : References = " << node->getNumReferences() << "\n";
// Emit each of the edges.
for (auto &edge : *node) {
@@ -214,13 +261,21 @@ void CallGraph::print(raw_ostream &os) const {
os << "Call";
else if (edge.isChild())
os << "Child";
+ else if (edge.isRef())
+ os << "Ref";
os << "-Edge : ";
emitNodeName(edge.getTarget());
os << "\n";
}
os << "//\n";
- }
+ };
+
+ // Emit all graph nodes including entry and exit node.
+ for (auto &nodeIt : nodes)
+ emitNodeAndEdge(nodeIt.second.get());
+ emitNodeAndEdge(getEntryNode());
+ emitNodeAndEdge(getExitNode());
os << "// -- SCCs --\n";
diff --git a/mlir/lib/Transforms/Utils/Inliner.cpp b/mlir/lib/Transforms/Utils/Inliner.cpp
index 8acfc96d2b611b..86348e8ad6741f 100644
--- a/mlir/lib/Transforms/Utils/Inliner.cpp
+++ b/mlir/lib/Transforms/Utils/Inliner.cpp
@@ -31,223 +31,6 @@ using namespace mlir;
using ResolvedCall = Inliner::ResolvedCall;
-//===----------------------------------------------------------------------===//
-// Symbol Use Tracking
-//===----------------------------------------------------------------------===//
-
-/// Walk all of the used symbol callgraph nodes referenced with the given op.
-static void walkReferencedSymbolNodes(
- Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable,
- DenseMap<Attribute, CallGraphNode *> &resolvedRefs,
- function_ref<void(CallGraphNode *, Operation *)> callback) {
- auto symbolUses = SymbolTable::getSymbolUses(op);
- assert(symbolUses && "expected uses to be valid");
-
- Operation *symbolTableOp = op->getParentOp();
- for (const SymbolTable::SymbolUse &use : *symbolUses) {
- auto refIt = resolvedRefs.insert({use.getSymbolRef(), nullptr});
- CallGraphNode *&node = refIt.first->second;
-
- // If this is the first instance of this reference, try to resolve a
- // callgraph node for it.
- if (refIt.second) {
- auto *symbolOp = symbolTable.lookupNearestSymbolFrom(symbolTableOp,
- use.getSymbolRef());
- auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
- if (!callableOp)
- continue;
- node = cg.lookupNode(callableOp.getCallableRegion());
- }
- if (node)
- callback(node, use.getUser());
- }
-}
-
-//===----------------------------------------------------------------------===//
-// CGUseList
-
-namespace {
-/// This struct tracks the uses of callgraph nodes that can be dropped when
-/// use_empty. It directly tracks and manages a use-list for all of the
-/// call-graph nodes. This is necessary because many callgraph nodes are
-/// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use`
-/// class.
-struct CGUseList {
- /// This struct tracks the uses of callgraph nodes within a specific
- /// operation.
- struct CGUser {
- /// Any nodes referenced in the top-level attribute list of this user. We
- /// use a set here because the number of references does not matter.
- DenseSet<CallGraphNode *> topLevelUses;
-
- /// Uses of nodes referenced by nested operations.
- DenseMap<CallGraphNode *, int> innerUses;
- };
-
- CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable);
-
- /// Drop uses of nodes referred to by the given call operation that resides
- /// within 'userNode'.
- void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg);
-
- /// Remove the given node from the use list.
- void eraseNode(CallGraphNode *node);
-
- /// Returns true if the given callgraph node has no uses and can be pruned.
- bool isDead(CallGraphNode *node) const;
-
- /// Returns true if the given callgraph node has a single use and can be
- /// discarded.
- bool hasOneUseAndDiscardable(CallGraphNode *node) const;
-
- /// Recompute the uses held by the given callgraph node.
- void recomputeUses(CallGraphNode *node, CallGraph &cg);
-
- /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy
- /// of 'lhs' into 'rhs'.
- void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs);
-
-private:
- /// Decrement the uses of discardable nodes referenced by the given user.
- void decrementDiscardableUses(CGUser &uses);
-
- /// A mapping between a discardable callgraph node (that is a symbol) and the
- /// number of uses for this node.
- DenseMap<CallGraphNode *, int> discardableSymNodeUses;
-
- /// A mapping between a callgraph node and the symbol callgraph nodes that it
- /// uses.
- DenseMap<CallGraphNode *, CGUser> nodeUses;
-
- /// A symbol table to use when resolving call lookups.
- SymbolTableCollection &symbolTable;
-};
-} // namespace
-
-CGUseList::CGUseList(Operation *op, CallGraph &cg,
- SymbolTableCollection &symbolTable)
- : symbolTable(symbolTable) {
- /// A set of callgraph nodes that are always known to be live during inlining.
- DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes;
-
- // Walk each of the symbol tables looking for discardable callgraph nodes.
- auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
- for (Operation &op : symbolTableOp->getRegion(0).getOps()) {
- // If this is a callgraph operation, check to see if it is discardable.
- if (auto callable = dyn_cast<CallableOpInterface>(&op)) {
- if (auto *node = cg.lookupNode(callable.getCallableRegion())) {
- SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
- if (symbol && (allUsesVisible || symbol.isPrivate()) &&
- symbol.canDiscardOnUseEmpty()) {
- discardableSymNodeUses.try_emplace(node, 0);
- }
- continue;
- }
- }
- // Otherwise, check for any referenced nodes. These will be always-live.
- walkReferencedSymbolNodes(&op, cg, symbolTable, alwaysLiveNodes,
- [](CallGraphNode *, Operation *) {});
- }
- };
- SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
- walkFn);
-
- // Drop the use information for any discardable nodes that are always live.
- for (auto &it : alwaysLiveNodes)
- discardableSymNodeUses.erase(it.second);
-
- // Compute the uses for each of the callable nodes in the graph.
- for (CallGraphNode *node : cg)
- recomputeUses(node, cg);
-}
-
-void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp,
- CallGraph &cg) {
- auto &userRefs = nodeUses[userNode].innerUses;
- auto walkFn = [&](CallGraphNode *node, Operation *user) {
- auto parentIt = userRefs.find(node);
- if (parentIt == userRefs.end())
- return;
- --parentIt->second;
- --discardableSymNodeUses[node];
- };
- DenseMap<Attribute, CallGraphNode *> resolvedRefs;
- walkReferencedSymbolNodes(callOp, cg, symbolTable, resolvedRefs, walkFn);
-}
-
-void CGUseList::eraseNode(CallGraphNode *node) {
- // Drop all child nodes.
- for (auto &edge : *node)
- if (edge.isChild())
- eraseNode(edge.getTarget());
-
- // Drop the uses held by this node and erase it.
- auto useIt = nodeUses.find(node);
- assert(useIt != nodeUses.end() && "expected node to be valid");
- decrementDiscardableUses(useIt->getSecond());
- nodeUses.erase(useIt);
- discardableSymNodeUses.erase(node);
-}
-
-bool CGUseList::isDead(CallGraphNode *node) const {
- // If the parent operation isn't a symbol, simply check normal SSA deadness.
- Operation *nodeOp = node->getCallableRegion()->getParentOp();
- if (!isa<SymbolOpInterface>(nodeOp))
- return isMemoryEffectFree(nodeOp) && nodeOp->use_empty();
-
- // Otherwise, check the number of symbol uses.
- auto symbolIt = discardableSymNodeUses.find(node);
- return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0;
-}
-
-bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const {
- // If this isn't a symbol node, check for side-effects and SSA use count.
- Operation *nodeOp = node->getCallableRegion()->getParentOp();
- if (!isa<SymbolOpInterface>(nodeOp))
- return isMemoryEffectFree(nodeOp) && nodeOp->hasOneUse();
-
- // Otherwise, check the number of symbol uses.
- auto symbolIt = discardableSymNodeUses.find(node);
- return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1;
-}
-
-void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) {
- Operation *parentOp = node->getCallableRegion()->getParentOp();
- CGUser &uses = nodeUses[node];
- decrementDiscardableUses(uses);
-
- // Collect the new discardable uses within this node.
- uses = CGUser();
- DenseMap<Attribute, CallGraphNode *> resolvedRefs;
- auto walkFn = [&](CallGraphNode *refNode, Operation *user) {
- auto discardSymIt = discardableSymNodeUses.find(refNode);
- if (discardSymIt == discardableSymNodeUses.end())
- return;
-
- if (user != parentOp)
- ++uses.innerUses[refNode];
- else if (!uses.topLevelUses.insert(refNode).second)
- return;
- ++discardSymIt->second;
- };
- walkReferencedSymbolNodes(parentOp, cg, symbolTable, resolvedRefs, walkFn);
-}
-
-void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) {
- auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs];
- for (auto &useIt : lhsUses.innerUses) {
- rhsUses.innerUses[useIt.first] += useIt.second;
- discardableSymNodeUses[useIt.first] += useIt.second;
- }
-}
-
-void CGUseList::decrementDiscardableUses(CGUser &uses) {
- for (CallGraphNode *node : uses.topLevelUses)
- --discardableSymNodeUses[node];
- for (auto &it : uses.innerUses)
- discardableSymNodeUses[it.first] -= it.second;
-}
-
//===----------------------------------------------------------------------===//
// CallGraph traversal
//===----------------------------------------------------------------------===//
@@ -396,18 +179,18 @@ struct InlinerInterfaceImpl : public InlinerInterface {
/*traverseNestedCGNodes=*/true);
}
- /// Mark the given callgraph node for deletion.
- void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); }
+ /// Mark the given callable for deletion.
+ void markForDeletion(Operation *callable) { deadCallables.insert(callable); }
/// This method properly disposes of callables that became dead during
/// inlining. This should not be called while iterating over the SCCs.
void eraseDeadCallables() {
- for (CallGraphNode *node : deadNodes)
- node->getCallableRegion()->getParentOp()->erase();
+ for (Operation *callable : deadCallables)
+ callable->erase();
}
/// The set of callables known to be dead.
- SmallPtrSet<CallGraphNode *, 8> deadNodes;
+ SmallPtrSet<Operation *, 8> deadCallables;
/// The current set of call instructions to consider for inlining.
SmallVector<ResolvedCall, 8> calls;
@@ -431,15 +214,16 @@ class Inliner::Impl {
/// devirtualized calls. Returns failure if there was a fatal error during
/// inlining.
LogicalResult inlineSCC(InlinerInterfaceImpl &inlinerIface,
- CGUseList &useList, CallGraphSCC ¤tSCC,
- MLIRContext *context);
+ CallGraphSCC ¤tSCC, MLIRContext *context);
+
+ void collectDeadNodeAfterInline(Operation *topLevelOp,
+ InlinerInterfaceImpl &inlinerIface);
private:
/// Optimize the nodes within the given SCC with one of the held optimization
/// pass pipelines. Returns failure if an error occurred during the
/// optimization of the SCC, success otherwise.
- LogicalResult optimizeSCC(CallGraph &cg, CGUseList &useList,
- CallGraphSCC ¤tSCC, MLIRContext *context);
+ LogicalResult optimizeSCC(CallGraphSCC ¤tSCC, MLIRContext *context);
/// Optimize the nodes within the given SCC in parallel. Returns failure if an
/// error occurred during the optimization of the SCC, success otherwise.
@@ -456,7 +240,7 @@ class Inliner::Impl {
/// Attempt to inline calls within the given scc. This function returns
/// success if any calls were inlined, failure otherwise.
LogicalResult inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
- CGUseList &useList, CallGraphSCC ¤tSCC);
+ CallGraphSCC ¤tSCC);
/// Returns true if the given call should be inlined.
bool shouldInline(ResolvedCall &resolvedCall);
@@ -467,7 +251,6 @@ class Inliner::Impl {
};
LogicalResult Inliner::Impl::inlineSCC(InlinerInterfaceImpl &inlinerIface,
- CGUseList &useList,
CallGraphSCC ¤tSCC,
MLIRContext *context) {
// Continuously simplify and inline until we either reach a fixed point, or
@@ -475,20 +258,20 @@ LogicalResult Inliner::Impl::inlineSCC(InlinerInterfaceImpl &inlinerIface,
// model, and in future iterations may devirtualize new calls.
unsigned iterationCount = 0;
do {
- if (failed(optimizeSCC(inlinerIface.cg, useList, currentSCC, context)))
+ if (failed(optimizeSCC(currentSCC, context)))
return failure();
- if (failed(inlineCallsInSCC(inlinerIface, useList, currentSCC)))
+ if (failed(inlineCallsInSCC(inlinerIface, currentSCC)))
break;
} while (++iterationCount < inliner.config.getMaxInliningIterations());
return success();
}
-LogicalResult Inliner::Impl::optimizeSCC(CallGraph &cg, CGUseList &useList,
- CallGraphSCC ¤tSCC,
+LogicalResult Inliner::Impl::optimizeSCC(CallGraphSCC ¤tSCC,
MLIRContext *context) {
// Collect the sets of nodes to simplify.
SmallVector<CallGraphNode *, 4> nodesToVisit;
for (auto *node : currentSCC) {
+ assert(!node->isDead() && "unexcepted dead node in callgraph SCC");
if (node->isExternal())
continue;
@@ -512,9 +295,6 @@ LogicalResult Inliner::Impl::optimizeSCC(CallGraph &cg, CGUseList &useList,
if (failed(optimizeSCCAsync(nodesToVisit, context)))
return failure();
- // Recompute the uses held by each of the nodes.
- for (CallGraphNode *node : nodesToVisit)
- useList.recomputeUses(node, cg);
return success();
}
@@ -583,7 +363,7 @@ Inliner::Impl::optimizeCallable(CallGraphNode *node,
/// success if any calls were inlined, failure otherwise.
LogicalResult
Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
- CGUseList &useList, CallGraphSCC ¤tSCC) {
+ CallGraphSCC ¤tSCC) {
CallGraph &cg = inlinerIface.cg;
auto &calls = inlinerIface.calls;
@@ -594,17 +374,13 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
// don't traverse nested callgraph nodes, because they are handled separately
// likely within a different SCC.
for (CallGraphNode *node : currentSCC) {
+ assert(!node->isDead() && "unexcepted dead node in callgraph SCC");
if (node->isExternal())
continue;
- // Don't collect calls if the node is already dead.
- if (useList.isDead(node)) {
- deadNodes.insert(node);
- } else {
- collectCallOps(*node->getCallableRegion(), node, cg,
- inlinerIface.symbolTable, calls,
- /*traverseNestedCGNodes=*/false);
- }
+ collectCallOps(*node->getCallableRegion(), node, cg,
+ inlinerIface.symbolTable, calls,
+ /*traverseNestedCGNodes=*/false);
}
// When inlining a callee produces new call sites, we want to keep track of
@@ -646,14 +422,10 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
unsigned prevSize = calls.size();
Region *targetRegion = it.targetNode->getCallableRegion();
- // If this is the last call to the target node and the node is discardable,
- // then inline it in-place and delete the node if successful.
- bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode);
-
LogicalResult inlineResult =
inlineCall(inlinerIface, call,
cast<CallableOpInterface>(targetRegion->getParentOp()),
- targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace);
+ targetRegion, /*shouldCloneInlinedRegion=*/true);
if (failed(inlineResult)) {
LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n");
continue;
@@ -683,24 +455,10 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
<< historyToString(inlineHistoryID) << "\n");
}
- // If the inlining was successful, Merge the new uses into the source node.
- useList.dropCallUses(it.sourceNode, call.getOperation(), cg);
- useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode);
-
// then erase the call.
call.erase();
-
- // If we inlined in place, mark the node for deletion.
- if (inlineInPlace) {
- useList.eraseNode(it.targetNode);
- deadNodes.insert(it.targetNode);
- }
}
- for (CallGraphNode *node : deadNodes) {
- currentSCC.remove(node);
- inlinerIface.markForDeletion(node);
- }
calls.clear();
return success(inlinedAnyCalls);
}
@@ -748,6 +506,28 @@ bool Inliner::Impl::shouldInline(ResolvedCall &resolvedCall) {
return true;
}
+/// Iteratively clean up dead nodes until no change happened.
+void Inliner::Impl::collectDeadNodeAfterInline(
+ Operation *topLevelOp, InlinerInterfaceImpl &inlinerIface) {
+ auto newCallGraph = CallGraph(topLevelOp);
+ SmallVector<CallGraphNode *, 10> deadNodes;
+ while (1) {
+ deadNodes.clear();
+ for (CallGraphNode *node : newCallGraph) {
+ if (node->isDead())
+ deadNodes.push_back(node);
+ }
+
+ for (auto &node : deadNodes) {
+ inlinerIface.markForDeletion(node->getCallableRegion()->getParentOp());
+ newCallGraph.eraseNode(node);
+ }
+
+ if (deadNodes.empty())
+ break;
+ }
+}
+
LogicalResult Inliner::doInlining() {
Impl impl(*this);
auto *context = op->getContext();
@@ -757,14 +537,14 @@ LogicalResult Inliner::doInlining() {
// of the Impl's methods, if the inlinerIface and useList
// become the states of the Impl.
InlinerInterfaceImpl inlinerIface(context, cg, symbolTable);
- CGUseList useList(op, cg, symbolTable);
LogicalResult result = runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) {
- return impl.inlineSCC(inlinerIface, useList, scc, context);
+ return impl.inlineSCC(inlinerIface, scc, context);
});
if (failed(result))
return result;
// After inlining, make sure to erase any callables proven to be dead.
+ impl.collectDeadNodeAfterInline(op, inlinerIface);
inlinerIface.eraseDeadCallables();
return success();
}
diff --git a/mlir/test/Analysis/test-callgraph.mlir b/mlir/test/Analysis/test-callgraph.mlir
index f6c9ff5006e053..8d47003a4b4d68 100644
--- a/mlir/test/Analysis/test-callgraph.mlir
+++ b/mlir/test/Analysis/test-callgraph.mlir
@@ -8,30 +8,32 @@ module attributes {test.name = "simple"} {
return
}
+ // CHECK-NOT: Node{{.*}}func_b
func.func private @func_b()
- // CHECK: Node{{.*}}func_c
+ // CHECK: Node{{.*}}func_c{{.*}}private
// CHECK-NEXT: Call-Edge{{.*}}Unknown-Callee-Node
- func.func @func_c() {
+ func.func private @func_c() {
call @func_b() : () -> ()
return
}
// CHECK: Node{{.*}}func_d
- // CHECK-NEXT: Call-Edge{{.*}}func_c
+ // CHECK-NEXT: Call-Edge{{.*}}func_c{{.*}}private
func.func @func_d() {
call @func_c() : () -> ()
return
}
// CHECK: Node{{.*}}func_e
- // CHECK-DAG: Call-Edge{{.*}}func_c
+ // CHECK-DAG: Call-Edge{{.*}}func_c{{.*}}private
// CHECK-DAG: Call-Edge{{.*}}func_d
// CHECK-DAG: Call-Edge{{.*}}func_e
+ // CHECK-DAG: Ref-Edge{{.*}}func_a
func.func @func_e() {
call @func_c() : () -> ()
call @func_d() : () -> ()
- call @func_e() : () -> ()
+ call @func_e() { use = @func_a } : () -> ()
return
}
@@ -49,6 +51,39 @@ module attributes {test.name = "simple"} {
call_indirect %fn() : () -> ()
return
}
+
+ // CHECK: Node{{.*}}func_g
+ // CHECK: Ref-Edge{{.*}}func_c
+ // CHECK: Call-Edge{{.*}}Unknown-Callee-Node
+ func.func @func_g() -> (() -> ()) {
+ // A private symbol maybe escaped.
+ %0 = func.constant @func_c : () -> ()
+ call_indirect %0() : () -> ()
+ return %0 : () -> ()
+ }
+
+ // CHECK: Node{{.*}}func_h{{.*}}private
+ func.func private @func_h() {
+ return
+ }
+
+ // Referenced symbol declarations is ignored.
+ "live.user"() { use = @func_b } : () -> ()
+ // non-callable top level operation reference callable symbol
+ "live.user"() { use = @func_c } : () -> ()
+ func.call @func_h() : () -> ()
+
+ // CHECK: Node{{.*}}External-Caller-Node
+ // CHECK-NEXT: Ref-Edge{{.*}}func_a
+ // CHECK-NEXT: Ref-Edge{{.*}}func_d
+ // CHECK-NEXT: Ref-Edge{{.*}}func_e
+ // CHECK-NEXT: Ref-Edge{{.*}}func_f
+ // CHECK-NEXT: Ref-Edge{{.*}}func_g
+ // CHECK-NOT: Ref-Edge{{.*}}func_b
+ // CHECK-NEXT: Ref-Edge{{.*}}func_c{{.*}}private
+ // CHECK-NEXT: Ref-Edge{{.*}}func_h{{.*}}private
+
+ // CHECK: Node{{.*}}Unknown-Callee-Node
}
// -----
@@ -56,18 +91,30 @@ module attributes {test.name = "simple"} {
// CHECK-LABEL: Testing : "nested"
module attributes {test.name = "nested"} {
module @nested_module {
- // CHECK: Node{{.*}}func_a
- func.func @func_a() {
+ // CHECK: Node{{.*}}func_a{{.*}}nested
+ func.func nested @func_a() {
+ return
+ }
+ // CHECK: Node{{.*}}func_b{{.*}}nested
+ func.func nested @func_b() {
return
}
}
- // CHECK: Node{{.*}}func_b
- // CHECK: Call-Edge{{.*}}func_a
- func.func @func_b() {
- "test.conversion_call_op"() { callee = @nested_module::@func_a } : () -> ()
+ // CHECK: Node{{.*}}func_c
+ // CHECK: Call-Edge{{.*}}func_a{{.*}}nested
+ // CHECK: Ref-Edge{{.*}}func_b{{.*}}nested
+ func.func @func_c() {
+ "test.conversion_call_op"() { use = @nested_module::@func_b, callee = @nested_module::@func_a } : () -> ()
return
}
+
+ // CHECK: Node{{.*}}External-Caller-Node
+ // CHECK-NEXT: Ref-Edge{{.*}}func_c
+ // CHECK-NOT: Ref-Edge{{.*}}func_a
+ // CHECK-NOT: Ref-Edge{{.*}}func_b
+
+ // CHECK: Node{{.*}}Unknown-Callee-Node
}
// -----
@@ -95,4 +142,3 @@ module attributes {test.name = "SCC"} {
// CHECK: SCC :
// CHECK-NEXT: Node{{.*}}External-Caller-Node
}
-
diff --git a/mlir/test/Transforms/inlining-dce.mlir b/mlir/test/Transforms/inlining-dce.mlir
index d167c1b4baae98..c31b909903c3c2 100644
--- a/mlir/test/Transforms/inlining-dce.mlir
+++ b/mlir/test/Transforms/inlining-dce.mlir
@@ -10,7 +10,7 @@ func.func private @dead_function() {
// Function becomes dead after inlining.
// CHECK-NOT: func private @dead_function_b
-func.func @dead_function_b() {
+func.func private @dead_function_b() {
return
}
@@ -44,6 +44,35 @@ func.func @live_function_c() {
return
}
+// A transitive example, but no one be called by live-function.
+
+// CHECK-NOT: func private @dead_function_e
+func.func private @dead_function_e() {
+ call @live_function_b() : () -> ()
+ return
+}
+
+// CHECK-NOT: func private @dead_function_f
+func.func private @dead_function_f() {
+ call @dead_function_e() : () -> ()
+ return
+}
+
+// A function constant is inlined after optimization
+
+// CHECK-NOT: func private @dead_function_h
+func.func private @dead_function_h() {
+ call @live_function_b() : () -> ()
+ return
+}
+
+// CHECK: func @live_function_f
+func.func @live_function_f() {
+ %0 = func.constant @dead_function_h : () -> ()
+ call_indirect %0() : () -> ()
+ return
+}
+
// Function is referenced by non-callable top-level user.
// CHECK: func private @live_function_d
func.func private @live_function_d() {
More information about the Mlir-commits
mailing list