[Mlir-commits] [mlir] 64c0c9f - [mlir] Expose Dialect class and registration/loading to C API

Alex Zinenko llvmlistbot at llvm.org
Tue Sep 29 07:30:14 PDT 2020


Author: Alex Zinenko
Date: 2020-09-29T16:30:08+02:00
New Revision: 64c0c9f01511dc300b29e7a20a13958c5932e314

URL: https://github.com/llvm/llvm-project/commit/64c0c9f01511dc300b29e7a20a13958c5932e314
DIFF: https://github.com/llvm/llvm-project/commit/64c0c9f01511dc300b29e7a20a13958c5932e314.diff

LOG: [mlir] Expose Dialect class and registration/loading to C API

- Add a minimalist C API for mlir::Dialect.
- Allow one to query the context about registered and loaded dialects.
- Add API for loading dialects.
- Provide functions to register the Standard dialect.

When used naively, this will require to separately register each dialect. When
we have more than one exposed, we can add variadic macros that expand to
individual calls.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D88162

Added: 
    mlir/include/mlir-c/StandardDialect.h
    mlir/lib/CAPI/Standard/CMakeLists.txt
    mlir/lib/CAPI/Standard/StandardDialect.cpp

Modified: 
    mlir/include/mlir-c/IR.h
    mlir/include/mlir/CAPI/IR.h
    mlir/lib/CAPI/CMakeLists.txt
    mlir/lib/CAPI/IR/IR.cpp
    mlir/test/CAPI/CMakeLists.txt
    mlir/test/CAPI/ir.c

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 4aca261868f3..82149c7fce06 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -20,6 +20,8 @@
 
 #include <stdint.h>
 
+#include "mlir-c/Support.h"
+
 #ifdef __cplusplus
 extern "C" {
 #endif
@@ -46,6 +48,7 @@ extern "C" {
   typedef struct name name
 
 DEFINE_C_API_STRUCT(MlirContext, void);
+DEFINE_C_API_STRUCT(MlirDialect, void);
 DEFINE_C_API_STRUCT(MlirOperation, void);
 DEFINE_C_API_STRUCT(MlirBlock, void);
 DEFINE_C_API_STRUCT(MlirRegion, void);
@@ -97,6 +100,39 @@ void mlirContextSetAllowUnregisteredDialects(MlirContext context, int allow);
 /** Returns whether the context allows unregistered dialects. */
 int mlirContextGetAllowUnregisteredDialects(MlirContext context);
 
+/** Returns the number of dialects registered with the given context. A
+ * registered dialect will be loaded if needed by the parser. */
+intptr_t mlirContextGetNumRegisteredDialects(MlirContext context);
+
+/** Returns the number of dialects loaded by the context.
+ */
+intptr_t mlirContextGetNumLoadedDialects(MlirContext context);
+
+/** Gets the dialect instance owned by the given context using the dialect
+ * namespace to identify it, loads (i.e., constructs the instance of) the
+ * dialect if necessary. If the dialect is not registered with the context,
+ * returns null. Use mlirContextLoad<Name>Dialect to load an unregistered
+ * dialect. */
+MlirDialect mlirContextGetOrLoadDialect(MlirContext context,
+                                        MlirStringRef name);
+
+/*============================================================================*/
+/* Dialect API.                                                               */
+/*============================================================================*/
+
+/** Returns the context that owns the dialect. */
+MlirContext mlirDialectGetContext(MlirDialect dialect);
+
+/** Checks if the dialect is null. */
+int mlirDialectIsNull(MlirDialect dialect);
+
+/** Checks if two dialects that belong to the same context are equal. Dialects
+ * from 
diff erent contexts will not compare equal. */
+int mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2);
+
+/** Returns the namespace of the given dialect. */
+MlirStringRef mlirDialectGetNamespace(MlirDialect dialect);
+
 /*============================================================================*/
 /* Location API.                                                              */
 /*============================================================================*/

diff  --git a/mlir/include/mlir-c/StandardDialect.h b/mlir/include/mlir-c/StandardDialect.h
new file mode 100644
index 000000000000..946d14859d5d
--- /dev/null
+++ b/mlir/include/mlir-c/StandardDialect.h
@@ -0,0 +1,42 @@
+/*===-- mlir-c/StandardDialect.h - C API for Standard dialect -----*- C -*-===*\
+|*                                                                            *|
+|* Part of the LLVM Project, under the Apache License v2.0 with LLVM          *|
+|* Exceptions.                                                                *|
+|* See https://llvm.org/LICENSE.txt for license information.                  *|
+|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception                    *|
+|*                                                                            *|
+|*===----------------------------------------------------------------------===*|
+|*                                                                            *|
+|* This header declares the C interface for registering and accessing the     *|
+|* Standard dialect. A dialect should be registered with a context to make it *|
+|* available to users of the context. These users must load the dialect       *|
+|* before using any of its attributes, operations or types. Parser and pass   *|
+|* manager can load registered dialects automatically.                        *|
+|*                                                                            *|
+\*===----------------------------------------------------------------------===*/
+
+#ifndef MLIR_C_STANDARDDIALECT_H
+#define MLIR_C_STANDARDDIALECT_H
+
+#include "mlir-c/IR.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/** Registers the Standard dialect with the given context. This allows the
+ * dialect to be loaded dynamically if needed when parsing. */
+void mlirContextRegisterStandardDialect(MlirContext context);
+
+/** Loads the Standard dialect into the given context. The dialect does _not_
+ * have to be registered in advance. */
+MlirDialect mlirContextLoadStandardDialect(MlirContext context);
+
+/** Returns the namespace of the Standard dialect, suitable for loading it. */
+MlirStringRef mlirStandardDialectGetNamespace();
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // MLIR_C_STANDARDDIALECT_H

diff  --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h
index 9a60ecf04fc8..dce293d05588 100644
--- a/mlir/include/mlir/CAPI/IR.h
+++ b/mlir/include/mlir/CAPI/IR.h
@@ -21,6 +21,7 @@
 #include "mlir/IR/Operation.h"
 
 DEFINE_C_API_PTR_METHODS(MlirContext, mlir::MLIRContext)
+DEFINE_C_API_PTR_METHODS(MlirDialect, mlir::Dialect)
 DEFINE_C_API_PTR_METHODS(MlirOperation, mlir::Operation)
 DEFINE_C_API_PTR_METHODS(MlirBlock, mlir::Block)
 DEFINE_C_API_PTR_METHODS(MlirRegion, mlir::Region)

diff  --git a/mlir/lib/CAPI/CMakeLists.txt b/mlir/lib/CAPI/CMakeLists.txt
index 79d472b2d026..b9d2c4601b98 100644
--- a/mlir/lib/CAPI/CMakeLists.txt
+++ b/mlir/lib/CAPI/CMakeLists.txt
@@ -1,2 +1,3 @@
 add_subdirectory(IR)
 add_subdirectory(Registration)
+add_subdirectory(Standard)

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 3b99f8ac4748..359ee69708eb 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -7,8 +7,10 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
 
 #include "mlir/CAPI/IR.h"
+#include "mlir/CAPI/Support.h"
 #include "mlir/CAPI/Utils.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Dialect.h"
@@ -41,6 +43,40 @@ void mlirContextSetAllowUnregisteredDialects(MlirContext context, int allow) {
 int mlirContextGetAllowUnregisteredDialects(MlirContext context) {
   return unwrap(context)->allowsUnregisteredDialects();
 }
+intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) {
+  return static_cast<intptr_t>(unwrap(context)->getAvailableDialects().size());
+}
+
+// TODO: expose a cheaper way than constructing + sorting a vector only to take
+// its size.
+intptr_t mlirContextGetNumLoadedDialects(MlirContext context) {
+  return static_cast<intptr_t>(unwrap(context)->getLoadedDialects().size());
+}
+
+MlirDialect mlirContextGetOrLoadDialect(MlirContext context,
+                                        MlirStringRef name) {
+  return wrap(unwrap(context)->getOrLoadDialect(unwrap(name)));
+}
+
+/* ========================================================================== */
+/* Dialect API.                                                               */
+/* ========================================================================== */
+
+MlirContext mlirDialectGetContext(MlirDialect dialect) {
+  return wrap(unwrap(dialect)->getContext());
+}
+
+int mlirDialectIsNull(MlirDialect dialect) {
+  return unwrap(dialect) == nullptr;
+}
+
+int mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) {
+  return unwrap(dialect1) == unwrap(dialect2);
+}
+
+MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) {
+  return wrap(unwrap(dialect)->getNamespace());
+}
 
 /* ========================================================================== */
 /* Location API.                                                              */

diff  --git a/mlir/lib/CAPI/Standard/CMakeLists.txt b/mlir/lib/CAPI/Standard/CMakeLists.txt
new file mode 100644
index 000000000000..662841c2d235
--- /dev/null
+++ b/mlir/lib/CAPI/Standard/CMakeLists.txt
@@ -0,0 +1,11 @@
+add_mlir_library(MLIRCAPIStandard
+
+  StandardDialect.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir-c
+
+  LINK_LIBS PUBLIC
+  MLIRCAPIIR
+  MLIRStandardOps
+  )

diff  --git a/mlir/lib/CAPI/Standard/StandardDialect.cpp b/mlir/lib/CAPI/Standard/StandardDialect.cpp
new file mode 100644
index 000000000000..f78c9c916873
--- /dev/null
+++ b/mlir/lib/CAPI/Standard/StandardDialect.cpp
@@ -0,0 +1,25 @@
+//===- StandardDialect.cpp - C Interface for Standard dialect -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/StandardDialect.h"
+#include "mlir-c/IR.h"
+#include "mlir/CAPI/IR.h"
+#include "mlir/CAPI/Support.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+
+void mlirContextRegisterStandardDialect(MlirContext context) {
+  unwrap(context)->getDialectRegistry().insert<mlir::StandardOpsDialect>();
+}
+
+MlirDialect mlirContextLoadStandardDialect(MlirContext context) {
+  return wrap(unwrap(context)->getOrLoadDialect<mlir::StandardOpsDialect>());
+}
+
+MlirStringRef mlirStandardDialectGetNamespace() {
+  return wrap(mlir::StandardOpsDialect::getDialectNamespace());
+}

diff  --git a/mlir/test/CAPI/CMakeLists.txt b/mlir/test/CAPI/CMakeLists.txt
index 19deda5e3f11..876d701d7211 100644
--- a/mlir/test/CAPI/CMakeLists.txt
+++ b/mlir/test/CAPI/CMakeLists.txt
@@ -13,4 +13,5 @@ target_link_libraries(mlir-capi-ir-test
   PRIVATE
   MLIRCAPIIR
   MLIRCAPIRegistration
+  MLIRCAPIStandard
   ${dialect_libs})

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 909929647a84..ae60d56a22ed 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -14,6 +14,7 @@
 #include "mlir-c/AffineMap.h"
 #include "mlir-c/Registration.h"
 #include "mlir-c/StandardAttributes.h"
+#include "mlir-c/StandardDialect.h"
 #include "mlir-c/StandardTypes.h"
 
 #include <assert.h>
@@ -790,6 +791,42 @@ int printAffineMap(MlirContext ctx) {
   return 0;
 }
 
+int registerOnlyStd() {
+  MlirContext ctx = mlirContextCreate();
+  // The built-in dialect is always loaded.
+  if (mlirContextGetNumLoadedDialects(ctx) != 1)
+    return 1;
+
+  MlirDialect std =
+      mlirContextGetOrLoadDialect(ctx, mlirStandardDialectGetNamespace());
+  if (!mlirDialectIsNull(std))
+    return 2;
+
+  mlirContextRegisterStandardDialect(ctx);
+  if (mlirContextGetNumRegisteredDialects(ctx) != 1)
+    return 3;
+  if (mlirContextGetNumLoadedDialects(ctx) != 1)
+    return 4;
+
+  std = mlirContextGetOrLoadDialect(ctx, mlirStandardDialectGetNamespace());
+  if (mlirDialectIsNull(std))
+    return 5;
+  if (mlirContextGetNumLoadedDialects(ctx) != 2)
+    return 6;
+
+  MlirDialect alsoStd = mlirContextLoadStandardDialect(ctx);
+  if (!mlirDialectEqual(std, alsoStd))
+    return 7;
+
+  MlirStringRef stdNs = mlirDialectGetNamespace(std);
+  MlirStringRef alsoStdNs = mlirStandardDialectGetNamespace();
+  if (stdNs.length != alsoStdNs.length ||
+      strncmp(stdNs.data, alsoStdNs.data, stdNs.length))
+    return 8;
+
+  return 0;
+}
+
 int main() {
   MlirContext ctx = mlirContextCreate();
   mlirRegisterAllDialects(ctx);
@@ -935,6 +972,14 @@ int main() {
   errcode = printAffineMap(ctx);
   fprintf(stderr, "%d\n", errcode);
 
+  fprintf(stderr, "@registration\n");
+  errcode = registerOnlyStd();
+  fprintf(stderr, "%d\n", errcode);
+  // clang-format off
+  // CHECK-LABEL: @registration
+  // CHECK: 0
+  // clang-format on
+
   mlirContextDestroy(ctx);
 
   return 0;


        


More information about the Mlir-commits mailing list