[Mlir-commits] [mlir] 0447766 - [MLIR][Python] Refine the behavior of Python-defined dialect reloading (#186128)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Mar 14 19:25:30 PDT 2026
Author: Twice
Date: 2026-03-15T10:25:24+08:00
New Revision: 044776691a67d769def35bdc4f4241a4b5fb23c6
URL: https://github.com/llvm/llvm-project/commit/044776691a67d769def35bdc4f4241a4b5fb23c6
DIFF: https://github.com/llvm/llvm-project/commit/044776691a67d769def35bdc4f4241a4b5fb23c6.diff
LOG: [MLIR][Python] Refine the behavior of Python-defined dialect reloading (#186128)
This includes several changes:
- `Dialect.load(reload=False)` will fail if the dialect was already
loaded in a different context. To prevent the further program abortion.
- `Dialect.load(reload=True)` implies `replace=True` in
dialect/operation registering.
- `PyGlobals::registerDialectImpl` now has a parameter `replace`.
- `register_dialect` and `register_operation` is no longer exposed in
`mlir.dialects.ext`.
This should solve the registering problem found in writing transform
test cases by @rolfmorel.
Added:
Modified:
mlir/include/mlir/Bindings/Python/Globals.h
mlir/lib/Bindings/Python/Globals.cpp
mlir/lib/Bindings/Python/IRCore.cpp
mlir/python/mlir/dialects/ext.py
mlir/test/python/dialects/transform_op_interface.py
mlir/test/python/dialects/transform_pattern_descriptor_op_interface.py
Removed:
################################################################################
diff --git a/mlir/include/mlir/Bindings/Python/Globals.h b/mlir/include/mlir/Bindings/Python/Globals.h
index 8f7085f6024f5..8a7f30fd218dc 100644
--- a/mlir/include/mlir/Bindings/Python/Globals.h
+++ b/mlir/include/mlir/Bindings/Python/Globals.h
@@ -78,10 +78,10 @@ class MLIR_PYTHON_API_EXPORTED PyGlobals {
bool replace = false);
/// Adds a concrete implementation dialect class.
- /// Raises an exception if the mapping already exists.
+ /// Raises an exception if the mapping already exists and replace == false.
/// This is intended to be called by implementation code.
void registerDialectImpl(const std::string &dialectNamespace,
- nanobind::object pyClass);
+ nanobind::object pyClass, bool replace = false);
/// Adds a concrete implementation operation class.
/// Raises an exception if the mapping already exists and replace == false.
diff --git a/mlir/lib/Bindings/Python/Globals.cpp b/mlir/lib/Bindings/Python/Globals.cpp
index 411b8a6705f1c..82195acb9f4fb 100644
--- a/mlir/lib/Bindings/Python/Globals.cpp
+++ b/mlir/lib/Bindings/Python/Globals.cpp
@@ -130,10 +130,10 @@ void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
}
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
- nb::object pyClass) {
+ nb::object pyClass, bool replace) {
nb::ft_lock_guard lock(mutex);
nb::object &found = dialectClassMap[dialectNamespace];
- if (found) {
+ if (found && !replace) {
throw std::runtime_error(nanobind::detail::join(
"Dialect namespace '", dialectNamespace, "' is already registered."));
}
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 7eb59d61b0d57..3d07e364b5c98 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2860,7 +2860,8 @@ void populateRoot(nb::module_ &m) {
},
"dialect_namespace"_a)
.def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
- "dialect_namespace"_a, "dialect_class"_a,
+ "dialect_namespace"_a, "dialect_class"_a, nb::kw_only(),
+ "replace"_a = false,
"Testing hook for directly registering a dialect")
.def("_register_operation_impl", &PyGlobals::registerOperationImpl,
"operation_name"_a, "operation_class"_a, nb::kw_only(),
diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py
index 15651a1c4e858..867da6ee96637 100644
--- a/mlir/python/mlir/dialects/ext.py
+++ b/mlir/python/mlir/dialects/ext.py
@@ -34,29 +34,12 @@
"Region",
"Type",
"Attribute",
- "register_dialect",
- "register_operation",
]
Operand = ir.Value
Result = ir.OpResult
Region = ir.Region
-register_dialect = _cext.register_dialect
-
-
-def register_operation(
- dialect_cls: type, *, replace: bool = False
-) -> Callable[[type], type]:
- register = _cext.register_operation(dialect_cls, replace=replace)
-
- def decorator(op_cls: type) -> type:
- register(op_cls)
- _cext.register_op_adaptor(op_cls, replace=replace)(op_cls.Adaptor)
- return op_cls
-
- return decorator
-
def construct_instance(origin, args):
# `origin.get` is to construct an instance of MLIR type or attribute.
@@ -816,11 +799,14 @@ def _emit_module(cls) -> ir.Module:
def load(
cls,
*,
- register: bool = True,
reload: bool = False,
- replace: bool = False,
) -> None:
if hasattr(cls, "_mlir_module") and not reload:
+ if cls._mlir_module.context is not ir.Context.current:
+ raise RuntimeError(
+ "This dialect was loaded in a
diff erent context. "
+ "Please set reload=True to reload the dialect in the current context."
+ )
return
cls._mlir_module = cls._emit_module()
@@ -833,17 +819,16 @@ def load(
for op in cls.operations:
op._attach_traits()
+ _cext.globals._register_dialect_impl(cls.DIALECT_NAMESPACE, cls, replace=reload)
+
for type_ in cls.types:
typeid = ir.DynamicType.lookup_typeid(type_.type_name)
- _cext.register_type_caster(typeid, replace=replace)(type_)
+ _cext.register_type_caster(typeid, replace=reload)(type_)
for attr in cls.attributes:
typeid = ir.DynamicAttr.lookup_typeid(attr.attr_name)
- _cext.register_type_caster(typeid, replace=replace)(attr)
-
- if register:
- register_dialect(cls)
+ _cext.register_type_caster(typeid, replace=reload)(attr)
- register_dialect_operation = register_operation(cls, replace=replace)
- for op in cls.operations:
- register_dialect_operation(op)
+ for op in cls.operations:
+ _cext.register_operation(cls, replace=reload)(op)
+ _cext.register_op_adaptor(op, replace=reload)(op.Adaptor)
diff --git a/mlir/test/python/dialects/transform_op_interface.py b/mlir/test/python/dialects/transform_op_interface.py
index f58e0be13befd..a6e2c6da45322 100644
--- a/mlir/test/python/dialects/transform_op_interface.py
+++ b/mlir/test/python/dialects/transform_op_interface.py
@@ -16,7 +16,6 @@
)
- at ext.register_dialect
class MyTransform(ext.Dialect, name="my_transform"):
pass
@@ -26,7 +25,7 @@ def run(emit_schedule):
with ir.Context() as ctx, ir.Location.unknown():
payload = emit_payload()
- MyTransform.load(register=False, reload=True)
+ MyTransform.load(reload=True)
GetNamedAttributeOp.attach_interface_impls(ctx)
PrintParamOp.attach_interface_impls(ctx)
@@ -86,7 +85,6 @@ def get_effects(op: ir.Operation, effects):
# Demonstration of a TransformOpInterface-implementing op that gets named attributes
# from target ops and produces them as param handles.
- at ext.register_operation(MyTransform)
class GetNamedAttributeOp(MyTransform.Operation, name="get_named_attribute"):
target: ext.Operand[transform.AnyOpType]
attr_name: ir.StringAttr
@@ -120,7 +118,6 @@ def allow_repeated_handle_operands(_op: "GetNamedAttributeOp") -> bool:
return False
- at ext.register_operation(MyTransform)
class PrintParamOp(MyTransform.Operation, name="print_param"):
target: ext.Operand[transform.AnyParamType]
name: ir.StringAttr
@@ -150,7 +147,6 @@ def allow_repeated_handle_operands(_op: "GetNamedAttributeOp") -> bool:
# Syntax for an op with one op handle operand and one op handle result.
- at ext.register_operation(MyTransform)
class OneOpInOneOpOut(MyTransform.Operation, name="one_op_in_one_op_out"):
target: ext.Operand[transform.AnyOpType]
res: ext.Result[transform.AnyOpType[()]]
@@ -273,7 +269,6 @@ def get_effects(op: ir.Operation, effects):
return schedule
- at ext.register_operation(MyTransform)
class OpValParamInParamOpValOut(
MyTransform.Operation, name="op_val_param_in_param_op_val_out"
):
@@ -378,7 +373,6 @@ def allow_repeated_handle_operands(_op: OpValParamInParamOpValOut) -> bool:
return schedule
- at ext.register_operation(MyTransform)
class OpsParamsInValuesParamOut(
MyTransform.Operation, name="ops_params_in_values_param_out"
):
diff --git a/mlir/test/python/dialects/transform_pattern_descriptor_op_interface.py b/mlir/test/python/dialects/transform_pattern_descriptor_op_interface.py
index 470c679179b03..9cd73331cfdea 100644
--- a/mlir/test/python/dialects/transform_pattern_descriptor_op_interface.py
+++ b/mlir/test/python/dialects/transform_pattern_descriptor_op_interface.py
@@ -7,7 +7,6 @@
from mlir.dialects.transform import AnyOpType, structured
- at ext.register_dialect
class MyPatternDescriptors(ext.Dialect, name="my_pattern_descriptors"):
pass
@@ -17,7 +16,7 @@ def run(emit_schedule):
with ir.Context(), ir.Location.unknown():
payload = emit_payload()
- MyPatternDescriptors.load(register=False, reload=True)
+ MyPatternDescriptors.load(reload=True)
# NB: Pattern descriptor ops have their interfaces attached
# in their respective test functions.
@@ -58,7 +57,6 @@ def schedule_boilerplate():
yield schedule, named_sequence
- at ext.register_operation(MyPatternDescriptors)
class SubiAddiRewritePatternOp(MyPatternDescriptors.Operation, name="add_pattern"):
@classmethod
def attach_interface_impls(cls, ctx=None):
More information about the Mlir-commits
mailing list