[Mlir-commits] [mlir] [MLIR][Python] Support op adaptor for Python-defined operations (PR #183528)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 26 06:12:55 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Twice (PragmaTwice)
<details>
<summary>Changes</summary>
Previously, in #<!-- -->177782, we added support for dialect conversion and generated an `OpAdaptor` subtype for every ODS-defined operation. In this PR, we will also generate `OpAdaptor` subtypes for Python-defined operations, so that they can be applied in dialect conversion as well.
---
Full diff: https://github.com/llvm/llvm-project/pull/183528.diff
2 Files Affected:
- (modified) mlir/python/mlir/dialects/ext.py (+49-1)
- (modified) mlir/test/python/dialects/ext.py (+10)
``````````diff
diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 39aacf32dabb9..d88e25cced8f6 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -41,7 +41,17 @@
Region = ir.Region
register_dialect = _cext.register_dialect
-register_operation = _cext.register_operation
+
+
+def register_operation(dialect_cls: type) -> Callable[[type], type]:
+ register = _cext.register_operation(dialect_cls)
+
+ def decorator(op_cls: type) -> type:
+ register(op_cls)
+ _cext.register_op_adaptor(op_cls)(op_cls.Adaptor)
+ return op_cls
+
+ return decorator
def construct_instance(origin, args):
@@ -307,6 +317,13 @@ def __init_subclass__(
cls._generate_result_properties(results)
cls._generate_region_properties(regions)
+ cls.Adaptor = type(
+ "Adaptor",
+ (OperationAdator,),
+ dict(),
+ operation=cls,
+ )
+
dialect_obj.operations.append(cls)
@staticmethod
@@ -507,6 +524,37 @@ def _emit_operation(cls) -> None:
)
+class OperationAdator(ir.OpAdaptor):
+ @classmethod
+ def __init_subclass__(cls, *, operation: type):
+ cls.OPERATION_NAME = operation.OPERATION_NAME
+ cls._operation_cls = operation
+
+ operands, attrs, results, regions = partition_fields(operation._fields)
+
+ for attr in attrs:
+ setattr(
+ cls,
+ attr.name,
+ property(lambda self, name=attr.name: self.attributes[name]),
+ )
+
+ for i, operand in enumerate(operands):
+ if operation._ODS_OPERAND_SEGMENTS:
+
+ def getter(self, i=i, operand=operand):
+ operand_range = segmented_accessor(
+ self.operands,
+ self.attributes["operandSegmentSizes"],
+ i,
+ )
+ return normalize_value_range(operand_range, operand.variadicity)
+
+ setattr(cls, operand.name, property(getter))
+ else:
+ setattr(cls, operand.name, property(lambda self, i=i: self.operands[i]))
+
+
@dataclass
class ParamDef:
name: str
diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py
index f9252bad37a39..2921615e75d54 100644
--- a/mlir/test/python/dialects/ext.py
+++ b/mlir/test/python/dialects/ext.py
@@ -91,6 +91,16 @@ class AddOp(Operation, dialect=MyInt, name="add"):
# CHECK: (self, /, value, *, loc=None, ip=None)
print(ConstantOp.__init__.__signature__)
+ # CHECK: True
+ print(issubclass(AddOp.Adaptor, OpAdaptor))
+ adaptor1 = AddOp.Adaptor(list(add1.operands), add1)
+ # CHECK: myint.add
+ print(adaptor1.OPERATION_NAME)
+ # CHECK: OpResult(%0 = "myint.constant"() {value = 2 : i32} : () -> i32)
+ print(adaptor1.lhs)
+ # CHECK: OpResult(%1 = "myint.constant"() {value = 3 : i32} : () -> i32)
+ print(adaptor1.rhs)
+
# CHECK: TEST: testExtDialect
@run
``````````
</details>
https://github.com/llvm/llvm-project/pull/183528
More information about the Mlir-commits
mailing list