[Mlir-commits] [mlir] [MLIR][Python] Make traits declarative in python-defined operations (PR #180748)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 10 07:33:14 PST 2026


https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/180748

>From 0e8b599be7e4e511b5da2437acf619afa2caffa9 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Tue, 10 Feb 2026 22:32:35 +0800
Subject: [PATCH 1/3] [MLIR][Python] Make traits declarative in python-defined
 operations

---
 mlir/lib/Bindings/Python/IRCore.cpp | 11 ++---
 mlir/python/mlir/dialects/ext.py    | 21 +++++++++-
 mlir/test/python/dialects/ext.py    | 64 +++++++++++++++++++++--------
 3 files changed, 73 insertions(+), 23 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 9edc981220f8b..fb25f3191962a 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2543,10 +2543,11 @@ static bool attachOpTrait(const nb::object &opName, MlirDynamicOpTrait trait,
 bool PyDynamicOpTrait::attach(const nb::object &opName,
                               const nb::object &target,
                               PyMlirContext &context) {
-  if (!nb::hasattr(target, "verify") && !nb::hasattr(target, "verify_region"))
+  if (!nb::hasattr(target, "verify_trait") &&
+      !nb::hasattr(target, "verify_region_trait"))
     throw nb::type_error(
-        "the target object must have at least one of 'verify' or "
-        "'verify_region' methods");
+        "the target object must have at least one of 'verify_trait' or "
+        "'verify_region_trait' methods");
 
   MlirDynamicOpTraitCallbacks callbacks;
   callbacks.construct = [](void *userData) {
@@ -2558,11 +2559,11 @@ bool PyDynamicOpTrait::attach(const nb::object &opName,
 
   callbacks.verifyTrait = [](MlirOperation op,
                              void *userData) -> MlirLogicalResult {
-    return verifyTraitByMethod(op, userData, "verify");
+    return verifyTraitByMethod(op, userData, "verify_trait");
   };
   callbacks.verifyRegionTrait = [](MlirOperation op,
                                    void *userData) -> MlirLogicalResult {
-    return verifyTraitByMethod(op, userData, "verify_region");
+    return verifyTraitByMethod(op, userData, "verify_region_trait");
   };
 
   // To ensure that the same dynamic trait gets the same TypeID despite how many
diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 251d4831cb331..bb8c61b59aaf3 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -204,7 +204,9 @@ class Operation(ir.OpView):
     """
 
     @classmethod
-    def __init_subclass__(cls, *, name: str = None, **kwargs):
+    def __init_subclass__(
+        cls, *, name: str | None = None, traits: list | None = None, **kwargs
+    ):
         """
         This method is to perform all magic to make a `Operation` subclass works like a dataclass, like:
         - generate the method to emit IRDL operations,
@@ -225,6 +227,14 @@ def __init_subclass__(cls, *, name: str = None, **kwargs):
 
         cls._fields = fields
 
+        traits = traits or []
+
+        for base in cls.__bases__:
+            if hasattr(base, "_traits"):
+                traits.extend(base._traits)
+
+        cls._traits = traits
+
         # for subclasses without "name" parameter,
         # just treat them as normal classes
         if not name:
@@ -407,6 +417,14 @@ def getter(self, i=i, result=result):
             else:
                 setattr(cls, result.name, property(lambda self, i=i: self.results[i]))
 
+    @classmethod
+    def _attach_trait(cls) -> None:
+        for trait in cls._traits:
+            trait.attach(cls.OPERATION_NAME)
+
+        if hasattr(cls, "verify_trait") or hasattr(cls, "verify_region_trait"):
+            ir.DynamicOpTrait.attach(cls.OPERATION_NAME, cls)
+
     @classmethod
     def _emit_operation(cls) -> None:
         ctx = ConstraintLoweringContext()
@@ -502,6 +520,7 @@ def load(cls) -> None:
         _cext.register_dialect(cls)
 
         for op in cls.operations:
+            op._attach_trait()
             _cext.register_operation(cls)(op)
 
         cls._mlir_module = mlir_module
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index d24a94bc8baf8..f850a517f5f1f 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -345,6 +345,16 @@ class TypeVarOp(Test.Operation, name="type_var"):
 # CHECK: TEST: testExtDialectWithRegion
 @run
 def testExtDialectWithRegion():
+    class ParentIsIfTrait(DynamicOpTrait):
+        @staticmethod
+        def verify_trait(op) -> bool:
+            if not isinstance(op.parent.opview, IfOp):
+                op.location.emit_error(
+                    f"{op.name} should be put inside {IfOp.OPERATION_NAME}"
+                )
+                return False
+            return True
+
     class TestRegion(Dialect, name="ext_region"):
         pass
 
@@ -354,10 +364,20 @@ class IfOp(TestRegion.Operation, name="if"):
         then: Region
         else_: Region
 
-    class YieldOp(TestRegion.Operation, name="yield"):
+    class YieldOp(
+        TestRegion.Operation, name="yield", traits=[IsTerminatorTrait, ParentIsIfTrait]
+    ):
         value: Operand[Any]
 
-    class NoTermOp(TestRegion.Operation, name="no_term"):
+        def verify_trait(self) -> bool:
+            if self.parent.results[0].type != self.value.type:
+                self.location.emit_error(
+                    "result type mismatch between YieldOp and its parent IfOp"
+                )
+                return False
+            return True
+
+    class NoTermOp(TestRegion.Operation, name="no_term", traits=[NoTerminatorTrait]):
         body: Region
 
     with Context(), Location.unknown():
@@ -383,21 +403,6 @@ class NoTermOp(TestRegion.Operation, name="no_term"):
         # CHECK: }
         print(TestRegion._mlir_module)
 
-        IsTerminatorTrait.attach(YieldOp)
-        NoTerminatorTrait.attach(NoTermOp)
-
-        class ParentIsIfTrait(DynamicOpTrait):
-            @staticmethod
-            def verify(op) -> bool:
-                if not isinstance(op.parent.opview, IfOp):
-                    op.location.emit_error(
-                        f"{op.name} should be put inside {IfOp.OPERATION_NAME}"
-                    )
-                    return False
-                return True
-
-        ParentIsIfTrait.attach(YieldOp)
-
         # CHECK: (self, /, result, cond, *, loc=None, ip=None)
         print(IfOp.__init__.__signature__)
 
@@ -489,3 +494,28 @@ def verify(op) -> bool:
             # CHECK: Verification failed:
             # CHECK: ext_region.yield should be put inside ext_region.if
             print(e)
+
+        module = Module.create()
+        with InsertionPoint(module.body):
+            i1 = IntegerType.get_signless(1)
+            i32 = IntegerType.get_signless(32)
+            cond = arith.constant(i1, 1)
+
+            if_ = IfOp(i1, cond)
+            if_.then.blocks.append()
+            if_.else_.blocks.append()
+
+            with InsertionPoint(if_.then.blocks[0]):
+                v = arith.constant(i32, 2)
+                YieldOp(v)
+
+            with InsertionPoint(if_.else_.blocks[0]):
+                v = arith.constant(i32, 3)
+                YieldOp(v)
+
+        try:
+            module.operation.verify()
+        except Exception as e:
+            # CHECK: Verification failed:
+            # CHECK: result type mismatch
+            print(e)

>From 0f13e5ccaac577f4bf517edb702efc68caff8fca Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Tue, 10 Feb 2026 23:08:04 +0800
Subject: [PATCH 2/3] fix typo

---
 mlir/python/mlir/dialects/ext.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index bb8c61b59aaf3..372ebbdd67762 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -418,7 +418,7 @@ def getter(self, i=i, result=result):
                 setattr(cls, result.name, property(lambda self, i=i: self.results[i]))
 
     @classmethod
-    def _attach_trait(cls) -> None:
+    def _attach_traits(cls) -> None:
         for trait in cls._traits:
             trait.attach(cls.OPERATION_NAME)
 
@@ -520,7 +520,7 @@ def load(cls) -> None:
         _cext.register_dialect(cls)
 
         for op in cls.operations:
-            op._attach_trait()
+            op._attach_traits()
             _cext.register_operation(cls)(op)
 
         cls._mlir_module = mlir_module

>From 02a1f73a33c98a3b7fa9c0683d05c6ed5388997a Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Tue, 10 Feb 2026 23:32:48 +0800
Subject: [PATCH 3/3] fix

---
 mlir/test/python/dialects/ext.py | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index f850a517f5f1f..95a37faa075fb 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -370,7 +370,10 @@ class YieldOp(
         value: Operand[Any]
 
         def verify_trait(self) -> bool:
-            if self.parent.results[0].type != self.value.type:
+            parent_results = self.parent.results
+            if len(parent_results) == 0:
+                return True
+            if parent_results[0].type != self.value.type:
                 self.location.emit_error(
                     "result type mismatch between YieldOp and its parent IfOp"
                 )



More information about the Mlir-commits mailing list