[Mlir-commits] [mlir] [mlir][Python] generate type stubs for dialect extensions (PR #175403)
Maksim Levental
llvmlistbot at llvm.org
Sat Jan 10 19:00:02 PST 2026
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/175403
>From 9b03a3cab7b1cc3c3e5bc99ba5983461267a1063 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Sat, 10 Jan 2026 18:11:42 -0800
Subject: [PATCH] [mlir][Python] generate type stubs for dialect extensions
---
mlir/lib/Bindings/Python/DialectLinalg.cpp | 132 +++++++++++-------
mlir/python/CMakeLists.txt | 24 +++-
.../python/lib/PythonTestModuleNanobind.cpp | 42 +++---
3 files changed, 119 insertions(+), 79 deletions(-)
diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp
index 0b079b404d42d..29ce86fb00fb9 100644
--- a/mlir/lib/Bindings/Python/DialectLinalg.cpp
+++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp
@@ -1,4 +1,4 @@
-//===- DialectLinalg.cpp - Pybind module for Linalg dialect API support --===//
+//===- DialectLinalg.cpp - Nanobind module for Linalg dialect API support -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -8,16 +8,27 @@
#include "mlir-c/Dialect/Linalg.h"
#include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include <mlir/Bindings/Python/IRAttributes.h>
+
namespace nb = nanobind;
using namespace mlir::python::nanobind_adaptors;
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace linalg {
+
+struct PyLinalgContractionDimensions : MlirLinalgContractionDimensions {};
+
+struct PyLinalgConvolutionDimensions : MlirLinalgConvolutionDimensions {};
-static std::optional<MlirLinalgContractionDimensions>
-InferContractionDimensions(MlirOperation op) {
+static std::optional<PyLinalgContractionDimensions>
+InferContractionDimensions(PyOperationBase &op) {
MlirLinalgContractionDimensions dims =
- mlirLinalgInferContractionDimensions(op);
+ mlirLinalgInferContractionDimensions(op.getOperation());
// Detect "empty" result. This occurs when `op` is not a contraction op,
// or when `linalg::inferContractionDims` fails.
@@ -25,13 +36,13 @@ InferContractionDimensions(MlirOperation op) {
mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
return std::nullopt;
}
- return dims;
+ return PyLinalgContractionDimensions{dims.batch, dims.m, dims.k, dims.n};
}
-static std::optional<MlirLinalgConvolutionDimensions>
-InferConvolutionDimensions(MlirOperation op) {
+static std::optional<PyLinalgConvolutionDimensions>
+InferConvolutionDimensions(PyOperationBase &op) {
MlirLinalgConvolutionDimensions dims =
- mlirLinalgInferConvolutionDimensions(op);
+ mlirLinalgInferConvolutionDimensions(op.getOperation());
// Detect "empty" result. This occurs when `op` is not a convolution op,
// or when `linalg::inferConvolutionDims` fails.
@@ -45,33 +56,38 @@ InferConvolutionDimensions(MlirOperation op) {
return std::nullopt;
}
- return dims;
+ return PyLinalgConvolutionDimensions{
+ dims.batch, dims.outputImage, dims.outputChannel, dims.filterLoop,
+ dims.inputChannel, dims.depth, dims.strides, dims.dilations};
}
static void populateDialectLinalgSubmodule(nb::module_ m) {
m.def(
"fill_builtin_region",
- [](MlirOperation op) { mlirLinalgFillBuiltinNamedOpRegion(op); },
+ [](PyOperationBase &op) {
+ mlirLinalgFillBuiltinNamedOpRegion(op.getOperation());
+ },
nb::arg("op"),
"Fill the region for `op`, which is assumed to be a builtin named Linalg "
"op.");
- m.def("isa_contraction_op", &mlirLinalgIsAContractionOp,
- "Checks if the given operation is a Linalg contraction operation.",
- nb::arg("op"));
+ m.def(
+ "isa_contraction_op",
+ [](PyOperationBase &op) {
+ return mlirLinalgIsAContractionOp(op.getOperation());
+ },
+ "Checks if the given operation is a Linalg contraction operation.",
+ nb::arg("op"));
- nb::class_<MlirLinalgContractionDimensions>(m, "ContractionDimensions")
- .def_prop_ro("batch",
- [](const MlirLinalgContractionDimensions &self) {
- return self.batch;
- })
+ nb::class_<PyLinalgContractionDimensions>(m, "ContractionDimensions")
+ .def_prop_ro(
+ "batch",
+ [](const PyLinalgContractionDimensions &self) { return self.batch; })
.def_prop_ro(
- "m",
- [](const MlirLinalgContractionDimensions &self) { return self.m; })
+ "m", [](const PyLinalgContractionDimensions &self) { return self.m; })
.def_prop_ro(
- "n",
- [](const MlirLinalgContractionDimensions &self) { return self.n; })
- .def_prop_ro("k", [](const MlirLinalgContractionDimensions &self) {
+ "n", [](const PyLinalgContractionDimensions &self) { return self.n; })
+ .def_prop_ro("k", [](const PyLinalgContractionDimensions &self) {
return self.k;
});
@@ -82,80 +98,90 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
m.def(
"infer_contraction_dimensions_from_maps",
- [](std::vector<MlirAffineMap> indexingMaps)
- -> std::optional<MlirLinalgContractionDimensions> {
+ [](std::vector<PyAffineMap> indexingMaps)
+ -> std::optional<PyLinalgContractionDimensions> {
if (indexingMaps.empty())
return std::nullopt;
+ std::vector<MlirAffineMap> indexingMaps_(indexingMaps.size());
+ std::copy(indexingMaps.begin(), indexingMaps.end(),
+ indexingMaps_.begin());
MlirLinalgContractionDimensions dims =
- mlirLinalgInferContractionDimensionsFromMaps(indexingMaps.data(),
- indexingMaps.size());
+ mlirLinalgInferContractionDimensionsFromMaps(indexingMaps_.data(),
+ indexingMaps_.size());
// Detect "empty" result from invalid input or failed inference.
if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) &&
mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
return std::nullopt;
}
- return dims;
+ return PyLinalgContractionDimensions{dims.batch, dims.m, dims.k,
+ dims.n};
},
"Infers contraction dimensions (batch/m/n/k) from a list of affine "
"maps.",
nb::arg("indexing_maps"));
- m.def("isa_convolution_op", &mlirLinalgIsAConvolutionOp,
- "Checks if the given operation is a Linalg convolution operation.",
- nb::arg("op"));
+ m.def(
+ "isa_convolution_op",
+ [](PyOperationBase &op) {
+ return mlirLinalgIsAConvolutionOp(op.getOperation());
+ },
+ "Checks if the given operation is a Linalg convolution operation.",
+ nb::arg("op"));
- nb::class_<MlirLinalgConvolutionDimensions>(m, "ConvolutionDimensions")
- .def_prop_ro("batch",
- [](const MlirLinalgConvolutionDimensions &self) {
- return self.batch;
- })
+ nb::class_<PyLinalgConvolutionDimensions>(m, "ConvolutionDimensions")
+ .def_prop_ro(
+ "batch",
+ [](const PyLinalgConvolutionDimensions &self) { return self.batch; })
.def_prop_ro("output_image",
- [](const MlirLinalgConvolutionDimensions &self) {
+ [](const PyLinalgConvolutionDimensions &self) {
return self.outputImage;
})
.def_prop_ro("output_channel",
- [](const MlirLinalgConvolutionDimensions &self) {
+ [](const PyLinalgConvolutionDimensions &self) {
return self.outputChannel;
})
.def_prop_ro("filter_loop",
- [](const MlirLinalgConvolutionDimensions &self) {
+ [](const PyLinalgConvolutionDimensions &self) {
return self.filterLoop;
})
.def_prop_ro("input_channel",
- [](const MlirLinalgConvolutionDimensions &self) {
+ [](const PyLinalgConvolutionDimensions &self) {
return self.inputChannel;
})
- .def_prop_ro("depth",
- [](const MlirLinalgConvolutionDimensions &self) {
- return self.depth;
- })
+ .def_prop_ro(
+ "depth",
+ [](const PyLinalgConvolutionDimensions &self) { return self.depth; })
.def_prop_ro("strides",
- [](const MlirLinalgConvolutionDimensions &self) {
+ [](const PyLinalgConvolutionDimensions &self) {
return self.strides;
})
- .def_prop_ro("dilations",
- [](const MlirLinalgConvolutionDimensions &self) {
- return self.dilations;
- });
+ .def_prop_ro("dilations", [](const PyLinalgConvolutionDimensions &self) {
+ return self.dilations;
+ });
m.def("infer_convolution_dimensions", &InferConvolutionDimensions,
"Infers convolution dimensions", nb::arg("op"));
m.def(
"get_indexing_maps",
- [](MlirOperation op) -> std::optional<MlirAttribute> {
- MlirAttribute attr = mlirLinalgGetIndexingMapsAttribute(op);
+ [](PyOperationBase &op) -> std::optional<PyAttribute> {
+ MlirAttribute attr =
+ mlirLinalgGetIndexingMapsAttribute(op.getOperation());
if (mlirAttributeIsNull(attr))
return std::nullopt;
- return attr;
+ return PyArrayAttribute(op.getContext(), attr);
},
"Returns the indexing_maps attribute for a linalg op.");
}
+} // namespace linalg
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
NB_MODULE(_mlirDialectsLinalg, m) {
m.doc() = "MLIR Linalg dialect.";
- populateDialectLinalgSubmodule(m);
+ mlir::python::mlir::linalg::populateDialectLinalgSubmodule(m);
}
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 003a06b16daac..a7cece87a11d6 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -924,7 +924,25 @@ if(MLIR_PYTHON_STUBGEN_ENABLED)
DEPENDS_TARGET_SRC_DEPS "${_core_extension_srcs}"
IMPORT_PATHS "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
)
- set(_mlir_typestub_gen_target "${NB_STUBGEN_CUSTOM_TARGET}")
+ set(_mlir_typestub_gen_targets "${NB_STUBGEN_CUSTOM_TARGET}")
+
+ get_target_property(_linalg_extension_srcs MLIRPythonExtension.Dialects.Linalg.Nanobind INTERFACE_SOURCES)
+ mlir_generate_type_stubs(
+ MODULE_NAME ${MLIR_PYTHON_PACKAGE_PREFIX}._mlir_libs._mlirDialectsLinalg
+ DEPENDS_TARGETS
+ # You need both _mlir and _mlirPythonTestNanobind because dialect modules import _mlir when loaded
+ # (so _mlir needs to be built before calling stubgen).
+ MLIRPythonModules.extension._mlir.dso
+ MLIRPythonModules.extension._mlirDialectsLinalg.dso
+ # You need this one so that ir.py "built" because mlir._mlir_libs.__init__.py import mlir.ir in _site_initialize.
+ MLIRPythonModules.sources.MLIRPythonSources.Core.Python
+ OUTPUT_DIR "${CMAKE_CURRENT_BINARY_DIR}/type_stubs/_mlir_libs"
+ OUTPUTS _mlirDialectsLinalg.pyi
+ DEPENDS_TARGET_SRC_DEPS "${_linalg_extension_srcs}"
+ IMPORT_PATHS "${MLIRPythonModules_ROOT_PREFIX}/.."
+ )
+ list(APPEND _mlir_typestub_gen_targets "${NB_STUBGEN_CUSTOM_TARGET}")
+ list(APPEND _core_type_stub_sources "_mlirDialectsLinalg.pyi")
list(TRANSFORM _core_type_stub_sources PREPEND "_mlir_libs/")
# Note, we do not do ADD_TO_PARENT here so that the type stubs are not associated (as mlir_DEPENDS) with
@@ -943,7 +961,7 @@ if(MLIR_PYTHON_STUBGEN_ENABLED)
get_target_property(_test_extension_srcs MLIRPythonTestSources.PythonTestExtensionNanobind INTERFACE_SOURCES)
mlir_generate_type_stubs(
# This is the FQN path because dialect modules import _mlir when loaded. See above.
- MODULE_NAME mlir._mlir_libs._mlirPythonTestNanobind
+ MODULE_NAME ${MLIR_PYTHON_PACKAGE_PREFIX}._mlir_libs._mlirPythonTestNanobind
DEPENDS_TARGETS
# You need both _mlir and _mlirPythonTestNanobind because dialect modules import _mlir when loaded
# (so _mlir needs to be built before calling stubgen).
@@ -994,7 +1012,7 @@ add_mlir_python_modules(MLIRPythonModules
MLIRPythonCAPI
)
if(MLIR_PYTHON_STUBGEN_ENABLED)
- add_dependencies(MLIRPythonModules "${_mlir_typestub_gen_target}")
+ add_dependencies(MLIRPythonModules ${_mlir_typestub_gen_targets})
if(MLIR_INCLUDE_TESTS)
add_dependencies(MLIRPythonModules "${_mlirPythonTestNanobind_typestub_gen_target}")
endif()
diff --git a/mlir/test/python/lib/PythonTestModuleNanobind.cpp b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
index a296b5e814b4b..e9754749352b1 100644
--- a/mlir/test/python/lib/PythonTestModuleNanobind.cpp
+++ b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
@@ -22,14 +22,16 @@
namespace nb = nanobind;
using namespace mlir::python::nanobind_adaptors;
-
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace python_test {
static bool mlirTypeIsARankedIntegerTensor(MlirType t) {
return mlirTypeIsARankedTensor(t) &&
mlirTypeIsAInteger(mlirShapedTypeGetElementType(t));
}
-struct PyTestType
- : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType<PyTestType> {
+struct PyTestType : PyConcreteType<PyTestType> {
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPythonTestTestType;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirPythonTestTestTypeGetTypeID;
@@ -39,8 +41,7 @@ struct PyTestType
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
- context) {
+ [](DefaultingPyMlirContext context) {
return PyTestType(context->getRef(),
mlirPythonTestTestTypeGet(context.get()->get()));
},
@@ -49,9 +50,7 @@ struct PyTestType
};
struct PyTestIntegerRankedTensorType
- : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType<
- PyTestIntegerRankedTensorType,
- mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyRankedTensorType> {
+ : PyConcreteType<PyTestIntegerRankedTensorType, PyRankedTensorType> {
static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedIntegerTensor;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirRankedTensorTypeGetTypeID;
@@ -62,8 +61,7 @@ struct PyTestIntegerRankedTensorType
c.def_static(
"get",
[](std::vector<int64_t> shape, unsigned width,
- mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
- ctx) {
+ DefaultingPyMlirContext ctx) {
MlirAttribute encoding = mlirAttributeGetNull();
return PyTestIntegerRankedTensorType(
ctx->getRef(),
@@ -76,9 +74,7 @@ struct PyTestIntegerRankedTensorType
}
};
-struct PyTestTensorValue
- : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteValue<
- PyTestTensorValue> {
+struct PyTestTensorValue : PyConcreteValue<PyTestTensorValue> {
static constexpr IsAFunctionTy isaFunction =
mlirTypeIsAPythonTestTestTensorValue;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -91,9 +87,7 @@ struct PyTestTensorValue
}
};
-class PyTestAttr
- : public mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteAttribute<
- PyTestAttr> {
+class PyTestAttr : public PyConcreteAttribute<PyTestAttr> {
public:
static constexpr IsAFunctionTy isaFunction =
mlirAttributeIsAPythonTestTestAttribute;
@@ -105,21 +99,23 @@ class PyTestAttr
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
- context) {
+ [](DefaultingPyMlirContext context) {
return PyTestAttr(context->getRef(), mlirPythonTestTestAttributeGet(
context.get()->get()));
},
nb::arg("context").none() = nb::none());
}
};
+} // namespace python_test
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
NB_MODULE(_mlirPythonTestNanobind, m) {
+ using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
m.def(
"register_python_test_dialect",
- [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
- context,
- bool load) {
+ [](DefaultingPyMlirContext context, bool load) {
MlirDialectHandle pythonTestDialect =
mlirGetDialectHandle__python_test__();
mlirDialectHandleRegisterDialect(pythonTestDialect,
@@ -144,14 +140,14 @@ NB_MODULE(_mlirPythonTestNanobind, m) {
m.def(
"test_diagnostics_with_errors_and_notes",
- [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
- ctx) {
+ [](DefaultingPyMlirContext ctx) {
mlir::python::CollectDiagnosticsToStringScope handler(ctx.get()->get());
mlirPythonTestEmitDiagnosticWithNote(ctx.get()->get());
throw nb::value_error(handler.takeMessage().c_str());
},
nb::arg("context").none() = nb::none());
+ using namespace python_test;
PyTestAttr::bind(m);
PyTestType::bind(m);
PyTestIntegerRankedTensorType::bind(m);
More information about the Mlir-commits
mailing list