[Mlir-commits] [mlir] [MLIR][Python] Add a DSL for defining IRDL dialects in Python bindings (PR #169045)
Rolf Morel
llvmlistbot at llvm.org
Fri Dec 12 08:46:04 PST 2025
================
@@ -0,0 +1,384 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from typing import Dict, List, Union, Callable, Tuple
+from dataclasses import dataclass
+from inspect import Parameter, Signature
+from types import SimpleNamespace
+from abc import ABC, abstractmethod
+from contextlib import nullcontext
+from ...dialects import irdl
+from .._ods_common import _cext, segmented_accessor
+from . import Variadicity
+
+ir = _cext.ir
+
+__all__ = [
+ "Variadicity",
+ "Is",
+ "AnyOf",
+ "AllOf",
+ "Any",
+ "BaseName",
+ "BaseRef",
+ "Operand",
+ "Result",
+ "Attribute",
+ "Dialect",
+]
+
+
+class ConstraintExpr(ABC):
+ @abstractmethod
+ def _lower(self, ctx: "ConstraintLoweringContext") -> ir.Value:
+ pass
+
+ def __or__(self, other: "ConstraintExpr") -> "ConstraintExpr":
+ return AnyOf(self, other)
+
+ def __and__(self, other: "ConstraintExpr") -> "ConstraintExpr":
+ return AllOf(self, other)
+
+
+class ConstraintLoweringContext:
+ def __init__(self):
+ # Cache so that the same ConstraintExpr instance reuses its SSA value.
+ self._cache: Dict[int, ir.Value] = {}
+
+ def lower(self, expr: ConstraintExpr) -> ir.Value:
+ key = id(expr)
+ if key in self._cache:
+ return self._cache[key]
+ v = expr._lower(self)
+ self._cache[key] = v
+ return v
+
+
+class Is(ConstraintExpr):
+ def __init__(self, val: Union[ir.Attribute, ir.Type]):
+ self.val = val
+
+ def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
+ return irdl.is_(
+ ir.TypeAttr.get(self.val) if isinstance(self.val, ir.Type) else self.val
+ )
+
+
+class AnyOf(ConstraintExpr):
+ def __init__(self, *exprs: ConstraintExpr):
+ self.exprs = exprs
+
+ def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
+ return irdl.any_of(ctx.lower(expr) for expr in self.exprs)
+
+
+class AllOf(ConstraintExpr):
+ def __init__(self, *exprs: ConstraintExpr):
+ self.exprs = exprs
+
+ def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
+ return irdl.all_of(ctx.lower(expr) for expr in self.exprs)
+
+
+class Any(ConstraintExpr):
+ def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
+ return irdl.any()
+
+
+class BaseName(ConstraintExpr):
+ def __init__(self, name: str):
+ self.name = name
+
+ def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
+ return irdl.base(base_name=self.name)
+
+
+class BaseRef(ConstraintExpr):
+ def __init__(self, ref):
+ self.ref = ref
+
+ def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
+ return irdl.base(base_ref=self.ref)
+
+
+class FieldDef:
+ pass
+
+
+ at dataclass
+class Operand(FieldDef):
+ constraint: ConstraintExpr
+ variadicity: Variadicity = Variadicity.single
+
+
+ at dataclass
+class Result(FieldDef):
+ constraint: ConstraintExpr
+ variadicity: Variadicity = Variadicity.single
+
+
+ at dataclass
+class Attribute(FieldDef):
+ constraint: ConstraintExpr
+
+ def __post_init__(self):
+ # just for unified processing,
+ # currently optional attribute is not supported by IRDL
+ self.variadicity = Variadicity.single
+
+
+def partition_fields(
+ fields: List[FieldDef],
+) -> Tuple[List[Operand], List[Attribute], List[Result]]:
+ operands = [i for i in fields if isinstance(i, Operand)]
+ attrs = [i for i in fields if isinstance(i, Attribute)]
+ results = [i for i in fields if isinstance(i, Result)]
+ return operands, attrs, results
+
+
+def normalize_value_range(
+ value_range: Union[ir.OpOperandList, ir.OpResultList],
+ variadicity: Variadicity,
+):
+ if variadicity == Variadicity.single:
+ return value_range[0]
+ if variadicity == Variadicity.optional:
+ return value_range[0] if len(value_range) > 0 else None
+ return value_range
+
+
+class Operation(ir.OpView):
+ @classmethod
+ def __init_subclass__(cls, *, name: str = None, **kwargs):
+ super().__init_subclass__(**kwargs)
+
+ # for subclasses without "name" parameter,
+ # just treat them as normal classes
+ if not name:
+ return
----------------
rolfmorel wrote:
Is there anyway to detect we are dealing with an intermediary class (e.g. a `dsl.Dialect`'s subclass of `dsl.Operation`) versus a concrete op definition? I am just thinking that I might forget to pass `name = ...` as a parameter to `class MyOp(MyDialect.Operation)` and will be scratching my head for a while as to what went wrong.
https://github.com/llvm/llvm-project/pull/169045
More information about the Mlir-commits
mailing list