[llvm] [PGO] Instrument modules with at least a single function definition (PR #93421)
Teresa Johnson via llvm-commits
llvm-commits at lists.llvm.org
Tue May 28 09:40:44 PDT 2024
================
@@ -0,0 +1,206 @@
+//===- PGOInstrumentationTest.cpp - Instrumentation unit tests ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Passes/PassBuilder.h"
+#include "llvm/ProfileData/InstrProf.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace {
+
+using namespace llvm;
+
+using testing::_;
+using ::testing::DoDefault;
+using ::testing::Invoke;
+using ::testing::IsNull;
+using ::testing::NotNull;
+using ::testing::Ref;
+using ::testing::Return;
+using ::testing::Sequence;
+using ::testing::Test;
+
+template <typename Derived> class MockAnalysisHandleBase {
+public:
+ class Analysis : public AnalysisInfoMixin<Analysis> {
+ public:
+ class Result {
+ public:
+ // Forward invalidation events to the mock handle.
+ bool invalidate(Module &M, const PreservedAnalyses &PA,
+ ModuleAnalysisManager::Invalidator &Inv) {
+ return Handle->invalidate(M, PA, Inv);
+ }
+
+ private:
+ explicit Result(Derived *Handle) : Handle(Handle) {}
+
+ friend MockAnalysisHandleBase;
+ Derived *Handle;
+ };
+
+ Result run(Module &M, ModuleAnalysisManager &AM) {
+ return Handle->run(M, AM);
+ }
+
+ private:
+ friend AnalysisInfoMixin<Analysis>;
+ friend MockAnalysisHandleBase;
+ static inline AnalysisKey Key;
+
+ Derived *Handle;
+
+ explicit Analysis(Derived *Handle) : Handle(Handle) {}
+ };
+
+ Analysis getAnalysis() { return Analysis(static_cast<Derived *>(this)); }
+
+ typename Analysis::Result getResult() {
+ return typename Analysis::Result(static_cast<Derived *>(this));
+ }
+
+protected:
+ void setDefaults() {
+ ON_CALL(static_cast<Derived &>(*this), run(_, _))
+ .WillByDefault(Return(this->getResult()));
+ ON_CALL(static_cast<Derived &>(*this), invalidate(_, _, _))
+ .WillByDefault(Invoke([](Module &M, const PreservedAnalyses &PA,
+ ModuleAnalysisManager::Invalidator &) {
+ auto PAC = PA.template getChecker<Analysis>();
+ return !PAC.preserved() &&
+ !PAC.template preservedSet<AllAnalysesOn<Module>>();
+ }));
+ }
+
+private:
+ friend Derived;
+ MockAnalysisHandleBase() = default;
+};
+
+class MockModuleAnalysisHandle
+ : public MockAnalysisHandleBase<MockModuleAnalysisHandle> {
+public:
+ MockModuleAnalysisHandle() { setDefaults(); }
+
+ MOCK_METHOD(typename Analysis::Result, run,
+ (Module &, ModuleAnalysisManager &));
+
+ MOCK_METHOD(bool, invalidate,
+ (Module &, const PreservedAnalyses &,
+ ModuleAnalysisManager::Invalidator &));
+};
+
+struct PGOInstrumentationGenTest : public Test {
+ LLVMContext Ctx;
+ ModulePassManager MPM;
+ PassBuilder PB;
+ MockModuleAnalysisHandle MMAHandle;
+ LoopAnalysisManager LAM;
+ FunctionAnalysisManager FAM;
+ CGSCCAnalysisManager CGAM;
+ ModuleAnalysisManager MAM;
+ LLVMContext Context;
+ std::unique_ptr<Module> M;
+
+ PGOInstrumentationGenTest() {
+ MAM.registerPass([&] { return MMAHandle.getAnalysis(); });
+ PB.registerModuleAnalyses(MAM);
+ PB.registerCGSCCAnalyses(CGAM);
+ PB.registerFunctionAnalyses(FAM);
+ PB.registerLoopAnalyses(LAM);
+ PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
+ MPM.addPass(
+ RequireAnalysisPass<MockModuleAnalysisHandle::Analysis, Module>());
+ MPM.addPass(PGOInstrumentationGen());
+ }
+
+ void parseAssembly(const StringRef IR) {
+ SMDiagnostic Error;
+ M = parseAssemblyString(IR, Error, Context);
+ std::string ErrMsg;
+ raw_string_ostream OS(ErrMsg);
+ Error.print("", OS);
+
+ // A failure here means that the test itself is buggy.
+ if (!M)
+ report_fatal_error(OS.str().c_str());
+ }
+};
+
+TEST_F(PGOInstrumentationGenTest, NotInstrumentedDeclarationsOnly) {
+ const StringRef NotInstrumentableCode = R"(
+ declare i32 @f(i32);
+ )";
+
+ parseAssembly(NotInstrumentableCode);
+
+ ASSERT_THAT(M, NotNull());
+
+ EXPECT_CALL(MMAHandle, run(Ref(*M), _)).WillOnce(DoDefault());
+ EXPECT_CALL(MMAHandle, invalidate(Ref(*M), _, _)).Times(0);
+
+ MPM.run(*M, MAM);
+
+ const auto *IRInstrVar =
+ M->getNamedGlobal(INSTR_PROF_QUOTE(INSTR_PROF_RAW_VERSION_VAR));
+ EXPECT_THAT(IRInstrVar, IsNull());
+}
+
+TEST_F(PGOInstrumentationGenTest, NotInstrumentedGlobals) {
+ const StringRef NotInstrumentableCode = R"(
+ @foo.table = internal unnamed_addr constant [1 x ptr] [ptr @f]
+ declare i32 @f(i32);
+ )";
+
+ parseAssembly(NotInstrumentableCode);
+
+ ASSERT_THAT(M, NotNull());
+
+ EXPECT_CALL(MMAHandle, run(Ref(*M), _)).WillOnce(DoDefault());
+ EXPECT_CALL(MMAHandle, invalidate(Ref(*M), _, _)).Times(0);
+
+ MPM.run(*M, MAM);
+
+ const auto *IRInstrVar =
+ M->getNamedGlobal(INSTR_PROF_QUOTE(INSTR_PROF_RAW_VERSION_VAR));
+ EXPECT_THAT(IRInstrVar, IsNull());
+}
+
+TEST_F(PGOInstrumentationGenTest, Instrumented) {
----------------
teresajohnson wrote:
Thanks for adding the unit testing. I suggest committing the infrastructure and this test (which I think is unrelated to this PR?) separately, and add in the new unit tests specific to this PR here.
https://github.com/llvm/llvm-project/pull/93421
More information about the llvm-commits
mailing list