[Mlir-commits] [mlir] [mlir][CallGraph] Add edges for callable symbol references in CallGraph (PR #116177)

Haocong Lu llvmlistbot at llvm.org
Wed Dec 4 00:05:58 PST 2024


https://github.com/Luhaocong updated https://github.com/llvm/llvm-project/pull/116177

>From 53901c1425f88c7fd6481249f5f6c814ea41d69b Mon Sep 17 00:00:00 2001
From: Lu Haocong <haocong.lu at evas.ai>
Date: Thu, 14 Nov 2024 15:11:40 +0800
Subject: [PATCH] [mlir][CallGraph] Add edges for callable symbol references in
 CallGraph

This patch introduces a reference edge to represent callable symbol
references in `CallGraph Analysis`. References edges exist whenever
any operation references a callable node. Any callable symbol might
be referenced by external node is also connected with a reference
edge from entry node. The callgraph will represent both direct call
and any potential call might be formed through static optimization.

This implementation refers to many design concepts of `LazyCallGraph`
in LLVM : bf71a34eb9e25c6080d5058a553dbcb75676ff95. The difference is
that mlir retaines the entry and exit nodes, and add references counter
for each callgraph node, which is easy to find a dead node.

This patch also improves readable callgraph, which is supported to dump
all nodes in callgraph and dump references number of each nodes.
---
 mlir/include/mlir/Analysis/CallGraph.h |  66 ++++--
 mlir/lib/Analysis/CallGraph.cpp        | 143 ++++++++----
 mlir/lib/Transforms/Utils/Inliner.cpp  | 310 ++++---------------------
 mlir/test/Analysis/test-callgraph.mlir |  70 +++++-
 mlir/test/Transforms/inlining-dce.mlir |  31 ++-
 5 files changed, 280 insertions(+), 340 deletions(-)

diff --git a/mlir/include/mlir/Analysis/CallGraph.h b/mlir/include/mlir/Analysis/CallGraph.h
index 631cdd1ad22909..d5f6de2107f0ca 100644
--- a/mlir/include/mlir/Analysis/CallGraph.h
+++ b/mlir/include/mlir/Analysis/CallGraph.h
@@ -16,6 +16,7 @@
 #ifndef MLIR_ANALYSIS_CALLGRAPH_H
 #define MLIR_ANALYSIS_CALLGRAPH_H
 
+#include "mlir/IR/SymbolTable.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/GraphTraits.h"
 #include "llvm/ADT/MapVector.h"
@@ -42,11 +43,12 @@ class CallGraphNode {
   /// This class represents a directed edge between two nodes in the callgraph.
   class Edge {
     enum class Kind {
-      // An 'Abstract' edge represents an opaque, non-operation, reference
-      // between this node and the target. Edges of this type are only valid
-      // from the external node, as there is no valid connection to an operation
-      // in the module.
-      Abstract,
+      // A 'Ref' edge represents a reference between this node and the target.
+      // Edges of this type exist whenever any operation references a callable
+      // node. A callable node which may be referenced externally is connected
+      // with a reference edge from the entry node. All 'Call' edges and 'Child'
+      // edges are inherently reference edges in the callgraph.
+      Ref,
 
       // A 'Call' edge represents a direct reference to the target node via a
       // call-like operation within the callable region of this node.
@@ -60,8 +62,8 @@ class CallGraphNode {
     };
 
   public:
-    /// Returns true if this edge represents an `Abstract` edge.
-    bool isAbstract() const { return targetAndKind.getInt() == Kind::Abstract; }
+    /// Returns true if this edge represents an `Ref` edge.
+    bool isRef() const { return targetAndKind.getInt() == Kind::Ref; }
 
     /// Returns true if this edge represents a `Call` edge.
     bool isCall() const { return targetAndKind.getInt() == Kind::Call; }
@@ -95,10 +97,9 @@ class CallGraphNode {
   /// on non-external nodes.
   Region *getCallableRegion() const;
 
-  /// Adds an abstract reference edge to the given node. An abstract edge does
-  /// not come from any observable operations, so this is only valid on the
-  /// external node.
-  void addAbstractEdge(CallGraphNode *node);
+  /// Adds a reference edge to the given node. A reference comes from
+  /// external node or any observable operations.
+  void addRefEdge(CallGraphNode *node);
 
   /// Add an outgoing call edge from this node.
   void addCallEdge(CallGraphNode *node);
@@ -106,6 +107,12 @@ class CallGraphNode {
   /// Adds a reference edge to the given child node.
   void addChildEdge(CallGraphNode *child);
 
+  /// Remove edges to the target node.
+  void removeEdgesTo(CallGraphNode *target);
+
+  /// Remove all edges for this callgraph node.
+  void removeAllEdges();
+
   /// Iterator over the outgoing edges of this node.
   using iterator = SmallVectorImpl<Edge>::const_iterator;
   iterator begin() const { return edges.begin(); }
@@ -114,6 +121,13 @@ class CallGraphNode {
   /// Returns true if this node has any child edges.
   bool hasChildren() const;
 
+  /// Return true if this node is dead.
+  bool isDead() const { return !isExternal() && getNumReferences() == 0; }
+
+  /// Returns the number of edges in this callgraph that connected
+  /// to this node.
+  unsigned getNumReferences() const { return NumReferences; }
+
 private:
   /// DenseMap info for callgraph edges.
   struct EdgeKeyInfo {
@@ -143,6 +157,17 @@ class CallGraphNode {
             llvm::SmallDenseSet<Edge, 4, EdgeKeyInfo>>
       edges;
 
+  /// The number of edges in this callgraph that connected to this
+  /// callgraph node.
+  unsigned NumReferences = 0;
+
+  /// Decrease reference counter when an edge connected to this callgraph node
+  /// is removed.
+  void dropRef() { --NumReferences; }
+
+  /// Add reference counter when a new edge is connected to this callgraph node.
+  void addRef() { ++NumReferences; }
+
   // Provide access to private methods.
   friend class CallGraph;
 };
@@ -184,13 +209,13 @@ class CallGraph {
   /// registered.
   CallGraphNode *lookupNode(Region *region) const;
 
-  /// Return the callgraph node representing an external caller.
-  CallGraphNode *getExternalCallerNode() const {
+  /// Return the callgraph node representing the entry node of this callgraph.
+  CallGraphNode *getEntryNode() const {
     return const_cast<CallGraphNode *>(&externalCallerNode);
   }
 
-  /// Return the callgraph node representing an indirect callee.
-  CallGraphNode *getUnknownCalleeNode() const {
+  /// Return the callgraph node representing the exit node of this callgraph.
+  CallGraphNode *getExitNode() const {
     return const_cast<CallGraphNode *>(&unknownCalleeNode);
   }
 
@@ -201,6 +226,11 @@ class CallGraph {
   CallGraphNode *resolveCallable(CallOpInterface call,
                                  SymbolTableCollection &symbolTable) const;
 
+  /// Resolve the callable for given symbol to a node in the callgraph, or the
+  /// external node if a valid node was not resolved.
+  CallGraphNode *resolveCallable(Operation *op, SymbolRefAttr symbolRef,
+                                 SymbolTableCollection &symbolTable) const;
+
   /// Erase the given node from the callgraph.
   void eraseNode(CallGraphNode *node);
 
@@ -217,10 +247,10 @@ class CallGraph {
   /// The set of nodes within the callgraph.
   NodeMapT nodes;
 
-  /// A special node used to indicate an external caller.
+  /// A special node used to indicate an entry callgraph node.
   CallGraphNode externalCallerNode;
 
-  /// A special node used to indicate an unknown callee.
+  /// A special node used to indicate an exit callgraph node.
   CallGraphNode unknownCalleeNode;
 };
 
@@ -254,7 +284,7 @@ struct GraphTraits<const mlir::CallGraph *>
     : public GraphTraits<const mlir::CallGraphNode *> {
   /// The entry node into the graph is the external node.
   static NodeRef getEntryNode(const mlir::CallGraph *cg) {
-    return cg->getExternalCallerNode();
+    return cg->getEntryNode();
   }
 
   // nodes_iterator/begin/end - Allow iteration over all nodes in the graph
diff --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp
index 780c7caee767c1..559c3c733b3df4 100644
--- a/mlir/lib/Analysis/CallGraph.cpp
+++ b/mlir/lib/Analysis/CallGraph.cpp
@@ -38,11 +38,9 @@ Region *CallGraphNode::getCallableRegion() const {
   return callableRegion;
 }
 
-/// Adds an reference edge to the given node. This is only valid on the
-/// external node.
-void CallGraphNode::addAbstractEdge(CallGraphNode *node) {
-  assert(isExternal() && "abstract edges are only valid on external nodes");
-  addEdge(node, Edge::Kind::Abstract);
+/// Adds an reference edge to the given node.
+void CallGraphNode::addRefEdge(CallGraphNode *node) {
+  addEdge(node, Edge::Kind::Ref);
 }
 
 /// Add an outgoing call edge from this node.
@@ -60,9 +58,28 @@ bool CallGraphNode::hasChildren() const {
   return llvm::any_of(edges, [](const Edge &edge) { return edge.isChild(); });
 }
 
+/// Remove edges to the target callgraph node.
+void CallGraphNode::removeEdgesTo(CallGraphNode *target) {
+  edges.remove_if([target](const CallGraphNode::Edge &edge) {
+    if (edge.getTarget() != target)
+      return false;
+    target->dropRef();
+    return true;
+  });
+}
+
+/// Remove all edges for this callgraph node.
+void CallGraphNode::removeAllEdges() {
+  edges.remove_if([](const CallGraphNode::Edge &edge) {
+    edge.getTarget()->dropRef();
+    return true;
+  });
+}
+
 /// Add an edge to 'node' with the given kind.
 void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) {
-  edges.insert({node, kind});
+  if (edges.insert({node, kind}))
+    node->addRef();
 }
 
 //===----------------------------------------------------------------------===//
@@ -73,14 +90,31 @@ void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) {
 /// edges are placed into the given callgraph object.
 static void computeCallGraph(Operation *op, CallGraph &cg,
                              SymbolTableCollection &symbolTable,
-                             CallGraphNode *parentNode, bool resolveCalls) {
-  if (CallOpInterface call = dyn_cast<CallOpInterface>(op)) {
-    // If there is no parent node, we ignore this operation. Even if this
-    // operation was a call, there would be no callgraph node to attribute it
-    // to.
-    if (resolveCalls && parentNode)
-      parentNode->addCallEdge(cg.resolveCallable(call, symbolTable));
-    return;
+                             CallGraphNode *parentNode, bool resolveSymbolRef) {
+  if (resolveSymbolRef) {
+    bool skipCalleeSymbolRef = false;
+    if (auto call = dyn_cast<CallOpInterface>(op)) {
+      // If there is no parent node, there would be no callgraph node to call it
+      // directly, we use a reference egde to represent this call.
+      if (!parentNode->isExternal()) {
+        parentNode->addCallEdge(cg.resolveCallable(call, symbolTable));
+        if (isa<SymbolRefAttr>(call.getCallableForCallee()))
+          skipCalleeSymbolRef = true;
+      }
+    }
+    // If an operation reference a callable symbol, create a reference edge.
+    op->getAttrDictionary().walk<WalkOrder::PreOrder>(
+        [&](SymbolRefAttr symbolRef) {
+          // Skip the first symbol reference if it is a resolved callee.
+          if (skipCalleeSymbolRef) {
+            skipCalleeSymbolRef = false;
+            return WalkResult::skip();
+          }
+          auto *node = cg.resolveCallable(op, symbolRef, symbolTable);
+          if (!node->isExternal())
+            parentNode->addRefEdge(node);
+          return WalkResult::skip();
+        });
   }
 
   // Compute the callgraph nodes and edges for each of the nested operations.
@@ -93,42 +127,41 @@ static void computeCallGraph(Operation *op, CallGraph &cg,
 
   for (Region &region : op->getRegions())
     for (Operation &nested : region.getOps())
-      computeCallGraph(&nested, cg, symbolTable, parentNode, resolveCalls);
+      computeCallGraph(&nested, cg, symbolTable, parentNode, resolveSymbolRef);
 }
 
 CallGraph::CallGraph(Operation *op)
     : externalCallerNode(/*callableRegion=*/nullptr),
       unknownCalleeNode(/*callableRegion=*/nullptr) {
   // Make two passes over the graph, one to compute the callables and one to
-  // resolve the calls. We split these up as we may have nested callable objects
-  // that need to be reserved before the calls.
+  // resolve the calls and symbol references. We split these up as we may have
+  // nested callable objects that need to be reserved before the calls.
   SymbolTableCollection symbolTable;
-  computeCallGraph(op, *this, symbolTable, /*parentNode=*/nullptr,
-                   /*resolveCalls=*/false);
-  computeCallGraph(op, *this, symbolTable, /*parentNode=*/nullptr,
-                   /*resolveCalls=*/true);
+  computeCallGraph(op, *this, symbolTable, getEntryNode(),
+                   /*resolveSymbolRef=*/false);
+  computeCallGraph(op, *this, symbolTable, getEntryNode(),
+                   /*resolveSymbolRef=*/true);
 }
 
 /// Get or add a call graph node for the given region.
 CallGraphNode *CallGraph::getOrAddNode(Region *region,
                                        CallGraphNode *parentNode) {
-  assert(region && isa<CallableOpInterface>(region->getParentOp()) &&
+  Operation *parentOp = region->getParentOp();
+  assert(region && isa<CallableOpInterface>(parentOp) &&
          "expected parent operation to be callable");
   std::unique_ptr<CallGraphNode> &node = nodes[region];
   if (!node) {
     node.reset(new CallGraphNode(region));
 
     // Add this node to the given parent node if necessary.
-    if (parentNode) {
+    assert(parentNode && "expected non-empty parent node");
+    if (!parentNode->isExternal()) {
       parentNode->addChildEdge(node.get());
-    } else {
-      // Otherwise, connect all callable nodes to the external node, this allows
-      // for conservatively including all callable nodes within the graph.
-      // FIXME This isn't correct, this is only necessary for callable nodes
-      // that *could* be called from external sources. This requires extending
-      // the interface for callables to check if they may be referenced
-      // externally.
-      externalCallerNode.addAbstractEdge(node.get());
+    } else if (auto symbolOp = dyn_cast<SymbolOpInterface>(parentOp)) {
+      // Otherwise, connect all symbol nodes with public visibility
+      // to the entry node, which may be referenced externally.
+      if (!symbolOp.canDiscardOnUseEmpty())
+        parentNode->addRefEdge(node.get());
     }
   }
   return node.get();
@@ -151,7 +184,21 @@ CallGraph::resolveCallable(CallOpInterface call,
     if (auto *node = lookupNode(callableOp.getCallableRegion()))
       return node;
 
-  return getUnknownCalleeNode();
+  return getExitNode();
+}
+
+/// Resolve the callable for given symbol to a node in the callgraph, or the
+/// external node if a valid node was not resolved.
+CallGraphNode *
+CallGraph::resolveCallable(Operation *op, SymbolRefAttr symbolRef,
+                           SymbolTableCollection &symbolTable) const {
+  auto *symbolOp = symbolTable.lookupNearestSymbolFrom(op, symbolRef);
+  if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp)) {
+    if (auto *node = lookupNode(callableOp.getCallableRegion()))
+      return node;
+  }
+
+  return getExitNode();
 }
 
 /// Erase the given node from the callgraph.
@@ -163,11 +210,12 @@ void CallGraph::eraseNode(CallGraphNode *node) {
         eraseNode(edge.getTarget());
   }
   // Erase any edges to this node from any other nodes.
-  for (auto &it : nodes) {
-    it.second->edges.remove_if([node](const CallGraphNode::Edge &edge) {
-      return edge.getTarget() == node;
-    });
-  }
+  for (auto &it : nodes)
+    it.second->removeEdgesTo(node);
+  // Erase all edges from this node to any other nodes.
+  node->removeAllEdges();
+
+  assert(node->isDead() && "expected no references");
   nodes.erase(node->getCallableRegion());
 }
 
@@ -181,11 +229,11 @@ void CallGraph::print(raw_ostream &os) const {
 
   // Functor used to output the name for the given node.
   auto emitNodeName = [&](const CallGraphNode *node) {
-    if (node == getExternalCallerNode()) {
+    if (node == getEntryNode()) {
       os << "<External-Caller-Node>";
       return;
     }
-    if (node == getUnknownCalleeNode()) {
+    if (node == getExitNode()) {
       os << "<Unknown-Callee-Node>";
       return;
     }
@@ -199,13 +247,12 @@ void CallGraph::print(raw_ostream &os) const {
       os << " : " << attrs;
   };
 
-  for (auto &nodeIt : nodes) {
-    const CallGraphNode *node = nodeIt.second.get();
-
+  // Functor used to emit the given node and edges.
+  auto emitNodeAndEdge = [&](const CallGraphNode *node) {
     // Dump the header for this node.
     os << "// - Node : ";
     emitNodeName(node);
-    os << "\n";
+    os << " : References = " << node->getNumReferences() << "\n";
 
     // Emit each of the edges.
     for (auto &edge : *node) {
@@ -214,13 +261,21 @@ void CallGraph::print(raw_ostream &os) const {
         os << "Call";
       else if (edge.isChild())
         os << "Child";
+      else if (edge.isRef())
+        os << "Ref";
 
       os << "-Edge : ";
       emitNodeName(edge.getTarget());
       os << "\n";
     }
     os << "//\n";
-  }
+  };
+
+  // Emit all graph nodes including entry and exit node.
+  for (auto &nodeIt : nodes)
+    emitNodeAndEdge(nodeIt.second.get());
+  emitNodeAndEdge(getEntryNode());
+  emitNodeAndEdge(getExitNode());
 
   os << "// -- SCCs --\n";
 
diff --git a/mlir/lib/Transforms/Utils/Inliner.cpp b/mlir/lib/Transforms/Utils/Inliner.cpp
index 8acfc96d2b611b..86348e8ad6741f 100644
--- a/mlir/lib/Transforms/Utils/Inliner.cpp
+++ b/mlir/lib/Transforms/Utils/Inliner.cpp
@@ -31,223 +31,6 @@ using namespace mlir;
 
 using ResolvedCall = Inliner::ResolvedCall;
 
-//===----------------------------------------------------------------------===//
-// Symbol Use Tracking
-//===----------------------------------------------------------------------===//
-
-/// Walk all of the used symbol callgraph nodes referenced with the given op.
-static void walkReferencedSymbolNodes(
-    Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable,
-    DenseMap<Attribute, CallGraphNode *> &resolvedRefs,
-    function_ref<void(CallGraphNode *, Operation *)> callback) {
-  auto symbolUses = SymbolTable::getSymbolUses(op);
-  assert(symbolUses && "expected uses to be valid");
-
-  Operation *symbolTableOp = op->getParentOp();
-  for (const SymbolTable::SymbolUse &use : *symbolUses) {
-    auto refIt = resolvedRefs.insert({use.getSymbolRef(), nullptr});
-    CallGraphNode *&node = refIt.first->second;
-
-    // If this is the first instance of this reference, try to resolve a
-    // callgraph node for it.
-    if (refIt.second) {
-      auto *symbolOp = symbolTable.lookupNearestSymbolFrom(symbolTableOp,
-                                                           use.getSymbolRef());
-      auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
-      if (!callableOp)
-        continue;
-      node = cg.lookupNode(callableOp.getCallableRegion());
-    }
-    if (node)
-      callback(node, use.getUser());
-  }
-}
-
-//===----------------------------------------------------------------------===//
-// CGUseList
-
-namespace {
-/// This struct tracks the uses of callgraph nodes that can be dropped when
-/// use_empty. It directly tracks and manages a use-list for all of the
-/// call-graph nodes. This is necessary because many callgraph nodes are
-/// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use`
-/// class.
-struct CGUseList {
-  /// This struct tracks the uses of callgraph nodes within a specific
-  /// operation.
-  struct CGUser {
-    /// Any nodes referenced in the top-level attribute list of this user. We
-    /// use a set here because the number of references does not matter.
-    DenseSet<CallGraphNode *> topLevelUses;
-
-    /// Uses of nodes referenced by nested operations.
-    DenseMap<CallGraphNode *, int> innerUses;
-  };
-
-  CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable);
-
-  /// Drop uses of nodes referred to by the given call operation that resides
-  /// within 'userNode'.
-  void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg);
-
-  /// Remove the given node from the use list.
-  void eraseNode(CallGraphNode *node);
-
-  /// Returns true if the given callgraph node has no uses and can be pruned.
-  bool isDead(CallGraphNode *node) const;
-
-  /// Returns true if the given callgraph node has a single use and can be
-  /// discarded.
-  bool hasOneUseAndDiscardable(CallGraphNode *node) const;
-
-  /// Recompute the uses held by the given callgraph node.
-  void recomputeUses(CallGraphNode *node, CallGraph &cg);
-
-  /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy
-  /// of 'lhs' into 'rhs'.
-  void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs);
-
-private:
-  /// Decrement the uses of discardable nodes referenced by the given user.
-  void decrementDiscardableUses(CGUser &uses);
-
-  /// A mapping between a discardable callgraph node (that is a symbol) and the
-  /// number of uses for this node.
-  DenseMap<CallGraphNode *, int> discardableSymNodeUses;
-
-  /// A mapping between a callgraph node and the symbol callgraph nodes that it
-  /// uses.
-  DenseMap<CallGraphNode *, CGUser> nodeUses;
-
-  /// A symbol table to use when resolving call lookups.
-  SymbolTableCollection &symbolTable;
-};
-} // namespace
-
-CGUseList::CGUseList(Operation *op, CallGraph &cg,
-                     SymbolTableCollection &symbolTable)
-    : symbolTable(symbolTable) {
-  /// A set of callgraph nodes that are always known to be live during inlining.
-  DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes;
-
-  // Walk each of the symbol tables looking for discardable callgraph nodes.
-  auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
-    for (Operation &op : symbolTableOp->getRegion(0).getOps()) {
-      // If this is a callgraph operation, check to see if it is discardable.
-      if (auto callable = dyn_cast<CallableOpInterface>(&op)) {
-        if (auto *node = cg.lookupNode(callable.getCallableRegion())) {
-          SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
-          if (symbol && (allUsesVisible || symbol.isPrivate()) &&
-              symbol.canDiscardOnUseEmpty()) {
-            discardableSymNodeUses.try_emplace(node, 0);
-          }
-          continue;
-        }
-      }
-      // Otherwise, check for any referenced nodes. These will be always-live.
-      walkReferencedSymbolNodes(&op, cg, symbolTable, alwaysLiveNodes,
-                                [](CallGraphNode *, Operation *) {});
-    }
-  };
-  SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
-                                walkFn);
-
-  // Drop the use information for any discardable nodes that are always live.
-  for (auto &it : alwaysLiveNodes)
-    discardableSymNodeUses.erase(it.second);
-
-  // Compute the uses for each of the callable nodes in the graph.
-  for (CallGraphNode *node : cg)
-    recomputeUses(node, cg);
-}
-
-void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp,
-                             CallGraph &cg) {
-  auto &userRefs = nodeUses[userNode].innerUses;
-  auto walkFn = [&](CallGraphNode *node, Operation *user) {
-    auto parentIt = userRefs.find(node);
-    if (parentIt == userRefs.end())
-      return;
-    --parentIt->second;
-    --discardableSymNodeUses[node];
-  };
-  DenseMap<Attribute, CallGraphNode *> resolvedRefs;
-  walkReferencedSymbolNodes(callOp, cg, symbolTable, resolvedRefs, walkFn);
-}
-
-void CGUseList::eraseNode(CallGraphNode *node) {
-  // Drop all child nodes.
-  for (auto &edge : *node)
-    if (edge.isChild())
-      eraseNode(edge.getTarget());
-
-  // Drop the uses held by this node and erase it.
-  auto useIt = nodeUses.find(node);
-  assert(useIt != nodeUses.end() && "expected node to be valid");
-  decrementDiscardableUses(useIt->getSecond());
-  nodeUses.erase(useIt);
-  discardableSymNodeUses.erase(node);
-}
-
-bool CGUseList::isDead(CallGraphNode *node) const {
-  // If the parent operation isn't a symbol, simply check normal SSA deadness.
-  Operation *nodeOp = node->getCallableRegion()->getParentOp();
-  if (!isa<SymbolOpInterface>(nodeOp))
-    return isMemoryEffectFree(nodeOp) && nodeOp->use_empty();
-
-  // Otherwise, check the number of symbol uses.
-  auto symbolIt = discardableSymNodeUses.find(node);
-  return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0;
-}
-
-bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const {
-  // If this isn't a symbol node, check for side-effects and SSA use count.
-  Operation *nodeOp = node->getCallableRegion()->getParentOp();
-  if (!isa<SymbolOpInterface>(nodeOp))
-    return isMemoryEffectFree(nodeOp) && nodeOp->hasOneUse();
-
-  // Otherwise, check the number of symbol uses.
-  auto symbolIt = discardableSymNodeUses.find(node);
-  return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1;
-}
-
-void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) {
-  Operation *parentOp = node->getCallableRegion()->getParentOp();
-  CGUser &uses = nodeUses[node];
-  decrementDiscardableUses(uses);
-
-  // Collect the new discardable uses within this node.
-  uses = CGUser();
-  DenseMap<Attribute, CallGraphNode *> resolvedRefs;
-  auto walkFn = [&](CallGraphNode *refNode, Operation *user) {
-    auto discardSymIt = discardableSymNodeUses.find(refNode);
-    if (discardSymIt == discardableSymNodeUses.end())
-      return;
-
-    if (user != parentOp)
-      ++uses.innerUses[refNode];
-    else if (!uses.topLevelUses.insert(refNode).second)
-      return;
-    ++discardSymIt->second;
-  };
-  walkReferencedSymbolNodes(parentOp, cg, symbolTable, resolvedRefs, walkFn);
-}
-
-void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) {
-  auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs];
-  for (auto &useIt : lhsUses.innerUses) {
-    rhsUses.innerUses[useIt.first] += useIt.second;
-    discardableSymNodeUses[useIt.first] += useIt.second;
-  }
-}
-
-void CGUseList::decrementDiscardableUses(CGUser &uses) {
-  for (CallGraphNode *node : uses.topLevelUses)
-    --discardableSymNodeUses[node];
-  for (auto &it : uses.innerUses)
-    discardableSymNodeUses[it.first] -= it.second;
-}
-
 //===----------------------------------------------------------------------===//
 // CallGraph traversal
 //===----------------------------------------------------------------------===//
@@ -396,18 +179,18 @@ struct InlinerInterfaceImpl : public InlinerInterface {
                    /*traverseNestedCGNodes=*/true);
   }
 
-  /// Mark the given callgraph node for deletion.
-  void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); }
+  /// Mark the given callable for deletion.
+  void markForDeletion(Operation *callable) { deadCallables.insert(callable); }
 
   /// This method properly disposes of callables that became dead during
   /// inlining. This should not be called while iterating over the SCCs.
   void eraseDeadCallables() {
-    for (CallGraphNode *node : deadNodes)
-      node->getCallableRegion()->getParentOp()->erase();
+    for (Operation *callable : deadCallables)
+      callable->erase();
   }
 
   /// The set of callables known to be dead.
-  SmallPtrSet<CallGraphNode *, 8> deadNodes;
+  SmallPtrSet<Operation *, 8> deadCallables;
 
   /// The current set of call instructions to consider for inlining.
   SmallVector<ResolvedCall, 8> calls;
@@ -431,15 +214,16 @@ class Inliner::Impl {
   /// devirtualized calls. Returns failure if there was a fatal error during
   /// inlining.
   LogicalResult inlineSCC(InlinerInterfaceImpl &inlinerIface,
-                          CGUseList &useList, CallGraphSCC &currentSCC,
-                          MLIRContext *context);
+                          CallGraphSCC &currentSCC, MLIRContext *context);
+
+  void collectDeadNodeAfterInline(Operation *topLevelOp,
+                                  InlinerInterfaceImpl &inlinerIface);
 
 private:
   /// Optimize the nodes within the given SCC with one of the held optimization
   /// pass pipelines. Returns failure if an error occurred during the
   /// optimization of the SCC, success otherwise.
-  LogicalResult optimizeSCC(CallGraph &cg, CGUseList &useList,
-                            CallGraphSCC &currentSCC, MLIRContext *context);
+  LogicalResult optimizeSCC(CallGraphSCC &currentSCC, MLIRContext *context);
 
   /// Optimize the nodes within the given SCC in parallel. Returns failure if an
   /// error occurred during the optimization of the SCC, success otherwise.
@@ -456,7 +240,7 @@ class Inliner::Impl {
   /// Attempt to inline calls within the given scc. This function returns
   /// success if any calls were inlined, failure otherwise.
   LogicalResult inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
-                                 CGUseList &useList, CallGraphSCC &currentSCC);
+                                 CallGraphSCC &currentSCC);
 
   /// Returns true if the given call should be inlined.
   bool shouldInline(ResolvedCall &resolvedCall);
@@ -467,7 +251,6 @@ class Inliner::Impl {
 };
 
 LogicalResult Inliner::Impl::inlineSCC(InlinerInterfaceImpl &inlinerIface,
-                                       CGUseList &useList,
                                        CallGraphSCC &currentSCC,
                                        MLIRContext *context) {
   // Continuously simplify and inline until we either reach a fixed point, or
@@ -475,20 +258,20 @@ LogicalResult Inliner::Impl::inlineSCC(InlinerInterfaceImpl &inlinerIface,
   // model, and in future iterations may devirtualize new calls.
   unsigned iterationCount = 0;
   do {
-    if (failed(optimizeSCC(inlinerIface.cg, useList, currentSCC, context)))
+    if (failed(optimizeSCC(currentSCC, context)))
       return failure();
-    if (failed(inlineCallsInSCC(inlinerIface, useList, currentSCC)))
+    if (failed(inlineCallsInSCC(inlinerIface, currentSCC)))
       break;
   } while (++iterationCount < inliner.config.getMaxInliningIterations());
   return success();
 }
 
-LogicalResult Inliner::Impl::optimizeSCC(CallGraph &cg, CGUseList &useList,
-                                         CallGraphSCC &currentSCC,
+LogicalResult Inliner::Impl::optimizeSCC(CallGraphSCC &currentSCC,
                                          MLIRContext *context) {
   // Collect the sets of nodes to simplify.
   SmallVector<CallGraphNode *, 4> nodesToVisit;
   for (auto *node : currentSCC) {
+    assert(!node->isDead() && "unexcepted dead node in callgraph SCC");
     if (node->isExternal())
       continue;
 
@@ -512,9 +295,6 @@ LogicalResult Inliner::Impl::optimizeSCC(CallGraph &cg, CGUseList &useList,
   if (failed(optimizeSCCAsync(nodesToVisit, context)))
     return failure();
 
-  // Recompute the uses held by each of the nodes.
-  for (CallGraphNode *node : nodesToVisit)
-    useList.recomputeUses(node, cg);
   return success();
 }
 
@@ -583,7 +363,7 @@ Inliner::Impl::optimizeCallable(CallGraphNode *node,
 /// success if any calls were inlined, failure otherwise.
 LogicalResult
 Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
-                                CGUseList &useList, CallGraphSCC &currentSCC) {
+                                CallGraphSCC &currentSCC) {
   CallGraph &cg = inlinerIface.cg;
   auto &calls = inlinerIface.calls;
 
@@ -594,17 +374,13 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
   // don't traverse nested callgraph nodes, because they are handled separately
   // likely within a different SCC.
   for (CallGraphNode *node : currentSCC) {
+    assert(!node->isDead() && "unexcepted dead node in callgraph SCC");
     if (node->isExternal())
       continue;
 
-    // Don't collect calls if the node is already dead.
-    if (useList.isDead(node)) {
-      deadNodes.insert(node);
-    } else {
-      collectCallOps(*node->getCallableRegion(), node, cg,
-                     inlinerIface.symbolTable, calls,
-                     /*traverseNestedCGNodes=*/false);
-    }
+    collectCallOps(*node->getCallableRegion(), node, cg,
+                   inlinerIface.symbolTable, calls,
+                   /*traverseNestedCGNodes=*/false);
   }
 
   // When inlining a callee produces new call sites, we want to keep track of
@@ -646,14 +422,10 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
     unsigned prevSize = calls.size();
     Region *targetRegion = it.targetNode->getCallableRegion();
 
-    // If this is the last call to the target node and the node is discardable,
-    // then inline it in-place and delete the node if successful.
-    bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode);
-
     LogicalResult inlineResult =
         inlineCall(inlinerIface, call,
                    cast<CallableOpInterface>(targetRegion->getParentOp()),
-                   targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace);
+                   targetRegion, /*shouldCloneInlinedRegion=*/true);
     if (failed(inlineResult)) {
       LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n");
       continue;
@@ -683,24 +455,10 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
                               << historyToString(inlineHistoryID) << "\n");
     }
 
-    // If the inlining was successful, Merge the new uses into the source node.
-    useList.dropCallUses(it.sourceNode, call.getOperation(), cg);
-    useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode);
-
     // then erase the call.
     call.erase();
-
-    // If we inlined in place, mark the node for deletion.
-    if (inlineInPlace) {
-      useList.eraseNode(it.targetNode);
-      deadNodes.insert(it.targetNode);
-    }
   }
 
-  for (CallGraphNode *node : deadNodes) {
-    currentSCC.remove(node);
-    inlinerIface.markForDeletion(node);
-  }
   calls.clear();
   return success(inlinedAnyCalls);
 }
@@ -748,6 +506,28 @@ bool Inliner::Impl::shouldInline(ResolvedCall &resolvedCall) {
   return true;
 }
 
+/// Iteratively clean up dead nodes until no change happened.
+void Inliner::Impl::collectDeadNodeAfterInline(
+    Operation *topLevelOp, InlinerInterfaceImpl &inlinerIface) {
+  auto newCallGraph = CallGraph(topLevelOp);
+  SmallVector<CallGraphNode *, 10> deadNodes;
+  while (1) {
+    deadNodes.clear();
+    for (CallGraphNode *node : newCallGraph) {
+      if (node->isDead())
+        deadNodes.push_back(node);
+    }
+
+    for (auto &node : deadNodes) {
+      inlinerIface.markForDeletion(node->getCallableRegion()->getParentOp());
+      newCallGraph.eraseNode(node);
+    }
+
+    if (deadNodes.empty())
+      break;
+  }
+}
+
 LogicalResult Inliner::doInlining() {
   Impl impl(*this);
   auto *context = op->getContext();
@@ -757,14 +537,14 @@ LogicalResult Inliner::doInlining() {
   // of the Impl's methods, if the inlinerIface and useList
   // become the states of the Impl.
   InlinerInterfaceImpl inlinerIface(context, cg, symbolTable);
-  CGUseList useList(op, cg, symbolTable);
   LogicalResult result = runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) {
-    return impl.inlineSCC(inlinerIface, useList, scc, context);
+    return impl.inlineSCC(inlinerIface, scc, context);
   });
   if (failed(result))
     return result;
 
   // After inlining, make sure to erase any callables proven to be dead.
+  impl.collectDeadNodeAfterInline(op, inlinerIface);
   inlinerIface.eraseDeadCallables();
   return success();
 }
diff --git a/mlir/test/Analysis/test-callgraph.mlir b/mlir/test/Analysis/test-callgraph.mlir
index f6c9ff5006e053..8d47003a4b4d68 100644
--- a/mlir/test/Analysis/test-callgraph.mlir
+++ b/mlir/test/Analysis/test-callgraph.mlir
@@ -8,30 +8,32 @@ module attributes {test.name = "simple"} {
     return
   }
 
+  // CHECK-NOT: Node{{.*}}func_b
   func.func private @func_b()
 
-  // CHECK: Node{{.*}}func_c
+  // CHECK: Node{{.*}}func_c{{.*}}private
   // CHECK-NEXT: Call-Edge{{.*}}Unknown-Callee-Node
-  func.func @func_c() {
+  func.func private @func_c() {
     call @func_b() : () -> ()
     return
   }
 
   // CHECK: Node{{.*}}func_d
-  // CHECK-NEXT: Call-Edge{{.*}}func_c
+  // CHECK-NEXT: Call-Edge{{.*}}func_c{{.*}}private
   func.func @func_d() {
     call @func_c() : () -> ()
     return
   }
 
   // CHECK: Node{{.*}}func_e
-  // CHECK-DAG: Call-Edge{{.*}}func_c
+  // CHECK-DAG: Call-Edge{{.*}}func_c{{.*}}private
   // CHECK-DAG: Call-Edge{{.*}}func_d
   // CHECK-DAG: Call-Edge{{.*}}func_e
+  // CHECK-DAG: Ref-Edge{{.*}}func_a
   func.func @func_e() {
     call @func_c() : () -> ()
     call @func_d() : () -> ()
-    call @func_e() : () -> ()
+    call @func_e() { use = @func_a } : () -> ()
     return
   }
 
@@ -49,6 +51,39 @@ module attributes {test.name = "simple"} {
     call_indirect %fn() : () -> ()
     return
   }
+
+  // CHECK: Node{{.*}}func_g
+  // CHECK: Ref-Edge{{.*}}func_c
+  // CHECK: Call-Edge{{.*}}Unknown-Callee-Node
+  func.func @func_g() -> (() -> ()) {
+    // A private symbol maybe escaped.
+    %0 = func.constant @func_c : () -> ()
+    call_indirect %0() : () -> ()
+    return %0 : () -> ()
+  }
+
+  // CHECK: Node{{.*}}func_h{{.*}}private
+  func.func private @func_h() {
+    return
+  }
+
+  // Referenced symbol declarations is ignored.
+  "live.user"() { use = @func_b } : () -> ()
+  // non-callable top level operation reference callable symbol
+  "live.user"() { use = @func_c } : () -> ()
+  func.call @func_h() : () -> ()
+
+  // CHECK: Node{{.*}}External-Caller-Node
+  // CHECK-NEXT: Ref-Edge{{.*}}func_a
+  // CHECK-NEXT: Ref-Edge{{.*}}func_d
+  // CHECK-NEXT: Ref-Edge{{.*}}func_e
+  // CHECK-NEXT: Ref-Edge{{.*}}func_f
+  // CHECK-NEXT: Ref-Edge{{.*}}func_g
+  // CHECK-NOT: Ref-Edge{{.*}}func_b
+  // CHECK-NEXT: Ref-Edge{{.*}}func_c{{.*}}private
+  // CHECK-NEXT: Ref-Edge{{.*}}func_h{{.*}}private
+
+  // CHECK: Node{{.*}}Unknown-Callee-Node
 }
 
 // -----
@@ -56,18 +91,30 @@ module attributes {test.name = "simple"} {
 // CHECK-LABEL: Testing : "nested"
 module attributes {test.name = "nested"} {
   module @nested_module {
-    // CHECK: Node{{.*}}func_a
-    func.func @func_a() {
+    // CHECK: Node{{.*}}func_a{{.*}}nested
+    func.func nested @func_a() {
+      return
+    }
+    // CHECK: Node{{.*}}func_b{{.*}}nested
+    func.func nested @func_b() {
       return
     }
   }
 
-  // CHECK: Node{{.*}}func_b
-  // CHECK: Call-Edge{{.*}}func_a
-  func.func @func_b() {
-    "test.conversion_call_op"() { callee = @nested_module::@func_a } : () -> ()
+  // CHECK: Node{{.*}}func_c
+  // CHECK: Call-Edge{{.*}}func_a{{.*}}nested
+  // CHECK: Ref-Edge{{.*}}func_b{{.*}}nested
+  func.func @func_c() {
+    "test.conversion_call_op"() { use = @nested_module::@func_b, callee = @nested_module::@func_a } : () -> ()
     return
   }
+
+  // CHECK: Node{{.*}}External-Caller-Node
+  // CHECK-NEXT: Ref-Edge{{.*}}func_c
+  // CHECK-NOT: Ref-Edge{{.*}}func_a
+  // CHECK-NOT: Ref-Edge{{.*}}func_b
+
+  // CHECK: Node{{.*}}Unknown-Callee-Node
 }
 
 // -----
@@ -95,4 +142,3 @@ module attributes {test.name = "SCC"} {
   // CHECK: SCC :
   // CHECK-NEXT: Node{{.*}}External-Caller-Node
 }
-
diff --git a/mlir/test/Transforms/inlining-dce.mlir b/mlir/test/Transforms/inlining-dce.mlir
index d167c1b4baae98..c31b909903c3c2 100644
--- a/mlir/test/Transforms/inlining-dce.mlir
+++ b/mlir/test/Transforms/inlining-dce.mlir
@@ -10,7 +10,7 @@ func.func private @dead_function() {
 
 // Function becomes dead after inlining.
 // CHECK-NOT: func private @dead_function_b
-func.func @dead_function_b() {
+func.func private @dead_function_b() {
   return
 }
 
@@ -44,6 +44,35 @@ func.func @live_function_c() {
   return
 }
 
+// A transitive example, but no one be called by live-function.
+
+// CHECK-NOT: func private @dead_function_e
+func.func private @dead_function_e() {
+  call @live_function_b() : () -> ()
+  return
+}
+
+// CHECK-NOT: func private @dead_function_f
+func.func private @dead_function_f() {
+  call @dead_function_e() : () -> ()
+  return
+}
+
+// A function constant is inlined after optimization
+
+// CHECK-NOT: func private @dead_function_h
+func.func private @dead_function_h() {
+  call @live_function_b() : () -> ()
+  return
+}
+
+// CHECK: func @live_function_f
+func.func @live_function_f() {
+  %0 = func.constant @dead_function_h : () -> ()
+  call_indirect %0() : () -> ()
+  return
+}
+
 // Function is referenced by non-callable top-level user.
 // CHECK: func private @live_function_d
 func.func private @live_function_d() {



More information about the Mlir-commits mailing list