[Mlir-commits] [mlir] 8542514 - [MLIR][Python] Allow passing dialect as a class keyword argument (#182465)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Feb 22 02:53:02 PST 2026
Author: Twice
Date: 2026-02-22T18:52:57+08:00
New Revision: 8542514e5cc39730cf40b939a08edc68374a94e5
URL: https://github.com/llvm/llvm-project/commit/8542514e5cc39730cf40b939a08edc68374a94e5
DIFF: https://github.com/llvm/llvm-project/commit/8542514e5cc39730cf40b939a08edc68374a94e5.diff
LOG: [MLIR][Python] Allow passing dialect as a class keyword argument (#182465)
Previously, we constructed new ops using the pattern `class
MyOp(MyInt.Operation)`.
Now we’ve added a new pattern: `class MyOp(Operation, dialect=MyInt)`,
which allows more flexible composition. For example:
```python
class BinOpBase(Operation): # it can be used in any dialect!
res: Result[Any]
lhs: Operand[Any]
rhs: Operand[Any]
class MyInt(Dialect, name="myint"):
pass
class AddOp(BinOpBase, dialect=MyInt, name="add"):
...
```
Added:
Modified:
mlir/python/mlir/dialects/ext.py
mlir/test/python/dialects/ext.py
Removed:
################################################################################
diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index ac1b3065336e8..79095658944e5 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -204,8 +204,14 @@ class Operation(ir.OpView):
"""
Base class of Python-defined operation.
- NOTE: Usually you don't need to use it directly.
- Use `Dialect` and `.Operation` of `Dialect` subclasses instead.
+ The following example shows two ways to define operations via this class:
+ ```python
+ class MyOp(MyDialect.Operation, name=..):
+ ...
+
+ class MyOp(Operation, dialect=MyDialect, name=..):
+ ...
+ ```
"""
def __init__(*args, **kwargs):
@@ -216,7 +222,12 @@ def __init__(*args, **kwargs):
@classmethod
def __init_subclass__(
- cls, *, name: str | None = None, traits: list[type] | None = None, **kwargs
+ cls,
+ *,
+ name: str | None = None,
+ traits: list[type] | None = None,
+ dialect: type | None = None,
+ **kwargs,
):
"""
This method is to perform all magic to make a `Operation` subclass works like a dataclass, like:
@@ -246,19 +257,27 @@ def __init_subclass__(
cls._traits = traits
+ if dialect:
+ if hasattr(cls, "_dialect_obj"):
+ raise RuntimeError(
+ f"This operation has already been attached to dialect '{cls._dialect_obj.DIALECT_NAMESPACE}'."
+ )
+ cls._dialect_obj = dialect
+
# for subclasses without "name" parameter,
# just treat them as normal classes
if not name:
return
- if not hasattr(cls, "_dialect_name") or not hasattr(cls, "_dialect_obj"):
+ if not hasattr(cls, "_dialect_obj"):
raise RuntimeError(
- "Operation subclasses must inherit from a Dialect's Operation subclass"
+ "Operation subclasses must either inherit from a Dialect's Operation subclass "
+ "or provide the dialect as a class keyword argument."
)
op_name = name
cls._op_name = op_name
- dialect_name = cls._dialect_name
+ dialect_name = cls._dialect_obj.DIALECT_NAMESPACE
dialect_obj = cls._dialect_obj
cls._generate_class_attributes(dialect_name, op_name, fields)
@@ -278,7 +297,7 @@ def _variadicity_to_segment(variadicity: Variadicity) -> int:
@staticmethod
def _generate_segments(
operands_or_results: List[Union[OperandDef, ResultDef]],
- ) -> List[int]:
+ ) -> List[int] | None:
if any(i.variadicity != Variadicity.single for i in operands_or_results):
return [
Operation._variadicity_to_segment(i.variadicity)
@@ -499,7 +518,8 @@ def __init_subclass__(cls, name: str, **kwargs):
cls.Operation = type(
"Operation",
(Operation,),
- {"_dialect_obj": cls, "_dialect_name": name},
+ dict(),
+ dialect=cls,
)
@classmethod
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index d300f0b0442ae..196af91e511ec 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -24,7 +24,7 @@ class ConstantOp(MyInt.Operation, name="constant"):
value: IntegerAttr
cst: Result[i32]
- class AddOp(MyInt.Operation, name="add"):
+ class AddOp(Operation, dialect=MyInt, name="add"):
lhs: Operand[i32]
rhs: Operand[i32]
res: Result[i32]
More information about the Mlir-commits
mailing list