[Mlir-commits] [mlir] 972aa59 - [MLIR][Python] Make traits declarative in python-defined operations (#180748)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Feb 11 04:40:03 PST 2026
Author: Twice
Date: 2026-02-11T20:39:58+08:00
New Revision: 972aa597de3a39f457207546af93ab3bfd1027f0
URL: https://github.com/llvm/llvm-project/commit/972aa597de3a39f457207546af93ab3bfd1027f0
DIFF: https://github.com/llvm/llvm-project/commit/972aa597de3a39f457207546af93ab3bfd1027f0.diff
LOG: [MLIR][Python] Make traits declarative in python-defined operations (#180748)
This will support two syntax in python-defined dialects.
First is that traits can now be declared in class parameters, e.g.
```python
class ParentIsIfTrait(DynamicOpTrait): #define a python-side trait
@staticmethod
def verify_invariants(op) -> bool:
if not isinstance(op.parent.opview, IfOp):
op.location.emit_error(
f"{op.name} should be put inside {IfOp.OPERATION_NAME}"
)
return False
return True
class YieldOp( # attach two traits: IsTerminatorTrait, ParentIsIfTrait
TestRegion.Operation, name="yield", traits=[IsTerminatorTrait, ParentIsIfTrait]
):
...
```
Second is that users can directly define
`verify_invariants`/`verify_region_invariants` methods in the operation
to add additional custom verification logic. And this is implemented via
traits.
```python
class YieldOp(TestRegion.Operation, name="yield", ...):
value: Operand[Any]
def verify_invariants(self) -> bool: # define a method directly
if self.parent.results[0].type != self.value.type:
self.location.emit_error(
"result type mismatch between YieldOp and its parent IfOp"
)
return False
return True
```
Previously we use `verify`/`verify_region` as method names (in
yesterday's PR #179705), but in this PR they are renamed to
`verify_invariants`/`verify_region_invariants` because there are
conflicts between the newly-added `verify` method and `ir.OpView.verify`
method:
- `verify_invariants` is just to attach **additional** verification
logic. but `OpView.verify` is to construct an OperationVerifer and do
full verification for an operation, so the semantics is not same between
these two. We should not shadow the `OpView.verify` method by defining a
new semantically-different `verify` method.
- it will make users confuse between these two `verify` methods, since
they have different meaning.
- if users didn't define the `verify` method in their python-defined
operation, `DynamicOpTraits.attach(opname, MyOpCls)` still do the
attaching (because `hasattr("verify")` returns `True`) and seg fault
(because we cannot attach `OpView.verify`).
---------
Co-authored-by: Rolf Morel <rolfmorel at gmail.com>
Added:
Modified:
mlir/lib/Bindings/Python/IRCore.cpp
mlir/python/mlir/dialects/ext.py
mlir/test/python/dialects/ext.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 9edc981220f8b..bffe5da45f6dc 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2543,10 +2543,11 @@ static bool attachOpTrait(const nb::object &opName, MlirDynamicOpTrait trait,
bool PyDynamicOpTrait::attach(const nb::object &opName,
const nb::object &target,
PyMlirContext &context) {
- if (!nb::hasattr(target, "verify") && !nb::hasattr(target, "verify_region"))
+ if (!nb::hasattr(target, "verify_invariants") &&
+ !nb::hasattr(target, "verify_region_invariants"))
throw nb::type_error(
- "the target object must have at least one of 'verify' or "
- "'verify_region' methods");
+ "the target object must have at least one of 'verify_invariants' or "
+ "'verify_region_invariants' methods");
MlirDynamicOpTraitCallbacks callbacks;
callbacks.construct = [](void *userData) {
@@ -2558,11 +2559,11 @@ bool PyDynamicOpTrait::attach(const nb::object &opName,
callbacks.verifyTrait = [](MlirOperation op,
void *userData) -> MlirLogicalResult {
- return verifyTraitByMethod(op, userData, "verify");
+ return verifyTraitByMethod(op, userData, "verify_invariants");
};
callbacks.verifyRegionTrait = [](MlirOperation op,
void *userData) -> MlirLogicalResult {
- return verifyTraitByMethod(op, userData, "verify_region");
+ return verifyTraitByMethod(op, userData, "verify_region_invariants");
};
// To ensure that the same dynamic trait gets the same TypeID despite how many
diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 251d4831cb331..07d52a5a28d14 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -204,7 +204,9 @@ class Operation(ir.OpView):
"""
@classmethod
- def __init_subclass__(cls, *, name: str = None, **kwargs):
+ def __init_subclass__(
+ cls, *, name: str | None = None, traits: list[type] | None = None, **kwargs
+ ):
"""
This method is to perform all magic to make a `Operation` subclass works like a dataclass, like:
- generate the method to emit IRDL operations,
@@ -225,6 +227,14 @@ def __init_subclass__(cls, *, name: str = None, **kwargs):
cls._fields = fields
+ traits = traits or []
+
+ for base in cls.__bases__:
+ if hasattr(base, "_traits"):
+ traits = base._traits + traits
+
+ cls._traits = traits
+
# for subclasses without "name" parameter,
# just treat them as normal classes
if not name:
@@ -407,6 +417,16 @@ def getter(self, i=i, result=result):
else:
setattr(cls, result.name, property(lambda self, i=i: self.results[i]))
+ @classmethod
+ def _attach_traits(cls) -> None:
+ for trait in cls._traits:
+ trait.attach(cls.OPERATION_NAME)
+
+ if hasattr(cls, "verify_invariants") or hasattr(
+ cls, "verify_region_invariants"
+ ):
+ ir.DynamicOpTrait.attach(cls.OPERATION_NAME, cls)
+
@classmethod
def _emit_operation(cls) -> None:
ctx = ConstraintLoweringContext()
@@ -502,6 +522,7 @@ def load(cls) -> None:
_cext.register_dialect(cls)
for op in cls.operations:
+ op._attach_traits()
_cext.register_operation(cls)(op)
cls._mlir_module = mlir_module
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index d24a94bc8baf8..d300f0b0442ae 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -345,6 +345,16 @@ class TypeVarOp(Test.Operation, name="type_var"):
# CHECK: TEST: testExtDialectWithRegion
@run
def testExtDialectWithRegion():
+ class ParentIsIfTrait(DynamicOpTrait):
+ @staticmethod
+ def verify_invariants(op) -> bool:
+ if not isinstance(op.parent.opview, IfOp):
+ op.location.emit_error(
+ f"{op.name} should be put inside {IfOp.OPERATION_NAME}"
+ )
+ return False
+ return True
+
class TestRegion(Dialect, name="ext_region"):
pass
@@ -354,10 +364,20 @@ class IfOp(TestRegion.Operation, name="if"):
then: Region
else_: Region
- class YieldOp(TestRegion.Operation, name="yield"):
+ class YieldOp(
+ TestRegion.Operation, name="yield", traits=[IsTerminatorTrait, ParentIsIfTrait]
+ ):
value: Operand[Any]
- class NoTermOp(TestRegion.Operation, name="no_term"):
+ def verify_invariants(self) -> bool:
+ if self.parent.results[0].type != self.value.type:
+ self.location.emit_error(
+ "result type mismatch between YieldOp and its parent IfOp"
+ )
+ return False
+ return True
+
+ class NoTermOp(TestRegion.Operation, name="no_term", traits=[NoTerminatorTrait]):
body: Region
with Context(), Location.unknown():
@@ -383,21 +403,6 @@ class NoTermOp(TestRegion.Operation, name="no_term"):
# CHECK: }
print(TestRegion._mlir_module)
- IsTerminatorTrait.attach(YieldOp)
- NoTerminatorTrait.attach(NoTermOp)
-
- class ParentIsIfTrait(DynamicOpTrait):
- @staticmethod
- def verify(op) -> bool:
- if not isinstance(op.parent.opview, IfOp):
- op.location.emit_error(
- f"{op.name} should be put inside {IfOp.OPERATION_NAME}"
- )
- return False
- return True
-
- ParentIsIfTrait.attach(YieldOp)
-
# CHECK: (self, /, result, cond, *, loc=None, ip=None)
print(IfOp.__init__.__signature__)
@@ -489,3 +494,28 @@ def verify(op) -> bool:
# CHECK: Verification failed:
# CHECK: ext_region.yield should be put inside ext_region.if
print(e)
+
+ module = Module.create()
+ with InsertionPoint(module.body):
+ i1 = IntegerType.get_signless(1)
+ i32 = IntegerType.get_signless(32)
+ cond = arith.constant(i1, 1)
+
+ if_ = IfOp(i1, cond)
+ if_.then.blocks.append()
+ if_.else_.blocks.append()
+
+ with InsertionPoint(if_.then.blocks[0]):
+ v = arith.constant(i32, 2)
+ YieldOp(v)
+
+ with InsertionPoint(if_.else_.blocks[0]):
+ v = arith.constant(i32, 3)
+ YieldOp(v)
+
+ try:
+ module.operation.verify()
+ except Exception as e:
+ # CHECK: Verification failed:
+ # CHECK: result type mismatch
+ print(e)
More information about the Mlir-commits
mailing list