[Mlir-commits] [mlir] [MLIR][Python] enable precise registration (PR #160742)
Maksim Levental
llvmlistbot at llvm.org
Thu Sep 25 21:32:43 PDT 2025
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/160742
>From 5815846788bc115ee7e7db71c6aaf17060596d8b Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Thu, 25 Sep 2025 09:40:53 -0700
Subject: [PATCH 1/2] [MLIR][Python] enable precise registration
---
mlir/cmake/modules/AddMLIRPython.cmake | 14 +++++++++++---
mlir/python/CMakeLists.txt | 20 ++++++++++++++++++--
mlir/python/mlir/_mlir_libs/_capi.py.in | 8 ++++++++
mlir/test/python/ir/capi.py | 6 ++++++
4 files changed, 43 insertions(+), 5 deletions(-)
create mode 100644 mlir/python/mlir/_mlir_libs/_capi.py.in
create mode 100644 mlir/test/python/ir/capi.py
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 208cbdd1dd535..d8b6d493f985c 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -23,11 +23,14 @@
# grouping. Source groupings form a DAG.
# SOURCES: List of specific source files relative to ROOT_DIR to include.
# SOURCES_GLOB: List of glob patterns relative to ROOT_DIR to include.
+# EMBED_CAPI_LINK_LIBS: Dependent CAPI libraries that this extension depends
+# on. These will be collected for all extensions and put into an
+# aggregate dylib that is linked against.
function(declare_mlir_python_sources name)
cmake_parse_arguments(ARG
""
"ROOT_DIR;ADD_TO_PARENT"
- "SOURCES;SOURCES_GLOB"
+ "SOURCES;SOURCES_GLOB;EMBED_CAPI_LINK_LIBS"
${ARGN})
if(NOT ARG_ROOT_DIR)
@@ -53,9 +56,10 @@ function(declare_mlir_python_sources name)
set_target_properties(${name} PROPERTIES
# Yes: Leading-lowercase property names are load bearing and the recommended
# way to do this: https://gitlab.kitware.com/cmake/cmake/-/issues/19261
- EXPORT_PROPERTIES "mlir_python_SOURCES_TYPE;mlir_python_DEPENDS"
+ EXPORT_PROPERTIES "mlir_python_SOURCES_TYPE;mlir_python_DEPENDS;mlir_python_EMBED_CAPI_LINK_LIBS"
mlir_python_SOURCES_TYPE pure
mlir_python_DEPENDS ""
+ mlir_python_EMBED_CAPI_LINK_LIBS "${ARG_EMBED_CAPI_LINK_LIBS}"
)
# Use the interface include directories and sources on the target to carry the
@@ -374,6 +378,9 @@ endfunction()
# This file is where the *EnumAttrs are defined, not where the *Enums are defined.
# **WARNING**: This arg will shortly be removed when the just-below TODO is satisfied. Use at your
# risk.
+# EMBED_CAPI_LINK_LIBS: Dependent CAPI libraries that this extension depends
+# on. These will be collected for all extensions and put into an
+# aggregate dylib that is linked against.
#
# TODO: Right now `TD_FILE` can't be the actual dialect tablegen file, since we
# use its path to determine where to place the generated python file. If
@@ -383,7 +390,7 @@ function(declare_mlir_dialect_python_bindings)
cmake_parse_arguments(ARG
"GEN_ENUM_BINDINGS"
"ROOT_DIR;ADD_TO_PARENT;TD_FILE;DIALECT_NAME"
- "SOURCES;SOURCES_GLOB;DEPENDS;GEN_ENUM_BINDINGS_TD_FILE"
+ "SOURCES;SOURCES_GLOB;DEPENDS;GEN_ENUM_BINDINGS_TD_FILE;EMBED_CAPI_LINK_LIBS"
${ARGN})
# Sources.
set(_dialect_target "${ARG_ADD_TO_PARENT}.${ARG_DIALECT_NAME}")
@@ -424,6 +431,7 @@ function(declare_mlir_dialect_python_bindings)
ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}"
ADD_TO_PARENT "${_dialect_target}"
SOURCES ${_sources}
+ EMBED_CAPI_LINK_LIBS "${ARG_EMBED_CAPI_LINK_LIBS}"
)
endif()
endfunction()
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index d6686bb89ce4e..7e0be1c10433e 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -326,7 +326,10 @@ declare_mlir_dialect_python_bindings(
SOURCES
dialects/arith.py
DIALECT_NAME arith
- GEN_ENUM_BINDINGS)
+ GEN_ENUM_BINDINGS
+ EMBED_CAPI_LINK_LIBS
+ MLIRCAPIArith
+ )
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
@@ -846,8 +849,20 @@ endif()
# once ready.
################################################################################
+set(MLIR_PYTHON_CAPI_DYLIB_NAME MLIRPythonCAPI)
+configure_file(
+ "${CMAKE_CURRENT_LIST_DIR}/mlir/_mlir_libs/_capi.py.in"
+ "${CMAKE_CURRENT_BINARY_DIR}/_mlir_libs/_capi.py"
+ @ONLY
+)
+declare_mlir_python_sources(
+ MLIRPythonCAPICTypesBinding
+ ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}"
+ SOURCES _mlir_libs/_capi.py
+)
+
set(MLIRPythonModules_ROOT_PREFIX "${MLIR_BINARY_DIR}/${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}")
-add_mlir_python_common_capi_library(MLIRPythonCAPI
+add_mlir_python_common_capi_library(${MLIR_PYTHON_CAPI_DYLIB_NAME}
INSTALL_COMPONENT MLIRPythonModules
INSTALL_DESTINATION "${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}/_mlir_libs"
OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
@@ -966,6 +981,7 @@ add_mlir_python_modules(MLIRPythonModules
MLIRPythonSources
MLIRPythonExtension.RegisterEverything
MLIRPythonExtension.Core.type_stub_gen
+ MLIRPythonCAPICTypesBinding
${_ADDL_TEST_SOURCES}
COMMON_CAPI_LINK_LIBS
MLIRPythonCAPI
diff --git a/mlir/python/mlir/_mlir_libs/_capi.py.in b/mlir/python/mlir/_mlir_libs/_capi.py.in
new file mode 100644
index 0000000000000..9568845e67de9
--- /dev/null
+++ b/mlir/python/mlir/_mlir_libs/_capi.py.in
@@ -0,0 +1,8 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import ctypes
+from pathlib import Path
+
+_capi = ctypes.CDLL(str(Path(__file__).parent / "@CMAKE_SHARED_LIBRARY_PREFIX@@MLIR_PYTHON_CAPI_DYLIB_NAME@@CMAKE_SHARED_LIBRARY_SUFFIX@"))
\ No newline at end of file
diff --git a/mlir/test/python/ir/capi.py b/mlir/test/python/ir/capi.py
new file mode 100644
index 0000000000000..d60fbd820f91e
--- /dev/null
+++ b/mlir/test/python/ir/capi.py
@@ -0,0 +1,6 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir._mlir_libs._capi import _capi
+
+print("success")
+# CHECK: success
\ No newline at end of file
>From 52e38f8897fd78494f489f951559b51558139025 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Thu, 25 Sep 2025 14:59:58 -0700
Subject: [PATCH 2/2] [MLIR][Python] enable precise registration
---
mlir/include/mlir-c/Bindings/Python/Interop.h | 9 ++++
mlir/include/mlir-c/Dialect/Builtin.h | 33 ++++++++++++++
mlir/include/mlir-c/IR.h | 18 +++++---
mlir/lib/Bindings/Python/IRCore.cpp | 34 ++++++++++++++-
mlir/lib/CAPI/Dialect/Builtin.cpp | 13 ++++++
mlir/lib/CAPI/Dialect/CMakeLists.txt | 8 ++++
mlir/lib/CAPI/IR/IR.cpp | 12 ++++++
mlir/python/CMakeLists.txt | 43 +++++++++----------
mlir/python/mlir/_mlir_libs/_capi.py.in | 10 ++++-
mlir/test/python/ir/capi.py | 39 ++++++++++++++++-
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 17 +++++++-
11 files changed, 204 insertions(+), 32 deletions(-)
create mode 100644 mlir/include/mlir-c/Dialect/Builtin.h
create mode 100644 mlir/lib/CAPI/Dialect/Builtin.cpp
diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h
index a33190c380d37..89559da689017 100644
--- a/mlir/include/mlir-c/Bindings/Python/Interop.h
+++ b/mlir/include/mlir-c/Bindings/Python/Interop.h
@@ -84,6 +84,8 @@
#define MLIR_PYTHON_CAPSULE_VALUE MAKE_MLIR_PYTHON_QUALNAME("ir.Value._CAPIPtr")
#define MLIR_PYTHON_CAPSULE_TYPEID \
MAKE_MLIR_PYTHON_QUALNAME("ir.TypeID._CAPIPtr")
+#define MLIR_PYTHON_CAPSULE_DIALECT_HANDLE \
+ MAKE_MLIR_PYTHON_QUALNAME("ir.DialectHandle._CAPIPtr")
/** Attribute on MLIR Python objects that expose their C-API pointer.
* This will be a type-specific capsule created as per one of the helpers
@@ -457,6 +459,13 @@ static inline MlirValue mlirPythonCapsuleToValue(PyObject *capsule) {
return value;
}
+static inline MlirDialectHandle
+mlirPythonCapsuleToDialectHandle(PyObject *capsule) {
+ void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_DIALECT_HANDLE);
+ MlirDialectHandle handle = {ptr};
+ return handle;
+}
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/include/mlir-c/Dialect/Builtin.h b/mlir/include/mlir-c/Dialect/Builtin.h
new file mode 100644
index 0000000000000..c5d958249b36f
--- /dev/null
+++ b/mlir/include/mlir-c/Dialect/Builtin.h
@@ -0,0 +1,33 @@
+//===-- mlir-c/Dialect/Builtin.h - C API for Builtin dialect ------*- C -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM
+// Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header declares the C interface for registering and accessing the
+// Builtin dialect. A dialect should be registered with a context to make it
+// available to users of the context. These users must load the dialect
+// before using any of its attributes, operations or types. Parser and pass
+// manager can load registered dialects automatically.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_C_DIALECT_BUILTIN_H
+#define MLIR_C_DIALECT_BUILTIN_H
+
+#include "mlir-c/IR.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Builtin, builtin);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // MLIR_C_DIALECT_BUILTIN_H
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 061d7620ba077..55cc86accb8a0 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -66,6 +66,7 @@ DEFINE_C_API_STRUCT(MlirLocation, const void);
DEFINE_C_API_STRUCT(MlirModule, const void);
DEFINE_C_API_STRUCT(MlirType, const void);
DEFINE_C_API_STRUCT(MlirValue, const void);
+DEFINE_C_API_STRUCT(MlirDialectHandle, const void);
#undef DEFINE_C_API_STRUCT
@@ -207,11 +208,6 @@ MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect);
// registration schemes.
//===----------------------------------------------------------------------===//
-struct MlirDialectHandle {
- const void *ptr;
-};
-typedef struct MlirDialectHandle MlirDialectHandle;
-
#define MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Name, Namespace) \
MLIR_CAPI_EXPORTED MlirDialectHandle mlirGetDialectHandle__##Namespace##__( \
void)
@@ -233,6 +229,11 @@ MLIR_CAPI_EXPORTED void mlirDialectHandleRegisterDialect(MlirDialectHandle,
MLIR_CAPI_EXPORTED MlirDialect mlirDialectHandleLoadDialect(MlirDialectHandle,
MlirContext);
+/// Checks if the dialect handle is null.
+static inline bool mlirDialectHandleIsNull(MlirDialectHandle handle) {
+ return !handle.ptr;
+}
+
//===----------------------------------------------------------------------===//
// DialectRegistry API.
//===----------------------------------------------------------------------===//
@@ -249,6 +250,13 @@ static inline bool mlirDialectRegistryIsNull(MlirDialectRegistry registry) {
MLIR_CAPI_EXPORTED void
mlirDialectRegistryDestroy(MlirDialectRegistry registry);
+MLIR_CAPI_EXPORTED int64_t
+mlirDialectRegistryGetNumDialectNames(MlirDialectRegistry registry);
+
+MLIR_CAPI_EXPORTED void
+mlirDialectRegistryGetDialectNames(MlirDialectRegistry registry,
+ MlirStringRef *dialectNames);
+
//===----------------------------------------------------------------------===//
// Location API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 83a8757bb72c7..c3d28733c2cee 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2897,6 +2897,14 @@ maybeGetTracebackLocation(const std::optional<PyLocation> &location) {
// Populates the core exports of the 'ir' submodule.
//------------------------------------------------------------------------------
+MlirDialectHandle createMlirDialectHandleFromCapsule(nb::object capsule) {
+ MlirDialectHandle rawRegistry =
+ mlirPythonCapsuleToDialectHandle(capsule.ptr());
+ if (mlirDialectHandleIsNull(rawRegistry))
+ throw nb::python_error();
+ return rawRegistry;
+}
+
void mlir::python::populateIRCore(nb::module_ &m) {
// disable leak warnings which tend to be false positives.
nb::set_leak_warnings(false);
@@ -3126,6 +3134,14 @@ void mlir::python::populateIRCore(nb::module_ &m) {
},
nb::sig("def __repr__(self) -> str"));
+ //----------------------------------------------------------------------------
+ // Mapping of MlirDialectHandle
+ //----------------------------------------------------------------------------
+
+ nb::class_<MlirDialectHandle>(m, "DialectHandle")
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
+ &createMlirDialectHandleFromCapsule);
+
//----------------------------------------------------------------------------
// Mapping of PyDialectRegistry
//----------------------------------------------------------------------------
@@ -3133,7 +3149,23 @@ void mlir::python::populateIRCore(nb::module_ &m) {
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule)
.def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
&PyDialectRegistry::createFromCapsule)
- .def(nb::init<>());
+ .def(nb::init<>())
+ .def("insert_dialect",
+ [](PyDialectRegistry &self, MlirDialectHandle handle) {
+ mlirDialectHandleInsertDialect(handle, self.get());
+ })
+ .def("insert_dialect",
+ [](PyDialectRegistry &self, intptr_t ptr) {
+ mlirDialectHandleInsertDialect(
+ {reinterpret_cast<const void *>(ptr)}, self.get());
+ })
+ .def_prop_ro("dialect_names", [](PyDialectRegistry &self) {
+ int64_t numDialectNames =
+ mlirDialectRegistryGetNumDialectNames(self.get());
+ std::vector<MlirStringRef> dialectNames(numDialectNames);
+ mlirDialectRegistryGetDialectNames(self.get(), dialectNames.data());
+ return dialectNames;
+ });
//----------------------------------------------------------------------------
// Mapping of Location
diff --git a/mlir/lib/CAPI/Dialect/Builtin.cpp b/mlir/lib/CAPI/Dialect/Builtin.cpp
new file mode 100644
index 0000000000000..d095daa294c56
--- /dev/null
+++ b/mlir/lib/CAPI/Dialect/Builtin.cpp
@@ -0,0 +1,13 @@
+//===- Builtin.cpp - C Interface for Builtin dialect ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/Dialect/Builtin.h"
+#include "mlir/CAPI/Registration.h"
+#include "mlir/IR/BuiltinDialect.h"
+
+MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Builtin, builtin, mlir::BuiltinDialect)
diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt
index bb1fdf8be3c8f..0dc6fbd224882 100644
--- a/mlir/lib/CAPI/Dialect/CMakeLists.txt
+++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt
@@ -16,6 +16,14 @@ add_mlir_upstream_c_api_library(MLIRCAPIArith
MLIRArithDialect
)
+add_mlir_upstream_c_api_library(MLIRCAPIBuiltin
+ Builtin.cpp
+
+ PARTIAL_SOURCES_INTENDED
+ LINK_LIBS PUBLIC
+ MLIRCAPIIR
+)
+
add_mlir_upstream_c_api_library(MLIRCAPIAsync
Async.cpp
AsyncPasses.cpp
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index e9844a7cc1909..a81e2a14e5255 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -150,6 +150,18 @@ void mlirDialectRegistryDestroy(MlirDialectRegistry registry) {
delete unwrap(registry);
}
+int64_t mlirDialectRegistryGetNumDialectNames(MlirDialectRegistry registry) {
+ auto dialectNames = unwrap(registry)->getDialectNames();
+ return std::distance(dialectNames.begin(), dialectNames.end());
+}
+
+void mlirDialectRegistryGetDialectNames(MlirDialectRegistry registry,
+ MlirStringRef *dialectNames) {
+ for (auto [i, location] :
+ llvm::enumerate(unwrap(registry)->getDialectNames()))
+ dialectNames[i] = wrap(location);
+}
+
//===----------------------------------------------------------------------===//
// AsmState API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 7e0be1c10433e..71180ec6e3c35 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -528,31 +528,32 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
MLIRCAPIDebug
MLIRCAPIIR
MLIRCAPIInterfaces
+ MLIRCAPITransforms
+ MLIRCAPIBuiltin
# Dialects
MLIRCAPIFunc
)
-# This extension exposes an API to register all dialects, extensions, and passes
-# packaged in upstream MLIR and it is used for the upstream "mlir" Python
-# package. Downstreams will likely want to provide their own and not depend
-# on this one, since it links in the world.
-# Note that this is not added to any top-level source target for transitive
-# inclusion: It must be included explicitly by downstreams if desired. Note that
-# this has a very large impact on what gets built/packaged.
-declare_mlir_python_extension(MLIRPythonExtension.RegisterEverything
- MODULE_NAME _mlirRegisterEverything
- ROOT_DIR "${PYTHON_SOURCE_DIR}"
- PYTHON_BINDINGS_LIBRARY nanobind
- SOURCES
- RegisterEverything.cpp
- PRIVATE_LINK_LIBS
- LLVMSupport
- EMBED_CAPI_LINK_LIBS
- MLIRCAPIConversion
- MLIRCAPITransforms
- MLIRCAPIRegisterEverything
-)
+## This extension exposes an API to register all dialects, extensions, and passes
+## packaged in upstream MLIR and it is used for the upstream "mlir" Python
+## package. Downstreams will likely want to provide their own and not depend
+## on this one, since it links in the world.
+## Note that this is not added to any top-level source target for transitive
+## inclusion: It must be included explicitly by downstreams if desired. Note that
+## this has a very large impact on what gets built/packaged.
+#declare_mlir_python_extension(MLIRPythonExtension.RegisterEverything
+# MODULE_NAME _mlirRegisterEverything
+# ROOT_DIR "${PYTHON_SOURCE_DIR}"
+# PYTHON_BINDINGS_LIBRARY nanobind
+# SOURCES
+# RegisterEverything.cpp
+# PRIVATE_LINK_LIBS
+# LLVMSupport
+# MLIRCAPIConversion
+# MLIRCAPITransforms
+# MLIRCAPIRegisterEverything
+#)
declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind
MODULE_NAME _mlirDialectsLinalg
@@ -871,7 +872,6 @@ add_mlir_python_common_capi_library(${MLIR_PYTHON_CAPI_DYLIB_NAME}
MLIRPythonCAPI.HeaderSources
DECLARED_SOURCES
MLIRPythonSources
- MLIRPythonExtension.RegisterEverything
${_ADDL_TEST_SOURCES}
)
@@ -979,7 +979,6 @@ add_mlir_python_modules(MLIRPythonModules
INSTALL_PREFIX "${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}"
DECLARED_SOURCES
MLIRPythonSources
- MLIRPythonExtension.RegisterEverything
MLIRPythonExtension.Core.type_stub_gen
MLIRPythonCAPICTypesBinding
${_ADDL_TEST_SOURCES}
diff --git a/mlir/python/mlir/_mlir_libs/_capi.py.in b/mlir/python/mlir/_mlir_libs/_capi.py.in
index 9568845e67de9..b79d14957a511 100644
--- a/mlir/python/mlir/_mlir_libs/_capi.py.in
+++ b/mlir/python/mlir/_mlir_libs/_capi.py.in
@@ -5,4 +5,12 @@
import ctypes
from pathlib import Path
-_capi = ctypes.CDLL(str(Path(__file__).parent / "@CMAKE_SHARED_LIBRARY_PREFIX@@MLIR_PYTHON_CAPI_DYLIB_NAME@@CMAKE_SHARED_LIBRARY_SUFFIX@"))
\ No newline at end of file
+_capi = ctypes.CDLL(str(Path(__file__).parent / "@CMAKE_SHARED_LIBRARY_PREFIX@@MLIR_PYTHON_CAPI_DYLIB_NAME@@CMAKE_SHARED_LIBRARY_SUFFIX@"))
+
+PyCapsule_New = ctypes.pythonapi.PyCapsule_New
+PyCapsule_New.restype = ctypes.py_object
+PyCapsule_New.argtypes = ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p
+
+MLIR_PYTHON_CAPSULE_DIALECT_HANDLE = (
+ "@MLIR_PYTHON_PACKAGE_PREFIX at .ir.DialectHandle._CAPIPtr"
+).encode()
diff --git a/mlir/test/python/ir/capi.py b/mlir/test/python/ir/capi.py
index d60fbd820f91e..934fab91f05a0 100644
--- a/mlir/test/python/ir/capi.py
+++ b/mlir/test/python/ir/capi.py
@@ -1,6 +1,41 @@
# RUN: %PYTHON %s | FileCheck %s
-from mlir._mlir_libs._capi import _capi
+import ctypes
+
+from mlir._mlir_libs import get_dialect_registry
+from mlir._mlir_libs._capi import (
+ _capi,
+ PyCapsule_New,
+ MLIR_PYTHON_CAPSULE_DIALECT_HANDLE,
+)
+from mlir.ir import DialectHandle
print("success")
-# CHECK: success
\ No newline at end of file
+# CHECK: success
+
+
+if not hasattr(_capi, "mlirGetDialectHandle__arith__"):
+ raise Exception("missing API")
+_capi.mlirGetDialectHandle__arith__.argtypes = []
+_capi.mlirGetDialectHandle__arith__.restype = ctypes.c_void_p
+
+if not hasattr(_capi, "mlirGetDialectHandle__quant__"):
+ raise Exception("missing API")
+_capi.mlirGetDialectHandle__quant__.argtypes = []
+_capi.mlirGetDialectHandle__quant__.restype = ctypes.c_void_p
+
+dialect_registry = get_dialect_registry()
+# CHECK: ['builtin']
+print(dialect_registry.dialect_names)
+
+arith_handle = _capi.mlirGetDialectHandle__arith__()
+dialect_registry.insert_dialect(arith_handle)
+# CHECK: ['arith', 'builtin']
+print(dialect_registry.dialect_names)
+
+quant_handle = _capi.mlirGetDialectHandle__quant__()
+capsule = PyCapsule_New(quant_handle, MLIR_PYTHON_CAPSULE_DIALECT_HANDLE, None)
+dialect_handle = DialectHandle._CAPICreate(capsule)
+dialect_registry.insert_dialect(dialect_handle)
+# CHECK: ['arith', 'builtin', 'quant']
+print(dialect_registry.dialect_names)
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 0172b3fa38a6b..248ac57f2a947 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -47,6 +47,21 @@ _ods_cext.globals.register_traceback_file_exclusion(__file__)
import builtins
from typing import Sequence as _Sequence, Union as _Union, Optional as _Optional
+import ctypes
+from .._mlir_libs import get_dialect_registry
+from .._mlir_libs._capi import _capi
+
+_dialect_registry = get_dialect_registry()
+
+if not hasattr(_capi, "mlirGetDialectHandle__{0}__"):
+ raise Exception("missing API")
+
+_capi.mlirGetDialectHandle__{0}__.argtypes = []
+_capi.mlirGetDialectHandle__{0}__.restype = ctypes.c_void_p
+
+_{0}_handle = _capi.mlirGetDialectHandle__{0}__()
+_dialect_registry.insert_dialect(_{0}_handle)
+
)Py";
/// Template for dialect class:
@@ -1191,7 +1206,7 @@ static bool emitAllOps(const RecordKeeper &records, raw_ostream &os) {
if (clDialectName.empty())
llvm::PrintFatalError("dialect name not provided");
- os << fileHeader;
+ os << formatv(fileHeader, clDialectName.getValue());
if (!clDialectExtensionName.empty())
os << formatv(dialectExtensionTemplate, clDialectName.getValue());
else
More information about the Mlir-commits
mailing list