[Mlir-commits] [mlir] d818fa4 - [MLIR][Python] Make init parameters follow the field definition order (#186574)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Mar 21 09:48:21 PDT 2026
Author: Twice
Date: 2026-03-22T00:48:14+08:00
New Revision: d818fa4c55c24b989eb925581e62b284c3c3a461
URL: https://github.com/llvm/llvm-project/commit/d818fa4c55c24b989eb925581e62b284c3c3a461
DIFF: https://github.com/llvm/llvm-project/commit/d818fa4c55c24b989eb925581e62b284c3c3a461.diff
LOG: [MLIR][Python] Make init parameters follow the field definition order (#186574)
Currently, Python-defined operations automatically generate an
`__init__` function to serve as the operation builder. Previously, the
parameters of this `__init__` function followed a fairly complex set of
rules. For example:
* All result fields were moved to the front to align with other op
builders.
* Fields of `Optional` type were automatically moved to the end and
treated as keyword parameters.
* If the types of all result fields could be inferred automatically,
then all result fields were removed from the parameter list.
* Other than that, the parameter order followed the field definition
order.
These rules may seem reasonable, and they have worked well in practice,
but they have one major drawback: users cannot easily tell what the
actual `__init__` parameter list will look like when writing code,
because the rules are simply too complicated. Users can inspect the
signature at runtime via `MyOp.__init__.__signature__`, but this is
clearly poor for the development experience.
After some offline discussion with Rolf, we decided to replace this with
the following rules, which is what this PR implements:
* The parameters of `__init__` now strictly follow the field definition
order.
* By default, all parameters are required. A field becomes a keyword
parameter only if it is declared with `= None`.
* By default, no type inference is performed. A field becomes a keyword
parameter and participates in type inference only if it is declared with
`= infer_type()`.
These new rules give users full control over the `__init__` parameter
list, making it easy to understand the parameter order and explicitly
control optional parameters and type inference. In addition, this makes
it much easier in the future to use `dataclass_transform` so that type
checkers can understand these automatically generated `__init__`
methods, although there are still some issues to resolve at the moment.
NOTE that this is a **breaking change**. Even so, I still believe it is
worth making, mainly for two reasons:
- the current number of users of Python-defined dialects is still quite
limited;
- this change will bring long-term benefits.
Assisted by Copilot.
Added:
Modified:
mlir/python/mlir/dialects/ext.py
mlir/test/python/dialects/ext.py
mlir/test/python/dialects/transform_op_interface.py
mlir/test/python/integration/dialects/bf.py
Removed:
################################################################################
diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 867da6ee966370..1900b8c162456a 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -34,6 +34,7 @@
"Region",
"Type",
"Attribute",
+ "result",
]
Operand = ir.Value
@@ -42,6 +43,9 @@
def construct_instance(origin, args):
+ if not issubclass(origin, ir.Type | ir.Attribute):
+ raise TypeError(f"unsupported type in constraints: {origin}")
+
# `origin.get` is to construct an instance of MLIR type or attribute.
return origin.get(
*(
@@ -109,21 +113,47 @@ def _lower(self, type_) -> ir.Value:
raise TypeError(f"unsupported type in constraints: {type_}")
-def infer_type(type_) -> Optional[Callable[[], ir.Type]]:
+ at dataclass
+class FieldSpecifier:
+ infer_type: bool = False
+ default_is_none: bool = False
+
+ def __post_init__(self):
+ if self.infer_type and self.default_is_none:
+ raise ValueError(
+ "a field cannot be marked with both infer_type and default_is_none"
+ )
+
+ def kw_only(self) -> bool:
+ return self.default_is_none or self.infer_type
+
+
+def result(*, infer_type: bool = False) -> Any:
+ """
+ A field specifier for `Result` definitions.
+ """
+
+ return FieldSpecifier(infer_type=infer_type)
+
+
+def infer_type_impl(type_) -> Callable[[], ir.Type]:
"""
A function to infer ir.Type from type annotation.
- Returns a callable that returns the inferred ir.Type,
- or None if the type cannot be inferred.
+ Returns a callable that returns the inferred ir.Type.
We use callables so that MLIR contexts are not required
while calling this function.
"""
origin = get_origin(type_)
- if origin and issubclass(origin, ir.Type):
- return lambda: construct_instance(origin, get_args(type_))
+ if origin and issubclass(origin, ir.Type | ir.Attribute):
+ args = [
+ infer_type_impl(arg) if get_origin(arg) else lambda: arg
+ for arg in get_args(type_)
+ ]
+ return lambda: origin.get(*[arg() for arg in args])
elif isinstance(type_, TypeVar):
- return infer_type(type_.__bound__)
- return None
+ return infer_type_impl(type_.__bound__)
+ raise TypeError(f"unsupported type for inferring: {type_}")
@dataclass
@@ -134,9 +164,12 @@ class FieldDef:
name: str
variadicity: Variadicity
+ constraint: Any
+
+ kw_only: bool = False
@staticmethod
- def from_type_hint(name, type_) -> "FieldDef":
+ def from_type_hint(name, type_, specifier) -> "FieldDef":
variadicity = Variadicity.single
if inner := match_optional(type_):
variadicity = Variadicity.optional
@@ -147,29 +180,68 @@ def from_type_hint(name, type_) -> "FieldDef":
origin = get_origin(type_)
if origin is ir.OpResult:
- return ResultDef(name, variadicity, get_args(type_)[0])
+ constraint = get_args(type_)[0]
+ return ResultDef(
+ name,
+ variadicity,
+ constraint,
+ kw_only=specifier.kw_only(),
+ infer_type=(
+ infer_type_impl(constraint) if specifier.infer_type else None
+ ),
+ )
elif origin is ir.Value:
- return OperandDef(name, variadicity, get_args(type_)[0])
+ return OperandDef(
+ name,
+ variadicity,
+ get_args(type_)[0],
+ kw_only=specifier.kw_only(),
+ )
elif issubclass(origin or type_, ir.Attribute):
return AttributeDef(name, variadicity, type_)
elif type_ is ir.Region:
- return RegionDef(name, variadicity)
- raise TypeError(f"unsupported type in operation definition: {type_}")
+ return RegionDef(name, variadicity, Any)
+ raise TypeError(
+ f"unsupported type for field '{name}' in operation definition: {type_}"
+ )
@dataclass
class OperandDef(FieldDef):
- constraint: Any
+ def __post_init__(self):
+ if self.variadicity != Variadicity.optional and self.kw_only:
+ raise ValueError(f"only optional operand can be a keyword parameter")
@dataclass
class ResultDef(FieldDef):
- constraint: Any
+ infer_type: Callable[[], ir.Type] | None = None
+
+ def __post_init__(self):
+ if (
+ self.variadicity != Variadicity.optional
+ and not self.infer_type
+ and self.kw_only
+ ):
+ raise ValueError(f"only optional result can be a keyword parameter")
+
+ if self.infer_type and self.variadicity != Variadicity.single:
+ raise ValueError(
+ f"type of variadic or optional result '{self.name}' cannot be inferred"
+ )
+
+ def process_type(self, type_):
+ if type_:
+ return type_
+
+ if self.infer_type:
+ return self.infer_type()
+
+ return None
@dataclass
class AttributeDef(FieldDef):
- constraint: Any
def __post_init__(self):
if self.variadicity != Variadicity.single:
@@ -267,7 +339,19 @@ def __init_subclass__(
if hasattr(base, "_fields"):
fields.extend(base._fields)
for key, value in cls.__annotations__.items():
- field = FieldDef.from_type_hint(key, value)
+ # if the class variable is not defined, we treat it as a default specifier;
+ # if it is assigned with `None`, we treat it as a specifier with `default_is_none=True`.
+ # e.g. x : int # default specifier
+ # y : int = None # specifier with default_is_none=True
+ specifier = cls.__dict__.get(key, FieldSpecifier()) or FieldSpecifier(
+ default_is_none=True
+ )
+ # treat all other values as invalid
+ if not isinstance(specifier, FieldSpecifier):
+ raise TypeError(
+ f"the field specifier of field '{key}' is not supported"
+ )
+ field = FieldDef.from_type_hint(key, value, specifier)
fields.append(field)
cls._fields = fields
@@ -336,27 +420,17 @@ def _generate_segments(
return None
@staticmethod
- def _generate_init_signature(
- fields: List[FieldDef], can_infer_types: bool
- ) -> Signature:
- result_args = (
- [] if can_infer_types else [i for i in fields if isinstance(i, ResultDef)]
- )
- # results are placed at the beginning of the parameter list,
- # but operands and attributes can appear in any relative order.
- args = result_args + [
- i for i in fields if not isinstance(i, ResultDef | RegionDef)
- ]
- positional_args = [
- i.name for i in args if i.variadicity != Variadicity.optional
- ]
- optional_args = [i.name for i in args if i.variadicity == Variadicity.optional]
+ def _generate_init_signature(fields: List[FieldDef]) -> Signature:
+ args = [i for i in fields if not isinstance(i, RegionDef)]
params = [Parameter("self", Parameter.POSITIONAL_ONLY)]
- for i in positional_args:
- params.append(Parameter(i, Parameter.POSITIONAL_OR_KEYWORD))
- for i in optional_args:
- params.append(Parameter(i, Parameter.KEYWORD_ONLY, default=None))
+
+ for i in args:
+ if i.kw_only:
+ params.append(Parameter(i.name, Parameter.KEYWORD_ONLY, default=None))
+ else:
+ params.append(Parameter(i.name, Parameter.POSITIONAL_OR_KEYWORD))
+
params.append(Parameter("loc", Parameter.KEYWORD_ONLY, default=None))
params.append(Parameter("ip", Parameter.KEYWORD_ONLY, default=None))
@@ -365,15 +439,8 @@ def _generate_init_signature(
@classmethod
def _generate_init_method(cls, fields: List[FieldDef]) -> None:
operands, attrs, results, regions = partition_fields(fields)
- inferred_types = [infer_type(i.constraint) for i in results]
- # we infer result types only when all result types can be inferred
- # and all results are single (not optional or variadic)
- can_infer_types = all(inferred_types) and all(
- i.variadicity == Variadicity.single for i in results
- )
-
- init_sig = cls._generate_init_signature(fields, can_infer_types)
+ init_sig = cls._generate_init_signature(fields)
def __init__(*args, **kwargs):
bound = init_sig.bind(*args, **kwargs)
@@ -381,11 +448,7 @@ def __init__(*args, **kwargs):
args = bound.arguments
_operands = [args[operand.name] for operand in operands]
- _results = (
- [t() for t in inferred_types]
- if can_infer_types
- else [args[result.name] for result in results]
- )
+ _results = [result.process_type(args[result.name]) for result in results]
_attributes = dict(
(attr.name, args[attr.name])
for attr in attrs
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index 30132f891faecc..a1593c35855eaf 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -23,12 +23,12 @@ class MyInt(Dialect, name="myint"):
class ConstantOp(MyInt.Operation, name="constant"):
value: IntegerAttr
- cst: Result[i32]
+ cst: Result[i32] = result(infer_type=True)
class AddOp(Operation, dialect=MyInt, name="add"):
lhs: Operand[i32]
rhs: Operand[i32]
- res: Result[i32]
+ res: Result[i32] = result(infer_type=True)
# CHECK: irdl.dialect @myint {
# CHECK: irdl.operation @constant {
@@ -88,9 +88,9 @@ class AddOp(Operation, dialect=MyInt, name="add"):
print(two.value)
# CHECK: OpResult(%0
print(two.cst)
- # CHECK: (self, /, lhs, rhs, *, loc=None, ip=None)
+ # CHECK: (self, /, lhs, rhs, *, res=None, loc=None, ip=None)
print(AddOp.__init__.__signature__)
- # CHECK: (self, /, value, *, loc=None, ip=None)
+ # CHECK: (self, /, value, *, cst=None, loc=None, ip=None)
print(ConstantOp.__init__.__signature__)
# CHECK: True
@@ -124,28 +124,28 @@ class ConstraintOp(Test.Operation, name="constraint"):
y: FloatAttr
class OptionalOp(Test.Operation, name="optional"):
- a: Operand[i32]
- b: Optional[Operand[i32]]
out1: Result[i32]
- out2: Result[i32] | None
out3: Result[i32]
+ a: Operand[i32]
+ out2: Result[i32] | None = None
+ b: Optional[Operand[i32]] = None
class Optional2Op(Test.Operation, name="optional2"):
- a: Optional[Operand[i32]]
- b: Optional[Result[i32]]
+ b: Optional[Result[i32]] = None
+ a: Optional[Operand[i32]] = None
class VariadicOp(Test.Operation, name="variadic"):
- a: Operand[i32]
- b: Optional[Operand[i32]]
- c: Sequence[Operand[i32]]
out1: Sequence[Result[i32]]
out2: Sequence[Result[i32]]
- out3: Optional[Result[i32]]
out4: Result[i32]
+ a: Operand[i32]
+ c: Sequence[Operand[i32]]
+ out3: Optional[Result[i32]] = None
+ b: Optional[Operand[i32]] = None
class Variadic2Op(Test.Operation, name="variadic2"):
- a: Sequence[Operand[i32]]
b: Sequence[Result[i32]]
+ a: Sequence[Operand[i32]]
class MixedOpBase(Test.Operation):
out: Result[i32]
@@ -153,9 +153,9 @@ class MixedOpBase(Test.Operation):
class MixedOp(MixedOpBase, name="mixed"):
in2: IntegerAttr
- in3: Optional[Operand[i32]]
in4: IntegerAttr
in5: Operand[i32]
+ in3: Optional[Operand[i32]] = None
T = TypeVar("T")
U = TypeVar("U", bound=IntegerType[32] | IntegerType[64])
@@ -168,6 +168,11 @@ class TypeVarOp(Test.Operation, name="type_var"):
in4: Operand[U | V]
in5: Operand[V]
+ class OptionalButNotKeywordOp(Test.Operation, name="optional_but_not_keyword"):
+ a: Operand[i32]
+ b: Optional[Operand[i32]]
+ c: Operand[i32]
+
# CHECK: irdl.dialect @ext_test {
# CHECK: irdl.operation @constraint {
# CHECK: %0 = irdl.is i32
@@ -185,7 +190,7 @@ class TypeVarOp(Test.Operation, name="type_var"):
# CHECK: irdl.operation @optional {
# CHECK: %0 = irdl.is i32
# CHECK: irdl.operands(a: %0, b: optional %0)
- # CHECK: irdl.results(out1: %0, out2: optional %0, out3: %0)
+ # CHECK: irdl.results(out1: %0, out3: %0, out2: optional %0)
# CHECK: }
# CHECK: irdl.operation @optional2 {
# CHECK: %0 = irdl.is i32
@@ -194,8 +199,8 @@ class TypeVarOp(Test.Operation, name="type_var"):
# CHECK: }
# CHECK: irdl.operation @variadic {
# CHECK: %0 = irdl.is i32
- # CHECK: irdl.operands(a: %0, b: optional %0, c: variadic %0)
- # CHECK: irdl.results(out1: variadic %0, out2: variadic %0, out3: optional %0, out4: %0)
+ # CHECK: irdl.operands(a: %0, c: variadic %0, b: optional %0)
+ # CHECK: irdl.results(out1: variadic %0, out2: variadic %0, out4: %0, out3: optional %0)
# CHECK: }
# CHECK: irdl.operation @variadic2 {
# CHECK: %0 = irdl.is i32
@@ -204,7 +209,7 @@ class TypeVarOp(Test.Operation, name="type_var"):
# CHECK: }
# CHECK: irdl.operation @mixed {
# CHECK: %0 = irdl.is i32
- # CHECK: irdl.operands(in1: %0, in3: optional %0, in5: %0)
+ # CHECK: irdl.operands(in1: %0, in5: %0, in3: optional %0)
# CHECK: %1 = irdl.base "#builtin.integer"
# CHECK: %2 = irdl.base "#builtin.integer"
# CHECK: irdl.attributes {"in2" = %1, "in4" = %2}
@@ -236,16 +241,20 @@ class TypeVarOp(Test.Operation, name="type_var"):
print(VariadicOp.__init__.__signature__)
# CHECK: (self, /, b, a, *, loc=None, ip=None)
print(Variadic2Op.__init__.__signature__)
- # CHECK: (self, /, in1, in2, in4, in5, *, in3=None, loc=None, ip=None)
+ # CHECK: (self, /, out, in1, in2, in4, in5, *, in3=None, loc=None, ip=None)
print(MixedOp.__init__.__signature__)
+ # CHECK: (self, /, in1, in2, in3, in4, in5, *, loc=None, ip=None)
+ print(TypeVarOp.__init__.__signature__)
+ # CHECK: (self, /, a, b, c, *, loc=None, ip=None)
+ print(OptionalButNotKeywordOp.__init__.__signature__)
# CHECK: None None
print(ConstraintOp._ODS_OPERAND_SEGMENTS, ConstraintOp._ODS_RESULT_SEGMENTS)
- # CHECK: [1, 0] [1, 0, 1]
+ # CHECK: [1, 0] [1, 1, 0]
print(OptionalOp._ODS_OPERAND_SEGMENTS, OptionalOp._ODS_RESULT_SEGMENTS)
# CHECK: [0] [0]
print(Optional2Op._ODS_OPERAND_SEGMENTS, Optional2Op._ODS_RESULT_SEGMENTS)
- # CHECK: [1, 0, -1] [-1, -1, 0, 1]
+ # CHECK: [1, -1, 0] [-1, -1, 1, 0]
print(VariadicOp._ODS_OPERAND_SEGMENTS, VariadicOp._ODS_RESULT_SEGMENTS)
# CHECK: [-1] [-1]
print(Variadic2Op._ODS_OPERAND_SEGMENTS, Variadic2Op._ODS_RESULT_SEGMENTS)
@@ -269,7 +278,7 @@ class TypeVarOp(Test.Operation, name="type_var"):
# CHECK: ext_test.constraint"(%c1_i32, %cst, %c1_i32, %cst) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, f32, i32, f32) -> ()
ConstraintOp(ione, fone, ione, fone, iattr, fattr)
- # CHECK: %0:2 = "ext_test.optional"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0>, resultSegmentSizes = array<i32: 1, 0, 1>} : (i32) -> (i32, i32)
+ # CHECK: %0:2 = "ext_test.optional"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0>, resultSegmentSizes = array<i32: 1, 1, 0>} : (i32) -> (i32, i32)
o1 = OptionalOp(i32, i32, ione)
# CHECK: %1:3 = "ext_test.optional"(%c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 1>, resultSegmentSizes = array<i32: 1, 1, 1>} : (i32, i32) -> (i32, i32, i32)
o2 = OptionalOp(i32, i32, ione, out2=i32, b=ione)
@@ -282,11 +291,11 @@ class TypeVarOp(Test.Operation, name="type_var"):
# CHECK: %3 = "ext_test.optional2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 1>} : (i32) -> i32
o6 = Optional2Op(b=i32, a=ione)
- # CHECK: %4:4 = "ext_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 0, 2>, resultSegmentSizes = array<i32: 1, 2, 0, 1>} : (i32, i32, i32) -> (i32, i32, i32, i32)
+ # CHECK: %4:4 = "ext_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 2, 0>, resultSegmentSizes = array<i32: 1, 2, 1, 0>} : (i32, i32, i32) -> (i32, i32, i32, i32)
v1 = VariadicOp([i32], [i32, i32], i32, ione, [ione, ione])
# CHECK: %5:5 = "ext_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 1, 1>, resultSegmentSizes = array<i32: 1, 2, 1, 1>} : (i32, i32, i32) -> (i32, i32, i32, i32, i32)
v2 = VariadicOp([i32], [i32, i32], i32, ione, [ione], out3=i32, b=ione)
- # CHECK: %6:4 = "ext_test.variadic"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0, 0>, resultSegmentSizes = array<i32: 2, 1, 0, 1>} : (i32) -> (i32, i32, i32, i32)
+ # CHECK: %6:4 = "ext_test.variadic"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0, 0>, resultSegmentSizes = array<i32: 2, 1, 1, 0>} : (i32) -> (i32, i32, i32, i32)
v3 = VariadicOp([i32, i32], [i32], i32, ione, [])
# CHECK: "ext_test.variadic2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 0>} : () -> ()
v4 = Variadic2Op([], [])
@@ -295,10 +304,10 @@ class TypeVarOp(Test.Operation, name="type_var"):
# CHECK: %7:2 = "ext_test.variadic2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 2>} : (i32) -> (i32, i32)
v6 = Variadic2Op([i32, i32], [ione])
- # CHECK: %8 = "ext_test.mixed"(%c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 0, 1>} : (i32, i32) -> i32
- m1 = MixedOp(ione, iattr, iattr, ione)
+ # CHECK: %8 = "ext_test.mixed"(%c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 1, 0>} : (i32, i32) -> i32
+ m1 = MixedOp(i32, ione, iattr, iattr, ione)
# CHECK: %9 = "ext_test.mixed"(%c1_i32, %c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 1, 1>} : (i32, i32, i32) -> i32
- m2 = MixedOp(ione, iattr, iattr, ione, in3=ione)
+ m2 = MixedOp(i32, ione, iattr, iattr, ione, in3=ione)
print(module)
assert module.operation.verify()
@@ -319,7 +328,7 @@ class TypeVarOp(Test.Operation, name="type_var"):
print(o1.out2)
# CHECK: 0
print(o2.out1.result_number)
- # CHECK: 1
+ # CHECK: 2
print(o2.out2.result_number)
# CHECK: None
print(o3.a)
@@ -371,8 +380,8 @@ class TestRegion(Dialect, name="ext_region"):
pass
class IfOp(TestRegion.Operation, name="if"):
- cond: Operand[IntegerType[1]]
result: Result[Any]
+ cond: Operand[IntegerType[1]]
then: Region
else_: Region
@@ -547,7 +556,9 @@ class MakeArrayOp(TestType.Operation, name="make_array"):
arr: Result[Array]
class MakeArray3Op(TestType.Operation, name="make_array3"):
- arr: Result[Array[IntegerType[32], IntegerAttr[IntegerType[32], 3]]]
+ arr: Result[Array[IntegerType[32], IntegerAttr[IntegerType[32], 3]]] = result(
+ infer_type=True
+ )
with Context(), Location.unknown():
TestType.load()
@@ -695,3 +706,33 @@ class Op2(TestAttr.Operation, name="op2"):
# CHECK: "ext_attr.op1"() {pair = #ext_attr.pair<1 : i32, 2 : i32>} : () -> ()
# CHECK: "ext_attr.op2"() {pair = #ext_attr.str_pair<"hello", "world">, pair2 = #ext_attr.str_pair<"a", "b">} : () -> ()
print(module)
+
+
+# CHECK: TEST: testExtDialectWithInvalidOp
+ at run
+def testExtDialectWithInvalidOp():
+ class TestInvalid(Dialect, name="ext_invalid"):
+ pass
+
+ try:
+
+ class InferTypeBeforePositionalOp(
+ TestInvalid.Operation, name="infer_before_pos"
+ ):
+ res: Result[IntegerType[32]] = result(infer_type=True)
+ a: Operand[IntegerType[32]]
+
+ except ValueError as e:
+ # CHECK: wrong parameter order
+ print(e)
+
+ try:
+
+ class AssignNoneOnNonOptionalOp(
+ TestInvalid.Operation, name="assign_none_on_non_optional"
+ ):
+ a: Operand[IntegerType[32]] = None
+
+ except ValueError as e:
+ # CHECK: only optional operand can be a keyword parameter
+ print(e)
diff --git a/mlir/test/python/dialects/transform_op_interface.py b/mlir/test/python/dialects/transform_op_interface.py
index a6e2c6da453222..b9a1c43e111afc 100644
--- a/mlir/test/python/dialects/transform_op_interface.py
+++ b/mlir/test/python/dialects/transform_op_interface.py
@@ -88,7 +88,7 @@ def get_effects(op: ir.Operation, effects):
class GetNamedAttributeOp(MyTransform.Operation, name="get_named_attribute"):
target: ext.Operand[transform.AnyOpType]
attr_name: ir.StringAttr
- attr_as_param: ext.Result[transform.AnyParamType[()]]
+ attr_as_param: ext.Result[transform.AnyParamType[()]] = ext.result(infer_type=True)
@classmethod
def attach_interface_impls(cls, ctx=None):
@@ -149,7 +149,7 @@ def allow_repeated_handle_operands(_op: "GetNamedAttributeOp") -> bool:
# Syntax for an op with one op handle operand and one op handle result.
class OneOpInOneOpOut(MyTransform.Operation, name="one_op_in_one_op_out"):
target: ext.Operand[transform.AnyOpType]
- res: ext.Result[transform.AnyOpType[()]]
+ res: ext.Result[transform.AnyOpType[()]] = ext.result(infer_type=True)
# CHECK-LABEL: Test: OneOpInOneOpOutTransformOpInterface
@@ -277,9 +277,9 @@ class OpValParamInParamOpValOut(
val_arg: ext.Operand[transform.AnyValueType]
param_arg: ext.Operand[transform.AnyParamType]
# results
- param_res: ext.Result[transform.AnyParamType[()]]
- op_res: ext.Result[transform.AnyOpType[()]]
- value_res: ext.Result[transform.AnyValueType[()]]
+ param_res: ext.Result[transform.AnyParamType[()]] = ext.result(infer_type=True)
+ op_res: ext.Result[transform.AnyOpType[()]] = ext.result(infer_type=True)
+ value_res: ext.Result[transform.AnyValueType[()]] = ext.result(infer_type=True)
# CHECK-LABEL: Test: OpValParamInParamOpValOutTransformOpInterface
@@ -376,12 +376,12 @@ def allow_repeated_handle_operands(_op: OpValParamInParamOpValOut) -> bool:
class OpsParamsInValuesParamOut(
MyTransform.Operation, name="ops_params_in_values_param_out"
):
- # operands
- ops: Sequence[ext.Operand[transform.AnyOpType]]
- params: Sequence[ext.Operand[transform.AnyParamType]]
# results
values: Sequence[ext.Result[transform.AnyValueType]]
param: ext.Result[transform.AnyParamType]
+ # operands
+ ops: Sequence[ext.Operand[transform.AnyOpType]]
+ params: Sequence[ext.Operand[transform.AnyParamType]]
# CHECK-LABEL: Test: OpsParamsInValuesParamOutTransformOpInterface
diff --git a/mlir/test/python/integration/dialects/bf.py b/mlir/test/python/integration/dialects/bf.py
index cbb815378ddf7c..e38155dcda973e 100644
--- a/mlir/test/python/integration/dialects/bf.py
+++ b/mlir/test/python/integration/dialects/bf.py
@@ -20,12 +20,12 @@ class PtrType(BfDialect.Type, name="ptr"):
class NextOp(BfDialect.Operation, name="next"):
in_: Operand[PtrType]
- out: Result[PtrType[()]]
+ out: Result[PtrType[()]] = result(infer_type=True)
class PrevOp(BfDialect.Operation, name="prev"):
in_: Operand[PtrType]
- out: Result[PtrType[()]]
+ out: Result[PtrType[()]] = result(infer_type=True)
class IncOp(BfDialect.Operation, name="inc"):
@@ -46,7 +46,7 @@ class OutputOp(BfDialect.Operation, name="output"):
class WhileOp(BfDialect.Operation, name="while"):
in_: Operand[PtrType]
- out: Result[PtrType[()]]
+ out: Result[PtrType[()]] = result(infer_type=True)
body: Region
More information about the Mlir-commits
mailing list