[Mlir-commits] [mlir] 5099a48 - [MLIR] Replace dialect registration hooks with dialect handle

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 9 09:11:40 PST 2021


Author: George
Date: 2021-02-09T09:02:16-08:00
New Revision: 5099a48a3bdc6a25b83c47281251825101727e96

URL: https://github.com/llvm/llvm-project/commit/5099a48a3bdc6a25b83c47281251825101727e96
DIFF: https://github.com/llvm/llvm-project/commit/5099a48a3bdc6a25b83c47281251825101727e96.diff

LOG: [MLIR] Replace dialect registration hooks with dialect handle

Replace MlirDialectRegistrationHooks with MlirDialectHandle, which under-the-hood is an opaque pointer to MlirDialectRegistrationHooks. Then we expose the functionality previously directly on MlirDialectRegistrationHooks, as functions which take the opaque MlirDialectHandle struct. This makes the actual structure of the registration hooks an implementation detail, and happens to avoid this issue: https://llvm.discourse.group/t/strange-swift-issues-with-dialect-registration-hooks/2759/3

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D96229

Added: 
    mlir/lib/CAPI/IR/DialectHandle.cpp

Modified: 
    mlir/include/mlir-c/Registration.h
    mlir/include/mlir/CAPI/Registration.h
    mlir/lib/CAPI/IR/CMakeLists.txt
    mlir/test/CAPI/ir.c

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/Registration.h b/mlir/include/mlir-c/Registration.h
index 7fde05d50bd9..6c7a486373ee 100644
--- a/mlir/include/mlir-c/Registration.h
+++ b/mlir/include/mlir-c/Registration.h
@@ -23,47 +23,34 @@ extern "C" {
 // API name (i.e. "Standard", "Tensor", "Linalg") and namespace (i.e. "std",
 // "tensor", "linalg"). The following declarations are produced:
 //
-//   /// Registers the dialect with the given context. This allows the
-//   /// dialect to be loaded dynamically if needed when parsing. */
-//   void mlirContextRegister{NAME}Dialect(MlirContext);
-//
-//   /// Loads the dialect into the given context. The dialect does _not_
-//   /// have to be registered in advance.
-//   MlirDialect mlirContextLoad{NAME}Dialect(MlirContext context);
-//
-//   /// Returns the namespace of the Standard dialect, suitable for loading it.
-//   MlirStringRef mlir{NAME}DialectGetNamespace();
-//
 //   /// Gets the above hook methods in struct form for a dialect by namespace.
 //   /// This is intended to facilitate dynamic lookup and registration of
 //   /// dialects via a plugin facility based on shared library symbol lookup.
-//   const MlirDialectRegistrationHooks *mlirGetDialectHooks__{NAMESPACE}__();
+//   const MlirDialectHandle *mlirGetDialectHandle__{NAMESPACE}__();
 //
 // This is done via a common macro to facilitate future expansion to
 // registration schemes.
 //===----------------------------------------------------------------------===//
 
+struct MlirDialectHandle {
+  const void *ptr;
+};
+typedef struct MlirDialectHandle MlirDialectHandle;
+
 #define MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Name, Namespace)                \
-  MLIR_CAPI_EXPORTED void mlirContextRegister##Name##Dialect(                  \
-      MlirContext context);                                                    \
-  MLIR_CAPI_EXPORTED MlirDialect mlirContextLoad##Name##Dialect(               \
-      MlirContext context);                                                    \
-  MLIR_CAPI_EXPORTED MlirStringRef mlir##Name##DialectGetNamespace();          \
-  MLIR_CAPI_EXPORTED const MlirDialectRegistrationHooks                        \
-      *mlirGetDialectHooks__##Namespace##__()
+  MLIR_CAPI_EXPORTED MlirDialectHandle mlirGetDialectHandle__##Namespace##__()
 
-/// Hooks for dynamic discovery of dialects.
-typedef void (*MlirContextRegisterDialectHook)(MlirContext context);
-typedef MlirDialect (*MlirContextLoadDialectHook)(MlirContext context);
-typedef MlirStringRef (*MlirDialectGetNamespaceHook)();
+/// Returns the namespace associated with the provided dialect handle.
+MLIR_CAPI_EXPORTED
+MlirStringRef mlirDialectHandleGetNamespace(MlirDialectHandle);
 
-/// Structure of dialect registration hooks.
-struct MlirDialectRegistrationHooks {
-  MlirContextRegisterDialectHook registerHook;
-  MlirContextLoadDialectHook loadHook;
-  MlirDialectGetNamespaceHook getNamespaceHook;
-};
-typedef struct MlirDialectRegistrationHooks MlirDialectRegistrationHooks;
+/// Registers the dialect associated with the provided dialect handle.
+MLIR_CAPI_EXPORTED void mlirDialectHandleRegisterDialect(MlirDialectHandle,
+                                                         MlirContext);
+
+/// Loads the dialect associated with the provided dialect handle.
+MLIR_CAPI_EXPORTED MlirDialect mlirDialectHandleLoadDialect(MlirDialectHandle,
+                                                            MlirContext);
 
 /// Registers all dialects known to core MLIR with the provided Context.
 /// This is needed before creating IR for these Dialects.

diff  --git a/mlir/include/mlir/CAPI/Registration.h b/mlir/include/mlir/CAPI/Registration.h
index da63afb4c515..7601f9fc0e63 100644
--- a/mlir/include/mlir/CAPI/Registration.h
+++ b/mlir/include/mlir/CAPI/Registration.h
@@ -20,21 +20,34 @@
 // of the dialect class.
 //===----------------------------------------------------------------------===//
 
+/// Hooks for dynamic discovery of dialects.
+typedef void (*MlirContextRegisterDialectHook)(MlirContext context);
+typedef MlirDialect (*MlirContextLoadDialectHook)(MlirContext context);
+typedef MlirStringRef (*MlirDialectGetNamespaceHook)();
+
+/// Structure of dialect registration hooks.
+struct MlirDialectRegistrationHooks {
+  MlirContextRegisterDialectHook registerHook;
+  MlirContextLoadDialectHook loadHook;
+  MlirDialectGetNamespaceHook getNamespaceHook;
+};
+typedef struct MlirDialectRegistrationHooks MlirDialectRegistrationHooks;
+
 #define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Name, Namespace, ClassName)      \
-  void mlirContextRegister##Name##Dialect(MlirContext context) {               \
+  static void mlirContextRegister##Name##Dialect(MlirContext context) {        \
     unwrap(context)->getDialectRegistry().insert<ClassName>();                 \
   }                                                                            \
-  MlirDialect mlirContextLoad##Name##Dialect(MlirContext context) {            \
+  static MlirDialect mlirContextLoad##Name##Dialect(MlirContext context) {     \
     return wrap(unwrap(context)->getOrLoadDialect<ClassName>());               \
   }                                                                            \
-  MlirStringRef mlir##Name##DialectGetNamespace() {                            \
+  static MlirStringRef mlir##Name##DialectGetNamespace() {                     \
     return wrap(ClassName::getDialectNamespace());                             \
   }                                                                            \
-  const MlirDialectRegistrationHooks *mlirGetDialectHooks__##Namespace##__() { \
+  MlirDialectHandle mlirGetDialectHandle__##Namespace##__() {                  \
     static MlirDialectRegistrationHooks hooks = {                              \
         mlirContextRegister##Name##Dialect, mlirContextLoad##Name##Dialect,    \
         mlir##Name##DialectGetNamespace};                                      \
-    return &hooks;                                                             \
+    return MlirDialectHandle{&hooks};                                          \
   }
 
 #endif // MLIR_CAPI_REGISTRATION_H

diff  --git a/mlir/lib/CAPI/IR/CMakeLists.txt b/mlir/lib/CAPI/IR/CMakeLists.txt
index 893ccb6721fb..486ba6e0f8e1 100644
--- a/mlir/lib/CAPI/IR/CMakeLists.txt
+++ b/mlir/lib/CAPI/IR/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_public_c_api_library(MLIRCAPIIR
   BuiltinAttributes.cpp
   BuiltinTypes.cpp
   Diagnostics.cpp
+  DialectHandle.cpp
   IntegerSet.cpp
   IR.cpp
   Pass.cpp

diff  --git a/mlir/lib/CAPI/IR/DialectHandle.cpp b/mlir/lib/CAPI/IR/DialectHandle.cpp
new file mode 100644
index 000000000000..fb972316ebdf
--- /dev/null
+++ b/mlir/lib/CAPI/IR/DialectHandle.cpp
@@ -0,0 +1,28 @@
+//===- DialectHandle.cpp - C Interface for MLIR Dialect Operations -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/CAPI/Registration.h"
+
+static inline const MlirDialectRegistrationHooks *
+unwrap(MlirDialectHandle handle) {
+  return (const MlirDialectRegistrationHooks *)handle.ptr;
+}
+
+MlirStringRef mlirDialectHandleGetNamespace(MlirDialectHandle handle) {
+  return unwrap(handle)->getNamespaceHook();
+}
+
+void mlirDialectHandleRegisterDialect(MlirDialectHandle handle,
+                                      MlirContext ctx) {
+  unwrap(handle)->registerHook(ctx);
+}
+
+MlirDialect mlirDialectHandleLoadDialect(MlirDialectHandle handle,
+                                         MlirContext ctx) {
+  return unwrap(handle)->loadHook(ctx);
+}

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 2f81d13160c2..7576133e60d4 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -1412,23 +1412,26 @@ int registerOnlyStd() {
   if (mlirContextGetNumLoadedDialects(ctx) != 1)
     return 1;
 
-  MlirDialect std =
-      mlirContextGetOrLoadDialect(ctx, mlirStandardDialectGetNamespace());
+  MlirDialectHandle stdHandle = mlirGetDialectHandle__std__();
+
+  MlirDialect std = mlirContextGetOrLoadDialect(
+      ctx, mlirDialectHandleGetNamespace(stdHandle));
   if (!mlirDialectIsNull(std))
     return 2;
 
-  mlirContextRegisterStandardDialect(ctx);
+  mlirDialectHandleRegisterDialect(stdHandle, ctx);
 
-  std = mlirContextGetOrLoadDialect(ctx, mlirStandardDialectGetNamespace());
+  std = mlirContextGetOrLoadDialect(ctx,
+                                    mlirDialectHandleGetNamespace(stdHandle));
   if (mlirDialectIsNull(std))
     return 3;
 
-  MlirDialect alsoStd = mlirContextLoadStandardDialect(ctx);
+  MlirDialect alsoStd = mlirDialectHandleLoadDialect(stdHandle, ctx);
   if (!mlirDialectEqual(std, alsoStd))
     return 4;
 
   MlirStringRef stdNs = mlirDialectGetNamespace(std);
-  MlirStringRef alsoStdNs = mlirStandardDialectGetNamespace();
+  MlirStringRef alsoStdNs = mlirDialectHandleGetNamespace(stdHandle);
   if (stdNs.length != alsoStdNs.length ||
       strncmp(stdNs.data, alsoStdNs.data, stdNs.length))
     return 5;


        


More information about the Mlir-commits mailing list