[Mlir-commits] [mlir] [mlir][Python] generate type stubs for dialect extensions (PR #175403)

Maksim Levental llvmlistbot at llvm.org
Sat Jan 10 19:00:02 PST 2026


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/175403

>From 9b03a3cab7b1cc3c3e5bc99ba5983461267a1063 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Sat, 10 Jan 2026 18:11:42 -0800
Subject: [PATCH] [mlir][Python] generate type stubs for dialect extensions

---
 mlir/lib/Bindings/Python/DialectLinalg.cpp    | 132 +++++++++++-------
 mlir/python/CMakeLists.txt                    |  24 +++-
 .../python/lib/PythonTestModuleNanobind.cpp   |  42 +++---
 3 files changed, 119 insertions(+), 79 deletions(-)

diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp
index 0b079b404d42d..29ce86fb00fb9 100644
--- a/mlir/lib/Bindings/Python/DialectLinalg.cpp
+++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp
@@ -1,4 +1,4 @@
-//===- DialectLinalg.cpp - Pybind module for Linalg dialect API support --===//
+//===- DialectLinalg.cpp - Nanobind module for Linalg dialect API support -===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -8,16 +8,27 @@
 
 #include "mlir-c/Dialect/Linalg.h"
 #include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/IRCore.h"
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir/Bindings/Python/NanobindAdaptors.h"
 
+#include <mlir/Bindings/Python/IRAttributes.h>
+
 namespace nb = nanobind;
 using namespace mlir::python::nanobind_adaptors;
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace linalg {
+
+struct PyLinalgContractionDimensions : MlirLinalgContractionDimensions {};
+
+struct PyLinalgConvolutionDimensions : MlirLinalgConvolutionDimensions {};
 
-static std::optional<MlirLinalgContractionDimensions>
-InferContractionDimensions(MlirOperation op) {
+static std::optional<PyLinalgContractionDimensions>
+InferContractionDimensions(PyOperationBase &op) {
   MlirLinalgContractionDimensions dims =
-      mlirLinalgInferContractionDimensions(op);
+      mlirLinalgInferContractionDimensions(op.getOperation());
 
   // Detect "empty" result. This occurs when `op` is not a contraction op,
   // or when `linalg::inferContractionDims` fails.
@@ -25,13 +36,13 @@ InferContractionDimensions(MlirOperation op) {
       mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
     return std::nullopt;
   }
-  return dims;
+  return PyLinalgContractionDimensions{dims.batch, dims.m, dims.k, dims.n};
 }
 
-static std::optional<MlirLinalgConvolutionDimensions>
-InferConvolutionDimensions(MlirOperation op) {
+static std::optional<PyLinalgConvolutionDimensions>
+InferConvolutionDimensions(PyOperationBase &op) {
   MlirLinalgConvolutionDimensions dims =
-      mlirLinalgInferConvolutionDimensions(op);
+      mlirLinalgInferConvolutionDimensions(op.getOperation());
 
   // Detect "empty" result. This occurs when `op` is not a convolution op,
   // or when `linalg::inferConvolutionDims` fails.
@@ -45,33 +56,38 @@ InferConvolutionDimensions(MlirOperation op) {
     return std::nullopt;
   }
 
-  return dims;
+  return PyLinalgConvolutionDimensions{
+      dims.batch,        dims.outputImage, dims.outputChannel, dims.filterLoop,
+      dims.inputChannel, dims.depth,       dims.strides,       dims.dilations};
 }
 
 static void populateDialectLinalgSubmodule(nb::module_ m) {
   m.def(
       "fill_builtin_region",
-      [](MlirOperation op) { mlirLinalgFillBuiltinNamedOpRegion(op); },
+      [](PyOperationBase &op) {
+        mlirLinalgFillBuiltinNamedOpRegion(op.getOperation());
+      },
       nb::arg("op"),
       "Fill the region for `op`, which is assumed to be a builtin named Linalg "
       "op.");
 
-  m.def("isa_contraction_op", &mlirLinalgIsAContractionOp,
-        "Checks if the given operation is a Linalg contraction operation.",
-        nb::arg("op"));
+  m.def(
+      "isa_contraction_op",
+      [](PyOperationBase &op) {
+        return mlirLinalgIsAContractionOp(op.getOperation());
+      },
+      "Checks if the given operation is a Linalg contraction operation.",
+      nb::arg("op"));
 
-  nb::class_<MlirLinalgContractionDimensions>(m, "ContractionDimensions")
-      .def_prop_ro("batch",
-                   [](const MlirLinalgContractionDimensions &self) {
-                     return self.batch;
-                   })
+  nb::class_<PyLinalgContractionDimensions>(m, "ContractionDimensions")
+      .def_prop_ro(
+          "batch",
+          [](const PyLinalgContractionDimensions &self) { return self.batch; })
       .def_prop_ro(
-          "m",
-          [](const MlirLinalgContractionDimensions &self) { return self.m; })
+          "m", [](const PyLinalgContractionDimensions &self) { return self.m; })
       .def_prop_ro(
-          "n",
-          [](const MlirLinalgContractionDimensions &self) { return self.n; })
-      .def_prop_ro("k", [](const MlirLinalgContractionDimensions &self) {
+          "n", [](const PyLinalgContractionDimensions &self) { return self.n; })
+      .def_prop_ro("k", [](const PyLinalgContractionDimensions &self) {
         return self.k;
       });
 
@@ -82,80 +98,90 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
 
   m.def(
       "infer_contraction_dimensions_from_maps",
-      [](std::vector<MlirAffineMap> indexingMaps)
-          -> std::optional<MlirLinalgContractionDimensions> {
+      [](std::vector<PyAffineMap> indexingMaps)
+          -> std::optional<PyLinalgContractionDimensions> {
         if (indexingMaps.empty())
           return std::nullopt;
 
+        std::vector<MlirAffineMap> indexingMaps_(indexingMaps.size());
+        std::copy(indexingMaps.begin(), indexingMaps.end(),
+                  indexingMaps_.begin());
         MlirLinalgContractionDimensions dims =
-            mlirLinalgInferContractionDimensionsFromMaps(indexingMaps.data(),
-                                                         indexingMaps.size());
+            mlirLinalgInferContractionDimensionsFromMaps(indexingMaps_.data(),
+                                                         indexingMaps_.size());
 
         // Detect "empty" result from invalid input or failed inference.
         if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) &&
             mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
           return std::nullopt;
         }
-        return dims;
+        return PyLinalgContractionDimensions{dims.batch, dims.m, dims.k,
+                                             dims.n};
       },
       "Infers contraction dimensions (batch/m/n/k) from a list of affine "
       "maps.",
       nb::arg("indexing_maps"));
 
-  m.def("isa_convolution_op", &mlirLinalgIsAConvolutionOp,
-        "Checks if the given operation is a Linalg convolution operation.",
-        nb::arg("op"));
+  m.def(
+      "isa_convolution_op",
+      [](PyOperationBase &op) {
+        return mlirLinalgIsAConvolutionOp(op.getOperation());
+      },
+      "Checks if the given operation is a Linalg convolution operation.",
+      nb::arg("op"));
 
-  nb::class_<MlirLinalgConvolutionDimensions>(m, "ConvolutionDimensions")
-      .def_prop_ro("batch",
-                   [](const MlirLinalgConvolutionDimensions &self) {
-                     return self.batch;
-                   })
+  nb::class_<PyLinalgConvolutionDimensions>(m, "ConvolutionDimensions")
+      .def_prop_ro(
+          "batch",
+          [](const PyLinalgConvolutionDimensions &self) { return self.batch; })
       .def_prop_ro("output_image",
-                   [](const MlirLinalgConvolutionDimensions &self) {
+                   [](const PyLinalgConvolutionDimensions &self) {
                      return self.outputImage;
                    })
       .def_prop_ro("output_channel",
-                   [](const MlirLinalgConvolutionDimensions &self) {
+                   [](const PyLinalgConvolutionDimensions &self) {
                      return self.outputChannel;
                    })
       .def_prop_ro("filter_loop",
-                   [](const MlirLinalgConvolutionDimensions &self) {
+                   [](const PyLinalgConvolutionDimensions &self) {
                      return self.filterLoop;
                    })
       .def_prop_ro("input_channel",
-                   [](const MlirLinalgConvolutionDimensions &self) {
+                   [](const PyLinalgConvolutionDimensions &self) {
                      return self.inputChannel;
                    })
-      .def_prop_ro("depth",
-                   [](const MlirLinalgConvolutionDimensions &self) {
-                     return self.depth;
-                   })
+      .def_prop_ro(
+          "depth",
+          [](const PyLinalgConvolutionDimensions &self) { return self.depth; })
       .def_prop_ro("strides",
-                   [](const MlirLinalgConvolutionDimensions &self) {
+                   [](const PyLinalgConvolutionDimensions &self) {
                      return self.strides;
                    })
-      .def_prop_ro("dilations",
-                   [](const MlirLinalgConvolutionDimensions &self) {
-                     return self.dilations;
-                   });
+      .def_prop_ro("dilations", [](const PyLinalgConvolutionDimensions &self) {
+        return self.dilations;
+      });
 
   m.def("infer_convolution_dimensions", &InferConvolutionDimensions,
         "Infers convolution dimensions", nb::arg("op"));
 
   m.def(
       "get_indexing_maps",
-      [](MlirOperation op) -> std::optional<MlirAttribute> {
-        MlirAttribute attr = mlirLinalgGetIndexingMapsAttribute(op);
+      [](PyOperationBase &op) -> std::optional<PyAttribute> {
+        MlirAttribute attr =
+            mlirLinalgGetIndexingMapsAttribute(op.getOperation());
         if (mlirAttributeIsNull(attr))
           return std::nullopt;
-        return attr;
+        return PyArrayAttribute(op.getContext(), attr);
       },
       "Returns the indexing_maps attribute for a linalg op.");
 }
+} // namespace linalg
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
 
 NB_MODULE(_mlirDialectsLinalg, m) {
   m.doc() = "MLIR Linalg dialect.";
 
-  populateDialectLinalgSubmodule(m);
+  mlir::python::mlir::linalg::populateDialectLinalgSubmodule(m);
 }
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 003a06b16daac..a7cece87a11d6 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -924,7 +924,25 @@ if(MLIR_PYTHON_STUBGEN_ENABLED)
     DEPENDS_TARGET_SRC_DEPS "${_core_extension_srcs}"
     IMPORT_PATHS "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs"
   )
-  set(_mlir_typestub_gen_target "${NB_STUBGEN_CUSTOM_TARGET}")
+  set(_mlir_typestub_gen_targets "${NB_STUBGEN_CUSTOM_TARGET}")
+
+  get_target_property(_linalg_extension_srcs MLIRPythonExtension.Dialects.Linalg.Nanobind INTERFACE_SOURCES)
+  mlir_generate_type_stubs(
+    MODULE_NAME ${MLIR_PYTHON_PACKAGE_PREFIX}._mlir_libs._mlirDialectsLinalg
+    DEPENDS_TARGETS
+      # You need both _mlir and _mlirPythonTestNanobind because dialect modules import _mlir when loaded
+      # (so _mlir needs to be built before calling stubgen).
+      MLIRPythonModules.extension._mlir.dso
+      MLIRPythonModules.extension._mlirDialectsLinalg.dso
+      # You need this one so that ir.py "built" because mlir._mlir_libs.__init__.py import mlir.ir in _site_initialize.
+      MLIRPythonModules.sources.MLIRPythonSources.Core.Python
+    OUTPUT_DIR "${CMAKE_CURRENT_BINARY_DIR}/type_stubs/_mlir_libs"
+    OUTPUTS _mlirDialectsLinalg.pyi
+    DEPENDS_TARGET_SRC_DEPS "${_linalg_extension_srcs}"
+    IMPORT_PATHS "${MLIRPythonModules_ROOT_PREFIX}/.."
+  )
+  list(APPEND _mlir_typestub_gen_targets "${NB_STUBGEN_CUSTOM_TARGET}")
+  list(APPEND _core_type_stub_sources "_mlirDialectsLinalg.pyi")
 
   list(TRANSFORM _core_type_stub_sources PREPEND "_mlir_libs/")
   # Note, we do not do ADD_TO_PARENT here so that the type stubs are not associated (as mlir_DEPENDS) with
@@ -943,7 +961,7 @@ if(MLIR_PYTHON_STUBGEN_ENABLED)
     get_target_property(_test_extension_srcs MLIRPythonTestSources.PythonTestExtensionNanobind INTERFACE_SOURCES)
     mlir_generate_type_stubs(
       # This is the FQN path because dialect modules import _mlir when loaded. See above.
-      MODULE_NAME mlir._mlir_libs._mlirPythonTestNanobind
+      MODULE_NAME ${MLIR_PYTHON_PACKAGE_PREFIX}._mlir_libs._mlirPythonTestNanobind
       DEPENDS_TARGETS
         # You need both _mlir and _mlirPythonTestNanobind because dialect modules import _mlir when loaded
         # (so _mlir needs to be built before calling stubgen).
@@ -994,7 +1012,7 @@ add_mlir_python_modules(MLIRPythonModules
     MLIRPythonCAPI
 )
 if(MLIR_PYTHON_STUBGEN_ENABLED)
-  add_dependencies(MLIRPythonModules "${_mlir_typestub_gen_target}")
+  add_dependencies(MLIRPythonModules ${_mlir_typestub_gen_targets})
   if(MLIR_INCLUDE_TESTS)
     add_dependencies(MLIRPythonModules "${_mlirPythonTestNanobind_typestub_gen_target}")
   endif()
diff --git a/mlir/test/python/lib/PythonTestModuleNanobind.cpp b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
index a296b5e814b4b..e9754749352b1 100644
--- a/mlir/test/python/lib/PythonTestModuleNanobind.cpp
+++ b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
@@ -22,14 +22,16 @@
 
 namespace nb = nanobind;
 using namespace mlir::python::nanobind_adaptors;
-
+namespace mlir {
+namespace python {
+namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+namespace python_test {
 static bool mlirTypeIsARankedIntegerTensor(MlirType t) {
   return mlirTypeIsARankedTensor(t) &&
          mlirTypeIsAInteger(mlirShapedTypeGetElementType(t));
 }
 
-struct PyTestType
-    : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType<PyTestType> {
+struct PyTestType : PyConcreteType<PyTestType> {
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPythonTestTestType;
   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
       mlirPythonTestTestTypeGetTypeID;
@@ -39,8 +41,7 @@ struct PyTestType
   static void bindDerived(ClassTy &c) {
     c.def_static(
         "get",
-        [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
-               context) {
+        [](DefaultingPyMlirContext context) {
           return PyTestType(context->getRef(),
                             mlirPythonTestTestTypeGet(context.get()->get()));
         },
@@ -49,9 +50,7 @@ struct PyTestType
 };
 
 struct PyTestIntegerRankedTensorType
-    : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType<
-          PyTestIntegerRankedTensorType,
-          mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyRankedTensorType> {
+    : PyConcreteType<PyTestIntegerRankedTensorType, PyRankedTensorType> {
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedIntegerTensor;
   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
       mlirRankedTensorTypeGetTypeID;
@@ -62,8 +61,7 @@ struct PyTestIntegerRankedTensorType
     c.def_static(
         "get",
         [](std::vector<int64_t> shape, unsigned width,
-           mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
-               ctx) {
+           DefaultingPyMlirContext ctx) {
           MlirAttribute encoding = mlirAttributeGetNull();
           return PyTestIntegerRankedTensorType(
               ctx->getRef(),
@@ -76,9 +74,7 @@ struct PyTestIntegerRankedTensorType
   }
 };
 
-struct PyTestTensorValue
-    : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteValue<
-          PyTestTensorValue> {
+struct PyTestTensorValue : PyConcreteValue<PyTestTensorValue> {
   static constexpr IsAFunctionTy isaFunction =
       mlirTypeIsAPythonTestTestTensorValue;
   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -91,9 +87,7 @@ struct PyTestTensorValue
   }
 };
 
-class PyTestAttr
-    : public mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteAttribute<
-          PyTestAttr> {
+class PyTestAttr : public PyConcreteAttribute<PyTestAttr> {
 public:
   static constexpr IsAFunctionTy isaFunction =
       mlirAttributeIsAPythonTestTestAttribute;
@@ -105,21 +99,23 @@ class PyTestAttr
   static void bindDerived(ClassTy &c) {
     c.def_static(
         "get",
-        [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
-               context) {
+        [](DefaultingPyMlirContext context) {
           return PyTestAttr(context->getRef(), mlirPythonTestTestAttributeGet(
                                                    context.get()->get()));
         },
         nb::arg("context").none() = nb::none());
   }
 };
+} // namespace python_test
+} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
+} // namespace python
+} // namespace mlir
 
 NB_MODULE(_mlirPythonTestNanobind, m) {
+  using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
   m.def(
       "register_python_test_dialect",
-      [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
-             context,
-         bool load) {
+      [](DefaultingPyMlirContext context, bool load) {
         MlirDialectHandle pythonTestDialect =
             mlirGetDialectHandle__python_test__();
         mlirDialectHandleRegisterDialect(pythonTestDialect,
@@ -144,14 +140,14 @@ NB_MODULE(_mlirPythonTestNanobind, m) {
 
   m.def(
       "test_diagnostics_with_errors_and_notes",
-      [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
-             ctx) {
+      [](DefaultingPyMlirContext ctx) {
         mlir::python::CollectDiagnosticsToStringScope handler(ctx.get()->get());
         mlirPythonTestEmitDiagnosticWithNote(ctx.get()->get());
         throw nb::value_error(handler.takeMessage().c_str());
       },
       nb::arg("context").none() = nb::none());
 
+  using namespace python_test;
   PyTestAttr::bind(m);
   PyTestType::bind(m);
   PyTestIntegerRankedTensorType::bind(m);



More information about the Mlir-commits mailing list