[Mlir-commits] [mlir] d09c805 - Add a breakpoint manager that matches based on File/Line/Col Locations

Mehdi Amini llvmlistbot at llvm.org
Fri Apr 21 21:28:43 PDT 2023


Author: Mehdi Amini
Date: 2023-04-21T22:28:27-06:00
New Revision: d09c80515d0e7b1f1a81d3f18a3e799565f5e969

URL: https://github.com/llvm/llvm-project/commit/d09c80515d0e7b1f1a81d3f18a3e799565f5e969
DIFF: https://github.com/llvm/llvm-project/commit/d09c80515d0e7b1f1a81d3f18a3e799565f5e969.diff

LOG: Add a breakpoint manager that matches based on File/Line/Col Locations

This will match the locations attached to the IRunits passed in as context
with an action.

Differential Revision: https://reviews.llvm.org/D144815

Added: 
    mlir/include/mlir/Debug/BreakpointManagers/FileLineColLocBreakpointManager.h
    mlir/lib/Debug/BreakpointManagers/FileLineColLocBreakpointManager.cpp
    mlir/test/Pass/action-logging-filter.mlir
    mlir/unittests/Debug/FileLineColLocBreakpointManagerTest.cpp

Modified: 
    mlir/include/mlir/Debug/Observers/ActionLogging.h
    mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
    mlir/lib/Debug/CMakeLists.txt
    mlir/lib/Debug/Observers/ActionLogging.cpp
    mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
    mlir/unittests/Debug/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Debug/BreakpointManagers/FileLineColLocBreakpointManager.h b/mlir/include/mlir/Debug/BreakpointManagers/FileLineColLocBreakpointManager.h
new file mode 100644
index 0000000000000..d4f9a6ebe7fa4
--- /dev/null
+++ b/mlir/include/mlir/Debug/BreakpointManagers/FileLineColLocBreakpointManager.h
@@ -0,0 +1,137 @@
+//===- FileLineColLocBreakpointManager.h - TODO: add message ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRACING_BREAKPOINTMANAGERS_FILELINECOLLOCBREAKPOINTMANAGER_H
+#define MLIR_TRACING_BREAKPOINTMANAGERS_FILELINECOLLOCBREAKPOINTMANAGER_H
+
+#include "mlir/Debug/BreakpointManager.h"
+#include "mlir/Debug/ExecutionContext.h"
+#include "mlir/IR/Action.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Operation.h"
+#include "llvm/ADT/DenseMap.h"
+#include <memory>
+#include <optional>
+
+namespace mlir {
+namespace tracing {
+
+/// This breakpoing intends to match a FileLineColLocation, that is a tuple of
+/// file name, line number, and column number. Using -1 for  the column and the
+/// line number will match any column and line number respectively.
+class FileLineColLocBreakpoint
+    : public BreakpointBase<FileLineColLocBreakpoint> {
+public:
+  FileLineColLocBreakpoint(StringRef file, int64_t line, int64_t col)
+      : line(line), col(col) {}
+
+  void print(raw_ostream &os) const override {
+    os << "Location: " << file << ':' << line << ':' << col;
+  }
+
+  /// Parse a string representation in the form of "<file>:<line>:<col>". Return
+  /// a tuple with these three elements, the first one is a StringRef pointing
+  /// into the original string.
+  static FailureOr<std::tuple<StringRef, int64_t, int64_t>> parseFromString(
+      StringRef str, llvm::function_ref<void(Twine)> diag = [](Twine) {});
+
+private:
+  /// A filename on which to break.
+  StringRef file;
+
+  /// A particular line on which to break, or -1 to break on any line.
+  int64_t line;
+
+  /// A particular column on which to break, or -1 to break on any column
+  int64_t col;
+
+  friend class FileLineColLocBreakpointManager;
+};
+
+/// This breakpoint manager is responsible for matching
+/// FileLineColLocBreakpoint. It'll extract the location from the action context
+/// looking for a FileLineColLocation, and match it against the registered
+/// breakpoints.
+class FileLineColLocBreakpointManager
+    : public BreakpointManagerBase<FileLineColLocBreakpointManager> {
+public:
+  Breakpoint *match(const Action &action) const override {
+    for (const IRUnit &unit : action.getContextIRUnits()) {
+      if (auto *op = unit.dyn_cast<Operation *>()) {
+        if (auto match = matchFromLocation(op->getLoc()))
+          return *match;
+        continue;
+      }
+      if (auto *block = unit.dyn_cast<Block *>()) {
+        for (auto &op : block->getOperations()) {
+          if (auto match = matchFromLocation(op.getLoc()))
+            return *match;
+        }
+        continue;
+      }
+      if (Region *region = unit.dyn_cast<Region *>()) {
+        if (auto match = matchFromLocation(region->getLoc()))
+          return *match;
+        continue;
+      }
+    }
+    return {};
+  }
+
+  FileLineColLocBreakpoint *addBreakpoint(StringRef file, int64_t line,
+                                          int64_t col = -1) {
+    auto &breakpoint = breakpoints[std::make_tuple(file, line, col)];
+    if (!breakpoint)
+      breakpoint = std::make_unique<FileLineColLocBreakpoint>(file, line, col);
+    return breakpoint.get();
+  }
+
+private:
+  std::optional<Breakpoint *> matchFromLocation(Location initialLoc) const {
+    std::optional<Breakpoint *> match = std::nullopt;
+    initialLoc->walk([&](Location loc) {
+      auto fileLoc = loc.dyn_cast<FileLineColLoc>();
+      if (!fileLoc)
+        return WalkResult::advance();
+      StringRef file = fileLoc.getFilename();
+      int64_t line = fileLoc.getLine();
+      int64_t col = fileLoc.getColumn();
+      auto lookup = breakpoints.find(std::make_tuple(file, line, col));
+      if (lookup != breakpoints.end() && lookup->second->isEnabled()) {
+        match = lookup->second.get();
+        return WalkResult::interrupt();
+      }
+      // If not found, check with the -1 key if we have a breakpoint for any
+      // col.
+      lookup = breakpoints.find(std::make_tuple(file, line, -1));
+      if (lookup != breakpoints.end() && lookup->second->isEnabled()) {
+        match = lookup->second.get();
+        return WalkResult::interrupt();
+      }
+      // If not found, check with the -1 key if we have a breakpoint for any
+      // line.
+      lookup = breakpoints.find(std::make_tuple(file, -1, -1));
+      if (lookup != breakpoints.end() && lookup->second->isEnabled()) {
+        match = lookup->second.get();
+        return WalkResult::interrupt();
+      }
+      return WalkResult::advance();
+    });
+    return match;
+  }
+
+  /// A map from a (filename, line, column) -> breakpoint.
+  DenseMap<std::tuple<StringRef, int64_t, int64_t>,
+           std::unique_ptr<FileLineColLocBreakpoint>>
+      breakpoints;
+};
+
+} // namespace tracing
+} // namespace mlir
+
+#endif // MLIR_TRACING_BREAKPOINTMANAGERS_FILELINECOLLOCBREAKPOINTMANAGER_H

diff  --git a/mlir/include/mlir/Debug/Observers/ActionLogging.h b/mlir/include/mlir/Debug/Observers/ActionLogging.h
index bd1d56538906a..2e45ccf0adfbe 100644
--- a/mlir/include/mlir/Debug/Observers/ActionLogging.h
+++ b/mlir/include/mlir/Debug/Observers/ActionLogging.h
@@ -30,11 +30,21 @@ struct ActionLogger : public ExecutionContext::Observer {
                      bool willExecute) override;
   void afterExecute(const ActionActiveStack *action) override;
 
+  /// If one of multiple breakpoint managers are set, only actions that are
+  /// matching a breakpoint will be logged.
+  void addBreakpointManager(const BreakpointManager *manager) {
+    breakpointManagers.push_back(manager);
+  }
+
 private:
+  /// Check if we should log this action or not.
+  bool shouldLog(const ActionActiveStack *action);
+
   raw_ostream &os;
   bool printActions;
   bool printBreakpoints;
   bool printIRUnits;
+  std::vector<const BreakpointManager *> breakpointManagers;
 };
 
 } // namespace tracing

diff  --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
index f54c29c8f6e31..39f7cd5e0bd80 100644
--- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
+++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
@@ -13,6 +13,7 @@
 #ifndef MLIR_TOOLS_MLIROPT_MLIROPTMAIN_H
 #define MLIR_TOOLS_MLIROPT_MLIROPTMAIN_H
 
+#include "mlir/Debug/BreakpointManagers/FileLineColLocBreakpointManager.h"
 #include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/StringRef.h"
 
@@ -29,6 +30,9 @@ namespace mlir {
 class DialectRegistry;
 class PassPipelineCLParser;
 class PassManager;
+namespace tracing {
+class FileLineColLocBreakpointManager;
+}
 
 /// Configuration options for the mlir-opt tool.
 /// This is intended to help building tools like mlir-opt by collecting the
@@ -82,6 +86,18 @@ class MlirOptMainConfig {
   /// Get the filename to use for logging actions.
   StringRef getLogActionsTo() const { return logActionsToFlag; }
 
+  /// Set a location breakpoint manager to filter out action logging based on
+  /// the attached IR location in the Action context. Ownership stays with the
+  /// caller.
+  void addLogActionLocFilter(tracing::BreakpointManager *breakpointManager) {
+    logActionLocationFilter.push_back(breakpointManager);
+  }
+
+  /// Get the location breakpoint managers to use to filter out action logging.
+  ArrayRef<tracing::BreakpointManager *> getLogActionsLocFilters() const {
+    return logActionLocationFilter;
+  }
+
   /// Set the callback to populate the pass manager.
   MlirOptMainConfig &
   setPassPipelineSetupFn(std::function<LogicalResult(PassManager &)> callback) {
@@ -160,6 +176,9 @@ class MlirOptMainConfig {
   /// Log action execution to the given file (or "-" for stdout)
   std::string logActionsToFlag;
 
+  /// Location Breakpoints to filter the action logging.
+  std::vector<tracing::BreakpointManager *> logActionLocationFilter;
+
   /// The callback to populate the pass manager.
   std::function<LogicalResult(PassManager &)> passPipelineCallback;
 

diff  --git a/mlir/lib/Debug/BreakpointManagers/FileLineColLocBreakpointManager.cpp b/mlir/lib/Debug/BreakpointManagers/FileLineColLocBreakpointManager.cpp
new file mode 100644
index 0000000000000..6fef83f1a8a58
--- /dev/null
+++ b/mlir/lib/Debug/BreakpointManagers/FileLineColLocBreakpointManager.cpp
@@ -0,0 +1,47 @@
+//===- FileLineColLocBreakpointManager.cpp - MLIR Optimizer Driver --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Debug/BreakpointManagers/FileLineColLocBreakpointManager.h"
+#include "mlir/IR/Diagnostics.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace mlir::tracing;
+
+FailureOr<std::tuple<StringRef, int64_t, int64_t>>
+FileLineColLocBreakpoint::parseFromString(StringRef str,
+                                          function_ref<void(Twine)> diag) {
+  // Watch at debug locations arguments are expected to be in the form:
+  // `fileName:line:col`, `fileName:line`, or `fileName`.
+
+  auto [file, lineCol] = str.split(':');
+  auto [lineStr, colStr] = lineCol.split(':');
+  if (file.empty()) {
+    if (diag)
+      diag("error: initializing FileLineColLocBreakpoint with empty file name");
+    return failure();
+  }
+
+  // Extract the line and column value
+  int64_t line = -1, col = -1;
+  if (!lineStr.empty() && lineStr.getAsInteger(0, line)) {
+    if (diag)
+      diag("error: initializing FileLineColLocBreakpoint with a non-numeric "
+           "line value: `" +
+           Twine(lineStr) + "`");
+    return failure();
+  }
+  if (!colStr.empty() && colStr.getAsInteger(0, col)) {
+    if (diag)
+      diag("error: initializing FileLineColLocBreakpoint with a non-numeric "
+           "col value: `" +
+           Twine(colStr) + "`");
+    return failure();
+  }
+  return std::tuple<StringRef, int64_t, int64_t>{file, line, col};
+}

diff  --git a/mlir/lib/Debug/CMakeLists.txt b/mlir/lib/Debug/CMakeLists.txt
index 481db88983cc3..e4b844a9aad4d 100644
--- a/mlir/lib/Debug/CMakeLists.txt
+++ b/mlir/lib/Debug/CMakeLists.txt
@@ -3,6 +3,7 @@ add_subdirectory(Observers)
 add_mlir_library(MLIRDebug
   DebugCounter.cpp
   ExecutionContext.cpp
+  BreakpointManagers/FileLineColLocBreakpointManager.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Debug

diff  --git a/mlir/lib/Debug/Observers/ActionLogging.cpp b/mlir/lib/Debug/Observers/ActionLogging.cpp
index 7e7c5acaaee1f..add16e84653e5 100644
--- a/mlir/lib/Debug/Observers/ActionLogging.cpp
+++ b/mlir/lib/Debug/Observers/ActionLogging.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Debug/Observers/ActionLogging.h"
+#include "mlir/Debug/BreakpointManager.h"
 #include "mlir/IR/Action.h"
 #include "llvm/Support/Threading.h"
 #include "llvm/Support/raw_ostream.h"
@@ -18,8 +19,20 @@ using namespace mlir::tracing;
 // ActionLogger
 //===----------------------------------------------------------------------===//
 
+bool ActionLogger::shouldLog(const ActionActiveStack *action) {
+  // If some condition was set, we ensured it is met before logging.
+  if (breakpointManagers.empty())
+    return true;
+  return llvm::any_of(breakpointManagers,
+                      [&](const BreakpointManager *manager) {
+                        return manager->match(action->getAction());
+                      });
+}
+
 void ActionLogger::beforeExecute(const ActionActiveStack *action,
                                  Breakpoint *breakpoint, bool willExecute) {
+  if (!shouldLog(action))
+    return;
   SmallVector<char> name;
   llvm::get_thread_name(name);
   if (name.empty()) {
@@ -51,6 +64,8 @@ void ActionLogger::beforeExecute(const ActionActiveStack *action,
 }
 
 void ActionLogger::afterExecute(const ActionActiveStack *action) {
+  if (!shouldLog(action))
+    return;
   SmallVector<char> name;
   llvm::get_thread_name(name);
   if (name.empty()) {

diff  --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 8f608ef145aa8..28324508ee4f3 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -32,6 +32,7 @@
 #include "mlir/Tools/ParseUtilities.h"
 #include "mlir/Tools/Plugins/DialectPlugin.h"
 #include "mlir/Tools/Plugins/PassPlugin.h"
+#include "llvm/ADT/StringRef.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/FileUtilities.h"
 #include "llvm/Support/InitLLVM.h"
@@ -81,6 +82,33 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
                  " '-' is passed"),
         cl::location(logActionsToFlag)};
 
+    static cl::list<std::string> logActionLocationFilter(
+        "log-mlir-actions-filter",
+        cl::desc(
+            "Comma separated list of locations to filter actions from logging"),
+        cl::CommaSeparated,
+        cl::cb<void, std::string>([&](const std::string &location) {
+          static bool register_once = [&] {
+            addLogActionLocFilter(&locBreakpointManager);
+            return true;
+          }();
+          (void)register_once;
+          static std::vector<std::string> locations;
+          locations.push_back(location);
+          StringRef locStr = locations.back();
+
+          // Parse the individual location filters and set the breakpoints.
+          auto diag = [](Twine msg) { llvm::errs() << msg << "\n"; };
+          auto locBreakpoint =
+              tracing::FileLineColLocBreakpoint::parseFromString(locStr, diag);
+          if (failed(locBreakpoint)) {
+            llvm::errs() << "Invalid location filter: " << locStr << "\n";
+            exit(1);
+          }
+          auto [file, line, col] = *locBreakpoint;
+          locBreakpointManager.addBreakpoint(file, line, col);
+        }));
+
     static cl::opt<bool, /*ExternalStorage=*/true> showDialects(
         "show-dialects",
         cl::desc("Print the list of registered dialects and exit"),
@@ -130,6 +158,9 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
   /// Pointer to static dialectPlugins variable in constructor, needed by
   /// setDialectPluginsCallback(DialectRegistry&).
   cl::list<std::string> *dialectPlugins = nullptr;
+
+  /// The breakpoint manager for the log action location filter.
+  tracing::FileLineColLocBreakpointManager locBreakpointManager;
 };
 } // namespace
 
@@ -199,6 +230,8 @@ class InstallDebugHandler {
     logActionsFile->keep();
     raw_fd_ostream &logActionsStream = logActionsFile->os();
     actionLogger = std::make_unique<tracing::ActionLogger>(logActionsStream);
+    for (const auto *locationBreakpoint : config.getLogActionsLocFilters())
+      actionLogger->addBreakpointManager(locationBreakpoint);
 
     executionContext.registerObserver(actionLogger.get());
     context.registerActionHandler(executionContext);
@@ -207,6 +240,8 @@ class InstallDebugHandler {
 private:
   std::unique_ptr<llvm::ToolOutputFile> logActionsFile;
   std::unique_ptr<tracing::ActionLogger> actionLogger;
+  std::vector<std::unique_ptr<tracing::FileLineColLocBreakpoint>>
+      locationBreakpoints;
   tracing::ExecutionContext executionContext;
 };
 

diff  --git a/mlir/test/Pass/action-logging-filter.mlir b/mlir/test/Pass/action-logging-filter.mlir
new file mode 100644
index 0000000000000..e565b18d1ce13
--- /dev/null
+++ b/mlir/test/Pass/action-logging-filter.mlir
@@ -0,0 +1,60 @@
+// Run the canonicalize on each function, use the --log-mlir-actions-filter= option
+// to filter which action should be logged.
+
+func.func @a() {
+    return
+}
+
+func.func @b() {
+    return
+}
+
+func.func @c() {
+    return
+}
+
+////////////////////////////////////
+/// 1. All actions should be logged.
+
+// RUN: mlir-opt %s --log-actions-to=- -pass-pipeline="builtin.module(func.func(canonicalize))" -o %t --mlir-disable-threading | FileCheck %s
+// Specify the current file as filter, expect to see all actions.
+// RUN: mlir-opt %s --log-mlir-actions-filter=%s --log-actions-to=- -pass-pipeline="builtin.module(func.func(canonicalize))" -o %t --mlir-disable-threading | FileCheck %s
+
+// CHECK: [thread {{.*}}] begins (no breakpoint) Action `pass-execution-action`  running `Canonicalizer` on Operation `func.func` (func.func @a() {...}
+// CHECK-NEXT: [thread {{.*}}] completed `pass-execution-action`
+// CHECK-NEXT: [thread {{.*}}] begins (no breakpoint) Action `pass-execution-action`  running `Canonicalizer` on Operation `func.func` (func.func @b() {...}
+// CHECK-NEXT: [thread {{.*}}] completed `pass-execution-action`
+// CHECK-NEXT: [thread {{.*}}] begins (no breakpoint) Action `pass-execution-action`  running `Canonicalizer` on Operation `func.func` (func.func @c() {...}
+// CHECK-NEXT: [thread {{.*}}] completed `pass-execution-action`
+
+////////////////////////////////////
+/// 2. No match
+
+// Specify a non-existing file as filter, expect to see no actions.
+// RUN: mlir-opt %s --log-mlir-actions-filter=foo.mlir --log-actions-to=- -pass-pipeline="builtin.module(func.func(canonicalize))" -o %t --mlir-disable-threading | FileCheck %s --check-prefix=CHECK-NONE --allow-empty
+// Filter on a non-matching line, expect to see no actions.
+// RUN: mlir-opt %s --log-mlir-actions-filter=%s:1 --log-actions-to=- -pass-pipeline="builtin.module(func.func(canonicalize))" -o %t --mlir-disable-threading | FileCheck %s --check-prefix=CHECK-NONE --allow-empty
+
+// Invalid Filter
+// CHECK-NONE-NOT: Canonicalizer
+
+////////////////////////////////////
+/// 3. Matching filters
+
+// Filter the second function only
+// RUN: mlir-opt %s --log-mlir-actions-filter=%s:8 --log-actions-to=- -pass-pipeline="builtin.module(func.func(canonicalize))" -o %t --mlir-disable-threading | FileCheck %s --check-prefix=CHECK-SECOND
+
+// CHECK-SECOND-NOT: @a
+// CHECK-SECOND-NOT: @c
+// CHECK-SECOND: [thread {{.*}}] begins (no breakpoint) Action `pass-execution-action`  running `Canonicalizer` on Operation `func.func` (func.func @b() {...}
+// CHECK-SECOND-NEXT: [thread {{.*}}] completed `pass-execution-action`
+
+// Filter the first and third functions
+// RUN: mlir-opt %s --log-mlir-actions-filter=%s:4,%s:12 --log-actions-to=- -pass-pipeline="builtin.module(func.func(canonicalize))" -o %t --mlir-disable-threading | FileCheck %s  --check-prefix=CHECK-FIRST-THIRD
+
+// CHECK-FIRST-THIRD-NOT: Canonicalizer
+// CHECK-FIRST-THIRD: [thread {{.*}}] begins (no breakpoint) Action `pass-execution-action`  running `Canonicalizer` on Operation `func.func` (func.func @a() {...}
+// CHECK-FIRST-THIRD-NEXT: [thread {{.*}}] completed `pass-execution-action`
+// CHECK-FIRST-THIRD-NEXT: [thread {{.*}}] begins (no breakpoint) Action `pass-execution-action`  running `Canonicalizer` on Operation `func.func` (func.func @c() {...}
+// CHECK-FIRST-THIRD-NEXT: [thread {{.*}}] completed `pass-execution-action`
+// CHECK-FIRST-THIRD-NOT: Canonicalizer

diff  --git a/mlir/unittests/Debug/CMakeLists.txt b/mlir/unittests/Debug/CMakeLists.txt
index 5ea18d2751de0..59728bc819d9a 100644
--- a/mlir/unittests/Debug/CMakeLists.txt
+++ b/mlir/unittests/Debug/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_unittest(MLIRDebugTests
   DebugCounterTest.cpp
   ExecutionContextTest.cpp
+  FileLineColLocBreakpointManagerTest.cpp
 )
 
 target_link_libraries(MLIRDebugTests

diff  --git a/mlir/unittests/Debug/FileLineColLocBreakpointManagerTest.cpp b/mlir/unittests/Debug/FileLineColLocBreakpointManagerTest.cpp
new file mode 100644
index 0000000000000..e9cac3949ddbd
--- /dev/null
+++ b/mlir/unittests/Debug/FileLineColLocBreakpointManagerTest.cpp
@@ -0,0 +1,232 @@
+//===- FileLineColLocBreakpointManagerTest.cpp - --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Debug/BreakpointManagers/FileLineColLocBreakpointManager.h"
+#include "mlir/Debug/ExecutionContext.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/OperationSupport.h"
+#include "llvm/ADT/STLExtras.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::tracing;
+
+static Operation *createOp(MLIRContext *context, Location loc,
+                           StringRef operationName,
+                           unsigned int numRegions = 0) {
+  context->allowUnregisteredDialects();
+  return Operation::create(loc, OperationName(operationName, context),
+                           std::nullopt, std::nullopt, std::nullopt,
+                           std::nullopt, numRegions);
+}
+
+namespace {
+struct FileLineColLocTestingAction
+    : public ActionImpl<FileLineColLocTestingAction> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FileLineColLocTestingAction)
+  static constexpr StringLiteral tag = "file-line-col-loc-testing-action";
+  FileLineColLocTestingAction(ArrayRef<IRUnit> irUnits)
+      : ActionImpl<FileLineColLocTestingAction>(irUnits) {}
+};
+
+TEST(FileLineColLocBreakpointManager, OperationMatch) {
+  // This test will process a sequence of operation and check various situation
+  // with a breakpoint hitting or not based on the location attached to the
+  // operation. When a breakpoint hits, the action is skipped and the counter is
+  // not incremented.
+  ExecutionContext executionCtx(
+      [](const ActionActiveStack *) { return ExecutionContext::Skip; });
+  int counter = 0;
+  auto counterInc = [&]() { counter++; };
+
+  // Setup
+
+  MLIRContext context;
+  // Miscellaneous information to define operations
+  std::vector<StringRef> fileNames = {
+      StringRef("foo.bar"), StringRef("baz.qux"), StringRef("quux.corge")};
+  std::vector<std::pair<unsigned, unsigned>> lineColLoc = {{42, 7}, {24, 3}};
+  Location callee = UnknownLoc::get(&context),
+           caller = UnknownLoc::get(&context), loc = UnknownLoc::get(&context);
+
+  // Set of operations over where we are going to be testing the functionality
+  std::vector<Operation *> operations = {
+      createOp(&context, CallSiteLoc::get(callee, caller),
+               "callSiteLocOperation"),
+      createOp(&context,
+               FileLineColLoc::get(&context, fileNames[0], lineColLoc[0].first,
+                                   lineColLoc[0].second),
+               "fileLineColLocOperation"),
+      createOp(&context, FusedLoc::get(&context, {}, Attribute()),
+               "fusedLocOperation"),
+      createOp(&context, NameLoc::get(StringAttr::get(&context, fileNames[2])),
+               "nameLocOperation"),
+      createOp(&context, OpaqueLoc::get<void *>(nullptr, loc),
+               "opaqueLocOperation"),
+      createOp(&context,
+               FileLineColLoc::get(&context, fileNames[1], lineColLoc[1].first,
+                                   lineColLoc[1].second),
+               "anotherFileLineColLocOperation"),
+      createOp(&context, UnknownLoc::get(&context), "unknownLocOperation"),
+  };
+
+  FileLineColLocBreakpointManager breakpointManager;
+  executionCtx.addBreakpointManager(&breakpointManager);
+
+  // Test
+
+  // Basic case is that no breakpoint is set and the counter is incremented for
+  // every op.
+  auto checkNoMatch = [&]() {
+    counter = 0;
+    for (auto enumeratedOp : llvm::enumerate(operations)) {
+      executionCtx(counterInc,
+                   FileLineColLocTestingAction({enumeratedOp.value()}));
+      EXPECT_EQ(counter, static_cast<int>(enumeratedOp.index() + 1));
+    }
+  };
+  checkNoMatch();
+
+  // Set a breakpoint matching only the second operation in the list.
+  auto *breakpoint = breakpointManager.addBreakpoint(
+      fileNames[0], lineColLoc[0].first, lineColLoc[0].second);
+  auto checkMatchIdxs = [&](DenseSet<int> idxs) {
+    counter = 0;
+    int reference = 0;
+    for (int i = 0; i < (int)operations.size(); ++i) {
+      executionCtx(counterInc, FileLineColLocTestingAction({operations[i]}));
+      if (!idxs.contains(i))
+        reference++;
+      EXPECT_EQ(counter, reference);
+    }
+  };
+  checkMatchIdxs({1});
+
+  // Check that disabling the breakpoing brings us back to the original
+  // behavior.
+  breakpoint->disable();
+  checkNoMatch();
+
+  // Adding a breakpoint that won't match any location shouldn't affect the
+  // behavior.
+  breakpointManager.addBreakpoint(StringRef("random.file"), 3, 14);
+  checkNoMatch();
+
+  // Set a breakpoint matching only the fifth operation in the list.
+  breakpointManager.addBreakpoint(fileNames[1], lineColLoc[1].first,
+                                  lineColLoc[1].second);
+  counter = 0;
+  checkMatchIdxs({5});
+
+  // Re-enable the breakpoint matching only the second operation in the list.
+  // We now expect matching of operations 1 and 5.
+  breakpoint->enable();
+  checkMatchIdxs({1, 5});
+
+  for (auto *op : operations) {
+    op->destroy();
+  }
+}
+
+TEST(FileLineColLocBreakpointManager, BlockMatch) {
+  // This test will process a block and check various situation with
+  // a breakpoint hitting or not based on the location attached.
+  // When a breakpoint hits, the action is skipped and the counter is not
+  // incremented.
+  ExecutionContext executionCtx(
+      [](const ActionActiveStack *) { return ExecutionContext::Skip; });
+  int counter = 0;
+  auto counterInc = [&]() { counter++; };
+
+  // Setup
+
+  MLIRContext context;
+  std::vector<StringRef> fileNames = {StringRef("grault.garply"),
+                                      StringRef("waldo.fred")};
+  std::vector<std::pair<unsigned, unsigned>> lineColLoc = {{42, 7}, {24, 3}};
+  Operation *frontOp = createOp(&context,
+                                FileLineColLoc::get(&context, fileNames.front(),
+                                                    lineColLoc.front().first,
+                                                    lineColLoc.front().second),
+                                "firstOperation");
+  Operation *backOp = createOp(&context,
+                               FileLineColLoc::get(&context, fileNames.back(),
+                                                   lineColLoc.back().first,
+                                                   lineColLoc.back().second),
+                               "secondOperation");
+  Block block;
+  block.push_back(frontOp);
+  block.push_back(backOp);
+
+  FileLineColLocBreakpointManager breakpointManager;
+  executionCtx.addBreakpointManager(&breakpointManager);
+
+  // Test
+
+  executionCtx(counterInc, FileLineColLocTestingAction({&block}));
+  EXPECT_EQ(counter, 1);
+
+  auto *breakpoint = breakpointManager.addBreakpoint(
+      fileNames.front(), lineColLoc.front().first, lineColLoc.front().second);
+  counter = 0;
+  executionCtx(counterInc, FileLineColLocTestingAction({&block}));
+  EXPECT_EQ(counter, 0);
+  breakpoint->disable();
+  executionCtx(counterInc, FileLineColLocTestingAction({&block}));
+  EXPECT_EQ(counter, 1);
+
+  breakpoint = breakpointManager.addBreakpoint(
+      fileNames.back(), lineColLoc.back().first, lineColLoc.back().second);
+  counter = 0;
+  executionCtx(counterInc, FileLineColLocTestingAction({&block}));
+  EXPECT_EQ(counter, 0);
+  breakpoint->disable();
+  executionCtx(counterInc, FileLineColLocTestingAction({&block}));
+  EXPECT_EQ(counter, 1);
+}
+
+TEST(FileLineColLocBreakpointManager, RegionMatch) {
+  // This test will process a region and check various situation with
+  // a breakpoint hitting or not based on the location attached.
+  // When a breakpoint hits, the action is skipped and the counter is not
+  // incremented.
+  ExecutionContext executionCtx(
+      [](const ActionActiveStack *) { return ExecutionContext::Skip; });
+  int counter = 0;
+  auto counterInc = [&]() { counter++; };
+
+  // Setup
+
+  MLIRContext context;
+  StringRef fileName("plugh.xyzzy");
+  unsigned line = 42, col = 7;
+  Operation *containerOp =
+      createOp(&context, FileLineColLoc::get(&context, fileName, line, col),
+               "containerOperation", 1);
+  Region &region = containerOp->getRegion(0);
+
+  FileLineColLocBreakpointManager breakpointManager;
+  executionCtx.addBreakpointManager(&breakpointManager);
+
+  // Test
+  counter = 0;
+  executionCtx(counterInc, FileLineColLocTestingAction({&region}));
+  EXPECT_EQ(counter, 1);
+  auto *breakpoint = breakpointManager.addBreakpoint(fileName, line, col);
+  executionCtx(counterInc, FileLineColLocTestingAction({&region}));
+  EXPECT_EQ(counter, 1);
+  breakpoint->disable();
+  executionCtx(counterInc, FileLineColLocTestingAction({&region}));
+  EXPECT_EQ(counter, 2);
+
+  containerOp->destroy();
+}
+} // namespace


        


More information about the Mlir-commits mailing list