[llvm] [PGO] Instrument modules with at least a single function definition (PR #93421)

Teresa Johnson via llvm-commits llvm-commits at lists.llvm.org
Tue May 28 09:40:44 PDT 2024


================
@@ -0,0 +1,206 @@
+//===- PGOInstrumentationTest.cpp - Instrumentation unit tests ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Passes/PassBuilder.h"
+#include "llvm/ProfileData/InstrProf.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace {
+
+using namespace llvm;
+
+using testing::_;
+using ::testing::DoDefault;
+using ::testing::Invoke;
+using ::testing::IsNull;
+using ::testing::NotNull;
+using ::testing::Ref;
+using ::testing::Return;
+using ::testing::Sequence;
+using ::testing::Test;
+
+template <typename Derived> class MockAnalysisHandleBase {
+public:
+  class Analysis : public AnalysisInfoMixin<Analysis> {
+  public:
+    class Result {
+    public:
+      // Forward invalidation events to the mock handle.
+      bool invalidate(Module &M, const PreservedAnalyses &PA,
+                      ModuleAnalysisManager::Invalidator &Inv) {
+        return Handle->invalidate(M, PA, Inv);
+      }
+
+    private:
+      explicit Result(Derived *Handle) : Handle(Handle) {}
+
+      friend MockAnalysisHandleBase;
+      Derived *Handle;
+    };
+
+    Result run(Module &M, ModuleAnalysisManager &AM) {
+      return Handle->run(M, AM);
+    }
+
+  private:
+    friend AnalysisInfoMixin<Analysis>;
+    friend MockAnalysisHandleBase;
+    static inline AnalysisKey Key;
+
+    Derived *Handle;
+
+    explicit Analysis(Derived *Handle) : Handle(Handle) {}
+  };
+
+  Analysis getAnalysis() { return Analysis(static_cast<Derived *>(this)); }
+
+  typename Analysis::Result getResult() {
+    return typename Analysis::Result(static_cast<Derived *>(this));
+  }
+
+protected:
+  void setDefaults() {
+    ON_CALL(static_cast<Derived &>(*this), run(_, _))
+        .WillByDefault(Return(this->getResult()));
+    ON_CALL(static_cast<Derived &>(*this), invalidate(_, _, _))
+        .WillByDefault(Invoke([](Module &M, const PreservedAnalyses &PA,
+                                 ModuleAnalysisManager::Invalidator &) {
+          auto PAC = PA.template getChecker<Analysis>();
+          return !PAC.preserved() &&
+                 !PAC.template preservedSet<AllAnalysesOn<Module>>();
+        }));
+  }
+
+private:
+  friend Derived;
+  MockAnalysisHandleBase() = default;
+};
+
+class MockModuleAnalysisHandle
+    : public MockAnalysisHandleBase<MockModuleAnalysisHandle> {
+public:
+  MockModuleAnalysisHandle() { setDefaults(); }
+
+  MOCK_METHOD(typename Analysis::Result, run,
+              (Module &, ModuleAnalysisManager &));
+
+  MOCK_METHOD(bool, invalidate,
+              (Module &, const PreservedAnalyses &,
+               ModuleAnalysisManager::Invalidator &));
+};
+
+struct PGOInstrumentationGenTest : public Test {
+  LLVMContext Ctx;
+  ModulePassManager MPM;
+  PassBuilder PB;
+  MockModuleAnalysisHandle MMAHandle;
+  LoopAnalysisManager LAM;
+  FunctionAnalysisManager FAM;
+  CGSCCAnalysisManager CGAM;
+  ModuleAnalysisManager MAM;
+  LLVMContext Context;
+  std::unique_ptr<Module> M;
+
+  PGOInstrumentationGenTest() {
+    MAM.registerPass([&] { return MMAHandle.getAnalysis(); });
+    PB.registerModuleAnalyses(MAM);
+    PB.registerCGSCCAnalyses(CGAM);
+    PB.registerFunctionAnalyses(FAM);
+    PB.registerLoopAnalyses(LAM);
+    PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
+    MPM.addPass(
+        RequireAnalysisPass<MockModuleAnalysisHandle::Analysis, Module>());
+    MPM.addPass(PGOInstrumentationGen());
+  }
+
+  void parseAssembly(const StringRef IR) {
+    SMDiagnostic Error;
+    M = parseAssemblyString(IR, Error, Context);
+    std::string ErrMsg;
+    raw_string_ostream OS(ErrMsg);
+    Error.print("", OS);
+
+    // A failure here means that the test itself is buggy.
+    if (!M)
+      report_fatal_error(OS.str().c_str());
+  }
+};
+
+TEST_F(PGOInstrumentationGenTest, NotInstrumentedDeclarationsOnly) {
+  const StringRef NotInstrumentableCode = R"(
+    declare i32 @f(i32);
+  )";
+
+  parseAssembly(NotInstrumentableCode);
+
+  ASSERT_THAT(M, NotNull());
+
+  EXPECT_CALL(MMAHandle, run(Ref(*M), _)).WillOnce(DoDefault());
+  EXPECT_CALL(MMAHandle, invalidate(Ref(*M), _, _)).Times(0);
+
+  MPM.run(*M, MAM);
+
+  const auto *IRInstrVar =
+      M->getNamedGlobal(INSTR_PROF_QUOTE(INSTR_PROF_RAW_VERSION_VAR));
+  EXPECT_THAT(IRInstrVar, IsNull());
+}
+
+TEST_F(PGOInstrumentationGenTest, NotInstrumentedGlobals) {
+  const StringRef NotInstrumentableCode = R"(
+    @foo.table = internal unnamed_addr constant [1 x ptr] [ptr @f]
+    declare i32 @f(i32);
+  )";
+
+  parseAssembly(NotInstrumentableCode);
+
+  ASSERT_THAT(M, NotNull());
+
+  EXPECT_CALL(MMAHandle, run(Ref(*M), _)).WillOnce(DoDefault());
+  EXPECT_CALL(MMAHandle, invalidate(Ref(*M), _, _)).Times(0);
+
+  MPM.run(*M, MAM);
+
+  const auto *IRInstrVar =
+      M->getNamedGlobal(INSTR_PROF_QUOTE(INSTR_PROF_RAW_VERSION_VAR));
+  EXPECT_THAT(IRInstrVar, IsNull());
+}
+
+TEST_F(PGOInstrumentationGenTest, Instrumented) {
----------------
teresajohnson wrote:

Thanks for adding the unit testing. I suggest committing the infrastructure and this test (which I think is unrelated to this PR?) separately, and add in the new unit tests specific to this PR here.

https://github.com/llvm/llvm-project/pull/93421


More information about the llvm-commits mailing list