[Mlir-commits] [mlir] [MLIR][Python] Add a DSL for defining dialects in Python bindings (PR #169045)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 9 05:41:28 PST 2026
https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/169045
>From 3ec4809f2b5a76d14eba1a0707e30791cc5cc805 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Wed, 19 Nov 2025 01:04:38 +0800
Subject: [PATCH 01/20] [MLIR][Python] Add a DSL for defining IRDL dialects in
Python bindings
---
mlir/python/CMakeLists.txt | 4 +-
.../dialects/{irdl.py => irdl/__init__.py} | 13 +-
mlir/python/mlir/dialects/irdl/dsl.py | 343 ++++++++++++++++++
mlir/test/python/dialects/irdsl.py | 308 ++++++++++++++++
4 files changed, 661 insertions(+), 7 deletions(-)
rename mlir/python/mlir/dialects/{irdl.py => irdl/__init__.py} (91%)
create mode 100644 mlir/python/mlir/dialects/irdl/dsl.py
create mode 100644 mlir/test/python/dialects/irdsl.py
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 2acb6ee6cfda5..27b87a8b70144 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -511,7 +511,9 @@ declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/IRDLOps.td
- SOURCES dialects/irdl.py
+ SOURCES
+ dialects/irdl/__init__.py
+ dialects/irdl/dsl.py
DIALECT_NAME irdl
GEN_ENUM_BINDINGS
)
diff --git a/mlir/python/mlir/dialects/irdl.py b/mlir/python/mlir/dialects/irdl/__init__.py
similarity index 91%
rename from mlir/python/mlir/dialects/irdl.py
rename to mlir/python/mlir/dialects/irdl/__init__.py
index 1ec951b69b646..6b2787ed7966c 100644
--- a/mlir/python/mlir/dialects/irdl.py
+++ b/mlir/python/mlir/dialects/irdl/__init__.py
@@ -2,13 +2,14 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from ._irdl_ops_gen import *
-from ._irdl_ops_gen import _Dialect
-from ._irdl_enum_gen import *
-from .._mlir_libs._mlirDialectsIRDL import *
-from ..ir import register_attribute_builder
-from ._ods_common import _cext as _ods_cext
+from .._irdl_ops_gen import *
+from .._irdl_ops_gen import _Dialect
+from .._irdl_enum_gen import *
+from ..._mlir_libs._mlirDialectsIRDL import *
+from ...ir import register_attribute_builder
+from .._ods_common import _cext as _ods_cext
from typing import Union, Sequence
+from . import dsl
_ods_ir = _ods_cext.ir
diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
new file mode 100644
index 0000000000000..3cc234503665a
--- /dev/null
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -0,0 +1,343 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ...dialects import irdl as _irdl
+from .._ods_common import (
+ _cext as _ods_cext,
+ segmented_accessor as _ods_segmented_accessor,
+)
+from . import Variadicity
+from typing import Dict, List, Union, Callable, Tuple
+from dataclasses import dataclass
+from inspect import Parameter as _Parameter, Signature as _Signature
+from types import SimpleNamespace as _SimpleNameSpace
+
+_ods_ir = _ods_cext.ir
+
+
+class ConstraintExpr:
+ def _lower(self, ctx: "ConstraintLoweringContext") -> _ods_ir.Value:
+ raise NotImplementedError()
+
+ def __or__(self, other: "ConstraintExpr") -> "ConstraintExpr":
+ return AnyOf(self, other)
+
+ def __and__(self, other: "ConstraintExpr") -> "ConstraintExpr":
+ return AllOf(self, other)
+
+
+class ConstraintLoweringContext:
+ def __init__(self):
+ # Cache so that the same ConstraintExpr instance reuses its SSA value.
+ self._cache: Dict[int, _ods_ir.Value] = {}
+
+ def lower(self, expr: ConstraintExpr) -> _ods_ir.Value:
+ key = id(expr)
+ if key in self._cache:
+ return self._cache[key]
+ v = expr._lower(self)
+ self._cache[key] = v
+ return v
+
+
+class Is(ConstraintExpr):
+ def __init__(self, attr: _ods_ir.Attribute):
+ self.attr = attr
+
+ def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
+ return _irdl.is_(self.attr)
+
+
+class IsType(Is):
+ def __init__(self, typ: _ods_ir.Type):
+ super().__init__(_ods_ir.TypeAttr.get(typ))
+
+
+class AnyOf(ConstraintExpr):
+ def __init__(self, *exprs: ConstraintExpr):
+ self.exprs = exprs
+
+ def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
+ return _irdl.any_of(ctx.lower(expr) for expr in self.exprs)
+
+
+class AllOf(ConstraintExpr):
+ def __init__(self, *exprs: ConstraintExpr):
+ self.exprs = exprs
+
+ def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
+ return _irdl.all_of(ctx.lower(expr) for expr in self.exprs)
+
+
+class Any(ConstraintExpr):
+ def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
+ return _irdl.any()
+
+
+class BaseName(ConstraintExpr):
+ def __init__(self, name: str):
+ self.name = name
+
+ def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
+ return _irdl.base(base_name=self.name)
+
+
+class BaseRef(ConstraintExpr):
+ def __init__(self, ref):
+ self.ref = ref
+
+ def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
+ return _irdl.base(base_ref=self.ref)
+
+
+class FieldDef:
+ def __set_name__(self, owner, name: str):
+ self.name = name
+
+
+ at dataclass
+class Operand(FieldDef):
+ constraint: ConstraintExpr
+ variadicity: Variadicity = Variadicity.single
+
+
+ at dataclass
+class Result(FieldDef):
+ constraint: ConstraintExpr
+ variadicity: Variadicity = Variadicity.single
+
+
+ at dataclass
+class Attribute(FieldDef):
+ constraint: ConstraintExpr
+
+ def __post_init__(self):
+ # just for unified processing,
+ # currently optional attribute is not supported by IRDL
+ self.variadicity = Variadicity.single
+
+
+ at dataclass
+class Operation:
+ dialect_name: str
+ name: str
+ # We store operands and attributes into one list to maintain relative orders
+ # among them for generating OpView class.
+ operands_and_attrs: List[Union[Operand, Attribute]]
+ results: List[Result]
+
+ def _emit(self) -> None:
+ op = _irdl.operation_(self.name)
+ ctx = ConstraintLoweringContext()
+
+ operands = [i for i in self.operands_and_attrs if isinstance(i, Operand)]
+ attrs = [i for i in self.operands_and_attrs if isinstance(i, Attribute)]
+
+ with _ods_ir.InsertionPoint(op.body):
+ if operands:
+ _irdl.operands_(
+ [ctx.lower(i.constraint) for i in operands],
+ [i.name for i in operands],
+ [i.variadicity for i in operands],
+ )
+ if attrs:
+ _irdl.attributes_(
+ [ctx.lower(i.constraint) for i in attrs],
+ [i.name for i in attrs],
+ )
+ if self.results:
+ _irdl.results_(
+ [ctx.lower(i.constraint) for i in self.results],
+ [i.name for i in self.results],
+ [i.variadicity for i in self.results],
+ )
+
+ def _make_op_view_and_builder(self) -> Tuple[type, Callable]:
+ operands = [i for i in self.operands_and_attrs if isinstance(i, Operand)]
+ attrs = [i for i in self.operands_and_attrs if isinstance(i, Attribute)]
+
+ def variadicity_to_segment(variadicity: Variadicity) -> int:
+ if variadicity == Variadicity.variadic:
+ return -1
+ if variadicity == Variadicity.optional:
+ return 0
+ return 1
+
+ operand_segments = None
+ if any(i.variadicity != Variadicity.single for i in operands):
+ operand_segments = [variadicity_to_segment(i.variadicity) for i in operands]
+
+ result_segments = None
+ if any(i.variadicity != Variadicity.single for i in self.results):
+ result_segments = [
+ variadicity_to_segment(i.variadicity) for i in self.results
+ ]
+
+ args = self.results + self.operands_and_attrs
+ positional_args = [
+ i.name for i in args if i.variadicity != Variadicity.optional
+ ]
+ optional_args = [i.name for i in args if i.variadicity == Variadicity.optional]
+
+ params = [_Parameter("self", _Parameter.POSITIONAL_ONLY)]
+ for i in positional_args:
+ params.append(_Parameter(i, _Parameter.POSITIONAL_OR_KEYWORD))
+ for i in optional_args:
+ params.append(_Parameter(i, _Parameter.KEYWORD_ONLY, default=None))
+ params.append(_Parameter("loc", _Parameter.KEYWORD_ONLY, default=None))
+ params.append(_Parameter("ip", _Parameter.KEYWORD_ONLY, default=None))
+
+ sig = _Signature(params)
+ op = self
+
+ class _OpView(_ods_ir.OpView):
+ OPERATION_NAME = f"{op.dialect_name}.{op.name}"
+ _ODS_REGIONS = (0, True)
+ _ODS_OPERAND_SEGMENTS = operand_segments
+ _ODS_RESULT_SEGMENTS = result_segments
+
+ def __init__(*args, **kwargs):
+ bound = sig.bind(*args, **kwargs)
+ bound.apply_defaults()
+ args = bound.arguments
+
+ _operands = [args[operand.name] for operand in operands]
+ _results = [args[result.name] for result in op.results]
+ _attributes = dict(
+ (attr.name, args[attr.name])
+ for attr in attrs
+ if args[attr.name] is not None
+ )
+ _regions = None
+ _ods_successors = None
+ self = args["self"]
+ super(_OpView, self).__init__(
+ self.OPERATION_NAME,
+ self._ODS_REGIONS,
+ self._ODS_OPERAND_SEGMENTS,
+ self._ODS_RESULT_SEGMENTS,
+ attributes=_attributes,
+ results=_results,
+ operands=_operands,
+ successors=_ods_successors,
+ regions=_regions,
+ loc=args["loc"],
+ ip=args["ip"],
+ )
+
+ __init__.__signature__ = sig
+
+ for attr in attrs:
+ setattr(
+ _OpView,
+ attr.name,
+ property(lambda self, name=attr.name: self.attributes[name]),
+ )
+
+ def value_range_getter(
+ value_range: Union[_ods_ir.OpOperandList, _ods_ir.OpResultList],
+ variadicity: Variadicity,
+ ):
+ if variadicity == Variadicity.single:
+ return value_range[0]
+ if variadicity == Variadicity.optional:
+ return value_range[0] if len(value_range) > 0 else None
+ return value_range
+
+ for i, operand in enumerate(operands):
+ if operand_segments:
+
+ def getter(self, i=i, operand=operand):
+ operand_range = _ods_segmented_accessor(
+ self.operation.operands,
+ self.operation.attributes["operandSegmentSizes"],
+ i,
+ )
+ return value_range_getter(operand_range, operand.variadicity)
+
+ setattr(_OpView, operand.name, property(getter))
+ else:
+ setattr(
+ _OpView, operand.name, property(lambda self, i=i: self.operands[i])
+ )
+ for i, result in enumerate(self.results):
+ if result_segments:
+
+ def getter(self, i=i, result=result):
+ result_range = _ods_segmented_accessor(
+ self.operation.results,
+ self.operation.attributes["resultSegmentSizes"],
+ i,
+ )
+ return value_range_getter(result_range, result.variadicity)
+
+ setattr(_OpView, result.name, property(getter))
+ else:
+ setattr(
+ _OpView, result.name, property(lambda self, i=i: self.results[i])
+ )
+
+ def _builder(*args, **kwargs) -> _OpView:
+ return _OpView(*args, **kwargs)
+
+ _builder.__signature__ = _Signature(params[1:])
+
+ return _OpView, _builder
+
+
+class Dialect:
+ def __init__(self, name: str):
+ self.name = name
+ self.operations: List[Operation] = []
+ self.namespace = _SimpleNameSpace()
+
+ def _emit(self) -> None:
+ d = _irdl.dialect(self.name)
+ with _ods_ir.InsertionPoint(d.body):
+ for op in self.operations:
+ op._emit()
+
+ def _make_module(self) -> _ods_ir.Module:
+ with _ods_ir.Location.unknown():
+ m = _ods_ir.Module.create()
+ with _ods_ir.InsertionPoint(m.body):
+ self._emit()
+ return m
+
+ def _make_dialect_class(self) -> type:
+ class _Dialect(_ods_ir.Dialect):
+ DIALECT_NAMESPACE = self.name
+
+ return _Dialect
+
+ def load(self) -> _SimpleNameSpace:
+ _irdl.load_dialects(self._make_module())
+ dialect_class = self._make_dialect_class()
+ _ods_cext.register_dialect(dialect_class)
+ for op in self.operations:
+ _ods_cext.register_operation(dialect_class)(op.op_view)
+ return self.namespace
+
+ def op(self, name: str) -> Callable[[type], type]:
+ def decorator(cls: type) -> type:
+ operands_and_attrs: List[Union[Operand, Attribute]] = []
+ results: List[Result] = []
+
+ for field in cls.__dict__.values():
+ if isinstance(field, Operand) or isinstance(field, Attribute):
+ operands_and_attrs.append(field)
+ elif isinstance(field, Result):
+ results.append(field)
+
+ op_def = Operation(self.name, name, operands_and_attrs, results)
+ op_view, builder = op_def._make_op_view_and_builder()
+ setattr(op_def, "op_view", op_view)
+ setattr(op_def, "builder", builder)
+ self.operations.append(op_def)
+ self.namespace.__dict__[cls.__name__] = op_view
+ op_view.__name__ = cls.__name__
+ self.namespace.__dict__[name.replace(".", "_")] = builder
+ return cls
+
+ return decorator
diff --git a/mlir/test/python/dialects/irdsl.py b/mlir/test/python/dialects/irdsl.py
new file mode 100644
index 0000000000000..8ef30ae0a4c13
--- /dev/null
+++ b/mlir/test/python/dialects/irdsl.py
@@ -0,0 +1,308 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects.irdl import dsl as irdsl
+from mlir.dialects import arith
+import sys
+
+
+def run(f):
+ print("\nTEST:", f.__name__, file=sys.stderr)
+ with Context():
+ f()
+
+
+# CHECK: TEST: testMyInt
+ at run
+def testMyInt():
+ myint = irdsl.Dialect("myint")
+ iattr = irdsl.BaseName("#builtin.integer")
+ i32 = irdsl.IsType(IntegerType.get_signless(32))
+
+ @myint.op("constant")
+ class ConstantOp:
+ value = irdsl.Attribute(iattr)
+ cst = irdsl.Result(i32)
+
+ @myint.op("add")
+ class AddOp:
+ lhs = irdsl.Operand(i32)
+ rhs = irdsl.Operand(i32)
+ res = irdsl.Result(i32)
+
+ # CHECK: irdl.dialect @myint {
+ # CHECK: irdl.operation @constant {
+ # CHECK: %0 = irdl.base "#builtin.integer"
+ # CHECK: irdl.attributes {"value" = %0}
+ # CHECK: %1 = irdl.is i32
+ # CHECK: irdl.results(cst: %1)
+ # CHECK: }
+ # CHECK: irdl.operation @add {
+ # CHECK: %0 = irdl.is i32
+ # CHECK: irdl.operands(lhs: %0, rhs: %0)
+ # CHECK: irdl.results(res: %0)
+ # CHECK: }
+ # CHECK: }
+ print(myint._make_module())
+ myint = myint.load()
+
+ # CHECK: ['ConstantOp', 'constant', 'AddOp', 'add']
+ print([i for i in myint.__dict__.keys()])
+
+ i32 = IntegerType.get_signless(32)
+ with Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ two = myint.constant(i32, IntegerAttr.get(i32, 2))
+ three = myint.constant(i32, IntegerAttr.get(i32, 3))
+ add1 = myint.add(i32, two, three)
+ add2 = myint.add(i32, add1, two)
+ add3 = myint.add(i32, add2, three)
+
+ # CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
+ # CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
+ # CHECK: %2 = "myint.add"(%0, %1) : (i32, i32) -> i32
+ # CHECK: %3 = "myint.add"(%2, %0) : (i32, i32) -> i32
+ # CHECK: %4 = "myint.add"(%3, %1) : (i32, i32) -> i32
+ print(module)
+ assert module.operation.verify()
+
+ # CHECK: AddOp
+ print(type(add1).__name__)
+ # CHECK: ConstantOp
+ print(type(two).__name__)
+ # CHECK: myint.add
+ print(add1.OPERATION_NAME)
+ # CHECK: None
+ print(add1._ODS_OPERAND_SEGMENTS)
+ # CHECK: None
+ print(add1._ODS_RESULT_SEGMENTS)
+ # CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
+ print(add1.lhs.owner)
+ # CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
+ print(add1.rhs.owner)
+ # CHECK: 2 : i32
+ print(two.value)
+ # CHECK: Value(%0
+ print(two.cst)
+ # CHECK: (res, lhs, rhs, *, loc=None, ip=None)
+ print(myint.add.__signature__)
+ # CHECK: (cst, value, *, loc=None, ip=None)
+ print(myint.constant.__signature__)
+
+
+ at run
+def testIRDSL():
+ test = irdsl.Dialect("irdsl_test")
+ i32 = irdsl.IsType(IntegerType.get_signless(32))
+ i64 = irdsl.IsType(IntegerType.get_signless(64))
+ i32or64 = i32 | i64
+ any = irdsl.Any()
+ f32 = irdsl.IsType(F32Type.get())
+ iattr = irdsl.BaseName("#builtin.integer")
+ fattr = irdsl.BaseName("#builtin.float")
+
+ @test.op("constraint")
+ class ConstraintOp:
+ a = irdsl.Operand(i32or64)
+ b = irdsl.Operand(any)
+ c = irdsl.Operand(f32 | i32)
+ d = irdsl.Operand(any)
+ x = irdsl.Attribute(iattr)
+ y = irdsl.Attribute(fattr)
+
+ @test.op("optional")
+ class OptionalOp:
+ a = irdsl.Operand(i32)
+ b = irdsl.Operand(i32, irdsl.Variadicity.optional)
+ out1 = irdsl.Result(i32)
+ out2 = irdsl.Result(i32, irdsl.Variadicity.optional)
+ out3 = irdsl.Result(i32)
+
+ @test.op("optional2")
+ class Optional2Op:
+ a = irdsl.Operand(i32, irdsl.Variadicity.optional)
+ b = irdsl.Result(i32, irdsl.Variadicity.optional)
+
+ @test.op("variadic")
+ class VariadicOp:
+ a = irdsl.Operand(i32)
+ b = irdsl.Operand(i32, irdsl.Variadicity.optional)
+ c = irdsl.Operand(i32, irdsl.Variadicity.variadic)
+ out1 = irdsl.Result(i32, irdsl.Variadicity.variadic)
+ out2 = irdsl.Result(i32, irdsl.Variadicity.variadic)
+ out3 = irdsl.Result(i32, irdsl.Variadicity.optional)
+ out4 = irdsl.Result(i32)
+
+ @test.op("variadic2")
+ class Variadic2Op:
+ a = irdsl.Operand(i32, irdsl.Variadicity.variadic)
+ b = irdsl.Result(i32, irdsl.Variadicity.variadic)
+
+ @test.op("mixed")
+ class MixedOp:
+ out = irdsl.Result(i32)
+ in1 = irdsl.Operand(i32)
+ in2 = irdsl.Attribute(iattr)
+ in3 = irdsl.Operand(i32, irdsl.Variadicity.optional)
+ in4 = irdsl.Attribute(iattr)
+ in5 = irdsl.Operand(i32)
+
+ # CHECK: irdl.dialect @irdsl_test {
+ # CHECK: irdl.operation @constraint {
+ # CHECK: %0 = irdl.is i32
+ # CHECK: %1 = irdl.is i64
+ # CHECK: %2 = irdl.any_of(%0, %1)
+ # CHECK: %3 = irdl.any
+ # CHECK: %4 = irdl.is f32
+ # CHECK: %5 = irdl.any_of(%4, %0)
+ # CHECK: irdl.operands(a: %2, b: %3, c: %5, d: %3)
+ # CHECK: %6 = irdl.base "#builtin.integer"
+ # CHECK: %7 = irdl.base "#builtin.float"
+ # CHECK: irdl.attributes {"x" = %6, "y" = %7}
+ # CHECK: }
+ # CHECK: irdl.operation @optional {
+ # CHECK: %0 = irdl.is i32
+ # CHECK: irdl.operands(a: %0, b: optional %0)
+ # CHECK: irdl.results(out1: %0, out2: optional %0, out3: %0)
+ # CHECK: }
+ # CHECK: irdl.operation @optional2 {
+ # CHECK: %0 = irdl.is i32
+ # CHECK: irdl.operands(a: optional %0)
+ # CHECK: irdl.results(b: optional %0)
+ # CHECK: }
+ # CHECK: irdl.operation @variadic {
+ # CHECK: %0 = irdl.is i32
+ # CHECK: irdl.operands(a: %0, b: optional %0, c: variadic %0)
+ # CHECK: irdl.results(out1: variadic %0, out2: variadic %0, out3: optional %0, out4: %0)
+ # CHECK: }
+ # CHECK: irdl.operation @variadic2 {
+ # CHECK: %0 = irdl.is i32
+ # CHECK: irdl.operands(a: variadic %0)
+ # CHECK: irdl.results(b: variadic %0)
+ # CHECK: }
+ # CHECK: irdl.operation @mixed {
+ # CHECK: %0 = irdl.is i32
+ # CHECK: irdl.operands(in1: %0, in3: optional %0, in5: %0)
+ # CHECK: %1 = irdl.base "#builtin.integer"
+ # CHECK: irdl.attributes {"in2" = %1, "in4" = %1}
+ # CHECK: irdl.results(out: %0)
+ # CHECK: }
+ # CHECK: }
+ print(test._make_module())
+ test = test.load()
+
+ # CHECK: (a, b, c, d, x, y, *, loc=None, ip=None)
+ print(test.constraint.__signature__)
+ # CHECK: (out1, out3, a, *, out2=None, b=None, loc=None, ip=None)
+ print(test.optional.__signature__)
+ # CHECK: (*, b=None, a=None, loc=None, ip=None)
+ print(test.optional2.__signature__)
+ # CHECK: (out1, out2, out4, a, c, *, out3=None, b=None, loc=None, ip=None)
+ print(test.variadic.__signature__)
+ # CHECK: (b, a, *, loc=None, ip=None)
+ print(test.variadic2.__signature__)
+ # CHECK: (out, in1, in2, in4, in5, *, in3=None, loc=None, ip=None)
+ print(test.mixed.__signature__)
+
+ # CHECK: None None
+ print(
+ test.ConstraintOp._ODS_OPERAND_SEGMENTS, test.ConstraintOp._ODS_RESULT_SEGMENTS
+ )
+ # CHECK: [1, 0] [1, 0, 1]
+ print(test.OptionalOp._ODS_OPERAND_SEGMENTS, test.OptionalOp._ODS_RESULT_SEGMENTS)
+ # CHECK: [0] [0]
+ print(test.Optional2Op._ODS_OPERAND_SEGMENTS, test.Optional2Op._ODS_RESULT_SEGMENTS)
+ # CHECK: [1, 0, -1] [-1, -1, 0, 1]
+ print(test.VariadicOp._ODS_OPERAND_SEGMENTS, test.VariadicOp._ODS_RESULT_SEGMENTS)
+ # CHECK: [-1] [-1]
+ print(test.Variadic2Op._ODS_OPERAND_SEGMENTS, test.Variadic2Op._ODS_RESULT_SEGMENTS)
+
+ i32 = IntegerType.get_signless(32)
+ i64 = IntegerType.get_signless(64)
+ f32 = F32Type.get()
+
+ with Location.unknown():
+ iattr = IntegerAttr.get(i32, 2)
+ fattr = FloatAttr.get_f32(2.3)
+
+ module = Module.create()
+ with InsertionPoint(module.body):
+ ione = arith.constant(i32, 1)
+ fone = arith.constant(f32, 1.2)
+
+ # CHECK: "irdsl_test.constraint"(%c1_i32, %c1_i32, %cst, %c1_i32) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, i32, f32, i32) -> ()
+ c1 = test.constraint(ione, ione, fone, ione, iattr, fattr)
+ # CHECK: "irdsl_test.constraint"(%c1_i32, %cst, %cst, %cst) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, f32, f32, f32) -> ()
+ test.constraint(ione, fone, fone, fone, iattr, fattr)
+ # CHECK: irdsl_test.constraint"(%c1_i32, %cst, %c1_i32, %cst) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, f32, i32, f32) -> ()
+ test.constraint(ione, fone, ione, fone, iattr, fattr)
+
+ # CHECK: %0:2 = "irdsl_test.optional"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0>, resultSegmentSizes = array<i32: 1, 0, 1>} : (i32) -> (i32, i32)
+ o1 = test.optional(i32, i32, ione)
+ # CHECK: %1:3 = "irdsl_test.optional"(%c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 1>, resultSegmentSizes = array<i32: 1, 1, 1>} : (i32, i32) -> (i32, i32, i32)
+ o2 = test.optional(i32, i32, ione, out2=i32, b=ione)
+ # CHECK: irdsl_test.optional2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 0>} : () -> ()
+ o3 = test.optional2()
+ # CHECK: %2 = "irdsl_test.optional2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 1>} : () -> i32
+ o4 = test.optional2(b=i32)
+ # CHECK: "irdsl_test.optional2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 0>} : (i32) -> ()
+ o5 = test.optional2(a=ione)
+ # CHECK: %3 = "irdsl_test.optional2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 1>} : (i32) -> i32
+ o6 = test.optional2(b=i32, a=ione)
+
+ # CHECK: %4:4 = "irdsl_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 0, 2>, resultSegmentSizes = array<i32: 1, 2, 0, 1>} : (i32, i32, i32) -> (i32, i32, i32, i32)
+ v1 = test.variadic([i32], [i32, i32], i32, ione, [ione, ione])
+ # CHECK: %5:5 = "irdsl_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 1, 1>, resultSegmentSizes = array<i32: 1, 2, 1, 1>} : (i32, i32, i32) -> (i32, i32, i32, i32, i32)
+ v2 = test.variadic([i32], [i32, i32], i32, ione, [ione], out3=i32, b=ione)
+ # CHECK: %6:4 = "irdsl_test.variadic"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0, 0>, resultSegmentSizes = array<i32: 2, 1, 0, 1>} : (i32) -> (i32, i32, i32, i32)
+ v3 = test.variadic([i32, i32], [i32], i32, ione, [])
+ # CHECK: "irdsl_test.variadic2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 0>} : () -> ()
+ v4 = test.variadic2([], [])
+ # CHECK: "irdsl_test.variadic2"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 3>, resultSegmentSizes = array<i32: 0>} : (i32, i32, i32) -> ()
+ v5 = test.variadic2([], [ione, ione, ione])
+ # CHECK: %7:2 = "irdsl_test.variadic2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 2>} : (i32) -> (i32, i32)
+ v6 = test.variadic2([i32, i32], [ione])
+
+ # CHECK: %8 = "irdsl_test.mixed"(%c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 0, 1>} : (i32, i32) -> i32
+ m1 = test.mixed(i32, ione, iattr, iattr, ione)
+ # CHECK: %9 = "irdsl_test.mixed"(%c1_i32, %c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 1, 1>} : (i32, i32, i32) -> i32
+ m2 = test.mixed(i32, ione, iattr, iattr, ione, in3=ione)
+
+ print(module)
+ assert module.operation.verify()
+
+ # CHECK: Value(%c1_i32
+ print(c1.a)
+ # CHECK: 2 : i32
+ print(c1.x)
+ # CHECK: Value(%c1_i32
+ print(o1.a)
+ # CHECK: None
+ print(o1.b)
+ # CHECK: Value(%c1_i32
+ print(o2.b)
+ # CHECK: 0
+ print(o1.out1.result_number)
+ # CHECK: None
+ print(o1.out2)
+ # CHECK: 0
+ print(o2.out1.result_number)
+ # CHECK: 1
+ print(o2.out2.result_number)
+ # CHECK: None
+ print(o3.a)
+ # CHECK: Value(%c1_i32
+ print(o5.a)
+ # CHECK: ['Value(%c1_i32 = arith.constant 1 : i32)', 'Value(%c1_i32 = arith.constant 1 : i32)']
+ print([str(i) for i in v1.c])
+ # CHECK: ['Value(%c1_i32 = arith.constant 1 : i32)']
+ print([str(i) for i in v2.c])
+ # CHECK: []
+ print([str(i) for i in v3.c])
+ # CHECK: 0 0
+ print(len(v4.a), len(v4.b))
+ # CHECK: 3 0
+ print(len(v5.a), len(v5.b))
+ # CHECK: 1 2
+ print(len(v6.a), len(v6.b))
>From c0b8ba70e97dd09179ba75c6ae5dde04d3be152e Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 24 Nov 2025 12:55:21 +0800
Subject: [PATCH 02/20] make FieldDef private
---
mlir/python/mlir/dialects/irdl/dsl.py | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
index 3cc234503665a..b1916fdbd5f9b 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -91,25 +91,25 @@ def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
return _irdl.base(base_ref=self.ref)
-class FieldDef:
+class _FieldDef:
def __set_name__(self, owner, name: str):
self.name = name
@dataclass
-class Operand(FieldDef):
+class Operand(_FieldDef):
constraint: ConstraintExpr
variadicity: Variadicity = Variadicity.single
@dataclass
-class Result(FieldDef):
+class Result(_FieldDef):
constraint: ConstraintExpr
variadicity: Variadicity = Variadicity.single
@dataclass
-class Attribute(FieldDef):
+class Attribute(_FieldDef):
constraint: ConstraintExpr
def __post_init__(self):
>From 7de9ad40b6c293efc2bbb8f51140b82ad17522b2 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 24 Nov 2025 13:18:07 +0800
Subject: [PATCH 03/20] refactor some methods
---
mlir/python/mlir/dialects/irdl/dsl.py | 63 ++++++++++++++++-----------
1 file changed, 37 insertions(+), 26 deletions(-)
diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
index b1916fdbd5f9b..307609c1b342e 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -127,13 +127,16 @@ class Operation:
operands_and_attrs: List[Union[Operand, Attribute]]
results: List[Result]
- def _emit(self) -> None:
- op = _irdl.operation_(self.name)
- ctx = ConstraintLoweringContext()
-
+ def _partition_operands_and_attrs(self) -> Tuple[List[Operand], List[Attribute]]:
operands = [i for i in self.operands_and_attrs if isinstance(i, Operand)]
attrs = [i for i in self.operands_and_attrs if isinstance(i, Attribute)]
+ return operands, attrs
+
+ def _emit(self) -> None:
+ ctx = ConstraintLoweringContext()
+ operands, attrs = self._partition_operands_and_attrs()
+ op = _irdl.operation_(self.name)
with _ods_ir.InsertionPoint(op.body):
if operands:
_irdl.operands_(
@@ -153,27 +156,26 @@ def _emit(self) -> None:
[i.variadicity for i in self.results],
)
- def _make_op_view_and_builder(self) -> Tuple[type, Callable]:
- operands = [i for i in self.operands_and_attrs if isinstance(i, Operand)]
- attrs = [i for i in self.operands_and_attrs if isinstance(i, Attribute)]
-
- def variadicity_to_segment(variadicity: Variadicity) -> int:
- if variadicity == Variadicity.variadic:
- return -1
- if variadicity == Variadicity.optional:
- return 0
- return 1
-
- operand_segments = None
- if any(i.variadicity != Variadicity.single for i in operands):
- operand_segments = [variadicity_to_segment(i.variadicity) for i in operands]
-
- result_segments = None
- if any(i.variadicity != Variadicity.single for i in self.results):
- result_segments = [
- variadicity_to_segment(i.variadicity) for i in self.results
+ @staticmethod
+ def _variadicity_to_segment(variadicity: Variadicity) -> int:
+ if variadicity == Variadicity.variadic:
+ return -1
+ if variadicity == Variadicity.optional:
+ return 0
+ return 1
+
+ @staticmethod
+ def _generate_segments(
+ operands_or_results: List[Union[Operand, Result]],
+ ) -> List[int]:
+ if any(i.variadicity != Variadicity.single for i in operands_or_results):
+ return [
+ Operation._variadicity_to_segment(i.variadicity)
+ for i in operands_or_results
]
+ return None
+ def _generate_init_params(self) -> List[_Parameter]:
args = self.results + self.operands_and_attrs
positional_args = [
i.name for i in args if i.variadicity != Variadicity.optional
@@ -188,7 +190,16 @@ def variadicity_to_segment(variadicity: Variadicity) -> int:
params.append(_Parameter("loc", _Parameter.KEYWORD_ONLY, default=None))
params.append(_Parameter("ip", _Parameter.KEYWORD_ONLY, default=None))
- sig = _Signature(params)
+ return params
+
+ def _make_op_view_and_builder(self) -> Tuple[type, Callable]:
+ operands, attrs = self._partition_operands_and_attrs()
+
+ operand_segments = Operation._generate_segments(operands)
+ result_segments = Operation._generate_segments(self.results)
+
+ params = self._generate_init_params()
+ init_sig = _Signature(params)
op = self
class _OpView(_ods_ir.OpView):
@@ -198,7 +209,7 @@ class _OpView(_ods_ir.OpView):
_ODS_RESULT_SEGMENTS = result_segments
def __init__(*args, **kwargs):
- bound = sig.bind(*args, **kwargs)
+ bound = init_sig.bind(*args, **kwargs)
bound.apply_defaults()
args = bound.arguments
@@ -226,7 +237,7 @@ def __init__(*args, **kwargs):
ip=args["ip"],
)
- __init__.__signature__ = sig
+ __init__.__signature__ = init_sig
for attr in attrs:
setattr(
>From 985cd25e1486ea4e8a7a67f76daf835172c58112 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 24 Nov 2025 14:05:26 +0800
Subject: [PATCH 04/20] more refactor
---
mlir/python/mlir/dialects/irdl/dsl.py | 47 ++++++++++++++-------------
1 file changed, 24 insertions(+), 23 deletions(-)
diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
index 307609c1b342e..0a30678e1318c 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -124,17 +124,17 @@ class Operation:
name: str
# We store operands and attributes into one list to maintain relative orders
# among them for generating OpView class.
- operands_and_attrs: List[Union[Operand, Attribute]]
- results: List[Result]
+ fields: List[Union[Operand, Attribute, Result]]
- def _partition_operands_and_attrs(self) -> Tuple[List[Operand], List[Attribute]]:
- operands = [i for i in self.operands_and_attrs if isinstance(i, Operand)]
- attrs = [i for i in self.operands_and_attrs if isinstance(i, Attribute)]
- return operands, attrs
+ def _partition_fields(self) -> Tuple[List[Operand], List[Attribute], List[Result]]:
+ operands = [i for i in self.fields if isinstance(i, Operand)]
+ attrs = [i for i in self.fields if isinstance(i, Attribute)]
+ results = [i for i in self.fields if isinstance(i, Result)]
+ return operands, attrs, results
def _emit(self) -> None:
ctx = ConstraintLoweringContext()
- operands, attrs = self._partition_operands_and_attrs()
+ operands, attrs, results = self._partition_fields()
op = _irdl.operation_(self.name)
with _ods_ir.InsertionPoint(op.body):
@@ -149,11 +149,11 @@ def _emit(self) -> None:
[ctx.lower(i.constraint) for i in attrs],
[i.name for i in attrs],
)
- if self.results:
+ if results:
_irdl.results_(
- [ctx.lower(i.constraint) for i in self.results],
- [i.name for i in self.results],
- [i.variadicity for i in self.results],
+ [ctx.lower(i.constraint) for i in results],
+ [i.name for i in results],
+ [i.variadicity for i in results],
)
@staticmethod
@@ -176,7 +176,11 @@ def _generate_segments(
return None
def _generate_init_params(self) -> List[_Parameter]:
- args = self.results + self.operands_and_attrs
+ # results are placed at the beginning of the parameter list,
+ # but operands and attributes can appear in any relative order.
+ args = [i for i in self.fields if isinstance(i, Result)] + [
+ i for i in self.fields if not isinstance(i, Result)
+ ]
positional_args = [
i.name for i in args if i.variadicity != Variadicity.optional
]
@@ -193,10 +197,10 @@ def _generate_init_params(self) -> List[_Parameter]:
return params
def _make_op_view_and_builder(self) -> Tuple[type, Callable]:
- operands, attrs = self._partition_operands_and_attrs()
+ operands, attrs, results = self._partition_fields()
operand_segments = Operation._generate_segments(operands)
- result_segments = Operation._generate_segments(self.results)
+ result_segments = Operation._generate_segments(results)
params = self._generate_init_params()
init_sig = _Signature(params)
@@ -214,7 +218,7 @@ def __init__(*args, **kwargs):
args = bound.arguments
_operands = [args[operand.name] for operand in operands]
- _results = [args[result.name] for result in op.results]
+ _results = [args[result.name] for result in results]
_attributes = dict(
(attr.name, args[attr.name])
for attr in attrs
@@ -272,7 +276,7 @@ def getter(self, i=i, operand=operand):
setattr(
_OpView, operand.name, property(lambda self, i=i: self.operands[i])
)
- for i, result in enumerate(self.results):
+ for i, result in enumerate(results):
if result_segments:
def getter(self, i=i, result=result):
@@ -332,16 +336,13 @@ def load(self) -> _SimpleNameSpace:
def op(self, name: str) -> Callable[[type], type]:
def decorator(cls: type) -> type:
- operands_and_attrs: List[Union[Operand, Attribute]] = []
- results: List[Result] = []
+ fields: List[Union[Operand, Attribute, Result]] = []
for field in cls.__dict__.values():
- if isinstance(field, Operand) or isinstance(field, Attribute):
- operands_and_attrs.append(field)
- elif isinstance(field, Result):
- results.append(field)
+ if isinstance(field, _FieldDef):
+ fields.append(field)
- op_def = Operation(self.name, name, operands_and_attrs, results)
+ op_def = Operation(self.name, name, fields)
op_view, builder = op_def._make_op_view_and_builder()
setattr(op_def, "op_view", op_view)
setattr(op_def, "builder", builder)
>From 2405fe4bd522c21ca7e71bd35645d49b938da570 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 24 Nov 2025 14:10:48 +0800
Subject: [PATCH 05/20] more refactor
---
mlir/python/mlir/dialects/irdl/dsl.py | 14 +++++++-------
1 file changed, 7 insertions(+), 7 deletions(-)
diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
index 0a30678e1318c..d98d82d1ba4c9 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -336,20 +336,20 @@ def load(self) -> _SimpleNameSpace:
def op(self, name: str) -> Callable[[type], type]:
def decorator(cls: type) -> type:
- fields: List[Union[Operand, Attribute, Result]] = []
-
- for field in cls.__dict__.values():
- if isinstance(field, _FieldDef):
- fields.append(field)
-
+ fields = [
+ field for field in cls.__dict__.values() if isinstance(field, _FieldDef)
+ ]
op_def = Operation(self.name, name, fields)
+
op_view, builder = op_def._make_op_view_and_builder()
+ op_view.__name__ = cls.__name__
setattr(op_def, "op_view", op_view)
setattr(op_def, "builder", builder)
self.operations.append(op_def)
+
self.namespace.__dict__[cls.__name__] = op_view
- op_view.__name__ = cls.__name__
self.namespace.__dict__[name.replace(".", "_")] = builder
+
return cls
return decorator
>From d74e6603052698a35b3efc84c303d91d28d9a9df Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 24 Nov 2025 14:16:18 +0800
Subject: [PATCH 06/20] add a method to convert op name
---
mlir/python/mlir/dialects/irdl/dsl.py | 6 +++++-
1 file changed, 5 insertions(+), 1 deletion(-)
diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
index d98d82d1ba4c9..2092c2417145a 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -334,6 +334,10 @@ def load(self) -> _SimpleNameSpace:
_ods_cext.register_operation(dialect_class)(op.op_view)
return self.namespace
+ @staticmethod
+ def op_name_to_python_id(name):
+ return name.replace(".", "_").replace("$", "_")
+
def op(self, name: str) -> Callable[[type], type]:
def decorator(cls: type) -> type:
fields = [
@@ -348,7 +352,7 @@ def decorator(cls: type) -> type:
self.operations.append(op_def)
self.namespace.__dict__[cls.__name__] = op_view
- self.namespace.__dict__[name.replace(".", "_")] = builder
+ self.namespace.__dict__[Dialect.op_name_to_python_id(name)] = builder
return cls
>From bff9fb79903788ad4d2115def90ac2b1771fa81a Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 24 Nov 2025 14:16:58 +0800
Subject: [PATCH 07/20] make it private
---
mlir/python/mlir/dialects/irdl/dsl.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
index 2092c2417145a..0db7cbbed81ca 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -335,7 +335,7 @@ def load(self) -> _SimpleNameSpace:
return self.namespace
@staticmethod
- def op_name_to_python_id(name):
+ def _op_name_to_python_id(name):
return name.replace(".", "_").replace("$", "_")
def op(self, name: str) -> Callable[[type], type]:
@@ -352,7 +352,7 @@ def decorator(cls: type) -> type:
self.operations.append(op_def)
self.namespace.__dict__[cls.__name__] = op_view
- self.namespace.__dict__[Dialect.op_name_to_python_id(name)] = builder
+ self.namespace.__dict__[Dialect._op_name_to_python_id(name)] = builder
return cls
>From 4c4c81cf9f4b9f99d52053b4267ce1db8689d083 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 24 Nov 2025 14:17:25 +0800
Subject: [PATCH 08/20] add type annotation
---
mlir/python/mlir/dialects/irdl/dsl.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
index 0db7cbbed81ca..41520f0e0f68c 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -335,7 +335,7 @@ def load(self) -> _SimpleNameSpace:
return self.namespace
@staticmethod
- def _op_name_to_python_id(name):
+ def _op_name_to_python_id(name: str) -> str:
return name.replace(".", "_").replace("$", "_")
def op(self, name: str) -> Callable[[type], type]:
>From 145ca87c2dce074b5eb9a76457646e57c51e9cb7 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Tue, 9 Dec 2025 23:33:38 +0800
Subject: [PATCH 09/20] Refactor IRDSL via __init_subclass__
---
mlir/python/mlir/dialects/irdl/dsl.py | 431 ++++++++++++++------------
mlir/test/python/dialects/irdsl.py | 136 ++++----
2 files changed, 291 insertions(+), 276 deletions(-)
diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
index 41520f0e0f68c..630740c0103ab 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -2,23 +2,37 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from ...dialects import irdl as _irdl
-from .._ods_common import (
- _cext as _ods_cext,
- segmented_accessor as _ods_segmented_accessor,
-)
-from . import Variadicity
from typing import Dict, List, Union, Callable, Tuple
from dataclasses import dataclass
-from inspect import Parameter as _Parameter, Signature as _Signature
-from types import SimpleNamespace as _SimpleNameSpace
+from inspect import Parameter, Signature
+from types import SimpleNamespace
+from abc import ABC, abstractmethod
+from contextlib import nullcontext
+from ...dialects import irdl
+from .._ods_common import _cext, segmented_accessor
+from . import Variadicity
+
+ir = _cext.ir
-_ods_ir = _ods_cext.ir
+__all__ = [
+ "Variadicity",
+ "Is",
+ "AnyOf",
+ "AllOf",
+ "Any",
+ "BaseName",
+ "BaseRef",
+ "Operand",
+ "Result",
+ "Attribute",
+ "Dialect",
+]
-class ConstraintExpr:
- def _lower(self, ctx: "ConstraintLoweringContext") -> _ods_ir.Value:
- raise NotImplementedError()
+class ConstraintExpr(ABC):
+ @abstractmethod
+ def _lower(self, ctx: "ConstraintLoweringContext") -> ir.Value:
+ pass
def __or__(self, other: "ConstraintExpr") -> "ConstraintExpr":
return AnyOf(self, other)
@@ -30,9 +44,9 @@ def __and__(self, other: "ConstraintExpr") -> "ConstraintExpr":
class ConstraintLoweringContext:
def __init__(self):
# Cache so that the same ConstraintExpr instance reuses its SSA value.
- self._cache: Dict[int, _ods_ir.Value] = {}
+ self._cache: Dict[int, ir.Value] = {}
- def lower(self, expr: ConstraintExpr) -> _ods_ir.Value:
+ def lower(self, expr: ConstraintExpr) -> ir.Value:
key = id(expr)
if key in self._cache:
return self._cache[key]
@@ -42,74 +56,70 @@ def lower(self, expr: ConstraintExpr) -> _ods_ir.Value:
class Is(ConstraintExpr):
- def __init__(self, attr: _ods_ir.Attribute):
- self.attr = attr
-
- def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
- return _irdl.is_(self.attr)
-
+ def __init__(self, val: Union[ir.Attribute, ir.Type]):
+ self.val = val
-class IsType(Is):
- def __init__(self, typ: _ods_ir.Type):
- super().__init__(_ods_ir.TypeAttr.get(typ))
+ def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
+ return irdl.is_(
+ ir.TypeAttr.get(self.val) if isinstance(self.val, ir.Type) else self.val
+ )
class AnyOf(ConstraintExpr):
def __init__(self, *exprs: ConstraintExpr):
self.exprs = exprs
- def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
- return _irdl.any_of(ctx.lower(expr) for expr in self.exprs)
+ def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
+ return irdl.any_of(ctx.lower(expr) for expr in self.exprs)
class AllOf(ConstraintExpr):
def __init__(self, *exprs: ConstraintExpr):
self.exprs = exprs
- def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
- return _irdl.all_of(ctx.lower(expr) for expr in self.exprs)
+ def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
+ return irdl.all_of(ctx.lower(expr) for expr in self.exprs)
class Any(ConstraintExpr):
- def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
- return _irdl.any()
+ def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
+ return irdl.any()
class BaseName(ConstraintExpr):
def __init__(self, name: str):
self.name = name
- def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
- return _irdl.base(base_name=self.name)
+ def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
+ return irdl.base(base_name=self.name)
class BaseRef(ConstraintExpr):
def __init__(self, ref):
self.ref = ref
- def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
- return _irdl.base(base_ref=self.ref)
+ def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
+ return irdl.base(base_ref=self.ref)
-class _FieldDef:
- def __set_name__(self, owner, name: str):
- self.name = name
+class FieldDef:
+ pass
@dataclass
-class Operand(_FieldDef):
+class Operand(FieldDef):
constraint: ConstraintExpr
variadicity: Variadicity = Variadicity.single
@dataclass
-class Result(_FieldDef):
+class Result(FieldDef):
constraint: ConstraintExpr
variadicity: Variadicity = Variadicity.single
@dataclass
-class Attribute(_FieldDef):
+class Attribute(FieldDef):
constraint: ConstraintExpr
def __post_init__(self):
@@ -118,43 +128,57 @@ def __post_init__(self):
self.variadicity = Variadicity.single
- at dataclass
-class Operation:
- dialect_name: str
- name: str
- # We store operands and attributes into one list to maintain relative orders
- # among them for generating OpView class.
- fields: List[Union[Operand, Attribute, Result]]
-
- def _partition_fields(self) -> Tuple[List[Operand], List[Attribute], List[Result]]:
- operands = [i for i in self.fields if isinstance(i, Operand)]
- attrs = [i for i in self.fields if isinstance(i, Attribute)]
- results = [i for i in self.fields if isinstance(i, Result)]
- return operands, attrs, results
-
- def _emit(self) -> None:
- ctx = ConstraintLoweringContext()
- operands, attrs, results = self._partition_fields()
+def partition_fields(
+ fields: List[FieldDef],
+) -> Tuple[List[Operand], List[Attribute], List[Result]]:
+ operands = [i for i in fields if isinstance(i, Operand)]
+ attrs = [i for i in fields if isinstance(i, Attribute)]
+ results = [i for i in fields if isinstance(i, Result)]
+ return operands, attrs, results
- op = _irdl.operation_(self.name)
- with _ods_ir.InsertionPoint(op.body):
- if operands:
- _irdl.operands_(
- [ctx.lower(i.constraint) for i in operands],
- [i.name for i in operands],
- [i.variadicity for i in operands],
- )
- if attrs:
- _irdl.attributes_(
- [ctx.lower(i.constraint) for i in attrs],
- [i.name for i in attrs],
- )
- if results:
- _irdl.results_(
- [ctx.lower(i.constraint) for i in results],
- [i.name for i in results],
- [i.variadicity for i in results],
- )
+
+def normalize_value_range(
+ value_range: Union[ir.OpOperandList, ir.OpResultList],
+ variadicity: Variadicity,
+):
+ if variadicity == Variadicity.single:
+ return value_range[0]
+ if variadicity == Variadicity.optional:
+ return value_range[0] if len(value_range) > 0 else None
+ return value_range
+
+
+class Operation(ir.OpView):
+ @classmethod
+ def __init_subclass__(cls, *, name=None, **kwargs):
+ super().__init_subclass__(**kwargs)
+
+ # for subclasses without "name" parameter,
+ # just treat them as normal classes
+ if not name:
+ return
+
+ op_name = name
+ cls._op_name = op_name
+ dialect_name = cls._dialect_name
+ dialect_obj = cls._dialect_obj
+
+ fields = []
+ cls._fields = fields
+
+ for key, value in cls.__dict__.items():
+ if isinstance(value, FieldDef):
+ setattr(value, "name", key)
+ fields.append(value)
+
+ cls._generate_class_attributes(dialect_name, op_name, fields)
+ cls._generate_init_method(fields)
+ operands, attrs, results = partition_fields(fields)
+ cls._generate_attr_properties(attrs)
+ cls._generate_operand_properties(operands)
+ cls._generate_result_properties(results)
+
+ dialect_obj.operations.append(cls)
@staticmethod
def _variadicity_to_segment(variadicity: Variadicity) -> int:
@@ -175,185 +199,184 @@ def _generate_segments(
]
return None
- def _generate_init_params(self) -> List[_Parameter]:
+ @staticmethod
+ def _generate_init_signature(fields: List[FieldDef]) -> Signature:
# results are placed at the beginning of the parameter list,
# but operands and attributes can appear in any relative order.
- args = [i for i in self.fields if isinstance(i, Result)] + [
- i for i in self.fields if not isinstance(i, Result)
+ args = [i for i in fields if isinstance(i, Result)] + [
+ i for i in fields if not isinstance(i, Result)
]
positional_args = [
i.name for i in args if i.variadicity != Variadicity.optional
]
optional_args = [i.name for i in args if i.variadicity == Variadicity.optional]
- params = [_Parameter("self", _Parameter.POSITIONAL_ONLY)]
+ params = [Parameter("self", Parameter.POSITIONAL_ONLY)]
for i in positional_args:
- params.append(_Parameter(i, _Parameter.POSITIONAL_OR_KEYWORD))
+ params.append(Parameter(i, Parameter.POSITIONAL_OR_KEYWORD))
for i in optional_args:
- params.append(_Parameter(i, _Parameter.KEYWORD_ONLY, default=None))
- params.append(_Parameter("loc", _Parameter.KEYWORD_ONLY, default=None))
- params.append(_Parameter("ip", _Parameter.KEYWORD_ONLY, default=None))
-
- return params
-
- def _make_op_view_and_builder(self) -> Tuple[type, Callable]:
- operands, attrs, results = self._partition_fields()
-
- operand_segments = Operation._generate_segments(operands)
- result_segments = Operation._generate_segments(results)
-
- params = self._generate_init_params()
- init_sig = _Signature(params)
- op = self
-
- class _OpView(_ods_ir.OpView):
- OPERATION_NAME = f"{op.dialect_name}.{op.name}"
- _ODS_REGIONS = (0, True)
- _ODS_OPERAND_SEGMENTS = operand_segments
- _ODS_RESULT_SEGMENTS = result_segments
-
- def __init__(*args, **kwargs):
- bound = init_sig.bind(*args, **kwargs)
- bound.apply_defaults()
- args = bound.arguments
-
- _operands = [args[operand.name] for operand in operands]
- _results = [args[result.name] for result in results]
- _attributes = dict(
- (attr.name, args[attr.name])
- for attr in attrs
- if args[attr.name] is not None
- )
- _regions = None
- _ods_successors = None
- self = args["self"]
- super(_OpView, self).__init__(
- self.OPERATION_NAME,
- self._ODS_REGIONS,
- self._ODS_OPERAND_SEGMENTS,
- self._ODS_RESULT_SEGMENTS,
- attributes=_attributes,
- results=_results,
- operands=_operands,
- successors=_ods_successors,
- regions=_regions,
- loc=args["loc"],
- ip=args["ip"],
- )
+ params.append(Parameter(i, Parameter.KEYWORD_ONLY, default=None))
+ params.append(Parameter("loc", Parameter.KEYWORD_ONLY, default=None))
+ params.append(Parameter("ip", Parameter.KEYWORD_ONLY, default=None))
+
+ return Signature(params)
+
+ @classmethod
+ def _generate_init_method(cls, fields):
+ init_sig = cls._generate_init_signature(fields)
+ operands, attrs, results = partition_fields(fields)
+
+ def __init__(*args, **kwargs):
+ bound = init_sig.bind(*args, **kwargs)
+ bound.apply_defaults()
+ args = bound.arguments
+
+ _operands = [args[operand.name] for operand in operands]
+ _results = [args[result.name] for result in results]
+ _attributes = dict(
+ (attr.name, args[attr.name])
+ for attr in attrs
+ if args[attr.name] is not None
+ )
+ _regions = None
+ _ods_successors = None
+ self = args["self"]
+ super(Operation, self).__init__(
+ self.OPERATION_NAME,
+ self._ODS_REGIONS,
+ self._ODS_OPERAND_SEGMENTS,
+ self._ODS_RESULT_SEGMENTS,
+ attributes=_attributes,
+ results=_results,
+ operands=_operands,
+ successors=_ods_successors,
+ regions=_regions,
+ loc=args["loc"],
+ ip=args["ip"],
+ )
+
+ __init__.__signature__ = init_sig
+ cls.__init__ = __init__
+
+ @classmethod
+ def _generate_class_attributes(cls, dialect_name, op_name, fields):
+ operands, attrs, results = partition_fields(fields)
+
+ operand_segments = cls._generate_segments(operands)
+ result_segments = cls._generate_segments(results)
- __init__.__signature__ = init_sig
+ cls.OPERATION_NAME = f"{dialect_name}.{op_name}"
+ cls._ODS_REGIONS = (0, True)
+ cls._ODS_OPERAND_SEGMENTS = operand_segments
+ cls._ODS_RESULT_SEGMENTS = result_segments
+ @classmethod
+ def _generate_attr_properties(cls, attrs):
for attr in attrs:
setattr(
- _OpView,
+ cls,
attr.name,
property(lambda self, name=attr.name: self.attributes[name]),
)
- def value_range_getter(
- value_range: Union[_ods_ir.OpOperandList, _ods_ir.OpResultList],
- variadicity: Variadicity,
- ):
- if variadicity == Variadicity.single:
- return value_range[0]
- if variadicity == Variadicity.optional:
- return value_range[0] if len(value_range) > 0 else None
- return value_range
-
+ @classmethod
+ def _generate_operand_properties(cls, operands):
for i, operand in enumerate(operands):
- if operand_segments:
+ if cls._ODS_OPERAND_SEGMENTS:
def getter(self, i=i, operand=operand):
- operand_range = _ods_segmented_accessor(
+ operand_range = segmented_accessor(
self.operation.operands,
self.operation.attributes["operandSegmentSizes"],
i,
)
- return value_range_getter(operand_range, operand.variadicity)
+ return normalize_value_range(operand_range, operand.variadicity)
- setattr(_OpView, operand.name, property(getter))
+ setattr(cls, operand.name, property(getter))
else:
- setattr(
- _OpView, operand.name, property(lambda self, i=i: self.operands[i])
- )
+ setattr(cls, operand.name, property(lambda self, i=i: self.operands[i]))
+
+ @classmethod
+ def _generate_result_properties(cls, results):
for i, result in enumerate(results):
- if result_segments:
+ if cls._ODS_RESULT_SEGMENTS:
def getter(self, i=i, result=result):
- result_range = _ods_segmented_accessor(
+ result_range = segmented_accessor(
self.operation.results,
self.operation.attributes["resultSegmentSizes"],
i,
)
- return value_range_getter(result_range, result.variadicity)
+ return normalize_value_range(result_range, result.variadicity)
- setattr(_OpView, result.name, property(getter))
+ setattr(cls, result.name, property(getter))
else:
- setattr(
- _OpView, result.name, property(lambda self, i=i: self.results[i])
- )
+ setattr(cls, result.name, property(lambda self, i=i: self.results[i]))
- def _builder(*args, **kwargs) -> _OpView:
- return _OpView(*args, **kwargs)
-
- _builder.__signature__ = _Signature(params[1:])
+ @classmethod
+ def _emit_operation(cls) -> None:
+ ctx = ConstraintLoweringContext()
+ operands, attrs, results = partition_fields(cls._fields)
- return _OpView, _builder
+ op = irdl.operation_(cls._op_name)
+ with ir.InsertionPoint(op.body):
+ if operands:
+ irdl.operands_(
+ [ctx.lower(i.constraint) for i in operands],
+ [i.name for i in operands],
+ [i.variadicity for i in operands],
+ )
+ if attrs:
+ irdl.attributes_(
+ [ctx.lower(i.constraint) for i in attrs],
+ [i.name for i in attrs],
+ )
+ if results:
+ irdl.results_(
+ [ctx.lower(i.constraint) for i in results],
+ [i.name for i in results],
+ [i.variadicity for i in results],
+ )
class Dialect:
def __init__(self, name: str):
self.name = name
- self.operations: List[Operation] = []
- self.namespace = _SimpleNameSpace()
-
- def _emit(self) -> None:
- d = _irdl.dialect(self.name)
- with _ods_ir.InsertionPoint(d.body):
+ self.operations = []
+ self.Operation = type(
+ "Operation",
+ (Operation,),
+ {"_dialect_obj": self, "_dialect_name": name},
+ )
+
+ def _emit_dialect(self) -> None:
+ d = irdl.dialect(self.name)
+ with ir.InsertionPoint(d.body):
for op in self.operations:
- op._emit()
+ op._emit_operation()
+
+ def _emit_module(self) -> ir.Module:
+ m = ir.Module.create()
+ with ir.InsertionPoint(m.body):
+ self._emit_dialect()
- def _make_module(self) -> _ods_ir.Module:
- with _ods_ir.Location.unknown():
- m = _ods_ir.Module.create()
- with _ods_ir.InsertionPoint(m.body):
- self._emit()
return m
def _make_dialect_class(self) -> type:
- class _Dialect(_ods_ir.Dialect):
- DIALECT_NAMESPACE = self.name
-
- return _Dialect
-
- def load(self) -> _SimpleNameSpace:
- _irdl.load_dialects(self._make_module())
- dialect_class = self._make_dialect_class()
- _ods_cext.register_dialect(dialect_class)
- for op in self.operations:
- _ods_cext.register_operation(dialect_class)(op.op_view)
- return self.namespace
+ return type("Dialect", (ir.Dialect,), {"DIALECT_NAMESPACE": self.name})
- @staticmethod
- def _op_name_to_python_id(name: str) -> str:
- return name.replace(".", "_").replace("$", "_")
-
- def op(self, name: str) -> Callable[[type], type]:
- def decorator(cls: type) -> type:
- fields = [
- field for field in cls.__dict__.values() if isinstance(field, _FieldDef)
- ]
- op_def = Operation(self.name, name, fields)
+ def load(self) -> None:
+ if hasattr(self, "mlir_module"):
+ raise RuntimeError(f"Dialect {self.name} is already loaded.")
- op_view, builder = op_def._make_op_view_and_builder()
- op_view.__name__ = cls.__name__
- setattr(op_def, "op_view", op_view)
- setattr(op_def, "builder", builder)
- self.operations.append(op_def)
+ mlir_module = self._emit_module()
+ irdl.load_dialects(mlir_module)
- self.namespace.__dict__[cls.__name__] = op_view
- self.namespace.__dict__[Dialect._op_name_to_python_id(name)] = builder
+ dialect_class = self._make_dialect_class()
+ _cext.register_dialect(dialect_class)
- return cls
+ # for op in self.operations:
+ # _cext.register_operation(dialect_class)(op)
- return decorator
+ self.mlir_module = mlir_module
+ self.dialect_class = dialect_class
diff --git a/mlir/test/python/dialects/irdsl.py b/mlir/test/python/dialects/irdsl.py
index 8ef30ae0a4c13..f3a80655efe80 100644
--- a/mlir/test/python/dialects/irdsl.py
+++ b/mlir/test/python/dialects/irdsl.py
@@ -17,15 +17,13 @@ def run(f):
def testMyInt():
myint = irdsl.Dialect("myint")
iattr = irdsl.BaseName("#builtin.integer")
- i32 = irdsl.IsType(IntegerType.get_signless(32))
+ i32 = irdsl.Is(IntegerType.get_signless(32))
- @myint.op("constant")
- class ConstantOp:
+ class ConstantOp(myint.Operation, name="constant"):
value = irdsl.Attribute(iattr)
cst = irdsl.Result(i32)
- @myint.op("add")
- class AddOp:
+ class AddOp(myint.Operation, name="add"):
lhs = irdsl.Operand(i32)
rhs = irdsl.Operand(i32)
res = irdsl.Result(i32)
@@ -43,21 +41,22 @@ class AddOp:
# CHECK: irdl.results(res: %0)
# CHECK: }
# CHECK: }
- print(myint._make_module())
- myint = myint.load()
+ with Location.unknown():
+ myint.load()
+ print(myint.mlir_module)
- # CHECK: ['ConstantOp', 'constant', 'AddOp', 'add']
- print([i for i in myint.__dict__.keys()])
+ # CHECK: ['constant', 'add']
+ print([i._op_name for i in myint.operations])
i32 = IntegerType.get_signless(32)
with Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
- two = myint.constant(i32, IntegerAttr.get(i32, 2))
- three = myint.constant(i32, IntegerAttr.get(i32, 3))
- add1 = myint.add(i32, two, three)
- add2 = myint.add(i32, add1, two)
- add3 = myint.add(i32, add2, three)
+ two = ConstantOp(i32, IntegerAttr.get(i32, 2))
+ three = ConstantOp(i32, IntegerAttr.get(i32, 3))
+ add1 = AddOp(i32, two, three)
+ add2 = AddOp(i32, add1, two)
+ add3 = AddOp(i32, add2, three)
# CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
# CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
@@ -85,25 +84,24 @@ class AddOp:
print(two.value)
# CHECK: Value(%0
print(two.cst)
- # CHECK: (res, lhs, rhs, *, loc=None, ip=None)
- print(myint.add.__signature__)
- # CHECK: (cst, value, *, loc=None, ip=None)
- print(myint.constant.__signature__)
+ # CHECK: (self, /, res, lhs, rhs, *, loc=None, ip=None)
+ print(AddOp.__init__.__signature__)
+ # CHECK: (self, /, cst, value, *, loc=None, ip=None)
+ print(ConstantOp.__init__.__signature__)
@run
def testIRDSL():
test = irdsl.Dialect("irdsl_test")
- i32 = irdsl.IsType(IntegerType.get_signless(32))
- i64 = irdsl.IsType(IntegerType.get_signless(64))
+ i32 = irdsl.Is(IntegerType.get_signless(32))
+ i64 = irdsl.Is(IntegerType.get_signless(64))
i32or64 = i32 | i64
any = irdsl.Any()
- f32 = irdsl.IsType(F32Type.get())
+ f32 = irdsl.Is(F32Type.get())
iattr = irdsl.BaseName("#builtin.integer")
fattr = irdsl.BaseName("#builtin.float")
- @test.op("constraint")
- class ConstraintOp:
+ class ConstraintOp(test.Operation, name="constraint"):
a = irdsl.Operand(i32or64)
b = irdsl.Operand(any)
c = irdsl.Operand(f32 | i32)
@@ -111,21 +109,18 @@ class ConstraintOp:
x = irdsl.Attribute(iattr)
y = irdsl.Attribute(fattr)
- @test.op("optional")
- class OptionalOp:
+ class OptionalOp(test.Operation, name="optional"):
a = irdsl.Operand(i32)
b = irdsl.Operand(i32, irdsl.Variadicity.optional)
out1 = irdsl.Result(i32)
out2 = irdsl.Result(i32, irdsl.Variadicity.optional)
out3 = irdsl.Result(i32)
- @test.op("optional2")
- class Optional2Op:
+ class Optional2Op(test.Operation, name="optional2"):
a = irdsl.Operand(i32, irdsl.Variadicity.optional)
b = irdsl.Result(i32, irdsl.Variadicity.optional)
- @test.op("variadic")
- class VariadicOp:
+ class VariadicOp(test.Operation, name="variadic"):
a = irdsl.Operand(i32)
b = irdsl.Operand(i32, irdsl.Variadicity.optional)
c = irdsl.Operand(i32, irdsl.Variadicity.variadic)
@@ -134,13 +129,11 @@ class VariadicOp:
out3 = irdsl.Result(i32, irdsl.Variadicity.optional)
out4 = irdsl.Result(i32)
- @test.op("variadic2")
- class Variadic2Op:
+ class Variadic2Op(test.Operation, name="variadic2"):
a = irdsl.Operand(i32, irdsl.Variadicity.variadic)
b = irdsl.Result(i32, irdsl.Variadicity.variadic)
- @test.op("mixed")
- class MixedOp:
+ class MixedOp(test.Operation, name="mixed"):
out = irdsl.Result(i32)
in1 = irdsl.Operand(i32)
in2 = irdsl.Attribute(iattr)
@@ -189,34 +182,33 @@ class MixedOp:
# CHECK: irdl.results(out: %0)
# CHECK: }
# CHECK: }
- print(test._make_module())
- test = test.load()
-
- # CHECK: (a, b, c, d, x, y, *, loc=None, ip=None)
- print(test.constraint.__signature__)
- # CHECK: (out1, out3, a, *, out2=None, b=None, loc=None, ip=None)
- print(test.optional.__signature__)
- # CHECK: (*, b=None, a=None, loc=None, ip=None)
- print(test.optional2.__signature__)
- # CHECK: (out1, out2, out4, a, c, *, out3=None, b=None, loc=None, ip=None)
- print(test.variadic.__signature__)
- # CHECK: (b, a, *, loc=None, ip=None)
- print(test.variadic2.__signature__)
- # CHECK: (out, in1, in2, in4, in5, *, in3=None, loc=None, ip=None)
- print(test.mixed.__signature__)
+ with Location.unknown():
+ test.load()
+ print(test.mlir_module)
+
+ # CHECK: (self, /, a, b, c, d, x, y, *, loc=None, ip=None)
+ print(ConstraintOp.__init__.__signature__)
+ # CHECK: (self, /, out1, out3, a, *, out2=None, b=None, loc=None, ip=None)
+ print(OptionalOp.__init__.__signature__)
+ # CHECK: (self, /, *, b=None, a=None, loc=None, ip=None)
+ print(Optional2Op.__init__.__signature__)
+ # CHECK: (self, /, out1, out2, out4, a, c, *, out3=None, b=None, loc=None, ip=None)
+ print(VariadicOp.__init__.__signature__)
+ # CHECK: (self, /, b, a, *, loc=None, ip=None)
+ print(Variadic2Op.__init__.__signature__)
+ # CHECK: (self, /, out, in1, in2, in4, in5, *, in3=None, loc=None, ip=None)
+ print(MixedOp.__init__.__signature__)
# CHECK: None None
- print(
- test.ConstraintOp._ODS_OPERAND_SEGMENTS, test.ConstraintOp._ODS_RESULT_SEGMENTS
- )
+ print(ConstraintOp._ODS_OPERAND_SEGMENTS, ConstraintOp._ODS_RESULT_SEGMENTS)
# CHECK: [1, 0] [1, 0, 1]
- print(test.OptionalOp._ODS_OPERAND_SEGMENTS, test.OptionalOp._ODS_RESULT_SEGMENTS)
+ print(OptionalOp._ODS_OPERAND_SEGMENTS, OptionalOp._ODS_RESULT_SEGMENTS)
# CHECK: [0] [0]
- print(test.Optional2Op._ODS_OPERAND_SEGMENTS, test.Optional2Op._ODS_RESULT_SEGMENTS)
+ print(Optional2Op._ODS_OPERAND_SEGMENTS, Optional2Op._ODS_RESULT_SEGMENTS)
# CHECK: [1, 0, -1] [-1, -1, 0, 1]
- print(test.VariadicOp._ODS_OPERAND_SEGMENTS, test.VariadicOp._ODS_RESULT_SEGMENTS)
+ print(VariadicOp._ODS_OPERAND_SEGMENTS, VariadicOp._ODS_RESULT_SEGMENTS)
# CHECK: [-1] [-1]
- print(test.Variadic2Op._ODS_OPERAND_SEGMENTS, test.Variadic2Op._ODS_RESULT_SEGMENTS)
+ print(Variadic2Op._ODS_OPERAND_SEGMENTS, Variadic2Op._ODS_RESULT_SEGMENTS)
i32 = IntegerType.get_signless(32)
i64 = IntegerType.get_signless(64)
@@ -232,42 +224,42 @@ class MixedOp:
fone = arith.constant(f32, 1.2)
# CHECK: "irdsl_test.constraint"(%c1_i32, %c1_i32, %cst, %c1_i32) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, i32, f32, i32) -> ()
- c1 = test.constraint(ione, ione, fone, ione, iattr, fattr)
+ c1 = ConstraintOp(ione, ione, fone, ione, iattr, fattr)
# CHECK: "irdsl_test.constraint"(%c1_i32, %cst, %cst, %cst) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, f32, f32, f32) -> ()
- test.constraint(ione, fone, fone, fone, iattr, fattr)
+ ConstraintOp(ione, fone, fone, fone, iattr, fattr)
# CHECK: irdsl_test.constraint"(%c1_i32, %cst, %c1_i32, %cst) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, f32, i32, f32) -> ()
- test.constraint(ione, fone, ione, fone, iattr, fattr)
+ ConstraintOp(ione, fone, ione, fone, iattr, fattr)
# CHECK: %0:2 = "irdsl_test.optional"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0>, resultSegmentSizes = array<i32: 1, 0, 1>} : (i32) -> (i32, i32)
- o1 = test.optional(i32, i32, ione)
+ o1 = OptionalOp(i32, i32, ione)
# CHECK: %1:3 = "irdsl_test.optional"(%c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 1>, resultSegmentSizes = array<i32: 1, 1, 1>} : (i32, i32) -> (i32, i32, i32)
- o2 = test.optional(i32, i32, ione, out2=i32, b=ione)
+ o2 = OptionalOp(i32, i32, ione, out2=i32, b=ione)
# CHECK: irdsl_test.optional2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 0>} : () -> ()
- o3 = test.optional2()
+ o3 = Optional2Op()
# CHECK: %2 = "irdsl_test.optional2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 1>} : () -> i32
- o4 = test.optional2(b=i32)
+ o4 = Optional2Op(b=i32)
# CHECK: "irdsl_test.optional2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 0>} : (i32) -> ()
- o5 = test.optional2(a=ione)
+ o5 = Optional2Op(a=ione)
# CHECK: %3 = "irdsl_test.optional2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 1>} : (i32) -> i32
- o6 = test.optional2(b=i32, a=ione)
+ o6 = Optional2Op(b=i32, a=ione)
# CHECK: %4:4 = "irdsl_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 0, 2>, resultSegmentSizes = array<i32: 1, 2, 0, 1>} : (i32, i32, i32) -> (i32, i32, i32, i32)
- v1 = test.variadic([i32], [i32, i32], i32, ione, [ione, ione])
+ v1 = VariadicOp([i32], [i32, i32], i32, ione, [ione, ione])
# CHECK: %5:5 = "irdsl_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 1, 1>, resultSegmentSizes = array<i32: 1, 2, 1, 1>} : (i32, i32, i32) -> (i32, i32, i32, i32, i32)
- v2 = test.variadic([i32], [i32, i32], i32, ione, [ione], out3=i32, b=ione)
+ v2 = VariadicOp([i32], [i32, i32], i32, ione, [ione], out3=i32, b=ione)
# CHECK: %6:4 = "irdsl_test.variadic"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0, 0>, resultSegmentSizes = array<i32: 2, 1, 0, 1>} : (i32) -> (i32, i32, i32, i32)
- v3 = test.variadic([i32, i32], [i32], i32, ione, [])
+ v3 = VariadicOp([i32, i32], [i32], i32, ione, [])
# CHECK: "irdsl_test.variadic2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 0>} : () -> ()
- v4 = test.variadic2([], [])
+ v4 = Variadic2Op([], [])
# CHECK: "irdsl_test.variadic2"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 3>, resultSegmentSizes = array<i32: 0>} : (i32, i32, i32) -> ()
- v5 = test.variadic2([], [ione, ione, ione])
+ v5 = Variadic2Op([], [ione, ione, ione])
# CHECK: %7:2 = "irdsl_test.variadic2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 2>} : (i32) -> (i32, i32)
- v6 = test.variadic2([i32, i32], [ione])
+ v6 = Variadic2Op([i32, i32], [ione])
# CHECK: %8 = "irdsl_test.mixed"(%c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 0, 1>} : (i32, i32) -> i32
- m1 = test.mixed(i32, ione, iattr, iattr, ione)
+ m1 = MixedOp(i32, ione, iattr, iattr, ione)
# CHECK: %9 = "irdsl_test.mixed"(%c1_i32, %c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 1, 1>} : (i32, i32, i32) -> i32
- m2 = test.mixed(i32, ione, iattr, iattr, ione, in3=ione)
+ m2 = MixedOp(i32, ione, iattr, iattr, ione, in3=ione)
print(module)
assert module.operation.verify()
>From 1d7e70788546a2b86fbc719a052a03fca0e0097f Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Thu, 11 Dec 2025 10:17:59 +0800
Subject: [PATCH 10/20] fix
---
mlir/python/mlir/dialects/irdl/dsl.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
index 630740c0103ab..97df1f78ac16c 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -375,8 +375,8 @@ def load(self) -> None:
dialect_class = self._make_dialect_class()
_cext.register_dialect(dialect_class)
- # for op in self.operations:
- # _cext.register_operation(dialect_class)(op)
+ for op in self.operations:
+ _cext.register_operation(dialect_class)(op)
self.mlir_module = mlir_module
self.dialect_class = dialect_class
>From c8364487a1625754b27b75885db59f5087bcba53 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Thu, 11 Dec 2025 10:28:06 +0800
Subject: [PATCH 11/20] add more type annotations
---
mlir/python/mlir/dialects/irdl/dsl.py | 14 ++++++++------
1 file changed, 8 insertions(+), 6 deletions(-)
diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
index 97df1f78ac16c..56869e7ae4e3c 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -150,7 +150,7 @@ def normalize_value_range(
class Operation(ir.OpView):
@classmethod
- def __init_subclass__(cls, *, name=None, **kwargs):
+ def __init_subclass__(cls, *, name: str = None, **kwargs):
super().__init_subclass__(**kwargs)
# for subclasses without "name" parameter,
@@ -222,7 +222,7 @@ def _generate_init_signature(fields: List[FieldDef]) -> Signature:
return Signature(params)
@classmethod
- def _generate_init_method(cls, fields):
+ def _generate_init_method(cls, fields: List[FieldDef]) -> None:
init_sig = cls._generate_init_signature(fields)
operands, attrs, results = partition_fields(fields)
@@ -259,7 +259,9 @@ def __init__(*args, **kwargs):
cls.__init__ = __init__
@classmethod
- def _generate_class_attributes(cls, dialect_name, op_name, fields):
+ def _generate_class_attributes(
+ cls, dialect_name: str, op_name: str, fields: List[FieldDef]
+ ) -> None:
operands, attrs, results = partition_fields(fields)
operand_segments = cls._generate_segments(operands)
@@ -271,7 +273,7 @@ def _generate_class_attributes(cls, dialect_name, op_name, fields):
cls._ODS_RESULT_SEGMENTS = result_segments
@classmethod
- def _generate_attr_properties(cls, attrs):
+ def _generate_attr_properties(cls, attrs: List[Attribute]) -> None:
for attr in attrs:
setattr(
cls,
@@ -280,7 +282,7 @@ def _generate_attr_properties(cls, attrs):
)
@classmethod
- def _generate_operand_properties(cls, operands):
+ def _generate_operand_properties(cls, operands: List[Operand]) -> None:
for i, operand in enumerate(operands):
if cls._ODS_OPERAND_SEGMENTS:
@@ -297,7 +299,7 @@ def getter(self, i=i, operand=operand):
setattr(cls, operand.name, property(lambda self, i=i: self.operands[i]))
@classmethod
- def _generate_result_properties(cls, results):
+ def _generate_result_properties(cls, results: List[Result]) -> None:
for i, result in enumerate(results):
if cls._ODS_RESULT_SEGMENTS:
>From 6e3eef309f4d56e6e8c64351184040db2205c309 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 22 Dec 2025 15:25:15 +0800
Subject: [PATCH 12/20] make sure that context is not required for defining
dialects
---
mlir/python/mlir/dialects/irdl/dsl.py | 28 +++-
mlir/test/python/dialects/irdsl.py | 220 +++++++++++++-------------
2 files changed, 133 insertions(+), 115 deletions(-)
diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
index 56869e7ae4e3c..35237c848fbeb 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -56,13 +56,33 @@ def lower(self, expr: ConstraintExpr) -> ir.Value:
class Is(ConstraintExpr):
- def __init__(self, val: Union[ir.Attribute, ir.Type]):
+ def __init__(self, val: Callable[..., Union[ir.Attribute, ir.Type]]):
self.val = val
+ self.args = []
+ self.kwargs = {}
+
+ def __call__(self, *args, **kwargs) -> "Is":
+ self.args.extend(args)
+ self.kwargs.update(kwargs)
+ return self
+
+ def __class_getitem__(
+ cls, val: Callable[..., Union[ir.Attribute, ir.Type]]
+ ) -> "Is":
+ return cls(val)
def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
- return irdl.is_(
- ir.TypeAttr.get(self.val) if isinstance(self.val, ir.Type) else self.val
- )
+ # for most attributes and types, they are created via `.get` method,
+ # here we can just omit the `.get`
+ if isinstance(self.val, type) and hasattr(self.val, "get"):
+ self.val = self.val.get
+
+ val = self.val(*self.args, **self.kwargs)
+
+ if isinstance(val, ir.Type):
+ val = ir.TypeAttr.get(val)
+
+ return irdl.is_(val)
class AnyOf(ConstraintExpr):
diff --git a/mlir/test/python/dialects/irdsl.py b/mlir/test/python/dialects/irdsl.py
index f3a80655efe80..46511b8e1bfeb 100644
--- a/mlir/test/python/dialects/irdsl.py
+++ b/mlir/test/python/dialects/irdsl.py
@@ -8,8 +8,7 @@
def run(f):
print("\nTEST:", f.__name__, file=sys.stderr)
- with Context():
- f()
+ f()
# CHECK: TEST: testMyInt
@@ -17,7 +16,7 @@ def run(f):
def testMyInt():
myint = irdsl.Dialect("myint")
iattr = irdsl.BaseName("#builtin.integer")
- i32 = irdsl.Is(IntegerType.get_signless(32))
+ i32 = irdsl.Is[IntegerType.get_signless](32)
class ConstantOp(myint.Operation, name="constant"):
value = irdsl.Attribute(iattr)
@@ -41,15 +40,15 @@ class AddOp(myint.Operation, name="add"):
# CHECK: irdl.results(res: %0)
# CHECK: }
# CHECK: }
- with Location.unknown():
+ with Context(), Location.unknown():
myint.load()
- print(myint.mlir_module)
+ print(myint.mlir_module)
- # CHECK: ['constant', 'add']
- print([i._op_name for i in myint.operations])
+ # CHECK: ['constant', 'add']
+ print([i._op_name for i in myint.operations])
+
+ i32 = IntegerType.get_signless(32)
- i32 = IntegerType.get_signless(32)
- with Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
two = ConstantOp(i32, IntegerAttr.get(i32, 2))
@@ -58,46 +57,46 @@ class AddOp(myint.Operation, name="add"):
add2 = AddOp(i32, add1, two)
add3 = AddOp(i32, add2, three)
- # CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
- # CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
- # CHECK: %2 = "myint.add"(%0, %1) : (i32, i32) -> i32
- # CHECK: %3 = "myint.add"(%2, %0) : (i32, i32) -> i32
- # CHECK: %4 = "myint.add"(%3, %1) : (i32, i32) -> i32
- print(module)
- assert module.operation.verify()
-
- # CHECK: AddOp
- print(type(add1).__name__)
- # CHECK: ConstantOp
- print(type(two).__name__)
- # CHECK: myint.add
- print(add1.OPERATION_NAME)
- # CHECK: None
- print(add1._ODS_OPERAND_SEGMENTS)
- # CHECK: None
- print(add1._ODS_RESULT_SEGMENTS)
- # CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
- print(add1.lhs.owner)
- # CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
- print(add1.rhs.owner)
- # CHECK: 2 : i32
- print(two.value)
- # CHECK: Value(%0
- print(two.cst)
- # CHECK: (self, /, res, lhs, rhs, *, loc=None, ip=None)
- print(AddOp.__init__.__signature__)
- # CHECK: (self, /, cst, value, *, loc=None, ip=None)
- print(ConstantOp.__init__.__signature__)
+ # CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
+ # CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
+ # CHECK: %2 = "myint.add"(%0, %1) : (i32, i32) -> i32
+ # CHECK: %3 = "myint.add"(%2, %0) : (i32, i32) -> i32
+ # CHECK: %4 = "myint.add"(%3, %1) : (i32, i32) -> i32
+ print(module)
+ assert module.operation.verify()
+
+ # CHECK: AddOp
+ print(type(add1).__name__)
+ # CHECK: ConstantOp
+ print(type(two).__name__)
+ # CHECK: myint.add
+ print(add1.OPERATION_NAME)
+ # CHECK: None
+ print(add1._ODS_OPERAND_SEGMENTS)
+ # CHECK: None
+ print(add1._ODS_RESULT_SEGMENTS)
+ # CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
+ print(add1.lhs.owner)
+ # CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
+ print(add1.rhs.owner)
+ # CHECK: 2 : i32
+ print(two.value)
+ # CHECK: Value(%0
+ print(two.cst)
+ # CHECK: (self, /, res, lhs, rhs, *, loc=None, ip=None)
+ print(AddOp.__init__.__signature__)
+ # CHECK: (self, /, cst, value, *, loc=None, ip=None)
+ print(ConstantOp.__init__.__signature__)
@run
def testIRDSL():
test = irdsl.Dialect("irdsl_test")
- i32 = irdsl.Is(IntegerType.get_signless(32))
- i64 = irdsl.Is(IntegerType.get_signless(64))
+ i32 = irdsl.Is[IntegerType.get_signless](32)
+ i64 = irdsl.Is[IntegerType.get_signless](64)
i32or64 = i32 | i64
any = irdsl.Any()
- f32 = irdsl.Is(F32Type.get())
+ f32 = irdsl.Is[F32Type]
iattr = irdsl.BaseName("#builtin.integer")
fattr = irdsl.BaseName("#builtin.float")
@@ -182,39 +181,38 @@ class MixedOp(test.Operation, name="mixed"):
# CHECK: irdl.results(out: %0)
# CHECK: }
# CHECK: }
- with Location.unknown():
+ with Context(), Location.unknown():
test.load()
- print(test.mlir_module)
-
- # CHECK: (self, /, a, b, c, d, x, y, *, loc=None, ip=None)
- print(ConstraintOp.__init__.__signature__)
- # CHECK: (self, /, out1, out3, a, *, out2=None, b=None, loc=None, ip=None)
- print(OptionalOp.__init__.__signature__)
- # CHECK: (self, /, *, b=None, a=None, loc=None, ip=None)
- print(Optional2Op.__init__.__signature__)
- # CHECK: (self, /, out1, out2, out4, a, c, *, out3=None, b=None, loc=None, ip=None)
- print(VariadicOp.__init__.__signature__)
- # CHECK: (self, /, b, a, *, loc=None, ip=None)
- print(Variadic2Op.__init__.__signature__)
- # CHECK: (self, /, out, in1, in2, in4, in5, *, in3=None, loc=None, ip=None)
- print(MixedOp.__init__.__signature__)
-
- # CHECK: None None
- print(ConstraintOp._ODS_OPERAND_SEGMENTS, ConstraintOp._ODS_RESULT_SEGMENTS)
- # CHECK: [1, 0] [1, 0, 1]
- print(OptionalOp._ODS_OPERAND_SEGMENTS, OptionalOp._ODS_RESULT_SEGMENTS)
- # CHECK: [0] [0]
- print(Optional2Op._ODS_OPERAND_SEGMENTS, Optional2Op._ODS_RESULT_SEGMENTS)
- # CHECK: [1, 0, -1] [-1, -1, 0, 1]
- print(VariadicOp._ODS_OPERAND_SEGMENTS, VariadicOp._ODS_RESULT_SEGMENTS)
- # CHECK: [-1] [-1]
- print(Variadic2Op._ODS_OPERAND_SEGMENTS, Variadic2Op._ODS_RESULT_SEGMENTS)
-
- i32 = IntegerType.get_signless(32)
- i64 = IntegerType.get_signless(64)
- f32 = F32Type.get()
-
- with Location.unknown():
+ print(test.mlir_module)
+
+ # CHECK: (self, /, a, b, c, d, x, y, *, loc=None, ip=None)
+ print(ConstraintOp.__init__.__signature__)
+ # CHECK: (self, /, out1, out3, a, *, out2=None, b=None, loc=None, ip=None)
+ print(OptionalOp.__init__.__signature__)
+ # CHECK: (self, /, *, b=None, a=None, loc=None, ip=None)
+ print(Optional2Op.__init__.__signature__)
+ # CHECK: (self, /, out1, out2, out4, a, c, *, out3=None, b=None, loc=None, ip=None)
+ print(VariadicOp.__init__.__signature__)
+ # CHECK: (self, /, b, a, *, loc=None, ip=None)
+ print(Variadic2Op.__init__.__signature__)
+ # CHECK: (self, /, out, in1, in2, in4, in5, *, in3=None, loc=None, ip=None)
+ print(MixedOp.__init__.__signature__)
+
+ # CHECK: None None
+ print(ConstraintOp._ODS_OPERAND_SEGMENTS, ConstraintOp._ODS_RESULT_SEGMENTS)
+ # CHECK: [1, 0] [1, 0, 1]
+ print(OptionalOp._ODS_OPERAND_SEGMENTS, OptionalOp._ODS_RESULT_SEGMENTS)
+ # CHECK: [0] [0]
+ print(Optional2Op._ODS_OPERAND_SEGMENTS, Optional2Op._ODS_RESULT_SEGMENTS)
+ # CHECK: [1, 0, -1] [-1, -1, 0, 1]
+ print(VariadicOp._ODS_OPERAND_SEGMENTS, VariadicOp._ODS_RESULT_SEGMENTS)
+ # CHECK: [-1] [-1]
+ print(Variadic2Op._ODS_OPERAND_SEGMENTS, Variadic2Op._ODS_RESULT_SEGMENTS)
+
+ i32 = IntegerType.get_signless(32)
+ i64 = IntegerType.get_signless(64)
+ f32 = F32Type.get()
+
iattr = IntegerAttr.get(i32, 2)
fattr = FloatAttr.get_f32(2.3)
@@ -261,40 +259,40 @@ class MixedOp(test.Operation, name="mixed"):
# CHECK: %9 = "irdsl_test.mixed"(%c1_i32, %c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 1, 1>} : (i32, i32, i32) -> i32
m2 = MixedOp(i32, ione, iattr, iattr, ione, in3=ione)
- print(module)
- assert module.operation.verify()
-
- # CHECK: Value(%c1_i32
- print(c1.a)
- # CHECK: 2 : i32
- print(c1.x)
- # CHECK: Value(%c1_i32
- print(o1.a)
- # CHECK: None
- print(o1.b)
- # CHECK: Value(%c1_i32
- print(o2.b)
- # CHECK: 0
- print(o1.out1.result_number)
- # CHECK: None
- print(o1.out2)
- # CHECK: 0
- print(o2.out1.result_number)
- # CHECK: 1
- print(o2.out2.result_number)
- # CHECK: None
- print(o3.a)
- # CHECK: Value(%c1_i32
- print(o5.a)
- # CHECK: ['Value(%c1_i32 = arith.constant 1 : i32)', 'Value(%c1_i32 = arith.constant 1 : i32)']
- print([str(i) for i in v1.c])
- # CHECK: ['Value(%c1_i32 = arith.constant 1 : i32)']
- print([str(i) for i in v2.c])
- # CHECK: []
- print([str(i) for i in v3.c])
- # CHECK: 0 0
- print(len(v4.a), len(v4.b))
- # CHECK: 3 0
- print(len(v5.a), len(v5.b))
- # CHECK: 1 2
- print(len(v6.a), len(v6.b))
+ print(module)
+ assert module.operation.verify()
+
+ # CHECK: Value(%c1_i32
+ print(c1.a)
+ # CHECK: 2 : i32
+ print(c1.x)
+ # CHECK: Value(%c1_i32
+ print(o1.a)
+ # CHECK: None
+ print(o1.b)
+ # CHECK: Value(%c1_i32
+ print(o2.b)
+ # CHECK: 0
+ print(o1.out1.result_number)
+ # CHECK: None
+ print(o1.out2)
+ # CHECK: 0
+ print(o2.out1.result_number)
+ # CHECK: 1
+ print(o2.out2.result_number)
+ # CHECK: None
+ print(o3.a)
+ # CHECK: Value(%c1_i32
+ print(o5.a)
+ # CHECK: ['Value(%c1_i32 = arith.constant 1 : i32)', 'Value(%c1_i32 = arith.constant 1 : i32)']
+ print([str(i) for i in v1.c])
+ # CHECK: ['Value(%c1_i32 = arith.constant 1 : i32)']
+ print([str(i) for i in v2.c])
+ # CHECK: []
+ print([str(i) for i in v3.c])
+ # CHECK: 0 0
+ print(len(v4.a), len(v4.b))
+ # CHECK: 3 0
+ print(len(v5.a), len(v5.b))
+ # CHECK: 1 2
+ print(len(v6.a), len(v6.b))
>From 4ff653d192b8837fd8484eba8be492e4124501b1 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 22 Dec 2025 15:27:05 +0800
Subject: [PATCH 13/20] adjust the comment
---
mlir/python/mlir/dialects/irdl/dsl.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
index 35237c848fbeb..59e667ab1ba8b 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -73,7 +73,7 @@ def __class_getitem__(
def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
# for most attributes and types, they are created via `.get` method,
- # here we can just omit the `.get`
+ # here we can just omit the `.get` suffix for convenience
if isinstance(self.val, type) and hasattr(self.val, "get"):
self.val = self.val.get
>From c9fcf2473844f09a6d11fef54b7651c16ac9b1c9 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 22 Dec 2025 16:34:17 +0800
Subject: [PATCH 14/20] make irdsl.Dialect a subclass of ir.Dialect
---
mlir/python/mlir/dialects/irdl/dsl.py | 48 +++++++++++++--------------
mlir/test/python/dialects/irdsl.py | 35 ++++++++++---------
2 files changed, 43 insertions(+), 40 deletions(-)
diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
index 59e667ab1ba8b..c5679ee587e8c 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -361,44 +361,44 @@ def _emit_operation(cls) -> None:
)
-class Dialect:
- def __init__(self, name: str):
- self.name = name
- self.operations = []
- self.Operation = type(
+class Dialect(ir.Dialect):
+ @classmethod
+ def __init_subclass__(cls, name: str, **kwargs):
+ cls.name = name
+ cls.DIALECT_NAMESPACE = name
+ cls.operations = []
+ cls.Operation = type(
"Operation",
(Operation,),
- {"_dialect_obj": self, "_dialect_name": name},
+ {"_dialect_obj": cls, "_dialect_name": name},
)
- def _emit_dialect(self) -> None:
- d = irdl.dialect(self.name)
+ @classmethod
+ def _emit_dialect(cls) -> None:
+ d = irdl.dialect(cls.name)
with ir.InsertionPoint(d.body):
- for op in self.operations:
+ for op in cls.operations:
op._emit_operation()
- def _emit_module(self) -> ir.Module:
+ @classmethod
+ def _emit_module(cls) -> ir.Module:
m = ir.Module.create()
with ir.InsertionPoint(m.body):
- self._emit_dialect()
+ cls._emit_dialect()
return m
- def _make_dialect_class(self) -> type:
- return type("Dialect", (ir.Dialect,), {"DIALECT_NAMESPACE": self.name})
-
- def load(self) -> None:
- if hasattr(self, "mlir_module"):
- raise RuntimeError(f"Dialect {self.name} is already loaded.")
+ @classmethod
+ def load(cls) -> None:
+ if hasattr(cls, "mlir_module"):
+ raise RuntimeError(f"Dialect {cls.name} is already loaded.")
- mlir_module = self._emit_module()
+ mlir_module = cls._emit_module()
irdl.load_dialects(mlir_module)
- dialect_class = self._make_dialect_class()
- _cext.register_dialect(dialect_class)
+ _cext.register_dialect(cls)
- for op in self.operations:
- _cext.register_operation(dialect_class)(op)
+ for op in cls.operations:
+ _cext.register_operation(cls)(op)
- self.mlir_module = mlir_module
- self.dialect_class = dialect_class
+ cls.mlir_module = mlir_module
diff --git a/mlir/test/python/dialects/irdsl.py b/mlir/test/python/dialects/irdsl.py
index 46511b8e1bfeb..de7949a7d376a 100644
--- a/mlir/test/python/dialects/irdsl.py
+++ b/mlir/test/python/dialects/irdsl.py
@@ -14,15 +14,17 @@ def run(f):
# CHECK: TEST: testMyInt
@run
def testMyInt():
- myint = irdsl.Dialect("myint")
+ class MyInt(irdsl.Dialect, name="myint"):
+ pass
+
iattr = irdsl.BaseName("#builtin.integer")
i32 = irdsl.Is[IntegerType.get_signless](32)
- class ConstantOp(myint.Operation, name="constant"):
+ class ConstantOp(MyInt.Operation, name="constant"):
value = irdsl.Attribute(iattr)
cst = irdsl.Result(i32)
- class AddOp(myint.Operation, name="add"):
+ class AddOp(MyInt.Operation, name="add"):
lhs = irdsl.Operand(i32)
rhs = irdsl.Operand(i32)
res = irdsl.Result(i32)
@@ -41,12 +43,11 @@ class AddOp(myint.Operation, name="add"):
# CHECK: }
# CHECK: }
with Context(), Location.unknown():
- myint.load()
- print(myint.mlir_module)
+ MyInt.load()
+ print(MyInt.mlir_module)
# CHECK: ['constant', 'add']
- print([i._op_name for i in myint.operations])
-
+ print([i._op_name for i in MyInt.operations])
i32 = IntegerType.get_signless(32)
module = Module.create()
@@ -91,7 +92,9 @@ class AddOp(myint.Operation, name="add"):
@run
def testIRDSL():
- test = irdsl.Dialect("irdsl_test")
+ class Test(irdsl.Dialect, name="irdsl_test"):
+ pass
+
i32 = irdsl.Is[IntegerType.get_signless](32)
i64 = irdsl.Is[IntegerType.get_signless](64)
i32or64 = i32 | i64
@@ -100,7 +103,7 @@ def testIRDSL():
iattr = irdsl.BaseName("#builtin.integer")
fattr = irdsl.BaseName("#builtin.float")
- class ConstraintOp(test.Operation, name="constraint"):
+ class ConstraintOp(Test.Operation, name="constraint"):
a = irdsl.Operand(i32or64)
b = irdsl.Operand(any)
c = irdsl.Operand(f32 | i32)
@@ -108,18 +111,18 @@ class ConstraintOp(test.Operation, name="constraint"):
x = irdsl.Attribute(iattr)
y = irdsl.Attribute(fattr)
- class OptionalOp(test.Operation, name="optional"):
+ class OptionalOp(Test.Operation, name="optional"):
a = irdsl.Operand(i32)
b = irdsl.Operand(i32, irdsl.Variadicity.optional)
out1 = irdsl.Result(i32)
out2 = irdsl.Result(i32, irdsl.Variadicity.optional)
out3 = irdsl.Result(i32)
- class Optional2Op(test.Operation, name="optional2"):
+ class Optional2Op(Test.Operation, name="optional2"):
a = irdsl.Operand(i32, irdsl.Variadicity.optional)
b = irdsl.Result(i32, irdsl.Variadicity.optional)
- class VariadicOp(test.Operation, name="variadic"):
+ class VariadicOp(Test.Operation, name="variadic"):
a = irdsl.Operand(i32)
b = irdsl.Operand(i32, irdsl.Variadicity.optional)
c = irdsl.Operand(i32, irdsl.Variadicity.variadic)
@@ -128,11 +131,11 @@ class VariadicOp(test.Operation, name="variadic"):
out3 = irdsl.Result(i32, irdsl.Variadicity.optional)
out4 = irdsl.Result(i32)
- class Variadic2Op(test.Operation, name="variadic2"):
+ class Variadic2Op(Test.Operation, name="variadic2"):
a = irdsl.Operand(i32, irdsl.Variadicity.variadic)
b = irdsl.Result(i32, irdsl.Variadicity.variadic)
- class MixedOp(test.Operation, name="mixed"):
+ class MixedOp(Test.Operation, name="mixed"):
out = irdsl.Result(i32)
in1 = irdsl.Operand(i32)
in2 = irdsl.Attribute(iattr)
@@ -182,8 +185,8 @@ class MixedOp(test.Operation, name="mixed"):
# CHECK: }
# CHECK: }
with Context(), Location.unknown():
- test.load()
- print(test.mlir_module)
+ Test.load()
+ print(Test.mlir_module)
# CHECK: (self, /, a, b, c, d, x, y, *, loc=None, ip=None)
print(ConstraintOp.__init__.__signature__)
>From 59d7ae71288af54ae76d53f2976fe4299e638308 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 22 Dec 2025 17:12:53 +0800
Subject: [PATCH 15/20] add example for base class
---
mlir/python/mlir/dialects/irdl/dsl.py | 9 +++++----
mlir/test/python/dialects/irdsl.py | 4 +++-
2 files changed, 8 insertions(+), 5 deletions(-)
diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
index c5679ee587e8c..d1c706169ffd8 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -186,10 +186,11 @@ def __init_subclass__(cls, *, name: str = None, **kwargs):
fields = []
cls._fields = fields
- for key, value in cls.__dict__.items():
- if isinstance(value, FieldDef):
- setattr(value, "name", key)
- fields.append(value)
+ for base in reversed(cls.__mro__):
+ for key, value in base.__dict__.items():
+ if isinstance(value, FieldDef):
+ setattr(value, "name", key)
+ fields.append(value)
cls._generate_class_attributes(dialect_name, op_name, fields)
cls._generate_init_method(fields)
diff --git a/mlir/test/python/dialects/irdsl.py b/mlir/test/python/dialects/irdsl.py
index de7949a7d376a..e80bb412139b0 100644
--- a/mlir/test/python/dialects/irdsl.py
+++ b/mlir/test/python/dialects/irdsl.py
@@ -135,9 +135,11 @@ class Variadic2Op(Test.Operation, name="variadic2"):
a = irdsl.Operand(i32, irdsl.Variadicity.variadic)
b = irdsl.Result(i32, irdsl.Variadicity.variadic)
- class MixedOp(Test.Operation, name="mixed"):
+ class MixedOpBase(Test.Operation):
out = irdsl.Result(i32)
in1 = irdsl.Operand(i32)
+
+ class MixedOp(MixedOpBase, name="mixed"):
in2 = irdsl.Attribute(iattr)
in3 = irdsl.Operand(i32, irdsl.Variadicity.optional)
in4 = irdsl.Attribute(iattr)
>From 2cb35816f9025ecdbe38cd5f8d8abb041938bf03 Mon Sep 17 00:00:00 2001
From: Twice <twice at apache.org>
Date: Mon, 22 Dec 2025 17:32:32 +0800
Subject: [PATCH 16/20] Update mlir/python/mlir/dialects/irdl/dsl.py
Co-authored-by: Rolf Morel <rolfmorel at gmail.com>
---
mlir/python/mlir/dialects/irdl/dsl.py | 6 +-----
1 file changed, 1 insertion(+), 5 deletions(-)
diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
index d1c706169ffd8..075926a08872a 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -141,11 +141,7 @@ class Result(FieldDef):
@dataclass
class Attribute(FieldDef):
constraint: ConstraintExpr
-
- def __post_init__(self):
- # just for unified processing,
- # currently optional attribute is not supported by IRDL
- self.variadicity = Variadicity.single
+ variadicity: typing.ClassVar[Variadicity] = Variadicity.single
def partition_fields(
>From 36104c4836a77cc0ca2228b22bbf2d904407534f Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 22 Dec 2025 17:34:12 +0800
Subject: [PATCH 17/20] fix import
---
mlir/python/mlir/dialects/irdl/dsl.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
index 075926a08872a..0b6fdfc268c47 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -2,7 +2,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from typing import Dict, List, Union, Callable, Tuple
+from typing import Dict, List, Union, Callable, Tuple, ClassVar
from dataclasses import dataclass
from inspect import Parameter, Signature
from types import SimpleNamespace
@@ -141,7 +141,7 @@ class Result(FieldDef):
@dataclass
class Attribute(FieldDef):
constraint: ConstraintExpr
- variadicity: typing.ClassVar[Variadicity] = Variadicity.single
+ variadicity: ClassVar[Variadicity] = Variadicity.single
def partition_fields(
>From 4fbe94384b689d3f765ecdddc192ded86d2f493f Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Thu, 8 Jan 2026 21:12:29 +0800
Subject: [PATCH 18/20] fix order of base fields
---
mlir/python/mlir/dialects/irdl/dsl.py | 20 +++++++++++---------
1 file changed, 11 insertions(+), 9 deletions(-)
diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py
index 0b6fdfc268c47..31a2a1bf98b14 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/irdl/dsl.py
@@ -169,6 +169,17 @@ class Operation(ir.OpView):
def __init_subclass__(cls, *, name: str = None, **kwargs):
super().__init_subclass__(**kwargs)
+ fields = []
+ cls._fields = fields
+
+ for base in cls.__bases__:
+ if hasattr(base, "_fields"):
+ fields.extend(base._fields)
+ for key, value in cls.__dict__.items():
+ if isinstance(value, FieldDef):
+ setattr(value, "name", key)
+ fields.append(value)
+
# for subclasses without "name" parameter,
# just treat them as normal classes
if not name:
@@ -179,15 +190,6 @@ def __init_subclass__(cls, *, name: str = None, **kwargs):
dialect_name = cls._dialect_name
dialect_obj = cls._dialect_obj
- fields = []
- cls._fields = fields
-
- for base in reversed(cls.__mro__):
- for key, value in base.__dict__.items():
- if isinstance(value, FieldDef):
- setattr(value, "name", key)
- fields.append(value)
-
cls._generate_class_attributes(dialect_name, op_name, fields)
cls._generate_init_method(fields)
operands, attrs, results = partition_fields(fields)
>From 84b99bbc0a6f1a07dfac481f4a323dd2c39cbb5b Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 9 Jan 2026 19:50:32 +0800
Subject: [PATCH 19/20] huge refactor
---
mlir/include/mlir-c/BuiltinAttributes.h | 2 +
.../mlir/Bindings/Python/IRAttributes.h | 1 +
mlir/include/mlir/Bindings/Python/IRCore.h | 7 +-
mlir/lib/Bindings/Python/IRCore.cpp | 2 +-
mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 2 +
mlir/python/CMakeLists.txt | 5 +-
.../mlir/dialects/{irdl/dsl.py => ext.py} | 209 ++++++++----------
.../dialects/{irdl/__init__.py => irdl.py} | 13 +-
.../test/python/dialects/{irdsl.py => ext.py} | 88 ++++----
9 files changed, 147 insertions(+), 182 deletions(-)
rename mlir/python/mlir/dialects/{irdl/dsl.py => ext.py} (69%)
rename mlir/python/mlir/dialects/{irdl/__init__.py => irdl.py} (91%)
rename mlir/test/python/dialects/{irdsl.py => ext.py} (84%)
diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h
index 17c73f44cfc74..eab732365f6b8 100644
--- a/mlir/include/mlir-c/BuiltinAttributes.h
+++ b/mlir/include/mlir-c/BuiltinAttributes.h
@@ -115,6 +115,8 @@ MLIR_CAPI_EXPORTED MlirTypeID mlirDictionaryAttrGetTypeID(void);
/// Checks whether the given attribute is a floating point attribute.
MLIR_CAPI_EXPORTED bool mlirAttributeIsAFloat(MlirAttribute attr);
+MLIR_CAPI_EXPORTED MlirStringRef mlirFloatAttrGetName(void);
+
/// Creates a floating point attribute in the given context with the given
/// double value and double-precision FP semantics.
MLIR_CAPI_EXPORTED MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx,
diff --git a/mlir/include/mlir/Bindings/Python/IRAttributes.h b/mlir/include/mlir/Bindings/Python/IRAttributes.h
index 05d64b0d91b1b..6175710d76dd0 100644
--- a/mlir/include/mlir/Bindings/Python/IRAttributes.h
+++ b/mlir/include/mlir/Bindings/Python/IRAttributes.h
@@ -324,6 +324,7 @@ class MLIR_PYTHON_API_EXPORTED PyFloatAttribute
using PyConcreteAttribute::PyConcreteAttribute;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirFloatAttrGetTypeID;
+ static inline const MlirStringRef name = mlirFloatAttrGetName();
static void bindDerived(ClassTy &c);
};
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 59dc496c9e206..766eb9290ad58 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -957,7 +957,7 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteType : public BaseTy {
}
static void bind(nanobind::module_ &m) {
- auto cls = ClassTy(m, DerivedTy::pyClassName);
+ auto cls = ClassTy(m, DerivedTy::pyClassName, nanobind::is_generic());
cls.def(nanobind::init<PyType &>(), nanobind::keep_alive<0, 1>(),
nanobind::arg("cast_from_type"));
cls.def_prop_ro_static(
@@ -1092,9 +1092,10 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteAttribute : public BaseTy {
static void bind(nanobind::module_ &m, PyType_Slot *slots = nullptr) {
ClassTy cls;
if (slots) {
- cls = ClassTy(m, DerivedTy::pyClassName, nanobind::type_slots(slots));
+ cls = ClassTy(m, DerivedTy::pyClassName, nanobind::type_slots(slots),
+ nanobind::is_generic());
} else {
- cls = ClassTy(m, DerivedTy::pyClassName);
+ cls = ClassTy(m, DerivedTy::pyClassName, nanobind::is_generic());
}
cls.def(nanobind::init<PyAttribute &>(), nanobind::keep_alive<0, 1>(),
nanobind::arg("cast_from_attr"));
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 19db41fae4fe2..a9f228f33a0ab 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -403,7 +403,7 @@ size_t PyOpOperand::getOperandNumber() const {
}
void PyOpOperand::bind(nb::module_ &m) {
- nb::class_<PyOpOperand>(m, "OpOperand")
+ nb::class_<PyOpOperand>(m, "OpOperand", nb::is_generic())
.def_prop_ro("owner", &PyOpOperand::getOwner,
"Returns the operation that owns this operand.")
.def_prop_ro("operand_number", &PyOpOperand::getOperandNumber,
diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index eebf82215eab0..f7172c21a0cb9 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -131,6 +131,8 @@ bool mlirAttributeIsAFloat(MlirAttribute attr) {
return llvm::isa<FloatAttr>(unwrap(attr));
}
+MlirStringRef mlirFloatAttrGetName(void) { return wrap(FloatAttr::name); }
+
MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type,
double value) {
return wrap(FloatAttr::get(unwrap(type), value));
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index d97066657e070..8ab145ada85dd 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -30,6 +30,7 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
passmanager.py
rewrite.py
dialects/_ods_common.py
+ dialects/ext.py
)
declare_mlir_python_sources(MLIRPythonSources.Core.Python.Extras
@@ -513,9 +514,7 @@ declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/IRDLOps.td
- SOURCES
- dialects/irdl/__init__.py
- dialects/irdl/dsl.py
+ SOURCES dialects/irdl.py
DIALECT_NAME irdl
GEN_ENUM_BINDINGS
)
diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/ext.py
similarity index 69%
rename from mlir/python/mlir/dialects/irdl/dsl.py
rename to mlir/python/mlir/dialects/ext.py
index 31a2a1bf98b14..18ab977d164de 100644
--- a/mlir/python/mlir/dialects/irdl/dsl.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -2,124 +2,62 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from typing import Dict, List, Union, Callable, Tuple, ClassVar
+from typing import (
+ Dict,
+ List,
+ Union,
+ Tuple,
+ Any,
+ TypeVar,
+ get_origin,
+ get_args,
+)
+from collections.abc import Sequence
from dataclasses import dataclass
from inspect import Parameter, Signature
-from types import SimpleNamespace
-from abc import ABC, abstractmethod
-from contextlib import nullcontext
-from ...dialects import irdl
-from .._ods_common import _cext, segmented_accessor
-from . import Variadicity
+from types import UnionType
+from . import irdl
+from ._ods_common import _cext, segmented_accessor
+from .irdl import Variadicity
ir = _cext.ir
__all__ = [
- "Variadicity",
- "Is",
- "AnyOf",
- "AllOf",
- "Any",
- "BaseName",
- "BaseRef",
- "Operand",
- "Result",
- "Attribute",
"Dialect",
]
-class ConstraintExpr(ABC):
- @abstractmethod
- def _lower(self, ctx: "ConstraintLoweringContext") -> ir.Value:
- pass
-
- def __or__(self, other: "ConstraintExpr") -> "ConstraintExpr":
- return AnyOf(self, other)
-
- def __and__(self, other: "ConstraintExpr") -> "ConstraintExpr":
- return AllOf(self, other)
-
-
class ConstraintLoweringContext:
def __init__(self):
# Cache so that the same ConstraintExpr instance reuses its SSA value.
self._cache: Dict[int, ir.Value] = {}
- def lower(self, expr: ConstraintExpr) -> ir.Value:
- key = id(expr)
+ def lower(self, type_) -> ir.Value:
+ key = id(type_)
if key in self._cache:
return self._cache[key]
- v = expr._lower(self)
+ v = self._lower(type_)
self._cache[key] = v
return v
-
-class Is(ConstraintExpr):
- def __init__(self, val: Callable[..., Union[ir.Attribute, ir.Type]]):
- self.val = val
- self.args = []
- self.kwargs = {}
-
- def __call__(self, *args, **kwargs) -> "Is":
- self.args.extend(args)
- self.kwargs.update(kwargs)
- return self
-
- def __class_getitem__(
- cls, val: Callable[..., Union[ir.Attribute, ir.Type]]
- ) -> "Is":
- return cls(val)
-
- def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
- # for most attributes and types, they are created via `.get` method,
- # here we can just omit the `.get` suffix for convenience
- if isinstance(self.val, type) and hasattr(self.val, "get"):
- self.val = self.val.get
-
- val = self.val(*self.args, **self.kwargs)
-
- if isinstance(val, ir.Type):
- val = ir.TypeAttr.get(val)
-
- return irdl.is_(val)
-
-
-class AnyOf(ConstraintExpr):
- def __init__(self, *exprs: ConstraintExpr):
- self.exprs = exprs
-
- def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
- return irdl.any_of(ctx.lower(expr) for expr in self.exprs)
-
-
-class AllOf(ConstraintExpr):
- def __init__(self, *exprs: ConstraintExpr):
- self.exprs = exprs
-
- def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
- return irdl.all_of(ctx.lower(expr) for expr in self.exprs)
-
-
-class Any(ConstraintExpr):
- def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
- return irdl.any()
-
-
-class BaseName(ConstraintExpr):
- def __init__(self, name: str):
- self.name = name
-
- def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
- return irdl.base(base_name=self.name)
-
-
-class BaseRef(ConstraintExpr):
- def __init__(self, ref):
- self.ref = ref
-
- def _lower(self, ctx: ConstraintLoweringContext) -> ir.Value:
- return irdl.base(base_ref=self.ref)
+ def _lower(self, type_) -> ir.Value:
+ origin = get_origin(type_)
+ if origin and issubclass(origin, ir.Type):
+ t = origin.get(*get_args(type_))
+ return irdl.is_(ir.TypeAttr.get(t))
+ elif origin and issubclass(origin, ir.Attribute):
+ attr = origin.get(*get_args(type_))
+ return irdl.is_(attr)
+ elif origin is UnionType:
+ return irdl.any_of(self.lower(arg) for arg in get_args(type_))
+ elif type_ is Any or isinstance(type_, TypeVar):
+ return irdl.any()
+ elif issubclass(type_, ir.Type):
+ return irdl.base(base_name=f"!{type_.type_name}")
+ elif issubclass(type_, ir.Attribute):
+ return irdl.base(base_name=f"#{type_.attr_name}")
+
+ raise TypeError(f"unsupported type in constraints: {type_}")
class FieldDef:
@@ -127,29 +65,33 @@ class FieldDef:
@dataclass
-class Operand(FieldDef):
- constraint: ConstraintExpr
- variadicity: Variadicity = Variadicity.single
+class OperandDef(FieldDef):
+ constraint: Any
+ variadicity: Variadicity
@dataclass
-class Result(FieldDef):
- constraint: ConstraintExpr
- variadicity: Variadicity = Variadicity.single
+class ResultDef(FieldDef):
+ constraint: Any
+ variadicity: Variadicity
@dataclass
-class Attribute(FieldDef):
- constraint: ConstraintExpr
- variadicity: ClassVar[Variadicity] = Variadicity.single
+class AttributeDef(FieldDef):
+ constraint: Any
+ variadicity: Variadicity
+
+ def __post_init__(self):
+ if self.variadicity != Variadicity.single:
+ raise ValueError("optional attribute is not supported in IRDL")
def partition_fields(
fields: List[FieldDef],
-) -> Tuple[List[Operand], List[Attribute], List[Result]]:
- operands = [i for i in fields if isinstance(i, Operand)]
- attrs = [i for i in fields if isinstance(i, Attribute)]
- results = [i for i in fields if isinstance(i, Result)]
+) -> Tuple[List[OperandDef], List[AttributeDef], List[ResultDef]]:
+ operands = [i for i in fields if isinstance(i, OperandDef)]
+ attrs = [i for i in fields if isinstance(i, AttributeDef)]
+ results = [i for i in fields if isinstance(i, ResultDef)]
return operands, attrs, results
@@ -165,6 +107,30 @@ def normalize_value_range(
class Operation(ir.OpView):
+ @staticmethod
+ def convert_type_to_field_def(type_) -> FieldDef:
+ variadicity = Variadicity.single
+ origin = get_origin(type_)
+ if (
+ origin is Union
+ and len(get_args(type_)) == 2
+ and get_args(type_)[1] is type(None)
+ ):
+ variadicity = Variadicity.optional
+ type_ = get_args(type_)[0]
+ elif origin is Sequence:
+ variadicity = Variadicity.variadic
+ type_ = get_args(type_)[0]
+
+ origin = get_origin(type_)
+ if origin is ir.OpOperand:
+ return OperandDef(get_args(type_)[0], variadicity)
+ elif origin is ir.OpResult:
+ return ResultDef(get_args(type_)[0], variadicity)
+ elif issubclass(origin or type_, ir.Attribute):
+ return AttributeDef(type_, variadicity)
+ raise TypeError(f"unsupported type in operation definition: {type_}")
+
@classmethod
def __init_subclass__(cls, *, name: str = None, **kwargs):
super().__init_subclass__(**kwargs)
@@ -175,10 +141,10 @@ def __init_subclass__(cls, *, name: str = None, **kwargs):
for base in cls.__bases__:
if hasattr(base, "_fields"):
fields.extend(base._fields)
- for key, value in cls.__dict__.items():
- if isinstance(value, FieldDef):
- setattr(value, "name", key)
- fields.append(value)
+ for key, value in cls.__annotations__.items():
+ field = Operation.convert_type_to_field_def(value)
+ setattr(field, "name", key)
+ fields.append(field)
# for subclasses without "name" parameter,
# just treat them as normal classes
@@ -209,7 +175,7 @@ def _variadicity_to_segment(variadicity: Variadicity) -> int:
@staticmethod
def _generate_segments(
- operands_or_results: List[Union[Operand, Result]],
+ operands_or_results: List[Union[OperandDef, ResultDef]],
) -> List[int]:
if any(i.variadicity != Variadicity.single for i in operands_or_results):
return [
@@ -222,8 +188,8 @@ def _generate_segments(
def _generate_init_signature(fields: List[FieldDef]) -> Signature:
# results are placed at the beginning of the parameter list,
# but operands and attributes can appear in any relative order.
- args = [i for i in fields if isinstance(i, Result)] + [
- i for i in fields if not isinstance(i, Result)
+ args = [i for i in fields if isinstance(i, ResultDef)] + [
+ i for i in fields if not isinstance(i, ResultDef)
]
positional_args = [
i.name for i in args if i.variadicity != Variadicity.optional
@@ -292,7 +258,7 @@ def _generate_class_attributes(
cls._ODS_RESULT_SEGMENTS = result_segments
@classmethod
- def _generate_attr_properties(cls, attrs: List[Attribute]) -> None:
+ def _generate_attr_properties(cls, attrs: List[AttributeDef]) -> None:
for attr in attrs:
setattr(
cls,
@@ -301,7 +267,7 @@ def _generate_attr_properties(cls, attrs: List[Attribute]) -> None:
)
@classmethod
- def _generate_operand_properties(cls, operands: List[Operand]) -> None:
+ def _generate_operand_properties(cls, operands: List[OperandDef]) -> None:
for i, operand in enumerate(operands):
if cls._ODS_OPERAND_SEGMENTS:
@@ -318,7 +284,7 @@ def getter(self, i=i, operand=operand):
setattr(cls, operand.name, property(lambda self, i=i: self.operands[i]))
@classmethod
- def _generate_result_properties(cls, results: List[Result]) -> None:
+ def _generate_result_properties(cls, results: List[ResultDef]) -> None:
for i, result in enumerate(results):
if cls._ODS_RESULT_SEGMENTS:
@@ -393,6 +359,7 @@ def load(cls) -> None:
raise RuntimeError(f"Dialect {cls.name} is already loaded.")
mlir_module = cls._emit_module()
+ print(mlir_module)
irdl.load_dialects(mlir_module)
_cext.register_dialect(cls)
diff --git a/mlir/python/mlir/dialects/irdl/__init__.py b/mlir/python/mlir/dialects/irdl.py
similarity index 91%
rename from mlir/python/mlir/dialects/irdl/__init__.py
rename to mlir/python/mlir/dialects/irdl.py
index 6b2787ed7966c..1ec951b69b646 100644
--- a/mlir/python/mlir/dialects/irdl/__init__.py
+++ b/mlir/python/mlir/dialects/irdl.py
@@ -2,14 +2,13 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from .._irdl_ops_gen import *
-from .._irdl_ops_gen import _Dialect
-from .._irdl_enum_gen import *
-from ..._mlir_libs._mlirDialectsIRDL import *
-from ...ir import register_attribute_builder
-from .._ods_common import _cext as _ods_cext
+from ._irdl_ops_gen import *
+from ._irdl_ops_gen import _Dialect
+from ._irdl_enum_gen import *
+from .._mlir_libs._mlirDialectsIRDL import *
+from ..ir import register_attribute_builder
+from ._ods_common import _cext as _ods_cext
from typing import Union, Sequence
-from . import dsl
_ods_ir = _ods_cext.ir
diff --git a/mlir/test/python/dialects/irdsl.py b/mlir/test/python/dialects/ext.py
similarity index 84%
rename from mlir/test/python/dialects/irdsl.py
rename to mlir/test/python/dialects/ext.py
index e80bb412139b0..f903bb8be978c 100644
--- a/mlir/test/python/dialects/irdsl.py
+++ b/mlir/test/python/dialects/ext.py
@@ -1,33 +1,32 @@
# RUN: %PYTHON %s 2>&1 | FileCheck %s
from mlir.ir import *
-from mlir.dialects.irdl import dsl as irdsl
-from mlir.dialects import arith
+from mlir.dialects import ext, arith
+from typing import Any, Optional, Sequence
import sys
def run(f):
- print("\nTEST:", f.__name__, file=sys.stderr)
+ print("\nTEST:", f.__name__)
f()
# CHECK: TEST: testMyInt
@run
def testMyInt():
- class MyInt(irdsl.Dialect, name="myint"):
+ class MyInt(ext.Dialect, name="myint"):
pass
- iattr = irdsl.BaseName("#builtin.integer")
- i32 = irdsl.Is[IntegerType.get_signless](32)
+ i32 = IntegerType[32]
class ConstantOp(MyInt.Operation, name="constant"):
- value = irdsl.Attribute(iattr)
- cst = irdsl.Result(i32)
+ value: IntegerAttr
+ cst: OpResult[i32]
class AddOp(MyInt.Operation, name="add"):
- lhs = irdsl.Operand(i32)
- rhs = irdsl.Operand(i32)
- res = irdsl.Result(i32)
+ lhs: OpOperand[i32]
+ rhs: OpOperand[i32]
+ res: OpResult[i32]
# CHECK: irdl.dialect @myint {
# CHECK: irdl.operation @constant {
@@ -90,60 +89,55 @@ class AddOp(MyInt.Operation, name="add"):
print(ConstantOp.__init__.__signature__)
+# CHECK: TEST: testIRDSL
@run
def testIRDSL():
- class Test(irdsl.Dialect, name="irdsl_test"):
+ class Test(ext.Dialect, name="irdsl_test"):
pass
- i32 = irdsl.Is[IntegerType.get_signless](32)
- i64 = irdsl.Is[IntegerType.get_signless](64)
- i32or64 = i32 | i64
- any = irdsl.Any()
- f32 = irdsl.Is[F32Type]
- iattr = irdsl.BaseName("#builtin.integer")
- fattr = irdsl.BaseName("#builtin.float")
+ i32 = IntegerType[32]
class ConstraintOp(Test.Operation, name="constraint"):
- a = irdsl.Operand(i32or64)
- b = irdsl.Operand(any)
- c = irdsl.Operand(f32 | i32)
- d = irdsl.Operand(any)
- x = irdsl.Attribute(iattr)
- y = irdsl.Attribute(fattr)
+ a: OpOperand[i32 | IntegerType[64]]
+ b: OpOperand[Any]
+ c: OpOperand[F32Type[()] | i32]
+ d: OpOperand[Any]
+ x: IntegerAttr
+ y: FloatAttr
class OptionalOp(Test.Operation, name="optional"):
- a = irdsl.Operand(i32)
- b = irdsl.Operand(i32, irdsl.Variadicity.optional)
- out1 = irdsl.Result(i32)
- out2 = irdsl.Result(i32, irdsl.Variadicity.optional)
- out3 = irdsl.Result(i32)
+ a: OpOperand[i32]
+ b: Optional[OpOperand[i32]]
+ out1: OpResult[i32]
+ out2: Optional[OpResult[i32]]
+ out3: OpResult[i32]
class Optional2Op(Test.Operation, name="optional2"):
- a = irdsl.Operand(i32, irdsl.Variadicity.optional)
- b = irdsl.Result(i32, irdsl.Variadicity.optional)
+ a: Optional[OpOperand[i32]]
+ b: Optional[OpResult[i32]]
class VariadicOp(Test.Operation, name="variadic"):
- a = irdsl.Operand(i32)
- b = irdsl.Operand(i32, irdsl.Variadicity.optional)
- c = irdsl.Operand(i32, irdsl.Variadicity.variadic)
- out1 = irdsl.Result(i32, irdsl.Variadicity.variadic)
- out2 = irdsl.Result(i32, irdsl.Variadicity.variadic)
- out3 = irdsl.Result(i32, irdsl.Variadicity.optional)
- out4 = irdsl.Result(i32)
+ a: OpOperand[i32]
+ b: Optional[OpOperand[i32]]
+ c: Sequence[OpOperand[i32]]
+ out1: Sequence[OpResult[i32]]
+ out2: Sequence[OpResult[i32]]
+ out3: Optional[OpResult[i32]]
+ out4: OpResult[i32]
class Variadic2Op(Test.Operation, name="variadic2"):
- a = irdsl.Operand(i32, irdsl.Variadicity.variadic)
- b = irdsl.Result(i32, irdsl.Variadicity.variadic)
+ a: Sequence[OpOperand[i32]]
+ b: Sequence[OpResult[i32]]
class MixedOpBase(Test.Operation):
- out = irdsl.Result(i32)
- in1 = irdsl.Operand(i32)
+ out: OpResult[i32]
+ in1: OpOperand[i32]
class MixedOp(MixedOpBase, name="mixed"):
- in2 = irdsl.Attribute(iattr)
- in3 = irdsl.Operand(i32, irdsl.Variadicity.optional)
- in4 = irdsl.Attribute(iattr)
- in5 = irdsl.Operand(i32)
+ in2: IntegerAttr
+ in3: Optional[OpOperand[i32]]
+ in4: IntegerAttr
+ in5: OpOperand[i32]
# CHECK: irdl.dialect @irdsl_test {
# CHECK: irdl.operation @constraint {
>From 7e7b5187904a0733877e729a5bcd31961d9762b8 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 9 Jan 2026 21:41:01 +0800
Subject: [PATCH 20/20] add match_optional
---
mlir/python/mlir/dialects/ext.py | 26 +++++++++++++++++---------
1 file changed, 17 insertions(+), 9 deletions(-)
diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 18ab977d164de..14b25119946d2 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -48,7 +48,7 @@ def _lower(self, type_) -> ir.Value:
elif origin and issubclass(origin, ir.Attribute):
attr = origin.get(*get_args(type_))
return irdl.is_(attr)
- elif origin is UnionType:
+ elif origin is UnionType or origin is Union:
return irdl.any_of(self.lower(arg) for arg in get_args(type_))
elif type_ is Any or isinstance(type_, TypeVar):
return irdl.any()
@@ -106,19 +106,27 @@ def normalize_value_range(
return value_range
+def match_optional(type_):
+ origin = get_origin(type_)
+ args = get_args(type_)
+ if (
+ (origin is Union or origin is UnionType)
+ and len(args) == 2
+ and type(None) in args
+ ):
+ return args[0] if args[1] is type(None) else args[1]
+
+ return None
+
+
class Operation(ir.OpView):
@staticmethod
def convert_type_to_field_def(type_) -> FieldDef:
variadicity = Variadicity.single
- origin = get_origin(type_)
- if (
- origin is Union
- and len(get_args(type_)) == 2
- and get_args(type_)[1] is type(None)
- ):
+ if inner := match_optional(type_):
variadicity = Variadicity.optional
- type_ = get_args(type_)[0]
- elif origin is Sequence:
+ type_ = inner
+ elif get_origin(type_) is Sequence:
variadicity = Variadicity.variadic
type_ = get_args(type_)[0]
More information about the Mlir-commits
mailing list