[llvm] [JTS] Propagate profile info (PR #153305)
Mircea Trofin via llvm-commits
llvm-commits at lists.llvm.org
Thu Aug 14 07:44:38 PDT 2025
https://github.com/mtrofin updated https://github.com/llvm/llvm-project/pull/153305
>From 94c57f347294051d85e30f281d7d8c644cb2ceff Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Tue, 12 Aug 2025 14:30:50 -0700
Subject: [PATCH] [JTS] Propagate profile info
---
llvm/include/llvm/ProfileData/InstrProf.h | 4 +
.../Transforms/Scalar/JumpTableToSwitch.cpp | 74 +++++++++++++++++--
.../Transforms/JumpTableToSwitch/basic.ll | 9 ++-
3 files changed, 79 insertions(+), 8 deletions(-)
diff --git a/llvm/include/llvm/ProfileData/InstrProf.h b/llvm/include/llvm/ProfileData/InstrProf.h
index bab1963dba22e..85a9efe73855b 100644
--- a/llvm/include/llvm/ProfileData/InstrProf.h
+++ b/llvm/include/llvm/ProfileData/InstrProf.h
@@ -665,6 +665,10 @@ class InstrProfSymtab {
return Error::success();
}
+ const std::vector<std::pair<uint64_t, Function *>> &getIDToNameMap() const {
+ return MD5FuncMap;
+ }
+
const StringSet<> &getVTableNames() const { return VTableNames; }
/// Map a function address to its name's MD5 hash. This interface
diff --git a/llvm/lib/Transforms/Scalar/JumpTableToSwitch.cpp b/llvm/lib/Transforms/Scalar/JumpTableToSwitch.cpp
index 7f99cd2060a9d..6719ce64b96b6 100644
--- a/llvm/lib/Transforms/Scalar/JumpTableToSwitch.cpp
+++ b/llvm/lib/Transforms/Scalar/JumpTableToSwitch.cpp
@@ -7,14 +7,24 @@
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Scalar/JumpTableToSwitch.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/ConstantFolding.h"
+#include "llvm/Analysis/CtxProfAnalysis.h"
#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/PostDominators.h"
#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/ProfDataUtils.h"
+#include "llvm/ProfileData/InstrProf.h"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include <limits>
using namespace llvm;
@@ -33,6 +43,8 @@ static cl::opt<unsigned> FunctionSizeThreshold(
"or equal than this threshold."),
cl::init(50));
+extern cl::opt<bool> ProfcheckDisableMetadataFixes;
+
#define DEBUG_TYPE "jump-table-to-switch"
namespace {
@@ -90,9 +102,11 @@ static std::optional<JumpTableTy> parseJumpTable(GetElementPtrInst *GEP,
return JumpTable;
}
-static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
- DomTreeUpdater &DTU,
- OptimizationRemarkEmitter &ORE) {
+static BasicBlock *
+expandToSwitch(CallBase *CB, const JumpTableTy &JT, DomTreeUpdater &DTU,
+ OptimizationRemarkEmitter &ORE,
+ llvm::function_ref<GlobalValue::GUID(const Function &)>
+ GetGuidForFunction) {
const bool IsVoid = CB->getType() == Type::getVoidTy(CB->getContext());
SmallVector<DominatorTree::UpdateType, 8> DTUpdates;
@@ -115,7 +129,31 @@ static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
IRBuilder<> BuilderTail(CB);
PHINode *PHI =
IsVoid ? nullptr : BuilderTail.CreatePHI(CB->getType(), JT.Funcs.size());
-
+ const auto *ProfMD = CB->getMetadata(LLVMContext::MD_prof);
+
+ SmallVector<uint64_t> BranchWeights;
+ DenseMap<GlobalValue::GUID, uint64_t> GuidToCounter;
+ const bool HadProfile = isValueProfileMD(ProfMD);
+ if (HadProfile) {
+ // The assumptions, coming in, are that the functions in JT.Funcs are
+ // defined in this module (from parseJumpTable).
+ assert(llvm::all_of(
+ JT.Funcs, [](const Function *F) { return F && !F->isDeclaration(); }));
+ BranchWeights.reserve(JT.Funcs.size() + 1);
+ // The first is the default target, which is the unreachable block created
+ // above.
+ BranchWeights.push_back(0U);
+ uint64_t TotalCount = 0;
+ auto Targets = getValueProfDataFromInst(
+ *CB, InstrProfValueKind::IPVK_IndirectCallTarget,
+ std::numeric_limits<uint32_t>::max(), TotalCount);
+
+ for (const auto &[G, C] : Targets) {
+ auto It = GuidToCounter.insert({G, C});
+ assert(It.second);
+ (void)It;
+ }
+ }
for (auto [Index, Func] : llvm::enumerate(JT.Funcs)) {
BasicBlock *B = BasicBlock::Create(Func->getContext(),
"call." + Twine(Index), &F, Tail);
@@ -127,6 +165,11 @@ static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
Call->insertInto(B, B->end());
Switch->addCase(
cast<ConstantInt>(ConstantInt::get(JT.Index->getType(), Index)), B);
+ GlobalValue::GUID FctID = GetGuidForFunction(*Func);
+ // It'd be OK to _not_ find target functions in GuidToCounter, e.g. suppose
+ // just some of the jump targets are taken (for the given profile).
+ BranchWeights.push_back(FctID == 0U ? 0U
+ : GuidToCounter.lookup_or(FctID, 0U));
BranchInst::Create(Tail, B);
if (PHI)
PHI->addIncoming(Call, B);
@@ -136,6 +179,13 @@ static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
return OptimizationRemark(DEBUG_TYPE, "ReplacedJumpTableWithSwitch", CB)
<< "expanded indirect call into switch";
});
+ if (HadProfile && !ProfcheckDisableMetadataFixes) {
+ // At least one of the targets must've been taken.
+ assert(llvm::any_of(BranchWeights, [](uint64_t V) { return V != 0; }));
+ setProfMetadata(F.getParent(), Switch, BranchWeights,
+ *llvm::max_element(BranchWeights));
+ } else
+ setExplicitlyUnknownBranchWeights(*Switch);
if (PHI)
CB->replaceAllUsesWith(PHI);
CB->eraseFromParent();
@@ -150,6 +200,15 @@ PreservedAnalyses JumpTableToSwitchPass::run(Function &F,
PostDominatorTree *PDT = AM.getCachedResult<PostDominatorTreeAnalysis>(F);
DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy);
bool Changed = false;
+ InstrProfSymtab Symtab;
+ if (auto E = Symtab.create(*F.getParent()))
+ F.getContext().emitError(
+ "Could not create indirect call table, likely corrupted IR" +
+ toString(std::move(E)));
+ DenseMap<const Function *, GlobalValue::GUID> FToGuid;
+ for (const auto &[G, FPtr] : Symtab.getIDToNameMap())
+ FToGuid.insert({FPtr, G});
+
for (BasicBlock &BB : make_early_inc_range(F)) {
BasicBlock *CurrentBB = &BB;
while (CurrentBB) {
@@ -170,7 +229,12 @@ PreservedAnalyses JumpTableToSwitchPass::run(Function &F,
std::optional<JumpTableTy> JumpTable = parseJumpTable(GEP, PtrTy);
if (!JumpTable)
continue;
- SplittedOutTail = expandToSwitch(Call, *JumpTable, DTU, ORE);
+ SplittedOutTail = expandToSwitch(
+ Call, *JumpTable, DTU, ORE, [&](const Function &Fct) {
+ if (Fct.getMetadata(AssignGUIDPass::GUIDMetadataName))
+ return AssignGUIDPass::getGUID(Fct);
+ return FToGuid.lookup_or(&Fct, 0U);
+ });
Changed = true;
break;
}
diff --git a/llvm/test/Transforms/JumpTableToSwitch/basic.ll b/llvm/test/Transforms/JumpTableToSwitch/basic.ll
index 321f837077ab6..577c2adaf5afa 100644
--- a/llvm/test/Transforms/JumpTableToSwitch/basic.ll
+++ b/llvm/test/Transforms/JumpTableToSwitch/basic.ll
@@ -4,11 +4,11 @@
@func_array = constant [2 x ptr] [ptr @func0, ptr @func1]
-define i32 @func0() {
+define i32 @func0() !guid !0 {
ret i32 1
}
-define i32 @func1() {
+define i32 @func1() !guid !1 {
ret i32 2
}
@@ -42,7 +42,7 @@ define i32 @function_with_jump_table(i32 %index) {
;
%gep = getelementptr inbounds [2 x ptr], ptr @func_array, i32 0, i32 %index
%func_ptr = load ptr, ptr %gep
- %result = call i32 %func_ptr()
+ %result = call i32 %func_ptr(), !prof !2
ret i32 %result
}
@@ -226,3 +226,6 @@ define i32 @function_with_jump_table_addrspace_42(i32 %index) addrspace(42) {
ret i32 %result
}
+!0 = !{i64 5678}
+!1 = !{i64 5555}
+!2 = !{!"VP", i32 0, i64 25, i64 5678, i64 20, i64 5555, i64 5}
\ No newline at end of file
More information about the llvm-commits
mailing list