[Mlir-commits] [mlir] 97fc568 - [mlir][capi] Add DialectRegistry to MLIR C-API

Daniel Resnick llvmlistbot at llvm.org
Tue Feb 1 12:43:03 PST 2022


Author: Daniel Resnick
Date: 2022-02-01T13:42:06-07:00
New Revision: 97fc5682112d019230a3ab23cced3bdf093f6094

URL: https://github.com/llvm/llvm-project/commit/97fc5682112d019230a3ab23cced3bdf093f6094
DIFF: https://github.com/llvm/llvm-project/commit/97fc5682112d019230a3ab23cced3bdf093f6094.diff

LOG: [mlir][capi] Add DialectRegistry to MLIR C-API

Exposes mlir::DialectRegistry to the C API as MlirDialectRegistry along with
helper functions. A hook has been added to MlirDialectHandle that inserts
the dialect into a registry.

A future possible change is removing mlirDialectHandleRegisterDialect in
favor of using mlirDialectHandleInsertDialect, which it is now implemented with.

Differential Revision: https://reviews.llvm.org/D118293

Added: 
    

Modified: 
    mlir/include/mlir-c/IR.h
    mlir/include/mlir-c/Registration.h
    mlir/include/mlir/CAPI/IR.h
    mlir/include/mlir/CAPI/Registration.h
    mlir/lib/CAPI/IR/DialectHandle.cpp
    mlir/lib/CAPI/IR/IR.cpp
    mlir/test/CAPI/ir.c

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index e5c20ae70cacf..d999554664d96 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -50,6 +50,7 @@ extern "C" {
 
 DEFINE_C_API_STRUCT(MlirContext, void);
 DEFINE_C_API_STRUCT(MlirDialect, void);
+DEFINE_C_API_STRUCT(MlirDialectRegistry, void);
 DEFINE_C_API_STRUCT(MlirOperation, void);
 DEFINE_C_API_STRUCT(MlirOpPrintingFlags, void);
 DEFINE_C_API_STRUCT(MlirBlock, void);
@@ -108,6 +109,11 @@ mlirContextGetAllowUnregisteredDialects(MlirContext context);
 MLIR_CAPI_EXPORTED intptr_t
 mlirContextGetNumRegisteredDialects(MlirContext context);
 
+/// Append the contents of the given dialect registry to the registry associated
+/// with the context.
+MLIR_CAPI_EXPORTED void
+mlirContextAppendDialectRegistry(MlirContext ctx, MlirDialectRegistry registry);
+
 /// Returns the number of dialects loaded by the context.
 
 MLIR_CAPI_EXPORTED intptr_t
@@ -152,6 +158,22 @@ MLIR_CAPI_EXPORTED bool mlirDialectEqual(MlirDialect dialect1,
 /// Returns the namespace of the given dialect.
 MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect);
 
+//===----------------------------------------------------------------------===//
+// DialectRegistry API.
+//===----------------------------------------------------------------------===//
+
+/// Creates a dialect registry and transfers its ownership to the caller.
+MLIR_CAPI_EXPORTED MlirDialectRegistry mlirDialectRegistryCreate();
+
+/// Checks if the dialect registry is null.
+static inline bool mlirDialectRegistryIsNull(MlirDialectRegistry registry) {
+  return !registry.ptr;
+}
+
+/// Takes a dialect registry owned by the caller and destroys it.
+MLIR_CAPI_EXPORTED void
+mlirDialectRegistryDestroy(MlirDialectRegistry registry);
+
 //===----------------------------------------------------------------------===//
 // Location API.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir-c/Registration.h b/mlir/include/mlir-c/Registration.h
index e8604329fd381..442449626632c 100644
--- a/mlir/include/mlir-c/Registration.h
+++ b/mlir/include/mlir-c/Registration.h
@@ -44,6 +44,11 @@ typedef struct MlirDialectHandle MlirDialectHandle;
 MLIR_CAPI_EXPORTED
 MlirStringRef mlirDialectHandleGetNamespace(MlirDialectHandle);
 
+/// Inserts the dialect associated with the provided dialect handle into the
+/// provided dialect registry
+MLIR_CAPI_EXPORTED void mlirDialectHandleInsertDialect(MlirDialectHandle,
+                                                       MlirDialectRegistry);
+
 /// Registers the dialect associated with the provided dialect handle.
 MLIR_CAPI_EXPORTED void mlirDialectHandleRegisterDialect(MlirDialectHandle,
                                                          MlirContext);

diff  --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h
index af7ae89771f94..06cf7762a9c0e 100644
--- a/mlir/include/mlir/CAPI/IR.h
+++ b/mlir/include/mlir/CAPI/IR.h
@@ -22,6 +22,7 @@
 
 DEFINE_C_API_PTR_METHODS(MlirContext, mlir::MLIRContext)
 DEFINE_C_API_PTR_METHODS(MlirDialect, mlir::Dialect)
+DEFINE_C_API_PTR_METHODS(MlirDialectRegistry, mlir::DialectRegistry)
 DEFINE_C_API_PTR_METHODS(MlirOperation, mlir::Operation)
 DEFINE_C_API_PTR_METHODS(MlirBlock, mlir::Block)
 DEFINE_C_API_PTR_METHODS(MlirOpPrintingFlags, mlir::OpPrintingFlags)

diff  --git a/mlir/include/mlir/CAPI/Registration.h b/mlir/include/mlir/CAPI/Registration.h
index ac909d1dd9da7..e57023d30130e 100644
--- a/mlir/include/mlir/CAPI/Registration.h
+++ b/mlir/include/mlir/CAPI/Registration.h
@@ -21,23 +21,23 @@
 //===----------------------------------------------------------------------===//
 
 /// Hooks for dynamic discovery of dialects.
-typedef void (*MlirContextRegisterDialectHook)(MlirContext context);
+typedef void (*MlirDialectRegistryInsertDialectHook)(
+    MlirDialectRegistry registry);
 typedef MlirDialect (*MlirContextLoadDialectHook)(MlirContext context);
 typedef MlirStringRef (*MlirDialectGetNamespaceHook)();
 
 /// Structure of dialect registration hooks.
 struct MlirDialectRegistrationHooks {
-  MlirContextRegisterDialectHook registerHook;
+  MlirDialectRegistryInsertDialectHook insertHook;
   MlirContextLoadDialectHook loadHook;
   MlirDialectGetNamespaceHook getNamespaceHook;
 };
 typedef struct MlirDialectRegistrationHooks MlirDialectRegistrationHooks;
 
 #define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Name, Namespace, ClassName)      \
-  static void mlirContextRegister##Name##Dialect(MlirContext context) {        \
-    mlir::DialectRegistry registry;                                            \
-    registry.insert<ClassName>();                                              \
-    unwrap(context)->appendDialectRegistry(registry);                          \
+  static void mlirDialectRegistryInsert##Name##Dialect(                        \
+      MlirDialectRegistry registry) {                                          \
+    unwrap(registry)->insert<ClassName>();                                     \
   }                                                                            \
   static MlirDialect mlirContextLoad##Name##Dialect(MlirContext context) {     \
     return wrap(unwrap(context)->getOrLoadDialect<ClassName>());               \
@@ -47,8 +47,8 @@ typedef struct MlirDialectRegistrationHooks MlirDialectRegistrationHooks;
   }                                                                            \
   MlirDialectHandle mlirGetDialectHandle__##Namespace##__() {                  \
     static MlirDialectRegistrationHooks hooks = {                              \
-        mlirContextRegister##Name##Dialect, mlirContextLoad##Name##Dialect,    \
-        mlir##Name##DialectGetNamespace};                                      \
+        mlirDialectRegistryInsert##Name##Dialect,                              \
+        mlirContextLoad##Name##Dialect, mlir##Name##DialectGetNamespace};      \
     return MlirDialectHandle{&hooks};                                          \
   }
 

diff  --git a/mlir/lib/CAPI/IR/DialectHandle.cpp b/mlir/lib/CAPI/IR/DialectHandle.cpp
index fb972316ebdf0..19f64d9482179 100644
--- a/mlir/lib/CAPI/IR/DialectHandle.cpp
+++ b/mlir/lib/CAPI/IR/DialectHandle.cpp
@@ -17,9 +17,16 @@ MlirStringRef mlirDialectHandleGetNamespace(MlirDialectHandle handle) {
   return unwrap(handle)->getNamespaceHook();
 }
 
+void mlirDialectHandleInsertDialect(MlirDialectHandle handle,
+                                    MlirDialectRegistry registry) {
+  unwrap(handle)->insertHook(registry);
+}
+
 void mlirDialectHandleRegisterDialect(MlirDialectHandle handle,
                                       MlirContext ctx) {
-  unwrap(handle)->registerHook(ctx);
+  mlir::DialectRegistry registry;
+  mlirDialectHandleInsertDialect(handle, wrap(&registry));
+  unwrap(ctx)->appendDialectRegistry(registry);
 }
 
 MlirDialect mlirDialectHandleLoadDialect(MlirDialectHandle handle,

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 26af65ec75b6a..c067b202b71c1 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -53,6 +53,11 @@ intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) {
   return static_cast<intptr_t>(unwrap(context)->getAvailableDialects().size());
 }
 
+void mlirContextAppendDialectRegistry(MlirContext ctx,
+                                      MlirDialectRegistry registry) {
+  unwrap(ctx)->appendDialectRegistry(*unwrap(registry));
+}
+
 // TODO: expose a cheaper way than constructing + sorting a vector only to take
 // its size.
 intptr_t mlirContextGetNumLoadedDialects(MlirContext context) {
@@ -88,6 +93,18 @@ MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) {
   return wrap(unwrap(dialect)->getNamespace());
 }
 
+//===----------------------------------------------------------------------===//
+// DialectRegistry API.
+//===----------------------------------------------------------------------===//
+
+MlirDialectRegistry mlirDialectRegistryCreate() {
+  return wrap(new DialectRegistry());
+}
+
+void mlirDialectRegistryDestroy(MlirDialectRegistry registry) {
+  delete unwrap(registry);
+}
+
 //===----------------------------------------------------------------------===//
 // Printing flags API.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 257d5e9b8683d..79fb3344d0819 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -1904,6 +1904,36 @@ int testSymbolTable(MlirContext ctx) {
   return 0;
 }
 
+int testDialectRegistry() {
+  fprintf(stderr, "@testDialectRegistry\n");
+
+  MlirDialectRegistry registry = mlirDialectRegistryCreate();
+  if (mlirDialectRegistryIsNull(registry)) {
+    fprintf(stderr, "ERROR: Expected registry to be present\n");
+    return 1;
+  }
+
+  MlirDialectHandle stdHandle = mlirGetDialectHandle__std__();
+  mlirDialectHandleInsertDialect(stdHandle, registry);
+
+  MlirContext ctx = mlirContextCreate();
+  if (mlirContextGetNumRegisteredDialects(ctx) != 0) {
+    fprintf(stderr,
+            "ERROR: Expected no dialects to be registered to new context\n");
+  }
+
+  mlirContextAppendDialectRegistry(ctx, registry);
+  if (mlirContextGetNumRegisteredDialects(ctx) != 1) {
+    fprintf(stderr, "ERROR: Expected the dialect in the registry to be "
+                    "registered to the context\n");
+  }
+
+  mlirContextDestroy(ctx);
+  mlirDialectRegistryDestroy(registry);
+
+  return 0;
+}
+
 void testDiagnostics() {
   MlirContext ctx = mlirContextCreate();
   MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler(
@@ -1988,6 +2018,8 @@ int main() {
     return 13;
   if (testSymbolTable(ctx))
     return 14;
+  if (testDialectRegistry())
+    return 15;
 
   mlirContextDestroy(ctx);
 


        


More information about the Mlir-commits mailing list