[llvm] [ctx_prof] Simple ICP criteria during module inliner (PR #109881)

Mircea Trofin via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 24 16:08:03 PDT 2024


https://github.com/mtrofin created https://github.com/llvm/llvm-project/pull/109881

None

>From 2059c96567c7b766db1c1eb3d9d6b32e64c269ab Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Mon, 23 Sep 2024 15:20:18 -0700
Subject: [PATCH] [ctx_prof] Simple ICP criteria during module inliner

---
 llvm/include/llvm/Analysis/CtxProfAnalysis.h  | 12 ++++
 llvm/lib/Analysis/CtxProfAnalysis.cpp         | 23 ++++++++
 llvm/lib/Transforms/IPO/ModuleInliner.cpp     | 43 ++++++++++++---
 .../Transforms/Utils/CallPromotionUtils.cpp   | 29 +++++-----
 .../Analysis/CtxProfAnalysis/flatten-icp.ll   | 55 +++++++++++++++++++
 5 files changed, 139 insertions(+), 23 deletions(-)
 create mode 100644 llvm/test/Analysis/CtxProfAnalysis/flatten-icp.ll

diff --git a/llvm/include/llvm/Analysis/CtxProfAnalysis.h b/llvm/include/llvm/Analysis/CtxProfAnalysis.h
index 0a5beb92fcbcc0..ffcecac079243c 100644
--- a/llvm/include/llvm/Analysis/CtxProfAnalysis.h
+++ b/llvm/include/llvm/Analysis/CtxProfAnalysis.h
@@ -9,6 +9,7 @@
 #ifndef LLVM_ANALYSIS_CTXPROFANALYSIS_H
 #define LLVM_ANALYSIS_CTXPROFANALYSIS_H
 
+#include "llvm/ADT/SetVector.h"
 #include "llvm/IR/GlobalValue.h"
 #include "llvm/IR/InstrTypes.h"
 #include "llvm/IR/IntrinsicInst.h"
@@ -63,6 +64,12 @@ class PGOContextualProfile {
     return getDefinedFunctionGUID(F) != 0;
   }
 
+  StringRef getFunctionName(GlobalValue::GUID GUID) const {
+    auto It = FuncInfo.find(GUID);
+    if (It == FuncInfo.end()) return "";
+    return It->second.Name;
+  }
+
   uint32_t getNumCounters(const Function &F) const {
     assert(isFunctionKnown(F));
     return FuncInfo.find(getDefinedFunctionGUID(F))->second.NextCounterIndex;
@@ -120,6 +127,11 @@ class CtxProfAnalysis : public AnalysisInfoMixin<CtxProfAnalysis> {
 
   /// Get the step instrumentation associated with a `select`
   static InstrProfIncrementInstStep *getSelectInstrumentation(SelectInst &SI);
+
+  // FIXME: refactor to an advisor model, and separate
+  static void collectIndirectCallPromotionList(
+      CallBase &IC, Result &Profile,
+      SetVector<std::pair<CallBase *, Function *>> &Candidates);
 };
 
 class CtxProfAnalysisPrinterPass
diff --git a/llvm/lib/Analysis/CtxProfAnalysis.cpp b/llvm/lib/Analysis/CtxProfAnalysis.cpp
index 7517011395a7d6..873277cf51d6b9 100644
--- a/llvm/lib/Analysis/CtxProfAnalysis.cpp
+++ b/llvm/lib/Analysis/CtxProfAnalysis.cpp
@@ -21,6 +21,7 @@
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/JSON.h"
 #include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Transforms/Utils/CallPromotionUtils.h"
 
 #define DEBUG_TYPE "ctx_prof"
 
@@ -309,3 +310,25 @@ const CtxProfFlatProfile PGOContextualProfile::flatten() const {
       });
   return Flat;
 }
+
+void CtxProfAnalysis::collectIndirectCallPromotionList(
+    CallBase &IC, Result &Profile,
+    SetVector<std::pair<CallBase *, Function *>> &Candidates) {
+  const auto *Instr = CtxProfAnalysis::getCallsiteInstrumentation(IC);
+  if (!Instr)
+    return;
+  Module &M = *IC.getParent()->getModule();
+  const uint32_t CallID = Instr->getIndex()->getZExtValue();
+  Profile.visit(
+      [&](const PGOCtxProfContext &Ctx) {
+        const auto &Targets = Ctx.callsites().find(CallID);
+        if (Targets == Ctx.callsites().end())
+          return;
+        for (const auto &[Guid, _] : Targets->second)
+          if (auto Name = Profile.getFunctionName(Guid); !Name.empty())
+            if (auto *Target = M.getFunction(Name))
+              if (Target->hasFnAttribute(Attribute::AlwaysInline))
+                Candidates.insert({&IC, Target});
+      },
+      IC.getCaller());
+}
diff --git a/llvm/lib/Transforms/IPO/ModuleInliner.cpp b/llvm/lib/Transforms/IPO/ModuleInliner.cpp
index 542c319b880747..cf2b34b5a6367b 100644
--- a/llvm/lib/Transforms/IPO/ModuleInliner.cpp
+++ b/llvm/lib/Transforms/IPO/ModuleInliner.cpp
@@ -49,6 +49,13 @@ using namespace llvm;
 STATISTIC(NumInlined, "Number of functions inlined");
 STATISTIC(NumDeleted, "Number of functions deleted because all callers found");
 
+cl::opt<bool> CtxProfPromoteAlwaysInline(
+    "ctx-prof-promote-alwaysinline", cl::init(false), cl::Hidden,
+    cl::desc("If using a contextual profile in this module, and an indirect "
+             "call target is marked as alwaysinline, perform indirect call "
+             "promotion for that target. If multiple targets for an indirect "
+             "call site fit this description, they are all promoted."));
+
 /// Return true if the specified inline history ID
 /// indicates an inline history that includes the specified function.
 static bool inlineHistoryIncludes(
@@ -145,10 +152,11 @@ PreservedAnalyses ModuleInlinerPass::run(Module &M,
   assert(Calls != nullptr && "Expected an initialized InlineOrder");
 
   // Populate the initial list of calls in this module.
+  SetVector<std::pair<CallBase *, Function *>> ICPCandidates;
   for (Function &F : M) {
     auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
-    for (Instruction &I : instructions(F))
-      if (auto *CB = dyn_cast<CallBase>(&I))
+    for (Instruction &I : instructions(F)) {
+      if (auto *CB = dyn_cast<CallBase>(&I)) {
         if (Function *Callee = CB->getCalledFunction()) {
           if (!Callee->isDeclaration())
             Calls->push({CB, -1});
@@ -163,7 +171,17 @@ PreservedAnalyses ModuleInlinerPass::run(Module &M,
                      << setIsVerbose();
             });
           }
+        } else if (CtxProfPromoteAlwaysInline && CtxProf &&
+                   CB->isIndirectCall()) {
+          CtxProfAnalysis::collectIndirectCallPromotionList(*CB, CtxProf,
+                                                            ICPCandidates);
         }
+      }
+    }
+  }
+  for (auto &[CB, Target] : ICPCandidates) {
+    if (auto *DirectCB = promoteCallWithIfThenElse(*CB, *Target, CtxProf))
+      Calls->push({DirectCB, -1});
   }
   if (Calls->empty())
     return PreservedAnalyses::all();
@@ -242,13 +260,22 @@ PreservedAnalyses ModuleInlinerPass::run(Module &M,
           // iteration because the next iteration may not happen and we may
           // miss inlining it.
           // FIXME: enable for ctxprof.
-          if (!CtxProf)
-            if (tryPromoteCall(*ICB))
-              NewCallee = ICB->getCalledFunction();
+          if (CtxProfPromoteAlwaysInline && CtxProf) {
+            SetVector<std::pair<CallBase *, Function *>> Candidates;
+            CtxProfAnalysis::collectIndirectCallPromotionList(*ICB, CtxProf,
+                                                              Candidates);
+            for (auto &[DC, _] : Candidates) {
+              assert(!DC->isIndirectCall());
+              assert(!DC->getCalledFunction()->isDeclaration() &&
+                     "CtxProf promotes calls to defined targets only");
+              Calls->push({DC, NewHistoryID});
+            }
+          } else if (tryPromoteCall(*ICB)) {
+            NewCallee = ICB->getCalledFunction();
+            if (NewCallee && !NewCallee->isDeclaration())
+              Calls->push({ICB, NewHistoryID});
+          }
         }
-        if (NewCallee)
-          if (!NewCallee->isDeclaration())
-            Calls->push({ICB, NewHistoryID});
       }
     }
 
diff --git a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
index 5f872c352429c1..f216be26cfa6fc 100644
--- a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
+++ b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
@@ -627,30 +627,29 @@ CallBase *llvm::promoteCallWithIfThenElse(CallBase &CB, Function &Callee,
     if (!Ctx.hasCallsite(CSIndex))
       return;
     auto &CSData = Ctx.callsite(CSIndex);
-    auto It = CSData.find(CalleeGUID);
 
-    // Maybe we did notice the indirect callsite, but to other targets.
-    if (It == CSData.end())
-      return;
-
-    assert(CalleeGUID == It->second.guid());
-
-    uint32_t DirectCount = It->second.getEntrycount();
-    uint32_t TotalCount = 0;
+    uint64_t TotalCount = 0;
     for (const auto &[_, V] : CSData)
       TotalCount += V.getEntrycount();
+    // Maybe we did notice the indirect callsite, but to other targets.
+    uint64_t DirectCount = 0;
+    if (auto It = CSData.find(CalleeGUID); It != CSData.end()) {
+      assert(CalleeGUID == It->second.guid());
+      DirectCount = It->second.getEntrycount();
+      // This particular indirect target needs to be moved to this caller under
+      // the newly-allocated callsite index.
+      assert(Ctx.callsites().count(NewCSID) == 0);
+      Ctx.ingestContext(NewCSID, std::move(It->second));
+      CSData.erase(CalleeGUID);
+    }
+
     assert(TotalCount >= DirectCount);
-    uint32_t IndirectCount = TotalCount - DirectCount;
+    uint64_t IndirectCount = TotalCount - DirectCount;
     // The ICP's effect is as-if the direct BB would have been taken DirectCount
     // times, and the indirect BB, IndirectCount times
     Ctx.counters()[DirectID] = DirectCount;
     Ctx.counters()[IndirectID] = IndirectCount;
 
-    // This particular indirect target needs to be moved to this caller under
-    // the newly-allocated callsite index.
-    assert(Ctx.callsites().count(NewCSID) == 0);
-    Ctx.ingestContext(NewCSID, std::move(It->second));
-    CSData.erase(CalleeGUID);
   };
   CtxProf.update(ProfileUpdater, &Caller);
   return &DirectCall;
diff --git a/llvm/test/Analysis/CtxProfAnalysis/flatten-icp.ll b/llvm/test/Analysis/CtxProfAnalysis/flatten-icp.ll
new file mode 100644
index 00000000000000..f7529432d4251d
--- /dev/null
+++ b/llvm/test/Analysis/CtxProfAnalysis/flatten-icp.ll
@@ -0,0 +1,55 @@
+; RUN: split-file %s %t
+; RUN: llvm-ctxprof-util fromJSON --input %t/profile.json --output %t/profile.ctxprofdata
+;
+; In the given profile, in one of the contexts the indirect call is taken, the
+; target we're trying to ICP - GUID:2000 - doesn't appear at all. That should
+; contribute to the count of the "indirect call BB".
+; RUN: opt %t/test.ll -S -passes='require<ctx-prof-analysis>,module-inline,ctx-prof-flatten' -use-ctx-profile=%t/profile.ctxprofdata -ctx-prof-promote-alwaysinline 
+
+; CHECK-LABEL: define i32 @caller(ptr %c)
+; CHECK-NEXT:     [[CND:[0-9]+]] = icmp eq ptr %c, @one
+; CHECK-NEXT:     br i1 [[CND]], label %{{.*}}, label %{{.*}}, !prof ![[BW:[0-9]+]]
+
+; CHECK: ![[BW]] = !{!"branch_weights", i32 10, i32 10}
+
+;--- test.ll
+declare i32 @external(i32 %x)
+define i32 @one() #0 !guid !0 {
+  call void @llvm.instrprof.increment(ptr @one, i64 123, i32 1, i32 0)
+  call void @llvm.instrprof.callsite(ptr @one, i64 123, i32 1, i32 0, ptr @external)
+  %ret = call i32 @external(i32 1)
+  ret i32 %ret
+}
+
+define i32 @caller(ptr %c) #1 !guid !1 {
+  call void @llvm.instrprof.increment(ptr @caller, i64 567, i32 1, i32 0)
+  call void @llvm.instrprof.callsite(ptr @caller, i64 567, i32 1, i32 0, ptr %c)
+  %ret = call i32 %c()
+  ret i32 %ret
+}
+
+define i32 @root(ptr %c) !guid !2 {
+  call void @llvm.instrprof.increment(ptr @root, i64 432, i32 1, i32 0)
+  call void @llvm.instrprof.callsite(ptr @root, i64 432, i32 2, i32 0, ptr @caller)
+  %a = call i32 @caller(ptr %c)
+  call void @llvm.instrprof.callsite(ptr @root, i64 432, i32 2, i32 1, ptr @caller)
+  %b = call i32 @caller(ptr %c)
+  %ret = add i32 %a, %b
+  ret i32 %ret
+
+}
+
+attributes #0 = { alwaysinline }
+attributes #1 = { noinline }
+!0 = !{i64 1000}
+!1 = !{i64 3000}
+!2 = !{i64 4000}
+
+;--- profile.json
+[ {
+  "Guid": 4000, "Counters":[10], "Callsites": [
+    [{"Guid":3000, "Counters":[10], "Callsites":[[{"Guid":1000, "Counters":[10]}]]}],
+    [{"Guid":3000, "Counters":[10], "Callsites":[[{"Guid":9000, "Counters":[10]}]]}]
+  ]
+}
+]
\ No newline at end of file



More information about the llvm-commits mailing list