[Mlir-commits] [mlir] 2cc4d45 - [MLIR][Python] Add a DSL for defining dialects in Python bindings (#169045)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jan 25 07:08:49 PST 2026


Author: Twice
Date: 2026-01-25T23:08:45+08:00
New Revision: 2cc4d45715cdd5eb23df51a48d6fd95c37a2276a

URL: https://github.com/llvm/llvm-project/commit/2cc4d45715cdd5eb23df51a48d6fd95c37a2276a
DIFF: https://github.com/llvm/llvm-project/commit/2cc4d45715cdd5eb23df51a48d6fd95c37a2276a.diff

LOG: [MLIR][Python] Add a DSL for defining dialects in Python bindings (#169045)

Python bindings for the IRDL dialect were introduced in #158488. They
are currently usable—for constructing IR and dynamically loading modules
that contain `irdl.dialect` into MLIR. However, there are still several
pain points when working with them:

* The IRDL IR-building interface is not very intuitive and tends to be
quite verbose.
* We do not yet have the corresponding `OpView` classes for IRDL-defined
operations.

To address these issues, I propose creating a wrapper (effectively a
small “DSL”) on top of the existing IRDL Python bindings. This wrapper
aims to simplify IR construction and automatically generate the
corresponding `OpView` types. A simple example is shown below.

Currently, using the IRDL bindings looks like this:

```python
m = Module.create()
with InsertionPoint(m.body):
    myint = irdl.dialect("myint")
    with InsertionPoint(myint.body):
        constant = irdl.operation_("constant")
        with InsertionPoint(constant.body):
            iattr = irdl.base(base_name="#builtin.integer")
            i32 = irdl.is_(TypeAttr.get(IntegerType.get_signless(32)))
            irdl.attributes_([iattr], ["value"])
            irdl.results_([i32], ["cst"], [irdl.Variadicity.single])

        add = irdl.operation_("add")
        with InsertionPoint(add.body):
            i32 = irdl.is_(TypeAttr.get(IntegerType.get_signless(32)))
            irdl.operands_(
                [i32, i32],
                ["lhs", "rhs"],
                [irdl.Variadicity.single, irdl.Variadicity.single],
            )
            irdl.results_([i32], ["res"], [irdl.Variadicity.single])

irdl.load_dialects(m)
```

With the proposed DSL (module name `mlir.dialects.ext`), the equivalent
implementation becomes:

```python
class MyInt(Dialect, name="myint"):
    pass

i32 = IntegerType[32]

class ConstantOp(MyInt.Operation, name="constant"):
    value: IntegerAttr
    cst: Result[i32]

class AddOp(MyInt.Operation, name="add"):
    lhs: Operand[i32]
    rhs: Operand[i32]
    res: Result[i32]

MyInt.load()
```

Compared with the current IRDL Python bindings, this DSL mainly adds the
following:

* **A more intuitive interface** for constructing IRDL definitions (as
shown in the example).
* **Automatic generation of the corresponding `OpView`
classes**—including `__init__` methods and property getters for each
defined operation. Similar to TableGen’s `ins`, operands and attributes
can be interleaved in arbitrary order. Special handling is also
implemented for optional and variadic operands/results (such as
computing segment sizes) so that they feel as natural to use as native
operations.
* **Lazy insertion of ops**: all ops are created and inserted only when
`Dialect.load()` is called, which makes it unnecessary to specify an
MLIR context immediately when defining an IRDL dialect.
* **Basic type inference** in operation builders (i.e.
`OpViewCls.__init__`) for trivial result types.

The current DSL does not yet cover all IRDL operations. Several features
are not supported at the moment:
- Defining new types or attributes
- Parametric constraints
- Adding regions to operations

---------

Co-authored-by: Rolf Morel <rolfmorel at gmail.com>

Added: 
    mlir/python/mlir/dialects/ext.py
    mlir/test/python/dialects/ext.py

Modified: 
    mlir/include/mlir/Bindings/Python/IRCore.h
    mlir/python/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 599771f8a3283..f9fc34e82c972 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -957,7 +957,7 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteType : public BaseTy {
   }
 
   static void bind(nanobind::module_ &m) {
-    auto cls = ClassTy(m, DerivedTy::pyClassName);
+    auto cls = ClassTy(m, DerivedTy::pyClassName, nanobind::is_generic());
     cls.def(nanobind::init<PyType &>(), nanobind::keep_alive<0, 1>(),
             nanobind::arg("cast_from_type"));
     cls.def_prop_ro_static(
@@ -1092,9 +1092,10 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteAttribute : public BaseTy {
   static void bind(nanobind::module_ &m, PyType_Slot *slots = nullptr) {
     ClassTy cls;
     if (slots) {
-      cls = ClassTy(m, DerivedTy::pyClassName, nanobind::type_slots(slots));
+      cls = ClassTy(m, DerivedTy::pyClassName, nanobind::type_slots(slots),
+                    nanobind::is_generic());
     } else {
-      cls = ClassTy(m, DerivedTy::pyClassName);
+      cls = ClassTy(m, DerivedTy::pyClassName, nanobind::is_generic());
     }
     cls.def(nanobind::init<PyAttribute &>(), nanobind::keep_alive<0, 1>(),
             nanobind::arg("cast_from_attr"));

diff  --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 003a06b16daac..8ab145ada85dd 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -30,6 +30,7 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
     passmanager.py
     rewrite.py
     dialects/_ods_common.py
+    dialects/ext.py
 )
 
 declare_mlir_python_sources(MLIRPythonSources.Core.Python.Extras

diff  --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
new file mode 100644
index 0000000000000..237c27bf62f77
--- /dev/null
+++ b/mlir/python/mlir/dialects/ext.py
@@ -0,0 +1,471 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from typing import (
+    Dict,
+    List,
+    Union,
+    Tuple,
+    Any,
+    Optional,
+    Callable,
+    TypeVar,
+    get_origin,
+    get_args,
+)
+from collections.abc import Sequence
+from dataclasses import dataclass
+from inspect import Parameter, Signature
+from types import UnionType
+from . import irdl
+from ._ods_common import _cext, segmented_accessor
+from .irdl import Variadicity
+from ..passmanager import PassManager
+
+ir = _cext.ir
+
+__all__ = [
+    "Dialect",
+    "Operand",
+    "Result",
+]
+
+Operand = ir.Value
+Result = ir.OpResult
+
+
+class ConstraintLoweringContext:
+    def __init__(self):
+        self._cache: Dict[str, ir.Value] = {}
+
+    def lower(self, type_) -> ir.Value:
+        """
+        Lower a type hint (e.g. `Any`, `IntegerType[32]`, `IntegerAttr | StringAttr`) into IRDL ops.
+        """
+
+        if type(type_) is TypeVar:
+            if type_.__name__ in self._cache:
+                return self._cache[type_.__name__]
+            v = self._lower(type_.__bound__ or Any)
+            self._cache[type_.__name__] = v
+        else:
+            v = self._lower(type_)
+        return v
+
+    def _lower(self, type_) -> ir.Value:
+        origin = get_origin(type_)
+        if origin is UnionType or origin is Union:
+            return irdl.any_of(self.lower(arg) for arg in get_args(type_))
+        elif type_ is Any:
+            return irdl.any()
+        elif isinstance(type_, TypeVar):
+            return self.lower(type_)
+        elif origin and issubclass(origin, ir.Type):
+            # `origin.get` is to construct an instance of MLIR type.
+            t = origin.get(*get_args(type_))
+            return irdl.is_(ir.TypeAttr.get(t))
+        elif origin and issubclass(origin, ir.Attribute):
+            # `origin.get` is to construct an instance of MLIR attribute.
+            attr = origin.get(*get_args(type_))
+            return irdl.is_(attr)
+        elif issubclass(type_, ir.Type):
+            return irdl.base(base_name=f"!{type_.type_name}")
+        elif issubclass(type_, ir.Attribute):
+            return irdl.base(base_name=f"#{type_.attr_name}")
+
+        raise TypeError(f"unsupported type in constraints: {type_}")
+
+
+def infer_type(type_) -> Optional[Callable[[], ir.Type]]:
+    """
+    A function to infer ir.Type from type annotation.
+    Returns a callable that returns the inferred ir.Type,
+    or None if the type cannot be inferred.
+    We use callables so that MLIR contexts are not required
+    while calling this function.
+    """
+
+    origin = get_origin(type_)
+    if origin and issubclass(origin, ir.Type):
+        # `origin.get` is to construct an instance of MLIR type/attribute.
+        return lambda: origin.get(*get_args(type_))
+    elif isinstance(type_, TypeVar):
+        return infer_type(type_.__bound__)
+    return None
+
+
+ at dataclass
+class FieldDef:
+    """
+    Base class for kinds of fields that can occur in an `Operation`'s definition.
+    """
+
+    name: str
+    constraint: Any
+    variadicity: Variadicity
+
+    @staticmethod
+    def from_type_hint(name, type_) -> "FieldDef":
+        variadicity = Variadicity.single
+        if inner := match_optional(type_):
+            variadicity = Variadicity.optional
+            type_ = inner
+        elif get_origin(type_) is Sequence:
+            variadicity = Variadicity.variadic
+            type_ = get_args(type_)[0]
+
+        origin = get_origin(type_)
+        if origin is ir.OpResult:
+            return ResultDef(name, get_args(type_)[0], variadicity)
+        elif origin is ir.Value:
+            return OperandDef(name, get_args(type_)[0], variadicity)
+        elif issubclass(origin or type_, ir.Attribute):
+            return AttributeDef(name, type_, variadicity)
+        raise TypeError(f"unsupported type in operation definition: {type_}")
+
+
+ at dataclass
+class OperandDef(FieldDef):
+    pass
+
+
+ at dataclass
+class ResultDef(FieldDef):
+    pass
+
+
+ at dataclass
+class AttributeDef(FieldDef):
+    def __post_init__(self):
+        if self.variadicity != Variadicity.single:
+            raise ValueError("optional attribute is not supported in IRDL")
+
+
+def partition_fields(
+    fields: List[FieldDef],
+) -> Tuple[List[OperandDef], List[AttributeDef], List[ResultDef]]:
+    operands = [i for i in fields if isinstance(i, OperandDef)]
+    attrs = [i for i in fields if isinstance(i, AttributeDef)]
+    results = [i for i in fields if isinstance(i, ResultDef)]
+    return operands, attrs, results
+
+
+def normalize_value_range(
+    value_range: Union[ir.OpOperandList, ir.OpResultList],
+    variadicity: Variadicity,
+) -> ir.Value | ir.OpOperandList | ir.OpResultList | None:
+    if variadicity == Variadicity.single:
+        return value_range[0]
+    if variadicity == Variadicity.optional:
+        return value_range[0] if len(value_range) > 0 else None
+    return value_range
+
+
+def match_optional(type_) -> Optional[Any]:
+    """
+    Try to match type hint like `Optional[T]`, `T | None` or `None | T`.
+    Returns the `T` inside `Optional[T]` if matched.
+    Returns `None` if not matched.
+    """
+
+    origin = get_origin(type_)
+    args = get_args(type_)
+    if (
+        (origin is Union or origin is UnionType)
+        and len(args) == 2
+        and type(None) in args
+    ):
+        return args[0] if args[1] is type(None) else args[1]
+
+    return None
+
+
+class Operation(ir.OpView):
+    """
+    Base class of Python-defined operation.
+
+    NOTE: Usually you don't need to use it directly.
+    Use `Dialect` and `.Operation` of `Dialect` subclasses instead.
+    """
+
+    @classmethod
+    def __init_subclass__(cls, *, name: str = None, **kwargs):
+        """
+        This method is to perform all magic to make a `Operation` subclass works like a dataclass, like:
+        - generate the method to emit IRDL operations,
+        - generate `__init__` method as an operation builder function,
+        - generate operand, result and attribute accessors
+        """
+
+        super().__init_subclass__(**kwargs)
+
+        fields = []
+
+        for base in cls.__bases__:
+            if hasattr(base, "_fields"):
+                fields.extend(base._fields)
+        for key, value in cls.__annotations__.items():
+            field = FieldDef.from_type_hint(key, value)
+            fields.append(field)
+
+        cls._fields = fields
+
+        # for subclasses without "name" parameter,
+        # just treat them as normal classes
+        if not name:
+            return
+
+        op_name = name
+        cls._op_name = op_name
+        dialect_name = cls._dialect_name
+        dialect_obj = cls._dialect_obj
+
+        cls._generate_class_attributes(dialect_name, op_name, fields)
+        cls._generate_init_method(fields)
+        operands, attrs, results = partition_fields(fields)
+        cls._generate_attr_properties(attrs)
+        cls._generate_operand_properties(operands)
+        cls._generate_result_properties(results)
+
+        dialect_obj.operations.append(cls)
+
+    @staticmethod
+    def _variadicity_to_segment(variadicity: Variadicity) -> int:
+        return {Variadicity.variadic: -1, Variadicity.optional: 0}.get(variadicity, 1)
+
+    @staticmethod
+    def _generate_segments(
+        operands_or_results: List[Union[OperandDef, ResultDef]],
+    ) -> List[int]:
+        if any(i.variadicity != Variadicity.single for i in operands_or_results):
+            return [
+                Operation._variadicity_to_segment(i.variadicity)
+                for i in operands_or_results
+            ]
+        return None
+
+    @staticmethod
+    def _generate_init_signature(
+        fields: List[FieldDef], can_infer_types: bool
+    ) -> Signature:
+        result_args = (
+            [] if can_infer_types else [i for i in fields if isinstance(i, ResultDef)]
+        )
+        # results are placed at the beginning of the parameter list,
+        # but operands and attributes can appear in any relative order.
+        args = result_args + [i for i in fields if not isinstance(i, ResultDef)]
+        positional_args = [
+            i.name for i in args if i.variadicity != Variadicity.optional
+        ]
+        optional_args = [i.name for i in args if i.variadicity == Variadicity.optional]
+
+        params = [Parameter("self", Parameter.POSITIONAL_ONLY)]
+        for i in positional_args:
+            params.append(Parameter(i, Parameter.POSITIONAL_OR_KEYWORD))
+        for i in optional_args:
+            params.append(Parameter(i, Parameter.KEYWORD_ONLY, default=None))
+        params.append(Parameter("loc", Parameter.KEYWORD_ONLY, default=None))
+        params.append(Parameter("ip", Parameter.KEYWORD_ONLY, default=None))
+
+        return Signature(params)
+
+    @classmethod
+    def _generate_init_method(cls, fields: List[FieldDef]) -> None:
+        operands, attrs, results = partition_fields(fields)
+        inferred_types = [infer_type(i.constraint) for i in results]
+
+        # we infer result types only when all result types can be inferred
+        # and all results are single (not optional or variadic)
+        can_infer_types = all(inferred_types) and all(
+            i.variadicity == Variadicity.single for i in results
+        )
+
+        init_sig = cls._generate_init_signature(fields, can_infer_types)
+
+        def __init__(*args, **kwargs):
+            bound = init_sig.bind(*args, **kwargs)
+            bound.apply_defaults()
+            args = bound.arguments
+
+            _operands = [args[operand.name] for operand in operands]
+            _results = (
+                [t() for t in inferred_types]
+                if can_infer_types
+                else [args[result.name] for result in results]
+            )
+            _attributes = dict(
+                (attr.name, args[attr.name])
+                for attr in attrs
+                if args[attr.name] is not None
+            )
+            _regions = None
+            _ods_successors = None
+            self = args["self"]
+            super(Operation, self).__init__(
+                self.OPERATION_NAME,
+                self._ODS_REGIONS,
+                self._ODS_OPERAND_SEGMENTS,
+                self._ODS_RESULT_SEGMENTS,
+                attributes=_attributes,
+                results=_results,
+                operands=_operands,
+                successors=_ods_successors,
+                regions=_regions,
+                loc=args["loc"],
+                ip=args["ip"],
+            )
+
+        __init__.__signature__ = init_sig
+        cls.__init__ = __init__
+
+    @classmethod
+    def _generate_class_attributes(
+        cls, dialect_name: str, op_name: str, fields: List[FieldDef]
+    ) -> None:
+        operands, attrs, results = partition_fields(fields)
+
+        operand_segments = cls._generate_segments(operands)
+        result_segments = cls._generate_segments(results)
+
+        cls.OPERATION_NAME = f"{dialect_name}.{op_name}"
+        cls._ODS_REGIONS = (0, True)
+        cls._ODS_OPERAND_SEGMENTS = operand_segments
+        cls._ODS_RESULT_SEGMENTS = result_segments
+
+    @classmethod
+    def _generate_attr_properties(cls, attrs: List[AttributeDef]) -> None:
+        for attr in attrs:
+            setattr(
+                cls,
+                attr.name,
+                property(lambda self, name=attr.name: self.attributes[name]),
+            )
+
+    @classmethod
+    def _generate_operand_properties(cls, operands: List[OperandDef]) -> None:
+        for i, operand in enumerate(operands):
+            if cls._ODS_OPERAND_SEGMENTS:
+
+                def getter(self, i=i, operand=operand):
+                    operand_range = segmented_accessor(
+                        self.operation.operands,
+                        self.operation.attributes["operandSegmentSizes"],
+                        i,
+                    )
+                    return normalize_value_range(operand_range, operand.variadicity)
+
+                setattr(cls, operand.name, property(getter))
+            else:
+                setattr(cls, operand.name, property(lambda self, i=i: self.operands[i]))
+
+    @classmethod
+    def _generate_result_properties(cls, results: List[ResultDef]) -> None:
+        for i, result in enumerate(results):
+            if cls._ODS_RESULT_SEGMENTS:
+
+                def getter(self, i=i, result=result):
+                    result_range = segmented_accessor(
+                        self.operation.results,
+                        self.operation.attributes["resultSegmentSizes"],
+                        i,
+                    )
+                    return normalize_value_range(result_range, result.variadicity)
+
+                setattr(cls, result.name, property(getter))
+            else:
+                setattr(cls, result.name, property(lambda self, i=i: self.results[i]))
+
+    @classmethod
+    def _emit_operation(cls) -> None:
+        ctx = ConstraintLoweringContext()
+        operands, attrs, results = partition_fields(cls._fields)
+
+        op = irdl.operation_(cls._op_name)
+        with ir.InsertionPoint(op.body):
+            if operands:
+                irdl.operands_(
+                    [ctx.lower(i.constraint) for i in operands],
+                    [i.name for i in operands],
+                    [i.variadicity for i in operands],
+                )
+            if attrs:
+                irdl.attributes_(
+                    [ctx.lower(i.constraint) for i in attrs],
+                    [i.name for i in attrs],
+                )
+            if results:
+                irdl.results_(
+                    [ctx.lower(i.constraint) for i in results],
+                    [i.name for i in results],
+                    [i.variadicity for i in results],
+                )
+
+
+class Dialect(ir.Dialect):
+    """
+    Base class of a Python-defined dialect.
+
+    It can be used like the following example:
+    ```python
+    class MyInt(Dialect, name="myint"):
+        pass
+
+    i32 = IntegerType[32]
+
+    class ConstantOp(MyInt.Operation, name="constant"):
+        value: IntegerAttr
+        cst: Result[i32]
+
+    class AddOp(MyInt.Operation, name="add"):
+        lhs: Operand[i32]
+        rhs: Operand[i32]
+        res: Result[i32]
+    ```
+    """
+
+    @classmethod
+    def __init_subclass__(cls, name: str, **kwargs):
+        cls.name = name
+        cls.DIALECT_NAMESPACE = name
+        cls.operations = []
+        cls.Operation = type(
+            "Operation",
+            (Operation,),
+            {"_dialect_obj": cls, "_dialect_name": name},
+        )
+
+    @classmethod
+    def _emit_dialect(cls) -> None:
+        d = irdl.dialect(cls.name)
+        with ir.InsertionPoint(d.body):
+            for op in cls.operations:
+                op._emit_operation()
+
+    @classmethod
+    def _emit_module(cls) -> ir.Module:
+        m = ir.Module.create()
+        with ir.InsertionPoint(m.body):
+            cls._emit_dialect()
+
+        return m
+
+    @classmethod
+    def load(cls) -> None:
+        if hasattr(cls, "_mlir_module"):
+            raise RuntimeError(f"Dialect {cls.name} is already loaded.")
+
+        mlir_module = cls._emit_module()
+
+        pm = PassManager()
+        pm.add("canonicalize, cse")
+        pm.run(mlir_module.operation)
+
+        irdl.load_dialects(mlir_module)
+
+        _cext.register_dialect(cls)
+
+        for op in cls.operations:
+            _cext.register_operation(cls)(op)
+
+        cls._mlir_module = mlir_module

diff  --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
new file mode 100644
index 0000000000000..483953ddfde51
--- /dev/null
+++ b/mlir/test/python/dialects/ext.py
@@ -0,0 +1,340 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects import arith
+from mlir.dialects.ext import *
+from typing import Any, Optional, Sequence, TypeVar, Union
+import sys
+
+
+def run(f):
+    print("\nTEST:", f.__name__)
+    f()
+
+
+# CHECK: TEST: testMyInt
+ at run
+def testMyInt():
+    class MyInt(Dialect, name="myint"):
+        pass
+
+    i32 = IntegerType[32]
+
+    class ConstantOp(MyInt.Operation, name="constant"):
+        value: IntegerAttr
+        cst: Result[i32]
+
+    class AddOp(MyInt.Operation, name="add"):
+        lhs: Operand[i32]
+        rhs: Operand[i32]
+        res: Result[i32]
+
+    # CHECK: irdl.dialect @myint {
+    # CHECK:   irdl.operation @constant {
+    # CHECK:     %0 = irdl.base "#builtin.integer"
+    # CHECK:     irdl.attributes {"value" = %0}
+    # CHECK:     %1 = irdl.is i32
+    # CHECK:     irdl.results(cst: %1)
+    # CHECK:   }
+    # CHECK:   irdl.operation @add {
+    # CHECK:     %0 = irdl.is i32
+    # CHECK:     irdl.operands(lhs: %0, rhs: %0)
+    # CHECK:     irdl.results(res: %0)
+    # CHECK:   }
+    # CHECK: }
+    with Context(), Location.unknown():
+        MyInt.load()
+        print(MyInt._mlir_module)
+
+        # CHECK: ['constant', 'add']
+        print([i._op_name for i in MyInt.operations])
+        i32 = IntegerType.get_signless(32)
+
+        module = Module.create()
+        with InsertionPoint(module.body):
+            two = ConstantOp(IntegerAttr.get(i32, 2))
+            three = ConstantOp(IntegerAttr.get(i32, 3))
+            add1 = AddOp(two, three)
+            add2 = AddOp(add1, two)
+            add3 = AddOp(add2, three)
+
+        # CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
+        # CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
+        # CHECK: %2 = "myint.add"(%0, %1) : (i32, i32) -> i32
+        # CHECK: %3 = "myint.add"(%2, %0) : (i32, i32) -> i32
+        # CHECK: %4 = "myint.add"(%3, %1) : (i32, i32) -> i32
+        print(module)
+        assert module.operation.verify()
+
+        # CHECK: AddOp
+        print(type(add1).__name__)
+        # CHECK: ConstantOp
+        print(type(two).__name__)
+        # CHECK: myint.add
+        print(add1.OPERATION_NAME)
+        # CHECK: None
+        print(add1._ODS_OPERAND_SEGMENTS)
+        # CHECK: None
+        print(add1._ODS_RESULT_SEGMENTS)
+        # CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
+        print(add1.lhs.owner)
+        # CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
+        print(add1.rhs.owner)
+        # CHECK: 2 : i32
+        print(two.value)
+        # CHECK: OpResult(%0
+        print(two.cst)
+        # CHECK: (self, /, lhs, rhs, *, loc=None, ip=None)
+        print(AddOp.__init__.__signature__)
+        # CHECK: (self, /, value, *, loc=None, ip=None)
+        print(ConstantOp.__init__.__signature__)
+
+
+# CHECK: TEST: testExtDialect
+ at run
+def testExtDialect():
+    class Test(Dialect, name="ext_test"):
+        pass
+
+    i32 = IntegerType[32]
+
+    class ConstraintOp(Test.Operation, name="constraint"):
+        a: Operand[i32 | IntegerType[64]]
+        b: Operand[Any]
+        # Here we use `F32Type[()]` instead of just `F32Type`
+        # because of an existing issue in IRDL implementation
+        # where `irdl.base` cannot exist in `irdl.any_of`.
+        c: Operand[F32Type[()] | i32]
+        d: Operand[Any]
+        x: IntegerAttr
+        y: FloatAttr
+
+    class OptionalOp(Test.Operation, name="optional"):
+        a: Operand[i32]
+        b: Optional[Operand[i32]]
+        out1: Result[i32]
+        out2: Result[i32] | None
+        out3: Result[i32]
+
+    class Optional2Op(Test.Operation, name="optional2"):
+        a: Optional[Operand[i32]]
+        b: Optional[Result[i32]]
+
+    class VariadicOp(Test.Operation, name="variadic"):
+        a: Operand[i32]
+        b: Optional[Operand[i32]]
+        c: Sequence[Operand[i32]]
+        out1: Sequence[Result[i32]]
+        out2: Sequence[Result[i32]]
+        out3: Optional[Result[i32]]
+        out4: Result[i32]
+
+    class Variadic2Op(Test.Operation, name="variadic2"):
+        a: Sequence[Operand[i32]]
+        b: Sequence[Result[i32]]
+
+    class MixedOpBase(Test.Operation):
+        out: Result[i32]
+        in1: Operand[i32]
+
+    class MixedOp(MixedOpBase, name="mixed"):
+        in2: IntegerAttr
+        in3: Optional[Operand[i32]]
+        in4: IntegerAttr
+        in5: Operand[i32]
+
+    T = TypeVar("T")
+    U = TypeVar("U", bound=IntegerType[32] | IntegerType[64])
+    V = TypeVar("V", bound=Union[IntegerType[8], IntegerType[16]])
+
+    class TypeVarOp(Test.Operation, name="type_var"):
+        in1: Operand[T]
+        in2: Operand[T]
+        in3: Operand[U]
+        in4: Operand[U | V]
+        in5: Operand[V]
+
+    # CHECK: irdl.dialect @ext_test {
+    # CHECK:   irdl.operation @constraint {
+    # CHECK:     %0 = irdl.is i32
+    # CHECK:     %1 = irdl.is i64
+    # CHECK:     %2 = irdl.any_of(%0, %1)
+    # CHECK:     %3 = irdl.any
+    # CHECK:     %4 = irdl.is f32
+    # CHECK:     %5 = irdl.any_of(%4, %0)
+    # CHECK:     %6 = irdl.any
+    # CHECK:     irdl.operands(a: %2, b: %3, c: %5, d: %6)
+    # CHECK:     %7 = irdl.base "#builtin.integer"
+    # CHECK:     %8 = irdl.base "#builtin.float"
+    # CHECK:     irdl.attributes {"x" = %7, "y" = %8}
+    # CHECK:   }
+    # CHECK:   irdl.operation @optional {
+    # CHECK:     %0 = irdl.is i32
+    # CHECK:     irdl.operands(a: %0, b: optional %0)
+    # CHECK:     irdl.results(out1: %0, out2: optional %0, out3: %0)
+    # CHECK:   }
+    # CHECK:   irdl.operation @optional2 {
+    # CHECK:     %0 = irdl.is i32
+    # CHECK:     irdl.operands(a: optional %0)
+    # CHECK:     irdl.results(b: optional %0)
+    # CHECK:   }
+    # CHECK:   irdl.operation @variadic {
+    # CHECK:     %0 = irdl.is i32
+    # CHECK:     irdl.operands(a: %0, b: optional %0, c: variadic %0)
+    # CHECK:     irdl.results(out1: variadic %0, out2: variadic %0, out3: optional %0, out4: %0)
+    # CHECK:   }
+    # CHECK:   irdl.operation @variadic2 {
+    # CHECK:     %0 = irdl.is i32
+    # CHECK:     irdl.operands(a: variadic %0)
+    # CHECK:     irdl.results(b: variadic %0)
+    # CHECK:   }
+    # CHECK:   irdl.operation @mixed {
+    # CHECK:     %0 = irdl.is i32
+    # CHECK:     irdl.operands(in1: %0, in3: optional %0, in5: %0)
+    # CHECK:     %1 = irdl.base "#builtin.integer"
+    # CHECK:     %2 = irdl.base "#builtin.integer"
+    # CHECK:     irdl.attributes {"in2" = %1, "in4" = %2}
+    # CHECK:     irdl.results(out: %0)
+    # CHECK:   }
+    # CHECK:   irdl.operation @type_var {
+    # CHECK:     %0 = irdl.any
+    # CHECK:     %1 = irdl.is i32
+    # CHECK:     %2 = irdl.is i64
+    # CHECK:     %3 = irdl.any_of(%1, %2)
+    # CHECK:     %4 = irdl.is i8
+    # CHECK:     %5 = irdl.is i16
+    # CHECK:     %6 = irdl.any_of(%4, %5)
+    # CHECK:     %7 = irdl.any_of(%3, %6)
+    # CHECK:     irdl.operands(in1: %0, in2: %0, in3: %3, in4: %7, in5: %6)
+    # CHECK:   }
+    # CHECK: }
+    with Context(), Location.unknown():
+        Test.load()
+        print(Test._mlir_module)
+
+        # CHECK: (self, /, a, b, c, d, x, y, *, loc=None, ip=None)
+        print(ConstraintOp.__init__.__signature__)
+        # CHECK: (self, /, out1, out3, a, *, out2=None, b=None, loc=None, ip=None)
+        print(OptionalOp.__init__.__signature__)
+        # CHECK: (self, /, *, b=None, a=None, loc=None, ip=None)
+        print(Optional2Op.__init__.__signature__)
+        # CHECK: (self, /, out1, out2, out4, a, c, *, out3=None, b=None, loc=None, ip=None)
+        print(VariadicOp.__init__.__signature__)
+        # CHECK: (self, /, b, a, *, loc=None, ip=None)
+        print(Variadic2Op.__init__.__signature__)
+        # CHECK: (self, /, in1, in2, in4, in5, *, in3=None, loc=None, ip=None)
+        print(MixedOp.__init__.__signature__)
+
+        # CHECK: None None
+        print(ConstraintOp._ODS_OPERAND_SEGMENTS, ConstraintOp._ODS_RESULT_SEGMENTS)
+        # CHECK: [1, 0] [1, 0, 1]
+        print(OptionalOp._ODS_OPERAND_SEGMENTS, OptionalOp._ODS_RESULT_SEGMENTS)
+        # CHECK: [0] [0]
+        print(Optional2Op._ODS_OPERAND_SEGMENTS, Optional2Op._ODS_RESULT_SEGMENTS)
+        # CHECK: [1, 0, -1] [-1, -1, 0, 1]
+        print(VariadicOp._ODS_OPERAND_SEGMENTS, VariadicOp._ODS_RESULT_SEGMENTS)
+        # CHECK: [-1] [-1]
+        print(Variadic2Op._ODS_OPERAND_SEGMENTS, Variadic2Op._ODS_RESULT_SEGMENTS)
+
+        i32 = IntegerType.get_signless(32)
+        i64 = IntegerType.get_signless(64)
+        f32 = F32Type.get()
+
+        iattr = IntegerAttr.get(i32, 2)
+        fattr = FloatAttr.get_f32(2.3)
+
+        module = Module.create()
+        with InsertionPoint(module.body):
+            ione = arith.constant(i32, 1)
+            fone = arith.constant(f32, 1.2)
+
+            # CHECK: "ext_test.constraint"(%c1_i32, %c1_i32, %cst, %c1_i32) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, i32, f32, i32) -> ()
+            c1 = ConstraintOp(ione, ione, fone, ione, iattr, fattr)
+            # CHECK: "ext_test.constraint"(%c1_i32, %cst, %cst, %cst) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, f32, f32, f32) -> ()
+            ConstraintOp(ione, fone, fone, fone, iattr, fattr)
+            # CHECK: ext_test.constraint"(%c1_i32, %cst, %c1_i32, %cst) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, f32, i32, f32) -> ()
+            ConstraintOp(ione, fone, ione, fone, iattr, fattr)
+
+            # CHECK: %0:2 = "ext_test.optional"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0>, resultSegmentSizes = array<i32: 1, 0, 1>} : (i32) -> (i32, i32)
+            o1 = OptionalOp(i32, i32, ione)
+            # CHECK: %1:3 = "ext_test.optional"(%c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 1>, resultSegmentSizes = array<i32: 1, 1, 1>} : (i32, i32) -> (i32, i32, i32)
+            o2 = OptionalOp(i32, i32, ione, out2=i32, b=ione)
+            # CHECK: ext_test.optional2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 0>} : () -> ()
+            o3 = Optional2Op()
+            # CHECK: %2 = "ext_test.optional2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 1>} : () -> i32
+            o4 = Optional2Op(b=i32)
+            # CHECK: "ext_test.optional2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 0>} : (i32) -> ()
+            o5 = Optional2Op(a=ione)
+            # CHECK: %3 = "ext_test.optional2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 1>} : (i32) -> i32
+            o6 = Optional2Op(b=i32, a=ione)
+
+            # CHECK: %4:4 = "ext_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 0, 2>, resultSegmentSizes = array<i32: 1, 2, 0, 1>} : (i32, i32, i32) -> (i32, i32, i32, i32)
+            v1 = VariadicOp([i32], [i32, i32], i32, ione, [ione, ione])
+            # CHECK: %5:5 = "ext_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 1, 1>, resultSegmentSizes = array<i32: 1, 2, 1, 1>} : (i32, i32, i32) -> (i32, i32, i32, i32, i32)
+            v2 = VariadicOp([i32], [i32, i32], i32, ione, [ione], out3=i32, b=ione)
+            # CHECK: %6:4 = "ext_test.variadic"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0, 0>, resultSegmentSizes = array<i32: 2, 1, 0, 1>} : (i32) -> (i32, i32, i32, i32)
+            v3 = VariadicOp([i32, i32], [i32], i32, ione, [])
+            # CHECK: "ext_test.variadic2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 0>} : () -> ()
+            v4 = Variadic2Op([], [])
+            # CHECK: "ext_test.variadic2"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 3>, resultSegmentSizes = array<i32: 0>} : (i32, i32, i32) -> ()
+            v5 = Variadic2Op([], [ione, ione, ione])
+            # CHECK: %7:2 = "ext_test.variadic2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 2>} : (i32) -> (i32, i32)
+            v6 = Variadic2Op([i32, i32], [ione])
+
+            # CHECK: %8 = "ext_test.mixed"(%c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 0, 1>} : (i32, i32) -> i32
+            m1 = MixedOp(ione, iattr, iattr, ione)
+            # CHECK: %9 = "ext_test.mixed"(%c1_i32, %c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 1, 1>} : (i32, i32, i32) -> i32
+            m2 = MixedOp(ione, iattr, iattr, ione, in3=ione)
+
+        print(module)
+        assert module.operation.verify()
+
+        # CHECK: OpResult(%c1_i32
+        print(c1.a)
+        # CHECK: 2 : i32
+        print(c1.x)
+        # CHECK: OpResult(%c1_i32
+        print(o1.a)
+        # CHECK: None
+        print(o1.b)
+        # CHECK: OpResult(%c1_i32
+        print(o2.b)
+        # CHECK: 0
+        print(o1.out1.result_number)
+        # CHECK: None
+        print(o1.out2)
+        # CHECK: 0
+        print(o2.out1.result_number)
+        # CHECK: 1
+        print(o2.out2.result_number)
+        # CHECK: None
+        print(o3.a)
+        # CHECK: OpResult(%c1_i32
+        print(o5.a)
+        # CHECK: ['OpResult(%c1_i32 = arith.constant 1 : i32)', 'OpResult(%c1_i32 = arith.constant 1 : i32)']
+        print([str(i) for i in v1.c])
+        # CHECK: ['OpResult(%c1_i32 = arith.constant 1 : i32)']
+        print([str(i) for i in v2.c])
+        # CHECK: []
+        print([str(i) for i in v3.c])
+        # CHECK: 0 0
+        print(len(v4.a), len(v4.b))
+        # CHECK: 3 0
+        print(len(v5.a), len(v5.b))
+        # CHECK: 1 2
+        print(len(v6.a), len(v6.b))
+
+        # cases to violate constraits
+        module = Module.create()
+        with InsertionPoint(module.body):
+            try:
+                c1 = ConstraintOp(ione, ione, fone, ione, iattr)
+            except TypeError as e:
+                # CHECK: missing a required argument: 'y'
+                print(e)
+
+            try:
+                c2 = ConstraintOp(ione, ione, fone, ione, iattr, fattr, ione)
+            except TypeError as e:
+                # CHECK:too many positional arguments
+                print(e)


        


More information about the Mlir-commits mailing list