[Mlir-commits] [mlir] cb7b038 - [mlir][python] Simplify python extension loading.

Stella Laurenzo llvmlistbot at llvm.org
Thu Sep 2 17:48:01 PDT 2021


Author: Stella Laurenzo
Date: 2021-09-03T00:43:28Z
New Revision: cb7b03819ae667a87e49fa2546498dcf6248d99c

URL: https://github.com/llvm/llvm-project/commit/cb7b03819ae667a87e49fa2546498dcf6248d99c
DIFF: https://github.com/llvm/llvm-project/commit/cb7b03819ae667a87e49fa2546498dcf6248d99c.diff

LOG: [mlir][python] Simplify python extension loading.

* Now that packaging has stabilized, removes old mechanisms for loading extensions, preferring direct importing.
* Removes _cext_loader.py, _dlloader.py as unnecessary.
* Fixes the path where the CAPI dll is written on Windows. This enables that path of least resistance loading behavior to work with no further drama (see: https://bugs.python.org/issue36085).
* With this patch, `ninja check-mlir` on Windows with Python bindings works for me, modulo some failures that are actually due to a couple of pre-existing Windows bugs. I think this is the first time the Windows Python bindings have worked upstream.
* Downstream changes needed:
  * If downstreams are using the now removed `load_extension`, `reexport_cext`, etc, then those should be replaced with normal import statements as done in this patch.

Reviewed By: jdd, aartbik

Differential Revision: https://reviews.llvm.org/D108489

Added: 
    

Modified: 
    mlir/cmake/modules/AddMLIRPython.cmake
    mlir/lib/Bindings/Python/IRModule.cpp
    mlir/python/CMakeLists.txt
    mlir/python/mlir/_mlir_libs/__init__.py
    mlir/python/mlir/all_passes_registration/__init__.py
    mlir/python/mlir/conversions/__init__.py
    mlir/python/mlir/dialects/_linalg_ops_ext.py
    mlir/python/mlir/dialects/_ods_common.py
    mlir/python/mlir/dialects/async_dialect/passes/__init__.py
    mlir/python/mlir/dialects/gpu/passes/__init__.py
    mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
    mlir/python/mlir/dialects/linalg/passes/__init__.py
    mlir/python/mlir/dialects/sparse_tensor.py
    mlir/python/mlir/execution_engine.py
    mlir/python/mlir/ir.py
    mlir/python/mlir/passmanager.py
    mlir/python/mlir/transforms/__init__.py

Removed: 
    mlir/python/mlir/_cext_loader.py
    mlir/python/mlir/_dlloader.py


################################################################################
diff  --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index b1e2f0b3f5559..d67820152ef94 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -371,6 +371,9 @@ function(add_mlir_python_common_capi_library name)
   set_target_properties(${name} PROPERTIES
     LIBRARY_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
     BINARY_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
+    # Needed for windows (and don't hurt others).
+    RUNTIME_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
+    ARCHIVE_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}"
   )
   mlir_python_setup_extension_rpath(${name}
     RELATIVE_INSTALL_ROOT "${ARG_RELATIVE_INSTALL_ROOT}"

diff  --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index 08ce06da8783e..9f853eb92df18 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -12,6 +12,8 @@
 
 #include <vector>
 
+#include "mlir-c/Bindings/Python/Interop.h"
+
 namespace py = pybind11;
 using namespace mlir;
 using namespace mlir::python;
@@ -25,6 +27,9 @@ PyGlobals *PyGlobals::instance = nullptr;
 PyGlobals::PyGlobals() {
   assert(!instance && "PyGlobals already constructed");
   instance = this;
+  // The default search path include {mlir.}dialects, where {mlir.} is the
+  // package prefix configured at compile time.
+  dialectSearchPrefixes.push_back(MAKE_MLIR_PYTHON_QUALNAME("dialects"));
 }
 
 PyGlobals::~PyGlobals() { instance = nullptr; }

diff  --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 9f66aa9c25596..506d8ead221dd 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -20,8 +20,6 @@ declare_mlir_python_sources(MLIRPythonSources.Core
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   ADD_TO_PARENT MLIRPythonSources
   SOURCES
-    _cext_loader.py
-    _dlloader.py
     _mlir_libs/__init__.py
     ir.py
     passmanager.py

diff  --git a/mlir/python/mlir/_cext_loader.py b/mlir/python/mlir/_cext_loader.py
deleted file mode 100644
index 5f2de7f006960..0000000000000
--- a/mlir/python/mlir/_cext_loader.py
+++ /dev/null
@@ -1,57 +0,0 @@
-#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-#  See https://llvm.org/LICENSE.txt for license information.
-#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-"""Common module for looking up and manipulating C-Extensions."""
-
-# The normal layout is to have a nested _mlir_libs package that contains
-# all native libraries and extensions. If that exists, use it, but also fallback
-# to old behavior where extensions were at the top level as loose libraries.
-# TODO: Remove the fallback once downstreams adapt.
-try:
-  from ._mlir_libs import *
-  # TODO: Remove these aliases once everything migrates
-  _preload_dependency = preload_dependency
-  _load_extension = load_extension
-except ModuleNotFoundError:
-  # Assume that we are in-tree.
-  # The _dlloader takes care of platform specific setup before we try to
-  # load a shared library.
-  # TODO: Remove _dlloader once all consolidated on the _mlir_libs approach.
-  from ._dlloader import preload_dependency
-
-  def load_extension(name):
-    import importlib
-    return importlib.import_module(name)  # i.e. '_mlir' at the top level
-
-preload_dependency("MLIRPythonCAPI")
-
-# Expose the corresponding C-Extension module with a well-known name at this
-# top-level module. This allows relative imports like the following to
-# function:
-#   from .._cext_loader import _cext
-# This reduces coupling, allowing embedding of the python sources into another
-# project that can just vary based on this top-level loader module.
-_cext = load_extension("_mlir")
-
-
-def _reexport_cext(cext_module_name, target_module_name):
-  """Re-exports a named sub-module of the C-Extension into another module.
-
-  Typically:
-    from ._cext_loader import _reexport_cext
-    _reexport_cext("ir", __name__)
-    del _reexport_cext
-  """
-  import sys
-  target_module = sys.modules[target_module_name]
-  submodule_names = cext_module_name.split(".")
-  source_module = _cext
-  for submodule_name in submodule_names:
-    source_module = getattr(source_module, submodule_name)
-  for attr_name in dir(source_module):
-    if not attr_name.startswith("__"):
-      setattr(target_module, attr_name, getattr(source_module, attr_name))
-
-
-# Add our 'dialects' parent module to the search path for implementations.
-_cext.globals.append_dialect_search_prefix("mlir.dialects")

diff  --git a/mlir/python/mlir/_dlloader.py b/mlir/python/mlir/_dlloader.py
deleted file mode 100644
index 454a7b7f137f7..0000000000000
--- a/mlir/python/mlir/_dlloader.py
+++ /dev/null
@@ -1,59 +0,0 @@
-#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-#  See https://llvm.org/LICENSE.txt for license information.
-#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-import os
-import platform
-
-_is_windows = platform.system() == "Windows"
-_this_directory = os.path.dirname(__file__)
-
-# The standard LLVM build/install tree for Windows is laid out as:
-#   bin/
-#     MLIRPublicAPI.dll
-#   python/
-#     _mlir.*.pyd (dll extension)
-#     mlir/
-#       _dlloader.py (this file)
-# First check the python/ directory level for DLLs co-located with the pyd
-# file, and then fall back to searching the bin/ directory.
-# TODO: This should be configurable at some point.
-_dll_search_path = [
-  os.path.join(_this_directory, ".."),
-  os.path.join(_this_directory, "..", "..", "bin"),
-]
-
-# Stash loaded DLLs to keep them alive.
-_loaded_dlls = []
-
-def preload_dependency(public_name):
-  """Preloads a dylib by its soname or DLL name.
-
-  On Windows and Linux, doing this prior to loading a dependency will populate
-  the library in the flat namespace so that a subsequent library that depend
-  on it will resolve to this preloaded version.
-
-  On OSX, resolution is completely path based so this facility no-ops. On
-  Linux, as long as RPATHs are setup properly, resolution is path based but
-  this facility can still act as an escape hatch for relocatable distributions.
-  """
-  if _is_windows:
-    _preload_dependency_windows(public_name)
-
-
-def _preload_dependency_windows(public_name):
-  dll_basename = public_name + ".dll"
-  found_path = None
-  for search_dir in _dll_search_path:
-    candidate_path = os.path.join(search_dir, dll_basename)
-    if os.path.exists(candidate_path):
-      found_path = candidate_path
-      break
-
-  if found_path is None:
-    raise RuntimeError(
-      f"Unable to find dependency DLL {dll_basename} in search "
-      f"path {_dll_search_path}")
-
-  import ctypes
-  _loaded_dlls.append(ctypes.CDLL(found_path))

diff  --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index 55139b2a84876..4e2e5f453bc58 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -4,24 +4,10 @@
 
 from typing import Sequence
 
-import importlib
 import os
 
-__all__ = [
-  "load_extension",
-  "preload_dependency",
-]
-
 _this_dir = os.path.dirname(__file__)
 
-def load_extension(name):
-  return importlib.import_module(f".{name}", __package__)
-
-
-def preload_dependency(public_name):
-  # TODO: Implement this hook to pre-load DLLs with ctypes on Windows.
-  pass
-
 
 def get_lib_dirs() -> Sequence[str]:
   """Gets the lib directory for linking to shared libraries.

diff  --git a/mlir/python/mlir/all_passes_registration/__init__.py b/mlir/python/mlir/all_passes_registration/__init__.py
index cf3367cfe92ff..aca557ab9c70d 100644
--- a/mlir/python/mlir/all_passes_registration/__init__.py
+++ b/mlir/python/mlir/all_passes_registration/__init__.py
@@ -2,7 +2,4 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-from .._cext_loader import _load_extension
-
-_cextAllPasses = _load_extension("_mlirAllPassesRegistration")
-del _load_extension
+from .._mlir_libs import _mlirAllPassesRegistration as _cextAllPasses

diff  --git a/mlir/python/mlir/conversions/__init__.py b/mlir/python/mlir/conversions/__init__.py
index 0989449a447b7..a6a9eb8213557 100644
--- a/mlir/python/mlir/conversions/__init__.py
+++ b/mlir/python/mlir/conversions/__init__.py
@@ -4,5 +4,4 @@
 
 # Expose the corresponding C-Extension module with a well-known name at this
 # level.
-from .._cext_loader import _load_extension
-_cextConversions = _load_extension("_mlirConversions")
+from .._mlir_libs import _mlirConversions as _cextConversions

diff  --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py
index 656992cac26d7..5360967492d5c 100644
--- a/mlir/python/mlir/dialects/_linalg_ops_ext.py
+++ b/mlir/python/mlir/dialects/_linalg_ops_ext.py
@@ -6,10 +6,7 @@
   from typing import Optional, Sequence, Union
   from ..ir import *
   from ._ods_common import get_default_loc_context
-  # TODO: resolve name collision for Linalg functionality that is injected inside
-  # the _mlir.dialects.linalg directly via pybind.
-  from .._cext_loader import _cext
-  fill_builtin_region = _cext.dialects.linalg.fill_builtin_region
+  from .._mlir_libs._mlir.dialects.linalg import fill_builtin_region
 except ImportError as e:
   raise RuntimeError("Error loading imports from extension module") from e
 
@@ -29,12 +26,11 @@ def __init__(self, output: Value, value: Value, *, loc=None, ip=None):
     results = []
     if isa(RankedTensorType, output.type):
       results = [output.type]
-    op = self.build_generic(
-        results=results,
-        operands=[value, output],
-        attributes=None,
-        loc=loc,
-        ip=ip)
+    op = self.build_generic(results=results,
+                            operands=[value, output],
+                            attributes=None,
+                            loc=loc,
+                            ip=ip)
     OpView.__init__(self, op)
     linalgDialect = Context.current.get_dialect_descriptor("linalg")
     fill_builtin_region(linalgDialect, self.operation)
@@ -78,12 +74,11 @@ def __init__(self,
     attributes["static_sizes"] = ArrayAttr.get(
         [IntegerAttr.get(i64_type, s) for s in static_size_ints],
         context=context)
-    op = self.build_generic(
-        results=[result_type],
-        operands=operands,
-        attributes=attributes,
-        loc=loc,
-        ip=ip)
+    op = self.build_generic(results=[result_type],
+                            operands=operands,
+                            attributes=attributes,
+                            loc=loc,
+                            ip=ip)
     OpView.__init__(self, op)
 
 
@@ -92,11 +87,10 @@ class StructuredOpMixin:
 
   def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None):
     super().__init__(
-        self.build_generic(
-            results=list(results),
-            operands=[list(inputs), list(outputs)],
-            loc=loc,
-            ip=ip))
+        self.build_generic(results=list(results),
+                           operands=[list(inputs), list(outputs)],
+                           loc=loc,
+                           ip=ip))
 
 
 def select_opview_mixin(parent_opview_cls):

diff  --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index d030440887414..2fbf3545f46d4 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -2,8 +2,9 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-# Re-export the parent _cext so that every level of the API can get it locally.
-from .._cext_loader import _cext
+# Provide a convenient name for sub-packages to resolve the main C-extension
+# with a relative import.
+from .._mlir_libs import _mlir as _cext
 
 __all__ = [
     "equally_sized_accessor",

diff  --git a/mlir/python/mlir/dialects/async_dialect/passes/__init__.py b/mlir/python/mlir/dialects/async_dialect/passes/__init__.py
index 88a7b539c9c7e..851d5614881ed 100644
--- a/mlir/python/mlir/dialects/async_dialect/passes/__init__.py
+++ b/mlir/python/mlir/dialects/async_dialect/passes/__init__.py
@@ -2,5 +2,4 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-from ...._cext_loader import _load_extension
-_cextAsyncPasses = _load_extension("_mlirAsyncPasses")
+from ...._mlir_libs import _mlirAsyncPasses as _cextAsyncPasses

diff  --git a/mlir/python/mlir/dialects/gpu/passes/__init__.py b/mlir/python/mlir/dialects/gpu/passes/__init__.py
index dd28e91a4646a..9b1ef076aa3c5 100644
--- a/mlir/python/mlir/dialects/gpu/passes/__init__.py
+++ b/mlir/python/mlir/dialects/gpu/passes/__init__.py
@@ -2,5 +2,4 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-from ...._cext_loader import _load_extension
-_cextGPUPasses = _load_extension("_mlirGPUPasses")
+from ...._mlir_libs import _mlirGPUPasses as _cextGPUPasses

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index ea2da7151beac..b151a9ba9f39f 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -5,13 +5,11 @@
 from typing import Dict, Sequence
 
 from .....ir import *
+from ....._mlir_libs._mlir.dialects.linalg import fill_builtin_region
+
 from .... import linalg
 from .... import std
 from .... import math
-# TODO: resolve name collision for Linalg functionality that is injected inside
-# the _mlir.dialects.linalg directly via pybind.
-from ....._cext_loader import _cext
-fill_builtin_region = _cext.dialects.linalg.fill_builtin_region
 
 from .scalar_expr import *
 from .config import *
@@ -216,8 +214,8 @@ def expression(self, expr: ScalarExpression) -> Value:
       value_attr = Attribute.parse(expr.scalar_const.value)
       return std.ConstantOp(value_attr.type, value_attr).result
     elif expr.scalar_index:
-      dim_attr = IntegerAttr.get(
-          IntegerType.get_signless(64), expr.scalar_index.dim)
+      dim_attr = IntegerAttr.get(IntegerType.get_signless(64),
+                                 expr.scalar_index.dim)
       return linalg.IndexOp(IndexType.get(), dim_attr).result
     elif expr.scalar_apply:
       try:

diff  --git a/mlir/python/mlir/dialects/linalg/passes/__init__.py b/mlir/python/mlir/dialects/linalg/passes/__init__.py
index 6555ad69a5231..0920e8ef490fb 100644
--- a/mlir/python/mlir/dialects/linalg/passes/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/passes/__init__.py
@@ -2,5 +2,4 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-from ...._cext_loader import _load_extension
-_cextLinalgPasses = _load_extension("_mlirLinalgPasses")
+from ...._mlir_libs import _mlirLinalgPasses as _cextLinalgPasses

diff  --git a/mlir/python/mlir/dialects/sparse_tensor.py b/mlir/python/mlir/dialects/sparse_tensor.py
index 59fd86021dc6e..4a89ef8ae0532 100644
--- a/mlir/python/mlir/dialects/sparse_tensor.py
+++ b/mlir/python/mlir/dialects/sparse_tensor.py
@@ -2,11 +2,5 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-from .._cext_loader import _reexport_cext
-from .._cext_loader import _load_extension
-
-_reexport_cext("dialects.sparse_tensor", __name__)
-_cextSparseTensorPasses = _load_extension("_mlirSparseTensorPasses")
-
-del _reexport_cext
-del _load_extension
+from .._mlir_libs._mlir.dialects.sparse_tensor import *
+from .._mlir_libs import _mlirSparseTensorPasses as _cextSparseTensorPasses

diff  --git a/mlir/python/mlir/execution_engine.py b/mlir/python/mlir/execution_engine.py
index f3bcd0e0d78a7..1c516ae5a12e4 100644
--- a/mlir/python/mlir/execution_engine.py
+++ b/mlir/python/mlir/execution_engine.py
@@ -3,8 +3,7 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 # Simply a wrapper around the extension module of the same name.
-from ._cext_loader import load_extension
-_execution_engine = load_extension("_mlirExecutionEngine")
+from ._mlir_libs import _mlirExecutionEngine as _execution_engine
 import ctypes
 
 __all__ = [

diff  --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 2b420511d1c03..99e88ff743848 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -2,8 +2,5 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-# Simply a wrapper around the extension module of the same name.
-from ._cext_loader import _reexport_cext
-_reexport_cext("ir", __name__)
-del _reexport_cext
-
+from ._mlir_libs._mlir.ir import *
+from ._mlir_libs._mlir.ir import _GlobalDebug

diff  --git a/mlir/python/mlir/passmanager.py b/mlir/python/mlir/passmanager.py
index 6b267b76eb7d4..22e86b8798dea 100644
--- a/mlir/python/mlir/passmanager.py
+++ b/mlir/python/mlir/passmanager.py
@@ -2,7 +2,4 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-# Simply a wrapper around the extension module of the same name.
-from ._cext_loader import _reexport_cext
-_reexport_cext("passmanager", __name__)
-del _reexport_cext
+from ._mlir_libs._mlir.passmanager import *

diff  --git a/mlir/python/mlir/transforms/__init__.py b/mlir/python/mlir/transforms/__init__.py
index 2149933d0848e..71ea17d7f1b6f 100644
--- a/mlir/python/mlir/transforms/__init__.py
+++ b/mlir/python/mlir/transforms/__init__.py
@@ -4,5 +4,4 @@
 
 # Expose the corresponding C-Extension module with a well-known name at this
 # level.
-from .._cext_loader import _load_extension
-_cextTransforms = _load_extension("_mlirTransforms")
+from .._mlir_libs import _mlirTransforms as _cextTransforms


        


More information about the Mlir-commits mailing list