[Mlir-commits] [mlir] 396d638 - [MLIR] [Python] More improvements to type annotations (#188468)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 25 07:18:34 PDT 2026


Author: Sergei Lebedev
Date: 2026-03-25T14:18:29Z
New Revision: 396d63813ccb4e3e9ed9e015bc236a3941630d6e

URL: https://github.com/llvm/llvm-project/commit/396d63813ccb4e3e9ed9e015bc236a3941630d6e
DIFF: https://github.com/llvm/llvm-project/commit/396d63813ccb4e3e9ed9e015bc236a3941630d6e.diff

LOG: [MLIR] [Python] More improvements to type annotations (#188468)

* `mlir.ir` now exports `_OperationBase`. It is handy to use when both
`Operation` and `OpView` are accepted.
* Added type arguments where they were missing, e.g.
`list[ir.Attribute]` instead of just `list`.
* Changed `Opview.build_generic` and `OpView.parse` to return `Self`
instead of the supertype `Type`.
* Changed the bindings generator to emit a parameterized `OpResult` when
the exact type is available.

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRAffine.cpp
    mlir/lib/Bindings/Python/IRAttributes.cpp
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/IRTypes.cpp
    mlir/python/mlir/ir.py
    mlir/test/mlir-tblgen/op-python-bindings.td
    mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index 116e20ee834e9..2ec13f10df380 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -733,7 +733,8 @@ void populateIRAffine(nb::module_ &m) {
            })
       .def_static(
           "compress_unused_symbols",
-          [](const nb::list &affineMaps, DefaultingPyMlirContext context) {
+          [](nb::typed<nb::list, PyAffineMap> affineMaps,
+             DefaultingPyMlirContext context) {
             std::vector<MlirAffineMap> maps;
             pyListToVector<PyAffineMap, MlirAffineMap>(
                 affineMaps, maps, "attempting to create an AffineMap");
@@ -760,7 +761,8 @@ void populateIRAffine(nb::module_ &m) {
           kDumpDocstring)
       .def_static(
           "get",
-          [](intptr_t dimCount, intptr_t symbolCount, const nb::list &exprs,
+          [](intptr_t dimCount, intptr_t symbolCount,
+             nb::typed<nb::list, PyAffineExpr> exprs,
              DefaultingPyMlirContext context) {
             std::vector<MlirAffineExpr> affineExprs;
             pyListToVector<PyAffineExpr, MlirAffineExpr>(
@@ -927,8 +929,9 @@ void populateIRAffine(nb::module_ &m) {
           kDumpDocstring)
       .def_static(
           "get",
-          [](intptr_t numDims, intptr_t numSymbols, const nb::list &exprs,
-             std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
+          [](intptr_t numDims, intptr_t numSymbols,
+             nb::typed<nb::list, PyAffineExpr> exprs, std::vector<bool> eqFlags,
+             DefaultingPyMlirContext context) {
             if (exprs.size() != eqFlags.size())
               throw nb::value_error(
                   "Expected the number of constraints to match "
@@ -960,9 +963,9 @@ void populateIRAffine(nb::module_ &m) {
           nb::arg("context") = nb::none())
       .def(
           "get_replaced",
-          [](PyIntegerSet &self, const nb::list &dimExprs,
-             const nb::list &symbolExprs, intptr_t numResultDims,
-             intptr_t numResultSymbols) {
+          [](PyIntegerSet &self, nb::typed<nb::list, PyAffineExpr> dimExprs,
+             nb::typed<nb::list, PyAffineExpr> symbolExprs,
+             intptr_t numResultDims, intptr_t numResultSymbols) {
             if (static_cast<intptr_t>(dimExprs.size()) !=
                 mlirIntegerSetGetNumDims(self))
               throw nb::value_error(

diff  --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 9f5602cc61b35..5aebfabf5bc18 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -280,7 +280,8 @@ MlirAttribute PyArrayAttribute::getItem(intptr_t i) const {
 void PyArrayAttribute::bindDerived(ClassTy &c) {
   c.def_static(
       "get",
-      [](const nb::list &attributes, DefaultingPyMlirContext context) {
+      [](nb::typed<nb::list, PyAttribute> attributes,
+         DefaultingPyMlirContext context) {
         std::vector<MlirAttribute> mlirAttributes;
         mlirAttributes.reserve(nb::len(attributes));
         for (auto attribute : attributes) {
@@ -306,7 +307,8 @@ void PyArrayAttribute::bindDerived(ClassTy &c) {
       .def("__iter__", [](const PyArrayAttribute &arr) {
         return PyArrayAttributeIterator(arr);
       });
-  c.def("__add__", [](PyArrayAttribute arr, const nb::list &extras) {
+  c.def("__add__", [](PyArrayAttribute arr,
+                      nb::typed<nb::list, PyAttribute> extras) {
     std::vector<MlirAttribute> attributes;
     intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
     attributes.reserve(numOldElements + nb::len(extras));

diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index f3f1ee4ce343f..ee36659634fec 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2489,7 +2489,7 @@ void PyOpAttributeMap::bind(nb::module_ &m) {
            "exist.")
       .def(
           "__iter__",
-          [](PyOpAttributeMap &self) {
+          [](PyOpAttributeMap &self) -> nb::typed<nb::iterator, nb::str> {
             nb::list keys;
             PyOpAttributeMap::forEachAttr(
                 self.operation->get(), [&](MlirStringRef name, MlirAttribute) {
@@ -2500,7 +2500,7 @@ void PyOpAttributeMap::bind(nb::module_ &m) {
           "Iterates over attribute names.")
       .def(
           "keys",
-          [](PyOpAttributeMap &self) {
+          [](PyOpAttributeMap &self) -> nb::typed<nb::list, nb::str> {
             nb::list out;
             PyOpAttributeMap::forEachAttr(
                 self.operation->get(), [&](MlirStringRef name, MlirAttribute) {
@@ -2511,7 +2511,7 @@ void PyOpAttributeMap::bind(nb::module_ &m) {
           "Returns a list of attribute names.")
       .def(
           "values",
-          [](PyOpAttributeMap &self) {
+          [](PyOpAttributeMap &self) -> nb::typed<nb::list, PyAttribute> {
             nb::list out;
             PyOpAttributeMap::forEachAttr(
                 self.operation->get(), [&](MlirStringRef, MlirAttribute attr) {
@@ -2523,7 +2523,9 @@ void PyOpAttributeMap::bind(nb::module_ &m) {
           "Returns a list of attribute values.")
       .def(
           "items",
-          [](PyOpAttributeMap &self) {
+          [](PyOpAttributeMap &self)
+              -> nb::typed<nb::list,
+                           nb::typed<nb::tuple, nb::str, PyAttribute>> {
             nb::list out;
             PyOpAttributeMap::forEachAttr(
                 self.operation->get(),
@@ -4139,6 +4141,9 @@ void populateIRCore(nb::module_ &m) {
       "cls"_a, "results"_a = nb::none(), "operands"_a = nb::none(),
       "attributes"_a = nb::none(), "successors"_a = nb::none(),
       "regions"_a = nb::none(), "loc"_a = nb::none(), "ip"_a = nb::none(),
+      // clang-format off
+      nb::sig("def build_generic(cls, results: Sequence[Type] | None = None, operands: Sequence[Value] | None = None, attributes: dict[str, Attribute] | None = None, successors: Sequence[Block] | None = None, regions: int | None = None, loc: Location | None = None, ip: InsertionPoint | None = None) -> typing.Self"),
+      // clang-format on
       "Builds a specific, generated OpView based on class level attributes.");
   opViewClass.attr("parse") = classmethod(
       [](const nb::object &cls, const std::string &sourceStr,
@@ -4164,6 +4169,9 @@ void populateIRCore(nb::module_ &m) {
       },
       "cls"_a, "source"_a, nb::kw_only(), "source_name"_a = "",
       "context"_a = nb::none(),
+      // clang-format off
+      nb::sig("def parse(cls, source: str, *, source_name: str = '', context: Context | None = None) -> typing.Self"),
+      // clang-format on
       "Parses a specific, generated OpView based on class level attributes.");
 
   PyOpAdaptor::bind(m);
@@ -4262,8 +4270,9 @@ void populateIRCore(nb::module_ &m) {
           "Returns a forward-optimized sequence of operations.")
       .def_static(
           "create_at_start",
-          [](PyRegion &parent, const nb::sequence &pyArgTypes,
-             const std::optional<nb::sequence> &pyArgLocs) {
+          [](PyRegion &parent, nb::typed<nb::sequence, PyType> pyArgTypes,
+             const std::optional<nb::typed<nb::sequence, PyLocation>>
+                 &pyArgLocs) {
             parent.checkValid();
             MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
             mlirRegionInsertOwnedBlock(parent, 0, block);
@@ -4291,7 +4300,8 @@ void populateIRCore(nb::module_ &m) {
       .def(
           "create_before",
           [](PyBlock &self, const nb::args &pyArgTypes,
-             const std::optional<nb::sequence> &pyArgLocs) {
+             const std::optional<nb::typed<nb::sequence, PyLocation>>
+                 &pyArgLocs) {
             self.checkValid();
             MlirBlock block =
                 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
@@ -4305,7 +4315,8 @@ void populateIRCore(nb::module_ &m) {
       .def(
           "create_after",
           [](PyBlock &self, const nb::args &pyArgTypes,
-             const std::optional<nb::sequence> &pyArgLocs) {
+             const std::optional<nb::typed<nb::sequence, PyLocation>>
+                 &pyArgLocs) {
             self.checkValid();
             MlirBlock block =
                 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);

diff  --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index e04ff99b6c5fc..340deb019eb20 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -790,7 +790,7 @@ void PyFunctionType::bindDerived(ClassTy &c) {
       "Gets a FunctionType from a list of input and result types");
   c.def_prop_ro(
       "inputs",
-      [](PyFunctionType &self) {
+      [](PyFunctionType &self) -> nb::typed<nb::list, PyType> {
         MlirType t = self;
         nb::list types;
         for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
@@ -802,7 +802,7 @@ void PyFunctionType::bindDerived(ClassTy &c) {
       "Returns the list of input types in the FunctionType.");
   c.def_prop_ro(
       "results",
-      [](PyFunctionType &self) {
+      [](PyFunctionType &self) -> nb::typed<nb::list, PyType> {
         nb::list types;
         for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
              ++i) {

diff  --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 3795f5cb2e036..c5b00a561831b 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -9,6 +9,7 @@
 
 from ._mlir_libs._mlir.ir import *
 from ._mlir_libs._mlir.ir import _GlobalDebug
+from ._mlir_libs._mlir.ir import _OperationBase
 from ._mlir_libs._mlir import (
     register_type_caster,
     register_value_caster,

diff  --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 9bfda1ec02303..141cf430f36ef 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -646,6 +646,21 @@ def SimpleOp : TestOp<"simple"> {
 // CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None) -> _ods_ir.OpResultList:
 // CHECK:   return SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip).results
 
+// CHECK: @_ods_cext.register_operation(_Dialect)
+// CHECK-LABEL: class SingleTypedResultOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.single_typed_result"
+def SingleTypedResultOp : TestOp<"single_typed_result"> {
+  // CHECK: @builtins.property
+  // CHECK: def i64(self) -> _ods_ir.OpResult[_ods_ir.IntegerType]:
+  // CHECK:   return self.operation.results[0]
+  let arguments = (ins AnyType:$in);
+  let results = (outs I64:$i64);
+}
+
+// CHECK: def single_typed_result(in_, *, results=None, loc=None, ip=None) ->
+// _ods_ir.OpResult[_ods_ir.IntegerType]: CHECK:   return
+// SingleTypedResultOp(in_=in_, results=results, loc=loc, ip=ip).result
+
 // CHECK-LABEL: class VariadicAndNormalRegionOp(_ods_ir.OpView):
 // CHECK: OPERATION_NAME = "test.variadic_and_normal_region"
 def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {

diff  --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 84dce9bdf0c6d..39e79e5631479 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -1261,6 +1261,10 @@ static void emitValueBuilder(const Operator &op,
       results = ".results";
     } else if (op.getNumResults() == 1) {
       type = "_ods_ir.OpResult";
+      if (StringRef pythonType =
+              getPythonType(op.getResult(0).constraint.getCppType());
+          !pythonType.empty())
+        type = llvm::formatv("{0}[{1}]", type, pythonType);
       results = ".result";
     }
     os << formatv(valueBuilderTemplate, nameWithoutDialect,


        


More information about the Mlir-commits mailing list