[llvm] [JTS] Propagate profile info (PR #153305)

Mircea Trofin via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 14 07:44:38 PDT 2025


https://github.com/mtrofin updated https://github.com/llvm/llvm-project/pull/153305

>From 94c57f347294051d85e30f281d7d8c644cb2ceff Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Tue, 12 Aug 2025 14:30:50 -0700
Subject: [PATCH] [JTS] Propagate profile info

---
 llvm/include/llvm/ProfileData/InstrProf.h     |  4 +
 .../Transforms/Scalar/JumpTableToSwitch.cpp   | 74 +++++++++++++++++--
 .../Transforms/JumpTableToSwitch/basic.ll     |  9 ++-
 3 files changed, 79 insertions(+), 8 deletions(-)

diff --git a/llvm/include/llvm/ProfileData/InstrProf.h b/llvm/include/llvm/ProfileData/InstrProf.h
index bab1963dba22e..85a9efe73855b 100644
--- a/llvm/include/llvm/ProfileData/InstrProf.h
+++ b/llvm/include/llvm/ProfileData/InstrProf.h
@@ -665,6 +665,10 @@ class InstrProfSymtab {
     return Error::success();
   }
 
+  const std::vector<std::pair<uint64_t, Function *>> &getIDToNameMap() const {
+    return MD5FuncMap;
+  }
+
   const StringSet<> &getVTableNames() const { return VTableNames; }
 
   /// Map a function address to its name's MD5 hash. This interface
diff --git a/llvm/lib/Transforms/Scalar/JumpTableToSwitch.cpp b/llvm/lib/Transforms/Scalar/JumpTableToSwitch.cpp
index 7f99cd2060a9d..6719ce64b96b6 100644
--- a/llvm/lib/Transforms/Scalar/JumpTableToSwitch.cpp
+++ b/llvm/lib/Transforms/Scalar/JumpTableToSwitch.cpp
@@ -7,14 +7,24 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Transforms/Scalar/JumpTableToSwitch.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Analysis/ConstantFolding.h"
+#include "llvm/Analysis/CtxProfAnalysis.h"
 #include "llvm/Analysis/DomTreeUpdater.h"
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
 #include "llvm/Analysis/PostDominators.h"
 #include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/ProfDataUtils.h"
+#include "llvm/ProfileData/InstrProf.h"
 #include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include <limits>
 
 using namespace llvm;
 
@@ -33,6 +43,8 @@ static cl::opt<unsigned> FunctionSizeThreshold(
              "or equal than this threshold."),
     cl::init(50));
 
+extern cl::opt<bool> ProfcheckDisableMetadataFixes;
+
 #define DEBUG_TYPE "jump-table-to-switch"
 
 namespace {
@@ -90,9 +102,11 @@ static std::optional<JumpTableTy> parseJumpTable(GetElementPtrInst *GEP,
   return JumpTable;
 }
 
-static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
-                                  DomTreeUpdater &DTU,
-                                  OptimizationRemarkEmitter &ORE) {
+static BasicBlock *
+expandToSwitch(CallBase *CB, const JumpTableTy &JT, DomTreeUpdater &DTU,
+               OptimizationRemarkEmitter &ORE,
+               llvm::function_ref<GlobalValue::GUID(const Function &)>
+                   GetGuidForFunction) {
   const bool IsVoid = CB->getType() == Type::getVoidTy(CB->getContext());
 
   SmallVector<DominatorTree::UpdateType, 8> DTUpdates;
@@ -115,7 +129,31 @@ static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
   IRBuilder<> BuilderTail(CB);
   PHINode *PHI =
       IsVoid ? nullptr : BuilderTail.CreatePHI(CB->getType(), JT.Funcs.size());
-
+  const auto *ProfMD = CB->getMetadata(LLVMContext::MD_prof);
+
+  SmallVector<uint64_t> BranchWeights;
+  DenseMap<GlobalValue::GUID, uint64_t> GuidToCounter;
+  const bool HadProfile = isValueProfileMD(ProfMD);
+  if (HadProfile) {
+    // The assumptions, coming in, are that the functions in JT.Funcs are
+    // defined in this module (from parseJumpTable).
+    assert(llvm::all_of(
+        JT.Funcs, [](const Function *F) { return F && !F->isDeclaration(); }));
+    BranchWeights.reserve(JT.Funcs.size() + 1);
+    // The first is the default target, which is the unreachable block created
+    // above.
+    BranchWeights.push_back(0U);
+    uint64_t TotalCount = 0;
+    auto Targets = getValueProfDataFromInst(
+        *CB, InstrProfValueKind::IPVK_IndirectCallTarget,
+        std::numeric_limits<uint32_t>::max(), TotalCount);
+
+    for (const auto &[G, C] : Targets) {
+      auto It = GuidToCounter.insert({G, C});
+      assert(It.second);
+      (void)It;
+    }
+  }
   for (auto [Index, Func] : llvm::enumerate(JT.Funcs)) {
     BasicBlock *B = BasicBlock::Create(Func->getContext(),
                                        "call." + Twine(Index), &F, Tail);
@@ -127,6 +165,11 @@ static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
     Call->insertInto(B, B->end());
     Switch->addCase(
         cast<ConstantInt>(ConstantInt::get(JT.Index->getType(), Index)), B);
+    GlobalValue::GUID FctID = GetGuidForFunction(*Func);
+    // It'd be OK to _not_ find target functions in GuidToCounter, e.g. suppose
+    // just some of the jump targets are taken (for the given profile).
+    BranchWeights.push_back(FctID == 0U ? 0U
+                                        : GuidToCounter.lookup_or(FctID, 0U));
     BranchInst::Create(Tail, B);
     if (PHI)
       PHI->addIncoming(Call, B);
@@ -136,6 +179,13 @@ static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
     return OptimizationRemark(DEBUG_TYPE, "ReplacedJumpTableWithSwitch", CB)
            << "expanded indirect call into switch";
   });
+  if (HadProfile && !ProfcheckDisableMetadataFixes) {
+    // At least one of the targets must've been taken.
+    assert(llvm::any_of(BranchWeights, [](uint64_t V) { return V != 0; }));
+    setProfMetadata(F.getParent(), Switch, BranchWeights,
+                    *llvm::max_element(BranchWeights));
+  } else
+    setExplicitlyUnknownBranchWeights(*Switch);
   if (PHI)
     CB->replaceAllUsesWith(PHI);
   CB->eraseFromParent();
@@ -150,6 +200,15 @@ PreservedAnalyses JumpTableToSwitchPass::run(Function &F,
   PostDominatorTree *PDT = AM.getCachedResult<PostDominatorTreeAnalysis>(F);
   DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy);
   bool Changed = false;
+  InstrProfSymtab Symtab;
+  if (auto E = Symtab.create(*F.getParent()))
+    F.getContext().emitError(
+        "Could not create indirect call table, likely corrupted IR" +
+        toString(std::move(E)));
+  DenseMap<const Function *, GlobalValue::GUID> FToGuid;
+  for (const auto &[G, FPtr] : Symtab.getIDToNameMap())
+    FToGuid.insert({FPtr, G});
+
   for (BasicBlock &BB : make_early_inc_range(F)) {
     BasicBlock *CurrentBB = &BB;
     while (CurrentBB) {
@@ -170,7 +229,12 @@ PreservedAnalyses JumpTableToSwitchPass::run(Function &F,
         std::optional<JumpTableTy> JumpTable = parseJumpTable(GEP, PtrTy);
         if (!JumpTable)
           continue;
-        SplittedOutTail = expandToSwitch(Call, *JumpTable, DTU, ORE);
+        SplittedOutTail = expandToSwitch(
+            Call, *JumpTable, DTU, ORE, [&](const Function &Fct) {
+              if (Fct.getMetadata(AssignGUIDPass::GUIDMetadataName))
+                return AssignGUIDPass::getGUID(Fct);
+              return FToGuid.lookup_or(&Fct, 0U);
+            });
         Changed = true;
         break;
       }
diff --git a/llvm/test/Transforms/JumpTableToSwitch/basic.ll b/llvm/test/Transforms/JumpTableToSwitch/basic.ll
index 321f837077ab6..577c2adaf5afa 100644
--- a/llvm/test/Transforms/JumpTableToSwitch/basic.ll
+++ b/llvm/test/Transforms/JumpTableToSwitch/basic.ll
@@ -4,11 +4,11 @@
 
 @func_array = constant [2 x ptr] [ptr @func0, ptr @func1]
 
-define i32 @func0() {
+define i32 @func0() !guid !0 {
   ret i32 1
 }
 
-define i32 @func1() {
+define i32 @func1() !guid !1 {
   ret i32 2
 }
 
@@ -42,7 +42,7 @@ define i32 @function_with_jump_table(i32 %index) {
 ;
   %gep = getelementptr inbounds [2 x ptr], ptr @func_array, i32 0, i32 %index
   %func_ptr = load ptr, ptr %gep
-  %result = call i32 %func_ptr()
+  %result = call i32 %func_ptr(), !prof !2
   ret i32 %result
 }
 
@@ -226,3 +226,6 @@ define i32 @function_with_jump_table_addrspace_42(i32 %index) addrspace(42) {
   ret i32 %result
 }
 
+!0 = !{i64 5678}
+!1 = !{i64 5555}
+!2 = !{!"VP", i32 0, i64 25, i64 5678, i64 20, i64 5555, i64 5}
\ No newline at end of file



More information about the llvm-commits mailing list