[Mlir-commits] [mlir] [MLIR][Python] Support type definitions in Python-defined dialects (PR #182805)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Feb 22 21:38:25 PST 2026
https://github.com/PragmaTwice created https://github.com/llvm/llvm-project/pull/182805
In this PR, we added basic support of type definitions in Python-defined dialects, including:
- IRDL codegen for type definitions
- Type builders like `MyType.get(..)` and type parameter accessors (e.g. `my_type.param1`)
- Use Python-defined types in Python-defined oeprations
```python
class TestType(Dialect, name="ext_type"):
pass
class Array(TestType.Type, name="array"):
elem_type: IntegerType[32] | IntegerType[64]
length: IntegerAttr
class MakeArrayOp(TestType.Operation, name="make_array"):
arr: Result[Array]
class MakeArray3Op(TestType.Operation, name="make_array3"):
arr: Result[Array[IntegerType[32], IntegerAttr[IntegerType[32], 3]]]
```
>From 625f86f03e5cdb89b63edddb852603b3623f9012 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Mon, 23 Feb 2026 13:34:46 +0800
Subject: [PATCH] [MLIR][Python] Support type definitions in Python-defined
dialects
---
mlir/python/mlir/dialects/ext.py | 117 +++++++++++++++++++++++++++++--
mlir/test/python/dialects/ext.py | 72 +++++++++++++++++++
2 files changed, 183 insertions(+), 6 deletions(-)
diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 79095658944e5..165ac9f119d7b 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -43,6 +43,20 @@
register_operation = _cext.register_operation
+def construct_instance(origin, args):
+ # `origin.get` is to construct an instance of MLIR type or attribute.
+ return origin.get(
+ *(
+ (
+ construct_instance(get_origin(arg), get_args(arg))
+ if get_origin(arg)
+ else arg
+ )
+ for arg in args
+ )
+ )
+
+
class ConstraintLoweringContext:
def __init__(self):
self._cache: Dict[str, ir.Value] = {}
@@ -70,14 +84,19 @@ def _lower(self, type_) -> ir.Value:
elif isinstance(type_, TypeVar):
return self.lower(type_)
elif origin and issubclass(origin, ir.Type):
- # `origin.get` is to construct an instance of MLIR type.
- t = origin.get(*get_args(type_))
+ if issubclass(origin, Type):
+ return irdl.parametric(
+ base_type=[origin._dialect_name, origin._name],
+ args=[self.lower(arg) for arg in get_args(type_)],
+ )
+ t = construct_instance(origin, get_args(type_))
return irdl.is_(ir.TypeAttr.get(t))
elif origin and issubclass(origin, ir.Attribute):
- # `origin.get` is to construct an instance of MLIR attribute.
- attr = origin.get(*get_args(type_))
+ attr = construct_instance(origin, get_args(type_))
return irdl.is_(attr)
elif issubclass(type_, ir.Type):
+ if issubclass(type_, Type):
+ return irdl.base(base_ref=[type_._dialect_name, type_._name])
return irdl.base(base_name=f"!{type_.type_name}")
elif issubclass(type_, ir.Attribute):
return irdl.base(base_name=f"#{type_.attr_name}")
@@ -96,8 +115,7 @@ def infer_type(type_) -> Optional[Callable[[], ir.Type]]:
origin = get_origin(type_)
if origin and issubclass(origin, ir.Type):
- # `origin.get` is to construct an instance of MLIR type/attribute.
- return lambda: origin.get(*get_args(type_))
+ return lambda: construct_instance(origin, get_args(type_))
elif isinstance(type_, TypeVar):
return infer_type(type_.__bound__)
return None
@@ -488,6 +506,84 @@ def _emit_operation(cls) -> None:
)
+ at dataclass
+class ParamDef:
+ name: str
+ constraint: Any
+
+
+class Type(ir.DynamicType):
+ @classmethod
+ def __init_subclass__(
+ cls,
+ *,
+ name: str | None = None,
+ dialect: type | None = None,
+ **kwargs,
+ ):
+ super().__init_subclass__(**kwargs)
+
+ fields = []
+
+ for base in cls.__bases__:
+ if hasattr(base, "_fields"):
+ fields.extend(base._fields)
+ for key, value in cls.__annotations__.items():
+ field = ParamDef(key, value)
+ fields.append(field)
+
+ cls._fields = fields
+
+ if dialect:
+ if hasattr(cls, "_dialect_obj"):
+ raise RuntimeError(
+ f"This type has already been attached to dialect '{cls._dialect_obj.DIALECT_NAMESPACE}'."
+ )
+ cls._dialect_obj = dialect
+
+ # for subclasses without "name" parameter,
+ # just treat them as normal classes
+ if not name:
+ return
+
+ if not hasattr(cls, "_dialect_obj"):
+ raise RuntimeError(
+ "Type subclasses must either inherit from a Dialect's Type subclass "
+ "or provide the dialect as a class keyword argument."
+ )
+
+ cls._name = name
+ cls._dialect_name = cls._dialect_obj.DIALECT_NAMESPACE
+ cls.type_name = f"{cls._dialect_name}.{name}"
+
+ for i, field in enumerate(cls._fields):
+ setattr(
+ cls,
+ field.name,
+ property(lambda self, i=i: self.params[i]),
+ )
+
+ cls._dialect_obj.types.append(cls)
+
+ @classmethod
+ def get(cls, *args, context=None):
+ args = [
+ ir.TypeAttr.get(arg) if isinstance(arg, ir.Type) else arg for arg in args
+ ]
+ return cls(ir.DynamicType.get(cls.type_name, args, context=context))
+
+ @classmethod
+ def _emit_type(cls) -> None:
+ ctx = ConstraintLoweringContext()
+
+ t = irdl.type_(cls._name)
+ with ir.InsertionPoint(t.body):
+ irdl.parameters(
+ [ctx.lower(f.constraint) for f in cls._fields],
+ [f.name for f in cls._fields],
+ )
+
+
class Dialect(ir.Dialect):
"""
Base class of a Python-defined dialect.
@@ -521,11 +617,20 @@ def __init_subclass__(cls, name: str, **kwargs):
dict(),
dialect=cls,
)
+ cls.types = []
+ cls.Type = type(
+ "Type",
+ (Type,),
+ dict(),
+ dialect=cls,
+ )
@classmethod
def _emit_dialect(cls) -> None:
d = irdl.dialect(cls.name)
with ir.InsertionPoint(d.body):
+ for type_ in cls.types:
+ type_._emit_type()
for op in cls.operations:
op._emit_operation()
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index 196af91e511ec..78fe188b0ee58 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -519,3 +519,75 @@ class NoTermOp(TestRegion.Operation, name="no_term", traits=[NoTerminatorTrait])
# CHECK: Verification failed:
# CHECK: result type mismatch
print(e)
+
+
+# CHECK: TEST: testExtDialectWithType
+ at run
+def testExtDialectWithType():
+ class TestType(Dialect, name="ext_type"):
+ pass
+
+ class Array(TestType.Type, name="array"):
+ elem_type: IntegerType[32] | IntegerType[64]
+ length: IntegerAttr
+
+ class MakeArrayOp(TestType.Operation, name="make_array"):
+ arr: Result[Array]
+
+ class MakeArray3Op(TestType.Operation, name="make_array3"):
+ arr: Result[Array[IntegerType[32], IntegerAttr[IntegerType[32], 3]]]
+
+ with Context(), Location.unknown():
+ TestType.load()
+ # CHECK: irdl.dialect @ext_type {
+ # CHECK: irdl.type @array {
+ # CHECK: %0 = irdl.is i32
+ # CHECK: %1 = irdl.is i64
+ # CHECK: %2 = irdl.any_of(%0, %1)
+ # CHECK: %3 = irdl.base "#builtin.integer"
+ # CHECK: irdl.parameters(elem_type: %2, length: %3)
+ # CHECK: }
+ # CHECK: irdl.operation @make_array {
+ # CHECK: %0 = irdl.base @ext_type::@array
+ # CHECK: irdl.results(arr: %0)
+ # CHECK: }
+ # CHECK: irdl.operation @make_array3 {
+ # CHECK: %0 = irdl.is i32
+ # CHECK: %1 = irdl.is 3 : i32
+ # CHECK: %2 = irdl.parametric @ext_type::@array<%0, %1>
+ # CHECK: irdl.results(arr: %2)
+ # CHECK: }
+ # CHECK: }
+ print(TestType._mlir_module)
+
+ # CHECK: ext_type.array
+ print(Array.type_name)
+
+ i32 = IntegerType.get_signless(32)
+ i64 = IntegerType.get_signless(64)
+ a4 = Array.get(i32, IntegerAttr.get(i32, 4))
+ a6 = Array.get(i64, IntegerAttr.get(i32, 6))
+ # CHECK: !ext_type.array<i32, 4 : i32>
+ print(a4)
+ # CHECK: !ext_type.array<i64, 6 : i32>
+ print(a6)
+
+ # CHECK: i32
+ print(a4.elem_type)
+ # CHECK: 4 : i32
+ print(a4.length)
+ # CHECK: i64
+ print(a6.elem_type)
+ # CHECK: 6 : i32
+ print(a6.length)
+
+ module = Module.create()
+ with InsertionPoint(module.body):
+ MakeArrayOp(a4)
+ MakeArrayOp(a6)
+ MakeArray3Op()
+
+ # CHECK: %0 = "ext_type.make_array"() : () -> !ext_type.array<i32, 4 : i32>
+ # CHECK: %1 = "ext_type.make_array"() : () -> !ext_type.array<i64, 6 : i32>
+ # CHECK: %2 = "ext_type.make_array3"() : () -> !ext_type.array<i32, 3 : i32>
+ print(module)
More information about the Mlir-commits
mailing list