[Mlir-commits] [mlir] 7fbcf10 - Change the DebugAction paradigm to delegate the control to the handler

Mehdi Amini llvmlistbot at llvm.org
Mon Mar 6 06:59:31 PST 2023


Author: Mehdi Amini
Date: 2023-03-06T15:58:26+01:00
New Revision: 7fbcf10e2e7a408dc551cab5c180ebc2fb679454

URL: https://github.com/llvm/llvm-project/commit/7fbcf10e2e7a408dc551cab5c180ebc2fb679454
DIFF: https://github.com/llvm/llvm-project/commit/7fbcf10e2e7a408dc551cab5c180ebc2fb679454.diff

LOG: Change the DebugAction paradigm to delegate the control to the handler

At the moment, we invoke `shouldExecute()` that way:

```
if (manager.shouldExecute<DebugAction>(currentOp) {
  // apply a transformation
  …
}
```

In this sequence, the manager isn’t involved in the actual execution
of the action and can’t develop rich instrumentations. Instead the API
could let the control to the handler itself:

```
// Execute the action under the control of the manager
manager.execute<DebugAction>(currentOp, [&]() {
  // apply the transformation in this callback
  …
});
```

This inversion of control (by injecting a callback) allows handlers to
implement potentially new interesting features: for example, snapshot
the IR before and after the action, or record an action execution time.
More importantly, it will allow to capture the nesting execution of
actions.

On the other side: handlers receives now a DebugAction object that wraps
generic information (tag and description especially) as well as
action-specific data.

Finally, the DebugActionManager is now enabled in release builds as
well.

Differential Revision: https://reviews.llvm.org/D144808

Added: 
    

Modified: 
    mlir/include/mlir/Support/DebugAction.h
    mlir/include/mlir/Support/DebugCounter.h
    mlir/lib/Support/DebugCounter.cpp
    mlir/unittests/Support/DebugActionTest.cpp
    mlir/unittests/Support/DebugCounterTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Support/DebugAction.h b/mlir/include/mlir/Support/DebugAction.h
index e1dc25e7792f8..aa440289862bd 100644
--- a/mlir/include/mlir/Support/DebugAction.h
+++ b/mlir/include/mlir/Support/DebugAction.h
@@ -8,9 +8,7 @@
 //
 // This file contains definitions for the debug action framework. This framework
 // allows for external entities to control certain actions taken by the compiler
-// by registering handler functions. A debug action handler provides the
-// internal implementation for the various queries on a debug action, such as
-// whether it should execute or not.
+// by registering handler functions.
 //
 //===----------------------------------------------------------------------===//
 
@@ -29,6 +27,34 @@
 
 namespace mlir {
 
+/// This class represents the base class of a debug action.
+class DebugActionBase {
+public:
+  virtual ~DebugActionBase() = default;
+
+  /// Return the unique action id of this action, use for casting
+  /// functionality.
+  TypeID getActionID() const { return actionID; }
+
+  StringRef getTag() const { return tag; }
+
+  StringRef getDescription() const { return desc; }
+
+  virtual void print(raw_ostream &os) const {
+    os << "Action \"" << tag << "\" : " << desc << "\n";
+  }
+
+protected:
+  DebugActionBase(TypeID actionID, StringRef tag, StringRef desc)
+      : actionID(actionID), tag(tag), desc(desc) {}
+
+  /// The type of the derived action class. This allows for detecting the
+  /// specific handler of a given action type.
+  TypeID actionID;
+  StringRef tag;
+  StringRef desc;
+};
+
 //===----------------------------------------------------------------------===//
 // DebugActionManager
 //===----------------------------------------------------------------------===//
@@ -74,11 +100,11 @@ class DebugActionManager {
   public:
     GenericHandler() : HandlerBase(TypeID::get<GenericHandler>()) {}
 
-    /// This hook allows for controlling whether an action should execute or
-    /// not. It should return failure if the handler could not process the
-    /// action, passing it to the next registered handler.
-    virtual FailureOr<bool> shouldExecute(StringRef actionTag,
-                                          StringRef description) {
+    /// This hook allows for controlling the execution of an action. It should
+    /// return failure if the handler could not process the action, or whether
+    /// the `transform` was executed or not.
+    virtual FailureOr<bool> execute(function_ref<void()> transform,
+                                    const DebugActionBase &action) {
       return failure();
     }
 
@@ -90,10 +116,7 @@ class DebugActionManager {
 
   /// Register the given action handler with the manager.
   void registerActionHandler(std::unique_ptr<HandlerBase> handler) {
-    // The manager is always disabled if built without debug.
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
     actionHandlers.emplace_back(std::move(handler));
-#endif
   }
   template <typename T>
   void registerActionHandler() {
@@ -104,31 +127,35 @@ class DebugActionManager {
   // Action Queries
   //===--------------------------------------------------------------------===//
 
-  /// Returns true if the given action type should be executed, false otherwise.
-  /// `Args` are a set of parameters used by handlers of `ActionType` to
-  /// determine if the action should be executed.
+  /// Dispatch an action represented by the `transform` callback. If no handler
+  /// is found, the `transform` callback is invoked directly.
+  /// Return true if the action was executed, false otherwise.
   template <typename ActionType, typename... Args>
-  bool shouldExecute(Args &&...args) {
-    // The manager is always disabled if built without debug.
-#if !LLVM_ENABLE_ABI_BREAKING_CHECKS
-    return true;
-#else
-    // Invoke the `shouldExecute` method on the provided handler.
-    auto shouldExecuteFn = [&](auto *handler, auto &&...handlerParams) {
-      return handler->shouldExecute(
-          std::forward<decltype(handlerParams)>(handlerParams)...);
+  bool execute(function_ref<void()> transform, Args &&...args) {
+    if (actionHandlers.empty()) {
+      transform();
+      return true;
+    }
+
+    // Invoke the `execute` method on the provided handler.
+    auto executeFn = [&](auto *handler, auto &&...handlerParams) {
+      return handler->execute(
+          transform,
+          ActionType(std::forward<decltype(handlerParams)>(handlerParams)...));
     };
     FailureOr<bool> result = dispatchToHandler<ActionType, bool>(
-        shouldExecuteFn, std::forward<Args>(args)...);
+        executeFn, std::forward<Args>(args)...);
+    // no handler found, execute the transform directly.
+    if (failed(result)) {
+      transform();
+      return true;
+    }
 
-    // If the action wasn't handled, execute the action by default.
-    return succeeded(result) ? *result : true;
-#endif
+    // Return the result of the handler.
+    return *result;
   }
 
 private:
-// The manager is always disabled if built without debug.
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
   //===--------------------------------------------------------------------===//
   // Query to Handler Dispatch
   //===--------------------------------------------------------------------===//
@@ -145,16 +172,13 @@ class DebugActionManager {
                   "cannot execute action with the given set of parameters");
 
     // Process any generic or action specific handlers.
-    // TODO: We currently just pick the first handler that gives us a result,
-    // but in the future we may want to employ a reduction over all of the
-    // values returned.
-    for (std::unique_ptr<HandlerBase> &it : llvm::reverse(actionHandlers)) {
+    // The first handler that gives us a result is the one that we will return.
+    for (std::unique_ptr<HandlerBase> &it : reverse(actionHandlers)) {
       FailureOr<ResultT> result = failure();
       if (auto *handler = dyn_cast<typename ActionType::Handler>(&*it)) {
         result = handlerCallback(handler, std::forward<Args>(args)...);
       } else if (auto *genericHandler = dyn_cast<GenericHandler>(&*it)) {
-        result = handlerCallback(genericHandler, ActionType::getTag(),
-                                 ActionType::getDescription());
+        result = handlerCallback(genericHandler, std::forward<Args>(args)...);
       }
 
       // If the handler succeeded, return the result. Otherwise, try a new
@@ -167,7 +191,6 @@ class DebugActionManager {
 
   /// The set of action handlers that have been registered with the manager.
   SmallVector<std::unique_ptr<HandlerBase>> actionHandlers;
-#endif
 };
 
 //===----------------------------------------------------------------------===//
@@ -191,17 +214,27 @@ class DebugActionManager {
 /// instances of this action. The parameters to its query methods map 1-1 to the
 /// types on the action type.
 template <typename Derived, typename... ParameterTs>
-class DebugAction {
+class DebugAction : public DebugActionBase {
 public:
+  DebugAction()
+      : DebugActionBase(TypeID::get<Derived>(), Derived::getTag(),
+                        Derived::getDescription()) {}
+
+  /// Provide classof to allow casting between action types.
+  static bool classof(const DebugActionBase *action) {
+    return action->getActionID() == TypeID::get<Derived>();
+  }
+
   class Handler : public DebugActionManager::HandlerBase {
   public:
     Handler() : HandlerBase(TypeID::get<Derived>()) {}
 
-    /// This hook allows for controlling whether an action should execute or
-    /// not. `parameters` correspond to the set of values provided by the
+    /// This hook allows for controlling the execution of an action.
+    /// `parameters` correspond to the set of values provided by the
     /// action as context. It should return failure if the handler could not
     /// process the action, passing it to the next registered handler.
-    virtual FailureOr<bool> shouldExecute(ParameterTs... parameters) {
+    virtual FailureOr<bool> execute(function_ref<void()> transform,
+                                    const Derived &action) {
       return failure();
     }
 

diff  --git a/mlir/include/mlir/Support/DebugCounter.h b/mlir/include/mlir/Support/DebugCounter.h
index 5b6c015aadda5..83fd69d6526fe 100644
--- a/mlir/include/mlir/Support/DebugCounter.h
+++ b/mlir/include/mlir/Support/DebugCounter.h
@@ -38,7 +38,8 @@ class DebugCounter : public DebugActionManager::GenericHandler {
                   int64_t countToStopAfter);
 
   /// Register a counter with the specified name.
-  FailureOr<bool> shouldExecute(StringRef tag, StringRef description) final;
+  FailureOr<bool> execute(llvm::function_ref<void()> transform,
+                          const DebugActionBase &action) final;
 
   /// Print the counters that have been registered with this instance to the
   /// provided output stream.

diff  --git a/mlir/lib/Support/DebugCounter.cpp b/mlir/lib/Support/DebugCounter.cpp
index 44bcdf4ce1b75..a587b43914806 100644
--- a/mlir/lib/Support/DebugCounter.cpp
+++ b/mlir/lib/Support/DebugCounter.cpp
@@ -62,9 +62,9 @@ void DebugCounter::addCounter(StringRef actionTag, int64_t countToSkip,
 }
 
 // Register a counter with the specified name.
-FailureOr<bool> DebugCounter::shouldExecute(StringRef tag,
-                                            StringRef description) {
-  auto counterIt = counters.find(tag);
+FailureOr<bool> DebugCounter::execute(llvm::function_ref<void()> transform,
+                                      const DebugActionBase &action) {
+  auto counterIt = counters.find(action.getTag());
   if (counterIt == counters.end())
     return true;
 

diff  --git a/mlir/unittests/Support/DebugActionTest.cpp b/mlir/unittests/Support/DebugActionTest.cpp
index be2ca2e1b1b1a..1e80a0e612591 100644
--- a/mlir/unittests/Support/DebugActionTest.cpp
+++ b/mlir/unittests/Support/DebugActionTest.cpp
@@ -10,9 +10,6 @@
 #include "mlir/Support/TypeID.h"
 #include "gmock/gmock.h"
 
-// DebugActionManager is only enabled in DEBUG mode.
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
-
 using namespace mlir;
 
 namespace {
@@ -30,6 +27,8 @@ struct OtherSimpleAction : DebugAction<OtherSimpleAction> {
 };
 struct ParametricAction : DebugAction<ParametricAction, bool> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ParametricAction)
+  ParametricAction(bool executeParam) : executeParam(executeParam) {}
+  bool executeParam;
   static StringRef getTag() { return "param-action"; }
   static StringRef getDescription() { return "param-action-description"; }
 };
@@ -40,21 +39,25 @@ TEST(DebugActionTest, GenericHandler) {
   // A generic handler that always executes the simple action, but not the
   // parametric action.
   struct GenericHandler : DebugActionManager::GenericHandler {
-    FailureOr<bool> shouldExecute(StringRef tag, StringRef desc) final {
-      if (tag == SimpleAction::getTag()) {
+    FailureOr<bool> execute(llvm::function_ref<void()> transform,
+                            const DebugActionBase &action) final {
+      StringRef desc = action.getDescription();
+      if (isa<SimpleAction>(action)) {
         EXPECT_EQ(desc, SimpleAction::getDescription());
+        transform();
         return true;
       }
 
-      EXPECT_EQ(tag, ParametricAction::getTag());
+      EXPECT_TRUE(isa<ParametricAction>(action));
       EXPECT_EQ(desc, ParametricAction::getDescription());
       return false;
     }
   };
   manager.registerActionHandler<GenericHandler>();
 
-  EXPECT_TRUE(manager.shouldExecute<SimpleAction>());
-  EXPECT_FALSE(manager.shouldExecute<ParametricAction>(true));
+  auto noOp = []() { return; };
+  EXPECT_TRUE(manager.execute<SimpleAction>(noOp));
+  EXPECT_FALSE(manager.execute<ParametricAction>(noOp, true));
 }
 
 TEST(DebugActionTest, ActionSpecificHandler) {
@@ -62,17 +65,25 @@ TEST(DebugActionTest, ActionSpecificHandler) {
 
   // Handler that simply uses the input as the decider.
   struct ActionSpecificHandler : ParametricAction::Handler {
-    FailureOr<bool> shouldExecute(bool shouldExecuteParam) final {
-      return shouldExecuteParam;
+    FailureOr<bool> execute(llvm::function_ref<void()> transform,
+                            const ParametricAction &action) final {
+      if (action.executeParam)
+        transform();
+      return action.executeParam;
     }
   };
   manager.registerActionHandler<ActionSpecificHandler>();
 
-  EXPECT_TRUE(manager.shouldExecute<ParametricAction>(true));
-  EXPECT_FALSE(manager.shouldExecute<ParametricAction>(false));
+  int count = 0;
+  auto incCount = [&]() { count++; };
+  EXPECT_TRUE(manager.execute<ParametricAction>(incCount, true));
+  EXPECT_EQ(count, 1);
+  EXPECT_FALSE(manager.execute<ParametricAction>(incCount, false));
+  EXPECT_EQ(count, 1);
 
   // There is no handler for the simple action, so it is always executed.
-  EXPECT_TRUE(manager.shouldExecute<SimpleAction>());
+  EXPECT_TRUE(manager.execute<SimpleAction>(incCount));
+  EXPECT_EQ(count, 2);
 }
 
 TEST(DebugActionTest, DebugCounterHandler) {
@@ -80,17 +91,24 @@ TEST(DebugActionTest, DebugCounterHandler) {
 
   // Handler that uses the number of action executions as the decider.
   struct DebugCounterHandler : SimpleAction::Handler {
-    FailureOr<bool> shouldExecute() final { return numExecutions++ < 3; }
+    FailureOr<bool> execute(llvm::function_ref<void()> transform,
+                            const SimpleAction &action) final {
+      bool shouldExecute = numExecutions++ < 3;
+      if (shouldExecute)
+        transform();
+      return shouldExecute;
+    }
     unsigned numExecutions = 0;
   };
   manager.registerActionHandler<DebugCounterHandler>();
 
   // Check that the action is executed 3 times, but no more after.
-  EXPECT_TRUE(manager.shouldExecute<SimpleAction>());
-  EXPECT_TRUE(manager.shouldExecute<SimpleAction>());
-  EXPECT_TRUE(manager.shouldExecute<SimpleAction>());
-  EXPECT_FALSE(manager.shouldExecute<SimpleAction>());
-  EXPECT_FALSE(manager.shouldExecute<SimpleAction>());
+  auto noOp = []() { return; };
+  EXPECT_TRUE(manager.execute<SimpleAction>(noOp));
+  EXPECT_TRUE(manager.execute<SimpleAction>(noOp));
+  EXPECT_TRUE(manager.execute<SimpleAction>(noOp));
+  EXPECT_FALSE(manager.execute<SimpleAction>(noOp));
+  EXPECT_FALSE(manager.execute<SimpleAction>(noOp));
 }
 
 TEST(DebugActionTest, NonOverlappingActionSpecificHandlers) {
@@ -98,17 +116,24 @@ TEST(DebugActionTest, NonOverlappingActionSpecificHandlers) {
 
   // One handler returns true and another returns false
   struct SimpleActionHandler : SimpleAction::Handler {
-    FailureOr<bool> shouldExecute() final { return true; }
+    FailureOr<bool> execute(llvm::function_ref<void()> transform,
+                            const SimpleAction &action) final {
+      transform();
+      return true;
+    }
   };
   struct OtherSimpleActionHandler : OtherSimpleAction::Handler {
-    FailureOr<bool> shouldExecute() final { return false; }
+    FailureOr<bool> execute(llvm::function_ref<void()> transform,
+                            const OtherSimpleAction &action) final {
+      transform();
+      return false;
+    }
   };
   manager.registerActionHandler<SimpleActionHandler>();
   manager.registerActionHandler<OtherSimpleActionHandler>();
-  EXPECT_TRUE(manager.shouldExecute<SimpleAction>());
-  EXPECT_FALSE(manager.shouldExecute<OtherSimpleAction>());
+  auto noOp = []() { return; };
+  EXPECT_TRUE(manager.execute<SimpleAction>(noOp));
+  EXPECT_FALSE(manager.execute<OtherSimpleAction>(noOp));
 }
 
 } // namespace
-
-#endif

diff  --git a/mlir/unittests/Support/DebugCounterTest.cpp b/mlir/unittests/Support/DebugCounterTest.cpp
index bf8d0279938c3..c46550b58f179 100644
--- a/mlir/unittests/Support/DebugCounterTest.cpp
+++ b/mlir/unittests/Support/DebugCounterTest.cpp
@@ -12,9 +12,6 @@
 
 using namespace mlir;
 
-// DebugActionManager is only enabled in DEBUG mode.
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
-
 namespace {
 
 struct CounterAction : public DebugAction<CounterAction> {
@@ -31,16 +28,16 @@ TEST(DebugCounterTest, CounterTest) {
   DebugActionManager manager;
   manager.registerActionHandler(std::move(counter));
 
+  auto noOp = []() { return; };
+
   // The first execution is skipped.
-  EXPECT_FALSE(manager.shouldExecute<CounterAction>());
+  EXPECT_FALSE(manager.execute<CounterAction>(noOp));
 
   // The counter stops after 3 successful executions.
-  EXPECT_TRUE(manager.shouldExecute<CounterAction>());
-  EXPECT_TRUE(manager.shouldExecute<CounterAction>());
-  EXPECT_TRUE(manager.shouldExecute<CounterAction>());
-  EXPECT_FALSE(manager.shouldExecute<CounterAction>());
+  EXPECT_TRUE(manager.execute<CounterAction>(noOp));
+  EXPECT_TRUE(manager.execute<CounterAction>(noOp));
+  EXPECT_TRUE(manager.execute<CounterAction>(noOp));
+  EXPECT_FALSE(manager.execute<CounterAction>(noOp));
 }
 
 } // namespace
-
-#endif


        


More information about the Mlir-commits mailing list