[llvm] daa2a58 - [TRE] Adjust function entry count when using instrumented profiles (#143987)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Jun 23 08:07:35 PDT 2025
Author: Mircea Trofin
Date: 2025-06-23T08:07:31-07:00
New Revision: daa2a587cc01c5656deecda7f768fed0afc1e515
URL: https://github.com/llvm/llvm-project/commit/daa2a587cc01c5656deecda7f768fed0afc1e515
DIFF: https://github.com/llvm/llvm-project/commit/daa2a587cc01c5656deecda7f768fed0afc1e515.diff
LOG: [TRE] Adjust function entry count when using instrumented profiles (#143987)
The entry count of a function needs to be updated after a callsite is elided by TRE: before elision, the entry count accounted for the recursive call at that callsite. After TRE, we need to remove that callsite's contribution.
This patch enables this for instrumented profiling cases because, there, we know the function entry count captured entries before TRE. We cannot currently address this for sample-based (because we don't know whether this function was TRE-ed in the binary that donated samples)
Added:
llvm/test/Transforms/TailCallElim/entry-count-adjustment.ll
Modified:
llvm/include/llvm/Passes/PassBuilder.h
llvm/include/llvm/Transforms/Scalar/TailRecursionElimination.h
llvm/lib/Passes/PassBuilderPipelines.cpp
llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Passes/PassBuilder.h b/llvm/include/llvm/Passes/PassBuilder.h
index f13b5c678a894..9cdb7ca7dbc9b 100644
--- a/llvm/include/llvm/Passes/PassBuilder.h
+++ b/llvm/include/llvm/Passes/PassBuilder.h
@@ -773,6 +773,8 @@ class PassBuilder {
IntrusiveRefCntPtr<vfs::FileSystem> FS);
void addPostPGOLoopRotation(ModulePassManager &MPM, OptimizationLevel Level);
+ bool isInstrumentedPGOUse() const;
+
// Extension Point callbacks
SmallVector<std::function<void(FunctionPassManager &, OptimizationLevel)>, 2>
PeepholeEPCallbacks;
diff --git a/llvm/include/llvm/Transforms/Scalar/TailRecursionElimination.h b/llvm/include/llvm/Transforms/Scalar/TailRecursionElimination.h
index 57b1ed9bf4fe8..22a70cd66865a 100644
--- a/llvm/include/llvm/Transforms/Scalar/TailRecursionElimination.h
+++ b/llvm/include/llvm/Transforms/Scalar/TailRecursionElimination.h
@@ -58,7 +58,12 @@ namespace llvm {
class Function;
-struct TailCallElimPass : PassInfoMixin<TailCallElimPass> {
+class TailCallElimPass : public PassInfoMixin<TailCallElimPass> {
+ const bool UpdateFunctionEntryCount;
+
+public:
+ TailCallElimPass(bool UpdateFunctionEntryCount = true)
+ : UpdateFunctionEntryCount(UpdateFunctionEntryCount) {}
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
};
}
diff --git a/llvm/lib/Passes/PassBuilderPipelines.cpp b/llvm/lib/Passes/PassBuilderPipelines.cpp
index b0cdd1b94e565..c83d2dc1f1514 100644
--- a/llvm/lib/Passes/PassBuilderPipelines.cpp
+++ b/llvm/lib/Passes/PassBuilderPipelines.cpp
@@ -625,7 +625,8 @@ PassBuilder::buildFunctionSimplificationPipeline(OptimizationLevel Level,
!Level.isOptimizingForSize())
FPM.addPass(PGOMemOPSizeOpt());
- FPM.addPass(TailCallElimPass());
+ FPM.addPass(TailCallElimPass(/*UpdateFunctionEntryCount=*/
+ isInstrumentedPGOUse()));
FPM.addPass(
SimplifyCFGPass(SimplifyCFGOptions().convertSwitchRangeToICmp(true)));
@@ -1578,7 +1579,8 @@ PassBuilder::buildModuleOptimizationPipeline(OptimizationLevel Level,
OptimizePM.addPass(DivRemPairsPass());
// Try to annotate calls that were created during optimization.
- OptimizePM.addPass(TailCallElimPass());
+ OptimizePM.addPass(
+ TailCallElimPass(/*UpdateFunctionEntryCount=*/isInstrumentedPGOUse()));
// LoopSink (and other loop passes since the last simplifyCFG) might have
// resulted in single-entry-single-exit or empty blocks. Clean up the CFG.
@@ -2066,7 +2068,8 @@ PassBuilder::buildLTODefaultPipeline(OptimizationLevel Level,
// LTO provides additional opportunities for tailcall elimination due to
// link-time inlining, and visibility of nocapture attribute.
- FPM.addPass(TailCallElimPass());
+ FPM.addPass(
+ TailCallElimPass(/*UpdateFunctionEntryCount=*/isInstrumentedPGOUse()));
// Run a few AA driver optimizations here and now to cleanup the code.
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM),
@@ -2347,3 +2350,8 @@ AAManager PassBuilder::buildDefaultAAPipeline() {
return AA;
}
+
+bool PassBuilder::isInstrumentedPGOUse() const {
+ return (PGOOpt && PGOOpt->Action == PGOOptions::IRUse) ||
+ !UseCtxProfile.empty();
+}
\ No newline at end of file
diff --git a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
index e7d989a43840d..7828571123bca 100644
--- a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
@@ -53,6 +53,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/BlockFrequencyInfo.h"
#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/InstructionSimplify.h"
@@ -75,10 +76,12 @@
#include "llvm/IR/Module.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
+#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include <cmath>
using namespace llvm;
#define DEBUG_TYPE "tailcallelim"
@@ -87,6 +90,11 @@ STATISTIC(NumEliminated, "Number of tail calls removed");
STATISTIC(NumRetDuped, "Number of return duplicated");
STATISTIC(NumAccumAdded, "Number of accumulators introduced");
+static cl::opt<bool> ForceDisableBFI(
+ "tre-disable-entrycount-recompute", cl::init(false), cl::Hidden,
+ cl::desc("Force disabling recomputing of function entry count, on "
+ "successful tail recursion elimination."));
+
/// Scan the specified function for alloca instructions.
/// If it contains any dynamic allocas, returns false.
static bool canTRE(Function &F) {
@@ -399,6 +407,9 @@ class TailRecursionEliminator {
AliasAnalysis *AA;
OptimizationRemarkEmitter *ORE;
DomTreeUpdater &DTU;
+ BlockFrequencyInfo *const BFI;
+ const uint64_t OrigEntryBBFreq;
+ const uint64_t OrigEntryCount;
// The below are shared state we want to have available when eliminating any
// calls in the function. There values should be populated by
@@ -428,8 +439,19 @@ class TailRecursionEliminator {
TailRecursionEliminator(Function &F, const TargetTransformInfo *TTI,
AliasAnalysis *AA, OptimizationRemarkEmitter *ORE,
- DomTreeUpdater &DTU)
- : F(F), TTI(TTI), AA(AA), ORE(ORE), DTU(DTU) {}
+ DomTreeUpdater &DTU, BlockFrequencyInfo *BFI)
+ : F(F), TTI(TTI), AA(AA), ORE(ORE), DTU(DTU), BFI(BFI),
+ OrigEntryBBFreq(
+ BFI ? BFI->getBlockFreq(&F.getEntryBlock()).getFrequency() : 0U),
+ OrigEntryCount(F.getEntryCount() ? F.getEntryCount()->getCount() : 0) {
+ if (BFI) {
+ // The assert is meant as API documentation for the caller.
+ assert((OrigEntryCount != 0 && OrigEntryBBFreq != 0) &&
+ "If a BFI was provided, the function should have both an entry "
+ "count that is non-zero and an entry basic block with a non-zero "
+ "frequency.");
+ }
+ }
CallInst *findTRECandidate(BasicBlock *BB);
@@ -450,7 +472,7 @@ class TailRecursionEliminator {
public:
static bool eliminate(Function &F, const TargetTransformInfo *TTI,
AliasAnalysis *AA, OptimizationRemarkEmitter *ORE,
- DomTreeUpdater &DTU);
+ DomTreeUpdater &DTU, BlockFrequencyInfo *BFI);
};
} // namespace
@@ -735,6 +757,28 @@ bool TailRecursionEliminator::eliminateCall(CallInst *CI) {
CI->eraseFromParent(); // Remove call.
DTU.applyUpdates({{DominatorTree::Insert, BB, HeaderBB}});
++NumEliminated;
+ if (OrigEntryBBFreq) {
+ assert(F.getEntryCount().has_value());
+ // This pass is not expected to remove BBs, only add an entry BB. For that
+ // reason, and because the BB here isn't the new entry BB, the BFI lookup is
+ // expected to succeed.
+ assert(&F.getEntryBlock() != BB);
+ auto RelativeBBFreq =
+ static_cast<double>(BFI->getBlockFreq(BB).getFrequency()) /
+ static_cast<double>(OrigEntryBBFreq);
+ auto ToSubtract =
+ static_cast<uint64_t>(std::round(RelativeBBFreq * OrigEntryCount));
+ auto OldEntryCount = F.getEntryCount()->getCount();
+ if (OldEntryCount <= ToSubtract) {
+ LLVM_DEBUG(
+ errs() << "[TRE] The entrycount attributable to the recursive call, "
+ << ToSubtract
+ << ", should be strictly lower than the function entry count, "
+ << OldEntryCount << "\n");
+ } else {
+ F.setEntryCount(OldEntryCount - ToSubtract, F.getEntryCount()->getType());
+ }
+ }
return true;
}
@@ -861,7 +905,8 @@ bool TailRecursionEliminator::eliminate(Function &F,
const TargetTransformInfo *TTI,
AliasAnalysis *AA,
OptimizationRemarkEmitter *ORE,
- DomTreeUpdater &DTU) {
+ DomTreeUpdater &DTU,
+ BlockFrequencyInfo *BFI) {
if (F.getFnAttribute("disable-tail-calls").getValueAsBool())
return false;
@@ -877,7 +922,7 @@ bool TailRecursionEliminator::eliminate(Function &F,
return MadeChange;
// Change any tail recursive calls to loops.
- TailRecursionEliminator TRE(F, TTI, AA, ORE, DTU);
+ TailRecursionEliminator TRE(F, TTI, AA, ORE, DTU, BFI);
for (BasicBlock &BB : F)
MadeChange |= TRE.processBlock(BB);
@@ -919,7 +964,8 @@ struct TailCallElim : public FunctionPass {
return TailRecursionEliminator::eliminate(
F, &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F),
&getAnalysis<AAResultsWrapperPass>().getAAResults(),
- &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(), DTU);
+ &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(), DTU,
+ /*BFI=*/nullptr);
}
};
}
@@ -942,6 +988,13 @@ PreservedAnalyses TailCallElimPass::run(Function &F,
TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F);
AliasAnalysis &AA = AM.getResult<AAManager>(F);
+ // This must come first. It needs the 2 analyses, meaning, if it came after
+ // the lines asking for the cached result, should they be nullptr (which, in
+ // the case of the PDT, is likely), updates to the trees would be missed.
+ auto *BFI = (!ForceDisableBFI && UpdateFunctionEntryCount &&
+ F.getEntryCount().has_value() && F.getEntryCount()->getCount())
+ ? &AM.getResult<BlockFrequencyAnalysis>(F)
+ : nullptr;
auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F);
auto *PDT = AM.getCachedResult<PostDominatorTreeAnalysis>(F);
@@ -949,7 +1002,8 @@ PreservedAnalyses TailCallElimPass::run(Function &F,
// UpdateStrategy based on some test results. It is feasible to switch the
// UpdateStrategy to Lazy if we find it profitable later.
DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Eager);
- bool Changed = TailRecursionEliminator::eliminate(F, &TTI, &AA, &ORE, DTU);
+ bool Changed =
+ TailRecursionEliminator::eliminate(F, &TTI, &AA, &ORE, DTU, BFI);
if (!Changed)
return PreservedAnalyses::all();
diff --git a/llvm/test/Transforms/TailCallElim/entry-count-adjustment.ll b/llvm/test/Transforms/TailCallElim/entry-count-adjustment.ll
new file mode 100644
index 0000000000000..6001e6040a741
--- /dev/null
+++ b/llvm/test/Transforms/TailCallElim/entry-count-adjustment.ll
@@ -0,0 +1,120 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals
+; RUN: opt -passes=tailcallelim -S %s -o - | FileCheck %s --check-prefixes=CHECK,ENABLED
+; RUN: opt -passes=tailcallelim -tre-disable-entrycount-recompute -S %s -o - | FileCheck %s --check-prefixes=CHECK,DISABLED
+
+; Test that tail call elimination correctly adjusts function entry counts
+; when eliminating tail recursive calls.
+
+; Basic test: eliminate a tail call and adjust entry count
+define i32 @test_basic_entry_count_adjustment(i32 %n) !prof !0 {
+; CHECK-LABEL: @test_basic_entry_count_adjustment(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br label [[TAILRECURSE:%.*]]
+; CHECK: tailrecurse:
+; CHECK-NEXT: [[N_TR:%.*]] = phi i32 [ [[N:%.*]], [[ENTRY:%.*]] ], [ [[SUB:%.*]], [[IF_THEN:%.*]] ]
+; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[N_TR]], 0
+; CHECK-NEXT: br i1 [[CMP]], label [[IF_THEN]], label [[IF_ELSE:%.*]], !prof [[PROF1:![0-9]+]]
+; CHECK: if.then:
+; CHECK-NEXT: [[SUB]] = sub i32 [[N_TR]], 1
+; CHECK-NEXT: br label [[TAILRECURSE]]
+; CHECK: if.else:
+; CHECK-NEXT: ret i32 0
+;
+entry:
+ %cmp = icmp sgt i32 %n, 0
+ br i1 %cmp, label %if.then, label %if.else, !prof !1
+
+if.then: ; preds = %entry
+ %sub = sub i32 %n, 1
+ %call = tail call i32 @test_basic_entry_count_adjustment(i32 %sub)
+ ret i32 %call
+
+if.else: ; preds = %entry
+ ret i32 0
+}
+
+; Test multiple tail calls in
diff erent blocks with
diff erent frequencies
+define i32 @test_multiple_blocks_entry_count(i32 %n, i32 %flag) !prof !2 {
+; CHECK-LABEL: @test_multiple_blocks_entry_count(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br label [[TAILRECURSE:%.*]]
+; CHECK: tailrecurse:
+; CHECK-NEXT: [[N_TR:%.*]] = phi i32 [ [[N:%.*]], [[ENTRY:%.*]] ], [ [[SUB1:%.*]], [[BLOCK1:%.*]] ], [ [[SUB2:%.*]], [[BLOCK2:%.*]] ]
+; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[N_TR]], 0
+; CHECK-NEXT: br i1 [[CMP]], label [[CHECK_FLAG:%.*]], label [[BASE_CASE:%.*]], !prof [[PROF3:![0-9]+]]
+; CHECK: check.flag:
+; CHECK-NEXT: [[CMP_FLAG:%.*]] = icmp eq i32 [[FLAG:%.*]], 1
+; CHECK-NEXT: br i1 [[CMP_FLAG]], label [[BLOCK1]], label [[BLOCK2]], !prof [[PROF4:![0-9]+]]
+; CHECK: block1:
+; CHECK-NEXT: [[SUB1]] = sub i32 [[N_TR]], 1
+; CHECK-NEXT: br label [[TAILRECURSE]]
+; CHECK: block2:
+; CHECK-NEXT: [[SUB2]] = sub i32 [[N_TR]], 2
+; CHECK-NEXT: br label [[TAILRECURSE]]
+; CHECK: base.case:
+; CHECK-NEXT: ret i32 1
+;
+entry:
+ %cmp = icmp sgt i32 %n, 0
+ br i1 %cmp, label %check.flag, label %base.case, !prof !3
+check.flag:
+ %cmp.flag = icmp eq i32 %flag, 1
+ br i1 %cmp.flag, label %block1, label %block2, !prof !4
+block1: ; preds = %check.flag
+ %sub1 = sub i32 %n, 1
+ %call1 = tail call i32 @test_multiple_blocks_entry_count(i32 %sub1, i32 %flag)
+ ret i32 %call1
+block2: ; preds = %check.flag
+ %sub2 = sub i32 %n, 2
+ %call2 = tail call i32 @test_multiple_blocks_entry_count(i32 %sub2, i32 %flag)
+ ret i32 %call2
+base.case: ; preds = %entry
+ ret i32 1
+}
+
+define i32 @test_no_entry_count(i32 %n) {
+; CHECK-LABEL: @test_no_entry_count(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br label [[TAILRECURSE:%.*]]
+; CHECK: tailrecurse:
+; CHECK-NEXT: [[N_TR:%.*]] = phi i32 [ [[N:%.*]], [[ENTRY:%.*]] ], [ [[SUB:%.*]], [[IF_THEN:%.*]] ]
+; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[N_TR]], 0
+; CHECK-NEXT: br i1 [[CMP]], label [[IF_THEN]], label [[IF_ELSE:%.*]]
+; CHECK: if.then:
+; CHECK-NEXT: [[SUB]] = sub i32 [[N_TR]], 1
+; CHECK-NEXT: br label [[TAILRECURSE]]
+; CHECK: if.else:
+; CHECK-NEXT: ret i32 0
+;
+entry:
+ %cmp = icmp sgt i32 %n, 0
+ br i1 %cmp, label %if.then, label %if.else
+
+if.then: ; preds = %entry
+ %sub = sub i32 %n, 1
+ %call = tail call i32 @test_no_entry_count(i32 %sub)
+ ret i32 %call
+
+if.else: ; preds = %entry
+ ret i32 0
+}
+
+; Function entry count metadata
+!0 = !{!"function_entry_count", i64 1000}
+!1 = !{!"branch_weights", i32 800, i32 200}
+!2 = !{!"function_entry_count", i64 2000}
+!3 = !{!"branch_weights", i32 3, i32 1}
+!4 = !{!"branch_weights", i32 100, i32 900}
+;.
+; ENABLED: [[META0:![0-9]+]] = !{!"function_entry_count", i64 200}
+; ENABLED: [[PROF1]] = !{!"branch_weights", i32 800, i32 200}
+; ENABLED: [[META2:![0-9]+]] = !{!"function_entry_count", i64 500}
+; ENABLED: [[PROF3]] = !{!"branch_weights", i32 3, i32 1}
+; ENABLED: [[PROF4]] = !{!"branch_weights", i32 100, i32 900}
+;.
+; DISABLED: [[META0:![0-9]+]] = !{!"function_entry_count", i64 1000}
+; DISABLED: [[PROF1]] = !{!"branch_weights", i32 800, i32 200}
+; DISABLED: [[META2:![0-9]+]] = !{!"function_entry_count", i64 2000}
+; DISABLED: [[PROF3]] = !{!"branch_weights", i32 3, i32 1}
+; DISABLED: [[PROF4]] = !{!"branch_weights", i32 100, i32 900}
+;.
More information about the llvm-commits
mailing list