[llvm] [TRE] Adjust function entry count (PR #143987)

Mircea Trofin via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 13 15:48:06 PDT 2025


https://github.com/mtrofin updated https://github.com/llvm/llvm-project/pull/143987

>From 17758ae24eaaeb7922564fc7dda024ffe627acb9 Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Thu, 12 Jun 2025 15:34:46 -0700
Subject: [PATCH] [TailRecursionElim] Adjust function entry count

---
 .../Scalar/TailRecursionElimination.h         |   7 +-
 llvm/lib/Passes/PassBuilderPipelines.cpp      |  15 ++-
 .../Scalar/TailRecursionElimination.cpp       |  54 ++++++--
 .../TailCallElim/entry-count-adjustment.ll    | 116 ++++++++++++++++++
 4 files changed, 181 insertions(+), 11 deletions(-)
 create mode 100644 llvm/test/Transforms/TailCallElim/entry-count-adjustment.ll

diff --git a/llvm/include/llvm/Transforms/Scalar/TailRecursionElimination.h b/llvm/include/llvm/Transforms/Scalar/TailRecursionElimination.h
index 57b1ed9bf4fe8..22a70cd66865a 100644
--- a/llvm/include/llvm/Transforms/Scalar/TailRecursionElimination.h
+++ b/llvm/include/llvm/Transforms/Scalar/TailRecursionElimination.h
@@ -58,7 +58,12 @@ namespace llvm {
 
 class Function;
 
-struct TailCallElimPass : PassInfoMixin<TailCallElimPass> {
+class TailCallElimPass : public PassInfoMixin<TailCallElimPass> {
+  const bool UpdateFunctionEntryCount;
+
+public:
+  TailCallElimPass(bool UpdateFunctionEntryCount = true)
+      : UpdateFunctionEntryCount(UpdateFunctionEntryCount) {}
   PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
 };
 }
diff --git a/llvm/lib/Passes/PassBuilderPipelines.cpp b/llvm/lib/Passes/PassBuilderPipelines.cpp
index a99146d5eaa34..f0a7a7f693792 100644
--- a/llvm/lib/Passes/PassBuilderPipelines.cpp
+++ b/llvm/lib/Passes/PassBuilderPipelines.cpp
@@ -628,7 +628,10 @@ PassBuilder::buildFunctionSimplificationPipeline(OptimizationLevel Level,
       !Level.isOptimizingForSize())
     FPM.addPass(PGOMemOPSizeOpt());
 
-  FPM.addPass(TailCallElimPass());
+  FPM.addPass(
+      TailCallElimPass(/*UpdateFunctionEntryCount=*/(
+                           PGOOpt && PGOOpt->Action == PGOOptions::IRUse) ||
+                       !UseCtxProfile.empty()));
   FPM.addPass(
       SimplifyCFGPass(SimplifyCFGOptions().convertSwitchRangeToICmp(true)));
 
@@ -1581,7 +1584,10 @@ PassBuilder::buildModuleOptimizationPipeline(OptimizationLevel Level,
   OptimizePM.addPass(DivRemPairsPass());
 
   // Try to annotate calls that were created during optimization.
-  OptimizePM.addPass(TailCallElimPass());
+  OptimizePM.addPass(
+      TailCallElimPass(/*UpdateFunctionEntryCount=*/(
+                           PGOOpt && PGOOpt->Action == PGOOptions::IRUse) ||
+                       !UseCtxProfile.empty()));
 
   // LoopSink (and other loop passes since the last simplifyCFG) might have
   // resulted in single-entry-single-exit or empty blocks. Clean up the CFG.
@@ -2069,7 +2075,10 @@ PassBuilder::buildLTODefaultPipeline(OptimizationLevel Level,
 
   // LTO provides additional opportunities for tailcall elimination due to
   // link-time inlining, and visibility of nocapture attribute.
-  FPM.addPass(TailCallElimPass());
+  FPM.addPass(
+      TailCallElimPass(/*UpdateFunctionEntryCount=*/(
+                           PGOOpt && PGOOpt->Action == PGOOptions::IRUse) ||
+                       !UseCtxProfile.empty()));
 
   // Run a few AA driver optimizations here and now to cleanup the code.
   MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM),
diff --git a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
index c71c5a70a12fd..e0e2764709cf6 100644
--- a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
@@ -53,6 +53,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/BlockFrequencyInfo.h"
 #include "llvm/Analysis/DomTreeUpdater.h"
 #include "llvm/Analysis/GlobalsModRef.h"
 #include "llvm/Analysis/InstructionSimplify.h"
@@ -75,6 +76,7 @@
 #include "llvm/IR/Module.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Pass.h"
+#include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
 #include "llvm/Transforms/Scalar.h"
@@ -87,6 +89,11 @@ STATISTIC(NumEliminated, "Number of tail calls removed");
 STATISTIC(NumRetDuped,   "Number of return duplicated");
 STATISTIC(NumAccumAdded, "Number of accumulators introduced");
 
+static cl::opt<bool> ForceDisableBFI(
+    "tre-disable-entrycount-recompute", cl::init(false), cl::Hidden,
+    cl::desc("Force disabling recomputing of function entry count, on "
+             "successful tail recursion elimination."));
+
 /// Scan the specified function for alloca instructions.
 /// If it contains any dynamic allocas, returns false.
 static bool canTRE(Function &F) {
@@ -409,6 +416,8 @@ class TailRecursionEliminator {
   AliasAnalysis *AA;
   OptimizationRemarkEmitter *ORE;
   DomTreeUpdater &DTU;
+  const uint64_t OrigEntryBBFreq;
+  DenseMap<const BasicBlock *, uint64_t> OriginalBBFreqs;
 
   // The below are shared state we want to have available when eliminating any
   // calls in the function. There values should be populated by
@@ -438,8 +447,18 @@ class TailRecursionEliminator {
 
   TailRecursionEliminator(Function &F, const TargetTransformInfo *TTI,
                           AliasAnalysis *AA, OptimizationRemarkEmitter *ORE,
-                          DomTreeUpdater &DTU)
-      : F(F), TTI(TTI), AA(AA), ORE(ORE), DTU(DTU) {}
+                          DomTreeUpdater &DTU, BlockFrequencyInfo *BFI)
+      : F(F), TTI(TTI), AA(AA), ORE(ORE), DTU(DTU),
+        OrigEntryBBFreq(
+            BFI ? BFI->getBlockFreq(&F.getEntryBlock()).getFrequency() : 0U) {
+    assert(((BFI != nullptr) == (OrigEntryBBFreq != 0)) &&
+           "If the function has an entry count, its entry basic block should "
+           "have a non-zero frequency. Pass a nullptr BFI if the function has "
+           "no entry count");
+    if (BFI)
+      for (const auto &BB : F)
+        OriginalBBFreqs.insert({&BB, BFI->getBlockFreq(&BB).getFrequency()});
+  }
 
   CallInst *findTRECandidate(BasicBlock *BB);
 
@@ -460,7 +479,7 @@ class TailRecursionEliminator {
 public:
   static bool eliminate(Function &F, const TargetTransformInfo *TTI,
                         AliasAnalysis *AA, OptimizationRemarkEmitter *ORE,
-                        DomTreeUpdater &DTU);
+                        DomTreeUpdater &DTU, BlockFrequencyInfo *BFI);
 };
 } // namespace
 
@@ -746,6 +765,17 @@ bool TailRecursionEliminator::eliminateCall(CallInst *CI) {
   CI->eraseFromParent();   // Remove call.
   DTU.applyUpdates({{DominatorTree::Insert, BB, HeaderBB}});
   ++NumEliminated;
+  if (auto EC = F.getEntryCount()) {
+    assert(OrigEntryBBFreq);
+    auto It = OriginalBBFreqs.find(BB);
+    assert(It != OriginalBBFreqs.end());
+    auto RelativeBBFreq =
+        static_cast<double>(It->second) / static_cast<double>(OrigEntryBBFreq);
+    auto OldEntryCount = EC.value().getCount();
+    auto ToSubtract = static_cast<uint64_t>(RelativeBBFreq * OldEntryCount);
+    assert(OldEntryCount > ToSubtract);
+    F.setEntryCount(OldEntryCount - ToSubtract, EC->getType());
+  }
   return true;
 }
 
@@ -872,7 +902,8 @@ bool TailRecursionEliminator::eliminate(Function &F,
                                         const TargetTransformInfo *TTI,
                                         AliasAnalysis *AA,
                                         OptimizationRemarkEmitter *ORE,
-                                        DomTreeUpdater &DTU) {
+                                        DomTreeUpdater &DTU,
+                                        BlockFrequencyInfo *BFI) {
   if (F.getFnAttribute("disable-tail-calls").getValueAsBool())
     return false;
 
@@ -888,7 +919,7 @@ bool TailRecursionEliminator::eliminate(Function &F,
     return MadeChange;
 
   // Change any tail recursive calls to loops.
-  TailRecursionEliminator TRE(F, TTI, AA, ORE, DTU);
+  TailRecursionEliminator TRE(F, TTI, AA, ORE, DTU, BFI);
 
   for (BasicBlock &BB : F)
     MadeChange |= TRE.processBlock(BB);
@@ -930,7 +961,8 @@ struct TailCallElim : public FunctionPass {
     return TailRecursionEliminator::eliminate(
         F, &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F),
         &getAnalysis<AAResultsWrapperPass>().getAAResults(),
-        &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(), DTU);
+        &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(), DTU,
+        nullptr);
   }
 };
 }
@@ -953,6 +985,13 @@ PreservedAnalyses TailCallElimPass::run(Function &F,
 
   TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F);
   AliasAnalysis &AA = AM.getResult<AAManager>(F);
+  // This must come first. It needs the 2 analyses, meaning, if it came after
+  // the lines asking for the cached result, should they be nullptr (which, in
+  // the case of the PDT, is likely), updates to the trees would be missed.
+  auto *BFI = (!ForceDisableBFI && UpdateFunctionEntryCount &&
+               F.getEntryCount().has_value())
+                  ? &AM.getResult<BlockFrequencyAnalysis>(F)
+                  : nullptr;
   auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
   auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F);
   auto *PDT = AM.getCachedResult<PostDominatorTreeAnalysis>(F);
@@ -960,7 +999,8 @@ PreservedAnalyses TailCallElimPass::run(Function &F,
   // UpdateStrategy based on some test results. It is feasible to switch the
   // UpdateStrategy to Lazy if we find it profitable later.
   DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Eager);
-  bool Changed = TailRecursionEliminator::eliminate(F, &TTI, &AA, &ORE, DTU);
+  bool Changed =
+      TailRecursionEliminator::eliminate(F, &TTI, &AA, &ORE, DTU, BFI);
 
   if (!Changed)
     return PreservedAnalyses::all();
diff --git a/llvm/test/Transforms/TailCallElim/entry-count-adjustment.ll b/llvm/test/Transforms/TailCallElim/entry-count-adjustment.ll
new file mode 100644
index 0000000000000..94be821d92268
--- /dev/null
+++ b/llvm/test/Transforms/TailCallElim/entry-count-adjustment.ll
@@ -0,0 +1,116 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals
+; RUN: opt -passes=tailcallelim -S %s -o - | FileCheck %s
+
+; Test that tail call elimination correctly adjusts function entry counts
+; when eliminating tail recursive calls.
+
+; Basic test: eliminate a tail call and adjust entry count
+define i32 @test_basic_entry_count_adjustment(i32 %n) !prof !0 {
+; CHECK-LABEL: @test_basic_entry_count_adjustment(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br label [[TAILRECURSE:%.*]]
+; CHECK:       tailrecurse:
+; CHECK-NEXT:    [[N_TR:%.*]] = phi i32 [ [[N:%.*]], [[ENTRY:%.*]] ], [ [[SUB:%.*]], [[IF_THEN:%.*]] ]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i32 [[N_TR]], 0
+; CHECK-NEXT:    br i1 [[CMP]], label [[IF_THEN]], label [[IF_ELSE:%.*]], !prof [[PROF1:![0-9]+]]
+; CHECK:       if.then:
+; CHECK-NEXT:    [[SUB]] = sub i32 [[N_TR]], 1
+; CHECK-NEXT:    br label [[TAILRECURSE]]
+; CHECK:       if.else:
+; CHECK-NEXT:    ret i32 0
+;
+entry:
+  %cmp = icmp sgt i32 %n, 0
+  br i1 %cmp, label %if.then, label %if.else, !prof !1
+
+if.then:                                          ; preds = %entry
+  %sub = sub i32 %n, 1
+  %call = tail call i32 @test_basic_entry_count_adjustment(i32 %sub)
+  ret i32 %call
+
+if.else:                                          ; preds = %entry
+  ret i32 0
+}
+
+; Test multiple tail calls in different blocks with different frequencies
+define i32 @test_multiple_blocks_entry_count(i32 %n, i32 %flag) !prof !2 {
+; CHECK-LABEL: @test_multiple_blocks_entry_count(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br label [[TAILRECURSE:%.*]]
+; CHECK:       tailrecurse:
+; CHECK-NEXT:    [[N_TR:%.*]] = phi i32 [ [[N:%.*]], [[ENTRY:%.*]] ], [ [[SUB1:%.*]], [[BLOCK1:%.*]] ], [ [[SUB2:%.*]], [[BLOCK2:%.*]] ]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i32 [[N_TR]], 0
+; CHECK-NEXT:    br i1 [[CMP]], label [[CHECK_FLAG:%.*]], label [[BASE_CASE:%.*]]
+; CHECK:       check.flag:
+; CHECK-NEXT:    [[CMP_FLAG:%.*]] = icmp eq i32 [[FLAG:%.*]], 1
+; CHECK-NEXT:    br i1 [[CMP_FLAG]], label [[BLOCK1]], label [[BLOCK2]], !prof [[PROF3:![0-9]+]]
+; CHECK:       block1:
+; CHECK-NEXT:    [[SUB1]] = sub i32 [[N_TR]], 1
+; CHECK-NEXT:    br label [[TAILRECURSE]]
+; CHECK:       block2:
+; CHECK-NEXT:    [[SUB2]] = sub i32 [[N_TR]], 2
+; CHECK-NEXT:    br label [[TAILRECURSE]]
+; CHECK:       base.case:
+; CHECK-NEXT:    ret i32 1
+;
+entry:
+  %cmp = icmp sgt i32 %n, 0
+  br i1 %cmp, label %check.flag, label %base.case
+
+check.flag:
+  %cmp.flag = icmp eq i32 %flag, 1
+  br i1 %cmp.flag, label %block1, label %block2, !prof !3
+
+block1:                                           ; preds = %check.flag
+  %sub1 = sub i32 %n, 1
+  %call1 = tail call i32 @test_multiple_blocks_entry_count(i32 %sub1, i32 %flag)
+  ret i32 %call1
+
+block2:                                           ; preds = %check.flag
+  %sub2 = sub i32 %n, 2
+  %call2 = tail call i32 @test_multiple_blocks_entry_count(i32 %sub2, i32 %flag)
+  ret i32 %call2
+
+base.case:                                        ; preds = %entry
+  ret i32 1
+}
+
+; Test function without entry count (should not crash)
+define i32 @test_no_entry_count(i32 %n) {
+; CHECK-LABEL: @test_no_entry_count(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br label [[TAILRECURSE:%.*]]
+; CHECK:       tailrecurse:
+; CHECK-NEXT:    [[N_TR:%.*]] = phi i32 [ [[N:%.*]], [[ENTRY:%.*]] ], [ [[SUB:%.*]], [[IF_THEN:%.*]] ]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i32 [[N_TR]], 0
+; CHECK-NEXT:    br i1 [[CMP]], label [[IF_THEN]], label [[IF_ELSE:%.*]]
+; CHECK:       if.then:
+; CHECK-NEXT:    [[SUB]] = sub i32 [[N_TR]], 1
+; CHECK-NEXT:    br label [[TAILRECURSE]]
+; CHECK:       if.else:
+; CHECK-NEXT:    ret i32 0
+;
+entry:
+  %cmp = icmp sgt i32 %n, 0
+  br i1 %cmp, label %if.then, label %if.else
+
+if.then:                                          ; preds = %entry
+  %sub = sub i32 %n, 1
+  %call = tail call i32 @test_no_entry_count(i32 %sub)
+  ret i32 %call
+
+if.else:                                          ; preds = %entry
+  ret i32 0
+}
+
+; Function entry count metadata
+!0 = !{!"function_entry_count", i64 1000}
+!1 = !{!"branch_weights", i32 800, i32 200}
+!2 = !{!"function_entry_count", i64 2000}
+!3 = !{!"branch_weights", i32 100, i32 500}
+;.
+; CHECK: [[META0:![0-9]+]] = !{!"function_entry_count", i64 201}
+; CHECK: [[PROF1]] = !{!"branch_weights", i32 800, i32 200}
+; CHECK: [[META2:![0-9]+]] = !{!"function_entry_count", i64 859}
+; CHECK: [[PROF3]] = !{!"branch_weights", i32 100, i32 500}
+;.



More information about the llvm-commits mailing list