[Mlir-commits] [mlir] [MLIR][Python] Make init parameters follow the field definition order (PR #186574)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Mar 14 22:13:03 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Twice (PragmaTwice)

<details>
<summary>Changes</summary>

Currently, Python-defined operations automatically generate an `__init__` function to serve as the operation builder. Previously, the parameters of this `__init__` function followed a fairly complex set of rules. For example:

* All result fields were moved to the front to align with other op builders.
* Fields of `Optional` type were automatically moved to the end and treated as keyword parameters.
* If the types of all result fields could be inferred automatically, then all result fields were removed from the parameter list.
* Other than that, the parameter order followed the field definition order.

These rules may seem reasonable, and they have worked well in practice, but they have one major drawback: users cannot easily tell what the actual `__init__` parameter list will look like when writing code, because the rules are simply too complicated. Users can inspect the signature at runtime via `MyOp.__init__.__signature__`, but this is clearly poor for the development experience.

After some offline discussion with Rolf, we decided to replace this with the following rules, which is what this PR implements:

* The parameters of `__init__` now strictly follow the field definition order.
* By default, all parameters are required. A field becomes a keyword parameter only if it is declared with `= None`.
* By default, no type inference is performed. A field becomes a keyword parameter and participates in type inference only if it is declared with `= infer_type()`.

These new rules give users full control over the `__init__` parameter list, making it easy to understand the parameter order and explicitly control optional parameters and type inference. In addition, this makes it much easier in the future to use `dataclass_transform` so that type checkers can understand these automatically generated `__init__` methods, although there are still some issues to resolve at the moment.

NOTE that this is a **breaking change**. Even so, I still believe it is worth making, mainly for two reasons: 
- the current number of users of Python-defined dialects is still quite limited;
- this change will bring long-term benefits.

Assisted by Copilot.


---

Patch is 22.70 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/186574.diff


4 Files Affected:

- (modified) mlir/python/mlir/dialects/ext.py (+108-48) 
- (modified) mlir/test/python/dialects/ext.py (+42-31) 
- (modified) mlir/test/python/dialects/transform_op_interface.py (+8-8) 
- (modified) mlir/test/python/integration/dialects/bf.py (+3-3) 


``````````diff
diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 867da6ee96637..57260e09971aa 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -34,6 +34,7 @@
     "Region",
     "Type",
     "Attribute",
+    "infer_type",
 ]
 
 Operand = ir.Value
@@ -42,6 +43,9 @@
 
 
 def construct_instance(origin, args):
+    if not issubclass(origin, ir.Type | ir.Attribute):
+        raise TypeError(f"unsupported type in constraints: {origin}")
+
     # `origin.get` is to construct an instance of MLIR type or attribute.
     return origin.get(
         *(
@@ -109,21 +113,48 @@ def _lower(self, type_) -> ir.Value:
         raise TypeError(f"unsupported type in constraints: {type_}")
 
 
-def infer_type(type_) -> Optional[Callable[[], ir.Type]]:
+ at dataclass
+class Marker:
+    infer_type: bool = False
+    default_is_none: bool = False
+
+    def __post_init__(self):
+        if self.infer_type and self.default_is_none:
+            raise ValueError(
+                "a field cannot be marked with both infer_type and default_is_none"
+            )
+
+    def kw_only(self) -> bool:
+        return self.default_is_none or self.infer_type
+
+
+def infer_type() -> Any:
+    """
+    A marker to indicate that the type of a result should be inferred.
+    It can only be used in `Result` definitions.
+    """
+
+    return Marker(infer_type=True)
+
+
+def infer_type_impl(type_) -> Callable[[], ir.Type]:
     """
     A function to infer ir.Type from type annotation.
-    Returns a callable that returns the inferred ir.Type,
-    or None if the type cannot be inferred.
+    Returns a callable that returns the inferred ir.Type.
     We use callables so that MLIR contexts are not required
     while calling this function.
     """
 
     origin = get_origin(type_)
-    if origin and issubclass(origin, ir.Type):
-        return lambda: construct_instance(origin, get_args(type_))
+    if origin and issubclass(origin, ir.Type | ir.Attribute):
+        args = [
+            infer_type_impl(arg) if get_origin(arg) else lambda: arg
+            for arg in get_args(type_)
+        ]
+        return lambda: origin.get(*[arg() for arg in args])
     elif isinstance(type_, TypeVar):
-        return infer_type(type_.__bound__)
-    return None
+        return infer_type_impl(type_.__bound__)
+    raise TypeError(f"unsupported type for inferring: {type_}")
 
 
 @dataclass
@@ -134,9 +165,12 @@ class FieldDef:
 
     name: str
     variadicity: Variadicity
+    constraint: Any
+
+    kw_only: bool = False
 
     @staticmethod
-    def from_type_hint(name, type_) -> "FieldDef":
+    def from_type_hint(name, type_, marker) -> "FieldDef":
         variadicity = Variadicity.single
         if inner := match_optional(type_):
             variadicity = Variadicity.optional
@@ -147,29 +181,66 @@ def from_type_hint(name, type_) -> "FieldDef":
 
         origin = get_origin(type_)
         if origin is ir.OpResult:
-            return ResultDef(name, variadicity, get_args(type_)[0])
+            constraint = get_args(type_)[0]
+            return ResultDef(
+                name,
+                variadicity,
+                constraint,
+                kw_only=marker.kw_only(),
+                infer_type=infer_type_impl(constraint) if marker.infer_type else None,
+            )
         elif origin is ir.Value:
-            return OperandDef(name, variadicity, get_args(type_)[0])
+            return OperandDef(
+                name,
+                variadicity,
+                get_args(type_)[0],
+                kw_only=marker.kw_only(),
+            )
         elif issubclass(origin or type_, ir.Attribute):
             return AttributeDef(name, variadicity, type_)
         elif type_ is ir.Region:
-            return RegionDef(name, variadicity)
-        raise TypeError(f"unsupported type in operation definition: {type_}")
+            return RegionDef(name, variadicity, Any)
+        raise TypeError(
+            f"unsupported type for field '{name}' in operation definition: {type_}"
+        )
 
 
 @dataclass
 class OperandDef(FieldDef):
-    constraint: Any
+    def __post_init__(self):
+        if self.variadicity != Variadicity.optional and self.kw_only:
+            raise ValueError(f"only optional operand can be a keyword parameter")
 
 
 @dataclass
 class ResultDef(FieldDef):
-    constraint: Any
+    infer_type: Callable[[], ir.Type] | None = None
+
+    def __post_init__(self):
+        if (
+            self.variadicity != Variadicity.optional
+            and not self.infer_type
+            and self.kw_only
+        ):
+            raise ValueError(f"only optional result can be a keyword parameter")
+
+        if self.infer_type and self.variadicity != Variadicity.single:
+            raise ValueError(
+                f"type of variadic or optional result '{self.name}' cannot be inferred"
+            )
+
+    def process_type(self, type_):
+        if type_:
+            return type_
+
+        if self.infer_type:
+            return self.infer_type()
+
+        return None
 
 
 @dataclass
 class AttributeDef(FieldDef):
-    constraint: Any
 
     def __post_init__(self):
         if self.variadicity != Variadicity.single:
@@ -267,7 +338,17 @@ def __init_subclass__(
             if hasattr(base, "_fields"):
                 fields.extend(base._fields)
         for key, value in cls.__annotations__.items():
-            field = FieldDef.from_type_hint(key, value)
+            # if the class variable is not defined, we treat it as a default marker;
+            # if it is assigned with `None`, we treat it as a marker with `default_is_none=True`.
+            # e.g. x : int         # default marker
+            #      y : int = None  # marker with default_is_none=True
+            marker = cls.__dict__.get(key, Marker()) or Marker(default_is_none=True)
+            # treat all other values as invalid
+            if not isinstance(marker, Marker):
+                raise TypeError(
+                    f"the field specifier of field '{key}' is not supported"
+                )
+            field = FieldDef.from_type_hint(key, value, marker)
             fields.append(field)
 
         cls._fields = fields
@@ -336,27 +417,17 @@ def _generate_segments(
         return None
 
     @staticmethod
-    def _generate_init_signature(
-        fields: List[FieldDef], can_infer_types: bool
-    ) -> Signature:
-        result_args = (
-            [] if can_infer_types else [i for i in fields if isinstance(i, ResultDef)]
-        )
-        # results are placed at the beginning of the parameter list,
-        # but operands and attributes can appear in any relative order.
-        args = result_args + [
-            i for i in fields if not isinstance(i, ResultDef | RegionDef)
-        ]
-        positional_args = [
-            i.name for i in args if i.variadicity != Variadicity.optional
-        ]
-        optional_args = [i.name for i in args if i.variadicity == Variadicity.optional]
+    def _generate_init_signature(fields: List[FieldDef]) -> Signature:
+        args = [i for i in fields if not isinstance(i, RegionDef)]
 
         params = [Parameter("self", Parameter.POSITIONAL_ONLY)]
-        for i in positional_args:
-            params.append(Parameter(i, Parameter.POSITIONAL_OR_KEYWORD))
-        for i in optional_args:
-            params.append(Parameter(i, Parameter.KEYWORD_ONLY, default=None))
+
+        for i in args:
+            if i.kw_only:
+                params.append(Parameter(i.name, Parameter.KEYWORD_ONLY, default=None))
+            else:
+                params.append(Parameter(i.name, Parameter.POSITIONAL_OR_KEYWORD))
+
         params.append(Parameter("loc", Parameter.KEYWORD_ONLY, default=None))
         params.append(Parameter("ip", Parameter.KEYWORD_ONLY, default=None))
 
@@ -365,15 +436,8 @@ def _generate_init_signature(
     @classmethod
     def _generate_init_method(cls, fields: List[FieldDef]) -> None:
         operands, attrs, results, regions = partition_fields(fields)
-        inferred_types = [infer_type(i.constraint) for i in results]
 
-        # we infer result types only when all result types can be inferred
-        # and all results are single (not optional or variadic)
-        can_infer_types = all(inferred_types) and all(
-            i.variadicity == Variadicity.single for i in results
-        )
-
-        init_sig = cls._generate_init_signature(fields, can_infer_types)
+        init_sig = cls._generate_init_signature(fields)
 
         def __init__(*args, **kwargs):
             bound = init_sig.bind(*args, **kwargs)
@@ -381,11 +445,7 @@ def __init__(*args, **kwargs):
             args = bound.arguments
 
             _operands = [args[operand.name] for operand in operands]
-            _results = (
-                [t() for t in inferred_types]
-                if can_infer_types
-                else [args[result.name] for result in results]
-            )
+            _results = [result.process_type(args[result.name]) for result in results]
             _attributes = dict(
                 (attr.name, args[attr.name])
                 for attr in attrs
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index 30132f891faec..abb50c21e8a5a 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -23,12 +23,12 @@ class MyInt(Dialect, name="myint"):
 
     class ConstantOp(MyInt.Operation, name="constant"):
         value: IntegerAttr
-        cst: Result[i32]
+        cst: Result[i32] = infer_type()
 
     class AddOp(Operation, dialect=MyInt, name="add"):
         lhs: Operand[i32]
         rhs: Operand[i32]
-        res: Result[i32]
+        res: Result[i32] = infer_type()
 
     # CHECK: irdl.dialect @myint {
     # CHECK:   irdl.operation @constant {
@@ -88,9 +88,9 @@ class AddOp(Operation, dialect=MyInt, name="add"):
         print(two.value)
         # CHECK: OpResult(%0
         print(two.cst)
-        # CHECK: (self, /, lhs, rhs, *, loc=None, ip=None)
+        # CHECK: (self, /, lhs, rhs, *, res=None, loc=None, ip=None)
         print(AddOp.__init__.__signature__)
-        # CHECK: (self, /, value, *, loc=None, ip=None)
+        # CHECK: (self, /, value, *, cst=None, loc=None, ip=None)
         print(ConstantOp.__init__.__signature__)
 
         # CHECK: True
@@ -124,28 +124,28 @@ class ConstraintOp(Test.Operation, name="constraint"):
         y: FloatAttr
 
     class OptionalOp(Test.Operation, name="optional"):
-        a: Operand[i32]
-        b: Optional[Operand[i32]]
         out1: Result[i32]
-        out2: Result[i32] | None
         out3: Result[i32]
+        a: Operand[i32]
+        out2: Result[i32] | None = None
+        b: Optional[Operand[i32]] = None
 
     class Optional2Op(Test.Operation, name="optional2"):
-        a: Optional[Operand[i32]]
-        b: Optional[Result[i32]]
+        b: Optional[Result[i32]] = None
+        a: Optional[Operand[i32]] = None
 
     class VariadicOp(Test.Operation, name="variadic"):
-        a: Operand[i32]
-        b: Optional[Operand[i32]]
-        c: Sequence[Operand[i32]]
         out1: Sequence[Result[i32]]
         out2: Sequence[Result[i32]]
-        out3: Optional[Result[i32]]
         out4: Result[i32]
+        a: Operand[i32]
+        c: Sequence[Operand[i32]]
+        out3: Optional[Result[i32]] = None
+        b: Optional[Operand[i32]] = None
 
     class Variadic2Op(Test.Operation, name="variadic2"):
-        a: Sequence[Operand[i32]]
         b: Sequence[Result[i32]]
+        a: Sequence[Operand[i32]]
 
     class MixedOpBase(Test.Operation):
         out: Result[i32]
@@ -153,9 +153,9 @@ class MixedOpBase(Test.Operation):
 
     class MixedOp(MixedOpBase, name="mixed"):
         in2: IntegerAttr
-        in3: Optional[Operand[i32]]
         in4: IntegerAttr
         in5: Operand[i32]
+        in3: Optional[Operand[i32]] = None
 
     T = TypeVar("T")
     U = TypeVar("U", bound=IntegerType[32] | IntegerType[64])
@@ -168,6 +168,11 @@ class TypeVarOp(Test.Operation, name="type_var"):
         in4: Operand[U | V]
         in5: Operand[V]
 
+    class OptionalButNotKeywordOp(Test.Operation, name="optional_but_not_keyword"):
+        a: Operand[i32]
+        b: Optional[Operand[i32]]
+        c: Operand[i32]
+
     # CHECK: irdl.dialect @ext_test {
     # CHECK:   irdl.operation @constraint {
     # CHECK:     %0 = irdl.is i32
@@ -185,7 +190,7 @@ class TypeVarOp(Test.Operation, name="type_var"):
     # CHECK:   irdl.operation @optional {
     # CHECK:     %0 = irdl.is i32
     # CHECK:     irdl.operands(a: %0, b: optional %0)
-    # CHECK:     irdl.results(out1: %0, out2: optional %0, out3: %0)
+    # CHECK:     irdl.results(out1: %0, out3: %0, out2: optional %0)
     # CHECK:   }
     # CHECK:   irdl.operation @optional2 {
     # CHECK:     %0 = irdl.is i32
@@ -194,8 +199,8 @@ class TypeVarOp(Test.Operation, name="type_var"):
     # CHECK:   }
     # CHECK:   irdl.operation @variadic {
     # CHECK:     %0 = irdl.is i32
-    # CHECK:     irdl.operands(a: %0, b: optional %0, c: variadic %0)
-    # CHECK:     irdl.results(out1: variadic %0, out2: variadic %0, out3: optional %0, out4: %0)
+    # CHECK:     irdl.operands(a: %0, c: variadic %0, b: optional %0)
+    # CHECK:     irdl.results(out1: variadic %0, out2: variadic %0, out4: %0, out3: optional %0)
     # CHECK:   }
     # CHECK:   irdl.operation @variadic2 {
     # CHECK:     %0 = irdl.is i32
@@ -204,7 +209,7 @@ class TypeVarOp(Test.Operation, name="type_var"):
     # CHECK:   }
     # CHECK:   irdl.operation @mixed {
     # CHECK:     %0 = irdl.is i32
-    # CHECK:     irdl.operands(in1: %0, in3: optional %0, in5: %0)
+    # CHECK:     irdl.operands(in1: %0, in5: %0, in3: optional %0)
     # CHECK:     %1 = irdl.base "#builtin.integer"
     # CHECK:     %2 = irdl.base "#builtin.integer"
     # CHECK:     irdl.attributes {"in2" = %1, "in4" = %2}
@@ -236,16 +241,20 @@ class TypeVarOp(Test.Operation, name="type_var"):
         print(VariadicOp.__init__.__signature__)
         # CHECK: (self, /, b, a, *, loc=None, ip=None)
         print(Variadic2Op.__init__.__signature__)
-        # CHECK: (self, /, in1, in2, in4, in5, *, in3=None, loc=None, ip=None)
+        # CHECK: (self, /, out, in1, in2, in4, in5, *, in3=None, loc=None, ip=None)
         print(MixedOp.__init__.__signature__)
+        # CHECK: (self, /, in1, in2, in3, in4, in5, *, loc=None, ip=None)
+        print(TypeVarOp.__init__.__signature__)
+        # CHECK: (self, /, a, b, c, *, loc=None, ip=None)
+        print(OptionalButNotKeywordOp.__init__.__signature__)
 
         # CHECK: None None
         print(ConstraintOp._ODS_OPERAND_SEGMENTS, ConstraintOp._ODS_RESULT_SEGMENTS)
-        # CHECK: [1, 0] [1, 0, 1]
+        # CHECK: [1, 0] [1, 1, 0]
         print(OptionalOp._ODS_OPERAND_SEGMENTS, OptionalOp._ODS_RESULT_SEGMENTS)
         # CHECK: [0] [0]
         print(Optional2Op._ODS_OPERAND_SEGMENTS, Optional2Op._ODS_RESULT_SEGMENTS)
-        # CHECK: [1, 0, -1] [-1, -1, 0, 1]
+        # CHECK: [1, -1, 0] [-1, -1, 1, 0]
         print(VariadicOp._ODS_OPERAND_SEGMENTS, VariadicOp._ODS_RESULT_SEGMENTS)
         # CHECK: [-1] [-1]
         print(Variadic2Op._ODS_OPERAND_SEGMENTS, Variadic2Op._ODS_RESULT_SEGMENTS)
@@ -269,7 +278,7 @@ class TypeVarOp(Test.Operation, name="type_var"):
             # CHECK: ext_test.constraint"(%c1_i32, %cst, %c1_i32, %cst) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, f32, i32, f32) -> ()
             ConstraintOp(ione, fone, ione, fone, iattr, fattr)
 
-            # CHECK: %0:2 = "ext_test.optional"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0>, resultSegmentSizes = array<i32: 1, 0, 1>} : (i32) -> (i32, i32)
+            # CHECK: %0:2 = "ext_test.optional"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0>, resultSegmentSizes = array<i32: 1, 1, 0>} : (i32) -> (i32, i32)
             o1 = OptionalOp(i32, i32, ione)
             # CHECK: %1:3 = "ext_test.optional"(%c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 1>, resultSegmentSizes = array<i32: 1, 1, 1>} : (i32, i32) -> (i32, i32, i32)
             o2 = OptionalOp(i32, i32, ione, out2=i32, b=ione)
@@ -282,11 +291,11 @@ class TypeVarOp(Test.Operation, name="type_var"):
             # CHECK: %3 = "ext_test.optional2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 1>} : (i32) -> i32
             o6 = Optional2Op(b=i32, a=ione)
 
-            # CHECK: %4:4 = "ext_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 0, 2>, resultSegmentSizes = array<i32: 1, 2, 0, 1>} : (i32, i32, i32) -> (i32, i32, i32, i32)
+            # CHECK: %4:4 = "ext_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 2, 0>, resultSegmentSizes = array<i32: 1, 2, 1, 0>} : (i32, i32, i32) -> (i32, i32, i32, i32)
             v1 = VariadicOp([i32], [i32, i32], i32, ione, [ione, ione])
             # CHECK: %5:5 = "ext_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 1, 1>, resultSegmentSizes = array<i32: 1, 2, 1, 1>} : (i32, i32, i32) -> (i32, i32, i32, i32, i32)
             v2 = VariadicOp([i32], [i32, i32], i32, ione, [ione], out3=i32, b=ione)
-            # CHECK: %6:4 = "ext_test.variadic"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0, 0>, resultSegmentSizes = array<i32: 2, 1, 0, 1>} : (i32) -> (i32, i32, i32, i32)
+            # CHECK: %6:4 = "ext_test.variadic"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0, 0>, resultSegmentSizes = array<i32: 2, 1, 1, 0>} : (i32) -> (i32, i32, i32, i32)
             v3 = VariadicOp([i32, i32], [i32], i32, ione, [])
             # CHECK: "ext_test.variadic2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 0>} : () -> ()
             v4 = Variadic2Op([], [])
@@ -295,10 +304,10 @@ class TypeVarOp(Test.Operation, name="type_var"):
             # CHECK: %7:2 = "ext_test.variadic2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 2>} : (i32) -> (i32, i32)
             v6 = Variadic2Op([i32, i32], [ione])
 
-            # CHECK: %8 = "ext_test.mixed"(%c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 0, 1>} : (i32, i32) -> i32
-            m1 = MixedOp(ione, iattr, iattr, ione)
+            # CHECK: %8 = "ext_test.mixed"(%c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 1, 0>} : (i32, i32) -> i32
+            m1 = MixedOp(i32, ione, iattr, iattr, ione)
             # CHECK: %9 = "ext_test.mixed"(%c1_i32, %c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 1, 1>} : (i32, i32, i32) -> i32
-            m2 = MixedOp(ione, iattr, iattr, ione, in3=ione)
+            m2 = MixedOp(i32, ione, iattr, iattr, ione, in3=ione)
 
         print(module)
         assert module.operation.verify()
@@ -319,7 +328,7 @@ class TypeVarOp(Test.Operation, name="type_var"):
         print(o1.out2)
         # CHECK: 0
         print(o2.out1.result_number)
-        # CHECK: 1
+        # CHECK: 2
         print(o2.out2.result_number)
         # CHECK: None
         print(o3.a)
@@ -371,8 +380,8 @@ class TestRegion(Dialect, name="ext_region"):
         pass
 
     class IfOp(TestRegion.Operation, name="if"):
-        cond: Operand[IntegerType[1]]
         result: Result[Any]
+        cond: Operand[IntegerType[1]]
         then: Region
         else_: Region
 
@@ -547,7 +556,9 @@ class MakeArrayOp(TestType.Operation, name="make_array"):
         arr: Result[Array]
 
     class MakeArray3Op(TestType.Operation, name="make_array3"):
-        arr: Result[Array[IntegerType[32], IntegerAttr[IntegerType[32], 3]]]
+        arr: Result[
+            Array[IntegerType[32], IntegerAttr[IntegerType[32], 3]]
+        ] = infer_type()
 
     with Context(), Location.unknown():
         TestType.load()
diff --git a/mlir/test/python/dialects/transform_op_interface.py b/mlir/test/python/dialects/transform_op_interface.py
index a6e2c6da45322..9bc17a8efdc43 100644
--- a/mlir/test/python/dialects/transform_op_interface.py
+++ b/mlir/test/python/dialects/transform_op_interface.py
@@ -88,7 +88,7 @@ def get_effects(op: ir.Operation, effects):
 class GetNamedAttributeOp(MyTransform.Operation, name="get_named_attribute"):
     target: ext.Operand[transform.AnyOpType]
     attr_name: ir.StringAttr
-    attr_as_param: ext.Result[transform.AnyParamType[()]]
+    attr_as_param: e...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/186574


More information about the Mlir-commits mailing list