[Mlir-commits] [mlir] [MLIR][Python] Make init parameters follow the field definition order (PR #186574)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 19 09:50:40 PDT 2026


https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/186574

>From 9c7d601c1836630988faf464fd19ed7eb9c8e1c8 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 14 Mar 2026 15:35:48 +0800
Subject: [PATCH 1/5] [MLIR][Python] Make init parameters follow the field
 definition order

---
 mlir/python/mlir/dialects/ext.py              | 152 ++++++++++++------
 mlir/test/python/dialects/ext.py              |  73 +++++----
 .../python/dialects/transform_op_interface.py |  16 +-
 mlir/test/python/integration/dialects/bf.py   |   6 +-
 4 files changed, 156 insertions(+), 91 deletions(-)

diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 15651a1c4e858..736f3b757f38f 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -17,7 +17,7 @@
 from collections.abc import Sequence
 from dataclasses import dataclass
 from inspect import Parameter, Signature
-from types import UnionType
+from types import UnionType, SimpleNamespace
 from . import irdl
 from ._ods_common import _cext, segmented_accessor
 from .irdl import Variadicity
@@ -34,6 +34,7 @@
     "Region",
     "Type",
     "Attribute",
+    "infer_type",
     "register_dialect",
     "register_operation",
 ]
@@ -59,6 +60,9 @@ def decorator(op_cls: type) -> type:
 
 
 def construct_instance(origin, args):
+    if not issubclass(origin, ir.Type | ir.Attribute):
+        raise TypeError(f"unsupported type in constraints: {origin}")
+
     # `origin.get` is to construct an instance of MLIR type or attribute.
     return origin.get(
         *(
@@ -126,21 +130,48 @@ def _lower(self, type_) -> ir.Value:
         raise TypeError(f"unsupported type in constraints: {type_}")
 
 
-def infer_type(type_) -> Optional[Callable[[], ir.Type]]:
+ at dataclass
+class Marker:
+    infer_type: bool = False
+    default_is_none: bool = False
+
+    def __post_init__(self):
+        if self.infer_type and self.default_is_none:
+            raise ValueError(
+                "a field cannot be marked with both infer_type and default_is_none"
+            )
+
+    def kw_only(self) -> bool:
+        return self.default_is_none or self.infer_type
+
+
+def infer_type() -> Any:
+    """
+    A marker to indicate that the type of a result should be inferred.
+    It can only be used in `Result` definitions.
+    """
+
+    return Marker(infer_type=True)
+
+
+def infer_type_impl(type_) -> Callable[[], ir.Type]:
     """
     A function to infer ir.Type from type annotation.
-    Returns a callable that returns the inferred ir.Type,
-    or None if the type cannot be inferred.
+    Returns a callable that returns the inferred ir.Type.
     We use callables so that MLIR contexts are not required
     while calling this function.
     """
 
     origin = get_origin(type_)
-    if origin and issubclass(origin, ir.Type):
-        return lambda: construct_instance(origin, get_args(type_))
+    if origin and issubclass(origin, ir.Type | ir.Attribute):
+        args = [
+            infer_type_impl(arg) if get_origin(arg) else lambda: arg
+            for arg in get_args(type_)
+        ]
+        return lambda: origin.get(*[arg() for arg in args])
     elif isinstance(type_, TypeVar):
-        return infer_type(type_.__bound__)
-    return None
+        return infer_type_impl(type_.__bound__)
+    raise TypeError(f"unsupported type for inferring: {type_}")
 
 
 @dataclass
@@ -151,9 +182,12 @@ class FieldDef:
 
     name: str
     variadicity: Variadicity
+    constraint: Any
+
+    kw_only: bool = False
 
     @staticmethod
-    def from_type_hint(name, type_) -> "FieldDef":
+    def from_type_hint(name, type_, marker) -> "FieldDef":
         variadicity = Variadicity.single
         if inner := match_optional(type_):
             variadicity = Variadicity.optional
@@ -164,29 +198,57 @@ def from_type_hint(name, type_) -> "FieldDef":
 
         origin = get_origin(type_)
         if origin is ir.OpResult:
-            return ResultDef(name, variadicity, get_args(type_)[0])
+            constraint = get_args(type_)[0]
+            return ResultDef(
+                name,
+                variadicity,
+                constraint,
+                kw_only=marker.kw_only(),
+                infer_type=infer_type_impl(constraint) if marker.infer_type else None,
+            )
         elif origin is ir.Value:
-            return OperandDef(name, variadicity, get_args(type_)[0])
+            return OperandDef(
+                name,
+                variadicity,
+                get_args(type_)[0],
+                kw_only=marker.kw_only(),
+            )
         elif issubclass(origin or type_, ir.Attribute):
             return AttributeDef(name, variadicity, type_)
         elif type_ is ir.Region:
-            return RegionDef(name, variadicity)
-        raise TypeError(f"unsupported type in operation definition: {type_}")
+            return RegionDef(name, variadicity, Any)
+        raise TypeError(
+            f"unsupported type for field '{name}' in operation definition: {type_}"
+        )
 
 
 @dataclass
 class OperandDef(FieldDef):
-    constraint: Any
+    def __post_init__(self):
+        if self.variadicity != Variadicity.optional and self.kw_only:
+            raise ValueError(f"only optional operand can be a keyword parameter")
 
 
 @dataclass
 class ResultDef(FieldDef):
-    constraint: Any
+    infer_type: Callable[[], ir.Type] | None = None
+
+    def __post_init__(self):
+        if (
+            self.variadicity != Variadicity.optional
+            and not self.infer_type
+            and self.kw_only
+        ):
+            raise ValueError(f"only optional result can be a keyword parameter")
+
+        if self.infer_type and self.variadicity != Variadicity.single:
+            raise ValueError(
+                f"type of variadic or optional result '{self.name}' cannot be inferred"
+            )
 
 
 @dataclass
 class AttributeDef(FieldDef):
-    constraint: Any
 
     def __post_init__(self):
         if self.variadicity != Variadicity.single:
@@ -284,7 +346,17 @@ def __init_subclass__(
             if hasattr(base, "_fields"):
                 fields.extend(base._fields)
         for key, value in cls.__annotations__.items():
-            field = FieldDef.from_type_hint(key, value)
+            # if the class variable is not defined, we treat it as a default marker;
+            # if it is assigned with `None`, we treat it as a marker with `default_is_none=True`.
+            # e.g. x : int         # default marker
+            #      y : int = None  # marker with default_is_none=True
+            marker = cls.__dict__.get(key, Marker()) or Marker(default_is_none=True)
+            # treat all other values as invalid
+            if not isinstance(marker, Marker):
+                raise TypeError(
+                    f"the field specifier of field '{key}' is not supported"
+                )
+            field = FieldDef.from_type_hint(key, value, marker)
             fields.append(field)
 
         cls._fields = fields
@@ -353,27 +425,17 @@ def _generate_segments(
         return None
 
     @staticmethod
-    def _generate_init_signature(
-        fields: List[FieldDef], can_infer_types: bool
-    ) -> Signature:
-        result_args = (
-            [] if can_infer_types else [i for i in fields if isinstance(i, ResultDef)]
-        )
-        # results are placed at the beginning of the parameter list,
-        # but operands and attributes can appear in any relative order.
-        args = result_args + [
-            i for i in fields if not isinstance(i, ResultDef | RegionDef)
-        ]
-        positional_args = [
-            i.name for i in args if i.variadicity != Variadicity.optional
-        ]
-        optional_args = [i.name for i in args if i.variadicity == Variadicity.optional]
+    def _generate_init_signature(fields: List[FieldDef]) -> Signature:
+        args = [i for i in fields if not isinstance(i, RegionDef)]
 
         params = [Parameter("self", Parameter.POSITIONAL_ONLY)]
-        for i in positional_args:
-            params.append(Parameter(i, Parameter.POSITIONAL_OR_KEYWORD))
-        for i in optional_args:
-            params.append(Parameter(i, Parameter.KEYWORD_ONLY, default=None))
+
+        for i in args:
+            if i.kw_only:
+                params.append(Parameter(i.name, Parameter.KEYWORD_ONLY, default=None))
+            else:
+                params.append(Parameter(i.name, Parameter.POSITIONAL_OR_KEYWORD))
+
         params.append(Parameter("loc", Parameter.KEYWORD_ONLY, default=None))
         params.append(Parameter("ip", Parameter.KEYWORD_ONLY, default=None))
 
@@ -382,15 +444,8 @@ def _generate_init_signature(
     @classmethod
     def _generate_init_method(cls, fields: List[FieldDef]) -> None:
         operands, attrs, results, regions = partition_fields(fields)
-        inferred_types = [infer_type(i.constraint) for i in results]
-
-        # we infer result types only when all result types can be inferred
-        # and all results are single (not optional or variadic)
-        can_infer_types = all(inferred_types) and all(
-            i.variadicity == Variadicity.single for i in results
-        )
 
-        init_sig = cls._generate_init_signature(fields, can_infer_types)
+        init_sig = cls._generate_init_signature(fields)
 
         def __init__(*args, **kwargs):
             bound = init_sig.bind(*args, **kwargs)
@@ -398,11 +453,10 @@ def __init__(*args, **kwargs):
             args = bound.arguments
 
             _operands = [args[operand.name] for operand in operands]
-            _results = (
-                [t() for t in inferred_types]
-                if can_infer_types
-                else [args[result.name] for result in results]
-            )
+            _results = [
+                result.infer_type() if result.infer_type else args[result.name]
+                for result in results
+            ]
             _attributes = dict(
                 (attr.name, args[attr.name])
                 for attr in attrs
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index 30132f891faec..f1d920e68621e 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -23,12 +23,12 @@ class MyInt(Dialect, name="myint"):
 
     class ConstantOp(MyInt.Operation, name="constant"):
         value: IntegerAttr
-        cst: Result[i32]
+        cst: Result[i32] = infer_type()
 
     class AddOp(Operation, dialect=MyInt, name="add"):
         lhs: Operand[i32]
         rhs: Operand[i32]
-        res: Result[i32]
+        res: Result[i32] = infer_type()
 
     # CHECK: irdl.dialect @myint {
     # CHECK:   irdl.operation @constant {
@@ -88,9 +88,9 @@ class AddOp(Operation, dialect=MyInt, name="add"):
         print(two.value)
         # CHECK: OpResult(%0
         print(two.cst)
-        # CHECK: (self, /, lhs, rhs, *, loc=None, ip=None)
+        # CHECK: (self, /, lhs, rhs, *, res=None, loc=None, ip=None)
         print(AddOp.__init__.__signature__)
-        # CHECK: (self, /, value, *, loc=None, ip=None)
+        # CHECK: (self, /, value, *, cst=None, loc=None, ip=None)
         print(ConstantOp.__init__.__signature__)
 
         # CHECK: True
@@ -124,28 +124,28 @@ class ConstraintOp(Test.Operation, name="constraint"):
         y: FloatAttr
 
     class OptionalOp(Test.Operation, name="optional"):
-        a: Operand[i32]
-        b: Optional[Operand[i32]]
         out1: Result[i32]
-        out2: Result[i32] | None
         out3: Result[i32]
+        a: Operand[i32]
+        out2: Result[i32] | None = None
+        b: Optional[Operand[i32]] = None
 
     class Optional2Op(Test.Operation, name="optional2"):
-        a: Optional[Operand[i32]]
-        b: Optional[Result[i32]]
+        b: Optional[Result[i32]] = None
+        a: Optional[Operand[i32]] = None
 
     class VariadicOp(Test.Operation, name="variadic"):
-        a: Operand[i32]
-        b: Optional[Operand[i32]]
-        c: Sequence[Operand[i32]]
         out1: Sequence[Result[i32]]
         out2: Sequence[Result[i32]]
-        out3: Optional[Result[i32]]
         out4: Result[i32]
+        a: Operand[i32]
+        c: Sequence[Operand[i32]]
+        out3: Optional[Result[i32]] = None
+        b: Optional[Operand[i32]] = None
 
     class Variadic2Op(Test.Operation, name="variadic2"):
-        a: Sequence[Operand[i32]]
         b: Sequence[Result[i32]]
+        a: Sequence[Operand[i32]]
 
     class MixedOpBase(Test.Operation):
         out: Result[i32]
@@ -153,9 +153,9 @@ class MixedOpBase(Test.Operation):
 
     class MixedOp(MixedOpBase, name="mixed"):
         in2: IntegerAttr
-        in3: Optional[Operand[i32]]
         in4: IntegerAttr
         in5: Operand[i32]
+        in3: Optional[Operand[i32]] = None
 
     T = TypeVar("T")
     U = TypeVar("U", bound=IntegerType[32] | IntegerType[64])
@@ -168,6 +168,11 @@ class TypeVarOp(Test.Operation, name="type_var"):
         in4: Operand[U | V]
         in5: Operand[V]
 
+    class OptionalButNotKeywordOp(Test.Operation, name="optional_but_not_keyword"):
+        a: Operand[i32]
+        b: Optional[Operand[i32]]
+        c: Operand[i32]
+
     # CHECK: irdl.dialect @ext_test {
     # CHECK:   irdl.operation @constraint {
     # CHECK:     %0 = irdl.is i32
@@ -185,7 +190,7 @@ class TypeVarOp(Test.Operation, name="type_var"):
     # CHECK:   irdl.operation @optional {
     # CHECK:     %0 = irdl.is i32
     # CHECK:     irdl.operands(a: %0, b: optional %0)
-    # CHECK:     irdl.results(out1: %0, out2: optional %0, out3: %0)
+    # CHECK:     irdl.results(out1: %0, out3: %0, out2: optional %0)
     # CHECK:   }
     # CHECK:   irdl.operation @optional2 {
     # CHECK:     %0 = irdl.is i32
@@ -194,8 +199,8 @@ class TypeVarOp(Test.Operation, name="type_var"):
     # CHECK:   }
     # CHECK:   irdl.operation @variadic {
     # CHECK:     %0 = irdl.is i32
-    # CHECK:     irdl.operands(a: %0, b: optional %0, c: variadic %0)
-    # CHECK:     irdl.results(out1: variadic %0, out2: variadic %0, out3: optional %0, out4: %0)
+    # CHECK:     irdl.operands(a: %0, c: variadic %0, b: optional %0)
+    # CHECK:     irdl.results(out1: variadic %0, out2: variadic %0, out4: %0, out3: optional %0)
     # CHECK:   }
     # CHECK:   irdl.operation @variadic2 {
     # CHECK:     %0 = irdl.is i32
@@ -204,7 +209,7 @@ class TypeVarOp(Test.Operation, name="type_var"):
     # CHECK:   }
     # CHECK:   irdl.operation @mixed {
     # CHECK:     %0 = irdl.is i32
-    # CHECK:     irdl.operands(in1: %0, in3: optional %0, in5: %0)
+    # CHECK:     irdl.operands(in1: %0, in5: %0, in3: optional %0)
     # CHECK:     %1 = irdl.base "#builtin.integer"
     # CHECK:     %2 = irdl.base "#builtin.integer"
     # CHECK:     irdl.attributes {"in2" = %1, "in4" = %2}
@@ -236,16 +241,20 @@ class TypeVarOp(Test.Operation, name="type_var"):
         print(VariadicOp.__init__.__signature__)
         # CHECK: (self, /, b, a, *, loc=None, ip=None)
         print(Variadic2Op.__init__.__signature__)
-        # CHECK: (self, /, in1, in2, in4, in5, *, in3=None, loc=None, ip=None)
+        # CHECK: (self, /, out, in1, in2, in4, in5, *, in3=None, loc=None, ip=None)
         print(MixedOp.__init__.__signature__)
+        # CHECK: (self, /, in1, in2, in3, in4, in5, *, loc=None, ip=None)
+        print(TypeVarOp.__init__.__signature__)
+        # CHECK: (self, /, a, b, c, *, loc=None, ip=None)
+        print(OptionalButNotKeywordOp.__init__.__signature__)
 
         # CHECK: None None
         print(ConstraintOp._ODS_OPERAND_SEGMENTS, ConstraintOp._ODS_RESULT_SEGMENTS)
-        # CHECK: [1, 0] [1, 0, 1]
+        # CHECK: [1, 0] [1, 1, 0]
         print(OptionalOp._ODS_OPERAND_SEGMENTS, OptionalOp._ODS_RESULT_SEGMENTS)
         # CHECK: [0] [0]
         print(Optional2Op._ODS_OPERAND_SEGMENTS, Optional2Op._ODS_RESULT_SEGMENTS)
-        # CHECK: [1, 0, -1] [-1, -1, 0, 1]
+        # CHECK: [1, -1, 0] [-1, -1, 1, 0]
         print(VariadicOp._ODS_OPERAND_SEGMENTS, VariadicOp._ODS_RESULT_SEGMENTS)
         # CHECK: [-1] [-1]
         print(Variadic2Op._ODS_OPERAND_SEGMENTS, Variadic2Op._ODS_RESULT_SEGMENTS)
@@ -269,7 +278,7 @@ class TypeVarOp(Test.Operation, name="type_var"):
             # CHECK: ext_test.constraint"(%c1_i32, %cst, %c1_i32, %cst) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, f32, i32, f32) -> ()
             ConstraintOp(ione, fone, ione, fone, iattr, fattr)
 
-            # CHECK: %0:2 = "ext_test.optional"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0>, resultSegmentSizes = array<i32: 1, 0, 1>} : (i32) -> (i32, i32)
+            # CHECK: %0:2 = "ext_test.optional"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0>, resultSegmentSizes = array<i32: 1, 1, 0>} : (i32) -> (i32, i32)
             o1 = OptionalOp(i32, i32, ione)
             # CHECK: %1:3 = "ext_test.optional"(%c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 1>, resultSegmentSizes = array<i32: 1, 1, 1>} : (i32, i32) -> (i32, i32, i32)
             o2 = OptionalOp(i32, i32, ione, out2=i32, b=ione)
@@ -282,11 +291,11 @@ class TypeVarOp(Test.Operation, name="type_var"):
             # CHECK: %3 = "ext_test.optional2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 1>} : (i32) -> i32
             o6 = Optional2Op(b=i32, a=ione)
 
-            # CHECK: %4:4 = "ext_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 0, 2>, resultSegmentSizes = array<i32: 1, 2, 0, 1>} : (i32, i32, i32) -> (i32, i32, i32, i32)
+            # CHECK: %4:4 = "ext_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 2, 0>, resultSegmentSizes = array<i32: 1, 2, 1, 0>} : (i32, i32, i32) -> (i32, i32, i32, i32)
             v1 = VariadicOp([i32], [i32, i32], i32, ione, [ione, ione])
             # CHECK: %5:5 = "ext_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 1, 1>, resultSegmentSizes = array<i32: 1, 2, 1, 1>} : (i32, i32, i32) -> (i32, i32, i32, i32, i32)
             v2 = VariadicOp([i32], [i32, i32], i32, ione, [ione], out3=i32, b=ione)
-            # CHECK: %6:4 = "ext_test.variadic"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0, 0>, resultSegmentSizes = array<i32: 2, 1, 0, 1>} : (i32) -> (i32, i32, i32, i32)
+            # CHECK: %6:4 = "ext_test.variadic"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0, 0>, resultSegmentSizes = array<i32: 2, 1, 1, 0>} : (i32) -> (i32, i32, i32, i32)
             v3 = VariadicOp([i32, i32], [i32], i32, ione, [])
             # CHECK: "ext_test.variadic2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 0>} : () -> ()
             v4 = Variadic2Op([], [])
@@ -295,10 +304,10 @@ class TypeVarOp(Test.Operation, name="type_var"):
             # CHECK: %7:2 = "ext_test.variadic2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 2>} : (i32) -> (i32, i32)
             v6 = Variadic2Op([i32, i32], [ione])
 
-            # CHECK: %8 = "ext_test.mixed"(%c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 0, 1>} : (i32, i32) -> i32
-            m1 = MixedOp(ione, iattr, iattr, ione)
+            # CHECK: %8 = "ext_test.mixed"(%c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 1, 0>} : (i32, i32) -> i32
+            m1 = MixedOp(i32, ione, iattr, iattr, ione)
             # CHECK: %9 = "ext_test.mixed"(%c1_i32, %c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 1, 1>} : (i32, i32, i32) -> i32
-            m2 = MixedOp(ione, iattr, iattr, ione, in3=ione)
+            m2 = MixedOp(i32, ione, iattr, iattr, ione, in3=ione)
 
         print(module)
         assert module.operation.verify()
@@ -319,7 +328,7 @@ class TypeVarOp(Test.Operation, name="type_var"):
         print(o1.out2)
         # CHECK: 0
         print(o2.out1.result_number)
-        # CHECK: 1
+        # CHECK: 2
         print(o2.out2.result_number)
         # CHECK: None
         print(o3.a)
@@ -371,8 +380,8 @@ class TestRegion(Dialect, name="ext_region"):
         pass
 
     class IfOp(TestRegion.Operation, name="if"):
-        cond: Operand[IntegerType[1]]
         result: Result[Any]
+        cond: Operand[IntegerType[1]]
         then: Region
         else_: Region
 
@@ -547,7 +556,9 @@ class MakeArrayOp(TestType.Operation, name="make_array"):
         arr: Result[Array]
 
     class MakeArray3Op(TestType.Operation, name="make_array3"):
-        arr: Result[Array[IntegerType[32], IntegerAttr[IntegerType[32], 3]]]
+        arr: Result[Array[IntegerType[32], IntegerAttr[IntegerType[32], 3]]] = (
+            infer_type()
+        )
 
     with Context(), Location.unknown():
         TestType.load()
diff --git a/mlir/test/python/dialects/transform_op_interface.py b/mlir/test/python/dialects/transform_op_interface.py
index f58e0be13befd..c70082e892371 100644
--- a/mlir/test/python/dialects/transform_op_interface.py
+++ b/mlir/test/python/dialects/transform_op_interface.py
@@ -90,7 +90,7 @@ def get_effects(op: ir.Operation, effects):
 class GetNamedAttributeOp(MyTransform.Operation, name="get_named_attribute"):
     target: ext.Operand[transform.AnyOpType]
     attr_name: ir.StringAttr
-    attr_as_param: ext.Result[transform.AnyParamType[()]]
+    attr_as_param: ext.Result[transform.AnyParamType[()]] = ext.infer_type()
 
     @classmethod
     def attach_interface_impls(cls, ctx=None):
@@ -153,7 +153,7 @@ def allow_repeated_handle_operands(_op: "GetNamedAttributeOp") -> bool:
 @ext.register_operation(MyTransform)
 class OneOpInOneOpOut(MyTransform.Operation, name="one_op_in_one_op_out"):
     target: ext.Operand[transform.AnyOpType]
-    res: ext.Result[transform.AnyOpType[()]]
+    res: ext.Result[transform.AnyOpType[()]] = ext.infer_type()
 
 
 # CHECK-LABEL: Test: OneOpInOneOpOutTransformOpInterface
@@ -282,9 +282,9 @@ class OpValParamInParamOpValOut(
     val_arg: ext.Operand[transform.AnyValueType]
     param_arg: ext.Operand[transform.AnyParamType]
     # results
-    param_res: ext.Result[transform.AnyParamType[()]]
-    op_res: ext.Result[transform.AnyOpType[()]]
-    value_res: ext.Result[transform.AnyValueType[()]]
+    param_res: ext.Result[transform.AnyParamType[()]] = ext.infer_type()
+    op_res: ext.Result[transform.AnyOpType[()]] = ext.infer_type()
+    value_res: ext.Result[transform.AnyValueType[()]] = ext.infer_type()
 
 
 # CHECK-LABEL: Test: OpValParamInParamOpValOutTransformOpInterface
@@ -382,12 +382,12 @@ def allow_repeated_handle_operands(_op: OpValParamInParamOpValOut) -> bool:
 class OpsParamsInValuesParamOut(
     MyTransform.Operation, name="ops_params_in_values_param_out"
 ):
-    # operands
-    ops: Sequence[ext.Operand[transform.AnyOpType]]
-    params: Sequence[ext.Operand[transform.AnyParamType]]
     # results
     values: Sequence[ext.Result[transform.AnyValueType]]
     param: ext.Result[transform.AnyParamType]
+    # operands
+    ops: Sequence[ext.Operand[transform.AnyOpType]]
+    params: Sequence[ext.Operand[transform.AnyParamType]]
 
 
 # CHECK-LABEL: Test: OpsParamsInValuesParamOutTransformOpInterface
diff --git a/mlir/test/python/integration/dialects/bf.py b/mlir/test/python/integration/dialects/bf.py
index cbb815378ddf7..61e25c6c3d872 100644
--- a/mlir/test/python/integration/dialects/bf.py
+++ b/mlir/test/python/integration/dialects/bf.py
@@ -20,12 +20,12 @@ class PtrType(BfDialect.Type, name="ptr"):
 
 class NextOp(BfDialect.Operation, name="next"):
     in_: Operand[PtrType]
-    out: Result[PtrType[()]]
+    out: Result[PtrType[()]] = infer_type()
 
 
 class PrevOp(BfDialect.Operation, name="prev"):
     in_: Operand[PtrType]
-    out: Result[PtrType[()]]
+    out: Result[PtrType[()]] = infer_type()
 
 
 class IncOp(BfDialect.Operation, name="inc"):
@@ -46,7 +46,7 @@ class OutputOp(BfDialect.Operation, name="output"):
 
 class WhileOp(BfDialect.Operation, name="while"):
     in_: Operand[PtrType]
-    out: Result[PtrType[()]]
+    out: Result[PtrType[()]] = infer_type()
     body: Region
 
 

>From 78244a7ce5835eb02200c7808c9c473dd6535dbc Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 14 Mar 2026 16:16:26 +0800
Subject: [PATCH 2/5] format

---
 mlir/test/python/dialects/ext.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index f1d920e68621e..abb50c21e8a5a 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -556,9 +556,9 @@ class MakeArrayOp(TestType.Operation, name="make_array"):
         arr: Result[Array]
 
     class MakeArray3Op(TestType.Operation, name="make_array3"):
-        arr: Result[Array[IntegerType[32], IntegerAttr[IntegerType[32], 3]]] = (
-            infer_type()
-        )
+        arr: Result[
+            Array[IntegerType[32], IntegerAttr[IntegerType[32], 3]]
+        ] = infer_type()
 
     with Context(), Location.unknown():
         TestType.load()

>From 0591acc938ad9a1f026bd52ce1153129de7595f9 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sun, 15 Mar 2026 12:16:19 +0800
Subject: [PATCH 3/5] remove useless import

---
 mlir/python/mlir/dialects/ext.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index f1054489355c0..d658cf93a9d85 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -17,7 +17,7 @@
 from collections.abc import Sequence
 from dataclasses import dataclass
 from inspect import Parameter, Signature
-from types import UnionType, SimpleNamespace
+from types import UnionType
 from . import irdl
 from ._ods_common import _cext, segmented_accessor
 from .irdl import Variadicity

>From fcfc82d182dafb4af73a8c8b8b2b6d986b4cfdca Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sun, 15 Mar 2026 13:11:47 +0800
Subject: [PATCH 4/5] fix result type handling

---
 mlir/python/mlir/dialects/ext.py | 14 ++++++++++----
 1 file changed, 10 insertions(+), 4 deletions(-)

diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index d658cf93a9d85..57260e09971aa 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -229,6 +229,15 @@ def __post_init__(self):
                 f"type of variadic or optional result '{self.name}' cannot be inferred"
             )
 
+    def process_type(self, type_):
+        if type_:
+            return type_
+
+        if self.infer_type:
+            return self.infer_type()
+
+        return None
+
 
 @dataclass
 class AttributeDef(FieldDef):
@@ -436,10 +445,7 @@ def __init__(*args, **kwargs):
             args = bound.arguments
 
             _operands = [args[operand.name] for operand in operands]
-            _results = [
-                result.infer_type() if result.infer_type else args[result.name]
-                for result in results
-            ]
+            _results = [result.process_type(args[result.name]) for result in results]
             _attributes = dict(
                 (attr.name, args[attr.name])
                 for attr in attrs

>From 5c24970dbdbd96383c8fb7d67a68843f6508791b Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 20 Mar 2026 00:49:59 +0800
Subject: [PATCH 5/5] replace infer_type with result

---
 mlir/python/mlir/dialects/ext.py              | 37 +++++++++--------
 mlir/test/python/dialects/ext.py              | 40 ++++++++++++++++---
 .../python/dialects/transform_op_interface.py | 10 ++---
 mlir/test/python/integration/dialects/bf.py   |  6 +--
 4 files changed, 63 insertions(+), 30 deletions(-)

diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 57260e09971aa..1900b8c162456 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -34,7 +34,7 @@
     "Region",
     "Type",
     "Attribute",
-    "infer_type",
+    "result",
 ]
 
 Operand = ir.Value
@@ -114,7 +114,7 @@ def _lower(self, type_) -> ir.Value:
 
 
 @dataclass
-class Marker:
+class FieldSpecifier:
     infer_type: bool = False
     default_is_none: bool = False
 
@@ -128,13 +128,12 @@ def kw_only(self) -> bool:
         return self.default_is_none or self.infer_type
 
 
-def infer_type() -> Any:
+def result(*, infer_type: bool = False) -> Any:
     """
-    A marker to indicate that the type of a result should be inferred.
-    It can only be used in `Result` definitions.
+    A field specifier for `Result` definitions.
     """
 
-    return Marker(infer_type=True)
+    return FieldSpecifier(infer_type=infer_type)
 
 
 def infer_type_impl(type_) -> Callable[[], ir.Type]:
@@ -170,7 +169,7 @@ class FieldDef:
     kw_only: bool = False
 
     @staticmethod
-    def from_type_hint(name, type_, marker) -> "FieldDef":
+    def from_type_hint(name, type_, specifier) -> "FieldDef":
         variadicity = Variadicity.single
         if inner := match_optional(type_):
             variadicity = Variadicity.optional
@@ -186,15 +185,17 @@ def from_type_hint(name, type_, marker) -> "FieldDef":
                 name,
                 variadicity,
                 constraint,
-                kw_only=marker.kw_only(),
-                infer_type=infer_type_impl(constraint) if marker.infer_type else None,
+                kw_only=specifier.kw_only(),
+                infer_type=(
+                    infer_type_impl(constraint) if specifier.infer_type else None
+                ),
             )
         elif origin is ir.Value:
             return OperandDef(
                 name,
                 variadicity,
                 get_args(type_)[0],
-                kw_only=marker.kw_only(),
+                kw_only=specifier.kw_only(),
             )
         elif issubclass(origin or type_, ir.Attribute):
             return AttributeDef(name, variadicity, type_)
@@ -338,17 +339,19 @@ def __init_subclass__(
             if hasattr(base, "_fields"):
                 fields.extend(base._fields)
         for key, value in cls.__annotations__.items():
-            # if the class variable is not defined, we treat it as a default marker;
-            # if it is assigned with `None`, we treat it as a marker with `default_is_none=True`.
-            # e.g. x : int         # default marker
-            #      y : int = None  # marker with default_is_none=True
-            marker = cls.__dict__.get(key, Marker()) or Marker(default_is_none=True)
+            # if the class variable is not defined, we treat it as a default specifier;
+            # if it is assigned with `None`, we treat it as a specifier with `default_is_none=True`.
+            # e.g. x : int         # default specifier
+            #      y : int = None  # specifier with default_is_none=True
+            specifier = cls.__dict__.get(key, FieldSpecifier()) or FieldSpecifier(
+                default_is_none=True
+            )
             # treat all other values as invalid
-            if not isinstance(marker, Marker):
+            if not isinstance(specifier, FieldSpecifier):
                 raise TypeError(
                     f"the field specifier of field '{key}' is not supported"
                 )
-            field = FieldDef.from_type_hint(key, value, marker)
+            field = FieldDef.from_type_hint(key, value, specifier)
             fields.append(field)
 
         cls._fields = fields
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index abb50c21e8a5a..a1593c35855ea 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -23,12 +23,12 @@ class MyInt(Dialect, name="myint"):
 
     class ConstantOp(MyInt.Operation, name="constant"):
         value: IntegerAttr
-        cst: Result[i32] = infer_type()
+        cst: Result[i32] = result(infer_type=True)
 
     class AddOp(Operation, dialect=MyInt, name="add"):
         lhs: Operand[i32]
         rhs: Operand[i32]
-        res: Result[i32] = infer_type()
+        res: Result[i32] = result(infer_type=True)
 
     # CHECK: irdl.dialect @myint {
     # CHECK:   irdl.operation @constant {
@@ -556,9 +556,9 @@ class MakeArrayOp(TestType.Operation, name="make_array"):
         arr: Result[Array]
 
     class MakeArray3Op(TestType.Operation, name="make_array3"):
-        arr: Result[
-            Array[IntegerType[32], IntegerAttr[IntegerType[32], 3]]
-        ] = infer_type()
+        arr: Result[Array[IntegerType[32], IntegerAttr[IntegerType[32], 3]]] = result(
+            infer_type=True
+        )
 
     with Context(), Location.unknown():
         TestType.load()
@@ -706,3 +706,33 @@ class Op2(TestAttr.Operation, name="op2"):
         # CHECK: "ext_attr.op1"() {pair = #ext_attr.pair<1 : i32, 2 : i32>} : () -> ()
         # CHECK: "ext_attr.op2"() {pair = #ext_attr.str_pair<"hello", "world">, pair2 = #ext_attr.str_pair<"a", "b">} : () -> ()
         print(module)
+
+
+# CHECK: TEST: testExtDialectWithInvalidOp
+ at run
+def testExtDialectWithInvalidOp():
+    class TestInvalid(Dialect, name="ext_invalid"):
+        pass
+
+    try:
+
+        class InferTypeBeforePositionalOp(
+            TestInvalid.Operation, name="infer_before_pos"
+        ):
+            res: Result[IntegerType[32]] = result(infer_type=True)
+            a: Operand[IntegerType[32]]
+
+    except ValueError as e:
+        # CHECK: wrong parameter order
+        print(e)
+
+    try:
+
+        class AssignNoneOnNonOptionalOp(
+            TestInvalid.Operation, name="assign_none_on_non_optional"
+        ):
+            a: Operand[IntegerType[32]] = None
+
+    except ValueError as e:
+        # CHECK: only optional operand can be a keyword parameter
+        print(e)
diff --git a/mlir/test/python/dialects/transform_op_interface.py b/mlir/test/python/dialects/transform_op_interface.py
index 9bc17a8efdc43..b9a1c43e111af 100644
--- a/mlir/test/python/dialects/transform_op_interface.py
+++ b/mlir/test/python/dialects/transform_op_interface.py
@@ -88,7 +88,7 @@ def get_effects(op: ir.Operation, effects):
 class GetNamedAttributeOp(MyTransform.Operation, name="get_named_attribute"):
     target: ext.Operand[transform.AnyOpType]
     attr_name: ir.StringAttr
-    attr_as_param: ext.Result[transform.AnyParamType[()]] = ext.infer_type()
+    attr_as_param: ext.Result[transform.AnyParamType[()]] = ext.result(infer_type=True)
 
     @classmethod
     def attach_interface_impls(cls, ctx=None):
@@ -149,7 +149,7 @@ def allow_repeated_handle_operands(_op: "GetNamedAttributeOp") -> bool:
 # Syntax for an op with one op handle operand and one op handle result.
 class OneOpInOneOpOut(MyTransform.Operation, name="one_op_in_one_op_out"):
     target: ext.Operand[transform.AnyOpType]
-    res: ext.Result[transform.AnyOpType[()]] = ext.infer_type()
+    res: ext.Result[transform.AnyOpType[()]] = ext.result(infer_type=True)
 
 
 # CHECK-LABEL: Test: OneOpInOneOpOutTransformOpInterface
@@ -277,9 +277,9 @@ class OpValParamInParamOpValOut(
     val_arg: ext.Operand[transform.AnyValueType]
     param_arg: ext.Operand[transform.AnyParamType]
     # results
-    param_res: ext.Result[transform.AnyParamType[()]] = ext.infer_type()
-    op_res: ext.Result[transform.AnyOpType[()]] = ext.infer_type()
-    value_res: ext.Result[transform.AnyValueType[()]] = ext.infer_type()
+    param_res: ext.Result[transform.AnyParamType[()]] = ext.result(infer_type=True)
+    op_res: ext.Result[transform.AnyOpType[()]] = ext.result(infer_type=True)
+    value_res: ext.Result[transform.AnyValueType[()]] = ext.result(infer_type=True)
 
 
 # CHECK-LABEL: Test: OpValParamInParamOpValOutTransformOpInterface
diff --git a/mlir/test/python/integration/dialects/bf.py b/mlir/test/python/integration/dialects/bf.py
index 61e25c6c3d872..e38155dcda973 100644
--- a/mlir/test/python/integration/dialects/bf.py
+++ b/mlir/test/python/integration/dialects/bf.py
@@ -20,12 +20,12 @@ class PtrType(BfDialect.Type, name="ptr"):
 
 class NextOp(BfDialect.Operation, name="next"):
     in_: Operand[PtrType]
-    out: Result[PtrType[()]] = infer_type()
+    out: Result[PtrType[()]] = result(infer_type=True)
 
 
 class PrevOp(BfDialect.Operation, name="prev"):
     in_: Operand[PtrType]
-    out: Result[PtrType[()]] = infer_type()
+    out: Result[PtrType[()]] = result(infer_type=True)
 
 
 class IncOp(BfDialect.Operation, name="inc"):
@@ -46,7 +46,7 @@ class OutputOp(BfDialect.Operation, name="output"):
 
 class WhileOp(BfDialect.Operation, name="while"):
     in_: Operand[PtrType]
-    out: Result[PtrType[()]] = infer_type()
+    out: Result[PtrType[()]] = result(infer_type=True)
     body: Region
 
 



More information about the Mlir-commits mailing list