[Mlir-commits] [mlir] [mlir][Python] generate type stubs for dialect extensions (PR #175403)
Maksim Levental
llvmlistbot at llvm.org
Sun Jan 11 18:19:56 PST 2026
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/175403
>From 8bff757457813aba9db16dd93f4076cbee5b986d Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Sat, 10 Jan 2026 18:11:42 -0800
Subject: [PATCH] [mlir][Python] generate type stubs for dialect extensions
---
mlir/python/CMakeLists.txt | 93 ++++++++----
.../mlir/_mlir_libs/_mlir/dialects/pdl.pyi | 63 --------
.../mlir/_mlir_libs/_mlir/dialects/quant.pyi | 142 ------------------
.../_mlir/dialects/transform/__init__.pyi | 25 ---
.../mlir/_mlir_libs/_mlirExecutionEngine.pyi | 24 ---
mlir/python/replace_text.cmake | 9 ++
6 files changed, 69 insertions(+), 287 deletions(-)
delete mode 100644 mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi
delete mode 100644 mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi
delete mode 100644 mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi
delete mode 100644 mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi
create mode 100644 mlir/python/replace_text.cmake
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 003a06b16daac..25243680867da 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -45,7 +45,6 @@ declare_mlir_python_sources(MLIRPythonSources.ExecutionEngine
ADD_TO_PARENT MLIRPythonSources
SOURCES
execution_engine.py
- _mlir_libs/_mlirExecutionEngine.pyi
SOURCES_GLOB
runtime/*.py
)
@@ -216,7 +215,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/TransformOps.td
SOURCES
dialects/transform/__init__.py
- _mlir_libs/_mlir/dialects/transform/__init__.pyi
DIALECT_NAME transform
GEN_ENUM_BINDINGS_TD_FILE
"../../include/mlir/Dialect/Transform/IR/TransformAttrs.td"
@@ -406,8 +404,7 @@ declare_mlir_python_sources(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
GEN_ENUM_BINDINGS
SOURCES
- dialects/quant.py
- _mlir_libs/_mlir/dialects/quant.pyi)
+ dialects/quant.py)
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
@@ -423,7 +420,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/PDLOps.td
SOURCES
dialects/pdl.py
- _mlir_libs/_mlir/dialects/pdl.pyi
DIALECT_NAME pdl)
declare_mlir_dialect_python_bindings(
@@ -882,7 +878,50 @@ add_mlir_python_common_capi_library(MLIRPythonCAPI
_flatten_mlir_python_targets(mlir_python_sources_deps MLIRPythonSources)
+function(mlir_generate_dialect_extension_type_stubs ext_target)
+ get_target_property(_extension_srcs ${ext_target} INTERFACE_SOURCES)
+ get_target_property(_module_name ${ext_target} mlir_python_EXTENSION_MODULE_NAME)
+ mlir_generate_type_stubs(
+ # This is the FQN path because dialect modules import _mlir when loaded. See above.
+ MODULE_NAME ${MLIR_PYTHON_PACKAGE_PREFIX}._mlir_libs.${_module_name}
+ DEPENDS_TARGETS
+ # You need both _mlir and ${_module_name} because dialect modules import _mlir when loaded
+ # (so _mlir needs to be built before calling stubgen).
+ MLIRPythonModules.extension._mlir.dso
+ MLIRPythonModules.extension.${_module_name}.dso
+ # You need this one so that ir.py "built" because mlir._mlir_libs.__init__.py import mlir.ir in _site_initialize.
+ MLIRPythonModules.sources.MLIRPythonSources.Core.Python
+ OUTPUT_DIR "${CMAKE_CURRENT_BINARY_DIR}/type_stubs/_mlir_libs"
+ OUTPUTS ${_module_name}.pyi
+ DEPENDS_TARGET_SRC_DEPS "${_extension_srcs}"
+ IMPORT_PATHS "${MLIRPythonModules_ROOT_PREFIX}/.."
+ )
+ list(APPEND _mlir_typestub_gen_targets ${NB_STUBGEN_CUSTOM_TARGET})
+ set(_mlir_typestub_gen_targets ${_mlir_typestub_gen_targets} PARENT_SCOPE)
+ set(INPUT_PATH "${CMAKE_CURRENT_BINARY_DIR}/type_stubs/_mlir_libs/${_module_name}.pyi")
+ add_custom_target(${NB_STUBGEN_CUSTOM_TARGET}.fixup
+ COMMAND ${CMAKE_COMMAND}
+ -DINPUT_FILE=${INPUT_PATH}
+ -DOUTPUT_FILE=${INPUT_PATH}
+ -P "${CMAKE_CURRENT_SOURCE_DIR}/replace_text.cmake"
+ DEPENDS "${INPUT_PATH}" "${CMAKE_CURRENT_SOURCE_DIR}/replace_text.cmake"
+ COMMENT "Replacing strings in ${_module_name}.pyi at build time"
+ VERBATIM
+ )
+ list(APPEND _mlir_typestub_gen_targets ${NB_STUBGEN_CUSTOM_TARGET}.fixup)
+ declare_mlir_python_sources(
+ ${ext_target}.type_stub_gen
+ ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}/type_stubs"
+ ADD_TO_PARENT MLIRPythonSources.Dialects
+ SOURCES _mlir_libs/${_module_name}.pyi
+ )
+ list(APPEND mlir_python_sources_deps ${ext_target}.type_stub_gen)
+ set(mlir_python_sources_deps ${mlir_python_sources_deps} PARENT_SCOPE)
+endfunction()
+
if(MLIR_PYTHON_STUBGEN_ENABLED)
+ set(_mlir_typestub_gen_targets)
+
# _mlir stubgen
# Note: All this needs to come before add_mlir_python_modules(MLIRPythonModules so that the install targets for the
# generated type stubs get created.
@@ -924,7 +963,7 @@ if(MLIR_PYTHON_STUBGEN_ENABLED)
DEPENDS_TARGET_SRC_DEPS "${_core_extension_srcs}"
IMPORT_PATHS "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
)
- set(_mlir_typestub_gen_target "${NB_STUBGEN_CUSTOM_TARGET}")
+ list(APPEND _mlir_typestub_gen_targets ${NB_STUBGEN_CUSTOM_TARGET})
list(TRANSFORM _core_type_stub_sources PREPEND "_mlir_libs/")
# Note, we do not do ADD_TO_PARENT here so that the type stubs are not associated (as mlir_DEPENDS) with
@@ -937,32 +976,23 @@ if(MLIR_PYTHON_STUBGEN_ENABLED)
)
list(APPEND mlir_python_sources_deps MLIRPythonExtension.Core.type_stub_gen)
+ mlir_generate_dialect_extension_type_stubs(MLIRPythonExtension.Dialects.Linalg.Nanobind)
+ mlir_generate_dialect_extension_type_stubs(MLIRPythonExtension.Dialects.GPU.Nanobind)
+ mlir_generate_dialect_extension_type_stubs(MLIRPythonExtension.Dialects.Quant.Nanobind)
+ mlir_generate_dialect_extension_type_stubs(MLIRPythonExtension.Dialects.NVGPU.Nanobind)
+ mlir_generate_dialect_extension_type_stubs(MLIRPythonExtension.Dialects.PDL.Nanobind)
+ mlir_generate_dialect_extension_type_stubs(MLIRPythonExtension.Dialects.SparseTensor.Nanobind)
+ mlir_generate_dialect_extension_type_stubs(MLIRPythonExtension.Dialects.Transform.Nanobind)
+ mlir_generate_dialect_extension_type_stubs(MLIRPythonExtension.Dialects.IRDL.Nanobind)
+ mlir_generate_dialect_extension_type_stubs(MLIRPythonExtension.ExecutionEngine)
+ mlir_generate_dialect_extension_type_stubs(MLIRPythonExtension.Dialects.SMT.Nanobind)
+ mlir_generate_dialect_extension_type_stubs(MLIRPythonExtension.TransformInterpreter)
+ mlir_generate_dialect_extension_type_stubs(MLIRPythonExtension.Dialects.AMDGPU.Nanobind)
+
# _mlirPythonTestNanobind stubgen
if(MLIR_INCLUDE_TESTS)
- get_target_property(_test_extension_srcs MLIRPythonTestSources.PythonTestExtensionNanobind INTERFACE_SOURCES)
- mlir_generate_type_stubs(
- # This is the FQN path because dialect modules import _mlir when loaded. See above.
- MODULE_NAME mlir._mlir_libs._mlirPythonTestNanobind
- DEPENDS_TARGETS
- # You need both _mlir and _mlirPythonTestNanobind because dialect modules import _mlir when loaded
- # (so _mlir needs to be built before calling stubgen).
- MLIRPythonModules.extension._mlir.dso
- MLIRPythonModules.extension._mlirPythonTestNanobind.dso
- # You need this one so that ir.py "built" because mlir._mlir_libs.__init__.py import mlir.ir in _site_initialize.
- MLIRPythonModules.sources.MLIRPythonSources.Core.Python
- OUTPUT_DIR "${CMAKE_CURRENT_BINARY_DIR}/type_stubs/_mlir_libs"
- OUTPUTS _mlirPythonTestNanobind.pyi
- DEPENDS_TARGET_SRC_DEPS "${_test_extension_srcs}"
- IMPORT_PATHS "${MLIRPythonModules_ROOT_PREFIX}/.."
- )
- set(_mlirPythonTestNanobind_typestub_gen_target "${NB_STUBGEN_CUSTOM_TARGET}")
- declare_mlir_python_sources(
- MLIRPythonTestSources.PythonTestExtensionNanobind.type_stub_gen
- ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}/type_stubs"
- ADD_TO_PARENT MLIRPythonTestSources.Dialects
- SOURCES _mlir_libs/_mlirPythonTestNanobind.pyi
- )
+ mlir_generate_dialect_extension_type_stubs(MLIRPythonTestSources.PythonTestExtensionNanobind)
endif()
endif()
@@ -994,8 +1024,5 @@ add_mlir_python_modules(MLIRPythonModules
MLIRPythonCAPI
)
if(MLIR_PYTHON_STUBGEN_ENABLED)
- add_dependencies(MLIRPythonModules "${_mlir_typestub_gen_target}")
- if(MLIR_INCLUDE_TESTS)
- add_dependencies(MLIRPythonModules "${_mlirPythonTestNanobind_typestub_gen_target}")
- endif()
+ add_dependencies(MLIRPythonModules ${_mlir_typestub_gen_targets})
endif()
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi b/mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi
deleted file mode 100644
index d12c6839deaba..0000000000000
--- a/mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi
+++ /dev/null
@@ -1,63 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-
-from mlir.ir import Type, Context
-
-__all__ = [
- 'PDLType',
- 'AttributeType',
- 'OperationType',
- 'RangeType',
- 'TypeType',
- 'ValueType',
-]
-
-
-class PDLType(Type):
- @staticmethod
- def isinstance(type: Type) -> bool: ...
-
-
-class AttributeType(Type):
- @staticmethod
- def isinstance(type: Type) -> bool: ...
-
- @staticmethod
- def get(context: Context | None = None) -> AttributeType: ...
-
-
-class OperationType(Type):
- @staticmethod
- def isinstance(type: Type) -> bool: ...
-
- @staticmethod
- def get(context: Context | None = None) -> OperationType: ...
-
-
-class RangeType(Type):
- @staticmethod
- def isinstance(type: Type) -> bool: ...
-
- @staticmethod
- def get(element_type: Type) -> RangeType: ...
-
- @property
- def element_type(self) -> Type: ...
-
-
-class TypeType(Type):
- @staticmethod
- def isinstance(type: Type) -> bool: ...
-
- @staticmethod
- def get(context: Context | None = None) -> TypeType: ...
-
-
-class ValueType(Type):
- @staticmethod
- def isinstance(type: Type) -> bool: ...
-
- @staticmethod
- def get(context: Context | None = None) -> ValueType: ...
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi b/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi
deleted file mode 100644
index 3f5304584edef..0000000000000
--- a/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi
+++ /dev/null
@@ -1,142 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-
-from mlir.ir import DenseElementsAttr, Type
-
-__all__ = [
- "QuantizedType",
- "AnyQuantizedType",
- "UniformQuantizedType",
- "UniformQuantizedPerAxisType",
- "CalibratedQuantizedType",
-]
-
-class QuantizedType(Type):
- @staticmethod
- def isinstance(type: Type) -> bool: ...
-
- @staticmethod
- def default_minimum_for_integer(is_signed: bool, integral_width: int) -> int:
- ...
-
- @staticmethod
- def default_maximum_for_integer(is_signed: bool, integral_width: int) -> int:
- ...
-
- @property
- def expressed_type(self) -> Type: ...
-
- @property
- def flags(self) -> int: ...
-
- @property
- def is_signed(self) -> bool: ...
-
- @property
- def storage_type(self) -> Type: ...
-
- @property
- def storage_type_min(self) -> int: ...
-
- @property
- def storage_type_max(self) -> int: ...
-
- @property
- def storage_type_integral_width(self) -> int: ...
-
- def is_compatible_expressed_type(self, candidate: Type) -> bool: ...
-
- @property
- def quantized_element_type(self) -> Type: ...
-
- def cast_from_storage_type(self, candidate: Type) -> Type: ...
-
- @staticmethod
- def cast_to_storage_type(type: Type) -> Type: ...
-
- def cast_from_expressed_type(self, candidate: Type) -> Type: ...
-
- @staticmethod
- def cast_to_expressed_type(type: Type) -> Type: ...
-
- def cast_expressed_to_storage_type(self, candidate: Type) -> Type: ...
-
-
-class AnyQuantizedType(QuantizedType):
-
- @classmethod
- def get(cls, flags: int, storage_type: Type, expressed_type: Type,
- storage_type_min: int, storage_type_max: int) -> Type:
- ...
-
-
-class UniformQuantizedType(QuantizedType):
-
- @classmethod
- def get(cls, flags: int, storage_type: Type, expressed_type: Type,
- scale: float, zero_point: int, storage_type_min: int,
- storage_type_max: int) -> Type: ...
-
- @property
- def scale(self) -> float: ...
-
- @property
- def zero_point(self) -> int: ...
-
- @property
- def is_fixed_point(self) -> bool: ...
-
-
-class UniformQuantizedPerAxisType(QuantizedType):
-
- @classmethod
- def get(cls, flags: int, storage_type: Type, expressed_type: Type,
- scales: list[float], zero_points: list[int], quantized_dimension: int,
- storage_type_min: int, storage_type_max: int):
- ...
-
- @property
- def scales(self) -> list[float]: ...
-
- @property
- def zero_points(self) -> list[int]: ...
-
- @property
- def quantized_dimension(self) -> int: ...
-
- @property
- def is_fixed_point(self) -> bool: ...
-
-class UniformQuantizedSubChannelType(QuantizedType):
-
- @classmethod
- def get(cls, flags: int, storage_type: Type, expressed_type: Type,
- scales: DenseElementsAttr, zero_points: DenseElementsAttr,
- quantized_dimensions: list[int], block_sizes: list[int],
- storage_type_min: int, storage_type_max: int):
- ...
-
- @property
- def quantized_dimensions(self) -> list[int]: ...
-
- @property
- def block_sizes(self) -> list[int]: ...
-
- @property
- def scales(self) -> DenseElementsAttr: ...
-
- @property
- def zero_points(self) -> DenseElementsAttr: ...
-
-def CalibratedQuantizedType(QuantizedType):
-
- @classmethod
- def get(cls, expressed_type: Type, min: float, max: float): ...
-
- @property
- def min(self) -> float: ...
-
- @property
- def max(self) -> float: ...
\ No newline at end of file
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi b/mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi
deleted file mode 100644
index a3f1b09102379..0000000000000
--- a/mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi
+++ /dev/null
@@ -1,25 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-
-from mlir.ir import Type, Context
-
-
-class AnyOpType(Type):
- @staticmethod
- def isinstance(type: Type) -> bool: ...
-
- @staticmethod
- def get(context: Context | None = None) -> AnyOpType: ...
-
-
-class OperationType(Type):
- @staticmethod
- def isinstance(type: Type) -> bool: ...
-
- @staticmethod
- def get(operation_name: str, context: Context | None = None) -> OperationType: ...
-
- @property
- def operation_name(self) -> str: ...
diff --git a/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi b/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi
deleted file mode 100644
index 4b82c78489295..0000000000000
--- a/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi
+++ /dev/null
@@ -1,24 +0,0 @@
-# Originally imported via:
-# stubgen {...} -m mlir._mlir_libs._mlirExecutionEngine
-# Local modifications:
-# * Relative imports for cross-module references.
-# * Add __all__
-
-from collections.abc import Sequence
-
-from ._mlir import ir as _ir
-
-__all__ = [
- "ExecutionEngine",
-]
-
-class ExecutionEngine:
- def __init__(self, module: _ir.Module, opt_level: int = 2, shared_libs: Sequence[str] = ...) -> None: ...
- def _CAPICreate(self) -> object: ...
- def _testing_release(self) -> None: ...
- def dump_to_object_file(self, file_name: str) -> None: ...
- def raw_lookup(self, func_name: str) -> int: ...
- def raw_register_runtime(self, name: str, callback: object) -> None: ...
- def init() -> None: ...
- @property
- def _CAPIPtr(self) -> object: ...
diff --git a/mlir/python/replace_text.cmake b/mlir/python/replace_text.cmake
new file mode 100644
index 0000000000000..f675fb7927815
--- /dev/null
+++ b/mlir/python/replace_text.cmake
@@ -0,0 +1,9 @@
+# replace_text.cmake
+# Variables INPUT_FILE and OUTPUT_FILE must be passed via -D
+
+file(READ "${INPUT_FILE}" CONTENT)
+
+# Perform replacement
+string(REPLACE "${MLIR_PYTHON_PACKAGE_PREFIX}._mlir_libs._mlir.ir" "${MLIR_PYTHON_PACKAGE_PREFIX}.ir" MODIFIED_CONTENT "${CONTENT}")
+
+file(WRITE "${OUTPUT_FILE}" "${MODIFIED_CONTENT}")
More information about the Mlir-commits
mailing list