[llvm] [LTO] Add a hook to customize the optimization pipeline (PR #71268)

Igor Kudrin via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 22 12:59:09 PST 2023


https://github.com/igorkudrin updated https://github.com/llvm/llvm-project/pull/71268

>From c63c4792f08c6f9dac63ea44cdc110491ada1a2b Mon Sep 17 00:00:00 2001
From: Igor Kudrin <ikudrin at accesssoftek.com>
Date: Fri, 3 Nov 2023 20:26:13 -0700
Subject: [PATCH] [LTO] Enable adding custom pass instrumentation callbacks

The hook allows the calling code to register instrumentation callbacks
for the LTO optimization pipeline. In particular, a custom pass filter
can be added to skip certain passes from the default pipeline.
---
 llvm/include/llvm/LTO/Config.h    |   5 ++
 llvm/lib/LTO/LTOBackend.cpp       |   2 +
 llvm/unittests/CMakeLists.txt     |   1 +
 llvm/unittests/LTO/CMakeLists.txt |   8 ++
 llvm/unittests/LTO/LTOTest.cpp    | 118 ++++++++++++++++++++++++++++++
 5 files changed, 134 insertions(+)
 create mode 100644 llvm/unittests/LTO/CMakeLists.txt
 create mode 100644 llvm/unittests/LTO/LTOTest.cpp

diff --git a/llvm/include/llvm/LTO/Config.h b/llvm/include/llvm/LTO/Config.h
index 6fb55f1cf1686a5..912dca1752dd80f 100644
--- a/llvm/include/llvm/LTO/Config.h
+++ b/llvm/include/llvm/LTO/Config.h
@@ -257,6 +257,11 @@ struct Config {
       const DenseSet<GlobalValue::GUID> &GUIDPreservedSymbols)>;
   CombinedIndexHookFn CombinedIndexHook;
 
+  /// This hook is called when the optimization pipeline is being built.
+  using PassInstrumentationHookFn =
+      std::function<void(PassInstrumentationCallbacks &)>;
+  PassInstrumentationHookFn PassInstrumentationHook;
+
   /// This is a convenience function that configures this Config object to write
   /// temporary files named after the given OutputFileName for each of the LTO
   /// phases to disk. A client can use this function to implement -save-temps.
diff --git a/llvm/lib/LTO/LTOBackend.cpp b/llvm/lib/LTO/LTOBackend.cpp
index ccc4276e36dacf0..671bc936467be5e 100644
--- a/llvm/lib/LTO/LTOBackend.cpp
+++ b/llvm/lib/LTO/LTOBackend.cpp
@@ -265,6 +265,8 @@ static void runNewPMPasses(const Config &Conf, Module &Mod, TargetMachine *TM,
   ModuleAnalysisManager MAM;
 
   PassInstrumentationCallbacks PIC;
+  if (Conf.PassInstrumentationHook)
+    Conf.PassInstrumentationHook(PIC);
   StandardInstrumentations SI(Mod.getContext(), Conf.DebugPassManager,
                               Conf.VerifyEach);
   SI.registerCallbacks(PIC, &MAM);
diff --git a/llvm/unittests/CMakeLists.txt b/llvm/unittests/CMakeLists.txt
index 46f30ff398e10db..32e7dd7b514e881 100644
--- a/llvm/unittests/CMakeLists.txt
+++ b/llvm/unittests/CMakeLists.txt
@@ -33,6 +33,7 @@ add_subdirectory(InterfaceStub)
 add_subdirectory(IR)
 add_subdirectory(LineEditor)
 add_subdirectory(Linker)
+add_subdirectory(LTO)
 add_subdirectory(MC)
 add_subdirectory(MI)
 add_subdirectory(MIR)
diff --git a/llvm/unittests/LTO/CMakeLists.txt b/llvm/unittests/LTO/CMakeLists.txt
new file mode 100644
index 000000000000000..d3136bd020f1cbc
--- /dev/null
+++ b/llvm/unittests/LTO/CMakeLists.txt
@@ -0,0 +1,8 @@
+set(LLVM_LINK_COMPONENTS
+  ${LLVM_TARGETS_TO_BUILD}
+  LTO
+  )
+
+add_llvm_unittest(LTOTests
+  LTOTest.cpp
+  )
diff --git a/llvm/unittests/LTO/LTOTest.cpp b/llvm/unittests/LTO/LTOTest.cpp
new file mode 100644
index 000000000000000..3ad7fd00758d324
--- /dev/null
+++ b/llvm/unittests/LTO/LTOTest.cpp
@@ -0,0 +1,118 @@
+//===- llvm/unittest/LTO/LTOTest.cpp - Unit tests for LTO -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/LTO/LTO.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/Bitcode/BitcodeWriter.h"
+#include "llvm/MC/TargetRegistry.h"
+#include "llvm/Support/SmallVectorMemoryBuffer.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Target/TargetMachine.h"
+#include "llvm/TargetParser/Host.h"
+#include "llvm/TargetParser/Triple.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+using ::testing::_;
+using ::testing::AtLeast;
+using ::testing::Return;
+
+namespace {
+
+class LTOTest : public ::testing::Test {
+protected:
+  LLVMContext Context;
+  std::string TripleName;
+  std::unique_ptr<TargetMachine> TM;
+
+  std::unique_ptr<Module> createEmptyModule();
+
+public:
+  static void SetUpTestSuite();
+  void SetUp() override;
+};
+
+void LTOTest::SetUpTestSuite() {
+  InitializeAllTargets();
+  InitializeAllTargetMCs();
+  InitializeAllAsmPrinters();
+}
+
+void LTOTest::SetUp() {
+  TripleName = Triple::normalize(sys::getDefaultTargetTriple());
+  std::string Error;
+  const auto *TheTarget = TargetRegistry::lookupTarget(TripleName, Error);
+  if (!TheTarget)
+    GTEST_SKIP();
+  TM.reset(TheTarget->createTargetMachine(TripleName, "", "", TargetOptions(),
+                                          std::nullopt));
+  if (!TM)
+    GTEST_SKIP();
+}
+
+std::unique_ptr<Module> LTOTest::createEmptyModule() {
+  auto M = std::make_unique<Module>("Empty", Context);
+  M->setTargetTriple(TripleName);
+  M->setDataLayout(TM->createDataLayout());
+  return M;
+}
+
+static std::unique_ptr<MemoryBuffer>
+writeBitcodeToMemoryBuffer(const Module &M) {
+  SmallString<0> Buffer;
+  raw_svector_ostream OS(Buffer);
+  WriteBitcodeToFile(M, OS);
+  return std::make_unique<SmallVectorMemoryBuffer>(std::move(Buffer));
+}
+
+static void
+addMemBufToLto(lto::LTO &Lto, MemoryBufferRef MB,
+               ArrayRef<lto::SymbolResolution> SymRes = std::nullopt) {
+  auto InputOrError = lto::InputFile::create(MB);
+  ASSERT_TRUE(!!InputOrError) << toString(InputOrError.takeError());
+  auto AddResult = Lto.add(std::move(*InputOrError), SymRes);
+  ASSERT_TRUE(!AddResult) << toString(std::move(AddResult));
+}
+
+static void runLto(lto::LTO &Lto) {
+  std::vector<SmallString<0>> AddStreamBufs;
+  auto AddStreamFn = [&AddStreamBufs](size_t task,
+                                      const Twine & /*moduleName*/) {
+    return std::make_unique<CachedFileStream>(
+        std::make_unique<raw_svector_ostream>(AddStreamBufs[task]));
+  };
+  AddStreamBufs.resize(Lto.getMaxTasks());
+  auto RunResult = Lto.run(AddStreamFn);
+  ASSERT_TRUE(!RunResult) << toString(std::move(RunResult));
+}
+
+struct MockBeforePassFunc {
+  MOCK_METHOD(bool, Op, (StringRef, Any));
+  bool operator()(StringRef Name, Any IR) { return Op(Name, IR); }
+};
+
+TEST_F(LTOTest, CustomizationHook) {
+  MockBeforePassFunc MBPF;
+  EXPECT_CALL(MBPF, Op(_, _)).Times(AtLeast(1));
+
+  lto::Config LtoConfig;
+  LtoConfig.PassInstrumentationHook = [&](PassInstrumentationCallbacks &PIC) {
+    PIC.registerShouldRunOptionalPassCallback(std::ref(MBPF));
+  };
+  lto::LTO LtoTest(std::move(LtoConfig));
+
+  auto Module = createEmptyModule();
+  auto ModuleMemBuf = writeBitcodeToMemoryBuffer(*Module);
+  ASSERT_NO_FATAL_FAILURE(
+      addMemBufToLto(LtoTest, ModuleMemBuf->getMemBufferRef()));
+  ASSERT_NO_FATAL_FAILURE(runLto(LtoTest));
+}
+
+} // end anonymous namespace



More information about the llvm-commits mailing list