[llvm] [WPD] set the function entry count (PR #155657)

Mircea Trofin via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 27 10:14:07 PDT 2025


https://github.com/mtrofin created https://github.com/llvm/llvm-project/pull/155657

None

>From 59546ec4583ad9b64392ddb853ed0fcd9fbf1beb Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Tue, 26 Aug 2025 17:28:55 +0000
Subject: [PATCH] [WPD] set the function entry count

---
 llvm/include/llvm/IR/ProfDataUtils.h          |  1 +
 llvm/lib/IR/ProfDataUtils.cpp                 |  8 ++++
 llvm/lib/IR/Verifier.cpp                      |  3 --
 .../lib/Transforms/IPO/WholeProgramDevirt.cpp | 48 +++++++++++++++----
 .../WholeProgramDevirt/branch-funnel.ll       | 35 ++++++++------
 5 files changed, 68 insertions(+), 27 deletions(-)

diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index 404875285beae..0a1efb81cb13d 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -179,6 +179,7 @@ inline uint32_t scaleBranchCount(uint64_t Count, uint64_t Scale) {
 /// bugs where the pass forgets to transfer over or otherwise specify profile
 /// info.
 LLVM_ABI void setExplicitlyUnknownBranchWeights(Instruction &I);
+LLVM_ABI void setExplicitlyUnknownFunctionEntryCount(Function &I);
 
 LLVM_ABI bool isExplicitlyUnknownBranchWeightsMetadata(const MDNode &MD);
 LLVM_ABI bool hasExplicitlyUnknownBranchWeights(const Instruction &I);
diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index d24263f8b3bda..a63d731916600 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -250,6 +250,14 @@ void setExplicitlyUnknownBranchWeights(Instruction &I) {
                   MDB.createString(MDProfLabels::UnknownBranchWeightsMarker)));
 }
 
+void setExplicitlyUnknownFunctionEntryCount(Function &F) {
+  MDBuilder MDB(F.getContext());
+  F.setMetadata(
+      LLVMContext::MD_prof,
+      MDNode::get(F.getContext(),
+                  MDB.createString(MDProfLabels::UnknownBranchWeightsMarker)));
+}
+
 bool isExplicitlyUnknownBranchWeightsMetadata(const MDNode &MD) {
   if (MD.getNumOperands() != 1)
     return false;
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 9fda08645e118..1eeaf630ffcfb 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -2527,9 +2527,6 @@ void Verifier::verifyFunctionMetadata(
     if (Pair.first == LLVMContext::MD_prof) {
       MDNode *MD = Pair.second;
       if (isExplicitlyUnknownBranchWeightsMetadata(*MD)) {
-        CheckFailed("'unknown' !prof metadata should appear only on "
-                    "instructions supporting the 'branch_weights' metadata",
-                    MD);
         continue;
       }
       Check(MD->getNumOperands() >= 2,
diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
index cb98ed838f5d7..22baa0fcf4c5f 100644
--- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
+++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
@@ -60,6 +60,7 @@
 #include "llvm/ADT/Statistic.h"
 #include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/BasicAliasAnalysis.h"
+#include "llvm/Analysis/BlockFrequencyInfo.h"
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
 #include "llvm/Analysis/TypeMetadataUtils.h"
 #include "llvm/Bitcode/BitcodeReader.h"
@@ -84,6 +85,7 @@
 #include "llvm/IR/Module.h"
 #include "llvm/IR/ModuleSummaryIndexYAML.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Errc.h"
@@ -97,6 +99,7 @@
 #include "llvm/Transforms/Utils/CallPromotionUtils.h"
 #include "llvm/Transforms/Utils/Evaluator.h"
 #include <algorithm>
+#include <cmath>
 #include <cstddef>
 #include <map>
 #include <set>
@@ -169,6 +172,8 @@ static cl::list<std::string>
                       cl::desc("Prevent function(s) from being devirtualized"),
                       cl::Hidden, cl::CommaSeparated);
 
+extern cl::opt<bool> ProfcheckDisableMetadataFixes;
+
 /// With Clang, a pure virtual class's deleting destructor is emitted as a
 /// `llvm.trap` intrinsic followed by an unreachable IR instruction. In the
 /// context of whole program devirtualization, the deleting destructor of a pure
@@ -656,7 +661,7 @@ struct DevirtModule {
                            VTableSlotInfo &SlotInfo,
                            WholeProgramDevirtResolution *Res);
 
-  void applyICallBranchFunnel(VTableSlotInfo &SlotInfo, Constant *JT,
+  void applyICallBranchFunnel(VTableSlotInfo &SlotInfo, Function &JT,
                               bool &IsExported);
   void tryICallBranchFunnel(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
                             VTableSlotInfo &SlotInfo,
@@ -1453,7 +1458,7 @@ void DevirtModule::tryICallBranchFunnel(
 
   FunctionType *FT =
       FunctionType::get(Type::getVoidTy(M.getContext()), {Int8PtrTy}, true);
-  Function *JT;
+  Function *JT = nullptr;
   if (isa<MDString>(Slot.TypeID)) {
     JT = Function::Create(FT, Function::ExternalLinkage,
                           M.getDataLayout().getProgramAddressSpace(),
@@ -1482,13 +1487,18 @@ void DevirtModule::tryICallBranchFunnel(
   ReturnInst::Create(M.getContext(), nullptr, BB);
 
   bool IsExported = false;
-  applyICallBranchFunnel(SlotInfo, JT, IsExported);
+  applyICallBranchFunnel(SlotInfo, *JT, IsExported);
   if (IsExported)
     Res->TheKind = WholeProgramDevirtResolution::BranchFunnel;
+
+  if (!JT->getEntryCount().has_value()) {
+    setExplicitlyUnknownFunctionEntryCount(*JT);
+  }
 }
 
 void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
-                                          Constant *JT, bool &IsExported) {
+                                          Function &JT, bool &IsExported) {
+  DenseMap<Function *, double> FunctionEntryCounts;
   auto Apply = [&](CallSiteInfo &CSInfo) {
     if (CSInfo.isExported())
       IsExported = true;
@@ -1517,7 +1527,7 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
       NumBranchFunnel++;
       if (RemarksEnabled)
         VCallSite.emitRemark("branch-funnel",
-                             JT->stripPointerCasts()->getName(), OREGetter);
+                             JT.stripPointerCasts()->getName(), OREGetter);
 
       // Pass the address of the vtable in the nest register, which is r10 on
       // x86_64.
@@ -1533,11 +1543,26 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
       llvm::append_range(Args, CB.args());
 
       CallBase *NewCS = nullptr;
+      if (!JT.isDeclaration() && !ProfcheckDisableMetadataFixes) {
+        auto &F = *CB.getCaller();
+        auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(F);
+        auto EC = BFI.getBlockFreq(&F.getEntryBlock());
+        auto CC = F.getEntryCount(/*AllowSynthetic=*/true);
+        double CallCount = 0.0;
+        if (EC.getFrequency() != 0 && CC && CC->getCount() != 0) {
+          double CallFreq =
+              static_cast<double>(
+                  BFI.getBlockFreq(CB.getParent()).getFrequency()) /
+              EC.getFrequency();
+          CallCount = CallFreq * CC->getCount();
+        }
+        FunctionEntryCounts[&JT] += CallCount;
+      }
       if (isa<CallInst>(CB))
-        NewCS = IRB.CreateCall(NewFT, JT, Args);
+        NewCS = IRB.CreateCall(NewFT, &JT, Args);
       else
         NewCS =
-            IRB.CreateInvoke(NewFT, JT, cast<InvokeInst>(CB).getNormalDest(),
+            IRB.CreateInvoke(NewFT, &JT, cast<InvokeInst>(CB).getNormalDest(),
                              cast<InvokeInst>(CB).getUnwindDest(), Args);
       NewCS->setCallingConv(CB.getCallingConv());
 
@@ -1571,6 +1596,11 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
   Apply(SlotInfo.CSInfo);
   for (auto &P : SlotInfo.ConstCSInfo)
     Apply(P.second);
+  for (auto &[F, C] : FunctionEntryCounts) {
+    assert(!F->getEntryCount(/*AllowSynthetic=*/true) &&
+           "Unexpected entry count for funnel that was freshly synthesized");
+    F->setEntryCount(static_cast<uint64_t>(std::round(C)));
+  }
 }
 
 bool DevirtModule::tryEvaluateFunctionsWithArgs(
@@ -2244,12 +2274,12 @@ void DevirtModule::importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo) {
   if (Res.TheKind == WholeProgramDevirtResolution::BranchFunnel) {
     // The type of the function is irrelevant, because it's bitcast at calls
     // anyhow.
-    Constant *JT = cast<Constant>(
+    auto *JT = cast<Function>(
         M.getOrInsertFunction(getGlobalName(Slot, {}, "branch_funnel"),
                               Type::getVoidTy(M.getContext()))
             .getCallee());
     bool IsExported = false;
-    applyICallBranchFunnel(SlotInfo, JT, IsExported);
+    applyICallBranchFunnel(SlotInfo, *JT, IsExported);
     assert(!IsExported);
   }
 }
diff --git a/llvm/test/Transforms/WholeProgramDevirt/branch-funnel.ll b/llvm/test/Transforms/WholeProgramDevirt/branch-funnel.ll
index 0b1023eee2732..5756cfd22f266 100644
--- a/llvm/test/Transforms/WholeProgramDevirt/branch-funnel.ll
+++ b/llvm/test/Transforms/WholeProgramDevirt/branch-funnel.ll
@@ -3,7 +3,7 @@
 
 ; RUN: opt -passes=wholeprogramdevirt -whole-program-visibility -wholeprogramdevirt-summary-action=export -wholeprogramdevirt-read-summary=%S/Inputs/export.yaml -wholeprogramdevirt-write-summary=%t -S -o - %s | FileCheck --check-prefixes=CHECK,RETP %s
 
-; RUN: opt -passes='wholeprogramdevirt,default<O3>' -whole-program-visibility -wholeprogramdevirt-summary-action=export -wholeprogramdevirt-read-summary=%S/Inputs/export.yaml -wholeprogramdevirt-write-summary=%t  -S -o - %s | FileCheck --check-prefixes=CHECK %s
+; RUN: opt -passes='wholeprogramdevirt,default<O3>' -whole-program-visibility -wholeprogramdevirt-summary-action=export -wholeprogramdevirt-read-summary=%S/Inputs/export.yaml -wholeprogramdevirt-write-summary=%t  -S -o - %s | FileCheck --check-prefixes=CHECK,O3 %s
 
 ; RUN: FileCheck --check-prefix=SUMMARY %s < %t
 
@@ -159,7 +159,7 @@ declare ptr @llvm.load.relative.i32(ptr, i32)
 
 ; CHECK-LABEL: define i32 @fn1
 ; CHECK-NOT: call void (...) @llvm.icall.branch.funnel
-define i32 @fn1(ptr %obj) #0 {
+define i32 @fn1(ptr %obj) #0 !prof !10 {
   %vtable = load ptr, ptr %obj
   %p = call i1 @llvm.type.test(ptr %vtable, metadata !"typeid1")
   call void @llvm.assume(i1 %p)
@@ -172,7 +172,7 @@ define i32 @fn1(ptr %obj) #0 {
 
 ; CHECK-LABEL: define i32 @fn1_rv
 ; CHECK-NOT: call void (...) @llvm.icall.branch.funnel
-define i32 @fn1_rv(ptr %obj) #0 {
+define i32 @fn1_rv(ptr %obj) #0 !prof !10 {
   %vtable = load ptr, ptr %obj
   %p = call i1 @llvm.type.test(ptr %vtable, metadata !"typeid1_rv")
   call void @llvm.assume(i1 %p)
@@ -185,7 +185,7 @@ define i32 @fn1_rv(ptr %obj) #0 {
 
 ; CHECK-LABEL: define i32 @fn2
 ; CHECK-NOT: call void (...) @llvm.icall.branch.funnel
-define i32 @fn2(ptr %obj) #0 {
+define i32 @fn2(ptr %obj) #0 !prof !10 {
   %vtable = load ptr, ptr %obj
   %p = call i1 @llvm.type.test(ptr %vtable, metadata !"typeid2")
   call void @llvm.assume(i1 %p)
@@ -197,7 +197,7 @@ define i32 @fn2(ptr %obj) #0 {
 
 ; CHECK-LABEL: define i32 @fn2_rv
 ; CHECK-NOT: call void (...) @llvm.icall.branch.funnel
-define i32 @fn2_rv(ptr %obj) #0 {
+define i32 @fn2_rv(ptr %obj) #0 !prof !10 {
   %vtable = load ptr, ptr %obj
   %p = call i1 @llvm.type.test(ptr %vtable, metadata !"typeid2_rv")
   call void @llvm.assume(i1 %p)
@@ -209,7 +209,7 @@ define i32 @fn2_rv(ptr %obj) #0 {
 
 ; CHECK-LABEL: define i32 @fn3
 ; CHECK-NOT: call void (...) @llvm.icall.branch.funnel
-define i32 @fn3(ptr %obj) #0 {
+define i32 @fn3(ptr %obj) #0 !prof !10 {
   %vtable = load ptr, ptr %obj
   %p = call i1 @llvm.type.test(ptr %vtable, metadata !4)
   call void @llvm.assume(i1 %p)
@@ -222,7 +222,7 @@ define i32 @fn3(ptr %obj) #0 {
 
 ; CHECK-LABEL: define i32 @fn3_rv
 ; CHECK-NOT: call void (...) @llvm.icall.branch.funnel
-define i32 @fn3_rv(ptr %obj) #0 {
+define i32 @fn3_rv(ptr %obj) #0 !prof !10 {
   %vtable = load ptr, ptr %obj
   %p = call i1 @llvm.type.test(ptr %vtable, metadata !9)
   call void @llvm.assume(i1 %p)
@@ -235,7 +235,7 @@ define i32 @fn3_rv(ptr %obj) #0 {
 
 ; CHECK-LABEL: define i32 @fn4
 ; CHECK-NOT: call void (...) @llvm.icall.branch.funnel
-define i32 @fn4(ptr %obj) #0 {
+define i32 @fn4(ptr %obj) #0 !prof !10 {
   %p = call i1 @llvm.type.test(ptr @vt1_1, metadata !"typeid1")
   call void @llvm.assume(i1 %p)
   %fptr = load ptr, ptr @vt1_1
@@ -247,7 +247,7 @@ define i32 @fn4(ptr %obj) #0 {
 
 ; CHECK-LABEL: define i32 @fn4_cpy
 ; CHECK-NOT: call void (...) @llvm.icall.branch.funnel
-define i32 @fn4_cpy(ptr %obj) #0 {
+define i32 @fn4_cpy(ptr %obj) #0 !prof !10 {
   %p = call i1 @llvm.type.test(ptr @vt1_1, metadata !"typeid1")
   call void @llvm.assume(i1 %p)
   %fptr = load ptr, ptr @vt1_1
@@ -259,7 +259,7 @@ define i32 @fn4_cpy(ptr %obj) #0 {
 
 ; CHECK-LABEL: define i32 @fn4_rv
 ; CHECK-NOT: call void (...) @llvm.icall.branch.funnel
-define i32 @fn4_rv(ptr %obj) #0 {
+define i32 @fn4_rv(ptr %obj) #0 !prof !10 {
   %p = call i1 @llvm.type.test(ptr @vt1_1_rv, metadata !"typeid1_rv")
   call void @llvm.assume(i1 %p)
   %fptr = call ptr @llvm.load.relative.i32(ptr @vt1_1_rv, i32 0)
@@ -271,7 +271,7 @@ define i32 @fn4_rv(ptr %obj) #0 {
 
 ; CHECK-LABEL: define i32 @fn4_rv_cpy
 ; CHECK-NOT: call void (...) @llvm.icall.branch.funnel
-define i32 @fn4_rv_cpy(ptr %obj) #0 {
+define i32 @fn4_rv_cpy(ptr %obj) #0 !prof !10 {
   %p = call i1 @llvm.type.test(ptr @vt1_1_rv, metadata !"typeid1_rv")
   call void @llvm.assume(i1 %p)
   %fptr = call ptr @llvm.load.relative.i32(ptr @vt1_1_rv, i32 0)
@@ -281,14 +281,18 @@ define i32 @fn4_rv_cpy(ptr %obj) #0 {
   ret i32 %result
 }
 
-; CHECK-LABEL: define hidden void @__typeid_typeid1_0_branch_funnel(ptr nest %0, ...)
+; CHECK-LABEL: define hidden void @__typeid_typeid1_0_branch_funnel(ptr nest %0, ...) !prof !11
 ; CHECK-NEXT: musttail call void (...) @llvm.icall.branch.funnel(ptr %0, ptr {{(nonnull )?}}@vt1_1, ptr {{(nonnull )?}}@vf1_1, ptr {{(nonnull )?}}@vt1_2, ptr {{(nonnull )?}}@vf1_2, ...)
 
-; CHECK-LABEL: define hidden void @__typeid_typeid1_rv_0_branch_funnel(ptr nest %0, ...)
+; CHECK-LABEL: define hidden void @__typeid_typeid1_rv_0_branch_funnel(ptr nest %0, ...) !prof !11
 ; CHECK-NEXT: musttail call void (...) @llvm.icall.branch.funnel(ptr %0, ptr {{(nonnull )?}}@vt1_1_rv, ptr {{(nonnull )?}}@vf1_1, ptr {{(nonnull )?}}@vt1_2_rv, ptr {{(nonnull )?}}@vf1_2, ...)
 
-; CHECK: define internal void @branch_funnel(ptr
-; CHECK: define internal void @branch_funnel.1(ptr
+; CHECK: define internal void @branch_funnel(ptr {{.*}})
+; RETP-SAME !prof !10
+; NORETP-SAME !prof !11
+; CHECK: define internal void @branch_funnel.1(ptr {{.*}})
+; RETP-SAME !prof !10
+; NORETP-SAME !prof !11
 
 declare i1 @llvm.type.test(ptr, metadata)
 declare void @llvm.assume(i1)
@@ -303,5 +307,6 @@ declare void @llvm.assume(i1)
 !7 = !{i32 0, !"typeid3_rv"}
 !8 = !{i32 0, !9}
 !9 = distinct !{}
+!10 = !{!"function_entry_count", i64 1000}
 
 attributes #0 = { "target-features"="+retpoline" }



More information about the llvm-commits mailing list