[Mlir-commits] [mlir] 0447766 - [MLIR][Python] Refine the behavior of Python-defined dialect reloading (#186128)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Mar 14 19:25:30 PDT 2026


Author: Twice
Date: 2026-03-15T10:25:24+08:00
New Revision: 044776691a67d769def35bdc4f4241a4b5fb23c6

URL: https://github.com/llvm/llvm-project/commit/044776691a67d769def35bdc4f4241a4b5fb23c6
DIFF: https://github.com/llvm/llvm-project/commit/044776691a67d769def35bdc4f4241a4b5fb23c6.diff

LOG: [MLIR][Python] Refine the behavior of Python-defined dialect reloading (#186128)

This includes several changes:
- `Dialect.load(reload=False)` will fail if the dialect was already
loaded in a different context. To prevent the further program abortion.
- `Dialect.load(reload=True)` implies `replace=True` in
dialect/operation registering.
- `PyGlobals::registerDialectImpl` now has a parameter `replace`.
- `register_dialect` and `register_operation` is no longer exposed in
`mlir.dialects.ext`.

This should solve the registering problem found in writing transform
test cases by @rolfmorel.

Added: 
    

Modified: 
    mlir/include/mlir/Bindings/Python/Globals.h
    mlir/lib/Bindings/Python/Globals.cpp
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/python/mlir/dialects/ext.py
    mlir/test/python/dialects/transform_op_interface.py
    mlir/test/python/dialects/transform_pattern_descriptor_op_interface.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Bindings/Python/Globals.h b/mlir/include/mlir/Bindings/Python/Globals.h
index 8f7085f6024f5..8a7f30fd218dc 100644
--- a/mlir/include/mlir/Bindings/Python/Globals.h
+++ b/mlir/include/mlir/Bindings/Python/Globals.h
@@ -78,10 +78,10 @@ class MLIR_PYTHON_API_EXPORTED PyGlobals {
                            bool replace = false);
 
   /// Adds a concrete implementation dialect class.
-  /// Raises an exception if the mapping already exists.
+  /// Raises an exception if the mapping already exists and replace == false.
   /// This is intended to be called by implementation code.
   void registerDialectImpl(const std::string &dialectNamespace,
-                           nanobind::object pyClass);
+                           nanobind::object pyClass, bool replace = false);
 
   /// Adds a concrete implementation operation class.
   /// Raises an exception if the mapping already exists and replace == false.

diff  --git a/mlir/lib/Bindings/Python/Globals.cpp b/mlir/lib/Bindings/Python/Globals.cpp
index 411b8a6705f1c..82195acb9f4fb 100644
--- a/mlir/lib/Bindings/Python/Globals.cpp
+++ b/mlir/lib/Bindings/Python/Globals.cpp
@@ -130,10 +130,10 @@ void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
 }
 
 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
-                                    nb::object pyClass) {
+                                    nb::object pyClass, bool replace) {
   nb::ft_lock_guard lock(mutex);
   nb::object &found = dialectClassMap[dialectNamespace];
-  if (found) {
+  if (found && !replace) {
     throw std::runtime_error(nanobind::detail::join(
         "Dialect namespace '", dialectNamespace, "' is already registered."));
   }

diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 7eb59d61b0d57..3d07e364b5c98 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2860,7 +2860,8 @@ void populateRoot(nb::module_ &m) {
           },
           "dialect_namespace"_a)
       .def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
-           "dialect_namespace"_a, "dialect_class"_a,
+           "dialect_namespace"_a, "dialect_class"_a, nb::kw_only(),
+           "replace"_a = false,
            "Testing hook for directly registering a dialect")
       .def("_register_operation_impl", &PyGlobals::registerOperationImpl,
            "operation_name"_a, "operation_class"_a, nb::kw_only(),

diff  --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 15651a1c4e858..867da6ee96637 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -34,29 +34,12 @@
     "Region",
     "Type",
     "Attribute",
-    "register_dialect",
-    "register_operation",
 ]
 
 Operand = ir.Value
 Result = ir.OpResult
 Region = ir.Region
 
-register_dialect = _cext.register_dialect
-
-
-def register_operation(
-    dialect_cls: type, *, replace: bool = False
-) -> Callable[[type], type]:
-    register = _cext.register_operation(dialect_cls, replace=replace)
-
-    def decorator(op_cls: type) -> type:
-        register(op_cls)
-        _cext.register_op_adaptor(op_cls, replace=replace)(op_cls.Adaptor)
-        return op_cls
-
-    return decorator
-
 
 def construct_instance(origin, args):
     # `origin.get` is to construct an instance of MLIR type or attribute.
@@ -816,11 +799,14 @@ def _emit_module(cls) -> ir.Module:
     def load(
         cls,
         *,
-        register: bool = True,
         reload: bool = False,
-        replace: bool = False,
     ) -> None:
         if hasattr(cls, "_mlir_module") and not reload:
+            if cls._mlir_module.context is not ir.Context.current:
+                raise RuntimeError(
+                    "This dialect was loaded in a 
diff erent context. "
+                    "Please set reload=True to reload the dialect in the current context."
+                )
             return
 
         cls._mlir_module = cls._emit_module()
@@ -833,17 +819,16 @@ def load(
         for op in cls.operations:
             op._attach_traits()
 
+        _cext.globals._register_dialect_impl(cls.DIALECT_NAMESPACE, cls, replace=reload)
+
         for type_ in cls.types:
             typeid = ir.DynamicType.lookup_typeid(type_.type_name)
-            _cext.register_type_caster(typeid, replace=replace)(type_)
+            _cext.register_type_caster(typeid, replace=reload)(type_)
 
         for attr in cls.attributes:
             typeid = ir.DynamicAttr.lookup_typeid(attr.attr_name)
-            _cext.register_type_caster(typeid, replace=replace)(attr)
-
-        if register:
-            register_dialect(cls)
+            _cext.register_type_caster(typeid, replace=reload)(attr)
 
-            register_dialect_operation = register_operation(cls, replace=replace)
-            for op in cls.operations:
-                register_dialect_operation(op)
+        for op in cls.operations:
+            _cext.register_operation(cls, replace=reload)(op)
+            _cext.register_op_adaptor(op, replace=reload)(op.Adaptor)

diff  --git a/mlir/test/python/dialects/transform_op_interface.py b/mlir/test/python/dialects/transform_op_interface.py
index f58e0be13befd..a6e2c6da45322 100644
--- a/mlir/test/python/dialects/transform_op_interface.py
+++ b/mlir/test/python/dialects/transform_op_interface.py
@@ -16,7 +16,6 @@
 )
 
 
- at ext.register_dialect
 class MyTransform(ext.Dialect, name="my_transform"):
     pass
 
@@ -26,7 +25,7 @@ def run(emit_schedule):
     with ir.Context() as ctx, ir.Location.unknown():
         payload = emit_payload()
 
-        MyTransform.load(register=False, reload=True)
+        MyTransform.load(reload=True)
 
         GetNamedAttributeOp.attach_interface_impls(ctx)
         PrintParamOp.attach_interface_impls(ctx)
@@ -86,7 +85,6 @@ def get_effects(op: ir.Operation, effects):
 
 # Demonstration of a TransformOpInterface-implementing op that gets named attributes
 # from target ops and produces them as param handles.
- at ext.register_operation(MyTransform)
 class GetNamedAttributeOp(MyTransform.Operation, name="get_named_attribute"):
     target: ext.Operand[transform.AnyOpType]
     attr_name: ir.StringAttr
@@ -120,7 +118,6 @@ def allow_repeated_handle_operands(_op: "GetNamedAttributeOp") -> bool:
             return False
 
 
- at ext.register_operation(MyTransform)
 class PrintParamOp(MyTransform.Operation, name="print_param"):
     target: ext.Operand[transform.AnyParamType]
     name: ir.StringAttr
@@ -150,7 +147,6 @@ def allow_repeated_handle_operands(_op: "GetNamedAttributeOp") -> bool:
 
 
 # Syntax for an op with one op handle operand and one op handle result.
- at ext.register_operation(MyTransform)
 class OneOpInOneOpOut(MyTransform.Operation, name="one_op_in_one_op_out"):
     target: ext.Operand[transform.AnyOpType]
     res: ext.Result[transform.AnyOpType[()]]
@@ -273,7 +269,6 @@ def get_effects(op: ir.Operation, effects):
     return schedule
 
 
- at ext.register_operation(MyTransform)
 class OpValParamInParamOpValOut(
     MyTransform.Operation, name="op_val_param_in_param_op_val_out"
 ):
@@ -378,7 +373,6 @@ def allow_repeated_handle_operands(_op: OpValParamInParamOpValOut) -> bool:
     return schedule
 
 
- at ext.register_operation(MyTransform)
 class OpsParamsInValuesParamOut(
     MyTransform.Operation, name="ops_params_in_values_param_out"
 ):

diff  --git a/mlir/test/python/dialects/transform_pattern_descriptor_op_interface.py b/mlir/test/python/dialects/transform_pattern_descriptor_op_interface.py
index 470c679179b03..9cd73331cfdea 100644
--- a/mlir/test/python/dialects/transform_pattern_descriptor_op_interface.py
+++ b/mlir/test/python/dialects/transform_pattern_descriptor_op_interface.py
@@ -7,7 +7,6 @@
 from mlir.dialects.transform import AnyOpType, structured
 
 
- at ext.register_dialect
 class MyPatternDescriptors(ext.Dialect, name="my_pattern_descriptors"):
     pass
 
@@ -17,7 +16,7 @@ def run(emit_schedule):
     with ir.Context(), ir.Location.unknown():
         payload = emit_payload()
 
-        MyPatternDescriptors.load(register=False, reload=True)
+        MyPatternDescriptors.load(reload=True)
 
         # NB: Pattern descriptor ops have their interfaces attached
         #     in their respective test functions.
@@ -58,7 +57,6 @@ def schedule_boilerplate():
             yield schedule, named_sequence
 
 
- at ext.register_operation(MyPatternDescriptors)
 class SubiAddiRewritePatternOp(MyPatternDescriptors.Operation, name="add_pattern"):
     @classmethod
     def attach_interface_impls(cls, ctx=None):


        


More information about the Mlir-commits mailing list