[llvm] [Pass] Use `function_traits` to infer `IRUnitT` and `AnalysisT` in `RequireAnalysisPass ` (PR #89538)

via llvm-commits llvm-commits at lists.llvm.org
Sun Apr 21 04:09:05 PDT 2024


https://github.com/paperchalice created https://github.com/llvm/llvm-project/pull/89538

We have `function_traits` now, we can infer the `IRUnitT` and `AnalysisT` from the first and second arguments of `AnalysisT::run`.

>From 0d11a2cd93cb134a0418aa69596abf40980b4854 Mon Sep 17 00:00:00 2001
From: PaperChalice <liujunchang97 at outlook.com>
Date: Sun, 21 Apr 2024 19:06:17 +0800
Subject: [PATCH] [Pass] Use `function_traits` to infer `IRUnitT` and
 `AnalysisT` We have `function_traits` now, we can infer the `IRUnitT` and
 `AnalysisT` from the first and second arguments of `AnalysisT::run`.

---
 llvm/include/llvm/Analysis/CGSCCPassManager.h |  6 ++---
 llvm/include/llvm/IR/PassManager.h            | 12 ++++++----
 llvm/include/llvm/Passes/CodeGenPassBuilder.h |  4 ++--
 llvm/include/llvm/Passes/PassBuilder.h        |  3 +--
 .../llvm/Transforms/Scalar/LoopPassManager.h  |  7 +++---
 llvm/lib/Passes/PassBuilder.cpp               | 24 ++++++++-----------
 llvm/lib/Passes/PassBuilderPipelines.cpp      | 16 ++++++-------
 7 files changed, 33 insertions(+), 39 deletions(-)

diff --git a/llvm/include/llvm/Analysis/CGSCCPassManager.h b/llvm/include/llvm/Analysis/CGSCCPassManager.h
index 5654ad46d6eab0..e7eef7952fdd91 100644
--- a/llvm/include/llvm/Analysis/CGSCCPassManager.h
+++ b/llvm/include/llvm/Analysis/CGSCCPassManager.h
@@ -145,10 +145,8 @@ using CGSCCPassManager =
 
 /// An explicit specialization of the require analysis template pass.
 template <typename AnalysisT>
-struct RequireAnalysisPass<AnalysisT, LazyCallGraph::SCC, CGSCCAnalysisManager,
-                           LazyCallGraph &, CGSCCUpdateResult &>
-    : PassInfoMixin<RequireAnalysisPass<AnalysisT, LazyCallGraph::SCC,
-                                        CGSCCAnalysisManager, LazyCallGraph &,
+struct RequireAnalysisPass<AnalysisT, LazyCallGraph &, CGSCCUpdateResult &>
+    : PassInfoMixin<RequireAnalysisPass<AnalysisT, LazyCallGraph &,
                                         CGSCCUpdateResult &>> {
   PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM,
                         LazyCallGraph &CG, CGSCCUpdateResult &) {
diff --git a/llvm/include/llvm/IR/PassManager.h b/llvm/include/llvm/IR/PassManager.h
index d701481202f8db..6bbcb42fc3419b 100644
--- a/llvm/include/llvm/IR/PassManager.h
+++ b/llvm/include/llvm/IR/PassManager.h
@@ -934,12 +934,14 @@ createModuleToFunctionPassAdaptor(FunctionPassT &&Pass,
 ///
 /// Specific patterns of run-method extra arguments and analysis manager extra
 /// arguments will have to be defined as appropriate specializations.
-template <typename AnalysisT, typename IRUnitT,
-          typename AnalysisManagerT = AnalysisManager<IRUnitT>,
-          typename... ExtraArgTs>
+template <typename AnalysisT, typename... ExtraArgTs>
 struct RequireAnalysisPass
-    : PassInfoMixin<RequireAnalysisPass<AnalysisT, IRUnitT, AnalysisManagerT,
-                                        ExtraArgTs...>> {
+    : PassInfoMixin<RequireAnalysisPass<AnalysisT, ExtraArgTs...>> {
+  using IRUnitT = std::remove_reference_t<
+      typename function_traits<typename AnalysisT::run>::template arg_t<0>>;
+  using AnalysisManagerT = std::remove_reference_t<
+      typename function_traits<typename AnalysisT::run>::template arg_t<1>>;
+
   /// Run this pass over some unit of IR.
   ///
   /// This pass can be run over any unit of IR and use any analysis manager
diff --git a/llvm/include/llvm/Passes/CodeGenPassBuilder.h b/llvm/include/llvm/Passes/CodeGenPassBuilder.h
index 2e94a19502131a..5ba6ab2351b2b5 100644
--- a/llvm/include/llvm/Passes/CodeGenPassBuilder.h
+++ b/llvm/include/llvm/Passes/CodeGenPassBuilder.h
@@ -512,8 +512,8 @@ Error CodeGenPassBuilder<Derived, TargetMachineT>::buildPipeline(
 
   {
     AddIRPass addIRPass(MPM, derived());
-    addIRPass(RequireAnalysisPass<ProfileSummaryAnalysis, Module>());
-    addIRPass(RequireAnalysisPass<CollectorMetadataAnalysis, Module>());
+    addIRPass(RequireAnalysisPass<ProfileSummaryAnalysis>());
+    addIRPass(RequireAnalysisPass<CollectorMetadataAnalysis>());
     addISelPasses(addIRPass);
   }
 
diff --git a/llvm/include/llvm/Passes/PassBuilder.h b/llvm/include/llvm/Passes/PassBuilder.h
index c8f643452bb158..ee5cd437b3a978 100644
--- a/llvm/include/llvm/Passes/PassBuilder.h
+++ b/llvm/include/llvm/Passes/PassBuilder.h
@@ -828,8 +828,7 @@ bool parseAnalysisUtilityPasses(
     PipelineName = PipelineName.substr(8, PipelineName.size() - 9);
     if (PipelineName != AnalysisName)
       return false;
-    PM.addPass(RequireAnalysisPass<AnalysisT, IRUnitT, AnalysisManagerT,
-                                   ExtraArgTs...>());
+    PM.addPass(RequireAnalysisPass<AnalysisT, ExtraArgTs...>());
     return true;
   }
 
diff --git a/llvm/include/llvm/Transforms/Scalar/LoopPassManager.h b/llvm/include/llvm/Transforms/Scalar/LoopPassManager.h
index 6aab1f98e67816..ecda329faa3b54 100644
--- a/llvm/include/llvm/Transforms/Scalar/LoopPassManager.h
+++ b/llvm/include/llvm/Transforms/Scalar/LoopPassManager.h
@@ -217,8 +217,8 @@ typedef PassManager<Loop, LoopAnalysisManager, LoopStandardAnalysisResults &,
 /// the extra parameters from a transformation's run method to the
 /// AnalysisManager's getResult.
 template <typename AnalysisT>
-struct RequireAnalysisPass<AnalysisT, Loop, LoopAnalysisManager,
-                           LoopStandardAnalysisResults &, LPMUpdater &>
+struct RequireAnalysisPass<AnalysisT, LoopStandardAnalysisResults &,
+                           LPMUpdater &>
     : PassInfoMixin<
           RequireAnalysisPass<AnalysisT, Loop, LoopAnalysisManager,
                               LoopStandardAnalysisResults &, LPMUpdater &>> {
@@ -238,8 +238,7 @@ struct RequireAnalysisPass<AnalysisT, Loop, LoopAnalysisManager,
 /// An alias template to easily name a require analysis loop pass.
 template <typename AnalysisT>
 using RequireAnalysisLoopPass =
-    RequireAnalysisPass<AnalysisT, Loop, LoopAnalysisManager,
-                        LoopStandardAnalysisResults &, LPMUpdater &>;
+    RequireAnalysisPass<AnalysisT, LoopStandardAnalysisResults &, LPMUpdater &>;
 
 class FunctionToLoopPassAdaptor;
 
diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index 8d408ca2363a98..fe91db7a467187 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -1503,9 +1503,8 @@ Error PassBuilder::parseModulePass(ModulePassManager &MPM,
   }
 #define MODULE_ANALYSIS(NAME, CREATE_PASS)                                     \
   if (Name == "require<" NAME ">") {                                           \
-    MPM.addPass(                                                               \
-        RequireAnalysisPass<                                                   \
-            std::remove_reference_t<decltype(CREATE_PASS)>, Module>());        \
+    MPM.addPass(RequireAnalysisPass<                                           \
+                std::remove_reference_t<decltype(CREATE_PASS)>>());            \
     return Error::success();                                                   \
   }                                                                            \
   if (Name == "invalidate<" NAME ">") {                                        \
@@ -1638,10 +1637,9 @@ Error PassBuilder::parseCGSCCPass(CGSCCPassManager &CGPM,
   }
 #define CGSCC_ANALYSIS(NAME, CREATE_PASS)                                      \
   if (Name == "require<" NAME ">") {                                           \
-    CGPM.addPass(RequireAnalysisPass<                                          \
-                 std::remove_reference_t<decltype(CREATE_PASS)>,               \
-                 LazyCallGraph::SCC, CGSCCAnalysisManager, LazyCallGraph &,    \
-                 CGSCCUpdateResult &>());                                      \
+    CGPM.addPass(                                                              \
+        RequireAnalysisPass<std::remove_reference_t<decltype(CREATE_PASS)>,    \
+                            LazyCallGraph &, CGSCCUpdateResult &>());          \
     return Error::success();                                                   \
   }                                                                            \
   if (Name == "invalidate<" NAME ">") {                                        \
@@ -1759,9 +1757,8 @@ Error PassBuilder::parseFunctionPass(FunctionPassManager &FPM,
   }
 #define FUNCTION_ANALYSIS(NAME, CREATE_PASS)                                   \
   if (Name == "require<" NAME ">") {                                           \
-    FPM.addPass(                                                               \
-        RequireAnalysisPass<                                                   \
-            std::remove_reference_t<decltype(CREATE_PASS)>, Function>());      \
+    FPM.addPass(RequireAnalysisPass<                                           \
+                std::remove_reference_t<decltype(CREATE_PASS)>>());            \
     return Error::success();                                                   \
   }                                                                            \
   if (Name == "invalidate<" NAME ">") {                                        \
@@ -1856,10 +1853,9 @@ Error PassBuilder::parseLoopPass(LoopPassManager &LPM,
   }
 #define LOOP_ANALYSIS(NAME, CREATE_PASS)                                       \
   if (Name == "require<" NAME ">") {                                           \
-    LPM.addPass(RequireAnalysisPass<                                           \
-                std::remove_reference_t<decltype(CREATE_PASS)>, Loop,          \
-                LoopAnalysisManager, LoopStandardAnalysisResults &,            \
-                LPMUpdater &>());                                              \
+    LPM.addPass(                                                               \
+        RequireAnalysisPass<std::remove_reference_t<decltype(CREATE_PASS)>,    \
+                            LoopStandardAnalysisResults &, LPMUpdater &>());   \
     return Error::success();                                                   \
   }                                                                            \
   if (Name == "invalidate<" NAME ">") {                                        \
diff --git a/llvm/lib/Passes/PassBuilderPipelines.cpp b/llvm/lib/Passes/PassBuilderPipelines.cpp
index 3bb2ce0ae3460b..596c3b9fe88f20 100644
--- a/llvm/lib/Passes/PassBuilderPipelines.cpp
+++ b/llvm/lib/Passes/PassBuilderPipelines.cpp
@@ -816,7 +816,7 @@ void PassBuilder::addPGOInstrPasses(ModulePassManager &MPM,
         PGOInstrumentationUse(ProfileFile, ProfileRemappingFile, IsCS, FS));
     // Cache ProfileSummaryAnalysis once to avoid the potential need to insert
     // RequireAnalysisPass for PSI before subsequent non-module passes.
-    MPM.addPass(RequireAnalysisPass<ProfileSummaryAnalysis, Module>());
+    MPM.addPass(RequireAnalysisPass<ProfileSummaryAnalysis>());
     return;
   }
 
@@ -855,7 +855,7 @@ void PassBuilder::addPGOInstrPassesForO0(
         PGOInstrumentationUse(ProfileFile, ProfileRemappingFile, IsCS, FS));
     // Cache ProfileSummaryAnalysis once to avoid the potential need to insert
     // RequireAnalysisPass for PSI before subsequent non-module passes.
-    MPM.addPass(RequireAnalysisPass<ProfileSummaryAnalysis, Module>());
+    MPM.addPass(RequireAnalysisPass<ProfileSummaryAnalysis>());
     return;
   }
 
@@ -904,7 +904,7 @@ PassBuilder::buildInlinerPipeline(OptimizationLevel Level,
   // Require the GlobalsAA analysis for the module so we can query it within
   // the CGSCC pipeline.
   if (EnableGlobalAnalyses) {
-    MIWP.addModulePass(RequireAnalysisPass<GlobalsAA, Module>());
+    MIWP.addModulePass(RequireAnalysisPass<GlobalsAA>());
     // Invalidate AAManager so it can be recreated and pick up the newly
     // available GlobalsAA.
     MIWP.addModulePass(
@@ -913,7 +913,7 @@ PassBuilder::buildInlinerPipeline(OptimizationLevel Level,
 
   // Require the ProfileSummaryAnalysis for the module so we can query it within
   // the inliner pass.
-  MIWP.addModulePass(RequireAnalysisPass<ProfileSummaryAnalysis, Module>());
+  MIWP.addModulePass(RequireAnalysisPass<ProfileSummaryAnalysis>());
 
   // Now begin the main postorder CGSCC pipeline.
   // FIXME: The current CGSCC pipeline has its origins in the legacy pass
@@ -961,7 +961,7 @@ PassBuilder::buildInlinerPipeline(OptimizationLevel Level,
   // simplified again if we somehow revisit it due to CGSCC mutations unless
   // it's been modified since.
   MainCGPipeline.addPass(createCGSCCToFunctionPassAdaptor(
-      RequireAnalysisPass<ShouldNotRunFunctionPassesAnalysis, Function>()));
+      RequireAnalysisPass<ShouldNotRunFunctionPassesAnalysis>()));
 
   MainCGPipeline.addPass(CoroSplitPass(Level != OptimizationLevel::O0));
 
@@ -1084,7 +1084,7 @@ PassBuilder::buildModuleSimplificationPipeline(OptimizationLevel Level,
                                         PGOOpt->ProfileRemappingFile, Phase));
     // Cache ProfileSummaryAnalysis once to avoid the potential need to insert
     // RequireAnalysisPass for PSI before subsequent non-module passes.
-    MPM.addPass(RequireAnalysisPass<ProfileSummaryAnalysis, Module>());
+    MPM.addPass(RequireAnalysisPass<ProfileSummaryAnalysis>());
     // Do not invoke ICP in the LTOPrelink phase as it makes it hard
     // for the profile annotation to be accurate in the LTO backend.
     if (!isLTOPreLink(Phase))
@@ -1740,7 +1740,7 @@ PassBuilder::buildLTODefaultPipeline(OptimizationLevel Level,
                                         ThinOrFullLTOPhase::FullLTOPostLink));
     // Cache ProfileSummaryAnalysis once to avoid the potential need to insert
     // RequireAnalysisPass for PSI before subsequent non-module passes.
-    MPM.addPass(RequireAnalysisPass<ProfileSummaryAnalysis, Module>());
+    MPM.addPass(RequireAnalysisPass<ProfileSummaryAnalysis>());
   }
 
   // Try to run OpenMP optimizations, quick no-op if no OpenMP metadata present.
@@ -1915,7 +1915,7 @@ PassBuilder::buildLTODefaultPipeline(OptimizationLevel Level,
   // Require the GlobalsAA analysis for the module so we can query it within
   // MainFPM.
   if (EnableGlobalAnalyses) {
-    MPM.addPass(RequireAnalysisPass<GlobalsAA, Module>());
+    MPM.addPass(RequireAnalysisPass<GlobalsAA>());
     // Invalidate AAManager so it can be recreated and pick up the newly
     // available GlobalsAA.
     MPM.addPass(



More information about the llvm-commits mailing list