[llvm] [Pass] Support eraseIf in pass manager (PR #116734)

via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 18 19:12:36 PST 2024


https://github.com/paperchalice updated https://github.com/llvm/llvm-project/pull/116734

>From ca98c86e3377ec2fe7d4248f7f4be11294ddbb46 Mon Sep 17 00:00:00 2001
From: PaperChalice <liujunchang97 at outlook.com>
Date: Mon, 19 Aug 2024 08:19:28 +0800
Subject: [PATCH] Support eraseIf in pass manager

---
 llvm/include/llvm/Analysis/CGSCCPassManager.h |  6 +++
 .../include/llvm/CodeGen/MachinePassManager.h |  4 ++
 llvm/include/llvm/IR/PassManager.h            | 20 +++++++++
 llvm/include/llvm/IR/PassManagerInternal.h    | 34 ++++++++++++++
 .../llvm/Transforms/Scalar/LoopPassManager.h  |  8 ++++
 llvm/lib/Analysis/CGSCCPassManager.cpp        | 23 ++++++++++
 llvm/lib/IR/PassManager.cpp                   | 11 +++++
 .../lib/Transforms/Scalar/LoopPassManager.cpp | 44 +++++++++++++++++++
 8 files changed, 150 insertions(+)

diff --git a/llvm/include/llvm/Analysis/CGSCCPassManager.h b/llvm/include/llvm/Analysis/CGSCCPassManager.h
index 15b7f226fd8283..4a49bfb283b6a2 100644
--- a/llvm/include/llvm/Analysis/CGSCCPassManager.h
+++ b/llvm/include/llvm/Analysis/CGSCCPassManager.h
@@ -346,6 +346,10 @@ class ModuleToPostOrderCGSCCPassAdaptor
 
   static bool isRequired() { return true; }
 
+  bool isEmpty() const { return Pass == nullptr; }
+
+  void eraseIf(function_ref<bool(StringRef)> Pred);
+
 private:
   std::unique_ptr<PassConceptT> Pass;
 };
@@ -488,6 +492,8 @@ class CGSCCToFunctionPassAdaptor
 
   static bool isRequired() { return true; }
 
+  void eraseIf(function_ref<bool(StringRef)> Pred);
+
 private:
   std::unique_ptr<PassConceptT> Pass;
   bool EagerlyInvalidate;
diff --git a/llvm/include/llvm/CodeGen/MachinePassManager.h b/llvm/include/llvm/CodeGen/MachinePassManager.h
index 69b5f6e92940c4..3d7c7d87af480a 100644
--- a/llvm/include/llvm/CodeGen/MachinePassManager.h
+++ b/llvm/include/llvm/CodeGen/MachinePassManager.h
@@ -207,6 +207,10 @@ class FunctionToMachineFunctionPassAdaptor
 
   static bool isRequired() { return true; }
 
+  bool isEmpty() const { return Pass == nullptr; }
+
+  void eraseIf(function_ref<bool(StringRef)> Pred);
+
 private:
   std::unique_ptr<PassConceptT> Pass;
 };
diff --git a/llvm/include/llvm/IR/PassManager.h b/llvm/include/llvm/IR/PassManager.h
index d269221fac0701..e06643bdb2b738 100644
--- a/llvm/include/llvm/IR/PassManager.h
+++ b/llvm/include/llvm/IR/PassManager.h
@@ -218,6 +218,22 @@ class PassManager : public PassInfoMixin<
 
   static bool isRequired() { return true; }
 
+  /// Erase all passes that satisfy the predicate \p Pred.
+  /// For internal use only!
+  void eraseIf(function_ref<bool(StringRef)> Pred) {
+    for (auto I = Passes.begin(); I != Passes.end();) {
+      auto &P = *I;
+      P->eraseIf(Pred);
+      bool IsSpecial = P->name().ends_with("PassAdaptor") ||
+                       P->name().contains("PassManager");
+      bool PredResult = Pred(P->name());
+      if ((!IsSpecial && PredResult) || (IsSpecial && P->isEmpty()))
+        I = Passes.erase(I);
+      else
+        ++I;
+    }
+  }
+
 protected:
   using PassConceptT =
       detail::PassConcept<IRUnitT, AnalysisManagerT, ExtraArgTs...>;
@@ -836,6 +852,10 @@ class ModuleToFunctionPassAdaptor
 
   static bool isRequired() { return true; }
 
+  bool isEmpty() const { return Pass == nullptr; }
+
+  void eraseIf(function_ref<bool(StringRef)> Pred);
+
 private:
   std::unique_ptr<PassConceptT> Pass;
   bool EagerlyInvalidate;
diff --git a/llvm/include/llvm/IR/PassManagerInternal.h b/llvm/include/llvm/IR/PassManagerInternal.h
index 4ada6ee5dd6831..36caf7cd85a3d5 100644
--- a/llvm/include/llvm/IR/PassManagerInternal.h
+++ b/llvm/include/llvm/IR/PassManagerInternal.h
@@ -59,6 +59,13 @@ struct PassConcept {
   /// To opt-in, pass should implement `static bool isRequired()`. It's no-op
   /// to have `isRequired` always return false since that is the default.
   virtual bool isRequired() const = 0;
+
+  /// Polymorphic method to refurbish pass pipeline.
+  virtual void eraseIf(function_ref<bool(StringRef)> Pred) = 0;
+
+  /// There may be some empty PassManager after erasing,
+  /// use it to remove them.
+  virtual bool isEmpty() const = 0;
 };
 
 /// A template wrapper used to implement the polymorphic API.
@@ -114,6 +121,33 @@ struct PassModel : PassConcept<IRUnitT, AnalysisManagerT, ExtraArgTs...> {
 
   bool isRequired() const override { return passIsRequiredImpl<PassT>(); }
 
+  template <typename T>
+  using has_erase_if_t = decltype(std::declval<T &>().eraseIf(
+      std::declval<function_ref<bool(StringRef)>>()));
+
+  template <typename T>
+  std::enable_if_t<is_detected<has_erase_if_t, T>::value>
+  eraseIfImpl(function_ref<bool(StringRef)> Pred) {
+    Pass.eraseIf(Pred);
+  }
+
+  template <typename T>
+  std::enable_if_t<!is_detected<has_erase_if_t, T>::value>
+  eraseIfImpl(function_ref<bool(StringRef)>) {}
+
+  void eraseIf(function_ref<bool(StringRef)> Pred) override {
+    eraseIfImpl<PassT>(Pred);
+  }
+
+  template <typename T>
+  using has_is_empty_t = decltype(std::declval<T &>().isEmpty());
+
+  bool isEmpty() const override {
+    if constexpr (is_detected<has_is_empty_t, PassT>::value)
+      return Pass.isEmpty();
+    return false;
+  }
+
   PassT Pass;
 };
 
diff --git a/llvm/include/llvm/Transforms/Scalar/LoopPassManager.h b/llvm/include/llvm/Transforms/Scalar/LoopPassManager.h
index f55022fbff07c1..33ebbd711bec04 100644
--- a/llvm/include/llvm/Transforms/Scalar/LoopPassManager.h
+++ b/llvm/include/llvm/Transforms/Scalar/LoopPassManager.h
@@ -134,6 +134,10 @@ class PassManager<Loop, LoopAnalysisManager, LoopStandardAnalysisResults &,
 
   static bool isRequired() { return true; }
 
+  /// Erase all passes that satisfy the predicate \p Pred.
+  /// For internal use only!
+  void eraseIf(function_ref<bool(StringRef)> Pred);
+
   size_t getNumLoopPasses() const { return LoopPasses.size(); }
   size_t getNumLoopNestPasses() const { return LoopNestPasses.size(); }
 
@@ -425,6 +429,10 @@ class FunctionToLoopPassAdaptor
 
   static bool isRequired() { return true; }
 
+  bool isEmpty() const { return Pass == nullptr; }
+
+  void eraseIf(function_ref<bool(StringRef)> Pred);
+
   bool isLoopNestMode() const { return LoopNestMode; }
 
 private:
diff --git a/llvm/lib/Analysis/CGSCCPassManager.cpp b/llvm/lib/Analysis/CGSCCPassManager.cpp
index 948bc2435ab275..d87ed86655ab30 100644
--- a/llvm/lib/Analysis/CGSCCPassManager.cpp
+++ b/llvm/lib/Analysis/CGSCCPassManager.cpp
@@ -576,6 +576,17 @@ PreservedAnalyses CGSCCToFunctionPassAdaptor::run(LazyCallGraph::SCC &C,
   return PA;
 }
 
+void CGSCCToFunctionPassAdaptor::eraseIf(function_ref<bool(StringRef)> Pred) {
+  StringRef PassName = Pass->name();
+  if (PassName.contains("PassManager") || PassName.ends_with("PassAdaptor")) {
+    Pass->eraseIf(Pred);
+    if (Pass->isEmpty())
+      Pass.reset();
+  } else if (Pred(PassName)) {
+    Pass.reset();
+  }
+}
+
 bool CGSCCAnalysisManagerModuleProxy::Result::invalidate(
     Module &M, const PreservedAnalyses &PA,
     ModuleAnalysisManager::Invalidator &Inv) {
@@ -751,6 +762,18 @@ bool FunctionAnalysisManagerCGSCCProxy::Result::invalidate(
   return false;
 }
 
+void ModuleToPostOrderCGSCCPassAdaptor::eraseIf(
+    function_ref<bool(StringRef)> Pred) {
+  StringRef PassName = Pass->name();
+  if (PassName.contains("PassManager") || PassName.ends_with("PassAdaptor")) {
+    Pass->eraseIf(Pred);
+    if (Pass->isEmpty())
+      Pass.reset();
+  } else if (Pred(PassName)) {
+    Pass.reset();
+  }
+}
+
 } // end namespace llvm
 
 /// When a new SCC is created for the graph we first update the
diff --git a/llvm/lib/IR/PassManager.cpp b/llvm/lib/IR/PassManager.cpp
index bf8f7906d3368d..0c9ee99f24dd19 100644
--- a/llvm/lib/IR/PassManager.cpp
+++ b/llvm/lib/IR/PassManager.cpp
@@ -145,6 +145,17 @@ PreservedAnalyses ModuleToFunctionPassAdaptor::run(Module &M,
   return PA;
 }
 
+void ModuleToFunctionPassAdaptor::eraseIf(function_ref<bool(StringRef)> Pred) {
+  StringRef PassName = Pass->name();
+  if (PassName.contains("PassManager") || PassName.ends_with("PassAdaptor")) {
+    Pass->eraseIf(Pred);
+    if (Pass->isEmpty())
+      Pass.reset();
+  } else if (Pred(PassName)) {
+    Pass.reset();
+  }
+}
+
 template <>
 void llvm::printIRUnitNameForStackTrace<Module>(raw_ostream &OS,
                                                 const Module &IR) {
diff --git a/llvm/lib/Transforms/Scalar/LoopPassManager.cpp b/llvm/lib/Transforms/Scalar/LoopPassManager.cpp
index 9f4270f5d62f5c..081303c2dda4c9 100644
--- a/llvm/lib/Transforms/Scalar/LoopPassManager.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopPassManager.cpp
@@ -62,6 +62,39 @@ void PassManager<Loop, LoopAnalysisManager, LoopStandardAnalysisResults &,
   }
 }
 
+void PassManager<Loop, LoopAnalysisManager, LoopStandardAnalysisResults &,
+                 LPMUpdater &>::eraseIf(function_ref<bool(StringRef)> Pred) {
+  std::vector<char> IsLoopNestPassVec(
+      static_cast<size_t>(IsLoopNestPass.size()));
+  for (unsigned Idx = 0, Sz = IsLoopNestPass.size(); Idx != Sz; ++Idx)
+    IsLoopNestPassVec[Idx] = IsLoopNestPass[Idx];
+
+  auto ILP = LoopPasses.begin();
+  auto ILNP = LoopNestPasses.begin();
+  for (auto I = IsLoopNestPassVec.begin(); I != IsLoopNestPassVec.end();) {
+    if (*I) {
+      if (Pred((*ILNP)->name())) {
+        I = IsLoopNestPassVec.erase(I);
+        ILNP = LoopNestPasses.erase(ILNP);
+        continue;
+      }
+      ++ILNP;
+    } else {
+      if (Pred((*ILP)->name())) {
+        I = IsLoopNestPassVec.erase(I);
+        ILP = LoopPasses.erase(ILP);
+        continue;
+      }
+      ++ILP;
+    }
+    ++I;
+  }
+
+  IsLoopNestPass.clear();
+  for (const auto I : IsLoopNestPassVec)
+    IsLoopNestPass.push_back(I);
+}
+
 // Run both loop passes and loop-nest passes on top-level loop \p L.
 PreservedAnalyses
 LoopPassManager::runWithLoopNestPasses(Loop &L, LoopAnalysisManager &AM,
@@ -359,6 +392,17 @@ PreservedAnalyses FunctionToLoopPassAdaptor::run(Function &F,
   return PA;
 }
 
+void FunctionToLoopPassAdaptor::eraseIf(function_ref<bool(StringRef)> Pred) {
+  StringRef PassName = Pass->name();
+  if (PassName.contains("PassManager") || PassName.ends_with("PassAdaptor")) {
+    Pass->eraseIf(Pred);
+    if (Pass->isEmpty())
+      Pass.reset();
+  } else if (Pred(PassName)) {
+    Pass.reset();
+  }
+}
+
 PrintLoopPass::PrintLoopPass() : OS(dbgs()) {}
 PrintLoopPass::PrintLoopPass(raw_ostream &OS, const std::string &Banner)
     : OS(OS), Banner(Banner) {}



More information about the llvm-commits mailing list