[llvm] 8c33b33 - [PGO] Add a unit test for the PGOInstrumentationGen pass (#93636)

via llvm-commits llvm-commits at lists.llvm.org
Wed May 29 19:51:32 PDT 2024


Author: Pavel Samolysov
Date: 2024-05-30T05:51:29+03:00
New Revision: 8c33b3380b8044824f6adb48cc8d2076aecae566

URL: https://github.com/llvm/llvm-project/commit/8c33b3380b8044824f6adb48cc8d2076aecae566
DIFF: https://github.com/llvm/llvm-project/commit/8c33b3380b8044824f6adb48cc8d2076aecae566.diff

LOG: [PGO] Add a unit test for the PGOInstrumentationGen pass (#93636)

The patch introduces the gmock-based unittest infrastructure for PGO
Instrumentation and adds some test cases to check whether the
instrumentation has taken place. The testing infrastructure for analysis
modules was borrowed from the LoopPassManagerTest unittest and
simplified a bit to handle module analysis passes only. Actually, we are
testing whether the result of a trivial analysis pass was invalidated by
the PGOInstrumentGen one: we exploit the fact the pass invalidates all
the analysis results after a module was instrumented.

NFC.

Added: 
    llvm/unittests/Transforms/Instrumentation/CMakeLists.txt
    llvm/unittests/Transforms/Instrumentation/PGOInstrumentationTest.cpp

Modified: 
    llvm/unittests/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/llvm/unittests/Transforms/CMakeLists.txt b/llvm/unittests/Transforms/CMakeLists.txt
index 98c821acde3a5..320cdf5674149 100644
--- a/llvm/unittests/Transforms/CMakeLists.txt
+++ b/llvm/unittests/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_subdirectory(Coroutines)
+add_subdirectory(Instrumentation)
 add_subdirectory(IPO)
 add_subdirectory(Scalar)
 add_subdirectory(Utils)

diff  --git a/llvm/unittests/Transforms/Instrumentation/CMakeLists.txt b/llvm/unittests/Transforms/Instrumentation/CMakeLists.txt
new file mode 100644
index 0000000000000..1f249b0049d06
--- /dev/null
+++ b/llvm/unittests/Transforms/Instrumentation/CMakeLists.txt
@@ -0,0 +1,16 @@
+set(LLVM_LINK_COMPONENTS
+  Analysis
+  AsmParser
+  Core
+  Instrumentation
+  Passes
+  Support
+)
+
+add_llvm_unittest(InstrumentationTests
+  PGOInstrumentationTest.cpp
+  )
+
+target_link_libraries(InstrumentationTests PRIVATE LLVMTestingSupport)
+
+set_property(TARGET InstrumentationTests PROPERTY FOLDER "Tests/UnitTests/TransformTests")

diff  --git a/llvm/unittests/Transforms/Instrumentation/PGOInstrumentationTest.cpp b/llvm/unittests/Transforms/Instrumentation/PGOInstrumentationTest.cpp
new file mode 100644
index 0000000000000..02c2df2a138b0
--- /dev/null
+++ b/llvm/unittests/Transforms/Instrumentation/PGOInstrumentationTest.cpp
@@ -0,0 +1,192 @@
+//===- PGOInstrumentationTest.cpp - Instrumentation unit tests ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Passes/PassBuilder.h"
+#include "llvm/ProfileData/InstrProf.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+#include <tuple>
+
+namespace {
+
+using namespace llvm;
+
+using testing::_;
+using ::testing::DoDefault;
+using ::testing::Invoke;
+using ::testing::IsNull;
+using ::testing::NotNull;
+using ::testing::Ref;
+using ::testing::Return;
+using ::testing::Sequence;
+using ::testing::Test;
+using ::testing::TestParamInfo;
+using ::testing::Values;
+using ::testing::WithParamInterface;
+
+template <typename Derived> class MockAnalysisHandleBase {
+public:
+  class Analysis : public AnalysisInfoMixin<Analysis> {
+  public:
+    class Result {
+    public:
+      // Forward invalidation events to the mock handle.
+      bool invalidate(Module &M, const PreservedAnalyses &PA,
+                      ModuleAnalysisManager::Invalidator &Inv) {
+        return Handle->invalidate(M, PA, Inv);
+      }
+
+    private:
+      explicit Result(Derived *Handle) : Handle(Handle) {}
+
+      friend MockAnalysisHandleBase;
+      Derived *Handle;
+    };
+
+    Result run(Module &M, ModuleAnalysisManager &AM) {
+      return Handle->run(M, AM);
+    }
+
+  private:
+    friend AnalysisInfoMixin<Analysis>;
+    friend MockAnalysisHandleBase;
+    static inline AnalysisKey Key;
+
+    Derived *Handle;
+
+    explicit Analysis(Derived *Handle) : Handle(Handle) {}
+  };
+
+  Analysis getAnalysis() { return Analysis(static_cast<Derived *>(this)); }
+
+  typename Analysis::Result getResult() {
+    return typename Analysis::Result(static_cast<Derived *>(this));
+  }
+
+protected:
+  void setDefaults() {
+    ON_CALL(static_cast<Derived &>(*this), run(_, _))
+        .WillByDefault(Return(this->getResult()));
+    ON_CALL(static_cast<Derived &>(*this), invalidate(_, _, _))
+        .WillByDefault(Invoke([](Module &M, const PreservedAnalyses &PA,
+                                 ModuleAnalysisManager::Invalidator &) {
+          auto PAC = PA.template getChecker<Analysis>();
+          return !PAC.preserved() &&
+                 !PAC.template preservedSet<AllAnalysesOn<Module>>();
+        }));
+  }
+
+private:
+  friend Derived;
+  MockAnalysisHandleBase() = default;
+};
+
+class MockModuleAnalysisHandle
+    : public MockAnalysisHandleBase<MockModuleAnalysisHandle> {
+public:
+  MockModuleAnalysisHandle() { setDefaults(); }
+
+  MOCK_METHOD(typename Analysis::Result, run,
+              (Module &, ModuleAnalysisManager &));
+
+  MOCK_METHOD(bool, invalidate,
+              (Module &, const PreservedAnalyses &,
+               ModuleAnalysisManager::Invalidator &));
+};
+
+struct PGOInstrumentationGenTest
+    : public Test,
+      WithParamInterface<std::tuple<StringRef, StringRef>> {
+  LLVMContext Ctx;
+  ModulePassManager MPM;
+  PassBuilder PB;
+  MockModuleAnalysisHandle MMAHandle;
+  LoopAnalysisManager LAM;
+  FunctionAnalysisManager FAM;
+  CGSCCAnalysisManager CGAM;
+  ModuleAnalysisManager MAM;
+  LLVMContext Context;
+  std::unique_ptr<Module> M;
+
+  PGOInstrumentationGenTest() {
+    MAM.registerPass([&] { return MMAHandle.getAnalysis(); });
+    PB.registerModuleAnalyses(MAM);
+    PB.registerCGSCCAnalyses(CGAM);
+    PB.registerFunctionAnalyses(FAM);
+    PB.registerLoopAnalyses(LAM);
+    PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
+    MPM.addPass(
+        RequireAnalysisPass<MockModuleAnalysisHandle::Analysis, Module>());
+    MPM.addPass(PGOInstrumentationGen());
+  }
+
+  void parseAssembly(const StringRef IR) {
+    SMDiagnostic Error;
+    M = parseAssemblyString(IR, Error, Context);
+    std::string ErrMsg;
+    raw_string_ostream OS(ErrMsg);
+    Error.print("", OS);
+
+    // A failure here means that the test itself is buggy.
+    if (!M)
+      report_fatal_error(OS.str().c_str());
+  }
+};
+
+static constexpr StringRef CodeWithFuncDefs = R"(
+  define i32 @f(i32 %n) {
+  entry:
+    ret i32 0
+  })";
+
+static constexpr StringRef CodeWithFuncDecls = R"(
+  declare i32 @f(i32);
+)";
+
+static constexpr StringRef CodeWithGlobals = R"(
+  @foo.table = internal unnamed_addr constant [1 x ptr] [ptr @f]
+  declare i32 @f(i32);
+)";
+
+INSTANTIATE_TEST_SUITE_P(
+    PGOInstrumetationGenTestSuite, PGOInstrumentationGenTest,
+    Values(std::make_tuple(CodeWithFuncDefs, "instrument_function_defs"),
+           std::make_tuple(CodeWithFuncDecls, "instrument_function_decls"),
+           std::make_tuple(CodeWithGlobals, "instrument_globals")),
+    [](const TestParamInfo<PGOInstrumentationGenTest::ParamType> &Info) {
+      return std::get<1>(Info.param).str();
+    });
+
+TEST_P(PGOInstrumentationGenTest, Instrumented) {
+  const StringRef Code = std::get<0>(GetParam());
+  parseAssembly(Code);
+
+  ASSERT_THAT(M, NotNull());
+
+  Sequence PassSequence;
+  EXPECT_CALL(MMAHandle, run(Ref(*M), _))
+      .InSequence(PassSequence)
+      .WillOnce(DoDefault());
+  EXPECT_CALL(MMAHandle, invalidate(Ref(*M), _, _))
+      .InSequence(PassSequence)
+      .WillOnce(DoDefault());
+
+  MPM.run(*M, MAM);
+
+  const auto *IRInstrVar =
+      M->getNamedGlobal(INSTR_PROF_QUOTE(INSTR_PROF_RAW_VERSION_VAR));
+  EXPECT_THAT(IRInstrVar, NotNull());
+  EXPECT_FALSE(IRInstrVar->isDeclaration());
+}
+
+} // end anonymous namespace


        


More information about the llvm-commits mailing list