[Mlir-commits] [mlir] [MLIR][Python] enable precise registration (PR #160742)

Maksim Levental llvmlistbot at llvm.org
Thu Sep 25 21:32:43 PDT 2025


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/160742

>From 5815846788bc115ee7e7db71c6aaf17060596d8b Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Thu, 25 Sep 2025 09:40:53 -0700
Subject: [PATCH 1/2] [MLIR][Python] enable precise registration

---
 mlir/cmake/modules/AddMLIRPython.cmake  | 14 +++++++++++---
 mlir/python/CMakeLists.txt              | 20 ++++++++++++++++++--
 mlir/python/mlir/_mlir_libs/_capi.py.in |  8 ++++++++
 mlir/test/python/ir/capi.py             |  6 ++++++
 4 files changed, 43 insertions(+), 5 deletions(-)
 create mode 100644 mlir/python/mlir/_mlir_libs/_capi.py.in
 create mode 100644 mlir/test/python/ir/capi.py

diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 208cbdd1dd535..d8b6d493f985c 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -23,11 +23,14 @@
 #     grouping. Source groupings form a DAG.
 #   SOURCES: List of specific source files relative to ROOT_DIR to include.
 #   SOURCES_GLOB: List of glob patterns relative to ROOT_DIR to include.
+#   EMBED_CAPI_LINK_LIBS: Dependent CAPI libraries that this extension depends
+#     on. These will be collected for all extensions and put into an
+#     aggregate dylib that is linked against.
 function(declare_mlir_python_sources name)
   cmake_parse_arguments(ARG
     ""
     "ROOT_DIR;ADD_TO_PARENT"
-    "SOURCES;SOURCES_GLOB"
+    "SOURCES;SOURCES_GLOB;EMBED_CAPI_LINK_LIBS"
     ${ARGN})
 
   if(NOT ARG_ROOT_DIR)
@@ -53,9 +56,10 @@ function(declare_mlir_python_sources name)
   set_target_properties(${name} PROPERTIES
     # Yes: Leading-lowercase property names are load bearing and the recommended
     # way to do this: https://gitlab.kitware.com/cmake/cmake/-/issues/19261
-    EXPORT_PROPERTIES "mlir_python_SOURCES_TYPE;mlir_python_DEPENDS"
+    EXPORT_PROPERTIES "mlir_python_SOURCES_TYPE;mlir_python_DEPENDS;mlir_python_EMBED_CAPI_LINK_LIBS"
     mlir_python_SOURCES_TYPE pure
     mlir_python_DEPENDS ""
+    mlir_python_EMBED_CAPI_LINK_LIBS "${ARG_EMBED_CAPI_LINK_LIBS}"
   )
 
   # Use the interface include directories and sources on the target to carry the
@@ -374,6 +378,9 @@ endfunction()
 #     This file is where the *EnumAttrs are defined, not where the *Enums are defined.
 #     **WARNING**: This arg will shortly be removed when the just-below TODO is satisfied. Use at your
 #     risk.
+#   EMBED_CAPI_LINK_LIBS: Dependent CAPI libraries that this extension depends
+#     on. These will be collected for all extensions and put into an
+#     aggregate dylib that is linked against.
 #
 # TODO: Right now `TD_FILE` can't be the actual dialect tablegen file, since we
 #       use its path to determine where to place the generated python file. If
@@ -383,7 +390,7 @@ function(declare_mlir_dialect_python_bindings)
   cmake_parse_arguments(ARG
     "GEN_ENUM_BINDINGS"
     "ROOT_DIR;ADD_TO_PARENT;TD_FILE;DIALECT_NAME"
-    "SOURCES;SOURCES_GLOB;DEPENDS;GEN_ENUM_BINDINGS_TD_FILE"
+    "SOURCES;SOURCES_GLOB;DEPENDS;GEN_ENUM_BINDINGS_TD_FILE;EMBED_CAPI_LINK_LIBS"
     ${ARGN})
   # Sources.
   set(_dialect_target "${ARG_ADD_TO_PARENT}.${ARG_DIALECT_NAME}")
@@ -424,6 +431,7 @@ function(declare_mlir_dialect_python_bindings)
       ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}"
       ADD_TO_PARENT "${_dialect_target}"
       SOURCES ${_sources}
+      EMBED_CAPI_LINK_LIBS "${ARG_EMBED_CAPI_LINK_LIBS}"
     )
   endif()
 endfunction()
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index d6686bb89ce4e..7e0be1c10433e 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -326,7 +326,10 @@ declare_mlir_dialect_python_bindings(
   SOURCES
     dialects/arith.py
   DIALECT_NAME arith
-  GEN_ENUM_BINDINGS)
+  GEN_ENUM_BINDINGS
+  EMBED_CAPI_LINK_LIBS
+    MLIRCAPIArith
+  )
 
 declare_mlir_dialect_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
@@ -846,8 +849,20 @@ endif()
 # once ready.
 ################################################################################
 
+set(MLIR_PYTHON_CAPI_DYLIB_NAME MLIRPythonCAPI)
+configure_file(
+  "${CMAKE_CURRENT_LIST_DIR}/mlir/_mlir_libs/_capi.py.in"
+  "${CMAKE_CURRENT_BINARY_DIR}/_mlir_libs/_capi.py"
+  @ONLY
+)
+declare_mlir_python_sources(
+  MLIRPythonCAPICTypesBinding
+  ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}"
+  SOURCES _mlir_libs/_capi.py
+)
+
 set(MLIRPythonModules_ROOT_PREFIX "${MLIR_BINARY_DIR}/${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}")
-add_mlir_python_common_capi_library(MLIRPythonCAPI
+add_mlir_python_common_capi_library(${MLIR_PYTHON_CAPI_DYLIB_NAME}
   INSTALL_COMPONENT MLIRPythonModules
   INSTALL_DESTINATION "${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}/_mlir_libs"
   OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
@@ -966,6 +981,7 @@ add_mlir_python_modules(MLIRPythonModules
     MLIRPythonSources
     MLIRPythonExtension.RegisterEverything
     MLIRPythonExtension.Core.type_stub_gen
+    MLIRPythonCAPICTypesBinding
     ${_ADDL_TEST_SOURCES}
   COMMON_CAPI_LINK_LIBS
     MLIRPythonCAPI
diff --git a/mlir/python/mlir/_mlir_libs/_capi.py.in b/mlir/python/mlir/_mlir_libs/_capi.py.in
new file mode 100644
index 0000000000000..9568845e67de9
--- /dev/null
+++ b/mlir/python/mlir/_mlir_libs/_capi.py.in
@@ -0,0 +1,8 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import ctypes
+from pathlib import Path
+
+_capi = ctypes.CDLL(str(Path(__file__).parent / "@CMAKE_SHARED_LIBRARY_PREFIX@@MLIR_PYTHON_CAPI_DYLIB_NAME@@CMAKE_SHARED_LIBRARY_SUFFIX@"))
\ No newline at end of file
diff --git a/mlir/test/python/ir/capi.py b/mlir/test/python/ir/capi.py
new file mode 100644
index 0000000000000..d60fbd820f91e
--- /dev/null
+++ b/mlir/test/python/ir/capi.py
@@ -0,0 +1,6 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir._mlir_libs._capi import _capi
+
+print("success")
+# CHECK: success
\ No newline at end of file

>From 52e38f8897fd78494f489f951559b51558139025 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Thu, 25 Sep 2025 14:59:58 -0700
Subject: [PATCH 2/2] [MLIR][Python] enable precise registration

---
 mlir/include/mlir-c/Bindings/Python/Interop.h |  9 ++++
 mlir/include/mlir-c/Dialect/Builtin.h         | 33 ++++++++++++++
 mlir/include/mlir-c/IR.h                      | 18 +++++---
 mlir/lib/Bindings/Python/IRCore.cpp           | 34 ++++++++++++++-
 mlir/lib/CAPI/Dialect/Builtin.cpp             | 13 ++++++
 mlir/lib/CAPI/Dialect/CMakeLists.txt          |  8 ++++
 mlir/lib/CAPI/IR/IR.cpp                       | 12 ++++++
 mlir/python/CMakeLists.txt                    | 43 +++++++++----------
 mlir/python/mlir/_mlir_libs/_capi.py.in       | 10 ++++-
 mlir/test/python/ir/capi.py                   | 39 ++++++++++++++++-
 mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 17 +++++++-
 11 files changed, 204 insertions(+), 32 deletions(-)
 create mode 100644 mlir/include/mlir-c/Dialect/Builtin.h
 create mode 100644 mlir/lib/CAPI/Dialect/Builtin.cpp

diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h
index a33190c380d37..89559da689017 100644
--- a/mlir/include/mlir-c/Bindings/Python/Interop.h
+++ b/mlir/include/mlir-c/Bindings/Python/Interop.h
@@ -84,6 +84,8 @@
 #define MLIR_PYTHON_CAPSULE_VALUE MAKE_MLIR_PYTHON_QUALNAME("ir.Value._CAPIPtr")
 #define MLIR_PYTHON_CAPSULE_TYPEID                                             \
   MAKE_MLIR_PYTHON_QUALNAME("ir.TypeID._CAPIPtr")
+#define MLIR_PYTHON_CAPSULE_DIALECT_HANDLE                                     \
+  MAKE_MLIR_PYTHON_QUALNAME("ir.DialectHandle._CAPIPtr")
 
 /** Attribute on MLIR Python objects that expose their C-API pointer.
  * This will be a type-specific capsule created as per one of the helpers
@@ -457,6 +459,13 @@ static inline MlirValue mlirPythonCapsuleToValue(PyObject *capsule) {
   return value;
 }
 
+static inline MlirDialectHandle
+mlirPythonCapsuleToDialectHandle(PyObject *capsule) {
+  void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_DIALECT_HANDLE);
+  MlirDialectHandle handle = {ptr};
+  return handle;
+}
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/mlir/include/mlir-c/Dialect/Builtin.h b/mlir/include/mlir-c/Dialect/Builtin.h
new file mode 100644
index 0000000000000..c5d958249b36f
--- /dev/null
+++ b/mlir/include/mlir-c/Dialect/Builtin.h
@@ -0,0 +1,33 @@
+//===-- mlir-c/Dialect/Builtin.h - C API for Builtin dialect ------*- C -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM
+// Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header declares the C interface for registering and accessing the
+// Builtin dialect. A dialect should be registered with a context to make it
+// available to users of the context. These users must load the dialect
+// before using any of its attributes, operations or types. Parser and pass
+// manager can load registered dialects automatically.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_C_DIALECT_BUILTIN_H
+#define MLIR_C_DIALECT_BUILTIN_H
+
+#include "mlir-c/IR.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Builtin, builtin);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // MLIR_C_DIALECT_BUILTIN_H
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 061d7620ba077..55cc86accb8a0 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -66,6 +66,7 @@ DEFINE_C_API_STRUCT(MlirLocation, const void);
 DEFINE_C_API_STRUCT(MlirModule, const void);
 DEFINE_C_API_STRUCT(MlirType, const void);
 DEFINE_C_API_STRUCT(MlirValue, const void);
+DEFINE_C_API_STRUCT(MlirDialectHandle, const void);
 
 #undef DEFINE_C_API_STRUCT
 
@@ -207,11 +208,6 @@ MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect);
 // registration schemes.
 //===----------------------------------------------------------------------===//
 
-struct MlirDialectHandle {
-  const void *ptr;
-};
-typedef struct MlirDialectHandle MlirDialectHandle;
-
 #define MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Name, Namespace)                \
   MLIR_CAPI_EXPORTED MlirDialectHandle mlirGetDialectHandle__##Namespace##__(  \
       void)
@@ -233,6 +229,11 @@ MLIR_CAPI_EXPORTED void mlirDialectHandleRegisterDialect(MlirDialectHandle,
 MLIR_CAPI_EXPORTED MlirDialect mlirDialectHandleLoadDialect(MlirDialectHandle,
                                                             MlirContext);
 
+/// Checks if the dialect handle is null.
+static inline bool mlirDialectHandleIsNull(MlirDialectHandle handle) {
+  return !handle.ptr;
+}
+
 //===----------------------------------------------------------------------===//
 // DialectRegistry API.
 //===----------------------------------------------------------------------===//
@@ -249,6 +250,13 @@ static inline bool mlirDialectRegistryIsNull(MlirDialectRegistry registry) {
 MLIR_CAPI_EXPORTED void
 mlirDialectRegistryDestroy(MlirDialectRegistry registry);
 
+MLIR_CAPI_EXPORTED int64_t
+mlirDialectRegistryGetNumDialectNames(MlirDialectRegistry registry);
+
+MLIR_CAPI_EXPORTED void
+mlirDialectRegistryGetDialectNames(MlirDialectRegistry registry,
+                                   MlirStringRef *dialectNames);
+
 //===----------------------------------------------------------------------===//
 // Location API.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 83a8757bb72c7..c3d28733c2cee 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2897,6 +2897,14 @@ maybeGetTracebackLocation(const std::optional<PyLocation> &location) {
 // Populates the core exports of the 'ir' submodule.
 //------------------------------------------------------------------------------
 
+MlirDialectHandle createMlirDialectHandleFromCapsule(nb::object capsule) {
+  MlirDialectHandle rawRegistry =
+      mlirPythonCapsuleToDialectHandle(capsule.ptr());
+  if (mlirDialectHandleIsNull(rawRegistry))
+    throw nb::python_error();
+  return rawRegistry;
+}
+
 void mlir::python::populateIRCore(nb::module_ &m) {
   // disable leak warnings which tend to be false positives.
   nb::set_leak_warnings(false);
@@ -3126,6 +3134,14 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           },
           nb::sig("def __repr__(self) -> str"));
 
+  //----------------------------------------------------------------------------
+  // Mapping of MlirDialectHandle
+  //----------------------------------------------------------------------------
+
+  nb::class_<MlirDialectHandle>(m, "DialectHandle")
+      .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
+                  &createMlirDialectHandleFromCapsule);
+
   //----------------------------------------------------------------------------
   // Mapping of PyDialectRegistry
   //----------------------------------------------------------------------------
@@ -3133,7 +3149,23 @@ void mlir::python::populateIRCore(nb::module_ &m) {
       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule)
       .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
                   &PyDialectRegistry::createFromCapsule)
-      .def(nb::init<>());
+      .def(nb::init<>())
+      .def("insert_dialect",
+           [](PyDialectRegistry &self, MlirDialectHandle handle) {
+             mlirDialectHandleInsertDialect(handle, self.get());
+           })
+      .def("insert_dialect",
+           [](PyDialectRegistry &self, intptr_t ptr) {
+             mlirDialectHandleInsertDialect(
+                 {reinterpret_cast<const void *>(ptr)}, self.get());
+           })
+      .def_prop_ro("dialect_names", [](PyDialectRegistry &self) {
+        int64_t numDialectNames =
+            mlirDialectRegistryGetNumDialectNames(self.get());
+        std::vector<MlirStringRef> dialectNames(numDialectNames);
+        mlirDialectRegistryGetDialectNames(self.get(), dialectNames.data());
+        return dialectNames;
+      });
 
   //----------------------------------------------------------------------------
   // Mapping of Location
diff --git a/mlir/lib/CAPI/Dialect/Builtin.cpp b/mlir/lib/CAPI/Dialect/Builtin.cpp
new file mode 100644
index 0000000000000..d095daa294c56
--- /dev/null
+++ b/mlir/lib/CAPI/Dialect/Builtin.cpp
@@ -0,0 +1,13 @@
+//===- Builtin.cpp - C Interface for Builtin dialect ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/Dialect/Builtin.h"
+#include "mlir/CAPI/Registration.h"
+#include "mlir/IR/BuiltinDialect.h"
+
+MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Builtin, builtin, mlir::BuiltinDialect)
diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt
index bb1fdf8be3c8f..0dc6fbd224882 100644
--- a/mlir/lib/CAPI/Dialect/CMakeLists.txt
+++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt
@@ -16,6 +16,14 @@ add_mlir_upstream_c_api_library(MLIRCAPIArith
   MLIRArithDialect
 )
 
+add_mlir_upstream_c_api_library(MLIRCAPIBuiltin
+  Builtin.cpp
+
+  PARTIAL_SOURCES_INTENDED
+  LINK_LIBS PUBLIC
+  MLIRCAPIIR
+)
+
 add_mlir_upstream_c_api_library(MLIRCAPIAsync
   Async.cpp
   AsyncPasses.cpp
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index e9844a7cc1909..a81e2a14e5255 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -150,6 +150,18 @@ void mlirDialectRegistryDestroy(MlirDialectRegistry registry) {
   delete unwrap(registry);
 }
 
+int64_t mlirDialectRegistryGetNumDialectNames(MlirDialectRegistry registry) {
+  auto dialectNames = unwrap(registry)->getDialectNames();
+  return std::distance(dialectNames.begin(), dialectNames.end());
+}
+
+void mlirDialectRegistryGetDialectNames(MlirDialectRegistry registry,
+                                        MlirStringRef *dialectNames) {
+  for (auto [i, location] :
+       llvm::enumerate(unwrap(registry)->getDialectNames()))
+    dialectNames[i] = wrap(location);
+}
+
 //===----------------------------------------------------------------------===//
 // AsmState API.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 7e0be1c10433e..71180ec6e3c35 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -528,31 +528,32 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
     MLIRCAPIDebug
     MLIRCAPIIR
     MLIRCAPIInterfaces
+    MLIRCAPITransforms
+    MLIRCAPIBuiltin
 
     # Dialects
     MLIRCAPIFunc
 )
 
-# This extension exposes an API to register all dialects, extensions, and passes
-# packaged in upstream MLIR and it is used for the upstream "mlir" Python
-# package. Downstreams will likely want to provide their own and not depend
-# on this one, since it links in the world.
-# Note that this is not added to any top-level source target for transitive
-# inclusion: It must be included explicitly by downstreams if desired. Note that
-# this has a very large impact on what gets built/packaged.
-declare_mlir_python_extension(MLIRPythonExtension.RegisterEverything
-  MODULE_NAME _mlirRegisterEverything
-  ROOT_DIR "${PYTHON_SOURCE_DIR}"
-  PYTHON_BINDINGS_LIBRARY nanobind
-  SOURCES
-    RegisterEverything.cpp
-  PRIVATE_LINK_LIBS
-    LLVMSupport
-  EMBED_CAPI_LINK_LIBS
-    MLIRCAPIConversion
-    MLIRCAPITransforms
-    MLIRCAPIRegisterEverything
-)
+## This extension exposes an API to register all dialects, extensions, and passes
+## packaged in upstream MLIR and it is used for the upstream "mlir" Python
+## package. Downstreams will likely want to provide their own and not depend
+## on this one, since it links in the world.
+## Note that this is not added to any top-level source target for transitive
+## inclusion: It must be included explicitly by downstreams if desired. Note that
+## this has a very large impact on what gets built/packaged.
+#declare_mlir_python_extension(MLIRPythonExtension.RegisterEverything
+#  MODULE_NAME _mlirRegisterEverything
+#  ROOT_DIR "${PYTHON_SOURCE_DIR}"
+#  PYTHON_BINDINGS_LIBRARY nanobind
+#  SOURCES
+#    RegisterEverything.cpp
+#  PRIVATE_LINK_LIBS
+#    LLVMSupport
+#    MLIRCAPIConversion
+#    MLIRCAPITransforms
+#    MLIRCAPIRegisterEverything
+#)
 
 declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind
   MODULE_NAME _mlirDialectsLinalg
@@ -871,7 +872,6 @@ add_mlir_python_common_capi_library(${MLIR_PYTHON_CAPI_DYLIB_NAME}
     MLIRPythonCAPI.HeaderSources
   DECLARED_SOURCES
     MLIRPythonSources
-    MLIRPythonExtension.RegisterEverything
     ${_ADDL_TEST_SOURCES}
 )
 
@@ -979,7 +979,6 @@ add_mlir_python_modules(MLIRPythonModules
   INSTALL_PREFIX "${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}"
   DECLARED_SOURCES
     MLIRPythonSources
-    MLIRPythonExtension.RegisterEverything
     MLIRPythonExtension.Core.type_stub_gen
     MLIRPythonCAPICTypesBinding
     ${_ADDL_TEST_SOURCES}
diff --git a/mlir/python/mlir/_mlir_libs/_capi.py.in b/mlir/python/mlir/_mlir_libs/_capi.py.in
index 9568845e67de9..b79d14957a511 100644
--- a/mlir/python/mlir/_mlir_libs/_capi.py.in
+++ b/mlir/python/mlir/_mlir_libs/_capi.py.in
@@ -5,4 +5,12 @@
 import ctypes
 from pathlib import Path
 
-_capi = ctypes.CDLL(str(Path(__file__).parent / "@CMAKE_SHARED_LIBRARY_PREFIX@@MLIR_PYTHON_CAPI_DYLIB_NAME@@CMAKE_SHARED_LIBRARY_SUFFIX@"))
\ No newline at end of file
+_capi = ctypes.CDLL(str(Path(__file__).parent / "@CMAKE_SHARED_LIBRARY_PREFIX@@MLIR_PYTHON_CAPI_DYLIB_NAME@@CMAKE_SHARED_LIBRARY_SUFFIX@"))
+
+PyCapsule_New = ctypes.pythonapi.PyCapsule_New
+PyCapsule_New.restype = ctypes.py_object
+PyCapsule_New.argtypes = ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p
+
+MLIR_PYTHON_CAPSULE_DIALECT_HANDLE = (
+    "@MLIR_PYTHON_PACKAGE_PREFIX at .ir.DialectHandle._CAPIPtr"
+).encode()
diff --git a/mlir/test/python/ir/capi.py b/mlir/test/python/ir/capi.py
index d60fbd820f91e..934fab91f05a0 100644
--- a/mlir/test/python/ir/capi.py
+++ b/mlir/test/python/ir/capi.py
@@ -1,6 +1,41 @@
 # RUN: %PYTHON %s | FileCheck %s
 
-from mlir._mlir_libs._capi import _capi
+import ctypes
+
+from mlir._mlir_libs import get_dialect_registry
+from mlir._mlir_libs._capi import (
+    _capi,
+    PyCapsule_New,
+    MLIR_PYTHON_CAPSULE_DIALECT_HANDLE,
+)
+from mlir.ir import DialectHandle
 
 print("success")
-# CHECK: success
\ No newline at end of file
+# CHECK: success
+
+
+if not hasattr(_capi, "mlirGetDialectHandle__arith__"):
+    raise Exception("missing API")
+_capi.mlirGetDialectHandle__arith__.argtypes = []
+_capi.mlirGetDialectHandle__arith__.restype = ctypes.c_void_p
+
+if not hasattr(_capi, "mlirGetDialectHandle__quant__"):
+    raise Exception("missing API")
+_capi.mlirGetDialectHandle__quant__.argtypes = []
+_capi.mlirGetDialectHandle__quant__.restype = ctypes.c_void_p
+
+dialect_registry = get_dialect_registry()
+# CHECK: ['builtin']
+print(dialect_registry.dialect_names)
+
+arith_handle = _capi.mlirGetDialectHandle__arith__()
+dialect_registry.insert_dialect(arith_handle)
+# CHECK: ['arith', 'builtin']
+print(dialect_registry.dialect_names)
+
+quant_handle = _capi.mlirGetDialectHandle__quant__()
+capsule = PyCapsule_New(quant_handle, MLIR_PYTHON_CAPSULE_DIALECT_HANDLE, None)
+dialect_handle = DialectHandle._CAPICreate(capsule)
+dialect_registry.insert_dialect(dialect_handle)
+# CHECK: ['arith', 'builtin', 'quant']
+print(dialect_registry.dialect_names)
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 0172b3fa38a6b..248ac57f2a947 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -47,6 +47,21 @@ _ods_cext.globals.register_traceback_file_exclusion(__file__)
 import builtins
 from typing import Sequence as _Sequence, Union as _Union, Optional as _Optional
 
+import ctypes
+from .._mlir_libs import get_dialect_registry
+from .._mlir_libs._capi import _capi
+
+_dialect_registry = get_dialect_registry()
+
+if not hasattr(_capi, "mlirGetDialectHandle__{0}__"):
+    raise Exception("missing API")
+
+_capi.mlirGetDialectHandle__{0}__.argtypes = []
+_capi.mlirGetDialectHandle__{0}__.restype = ctypes.c_void_p
+
+_{0}_handle = _capi.mlirGetDialectHandle__{0}__()
+_dialect_registry.insert_dialect(_{0}_handle)
+
 )Py";
 
 /// Template for dialect class:
@@ -1191,7 +1206,7 @@ static bool emitAllOps(const RecordKeeper &records, raw_ostream &os) {
   if (clDialectName.empty())
     llvm::PrintFatalError("dialect name not provided");
 
-  os << fileHeader;
+  os << formatv(fileHeader, clDialectName.getValue());
   if (!clDialectExtensionName.empty())
     os << formatv(dialectExtensionTemplate, clDialectName.getValue());
   else



More information about the Mlir-commits mailing list