[Mlir-commits] [mlir] 0580901 - Fix #58322: Handlers for debug actions with equal parameter types must not override each other
River Riddle
llvmlistbot at llvm.org
Sat Oct 22 14:30:08 PDT 2022
Author: Tomás Longeri
Date: 2022-10-22T14:18:00-07:00
New Revision: 0580901bbb0332547a08f37072a6ff8ca9e7c893
URL: https://github.com/llvm/llvm-project/commit/0580901bbb0332547a08f37072a6ff8ca9e7c893
DIFF: https://github.com/llvm/llvm-project/commit/0580901bbb0332547a08f37072a6ff8ca9e7c893.diff
LOG: Fix #58322: Handlers for debug actions with equal parameter types must not override each other
Also clean up redundant public access specifiers.
Reviewed By: mehdi_amini, rriddle
Differential Revision: https://reviews.llvm.org/D135924
Added:
Modified:
mlir/docs/DebugActions.md
mlir/include/mlir/Support/DebugAction.h
mlir/unittests/Support/DebugActionTest.cpp
mlir/unittests/Support/DebugCounterTest.cpp
Removed:
################################################################################
diff --git a/mlir/docs/DebugActions.md b/mlir/docs/DebugActions.md
index 777be8eaf4b1f..a40459c6161d6 100644
--- a/mlir/docs/DebugActions.md
+++ b/mlir/docs/DebugActions.md
@@ -54,10 +54,12 @@ rewrite patterns.
/// * The Tag is specified via a static `StringRef getTag()` method.
/// * The Description is specified via a static `StringRef getDescription()`
/// method.
-/// * The parameters for the action are provided via template parameters when
-/// inheriting from `DebugAction`.
+/// * `DebugAction` is a CRTP class, so the first template parameter is the
+/// action type class itself.
+/// * The parameters for the action are provided via additional template
+/// parameters when inheriting from `DebugAction`.
struct ApplyPatternAction
- : public DebugAction<Operation *, const Pattern &> {
+ : public DebugAction<ApplyPatternAction, Operation *, const Pattern &> {
static StringRef getTag() { return "apply-pattern"; }
static StringRef getDescription() {
return "Control the application of rewrite patterns";
@@ -95,7 +97,7 @@ usage of the `shouldExecute` query is shown below:
```c++
/// A debug action that allows for controlling the application of patterns.
struct ApplyPatternAction
- : public DebugAction<Operation *, const Pattern &> {
+ : public DebugAction<ApplyPatternAction, Operation *, const Pattern &> {
static StringRef getTag() { return "apply-pattern"; }
static StringRef getDescription() {
return "Control the application of rewrite patterns";
diff --git a/mlir/include/mlir/Support/DebugAction.h b/mlir/include/mlir/Support/DebugAction.h
index 41ec8b111f94a..e1dc25e7792f8 100644
--- a/mlir/include/mlir/Support/DebugAction.h
+++ b/mlir/include/mlir/Support/DebugAction.h
@@ -190,14 +190,12 @@ class DebugActionManager {
/// This class provides a handler class that can be derived from to handle
/// instances of this action. The parameters to its query methods map 1-1 to the
/// types on the action type.
-template <typename... ParameterTs>
+template <typename Derived, typename... ParameterTs>
class DebugAction {
public:
class Handler : public DebugActionManager::HandlerBase {
public:
- Handler()
- : HandlerBase(
- TypeID::get<typename DebugAction<ParameterTs...>::Handler>()) {}
+ Handler() : HandlerBase(TypeID::get<Derived>()) {}
/// This hook allows for controlling whether an action should execute or
/// not. `parameters` correspond to the set of values provided by the
@@ -209,8 +207,7 @@ class DebugAction {
/// Provide classof to allow casting between handler types.
static bool classof(const DebugActionManager::HandlerBase *handler) {
- return handler->getHandlerID() ==
- TypeID::get<typename DebugAction<ParameterTs...>::Handler>();
+ return handler->getHandlerID() == TypeID::get<Derived>();
}
};
diff --git a/mlir/unittests/Support/DebugActionTest.cpp b/mlir/unittests/Support/DebugActionTest.cpp
index 4ecaf128a1050..be2ca2e1b1b1a 100644
--- a/mlir/unittests/Support/DebugActionTest.cpp
+++ b/mlir/unittests/Support/DebugActionTest.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Support/DebugAction.h"
+#include "mlir/Support/TypeID.h"
#include "gmock/gmock.h"
// DebugActionManager is only enabled in DEBUG mode.
@@ -15,11 +16,20 @@
using namespace mlir;
namespace {
-struct SimpleAction : public DebugAction<> {
+struct SimpleAction : DebugAction<SimpleAction> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SimpleAction)
static StringRef getTag() { return "simple-action"; }
static StringRef getDescription() { return "simple-action-description"; }
};
-struct ParametricAction : public DebugAction<bool> {
+struct OtherSimpleAction : DebugAction<OtherSimpleAction> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OtherSimpleAction)
+ static StringRef getTag() { return "other-simple-action"; }
+ static StringRef getDescription() {
+ return "other-simple-action-description";
+ }
+};
+struct ParametricAction : DebugAction<ParametricAction, bool> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ParametricAction)
static StringRef getTag() { return "param-action"; }
static StringRef getDescription() { return "param-action-description"; }
};
@@ -29,7 +39,7 @@ TEST(DebugActionTest, GenericHandler) {
// A generic handler that always executes the simple action, but not the
// parametric action.
- struct GenericHandler : public DebugActionManager::GenericHandler {
+ struct GenericHandler : DebugActionManager::GenericHandler {
FailureOr<bool> shouldExecute(StringRef tag, StringRef desc) final {
if (tag == SimpleAction::getTag()) {
EXPECT_EQ(desc, SimpleAction::getDescription());
@@ -51,7 +61,7 @@ TEST(DebugActionTest, ActionSpecificHandler) {
DebugActionManager manager;
// Handler that simply uses the input as the decider.
- struct ActionSpecificHandler : public ParametricAction::Handler {
+ struct ActionSpecificHandler : ParametricAction::Handler {
FailureOr<bool> shouldExecute(bool shouldExecuteParam) final {
return shouldExecuteParam;
}
@@ -69,7 +79,7 @@ TEST(DebugActionTest, DebugCounterHandler) {
DebugActionManager manager;
// Handler that uses the number of action executions as the decider.
- struct DebugCounterHandler : public SimpleAction::Handler {
+ struct DebugCounterHandler : SimpleAction::Handler {
FailureOr<bool> shouldExecute() final { return numExecutions++ < 3; }
unsigned numExecutions = 0;
};
@@ -83,6 +93,22 @@ TEST(DebugActionTest, DebugCounterHandler) {
EXPECT_FALSE(manager.shouldExecute<SimpleAction>());
}
+TEST(DebugActionTest, NonOverlappingActionSpecificHandlers) {
+ DebugActionManager manager;
+
+ // One handler returns true and another returns false
+ struct SimpleActionHandler : SimpleAction::Handler {
+ FailureOr<bool> shouldExecute() final { return true; }
+ };
+ struct OtherSimpleActionHandler : OtherSimpleAction::Handler {
+ FailureOr<bool> shouldExecute() final { return false; }
+ };
+ manager.registerActionHandler<SimpleActionHandler>();
+ manager.registerActionHandler<OtherSimpleActionHandler>();
+ EXPECT_TRUE(manager.shouldExecute<SimpleAction>());
+ EXPECT_FALSE(manager.shouldExecute<OtherSimpleAction>());
+}
+
} // namespace
#endif
diff --git a/mlir/unittests/Support/DebugCounterTest.cpp b/mlir/unittests/Support/DebugCounterTest.cpp
index 7ca48cdf35731..bf8d0279938c3 100644
--- a/mlir/unittests/Support/DebugCounterTest.cpp
+++ b/mlir/unittests/Support/DebugCounterTest.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Support/DebugCounter.h"
+#include "mlir/Support/TypeID.h"
#include "gmock/gmock.h"
using namespace mlir;
@@ -16,7 +17,8 @@ using namespace mlir;
namespace {
-struct CounterAction : public DebugAction<> {
+struct CounterAction : public DebugAction<CounterAction> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CounterAction)
static StringRef getTag() { return "counter-action"; }
static StringRef getDescription() { return "Test action for debug counters"; }
};
More information about the Mlir-commits
mailing list