[Mlir-commits] [mlir] 90ba731 - [MLIR][Python] Add `replace` parameter to `Dialect.load` (#184604)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 4 18:21:20 PST 2026


Author: Twice
Date: 2026-03-05T10:21:15+08:00
New Revision: 90ba731c496b9ccf7fd3954756f5d096add3d0b5

URL: https://github.com/llvm/llvm-project/commit/90ba731c496b9ccf7fd3954756f5d096add3d0b5
DIFF: https://github.com/llvm/llvm-project/commit/90ba731c496b9ccf7fd3954756f5d096add3d0b5.diff

LOG: [MLIR][Python] Add `replace` parameter to `Dialect.load` (#184604)

In this PR, `replace` keyword parameter is added to `Dialect.load(..)`
in `mlir.dialects.ext`. It's to replace existing registered
operations/types/attrs.

Added: 
    

Modified: 
    mlir/python/mlir/dialects/ext.py

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index c2bd33b4203fe..5bcc595220f69 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -44,12 +44,14 @@
 register_dialect = _cext.register_dialect
 
 
-def register_operation(dialect_cls: type) -> Callable[[type], type]:
-    register = _cext.register_operation(dialect_cls)
+def register_operation(
+    dialect_cls: type, *, replace: bool = False
+) -> Callable[[type], type]:
+    register = _cext.register_operation(dialect_cls, replace=replace)
 
     def decorator(op_cls: type) -> type:
         register(op_cls)
-        _cext.register_op_adaptor(op_cls)(op_cls.Adaptor)
+        _cext.register_op_adaptor(op_cls, replace=replace)(op_cls.Adaptor)
         return op_cls
 
     return decorator
@@ -809,7 +811,13 @@ def _emit_module(cls) -> ir.Module:
         return m
 
     @classmethod
-    def load(cls, register=True, reload=False) -> None:
+    def load(
+        cls,
+        *,
+        register: bool = True,
+        reload: bool = False,
+        replace: bool = False,
+    ) -> None:
         if hasattr(cls, "_mlir_module") and not reload:
             return
 
@@ -825,15 +833,15 @@ def load(cls, register=True, reload=False) -> None:
 
         for type_ in cls.types:
             typeid = ir.DynamicType.lookup_typeid(type_.type_name)
-            _cext.register_type_caster(typeid)(type_)
+            _cext.register_type_caster(typeid, replace=replace)(type_)
 
         for attr in cls.attributes:
             typeid = ir.DynamicAttr.lookup_typeid(attr.attr_name)
-            _cext.register_type_caster(typeid)(attr)
+            _cext.register_type_caster(typeid, replace=replace)(attr)
 
         if register:
             register_dialect(cls)
 
-            register_dialect_operation = register_operation(cls)
+            register_dialect_operation = register_operation(cls, replace=replace)
             for op in cls.operations:
                 register_dialect_operation(op)


        


More information about the Mlir-commits mailing list