[llvm] [TRE] Adjust function entry count when using instrumented profiles (PR #143987)
Mircea Trofin via llvm-commits
llvm-commits at lists.llvm.org
Thu Jun 19 07:53:01 PDT 2025
https://github.com/mtrofin updated https://github.com/llvm/llvm-project/pull/143987
>From 51cf14ad42a03c71df968706d776015551bd4a2d Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Thu, 12 Jun 2025 15:34:46 -0700
Subject: [PATCH] [TailRecursionElim] Adjust function entry count
---
llvm/include/llvm/Passes/PassBuilder.h | 2 +
.../Scalar/TailRecursionElimination.h | 7 +-
llvm/lib/Passes/PassBuilderPipelines.cpp | 14 +-
.../Scalar/TailRecursionElimination.cpp | 67 ++++++-
.../TailCallElim/entry-count-adjustment.ll | 167 ++++++++++++++++++
5 files changed, 246 insertions(+), 11 deletions(-)
create mode 100644 llvm/test/Transforms/TailCallElim/entry-count-adjustment.ll
diff --git a/llvm/include/llvm/Passes/PassBuilder.h b/llvm/include/llvm/Passes/PassBuilder.h
index f13b5c678a894..9cdb7ca7dbc9b 100644
--- a/llvm/include/llvm/Passes/PassBuilder.h
+++ b/llvm/include/llvm/Passes/PassBuilder.h
@@ -773,6 +773,8 @@ class PassBuilder {
IntrusiveRefCntPtr<vfs::FileSystem> FS);
void addPostPGOLoopRotation(ModulePassManager &MPM, OptimizationLevel Level);
+ bool isInstrumentedPGOUse() const;
+
// Extension Point callbacks
SmallVector<std::function<void(FunctionPassManager &, OptimizationLevel)>, 2>
PeepholeEPCallbacks;
diff --git a/llvm/include/llvm/Transforms/Scalar/TailRecursionElimination.h b/llvm/include/llvm/Transforms/Scalar/TailRecursionElimination.h
index 57b1ed9bf4fe8..22a70cd66865a 100644
--- a/llvm/include/llvm/Transforms/Scalar/TailRecursionElimination.h
+++ b/llvm/include/llvm/Transforms/Scalar/TailRecursionElimination.h
@@ -58,7 +58,12 @@ namespace llvm {
class Function;
-struct TailCallElimPass : PassInfoMixin<TailCallElimPass> {
+class TailCallElimPass : public PassInfoMixin<TailCallElimPass> {
+ const bool UpdateFunctionEntryCount;
+
+public:
+ TailCallElimPass(bool UpdateFunctionEntryCount = true)
+ : UpdateFunctionEntryCount(UpdateFunctionEntryCount) {}
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
};
}
diff --git a/llvm/lib/Passes/PassBuilderPipelines.cpp b/llvm/lib/Passes/PassBuilderPipelines.cpp
index a99146d5eaa34..b638ce4803a79 100644
--- a/llvm/lib/Passes/PassBuilderPipelines.cpp
+++ b/llvm/lib/Passes/PassBuilderPipelines.cpp
@@ -628,7 +628,8 @@ PassBuilder::buildFunctionSimplificationPipeline(OptimizationLevel Level,
!Level.isOptimizingForSize())
FPM.addPass(PGOMemOPSizeOpt());
- FPM.addPass(TailCallElimPass());
+ FPM.addPass(TailCallElimPass(/*UpdateFunctionEntryCount=*/
+ isInstrumentedPGOUse()));
FPM.addPass(
SimplifyCFGPass(SimplifyCFGOptions().convertSwitchRangeToICmp(true)));
@@ -1581,7 +1582,8 @@ PassBuilder::buildModuleOptimizationPipeline(OptimizationLevel Level,
OptimizePM.addPass(DivRemPairsPass());
// Try to annotate calls that were created during optimization.
- OptimizePM.addPass(TailCallElimPass());
+ OptimizePM.addPass(
+ TailCallElimPass(/*UpdateFunctionEntryCount=*/isInstrumentedPGOUse()));
// LoopSink (and other loop passes since the last simplifyCFG) might have
// resulted in single-entry-single-exit or empty blocks. Clean up the CFG.
@@ -2069,7 +2071,8 @@ PassBuilder::buildLTODefaultPipeline(OptimizationLevel Level,
// LTO provides additional opportunities for tailcall elimination due to
// link-time inlining, and visibility of nocapture attribute.
- FPM.addPass(TailCallElimPass());
+ FPM.addPass(
+ TailCallElimPass(/*UpdateFunctionEntryCount=*/isInstrumentedPGOUse()));
// Run a few AA driver optimizations here and now to cleanup the code.
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM),
@@ -2350,3 +2353,8 @@ AAManager PassBuilder::buildDefaultAAPipeline() {
return AA;
}
+
+bool PassBuilder::isInstrumentedPGOUse() const {
+ return (PGOOpt && PGOOpt->Action == PGOOptions::IRUse) ||
+ !UseCtxProfile.empty();
+}
\ No newline at end of file
diff --git a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
index e7d989a43840d..54a09093812be 100644
--- a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
@@ -53,6 +53,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/BlockFrequencyInfo.h"
#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/InstructionSimplify.h"
@@ -75,10 +76,12 @@
#include "llvm/IR/Module.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
+#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include <cmath>
using namespace llvm;
#define DEBUG_TYPE "tailcallelim"
@@ -87,6 +90,11 @@ STATISTIC(NumEliminated, "Number of tail calls removed");
STATISTIC(NumRetDuped, "Number of return duplicated");
STATISTIC(NumAccumAdded, "Number of accumulators introduced");
+static cl::opt<bool> ForceDisableBFI(
+ "tre-disable-entrycount-recompute", cl::init(false), cl::Hidden,
+ cl::desc("Force disabling recomputing of function entry count, on "
+ "successful tail recursion elimination."));
+
/// Scan the specified function for alloca instructions.
/// If it contains any dynamic allocas, returns false.
static bool canTRE(Function &F) {
@@ -399,6 +407,8 @@ class TailRecursionEliminator {
AliasAnalysis *AA;
OptimizationRemarkEmitter *ORE;
DomTreeUpdater &DTU;
+ BlockFrequencyInfo *const BFI;
+ const uint64_t OrigEntryBBFreq;
// The below are shared state we want to have available when eliminating any
// calls in the function. There values should be populated by
@@ -428,8 +438,19 @@ class TailRecursionEliminator {
TailRecursionEliminator(Function &F, const TargetTransformInfo *TTI,
AliasAnalysis *AA, OptimizationRemarkEmitter *ORE,
- DomTreeUpdater &DTU)
- : F(F), TTI(TTI), AA(AA), ORE(ORE), DTU(DTU) {}
+ DomTreeUpdater &DTU, BlockFrequencyInfo *BFI)
+ : F(F), TTI(TTI), AA(AA), ORE(ORE), DTU(DTU), BFI(BFI),
+ OrigEntryBBFreq(
+ BFI ? BFI->getBlockFreq(&F.getEntryBlock()).getFrequency() : 0U) {
+ if (BFI) {
+ auto EC = F.getEntryCount();
+ (void)EC;
+ assert((EC.has_value() && EC->getCount() != 0 && OrigEntryBBFreq) &&
+ "If a BFI was provided, the function should have both an entry "
+ "count that is non-zero and an entry basic block with a non-zero "
+ "frequency.");
+ }
+ }
CallInst *findTRECandidate(BasicBlock *BB);
@@ -450,7 +471,7 @@ class TailRecursionEliminator {
public:
static bool eliminate(Function &F, const TargetTransformInfo *TTI,
AliasAnalysis *AA, OptimizationRemarkEmitter *ORE,
- DomTreeUpdater &DTU);
+ DomTreeUpdater &DTU, BlockFrequencyInfo *BFI);
};
} // namespace
@@ -735,6 +756,28 @@ bool TailRecursionEliminator::eliminateCall(CallInst *CI) {
CI->eraseFromParent(); // Remove call.
DTU.applyUpdates({{DominatorTree::Insert, BB, HeaderBB}});
++NumEliminated;
+ if (OrigEntryBBFreq) {
+ assert(F.getEntryCount().has_value());
+ // This pass is not expected to remove BBs, only add an entry BB. For that
+ // reason, and because the BB here isn't the new entry BB, the BFI lookup is
+ // expected to succeed.
+ assert(&F.getEntryBlock() != BB);
+ auto RelativeBBFreq =
+ static_cast<double>(BFI->getBlockFreq(BB).getFrequency()) /
+ static_cast<double>(OrigEntryBBFreq);
+ auto OldEntryCount = F.getEntryCount()->getCount();
+ auto ToSubtract =
+ static_cast<uint64_t>(std::round(RelativeBBFreq * OldEntryCount));
+ if (OldEntryCount <= ToSubtract) {
+ LLVM_DEBUG(
+ errs() << "[TRE] The entrycount attributable to the recursive call, "
+ << ToSubtract
+ << ", should be strictly lower than the original function "
+ "entry count, "
+ << OldEntryCount << "\n");
+ }
+ F.setEntryCount(OldEntryCount - ToSubtract, F.getEntryCount()->getType());
+ }
return true;
}
@@ -861,7 +904,8 @@ bool TailRecursionEliminator::eliminate(Function &F,
const TargetTransformInfo *TTI,
AliasAnalysis *AA,
OptimizationRemarkEmitter *ORE,
- DomTreeUpdater &DTU) {
+ DomTreeUpdater &DTU,
+ BlockFrequencyInfo *BFI) {
if (F.getFnAttribute("disable-tail-calls").getValueAsBool())
return false;
@@ -877,7 +921,7 @@ bool TailRecursionEliminator::eliminate(Function &F,
return MadeChange;
// Change any tail recursive calls to loops.
- TailRecursionEliminator TRE(F, TTI, AA, ORE, DTU);
+ TailRecursionEliminator TRE(F, TTI, AA, ORE, DTU, BFI);
for (BasicBlock &BB : F)
MadeChange |= TRE.processBlock(BB);
@@ -919,7 +963,8 @@ struct TailCallElim : public FunctionPass {
return TailRecursionEliminator::eliminate(
F, &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F),
&getAnalysis<AAResultsWrapperPass>().getAAResults(),
- &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(), DTU);
+ &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(), DTU,
+ nullptr);
}
};
}
@@ -942,6 +987,13 @@ PreservedAnalyses TailCallElimPass::run(Function &F,
TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F);
AliasAnalysis &AA = AM.getResult<AAManager>(F);
+ // This must come first. It needs the 2 analyses, meaning, if it came after
+ // the lines asking for the cached result, should they be nullptr (which, in
+ // the case of the PDT, is likely), updates to the trees would be missed.
+ auto *BFI = (!ForceDisableBFI && UpdateFunctionEntryCount &&
+ F.getEntryCount().has_value() && F.getEntryCount()->getCount())
+ ? &AM.getResult<BlockFrequencyAnalysis>(F)
+ : nullptr;
auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F);
auto *PDT = AM.getCachedResult<PostDominatorTreeAnalysis>(F);
@@ -949,7 +1001,8 @@ PreservedAnalyses TailCallElimPass::run(Function &F,
// UpdateStrategy based on some test results. It is feasible to switch the
// UpdateStrategy to Lazy if we find it profitable later.
DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Eager);
- bool Changed = TailRecursionEliminator::eliminate(F, &TTI, &AA, &ORE, DTU);
+ bool Changed =
+ TailRecursionEliminator::eliminate(F, &TTI, &AA, &ORE, DTU, BFI);
if (!Changed)
return PreservedAnalyses::all();
diff --git a/llvm/test/Transforms/TailCallElim/entry-count-adjustment.ll b/llvm/test/Transforms/TailCallElim/entry-count-adjustment.ll
new file mode 100644
index 0000000000000..6537ac21aef29
--- /dev/null
+++ b/llvm/test/Transforms/TailCallElim/entry-count-adjustment.ll
@@ -0,0 +1,167 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals
+; RUN: opt -passes=tailcallelim -S %s -o - | FileCheck %s
+; RUN: opt -passes=tailcallelim -tre-disable-entrycount-recompute -S %s -o - | FileCheck %s --check-prefix=DISABLED
+
+; Test that tail call elimination correctly adjusts function entry counts
+; when eliminating tail recursive calls.
+
+; Basic test: eliminate a tail call and adjust entry count
+define i32 @test_basic_entry_count_adjustment(i32 %n) !prof !0 {
+; CHECK-LABEL: @test_basic_entry_count_adjustment(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br label [[TAILRECURSE:%.*]]
+; CHECK: tailrecurse:
+; CHECK-NEXT: [[N_TR:%.*]] = phi i32 [ [[N:%.*]], [[ENTRY:%.*]] ], [ [[SUB:%.*]], [[IF_THEN:%.*]] ]
+; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[N_TR]], 0
+; CHECK-NEXT: br i1 [[CMP]], label [[IF_THEN]], label [[IF_ELSE:%.*]], !prof [[PROF1:![0-9]+]]
+; CHECK: if.then:
+; CHECK-NEXT: [[SUB]] = sub i32 [[N_TR]], 1
+; CHECK-NEXT: br label [[TAILRECURSE]]
+; CHECK: if.else:
+; CHECK-NEXT: ret i32 0
+;
+; DISABLED-LABEL: @test_basic_entry_count_adjustment(
+; DISABLED-NEXT: entry:
+; DISABLED-NEXT: br label [[TAILRECURSE:%.*]]
+; DISABLED: tailrecurse:
+; DISABLED-NEXT: [[N_TR:%.*]] = phi i32 [ [[N:%.*]], [[ENTRY:%.*]] ], [ [[SUB:%.*]], [[IF_THEN:%.*]] ]
+; DISABLED-NEXT: [[CMP:%.*]] = icmp sgt i32 [[N_TR]], 0
+; DISABLED-NEXT: br i1 [[CMP]], label [[IF_THEN]], label [[IF_ELSE:%.*]], !prof [[PROF1:![0-9]+]]
+; DISABLED: if.then:
+; DISABLED-NEXT: [[SUB]] = sub i32 [[N_TR]], 1
+; DISABLED-NEXT: br label [[TAILRECURSE]]
+; DISABLED: if.else:
+; DISABLED-NEXT: ret i32 0
+;
+entry:
+ %cmp = icmp sgt i32 %n, 0
+ br i1 %cmp, label %if.then, label %if.else, !prof !1
+
+if.then: ; preds = %entry
+ %sub = sub i32 %n, 1
+ %call = tail call i32 @test_basic_entry_count_adjustment(i32 %sub)
+ ret i32 %call
+
+if.else: ; preds = %entry
+ ret i32 0
+}
+
+; Test multiple tail calls in different blocks with different frequencies
+define i32 @test_multiple_blocks_entry_count(i32 %n, i32 %flag) !prof !2 {
+; CHECK-LABEL: @test_multiple_blocks_entry_count(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br label [[TAILRECURSE:%.*]]
+; CHECK: tailrecurse:
+; CHECK-NEXT: [[N_TR:%.*]] = phi i32 [ [[N:%.*]], [[ENTRY:%.*]] ], [ [[SUB1:%.*]], [[BLOCK1:%.*]] ], [ [[SUB2:%.*]], [[BLOCK2:%.*]] ]
+; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[N_TR]], 0
+; CHECK-NEXT: br i1 [[CMP]], label [[CHECK_FLAG:%.*]], label [[BASE_CASE:%.*]]
+; CHECK: check.flag:
+; CHECK-NEXT: [[CMP_FLAG:%.*]] = icmp eq i32 [[FLAG:%.*]], 1
+; CHECK-NEXT: br i1 [[CMP_FLAG]], label [[BLOCK1]], label [[BLOCK2]], !prof [[PROF3:![0-9]+]]
+; CHECK: block1:
+; CHECK-NEXT: [[SUB1]] = sub i32 [[N_TR]], 1
+; CHECK-NEXT: br label [[TAILRECURSE]]
+; CHECK: block2:
+; CHECK-NEXT: [[SUB2]] = sub i32 [[N_TR]], 2
+; CHECK-NEXT: br label [[TAILRECURSE]]
+; CHECK: base.case:
+; CHECK-NEXT: ret i32 1
+;
+; DISABLED-LABEL: @test_multiple_blocks_entry_count(
+; DISABLED-NEXT: entry:
+; DISABLED-NEXT: br label [[TAILRECURSE:%.*]]
+; DISABLED: tailrecurse:
+; DISABLED-NEXT: [[N_TR:%.*]] = phi i32 [ [[N:%.*]], [[ENTRY:%.*]] ], [ [[SUB1:%.*]], [[BLOCK1:%.*]] ], [ [[SUB2:%.*]], [[BLOCK2:%.*]] ]
+; DISABLED-NEXT: [[CMP:%.*]] = icmp sgt i32 [[N_TR]], 0
+; DISABLED-NEXT: br i1 [[CMP]], label [[CHECK_FLAG:%.*]], label [[BASE_CASE:%.*]]
+; DISABLED: check.flag:
+; DISABLED-NEXT: [[CMP_FLAG:%.*]] = icmp eq i32 [[FLAG:%.*]], 1
+; DISABLED-NEXT: br i1 [[CMP_FLAG]], label [[BLOCK1]], label [[BLOCK2]], !prof [[PROF3:![0-9]+]]
+; DISABLED: block1:
+; DISABLED-NEXT: [[SUB1]] = sub i32 [[N_TR]], 1
+; DISABLED-NEXT: br label [[TAILRECURSE]]
+; DISABLED: block2:
+; DISABLED-NEXT: [[SUB2]] = sub i32 [[N_TR]], 2
+; DISABLED-NEXT: br label [[TAILRECURSE]]
+; DISABLED: base.case:
+; DISABLED-NEXT: ret i32 1
+;
+entry:
+ %cmp = icmp sgt i32 %n, 0
+ br i1 %cmp, label %check.flag, label %base.case
+
+check.flag:
+ %cmp.flag = icmp eq i32 %flag, 1
+ br i1 %cmp.flag, label %block1, label %block2, !prof !3
+
+block1: ; preds = %check.flag
+ %sub1 = sub i32 %n, 1
+ %call1 = tail call i32 @test_multiple_blocks_entry_count(i32 %sub1, i32 %flag)
+ ret i32 %call1
+
+block2: ; preds = %check.flag
+ %sub2 = sub i32 %n, 2
+ %call2 = tail call i32 @test_multiple_blocks_entry_count(i32 %sub2, i32 %flag)
+ ret i32 %call2
+
+base.case: ; preds = %entry
+ ret i32 1
+}
+
+; Test function without entry count (should not crash)
+define i32 @test_no_entry_count(i32 %n) {
+; CHECK-LABEL: @test_no_entry_count(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br label [[TAILRECURSE:%.*]]
+; CHECK: tailrecurse:
+; CHECK-NEXT: [[N_TR:%.*]] = phi i32 [ [[N:%.*]], [[ENTRY:%.*]] ], [ [[SUB:%.*]], [[IF_THEN:%.*]] ]
+; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[N_TR]], 0
+; CHECK-NEXT: br i1 [[CMP]], label [[IF_THEN]], label [[IF_ELSE:%.*]]
+; CHECK: if.then:
+; CHECK-NEXT: [[SUB]] = sub i32 [[N_TR]], 1
+; CHECK-NEXT: br label [[TAILRECURSE]]
+; CHECK: if.else:
+; CHECK-NEXT: ret i32 0
+;
+; DISABLED-LABEL: @test_no_entry_count(
+; DISABLED-NEXT: entry:
+; DISABLED-NEXT: br label [[TAILRECURSE:%.*]]
+; DISABLED: tailrecurse:
+; DISABLED-NEXT: [[N_TR:%.*]] = phi i32 [ [[N:%.*]], [[ENTRY:%.*]] ], [ [[SUB:%.*]], [[IF_THEN:%.*]] ]
+; DISABLED-NEXT: [[CMP:%.*]] = icmp sgt i32 [[N_TR]], 0
+; DISABLED-NEXT: br i1 [[CMP]], label [[IF_THEN]], label [[IF_ELSE:%.*]]
+; DISABLED: if.then:
+; DISABLED-NEXT: [[SUB]] = sub i32 [[N_TR]], 1
+; DISABLED-NEXT: br label [[TAILRECURSE]]
+; DISABLED: if.else:
+; DISABLED-NEXT: ret i32 0
+;
+entry:
+ %cmp = icmp sgt i32 %n, 0
+ br i1 %cmp, label %if.then, label %if.else
+
+if.then: ; preds = %entry
+ %sub = sub i32 %n, 1
+ %call = tail call i32 @test_no_entry_count(i32 %sub)
+ ret i32 %call
+
+if.else: ; preds = %entry
+ ret i32 0
+}
+
+; Function entry count metadata
+!0 = !{!"function_entry_count", i64 1000}
+!1 = !{!"branch_weights", i32 800, i32 200}
+!2 = !{!"function_entry_count", i64 2000}
+!3 = !{!"branch_weights", i32 100, i32 500}
+;.
+; CHECK: [[META0:![0-9]+]] = !{!"function_entry_count", i64 200}
+; CHECK: [[PROF1]] = !{!"branch_weights", i32 800, i32 200}
+; CHECK: [[META2:![0-9]+]] = !{!"function_entry_count", i64 859}
+; CHECK: [[PROF3]] = !{!"branch_weights", i32 100, i32 500}
+;.
+; DISABLED: [[META0:![0-9]+]] = !{!"function_entry_count", i64 1000}
+; DISABLED: [[PROF1]] = !{!"branch_weights", i32 800, i32 200}
+; DISABLED: [[META2:![0-9]+]] = !{!"function_entry_count", i64 2000}
+; DISABLED: [[PROF3]] = !{!"branch_weights", i32 100, i32 500}
+;.
More information about the llvm-commits
mailing list