[Mlir-commits] [mlir] [mlir][py] Enable loading only specified dialects during creation. (PR #121421)

Jacques Pienaar llvmlistbot at llvm.org
Thu Jan 2 12:02:26 PST 2025


https://github.com/jpienaar updated https://github.com/llvm/llvm-project/pull/121421

>From cebba8e9e795092fbc4a541395d97e704f94c2f4 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Wed, 1 Jan 2025 01:53:23 +0000
Subject: [PATCH 1/2] [mlir][py] Enable loading only specified dialects during
 creation.

Gives option post as global list as well as arg to control which
dialects are loaded during context creation. This enables setting either
a good base set or skipping in individual cases.
---
 mlir/python/mlir/_mlir_libs/__init__.py | 35 ++++++++++++++++++++++---
 mlir/python/mlir/ir.py                  |  2 +-
 mlir/test/python/ir/dialects.py         | 28 ++++++++++++++++++++
 3 files changed, 61 insertions(+), 4 deletions(-)

diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index c5cb22c6dccb8f..dbc458b887d671 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -58,6 +58,7 @@ def get_include_dirs() -> Sequence[str]:
 # needs.
 
 _dialect_registry = None
+_load_on_create_dialects = None
 
 
 def get_dialect_registry():
@@ -71,6 +72,21 @@ def get_dialect_registry():
     return _dialect_registry
 
 
+def append_load_on_create_dialect(dialect: str):
+    global _load_on_create_dialects
+    if _load_on_create_dialects is None:
+        _load_on_create_dialects = [dialect]
+    else:
+        _load_on_create_dialects.append(dialect)
+
+
+def get_load_on_create_dialects():
+    global _load_on_create_dialects
+    if _load_on_create_dialects is None:
+        _load_on_create_dialects = []
+    return _load_on_create_dialects
+
+
 def _site_initialize():
     import importlib
     import itertools
@@ -132,15 +148,28 @@ def process_initializer_module(module_name):
             break
 
     class Context(ir._BaseContext):
-        def __init__(self, *args, **kwargs):
+        def __init__(self, load_on_create_dialects=None, *args, **kwargs):
             super().__init__(*args, **kwargs)
             self.append_dialect_registry(get_dialect_registry())
             for hook in post_init_hooks:
                 hook(self)
             if not disable_multithreading:
                 self.enable_multithreading(True)
-            if not disable_load_all_available_dialects:
-                self.load_all_available_dialects()
+            if load_on_create_dialects is not None:
+                logger.debug("Loading all dialects from load_on_create_dialects arg %r", _load_on_create_dialects)
+                for dialect in load_on_create_dialects:
+                    # Load dialect.
+                    _ = self.dialects[dialect]
+            else:
+                if disable_load_all_available_dialects:
+                    if _load_on_create_dialects:
+                        logger.debug("Loading all dialects from global load_on_create_dialects %r", _load_on_create_dialects)
+                        for dialect in _load_on_create_dialects:
+                            # Load dialect.
+                            _ = self.dialects[dialect]
+                else:
+                    logger.debug("Loading all available dialects")
+                    self.load_all_available_dialects()
             if init_module:
                 logger.debug(
                     "Registering translations from initializer %r", init_module
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 9a6ce462047ad2..6f1c0da8a4e5d6 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -5,7 +5,7 @@
 from ._mlir_libs._mlir.ir import *
 from ._mlir_libs._mlir.ir import _GlobalDebug
 from ._mlir_libs._mlir import register_type_caster, register_value_caster
-from ._mlir_libs import get_dialect_registry
+from ._mlir_libs import get_dialect_registry, append_load_on_create_dialect, get_load_on_create_dialects
 
 
 # Convenience decorator for registering user-friendly Attribute builders.
diff --git a/mlir/test/python/ir/dialects.py b/mlir/test/python/ir/dialects.py
index d59c6a6bc424e6..3742835208a5d9 100644
--- a/mlir/test/python/ir/dialects.py
+++ b/mlir/test/python/ir/dialects.py
@@ -121,3 +121,31 @@ def testAppendPrefixSearchPath():
         sys.path.append(".")
         _cext.globals.append_dialect_search_prefix("custom_dialect")
         assert _cext.globals._check_dialect_module_loaded("custom")
+
+
+# CHECK-LABEL: TEST: testDialectLoadOnCreate
+ at run
+def testDialectLoadOnCreate():
+    with Context(load_on_create_dialects=[]) as ctx:
+        ctx.emit_error_diagnostics = True
+        ctx.allow_unregistered_dialects = True
+        
+        def callback(d):
+            # CHECK: DIAGNOSTIC
+            # CHECK-SAME: op created with unregistered dialect
+            print(f"DIAGNOSTIC={d.message}")
+            return True
+
+        handler = ctx.attach_diagnostic_handler(callback)
+        loc = Location.unknown(ctx)
+        try:
+          op = Operation.create("arith.addi", loc=loc)
+          ctx.allow_unregistered_dialects = False
+          op.verify()
+        except MLIRError as e:
+          pass
+  
+    with Context(load_on_create_dialects=["func"]) as ctx:
+      loc = Location.unknown(ctx)
+      fn = Operation.create("func.func", loc=loc)
+

>From 22d36f863c77d1b288025f6bc76a2a27036b7f37 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Thu, 2 Jan 2025 20:02:02 +0000
Subject: [PATCH 2/2] Address review comments

---
 mlir/python/mlir/_mlir_libs/__init__.py | 19 ++++++++++++------
 mlir/python/mlir/ir.py                  |  6 +++++-
 mlir/test/python/ir/dialects.py         | 26 ++++++++++++++++---------
 3 files changed, 35 insertions(+), 16 deletions(-)

diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index dbc458b887d671..d021dde05dd871 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -156,16 +156,23 @@ def __init__(self, load_on_create_dialects=None, *args, **kwargs):
             if not disable_multithreading:
                 self.enable_multithreading(True)
             if load_on_create_dialects is not None:
-                logger.debug("Loading all dialects from load_on_create_dialects arg %r", _load_on_create_dialects)
+                logger.debug(
+                    "Loading all dialects from load_on_create_dialects arg %r",
+                    load_on_create_dialects,
+                )
                 for dialect in load_on_create_dialects:
-                    # Load dialect.
+                    # This triggers loading the dialect into the context.
                     _ = self.dialects[dialect]
             else:
                 if disable_load_all_available_dialects:
-                    if _load_on_create_dialects:
-                        logger.debug("Loading all dialects from global load_on_create_dialects %r", _load_on_create_dialects)
-                        for dialect in _load_on_create_dialects:
-                            # Load dialect.
+                    dialects = get_load_on_create_dialects()
+                    if dialects:
+                        logger.debug(
+                            "Loading all dialects from global load_on_create_dialects %r",
+                            dialects,
+                        )
+                        for dialect in dialects:
+                            # This triggers loading the dialect into the context.
                             _ = self.dialects[dialect]
                 else:
                     logger.debug("Loading all available dialects")
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 6f1c0da8a4e5d6..6f37266d5bf395 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -5,7 +5,11 @@
 from ._mlir_libs._mlir.ir import *
 from ._mlir_libs._mlir.ir import _GlobalDebug
 from ._mlir_libs._mlir import register_type_caster, register_value_caster
-from ._mlir_libs import get_dialect_registry, append_load_on_create_dialect, get_load_on_create_dialects
+from ._mlir_libs import (
+    get_dialect_registry,
+    append_load_on_create_dialect,
+    get_load_on_create_dialects,
+)
 
 
 # Convenience decorator for registering user-friendly Attribute builders.
diff --git a/mlir/test/python/ir/dialects.py b/mlir/test/python/ir/dialects.py
index 3742835208a5d9..5a2ed684d298b3 100644
--- a/mlir/test/python/ir/dialects.py
+++ b/mlir/test/python/ir/dialects.py
@@ -129,7 +129,7 @@ def testDialectLoadOnCreate():
     with Context(load_on_create_dialects=[]) as ctx:
         ctx.emit_error_diagnostics = True
         ctx.allow_unregistered_dialects = True
-        
+
         def callback(d):
             # CHECK: DIAGNOSTIC
             # CHECK-SAME: op created with unregistered dialect
@@ -139,13 +139,21 @@ def callback(d):
         handler = ctx.attach_diagnostic_handler(callback)
         loc = Location.unknown(ctx)
         try:
-          op = Operation.create("arith.addi", loc=loc)
-          ctx.allow_unregistered_dialects = False
-          op.verify()
+            op = Operation.create("arith.addi", loc=loc)
+            ctx.allow_unregistered_dialects = False
+            op.verify()
         except MLIRError as e:
-          pass
-  
-    with Context(load_on_create_dialects=["func"]) as ctx:
-      loc = Location.unknown(ctx)
-      fn = Operation.create("func.func", loc=loc)
+            pass
 
+    with Context(load_on_create_dialects=["func"]) as ctx:
+        loc = Location.unknown(ctx)
+        fn = Operation.create("func.func", loc=loc)
+
+    # TODO: This may require an update if a site wide policy is set.
+    # CHECK: Load on create: []
+    print(f"Load on create: {get_load_on_create_dialects()}")
+    append_load_on_create_dialect("func")
+    # CHECK: Load on create:
+    # CHECK-SAME: func
+    print(f"Load on create: {get_load_on_create_dialects()}")
+    print(get_load_on_create_dialects())



More information about the Mlir-commits mailing list