[Mlir-commits] [mlir] [MLIR][Python] Make init parameters follow the field definition order (PR #186574)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Mar 14 01:10:36 PDT 2026
https://github.com/PragmaTwice created https://github.com/llvm/llvm-project/pull/186574
Currently, Python-defined operations automatically generate an `__init__` function to serve as the operation builder. Previously, the parameters of this `__init__` function followed a fairly complex set of rules. For example:
* All result fields were moved to the front to align with other op builders.
* Fields of `Optional` type were automatically moved to the end and treated as keyword parameters.
* If the types of all result fields could be inferred automatically, then all result fields were removed from the parameter list.
* Other than that, the parameter order followed the field definition order.
These rules may seem reasonable, and they have worked well in practice, but they have one major drawback: users cannot easily tell what the actual `__init__` parameter list will look like when writing code, because the rules are simply too complicated. Users can inspect the signature at runtime via `MyOp.__init__.__signature__`, but this is clearly poor for the development experience.
After some offline discussion with Rolf, we decided to replace this with the following rules, which is what this PR implements:
* The parameters of `__init__` now strictly follow the field definition order.
* By default, all parameters are required. A field becomes a keyword parameter only if it is declared with `= None`.
* By default, no type inference is performed. A field becomes a keyword parameter and participates in type inference only if it is declared with `= infer_type()`.
These new rules give users full control over the `__init__` parameter list, making it easy to understand the parameter order and explicitly control optional parameters and type inference. In addition, this makes it much easier in the future to use `dataclass_transform` so that type checkers can understand these automatically generated `__init__` methods, although there are still some issues to resolve at the moment.
>From 9c7d601c1836630988faf464fd19ed7eb9c8e1c8 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 14 Mar 2026 15:35:48 +0800
Subject: [PATCH] [MLIR][Python] Make init parameters follow the field
definition order
---
mlir/python/mlir/dialects/ext.py | 152 ++++++++++++------
mlir/test/python/dialects/ext.py | 73 +++++----
.../python/dialects/transform_op_interface.py | 16 +-
mlir/test/python/integration/dialects/bf.py | 6 +-
4 files changed, 156 insertions(+), 91 deletions(-)
diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 15651a1c4e858..736f3b757f38f 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -17,7 +17,7 @@
from collections.abc import Sequence
from dataclasses import dataclass
from inspect import Parameter, Signature
-from types import UnionType
+from types import UnionType, SimpleNamespace
from . import irdl
from ._ods_common import _cext, segmented_accessor
from .irdl import Variadicity
@@ -34,6 +34,7 @@
"Region",
"Type",
"Attribute",
+ "infer_type",
"register_dialect",
"register_operation",
]
@@ -59,6 +60,9 @@ def decorator(op_cls: type) -> type:
def construct_instance(origin, args):
+ if not issubclass(origin, ir.Type | ir.Attribute):
+ raise TypeError(f"unsupported type in constraints: {origin}")
+
# `origin.get` is to construct an instance of MLIR type or attribute.
return origin.get(
*(
@@ -126,21 +130,48 @@ def _lower(self, type_) -> ir.Value:
raise TypeError(f"unsupported type in constraints: {type_}")
-def infer_type(type_) -> Optional[Callable[[], ir.Type]]:
+ at dataclass
+class Marker:
+ infer_type: bool = False
+ default_is_none: bool = False
+
+ def __post_init__(self):
+ if self.infer_type and self.default_is_none:
+ raise ValueError(
+ "a field cannot be marked with both infer_type and default_is_none"
+ )
+
+ def kw_only(self) -> bool:
+ return self.default_is_none or self.infer_type
+
+
+def infer_type() -> Any:
+ """
+ A marker to indicate that the type of a result should be inferred.
+ It can only be used in `Result` definitions.
+ """
+
+ return Marker(infer_type=True)
+
+
+def infer_type_impl(type_) -> Callable[[], ir.Type]:
"""
A function to infer ir.Type from type annotation.
- Returns a callable that returns the inferred ir.Type,
- or None if the type cannot be inferred.
+ Returns a callable that returns the inferred ir.Type.
We use callables so that MLIR contexts are not required
while calling this function.
"""
origin = get_origin(type_)
- if origin and issubclass(origin, ir.Type):
- return lambda: construct_instance(origin, get_args(type_))
+ if origin and issubclass(origin, ir.Type | ir.Attribute):
+ args = [
+ infer_type_impl(arg) if get_origin(arg) else lambda: arg
+ for arg in get_args(type_)
+ ]
+ return lambda: origin.get(*[arg() for arg in args])
elif isinstance(type_, TypeVar):
- return infer_type(type_.__bound__)
- return None
+ return infer_type_impl(type_.__bound__)
+ raise TypeError(f"unsupported type for inferring: {type_}")
@dataclass
@@ -151,9 +182,12 @@ class FieldDef:
name: str
variadicity: Variadicity
+ constraint: Any
+
+ kw_only: bool = False
@staticmethod
- def from_type_hint(name, type_) -> "FieldDef":
+ def from_type_hint(name, type_, marker) -> "FieldDef":
variadicity = Variadicity.single
if inner := match_optional(type_):
variadicity = Variadicity.optional
@@ -164,29 +198,57 @@ def from_type_hint(name, type_) -> "FieldDef":
origin = get_origin(type_)
if origin is ir.OpResult:
- return ResultDef(name, variadicity, get_args(type_)[0])
+ constraint = get_args(type_)[0]
+ return ResultDef(
+ name,
+ variadicity,
+ constraint,
+ kw_only=marker.kw_only(),
+ infer_type=infer_type_impl(constraint) if marker.infer_type else None,
+ )
elif origin is ir.Value:
- return OperandDef(name, variadicity, get_args(type_)[0])
+ return OperandDef(
+ name,
+ variadicity,
+ get_args(type_)[0],
+ kw_only=marker.kw_only(),
+ )
elif issubclass(origin or type_, ir.Attribute):
return AttributeDef(name, variadicity, type_)
elif type_ is ir.Region:
- return RegionDef(name, variadicity)
- raise TypeError(f"unsupported type in operation definition: {type_}")
+ return RegionDef(name, variadicity, Any)
+ raise TypeError(
+ f"unsupported type for field '{name}' in operation definition: {type_}"
+ )
@dataclass
class OperandDef(FieldDef):
- constraint: Any
+ def __post_init__(self):
+ if self.variadicity != Variadicity.optional and self.kw_only:
+ raise ValueError(f"only optional operand can be a keyword parameter")
@dataclass
class ResultDef(FieldDef):
- constraint: Any
+ infer_type: Callable[[], ir.Type] | None = None
+
+ def __post_init__(self):
+ if (
+ self.variadicity != Variadicity.optional
+ and not self.infer_type
+ and self.kw_only
+ ):
+ raise ValueError(f"only optional result can be a keyword parameter")
+
+ if self.infer_type and self.variadicity != Variadicity.single:
+ raise ValueError(
+ f"type of variadic or optional result '{self.name}' cannot be inferred"
+ )
@dataclass
class AttributeDef(FieldDef):
- constraint: Any
def __post_init__(self):
if self.variadicity != Variadicity.single:
@@ -284,7 +346,17 @@ def __init_subclass__(
if hasattr(base, "_fields"):
fields.extend(base._fields)
for key, value in cls.__annotations__.items():
- field = FieldDef.from_type_hint(key, value)
+ # if the class variable is not defined, we treat it as a default marker;
+ # if it is assigned with `None`, we treat it as a marker with `default_is_none=True`.
+ # e.g. x : int # default marker
+ # y : int = None # marker with default_is_none=True
+ marker = cls.__dict__.get(key, Marker()) or Marker(default_is_none=True)
+ # treat all other values as invalid
+ if not isinstance(marker, Marker):
+ raise TypeError(
+ f"the field specifier of field '{key}' is not supported"
+ )
+ field = FieldDef.from_type_hint(key, value, marker)
fields.append(field)
cls._fields = fields
@@ -353,27 +425,17 @@ def _generate_segments(
return None
@staticmethod
- def _generate_init_signature(
- fields: List[FieldDef], can_infer_types: bool
- ) -> Signature:
- result_args = (
- [] if can_infer_types else [i for i in fields if isinstance(i, ResultDef)]
- )
- # results are placed at the beginning of the parameter list,
- # but operands and attributes can appear in any relative order.
- args = result_args + [
- i for i in fields if not isinstance(i, ResultDef | RegionDef)
- ]
- positional_args = [
- i.name for i in args if i.variadicity != Variadicity.optional
- ]
- optional_args = [i.name for i in args if i.variadicity == Variadicity.optional]
+ def _generate_init_signature(fields: List[FieldDef]) -> Signature:
+ args = [i for i in fields if not isinstance(i, RegionDef)]
params = [Parameter("self", Parameter.POSITIONAL_ONLY)]
- for i in positional_args:
- params.append(Parameter(i, Parameter.POSITIONAL_OR_KEYWORD))
- for i in optional_args:
- params.append(Parameter(i, Parameter.KEYWORD_ONLY, default=None))
+
+ for i in args:
+ if i.kw_only:
+ params.append(Parameter(i.name, Parameter.KEYWORD_ONLY, default=None))
+ else:
+ params.append(Parameter(i.name, Parameter.POSITIONAL_OR_KEYWORD))
+
params.append(Parameter("loc", Parameter.KEYWORD_ONLY, default=None))
params.append(Parameter("ip", Parameter.KEYWORD_ONLY, default=None))
@@ -382,15 +444,8 @@ def _generate_init_signature(
@classmethod
def _generate_init_method(cls, fields: List[FieldDef]) -> None:
operands, attrs, results, regions = partition_fields(fields)
- inferred_types = [infer_type(i.constraint) for i in results]
-
- # we infer result types only when all result types can be inferred
- # and all results are single (not optional or variadic)
- can_infer_types = all(inferred_types) and all(
- i.variadicity == Variadicity.single for i in results
- )
- init_sig = cls._generate_init_signature(fields, can_infer_types)
+ init_sig = cls._generate_init_signature(fields)
def __init__(*args, **kwargs):
bound = init_sig.bind(*args, **kwargs)
@@ -398,11 +453,10 @@ def __init__(*args, **kwargs):
args = bound.arguments
_operands = [args[operand.name] for operand in operands]
- _results = (
- [t() for t in inferred_types]
- if can_infer_types
- else [args[result.name] for result in results]
- )
+ _results = [
+ result.infer_type() if result.infer_type else args[result.name]
+ for result in results
+ ]
_attributes = dict(
(attr.name, args[attr.name])
for attr in attrs
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index 30132f891faec..f1d920e68621e 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -23,12 +23,12 @@ class MyInt(Dialect, name="myint"):
class ConstantOp(MyInt.Operation, name="constant"):
value: IntegerAttr
- cst: Result[i32]
+ cst: Result[i32] = infer_type()
class AddOp(Operation, dialect=MyInt, name="add"):
lhs: Operand[i32]
rhs: Operand[i32]
- res: Result[i32]
+ res: Result[i32] = infer_type()
# CHECK: irdl.dialect @myint {
# CHECK: irdl.operation @constant {
@@ -88,9 +88,9 @@ class AddOp(Operation, dialect=MyInt, name="add"):
print(two.value)
# CHECK: OpResult(%0
print(two.cst)
- # CHECK: (self, /, lhs, rhs, *, loc=None, ip=None)
+ # CHECK: (self, /, lhs, rhs, *, res=None, loc=None, ip=None)
print(AddOp.__init__.__signature__)
- # CHECK: (self, /, value, *, loc=None, ip=None)
+ # CHECK: (self, /, value, *, cst=None, loc=None, ip=None)
print(ConstantOp.__init__.__signature__)
# CHECK: True
@@ -124,28 +124,28 @@ class ConstraintOp(Test.Operation, name="constraint"):
y: FloatAttr
class OptionalOp(Test.Operation, name="optional"):
- a: Operand[i32]
- b: Optional[Operand[i32]]
out1: Result[i32]
- out2: Result[i32] | None
out3: Result[i32]
+ a: Operand[i32]
+ out2: Result[i32] | None = None
+ b: Optional[Operand[i32]] = None
class Optional2Op(Test.Operation, name="optional2"):
- a: Optional[Operand[i32]]
- b: Optional[Result[i32]]
+ b: Optional[Result[i32]] = None
+ a: Optional[Operand[i32]] = None
class VariadicOp(Test.Operation, name="variadic"):
- a: Operand[i32]
- b: Optional[Operand[i32]]
- c: Sequence[Operand[i32]]
out1: Sequence[Result[i32]]
out2: Sequence[Result[i32]]
- out3: Optional[Result[i32]]
out4: Result[i32]
+ a: Operand[i32]
+ c: Sequence[Operand[i32]]
+ out3: Optional[Result[i32]] = None
+ b: Optional[Operand[i32]] = None
class Variadic2Op(Test.Operation, name="variadic2"):
- a: Sequence[Operand[i32]]
b: Sequence[Result[i32]]
+ a: Sequence[Operand[i32]]
class MixedOpBase(Test.Operation):
out: Result[i32]
@@ -153,9 +153,9 @@ class MixedOpBase(Test.Operation):
class MixedOp(MixedOpBase, name="mixed"):
in2: IntegerAttr
- in3: Optional[Operand[i32]]
in4: IntegerAttr
in5: Operand[i32]
+ in3: Optional[Operand[i32]] = None
T = TypeVar("T")
U = TypeVar("U", bound=IntegerType[32] | IntegerType[64])
@@ -168,6 +168,11 @@ class TypeVarOp(Test.Operation, name="type_var"):
in4: Operand[U | V]
in5: Operand[V]
+ class OptionalButNotKeywordOp(Test.Operation, name="optional_but_not_keyword"):
+ a: Operand[i32]
+ b: Optional[Operand[i32]]
+ c: Operand[i32]
+
# CHECK: irdl.dialect @ext_test {
# CHECK: irdl.operation @constraint {
# CHECK: %0 = irdl.is i32
@@ -185,7 +190,7 @@ class TypeVarOp(Test.Operation, name="type_var"):
# CHECK: irdl.operation @optional {
# CHECK: %0 = irdl.is i32
# CHECK: irdl.operands(a: %0, b: optional %0)
- # CHECK: irdl.results(out1: %0, out2: optional %0, out3: %0)
+ # CHECK: irdl.results(out1: %0, out3: %0, out2: optional %0)
# CHECK: }
# CHECK: irdl.operation @optional2 {
# CHECK: %0 = irdl.is i32
@@ -194,8 +199,8 @@ class TypeVarOp(Test.Operation, name="type_var"):
# CHECK: }
# CHECK: irdl.operation @variadic {
# CHECK: %0 = irdl.is i32
- # CHECK: irdl.operands(a: %0, b: optional %0, c: variadic %0)
- # CHECK: irdl.results(out1: variadic %0, out2: variadic %0, out3: optional %0, out4: %0)
+ # CHECK: irdl.operands(a: %0, c: variadic %0, b: optional %0)
+ # CHECK: irdl.results(out1: variadic %0, out2: variadic %0, out4: %0, out3: optional %0)
# CHECK: }
# CHECK: irdl.operation @variadic2 {
# CHECK: %0 = irdl.is i32
@@ -204,7 +209,7 @@ class TypeVarOp(Test.Operation, name="type_var"):
# CHECK: }
# CHECK: irdl.operation @mixed {
# CHECK: %0 = irdl.is i32
- # CHECK: irdl.operands(in1: %0, in3: optional %0, in5: %0)
+ # CHECK: irdl.operands(in1: %0, in5: %0, in3: optional %0)
# CHECK: %1 = irdl.base "#builtin.integer"
# CHECK: %2 = irdl.base "#builtin.integer"
# CHECK: irdl.attributes {"in2" = %1, "in4" = %2}
@@ -236,16 +241,20 @@ class TypeVarOp(Test.Operation, name="type_var"):
print(VariadicOp.__init__.__signature__)
# CHECK: (self, /, b, a, *, loc=None, ip=None)
print(Variadic2Op.__init__.__signature__)
- # CHECK: (self, /, in1, in2, in4, in5, *, in3=None, loc=None, ip=None)
+ # CHECK: (self, /, out, in1, in2, in4, in5, *, in3=None, loc=None, ip=None)
print(MixedOp.__init__.__signature__)
+ # CHECK: (self, /, in1, in2, in3, in4, in5, *, loc=None, ip=None)
+ print(TypeVarOp.__init__.__signature__)
+ # CHECK: (self, /, a, b, c, *, loc=None, ip=None)
+ print(OptionalButNotKeywordOp.__init__.__signature__)
# CHECK: None None
print(ConstraintOp._ODS_OPERAND_SEGMENTS, ConstraintOp._ODS_RESULT_SEGMENTS)
- # CHECK: [1, 0] [1, 0, 1]
+ # CHECK: [1, 0] [1, 1, 0]
print(OptionalOp._ODS_OPERAND_SEGMENTS, OptionalOp._ODS_RESULT_SEGMENTS)
# CHECK: [0] [0]
print(Optional2Op._ODS_OPERAND_SEGMENTS, Optional2Op._ODS_RESULT_SEGMENTS)
- # CHECK: [1, 0, -1] [-1, -1, 0, 1]
+ # CHECK: [1, -1, 0] [-1, -1, 1, 0]
print(VariadicOp._ODS_OPERAND_SEGMENTS, VariadicOp._ODS_RESULT_SEGMENTS)
# CHECK: [-1] [-1]
print(Variadic2Op._ODS_OPERAND_SEGMENTS, Variadic2Op._ODS_RESULT_SEGMENTS)
@@ -269,7 +278,7 @@ class TypeVarOp(Test.Operation, name="type_var"):
# CHECK: ext_test.constraint"(%c1_i32, %cst, %c1_i32, %cst) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, f32, i32, f32) -> ()
ConstraintOp(ione, fone, ione, fone, iattr, fattr)
- # CHECK: %0:2 = "ext_test.optional"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0>, resultSegmentSizes = array<i32: 1, 0, 1>} : (i32) -> (i32, i32)
+ # CHECK: %0:2 = "ext_test.optional"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0>, resultSegmentSizes = array<i32: 1, 1, 0>} : (i32) -> (i32, i32)
o1 = OptionalOp(i32, i32, ione)
# CHECK: %1:3 = "ext_test.optional"(%c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 1>, resultSegmentSizes = array<i32: 1, 1, 1>} : (i32, i32) -> (i32, i32, i32)
o2 = OptionalOp(i32, i32, ione, out2=i32, b=ione)
@@ -282,11 +291,11 @@ class TypeVarOp(Test.Operation, name="type_var"):
# CHECK: %3 = "ext_test.optional2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 1>} : (i32) -> i32
o6 = Optional2Op(b=i32, a=ione)
- # CHECK: %4:4 = "ext_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 0, 2>, resultSegmentSizes = array<i32: 1, 2, 0, 1>} : (i32, i32, i32) -> (i32, i32, i32, i32)
+ # CHECK: %4:4 = "ext_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 2, 0>, resultSegmentSizes = array<i32: 1, 2, 1, 0>} : (i32, i32, i32) -> (i32, i32, i32, i32)
v1 = VariadicOp([i32], [i32, i32], i32, ione, [ione, ione])
# CHECK: %5:5 = "ext_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 1, 1>, resultSegmentSizes = array<i32: 1, 2, 1, 1>} : (i32, i32, i32) -> (i32, i32, i32, i32, i32)
v2 = VariadicOp([i32], [i32, i32], i32, ione, [ione], out3=i32, b=ione)
- # CHECK: %6:4 = "ext_test.variadic"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0, 0>, resultSegmentSizes = array<i32: 2, 1, 0, 1>} : (i32) -> (i32, i32, i32, i32)
+ # CHECK: %6:4 = "ext_test.variadic"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0, 0>, resultSegmentSizes = array<i32: 2, 1, 1, 0>} : (i32) -> (i32, i32, i32, i32)
v3 = VariadicOp([i32, i32], [i32], i32, ione, [])
# CHECK: "ext_test.variadic2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 0>} : () -> ()
v4 = Variadic2Op([], [])
@@ -295,10 +304,10 @@ class TypeVarOp(Test.Operation, name="type_var"):
# CHECK: %7:2 = "ext_test.variadic2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 2>} : (i32) -> (i32, i32)
v6 = Variadic2Op([i32, i32], [ione])
- # CHECK: %8 = "ext_test.mixed"(%c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 0, 1>} : (i32, i32) -> i32
- m1 = MixedOp(ione, iattr, iattr, ione)
+ # CHECK: %8 = "ext_test.mixed"(%c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 1, 0>} : (i32, i32) -> i32
+ m1 = MixedOp(i32, ione, iattr, iattr, ione)
# CHECK: %9 = "ext_test.mixed"(%c1_i32, %c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 1, 1>} : (i32, i32, i32) -> i32
- m2 = MixedOp(ione, iattr, iattr, ione, in3=ione)
+ m2 = MixedOp(i32, ione, iattr, iattr, ione, in3=ione)
print(module)
assert module.operation.verify()
@@ -319,7 +328,7 @@ class TypeVarOp(Test.Operation, name="type_var"):
print(o1.out2)
# CHECK: 0
print(o2.out1.result_number)
- # CHECK: 1
+ # CHECK: 2
print(o2.out2.result_number)
# CHECK: None
print(o3.a)
@@ -371,8 +380,8 @@ class TestRegion(Dialect, name="ext_region"):
pass
class IfOp(TestRegion.Operation, name="if"):
- cond: Operand[IntegerType[1]]
result: Result[Any]
+ cond: Operand[IntegerType[1]]
then: Region
else_: Region
@@ -547,7 +556,9 @@ class MakeArrayOp(TestType.Operation, name="make_array"):
arr: Result[Array]
class MakeArray3Op(TestType.Operation, name="make_array3"):
- arr: Result[Array[IntegerType[32], IntegerAttr[IntegerType[32], 3]]]
+ arr: Result[Array[IntegerType[32], IntegerAttr[IntegerType[32], 3]]] = (
+ infer_type()
+ )
with Context(), Location.unknown():
TestType.load()
diff --git a/mlir/test/python/dialects/transform_op_interface.py b/mlir/test/python/dialects/transform_op_interface.py
index f58e0be13befd..c70082e892371 100644
--- a/mlir/test/python/dialects/transform_op_interface.py
+++ b/mlir/test/python/dialects/transform_op_interface.py
@@ -90,7 +90,7 @@ def get_effects(op: ir.Operation, effects):
class GetNamedAttributeOp(MyTransform.Operation, name="get_named_attribute"):
target: ext.Operand[transform.AnyOpType]
attr_name: ir.StringAttr
- attr_as_param: ext.Result[transform.AnyParamType[()]]
+ attr_as_param: ext.Result[transform.AnyParamType[()]] = ext.infer_type()
@classmethod
def attach_interface_impls(cls, ctx=None):
@@ -153,7 +153,7 @@ def allow_repeated_handle_operands(_op: "GetNamedAttributeOp") -> bool:
@ext.register_operation(MyTransform)
class OneOpInOneOpOut(MyTransform.Operation, name="one_op_in_one_op_out"):
target: ext.Operand[transform.AnyOpType]
- res: ext.Result[transform.AnyOpType[()]]
+ res: ext.Result[transform.AnyOpType[()]] = ext.infer_type()
# CHECK-LABEL: Test: OneOpInOneOpOutTransformOpInterface
@@ -282,9 +282,9 @@ class OpValParamInParamOpValOut(
val_arg: ext.Operand[transform.AnyValueType]
param_arg: ext.Operand[transform.AnyParamType]
# results
- param_res: ext.Result[transform.AnyParamType[()]]
- op_res: ext.Result[transform.AnyOpType[()]]
- value_res: ext.Result[transform.AnyValueType[()]]
+ param_res: ext.Result[transform.AnyParamType[()]] = ext.infer_type()
+ op_res: ext.Result[transform.AnyOpType[()]] = ext.infer_type()
+ value_res: ext.Result[transform.AnyValueType[()]] = ext.infer_type()
# CHECK-LABEL: Test: OpValParamInParamOpValOutTransformOpInterface
@@ -382,12 +382,12 @@ def allow_repeated_handle_operands(_op: OpValParamInParamOpValOut) -> bool:
class OpsParamsInValuesParamOut(
MyTransform.Operation, name="ops_params_in_values_param_out"
):
- # operands
- ops: Sequence[ext.Operand[transform.AnyOpType]]
- params: Sequence[ext.Operand[transform.AnyParamType]]
# results
values: Sequence[ext.Result[transform.AnyValueType]]
param: ext.Result[transform.AnyParamType]
+ # operands
+ ops: Sequence[ext.Operand[transform.AnyOpType]]
+ params: Sequence[ext.Operand[transform.AnyParamType]]
# CHECK-LABEL: Test: OpsParamsInValuesParamOutTransformOpInterface
diff --git a/mlir/test/python/integration/dialects/bf.py b/mlir/test/python/integration/dialects/bf.py
index cbb815378ddf7..61e25c6c3d872 100644
--- a/mlir/test/python/integration/dialects/bf.py
+++ b/mlir/test/python/integration/dialects/bf.py
@@ -20,12 +20,12 @@ class PtrType(BfDialect.Type, name="ptr"):
class NextOp(BfDialect.Operation, name="next"):
in_: Operand[PtrType]
- out: Result[PtrType[()]]
+ out: Result[PtrType[()]] = infer_type()
class PrevOp(BfDialect.Operation, name="prev"):
in_: Operand[PtrType]
- out: Result[PtrType[()]]
+ out: Result[PtrType[()]] = infer_type()
class IncOp(BfDialect.Operation, name="inc"):
@@ -46,7 +46,7 @@ class OutputOp(BfDialect.Operation, name="output"):
class WhileOp(BfDialect.Operation, name="while"):
in_: Operand[PtrType]
- out: Result[PtrType[()]]
+ out: Result[PtrType[()]] = infer_type()
body: Region
More information about the Mlir-commits
mailing list