[Mlir-commits] [mlir] [MLIR][Python] Support attribute definitions in Python-defined dialects (PR #183907)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Feb 28 04:31:53 PST 2026


https://github.com/PragmaTwice created https://github.com/llvm/llvm-project/pull/183907

This PR is quite similiar to https://github.com/llvm/llvm-project/pull/182805.

We added basic support of attribute definitions in Python-defined dialects, including:

- IRDL codegen for attribute definitions
- Attr builders like `MyAttr.get(..)` and attr parameter accessors (e.g. `my_type.param1`)
- Use Python-defined attrs in Python-defined operations

>From f948a8db7ac6bec98752d56511fc633ab2fd17e1 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sat, 28 Feb 2026 00:21:26 +0800
Subject: [PATCH] [MLIR][Python] Support attribute definitions in
 Python-defined dialects

---
 mlir/python/mlir/dialects/ext.py | 103 +++++++++++++++++++++++++++++++
 mlir/test/python/dialects/ext.py |  77 +++++++++++++++++++++++
 2 files changed, 180 insertions(+)

diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index d88e25cced8f6..c601eb29d4b10 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -32,6 +32,7 @@
     "Result",
     "Region",
     "Type",
+    "Attribute",
     "register_dialect",
     "register_operation",
 ]
@@ -103,6 +104,11 @@ def _lower(self, type_) -> ir.Value:
             t = construct_instance(origin, get_args(type_))
             return irdl.is_(ir.TypeAttr.get(t))
         elif origin and issubclass(origin, ir.Attribute):
+            if issubclass(origin, Attribute):
+                return irdl.parametric(
+                    base_type=[origin._dialect_name, origin._name],
+                    args=[self.lower(arg) for arg in get_args(type_)],
+                )
             attr = construct_instance(origin, get_args(type_))
             return irdl.is_(attr)
         elif issubclass(type_, ir.Type):
@@ -110,6 +116,8 @@ def _lower(self, type_) -> ir.Value:
                 return irdl.base(base_ref=[type_._dialect_name, type_._name])
             return irdl.base(base_name=f"!{type_.type_name}")
         elif issubclass(type_, ir.Attribute):
+            if issubclass(type_, Attribute):
+                return irdl.base(base_ref=[type_._dialect_name, type_._name])
             return irdl.base(base_name=f"#{type_.attr_name}")
 
         raise TypeError(f"unsupported type in constraints: {type_}")
@@ -647,6 +655,92 @@ def _emit_type(cls) -> None:
             )
 
 
+class Attribute(ir.DynamicAttr):
+    """
+    Base class of Python-defined attributes.
+
+    The following example shows two ways to define attributes via this class:
+    ```python
+    class MyAttr(MyDialect.Attribute, name=..):
+      ...
+
+    class MyAttr(Attribute, dialect=MyDialect, name=..):
+      ...
+    ```
+    """
+
+    @classmethod
+    def __init_subclass__(
+        cls,
+        *,
+        name: str | None = None,
+        dialect: type | None = None,
+        **kwargs,
+    ):
+        super().__init_subclass__(**kwargs)
+
+        fields = []
+
+        for base in cls.__bases__:
+            if hasattr(base, "_fields"):
+                fields.extend(base._fields)
+        for key, value in cls.__annotations__.items():
+            field = ParamDef(key, value)
+            fields.append(field)
+
+        cls._fields = fields
+
+        if dialect:
+            if hasattr(cls, "_dialect_obj"):
+                raise RuntimeError(
+                    f"This attribute has already been attached to dialect '{cls._dialect_obj.DIALECT_NAMESPACE}'."
+                )
+            cls._dialect_obj = dialect
+
+        # for subclasses without "name" parameter,
+        # just treat them as normal classes
+        if not name:
+            return
+
+        if not hasattr(cls, "_dialect_obj"):
+            raise RuntimeError(
+                "Attribute subclasses must either inherit from a Dialect's Attribute subclass "
+                "or provide the dialect as a class keyword argument."
+            )
+
+        cls._name = name
+        cls._dialect_name = cls._dialect_obj.DIALECT_NAMESPACE
+        cls.attr_name = f"{cls._dialect_name}.{name}"
+
+        for i, field in enumerate(cls._fields):
+            setattr(
+                cls,
+                field.name,
+                property(lambda self, i=i: self.params[i]),
+            )
+
+        cls._dialect_obj.attributes.append(cls)
+
+    @classmethod
+    def get(cls, *args, context=None):
+        args = [
+            ir.TypeAttr.get(arg, context) if isinstance(arg, ir.Type) else arg
+            for arg in args
+        ]
+        return cls(ir.DynamicAttr.get(cls.attr_name, args, context=context))
+
+    @classmethod
+    def _emit_attr(cls) -> None:
+        ctx = ConstraintLoweringContext()
+
+        t = irdl.attribute(cls._name)
+        with ir.InsertionPoint(t.body):
+            irdl.parameters(
+                [ctx.lower(f.constraint) for f in cls._fields],
+                [f.name for f in cls._fields],
+            )
+
+
 class Dialect(ir.Dialect):
     """
     Base class of a Python-defined dialect.
@@ -687,6 +781,13 @@ def __init_subclass__(cls, name: str, **kwargs):
             dict(),
             dialect=cls,
         )
+        cls.attributes = []
+        cls.Attribute = type(
+            "Attribute",
+            (Attribute,),
+            dict(),
+            dialect=cls,
+        )
 
     @classmethod
     def _emit_dialect(cls) -> None:
@@ -694,6 +795,8 @@ def _emit_dialect(cls) -> None:
         with ir.InsertionPoint(d.body):
             for type_ in cls.types:
                 type_._emit_type()
+            for attr in cls.attributes:
+                attr._emit_attr()
             for op in cls.operations:
                 op._emit_operation()
 
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index 2921615e75d54..01d6aa92c99c0 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -605,3 +605,80 @@ class MakeArray3Op(TestType.Operation, name="make_array3"):
         # CHECK: %2 = "ext_type.make_array3"() : () -> !ext_type.array<i32, 3 : i32>
         assert module.operation.verify()
         print(module)
+
+
+# CHECK: TEST: testExtDialectWithAttr
+ at run
+def testExtDialectWithAttr():
+    class TestAttr(Dialect, name="ext_attr"):
+        pass
+
+    class IntPair(TestAttr.Attribute, name="pair"):
+        first: IntegerAttr
+        second: IntegerAttr
+
+    class StrPair(TestAttr.Attribute, name="str_pair"):
+        first: StringAttr
+        second: StringAttr
+
+    class Op1(TestAttr.Operation, name="op1"):
+        pair: IntPair
+
+    class Op2(TestAttr.Operation, name="op2"):
+        pair: StrPair
+        pair2: StrPair[StringAttr["a"], StringAttr["b"]]
+
+    with Context(), Location.unknown():
+        TestAttr.load()
+        # CHECK: irdl.dialect @ext_attr {
+        # CHECK:   irdl.attribute @pair {
+        # CHECK:     %0 = irdl.base "#builtin.integer"
+        # CHECK:     %1 = irdl.base "#builtin.integer"
+        # CHECK:     irdl.parameters(first: %0, second: %1)
+        # CHECK:   }
+        # CHECK:   irdl.attribute @str_pair {
+        # CHECK:     %0 = irdl.base "#builtin.string"
+        # CHECK:     %1 = irdl.base "#builtin.string"
+        # CHECK:     irdl.parameters(first: %0, second: %1)
+        # CHECK:   }
+        # CHECK:   irdl.operation @op1 {
+        # CHECK:     %0 = irdl.base @ext_attr::@pair
+        # CHECK:     irdl.attributes {"pair" = %0}
+        # CHECK:   }
+        # CHECK:   irdl.operation @op2 {
+        # CHECK:     %0 = irdl.base @ext_attr::@str_pair
+        # CHECK:     %1 = irdl.is "a"
+        # CHECK:     %2 = irdl.is "b"
+        # CHECK:     %3 = irdl.parametric @ext_attr::@str_pair<%1, %2>
+        # CHECK:     irdl.attributes {"pair" = %0, "pair2" = %3}
+        # CHECK:   }
+        # CHECK: }
+        print(TestAttr._mlir_module)
+
+        # CHECK: ext_attr.pair
+        print(IntPair.attr_name)
+
+        # CHECK: ext_attr.str_pair
+        print(StrPair.attr_name)
+
+        ip = IntPair.get(
+            IntegerAttr.get(IntegerType.get_signless(32), 1),
+            IntegerAttr.get(IntegerType.get_signless(32), 2),
+        )
+        sp = StrPair.get(StringAttr.get("hello"), StringAttr.get("world"))
+        # CHECK: #ext_attr.pair<1 : i32, 2 : i32>
+        print(ip)
+        # CHECK: #ext_attr.str_pair<"hello", "world">
+        print(sp)
+
+        module = Module.create()
+        with InsertionPoint(module.body):
+            Op1(ip)
+            p2 = StrPair.get(StringAttr.get("a"), StringAttr.get("b"))
+            Op2(sp, p2)
+
+        assert module.operation.verify()
+
+        # CHECK: "ext_attr.op1"() {pair = #ext_attr.pair<1 : i32, 2 : i32>} : () -> ()
+        # CHECK: "ext_attr.op2"() {pair = #ext_attr.str_pair<"hello", "world">, pair2 = #ext_attr.str_pair<"a", "b">} : () -> ()
+        print(module)



More information about the Mlir-commits mailing list