[Mlir-commits] [mlir] [MLIR][Python] Make traits declarative in python-defined operations (PR #180748)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 10 07:33:14 PST 2026
https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/180748
>From 0e8b599be7e4e511b5da2437acf619afa2caffa9 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Tue, 10 Feb 2026 22:32:35 +0800
Subject: [PATCH 1/3] [MLIR][Python] Make traits declarative in python-defined
operations
---
mlir/lib/Bindings/Python/IRCore.cpp | 11 ++---
mlir/python/mlir/dialects/ext.py | 21 +++++++++-
mlir/test/python/dialects/ext.py | 64 +++++++++++++++++++++--------
3 files changed, 73 insertions(+), 23 deletions(-)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 9edc981220f8b..fb25f3191962a 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2543,10 +2543,11 @@ static bool attachOpTrait(const nb::object &opName, MlirDynamicOpTrait trait,
bool PyDynamicOpTrait::attach(const nb::object &opName,
const nb::object &target,
PyMlirContext &context) {
- if (!nb::hasattr(target, "verify") && !nb::hasattr(target, "verify_region"))
+ if (!nb::hasattr(target, "verify_trait") &&
+ !nb::hasattr(target, "verify_region_trait"))
throw nb::type_error(
- "the target object must have at least one of 'verify' or "
- "'verify_region' methods");
+ "the target object must have at least one of 'verify_trait' or "
+ "'verify_region_trait' methods");
MlirDynamicOpTraitCallbacks callbacks;
callbacks.construct = [](void *userData) {
@@ -2558,11 +2559,11 @@ bool PyDynamicOpTrait::attach(const nb::object &opName,
callbacks.verifyTrait = [](MlirOperation op,
void *userData) -> MlirLogicalResult {
- return verifyTraitByMethod(op, userData, "verify");
+ return verifyTraitByMethod(op, userData, "verify_trait");
};
callbacks.verifyRegionTrait = [](MlirOperation op,
void *userData) -> MlirLogicalResult {
- return verifyTraitByMethod(op, userData, "verify_region");
+ return verifyTraitByMethod(op, userData, "verify_region_trait");
};
// To ensure that the same dynamic trait gets the same TypeID despite how many
diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 251d4831cb331..bb8c61b59aaf3 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -204,7 +204,9 @@ class Operation(ir.OpView):
"""
@classmethod
- def __init_subclass__(cls, *, name: str = None, **kwargs):
+ def __init_subclass__(
+ cls, *, name: str | None = None, traits: list | None = None, **kwargs
+ ):
"""
This method is to perform all magic to make a `Operation` subclass works like a dataclass, like:
- generate the method to emit IRDL operations,
@@ -225,6 +227,14 @@ def __init_subclass__(cls, *, name: str = None, **kwargs):
cls._fields = fields
+ traits = traits or []
+
+ for base in cls.__bases__:
+ if hasattr(base, "_traits"):
+ traits.extend(base._traits)
+
+ cls._traits = traits
+
# for subclasses without "name" parameter,
# just treat them as normal classes
if not name:
@@ -407,6 +417,14 @@ def getter(self, i=i, result=result):
else:
setattr(cls, result.name, property(lambda self, i=i: self.results[i]))
+ @classmethod
+ def _attach_trait(cls) -> None:
+ for trait in cls._traits:
+ trait.attach(cls.OPERATION_NAME)
+
+ if hasattr(cls, "verify_trait") or hasattr(cls, "verify_region_trait"):
+ ir.DynamicOpTrait.attach(cls.OPERATION_NAME, cls)
+
@classmethod
def _emit_operation(cls) -> None:
ctx = ConstraintLoweringContext()
@@ -502,6 +520,7 @@ def load(cls) -> None:
_cext.register_dialect(cls)
for op in cls.operations:
+ op._attach_trait()
_cext.register_operation(cls)(op)
cls._mlir_module = mlir_module
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index d24a94bc8baf8..f850a517f5f1f 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -345,6 +345,16 @@ class TypeVarOp(Test.Operation, name="type_var"):
# CHECK: TEST: testExtDialectWithRegion
@run
def testExtDialectWithRegion():
+ class ParentIsIfTrait(DynamicOpTrait):
+ @staticmethod
+ def verify_trait(op) -> bool:
+ if not isinstance(op.parent.opview, IfOp):
+ op.location.emit_error(
+ f"{op.name} should be put inside {IfOp.OPERATION_NAME}"
+ )
+ return False
+ return True
+
class TestRegion(Dialect, name="ext_region"):
pass
@@ -354,10 +364,20 @@ class IfOp(TestRegion.Operation, name="if"):
then: Region
else_: Region
- class YieldOp(TestRegion.Operation, name="yield"):
+ class YieldOp(
+ TestRegion.Operation, name="yield", traits=[IsTerminatorTrait, ParentIsIfTrait]
+ ):
value: Operand[Any]
- class NoTermOp(TestRegion.Operation, name="no_term"):
+ def verify_trait(self) -> bool:
+ if self.parent.results[0].type != self.value.type:
+ self.location.emit_error(
+ "result type mismatch between YieldOp and its parent IfOp"
+ )
+ return False
+ return True
+
+ class NoTermOp(TestRegion.Operation, name="no_term", traits=[NoTerminatorTrait]):
body: Region
with Context(), Location.unknown():
@@ -383,21 +403,6 @@ class NoTermOp(TestRegion.Operation, name="no_term"):
# CHECK: }
print(TestRegion._mlir_module)
- IsTerminatorTrait.attach(YieldOp)
- NoTerminatorTrait.attach(NoTermOp)
-
- class ParentIsIfTrait(DynamicOpTrait):
- @staticmethod
- def verify(op) -> bool:
- if not isinstance(op.parent.opview, IfOp):
- op.location.emit_error(
- f"{op.name} should be put inside {IfOp.OPERATION_NAME}"
- )
- return False
- return True
-
- ParentIsIfTrait.attach(YieldOp)
-
# CHECK: (self, /, result, cond, *, loc=None, ip=None)
print(IfOp.__init__.__signature__)
@@ -489,3 +494,28 @@ def verify(op) -> bool:
# CHECK: Verification failed:
# CHECK: ext_region.yield should be put inside ext_region.if
print(e)
+
+ module = Module.create()
+ with InsertionPoint(module.body):
+ i1 = IntegerType.get_signless(1)
+ i32 = IntegerType.get_signless(32)
+ cond = arith.constant(i1, 1)
+
+ if_ = IfOp(i1, cond)
+ if_.then.blocks.append()
+ if_.else_.blocks.append()
+
+ with InsertionPoint(if_.then.blocks[0]):
+ v = arith.constant(i32, 2)
+ YieldOp(v)
+
+ with InsertionPoint(if_.else_.blocks[0]):
+ v = arith.constant(i32, 3)
+ YieldOp(v)
+
+ try:
+ module.operation.verify()
+ except Exception as e:
+ # CHECK: Verification failed:
+ # CHECK: result type mismatch
+ print(e)
>From 0f13e5ccaac577f4bf517edb702efc68caff8fca Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Tue, 10 Feb 2026 23:08:04 +0800
Subject: [PATCH 2/3] fix typo
---
mlir/python/mlir/dialects/ext.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index bb8c61b59aaf3..372ebbdd67762 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -418,7 +418,7 @@ def getter(self, i=i, result=result):
setattr(cls, result.name, property(lambda self, i=i: self.results[i]))
@classmethod
- def _attach_trait(cls) -> None:
+ def _attach_traits(cls) -> None:
for trait in cls._traits:
trait.attach(cls.OPERATION_NAME)
@@ -520,7 +520,7 @@ def load(cls) -> None:
_cext.register_dialect(cls)
for op in cls.operations:
- op._attach_trait()
+ op._attach_traits()
_cext.register_operation(cls)(op)
cls._mlir_module = mlir_module
>From 02a1f73a33c98a3b7fa9c0683d05c6ed5388997a Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Tue, 10 Feb 2026 23:32:48 +0800
Subject: [PATCH 3/3] fix
---
mlir/test/python/dialects/ext.py | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index f850a517f5f1f..95a37faa075fb 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -370,7 +370,10 @@ class YieldOp(
value: Operand[Any]
def verify_trait(self) -> bool:
- if self.parent.results[0].type != self.value.type:
+ parent_results = self.parent.results
+ if len(parent_results) == 0:
+ return True
+ if parent_results[0].type != self.value.type:
self.location.emit_error(
"result type mismatch between YieldOp and its parent IfOp"
)
More information about the Mlir-commits
mailing list