[llvm] [Pass] Support start/stop in instrumentation (PR #70912)

via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 1 01:15:35 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-ir

Author: None (paperchalice)

<details>
<summary>Changes</summary>

Add a special hook to support `-start-*=<passes>` and `-stop-*=<passes>` in `llc`. If still use `registerShouldRunOptionalPassCallback` then `llc` can't skip some passes.
Part of #<!-- -->69879.
NOTE: This hook has a higher priority than `PassName::isRequired()`.

---
Full diff: https://github.com/llvm/llvm-project/pull/70912.diff


3 Files Affected:

- (modified) llvm/include/llvm/IR/PassInstrumentation.h (+33) 
- (modified) llvm/lib/CodeGen/TargetPassConfig.cpp (+16-3) 
- (modified) llvm/lib/Passes/StandardInstrumentations.cpp (+4-3) 


``````````diff
diff --git a/llvm/include/llvm/IR/PassInstrumentation.h b/llvm/include/llvm/IR/PassInstrumentation.h
index 519a5e46b4373b7..3f70fcf180af81a 100644
--- a/llvm/include/llvm/IR/PassInstrumentation.h
+++ b/llvm/include/llvm/IR/PassInstrumentation.h
@@ -84,6 +84,23 @@ class PassInstrumentationCallbacks {
   using AfterAnalysisFunc = void(StringRef, Any);
   using AnalysisInvalidatedFunc = void(StringRef, Any);
   using AnalysesClearedFunc = void(StringRef);
+  using StartStopFunc = bool(StringRef, Any);
+
+  struct CodeGenStartStopInfo {
+    StringRef Start;
+    StringRef Stop;
+
+    bool IsStopMachinePass = false;
+
+    llvm::unique_function<StartStopFunc> StartStopCallback;
+
+    bool operator()(StringRef PassID, Any IR) {
+      return StartStopCallback(PassID, IR);
+    }
+    bool isStopMachineFunctionPass() const { return IsStopMachinePass; }
+    bool willCompleteCodeGenPipeline() const { return Stop.empty(); }
+    StringRef getStop() const { return Stop; }
+  };
 
 public:
   PassInstrumentationCallbacks() = default;
@@ -148,6 +165,17 @@ class PassInstrumentationCallbacks {
     AnalysesClearedCallbacks.emplace_back(std::move(C));
   }
 
+  void registerStartStopInfo(CodeGenStartStopInfo &&C) {
+    StartStopInfo = std::move(C);
+  }
+
+  bool isStartStopInfoRegistered() const { return StartStopInfo.has_value(); }
+
+  CodeGenStartStopInfo &getStartStopInfo() {
+    assert(StartStopInfo.has_value() && "StartStopInfo is unregistered!");
+    return *StartStopInfo;
+  }
+
   /// Add a class name to pass name mapping for use by pass instrumentation.
   void addClassToPassName(StringRef ClassName, StringRef PassName);
   /// Get the pass name for a given pass class name.
@@ -183,6 +211,8 @@ class PassInstrumentationCallbacks {
   /// These are run on analyses that have been cleared.
   SmallVector<llvm::unique_function<AnalysesClearedFunc>, 4>
       AnalysesClearedCallbacks;
+  /// For `llc` -start-* -stop-* options.
+  std::optional<CodeGenStartStopInfo> StartStopInfo;
 
   StringMap<std::string> ClassToPassName;
 };
@@ -236,6 +266,9 @@ class PassInstrumentation {
         ShouldRun &= C(Pass.name(), llvm::Any(&IR));
     }
 
+    if (Callbacks->StartStopInfo)
+      ShouldRun &= (*Callbacks->StartStopInfo)(Pass.name(), llvm::Any(&IR));
+
     if (ShouldRun) {
       for (auto &C : Callbacks->BeforeNonSkippedPassCallbacks)
         C(Pass.name(), llvm::Any(&IR));
diff --git a/llvm/lib/CodeGen/TargetPassConfig.cpp b/llvm/lib/CodeGen/TargetPassConfig.cpp
index e6ecbc9b03f7149..bea8590a14080ca 100644
--- a/llvm/lib/CodeGen/TargetPassConfig.cpp
+++ b/llvm/lib/CodeGen/TargetPassConfig.cpp
@@ -508,6 +508,9 @@ static void registerPartialPipelineCallback(PassInstrumentationCallbacks &PIC,
   unsigned StopBeforeInstanceNum = 0;
   unsigned StopAfterInstanceNum = 0;
 
+  bool IsStopBeforeMachinePass = false;
+  bool IsStopAfterMachinePass = false;
+
   std::tie(StartBefore, StartBeforeInstanceNum) =
       getPassNameAndInstanceNum(StartBeforeOpt);
   std::tie(StartAfter, StartAfterInstanceNum) =
@@ -536,7 +539,15 @@ static void registerPartialPipelineCallback(PassInstrumentationCallbacks &PIC,
     report_fatal_error(Twine(StopBeforeOptName) + Twine(" and ") +
                        Twine(StopAfterOptName) + Twine(" specified!"));
 
-  PIC.registerShouldRunOptionalPassCallback(
+  std::vector<StringRef> SpecialPasses = {"PassManager", "PassAdaptor",
+                                          "PrintMIRPass", "PrintModulePass"};
+
+  PassInstrumentationCallbacks::CodeGenStartStopInfo Info;
+  Info.Start = StartBefore.empty() ? StartAfter : StartBefore;
+  Info.Stop = StopBefore.empty() ? StopAfter : StopBefore;
+
+  Info.IsStopMachinePass = IsStopBeforeMachinePass || IsStopAfterMachinePass;
+  Info.StartStopCallback =
       [=, EnableCurrent = StartBefore.empty() && StartAfter.empty(),
        EnableNext = std::optional<bool>(), StartBeforeCount = 0u,
        StartAfterCount = 0u, StopBeforeCount = 0u,
@@ -567,8 +578,10 @@ static void registerPartialPipelineCallback(PassInstrumentationCallbacks &PIC,
           EnableCurrent = true;
         if (StopBeforePass && StopBeforeCount++ == StopBeforeInstanceNum)
           EnableCurrent = false;
-        return EnableCurrent;
-      });
+        return EnableCurrent || isSpecialPass(P, SpecialPasses);
+      };
+
+  PIC.registerStartStopInfo(std::move(Info));
 }
 
 void llvm::registerCodeGenCallback(PassInstrumentationCallbacks &PIC,
diff --git a/llvm/lib/Passes/StandardInstrumentations.cpp b/llvm/lib/Passes/StandardInstrumentations.cpp
index 06cc58c0219632d..c6ec1c34d74418d 100644
--- a/llvm/lib/Passes/StandardInstrumentations.cpp
+++ b/llvm/lib/Passes/StandardInstrumentations.cpp
@@ -1036,9 +1036,10 @@ void PrintPassInstrumentation::registerCallbacks(
     SpecialPasses.emplace_back("PassAdaptor");
   }
 
-  PIC.registerBeforeSkippedPassCallback([this, SpecialPasses](StringRef PassID,
-                                                              Any IR) {
-    assert(!isSpecialPass(PassID, SpecialPasses) &&
+  PIC.registerBeforeSkippedPassCallback([this, SpecialPasses,
+                                         &PIC](StringRef PassID, Any IR) {
+    assert((!isSpecialPass(PassID, SpecialPasses) ||
+            PIC.isStartStopInfoRegistered()) &&
            "Unexpectedly skipping special pass");
 
     print() << "Skipping pass: " << PassID << " on " << getIRName(IR) << "\n";

``````````

</details>


https://github.com/llvm/llvm-project/pull/70912


More information about the llvm-commits mailing list