[Mlir-commits] [mlir] [MLIR][Python] Use infer_type() as a new field specifier (PR #191849)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 13 09:53:04 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Twice (PragmaTwice)

<details>
<summary>Changes</summary>

Currrently the signature of `result(..)` is:
```python
result(*, infer_type: bool = False, default_factory: Callable[[], Any] | None = None, kw_only: bool = False) -> Result
```

so when users use `result(infer_type=True)`, the type checkers will still get `kw_only=False` (from the signature), but actually the `kw_only` should be `True` (it should follow the value of `infer_type`). users can use `result(infer_type=True, kw_only=True)` but it's unnecessarily verbose.

So it may introduce an incompatibility when we start to use `dataclass_transform`. currently it's fine because we just don't use `dataclass_transform`. But when we use, we may require a breaking change.

This PR migrates such use to a new field specifier named `infer_result()`.

---
Full diff: https://github.com/llvm/llvm-project/pull/191849.diff


5 Files Affected:

- (modified) mlir/docs/Bindings/Python.md (+2-2) 
- (modified) mlir/python/mlir/dialects/ext.py (+14-8) 
- (modified) mlir/test/python/dialects/ext.py (+8-8) 
- (modified) mlir/test/python/dialects/transform_op_interface.py (+5-5) 
- (modified) mlir/test/python/integration/dialects/bf.py (+3-3) 


``````````diff
diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index c43181859f968..6ab08d435f4b2 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -1330,12 +1330,12 @@ class MyInt(Dialect, name="myint"):
 
 class ConstantOp(MyInt.Operation, name="constant"):
     value: IntegerAttr
-    cst: Result[IntegerType[32]] = result(infer_type=True)
+    cst: Result[IntegerType[32]] = infer_result()
 
 class AddOp(MyInt.Operation, name="add"):
     lhs: Operand[IntegerType[32]]
     rhs: Operand[IntegerType[32]]
-    res: Result[IntegerType[32]] = result(infer_type=True)
+    res: Result[IntegerType[32]] = infer_result()
 
 # The code below requires an available MLIR context and location.
 
diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 89051bf9ec924..c2efa9bb773cc 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -36,6 +36,7 @@
     "Type",
     "Attribute",
     "result",
+    "infer_result",
     "operand",
     "attribute",
 ]
@@ -128,26 +129,31 @@ def param_kind(self):
 
 def result(
     *,
-    infer_type: bool = False,
     default_factory: Optional[Callable[[], Any]] = None,
     kw_only: bool = False,
 ) -> Result:
     """
     A field specifier for `Result` definitions.
     """
-    if infer_type and default_factory:
-        raise ValueError(
-            "a result field cannot have both infer_type and default_factory"
-        )
 
     return FieldSpecifier(
         type_=Result,
-        infer_type=infer_type,
         default_factory=default_factory,
         kw_only=kw_only,
     )
 
 
+def infer_result() -> Result:
+    """
+    A field specifier for `Result` definitions with type inference enabled.
+    """
+
+    return FieldSpecifier(
+        type_=Result,
+        infer_type=True,
+    )
+
+
 def operand(
     *,
     kw_only: bool = False,
@@ -894,12 +900,12 @@ class MyInt(Dialect, name="myint"):
 
     class ConstantOp(MyInt.Operation, name="constant"):
         value: IntegerAttr
-        cst: Result[i32]
+        cst: Result[i32] = infer_result()
 
     class AddOp(MyInt.Operation, name="add"):
         lhs: Operand[i32]
         rhs: Operand[i32]
-        res: Result[i32]
+        res: Result[i32] = infer_result()
     ```
     """
 
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index 8dfc74ad29d4f..cfdbeb3362735 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -23,12 +23,12 @@ class MyInt(Dialect, name="myint"):
 
     class ConstantOp(MyInt.Operation, name="constant"):
         value: IntegerAttr
-        cst: Result[i32] = result(infer_type=True)
+        cst: Result[i32] = infer_result()
 
     class AddOp(Operation, dialect=MyInt, name="add"):
         lhs: Operand[i32]
         rhs: Operand[i32]
-        res: Result[i32] = result(infer_type=True)
+        res: Result[i32] = infer_result()
 
     # CHECK: irdl.dialect @myint {
     # CHECK:   irdl.operation @constant {
@@ -575,9 +575,9 @@ class MakeArrayOp(TestType.Operation, name="make_array"):
         arr: Result[Array]
 
     class MakeArray3Op(TestType.Operation, name="make_array3"):
-        arr: Result[Array[IntegerType[32], IntegerAttr[IntegerType[32], 3]]] = result(
-            infer_type=True
-        )
+        arr: Result[
+            Array[IntegerType[32], IntegerAttr[IntegerType[32], 3]]
+        ] = infer_result()
 
     with Context(), Location.unknown():
         TestType.load()
@@ -738,7 +738,7 @@ class TestInvalid(Dialect, name="ext_invalid"):
         class InferTypeBeforePositionalOp(
             TestInvalid.Operation, name="infer_before_pos"
         ):
-            res: Result[IntegerType[32]] = result(infer_type=True)
+            res: Result[IntegerType[32]] = infer_result()
             a: Operand[IntegerType[32]]
 
     except ValueError as e:
@@ -770,7 +770,7 @@ class AssignNoneOnnAttributeOp(
     try:
 
         class CannotInferTypeOp(TestInvalid.Operation, name="cannot_infer_type"):
-            a: Result[IntegerType] = result(infer_type=True)
+            a: Result[IntegerType] = infer_result()
 
     except TypeError as e:
         # CHECK: unsupported type for inferring
@@ -852,7 +852,7 @@ class OperandSpecifierOp(TestFieldSpecifiers.Operation, name="operand_specifier"
 
     class ResultSpecifierOp(TestFieldSpecifiers.Operation, name="result_specifier"):
         a: Result[IntegerType[32]] = result()
-        b: Result[IntegerType[16]] = result(infer_type=True)
+        b: Result[IntegerType[16]] = infer_result()
         c: Result[IntegerType] = result(
             default_factory=lambda: IntegerType.get_signless(8)
         )
diff --git a/mlir/test/python/dialects/transform_op_interface.py b/mlir/test/python/dialects/transform_op_interface.py
index b9a1c43e111af..b0b416530eccc 100644
--- a/mlir/test/python/dialects/transform_op_interface.py
+++ b/mlir/test/python/dialects/transform_op_interface.py
@@ -88,7 +88,7 @@ def get_effects(op: ir.Operation, effects):
 class GetNamedAttributeOp(MyTransform.Operation, name="get_named_attribute"):
     target: ext.Operand[transform.AnyOpType]
     attr_name: ir.StringAttr
-    attr_as_param: ext.Result[transform.AnyParamType[()]] = ext.result(infer_type=True)
+    attr_as_param: ext.Result[transform.AnyParamType[()]] = ext.infer_result()
 
     @classmethod
     def attach_interface_impls(cls, ctx=None):
@@ -149,7 +149,7 @@ def allow_repeated_handle_operands(_op: "GetNamedAttributeOp") -> bool:
 # Syntax for an op with one op handle operand and one op handle result.
 class OneOpInOneOpOut(MyTransform.Operation, name="one_op_in_one_op_out"):
     target: ext.Operand[transform.AnyOpType]
-    res: ext.Result[transform.AnyOpType[()]] = ext.result(infer_type=True)
+    res: ext.Result[transform.AnyOpType[()]] = ext.infer_result()
 
 
 # CHECK-LABEL: Test: OneOpInOneOpOutTransformOpInterface
@@ -277,9 +277,9 @@ class OpValParamInParamOpValOut(
     val_arg: ext.Operand[transform.AnyValueType]
     param_arg: ext.Operand[transform.AnyParamType]
     # results
-    param_res: ext.Result[transform.AnyParamType[()]] = ext.result(infer_type=True)
-    op_res: ext.Result[transform.AnyOpType[()]] = ext.result(infer_type=True)
-    value_res: ext.Result[transform.AnyValueType[()]] = ext.result(infer_type=True)
+    param_res: ext.Result[transform.AnyParamType[()]] = ext.infer_result()
+    op_res: ext.Result[transform.AnyOpType[()]] = ext.infer_result()
+    value_res: ext.Result[transform.AnyValueType[()]] = ext.infer_result()
 
 
 # CHECK-LABEL: Test: OpValParamInParamOpValOutTransformOpInterface
diff --git a/mlir/test/python/integration/dialects/bf.py b/mlir/test/python/integration/dialects/bf.py
index e38155dcda973..23632b7f4b9c9 100644
--- a/mlir/test/python/integration/dialects/bf.py
+++ b/mlir/test/python/integration/dialects/bf.py
@@ -20,12 +20,12 @@ class PtrType(BfDialect.Type, name="ptr"):
 
 class NextOp(BfDialect.Operation, name="next"):
     in_: Operand[PtrType]
-    out: Result[PtrType[()]] = result(infer_type=True)
+    out: Result[PtrType[()]] = infer_result()
 
 
 class PrevOp(BfDialect.Operation, name="prev"):
     in_: Operand[PtrType]
-    out: Result[PtrType[()]] = result(infer_type=True)
+    out: Result[PtrType[()]] = infer_result()
 
 
 class IncOp(BfDialect.Operation, name="inc"):
@@ -46,7 +46,7 @@ class OutputOp(BfDialect.Operation, name="output"):
 
 class WhileOp(BfDialect.Operation, name="while"):
     in_: Operand[PtrType]
-    out: Result[PtrType[()]] = result(infer_type=True)
+    out: Result[PtrType[()]] = infer_result()
     body: Region
 
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/191849


More information about the Mlir-commits mailing list