[llvm] [TailRecursionElim] Adjust function entry count (PR #143987)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Jun 13 15:17:15 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Mircea Trofin (mtrofin)
<details>
<summary>Changes</summary>
The entry count of a function needs to be updated after a callsite is elided by TRE: before elision, the entry count accounted for the recursive call at that callsite. After TRE, we need to remove that callsite's contribution.
This patch enables this for instrumented profiling cases because, there, we know the function entry count captured entries before TRE. A bit harder to assert that for sampling-based.
---
Full diff: https://github.com/llvm/llvm-project/pull/143987.diff
4 Files Affected:
- (modified) llvm/include/llvm/Transforms/Scalar/TailRecursionElimination.h (+6-1)
- (modified) llvm/lib/Passes/PassBuilderPipelines.cpp (+12-3)
- (modified) llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp (+47-7)
- (added) llvm/test/Transforms/TailCallElim/entry-count-adjustment.ll (+116)
``````````diff
diff --git a/llvm/include/llvm/Transforms/Scalar/TailRecursionElimination.h b/llvm/include/llvm/Transforms/Scalar/TailRecursionElimination.h
index 57b1ed9bf4fe8..1f88d3ff6ab22 100644
--- a/llvm/include/llvm/Transforms/Scalar/TailRecursionElimination.h
+++ b/llvm/include/llvm/Transforms/Scalar/TailRecursionElimination.h
@@ -58,7 +58,12 @@ namespace llvm {
class Function;
-struct TailCallElimPass : PassInfoMixin<TailCallElimPass> {
+class TailCallElimPass : PassInfoMixin<TailCallElimPass> {
+ const bool UpdateFunctionEntryCount;
+
+public:
+ TailCallElimPass(bool UpdateFunctionEntryCount = true)
+ : UpdateFunctionEntryCount(UpdateFunctionEntryCount) {}
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
};
}
diff --git a/llvm/lib/Passes/PassBuilderPipelines.cpp b/llvm/lib/Passes/PassBuilderPipelines.cpp
index a99146d5eaa34..f0a7a7f693792 100644
--- a/llvm/lib/Passes/PassBuilderPipelines.cpp
+++ b/llvm/lib/Passes/PassBuilderPipelines.cpp
@@ -628,7 +628,10 @@ PassBuilder::buildFunctionSimplificationPipeline(OptimizationLevel Level,
!Level.isOptimizingForSize())
FPM.addPass(PGOMemOPSizeOpt());
- FPM.addPass(TailCallElimPass());
+ FPM.addPass(
+ TailCallElimPass(/*UpdateFunctionEntryCount=*/(
+ PGOOpt && PGOOpt->Action == PGOOptions::IRUse) ||
+ !UseCtxProfile.empty()));
FPM.addPass(
SimplifyCFGPass(SimplifyCFGOptions().convertSwitchRangeToICmp(true)));
@@ -1581,7 +1584,10 @@ PassBuilder::buildModuleOptimizationPipeline(OptimizationLevel Level,
OptimizePM.addPass(DivRemPairsPass());
// Try to annotate calls that were created during optimization.
- OptimizePM.addPass(TailCallElimPass());
+ OptimizePM.addPass(
+ TailCallElimPass(/*UpdateFunctionEntryCount=*/(
+ PGOOpt && PGOOpt->Action == PGOOptions::IRUse) ||
+ !UseCtxProfile.empty()));
// LoopSink (and other loop passes since the last simplifyCFG) might have
// resulted in single-entry-single-exit or empty blocks. Clean up the CFG.
@@ -2069,7 +2075,10 @@ PassBuilder::buildLTODefaultPipeline(OptimizationLevel Level,
// LTO provides additional opportunities for tailcall elimination due to
// link-time inlining, and visibility of nocapture attribute.
- FPM.addPass(TailCallElimPass());
+ FPM.addPass(
+ TailCallElimPass(/*UpdateFunctionEntryCount=*/(
+ PGOOpt && PGOOpt->Action == PGOOptions::IRUse) ||
+ !UseCtxProfile.empty()));
// Run a few AA driver optimizations here and now to cleanup the code.
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM),
diff --git a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
index c71c5a70a12fd..e0e2764709cf6 100644
--- a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
@@ -53,6 +53,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/BlockFrequencyInfo.h"
#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/InstructionSimplify.h"
@@ -75,6 +76,7 @@
#include "llvm/IR/Module.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
+#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Scalar.h"
@@ -87,6 +89,11 @@ STATISTIC(NumEliminated, "Number of tail calls removed");
STATISTIC(NumRetDuped, "Number of return duplicated");
STATISTIC(NumAccumAdded, "Number of accumulators introduced");
+static cl::opt<bool> ForceDisableBFI(
+ "tre-disable-entrycount-recompute", cl::init(false), cl::Hidden,
+ cl::desc("Force disabling recomputing of function entry count, on "
+ "successful tail recursion elimination."));
+
/// Scan the specified function for alloca instructions.
/// If it contains any dynamic allocas, returns false.
static bool canTRE(Function &F) {
@@ -409,6 +416,8 @@ class TailRecursionEliminator {
AliasAnalysis *AA;
OptimizationRemarkEmitter *ORE;
DomTreeUpdater &DTU;
+ const uint64_t OrigEntryBBFreq;
+ DenseMap<const BasicBlock *, uint64_t> OriginalBBFreqs;
// The below are shared state we want to have available when eliminating any
// calls in the function. There values should be populated by
@@ -438,8 +447,18 @@ class TailRecursionEliminator {
TailRecursionEliminator(Function &F, const TargetTransformInfo *TTI,
AliasAnalysis *AA, OptimizationRemarkEmitter *ORE,
- DomTreeUpdater &DTU)
- : F(F), TTI(TTI), AA(AA), ORE(ORE), DTU(DTU) {}
+ DomTreeUpdater &DTU, BlockFrequencyInfo *BFI)
+ : F(F), TTI(TTI), AA(AA), ORE(ORE), DTU(DTU),
+ OrigEntryBBFreq(
+ BFI ? BFI->getBlockFreq(&F.getEntryBlock()).getFrequency() : 0U) {
+ assert(((BFI != nullptr) == (OrigEntryBBFreq != 0)) &&
+ "If the function has an entry count, its entry basic block should "
+ "have a non-zero frequency. Pass a nullptr BFI if the function has "
+ "no entry count");
+ if (BFI)
+ for (const auto &BB : F)
+ OriginalBBFreqs.insert({&BB, BFI->getBlockFreq(&BB).getFrequency()});
+ }
CallInst *findTRECandidate(BasicBlock *BB);
@@ -460,7 +479,7 @@ class TailRecursionEliminator {
public:
static bool eliminate(Function &F, const TargetTransformInfo *TTI,
AliasAnalysis *AA, OptimizationRemarkEmitter *ORE,
- DomTreeUpdater &DTU);
+ DomTreeUpdater &DTU, BlockFrequencyInfo *BFI);
};
} // namespace
@@ -746,6 +765,17 @@ bool TailRecursionEliminator::eliminateCall(CallInst *CI) {
CI->eraseFromParent(); // Remove call.
DTU.applyUpdates({{DominatorTree::Insert, BB, HeaderBB}});
++NumEliminated;
+ if (auto EC = F.getEntryCount()) {
+ assert(OrigEntryBBFreq);
+ auto It = OriginalBBFreqs.find(BB);
+ assert(It != OriginalBBFreqs.end());
+ auto RelativeBBFreq =
+ static_cast<double>(It->second) / static_cast<double>(OrigEntryBBFreq);
+ auto OldEntryCount = EC.value().getCount();
+ auto ToSubtract = static_cast<uint64_t>(RelativeBBFreq * OldEntryCount);
+ assert(OldEntryCount > ToSubtract);
+ F.setEntryCount(OldEntryCount - ToSubtract, EC->getType());
+ }
return true;
}
@@ -872,7 +902,8 @@ bool TailRecursionEliminator::eliminate(Function &F,
const TargetTransformInfo *TTI,
AliasAnalysis *AA,
OptimizationRemarkEmitter *ORE,
- DomTreeUpdater &DTU) {
+ DomTreeUpdater &DTU,
+ BlockFrequencyInfo *BFI) {
if (F.getFnAttribute("disable-tail-calls").getValueAsBool())
return false;
@@ -888,7 +919,7 @@ bool TailRecursionEliminator::eliminate(Function &F,
return MadeChange;
// Change any tail recursive calls to loops.
- TailRecursionEliminator TRE(F, TTI, AA, ORE, DTU);
+ TailRecursionEliminator TRE(F, TTI, AA, ORE, DTU, BFI);
for (BasicBlock &BB : F)
MadeChange |= TRE.processBlock(BB);
@@ -930,7 +961,8 @@ struct TailCallElim : public FunctionPass {
return TailRecursionEliminator::eliminate(
F, &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F),
&getAnalysis<AAResultsWrapperPass>().getAAResults(),
- &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(), DTU);
+ &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(), DTU,
+ nullptr);
}
};
}
@@ -953,6 +985,13 @@ PreservedAnalyses TailCallElimPass::run(Function &F,
TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F);
AliasAnalysis &AA = AM.getResult<AAManager>(F);
+ // This must come first. It needs the 2 analyses, meaning, if it came after
+ // the lines asking for the cached result, should they be nullptr (which, in
+ // the case of the PDT, is likely), updates to the trees would be missed.
+ auto *BFI = (!ForceDisableBFI && UpdateFunctionEntryCount &&
+ F.getEntryCount().has_value())
+ ? &AM.getResult<BlockFrequencyAnalysis>(F)
+ : nullptr;
auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F);
auto *PDT = AM.getCachedResult<PostDominatorTreeAnalysis>(F);
@@ -960,7 +999,8 @@ PreservedAnalyses TailCallElimPass::run(Function &F,
// UpdateStrategy based on some test results. It is feasible to switch the
// UpdateStrategy to Lazy if we find it profitable later.
DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Eager);
- bool Changed = TailRecursionEliminator::eliminate(F, &TTI, &AA, &ORE, DTU);
+ bool Changed =
+ TailRecursionEliminator::eliminate(F, &TTI, &AA, &ORE, DTU, BFI);
if (!Changed)
return PreservedAnalyses::all();
diff --git a/llvm/test/Transforms/TailCallElim/entry-count-adjustment.ll b/llvm/test/Transforms/TailCallElim/entry-count-adjustment.ll
new file mode 100644
index 0000000000000..94be821d92268
--- /dev/null
+++ b/llvm/test/Transforms/TailCallElim/entry-count-adjustment.ll
@@ -0,0 +1,116 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals
+; RUN: opt -passes=tailcallelim -S %s -o - | FileCheck %s
+
+; Test that tail call elimination correctly adjusts function entry counts
+; when eliminating tail recursive calls.
+
+; Basic test: eliminate a tail call and adjust entry count
+define i32 @test_basic_entry_count_adjustment(i32 %n) !prof !0 {
+; CHECK-LABEL: @test_basic_entry_count_adjustment(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br label [[TAILRECURSE:%.*]]
+; CHECK: tailrecurse:
+; CHECK-NEXT: [[N_TR:%.*]] = phi i32 [ [[N:%.*]], [[ENTRY:%.*]] ], [ [[SUB:%.*]], [[IF_THEN:%.*]] ]
+; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[N_TR]], 0
+; CHECK-NEXT: br i1 [[CMP]], label [[IF_THEN]], label [[IF_ELSE:%.*]], !prof [[PROF1:![0-9]+]]
+; CHECK: if.then:
+; CHECK-NEXT: [[SUB]] = sub i32 [[N_TR]], 1
+; CHECK-NEXT: br label [[TAILRECURSE]]
+; CHECK: if.else:
+; CHECK-NEXT: ret i32 0
+;
+entry:
+ %cmp = icmp sgt i32 %n, 0
+ br i1 %cmp, label %if.then, label %if.else, !prof !1
+
+if.then: ; preds = %entry
+ %sub = sub i32 %n, 1
+ %call = tail call i32 @test_basic_entry_count_adjustment(i32 %sub)
+ ret i32 %call
+
+if.else: ; preds = %entry
+ ret i32 0
+}
+
+; Test multiple tail calls in different blocks with different frequencies
+define i32 @test_multiple_blocks_entry_count(i32 %n, i32 %flag) !prof !2 {
+; CHECK-LABEL: @test_multiple_blocks_entry_count(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br label [[TAILRECURSE:%.*]]
+; CHECK: tailrecurse:
+; CHECK-NEXT: [[N_TR:%.*]] = phi i32 [ [[N:%.*]], [[ENTRY:%.*]] ], [ [[SUB1:%.*]], [[BLOCK1:%.*]] ], [ [[SUB2:%.*]], [[BLOCK2:%.*]] ]
+; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[N_TR]], 0
+; CHECK-NEXT: br i1 [[CMP]], label [[CHECK_FLAG:%.*]], label [[BASE_CASE:%.*]]
+; CHECK: check.flag:
+; CHECK-NEXT: [[CMP_FLAG:%.*]] = icmp eq i32 [[FLAG:%.*]], 1
+; CHECK-NEXT: br i1 [[CMP_FLAG]], label [[BLOCK1]], label [[BLOCK2]], !prof [[PROF3:![0-9]+]]
+; CHECK: block1:
+; CHECK-NEXT: [[SUB1]] = sub i32 [[N_TR]], 1
+; CHECK-NEXT: br label [[TAILRECURSE]]
+; CHECK: block2:
+; CHECK-NEXT: [[SUB2]] = sub i32 [[N_TR]], 2
+; CHECK-NEXT: br label [[TAILRECURSE]]
+; CHECK: base.case:
+; CHECK-NEXT: ret i32 1
+;
+entry:
+ %cmp = icmp sgt i32 %n, 0
+ br i1 %cmp, label %check.flag, label %base.case
+
+check.flag:
+ %cmp.flag = icmp eq i32 %flag, 1
+ br i1 %cmp.flag, label %block1, label %block2, !prof !3
+
+block1: ; preds = %check.flag
+ %sub1 = sub i32 %n, 1
+ %call1 = tail call i32 @test_multiple_blocks_entry_count(i32 %sub1, i32 %flag)
+ ret i32 %call1
+
+block2: ; preds = %check.flag
+ %sub2 = sub i32 %n, 2
+ %call2 = tail call i32 @test_multiple_blocks_entry_count(i32 %sub2, i32 %flag)
+ ret i32 %call2
+
+base.case: ; preds = %entry
+ ret i32 1
+}
+
+; Test function without entry count (should not crash)
+define i32 @test_no_entry_count(i32 %n) {
+; CHECK-LABEL: @test_no_entry_count(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br label [[TAILRECURSE:%.*]]
+; CHECK: tailrecurse:
+; CHECK-NEXT: [[N_TR:%.*]] = phi i32 [ [[N:%.*]], [[ENTRY:%.*]] ], [ [[SUB:%.*]], [[IF_THEN:%.*]] ]
+; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[N_TR]], 0
+; CHECK-NEXT: br i1 [[CMP]], label [[IF_THEN]], label [[IF_ELSE:%.*]]
+; CHECK: if.then:
+; CHECK-NEXT: [[SUB]] = sub i32 [[N_TR]], 1
+; CHECK-NEXT: br label [[TAILRECURSE]]
+; CHECK: if.else:
+; CHECK-NEXT: ret i32 0
+;
+entry:
+ %cmp = icmp sgt i32 %n, 0
+ br i1 %cmp, label %if.then, label %if.else
+
+if.then: ; preds = %entry
+ %sub = sub i32 %n, 1
+ %call = tail call i32 @test_no_entry_count(i32 %sub)
+ ret i32 %call
+
+if.else: ; preds = %entry
+ ret i32 0
+}
+
+; Function entry count metadata
+!0 = !{!"function_entry_count", i64 1000}
+!1 = !{!"branch_weights", i32 800, i32 200}
+!2 = !{!"function_entry_count", i64 2000}
+!3 = !{!"branch_weights", i32 100, i32 500}
+;.
+; CHECK: [[META0:![0-9]+]] = !{!"function_entry_count", i64 201}
+; CHECK: [[PROF1]] = !{!"branch_weights", i32 800, i32 200}
+; CHECK: [[META2:![0-9]+]] = !{!"function_entry_count", i64 859}
+; CHECK: [[PROF3]] = !{!"branch_weights", i32 100, i32 500}
+;.
``````````
</details>
https://github.com/llvm/llvm-project/pull/143987
More information about the llvm-commits
mailing list