[Mlir-commits] [mlir] [MLIR][Python] Add more field specifiers to Python-defined operations (PR #188064)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 26 09:11:42 PDT 2026
https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/188064
>From 17795c8f8e4833444166185fabbab0f6f5d35ec4 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Tue, 24 Mar 2026 00:14:13 +0800
Subject: [PATCH 1/3] [MLIR][Python] Add more field specifiers to
Python-defined operations
---
mlir/python/mlir/dialects/ext.py | 159 +++++++++++++++++++++++++------
mlir/test/python/dialects/ext.py | 100 ++++++++++++++++++-
2 files changed, 229 insertions(+), 30 deletions(-)
diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index dfcd7f2d641d0..7242c65d54492 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -18,6 +18,7 @@
from dataclasses import dataclass
from inspect import Parameter, Signature
from types import UnionType
+from enum import Enum
from . import irdl
from ._ods_common import _cext, segmented_accessor
from .irdl import Variadicity
@@ -35,6 +36,8 @@
"Type",
"Attribute",
"result",
+ "operand",
+ "attribute",
]
Operand = ir.Value
@@ -108,25 +111,70 @@ def _lower(self, type_) -> ir.Value:
@dataclass
class FieldSpecifier:
+ type_: Any = None
infer_type: bool = False
default_is_none: bool = False
+ default_factory: Optional[Callable[[], Any]] = None
+ kw_only: bool = False
- def __post_init__(self):
- if self.infer_type and self.default_is_none:
- raise ValueError(
- "a field cannot be marked with both infer_type and default_is_none"
- )
-
- def kw_only(self) -> bool:
- return self.default_is_none or self.infer_type
+ def param_kind(self):
+ if self.default_is_none or self.default_factory or self.infer_type:
+ return ParameterKind.KEYWORD_ONLY_WITH_DEFAULT
+ if self.kw_only:
+ return ParameterKind.KEYWORD_ONLY_WITHOUT_DEFAULT
+ return ParameterKind.POSITIONAL_OR_KEYWORD
-def result(*, infer_type: bool = False) -> Any:
+def result(
+ *,
+ infer_type: bool = False,
+ default_factory: Optional[Callable[[], Any]] = None,
+ kw_only: bool = False,
+) -> Any:
"""
A field specifier for `Result` definitions.
"""
+ if infer_type and default_factory:
+ raise ValueError(
+ "a result field cannot have both infer_type and default_factory"
+ )
- return FieldSpecifier(infer_type=infer_type)
+ return FieldSpecifier(
+ type_=Result,
+ infer_type=infer_type,
+ default_factory=default_factory,
+ kw_only=kw_only,
+ )
+
+
+def operand(
+ *,
+ kw_only: bool = False,
+) -> Any:
+ """
+ A field specifier for `Operand` definitions.
+ """
+
+ return FieldSpecifier(
+ type_=Operand,
+ kw_only=kw_only,
+ )
+
+
+def attribute(
+ *,
+ default_factory: Optional[Callable[[], Any]] = None,
+ kw_only: bool = False,
+) -> Any:
+ """
+ A field specifier for attribute definitions.
+ """
+
+ return FieldSpecifier(
+ type_=Attribute,
+ default_factory=default_factory,
+ kw_only=kw_only,
+ )
def infer_type_impl(type_) -> Callable[[], ir.Type]:
@@ -149,6 +197,12 @@ def infer_type_impl(type_) -> Callable[[], ir.Type]:
raise TypeError(f"unsupported type for inferring: {type_}")
+class ParameterKind(Enum):
+ POSITIONAL_OR_KEYWORD = 1
+ KEYWORD_ONLY_WITHOUT_DEFAULT = 2
+ KEYWORD_ONLY_WITH_DEFAULT = 3
+
+
@dataclass
class FieldDef:
"""
@@ -159,7 +213,7 @@ class FieldDef:
variadicity: Variadicity
constraint: Any
- kw_only: bool = False
+ param_kind: ParameterKind = ParameterKind.POSITIONAL_OR_KEYWORD
@staticmethod
def from_type_hint(name, type_, specifier) -> "FieldDef":
@@ -173,46 +227,72 @@ def from_type_hint(name, type_, specifier) -> "FieldDef":
origin = get_origin(type_)
if origin is ir.OpResult:
+ if specifier.type_ and specifier.type_ is not Result:
+ raise TypeError(
+ f"only `result` field specifier can be used for result fields"
+ )
constraint = get_args(type_)[0]
return ResultDef(
name,
variadicity,
constraint,
- kw_only=specifier.kw_only(),
+ param_kind=specifier.param_kind(),
+ default_factory=specifier.default_factory,
+ default_is_none=specifier.default_is_none,
infer_type=(
infer_type_impl(constraint) if specifier.infer_type else None
),
)
elif origin is ir.Value:
+ if specifier.type_ and specifier.type_ is not Operand:
+ raise TypeError(
+ f"only `operand` field specifier can be used for operand fields"
+ )
return OperandDef(
name,
variadicity,
get_args(type_)[0],
- kw_only=specifier.kw_only(),
+ param_kind=specifier.param_kind(),
+ default_is_none=specifier.default_is_none,
)
elif type_ is ir.Region:
+ if specifier.type_ and specifier.type_ is not Region:
+ raise TypeError(
+ f"only `region` field specifier can be used for region fields"
+ )
return RegionDef(name, variadicity, Any)
- return AttributeDef(name, variadicity, type_)
+
+ if specifier.type_ and specifier.type_ is not Attribute:
+ raise TypeError(
+ f"only `attribute` field specifier can be used for attribute fields"
+ )
+ return AttributeDef(
+ name,
+ variadicity,
+ type_,
+ param_kind=specifier.param_kind(),
+ default_factory=specifier.default_factory,
+ )
@dataclass
class OperandDef(FieldDef):
+ default_is_none: bool = False
+
def __post_init__(self):
- if self.variadicity != Variadicity.optional and self.kw_only:
- raise ValueError(f"only optional operand can be a keyword parameter")
+ if self.variadicity != Variadicity.optional and self.default_is_none:
+ raise ValueError(f"only optional operand can be set to None")
@dataclass
class ResultDef(FieldDef):
infer_type: Callable[[], ir.Type] | None = None
+ default_factory: Optional[Callable[[], Any]] = None
+ default_is_none: bool = False
def __post_init__(self):
- if (
- self.variadicity != Variadicity.optional
- and not self.infer_type
- and self.kw_only
- ):
- raise ValueError(f"only optional result can be a keyword parameter")
+ if self.variadicity != Variadicity.optional and self.default_is_none:
+ raise ValueError(f"only optional result can be set to None")
if self.infer_type and self.variadicity != Variadicity.single:
raise ValueError(
@@ -226,15 +306,33 @@ def process_type(self, type_):
if self.infer_type:
return self.infer_type()
+ if self.default_factory:
+ return self.default_factory()
+
return None
@dataclass
class AttributeDef(FieldDef):
+ default_factory: Optional[Callable[[], Any]] = None
def __post_init__(self):
if self.variadicity != Variadicity.single:
raise ValueError("optional attribute is not currently supported")
+ if (
+ self.param_kind == ParameterKind.KEYWORD_ONLY_WITH_DEFAULT
+ and not self.default_factory
+ ):
+ raise ValueError(f"only optional attribute can be set to None")
+
+ def process_attr(self, attr):
+ if attr:
+ return attr
+
+ if self.default_factory:
+ return self.default_factory()
+
+ return None
@dataclass
@@ -415,10 +513,15 @@ def _generate_init_signature(fields: List[FieldDef]) -> Signature:
params = [Parameter("self", Parameter.POSITIONAL_ONLY)]
for i in args:
- if i.kw_only:
- params.append(Parameter(i.name, Parameter.KEYWORD_ONLY, default=None))
- else:
- params.append(Parameter(i.name, Parameter.POSITIONAL_OR_KEYWORD))
+ match i.param_kind:
+ case ParameterKind.POSITIONAL_OR_KEYWORD:
+ params.append(Parameter(i.name, Parameter.POSITIONAL_OR_KEYWORD))
+ case ParameterKind.KEYWORD_ONLY_WITH_DEFAULT:
+ params.append(
+ Parameter(i.name, Parameter.KEYWORD_ONLY, default=None)
+ )
+ case ParameterKind.KEYWORD_ONLY_WITHOUT_DEFAULT:
+ params.append(Parameter(i.name, Parameter.KEYWORD_ONLY))
params.append(Parameter("loc", Parameter.KEYWORD_ONLY, default=None))
params.append(Parameter("ip", Parameter.KEYWORD_ONLY, default=None))
@@ -439,9 +542,7 @@ def __init__(*args, **kwargs):
_operands = [args[operand.name] for operand in operands]
_results = [result.process_type(args[result.name]) for result in results]
_attributes = dict(
- (attr.name, args[attr.name])
- for attr in attrs
- if args[attr.name] is not None
+ (attr.name, attr.process_attr(args[attr.name])) for attr in attrs
)
_regions = len(regions) or None
_ods_successors = None
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index 78c74684cef77..f71d3cd794463 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -734,7 +734,49 @@ class AssignNoneOnNonOptionalOp(
a: Operand[IntegerType[32]] = None
except ValueError as e:
- # CHECK: only optional operand can be a keyword parameter
+ # CHECK: only optional operand can be set to None
+ print(e)
+
+ try:
+
+ class AssignNoneOnnAttributeOp(
+ TestInvalid.Operation, name="assign_none_on_attribute"
+ ):
+ a: IntegerAttr = None
+
+ except ValueError as e:
+ # CHECK: only optional attribute can be set to None
+ print(e)
+
+ try:
+
+ class CannotInferTypeOp(TestInvalid.Operation, name="cannot_infer_type"):
+ a: Result[IntegerType] = result(infer_type=True)
+
+ except TypeError as e:
+ # CHECK: unsupported type for inferring
+ print(e)
+
+ try:
+
+ class WrongFieldSpecifierOp(
+ TestInvalid.Operation, name="wrong_field_specifier"
+ ):
+ a: Result[IntegerType] = operand()
+
+ except TypeError as e:
+ # CHECK: only `result` field specifier can be used for result fields
+ print(e)
+
+ try:
+
+ class WrongFieldSpecifierOp2(
+ TestInvalid.Operation, name="wrong_field_specifier2"
+ ):
+ a: IntegerAttr = operand()
+
+ except TypeError as e:
+ # CHECK: only `attribute` field specifier can be used for attribute fields
print(e)
@@ -777,3 +819,59 @@ class OpWithAttr(TestAttrInOp.Operation, name="op_with_attr"):
# CHECK: "ext_attr_in_op.op_with_attr"() {a = 42 : i32, b = i32} : () -> ()
# CHECK: "ext_attr_in_op.op_with_attr"() {a = "hello", b = i64} : () -> ()
print(module)
+
+
+ at run
+def testExtDialectFieldSpecifiers():
+ class TestFieldSpecifiers(Dialect, name="ext_field_specifiers"):
+ pass
+
+ class OperandSpecifierOp(TestFieldSpecifiers.Operation, name="operand_specifier"):
+ a: Operand[IntegerType[32]] = operand()
+ b: Optional[Operand[IntegerType[32]]] = None
+ c: Operand[IntegerType[32]] = operand(kw_only=True)
+
+ class ResultSpecifierOp(TestFieldSpecifiers.Operation, name="result_specifier"):
+ a: Result[IntegerType[32]] = result()
+ b: Result[IntegerType[16]] = result(infer_type=True)
+ c: Result[IntegerType] = result(
+ default_factory=lambda: IntegerType.get_signless(8)
+ )
+ d: Sequence[Result[IntegerType]] = result(default_factory=list)
+ e: Result[IntegerType[32]] = result(kw_only=True)
+
+ class AttributeSpecifierOp(
+ TestFieldSpecifiers.Operation, name="attribute_specifier"
+ ):
+ a: IntegerAttr = attribute()
+ b: IntegerAttr = attribute(
+ default_factory=lambda: IntegerAttr.get(IntegerType.get_signless(32), 42)
+ )
+ b: StringAttr["a"] | StringAttr["b"] = attribute(
+ default_factory=lambda: StringAttr.get("a")
+ )
+ c: IntegerAttr = attribute(kw_only=True)
+
+ with Context(), Location.unknown():
+ TestFieldSpecifiers.load()
+
+ # CHECK: (self, /, a, *, b=None, c, loc=None, ip=None)
+ print(OperandSpecifierOp.__init__.__signature__)
+ # CHECK: (self, /, a, *, b=None, c=None, d=None, e, loc=None, ip=None)
+ print(ResultSpecifierOp.__init__.__signature__)
+ # CHECK: (self, /, a, *, b=None, c, loc=None, ip=None)
+ print(AttributeSpecifierOp.__init__.__signature__)
+
+ module = Module.create()
+ i32 = IntegerType.get_signless(32)
+ with InsertionPoint(module.body):
+ one = arith.constant(i32, 1)
+ OperandSpecifierOp(one, c=one)
+ ResultSpecifierOp(i32, e=i32)
+ AttributeSpecifierOp(IntegerAttr.get(i32, 42), c=IntegerAttr.get(i32, 100))
+
+ assert module.operation.verify()
+ # CHECK: "ext_field_specifiers.operand_specifier"(%c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 0, 1>} : (i32, i32) -> ()
+ # CHECK: %0:4 = "ext_field_specifiers.result_specifier"() {resultSegmentSizes = array<i32: 1, 1, 1, 0, 1>} : () -> (i32, i16, i8, i32)
+ # CHECK: "ext_field_specifiers.attribute_specifier"() {a = 42 : i32, b = "a", c = 100 : i32} : () -> ()
+ print(module)
>From 57cfb98cf0894cc30c6fb40fe279bb76906d53fb Mon Sep 17 00:00:00 2001
From: Twice <twice at apache.org>
Date: Wed, 25 Mar 2026 00:12:49 +0800
Subject: [PATCH 2/3] Apply suggestion from @PragmaTwice
---
mlir/python/mlir/dialects/ext.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 7242c65d54492..2c1237d3660aa 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -258,7 +258,7 @@ def from_type_hint(name, type_, specifier) -> "FieldDef":
elif type_ is ir.Region:
if specifier.type_ and specifier.type_ is not Region:
raise TypeError(
- f"only `region` field specifier can be used for region fields"
+ f"currently no field specifier can be used for region fields"
)
return RegionDef(name, variadicity, Any)
>From 8e7c80c72e660dbfdfcec13165a42ea65a25aac4 Mon Sep 17 00:00:00 2001
From: Twice <twice at apache.org>
Date: Fri, 27 Mar 2026 00:11:29 +0800
Subject: [PATCH 3/3] Update mlir/python/mlir/dialects/ext.py
Co-authored-by: Rolf Morel <rolfmorel at gmail.com>
---
mlir/python/mlir/dialects/ext.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 2c1237d3660aa..fe4d3290ffb0b 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -130,7 +130,7 @@ def result(
infer_type: bool = False,
default_factory: Optional[Callable[[], Any]] = None,
kw_only: bool = False,
-) -> Any:
+) -> Result:
"""
A field specifier for `Result` definitions.
"""
More information about the Mlir-commits
mailing list