[Mlir-commits] [mlir] [MLIR][Python] Support type definitions in Python-defined dialects (PR #182805)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Feb 22 21:38:25 PST 2026


https://github.com/PragmaTwice created https://github.com/llvm/llvm-project/pull/182805

In this PR, we added basic support of type definitions in Python-defined dialects, including:
- IRDL codegen for type definitions
- Type builders like `MyType.get(..)` and type parameter accessors (e.g. `my_type.param1`)
- Use Python-defined types in Python-defined oeprations

```python
class TestType(Dialect, name="ext_type"):
    pass

class Array(TestType.Type, name="array"):
    elem_type: IntegerType[32] | IntegerType[64]
    length: IntegerAttr

class MakeArrayOp(TestType.Operation, name="make_array"):
    arr: Result[Array]

class MakeArray3Op(TestType.Operation, name="make_array3"):
    arr: Result[Array[IntegerType[32], IntegerAttr[IntegerType[32], 3]]]
```

>From 625f86f03e5cdb89b63edddb852603b3623f9012 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 23 Feb 2026 13:34:46 +0800
Subject: [PATCH] [MLIR][Python] Support type definitions in Python-defined
 dialects

---
 mlir/python/mlir/dialects/ext.py | 117 +++++++++++++++++++++++++++++--
 mlir/test/python/dialects/ext.py |  72 +++++++++++++++++++
 2 files changed, 183 insertions(+), 6 deletions(-)

diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 79095658944e5..165ac9f119d7b 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -43,6 +43,20 @@
 register_operation = _cext.register_operation
 
 
+def construct_instance(origin, args):
+    # `origin.get` is to construct an instance of MLIR type or attribute.
+    return origin.get(
+        *(
+            (
+                construct_instance(get_origin(arg), get_args(arg))
+                if get_origin(arg)
+                else arg
+            )
+            for arg in args
+        )
+    )
+
+
 class ConstraintLoweringContext:
     def __init__(self):
         self._cache: Dict[str, ir.Value] = {}
@@ -70,14 +84,19 @@ def _lower(self, type_) -> ir.Value:
         elif isinstance(type_, TypeVar):
             return self.lower(type_)
         elif origin and issubclass(origin, ir.Type):
-            # `origin.get` is to construct an instance of MLIR type.
-            t = origin.get(*get_args(type_))
+            if issubclass(origin, Type):
+                return irdl.parametric(
+                    base_type=[origin._dialect_name, origin._name],
+                    args=[self.lower(arg) for arg in get_args(type_)],
+                )
+            t = construct_instance(origin, get_args(type_))
             return irdl.is_(ir.TypeAttr.get(t))
         elif origin and issubclass(origin, ir.Attribute):
-            # `origin.get` is to construct an instance of MLIR attribute.
-            attr = origin.get(*get_args(type_))
+            attr = construct_instance(origin, get_args(type_))
             return irdl.is_(attr)
         elif issubclass(type_, ir.Type):
+            if issubclass(type_, Type):
+                return irdl.base(base_ref=[type_._dialect_name, type_._name])
             return irdl.base(base_name=f"!{type_.type_name}")
         elif issubclass(type_, ir.Attribute):
             return irdl.base(base_name=f"#{type_.attr_name}")
@@ -96,8 +115,7 @@ def infer_type(type_) -> Optional[Callable[[], ir.Type]]:
 
     origin = get_origin(type_)
     if origin and issubclass(origin, ir.Type):
-        # `origin.get` is to construct an instance of MLIR type/attribute.
-        return lambda: origin.get(*get_args(type_))
+        return lambda: construct_instance(origin, get_args(type_))
     elif isinstance(type_, TypeVar):
         return infer_type(type_.__bound__)
     return None
@@ -488,6 +506,84 @@ def _emit_operation(cls) -> None:
                 )
 
 
+ at dataclass
+class ParamDef:
+    name: str
+    constraint: Any
+
+
+class Type(ir.DynamicType):
+    @classmethod
+    def __init_subclass__(
+        cls,
+        *,
+        name: str | None = None,
+        dialect: type | None = None,
+        **kwargs,
+    ):
+        super().__init_subclass__(**kwargs)
+
+        fields = []
+
+        for base in cls.__bases__:
+            if hasattr(base, "_fields"):
+                fields.extend(base._fields)
+        for key, value in cls.__annotations__.items():
+            field = ParamDef(key, value)
+            fields.append(field)
+
+        cls._fields = fields
+
+        if dialect:
+            if hasattr(cls, "_dialect_obj"):
+                raise RuntimeError(
+                    f"This type has already been attached to dialect '{cls._dialect_obj.DIALECT_NAMESPACE}'."
+                )
+            cls._dialect_obj = dialect
+
+        # for subclasses without "name" parameter,
+        # just treat them as normal classes
+        if not name:
+            return
+
+        if not hasattr(cls, "_dialect_obj"):
+            raise RuntimeError(
+                "Type subclasses must either inherit from a Dialect's Type subclass "
+                "or provide the dialect as a class keyword argument."
+            )
+
+        cls._name = name
+        cls._dialect_name = cls._dialect_obj.DIALECT_NAMESPACE
+        cls.type_name = f"{cls._dialect_name}.{name}"
+
+        for i, field in enumerate(cls._fields):
+            setattr(
+                cls,
+                field.name,
+                property(lambda self, i=i: self.params[i]),
+            )
+
+        cls._dialect_obj.types.append(cls)
+
+    @classmethod
+    def get(cls, *args, context=None):
+        args = [
+            ir.TypeAttr.get(arg) if isinstance(arg, ir.Type) else arg for arg in args
+        ]
+        return cls(ir.DynamicType.get(cls.type_name, args, context=context))
+
+    @classmethod
+    def _emit_type(cls) -> None:
+        ctx = ConstraintLoweringContext()
+
+        t = irdl.type_(cls._name)
+        with ir.InsertionPoint(t.body):
+            irdl.parameters(
+                [ctx.lower(f.constraint) for f in cls._fields],
+                [f.name for f in cls._fields],
+            )
+
+
 class Dialect(ir.Dialect):
     """
     Base class of a Python-defined dialect.
@@ -521,11 +617,20 @@ def __init_subclass__(cls, name: str, **kwargs):
             dict(),
             dialect=cls,
         )
+        cls.types = []
+        cls.Type = type(
+            "Type",
+            (Type,),
+            dict(),
+            dialect=cls,
+        )
 
     @classmethod
     def _emit_dialect(cls) -> None:
         d = irdl.dialect(cls.name)
         with ir.InsertionPoint(d.body):
+            for type_ in cls.types:
+                type_._emit_type()
             for op in cls.operations:
                 op._emit_operation()
 
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index 196af91e511ec..78fe188b0ee58 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -519,3 +519,75 @@ class NoTermOp(TestRegion.Operation, name="no_term", traits=[NoTerminatorTrait])
             # CHECK: Verification failed:
             # CHECK: result type mismatch
             print(e)
+
+
+# CHECK: TEST: testExtDialectWithType
+ at run
+def testExtDialectWithType():
+    class TestType(Dialect, name="ext_type"):
+        pass
+
+    class Array(TestType.Type, name="array"):
+        elem_type: IntegerType[32] | IntegerType[64]
+        length: IntegerAttr
+
+    class MakeArrayOp(TestType.Operation, name="make_array"):
+        arr: Result[Array]
+
+    class MakeArray3Op(TestType.Operation, name="make_array3"):
+        arr: Result[Array[IntegerType[32], IntegerAttr[IntegerType[32], 3]]]
+
+    with Context(), Location.unknown():
+        TestType.load()
+        # CHECK: irdl.dialect @ext_type {
+        # CHECK:   irdl.type @array {
+        # CHECK:     %0 = irdl.is i32
+        # CHECK:     %1 = irdl.is i64
+        # CHECK:     %2 = irdl.any_of(%0, %1)
+        # CHECK:     %3 = irdl.base "#builtin.integer"
+        # CHECK:     irdl.parameters(elem_type: %2, length: %3)
+        # CHECK:   }
+        # CHECK:   irdl.operation @make_array {
+        # CHECK:     %0 = irdl.base @ext_type::@array
+        # CHECK:     irdl.results(arr: %0)
+        # CHECK:   }
+        # CHECK:   irdl.operation @make_array3 {
+        # CHECK:     %0 = irdl.is i32
+        # CHECK:     %1 = irdl.is 3 : i32
+        # CHECK:     %2 = irdl.parametric @ext_type::@array<%0, %1>
+        # CHECK:     irdl.results(arr: %2)
+        # CHECK:   }
+        # CHECK: }
+        print(TestType._mlir_module)
+
+        # CHECK: ext_type.array
+        print(Array.type_name)
+
+        i32 = IntegerType.get_signless(32)
+        i64 = IntegerType.get_signless(64)
+        a4 = Array.get(i32, IntegerAttr.get(i32, 4))
+        a6 = Array.get(i64, IntegerAttr.get(i32, 6))
+        # CHECK: !ext_type.array<i32, 4 : i32>
+        print(a4)
+        # CHECK: !ext_type.array<i64, 6 : i32>
+        print(a6)
+
+        # CHECK: i32
+        print(a4.elem_type)
+        # CHECK: 4 : i32
+        print(a4.length)
+        # CHECK: i64
+        print(a6.elem_type)
+        # CHECK: 6 : i32
+        print(a6.length)
+
+        module = Module.create()
+        with InsertionPoint(module.body):
+            MakeArrayOp(a4)
+            MakeArrayOp(a6)
+            MakeArray3Op()
+
+        # CHECK: %0 = "ext_type.make_array"() : () -> !ext_type.array<i32, 4 : i32>
+        # CHECK: %1 = "ext_type.make_array"() : () -> !ext_type.array<i64, 6 : i32>
+        # CHECK: %2 = "ext_type.make_array3"() : () -> !ext_type.array<i32, 3 : i32>
+        print(module)



More information about the Mlir-commits mailing list