[llvm-branch-commits] [llvm] [ctx_prof] Add support for ICP (PR #105469)
Mircea Trofin via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Aug 21 11:19:32 PDT 2024
https://github.com/mtrofin updated https://github.com/llvm/llvm-project/pull/105469
>From 0d7c720e67a0213565f0e7c141c4ffa1b91fc5b9 Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Tue, 20 Aug 2024 21:09:16 -0700
Subject: [PATCH 1/2] [ctx_prof] API to get the instrumentation of a BB
---
llvm/include/llvm/Analysis/CtxProfAnalysis.h | 5 +++++
llvm/lib/Analysis/CtxProfAnalysis.cpp | 7 ++++++
.../Analysis/CtxProfAnalysisTest.cpp | 22 +++++++++++++++++++
3 files changed, 34 insertions(+)
diff --git a/llvm/include/llvm/Analysis/CtxProfAnalysis.h b/llvm/include/llvm/Analysis/CtxProfAnalysis.h
index 23abcbe2c6e9d2..0b4dd8ae3a0dc7 100644
--- a/llvm/include/llvm/Analysis/CtxProfAnalysis.h
+++ b/llvm/include/llvm/Analysis/CtxProfAnalysis.h
@@ -95,7 +95,12 @@ class CtxProfAnalysis : public AnalysisInfoMixin<CtxProfAnalysis> {
PGOContextualProfile run(Module &M, ModuleAnalysisManager &MAM);
+ /// Get the instruction instrumenting a callsite, or nullptr if that cannot be
+ /// found.
static InstrProfCallsite *getCallsiteInstrumentation(CallBase &CB);
+
+ /// Get the instruction instrumenting a BB, or nullptr if not present.
+ static InstrProfIncrementInst *getBBInstrumentation(BasicBlock &BB);
};
class CtxProfAnalysisPrinterPass
diff --git a/llvm/lib/Analysis/CtxProfAnalysis.cpp b/llvm/lib/Analysis/CtxProfAnalysis.cpp
index ceebb2cf06d235..3fc1bc34afb97e 100644
--- a/llvm/lib/Analysis/CtxProfAnalysis.cpp
+++ b/llvm/lib/Analysis/CtxProfAnalysis.cpp
@@ -202,6 +202,13 @@ InstrProfCallsite *CtxProfAnalysis::getCallsiteInstrumentation(CallBase &CB) {
return nullptr;
}
+InstrProfIncrementInst *CtxProfAnalysis::getBBInstrumentation(BasicBlock &BB) {
+ for (auto &I : BB)
+ if (auto *Incr = dyn_cast<InstrProfIncrementInst>(&I))
+ return Incr;
+ return nullptr;
+}
+
static void
preorderVisit(const PGOCtxProfContext::CallTargetMapTy &Profiles,
function_ref<void(const PGOCtxProfContext &)> Visitor) {
diff --git a/llvm/unittests/Analysis/CtxProfAnalysisTest.cpp b/llvm/unittests/Analysis/CtxProfAnalysisTest.cpp
index 5f9bf3ec540eb3..fbe3a6e45109cc 100644
--- a/llvm/unittests/Analysis/CtxProfAnalysisTest.cpp
+++ b/llvm/unittests/Analysis/CtxProfAnalysisTest.cpp
@@ -132,4 +132,26 @@ TEST_F(CtxProfAnalysisTest, GetCallsiteIDNegativeTest) {
EXPECT_EQ(IndIns, nullptr);
}
+TEST_F(CtxProfAnalysisTest, GetBBIDTest) {
+ ModulePassManager MPM;
+ MPM.addPass(PGOInstrumentationGen(PGOInstrumentationType::CTXPROF));
+ EXPECT_FALSE(MPM.run(*M, MAM).areAllPreserved());
+ auto *F = M->getFunction("foo");
+ ASSERT_NE(F, nullptr);
+ std::map<std::string, int> BBNameAndID;
+
+ for (auto &BB : *F) {
+ auto *Ins = CtxProfAnalysis::getBBInstrumentation(BB);
+ if (Ins)
+ BBNameAndID[BB.getName().str()] =
+ static_cast<int>(Ins->getIndex()->getZExtValue());
+ else
+ BBNameAndID[BB.getName().str()] = -1;
+ }
+
+ EXPECT_THAT(BBNameAndID,
+ testing::UnorderedElementsAre(
+ testing::Pair("", 0), testing::Pair("yes", 1),
+ testing::Pair("no", -1), testing::Pair("exit", -1)));
+}
} // namespace
>From 61e37e3e1657a7e85e9df2f77feb6957c304851a Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Tue, 20 Aug 2024 21:32:23 -0700
Subject: [PATCH 2/2] [ctx_prof] Add support for ICP
---
llvm/include/llvm/Analysis/CtxProfAnalysis.h | 18 +-
llvm/include/llvm/IR/IntrinsicInst.h | 2 +
.../llvm/ProfileData/PGOCtxProfReader.h | 20 ++
.../Transforms/Utils/CallPromotionUtils.h | 4 +
llvm/lib/Analysis/CtxProfAnalysis.cpp | 79 +++++---
llvm/lib/IR/IntrinsicInst.cpp | 10 +
.../Transforms/Utils/CallPromotionUtils.cpp | 86 +++++++++
.../Utils/CallPromotionUtilsTest.cpp | 178 ++++++++++++++++++
8 files changed, 364 insertions(+), 33 deletions(-)
diff --git a/llvm/include/llvm/Analysis/CtxProfAnalysis.h b/llvm/include/llvm/Analysis/CtxProfAnalysis.h
index 0b4dd8ae3a0dc7..d6c2bb26a091af 100644
--- a/llvm/include/llvm/Analysis/CtxProfAnalysis.h
+++ b/llvm/include/llvm/Analysis/CtxProfAnalysis.h
@@ -73,6 +73,12 @@ class PGOContextualProfile {
return FuncInfo.find(getDefinedFunctionGUID(F))->second.NextCallsiteIndex++;
}
+ using ConstVisitor = function_ref<void(const PGOCtxProfContext &)>;
+ using Visitor = function_ref<void(PGOCtxProfContext &)>;
+
+ void update(Visitor, const Function *F = nullptr);
+ void visit(ConstVisitor, const Function *F = nullptr) const;
+
const CtxProfFlatProfile flatten() const;
bool invalidate(Module &, const PreservedAnalyses &PA,
@@ -105,13 +111,18 @@ class CtxProfAnalysis : public AnalysisInfoMixin<CtxProfAnalysis> {
class CtxProfAnalysisPrinterPass
: public PassInfoMixin<CtxProfAnalysisPrinterPass> {
- raw_ostream &OS;
-
public:
- explicit CtxProfAnalysisPrinterPass(raw_ostream &OS) : OS(OS) {}
+ enum class PrintMode { Everything, JSON };
+ explicit CtxProfAnalysisPrinterPass(raw_ostream &OS,
+ PrintMode Mode = PrintMode::Everything)
+ : OS(OS), Mode(Mode) {}
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
static bool isRequired() { return true; }
+
+private:
+ raw_ostream &OS;
+ const PrintMode Mode;
};
/// Assign a GUID to functions as metadata. GUID calculation takes linkage into
@@ -134,6 +145,5 @@ class AssignGUIDPass : public PassInfoMixin<AssignGUIDPass> {
// This should become GlobalValue::getGUID
static uint64_t getGUID(const Function &F);
};
-
} // namespace llvm
#endif // LLVM_ANALYSIS_CTXPROFANALYSIS_H
diff --git a/llvm/include/llvm/IR/IntrinsicInst.h b/llvm/include/llvm/IR/IntrinsicInst.h
index 2f1e2c08c3ecec..bab41efab528e2 100644
--- a/llvm/include/llvm/IR/IntrinsicInst.h
+++ b/llvm/include/llvm/IR/IntrinsicInst.h
@@ -1519,6 +1519,7 @@ class InstrProfCntrInstBase : public InstrProfInstBase {
ConstantInt *getNumCounters() const;
// The index of the counter that this instruction acts on.
ConstantInt *getIndex() const;
+ void setIndex(uint32_t Idx);
};
/// This represents the llvm.instrprof.cover intrinsic.
@@ -1569,6 +1570,7 @@ class InstrProfCallsite : public InstrProfCntrInstBase {
return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
}
Value *getCallee() const;
+ void setCallee(Value *);
};
/// This represents the llvm.instrprof.timestamp intrinsic.
diff --git a/llvm/include/llvm/ProfileData/PGOCtxProfReader.h b/llvm/include/llvm/ProfileData/PGOCtxProfReader.h
index 190deaeeacd085..23dcc376508b39 100644
--- a/llvm/include/llvm/ProfileData/PGOCtxProfReader.h
+++ b/llvm/include/llvm/ProfileData/PGOCtxProfReader.h
@@ -57,9 +57,23 @@ class PGOCtxProfContext final {
GlobalValue::GUID guid() const { return GUID; }
const SmallVectorImpl<uint64_t> &counters() const { return Counters; }
+ SmallVectorImpl<uint64_t> &counters() { return Counters; }
+
+ uint64_t getEntrycount() const { return Counters[0]; }
+
const CallsiteMapTy &callsites() const { return Callsites; }
CallsiteMapTy &callsites() { return Callsites; }
+ void ingestContext(uint32_t CSId, PGOCtxProfContext &&Other) {
+ auto [Iter, _] = callsites().try_emplace(CSId, CallTargetMapTy());
+ Iter->second.emplace(Other.guid(), std::move(Other));
+ }
+
+ void growCounters(uint32_t Size) {
+ if (Size >= Counters.size())
+ Counters.resize(Size);
+ }
+
bool hasCallsite(uint32_t I) const {
return Callsites.find(I) != Callsites.end();
}
@@ -68,6 +82,12 @@ class PGOCtxProfContext final {
assert(hasCallsite(I) && "Callsite not found");
return Callsites.find(I)->second;
}
+
+ CallTargetMapTy &callsite(uint32_t I) {
+ assert(hasCallsite(I) && "Callsite not found");
+ return Callsites.find(I)->second;
+ }
+
void getContainedGuids(DenseSet<GlobalValue::GUID> &Guids) const;
};
diff --git a/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h b/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h
index 385831f457038d..58af26f31417b0 100644
--- a/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h
+++ b/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h
@@ -14,6 +14,7 @@
#ifndef LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H
#define LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H
+#include "llvm/Analysis/CtxProfAnalysis.h"
namespace llvm {
template <typename T> class ArrayRef;
class Constant;
@@ -56,6 +57,9 @@ CallBase &promoteCall(CallBase &CB, Function *Callee,
CallBase &promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
MDNode *BranchWeights = nullptr);
+CallBase *promoteCallWithIfThenElse(CallBase &CB, Function &Callee,
+ PGOContextualProfile &CtxProf);
+
/// This is similar to `promoteCallWithIfThenElse` except that the condition to
/// promote a virtual call is that \p VPtr is the same as any of \p
/// AddressPoints.
diff --git a/llvm/lib/Analysis/CtxProfAnalysis.cpp b/llvm/lib/Analysis/CtxProfAnalysis.cpp
index 3fc1bc34afb97e..2cd3f2114397e5 100644
--- a/llvm/lib/Analysis/CtxProfAnalysis.cpp
+++ b/llvm/lib/Analysis/CtxProfAnalysis.cpp
@@ -173,16 +173,22 @@ PreservedAnalyses CtxProfAnalysisPrinterPass::run(Module &M,
return PreservedAnalyses::all();
}
- OS << "Function Info:\n";
- for (const auto &[Guid, FuncInfo] : C.FuncInfo)
- OS << Guid << " : " << FuncInfo.Name
- << ". MaxCounterID: " << FuncInfo.NextCounterIndex
- << ". MaxCallsiteID: " << FuncInfo.NextCallsiteIndex << "\n";
+ if (Mode == PrintMode::Everything) {
+ OS << "Function Info:\n";
+ for (const auto &[Guid, FuncInfo] : C.FuncInfo)
+ OS << Guid << " : " << FuncInfo.Name
+ << ". MaxCounterID: " << FuncInfo.NextCounterIndex
+ << ". MaxCallsiteID: " << FuncInfo.NextCallsiteIndex << "\n";
+ }
const auto JSONed = ::llvm::json::toJSON(C.profiles());
- OS << "\nCurrent Profile:\n";
+ if (Mode == PrintMode::Everything)
+ OS << "\nCurrent Profile:\n";
OS << formatv("{0:2}", JSONed);
+ if (Mode == PrintMode::JSON)
+ return PreservedAnalyses::all();
+
OS << "\n";
OS << "\nFlat Profile:\n";
auto Flat = C.flatten();
@@ -209,34 +215,49 @@ InstrProfIncrementInst *CtxProfAnalysis::getBBInstrumentation(BasicBlock &BB) {
return nullptr;
}
-static void
-preorderVisit(const PGOCtxProfContext::CallTargetMapTy &Profiles,
- function_ref<void(const PGOCtxProfContext &)> Visitor) {
- std::function<void(const PGOCtxProfContext &)> Traverser =
- [&](const auto &Ctx) {
- Visitor(Ctx);
- for (const auto &[_, SubCtxSet] : Ctx.callsites())
- for (const auto &[__, Subctx] : SubCtxSet)
- Traverser(Subctx);
- };
- for (const auto &[_, P] : Profiles)
+template <class ProfilesTy, class ProfTy>
+static void preorderVisit(ProfilesTy &Profiles,
+ function_ref<void(ProfTy &)> Visitor,
+ GlobalValue::GUID Match = 0) {
+ std::function<void(ProfTy &)> Traverser = [&](auto &Ctx) {
+ if (!Match || Ctx.guid() == Match)
+ Visitor(Ctx);
+ for (auto &[_, SubCtxSet] : Ctx.callsites())
+ for (auto &[__, Subctx] : SubCtxSet)
+ Traverser(Subctx);
+ };
+ for (auto &[_, P] : Profiles)
Traverser(P);
}
+void PGOContextualProfile::update(Visitor V, const Function *F) {
+ GlobalValue::GUID G = F ? getDefinedFunctionGUID(*F) : 0U;
+ preorderVisit<PGOCtxProfContext::CallTargetMapTy, PGOCtxProfContext>(
+ *Profiles, V, G);
+}
+
+void PGOContextualProfile::visit(ConstVisitor V, const Function *F) const {
+ GlobalValue::GUID G = F ? getDefinedFunctionGUID(*F) : 0U;
+ preorderVisit<const PGOCtxProfContext::CallTargetMapTy,
+ const PGOCtxProfContext>(*Profiles, V, G);
+}
+
const CtxProfFlatProfile PGOContextualProfile::flatten() const {
assert(Profiles.has_value());
CtxProfFlatProfile Flat;
- preorderVisit(*Profiles, [&](const PGOCtxProfContext &Ctx) {
- auto [It, Ins] = Flat.insert({Ctx.guid(), {}});
- if (Ins) {
- llvm::append_range(It->second, Ctx.counters());
- return;
- }
- assert(It->second.size() == Ctx.counters().size() &&
- "All contexts corresponding to a function should have the exact "
- "same number of counters.");
- for (size_t I = 0, E = It->second.size(); I < E; ++I)
- It->second[I] += Ctx.counters()[I];
- });
+ preorderVisit<const PGOCtxProfContext::CallTargetMapTy,
+ const PGOCtxProfContext>(
+ *Profiles, [&](const PGOCtxProfContext &Ctx) {
+ auto [It, Ins] = Flat.insert({Ctx.guid(), {}});
+ if (Ins) {
+ llvm::append_range(It->second, Ctx.counters());
+ return;
+ }
+ assert(It->second.size() == Ctx.counters().size() &&
+ "All contexts corresponding to a function should have the exact "
+ "same number of counters.");
+ for (size_t I = 0, E = It->second.size(); I < E; ++I)
+ It->second[I] += Ctx.counters()[I];
+ });
return Flat;
}
diff --git a/llvm/lib/IR/IntrinsicInst.cpp b/llvm/lib/IR/IntrinsicInst.cpp
index db3b0196f66fd6..0eadd0f980c15b 100644
--- a/llvm/lib/IR/IntrinsicInst.cpp
+++ b/llvm/lib/IR/IntrinsicInst.cpp
@@ -285,6 +285,11 @@ ConstantInt *InstrProfCntrInstBase::getIndex() const {
return cast<ConstantInt>(const_cast<Value *>(getArgOperand(3)));
}
+void InstrProfCntrInstBase::setIndex(uint32_t Idx) {
+ assert(isa<InstrProfCntrInstBase>(this));
+ setArgOperand(3, ConstantInt::get(Type::getInt32Ty(getContext()), Idx));
+}
+
Value *InstrProfIncrementInst::getStep() const {
if (InstrProfIncrementInstStep::classof(this)) {
return const_cast<Value *>(getArgOperand(4));
@@ -300,6 +305,11 @@ Value *InstrProfCallsite::getCallee() const {
return nullptr;
}
+void InstrProfCallsite::setCallee(Value *V) {
+ assert(isa<InstrProfCallsite>(this));
+ setArgOperand(4, V);
+}
+
std::optional<RoundingMode> ConstrainedFPIntrinsic::getRoundingMode() const {
unsigned NumOperands = arg_size();
Metadata *MD = nullptr;
diff --git a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
index 90dc727cde16d7..0ca7524c273daa 100644
--- a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
+++ b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
@@ -13,13 +13,16 @@
#include "llvm/Transforms/Utils/CallPromotionUtils.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/Analysis/CtxProfAnalysis.h"
#include "llvm/Analysis/Loads.h"
#include "llvm/Analysis/TypeMetadataUtils.h"
#include "llvm/IR/AttributeMask.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Module.h"
+#include "llvm/ProfileData/PGOCtxProfReader.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
using namespace llvm;
@@ -572,6 +575,89 @@ CallBase &llvm::promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
return promoteCall(NewInst, Callee);
}
+CallBase *llvm::promoteCallWithIfThenElse(CallBase &CB, Function &Callee,
+ PGOContextualProfile &CtxProf) {
+ assert(CB.isIndirectCall());
+ if (!CtxProf.isFunctionKnown(Callee))
+ return nullptr;
+ auto &Caller = *CB.getParent()->getParent();
+ auto *CSInstr = CtxProfAnalysis::getCallsiteInstrumentation(CB);
+ if (!CSInstr)
+ return nullptr;
+ const auto CSIndex = CSInstr->getIndex()->getZExtValue();
+
+ CallBase &DirectCall = promoteCall(
+ versionCallSite(CB, &Callee, /*BranchWeights=*/nullptr), &Callee);
+ CSInstr->moveBefore(&CB);
+ const auto NewCSID = CtxProf.allocateNextCallsiteIndex(Caller);
+ auto *NewCSInstr = cast<InstrProfCallsite>(CSInstr->clone());
+ NewCSInstr->setIndex(NewCSID);
+ NewCSInstr->setCallee(&Callee);
+ NewCSInstr->insertBefore(&DirectCall);
+ auto &DirectBB = *DirectCall.getParent();
+ auto &IndirectBB = *CB.getParent();
+
+ assert((CtxProfAnalysis::getBBInstrumentation(IndirectBB) == nullptr) &&
+ "The ICP direct BB is new, it shouldn't have instrumentation");
+ assert((CtxProfAnalysis::getBBInstrumentation(DirectBB) == nullptr) &&
+ "The ICP indirect BB is new, it shouldn't have instrumentation");
+
+ // Make the 2 new BBs have counters.
+ const uint32_t DirectID = CtxProf.allocateNextCounterIndex(Caller);
+ const uint32_t IndirectID = CtxProf.allocateNextCounterIndex(Caller);
+ const uint32_t NewCountersSize = IndirectID + 1;
+ auto *EntryBBIns =
+ CtxProfAnalysis::getBBInstrumentation(Caller.getEntryBlock());
+ auto *DirectBBIns = cast<InstrProfCntrInstBase>(EntryBBIns->clone());
+ DirectBBIns->setIndex(DirectID);
+ DirectBBIns->insertInto(&DirectBB, DirectBB.getFirstInsertionPt());
+
+ auto *IndirectBBIns = cast<InstrProfCntrInstBase>(EntryBBIns->clone());
+ IndirectBBIns->setIndex(IndirectID);
+ IndirectBBIns->insertInto(&IndirectBB, IndirectBB.getFirstInsertionPt());
+
+ const GlobalValue::GUID CalleeGUID = AssignGUIDPass::getGUID(Callee);
+
+ auto ProfileUpdater = [&](PGOCtxProfContext &Ctx) {
+ assert(Ctx.guid() == AssignGUIDPass::getGUID(Caller));
+ assert(NewCountersSize - 2 == Ctx.counters().size());
+ // Regardless what next, all the ctx-es belonging to a function must have
+ // the same size counters.
+ Ctx.growCounters(NewCountersSize);
+
+ // Maybe in this context, the indirect callsite wasn't observed at all
+ if (!Ctx.hasCallsite(CSIndex))
+ return;
+ auto &CSData = Ctx.callsite(CSIndex);
+ auto It = CSData.find(CalleeGUID);
+
+ // Maybe we did notice the indirect callsite, but to other targets.
+ if (It == CSData.end())
+ return;
+
+ assert(CalleeGUID == It->second.guid());
+
+ uint32_t DirectCount = It->second.getEntrycount();
+ uint32_t TotalCount = 0;
+ for (const auto &[_, V] : CSData)
+ TotalCount += V.getEntrycount();
+ assert(TotalCount >= DirectCount);
+ uint32_t IndirectCount = TotalCount - DirectCount;
+ // The ICP's effect is as-if the direct BB would have been taken DirectCount
+ // times, and the indirect BB, IndirectCount times
+ Ctx.counters()[DirectID] = DirectCount;
+ Ctx.counters()[IndirectID] = IndirectCount;
+
+ // This particular indirect target needs to be moved to this caller under
+ // the newly-allocated callsite index.
+ assert(Ctx.callsites().count(NewCSID) == 0);
+ Ctx.ingestContext(NewCSID, std::move(It->second));
+ CSData.erase(CalleeGUID);
+ };
+ CtxProf.update(ProfileUpdater, &Caller);
+ return &DirectCall;
+}
+
CallBase &llvm::promoteCallWithVTableCmp(CallBase &CB, Instruction *VPtr,
Function *Callee,
ArrayRef<Constant *> AddressPoints,
diff --git a/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp b/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp
index 2d457eb3b678aa..aff603de2a2bd5 100644
--- a/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp
+++ b/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Utils/CallPromotionUtils.h"
+#include "llvm/Analysis/CtxProfAnalysis.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
@@ -14,7 +15,12 @@
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/NoFolder.h"
+#include "llvm/IR/PassInstrumentation.h"
+#include "llvm/ProfileData/PGOCtxProfReader.h"
+#include "llvm/Support/JSON.h"
#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/Testing/Support/SupportHelpers.h"
#include "gtest/gtest.h"
using namespace llvm;
@@ -456,3 +462,175 @@ declare void @_ZN5Base35func3Ev(ptr)
// 1 call instruction from the entry block.
EXPECT_EQ(F->front().size(), OrigEntryBBSize + 4);
}
+
+using namespace llvm::ctx_profile;
+
+class ContextManager final {
+ std::vector<std::unique_ptr<char[]>> Nodes;
+ std::map<GUID, const ContextNode *> Roots;
+
+public:
+ ContextNode *createNode(GUID Guid, uint32_t NrCounters, uint32_t NrCallsites,
+ ContextNode *Next = nullptr) {
+ auto AllocSize = ContextNode::getAllocSize(NrCounters, NrCallsites);
+ auto *Mem = Nodes.emplace_back(std::make_unique<char[]>(AllocSize)).get();
+ std::memset(Mem, 0, AllocSize);
+ auto *Ret = new (Mem) ContextNode(Guid, NrCounters, NrCallsites, Next);
+ return Ret;
+ }
+};
+
+TEST(CallPromotionUtilsTest, PromoteWithIcmpAndCtxProf) {
+ LLVMContext C;
+ std::unique_ptr<Module> M = parseIR(C,
+ R"IR(
+define i32 @testfunc1(ptr %d) !guid !0 {
+ call void @llvm.instrprof.increment(ptr null, i64 0, i32 1, i32 0)
+ call void @llvm.instrprof.callsite(ptr null, i64 0, i32 1, i32 0, ptr %d)
+ %call = call i32 %d()
+ ret i32 %call
+}
+
+define i32 @f1() !guid !1 {
+ call void @llvm.instrprof.increment(ptr null, i64 0, i32 1, i32 0)
+ ret i32 2
+}
+
+define i32 @f2() !guid !2 {
+ call void @llvm.instrprof.increment(ptr null, i64 0, i32 1, i32 0)
+ call void @llvm.instrprof.callsite(ptr null, i64 0, i32 1, i32 0, ptr @f4)
+ %r = call i32 @f4()
+ ret i32 %r
+}
+
+define i32 @testfunc2(ptr %p) !guid !4 {
+ call void @llvm.instrprof.increment(ptr null, i64 0, i32 1, i32 0)
+ call void @llvm.instrprof.callsite(ptr null, i64 0, i32 1, i32 0, ptr @testfunc1)
+ %r = call i32 @testfunc1(ptr %p)
+ ret i32 %r
+}
+
+declare i32 @f3()
+
+define i32 @f4() !guid !3 {
+ ret i32 3
+}
+
+!0 = !{i64 1000}
+!1 = !{i64 1001}
+!2 = !{i64 1002}
+!3 = !{i64 1004}
+!4 = !{i64 1005}
+)IR");
+
+ // Synthesize a profile. The profile is nonsensical, but the goal is to check
+ // that new BBs are created with IDs and the right counter values.
+ ContextManager Mgr;
+ auto BuildTree = [&](const std::vector<uint32_t> &CalleeEntrycounts) {
+ auto *Entry = Mgr.createNode(1000, 1, 1);
+ // Set the entrycount to 1 so it's not 0. We don't care about it, really,
+ // for this test but we generally assume it's not 0.
+ Entry->counters()[0] = 1;
+ auto *F1 = Mgr.createNode(1001, 1, 0);
+ auto *F2 = Mgr.createNode(1002, 1, 1, F1);
+ auto *F3 = Mgr.createNode(1003, 1, 0, F2);
+ auto *F4 = Mgr.createNode(1004, 1, 0);
+
+ F1->counters()[0] = CalleeEntrycounts[0];
+ F2->counters()[0] = CalleeEntrycounts[1];
+ F3->counters()[0] = CalleeEntrycounts[2];
+ F4->counters()[0] = CalleeEntrycounts[3];
+ F2->subContexts()[0] = F4;
+ Entry->subContexts()[0] = F3; // which chains F2 and F1
+ return Entry;
+ };
+ // We'll be interested in f2. the entry counts for it are: 11 in the first
+ // context; and 102 in the second.
+ // The total number of times the indirect callsite is exercised is:
+ // 10+11+12 = 35 in the first case; and 101+102+103 = 306 in the
+ // second.
+ // This means that the direct/indirect call counters will be: 11/22 in the
+ // first case and 102/204 in the second. Meaning, the "Counters" for the
+ // GUID=1002 context will look like [1, 11, 22] and [1, 102, 204],
+ // respectivelly (the first "1" being the entrycount which we set to 1 above)
+ auto *Entry1 = BuildTree({10, 11, 12, 13});
+ auto *SubTree2 = BuildTree({101, 102, 103, 104});
+ auto *Entry2 = Mgr.createNode(1005, 1, 1);
+ Entry2->counters()[0] = 2;
+ Entry2->subContexts()[0] = SubTree2;
+
+ llvm::unittest::TempFile ProfileFile("ctx_profile", "", "", /*Unique*/ true);
+ {
+ std::error_code EC;
+ raw_fd_stream Out(ProfileFile.path(), EC);
+ ASSERT_FALSE(EC);
+ {
+ PGOCtxProfileWriter Writer(Out);
+ Writer.write(*Entry1);
+ Writer.write(*Entry2);
+ }
+ }
+
+ ModuleAnalysisManager MAM;
+ MAM.registerPass([&]() { return CtxProfAnalysis(ProfileFile.path()); });
+ MAM.registerPass([&]() { return PassInstrumentationAnalysis(); });
+ auto &CtxProf = MAM.getResult<CtxProfAnalysis>(*M);
+ auto *Caller = M->getFunction("testfunc1");
+ ASSERT_TRUE(!!Caller);
+ auto *Callee = M->getFunction("f2");
+ ASSERT_TRUE(!!Callee);
+ auto *IndirectCS = [&]() -> CallBase * {
+ for (auto &BB : *Caller)
+ for (auto &I : BB)
+ if (auto *CB = dyn_cast<CallBase>(&I); CB && CB->isIndirectCall())
+ return CB;
+ return nullptr;
+ }();
+ ASSERT_TRUE(!!IndirectCS);
+ promoteCallWithIfThenElse(*IndirectCS, *Callee, CtxProf);
+
+ std::string Str;
+ raw_string_ostream OS(Str);
+ CtxProfAnalysisPrinterPass Printer(
+ OS, CtxProfAnalysisPrinterPass::PrintMode::JSON);
+ Printer.run(*M, MAM);
+ const char *Expected = R"(
+ [
+ {
+ "Guid": 1000,
+ "Counters": [1, 11, 22],
+ "Callsites": [
+ [{ "Guid": 1001,
+ "Counters": [10]},
+ { "Guid": 1003,
+ "Counters": [12]
+ }],
+ [{ "Guid": 1002,
+ "Counters": [11],
+ "Callsites": [
+ [{ "Guid": 1004,
+ "Counters": [13] }]]}]]
+ },
+ {
+ "Guid": 1005,
+ "Counters": [2],
+ "Callsites": [
+ [{ "Guid": 1000,
+ "Counters": [1, 102, 204],
+ "Callsites": [
+ [{ "Guid": 1001,
+ "Counters": [101]},
+ { "Guid": 1003,
+ "Counters": [103]}],
+ [{ "Guid": 1002,
+ "Counters": [102],
+ "Callsites": [
+ [{ "Guid": 1004,
+ "Counters": [104]}]]}]]}]]}
+])";
+ auto ExpectedJSON = json::parse(Expected);
+ ASSERT_TRUE(!!ExpectedJSON);
+ auto ProducedJSON = json::parse(Str);
+ ASSERT_TRUE(!!ProducedJSON);
+ EXPECT_EQ(*ProducedJSON, *ExpectedJSON);
+}
More information about the llvm-branch-commits
mailing list