[Mlir-commits] [mlir] [MLIR][Python] Support attribute definitions in Python-defined dialects (PR #183907)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Feb 28 04:31:53 PST 2026
https://github.com/PragmaTwice created https://github.com/llvm/llvm-project/pull/183907
This PR is quite similiar to https://github.com/llvm/llvm-project/pull/182805.
We added basic support of attribute definitions in Python-defined dialects, including:
- IRDL codegen for attribute definitions
- Attr builders like `MyAttr.get(..)` and attr parameter accessors (e.g. `my_type.param1`)
- Use Python-defined attrs in Python-defined operations
>From f948a8db7ac6bec98752d56511fc633ab2fd17e1 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 28 Feb 2026 00:21:26 +0800
Subject: [PATCH] [MLIR][Python] Support attribute definitions in
Python-defined dialects
---
mlir/python/mlir/dialects/ext.py | 103 +++++++++++++++++++++++++++++++
mlir/test/python/dialects/ext.py | 77 +++++++++++++++++++++++
2 files changed, 180 insertions(+)
diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index d88e25cced8f6..c601eb29d4b10 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -32,6 +32,7 @@
"Result",
"Region",
"Type",
+ "Attribute",
"register_dialect",
"register_operation",
]
@@ -103,6 +104,11 @@ def _lower(self, type_) -> ir.Value:
t = construct_instance(origin, get_args(type_))
return irdl.is_(ir.TypeAttr.get(t))
elif origin and issubclass(origin, ir.Attribute):
+ if issubclass(origin, Attribute):
+ return irdl.parametric(
+ base_type=[origin._dialect_name, origin._name],
+ args=[self.lower(arg) for arg in get_args(type_)],
+ )
attr = construct_instance(origin, get_args(type_))
return irdl.is_(attr)
elif issubclass(type_, ir.Type):
@@ -110,6 +116,8 @@ def _lower(self, type_) -> ir.Value:
return irdl.base(base_ref=[type_._dialect_name, type_._name])
return irdl.base(base_name=f"!{type_.type_name}")
elif issubclass(type_, ir.Attribute):
+ if issubclass(type_, Attribute):
+ return irdl.base(base_ref=[type_._dialect_name, type_._name])
return irdl.base(base_name=f"#{type_.attr_name}")
raise TypeError(f"unsupported type in constraints: {type_}")
@@ -647,6 +655,92 @@ def _emit_type(cls) -> None:
)
+class Attribute(ir.DynamicAttr):
+ """
+ Base class of Python-defined attributes.
+
+ The following example shows two ways to define attributes via this class:
+ ```python
+ class MyAttr(MyDialect.Attribute, name=..):
+ ...
+
+ class MyAttr(Attribute, dialect=MyDialect, name=..):
+ ...
+ ```
+ """
+
+ @classmethod
+ def __init_subclass__(
+ cls,
+ *,
+ name: str | None = None,
+ dialect: type | None = None,
+ **kwargs,
+ ):
+ super().__init_subclass__(**kwargs)
+
+ fields = []
+
+ for base in cls.__bases__:
+ if hasattr(base, "_fields"):
+ fields.extend(base._fields)
+ for key, value in cls.__annotations__.items():
+ field = ParamDef(key, value)
+ fields.append(field)
+
+ cls._fields = fields
+
+ if dialect:
+ if hasattr(cls, "_dialect_obj"):
+ raise RuntimeError(
+ f"This attribute has already been attached to dialect '{cls._dialect_obj.DIALECT_NAMESPACE}'."
+ )
+ cls._dialect_obj = dialect
+
+ # for subclasses without "name" parameter,
+ # just treat them as normal classes
+ if not name:
+ return
+
+ if not hasattr(cls, "_dialect_obj"):
+ raise RuntimeError(
+ "Attribute subclasses must either inherit from a Dialect's Attribute subclass "
+ "or provide the dialect as a class keyword argument."
+ )
+
+ cls._name = name
+ cls._dialect_name = cls._dialect_obj.DIALECT_NAMESPACE
+ cls.attr_name = f"{cls._dialect_name}.{name}"
+
+ for i, field in enumerate(cls._fields):
+ setattr(
+ cls,
+ field.name,
+ property(lambda self, i=i: self.params[i]),
+ )
+
+ cls._dialect_obj.attributes.append(cls)
+
+ @classmethod
+ def get(cls, *args, context=None):
+ args = [
+ ir.TypeAttr.get(arg, context) if isinstance(arg, ir.Type) else arg
+ for arg in args
+ ]
+ return cls(ir.DynamicAttr.get(cls.attr_name, args, context=context))
+
+ @classmethod
+ def _emit_attr(cls) -> None:
+ ctx = ConstraintLoweringContext()
+
+ t = irdl.attribute(cls._name)
+ with ir.InsertionPoint(t.body):
+ irdl.parameters(
+ [ctx.lower(f.constraint) for f in cls._fields],
+ [f.name for f in cls._fields],
+ )
+
+
class Dialect(ir.Dialect):
"""
Base class of a Python-defined dialect.
@@ -687,6 +781,13 @@ def __init_subclass__(cls, name: str, **kwargs):
dict(),
dialect=cls,
)
+ cls.attributes = []
+ cls.Attribute = type(
+ "Attribute",
+ (Attribute,),
+ dict(),
+ dialect=cls,
+ )
@classmethod
def _emit_dialect(cls) -> None:
@@ -694,6 +795,8 @@ def _emit_dialect(cls) -> None:
with ir.InsertionPoint(d.body):
for type_ in cls.types:
type_._emit_type()
+ for attr in cls.attributes:
+ attr._emit_attr()
for op in cls.operations:
op._emit_operation()
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index 2921615e75d54..01d6aa92c99c0 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -605,3 +605,80 @@ class MakeArray3Op(TestType.Operation, name="make_array3"):
# CHECK: %2 = "ext_type.make_array3"() : () -> !ext_type.array<i32, 3 : i32>
assert module.operation.verify()
print(module)
+
+
+# CHECK: TEST: testExtDialectWithAttr
+ at run
+def testExtDialectWithAttr():
+ class TestAttr(Dialect, name="ext_attr"):
+ pass
+
+ class IntPair(TestAttr.Attribute, name="pair"):
+ first: IntegerAttr
+ second: IntegerAttr
+
+ class StrPair(TestAttr.Attribute, name="str_pair"):
+ first: StringAttr
+ second: StringAttr
+
+ class Op1(TestAttr.Operation, name="op1"):
+ pair: IntPair
+
+ class Op2(TestAttr.Operation, name="op2"):
+ pair: StrPair
+ pair2: StrPair[StringAttr["a"], StringAttr["b"]]
+
+ with Context(), Location.unknown():
+ TestAttr.load()
+ # CHECK: irdl.dialect @ext_attr {
+ # CHECK: irdl.attribute @pair {
+ # CHECK: %0 = irdl.base "#builtin.integer"
+ # CHECK: %1 = irdl.base "#builtin.integer"
+ # CHECK: irdl.parameters(first: %0, second: %1)
+ # CHECK: }
+ # CHECK: irdl.attribute @str_pair {
+ # CHECK: %0 = irdl.base "#builtin.string"
+ # CHECK: %1 = irdl.base "#builtin.string"
+ # CHECK: irdl.parameters(first: %0, second: %1)
+ # CHECK: }
+ # CHECK: irdl.operation @op1 {
+ # CHECK: %0 = irdl.base @ext_attr::@pair
+ # CHECK: irdl.attributes {"pair" = %0}
+ # CHECK: }
+ # CHECK: irdl.operation @op2 {
+ # CHECK: %0 = irdl.base @ext_attr::@str_pair
+ # CHECK: %1 = irdl.is "a"
+ # CHECK: %2 = irdl.is "b"
+ # CHECK: %3 = irdl.parametric @ext_attr::@str_pair<%1, %2>
+ # CHECK: irdl.attributes {"pair" = %0, "pair2" = %3}
+ # CHECK: }
+ # CHECK: }
+ print(TestAttr._mlir_module)
+
+ # CHECK: ext_attr.pair
+ print(IntPair.attr_name)
+
+ # CHECK: ext_attr.str_pair
+ print(StrPair.attr_name)
+
+ ip = IntPair.get(
+ IntegerAttr.get(IntegerType.get_signless(32), 1),
+ IntegerAttr.get(IntegerType.get_signless(32), 2),
+ )
+ sp = StrPair.get(StringAttr.get("hello"), StringAttr.get("world"))
+ # CHECK: #ext_attr.pair<1 : i32, 2 : i32>
+ print(ip)
+ # CHECK: #ext_attr.str_pair<"hello", "world">
+ print(sp)
+
+ module = Module.create()
+ with InsertionPoint(module.body):
+ Op1(ip)
+ p2 = StrPair.get(StringAttr.get("a"), StringAttr.get("b"))
+ Op2(sp, p2)
+
+ assert module.operation.verify()
+
+ # CHECK: "ext_attr.op1"() {pair = #ext_attr.pair<1 : i32, 2 : i32>} : () -> ()
+ # CHECK: "ext_attr.op2"() {pair = #ext_attr.str_pair<"hello", "world">, pair2 = #ext_attr.str_pair<"a", "b">} : () -> ()
+ print(module)
More information about the Mlir-commits
mailing list