[Mlir-commits] [mlir] [MLIR][Python] Improve and test python type inference (PR #175431)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Jan 11 03:55:01 PST 2026
https://github.com/MaPePeR created https://github.com/llvm/llvm-project/pull/175431
I improved the return types for `OpAttributeMap.{keys,values,items}` and added a testcase that checks the inferred return types of various `mlir.ir` objects and their propertys/functions that uses [ty](https://github.com/astral-sh/ty) and the `reveal_type` function.
>From 220550df5588943ebcec75c27c720ac87cb50285 Mon Sep 17 00:00:00 2001
From: MaPePeR <MaPePeR at users.noreply.github.com>
Date: Sun, 11 Jan 2026 11:38:35 +0000
Subject: [PATCH 1/2] [MLIR][Python] Add testcase that checks type inference
using `ty`.
---
mlir/python/requirements.txt | 1 +
mlir/test/python/ir/test_type_inference.py | 168 +++++++++++++++++++++
2 files changed, 169 insertions(+)
create mode 100644 mlir/test/python/ir/test_type_inference.py
diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt
index d7b89d5ce6b92..735f265f92742 100644
--- a/mlir/python/requirements.txt
+++ b/mlir/python/requirements.txt
@@ -2,6 +2,7 @@
nanobind>=2.9, <3.0
PyYAML>=5.4.0, <=6.0.1
typing_extensions>=4.12.2
+ty>=0.0.11
# RUN dependencies
numpy>=1.19.5, <=2.1.2
ml_dtypes>=0.1.0, <=0.6.0; python_version<"3.13" # provides several NumPy dtype extensions, including the bf16
diff --git a/mlir/test/python/ir/test_type_inference.py b/mlir/test/python/ir/test_type_inference.py
new file mode 100644
index 0000000000000..18433dce7c0a7
--- /dev/null
+++ b/mlir/test/python/ir/test_type_inference.py
@@ -0,0 +1,168 @@
+# RUN: %PYTHON -m ty check --output-format concise %s | FileCheck %s
+
+from mlir.ir import (
+ Module,
+ Context,
+ Region,
+ Block,
+ OpView,
+ Attribute,
+ OpOperandList,
+ RegionSequence,
+ OpResultList,
+ OpResult,
+ BlockList,
+ BlockArgumentList,
+ BlockPredecessors,
+ BlockSuccessors,
+ OperationList,
+ OpAttributeMap,
+ NamedAttribute,
+)
+from typing import reveal_type
+
+# This file is not a valid python program. It is only used to test type inference.
+
+module = Module.create()
+# CHECK: Revealed type: `Module`
+reveal_type(module)
+
+if True: # Tests for Module
+ # CHECK: Revealed type: `Block`
+ reveal_type(module.body)
+
+ # CHECK: Revealed type: `Operation`
+ reveal_type(module.operation)
+
+if True: # Tests for Block
+ block: Block = module.body
+
+ for block_iter in block:
+ # CHECK: Revealed type: `OpView`
+ reveal_type(block_iter)
+
+ # CHECK: Revealed type: `BlockArgumentList`
+ reveal_type(block.arguments)
+
+ if True: # Tests for BlockArgumentList
+ block_arguments_list: BlockArgumentList = block.arguments
+
+ # CHECK: Revealed type: `BlockArgument`
+ reveal_type(block.arguments[0])
+ for block_arguments_iter in block.arguments:
+ # CHECK: Revealed type: `BlockArgument`
+ reveal_type(block_arguments_iter)
+
+ # CHECK: Revealed type: `OperationList`
+ reveal_type(block.operations)
+ if True: # Tests for OperationList
+ operation_list: OperationList = block.operations
+
+ # CHECK: Revealed type: `OpView`
+ reveal_type(operation_list[0])
+ for operation_list_iter in operation_list:
+ # CHECK: Revealed type: `OpView`
+ reveal_type(operation_list_iter)
+
+ # CHECK: Revealed type: `BlockPredecessors`
+ reveal_type(block.predecessors)
+ if True: # Tests for BlockPredecessors
+ block_predecessors: BlockPredecessors = block.predecessors
+
+ # CHECK: Revealed type: `Block`
+ reveal_type(block_predecessors[0])
+ for block_predecessors_iter in block_predecessors:
+ # CHECK: Revealed type: `Block`
+ reveal_type(block_predecessors_iter)
+
+ # CHECK: Revealed type: `BlockSuccessors`
+ reveal_type(block.successors)
+ if True: # Tests for BlockSuccessors
+ block_successors: BlockSuccessors = block.successors
+
+ # CHECK: Revealed type: `Block`
+ reveal_type(block_successors[0])
+ for block_successors_iter in block_successors:
+ # CHECK: Revealed type: `Block`
+ reveal_type(block_successors_iter)
+
+if True: # Tests for OpView
+ opview: OpView = module.body.operations[0]
+
+ # CHECK: Revealed type: `OpAttributeMap`
+ reveal_type(opview.attributes)
+ if True: # Tests for OpAttributeMap
+ attribue_map: OpAttributeMap = opview.attributes
+
+ # CHECK: Revealed type: `NamedAttribute`
+ reveal_type(attribue_map[0])
+
+ # CHECK: Revealed type: `Attribute`
+ reveal_type(attribue_map["str"])
+
+ # This type hint is a lie, because `get` will also return any other default argument
+ # CHECK: Revealed type: `Attribute | None`
+ reveal_type(attribue_map.get("str"))
+
+ # CHECK: Revealed type: `list[tuple[str, Attribute]]`
+ reveal_type(attribue_map.items())
+
+ # CHECK: Revealed type: `list[str]`
+ reveal_type(attribue_map.keys())
+
+ # CHECK: Revealed type: `list[Attribute]`
+ reveal_type(attribue_map.values())
+
+ # CHECK: Revealed type: `OpOperandList`
+ reveal_type(opview.operands)
+ if True: # Tests for OpOperandList
+ op_operands_list: OpOperandList = opview.operands
+
+ # CHECK: Revealed type: `Value`
+ reveal_type(op_operands_list[0])
+ for op_operands_list_iter in op_operands_list:
+ # CHECK: Revealed type: `Value`
+ reveal_type(op_operands_list_iter)
+
+ # CHECK: Revealed type: `RegionSequence`
+ reveal_type(opview.regions)
+ if True: # Tests for RegionSequence
+ region_sequence: RegionSequence = opview.regions
+
+ # CHECK: Revealed type: `Region`
+ reveal_type(region_sequence[0])
+ for regions_sequence_iter in region_sequence:
+ # CHECK: Revealed type: `Region`
+ reveal_type(regions_sequence_iter)
+
+ # CHECK: Revealed type: `OpResultList`
+ reveal_type(opview.results)
+ if True: # Tests for OpResultList
+ result_list: OpResultList = opview.results
+
+ # CHECK: Revealed type: `OpResult`
+ reveal_type(result_list[0])
+ for result_list_iter in result_list:
+ # CHECK: Revealed type: `OpResult`
+ reveal_type(result_list_iter)
+
+ # CHECK: Revealed type: `OpResult`
+ reveal_type(opview.result)
+
+if True: # Tests for Region
+ region: Region = module.body.operations[0].regions[0]
+
+ for region_iter in region:
+ # CHECK: Revealed type: `Block`
+ reveal_type(region_iter)
+
+ # CHECK: Revealed type: `BlockList`
+ reveal_type(region.blocks)
+ if True: # Tests for BlockList
+ blocklist: BlockList = region.blocks
+
+ # CHECK: Revealed type: `Block`
+ reveal_type(blocklist[0])
+ for blocklist_iter in blocklist:
+ # CHECK: Revealed type: `Block`
+ reveal_type(blocklist_iter)
>From d4738798aab005aeed20b6ad4d21e4262b3c1f55 Mon Sep 17 00:00:00 2001
From: MaPePeR <MaPePeR at users.noreply.github.com>
Date: Sun, 11 Jan 2026 11:39:19 +0000
Subject: [PATCH 2/2] [MLIR][Python] Set return type for
`OpAttributeMap.{keys,values,items}`
---
mlir/lib/Bindings/Python/IRCore.cpp | 8 +++++---
1 file changed, 5 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index f04f0b6271630..8dd2f297b0aee 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2476,7 +2476,7 @@ void PyOpAttributeMap::bind(nb::module_ &m) {
"Iterates over attribute names.")
.def(
"keys",
- [](PyOpAttributeMap &self) {
+ [](PyOpAttributeMap &self) -> nb::typed<nb::list, std::string> {
nb::list out;
PyOpAttributeMap::forEachAttr(
self.operation->get(), [&](MlirStringRef name, MlirAttribute) {
@@ -2487,7 +2487,7 @@ void PyOpAttributeMap::bind(nb::module_ &m) {
"Returns a list of attribute names.")
.def(
"values",
- [](PyOpAttributeMap &self) {
+ [](PyOpAttributeMap &self) -> nb::typed<nb::list, PyAttribute> {
nb::list out;
PyOpAttributeMap::forEachAttr(
self.operation->get(), [&](MlirStringRef, MlirAttribute attr) {
@@ -2499,7 +2499,9 @@ void PyOpAttributeMap::bind(nb::module_ &m) {
"Returns a list of attribute values.")
.def(
"items",
- [](PyOpAttributeMap &self) {
+ [](PyOpAttributeMap &self)
+ -> nb::typed<nb::list,
+ nb::typed<nb::tuple, std::string, PyAttribute>> {
nb::list out;
PyOpAttributeMap::forEachAttr(
self.operation->get(),
More information about the Mlir-commits
mailing list