[Mlir-commits] [mlir] 97fc568 - [mlir][capi] Add DialectRegistry to MLIR C-API
Daniel Resnick
llvmlistbot at llvm.org
Tue Feb 1 12:43:03 PST 2022
Author: Daniel Resnick
Date: 2022-02-01T13:42:06-07:00
New Revision: 97fc5682112d019230a3ab23cced3bdf093f6094
URL: https://github.com/llvm/llvm-project/commit/97fc5682112d019230a3ab23cced3bdf093f6094
DIFF: https://github.com/llvm/llvm-project/commit/97fc5682112d019230a3ab23cced3bdf093f6094.diff
LOG: [mlir][capi] Add DialectRegistry to MLIR C-API
Exposes mlir::DialectRegistry to the C API as MlirDialectRegistry along with
helper functions. A hook has been added to MlirDialectHandle that inserts
the dialect into a registry.
A future possible change is removing mlirDialectHandleRegisterDialect in
favor of using mlirDialectHandleInsertDialect, which it is now implemented with.
Differential Revision: https://reviews.llvm.org/D118293
Added:
Modified:
mlir/include/mlir-c/IR.h
mlir/include/mlir-c/Registration.h
mlir/include/mlir/CAPI/IR.h
mlir/include/mlir/CAPI/Registration.h
mlir/lib/CAPI/IR/DialectHandle.cpp
mlir/lib/CAPI/IR/IR.cpp
mlir/test/CAPI/ir.c
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index e5c20ae70cacf..d999554664d96 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -50,6 +50,7 @@ extern "C" {
DEFINE_C_API_STRUCT(MlirContext, void);
DEFINE_C_API_STRUCT(MlirDialect, void);
+DEFINE_C_API_STRUCT(MlirDialectRegistry, void);
DEFINE_C_API_STRUCT(MlirOperation, void);
DEFINE_C_API_STRUCT(MlirOpPrintingFlags, void);
DEFINE_C_API_STRUCT(MlirBlock, void);
@@ -108,6 +109,11 @@ mlirContextGetAllowUnregisteredDialects(MlirContext context);
MLIR_CAPI_EXPORTED intptr_t
mlirContextGetNumRegisteredDialects(MlirContext context);
+/// Append the contents of the given dialect registry to the registry associated
+/// with the context.
+MLIR_CAPI_EXPORTED void
+mlirContextAppendDialectRegistry(MlirContext ctx, MlirDialectRegistry registry);
+
/// Returns the number of dialects loaded by the context.
MLIR_CAPI_EXPORTED intptr_t
@@ -152,6 +158,22 @@ MLIR_CAPI_EXPORTED bool mlirDialectEqual(MlirDialect dialect1,
/// Returns the namespace of the given dialect.
MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect);
+//===----------------------------------------------------------------------===//
+// DialectRegistry API.
+//===----------------------------------------------------------------------===//
+
+/// Creates a dialect registry and transfers its ownership to the caller.
+MLIR_CAPI_EXPORTED MlirDialectRegistry mlirDialectRegistryCreate();
+
+/// Checks if the dialect registry is null.
+static inline bool mlirDialectRegistryIsNull(MlirDialectRegistry registry) {
+ return !registry.ptr;
+}
+
+/// Takes a dialect registry owned by the caller and destroys it.
+MLIR_CAPI_EXPORTED void
+mlirDialectRegistryDestroy(MlirDialectRegistry registry);
+
//===----------------------------------------------------------------------===//
// Location API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir-c/Registration.h b/mlir/include/mlir-c/Registration.h
index e8604329fd381..442449626632c 100644
--- a/mlir/include/mlir-c/Registration.h
+++ b/mlir/include/mlir-c/Registration.h
@@ -44,6 +44,11 @@ typedef struct MlirDialectHandle MlirDialectHandle;
MLIR_CAPI_EXPORTED
MlirStringRef mlirDialectHandleGetNamespace(MlirDialectHandle);
+/// Inserts the dialect associated with the provided dialect handle into the
+/// provided dialect registry
+MLIR_CAPI_EXPORTED void mlirDialectHandleInsertDialect(MlirDialectHandle,
+ MlirDialectRegistry);
+
/// Registers the dialect associated with the provided dialect handle.
MLIR_CAPI_EXPORTED void mlirDialectHandleRegisterDialect(MlirDialectHandle,
MlirContext);
diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h
index af7ae89771f94..06cf7762a9c0e 100644
--- a/mlir/include/mlir/CAPI/IR.h
+++ b/mlir/include/mlir/CAPI/IR.h
@@ -22,6 +22,7 @@
DEFINE_C_API_PTR_METHODS(MlirContext, mlir::MLIRContext)
DEFINE_C_API_PTR_METHODS(MlirDialect, mlir::Dialect)
+DEFINE_C_API_PTR_METHODS(MlirDialectRegistry, mlir::DialectRegistry)
DEFINE_C_API_PTR_METHODS(MlirOperation, mlir::Operation)
DEFINE_C_API_PTR_METHODS(MlirBlock, mlir::Block)
DEFINE_C_API_PTR_METHODS(MlirOpPrintingFlags, mlir::OpPrintingFlags)
diff --git a/mlir/include/mlir/CAPI/Registration.h b/mlir/include/mlir/CAPI/Registration.h
index ac909d1dd9da7..e57023d30130e 100644
--- a/mlir/include/mlir/CAPI/Registration.h
+++ b/mlir/include/mlir/CAPI/Registration.h
@@ -21,23 +21,23 @@
//===----------------------------------------------------------------------===//
/// Hooks for dynamic discovery of dialects.
-typedef void (*MlirContextRegisterDialectHook)(MlirContext context);
+typedef void (*MlirDialectRegistryInsertDialectHook)(
+ MlirDialectRegistry registry);
typedef MlirDialect (*MlirContextLoadDialectHook)(MlirContext context);
typedef MlirStringRef (*MlirDialectGetNamespaceHook)();
/// Structure of dialect registration hooks.
struct MlirDialectRegistrationHooks {
- MlirContextRegisterDialectHook registerHook;
+ MlirDialectRegistryInsertDialectHook insertHook;
MlirContextLoadDialectHook loadHook;
MlirDialectGetNamespaceHook getNamespaceHook;
};
typedef struct MlirDialectRegistrationHooks MlirDialectRegistrationHooks;
#define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Name, Namespace, ClassName) \
- static void mlirContextRegister##Name##Dialect(MlirContext context) { \
- mlir::DialectRegistry registry; \
- registry.insert<ClassName>(); \
- unwrap(context)->appendDialectRegistry(registry); \
+ static void mlirDialectRegistryInsert##Name##Dialect( \
+ MlirDialectRegistry registry) { \
+ unwrap(registry)->insert<ClassName>(); \
} \
static MlirDialect mlirContextLoad##Name##Dialect(MlirContext context) { \
return wrap(unwrap(context)->getOrLoadDialect<ClassName>()); \
@@ -47,8 +47,8 @@ typedef struct MlirDialectRegistrationHooks MlirDialectRegistrationHooks;
} \
MlirDialectHandle mlirGetDialectHandle__##Namespace##__() { \
static MlirDialectRegistrationHooks hooks = { \
- mlirContextRegister##Name##Dialect, mlirContextLoad##Name##Dialect, \
- mlir##Name##DialectGetNamespace}; \
+ mlirDialectRegistryInsert##Name##Dialect, \
+ mlirContextLoad##Name##Dialect, mlir##Name##DialectGetNamespace}; \
return MlirDialectHandle{&hooks}; \
}
diff --git a/mlir/lib/CAPI/IR/DialectHandle.cpp b/mlir/lib/CAPI/IR/DialectHandle.cpp
index fb972316ebdf0..19f64d9482179 100644
--- a/mlir/lib/CAPI/IR/DialectHandle.cpp
+++ b/mlir/lib/CAPI/IR/DialectHandle.cpp
@@ -17,9 +17,16 @@ MlirStringRef mlirDialectHandleGetNamespace(MlirDialectHandle handle) {
return unwrap(handle)->getNamespaceHook();
}
+void mlirDialectHandleInsertDialect(MlirDialectHandle handle,
+ MlirDialectRegistry registry) {
+ unwrap(handle)->insertHook(registry);
+}
+
void mlirDialectHandleRegisterDialect(MlirDialectHandle handle,
MlirContext ctx) {
- unwrap(handle)->registerHook(ctx);
+ mlir::DialectRegistry registry;
+ mlirDialectHandleInsertDialect(handle, wrap(®istry));
+ unwrap(ctx)->appendDialectRegistry(registry);
}
MlirDialect mlirDialectHandleLoadDialect(MlirDialectHandle handle,
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 26af65ec75b6a..c067b202b71c1 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -53,6 +53,11 @@ intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) {
return static_cast<intptr_t>(unwrap(context)->getAvailableDialects().size());
}
+void mlirContextAppendDialectRegistry(MlirContext ctx,
+ MlirDialectRegistry registry) {
+ unwrap(ctx)->appendDialectRegistry(*unwrap(registry));
+}
+
// TODO: expose a cheaper way than constructing + sorting a vector only to take
// its size.
intptr_t mlirContextGetNumLoadedDialects(MlirContext context) {
@@ -88,6 +93,18 @@ MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) {
return wrap(unwrap(dialect)->getNamespace());
}
+//===----------------------------------------------------------------------===//
+// DialectRegistry API.
+//===----------------------------------------------------------------------===//
+
+MlirDialectRegistry mlirDialectRegistryCreate() {
+ return wrap(new DialectRegistry());
+}
+
+void mlirDialectRegistryDestroy(MlirDialectRegistry registry) {
+ delete unwrap(registry);
+}
+
//===----------------------------------------------------------------------===//
// Printing flags API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 257d5e9b8683d..79fb3344d0819 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -1904,6 +1904,36 @@ int testSymbolTable(MlirContext ctx) {
return 0;
}
+int testDialectRegistry() {
+ fprintf(stderr, "@testDialectRegistry\n");
+
+ MlirDialectRegistry registry = mlirDialectRegistryCreate();
+ if (mlirDialectRegistryIsNull(registry)) {
+ fprintf(stderr, "ERROR: Expected registry to be present\n");
+ return 1;
+ }
+
+ MlirDialectHandle stdHandle = mlirGetDialectHandle__std__();
+ mlirDialectHandleInsertDialect(stdHandle, registry);
+
+ MlirContext ctx = mlirContextCreate();
+ if (mlirContextGetNumRegisteredDialects(ctx) != 0) {
+ fprintf(stderr,
+ "ERROR: Expected no dialects to be registered to new context\n");
+ }
+
+ mlirContextAppendDialectRegistry(ctx, registry);
+ if (mlirContextGetNumRegisteredDialects(ctx) != 1) {
+ fprintf(stderr, "ERROR: Expected the dialect in the registry to be "
+ "registered to the context\n");
+ }
+
+ mlirContextDestroy(ctx);
+ mlirDialectRegistryDestroy(registry);
+
+ return 0;
+}
+
void testDiagnostics() {
MlirContext ctx = mlirContextCreate();
MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler(
@@ -1988,6 +2018,8 @@ int main() {
return 13;
if (testSymbolTable(ctx))
return 14;
+ if (testDialectRegistry())
+ return 15;
mlirContextDestroy(ctx);
More information about the Mlir-commits
mailing list