[Mlir-commits] [mlir] [mlir][CallGraph] Fix abstract edge connected to external node (PR #116177)

Haocong Lu llvmlistbot at llvm.org
Thu Nov 14 00:02:44 PST 2024


https://github.com/Luhaocong created https://github.com/llvm/llvm-project/pull/116177

In `CallGraph Analysis`, maybe only `CallableOpInterface` with a symbol could be referenced from external node. This patch connects abstract edge to a target callable node, only when the node is a symbol with public visibility. This patch also supports to dump ExternalCallerNode and UnknownCalleeNode

>From a75893471e10f66e6465ddaa6a647e50224319c6 Mon Sep 17 00:00:00 2001
From: Lu Haocong <haocong.lu at evas.ai>
Date: Thu, 14 Nov 2024 15:11:40 +0800
Subject: [PATCH] [mlir][CallGraph] Fix abstract edge connected to external
 node

In `CallGraph Analysis`, maybe only `CallableOpInterface` with a symbol
could be referenced from external node. This patch connects abstract
edge to a target callable node, only when the node is a symbol with
public visibility. This patch also supports to dump ExternalCallerNode
and UnknownCalleeNode
---
 mlir/lib/Analysis/CallGraph.cpp        | 29 +++++++++++++++-----------
 mlir/test/Analysis/test-callgraph.mlir | 29 ++++++++++++++++++++------
 2 files changed, 40 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp
index 780c7caee767c1..560072570149b8 100644
--- a/mlir/lib/Analysis/CallGraph.cpp
+++ b/mlir/lib/Analysis/CallGraph.cpp
@@ -112,7 +112,8 @@ CallGraph::CallGraph(Operation *op)
 /// Get or add a call graph node for the given region.
 CallGraphNode *CallGraph::getOrAddNode(Region *region,
                                        CallGraphNode *parentNode) {
-  assert(region && isa<CallableOpInterface>(region->getParentOp()) &&
+  Operation *parentOp = region->getParentOp();
+  assert(region && isa<CallableOpInterface>(parentOp) &&
          "expected parent operation to be callable");
   std::unique_ptr<CallGraphNode> &node = nodes[region];
   if (!node) {
@@ -122,13 +123,12 @@ CallGraphNode *CallGraph::getOrAddNode(Region *region,
     if (parentNode) {
       parentNode->addChildEdge(node.get());
     } else {
-      // Otherwise, connect all callable nodes to the external node, this allows
-      // for conservatively including all callable nodes within the graph.
-      // FIXME This isn't correct, this is only necessary for callable nodes
-      // that *could* be called from external sources. This requires extending
-      // the interface for callables to check if they may be referenced
-      // externally.
-      externalCallerNode.addAbstractEdge(node.get());
+      // Otherwise, connect all symbol nodes with public visibility
+      // to the external node, which is a set including callable nodes
+      // may be referenced externally.
+      if (isa<SymbolOpInterface>(parentOp) &&
+          cast<SymbolOpInterface>(parentOp).isPublic())
+        externalCallerNode.addAbstractEdge(node.get());
     }
   }
   return node.get();
@@ -199,9 +199,8 @@ void CallGraph::print(raw_ostream &os) const {
       os << " : " << attrs;
   };
 
-  for (auto &nodeIt : nodes) {
-    const CallGraphNode *node = nodeIt.second.get();
-
+  // Functor used to emit the given node and edges.
+  auto emitNodeAndEdge = [&](const CallGraphNode *node) {
     // Dump the header for this node.
     os << "// - Node : ";
     emitNodeName(node);
@@ -220,7 +219,13 @@ void CallGraph::print(raw_ostream &os) const {
       os << "\n";
     }
     os << "//\n";
-  }
+  };
+
+  // Emit all graph nodes including ExternalCallerNode and UnknownCalleeNode.
+  for (auto &nodeIt : nodes)
+    emitNodeAndEdge(nodeIt.second.get());
+  emitNodeAndEdge(getExternalCallerNode());
+  emitNodeAndEdge(getUnknownCalleeNode());
 
   os << "// -- SCCs --\n";
 
diff --git a/mlir/test/Analysis/test-callgraph.mlir b/mlir/test/Analysis/test-callgraph.mlir
index f6c9ff5006e053..8a00966bea61dd 100644
--- a/mlir/test/Analysis/test-callgraph.mlir
+++ b/mlir/test/Analysis/test-callgraph.mlir
@@ -8,24 +8,25 @@ module attributes {test.name = "simple"} {
     return
   }
 
+  // CHECK-NOT: Node{{.*}}func_b
   func.func private @func_b()
 
-  // CHECK: Node{{.*}}func_c
+  // CHECK: Node{{.*}}func_c{{.*}}private
   // CHECK-NEXT: Call-Edge{{.*}}Unknown-Callee-Node
-  func.func @func_c() {
+  func.func private @func_c() {
     call @func_b() : () -> ()
     return
   }
 
   // CHECK: Node{{.*}}func_d
-  // CHECK-NEXT: Call-Edge{{.*}}func_c
+  // CHECK-NEXT: Call-Edge{{.*}}func_c{{.*}}private
   func.func @func_d() {
     call @func_c() : () -> ()
     return
   }
 
   // CHECK: Node{{.*}}func_e
-  // CHECK-DAG: Call-Edge{{.*}}func_c
+  // CHECK-DAG: Call-Edge{{.*}}func_c{{.*}}private
   // CHECK-DAG: Call-Edge{{.*}}func_d
   // CHECK-DAG: Call-Edge{{.*}}func_e
   func.func @func_e() {
@@ -49,6 +50,16 @@ module attributes {test.name = "simple"} {
     call_indirect %fn() : () -> ()
     return
   }
+
+  // CHECK: Node{{.*}}External-Caller-Node
+  // CHECK: Edge{{.*}}func_a
+  // CHECK-NOT: Edge{{.*}}func_b
+  // CHECK-NOT: Edge{{.*}}func_c
+  // CHECK: Edge{{.*}}func_d
+  // CHECK: Edge{{.*}}func_e
+  // CHECK: Edge{{.*}}func_f
+
+  // CHECK: Node{{.*}}Unknown-Callee-Node
 }
 
 // -----
@@ -57,17 +68,23 @@ module attributes {test.name = "simple"} {
 module attributes {test.name = "nested"} {
   module @nested_module {
     // CHECK: Node{{.*}}func_a
-    func.func @func_a() {
+    func.func nested @func_a() {
       return
     }
   }
 
   // CHECK: Node{{.*}}func_b
-  // CHECK: Call-Edge{{.*}}func_a
+  // CHECK: Call-Edge{{.*}}func_a{{.*}}nested
   func.func @func_b() {
     "test.conversion_call_op"() { callee = @nested_module::@func_a } : () -> ()
     return
   }
+
+  // CHECK: Node{{.*}}External-Caller-Node
+  // CHECK: Edge{{.*}}func_b
+  // CHECK-NOT: Edge{{.*}}func_a
+
+  // CHECK: Node{{.*}}Unknown-Callee-Node
 }
 
 // -----



More information about the Mlir-commits mailing list