[Mlir-commits] [mlir] [MLIR][Python] Add a DSL for defining dialects in Python bindings (PR #169045)
Rolf Morel
llvmlistbot at llvm.org
Thu Jan 22 08:29:02 PST 2026
================
@@ -0,0 +1,422 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from typing import (
+ Dict,
+ List,
+ Union,
+ Tuple,
+ Any,
+ Optional,
+ Callable,
+ TypeVar,
+ get_origin,
+ get_args,
+)
+from collections.abc import Sequence
+from dataclasses import dataclass
+from inspect import Parameter, Signature
+from types import UnionType
+from . import irdl
+from ._ods_common import _cext, segmented_accessor
+from .irdl import Variadicity
+from ..passmanager import PassManager
+
+ir = _cext.ir
+
+__all__ = [
+ "Dialect",
+ "Operand",
+ "Result",
+]
+
+Operand = ir.Value
+Result = ir.OpResult
+
+
+class ConstraintLoweringContext:
+ def __init__(self):
+ self._cache: Dict[str, ir.Value] = {}
+
+ def lower(self, type_) -> ir.Value:
+ if type(type_) is TypeVar:
+ if type_.__name__ in self._cache:
+ return self._cache[type_.__name__]
+ v = self._lower(type_.__bound__ or Any)
+ self._cache[type_.__name__] = v
+ else:
+ v = self._lower(type_)
+ return v
+
+ def _lower(self, type_) -> ir.Value:
+ origin = get_origin(type_)
+ if origin is UnionType or origin is Union:
+ return irdl.any_of(self.lower(arg) for arg in get_args(type_))
+ elif type_ is Any:
+ return irdl.any()
+ elif isinstance(type_, TypeVar):
+ return self.lower(type_)
+ elif origin and issubclass(origin, ir.Type):
+ t = origin.get(*get_args(type_))
+ return irdl.is_(ir.TypeAttr.get(t))
+ elif origin and issubclass(origin, ir.Attribute):
+ attr = origin.get(*get_args(type_))
+ return irdl.is_(attr)
+ elif issubclass(type_, ir.Type):
+ return irdl.base(base_name=f"!{type_.type_name}")
+ elif issubclass(type_, ir.Attribute):
+ return irdl.base(base_name=f"#{type_.attr_name}")
+
+ raise TypeError(f"unsupported type in constraints: {type_}")
+
+
+# A function to infer ir.Type from type annotation.
+# Returns a callable that returns the inferred ir.Type,
+# or None if the type cannot be inferred.
+# We use callables so that MLIR contexts are not required
+# while calling this function.
+def infer_type(type_) -> Optional[Callable[[], ir.Type]]:
+ origin = get_origin(type_)
+ if origin and issubclass(origin, ir.Type):
+ return lambda: origin.get(*get_args(type_))
+ elif isinstance(type_, TypeVar):
+ return infer_type(type_.__bound__)
+ return None
+
+
+class FieldDef:
+ pass
+
+
+ at dataclass
+class OperandDef(FieldDef):
+ constraint: Any
+ variadicity: Variadicity
+
+
+ at dataclass
+class ResultDef(FieldDef):
+ constraint: Any
+ variadicity: Variadicity
+
+
+ at dataclass
+class AttributeDef(FieldDef):
+ constraint: Any
+ variadicity: Variadicity
+
+ def __post_init__(self):
+ if self.variadicity != Variadicity.single:
+ raise ValueError("optional attribute is not supported in IRDL")
+
+
+def partition_fields(
+ fields: List[FieldDef],
+) -> Tuple[List[OperandDef], List[AttributeDef], List[ResultDef]]:
+ operands = [i for i in fields if isinstance(i, OperandDef)]
+ attrs = [i for i in fields if isinstance(i, AttributeDef)]
+ results = [i for i in fields if isinstance(i, ResultDef)]
+ return operands, attrs, results
+
+
+def normalize_value_range(
+ value_range: Union[ir.OpOperandList, ir.OpResultList],
+ variadicity: Variadicity,
+):
+ if variadicity == Variadicity.single:
+ return value_range[0]
+ if variadicity == Variadicity.optional:
+ return value_range[0] if len(value_range) > 0 else None
+ return value_range
+
+
+def match_optional(type_):
+ origin = get_origin(type_)
+ args = get_args(type_)
+ if (
+ (origin is Union or origin is UnionType)
+ and len(args) == 2
+ and type(None) in args
+ ):
+ return args[0] if args[1] is type(None) else args[1]
+
+ return None
+
+
+class Operation(ir.OpView):
+ @staticmethod
+ def convert_type_to_field_def(type_) -> FieldDef:
+ variadicity = Variadicity.single
+ if inner := match_optional(type_):
+ variadicity = Variadicity.optional
+ type_ = inner
+ elif get_origin(type_) is Sequence:
+ variadicity = Variadicity.variadic
+ type_ = get_args(type_)[0]
+
+ origin = get_origin(type_)
+ if origin is ir.OpResult:
+ return ResultDef(get_args(type_)[0], variadicity)
+ elif origin is ir.Value:
+ return OperandDef(get_args(type_)[0], variadicity)
+ elif issubclass(origin or type_, ir.Attribute):
+ return AttributeDef(type_, variadicity)
+ raise TypeError(f"unsupported type in operation definition: {type_}")
+
+ @classmethod
+ def __init_subclass__(cls, *, name: str = None, **kwargs):
+ super().__init_subclass__(**kwargs)
+
+ fields = []
+ cls._fields = fields
+
+ for base in cls.__bases__:
+ if hasattr(base, "_fields"):
+ fields.extend(base._fields)
+ for key, value in cls.__annotations__.items():
+ field = Operation.convert_type_to_field_def(value)
+ setattr(field, "name", key)
+ fields.append(field)
+
+ # for subclasses without "name" parameter,
+ # just treat them as normal classes
+ if not name:
+ return
+
+ op_name = name
+ cls._op_name = op_name
+ dialect_name = cls._dialect_name
+ dialect_obj = cls._dialect_obj
+
+ cls._generate_class_attributes(dialect_name, op_name, fields)
+ cls._generate_init_method(fields)
+ operands, attrs, results = partition_fields(fields)
+ cls._generate_attr_properties(attrs)
+ cls._generate_operand_properties(operands)
+ cls._generate_result_properties(results)
+
+ dialect_obj.operations.append(cls)
+
+ @staticmethod
+ def _variadicity_to_segment(variadicity: Variadicity) -> int:
+ if variadicity == Variadicity.variadic:
+ return -1
+ if variadicity == Variadicity.optional:
+ return 0
+ return 1
+
+ @staticmethod
+ def _generate_segments(
+ operands_or_results: List[Union[OperandDef, ResultDef]],
+ ) -> List[int]:
+ if any(i.variadicity != Variadicity.single for i in operands_or_results):
+ return [
+ Operation._variadicity_to_segment(i.variadicity)
+ for i in operands_or_results
+ ]
+ return None
+
+ @staticmethod
+ def _generate_init_signature(
+ fields: List[FieldDef], can_infer_types: bool
+ ) -> Signature:
+ result_args = (
+ [] if can_infer_types else [i for i in fields if isinstance(i, ResultDef)]
+ )
+ # results are placed at the beginning of the parameter list,
+ # but operands and attributes can appear in any relative order.
+ args = result_args + [i for i in fields if not isinstance(i, ResultDef)]
+ positional_args = [
+ i.name for i in args if i.variadicity != Variadicity.optional
+ ]
+ optional_args = [i.name for i in args if i.variadicity == Variadicity.optional]
+
+ params = [Parameter("self", Parameter.POSITIONAL_ONLY)]
+ for i in positional_args:
+ params.append(Parameter(i, Parameter.POSITIONAL_OR_KEYWORD))
+ for i in optional_args:
+ params.append(Parameter(i, Parameter.KEYWORD_ONLY, default=None))
+ params.append(Parameter("loc", Parameter.KEYWORD_ONLY, default=None))
+ params.append(Parameter("ip", Parameter.KEYWORD_ONLY, default=None))
+
+ return Signature(params)
+
+ @classmethod
+ def _generate_init_method(cls, fields: List[FieldDef]) -> None:
+ operands, attrs, results = partition_fields(fields)
+ inferred_types = [infer_type(i.constraint) for i in results]
+
+ # we infer result types only when all result types can be inferred
+ # and all results are single (not optional or variadic)
+ can_infer_types = all(inferred_types) and all(
+ i.variadicity == Variadicity.single for i in results
+ )
+
+ init_sig = cls._generate_init_signature(fields, can_infer_types)
+
+ def __init__(*args, **kwargs):
+ bound = init_sig.bind(*args, **kwargs)
+ bound.apply_defaults()
+ args = bound.arguments
+
+ _operands = [args[operand.name] for operand in operands]
+ _results = (
+ [t() for t in inferred_types]
+ if can_infer_types
+ else [args[result.name] for result in results]
+ )
+ _attributes = dict(
+ (attr.name, args[attr.name])
+ for attr in attrs
+ if args[attr.name] is not None
+ )
+ _regions = None
+ _ods_successors = None
+ self = args["self"]
+ super(Operation, self).__init__(
+ self.OPERATION_NAME,
+ self._ODS_REGIONS,
+ self._ODS_OPERAND_SEGMENTS,
+ self._ODS_RESULT_SEGMENTS,
+ attributes=_attributes,
+ results=_results,
+ operands=_operands,
+ successors=_ods_successors,
+ regions=_regions,
+ loc=args["loc"],
+ ip=args["ip"],
+ )
+
+ __init__.__signature__ = init_sig
+ cls.__init__ = __init__
+
+ @classmethod
+ def _generate_class_attributes(
+ cls, dialect_name: str, op_name: str, fields: List[FieldDef]
+ ) -> None:
+ operands, attrs, results = partition_fields(fields)
+
+ operand_segments = cls._generate_segments(operands)
+ result_segments = cls._generate_segments(results)
+
+ cls.OPERATION_NAME = f"{dialect_name}.{op_name}"
+ cls._ODS_REGIONS = (0, True)
+ cls._ODS_OPERAND_SEGMENTS = operand_segments
+ cls._ODS_RESULT_SEGMENTS = result_segments
+
+ @classmethod
+ def _generate_attr_properties(cls, attrs: List[AttributeDef]) -> None:
+ for attr in attrs:
+ setattr(
+ cls,
+ attr.name,
+ property(lambda self, name=attr.name: self.attributes[name]),
+ )
+
+ @classmethod
+ def _generate_operand_properties(cls, operands: List[OperandDef]) -> None:
+ for i, operand in enumerate(operands):
+ if cls._ODS_OPERAND_SEGMENTS:
+
+ def getter(self, i=i, operand=operand):
+ operand_range = segmented_accessor(
+ self.operation.operands,
+ self.operation.attributes["operandSegmentSizes"],
+ i,
+ )
+ return normalize_value_range(operand_range, operand.variadicity)
+
+ setattr(cls, operand.name, property(getter))
+ else:
+ setattr(cls, operand.name, property(lambda self, i=i: self.operands[i]))
+
+ @classmethod
+ def _generate_result_properties(cls, results: List[ResultDef]) -> None:
+ for i, result in enumerate(results):
+ if cls._ODS_RESULT_SEGMENTS:
+
+ def getter(self, i=i, result=result):
+ result_range = segmented_accessor(
+ self.operation.results,
+ self.operation.attributes["resultSegmentSizes"],
+ i,
+ )
+ return normalize_value_range(result_range, result.variadicity)
+
+ setattr(cls, result.name, property(getter))
+ else:
+ setattr(cls, result.name, property(lambda self, i=i: self.results[i]))
----------------
rolfmorel wrote:
Same here: why would I as a user be able to override the `i` index here? I.e. why have `i=i` as an arg?
https://github.com/llvm/llvm-project/pull/169045
More information about the Mlir-commits
mailing list