[llvm] [ctx_prof] Add support for ICP (PR #105469)

Mircea Trofin via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 26 13:21:26 PDT 2024


https://github.com/mtrofin updated https://github.com/llvm/llvm-project/pull/105469

>From 9bfbc4563db17b5ad9b6c0699fce127b58de81d4 Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Mon, 26 Aug 2024 13:20:26 -0700
Subject: [PATCH 1/2] [ctx_prof] Move the "from json" logic more centrally to
 reuse it from test.

---
 .../llvm/ProfileData/PGOCtxProfWriter.h       |  1 +
 llvm/lib/ProfileData/PGOCtxProfWriter.cpp     | 82 +++++++++++++++++++
 .../llvm-ctxprof-util/llvm-ctxprof-util.cpp   | 82 ++-----------------
 3 files changed, 90 insertions(+), 75 deletions(-)

diff --git a/llvm/include/llvm/ProfileData/PGOCtxProfWriter.h b/llvm/include/llvm/ProfileData/PGOCtxProfWriter.h
index db9a0fd77f8351..b370fdd9ba5a1c 100644
--- a/llvm/include/llvm/ProfileData/PGOCtxProfWriter.h
+++ b/llvm/include/llvm/ProfileData/PGOCtxProfWriter.h
@@ -81,5 +81,6 @@ class PGOCtxProfileWriter final {
   static constexpr StringRef ContainerMagic = "CTXP";
 };
 
+Error createCtxProfFromJSON(StringRef Profile, raw_ostream &Out);
 } // namespace llvm
 #endif
diff --git a/llvm/lib/ProfileData/PGOCtxProfWriter.cpp b/llvm/lib/ProfileData/PGOCtxProfWriter.cpp
index 74cd8763cc769d..99b5b2b3d05811 100644
--- a/llvm/lib/ProfileData/PGOCtxProfWriter.cpp
+++ b/llvm/lib/ProfileData/PGOCtxProfWriter.cpp
@@ -12,6 +12,8 @@
 
 #include "llvm/ProfileData/PGOCtxProfWriter.h"
 #include "llvm/Bitstream/BitCodeEnums.h"
+#include "llvm/ProfileData/CtxInstrContextNode.h"
+#include "llvm/Support/JSON.h"
 
 using namespace llvm;
 using namespace llvm::ctx_profile;
@@ -81,3 +83,83 @@ void PGOCtxProfileWriter::writeImpl(std::optional<uint32_t> CallerIndex,
 void PGOCtxProfileWriter::write(const ContextNode &RootNode) {
   writeImpl(std::nullopt, RootNode);
 }
+
+namespace {
+// A structural representation of the JSON input.
+struct DeserializableCtx {
+  ctx_profile::GUID Guid = 0;
+  std::vector<uint64_t> Counters;
+  std::vector<std::vector<DeserializableCtx>> Callsites;
+};
+
+ctx_profile::ContextNode *
+createNode(std::vector<std::unique_ptr<char[]>> &Nodes,
+           const std::vector<DeserializableCtx> &DCList);
+
+// Convert a DeserializableCtx into a ContextNode, potentially linking it to
+// its sibling (e.g. callee at same callsite) "Next".
+ctx_profile::ContextNode *
+createNode(std::vector<std::unique_ptr<char[]>> &Nodes,
+           const DeserializableCtx &DC,
+           ctx_profile::ContextNode *Next = nullptr) {
+  auto AllocSize = ctx_profile::ContextNode::getAllocSize(DC.Counters.size(),
+                                                          DC.Callsites.size());
+  auto *Mem = Nodes.emplace_back(std::make_unique<char[]>(AllocSize)).get();
+  std::memset(Mem, 0, AllocSize);
+  auto *Ret = new (Mem) ctx_profile::ContextNode(DC.Guid, DC.Counters.size(),
+                                                 DC.Callsites.size(), Next);
+  std::memcpy(Ret->counters(), DC.Counters.data(),
+              sizeof(uint64_t) * DC.Counters.size());
+  for (const auto &[I, DCList] : llvm::enumerate(DC.Callsites))
+    Ret->subContexts()[I] = createNode(Nodes, DCList);
+  return Ret;
+}
+
+// Convert a list of DeserializableCtx into a linked list of ContextNodes.
+ctx_profile::ContextNode *
+createNode(std::vector<std::unique_ptr<char[]>> &Nodes,
+           const std::vector<DeserializableCtx> &DCList) {
+  ctx_profile::ContextNode *List = nullptr;
+  for (const auto &DC : DCList)
+    List = createNode(Nodes, DC, List);
+  return List;
+}
+} // namespace
+
+namespace llvm {
+namespace json {
+bool fromJSON(const Value &E, DeserializableCtx &R, Path P) {
+  json::ObjectMapper Mapper(E, P);
+  return Mapper && Mapper.map("Guid", R.Guid) &&
+         Mapper.map("Counters", R.Counters) &&
+         Mapper.mapOptional("Callsites", R.Callsites);
+}
+} // namespace json
+} // namespace llvm
+
+Error llvm::createCtxProfFromJSON(StringRef Profile, raw_ostream &Out) {
+  auto P = json::parse(Profile);
+  if (!P)
+    return P.takeError();
+
+  json::Path::Root R("");
+  std::vector<DeserializableCtx> DCList;
+  if (!fromJSON(*P, DCList, R))
+    return R.getError();
+  // Nodes provides memory backing for the ContextualNodes.
+  std::vector<std::unique_ptr<char[]>> Nodes;
+  std::error_code EC;
+  if (EC)
+    return createStringError(EC, "failed to open output");
+  PGOCtxProfileWriter Writer(Out);
+  for (const auto &DC : DCList) {
+    auto *TopList = createNode(Nodes, DC);
+    if (!TopList)
+      return createStringError(
+          "Unexpected error converting internal structure to ctx profile");
+    Writer.write(*TopList);
+  }
+  if (EC)
+    return createStringError(EC, "failed to write output");
+  return Error::success();
+}
\ No newline at end of file
diff --git a/llvm/tools/llvm-ctxprof-util/llvm-ctxprof-util.cpp b/llvm/tools/llvm-ctxprof-util/llvm-ctxprof-util.cpp
index 3bb7681e33a871..0fad4ee4360ddf 100644
--- a/llvm/tools/llvm-ctxprof-util/llvm-ctxprof-util.cpp
+++ b/llvm/tools/llvm-ctxprof-util/llvm-ctxprof-util.cpp
@@ -46,90 +46,22 @@ static cl::opt<std::string> OutputFilename("output", cl::value_desc("output"),
                                            cl::desc("Output file"),
                                            cl::sub(FromJSON));
 
-namespace {
-// A structural representation of the JSON input.
-struct DeserializableCtx {
-  GlobalValue::GUID Guid = 0;
-  std::vector<uint64_t> Counters;
-  std::vector<std::vector<DeserializableCtx>> Callsites;
-};
-
-ctx_profile::ContextNode *
-createNode(std::vector<std::unique_ptr<char[]>> &Nodes,
-           const std::vector<DeserializableCtx> &DCList);
-
-// Convert a DeserializableCtx into a ContextNode, potentially linking it to
-// its sibling (e.g. callee at same callsite) "Next".
-ctx_profile::ContextNode *
-createNode(std::vector<std::unique_ptr<char[]>> &Nodes,
-           const DeserializableCtx &DC,
-           ctx_profile::ContextNode *Next = nullptr) {
-  auto AllocSize = ctx_profile::ContextNode::getAllocSize(DC.Counters.size(),
-                                                          DC.Callsites.size());
-  auto *Mem = Nodes.emplace_back(std::make_unique<char[]>(AllocSize)).get();
-  std::memset(Mem, 0, AllocSize);
-  auto *Ret = new (Mem) ctx_profile::ContextNode(DC.Guid, DC.Counters.size(),
-                                                 DC.Callsites.size(), Next);
-  std::memcpy(Ret->counters(), DC.Counters.data(),
-              sizeof(uint64_t) * DC.Counters.size());
-  for (const auto &[I, DCList] : llvm::enumerate(DC.Callsites))
-    Ret->subContexts()[I] = createNode(Nodes, DCList);
-  return Ret;
-}
-
-// Convert a list of DeserializableCtx into a linked list of ContextNodes.
-ctx_profile::ContextNode *
-createNode(std::vector<std::unique_ptr<char[]>> &Nodes,
-           const std::vector<DeserializableCtx> &DCList) {
-  ctx_profile::ContextNode *List = nullptr;
-  for (const auto &DC : DCList)
-    List = createNode(Nodes, DC, List);
-  return List;
-}
-} // namespace
-
-namespace llvm {
-namespace json {
-// Hook into the JSON deserialization.
-bool fromJSON(const Value &E, DeserializableCtx &R, Path P) {
-  json::ObjectMapper Mapper(E, P);
-  return Mapper && Mapper.map("Guid", R.Guid) &&
-         Mapper.map("Counters", R.Counters) &&
-         Mapper.mapOptional("Callsites", R.Callsites);
-}
-} // namespace json
-} // namespace llvm
-
 // Save the bitstream profile from the JSON representation.
 Error convertFromJSON() {
   auto BufOrError = MemoryBuffer::getFileOrSTDIN(InputFilename);
   if (!BufOrError)
     return createFileError(InputFilename, BufOrError.getError());
-  auto P = json::parse(BufOrError.get()->getBuffer());
-  if (!P)
-    return P.takeError();
 
-  std::vector<DeserializableCtx> DCList;
-  json::Path::Root R("");
-  if (!fromJSON(*P, DCList, R))
-    return R.getError();
-  // Nodes provides memory backing for the ContextualNodes.
-  std::vector<std::unique_ptr<char[]>> Nodes;
   std::error_code EC;
-  raw_fd_stream Out(OutputFilename, EC);
+  // Using a fd_ostream instead of a fd_stream. The latter would be more
+  // efficient as the bitstream writer supports incremental flush to it, but the
+  // json scenario is for test, and file size scalability doesn't really concern
+  // us.
+  raw_fd_ostream Out(OutputFilename, EC);
   if (EC)
     return createStringError(EC, "failed to open output");
-  PGOCtxProfileWriter Writer(Out);
-  for (const auto &DC : DCList) {
-    auto *TopList = createNode(Nodes, DC);
-    if (!TopList)
-      return createStringError(
-          "Unexpected error converting internal structure to ctx profile");
-    Writer.write(*TopList);
-  }
-  if (EC)
-    return createStringError(EC, "failed to write output");
-  return Error::success();
+
+  return llvm::createCtxProfFromJSON(BufOrError.get()->getBuffer(), Out);
 }
 
 int main(int argc, const char **argv) {

>From 41dfe6bf1418fd82ed2f5b11767ddbe1da05c19d Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Tue, 20 Aug 2024 21:32:23 -0700
Subject: [PATCH 2/2] [ctx_prof] Add support for ICP

---
 llvm/include/llvm/Analysis/CtxProfAnalysis.h  |  17 +-
 llvm/include/llvm/IR/IntrinsicInst.h          |   2 +
 .../llvm/ProfileData/PGOCtxProfReader.h       |  20 ++
 .../Transforms/Utils/CallPromotionUtils.h     |   4 +
 llvm/lib/Analysis/CtxProfAnalysis.cpp         |  79 +++++---
 llvm/lib/IR/IntrinsicInst.cpp                 |  10 +
 .../Transforms/Utils/CallPromotionUtils.cpp   |  86 +++++++++
 .../Utils/CallPromotionUtilsTest.cpp          | 178 ++++++++++++++++++
 8 files changed, 364 insertions(+), 32 deletions(-)

diff --git a/llvm/include/llvm/Analysis/CtxProfAnalysis.h b/llvm/include/llvm/Analysis/CtxProfAnalysis.h
index 0b4dd8ae3a0dc7..10aef6f6067b6f 100644
--- a/llvm/include/llvm/Analysis/CtxProfAnalysis.h
+++ b/llvm/include/llvm/Analysis/CtxProfAnalysis.h
@@ -73,6 +73,12 @@ class PGOContextualProfile {
     return FuncInfo.find(getDefinedFunctionGUID(F))->second.NextCallsiteIndex++;
   }
 
+  using ConstVisitor = function_ref<void(const PGOCtxProfContext &)>;
+  using Visitor = function_ref<void(PGOCtxProfContext &)>;
+
+  void update(Visitor, const Function *F = nullptr);
+  void visit(ConstVisitor, const Function *F = nullptr) const;
+
   const CtxProfFlatProfile flatten() const;
 
   bool invalidate(Module &, const PreservedAnalyses &PA,
@@ -105,13 +111,18 @@ class CtxProfAnalysis : public AnalysisInfoMixin<CtxProfAnalysis> {
 
 class CtxProfAnalysisPrinterPass
     : public PassInfoMixin<CtxProfAnalysisPrinterPass> {
-  raw_ostream &OS;
-
 public:
-  explicit CtxProfAnalysisPrinterPass(raw_ostream &OS) : OS(OS) {}
+  enum class PrintMode { Everything, JSON };
+  explicit CtxProfAnalysisPrinterPass(raw_ostream &OS,
+                                      PrintMode Mode = PrintMode::Everything)
+      : OS(OS), Mode(Mode) {}
 
   PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
   static bool isRequired() { return true; }
+
+private:
+  raw_ostream &OS;
+  const PrintMode Mode;
 };
 
 /// Assign a GUID to functions as metadata. GUID calculation takes linkage into
diff --git a/llvm/include/llvm/IR/IntrinsicInst.h b/llvm/include/llvm/IR/IntrinsicInst.h
index b45c89cadb0fde..5037e049aada57 100644
--- a/llvm/include/llvm/IR/IntrinsicInst.h
+++ b/llvm/include/llvm/IR/IntrinsicInst.h
@@ -1535,6 +1535,7 @@ class InstrProfCntrInstBase : public InstrProfInstBase {
   ConstantInt *getNumCounters() const;
   // The index of the counter that this instruction acts on.
   ConstantInt *getIndex() const;
+  void setIndex(uint32_t Idx);
 };
 
 /// This represents the llvm.instrprof.cover intrinsic.
@@ -1585,6 +1586,7 @@ class InstrProfCallsite : public InstrProfCntrInstBase {
     return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
   }
   Value *getCallee() const;
+  void setCallee(Value *);
 };
 
 /// This represents the llvm.instrprof.timestamp intrinsic.
diff --git a/llvm/include/llvm/ProfileData/PGOCtxProfReader.h b/llvm/include/llvm/ProfileData/PGOCtxProfReader.h
index 190deaeeacd085..23dcc376508b39 100644
--- a/llvm/include/llvm/ProfileData/PGOCtxProfReader.h
+++ b/llvm/include/llvm/ProfileData/PGOCtxProfReader.h
@@ -57,9 +57,23 @@ class PGOCtxProfContext final {
 
   GlobalValue::GUID guid() const { return GUID; }
   const SmallVectorImpl<uint64_t> &counters() const { return Counters; }
+  SmallVectorImpl<uint64_t> &counters() { return Counters; }
+
+  uint64_t getEntrycount() const { return Counters[0]; }
+
   const CallsiteMapTy &callsites() const { return Callsites; }
   CallsiteMapTy &callsites() { return Callsites; }
 
+  void ingestContext(uint32_t CSId, PGOCtxProfContext &&Other) {
+    auto [Iter, _] = callsites().try_emplace(CSId, CallTargetMapTy());
+    Iter->second.emplace(Other.guid(), std::move(Other));
+  }
+
+  void growCounters(uint32_t Size) {
+    if (Size >= Counters.size())
+      Counters.resize(Size);
+  }
+
   bool hasCallsite(uint32_t I) const {
     return Callsites.find(I) != Callsites.end();
   }
@@ -68,6 +82,12 @@ class PGOCtxProfContext final {
     assert(hasCallsite(I) && "Callsite not found");
     return Callsites.find(I)->second;
   }
+
+  CallTargetMapTy &callsite(uint32_t I) {
+    assert(hasCallsite(I) && "Callsite not found");
+    return Callsites.find(I)->second;
+  }
+
   void getContainedGuids(DenseSet<GlobalValue::GUID> &Guids) const;
 };
 
diff --git a/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h b/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h
index 385831f457038d..58af26f31417b0 100644
--- a/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h
+++ b/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h
@@ -14,6 +14,7 @@
 #ifndef LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H
 #define LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H
 
+#include "llvm/Analysis/CtxProfAnalysis.h"
 namespace llvm {
 template <typename T> class ArrayRef;
 class Constant;
@@ -56,6 +57,9 @@ CallBase &promoteCall(CallBase &CB, Function *Callee,
 CallBase &promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
                                     MDNode *BranchWeights = nullptr);
 
+CallBase *promoteCallWithIfThenElse(CallBase &CB, Function &Callee,
+                                    PGOContextualProfile &CtxProf);
+
 /// This is similar to `promoteCallWithIfThenElse` except that the condition to
 /// promote a virtual call is that \p VPtr is the same as any of \p
 /// AddressPoints.
diff --git a/llvm/lib/Analysis/CtxProfAnalysis.cpp b/llvm/lib/Analysis/CtxProfAnalysis.cpp
index 3fc1bc34afb97e..2cd3f2114397e5 100644
--- a/llvm/lib/Analysis/CtxProfAnalysis.cpp
+++ b/llvm/lib/Analysis/CtxProfAnalysis.cpp
@@ -173,16 +173,22 @@ PreservedAnalyses CtxProfAnalysisPrinterPass::run(Module &M,
     return PreservedAnalyses::all();
   }
 
-  OS << "Function Info:\n";
-  for (const auto &[Guid, FuncInfo] : C.FuncInfo)
-    OS << Guid << " : " << FuncInfo.Name
-       << ". MaxCounterID: " << FuncInfo.NextCounterIndex
-       << ". MaxCallsiteID: " << FuncInfo.NextCallsiteIndex << "\n";
+  if (Mode == PrintMode::Everything) {
+    OS << "Function Info:\n";
+    for (const auto &[Guid, FuncInfo] : C.FuncInfo)
+      OS << Guid << " : " << FuncInfo.Name
+         << ". MaxCounterID: " << FuncInfo.NextCounterIndex
+         << ". MaxCallsiteID: " << FuncInfo.NextCallsiteIndex << "\n";
+  }
 
   const auto JSONed = ::llvm::json::toJSON(C.profiles());
 
-  OS << "\nCurrent Profile:\n";
+  if (Mode == PrintMode::Everything)
+    OS << "\nCurrent Profile:\n";
   OS << formatv("{0:2}", JSONed);
+  if (Mode == PrintMode::JSON)
+    return PreservedAnalyses::all();
+
   OS << "\n";
   OS << "\nFlat Profile:\n";
   auto Flat = C.flatten();
@@ -209,34 +215,49 @@ InstrProfIncrementInst *CtxProfAnalysis::getBBInstrumentation(BasicBlock &BB) {
   return nullptr;
 }
 
-static void
-preorderVisit(const PGOCtxProfContext::CallTargetMapTy &Profiles,
-              function_ref<void(const PGOCtxProfContext &)> Visitor) {
-  std::function<void(const PGOCtxProfContext &)> Traverser =
-      [&](const auto &Ctx) {
-        Visitor(Ctx);
-        for (const auto &[_, SubCtxSet] : Ctx.callsites())
-          for (const auto &[__, Subctx] : SubCtxSet)
-            Traverser(Subctx);
-      };
-  for (const auto &[_, P] : Profiles)
+template <class ProfilesTy, class ProfTy>
+static void preorderVisit(ProfilesTy &Profiles,
+                          function_ref<void(ProfTy &)> Visitor,
+                          GlobalValue::GUID Match = 0) {
+  std::function<void(ProfTy &)> Traverser = [&](auto &Ctx) {
+    if (!Match || Ctx.guid() == Match)
+      Visitor(Ctx);
+    for (auto &[_, SubCtxSet] : Ctx.callsites())
+      for (auto &[__, Subctx] : SubCtxSet)
+        Traverser(Subctx);
+  };
+  for (auto &[_, P] : Profiles)
     Traverser(P);
 }
 
+void PGOContextualProfile::update(Visitor V, const Function *F) {
+  GlobalValue::GUID G = F ? getDefinedFunctionGUID(*F) : 0U;
+  preorderVisit<PGOCtxProfContext::CallTargetMapTy, PGOCtxProfContext>(
+      *Profiles, V, G);
+}
+
+void PGOContextualProfile::visit(ConstVisitor V, const Function *F) const {
+  GlobalValue::GUID G = F ? getDefinedFunctionGUID(*F) : 0U;
+  preorderVisit<const PGOCtxProfContext::CallTargetMapTy,
+                const PGOCtxProfContext>(*Profiles, V, G);
+}
+
 const CtxProfFlatProfile PGOContextualProfile::flatten() const {
   assert(Profiles.has_value());
   CtxProfFlatProfile Flat;
-  preorderVisit(*Profiles, [&](const PGOCtxProfContext &Ctx) {
-    auto [It, Ins] = Flat.insert({Ctx.guid(), {}});
-    if (Ins) {
-      llvm::append_range(It->second, Ctx.counters());
-      return;
-    }
-    assert(It->second.size() == Ctx.counters().size() &&
-           "All contexts corresponding to a function should have the exact "
-           "same number of counters.");
-    for (size_t I = 0, E = It->second.size(); I < E; ++I)
-      It->second[I] += Ctx.counters()[I];
-  });
+  preorderVisit<const PGOCtxProfContext::CallTargetMapTy,
+                const PGOCtxProfContext>(
+      *Profiles, [&](const PGOCtxProfContext &Ctx) {
+        auto [It, Ins] = Flat.insert({Ctx.guid(), {}});
+        if (Ins) {
+          llvm::append_range(It->second, Ctx.counters());
+          return;
+        }
+        assert(It->second.size() == Ctx.counters().size() &&
+               "All contexts corresponding to a function should have the exact "
+               "same number of counters.");
+        for (size_t I = 0, E = It->second.size(); I < E; ++I)
+          It->second[I] += Ctx.counters()[I];
+      });
   return Flat;
 }
diff --git a/llvm/lib/IR/IntrinsicInst.cpp b/llvm/lib/IR/IntrinsicInst.cpp
index db3b0196f66fd6..0eadd0f980c15b 100644
--- a/llvm/lib/IR/IntrinsicInst.cpp
+++ b/llvm/lib/IR/IntrinsicInst.cpp
@@ -285,6 +285,11 @@ ConstantInt *InstrProfCntrInstBase::getIndex() const {
   return cast<ConstantInt>(const_cast<Value *>(getArgOperand(3)));
 }
 
+void InstrProfCntrInstBase::setIndex(uint32_t Idx) {
+  assert(isa<InstrProfCntrInstBase>(this));
+  setArgOperand(3, ConstantInt::get(Type::getInt32Ty(getContext()), Idx));
+}
+
 Value *InstrProfIncrementInst::getStep() const {
   if (InstrProfIncrementInstStep::classof(this)) {
     return const_cast<Value *>(getArgOperand(4));
@@ -300,6 +305,11 @@ Value *InstrProfCallsite::getCallee() const {
   return nullptr;
 }
 
+void InstrProfCallsite::setCallee(Value *V) {
+  assert(isa<InstrProfCallsite>(this));
+  setArgOperand(4, V);
+}
+
 std::optional<RoundingMode> ConstrainedFPIntrinsic::getRoundingMode() const {
   unsigned NumOperands = arg_size();
   Metadata *MD = nullptr;
diff --git a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
index 90dc727cde16d7..0ca7524c273daa 100644
--- a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
+++ b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
@@ -13,13 +13,16 @@
 
 #include "llvm/Transforms/Utils/CallPromotionUtils.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/Analysis/CtxProfAnalysis.h"
 #include "llvm/Analysis/Loads.h"
 #include "llvm/Analysis/TypeMetadataUtils.h"
 #include "llvm/IR/AttributeMask.h"
 #include "llvm/IR/Constant.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Module.h"
+#include "llvm/ProfileData/PGOCtxProfReader.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 
 using namespace llvm;
@@ -572,6 +575,89 @@ CallBase &llvm::promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
   return promoteCall(NewInst, Callee);
 }
 
+CallBase *llvm::promoteCallWithIfThenElse(CallBase &CB, Function &Callee,
+                                          PGOContextualProfile &CtxProf) {
+  assert(CB.isIndirectCall());
+  if (!CtxProf.isFunctionKnown(Callee))
+    return nullptr;
+  auto &Caller = *CB.getParent()->getParent();
+  auto *CSInstr = CtxProfAnalysis::getCallsiteInstrumentation(CB);
+  if (!CSInstr)
+    return nullptr;
+  const auto CSIndex = CSInstr->getIndex()->getZExtValue();
+
+  CallBase &DirectCall = promoteCall(
+      versionCallSite(CB, &Callee, /*BranchWeights=*/nullptr), &Callee);
+  CSInstr->moveBefore(&CB);
+  const auto NewCSID = CtxProf.allocateNextCallsiteIndex(Caller);
+  auto *NewCSInstr = cast<InstrProfCallsite>(CSInstr->clone());
+  NewCSInstr->setIndex(NewCSID);
+  NewCSInstr->setCallee(&Callee);
+  NewCSInstr->insertBefore(&DirectCall);
+  auto &DirectBB = *DirectCall.getParent();
+  auto &IndirectBB = *CB.getParent();
+
+  assert((CtxProfAnalysis::getBBInstrumentation(IndirectBB) == nullptr) &&
+         "The ICP direct BB is new, it shouldn't have instrumentation");
+  assert((CtxProfAnalysis::getBBInstrumentation(DirectBB) == nullptr) &&
+         "The ICP indirect BB is new, it shouldn't have instrumentation");
+
+  // Make the 2 new BBs have counters.
+  const uint32_t DirectID = CtxProf.allocateNextCounterIndex(Caller);
+  const uint32_t IndirectID = CtxProf.allocateNextCounterIndex(Caller);
+  const uint32_t NewCountersSize = IndirectID + 1;
+  auto *EntryBBIns =
+      CtxProfAnalysis::getBBInstrumentation(Caller.getEntryBlock());
+  auto *DirectBBIns = cast<InstrProfCntrInstBase>(EntryBBIns->clone());
+  DirectBBIns->setIndex(DirectID);
+  DirectBBIns->insertInto(&DirectBB, DirectBB.getFirstInsertionPt());
+
+  auto *IndirectBBIns = cast<InstrProfCntrInstBase>(EntryBBIns->clone());
+  IndirectBBIns->setIndex(IndirectID);
+  IndirectBBIns->insertInto(&IndirectBB, IndirectBB.getFirstInsertionPt());
+
+  const GlobalValue::GUID CalleeGUID = AssignGUIDPass::getGUID(Callee);
+
+  auto ProfileUpdater = [&](PGOCtxProfContext &Ctx) {
+    assert(Ctx.guid() == AssignGUIDPass::getGUID(Caller));
+    assert(NewCountersSize - 2 == Ctx.counters().size());
+    // Regardless what next, all the ctx-es belonging to a function must have
+    // the same size counters.
+    Ctx.growCounters(NewCountersSize);
+
+    // Maybe in this context, the indirect callsite wasn't observed at all
+    if (!Ctx.hasCallsite(CSIndex))
+      return;
+    auto &CSData = Ctx.callsite(CSIndex);
+    auto It = CSData.find(CalleeGUID);
+
+    // Maybe we did notice the indirect callsite, but to other targets.
+    if (It == CSData.end())
+      return;
+
+    assert(CalleeGUID == It->second.guid());
+
+    uint32_t DirectCount = It->second.getEntrycount();
+    uint32_t TotalCount = 0;
+    for (const auto &[_, V] : CSData)
+      TotalCount += V.getEntrycount();
+    assert(TotalCount >= DirectCount);
+    uint32_t IndirectCount = TotalCount - DirectCount;
+    // The ICP's effect is as-if the direct BB would have been taken DirectCount
+    // times, and the indirect BB, IndirectCount times
+    Ctx.counters()[DirectID] = DirectCount;
+    Ctx.counters()[IndirectID] = IndirectCount;
+
+    // This particular indirect target needs to be moved to this caller under
+    // the newly-allocated callsite index.
+    assert(Ctx.callsites().count(NewCSID) == 0);
+    Ctx.ingestContext(NewCSID, std::move(It->second));
+    CSData.erase(CalleeGUID);
+  };
+  CtxProf.update(ProfileUpdater, &Caller);
+  return &DirectCall;
+}
+
 CallBase &llvm::promoteCallWithVTableCmp(CallBase &CB, Instruction *VPtr,
                                          Function *Callee,
                                          ArrayRef<Constant *> AddressPoints,
diff --git a/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp b/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp
index 2d457eb3b678aa..aff603de2a2bd5 100644
--- a/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp
+++ b/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Transforms/Utils/CallPromotionUtils.h"
+#include "llvm/Analysis/CtxProfAnalysis.h"
 #include "llvm/AsmParser/Parser.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Instructions.h"
@@ -14,7 +15,12 @@
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/NoFolder.h"
+#include "llvm/IR/PassInstrumentation.h"
+#include "llvm/ProfileData/PGOCtxProfReader.h"
+#include "llvm/Support/JSON.h"
 #include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/Testing/Support/SupportHelpers.h"
 #include "gtest/gtest.h"
 
 using namespace llvm;
@@ -456,3 +462,175 @@ declare void @_ZN5Base35func3Ev(ptr)
   // 1 call instruction from the entry block.
   EXPECT_EQ(F->front().size(), OrigEntryBBSize + 4);
 }
+
+using namespace llvm::ctx_profile;
+
+class ContextManager final {
+  std::vector<std::unique_ptr<char[]>> Nodes;
+  std::map<GUID, const ContextNode *> Roots;
+
+public:
+  ContextNode *createNode(GUID Guid, uint32_t NrCounters, uint32_t NrCallsites,
+                          ContextNode *Next = nullptr) {
+    auto AllocSize = ContextNode::getAllocSize(NrCounters, NrCallsites);
+    auto *Mem = Nodes.emplace_back(std::make_unique<char[]>(AllocSize)).get();
+    std::memset(Mem, 0, AllocSize);
+    auto *Ret = new (Mem) ContextNode(Guid, NrCounters, NrCallsites, Next);
+    return Ret;
+  }
+};
+
+TEST(CallPromotionUtilsTest, PromoteWithIcmpAndCtxProf) {
+  LLVMContext C;
+  std::unique_ptr<Module> M = parseIR(C,
+                                      R"IR(
+define i32 @testfunc1(ptr %d) !guid !0 {
+  call void @llvm.instrprof.increment(ptr null, i64 0, i32 1, i32 0)
+  call void @llvm.instrprof.callsite(ptr null, i64 0, i32 1, i32 0, ptr %d)
+  %call = call i32 %d()
+  ret i32 %call
+}
+
+define i32 @f1() !guid !1 {
+  call void @llvm.instrprof.increment(ptr null, i64 0, i32 1, i32 0)
+  ret i32 2
+}
+
+define i32 @f2() !guid !2 {
+  call void @llvm.instrprof.increment(ptr null, i64 0, i32 1, i32 0)
+  call void @llvm.instrprof.callsite(ptr null, i64 0, i32 1, i32 0, ptr @f4)
+  %r = call i32 @f4()
+  ret i32 %r
+}
+
+define i32 @testfunc2(ptr %p) !guid !4 {
+  call void @llvm.instrprof.increment(ptr null, i64 0, i32 1, i32 0)
+  call void @llvm.instrprof.callsite(ptr null, i64 0, i32 1, i32 0, ptr @testfunc1)
+  %r = call i32 @testfunc1(ptr %p)
+  ret i32 %r
+}
+
+declare i32 @f3()
+
+define i32 @f4() !guid !3 {
+  ret i32 3
+}
+
+!0 = !{i64 1000}
+!1 = !{i64 1001}
+!2 = !{i64 1002}
+!3 = !{i64 1004}
+!4 = !{i64 1005}
+)IR");
+
+  // Synthesize a profile. The profile is nonsensical, but the goal is to check
+  // that new BBs are created with IDs and the right counter values.
+  ContextManager Mgr;
+  auto BuildTree = [&](const std::vector<uint32_t> &CalleeEntrycounts) {
+    auto *Entry = Mgr.createNode(1000, 1, 1);
+    // Set the entrycount to 1 so it's not 0. We don't care about it, really,
+    // for this test but we generally assume it's not 0.
+    Entry->counters()[0] = 1;
+    auto *F1 = Mgr.createNode(1001, 1, 0);
+    auto *F2 = Mgr.createNode(1002, 1, 1, F1);
+    auto *F3 = Mgr.createNode(1003, 1, 0, F2);
+    auto *F4 = Mgr.createNode(1004, 1, 0);
+
+    F1->counters()[0] = CalleeEntrycounts[0];
+    F2->counters()[0] = CalleeEntrycounts[1];
+    F3->counters()[0] = CalleeEntrycounts[2];
+    F4->counters()[0] = CalleeEntrycounts[3];
+    F2->subContexts()[0] = F4;
+    Entry->subContexts()[0] = F3; // which chains F2 and F1
+    return Entry;
+  };
+  // We'll be interested in f2. the entry counts for it are: 11 in the first
+  // context; and 102 in the second.
+  // The total number of times the indirect callsite is exercised is:
+  // 10+11+12 = 35 in the first case; and 101+102+103 = 306 in the
+  // second.
+  // This means that the direct/indirect call counters will be: 11/22 in the
+  // first case and 102/204 in the second. Meaning, the "Counters" for the
+  // GUID=1002 context will look like [1, 11, 22] and [1, 102, 204],
+  // respectivelly (the first "1" being the entrycount which we set to 1 above)
+  auto *Entry1 = BuildTree({10, 11, 12, 13});
+  auto *SubTree2 = BuildTree({101, 102, 103, 104});
+  auto *Entry2 = Mgr.createNode(1005, 1, 1);
+  Entry2->counters()[0] = 2;
+  Entry2->subContexts()[0] = SubTree2;
+
+  llvm::unittest::TempFile ProfileFile("ctx_profile", "", "", /*Unique*/ true);
+  {
+    std::error_code EC;
+    raw_fd_stream Out(ProfileFile.path(), EC);
+    ASSERT_FALSE(EC);
+    {
+      PGOCtxProfileWriter Writer(Out);
+      Writer.write(*Entry1);
+      Writer.write(*Entry2);
+    }
+  }
+
+  ModuleAnalysisManager MAM;
+  MAM.registerPass([&]() { return CtxProfAnalysis(ProfileFile.path()); });
+  MAM.registerPass([&]() { return PassInstrumentationAnalysis(); });
+  auto &CtxProf = MAM.getResult<CtxProfAnalysis>(*M);
+  auto *Caller = M->getFunction("testfunc1");
+  ASSERT_TRUE(!!Caller);
+  auto *Callee = M->getFunction("f2");
+  ASSERT_TRUE(!!Callee);
+  auto *IndirectCS = [&]() -> CallBase * {
+    for (auto &BB : *Caller)
+      for (auto &I : BB)
+        if (auto *CB = dyn_cast<CallBase>(&I); CB && CB->isIndirectCall())
+          return CB;
+    return nullptr;
+  }();
+  ASSERT_TRUE(!!IndirectCS);
+  promoteCallWithIfThenElse(*IndirectCS, *Callee, CtxProf);
+
+  std::string Str;
+  raw_string_ostream OS(Str);
+  CtxProfAnalysisPrinterPass Printer(
+      OS, CtxProfAnalysisPrinterPass::PrintMode::JSON);
+  Printer.run(*M, MAM);
+  const char *Expected = R"(
+  [
+  {
+    "Guid": 1000,
+    "Counters": [1, 11, 22],
+    "Callsites": [
+      [{ "Guid": 1001,
+          "Counters": [10]}, 
+        { "Guid": 1003,
+          "Counters": [12]
+        }], 
+        [{ "Guid": 1002,
+          "Counters": [11],
+          "Callsites": [
+          [{ "Guid": 1004,
+            "Counters": [13] }]]}]]
+  },
+  {
+    "Guid": 1005,
+    "Counters": [2],
+    "Callsites": [
+      [{ "Guid": 1000,
+         "Counters": [1, 102, 204],
+         "Callsites": [
+            [{ "Guid": 1001,
+               "Counters": [101]}, 
+             { "Guid": 1003,
+               "Counters": [103]}],
+            [{ "Guid": 1002,
+               "Counters": [102],
+               "Callsites": [
+            [{ "Guid": 1004,
+               "Counters": [104]}]]}]]}]]}
+])";
+  auto ExpectedJSON = json::parse(Expected);
+  ASSERT_TRUE(!!ExpectedJSON);
+  auto ProducedJSON = json::parse(Str);
+  ASSERT_TRUE(!!ProducedJSON);
+  EXPECT_EQ(*ProducedJSON, *ExpectedJSON);
+}



More information about the llvm-commits mailing list