[llvm] c8cf393 - [mlgo][inliner] Handle recursive cases when skipping non-cold functions (#164099)
    via llvm-commits 
    llvm-commits at lists.llvm.org
       
    Mon Oct 20 07:28:54 PDT 2025
    
    
  
Author: Mircea Trofin
Date: 2025-10-20T07:28:50-07:00
New Revision: c8cf3937e7cb017be6e07512ccd63c126c639d61
URL: https://github.com/llvm/llvm-project/commit/c8cf3937e7cb017be6e07512ccd63c126c639d61
DIFF: https://github.com/llvm/llvm-project/commit/c8cf3937e7cb017be6e07512ccd63c126c639d61.diff
LOG: [mlgo][inliner] Handle recursive cases when skipping non-cold functions (#164099)
The `MLInlineAdvisor` currently skips over recursive cases, except that when we delegate to the default policy for non-cold functions, that policy could allow such inlining. The code updating internal state afterwards needs to handle that case.
Fix for https://issues.chromium.org/issues/369637577#comment14
Added: 
    llvm/test/Transforms/Inline/ML/recursive.ll
Modified: 
    llvm/lib/Analysis/MLInlineAdvisor.cpp
Removed: 
    
################################################################################
diff  --git a/llvm/lib/Analysis/MLInlineAdvisor.cpp b/llvm/lib/Analysis/MLInlineAdvisor.cpp
index 1d1a5560be478..9a5ae2ae26799 100644
--- a/llvm/lib/Analysis/MLInlineAdvisor.cpp
+++ b/llvm/lib/Analysis/MLInlineAdvisor.cpp
@@ -324,32 +324,44 @@ void MLInlineAdvisor::onSuccessfulInlining(const MLInlineAdvice &Advice,
     FAM.invalidate(*Caller, PA);
   }
   Advice.updateCachedCallerFPI(FAM);
-  int64_t IRSizeAfter =
-      getIRSize(*Caller) + (CalleeWasDeleted ? 0 : Advice.CalleeIRSize);
-  CurrentIRSize += IRSizeAfter - (Advice.CallerIRSize + Advice.CalleeIRSize);
+  if (Caller == Callee) {
+    assert(!CalleeWasDeleted);
+    // We double-counted CallerAndCalleeEdges - since the caller and callee
+    // would be the same
+    assert(Advice.CallerAndCalleeEdges % 2 == 0);
+    CurrentIRSize += getIRSize(*Caller) - Advice.CallerIRSize;
+    EdgeCount += getCachedFPI(*Caller).DirectCallsToDefinedFunctions -
+                 Advice.CallerAndCalleeEdges / 2;
+    // The NodeCount would stay the same.
+  } else {
+    int64_t IRSizeAfter =
+        getIRSize(*Caller) + (CalleeWasDeleted ? 0 : Advice.CalleeIRSize);
+    CurrentIRSize += IRSizeAfter - (Advice.CallerIRSize + Advice.CalleeIRSize);
+
+    // We can delta-update module-wide features. We know the inlining only
+    // changed the caller, and maybe the callee (by deleting the latter). Nodes
+    // are simple to update. For edges, we 'forget' the edges that the caller
+    // and callee used to have before inlining, and add back what they currently
+    // have together.
+    int64_t NewCallerAndCalleeEdges =
+        getCachedFPI(*Caller).DirectCallsToDefinedFunctions;
+
+    // A dead function's node is not actually removed from the call graph until
+    // the end of the call graph walk, but the node no longer belongs to any
+    // valid SCC.
+    if (CalleeWasDeleted) {
+      --NodeCount;
+      NodesInLastSCC.erase(CG.lookup(*Callee));
+      DeadFunctions.insert(Callee);
+    } else {
+      NewCallerAndCalleeEdges +=
+          getCachedFPI(*Callee).DirectCallsToDefinedFunctions;
+    }
+    EdgeCount += (NewCallerAndCalleeEdges - Advice.CallerAndCalleeEdges);
+  }
   if (CurrentIRSize > SizeIncreaseThreshold * InitialIRSize)
     ForceStop = true;
 
-  // We can delta-update module-wide features. We know the inlining only changed
-  // the caller, and maybe the callee (by deleting the latter).
-  // Nodes are simple to update.
-  // For edges, we 'forget' the edges that the caller and callee used to have
-  // before inlining, and add back what they currently have together.
-  int64_t NewCallerAndCalleeEdges =
-      getCachedFPI(*Caller).DirectCallsToDefinedFunctions;
-
-  // A dead function's node is not actually removed from the call graph until
-  // the end of the call graph walk, but the node no longer belongs to any valid
-  // SCC.
-  if (CalleeWasDeleted) {
-    --NodeCount;
-    NodesInLastSCC.erase(CG.lookup(*Callee));
-    DeadFunctions.insert(Callee);
-  } else {
-    NewCallerAndCalleeEdges +=
-        getCachedFPI(*Callee).DirectCallsToDefinedFunctions;
-  }
-  EdgeCount += (NewCallerAndCalleeEdges - Advice.CallerAndCalleeEdges);
   assert(CurrentIRSize >= 0 && EdgeCount >= 0 && NodeCount >= 0);
 }
 
diff  --git a/llvm/test/Transforms/Inline/ML/recursive.ll b/llvm/test/Transforms/Inline/ML/recursive.ll
new file mode 100644
index 0000000000000..2d9240a12a713
--- /dev/null
+++ b/llvm/test/Transforms/Inline/ML/recursive.ll
@@ -0,0 +1,34 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals all --version 6
+; REQUIRES: llvm_inliner_model_autogenerated
+; RUN: opt -S %s -o - -passes='inliner-ml-advisor-release' -ml-inliner-skip-policy=if-caller-not-cold | FileCheck %s
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128-Fn32"
+target triple = "aarch64-unknown-linux-android29"
+
+define i32 @a_func(ptr %this, i32 %color_id, i1 %dark_mode) local_unnamed_addr {
+; CHECK-LABEL: define i32 @a_func(
+; CHECK-SAME: ptr [[THIS:%.*]], i32 [[COLOR_ID:%.*]], i1 [[DARK_MODE:%.*]]) local_unnamed_addr {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    br i1 [[DARK_MODE]], label %[[SW_BB97:.*]], label %[[COMMON_RET:.*]]
+; CHECK:       [[COMMON_RET]]:
+; CHECK-NEXT:    ret i32 0
+; CHECK:       [[SW_BB97]]:
+; CHECK-NEXT:    br label %[[COMMON_RET]]
+;
+entry:
+  br i1 %dark_mode, label %sw.bb97, label %common.ret
+
+common.ret:                                       ; preds = %sw.bb97, %entry
+  ret i32 0
+
+sw.bb97:                                          ; preds = %entry
+  %call.i = tail call i32 @a_func(ptr null, i32 0, i1 false)
+  br label %common.ret
+}
+
+; Function Attrs: nocallback nofree nosync nounwind willreturn memory(argmem: read)
+declare ptr @llvm.load.relative.i32(ptr %0, i32 %1) #0
+
+attributes #0 = { nocallback nofree nosync nounwind willreturn memory(argmem: read) }
+;.
+; CHECK: attributes #[[ATTR0:[0-9]+]] = { nocallback nofree nosync nounwind willreturn memory(argmem: read) }
+;.
        
    
    
More information about the llvm-commits
mailing list