[Mlir-commits] [mlir] cca5106 - Refactor the mlir-opt command line options related to debugging in a helper

Mehdi Amini llvmlistbot at llvm.org
Mon Apr 24 14:34:37 PDT 2023


Author: Mehdi Amini
Date: 2023-04-24T14:34:15-07:00
New Revision: cca510640bf0aa3ef356a8ad51652de16b5a557a

URL: https://github.com/llvm/llvm-project/commit/cca510640bf0aa3ef356a8ad51652de16b5a557a
DIFF: https://github.com/llvm/llvm-project/commit/cca510640bf0aa3ef356a8ad51652de16b5a557a.diff

LOG: Refactor the mlir-opt command line options related to debugging in a helper

This makes it reusable across various tooling and reduces the amount of
boilerplate needed.

Differential Revision: https://reviews.llvm.org/D144818

Added: 
    mlir/include/mlir/Debug/CLOptionsSetup.h
    mlir/lib/Debug/CLOptionsSetup.cpp

Modified: 
    mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
    mlir/lib/Debug/CMakeLists.txt
    mlir/lib/Tools/mlir-opt/MlirOptMain.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Debug/CLOptionsSetup.h b/mlir/include/mlir/Debug/CLOptionsSetup.h
new file mode 100644
index 0000000000000..ab8e7b436a1e9
--- /dev/null
+++ b/mlir/include/mlir/Debug/CLOptionsSetup.h
@@ -0,0 +1,93 @@
+//===- CLOptionsSetup.h - Helpers to setup debug CL options -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DEBUG_CLOPTIONSSETUP_H
+#define MLIR_DEBUG_CLOPTIONSSETUP_H
+
+#include "mlir/Debug/BreakpointManagers/FileLineColLocBreakpointManager.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/StringRef.h"
+
+#include <memory>
+
+namespace mlir {
+class MLIRContext;
+namespace tracing {
+class BreakpointManager;
+
+class DebugConfig {
+public:
+  /// Register the options as global LLVM command line options.
+  static void registerCLOptions();
+
+  /// Create a new config with the default set from the CL options.
+  static DebugConfig createFromCLOptions();
+
+  ///
+  /// Options.
+  ///
+
+  /// Enable the Debugger action hook: it makes a debugger (like gdb or lldb)
+  /// able to intercept MLIR Actions.
+  void enableDebuggerActionHook(bool enabled = true) {
+    enableDebuggerActionHookFlag = enabled;
+  }
+
+  /// Return true if the debugger action hook is enabled.
+  bool isDebuggerActionHookEnabled() const {
+    return enableDebuggerActionHookFlag;
+  }
+
+  /// Set the filename to use for logging actions, use "-" for stdout.
+  DebugConfig &logActionsTo(StringRef filename) {
+    logActionsToFlag = filename;
+    return *this;
+  }
+  /// Get the filename to use for logging actions.
+  StringRef getLogActionsTo() const { return logActionsToFlag; }
+
+  /// Set a location breakpoint manager to filter out action logging based on
+  /// the attached IR location in the Action context. Ownership stays with the
+  /// caller.
+  void addLogActionLocFilter(tracing::BreakpointManager *breakpointManager) {
+    logActionLocationFilter.push_back(breakpointManager);
+  }
+
+  /// Get the location breakpoint managers to use to filter out action logging.
+  ArrayRef<tracing::BreakpointManager *> getLogActionsLocFilters() const {
+    return logActionLocationFilter;
+  }
+
+protected:
+  /// Enable the Debugger action hook: a debugger (like gdb or lldb) can
+  /// intercept MLIR Actions.
+  bool enableDebuggerActionHookFlag = false;
+
+  /// Log action execution to the given file (or "-" for stdout)
+  std::string logActionsToFlag;
+
+  /// Location Breakpoints to filter the action logging.
+  std::vector<tracing::BreakpointManager *> logActionLocationFilter;
+};
+
+/// This is a RAII class that installs the debug handlers on the context
+/// based on the provided configuration.
+class InstallDebugHandler {
+public:
+  InstallDebugHandler(MLIRContext &context, const DebugConfig &config);
+  ~InstallDebugHandler();
+
+private:
+  class Impl;
+  std::unique_ptr<Impl> impl;
+};
+
+} // namespace tracing
+} // namespace mlir
+
+#endif // MLIR_DEBUG_CLOPTIONSSETUP_H

diff  --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
index d393b24252f7f..02b0eb098f954 100644
--- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
+++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
@@ -13,7 +13,7 @@
 #ifndef MLIR_TOOLS_MLIROPT_MLIROPTMAIN_H
 #define MLIR_TOOLS_MLIROPT_MLIROPTMAIN_H
 
-#include "mlir/Debug/BreakpointManagers/FileLineColLocBreakpointManager.h"
+#include "mlir/Debug/CLOptionsSetup.h"
 #include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/StringRef.h"
 
@@ -30,9 +30,6 @@ namespace mlir {
 class DialectRegistry;
 class PassPipelineCLParser;
 class PassManager;
-namespace tracing {
-class FileLineColLocBreakpointManager;
-}
 
 /// Configuration options for the mlir-opt tool.
 /// This is intended to help building tools like mlir-opt by collecting the
@@ -64,6 +61,14 @@ class MlirOptMainConfig {
     return allowUnregisteredDialectsFlag;
   }
 
+  /// Set the debug configuration to use.
+  MlirOptMainConfig &setDebugConfig(tracing::DebugConfig config) {
+    debugConfig = std::move(config);
+    return *this;
+  }
+  tracing::DebugConfig &getDebugConfig() { return debugConfig; }
+  const tracing::DebugConfig &getDebugConfig() const { return debugConfig; }
+
   /// Print the pass-pipeline as text before executing.
   MlirOptMainConfig &dumpPassPipeline(bool dump) {
     dumpPassPipelineFlag = dump;
@@ -78,17 +83,6 @@ class MlirOptMainConfig {
   }
   bool shouldEmitBytecode() const { return emitBytecodeFlag; }
 
-  /// Enable the debugger action hook: it makes the debugger able to intercept
-  /// MLIR Actions.
-  void enableDebuggerActionHook(bool enabled = true) {
-    enableDebuggerActionHookFlag = enabled;
-  }
-
-  /// Return true if the Debugger action hook is enabled.
-  bool isDebuggerActionHookEnabled() const {
-    return enableDebuggerActionHookFlag;
-  }
-
   /// Set the IRDL file to load before processing the input.
   MlirOptMainConfig &setIrdlFile(StringRef file) {
     irdlFileFlag = file;
@@ -96,26 +90,6 @@ class MlirOptMainConfig {
   }
   StringRef getIrdlFile() const { return irdlFileFlag; }
 
-  /// Set the filename to use for logging actions, use "-" for stdout.
-  MlirOptMainConfig &logActionsTo(StringRef filename) {
-    logActionsToFlag = filename;
-    return *this;
-  }
-  /// Get the filename to use for logging actions.
-  StringRef getLogActionsTo() const { return logActionsToFlag; }
-
-  /// Set a location breakpoint manager to filter out action logging based on
-  /// the attached IR location in the Action context. Ownership stays with the
-  /// caller.
-  void addLogActionLocFilter(tracing::BreakpointManager *breakpointManager) {
-    logActionLocationFilter.push_back(breakpointManager);
-  }
-
-  /// Get the location breakpoint managers to use to filter out action logging.
-  ArrayRef<tracing::BreakpointManager *> getLogActionsLocFilters() const {
-    return logActionLocationFilter;
-  }
-
   /// Set the callback to populate the pass manager.
   MlirOptMainConfig &
   setPassPipelineSetupFn(std::function<LogicalResult(PassManager &)> callback) {
@@ -185,6 +159,9 @@ class MlirOptMainConfig {
   /// general.
   bool allowUnregisteredDialectsFlag = false;
 
+  /// Configuration for the debugging hooks.
+  tracing::DebugConfig debugConfig;
+
   /// Print the pipeline that will be run.
   bool dumpPassPipelineFlag = false;
 
@@ -197,9 +174,6 @@ class MlirOptMainConfig {
   /// IRDL file to register before processing the input.
   std::string irdlFileFlag = "";
 
-  /// Log action execution to the given file (or "-" for stdout)
-  std::string logActionsToFlag;
-
   /// Location Breakpoints to filter the action logging.
   std::vector<tracing::BreakpointManager *> logActionLocationFilter;
 

diff  --git a/mlir/lib/Debug/CLOptionsSetup.cpp b/mlir/lib/Debug/CLOptionsSetup.cpp
new file mode 100644
index 0000000000000..96e6ddc32be9b
--- /dev/null
+++ b/mlir/lib/Debug/CLOptionsSetup.cpp
@@ -0,0 +1,120 @@
+//===- CLOptionsSetup.cpp - Helpers to setup debug CL options ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Debug/CLOptionsSetup.h"
+
+#include "mlir/Debug/Counter.h"
+#include "mlir/Debug/DebuggerDebugExecutionContextHook.h"
+#include "mlir/Debug/ExecutionContext.h"
+#include "mlir/Debug/Observers/ActionLogging.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Support/FileUtilities.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/ToolOutputFile.h"
+
+using namespace mlir;
+using namespace mlir::tracing;
+using namespace llvm;
+
+namespace {
+struct DebugConfigCLOptions : public DebugConfig {
+  DebugConfigCLOptions() {
+    static cl::opt<std::string, /*ExternalStorage=*/true> logActionsTo{
+        "log-actions-to",
+        cl::desc("Log action execution to a file, or stderr if "
+                 " '-' is passed"),
+        cl::location(logActionsToFlag)};
+
+    static cl::list<std::string> logActionLocationFilter(
+        "log-mlir-actions-filter",
+        cl::desc(
+            "Comma separated list of locations to filter actions from logging"),
+        cl::CommaSeparated,
+        cl::cb<void, std::string>([&](const std::string &location) {
+          static bool register_once = [&] {
+            addLogActionLocFilter(&locBreakpointManager);
+            return true;
+          }();
+          (void)register_once;
+          static std::vector<std::string> locations;
+          locations.push_back(location);
+          StringRef locStr = locations.back();
+
+          // Parse the individual location filters and set the breakpoints.
+          auto diag = [](Twine msg) { llvm::errs() << msg << "\n"; };
+          auto locBreakpoint =
+              tracing::FileLineColLocBreakpoint::parseFromString(locStr, diag);
+          if (failed(locBreakpoint)) {
+            llvm::errs() << "Invalid location filter: " << locStr << "\n";
+            exit(1);
+          }
+          auto [file, line, col] = *locBreakpoint;
+          locBreakpointManager.addBreakpoint(file, line, col);
+        }));
+  }
+  tracing::FileLineColLocBreakpointManager locBreakpointManager;
+};
+
+} // namespace
+
+static ManagedStatic<DebugConfigCLOptions> clOptionsConfig;
+void DebugConfig::registerCLOptions() { *clOptionsConfig; }
+
+DebugConfig DebugConfig::createFromCLOptions() { return *clOptionsConfig; }
+
+class InstallDebugHandler::Impl {
+public:
+  Impl(MLIRContext &context, const DebugConfig &config) {
+    if (config.getLogActionsTo().empty() &&
+        !config.isDebuggerActionHookEnabled()) {
+      if (tracing::DebugCounter::isActivated())
+        context.registerActionHandler(tracing::DebugCounter());
+      return;
+    }
+    errs() << "ExecutionContext registered on the context";
+    if (tracing::DebugCounter::isActivated())
+      emitError(UnknownLoc::get(&context),
+                "Debug counters are incompatible with --log-actions-to and "
+                "--mlir-enable-debugger-hook options and are disabled");
+    if (!config.getLogActionsTo().empty()) {
+      std::string errorMessage;
+      logActionsFile = openOutputFile(config.getLogActionsTo(), &errorMessage);
+      if (!logActionsFile) {
+        emitError(UnknownLoc::get(&context),
+                  "Opening file for --log-actions-to failed: ")
+            << errorMessage << "\n";
+        return;
+      }
+      logActionsFile->keep();
+      raw_fd_ostream &logActionsStream = logActionsFile->os();
+      actionLogger = std::make_unique<tracing::ActionLogger>(logActionsStream);
+      for (const auto *locationBreakpoint : config.getLogActionsLocFilters())
+        actionLogger->addBreakpointManager(locationBreakpoint);
+      executionContext.registerObserver(actionLogger.get());
+    }
+    if (config.isDebuggerActionHookEnabled()) {
+      errs() << " (with Debugger hook)";
+      setupDebuggerDebugExecutionContextHook(executionContext);
+    }
+    errs() << "\n";
+    context.registerActionHandler(executionContext);
+  }
+
+private:
+  std::unique_ptr<ToolOutputFile> logActionsFile;
+  tracing::ExecutionContext executionContext;
+  std::unique_ptr<tracing::ActionLogger> actionLogger;
+  std::vector<std::unique_ptr<tracing::FileLineColLocBreakpoint>>
+      locationBreakpoints;
+};
+
+InstallDebugHandler::InstallDebugHandler(MLIRContext &context,
+                                         const DebugConfig &config)
+    : impl(std::make_unique<Impl>(context, config)) {}
+
+InstallDebugHandler::~InstallDebugHandler() = default;

diff  --git a/mlir/lib/Debug/CMakeLists.txt b/mlir/lib/Debug/CMakeLists.txt
index b1b0c6e3bca1b..65342beb5bd06 100644
--- a/mlir/lib/Debug/CMakeLists.txt
+++ b/mlir/lib/Debug/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_subdirectory(Observers)
 
 add_mlir_library(MLIRDebug
+  CLOptionsSetup.cpp
   DebugCounter.cpp
   ExecutionContext.cpp
   BreakpointManagers/FileLineColLocBreakpointManager.cpp

diff  --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index e31c29e9cb105..5112d2459d3f9 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -13,6 +13,7 @@
 
 #include "mlir/Tools/mlir-opt/MlirOptMain.h"
 #include "mlir/Bytecode/BytecodeWriter.h"
+#include "mlir/Debug/CLOptionsSetup.h"
 #include "mlir/Debug/Counter.h"
 #include "mlir/Debug/DebuggerExecutionContextHook.h"
 #include "mlir/Debug/ExecutionContext.h"
@@ -89,39 +90,6 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
                  "parsing"),
         cl::location(useExplicitModuleFlag), cl::init(false));
 
-    static cl::opt<std::string, /*ExternalStorage=*/true> logActionsTo{
-        "log-actions-to",
-        cl::desc("Log action execution to a file, or stderr if "
-                 " '-' is passed"),
-        cl::location(logActionsToFlag)};
-
-    static cl::list<std::string> logActionLocationFilter(
-        "log-mlir-actions-filter",
-        cl::desc(
-            "Comma separated list of locations to filter actions from logging"),
-        cl::CommaSeparated,
-        cl::cb<void, std::string>([&](const std::string &location) {
-          static bool register_once = [&] {
-            addLogActionLocFilter(&locBreakpointManager);
-            return true;
-          }();
-          (void)register_once;
-          static std::vector<std::string> locations;
-          locations.push_back(location);
-          StringRef locStr = locations.back();
-
-          // Parse the individual location filters and set the breakpoints.
-          auto diag = [](Twine msg) { llvm::errs() << msg << "\n"; };
-          auto locBreakpoint =
-              tracing::FileLineColLocBreakpoint::parseFromString(locStr, diag);
-          if (failed(locBreakpoint)) {
-            llvm::errs() << "Invalid location filter: " << locStr << "\n";
-            exit(1);
-          }
-          auto [file, line, col] = *locBreakpoint;
-          locBreakpointManager.addBreakpoint(file, line, col);
-        }));
-
     static cl::opt<bool, /*ExternalStorage=*/true> showDialects(
         "show-dialects",
         cl::desc("Print the list of registered dialects and exit"),
@@ -171,9 +139,6 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
   /// Pointer to static dialectPlugins variable in constructor, needed by
   /// setDialectPluginsCallback(DialectRegistry&).
   cl::list<std::string> *dialectPlugins = nullptr;
-
-  /// The breakpoint manager for the log action location filter.
-  tracing::FileLineColLocBreakpointManager locBreakpointManager;
 };
 } // namespace
 
@@ -181,9 +146,11 @@ ManagedStatic<MlirOptMainConfigCLOptions> clOptionsConfig;
 
 void MlirOptMainConfig::registerCLOptions(DialectRegistry &registry) {
   clOptionsConfig->setDialectPluginsCallback(registry);
+  tracing::DebugConfig::registerCLOptions();
 }
 
 MlirOptMainConfig MlirOptMainConfig::createFromCLOptions() {
+  clOptionsConfig->setDebugConfig(tracing::DebugConfig::createFromCLOptions());
   return *clOptionsConfig;
 }
 
@@ -219,53 +186,6 @@ void MlirOptMainConfigCLOptions::setDialectPluginsCallback(
   });
 }
 
-/// Set the ExecutionContext on the context and handle the observers.
-class InstallDebugHandler {
-public:
-  InstallDebugHandler(MLIRContext &context, const MlirOptMainConfig &config) {
-    if (config.getLogActionsTo().empty() &&
-        !config.isDebuggerActionHookEnabled()) {
-      if (tracing::DebugCounter::isActivated())
-        context.registerActionHandler(tracing::DebugCounter());
-      return;
-    }
-    llvm::errs() << "ExecutionContext registered on the context";
-    if (tracing::DebugCounter::isActivated())
-      emitError(UnknownLoc::get(&context),
-                "Debug counters are incompatible with --log-actions-to and "
-                "--mlir-enable-debugger-hook options and are disabled");
-    if (!config.getLogActionsTo().empty()) {
-      std::string errorMessage;
-      logActionsFile = openOutputFile(config.getLogActionsTo(), &errorMessage);
-      if (!logActionsFile) {
-        emitError(UnknownLoc::get(&context),
-                  "Opening file for --log-actions-to failed: ")
-            << errorMessage << "\n";
-        return;
-      }
-      logActionsFile->keep();
-      raw_fd_ostream &logActionsStream = logActionsFile->os();
-      actionLogger = std::make_unique<tracing::ActionLogger>(logActionsStream);
-      for (const auto *locationBreakpoint : config.getLogActionsLocFilters())
-        actionLogger->addBreakpointManager(locationBreakpoint);
-      executionContext.registerObserver(actionLogger.get());
-    }
-    if (config.isDebuggerActionHookEnabled()) {
-      llvm::errs() << " (with Debugger hook)";
-      setupDebuggerExecutionContextHook(executionContext);
-    }
-    llvm::errs() << "\n";
-    context.registerActionHandler(executionContext);
-  }
-
-private:
-  std::unique_ptr<llvm::ToolOutputFile> logActionsFile;
-  std::unique_ptr<tracing::ActionLogger> actionLogger;
-  std::vector<std::unique_ptr<tracing::FileLineColLocBreakpoint>>
-      locationBreakpoints;
-  tracing::ExecutionContext executionContext;
-};
-
 /// Perform the actions on the input file indicated by the command line flags
 /// within the specified context.
 ///
@@ -386,7 +306,8 @@ static LogicalResult processBuffer(raw_ostream &os,
   if (config.shouldVerifyDiagnostics())
     context.printOpOnDiagnostic(false);
 
-  InstallDebugHandler installDebugHandler(context, config);
+  tracing::InstallDebugHandler installDebugHandler(context,
+                                                   config.getDebugConfig());
 
   // If we are in verify diagnostics mode then we have a lot of work to do,
   // otherwise just perform the actions without worrying about it.


        


More information about the Mlir-commits mailing list