[Mlir-commits] [mlir] c703b46 - [mlir][py] Enable loading only specified dialects during creation. (#121421)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jan 2 14:40:18 PST 2025
Author: Jacques Pienaar
Date: 2025-01-02T14:40:15-08:00
New Revision: c703b4645c79e889fd6a0f3f64f01f957d981aa4
URL: https://github.com/llvm/llvm-project/commit/c703b4645c79e889fd6a0f3f64f01f957d981aa4
DIFF: https://github.com/llvm/llvm-project/commit/c703b4645c79e889fd6a0f3f64f01f957d981aa4.diff
LOG: [mlir][py] Enable loading only specified dialects during creation. (#121421)
Gives option post as global list as well as arg to control which
dialects are loaded during context creation. This enables setting either
a good base set or skipping in individual cases.
Added:
Modified:
mlir/python/mlir/_mlir_libs/__init__.py
mlir/python/mlir/ir.py
mlir/test/python/ir/dialects.py
Removed:
################################################################################
diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index c5cb22c6dccb8f..d021dde05dd871 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -58,6 +58,7 @@ def get_include_dirs() -> Sequence[str]:
# needs.
_dialect_registry = None
+_load_on_create_dialects = None
def get_dialect_registry():
@@ -71,6 +72,21 @@ def get_dialect_registry():
return _dialect_registry
+def append_load_on_create_dialect(dialect: str):
+ global _load_on_create_dialects
+ if _load_on_create_dialects is None:
+ _load_on_create_dialects = [dialect]
+ else:
+ _load_on_create_dialects.append(dialect)
+
+
+def get_load_on_create_dialects():
+ global _load_on_create_dialects
+ if _load_on_create_dialects is None:
+ _load_on_create_dialects = []
+ return _load_on_create_dialects
+
+
def _site_initialize():
import importlib
import itertools
@@ -132,15 +148,35 @@ def process_initializer_module(module_name):
break
class Context(ir._BaseContext):
- def __init__(self, *args, **kwargs):
+ def __init__(self, load_on_create_dialects=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.append_dialect_registry(get_dialect_registry())
for hook in post_init_hooks:
hook(self)
if not disable_multithreading:
self.enable_multithreading(True)
- if not disable_load_all_available_dialects:
- self.load_all_available_dialects()
+ if load_on_create_dialects is not None:
+ logger.debug(
+ "Loading all dialects from load_on_create_dialects arg %r",
+ load_on_create_dialects,
+ )
+ for dialect in load_on_create_dialects:
+ # This triggers loading the dialect into the context.
+ _ = self.dialects[dialect]
+ else:
+ if disable_load_all_available_dialects:
+ dialects = get_load_on_create_dialects()
+ if dialects:
+ logger.debug(
+ "Loading all dialects from global load_on_create_dialects %r",
+ dialects,
+ )
+ for dialect in dialects:
+ # This triggers loading the dialect into the context.
+ _ = self.dialects[dialect]
+ else:
+ logger.debug("Loading all available dialects")
+ self.load_all_available_dialects()
if init_module:
logger.debug(
"Registering translations from initializer %r", init_module
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 9a6ce462047ad2..6f37266d5bf395 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -5,7 +5,11 @@
from ._mlir_libs._mlir.ir import *
from ._mlir_libs._mlir.ir import _GlobalDebug
from ._mlir_libs._mlir import register_type_caster, register_value_caster
-from ._mlir_libs import get_dialect_registry
+from ._mlir_libs import (
+ get_dialect_registry,
+ append_load_on_create_dialect,
+ get_load_on_create_dialects,
+)
# Convenience decorator for registering user-friendly Attribute builders.
diff --git a/mlir/test/python/ir/dialects.py b/mlir/test/python/ir/dialects.py
index d59c6a6bc424e6..5a2ed684d298b3 100644
--- a/mlir/test/python/ir/dialects.py
+++ b/mlir/test/python/ir/dialects.py
@@ -121,3 +121,39 @@ def testAppendPrefixSearchPath():
sys.path.append(".")
_cext.globals.append_dialect_search_prefix("custom_dialect")
assert _cext.globals._check_dialect_module_loaded("custom")
+
+
+# CHECK-LABEL: TEST: testDialectLoadOnCreate
+ at run
+def testDialectLoadOnCreate():
+ with Context(load_on_create_dialects=[]) as ctx:
+ ctx.emit_error_diagnostics = True
+ ctx.allow_unregistered_dialects = True
+
+ def callback(d):
+ # CHECK: DIAGNOSTIC
+ # CHECK-SAME: op created with unregistered dialect
+ print(f"DIAGNOSTIC={d.message}")
+ return True
+
+ handler = ctx.attach_diagnostic_handler(callback)
+ loc = Location.unknown(ctx)
+ try:
+ op = Operation.create("arith.addi", loc=loc)
+ ctx.allow_unregistered_dialects = False
+ op.verify()
+ except MLIRError as e:
+ pass
+
+ with Context(load_on_create_dialects=["func"]) as ctx:
+ loc = Location.unknown(ctx)
+ fn = Operation.create("func.func", loc=loc)
+
+ # TODO: This may require an update if a site wide policy is set.
+ # CHECK: Load on create: []
+ print(f"Load on create: {get_load_on_create_dialects()}")
+ append_load_on_create_dialect("func")
+ # CHECK: Load on create:
+ # CHECK-SAME: func
+ print(f"Load on create: {get_load_on_create_dialects()}")
+ print(get_load_on_create_dialects())
More information about the Mlir-commits
mailing list