[llvm-branch-commits] [llvm] [ctx_prof] Handle `select` (PR #109185)

Mircea Trofin via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Sep 18 12:49:19 PDT 2024


https://github.com/mtrofin created https://github.com/llvm/llvm-project/pull/109185

None

>From 09642a4889da1d0e10f54b17b84e32dae5c8557e Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Tue, 17 Sep 2024 22:00:42 -0700
Subject: [PATCH] [ctx_prof] Handle `select`

---
 llvm/include/llvm/Analysis/CtxProfAnalysis.h  |  3 +
 llvm/lib/Analysis/CtxProfAnalysis.cpp         |  9 +++
 .../Instrumentation/PGOCtxProfFlattening.cpp  | 45 ++++++++++-
 llvm/lib/Transforms/Utils/InlineFunction.cpp  | 10 ++-
 .../Analysis/CtxProfAnalysis/handle-select.ll | 76 +++++++++++++++++++
 5 files changed, 140 insertions(+), 3 deletions(-)
 create mode 100644 llvm/test/Analysis/CtxProfAnalysis/handle-select.ll

diff --git a/llvm/include/llvm/Analysis/CtxProfAnalysis.h b/llvm/include/llvm/Analysis/CtxProfAnalysis.h
index b3e64b26ee543c..0a5beb92fcbcc0 100644
--- a/llvm/include/llvm/Analysis/CtxProfAnalysis.h
+++ b/llvm/include/llvm/Analysis/CtxProfAnalysis.h
@@ -117,6 +117,9 @@ class CtxProfAnalysis : public AnalysisInfoMixin<CtxProfAnalysis> {
 
   /// Get the instruction instrumenting a BB, or nullptr if not present.
   static InstrProfIncrementInst *getBBInstrumentation(BasicBlock &BB);
+
+  /// Get the step instrumentation associated with a `select`
+  static InstrProfIncrementInstStep *getSelectInstrumentation(SelectInst &SI);
 };
 
 class CtxProfAnalysisPrinterPass
diff --git a/llvm/lib/Analysis/CtxProfAnalysis.cpp b/llvm/lib/Analysis/CtxProfAnalysis.cpp
index 3df72983862d98..7517011395a7d6 100644
--- a/llvm/lib/Analysis/CtxProfAnalysis.cpp
+++ b/llvm/lib/Analysis/CtxProfAnalysis.cpp
@@ -254,6 +254,15 @@ InstrProfIncrementInst *CtxProfAnalysis::getBBInstrumentation(BasicBlock &BB) {
   return nullptr;
 }
 
+InstrProfIncrementInstStep *
+CtxProfAnalysis::getSelectInstrumentation(SelectInst &SI) {
+  Instruction *Prev = &SI;
+  while ((Prev = Prev->getPrevNode()))
+    if (auto *Step = dyn_cast<InstrProfIncrementInstStep>(Prev))
+      return Step;
+  return nullptr;
+}
+
 template <class ProfilesTy, class ProfTy>
 static void preorderVisit(ProfilesTy &Profiles,
                           function_ref<void(ProfTy &)> Visitor,
diff --git a/llvm/lib/Transforms/Instrumentation/PGOCtxProfFlattening.cpp b/llvm/lib/Transforms/Instrumentation/PGOCtxProfFlattening.cpp
index 91f950e2ba4c3e..30bb251364fdef 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOCtxProfFlattening.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOCtxProfFlattening.cpp
@@ -154,6 +154,8 @@ class ProfileAnnotator final {
 
     bool hasCount() const { return Count.has_value(); }
 
+    uint64_t getCount() const { return *Count;}
+
     bool trySetSingleUnknownInEdgeCount() {
       if (UnknownCountInEdges == 1) {
         setSingleUnknownEdgeCount(InEdges);
@@ -266,6 +268,21 @@ class ProfileAnnotator final {
     return HitExit;
   }
 
+  bool allNonColdSelectsHaveProfile() const {
+    for (const auto &BB : F) {
+      if (getBBInfo(BB).getCount() > 0) {
+        for (const auto &I : BB) {
+          if (const auto *SI = dyn_cast<SelectInst>(&I)) {
+            if (!SI->getMetadata(LLVMContext::MD_prof)) {
+              return false;
+            }
+          }
+        }
+      }
+    }
+    return true;
+  }
+
 public:
   ProfileAnnotator(Function &F, const SmallVectorImpl<uint64_t> &Counters,
                    InstrProfSummaryBuilder &PB)
@@ -324,12 +341,33 @@ class ProfileAnnotator final {
     PB.addEntryCount(Counters[0]);
 
     for (auto &BB : F) {
+      const auto &BBInfo = getBBInfo(BB);
+      if (BBInfo.getCount() > 0) {
+        for (auto &I : BB) {
+          if (auto *SI = dyn_cast<SelectInst>(&I)) {
+            if (auto *Step = CtxProfAnalysis::getSelectInstrumentation(*SI)) {
+              auto Index = Step->getIndex()->getZExtValue();
+              assert(Index < Counters.size() &&
+                    "The index of the step instruction must be inside the "
+                    "counters vector by "
+                    "construction - tripping this assertion indicates a bug in "
+                    "how the contextual profile is managed by IPO transforms");
+              auto TotalCount = BBInfo.getCount();
+              auto TrueCount = Counters[Index];
+              auto FalseCount =
+                  (TotalCount > TrueCount ? TotalCount - TrueCount : 0U);
+              setProfMetadata(F.getParent(), SI, {TrueCount, FalseCount},
+                              std::max(TrueCount, FalseCount));
+            }
+          }
+        }
+      }
       if (succ_size(&BB) < 2)
         continue;
       auto *Term = BB.getTerminator();
       SmallVector<uint64_t, 2> EdgeCounts(Term->getNumSuccessors(), 0);
       uint64_t MaxCount = 0;
-      const auto &BBInfo = getBBInfo(BB);
+
       for (unsigned SuccIdx = 0, Size = BBInfo.getNumOutEdges(); SuccIdx < Size;
            ++SuccIdx) {
         uint64_t EdgeCount = BBInfo.getEdgeCount(SuccIdx);
@@ -343,12 +381,15 @@ class ProfileAnnotator final {
         setProfMetadata(F.getParent(), Term, EdgeCounts, MaxCount);
     }
     assert(allCountersAreAssigned() &&
-           "Expected all counters have been assigned.");
+           "[ctx-prof] Expected all counters have been assigned.");
     assert(allTakenPathsExit() &&
            "[ctx-prof] Encountered a BB with more than one successor, where "
            "all outgoing edges have a 0 count. This occurs in non-exiting "
            "functions (message pumps, usually) which are not supported in the "
            "contextual profiling case");
+    assert(allNonColdSelectsHaveProfile() &&
+           "[ctx-prof] All non-cold select instructions were expected to have "
+           "a profile.");
   }
 };
 
diff --git a/llvm/lib/Transforms/Utils/InlineFunction.cpp b/llvm/lib/Transforms/Utils/InlineFunction.cpp
index 2e05fa80464b8d..257850f7e6d2af 100644
--- a/llvm/lib/Transforms/Utils/InlineFunction.cpp
+++ b/llvm/lib/Transforms/Utils/InlineFunction.cpp
@@ -2211,7 +2211,15 @@ remapIndices(Function &Caller, BasicBlock *StartBB,
     }
     for (auto &I : llvm::make_early_inc_range(*BB)) {
       if (auto *Inc = dyn_cast<InstrProfIncrementInst>(&I)) {
-        if (Inc != BBID) {
+        if (isa<InstrProfIncrementInstStep>(Inc)) {
+          if (isa<Constant>(Inc->getStep())) {
+            assert(!Inc->getNextNode() || !isa<SelectInst>(Inc->getNextNode()));
+            Inc->eraseFromParent();
+          } else {
+            assert(isa_and_nonnull<SelectInst>(Inc->getNextNode()));
+            RewriteInstrIfNeeded(*Inc);
+          }
+        } else if (Inc != BBID) {
           // If we're here it means that the BB had more than 1 IDs, presumably
           // some coming from the callee. We "made up our mind" to keep the
           // first one (which may or may not have been originally the caller's).
diff --git a/llvm/test/Analysis/CtxProfAnalysis/handle-select.ll b/llvm/test/Analysis/CtxProfAnalysis/handle-select.ll
new file mode 100644
index 00000000000000..e740466a03f3e9
--- /dev/null
+++ b/llvm/test/Analysis/CtxProfAnalysis/handle-select.ll
@@ -0,0 +1,76 @@
+; Check that we handle `step` instrumentations. These addorn `select`s.
+; We don't want to confuse the `step` with normal increments, the latter of which
+; we use for BB ID-ing: we want to keep the `step`s after inlining, except if
+; the `select` is elided.
+;
+; RUN: split-file %s %t
+; RUN: llvm-ctxprof-util fromJSON --input=%t/profile.json --output=%t/profile.ctxprofdata
+;
+; RUN: opt -passes=ctx-instr-gen %t/example.ll -use-ctx-profile=%t/profile.ctxprofdata -S -o - | FileCheck %s --check-prefix=INSTR
+; RUN: opt -passes=ctx-instr-gen,module-inline %t/example.ll -use-ctx-profile=%t/profile.ctxprofdata -S -o - | FileCheck %s --check-prefix=POST-INL
+; RUN: opt -passes=ctx-instr-gen,module-inline,ctx-prof-flatten %t/example.ll -use-ctx-profile=%t/profile.ctxprofdata -S -o - | FileCheck %s --check-prefix=FLATTEN
+
+; INSTR-LABEL: yes:
+; INSTR-NEXT:   call void @llvm.instrprof.increment(ptr @foo, i64 [[#]], i32 2, i32 1)
+; INSTR-NEXT:   call void @llvm.instrprof.callsite(ptr @foo, i64 [[#]], i32 2, i32 0, ptr @bar)
+
+; INSTR-LABEL: no:
+; INSTR-NEXT:   call void @llvm.instrprof.callsite(ptr @foo, i64 [[#]], i32 2, i32 1, ptr @bar)
+
+; INSTR-LABEL: define i32 @bar
+; INSTR-NEXT:   call void @llvm.instrprof.increment(ptr @bar, i64 [[#]], i32 2, i32 0)
+; INSTR-NEXT:   %inc =
+; INSTR:        %test = icmp eq i32 %t, 0
+; INSTR-NEXT:   %1  = zext i1 %test to i64
+; INSTR-NEXT:   call void @llvm.instrprof.increment.step(ptr @bar, i64 [[#]], i32 2, i32 1, i64 %1)
+; INSTR-NEXT:   %res = select
+
+; POST-INL-LABEL: yes:
+; POST-INL-NEXT:   call void @llvm.instrprof.increment
+; POST-INL:        call void @llvm.instrprof.increment.step
+; POST-INL-NEXT:   %res.i = select
+
+; POST-INL-LABEL: no:
+; POST-INL-NEXT:   call void @llvm.instrprof.increment
+; POST-INL-NEXT:   br label
+
+; POST-INL-LABEL: exit:
+; POST-INL-NEXT:   %res = phi i32 [ %res.i, %yes ], [ 1, %no ]
+
+; FLATTEN-LABEL: yes:
+; FLATTEN:          %res.i = select i1 %test.i, i32 %inc.i, i32 %dec.i, !prof ![[SELPROF:[0-9]+]]
+; FLATTEN-LABEL: no:
+;
+; See the profile, in the "yes" case we set the step counter's value, in @bar, to 3. The total
+; entry count of that BB is 4.
+; ![[SELPROF]] = !{!"branch_weights", i32 3, i32 1}
+
+;--- example.ll
+define i32 @foo(i32 %t) !guid !0 {
+  %test = icmp slt i32 %t, 0
+  br i1 %test, label %yes, label %no
+yes:
+  %res1 = call i32 @bar(i32 %t) alwaysinline
+  br label %exit
+no:
+  ; this will result in eliding the select in @bar, when inlined.
+  %res2 = call i32 @bar(i32 0) alwaysinline
+  br label %exit
+exit:
+  %res = phi i32 [%res1, %yes], [%res2, %no]
+  ret i32 %res
+}
+
+define i32 @bar(i32 %t) !guid !1 {
+  %inc = add i32 %t, 1
+  %dec = sub i32 %t, 1
+  %test = icmp eq i32 %t, 0
+  %res = select i1 %test, i32 %inc, i32 %dec
+  ret i32 %res
+}
+
+!0 = !{i64 1234}
+!1 = !{i64 5678}
+
+;--- profile.json
+[{"Guid":1234, "Counters":[10, 4], "Callsites":[[{"Guid": 5678, "Counters":[4,3]}],[{"Guid": 5678, "Counters":[6,6]}]]}]



More information about the llvm-branch-commits mailing list