[Mlir-commits] [mlir] [MLIR][Python] Refine the behavior of Python-defined dialect reloading (PR #186128)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 12 07:39:30 PDT 2026


https://github.com/PragmaTwice created https://github.com/llvm/llvm-project/pull/186128

This includes several changes:
- `Dialect.load(reload=False)` will fail if the dialect was already loaded in a different context. To prevent the further program abortion.
- `Dialect.load(reload=True)` implies `replace=True` in dialect/operation registering.
- `PyGlobals::registerDialectImpl` now has a parameter `replace`.
- `register_dialect` and `register_operation` is no longer exposed in `mlir.dialects.ext`.

This should solve the registering problem found in writing transform test cases by @rolfmorel.


>From ec5a6159ced2bdf05964f0f46ce463364d316273 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Thu, 12 Mar 2026 22:31:52 +0800
Subject: [PATCH] [MLIR][Python] Refine the behavior of Python-defined dialect
 reloading

---
 mlir/include/mlir/Bindings/Python/Globals.h   |  4 +-
 mlir/lib/Bindings/Python/Globals.cpp          |  4 +-
 mlir/lib/Bindings/Python/IRCore.cpp           |  3 +-
 mlir/python/mlir/dialects/ext.py              | 39 ++++++-------------
 .../python/dialects/transform_op_interface.py |  8 +---
 ...ansform_pattern_descriptor_op_interface.py |  4 +-
 6 files changed, 20 insertions(+), 42 deletions(-)

diff --git a/mlir/include/mlir/Bindings/Python/Globals.h b/mlir/include/mlir/Bindings/Python/Globals.h
index 8f7085f6024f5..8a7f30fd218dc 100644
--- a/mlir/include/mlir/Bindings/Python/Globals.h
+++ b/mlir/include/mlir/Bindings/Python/Globals.h
@@ -78,10 +78,10 @@ class MLIR_PYTHON_API_EXPORTED PyGlobals {
                            bool replace = false);
 
   /// Adds a concrete implementation dialect class.
-  /// Raises an exception if the mapping already exists.
+  /// Raises an exception if the mapping already exists and replace == false.
   /// This is intended to be called by implementation code.
   void registerDialectImpl(const std::string &dialectNamespace,
-                           nanobind::object pyClass);
+                           nanobind::object pyClass, bool replace = false);
 
   /// Adds a concrete implementation operation class.
   /// Raises an exception if the mapping already exists and replace == false.
diff --git a/mlir/lib/Bindings/Python/Globals.cpp b/mlir/lib/Bindings/Python/Globals.cpp
index 411b8a6705f1c..82195acb9f4fb 100644
--- a/mlir/lib/Bindings/Python/Globals.cpp
+++ b/mlir/lib/Bindings/Python/Globals.cpp
@@ -130,10 +130,10 @@ void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
 }
 
 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
-                                    nb::object pyClass) {
+                                    nb::object pyClass, bool replace) {
   nb::ft_lock_guard lock(mutex);
   nb::object &found = dialectClassMap[dialectNamespace];
-  if (found) {
+  if (found && !replace) {
     throw std::runtime_error(nanobind::detail::join(
         "Dialect namespace '", dialectNamespace, "' is already registered."));
   }
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index b8637c57a3f48..7341e7218c962 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2791,7 +2791,8 @@ void populateRoot(nb::module_ &m) {
           },
           "dialect_namespace"_a)
       .def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
-           "dialect_namespace"_a, "dialect_class"_a,
+           "dialect_namespace"_a, "dialect_class"_a, nb::kw_only(),
+           "replace"_a = false,
            "Testing hook for directly registering a dialect")
       .def("_register_operation_impl", &PyGlobals::registerOperationImpl,
            "operation_name"_a, "operation_class"_a, nb::kw_only(),
diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 5bcc595220f69..45eb218af3448 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -33,29 +33,12 @@
     "Region",
     "Type",
     "Attribute",
-    "register_dialect",
-    "register_operation",
 ]
 
 Operand = ir.Value
 Result = ir.OpResult
 Region = ir.Region
 
-register_dialect = _cext.register_dialect
-
-
-def register_operation(
-    dialect_cls: type, *, replace: bool = False
-) -> Callable[[type], type]:
-    register = _cext.register_operation(dialect_cls, replace=replace)
-
-    def decorator(op_cls: type) -> type:
-        register(op_cls)
-        _cext.register_op_adaptor(op_cls, replace=replace)(op_cls.Adaptor)
-        return op_cls
-
-    return decorator
-
 
 def construct_instance(origin, args):
     # `origin.get` is to construct an instance of MLIR type or attribute.
@@ -814,11 +797,14 @@ def _emit_module(cls) -> ir.Module:
     def load(
         cls,
         *,
-        register: bool = True,
         reload: bool = False,
-        replace: bool = False,
     ) -> None:
         if hasattr(cls, "_mlir_module") and not reload:
+            if cls._mlir_module.context is not ir.Context.current:
+                raise RuntimeError(
+                    "This dialect was loaded in a different context. "
+                    "Please set reload=True to reload the dialect in the current context."
+                )
             return
 
         cls._mlir_module = cls._emit_module()
@@ -831,17 +817,16 @@ def load(
         for op in cls.operations:
             op._attach_traits()
 
+        _cext.globals._register_dialect_impl(cls.DIALECT_NAMESPACE, cls, replace=reload)
+
         for type_ in cls.types:
             typeid = ir.DynamicType.lookup_typeid(type_.type_name)
-            _cext.register_type_caster(typeid, replace=replace)(type_)
+            _cext.register_type_caster(typeid, replace=reload)(type_)
 
         for attr in cls.attributes:
             typeid = ir.DynamicAttr.lookup_typeid(attr.attr_name)
-            _cext.register_type_caster(typeid, replace=replace)(attr)
-
-        if register:
-            register_dialect(cls)
+            _cext.register_type_caster(typeid, replace=reload)(attr)
 
-            register_dialect_operation = register_operation(cls, replace=replace)
-            for op in cls.operations:
-                register_dialect_operation(op)
+        for op in cls.operations:
+            _cext.register_operation(cls, replace=reload)(op)
+            _cext.register_op_adaptor(op, replace=reload)(op.Adaptor)
diff --git a/mlir/test/python/dialects/transform_op_interface.py b/mlir/test/python/dialects/transform_op_interface.py
index f58e0be13befd..a6e2c6da45322 100644
--- a/mlir/test/python/dialects/transform_op_interface.py
+++ b/mlir/test/python/dialects/transform_op_interface.py
@@ -16,7 +16,6 @@
 )
 
 
- at ext.register_dialect
 class MyTransform(ext.Dialect, name="my_transform"):
     pass
 
@@ -26,7 +25,7 @@ def run(emit_schedule):
     with ir.Context() as ctx, ir.Location.unknown():
         payload = emit_payload()
 
-        MyTransform.load(register=False, reload=True)
+        MyTransform.load(reload=True)
 
         GetNamedAttributeOp.attach_interface_impls(ctx)
         PrintParamOp.attach_interface_impls(ctx)
@@ -86,7 +85,6 @@ def get_effects(op: ir.Operation, effects):
 
 # Demonstration of a TransformOpInterface-implementing op that gets named attributes
 # from target ops and produces them as param handles.
- at ext.register_operation(MyTransform)
 class GetNamedAttributeOp(MyTransform.Operation, name="get_named_attribute"):
     target: ext.Operand[transform.AnyOpType]
     attr_name: ir.StringAttr
@@ -120,7 +118,6 @@ def allow_repeated_handle_operands(_op: "GetNamedAttributeOp") -> bool:
             return False
 
 
- at ext.register_operation(MyTransform)
 class PrintParamOp(MyTransform.Operation, name="print_param"):
     target: ext.Operand[transform.AnyParamType]
     name: ir.StringAttr
@@ -150,7 +147,6 @@ def allow_repeated_handle_operands(_op: "GetNamedAttributeOp") -> bool:
 
 
 # Syntax for an op with one op handle operand and one op handle result.
- at ext.register_operation(MyTransform)
 class OneOpInOneOpOut(MyTransform.Operation, name="one_op_in_one_op_out"):
     target: ext.Operand[transform.AnyOpType]
     res: ext.Result[transform.AnyOpType[()]]
@@ -273,7 +269,6 @@ def get_effects(op: ir.Operation, effects):
     return schedule
 
 
- at ext.register_operation(MyTransform)
 class OpValParamInParamOpValOut(
     MyTransform.Operation, name="op_val_param_in_param_op_val_out"
 ):
@@ -378,7 +373,6 @@ def allow_repeated_handle_operands(_op: OpValParamInParamOpValOut) -> bool:
     return schedule
 
 
- at ext.register_operation(MyTransform)
 class OpsParamsInValuesParamOut(
     MyTransform.Operation, name="ops_params_in_values_param_out"
 ):
diff --git a/mlir/test/python/dialects/transform_pattern_descriptor_op_interface.py b/mlir/test/python/dialects/transform_pattern_descriptor_op_interface.py
index 470c679179b03..9cd73331cfdea 100644
--- a/mlir/test/python/dialects/transform_pattern_descriptor_op_interface.py
+++ b/mlir/test/python/dialects/transform_pattern_descriptor_op_interface.py
@@ -7,7 +7,6 @@
 from mlir.dialects.transform import AnyOpType, structured
 
 
- at ext.register_dialect
 class MyPatternDescriptors(ext.Dialect, name="my_pattern_descriptors"):
     pass
 
@@ -17,7 +16,7 @@ def run(emit_schedule):
     with ir.Context(), ir.Location.unknown():
         payload = emit_payload()
 
-        MyPatternDescriptors.load(register=False, reload=True)
+        MyPatternDescriptors.load(reload=True)
 
         # NB: Pattern descriptor ops have their interfaces attached
         #     in their respective test functions.
@@ -58,7 +57,6 @@ def schedule_boilerplate():
             yield schedule, named_sequence
 
 
- at ext.register_operation(MyPatternDescriptors)
 class SubiAddiRewritePatternOp(MyPatternDescriptors.Operation, name="add_pattern"):
     @classmethod
     def attach_interface_impls(cls, ctx=None):



More information about the Mlir-commits mailing list