[Mlir-commits] [mlir] 82f86b8 - [mlir][CallGraph] Add special call graph node for representing unknown callees

Markus Böck llvmlistbot at llvm.org
Fri Sep 9 11:23:10 PDT 2022


Author: Markus Böck
Date: 2022-09-09T20:22:59+02:00
New Revision: 82f86b862b8ba9b1c6d91a32c88907f0dad6c3d2

URL: https://github.com/llvm/llvm-project/commit/82f86b862b8ba9b1c6d91a32c88907f0dad6c3d2
DIFF: https://github.com/llvm/llvm-project/commit/82f86b862b8ba9b1c6d91a32c88907f0dad6c3d2.diff

LOG: [mlir][CallGraph] Add special call graph node for representing unknown callees

The callgraph currently contains a special external node that is used both as the quasi caller for any externally callable as well as callees that could not be resolved.
This has one negative side effect however, which is the motivation for this patch: It leads to every externally callable which contains a call that could not be resolved (eg. an indirect call), to be put into one giant SCC when iterating over the SCCs of the call graph.

This patch fixes that issue by creating a second special callgraph node that acts as the callee for any unresolved callable. This breaks the cycles produced in the callgraph, yielding proper SCCs for all direct calls.

Differential Revision: https://reviews.llvm.org/D133585

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/CallGraph.h
    mlir/lib/Analysis/CallGraph.cpp
    mlir/test/Analysis/test-callgraph.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/CallGraph.h b/mlir/include/mlir/Analysis/CallGraph.h
index fe7a6faffb4c7..631cdd1ad2290 100644
--- a/mlir/include/mlir/Analysis/CallGraph.h
+++ b/mlir/include/mlir/Analysis/CallGraph.h
@@ -184,9 +184,14 @@ class CallGraph {
   /// registered.
   CallGraphNode *lookupNode(Region *region) const;
 
-  /// Return the callgraph node representing the indirect-external callee.
-  CallGraphNode *getExternalNode() const {
-    return const_cast<CallGraphNode *>(&externalNode);
+  /// Return the callgraph node representing an external caller.
+  CallGraphNode *getExternalCallerNode() const {
+    return const_cast<CallGraphNode *>(&externalCallerNode);
+  }
+
+  /// Return the callgraph node representing an indirect callee.
+  CallGraphNode *getUnknownCalleeNode() const {
+    return const_cast<CallGraphNode *>(&unknownCalleeNode);
   }
 
   /// Resolve the callable for given callee to a node in the callgraph, or the
@@ -212,8 +217,11 @@ class CallGraph {
   /// The set of nodes within the callgraph.
   NodeMapT nodes;
 
-  /// A special node used to indicate an external edges.
-  CallGraphNode externalNode;
+  /// A special node used to indicate an external caller.
+  CallGraphNode externalCallerNode;
+
+  /// A special node used to indicate an unknown callee.
+  CallGraphNode unknownCalleeNode;
 };
 
 } // namespace mlir
@@ -246,7 +254,7 @@ struct GraphTraits<const mlir::CallGraph *>
     : public GraphTraits<const mlir::CallGraphNode *> {
   /// The entry node into the graph is the external node.
   static NodeRef getEntryNode(const mlir::CallGraph *cg) {
-    return cg->getExternalNode();
+    return cg->getExternalCallerNode();
   }
 
   // nodes_iterator/begin/end - Allow iteration over all nodes in the graph

diff  --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp
index 0f214446bbc64..614e8e24ff980 100644
--- a/mlir/lib/Analysis/CallGraph.cpp
+++ b/mlir/lib/Analysis/CallGraph.cpp
@@ -92,7 +92,9 @@ static void computeCallGraph(Operation *op, CallGraph &cg,
       computeCallGraph(&nested, cg, symbolTable, parentNode, resolveCalls);
 }
 
-CallGraph::CallGraph(Operation *op) : externalNode(/*callableRegion=*/nullptr) {
+CallGraph::CallGraph(Operation *op)
+    : externalCallerNode(/*callableRegion=*/nullptr),
+      unknownCalleeNode(/*callableRegion=*/nullptr) {
   // Make two passes over the graph, one to compute the callables and one to
   // resolve the calls. We split these up as we may have nested callable objects
   // that need to be reserved before the calls.
@@ -122,7 +124,7 @@ CallGraphNode *CallGraph::getOrAddNode(Region *region,
       // that *could* be called from external sources. This requires extending
       // the interface for callables to check if they may be referenced
       // externally.
-      externalNode.addAbstractEdge(node.get());
+      externalCallerNode.addAbstractEdge(node.get());
     }
   }
   return node.get();
@@ -136,7 +138,7 @@ CallGraphNode *CallGraph::lookupNode(Region *region) const {
 }
 
 /// Resolve the callable for given callee to a node in the callgraph, or the
-/// external node if a valid node was not resolved.
+/// unknown callee node if a valid node was not resolved.
 CallGraphNode *
 CallGraph::resolveCallable(CallOpInterface call,
                            SymbolTableCollection &symbolTable) const {
@@ -145,8 +147,7 @@ CallGraph::resolveCallable(CallOpInterface call,
     if (auto *node = lookupNode(callableOp.getCallableRegion()))
       return node;
 
-  // If we don't have a valid direct region, this is an external call.
-  return getExternalNode();
+  return getUnknownCalleeNode();
 }
 
 /// Erase the given node from the callgraph.
@@ -176,8 +177,12 @@ void CallGraph::print(raw_ostream &os) const {
 
   // Functor used to output the name for the given node.
   auto emitNodeName = [&](const CallGraphNode *node) {
-    if (node->isExternal()) {
-      os << "<External-Node>";
+    if (node == getExternalCallerNode()) {
+      os << "<External-Caller-Node>";
+      return;
+    }
+    if (node == getUnknownCalleeNode()) {
+      os << "<Unknown-Callee-Node>";
       return;
     }
 

diff  --git a/mlir/test/Analysis/test-callgraph.mlir b/mlir/test/Analysis/test-callgraph.mlir
index a316a5c0db0b2..f6c9ff5006e05 100644
--- a/mlir/test/Analysis/test-callgraph.mlir
+++ b/mlir/test/Analysis/test-callgraph.mlir
@@ -11,7 +11,7 @@ module attributes {test.name = "simple"} {
   func.func private @func_b()
 
   // CHECK: Node{{.*}}func_c
-  // CHECK-NEXT: Call-Edge{{.*}}External-Node
+  // CHECK-NEXT: Call-Edge{{.*}}Unknown-Callee-Node
   func.func @func_c() {
     call @func_b() : () -> ()
     return
@@ -69,3 +69,30 @@ module attributes {test.name = "nested"} {
     return
   }
 }
+
+// -----
+
+// CHECK-LABEL: Testing : "SCC"
+// CHECK: SCCs
+module attributes {test.name = "SCC"} {
+  // CHECK: SCC :
+  // CHECK-NEXT: Node{{.*}}Unknown-Callee-Node
+
+  // CHECK: SCC :
+  // CHECK-NEXT: Node{{.*}}foo
+  func.func @foo(%arg0 : () -> ()) {
+    call_indirect %arg0() : () -> ()
+    return
+  }
+
+  // CHECK: SCC :
+  // CHECK-NEXT: Node{{.*}}bar
+  func.func @bar(%arg1 : () -> ()) {
+    call_indirect %arg1() : () -> ()
+    return
+  }
+
+  // CHECK: SCC :
+  // CHECK-NEXT: Node{{.*}}External-Caller-Node
+}
+


        


More information about the Mlir-commits mailing list