[Mlir-commits] [mlir] 34ea608 - [mlir] Support repeated delayed registration of dialect interfaces

Alex Zinenko llvmlistbot at llvm.org
Mon Feb 15 01:46:35 PST 2021


Author: Alex Zinenko
Date: 2021-02-15T10:46:26+01:00
New Revision: 34ea608a473a5c67263c49255551ea348ffc1700

URL: https://github.com/llvm/llvm-project/commit/34ea608a473a5c67263c49255551ea348ffc1700
DIFF: https://github.com/llvm/llvm-project/commit/34ea608a473a5c67263c49255551ea348ffc1700.diff

LOG: [mlir] Support repeated delayed registration of dialect interfaces

Dialects themselves do not support repeated addition of interfaces with the
same TypeID. However, in case of delayed registration, the registry may contain
such an interface, or have the same interface registered several times due to,
e.g., dependencies. Make sure we delayed registration does not attempt to add
an interface with the same TypeID more than once.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D96606

Added: 
    

Modified: 
    mlir/include/mlir/IR/Dialect.h
    mlir/lib/IR/Dialect.cpp
    mlir/unittests/IR/DialectTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h
index 7798f92400b3..4a816ccb79c9 100644
--- a/mlir/include/mlir/IR/Dialect.h
+++ b/mlir/include/mlir/IR/Dialect.h
@@ -239,7 +239,8 @@ class DialectRegistry {
   using MapTy =
       std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>>;
   using InterfaceMapTy =
-      DenseMap<TypeID, SmallVector<InterfaceAllocatorFunction, 2>>;
+      DenseMap<TypeID,
+               SmallVector<std::pair<TypeID, InterfaceAllocatorFunction>, 2>>;
 
 public:
   explicit DialectRegistry() {}
@@ -292,17 +293,20 @@ class DialectRegistry {
   /// dialect provided as template parameter. The dialect must be present in
   /// the registry.
   template <typename DialectTy>
-  void addDialectInterface(InterfaceAllocatorFunction allocator) {
-    addDialectInterface(DialectTy::getDialectNamespace(), allocator);
+  void addDialectInterface(TypeID interfaceTypeID,
+                           InterfaceAllocatorFunction allocator) {
+    addDialectInterface(DialectTy::getDialectNamespace(), interfaceTypeID,
+                        allocator);
   }
 
   /// Add an interface to the dialect, both provided as template parameter. The
   /// dialect must be present in the registry.
   template <typename DialectTy, typename InterfaceTy>
   void addDialectInterface() {
-    addDialectInterface<DialectTy>([](Dialect *dialect) {
-      return std::make_unique<InterfaceTy>(dialect);
-    });
+    addDialectInterface<DialectTy>(
+        InterfaceTy::getInterfaceID(), [](Dialect *dialect) {
+          return std::make_unique<InterfaceTy>(dialect);
+        });
   }
 
   /// Register any interfaces required for the given dialect (based on its
@@ -312,7 +316,7 @@ class DialectRegistry {
 private:
   /// Add an interface constructed with the given allocation function to the
   /// dialect identified by its namespace.
-  void addDialectInterface(StringRef dialectName,
+  void addDialectInterface(StringRef dialectName, TypeID interfaceTypeID,
                            InterfaceAllocatorFunction allocator);
 
   MapTy registry;

diff  --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index a77f31cc7f40..228bb8d6f327 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -14,9 +14,12 @@
 #include "mlir/IR/Operation.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/Twine.h"
+#include "llvm/Support/Debug.h"
 #include "llvm/Support/ManagedStatic.h"
 #include "llvm/Support/Regex.h"
 
+#define DEBUG_TYPE "dialect"
+
 using namespace mlir;
 using namespace detail;
 
@@ -27,12 +30,28 @@ DialectAsmParser::~DialectAsmParser() {}
 //===----------------------------------------------------------------------===//
 
 void DialectRegistry::addDialectInterface(
-    StringRef dialectName, InterfaceAllocatorFunction allocator) {
+    StringRef dialectName, TypeID interfaceTypeID,
+    InterfaceAllocatorFunction allocator) {
   assert(allocator && "unexpected null interface allocation function");
   auto it = registry.find(dialectName.str());
   assert(it != registry.end() &&
          "adding an interface for an unregistered dialect");
-  interfaces[it->second.first].push_back(allocator);
+
+  // Bail out if the interface with the given ID is already in the registry for
+  // the given dialect. We expect a small number (dozens) of interfaces so a
+  // linear search is fine here.
+  auto &dialectInterfaces = interfaces[it->second.first];
+  for (const auto &kvp : dialectInterfaces) {
+    if (kvp.first == interfaceTypeID) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "[" DEBUG_TYPE
+                    "] repeated interface registration for dialect "
+                 << dialectName);
+      return;
+    }
+  }
+
+  dialectInterfaces.emplace_back(interfaceTypeID, allocator);
 }
 
 DialectAllocatorFunctionRef
@@ -59,8 +78,12 @@ void DialectRegistry::registerDelayedInterfaces(Dialect *dialect) const {
   if (it == interfaces.end())
     return;
 
-  for (const InterfaceAllocatorFunction &createInterface : it->second)
-    dialect->addInterface(createInterface(dialect));
+  // Add an interface if it is not already present.
+  for (const auto &kvp : it->second) {
+    if (dialect->getRegisteredInterface(kvp.first))
+      continue;
+    dialect->addInterface(kvp.second(dialect));
+  }
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/unittests/IR/DialectTest.cpp b/mlir/unittests/IR/DialectTest.cpp
index 64d207bec453..ddabec514b93 100644
--- a/mlir/unittests/IR/DialectTest.cpp
+++ b/mlir/unittests/IR/DialectTest.cpp
@@ -94,4 +94,52 @@ TEST(Dialect, DelayedInterfaceRegistration) {
   EXPECT_TRUE(secondTestDialectInterface != nullptr);
 }
 
+TEST(Dialect, RepeatedDelayedRegistration) {
+  // Set up the delayed registration.
+  DialectRegistry registry;
+  registry.insert<TestDialect>();
+  registry.addDialectInterface<TestDialect, TestDialectInterface>();
+  MLIRContext context(registry);
+
+  // Load the TestDialect and check that the interface got registered for it.
+  auto *testDialect = context.getOrLoadDialect<TestDialect>();
+  ASSERT_TRUE(testDialect != nullptr);
+  auto *testDialectInterface =
+      testDialect->getRegisteredInterface<TestDialectInterfaceBase>();
+  EXPECT_TRUE(testDialectInterface != nullptr);
+
+  // Try adding the same dialect interface again and check that we don't crash
+  // on repeated interface registration.
+  DialectRegistry secondRegistry;
+  secondRegistry.insert<TestDialect>();
+  secondRegistry.addDialectInterface<TestDialect, TestDialectInterface>();
+  context.appendDialectRegistry(secondRegistry);
+  testDialectInterface =
+      testDialect->getRegisteredInterface<TestDialectInterfaceBase>();
+  EXPECT_TRUE(testDialectInterface != nullptr);
+}
+
+// A dialect that registers two interfaces with the same InterfaceID, triggering
+// an assertion failure.
+struct RepeatedRegistrationDialect : public Dialect {
+  static StringRef getDialectNamespace() { return "repeatedreg"; }
+  RepeatedRegistrationDialect(MLIRContext *context)
+      : Dialect(getDialectNamespace(), context,
+                TypeID::get<RepeatedRegistrationDialect>()) {
+    addInterfaces<TestDialectInterface>();
+    addInterfaces<SecondTestDialectInterface>();
+  }
+};
+
+TEST(Dialect, RepeatedInterfaceRegistrationDeath) {
+  MLIRContext context;
+  (void)context;
+
+  // This triggers an assertion in debug mode.
+#ifndef NDEBUG
+  ASSERT_DEATH(context.loadDialect<RepeatedRegistrationDialect>(),
+               "interface kind has already been registered");
+#endif
+}
+
 } // end namespace


        


More information about the Mlir-commits mailing list