[llvm] [ctx_prof] Add support for ICP (PR #105469)
Mircea Trofin via llvm-commits
llvm-commits at lists.llvm.org
Fri Aug 23 09:57:50 PDT 2024
https://github.com/mtrofin updated https://github.com/llvm/llvm-project/pull/105469
>From f1780fa803cfcd0324fb463f3a6ed98c32cc1e38 Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Thu, 22 Aug 2024 13:27:44 -0700
Subject: [PATCH 1/2] [ctx_prof] Remove the dependency on the "name"
GlobalVariable
We don't need that name variable for contextual instrumentation, we just
use the function to get its GUID which we pass to the runtime, and rely
on metadata to capture it through the various optimization passes. This
change removes the need for the name global variable.
---
llvm/include/llvm/IR/IntrinsicInst.h | 14 +++++--
.../Instrumentation/PGOCtxProfLowering.cpp | 3 +-
.../Instrumentation/PGOInstrumentation.cpp | 15 +++----
.../PGOProfile/ctx-instrumentation.ll | 41 +++++++------------
.../PGOProfile/ctx-prof-use-prelink.ll | 11 ++---
5 files changed, 39 insertions(+), 45 deletions(-)
diff --git a/llvm/include/llvm/IR/IntrinsicInst.h b/llvm/include/llvm/IR/IntrinsicInst.h
index c188bec631a239..b45c89cadb0fde 100644
--- a/llvm/include/llvm/IR/IntrinsicInst.h
+++ b/llvm/include/llvm/IR/IntrinsicInst.h
@@ -1503,11 +1503,19 @@ class InstrProfInstBase : public IntrinsicInst {
return isCounterBase(*Instr) || isMCDCBitmapBase(*Instr);
return false;
}
- // The name of the instrumented function.
+
+ // The name of the instrumented function, assuming it is a global variable.
GlobalVariable *getName() const {
- return cast<GlobalVariable>(
- const_cast<Value *>(getArgOperand(0))->stripPointerCasts());
+ return cast<GlobalVariable>(getNameValue());
+ }
+
+ // The "name" operand of the profile instrumentation instruction - this is the
+ // operand that can be used to relate the instruction to the function it
+ // belonged to at instrumentation time.
+ Value *getNameValue() const {
+ return const_cast<Value *>(getArgOperand(0))->stripPointerCasts();
}
+
// The hash of the CFG for the instrumented function.
ConstantInt *getHash() const {
return cast<ConstantInt>(const_cast<Value *>(getArgOperand(1)));
diff --git a/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp b/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp
index 9b10cbba84075a..43bebc99316e06 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp
@@ -226,7 +226,8 @@ bool CtxInstrumentationLowerer::lowerFunction(Function &F) {
IRBuilder<> Builder(Mark);
- Guid = Builder.getInt64(AssignGUIDPass::getGUID(F));
+ Guid = Builder.getInt64(
+ AssignGUIDPass::getGUID(cast<Function>(*Mark->getNameValue())));
// The type of the context of this function is now knowable since we have
// NrCallsites and NrCounters. We delcare it here because it's more
// convenient - we have the Builder.
diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index 39cf94daab7d3b..aacfe39f16fbc4 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -464,7 +464,7 @@ struct SelectInstVisitor : public InstVisitor<SelectInstVisitor> {
VisitMode Mode = VM_counting; // Visiting mode.
unsigned *CurCtrIdx = nullptr; // Pointer to current counter index.
unsigned TotalNumCtrs = 0; // Total number of counters
- GlobalVariable *FuncNameVar = nullptr;
+ GlobalValue *FuncNameVar = nullptr;
uint64_t FuncHash = 0;
PGOUseFunc *UseFunc = nullptr;
bool HasSingleByteCoverage;
@@ -482,7 +482,7 @@ struct SelectInstVisitor : public InstVisitor<SelectInstVisitor> {
// Ind is a pointer to the counter index variable; \p TotalNC
// is the total number of counters; \p FNV is the pointer to the
// PGO function name var; \p FHash is the function hash.
- void instrumentSelects(unsigned *Ind, unsigned TotalNC, GlobalVariable *FNV,
+ void instrumentSelects(unsigned *Ind, unsigned TotalNC, GlobalValue *FNV,
uint64_t FHash) {
Mode = VM_instrument;
CurCtrIdx = Ind;
@@ -901,13 +901,14 @@ void FunctionInstrumenter::instrument() {
SplitIndirectBrCriticalEdges(F, /*IgnoreBlocksWithoutPHI=*/false, BPI, BFI);
}
+ const bool IsCtxProf = InstrumentationType == PGOInstrumentationType::CTXPROF;
FuncPGOInstrumentation<PGOEdge, PGOBBInfo> FuncInfo(
- F, TLI, ComdatMembers, true, BPI, BFI,
+ F, TLI, ComdatMembers, /*CreateGlobalVar=*/!IsCtxProf, BPI, BFI,
InstrumentationType == PGOInstrumentationType::CSFDO,
shouldInstrumentEntryBB(), PGOBlockCoverage);
- auto Name = FuncInfo.FuncNameVar;
- auto CFGHash =
+ auto *const Name = IsCtxProf ? cast<GlobalValue>(&F) : FuncInfo.FuncNameVar;
+ auto *const CFGHash =
ConstantInt::get(Type::getInt64Ty(M.getContext()), FuncInfo.FunctionHash);
// Make sure that pointer to global is passed in with zero addrspace
// This is relevant during GPU profiling
@@ -929,7 +930,7 @@ void FunctionInstrumenter::instrument() {
unsigned NumCounters =
InstrumentBBs.size() + FuncInfo.SIVisitor.getNumOfSelectInsts();
- if (InstrumentationType == PGOInstrumentationType::CTXPROF) {
+ if (IsCtxProf) {
auto *CSIntrinsic =
Intrinsic::getDeclaration(&M, Intrinsic::instrprof_callsite);
// We want to count the instrumentable callsites, then instrument them. This
@@ -995,7 +996,7 @@ void FunctionInstrumenter::instrument() {
}
// Now instrument select instructions:
- FuncInfo.SIVisitor.instrumentSelects(&I, NumCounters, FuncInfo.FuncNameVar,
+ FuncInfo.SIVisitor.instrumentSelects(&I, NumCounters, Name,
FuncInfo.FunctionHash);
assert(I == NumCounters);
diff --git a/llvm/test/Transforms/PGOProfile/ctx-instrumentation.ll b/llvm/test/Transforms/PGOProfile/ctx-instrumentation.ll
index df4e467567c46e..c94c2b4da57a98 100644
--- a/llvm/test/Transforms/PGOProfile/ctx-instrumentation.ll
+++ b/llvm/test/Transforms/PGOProfile/ctx-instrumentation.ll
@@ -9,19 +9,6 @@
declare void @bar()
;.
-; INSTRUMENT: @__profn_foo = private constant [3 x i8] c"foo"
-; INSTRUMENT: @__profn_an_entrypoint = private constant [13 x i8] c"an_entrypoint"
-; INSTRUMENT: @__profn_another_entrypoint_no_callees = private constant [29 x i8] c"another_entrypoint_no_callees"
-; INSTRUMENT: @__profn_simple = private constant [6 x i8] c"simple"
-; INSTRUMENT: @__profn_no_callsites = private constant [12 x i8] c"no_callsites"
-; INSTRUMENT: @__profn_no_counters = private constant [11 x i8] c"no_counters"
-;.
-; LOWERING: @__profn_foo = private constant [3 x i8] c"foo"
-; LOWERING: @__profn_an_entrypoint = private constant [13 x i8] c"an_entrypoint"
-; LOWERING: @__profn_another_entrypoint_no_callees = private constant [29 x i8] c"another_entrypoint_no_callees"
-; LOWERING: @__profn_simple = private constant [6 x i8] c"simple"
-; LOWERING: @__profn_no_callsites = private constant [12 x i8] c"no_callsites"
-; LOWERING: @__profn_no_counters = private constant [11 x i8] c"no_counters"
; LOWERING: @an_entrypoint_ctx_root = global { ptr, ptr, ptr, i8 } zeroinitializer
; LOWERING: @another_entrypoint_no_callees_ctx_root = global { ptr, ptr, ptr, i8 } zeroinitializer
; LOWERING: @__llvm_ctx_profile_callsite = external hidden thread_local global ptr
@@ -30,16 +17,16 @@ declare void @bar()
define void @foo(i32 %a, ptr %fct) {
; INSTRUMENT-LABEL: define void @foo(
; INSTRUMENT-SAME: i32 [[A:%.*]], ptr [[FCT:%.*]]) {
-; INSTRUMENT-NEXT: call void @llvm.instrprof.increment(ptr @__profn_foo, i64 728453322856651412, i32 2, i32 0)
+; INSTRUMENT-NEXT: call void @llvm.instrprof.increment(ptr @foo, i64 728453322856651412, i32 2, i32 0)
; INSTRUMENT-NEXT: [[T:%.*]] = icmp eq i32 [[A]], 0
; INSTRUMENT-NEXT: br i1 [[T]], label [[YES:%.*]], label [[NO:%.*]]
; INSTRUMENT: yes:
-; INSTRUMENT-NEXT: call void @llvm.instrprof.increment(ptr @__profn_foo, i64 728453322856651412, i32 2, i32 1)
-; INSTRUMENT-NEXT: call void @llvm.instrprof.callsite(ptr @__profn_foo, i64 728453322856651412, i32 2, i32 0, ptr [[FCT]])
+; INSTRUMENT-NEXT: call void @llvm.instrprof.increment(ptr @foo, i64 728453322856651412, i32 2, i32 1)
+; INSTRUMENT-NEXT: call void @llvm.instrprof.callsite(ptr @foo, i64 728453322856651412, i32 2, i32 0, ptr [[FCT]])
; INSTRUMENT-NEXT: call void [[FCT]](i32 [[A]])
; INSTRUMENT-NEXT: br label [[EXIT:%.*]]
; INSTRUMENT: no:
-; INSTRUMENT-NEXT: call void @llvm.instrprof.callsite(ptr @__profn_foo, i64 728453322856651412, i32 2, i32 1, ptr @bar)
+; INSTRUMENT-NEXT: call void @llvm.instrprof.callsite(ptr @foo, i64 728453322856651412, i32 2, i32 1, ptr @bar)
; INSTRUMENT-NEXT: call void @bar()
; INSTRUMENT-NEXT: br label [[EXIT]]
; INSTRUMENT: exit:
@@ -92,12 +79,12 @@ exit:
define void @an_entrypoint(i32 %a) {
; INSTRUMENT-LABEL: define void @an_entrypoint(
; INSTRUMENT-SAME: i32 [[A:%.*]]) {
-; INSTRUMENT-NEXT: call void @llvm.instrprof.increment(ptr @__profn_an_entrypoint, i64 784007058953177093, i32 2, i32 0)
+; INSTRUMENT-NEXT: call void @llvm.instrprof.increment(ptr @an_entrypoint, i64 784007058953177093, i32 2, i32 0)
; INSTRUMENT-NEXT: [[T:%.*]] = icmp eq i32 [[A]], 0
; INSTRUMENT-NEXT: br i1 [[T]], label [[YES:%.*]], label [[NO:%.*]]
; INSTRUMENT: yes:
-; INSTRUMENT-NEXT: call void @llvm.instrprof.increment(ptr @__profn_an_entrypoint, i64 784007058953177093, i32 2, i32 1)
-; INSTRUMENT-NEXT: call void @llvm.instrprof.callsite(ptr @__profn_an_entrypoint, i64 784007058953177093, i32 1, i32 0, ptr @foo)
+; INSTRUMENT-NEXT: call void @llvm.instrprof.increment(ptr @an_entrypoint, i64 784007058953177093, i32 2, i32 1)
+; INSTRUMENT-NEXT: call void @llvm.instrprof.callsite(ptr @an_entrypoint, i64 784007058953177093, i32 1, i32 0, ptr @foo)
; INSTRUMENT-NEXT: call void @foo(i32 1, ptr null)
; INSTRUMENT-NEXT: ret void
; INSTRUMENT: no:
@@ -144,11 +131,11 @@ no:
define void @another_entrypoint_no_callees(i32 %a) {
; INSTRUMENT-LABEL: define void @another_entrypoint_no_callees(
; INSTRUMENT-SAME: i32 [[A:%.*]]) {
-; INSTRUMENT-NEXT: call void @llvm.instrprof.increment(ptr @__profn_another_entrypoint_no_callees, i64 784007058953177093, i32 2, i32 0)
+; INSTRUMENT-NEXT: call void @llvm.instrprof.increment(ptr @another_entrypoint_no_callees, i64 784007058953177093, i32 2, i32 0)
; INSTRUMENT-NEXT: [[T:%.*]] = icmp eq i32 [[A]], 0
; INSTRUMENT-NEXT: br i1 [[T]], label [[YES:%.*]], label [[NO:%.*]]
; INSTRUMENT: yes:
-; INSTRUMENT-NEXT: call void @llvm.instrprof.increment(ptr @__profn_another_entrypoint_no_callees, i64 784007058953177093, i32 2, i32 1)
+; INSTRUMENT-NEXT: call void @llvm.instrprof.increment(ptr @another_entrypoint_no_callees, i64 784007058953177093, i32 2, i32 1)
; INSTRUMENT-NEXT: ret void
; INSTRUMENT: no:
; INSTRUMENT-NEXT: ret void
@@ -184,7 +171,7 @@ no:
define void @simple(i32 %a) {
; INSTRUMENT-LABEL: define void @simple(
; INSTRUMENT-SAME: i32 [[A:%.*]]) {
-; INSTRUMENT-NEXT: call void @llvm.instrprof.increment(ptr @__profn_simple, i64 742261418966908927, i32 1, i32 0)
+; INSTRUMENT-NEXT: call void @llvm.instrprof.increment(ptr @simple, i64 742261418966908927, i32 1, i32 0)
; INSTRUMENT-NEXT: ret void
;
; LOWERING-LABEL: define void @simple(
@@ -202,11 +189,11 @@ define void @simple(i32 %a) {
define i32 @no_callsites(i32 %a) {
; INSTRUMENT-LABEL: define i32 @no_callsites(
; INSTRUMENT-SAME: i32 [[A:%.*]]) {
-; INSTRUMENT-NEXT: call void @llvm.instrprof.increment(ptr @__profn_no_callsites, i64 784007058953177093, i32 2, i32 0)
+; INSTRUMENT-NEXT: call void @llvm.instrprof.increment(ptr @no_callsites, i64 784007058953177093, i32 2, i32 0)
; INSTRUMENT-NEXT: [[C:%.*]] = icmp eq i32 [[A]], 0
; INSTRUMENT-NEXT: br i1 [[C]], label [[YES:%.*]], label [[NO:%.*]]
; INSTRUMENT: yes:
-; INSTRUMENT-NEXT: call void @llvm.instrprof.increment(ptr @__profn_no_callsites, i64 784007058953177093, i32 2, i32 1)
+; INSTRUMENT-NEXT: call void @llvm.instrprof.increment(ptr @no_callsites, i64 784007058953177093, i32 2, i32 1)
; INSTRUMENT-NEXT: ret i32 1
; INSTRUMENT: no:
; INSTRUMENT-NEXT: ret i32 0
@@ -238,8 +225,8 @@ no:
define void @no_counters() {
; INSTRUMENT-LABEL: define void @no_counters() {
-; INSTRUMENT-NEXT: call void @llvm.instrprof.increment(ptr @__profn_no_counters, i64 742261418966908927, i32 1, i32 0)
-; INSTRUMENT-NEXT: call void @llvm.instrprof.callsite(ptr @__profn_no_counters, i64 742261418966908927, i32 1, i32 0, ptr @bar)
+; INSTRUMENT-NEXT: call void @llvm.instrprof.increment(ptr @no_counters, i64 742261418966908927, i32 1, i32 0)
+; INSTRUMENT-NEXT: call void @llvm.instrprof.callsite(ptr @no_counters, i64 742261418966908927, i32 1, i32 0, ptr @bar)
; INSTRUMENT-NEXT: call void @bar()
; INSTRUMENT-NEXT: ret void
;
diff --git a/llvm/test/Transforms/PGOProfile/ctx-prof-use-prelink.ll b/llvm/test/Transforms/PGOProfile/ctx-prof-use-prelink.ll
index cb8ab78dc0f414..7959e4d0760edb 100644
--- a/llvm/test/Transforms/PGOProfile/ctx-prof-use-prelink.ll
+++ b/llvm/test/Transforms/PGOProfile/ctx-prof-use-prelink.ll
@@ -7,22 +7,19 @@
declare void @bar()
-;.
-; CHECK: @__profn_foo = private constant [3 x i8] c"foo"
-;.
define void @foo(i32 %a, ptr %fct) {
; CHECK-LABEL: define void @foo(
; CHECK-SAME: i32 [[A:%.*]], ptr [[FCT:%.*]]) local_unnamed_addr !guid [[META0:![0-9]+]] {
-; CHECK-NEXT: call void @llvm.instrprof.increment(ptr @__profn_foo, i64 728453322856651412, i32 2, i32 0)
+; CHECK-NEXT: call void @llvm.instrprof.increment(ptr @foo, i64 728453322856651412, i32 2, i32 0)
; CHECK-NEXT: [[T:%.*]] = icmp eq i32 [[A]], 0
; CHECK-NEXT: br i1 [[T]], label %[[YES:.*]], label %[[NO:.*]]
; CHECK: [[YES]]:
-; CHECK-NEXT: call void @llvm.instrprof.increment(ptr @__profn_foo, i64 728453322856651412, i32 2, i32 1)
-; CHECK-NEXT: call void @llvm.instrprof.callsite(ptr @__profn_foo, i64 728453322856651412, i32 2, i32 0, ptr [[FCT]])
+; CHECK-NEXT: call void @llvm.instrprof.increment(ptr @foo, i64 728453322856651412, i32 2, i32 1)
+; CHECK-NEXT: call void @llvm.instrprof.callsite(ptr @foo, i64 728453322856651412, i32 2, i32 0, ptr [[FCT]])
; CHECK-NEXT: call void [[FCT]](i32 0)
; CHECK-NEXT: br label %[[EXIT:.*]]
; CHECK: [[NO]]:
-; CHECK-NEXT: call void @llvm.instrprof.callsite(ptr @__profn_foo, i64 728453322856651412, i32 2, i32 1, ptr @bar)
+; CHECK-NEXT: call void @llvm.instrprof.callsite(ptr @foo, i64 728453322856651412, i32 2, i32 1, ptr @bar)
; CHECK-NEXT: call void @bar()
; CHECK-NEXT: br label %[[EXIT]]
; CHECK: [[EXIT]]:
>From c6e027acf5ce0ffd40fff0bc81537eb4708b2222 Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Tue, 20 Aug 2024 21:32:23 -0700
Subject: [PATCH 2/2] [ctx_prof] Add support for ICP
---
llvm/include/llvm/Analysis/CtxProfAnalysis.h | 17 +-
llvm/include/llvm/IR/IntrinsicInst.h | 2 +
.../llvm/ProfileData/PGOCtxProfReader.h | 20 ++
.../Transforms/Utils/CallPromotionUtils.h | 4 +
llvm/lib/Analysis/CtxProfAnalysis.cpp | 79 +++++---
llvm/lib/IR/IntrinsicInst.cpp | 10 +
.../Transforms/Utils/CallPromotionUtils.cpp | 86 +++++++++
.../Utils/CallPromotionUtilsTest.cpp | 178 ++++++++++++++++++
8 files changed, 364 insertions(+), 32 deletions(-)
diff --git a/llvm/include/llvm/Analysis/CtxProfAnalysis.h b/llvm/include/llvm/Analysis/CtxProfAnalysis.h
index 0b4dd8ae3a0dc7..10aef6f6067b6f 100644
--- a/llvm/include/llvm/Analysis/CtxProfAnalysis.h
+++ b/llvm/include/llvm/Analysis/CtxProfAnalysis.h
@@ -73,6 +73,12 @@ class PGOContextualProfile {
return FuncInfo.find(getDefinedFunctionGUID(F))->second.NextCallsiteIndex++;
}
+ using ConstVisitor = function_ref<void(const PGOCtxProfContext &)>;
+ using Visitor = function_ref<void(PGOCtxProfContext &)>;
+
+ void update(Visitor, const Function *F = nullptr);
+ void visit(ConstVisitor, const Function *F = nullptr) const;
+
const CtxProfFlatProfile flatten() const;
bool invalidate(Module &, const PreservedAnalyses &PA,
@@ -105,13 +111,18 @@ class CtxProfAnalysis : public AnalysisInfoMixin<CtxProfAnalysis> {
class CtxProfAnalysisPrinterPass
: public PassInfoMixin<CtxProfAnalysisPrinterPass> {
- raw_ostream &OS;
-
public:
- explicit CtxProfAnalysisPrinterPass(raw_ostream &OS) : OS(OS) {}
+ enum class PrintMode { Everything, JSON };
+ explicit CtxProfAnalysisPrinterPass(raw_ostream &OS,
+ PrintMode Mode = PrintMode::Everything)
+ : OS(OS), Mode(Mode) {}
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
static bool isRequired() { return true; }
+
+private:
+ raw_ostream &OS;
+ const PrintMode Mode;
};
/// Assign a GUID to functions as metadata. GUID calculation takes linkage into
diff --git a/llvm/include/llvm/IR/IntrinsicInst.h b/llvm/include/llvm/IR/IntrinsicInst.h
index b45c89cadb0fde..5037e049aada57 100644
--- a/llvm/include/llvm/IR/IntrinsicInst.h
+++ b/llvm/include/llvm/IR/IntrinsicInst.h
@@ -1535,6 +1535,7 @@ class InstrProfCntrInstBase : public InstrProfInstBase {
ConstantInt *getNumCounters() const;
// The index of the counter that this instruction acts on.
ConstantInt *getIndex() const;
+ void setIndex(uint32_t Idx);
};
/// This represents the llvm.instrprof.cover intrinsic.
@@ -1585,6 +1586,7 @@ class InstrProfCallsite : public InstrProfCntrInstBase {
return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
}
Value *getCallee() const;
+ void setCallee(Value *);
};
/// This represents the llvm.instrprof.timestamp intrinsic.
diff --git a/llvm/include/llvm/ProfileData/PGOCtxProfReader.h b/llvm/include/llvm/ProfileData/PGOCtxProfReader.h
index 190deaeeacd085..23dcc376508b39 100644
--- a/llvm/include/llvm/ProfileData/PGOCtxProfReader.h
+++ b/llvm/include/llvm/ProfileData/PGOCtxProfReader.h
@@ -57,9 +57,23 @@ class PGOCtxProfContext final {
GlobalValue::GUID guid() const { return GUID; }
const SmallVectorImpl<uint64_t> &counters() const { return Counters; }
+ SmallVectorImpl<uint64_t> &counters() { return Counters; }
+
+ uint64_t getEntrycount() const { return Counters[0]; }
+
const CallsiteMapTy &callsites() const { return Callsites; }
CallsiteMapTy &callsites() { return Callsites; }
+ void ingestContext(uint32_t CSId, PGOCtxProfContext &&Other) {
+ auto [Iter, _] = callsites().try_emplace(CSId, CallTargetMapTy());
+ Iter->second.emplace(Other.guid(), std::move(Other));
+ }
+
+ void growCounters(uint32_t Size) {
+ if (Size >= Counters.size())
+ Counters.resize(Size);
+ }
+
bool hasCallsite(uint32_t I) const {
return Callsites.find(I) != Callsites.end();
}
@@ -68,6 +82,12 @@ class PGOCtxProfContext final {
assert(hasCallsite(I) && "Callsite not found");
return Callsites.find(I)->second;
}
+
+ CallTargetMapTy &callsite(uint32_t I) {
+ assert(hasCallsite(I) && "Callsite not found");
+ return Callsites.find(I)->second;
+ }
+
void getContainedGuids(DenseSet<GlobalValue::GUID> &Guids) const;
};
diff --git a/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h b/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h
index 385831f457038d..58af26f31417b0 100644
--- a/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h
+++ b/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h
@@ -14,6 +14,7 @@
#ifndef LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H
#define LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H
+#include "llvm/Analysis/CtxProfAnalysis.h"
namespace llvm {
template <typename T> class ArrayRef;
class Constant;
@@ -56,6 +57,9 @@ CallBase &promoteCall(CallBase &CB, Function *Callee,
CallBase &promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
MDNode *BranchWeights = nullptr);
+CallBase *promoteCallWithIfThenElse(CallBase &CB, Function &Callee,
+ PGOContextualProfile &CtxProf);
+
/// This is similar to `promoteCallWithIfThenElse` except that the condition to
/// promote a virtual call is that \p VPtr is the same as any of \p
/// AddressPoints.
diff --git a/llvm/lib/Analysis/CtxProfAnalysis.cpp b/llvm/lib/Analysis/CtxProfAnalysis.cpp
index 3fc1bc34afb97e..2cd3f2114397e5 100644
--- a/llvm/lib/Analysis/CtxProfAnalysis.cpp
+++ b/llvm/lib/Analysis/CtxProfAnalysis.cpp
@@ -173,16 +173,22 @@ PreservedAnalyses CtxProfAnalysisPrinterPass::run(Module &M,
return PreservedAnalyses::all();
}
- OS << "Function Info:\n";
- for (const auto &[Guid, FuncInfo] : C.FuncInfo)
- OS << Guid << " : " << FuncInfo.Name
- << ". MaxCounterID: " << FuncInfo.NextCounterIndex
- << ". MaxCallsiteID: " << FuncInfo.NextCallsiteIndex << "\n";
+ if (Mode == PrintMode::Everything) {
+ OS << "Function Info:\n";
+ for (const auto &[Guid, FuncInfo] : C.FuncInfo)
+ OS << Guid << " : " << FuncInfo.Name
+ << ". MaxCounterID: " << FuncInfo.NextCounterIndex
+ << ". MaxCallsiteID: " << FuncInfo.NextCallsiteIndex << "\n";
+ }
const auto JSONed = ::llvm::json::toJSON(C.profiles());
- OS << "\nCurrent Profile:\n";
+ if (Mode == PrintMode::Everything)
+ OS << "\nCurrent Profile:\n";
OS << formatv("{0:2}", JSONed);
+ if (Mode == PrintMode::JSON)
+ return PreservedAnalyses::all();
+
OS << "\n";
OS << "\nFlat Profile:\n";
auto Flat = C.flatten();
@@ -209,34 +215,49 @@ InstrProfIncrementInst *CtxProfAnalysis::getBBInstrumentation(BasicBlock &BB) {
return nullptr;
}
-static void
-preorderVisit(const PGOCtxProfContext::CallTargetMapTy &Profiles,
- function_ref<void(const PGOCtxProfContext &)> Visitor) {
- std::function<void(const PGOCtxProfContext &)> Traverser =
- [&](const auto &Ctx) {
- Visitor(Ctx);
- for (const auto &[_, SubCtxSet] : Ctx.callsites())
- for (const auto &[__, Subctx] : SubCtxSet)
- Traverser(Subctx);
- };
- for (const auto &[_, P] : Profiles)
+template <class ProfilesTy, class ProfTy>
+static void preorderVisit(ProfilesTy &Profiles,
+ function_ref<void(ProfTy &)> Visitor,
+ GlobalValue::GUID Match = 0) {
+ std::function<void(ProfTy &)> Traverser = [&](auto &Ctx) {
+ if (!Match || Ctx.guid() == Match)
+ Visitor(Ctx);
+ for (auto &[_, SubCtxSet] : Ctx.callsites())
+ for (auto &[__, Subctx] : SubCtxSet)
+ Traverser(Subctx);
+ };
+ for (auto &[_, P] : Profiles)
Traverser(P);
}
+void PGOContextualProfile::update(Visitor V, const Function *F) {
+ GlobalValue::GUID G = F ? getDefinedFunctionGUID(*F) : 0U;
+ preorderVisit<PGOCtxProfContext::CallTargetMapTy, PGOCtxProfContext>(
+ *Profiles, V, G);
+}
+
+void PGOContextualProfile::visit(ConstVisitor V, const Function *F) const {
+ GlobalValue::GUID G = F ? getDefinedFunctionGUID(*F) : 0U;
+ preorderVisit<const PGOCtxProfContext::CallTargetMapTy,
+ const PGOCtxProfContext>(*Profiles, V, G);
+}
+
const CtxProfFlatProfile PGOContextualProfile::flatten() const {
assert(Profiles.has_value());
CtxProfFlatProfile Flat;
- preorderVisit(*Profiles, [&](const PGOCtxProfContext &Ctx) {
- auto [It, Ins] = Flat.insert({Ctx.guid(), {}});
- if (Ins) {
- llvm::append_range(It->second, Ctx.counters());
- return;
- }
- assert(It->second.size() == Ctx.counters().size() &&
- "All contexts corresponding to a function should have the exact "
- "same number of counters.");
- for (size_t I = 0, E = It->second.size(); I < E; ++I)
- It->second[I] += Ctx.counters()[I];
- });
+ preorderVisit<const PGOCtxProfContext::CallTargetMapTy,
+ const PGOCtxProfContext>(
+ *Profiles, [&](const PGOCtxProfContext &Ctx) {
+ auto [It, Ins] = Flat.insert({Ctx.guid(), {}});
+ if (Ins) {
+ llvm::append_range(It->second, Ctx.counters());
+ return;
+ }
+ assert(It->second.size() == Ctx.counters().size() &&
+ "All contexts corresponding to a function should have the exact "
+ "same number of counters.");
+ for (size_t I = 0, E = It->second.size(); I < E; ++I)
+ It->second[I] += Ctx.counters()[I];
+ });
return Flat;
}
diff --git a/llvm/lib/IR/IntrinsicInst.cpp b/llvm/lib/IR/IntrinsicInst.cpp
index db3b0196f66fd6..0eadd0f980c15b 100644
--- a/llvm/lib/IR/IntrinsicInst.cpp
+++ b/llvm/lib/IR/IntrinsicInst.cpp
@@ -285,6 +285,11 @@ ConstantInt *InstrProfCntrInstBase::getIndex() const {
return cast<ConstantInt>(const_cast<Value *>(getArgOperand(3)));
}
+void InstrProfCntrInstBase::setIndex(uint32_t Idx) {
+ assert(isa<InstrProfCntrInstBase>(this));
+ setArgOperand(3, ConstantInt::get(Type::getInt32Ty(getContext()), Idx));
+}
+
Value *InstrProfIncrementInst::getStep() const {
if (InstrProfIncrementInstStep::classof(this)) {
return const_cast<Value *>(getArgOperand(4));
@@ -300,6 +305,11 @@ Value *InstrProfCallsite::getCallee() const {
return nullptr;
}
+void InstrProfCallsite::setCallee(Value *V) {
+ assert(isa<InstrProfCallsite>(this));
+ setArgOperand(4, V);
+}
+
std::optional<RoundingMode> ConstrainedFPIntrinsic::getRoundingMode() const {
unsigned NumOperands = arg_size();
Metadata *MD = nullptr;
diff --git a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
index 90dc727cde16d7..0ca7524c273daa 100644
--- a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
+++ b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
@@ -13,13 +13,16 @@
#include "llvm/Transforms/Utils/CallPromotionUtils.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/Analysis/CtxProfAnalysis.h"
#include "llvm/Analysis/Loads.h"
#include "llvm/Analysis/TypeMetadataUtils.h"
#include "llvm/IR/AttributeMask.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Module.h"
+#include "llvm/ProfileData/PGOCtxProfReader.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
using namespace llvm;
@@ -572,6 +575,89 @@ CallBase &llvm::promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
return promoteCall(NewInst, Callee);
}
+CallBase *llvm::promoteCallWithIfThenElse(CallBase &CB, Function &Callee,
+ PGOContextualProfile &CtxProf) {
+ assert(CB.isIndirectCall());
+ if (!CtxProf.isFunctionKnown(Callee))
+ return nullptr;
+ auto &Caller = *CB.getParent()->getParent();
+ auto *CSInstr = CtxProfAnalysis::getCallsiteInstrumentation(CB);
+ if (!CSInstr)
+ return nullptr;
+ const auto CSIndex = CSInstr->getIndex()->getZExtValue();
+
+ CallBase &DirectCall = promoteCall(
+ versionCallSite(CB, &Callee, /*BranchWeights=*/nullptr), &Callee);
+ CSInstr->moveBefore(&CB);
+ const auto NewCSID = CtxProf.allocateNextCallsiteIndex(Caller);
+ auto *NewCSInstr = cast<InstrProfCallsite>(CSInstr->clone());
+ NewCSInstr->setIndex(NewCSID);
+ NewCSInstr->setCallee(&Callee);
+ NewCSInstr->insertBefore(&DirectCall);
+ auto &DirectBB = *DirectCall.getParent();
+ auto &IndirectBB = *CB.getParent();
+
+ assert((CtxProfAnalysis::getBBInstrumentation(IndirectBB) == nullptr) &&
+ "The ICP direct BB is new, it shouldn't have instrumentation");
+ assert((CtxProfAnalysis::getBBInstrumentation(DirectBB) == nullptr) &&
+ "The ICP indirect BB is new, it shouldn't have instrumentation");
+
+ // Make the 2 new BBs have counters.
+ const uint32_t DirectID = CtxProf.allocateNextCounterIndex(Caller);
+ const uint32_t IndirectID = CtxProf.allocateNextCounterIndex(Caller);
+ const uint32_t NewCountersSize = IndirectID + 1;
+ auto *EntryBBIns =
+ CtxProfAnalysis::getBBInstrumentation(Caller.getEntryBlock());
+ auto *DirectBBIns = cast<InstrProfCntrInstBase>(EntryBBIns->clone());
+ DirectBBIns->setIndex(DirectID);
+ DirectBBIns->insertInto(&DirectBB, DirectBB.getFirstInsertionPt());
+
+ auto *IndirectBBIns = cast<InstrProfCntrInstBase>(EntryBBIns->clone());
+ IndirectBBIns->setIndex(IndirectID);
+ IndirectBBIns->insertInto(&IndirectBB, IndirectBB.getFirstInsertionPt());
+
+ const GlobalValue::GUID CalleeGUID = AssignGUIDPass::getGUID(Callee);
+
+ auto ProfileUpdater = [&](PGOCtxProfContext &Ctx) {
+ assert(Ctx.guid() == AssignGUIDPass::getGUID(Caller));
+ assert(NewCountersSize - 2 == Ctx.counters().size());
+ // Regardless what next, all the ctx-es belonging to a function must have
+ // the same size counters.
+ Ctx.growCounters(NewCountersSize);
+
+ // Maybe in this context, the indirect callsite wasn't observed at all
+ if (!Ctx.hasCallsite(CSIndex))
+ return;
+ auto &CSData = Ctx.callsite(CSIndex);
+ auto It = CSData.find(CalleeGUID);
+
+ // Maybe we did notice the indirect callsite, but to other targets.
+ if (It == CSData.end())
+ return;
+
+ assert(CalleeGUID == It->second.guid());
+
+ uint32_t DirectCount = It->second.getEntrycount();
+ uint32_t TotalCount = 0;
+ for (const auto &[_, V] : CSData)
+ TotalCount += V.getEntrycount();
+ assert(TotalCount >= DirectCount);
+ uint32_t IndirectCount = TotalCount - DirectCount;
+ // The ICP's effect is as-if the direct BB would have been taken DirectCount
+ // times, and the indirect BB, IndirectCount times
+ Ctx.counters()[DirectID] = DirectCount;
+ Ctx.counters()[IndirectID] = IndirectCount;
+
+ // This particular indirect target needs to be moved to this caller under
+ // the newly-allocated callsite index.
+ assert(Ctx.callsites().count(NewCSID) == 0);
+ Ctx.ingestContext(NewCSID, std::move(It->second));
+ CSData.erase(CalleeGUID);
+ };
+ CtxProf.update(ProfileUpdater, &Caller);
+ return &DirectCall;
+}
+
CallBase &llvm::promoteCallWithVTableCmp(CallBase &CB, Instruction *VPtr,
Function *Callee,
ArrayRef<Constant *> AddressPoints,
diff --git a/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp b/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp
index 2d457eb3b678aa..aff603de2a2bd5 100644
--- a/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp
+++ b/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Utils/CallPromotionUtils.h"
+#include "llvm/Analysis/CtxProfAnalysis.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
@@ -14,7 +15,12 @@
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/NoFolder.h"
+#include "llvm/IR/PassInstrumentation.h"
+#include "llvm/ProfileData/PGOCtxProfReader.h"
+#include "llvm/Support/JSON.h"
#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/Testing/Support/SupportHelpers.h"
#include "gtest/gtest.h"
using namespace llvm;
@@ -456,3 +462,175 @@ declare void @_ZN5Base35func3Ev(ptr)
// 1 call instruction from the entry block.
EXPECT_EQ(F->front().size(), OrigEntryBBSize + 4);
}
+
+using namespace llvm::ctx_profile;
+
+class ContextManager final {
+ std::vector<std::unique_ptr<char[]>> Nodes;
+ std::map<GUID, const ContextNode *> Roots;
+
+public:
+ ContextNode *createNode(GUID Guid, uint32_t NrCounters, uint32_t NrCallsites,
+ ContextNode *Next = nullptr) {
+ auto AllocSize = ContextNode::getAllocSize(NrCounters, NrCallsites);
+ auto *Mem = Nodes.emplace_back(std::make_unique<char[]>(AllocSize)).get();
+ std::memset(Mem, 0, AllocSize);
+ auto *Ret = new (Mem) ContextNode(Guid, NrCounters, NrCallsites, Next);
+ return Ret;
+ }
+};
+
+TEST(CallPromotionUtilsTest, PromoteWithIcmpAndCtxProf) {
+ LLVMContext C;
+ std::unique_ptr<Module> M = parseIR(C,
+ R"IR(
+define i32 @testfunc1(ptr %d) !guid !0 {
+ call void @llvm.instrprof.increment(ptr null, i64 0, i32 1, i32 0)
+ call void @llvm.instrprof.callsite(ptr null, i64 0, i32 1, i32 0, ptr %d)
+ %call = call i32 %d()
+ ret i32 %call
+}
+
+define i32 @f1() !guid !1 {
+ call void @llvm.instrprof.increment(ptr null, i64 0, i32 1, i32 0)
+ ret i32 2
+}
+
+define i32 @f2() !guid !2 {
+ call void @llvm.instrprof.increment(ptr null, i64 0, i32 1, i32 0)
+ call void @llvm.instrprof.callsite(ptr null, i64 0, i32 1, i32 0, ptr @f4)
+ %r = call i32 @f4()
+ ret i32 %r
+}
+
+define i32 @testfunc2(ptr %p) !guid !4 {
+ call void @llvm.instrprof.increment(ptr null, i64 0, i32 1, i32 0)
+ call void @llvm.instrprof.callsite(ptr null, i64 0, i32 1, i32 0, ptr @testfunc1)
+ %r = call i32 @testfunc1(ptr %p)
+ ret i32 %r
+}
+
+declare i32 @f3()
+
+define i32 @f4() !guid !3 {
+ ret i32 3
+}
+
+!0 = !{i64 1000}
+!1 = !{i64 1001}
+!2 = !{i64 1002}
+!3 = !{i64 1004}
+!4 = !{i64 1005}
+)IR");
+
+ // Synthesize a profile. The profile is nonsensical, but the goal is to check
+ // that new BBs are created with IDs and the right counter values.
+ ContextManager Mgr;
+ auto BuildTree = [&](const std::vector<uint32_t> &CalleeEntrycounts) {
+ auto *Entry = Mgr.createNode(1000, 1, 1);
+ // Set the entrycount to 1 so it's not 0. We don't care about it, really,
+ // for this test but we generally assume it's not 0.
+ Entry->counters()[0] = 1;
+ auto *F1 = Mgr.createNode(1001, 1, 0);
+ auto *F2 = Mgr.createNode(1002, 1, 1, F1);
+ auto *F3 = Mgr.createNode(1003, 1, 0, F2);
+ auto *F4 = Mgr.createNode(1004, 1, 0);
+
+ F1->counters()[0] = CalleeEntrycounts[0];
+ F2->counters()[0] = CalleeEntrycounts[1];
+ F3->counters()[0] = CalleeEntrycounts[2];
+ F4->counters()[0] = CalleeEntrycounts[3];
+ F2->subContexts()[0] = F4;
+ Entry->subContexts()[0] = F3; // which chains F2 and F1
+ return Entry;
+ };
+ // We'll be interested in f2. the entry counts for it are: 11 in the first
+ // context; and 102 in the second.
+ // The total number of times the indirect callsite is exercised is:
+ // 10+11+12 = 35 in the first case; and 101+102+103 = 306 in the
+ // second.
+ // This means that the direct/indirect call counters will be: 11/22 in the
+ // first case and 102/204 in the second. Meaning, the "Counters" for the
+ // GUID=1002 context will look like [1, 11, 22] and [1, 102, 204],
+ // respectivelly (the first "1" being the entrycount which we set to 1 above)
+ auto *Entry1 = BuildTree({10, 11, 12, 13});
+ auto *SubTree2 = BuildTree({101, 102, 103, 104});
+ auto *Entry2 = Mgr.createNode(1005, 1, 1);
+ Entry2->counters()[0] = 2;
+ Entry2->subContexts()[0] = SubTree2;
+
+ llvm::unittest::TempFile ProfileFile("ctx_profile", "", "", /*Unique*/ true);
+ {
+ std::error_code EC;
+ raw_fd_stream Out(ProfileFile.path(), EC);
+ ASSERT_FALSE(EC);
+ {
+ PGOCtxProfileWriter Writer(Out);
+ Writer.write(*Entry1);
+ Writer.write(*Entry2);
+ }
+ }
+
+ ModuleAnalysisManager MAM;
+ MAM.registerPass([&]() { return CtxProfAnalysis(ProfileFile.path()); });
+ MAM.registerPass([&]() { return PassInstrumentationAnalysis(); });
+ auto &CtxProf = MAM.getResult<CtxProfAnalysis>(*M);
+ auto *Caller = M->getFunction("testfunc1");
+ ASSERT_TRUE(!!Caller);
+ auto *Callee = M->getFunction("f2");
+ ASSERT_TRUE(!!Callee);
+ auto *IndirectCS = [&]() -> CallBase * {
+ for (auto &BB : *Caller)
+ for (auto &I : BB)
+ if (auto *CB = dyn_cast<CallBase>(&I); CB && CB->isIndirectCall())
+ return CB;
+ return nullptr;
+ }();
+ ASSERT_TRUE(!!IndirectCS);
+ promoteCallWithIfThenElse(*IndirectCS, *Callee, CtxProf);
+
+ std::string Str;
+ raw_string_ostream OS(Str);
+ CtxProfAnalysisPrinterPass Printer(
+ OS, CtxProfAnalysisPrinterPass::PrintMode::JSON);
+ Printer.run(*M, MAM);
+ const char *Expected = R"(
+ [
+ {
+ "Guid": 1000,
+ "Counters": [1, 11, 22],
+ "Callsites": [
+ [{ "Guid": 1001,
+ "Counters": [10]},
+ { "Guid": 1003,
+ "Counters": [12]
+ }],
+ [{ "Guid": 1002,
+ "Counters": [11],
+ "Callsites": [
+ [{ "Guid": 1004,
+ "Counters": [13] }]]}]]
+ },
+ {
+ "Guid": 1005,
+ "Counters": [2],
+ "Callsites": [
+ [{ "Guid": 1000,
+ "Counters": [1, 102, 204],
+ "Callsites": [
+ [{ "Guid": 1001,
+ "Counters": [101]},
+ { "Guid": 1003,
+ "Counters": [103]}],
+ [{ "Guid": 1002,
+ "Counters": [102],
+ "Callsites": [
+ [{ "Guid": 1004,
+ "Counters": [104]}]]}]]}]]}
+])";
+ auto ExpectedJSON = json::parse(Expected);
+ ASSERT_TRUE(!!ExpectedJSON);
+ auto ProducedJSON = json::parse(Str);
+ ASSERT_TRUE(!!ProducedJSON);
+ EXPECT_EQ(*ProducedJSON, *ExpectedJSON);
+}
More information about the llvm-commits
mailing list