[Mlir-commits] [mlir] [MLIR][Python] Add a DSL for defining IRDL dialects in Python bindings (PR #169045)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Nov 23 20:55:33 PST 2025
https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/169045
>From 3ec4809f2b5a76d14eba1a0707e30791cc5cc805 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Wed, 19 Nov 2025 01:04:38 +0800
Subject: [PATCH 1/2] [MLIR][Python] Add a DSL for defining IRDL dialects in
Python bindings
---
mlir/python/CMakeLists.txt | 4 +-
.../dialects/{irdl.py => irdl/__init__.py} | 13 +-
mlir/python/mlir/dialects/irdl/dsl.py | 343 ++++++++++++++++++
mlir/test/python/dialects/irdsl.py | 308 ++++++++++++++++
4 files changed, 661 insertions(+), 7 deletions(-)
rename mlir/python/mlir/dialects/{irdl.py => irdl/__init__.py} (91%)
create mode 100644 mlir/python/mlir/dialects/irdl/dsl.py
create mode 100644 mlir/test/python/dialects/irdsl.py
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 2acb6ee6cfda5..27b87a8b70144 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -511,7 +511,9 @@ declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/IRDLOps.td
- SOURCES dialects/irdl.py
+ SOURCES
+ dialects/irdl/__init__.py
+ dialects/irdl/dsl.py
DIALECT_NAME irdl
GEN_ENUM_BINDINGS
)
diff --git a/mlir/python/mlir/dialects/irdl.py b/mlir/python/mlir/dialects/irdl/__init__.py
similarity index 91%
rename from mlir/python/mlir/dialects/irdl.py
rename to mlir/python/mlir/dialects/irdl/__init__.py
index 1ec951b69b646..6b2787ed7966c 100644
--- a/mlir/python/mlir/dialects/irdl.py
+++ b/mlir/python/mlir/dialects/irdl/__init__.py
@@ -2,13 +2,14 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from ._irdl_ops_gen import *
-from ._irdl_ops_gen import _Dialect
-from ._irdl_enum_gen import *
-from .._mlir_libs._mlirDialectsIRDL import *
-from ..ir import register_attribute_builder
-from ._ods_common import _cext as _ods_cext
+from .._irdl_ops_gen import *
+from .._irdl_ops_gen import _Dialect
+from .._irdl_enum_gen import *
+from ..._mlir_libs._mlirDialectsIRDL import *
+from ...ir import register_attribute_builder
+from .._ods_common import _cext as _ods_cext
from typing import Union, Sequence
+from . import dsl
_ods_ir = _ods_cext.ir
diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
new file mode 100644
index 0000000000000..3cc234503665a
--- /dev/null
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -0,0 +1,343 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ...dialects import irdl as _irdl
+from .._ods_common import (
+ _cext as _ods_cext,
+ segmented_accessor as _ods_segmented_accessor,
+)
+from . import Variadicity
+from typing import Dict, List, Union, Callable, Tuple
+from dataclasses import dataclass
+from inspect import Parameter as _Parameter, Signature as _Signature
+from types import SimpleNamespace as _SimpleNameSpace
+
+_ods_ir = _ods_cext.ir
+
+
+class ConstraintExpr:
+ def _lower(self, ctx: "ConstraintLoweringContext") -> _ods_ir.Value:
+ raise NotImplementedError()
+
+ def __or__(self, other: "ConstraintExpr") -> "ConstraintExpr":
+ return AnyOf(self, other)
+
+ def __and__(self, other: "ConstraintExpr") -> "ConstraintExpr":
+ return AllOf(self, other)
+
+
+class ConstraintLoweringContext:
+ def __init__(self):
+ # Cache so that the same ConstraintExpr instance reuses its SSA value.
+ self._cache: Dict[int, _ods_ir.Value] = {}
+
+ def lower(self, expr: ConstraintExpr) -> _ods_ir.Value:
+ key = id(expr)
+ if key in self._cache:
+ return self._cache[key]
+ v = expr._lower(self)
+ self._cache[key] = v
+ return v
+
+
+class Is(ConstraintExpr):
+ def __init__(self, attr: _ods_ir.Attribute):
+ self.attr = attr
+
+ def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
+ return _irdl.is_(self.attr)
+
+
+class IsType(Is):
+ def __init__(self, typ: _ods_ir.Type):
+ super().__init__(_ods_ir.TypeAttr.get(typ))
+
+
+class AnyOf(ConstraintExpr):
+ def __init__(self, *exprs: ConstraintExpr):
+ self.exprs = exprs
+
+ def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
+ return _irdl.any_of(ctx.lower(expr) for expr in self.exprs)
+
+
+class AllOf(ConstraintExpr):
+ def __init__(self, *exprs: ConstraintExpr):
+ self.exprs = exprs
+
+ def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
+ return _irdl.all_of(ctx.lower(expr) for expr in self.exprs)
+
+
+class Any(ConstraintExpr):
+ def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
+ return _irdl.any()
+
+
+class BaseName(ConstraintExpr):
+ def __init__(self, name: str):
+ self.name = name
+
+ def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
+ return _irdl.base(base_name=self.name)
+
+
+class BaseRef(ConstraintExpr):
+ def __init__(self, ref):
+ self.ref = ref
+
+ def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
+ return _irdl.base(base_ref=self.ref)
+
+
+class FieldDef:
+ def __set_name__(self, owner, name: str):
+ self.name = name
+
+
+ at dataclass
+class Operand(FieldDef):
+ constraint: ConstraintExpr
+ variadicity: Variadicity = Variadicity.single
+
+
+ at dataclass
+class Result(FieldDef):
+ constraint: ConstraintExpr
+ variadicity: Variadicity = Variadicity.single
+
+
+ at dataclass
+class Attribute(FieldDef):
+ constraint: ConstraintExpr
+
+ def __post_init__(self):
+ # just for unified processing,
+ # currently optional attribute is not supported by IRDL
+ self.variadicity = Variadicity.single
+
+
+ at dataclass
+class Operation:
+ dialect_name: str
+ name: str
+ # We store operands and attributes into one list to maintain relative orders
+ # among them for generating OpView class.
+ operands_and_attrs: List[Union[Operand, Attribute]]
+ results: List[Result]
+
+ def _emit(self) -> None:
+ op = _irdl.operation_(self.name)
+ ctx = ConstraintLoweringContext()
+
+ operands = [i for i in self.operands_and_attrs if isinstance(i, Operand)]
+ attrs = [i for i in self.operands_and_attrs if isinstance(i, Attribute)]
+
+ with _ods_ir.InsertionPoint(op.body):
+ if operands:
+ _irdl.operands_(
+ [ctx.lower(i.constraint) for i in operands],
+ [i.name for i in operands],
+ [i.variadicity for i in operands],
+ )
+ if attrs:
+ _irdl.attributes_(
+ [ctx.lower(i.constraint) for i in attrs],
+ [i.name for i in attrs],
+ )
+ if self.results:
+ _irdl.results_(
+ [ctx.lower(i.constraint) for i in self.results],
+ [i.name for i in self.results],
+ [i.variadicity for i in self.results],
+ )
+
+ def _make_op_view_and_builder(self) -> Tuple[type, Callable]:
+ operands = [i for i in self.operands_and_attrs if isinstance(i, Operand)]
+ attrs = [i for i in self.operands_and_attrs if isinstance(i, Attribute)]
+
+ def variadicity_to_segment(variadicity: Variadicity) -> int:
+ if variadicity == Variadicity.variadic:
+ return -1
+ if variadicity == Variadicity.optional:
+ return 0
+ return 1
+
+ operand_segments = None
+ if any(i.variadicity != Variadicity.single for i in operands):
+ operand_segments = [variadicity_to_segment(i.variadicity) for i in operands]
+
+ result_segments = None
+ if any(i.variadicity != Variadicity.single for i in self.results):
+ result_segments = [
+ variadicity_to_segment(i.variadicity) for i in self.results
+ ]
+
+ args = self.results + self.operands_and_attrs
+ positional_args = [
+ i.name for i in args if i.variadicity != Variadicity.optional
+ ]
+ optional_args = [i.name for i in args if i.variadicity == Variadicity.optional]
+
+ params = [_Parameter("self", _Parameter.POSITIONAL_ONLY)]
+ for i in positional_args:
+ params.append(_Parameter(i, _Parameter.POSITIONAL_OR_KEYWORD))
+ for i in optional_args:
+ params.append(_Parameter(i, _Parameter.KEYWORD_ONLY, default=None))
+ params.append(_Parameter("loc", _Parameter.KEYWORD_ONLY, default=None))
+ params.append(_Parameter("ip", _Parameter.KEYWORD_ONLY, default=None))
+
+ sig = _Signature(params)
+ op = self
+
+ class _OpView(_ods_ir.OpView):
+ OPERATION_NAME = f"{op.dialect_name}.{op.name}"
+ _ODS_REGIONS = (0, True)
+ _ODS_OPERAND_SEGMENTS = operand_segments
+ _ODS_RESULT_SEGMENTS = result_segments
+
+ def __init__(*args, **kwargs):
+ bound = sig.bind(*args, **kwargs)
+ bound.apply_defaults()
+ args = bound.arguments
+
+ _operands = [args[operand.name] for operand in operands]
+ _results = [args[result.name] for result in op.results]
+ _attributes = dict(
+ (attr.name, args[attr.name])
+ for attr in attrs
+ if args[attr.name] is not None
+ )
+ _regions = None
+ _ods_successors = None
+ self = args["self"]
+ super(_OpView, self).__init__(
+ self.OPERATION_NAME,
+ self._ODS_REGIONS,
+ self._ODS_OPERAND_SEGMENTS,
+ self._ODS_RESULT_SEGMENTS,
+ attributes=_attributes,
+ results=_results,
+ operands=_operands,
+ successors=_ods_successors,
+ regions=_regions,
+ loc=args["loc"],
+ ip=args["ip"],
+ )
+
+ __init__.__signature__ = sig
+
+ for attr in attrs:
+ setattr(
+ _OpView,
+ attr.name,
+ property(lambda self, name=attr.name: self.attributes[name]),
+ )
+
+ def value_range_getter(
+ value_range: Union[_ods_ir.OpOperandList, _ods_ir.OpResultList],
+ variadicity: Variadicity,
+ ):
+ if variadicity == Variadicity.single:
+ return value_range[0]
+ if variadicity == Variadicity.optional:
+ return value_range[0] if len(value_range) > 0 else None
+ return value_range
+
+ for i, operand in enumerate(operands):
+ if operand_segments:
+
+ def getter(self, i=i, operand=operand):
+ operand_range = _ods_segmented_accessor(
+ self.operation.operands,
+ self.operation.attributes["operandSegmentSizes"],
+ i,
+ )
+ return value_range_getter(operand_range, operand.variadicity)
+
+ setattr(_OpView, operand.name, property(getter))
+ else:
+ setattr(
+ _OpView, operand.name, property(lambda self, i=i: self.operands[i])
+ )
+ for i, result in enumerate(self.results):
+ if result_segments:
+
+ def getter(self, i=i, result=result):
+ result_range = _ods_segmented_accessor(
+ self.operation.results,
+ self.operation.attributes["resultSegmentSizes"],
+ i,
+ )
+ return value_range_getter(result_range, result.variadicity)
+
+ setattr(_OpView, result.name, property(getter))
+ else:
+ setattr(
+ _OpView, result.name, property(lambda self, i=i: self.results[i])
+ )
+
+ def _builder(*args, **kwargs) -> _OpView:
+ return _OpView(*args, **kwargs)
+
+ _builder.__signature__ = _Signature(params[1:])
+
+ return _OpView, _builder
+
+
+class Dialect:
+ def __init__(self, name: str):
+ self.name = name
+ self.operations: List[Operation] = []
+ self.namespace = _SimpleNameSpace()
+
+ def _emit(self) -> None:
+ d = _irdl.dialect(self.name)
+ with _ods_ir.InsertionPoint(d.body):
+ for op in self.operations:
+ op._emit()
+
+ def _make_module(self) -> _ods_ir.Module:
+ with _ods_ir.Location.unknown():
+ m = _ods_ir.Module.create()
+ with _ods_ir.InsertionPoint(m.body):
+ self._emit()
+ return m
+
+ def _make_dialect_class(self) -> type:
+ class _Dialect(_ods_ir.Dialect):
+ DIALECT_NAMESPACE = self.name
+
+ return _Dialect
+
+ def load(self) -> _SimpleNameSpace:
+ _irdl.load_dialects(self._make_module())
+ dialect_class = self._make_dialect_class()
+ _ods_cext.register_dialect(dialect_class)
+ for op in self.operations:
+ _ods_cext.register_operation(dialect_class)(op.op_view)
+ return self.namespace
+
+ def op(self, name: str) -> Callable[[type], type]:
+ def decorator(cls: type) -> type:
+ operands_and_attrs: List[Union[Operand, Attribute]] = []
+ results: List[Result] = []
+
+ for field in cls.__dict__.values():
+ if isinstance(field, Operand) or isinstance(field, Attribute):
+ operands_and_attrs.append(field)
+ elif isinstance(field, Result):
+ results.append(field)
+
+ op_def = Operation(self.name, name, operands_and_attrs, results)
+ op_view, builder = op_def._make_op_view_and_builder()
+ setattr(op_def, "op_view", op_view)
+ setattr(op_def, "builder", builder)
+ self.operations.append(op_def)
+ self.namespace.__dict__[cls.__name__] = op_view
+ op_view.__name__ = cls.__name__
+ self.namespace.__dict__[name.replace(".", "_")] = builder
+ return cls
+
+ return decorator
diff --git a/mlir/test/python/dialects/irdsl.py b/mlir/test/python/dialects/irdsl.py
new file mode 100644
index 0000000000000..8ef30ae0a4c13
--- /dev/null
+++ b/mlir/test/python/dialects/irdsl.py
@@ -0,0 +1,308 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects.irdl import dsl as irdsl
+from mlir.dialects import arith
+import sys
+
+
+def run(f):
+ print("\nTEST:", f.__name__, file=sys.stderr)
+ with Context():
+ f()
+
+
+# CHECK: TEST: testMyInt
+ at run
+def testMyInt():
+ myint = irdsl.Dialect("myint")
+ iattr = irdsl.BaseName("#builtin.integer")
+ i32 = irdsl.IsType(IntegerType.get_signless(32))
+
+ @myint.op("constant")
+ class ConstantOp:
+ value = irdsl.Attribute(iattr)
+ cst = irdsl.Result(i32)
+
+ @myint.op("add")
+ class AddOp:
+ lhs = irdsl.Operand(i32)
+ rhs = irdsl.Operand(i32)
+ res = irdsl.Result(i32)
+
+ # CHECK: irdl.dialect @myint {
+ # CHECK: irdl.operation @constant {
+ # CHECK: %0 = irdl.base "#builtin.integer"
+ # CHECK: irdl.attributes {"value" = %0}
+ # CHECK: %1 = irdl.is i32
+ # CHECK: irdl.results(cst: %1)
+ # CHECK: }
+ # CHECK: irdl.operation @add {
+ # CHECK: %0 = irdl.is i32
+ # CHECK: irdl.operands(lhs: %0, rhs: %0)
+ # CHECK: irdl.results(res: %0)
+ # CHECK: }
+ # CHECK: }
+ print(myint._make_module())
+ myint = myint.load()
+
+ # CHECK: ['ConstantOp', 'constant', 'AddOp', 'add']
+ print([i for i in myint.__dict__.keys()])
+
+ i32 = IntegerType.get_signless(32)
+ with Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ two = myint.constant(i32, IntegerAttr.get(i32, 2))
+ three = myint.constant(i32, IntegerAttr.get(i32, 3))
+ add1 = myint.add(i32, two, three)
+ add2 = myint.add(i32, add1, two)
+ add3 = myint.add(i32, add2, three)
+
+ # CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
+ # CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
+ # CHECK: %2 = "myint.add"(%0, %1) : (i32, i32) -> i32
+ # CHECK: %3 = "myint.add"(%2, %0) : (i32, i32) -> i32
+ # CHECK: %4 = "myint.add"(%3, %1) : (i32, i32) -> i32
+ print(module)
+ assert module.operation.verify()
+
+ # CHECK: AddOp
+ print(type(add1).__name__)
+ # CHECK: ConstantOp
+ print(type(two).__name__)
+ # CHECK: myint.add
+ print(add1.OPERATION_NAME)
+ # CHECK: None
+ print(add1._ODS_OPERAND_SEGMENTS)
+ # CHECK: None
+ print(add1._ODS_RESULT_SEGMENTS)
+ # CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
+ print(add1.lhs.owner)
+ # CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
+ print(add1.rhs.owner)
+ # CHECK: 2 : i32
+ print(two.value)
+ # CHECK: Value(%0
+ print(two.cst)
+ # CHECK: (res, lhs, rhs, *, loc=None, ip=None)
+ print(myint.add.__signature__)
+ # CHECK: (cst, value, *, loc=None, ip=None)
+ print(myint.constant.__signature__)
+
+
+ at run
+def testIRDSL():
+ test = irdsl.Dialect("irdsl_test")
+ i32 = irdsl.IsType(IntegerType.get_signless(32))
+ i64 = irdsl.IsType(IntegerType.get_signless(64))
+ i32or64 = i32 | i64
+ any = irdsl.Any()
+ f32 = irdsl.IsType(F32Type.get())
+ iattr = irdsl.BaseName("#builtin.integer")
+ fattr = irdsl.BaseName("#builtin.float")
+
+ @test.op("constraint")
+ class ConstraintOp:
+ a = irdsl.Operand(i32or64)
+ b = irdsl.Operand(any)
+ c = irdsl.Operand(f32 | i32)
+ d = irdsl.Operand(any)
+ x = irdsl.Attribute(iattr)
+ y = irdsl.Attribute(fattr)
+
+ @test.op("optional")
+ class OptionalOp:
+ a = irdsl.Operand(i32)
+ b = irdsl.Operand(i32, irdsl.Variadicity.optional)
+ out1 = irdsl.Result(i32)
+ out2 = irdsl.Result(i32, irdsl.Variadicity.optional)
+ out3 = irdsl.Result(i32)
+
+ @test.op("optional2")
+ class Optional2Op:
+ a = irdsl.Operand(i32, irdsl.Variadicity.optional)
+ b = irdsl.Result(i32, irdsl.Variadicity.optional)
+
+ @test.op("variadic")
+ class VariadicOp:
+ a = irdsl.Operand(i32)
+ b = irdsl.Operand(i32, irdsl.Variadicity.optional)
+ c = irdsl.Operand(i32, irdsl.Variadicity.variadic)
+ out1 = irdsl.Result(i32, irdsl.Variadicity.variadic)
+ out2 = irdsl.Result(i32, irdsl.Variadicity.variadic)
+ out3 = irdsl.Result(i32, irdsl.Variadicity.optional)
+ out4 = irdsl.Result(i32)
+
+ @test.op("variadic2")
+ class Variadic2Op:
+ a = irdsl.Operand(i32, irdsl.Variadicity.variadic)
+ b = irdsl.Result(i32, irdsl.Variadicity.variadic)
+
+ @test.op("mixed")
+ class MixedOp:
+ out = irdsl.Result(i32)
+ in1 = irdsl.Operand(i32)
+ in2 = irdsl.Attribute(iattr)
+ in3 = irdsl.Operand(i32, irdsl.Variadicity.optional)
+ in4 = irdsl.Attribute(iattr)
+ in5 = irdsl.Operand(i32)
+
+ # CHECK: irdl.dialect @irdsl_test {
+ # CHECK: irdl.operation @constraint {
+ # CHECK: %0 = irdl.is i32
+ # CHECK: %1 = irdl.is i64
+ # CHECK: %2 = irdl.any_of(%0, %1)
+ # CHECK: %3 = irdl.any
+ # CHECK: %4 = irdl.is f32
+ # CHECK: %5 = irdl.any_of(%4, %0)
+ # CHECK: irdl.operands(a: %2, b: %3, c: %5, d: %3)
+ # CHECK: %6 = irdl.base "#builtin.integer"
+ # CHECK: %7 = irdl.base "#builtin.float"
+ # CHECK: irdl.attributes {"x" = %6, "y" = %7}
+ # CHECK: }
+ # CHECK: irdl.operation @optional {
+ # CHECK: %0 = irdl.is i32
+ # CHECK: irdl.operands(a: %0, b: optional %0)
+ # CHECK: irdl.results(out1: %0, out2: optional %0, out3: %0)
+ # CHECK: }
+ # CHECK: irdl.operation @optional2 {
+ # CHECK: %0 = irdl.is i32
+ # CHECK: irdl.operands(a: optional %0)
+ # CHECK: irdl.results(b: optional %0)
+ # CHECK: }
+ # CHECK: irdl.operation @variadic {
+ # CHECK: %0 = irdl.is i32
+ # CHECK: irdl.operands(a: %0, b: optional %0, c: variadic %0)
+ # CHECK: irdl.results(out1: variadic %0, out2: variadic %0, out3: optional %0, out4: %0)
+ # CHECK: }
+ # CHECK: irdl.operation @variadic2 {
+ # CHECK: %0 = irdl.is i32
+ # CHECK: irdl.operands(a: variadic %0)
+ # CHECK: irdl.results(b: variadic %0)
+ # CHECK: }
+ # CHECK: irdl.operation @mixed {
+ # CHECK: %0 = irdl.is i32
+ # CHECK: irdl.operands(in1: %0, in3: optional %0, in5: %0)
+ # CHECK: %1 = irdl.base "#builtin.integer"
+ # CHECK: irdl.attributes {"in2" = %1, "in4" = %1}
+ # CHECK: irdl.results(out: %0)
+ # CHECK: }
+ # CHECK: }
+ print(test._make_module())
+ test = test.load()
+
+ # CHECK: (a, b, c, d, x, y, *, loc=None, ip=None)
+ print(test.constraint.__signature__)
+ # CHECK: (out1, out3, a, *, out2=None, b=None, loc=None, ip=None)
+ print(test.optional.__signature__)
+ # CHECK: (*, b=None, a=None, loc=None, ip=None)
+ print(test.optional2.__signature__)
+ # CHECK: (out1, out2, out4, a, c, *, out3=None, b=None, loc=None, ip=None)
+ print(test.variadic.__signature__)
+ # CHECK: (b, a, *, loc=None, ip=None)
+ print(test.variadic2.__signature__)
+ # CHECK: (out, in1, in2, in4, in5, *, in3=None, loc=None, ip=None)
+ print(test.mixed.__signature__)
+
+ # CHECK: None None
+ print(
+ test.ConstraintOp._ODS_OPERAND_SEGMENTS, test.ConstraintOp._ODS_RESULT_SEGMENTS
+ )
+ # CHECK: [1, 0] [1, 0, 1]
+ print(test.OptionalOp._ODS_OPERAND_SEGMENTS, test.OptionalOp._ODS_RESULT_SEGMENTS)
+ # CHECK: [0] [0]
+ print(test.Optional2Op._ODS_OPERAND_SEGMENTS, test.Optional2Op._ODS_RESULT_SEGMENTS)
+ # CHECK: [1, 0, -1] [-1, -1, 0, 1]
+ print(test.VariadicOp._ODS_OPERAND_SEGMENTS, test.VariadicOp._ODS_RESULT_SEGMENTS)
+ # CHECK: [-1] [-1]
+ print(test.Variadic2Op._ODS_OPERAND_SEGMENTS, test.Variadic2Op._ODS_RESULT_SEGMENTS)
+
+ i32 = IntegerType.get_signless(32)
+ i64 = IntegerType.get_signless(64)
+ f32 = F32Type.get()
+
+ with Location.unknown():
+ iattr = IntegerAttr.get(i32, 2)
+ fattr = FloatAttr.get_f32(2.3)
+
+ module = Module.create()
+ with InsertionPoint(module.body):
+ ione = arith.constant(i32, 1)
+ fone = arith.constant(f32, 1.2)
+
+ # CHECK: "irdsl_test.constraint"(%c1_i32, %c1_i32, %cst, %c1_i32) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, i32, f32, i32) -> ()
+ c1 = test.constraint(ione, ione, fone, ione, iattr, fattr)
+ # CHECK: "irdsl_test.constraint"(%c1_i32, %cst, %cst, %cst) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, f32, f32, f32) -> ()
+ test.constraint(ione, fone, fone, fone, iattr, fattr)
+ # CHECK: irdsl_test.constraint"(%c1_i32, %cst, %c1_i32, %cst) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, f32, i32, f32) -> ()
+ test.constraint(ione, fone, ione, fone, iattr, fattr)
+
+ # CHECK: %0:2 = "irdsl_test.optional"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0>, resultSegmentSizes = array<i32: 1, 0, 1>} : (i32) -> (i32, i32)
+ o1 = test.optional(i32, i32, ione)
+ # CHECK: %1:3 = "irdsl_test.optional"(%c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 1>, resultSegmentSizes = array<i32: 1, 1, 1>} : (i32, i32) -> (i32, i32, i32)
+ o2 = test.optional(i32, i32, ione, out2=i32, b=ione)
+ # CHECK: irdsl_test.optional2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 0>} : () -> ()
+ o3 = test.optional2()
+ # CHECK: %2 = "irdsl_test.optional2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 1>} : () -> i32
+ o4 = test.optional2(b=i32)
+ # CHECK: "irdsl_test.optional2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 0>} : (i32) -> ()
+ o5 = test.optional2(a=ione)
+ # CHECK: %3 = "irdsl_test.optional2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 1>} : (i32) -> i32
+ o6 = test.optional2(b=i32, a=ione)
+
+ # CHECK: %4:4 = "irdsl_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 0, 2>, resultSegmentSizes = array<i32: 1, 2, 0, 1>} : (i32, i32, i32) -> (i32, i32, i32, i32)
+ v1 = test.variadic([i32], [i32, i32], i32, ione, [ione, ione])
+ # CHECK: %5:5 = "irdsl_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 1, 1>, resultSegmentSizes = array<i32: 1, 2, 1, 1>} : (i32, i32, i32) -> (i32, i32, i32, i32, i32)
+ v2 = test.variadic([i32], [i32, i32], i32, ione, [ione], out3=i32, b=ione)
+ # CHECK: %6:4 = "irdsl_test.variadic"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0, 0>, resultSegmentSizes = array<i32: 2, 1, 0, 1>} : (i32) -> (i32, i32, i32, i32)
+ v3 = test.variadic([i32, i32], [i32], i32, ione, [])
+ # CHECK: "irdsl_test.variadic2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 0>} : () -> ()
+ v4 = test.variadic2([], [])
+ # CHECK: "irdsl_test.variadic2"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 3>, resultSegmentSizes = array<i32: 0>} : (i32, i32, i32) -> ()
+ v5 = test.variadic2([], [ione, ione, ione])
+ # CHECK: %7:2 = "irdsl_test.variadic2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 2>} : (i32) -> (i32, i32)
+ v6 = test.variadic2([i32, i32], [ione])
+
+ # CHECK: %8 = "irdsl_test.mixed"(%c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 0, 1>} : (i32, i32) -> i32
+ m1 = test.mixed(i32, ione, iattr, iattr, ione)
+ # CHECK: %9 = "irdsl_test.mixed"(%c1_i32, %c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 1, 1>} : (i32, i32, i32) -> i32
+ m2 = test.mixed(i32, ione, iattr, iattr, ione, in3=ione)
+
+ print(module)
+ assert module.operation.verify()
+
+ # CHECK: Value(%c1_i32
+ print(c1.a)
+ # CHECK: 2 : i32
+ print(c1.x)
+ # CHECK: Value(%c1_i32
+ print(o1.a)
+ # CHECK: None
+ print(o1.b)
+ # CHECK: Value(%c1_i32
+ print(o2.b)
+ # CHECK: 0
+ print(o1.out1.result_number)
+ # CHECK: None
+ print(o1.out2)
+ # CHECK: 0
+ print(o2.out1.result_number)
+ # CHECK: 1
+ print(o2.out2.result_number)
+ # CHECK: None
+ print(o3.a)
+ # CHECK: Value(%c1_i32
+ print(o5.a)
+ # CHECK: ['Value(%c1_i32 = arith.constant 1 : i32)', 'Value(%c1_i32 = arith.constant 1 : i32)']
+ print([str(i) for i in v1.c])
+ # CHECK: ['Value(%c1_i32 = arith.constant 1 : i32)']
+ print([str(i) for i in v2.c])
+ # CHECK: []
+ print([str(i) for i in v3.c])
+ # CHECK: 0 0
+ print(len(v4.a), len(v4.b))
+ # CHECK: 3 0
+ print(len(v5.a), len(v5.b))
+ # CHECK: 1 2
+ print(len(v6.a), len(v6.b))
>From c0b8ba70e97dd09179ba75c6ae5dde04d3be152e Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 24 Nov 2025 12:55:21 +0800
Subject: [PATCH 2/2] make FieldDef private
---
mlir/python/mlir/dialects/irdl/dsl.py | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
index 3cc234503665a..b1916fdbd5f9b 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -91,25 +91,25 @@ def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
return _irdl.base(base_ref=self.ref)
-class FieldDef:
+class _FieldDef:
def __set_name__(self, owner, name: str):
self.name = name
@dataclass
-class Operand(FieldDef):
+class Operand(_FieldDef):
constraint: ConstraintExpr
variadicity: Variadicity = Variadicity.single
@dataclass
-class Result(FieldDef):
+class Result(_FieldDef):
constraint: ConstraintExpr
variadicity: Variadicity = Variadicity.single
@dataclass
-class Attribute(FieldDef):
+class Attribute(_FieldDef):
constraint: ConstraintExpr
def __post_init__(self):
More information about the Mlir-commits
mailing list