[Mlir-commits] [mlir] [mlir][python] enable registering dialects with the default `Context` (PR #72488)
Maksim Levental
llvmlistbot at llvm.org
Mon Nov 27 13:27:33 PST 2023
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/72488
>From 829c627e825c21f3a888bd9b974f828ffa18e81f Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Thu, 16 Nov 2023 01:40:07 -0600
Subject: [PATCH 1/2] [mlir][python] allow people to register dialects with the
default context
---
mlir/python/mlir/_mlir_libs/__init__.py | 16 +++++++++++++---
mlir/python/mlir/dialects/python_test.py | 4 ++--
mlir/python/mlir/ir.py | 1 +
mlir/test/python/dialects/python_test.py | 16 ++--------------
mlir/test/python/lib/PythonTestModule.cpp | 9 +++++++++
5 files changed, 27 insertions(+), 19 deletions(-)
diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index 6ce77b4cb93f609..0761579da15fb94 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -56,6 +56,17 @@ def get_include_dirs() -> Sequence[str]:
#
# This facility allows downstreams to customize Context creation to their
# needs.
+
+
+def get_registry():
+ if not hasattr(get_registry, "__registry"):
+ from ._mlir import ir
+
+ get_registry.__registry = ir.DialectRegistry()
+
+ return get_registry.__registry
+
+
def _site_initialize():
import importlib
import itertools
@@ -63,7 +74,6 @@ def _site_initialize():
from ._mlir import ir
logger = logging.getLogger(__name__)
- registry = ir.DialectRegistry()
post_init_hooks = []
disable_multithreading = False
@@ -84,7 +94,7 @@ def process_initializer_module(module_name):
logger.debug("Initializing MLIR with module: %s", module_name)
if hasattr(m, "register_dialects"):
logger.debug("Registering dialects from initializer %r", m)
- m.register_dialects(registry)
+ m.register_dialects(get_registry())
if hasattr(m, "context_init_hook"):
logger.debug("Adding context init hook from %r", m)
post_init_hooks.append(m.context_init_hook)
@@ -110,7 +120,7 @@ def process_initializer_module(module_name):
class Context(ir._BaseContext):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self.append_dialect_registry(registry)
+ self.append_dialect_registry(get_registry())
for hook in post_init_hooks:
hook(self)
if not disable_multithreading:
diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py
index 6579e02d8549efa..b5baa80bc767fb3 100644
--- a/mlir/python/mlir/dialects/python_test.py
+++ b/mlir/python/mlir/dialects/python_test.py
@@ -11,7 +11,7 @@
)
-def register_python_test_dialect(context, load=True):
+def register_python_test_dialect(registry):
from .._mlir_libs import _mlirPythonTest
- _mlirPythonTest.register_python_test_dialect(context, load)
+ _mlirPythonTest.register_dialect(registry)
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 18526ab8c3c02dc..82403c0b8d5fed1 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -5,6 +5,7 @@
from ._mlir_libs._mlir.ir import *
from ._mlir_libs._mlir.ir import _GlobalDebug
from ._mlir_libs._mlir import register_type_caster, register_value_caster
+from ._mlir_libs import get_registry
# Convenience decorator for registering user-friendly Attribute builders.
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index f313a400b73c0a5..562190c6fcdf5d5 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -6,6 +6,8 @@
import mlir.dialects.tensor as tensor
import mlir.dialects.arith as arith
+test.register_python_test_dialect(get_registry())
+
def run(f):
print("\nTEST:", f.__name__)
@@ -17,7 +19,6 @@ def run(f):
@run
def testAttributes():
with Context() as ctx, Location.unknown():
- test.register_python_test_dialect(ctx)
#
# Check op construction with attributes.
#
@@ -138,7 +139,6 @@ def testAttributes():
@run
def attrBuilder():
with Context() as ctx, Location.unknown():
- test.register_python_test_dialect(ctx)
# CHECK: python_test.attributes_op
op = test.AttributesOp(
# CHECK-DAG: x_affinemap = affine_map<() -> (2)>
@@ -215,7 +215,6 @@ def attrBuilder():
@run
def inferReturnTypes():
with Context() as ctx, Location.unknown(ctx):
- test.register_python_test_dialect(ctx)
module = Module.create()
with InsertionPoint(module.body):
op = test.InferResultsOp()
@@ -260,7 +259,6 @@ def inferReturnTypes():
@run
def resultTypesDefinedByTraits():
with Context() as ctx, Location.unknown(ctx):
- test.register_python_test_dialect(ctx)
module = Module.create()
with InsertionPoint(module.body):
inferred = test.InferResultsOp()
@@ -295,8 +293,6 @@ def resultTypesDefinedByTraits():
@run
def testOptionalOperandOp():
with Context() as ctx, Location.unknown():
- test.register_python_test_dialect(ctx)
-
module = Module.create()
with InsertionPoint(module.body):
op1 = test.OptionalOperandOp()
@@ -312,7 +308,6 @@ def testOptionalOperandOp():
@run
def testCustomAttribute():
with Context() as ctx:
- test.register_python_test_dialect(ctx)
a = test.TestAttr.get()
# CHECK: #python_test.test_attr
print(a)
@@ -350,7 +345,6 @@ def testCustomAttribute():
@run
def testCustomType():
with Context() as ctx:
- test.register_python_test_dialect(ctx)
a = test.TestType.get()
# CHECK: !python_test.test_type
print(a)
@@ -397,8 +391,6 @@ def testCustomType():
# CHECK-LABEL: TEST: testTensorValue
def testTensorValue():
with Context() as ctx, Location.unknown():
- test.register_python_test_dialect(ctx)
-
i8 = IntegerType.get_signless(8)
class Tensor(test.TestTensorValue):
@@ -436,7 +428,6 @@ def __str__(self):
@run
def inferReturnTypeComponents():
with Context() as ctx, Location.unknown(ctx):
- test.register_python_test_dialect(ctx)
module = Module.create()
i32 = IntegerType.get_signless(32)
with InsertionPoint(module.body):
@@ -488,8 +479,6 @@ def inferReturnTypeComponents():
@run
def testCustomTypeTypeCaster():
with Context() as ctx, Location.unknown():
- test.register_python_test_dialect(ctx)
-
a = test.TestType.get()
assert a.typeid is not None
@@ -542,7 +531,6 @@ def type_caster(pytype):
@run
def testInferTypeOpInterface():
with Context() as ctx, Location.unknown(ctx):
- test.register_python_test_dialect(ctx)
module = Module.create()
with InsertionPoint(module.body):
i64 = IntegerType.get_signless(64)
diff --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp
index aff414894cb825a..f81b851f8759bf7 100644
--- a/mlir/test/python/lib/PythonTestModule.cpp
+++ b/mlir/test/python/lib/PythonTestModule.cpp
@@ -34,6 +34,15 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
},
py::arg("context"), py::arg("load") = true);
+ m.def(
+ "register_dialect",
+ [](MlirDialectRegistry registry) {
+ MlirDialectHandle pythonTestDialect =
+ mlirGetDialectHandle__python_test__();
+ mlirDialectHandleInsertDialect(pythonTestDialect, registry);
+ },
+ py::arg("registry"));
+
mlir_attribute_subclass(m, "TestAttr",
mlirAttributeIsAPythonTestTestAttribute)
.def_classmethod(
>From 2eae3382039246916e410156ecfccd10fb5b1109 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Mon, 27 Nov 2023 15:27:20 -0600
Subject: [PATCH 2/2] incorporate comments
---
mlir/python/mlir/_mlir_libs/__init__.py | 16 ++++++++++------
mlir/python/mlir/ir.py | 2 +-
mlir/test/python/dialects/python_test.py | 2 +-
3 files changed, 12 insertions(+), 8 deletions(-)
diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index 0761579da15fb94..32f46d24cc7392b 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -57,14 +57,18 @@ def get_include_dirs() -> Sequence[str]:
# This facility allows downstreams to customize Context creation to their
# needs.
+_dialect_registry = None
-def get_registry():
- if not hasattr(get_registry, "__registry"):
+
+def get_dialect_registry():
+ global _dialect_registry
+
+ if _dialect_registry is None:
from ._mlir import ir
- get_registry.__registry = ir.DialectRegistry()
+ _dialect_registry = ir.DialectRegistry()
- return get_registry.__registry
+ return _dialect_registry
def _site_initialize():
@@ -94,7 +98,7 @@ def process_initializer_module(module_name):
logger.debug("Initializing MLIR with module: %s", module_name)
if hasattr(m, "register_dialects"):
logger.debug("Registering dialects from initializer %r", m)
- m.register_dialects(get_registry())
+ m.register_dialects(get_dialect_registry())
if hasattr(m, "context_init_hook"):
logger.debug("Adding context init hook from %r", m)
post_init_hooks.append(m.context_init_hook)
@@ -120,7 +124,7 @@ def process_initializer_module(module_name):
class Context(ir._BaseContext):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self.append_dialect_registry(get_registry())
+ self.append_dialect_registry(get_dialect_registry())
for hook in post_init_hooks:
hook(self)
if not disable_multithreading:
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 82403c0b8d5fed1..6d21da3b4179fdf 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -5,7 +5,7 @@
from ._mlir_libs._mlir.ir import *
from ._mlir_libs._mlir.ir import _GlobalDebug
from ._mlir_libs._mlir import register_type_caster, register_value_caster
-from ._mlir_libs import get_registry
+from ._mlir_libs import get_dialect_registry
# Convenience decorator for registering user-friendly Attribute builders.
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 562190c6fcdf5d5..88761c9d08fe07c 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -6,7 +6,7 @@
import mlir.dialects.tensor as tensor
import mlir.dialects.arith as arith
-test.register_python_test_dialect(get_registry())
+test.register_python_test_dialect(get_dialect_registry())
def run(f):
More information about the Mlir-commits
mailing list