[Mlir-commits] [mlir] fa51c17 - Introduce mlir::tracing::ExecutionContext

Mehdi Amini llvmlistbot at llvm.org
Sun Mar 12 14:21:07 PDT 2023


Author: Mehdi Amini
Date: 2023-03-12T22:20:50+01:00
New Revision: fa51c1753a274fbb7a71d8fe91fd4e5caf2fa4d3

URL: https://github.com/llvm/llvm-project/commit/fa51c1753a274fbb7a71d8fe91fd4e5caf2fa4d3
DIFF: https://github.com/llvm/llvm-project/commit/fa51c1753a274fbb7a71d8fe91fd4e5caf2fa4d3.diff

LOG: Introduce mlir::tracing::ExecutionContext

This component acts as an action handler that can be registered in the
MLIRContext. It is the main orchestration of the infrastructure, and implements
support for clients to hook there and snoop on or control the execution.
This is the basis to build tracing as well as a "gdb-like" control of the
compilation flow.

The ExecutionContext acts as a handler in the MLIRContext for executing an
Action. When an action is dispatched, it'll query its set of Breakpoints
managers for a breakpoint matching this action. If a breakpoint is hit, it
passes the action and the breakpoint information to a callback. The callback
is responsible for controlling the execution of the action through an enum
value it returns. Optionally, observers can be registered to be notified
before and after the callback is executed.

Differential Revision: https://reviews.llvm.org/D144812

Added: 
    mlir/include/mlir/Debug/BreakpointManager.h
    mlir/include/mlir/Debug/BreakpointManagers/TagBreakpointManager.h
    mlir/include/mlir/Debug/ExecutionContext.h
    mlir/lib/Debug/ExecutionContext.cpp
    mlir/unittests/Debug/ExecutionContextTest.cpp

Modified: 
    mlir/lib/Debug/CMakeLists.txt
    mlir/unittests/Debug/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Debug/BreakpointManager.h b/mlir/include/mlir/Debug/BreakpointManager.h
new file mode 100644
index 0000000000000..cc4be60cb69a3
--- /dev/null
+++ b/mlir/include/mlir/Debug/BreakpointManager.h
@@ -0,0 +1,95 @@
+//===- BreakpointManager.h - Breakpoint Manager Support ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRACING_BREAKPOINTMANAGER_H
+#define MLIR_TRACING_BREAKPOINTMANAGER_H
+
+#include "mlir/IR/Action.h"
+#include "llvm/ADT/MapVector.h"
+
+namespace mlir {
+namespace tracing {
+
+/// This abstract class represents a breakpoint.
+class Breakpoint {
+public:
+  virtual ~Breakpoint() = default;
+
+  /// TypeID for the subclass, used for casting purpose.
+  TypeID getTypeID() const { return typeID; }
+
+  bool isEnabled() const { return enableStatus; }
+  void enable() { enableStatus = true; }
+  void disable() { enableStatus = false; }
+  virtual void print(raw_ostream &os) const = 0;
+
+protected:
+  Breakpoint(TypeID typeID) : enableStatus(true), typeID(typeID) {}
+
+private:
+  /// The current state of the breakpoint. A breakpoint can be either enabled
+  /// or disabled.
+  bool enableStatus;
+  TypeID typeID;
+};
+
+inline raw_ostream &operator<<(raw_ostream &os, const Breakpoint &breakpoint) {
+  breakpoint.print(os);
+  return os;
+}
+
+/// This class provides a CRTP wrapper around a base breakpoint class to define
+/// a few necessary utility methods.
+template <typename Derived>
+class BreakpointBase : public Breakpoint {
+public:
+  /// Support isa/dyn_cast functionality for the derived pass class.
+  static bool classof(const Breakpoint *breakpoint) {
+    return breakpoint->getTypeID() == TypeID::get<Derived>();
+  }
+
+protected:
+  BreakpointBase() : Breakpoint(TypeID::get<Derived>()) {}
+};
+
+/// A breakpoint manager is responsible for managing a set of breakpoints and
+/// matching them to a given action.
+class BreakpointManager {
+public:
+  virtual ~BreakpointManager() = default;
+
+  /// TypeID for the subclass, used for casting purpose.
+  TypeID getTypeID() const { return typeID; }
+
+  /// Try to match a Breakpoint to a given Action. If there is a match and
+  /// the breakpoint is enabled, return the breakpoint. Otherwise, return
+  /// nullptr.
+  virtual Breakpoint *match(const Action &action) const = 0;
+
+protected:
+  BreakpointManager(TypeID typeID) : typeID(typeID) {}
+
+  TypeID typeID;
+};
+
+/// CRTP base class for BreakpointManager implementations.
+template <typename Derived>
+class BreakpointManagerBase : public BreakpointManager {
+public:
+  BreakpointManagerBase() : BreakpointManager(TypeID::get<Derived>()) {}
+
+  /// Provide classof to allow casting between breakpoint manager types.
+  static bool classof(const BreakpointManager *breakpointManager) {
+    return breakpointManager->getTypeID() == TypeID::get<Derived>();
+  }
+};
+
+} // namespace tracing
+} // namespace mlir
+
+#endif // MLIR_TRACING_BREAKPOINTMANAGER_H

diff  --git a/mlir/include/mlir/Debug/BreakpointManagers/TagBreakpointManager.h b/mlir/include/mlir/Debug/BreakpointManagers/TagBreakpointManager.h
new file mode 100644
index 0000000000000..85fdb9a63286a
--- /dev/null
+++ b/mlir/include/mlir/Debug/BreakpointManagers/TagBreakpointManager.h
@@ -0,0 +1,65 @@
+//===- TagBreakpointManager.h - Simple breakpoint Support -------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DEBUG_BREAKPOINTMANAGERS_TAGBREAKPOINTMANAGER_H
+#define MLIR_DEBUG_BREAKPOINTMANAGERS_TAGBREAKPOINTMANAGER_H
+
+#include "mlir/Debug/BreakpointManager.h"
+#include "mlir/Debug/ExecutionContext.h"
+#include "mlir/IR/Action.h"
+#include "llvm/ADT/MapVector.h"
+
+namespace mlir {
+namespace tracing {
+
+/// Simple breakpoint matching an action "tag".
+class TagBreakpoint : public BreakpointBase<TagBreakpoint> {
+public:
+  TagBreakpoint(StringRef tag) : tag(tag) {}
+
+  void print(raw_ostream &os) const override { os << "Tag: `" << tag << '`'; }
+
+private:
+  /// A tag to associate the TagBreakpoint with.
+  std::string tag;
+
+  /// Allow access to `tag`.
+  friend class TagBreakpointManager;
+};
+
+/// This is a manager to store a collection of breakpoints that trigger
+/// on tags.
+class TagBreakpointManager
+    : public BreakpointManagerBase<TagBreakpointManager> {
+public:
+  Breakpoint *match(const Action &action) const override {
+    auto it = breakpoints.find(action.getTag());
+    if (it != breakpoints.end() && it->second->isEnabled())
+      return it->second.get();
+    return {};
+  }
+
+  /// Add a breakpoint to the manager for the given tag and return it.
+  /// If a breakpoint already exists for the given tag, return the existing
+  /// instance.
+  TagBreakpoint *addBreakpoint(StringRef tag) {
+    auto result = breakpoints.insert({tag, nullptr});
+    auto &it = result.first;
+    if (result.second)
+      it->second = std::make_unique<TagBreakpoint>(tag.str());
+    return it->second.get();
+  }
+
+private:
+  llvm::StringMap<std::unique_ptr<TagBreakpoint>> breakpoints;
+};
+
+} // namespace tracing
+} // namespace mlir
+
+#endif // MLIR_DEBUG_BREAKPOINTMANAGERS_TAGBREAKPOINTMANAGER_H

diff  --git a/mlir/include/mlir/Debug/ExecutionContext.h b/mlir/include/mlir/Debug/ExecutionContext.h
new file mode 100644
index 0000000000000..43a838f1f65a7
--- /dev/null
+++ b/mlir/include/mlir/Debug/ExecutionContext.h
@@ -0,0 +1,132 @@
+//===- ExecutionContext.h -  Execution Context Support *- C++ -*-=============//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRACING_EXECUTIONCONTEXT_H
+#define MLIR_TRACING_EXECUTIONCONTEXT_H
+
+#include "mlir/Debug/BreakpointManager.h"
+#include "mlir/IR/Action.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace mlir {
+namespace tracing {
+
+/// This class is used to keep track of the active actions in the stack.
+/// It provides the current action but also access to the parent entry in the
+/// stack. This allows to keep track of the nested nature in which actions may
+/// be executed.
+struct ActionActiveStack {
+public:
+  ActionActiveStack(const ActionActiveStack *parent, const Action &action,
+                    int depth)
+      : parent(parent), action(action), depth(depth) {}
+  const ActionActiveStack *getParent() const { return parent; }
+  const Action &getAction() const { return action; }
+  int getDepth() const { return depth; }
+
+private:
+  const ActionActiveStack *parent;
+  const Action &action;
+  int depth;
+};
+
+/// The ExecutionContext is the main orchestration of the infrastructure, it
+/// acts as a handler in the MLIRContext for executing an Action. When an action
+/// is dispatched, it'll query its set of Breakpoints managers for a breakpoint
+/// matching this action. If a breakpoint is hit, it passes the action and the
+/// breakpoint information to a callback. The callback is responsible for
+/// controlling the execution of the action through an enum value it returns.
+/// Optionally, observers can be registered to be notified before and after the
+/// callback is executed.
+class ExecutionContext {
+public:
+  /// Enum that allows the client of the context to control the execution of the
+  /// action.
+  /// - Apply: The action is executed.
+  /// - Skip: The action is skipped.
+  /// - Step: The action is executed and the execution is paused before the next
+  ///         action, including for nested actions encountered before the
+  ///         current action finishes.
+  /// - Next: The action is executed and the execution is paused after the
+  ///         current action finishes before the next action.
+  /// - Finish: The action is executed and the execution is paused only when we
+  ///           reach the parent/enclosing operation. If there are no enclosing
+  ///           operation, the execution continues without stopping.
+  enum Control { Apply = 1, Skip = 2, Step = 3, Next = 4, Finish = 5 };
+
+  /// The type of the callback that is used to control the execution.
+  /// The callback is passed the current action.
+  using CallbackTy = function_ref<Control(const ActionActiveStack *)>;
+
+  /// Create an ExecutionContext with a callback that is used to control the
+  /// execution.
+  ExecutionContext(CallbackTy callback) { setCallback(callback); }
+  ExecutionContext() = default;
+
+  /// Set the callback that is used to control the execution.
+  void setCallback(CallbackTy callback);
+
+  /// This abstract class defines the interface used to observe an Action
+  /// execution. It allows to be notified before and after the callback is
+  /// processed, but can't affect the execution.
+  struct Observer {
+    virtual ~Observer() = default;
+    /// This method is called before the Action is executed
+    /// If a breakpoint was hit, it is passed as an argument to the callback.
+    /// The `willExecute` argument indicates whether the action will be executed
+    /// or not.
+    /// Note that this method will be called from multiple threads concurrently
+    /// when MLIR multi-threading is enabled.
+    virtual void beforeExecute(const ActionActiveStack *action,
+                               Breakpoint *breakpoint, bool willExecute) {}
+
+    /// This method is called after the Action is executed, if it was executed.
+    /// It is not called if the action is skipped.
+    /// Note that this method will be called from multiple threads concurrently
+    /// when MLIR multi-threading is enabled.
+    virtual void afterExecute(const ActionActiveStack *action) {}
+  };
+
+  /// Register a new `Observer` on this context. It'll be notified before and
+  /// after executing an action. Note that this method is not thread-safe: it
+  /// isn't supported to add a new observer while actions may be executed.
+  void registerObserver(Observer *observer);
+
+  /// Register a new `BreakpointManager` on this context. It'll have a chance to
+  /// match an action before it gets executed. Note that this method is not
+  /// thread-safe: it isn't supported to add a new manager while actions may be
+  /// executed.
+  void addBreakpointManager(BreakpointManager *manager) {
+    breakpoints.push_back(manager);
+  }
+
+  /// Process the given action. This is the operator called by MLIRContext on
+  /// `executeAction()`.
+  void operator()(function_ref<void()> transform, const Action &action);
+
+private:
+  /// Callback that is executed when a breakpoint is hit and allows the client
+  /// to control the execution.
+  CallbackTy onBreakpointControlExecutionCallback;
+
+  /// Next point to stop execution as describe by `Control` enum.
+  /// This is handle by indicating at which levels of depth the next
+  /// break should happen.
+  Optional<int> depthToBreak;
+
+  /// Observers that are notified before and after the callback is executed.
+  SmallVector<Observer *> observers;
+
+  /// The list of managers that are queried for breakpoints.
+  SmallVector<BreakpointManager *> breakpoints;
+};
+
+} // namespace tracing
+} // namespace mlir
+
+#endif // MLIR_TRACING_EXECUTIONCONTEXT_H

diff  --git a/mlir/lib/Debug/CMakeLists.txt b/mlir/lib/Debug/CMakeLists.txt
index c3a8d37f33b87..336749078dae4 100644
--- a/mlir/lib/Debug/CMakeLists.txt
+++ b/mlir/lib/Debug/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_library(MLIRDebug
   DebugCounter.cpp
+  ExecutionContext.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Debug

diff  --git a/mlir/lib/Debug/ExecutionContext.cpp b/mlir/lib/Debug/ExecutionContext.cpp
new file mode 100644
index 0000000000000..ee7c33a6b3f14
--- /dev/null
+++ b/mlir/lib/Debug/ExecutionContext.cpp
@@ -0,0 +1,97 @@
+//===- ExecutionContext.cpp - Debug Execution Context Support -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Debug/ExecutionContext.h"
+
+#include "llvm/ADT/ScopeExit.h"
+
+#include <cstddef>
+
+using namespace mlir;
+using namespace mlir::tracing;
+
+//===----------------------------------------------------------------------===//
+// ExecutionContext
+//===----------------------------------------------------------------------===//
+
+static const thread_local ActionActiveStack *actionStack = nullptr;
+
+void ExecutionContext::setCallback(CallbackTy callback) {
+  onBreakpointControlExecutionCallback = callback;
+}
+
+void ExecutionContext::registerObserver(Observer *observer) {
+  observers.push_back(observer);
+}
+
+void ExecutionContext::operator()(llvm::function_ref<void()> transform,
+                                  const Action &action) {
+  // Update the top of the stack with the current action.
+  int depth = 0;
+  if (actionStack)
+    depth = actionStack->getDepth() + 1;
+  ActionActiveStack info{actionStack, action, depth};
+  actionStack = &info;
+  auto raii = llvm::make_scope_exit([&]() { actionStack = info.getParent(); });
+  Breakpoint *breakpoint = nullptr;
+
+  // Invoke the callback here and handles control requests here.
+  auto handleUserInput = [&]() -> bool {
+    if (!onBreakpointControlExecutionCallback)
+      return true;
+    auto todoNext = onBreakpointControlExecutionCallback(actionStack);
+    switch (todoNext) {
+    case ExecutionContext::Apply:
+      depthToBreak = std::nullopt;
+      return true;
+    case ExecutionContext::Skip:
+      depthToBreak = std::nullopt;
+      return false;
+    case ExecutionContext::Step:
+      depthToBreak = depth + 1;
+      return true;
+    case ExecutionContext::Next:
+      depthToBreak = depth;
+      return true;
+    case ExecutionContext::Finish:
+      depthToBreak = depth - 1;
+      return true;
+    }
+    llvm::report_fatal_error("Unknown control request");
+  };
+
+  // Try to find a breakpoint that would hit on this action.
+  // Right now there is no way to collect them all, we stop at the first one.
+  for (auto *breakpointManager : breakpoints) {
+    breakpoint = breakpointManager->match(action);
+    if (breakpoint)
+      break;
+  }
+
+  bool shouldExecuteAction = true;
+  // If we have a breakpoint, or if `depthToBreak` was previously set and the
+  // current depth matches, we invoke the user-provided callback.
+  if (breakpoint || (depthToBreak && depth <= depthToBreak))
+    shouldExecuteAction = handleUserInput();
+
+  // Notify the observers about the current action.
+  for (auto *observer : observers)
+    observer->beforeExecute(actionStack, breakpoint, shouldExecuteAction);
+
+  if (shouldExecuteAction) {
+    // Execute the action here.
+    transform();
+
+    // Notify the observers about completion of the action.
+    for (auto *observer : observers)
+      observer->afterExecute(actionStack);
+  }
+
+  if (depthToBreak && depth <= depthToBreak)
+    handleUserInput();
+}

diff  --git a/mlir/unittests/Debug/CMakeLists.txt b/mlir/unittests/Debug/CMakeLists.txt
index 1d6644083049a..5ea18d2751de0 100644
--- a/mlir/unittests/Debug/CMakeLists.txt
+++ b/mlir/unittests/Debug/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_unittest(MLIRDebugTests
   DebugCounterTest.cpp
+  ExecutionContextTest.cpp
 )
 
 target_link_libraries(MLIRDebugTests

diff  --git a/mlir/unittests/Debug/ExecutionContextTest.cpp b/mlir/unittests/Debug/ExecutionContextTest.cpp
new file mode 100644
index 0000000000000..d757d5451afec
--- /dev/null
+++ b/mlir/unittests/Debug/ExecutionContextTest.cpp
@@ -0,0 +1,352 @@
+//===- ExecutionContextTest.cpp - Debug Execution Context first impl ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Debug/ExecutionContext.h"
+#include "mlir/Debug/BreakpointManagers/TagBreakpointManager.h"
+#include "llvm/ADT/MapVector.h"
+#include "gmock/gmock.h"
+
+using namespace mlir;
+using namespace mlir::tracing;
+
+namespace {
+struct DebuggerAction : public ActionImpl<DebuggerAction> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DebuggerAction)
+  static constexpr StringLiteral tag = "debugger-action";
+};
+struct OtherAction : public ActionImpl<OtherAction> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OtherAction)
+  static constexpr StringLiteral tag = "other-action";
+};
+struct ThirdAction : public ActionImpl<ThirdAction> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ThirdAction)
+  static constexpr StringLiteral tag = "third-action";
+};
+
+// Simple action that does nothing.
+void noOp() { return; }
+
+/// This test executes a stack of nested action and check that the backtrace is
+/// as expect.
+TEST(ExecutionContext, ActionActiveStackTest) {
+
+  // We'll break three time, once on each action, the backtraces should match
+  // each of the entries here.
+  std::vector<std::vector<StringRef>> expectedStacks = {
+      {DebuggerAction::tag},
+      {OtherAction::tag, DebuggerAction::tag},
+      {ThirdAction::tag, OtherAction::tag, DebuggerAction::tag}};
+
+  auto checkStacks = [&](const ActionActiveStack *backtrace,
+                         const std::vector<StringRef> &currentStack) {
+    ASSERT_EQ((int)currentStack.size(), backtrace->getDepth() + 1);
+    for (StringRef stackEntry : currentStack) {
+      ASSERT_NE(backtrace, nullptr);
+      ASSERT_EQ(stackEntry, backtrace->getAction().getTag());
+      backtrace = backtrace->getParent();
+    }
+  };
+
+  std::vector<ExecutionContext::Control> controlSequence = {
+      ExecutionContext::Step, ExecutionContext::Step, ExecutionContext::Apply};
+  int idx = 0;
+  StringRef current;
+  int currentDepth = -1;
+  auto onBreakpoint = [&](const ActionActiveStack *backtrace) {
+    current = backtrace->getAction().getTag();
+    currentDepth = backtrace->getDepth();
+    checkStacks(backtrace, expectedStacks[idx]);
+    return controlSequence[idx++];
+  };
+
+  TagBreakpointManager simpleManager;
+  ExecutionContext executionCtx(onBreakpoint);
+  executionCtx.addBreakpointManager(&simpleManager);
+  std::vector<TagBreakpoint *> breakpoints;
+  breakpoints.push_back(simpleManager.addBreakpoint(DebuggerAction::tag));
+  breakpoints.push_back(simpleManager.addBreakpoint(OtherAction::tag));
+  breakpoints.push_back(simpleManager.addBreakpoint(ThirdAction::tag));
+
+  auto third = [&]() {
+    EXPECT_EQ(current, ThirdAction::tag);
+    EXPECT_EQ(currentDepth, 2);
+  };
+  auto nested = [&]() {
+    EXPECT_EQ(current, OtherAction::tag);
+    EXPECT_EQ(currentDepth, 1);
+    executionCtx(third, ThirdAction{});
+  };
+  auto original = [&]() {
+    EXPECT_EQ(current, DebuggerAction::tag);
+    EXPECT_EQ(currentDepth, 0);
+    executionCtx(nested, OtherAction{});
+    return;
+  };
+
+  executionCtx(original, DebuggerAction{});
+}
+
+TEST(ExecutionContext, DebuggerTest) {
+  // Check matching and non matching breakpoints, with various enable/disable
+  // schemes.
+  int match = 0;
+  auto onBreakpoint = [&match](const ActionActiveStack *backtrace) {
+    match++;
+    return ExecutionContext::Skip;
+  };
+  TagBreakpointManager simpleManager;
+  ExecutionContext executionCtx(onBreakpoint);
+  executionCtx.addBreakpointManager(&simpleManager);
+
+  executionCtx(noOp, DebuggerAction{});
+  EXPECT_EQ(match, 0);
+
+  Breakpoint *dbgBreakpoint = simpleManager.addBreakpoint(DebuggerAction::tag);
+  executionCtx(noOp, DebuggerAction{});
+  EXPECT_EQ(match, 1);
+
+  dbgBreakpoint->disable();
+  executionCtx(noOp, DebuggerAction{});
+  EXPECT_EQ(match, 1);
+
+  dbgBreakpoint->enable();
+  executionCtx(noOp, DebuggerAction{});
+  EXPECT_EQ(match, 2);
+
+  executionCtx(noOp, OtherAction{});
+  EXPECT_EQ(match, 2);
+}
+
+TEST(ExecutionContext, ApplyTest) {
+  // Test the "apply" control.
+  std::vector<StringRef> tagSequence = {DebuggerAction::tag};
+  std::vector<ExecutionContext::Control> controlSequence = {
+      ExecutionContext::Apply};
+  int idx = 0, counter = 0;
+  auto onBreakpoint = [&](const ActionActiveStack *backtrace) {
+    ++counter;
+    EXPECT_EQ(tagSequence[idx], backtrace->getAction().getTag());
+    return controlSequence[idx++];
+  };
+  auto callback = [&]() { EXPECT_EQ(counter, 1); };
+  TagBreakpointManager simpleManager;
+  ExecutionContext executionCtx(onBreakpoint);
+  executionCtx.addBreakpointManager(&simpleManager);
+  simpleManager.addBreakpoint(DebuggerAction::tag);
+
+  executionCtx(callback, DebuggerAction{});
+  EXPECT_EQ(counter, 1);
+}
+
+TEST(ExecutionContext, SkipTest) {
+  // Test the "skip" control.
+  std::vector<StringRef> tagSequence = {DebuggerAction::tag,
+                                        DebuggerAction::tag};
+  std::vector<ExecutionContext::Control> controlSequence = {
+      ExecutionContext::Apply, ExecutionContext::Skip};
+  int idx = 0, counter = 0, executionCounter = 0;
+  auto onBreakpoint = [&](const ActionActiveStack *backtrace) {
+    ++counter;
+    EXPECT_EQ(tagSequence[idx], backtrace->getAction().getTag());
+    return controlSequence[idx++];
+  };
+  auto callback = [&]() { ++executionCounter; };
+  TagBreakpointManager simpleManager;
+  ExecutionContext executionCtx(onBreakpoint);
+  executionCtx.addBreakpointManager(&simpleManager);
+  simpleManager.addBreakpoint(DebuggerAction::tag);
+
+  executionCtx(callback, DebuggerAction{});
+  executionCtx(callback, DebuggerAction{});
+  EXPECT_EQ(counter, 2);
+  EXPECT_EQ(executionCounter, 1);
+}
+
+TEST(ExecutionContext, StepApplyTest) {
+  // Test the "step" control with a nested action.
+  std::vector<StringRef> tagSequence = {DebuggerAction::tag, OtherAction::tag};
+  std::vector<ExecutionContext::Control> controlSequence = {
+      ExecutionContext::Step, ExecutionContext::Apply};
+  int idx = 0, counter = 0;
+  auto onBreakpoint = [&](const ActionActiveStack *backtrace) {
+    ++counter;
+    EXPECT_EQ(tagSequence[idx], backtrace->getAction().getTag());
+    return controlSequence[idx++];
+  };
+  TagBreakpointManager simpleManager;
+  ExecutionContext executionCtx(onBreakpoint);
+  executionCtx.addBreakpointManager(&simpleManager);
+  simpleManager.addBreakpoint(DebuggerAction::tag);
+  auto nested = [&]() { EXPECT_EQ(counter, 2); };
+  auto original = [&]() {
+    EXPECT_EQ(counter, 1);
+    executionCtx(nested, OtherAction{});
+  };
+
+  executionCtx(original, DebuggerAction{});
+  EXPECT_EQ(counter, 2);
+}
+
+TEST(ExecutionContext, StepNothingInsideTest) {
+  // Test the "step" control without a nested action.
+  std::vector<StringRef> tagSequence = {DebuggerAction::tag,
+                                        DebuggerAction::tag};
+  std::vector<ExecutionContext::Control> controlSequence = {
+      ExecutionContext::Step, ExecutionContext::Step};
+  int idx = 0, counter = 0;
+  auto onBreakpoint = [&](const ActionActiveStack *backtrace) {
+    ++counter;
+    EXPECT_EQ(tagSequence[idx], backtrace->getAction().getTag());
+    return controlSequence[idx++];
+  };
+  auto callback = [&]() { EXPECT_EQ(counter, 1); };
+  TagBreakpointManager simpleManager;
+  ExecutionContext executionCtx(onBreakpoint);
+  executionCtx.addBreakpointManager(&simpleManager);
+  simpleManager.addBreakpoint(DebuggerAction::tag);
+
+  executionCtx(callback, DebuggerAction{});
+  EXPECT_EQ(counter, 2);
+}
+
+TEST(ExecutionContext, NextTest) {
+  // Test the "next" control.
+  std::vector<StringRef> tagSequence = {DebuggerAction::tag,
+                                        DebuggerAction::tag};
+  std::vector<ExecutionContext::Control> controlSequence = {
+      ExecutionContext::Next, ExecutionContext::Next};
+  int idx = 0, counter = 0;
+  auto onBreakpoint = [&](const ActionActiveStack *backtrace) {
+    ++counter;
+    EXPECT_EQ(tagSequence[idx], backtrace->getAction().getTag());
+    return controlSequence[idx++];
+  };
+  auto callback = [&]() { EXPECT_EQ(counter, 1); };
+  TagBreakpointManager simpleManager;
+  ExecutionContext executionCtx(onBreakpoint);
+  executionCtx.addBreakpointManager(&simpleManager);
+  simpleManager.addBreakpoint(DebuggerAction::tag);
+
+  executionCtx(callback, DebuggerAction{});
+  EXPECT_EQ(counter, 2);
+}
+
+TEST(ExecutionContext, FinishTest) {
+  // Test the "finish" control.
+  std::vector<StringRef> tagSequence = {DebuggerAction::tag, OtherAction::tag,
+                                        DebuggerAction::tag};
+  std::vector<ExecutionContext::Control> controlSequence = {
+      ExecutionContext::Step, ExecutionContext::Finish,
+      ExecutionContext::Apply};
+  int idx = 0, counter = 0;
+  auto onBreakpoint = [&](const ActionActiveStack *backtrace) {
+    ++counter;
+    EXPECT_EQ(tagSequence[idx], backtrace->getAction().getTag());
+    return controlSequence[idx++];
+  };
+  TagBreakpointManager simpleManager;
+  ExecutionContext executionCtx(onBreakpoint);
+  executionCtx.addBreakpointManager(&simpleManager);
+  simpleManager.addBreakpoint(DebuggerAction::tag);
+  auto nested = [&]() { EXPECT_EQ(counter, 2); };
+  auto original = [&]() {
+    EXPECT_EQ(counter, 1);
+    executionCtx(nested, OtherAction{});
+    EXPECT_EQ(counter, 2);
+  };
+
+  executionCtx(original, DebuggerAction{});
+  EXPECT_EQ(counter, 3);
+}
+
+TEST(ExecutionContext, FinishBreakpointInNestedTest) {
+  // Test the "finish" control with a breakpoint in the nested action.
+  std::vector<StringRef> tagSequence = {OtherAction::tag, DebuggerAction::tag};
+  std::vector<ExecutionContext::Control> controlSequence = {
+      ExecutionContext::Finish, ExecutionContext::Apply};
+  int idx = 0, counter = 0;
+  auto onBreakpoint = [&](const ActionActiveStack *backtrace) {
+    ++counter;
+    EXPECT_EQ(tagSequence[idx], backtrace->getAction().getTag());
+    return controlSequence[idx++];
+  };
+  TagBreakpointManager simpleManager;
+  ExecutionContext executionCtx(onBreakpoint);
+  executionCtx.addBreakpointManager(&simpleManager);
+  simpleManager.addBreakpoint(OtherAction::tag);
+
+  auto nested = [&]() { EXPECT_EQ(counter, 1); };
+  auto original = [&]() {
+    EXPECT_EQ(counter, 0);
+    executionCtx(nested, OtherAction{});
+    EXPECT_EQ(counter, 1);
+  };
+
+  executionCtx(original, DebuggerAction{});
+  EXPECT_EQ(counter, 2);
+}
+
+TEST(ExecutionContext, FinishNothingBackTest) {
+  // Test the "finish" control without a nested action.
+  std::vector<StringRef> tagSequence = {DebuggerAction::tag};
+  std::vector<ExecutionContext::Control> controlSequence = {
+      ExecutionContext::Finish};
+  int idx = 0, counter = 0;
+  auto onBreakpoint = [&](const ActionActiveStack *backtrace) {
+    ++counter;
+    EXPECT_EQ(tagSequence[idx], backtrace->getAction().getTag());
+    return controlSequence[idx++];
+  };
+  auto callback = [&]() { EXPECT_EQ(counter, 1); };
+  TagBreakpointManager simpleManager;
+  ExecutionContext executionCtx(onBreakpoint);
+  executionCtx.addBreakpointManager(&simpleManager);
+  simpleManager.addBreakpoint(DebuggerAction::tag);
+
+  executionCtx(callback, DebuggerAction{});
+  EXPECT_EQ(counter, 1);
+}
+
+TEST(ExecutionContext, EnableDisableBreakpointOnCallback) {
+  // Test enabling and disabling breakpoints while executing the action.
+  std::vector<StringRef> tagSequence = {DebuggerAction::tag, ThirdAction::tag,
+                                        OtherAction::tag, DebuggerAction::tag};
+  std::vector<ExecutionContext::Control> controlSequence = {
+      ExecutionContext::Apply, ExecutionContext::Finish,
+      ExecutionContext::Finish, ExecutionContext::Apply};
+  int idx = 0, counter = 0;
+  auto onBreakpoint = [&](const ActionActiveStack *backtrace) {
+    ++counter;
+    EXPECT_EQ(tagSequence[idx], backtrace->getAction().getTag());
+    return controlSequence[idx++];
+  };
+
+  TagBreakpointManager simpleManager;
+  ExecutionContext executionCtx(onBreakpoint);
+  executionCtx.addBreakpointManager(&simpleManager);
+  simpleManager.addBreakpoint(DebuggerAction::tag);
+  Breakpoint *toBeDisabled = simpleManager.addBreakpoint(OtherAction::tag);
+
+  auto third = [&]() { EXPECT_EQ(counter, 2); };
+  auto nested = [&]() {
+    EXPECT_EQ(counter, 1);
+    executionCtx(third, ThirdAction{});
+    EXPECT_EQ(counter, 2);
+  };
+  auto original = [&]() {
+    EXPECT_EQ(counter, 1);
+    toBeDisabled->disable();
+    simpleManager.addBreakpoint(ThirdAction::tag);
+    executionCtx(nested, OtherAction{});
+    EXPECT_EQ(counter, 3);
+  };
+
+  executionCtx(original, DebuggerAction{});
+  EXPECT_EQ(counter, 4);
+}
+} // namespace


        


More information about the Mlir-commits mailing list