[Mlir-commits] [mlir] [MLIR] [Python] More improvements to type annotations (PR #188468)
Sergei Lebedev
llvmlistbot at llvm.org
Wed Mar 25 04:56:50 PDT 2026
https://github.com/superbobry created https://github.com/llvm/llvm-project/pull/188468
* `mlir.ir` now exports `_OperationBase`. It is handy to use when both `Operation` and `OpView` are accepted.
* Added type arguments where they were missing, e.g. `list[ir.Attribute]` instead of just `list`.
* Changed `Opview.build_generic` and `OpView.parse` to return `Self` instead of the supertype `Type`.
* Changed the bindings generator to emit a parameterized `OpResult` when the exact type is available.
>From 75813f7b2a30c4863fec6d654875059a891859aa Mon Sep 17 00:00:00 2001
From: Sergei Lebedev <slebedev at google.com>
Date: Wed, 25 Mar 2026 11:52:48 +0000
Subject: [PATCH] [MLIR] [Python] More improvements to type annotations
* `mlir.ir` now exports `_OperationBase`. It is handy to use when both
`Operation` and `OpView` are accepted.
* Added type arguments where they were missing, e.g. `list[ir.Attribute]`
instead of just `list`.
* Changed `Opview.build_generic` and `OpView.parse` to return `Self`
instead of the supertype `Type`.
* Changed the bindings generator to emit a parameterized `OpResult` when the
exact type is available.
---
mlir/lib/Bindings/Python/IRAffine.cpp | 17 +++++++------
mlir/lib/Bindings/Python/IRAttributes.cpp | 6 +++--
mlir/lib/Bindings/Python/IRCore.cpp | 25 +++++++++++++------
mlir/lib/Bindings/Python/IRTypes.cpp | 4 +--
mlir/python/mlir/ir.py | 1 +
mlir/test/mlir-tblgen/op-python-bindings.td | 15 +++++++++++
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 4 +++
7 files changed, 53 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index 116e20ee834e9..2ec13f10df380 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -733,7 +733,8 @@ void populateIRAffine(nb::module_ &m) {
})
.def_static(
"compress_unused_symbols",
- [](const nb::list &affineMaps, DefaultingPyMlirContext context) {
+ [](nb::typed<nb::list, PyAffineMap> affineMaps,
+ DefaultingPyMlirContext context) {
std::vector<MlirAffineMap> maps;
pyListToVector<PyAffineMap, MlirAffineMap>(
affineMaps, maps, "attempting to create an AffineMap");
@@ -760,7 +761,8 @@ void populateIRAffine(nb::module_ &m) {
kDumpDocstring)
.def_static(
"get",
- [](intptr_t dimCount, intptr_t symbolCount, const nb::list &exprs,
+ [](intptr_t dimCount, intptr_t symbolCount,
+ nb::typed<nb::list, PyAffineExpr> exprs,
DefaultingPyMlirContext context) {
std::vector<MlirAffineExpr> affineExprs;
pyListToVector<PyAffineExpr, MlirAffineExpr>(
@@ -927,8 +929,9 @@ void populateIRAffine(nb::module_ &m) {
kDumpDocstring)
.def_static(
"get",
- [](intptr_t numDims, intptr_t numSymbols, const nb::list &exprs,
- std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
+ [](intptr_t numDims, intptr_t numSymbols,
+ nb::typed<nb::list, PyAffineExpr> exprs, std::vector<bool> eqFlags,
+ DefaultingPyMlirContext context) {
if (exprs.size() != eqFlags.size())
throw nb::value_error(
"Expected the number of constraints to match "
@@ -960,9 +963,9 @@ void populateIRAffine(nb::module_ &m) {
nb::arg("context") = nb::none())
.def(
"get_replaced",
- [](PyIntegerSet &self, const nb::list &dimExprs,
- const nb::list &symbolExprs, intptr_t numResultDims,
- intptr_t numResultSymbols) {
+ [](PyIntegerSet &self, nb::typed<nb::list, PyAffineExpr> dimExprs,
+ nb::typed<nb::list, PyAffineExpr> symbolExprs,
+ intptr_t numResultDims, intptr_t numResultSymbols) {
if (static_cast<intptr_t>(dimExprs.size()) !=
mlirIntegerSetGetNumDims(self))
throw nb::value_error(
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 9f5602cc61b35..5aebfabf5bc18 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -280,7 +280,8 @@ MlirAttribute PyArrayAttribute::getItem(intptr_t i) const {
void PyArrayAttribute::bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](const nb::list &attributes, DefaultingPyMlirContext context) {
+ [](nb::typed<nb::list, PyAttribute> attributes,
+ DefaultingPyMlirContext context) {
std::vector<MlirAttribute> mlirAttributes;
mlirAttributes.reserve(nb::len(attributes));
for (auto attribute : attributes) {
@@ -306,7 +307,8 @@ void PyArrayAttribute::bindDerived(ClassTy &c) {
.def("__iter__", [](const PyArrayAttribute &arr) {
return PyArrayAttributeIterator(arr);
});
- c.def("__add__", [](PyArrayAttribute arr, const nb::list &extras) {
+ c.def("__add__", [](PyArrayAttribute arr,
+ nb::typed<nb::list, PyAttribute> extras) {
std::vector<MlirAttribute> attributes;
intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
attributes.reserve(numOldElements + nb::len(extras));
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index f3f1ee4ce343f..e1c9fadcd75f0 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2489,7 +2489,7 @@ void PyOpAttributeMap::bind(nb::module_ &m) {
"exist.")
.def(
"__iter__",
- [](PyOpAttributeMap &self) {
+ [](PyOpAttributeMap &self) -> nb::typed<nb::iterator, nb::str> {
nb::list keys;
PyOpAttributeMap::forEachAttr(
self.operation->get(), [&](MlirStringRef name, MlirAttribute) {
@@ -2500,7 +2500,7 @@ void PyOpAttributeMap::bind(nb::module_ &m) {
"Iterates over attribute names.")
.def(
"keys",
- [](PyOpAttributeMap &self) {
+ [](PyOpAttributeMap &self) -> nb::typed<nb::list, nb::str> {
nb::list out;
PyOpAttributeMap::forEachAttr(
self.operation->get(), [&](MlirStringRef name, MlirAttribute) {
@@ -2511,7 +2511,7 @@ void PyOpAttributeMap::bind(nb::module_ &m) {
"Returns a list of attribute names.")
.def(
"values",
- [](PyOpAttributeMap &self) {
+ [](PyOpAttributeMap &self) -> nb::typed<nb::list, PyAttribute> {
nb::list out;
PyOpAttributeMap::forEachAttr(
self.operation->get(), [&](MlirStringRef, MlirAttribute attr) {
@@ -2523,7 +2523,7 @@ void PyOpAttributeMap::bind(nb::module_ &m) {
"Returns a list of attribute values.")
.def(
"items",
- [](PyOpAttributeMap &self) {
+ [](PyOpAttributeMap &self) -> nb::typed<nb::list, nb::tuple> {
nb::list out;
PyOpAttributeMap::forEachAttr(
self.operation->get(),
@@ -4139,6 +4139,9 @@ void populateIRCore(nb::module_ &m) {
"cls"_a, "results"_a = nb::none(), "operands"_a = nb::none(),
"attributes"_a = nb::none(), "successors"_a = nb::none(),
"regions"_a = nb::none(), "loc"_a = nb::none(), "ip"_a = nb::none(),
+ // clang-format off
+ nb::sig("def build_generic(cls, results: Sequence[Type] | None = None, operands: Sequence[Value] | None = None, attributes: dict[str, Attribute] | None = None, successors: Sequence[Block] | None = None, regions: int | None = None, loc: Location | None = None, ip: InsertionPoint | None = None) -> typing.Self"),
+ // clang-format on
"Builds a specific, generated OpView based on class level attributes.");
opViewClass.attr("parse") = classmethod(
[](const nb::object &cls, const std::string &sourceStr,
@@ -4164,6 +4167,9 @@ void populateIRCore(nb::module_ &m) {
},
"cls"_a, "source"_a, nb::kw_only(), "source_name"_a = "",
"context"_a = nb::none(),
+ // clang-format off
+ nb::sig("def parse(cls, source: str, *, source_name: str = '', context: Context | None = None) -> typing.Self"),
+ // clang-format on
"Parses a specific, generated OpView based on class level attributes.");
PyOpAdaptor::bind(m);
@@ -4262,8 +4268,9 @@ void populateIRCore(nb::module_ &m) {
"Returns a forward-optimized sequence of operations.")
.def_static(
"create_at_start",
- [](PyRegion &parent, const nb::sequence &pyArgTypes,
- const std::optional<nb::sequence> &pyArgLocs) {
+ [](PyRegion &parent, nb::typed<nb::sequence, PyType> pyArgTypes,
+ const std::optional<nb::typed<nb::sequence, PyLocation>>
+ &pyArgLocs) {
parent.checkValid();
MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
mlirRegionInsertOwnedBlock(parent, 0, block);
@@ -4291,7 +4298,8 @@ void populateIRCore(nb::module_ &m) {
.def(
"create_before",
[](PyBlock &self, const nb::args &pyArgTypes,
- const std::optional<nb::sequence> &pyArgLocs) {
+ const std::optional<nb::typed<nb::sequence, PyLocation>>
+ &pyArgLocs) {
self.checkValid();
MlirBlock block =
createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
@@ -4305,7 +4313,8 @@ void populateIRCore(nb::module_ &m) {
.def(
"create_after",
[](PyBlock &self, const nb::args &pyArgTypes,
- const std::optional<nb::sequence> &pyArgLocs) {
+ const std::optional<nb::typed<nb::sequence, PyLocation>>
+ &pyArgLocs) {
self.checkValid();
MlirBlock block =
createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index e04ff99b6c5fc..340deb019eb20 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -790,7 +790,7 @@ void PyFunctionType::bindDerived(ClassTy &c) {
"Gets a FunctionType from a list of input and result types");
c.def_prop_ro(
"inputs",
- [](PyFunctionType &self) {
+ [](PyFunctionType &self) -> nb::typed<nb::list, PyType> {
MlirType t = self;
nb::list types;
for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
@@ -802,7 +802,7 @@ void PyFunctionType::bindDerived(ClassTy &c) {
"Returns the list of input types in the FunctionType.");
c.def_prop_ro(
"results",
- [](PyFunctionType &self) {
+ [](PyFunctionType &self) -> nb::typed<nb::list, PyType> {
nb::list types;
for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
++i) {
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 3795f5cb2e036..c5b00a561831b 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -9,6 +9,7 @@
from ._mlir_libs._mlir.ir import *
from ._mlir_libs._mlir.ir import _GlobalDebug
+from ._mlir_libs._mlir.ir import _OperationBase
from ._mlir_libs._mlir import (
register_type_caster,
register_value_caster,
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 9bfda1ec02303..141cf430f36ef 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -646,6 +646,21 @@ def SimpleOp : TestOp<"simple"> {
// CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None) -> _ods_ir.OpResultList:
// CHECK: return SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip).results
+// CHECK: @_ods_cext.register_operation(_Dialect)
+// CHECK-LABEL: class SingleTypedResultOp(_ods_ir.OpView):
+// CHECK: OPERATION_NAME = "test.single_typed_result"
+def SingleTypedResultOp : TestOp<"single_typed_result"> {
+ // CHECK: @builtins.property
+ // CHECK: def i64(self) -> _ods_ir.OpResult[_ods_ir.IntegerType]:
+ // CHECK: return self.operation.results[0]
+ let arguments = (ins AnyType:$in);
+ let results = (outs I64:$i64);
+}
+
+// CHECK: def single_typed_result(in_, *, results=None, loc=None, ip=None) ->
+// _ods_ir.OpResult[_ods_ir.IntegerType]: CHECK: return
+// SingleTypedResultOp(in_=in_, results=results, loc=loc, ip=ip).result
+
// CHECK-LABEL: class VariadicAndNormalRegionOp(_ods_ir.OpView):
// CHECK: OPERATION_NAME = "test.variadic_and_normal_region"
def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 84dce9bdf0c6d..39e79e5631479 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -1261,6 +1261,10 @@ static void emitValueBuilder(const Operator &op,
results = ".results";
} else if (op.getNumResults() == 1) {
type = "_ods_ir.OpResult";
+ if (StringRef pythonType =
+ getPythonType(op.getResult(0).constraint.getCppType());
+ !pythonType.empty())
+ type = llvm::formatv("{0}[{1}]", type, pythonType);
results = ".result";
}
os << formatv(valueBuilderTemplate, nameWithoutDialect,
More information about the Mlir-commits
mailing list