[llvm] [PGO] Instrument modules with at least a single function definition (PR #93421)

Pavel Samolysov via llvm-commits llvm-commits at lists.llvm.org
Sun May 26 11:00:29 PDT 2024


https://github.com/samolisov updated https://github.com/llvm/llvm-project/pull/93421

>From 7bf25a14a982e0841a0eaec5b3632d0af27c1cfc Mon Sep 17 00:00:00 2001
From: Pavel Samolysov <samolisov at gmail.com>
Date: Wed, 22 May 2024 22:10:55 +0300
Subject: [PATCH] [PGO] Instrument modules with at least a single function
 definition

It makes sense to instrument modules that contain at least a single
function definition only. When a module contains only function
declarations and, optionally, global variable definitions (a module for
the regular-LTO phase for thin-LTO when LTOUnit splitting is enabled,
for example), it makes no sense to instrument such module and moreover,
to introduce the `__llvm_profile_raw_version` variable into the module.

The patch also introduces the gmock-based unittest infrastructure for
PGO Instrumentation and adds some test cases to check whether the
instrumentation has taken place. The testing infrastructure for analysis
modules was borrowed from the `LoopPassManagerTest` unittest and
simplified a bit to handle module analysis passes only. Actually, we are
testing whether the result of a trivial analysis pass was invalidated by
the `PGOInstrumentGen` one: we exploit the fact the pass invalidates all
the analysis results after a module was instrumented.
---
 .../Instrumentation/PGOInstrumentation.cpp    |  25 ++-
 .../PGOProfile/declarations_only.ll           |   7 +
 .../PGOProfile/global_variables_only.ll       |   5 +
 .../PGOProfile/no_global_variables.ll         |  15 ++
 llvm/unittests/Transforms/CMakeLists.txt      |   1 +
 .../Transforms/Instrumentation/CMakeLists.txt |  16 ++
 .../PGOInstrumentationTest.cpp                | 206 ++++++++++++++++++
 7 files changed, 268 insertions(+), 7 deletions(-)
 create mode 100644 llvm/test/Transforms/PGOProfile/declarations_only.ll
 create mode 100644 llvm/test/Transforms/PGOProfile/global_variables_only.ll
 create mode 100644 llvm/test/Transforms/PGOProfile/no_global_variables.ll
 create mode 100644 llvm/unittests/Transforms/Instrumentation/CMakeLists.txt
 create mode 100644 llvm/unittests/Transforms/Instrumentation/PGOInstrumentationTest.cpp

diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index 2269c2e0fffae..ff94fc98d0117 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -405,9 +405,14 @@ static GlobalVariable *createIRLevelProfileFlagVar(Module &M, bool IsCS) {
     ProfileVersion |= VARIANT_MASK_BYTE_COVERAGE;
   if (PGOTemporalInstrumentation)
     ProfileVersion |= VARIANT_MASK_TEMPORAL_PROF;
-  auto IRLevelVersionVariable = new GlobalVariable(
+  assert(!M.global_empty() &&
+         "If a module was instrumented, there must be defined global variables "
+         "at least for the counters.");
+  auto *InsertionPoint = &*M.global_begin();
+  auto *IRLevelVersionVariable = new GlobalVariable(
       M, IntTy64, true, GlobalValue::WeakAnyLinkage,
-      Constant::getIntegerValue(IntTy64, APInt(64, ProfileVersion)), VarName);
+      Constant::getIntegerValue(IntTy64, APInt(64, ProfileVersion)), VarName,
+      InsertionPoint);
   IRLevelVersionVariable->setVisibility(GlobalValue::HiddenVisibility);
   Triple TT(M.getTargetTriple());
   if (TT.supportsCOMDAT()) {
@@ -1829,11 +1834,6 @@ static bool InstrumentAllFunctions(
     Module &M, function_ref<TargetLibraryInfo &(Function &)> LookupTLI,
     function_ref<BranchProbabilityInfo *(Function &)> LookupBPI,
     function_ref<BlockFrequencyInfo *(Function &)> LookupBFI, bool IsCS) {
-  // For the context-sensitve instrumentation, we should have a separated pass
-  // (before LTO/ThinLTO linking) to create these variables.
-  if (!IsCS && !PGOCtxProfLoweringPass::isContextualIRPGOEnabled())
-    createIRLevelProfileFlagVar(M, /*IsCS=*/false);
-
   Triple TT(M.getTargetTriple());
   LLVMContext &Ctx = M.getContext();
   if (!TT.isOSBinFormatELF() && EnableVTableValueProfiling)
@@ -1845,6 +1845,7 @@ static bool InstrumentAllFunctions(
   std::unordered_multimap<Comdat *, GlobalValue *> ComdatMembers;
   collectComdatMembers(M, ComdatMembers);
 
+  bool AnythingInstrumented = false;
   for (auto &F : M) {
     if (skipPGOGen(F))
       continue;
@@ -1852,7 +1853,17 @@ static bool InstrumentAllFunctions(
     auto *BPI = LookupBPI(F);
     auto *BFI = LookupBFI(F);
     instrumentOneFunc(F, &M, TLI, BPI, BFI, ComdatMembers, IsCS);
+    AnythingInstrumented = true;
   }
+
+  if (!AnythingInstrumented)
+    return false;
+
+  // For the context-sensitve instrumentation, we should have a separated pass
+  // (before LTO/ThinLTO linking) to create these variables.
+  if (!IsCS && !PGOCtxProfLoweringPass::isContextualIRPGOEnabled())
+    createIRLevelProfileFlagVar(M, /*IsCS=*/false);
+
   return true;
 }
 
diff --git a/llvm/test/Transforms/PGOProfile/declarations_only.ll b/llvm/test/Transforms/PGOProfile/declarations_only.ll
new file mode 100644
index 0000000000000..6d9c86cb8f7a7
--- /dev/null
+++ b/llvm/test/Transforms/PGOProfile/declarations_only.ll
@@ -0,0 +1,7 @@
+; RUN: opt < %s -passes=pgo-instr-gen -S | FileCheck %s --implicit-check-not='__llvm_profile_raw_version'
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+declare i32 @test_1(i32 %i);
+
+declare i32 @test_2(i32 %i);
diff --git a/llvm/test/Transforms/PGOProfile/global_variables_only.ll b/llvm/test/Transforms/PGOProfile/global_variables_only.ll
new file mode 100644
index 0000000000000..c4b298a27bce5
--- /dev/null
+++ b/llvm/test/Transforms/PGOProfile/global_variables_only.ll
@@ -0,0 +1,5 @@
+; RUN: opt < %s -passes=pgo-instr-gen -S | FileCheck %s --implicit-check-not='__llvm_profile_raw_version'
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+ at var = internal unnamed_addr global [35 x ptr] zeroinitializer, align 16
diff --git a/llvm/test/Transforms/PGOProfile/no_global_variables.ll b/llvm/test/Transforms/PGOProfile/no_global_variables.ll
new file mode 100644
index 0000000000000..0faa2f35274c9
--- /dev/null
+++ b/llvm/test/Transforms/PGOProfile/no_global_variables.ll
@@ -0,0 +1,15 @@
+; RUN: opt < %s -passes=pgo-instr-gen -S | FileCheck %s --check-prefix=GEN-COMDAT
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+; GEN-COMDAT: $__llvm_profile_raw_version = comdat any
+; GEN-COMDAT: @__llvm_profile_raw_version = hidden constant i64 {{[0-9]+}}, comdat
+
+define i32 @test_1(i32 %i) {
+entry:
+  ret i32 0
+}
+
+define i32 @test_2(i32 %i) {
+entry:
+  ret i32 1
+}
diff --git a/llvm/unittests/Transforms/CMakeLists.txt b/llvm/unittests/Transforms/CMakeLists.txt
index 98c821acde3a5..320cdf5674149 100644
--- a/llvm/unittests/Transforms/CMakeLists.txt
+++ b/llvm/unittests/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_subdirectory(Coroutines)
+add_subdirectory(Instrumentation)
 add_subdirectory(IPO)
 add_subdirectory(Scalar)
 add_subdirectory(Utils)
diff --git a/llvm/unittests/Transforms/Instrumentation/CMakeLists.txt b/llvm/unittests/Transforms/Instrumentation/CMakeLists.txt
new file mode 100644
index 0000000000000..1f249b0049d06
--- /dev/null
+++ b/llvm/unittests/Transforms/Instrumentation/CMakeLists.txt
@@ -0,0 +1,16 @@
+set(LLVM_LINK_COMPONENTS
+  Analysis
+  AsmParser
+  Core
+  Instrumentation
+  Passes
+  Support
+)
+
+add_llvm_unittest(InstrumentationTests
+  PGOInstrumentationTest.cpp
+  )
+
+target_link_libraries(InstrumentationTests PRIVATE LLVMTestingSupport)
+
+set_property(TARGET InstrumentationTests PROPERTY FOLDER "Tests/UnitTests/TransformTests")
diff --git a/llvm/unittests/Transforms/Instrumentation/PGOInstrumentationTest.cpp b/llvm/unittests/Transforms/Instrumentation/PGOInstrumentationTest.cpp
new file mode 100644
index 0000000000000..33eb3895d9db8
--- /dev/null
+++ b/llvm/unittests/Transforms/Instrumentation/PGOInstrumentationTest.cpp
@@ -0,0 +1,206 @@
+//===- PGOInstrumentationTest.cpp - Instrumentation unit tests ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Passes/PassBuilder.h"
+#include "llvm/ProfileData/InstrProf.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace {
+
+using namespace llvm;
+
+using testing::_;
+using ::testing::DoDefault;
+using ::testing::Invoke;
+using ::testing::IsNull;
+using ::testing::NotNull;
+using ::testing::Ref;
+using ::testing::Return;
+using ::testing::Sequence;
+using ::testing::Test;
+
+template <typename Derived> class MockAnalysisHandleBase {
+public:
+  class Analysis : public AnalysisInfoMixin<Analysis> {
+  public:
+    class Result {
+    public:
+      // Forward invalidation events to the mock handle.
+      bool invalidate(Module &M, const PreservedAnalyses &PA,
+                      ModuleAnalysisManager::Invalidator &Inv) {
+        return Handle->invalidate(M, PA, Inv);
+      }
+
+    private:
+      explicit Result(Derived *Handle) : Handle(Handle) {}
+
+      friend MockAnalysisHandleBase;
+      Derived *Handle;
+    };
+
+    Result run(Module &M, ModuleAnalysisManager &AM) {
+      return Handle->run(M, AM);
+    }
+
+  private:
+    friend AnalysisInfoMixin<Analysis>;
+    friend MockAnalysisHandleBase;
+    static inline AnalysisKey Key;
+
+    Derived *Handle;
+
+    explicit Analysis(Derived *Handle) : Handle(Handle) {}
+  };
+
+  Analysis getAnalysis() { return Analysis(static_cast<Derived *>(this)); }
+
+  typename Analysis::Result getResult() {
+    return typename Analysis::Result(static_cast<Derived *>(this));
+  }
+
+protected:
+  void setDefaults() {
+    ON_CALL(static_cast<Derived &>(*this), run(_, _))
+        .WillByDefault(Return(this->getResult()));
+    ON_CALL(static_cast<Derived &>(*this), invalidate(_, _, _))
+        .WillByDefault(Invoke([](Module &M, const PreservedAnalyses &PA,
+                                 ModuleAnalysisManager::Invalidator &) {
+          auto PAC = PA.template getChecker<Analysis>();
+          return !PAC.preserved() &&
+                 !PAC.template preservedSet<AllAnalysesOn<Module>>();
+        }));
+  }
+
+private:
+  friend Derived;
+  MockAnalysisHandleBase() = default;
+};
+
+class MockModuleAnalysisHandle
+    : public MockAnalysisHandleBase<MockModuleAnalysisHandle> {
+public:
+  MockModuleAnalysisHandle() { setDefaults(); }
+
+  MOCK_METHOD(typename Analysis::Result, run,
+              (Module &, ModuleAnalysisManager &));
+
+  MOCK_METHOD(bool, invalidate,
+              (Module &, const PreservedAnalyses &,
+               ModuleAnalysisManager::Invalidator &));
+};
+
+struct PGOInstrumentationGenTest : public Test {
+  LLVMContext Ctx;
+  ModulePassManager MPM;
+  PassBuilder PB;
+  MockModuleAnalysisHandle MMAHandle;
+  LoopAnalysisManager LAM;
+  FunctionAnalysisManager FAM;
+  CGSCCAnalysisManager CGAM;
+  ModuleAnalysisManager MAM;
+  LLVMContext Context;
+  std::unique_ptr<Module> M;
+
+  PGOInstrumentationGenTest() {
+    MAM.registerPass([&] { return MMAHandle.getAnalysis(); });
+    PB.registerModuleAnalyses(MAM);
+    PB.registerCGSCCAnalyses(CGAM);
+    PB.registerFunctionAnalyses(FAM);
+    PB.registerLoopAnalyses(LAM);
+    PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
+    MPM.addPass(
+        RequireAnalysisPass<MockModuleAnalysisHandle::Analysis, Module>());
+    MPM.addPass(PGOInstrumentationGen());
+  }
+
+  void parseAssembly(const StringRef IR) {
+    SMDiagnostic Error;
+    M = parseAssemblyString(IR, Error, Context);
+    std::string ErrMsg;
+    raw_string_ostream OS(ErrMsg);
+    Error.print("", OS);
+
+    // A failure here means that the test itself is buggy.
+    if (!M)
+      report_fatal_error(OS.str().c_str());
+  }
+};
+
+TEST_F(PGOInstrumentationGenTest, NotInstrumentedDeclarationsOnly) {
+  const StringRef NotInstrumentableCode = R"(
+    declare i32 @f(i32);
+  )";
+
+  parseAssembly(NotInstrumentableCode);
+
+  ASSERT_THAT(M, NotNull());
+
+  EXPECT_CALL(MMAHandle, run(Ref(*M), _)).WillOnce(DoDefault());
+  EXPECT_CALL(MMAHandle, invalidate(Ref(*M), _, _)).Times(0);
+
+  MPM.run(*M, MAM);
+
+  const auto *IRInstrVar =
+      M->getNamedGlobal(INSTR_PROF_QUOTE(INSTR_PROF_RAW_VERSION_VAR));
+  EXPECT_THAT(IRInstrVar, IsNull());
+}
+
+TEST_F(PGOInstrumentationGenTest, NotInstrumentedGlobals) {
+  const StringRef NotInstrumentableCode = R"(
+    @foo.table = internal unnamed_addr constant [1 x ptr] [ptr @f]
+    declare i32 @f(i32);
+  )";
+
+  parseAssembly(NotInstrumentableCode);
+
+  ASSERT_THAT(M, NotNull());
+
+  EXPECT_CALL(MMAHandle, run(Ref(*M), _)).WillOnce(DoDefault());
+  EXPECT_CALL(MMAHandle, invalidate(Ref(*M), _, _)).Times(0);
+
+  MPM.run(*M, MAM);
+
+  const auto *IRInstrVar =
+      M->getNamedGlobal(INSTR_PROF_QUOTE(INSTR_PROF_RAW_VERSION_VAR));
+  EXPECT_THAT(IRInstrVar, IsNull());
+}
+
+TEST_F(PGOInstrumentationGenTest, Instrumented) {
+  const StringRef InstrumentableCode = R"(
+    define i32 @f(i32 %n) {
+    entry:
+      ret i32 0
+    }
+  )";
+
+  parseAssembly(InstrumentableCode);
+
+  ASSERT_THAT(M, NotNull());
+
+  Sequence PassSequence;
+  EXPECT_CALL(MMAHandle, run(Ref(*M), _))
+      .InSequence(PassSequence)
+      .WillOnce(DoDefault());
+  EXPECT_CALL(MMAHandle, invalidate(Ref(*M), _, _))
+      .InSequence(PassSequence)
+      .WillOnce(DoDefault());
+
+  MPM.run(*M, MAM);
+
+  const auto *IRInstrVar =
+      M->getNamedGlobal(INSTR_PROF_QUOTE(INSTR_PROF_RAW_VERSION_VAR));
+  EXPECT_THAT(IRInstrVar, NotNull());
+  EXPECT_FALSE(IRInstrVar->isDeclaration());
+}
+
+} // end anonymous namespace



More information about the llvm-commits mailing list