[Mlir-commits] [mlir] [MLIR][Python] Support op adaptor for Python-defined operations (PR #183528)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 26 06:12:55 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Twice (PragmaTwice)

<details>
<summary>Changes</summary>

Previously, in #<!-- -->177782, we added support for dialect conversion and generated an `OpAdaptor` subtype for every ODS-defined operation. In this PR, we will also generate `OpAdaptor` subtypes for Python-defined operations, so that they can be applied in dialect conversion as well.


---
Full diff: https://github.com/llvm/llvm-project/pull/183528.diff


2 Files Affected:

- (modified) mlir/python/mlir/dialects/ext.py (+49-1) 
- (modified) mlir/test/python/dialects/ext.py (+10) 


``````````diff
diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 39aacf32dabb9..d88e25cced8f6 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -41,7 +41,17 @@
 Region = ir.Region
 
 register_dialect = _cext.register_dialect
-register_operation = _cext.register_operation
+
+
+def register_operation(dialect_cls: type) -> Callable[[type], type]:
+    register = _cext.register_operation(dialect_cls)
+
+    def decorator(op_cls: type) -> type:
+        register(op_cls)
+        _cext.register_op_adaptor(op_cls)(op_cls.Adaptor)
+        return op_cls
+
+    return decorator
 
 
 def construct_instance(origin, args):
@@ -307,6 +317,13 @@ def __init_subclass__(
         cls._generate_result_properties(results)
         cls._generate_region_properties(regions)
 
+        cls.Adaptor = type(
+            "Adaptor",
+            (OperationAdator,),
+            dict(),
+            operation=cls,
+        )
+
         dialect_obj.operations.append(cls)
 
     @staticmethod
@@ -507,6 +524,37 @@ def _emit_operation(cls) -> None:
                 )
 
 
+class OperationAdator(ir.OpAdaptor):
+    @classmethod
+    def __init_subclass__(cls, *, operation: type):
+        cls.OPERATION_NAME = operation.OPERATION_NAME
+        cls._operation_cls = operation
+
+        operands, attrs, results, regions = partition_fields(operation._fields)
+
+        for attr in attrs:
+            setattr(
+                cls,
+                attr.name,
+                property(lambda self, name=attr.name: self.attributes[name]),
+            )
+
+        for i, operand in enumerate(operands):
+            if operation._ODS_OPERAND_SEGMENTS:
+
+                def getter(self, i=i, operand=operand):
+                    operand_range = segmented_accessor(
+                        self.operands,
+                        self.attributes["operandSegmentSizes"],
+                        i,
+                    )
+                    return normalize_value_range(operand_range, operand.variadicity)
+
+                setattr(cls, operand.name, property(getter))
+            else:
+                setattr(cls, operand.name, property(lambda self, i=i: self.operands[i]))
+
+
 @dataclass
 class ParamDef:
     name: str
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index f9252bad37a39..2921615e75d54 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -91,6 +91,16 @@ class AddOp(Operation, dialect=MyInt, name="add"):
         # CHECK: (self, /, value, *, loc=None, ip=None)
         print(ConstantOp.__init__.__signature__)
 
+        # CHECK: True
+        print(issubclass(AddOp.Adaptor, OpAdaptor))
+        adaptor1 = AddOp.Adaptor(list(add1.operands), add1)
+        # CHECK: myint.add
+        print(adaptor1.OPERATION_NAME)
+        # CHECK: OpResult(%0 = "myint.constant"() {value = 2 : i32} : () -> i32)
+        print(adaptor1.lhs)
+        # CHECK: OpResult(%1 = "myint.constant"() {value = 3 : i32} : () -> i32)
+        print(adaptor1.rhs)
+
 
 # CHECK: TEST: testExtDialect
 @run

``````````

</details>


https://github.com/llvm/llvm-project/pull/183528


More information about the Mlir-commits mailing list