[Mlir-commits] [mlir] [mlir][CallGraph] Fix abstract edge connected to external node (PR #116177)

Haocong Lu llvmlistbot at llvm.org
Fri Nov 15 01:24:29 PST 2024


https://github.com/Luhaocong updated https://github.com/llvm/llvm-project/pull/116177

>From ae37879de9170845ad6f4d5023894b839d73f224 Mon Sep 17 00:00:00 2001
From: Lu Haocong <haocong.lu at evas.ai>
Date: Thu, 14 Nov 2024 15:11:40 +0800
Subject: [PATCH] [mlir][CallGraph] Fix abstract edge connected to external
 node

In `CallGraph Analysis`, maybe only `CallableOpInterface` with a symbol
could be referenced from external node. This patch connects abstract
edge to a target callable node, only when the node is a symbol with
public visibility. This patch also supports to dump ExternalCallerNode
and UnknownCalleeNode
---
 mlir/lib/Analysis/CallGraph.cpp        | 31 ++++++++++++++++----------
 mlir/lib/Transforms/Utils/Inliner.cpp  | 25 +++++++++++++++++++++
 mlir/test/Analysis/test-callgraph.mlir | 29 +++++++++++++++++++-----
 mlir/test/Transforms/inlining-dce.mlir | 15 ++++++++++++-
 4 files changed, 81 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp
index 780c7caee767c1..a2ff9c99469179 100644
--- a/mlir/lib/Analysis/CallGraph.cpp
+++ b/mlir/lib/Analysis/CallGraph.cpp
@@ -112,7 +112,8 @@ CallGraph::CallGraph(Operation *op)
 /// Get or add a call graph node for the given region.
 CallGraphNode *CallGraph::getOrAddNode(Region *region,
                                        CallGraphNode *parentNode) {
-  assert(region && isa<CallableOpInterface>(region->getParentOp()) &&
+  Operation *parentOp = region->getParentOp();
+  assert(region && isa<CallableOpInterface>(parentOp) &&
          "expected parent operation to be callable");
   std::unique_ptr<CallGraphNode> &node = nodes[region];
   if (!node) {
@@ -122,13 +123,12 @@ CallGraphNode *CallGraph::getOrAddNode(Region *region,
     if (parentNode) {
       parentNode->addChildEdge(node.get());
     } else {
-      // Otherwise, connect all callable nodes to the external node, this allows
-      // for conservatively including all callable nodes within the graph.
-      // FIXME This isn't correct, this is only necessary for callable nodes
-      // that *could* be called from external sources. This requires extending
-      // the interface for callables to check if they may be referenced
-      // externally.
-      externalCallerNode.addAbstractEdge(node.get());
+      // Otherwise, connect all symbol nodes with public visibility
+      // to the external node, which is a set including callable nodes
+      // may be referenced externally.
+      if (isa<SymbolOpInterface>(parentOp) &&
+          cast<SymbolOpInterface>(parentOp).isPublic())
+        externalCallerNode.addAbstractEdge(node.get());
     }
   }
   return node.get();
@@ -199,9 +199,8 @@ void CallGraph::print(raw_ostream &os) const {
       os << " : " << attrs;
   };
 
-  for (auto &nodeIt : nodes) {
-    const CallGraphNode *node = nodeIt.second.get();
-
+  // Functor used to emit the given node and edges.
+  auto emitNodeAndEdge = [&](const CallGraphNode *node) {
     // Dump the header for this node.
     os << "// - Node : ";
     emitNodeName(node);
@@ -214,13 +213,21 @@ void CallGraph::print(raw_ostream &os) const {
         os << "Call";
       else if (edge.isChild())
         os << "Child";
+      else if (edge.isAbstract())
+        os << "Abstract";
 
       os << "-Edge : ";
       emitNodeName(edge.getTarget());
       os << "\n";
     }
     os << "//\n";
-  }
+  };
+
+  // Emit all graph nodes including ExternalCallerNode and UnknownCalleeNode.
+  for (auto &nodeIt : nodes)
+    emitNodeAndEdge(nodeIt.second.get());
+  emitNodeAndEdge(getExternalCallerNode());
+  emitNodeAndEdge(getUnknownCalleeNode());
 
   os << "// -- SCCs --\n";
 
diff --git a/mlir/lib/Transforms/Utils/Inliner.cpp b/mlir/lib/Transforms/Utils/Inliner.cpp
index 8acfc96d2b611b..978bf7f5c0b70f 100644
--- a/mlir/lib/Transforms/Utils/Inliner.cpp
+++ b/mlir/lib/Transforms/Utils/Inliner.cpp
@@ -434,6 +434,9 @@ class Inliner::Impl {
                           CGUseList &useList, CallGraphSCC &currentSCC,
                           MLIRContext *context);
 
+  void collectDeadNodeAfterInline(CallGraph &cg, CGUseList &useList,
+                                  InlinerInterfaceImpl &inlinerIface);
+
 private:
   /// Optimize the nodes within the given SCC with one of the held optimization
   /// pass pipelines. Returns failure if an error occurred during the
@@ -748,6 +751,27 @@ bool Inliner::Impl::shouldInline(ResolvedCall &resolvedCall) {
   return true;
 }
 
+/// Iteratively clean up dead nodes until no change happened.
+void Inliner::Impl::collectDeadNodeAfterInline(
+    CallGraph &cg, CGUseList &useList, InlinerInterfaceImpl &inlinerIface) {
+  auto eraseDeadNode = [&](void) {
+    bool changed = false;
+    for (CallGraphNode *node : cg) {
+      if (useList.isDead(node)) {
+        useList.eraseNode(node);
+        inlinerIface.markForDeletion(node);
+        changed = true;
+      }
+    }
+    return changed;
+  };
+
+  while (1) {
+    if (!eraseDeadNode())
+      break;
+  }
+}
+
 LogicalResult Inliner::doInlining() {
   Impl impl(*this);
   auto *context = op->getContext();
@@ -765,6 +789,7 @@ LogicalResult Inliner::doInlining() {
     return result;
 
   // After inlining, make sure to erase any callables proven to be dead.
+  impl.collectDeadNodeAfterInline(cg, useList, inlinerIface);
   inlinerIface.eraseDeadCallables();
   return success();
 }
diff --git a/mlir/test/Analysis/test-callgraph.mlir b/mlir/test/Analysis/test-callgraph.mlir
index f6c9ff5006e053..8a00966bea61dd 100644
--- a/mlir/test/Analysis/test-callgraph.mlir
+++ b/mlir/test/Analysis/test-callgraph.mlir
@@ -8,24 +8,25 @@ module attributes {test.name = "simple"} {
     return
   }
 
+  // CHECK-NOT: Node{{.*}}func_b
   func.func private @func_b()
 
-  // CHECK: Node{{.*}}func_c
+  // CHECK: Node{{.*}}func_c{{.*}}private
   // CHECK-NEXT: Call-Edge{{.*}}Unknown-Callee-Node
-  func.func @func_c() {
+  func.func private @func_c() {
     call @func_b() : () -> ()
     return
   }
 
   // CHECK: Node{{.*}}func_d
-  // CHECK-NEXT: Call-Edge{{.*}}func_c
+  // CHECK-NEXT: Call-Edge{{.*}}func_c{{.*}}private
   func.func @func_d() {
     call @func_c() : () -> ()
     return
   }
 
   // CHECK: Node{{.*}}func_e
-  // CHECK-DAG: Call-Edge{{.*}}func_c
+  // CHECK-DAG: Call-Edge{{.*}}func_c{{.*}}private
   // CHECK-DAG: Call-Edge{{.*}}func_d
   // CHECK-DAG: Call-Edge{{.*}}func_e
   func.func @func_e() {
@@ -49,6 +50,16 @@ module attributes {test.name = "simple"} {
     call_indirect %fn() : () -> ()
     return
   }
+
+  // CHECK: Node{{.*}}External-Caller-Node
+  // CHECK: Edge{{.*}}func_a
+  // CHECK-NOT: Edge{{.*}}func_b
+  // CHECK-NOT: Edge{{.*}}func_c
+  // CHECK: Edge{{.*}}func_d
+  // CHECK: Edge{{.*}}func_e
+  // CHECK: Edge{{.*}}func_f
+
+  // CHECK: Node{{.*}}Unknown-Callee-Node
 }
 
 // -----
@@ -57,17 +68,23 @@ module attributes {test.name = "simple"} {
 module attributes {test.name = "nested"} {
   module @nested_module {
     // CHECK: Node{{.*}}func_a
-    func.func @func_a() {
+    func.func nested @func_a() {
       return
     }
   }
 
   // CHECK: Node{{.*}}func_b
-  // CHECK: Call-Edge{{.*}}func_a
+  // CHECK: Call-Edge{{.*}}func_a{{.*}}nested
   func.func @func_b() {
     "test.conversion_call_op"() { callee = @nested_module::@func_a } : () -> ()
     return
   }
+
+  // CHECK: Node{{.*}}External-Caller-Node
+  // CHECK: Edge{{.*}}func_b
+  // CHECK-NOT: Edge{{.*}}func_a
+
+  // CHECK: Node{{.*}}Unknown-Callee-Node
 }
 
 // -----
diff --git a/mlir/test/Transforms/inlining-dce.mlir b/mlir/test/Transforms/inlining-dce.mlir
index d167c1b4baae98..45b3ebc1e01772 100644
--- a/mlir/test/Transforms/inlining-dce.mlir
+++ b/mlir/test/Transforms/inlining-dce.mlir
@@ -10,7 +10,7 @@ func.func private @dead_function() {
 
 // Function becomes dead after inlining.
 // CHECK-NOT: func private @dead_function_b
-func.func @dead_function_b() {
+func.func private @dead_function_b() {
   return
 }
 
@@ -44,6 +44,19 @@ func.func @live_function_c() {
   return
 }
 
+// A transitive example, but no one be called by live-function.
+
+// CHECK-NOT: func private @dead_function_e
+func.func private @dead_function_e() {
+  call @live_function_b() : () -> ()
+  return
+}
+// CHECK-NOT: func private @dead_function_f
+func.func private @dead_function_f() {
+  call @dead_function_e() : () -> ()
+  return
+}
+
 // Function is referenced by non-callable top-level user.
 // CHECK: func private @live_function_d
 func.func private @live_function_d() {



More information about the Mlir-commits mailing list