[Mlir-commits] [mlir] 972aa59 - [MLIR][Python] Make traits declarative in python-defined operations (#180748)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 11 04:40:03 PST 2026


Author: Twice
Date: 2026-02-11T20:39:58+08:00
New Revision: 972aa597de3a39f457207546af93ab3bfd1027f0

URL: https://github.com/llvm/llvm-project/commit/972aa597de3a39f457207546af93ab3bfd1027f0
DIFF: https://github.com/llvm/llvm-project/commit/972aa597de3a39f457207546af93ab3bfd1027f0.diff

LOG: [MLIR][Python] Make traits declarative in python-defined operations (#180748)

This will support two syntax in python-defined dialects.

First is that traits can now be declared in class parameters, e.g.
```python
class ParentIsIfTrait(DynamicOpTrait): #define a python-side trait
    @staticmethod
    def verify_invariants(op) -> bool:
        if not isinstance(op.parent.opview, IfOp):
            op.location.emit_error(
                f"{op.name} should be put inside {IfOp.OPERATION_NAME}"
            )
            return False
        return True

class YieldOp( # attach two traits: IsTerminatorTrait, ParentIsIfTrait
    TestRegion.Operation, name="yield", traits=[IsTerminatorTrait, ParentIsIfTrait]
):
    ...
```

Second is that users can directly define
`verify_invariants`/`verify_region_invariants` methods in the operation
to add additional custom verification logic. And this is implemented via
traits.
```python
class YieldOp(TestRegion.Operation, name="yield", ...):
    value: Operand[Any]

    def verify_invariants(self) -> bool: # define a method directly
        if self.parent.results[0].type != self.value.type:
            self.location.emit_error(
                "result type mismatch between YieldOp and its parent IfOp"
            )
            return False
        return True
```

Previously we use `verify`/`verify_region` as method names (in
yesterday's PR #179705), but in this PR they are renamed to
`verify_invariants`/`verify_region_invariants` because there are
conflicts between the newly-added `verify` method and `ir.OpView.verify`
method:
- `verify_invariants` is just to attach **additional** verification
logic. but `OpView.verify` is to construct an OperationVerifer and do
full verification for an operation, so the semantics is not same between
these two. We should not shadow the `OpView.verify` method by defining a
new semantically-different `verify` method.
- it will make users confuse between these two `verify` methods, since
they have different meaning.
- if users didn't define the `verify` method in their python-defined
operation, `DynamicOpTraits.attach(opname, MyOpCls)` still do the
attaching (because `hasattr("verify")` returns `True`) and seg fault
(because we cannot attach `OpView.verify`).

---------

Co-authored-by: Rolf Morel <rolfmorel at gmail.com>

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/python/mlir/dialects/ext.py
    mlir/test/python/dialects/ext.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 9edc981220f8b..bffe5da45f6dc 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2543,10 +2543,11 @@ static bool attachOpTrait(const nb::object &opName, MlirDynamicOpTrait trait,
 bool PyDynamicOpTrait::attach(const nb::object &opName,
                               const nb::object &target,
                               PyMlirContext &context) {
-  if (!nb::hasattr(target, "verify") && !nb::hasattr(target, "verify_region"))
+  if (!nb::hasattr(target, "verify_invariants") &&
+      !nb::hasattr(target, "verify_region_invariants"))
     throw nb::type_error(
-        "the target object must have at least one of 'verify' or "
-        "'verify_region' methods");
+        "the target object must have at least one of 'verify_invariants' or "
+        "'verify_region_invariants' methods");
 
   MlirDynamicOpTraitCallbacks callbacks;
   callbacks.construct = [](void *userData) {
@@ -2558,11 +2559,11 @@ bool PyDynamicOpTrait::attach(const nb::object &opName,
 
   callbacks.verifyTrait = [](MlirOperation op,
                              void *userData) -> MlirLogicalResult {
-    return verifyTraitByMethod(op, userData, "verify");
+    return verifyTraitByMethod(op, userData, "verify_invariants");
   };
   callbacks.verifyRegionTrait = [](MlirOperation op,
                                    void *userData) -> MlirLogicalResult {
-    return verifyTraitByMethod(op, userData, "verify_region");
+    return verifyTraitByMethod(op, userData, "verify_region_invariants");
   };
 
   // To ensure that the same dynamic trait gets the same TypeID despite how many

diff  --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 251d4831cb331..07d52a5a28d14 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -204,7 +204,9 @@ class Operation(ir.OpView):
     """
 
     @classmethod
-    def __init_subclass__(cls, *, name: str = None, **kwargs):
+    def __init_subclass__(
+        cls, *, name: str | None = None, traits: list[type] | None = None, **kwargs
+    ):
         """
         This method is to perform all magic to make a `Operation` subclass works like a dataclass, like:
         - generate the method to emit IRDL operations,
@@ -225,6 +227,14 @@ def __init_subclass__(cls, *, name: str = None, **kwargs):
 
         cls._fields = fields
 
+        traits = traits or []
+
+        for base in cls.__bases__:
+            if hasattr(base, "_traits"):
+                traits = base._traits + traits
+
+        cls._traits = traits
+
         # for subclasses without "name" parameter,
         # just treat them as normal classes
         if not name:
@@ -407,6 +417,16 @@ def getter(self, i=i, result=result):
             else:
                 setattr(cls, result.name, property(lambda self, i=i: self.results[i]))
 
+    @classmethod
+    def _attach_traits(cls) -> None:
+        for trait in cls._traits:
+            trait.attach(cls.OPERATION_NAME)
+
+        if hasattr(cls, "verify_invariants") or hasattr(
+            cls, "verify_region_invariants"
+        ):
+            ir.DynamicOpTrait.attach(cls.OPERATION_NAME, cls)
+
     @classmethod
     def _emit_operation(cls) -> None:
         ctx = ConstraintLoweringContext()
@@ -502,6 +522,7 @@ def load(cls) -> None:
         _cext.register_dialect(cls)
 
         for op in cls.operations:
+            op._attach_traits()
             _cext.register_operation(cls)(op)
 
         cls._mlir_module = mlir_module

diff  --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index d24a94bc8baf8..d300f0b0442ae 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -345,6 +345,16 @@ class TypeVarOp(Test.Operation, name="type_var"):
 # CHECK: TEST: testExtDialectWithRegion
 @run
 def testExtDialectWithRegion():
+    class ParentIsIfTrait(DynamicOpTrait):
+        @staticmethod
+        def verify_invariants(op) -> bool:
+            if not isinstance(op.parent.opview, IfOp):
+                op.location.emit_error(
+                    f"{op.name} should be put inside {IfOp.OPERATION_NAME}"
+                )
+                return False
+            return True
+
     class TestRegion(Dialect, name="ext_region"):
         pass
 
@@ -354,10 +364,20 @@ class IfOp(TestRegion.Operation, name="if"):
         then: Region
         else_: Region
 
-    class YieldOp(TestRegion.Operation, name="yield"):
+    class YieldOp(
+        TestRegion.Operation, name="yield", traits=[IsTerminatorTrait, ParentIsIfTrait]
+    ):
         value: Operand[Any]
 
-    class NoTermOp(TestRegion.Operation, name="no_term"):
+        def verify_invariants(self) -> bool:
+            if self.parent.results[0].type != self.value.type:
+                self.location.emit_error(
+                    "result type mismatch between YieldOp and its parent IfOp"
+                )
+                return False
+            return True
+
+    class NoTermOp(TestRegion.Operation, name="no_term", traits=[NoTerminatorTrait]):
         body: Region
 
     with Context(), Location.unknown():
@@ -383,21 +403,6 @@ class NoTermOp(TestRegion.Operation, name="no_term"):
         # CHECK: }
         print(TestRegion._mlir_module)
 
-        IsTerminatorTrait.attach(YieldOp)
-        NoTerminatorTrait.attach(NoTermOp)
-
-        class ParentIsIfTrait(DynamicOpTrait):
-            @staticmethod
-            def verify(op) -> bool:
-                if not isinstance(op.parent.opview, IfOp):
-                    op.location.emit_error(
-                        f"{op.name} should be put inside {IfOp.OPERATION_NAME}"
-                    )
-                    return False
-                return True
-
-        ParentIsIfTrait.attach(YieldOp)
-
         # CHECK: (self, /, result, cond, *, loc=None, ip=None)
         print(IfOp.__init__.__signature__)
 
@@ -489,3 +494,28 @@ def verify(op) -> bool:
             # CHECK: Verification failed:
             # CHECK: ext_region.yield should be put inside ext_region.if
             print(e)
+
+        module = Module.create()
+        with InsertionPoint(module.body):
+            i1 = IntegerType.get_signless(1)
+            i32 = IntegerType.get_signless(32)
+            cond = arith.constant(i1, 1)
+
+            if_ = IfOp(i1, cond)
+            if_.then.blocks.append()
+            if_.else_.blocks.append()
+
+            with InsertionPoint(if_.then.blocks[0]):
+                v = arith.constant(i32, 2)
+                YieldOp(v)
+
+            with InsertionPoint(if_.else_.blocks[0]):
+                v = arith.constant(i32, 3)
+                YieldOp(v)
+
+        try:
+            module.operation.verify()
+        except Exception as e:
+            # CHECK: Verification failed:
+            # CHECK: result type mismatch
+            print(e)


        


More information about the Mlir-commits mailing list