[Mlir-commits] [mlir] [mlir][inliner] optimize self-recursive function detection [NFC] (PR #88452)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 11 16:12:10 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Congcong Cai (HerrCai0907)

<details>
<summary>Changes</summary>

It helps the iteration finish as quickly as possible.
The original implement depends on the target of edge in `targetCall`, which is maintained by call graph and will not update during inlining. This PR wants to prepare a recursiveCallRegions before each iteration, then It can dynamically update.

Here is an example:
```mlir
func.func @<!-- -->b0() {
  func.call @<!-- -->b1() : () -> ()
  return
}
func.func @<!-- -->b1() {
  func.call @<!-- -->b0() : () -> ()
  return
}
```

The original inline debug log looks like:
```
* Inliner: Initial calls in SCC are: {
  0. func.call @<!-- -->b0() : () -> (),
  1. func.call @<!-- -->b1() : () -> (),
}
* Inlining call: 0. func.call @<!-- -->b0() : () -> ()
* new inlineHistory entry: 0. [func.call @<!-- -->b0() : () -> (), root]
* new call 2 {func.call @<!-- -->b0() : () -> ()}
   with historyID = 0, added due to inlining of
  call {func.call @<!-- -->b0() : () -> ()}
 with historyID = root
* Inlining call: 1. func.call @<!-- -->b1() : () -> ()
* new inlineHistory entry: 1. [func.call @<!-- -->b1() : () -> (), root]
* new call 3 {func.call @<!-- -->b1() : () -> ()}
   with historyID = 1, added due to inlining of
  call {func.call @<!-- -->b1() : () -> ()}
 with historyID = root
* Not inlining call: 2. func.call @<!-- -->b1() : () -> ()
* Not inlining call: 3. func.call @<!-- -->b1() : () -> ()
* Inliner: Initial calls in SCC are: {
  0. func.call @<!-- -->b1() : () -> (),
  1. func.call @<!-- -->b1() : () -> (),
}
* Not inlining call: 0. func.call @<!-- -->b1() : () -> ()
* Inlining call: 1. func.call @<!-- -->b1() : () -> ()
* new inlineHistory entry: 0. [func.call @<!-- -->b1() : () -> (), root]
* new call 2 {func.call @<!-- -->b1() : () -> ()}
   with historyID = 0, added due to inlining of
  call {func.call @<!-- -->b1() : () -> ()}
 with historyID = root
* Not inlining call: 2. func.call @<!-- -->b1() : () -> ()
* Inliner: Initial calls in SCC are: {
  0. func.call @<!-- -->b1() : () -> (),
  1. func.call @<!-- -->b1() : () -> (),
}
* Not inlining call: 0. func.call @<!-- -->b1() : () -> ()
* Inlining call: 1. func.call @<!-- -->b1() : () -> ()
* new inlineHistory entry: 0. [func.call @<!-- -->b1() : () -> (), root]
* new call 2 {func.call @<!-- -->b1() : () -> ()}
   with historyID = 0, added due to inlining of
  call {func.call @<!-- -->b1() : () -> ()}
 with historyID = root
* Not inlining call: 2. func.call @<!-- -->b1() : () -> ()
* Inliner: Initial calls in SCC are: {
  0. func.call @<!-- -->b1() : () -> (),
  1. func.call @<!-- -->b1() : () -> (),
}
* Not inlining call: 0. func.call @<!-- -->b1() : () -> ()
* Inlining call: 1. func.call @<!-- -->b1() : () -> ()
* new inlineHistory entry: 0. [func.call @<!-- -->b1() : () -> (), root]
* new call 2 {func.call @<!-- -->b1() : () -> ()}
   with historyID = 0, added due to inlining of
  call {func.call @<!-- -->b1() : () -> ()}
 with historyID = root
* Not inlining call: 2. func.call @<!-- -->b1() : () -> ()
* Inliner: Initial calls in SCC are: {
}
```

After optimization:
```
* Inliner: Initial calls in SCC are: {
  0. func.call @<!-- -->b0() : () -> (),
  1. func.call @<!-- -->b1() : () -> (),
}
* Inlining call: 0. func.call @<!-- -->b0() : () -> ()
* new inlineHistory entry: 0. [func.call @<!-- -->b0() : () -> (), root]
* new call 2 {func.call @<!-- -->b0() : () -> ()}
   with historyID = 0, added due to inlining of
  call {func.call @<!-- -->b0() : () -> ()}
 with historyID = root
* Inlining call: 1. func.call @<!-- -->b1() : () -> ()
* new inlineHistory entry: 1. [func.call @<!-- -->b1() : () -> (), root]
* new call 3 {func.call @<!-- -->b1() : () -> ()}
   with historyID = 1, added due to inlining of
  call {func.call @<!-- -->b1() : () -> ()}
 with historyID = root
* Not inlining call: 2. func.call @<!-- -->b1() : () -> ()
* Not inlining call: 3. func.call @<!-- -->b1() : () -> ()
* Inliner: Initial calls in SCC are: {
  0. func.call @<!-- -->b1() : () -> (),
  1. func.call @<!-- -->b1() : () -> (),
}
* Not inlining call: 0. func.call @<!-- -->b1() : () -> ()
* Not inlining call: 1. func.call @<!-- -->b1() : () -> ()
* Inliner: Initial calls in SCC are: {
}
```

---
Full diff: https://github.com/llvm/llvm-project/pull/88452.diff


1 Files Affected:

- (modified) mlir/lib/Transforms/Utils/Inliner.cpp (+23-8) 


``````````diff
diff --git a/mlir/lib/Transforms/Utils/Inliner.cpp b/mlir/lib/Transforms/Utils/Inliner.cpp
index 8acfc96d2b611b..fbb398cab6efd7 100644
--- a/mlir/lib/Transforms/Utils/Inliner.cpp
+++ b/mlir/lib/Transforms/Utils/Inliner.cpp
@@ -459,7 +459,9 @@ class Inliner::Impl {
                                  CGUseList &useList, CallGraphSCC &currentSCC);
 
   /// Returns true if the given call should be inlined.
-  bool shouldInline(ResolvedCall &resolvedCall);
+  bool
+  shouldInline(ResolvedCall &resolvedCall,
+               llvm::SmallPtrSet<Region *, 16U> const &recursiveCallRegions);
 
 private:
   Inliner &inliner;
@@ -621,6 +623,12 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
     llvm::dbgs() << "}\n";
   });
 
+  llvm::SmallPtrSet<Region *, 16U> recursiveCallRegions{};
+  for (ResolvedCall const &it : calls) {
+    if (it.sourceNode == it.targetNode)
+      recursiveCallRegions.insert(it.targetNode->getCallableRegion());
+  }
+
   // Try to inline each of the call operations. Don't cache the end iterator
   // here as more calls may be added during inlining.
   bool inlinedAnyCalls = false;
@@ -632,7 +640,7 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
     InlineHistoryT inlineHistoryID = callHistory[i];
     bool inHistory =
         inlineHistoryIncludes(it.targetNode, inlineHistoryID, inlineHistory);
-    bool doInline = !inHistory && shouldInline(it);
+    bool doInline = !inHistory && shouldInline(it, recursiveCallRegions);
     CallOpInterface call = it.call;
     LLVM_DEBUG({
       if (doInline)
@@ -705,23 +713,30 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
   return success(inlinedAnyCalls);
 }
 
+static bool isSelfRecursiveFunction(CallGraphNode *node) {
+  return llvm::find_if(*node, [&](CallGraphNode::Edge const &edge) -> bool {
+           return edge.getTarget() == node;
+         }) != node->end();
+}
+
 /// Returns true if the given call should be inlined.
-bool Inliner::Impl::shouldInline(ResolvedCall &resolvedCall) {
+bool Inliner::Impl::shouldInline(
+    ResolvedCall &resolvedCall,
+    llvm::SmallPtrSet<Region *, 16U> const &recursiveCallRegions) {
   // Don't allow inlining terminator calls. We currently don't support this
   // case.
   if (resolvedCall.call->hasTrait<OpTrait::IsTerminator>())
     return false;
 
+  Region *callableRegion = resolvedCall.targetNode->getCallableRegion();
+
   // Don't allow inlining if the target is a self-recursive function.
-  if (llvm::count_if(*resolvedCall.targetNode,
-                     [&](CallGraphNode::Edge const &edge) -> bool {
-                       return edge.getTarget() == resolvedCall.targetNode;
-                     }) > 0)
+  if (recursiveCallRegions.contains(callableRegion) ||
+      isSelfRecursiveFunction(resolvedCall.targetNode))
     return false;
 
   // Don't allow inlining if the target is an ancestor of the call. This
   // prevents inlining recursively.
-  Region *callableRegion = resolvedCall.targetNode->getCallableRegion();
   if (callableRegion->isAncestor(resolvedCall.call->getParentRegion()))
     return false;
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/88452


More information about the Mlir-commits mailing list