[Mlir-commits] [mlir] [mlir][python] allow upstream dialect registration (PR #74252)
Maksim Levental
llvmlistbot at llvm.org
Sun Dec 3 14:15:04 PST 2023
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/74252
>From bc9e2168179c3e07562d7d00b47dc5d7e01e48eb Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Sun, 3 Dec 2023 14:45:16 -0600
Subject: [PATCH] [mlir][python] allow upstream dialect registration
---
.../mlir-c/Dialect/RemainingDialects.h | 41 +++++++++
mlir/lib/Bindings/Python/MainModule.cpp | 83 +++++++++++++++++++
mlir/lib/CAPI/Dialect/CMakeLists.txt | 11 +++
mlir/lib/CAPI/Dialect/RemainingDialects.cpp | 53 ++++++++++++
.../RegisterEverything/RegisterEverything.cpp | 2 -
mlir/python/CMakeLists.txt | 23 +++++
6 files changed, 211 insertions(+), 2 deletions(-)
create mode 100644 mlir/include/mlir-c/Dialect/RemainingDialects.h
create mode 100644 mlir/lib/CAPI/Dialect/RemainingDialects.cpp
diff --git a/mlir/include/mlir-c/Dialect/RemainingDialects.h b/mlir/include/mlir-c/Dialect/RemainingDialects.h
new file mode 100644
index 0000000000000..e98f084798cd6
--- /dev/null
+++ b/mlir/include/mlir-c/Dialect/RemainingDialects.h
@@ -0,0 +1,41 @@
+#ifndef MLIR_C_REMAINING_DIALECTS_H
+#define MLIR_C_REMAINING_DIALECTS_H
+
+#include "mlir-c/IR.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_(NAMESPACE) \
+ MLIR_CAPI_EXPORTED MlirDialectHandle mlirGetDialectHandle__##NAMESPACE##__();
+
+#define FORALL_DIALECTS(_) \
+ _(acc) \
+ _(affine) \
+ _(amx) \
+ _(arm_neon) \
+ _(arm_sme) \
+ _(arm_sve) \
+ _(bufferization) \
+ _(complex) \
+ _(dlti) \
+ _(emitc) \
+ _(index) \
+ _(irdl) \
+ _(mesh) \
+ _(spirv) \
+ _(tosa) \
+ _(ub) \
+ _(x86vector)
+
+FORALL_DIALECTS(MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_)
+
+#undef MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_
+#undef FORALL_DIALECTS
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // MLIR_C_REMAINING_DIALECTS_H
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 17272472ccca4..dc062244b828c 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -12,6 +12,32 @@
#include "IRModule.h"
#include "Pass.h"
+#include "mlir-c/Dialect/AMDGPU.h"
+#include "mlir-c/Dialect/Arith.h"
+#include "mlir-c/Dialect/Async.h"
+#include "mlir-c/Dialect/ControlFlow.h"
+#include "mlir-c/Dialect/Func.h"
+#include "mlir-c/Dialect/GPU.h"
+#include "mlir-c/Dialect/LLVM.h"
+#include "mlir-c/Dialect/Linalg.h"
+#include "mlir-c/Dialect/MLProgram.h"
+#include "mlir-c/Dialect/Math.h"
+#include "mlir-c/Dialect/MemRef.h"
+#include "mlir-c/Dialect/NVGPU.h"
+#include "mlir-c/Dialect/NVVM.h"
+#include "mlir-c/Dialect/OpenMP.h"
+#include "mlir-c/Dialect/PDL.h"
+#include "mlir-c/Dialect/Quant.h"
+#include "mlir-c/Dialect/ROCDL.h"
+#include "mlir-c/Dialect/SCF.h"
+#include "mlir-c/Dialect/Shape.h"
+#include "mlir-c/Dialect/SparseTensor.h"
+#include "mlir-c/Dialect/Tensor.h"
+#include "mlir-c/Dialect/Transform.h"
+#include "mlir-c/Dialect/Vector.h"
+
+#include "mlir-c/Dialect/RemainingDialects.h"
+
namespace py = pybind11;
using namespace mlir;
using namespace py::literals;
@@ -65,6 +91,63 @@ PYBIND11_MODULE(_mlir, m) {
},
"dialect_class"_a,
"Class decorator for registering a custom Dialect wrapper");
+ m.def(
+ "add_dialect_to_dialect_registry",
+ [](MlirDialectRegistry registry, const std::string &dialectNamespace) {
+
+#define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_(NAMESPACE) \
+ if (dialectNamespace == #NAMESPACE) { \
+ mlirDialectHandleInsertDialect(mlirGetDialectHandle__##NAMESPACE##__(), \
+ registry); \
+ return; \
+ }
+
+#define FORALL_DIALECTS(_) \
+ _(acc) \
+ _(affine) \
+ _(amdgpu) \
+ _(amx) \
+ _(arith) \
+ _(arm_neon) \
+ _(arm_sme) \
+ _(arm_sve) \
+ _(async) \
+ _(bufferization) \
+ _(cf) \
+ _(complex) \
+ _(emitc) \
+ _(func) \
+ _(gpu) \
+ _(index) \
+ _(irdl) \
+ _(linalg) \
+ _(llvm) \
+ _(math) \
+ _(memref) \
+ _(mesh) \
+ _(ml_program) \
+ _(nvgpu) \
+ _(nvvm) \
+ _(omp) \
+ _(pdl) \
+ _(quant) \
+ _(rocdl) \
+ _(scf) \
+ _(shape) \
+ _(spirv) \
+ _(tensor) \
+ _(tosa) \
+ _(ub) \
+ _(vector) \
+ _(x86vector)
+ FORALL_DIALECTS(MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_)
+
+#undef MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_
+#undef FORALL_DIALECTS
+ throw std::runtime_error("unknown dialect namespace: " +
+ dialectNamespace);
+ },
+ "dialect_registry"_a, "dialect_namespace"_a);
m.def(
"register_operation",
[](const py::object &dialectClass, bool replace) -> py::cpp_function {
diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt
index d815eba48d9b9..de0c4b9ac2478 100644
--- a/mlir/lib/CAPI/Dialect/CMakeLists.txt
+++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt
@@ -224,3 +224,14 @@ add_mlir_upstream_c_api_library(MLIRCAPIVector
MLIRCAPIIR
MLIRVectorDialect
)
+
+get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
+
+add_mlir_upstream_c_api_library(MLIRCAPIRemainingDialects
+ RemainingDialects.cpp
+
+ PARTIAL_SOURCES_INTENDED
+ LINK_LIBS PUBLIC
+ MLIRCAPIIR
+ ${dialect_libs}
+)
diff --git a/mlir/lib/CAPI/Dialect/RemainingDialects.cpp b/mlir/lib/CAPI/Dialect/RemainingDialects.cpp
new file mode 100644
index 0000000000000..a35814376a3eb
--- /dev/null
+++ b/mlir/lib/CAPI/Dialect/RemainingDialects.cpp
@@ -0,0 +1,53 @@
+#include "mlir-c/Dialect/RemainingDialects.h"
+
+#include "mlir/CAPI/Registration.h"
+#include "mlir/InitAllDialects.h"
+
+using namespace mlir;
+
+#define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_(NAMESPACE, NAME) \
+ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(NAME, NAMESPACE, \
+ NAMESPACE::NAME##Dialect)
+
+#define FORALL_DIALECTS(_) \
+ _(acc, OpenACC) \
+ _(affine, Affine) \
+ _(amx, AMX) \
+ _(arith, Arith) \
+ _(arm_neon, ArmNeon) \
+ _(arm_sme, ArmSME) \
+ _(arm_sve, ArmSVE) \
+ _(bufferization, Bufferization) \
+ _(complex, Complex) \
+ _(emitc, EmitC) \
+ _(index, Index) \
+ _(irdl, IRDL) \
+ _(mesh, Mesh) \
+ _(spirv, SPIRV) \
+ _(tosa, Tosa) \
+ _(ub, UB) \
+ _(x86vector, X86Vector)
+
+FORALL_DIALECTS(MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_)
+
+#undef MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_
+#undef FORALL_DIALECTS
+
+static void mlirDialectRegistryInsertDLTIDialect(MlirDialectRegistry registry) {
+ unwrap(registry)->insert<mlir::DLTIDialect>();
+}
+
+static MlirDialect mlirContextLoadDLTIDialect(MlirContext context) {
+ return wrap(unwrap(context)->getOrLoadDialect<mlir::DLTIDialect>());
+}
+
+static MlirStringRef mlirDLTIDialectGetNamespace() {
+ return wrap(mlir::DLTIDialect::getDialectNamespace());
+}
+
+MlirDialectHandle mlirGetDialectHandle__dlti__() {
+ static MlirDialectRegistrationHooks hooks = {
+ mlirDialectRegistryInsertDLTIDialect, mlirContextLoadDLTIDialect,
+ mlirDLTIDialectGetNamespace};
+ return MlirDialectHandle{&hooks};
+}
diff --git a/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp b/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp
index c1c4a418b2552..767e7631de17d 100644
--- a/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp
+++ b/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp
@@ -14,8 +14,6 @@
#include "mlir/InitAllExtensions.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Target/LLVMIR/Dialect/All.h"
-#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
-#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
void mlirRegisterAllDialects(MlirDialectRegistry registry) {
mlir::registerAllDialects(*unwrap(registry));
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 55731943fb78d..456cf5f205cc5 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -423,7 +423,30 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
MLIRCAPIInterfaces
# Dialects
+ MLIRCAPIAMDGPU
+ MLIRCAPIArith
+ MLIRCAPIAsync
+ MLIRCAPIControlFlow
MLIRCAPIFunc
+ MLIRCAPIGPU
+ MLIRCAPILLVM
+ MLIRCAPILinalg
+ MLIRCAPIMLProgram
+ MLIRCAPIMath
+ MLIRCAPIMemRef
+ MLIRCAPINVGPU
+ MLIRCAPINVVM
+ MLIRCAPIOpenMP
+ MLIRCAPIPDL
+ MLIRCAPIQuant
+ MLIRCAPIROCDL
+ MLIRCAPISCF
+ MLIRCAPIShape
+ MLIRCAPISparseTensor
+ MLIRCAPITensor
+ MLIRCAPITransformDialect
+ MLIRCAPIVector
+ MLIRCAPIRemainingDialects
)
# This extension exposes an API to register all dialects, extensions, and passes
More information about the Mlir-commits
mailing list