[llvm] [CodeGen] Allow `CodeGenPassBuilder` to add module pass after function pass (PR #77084)

via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 5 03:53:05 PST 2024


https://github.com/paperchalice created https://github.com/llvm/llvm-project/pull/77084

In fact, there are several backends, e.g. AArch64, AMDGPU etc. add module pass after function pass, this patch removes this constraint. This patch also adds a simple unit test for `CodeGenPassBuilder`.

>From d62a7bf2572dfbc14a6651542395fee8e63d14c0 Mon Sep 17 00:00:00 2001
From: PaperChalice <liujunchang97 at outlook.com>
Date: Fri, 5 Jan 2024 19:35:43 +0800
Subject: [PATCH] [CodeGen] Allow `CodeGenPassBuilder` to add module pass after
 function pass

In fact, there are several backends, e.g. AArch64, AMDGPU etc. add module pass after function pass, this patch removes this constraint.
This patch also adds a simple unit test for `CodeGenPassBuilder`
---
 .../include/llvm/CodeGen/CodeGenPassBuilder.h |  50 +++----
 llvm/unittests/CodeGen/CMakeLists.txt         |   4 +
 .../CodeGen/CodeGenPassBuilderTest.cpp        | 130 ++++++++++++++++++
 3 files changed, 153 insertions(+), 31 deletions(-)
 create mode 100644 llvm/unittests/CodeGen/CodeGenPassBuilderTest.cpp

diff --git a/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h b/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h
index fc3da34914120e..13a556651e9652 100644
--- a/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h
+++ b/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h
@@ -172,45 +172,33 @@ template <typename DerivedT> class CodeGenPassBuilder {
   // Function object to maintain state while adding codegen IR passes.
   class AddIRPass {
   public:
-    AddIRPass(ModulePassManager &MPM, bool DebugPM, bool Check = true)
-        : MPM(MPM) {
-      if (Check)
-        AddingFunctionPasses = false;
-    }
+    AddIRPass(ModulePassManager &MPM, bool DebugPM) : MPM(MPM) {}
     ~AddIRPass() {
-      MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
-    }
-
-    // Add Function Pass
-    template <typename PassT>
-    std::enable_if_t<is_detected<is_function_pass_t, PassT>::value>
-    operator()(PassT &&Pass) {
-      if (AddingFunctionPasses && !*AddingFunctionPasses)
-        AddingFunctionPasses = true;
-      FPM.addPass(std::forward<PassT>(Pass));
+      if (!FPM.isEmpty())
+        MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
     }
 
-    // Add Module Pass
-    template <typename PassT>
-    std::enable_if_t<is_detected<is_module_pass_t, PassT>::value &&
-                     !is_detected<is_function_pass_t, PassT>::value>
-    operator()(PassT &&Pass) {
-      assert((!AddingFunctionPasses || !*AddingFunctionPasses) &&
-             "could not add module pass after adding function pass");
-      MPM.addPass(std::forward<PassT>(Pass));
+    template <typename PassT> void operator()(PassT &&Pass) {
+      // Add Function Pass
+      if constexpr (is_detected<is_function_pass_t, PassT>::value) {
+        FPM.addPass(std::forward<PassT>(Pass));
+        return;
+      }
+
+      // Add Module Pass
+      if constexpr (is_detected<is_module_pass_t, PassT>::value &&
+                    !is_detected<is_function_pass_t, PassT>::value) {
+        if (!FPM.isEmpty()) {
+          MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
+          FPM = FunctionPassManager();
+        }
+        MPM.addPass(std::forward<PassT>(Pass));
+      }
     }
 
   private:
     ModulePassManager &MPM;
     FunctionPassManager FPM;
-    // The codegen IR pipeline are mostly function passes with the exceptions of
-    // a few loop and module passes. `AddingFunctionPasses` make sures that
-    // we could only add module passes at the beginning of the pipeline. Once
-    // we begin adding function passes, we could no longer add module passes.
-    // This special-casing introduces less adaptor passes. If we have the need
-    // of adding module passes after function passes, we could change the
-    // implementation to accommodate that.
-    std::optional<bool> AddingFunctionPasses;
   };
 
   // Function object to maintain state while adding codegen machine passes.
diff --git a/llvm/unittests/CodeGen/CMakeLists.txt b/llvm/unittests/CodeGen/CMakeLists.txt
index fa6c9cf7c5aebf..c78cbfcc281939 100644
--- a/llvm/unittests/CodeGen/CMakeLists.txt
+++ b/llvm/unittests/CodeGen/CMakeLists.txt
@@ -7,13 +7,16 @@ set(LLVM_LINK_COMPONENTS
   CodeGenTypes
   Core
   FileCheck
+  IRPrinter
   MC
   MIRParser
   Passes
+  ScalarOpts
   SelectionDAG
   Support
   Target
   TargetParser
+  TransformUtils
   )
 
 add_llvm_unittest(CodeGenTests
@@ -22,6 +25,7 @@ add_llvm_unittest(CodeGenTests
   AMDGPUMetadataTest.cpp
   AsmPrinterDwarfTest.cpp
   CCStateTest.cpp
+  CodeGenPassBuilderTest.cpp
   DIEHashTest.cpp
   DIETest.cpp
   DwarfStringPoolEntryRefTest.cpp
diff --git a/llvm/unittests/CodeGen/CodeGenPassBuilderTest.cpp b/llvm/unittests/CodeGen/CodeGenPassBuilderTest.cpp
new file mode 100644
index 00000000000000..e3450ff1428fe0
--- /dev/null
+++ b/llvm/unittests/CodeGen/CodeGenPassBuilderTest.cpp
@@ -0,0 +1,130 @@
+//===- llvm/unittest/CodeGen/CodeGenPassBuilder.cpp -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/CodeGen/CodeGenPassBuilder.h"
+#include "llvm/CodeGen/MachinePassManager.h"
+#include "llvm/CodeGen/TargetPassConfig.h"
+#include "llvm/MC/MCStreamer.h"
+#include "llvm/MC/TargetRegistry.h"
+#include "llvm/Passes/PassBuilder.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/TargetParser/Host.h"
+#include "gtest/gtest.h"
+#include <string>
+
+namespace llvm {
+
+extern cl::opt<bool> PrintPipelinePasses;
+
+}
+
+using namespace llvm;
+
+namespace {
+
+struct NoOpModulePass : PassInfoMixin<NoOpModulePass> {
+  PreservedAnalyses run(Module &M, ModuleAnalysisManager &) {
+    return PreservedAnalyses::all();
+  }
+
+  static StringRef name() { return "NoOpModulePass"; }
+};
+
+struct NoOpFunctionPass : PassInfoMixin<NoOpFunctionPass> {
+  PreservedAnalyses run(Function &F, FunctionAnalysisManager &) {
+    return PreservedAnalyses::all();
+  }
+  static StringRef name() { return "NoOpFunctionPass"; }
+};
+
+class DummyCodeGenPassBuilder
+    : public CodeGenPassBuilder<DummyCodeGenPassBuilder> {
+public:
+  DummyCodeGenPassBuilder(LLVMTargetMachine &TM, CGPassBuilderOption Opts,
+                          PassInstrumentationCallbacks *PIC)
+      : CodeGenPassBuilder(TM, Opts, PIC){};
+
+  void addPreISel(AddIRPass &addPass) const {
+    addPass(NoOpModulePass());
+    addPass(NoOpFunctionPass());
+    addPass(NoOpFunctionPass());
+    addPass(NoOpFunctionPass());
+    addPass(NoOpModulePass());
+    addPass(NoOpFunctionPass());
+  }
+
+  void addAsmPrinter(AddMachinePass &, CreateMCStreamer) const {}
+
+  Error addInstSelector(AddMachinePass &) const { return Error::success(); }
+};
+
+class CodeGenPassBuilderTest : public testing::Test {
+public:
+  LLVMTargetMachine *TM;
+
+  static void SetUpTestCase() {
+    InitializeAllTargets();
+    InitializeAllTargetMCs();
+  }
+
+  void SetUp() override {
+    std::string TripleName = Triple::normalize(sys::getDefaultTargetTriple());
+    std::string Error;
+    const Target *TheTarget = TargetRegistry::lookupTarget(TripleName, Error);
+    if (!TheTarget)
+      GTEST_SKIP();
+
+    TargetOptions Options;
+    TM = static_cast<LLVMTargetMachine *>(
+        TheTarget->createTargetMachine("", "", "", Options, std::nullopt));
+  }
+};
+
+TEST_F(CodeGenPassBuilderTest, basic) {
+  PrintPipelinePasses = true;
+
+  LoopAnalysisManager LAM;
+  FunctionAnalysisManager FAM;
+  CGSCCAnalysisManager CGAM;
+  ModuleAnalysisManager MAM;
+
+  PassInstrumentationCallbacks PIC;
+  DummyCodeGenPassBuilder CGPB(*TM, getCGPassBuilderOption(), &PIC);
+  PipelineTuningOptions PTO;
+  PassBuilder PB(TM, PTO, std::nullopt, &PIC);
+
+  PB.registerModuleAnalyses(MAM);
+  PB.registerCGSCCAnalyses(CGAM);
+  PB.registerFunctionAnalyses(FAM);
+  PB.registerLoopAnalyses(LAM);
+  PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
+
+  ModulePassManager MPM;
+  MachineFunctionPassManager MFPM;
+  Error Err =
+      CGPB.buildPipeline(MPM, MFPM, outs(), nullptr, CodeGenFileType::Null);
+  EXPECT_FALSE(Err);
+  std::string IRPipeline;
+  raw_string_ostream IROS(IRPipeline);
+  MPM.printPipeline(IROS, [&PIC](StringRef Name) {
+    auto PassName = PIC.getPassNameForClassName(Name);
+    return PassName.empty() ? Name : PassName;
+  });
+
+  std::string MIRPipeline;
+  raw_string_ostream MIROS(MIRPipeline);
+  MFPM.printPipeline(MIROS, [&PIC](StringRef Name) {
+    auto PassName = PIC.getPassNameForClassName(Name);
+    return PassName.empty() ? Name : PassName;
+  });
+  // TODO: Check pipeline string when all pass names are populated.
+}
+
+} // namespace



More information about the llvm-commits mailing list