[Mlir-commits] [mlir] f65a3f7 - Make MLIR Pass Timing output configurable through injection

Mehdi Amini llvmlistbot at llvm.org
Mon Apr 27 18:39:38 PDT 2020


Author: Mehdi Amini
Date: 2020-04-28T01:39:25Z
New Revision: f65a3f7c83b29fe76e91abd88b34d6b8755b0e0e

URL: https://github.com/llvm/llvm-project/commit/f65a3f7c83b29fe76e91abd88b34d6b8755b0e0e
DIFF: https://github.com/llvm/llvm-project/commit/f65a3f7c83b29fe76e91abd88b34d6b8755b0e0e.diff

LOG: Make MLIR Pass Timing output configurable through injection

This makes it possible for the client to control where the pass timings will
be printed.

Differential Revision: https://reviews.llvm.org/D78891

Added: 
    

Modified: 
    mlir/include/mlir/Pass/PassManager.h
    mlir/lib/Pass/PassManagerOptions.cpp
    mlir/lib/Pass/PassTiming.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h
index 3117d0f33dc6..98101b82f542 100644
--- a/mlir/include/mlir/Pass/PassManager.h
+++ b/mlir/include/mlir/Pass/PassManager.h
@@ -229,12 +229,38 @@ class PassManager : public OpPassManager {
   //===--------------------------------------------------------------------===//
   // Pass Timing
 
+  /// A configuration struct provided to the pass timing feature.
+  class PassTimingConfig {
+  public:
+    using PrintCallbackFn = function_ref<void(raw_ostream &)>;
+
+    /// Initialize the configuration.
+    /// * 'displayMode' switch between list or pipeline display (see the
+    /// `PassDisplayMode` enum documentation).
+    explicit PassTimingConfig(
+        PassDisplayMode displayMode = PassDisplayMode::Pipeline)
+        : displayMode(displayMode) {}
+
+    virtual ~PassTimingConfig();
+
+    /// A hook that may be overridden by a derived config to control the
+    /// printing. The callback is supplied by the framework and the config is
+    /// responsible to call it back with a stream for the output.
+    virtual void printTiming(PrintCallbackFn printCallback);
+
+    /// Return the `PassDisplayMode` this config was created with.
+    PassDisplayMode getDisplayMode() { return displayMode; }
+
+  private:
+    PassDisplayMode displayMode;
+  };
+
   /// Add an instrumentation to time the execution of passes and the computation
   /// of analyses.
   /// Note: Timing should be enabled after all other instrumentations to avoid
   /// any potential "ghost" timing from other instrumentations being
   /// unintentionally included in the timing results.
-  void enableTiming(PassDisplayMode displayMode = PassDisplayMode::Pipeline);
+  void enableTiming(std::unique_ptr<PassTimingConfig> config = nullptr);
 
   /// Prompts the pass manager to print the statistics collected for each of the
   /// held passes after each call to 'run'.

diff  --git a/mlir/lib/Pass/PassManagerOptions.cpp b/mlir/lib/Pass/PassManagerOptions.cpp
index e0c4df56cf72..953faa28c28a 100644
--- a/mlir/lib/Pass/PassManagerOptions.cpp
+++ b/mlir/lib/Pass/PassManagerOptions.cpp
@@ -141,7 +141,8 @@ void PassManagerOptions::addPrinterInstrumentation(PassManager &pm) {
 /// Add a pass timing instrumentation if enabled by 'pass-timing' flags.
 void PassManagerOptions::addTimingInstrumentation(PassManager &pm) {
   if (passTiming)
-    pm.enableTiming(passTimingDisplayMode);
+    pm.enableTiming(
+        std::make_unique<PassManager::PassTimingConfig>(passTimingDisplayMode));
 }
 
 void mlir::registerPassManagerCLOptions() {

diff  --git a/mlir/lib/Pass/PassTiming.cpp b/mlir/lib/Pass/PassTiming.cpp
index 663cbdad7c39..c8f0ad8afa50 100644
--- a/mlir/lib/Pass/PassTiming.cpp
+++ b/mlir/lib/Pass/PassTiming.cpp
@@ -160,7 +160,8 @@ struct Timer {
 };
 
 struct PassTiming : public PassInstrumentation {
-  PassTiming(PassDisplayMode displayMode) : displayMode(displayMode) {}
+  PassTiming(std::unique_ptr<PassManager::PassTimingConfig> config)
+      : config(std::move(config)) {}
   ~PassTiming() override { print(); }
 
   /// Setup the instrumentation hooks.
@@ -231,8 +232,8 @@ struct PassTiming : public PassInstrumentation {
   /// A stack of the currently active pass timers per thread.
   DenseMap<uint64_t, SmallVector<Timer *, 4>> activeThreadTimers;
 
-  /// The display mode to use when printing the timing results.
-  PassDisplayMode displayMode;
+  /// The configuration object to use when printing the timing results.
+  std::unique_ptr<PassManager::PassTimingConfig> config;
 
   /// A mapping of pipeline timers that need to be merged into the parent
   /// collection. The timers are mapped to the parent info to merge into.
@@ -353,28 +354,37 @@ void PassTiming::print() {
     return;
 
   assert(rootTimers.size() == 1 && "expected one remaining root timer");
-  auto &rootTimer = rootTimers.begin()->second;
-  auto os = llvm::CreateInfoOutputFile();
-
-  // Print the timer header.
-  TimeRecord totalTime = rootTimer->getTotalTime();
-  printTimerHeader(*os, totalTime);
-
-  // Defer to a specialized printer for each display mode.
-  switch (displayMode) {
-  case PassDisplayMode::List:
-    printResultsAsList(*os, rootTimer.get(), totalTime);
-    break;
-  case PassDisplayMode::Pipeline:
-    printResultsAsPipeline(*os, rootTimer.get(), totalTime);
-    break;
-  }
-  printTimeEntry(*os, 0, "Total", totalTime, totalTime);
-  os->flush();
 
-  // Reset root timers.
-  rootTimers.clear();
-  activeThreadTimers.clear();
+  auto printCallback = [&](raw_ostream &os) {
+    auto &rootTimer = rootTimers.begin()->second;
+    // Print the timer header.
+    TimeRecord totalTime = rootTimer->getTotalTime();
+    printTimerHeader(os, totalTime);
+    // Defer to a specialized printer for each display mode.
+    switch (config->getDisplayMode()) {
+    case PassDisplayMode::List:
+      printResultsAsList(os, rootTimer.get(), totalTime);
+      break;
+    case PassDisplayMode::Pipeline:
+      printResultsAsPipeline(os, rootTimer.get(), totalTime);
+      break;
+    }
+    printTimeEntry(os, 0, "Total", totalTime, totalTime);
+    os.flush();
+
+    // Reset root timers.
+    rootTimers.clear();
+    activeThreadTimers.clear();
+  };
+
+  config->printTiming(printCallback);
+}
+
+// The default implementation for printTiming uses
+// `llvm::CreateInfoOutputFile()` as stream, it can be overridden by clients
+// to customize the output.
+void PassManager::PassTimingConfig::printTiming(PrintCallbackFn printCallback) {
+  printCallback(*llvm::CreateInfoOutputFile());
 }
 
 /// Print the timing result in list mode.
@@ -449,16 +459,21 @@ void PassTiming::printResultsAsPipeline(raw_ostream &os, Timer *root,
     printTimer(0, topLevelTimer.second.get());
 }
 
+// Out-of-line as key function.
+PassManager::PassTimingConfig::~PassTimingConfig() {}
+
 //===----------------------------------------------------------------------===//
 // PassManager
 //===----------------------------------------------------------------------===//
 
 /// Add an instrumentation to time the execution of passes and the computation
 /// of analyses.
-void PassManager::enableTiming(PassDisplayMode displayMode) {
+void PassManager::enableTiming(std::unique_ptr<PassTimingConfig> config) {
   // Check if pass timing is already enabled.
   if (passTiming)
     return;
-  addInstrumentation(std::make_unique<PassTiming>(displayMode));
+  if (!config)
+    config = std::make_unique<PassManager::PassTimingConfig>();
+  addInstrumentation(std::make_unique<PassTiming>(std::move(config)));
   passTiming = true;
 }


        


More information about the Mlir-commits mailing list