[Mlir-commits] [mlir] [mlir][Python] generate type stubs for dialect extensions (PR #175403)

Maksim Levental llvmlistbot at llvm.org
Sun Jan 11 18:19:56 PST 2026


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/175403

>From 8bff757457813aba9db16dd93f4076cbee5b986d Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Sat, 10 Jan 2026 18:11:42 -0800
Subject: [PATCH] [mlir][Python] generate type stubs for dialect extensions

---
 mlir/python/CMakeLists.txt                    |  93 ++++++++----
 .../mlir/_mlir_libs/_mlir/dialects/pdl.pyi    |  63 --------
 .../mlir/_mlir_libs/_mlir/dialects/quant.pyi  | 142 ------------------
 .../_mlir/dialects/transform/__init__.pyi     |  25 ---
 .../mlir/_mlir_libs/_mlirExecutionEngine.pyi  |  24 ---
 mlir/python/replace_text.cmake                |   9 ++
 6 files changed, 69 insertions(+), 287 deletions(-)
 delete mode 100644 mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi
 delete mode 100644 mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi
 delete mode 100644 mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi
 delete mode 100644 mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi
 create mode 100644 mlir/python/replace_text.cmake

diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 003a06b16daac..25243680867da 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -45,7 +45,6 @@ declare_mlir_python_sources(MLIRPythonSources.ExecutionEngine
   ADD_TO_PARENT MLIRPythonSources
   SOURCES
     execution_engine.py
-    _mlir_libs/_mlirExecutionEngine.pyi
   SOURCES_GLOB
     runtime/*.py
 )
@@ -216,7 +215,6 @@ declare_mlir_dialect_python_bindings(
   TD_FILE dialects/TransformOps.td
   SOURCES
     dialects/transform/__init__.py
-    _mlir_libs/_mlir/dialects/transform/__init__.pyi
   DIALECT_NAME transform
   GEN_ENUM_BINDINGS_TD_FILE
     "../../include/mlir/Dialect/Transform/IR/TransformAttrs.td"
@@ -406,8 +404,7 @@ declare_mlir_python_sources(
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   GEN_ENUM_BINDINGS
   SOURCES
-    dialects/quant.py
-    _mlir_libs/_mlir/dialects/quant.pyi)
+    dialects/quant.py)
 
 declare_mlir_dialect_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
@@ -423,7 +420,6 @@ declare_mlir_dialect_python_bindings(
   TD_FILE dialects/PDLOps.td
   SOURCES
     dialects/pdl.py
-    _mlir_libs/_mlir/dialects/pdl.pyi
   DIALECT_NAME pdl)
 
 declare_mlir_dialect_python_bindings(
@@ -882,7 +878,50 @@ add_mlir_python_common_capi_library(MLIRPythonCAPI
 
 _flatten_mlir_python_targets(mlir_python_sources_deps MLIRPythonSources)
 
+function(mlir_generate_dialect_extension_type_stubs ext_target)
+  get_target_property(_extension_srcs ${ext_target} INTERFACE_SOURCES)
+  get_target_property(_module_name ${ext_target} mlir_python_EXTENSION_MODULE_NAME)
+  mlir_generate_type_stubs(
+    # This is the FQN path because dialect modules import _mlir when loaded. See above.
+    MODULE_NAME ${MLIR_PYTHON_PACKAGE_PREFIX}._mlir_libs.${_module_name}
+    DEPENDS_TARGETS
+      # You need both _mlir and ${_module_name} because dialect modules import _mlir when loaded
+      # (so _mlir needs to be built before calling stubgen).
+      MLIRPythonModules.extension._mlir.dso
+      MLIRPythonModules.extension.${_module_name}.dso
+      # You need this one so that ir.py "built" because mlir._mlir_libs.__init__.py import mlir.ir in _site_initialize.
+      MLIRPythonModules.sources.MLIRPythonSources.Core.Python
+    OUTPUT_DIR "${CMAKE_CURRENT_BINARY_DIR}/type_stubs/_mlir_libs"
+    OUTPUTS ${_module_name}.pyi
+    DEPENDS_TARGET_SRC_DEPS "${_extension_srcs}"
+    IMPORT_PATHS "${MLIRPythonModules_ROOT_PREFIX}/.."
+  )
+  list(APPEND _mlir_typestub_gen_targets ${NB_STUBGEN_CUSTOM_TARGET})
+  set(_mlir_typestub_gen_targets ${_mlir_typestub_gen_targets} PARENT_SCOPE)
+  set(INPUT_PATH "${CMAKE_CURRENT_BINARY_DIR}/type_stubs/_mlir_libs/${_module_name}.pyi")
+  add_custom_target(${NB_STUBGEN_CUSTOM_TARGET}.fixup
+    COMMAND ${CMAKE_COMMAND}
+            -DINPUT_FILE=${INPUT_PATH}
+            -DOUTPUT_FILE=${INPUT_PATH}
+            -P "${CMAKE_CURRENT_SOURCE_DIR}/replace_text.cmake"
+    DEPENDS "${INPUT_PATH}" "${CMAKE_CURRENT_SOURCE_DIR}/replace_text.cmake"
+    COMMENT "Replacing strings in ${_module_name}.pyi at build time"
+    VERBATIM
+  )
+  list(APPEND _mlir_typestub_gen_targets ${NB_STUBGEN_CUSTOM_TARGET}.fixup)
+  declare_mlir_python_sources(
+    ${ext_target}.type_stub_gen
+    ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}/type_stubs"
+    ADD_TO_PARENT MLIRPythonSources.Dialects
+    SOURCES _mlir_libs/${_module_name}.pyi
+  )
+  list(APPEND mlir_python_sources_deps ${ext_target}.type_stub_gen)
+  set(mlir_python_sources_deps ${mlir_python_sources_deps} PARENT_SCOPE)
+endfunction()
+
 if(MLIR_PYTHON_STUBGEN_ENABLED)
+  set(_mlir_typestub_gen_targets)
+
   # _mlir stubgen
   # Note: All this needs to come before add_mlir_python_modules(MLIRPythonModules so that the install targets for the
   # generated type stubs get created.
@@ -924,7 +963,7 @@ if(MLIR_PYTHON_STUBGEN_ENABLED)
     DEPENDS_TARGET_SRC_DEPS "${_core_extension_srcs}"
     IMPORT_PATHS "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
   )
-  set(_mlir_typestub_gen_target "${NB_STUBGEN_CUSTOM_TARGET}")
+  list(APPEND _mlir_typestub_gen_targets ${NB_STUBGEN_CUSTOM_TARGET})
 
   list(TRANSFORM _core_type_stub_sources PREPEND "_mlir_libs/")
   # Note, we do not do ADD_TO_PARENT here so that the type stubs are not associated (as mlir_DEPENDS) with
@@ -937,32 +976,23 @@ if(MLIR_PYTHON_STUBGEN_ENABLED)
   )
   list(APPEND mlir_python_sources_deps MLIRPythonExtension.Core.type_stub_gen)
 
+  mlir_generate_dialect_extension_type_stubs(MLIRPythonExtension.Dialects.Linalg.Nanobind)
+  mlir_generate_dialect_extension_type_stubs(MLIRPythonExtension.Dialects.GPU.Nanobind)
+  mlir_generate_dialect_extension_type_stubs(MLIRPythonExtension.Dialects.Quant.Nanobind)
+  mlir_generate_dialect_extension_type_stubs(MLIRPythonExtension.Dialects.NVGPU.Nanobind)
+  mlir_generate_dialect_extension_type_stubs(MLIRPythonExtension.Dialects.PDL.Nanobind)
+  mlir_generate_dialect_extension_type_stubs(MLIRPythonExtension.Dialects.SparseTensor.Nanobind)
+  mlir_generate_dialect_extension_type_stubs(MLIRPythonExtension.Dialects.Transform.Nanobind)
+  mlir_generate_dialect_extension_type_stubs(MLIRPythonExtension.Dialects.IRDL.Nanobind)
+  mlir_generate_dialect_extension_type_stubs(MLIRPythonExtension.ExecutionEngine)
+  mlir_generate_dialect_extension_type_stubs(MLIRPythonExtension.Dialects.SMT.Nanobind)
+  mlir_generate_dialect_extension_type_stubs(MLIRPythonExtension.TransformInterpreter)
+  mlir_generate_dialect_extension_type_stubs(MLIRPythonExtension.Dialects.AMDGPU.Nanobind)
+
   # _mlirPythonTestNanobind stubgen
 
   if(MLIR_INCLUDE_TESTS)
-    get_target_property(_test_extension_srcs MLIRPythonTestSources.PythonTestExtensionNanobind INTERFACE_SOURCES)
-    mlir_generate_type_stubs(
-      # This is the FQN path because dialect modules import _mlir when loaded. See above.
-      MODULE_NAME mlir._mlir_libs._mlirPythonTestNanobind
-      DEPENDS_TARGETS
-        # You need both _mlir and _mlirPythonTestNanobind because dialect modules import _mlir when loaded
-        # (so _mlir needs to be built before calling stubgen).
-        MLIRPythonModules.extension._mlir.dso
-        MLIRPythonModules.extension._mlirPythonTestNanobind.dso
-        # You need this one so that ir.py "built" because mlir._mlir_libs.__init__.py import mlir.ir in _site_initialize.
-        MLIRPythonModules.sources.MLIRPythonSources.Core.Python
-      OUTPUT_DIR "${CMAKE_CURRENT_BINARY_DIR}/type_stubs/_mlir_libs"
-      OUTPUTS _mlirPythonTestNanobind.pyi
-      DEPENDS_TARGET_SRC_DEPS "${_test_extension_srcs}"
-      IMPORT_PATHS "${MLIRPythonModules_ROOT_PREFIX}/.."
-    )
-    set(_mlirPythonTestNanobind_typestub_gen_target "${NB_STUBGEN_CUSTOM_TARGET}")
-    declare_mlir_python_sources(
-      MLIRPythonTestSources.PythonTestExtensionNanobind.type_stub_gen
-      ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}/type_stubs"
-      ADD_TO_PARENT MLIRPythonTestSources.Dialects
-      SOURCES _mlir_libs/_mlirPythonTestNanobind.pyi
-    )
+    mlir_generate_dialect_extension_type_stubs(MLIRPythonTestSources.PythonTestExtensionNanobind)
   endif()
 endif()
 
@@ -994,8 +1024,5 @@ add_mlir_python_modules(MLIRPythonModules
     MLIRPythonCAPI
 )
 if(MLIR_PYTHON_STUBGEN_ENABLED)
-  add_dependencies(MLIRPythonModules "${_mlir_typestub_gen_target}")
-  if(MLIR_INCLUDE_TESTS)
-    add_dependencies(MLIRPythonModules "${_mlirPythonTestNanobind_typestub_gen_target}")
-  endif()
+  add_dependencies(MLIRPythonModules ${_mlir_typestub_gen_targets})
 endif()
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi b/mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi
deleted file mode 100644
index d12c6839deaba..0000000000000
--- a/mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi
+++ /dev/null
@@ -1,63 +0,0 @@
-#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-#  See https://llvm.org/LICENSE.txt for license information.
-#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-
-from mlir.ir import Type, Context
-
-__all__ = [
-    'PDLType',
-    'AttributeType',
-    'OperationType',
-    'RangeType',
-    'TypeType',
-    'ValueType',
-]
-
-
-class PDLType(Type):
-  @staticmethod
-  def isinstance(type: Type) -> bool: ...
-
-
-class AttributeType(Type):
-  @staticmethod
-  def isinstance(type: Type) -> bool: ...
-
-  @staticmethod
-  def get(context: Context | None = None) -> AttributeType: ...
-
-
-class OperationType(Type):
-  @staticmethod
-  def isinstance(type: Type) -> bool: ...
-
-  @staticmethod
-  def get(context: Context | None = None) -> OperationType: ...
-
-
-class RangeType(Type):
-  @staticmethod
-  def isinstance(type: Type) -> bool: ...
-
-  @staticmethod
-  def get(element_type: Type) -> RangeType: ...
-
-  @property
-  def element_type(self) -> Type: ...
-
-
-class TypeType(Type):
-  @staticmethod
-  def isinstance(type: Type) -> bool: ...
-
-  @staticmethod
-  def get(context: Context | None = None) -> TypeType: ...
-
-
-class ValueType(Type):
-  @staticmethod
-  def isinstance(type: Type) -> bool: ...
-
-  @staticmethod
-  def get(context: Context | None = None) -> ValueType: ...
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi b/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi
deleted file mode 100644
index 3f5304584edef..0000000000000
--- a/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi
+++ /dev/null
@@ -1,142 +0,0 @@
-#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-#  See https://llvm.org/LICENSE.txt for license information.
-#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-
-from mlir.ir import DenseElementsAttr, Type
-
-__all__ = [
-  "QuantizedType",
-  "AnyQuantizedType",
-  "UniformQuantizedType",
-  "UniformQuantizedPerAxisType",
-  "CalibratedQuantizedType",
-]
-
-class QuantizedType(Type):
-  @staticmethod
-  def isinstance(type: Type) -> bool: ...
-
-  @staticmethod
-  def default_minimum_for_integer(is_signed: bool, integral_width: int) -> int:
-    ...
-
-  @staticmethod
-  def default_maximum_for_integer(is_signed: bool, integral_width: int) -> int:
-    ...
-
-  @property
-  def expressed_type(self) -> Type: ...
-
-  @property
-  def flags(self) -> int: ...
-
-  @property
-  def is_signed(self) -> bool: ...
-
-  @property
-  def storage_type(self) -> Type: ...
-
-  @property
-  def storage_type_min(self) -> int: ...
-
-  @property
-  def storage_type_max(self) -> int: ...
-
-  @property
-  def storage_type_integral_width(self) -> int: ...
-
-  def is_compatible_expressed_type(self, candidate: Type) -> bool: ...
-
-  @property
-  def quantized_element_type(self) -> Type: ...
-
-  def cast_from_storage_type(self, candidate: Type) -> Type: ...
-
-  @staticmethod
-  def cast_to_storage_type(type: Type) -> Type: ...
-
-  def cast_from_expressed_type(self, candidate: Type) -> Type: ...
-
-  @staticmethod
-  def cast_to_expressed_type(type: Type) -> Type: ...
-
-  def cast_expressed_to_storage_type(self, candidate: Type) -> Type: ...
-
-
-class AnyQuantizedType(QuantizedType):
-
-  @classmethod
-  def get(cls, flags: int, storage_type: Type, expressed_type: Type,
-          storage_type_min: int, storage_type_max: int) -> Type:
-    ...
-
-
-class UniformQuantizedType(QuantizedType):
-
-  @classmethod
-  def get(cls, flags: int, storage_type: Type, expressed_type: Type,
-          scale: float, zero_point: int, storage_type_min: int,
-          storage_type_max: int) -> Type: ...
-
-  @property
-  def scale(self) -> float: ...
-
-  @property
-  def zero_point(self) -> int: ...
-
-  @property
-  def is_fixed_point(self) -> bool: ...
-
-
-class UniformQuantizedPerAxisType(QuantizedType):
-
-  @classmethod
-  def get(cls, flags: int, storage_type: Type, expressed_type: Type,
-          scales: list[float], zero_points: list[int], quantized_dimension: int,
-          storage_type_min: int, storage_type_max: int):
-    ...
-
-  @property
-  def scales(self) -> list[float]: ...
-
-  @property
-  def zero_points(self) -> list[int]: ...
-
-  @property
-  def quantized_dimension(self) -> int: ...
-
-  @property
-  def is_fixed_point(self) -> bool: ...
-
-class UniformQuantizedSubChannelType(QuantizedType):
-
-  @classmethod
-  def get(cls, flags: int, storage_type: Type, expressed_type: Type,
-          scales: DenseElementsAttr, zero_points: DenseElementsAttr,
-          quantized_dimensions: list[int], block_sizes: list[int],
-          storage_type_min: int, storage_type_max: int):
-    ...
-
-  @property
-  def quantized_dimensions(self) -> list[int]: ...
-
-  @property
-  def block_sizes(self) -> list[int]: ...
-
-  @property
-  def scales(self) -> DenseElementsAttr: ...
-
-  @property
-  def zero_points(self) -> DenseElementsAttr: ...
-
-def CalibratedQuantizedType(QuantizedType):
-
-  @classmethod
-  def get(cls, expressed_type: Type, min: float, max: float): ...
-
-  @property
-  def min(self) -> float: ...
-
-  @property
-  def max(self) -> float: ...
\ No newline at end of file
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi b/mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi
deleted file mode 100644
index a3f1b09102379..0000000000000
--- a/mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi
+++ /dev/null
@@ -1,25 +0,0 @@
-#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-#  See https://llvm.org/LICENSE.txt for license information.
-#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-
-from mlir.ir import Type, Context
-
-
-class AnyOpType(Type):
-  @staticmethod
-  def isinstance(type: Type) -> bool: ...
-
-  @staticmethod
-  def get(context: Context | None = None) -> AnyOpType: ...
-
-
-class OperationType(Type):
-  @staticmethod
-  def isinstance(type: Type) -> bool: ...
-
-  @staticmethod
-  def get(operation_name: str, context: Context | None = None) -> OperationType: ...
-
-  @property
-  def operation_name(self) -> str: ...
diff --git a/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi b/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi
deleted file mode 100644
index 4b82c78489295..0000000000000
--- a/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi
+++ /dev/null
@@ -1,24 +0,0 @@
-# Originally imported via:
-#   stubgen {...} -m mlir._mlir_libs._mlirExecutionEngine
-# Local modifications:
-#   * Relative imports for cross-module references.
-#   * Add __all__
-
-from collections.abc import Sequence
-
-from ._mlir import ir as _ir
-
-__all__ = [
-    "ExecutionEngine",
-]
-
-class ExecutionEngine:
-    def __init__(self, module: _ir.Module, opt_level: int = 2, shared_libs: Sequence[str] = ...) -> None: ...
-    def _CAPICreate(self) -> object: ...
-    def _testing_release(self) -> None: ...
-    def dump_to_object_file(self, file_name: str) -> None: ...
-    def raw_lookup(self, func_name: str) -> int: ...
-    def raw_register_runtime(self, name: str, callback: object) -> None: ...
-    def init() -> None: ...
-    @property
-    def _CAPIPtr(self) -> object: ...
diff --git a/mlir/python/replace_text.cmake b/mlir/python/replace_text.cmake
new file mode 100644
index 0000000000000..f675fb7927815
--- /dev/null
+++ b/mlir/python/replace_text.cmake
@@ -0,0 +1,9 @@
+# replace_text.cmake
+# Variables INPUT_FILE and OUTPUT_FILE must be passed via -D
+
+file(READ "${INPUT_FILE}" CONTENT)
+
+# Perform replacement
+string(REPLACE "${MLIR_PYTHON_PACKAGE_PREFIX}._mlir_libs._mlir.ir" "${MLIR_PYTHON_PACKAGE_PREFIX}.ir" MODIFIED_CONTENT "${CONTENT}")
+
+file(WRITE "${OUTPUT_FILE}" "${MODIFIED_CONTENT}")



More information about the Mlir-commits mailing list