[Mlir-commits] [mlir] 58e7bf7 - [mlir] Add isa/dyn_cast support for dialect interfaces

River Riddle llvmlistbot at llvm.org
Mon Jan 31 19:24:55 PST 2022


Author: River Riddle
Date: 2022-01-31T19:24:34-08:00
New Revision: 58e7bf78a3ef724b70304912fb3bb66af8c4a10c

URL: https://github.com/llvm/llvm-project/commit/58e7bf78a3ef724b70304912fb3bb66af8c4a10c
DIFF: https://github.com/llvm/llvm-project/commit/58e7bf78a3ef724b70304912fb3bb66af8c4a10c.diff

LOG: [mlir] Add isa/dyn_cast support for dialect interfaces

This matches the same API usage as attributes/ops/types. For example:

```c++
Dialect *dialect = ...;

// Instead of this:
if (auto *interface = dialect->getRegisteredInterface<DialectInlinerInterface>())

// You can do this:
if (auto *interface = dyn_cast<DialectInlinerInterface>(dialect))
```

Differential Revision: https://reviews.llvm.org/D117859

Added: 
    

Modified: 
    mlir/docs/Interfaces.md
    mlir/include/mlir/IR/Dialect.h
    mlir/lib/Dialect/DLTI/DLTI.cpp
    mlir/lib/IR/BuiltinAttributes.cpp
    mlir/lib/IR/Operation.cpp
    mlir/lib/Interfaces/DataLayoutInterfaces.cpp
    mlir/unittests/IR/DialectTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md
index c181a603ffa71..b51aec9b603d6 100644
--- a/mlir/docs/Interfaces.md
+++ b/mlir/docs/Interfaces.md
@@ -77,8 +77,7 @@ or transformation without the need to determine the specific dialect subclass:
 
 ```c++
 Dialect *dialect = ...;
-if (DialectInlinerInterface *interface
-      = dialect->getRegisteredInterface<DialectInlinerInterface>()) {
+if (DialectInlinerInterface *interface = dyn_cast<DialectInlinerInterface>(dialect)) {
   // The dialect has provided an implementation of this interface.
   ...
 }

diff  --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h
index 7fb298c1b9425..798d66faccdfa 100644
--- a/mlir/include/mlir/IR/Dialect.h
+++ b/mlir/include/mlir/IR/Dialect.h
@@ -440,11 +440,58 @@ class DialectRegistry {
 
 namespace llvm {
 /// Provide isa functionality for Dialects.
-template <typename T> struct isa_impl<T, ::mlir::Dialect> {
+template <typename T>
+struct isa_impl<T, ::mlir::Dialect,
+                std::enable_if_t<std::is_base_of<::mlir::Dialect, T>::value>> {
   static inline bool doit(const ::mlir::Dialect &dialect) {
     return mlir::TypeID::get<T>() == dialect.getTypeID();
   }
 };
+template <typename T>
+struct isa_impl<
+    T, ::mlir::Dialect,
+    std::enable_if_t<std::is_base_of<::mlir::DialectInterface, T>::value>> {
+  static inline bool doit(const ::mlir::Dialect &dialect) {
+    return const_cast<::mlir::Dialect &>(dialect).getRegisteredInterface<T>();
+  }
+};
+template <typename T>
+struct cast_retty_impl<T, ::mlir::Dialect *> {
+  using ret_type =
+      std::conditional_t<std::is_base_of<::mlir::Dialect, T>::value, T *,
+                         const T *>;
+};
+template <typename T>
+struct cast_retty_impl<T, ::mlir::Dialect> {
+  using ret_type =
+      std::conditional_t<std::is_base_of<::mlir::Dialect, T>::value, T &,
+                         const T &>;
+};
+
+template <typename T>
+struct cast_convert_val<T, ::mlir::Dialect, ::mlir::Dialect> {
+  template <typename To>
+  static std::enable_if_t<std::is_base_of<::mlir::Dialect, To>::value, To &>
+  doitImpl(::mlir::Dialect &dialect) {
+    return static_cast<To &>(dialect);
+  }
+  template <typename To>
+  static std::enable_if_t<std::is_base_of<::mlir::DialectInterface, To>::value,
+                          const To &>
+  doitImpl(::mlir::Dialect &dialect) {
+    return *dialect.getRegisteredInterface<To>();
+  }
+
+  static auto &doit(::mlir::Dialect &dialect) { return doitImpl<T>(dialect); }
+};
+template <class T>
+struct cast_convert_val<T, ::mlir::Dialect *, ::mlir::Dialect *> {
+  static auto doit(::mlir::Dialect *dialect) {
+    return &cast_convert_val<T, ::mlir::Dialect, ::mlir::Dialect>::doit(
+        *dialect);
+  }
+};
+
 } // namespace llvm
 
 #endif

diff  --git a/mlir/lib/Dialect/DLTI/DLTI.cpp b/mlir/lib/Dialect/DLTI/DLTI.cpp
index cf1573ded67fd..7382fbad3e1cb 100644
--- a/mlir/lib/Dialect/DLTI/DLTI.cpp
+++ b/mlir/lib/Dialect/DLTI/DLTI.cpp
@@ -231,8 +231,8 @@ combineOneSpec(DataLayoutSpecInterface spec,
     // dialect is not loaded for some reason, use the default combinator
     // that conservatively accepts identical entries only.
     entriesForID[id] =
-        dialect ? dialect->getRegisteredInterface<DataLayoutDialectInterface>()
-                      ->combine(entriesForID[id], kvp.second)
+        dialect ? cast<DataLayoutDialectInterface>(dialect)->combine(
+                      entriesForID[id], kvp.second)
                 : DataLayoutDialectInterface::defaultCombine(entriesForID[id],
                                                              kvp.second);
     if (!entriesForID[id])

diff  --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 802df2dada9d0..79e80f7c1317c 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -1236,8 +1236,7 @@ bool OpaqueElementsAttr::decode(ElementsAttr &result) {
   Dialect *dialect = getContext()->getLoadedDialect(getDialect());
   if (!dialect)
     return true;
-  auto *interface =
-      dialect->getRegisteredInterface<DialectDecodeAttributesInterface>();
+  auto *interface = llvm::dyn_cast<DialectDecodeAttributesInterface>(dialect);
   if (!interface)
     return true;
   return failed(interface->decode(*this, result));

diff  --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 1ca4d684475bd..e67933790de11 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -506,7 +506,7 @@ LogicalResult Operation::fold(ArrayRef<Attribute> operands,
   if (!dialect)
     return failure();
 
-  auto *interface = dialect->getRegisteredInterface<DialectFoldInterface>();
+  auto *interface = dyn_cast<DialectFoldInterface>(dialect);
   if (!interface)
     return failure();
 

diff  --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
index 2b7ff5ef60583..ac6397c632ac4 100644
--- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
+++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
@@ -438,8 +438,7 @@ LogicalResult mlir::detail::verifyDataLayoutSpec(DataLayoutSpecInterface spec,
     if (!dialect)
       continue;
 
-    const auto *iface =
-        dialect->getRegisteredInterface<DataLayoutDialectInterface>();
+    const auto *iface = dyn_cast<DataLayoutDialectInterface>(dialect);
     if (!iface) {
       return emitError(loc)
              << "the '" << dialect->getNamespace()

diff  --git a/mlir/unittests/IR/DialectTest.cpp b/mlir/unittests/IR/DialectTest.cpp
index ca89e824a7ed7..b4fd697e08820 100644
--- a/mlir/unittests/IR/DialectTest.cpp
+++ b/mlir/unittests/IR/DialectTest.cpp
@@ -68,18 +68,17 @@ TEST(Dialect, DelayedInterfaceRegistration) {
   MLIRContext context(registry);
 
   // Load the TestDialect and check that the interface got registered for it.
-  auto *testDialect = context.getOrLoadDialect<TestDialect>();
+  Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
   ASSERT_TRUE(testDialect != nullptr);
-  auto *testDialectInterface =
-      testDialect->getRegisteredInterface<TestDialectInterfaceBase>();
+  auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
   EXPECT_TRUE(testDialectInterface != nullptr);
 
   // Load the SecondTestDialect and check that the interface is not registered
   // for it.
-  auto *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>();
+  Dialect *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>();
   ASSERT_TRUE(secondTestDialect != nullptr);
   auto *secondTestDialectInterface =
-      secondTestDialect->getRegisteredInterface<SecondTestDialectInterface>();
+      dyn_cast<SecondTestDialectInterface>(secondTestDialect);
   EXPECT_TRUE(secondTestDialectInterface == nullptr);
 
   // Use the same mechanism as for delayed registration but for an already
@@ -90,7 +89,7 @@ TEST(Dialect, DelayedInterfaceRegistration) {
       .addDialectInterface<SecondTestDialect, SecondTestDialectInterface>();
   context.appendDialectRegistry(secondRegistry);
   secondTestDialectInterface =
-      secondTestDialect->getRegisteredInterface<SecondTestDialectInterface>();
+      dyn_cast<SecondTestDialectInterface>(secondTestDialect);
   EXPECT_TRUE(secondTestDialectInterface != nullptr);
 }
 
@@ -102,10 +101,9 @@ TEST(Dialect, RepeatedDelayedRegistration) {
   MLIRContext context(registry);
 
   // Load the TestDialect and check that the interface got registered for it.
-  auto *testDialect = context.getOrLoadDialect<TestDialect>();
+  Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
   ASSERT_TRUE(testDialect != nullptr);
-  auto *testDialectInterface =
-      testDialect->getRegisteredInterface<TestDialectInterfaceBase>();
+  auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
   EXPECT_TRUE(testDialectInterface != nullptr);
 
   // Try adding the same dialect interface again and check that we don't crash
@@ -114,8 +112,7 @@ TEST(Dialect, RepeatedDelayedRegistration) {
   secondRegistry.insert<TestDialect>();
   secondRegistry.addDialectInterface<TestDialect, TestDialectInterface>();
   context.appendDialectRegistry(secondRegistry);
-  testDialectInterface =
-      testDialect->getRegisteredInterface<TestDialectInterfaceBase>();
+  testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
   EXPECT_TRUE(testDialectInterface != nullptr);
 }
 


        


More information about the Mlir-commits mailing list