[Mlir-commits] [mlir] [MLIR][Python] Use infer_type() as a new field specifier (PR #191849)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 13 09:50:17 PDT 2026


https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/191849

>From 5266a8930e13dcf99517eb66ca14c459edf8d25d Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Tue, 14 Apr 2026 00:46:06 +0800
Subject: [PATCH 1/2] [MLIR][Python] Use infer_type() as a new field specifier

---
 mlir/docs/Bindings/Python.md                  |  4 ++--
 mlir/python/mlir/dialects/ext.py              | 22 ++++++++++++-------
 mlir/test/python/dialects/ext.py              | 14 ++++++------
 .../python/dialects/transform_op_interface.py | 10 ++++-----
 mlir/test/python/integration/dialects/bf.py   |  6 ++---
 5 files changed, 31 insertions(+), 25 deletions(-)

diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index c43181859f968..6ab08d435f4b2 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -1330,12 +1330,12 @@ class MyInt(Dialect, name="myint"):
 
 class ConstantOp(MyInt.Operation, name="constant"):
     value: IntegerAttr
-    cst: Result[IntegerType[32]] = result(infer_type=True)
+    cst: Result[IntegerType[32]] = infer_result()
 
 class AddOp(MyInt.Operation, name="add"):
     lhs: Operand[IntegerType[32]]
     rhs: Operand[IntegerType[32]]
-    res: Result[IntegerType[32]] = result(infer_type=True)
+    res: Result[IntegerType[32]] = infer_result()
 
 # The code below requires an available MLIR context and location.
 
diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 89051bf9ec924..c2efa9bb773cc 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -36,6 +36,7 @@
     "Type",
     "Attribute",
     "result",
+    "infer_result",
     "operand",
     "attribute",
 ]
@@ -128,26 +129,31 @@ def param_kind(self):
 
 def result(
     *,
-    infer_type: bool = False,
     default_factory: Optional[Callable[[], Any]] = None,
     kw_only: bool = False,
 ) -> Result:
     """
     A field specifier for `Result` definitions.
     """
-    if infer_type and default_factory:
-        raise ValueError(
-            "a result field cannot have both infer_type and default_factory"
-        )
 
     return FieldSpecifier(
         type_=Result,
-        infer_type=infer_type,
         default_factory=default_factory,
         kw_only=kw_only,
     )
 
 
+def infer_result() -> Result:
+    """
+    A field specifier for `Result` definitions with type inference enabled.
+    """
+
+    return FieldSpecifier(
+        type_=Result,
+        infer_type=True,
+    )
+
+
 def operand(
     *,
     kw_only: bool = False,
@@ -894,12 +900,12 @@ class MyInt(Dialect, name="myint"):
 
     class ConstantOp(MyInt.Operation, name="constant"):
         value: IntegerAttr
-        cst: Result[i32]
+        cst: Result[i32] = infer_result()
 
     class AddOp(MyInt.Operation, name="add"):
         lhs: Operand[i32]
         rhs: Operand[i32]
-        res: Result[i32]
+        res: Result[i32] = infer_result()
     ```
     """
 
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index 8dfc74ad29d4f..733f18de402c8 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -23,12 +23,12 @@ class MyInt(Dialect, name="myint"):
 
     class ConstantOp(MyInt.Operation, name="constant"):
         value: IntegerAttr
-        cst: Result[i32] = result(infer_type=True)
+        cst: Result[i32] = infer_result()
 
     class AddOp(Operation, dialect=MyInt, name="add"):
         lhs: Operand[i32]
         rhs: Operand[i32]
-        res: Result[i32] = result(infer_type=True)
+        res: Result[i32] = infer_result()
 
     # CHECK: irdl.dialect @myint {
     # CHECK:   irdl.operation @constant {
@@ -575,8 +575,8 @@ class MakeArrayOp(TestType.Operation, name="make_array"):
         arr: Result[Array]
 
     class MakeArray3Op(TestType.Operation, name="make_array3"):
-        arr: Result[Array[IntegerType[32], IntegerAttr[IntegerType[32], 3]]] = result(
-            infer_type=True
+        arr: Result[Array[IntegerType[32], IntegerAttr[IntegerType[32], 3]]] = (
+            infer_result()
         )
 
     with Context(), Location.unknown():
@@ -738,7 +738,7 @@ class TestInvalid(Dialect, name="ext_invalid"):
         class InferTypeBeforePositionalOp(
             TestInvalid.Operation, name="infer_before_pos"
         ):
-            res: Result[IntegerType[32]] = result(infer_type=True)
+            res: Result[IntegerType[32]] = infer_result()
             a: Operand[IntegerType[32]]
 
     except ValueError as e:
@@ -770,7 +770,7 @@ class AssignNoneOnnAttributeOp(
     try:
 
         class CannotInferTypeOp(TestInvalid.Operation, name="cannot_infer_type"):
-            a: Result[IntegerType] = result(infer_type=True)
+            a: Result[IntegerType] = infer_result()
 
     except TypeError as e:
         # CHECK: unsupported type for inferring
@@ -852,7 +852,7 @@ class OperandSpecifierOp(TestFieldSpecifiers.Operation, name="operand_specifier"
 
     class ResultSpecifierOp(TestFieldSpecifiers.Operation, name="result_specifier"):
         a: Result[IntegerType[32]] = result()
-        b: Result[IntegerType[16]] = result(infer_type=True)
+        b: Result[IntegerType[16]] = infer_result()
         c: Result[IntegerType] = result(
             default_factory=lambda: IntegerType.get_signless(8)
         )
diff --git a/mlir/test/python/dialects/transform_op_interface.py b/mlir/test/python/dialects/transform_op_interface.py
index b9a1c43e111af..b0b416530eccc 100644
--- a/mlir/test/python/dialects/transform_op_interface.py
+++ b/mlir/test/python/dialects/transform_op_interface.py
@@ -88,7 +88,7 @@ def get_effects(op: ir.Operation, effects):
 class GetNamedAttributeOp(MyTransform.Operation, name="get_named_attribute"):
     target: ext.Operand[transform.AnyOpType]
     attr_name: ir.StringAttr
-    attr_as_param: ext.Result[transform.AnyParamType[()]] = ext.result(infer_type=True)
+    attr_as_param: ext.Result[transform.AnyParamType[()]] = ext.infer_result()
 
     @classmethod
     def attach_interface_impls(cls, ctx=None):
@@ -149,7 +149,7 @@ def allow_repeated_handle_operands(_op: "GetNamedAttributeOp") -> bool:
 # Syntax for an op with one op handle operand and one op handle result.
 class OneOpInOneOpOut(MyTransform.Operation, name="one_op_in_one_op_out"):
     target: ext.Operand[transform.AnyOpType]
-    res: ext.Result[transform.AnyOpType[()]] = ext.result(infer_type=True)
+    res: ext.Result[transform.AnyOpType[()]] = ext.infer_result()
 
 
 # CHECK-LABEL: Test: OneOpInOneOpOutTransformOpInterface
@@ -277,9 +277,9 @@ class OpValParamInParamOpValOut(
     val_arg: ext.Operand[transform.AnyValueType]
     param_arg: ext.Operand[transform.AnyParamType]
     # results
-    param_res: ext.Result[transform.AnyParamType[()]] = ext.result(infer_type=True)
-    op_res: ext.Result[transform.AnyOpType[()]] = ext.result(infer_type=True)
-    value_res: ext.Result[transform.AnyValueType[()]] = ext.result(infer_type=True)
+    param_res: ext.Result[transform.AnyParamType[()]] = ext.infer_result()
+    op_res: ext.Result[transform.AnyOpType[()]] = ext.infer_result()
+    value_res: ext.Result[transform.AnyValueType[()]] = ext.infer_result()
 
 
 # CHECK-LABEL: Test: OpValParamInParamOpValOutTransformOpInterface
diff --git a/mlir/test/python/integration/dialects/bf.py b/mlir/test/python/integration/dialects/bf.py
index e38155dcda973..23632b7f4b9c9 100644
--- a/mlir/test/python/integration/dialects/bf.py
+++ b/mlir/test/python/integration/dialects/bf.py
@@ -20,12 +20,12 @@ class PtrType(BfDialect.Type, name="ptr"):
 
 class NextOp(BfDialect.Operation, name="next"):
     in_: Operand[PtrType]
-    out: Result[PtrType[()]] = result(infer_type=True)
+    out: Result[PtrType[()]] = infer_result()
 
 
 class PrevOp(BfDialect.Operation, name="prev"):
     in_: Operand[PtrType]
-    out: Result[PtrType[()]] = result(infer_type=True)
+    out: Result[PtrType[()]] = infer_result()
 
 
 class IncOp(BfDialect.Operation, name="inc"):
@@ -46,7 +46,7 @@ class OutputOp(BfDialect.Operation, name="output"):
 
 class WhileOp(BfDialect.Operation, name="while"):
     in_: Operand[PtrType]
-    out: Result[PtrType[()]] = result(infer_type=True)
+    out: Result[PtrType[()]] = infer_result()
     body: Region
 
 

>From 41b98cc7e04b5868b48c921eedcf11bffd896917 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Tue, 14 Apr 2026 00:50:04 +0800
Subject: [PATCH 2/2] format

---
 mlir/test/python/dialects/ext.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index 733f18de402c8..cfdbeb3362735 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -575,9 +575,9 @@ class MakeArrayOp(TestType.Operation, name="make_array"):
         arr: Result[Array]
 
     class MakeArray3Op(TestType.Operation, name="make_array3"):
-        arr: Result[Array[IntegerType[32], IntegerAttr[IntegerType[32], 3]]] = (
-            infer_result()
-        )
+        arr: Result[
+            Array[IntegerType[32], IntegerAttr[IntegerType[32], 3]]
+        ] = infer_result()
 
     with Context(), Location.unknown():
         TestType.load()



More information about the Mlir-commits mailing list