[Mlir-commits] [mlir] [MLIR][Python] Add a DSL for defining IRDL dialects in Python bindings (PR #169045)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Nov 23 20:55:33 PST 2025


https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/169045

>From 3ec4809f2b5a76d14eba1a0707e30791cc5cc805 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Wed, 19 Nov 2025 01:04:38 +0800
Subject: [PATCH 1/2] [MLIR][Python] Add a DSL for defining IRDL dialects in
 Python bindings

---
 mlir/python/CMakeLists.txt                    |   4 +-
 .../dialects/{irdl.py => irdl/__init__.py}    |  13 +-
 mlir/python/mlir/dialects/irdl/dsl.py         | 343 ++++++++++++++++++
 mlir/test/python/dialects/irdsl.py            | 308 ++++++++++++++++
 4 files changed, 661 insertions(+), 7 deletions(-)
 rename mlir/python/mlir/dialects/{irdl.py => irdl/__init__.py} (91%)
 create mode 100644 mlir/python/mlir/dialects/irdl/dsl.py
 create mode 100644 mlir/test/python/dialects/irdsl.py

diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 2acb6ee6cfda5..27b87a8b70144 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -511,7 +511,9 @@ declare_mlir_dialect_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   TD_FILE dialects/IRDLOps.td
-  SOURCES dialects/irdl.py
+  SOURCES
+    dialects/irdl/__init__.py
+    dialects/irdl/dsl.py
   DIALECT_NAME irdl
   GEN_ENUM_BINDINGS
 )
diff --git a/mlir/python/mlir/dialects/irdl.py b/mlir/python/mlir/dialects/irdl/__init__.py
similarity index 91%
rename from mlir/python/mlir/dialects/irdl.py
rename to mlir/python/mlir/dialects/irdl/__init__.py
index 1ec951b69b646..6b2787ed7966c 100644
--- a/mlir/python/mlir/dialects/irdl.py
+++ b/mlir/python/mlir/dialects/irdl/__init__.py
@@ -2,13 +2,14 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-from ._irdl_ops_gen import *
-from ._irdl_ops_gen import _Dialect
-from ._irdl_enum_gen import *
-from .._mlir_libs._mlirDialectsIRDL import *
-from ..ir import register_attribute_builder
-from ._ods_common import _cext as _ods_cext
+from .._irdl_ops_gen import *
+from .._irdl_ops_gen import _Dialect
+from .._irdl_enum_gen import *
+from ..._mlir_libs._mlirDialectsIRDL import *
+from ...ir import register_attribute_builder
+from .._ods_common import _cext as _ods_cext
 from typing import Union, Sequence
+from . import dsl
 
 _ods_ir = _ods_cext.ir
 
diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
new file mode 100644
index 0000000000000..3cc234503665a
--- /dev/null
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -0,0 +1,343 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ...dialects import irdl as _irdl
+from .._ods_common import (
+    _cext as _ods_cext,
+    segmented_accessor as _ods_segmented_accessor,
+)
+from . import Variadicity
+from typing import Dict, List, Union, Callable, Tuple
+from dataclasses import dataclass
+from inspect import Parameter as _Parameter, Signature as _Signature
+from types import SimpleNamespace as _SimpleNameSpace
+
+_ods_ir = _ods_cext.ir
+
+
+class ConstraintExpr:
+    def _lower(self, ctx: "ConstraintLoweringContext") -> _ods_ir.Value:
+        raise NotImplementedError()
+
+    def __or__(self, other: "ConstraintExpr") -> "ConstraintExpr":
+        return AnyOf(self, other)
+
+    def __and__(self, other: "ConstraintExpr") -> "ConstraintExpr":
+        return AllOf(self, other)
+
+
+class ConstraintLoweringContext:
+    def __init__(self):
+        # Cache so that the same ConstraintExpr instance reuses its SSA value.
+        self._cache: Dict[int, _ods_ir.Value] = {}
+
+    def lower(self, expr: ConstraintExpr) -> _ods_ir.Value:
+        key = id(expr)
+        if key in self._cache:
+            return self._cache[key]
+        v = expr._lower(self)
+        self._cache[key] = v
+        return v
+
+
+class Is(ConstraintExpr):
+    def __init__(self, attr: _ods_ir.Attribute):
+        self.attr = attr
+
+    def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
+        return _irdl.is_(self.attr)
+
+
+class IsType(Is):
+    def __init__(self, typ: _ods_ir.Type):
+        super().__init__(_ods_ir.TypeAttr.get(typ))
+
+
+class AnyOf(ConstraintExpr):
+    def __init__(self, *exprs: ConstraintExpr):
+        self.exprs = exprs
+
+    def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
+        return _irdl.any_of(ctx.lower(expr) for expr in self.exprs)
+
+
+class AllOf(ConstraintExpr):
+    def __init__(self, *exprs: ConstraintExpr):
+        self.exprs = exprs
+
+    def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
+        return _irdl.all_of(ctx.lower(expr) for expr in self.exprs)
+
+
+class Any(ConstraintExpr):
+    def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
+        return _irdl.any()
+
+
+class BaseName(ConstraintExpr):
+    def __init__(self, name: str):
+        self.name = name
+
+    def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
+        return _irdl.base(base_name=self.name)
+
+
+class BaseRef(ConstraintExpr):
+    def __init__(self, ref):
+        self.ref = ref
+
+    def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
+        return _irdl.base(base_ref=self.ref)
+
+
+class FieldDef:
+    def __set_name__(self, owner, name: str):
+        self.name = name
+
+
+ at dataclass
+class Operand(FieldDef):
+    constraint: ConstraintExpr
+    variadicity: Variadicity = Variadicity.single
+
+
+ at dataclass
+class Result(FieldDef):
+    constraint: ConstraintExpr
+    variadicity: Variadicity = Variadicity.single
+
+
+ at dataclass
+class Attribute(FieldDef):
+    constraint: ConstraintExpr
+
+    def __post_init__(self):
+        # just for unified processing,
+        # currently optional attribute is not supported by IRDL
+        self.variadicity = Variadicity.single
+
+
+ at dataclass
+class Operation:
+    dialect_name: str
+    name: str
+    # We store operands and attributes into one list to maintain relative orders
+    # among them for generating OpView class.
+    operands_and_attrs: List[Union[Operand, Attribute]]
+    results: List[Result]
+
+    def _emit(self) -> None:
+        op = _irdl.operation_(self.name)
+        ctx = ConstraintLoweringContext()
+
+        operands = [i for i in self.operands_and_attrs if isinstance(i, Operand)]
+        attrs = [i for i in self.operands_and_attrs if isinstance(i, Attribute)]
+
+        with _ods_ir.InsertionPoint(op.body):
+            if operands:
+                _irdl.operands_(
+                    [ctx.lower(i.constraint) for i in operands],
+                    [i.name for i in operands],
+                    [i.variadicity for i in operands],
+                )
+            if attrs:
+                _irdl.attributes_(
+                    [ctx.lower(i.constraint) for i in attrs],
+                    [i.name for i in attrs],
+                )
+            if self.results:
+                _irdl.results_(
+                    [ctx.lower(i.constraint) for i in self.results],
+                    [i.name for i in self.results],
+                    [i.variadicity for i in self.results],
+                )
+
+    def _make_op_view_and_builder(self) -> Tuple[type, Callable]:
+        operands = [i for i in self.operands_and_attrs if isinstance(i, Operand)]
+        attrs = [i for i in self.operands_and_attrs if isinstance(i, Attribute)]
+
+        def variadicity_to_segment(variadicity: Variadicity) -> int:
+            if variadicity == Variadicity.variadic:
+                return -1
+            if variadicity == Variadicity.optional:
+                return 0
+            return 1
+
+        operand_segments = None
+        if any(i.variadicity != Variadicity.single for i in operands):
+            operand_segments = [variadicity_to_segment(i.variadicity) for i in operands]
+
+        result_segments = None
+        if any(i.variadicity != Variadicity.single for i in self.results):
+            result_segments = [
+                variadicity_to_segment(i.variadicity) for i in self.results
+            ]
+
+        args = self.results + self.operands_and_attrs
+        positional_args = [
+            i.name for i in args if i.variadicity != Variadicity.optional
+        ]
+        optional_args = [i.name for i in args if i.variadicity == Variadicity.optional]
+
+        params = [_Parameter("self", _Parameter.POSITIONAL_ONLY)]
+        for i in positional_args:
+            params.append(_Parameter(i, _Parameter.POSITIONAL_OR_KEYWORD))
+        for i in optional_args:
+            params.append(_Parameter(i, _Parameter.KEYWORD_ONLY, default=None))
+        params.append(_Parameter("loc", _Parameter.KEYWORD_ONLY, default=None))
+        params.append(_Parameter("ip", _Parameter.KEYWORD_ONLY, default=None))
+
+        sig = _Signature(params)
+        op = self
+
+        class _OpView(_ods_ir.OpView):
+            OPERATION_NAME = f"{op.dialect_name}.{op.name}"
+            _ODS_REGIONS = (0, True)
+            _ODS_OPERAND_SEGMENTS = operand_segments
+            _ODS_RESULT_SEGMENTS = result_segments
+
+            def __init__(*args, **kwargs):
+                bound = sig.bind(*args, **kwargs)
+                bound.apply_defaults()
+                args = bound.arguments
+
+                _operands = [args[operand.name] for operand in operands]
+                _results = [args[result.name] for result in op.results]
+                _attributes = dict(
+                    (attr.name, args[attr.name])
+                    for attr in attrs
+                    if args[attr.name] is not None
+                )
+                _regions = None
+                _ods_successors = None
+                self = args["self"]
+                super(_OpView, self).__init__(
+                    self.OPERATION_NAME,
+                    self._ODS_REGIONS,
+                    self._ODS_OPERAND_SEGMENTS,
+                    self._ODS_RESULT_SEGMENTS,
+                    attributes=_attributes,
+                    results=_results,
+                    operands=_operands,
+                    successors=_ods_successors,
+                    regions=_regions,
+                    loc=args["loc"],
+                    ip=args["ip"],
+                )
+
+            __init__.__signature__ = sig
+
+        for attr in attrs:
+            setattr(
+                _OpView,
+                attr.name,
+                property(lambda self, name=attr.name: self.attributes[name]),
+            )
+
+        def value_range_getter(
+            value_range: Union[_ods_ir.OpOperandList, _ods_ir.OpResultList],
+            variadicity: Variadicity,
+        ):
+            if variadicity == Variadicity.single:
+                return value_range[0]
+            if variadicity == Variadicity.optional:
+                return value_range[0] if len(value_range) > 0 else None
+            return value_range
+
+        for i, operand in enumerate(operands):
+            if operand_segments:
+
+                def getter(self, i=i, operand=operand):
+                    operand_range = _ods_segmented_accessor(
+                        self.operation.operands,
+                        self.operation.attributes["operandSegmentSizes"],
+                        i,
+                    )
+                    return value_range_getter(operand_range, operand.variadicity)
+
+                setattr(_OpView, operand.name, property(getter))
+            else:
+                setattr(
+                    _OpView, operand.name, property(lambda self, i=i: self.operands[i])
+                )
+        for i, result in enumerate(self.results):
+            if result_segments:
+
+                def getter(self, i=i, result=result):
+                    result_range = _ods_segmented_accessor(
+                        self.operation.results,
+                        self.operation.attributes["resultSegmentSizes"],
+                        i,
+                    )
+                    return value_range_getter(result_range, result.variadicity)
+
+                setattr(_OpView, result.name, property(getter))
+            else:
+                setattr(
+                    _OpView, result.name, property(lambda self, i=i: self.results[i])
+                )
+
+        def _builder(*args, **kwargs) -> _OpView:
+            return _OpView(*args, **kwargs)
+
+        _builder.__signature__ = _Signature(params[1:])
+
+        return _OpView, _builder
+
+
+class Dialect:
+    def __init__(self, name: str):
+        self.name = name
+        self.operations: List[Operation] = []
+        self.namespace = _SimpleNameSpace()
+
+    def _emit(self) -> None:
+        d = _irdl.dialect(self.name)
+        with _ods_ir.InsertionPoint(d.body):
+            for op in self.operations:
+                op._emit()
+
+    def _make_module(self) -> _ods_ir.Module:
+        with _ods_ir.Location.unknown():
+            m = _ods_ir.Module.create()
+            with _ods_ir.InsertionPoint(m.body):
+                self._emit()
+        return m
+
+    def _make_dialect_class(self) -> type:
+        class _Dialect(_ods_ir.Dialect):
+            DIALECT_NAMESPACE = self.name
+
+        return _Dialect
+
+    def load(self) -> _SimpleNameSpace:
+        _irdl.load_dialects(self._make_module())
+        dialect_class = self._make_dialect_class()
+        _ods_cext.register_dialect(dialect_class)
+        for op in self.operations:
+            _ods_cext.register_operation(dialect_class)(op.op_view)
+        return self.namespace
+
+    def op(self, name: str) -> Callable[[type], type]:
+        def decorator(cls: type) -> type:
+            operands_and_attrs: List[Union[Operand, Attribute]] = []
+            results: List[Result] = []
+
+            for field in cls.__dict__.values():
+                if isinstance(field, Operand) or isinstance(field, Attribute):
+                    operands_and_attrs.append(field)
+                elif isinstance(field, Result):
+                    results.append(field)
+
+            op_def = Operation(self.name, name, operands_and_attrs, results)
+            op_view, builder = op_def._make_op_view_and_builder()
+            setattr(op_def, "op_view", op_view)
+            setattr(op_def, "builder", builder)
+            self.operations.append(op_def)
+            self.namespace.__dict__[cls.__name__] = op_view
+            op_view.__name__ = cls.__name__
+            self.namespace.__dict__[name.replace(".", "_")] = builder
+            return cls
+
+        return decorator
diff --git a/mlir/test/python/dialects/irdsl.py b/mlir/test/python/dialects/irdsl.py
new file mode 100644
index 0000000000000..8ef30ae0a4c13
--- /dev/null
+++ b/mlir/test/python/dialects/irdsl.py
@@ -0,0 +1,308 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects.irdl import dsl as irdsl
+from mlir.dialects import arith
+import sys
+
+
+def run(f):
+    print("\nTEST:", f.__name__, file=sys.stderr)
+    with Context():
+        f()
+
+
+# CHECK: TEST: testMyInt
+ at run
+def testMyInt():
+    myint = irdsl.Dialect("myint")
+    iattr = irdsl.BaseName("#builtin.integer")
+    i32 = irdsl.IsType(IntegerType.get_signless(32))
+
+    @myint.op("constant")
+    class ConstantOp:
+        value = irdsl.Attribute(iattr)
+        cst = irdsl.Result(i32)
+
+    @myint.op("add")
+    class AddOp:
+        lhs = irdsl.Operand(i32)
+        rhs = irdsl.Operand(i32)
+        res = irdsl.Result(i32)
+
+    # CHECK: irdl.dialect @myint {
+    # CHECK:   irdl.operation @constant {
+    # CHECK:     %0 = irdl.base "#builtin.integer"
+    # CHECK:     irdl.attributes {"value" = %0}
+    # CHECK:     %1 = irdl.is i32
+    # CHECK:     irdl.results(cst: %1)
+    # CHECK:   }
+    # CHECK:   irdl.operation @add {
+    # CHECK:     %0 = irdl.is i32
+    # CHECK:     irdl.operands(lhs: %0, rhs: %0)
+    # CHECK:     irdl.results(res: %0)
+    # CHECK:   }
+    # CHECK: }
+    print(myint._make_module())
+    myint = myint.load()
+
+    # CHECK: ['ConstantOp', 'constant', 'AddOp', 'add']
+    print([i for i in myint.__dict__.keys()])
+
+    i32 = IntegerType.get_signless(32)
+    with Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            two = myint.constant(i32, IntegerAttr.get(i32, 2))
+            three = myint.constant(i32, IntegerAttr.get(i32, 3))
+            add1 = myint.add(i32, two, three)
+            add2 = myint.add(i32, add1, two)
+            add3 = myint.add(i32, add2, three)
+
+    # CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
+    # CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
+    # CHECK: %2 = "myint.add"(%0, %1) : (i32, i32) -> i32
+    # CHECK: %3 = "myint.add"(%2, %0) : (i32, i32) -> i32
+    # CHECK: %4 = "myint.add"(%3, %1) : (i32, i32) -> i32
+    print(module)
+    assert module.operation.verify()
+
+    # CHECK: AddOp
+    print(type(add1).__name__)
+    # CHECK: ConstantOp
+    print(type(two).__name__)
+    # CHECK: myint.add
+    print(add1.OPERATION_NAME)
+    # CHECK: None
+    print(add1._ODS_OPERAND_SEGMENTS)
+    # CHECK: None
+    print(add1._ODS_RESULT_SEGMENTS)
+    # CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
+    print(add1.lhs.owner)
+    # CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
+    print(add1.rhs.owner)
+    # CHECK: 2 : i32
+    print(two.value)
+    # CHECK: Value(%0
+    print(two.cst)
+    # CHECK: (res, lhs, rhs, *, loc=None, ip=None)
+    print(myint.add.__signature__)
+    # CHECK: (cst, value, *, loc=None, ip=None)
+    print(myint.constant.__signature__)
+
+
+ at run
+def testIRDSL():
+    test = irdsl.Dialect("irdsl_test")
+    i32 = irdsl.IsType(IntegerType.get_signless(32))
+    i64 = irdsl.IsType(IntegerType.get_signless(64))
+    i32or64 = i32 | i64
+    any = irdsl.Any()
+    f32 = irdsl.IsType(F32Type.get())
+    iattr = irdsl.BaseName("#builtin.integer")
+    fattr = irdsl.BaseName("#builtin.float")
+
+    @test.op("constraint")
+    class ConstraintOp:
+        a = irdsl.Operand(i32or64)
+        b = irdsl.Operand(any)
+        c = irdsl.Operand(f32 | i32)
+        d = irdsl.Operand(any)
+        x = irdsl.Attribute(iattr)
+        y = irdsl.Attribute(fattr)
+
+    @test.op("optional")
+    class OptionalOp:
+        a = irdsl.Operand(i32)
+        b = irdsl.Operand(i32, irdsl.Variadicity.optional)
+        out1 = irdsl.Result(i32)
+        out2 = irdsl.Result(i32, irdsl.Variadicity.optional)
+        out3 = irdsl.Result(i32)
+
+    @test.op("optional2")
+    class Optional2Op:
+        a = irdsl.Operand(i32, irdsl.Variadicity.optional)
+        b = irdsl.Result(i32, irdsl.Variadicity.optional)
+
+    @test.op("variadic")
+    class VariadicOp:
+        a = irdsl.Operand(i32)
+        b = irdsl.Operand(i32, irdsl.Variadicity.optional)
+        c = irdsl.Operand(i32, irdsl.Variadicity.variadic)
+        out1 = irdsl.Result(i32, irdsl.Variadicity.variadic)
+        out2 = irdsl.Result(i32, irdsl.Variadicity.variadic)
+        out3 = irdsl.Result(i32, irdsl.Variadicity.optional)
+        out4 = irdsl.Result(i32)
+
+    @test.op("variadic2")
+    class Variadic2Op:
+        a = irdsl.Operand(i32, irdsl.Variadicity.variadic)
+        b = irdsl.Result(i32, irdsl.Variadicity.variadic)
+
+    @test.op("mixed")
+    class MixedOp:
+        out = irdsl.Result(i32)
+        in1 = irdsl.Operand(i32)
+        in2 = irdsl.Attribute(iattr)
+        in3 = irdsl.Operand(i32, irdsl.Variadicity.optional)
+        in4 = irdsl.Attribute(iattr)
+        in5 = irdsl.Operand(i32)
+
+    # CHECK: irdl.dialect @irdsl_test {
+    # CHECK:   irdl.operation @constraint {
+    # CHECK:     %0 = irdl.is i32
+    # CHECK:     %1 = irdl.is i64
+    # CHECK:     %2 = irdl.any_of(%0, %1)
+    # CHECK:     %3 = irdl.any
+    # CHECK:     %4 = irdl.is f32
+    # CHECK:     %5 = irdl.any_of(%4, %0)
+    # CHECK:     irdl.operands(a: %2, b: %3, c: %5, d: %3)
+    # CHECK:     %6 = irdl.base "#builtin.integer"
+    # CHECK:     %7 = irdl.base "#builtin.float"
+    # CHECK:     irdl.attributes {"x" = %6, "y" = %7}
+    # CHECK:   }
+    # CHECK:   irdl.operation @optional {
+    # CHECK:     %0 = irdl.is i32
+    # CHECK:     irdl.operands(a: %0, b: optional %0)
+    # CHECK:     irdl.results(out1: %0, out2: optional %0, out3: %0)
+    # CHECK:   }
+    # CHECK:   irdl.operation @optional2 {
+    # CHECK:     %0 = irdl.is i32
+    # CHECK:     irdl.operands(a: optional %0)
+    # CHECK:     irdl.results(b: optional %0)
+    # CHECK:   }
+    # CHECK:   irdl.operation @variadic {
+    # CHECK:     %0 = irdl.is i32
+    # CHECK:     irdl.operands(a: %0, b: optional %0, c: variadic %0)
+    # CHECK:     irdl.results(out1: variadic %0, out2: variadic %0, out3: optional %0, out4: %0)
+    # CHECK:   }
+    # CHECK:   irdl.operation @variadic2 {
+    # CHECK:     %0 = irdl.is i32
+    # CHECK:     irdl.operands(a: variadic %0)
+    # CHECK:     irdl.results(b: variadic %0)
+    # CHECK:   }
+    # CHECK:   irdl.operation @mixed {
+    # CHECK:     %0 = irdl.is i32
+    # CHECK:     irdl.operands(in1: %0, in3: optional %0, in5: %0)
+    # CHECK:     %1 = irdl.base "#builtin.integer"
+    # CHECK:     irdl.attributes {"in2" = %1, "in4" = %1}
+    # CHECK:     irdl.results(out: %0)
+    # CHECK:   }
+    # CHECK: }
+    print(test._make_module())
+    test = test.load()
+
+    # CHECK: (a, b, c, d, x, y, *, loc=None, ip=None)
+    print(test.constraint.__signature__)
+    # CHECK: (out1, out3, a, *, out2=None, b=None, loc=None, ip=None)
+    print(test.optional.__signature__)
+    # CHECK: (*, b=None, a=None, loc=None, ip=None)
+    print(test.optional2.__signature__)
+    # CHECK: (out1, out2, out4, a, c, *, out3=None, b=None, loc=None, ip=None)
+    print(test.variadic.__signature__)
+    # CHECK: (b, a, *, loc=None, ip=None)
+    print(test.variadic2.__signature__)
+    # CHECK: (out, in1, in2, in4, in5, *, in3=None, loc=None, ip=None)
+    print(test.mixed.__signature__)
+
+    # CHECK: None None
+    print(
+        test.ConstraintOp._ODS_OPERAND_SEGMENTS, test.ConstraintOp._ODS_RESULT_SEGMENTS
+    )
+    # CHECK: [1, 0] [1, 0, 1]
+    print(test.OptionalOp._ODS_OPERAND_SEGMENTS, test.OptionalOp._ODS_RESULT_SEGMENTS)
+    # CHECK: [0] [0]
+    print(test.Optional2Op._ODS_OPERAND_SEGMENTS, test.Optional2Op._ODS_RESULT_SEGMENTS)
+    # CHECK: [1, 0, -1] [-1, -1, 0, 1]
+    print(test.VariadicOp._ODS_OPERAND_SEGMENTS, test.VariadicOp._ODS_RESULT_SEGMENTS)
+    # CHECK: [-1] [-1]
+    print(test.Variadic2Op._ODS_OPERAND_SEGMENTS, test.Variadic2Op._ODS_RESULT_SEGMENTS)
+
+    i32 = IntegerType.get_signless(32)
+    i64 = IntegerType.get_signless(64)
+    f32 = F32Type.get()
+
+    with Location.unknown():
+        iattr = IntegerAttr.get(i32, 2)
+        fattr = FloatAttr.get_f32(2.3)
+
+        module = Module.create()
+        with InsertionPoint(module.body):
+            ione = arith.constant(i32, 1)
+            fone = arith.constant(f32, 1.2)
+
+            # CHECK: "irdsl_test.constraint"(%c1_i32, %c1_i32, %cst, %c1_i32) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, i32, f32, i32) -> ()
+            c1 = test.constraint(ione, ione, fone, ione, iattr, fattr)
+            # CHECK: "irdsl_test.constraint"(%c1_i32, %cst, %cst, %cst) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, f32, f32, f32) -> ()
+            test.constraint(ione, fone, fone, fone, iattr, fattr)
+            # CHECK: irdsl_test.constraint"(%c1_i32, %cst, %c1_i32, %cst) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, f32, i32, f32) -> ()
+            test.constraint(ione, fone, ione, fone, iattr, fattr)
+
+            # CHECK: %0:2 = "irdsl_test.optional"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0>, resultSegmentSizes = array<i32: 1, 0, 1>} : (i32) -> (i32, i32)
+            o1 = test.optional(i32, i32, ione)
+            # CHECK: %1:3 = "irdsl_test.optional"(%c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 1>, resultSegmentSizes = array<i32: 1, 1, 1>} : (i32, i32) -> (i32, i32, i32)
+            o2 = test.optional(i32, i32, ione, out2=i32, b=ione)
+            # CHECK: irdsl_test.optional2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 0>} : () -> ()
+            o3 = test.optional2()
+            # CHECK: %2 = "irdsl_test.optional2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 1>} : () -> i32
+            o4 = test.optional2(b=i32)
+            # CHECK: "irdsl_test.optional2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 0>} : (i32) -> ()
+            o5 = test.optional2(a=ione)
+            # CHECK: %3 = "irdsl_test.optional2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 1>} : (i32) -> i32
+            o6 = test.optional2(b=i32, a=ione)
+
+            # CHECK: %4:4 = "irdsl_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 0, 2>, resultSegmentSizes = array<i32: 1, 2, 0, 1>} : (i32, i32, i32) -> (i32, i32, i32, i32)
+            v1 = test.variadic([i32], [i32, i32], i32, ione, [ione, ione])
+            # CHECK: %5:5 = "irdsl_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 1, 1>, resultSegmentSizes = array<i32: 1, 2, 1, 1>} : (i32, i32, i32) -> (i32, i32, i32, i32, i32)
+            v2 = test.variadic([i32], [i32, i32], i32, ione, [ione], out3=i32, b=ione)
+            # CHECK: %6:4 = "irdsl_test.variadic"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0, 0>, resultSegmentSizes = array<i32: 2, 1, 0, 1>} : (i32) -> (i32, i32, i32, i32)
+            v3 = test.variadic([i32, i32], [i32], i32, ione, [])
+            # CHECK: "irdsl_test.variadic2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 0>} : () -> ()
+            v4 = test.variadic2([], [])
+            # CHECK: "irdsl_test.variadic2"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 3>, resultSegmentSizes = array<i32: 0>} : (i32, i32, i32) -> ()
+            v5 = test.variadic2([], [ione, ione, ione])
+            # CHECK: %7:2 = "irdsl_test.variadic2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 2>} : (i32) -> (i32, i32)
+            v6 = test.variadic2([i32, i32], [ione])
+
+            # CHECK: %8 = "irdsl_test.mixed"(%c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 0, 1>} : (i32, i32) -> i32
+            m1 = test.mixed(i32, ione, iattr, iattr, ione)
+            # CHECK: %9 = "irdsl_test.mixed"(%c1_i32, %c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 1, 1>} : (i32, i32, i32) -> i32
+            m2 = test.mixed(i32, ione, iattr, iattr, ione, in3=ione)
+
+    print(module)
+    assert module.operation.verify()
+
+    # CHECK: Value(%c1_i32
+    print(c1.a)
+    # CHECK: 2 : i32
+    print(c1.x)
+    # CHECK: Value(%c1_i32
+    print(o1.a)
+    # CHECK: None
+    print(o1.b)
+    # CHECK: Value(%c1_i32
+    print(o2.b)
+    # CHECK: 0
+    print(o1.out1.result_number)
+    # CHECK: None
+    print(o1.out2)
+    # CHECK: 0
+    print(o2.out1.result_number)
+    # CHECK: 1
+    print(o2.out2.result_number)
+    # CHECK: None
+    print(o3.a)
+    # CHECK: Value(%c1_i32
+    print(o5.a)
+    # CHECK: ['Value(%c1_i32 = arith.constant 1 : i32)', 'Value(%c1_i32 = arith.constant 1 : i32)']
+    print([str(i) for i in v1.c])
+    # CHECK: ['Value(%c1_i32 = arith.constant 1 : i32)']
+    print([str(i) for i in v2.c])
+    # CHECK: []
+    print([str(i) for i in v3.c])
+    # CHECK: 0 0
+    print(len(v4.a), len(v4.b))
+    # CHECK: 3 0
+    print(len(v5.a), len(v5.b))
+    # CHECK: 1 2
+    print(len(v6.a), len(v6.b))

>From c0b8ba70e97dd09179ba75c6ae5dde04d3be152e Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 24 Nov 2025 12:55:21 +0800
Subject: [PATCH 2/2] make FieldDef private

---
 mlir/python/mlir/dialects/irdl/dsl.py | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
index 3cc234503665a..b1916fdbd5f9b 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -91,25 +91,25 @@ def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
         return _irdl.base(base_ref=self.ref)
 
 
-class FieldDef:
+class _FieldDef:
     def __set_name__(self, owner, name: str):
         self.name = name
 
 
 @dataclass
-class Operand(FieldDef):
+class Operand(_FieldDef):
     constraint: ConstraintExpr
     variadicity: Variadicity = Variadicity.single
 
 
 @dataclass
-class Result(FieldDef):
+class Result(_FieldDef):
     constraint: ConstraintExpr
     variadicity: Variadicity = Variadicity.single
 
 
 @dataclass
-class Attribute(FieldDef):
+class Attribute(_FieldDef):
     constraint: ConstraintExpr
 
     def __post_init__(self):



More information about the Mlir-commits mailing list