[Mlir-commits] [mlir] [MLIR][Python] Use infer_type() as a new field specifier (PR #191849)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Apr 13 09:50:17 PDT 2026
https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/191849
>From 5266a8930e13dcf99517eb66ca14c459edf8d25d Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Tue, 14 Apr 2026 00:46:06 +0800
Subject: [PATCH 1/2] [MLIR][Python] Use infer_type() as a new field specifier
---
mlir/docs/Bindings/Python.md | 4 ++--
mlir/python/mlir/dialects/ext.py | 22 ++++++++++++-------
mlir/test/python/dialects/ext.py | 14 ++++++------
.../python/dialects/transform_op_interface.py | 10 ++++-----
mlir/test/python/integration/dialects/bf.py | 6 ++---
5 files changed, 31 insertions(+), 25 deletions(-)
diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index c43181859f968..6ab08d435f4b2 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -1330,12 +1330,12 @@ class MyInt(Dialect, name="myint"):
class ConstantOp(MyInt.Operation, name="constant"):
value: IntegerAttr
- cst: Result[IntegerType[32]] = result(infer_type=True)
+ cst: Result[IntegerType[32]] = infer_result()
class AddOp(MyInt.Operation, name="add"):
lhs: Operand[IntegerType[32]]
rhs: Operand[IntegerType[32]]
- res: Result[IntegerType[32]] = result(infer_type=True)
+ res: Result[IntegerType[32]] = infer_result()
# The code below requires an available MLIR context and location.
diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 89051bf9ec924..c2efa9bb773cc 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -36,6 +36,7 @@
"Type",
"Attribute",
"result",
+ "infer_result",
"operand",
"attribute",
]
@@ -128,26 +129,31 @@ def param_kind(self):
def result(
*,
- infer_type: bool = False,
default_factory: Optional[Callable[[], Any]] = None,
kw_only: bool = False,
) -> Result:
"""
A field specifier for `Result` definitions.
"""
- if infer_type and default_factory:
- raise ValueError(
- "a result field cannot have both infer_type and default_factory"
- )
return FieldSpecifier(
type_=Result,
- infer_type=infer_type,
default_factory=default_factory,
kw_only=kw_only,
)
+def infer_result() -> Result:
+ """
+ A field specifier for `Result` definitions with type inference enabled.
+ """
+
+ return FieldSpecifier(
+ type_=Result,
+ infer_type=True,
+ )
+
+
def operand(
*,
kw_only: bool = False,
@@ -894,12 +900,12 @@ class MyInt(Dialect, name="myint"):
class ConstantOp(MyInt.Operation, name="constant"):
value: IntegerAttr
- cst: Result[i32]
+ cst: Result[i32] = infer_result()
class AddOp(MyInt.Operation, name="add"):
lhs: Operand[i32]
rhs: Operand[i32]
- res: Result[i32]
+ res: Result[i32] = infer_result()
```
"""
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index 8dfc74ad29d4f..733f18de402c8 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -23,12 +23,12 @@ class MyInt(Dialect, name="myint"):
class ConstantOp(MyInt.Operation, name="constant"):
value: IntegerAttr
- cst: Result[i32] = result(infer_type=True)
+ cst: Result[i32] = infer_result()
class AddOp(Operation, dialect=MyInt, name="add"):
lhs: Operand[i32]
rhs: Operand[i32]
- res: Result[i32] = result(infer_type=True)
+ res: Result[i32] = infer_result()
# CHECK: irdl.dialect @myint {
# CHECK: irdl.operation @constant {
@@ -575,8 +575,8 @@ class MakeArrayOp(TestType.Operation, name="make_array"):
arr: Result[Array]
class MakeArray3Op(TestType.Operation, name="make_array3"):
- arr: Result[Array[IntegerType[32], IntegerAttr[IntegerType[32], 3]]] = result(
- infer_type=True
+ arr: Result[Array[IntegerType[32], IntegerAttr[IntegerType[32], 3]]] = (
+ infer_result()
)
with Context(), Location.unknown():
@@ -738,7 +738,7 @@ class TestInvalid(Dialect, name="ext_invalid"):
class InferTypeBeforePositionalOp(
TestInvalid.Operation, name="infer_before_pos"
):
- res: Result[IntegerType[32]] = result(infer_type=True)
+ res: Result[IntegerType[32]] = infer_result()
a: Operand[IntegerType[32]]
except ValueError as e:
@@ -770,7 +770,7 @@ class AssignNoneOnnAttributeOp(
try:
class CannotInferTypeOp(TestInvalid.Operation, name="cannot_infer_type"):
- a: Result[IntegerType] = result(infer_type=True)
+ a: Result[IntegerType] = infer_result()
except TypeError as e:
# CHECK: unsupported type for inferring
@@ -852,7 +852,7 @@ class OperandSpecifierOp(TestFieldSpecifiers.Operation, name="operand_specifier"
class ResultSpecifierOp(TestFieldSpecifiers.Operation, name="result_specifier"):
a: Result[IntegerType[32]] = result()
- b: Result[IntegerType[16]] = result(infer_type=True)
+ b: Result[IntegerType[16]] = infer_result()
c: Result[IntegerType] = result(
default_factory=lambda: IntegerType.get_signless(8)
)
diff --git a/mlir/test/python/dialects/transform_op_interface.py b/mlir/test/python/dialects/transform_op_interface.py
index b9a1c43e111af..b0b416530eccc 100644
--- a/mlir/test/python/dialects/transform_op_interface.py
+++ b/mlir/test/python/dialects/transform_op_interface.py
@@ -88,7 +88,7 @@ def get_effects(op: ir.Operation, effects):
class GetNamedAttributeOp(MyTransform.Operation, name="get_named_attribute"):
target: ext.Operand[transform.AnyOpType]
attr_name: ir.StringAttr
- attr_as_param: ext.Result[transform.AnyParamType[()]] = ext.result(infer_type=True)
+ attr_as_param: ext.Result[transform.AnyParamType[()]] = ext.infer_result()
@classmethod
def attach_interface_impls(cls, ctx=None):
@@ -149,7 +149,7 @@ def allow_repeated_handle_operands(_op: "GetNamedAttributeOp") -> bool:
# Syntax for an op with one op handle operand and one op handle result.
class OneOpInOneOpOut(MyTransform.Operation, name="one_op_in_one_op_out"):
target: ext.Operand[transform.AnyOpType]
- res: ext.Result[transform.AnyOpType[()]] = ext.result(infer_type=True)
+ res: ext.Result[transform.AnyOpType[()]] = ext.infer_result()
# CHECK-LABEL: Test: OneOpInOneOpOutTransformOpInterface
@@ -277,9 +277,9 @@ class OpValParamInParamOpValOut(
val_arg: ext.Operand[transform.AnyValueType]
param_arg: ext.Operand[transform.AnyParamType]
# results
- param_res: ext.Result[transform.AnyParamType[()]] = ext.result(infer_type=True)
- op_res: ext.Result[transform.AnyOpType[()]] = ext.result(infer_type=True)
- value_res: ext.Result[transform.AnyValueType[()]] = ext.result(infer_type=True)
+ param_res: ext.Result[transform.AnyParamType[()]] = ext.infer_result()
+ op_res: ext.Result[transform.AnyOpType[()]] = ext.infer_result()
+ value_res: ext.Result[transform.AnyValueType[()]] = ext.infer_result()
# CHECK-LABEL: Test: OpValParamInParamOpValOutTransformOpInterface
diff --git a/mlir/test/python/integration/dialects/bf.py b/mlir/test/python/integration/dialects/bf.py
index e38155dcda973..23632b7f4b9c9 100644
--- a/mlir/test/python/integration/dialects/bf.py
+++ b/mlir/test/python/integration/dialects/bf.py
@@ -20,12 +20,12 @@ class PtrType(BfDialect.Type, name="ptr"):
class NextOp(BfDialect.Operation, name="next"):
in_: Operand[PtrType]
- out: Result[PtrType[()]] = result(infer_type=True)
+ out: Result[PtrType[()]] = infer_result()
class PrevOp(BfDialect.Operation, name="prev"):
in_: Operand[PtrType]
- out: Result[PtrType[()]] = result(infer_type=True)
+ out: Result[PtrType[()]] = infer_result()
class IncOp(BfDialect.Operation, name="inc"):
@@ -46,7 +46,7 @@ class OutputOp(BfDialect.Operation, name="output"):
class WhileOp(BfDialect.Operation, name="while"):
in_: Operand[PtrType]
- out: Result[PtrType[()]] = result(infer_type=True)
+ out: Result[PtrType[()]] = infer_result()
body: Region
>From 41b98cc7e04b5868b48c921eedcf11bffd896917 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Tue, 14 Apr 2026 00:50:04 +0800
Subject: [PATCH 2/2] format
---
mlir/test/python/dialects/ext.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index 733f18de402c8..cfdbeb3362735 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -575,9 +575,9 @@ class MakeArrayOp(TestType.Operation, name="make_array"):
arr: Result[Array]
class MakeArray3Op(TestType.Operation, name="make_array3"):
- arr: Result[Array[IntegerType[32], IntegerAttr[IntegerType[32], 3]]] = (
- infer_result()
- )
+ arr: Result[
+ Array[IntegerType[32], IntegerAttr[IntegerType[32], 3]]
+ ] = infer_result()
with Context(), Location.unknown():
TestType.load()
More information about the Mlir-commits
mailing list