[Mlir-commits] [mlir] [MLIR][Python] Improve and test python type inference (PR #175431)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jan 11 03:55:28 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: None (MaPePeR)

<details>
<summary>Changes</summary>

I improved the return types for `OpAttributeMap.{keys,values,items}` and added a testcase that checks the inferred return types of various `mlir.ir` objects and their propertys/functions that uses [ty](https://github.com/astral-sh/ty) and the `reveal_type` function. 

---
Full diff: https://github.com/llvm/llvm-project/pull/175431.diff


3 Files Affected:

- (modified) mlir/lib/Bindings/Python/IRCore.cpp (+5-3) 
- (modified) mlir/python/requirements.txt (+1) 
- (added) mlir/test/python/ir/test_type_inference.py (+168) 


``````````diff
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index f04f0b6271630..8dd2f297b0aee 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2476,7 +2476,7 @@ void PyOpAttributeMap::bind(nb::module_ &m) {
           "Iterates over attribute names.")
       .def(
           "keys",
-          [](PyOpAttributeMap &self) {
+          [](PyOpAttributeMap &self) -> nb::typed<nb::list, std::string> {
             nb::list out;
             PyOpAttributeMap::forEachAttr(
                 self.operation->get(), [&](MlirStringRef name, MlirAttribute) {
@@ -2487,7 +2487,7 @@ void PyOpAttributeMap::bind(nb::module_ &m) {
           "Returns a list of attribute names.")
       .def(
           "values",
-          [](PyOpAttributeMap &self) {
+          [](PyOpAttributeMap &self) -> nb::typed<nb::list, PyAttribute> {
             nb::list out;
             PyOpAttributeMap::forEachAttr(
                 self.operation->get(), [&](MlirStringRef, MlirAttribute attr) {
@@ -2499,7 +2499,9 @@ void PyOpAttributeMap::bind(nb::module_ &m) {
           "Returns a list of attribute values.")
       .def(
           "items",
-          [](PyOpAttributeMap &self) {
+          [](PyOpAttributeMap &self)
+              -> nb::typed<nb::list,
+                           nb::typed<nb::tuple, std::string, PyAttribute>> {
             nb::list out;
             PyOpAttributeMap::forEachAttr(
                 self.operation->get(),
diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt
index d7b89d5ce6b92..735f265f92742 100644
--- a/mlir/python/requirements.txt
+++ b/mlir/python/requirements.txt
@@ -2,6 +2,7 @@
 nanobind>=2.9, <3.0
 PyYAML>=5.4.0, <=6.0.1
 typing_extensions>=4.12.2
+ty>=0.0.11
 # RUN dependencies
 numpy>=1.19.5, <=2.1.2
 ml_dtypes>=0.1.0, <=0.6.0; python_version<"3.13"   # provides several NumPy dtype extensions, including the bf16
diff --git a/mlir/test/python/ir/test_type_inference.py b/mlir/test/python/ir/test_type_inference.py
new file mode 100644
index 0000000000000..18433dce7c0a7
--- /dev/null
+++ b/mlir/test/python/ir/test_type_inference.py
@@ -0,0 +1,168 @@
+# RUN: %PYTHON -m ty check --output-format concise %s | FileCheck %s
+
+from mlir.ir import (
+    Module,
+    Context,
+    Region,
+    Block,
+    OpView,
+    Attribute,
+    OpOperandList,
+    RegionSequence,
+    OpResultList,
+    OpResult,
+    BlockList,
+    BlockArgumentList,
+    BlockPredecessors,
+    BlockSuccessors,
+    OperationList,
+    OpAttributeMap,
+    NamedAttribute,
+)
+from typing import reveal_type
+
+# This file is not a valid python program. It is only used to test type inference.
+
+module = Module.create()
+# CHECK: Revealed type: `Module`
+reveal_type(module)
+
+if True:  # Tests for Module
+    # CHECK: Revealed type: `Block`
+    reveal_type(module.body)
+
+    # CHECK: Revealed type: `Operation`
+    reveal_type(module.operation)
+
+if True:  # Tests for Block
+    block: Block = module.body
+
+    for block_iter in block:
+        # CHECK: Revealed type: `OpView`
+        reveal_type(block_iter)
+
+    # CHECK: Revealed type: `BlockArgumentList`
+    reveal_type(block.arguments)
+
+    if True:  # Tests for BlockArgumentList
+        block_arguments_list: BlockArgumentList = block.arguments
+
+        # CHECK: Revealed type: `BlockArgument`
+        reveal_type(block.arguments[0])
+        for block_arguments_iter in block.arguments:
+            # CHECK: Revealed type: `BlockArgument`
+            reveal_type(block_arguments_iter)
+
+    # CHECK: Revealed type: `OperationList`
+    reveal_type(block.operations)
+    if True:  # Tests for OperationList
+        operation_list: OperationList = block.operations
+
+        # CHECK: Revealed type: `OpView`
+        reveal_type(operation_list[0])
+        for operation_list_iter in operation_list:
+            # CHECK: Revealed type: `OpView`
+            reveal_type(operation_list_iter)
+
+    # CHECK: Revealed type: `BlockPredecessors`
+    reveal_type(block.predecessors)
+    if True:  # Tests for BlockPredecessors
+        block_predecessors: BlockPredecessors = block.predecessors
+
+        # CHECK: Revealed type: `Block`
+        reveal_type(block_predecessors[0])
+        for block_predecessors_iter in block_predecessors:
+            # CHECK: Revealed type: `Block`
+            reveal_type(block_predecessors_iter)
+
+    # CHECK: Revealed type: `BlockSuccessors`
+    reveal_type(block.successors)
+    if True:  # Tests for BlockSuccessors
+        block_successors: BlockSuccessors = block.successors
+
+        # CHECK: Revealed type: `Block`
+        reveal_type(block_successors[0])
+        for block_successors_iter in block_successors:
+            # CHECK: Revealed type: `Block`
+            reveal_type(block_successors_iter)
+
+if True:  # Tests for OpView
+    opview: OpView = module.body.operations[0]
+
+    # CHECK: Revealed type: `OpAttributeMap`
+    reveal_type(opview.attributes)
+    if True:  # Tests for OpAttributeMap
+        attribue_map: OpAttributeMap = opview.attributes
+
+        # CHECK: Revealed type: `NamedAttribute`
+        reveal_type(attribue_map[0])
+
+        # CHECK: Revealed type: `Attribute`
+        reveal_type(attribue_map["str"])
+
+        # This type hint is a lie, because `get` will also return any other default argument
+        # CHECK: Revealed type: `Attribute | None`
+        reveal_type(attribue_map.get("str"))
+
+        # CHECK: Revealed type: `list[tuple[str, Attribute]]`
+        reveal_type(attribue_map.items())
+
+        # CHECK: Revealed type: `list[str]`
+        reveal_type(attribue_map.keys())
+
+        # CHECK: Revealed type: `list[Attribute]`
+        reveal_type(attribue_map.values())
+
+    # CHECK: Revealed type: `OpOperandList`
+    reveal_type(opview.operands)
+    if True:  # Tests for OpOperandList
+        op_operands_list: OpOperandList = opview.operands
+
+        # CHECK: Revealed type: `Value`
+        reveal_type(op_operands_list[0])
+        for op_operands_list_iter in op_operands_list:
+            # CHECK: Revealed type: `Value`
+            reveal_type(op_operands_list_iter)
+
+    # CHECK: Revealed type: `RegionSequence`
+    reveal_type(opview.regions)
+    if True:  # Tests for RegionSequence
+        region_sequence: RegionSequence = opview.regions
+
+        # CHECK: Revealed type: `Region`
+        reveal_type(region_sequence[0])
+        for regions_sequence_iter in region_sequence:
+            # CHECK: Revealed type: `Region`
+            reveal_type(regions_sequence_iter)
+
+    # CHECK: Revealed type: `OpResultList`
+    reveal_type(opview.results)
+    if True:  # Tests for OpResultList
+        result_list: OpResultList = opview.results
+
+        # CHECK: Revealed type: `OpResult`
+        reveal_type(result_list[0])
+        for result_list_iter in result_list:
+            # CHECK: Revealed type: `OpResult`
+            reveal_type(result_list_iter)
+
+    # CHECK: Revealed type: `OpResult`
+    reveal_type(opview.result)
+
+if True:  # Tests for Region
+    region: Region = module.body.operations[0].regions[0]
+
+    for region_iter in region:
+        # CHECK: Revealed type: `Block`
+        reveal_type(region_iter)
+
+    # CHECK: Revealed type: `BlockList`
+    reveal_type(region.blocks)
+    if True:  # Tests for BlockList
+        blocklist: BlockList = region.blocks
+
+        # CHECK: Revealed type: `Block`
+        reveal_type(blocklist[0])
+        for blocklist_iter in blocklist:
+            # CHECK: Revealed type: `Block`
+            reveal_type(blocklist_iter)

``````````

</details>


https://github.com/llvm/llvm-project/pull/175431


More information about the Mlir-commits mailing list