[Mlir-commits] [mlir] 0580901 - Fix #58322: Handlers for debug actions with equal parameter types must not override each other

River Riddle llvmlistbot at llvm.org
Sat Oct 22 14:30:08 PDT 2022


Author: Tomás Longeri
Date: 2022-10-22T14:18:00-07:00
New Revision: 0580901bbb0332547a08f37072a6ff8ca9e7c893

URL: https://github.com/llvm/llvm-project/commit/0580901bbb0332547a08f37072a6ff8ca9e7c893
DIFF: https://github.com/llvm/llvm-project/commit/0580901bbb0332547a08f37072a6ff8ca9e7c893.diff

LOG: Fix #58322: Handlers for debug actions with equal parameter types must not override each other

Also clean up redundant public access specifiers.

Reviewed By: mehdi_amini, rriddle

Differential Revision: https://reviews.llvm.org/D135924

Added: 
    

Modified: 
    mlir/docs/DebugActions.md
    mlir/include/mlir/Support/DebugAction.h
    mlir/unittests/Support/DebugActionTest.cpp
    mlir/unittests/Support/DebugCounterTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/DebugActions.md b/mlir/docs/DebugActions.md
index 777be8eaf4b1f..a40459c6161d6 100644
--- a/mlir/docs/DebugActions.md
+++ b/mlir/docs/DebugActions.md
@@ -54,10 +54,12 @@ rewrite patterns.
 /// * The Tag is specified via a static `StringRef getTag()` method.
 /// * The Description is specified via a static `StringRef getDescription()`
 ///   method.
-/// * The parameters for the action are provided via template parameters when
-///   inheriting from `DebugAction`.
+/// * `DebugAction` is a CRTP class, so the first template parameter is the
+///   action type class itself.
+/// * The parameters for the action are provided via additional template
+///   parameters when inheriting from `DebugAction`.
 struct ApplyPatternAction
-    : public DebugAction<Operation *, const Pattern &> {
+    : public DebugAction<ApplyPatternAction, Operation *, const Pattern &> {
   static StringRef getTag() { return "apply-pattern"; }
   static StringRef getDescription() {
     return "Control the application of rewrite patterns";
@@ -95,7 +97,7 @@ usage of the `shouldExecute` query is shown below:
 ```c++
 /// A debug action that allows for controlling the application of patterns.
 struct ApplyPatternAction
-    : public DebugAction<Operation *, const Pattern &> {
+    : public DebugAction<ApplyPatternAction, Operation *, const Pattern &> {
   static StringRef getTag() { return "apply-pattern"; }
   static StringRef getDescription() {
     return "Control the application of rewrite patterns";

diff  --git a/mlir/include/mlir/Support/DebugAction.h b/mlir/include/mlir/Support/DebugAction.h
index 41ec8b111f94a..e1dc25e7792f8 100644
--- a/mlir/include/mlir/Support/DebugAction.h
+++ b/mlir/include/mlir/Support/DebugAction.h
@@ -190,14 +190,12 @@ class DebugActionManager {
 /// This class provides a handler class that can be derived from to handle
 /// instances of this action. The parameters to its query methods map 1-1 to the
 /// types on the action type.
-template <typename... ParameterTs>
+template <typename Derived, typename... ParameterTs>
 class DebugAction {
 public:
   class Handler : public DebugActionManager::HandlerBase {
   public:
-    Handler()
-        : HandlerBase(
-              TypeID::get<typename DebugAction<ParameterTs...>::Handler>()) {}
+    Handler() : HandlerBase(TypeID::get<Derived>()) {}
 
     /// This hook allows for controlling whether an action should execute or
     /// not. `parameters` correspond to the set of values provided by the
@@ -209,8 +207,7 @@ class DebugAction {
 
     /// Provide classof to allow casting between handler types.
     static bool classof(const DebugActionManager::HandlerBase *handler) {
-      return handler->getHandlerID() ==
-             TypeID::get<typename DebugAction<ParameterTs...>::Handler>();
+      return handler->getHandlerID() == TypeID::get<Derived>();
     }
   };
 

diff  --git a/mlir/unittests/Support/DebugActionTest.cpp b/mlir/unittests/Support/DebugActionTest.cpp
index 4ecaf128a1050..be2ca2e1b1b1a 100644
--- a/mlir/unittests/Support/DebugActionTest.cpp
+++ b/mlir/unittests/Support/DebugActionTest.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Support/DebugAction.h"
+#include "mlir/Support/TypeID.h"
 #include "gmock/gmock.h"
 
 // DebugActionManager is only enabled in DEBUG mode.
@@ -15,11 +16,20 @@
 using namespace mlir;
 
 namespace {
-struct SimpleAction : public DebugAction<> {
+struct SimpleAction : DebugAction<SimpleAction> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SimpleAction)
   static StringRef getTag() { return "simple-action"; }
   static StringRef getDescription() { return "simple-action-description"; }
 };
-struct ParametricAction : public DebugAction<bool> {
+struct OtherSimpleAction : DebugAction<OtherSimpleAction> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OtherSimpleAction)
+  static StringRef getTag() { return "other-simple-action"; }
+  static StringRef getDescription() {
+    return "other-simple-action-description";
+  }
+};
+struct ParametricAction : DebugAction<ParametricAction, bool> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ParametricAction)
   static StringRef getTag() { return "param-action"; }
   static StringRef getDescription() { return "param-action-description"; }
 };
@@ -29,7 +39,7 @@ TEST(DebugActionTest, GenericHandler) {
 
   // A generic handler that always executes the simple action, but not the
   // parametric action.
-  struct GenericHandler : public DebugActionManager::GenericHandler {
+  struct GenericHandler : DebugActionManager::GenericHandler {
     FailureOr<bool> shouldExecute(StringRef tag, StringRef desc) final {
       if (tag == SimpleAction::getTag()) {
         EXPECT_EQ(desc, SimpleAction::getDescription());
@@ -51,7 +61,7 @@ TEST(DebugActionTest, ActionSpecificHandler) {
   DebugActionManager manager;
 
   // Handler that simply uses the input as the decider.
-  struct ActionSpecificHandler : public ParametricAction::Handler {
+  struct ActionSpecificHandler : ParametricAction::Handler {
     FailureOr<bool> shouldExecute(bool shouldExecuteParam) final {
       return shouldExecuteParam;
     }
@@ -69,7 +79,7 @@ TEST(DebugActionTest, DebugCounterHandler) {
   DebugActionManager manager;
 
   // Handler that uses the number of action executions as the decider.
-  struct DebugCounterHandler : public SimpleAction::Handler {
+  struct DebugCounterHandler : SimpleAction::Handler {
     FailureOr<bool> shouldExecute() final { return numExecutions++ < 3; }
     unsigned numExecutions = 0;
   };
@@ -83,6 +93,22 @@ TEST(DebugActionTest, DebugCounterHandler) {
   EXPECT_FALSE(manager.shouldExecute<SimpleAction>());
 }
 
+TEST(DebugActionTest, NonOverlappingActionSpecificHandlers) {
+  DebugActionManager manager;
+
+  // One handler returns true and another returns false
+  struct SimpleActionHandler : SimpleAction::Handler {
+    FailureOr<bool> shouldExecute() final { return true; }
+  };
+  struct OtherSimpleActionHandler : OtherSimpleAction::Handler {
+    FailureOr<bool> shouldExecute() final { return false; }
+  };
+  manager.registerActionHandler<SimpleActionHandler>();
+  manager.registerActionHandler<OtherSimpleActionHandler>();
+  EXPECT_TRUE(manager.shouldExecute<SimpleAction>());
+  EXPECT_FALSE(manager.shouldExecute<OtherSimpleAction>());
+}
+
 } // namespace
 
 #endif

diff  --git a/mlir/unittests/Support/DebugCounterTest.cpp b/mlir/unittests/Support/DebugCounterTest.cpp
index 7ca48cdf35731..bf8d0279938c3 100644
--- a/mlir/unittests/Support/DebugCounterTest.cpp
+++ b/mlir/unittests/Support/DebugCounterTest.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Support/DebugCounter.h"
+#include "mlir/Support/TypeID.h"
 #include "gmock/gmock.h"
 
 using namespace mlir;
@@ -16,7 +17,8 @@ using namespace mlir;
 
 namespace {
 
-struct CounterAction : public DebugAction<> {
+struct CounterAction : public DebugAction<CounterAction> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CounterAction)
   static StringRef getTag() { return "counter-action"; }
   static StringRef getDescription() { return "Test action for debug counters"; }
 };


        


More information about the Mlir-commits mailing list