[Mlir-commits] [mlir] 3da5152 - [mlir] enable delayed registration of dialect interfaces

Alex Zinenko llvmlistbot at llvm.org
Wed Feb 10 03:07:43 PST 2021


Author: Alex Zinenko
Date: 2021-02-10T12:07:32+01:00
New Revision: 3da51522fb4f72b7d4619f2dfd454bb3073ab460

URL: https://github.com/llvm/llvm-project/commit/3da51522fb4f72b7d4619f2dfd454bb3073ab460
DIFF: https://github.com/llvm/llvm-project/commit/3da51522fb4f72b7d4619f2dfd454bb3073ab460.diff

LOG: [mlir] enable delayed registration of dialect interfaces

This introduces a mechanism to register interfaces for a dialect without making
the dialect itself depend on the interface. The registration request happens on
DialectRegistry and, if the dialect has not been loaded yet, the actual
registration is delayed until the dialect is loaded. It requires
DialectRegistry to become aware of the context that contains it and the context
to expose methods for querying if a dialect is loaded.

This mechanism will enable a simple extension mechanism for dialects that can
have interfaces defined outside of the dialect code. It is particularly helpful
for, e.g., translation to LLVM IR where we don't want the dialect itself to
depend on LLVM IR libraries.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D96137

Added: 
    

Modified: 
    mlir/include/mlir/IR/Dialect.h
    mlir/lib/IR/Dialect.cpp
    mlir/lib/IR/MLIRContext.cpp
    mlir/lib/Support/MlirOptMain.cpp
    mlir/unittests/IR/DialectTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h
index cd64d388aec2..978531f3098c 100644
--- a/mlir/include/mlir/IR/Dialect.h
+++ b/mlir/include/mlir/IR/Dialect.h
@@ -26,6 +26,8 @@ class OpBuilder;
 class Type;
 
 using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>;
+using InterfaceAllocatorFunction =
+    std::function<std::unique_ptr<DialectInterface>(Dialect *)>;
 
 /// Dialects are groups of MLIR operations, types and attributes, as well as
 /// behavior associated with the entire group.  For example, hooks into other
@@ -222,6 +224,7 @@ class Dialect {
   /// A collection of registered dialect interfaces.
   DenseMap<TypeID, std::unique_ptr<DialectInterface>> registeredInterfaces;
 
+  friend class DialectRegistry;
   friend void registerDialect();
   friend class MLIRContext;
 };
@@ -234,8 +237,13 @@ class Dialect {
 class DialectRegistry {
   using MapTy =
       std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>>;
+  using InterfaceMapTy =
+      DenseMap<TypeID, SmallVector<InterfaceAllocatorFunction, 2>>;
 
 public:
+  explicit DialectRegistry(MLIRContext *context = nullptr)
+      : owningContext(context) {}
+
   template <typename ConcreteDialect>
   void insert() {
     insert(TypeID::get<ConcreteDialect>(),
@@ -254,7 +262,9 @@ class DialectRegistry {
     insert<OtherDialect, MoreDialects...>();
   }
 
-  /// Add a new dialect constructor to the registry.
+  /// Add a new dialect constructor to the registry. The constructor must be
+  /// calling MLIRContext::getOrLoadDialect in order for the context to take
+  /// ownership of the dialect and for delayed interface registration to happen.
   void insert(TypeID typeID, StringRef name, DialectAllocatorFunction ctor);
 
   /// Load a dialect for this namespace in the provided context.
@@ -267,6 +277,7 @@ class DialectRegistry {
       destination.insert(nameAndRegistrationIt.second.first,
                          nameAndRegistrationIt.first,
                          nameAndRegistrationIt.second.second);
+    destination.interfaces.insert(interfaces.begin(), interfaces.end());
   }
   // Load all dialects available in the registry in the provided context.
   void loadAll(MLIRContext *context) {
@@ -274,11 +285,47 @@ class DialectRegistry {
       nameAndRegistrationIt.second.second(context);
   }
 
-  MapTy::const_iterator begin() const { return registry.begin(); }
-  MapTy::const_iterator end() const { return registry.end(); }
+  /// Return the names of dialects known to this registry.
+  auto getDialectNames() {
+    return llvm::map_range(
+        registry, [](const MapTy::value_type &item) { return item.first; });
+  }
+
+  /// Add an interface constructed with the given allocation function to the
+  /// dialect provided as template parameter. The dialect must be present in
+  /// the registry, but may or may not be loaded. If it is not loaded, the
+  /// interface registration is delayed until the loading.
+  template <typename DialectTy>
+  void addDialectInterface(InterfaceAllocatorFunction allocator) {
+    addDialectInterface(DialectTy::getDialectNamespace(), allocator);
+  }
+
+  /// Add an interface to the dialect, both provided as template parameter. The
+  /// dialect must be present in the registry, but may or may not be loaded. If
+  /// it is not loaded, the interface registration is delayed until the loading.
+  template <typename DialectTy, typename InterfaceTy>
+  void addDialectInterface() {
+    addDialectInterface<DialectTy>([](Dialect *dialect) {
+      return std::make_unique<InterfaceTy>(dialect);
+    });
+  }
+
+  /// Register any interfaces required for the given dialect (based on its
+  /// TypeID). Users are not expected to call this directly.
+  void registerDelayedInterfaces(Dialect *dialect);
 
 private:
+  /// Add an interface constructed with the given allocation function to the
+  /// dialect identified by its namespace.
+  void addDialectInterface(StringRef dialectName,
+                           InterfaceAllocatorFunction allocator);
+
   MapTy registry;
+  InterfaceMapTy interfaces;
+
+  /// If this registry belongs to a context, this points back to the context.
+  /// Useful for checking if a dialect is loaded in the context.
+  MLIRContext *owningContext;
 };
 
 } // namespace mlir

diff  --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index beabd4881c29..01f8ec1f38cc 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -22,6 +22,29 @@ using namespace detail;
 
 DialectAsmParser::~DialectAsmParser() {}
 
+//===----------------------------------------------------------------------===//
+// DialectRegistry
+//===----------------------------------------------------------------------===//
+
+void DialectRegistry::addDialectInterface(
+    StringRef dialectName, InterfaceAllocatorFunction allocator) {
+  assert(allocator && "unexpected null interface allocation function");
+
+  // If the dialect is already loaded, directly add the interface.
+  if (Dialect *dialect = owningContext
+                             ? owningContext->getLoadedDialect(dialectName)
+                             : nullptr) {
+    dialect->addInterface(allocator(dialect));
+    return;
+  }
+
+  // Otherwise, store it in the interface map for delayed registration.
+  auto it = registry.find(dialectName.str());
+  assert(it != registry.end() &&
+         "adding an interface for an unregistered dialect");
+  interfaces[it->second.first].push_back(allocator);
+}
+
 Dialect *DialectRegistry::loadByName(StringRef name, MLIRContext *context) {
   auto it = registry.find(name.str());
   if (it == registry.end())
@@ -40,6 +63,15 @@ void DialectRegistry::insert(TypeID typeID, StringRef name,
   }
 }
 
+void DialectRegistry::registerDelayedInterfaces(Dialect *dialect) {
+  auto it = interfaces.find(dialect->getTypeID());
+  if (it == interfaces.end())
+    return;
+
+  for (const InterfaceAllocatorFunction &createInterface : it->second)
+    dialect->addInterface(createInterface(dialect));
+}
+
 //===----------------------------------------------------------------------===//
 // Dialect
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index f56eb75390a9..832eea747771 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -326,7 +326,8 @@ class MLIRContextImpl {
   DictionaryAttr emptyDictionaryAttr;
 
 public:
-  MLIRContextImpl() : identifiers(identifierAllocator) {}
+  MLIRContextImpl(MLIRContext *ctx)
+      : dialectsRegistry(ctx), identifiers(identifierAllocator) {}
   ~MLIRContextImpl() {
     for (auto typeMapping : registeredTypes)
       typeMapping.second->~AbstractType();
@@ -336,7 +337,7 @@ class MLIRContextImpl {
 };
 } // end namespace mlir
 
-MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) {
+MLIRContext::MLIRContext() : impl(new MLIRContextImpl(this)) {
   // Initialize values based on the command line flags if they were provided.
   if (clOptions.isConstructed()) {
     disableMultithreading(clOptions->disableThreading);
@@ -441,8 +442,8 @@ std::vector<Dialect *> MLIRContext::getLoadedDialects() {
 }
 std::vector<StringRef> MLIRContext::getAvailableDialects() {
   std::vector<StringRef> result;
-  for (auto &dialect : impl->dialectsRegistry)
-    result.push_back(dialect.first);
+  for (auto dialect : impl->dialectsRegistry.getDialectNames())
+    result.push_back(dialect);
   return result;
 }
 
@@ -493,6 +494,8 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
           identifierEntry.first().startswith(dialectNamespace))
         identifierEntry.second = dialect.get();
 
+    // Actually register the interfaces with delayed registration.
+    impl.dialectsRegistry.registerDelayedInterfaces(dialect.get());
     return dialect.get();
   }
 

diff  --git a/mlir/lib/Support/MlirOptMain.cpp b/mlir/lib/Support/MlirOptMain.cpp
index 85891fd591f7..27968517987d 100644
--- a/mlir/lib/Support/MlirOptMain.cpp
+++ b/mlir/lib/Support/MlirOptMain.cpp
@@ -201,10 +201,8 @@ LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
   {
     llvm::raw_string_ostream os(helpHeader);
     MLIRContext context;
-    interleaveComma(registry, os, [&](auto &registryEntry) {
-      StringRef name = registryEntry.first;
-      os << name;
-    });
+    interleaveComma(registry.getDialectNames(), os,
+                    [&](auto name) { os << name; });
   }
   // Parse pass names in main to ensure static initialization completed.
   cl::ParseCommandLineOptions(argc, argv, helpHeader);
@@ -212,8 +210,8 @@ LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
   if (showDialects) {
     llvm::outs() << "Available Dialects:\n";
     interleave(
-        registry, llvm::outs(),
-        [](auto &registryEntry) { llvm::outs() << registryEntry.first; }, "\n");
+        registry.getDialectNames(), llvm::outs(),
+        [](auto name) { llvm::outs() << name; }, "\n");
     return success();
   }
 

diff  --git a/mlir/unittests/IR/DialectTest.cpp b/mlir/unittests/IR/DialectTest.cpp
index 2410be0263b3..ed19558ef5ca 100644
--- a/mlir/unittests/IR/DialectTest.cpp
+++ b/mlir/unittests/IR/DialectTest.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectInterface.h"
 #include "gtest/gtest.h"
 
 using namespace mlir;
@@ -34,4 +35,61 @@ TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) {
   ASSERT_DEATH(context.loadDialect<AnotherTestDialect>(), "");
 }
 
+struct SecondTestDialect : public Dialect {
+  static StringRef getDialectNamespace() { return "test2"; }
+  SecondTestDialect(MLIRContext *context)
+      : Dialect(getDialectNamespace(), context,
+                TypeID::get<SecondTestDialect>()) {}
+};
+
+struct TestDialectInterfaceBase
+    : public DialectInterface::Base<TestDialectInterfaceBase> {
+  TestDialectInterfaceBase(Dialect *dialect) : Base(dialect) {}
+  virtual int function() const { return 42; }
+};
+
+struct TestDialectInterface : public TestDialectInterfaceBase {
+  using TestDialectInterfaceBase::TestDialectInterfaceBase;
+  int function() const final { return 56; }
+};
+
+struct SecondTestDialectInterface : public TestDialectInterfaceBase {
+  using TestDialectInterfaceBase::TestDialectInterfaceBase;
+  int function() const final { return 78; }
+};
+
+TEST(Dialect, DelayedInterfaceRegistration) {
+  DialectRegistry registry;
+  registry.insert<TestDialect, SecondTestDialect>();
+
+  // Delayed registration of an interface for TestDialect.
+  registry.addDialectInterface<TestDialect, TestDialectInterface>();
+
+  MLIRContext context;
+  registry.appendTo(context.getDialectRegistry());
+
+  // Load the TestDialect and check that the interface got registered for it.
+  auto *testDialect = context.getOrLoadDialect<TestDialect>();
+  ASSERT_TRUE(testDialect != nullptr);
+  auto *testDialectInterface =
+      testDialect->getRegisteredInterface<TestDialectInterfaceBase>();
+  EXPECT_TRUE(testDialectInterface != nullptr);
+
+  // Load the SecondTestDialect and check that the interface is not registered
+  // for it.
+  auto *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>();
+  ASSERT_TRUE(secondTestDialect != nullptr);
+  auto *secondTestDialectInterface =
+      secondTestDialect->getRegisteredInterface<SecondTestDialectInterface>();
+  EXPECT_TRUE(secondTestDialectInterface == nullptr);
+
+  // Use the same mechanism as for delayed registration but for an already
+  // loaded dialect and check that the interface is now registered.
+  context.getDialectRegistry()
+      .addDialectInterface<SecondTestDialect, SecondTestDialectInterface>();
+  secondTestDialectInterface =
+      secondTestDialect->getRegisteredInterface<SecondTestDialectInterface>();
+  EXPECT_TRUE(secondTestDialectInterface != nullptr);
+}
+
 } // end namespace


        


More information about the Mlir-commits mailing list