[Mlir-commits] [mlir] [mlir][python] allow upstream dialect registration (PR #74252)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Dec 3 13:00:01 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Maksim Levental (makslevental)

<details>
<summary>Changes</summary>

This (currently just a sketch) PR allows users that **do not** build the `MLIRPythonExtension.RegisterEverything` target to still register upstream dialects from Python (discussed in https://github.com/llvm/llvm-project/issues/74245). It uses the `get_dialect_registry` API introduced in https://github.com/llvm/llvm-project/pull/72488 and adds/exposes `MlirDialectHandle`s for all upstream dialects.

Right now the API is lame - pass a string that's associated with the dialect (the dialect namespace) and the `MlirDialectHandle` is fetched/gotten/materialized internally (by calling the appropriate C API). Ignoring whether we should keep it like this, the basic idea is captured in the `m.def`:

```cpp
  m.def(
      "add_dialect_to_dialect_registry",
      [](MlirDialectRegistry registry, const std::string &dialectNamespace) {

#define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_(NAMESPACE)                      \
  if (dialectNamespace == #NAMESPACE) {                                        \
    mlirDialectHandleInsertDialect(mlirGetDialectHandle__##NAMESPACE##__(),    \
                                   registry);                                  \
    return;                                                                    \
  }
```

Before I do more manicuring on the API I just want to make sure this is what everyone wants. Also I'm not sure how to test this upstream because it's not like I can turn off `MLIRPythonExtension.RegisterEverything` for the core bindings. There is the `Standalone` example that conveniently a TODO about [moving to finer grained registration](https://github.com/llvm/llvm-project/blob/5e83a5b4752da6631d79c446f21e5d128b5c5495/mlir/examples/standalone/python/CMakeLists.txt#L43) so maybe now is the time to check off that TODO?

---
Full diff: https://github.com/llvm/llvm-project/pull/74252.diff


3 Files Affected:

- (modified) mlir/include/mlir-c/RegisterEverything.h (+26) 
- (modified) mlir/lib/Bindings/Python/MainModule.cpp (+82) 
- (modified) mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp (+32-2) 


``````````diff
diff --git a/mlir/include/mlir-c/RegisterEverything.h b/mlir/include/mlir-c/RegisterEverything.h
index ea2ea86449727..f894419ecb1e4 100644
--- a/mlir/include/mlir-c/RegisterEverything.h
+++ b/mlir/include/mlir-c/RegisterEverything.h
@@ -31,6 +31,32 @@ MLIR_CAPI_EXPORTED void mlirRegisterAllLLVMTranslations(MlirContext context);
 /// Register all compiler passes of MLIR.
 MLIR_CAPI_EXPORTED void mlirRegisterAllPasses(void);
 
+#define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_(NAMESPACE)                      \
+  MLIR_CAPI_EXPORTED MlirDialectHandle mlirGetDialectHandle__##NAMESPACE##__();
+
+#define FORALL_DIALECTS(_)                                                     \
+  _(acc)                                                                       \
+  _(affine)                                                                    \
+  _(amx)                                                                       \
+  _(arm_neon)                                                                  \
+  _(arm_sme)                                                                   \
+  _(arm_sve)                                                                   \
+  _(bufferization)                                                             \
+  _(complex)                                                                   \
+  _(emitc)                                                                     \
+  _(index)                                                                     \
+  _(irdl)                                                                      \
+  _(mesh)                                                                      \
+  _(spirv)                                                                     \
+  _(tosa)                                                                      \
+  _(ub)                                                                        \
+  _(x86vector)
+
+FORALL_DIALECTS(MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_)
+
+#undef MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_
+#undef FORALL_DIALECTS
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 17272472ccca4..4605ccc0f2935 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -12,6 +12,31 @@
 #include "IRModule.h"
 #include "Pass.h"
 
+#include "mlir-c/Dialect/AMDGPU.h"
+#include "mlir-c/Dialect/Arith.h"
+#include "mlir-c/Dialect/Async.h"
+#include "mlir-c/Dialect/ControlFlow.h"
+#include "mlir-c/Dialect/Func.h"
+#include "mlir-c/Dialect/GPU.h"
+#include "mlir-c/Dialect/LLVM.h"
+#include "mlir-c/Dialect/Linalg.h"
+#include "mlir-c/Dialect/MLProgram.h"
+#include "mlir-c/Dialect/Math.h"
+#include "mlir-c/Dialect/MemRef.h"
+#include "mlir-c/Dialect/NVGPU.h"
+#include "mlir-c/Dialect/NVVM.h"
+#include "mlir-c/Dialect/OpenMP.h"
+#include "mlir-c/Dialect/PDL.h"
+#include "mlir-c/Dialect/Quant.h"
+#include "mlir-c/Dialect/ROCDL.h"
+#include "mlir-c/Dialect/SCF.h"
+#include "mlir-c/Dialect/Shape.h"
+#include "mlir-c/Dialect/SparseTensor.h"
+#include "mlir-c/Dialect/Tensor.h"
+#include "mlir-c/Dialect/Transform.h"
+#include "mlir-c/Dialect/Vector.h"
+#include "mlir-c/RegisterEverything.h"
+
 namespace py = pybind11;
 using namespace mlir;
 using namespace py::literals;
@@ -65,6 +90,63 @@ PYBIND11_MODULE(_mlir, m) {
       },
       "dialect_class"_a,
       "Class decorator for registering a custom Dialect wrapper");
+  m.def(
+      "add_dialect_to_dialect_registry",
+      [](MlirDialectRegistry registry, const std::string &dialectNamespace) {
+
+#define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_(NAMESPACE)                      \
+  if (dialectNamespace == #NAMESPACE) {                                        \
+    mlirDialectHandleInsertDialect(mlirGetDialectHandle__##NAMESPACE##__(),    \
+                                   registry);                                  \
+    return;                                                                    \
+  }
+
+#define FORALL_DIALECTS(_)                                                     \
+  _(acc)                                                                       \
+  _(affine)                                                                    \
+  _(amdgpu)                                                                    \
+  _(amx)                                                                       \
+  _(arith)                                                                     \
+  _(arm_neon)                                                                  \
+  _(arm_sme)                                                                   \
+  _(arm_sve)                                                                   \
+  _(async)                                                                     \
+  _(bufferization)                                                             \
+  _(cf)                                                                        \
+  _(complex)                                                                   \
+  _(emitc)                                                                     \
+  _(func)                                                                      \
+  _(gpu)                                                                       \
+  _(index)                                                                     \
+  _(irdl)                                                                      \
+  _(linalg)                                                                    \
+  _(llvm)                                                                      \
+  _(math)                                                                      \
+  _(memref)                                                                    \
+  _(mesh)                                                                      \
+  _(ml_program)                                                                \
+  _(nvgpu)                                                                     \
+  _(nvvm)                                                                      \
+  _(omp)                                                                       \
+  _(pdl)                                                                       \
+  _(quant)                                                                     \
+  _(rocdl)                                                                     \
+  _(scf)                                                                       \
+  _(shape)                                                                     \
+  _(spirv)                                                                     \
+  _(tensor)                                                                    \
+  _(tosa)                                                                      \
+  _(ub)                                                                        \
+  _(vector)                                                                    \
+  _(x86vector)
+        FORALL_DIALECTS(MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_)
+
+#undef MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_
+#undef FORALL_DIALECTS
+        throw std::runtime_error("unknown dialect namespace: " +
+                                 dialectNamespace);
+      },
+      "dialect_registry"_a, "dialect_namespace"_a);
   m.def(
       "register_operation",
       [](const py::object &dialectClass, bool replace) -> py::cpp_function {
diff --git a/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp b/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp
index c1c4a418b2552..debebe58ad64f 100644
--- a/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp
+++ b/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp
@@ -9,13 +9,43 @@
 #include "mlir-c/RegisterEverything.h"
 
 #include "mlir/CAPI/IR.h"
+#include "mlir/CAPI/Registration.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/InitAllDialects.h"
 #include "mlir/InitAllExtensions.h"
 #include "mlir/InitAllPasses.h"
 #include "mlir/Target/LLVMIR/Dialect/All.h"
-#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
-#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
+
+using namespace mlir;
+
+#define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_(NAMESPACE, NAME)                \
+  MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(NAME, NAMESPACE,                       \
+                                        NAMESPACE::NAME##Dialect)
+
+#define FORALL_DIALECTS(_)                                                     \
+  _(acc, OpenACC)                                                              \
+  _(affine, Affine)                                                            \
+  _(amx, AMX)                                                                  \
+  _(arith, Arith)                                                              \
+  _(arm_neon, ArmNeon)                                                         \
+  _(arm_sme, ArmSME)                                                           \
+  _(arm_sve, ArmSVE)                                                           \
+  _(bufferization, Bufferization)                                              \
+  _(complex, Complex)                                                          \
+  _(emitc, EmitC)                                                              \
+  _(mlir, DLTI)                                                                \
+  _(index, Index)                                                              \
+  _(irdl, IRDL)                                                                \
+  _(mesh, Mesh)                                                                \
+  _(spirv, SPIRV)                                                              \
+  _(tosa, Tosa)                                                                \
+  _(ub, UB)                                                                    \
+  _(x86vector, X86Vector)
+
+FORALL_DIALECTS(MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_)
+
+#undef MLIR_DEFINE_CAPI_DIALECT_REGISTRATION_
+#undef FORALL_DIALECTS
 
 void mlirRegisterAllDialects(MlirDialectRegistry registry) {
   mlir::registerAllDialects(*unwrap(registry));

``````````

</details>


https://github.com/llvm/llvm-project/pull/74252


More information about the Mlir-commits mailing list