[llvm] 0abd744 - [PGO] Use the sum of profile counts to fix the function entry count

Rong Xu via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 16 13:38:03 PST 2020


Author: Rong Xu
Date: 2020-12-16T13:37:43-08:00
New Revision: 0abd744597ee502b6424e5a99fb940ca0f866fe9

URL: https://github.com/llvm/llvm-project/commit/0abd744597ee502b6424e5a99fb940ca0f866fe9
DIFF: https://github.com/llvm/llvm-project/commit/0abd744597ee502b6424e5a99fb940ca0f866fe9.diff

LOG: [PGO] Use the sum of profile counts to fix the function entry count

Raw profile count values for each BB are not kept after profile
annotation. We record function entry count and branch weights
and use them to compute the count when needed.  This mechanism
works well in a perfect world, but often breaks in real programs,
because of number prevision, inconsistent profile, or bugs in
BFI). This patch uses sum of profile count values to fix
function entry count to make the BFI count close to real profile
counts.

Differential Revision: https://reviews.llvm.org/D61540

Added: 
    llvm/test/Transforms/PGOProfile/Inputs/fix_bfi.proftext
    llvm/test/Transforms/PGOProfile/fix_bfi.ll

Modified: 
    llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
    llvm/test/Transforms/PGOProfile/bfi_verification.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index eba8d9e9c3c3..8627e8239b2e 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -252,6 +252,10 @@ static cl::opt<bool> PGOInstrumentEntry(
     "pgo-instrument-entry", cl::init(false), cl::Hidden,
     cl::desc("Force to instrument function entry basicblock."));
 
+static cl::opt<bool>
+    PGOFixEntryCount("pgo-fix-entry-count", cl::init(true), cl::Hidden,
+                     cl::desc("Fix function entry count in profile use."));
+
 static cl::opt<bool> PGOVerifyHotBFI(
     "pgo-verify-hot-bfi", cl::init(false), cl::Hidden,
     cl::desc("Print out the non-match BFI count if a hot raw profile count "
@@ -1640,6 +1644,53 @@ PreservedAnalyses PGOInstrumentationGen::run(Module &M,
   return PreservedAnalyses::none();
 }
 
+// Using the ratio b/w sums of profile count values and BFI count values to
+// adjust the func entry count.
+static void fixFuncEntryCount(PGOUseFunc &Func, LoopInfo &LI,
+                              BranchProbabilityInfo &NBPI) {
+  Function &F = Func.getFunc();
+  BlockFrequencyInfo NBFI(F, NBPI, LI);
+#ifndef NDEBUG
+  auto BFIEntryCount = F.getEntryCount();
+  assert(BFIEntryCount.hasValue() && (BFIEntryCount.getCount() > 0) &&
+         "Invalid BFI Entrycount");
+#endif
+  auto SumCount = APFloat::getZero(APFloat::IEEEdouble());
+  auto SumBFICount = APFloat::getZero(APFloat::IEEEdouble());
+  for (auto &BBI : F) {
+    uint64_t CountValue = 0;
+    uint64_t BFICountValue = 0;
+    if (!Func.findBBInfo(&BBI))
+      continue;
+    auto BFICount = NBFI.getBlockProfileCount(&BBI);
+    CountValue = Func.getBBInfo(&BBI).CountValue;
+    BFICountValue = BFICount.getValue();
+    SumCount.add(APFloat(CountValue * 1.0), APFloat::rmNearestTiesToEven);
+    SumBFICount.add(APFloat(BFICountValue * 1.0), APFloat::rmNearestTiesToEven);
+  }
+  if (SumCount.isZero())
+    return;
+
+  assert(SumBFICount.compare(APFloat(0.0)) == APFloat::cmpGreaterThan &&
+         "Incorrect sum of BFI counts");
+  if (SumBFICount.compare(SumCount) == APFloat::cmpEqual)
+    return;
+  double Scale = (SumCount / SumBFICount).convertToDouble();
+  if (Scale < 1.001 && Scale > 0.999)
+    return;
+
+  uint64_t FuncEntryCount = Func.getBBInfo(&*F.begin()).CountValue;
+  uint64_t NewEntryCount = 0.5 + FuncEntryCount * Scale;
+  if (NewEntryCount == 0)
+    NewEntryCount = 1;
+  if (NewEntryCount != FuncEntryCount) {
+    F.setEntryCount(ProfileCount(NewEntryCount, Function::PCT_Real));
+    LLVM_DEBUG(dbgs() << "FixFuncEntryCount: in " << F.getName()
+                      << ", entry_count " << FuncEntryCount << " --> "
+                      << NewEntryCount << "\n");
+  }
+}
+
 // Compare the profile count values with BFI count values, and print out
 // the non-matching ones.
 static void verifyFuncBFI(PGOUseFunc &Func, LoopInfo &LI,
@@ -1842,10 +1893,15 @@ static bool annotateAllFunctions(
       }
     }
 
-    // Verify BlockFrequency information.
-    if (PGOVerifyBFI || PGOVerifyHotBFI) {
+    if (PGOVerifyBFI || PGOVerifyHotBFI || PGOFixEntryCount) {
       LoopInfo LI{DominatorTree(F)};
       BranchProbabilityInfo NBPI(F, LI);
+
+      // Fix func entry count.
+      if (PGOFixEntryCount)
+        fixFuncEntryCount(Func, LI, NBPI);
+
+      // Verify BlockFrequency information.
       uint64_t HotCountThreshold = 0, ColdCountThreshold = 0;
       if (PGOVerifyHotBFI) {
         HotCountThreshold = PSI->getOrCompHotCountThreshold();

diff  --git a/llvm/test/Transforms/PGOProfile/Inputs/fix_bfi.proftext b/llvm/test/Transforms/PGOProfile/Inputs/fix_bfi.proftext
new file mode 100644
index 000000000000..dd5c2bcd57c5
--- /dev/null
+++ b/llvm/test/Transforms/PGOProfile/Inputs/fix_bfi.proftext
@@ -0,0 +1,16 @@
+# IR level Instrumentation Flag
+:ir
+sort_basket
+# Func Hash:
+948827210500800754
+# Num Counters:
+7
+# Counter Values:
+41017879
+31616738
+39637749
+32743703
+13338888
+6990942
+6013544
+

diff  --git a/llvm/test/Transforms/PGOProfile/bfi_verification.ll b/llvm/test/Transforms/PGOProfile/bfi_verification.ll
index 8386ebf0db74..029329ba3ccc 100644
--- a/llvm/test/Transforms/PGOProfile/bfi_verification.ll
+++ b/llvm/test/Transforms/PGOProfile/bfi_verification.ll
@@ -1,7 +1,7 @@
 ; Note: Verify bfi counter after loading the profile.
 ; RUN: llvm-profdata merge %S/Inputs/bfi_verification.proftext -o %t.profdata
-; RUN: opt < %s -pgo-instr-use -pgo-test-profile-file=%t.profdata -S -pgo-verify-bfi-ratio=2 -pgo-verify-bfi=true -pass-remarks-analysis=pgo 2>&1 | FileCheck %s --check-prefix=THRESHOLD-CHECK
-; RUN: opt < %s -pgo-instr-use -pgo-test-profile-file=%t.profdata -S -pgo-verify-hot-bfi=true -pass-remarks-analysis=pgo 2>&1 | FileCheck %s --check-prefix=HOTONLY-CHECK
+; RUN: opt < %s -pgo-instr-use -pgo-test-profile-file=%t.profdata -S -pgo-verify-bfi-ratio=2 -pgo-verify-bfi=true -pgo-fix-entry-count=false -pass-remarks-analysis=pgo 2>&1 | FileCheck %s --check-prefix=THRESHOLD-CHECK
+; RUN: opt < %s -pgo-instr-use -pgo-test-profile-file=%t.profdata -S -pgo-verify-hot-bfi=true -pgo-fix-entry-count=false -pass-remarks-analysis=pgo 2>&1 | FileCheck %s --check-prefix=HOTONLY-CHECK
 
 target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
 target triple = "x86_64-unknown-linux-gnu"

diff  --git a/llvm/test/Transforms/PGOProfile/fix_bfi.ll b/llvm/test/Transforms/PGOProfile/fix_bfi.ll
new file mode 100644
index 000000000000..eea8c2beb3b4
--- /dev/null
+++ b/llvm/test/Transforms/PGOProfile/fix_bfi.ll
@@ -0,0 +1,101 @@
+; Note: Scaling the func entry count (using the sum of count value) so that BFI counter value is close to raw profile counter values.
+; RUN: llvm-profdata merge %S/Inputs/fix_bfi.proftext -o %t.profdata
+; RUN: opt -pgo-instr-use -pgo-test-profile-file=%t.profdata -S -pgo-fix-entry-count=true < %s 2>&1 | FileCheck %s
+
+target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+%struct.basket = type { %struct.arc*, i64, i64 }
+%struct.arc = type { i64, %struct.node*, %struct.node*, i32, %struct.arc*, %struct.arc*, i64, i64 }
+%struct.node = type { i64, i32, %struct.node*, %struct.node*, %struct.node*, %struct.node*, %struct.arc*, %struct.arc*, %struct.arc*, %struct.arc*, i64, i64, i32, i32 }
+
+ at perm = internal unnamed_addr global [351 x %struct.basket*] zeroinitializer, align 16
+
+define dso_local void @sort_basket(i64 %min, i64 %max) {
+entry:
+  %add = add nsw i64 %min, %max
+  %div = sdiv i64 %add, 2
+  %arrayidx = getelementptr inbounds [351 x %struct.basket*], [351 x %struct.basket*]* @perm, i64 0, i64 %div
+  %0 = load %struct.basket*, %struct.basket** %arrayidx, align 8
+  %abs_cost = getelementptr inbounds %struct.basket, %struct.basket* %0, i64 0, i32 2
+  %1 = load i64, i64* %abs_cost, align 8
+  br label %do.body
+
+do.body:
+  %r.0 = phi i64 [ %max, %entry ], [ %r.2, %if.end ]
+  %l.0 = phi i64 [ %min, %entry ], [ %l.2, %if.end ]
+  br label %while.cond
+
+while.cond:
+  %l.1 = phi i64 [ %l.0, %do.body ], [ %inc, %while.body ]
+  %arrayidx1 = getelementptr inbounds [351 x %struct.basket*], [351 x %struct.basket*]* @perm, i64 0, i64 %l.1
+  %2 = load %struct.basket*, %struct.basket** %arrayidx1, align 8
+  %abs_cost2 = getelementptr inbounds %struct.basket, %struct.basket* %2, i64 0, i32 2
+  %3 = load i64, i64* %abs_cost2, align 8
+  %cmp = icmp sgt i64 %3, %1
+  br i1 %cmp, label %while.body, label %while.cond3
+
+while.body:
+  %inc = add nsw i64 %l.1, 1
+  br label %while.cond
+
+while.cond3:
+  %r.1 = phi i64 [ %r.0, %while.cond ], [ %dec, %while.body7 ]
+  %arrayidx4 = getelementptr inbounds [351 x %struct.basket*], [351 x %struct.basket*]* @perm, i64 0, i64 %r.1
+  %4 = load %struct.basket*, %struct.basket** %arrayidx4, align 8
+  %abs_cost5 = getelementptr inbounds %struct.basket, %struct.basket* %4, i64 0, i32 2
+  %5 = load i64, i64* %abs_cost5, align 8
+  %cmp6 = icmp sgt i64 %1, %5
+  br i1 %cmp6, label %while.body7, label %while.end8
+
+while.body7:
+  %dec = add nsw i64 %r.1, -1
+  br label %while.cond3
+
+while.end8:
+  %cmp9 = icmp slt i64 %l.1, %r.1
+  br i1 %cmp9, label %if.then, label %if.end
+
+if.then:
+  %6 = bitcast %struct.basket** %arrayidx1 to i64*
+  %7 = load i64, i64* %6, align 8
+  store %struct.basket* %4, %struct.basket** %arrayidx1, align 8
+  %8 = bitcast %struct.basket** %arrayidx4 to i64*
+  store i64 %7, i64* %8, align 8
+  br label %if.end
+
+if.end:
+  %cmp14 = icmp sgt i64 %l.1, %r.1
+  %not.cmp14 = xor i1 %cmp14, true
+  %9 = zext i1 %not.cmp14 to i64
+  %r.2 = sub i64 %r.1, %9
+  %not.cmp1457 = xor i1 %cmp14, true
+  %inc16 = zext i1 %not.cmp1457 to i64
+  %l.2 = add nsw i64 %l.1, %inc16
+  %cmp19 = icmp sgt i64 %l.2, %r.2
+  br i1 %cmp19, label %do.end, label %do.body
+
+do.end:
+  %cmp20 = icmp sgt i64 %r.2, %min
+  br i1 %cmp20, label %if.then21, label %if.end22
+
+if.then21:
+  call void @sort_basket(i64 %min, i64 %r.2)
+  br label %if.end22
+
+if.end22:
+  %cmp23 = icmp slt i64 %l.2, %max
+  %cmp24 = icmp slt i64 %l.2, 51
+  %or.cond = and i1 %cmp23, %cmp24
+  br i1 %or.cond, label %if.then25, label %if.end26
+
+if.then25:
+  call void @sort_basket(i64 %l.2, i64 %max)
+  br label %if.end26
+
+if.end26:
+  ret void
+}
+
+; CHECK: define dso_local void @sort_basket(i64 %min, i64 %max) #0 !prof [[ENTRY_COUNT:![0-9]+]]
+; CHECK: [[ENTRY_COUNT]] = !{!"function_entry_count", i64 12949310}


        


More information about the llvm-commits mailing list