[llvm] [ctx_prof] Flattened profile lowering pass (PR #107329)

Mircea Trofin via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 4 16:43:34 PDT 2024


https://github.com/mtrofin created https://github.com/llvm/llvm-project/pull/107329

None

>From 0d9ac07b2b62114c9073af4d18ad8342753c433a Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Tue, 3 Sep 2024 21:28:05 -0700
Subject: [PATCH] [ctx_prof] Flattened profile lowering pass

---
 .../llvm/Analysis/ProfileSummaryInfo.h        |  10 +-
 llvm/include/llvm/ProfileData/ProfileCommon.h |   6 +-
 .../Instrumentation/PGOCtxProfFlattening.h    |  24 ++
 llvm/lib/Passes/PassBuilder.cpp               |   1 +
 llvm/lib/Passes/PassRegistry.def              |   1 +
 .../Transforms/Instrumentation/CMakeLists.txt |   1 +
 .../Instrumentation/PGOCtxProfFlattening.cpp  | 278 ++++++++++++++++++
 .../CtxProfAnalysis/flatten-and-annotate.ll   |  79 +++++
 8 files changed, 396 insertions(+), 4 deletions(-)
 create mode 100644 llvm/include/llvm/Transforms/Instrumentation/PGOCtxProfFlattening.h
 create mode 100644 llvm/lib/Transforms/Instrumentation/PGOCtxProfFlattening.cpp
 create mode 100644 llvm/test/Analysis/CtxProfAnalysis/flatten-and-annotate.ll

diff --git a/llvm/include/llvm/Analysis/ProfileSummaryInfo.h b/llvm/include/llvm/Analysis/ProfileSummaryInfo.h
index ceae3e8a0ddb95..7e67a5d79052bc 100644
--- a/llvm/include/llvm/Analysis/ProfileSummaryInfo.h
+++ b/llvm/include/llvm/Analysis/ProfileSummaryInfo.h
@@ -42,7 +42,7 @@ class ProfileSummaryInfo {
 private:
   const Module *M;
   std::unique_ptr<ProfileSummary> Summary;
-  void computeThresholds();
+
   // Count thresholds to answer isHotCount and isColdCount queries.
   std::optional<uint64_t> HotCountThreshold, ColdCountThreshold;
   // True if the working set size of the code is considered huge,
@@ -63,6 +63,14 @@ class ProfileSummaryInfo {
   ProfileSummaryInfo(const Module &M) : M(&M) { refresh(); }
   ProfileSummaryInfo(ProfileSummaryInfo &&Arg) = default;
 
+  /// Replace the summary with the provided one.
+  void overrideSummary(std::unique_ptr<ProfileSummary> NewSummary) {
+    Summary.swap(NewSummary);
+  }
+
+  /// Compute the hot and cold thresholds.
+  void computeThresholds();
+
   /// If no summary is present, attempt to refresh.
   void refresh();
 
diff --git a/llvm/include/llvm/ProfileData/ProfileCommon.h b/llvm/include/llvm/ProfileData/ProfileCommon.h
index eaab59484c947a..0bc7e5fd2cd81c 100644
--- a/llvm/include/llvm/ProfileData/ProfileCommon.h
+++ b/llvm/include/llvm/ProfileData/ProfileCommon.h
@@ -79,13 +79,13 @@ class ProfileSummaryBuilder {
 class InstrProfSummaryBuilder final : public ProfileSummaryBuilder {
   uint64_t MaxInternalBlockCount = 0;
 
-  inline void addEntryCount(uint64_t Count);
-  inline void addInternalCount(uint64_t Count);
-
 public:
   InstrProfSummaryBuilder(std::vector<uint32_t> Cutoffs)
       : ProfileSummaryBuilder(std::move(Cutoffs)) {}
 
+  inline void addEntryCount(uint64_t Count);
+  inline void addInternalCount(uint64_t Count);
+
   void addRecord(const InstrProfRecord &);
   std::unique_ptr<ProfileSummary> getSummary();
 };
diff --git a/llvm/include/llvm/Transforms/Instrumentation/PGOCtxProfFlattening.h b/llvm/include/llvm/Transforms/Instrumentation/PGOCtxProfFlattening.h
new file mode 100644
index 00000000000000..56876740264379
--- /dev/null
+++ b/llvm/include/llvm/Transforms/Instrumentation/PGOCtxProfFlattening.h
@@ -0,0 +1,24 @@
+//===-- PGOCtxProfFlattening.h - Contextual Instr. Flattening ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the PGOCtxProfFlattening class.
+//
+//===----------------------------------------------------------------------===//
+#ifndef LLVM_TRANSFORMS_INSTRUMENTATION_PGOCTXPROFFLATTENING_H
+#define LLVM_TRANSFORMS_INSTRUMENTATION_PGOCTXPROFFLATTENING_H
+
+#include "llvm/IR/PassManager.h"
+namespace llvm {
+
+class PGOCtxProfFlattening : public PassInfoMixin<PGOCtxProfFlattening> {
+public:
+  explicit PGOCtxProfFlattening() = default;
+  PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
+};
+} // namespace llvm
+#endif
diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index 1df1449fce597c..a8827d1d909f6f 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -197,6 +197,7 @@
 #include "llvm/Transforms/Instrumentation/MemProfiler.h"
 #include "llvm/Transforms/Instrumentation/MemorySanitizer.h"
 #include "llvm/Transforms/Instrumentation/NumericalStabilitySanitizer.h"
+#include "llvm/Transforms/Instrumentation/PGOCtxProfFlattening.h"
 #include "llvm/Transforms/Instrumentation/PGOCtxProfLowering.h"
 #include "llvm/Transforms/Instrumentation/PGOForceFunctionAttrs.h"
 #include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def
index d6067089c6b5c1..923a799cda53e7 100644
--- a/llvm/lib/Passes/PassRegistry.def
+++ b/llvm/lib/Passes/PassRegistry.def
@@ -58,6 +58,7 @@ MODULE_PASS("coro-early", CoroEarlyPass())
 MODULE_PASS("cross-dso-cfi", CrossDSOCFIPass())
 MODULE_PASS("ctx-instr-gen",
             PGOInstrumentationGen(PGOInstrumentationType::CTXPROF))
+MODULE_PASS("ctx-prof-flatten", PGOCtxProfFlattening())
 MODULE_PASS("deadargelim", DeadArgumentEliminationPass())
 MODULE_PASS("debugify", NewPMDebugifyPass())
 MODULE_PASS("dfsan", DataFlowSanitizerPass())
diff --git a/llvm/lib/Transforms/Instrumentation/CMakeLists.txt b/llvm/lib/Transforms/Instrumentation/CMakeLists.txt
index deab37801ff1df..d45b07447d09da 100644
--- a/llvm/lib/Transforms/Instrumentation/CMakeLists.txt
+++ b/llvm/lib/Transforms/Instrumentation/CMakeLists.txt
@@ -15,6 +15,7 @@ add_llvm_component_library(LLVMInstrumentation
   InstrProfiling.cpp
   KCFI.cpp
   LowerAllowCheckPass.cpp
+  PGOCtxProfFlattening.cpp
   PGOCtxProfLowering.cpp
   PGOForceFunctionAttrs.cpp
   PGOInstrumentation.cpp
diff --git a/llvm/lib/Transforms/Instrumentation/PGOCtxProfFlattening.cpp b/llvm/lib/Transforms/Instrumentation/PGOCtxProfFlattening.cpp
new file mode 100644
index 00000000000000..4685a0965e927b
--- /dev/null
+++ b/llvm/lib/Transforms/Instrumentation/PGOCtxProfFlattening.cpp
@@ -0,0 +1,278 @@
+//===- PGOCtxProfFlattening.cpp - Contextual Instr. Flattening ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Flattens the contextual profile and lowers it to MD_prof.
+// This should happen after all IPO (which is assumed to have maintained the
+// contextual profile) happened. Flattening consists of summing the values at
+// the same index of the counters belonging to all the contexts of a function.
+// The lowering consists of materializing the counter values to function
+// entrypoint counts and branch probabilities.
+//
+// This pass also removes contextual instrumentation, which has been kept around
+// to facilitate its functionality.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Instrumentation/PGOCtxProfFlattening.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Analysis/CtxProfAnalysis.h"
+#include "llvm/Analysis/OptimizationRemarkEmitter.h"
+#include "llvm/Analysis/ProfileSummaryInfo.h"
+#include "llvm/CodeGen/MachineBasicBlock.h"
+#include "llvm/IR/Analysis.h"
+#include "llvm/IR/CFG.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/IR/ProfileSummary.h"
+#include "llvm/ProfileData/ProfileCommon.h"
+#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
+#include "llvm/Transforms/Scalar/DCE.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+
+using namespace llvm;
+
+namespace {
+
+class Solver final {
+  struct BBInfo;
+  struct EdgeInfo {
+    BBInfo *const Src;
+    BBInfo *const Dest;
+    std::optional<uint64_t> Count;
+
+    explicit EdgeInfo(BBInfo *Src, BBInfo *Dest) : Src(Src), Dest(Dest) {}
+  };
+
+  struct BBInfo {
+    std::optional<uint64_t> Count;
+    SmallVector<EdgeInfo *> OutEdges;
+    SmallVector<EdgeInfo *> InEdges;
+    size_t UnknownCountOutEdges = 0;
+    size_t UnknownCountInEdges = 0;
+
+    uint64_t getEdgeSum(const SmallVector<EdgeInfo *> &Edges,
+                        bool AssumeAllKnown) const {
+      uint64_t Sum = 0;
+      for (const auto *E : Edges)
+        if (E)
+          Sum += AssumeAllKnown ? *E->Count : E->Count.value_or(0U);
+      return Sum;
+    }
+
+    void takeCountFrom(const SmallVector<EdgeInfo *> &Edges) {
+      assert(!Count.has_value());
+      Count = getEdgeSum(Edges, true);
+    }
+  };
+
+  Function &F;
+  std::map<const BasicBlock *, BBInfo> BBInfos;
+  std::vector<EdgeInfo> EdgeInfos;
+  InstrProfSummaryBuilder &PB;
+
+  void setSingleUnknownEdgeCount(SmallVector<EdgeInfo *> &Edges,
+                                 uint64_t Value) {
+    EdgeInfo *E = nullptr;
+    for (auto *I : Edges)
+      if (I && !I->Count.has_value()) {
+        E = I;
+#ifdef NDEBUG
+        break;
+#else
+        assert((!E || E == I) &&
+               "Expected exactly one edge to have an unknown count, "
+               "found a second one");
+        continue;
+#endif
+      }
+    assert(E && "Expected exactly one edge to have an unknown count");
+    assert(!E->Count.has_value());
+    E->Count = Value;
+    assert(E->Src->UnknownCountOutEdges > 0);
+    assert(E->Dest->UnknownCountInEdges > 0);
+    --E->Src->UnknownCountOutEdges;
+    --E->Dest->UnknownCountInEdges;
+  }
+
+  void solve(const SmallVectorImpl<uint64_t> &Counters) {
+    for (const auto &BB : F) {
+      if (auto *Ins = CtxProfAnalysis::getBBInstrumentation(
+              const_cast<BasicBlock &>(BB)))
+        BBInfos.find(&BB)->second.Count =
+            Counters[Ins->getIndex()->getZExtValue()];
+    }
+    bool KeepGoing = true;
+    while (KeepGoing) {
+      KeepGoing = false;
+      for (const auto &BB : reverse(F)) {
+        auto &Info = BBInfos.find(&BB)->second;
+        if (!Info.Count) {
+          if (!succ_empty(&BB) && !Info.UnknownCountOutEdges) {
+            Info.takeCountFrom(Info.OutEdges);
+            KeepGoing = true;
+          } else if (!BB.isEntryBlock() && !Info.UnknownCountInEdges) {
+            Info.takeCountFrom(Info.InEdges);
+            KeepGoing = true;
+          }
+        }
+        if (Info.Count.has_value()) {
+          if (Info.UnknownCountOutEdges == 1) {
+            uint64_t KnownSum = Info.getEdgeSum(Info.OutEdges, false);
+            uint64_t EdgeVal =
+                *Info.Count > KnownSum ? *Info.Count - KnownSum : 0U;
+            setSingleUnknownEdgeCount(Info.OutEdges, EdgeVal);
+            KeepGoing = true;
+          }
+          if (Info.UnknownCountInEdges == 1) {
+            uint64_t KnownSum = Info.getEdgeSum(Info.InEdges, false);
+            uint64_t EdgeVal =
+                *Info.Count > KnownSum ? *Info.Count - KnownSum : 0U;
+            setSingleUnknownEdgeCount(Info.InEdges, EdgeVal);
+            KeepGoing = true;
+          }
+        }
+      }
+    }
+  }
+  // The only criteria for exclusion is faux suspend -> exit edges in presplit
+  // coroutines. The API serves for readability, currently.
+  bool shouldExcludeEdge(const BasicBlock &Src, const BasicBlock &Dest) const {
+    return llvm::isPresplitCoroSuspendExitEdge(Src, Dest);
+  }
+
+public:
+  Solver(Function &F, InstrProfSummaryBuilder &PB) : F(F), PB(PB) {
+    assert(!F.isDeclaration());
+    size_t NrEdges = 0;
+    for (const auto &BB : F) {
+      auto [It, Ins] = BBInfos.insert({&BB, {}});
+      (void)Ins;
+      assert(Ins);
+      NrEdges += llvm::count_if(successors(&BB), [&](const auto *Succ) {
+        return !shouldExcludeEdge(BB, *Succ);
+      });
+      It->second.InEdges.reserve(pred_size(&BB));
+      It->second.OutEdges.resize(succ_size(&BB));
+    }
+    EdgeInfos.reserve(NrEdges);
+    for (const auto &BB : F) {
+      auto &Info = BBInfos.find(&BB)->second;
+      for (auto I = 0U; I < BB.getTerminator()->getNumSuccessors(); ++I) {
+        const auto *Succ = BB.getTerminator()->getSuccessor(I);
+        if (!shouldExcludeEdge(BB, *Succ)) {
+          auto &EI = EdgeInfos.emplace_back(&BBInfos.find(&BB)->second,
+                                            &BBInfos.find(Succ)->second);
+          Info.OutEdges[I] = &EI;
+          ++Info.UnknownCountOutEdges;
+          BBInfos.find(Succ)->second.InEdges.push_back(&EI);
+          ++BBInfos.find(Succ)->second.UnknownCountInEdges;
+        }
+      }
+    }
+  }
+
+  void assignProfData(const SmallVectorImpl<uint64_t> &Counters) {
+    assert(!Counters.empty());
+    solve(Counters);
+    F.setEntryCount(Counters[0]);
+    PB.addEntryCount(Counters[0]);
+
+    for (auto &BB : F) {
+      if (succ_size(&BB) < 2)
+        continue;
+      auto *Term = BB.getTerminator();
+      SmallVector<uint64_t, 2> EdgeCounts(Term->getNumSuccessors(), 0);
+      uint64_t MaxCount = 0;
+      const auto &BBInfo = BBInfos.find(&BB)->second;
+      for (unsigned SuccIdx = 0, Size = BBInfo.OutEdges.size(); SuccIdx < Size;
+           ++SuccIdx) {
+        const auto *E = BBInfo.OutEdges[SuccIdx];
+        if (!E)
+          continue;
+        uint64_t EdgeCount = *E->Count;
+        if (EdgeCount > MaxCount)
+          MaxCount = EdgeCount;
+        EdgeCounts[SuccIdx] = EdgeCount;
+        PB.addInternalCount(EdgeCount);
+      }
+
+      if (MaxCount == 0)
+        F.getContext().emitError(
+            "[ctx-prof] Encountered a BB with more than one successor, where "
+            "all outgoing edges have a 0 count. This occurs in non-exiting "
+            "functions (message pumps, usually) which are not supported in the "
+            "contextual profiling case");
+      setProfMetadata(F.getParent(), Term, EdgeCounts, MaxCount);
+    }
+  }
+};
+
+bool areAllBBsReachable(const Function &F, FunctionAnalysisManager &FAM) {
+  auto &DT = FAM.getResult<DominatorTreeAnalysis>(const_cast<Function &>(F));
+  for (const auto &BB : F)
+    if (!DT.isReachableFromEntry(&BB))
+      return false;
+  return true;
+}
+
+void clearColdFunctionProfile(Function &F) {
+  for (auto &BB : F)
+    BB.getTerminator()->setMetadata(LLVMContext::MD_prof, nullptr);
+  F.setEntryCount(0U);
+}
+
+void removeInstrumentation(Function &F) {
+  for (auto &BB : F)
+    for (auto &I : llvm::make_early_inc_range(BB))
+      if (isa<InstrProfCntrInstBase>(I))
+        I.eraseFromParent();
+}
+
+} // namespace
+
+PreservedAnalyses PGOCtxProfFlattening::run(Module &M,
+                                            ModuleAnalysisManager &MAM) {
+  auto &CtxProf = MAM.getResult<CtxProfAnalysis>(M);
+  if (!CtxProf)
+    return PreservedAnalyses::all();
+
+  const auto FlattenedProfile = CtxProf.flatten();
+
+  InstrProfSummaryBuilder PB(ProfileSummaryBuilder::DefaultCutoffs);
+  for (auto &F : M) {
+    if (F.isDeclaration())
+      continue;
+
+    if (!areAllBBsReachable(F,
+                            MAM.getResult<FunctionAnalysisManagerModuleProxy>(M)
+                                .getManager())) {
+      M.getContext().emitError(
+          "[ctx-prof] Function has unreacheable basic blocks: " + F.getName());
+      continue;
+    }
+
+    const auto &FlatProfile =
+        FlattenedProfile.lookup(AssignGUIDPass::getGUID(F));
+    if (FlatProfile.empty())
+      clearColdFunctionProfile(F);
+    else {
+      Solver S(F, PB);
+      S.assignProfData(FlatProfile);
+    }
+    removeInstrumentation(F);
+  }
+
+  auto &PSI = MAM.getResult<ProfileSummaryAnalysis>(M);
+
+  PSI.overrideSummary(PB.getSummary());
+  PSI.computeThresholds();
+
+  return PreservedAnalyses::none();
+}
\ No newline at end of file
diff --git a/llvm/test/Analysis/CtxProfAnalysis/flatten-and-annotate.ll b/llvm/test/Analysis/CtxProfAnalysis/flatten-and-annotate.ll
new file mode 100644
index 00000000000000..1c3dd830e5f226
--- /dev/null
+++ b/llvm/test/Analysis/CtxProfAnalysis/flatten-and-annotate.ll
@@ -0,0 +1,79 @@
+; REQUIRES: x86_64-linux
+;
+; RUN: rm -rf %t
+; RUN: split-file %s %t
+; RUN: llvm-ctxprof-util fromJSON --input=%t/profile.json --output=%t/profile.ctxprofdata
+; RUN: opt -module-summary -passes='thinlto-pre-link<O2>' -use-ctx-profile=%t/profile.ctxprofdata \
+; RUN:   %t/example.ll -S -o %t/prelink.ll
+; RUN: FileCheck --input-file %t/prelink.ll %s --check-prefix=PRELINK
+; RUN: opt -passes='ctx-prof-flatten' -use-ctx-profile=%t/profile.ctxprofdata %t/prelink.ll -S  | FileCheck %s
+;
+;
+; Check that instrumentation occurs where expected: the "no" block for foo, and
+; the "yes" block for an_entrypoint - which explains the subsequent branch weights
+;
+; PRELINK-LABEL: @foo
+; PRELINK-LABEL: no:
+; PRELINK:         call void @llvm.instrprof.increment(ptr @foo, i64 [[#]], i32 2, i32 1)
+
+; PRELINK-LABEL: @an_entrypoint
+; PRELINK-LABEL: yes:
+; PRELINK:         call void @llvm.instrprof.increment(ptr @an_entrypoint, i64 [[#]], i32 2, i32 1)
+
+; CHECK-NOT:   call void @llvm.instrprof
+;
+; CHECK-LABEL: @foo
+; CHECK-SAME:    !prof !0
+; CHECK:          br i1 %t, label %yes, label %no, !prof !2
+; CHECK-LABEL: @an_entrypoint
+; CHECK-SAME:    !prof !3
+; CHECK:          br i1 %t, label %yes, label %common.ret, !prof !5
+; CHECK:       !0 = !{!"function_entry_count", i64 40} 
+; CHECK:       !2 = !{!"branch_weights", i32 30, i32 10} 
+; CHECK:       !5 = !{!"branch_weights", i32 40, i32 60} 
+
+;--- profile.json
+[
+  {
+    "Guid": 4909520559318251808,
+    "Counters": [100, 40],
+    "Callsites": [
+      [
+        {
+          "Guid": 11872291593386833696,
+          "Counters": [ 40, 10 ]
+        }
+      ]
+    ]
+  }
+]
+;--- example.ll
+declare void @bar()
+
+define void @foo(i32 %a, ptr %fct) #0 !guid !0 {
+  %t = icmp sgt i32 %a, 7
+  br i1 %t, label %yes, label %no
+yes:
+  call void %fct(i32 %a)
+  br label %exit
+no:
+  call void @bar()
+  br label %exit
+exit:
+  ret void
+}
+
+define void @an_entrypoint(i32 %a) !guid !1 {
+  %t = icmp sgt i32 %a, 0
+  br i1 %t, label %yes, label %no
+
+yes:
+  call void @foo(i32 1, ptr null)
+  ret void
+no:
+  ret void
+}
+
+attributes #0 = { noinline }
+!0 = !{ i64 11872291593386833696 }
+!1 = !{i64 4909520559318251808}



More information about the llvm-commits mailing list