[Mlir-commits] [mlir] [MLIR][Python] Add a DSL for defining IRDL dialects in Python bindings (PR #169045)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Dec 13 20:35:11 PST 2025
================
@@ -0,0 +1,384 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from typing import Dict, List, Union, Callable, Tuple
+from dataclasses import dataclass
+from inspect import Parameter, Signature
+from types import SimpleNamespace
+from abc import ABC, abstractmethod
+from contextlib import nullcontext
+from ...dialects import irdl
+from .._ods_common import _cext, segmented_accessor
+from . import Variadicity
+
+ir = _cext.ir
+
+__all__ = [
+ "Variadicity",
+ "Is",
+ "AnyOf",
+ "AllOf",
+ "Any",
+ "BaseName",
+ "BaseRef",
+ "Operand",
+ "Result",
+ "Attribute",
+ "Dialect",
+]
+
+
+class ConstraintExpr(ABC):
+ @abstractmethod
+ def _lower(self, ctx: "ConstraintLoweringContext") -> ir.Value:
+ pass
+
+ def __or__(self, other: "ConstraintExpr") -> "ConstraintExpr":
+ return AnyOf(self, other)
+
+ def __and__(self, other: "ConstraintExpr") -> "ConstraintExpr":
+ return AllOf(self, other)
+
+
+class ConstraintLoweringContext:
+ def __init__(self):
+ # Cache so that the same ConstraintExpr instance reuses its SSA value.
+ self._cache: Dict[int, ir.Value] = {}
+
+ def lower(self, expr: ConstraintExpr) -> ir.Value:
+ key = id(expr)
+ if key in self._cache:
+ return self._cache[key]
+ v = expr._lower(self)
+ self._cache[key] = v
+ return v
+
+
+class Is(ConstraintExpr):
+ def __init__(self, val: Union[ir.Attribute, ir.Type]):
+ self.val = val
+
+ def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
+ return irdl.is_(
+ ir.TypeAttr.get(self.val) if isinstance(self.val, ir.Type) else self.val
+ )
+
+
+class AnyOf(ConstraintExpr):
+ def __init__(self, *exprs: ConstraintExpr):
+ self.exprs = exprs
+
+ def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
+ return irdl.any_of(ctx.lower(expr) for expr in self.exprs)
+
+
+class AllOf(ConstraintExpr):
+ def __init__(self, *exprs: ConstraintExpr):
+ self.exprs = exprs
+
+ def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
+ return irdl.all_of(ctx.lower(expr) for expr in self.exprs)
+
+
+class Any(ConstraintExpr):
+ def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
+ return irdl.any()
+
+
+class BaseName(ConstraintExpr):
+ def __init__(self, name: str):
+ self.name = name
+
+ def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
+ return irdl.base(base_name=self.name)
+
+
+class BaseRef(ConstraintExpr):
+ def __init__(self, ref):
+ self.ref = ref
+
+ def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
+ return irdl.base(base_ref=self.ref)
+
+
+class FieldDef:
+ pass
+
+
+ at dataclass
+class Operand(FieldDef):
+ constraint: ConstraintExpr
+ variadicity: Variadicity = Variadicity.single
+
+
+ at dataclass
+class Result(FieldDef):
+ constraint: ConstraintExpr
+ variadicity: Variadicity = Variadicity.single
+
+
+ at dataclass
+class Attribute(FieldDef):
+ constraint: ConstraintExpr
+
+ def __post_init__(self):
+ # just for unified processing,
+ # currently optional attribute is not supported by IRDL
+ self.variadicity = Variadicity.single
+
+
+def partition_fields(
+ fields: List[FieldDef],
+) -> Tuple[List[Operand], List[Attribute], List[Result]]:
+ operands = [i for i in fields if isinstance(i, Operand)]
+ attrs = [i for i in fields if isinstance(i, Attribute)]
+ results = [i for i in fields if isinstance(i, Result)]
+ return operands, attrs, results
+
+
+def normalize_value_range(
+ value_range: Union[ir.OpOperandList, ir.OpResultList],
+ variadicity: Variadicity,
+):
+ if variadicity == Variadicity.single:
+ return value_range[0]
+ if variadicity == Variadicity.optional:
+ return value_range[0] if len(value_range) > 0 else None
+ return value_range
+
+
+class Operation(ir.OpView):
+ @classmethod
+ def __init_subclass__(cls, *, name: str = None, **kwargs):
+ super().__init_subclass__(**kwargs)
+
+ # for subclasses without "name" parameter,
+ # just treat them as normal classes
+ if not name:
+ return
----------------
PragmaTwice wrote:
At the moment, the way we distinguish a concrete op from an op base class is simply whether name is defined. After weighing the trade-offs, this seems like the better choice. Sometimes adding other distinguishing criteria just introduces extra redundant syntax and more cases we need to handle.
But this is still open, if you have a better idea, feel free to let me know.
https://github.com/llvm/llvm-project/pull/169045
More information about the Mlir-commits
mailing list