[Mlir-commits] [mlir] [MLIR][Python] Use infer_type() as a new field specifier (PR #191849)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Apr 13 09:53:04 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Twice (PragmaTwice)
<details>
<summary>Changes</summary>
Currrently the signature of `result(..)` is:
```python
result(*, infer_type: bool = False, default_factory: Callable[[], Any] | None = None, kw_only: bool = False) -> Result
```
so when users use `result(infer_type=True)`, the type checkers will still get `kw_only=False` (from the signature), but actually the `kw_only` should be `True` (it should follow the value of `infer_type`). users can use `result(infer_type=True, kw_only=True)` but it's unnecessarily verbose.
So it may introduce an incompatibility when we start to use `dataclass_transform`. currently it's fine because we just don't use `dataclass_transform`. But when we use, we may require a breaking change.
This PR migrates such use to a new field specifier named `infer_result()`.
---
Full diff: https://github.com/llvm/llvm-project/pull/191849.diff
5 Files Affected:
- (modified) mlir/docs/Bindings/Python.md (+2-2)
- (modified) mlir/python/mlir/dialects/ext.py (+14-8)
- (modified) mlir/test/python/dialects/ext.py (+8-8)
- (modified) mlir/test/python/dialects/transform_op_interface.py (+5-5)
- (modified) mlir/test/python/integration/dialects/bf.py (+3-3)
``````````diff
diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index c43181859f968..6ab08d435f4b2 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -1330,12 +1330,12 @@ class MyInt(Dialect, name="myint"):
class ConstantOp(MyInt.Operation, name="constant"):
value: IntegerAttr
- cst: Result[IntegerType[32]] = result(infer_type=True)
+ cst: Result[IntegerType[32]] = infer_result()
class AddOp(MyInt.Operation, name="add"):
lhs: Operand[IntegerType[32]]
rhs: Operand[IntegerType[32]]
- res: Result[IntegerType[32]] = result(infer_type=True)
+ res: Result[IntegerType[32]] = infer_result()
# The code below requires an available MLIR context and location.
diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 89051bf9ec924..c2efa9bb773cc 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -36,6 +36,7 @@
"Type",
"Attribute",
"result",
+ "infer_result",
"operand",
"attribute",
]
@@ -128,26 +129,31 @@ def param_kind(self):
def result(
*,
- infer_type: bool = False,
default_factory: Optional[Callable[[], Any]] = None,
kw_only: bool = False,
) -> Result:
"""
A field specifier for `Result` definitions.
"""
- if infer_type and default_factory:
- raise ValueError(
- "a result field cannot have both infer_type and default_factory"
- )
return FieldSpecifier(
type_=Result,
- infer_type=infer_type,
default_factory=default_factory,
kw_only=kw_only,
)
+def infer_result() -> Result:
+ """
+ A field specifier for `Result` definitions with type inference enabled.
+ """
+
+ return FieldSpecifier(
+ type_=Result,
+ infer_type=True,
+ )
+
+
def operand(
*,
kw_only: bool = False,
@@ -894,12 +900,12 @@ class MyInt(Dialect, name="myint"):
class ConstantOp(MyInt.Operation, name="constant"):
value: IntegerAttr
- cst: Result[i32]
+ cst: Result[i32] = infer_result()
class AddOp(MyInt.Operation, name="add"):
lhs: Operand[i32]
rhs: Operand[i32]
- res: Result[i32]
+ res: Result[i32] = infer_result()
```
"""
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index 8dfc74ad29d4f..cfdbeb3362735 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -23,12 +23,12 @@ class MyInt(Dialect, name="myint"):
class ConstantOp(MyInt.Operation, name="constant"):
value: IntegerAttr
- cst: Result[i32] = result(infer_type=True)
+ cst: Result[i32] = infer_result()
class AddOp(Operation, dialect=MyInt, name="add"):
lhs: Operand[i32]
rhs: Operand[i32]
- res: Result[i32] = result(infer_type=True)
+ res: Result[i32] = infer_result()
# CHECK: irdl.dialect @myint {
# CHECK: irdl.operation @constant {
@@ -575,9 +575,9 @@ class MakeArrayOp(TestType.Operation, name="make_array"):
arr: Result[Array]
class MakeArray3Op(TestType.Operation, name="make_array3"):
- arr: Result[Array[IntegerType[32], IntegerAttr[IntegerType[32], 3]]] = result(
- infer_type=True
- )
+ arr: Result[
+ Array[IntegerType[32], IntegerAttr[IntegerType[32], 3]]
+ ] = infer_result()
with Context(), Location.unknown():
TestType.load()
@@ -738,7 +738,7 @@ class TestInvalid(Dialect, name="ext_invalid"):
class InferTypeBeforePositionalOp(
TestInvalid.Operation, name="infer_before_pos"
):
- res: Result[IntegerType[32]] = result(infer_type=True)
+ res: Result[IntegerType[32]] = infer_result()
a: Operand[IntegerType[32]]
except ValueError as e:
@@ -770,7 +770,7 @@ class AssignNoneOnnAttributeOp(
try:
class CannotInferTypeOp(TestInvalid.Operation, name="cannot_infer_type"):
- a: Result[IntegerType] = result(infer_type=True)
+ a: Result[IntegerType] = infer_result()
except TypeError as e:
# CHECK: unsupported type for inferring
@@ -852,7 +852,7 @@ class OperandSpecifierOp(TestFieldSpecifiers.Operation, name="operand_specifier"
class ResultSpecifierOp(TestFieldSpecifiers.Operation, name="result_specifier"):
a: Result[IntegerType[32]] = result()
- b: Result[IntegerType[16]] = result(infer_type=True)
+ b: Result[IntegerType[16]] = infer_result()
c: Result[IntegerType] = result(
default_factory=lambda: IntegerType.get_signless(8)
)
diff --git a/mlir/test/python/dialects/transform_op_interface.py b/mlir/test/python/dialects/transform_op_interface.py
index b9a1c43e111af..b0b416530eccc 100644
--- a/mlir/test/python/dialects/transform_op_interface.py
+++ b/mlir/test/python/dialects/transform_op_interface.py
@@ -88,7 +88,7 @@ def get_effects(op: ir.Operation, effects):
class GetNamedAttributeOp(MyTransform.Operation, name="get_named_attribute"):
target: ext.Operand[transform.AnyOpType]
attr_name: ir.StringAttr
- attr_as_param: ext.Result[transform.AnyParamType[()]] = ext.result(infer_type=True)
+ attr_as_param: ext.Result[transform.AnyParamType[()]] = ext.infer_result()
@classmethod
def attach_interface_impls(cls, ctx=None):
@@ -149,7 +149,7 @@ def allow_repeated_handle_operands(_op: "GetNamedAttributeOp") -> bool:
# Syntax for an op with one op handle operand and one op handle result.
class OneOpInOneOpOut(MyTransform.Operation, name="one_op_in_one_op_out"):
target: ext.Operand[transform.AnyOpType]
- res: ext.Result[transform.AnyOpType[()]] = ext.result(infer_type=True)
+ res: ext.Result[transform.AnyOpType[()]] = ext.infer_result()
# CHECK-LABEL: Test: OneOpInOneOpOutTransformOpInterface
@@ -277,9 +277,9 @@ class OpValParamInParamOpValOut(
val_arg: ext.Operand[transform.AnyValueType]
param_arg: ext.Operand[transform.AnyParamType]
# results
- param_res: ext.Result[transform.AnyParamType[()]] = ext.result(infer_type=True)
- op_res: ext.Result[transform.AnyOpType[()]] = ext.result(infer_type=True)
- value_res: ext.Result[transform.AnyValueType[()]] = ext.result(infer_type=True)
+ param_res: ext.Result[transform.AnyParamType[()]] = ext.infer_result()
+ op_res: ext.Result[transform.AnyOpType[()]] = ext.infer_result()
+ value_res: ext.Result[transform.AnyValueType[()]] = ext.infer_result()
# CHECK-LABEL: Test: OpValParamInParamOpValOutTransformOpInterface
diff --git a/mlir/test/python/integration/dialects/bf.py b/mlir/test/python/integration/dialects/bf.py
index e38155dcda973..23632b7f4b9c9 100644
--- a/mlir/test/python/integration/dialects/bf.py
+++ b/mlir/test/python/integration/dialects/bf.py
@@ -20,12 +20,12 @@ class PtrType(BfDialect.Type, name="ptr"):
class NextOp(BfDialect.Operation, name="next"):
in_: Operand[PtrType]
- out: Result[PtrType[()]] = result(infer_type=True)
+ out: Result[PtrType[()]] = infer_result()
class PrevOp(BfDialect.Operation, name="prev"):
in_: Operand[PtrType]
- out: Result[PtrType[()]] = result(infer_type=True)
+ out: Result[PtrType[()]] = infer_result()
class IncOp(BfDialect.Operation, name="inc"):
@@ -46,7 +46,7 @@ class OutputOp(BfDialect.Operation, name="output"):
class WhileOp(BfDialect.Operation, name="while"):
in_: Operand[PtrType]
- out: Result[PtrType[()]] = result(infer_type=True)
+ out: Result[PtrType[()]] = infer_result()
body: Region
``````````
</details>
https://github.com/llvm/llvm-project/pull/191849
More information about the Mlir-commits
mailing list