[Mlir-commits] [mlir] [MLIR][Python] Add a DSL for defining IRDL dialects in Python bindings (PR #169045)

Rolf Morel llvmlistbot at llvm.org
Fri Dec 12 08:46:05 PST 2025


================
@@ -0,0 +1,384 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from typing import Dict, List, Union, Callable, Tuple
+from dataclasses import dataclass
+from inspect import Parameter, Signature
+from types import SimpleNamespace
+from abc import ABC, abstractmethod
+from contextlib import nullcontext
+from ...dialects import irdl
+from .._ods_common import _cext, segmented_accessor
+from . import Variadicity
+
+ir = _cext.ir
+
+__all__ = [
+    "Variadicity",
+    "Is",
+    "AnyOf",
+    "AllOf",
+    "Any",
+    "BaseName",
+    "BaseRef",
+    "Operand",
+    "Result",
+    "Attribute",
+    "Dialect",
+]
+
+
+class ConstraintExpr(ABC):
+    @abstractmethod
+    def _lower(self, ctx: "ConstraintLoweringContext") -> ir.Value:
+        pass
+
+    def __or__(self, other: "ConstraintExpr") -> "ConstraintExpr":
+        return AnyOf(self, other)
+
+    def __and__(self, other: "ConstraintExpr") -> "ConstraintExpr":
+        return AllOf(self, other)
+
+
+class ConstraintLoweringContext:
+    def __init__(self):
+        # Cache so that the same ConstraintExpr instance reuses its SSA value.
+        self._cache: Dict[int, ir.Value] = {}
+
+    def lower(self, expr: ConstraintExpr) -> ir.Value:
+        key = id(expr)
+        if key in self._cache:
+            return self._cache[key]
+        v = expr._lower(self)
+        self._cache[key] = v
+        return v
+
+
+class Is(ConstraintExpr):
+    def __init__(self, val: Union[ir.Attribute, ir.Type]):
+        self.val = val
+
+    def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
+        return irdl.is_(
+            ir.TypeAttr.get(self.val) if isinstance(self.val, ir.Type) else self.val
+        )
+
+
+class AnyOf(ConstraintExpr):
+    def __init__(self, *exprs: ConstraintExpr):
+        self.exprs = exprs
+
+    def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
+        return irdl.any_of(ctx.lower(expr) for expr in self.exprs)
+
+
+class AllOf(ConstraintExpr):
+    def __init__(self, *exprs: ConstraintExpr):
+        self.exprs = exprs
+
+    def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
+        return irdl.all_of(ctx.lower(expr) for expr in self.exprs)
+
+
+class Any(ConstraintExpr):
+    def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
+        return irdl.any()
+
+
+class BaseName(ConstraintExpr):
+    def __init__(self, name: str):
+        self.name = name
+
+    def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
+        return irdl.base(base_name=self.name)
+
+
+class BaseRef(ConstraintExpr):
+    def __init__(self, ref):
+        self.ref = ref
+
+    def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
+        return irdl.base(base_ref=self.ref)
+
+
+class FieldDef:
+    pass
+
+
+ at dataclass
+class Operand(FieldDef):
+    constraint: ConstraintExpr
+    variadicity: Variadicity = Variadicity.single
+
+
+ at dataclass
+class Result(FieldDef):
+    constraint: ConstraintExpr
+    variadicity: Variadicity = Variadicity.single
+
+
+ at dataclass
+class Attribute(FieldDef):
+    constraint: ConstraintExpr
+
+    def __post_init__(self):
+        # just for unified processing,
+        # currently optional attribute is not supported by IRDL
+        self.variadicity = Variadicity.single
+
+
+def partition_fields(
+    fields: List[FieldDef],
+) -> Tuple[List[Operand], List[Attribute], List[Result]]:
+    operands = [i for i in fields if isinstance(i, Operand)]
+    attrs = [i for i in fields if isinstance(i, Attribute)]
+    results = [i for i in fields if isinstance(i, Result)]
+    return operands, attrs, results
+
+
+def normalize_value_range(
+    value_range: Union[ir.OpOperandList, ir.OpResultList],
+    variadicity: Variadicity,
+):
+    if variadicity == Variadicity.single:
+        return value_range[0]
+    if variadicity == Variadicity.optional:
+        return value_range[0] if len(value_range) > 0 else None
+    return value_range
+
+
+class Operation(ir.OpView):
+    @classmethod
+    def __init_subclass__(cls, *, name: str = None, **kwargs):
+        super().__init_subclass__(**kwargs)
+
+        # for subclasses without "name" parameter,
+        # just treat them as normal classes
+        if not name:
+            return
+
+        op_name = name
+        cls._op_name = op_name
+        dialect_name = cls._dialect_name
+        dialect_obj = cls._dialect_obj
+
+        fields = []
+        cls._fields = fields
+
+        for key, value in cls.__dict__.items():
+            if isinstance(value, FieldDef):
+                setattr(value, "name", key)
+                fields.append(value)
+
+        cls._generate_class_attributes(dialect_name, op_name, fields)
+        cls._generate_init_method(fields)
+        operands, attrs, results = partition_fields(fields)
+        cls._generate_attr_properties(attrs)
+        cls._generate_operand_properties(operands)
+        cls._generate_result_properties(results)
+
+        dialect_obj.operations.append(cls)
----------------
rolfmorel wrote:

I guess the underlying question is how Python modules and dialects (should) relate. That is, dialects are very much modules as well, though currently in the Python bindings the "leading namespaces", e.g. `arith` in `arith.constant`, are Python modules and not dialect objects.

Not sure how this should look like for the DSL.

https://github.com/llvm/llvm-project/pull/169045


More information about the Mlir-commits mailing list