[Mlir-commits] [mlir] [MLIR][Python] Add a DSL for defining dialects in Python bindings (PR #169045)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 9 05:41:28 PST 2026


https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/169045

>From 3ec4809f2b5a76d14eba1a0707e30791cc5cc805 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Wed, 19 Nov 2025 01:04:38 +0800
Subject: [PATCH 01/20] [MLIR][Python] Add a DSL for defining IRDL dialects in
 Python bindings

---
 mlir/python/CMakeLists.txt                    |   4 +-
 .../dialects/{irdl.py => irdl/__init__.py}    |  13 +-
 mlir/python/mlir/dialects/irdl/dsl.py         | 343 ++++++++++++++++++
 mlir/test/python/dialects/irdsl.py            | 308 ++++++++++++++++
 4 files changed, 661 insertions(+), 7 deletions(-)
 rename mlir/python/mlir/dialects/{irdl.py => irdl/__init__.py} (91%)
 create mode 100644 mlir/python/mlir/dialects/irdl/dsl.py
 create mode 100644 mlir/test/python/dialects/irdsl.py

diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 2acb6ee6cfda5..27b87a8b70144 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -511,7 +511,9 @@ declare_mlir_dialect_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   TD_FILE dialects/IRDLOps.td
-  SOURCES dialects/irdl.py
+  SOURCES
+    dialects/irdl/__init__.py
+    dialects/irdl/dsl.py
   DIALECT_NAME irdl
   GEN_ENUM_BINDINGS
 )
diff --git a/mlir/python/mlir/dialects/irdl.py b/mlir/python/mlir/dialects/irdl/__init__.py
similarity index 91%
rename from mlir/python/mlir/dialects/irdl.py
rename to mlir/python/mlir/dialects/irdl/__init__.py
index 1ec951b69b646..6b2787ed7966c 100644
--- a/mlir/python/mlir/dialects/irdl.py
+++ b/mlir/python/mlir/dialects/irdl/__init__.py
@@ -2,13 +2,14 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-from ._irdl_ops_gen import *
-from ._irdl_ops_gen import _Dialect
-from ._irdl_enum_gen import *
-from .._mlir_libs._mlirDialectsIRDL import *
-from ..ir import register_attribute_builder
-from ._ods_common import _cext as _ods_cext
+from .._irdl_ops_gen import *
+from .._irdl_ops_gen import _Dialect
+from .._irdl_enum_gen import *
+from ..._mlir_libs._mlirDialectsIRDL import *
+from ...ir import register_attribute_builder
+from .._ods_common import _cext as _ods_cext
 from typing import Union, Sequence
+from . import dsl
 
 _ods_ir = _ods_cext.ir
 
diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
new file mode 100644
index 0000000000000..3cc234503665a
--- /dev/null
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -0,0 +1,343 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ...dialects import irdl as _irdl
+from .._ods_common import (
+    _cext as _ods_cext,
+    segmented_accessor as _ods_segmented_accessor,
+)
+from . import Variadicity
+from typing import Dict, List, Union, Callable, Tuple
+from dataclasses import dataclass
+from inspect import Parameter as _Parameter, Signature as _Signature
+from types import SimpleNamespace as _SimpleNameSpace
+
+_ods_ir = _ods_cext.ir
+
+
+class ConstraintExpr:
+    def _lower(self, ctx: "ConstraintLoweringContext") -> _ods_ir.Value:
+        raise NotImplementedError()
+
+    def __or__(self, other: "ConstraintExpr") -> "ConstraintExpr":
+        return AnyOf(self, other)
+
+    def __and__(self, other: "ConstraintExpr") -> "ConstraintExpr":
+        return AllOf(self, other)
+
+
+class ConstraintLoweringContext:
+    def __init__(self):
+        # Cache so that the same ConstraintExpr instance reuses its SSA value.
+        self._cache: Dict[int, _ods_ir.Value] = {}
+
+    def lower(self, expr: ConstraintExpr) -> _ods_ir.Value:
+        key = id(expr)
+        if key in self._cache:
+            return self._cache[key]
+        v = expr._lower(self)
+        self._cache[key] = v
+        return v
+
+
+class Is(ConstraintExpr):
+    def __init__(self, attr: _ods_ir.Attribute):
+        self.attr = attr
+
+    def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
+        return _irdl.is_(self.attr)
+
+
+class IsType(Is):
+    def __init__(self, typ: _ods_ir.Type):
+        super().__init__(_ods_ir.TypeAttr.get(typ))
+
+
+class AnyOf(ConstraintExpr):
+    def __init__(self, *exprs: ConstraintExpr):
+        self.exprs = exprs
+
+    def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
+        return _irdl.any_of(ctx.lower(expr) for expr in self.exprs)
+
+
+class AllOf(ConstraintExpr):
+    def __init__(self, *exprs: ConstraintExpr):
+        self.exprs = exprs
+
+    def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
+        return _irdl.all_of(ctx.lower(expr) for expr in self.exprs)
+
+
+class Any(ConstraintExpr):
+    def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
+        return _irdl.any()
+
+
+class BaseName(ConstraintExpr):
+    def __init__(self, name: str):
+        self.name = name
+
+    def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
+        return _irdl.base(base_name=self.name)
+
+
+class BaseRef(ConstraintExpr):
+    def __init__(self, ref):
+        self.ref = ref
+
+    def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
+        return _irdl.base(base_ref=self.ref)
+
+
+class FieldDef:
+    def __set_name__(self, owner, name: str):
+        self.name = name
+
+
+ at dataclass
+class Operand(FieldDef):
+    constraint: ConstraintExpr
+    variadicity: Variadicity = Variadicity.single
+
+
+ at dataclass
+class Result(FieldDef):
+    constraint: ConstraintExpr
+    variadicity: Variadicity = Variadicity.single
+
+
+ at dataclass
+class Attribute(FieldDef):
+    constraint: ConstraintExpr
+
+    def __post_init__(self):
+        # just for unified processing,
+        # currently optional attribute is not supported by IRDL
+        self.variadicity = Variadicity.single
+
+
+ at dataclass
+class Operation:
+    dialect_name: str
+    name: str
+    # We store operands and attributes into one list to maintain relative orders
+    # among them for generating OpView class.
+    operands_and_attrs: List[Union[Operand, Attribute]]
+    results: List[Result]
+
+    def _emit(self) -> None:
+        op = _irdl.operation_(self.name)
+        ctx = ConstraintLoweringContext()
+
+        operands = [i for i in self.operands_and_attrs if isinstance(i, Operand)]
+        attrs = [i for i in self.operands_and_attrs if isinstance(i, Attribute)]
+
+        with _ods_ir.InsertionPoint(op.body):
+            if operands:
+                _irdl.operands_(
+                    [ctx.lower(i.constraint) for i in operands],
+                    [i.name for i in operands],
+                    [i.variadicity for i in operands],
+                )
+            if attrs:
+                _irdl.attributes_(
+                    [ctx.lower(i.constraint) for i in attrs],
+                    [i.name for i in attrs],
+                )
+            if self.results:
+                _irdl.results_(
+                    [ctx.lower(i.constraint) for i in self.results],
+                    [i.name for i in self.results],
+                    [i.variadicity for i in self.results],
+                )
+
+    def _make_op_view_and_builder(self) -> Tuple[type, Callable]:
+        operands = [i for i in self.operands_and_attrs if isinstance(i, Operand)]
+        attrs = [i for i in self.operands_and_attrs if isinstance(i, Attribute)]
+
+        def variadicity_to_segment(variadicity: Variadicity) -> int:
+            if variadicity == Variadicity.variadic:
+                return -1
+            if variadicity == Variadicity.optional:
+                return 0
+            return 1
+
+        operand_segments = None
+        if any(i.variadicity != Variadicity.single for i in operands):
+            operand_segments = [variadicity_to_segment(i.variadicity) for i in operands]
+
+        result_segments = None
+        if any(i.variadicity != Variadicity.single for i in self.results):
+            result_segments = [
+                variadicity_to_segment(i.variadicity) for i in self.results
+            ]
+
+        args = self.results + self.operands_and_attrs
+        positional_args = [
+            i.name for i in args if i.variadicity != Variadicity.optional
+        ]
+        optional_args = [i.name for i in args if i.variadicity == Variadicity.optional]
+
+        params = [_Parameter("self", _Parameter.POSITIONAL_ONLY)]
+        for i in positional_args:
+            params.append(_Parameter(i, _Parameter.POSITIONAL_OR_KEYWORD))
+        for i in optional_args:
+            params.append(_Parameter(i, _Parameter.KEYWORD_ONLY, default=None))
+        params.append(_Parameter("loc", _Parameter.KEYWORD_ONLY, default=None))
+        params.append(_Parameter("ip", _Parameter.KEYWORD_ONLY, default=None))
+
+        sig = _Signature(params)
+        op = self
+
+        class _OpView(_ods_ir.OpView):
+            OPERATION_NAME = f"{op.dialect_name}.{op.name}"
+            _ODS_REGIONS = (0, True)
+            _ODS_OPERAND_SEGMENTS = operand_segments
+            _ODS_RESULT_SEGMENTS = result_segments
+
+            def __init__(*args, **kwargs):
+                bound = sig.bind(*args, **kwargs)
+                bound.apply_defaults()
+                args = bound.arguments
+
+                _operands = [args[operand.name] for operand in operands]
+                _results = [args[result.name] for result in op.results]
+                _attributes = dict(
+                    (attr.name, args[attr.name])
+                    for attr in attrs
+                    if args[attr.name] is not None
+                )
+                _regions = None
+                _ods_successors = None
+                self = args["self"]
+                super(_OpView, self).__init__(
+                    self.OPERATION_NAME,
+                    self._ODS_REGIONS,
+                    self._ODS_OPERAND_SEGMENTS,
+                    self._ODS_RESULT_SEGMENTS,
+                    attributes=_attributes,
+                    results=_results,
+                    operands=_operands,
+                    successors=_ods_successors,
+                    regions=_regions,
+                    loc=args["loc"],
+                    ip=args["ip"],
+                )
+
+            __init__.__signature__ = sig
+
+        for attr in attrs:
+            setattr(
+                _OpView,
+                attr.name,
+                property(lambda self, name=attr.name: self.attributes[name]),
+            )
+
+        def value_range_getter(
+            value_range: Union[_ods_ir.OpOperandList, _ods_ir.OpResultList],
+            variadicity: Variadicity,
+        ):
+            if variadicity == Variadicity.single:
+                return value_range[0]
+            if variadicity == Variadicity.optional:
+                return value_range[0] if len(value_range) > 0 else None
+            return value_range
+
+        for i, operand in enumerate(operands):
+            if operand_segments:
+
+                def getter(self, i=i, operand=operand):
+                    operand_range = _ods_segmented_accessor(
+                        self.operation.operands,
+                        self.operation.attributes["operandSegmentSizes"],
+                        i,
+                    )
+                    return value_range_getter(operand_range, operand.variadicity)
+
+                setattr(_OpView, operand.name, property(getter))
+            else:
+                setattr(
+                    _OpView, operand.name, property(lambda self, i=i: self.operands[i])
+                )
+        for i, result in enumerate(self.results):
+            if result_segments:
+
+                def getter(self, i=i, result=result):
+                    result_range = _ods_segmented_accessor(
+                        self.operation.results,
+                        self.operation.attributes["resultSegmentSizes"],
+                        i,
+                    )
+                    return value_range_getter(result_range, result.variadicity)
+
+                setattr(_OpView, result.name, property(getter))
+            else:
+                setattr(
+                    _OpView, result.name, property(lambda self, i=i: self.results[i])
+                )
+
+        def _builder(*args, **kwargs) -> _OpView:
+            return _OpView(*args, **kwargs)
+
+        _builder.__signature__ = _Signature(params[1:])
+
+        return _OpView, _builder
+
+
+class Dialect:
+    def __init__(self, name: str):
+        self.name = name
+        self.operations: List[Operation] = []
+        self.namespace = _SimpleNameSpace()
+
+    def _emit(self) -> None:
+        d = _irdl.dialect(self.name)
+        with _ods_ir.InsertionPoint(d.body):
+            for op in self.operations:
+                op._emit()
+
+    def _make_module(self) -> _ods_ir.Module:
+        with _ods_ir.Location.unknown():
+            m = _ods_ir.Module.create()
+            with _ods_ir.InsertionPoint(m.body):
+                self._emit()
+        return m
+
+    def _make_dialect_class(self) -> type:
+        class _Dialect(_ods_ir.Dialect):
+            DIALECT_NAMESPACE = self.name
+
+        return _Dialect
+
+    def load(self) -> _SimpleNameSpace:
+        _irdl.load_dialects(self._make_module())
+        dialect_class = self._make_dialect_class()
+        _ods_cext.register_dialect(dialect_class)
+        for op in self.operations:
+            _ods_cext.register_operation(dialect_class)(op.op_view)
+        return self.namespace
+
+    def op(self, name: str) -> Callable[[type], type]:
+        def decorator(cls: type) -> type:
+            operands_and_attrs: List[Union[Operand, Attribute]] = []
+            results: List[Result] = []
+
+            for field in cls.__dict__.values():
+                if isinstance(field, Operand) or isinstance(field, Attribute):
+                    operands_and_attrs.append(field)
+                elif isinstance(field, Result):
+                    results.append(field)
+
+            op_def = Operation(self.name, name, operands_and_attrs, results)
+            op_view, builder = op_def._make_op_view_and_builder()
+            setattr(op_def, "op_view", op_view)
+            setattr(op_def, "builder", builder)
+            self.operations.append(op_def)
+            self.namespace.__dict__[cls.__name__] = op_view
+            op_view.__name__ = cls.__name__
+            self.namespace.__dict__[name.replace(".", "_")] = builder
+            return cls
+
+        return decorator
diff --git a/mlir/test/python/dialects/irdsl.py b/mlir/test/python/dialects/irdsl.py
new file mode 100644
index 0000000000000..8ef30ae0a4c13
--- /dev/null
+++ b/mlir/test/python/dialects/irdsl.py
@@ -0,0 +1,308 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects.irdl import dsl as irdsl
+from mlir.dialects import arith
+import sys
+
+
+def run(f):
+    print("\nTEST:", f.__name__, file=sys.stderr)
+    with Context():
+        f()
+
+
+# CHECK: TEST: testMyInt
+ at run
+def testMyInt():
+    myint = irdsl.Dialect("myint")
+    iattr = irdsl.BaseName("#builtin.integer")
+    i32 = irdsl.IsType(IntegerType.get_signless(32))
+
+    @myint.op("constant")
+    class ConstantOp:
+        value = irdsl.Attribute(iattr)
+        cst = irdsl.Result(i32)
+
+    @myint.op("add")
+    class AddOp:
+        lhs = irdsl.Operand(i32)
+        rhs = irdsl.Operand(i32)
+        res = irdsl.Result(i32)
+
+    # CHECK: irdl.dialect @myint {
+    # CHECK:   irdl.operation @constant {
+    # CHECK:     %0 = irdl.base "#builtin.integer"
+    # CHECK:     irdl.attributes {"value" = %0}
+    # CHECK:     %1 = irdl.is i32
+    # CHECK:     irdl.results(cst: %1)
+    # CHECK:   }
+    # CHECK:   irdl.operation @add {
+    # CHECK:     %0 = irdl.is i32
+    # CHECK:     irdl.operands(lhs: %0, rhs: %0)
+    # CHECK:     irdl.results(res: %0)
+    # CHECK:   }
+    # CHECK: }
+    print(myint._make_module())
+    myint = myint.load()
+
+    # CHECK: ['ConstantOp', 'constant', 'AddOp', 'add']
+    print([i for i in myint.__dict__.keys()])
+
+    i32 = IntegerType.get_signless(32)
+    with Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            two = myint.constant(i32, IntegerAttr.get(i32, 2))
+            three = myint.constant(i32, IntegerAttr.get(i32, 3))
+            add1 = myint.add(i32, two, three)
+            add2 = myint.add(i32, add1, two)
+            add3 = myint.add(i32, add2, three)
+
+    # CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
+    # CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
+    # CHECK: %2 = "myint.add"(%0, %1) : (i32, i32) -> i32
+    # CHECK: %3 = "myint.add"(%2, %0) : (i32, i32) -> i32
+    # CHECK: %4 = "myint.add"(%3, %1) : (i32, i32) -> i32
+    print(module)
+    assert module.operation.verify()
+
+    # CHECK: AddOp
+    print(type(add1).__name__)
+    # CHECK: ConstantOp
+    print(type(two).__name__)
+    # CHECK: myint.add
+    print(add1.OPERATION_NAME)
+    # CHECK: None
+    print(add1._ODS_OPERAND_SEGMENTS)
+    # CHECK: None
+    print(add1._ODS_RESULT_SEGMENTS)
+    # CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
+    print(add1.lhs.owner)
+    # CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
+    print(add1.rhs.owner)
+    # CHECK: 2 : i32
+    print(two.value)
+    # CHECK: Value(%0
+    print(two.cst)
+    # CHECK: (res, lhs, rhs, *, loc=None, ip=None)
+    print(myint.add.__signature__)
+    # CHECK: (cst, value, *, loc=None, ip=None)
+    print(myint.constant.__signature__)
+
+
+ at run
+def testIRDSL():
+    test = irdsl.Dialect("irdsl_test")
+    i32 = irdsl.IsType(IntegerType.get_signless(32))
+    i64 = irdsl.IsType(IntegerType.get_signless(64))
+    i32or64 = i32 | i64
+    any = irdsl.Any()
+    f32 = irdsl.IsType(F32Type.get())
+    iattr = irdsl.BaseName("#builtin.integer")
+    fattr = irdsl.BaseName("#builtin.float")
+
+    @test.op("constraint")
+    class ConstraintOp:
+        a = irdsl.Operand(i32or64)
+        b = irdsl.Operand(any)
+        c = irdsl.Operand(f32 | i32)
+        d = irdsl.Operand(any)
+        x = irdsl.Attribute(iattr)
+        y = irdsl.Attribute(fattr)
+
+    @test.op("optional")
+    class OptionalOp:
+        a = irdsl.Operand(i32)
+        b = irdsl.Operand(i32, irdsl.Variadicity.optional)
+        out1 = irdsl.Result(i32)
+        out2 = irdsl.Result(i32, irdsl.Variadicity.optional)
+        out3 = irdsl.Result(i32)
+
+    @test.op("optional2")
+    class Optional2Op:
+        a = irdsl.Operand(i32, irdsl.Variadicity.optional)
+        b = irdsl.Result(i32, irdsl.Variadicity.optional)
+
+    @test.op("variadic")
+    class VariadicOp:
+        a = irdsl.Operand(i32)
+        b = irdsl.Operand(i32, irdsl.Variadicity.optional)
+        c = irdsl.Operand(i32, irdsl.Variadicity.variadic)
+        out1 = irdsl.Result(i32, irdsl.Variadicity.variadic)
+        out2 = irdsl.Result(i32, irdsl.Variadicity.variadic)
+        out3 = irdsl.Result(i32, irdsl.Variadicity.optional)
+        out4 = irdsl.Result(i32)
+
+    @test.op("variadic2")
+    class Variadic2Op:
+        a = irdsl.Operand(i32, irdsl.Variadicity.variadic)
+        b = irdsl.Result(i32, irdsl.Variadicity.variadic)
+
+    @test.op("mixed")
+    class MixedOp:
+        out = irdsl.Result(i32)
+        in1 = irdsl.Operand(i32)
+        in2 = irdsl.Attribute(iattr)
+        in3 = irdsl.Operand(i32, irdsl.Variadicity.optional)
+        in4 = irdsl.Attribute(iattr)
+        in5 = irdsl.Operand(i32)
+
+    # CHECK: irdl.dialect @irdsl_test {
+    # CHECK:   irdl.operation @constraint {
+    # CHECK:     %0 = irdl.is i32
+    # CHECK:     %1 = irdl.is i64
+    # CHECK:     %2 = irdl.any_of(%0, %1)
+    # CHECK:     %3 = irdl.any
+    # CHECK:     %4 = irdl.is f32
+    # CHECK:     %5 = irdl.any_of(%4, %0)
+    # CHECK:     irdl.operands(a: %2, b: %3, c: %5, d: %3)
+    # CHECK:     %6 = irdl.base "#builtin.integer"
+    # CHECK:     %7 = irdl.base "#builtin.float"
+    # CHECK:     irdl.attributes {"x" = %6, "y" = %7}
+    # CHECK:   }
+    # CHECK:   irdl.operation @optional {
+    # CHECK:     %0 = irdl.is i32
+    # CHECK:     irdl.operands(a: %0, b: optional %0)
+    # CHECK:     irdl.results(out1: %0, out2: optional %0, out3: %0)
+    # CHECK:   }
+    # CHECK:   irdl.operation @optional2 {
+    # CHECK:     %0 = irdl.is i32
+    # CHECK:     irdl.operands(a: optional %0)
+    # CHECK:     irdl.results(b: optional %0)
+    # CHECK:   }
+    # CHECK:   irdl.operation @variadic {
+    # CHECK:     %0 = irdl.is i32
+    # CHECK:     irdl.operands(a: %0, b: optional %0, c: variadic %0)
+    # CHECK:     irdl.results(out1: variadic %0, out2: variadic %0, out3: optional %0, out4: %0)
+    # CHECK:   }
+    # CHECK:   irdl.operation @variadic2 {
+    # CHECK:     %0 = irdl.is i32
+    # CHECK:     irdl.operands(a: variadic %0)
+    # CHECK:     irdl.results(b: variadic %0)
+    # CHECK:   }
+    # CHECK:   irdl.operation @mixed {
+    # CHECK:     %0 = irdl.is i32
+    # CHECK:     irdl.operands(in1: %0, in3: optional %0, in5: %0)
+    # CHECK:     %1 = irdl.base "#builtin.integer"
+    # CHECK:     irdl.attributes {"in2" = %1, "in4" = %1}
+    # CHECK:     irdl.results(out: %0)
+    # CHECK:   }
+    # CHECK: }
+    print(test._make_module())
+    test = test.load()
+
+    # CHECK: (a, b, c, d, x, y, *, loc=None, ip=None)
+    print(test.constraint.__signature__)
+    # CHECK: (out1, out3, a, *, out2=None, b=None, loc=None, ip=None)
+    print(test.optional.__signature__)
+    # CHECK: (*, b=None, a=None, loc=None, ip=None)
+    print(test.optional2.__signature__)
+    # CHECK: (out1, out2, out4, a, c, *, out3=None, b=None, loc=None, ip=None)
+    print(test.variadic.__signature__)
+    # CHECK: (b, a, *, loc=None, ip=None)
+    print(test.variadic2.__signature__)
+    # CHECK: (out, in1, in2, in4, in5, *, in3=None, loc=None, ip=None)
+    print(test.mixed.__signature__)
+
+    # CHECK: None None
+    print(
+        test.ConstraintOp._ODS_OPERAND_SEGMENTS, test.ConstraintOp._ODS_RESULT_SEGMENTS
+    )
+    # CHECK: [1, 0] [1, 0, 1]
+    print(test.OptionalOp._ODS_OPERAND_SEGMENTS, test.OptionalOp._ODS_RESULT_SEGMENTS)
+    # CHECK: [0] [0]
+    print(test.Optional2Op._ODS_OPERAND_SEGMENTS, test.Optional2Op._ODS_RESULT_SEGMENTS)
+    # CHECK: [1, 0, -1] [-1, -1, 0, 1]
+    print(test.VariadicOp._ODS_OPERAND_SEGMENTS, test.VariadicOp._ODS_RESULT_SEGMENTS)
+    # CHECK: [-1] [-1]
+    print(test.Variadic2Op._ODS_OPERAND_SEGMENTS, test.Variadic2Op._ODS_RESULT_SEGMENTS)
+
+    i32 = IntegerType.get_signless(32)
+    i64 = IntegerType.get_signless(64)
+    f32 = F32Type.get()
+
+    with Location.unknown():
+        iattr = IntegerAttr.get(i32, 2)
+        fattr = FloatAttr.get_f32(2.3)
+
+        module = Module.create()
+        with InsertionPoint(module.body):
+            ione = arith.constant(i32, 1)
+            fone = arith.constant(f32, 1.2)
+
+            # CHECK: "irdsl_test.constraint"(%c1_i32, %c1_i32, %cst, %c1_i32) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, i32, f32, i32) -> ()
+            c1 = test.constraint(ione, ione, fone, ione, iattr, fattr)
+            # CHECK: "irdsl_test.constraint"(%c1_i32, %cst, %cst, %cst) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, f32, f32, f32) -> ()
+            test.constraint(ione, fone, fone, fone, iattr, fattr)
+            # CHECK: irdsl_test.constraint"(%c1_i32, %cst, %c1_i32, %cst) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, f32, i32, f32) -> ()
+            test.constraint(ione, fone, ione, fone, iattr, fattr)
+
+            # CHECK: %0:2 = "irdsl_test.optional"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0>, resultSegmentSizes = array<i32: 1, 0, 1>} : (i32) -> (i32, i32)
+            o1 = test.optional(i32, i32, ione)
+            # CHECK: %1:3 = "irdsl_test.optional"(%c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 1>, resultSegmentSizes = array<i32: 1, 1, 1>} : (i32, i32) -> (i32, i32, i32)
+            o2 = test.optional(i32, i32, ione, out2=i32, b=ione)
+            # CHECK: irdsl_test.optional2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 0>} : () -> ()
+            o3 = test.optional2()
+            # CHECK: %2 = "irdsl_test.optional2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 1>} : () -> i32
+            o4 = test.optional2(b=i32)
+            # CHECK: "irdsl_test.optional2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 0>} : (i32) -> ()
+            o5 = test.optional2(a=ione)
+            # CHECK: %3 = "irdsl_test.optional2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 1>} : (i32) -> i32
+            o6 = test.optional2(b=i32, a=ione)
+
+            # CHECK: %4:4 = "irdsl_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 0, 2>, resultSegmentSizes = array<i32: 1, 2, 0, 1>} : (i32, i32, i32) -> (i32, i32, i32, i32)
+            v1 = test.variadic([i32], [i32, i32], i32, ione, [ione, ione])
+            # CHECK: %5:5 = "irdsl_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 1, 1>, resultSegmentSizes = array<i32: 1, 2, 1, 1>} : (i32, i32, i32) -> (i32, i32, i32, i32, i32)
+            v2 = test.variadic([i32], [i32, i32], i32, ione, [ione], out3=i32, b=ione)
+            # CHECK: %6:4 = "irdsl_test.variadic"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0, 0>, resultSegmentSizes = array<i32: 2, 1, 0, 1>} : (i32) -> (i32, i32, i32, i32)
+            v3 = test.variadic([i32, i32], [i32], i32, ione, [])
+            # CHECK: "irdsl_test.variadic2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 0>} : () -> ()
+            v4 = test.variadic2([], [])
+            # CHECK: "irdsl_test.variadic2"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 3>, resultSegmentSizes = array<i32: 0>} : (i32, i32, i32) -> ()
+            v5 = test.variadic2([], [ione, ione, ione])
+            # CHECK: %7:2 = "irdsl_test.variadic2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 2>} : (i32) -> (i32, i32)
+            v6 = test.variadic2([i32, i32], [ione])
+
+            # CHECK: %8 = "irdsl_test.mixed"(%c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 0, 1>} : (i32, i32) -> i32
+            m1 = test.mixed(i32, ione, iattr, iattr, ione)
+            # CHECK: %9 = "irdsl_test.mixed"(%c1_i32, %c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 1, 1>} : (i32, i32, i32) -> i32
+            m2 = test.mixed(i32, ione, iattr, iattr, ione, in3=ione)
+
+    print(module)
+    assert module.operation.verify()
+
+    # CHECK: Value(%c1_i32
+    print(c1.a)
+    # CHECK: 2 : i32
+    print(c1.x)
+    # CHECK: Value(%c1_i32
+    print(o1.a)
+    # CHECK: None
+    print(o1.b)
+    # CHECK: Value(%c1_i32
+    print(o2.b)
+    # CHECK: 0
+    print(o1.out1.result_number)
+    # CHECK: None
+    print(o1.out2)
+    # CHECK: 0
+    print(o2.out1.result_number)
+    # CHECK: 1
+    print(o2.out2.result_number)
+    # CHECK: None
+    print(o3.a)
+    # CHECK: Value(%c1_i32
+    print(o5.a)
+    # CHECK: ['Value(%c1_i32 = arith.constant 1 : i32)', 'Value(%c1_i32 = arith.constant 1 : i32)']
+    print([str(i) for i in v1.c])
+    # CHECK: ['Value(%c1_i32 = arith.constant 1 : i32)']
+    print([str(i) for i in v2.c])
+    # CHECK: []
+    print([str(i) for i in v3.c])
+    # CHECK: 0 0
+    print(len(v4.a), len(v4.b))
+    # CHECK: 3 0
+    print(len(v5.a), len(v5.b))
+    # CHECK: 1 2
+    print(len(v6.a), len(v6.b))

>From c0b8ba70e97dd09179ba75c6ae5dde04d3be152e Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 24 Nov 2025 12:55:21 +0800
Subject: [PATCH 02/20] make FieldDef private

---
 mlir/python/mlir/dialects/irdl/dsl.py | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
index 3cc234503665a..b1916fdbd5f9b 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -91,25 +91,25 @@ def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
         return _irdl.base(base_ref=self.ref)
 
 
-class FieldDef:
+class _FieldDef:
     def __set_name__(self, owner, name: str):
         self.name = name
 
 
 @dataclass
-class Operand(FieldDef):
+class Operand(_FieldDef):
     constraint: ConstraintExpr
     variadicity: Variadicity = Variadicity.single
 
 
 @dataclass
-class Result(FieldDef):
+class Result(_FieldDef):
     constraint: ConstraintExpr
     variadicity: Variadicity = Variadicity.single
 
 
 @dataclass
-class Attribute(FieldDef):
+class Attribute(_FieldDef):
     constraint: ConstraintExpr
 
     def __post_init__(self):

>From 7de9ad40b6c293efc2bbb8f51140b82ad17522b2 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 24 Nov 2025 13:18:07 +0800
Subject: [PATCH 03/20] refactor some methods

---
 mlir/python/mlir/dialects/irdl/dsl.py | 63 ++++++++++++++++-----------
 1 file changed, 37 insertions(+), 26 deletions(-)

diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
index b1916fdbd5f9b..307609c1b342e 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -127,13 +127,16 @@ class Operation:
     operands_and_attrs: List[Union[Operand, Attribute]]
     results: List[Result]
 
-    def _emit(self) -> None:
-        op = _irdl.operation_(self.name)
-        ctx = ConstraintLoweringContext()
-
+    def _partition_operands_and_attrs(self) -> Tuple[List[Operand], List[Attribute]]:
         operands = [i for i in self.operands_and_attrs if isinstance(i, Operand)]
         attrs = [i for i in self.operands_and_attrs if isinstance(i, Attribute)]
+        return operands, attrs
+
+    def _emit(self) -> None:
+        ctx = ConstraintLoweringContext()
+        operands, attrs = self._partition_operands_and_attrs()
 
+        op = _irdl.operation_(self.name)
         with _ods_ir.InsertionPoint(op.body):
             if operands:
                 _irdl.operands_(
@@ -153,27 +156,26 @@ def _emit(self) -> None:
                     [i.variadicity for i in self.results],
                 )
 
-    def _make_op_view_and_builder(self) -> Tuple[type, Callable]:
-        operands = [i for i in self.operands_and_attrs if isinstance(i, Operand)]
-        attrs = [i for i in self.operands_and_attrs if isinstance(i, Attribute)]
-
-        def variadicity_to_segment(variadicity: Variadicity) -> int:
-            if variadicity == Variadicity.variadic:
-                return -1
-            if variadicity == Variadicity.optional:
-                return 0
-            return 1
-
-        operand_segments = None
-        if any(i.variadicity != Variadicity.single for i in operands):
-            operand_segments = [variadicity_to_segment(i.variadicity) for i in operands]
-
-        result_segments = None
-        if any(i.variadicity != Variadicity.single for i in self.results):
-            result_segments = [
-                variadicity_to_segment(i.variadicity) for i in self.results
+    @staticmethod
+    def _variadicity_to_segment(variadicity: Variadicity) -> int:
+        if variadicity == Variadicity.variadic:
+            return -1
+        if variadicity == Variadicity.optional:
+            return 0
+        return 1
+
+    @staticmethod
+    def _generate_segments(
+        operands_or_results: List[Union[Operand, Result]],
+    ) -> List[int]:
+        if any(i.variadicity != Variadicity.single for i in operands_or_results):
+            return [
+                Operation._variadicity_to_segment(i.variadicity)
+                for i in operands_or_results
             ]
+        return None
 
+    def _generate_init_params(self) -> List[_Parameter]:
         args = self.results + self.operands_and_attrs
         positional_args = [
             i.name for i in args if i.variadicity != Variadicity.optional
@@ -188,7 +190,16 @@ def variadicity_to_segment(variadicity: Variadicity) -> int:
         params.append(_Parameter("loc", _Parameter.KEYWORD_ONLY, default=None))
         params.append(_Parameter("ip", _Parameter.KEYWORD_ONLY, default=None))
 
-        sig = _Signature(params)
+        return params
+
+    def _make_op_view_and_builder(self) -> Tuple[type, Callable]:
+        operands, attrs = self._partition_operands_and_attrs()
+
+        operand_segments = Operation._generate_segments(operands)
+        result_segments = Operation._generate_segments(self.results)
+
+        params = self._generate_init_params()
+        init_sig = _Signature(params)
         op = self
 
         class _OpView(_ods_ir.OpView):
@@ -198,7 +209,7 @@ class _OpView(_ods_ir.OpView):
             _ODS_RESULT_SEGMENTS = result_segments
 
             def __init__(*args, **kwargs):
-                bound = sig.bind(*args, **kwargs)
+                bound = init_sig.bind(*args, **kwargs)
                 bound.apply_defaults()
                 args = bound.arguments
 
@@ -226,7 +237,7 @@ def __init__(*args, **kwargs):
                     ip=args["ip"],
                 )
 
-            __init__.__signature__ = sig
+            __init__.__signature__ = init_sig
 
         for attr in attrs:
             setattr(

>From 985cd25e1486ea4e8a7a67f76daf835172c58112 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 24 Nov 2025 14:05:26 +0800
Subject: [PATCH 04/20] more refactor

---
 mlir/python/mlir/dialects/irdl/dsl.py | 47 ++++++++++++++-------------
 1 file changed, 24 insertions(+), 23 deletions(-)

diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
index 307609c1b342e..0a30678e1318c 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -124,17 +124,17 @@ class Operation:
     name: str
     # We store operands and attributes into one list to maintain relative orders
     # among them for generating OpView class.
-    operands_and_attrs: List[Union[Operand, Attribute]]
-    results: List[Result]
+    fields: List[Union[Operand, Attribute, Result]]
 
-    def _partition_operands_and_attrs(self) -> Tuple[List[Operand], List[Attribute]]:
-        operands = [i for i in self.operands_and_attrs if isinstance(i, Operand)]
-        attrs = [i for i in self.operands_and_attrs if isinstance(i, Attribute)]
-        return operands, attrs
+    def _partition_fields(self) -> Tuple[List[Operand], List[Attribute], List[Result]]:
+        operands = [i for i in self.fields if isinstance(i, Operand)]
+        attrs = [i for i in self.fields if isinstance(i, Attribute)]
+        results = [i for i in self.fields if isinstance(i, Result)]
+        return operands, attrs, results
 
     def _emit(self) -> None:
         ctx = ConstraintLoweringContext()
-        operands, attrs = self._partition_operands_and_attrs()
+        operands, attrs, results = self._partition_fields()
 
         op = _irdl.operation_(self.name)
         with _ods_ir.InsertionPoint(op.body):
@@ -149,11 +149,11 @@ def _emit(self) -> None:
                     [ctx.lower(i.constraint) for i in attrs],
                     [i.name for i in attrs],
                 )
-            if self.results:
+            if results:
                 _irdl.results_(
-                    [ctx.lower(i.constraint) for i in self.results],
-                    [i.name for i in self.results],
-                    [i.variadicity for i in self.results],
+                    [ctx.lower(i.constraint) for i in results],
+                    [i.name for i in results],
+                    [i.variadicity for i in results],
                 )
 
     @staticmethod
@@ -176,7 +176,11 @@ def _generate_segments(
         return None
 
     def _generate_init_params(self) -> List[_Parameter]:
-        args = self.results + self.operands_and_attrs
+        # results are placed at the beginning of the parameter list,
+        # but operands and attributes can appear in any relative order.
+        args = [i for i in self.fields if isinstance(i, Result)] + [
+            i for i in self.fields if not isinstance(i, Result)
+        ]
         positional_args = [
             i.name for i in args if i.variadicity != Variadicity.optional
         ]
@@ -193,10 +197,10 @@ def _generate_init_params(self) -> List[_Parameter]:
         return params
 
     def _make_op_view_and_builder(self) -> Tuple[type, Callable]:
-        operands, attrs = self._partition_operands_and_attrs()
+        operands, attrs, results = self._partition_fields()
 
         operand_segments = Operation._generate_segments(operands)
-        result_segments = Operation._generate_segments(self.results)
+        result_segments = Operation._generate_segments(results)
 
         params = self._generate_init_params()
         init_sig = _Signature(params)
@@ -214,7 +218,7 @@ def __init__(*args, **kwargs):
                 args = bound.arguments
 
                 _operands = [args[operand.name] for operand in operands]
-                _results = [args[result.name] for result in op.results]
+                _results = [args[result.name] for result in results]
                 _attributes = dict(
                     (attr.name, args[attr.name])
                     for attr in attrs
@@ -272,7 +276,7 @@ def getter(self, i=i, operand=operand):
                 setattr(
                     _OpView, operand.name, property(lambda self, i=i: self.operands[i])
                 )
-        for i, result in enumerate(self.results):
+        for i, result in enumerate(results):
             if result_segments:
 
                 def getter(self, i=i, result=result):
@@ -332,16 +336,13 @@ def load(self) -> _SimpleNameSpace:
 
     def op(self, name: str) -> Callable[[type], type]:
         def decorator(cls: type) -> type:
-            operands_and_attrs: List[Union[Operand, Attribute]] = []
-            results: List[Result] = []
+            fields: List[Union[Operand, Attribute, Result]] = []
 
             for field in cls.__dict__.values():
-                if isinstance(field, Operand) or isinstance(field, Attribute):
-                    operands_and_attrs.append(field)
-                elif isinstance(field, Result):
-                    results.append(field)
+                if isinstance(field, _FieldDef):
+                    fields.append(field)
 
-            op_def = Operation(self.name, name, operands_and_attrs, results)
+            op_def = Operation(self.name, name, fields)
             op_view, builder = op_def._make_op_view_and_builder()
             setattr(op_def, "op_view", op_view)
             setattr(op_def, "builder", builder)

>From 2405fe4bd522c21ca7e71bd35645d49b938da570 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 24 Nov 2025 14:10:48 +0800
Subject: [PATCH 05/20] more refactor

---
 mlir/python/mlir/dialects/irdl/dsl.py | 14 +++++++-------
 1 file changed, 7 insertions(+), 7 deletions(-)

diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
index 0a30678e1318c..d98d82d1ba4c9 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -336,20 +336,20 @@ def load(self) -> _SimpleNameSpace:
 
     def op(self, name: str) -> Callable[[type], type]:
         def decorator(cls: type) -> type:
-            fields: List[Union[Operand, Attribute, Result]] = []
-
-            for field in cls.__dict__.values():
-                if isinstance(field, _FieldDef):
-                    fields.append(field)
-
+            fields = [
+                field for field in cls.__dict__.values() if isinstance(field, _FieldDef)
+            ]
             op_def = Operation(self.name, name, fields)
+
             op_view, builder = op_def._make_op_view_and_builder()
+            op_view.__name__ = cls.__name__
             setattr(op_def, "op_view", op_view)
             setattr(op_def, "builder", builder)
             self.operations.append(op_def)
+
             self.namespace.__dict__[cls.__name__] = op_view
-            op_view.__name__ = cls.__name__
             self.namespace.__dict__[name.replace(".", "_")] = builder
+
             return cls
 
         return decorator

>From d74e6603052698a35b3efc84c303d91d28d9a9df Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 24 Nov 2025 14:16:18 +0800
Subject: [PATCH 06/20] add a method to convert op name

---
 mlir/python/mlir/dialects/irdl/dsl.py | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
index d98d82d1ba4c9..2092c2417145a 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -334,6 +334,10 @@ def load(self) -> _SimpleNameSpace:
             _ods_cext.register_operation(dialect_class)(op.op_view)
         return self.namespace
 
+    @staticmethod
+    def op_name_to_python_id(name):
+        return name.replace(".", "_").replace("$", "_")
+
     def op(self, name: str) -> Callable[[type], type]:
         def decorator(cls: type) -> type:
             fields = [
@@ -348,7 +352,7 @@ def decorator(cls: type) -> type:
             self.operations.append(op_def)
 
             self.namespace.__dict__[cls.__name__] = op_view
-            self.namespace.__dict__[name.replace(".", "_")] = builder
+            self.namespace.__dict__[Dialect.op_name_to_python_id(name)] = builder
 
             return cls
 

>From bff9fb79903788ad4d2115def90ac2b1771fa81a Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 24 Nov 2025 14:16:58 +0800
Subject: [PATCH 07/20] make it private

---
 mlir/python/mlir/dialects/irdl/dsl.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
index 2092c2417145a..0db7cbbed81ca 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -335,7 +335,7 @@ def load(self) -> _SimpleNameSpace:
         return self.namespace
 
     @staticmethod
-    def op_name_to_python_id(name):
+    def _op_name_to_python_id(name):
         return name.replace(".", "_").replace("$", "_")
 
     def op(self, name: str) -> Callable[[type], type]:
@@ -352,7 +352,7 @@ def decorator(cls: type) -> type:
             self.operations.append(op_def)
 
             self.namespace.__dict__[cls.__name__] = op_view
-            self.namespace.__dict__[Dialect.op_name_to_python_id(name)] = builder
+            self.namespace.__dict__[Dialect._op_name_to_python_id(name)] = builder
 
             return cls
 

>From 4c4c81cf9f4b9f99d52053b4267ce1db8689d083 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 24 Nov 2025 14:17:25 +0800
Subject: [PATCH 08/20] add type annotation

---
 mlir/python/mlir/dialects/irdl/dsl.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
index 0db7cbbed81ca..41520f0e0f68c 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -335,7 +335,7 @@ def load(self) -> _SimpleNameSpace:
         return self.namespace
 
     @staticmethod
-    def _op_name_to_python_id(name):
+    def _op_name_to_python_id(name: str) -> str:
         return name.replace(".", "_").replace("$", "_")
 
     def op(self, name: str) -> Callable[[type], type]:

>From 145ca87c2dce074b5eb9a76457646e57c51e9cb7 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Tue, 9 Dec 2025 23:33:38 +0800
Subject: [PATCH 09/20] Refactor IRDSL via __init_subclass__

---
 mlir/python/mlir/dialects/irdl/dsl.py | 431 ++++++++++++++------------
 mlir/test/python/dialects/irdsl.py    | 136 ++++----
 2 files changed, 291 insertions(+), 276 deletions(-)

diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
index 41520f0e0f68c..630740c0103ab 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -2,23 +2,37 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-from ...dialects import irdl as _irdl
-from .._ods_common import (
-    _cext as _ods_cext,
-    segmented_accessor as _ods_segmented_accessor,
-)
-from . import Variadicity
 from typing import Dict, List, Union, Callable, Tuple
 from dataclasses import dataclass
-from inspect import Parameter as _Parameter, Signature as _Signature
-from types import SimpleNamespace as _SimpleNameSpace
+from inspect import Parameter, Signature
+from types import SimpleNamespace
+from abc import ABC, abstractmethod
+from contextlib import nullcontext
+from ...dialects import irdl
+from .._ods_common import _cext, segmented_accessor
+from . import Variadicity
+
+ir = _cext.ir
 
-_ods_ir = _ods_cext.ir
+__all__ = [
+    "Variadicity",
+    "Is",
+    "AnyOf",
+    "AllOf",
+    "Any",
+    "BaseName",
+    "BaseRef",
+    "Operand",
+    "Result",
+    "Attribute",
+    "Dialect",
+]
 
 
-class ConstraintExpr:
-    def _lower(self, ctx: "ConstraintLoweringContext") -> _ods_ir.Value:
-        raise NotImplementedError()
+class ConstraintExpr(ABC):
+    @abstractmethod
+    def _lower(self, ctx: "ConstraintLoweringContext") -> ir.Value:
+        pass
 
     def __or__(self, other: "ConstraintExpr") -> "ConstraintExpr":
         return AnyOf(self, other)
@@ -30,9 +44,9 @@ def __and__(self, other: "ConstraintExpr") -> "ConstraintExpr":
 class ConstraintLoweringContext:
     def __init__(self):
         # Cache so that the same ConstraintExpr instance reuses its SSA value.
-        self._cache: Dict[int, _ods_ir.Value] = {}
+        self._cache: Dict[int, ir.Value] = {}
 
-    def lower(self, expr: ConstraintExpr) -> _ods_ir.Value:
+    def lower(self, expr: ConstraintExpr) -> ir.Value:
         key = id(expr)
         if key in self._cache:
             return self._cache[key]
@@ -42,74 +56,70 @@ def lower(self, expr: ConstraintExpr) -> _ods_ir.Value:
 
 
 class Is(ConstraintExpr):
-    def __init__(self, attr: _ods_ir.Attribute):
-        self.attr = attr
-
-    def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
-        return _irdl.is_(self.attr)
-
+    def __init__(self, val: Union[ir.Attribute, ir.Type]):
+        self.val = val
 
-class IsType(Is):
-    def __init__(self, typ: _ods_ir.Type):
-        super().__init__(_ods_ir.TypeAttr.get(typ))
+    def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
+        return irdl.is_(
+            ir.TypeAttr.get(self.val) if isinstance(self.val, ir.Type) else self.val
+        )
 
 
 class AnyOf(ConstraintExpr):
     def __init__(self, *exprs: ConstraintExpr):
         self.exprs = exprs
 
-    def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
-        return _irdl.any_of(ctx.lower(expr) for expr in self.exprs)
+    def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
+        return irdl.any_of(ctx.lower(expr) for expr in self.exprs)
 
 
 class AllOf(ConstraintExpr):
     def __init__(self, *exprs: ConstraintExpr):
         self.exprs = exprs
 
-    def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
-        return _irdl.all_of(ctx.lower(expr) for expr in self.exprs)
+    def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
+        return irdl.all_of(ctx.lower(expr) for expr in self.exprs)
 
 
 class Any(ConstraintExpr):
-    def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
-        return _irdl.any()
+    def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
+        return irdl.any()
 
 
 class BaseName(ConstraintExpr):
     def __init__(self, name: str):
         self.name = name
 
-    def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
-        return _irdl.base(base_name=self.name)
+    def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
+        return irdl.base(base_name=self.name)
 
 
 class BaseRef(ConstraintExpr):
     def __init__(self, ref):
         self.ref = ref
 
-    def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
-        return _irdl.base(base_ref=self.ref)
+    def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
+        return irdl.base(base_ref=self.ref)
 
 
-class _FieldDef:
-    def __set_name__(self, owner, name: str):
-        self.name = name
+class FieldDef:
+    pass
 
 
 @dataclass
-class Operand(_FieldDef):
+class Operand(FieldDef):
     constraint: ConstraintExpr
     variadicity: Variadicity = Variadicity.single
 
 
 @dataclass
-class Result(_FieldDef):
+class Result(FieldDef):
     constraint: ConstraintExpr
     variadicity: Variadicity = Variadicity.single
 
 
 @dataclass
-class Attribute(_FieldDef):
+class Attribute(FieldDef):
     constraint: ConstraintExpr
 
     def __post_init__(self):
@@ -118,43 +128,57 @@ def __post_init__(self):
         self.variadicity = Variadicity.single
 
 
- at dataclass
-class Operation:
-    dialect_name: str
-    name: str
-    # We store operands and attributes into one list to maintain relative orders
-    # among them for generating OpView class.
-    fields: List[Union[Operand, Attribute, Result]]
-
-    def _partition_fields(self) -> Tuple[List[Operand], List[Attribute], List[Result]]:
-        operands = [i for i in self.fields if isinstance(i, Operand)]
-        attrs = [i for i in self.fields if isinstance(i, Attribute)]
-        results = [i for i in self.fields if isinstance(i, Result)]
-        return operands, attrs, results
-
-    def _emit(self) -> None:
-        ctx = ConstraintLoweringContext()
-        operands, attrs, results = self._partition_fields()
+def partition_fields(
+    fields: List[FieldDef],
+) -> Tuple[List[Operand], List[Attribute], List[Result]]:
+    operands = [i for i in fields if isinstance(i, Operand)]
+    attrs = [i for i in fields if isinstance(i, Attribute)]
+    results = [i for i in fields if isinstance(i, Result)]
+    return operands, attrs, results
 
-        op = _irdl.operation_(self.name)
-        with _ods_ir.InsertionPoint(op.body):
-            if operands:
-                _irdl.operands_(
-                    [ctx.lower(i.constraint) for i in operands],
-                    [i.name for i in operands],
-                    [i.variadicity for i in operands],
-                )
-            if attrs:
-                _irdl.attributes_(
-                    [ctx.lower(i.constraint) for i in attrs],
-                    [i.name for i in attrs],
-                )
-            if results:
-                _irdl.results_(
-                    [ctx.lower(i.constraint) for i in results],
-                    [i.name for i in results],
-                    [i.variadicity for i in results],
-                )
+
+def normalize_value_range(
+    value_range: Union[ir.OpOperandList, ir.OpResultList],
+    variadicity: Variadicity,
+):
+    if variadicity == Variadicity.single:
+        return value_range[0]
+    if variadicity == Variadicity.optional:
+        return value_range[0] if len(value_range) > 0 else None
+    return value_range
+
+
+class Operation(ir.OpView):
+    @classmethod
+    def __init_subclass__(cls, *, name=None, **kwargs):
+        super().__init_subclass__(**kwargs)
+
+        # for subclasses without "name" parameter,
+        # just treat them as normal classes
+        if not name:
+            return
+
+        op_name = name
+        cls._op_name = op_name
+        dialect_name = cls._dialect_name
+        dialect_obj = cls._dialect_obj
+
+        fields = []
+        cls._fields = fields
+
+        for key, value in cls.__dict__.items():
+            if isinstance(value, FieldDef):
+                setattr(value, "name", key)
+                fields.append(value)
+
+        cls._generate_class_attributes(dialect_name, op_name, fields)
+        cls._generate_init_method(fields)
+        operands, attrs, results = partition_fields(fields)
+        cls._generate_attr_properties(attrs)
+        cls._generate_operand_properties(operands)
+        cls._generate_result_properties(results)
+
+        dialect_obj.operations.append(cls)
 
     @staticmethod
     def _variadicity_to_segment(variadicity: Variadicity) -> int:
@@ -175,185 +199,184 @@ def _generate_segments(
             ]
         return None
 
-    def _generate_init_params(self) -> List[_Parameter]:
+    @staticmethod
+    def _generate_init_signature(fields: List[FieldDef]) -> Signature:
         # results are placed at the beginning of the parameter list,
         # but operands and attributes can appear in any relative order.
-        args = [i for i in self.fields if isinstance(i, Result)] + [
-            i for i in self.fields if not isinstance(i, Result)
+        args = [i for i in fields if isinstance(i, Result)] + [
+            i for i in fields if not isinstance(i, Result)
         ]
         positional_args = [
             i.name for i in args if i.variadicity != Variadicity.optional
         ]
         optional_args = [i.name for i in args if i.variadicity == Variadicity.optional]
 
-        params = [_Parameter("self", _Parameter.POSITIONAL_ONLY)]
+        params = [Parameter("self", Parameter.POSITIONAL_ONLY)]
         for i in positional_args:
-            params.append(_Parameter(i, _Parameter.POSITIONAL_OR_KEYWORD))
+            params.append(Parameter(i, Parameter.POSITIONAL_OR_KEYWORD))
         for i in optional_args:
-            params.append(_Parameter(i, _Parameter.KEYWORD_ONLY, default=None))
-        params.append(_Parameter("loc", _Parameter.KEYWORD_ONLY, default=None))
-        params.append(_Parameter("ip", _Parameter.KEYWORD_ONLY, default=None))
-
-        return params
-
-    def _make_op_view_and_builder(self) -> Tuple[type, Callable]:
-        operands, attrs, results = self._partition_fields()
-
-        operand_segments = Operation._generate_segments(operands)
-        result_segments = Operation._generate_segments(results)
-
-        params = self._generate_init_params()
-        init_sig = _Signature(params)
-        op = self
-
-        class _OpView(_ods_ir.OpView):
-            OPERATION_NAME = f"{op.dialect_name}.{op.name}"
-            _ODS_REGIONS = (0, True)
-            _ODS_OPERAND_SEGMENTS = operand_segments
-            _ODS_RESULT_SEGMENTS = result_segments
-
-            def __init__(*args, **kwargs):
-                bound = init_sig.bind(*args, **kwargs)
-                bound.apply_defaults()
-                args = bound.arguments
-
-                _operands = [args[operand.name] for operand in operands]
-                _results = [args[result.name] for result in results]
-                _attributes = dict(
-                    (attr.name, args[attr.name])
-                    for attr in attrs
-                    if args[attr.name] is not None
-                )
-                _regions = None
-                _ods_successors = None
-                self = args["self"]
-                super(_OpView, self).__init__(
-                    self.OPERATION_NAME,
-                    self._ODS_REGIONS,
-                    self._ODS_OPERAND_SEGMENTS,
-                    self._ODS_RESULT_SEGMENTS,
-                    attributes=_attributes,
-                    results=_results,
-                    operands=_operands,
-                    successors=_ods_successors,
-                    regions=_regions,
-                    loc=args["loc"],
-                    ip=args["ip"],
-                )
+            params.append(Parameter(i, Parameter.KEYWORD_ONLY, default=None))
+        params.append(Parameter("loc", Parameter.KEYWORD_ONLY, default=None))
+        params.append(Parameter("ip", Parameter.KEYWORD_ONLY, default=None))
+
+        return Signature(params)
+
+    @classmethod
+    def _generate_init_method(cls, fields):
+        init_sig = cls._generate_init_signature(fields)
+        operands, attrs, results = partition_fields(fields)
+
+        def __init__(*args, **kwargs):
+            bound = init_sig.bind(*args, **kwargs)
+            bound.apply_defaults()
+            args = bound.arguments
+
+            _operands = [args[operand.name] for operand in operands]
+            _results = [args[result.name] for result in results]
+            _attributes = dict(
+                (attr.name, args[attr.name])
+                for attr in attrs
+                if args[attr.name] is not None
+            )
+            _regions = None
+            _ods_successors = None
+            self = args["self"]
+            super(Operation, self).__init__(
+                self.OPERATION_NAME,
+                self._ODS_REGIONS,
+                self._ODS_OPERAND_SEGMENTS,
+                self._ODS_RESULT_SEGMENTS,
+                attributes=_attributes,
+                results=_results,
+                operands=_operands,
+                successors=_ods_successors,
+                regions=_regions,
+                loc=args["loc"],
+                ip=args["ip"],
+            )
+
+        __init__.__signature__ = init_sig
+        cls.__init__ = __init__
+
+    @classmethod
+    def _generate_class_attributes(cls, dialect_name, op_name, fields):
+        operands, attrs, results = partition_fields(fields)
+
+        operand_segments = cls._generate_segments(operands)
+        result_segments = cls._generate_segments(results)
 
-            __init__.__signature__ = init_sig
+        cls.OPERATION_NAME = f"{dialect_name}.{op_name}"
+        cls._ODS_REGIONS = (0, True)
+        cls._ODS_OPERAND_SEGMENTS = operand_segments
+        cls._ODS_RESULT_SEGMENTS = result_segments
 
+    @classmethod
+    def _generate_attr_properties(cls, attrs):
         for attr in attrs:
             setattr(
-                _OpView,
+                cls,
                 attr.name,
                 property(lambda self, name=attr.name: self.attributes[name]),
             )
 
-        def value_range_getter(
-            value_range: Union[_ods_ir.OpOperandList, _ods_ir.OpResultList],
-            variadicity: Variadicity,
-        ):
-            if variadicity == Variadicity.single:
-                return value_range[0]
-            if variadicity == Variadicity.optional:
-                return value_range[0] if len(value_range) > 0 else None
-            return value_range
-
+    @classmethod
+    def _generate_operand_properties(cls, operands):
         for i, operand in enumerate(operands):
-            if operand_segments:
+            if cls._ODS_OPERAND_SEGMENTS:
 
                 def getter(self, i=i, operand=operand):
-                    operand_range = _ods_segmented_accessor(
+                    operand_range = segmented_accessor(
                         self.operation.operands,
                         self.operation.attributes["operandSegmentSizes"],
                         i,
                     )
-                    return value_range_getter(operand_range, operand.variadicity)
+                    return normalize_value_range(operand_range, operand.variadicity)
 
-                setattr(_OpView, operand.name, property(getter))
+                setattr(cls, operand.name, property(getter))
             else:
-                setattr(
-                    _OpView, operand.name, property(lambda self, i=i: self.operands[i])
-                )
+                setattr(cls, operand.name, property(lambda self, i=i: self.operands[i]))
+
+    @classmethod
+    def _generate_result_properties(cls, results):
         for i, result in enumerate(results):
-            if result_segments:
+            if cls._ODS_RESULT_SEGMENTS:
 
                 def getter(self, i=i, result=result):
-                    result_range = _ods_segmented_accessor(
+                    result_range = segmented_accessor(
                         self.operation.results,
                         self.operation.attributes["resultSegmentSizes"],
                         i,
                     )
-                    return value_range_getter(result_range, result.variadicity)
+                    return normalize_value_range(result_range, result.variadicity)
 
-                setattr(_OpView, result.name, property(getter))
+                setattr(cls, result.name, property(getter))
             else:
-                setattr(
-                    _OpView, result.name, property(lambda self, i=i: self.results[i])
-                )
+                setattr(cls, result.name, property(lambda self, i=i: self.results[i]))
 
-        def _builder(*args, **kwargs) -> _OpView:
-            return _OpView(*args, **kwargs)
-
-        _builder.__signature__ = _Signature(params[1:])
+    @classmethod
+    def _emit_operation(cls) -> None:
+        ctx = ConstraintLoweringContext()
+        operands, attrs, results = partition_fields(cls._fields)
 
-        return _OpView, _builder
+        op = irdl.operation_(cls._op_name)
+        with ir.InsertionPoint(op.body):
+            if operands:
+                irdl.operands_(
+                    [ctx.lower(i.constraint) for i in operands],
+                    [i.name for i in operands],
+                    [i.variadicity for i in operands],
+                )
+            if attrs:
+                irdl.attributes_(
+                    [ctx.lower(i.constraint) for i in attrs],
+                    [i.name for i in attrs],
+                )
+            if results:
+                irdl.results_(
+                    [ctx.lower(i.constraint) for i in results],
+                    [i.name for i in results],
+                    [i.variadicity for i in results],
+                )
 
 
 class Dialect:
     def __init__(self, name: str):
         self.name = name
-        self.operations: List[Operation] = []
-        self.namespace = _SimpleNameSpace()
-
-    def _emit(self) -> None:
-        d = _irdl.dialect(self.name)
-        with _ods_ir.InsertionPoint(d.body):
+        self.operations = []
+        self.Operation = type(
+            "Operation",
+            (Operation,),
+            {"_dialect_obj": self, "_dialect_name": name},
+        )
+
+    def _emit_dialect(self) -> None:
+        d = irdl.dialect(self.name)
+        with ir.InsertionPoint(d.body):
             for op in self.operations:
-                op._emit()
+                op._emit_operation()
+
+    def _emit_module(self) -> ir.Module:
+        m = ir.Module.create()
+        with ir.InsertionPoint(m.body):
+            self._emit_dialect()
 
-    def _make_module(self) -> _ods_ir.Module:
-        with _ods_ir.Location.unknown():
-            m = _ods_ir.Module.create()
-            with _ods_ir.InsertionPoint(m.body):
-                self._emit()
         return m
 
     def _make_dialect_class(self) -> type:
-        class _Dialect(_ods_ir.Dialect):
-            DIALECT_NAMESPACE = self.name
-
-        return _Dialect
-
-    def load(self) -> _SimpleNameSpace:
-        _irdl.load_dialects(self._make_module())
-        dialect_class = self._make_dialect_class()
-        _ods_cext.register_dialect(dialect_class)
-        for op in self.operations:
-            _ods_cext.register_operation(dialect_class)(op.op_view)
-        return self.namespace
+        return type("Dialect", (ir.Dialect,), {"DIALECT_NAMESPACE": self.name})
 
-    @staticmethod
-    def _op_name_to_python_id(name: str) -> str:
-        return name.replace(".", "_").replace("$", "_")
-
-    def op(self, name: str) -> Callable[[type], type]:
-        def decorator(cls: type) -> type:
-            fields = [
-                field for field in cls.__dict__.values() if isinstance(field, _FieldDef)
-            ]
-            op_def = Operation(self.name, name, fields)
+    def load(self) -> None:
+        if hasattr(self, "mlir_module"):
+            raise RuntimeError(f"Dialect {self.name} is already loaded.")
 
-            op_view, builder = op_def._make_op_view_and_builder()
-            op_view.__name__ = cls.__name__
-            setattr(op_def, "op_view", op_view)
-            setattr(op_def, "builder", builder)
-            self.operations.append(op_def)
+        mlir_module = self._emit_module()
+        irdl.load_dialects(mlir_module)
 
-            self.namespace.__dict__[cls.__name__] = op_view
-            self.namespace.__dict__[Dialect._op_name_to_python_id(name)] = builder
+        dialect_class = self._make_dialect_class()
+        _cext.register_dialect(dialect_class)
 
-            return cls
+        # for op in self.operations:
+        # _cext.register_operation(dialect_class)(op)
 
-        return decorator
+        self.mlir_module = mlir_module
+        self.dialect_class = dialect_class
diff --git a/mlir/test/python/dialects/irdsl.py b/mlir/test/python/dialects/irdsl.py
index 8ef30ae0a4c13..f3a80655efe80 100644
--- a/mlir/test/python/dialects/irdsl.py
+++ b/mlir/test/python/dialects/irdsl.py
@@ -17,15 +17,13 @@ def run(f):
 def testMyInt():
     myint = irdsl.Dialect("myint")
     iattr = irdsl.BaseName("#builtin.integer")
-    i32 = irdsl.IsType(IntegerType.get_signless(32))
+    i32 = irdsl.Is(IntegerType.get_signless(32))
 
-    @myint.op("constant")
-    class ConstantOp:
+    class ConstantOp(myint.Operation, name="constant"):
         value = irdsl.Attribute(iattr)
         cst = irdsl.Result(i32)
 
-    @myint.op("add")
-    class AddOp:
+    class AddOp(myint.Operation, name="add"):
         lhs = irdsl.Operand(i32)
         rhs = irdsl.Operand(i32)
         res = irdsl.Result(i32)
@@ -43,21 +41,22 @@ class AddOp:
     # CHECK:     irdl.results(res: %0)
     # CHECK:   }
     # CHECK: }
-    print(myint._make_module())
-    myint = myint.load()
+    with Location.unknown():
+        myint.load()
+    print(myint.mlir_module)
 
-    # CHECK: ['ConstantOp', 'constant', 'AddOp', 'add']
-    print([i for i in myint.__dict__.keys()])
+    # CHECK: ['constant', 'add']
+    print([i._op_name for i in myint.operations])
 
     i32 = IntegerType.get_signless(32)
     with Location.unknown():
         module = Module.create()
         with InsertionPoint(module.body):
-            two = myint.constant(i32, IntegerAttr.get(i32, 2))
-            three = myint.constant(i32, IntegerAttr.get(i32, 3))
-            add1 = myint.add(i32, two, three)
-            add2 = myint.add(i32, add1, two)
-            add3 = myint.add(i32, add2, three)
+            two = ConstantOp(i32, IntegerAttr.get(i32, 2))
+            three = ConstantOp(i32, IntegerAttr.get(i32, 3))
+            add1 = AddOp(i32, two, three)
+            add2 = AddOp(i32, add1, two)
+            add3 = AddOp(i32, add2, three)
 
     # CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
     # CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
@@ -85,25 +84,24 @@ class AddOp:
     print(two.value)
     # CHECK: Value(%0
     print(two.cst)
-    # CHECK: (res, lhs, rhs, *, loc=None, ip=None)
-    print(myint.add.__signature__)
-    # CHECK: (cst, value, *, loc=None, ip=None)
-    print(myint.constant.__signature__)
+    # CHECK: (self, /, res, lhs, rhs, *, loc=None, ip=None)
+    print(AddOp.__init__.__signature__)
+    # CHECK: (self, /, cst, value, *, loc=None, ip=None)
+    print(ConstantOp.__init__.__signature__)
 
 
 @run
 def testIRDSL():
     test = irdsl.Dialect("irdsl_test")
-    i32 = irdsl.IsType(IntegerType.get_signless(32))
-    i64 = irdsl.IsType(IntegerType.get_signless(64))
+    i32 = irdsl.Is(IntegerType.get_signless(32))
+    i64 = irdsl.Is(IntegerType.get_signless(64))
     i32or64 = i32 | i64
     any = irdsl.Any()
-    f32 = irdsl.IsType(F32Type.get())
+    f32 = irdsl.Is(F32Type.get())
     iattr = irdsl.BaseName("#builtin.integer")
     fattr = irdsl.BaseName("#builtin.float")
 
-    @test.op("constraint")
-    class ConstraintOp:
+    class ConstraintOp(test.Operation, name="constraint"):
         a = irdsl.Operand(i32or64)
         b = irdsl.Operand(any)
         c = irdsl.Operand(f32 | i32)
@@ -111,21 +109,18 @@ class ConstraintOp:
         x = irdsl.Attribute(iattr)
         y = irdsl.Attribute(fattr)
 
-    @test.op("optional")
-    class OptionalOp:
+    class OptionalOp(test.Operation, name="optional"):
         a = irdsl.Operand(i32)
         b = irdsl.Operand(i32, irdsl.Variadicity.optional)
         out1 = irdsl.Result(i32)
         out2 = irdsl.Result(i32, irdsl.Variadicity.optional)
         out3 = irdsl.Result(i32)
 
-    @test.op("optional2")
-    class Optional2Op:
+    class Optional2Op(test.Operation, name="optional2"):
         a = irdsl.Operand(i32, irdsl.Variadicity.optional)
         b = irdsl.Result(i32, irdsl.Variadicity.optional)
 
-    @test.op("variadic")
-    class VariadicOp:
+    class VariadicOp(test.Operation, name="variadic"):
         a = irdsl.Operand(i32)
         b = irdsl.Operand(i32, irdsl.Variadicity.optional)
         c = irdsl.Operand(i32, irdsl.Variadicity.variadic)
@@ -134,13 +129,11 @@ class VariadicOp:
         out3 = irdsl.Result(i32, irdsl.Variadicity.optional)
         out4 = irdsl.Result(i32)
 
-    @test.op("variadic2")
-    class Variadic2Op:
+    class Variadic2Op(test.Operation, name="variadic2"):
         a = irdsl.Operand(i32, irdsl.Variadicity.variadic)
         b = irdsl.Result(i32, irdsl.Variadicity.variadic)
 
-    @test.op("mixed")
-    class MixedOp:
+    class MixedOp(test.Operation, name="mixed"):
         out = irdsl.Result(i32)
         in1 = irdsl.Operand(i32)
         in2 = irdsl.Attribute(iattr)
@@ -189,34 +182,33 @@ class MixedOp:
     # CHECK:     irdl.results(out: %0)
     # CHECK:   }
     # CHECK: }
-    print(test._make_module())
-    test = test.load()
-
-    # CHECK: (a, b, c, d, x, y, *, loc=None, ip=None)
-    print(test.constraint.__signature__)
-    # CHECK: (out1, out3, a, *, out2=None, b=None, loc=None, ip=None)
-    print(test.optional.__signature__)
-    # CHECK: (*, b=None, a=None, loc=None, ip=None)
-    print(test.optional2.__signature__)
-    # CHECK: (out1, out2, out4, a, c, *, out3=None, b=None, loc=None, ip=None)
-    print(test.variadic.__signature__)
-    # CHECK: (b, a, *, loc=None, ip=None)
-    print(test.variadic2.__signature__)
-    # CHECK: (out, in1, in2, in4, in5, *, in3=None, loc=None, ip=None)
-    print(test.mixed.__signature__)
+    with Location.unknown():
+        test.load()
+    print(test.mlir_module)
+
+    # CHECK: (self, /, a, b, c, d, x, y, *, loc=None, ip=None)
+    print(ConstraintOp.__init__.__signature__)
+    # CHECK: (self, /, out1, out3, a, *, out2=None, b=None, loc=None, ip=None)
+    print(OptionalOp.__init__.__signature__)
+    # CHECK: (self, /, *, b=None, a=None, loc=None, ip=None)
+    print(Optional2Op.__init__.__signature__)
+    # CHECK: (self, /, out1, out2, out4, a, c, *, out3=None, b=None, loc=None, ip=None)
+    print(VariadicOp.__init__.__signature__)
+    # CHECK: (self, /, b, a, *, loc=None, ip=None)
+    print(Variadic2Op.__init__.__signature__)
+    # CHECK: (self, /, out, in1, in2, in4, in5, *, in3=None, loc=None, ip=None)
+    print(MixedOp.__init__.__signature__)
 
     # CHECK: None None
-    print(
-        test.ConstraintOp._ODS_OPERAND_SEGMENTS, test.ConstraintOp._ODS_RESULT_SEGMENTS
-    )
+    print(ConstraintOp._ODS_OPERAND_SEGMENTS, ConstraintOp._ODS_RESULT_SEGMENTS)
     # CHECK: [1, 0] [1, 0, 1]
-    print(test.OptionalOp._ODS_OPERAND_SEGMENTS, test.OptionalOp._ODS_RESULT_SEGMENTS)
+    print(OptionalOp._ODS_OPERAND_SEGMENTS, OptionalOp._ODS_RESULT_SEGMENTS)
     # CHECK: [0] [0]
-    print(test.Optional2Op._ODS_OPERAND_SEGMENTS, test.Optional2Op._ODS_RESULT_SEGMENTS)
+    print(Optional2Op._ODS_OPERAND_SEGMENTS, Optional2Op._ODS_RESULT_SEGMENTS)
     # CHECK: [1, 0, -1] [-1, -1, 0, 1]
-    print(test.VariadicOp._ODS_OPERAND_SEGMENTS, test.VariadicOp._ODS_RESULT_SEGMENTS)
+    print(VariadicOp._ODS_OPERAND_SEGMENTS, VariadicOp._ODS_RESULT_SEGMENTS)
     # CHECK: [-1] [-1]
-    print(test.Variadic2Op._ODS_OPERAND_SEGMENTS, test.Variadic2Op._ODS_RESULT_SEGMENTS)
+    print(Variadic2Op._ODS_OPERAND_SEGMENTS, Variadic2Op._ODS_RESULT_SEGMENTS)
 
     i32 = IntegerType.get_signless(32)
     i64 = IntegerType.get_signless(64)
@@ -232,42 +224,42 @@ class MixedOp:
             fone = arith.constant(f32, 1.2)
 
             # CHECK: "irdsl_test.constraint"(%c1_i32, %c1_i32, %cst, %c1_i32) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, i32, f32, i32) -> ()
-            c1 = test.constraint(ione, ione, fone, ione, iattr, fattr)
+            c1 = ConstraintOp(ione, ione, fone, ione, iattr, fattr)
             # CHECK: "irdsl_test.constraint"(%c1_i32, %cst, %cst, %cst) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, f32, f32, f32) -> ()
-            test.constraint(ione, fone, fone, fone, iattr, fattr)
+            ConstraintOp(ione, fone, fone, fone, iattr, fattr)
             # CHECK: irdsl_test.constraint"(%c1_i32, %cst, %c1_i32, %cst) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, f32, i32, f32) -> ()
-            test.constraint(ione, fone, ione, fone, iattr, fattr)
+            ConstraintOp(ione, fone, ione, fone, iattr, fattr)
 
             # CHECK: %0:2 = "irdsl_test.optional"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0>, resultSegmentSizes = array<i32: 1, 0, 1>} : (i32) -> (i32, i32)
-            o1 = test.optional(i32, i32, ione)
+            o1 = OptionalOp(i32, i32, ione)
             # CHECK: %1:3 = "irdsl_test.optional"(%c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 1>, resultSegmentSizes = array<i32: 1, 1, 1>} : (i32, i32) -> (i32, i32, i32)
-            o2 = test.optional(i32, i32, ione, out2=i32, b=ione)
+            o2 = OptionalOp(i32, i32, ione, out2=i32, b=ione)
             # CHECK: irdsl_test.optional2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 0>} : () -> ()
-            o3 = test.optional2()
+            o3 = Optional2Op()
             # CHECK: %2 = "irdsl_test.optional2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 1>} : () -> i32
-            o4 = test.optional2(b=i32)
+            o4 = Optional2Op(b=i32)
             # CHECK: "irdsl_test.optional2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 0>} : (i32) -> ()
-            o5 = test.optional2(a=ione)
+            o5 = Optional2Op(a=ione)
             # CHECK: %3 = "irdsl_test.optional2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 1>} : (i32) -> i32
-            o6 = test.optional2(b=i32, a=ione)
+            o6 = Optional2Op(b=i32, a=ione)
 
             # CHECK: %4:4 = "irdsl_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 0, 2>, resultSegmentSizes = array<i32: 1, 2, 0, 1>} : (i32, i32, i32) -> (i32, i32, i32, i32)
-            v1 = test.variadic([i32], [i32, i32], i32, ione, [ione, ione])
+            v1 = VariadicOp([i32], [i32, i32], i32, ione, [ione, ione])
             # CHECK: %5:5 = "irdsl_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 1, 1>, resultSegmentSizes = array<i32: 1, 2, 1, 1>} : (i32, i32, i32) -> (i32, i32, i32, i32, i32)
-            v2 = test.variadic([i32], [i32, i32], i32, ione, [ione], out3=i32, b=ione)
+            v2 = VariadicOp([i32], [i32, i32], i32, ione, [ione], out3=i32, b=ione)
             # CHECK: %6:4 = "irdsl_test.variadic"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0, 0>, resultSegmentSizes = array<i32: 2, 1, 0, 1>} : (i32) -> (i32, i32, i32, i32)
-            v3 = test.variadic([i32, i32], [i32], i32, ione, [])
+            v3 = VariadicOp([i32, i32], [i32], i32, ione, [])
             # CHECK: "irdsl_test.variadic2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 0>} : () -> ()
-            v4 = test.variadic2([], [])
+            v4 = Variadic2Op([], [])
             # CHECK: "irdsl_test.variadic2"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 3>, resultSegmentSizes = array<i32: 0>} : (i32, i32, i32) -> ()
-            v5 = test.variadic2([], [ione, ione, ione])
+            v5 = Variadic2Op([], [ione, ione, ione])
             # CHECK: %7:2 = "irdsl_test.variadic2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 2>} : (i32) -> (i32, i32)
-            v6 = test.variadic2([i32, i32], [ione])
+            v6 = Variadic2Op([i32, i32], [ione])
 
             # CHECK: %8 = "irdsl_test.mixed"(%c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 0, 1>} : (i32, i32) -> i32
-            m1 = test.mixed(i32, ione, iattr, iattr, ione)
+            m1 = MixedOp(i32, ione, iattr, iattr, ione)
             # CHECK: %9 = "irdsl_test.mixed"(%c1_i32, %c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 1, 1>} : (i32, i32, i32) -> i32
-            m2 = test.mixed(i32, ione, iattr, iattr, ione, in3=ione)
+            m2 = MixedOp(i32, ione, iattr, iattr, ione, in3=ione)
 
     print(module)
     assert module.operation.verify()

>From 1d7e70788546a2b86fbc719a052a03fca0e0097f Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Thu, 11 Dec 2025 10:17:59 +0800
Subject: [PATCH 10/20] fix

---
 mlir/python/mlir/dialects/irdl/dsl.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
index 630740c0103ab..97df1f78ac16c 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -375,8 +375,8 @@ def load(self) -> None:
         dialect_class = self._make_dialect_class()
         _cext.register_dialect(dialect_class)
 
-        # for op in self.operations:
-        # _cext.register_operation(dialect_class)(op)
+        for op in self.operations:
+            _cext.register_operation(dialect_class)(op)
 
         self.mlir_module = mlir_module
         self.dialect_class = dialect_class

>From c8364487a1625754b27b75885db59f5087bcba53 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Thu, 11 Dec 2025 10:28:06 +0800
Subject: [PATCH 11/20] add more type annotations

---
 mlir/python/mlir/dialects/irdl/dsl.py | 14 ++++++++------
 1 file changed, 8 insertions(+), 6 deletions(-)

diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
index 97df1f78ac16c..56869e7ae4e3c 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -150,7 +150,7 @@ def normalize_value_range(
 
 class Operation(ir.OpView):
     @classmethod
-    def __init_subclass__(cls, *, name=None, **kwargs):
+    def __init_subclass__(cls, *, name: str = None, **kwargs):
         super().__init_subclass__(**kwargs)
 
         # for subclasses without "name" parameter,
@@ -222,7 +222,7 @@ def _generate_init_signature(fields: List[FieldDef]) -> Signature:
         return Signature(params)
 
     @classmethod
-    def _generate_init_method(cls, fields):
+    def _generate_init_method(cls, fields: List[FieldDef]) -> None:
         init_sig = cls._generate_init_signature(fields)
         operands, attrs, results = partition_fields(fields)
 
@@ -259,7 +259,9 @@ def __init__(*args, **kwargs):
         cls.__init__ = __init__
 
     @classmethod
-    def _generate_class_attributes(cls, dialect_name, op_name, fields):
+    def _generate_class_attributes(
+        cls, dialect_name: str, op_name: str, fields: List[FieldDef]
+    ) -> None:
         operands, attrs, results = partition_fields(fields)
 
         operand_segments = cls._generate_segments(operands)
@@ -271,7 +273,7 @@ def _generate_class_attributes(cls, dialect_name, op_name, fields):
         cls._ODS_RESULT_SEGMENTS = result_segments
 
     @classmethod
-    def _generate_attr_properties(cls, attrs):
+    def _generate_attr_properties(cls, attrs: List[Attribute]) -> None:
         for attr in attrs:
             setattr(
                 cls,
@@ -280,7 +282,7 @@ def _generate_attr_properties(cls, attrs):
             )
 
     @classmethod
-    def _generate_operand_properties(cls, operands):
+    def _generate_operand_properties(cls, operands: List[Operand]) -> None:
         for i, operand in enumerate(operands):
             if cls._ODS_OPERAND_SEGMENTS:
 
@@ -297,7 +299,7 @@ def getter(self, i=i, operand=operand):
                 setattr(cls, operand.name, property(lambda self, i=i: self.operands[i]))
 
     @classmethod
-    def _generate_result_properties(cls, results):
+    def _generate_result_properties(cls, results: List[Result]) -> None:
         for i, result in enumerate(results):
             if cls._ODS_RESULT_SEGMENTS:
 

>From 6e3eef309f4d56e6e8c64351184040db2205c309 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 22 Dec 2025 15:25:15 +0800
Subject: [PATCH 12/20] make sure that context is not required for defining
 dialects

---
 mlir/python/mlir/dialects/irdl/dsl.py |  28 +++-
 mlir/test/python/dialects/irdsl.py    | 220 +++++++++++++-------------
 2 files changed, 133 insertions(+), 115 deletions(-)

diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
index 56869e7ae4e3c..35237c848fbeb 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -56,13 +56,33 @@ def lower(self, expr: ConstraintExpr) -> ir.Value:
 
 
 class Is(ConstraintExpr):
-    def __init__(self, val: Union[ir.Attribute, ir.Type]):
+    def __init__(self, val: Callable[..., Union[ir.Attribute, ir.Type]]):
         self.val = val
+        self.args = []
+        self.kwargs = {}
+
+    def __call__(self, *args, **kwargs) -> "Is":
+        self.args.extend(args)
+        self.kwargs.update(kwargs)
+        return self
+
+    def __class_getitem__(
+        cls, val: Callable[..., Union[ir.Attribute, ir.Type]]
+    ) -> "Is":
+        return cls(val)
 
     def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
-        return irdl.is_(
-            ir.TypeAttr.get(self.val) if isinstance(self.val, ir.Type) else self.val
-        )
+        # for most attributes and types, they are created via `.get` method,
+        # here we can just omit the `.get`
+        if isinstance(self.val, type) and hasattr(self.val, "get"):
+            self.val = self.val.get
+
+        val = self.val(*self.args, **self.kwargs)
+
+        if isinstance(val, ir.Type):
+            val = ir.TypeAttr.get(val)
+
+        return irdl.is_(val)
 
 
 class AnyOf(ConstraintExpr):
diff --git a/mlir/test/python/dialects/irdsl.py b/mlir/test/python/dialects/irdsl.py
index f3a80655efe80..46511b8e1bfeb 100644
--- a/mlir/test/python/dialects/irdsl.py
+++ b/mlir/test/python/dialects/irdsl.py
@@ -8,8 +8,7 @@
 
 def run(f):
     print("\nTEST:", f.__name__, file=sys.stderr)
-    with Context():
-        f()
+    f()
 
 
 # CHECK: TEST: testMyInt
@@ -17,7 +16,7 @@ def run(f):
 def testMyInt():
     myint = irdsl.Dialect("myint")
     iattr = irdsl.BaseName("#builtin.integer")
-    i32 = irdsl.Is(IntegerType.get_signless(32))
+    i32 = irdsl.Is[IntegerType.get_signless](32)
 
     class ConstantOp(myint.Operation, name="constant"):
         value = irdsl.Attribute(iattr)
@@ -41,15 +40,15 @@ class AddOp(myint.Operation, name="add"):
     # CHECK:     irdl.results(res: %0)
     # CHECK:   }
     # CHECK: }
-    with Location.unknown():
+    with Context(), Location.unknown():
         myint.load()
-    print(myint.mlir_module)
+        print(myint.mlir_module)
 
-    # CHECK: ['constant', 'add']
-    print([i._op_name for i in myint.operations])
+        # CHECK: ['constant', 'add']
+        print([i._op_name for i in myint.operations])
+
+        i32 = IntegerType.get_signless(32)
 
-    i32 = IntegerType.get_signless(32)
-    with Location.unknown():
         module = Module.create()
         with InsertionPoint(module.body):
             two = ConstantOp(i32, IntegerAttr.get(i32, 2))
@@ -58,46 +57,46 @@ class AddOp(myint.Operation, name="add"):
             add2 = AddOp(i32, add1, two)
             add3 = AddOp(i32, add2, three)
 
-    # CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
-    # CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
-    # CHECK: %2 = "myint.add"(%0, %1) : (i32, i32) -> i32
-    # CHECK: %3 = "myint.add"(%2, %0) : (i32, i32) -> i32
-    # CHECK: %4 = "myint.add"(%3, %1) : (i32, i32) -> i32
-    print(module)
-    assert module.operation.verify()
-
-    # CHECK: AddOp
-    print(type(add1).__name__)
-    # CHECK: ConstantOp
-    print(type(two).__name__)
-    # CHECK: myint.add
-    print(add1.OPERATION_NAME)
-    # CHECK: None
-    print(add1._ODS_OPERAND_SEGMENTS)
-    # CHECK: None
-    print(add1._ODS_RESULT_SEGMENTS)
-    # CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
-    print(add1.lhs.owner)
-    # CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
-    print(add1.rhs.owner)
-    # CHECK: 2 : i32
-    print(two.value)
-    # CHECK: Value(%0
-    print(two.cst)
-    # CHECK: (self, /, res, lhs, rhs, *, loc=None, ip=None)
-    print(AddOp.__init__.__signature__)
-    # CHECK: (self, /, cst, value, *, loc=None, ip=None)
-    print(ConstantOp.__init__.__signature__)
+        # CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
+        # CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
+        # CHECK: %2 = "myint.add"(%0, %1) : (i32, i32) -> i32
+        # CHECK: %3 = "myint.add"(%2, %0) : (i32, i32) -> i32
+        # CHECK: %4 = "myint.add"(%3, %1) : (i32, i32) -> i32
+        print(module)
+        assert module.operation.verify()
+
+        # CHECK: AddOp
+        print(type(add1).__name__)
+        # CHECK: ConstantOp
+        print(type(two).__name__)
+        # CHECK: myint.add
+        print(add1.OPERATION_NAME)
+        # CHECK: None
+        print(add1._ODS_OPERAND_SEGMENTS)
+        # CHECK: None
+        print(add1._ODS_RESULT_SEGMENTS)
+        # CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
+        print(add1.lhs.owner)
+        # CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
+        print(add1.rhs.owner)
+        # CHECK: 2 : i32
+        print(two.value)
+        # CHECK: Value(%0
+        print(two.cst)
+        # CHECK: (self, /, res, lhs, rhs, *, loc=None, ip=None)
+        print(AddOp.__init__.__signature__)
+        # CHECK: (self, /, cst, value, *, loc=None, ip=None)
+        print(ConstantOp.__init__.__signature__)
 
 
 @run
 def testIRDSL():
     test = irdsl.Dialect("irdsl_test")
-    i32 = irdsl.Is(IntegerType.get_signless(32))
-    i64 = irdsl.Is(IntegerType.get_signless(64))
+    i32 = irdsl.Is[IntegerType.get_signless](32)
+    i64 = irdsl.Is[IntegerType.get_signless](64)
     i32or64 = i32 | i64
     any = irdsl.Any()
-    f32 = irdsl.Is(F32Type.get())
+    f32 = irdsl.Is[F32Type]
     iattr = irdsl.BaseName("#builtin.integer")
     fattr = irdsl.BaseName("#builtin.float")
 
@@ -182,39 +181,38 @@ class MixedOp(test.Operation, name="mixed"):
     # CHECK:     irdl.results(out: %0)
     # CHECK:   }
     # CHECK: }
-    with Location.unknown():
+    with Context(), Location.unknown():
         test.load()
-    print(test.mlir_module)
-
-    # CHECK: (self, /, a, b, c, d, x, y, *, loc=None, ip=None)
-    print(ConstraintOp.__init__.__signature__)
-    # CHECK: (self, /, out1, out3, a, *, out2=None, b=None, loc=None, ip=None)
-    print(OptionalOp.__init__.__signature__)
-    # CHECK: (self, /, *, b=None, a=None, loc=None, ip=None)
-    print(Optional2Op.__init__.__signature__)
-    # CHECK: (self, /, out1, out2, out4, a, c, *, out3=None, b=None, loc=None, ip=None)
-    print(VariadicOp.__init__.__signature__)
-    # CHECK: (self, /, b, a, *, loc=None, ip=None)
-    print(Variadic2Op.__init__.__signature__)
-    # CHECK: (self, /, out, in1, in2, in4, in5, *, in3=None, loc=None, ip=None)
-    print(MixedOp.__init__.__signature__)
-
-    # CHECK: None None
-    print(ConstraintOp._ODS_OPERAND_SEGMENTS, ConstraintOp._ODS_RESULT_SEGMENTS)
-    # CHECK: [1, 0] [1, 0, 1]
-    print(OptionalOp._ODS_OPERAND_SEGMENTS, OptionalOp._ODS_RESULT_SEGMENTS)
-    # CHECK: [0] [0]
-    print(Optional2Op._ODS_OPERAND_SEGMENTS, Optional2Op._ODS_RESULT_SEGMENTS)
-    # CHECK: [1, 0, -1] [-1, -1, 0, 1]
-    print(VariadicOp._ODS_OPERAND_SEGMENTS, VariadicOp._ODS_RESULT_SEGMENTS)
-    # CHECK: [-1] [-1]
-    print(Variadic2Op._ODS_OPERAND_SEGMENTS, Variadic2Op._ODS_RESULT_SEGMENTS)
-
-    i32 = IntegerType.get_signless(32)
-    i64 = IntegerType.get_signless(64)
-    f32 = F32Type.get()
-
-    with Location.unknown():
+        print(test.mlir_module)
+
+        # CHECK: (self, /, a, b, c, d, x, y, *, loc=None, ip=None)
+        print(ConstraintOp.__init__.__signature__)
+        # CHECK: (self, /, out1, out3, a, *, out2=None, b=None, loc=None, ip=None)
+        print(OptionalOp.__init__.__signature__)
+        # CHECK: (self, /, *, b=None, a=None, loc=None, ip=None)
+        print(Optional2Op.__init__.__signature__)
+        # CHECK: (self, /, out1, out2, out4, a, c, *, out3=None, b=None, loc=None, ip=None)
+        print(VariadicOp.__init__.__signature__)
+        # CHECK: (self, /, b, a, *, loc=None, ip=None)
+        print(Variadic2Op.__init__.__signature__)
+        # CHECK: (self, /, out, in1, in2, in4, in5, *, in3=None, loc=None, ip=None)
+        print(MixedOp.__init__.__signature__)
+
+        # CHECK: None None
+        print(ConstraintOp._ODS_OPERAND_SEGMENTS, ConstraintOp._ODS_RESULT_SEGMENTS)
+        # CHECK: [1, 0] [1, 0, 1]
+        print(OptionalOp._ODS_OPERAND_SEGMENTS, OptionalOp._ODS_RESULT_SEGMENTS)
+        # CHECK: [0] [0]
+        print(Optional2Op._ODS_OPERAND_SEGMENTS, Optional2Op._ODS_RESULT_SEGMENTS)
+        # CHECK: [1, 0, -1] [-1, -1, 0, 1]
+        print(VariadicOp._ODS_OPERAND_SEGMENTS, VariadicOp._ODS_RESULT_SEGMENTS)
+        # CHECK: [-1] [-1]
+        print(Variadic2Op._ODS_OPERAND_SEGMENTS, Variadic2Op._ODS_RESULT_SEGMENTS)
+
+        i32 = IntegerType.get_signless(32)
+        i64 = IntegerType.get_signless(64)
+        f32 = F32Type.get()
+
         iattr = IntegerAttr.get(i32, 2)
         fattr = FloatAttr.get_f32(2.3)
 
@@ -261,40 +259,40 @@ class MixedOp(test.Operation, name="mixed"):
             # CHECK: %9 = "irdsl_test.mixed"(%c1_i32, %c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 1, 1>} : (i32, i32, i32) -> i32
             m2 = MixedOp(i32, ione, iattr, iattr, ione, in3=ione)
 
-    print(module)
-    assert module.operation.verify()
-
-    # CHECK: Value(%c1_i32
-    print(c1.a)
-    # CHECK: 2 : i32
-    print(c1.x)
-    # CHECK: Value(%c1_i32
-    print(o1.a)
-    # CHECK: None
-    print(o1.b)
-    # CHECK: Value(%c1_i32
-    print(o2.b)
-    # CHECK: 0
-    print(o1.out1.result_number)
-    # CHECK: None
-    print(o1.out2)
-    # CHECK: 0
-    print(o2.out1.result_number)
-    # CHECK: 1
-    print(o2.out2.result_number)
-    # CHECK: None
-    print(o3.a)
-    # CHECK: Value(%c1_i32
-    print(o5.a)
-    # CHECK: ['Value(%c1_i32 = arith.constant 1 : i32)', 'Value(%c1_i32 = arith.constant 1 : i32)']
-    print([str(i) for i in v1.c])
-    # CHECK: ['Value(%c1_i32 = arith.constant 1 : i32)']
-    print([str(i) for i in v2.c])
-    # CHECK: []
-    print([str(i) for i in v3.c])
-    # CHECK: 0 0
-    print(len(v4.a), len(v4.b))
-    # CHECK: 3 0
-    print(len(v5.a), len(v5.b))
-    # CHECK: 1 2
-    print(len(v6.a), len(v6.b))
+        print(module)
+        assert module.operation.verify()
+
+        # CHECK: Value(%c1_i32
+        print(c1.a)
+        # CHECK: 2 : i32
+        print(c1.x)
+        # CHECK: Value(%c1_i32
+        print(o1.a)
+        # CHECK: None
+        print(o1.b)
+        # CHECK: Value(%c1_i32
+        print(o2.b)
+        # CHECK: 0
+        print(o1.out1.result_number)
+        # CHECK: None
+        print(o1.out2)
+        # CHECK: 0
+        print(o2.out1.result_number)
+        # CHECK: 1
+        print(o2.out2.result_number)
+        # CHECK: None
+        print(o3.a)
+        # CHECK: Value(%c1_i32
+        print(o5.a)
+        # CHECK: ['Value(%c1_i32 = arith.constant 1 : i32)', 'Value(%c1_i32 = arith.constant 1 : i32)']
+        print([str(i) for i in v1.c])
+        # CHECK: ['Value(%c1_i32 = arith.constant 1 : i32)']
+        print([str(i) for i in v2.c])
+        # CHECK: []
+        print([str(i) for i in v3.c])
+        # CHECK: 0 0
+        print(len(v4.a), len(v4.b))
+        # CHECK: 3 0
+        print(len(v5.a), len(v5.b))
+        # CHECK: 1 2
+        print(len(v6.a), len(v6.b))

>From 4ff653d192b8837fd8484eba8be492e4124501b1 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 22 Dec 2025 15:27:05 +0800
Subject: [PATCH 13/20] adjust the comment

---
 mlir/python/mlir/dialects/irdl/dsl.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
index 35237c848fbeb..59e667ab1ba8b 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -73,7 +73,7 @@ def __class_getitem__(
 
     def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
         # for most attributes and types, they are created via `.get` method,
-        # here we can just omit the `.get`
+        # here we can just omit the `.get` suffix for convenience
         if isinstance(self.val, type) and hasattr(self.val, "get"):
             self.val = self.val.get
 

>From c9fcf2473844f09a6d11fef54b7651c16ac9b1c9 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 22 Dec 2025 16:34:17 +0800
Subject: [PATCH 14/20] make irdsl.Dialect a subclass of ir.Dialect

---
 mlir/python/mlir/dialects/irdl/dsl.py | 48 +++++++++++++--------------
 mlir/test/python/dialects/irdsl.py    | 35 ++++++++++---------
 2 files changed, 43 insertions(+), 40 deletions(-)

diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
index 59e667ab1ba8b..c5679ee587e8c 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -361,44 +361,44 @@ def _emit_operation(cls) -> None:
                 )
 
 
-class Dialect:
-    def __init__(self, name: str):
-        self.name = name
-        self.operations = []
-        self.Operation = type(
+class Dialect(ir.Dialect):
+    @classmethod
+    def __init_subclass__(cls, name: str, **kwargs):
+        cls.name = name
+        cls.DIALECT_NAMESPACE = name
+        cls.operations = []
+        cls.Operation = type(
             "Operation",
             (Operation,),
-            {"_dialect_obj": self, "_dialect_name": name},
+            {"_dialect_obj": cls, "_dialect_name": name},
         )
 
-    def _emit_dialect(self) -> None:
-        d = irdl.dialect(self.name)
+    @classmethod
+    def _emit_dialect(cls) -> None:
+        d = irdl.dialect(cls.name)
         with ir.InsertionPoint(d.body):
-            for op in self.operations:
+            for op in cls.operations:
                 op._emit_operation()
 
-    def _emit_module(self) -> ir.Module:
+    @classmethod
+    def _emit_module(cls) -> ir.Module:
         m = ir.Module.create()
         with ir.InsertionPoint(m.body):
-            self._emit_dialect()
+            cls._emit_dialect()
 
         return m
 
-    def _make_dialect_class(self) -> type:
-        return type("Dialect", (ir.Dialect,), {"DIALECT_NAMESPACE": self.name})
-
-    def load(self) -> None:
-        if hasattr(self, "mlir_module"):
-            raise RuntimeError(f"Dialect {self.name} is already loaded.")
+    @classmethod
+    def load(cls) -> None:
+        if hasattr(cls, "mlir_module"):
+            raise RuntimeError(f"Dialect {cls.name} is already loaded.")
 
-        mlir_module = self._emit_module()
+        mlir_module = cls._emit_module()
         irdl.load_dialects(mlir_module)
 
-        dialect_class = self._make_dialect_class()
-        _cext.register_dialect(dialect_class)
+        _cext.register_dialect(cls)
 
-        for op in self.operations:
-            _cext.register_operation(dialect_class)(op)
+        for op in cls.operations:
+            _cext.register_operation(cls)(op)
 
-        self.mlir_module = mlir_module
-        self.dialect_class = dialect_class
+        cls.mlir_module = mlir_module
diff --git a/mlir/test/python/dialects/irdsl.py b/mlir/test/python/dialects/irdsl.py
index 46511b8e1bfeb..de7949a7d376a 100644
--- a/mlir/test/python/dialects/irdsl.py
+++ b/mlir/test/python/dialects/irdsl.py
@@ -14,15 +14,17 @@ def run(f):
 # CHECK: TEST: testMyInt
 @run
 def testMyInt():
-    myint = irdsl.Dialect("myint")
+    class MyInt(irdsl.Dialect, name="myint"):
+        pass
+
     iattr = irdsl.BaseName("#builtin.integer")
     i32 = irdsl.Is[IntegerType.get_signless](32)
 
-    class ConstantOp(myint.Operation, name="constant"):
+    class ConstantOp(MyInt.Operation, name="constant"):
         value = irdsl.Attribute(iattr)
         cst = irdsl.Result(i32)
 
-    class AddOp(myint.Operation, name="add"):
+    class AddOp(MyInt.Operation, name="add"):
         lhs = irdsl.Operand(i32)
         rhs = irdsl.Operand(i32)
         res = irdsl.Result(i32)
@@ -41,12 +43,11 @@ class AddOp(myint.Operation, name="add"):
     # CHECK:   }
     # CHECK: }
     with Context(), Location.unknown():
-        myint.load()
-        print(myint.mlir_module)
+        MyInt.load()
+        print(MyInt.mlir_module)
 
         # CHECK: ['constant', 'add']
-        print([i._op_name for i in myint.operations])
-
+        print([i._op_name for i in MyInt.operations])
         i32 = IntegerType.get_signless(32)
 
         module = Module.create()
@@ -91,7 +92,9 @@ class AddOp(myint.Operation, name="add"):
 
 @run
 def testIRDSL():
-    test = irdsl.Dialect("irdsl_test")
+    class Test(irdsl.Dialect, name="irdsl_test"):
+        pass
+
     i32 = irdsl.Is[IntegerType.get_signless](32)
     i64 = irdsl.Is[IntegerType.get_signless](64)
     i32or64 = i32 | i64
@@ -100,7 +103,7 @@ def testIRDSL():
     iattr = irdsl.BaseName("#builtin.integer")
     fattr = irdsl.BaseName("#builtin.float")
 
-    class ConstraintOp(test.Operation, name="constraint"):
+    class ConstraintOp(Test.Operation, name="constraint"):
         a = irdsl.Operand(i32or64)
         b = irdsl.Operand(any)
         c = irdsl.Operand(f32 | i32)
@@ -108,18 +111,18 @@ class ConstraintOp(test.Operation, name="constraint"):
         x = irdsl.Attribute(iattr)
         y = irdsl.Attribute(fattr)
 
-    class OptionalOp(test.Operation, name="optional"):
+    class OptionalOp(Test.Operation, name="optional"):
         a = irdsl.Operand(i32)
         b = irdsl.Operand(i32, irdsl.Variadicity.optional)
         out1 = irdsl.Result(i32)
         out2 = irdsl.Result(i32, irdsl.Variadicity.optional)
         out3 = irdsl.Result(i32)
 
-    class Optional2Op(test.Operation, name="optional2"):
+    class Optional2Op(Test.Operation, name="optional2"):
         a = irdsl.Operand(i32, irdsl.Variadicity.optional)
         b = irdsl.Result(i32, irdsl.Variadicity.optional)
 
-    class VariadicOp(test.Operation, name="variadic"):
+    class VariadicOp(Test.Operation, name="variadic"):
         a = irdsl.Operand(i32)
         b = irdsl.Operand(i32, irdsl.Variadicity.optional)
         c = irdsl.Operand(i32, irdsl.Variadicity.variadic)
@@ -128,11 +131,11 @@ class VariadicOp(test.Operation, name="variadic"):
         out3 = irdsl.Result(i32, irdsl.Variadicity.optional)
         out4 = irdsl.Result(i32)
 
-    class Variadic2Op(test.Operation, name="variadic2"):
+    class Variadic2Op(Test.Operation, name="variadic2"):
         a = irdsl.Operand(i32, irdsl.Variadicity.variadic)
         b = irdsl.Result(i32, irdsl.Variadicity.variadic)
 
-    class MixedOp(test.Operation, name="mixed"):
+    class MixedOp(Test.Operation, name="mixed"):
         out = irdsl.Result(i32)
         in1 = irdsl.Operand(i32)
         in2 = irdsl.Attribute(iattr)
@@ -182,8 +185,8 @@ class MixedOp(test.Operation, name="mixed"):
     # CHECK:   }
     # CHECK: }
     with Context(), Location.unknown():
-        test.load()
-        print(test.mlir_module)
+        Test.load()
+        print(Test.mlir_module)
 
         # CHECK: (self, /, a, b, c, d, x, y, *, loc=None, ip=None)
         print(ConstraintOp.__init__.__signature__)

>From 59d7ae71288af54ae76d53f2976fe4299e638308 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 22 Dec 2025 17:12:53 +0800
Subject: [PATCH 15/20] add example for base class

---
 mlir/python/mlir/dialects/irdl/dsl.py | 9 +++++----
 mlir/test/python/dialects/irdsl.py    | 4 +++-
 2 files changed, 8 insertions(+), 5 deletions(-)

diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
index c5679ee587e8c..d1c706169ffd8 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -186,10 +186,11 @@ def __init_subclass__(cls, *, name: str = None, **kwargs):
         fields = []
         cls._fields = fields
 
-        for key, value in cls.__dict__.items():
-            if isinstance(value, FieldDef):
-                setattr(value, "name", key)
-                fields.append(value)
+        for base in reversed(cls.__mro__):
+            for key, value in base.__dict__.items():
+                if isinstance(value, FieldDef):
+                    setattr(value, "name", key)
+                    fields.append(value)
 
         cls._generate_class_attributes(dialect_name, op_name, fields)
         cls._generate_init_method(fields)
diff --git a/mlir/test/python/dialects/irdsl.py b/mlir/test/python/dialects/irdsl.py
index de7949a7d376a..e80bb412139b0 100644
--- a/mlir/test/python/dialects/irdsl.py
+++ b/mlir/test/python/dialects/irdsl.py
@@ -135,9 +135,11 @@ class Variadic2Op(Test.Operation, name="variadic2"):
         a = irdsl.Operand(i32, irdsl.Variadicity.variadic)
         b = irdsl.Result(i32, irdsl.Variadicity.variadic)
 
-    class MixedOp(Test.Operation, name="mixed"):
+    class MixedOpBase(Test.Operation):
         out = irdsl.Result(i32)
         in1 = irdsl.Operand(i32)
+
+    class MixedOp(MixedOpBase, name="mixed"):
         in2 = irdsl.Attribute(iattr)
         in3 = irdsl.Operand(i32, irdsl.Variadicity.optional)
         in4 = irdsl.Attribute(iattr)

>From 2cb35816f9025ecdbe38cd5f8d8abb041938bf03 Mon Sep 17 00:00:00 2001
From: Twice <twice at apache.org>
Date: Mon, 22 Dec 2025 17:32:32 +0800
Subject: [PATCH 16/20] Update mlir/python/mlir/dialects/irdl/dsl.py

Co-authored-by: Rolf Morel <rolfmorel at gmail.com>
---
 mlir/python/mlir/dialects/irdl/dsl.py | 6 +-----
 1 file changed, 1 insertion(+), 5 deletions(-)

diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
index d1c706169ffd8..075926a08872a 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -141,11 +141,7 @@ class Result(FieldDef):
 @dataclass
 class Attribute(FieldDef):
     constraint: ConstraintExpr
-
-    def __post_init__(self):
-        # just for unified processing,
-        # currently optional attribute is not supported by IRDL
-        self.variadicity = Variadicity.single
+    variadicity: typing.ClassVar[Variadicity] = Variadicity.single
 
 
 def partition_fields(

>From 36104c4836a77cc0ca2228b22bbf2d904407534f Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 22 Dec 2025 17:34:12 +0800
Subject: [PATCH 17/20] fix import

---
 mlir/python/mlir/dialects/irdl/dsl.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
index 075926a08872a..0b6fdfc268c47 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -2,7 +2,7 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-from typing import Dict, List, Union, Callable, Tuple
+from typing import Dict, List, Union, Callable, Tuple, ClassVar
 from dataclasses import dataclass
 from inspect import Parameter, Signature
 from types import SimpleNamespace
@@ -141,7 +141,7 @@ class Result(FieldDef):
 @dataclass
 class Attribute(FieldDef):
     constraint: ConstraintExpr
-    variadicity: typing.ClassVar[Variadicity] = Variadicity.single
+    variadicity: ClassVar[Variadicity] = Variadicity.single
 
 
 def partition_fields(

>From 4fbe94384b689d3f765ecdddc192ded86d2f493f Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Thu, 8 Jan 2026 21:12:29 +0800
Subject: [PATCH 18/20] fix order of base fields

---
 mlir/python/mlir/dialects/irdl/dsl.py | 20 +++++++++++---------
 1 file changed, 11 insertions(+), 9 deletions(-)

diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
index 0b6fdfc268c47..31a2a1bf98b14 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -169,6 +169,17 @@ class Operation(ir.OpView):
     def __init_subclass__(cls, *, name: str = None, **kwargs):
         super().__init_subclass__(**kwargs)
 
+        fields = []
+        cls._fields = fields
+
+        for base in cls.__bases__:
+            if hasattr(base, "_fields"):
+                fields.extend(base._fields)
+        for key, value in cls.__dict__.items():
+            if isinstance(value, FieldDef):
+                setattr(value, "name", key)
+                fields.append(value)
+
         # for subclasses without "name" parameter,
         # just treat them as normal classes
         if not name:
@@ -179,15 +190,6 @@ def __init_subclass__(cls, *, name: str = None, **kwargs):
         dialect_name = cls._dialect_name
         dialect_obj = cls._dialect_obj
 
-        fields = []
-        cls._fields = fields
-
-        for base in reversed(cls.__mro__):
-            for key, value in base.__dict__.items():
-                if isinstance(value, FieldDef):
-                    setattr(value, "name", key)
-                    fields.append(value)
-
         cls._generate_class_attributes(dialect_name, op_name, fields)
         cls._generate_init_method(fields)
         operands, attrs, results = partition_fields(fields)

>From 84b99bbc0a6f1a07dfac481f4a323dd2c39cbb5b Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 9 Jan 2026 19:50:32 +0800
Subject: [PATCH 19/20] huge refactor

---
 mlir/include/mlir-c/BuiltinAttributes.h       |   2 +
 .../mlir/Bindings/Python/IRAttributes.h       |   1 +
 mlir/include/mlir/Bindings/Python/IRCore.h    |   7 +-
 mlir/lib/Bindings/Python/IRCore.cpp           |   2 +-
 mlir/lib/CAPI/IR/BuiltinAttributes.cpp        |   2 +
 mlir/python/CMakeLists.txt                    |   5 +-
 .../mlir/dialects/{irdl/dsl.py => ext.py}     | 209 ++++++++----------
 .../dialects/{irdl/__init__.py => irdl.py}    |  13 +-
 .../test/python/dialects/{irdsl.py => ext.py} |  88 ++++----
 9 files changed, 147 insertions(+), 182 deletions(-)
 rename mlir/python/mlir/dialects/{irdl/dsl.py => ext.py} (69%)
 rename mlir/python/mlir/dialects/{irdl/__init__.py => irdl.py} (91%)
 rename mlir/test/python/dialects/{irdsl.py => ext.py} (84%)

diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h
index 17c73f44cfc74..eab732365f6b8 100644
--- a/mlir/include/mlir-c/BuiltinAttributes.h
+++ b/mlir/include/mlir-c/BuiltinAttributes.h
@@ -115,6 +115,8 @@ MLIR_CAPI_EXPORTED MlirTypeID mlirDictionaryAttrGetTypeID(void);
 /// Checks whether the given attribute is a floating point attribute.
 MLIR_CAPI_EXPORTED bool mlirAttributeIsAFloat(MlirAttribute attr);
 
+MLIR_CAPI_EXPORTED MlirStringRef mlirFloatAttrGetName(void);
+
 /// Creates a floating point attribute in the given context with the given
 /// double value and double-precision FP semantics.
 MLIR_CAPI_EXPORTED MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx,
diff --git a/mlir/include/mlir/Bindings/Python/IRAttributes.h b/mlir/include/mlir/Bindings/Python/IRAttributes.h
index 05d64b0d91b1b..6175710d76dd0 100644
--- a/mlir/include/mlir/Bindings/Python/IRAttributes.h
+++ b/mlir/include/mlir/Bindings/Python/IRAttributes.h
@@ -324,6 +324,7 @@ class MLIR_PYTHON_API_EXPORTED PyFloatAttribute
   using PyConcreteAttribute::PyConcreteAttribute;
   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
       mlirFloatAttrGetTypeID;
+  static inline const MlirStringRef name = mlirFloatAttrGetName();
 
   static void bindDerived(ClassTy &c);
 };
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 59dc496c9e206..766eb9290ad58 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -957,7 +957,7 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteType : public BaseTy {
   }
 
   static void bind(nanobind::module_ &m) {
-    auto cls = ClassTy(m, DerivedTy::pyClassName);
+    auto cls = ClassTy(m, DerivedTy::pyClassName, nanobind::is_generic());
     cls.def(nanobind::init<PyType &>(), nanobind::keep_alive<0, 1>(),
             nanobind::arg("cast_from_type"));
     cls.def_prop_ro_static(
@@ -1092,9 +1092,10 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteAttribute : public BaseTy {
   static void bind(nanobind::module_ &m, PyType_Slot *slots = nullptr) {
     ClassTy cls;
     if (slots) {
-      cls = ClassTy(m, DerivedTy::pyClassName, nanobind::type_slots(slots));
+      cls = ClassTy(m, DerivedTy::pyClassName, nanobind::type_slots(slots),
+                    nanobind::is_generic());
     } else {
-      cls = ClassTy(m, DerivedTy::pyClassName);
+      cls = ClassTy(m, DerivedTy::pyClassName, nanobind::is_generic());
     }
     cls.def(nanobind::init<PyAttribute &>(), nanobind::keep_alive<0, 1>(),
             nanobind::arg("cast_from_attr"));
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 19db41fae4fe2..a9f228f33a0ab 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -403,7 +403,7 @@ size_t PyOpOperand::getOperandNumber() const {
 }
 
 void PyOpOperand::bind(nb::module_ &m) {
-  nb::class_<PyOpOperand>(m, "OpOperand")
+  nb::class_<PyOpOperand>(m, "OpOperand", nb::is_generic())
       .def_prop_ro("owner", &PyOpOperand::getOwner,
                    "Returns the operation that owns this operand.")
       .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber,
diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index eebf82215eab0..f7172c21a0cb9 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -131,6 +131,8 @@ bool mlirAttributeIsAFloat(MlirAttribute attr) {
   return llvm::isa<FloatAttr>(unwrap(attr));
 }
 
+MlirStringRef mlirFloatAttrGetName(void) { return wrap(FloatAttr::name); }
+
 MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type,
                                      double value) {
   return wrap(FloatAttr::get(unwrap(type), value));
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index d97066657e070..8ab145ada85dd 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -30,6 +30,7 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
     passmanager.py
     rewrite.py
     dialects/_ods_common.py
+    dialects/ext.py
 )
 
 declare_mlir_python_sources(MLIRPythonSources.Core.Python.Extras
@@ -513,9 +514,7 @@ declare_mlir_dialect_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   TD_FILE dialects/IRDLOps.td
-  SOURCES
-    dialects/irdl/__init__.py
-    dialects/irdl/dsl.py
+  SOURCES dialects/irdl.py
   DIALECT_NAME irdl
   GEN_ENUM_BINDINGS
 )
diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/ext.py
similarity index 69%
rename from mlir/python/mlir/dialects/irdl/dsl.py
rename to mlir/python/mlir/dialects/ext.py
index 31a2a1bf98b14..18ab977d164de 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -2,124 +2,62 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-from typing import Dict, List, Union, Callable, Tuple, ClassVar
+from typing import (
+    Dict,
+    List,
+    Union,
+    Tuple,
+    Any,
+    TypeVar,
+    get_origin,
+    get_args,
+)
+from collections.abc import Sequence
 from dataclasses import dataclass
 from inspect import Parameter, Signature
-from types import SimpleNamespace
-from abc import ABC, abstractmethod
-from contextlib import nullcontext
-from ...dialects import irdl
-from .._ods_common import _cext, segmented_accessor
-from . import Variadicity
+from types import UnionType
+from . import irdl
+from ._ods_common import _cext, segmented_accessor
+from .irdl import Variadicity
 
 ir = _cext.ir
 
 __all__ = [
-    "Variadicity",
-    "Is",
-    "AnyOf",
-    "AllOf",
-    "Any",
-    "BaseName",
-    "BaseRef",
-    "Operand",
-    "Result",
-    "Attribute",
     "Dialect",
 ]
 
 
-class ConstraintExpr(ABC):
-    @abstractmethod
-    def _lower(self, ctx: "ConstraintLoweringContext") -> ir.Value:
-        pass
-
-    def __or__(self, other: "ConstraintExpr") -> "ConstraintExpr":
-        return AnyOf(self, other)
-
-    def __and__(self, other: "ConstraintExpr") -> "ConstraintExpr":
-        return AllOf(self, other)
-
-
 class ConstraintLoweringContext:
     def __init__(self):
         # Cache so that the same ConstraintExpr instance reuses its SSA value.
         self._cache: Dict[int, ir.Value] = {}
 
-    def lower(self, expr: ConstraintExpr) -> ir.Value:
-        key = id(expr)
+    def lower(self, type_) -> ir.Value:
+        key = id(type_)
         if key in self._cache:
             return self._cache[key]
-        v = expr._lower(self)
+        v = self._lower(type_)
         self._cache[key] = v
         return v
 
-
-class Is(ConstraintExpr):
-    def __init__(self, val: Callable[..., Union[ir.Attribute, ir.Type]]):
-        self.val = val
-        self.args = []
-        self.kwargs = {}
-
-    def __call__(self, *args, **kwargs) -> "Is":
-        self.args.extend(args)
-        self.kwargs.update(kwargs)
-        return self
-
-    def __class_getitem__(
-        cls, val: Callable[..., Union[ir.Attribute, ir.Type]]
-    ) -> "Is":
-        return cls(val)
-
-    def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
-        # for most attributes and types, they are created via `.get` method,
-        # here we can just omit the `.get` suffix for convenience
-        if isinstance(self.val, type) and hasattr(self.val, "get"):
-            self.val = self.val.get
-
-        val = self.val(*self.args, **self.kwargs)
-
-        if isinstance(val, ir.Type):
-            val = ir.TypeAttr.get(val)
-
-        return irdl.is_(val)
-
-
-class AnyOf(ConstraintExpr):
-    def __init__(self, *exprs: ConstraintExpr):
-        self.exprs = exprs
-
-    def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
-        return irdl.any_of(ctx.lower(expr) for expr in self.exprs)
-
-
-class AllOf(ConstraintExpr):
-    def __init__(self, *exprs: ConstraintExpr):
-        self.exprs = exprs
-
-    def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
-        return irdl.all_of(ctx.lower(expr) for expr in self.exprs)
-
-
-class Any(ConstraintExpr):
-    def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
-        return irdl.any()
-
-
-class BaseName(ConstraintExpr):
-    def __init__(self, name: str):
-        self.name = name
-
-    def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
-        return irdl.base(base_name=self.name)
-
-
-class BaseRef(ConstraintExpr):
-    def __init__(self, ref):
-        self.ref = ref
-
-    def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
-        return irdl.base(base_ref=self.ref)
+    def _lower(self, type_) -> ir.Value:
+        origin = get_origin(type_)
+        if origin and issubclass(origin, ir.Type):
+            t = origin.get(*get_args(type_))
+            return irdl.is_(ir.TypeAttr.get(t))
+        elif origin and issubclass(origin, ir.Attribute):
+            attr = origin.get(*get_args(type_))
+            return irdl.is_(attr)
+        elif origin is UnionType:
+            return irdl.any_of(self.lower(arg) for arg in get_args(type_))
+        elif type_ is Any or isinstance(type_, TypeVar):
+            return irdl.any()
+        elif issubclass(type_, ir.Type):
+            return irdl.base(base_name=f"!{type_.type_name}")
+        elif issubclass(type_, ir.Attribute):
+            return irdl.base(base_name=f"#{type_.attr_name}")
+
+        raise TypeError(f"unsupported type in constraints: {type_}")
 
 
 class FieldDef:
@@ -127,29 +65,33 @@ class FieldDef:
 
 
 @dataclass
-class Operand(FieldDef):
-    constraint: ConstraintExpr
-    variadicity: Variadicity = Variadicity.single
+class OperandDef(FieldDef):
+    constraint: Any
+    variadicity: Variadicity
 
 
 @dataclass
-class Result(FieldDef):
-    constraint: ConstraintExpr
-    variadicity: Variadicity = Variadicity.single
+class ResultDef(FieldDef):
+    constraint: Any
+    variadicity: Variadicity
 
 
 @dataclass
-class Attribute(FieldDef):
-    constraint: ConstraintExpr
-    variadicity: ClassVar[Variadicity] = Variadicity.single
+class AttributeDef(FieldDef):
+    constraint: Any
+    variadicity: Variadicity
+
+    def __post_init__(self):
+        if self.variadicity != Variadicity.single:
+            raise ValueError("optional attribute is not supported in IRDL")
 
 
 def partition_fields(
     fields: List[FieldDef],
-) -> Tuple[List[Operand], List[Attribute], List[Result]]:
-    operands = [i for i in fields if isinstance(i, Operand)]
-    attrs = [i for i in fields if isinstance(i, Attribute)]
-    results = [i for i in fields if isinstance(i, Result)]
+) -> Tuple[List[OperandDef], List[AttributeDef], List[ResultDef]]:
+    operands = [i for i in fields if isinstance(i, OperandDef)]
+    attrs = [i for i in fields if isinstance(i, AttributeDef)]
+    results = [i for i in fields if isinstance(i, ResultDef)]
     return operands, attrs, results
 
 
@@ -165,6 +107,30 @@ def normalize_value_range(
 
 
 class Operation(ir.OpView):
+    @staticmethod
+    def convert_type_to_field_def(type_) -> FieldDef:
+        variadicity = Variadicity.single
+        origin = get_origin(type_)
+        if (
+            origin is Union
+            and len(get_args(type_)) == 2
+            and get_args(type_)[1] is type(None)
+        ):
+            variadicity = Variadicity.optional
+            type_ = get_args(type_)[0]
+        elif origin is Sequence:
+            variadicity = Variadicity.variadic
+            type_ = get_args(type_)[0]
+
+        origin = get_origin(type_)
+        if origin is ir.OpOperand:
+            return OperandDef(get_args(type_)[0], variadicity)
+        elif origin is ir.OpResult:
+            return ResultDef(get_args(type_)[0], variadicity)
+        elif issubclass(origin or type_, ir.Attribute):
+            return AttributeDef(type_, variadicity)
+        raise TypeError(f"unsupported type in operation definition: {type_}")
+
     @classmethod
     def __init_subclass__(cls, *, name: str = None, **kwargs):
         super().__init_subclass__(**kwargs)
@@ -175,10 +141,10 @@ def __init_subclass__(cls, *, name: str = None, **kwargs):
         for base in cls.__bases__:
             if hasattr(base, "_fields"):
                 fields.extend(base._fields)
-        for key, value in cls.__dict__.items():
-            if isinstance(value, FieldDef):
-                setattr(value, "name", key)
-                fields.append(value)
+        for key, value in cls.__annotations__.items():
+            field = Operation.convert_type_to_field_def(value)
+            setattr(field, "name", key)
+            fields.append(field)
 
         # for subclasses without "name" parameter,
         # just treat them as normal classes
@@ -209,7 +175,7 @@ def _variadicity_to_segment(variadicity: Variadicity) -> int:
 
     @staticmethod
     def _generate_segments(
-        operands_or_results: List[Union[Operand, Result]],
+        operands_or_results: List[Union[OperandDef, ResultDef]],
     ) -> List[int]:
         if any(i.variadicity != Variadicity.single for i in operands_or_results):
             return [
@@ -222,8 +188,8 @@ def _generate_segments(
     def _generate_init_signature(fields: List[FieldDef]) -> Signature:
         # results are placed at the beginning of the parameter list,
         # but operands and attributes can appear in any relative order.
-        args = [i for i in fields if isinstance(i, Result)] + [
-            i for i in fields if not isinstance(i, Result)
+        args = [i for i in fields if isinstance(i, ResultDef)] + [
+            i for i in fields if not isinstance(i, ResultDef)
         ]
         positional_args = [
             i.name for i in args if i.variadicity != Variadicity.optional
@@ -292,7 +258,7 @@ def _generate_class_attributes(
         cls._ODS_RESULT_SEGMENTS = result_segments
 
     @classmethod
-    def _generate_attr_properties(cls, attrs: List[Attribute]) -> None:
+    def _generate_attr_properties(cls, attrs: List[AttributeDef]) -> None:
         for attr in attrs:
             setattr(
                 cls,
@@ -301,7 +267,7 @@ def _generate_attr_properties(cls, attrs: List[Attribute]) -> None:
             )
 
     @classmethod
-    def _generate_operand_properties(cls, operands: List[Operand]) -> None:
+    def _generate_operand_properties(cls, operands: List[OperandDef]) -> None:
         for i, operand in enumerate(operands):
             if cls._ODS_OPERAND_SEGMENTS:
 
@@ -318,7 +284,7 @@ def getter(self, i=i, operand=operand):
                 setattr(cls, operand.name, property(lambda self, i=i: self.operands[i]))
 
     @classmethod
-    def _generate_result_properties(cls, results: List[Result]) -> None:
+    def _generate_result_properties(cls, results: List[ResultDef]) -> None:
         for i, result in enumerate(results):
             if cls._ODS_RESULT_SEGMENTS:
 
@@ -393,6 +359,7 @@ def load(cls) -> None:
             raise RuntimeError(f"Dialect {cls.name} is already loaded.")
 
         mlir_module = cls._emit_module()
+        print(mlir_module)
         irdl.load_dialects(mlir_module)
 
         _cext.register_dialect(cls)
diff --git a/mlir/python/mlir/dialects/irdl/__init__.py b/mlir/python/mlir/dialects/irdl.py
similarity index 91%
rename from mlir/python/mlir/dialects/irdl/__init__.py
rename to mlir/python/mlir/dialects/irdl.py
index 6b2787ed7966c..1ec951b69b646 100644
--- a/mlir/python/mlir/dialects/irdl/__init__.py
+++ b/mlir/python/mlir/dialects/irdl.py
@@ -2,14 +2,13 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-from .._irdl_ops_gen import *
-from .._irdl_ops_gen import _Dialect
-from .._irdl_enum_gen import *
-from ..._mlir_libs._mlirDialectsIRDL import *
-from ...ir import register_attribute_builder
-from .._ods_common import _cext as _ods_cext
+from ._irdl_ops_gen import *
+from ._irdl_ops_gen import _Dialect
+from ._irdl_enum_gen import *
+from .._mlir_libs._mlirDialectsIRDL import *
+from ..ir import register_attribute_builder
+from ._ods_common import _cext as _ods_cext
 from typing import Union, Sequence
-from . import dsl
 
 _ods_ir = _ods_cext.ir
 
diff --git a/mlir/test/python/dialects/irdsl.py b/mlir/test/python/dialects/ext.py
similarity index 84%
rename from mlir/test/python/dialects/irdsl.py
rename to mlir/test/python/dialects/ext.py
index e80bb412139b0..f903bb8be978c 100644
--- a/mlir/test/python/dialects/irdsl.py
+++ b/mlir/test/python/dialects/ext.py
@@ -1,33 +1,32 @@
 # RUN: %PYTHON %s 2>&1 | FileCheck %s
 
 from mlir.ir import *
-from mlir.dialects.irdl import dsl as irdsl
-from mlir.dialects import arith
+from mlir.dialects import ext, arith
+from typing import Any, Optional, Sequence
 import sys
 
 
 def run(f):
-    print("\nTEST:", f.__name__, file=sys.stderr)
+    print("\nTEST:", f.__name__)
     f()
 
 
 # CHECK: TEST: testMyInt
 @run
 def testMyInt():
-    class MyInt(irdsl.Dialect, name="myint"):
+    class MyInt(ext.Dialect, name="myint"):
         pass
 
-    iattr = irdsl.BaseName("#builtin.integer")
-    i32 = irdsl.Is[IntegerType.get_signless](32)
+    i32 = IntegerType[32]
 
     class ConstantOp(MyInt.Operation, name="constant"):
-        value = irdsl.Attribute(iattr)
-        cst = irdsl.Result(i32)
+        value: IntegerAttr
+        cst: OpResult[i32]
 
     class AddOp(MyInt.Operation, name="add"):
-        lhs = irdsl.Operand(i32)
-        rhs = irdsl.Operand(i32)
-        res = irdsl.Result(i32)
+        lhs: OpOperand[i32]
+        rhs: OpOperand[i32]
+        res: OpResult[i32]
 
     # CHECK: irdl.dialect @myint {
     # CHECK:   irdl.operation @constant {
@@ -90,60 +89,55 @@ class AddOp(MyInt.Operation, name="add"):
         print(ConstantOp.__init__.__signature__)
 
 
+# CHECK: TEST: testIRDSL
 @run
 def testIRDSL():
-    class Test(irdsl.Dialect, name="irdsl_test"):
+    class Test(ext.Dialect, name="irdsl_test"):
         pass
 
-    i32 = irdsl.Is[IntegerType.get_signless](32)
-    i64 = irdsl.Is[IntegerType.get_signless](64)
-    i32or64 = i32 | i64
-    any = irdsl.Any()
-    f32 = irdsl.Is[F32Type]
-    iattr = irdsl.BaseName("#builtin.integer")
-    fattr = irdsl.BaseName("#builtin.float")
+    i32 = IntegerType[32]
 
     class ConstraintOp(Test.Operation, name="constraint"):
-        a = irdsl.Operand(i32or64)
-        b = irdsl.Operand(any)
-        c = irdsl.Operand(f32 | i32)
-        d = irdsl.Operand(any)
-        x = irdsl.Attribute(iattr)
-        y = irdsl.Attribute(fattr)
+        a: OpOperand[i32 | IntegerType[64]]
+        b: OpOperand[Any]
+        c: OpOperand[F32Type[()] | i32]
+        d: OpOperand[Any]
+        x: IntegerAttr
+        y: FloatAttr
 
     class OptionalOp(Test.Operation, name="optional"):
-        a = irdsl.Operand(i32)
-        b = irdsl.Operand(i32, irdsl.Variadicity.optional)
-        out1 = irdsl.Result(i32)
-        out2 = irdsl.Result(i32, irdsl.Variadicity.optional)
-        out3 = irdsl.Result(i32)
+        a: OpOperand[i32]
+        b: Optional[OpOperand[i32]]
+        out1: OpResult[i32]
+        out2: Optional[OpResult[i32]]
+        out3: OpResult[i32]
 
     class Optional2Op(Test.Operation, name="optional2"):
-        a = irdsl.Operand(i32, irdsl.Variadicity.optional)
-        b = irdsl.Result(i32, irdsl.Variadicity.optional)
+        a: Optional[OpOperand[i32]]
+        b: Optional[OpResult[i32]]
 
     class VariadicOp(Test.Operation, name="variadic"):
-        a = irdsl.Operand(i32)
-        b = irdsl.Operand(i32, irdsl.Variadicity.optional)
-        c = irdsl.Operand(i32, irdsl.Variadicity.variadic)
-        out1 = irdsl.Result(i32, irdsl.Variadicity.variadic)
-        out2 = irdsl.Result(i32, irdsl.Variadicity.variadic)
-        out3 = irdsl.Result(i32, irdsl.Variadicity.optional)
-        out4 = irdsl.Result(i32)
+        a: OpOperand[i32]
+        b: Optional[OpOperand[i32]]
+        c: Sequence[OpOperand[i32]]
+        out1: Sequence[OpResult[i32]]
+        out2: Sequence[OpResult[i32]]
+        out3: Optional[OpResult[i32]]
+        out4: OpResult[i32]
 
     class Variadic2Op(Test.Operation, name="variadic2"):
-        a = irdsl.Operand(i32, irdsl.Variadicity.variadic)
-        b = irdsl.Result(i32, irdsl.Variadicity.variadic)
+        a: Sequence[OpOperand[i32]]
+        b: Sequence[OpResult[i32]]
 
     class MixedOpBase(Test.Operation):
-        out = irdsl.Result(i32)
-        in1 = irdsl.Operand(i32)
+        out: OpResult[i32]
+        in1: OpOperand[i32]
 
     class MixedOp(MixedOpBase, name="mixed"):
-        in2 = irdsl.Attribute(iattr)
-        in3 = irdsl.Operand(i32, irdsl.Variadicity.optional)
-        in4 = irdsl.Attribute(iattr)
-        in5 = irdsl.Operand(i32)
+        in2: IntegerAttr
+        in3: Optional[OpOperand[i32]]
+        in4: IntegerAttr
+        in5: OpOperand[i32]
 
     # CHECK: irdl.dialect @irdsl_test {
     # CHECK:   irdl.operation @constraint {

>From 7e7b5187904a0733877e729a5bcd31961d9762b8 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 9 Jan 2026 21:41:01 +0800
Subject: [PATCH 20/20] add match_optional

---
 mlir/python/mlir/dialects/ext.py | 26 +++++++++++++++++---------
 1 file changed, 17 insertions(+), 9 deletions(-)

diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 18ab977d164de..14b25119946d2 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -48,7 +48,7 @@ def _lower(self, type_) -> ir.Value:
         elif origin and issubclass(origin, ir.Attribute):
             attr = origin.get(*get_args(type_))
             return irdl.is_(attr)
-        elif origin is UnionType:
+        elif origin is UnionType or origin is Union:
             return irdl.any_of(self.lower(arg) for arg in get_args(type_))
         elif type_ is Any or isinstance(type_, TypeVar):
             return irdl.any()
@@ -106,19 +106,27 @@ def normalize_value_range(
     return value_range
 
 
+def match_optional(type_):
+    origin = get_origin(type_)
+    args = get_args(type_)
+    if (
+        (origin is Union or origin is UnionType)
+        and len(args) == 2
+        and type(None) in args
+    ):
+        return args[0] if args[1] is type(None) else args[1]
+
+    return None
+
+
 class Operation(ir.OpView):
     @staticmethod
     def convert_type_to_field_def(type_) -> FieldDef:
         variadicity = Variadicity.single
-        origin = get_origin(type_)
-        if (
-            origin is Union
-            and len(get_args(type_)) == 2
-            and get_args(type_)[1] is type(None)
-        ):
+        if inner := match_optional(type_):
             variadicity = Variadicity.optional
-            type_ = get_args(type_)[0]
-        elif origin is Sequence:
+            type_ = inner
+        elif get_origin(type_) is Sequence:
             variadicity = Variadicity.variadic
             type_ = get_args(type_)[0]
 



More information about the Mlir-commits mailing list