[Mlir-commits] [mlir] [mlir][python] allow upstream dialect registration (PR #74252)

Maksim Levental llvmlistbot at llvm.org
Sun Dec 3 17:30:25 PST 2023


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/74252

>From 719695dcf53084c4d439f4de53d6c05c80cb32e9 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Sun, 3 Dec 2023 14:45:16 -0600
Subject: [PATCH] [mlir][python] allow upstream dialect registration

---
 .../examples/standalone/python/CMakeLists.txt |  6 --
 .../standalone/test/python/smoketest.py       |  2 +
 .../mlir-c/Dialect/RemainingDialects.h        | 41 +++++++++
 mlir/lib/Bindings/Python/MainModule.cpp       | 83 +++++++++++++++++++
 mlir/lib/CAPI/Dialect/CMakeLists.txt          | 11 +++
 mlir/lib/CAPI/Dialect/RemainingDialects.cpp   | 53 ++++++++++++
 .../RegisterEverything/RegisterEverything.cpp |  2 -
 mlir/python/CMakeLists.txt                    | 23 +++++
 mlir/python/mlir/ir.py                        |  6 +-
 9 files changed, 218 insertions(+), 9 deletions(-)
 create mode 100644 mlir/include/mlir-c/Dialect/RemainingDialects.h
 create mode 100644 mlir/lib/CAPI/Dialect/RemainingDialects.cpp

diff --git a/mlir/examples/standalone/python/CMakeLists.txt b/mlir/examples/standalone/python/CMakeLists.txt
index a8c43827a5a37..014d6061f7f0f 100644
--- a/mlir/examples/standalone/python/CMakeLists.txt
+++ b/mlir/examples/standalone/python/CMakeLists.txt
@@ -40,9 +40,6 @@ add_mlir_python_common_capi_library(StandalonePythonCAPI
   RELATIVE_INSTALL_ROOT "../../../.."
   DECLARED_SOURCES
     StandalonePythonSources
-    # TODO: Remove this in favor of showing fine grained registration once
-    # available.
-    MLIRPythonExtension.RegisterEverything
     MLIRPythonSources.Core
 )
 
@@ -55,9 +52,6 @@ add_mlir_python_modules(StandalonePythonModules
   INSTALL_PREFIX "python_packages/standalone/mlir_standalone"
   DECLARED_SOURCES
     StandalonePythonSources
-    # TODO: Remove this in favor of showing fine grained registration once
-    # available.
-    MLIRPythonExtension.RegisterEverything
     MLIRPythonSources
   COMMON_CAPI_LINK_LIBS
     StandalonePythonCAPI
diff --git a/mlir/examples/standalone/test/python/smoketest.py b/mlir/examples/standalone/test/python/smoketest.py
index 08e08cbd2fe24..6e82a91e0bfc7 100644
--- a/mlir/examples/standalone/test/python/smoketest.py
+++ b/mlir/examples/standalone/test/python/smoketest.py
@@ -3,6 +3,8 @@
 from mlir_standalone.ir import *
 from mlir_standalone.dialects import builtin as builtin_d, standalone as standalone_d
 
+add_dialect_to_dialect_registry(get_dialect_registry(), "arith")
+
 with Context():
     standalone_d.register_dialect()
     module = Module.parse(
diff --git a/mlir/include/mlir-c/Dialect/RemainingDialects.h b/mlir/include/mlir-c/Dialect/RemainingDialects.h
new file mode 100644
index 0000000000000..e98f084798cd6
--- /dev/null
+++ b/mlir/include/mlir-c/Dialect/RemainingDialects.h
@@ -0,0 +1,41 @@
+#ifndef MLIR_C_REMAINING_DIALECTS_H
+#define MLIR_C_REMAINING_DIALECTS_H
+
+#include "mlir-c/IR.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_(NAMESPACE)                      \
+  MLIR_CAPI_EXPORTED MlirDialectHandle mlirGetDialectHandle__##NAMESPACE##__();
+
+#define FORALL_DIALECTS(_)                                                     \
+  _(acc)                                                                       \
+  _(affine)                                                                    \
+  _(amx)                                                                       \
+  _(arm_neon)                                                                  \
+  _(arm_sme)                                                                   \
+  _(arm_sve)                                                                   \
+  _(bufferization)                                                             \
+  _(complex)                                                                   \
+  _(dlti)                                                                      \
+  _(emitc)                                                                     \
+  _(index)                                                                     \
+  _(irdl)                                                                      \
+  _(mesh)                                                                      \
+  _(spirv)                                                                     \
+  _(tosa)                                                                      \
+  _(ub)                                                                        \
+  _(x86vector)
+
+FORALL_DIALECTS(MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_)
+
+#undef MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_
+#undef FORALL_DIALECTS
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // MLIR_C_REMAINING_DIALECTS_H
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 17272472ccca4..dc062244b828c 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -12,6 +12,32 @@
 #include "IRModule.h"
 #include "Pass.h"
 
+#include "mlir-c/Dialect/AMDGPU.h"
+#include "mlir-c/Dialect/Arith.h"
+#include "mlir-c/Dialect/Async.h"
+#include "mlir-c/Dialect/ControlFlow.h"
+#include "mlir-c/Dialect/Func.h"
+#include "mlir-c/Dialect/GPU.h"
+#include "mlir-c/Dialect/LLVM.h"
+#include "mlir-c/Dialect/Linalg.h"
+#include "mlir-c/Dialect/MLProgram.h"
+#include "mlir-c/Dialect/Math.h"
+#include "mlir-c/Dialect/MemRef.h"
+#include "mlir-c/Dialect/NVGPU.h"
+#include "mlir-c/Dialect/NVVM.h"
+#include "mlir-c/Dialect/OpenMP.h"
+#include "mlir-c/Dialect/PDL.h"
+#include "mlir-c/Dialect/Quant.h"
+#include "mlir-c/Dialect/ROCDL.h"
+#include "mlir-c/Dialect/SCF.h"
+#include "mlir-c/Dialect/Shape.h"
+#include "mlir-c/Dialect/SparseTensor.h"
+#include "mlir-c/Dialect/Tensor.h"
+#include "mlir-c/Dialect/Transform.h"
+#include "mlir-c/Dialect/Vector.h"
+
+#include "mlir-c/Dialect/RemainingDialects.h"
+
 namespace py = pybind11;
 using namespace mlir;
 using namespace py::literals;
@@ -65,6 +91,63 @@ PYBIND11_MODULE(_mlir, m) {
       },
       "dialect_class"_a,
       "Class decorator for registering a custom Dialect wrapper");
+  m.def(
+      "add_dialect_to_dialect_registry",
+      [](MlirDialectRegistry registry, const std::string &dialectNamespace) {
+
+#define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_(NAMESPACE)                      \
+  if (dialectNamespace == #NAMESPACE) {                                        \
+    mlirDialectHandleInsertDialect(mlirGetDialectHandle__##NAMESPACE##__(),    \
+                                   registry);                                  \
+    return;                                                                    \
+  }
+
+#define FORALL_DIALECTS(_)                                                     \
+  _(acc)                                                                       \
+  _(affine)                                                                    \
+  _(amdgpu)                                                                    \
+  _(amx)                                                                       \
+  _(arith)                                                                     \
+  _(arm_neon)                                                                  \
+  _(arm_sme)                                                                   \
+  _(arm_sve)                                                                   \
+  _(async)                                                                     \
+  _(bufferization)                                                             \
+  _(cf)                                                                        \
+  _(complex)                                                                   \
+  _(emitc)                                                                     \
+  _(func)                                                                      \
+  _(gpu)                                                                       \
+  _(index)                                                                     \
+  _(irdl)                                                                      \
+  _(linalg)                                                                    \
+  _(llvm)                                                                      \
+  _(math)                                                                      \
+  _(memref)                                                                    \
+  _(mesh)                                                                      \
+  _(ml_program)                                                                \
+  _(nvgpu)                                                                     \
+  _(nvvm)                                                                      \
+  _(omp)                                                                       \
+  _(pdl)                                                                       \
+  _(quant)                                                                     \
+  _(rocdl)                                                                     \
+  _(scf)                                                                       \
+  _(shape)                                                                     \
+  _(spirv)                                                                     \
+  _(tensor)                                                                    \
+  _(tosa)                                                                      \
+  _(ub)                                                                        \
+  _(vector)                                                                    \
+  _(x86vector)
+        FORALL_DIALECTS(MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_)
+
+#undef MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_
+#undef FORALL_DIALECTS
+        throw std::runtime_error("unknown dialect namespace: " +
+                                 dialectNamespace);
+      },
+      "dialect_registry"_a, "dialect_namespace"_a);
   m.def(
       "register_operation",
       [](const py::object &dialectClass, bool replace) -> py::cpp_function {
diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt
index d815eba48d9b9..de0c4b9ac2478 100644
--- a/mlir/lib/CAPI/Dialect/CMakeLists.txt
+++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt
@@ -224,3 +224,14 @@ add_mlir_upstream_c_api_library(MLIRCAPIVector
   MLIRCAPIIR
   MLIRVectorDialect
 )
+
+get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
+
+add_mlir_upstream_c_api_library(MLIRCAPIRemainingDialects
+  RemainingDialects.cpp
+
+  PARTIAL_SOURCES_INTENDED
+  LINK_LIBS PUBLIC
+  MLIRCAPIIR
+  ${dialect_libs}
+)
diff --git a/mlir/lib/CAPI/Dialect/RemainingDialects.cpp b/mlir/lib/CAPI/Dialect/RemainingDialects.cpp
new file mode 100644
index 0000000000000..a35814376a3eb
--- /dev/null
+++ b/mlir/lib/CAPI/Dialect/RemainingDialects.cpp
@@ -0,0 +1,53 @@
+#include "mlir-c/Dialect/RemainingDialects.h"
+
+#include "mlir/CAPI/Registration.h"
+#include "mlir/InitAllDialects.h"
+
+using namespace mlir;
+
+#define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_(NAMESPACE, NAME)                \
+  MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(NAME, NAMESPACE,                       \
+                                        NAMESPACE::NAME##Dialect)
+
+#define FORALL_DIALECTS(_)                                                     \
+  _(acc, OpenACC)                                                              \
+  _(affine, Affine)                                                            \
+  _(amx, AMX)                                                                  \
+  _(arith, Arith)                                                              \
+  _(arm_neon, ArmNeon)                                                         \
+  _(arm_sme, ArmSME)                                                           \
+  _(arm_sve, ArmSVE)                                                           \
+  _(bufferization, Bufferization)                                              \
+  _(complex, Complex)                                                          \
+  _(emitc, EmitC)                                                              \
+  _(index, Index)                                                              \
+  _(irdl, IRDL)                                                                \
+  _(mesh, Mesh)                                                                \
+  _(spirv, SPIRV)                                                              \
+  _(tosa, Tosa)                                                                \
+  _(ub, UB)                                                                    \
+  _(x86vector, X86Vector)
+
+FORALL_DIALECTS(MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_)
+
+#undef MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_
+#undef FORALL_DIALECTS
+
+static void mlirDialectRegistryInsertDLTIDialect(MlirDialectRegistry registry) {
+  unwrap(registry)->insert<mlir::DLTIDialect>();
+}
+
+static MlirDialect mlirContextLoadDLTIDialect(MlirContext context) {
+  return wrap(unwrap(context)->getOrLoadDialect<mlir::DLTIDialect>());
+}
+
+static MlirStringRef mlirDLTIDialectGetNamespace() {
+  return wrap(mlir::DLTIDialect::getDialectNamespace());
+}
+
+MlirDialectHandle mlirGetDialectHandle__dlti__() {
+  static MlirDialectRegistrationHooks hooks = {
+      mlirDialectRegistryInsertDLTIDialect, mlirContextLoadDLTIDialect,
+      mlirDLTIDialectGetNamespace};
+  return MlirDialectHandle{&hooks};
+}
diff --git a/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp b/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp
index c1c4a418b2552..767e7631de17d 100644
--- a/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp
+++ b/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp
@@ -14,8 +14,6 @@
 #include "mlir/InitAllExtensions.h"
 #include "mlir/InitAllPasses.h"
 #include "mlir/Target/LLVMIR/Dialect/All.h"
-#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
-#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
 
 void mlirRegisterAllDialects(MlirDialectRegistry registry) {
   mlir::registerAllDialects(*unwrap(registry));
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 55731943fb78d..456cf5f205cc5 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -423,7 +423,30 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
     MLIRCAPIInterfaces
 
     # Dialects
+    MLIRCAPIAMDGPU
+    MLIRCAPIArith
+    MLIRCAPIAsync
+    MLIRCAPIControlFlow
     MLIRCAPIFunc
+    MLIRCAPIGPU
+    MLIRCAPILLVM
+    MLIRCAPILinalg
+    MLIRCAPIMLProgram
+    MLIRCAPIMath
+    MLIRCAPIMemRef
+    MLIRCAPINVGPU
+    MLIRCAPINVVM
+    MLIRCAPIOpenMP
+    MLIRCAPIPDL
+    MLIRCAPIQuant
+    MLIRCAPIROCDL
+    MLIRCAPISCF
+    MLIRCAPIShape
+    MLIRCAPISparseTensor
+    MLIRCAPITensor
+    MLIRCAPITransformDialect
+    MLIRCAPIVector
+    MLIRCAPIRemainingDialects
 )
 
 # This extension exposes an API to register all dialects, extensions, and passes
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 6d21da3b4179f..d46134b24416e 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -4,7 +4,11 @@
 
 from ._mlir_libs._mlir.ir import *
 from ._mlir_libs._mlir.ir import _GlobalDebug
-from ._mlir_libs._mlir import register_type_caster, register_value_caster
+from ._mlir_libs._mlir import (
+    register_type_caster,
+    register_value_caster,
+    add_dialect_to_dialect_registry,
+)
 from ._mlir_libs import get_dialect_registry
 
 



More information about the Mlir-commits mailing list