[Mlir-commits] [mlir] [MLIR][Python] Allow passing dialect as a class keyword argument (PR #182465)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 20 04:11:34 PST 2026


https://github.com/PragmaTwice updated https://github.com/llvm/llvm-project/pull/182465

>From 4947372264e3f080c5afdaba0def3415cbb6d378 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 20 Feb 2026 17:46:36 +0800
Subject: [PATCH 1/3] [MLIR][Python] Allow passing dialect as a class keyword
 argument

---
 mlir/python/mlir/dialects/ext.py | 20 +++++++++++++++++---
 mlir/test/python/dialects/ext.py |  2 +-
 2 files changed, 18 insertions(+), 4 deletions(-)

diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index ac1b3065336e8..b7dd1af8afb96 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -216,7 +216,12 @@ def __init__(*args, **kwargs):
 
     @classmethod
     def __init_subclass__(
-        cls, *, name: str | None = None, traits: list[type] | None = None, **kwargs
+        cls,
+        *,
+        name: str | None = None,
+        traits: list[type] | None = None,
+        dialect: type | None = None,
+        **kwargs,
     ):
         """
         This method is to perform all magic to make a `Operation` subclass works like a dataclass, like:
@@ -251,9 +256,18 @@ def __init_subclass__(
         if not name:
             return
 
+        if dialect:
+            if hasattr(cls, "_dialect_name") or hasattr(cls, "_dialect_obj"):
+                raise RuntimeError(
+                    f"This operation has already been attached to dialect '{cls._dialect_name}'."
+                )
+            cls._dialect_obj = dialect
+            cls._dialect_name = dialect.DIALECT_NAMESPACE
+
         if not hasattr(cls, "_dialect_name") or not hasattr(cls, "_dialect_obj"):
             raise RuntimeError(
-                "Operation subclasses must inherit from a Dialect's Operation subclass"
+                "Operation subclasses must either inherit from a Dialect's Operation subclass "
+                "or provide the dialect as a class keyword argument."
             )
 
         op_name = name
@@ -278,7 +292,7 @@ def _variadicity_to_segment(variadicity: Variadicity) -> int:
     @staticmethod
     def _generate_segments(
         operands_or_results: List[Union[OperandDef, ResultDef]],
-    ) -> List[int]:
+    ) -> List[int] | None:
         if any(i.variadicity != Variadicity.single for i in operands_or_results):
             return [
                 Operation._variadicity_to_segment(i.variadicity)
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index d300f0b0442ae..196af91e511ec 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -24,7 +24,7 @@ class ConstantOp(MyInt.Operation, name="constant"):
         value: IntegerAttr
         cst: Result[i32]
 
-    class AddOp(MyInt.Operation, name="add"):
+    class AddOp(Operation, dialect=MyInt, name="add"):
         lhs: Operand[i32]
         rhs: Operand[i32]
         res: Result[i32]

>From f60e1aa719e5280fdfc1737bb1e841b845ace71f Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 20 Feb 2026 18:10:20 +0800
Subject: [PATCH 2/3] fix docs

---
 mlir/python/mlir/dialects/ext.py | 10 ++++++++--
 1 file changed, 8 insertions(+), 2 deletions(-)

diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index b7dd1af8afb96..ea7f70a42ea94 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -204,8 +204,14 @@ class Operation(ir.OpView):
     """
     Base class of Python-defined operation.
 
-    NOTE: Usually you don't need to use it directly.
-    Use `Dialect` and `.Operation` of `Dialect` subclasses instead.
+    The following example shows two ways to define operations via this class:
+    ```python
+    class MyOp(MyDialect.Operation, name=..):
+      ...
+
+    class MyOp(Operation, dialect=MyDialect, name=..):
+      ...
+    ```
     """
 
     def __init__(*args, **kwargs):

>From 25cf9ceba093b3debacdb236ee090f0aee2c26bb Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Fri, 20 Feb 2026 20:11:20 +0800
Subject: [PATCH 3/3] refactor

---
 mlir/python/mlir/dialects/ext.py | 22 +++++++++++-----------
 1 file changed, 11 insertions(+), 11 deletions(-)

diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index ea7f70a42ea94..79095658944e5 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -257,20 +257,19 @@ def __init_subclass__(
 
         cls._traits = traits
 
-        # for subclasses without "name" parameter,
-        # just treat them as normal classes
-        if not name:
-            return
-
         if dialect:
-            if hasattr(cls, "_dialect_name") or hasattr(cls, "_dialect_obj"):
+            if hasattr(cls, "_dialect_obj"):
                 raise RuntimeError(
-                    f"This operation has already been attached to dialect '{cls._dialect_name}'."
+                    f"This operation has already been attached to dialect '{cls._dialect_obj.DIALECT_NAMESPACE}'."
                 )
             cls._dialect_obj = dialect
-            cls._dialect_name = dialect.DIALECT_NAMESPACE
 
-        if not hasattr(cls, "_dialect_name") or not hasattr(cls, "_dialect_obj"):
+        # for subclasses without "name" parameter,
+        # just treat them as normal classes
+        if not name:
+            return
+
+        if not hasattr(cls, "_dialect_obj"):
             raise RuntimeError(
                 "Operation subclasses must either inherit from a Dialect's Operation subclass "
                 "or provide the dialect as a class keyword argument."
@@ -278,7 +277,7 @@ def __init_subclass__(
 
         op_name = name
         cls._op_name = op_name
-        dialect_name = cls._dialect_name
+        dialect_name = cls._dialect_obj.DIALECT_NAMESPACE
         dialect_obj = cls._dialect_obj
 
         cls._generate_class_attributes(dialect_name, op_name, fields)
@@ -519,7 +518,8 @@ def __init_subclass__(cls, name: str, **kwargs):
         cls.Operation = type(
             "Operation",
             (Operation,),
-            {"_dialect_obj": cls, "_dialect_name": name},
+            dict(),
+            dialect=cls,
         )
 
     @classmethod



More information about the Mlir-commits mailing list