[Mlir-commits] [mlir] [MLIR][Python] Support region in python-defined dialects (PR #179086)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Jan 31 21:53:01 PST 2026
https://github.com/PragmaTwice created https://github.com/llvm/llvm-project/pull/179086
This PR adds basic support for defining regions in Python-defined dialects. Example usage:
```python
class TestRegion(Dialect, name="ext_region"):
pass
class IfOp(TestRegion.Operation, name="if"):
cond: Operand[IntegerType[1]]
then: Region
else_: Region
```
Current limitations:
* We can’t specify region constraints yet (e.g., number of blocks or block argument types). This will be addressed as a follow-up task.
* We can’t mark an op as a `Terminator` or `NoTerminator` yet. This depends on `DynamicOpTraits` (#177735) and Python-side trait API support, and will be implemented in a follow-up PR.
>From df631fd72564fcfe2ef47130cd4a0b2cbe9c02c5 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sun, 1 Feb 2026 13:07:23 +0800
Subject: [PATCH 1/3] [MLIR][Python] Support region in python-defined dialects
---
mlir/python/mlir/dialects/ext.py | 64 ++++++++++++++++++++++--------
mlir/test/python/dialects/ext.py | 67 ++++++++++++++++++++++++++++++++
2 files changed, 115 insertions(+), 16 deletions(-)
diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 237c27bf62f77..31378f74d049f 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -29,10 +29,12 @@
"Dialect",
"Operand",
"Result",
+ "Region",
]
Operand = ir.Value
Result = ir.OpResult
+Region = ir.Region
class ConstraintLoweringContext:
@@ -102,7 +104,6 @@ class FieldDef:
"""
name: str
- constraint: Any
variadicity: Variadicity
@staticmethod
@@ -117,38 +118,50 @@ def from_type_hint(name, type_) -> "FieldDef":
origin = get_origin(type_)
if origin is ir.OpResult:
- return ResultDef(name, get_args(type_)[0], variadicity)
+ return ResultDef(name, variadicity, get_args(type_)[0])
elif origin is ir.Value:
- return OperandDef(name, get_args(type_)[0], variadicity)
+ return OperandDef(name, variadicity, get_args(type_)[0])
elif issubclass(origin or type_, ir.Attribute):
- return AttributeDef(name, type_, variadicity)
+ return AttributeDef(name, variadicity, type_)
+ elif type_ is ir.Region:
+ return RegionDef(name, variadicity)
raise TypeError(f"unsupported type in operation definition: {type_}")
@dataclass
class OperandDef(FieldDef):
- pass
+ constraint: Any
@dataclass
class ResultDef(FieldDef):
- pass
+ constraint: Any
@dataclass
class AttributeDef(FieldDef):
+ constraint: Any
+
def __post_init__(self):
if self.variadicity != Variadicity.single:
- raise ValueError("optional attribute is not supported in IRDL")
+ raise ValueError("optional attribute is not currently supported")
+
+
+ at dataclass
+class RegionDef(FieldDef):
+ def __post_init__(self):
+ if self.variadicity != Variadicity.single:
+ raise ValueError("optional region is not currently supported")
def partition_fields(
fields: List[FieldDef],
-) -> Tuple[List[OperandDef], List[AttributeDef], List[ResultDef]]:
+) -> Tuple[List[OperandDef], List[AttributeDef], List[ResultDef], List[RegionDef]]:
operands = [i for i in fields if isinstance(i, OperandDef)]
attrs = [i for i in fields if isinstance(i, AttributeDef)]
results = [i for i in fields if isinstance(i, ResultDef)]
- return operands, attrs, results
+ regions = [i for i in fields if isinstance(i, RegionDef)]
+ return operands, attrs, results, regions
def normalize_value_range(
@@ -223,10 +236,11 @@ def __init_subclass__(cls, *, name: str = None, **kwargs):
cls._generate_class_attributes(dialect_name, op_name, fields)
cls._generate_init_method(fields)
- operands, attrs, results = partition_fields(fields)
+ operands, attrs, results, regions = partition_fields(fields)
cls._generate_attr_properties(attrs)
cls._generate_operand_properties(operands)
cls._generate_result_properties(results)
+ cls._generate_region_properties(regions)
dialect_obj.operations.append(cls)
@@ -254,7 +268,11 @@ def _generate_init_signature(
)
# results are placed at the beginning of the parameter list,
# but operands and attributes can appear in any relative order.
- args = result_args + [i for i in fields if not isinstance(i, ResultDef)]
+ args = result_args + [
+ i
+ for i in fields
+ if not isinstance(i, ResultDef) and not isinstance(i, RegionDef)
+ ]
positional_args = [
i.name for i in args if i.variadicity != Variadicity.optional
]
@@ -272,7 +290,7 @@ def _generate_init_signature(
@classmethod
def _generate_init_method(cls, fields: List[FieldDef]) -> None:
- operands, attrs, results = partition_fields(fields)
+ operands, attrs, results, regions = partition_fields(fields)
inferred_types = [infer_type(i.constraint) for i in results]
# we infer result types only when all result types can be inferred
@@ -299,7 +317,7 @@ def __init__(*args, **kwargs):
for attr in attrs
if args[attr.name] is not None
)
- _regions = None
+ _regions = len(regions) or None
_ods_successors = None
self = args["self"]
super(Operation, self).__init__(
@@ -323,13 +341,13 @@ def __init__(*args, **kwargs):
def _generate_class_attributes(
cls, dialect_name: str, op_name: str, fields: List[FieldDef]
) -> None:
- operands, attrs, results = partition_fields(fields)
+ operands, attrs, results, regions = partition_fields(fields)
operand_segments = cls._generate_segments(operands)
result_segments = cls._generate_segments(results)
cls.OPERATION_NAME = f"{dialect_name}.{op_name}"
- cls._ODS_REGIONS = (0, True)
+ cls._ODS_REGIONS = (len(regions), True)
cls._ODS_OPERAND_SEGMENTS = operand_segments
cls._ODS_RESULT_SEGMENTS = result_segments
@@ -342,6 +360,15 @@ def _generate_attr_properties(cls, attrs: List[AttributeDef]) -> None:
property(lambda self, name=attr.name: self.attributes[name]),
)
+ @classmethod
+ def _generate_region_properties(cls, regions: List[RegionDef]) -> None:
+ for i, region in enumerate(regions):
+ setattr(
+ cls,
+ region.name,
+ property(lambda self, i=i: self.regions[i]),
+ )
+
@classmethod
def _generate_operand_properties(cls, operands: List[OperandDef]) -> None:
for i, operand in enumerate(operands):
@@ -379,7 +406,7 @@ def getter(self, i=i, result=result):
@classmethod
def _emit_operation(cls) -> None:
ctx = ConstraintLoweringContext()
- operands, attrs, results = partition_fields(cls._fields)
+ operands, attrs, results, regions = partition_fields(cls._fields)
op = irdl.operation_(cls._op_name)
with ir.InsertionPoint(op.body):
@@ -400,6 +427,11 @@ def _emit_operation(cls) -> None:
[i.name for i in results],
[i.variadicity for i in results],
)
+ if regions:
+ irdl.regions_(
+ [irdl.region([]) for _ in regions],
+ [i.name for i in regions],
+ )
class Dialect(ir.Dialect):
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index 483953ddfde51..1301a04840220 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -76,6 +76,8 @@ class AddOp(MyInt.Operation, name="add"):
print(add1._ODS_OPERAND_SEGMENTS)
# CHECK: None
print(add1._ODS_RESULT_SEGMENTS)
+ # CHECK: (0, True)
+ print(add1._ODS_REGIONS)
# CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
print(add1.lhs.owner)
# CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
@@ -338,3 +340,68 @@ class TypeVarOp(Test.Operation, name="type_var"):
except TypeError as e:
# CHECK:too many positional arguments
print(e)
+
+
+# CHECK: TEST: testExtDialectWithRegion
+ at run
+def testExtDialectWithRegion():
+ class TestRegion(Dialect, name="ext_region"):
+ pass
+
+ class IfOp(TestRegion.Operation, name="if"):
+ cond: Operand[IntegerType[1]]
+ then: Region
+ else_: Region
+
+ with Context(), Location.unknown():
+ TestRegion.load()
+ # CHECK: irdl.dialect @ext_region {
+ # CHECK: irdl.operation @if {
+ # CHECK: %0 = irdl.is i1
+ # CHECK: irdl.operands(cond: %0)
+ # CHECK: %1 = irdl.region
+ # CHECK: %2 = irdl.region
+ # CHECK: irdl.regions(then: %1, else_: %2)
+ # CHECK: }
+ print(TestRegion._mlir_module)
+
+ # CHECK: (self, /, cond, *, loc=None, ip=None)
+ print(IfOp.__init__.__signature__)
+
+ # CHECK: None None
+ print(IfOp._ODS_OPERAND_SEGMENTS, IfOp._ODS_RESULT_SEGMENTS)
+ # CHECK: (2, True)
+ print(IfOp._ODS_REGIONS)
+
+ from mlir.dialects import llvm
+
+ module = Module.create()
+ with InsertionPoint(module.body):
+ i1 = IntegerType.get_signless(1)
+ i32 = IntegerType.get_signless(32)
+ cond = arith.constant(i1, 1)
+
+ if_ = IfOp(cond)
+ if_.then.blocks.append()
+ if_.else_.blocks.append()
+
+ with InsertionPoint(if_.then.blocks[0]):
+ v = arith.constant(i32, 2)
+ llvm.unreachable()
+
+ with InsertionPoint(if_.else_.blocks[0]):
+ v = arith.constant(i32, 3)
+ llvm.unreachable()
+
+ assert module.operation.verify()
+ # CHECK: module {
+ # CHECK: %true = arith.constant true
+ # CHECK: "ext_region.if"(%true) ({
+ # CHECK: %c2_i32 = arith.constant 2 : i32
+ # CHECK: llvm.unreachable
+ # CHECK: }, {
+ # CHECK: %c3_i32 = arith.constant 3 : i32
+ # CHECK: llvm.unreachable
+ # CHECK: }) : (i1) -> ()
+ # CHECK: }
+ print(module)
>From 72b6e842cd94b15f01735a8b49a97b2b6ec57c87 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sun, 1 Feb 2026 13:11:27 +0800
Subject: [PATCH 2/3] add more check
---
mlir/python/mlir/dialects/ext.py | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 31378f74d049f..a8701f0506e01 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -30,6 +30,7 @@
"Operand",
"Result",
"Region",
+ "Operation",
]
Operand = ir.Value
@@ -229,6 +230,11 @@ def __init_subclass__(cls, *, name: str = None, **kwargs):
if not name:
return
+ if not hasattr(cls, "_dialect_name") or not hasattr(cls, "_dialect_obj"):
+ raise RuntimeError(
+ "Operation subclasses must inherit from a Dialect's Operation subclass"
+ )
+
op_name = name
cls._op_name = op_name
dialect_name = cls._dialect_name
>From df7f5241ed648734dff88f7b69bbe2bf70606ec0 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Sun, 1 Feb 2026 13:28:03 +0800
Subject: [PATCH 3/3] add more test
---
mlir/test/python/dialects/ext.py | 5 +++++
1 file changed, 5 insertions(+)
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index 1301a04840220..30e705726756b 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -405,3 +405,8 @@ class IfOp(TestRegion.Operation, name="if"):
# CHECK: }) : (i1) -> ()
# CHECK: }
print(module)
+
+ # CHECK: %c2_i32 = arith.constant 2 : i32
+ print(if_.then.blocks[0])
+ # CHECK: %c3_i32 = arith.constant 3 : i32
+ print(if_.else_.blocks[0])
More information about the Mlir-commits
mailing list