[llvm] [ctx_prof] Add Inlining support (PR #106154)

Mircea Trofin via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 3 14:13:48 PDT 2024


https://github.com/mtrofin updated https://github.com/llvm/llvm-project/pull/106154

>From a6d5e9f0312b96cec13b76a1aab11af00750c3dd Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Thu, 22 Aug 2024 18:03:56 -0700
Subject: [PATCH] [ctx_prof] Add Inlining support

---
 llvm/include/llvm/Analysis/CtxProfAnalysis.h  |  19 +-
 llvm/include/llvm/IR/IntrinsicInst.h          |   2 +
 .../llvm/ProfileData/PGOCtxProfReader.h       |   7 +
 llvm/include/llvm/Transforms/Utils/Cloning.h  |  12 +
 llvm/lib/Analysis/CtxProfAnalysis.cpp         |  29 ++-
 llvm/lib/Transforms/IPO/ModuleInliner.cpp     |   5 +-
 llvm/lib/Transforms/Utils/InlineFunction.cpp  | 237 ++++++++++++++++++
 .../Analysis/CtxProfAnalysis/full-cycle.ll    |   2 +-
 llvm/test/Analysis/CtxProfAnalysis/inline.ll  | 109 ++++++++
 .../Analysis/CtxProfAnalysis/json_equals.py   |  15 ++
 llvm/test/Analysis/CtxProfAnalysis/load.ll    |   9 +-
 .../Utils/CallPromotionUtilsTest.cpp          |   3 +-
 12 files changed, 431 insertions(+), 18 deletions(-)
 create mode 100644 llvm/test/Analysis/CtxProfAnalysis/inline.ll
 create mode 100644 llvm/test/Analysis/CtxProfAnalysis/json_equals.py

diff --git a/llvm/include/llvm/Analysis/CtxProfAnalysis.h b/llvm/include/llvm/Analysis/CtxProfAnalysis.h
index 10aef6f6067b6f..80edd19ea8f8f8 100644
--- a/llvm/include/llvm/Analysis/CtxProfAnalysis.h
+++ b/llvm/include/llvm/Analysis/CtxProfAnalysis.h
@@ -15,6 +15,7 @@
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/ProfileData/PGOCtxProfReader.h"
+#include <optional>
 
 namespace llvm {
 
@@ -63,6 +64,16 @@ class PGOContextualProfile {
     return getDefinedFunctionGUID(F) != 0;
   }
 
+  uint32_t getNumCounters(const Function &F) const {
+    assert(isFunctionKnown(F));
+    return FuncInfo.find(getDefinedFunctionGUID(F))->second.NextCounterIndex;
+  }
+
+  uint32_t getNumCallsites(const Function &F) const {
+    assert(isFunctionKnown(F));
+    return FuncInfo.find(getDefinedFunctionGUID(F))->second.NextCallsiteIndex;
+  }
+
   uint32_t allocateNextCounterIndex(const Function &F) {
     assert(isFunctionKnown(F));
     return FuncInfo.find(getDefinedFunctionGUID(F))->second.NextCounterIndex++;
@@ -91,11 +102,11 @@ class PGOContextualProfile {
 };
 
 class CtxProfAnalysis : public AnalysisInfoMixin<CtxProfAnalysis> {
-  StringRef Profile;
+  const std::optional<StringRef> Profile;
 
 public:
   static AnalysisKey Key;
-  explicit CtxProfAnalysis(StringRef Profile = "");
+  explicit CtxProfAnalysis(std::optional<StringRef> Profile = std::nullopt);
 
   using Result = PGOContextualProfile;
 
@@ -113,9 +124,7 @@ class CtxProfAnalysisPrinterPass
     : public PassInfoMixin<CtxProfAnalysisPrinterPass> {
 public:
   enum class PrintMode { Everything, JSON };
-  explicit CtxProfAnalysisPrinterPass(raw_ostream &OS,
-                                      PrintMode Mode = PrintMode::Everything)
-      : OS(OS), Mode(Mode) {}
+  explicit CtxProfAnalysisPrinterPass(raw_ostream &OS);
 
   PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
   static bool isRequired() { return true; }
diff --git a/llvm/include/llvm/IR/IntrinsicInst.h b/llvm/include/llvm/IR/IntrinsicInst.h
index 71a96e0671c2f1..fc8d1b3d1947e3 100644
--- a/llvm/include/llvm/IR/IntrinsicInst.h
+++ b/llvm/include/llvm/IR/IntrinsicInst.h
@@ -1516,6 +1516,8 @@ class InstrProfInstBase : public IntrinsicInst {
     return const_cast<Value *>(getArgOperand(0))->stripPointerCasts();
   }
 
+  void setNameValue(Value *V) { setArgOperand(0, V); }
+
   // The hash of the CFG for the instrumented function.
   ConstantInt *getHash() const {
     return cast<ConstantInt>(const_cast<Value *>(getArgOperand(1)));
diff --git a/llvm/include/llvm/ProfileData/PGOCtxProfReader.h b/llvm/include/llvm/ProfileData/PGOCtxProfReader.h
index f7f88966f7573f..e03481916dd48a 100644
--- a/llvm/include/llvm/ProfileData/PGOCtxProfReader.h
+++ b/llvm/include/llvm/ProfileData/PGOCtxProfReader.h
@@ -74,6 +74,13 @@ class PGOCtxProfContext final {
     Iter->second.emplace(Other.guid(), std::move(Other));
   }
 
+  void ingestAllContexts(uint32_t CSId, CallTargetMapTy &&Other) {
+    auto [_, Inserted] = callsites().try_emplace(CSId, std::move(Other));
+    (void)Inserted;
+    assert(Inserted &&
+           "CSId was expected to be newly created as result of e.g. inlining");
+  }
+
   void resizeCounters(uint32_t Size) { Counters.resize(Size); }
 
   bool hasCallsite(uint32_t I) const {
diff --git a/llvm/include/llvm/Transforms/Utils/Cloning.h b/llvm/include/llvm/Transforms/Utils/Cloning.h
index 6226062dd713f6..2ddcfeb1501e28 100644
--- a/llvm/include/llvm/Transforms/Utils/Cloning.h
+++ b/llvm/include/llvm/Transforms/Utils/Cloning.h
@@ -20,6 +20,7 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/Twine.h"
 #include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/CtxProfAnalysis.h"
 #include "llvm/Analysis/InlineCost.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/ValueHandle.h"
@@ -270,6 +271,17 @@ InlineResult InlineFunction(CallBase &CB, InlineFunctionInfo &IFI,
                             bool InsertLifetime = true,
                             Function *ForwardVarArgsTo = nullptr);
 
+/// Same as above, but it will update the contextual profile. If the contextual
+/// profile is invalid (i.e. not loaded because it is not present), it defaults
+/// to the behavior of the non-contextual profile updating variant above. This
+/// makes it easy to drop-in replace uses of the non-contextual overload.
+InlineResult InlineFunction(CallBase &CB, InlineFunctionInfo &IFI,
+                            CtxProfAnalysis::Result &CtxProf,
+                            bool MergeAttributes = false,
+                            AAResults *CalleeAAR = nullptr,
+                            bool InsertLifetime = true,
+                            Function *ForwardVarArgsTo = nullptr);
+
 /// Clones a loop \p OrigLoop.  Returns the loop and the blocks in \p
 /// Blocks.
 ///
diff --git a/llvm/lib/Analysis/CtxProfAnalysis.cpp b/llvm/lib/Analysis/CtxProfAnalysis.cpp
index 2cd3f2114397e5..457a4dcc796847 100644
--- a/llvm/lib/Analysis/CtxProfAnalysis.cpp
+++ b/llvm/lib/Analysis/CtxProfAnalysis.cpp
@@ -29,6 +29,15 @@ cl::opt<std::string>
     UseCtxProfile("use-ctx-profile", cl::init(""), cl::Hidden,
                   cl::desc("Use the specified contextual profile file"));
 
+static cl::opt<CtxProfAnalysisPrinterPass::PrintMode> PrintLevel(
+    "ctx-profile-printer-level",
+    cl::init(CtxProfAnalysisPrinterPass::PrintMode::JSON), cl::Hidden,
+    cl::values(clEnumValN(CtxProfAnalysisPrinterPass::PrintMode::Everything,
+                          "everything", "print everything - most verbose"),
+               clEnumValN(CtxProfAnalysisPrinterPass::PrintMode::JSON, "json",
+                          "just the json representation of the profile")),
+    cl::desc("Verbosity level of the contextual profile printer pass."));
+
 namespace llvm {
 namespace json {
 Value toJSON(const PGOCtxProfContext &P) {
@@ -96,12 +105,20 @@ GlobalValue::GUID AssignGUIDPass::getGUID(const Function &F) {
 }
 AnalysisKey CtxProfAnalysis::Key;
 
-CtxProfAnalysis::CtxProfAnalysis(StringRef Profile)
-    : Profile(Profile.empty() ? UseCtxProfile : Profile) {}
+CtxProfAnalysis::CtxProfAnalysis(std::optional<StringRef> Profile)
+    : Profile([&]() -> std::optional<StringRef> {
+        if (Profile)
+          return *Profile;
+        if (UseCtxProfile.getNumOccurrences())
+          return UseCtxProfile;
+        return std::nullopt;
+      }()) {}
 
 PGOContextualProfile CtxProfAnalysis::run(Module &M,
                                           ModuleAnalysisManager &MAM) {
-  ErrorOr<std::unique_ptr<MemoryBuffer>> MB = MemoryBuffer::getFile(Profile);
+  if (!Profile)
+    return {};
+  ErrorOr<std::unique_ptr<MemoryBuffer>> MB = MemoryBuffer::getFile(*Profile);
   if (auto EC = MB.getError()) {
     M.getContext().emitError("could not open contextual profile file: " +
                              EC.message());
@@ -150,7 +167,6 @@ PGOContextualProfile CtxProfAnalysis::run(Module &M,
   // If we made it this far, the Result is valid - which we mark by setting
   // .Profiles.
   // Trim first the roots that aren't in this module.
-  DenseSet<GlobalValue::GUID> ProfiledGUIDs;
   for (auto &[RootGuid, _] : llvm::make_early_inc_range(*MaybeCtx))
     if (!Result.FuncInfo.contains(RootGuid))
       MaybeCtx->erase(RootGuid);
@@ -165,11 +181,14 @@ PGOContextualProfile::getDefinedFunctionGUID(const Function &F) const {
   return 0;
 }
 
+CtxProfAnalysisPrinterPass::CtxProfAnalysisPrinterPass(raw_ostream &OS)
+    : OS(OS), Mode(PrintLevel) {}
+
 PreservedAnalyses CtxProfAnalysisPrinterPass::run(Module &M,
                                                   ModuleAnalysisManager &MAM) {
   CtxProfAnalysis::Result &C = MAM.getResult<CtxProfAnalysis>(M);
   if (!C) {
-    M.getContext().emitError("Invalid CtxProfAnalysis");
+    OS << "No contextual profile was provided.\n";
     return PreservedAnalyses::all();
   }
 
diff --git a/llvm/lib/Transforms/IPO/ModuleInliner.cpp b/llvm/lib/Transforms/IPO/ModuleInliner.cpp
index 5e91ab80d7505f..b7e4531c8e390d 100644
--- a/llvm/lib/Transforms/IPO/ModuleInliner.cpp
+++ b/llvm/lib/Transforms/IPO/ModuleInliner.cpp
@@ -20,6 +20,7 @@
 #include "llvm/Analysis/AliasAnalysis.h"
 #include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/BlockFrequencyInfo.h"
+#include "llvm/Analysis/CtxProfAnalysis.h"
 #include "llvm/Analysis/InlineAdvisor.h"
 #include "llvm/Analysis/InlineCost.h"
 #include "llvm/Analysis/InlineOrder.h"
@@ -113,6 +114,8 @@ PreservedAnalyses ModuleInlinerPass::run(Module &M,
     return PreservedAnalyses::all();
   }
 
+  auto &CtxProf = MAM.getResult<CtxProfAnalysis>(M);
+
   bool Changed = false;
 
   ProfileSummaryInfo *PSI = MAM.getCachedResult<ProfileSummaryAnalysis>(M);
@@ -213,7 +216,7 @@ PreservedAnalyses ModuleInlinerPass::run(Module &M,
         &FAM.getResult<BlockFrequencyAnalysis>(Callee));
 
     InlineResult IR =
-        InlineFunction(*CB, IFI, /*MergeAttributes=*/true,
+        InlineFunction(*CB, IFI, CtxProf, /*MergeAttributes=*/true,
                        &FAM.getResult<AAManager>(*CB->getCaller()));
     if (!IR.isSuccess()) {
       Advice->recordUnsuccessfulInlining(IR);
diff --git a/llvm/lib/Transforms/Utils/InlineFunction.cpp b/llvm/lib/Transforms/Utils/InlineFunction.cpp
index 94e87656a192c7..498208951397dd 100644
--- a/llvm/lib/Transforms/Utils/InlineFunction.cpp
+++ b/llvm/lib/Transforms/Utils/InlineFunction.cpp
@@ -23,6 +23,7 @@
 #include "llvm/Analysis/BlockFrequencyInfo.h"
 #include "llvm/Analysis/CallGraph.h"
 #include "llvm/Analysis/CaptureTracking.h"
+#include "llvm/Analysis/CtxProfAnalysis.h"
 #include "llvm/Analysis/IndirectCallVisitor.h"
 #include "llvm/Analysis/InstructionSimplify.h"
 #include "llvm/Analysis/MemoryProfileInfo.h"
@@ -46,6 +47,7 @@
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/EHPersonalities.h"
 #include "llvm/IR/Function.h"
+#include "llvm/IR/GlobalVariable.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/InlineAsm.h"
 #include "llvm/IR/InstrTypes.h"
@@ -71,6 +73,7 @@
 #include <algorithm>
 #include <cassert>
 #include <cstdint>
+#include <deque>
 #include <iterator>
 #include <limits>
 #include <optional>
@@ -2116,6 +2119,240 @@ inlineRetainOrClaimRVCalls(CallBase &CB, objcarc::ARCInstKind RVCallKind,
   }
 }
 
+// In contextual profiling, when an inline succeeds, we want to remap the
+// indices of the callee into the index space of the caller. We can't just leave
+// them as-is because the same callee may appear in other places in this caller
+// (other callsites), and its (callee's) counters and sub-contextual profile
+// tree would be potentially different.
+// Not all BBs of the callee may survive the opportunistic DCE InlineFunction
+// does (same goes for callsites in the callee).
+// We will return a pair of vectors, one for basic block IDs and one for
+// callsites. For such a vector V, V[Idx] will be -1 if the callee
+// instrumentation with index Idx did not survive inlining, and a new value
+// otherwise.
+// This function will update the caller's instrumentation intrinsics
+// accordingly, mapping indices as described above. We also replace the "name"
+// operand because we use it to distinguish between "own" instrumentation and
+// "from callee" instrumentation when performing the traversal of the CFG of the
+// caller. We traverse depth-first from the callsite's BB and up to the point we
+// hit BBs owned by the caller.
+// The return values will be then used to update the contextual
+// profile. Note: we only update the "name" and "index" operands in the
+// instrumentation intrinsics, we leave the hash and total nr of indices as-is,
+// it's not worth updating those.
+static const std::pair<std::vector<int64_t>, std::vector<int64_t>>
+remapIndices(Function &Caller, BasicBlock *StartBB,
+             CtxProfAnalysis::Result &CtxProf, uint32_t CalleeCounters,
+             uint32_t CalleeCallsites) {
+  // We'll allocate a new ID to imported callsite counters and callsites. We're
+  // using -1 to indicate a counter we delete. Most likely the entry ID, for
+  // example, will be deleted - we don't want 2 IDs in the same BB, and the
+  // entry would have been cloned in the callsite's old BB.
+  std::vector<int64_t> CalleeCounterMap;
+  std::vector<int64_t> CalleeCallsiteMap;
+  CalleeCounterMap.resize(CalleeCounters, -1);
+  CalleeCallsiteMap.resize(CalleeCallsites, -1);
+
+  auto RewriteInstrIfNeeded = [&](InstrProfIncrementInst &Ins) -> bool {
+    if (Ins.getNameValue() == &Caller)
+      return false;
+    const auto OldID = static_cast<uint32_t>(Ins.getIndex()->getZExtValue());
+    if (CalleeCounterMap[OldID] == -1)
+      CalleeCounterMap[OldID] = CtxProf.allocateNextCounterIndex(Caller);
+    const auto NewID = static_cast<uint32_t>(CalleeCounterMap[OldID]);
+
+    Ins.setNameValue(&Caller);
+    Ins.setIndex(NewID);
+    return true;
+  };
+
+  auto RewriteCallsiteInsIfNeeded = [&](InstrProfCallsite &Ins) -> bool {
+    if (Ins.getNameValue() == &Caller)
+      return false;
+    const auto OldID = static_cast<uint32_t>(Ins.getIndex()->getZExtValue());
+    if (CalleeCallsiteMap[OldID] == -1)
+      CalleeCallsiteMap[OldID] = CtxProf.allocateNextCallsiteIndex(Caller);
+    const auto NewID = static_cast<uint32_t>(CalleeCallsiteMap[OldID]);
+
+    Ins.setNameValue(&Caller);
+    Ins.setIndex(NewID);
+    return true;
+  };
+
+  std::deque<BasicBlock *> Worklist;
+  DenseSet<const BasicBlock *> Seen;
+  // We will traverse the BBs starting from the callsite BB. The callsite BB
+  // will have at least a BB ID - maybe its own, and in any case the one coming
+  // from the cloned function's entry BB. The other BBs we'll start seeing from
+  // there on may or may not have BB IDs. BBs with IDs belonging to our caller
+  // are definitely not coming from the imported function and form a boundary
+  // past which we don't need to traverse anymore. BBs may have no
+  // instrumentation (because we originally inserted instrumentation as per
+  // MST), in which case we'll traverse past them. An invariant we'll keep is
+  // that a BB will have at most 1 BB ID. For example, in the callsite BB, we
+  // will delete the callee BB's instrumentation. This doesn't result in
+  // information loss: the entry BB of the callee will have the same count as
+  // the callsite's BB. At the end of this traversal, all the callee's
+  // instrumentation would be mapped into the caller's instrumentation index
+  // space. Some of the callee's counters may be deleted (as mentioned, this
+  // should result in no loss of information).
+  Worklist.push_back(StartBB);
+  while (!Worklist.empty()) {
+    auto *BB = Worklist.front();
+    Worklist.pop_front();
+    bool Changed = false;
+    auto *BBID = CtxProfAnalysis::getBBInstrumentation(*BB);
+    if (BBID) {
+      Changed |= RewriteInstrIfNeeded(*BBID);
+      // this may be the entryblock from the inlined callee, coming into a BB
+      // that didn't have instrumentation because of MST decisions. Let's make
+      // sure it's placed accordingly. This is a noop elsewhere.
+      BBID->moveBefore(&*BB->getFirstInsertionPt());
+    }
+    for (auto &I : llvm::make_early_inc_range(*BB)) {
+      if (auto *Inc = dyn_cast<InstrProfIncrementInst>(&I)) {
+        if (Inc != BBID) {
+          // If we're here it means that the BB had more than 1 IDs, presumably
+          // some coming from the callee. We "made up our mind" to keep the
+          // first one (which may or may not have been originally the caller's).
+          // All the others are superfluous and we delete them.
+          Inc->eraseFromParent();
+          Changed = true;
+        }
+      } else if (auto *CS = dyn_cast<InstrProfCallsite>(&I)) {
+        Changed |= RewriteCallsiteInsIfNeeded(*CS);
+      }
+    }
+    if (!BBID || Changed)
+      for (auto *Succ : successors(BB))
+        if (Seen.insert(Succ).second)
+          Worklist.push_back(Succ);
+  }
+
+  assert(llvm::count_if(CalleeCounterMap, [&](const auto &V) { V == 0; }) ==
+             0 &&
+         "Counter index mapping should be either to -1 or to non-zero index, "
+         "because the 0 "
+         "index corresponds to the entry BB of the caller");
+  assert(llvm::count_if(CalleeCallsiteMap, [&](const auto &V) { V == 0; }) ==
+             0 &&
+         "Callsite index mapping should be either to -1 or to non-zero index, "
+         "because there should have been at least a callsite - the inlined one "
+         "- which would have had a 0 index.");
+
+  return {std::move(CalleeCounterMap), std::move(CalleeCallsiteMap)};
+}
+
+// Inline. If successful, update the contextual profile (if a valid one is
+// given).
+// The contextual profile data is organized in trees, as follows:
+//  - each node corresponds to a function
+//  - the root of each tree corresponds to an "entrypoint" - e.g.
+//    RPC handler for server side
+//  - the path from the root to a node is a particular call path
+//  - the counters stored in a node are counter values observed in that
+//    particular call path ("context")
+//  - the edges between nodes are annotated with callsite IDs.
+//
+// Updating the contextual profile after an inlining means, at a high level,
+// copying over the data of the callee, **intentionally without any value
+// scaling**, and copying over the callees of the inlined callee.
+llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI,
+                                        CtxProfAnalysis::Result &CtxProf,
+                                        bool MergeAttributes,
+                                        AAResults *CalleeAAR,
+                                        bool InsertLifetime,
+                                        Function *ForwardVarArgsTo) {
+  if (!CtxProf)
+    return InlineFunction(CB, IFI, MergeAttributes, CalleeAAR, InsertLifetime,
+                          ForwardVarArgsTo);
+
+  auto &Caller = *CB.getCaller();
+  auto &Callee = *CB.getCalledFunction();
+  auto *StartBB = CB.getParent();
+
+  // Get some preliminary data about the callsite before it might get inlined.
+  // Inlining shouldn't delete the callee, but it's cleaner (and low-cost) to
+  // get this data upfront and rely less on InlineFunction's behavior.
+  const auto CalleeGUID = AssignGUIDPass::getGUID(Callee);
+  auto *CallsiteIDIns = CtxProfAnalysis::getCallsiteInstrumentation(CB);
+  const auto CallsiteID =
+      static_cast<uint32_t>(CallsiteIDIns->getIndex()->getZExtValue());
+
+  const auto NumCalleeCounters = CtxProf.getNumCounters(Callee);
+  const auto NumCalleeCallsites = CtxProf.getNumCallsites(Callee);
+
+  auto Ret = InlineFunction(CB, IFI, MergeAttributes, CalleeAAR, InsertLifetime,
+                            ForwardVarArgsTo);
+  if (!Ret.isSuccess())
+    return Ret;
+
+  // Inlining succeeded, we don't need the instrumentation of the inlined
+  // callsite.
+  CallsiteIDIns->eraseFromParent();
+
+  // Assinging Maps and then capturing references into it in the lambda because
+  // captured structured bindings are a C++20 extension. We do also need a
+  // capture here, though.
+  const auto IndicesMaps = remapIndices(Caller, StartBB, CtxProf,
+                                        NumCalleeCounters, NumCalleeCallsites);
+  const uint32_t NewCountersSize = CtxProf.getNumCounters(Caller);
+
+  auto Updater = [&](PGOCtxProfContext &Ctx) {
+    assert(Ctx.guid() == AssignGUIDPass::getGUID(Caller));
+    const auto &[CalleeCounterMap, CalleeCallsiteMap] = IndicesMaps;
+    assert(
+        (Ctx.counters().size() +
+             llvm::count_if(CalleeCounterMap, [](auto V) { return V != -1; }) ==
+         NewCountersSize) &&
+        "The caller's counters size should have grown by the number of new "
+        "distinct counters inherited from the inlined callee.");
+    Ctx.resizeCounters(NewCountersSize);
+    // If the callsite wasn't exercised in this context, the value of the
+    // counters coming from it is 0 - which it is right now, after resizing them
+    // - and so we're done.
+    auto CSIt = Ctx.callsites().find(CallsiteID);
+    if (CSIt == Ctx.callsites().end())
+      return;
+    auto CalleeCtxIt = CSIt->second.find(CalleeGUID);
+    // The callsite was exercised, but not with this callee (so presumably this
+    // is an indirect callsite). Again, we're done here.
+    if (CalleeCtxIt == CSIt->second.end())
+      return;
+
+    // Let's pull in the counter values and the subcontexts coming from the
+    // inlined callee.
+    auto &CalleeCtx = CalleeCtxIt->second;
+    assert(CalleeCtx.guid() == CalleeGUID);
+
+    for (auto I = 0U; I < CalleeCtx.counters().size(); ++I) {
+      const int64_t NewIndex = CalleeCounterMap[I];
+      if (NewIndex >= 0) {
+        assert(NewIndex != 0 && "counter index mapping shouldn't happen to a 0 "
+                                "index, that's the caller's entry BB");
+        Ctx.counters()[NewIndex] = CalleeCtx.counters()[I];
+      }
+    }
+    for (auto &[I, OtherSet] : CalleeCtx.callsites()) {
+      const int64_t NewCSIdx = CalleeCallsiteMap[I];
+      if (NewCSIdx >= 0) {
+        assert(NewCSIdx != 0 &&
+               "callsite index mapping shouldn't happen to a 0 index, the "
+               "caller must've had at least one callsite (with such an index)");
+        Ctx.ingestAllContexts(NewCSIdx, std::move(OtherSet));
+      }
+    }
+    // We know the traversal is preorder, so it wouldn't have yet looked at the
+    // sub-contexts of this context that it's currently visiting. Meaning, the
+    // erase below invalidates no iterators.
+    auto Deleted = Ctx.callsites().erase(CallsiteID);
+    assert(Deleted);
+    (void)Deleted;
+  };
+  CtxProf.update(Updater, &Caller);
+  return Ret;
+}
+
 /// This function inlines the called function into the basic block of the
 /// caller. This returns false if it is not possible to inline this call.
 /// The program is still in a well defined state if this occurs though.
diff --git a/llvm/test/Analysis/CtxProfAnalysis/full-cycle.ll b/llvm/test/Analysis/CtxProfAnalysis/full-cycle.ll
index 06ba8b3542f7d5..5284f3a3c7c4e2 100644
--- a/llvm/test/Analysis/CtxProfAnalysis/full-cycle.ll
+++ b/llvm/test/Analysis/CtxProfAnalysis/full-cycle.ll
@@ -24,7 +24,7 @@
 ; RUN:  -r %t/m2.bc,f1 \
 ; RUN:  -r %t/m2.bc,f3 \
 ; RUN:  -r %t/m2.bc,entrypoint,plx
-; RUN: opt --passes='function-import,require<ctx-prof-analysis>,print<ctx-prof-analysis>' \
+; RUN: opt --passes='function-import,require<ctx-prof-analysis>,print<ctx-prof-analysis>' -ctx-profile-printer-level=everything \
 ; RUN:  -summary-file=%t/m2.bc.thinlto.bc -use-ctx-profile=%t/profile.ctxprofdata %t/m2.bc \
 ; RUN:  -S -o %t/m2.post.ll 2> %t/profile.txt
 ; RUN: diff %t/expected.txt %t/profile.txt
diff --git a/llvm/test/Analysis/CtxProfAnalysis/inline.ll b/llvm/test/Analysis/CtxProfAnalysis/inline.ll
new file mode 100644
index 00000000000000..875bc4938653b9
--- /dev/null
+++ b/llvm/test/Analysis/CtxProfAnalysis/inline.ll
@@ -0,0 +1,109 @@
+; RUN: rm -rf %t
+; RUN: split-file %s %t
+; RUN: llvm-ctxprof-util fromJSON --input=%t/profile.json --output=%t/profile.ctxprofdata
+
+; RUN: opt -passes='module-inline,print<ctx-prof-analysis>' -ctx-profile-printer-level=everything %t/module.ll -S \
+; RUN:   -use-ctx-profile=%t/profile.ctxprofdata -ctx-profile-printer-level=json \
+; RUN:   -o - 2> %t/profile-final.txt | FileCheck %s
+; RUN: %python %S/json_equals.py %t/profile-final.txt %t/expected.json
+
+; There are 2 calls to @a from @entrypoint. We only inline the one callsite
+; marked as alwaysinline, the rest are blocked (marked noinline). After the inline,
+; the updated contextual profile should still have the same tree for the non-inlined case.
+; For the inlined case, we should observe, for the @entrypoint context:
+;  - an empty callsite where the inlined one was (first one, i.e. 0)
+;  - more counters appended to the old counter list (because we ingested the
+;    ones from @a). The values are copied.
+;  - a new callsite to @b
+; CHECK-LABEL: @entrypoint
+; CHECK-LABEL: yes:
+; CHECK:         call void @llvm.instrprof.increment(ptr @entrypoint, i64 0, i32 3, i32 1)
+; CHECK-NEXT:    br label %loop.i
+; CHECK-LABEL:  loop.i:
+; CHECK-NEXT:    %indvar.i = phi i32 [ %indvar.next.i, %loop.i ], [ 0, %yes ]
+; CHECK-NEXT:    call void @llvm.instrprof.increment(ptr @entrypoint, i64 0, i32 2, i32 3)
+; CHECK-NEXT:    %b.i = add i32 %x, %indvar.i
+; CHECK-NEXT:    call void @llvm.instrprof.callsite(ptr @entrypoint, i64 0, i32 1, i32 2, ptr @b)
+; CHECK-NEXT:    %call3.i = call i32 @b() #1
+; CHECK-LABEL: no:
+; CHECK-NEXT:    call void @llvm.instrprof.increment(ptr @entrypoint, i64 0, i32 3, i32 2)
+; CHECK-NEXT:    call void @llvm.instrprof.callsite(ptr @entrypoint, i64 0, i32 2, i32 1, ptr @a)
+; CHECK-NEXT:    %call2 = call i32 @a(i32 %x) #1
+; CHECK-NEXT:    br label %exit
+
+
+;--- module.ll
+define i32 @entrypoint(i32 %x) !guid !0 {
+  call void @llvm.instrprof.increment(ptr @entrypoint, i64 0, i32 3, i32 0)
+  %t = icmp eq i32 %x, 0
+  br i1 %t, label %yes, label %no
+yes:
+  call void @llvm.instrprof.increment(ptr @entrypoint, i64 0, i32 3, i32 1)
+  call void @llvm.instrprof.callsite(ptr @entrypoint, i64 0, i32 2, i32 0, ptr @a)
+  %call1 = call i32 @a(i32 %x) alwaysinline
+  br label %exit
+no:
+  call void @llvm.instrprof.increment(ptr @entrypoint, i64 0, i32 3, i32 2)
+  call void @llvm.instrprof.callsite(ptr @entrypoint, i64 0, i32 2, i32 1, ptr @a)
+  %call2 = call i32 @a(i32 %x) noinline
+  br label %exit
+exit:
+  %ret = phi i32 [%call1, %yes], [%call2, %no]
+  ret i32 %ret
+}
+
+define i32 @a(i32 %x) !guid !1 {
+entry:
+  call void @llvm.instrprof.increment(ptr @a, i64 0, i32 2, i32 0)
+  br label %loop
+loop:
+  %indvar = phi i32 [%indvar.next, %loop], [0, %entry]
+  call void @llvm.instrprof.increment(ptr @a, i64 0, i32 2, i32 1)
+  %b = add i32 %x, %indvar
+  call void @llvm.instrprof.callsite(ptr @a, i64 0, i32 1, i32 0, ptr @b)
+  %call3 = call i32 @b() noinline
+  %indvar.next = add i32 %indvar, %call3
+  %cond = icmp slt i32 %indvar.next, %x
+  br i1 %cond, label %loop, label %exit
+exit:
+  ret i32 8
+}
+
+define i32 @b() !guid !2 {
+  call void @llvm.instrprof.increment(ptr @b, i64 0, i32 1, i32 0)
+  ret i32 1
+}
+
+!0 = !{i64 1000}
+!1 = !{i64 1001}
+!2 = !{i64 1002}
+;--- profile.json
+[
+  { "Guid": 1000,
+    "Counters": [10, 2, 8],
+    "Callsites": [
+      [ { "Guid": 1001,
+          "Counters": [2, 100],
+          "Callsites": [[{"Guid": 1002, "Counters": [100]}]]}
+      ],
+      [ { "Guid": 1001,
+          "Counters": [8, 500],
+          "Callsites": [[{"Guid": 1002, "Counters": [500]}]]}
+      ]
+    ]
+  }
+]
+;--- expected.json
+[
+  { "Guid": 1000,
+    "Counters": [10, 2, 8, 100],
+    "Callsites": [
+      [],
+      [ { "Guid": 1001,
+          "Counters": [8, 500],
+          "Callsites": [[{"Guid": 1002, "Counters": [500]}]]}
+      ],
+      [{ "Guid": 1002, "Counters": [100]}]
+    ]
+  }
+]
diff --git a/llvm/test/Analysis/CtxProfAnalysis/json_equals.py b/llvm/test/Analysis/CtxProfAnalysis/json_equals.py
new file mode 100644
index 00000000000000..8b94dda5528c5b
--- /dev/null
+++ b/llvm/test/Analysis/CtxProfAnalysis/json_equals.py
@@ -0,0 +1,15 @@
+import json
+import sys
+
+
+def to_json(fname: str):
+    with open(fname) as f:
+        return json.load(f)
+
+
+a = to_json(sys.argv[1])
+b = to_json(sys.argv[2])
+
+if a == b:
+    exit(0)
+exit(1)
diff --git a/llvm/test/Analysis/CtxProfAnalysis/load.ll b/llvm/test/Analysis/CtxProfAnalysis/load.ll
index fa09474f433151..7d92f9678e7c3c 100644
--- a/llvm/test/Analysis/CtxProfAnalysis/load.ll
+++ b/llvm/test/Analysis/CtxProfAnalysis/load.ll
@@ -3,10 +3,10 @@
 ; RUN: rm -rf %t
 ; RUN: split-file %s %t
 ; RUN: llvm-ctxprof-util fromJSON --input=%t/profile.json --output=%t/profile.ctxprofdata
-; RUN: not opt -passes='require<ctx-prof-analysis>,print<ctx-prof-analysis>' \
-; RUN:   %t/example.ll -S 2>&1 | FileCheck %s --check-prefix=NO-FILE
+; RUN: opt -passes='require<ctx-prof-analysis>,print<ctx-prof-analysis>' -ctx-profile-printer-level=everything \
+; RUN:   %t/example.ll -S 2>&1 | FileCheck %s --check-prefix=NO-CTX
 
-; RUN: not opt -passes='require<ctx-prof-analysis>,print<ctx-prof-analysis>' \
+; RUN: not opt -passes='require<ctx-prof-analysis>,print<ctx-prof-analysis>' -ctx-profile-printer-level=everything \
 ; RUN:   -use-ctx-profile=does_not_exist.ctxprofdata %t/example.ll -S 2>&1 | FileCheck %s --check-prefix=NO-FILE
 
 ; RUN: opt -module-summary -passes='thinlto-pre-link<O2>' \
@@ -14,11 +14,12 @@
 
 ; RUN: opt -module-summary -passes='thinlto-pre-link<O2>' -use-ctx-profile=%t/profile.ctxprofdata \
 ; RUN:  %t/example.ll -S -o %t/prelink.ll
-; RUN: opt -passes='require<ctx-prof-analysis>,print<ctx-prof-analysis>' \
+; RUN: opt -passes='require<ctx-prof-analysis>,print<ctx-prof-analysis>' -ctx-profile-printer-level=everything \
 ; RUN:   -use-ctx-profile=%t/profile.ctxprofdata %t/prelink.ll -S 2> %t/output.txt
 ; RUN: diff %t/expected-profile-output.txt %t/output.txt
 
 ; NO-FILE: error: could not open contextual profile file
+; NO-CTX: No contextual profile was provided
 ;
 ; This is the reference profile, laid out in the format the json formatter will
 ; output it from opt.
diff --git a/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp b/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp
index 36c64b9f333d7c..dcb1c10433ccf4 100644
--- a/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp
+++ b/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp
@@ -570,8 +570,7 @@ define i32 @f4() !guid !3 {
 
   std::string Str;
   raw_string_ostream OS(Str);
-  CtxProfAnalysisPrinterPass Printer(
-      OS, CtxProfAnalysisPrinterPass::PrintMode::JSON);
+  CtxProfAnalysisPrinterPass Printer(OS);
   Printer.run(*M, MAM);
   const char *Expected = R"json(
   [



More information about the llvm-commits mailing list