[Mlir-commits] [mlir] [MLIR][Python] Add a DSL for defining dialects in Python bindings (PR #169045)

Rolf Morel llvmlistbot at llvm.org
Thu Jan 22 08:28:59 PST 2026


================
@@ -0,0 +1,422 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from typing import (
+    Dict,
+    List,
+    Union,
+    Tuple,
+    Any,
+    Optional,
+    Callable,
+    TypeVar,
+    get_origin,
+    get_args,
+)
+from collections.abc import Sequence
+from dataclasses import dataclass
+from inspect import Parameter, Signature
+from types import UnionType
+from . import irdl
+from ._ods_common import _cext, segmented_accessor
+from .irdl import Variadicity
+from ..passmanager import PassManager
+
+ir = _cext.ir
+
+__all__ = [
+    "Dialect",
+    "Operand",
+    "Result",
+]
+
+Operand = ir.Value
+Result = ir.OpResult
+
+
+class ConstraintLoweringContext:
+    def __init__(self):
+        self._cache: Dict[str, ir.Value] = {}
+
+    def lower(self, type_) -> ir.Value:
+        if type(type_) is TypeVar:
+            if type_.__name__ in self._cache:
+                return self._cache[type_.__name__]
+            v = self._lower(type_.__bound__ or Any)
+            self._cache[type_.__name__] = v
+        else:
+            v = self._lower(type_)
+        return v
+
+    def _lower(self, type_) -> ir.Value:
+        origin = get_origin(type_)
+        if origin is UnionType or origin is Union:
+            return irdl.any_of(self.lower(arg) for arg in get_args(type_))
+        elif type_ is Any:
+            return irdl.any()
+        elif isinstance(type_, TypeVar):
+            return self.lower(type_)
+        elif origin and issubclass(origin, ir.Type):
+            t = origin.get(*get_args(type_))
+            return irdl.is_(ir.TypeAttr.get(t))
+        elif origin and issubclass(origin, ir.Attribute):
+            attr = origin.get(*get_args(type_))
+            return irdl.is_(attr)
+        elif issubclass(type_, ir.Type):
+            return irdl.base(base_name=f"!{type_.type_name}")
+        elif issubclass(type_, ir.Attribute):
+            return irdl.base(base_name=f"#{type_.attr_name}")
+
+        raise TypeError(f"unsupported type in constraints: {type_}")
+
+
+# A function to infer ir.Type from type annotation.
+# Returns a callable that returns the inferred ir.Type,
+# or None if the type cannot be inferred.
+# We use callables so that MLIR contexts are not required
+# while calling this function.
+def infer_type(type_) -> Optional[Callable[[], ir.Type]]:
+    origin = get_origin(type_)
+    if origin and issubclass(origin, ir.Type):
+        return lambda: origin.get(*get_args(type_))
+    elif isinstance(type_, TypeVar):
+        return infer_type(type_.__bound__)
+    return None
+
+
+class FieldDef:
+    pass
+
+
+ at dataclass
+class OperandDef(FieldDef):
+    constraint: Any
+    variadicity: Variadicity
+
+
+ at dataclass
+class ResultDef(FieldDef):
+    constraint: Any
+    variadicity: Variadicity
+
+
+ at dataclass
+class AttributeDef(FieldDef):
+    constraint: Any
+    variadicity: Variadicity
+
+    def __post_init__(self):
+        if self.variadicity != Variadicity.single:
+            raise ValueError("optional attribute is not supported in IRDL")
+
+
+def partition_fields(
+    fields: List[FieldDef],
+) -> Tuple[List[OperandDef], List[AttributeDef], List[ResultDef]]:
+    operands = [i for i in fields if isinstance(i, OperandDef)]
+    attrs = [i for i in fields if isinstance(i, AttributeDef)]
+    results = [i for i in fields if isinstance(i, ResultDef)]
+    return operands, attrs, results
+
+
+def normalize_value_range(
+    value_range: Union[ir.OpOperandList, ir.OpResultList],
+    variadicity: Variadicity,
+):
+    if variadicity == Variadicity.single:
+        return value_range[0]
+    if variadicity == Variadicity.optional:
+        return value_range[0] if len(value_range) > 0 else None
+    return value_range
+
+
+def match_optional(type_):
+    origin = get_origin(type_)
+    args = get_args(type_)
+    if (
+        (origin is Union or origin is UnionType)
+        and len(args) == 2
+        and type(None) in args
+    ):
+        return args[0] if args[1] is type(None) else args[1]
+
+    return None
+
+
+class Operation(ir.OpView):
+    @staticmethod
+    def convert_type_to_field_def(type_) -> FieldDef:
+        variadicity = Variadicity.single
+        if inner := match_optional(type_):
+            variadicity = Variadicity.optional
+            type_ = inner
+        elif get_origin(type_) is Sequence:
+            variadicity = Variadicity.variadic
+            type_ = get_args(type_)[0]
+
+        origin = get_origin(type_)
+        if origin is ir.OpResult:
+            return ResultDef(get_args(type_)[0], variadicity)
+        elif origin is ir.Value:
+            return OperandDef(get_args(type_)[0], variadicity)
+        elif issubclass(origin or type_, ir.Attribute):
+            return AttributeDef(type_, variadicity)
+        raise TypeError(f"unsupported type in operation definition: {type_}")
+
+    @classmethod
+    def __init_subclass__(cls, *, name: str = None, **kwargs):
+        super().__init_subclass__(**kwargs)
+
+        fields = []
+        cls._fields = fields
+
+        for base in cls.__bases__:
+            if hasattr(base, "_fields"):
+                fields.extend(base._fields)
+        for key, value in cls.__annotations__.items():
+            field = Operation.convert_type_to_field_def(value)
+            setattr(field, "name", key)
----------------
rolfmorel wrote:

If every FieldDef has a `name`, I would just add it to FieldDef's constructor (rather than monkeypatching here).

https://github.com/llvm/llvm-project/pull/169045


More information about the Mlir-commits mailing list