[llvm] [CodeGen] Allow `CodeGenPassBuilder` to add module pass after function pass (PR #77084)

via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 5 17:29:51 PST 2024


https://github.com/paperchalice updated https://github.com/llvm/llvm-project/pull/77084

>From 9755fa616705a20cdd9f44e861a5520afb10bbf1 Mon Sep 17 00:00:00 2001
From: PaperChalice <liujunchang97 at outlook.com>
Date: Fri, 5 Jan 2024 19:35:43 +0800
Subject: [PATCH] [CodeGen] Allow `CodeGenPassBuilder` to add module pass after
 function pass

In fact, there are several backends, e.g. AArch64, AMDGPU etc. add module pass after function pass, this patch removes this constraint.
This patch also adds a simple unit test for `CodeGenPassBuilder`
---
 .../include/llvm/CodeGen/CodeGenPassBuilder.h |  53 +++----
 llvm/unittests/CodeGen/CMakeLists.txt         |   4 +
 .../CodeGen/CodeGenPassBuilderTest.cpp        | 130 ++++++++++++++++++
 3 files changed, 153 insertions(+), 34 deletions(-)
 create mode 100644 llvm/unittests/CodeGen/CodeGenPassBuilderTest.cpp

diff --git a/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h b/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h
index 2100c30aad1180..3f903fb616e49d 100644
--- a/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h
+++ b/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h
@@ -172,45 +172,33 @@ template <typename DerivedT> class CodeGenPassBuilder {
   // Function object to maintain state while adding codegen IR passes.
   class AddIRPass {
   public:
-    AddIRPass(ModulePassManager &MPM, bool DebugPM, bool Check = true)
-        : MPM(MPM) {
-      if (Check)
-        AddingFunctionPasses = false;
-    }
+    AddIRPass(ModulePassManager &MPM, bool DebugPM) : MPM(MPM) {}
     ~AddIRPass() {
-      MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
-    }
-
-    // Add Function Pass
-    template <typename PassT>
-    std::enable_if_t<is_detected<is_function_pass_t, PassT>::value>
-    operator()(PassT &&Pass) {
-      if (AddingFunctionPasses && !*AddingFunctionPasses)
-        AddingFunctionPasses = true;
-      FPM.addPass(std::forward<PassT>(Pass));
+      if (!FPM.isEmpty())
+        MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
     }
 
-    // Add Module Pass
-    template <typename PassT>
-    std::enable_if_t<is_detected<is_module_pass_t, PassT>::value &&
-                     !is_detected<is_function_pass_t, PassT>::value>
-    operator()(PassT &&Pass) {
-      assert((!AddingFunctionPasses || !*AddingFunctionPasses) &&
-             "could not add module pass after adding function pass");
-      MPM.addPass(std::forward<PassT>(Pass));
+    template <typename PassT> void operator()(PassT &&Pass) {
+      // Add Function Pass
+      if constexpr (is_detected<is_function_pass_t, PassT>::value) {
+        FPM.addPass(std::forward<PassT>(Pass));
+        return;
+      }
+
+      // Add Module Pass
+      if constexpr (is_detected<is_module_pass_t, PassT>::value &&
+                    !is_detected<is_function_pass_t, PassT>::value) {
+        if (!FPM.isEmpty()) {
+          MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
+          FPM = FunctionPassManager();
+        }
+        MPM.addPass(std::forward<PassT>(Pass));
+      }
     }
 
   private:
     ModulePassManager &MPM;
     FunctionPassManager FPM;
-    // The codegen IR pipeline are mostly function passes with the exceptions of
-    // a few loop and module passes. `AddingFunctionPasses` make sures that
-    // we could only add module passes at the beginning of the pipeline. Once
-    // we begin adding function passes, we could no longer add module passes.
-    // This special-casing introduces less adaptor passes. If we have the need
-    // of adding module passes after function passes, we could change the
-    // implementation to accommodate that.
-    std::optional<bool> AddingFunctionPasses;
   };
 
   // Function object to maintain state while adding codegen machine passes.
@@ -643,9 +631,6 @@ void CodeGenPassBuilder<Derived>::addIRPasses(AddIRPass &addPass) const {
   // Run GC lowering passes for builtin collectors
   // TODO: add a pass insertion point here
   addPass(GCLoweringPass());
-  // FIXME: `ShadowStackGCLoweringPass` now is a
-  // module pass, so it will trigger assertion.
-  // See comment of `AddingFunctionPasses`
   addPass(ShadowStackGCLoweringPass());
   addPass(LowerConstantIntrinsicsPass());
 
diff --git a/llvm/unittests/CodeGen/CMakeLists.txt b/llvm/unittests/CodeGen/CMakeLists.txt
index fa6c9cf7c5aebf..c78cbfcc281939 100644
--- a/llvm/unittests/CodeGen/CMakeLists.txt
+++ b/llvm/unittests/CodeGen/CMakeLists.txt
@@ -7,13 +7,16 @@ set(LLVM_LINK_COMPONENTS
   CodeGenTypes
   Core
   FileCheck
+  IRPrinter
   MC
   MIRParser
   Passes
+  ScalarOpts
   SelectionDAG
   Support
   Target
   TargetParser
+  TransformUtils
   )
 
 add_llvm_unittest(CodeGenTests
@@ -22,6 +25,7 @@ add_llvm_unittest(CodeGenTests
   AMDGPUMetadataTest.cpp
   AsmPrinterDwarfTest.cpp
   CCStateTest.cpp
+  CodeGenPassBuilderTest.cpp
   DIEHashTest.cpp
   DIETest.cpp
   DwarfStringPoolEntryRefTest.cpp
diff --git a/llvm/unittests/CodeGen/CodeGenPassBuilderTest.cpp b/llvm/unittests/CodeGen/CodeGenPassBuilderTest.cpp
new file mode 100644
index 00000000000000..e3450ff1428fe0
--- /dev/null
+++ b/llvm/unittests/CodeGen/CodeGenPassBuilderTest.cpp
@@ -0,0 +1,130 @@
+//===- llvm/unittest/CodeGen/CodeGenPassBuilder.cpp -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/CodeGen/CodeGenPassBuilder.h"
+#include "llvm/CodeGen/MachinePassManager.h"
+#include "llvm/CodeGen/TargetPassConfig.h"
+#include "llvm/MC/MCStreamer.h"
+#include "llvm/MC/TargetRegistry.h"
+#include "llvm/Passes/PassBuilder.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/TargetParser/Host.h"
+#include "gtest/gtest.h"
+#include <string>
+
+namespace llvm {
+
+extern cl::opt<bool> PrintPipelinePasses;
+
+}
+
+using namespace llvm;
+
+namespace {
+
+struct NoOpModulePass : PassInfoMixin<NoOpModulePass> {
+  PreservedAnalyses run(Module &M, ModuleAnalysisManager &) {
+    return PreservedAnalyses::all();
+  }
+
+  static StringRef name() { return "NoOpModulePass"; }
+};
+
+struct NoOpFunctionPass : PassInfoMixin<NoOpFunctionPass> {
+  PreservedAnalyses run(Function &F, FunctionAnalysisManager &) {
+    return PreservedAnalyses::all();
+  }
+  static StringRef name() { return "NoOpFunctionPass"; }
+};
+
+class DummyCodeGenPassBuilder
+    : public CodeGenPassBuilder<DummyCodeGenPassBuilder> {
+public:
+  DummyCodeGenPassBuilder(LLVMTargetMachine &TM, CGPassBuilderOption Opts,
+                          PassInstrumentationCallbacks *PIC)
+      : CodeGenPassBuilder(TM, Opts, PIC){};
+
+  void addPreISel(AddIRPass &addPass) const {
+    addPass(NoOpModulePass());
+    addPass(NoOpFunctionPass());
+    addPass(NoOpFunctionPass());
+    addPass(NoOpFunctionPass());
+    addPass(NoOpModulePass());
+    addPass(NoOpFunctionPass());
+  }
+
+  void addAsmPrinter(AddMachinePass &, CreateMCStreamer) const {}
+
+  Error addInstSelector(AddMachinePass &) const { return Error::success(); }
+};
+
+class CodeGenPassBuilderTest : public testing::Test {
+public:
+  LLVMTargetMachine *TM;
+
+  static void SetUpTestCase() {
+    InitializeAllTargets();
+    InitializeAllTargetMCs();
+  }
+
+  void SetUp() override {
+    std::string TripleName = Triple::normalize(sys::getDefaultTargetTriple());
+    std::string Error;
+    const Target *TheTarget = TargetRegistry::lookupTarget(TripleName, Error);
+    if (!TheTarget)
+      GTEST_SKIP();
+
+    TargetOptions Options;
+    TM = static_cast<LLVMTargetMachine *>(
+        TheTarget->createTargetMachine("", "", "", Options, std::nullopt));
+  }
+};
+
+TEST_F(CodeGenPassBuilderTest, basic) {
+  PrintPipelinePasses = true;
+
+  LoopAnalysisManager LAM;
+  FunctionAnalysisManager FAM;
+  CGSCCAnalysisManager CGAM;
+  ModuleAnalysisManager MAM;
+
+  PassInstrumentationCallbacks PIC;
+  DummyCodeGenPassBuilder CGPB(*TM, getCGPassBuilderOption(), &PIC);
+  PipelineTuningOptions PTO;
+  PassBuilder PB(TM, PTO, std::nullopt, &PIC);
+
+  PB.registerModuleAnalyses(MAM);
+  PB.registerCGSCCAnalyses(CGAM);
+  PB.registerFunctionAnalyses(FAM);
+  PB.registerLoopAnalyses(LAM);
+  PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
+
+  ModulePassManager MPM;
+  MachineFunctionPassManager MFPM;
+  Error Err =
+      CGPB.buildPipeline(MPM, MFPM, outs(), nullptr, CodeGenFileType::Null);
+  EXPECT_FALSE(Err);
+  std::string IRPipeline;
+  raw_string_ostream IROS(IRPipeline);
+  MPM.printPipeline(IROS, [&PIC](StringRef Name) {
+    auto PassName = PIC.getPassNameForClassName(Name);
+    return PassName.empty() ? Name : PassName;
+  });
+
+  std::string MIRPipeline;
+  raw_string_ostream MIROS(MIRPipeline);
+  MFPM.printPipeline(MIROS, [&PIC](StringRef Name) {
+    auto PassName = PIC.getPassNameForClassName(Name);
+    return PassName.empty() ? Name : PassName;
+  });
+  // TODO: Check pipeline string when all pass names are populated.
+}
+
+} // namespace



More information about the llvm-commits mailing list