[llvm] [PGO] Supporting code for always instrumenting loop entries (PR #116789)

via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 19 06:23:28 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-pgo

@llvm/pr-subscribers-llvm-transforms

Author: None (ronryvchin)

<details>
<summary>Changes</summary>

This patch extends the PGO infrastructure with an option to prefer the instrumentation of loop entry blocks.
This option is a generalization of https://github.com/llvm/llvm-project/commit/19fb5b467bb97f95eace1f3637d2d1041cebd3ce,
and helps to cover cases where the loop exit is never executed.
An example where this can occur are event handling loops.

Note that change does NOT change the default behavior.

---
Full diff: https://github.com/llvm/llvm-project/pull/116789.diff


4 Files Affected:

- (modified) llvm/include/llvm/Transforms/Instrumentation/CFGMST.h (+17-3) 
- (modified) llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp (+2-1) 
- (modified) llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp (+48-21) 
- (added) llvm/test/Transforms/PGOProfile/loop3.ll (+62) 


``````````diff
diff --git a/llvm/include/llvm/Transforms/Instrumentation/CFGMST.h b/llvm/include/llvm/Transforms/Instrumentation/CFGMST.h
index 35b3d615e3844a..8540f3103c23f5 100644
--- a/llvm/include/llvm/Transforms/Instrumentation/CFGMST.h
+++ b/llvm/include/llvm/Transforms/Instrumentation/CFGMST.h
@@ -19,6 +19,7 @@
 #include "llvm/Analysis/BlockFrequencyInfo.h"
 #include "llvm/Analysis/BranchProbabilityInfo.h"
 #include "llvm/Analysis/CFG.h"
+#include "llvm/Analysis/LoopInfo.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/Support/BranchProbability.h"
@@ -52,10 +53,14 @@ template <class Edge, class BBInfo> class CFGMST {
 
   BranchProbabilityInfo *const BPI;
   BlockFrequencyInfo *const BFI;
+  LoopInfo *const LI;
 
   // If function entry will be always instrumented.
   const bool InstrumentFuncEntry;
 
+  // If true loop entries will be always instrumented.
+  const bool InstrumentLoopEntries;
+
   // Find the root group of the G and compress the path from G to the root.
   BBInfo *findAndCompressGroup(BBInfo *G) {
     if (G->Group != G)
@@ -154,6 +159,11 @@ template <class Edge, class BBInfo> class CFGMST {
           }
           if (BPI != nullptr)
             Weight = BPI->getEdgeProbability(&BB, TargetBB).scale(scaleFactor);
+          // If InstrumentLoopEntries is on and TargetBB is a loop head (i.e.,
+          // the current edge leads to a loop), set Weight to be minimal, so
+          // that the edge won't be chosen for the MST and will be instrumented.
+          if (InstrumentLoopEntries && LI->isLoopHeader(TargetBB))
+            Weight = 0;
           if (Weight == 0)
             Weight++;
           auto *E = &addEdge(&BB, TargetBB, Weight);
@@ -291,10 +301,14 @@ template <class Edge, class BBInfo> class CFGMST {
     return *AllEdges.back();
   }
 
-  CFGMST(Function &Func, bool InstrumentFuncEntry,
+  CFGMST(Function &Func, bool InstrumentFuncEntry, bool InstrumentLoopEntries,
          BranchProbabilityInfo *BPI = nullptr,
-         BlockFrequencyInfo *BFI = nullptr)
-      : F(Func), BPI(BPI), BFI(BFI), InstrumentFuncEntry(InstrumentFuncEntry) {
+         BlockFrequencyInfo *BFI = nullptr, LoopInfo *LI = nullptr)
+      : F(Func), BPI(BPI), BFI(BFI), LI(LI),
+        InstrumentFuncEntry(InstrumentFuncEntry),
+        InstrumentLoopEntries(InstrumentLoopEntries) {
+    assert(!(InstrumentLoopEntries && !LI) &&
+           "expected a LoopInfo to instrumenting loop entries");
     buildEdges();
     sortEdgesByWeight();
     computeMinimumSpanningTree();
diff --git a/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp b/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp
index 2ea89be40a3d46..f9be7f933d31e4 100644
--- a/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp
+++ b/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp
@@ -820,7 +820,8 @@ bool GCOVProfiler::emitProfileNotes(
       SplitIndirectBrCriticalEdges(F, /*IgnoreBlocksWithoutPHI=*/false, BPI,
                                    BFI);
 
-      CFGMST<Edge, BBInfo> MST(F, /*InstrumentFuncEntry_=*/false, BPI, BFI);
+      CFGMST<Edge, BBInfo> MST(F, /*InstrumentFuncEntry=*/false,
+                               /*InstrumentLoopEntries=*/false, BPI, BFI);
 
       // getInstrBB can split basic blocks and push elements to AllEdges.
       for (size_t I : llvm::seq<size_t>(0, MST.numEdges())) {
diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index 4d8141431a0c19..ad68dcb7370507 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -259,6 +259,11 @@ static cl::opt<bool> PGOInstrumentEntry(
     "pgo-instrument-entry", cl::init(false), cl::Hidden,
     cl::desc("Force to instrument function entry basicblock."));
 
+static cl::opt<bool>
+    PGOInstrumentLoopEntries("pgo-instrument-loop-entries", cl::init(false),
+                             cl::Hidden,
+                             cl::desc("Force to instrument loop entries."));
+
 static cl::opt<bool> PGOFunctionEntryCoverage(
     "pgo-function-entry-coverage", cl::Hidden,
     cl::desc(
@@ -359,6 +364,7 @@ class FunctionInstrumenter final {
   std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers;
   BranchProbabilityInfo *const BPI;
   BlockFrequencyInfo *const BFI;
+  LoopInfo *const LI;
 
   const PGOInstrumentationType InstrumentationType;
 
@@ -376,14 +382,17 @@ class FunctionInstrumenter final {
            InstrumentationType == PGOInstrumentationType::CTXPROF;
   }
 
+  bool shouldInstrumentLoopEntries() const { return PGOInstrumentLoopEntries; }
+
 public:
   FunctionInstrumenter(
       Module &M, Function &F, TargetLibraryInfo &TLI,
       std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers,
       BranchProbabilityInfo *BPI = nullptr, BlockFrequencyInfo *BFI = nullptr,
+      LoopInfo *LI = nullptr,
       PGOInstrumentationType InstrumentationType = PGOInstrumentationType::FDO)
       : M(M), F(F), TLI(TLI), ComdatMembers(ComdatMembers), BPI(BPI), BFI(BFI),
-        InstrumentationType(InstrumentationType) {}
+        LI(LI), InstrumentationType(InstrumentationType) {}
 
   void instrument();
 };
@@ -625,12 +634,13 @@ template <class Edge, class BBInfo> class FuncPGOInstrumentation {
       Function &Func, TargetLibraryInfo &TLI,
       std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers,
       bool CreateGlobalVar = false, BranchProbabilityInfo *BPI = nullptr,
-      BlockFrequencyInfo *BFI = nullptr, bool IsCS = false,
-      bool InstrumentFuncEntry = true, bool HasSingleByteCoverage = false)
+      BlockFrequencyInfo *BFI = nullptr, LoopInfo *LI = nullptr,
+      bool IsCS = false, bool InstrumentFuncEntry = true,
+      bool InstrumentLoopEntries = false, bool HasSingleByteCoverage = false)
       : F(Func), IsCS(IsCS), ComdatMembers(ComdatMembers), VPC(Func, TLI),
         TLI(TLI), ValueSites(IPVK_Last + 1),
         SIVisitor(Func, HasSingleByteCoverage),
-        MST(F, InstrumentFuncEntry, BPI, BFI),
+        MST(F, InstrumentFuncEntry, InstrumentLoopEntries, BPI, BFI, LI),
         BCI(constructBCI(Func, HasSingleByteCoverage, InstrumentFuncEntry)) {
     if (BCI && PGOViewBlockCoverageGraph)
       BCI->viewBlockCoverageGraph();
@@ -916,9 +926,10 @@ void FunctionInstrumenter::instrument() {
 
   const bool IsCtxProf = InstrumentationType == PGOInstrumentationType::CTXPROF;
   FuncPGOInstrumentation<PGOEdge, PGOBBInfo> FuncInfo(
-      F, TLI, ComdatMembers, /*CreateGlobalVar=*/!IsCtxProf, BPI, BFI,
+      F, TLI, ComdatMembers, /*CreateGlobalVar=*/!IsCtxProf, BPI, BFI, LI,
       InstrumentationType == PGOInstrumentationType::CSFDO,
-      shouldInstrumentEntryBB(), PGOBlockCoverage);
+      shouldInstrumentEntryBB(), shouldInstrumentLoopEntries(),
+      PGOBlockCoverage);
 
   auto *const Name = IsCtxProf ? cast<GlobalValue>(&F) : FuncInfo.FuncNameVar;
   auto *const CFGHash =
@@ -1136,11 +1147,13 @@ class PGOUseFunc {
   PGOUseFunc(Function &Func, Module *Modu, TargetLibraryInfo &TLI,
              std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers,
              BranchProbabilityInfo *BPI, BlockFrequencyInfo *BFIin,
-             ProfileSummaryInfo *PSI, bool IsCS, bool InstrumentFuncEntry,
+             LoopInfo *LI, ProfileSummaryInfo *PSI, bool IsCS,
+             bool InstrumentFuncEntry, bool InstrumentLoopEntries,
              bool HasSingleByteCoverage)
       : F(Func), M(Modu), BFI(BFIin), PSI(PSI),
-        FuncInfo(Func, TLI, ComdatMembers, false, BPI, BFIin, IsCS,
-                 InstrumentFuncEntry, HasSingleByteCoverage),
+        FuncInfo(Func, TLI, ComdatMembers, false, BPI, BFIin, LI, IsCS,
+                 InstrumentFuncEntry, InstrumentLoopEntries,
+                 HasSingleByteCoverage),
         FreqAttr(FFA_Normal), IsCS(IsCS), VPC(Func, TLI) {}
 
   void handleInstrProfError(Error Err, uint64_t MismatchedFuncSum);
@@ -1923,6 +1936,7 @@ static bool InstrumentAllFunctions(
     Module &M, function_ref<TargetLibraryInfo &(Function &)> LookupTLI,
     function_ref<BranchProbabilityInfo *(Function &)> LookupBPI,
     function_ref<BlockFrequencyInfo *(Function &)> LookupBFI,
+    function_ref<LoopInfo *(Function &)> LookupLI,
     PGOInstrumentationType InstrumentationType) {
   // For the context-sensitve instrumentation, we should have a separated pass
   // (before LTO/ThinLTO linking) to create these variables.
@@ -1943,10 +1957,11 @@ static bool InstrumentAllFunctions(
   for (auto &F : M) {
     if (skipPGOGen(F))
       continue;
-    auto &TLI = LookupTLI(F);
-    auto *BPI = LookupBPI(F);
-    auto *BFI = LookupBFI(F);
-    FunctionInstrumenter FI(M, F, TLI, ComdatMembers, BPI, BFI,
+    TargetLibraryInfo &TLI = LookupTLI(F);
+    BranchProbabilityInfo *BPI = LookupBPI(F);
+    BlockFrequencyInfo *BFI = LookupBFI(F);
+    LoopInfo *LI = LookupLI(F);
+    FunctionInstrumenter FI(M, F, TLI, ComdatMembers, BPI, BFI, LI,
                             InstrumentationType);
     FI.instrument();
   }
@@ -1980,8 +1995,11 @@ PreservedAnalyses PGOInstrumentationGen::run(Module &M,
   auto LookupBFI = [&FAM](Function &F) {
     return &FAM.getResult<BlockFrequencyAnalysis>(F);
   };
+  auto LookupLI = [&FAM](Function &F) {
+    return &FAM.getResult<LoopAnalysis>(F);
+  };
 
-  if (!InstrumentAllFunctions(M, LookupTLI, LookupBPI, LookupBFI,
+  if (!InstrumentAllFunctions(M, LookupTLI, LookupBPI, LookupBFI, LookupLI,
                               InstrumentationType))
     return PreservedAnalyses::all();
 
@@ -2116,7 +2134,8 @@ static bool annotateAllFunctions(
     function_ref<TargetLibraryInfo &(Function &)> LookupTLI,
     function_ref<BranchProbabilityInfo *(Function &)> LookupBPI,
     function_ref<BlockFrequencyInfo *(Function &)> LookupBFI,
-    ProfileSummaryInfo *PSI, bool IsCS) {
+    function_ref<LoopInfo *(Function &)> LookupLI, ProfileSummaryInfo *PSI,
+    bool IsCS) {
   LLVM_DEBUG(dbgs() << "Read in profile counters: ");
   auto &Ctx = M.getContext();
   // Read the counter array from file.
@@ -2181,22 +2200,26 @@ static bool annotateAllFunctions(
   bool InstrumentFuncEntry = PGOReader->instrEntryBBEnabled();
   if (PGOInstrumentEntry.getNumOccurrences() > 0)
     InstrumentFuncEntry = PGOInstrumentEntry;
+  bool InstrumentLoopEntries =
+      (PGOInstrumentLoopEntries.getNumOccurrences() > 0);
 
   bool HasSingleByteCoverage = PGOReader->hasSingleByteCoverage();
   for (auto &F : M) {
     if (skipPGOUse(F))
       continue;
-    auto &TLI = LookupTLI(F);
-    auto *BPI = LookupBPI(F);
-    auto *BFI = LookupBFI(F);
+    TargetLibraryInfo &TLI = LookupTLI(F);
+    BranchProbabilityInfo *BPI = LookupBPI(F);
+    BlockFrequencyInfo *BFI = LookupBFI(F);
+    LoopInfo *LI = LookupLI(F);
     if (!HasSingleByteCoverage) {
       // Split indirectbr critical edges here before computing the MST rather
       // than later in getInstrBB() to avoid invalidating it.
       SplitIndirectBrCriticalEdges(F, /*IgnoreBlocksWithoutPHI=*/false, BPI,
                                    BFI);
     }
-    PGOUseFunc Func(F, &M, TLI, ComdatMembers, BPI, BFI, PSI, IsCS,
-                    InstrumentFuncEntry, HasSingleByteCoverage);
+    PGOUseFunc Func(F, &M, TLI, ComdatMembers, BPI, BFI, LI, PSI, IsCS,
+                    InstrumentFuncEntry, InstrumentLoopEntries,
+                    HasSingleByteCoverage);
     if (HasSingleByteCoverage) {
       Func.populateCoverage(PGOReader.get());
       continue;
@@ -2335,10 +2358,14 @@ PreservedAnalyses PGOInstrumentationUse::run(Module &M,
   auto LookupBFI = [&FAM](Function &F) {
     return &FAM.getResult<BlockFrequencyAnalysis>(F);
   };
+  auto LookupLI = [&FAM](Function &F) {
+    return &FAM.getResult<LoopAnalysis>(F);
+  };
 
   auto *PSI = &MAM.getResult<ProfileSummaryAnalysis>(M);
   if (!annotateAllFunctions(M, ProfileFileName, ProfileRemappingFileName, *FS,
-                            LookupTLI, LookupBPI, LookupBFI, PSI, IsCS))
+                            LookupTLI, LookupBPI, LookupBFI, LookupLI, PSI,
+                            IsCS))
     return PreservedAnalyses::all();
 
   return PreservedAnalyses::none();
diff --git a/llvm/test/Transforms/PGOProfile/loop3.ll b/llvm/test/Transforms/PGOProfile/loop3.ll
new file mode 100644
index 00000000000000..7ffc3450d55ecb
--- /dev/null
+++ b/llvm/test/Transforms/PGOProfile/loop3.ll
@@ -0,0 +1,62 @@
+; RUN: opt %s -passes=pgo-instr-gen -pgo-instrument-loop-entries=false -S | FileCheck %s --check-prefixes=GEN,NOTLOOPENTRIES
+; RUN: opt %s -passes=pgo-instr-gen -pgo-instrument-loop-entries=true -S | FileCheck %s --check-prefixes=GEN,LOOPENTRIES
+; RUN: opt %s -passes=pgo-instr-gen -pgo-instrument-entry=true -S | FileCheck %s --check-prefixes=GEN,FUNCTIONENTRY
+
+; GEN: $__llvm_profile_raw_version = comdat any
+; GEN: @__llvm_profile_raw_version = hidden constant i64 {{[0-9]+}}, comdat
+; GEN: @__profn_test_simple_for_with_bypass = private constant [27 x i8] c"test_simple_for_with_bypass"
+
+define i32 @test_simple_for_with_bypass(i32 %n) {
+entry:
+; GEN: entry:
+; NOTLOOPENTRIES: call void @llvm.instrprof.increment(ptr @__profn_test_simple_for_with_bypass, i64 {{[0-9]+}}, i32 3, i32 1)
+; LOOPENTRIES: call void @llvm.instrprof.increment(ptr @__profn_test_simple_for_with_bypass, i64 {{[0-9]+}}, i32 3, i32 0)
+; FUNCTIONENTRY: call void @llvm.instrprof.increment(ptr @__profn_test_simple_for_with_bypass, i64 {{[0-9]+}}, i32 3, i32 0)
+  %mask = and i32 %n, 65535
+  %skip = icmp eq i32 %mask, 0
+  br i1 %skip, label %end, label %for.entry
+
+for.entry:
+; GEN: for.entry:
+; LOOPENTRIES: call void @llvm.instrprof.increment(ptr @__profn_test_simple_for_with_bypass, i64 {{[0-9]+}}, i32 3, i32 1)
+; NOTLOOPENTRIES-NOT: call void @llvm.instrprof.increment
+; FUNCTIONENTRY-NOT: call void @llvm.instrprof.increment
+  br label %for.cond
+
+for.cond:
+; GEN: for.cond:
+; GEN-NOT: call void @llvm.instrprof.increment
+  %i = phi i32 [ 0, %for.entry ], [ %inc1, %for.inc ]
+  %sum = phi i32 [ 1, %for.entry ], [ %inc, %for.inc ]
+  %cmp = icmp slt i32 %i, %n
+  br i1 %cmp, label %for.body, label %for.end, !prof !1
+
+for.body:
+; GEN: for.body:
+; GEN-NOT: call void @llvm.instrprof.increment
+  %inc = add nsw i32 %sum, 1
+  br label %for.inc
+
+for.inc:
+; GEN: for.inc:
+; NOTLOOPENTRIES: call void @llvm.instrprof.increment(ptr @__profn_test_simple_for_with_bypass, i64 {{[0-9]+}}, i32 3, i32 0)
+; LOOPENTRIES: call void @llvm.instrprof.increment(ptr @__profn_test_simple_for_with_bypass, i64 {{[0-9]+}}, i32 3, i32 2)
+; FUNCTIONENTRY: call void @llvm.instrprof.increment(ptr @__profn_test_simple_for_with_bypass, i64 {{[0-9]+}}, i32 3, i32 1)
+  %inc1 = add nsw i32 %i, 1
+  br label %for.cond
+
+for.end:
+; GEN: for.end:
+; NOTLOOPENTRIES: call void @llvm.instrprof.increment(ptr @__profn_test_simple_for_with_bypass, i64 {{[0-9]+}}, i32 3, i32 2)
+; FUNCTIONENTRY: call void @llvm.instrprof.increment(ptr @__profn_test_simple_for_with_bypass, i64 {{[0-9]+}}, i32 3, i32 2)
+; LOOPENTRIES-NOT: call void @llvm.instrprof.increment
+  br label %end
+
+end:
+; GEN: end:
+; GEN-NOT: call void @llvm.instrprof.increment
+  %final_sum = phi i32 [ %sum, %for.end ], [ 0, %entry ]
+  ret i32 %final_sum
+}
+
+!1 = !{!"branch_weights", i32 100000, i32 80}

``````````

</details>


https://github.com/llvm/llvm-project/pull/116789


More information about the llvm-commits mailing list