[Mlir-commits] [mlir] [mlir][inliner] optimize self-recursive function detection [NFC] (PR #88452)

Congcong Cai llvmlistbot at llvm.org
Fri Apr 12 08:32:25 PDT 2024


https://github.com/HerrCai0907 updated https://github.com/llvm/llvm-project/pull/88452

>From 81b4420c3ca75c1f9588ab2e5df8425bbfe506ad Mon Sep 17 00:00:00 2001
From: Congcong Cai <congcongcai0907 at 163.com>
Date: Fri, 12 Apr 2024 06:48:13 +0800
Subject: [PATCH 1/2] [mlir][inliner] optimize self-recursive function
 detection

---
 mlir/lib/Transforms/Utils/Inliner.cpp | 31 ++++++++++++++++++++-------
 1 file changed, 23 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/Inliner.cpp b/mlir/lib/Transforms/Utils/Inliner.cpp
index 8acfc96d2b611b..fbb398cab6efd7 100644
--- a/mlir/lib/Transforms/Utils/Inliner.cpp
+++ b/mlir/lib/Transforms/Utils/Inliner.cpp
@@ -459,7 +459,9 @@ class Inliner::Impl {
                                  CGUseList &useList, CallGraphSCC &currentSCC);
 
   /// Returns true if the given call should be inlined.
-  bool shouldInline(ResolvedCall &resolvedCall);
+  bool
+  shouldInline(ResolvedCall &resolvedCall,
+               llvm::SmallPtrSet<Region *, 16U> const &recursiveCallRegions);
 
 private:
   Inliner &inliner;
@@ -621,6 +623,12 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
     llvm::dbgs() << "}\n";
   });
 
+  llvm::SmallPtrSet<Region *, 16U> recursiveCallRegions{};
+  for (ResolvedCall const &it : calls) {
+    if (it.sourceNode == it.targetNode)
+      recursiveCallRegions.insert(it.targetNode->getCallableRegion());
+  }
+
   // Try to inline each of the call operations. Don't cache the end iterator
   // here as more calls may be added during inlining.
   bool inlinedAnyCalls = false;
@@ -632,7 +640,7 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
     InlineHistoryT inlineHistoryID = callHistory[i];
     bool inHistory =
         inlineHistoryIncludes(it.targetNode, inlineHistoryID, inlineHistory);
-    bool doInline = !inHistory && shouldInline(it);
+    bool doInline = !inHistory && shouldInline(it, recursiveCallRegions);
     CallOpInterface call = it.call;
     LLVM_DEBUG({
       if (doInline)
@@ -705,23 +713,30 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
   return success(inlinedAnyCalls);
 }
 
+static bool isSelfRecursiveFunction(CallGraphNode *node) {
+  return llvm::find_if(*node, [&](CallGraphNode::Edge const &edge) -> bool {
+           return edge.getTarget() == node;
+         }) != node->end();
+}
+
 /// Returns true if the given call should be inlined.
-bool Inliner::Impl::shouldInline(ResolvedCall &resolvedCall) {
+bool Inliner::Impl::shouldInline(
+    ResolvedCall &resolvedCall,
+    llvm::SmallPtrSet<Region *, 16U> const &recursiveCallRegions) {
   // Don't allow inlining terminator calls. We currently don't support this
   // case.
   if (resolvedCall.call->hasTrait<OpTrait::IsTerminator>())
     return false;
 
+  Region *callableRegion = resolvedCall.targetNode->getCallableRegion();
+
   // Don't allow inlining if the target is a self-recursive function.
-  if (llvm::count_if(*resolvedCall.targetNode,
-                     [&](CallGraphNode::Edge const &edge) -> bool {
-                       return edge.getTarget() == resolvedCall.targetNode;
-                     }) > 0)
+  if (recursiveCallRegions.contains(callableRegion) ||
+      isSelfRecursiveFunction(resolvedCall.targetNode))
     return false;
 
   // Don't allow inlining if the target is an ancestor of the call. This
   // prevents inlining recursively.
-  Region *callableRegion = resolvedCall.targetNode->getCallableRegion();
   if (callableRegion->isAncestor(resolvedCall.call->getParentRegion()))
     return false;
 

>From d2f6e7cde8788b0e8458f4b455119a99c368f32b Mon Sep 17 00:00:00 2001
From: Congcong Cai <congcongcai0907 at 163.com>
Date: Fri, 12 Apr 2024 23:26:50 +0800
Subject: [PATCH 2/2] comments

---
 mlir/lib/Transforms/Utils/Inliner.cpp | 23 ++++++++++++-----------
 1 file changed, 12 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/Inliner.cpp b/mlir/lib/Transforms/Utils/Inliner.cpp
index fbb398cab6efd7..6f75849aa14af6 100644
--- a/mlir/lib/Transforms/Utils/Inliner.cpp
+++ b/mlir/lib/Transforms/Utils/Inliner.cpp
@@ -461,7 +461,7 @@ class Inliner::Impl {
   /// Returns true if the given call should be inlined.
   bool
   shouldInline(ResolvedCall &resolvedCall,
-               llvm::SmallPtrSet<Region *, 16U> const &recursiveCallRegions);
+               const llvm::SmallPtrSetImpl<Region *> &recursiveCallRegions);
 
 private:
   Inliner &inliner;
@@ -623,6 +623,8 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
     llvm::dbgs() << "}\n";
   });
 
+  // The call graph changes dynamically during inliner iterations.
+  // Maintaining a generated set can avoid to inline self-recursive function.
   llvm::SmallPtrSet<Region *, 16U> recursiveCallRegions{};
   for (ResolvedCall const &it : calls) {
     if (it.sourceNode == it.targetNode)
@@ -714,7 +716,7 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
 }
 
 static bool isSelfRecursiveFunction(CallGraphNode *node) {
-  return llvm::find_if(*node, [&](CallGraphNode::Edge const &edge) -> bool {
+  return llvm::find_if(*node, [&](const CallGraphNode::Edge &edge) -> bool {
            return edge.getTarget() == node;
          }) != node->end();
 }
@@ -722,7 +724,7 @@ static bool isSelfRecursiveFunction(CallGraphNode *node) {
 /// Returns true if the given call should be inlined.
 bool Inliner::Impl::shouldInline(
     ResolvedCall &resolvedCall,
-    llvm::SmallPtrSet<Region *, 16U> const &recursiveCallRegions) {
+    const llvm::SmallPtrSetImpl<Region *> &recursiveCallRegions) {
   // Don't allow inlining terminator calls. We currently don't support this
   // case.
   if (resolvedCall.call->hasTrait<OpTrait::IsTerminator>())
@@ -730,14 +732,13 @@ bool Inliner::Impl::shouldInline(
 
   Region *callableRegion = resolvedCall.targetNode->getCallableRegion();
 
-  // Don't allow inlining if the target is a self-recursive function.
-  if (recursiveCallRegions.contains(callableRegion) ||
-      isSelfRecursiveFunction(resolvedCall.targetNode))
-    return false;
-
-  // Don't allow inlining if the target is an ancestor of the call. This
-  // prevents inlining recursively.
-  if (callableRegion->isAncestor(resolvedCall.call->getParentRegion()))
+  // Don't allow inlining this following cases to prevent inlining recursively.
+  // 1. target has at least an edge back to itself in original call graph.
+  // 2. target has call instructions call itself after pervious inlining.
+  // 3. target is an ancestor of the call.
+  if (isSelfRecursiveFunction(resolvedCall.targetNode) ||
+      recursiveCallRegions.contains(callableRegion) ||
+      callableRegion->isAncestor(resolvedCall.call->getParentRegion()))
     return false;
 
   // Don't allow inlining if the callee has multiple blocks (unstructured



More information about the Mlir-commits mailing list