[Mlir-commits] [mlir] 58e7bf7 - [mlir] Add isa/dyn_cast support for dialect interfaces
River Riddle
llvmlistbot at llvm.org
Mon Jan 31 19:24:55 PST 2022
Author: River Riddle
Date: 2022-01-31T19:24:34-08:00
New Revision: 58e7bf78a3ef724b70304912fb3bb66af8c4a10c
URL: https://github.com/llvm/llvm-project/commit/58e7bf78a3ef724b70304912fb3bb66af8c4a10c
DIFF: https://github.com/llvm/llvm-project/commit/58e7bf78a3ef724b70304912fb3bb66af8c4a10c.diff
LOG: [mlir] Add isa/dyn_cast support for dialect interfaces
This matches the same API usage as attributes/ops/types. For example:
```c++
Dialect *dialect = ...;
// Instead of this:
if (auto *interface = dialect->getRegisteredInterface<DialectInlinerInterface>())
// You can do this:
if (auto *interface = dyn_cast<DialectInlinerInterface>(dialect))
```
Differential Revision: https://reviews.llvm.org/D117859
Added:
Modified:
mlir/docs/Interfaces.md
mlir/include/mlir/IR/Dialect.h
mlir/lib/Dialect/DLTI/DLTI.cpp
mlir/lib/IR/BuiltinAttributes.cpp
mlir/lib/IR/Operation.cpp
mlir/lib/Interfaces/DataLayoutInterfaces.cpp
mlir/unittests/IR/DialectTest.cpp
Removed:
################################################################################
diff --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md
index c181a603ffa71..b51aec9b603d6 100644
--- a/mlir/docs/Interfaces.md
+++ b/mlir/docs/Interfaces.md
@@ -77,8 +77,7 @@ or transformation without the need to determine the specific dialect subclass:
```c++
Dialect *dialect = ...;
-if (DialectInlinerInterface *interface
- = dialect->getRegisteredInterface<DialectInlinerInterface>()) {
+if (DialectInlinerInterface *interface = dyn_cast<DialectInlinerInterface>(dialect)) {
// The dialect has provided an implementation of this interface.
...
}
diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h
index 7fb298c1b9425..798d66faccdfa 100644
--- a/mlir/include/mlir/IR/Dialect.h
+++ b/mlir/include/mlir/IR/Dialect.h
@@ -440,11 +440,58 @@ class DialectRegistry {
namespace llvm {
/// Provide isa functionality for Dialects.
-template <typename T> struct isa_impl<T, ::mlir::Dialect> {
+template <typename T>
+struct isa_impl<T, ::mlir::Dialect,
+ std::enable_if_t<std::is_base_of<::mlir::Dialect, T>::value>> {
static inline bool doit(const ::mlir::Dialect &dialect) {
return mlir::TypeID::get<T>() == dialect.getTypeID();
}
};
+template <typename T>
+struct isa_impl<
+ T, ::mlir::Dialect,
+ std::enable_if_t<std::is_base_of<::mlir::DialectInterface, T>::value>> {
+ static inline bool doit(const ::mlir::Dialect &dialect) {
+ return const_cast<::mlir::Dialect &>(dialect).getRegisteredInterface<T>();
+ }
+};
+template <typename T>
+struct cast_retty_impl<T, ::mlir::Dialect *> {
+ using ret_type =
+ std::conditional_t<std::is_base_of<::mlir::Dialect, T>::value, T *,
+ const T *>;
+};
+template <typename T>
+struct cast_retty_impl<T, ::mlir::Dialect> {
+ using ret_type =
+ std::conditional_t<std::is_base_of<::mlir::Dialect, T>::value, T &,
+ const T &>;
+};
+
+template <typename T>
+struct cast_convert_val<T, ::mlir::Dialect, ::mlir::Dialect> {
+ template <typename To>
+ static std::enable_if_t<std::is_base_of<::mlir::Dialect, To>::value, To &>
+ doitImpl(::mlir::Dialect &dialect) {
+ return static_cast<To &>(dialect);
+ }
+ template <typename To>
+ static std::enable_if_t<std::is_base_of<::mlir::DialectInterface, To>::value,
+ const To &>
+ doitImpl(::mlir::Dialect &dialect) {
+ return *dialect.getRegisteredInterface<To>();
+ }
+
+ static auto &doit(::mlir::Dialect &dialect) { return doitImpl<T>(dialect); }
+};
+template <class T>
+struct cast_convert_val<T, ::mlir::Dialect *, ::mlir::Dialect *> {
+ static auto doit(::mlir::Dialect *dialect) {
+ return &cast_convert_val<T, ::mlir::Dialect, ::mlir::Dialect>::doit(
+ *dialect);
+ }
+};
+
} // namespace llvm
#endif
diff --git a/mlir/lib/Dialect/DLTI/DLTI.cpp b/mlir/lib/Dialect/DLTI/DLTI.cpp
index cf1573ded67fd..7382fbad3e1cb 100644
--- a/mlir/lib/Dialect/DLTI/DLTI.cpp
+++ b/mlir/lib/Dialect/DLTI/DLTI.cpp
@@ -231,8 +231,8 @@ combineOneSpec(DataLayoutSpecInterface spec,
// dialect is not loaded for some reason, use the default combinator
// that conservatively accepts identical entries only.
entriesForID[id] =
- dialect ? dialect->getRegisteredInterface<DataLayoutDialectInterface>()
- ->combine(entriesForID[id], kvp.second)
+ dialect ? cast<DataLayoutDialectInterface>(dialect)->combine(
+ entriesForID[id], kvp.second)
: DataLayoutDialectInterface::defaultCombine(entriesForID[id],
kvp.second);
if (!entriesForID[id])
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 802df2dada9d0..79e80f7c1317c 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -1236,8 +1236,7 @@ bool OpaqueElementsAttr::decode(ElementsAttr &result) {
Dialect *dialect = getContext()->getLoadedDialect(getDialect());
if (!dialect)
return true;
- auto *interface =
- dialect->getRegisteredInterface<DialectDecodeAttributesInterface>();
+ auto *interface = llvm::dyn_cast<DialectDecodeAttributesInterface>(dialect);
if (!interface)
return true;
return failed(interface->decode(*this, result));
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 1ca4d684475bd..e67933790de11 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -506,7 +506,7 @@ LogicalResult Operation::fold(ArrayRef<Attribute> operands,
if (!dialect)
return failure();
- auto *interface = dialect->getRegisteredInterface<DialectFoldInterface>();
+ auto *interface = dyn_cast<DialectFoldInterface>(dialect);
if (!interface)
return failure();
diff --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
index 2b7ff5ef60583..ac6397c632ac4 100644
--- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
+++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
@@ -438,8 +438,7 @@ LogicalResult mlir::detail::verifyDataLayoutSpec(DataLayoutSpecInterface spec,
if (!dialect)
continue;
- const auto *iface =
- dialect->getRegisteredInterface<DataLayoutDialectInterface>();
+ const auto *iface = dyn_cast<DataLayoutDialectInterface>(dialect);
if (!iface) {
return emitError(loc)
<< "the '" << dialect->getNamespace()
diff --git a/mlir/unittests/IR/DialectTest.cpp b/mlir/unittests/IR/DialectTest.cpp
index ca89e824a7ed7..b4fd697e08820 100644
--- a/mlir/unittests/IR/DialectTest.cpp
+++ b/mlir/unittests/IR/DialectTest.cpp
@@ -68,18 +68,17 @@ TEST(Dialect, DelayedInterfaceRegistration) {
MLIRContext context(registry);
// Load the TestDialect and check that the interface got registered for it.
- auto *testDialect = context.getOrLoadDialect<TestDialect>();
+ Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
ASSERT_TRUE(testDialect != nullptr);
- auto *testDialectInterface =
- testDialect->getRegisteredInterface<TestDialectInterfaceBase>();
+ auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
EXPECT_TRUE(testDialectInterface != nullptr);
// Load the SecondTestDialect and check that the interface is not registered
// for it.
- auto *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>();
+ Dialect *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>();
ASSERT_TRUE(secondTestDialect != nullptr);
auto *secondTestDialectInterface =
- secondTestDialect->getRegisteredInterface<SecondTestDialectInterface>();
+ dyn_cast<SecondTestDialectInterface>(secondTestDialect);
EXPECT_TRUE(secondTestDialectInterface == nullptr);
// Use the same mechanism as for delayed registration but for an already
@@ -90,7 +89,7 @@ TEST(Dialect, DelayedInterfaceRegistration) {
.addDialectInterface<SecondTestDialect, SecondTestDialectInterface>();
context.appendDialectRegistry(secondRegistry);
secondTestDialectInterface =
- secondTestDialect->getRegisteredInterface<SecondTestDialectInterface>();
+ dyn_cast<SecondTestDialectInterface>(secondTestDialect);
EXPECT_TRUE(secondTestDialectInterface != nullptr);
}
@@ -102,10 +101,9 @@ TEST(Dialect, RepeatedDelayedRegistration) {
MLIRContext context(registry);
// Load the TestDialect and check that the interface got registered for it.
- auto *testDialect = context.getOrLoadDialect<TestDialect>();
+ Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
ASSERT_TRUE(testDialect != nullptr);
- auto *testDialectInterface =
- testDialect->getRegisteredInterface<TestDialectInterfaceBase>();
+ auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
EXPECT_TRUE(testDialectInterface != nullptr);
// Try adding the same dialect interface again and check that we don't crash
@@ -114,8 +112,7 @@ TEST(Dialect, RepeatedDelayedRegistration) {
secondRegistry.insert<TestDialect>();
secondRegistry.addDialectInterface<TestDialect, TestDialectInterface>();
context.appendDialectRegistry(secondRegistry);
- testDialectInterface =
- testDialect->getRegisteredInterface<TestDialectInterfaceBase>();
+ testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
EXPECT_TRUE(testDialectInterface != nullptr);
}
More information about the Mlir-commits
mailing list