[llvm-branch-commits] [llvm] [ctx_prof] Add analysis utility to fetch ID of a callsite (PR #104491)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Aug 15 12:47:07 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-analysis

Author: Mircea Trofin (mtrofin)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/104491.diff


4 Files Affected:

- (modified) llvm/include/llvm/Analysis/CtxProfAnalysis.h (+4) 
- (modified) llvm/lib/Analysis/CtxProfAnalysis.cpp (+7) 
- (modified) llvm/unittests/Analysis/CMakeLists.txt (+1) 
- (added) llvm/unittests/Analysis/CtxProfAnalysisTest.cpp (+145) 


``````````diff
diff --git a/llvm/include/llvm/Analysis/CtxProfAnalysis.h b/llvm/include/llvm/Analysis/CtxProfAnalysis.h
index f0e2aeb0f92f74..483a6e557126e0 100644
--- a/llvm/include/llvm/Analysis/CtxProfAnalysis.h
+++ b/llvm/include/llvm/Analysis/CtxProfAnalysis.h
@@ -11,6 +11,8 @@
 
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/IR/GlobalValue.h"
+#include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/ProfileData/PGOCtxProfReader.h"
 
@@ -84,6 +86,8 @@ class CtxProfAnalysis : public AnalysisInfoMixin<CtxProfAnalysis> {
   using Result = PGOContextualProfile;
 
   PGOContextualProfile run(Module &M, ModuleAnalysisManager &MAM);
+
+  static InstrProfCallsite *getCallsiteInstrumentation(CallBase &CB);
 };
 
 class CtxProfAnalysisPrinterPass
diff --git a/llvm/lib/Analysis/CtxProfAnalysis.cpp b/llvm/lib/Analysis/CtxProfAnalysis.cpp
index 7b4666b29a1936..3e4c1f8c3df3c0 100644
--- a/llvm/lib/Analysis/CtxProfAnalysis.cpp
+++ b/llvm/lib/Analysis/CtxProfAnalysis.cpp
@@ -183,3 +183,10 @@ PreservedAnalyses CtxProfAnalysisPrinterPass::run(Module &M,
   OS << "\n";
   return PreservedAnalyses::all();
 }
+
+InstrProfCallsite *CtxProfAnalysis::getCallsiteInstrumentation(CallBase &CB) {
+  while (auto *Prev = CB.getPrevNode())
+    if (auto *IPC = dyn_cast<InstrProfCallsite>(Prev))
+      return IPC;
+  return nullptr;
+}
diff --git a/llvm/unittests/Analysis/CMakeLists.txt b/llvm/unittests/Analysis/CMakeLists.txt
index 3cba630867a83b..958d8f9a72fd66 100644
--- a/llvm/unittests/Analysis/CMakeLists.txt
+++ b/llvm/unittests/Analysis/CMakeLists.txt
@@ -22,6 +22,7 @@ set(ANALYSIS_TEST_SOURCES
   CFGTest.cpp
   CGSCCPassManagerTest.cpp
   ConstraintSystemTest.cpp
+  CtxProfAnalysisTest.cpp
   DDGTest.cpp
   DomTreeUpdaterTest.cpp
   DXILResourceTest.cpp
diff --git a/llvm/unittests/Analysis/CtxProfAnalysisTest.cpp b/llvm/unittests/Analysis/CtxProfAnalysisTest.cpp
new file mode 100644
index 00000000000000..8f57b828bf5a15
--- /dev/null
+++ b/llvm/unittests/Analysis/CtxProfAnalysisTest.cpp
@@ -0,0 +1,145 @@
+//===--- CtxProfAnalysisTest.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/CtxProfAnalysis.h"
+#include "llvm/Analysis/BlockFrequencyInfo.h"
+#include "llvm/Analysis/BranchProbabilityInfo.h"
+#include "llvm/Analysis/CGSCCPassManager.h"
+#include "llvm/Analysis/LoopAnalysisManager.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/Analysis.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PassInstrumentation.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/Passes/PassBuilder.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+namespace {
+
+class CtxProfAnalysisTest : public testing::Test {
+  static constexpr auto *IR = R"IR(
+declare void @bar()
+
+define private void @foo(i32 %a, ptr %fct) #0 !guid !0 {
+  %t = icmp eq i32 %a, 0
+  br i1 %t, label %yes, label %no
+yes:
+  call void %fct(i32 %a)
+  br label %exit
+no:
+  call void @bar()
+  br label %exit
+exit:
+  ret void
+}
+
+define void @an_entrypoint(i32 %a) {
+  %t = icmp eq i32 %a, 0
+  br i1 %t, label %yes, label %no
+
+yes:
+  call void @foo(i32 1, ptr null)
+  ret void
+no:
+  ret void
+}
+
+define void @another_entrypoint_no_callees(i32 %a) {
+  %t = icmp eq i32 %a, 0
+  br i1 %t, label %yes, label %no
+
+yes:
+  ret void
+no:
+  ret void
+}
+
+attributes #0 = { noinline }
+!0 = !{ i64 11872291593386833696 }
+)IR";
+
+protected:
+  LLVMContext C;
+  PassBuilder PB;
+  ModuleAnalysisManager MAM;
+  FunctionAnalysisManager FAM;
+  CGSCCAnalysisManager CGAM;
+  LoopAnalysisManager LAM;
+  std::unique_ptr<Module> M;
+
+  void SetUp() override {
+    SMDiagnostic Err;
+    M = parseAssemblyString(IR, Err, C);
+    if (!M)
+      Err.print("CtxProfAnalysisTest", errs());
+  }
+
+public:
+  CtxProfAnalysisTest() {
+    PB.registerModuleAnalyses(MAM);
+    PB.registerCGSCCAnalyses(CGAM);
+    PB.registerFunctionAnalyses(FAM);
+    PB.registerLoopAnalyses(LAM);
+    PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
+  }
+};
+
+TEST_F(CtxProfAnalysisTest, GetCallsiteIDTest) {
+  ASSERT_TRUE(!!M);
+  ModulePassManager MPM;
+  MPM.addPass(PGOInstrumentationGen(/*IsCS=*/false, /*IsCtxProf=*/true));
+  EXPECT_FALSE(MPM.run(*M, MAM).areAllPreserved());
+  auto *F = M->getFunction("foo");
+  ASSERT_NE(F, nullptr);
+  CallBase *IndCall = nullptr;
+  CallBase *DirCall = nullptr;
+  for (auto &BB : *F)
+    for (auto &I : BB)
+      if (auto *CB = dyn_cast<CallBase>(&I)) {
+        if (CB->isIndirectCall()) {
+          EXPECT_EQ(IndCall, nullptr);
+          IndCall = CB;
+        } else if (!CB->getCalledFunction()->isIntrinsic()) {
+          EXPECT_EQ(DirCall, nullptr);
+          DirCall = CB;
+        }
+      }
+  EXPECT_NE(IndCall, nullptr);
+  EXPECT_NE(DirCall, nullptr);
+  auto *IndIns = CtxProfAnalysis::getCallsiteInstrumentation(*IndCall);
+  ASSERT_NE(IndIns, nullptr);
+  EXPECT_EQ(IndIns->getIndex()->getZExtValue(), 0U);
+  auto *DirIns = CtxProfAnalysis::getCallsiteInstrumentation(*DirCall);
+  ASSERT_NE(DirIns, nullptr);
+  EXPECT_EQ(DirIns->getIndex()->getZExtValue(), 1U);
+}
+
+TEST_F(CtxProfAnalysisTest, GetCallsiteIDNegativeTest) {
+  ASSERT_TRUE(!!M);
+  auto *F = M->getFunction("foo");
+  ASSERT_NE(F, nullptr);
+  CallBase *FirstCall = nullptr;
+  for (auto &BB : *F)
+    for (auto &I : BB)
+      if (auto *CB = dyn_cast<CallBase>(&I)) {
+        if (CB->isIndirectCall() || !CB->getCalledFunction()->isIntrinsic()) {
+          FirstCall = CB;
+          break;
+        }
+      }
+  EXPECT_NE(FirstCall, nullptr);
+  auto *IndIns = CtxProfAnalysis::getCallsiteInstrumentation(*FirstCall);
+  ASSERT_EQ(IndIns, nullptr);
+}
+
+} // namespace

``````````

</details>


https://github.com/llvm/llvm-project/pull/104491


More information about the llvm-branch-commits mailing list