[Mlir-commits] [mlir] [MLIR][Python] Add more field specifiers to Python-defined operations (PR #188064)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Mar 28 01:47:46 PDT 2026


https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/188064

>From 17795c8f8e4833444166185fabbab0f6f5d35ec4 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Tue, 24 Mar 2026 00:14:13 +0800
Subject: [PATCH 1/6] [MLIR][Python] Add more field specifiers to
 Python-defined operations

---
 mlir/python/mlir/dialects/ext.py | 159 +++++++++++++++++++++++++------
 mlir/test/python/dialects/ext.py | 100 ++++++++++++++++++-
 2 files changed, 229 insertions(+), 30 deletions(-)

diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index dfcd7f2d641d0..7242c65d54492 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -18,6 +18,7 @@
 from dataclasses import dataclass
 from inspect import Parameter, Signature
 from types import UnionType
+from enum import Enum
 from . import irdl
 from ._ods_common import _cext, segmented_accessor
 from .irdl import Variadicity
@@ -35,6 +36,8 @@
     "Type",
     "Attribute",
     "result",
+    "operand",
+    "attribute",
 ]
 
 Operand = ir.Value
@@ -108,25 +111,70 @@ def _lower(self, type_) -> ir.Value:
 
 @dataclass
 class FieldSpecifier:
+    type_: Any = None
     infer_type: bool = False
     default_is_none: bool = False
+    default_factory: Optional[Callable[[], Any]] = None
+    kw_only: bool = False
 
-    def __post_init__(self):
-        if self.infer_type and self.default_is_none:
-            raise ValueError(
-                "a field cannot be marked with both infer_type and default_is_none"
-            )
-
-    def kw_only(self) -> bool:
-        return self.default_is_none or self.infer_type
+    def param_kind(self):
+        if self.default_is_none or self.default_factory or self.infer_type:
+            return ParameterKind.KEYWORD_ONLY_WITH_DEFAULT
+        if self.kw_only:
+            return ParameterKind.KEYWORD_ONLY_WITHOUT_DEFAULT
+        return ParameterKind.POSITIONAL_OR_KEYWORD
 
 
-def result(*, infer_type: bool = False) -> Any:
+def result(
+    *,
+    infer_type: bool = False,
+    default_factory: Optional[Callable[[], Any]] = None,
+    kw_only: bool = False,
+) -> Any:
     """
     A field specifier for `Result` definitions.
     """
+    if infer_type and default_factory:
+        raise ValueError(
+            "a result field cannot have both infer_type and default_factory"
+        )
 
-    return FieldSpecifier(infer_type=infer_type)
+    return FieldSpecifier(
+        type_=Result,
+        infer_type=infer_type,
+        default_factory=default_factory,
+        kw_only=kw_only,
+    )
+
+
+def operand(
+    *,
+    kw_only: bool = False,
+) -> Any:
+    """
+    A field specifier for `Operand` definitions.
+    """
+
+    return FieldSpecifier(
+        type_=Operand,
+        kw_only=kw_only,
+    )
+
+
+def attribute(
+    *,
+    default_factory: Optional[Callable[[], Any]] = None,
+    kw_only: bool = False,
+) -> Any:
+    """
+    A field specifier for attribute definitions.
+    """
+
+    return FieldSpecifier(
+        type_=Attribute,
+        default_factory=default_factory,
+        kw_only=kw_only,
+    )
 
 
 def infer_type_impl(type_) -> Callable[[], ir.Type]:
@@ -149,6 +197,12 @@ def infer_type_impl(type_) -> Callable[[], ir.Type]:
     raise TypeError(f"unsupported type for inferring: {type_}")
 
 
+class ParameterKind(Enum):
+    POSITIONAL_OR_KEYWORD = 1
+    KEYWORD_ONLY_WITHOUT_DEFAULT = 2
+    KEYWORD_ONLY_WITH_DEFAULT = 3
+
+
 @dataclass
 class FieldDef:
     """
@@ -159,7 +213,7 @@ class FieldDef:
     variadicity: Variadicity
     constraint: Any
 
-    kw_only: bool = False
+    param_kind: ParameterKind = ParameterKind.POSITIONAL_OR_KEYWORD
 
     @staticmethod
     def from_type_hint(name, type_, specifier) -> "FieldDef":
@@ -173,46 +227,72 @@ def from_type_hint(name, type_, specifier) -> "FieldDef":
 
         origin = get_origin(type_)
         if origin is ir.OpResult:
+            if specifier.type_ and specifier.type_ is not Result:
+                raise TypeError(
+                    f"only `result` field specifier can be used for result fields"
+                )
             constraint = get_args(type_)[0]
             return ResultDef(
                 name,
                 variadicity,
                 constraint,
-                kw_only=specifier.kw_only(),
+                param_kind=specifier.param_kind(),
+                default_factory=specifier.default_factory,
+                default_is_none=specifier.default_is_none,
                 infer_type=(
                     infer_type_impl(constraint) if specifier.infer_type else None
                 ),
             )
         elif origin is ir.Value:
+            if specifier.type_ and specifier.type_ is not Operand:
+                raise TypeError(
+                    f"only `operand` field specifier can be used for operand fields"
+                )
             return OperandDef(
                 name,
                 variadicity,
                 get_args(type_)[0],
-                kw_only=specifier.kw_only(),
+                param_kind=specifier.param_kind(),
+                default_is_none=specifier.default_is_none,
             )
         elif type_ is ir.Region:
+            if specifier.type_ and specifier.type_ is not Region:
+                raise TypeError(
+                    f"only `region` field specifier can be used for region fields"
+                )
             return RegionDef(name, variadicity, Any)
-        return AttributeDef(name, variadicity, type_)
+
+        if specifier.type_ and specifier.type_ is not Attribute:
+            raise TypeError(
+                f"only `attribute` field specifier can be used for attribute fields"
+            )
+        return AttributeDef(
+            name,
+            variadicity,
+            type_,
+            param_kind=specifier.param_kind(),
+            default_factory=specifier.default_factory,
+        )
 
 
 @dataclass
 class OperandDef(FieldDef):
+    default_is_none: bool = False
+
     def __post_init__(self):
-        if self.variadicity != Variadicity.optional and self.kw_only:
-            raise ValueError(f"only optional operand can be a keyword parameter")
+        if self.variadicity != Variadicity.optional and self.default_is_none:
+            raise ValueError(f"only optional operand can be set to None")
 
 
 @dataclass
 class ResultDef(FieldDef):
     infer_type: Callable[[], ir.Type] | None = None
+    default_factory: Optional[Callable[[], Any]] = None
+    default_is_none: bool = False
 
     def __post_init__(self):
-        if (
-            self.variadicity != Variadicity.optional
-            and not self.infer_type
-            and self.kw_only
-        ):
-            raise ValueError(f"only optional result can be a keyword parameter")
+        if self.variadicity != Variadicity.optional and self.default_is_none:
+            raise ValueError(f"only optional result can be set to None")
 
         if self.infer_type and self.variadicity != Variadicity.single:
             raise ValueError(
@@ -226,15 +306,33 @@ def process_type(self, type_):
         if self.infer_type:
             return self.infer_type()
 
+        if self.default_factory:
+            return self.default_factory()
+
         return None
 
 
 @dataclass
 class AttributeDef(FieldDef):
+    default_factory: Optional[Callable[[], Any]] = None
 
     def __post_init__(self):
         if self.variadicity != Variadicity.single:
             raise ValueError("optional attribute is not currently supported")
+        if (
+            self.param_kind == ParameterKind.KEYWORD_ONLY_WITH_DEFAULT
+            and not self.default_factory
+        ):
+            raise ValueError(f"only optional attribute can be set to None")
+
+    def process_attr(self, attr):
+        if attr:
+            return attr
+
+        if self.default_factory:
+            return self.default_factory()
+
+        return None
 
 
 @dataclass
@@ -415,10 +513,15 @@ def _generate_init_signature(fields: List[FieldDef]) -> Signature:
         params = [Parameter("self", Parameter.POSITIONAL_ONLY)]
 
         for i in args:
-            if i.kw_only:
-                params.append(Parameter(i.name, Parameter.KEYWORD_ONLY, default=None))
-            else:
-                params.append(Parameter(i.name, Parameter.POSITIONAL_OR_KEYWORD))
+            match i.param_kind:
+                case ParameterKind.POSITIONAL_OR_KEYWORD:
+                    params.append(Parameter(i.name, Parameter.POSITIONAL_OR_KEYWORD))
+                case ParameterKind.KEYWORD_ONLY_WITH_DEFAULT:
+                    params.append(
+                        Parameter(i.name, Parameter.KEYWORD_ONLY, default=None)
+                    )
+                case ParameterKind.KEYWORD_ONLY_WITHOUT_DEFAULT:
+                    params.append(Parameter(i.name, Parameter.KEYWORD_ONLY))
 
         params.append(Parameter("loc", Parameter.KEYWORD_ONLY, default=None))
         params.append(Parameter("ip", Parameter.KEYWORD_ONLY, default=None))
@@ -439,9 +542,7 @@ def __init__(*args, **kwargs):
             _operands = [args[operand.name] for operand in operands]
             _results = [result.process_type(args[result.name]) for result in results]
             _attributes = dict(
-                (attr.name, args[attr.name])
-                for attr in attrs
-                if args[attr.name] is not None
+                (attr.name, attr.process_attr(args[attr.name])) for attr in attrs
             )
             _regions = len(regions) or None
             _ods_successors = None
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index 78c74684cef77..f71d3cd794463 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -734,7 +734,49 @@ class AssignNoneOnNonOptionalOp(
             a: Operand[IntegerType[32]] = None
 
     except ValueError as e:
-        # CHECK: only optional operand can be a keyword parameter
+        # CHECK: only optional operand can be set to None
+        print(e)
+
+    try:
+
+        class AssignNoneOnnAttributeOp(
+            TestInvalid.Operation, name="assign_none_on_attribute"
+        ):
+            a: IntegerAttr = None
+
+    except ValueError as e:
+        # CHECK: only optional attribute can be set to None
+        print(e)
+
+    try:
+
+        class CannotInferTypeOp(TestInvalid.Operation, name="cannot_infer_type"):
+            a: Result[IntegerType] = result(infer_type=True)
+
+    except TypeError as e:
+        # CHECK: unsupported type for inferring
+        print(e)
+
+    try:
+
+        class WrongFieldSpecifierOp(
+            TestInvalid.Operation, name="wrong_field_specifier"
+        ):
+            a: Result[IntegerType] = operand()
+
+    except TypeError as e:
+        # CHECK: only `result` field specifier can be used for result fields
+        print(e)
+
+    try:
+
+        class WrongFieldSpecifierOp2(
+            TestInvalid.Operation, name="wrong_field_specifier2"
+        ):
+            a: IntegerAttr = operand()
+
+    except TypeError as e:
+        # CHECK: only `attribute` field specifier can be used for attribute fields
         print(e)
 
 
@@ -777,3 +819,59 @@ class OpWithAttr(TestAttrInOp.Operation, name="op_with_attr"):
         # CHECK: "ext_attr_in_op.op_with_attr"() {a = 42 : i32, b = i32} : () -> ()
         # CHECK: "ext_attr_in_op.op_with_attr"() {a = "hello", b = i64} : () -> ()
         print(module)
+
+
+ at run
+def testExtDialectFieldSpecifiers():
+    class TestFieldSpecifiers(Dialect, name="ext_field_specifiers"):
+        pass
+
+    class OperandSpecifierOp(TestFieldSpecifiers.Operation, name="operand_specifier"):
+        a: Operand[IntegerType[32]] = operand()
+        b: Optional[Operand[IntegerType[32]]] = None
+        c: Operand[IntegerType[32]] = operand(kw_only=True)
+
+    class ResultSpecifierOp(TestFieldSpecifiers.Operation, name="result_specifier"):
+        a: Result[IntegerType[32]] = result()
+        b: Result[IntegerType[16]] = result(infer_type=True)
+        c: Result[IntegerType] = result(
+            default_factory=lambda: IntegerType.get_signless(8)
+        )
+        d: Sequence[Result[IntegerType]] = result(default_factory=list)
+        e: Result[IntegerType[32]] = result(kw_only=True)
+
+    class AttributeSpecifierOp(
+        TestFieldSpecifiers.Operation, name="attribute_specifier"
+    ):
+        a: IntegerAttr = attribute()
+        b: IntegerAttr = attribute(
+            default_factory=lambda: IntegerAttr.get(IntegerType.get_signless(32), 42)
+        )
+        b: StringAttr["a"] | StringAttr["b"] = attribute(
+            default_factory=lambda: StringAttr.get("a")
+        )
+        c: IntegerAttr = attribute(kw_only=True)
+
+    with Context(), Location.unknown():
+        TestFieldSpecifiers.load()
+
+        # CHECK: (self, /, a, *, b=None, c, loc=None, ip=None)
+        print(OperandSpecifierOp.__init__.__signature__)
+        # CHECK: (self, /, a, *, b=None, c=None, d=None, e, loc=None, ip=None)
+        print(ResultSpecifierOp.__init__.__signature__)
+        # CHECK: (self, /, a, *, b=None, c, loc=None, ip=None)
+        print(AttributeSpecifierOp.__init__.__signature__)
+
+        module = Module.create()
+        i32 = IntegerType.get_signless(32)
+        with InsertionPoint(module.body):
+            one = arith.constant(i32, 1)
+            OperandSpecifierOp(one, c=one)
+            ResultSpecifierOp(i32, e=i32)
+            AttributeSpecifierOp(IntegerAttr.get(i32, 42), c=IntegerAttr.get(i32, 100))
+
+        assert module.operation.verify()
+        # CHECK: "ext_field_specifiers.operand_specifier"(%c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 0, 1>} : (i32, i32) -> ()
+        # CHECK: %0:4 = "ext_field_specifiers.result_specifier"() {resultSegmentSizes = array<i32: 1, 1, 1, 0, 1>} : () -> (i32, i16, i8, i32)
+        # CHECK: "ext_field_specifiers.attribute_specifier"() {a = 42 : i32, b = "a", c = 100 : i32} : () -> ()
+        print(module)

>From 57cfb98cf0894cc30c6fb40fe279bb76906d53fb Mon Sep 17 00:00:00 2001
From: Twice <twice at apache.org>
Date: Wed, 25 Mar 2026 00:12:49 +0800
Subject: [PATCH 2/6] Apply suggestion from @PragmaTwice

---
 mlir/python/mlir/dialects/ext.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 7242c65d54492..2c1237d3660aa 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -258,7 +258,7 @@ def from_type_hint(name, type_, specifier) -> "FieldDef":
         elif type_ is ir.Region:
             if specifier.type_ and specifier.type_ is not Region:
                 raise TypeError(
-                    f"only `region` field specifier can be used for region fields"
+                    f"currently no field specifier can be used for region fields"
                 )
             return RegionDef(name, variadicity, Any)
 

>From 8e7c80c72e660dbfdfcec13165a42ea65a25aac4 Mon Sep 17 00:00:00 2001
From: Twice <twice at apache.org>
Date: Fri, 27 Mar 2026 00:11:29 +0800
Subject: [PATCH 3/6] Update mlir/python/mlir/dialects/ext.py

Co-authored-by: Rolf Morel <rolfmorel at gmail.com>
---
 mlir/python/mlir/dialects/ext.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 2c1237d3660aa..fe4d3290ffb0b 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -130,7 +130,7 @@ def result(
     infer_type: bool = False,
     default_factory: Optional[Callable[[], Any]] = None,
     kw_only: bool = False,
-) -> Any:
+) -> Result:
     """
     A field specifier for `Result` definitions.
     """

>From 740f96b56df8fcd3df142fef30ab37e80156a5dd Mon Sep 17 00:00:00 2001
From: Twice <twice at apache.org>
Date: Fri, 27 Mar 2026 00:11:41 +0800
Subject: [PATCH 4/6] Update mlir/python/mlir/dialects/ext.py

Co-authored-by: Rolf Morel <rolfmorel at gmail.com>
---
 mlir/python/mlir/dialects/ext.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index fe4d3290ffb0b..362cea6097080 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -150,7 +150,7 @@ def result(
 def operand(
     *,
     kw_only: bool = False,
-) -> Any:
+) -> Operand:
     """
     A field specifier for `Operand` definitions.
     """

>From fc681e9d7a2afc0548701d1fd4d631e0b9fffa47 Mon Sep 17 00:00:00 2001
From: Twice <twice at apache.org>
Date: Fri, 27 Mar 2026 00:13:01 +0800
Subject: [PATCH 5/6] Update mlir/python/mlir/dialects/ext.py

Co-authored-by: Rolf Morel <rolfmorel at gmail.com>
---
 mlir/python/mlir/dialects/ext.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 362cea6097080..8537b3d36f993 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -165,7 +165,7 @@ def attribute(
     *,
     default_factory: Optional[Callable[[], Any]] = None,
     kw_only: bool = False,
-) -> Any:
+) -> ir.Attribute:
     """
     A field specifier for attribute definitions.
     """

>From c1623f9372e8185177fdb6967a6883922c25c277 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 28 Mar 2026 16:47:24 +0800
Subject: [PATCH 6/6] fix typo

---
 mlir/test/python/dialects/ext.py | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index f71d3cd794463..0d42f6c325c26 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -847,10 +847,10 @@ class AttributeSpecifierOp(
         b: IntegerAttr = attribute(
             default_factory=lambda: IntegerAttr.get(IntegerType.get_signless(32), 42)
         )
-        b: StringAttr["a"] | StringAttr["b"] = attribute(
+        c: StringAttr["a"] | StringAttr["b"] = attribute(
             default_factory=lambda: StringAttr.get("a")
         )
-        c: IntegerAttr = attribute(kw_only=True)
+        d: IntegerAttr = attribute(kw_only=True)
 
     with Context(), Location.unknown():
         TestFieldSpecifiers.load()
@@ -859,7 +859,7 @@ class AttributeSpecifierOp(
         print(OperandSpecifierOp.__init__.__signature__)
         # CHECK: (self, /, a, *, b=None, c=None, d=None, e, loc=None, ip=None)
         print(ResultSpecifierOp.__init__.__signature__)
-        # CHECK: (self, /, a, *, b=None, c, loc=None, ip=None)
+        # CHECK: (self, /, a, *, b=None, c=None, d, loc=None, ip=None)
         print(AttributeSpecifierOp.__init__.__signature__)
 
         module = Module.create()
@@ -868,10 +868,10 @@ class AttributeSpecifierOp(
             one = arith.constant(i32, 1)
             OperandSpecifierOp(one, c=one)
             ResultSpecifierOp(i32, e=i32)
-            AttributeSpecifierOp(IntegerAttr.get(i32, 42), c=IntegerAttr.get(i32, 100))
+            AttributeSpecifierOp(IntegerAttr.get(i32, 43), d=IntegerAttr.get(i32, 100))
 
         assert module.operation.verify()
         # CHECK: "ext_field_specifiers.operand_specifier"(%c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 0, 1>} : (i32, i32) -> ()
         # CHECK: %0:4 = "ext_field_specifiers.result_specifier"() {resultSegmentSizes = array<i32: 1, 1, 1, 0, 1>} : () -> (i32, i16, i8, i32)
-        # CHECK: "ext_field_specifiers.attribute_specifier"() {a = 42 : i32, b = "a", c = 100 : i32} : () -> ()
+        # CHECK: "ext_field_specifiers.attribute_specifier"() {a = 43 : i32, b = 42 : i32, c = "a", d = 100 : i32} : () -> ()
         print(module)



More information about the Mlir-commits mailing list