[llvm] [FunctionSpecialization] Preserve call counts of specialized functions (PR #157768)

Alan Zhao via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 9 17:55:09 PDT 2025


https://github.com/alanzhao1 updated https://github.com/llvm/llvm-project/pull/157768

>From ceae28df78e423ace826e4102628c0ef5f03b0e0 Mon Sep 17 00:00:00 2001
From: Alan Zhao <ayzhao at google.com>
Date: Tue, 9 Sep 2025 16:28:57 -0700
Subject: [PATCH 1/4] [FunctionSpecialization] Preserve call counts of
 specialized functions

A function that has been specialized will have its function entry counts
preserved as follows:

* Each specialization's count is the sum of each call site's basic
  block's number of entries as computed by `BlockFrequencyInfo`.
* The original function's count will be decreased by the counts of its
  specializations.

Tracking issue: #147390
---
 .../Transforms/IPO/FunctionSpecialization.cpp | 21 +++++++-
 .../FunctionSpecialization/profile-counts.ll  | 52 +++++++++++++++++++
 2 files changed, 72 insertions(+), 1 deletion(-)
 create mode 100644 llvm/test/Transforms/FunctionSpecialization/profile-counts.ll

diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
index a459a9eddbcfc..324723c7942ab 100644
--- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
@@ -784,9 +784,25 @@ bool FunctionSpecializer::run() {
 
     // Update the known call sites to call the clone.
     for (CallBase *Call : S.CallSites) {
+      Function *Clone = S.Clone;
       LLVM_DEBUG(dbgs() << "FnSpecialization: Redirecting " << *Call
-                        << " to call " << S.Clone->getName() << "\n");
+                        << " to call " << Clone->getName() << "\n");
       Call->setCalledFunction(S.Clone);
+      if (std::optional<uint64_t> Count =
+              GetBFI(*Call->getFunction())
+                  .getBlockProfileCount(Call->getParent())) {
+        uint64_t CallCount = *Count + Clone->getEntryCount()->getCount();
+        Clone->setEntryCount(CallCount);
+        if (std::optional<llvm::Function::ProfileCount> MaybeOriginalCount =
+                S.F->getEntryCount()) {
+          uint64_t OriginalCount = MaybeOriginalCount->getCount();
+          if (OriginalCount > CallCount) {
+            S.F->setEntryCount(OriginalCount - CallCount);
+          } else {
+            S.F->setEntryCount(0);
+          }
+        }
+      }
     }
 
     Clones.push_back(S.Clone);
@@ -1043,6 +1059,9 @@ Function *FunctionSpecializer::createSpecialization(Function *F,
   // clone must.
   Clone->setLinkage(GlobalValue::InternalLinkage);
 
+  if (F->getEntryCount())
+    Clone->setEntryCount(0);
+
   // Initialize the lattice state of the arguments of the function clone,
   // marking the argument on which we specialized the function constant
   // with the given value.
diff --git a/llvm/test/Transforms/FunctionSpecialization/profile-counts.ll b/llvm/test/Transforms/FunctionSpecialization/profile-counts.ll
new file mode 100644
index 0000000000000..4a2ad4ff9fe90
--- /dev/null
+++ b/llvm/test/Transforms/FunctionSpecialization/profile-counts.ll
@@ -0,0 +1,52 @@
+; RUN: opt -passes="ipsccp<func-spec>" -force-specialization -S < %s | FileCheck %s
+target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
+
+ at A = external dso_local constant i32, align 4
+ at B = external dso_local constant i32, align 4
+
+; CHECK: define dso_local i32 @bar(i32 %x, i32 %y, ptr %z) !prof ![[BAR_PROF:[0-9]]] {
+define dso_local i32 @bar(i32 %x, i32 %y, ptr %z) !prof !0 {
+entry:
+  %tobool = icmp ne i32 %x, 0
+; CHECK: br i1 %tobool, label %if.then, label %if.else, !prof ![[BRANCH_PROF:[0-9]]]
+  br i1 %tobool, label %if.then, label %if.else, !prof !1
+
+if.then:
+; CHECK: if.then:
+; CHECK: call i32 @foo.specialized.1(i32 %x, ptr @A)
+  %call = call i32 @foo(i32 %x, ptr @A)
+  br label %return
+
+if.else:
+; CHECK: if.else:
+; CHECK: call i32 @foo.specialized.2(i32 %y, ptr @B)
+  %call1 = call i32 @foo(i32 %y, ptr @B)
+  br label %return
+
+; CHECK: return:
+; CHECK: %call2 = call i32 @foo(i32 %x, ptr %z)
+return:
+  %retval.0 = phi i32 [ %call, %if.then ], [ %call1, %if.else ]
+  %call2 = call i32 @foo(i32 %x, ptr %z);
+  %add = add i32 %retval.0, %call2
+  ret i32 %add
+}
+
+; CHECK: define internal i32 @foo(i32 %x, ptr %b) !prof ![[FOO_UNSPEC_PROF:[0-9]]]
+; CHECK: define internal i32 @foo.specialized.1(i32 %x, ptr %b) !prof ![[FOO_SPEC_1_PROF:[0-9]]]
+; CHECK: define internal i32 @foo.specialized.2(i32 %x, ptr %b) !prof ![[FOO_SPEC_2_PROF:[0-9]]]
+define internal i32 @foo(i32 %x, ptr %b) !prof !2 {
+entry:
+  %0 = load i32, ptr %b, align 4
+  %add = add nsw i32 %x, %0
+  ret i32 %add
+}
+
+; CHECK: ![[BAR_PROF]] = !{!"function_entry_count", i64 1000}
+; CHECK: ![[BRANCH_PROF]] = !{!"branch_weights", i32 1, i32 3}
+; CHECK: ![[FOO_UNSPEC_PROF]] =  !{!"function_entry_count", i64 234}
+; CHECK: ![[FOO_SPEC_1_PROF]] = !{!"function_entry_count", i64 250}
+; CHECK: ![[FOO_SPEC_2_PROF]] = !{!"function_entry_count", i64 750}
+!0 = !{!"function_entry_count", i64 1000}
+!1 = !{!"branch_weights", i32 1, i32 3}
+!2 = !{!"function_entry_count", i64 1234}

>From 4ca46c342c1d74f32b814f5b87e4975ac50aac5d Mon Sep 17 00:00:00 2001
From: Alan Zhao <ayzhao at google.com>
Date: Tue, 9 Sep 2025 16:40:08 -0700
Subject: [PATCH 2/4] make test expecatations consistent

---
 llvm/test/Transforms/FunctionSpecialization/profile-counts.ll | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/test/Transforms/FunctionSpecialization/profile-counts.ll b/llvm/test/Transforms/FunctionSpecialization/profile-counts.ll
index 4a2ad4ff9fe90..d5b2e35feb118 100644
--- a/llvm/test/Transforms/FunctionSpecialization/profile-counts.ll
+++ b/llvm/test/Transforms/FunctionSpecialization/profile-counts.ll
@@ -11,15 +11,15 @@ entry:
 ; CHECK: br i1 %tobool, label %if.then, label %if.else, !prof ![[BRANCH_PROF:[0-9]]]
   br i1 %tobool, label %if.then, label %if.else, !prof !1
 
-if.then:
 ; CHECK: if.then:
 ; CHECK: call i32 @foo.specialized.1(i32 %x, ptr @A)
+if.then:
   %call = call i32 @foo(i32 %x, ptr @A)
   br label %return
 
-if.else:
 ; CHECK: if.else:
 ; CHECK: call i32 @foo.specialized.2(i32 %y, ptr @B)
+if.else:
   %call1 = call i32 @foo(i32 %y, ptr @B)
   br label %return
 

>From 6acdd6cdb350509a9758a6d5f65a5d3527130495 Mon Sep 17 00:00:00 2001
From: Alan Zhao <ayzhao at google.com>
Date: Tue, 9 Sep 2025 17:23:38 -0700
Subject: [PATCH 3/4] code review comments

---
 .../lib/Transforms/IPO/FunctionSpecialization.cpp | 15 ++++++++++-----
 1 file changed, 10 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
index 324723c7942ab..78975c95789d8 100644
--- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
@@ -788,18 +788,23 @@ bool FunctionSpecializer::run() {
       LLVM_DEBUG(dbgs() << "FnSpecialization: Redirecting " << *Call
                         << " to call " << Clone->getName() << "\n");
       Call->setCalledFunction(S.Clone);
+      auto &BFI = GetBFI(*Call->getFunction());
       if (std::optional<uint64_t> Count =
-              GetBFI(*Call->getFunction())
-                  .getBlockProfileCount(Call->getParent())) {
-        uint64_t CallCount = *Count + Clone->getEntryCount()->getCount();
+              BFI.getBlockProfileCount(Call->getParent())) {
+        std::optional<llvm::Function::ProfileCount> MaybeCloneCount =
+            Clone->getEntryCount();
+        assert(MaybeCloneCount && "Clone entry count was not set!");
+        uint64_t CallCount = *Count + MaybeCloneCount->getCount();
         Clone->setEntryCount(CallCount);
         if (std::optional<llvm::Function::ProfileCount> MaybeOriginalCount =
                 S.F->getEntryCount()) {
           uint64_t OriginalCount = MaybeOriginalCount->getCount();
-          if (OriginalCount > CallCount) {
+          if (OriginalCount >= CallCount) {
             S.F->setEntryCount(OriginalCount - CallCount);
           } else {
-            S.F->setEntryCount(0);
+            // This should generally not happen as that would mean there are
+            // more computed calls to the function than what was recorded.
+            LLVM_DEBUG(S.F->setEntryCount(0));
           }
         }
       }

>From a610b7003ec64ce731a175b54fb4a3957a910b29 Mon Sep 17 00:00:00 2001
From: Alan Zhao <ayzhao at google.com>
Date: Tue, 9 Sep 2025 17:54:51 -0700
Subject: [PATCH 4/4] Add flag for profcheck studies

---
 llvm/lib/Transforms/IPO/FunctionSpecialization.cpp | 9 ++++++---
 1 file changed, 6 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
index 78975c95789d8..30459caee1609 100644
--- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
@@ -89,6 +89,8 @@ static cl::opt<bool> SpecializeLiteralConstant(
         "Enable specialization of functions that take a literal constant as an "
         "argument"));
 
+extern cl::opt<bool> ProfcheckDisableMetadataFixes;
+
 bool InstCostVisitor::canEliminateSuccessor(BasicBlock *BB,
                                             BasicBlock *Succ) const {
   unsigned I = 0;
@@ -789,8 +791,9 @@ bool FunctionSpecializer::run() {
                         << " to call " << Clone->getName() << "\n");
       Call->setCalledFunction(S.Clone);
       auto &BFI = GetBFI(*Call->getFunction());
-      if (std::optional<uint64_t> Count =
-              BFI.getBlockProfileCount(Call->getParent())) {
+      std::optional<uint64_t> Count =
+          BFI.getBlockProfileCount(Call->getParent());
+      if (Count && !ProfcheckDisableMetadataFixes) {
         std::optional<llvm::Function::ProfileCount> MaybeCloneCount =
             Clone->getEntryCount();
         assert(MaybeCloneCount && "Clone entry count was not set!");
@@ -1064,7 +1067,7 @@ Function *FunctionSpecializer::createSpecialization(Function *F,
   // clone must.
   Clone->setLinkage(GlobalValue::InternalLinkage);
 
-  if (F->getEntryCount())
+  if (F->getEntryCount() && !ProfcheckDisableMetadataFixes)
     Clone->setEntryCount(0);
 
   // Initialize the lattice state of the arguments of the function clone,



More information about the llvm-commits mailing list