[Mlir-commits] [mlir] [MLIR] [Python] More improvements to type annotations (PR #188468)

Sergei Lebedev llvmlistbot at llvm.org
Wed Mar 25 07:09:50 PDT 2026


https://github.com/superbobry updated https://github.com/llvm/llvm-project/pull/188468

>From f48a60ba31e66d8cf2a14c6dec351c70e9a36bd0 Mon Sep 17 00:00:00 2001
From: Sergei Lebedev <slebedev at google.com>
Date: Wed, 25 Mar 2026 11:52:48 +0000
Subject: [PATCH] [MLIR] [Python] More improvements to type annotations

* `mlir.ir` now exports `_OperationBase`. It is handy to use when both
  `Operation` and `OpView` are accepted.
* Added type arguments where they were missing, e.g. `list[ir.Attribute]`
  instead of just `list`.
* Changed `Opview.build_generic` and `OpView.parse` to return `Self`
  instead of the supertype `Type`.
* Changed the bindings generator to emit a parameterized `OpResult` when the
  exact type is available.
---
 mlir/lib/Bindings/Python/IRAffine.cpp         | 17 +++++++-----
 mlir/lib/Bindings/Python/IRAttributes.cpp     |  6 +++--
 mlir/lib/Bindings/Python/IRCore.cpp           | 27 +++++++++++++------
 mlir/lib/Bindings/Python/IRTypes.cpp          |  4 +--
 mlir/python/mlir/ir.py                        |  1 +
 mlir/test/mlir-tblgen/op-python-bindings.td   | 15 +++++++++++
 mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp |  4 +++
 7 files changed, 55 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index 116e20ee834e9..2ec13f10df380 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -733,7 +733,8 @@ void populateIRAffine(nb::module_ &m) {
            })
       .def_static(
           "compress_unused_symbols",
-          [](const nb::list &affineMaps, DefaultingPyMlirContext context) {
+          [](nb::typed<nb::list, PyAffineMap> affineMaps,
+             DefaultingPyMlirContext context) {
             std::vector<MlirAffineMap> maps;
             pyListToVector<PyAffineMap, MlirAffineMap>(
                 affineMaps, maps, "attempting to create an AffineMap");
@@ -760,7 +761,8 @@ void populateIRAffine(nb::module_ &m) {
           kDumpDocstring)
       .def_static(
           "get",
-          [](intptr_t dimCount, intptr_t symbolCount, const nb::list &exprs,
+          [](intptr_t dimCount, intptr_t symbolCount,
+             nb::typed<nb::list, PyAffineExpr> exprs,
              DefaultingPyMlirContext context) {
             std::vector<MlirAffineExpr> affineExprs;
             pyListToVector<PyAffineExpr, MlirAffineExpr>(
@@ -927,8 +929,9 @@ void populateIRAffine(nb::module_ &m) {
           kDumpDocstring)
       .def_static(
           "get",
-          [](intptr_t numDims, intptr_t numSymbols, const nb::list &exprs,
-             std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
+          [](intptr_t numDims, intptr_t numSymbols,
+             nb::typed<nb::list, PyAffineExpr> exprs, std::vector<bool> eqFlags,
+             DefaultingPyMlirContext context) {
             if (exprs.size() != eqFlags.size())
               throw nb::value_error(
                   "Expected the number of constraints to match "
@@ -960,9 +963,9 @@ void populateIRAffine(nb::module_ &m) {
           nb::arg("context") = nb::none())
       .def(
           "get_replaced",
-          [](PyIntegerSet &self, const nb::list &dimExprs,
-             const nb::list &symbolExprs, intptr_t numResultDims,
-             intptr_t numResultSymbols) {
+          [](PyIntegerSet &self, nb::typed<nb::list, PyAffineExpr> dimExprs,
+             nb::typed<nb::list, PyAffineExpr> symbolExprs,
+             intptr_t numResultDims, intptr_t numResultSymbols) {
             if (static_cast<intptr_t>(dimExprs.size()) !=
                 mlirIntegerSetGetNumDims(self))
               throw nb::value_error(
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 9f5602cc61b35..5aebfabf5bc18 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -280,7 +280,8 @@ MlirAttribute PyArrayAttribute::getItem(intptr_t i) const {
 void PyArrayAttribute::bindDerived(ClassTy &c) {
   c.def_static(
       "get",
-      [](const nb::list &attributes, DefaultingPyMlirContext context) {
+      [](nb::typed<nb::list, PyAttribute> attributes,
+         DefaultingPyMlirContext context) {
         std::vector<MlirAttribute> mlirAttributes;
         mlirAttributes.reserve(nb::len(attributes));
         for (auto attribute : attributes) {
@@ -306,7 +307,8 @@ void PyArrayAttribute::bindDerived(ClassTy &c) {
       .def("__iter__", [](const PyArrayAttribute &arr) {
         return PyArrayAttributeIterator(arr);
       });
-  c.def("__add__", [](PyArrayAttribute arr, const nb::list &extras) {
+  c.def("__add__", [](PyArrayAttribute arr,
+                      nb::typed<nb::list, PyAttribute> extras) {
     std::vector<MlirAttribute> attributes;
     intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
     attributes.reserve(numOldElements + nb::len(extras));
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index f3f1ee4ce343f..ee36659634fec 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2489,7 +2489,7 @@ void PyOpAttributeMap::bind(nb::module_ &m) {
            "exist.")
       .def(
           "__iter__",
-          [](PyOpAttributeMap &self) {
+          [](PyOpAttributeMap &self) -> nb::typed<nb::iterator, nb::str> {
             nb::list keys;
             PyOpAttributeMap::forEachAttr(
                 self.operation->get(), [&](MlirStringRef name, MlirAttribute) {
@@ -2500,7 +2500,7 @@ void PyOpAttributeMap::bind(nb::module_ &m) {
           "Iterates over attribute names.")
       .def(
           "keys",
-          [](PyOpAttributeMap &self) {
+          [](PyOpAttributeMap &self) -> nb::typed<nb::list, nb::str> {
             nb::list out;
             PyOpAttributeMap::forEachAttr(
                 self.operation->get(), [&](MlirStringRef name, MlirAttribute) {
@@ -2511,7 +2511,7 @@ void PyOpAttributeMap::bind(nb::module_ &m) {
           "Returns a list of attribute names.")
       .def(
           "values",
-          [](PyOpAttributeMap &self) {
+          [](PyOpAttributeMap &self) -> nb::typed<nb::list, PyAttribute> {
             nb::list out;
             PyOpAttributeMap::forEachAttr(
                 self.operation->get(), [&](MlirStringRef, MlirAttribute attr) {
@@ -2523,7 +2523,9 @@ void PyOpAttributeMap::bind(nb::module_ &m) {
           "Returns a list of attribute values.")
       .def(
           "items",
-          [](PyOpAttributeMap &self) {
+          [](PyOpAttributeMap &self)
+              -> nb::typed<nb::list,
+                           nb::typed<nb::tuple, nb::str, PyAttribute>> {
             nb::list out;
             PyOpAttributeMap::forEachAttr(
                 self.operation->get(),
@@ -4139,6 +4141,9 @@ void populateIRCore(nb::module_ &m) {
       "cls"_a, "results"_a = nb::none(), "operands"_a = nb::none(),
       "attributes"_a = nb::none(), "successors"_a = nb::none(),
       "regions"_a = nb::none(), "loc"_a = nb::none(), "ip"_a = nb::none(),
+      // clang-format off
+      nb::sig("def build_generic(cls, results: Sequence[Type] | None = None, operands: Sequence[Value] | None = None, attributes: dict[str, Attribute] | None = None, successors: Sequence[Block] | None = None, regions: int | None = None, loc: Location | None = None, ip: InsertionPoint | None = None) -> typing.Self"),
+      // clang-format on
       "Builds a specific, generated OpView based on class level attributes.");
   opViewClass.attr("parse") = classmethod(
       [](const nb::object &cls, const std::string &sourceStr,
@@ -4164,6 +4169,9 @@ void populateIRCore(nb::module_ &m) {
       },
       "cls"_a, "source"_a, nb::kw_only(), "source_name"_a = "",
       "context"_a = nb::none(),
+      // clang-format off
+      nb::sig("def parse(cls, source: str, *, source_name: str = '', context: Context | None = None) -> typing.Self"),
+      // clang-format on
       "Parses a specific, generated OpView based on class level attributes.");
 
   PyOpAdaptor::bind(m);
@@ -4262,8 +4270,9 @@ void populateIRCore(nb::module_ &m) {
           "Returns a forward-optimized sequence of operations.")
       .def_static(
           "create_at_start",
-          [](PyRegion &parent, const nb::sequence &pyArgTypes,
-             const std::optional<nb::sequence> &pyArgLocs) {
+          [](PyRegion &parent, nb::typed<nb::sequence, PyType> pyArgTypes,
+             const std::optional<nb::typed<nb::sequence, PyLocation>>
+                 &pyArgLocs) {
             parent.checkValid();
             MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
             mlirRegionInsertOwnedBlock(parent, 0, block);
@@ -4291,7 +4300,8 @@ void populateIRCore(nb::module_ &m) {
       .def(
           "create_before",
           [](PyBlock &self, const nb::args &pyArgTypes,
-             const std::optional<nb::sequence> &pyArgLocs) {
+             const std::optional<nb::typed<nb::sequence, PyLocation>>
+                 &pyArgLocs) {
             self.checkValid();
             MlirBlock block =
                 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
@@ -4305,7 +4315,8 @@ void populateIRCore(nb::module_ &m) {
       .def(
           "create_after",
           [](PyBlock &self, const nb::args &pyArgTypes,
-             const std::optional<nb::sequence> &pyArgLocs) {
+             const std::optional<nb::typed<nb::sequence, PyLocation>>
+                 &pyArgLocs) {
             self.checkValid();
             MlirBlock block =
                 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index e04ff99b6c5fc..340deb019eb20 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -790,7 +790,7 @@ void PyFunctionType::bindDerived(ClassTy &c) {
       "Gets a FunctionType from a list of input and result types");
   c.def_prop_ro(
       "inputs",
-      [](PyFunctionType &self) {
+      [](PyFunctionType &self) -> nb::typed<nb::list, PyType> {
         MlirType t = self;
         nb::list types;
         for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
@@ -802,7 +802,7 @@ void PyFunctionType::bindDerived(ClassTy &c) {
       "Returns the list of input types in the FunctionType.");
   c.def_prop_ro(
       "results",
-      [](PyFunctionType &self) {
+      [](PyFunctionType &self) -> nb::typed<nb::list, PyType> {
         nb::list types;
         for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
              ++i) {
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 3795f5cb2e036..c5b00a561831b 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -9,6 +9,7 @@
 
 from ._mlir_libs._mlir.ir import *
 from ._mlir_libs._mlir.ir import _GlobalDebug
+from ._mlir_libs._mlir.ir import _OperationBase
 from ._mlir_libs._mlir import (
     register_type_caster,
     register_value_caster,
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 9bfda1ec02303..141cf430f36ef 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -646,6 +646,21 @@ def SimpleOp : TestOp<"simple"> {
 // CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None) -> _ods_ir.OpResultList:
 // CHECK:   return SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip).results
 
+// CHECK: @_ods_cext.register_operation(_Dialect)
+// CHECK-LABEL: class SingleTypedResultOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.single_typed_result"
+def SingleTypedResultOp : TestOp<"single_typed_result"> {
+  // CHECK: @builtins.property
+  // CHECK: def i64(self) -> _ods_ir.OpResult[_ods_ir.IntegerType]:
+  // CHECK:   return self.operation.results[0]
+  let arguments = (ins AnyType:$in);
+  let results = (outs I64:$i64);
+}
+
+// CHECK: def single_typed_result(in_, *, results=None, loc=None, ip=None) ->
+// _ods_ir.OpResult[_ods_ir.IntegerType]: CHECK:   return
+// SingleTypedResultOp(in_=in_, results=results, loc=loc, ip=ip).result
+
 // CHECK-LABEL: class VariadicAndNormalRegionOp(_ods_ir.OpView):
 // CHECK: OPERATION_NAME = "test.variadic_and_normal_region"
 def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 84dce9bdf0c6d..39e79e5631479 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -1261,6 +1261,10 @@ static void emitValueBuilder(const Operator &op,
       results = ".results";
     } else if (op.getNumResults() == 1) {
       type = "_ods_ir.OpResult";
+      if (StringRef pythonType =
+              getPythonType(op.getResult(0).constraint.getCppType());
+          !pythonType.empty())
+        type = llvm::formatv("{0}[{1}]", type, pythonType);
       results = ".result";
     }
     os << formatv(valueBuilderTemplate, nameWithoutDialect,



More information about the Mlir-commits mailing list