[llvm] [PGO] Supporting code for always instrumenting loop entries (PR #116789)

via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 19 05:31:44 PST 2024


https://github.com/ronryvchin updated https://github.com/llvm/llvm-project/pull/116789

>From bfe17ba60d8b357de40c5e0b47ff44b57039b735 Mon Sep 17 00:00:00 2001
From: Ron Ryvchin <ron.ryvchin at nextsilicon.com>
Date: Sun, 17 Nov 2024 11:11:44 +0200
Subject: [PATCH] [PGO] Supporting code for always instrumenting loop entries

This patch extends the PGO infrastructure with an option to prefer the instrumentation of loop entry blocks.
This option is a generalization of https://github.com/llvm/llvm-project/commit/19fb5b467bb97f95eace1f3637d2d1041cebd3ce,
and helps to cover cases where the loop exit is never executed.
An example where this can occur are event handling loops.

Note that change does NOT change the default behavior.
---
 .../llvm/Transforms/Instrumentation/CFGMST.h  | 19 +++++-
 .../Instrumentation/GCOVProfiling.cpp         |  3 +-
 .../Instrumentation/PGOInstrumentation.cpp    | 67 +++++++++++++------
 llvm/test/Transforms/PGOProfile/loop3.ll      | 62 +++++++++++++++++
 4 files changed, 127 insertions(+), 24 deletions(-)
 create mode 100644 llvm/test/Transforms/PGOProfile/loop3.ll

diff --git a/llvm/include/llvm/Transforms/Instrumentation/CFGMST.h b/llvm/include/llvm/Transforms/Instrumentation/CFGMST.h
index 35b3d615e3844a..f3e96a370bf72f 100644
--- a/llvm/include/llvm/Transforms/Instrumentation/CFGMST.h
+++ b/llvm/include/llvm/Transforms/Instrumentation/CFGMST.h
@@ -19,6 +19,7 @@
 #include "llvm/Analysis/BlockFrequencyInfo.h"
 #include "llvm/Analysis/BranchProbabilityInfo.h"
 #include "llvm/Analysis/CFG.h"
+#include "llvm/Analysis/LoopInfo.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/Support/BranchProbability.h"
@@ -52,10 +53,14 @@ template <class Edge, class BBInfo> class CFGMST {
 
   BranchProbabilityInfo *const BPI;
   BlockFrequencyInfo *const BFI;
+  LoopInfo *const LI;
 
   // If function entry will be always instrumented.
   const bool InstrumentFuncEntry;
 
+  // If true loop entries will be always instrumented.
+  const bool InstrumentLoopEntries;
+
   // Find the root group of the G and compress the path from G to the root.
   BBInfo *findAndCompressGroup(BBInfo *G) {
     if (G->Group != G)
@@ -154,6 +159,12 @@ template <class Edge, class BBInfo> class CFGMST {
           }
           if (BPI != nullptr)
             Weight = BPI->getEdgeProbability(&BB, TargetBB).scale(scaleFactor);
+          // If InstrumentLoopEntries is on and TargetBB is a loop head (i.e.,
+          // the current edge leads to a loop), set Weight to be minimal, so
+          // that the edge won't be chosen for the MST and will be instrumented.
+          if (InstrumentLoopEntries && LI != nullptr &&
+              LI->isLoopHeader(TargetBB))
+            Weight = 0;
           if (Weight == 0)
             Weight++;
           auto *E = &addEdge(&BB, TargetBB, Weight);
@@ -291,10 +302,12 @@ template <class Edge, class BBInfo> class CFGMST {
     return *AllEdges.back();
   }
 
-  CFGMST(Function &Func, bool InstrumentFuncEntry,
+  CFGMST(Function &Func, bool InstrumentFuncEntry, bool InstrumentLoopEntries,
          BranchProbabilityInfo *BPI = nullptr,
-         BlockFrequencyInfo *BFI = nullptr)
-      : F(Func), BPI(BPI), BFI(BFI), InstrumentFuncEntry(InstrumentFuncEntry) {
+         BlockFrequencyInfo *BFI = nullptr, LoopInfo *LI = nullptr)
+      : F(Func), BPI(BPI), BFI(BFI), LI(LI),
+        InstrumentFuncEntry(InstrumentFuncEntry),
+        InstrumentLoopEntries(InstrumentLoopEntries) {
     buildEdges();
     sortEdgesByWeight();
     computeMinimumSpanningTree();
diff --git a/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp b/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp
index 2ea89be40a3d46..f9be7f933d31e4 100644
--- a/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp
+++ b/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp
@@ -820,7 +820,8 @@ bool GCOVProfiler::emitProfileNotes(
       SplitIndirectBrCriticalEdges(F, /*IgnoreBlocksWithoutPHI=*/false, BPI,
                                    BFI);
 
-      CFGMST<Edge, BBInfo> MST(F, /*InstrumentFuncEntry_=*/false, BPI, BFI);
+      CFGMST<Edge, BBInfo> MST(F, /*InstrumentFuncEntry=*/false,
+                               /*InstrumentLoopEntries=*/false, BPI, BFI);
 
       // getInstrBB can split basic blocks and push elements to AllEdges.
       for (size_t I : llvm::seq<size_t>(0, MST.numEdges())) {
diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index 4d8141431a0c19..f5908dcdf67cab 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -259,6 +259,11 @@ static cl::opt<bool> PGOInstrumentEntry(
     "pgo-instrument-entry", cl::init(false), cl::Hidden,
     cl::desc("Force to instrument function entry basicblock."));
 
+static cl::opt<bool>
+    PGOInstrumentLoopEntries("pgo-instrument-loop-entries", cl::init(false),
+                             cl::Hidden,
+                             cl::desc("Force to instrument loop entries."));
+
 static cl::opt<bool> PGOFunctionEntryCoverage(
     "pgo-function-entry-coverage", cl::Hidden,
     cl::desc(
@@ -359,6 +364,7 @@ class FunctionInstrumenter final {
   std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers;
   BranchProbabilityInfo *const BPI;
   BlockFrequencyInfo *const BFI;
+  LoopInfo *const LI;
 
   const PGOInstrumentationType InstrumentationType;
 
@@ -376,14 +382,17 @@ class FunctionInstrumenter final {
            InstrumentationType == PGOInstrumentationType::CTXPROF;
   }
 
+  bool shouldInstrumentLoopEntries() const { return PGOInstrumentLoopEntries; }
+
 public:
   FunctionInstrumenter(
       Module &M, Function &F, TargetLibraryInfo &TLI,
       std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers,
       BranchProbabilityInfo *BPI = nullptr, BlockFrequencyInfo *BFI = nullptr,
+      LoopInfo *LI = nullptr,
       PGOInstrumentationType InstrumentationType = PGOInstrumentationType::FDO)
       : M(M), F(F), TLI(TLI), ComdatMembers(ComdatMembers), BPI(BPI), BFI(BFI),
-        InstrumentationType(InstrumentationType) {}
+        LI(LI), InstrumentationType(InstrumentationType) {}
 
   void instrument();
 };
@@ -625,12 +634,13 @@ template <class Edge, class BBInfo> class FuncPGOInstrumentation {
       Function &Func, TargetLibraryInfo &TLI,
       std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers,
       bool CreateGlobalVar = false, BranchProbabilityInfo *BPI = nullptr,
-      BlockFrequencyInfo *BFI = nullptr, bool IsCS = false,
-      bool InstrumentFuncEntry = true, bool HasSingleByteCoverage = false)
+      BlockFrequencyInfo *BFI = nullptr, LoopInfo *LI = nullptr,
+      bool IsCS = false, bool InstrumentFuncEntry = true,
+      bool InstrumentLoopEntries = false, bool HasSingleByteCoverage = false)
       : F(Func), IsCS(IsCS), ComdatMembers(ComdatMembers), VPC(Func, TLI),
         TLI(TLI), ValueSites(IPVK_Last + 1),
         SIVisitor(Func, HasSingleByteCoverage),
-        MST(F, InstrumentFuncEntry, BPI, BFI),
+        MST(F, InstrumentFuncEntry, InstrumentLoopEntries, BPI, BFI, LI),
         BCI(constructBCI(Func, HasSingleByteCoverage, InstrumentFuncEntry)) {
     if (BCI && PGOViewBlockCoverageGraph)
       BCI->viewBlockCoverageGraph();
@@ -916,9 +926,10 @@ void FunctionInstrumenter::instrument() {
 
   const bool IsCtxProf = InstrumentationType == PGOInstrumentationType::CTXPROF;
   FuncPGOInstrumentation<PGOEdge, PGOBBInfo> FuncInfo(
-      F, TLI, ComdatMembers, /*CreateGlobalVar=*/!IsCtxProf, BPI, BFI,
+      F, TLI, ComdatMembers, /*CreateGlobalVar=*/!IsCtxProf, BPI, BFI, LI,
       InstrumentationType == PGOInstrumentationType::CSFDO,
-      shouldInstrumentEntryBB(), PGOBlockCoverage);
+      shouldInstrumentEntryBB(), shouldInstrumentLoopEntries(),
+      PGOBlockCoverage);
 
   auto *const Name = IsCtxProf ? cast<GlobalValue>(&F) : FuncInfo.FuncNameVar;
   auto *const CFGHash =
@@ -1136,11 +1147,13 @@ class PGOUseFunc {
   PGOUseFunc(Function &Func, Module *Modu, TargetLibraryInfo &TLI,
              std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers,
              BranchProbabilityInfo *BPI, BlockFrequencyInfo *BFIin,
-             ProfileSummaryInfo *PSI, bool IsCS, bool InstrumentFuncEntry,
+             LoopInfo *LI, ProfileSummaryInfo *PSI, bool IsCS,
+             bool InstrumentFuncEntry, bool InstrumentLoopEntries,
              bool HasSingleByteCoverage)
       : F(Func), M(Modu), BFI(BFIin), PSI(PSI),
-        FuncInfo(Func, TLI, ComdatMembers, false, BPI, BFIin, IsCS,
-                 InstrumentFuncEntry, HasSingleByteCoverage),
+        FuncInfo(Func, TLI, ComdatMembers, false, BPI, BFIin, LI, IsCS,
+                 InstrumentFuncEntry, InstrumentLoopEntries,
+                 HasSingleByteCoverage),
         FreqAttr(FFA_Normal), IsCS(IsCS), VPC(Func, TLI) {}
 
   void handleInstrProfError(Error Err, uint64_t MismatchedFuncSum);
@@ -1923,6 +1936,7 @@ static bool InstrumentAllFunctions(
     Module &M, function_ref<TargetLibraryInfo &(Function &)> LookupTLI,
     function_ref<BranchProbabilityInfo *(Function &)> LookupBPI,
     function_ref<BlockFrequencyInfo *(Function &)> LookupBFI,
+    function_ref<LoopInfo *(Function &)> LookupLI,
     PGOInstrumentationType InstrumentationType) {
   // For the context-sensitve instrumentation, we should have a separated pass
   // (before LTO/ThinLTO linking) to create these variables.
@@ -1943,10 +1957,11 @@ static bool InstrumentAllFunctions(
   for (auto &F : M) {
     if (skipPGOGen(F))
       continue;
-    auto &TLI = LookupTLI(F);
-    auto *BPI = LookupBPI(F);
-    auto *BFI = LookupBFI(F);
-    FunctionInstrumenter FI(M, F, TLI, ComdatMembers, BPI, BFI,
+    TargetLibraryInfo &TLI = LookupTLI(F);
+    BranchProbabilityInfo *BPI = LookupBPI(F);
+    BlockFrequencyInfo *BFI = LookupBFI(F);
+    LoopInfo *LI = LookupLI(F);
+    FunctionInstrumenter FI(M, F, TLI, ComdatMembers, BPI, BFI, LI,
                             InstrumentationType);
     FI.instrument();
   }
@@ -1980,8 +1995,11 @@ PreservedAnalyses PGOInstrumentationGen::run(Module &M,
   auto LookupBFI = [&FAM](Function &F) {
     return &FAM.getResult<BlockFrequencyAnalysis>(F);
   };
+  auto LookupLI = [&FAM](Function &F) {
+    return &FAM.getResult<LoopAnalysis>(F);
+  };
 
-  if (!InstrumentAllFunctions(M, LookupTLI, LookupBPI, LookupBFI,
+  if (!InstrumentAllFunctions(M, LookupTLI, LookupBPI, LookupBFI, LookupLI,
                               InstrumentationType))
     return PreservedAnalyses::all();
 
@@ -2116,6 +2134,7 @@ static bool annotateAllFunctions(
     function_ref<TargetLibraryInfo &(Function &)> LookupTLI,
     function_ref<BranchProbabilityInfo *(Function &)> LookupBPI,
     function_ref<BlockFrequencyInfo *(Function &)> LookupBFI,
+    function_ref<LoopInfo *(Function &)> LookupLI,
     ProfileSummaryInfo *PSI, bool IsCS) {
   LLVM_DEBUG(dbgs() << "Read in profile counters: ");
   auto &Ctx = M.getContext();
@@ -2181,22 +2200,26 @@ static bool annotateAllFunctions(
   bool InstrumentFuncEntry = PGOReader->instrEntryBBEnabled();
   if (PGOInstrumentEntry.getNumOccurrences() > 0)
     InstrumentFuncEntry = PGOInstrumentEntry;
+  bool InstrumentLoopEntries =
+      (PGOInstrumentLoopEntries.getNumOccurrences() > 0);
 
   bool HasSingleByteCoverage = PGOReader->hasSingleByteCoverage();
   for (auto &F : M) {
     if (skipPGOUse(F))
       continue;
-    auto &TLI = LookupTLI(F);
-    auto *BPI = LookupBPI(F);
-    auto *BFI = LookupBFI(F);
+    TargetLibraryInfo &TLI = LookupTLI(F);
+    BranchProbabilityInfo *BPI = LookupBPI(F);
+    BlockFrequencyInfo *BFI = LookupBFI(F);
+    LoopInfo *LI = LookupLI(F);
     if (!HasSingleByteCoverage) {
       // Split indirectbr critical edges here before computing the MST rather
       // than later in getInstrBB() to avoid invalidating it.
       SplitIndirectBrCriticalEdges(F, /*IgnoreBlocksWithoutPHI=*/false, BPI,
                                    BFI);
     }
-    PGOUseFunc Func(F, &M, TLI, ComdatMembers, BPI, BFI, PSI, IsCS,
-                    InstrumentFuncEntry, HasSingleByteCoverage);
+    PGOUseFunc Func(F, &M, TLI, ComdatMembers, BPI, BFI, LI, PSI, IsCS,
+                    InstrumentFuncEntry, InstrumentLoopEntries,
+                    HasSingleByteCoverage);
     if (HasSingleByteCoverage) {
       Func.populateCoverage(PGOReader.get());
       continue;
@@ -2335,10 +2358,14 @@ PreservedAnalyses PGOInstrumentationUse::run(Module &M,
   auto LookupBFI = [&FAM](Function &F) {
     return &FAM.getResult<BlockFrequencyAnalysis>(F);
   };
+  auto LookupLI = [&FAM](Function &F) {
+    return &FAM.getResult<LoopAnalysis>(F);
+  };
 
   auto *PSI = &MAM.getResult<ProfileSummaryAnalysis>(M);
   if (!annotateAllFunctions(M, ProfileFileName, ProfileRemappingFileName, *FS,
-                            LookupTLI, LookupBPI, LookupBFI, PSI, IsCS))
+                            LookupTLI, LookupBPI, LookupBFI, LookupLI, PSI,
+                            IsCS))
     return PreservedAnalyses::all();
 
   return PreservedAnalyses::none();
diff --git a/llvm/test/Transforms/PGOProfile/loop3.ll b/llvm/test/Transforms/PGOProfile/loop3.ll
new file mode 100644
index 00000000000000..7ffc3450d55ecb
--- /dev/null
+++ b/llvm/test/Transforms/PGOProfile/loop3.ll
@@ -0,0 +1,62 @@
+; RUN: opt %s -passes=pgo-instr-gen -pgo-instrument-loop-entries=false -S | FileCheck %s --check-prefixes=GEN,NOTLOOPENTRIES
+; RUN: opt %s -passes=pgo-instr-gen -pgo-instrument-loop-entries=true -S | FileCheck %s --check-prefixes=GEN,LOOPENTRIES
+; RUN: opt %s -passes=pgo-instr-gen -pgo-instrument-entry=true -S | FileCheck %s --check-prefixes=GEN,FUNCTIONENTRY
+
+; GEN: $__llvm_profile_raw_version = comdat any
+; GEN: @__llvm_profile_raw_version = hidden constant i64 {{[0-9]+}}, comdat
+; GEN: @__profn_test_simple_for_with_bypass = private constant [27 x i8] c"test_simple_for_with_bypass"
+
+define i32 @test_simple_for_with_bypass(i32 %n) {
+entry:
+; GEN: entry:
+; NOTLOOPENTRIES: call void @llvm.instrprof.increment(ptr @__profn_test_simple_for_with_bypass, i64 {{[0-9]+}}, i32 3, i32 1)
+; LOOPENTRIES: call void @llvm.instrprof.increment(ptr @__profn_test_simple_for_with_bypass, i64 {{[0-9]+}}, i32 3, i32 0)
+; FUNCTIONENTRY: call void @llvm.instrprof.increment(ptr @__profn_test_simple_for_with_bypass, i64 {{[0-9]+}}, i32 3, i32 0)
+  %mask = and i32 %n, 65535
+  %skip = icmp eq i32 %mask, 0
+  br i1 %skip, label %end, label %for.entry
+
+for.entry:
+; GEN: for.entry:
+; LOOPENTRIES: call void @llvm.instrprof.increment(ptr @__profn_test_simple_for_with_bypass, i64 {{[0-9]+}}, i32 3, i32 1)
+; NOTLOOPENTRIES-NOT: call void @llvm.instrprof.increment
+; FUNCTIONENTRY-NOT: call void @llvm.instrprof.increment
+  br label %for.cond
+
+for.cond:
+; GEN: for.cond:
+; GEN-NOT: call void @llvm.instrprof.increment
+  %i = phi i32 [ 0, %for.entry ], [ %inc1, %for.inc ]
+  %sum = phi i32 [ 1, %for.entry ], [ %inc, %for.inc ]
+  %cmp = icmp slt i32 %i, %n
+  br i1 %cmp, label %for.body, label %for.end, !prof !1
+
+for.body:
+; GEN: for.body:
+; GEN-NOT: call void @llvm.instrprof.increment
+  %inc = add nsw i32 %sum, 1
+  br label %for.inc
+
+for.inc:
+; GEN: for.inc:
+; NOTLOOPENTRIES: call void @llvm.instrprof.increment(ptr @__profn_test_simple_for_with_bypass, i64 {{[0-9]+}}, i32 3, i32 0)
+; LOOPENTRIES: call void @llvm.instrprof.increment(ptr @__profn_test_simple_for_with_bypass, i64 {{[0-9]+}}, i32 3, i32 2)
+; FUNCTIONENTRY: call void @llvm.instrprof.increment(ptr @__profn_test_simple_for_with_bypass, i64 {{[0-9]+}}, i32 3, i32 1)
+  %inc1 = add nsw i32 %i, 1
+  br label %for.cond
+
+for.end:
+; GEN: for.end:
+; NOTLOOPENTRIES: call void @llvm.instrprof.increment(ptr @__profn_test_simple_for_with_bypass, i64 {{[0-9]+}}, i32 3, i32 2)
+; FUNCTIONENTRY: call void @llvm.instrprof.increment(ptr @__profn_test_simple_for_with_bypass, i64 {{[0-9]+}}, i32 3, i32 2)
+; LOOPENTRIES-NOT: call void @llvm.instrprof.increment
+  br label %end
+
+end:
+; GEN: end:
+; GEN-NOT: call void @llvm.instrprof.increment
+  %final_sum = phi i32 [ %sum, %for.end ], [ 0, %entry ]
+  ret i32 %final_sum
+}
+
+!1 = !{!"branch_weights", i32 100000, i32 80}



More information about the llvm-commits mailing list