[llvm] [LTO] Add a hook to customize the optimization pipeline (PR #71268)
Igor Kudrin via llvm-commits
llvm-commits at lists.llvm.org
Wed Nov 22 12:59:09 PST 2023
https://github.com/igorkudrin updated https://github.com/llvm/llvm-project/pull/71268
>From c63c4792f08c6f9dac63ea44cdc110491ada1a2b Mon Sep 17 00:00:00 2001
From: Igor Kudrin <ikudrin at accesssoftek.com>
Date: Fri, 3 Nov 2023 20:26:13 -0700
Subject: [PATCH] [LTO] Enable adding custom pass instrumentation callbacks
The hook allows the calling code to register instrumentation callbacks
for the LTO optimization pipeline. In particular, a custom pass filter
can be added to skip certain passes from the default pipeline.
---
llvm/include/llvm/LTO/Config.h | 5 ++
llvm/lib/LTO/LTOBackend.cpp | 2 +
llvm/unittests/CMakeLists.txt | 1 +
llvm/unittests/LTO/CMakeLists.txt | 8 ++
llvm/unittests/LTO/LTOTest.cpp | 118 ++++++++++++++++++++++++++++++
5 files changed, 134 insertions(+)
create mode 100644 llvm/unittests/LTO/CMakeLists.txt
create mode 100644 llvm/unittests/LTO/LTOTest.cpp
diff --git a/llvm/include/llvm/LTO/Config.h b/llvm/include/llvm/LTO/Config.h
index 6fb55f1cf1686a5..912dca1752dd80f 100644
--- a/llvm/include/llvm/LTO/Config.h
+++ b/llvm/include/llvm/LTO/Config.h
@@ -257,6 +257,11 @@ struct Config {
const DenseSet<GlobalValue::GUID> &GUIDPreservedSymbols)>;
CombinedIndexHookFn CombinedIndexHook;
+ /// This hook is called when the optimization pipeline is being built.
+ using PassInstrumentationHookFn =
+ std::function<void(PassInstrumentationCallbacks &)>;
+ PassInstrumentationHookFn PassInstrumentationHook;
+
/// This is a convenience function that configures this Config object to write
/// temporary files named after the given OutputFileName for each of the LTO
/// phases to disk. A client can use this function to implement -save-temps.
diff --git a/llvm/lib/LTO/LTOBackend.cpp b/llvm/lib/LTO/LTOBackend.cpp
index ccc4276e36dacf0..671bc936467be5e 100644
--- a/llvm/lib/LTO/LTOBackend.cpp
+++ b/llvm/lib/LTO/LTOBackend.cpp
@@ -265,6 +265,8 @@ static void runNewPMPasses(const Config &Conf, Module &Mod, TargetMachine *TM,
ModuleAnalysisManager MAM;
PassInstrumentationCallbacks PIC;
+ if (Conf.PassInstrumentationHook)
+ Conf.PassInstrumentationHook(PIC);
StandardInstrumentations SI(Mod.getContext(), Conf.DebugPassManager,
Conf.VerifyEach);
SI.registerCallbacks(PIC, &MAM);
diff --git a/llvm/unittests/CMakeLists.txt b/llvm/unittests/CMakeLists.txt
index 46f30ff398e10db..32e7dd7b514e881 100644
--- a/llvm/unittests/CMakeLists.txt
+++ b/llvm/unittests/CMakeLists.txt
@@ -33,6 +33,7 @@ add_subdirectory(InterfaceStub)
add_subdirectory(IR)
add_subdirectory(LineEditor)
add_subdirectory(Linker)
+add_subdirectory(LTO)
add_subdirectory(MC)
add_subdirectory(MI)
add_subdirectory(MIR)
diff --git a/llvm/unittests/LTO/CMakeLists.txt b/llvm/unittests/LTO/CMakeLists.txt
new file mode 100644
index 000000000000000..d3136bd020f1cbc
--- /dev/null
+++ b/llvm/unittests/LTO/CMakeLists.txt
@@ -0,0 +1,8 @@
+set(LLVM_LINK_COMPONENTS
+ ${LLVM_TARGETS_TO_BUILD}
+ LTO
+ )
+
+add_llvm_unittest(LTOTests
+ LTOTest.cpp
+ )
diff --git a/llvm/unittests/LTO/LTOTest.cpp b/llvm/unittests/LTO/LTOTest.cpp
new file mode 100644
index 000000000000000..3ad7fd00758d324
--- /dev/null
+++ b/llvm/unittests/LTO/LTOTest.cpp
@@ -0,0 +1,118 @@
+//===- llvm/unittest/LTO/LTOTest.cpp - Unit tests for LTO -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/LTO/LTO.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/Bitcode/BitcodeWriter.h"
+#include "llvm/MC/TargetRegistry.h"
+#include "llvm/Support/SmallVectorMemoryBuffer.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Target/TargetMachine.h"
+#include "llvm/TargetParser/Host.h"
+#include "llvm/TargetParser/Triple.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+using ::testing::_;
+using ::testing::AtLeast;
+using ::testing::Return;
+
+namespace {
+
+class LTOTest : public ::testing::Test {
+protected:
+ LLVMContext Context;
+ std::string TripleName;
+ std::unique_ptr<TargetMachine> TM;
+
+ std::unique_ptr<Module> createEmptyModule();
+
+public:
+ static void SetUpTestSuite();
+ void SetUp() override;
+};
+
+void LTOTest::SetUpTestSuite() {
+ InitializeAllTargets();
+ InitializeAllTargetMCs();
+ InitializeAllAsmPrinters();
+}
+
+void LTOTest::SetUp() {
+ TripleName = Triple::normalize(sys::getDefaultTargetTriple());
+ std::string Error;
+ const auto *TheTarget = TargetRegistry::lookupTarget(TripleName, Error);
+ if (!TheTarget)
+ GTEST_SKIP();
+ TM.reset(TheTarget->createTargetMachine(TripleName, "", "", TargetOptions(),
+ std::nullopt));
+ if (!TM)
+ GTEST_SKIP();
+}
+
+std::unique_ptr<Module> LTOTest::createEmptyModule() {
+ auto M = std::make_unique<Module>("Empty", Context);
+ M->setTargetTriple(TripleName);
+ M->setDataLayout(TM->createDataLayout());
+ return M;
+}
+
+static std::unique_ptr<MemoryBuffer>
+writeBitcodeToMemoryBuffer(const Module &M) {
+ SmallString<0> Buffer;
+ raw_svector_ostream OS(Buffer);
+ WriteBitcodeToFile(M, OS);
+ return std::make_unique<SmallVectorMemoryBuffer>(std::move(Buffer));
+}
+
+static void
+addMemBufToLto(lto::LTO &Lto, MemoryBufferRef MB,
+ ArrayRef<lto::SymbolResolution> SymRes = std::nullopt) {
+ auto InputOrError = lto::InputFile::create(MB);
+ ASSERT_TRUE(!!InputOrError) << toString(InputOrError.takeError());
+ auto AddResult = Lto.add(std::move(*InputOrError), SymRes);
+ ASSERT_TRUE(!AddResult) << toString(std::move(AddResult));
+}
+
+static void runLto(lto::LTO &Lto) {
+ std::vector<SmallString<0>> AddStreamBufs;
+ auto AddStreamFn = [&AddStreamBufs](size_t task,
+ const Twine & /*moduleName*/) {
+ return std::make_unique<CachedFileStream>(
+ std::make_unique<raw_svector_ostream>(AddStreamBufs[task]));
+ };
+ AddStreamBufs.resize(Lto.getMaxTasks());
+ auto RunResult = Lto.run(AddStreamFn);
+ ASSERT_TRUE(!RunResult) << toString(std::move(RunResult));
+}
+
+struct MockBeforePassFunc {
+ MOCK_METHOD(bool, Op, (StringRef, Any));
+ bool operator()(StringRef Name, Any IR) { return Op(Name, IR); }
+};
+
+TEST_F(LTOTest, CustomizationHook) {
+ MockBeforePassFunc MBPF;
+ EXPECT_CALL(MBPF, Op(_, _)).Times(AtLeast(1));
+
+ lto::Config LtoConfig;
+ LtoConfig.PassInstrumentationHook = [&](PassInstrumentationCallbacks &PIC) {
+ PIC.registerShouldRunOptionalPassCallback(std::ref(MBPF));
+ };
+ lto::LTO LtoTest(std::move(LtoConfig));
+
+ auto Module = createEmptyModule();
+ auto ModuleMemBuf = writeBitcodeToMemoryBuffer(*Module);
+ ASSERT_NO_FATAL_FAILURE(
+ addMemBufToLto(LtoTest, ModuleMemBuf->getMemBufferRef()));
+ ASSERT_NO_FATAL_FAILURE(runLto(LtoTest));
+}
+
+} // end anonymous namespace
More information about the llvm-commits
mailing list