[Mlir-commits] [mlir] [MLIR][Python] Make init parameters follow the field definition order (PR #186574)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Mar 14 22:13:03 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Twice (PragmaTwice)
<details>
<summary>Changes</summary>
Currently, Python-defined operations automatically generate an `__init__` function to serve as the operation builder. Previously, the parameters of this `__init__` function followed a fairly complex set of rules. For example:
* All result fields were moved to the front to align with other op builders.
* Fields of `Optional` type were automatically moved to the end and treated as keyword parameters.
* If the types of all result fields could be inferred automatically, then all result fields were removed from the parameter list.
* Other than that, the parameter order followed the field definition order.
These rules may seem reasonable, and they have worked well in practice, but they have one major drawback: users cannot easily tell what the actual `__init__` parameter list will look like when writing code, because the rules are simply too complicated. Users can inspect the signature at runtime via `MyOp.__init__.__signature__`, but this is clearly poor for the development experience.
After some offline discussion with Rolf, we decided to replace this with the following rules, which is what this PR implements:
* The parameters of `__init__` now strictly follow the field definition order.
* By default, all parameters are required. A field becomes a keyword parameter only if it is declared with `= None`.
* By default, no type inference is performed. A field becomes a keyword parameter and participates in type inference only if it is declared with `= infer_type()`.
These new rules give users full control over the `__init__` parameter list, making it easy to understand the parameter order and explicitly control optional parameters and type inference. In addition, this makes it much easier in the future to use `dataclass_transform` so that type checkers can understand these automatically generated `__init__` methods, although there are still some issues to resolve at the moment.
NOTE that this is a **breaking change**. Even so, I still believe it is worth making, mainly for two reasons:
- the current number of users of Python-defined dialects is still quite limited;
- this change will bring long-term benefits.
Assisted by Copilot.
---
Patch is 22.70 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/186574.diff
4 Files Affected:
- (modified) mlir/python/mlir/dialects/ext.py (+108-48)
- (modified) mlir/test/python/dialects/ext.py (+42-31)
- (modified) mlir/test/python/dialects/transform_op_interface.py (+8-8)
- (modified) mlir/test/python/integration/dialects/bf.py (+3-3)
``````````diff
diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 867da6ee96637..57260e09971aa 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -34,6 +34,7 @@
"Region",
"Type",
"Attribute",
+ "infer_type",
]
Operand = ir.Value
@@ -42,6 +43,9 @@
def construct_instance(origin, args):
+ if not issubclass(origin, ir.Type | ir.Attribute):
+ raise TypeError(f"unsupported type in constraints: {origin}")
+
# `origin.get` is to construct an instance of MLIR type or attribute.
return origin.get(
*(
@@ -109,21 +113,48 @@ def _lower(self, type_) -> ir.Value:
raise TypeError(f"unsupported type in constraints: {type_}")
-def infer_type(type_) -> Optional[Callable[[], ir.Type]]:
+ at dataclass
+class Marker:
+ infer_type: bool = False
+ default_is_none: bool = False
+
+ def __post_init__(self):
+ if self.infer_type and self.default_is_none:
+ raise ValueError(
+ "a field cannot be marked with both infer_type and default_is_none"
+ )
+
+ def kw_only(self) -> bool:
+ return self.default_is_none or self.infer_type
+
+
+def infer_type() -> Any:
+ """
+ A marker to indicate that the type of a result should be inferred.
+ It can only be used in `Result` definitions.
+ """
+
+ return Marker(infer_type=True)
+
+
+def infer_type_impl(type_) -> Callable[[], ir.Type]:
"""
A function to infer ir.Type from type annotation.
- Returns a callable that returns the inferred ir.Type,
- or None if the type cannot be inferred.
+ Returns a callable that returns the inferred ir.Type.
We use callables so that MLIR contexts are not required
while calling this function.
"""
origin = get_origin(type_)
- if origin and issubclass(origin, ir.Type):
- return lambda: construct_instance(origin, get_args(type_))
+ if origin and issubclass(origin, ir.Type | ir.Attribute):
+ args = [
+ infer_type_impl(arg) if get_origin(arg) else lambda: arg
+ for arg in get_args(type_)
+ ]
+ return lambda: origin.get(*[arg() for arg in args])
elif isinstance(type_, TypeVar):
- return infer_type(type_.__bound__)
- return None
+ return infer_type_impl(type_.__bound__)
+ raise TypeError(f"unsupported type for inferring: {type_}")
@dataclass
@@ -134,9 +165,12 @@ class FieldDef:
name: str
variadicity: Variadicity
+ constraint: Any
+
+ kw_only: bool = False
@staticmethod
- def from_type_hint(name, type_) -> "FieldDef":
+ def from_type_hint(name, type_, marker) -> "FieldDef":
variadicity = Variadicity.single
if inner := match_optional(type_):
variadicity = Variadicity.optional
@@ -147,29 +181,66 @@ def from_type_hint(name, type_) -> "FieldDef":
origin = get_origin(type_)
if origin is ir.OpResult:
- return ResultDef(name, variadicity, get_args(type_)[0])
+ constraint = get_args(type_)[0]
+ return ResultDef(
+ name,
+ variadicity,
+ constraint,
+ kw_only=marker.kw_only(),
+ infer_type=infer_type_impl(constraint) if marker.infer_type else None,
+ )
elif origin is ir.Value:
- return OperandDef(name, variadicity, get_args(type_)[0])
+ return OperandDef(
+ name,
+ variadicity,
+ get_args(type_)[0],
+ kw_only=marker.kw_only(),
+ )
elif issubclass(origin or type_, ir.Attribute):
return AttributeDef(name, variadicity, type_)
elif type_ is ir.Region:
- return RegionDef(name, variadicity)
- raise TypeError(f"unsupported type in operation definition: {type_}")
+ return RegionDef(name, variadicity, Any)
+ raise TypeError(
+ f"unsupported type for field '{name}' in operation definition: {type_}"
+ )
@dataclass
class OperandDef(FieldDef):
- constraint: Any
+ def __post_init__(self):
+ if self.variadicity != Variadicity.optional and self.kw_only:
+ raise ValueError(f"only optional operand can be a keyword parameter")
@dataclass
class ResultDef(FieldDef):
- constraint: Any
+ infer_type: Callable[[], ir.Type] | None = None
+
+ def __post_init__(self):
+ if (
+ self.variadicity != Variadicity.optional
+ and not self.infer_type
+ and self.kw_only
+ ):
+ raise ValueError(f"only optional result can be a keyword parameter")
+
+ if self.infer_type and self.variadicity != Variadicity.single:
+ raise ValueError(
+ f"type of variadic or optional result '{self.name}' cannot be inferred"
+ )
+
+ def process_type(self, type_):
+ if type_:
+ return type_
+
+ if self.infer_type:
+ return self.infer_type()
+
+ return None
@dataclass
class AttributeDef(FieldDef):
- constraint: Any
def __post_init__(self):
if self.variadicity != Variadicity.single:
@@ -267,7 +338,17 @@ def __init_subclass__(
if hasattr(base, "_fields"):
fields.extend(base._fields)
for key, value in cls.__annotations__.items():
- field = FieldDef.from_type_hint(key, value)
+ # if the class variable is not defined, we treat it as a default marker;
+ # if it is assigned with `None`, we treat it as a marker with `default_is_none=True`.
+ # e.g. x : int # default marker
+ # y : int = None # marker with default_is_none=True
+ marker = cls.__dict__.get(key, Marker()) or Marker(default_is_none=True)
+ # treat all other values as invalid
+ if not isinstance(marker, Marker):
+ raise TypeError(
+ f"the field specifier of field '{key}' is not supported"
+ )
+ field = FieldDef.from_type_hint(key, value, marker)
fields.append(field)
cls._fields = fields
@@ -336,27 +417,17 @@ def _generate_segments(
return None
@staticmethod
- def _generate_init_signature(
- fields: List[FieldDef], can_infer_types: bool
- ) -> Signature:
- result_args = (
- [] if can_infer_types else [i for i in fields if isinstance(i, ResultDef)]
- )
- # results are placed at the beginning of the parameter list,
- # but operands and attributes can appear in any relative order.
- args = result_args + [
- i for i in fields if not isinstance(i, ResultDef | RegionDef)
- ]
- positional_args = [
- i.name for i in args if i.variadicity != Variadicity.optional
- ]
- optional_args = [i.name for i in args if i.variadicity == Variadicity.optional]
+ def _generate_init_signature(fields: List[FieldDef]) -> Signature:
+ args = [i for i in fields if not isinstance(i, RegionDef)]
params = [Parameter("self", Parameter.POSITIONAL_ONLY)]
- for i in positional_args:
- params.append(Parameter(i, Parameter.POSITIONAL_OR_KEYWORD))
- for i in optional_args:
- params.append(Parameter(i, Parameter.KEYWORD_ONLY, default=None))
+
+ for i in args:
+ if i.kw_only:
+ params.append(Parameter(i.name, Parameter.KEYWORD_ONLY, default=None))
+ else:
+ params.append(Parameter(i.name, Parameter.POSITIONAL_OR_KEYWORD))
+
params.append(Parameter("loc", Parameter.KEYWORD_ONLY, default=None))
params.append(Parameter("ip", Parameter.KEYWORD_ONLY, default=None))
@@ -365,15 +436,8 @@ def _generate_init_signature(
@classmethod
def _generate_init_method(cls, fields: List[FieldDef]) -> None:
operands, attrs, results, regions = partition_fields(fields)
- inferred_types = [infer_type(i.constraint) for i in results]
- # we infer result types only when all result types can be inferred
- # and all results are single (not optional or variadic)
- can_infer_types = all(inferred_types) and all(
- i.variadicity == Variadicity.single for i in results
- )
-
- init_sig = cls._generate_init_signature(fields, can_infer_types)
+ init_sig = cls._generate_init_signature(fields)
def __init__(*args, **kwargs):
bound = init_sig.bind(*args, **kwargs)
@@ -381,11 +445,7 @@ def __init__(*args, **kwargs):
args = bound.arguments
_operands = [args[operand.name] for operand in operands]
- _results = (
- [t() for t in inferred_types]
- if can_infer_types
- else [args[result.name] for result in results]
- )
+ _results = [result.process_type(args[result.name]) for result in results]
_attributes = dict(
(attr.name, args[attr.name])
for attr in attrs
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index 30132f891faec..abb50c21e8a5a 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -23,12 +23,12 @@ class MyInt(Dialect, name="myint"):
class ConstantOp(MyInt.Operation, name="constant"):
value: IntegerAttr
- cst: Result[i32]
+ cst: Result[i32] = infer_type()
class AddOp(Operation, dialect=MyInt, name="add"):
lhs: Operand[i32]
rhs: Operand[i32]
- res: Result[i32]
+ res: Result[i32] = infer_type()
# CHECK: irdl.dialect @myint {
# CHECK: irdl.operation @constant {
@@ -88,9 +88,9 @@ class AddOp(Operation, dialect=MyInt, name="add"):
print(two.value)
# CHECK: OpResult(%0
print(two.cst)
- # CHECK: (self, /, lhs, rhs, *, loc=None, ip=None)
+ # CHECK: (self, /, lhs, rhs, *, res=None, loc=None, ip=None)
print(AddOp.__init__.__signature__)
- # CHECK: (self, /, value, *, loc=None, ip=None)
+ # CHECK: (self, /, value, *, cst=None, loc=None, ip=None)
print(ConstantOp.__init__.__signature__)
# CHECK: True
@@ -124,28 +124,28 @@ class ConstraintOp(Test.Operation, name="constraint"):
y: FloatAttr
class OptionalOp(Test.Operation, name="optional"):
- a: Operand[i32]
- b: Optional[Operand[i32]]
out1: Result[i32]
- out2: Result[i32] | None
out3: Result[i32]
+ a: Operand[i32]
+ out2: Result[i32] | None = None
+ b: Optional[Operand[i32]] = None
class Optional2Op(Test.Operation, name="optional2"):
- a: Optional[Operand[i32]]
- b: Optional[Result[i32]]
+ b: Optional[Result[i32]] = None
+ a: Optional[Operand[i32]] = None
class VariadicOp(Test.Operation, name="variadic"):
- a: Operand[i32]
- b: Optional[Operand[i32]]
- c: Sequence[Operand[i32]]
out1: Sequence[Result[i32]]
out2: Sequence[Result[i32]]
- out3: Optional[Result[i32]]
out4: Result[i32]
+ a: Operand[i32]
+ c: Sequence[Operand[i32]]
+ out3: Optional[Result[i32]] = None
+ b: Optional[Operand[i32]] = None
class Variadic2Op(Test.Operation, name="variadic2"):
- a: Sequence[Operand[i32]]
b: Sequence[Result[i32]]
+ a: Sequence[Operand[i32]]
class MixedOpBase(Test.Operation):
out: Result[i32]
@@ -153,9 +153,9 @@ class MixedOpBase(Test.Operation):
class MixedOp(MixedOpBase, name="mixed"):
in2: IntegerAttr
- in3: Optional[Operand[i32]]
in4: IntegerAttr
in5: Operand[i32]
+ in3: Optional[Operand[i32]] = None
T = TypeVar("T")
U = TypeVar("U", bound=IntegerType[32] | IntegerType[64])
@@ -168,6 +168,11 @@ class TypeVarOp(Test.Operation, name="type_var"):
in4: Operand[U | V]
in5: Operand[V]
+ class OptionalButNotKeywordOp(Test.Operation, name="optional_but_not_keyword"):
+ a: Operand[i32]
+ b: Optional[Operand[i32]]
+ c: Operand[i32]
+
# CHECK: irdl.dialect @ext_test {
# CHECK: irdl.operation @constraint {
# CHECK: %0 = irdl.is i32
@@ -185,7 +190,7 @@ class TypeVarOp(Test.Operation, name="type_var"):
# CHECK: irdl.operation @optional {
# CHECK: %0 = irdl.is i32
# CHECK: irdl.operands(a: %0, b: optional %0)
- # CHECK: irdl.results(out1: %0, out2: optional %0, out3: %0)
+ # CHECK: irdl.results(out1: %0, out3: %0, out2: optional %0)
# CHECK: }
# CHECK: irdl.operation @optional2 {
# CHECK: %0 = irdl.is i32
@@ -194,8 +199,8 @@ class TypeVarOp(Test.Operation, name="type_var"):
# CHECK: }
# CHECK: irdl.operation @variadic {
# CHECK: %0 = irdl.is i32
- # CHECK: irdl.operands(a: %0, b: optional %0, c: variadic %0)
- # CHECK: irdl.results(out1: variadic %0, out2: variadic %0, out3: optional %0, out4: %0)
+ # CHECK: irdl.operands(a: %0, c: variadic %0, b: optional %0)
+ # CHECK: irdl.results(out1: variadic %0, out2: variadic %0, out4: %0, out3: optional %0)
# CHECK: }
# CHECK: irdl.operation @variadic2 {
# CHECK: %0 = irdl.is i32
@@ -204,7 +209,7 @@ class TypeVarOp(Test.Operation, name="type_var"):
# CHECK: }
# CHECK: irdl.operation @mixed {
# CHECK: %0 = irdl.is i32
- # CHECK: irdl.operands(in1: %0, in3: optional %0, in5: %0)
+ # CHECK: irdl.operands(in1: %0, in5: %0, in3: optional %0)
# CHECK: %1 = irdl.base "#builtin.integer"
# CHECK: %2 = irdl.base "#builtin.integer"
# CHECK: irdl.attributes {"in2" = %1, "in4" = %2}
@@ -236,16 +241,20 @@ class TypeVarOp(Test.Operation, name="type_var"):
print(VariadicOp.__init__.__signature__)
# CHECK: (self, /, b, a, *, loc=None, ip=None)
print(Variadic2Op.__init__.__signature__)
- # CHECK: (self, /, in1, in2, in4, in5, *, in3=None, loc=None, ip=None)
+ # CHECK: (self, /, out, in1, in2, in4, in5, *, in3=None, loc=None, ip=None)
print(MixedOp.__init__.__signature__)
+ # CHECK: (self, /, in1, in2, in3, in4, in5, *, loc=None, ip=None)
+ print(TypeVarOp.__init__.__signature__)
+ # CHECK: (self, /, a, b, c, *, loc=None, ip=None)
+ print(OptionalButNotKeywordOp.__init__.__signature__)
# CHECK: None None
print(ConstraintOp._ODS_OPERAND_SEGMENTS, ConstraintOp._ODS_RESULT_SEGMENTS)
- # CHECK: [1, 0] [1, 0, 1]
+ # CHECK: [1, 0] [1, 1, 0]
print(OptionalOp._ODS_OPERAND_SEGMENTS, OptionalOp._ODS_RESULT_SEGMENTS)
# CHECK: [0] [0]
print(Optional2Op._ODS_OPERAND_SEGMENTS, Optional2Op._ODS_RESULT_SEGMENTS)
- # CHECK: [1, 0, -1] [-1, -1, 0, 1]
+ # CHECK: [1, -1, 0] [-1, -1, 1, 0]
print(VariadicOp._ODS_OPERAND_SEGMENTS, VariadicOp._ODS_RESULT_SEGMENTS)
# CHECK: [-1] [-1]
print(Variadic2Op._ODS_OPERAND_SEGMENTS, Variadic2Op._ODS_RESULT_SEGMENTS)
@@ -269,7 +278,7 @@ class TypeVarOp(Test.Operation, name="type_var"):
# CHECK: ext_test.constraint"(%c1_i32, %cst, %c1_i32, %cst) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, f32, i32, f32) -> ()
ConstraintOp(ione, fone, ione, fone, iattr, fattr)
- # CHECK: %0:2 = "ext_test.optional"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0>, resultSegmentSizes = array<i32: 1, 0, 1>} : (i32) -> (i32, i32)
+ # CHECK: %0:2 = "ext_test.optional"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0>, resultSegmentSizes = array<i32: 1, 1, 0>} : (i32) -> (i32, i32)
o1 = OptionalOp(i32, i32, ione)
# CHECK: %1:3 = "ext_test.optional"(%c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 1>, resultSegmentSizes = array<i32: 1, 1, 1>} : (i32, i32) -> (i32, i32, i32)
o2 = OptionalOp(i32, i32, ione, out2=i32, b=ione)
@@ -282,11 +291,11 @@ class TypeVarOp(Test.Operation, name="type_var"):
# CHECK: %3 = "ext_test.optional2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 1>} : (i32) -> i32
o6 = Optional2Op(b=i32, a=ione)
- # CHECK: %4:4 = "ext_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 0, 2>, resultSegmentSizes = array<i32: 1, 2, 0, 1>} : (i32, i32, i32) -> (i32, i32, i32, i32)
+ # CHECK: %4:4 = "ext_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 2, 0>, resultSegmentSizes = array<i32: 1, 2, 1, 0>} : (i32, i32, i32) -> (i32, i32, i32, i32)
v1 = VariadicOp([i32], [i32, i32], i32, ione, [ione, ione])
# CHECK: %5:5 = "ext_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 1, 1>, resultSegmentSizes = array<i32: 1, 2, 1, 1>} : (i32, i32, i32) -> (i32, i32, i32, i32, i32)
v2 = VariadicOp([i32], [i32, i32], i32, ione, [ione], out3=i32, b=ione)
- # CHECK: %6:4 = "ext_test.variadic"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0, 0>, resultSegmentSizes = array<i32: 2, 1, 0, 1>} : (i32) -> (i32, i32, i32, i32)
+ # CHECK: %6:4 = "ext_test.variadic"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0, 0>, resultSegmentSizes = array<i32: 2, 1, 1, 0>} : (i32) -> (i32, i32, i32, i32)
v3 = VariadicOp([i32, i32], [i32], i32, ione, [])
# CHECK: "ext_test.variadic2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 0>} : () -> ()
v4 = Variadic2Op([], [])
@@ -295,10 +304,10 @@ class TypeVarOp(Test.Operation, name="type_var"):
# CHECK: %7:2 = "ext_test.variadic2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 2>} : (i32) -> (i32, i32)
v6 = Variadic2Op([i32, i32], [ione])
- # CHECK: %8 = "ext_test.mixed"(%c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 0, 1>} : (i32, i32) -> i32
- m1 = MixedOp(ione, iattr, iattr, ione)
+ # CHECK: %8 = "ext_test.mixed"(%c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 1, 0>} : (i32, i32) -> i32
+ m1 = MixedOp(i32, ione, iattr, iattr, ione)
# CHECK: %9 = "ext_test.mixed"(%c1_i32, %c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 1, 1>} : (i32, i32, i32) -> i32
- m2 = MixedOp(ione, iattr, iattr, ione, in3=ione)
+ m2 = MixedOp(i32, ione, iattr, iattr, ione, in3=ione)
print(module)
assert module.operation.verify()
@@ -319,7 +328,7 @@ class TypeVarOp(Test.Operation, name="type_var"):
print(o1.out2)
# CHECK: 0
print(o2.out1.result_number)
- # CHECK: 1
+ # CHECK: 2
print(o2.out2.result_number)
# CHECK: None
print(o3.a)
@@ -371,8 +380,8 @@ class TestRegion(Dialect, name="ext_region"):
pass
class IfOp(TestRegion.Operation, name="if"):
- cond: Operand[IntegerType[1]]
result: Result[Any]
+ cond: Operand[IntegerType[1]]
then: Region
else_: Region
@@ -547,7 +556,9 @@ class MakeArrayOp(TestType.Operation, name="make_array"):
arr: Result[Array]
class MakeArray3Op(TestType.Operation, name="make_array3"):
- arr: Result[Array[IntegerType[32], IntegerAttr[IntegerType[32], 3]]]
+ arr: Result[
+ Array[IntegerType[32], IntegerAttr[IntegerType[32], 3]]
+ ] = infer_type()
with Context(), Location.unknown():
TestType.load()
diff --git a/mlir/test/python/dialects/transform_op_interface.py b/mlir/test/python/dialects/transform_op_interface.py
index a6e2c6da45322..9bc17a8efdc43 100644
--- a/mlir/test/python/dialects/transform_op_interface.py
+++ b/mlir/test/python/dialects/transform_op_interface.py
@@ -88,7 +88,7 @@ def get_effects(op: ir.Operation, effects):
class GetNamedAttributeOp(MyTransform.Operation, name="get_named_attribute"):
target: ext.Operand[transform.AnyOpType]
attr_name: ir.StringAttr
- attr_as_param: ext.Result[transform.AnyParamType[()]]
+ attr_as_param: e...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/186574
More information about the Mlir-commits
mailing list