[Mlir-commits] [mlir] b72e13c - [MLIR] Deduplicate dialect registration by ClassID

River Riddle llvmlistbot at llvm.org
Wed Mar 18 19:56:37 PDT 2020


Author: Geoffrey Martin-Noble
Date: 2020-03-18T19:52:27-07:00
New Revision: b72e13c242d9bbe1a4c7e471da98718bde85fa78

URL: https://github.com/llvm/llvm-project/commit/b72e13c242d9bbe1a4c7e471da98718bde85fa78
DIFF: https://github.com/llvm/llvm-project/commit/b72e13c242d9bbe1a4c7e471da98718bde85fa78.diff

LOG: [MLIR] Deduplicate dialect registration by ClassID

Summary:
With the move towards dialect registration that does not depend only use
static initialization, we are running into more cases where the dialects
are registered by different methods. For example, TensorFlow still uses
static initialization to register all MLIR core dialects, which prevents
explicit registration of any of them when linking it in. We ran into this
issue in https://github.com/google/iree/pull/982.

To address potential issues with conflicts from non-standard
allocators passed to registerDialectAllocator, made this method
private. Now all dialects can only be registered with their
constructor.

Similarly deduplicates DialectHooks for consistency and makes their
registration follow the same pattern.

Differential Revision: https://reviews.llvm.org/D76329

Added: 
    

Modified: 
    mlir/include/mlir/IR/Dialect.h
    mlir/include/mlir/IR/DialectHooks.h
    mlir/lib/IR/Dialect.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h
index 2f17ad002b30..f65060238171 100644
--- a/mlir/include/mlir/IR/Dialect.h
+++ b/mlir/include/mlir/IR/Dialect.h
@@ -28,6 +28,7 @@ using DialectConstantFoldHook = std::function<LogicalResult(
     Operation *, ArrayRef<Attribute>, SmallVectorImpl<Attribute> &)>;
 using DialectExtractElementHook =
     std::function<Attribute(const OpaqueElementsAttr, ArrayRef<uint64_t>)>;
+using DialectAllocatorFunction = std::function<void(MLIRContext *)>;
 
 /// Dialects are groups of MLIR operations and behavior associated with the
 /// entire group.  For example, hooks into other systems for constant folding,
@@ -241,24 +242,30 @@ class Dialect {
 
   /// A collection of registered dialect interfaces.
   DenseMap<ClassID *, std::unique_ptr<DialectInterface>> registeredInterfaces;
-};
-
-using DialectAllocatorFunction = std::function<void(MLIRContext *)>;
-
-/// Registers a specific dialect creation function with the system, typically
-/// used through the DialectRegistration template.
-void registerDialectAllocator(const DialectAllocatorFunction &function);
 
-/// Registers all dialects with the specified MLIRContext.
+  /// Registers a specific dialect creation function with the global registry.
+  /// Used through the registerDialect template.
+  /// Registrations are deduplicated by dialect ClassID and only the first
+  /// registration will be used.
+  static void
+  registerDialectAllocator(const ClassID *classId,
+                           const DialectAllocatorFunction &function);
+  template <typename ConcreteDialect>
+  friend void registerDialect();
+};
+/// Registers all dialects and hooks from the global registries with the
+/// specified MLIRContext.
 void registerAllDialects(MLIRContext *context);
 
 /// Utility to register a dialect. Client can register their dialect with the
 /// global registry by calling registerDialect<MyDialect>();
 template <typename ConcreteDialect> void registerDialect() {
-  registerDialectAllocator([](MLIRContext *ctx) {
-    // Just allocate the dialect, the context takes ownership of it.
-    new ConcreteDialect(ctx);
-  });
+  Dialect::registerDialectAllocator(ClassID::getID<ConcreteDialect>(),
+                                    [](MLIRContext *ctx) {
+                                      // Just allocate the dialect, the context
+                                      // takes ownership of it.
+                                      new ConcreteDialect(ctx);
+                                    });
 }
 
 /// DialectRegistration provides a global initializer that registers a Dialect

diff  --git a/mlir/include/mlir/IR/DialectHooks.h b/mlir/include/mlir/IR/DialectHooks.h
index 2dce1c2b203a..4e59b4953e65 100644
--- a/mlir/include/mlir/IR/DialectHooks.h
+++ b/mlir/include/mlir/IR/DialectHooks.h
@@ -35,36 +35,53 @@ class DialectHooks {
   DialectConstantDecodeHook getDecodeHook() { return nullptr; }
   // Returns hook to extract an element of an opaque constant tensor.
   DialectExtractElementHook getExtractElementHook() { return nullptr; }
+
+private:
+  /// Registers a function that will set hooks in the registered dialects.
+  /// Registrations are deduplicated by dialect ClassID and only the first
+  /// registration will be used.
+  static void registerDialectHooksSetter(const ClassID *classId,
+                                         const DialectHooksSetter &function);
+  template <typename ConcreteHooks>
+  friend void registerDialectHooks(StringRef dialectName);
 };
 
-/// Registers a function that will set hooks in the registered dialects
-/// based on information coming from DialectHooksRegistration.
-void registerDialectHooksSetter(const DialectHooksSetter &function);
+void registerDialectHooksSetter(const ClassID *classId,
+                                const DialectHooksSetter &function);
+
+/// Utility to register dialect hooks. Client can register their dialect hooks
+/// with the global registry by calling
+/// registerDialectHooks<MyHooks>("dialect_namespace");
+template <typename ConcreteHooks>
+void registerDialectHooks(StringRef dialectName) {
+  DialectHooks::registerDialectHooksSetter(
+      ClassID::getID<ConcreteHooks>(), [dialectName](MLIRContext *ctx) {
+        Dialect *dialect = ctx->getRegisteredDialect(dialectName);
+        if (!dialect) {
+          llvm::errs() << "error: cannot register hooks for unknown dialect '"
+                       << dialectName << "'\n";
+          abort();
+        }
+        // Set hooks.
+        ConcreteHooks hooks;
+        if (auto h = hooks.getConstantFoldHook())
+          dialect->constantFoldHook = h;
+        if (auto h = hooks.getDecodeHook())
+          dialect->decodeHook = h;
+        if (auto h = hooks.getExtractElementHook())
+          dialect->extractElementHook = h;
+      });
+}
 
 /// DialectHooksRegistration provides a global initializer that registers
 /// a dialect hooks setter routine.
 /// Usage:
 ///
 ///   // At namespace scope.
-///   static DialectHooksRegistration<MyHooks, MyDialect> unused;
+///   static DialectHooksRegistration<MyHooks> Unused("dialect_namespace");
 template <typename ConcreteHooks> struct DialectHooksRegistration {
   DialectHooksRegistration(StringRef dialectName) {
-    registerDialectHooksSetter([dialectName](MLIRContext *ctx) {
-      Dialect *dialect = ctx->getRegisteredDialect(dialectName);
-      if (!dialect) {
-        llvm::errs() << "error: cannot register hooks for unknown dialect '"
-                     << dialectName << "'\n";
-        abort();
-      }
-      // Set hooks.
-      ConcreteHooks hooks;
-      if (auto h = hooks.getConstantFoldHook())
-        dialect->constantFoldHook = h;
-      if (auto h = hooks.getDecodeHook())
-        dialect->decodeHook = h;
-      if (auto h = hooks.getExtractElementHook())
-        dialect->extractElementHook = h;
-    });
+    registerDialectHooks<ConcreteHooks>(dialectName);
   }
 };
 

diff  --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index 4ce461e2d7d9..e48e7f64010d 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -13,6 +13,7 @@
 #include "mlir/IR/DialectInterface.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Operation.h"
+#include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/Twine.h"
 #include "llvm/Support/ManagedStatic.h"
 #include "llvm/Support/Regex.h"
@@ -26,39 +27,40 @@ DialectAsmParser::~DialectAsmParser() {}
 // Dialect Registration
 //===----------------------------------------------------------------------===//
 
-// Registry for all dialect allocation functions.
-static llvm::ManagedStatic<SmallVector<DialectAllocatorFunction, 8>>
+/// Registry for all dialect allocation functions.
+static llvm::ManagedStatic<
+    llvm::MapVector<const ClassID *, DialectAllocatorFunction>>
     dialectRegistry;
 
-// Registry for functions that set dialect hooks.
-static llvm::ManagedStatic<SmallVector<DialectHooksSetter, 8>>
+/// Registry for functions that set dialect hooks.
+static llvm::ManagedStatic<llvm::MapVector<const ClassID *, DialectHooksSetter>>
     dialectHooksRegistry;
 
-/// Registers a specific dialect creation function with the system, typically
-/// used through the DialectRegistration template.
-void mlir::registerDialectAllocator(const DialectAllocatorFunction &function) {
+void Dialect::registerDialectAllocator(
+    const ClassID *classId, const DialectAllocatorFunction &function) {
   assert(function &&
          "Attempting to register an empty dialect initialize function");
-  dialectRegistry->push_back(function);
+  dialectRegistry->insert({classId, function});
 }
 
 /// Registers a function to set specific hooks for a specific dialect, typically
 /// used through the DialectHooksRegistration template.
-void mlir::registerDialectHooksSetter(const DialectHooksSetter &function) {
+void DialectHooks::registerDialectHooksSetter(
+    const ClassID *classId, const DialectHooksSetter &function) {
   assert(
       function &&
       "Attempting to register an empty dialect hooks initialization function");
 
-  dialectHooksRegistry->push_back(function);
+  dialectHooksRegistry->insert({classId, function});
 }
 
-/// Registers all dialects and their const folding hooks with the specified
-/// MLIRContext.
+/// Registers all dialects and hooks from the global registries with the
+/// specified MLIRContext.
 void mlir::registerAllDialects(MLIRContext *context) {
-  for (const auto &fn : *dialectRegistry)
-    fn(context);
-  for (const auto &fn : *dialectHooksRegistry) {
-    fn(context);
+  for (const auto &it : *dialectRegistry)
+    it.second(context);
+  for (const auto &it : *dialectHooksRegistry) {
+    it.second(context);
   }
 }
 


        


More information about the Mlir-commits mailing list