[Mlir-commits] [mlir] [mlir][python] remove various caching mechanisms (PR #70831)

Maksim Levental llvmlistbot at llvm.org
Thu Nov 2 08:51:23 PDT 2023


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/70831

>From 07aeb5d432f8722a2b6c66370aedbdaac0012348 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Tue, 31 Oct 2023 11:58:46 -0500
Subject: [PATCH 1/2] [mlir][python] remove various caching mechanism

---
 mlir/docs/Bindings/Python.md                  |   5 +-
 mlir/lib/Bindings/Python/Globals.h            |  24 +---
 mlir/lib/Bindings/Python/IRModule.cpp         | 131 +++++-------------
 mlir/lib/Bindings/Python/MainModule.cpp       |   5 +-
 .../test/python/ir/custom_dialect/__init__.py |   0
 mlir/test/python/ir/custom_dialect/custom.py  |   0
 .../python/ir/custom_dialect/lit.local.cfg    |   2 +
 mlir/test/python/ir/insertion_point.py        |   6 +
 8 files changed, 58 insertions(+), 115 deletions(-)
 create mode 100644 mlir/test/python/ir/custom_dialect/__init__.py
 create mode 100644 mlir/test/python/ir/custom_dialect/custom.py
 create mode 100644 mlir/test/python/ir/custom_dialect/lit.local.cfg

diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index bc2e676a878c0f4..6e52c4deaad9aa9 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -945,10 +945,11 @@ When the python bindings need to locate a wrapper module, they consult the
 `dialect_search_path` and use it to find an appropriately named module. For the
 main repository, this search path is hard-coded to include the `mlir.dialects`
 module, which is where wrappers are emitted by the above build rule. Out of tree
-dialects and add their modules to the search path by calling:
+dialects can add their modules to the search path by calling:
 
 ```python
-mlir._cext.append_dialect_search_prefix("myproject.mlir.dialects")
+from mlir.dialects._ods_common import _cext
+_cext.globals.append_dialect_search_prefix("myproject.mlir.dialects")
 ```
 
 ### Wrapper module code organization
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index 21899bdce22e810..976297257ced06e 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -9,10 +9,6 @@
 #ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H
 #define MLIR_BINDINGS_PYTHON_GLOBALS_H
 
-#include <optional>
-#include <string>
-#include <vector>
-
 #include "PybindUtils.h"
 
 #include "mlir-c/IR.h"
@@ -21,6 +17,10 @@
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/StringSet.h"
 
+#include <optional>
+#include <string>
+#include <vector>
+
 namespace mlir {
 namespace python {
 
@@ -45,17 +45,13 @@ class PyGlobals {
     dialectSearchPrefixes.swap(newValues);
   }
 
-  /// Clears positive and negative caches regarding what implementations are
-  /// available. Future lookups will do more expensive existence checks.
-  void clearImportCache();
-
   /// Loads a python module corresponding to the given dialect namespace.
   /// No-ops if the module has already been loaded or is not found. Raises
   /// an error on any evaluation issues.
   /// Note that this returns void because it is expected that the module
   /// contains calls to decorators and helpers that register the salient
-  /// entities.
-  void loadDialectModule(llvm::StringRef dialectNamespace);
+  /// entities. Returns true if dialect is successfully loaded.
+  bool loadDialectModule(llvm::StringRef dialectNamespace);
 
   /// Adds a user-friendly Attribute builder.
   /// Raises an exception if the mapping already exists and replace == false.
@@ -113,16 +109,10 @@ class PyGlobals {
   llvm::StringMap<pybind11::object> attributeBuilderMap;
   /// Map of MlirTypeID to custom type caster.
   llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMap;
-  /// Cache for map of MlirTypeID to custom type caster.
-  llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMapCache;
 
   /// Set of dialect namespaces that we have attempted to import implementation
   /// modules for.
-  llvm::StringSet<> loadedDialectModulesCache;
-  /// Cache of operation name to external operation class object. This is
-  /// maintained on lookup as a shadow of operationClassMap in order for repeat
-  /// lookups of the classes to only incur the cost of one hashtable lookup.
-  llvm::StringMap<pybind11::object> operationClassMapCache;
+  llvm::StringSet<> loadedDialectModules;
 };
 
 } // namespace python
diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index f8e22f7bb0c1ba7..6c5cde86236ce90 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -10,12 +10,12 @@
 #include "Globals.h"
 #include "PybindUtils.h"
 
-#include <optional>
-#include <vector>
-
 #include "mlir-c/Bindings/Python/Interop.h"
 #include "mlir-c/Support.h"
 
+#include <optional>
+#include <vector>
+
 namespace py = pybind11;
 using namespace mlir;
 using namespace mlir::python;
@@ -36,12 +36,12 @@ PyGlobals::PyGlobals() {
 
 PyGlobals::~PyGlobals() { instance = nullptr; }
 
-void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
-  if (loadedDialectModulesCache.contains(dialectNamespace))
-    return;
+bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
+  if (loadedDialectModules.contains(dialectNamespace))
+    return true;
   // Since re-entrancy is possible, make a copy of the search prefixes.
   std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
-  py::object loaded;
+  py::object loaded = py::none();
   for (std::string moduleName : localSearchPrefixes) {
     moduleName.push_back('.');
     moduleName.append(dialectNamespace.data(), dialectNamespace.size());
@@ -57,15 +57,18 @@ void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
     break;
   }
 
+  if (loaded.is_none())
+    return false;
   // Note: Iterator cannot be shared from prior to loading, since re-entrancy
   // may have occurred, which may do anything.
-  loadedDialectModulesCache.insert(dialectNamespace);
+  loadedDialectModules.insert(dialectNamespace);
+  return true;
 }
 
 void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
                                          py::function pyFunc, bool replace) {
   py::object &found = attributeBuilderMap[attributeKind];
-  if (found && !found.is_none() && !replace) {
+  if (found && !replace) {
     throw std::runtime_error((llvm::Twine("Attribute builder for '") +
                               attributeKind +
                               "' is already registered with func: " +
@@ -79,13 +82,10 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
                                    pybind11::function typeCaster,
                                    bool replace) {
   pybind11::object &found = typeCasterMap[mlirTypeID];
-  if (found && !found.is_none() && !replace)
-    throw std::runtime_error("Type caster is already registered");
+  if (found && !replace)
+    throw std::runtime_error("Type caster is already registered with caster: " +
+                             py::str(found).operator std::string());
   found = std::move(typeCaster);
-  const auto foundIt = typeCasterMapCache.find(mlirTypeID);
-  if (foundIt != typeCasterMapCache.end() && !foundIt->second.is_none()) {
-    typeCasterMapCache[mlirTypeID] = found;
-  }
 }
 
 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
@@ -108,114 +108,59 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,
                                  .str());
   }
   found = std::move(pyClass);
-  auto foundIt = operationClassMapCache.find(operationName);
-  if (foundIt != operationClassMapCache.end() && !foundIt->second.is_none()) {
-    operationClassMapCache[operationName] = found;
-  }
 }
 
 std::optional<py::function>
 PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
-  // Fast match against the class map first (common case).
   const auto foundIt = attributeBuilderMap.find(attributeKind);
   if (foundIt != attributeBuilderMap.end()) {
-    if (foundIt->second.is_none())
-      return std::nullopt;
-    assert(foundIt->second && "py::function is defined");
+    assert(foundIt->second && "attribute builder is defined");
     return foundIt->second;
   }
-
-  // Not found and loading did not yield a registration. Negative cache.
-  attributeBuilderMap[attributeKind] = py::none();
   return std::nullopt;
 }
 
 std::optional<py::function> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
                                                         MlirDialect dialect) {
-  {
-    // Fast match against the class map first (common case).
-    const auto foundIt = typeCasterMapCache.find(mlirTypeID);
-    if (foundIt != typeCasterMapCache.end()) {
-      if (foundIt->second.is_none())
-        return std::nullopt;
-      assert(foundIt->second && "py::function is defined");
-      return foundIt->second;
-    }
-  }
-
-  // Not found. Load the dialect namespace.
-  loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
-
-  // Attempt to find from the canonical map and cache.
-  {
-    const auto foundIt = typeCasterMap.find(mlirTypeID);
-    if (foundIt != typeCasterMap.end()) {
-      if (foundIt->second.is_none())
-        return std::nullopt;
-      assert(foundIt->second && "py::object is defined");
-      // Positive cache.
-      typeCasterMapCache[mlirTypeID] = foundIt->second;
-      return foundIt->second;
-    }
-    // Negative cache.
-    typeCasterMap[mlirTypeID] = py::none();
+  // Make sure dialect module is loaded.
+  if (!loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))))
     return std::nullopt;
+
+  const auto foundIt = typeCasterMap.find(mlirTypeID);
+  if (foundIt != typeCasterMap.end()) {
+    assert(foundIt->second && "type caster is defined");
+    return foundIt->second;
   }
+  return std::nullopt;
 }
 
 std::optional<py::object>
 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
-  loadDialectModule(dialectNamespace);
-  // Fast match against the class map first (common case).
+  // Make sure dialect module is loaded.
+  if (!loadDialectModule(dialectNamespace))
+    return std::nullopt;
   const auto foundIt = dialectClassMap.find(dialectNamespace);
   if (foundIt != dialectClassMap.end()) {
-    if (foundIt->second.is_none())
-      return std::nullopt;
-    assert(foundIt->second && "py::object is defined");
+    assert(foundIt->second && "dialect class is defined");
     return foundIt->second;
   }
-
-  // Not found and loading did not yield a registration. Negative cache.
-  dialectClassMap[dialectNamespace] = py::none();
+  // Not found and loading did not yield a registration.
   return std::nullopt;
 }
 
 std::optional<pybind11::object>
 PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
-  {
-    auto foundIt = operationClassMapCache.find(operationName);
-    if (foundIt != operationClassMapCache.end()) {
-      if (foundIt->second.is_none())
-        return std::nullopt;
-      assert(foundIt->second && "py::object is defined");
-      return foundIt->second;
-    }
-  }
-
-  // Not found. Load the dialect namespace.
+  // Make sure dialect module is loaded.
   auto split = operationName.split('.');
   llvm::StringRef dialectNamespace = split.first;
-  loadDialectModule(dialectNamespace);
-
-  // Attempt to find from the canonical map and cache.
-  {
-    auto foundIt = operationClassMap.find(operationName);
-    if (foundIt != operationClassMap.end()) {
-      if (foundIt->second.is_none())
-        return std::nullopt;
-      assert(foundIt->second && "py::object is defined");
-      // Positive cache.
-      operationClassMapCache[operationName] = foundIt->second;
-      return foundIt->second;
-    }
-    // Negative cache.
-    operationClassMap[operationName] = py::none();
+  if (!loadDialectModule(dialectNamespace))
     return std::nullopt;
-  }
-}
 
-void PyGlobals::clearImportCache() {
-  loadedDialectModulesCache.clear();
-  operationClassMapCache.clear();
-  typeCasterMapCache.clear();
+  auto foundIt = operationClassMap.find(operationName);
+  if (foundIt != operationClassMap.end()) {
+    assert(foundIt->second && "OpView is defined");
+    return foundIt->second;
+  }
+  // Not found and loading did not yield a registration.
+  return std::nullopt;
 }
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index a936becf67bea75..2b6248321c1c110 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -6,14 +6,14 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include <tuple>
-
 #include "PybindUtils.h"
 
 #include "Globals.h"
 #include "IRModule.h"
 #include "Pass.h"
 
+#include <tuple>
+
 namespace py = pybind11;
 using namespace mlir;
 using namespace py::literals;
@@ -34,7 +34,6 @@ PYBIND11_MODULE(_mlir, m) {
           "append_dialect_search_prefix",
           [](PyGlobals &self, std::string moduleName) {
             self.getDialectSearchPrefixes().push_back(std::move(moduleName));
-            self.clearImportCache();
           },
           "module_name"_a)
       .def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
diff --git a/mlir/test/python/ir/custom_dialect/__init__.py b/mlir/test/python/ir/custom_dialect/__init__.py
new file mode 100644
index 000000000000000..e69de29bb2d1d64
diff --git a/mlir/test/python/ir/custom_dialect/custom.py b/mlir/test/python/ir/custom_dialect/custom.py
new file mode 100644
index 000000000000000..e69de29bb2d1d64
diff --git a/mlir/test/python/ir/custom_dialect/lit.local.cfg b/mlir/test/python/ir/custom_dialect/lit.local.cfg
new file mode 100644
index 000000000000000..26ea63660d6a3bf
--- /dev/null
+++ b/mlir/test/python/ir/custom_dialect/lit.local.cfg
@@ -0,0 +1,2 @@
+config.excludes.add("__init__.py")
+config.excludes.add("custom.py")
diff --git a/mlir/test/python/ir/insertion_point.py b/mlir/test/python/ir/insertion_point.py
index 268d2e77d036f5e..b9054afaba82432 100644
--- a/mlir/test/python/ir/insertion_point.py
+++ b/mlir/test/python/ir/insertion_point.py
@@ -3,7 +3,13 @@
 import gc
 import io
 import itertools
+import sys
+
 from mlir.ir import *
+from mlir.dialects._ods_common import _cext
+
+sys.path.append(".")
+_cext.globals.append_dialect_search_prefix("custom_dialect")
 
 
 def run(f):

>From 49e0dfb5a0909620c88db315d55064d8098d9960 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Thu, 2 Nov 2023 10:30:42 -0500
Subject: [PATCH 2/2] move append_dialect_search_prefix test

---
 mlir/lib/Bindings/Python/MainModule.cpp        |  6 ++++++
 mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi |  1 +
 mlir/test/python/ir/custom_dialect/custom.py   |  4 ++++
 mlir/test/python/ir/dialects.py                | 17 +++++++++++++++++
 mlir/test/python/ir/insertion_point.py         |  8 --------
 5 files changed, 28 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 2b6248321c1c110..2ba3a3677198cbc 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -36,6 +36,12 @@ PYBIND11_MODULE(_mlir, m) {
             self.getDialectSearchPrefixes().push_back(std::move(moduleName));
           },
           "module_name"_a)
+      .def(
+          "_check_dialect_module_loaded",
+          [](PyGlobals &self, const std::string &dialectNamespace) {
+            return self.loadDialectModule(dialectNamespace);
+          },
+          "dialect_namespace"_a)
       .def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
            "dialect_namespace"_a, "dialect_class"_a,
            "Testing hook for directly registering a dialect")
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi
index 93b98c4aa53fbd8..3ed1872f1cd5a21 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi
@@ -7,6 +7,7 @@ class _Globals:
     def _register_dialect_impl(self, dialect_namespace: str, dialect_class: type) -> None: ...
     def _register_operation_impl(self, operation_name: str, operation_class: type) -> None: ...
     def append_dialect_search_prefix(self, module_name: str) -> None: ...
+    def _check_dialect_module_loaded(self, dialect_namespace: str) -> bool: ...
 
 def register_dialect(dialect_class: type) -> object: ...
 def register_operation(dialect_class: type) -> object: ...
diff --git a/mlir/test/python/ir/custom_dialect/custom.py b/mlir/test/python/ir/custom_dialect/custom.py
index e69de29bb2d1d64..388368ca6fe6bcd 100644
--- a/mlir/test/python/ir/custom_dialect/custom.py
+++ b/mlir/test/python/ir/custom_dialect/custom.py
@@ -0,0 +1,4 @@
+# The purpose of this empty dialect module is to enable successfully loading the "custom" dialect.
+# Without this file here (and a corresponding _cext.globals.append_dialect_search_prefix("custom_dialect")),
+# PyGlobals::loadDialectModule would search and fail to find the "custom" dialect for each Operation.create("custom.op")
+# (amongst other things).
diff --git a/mlir/test/python/ir/dialects.py b/mlir/test/python/ir/dialects.py
index eebf7c3e48989ff..d59c6a6bc424e68 100644
--- a/mlir/test/python/ir/dialects.py
+++ b/mlir/test/python/ir/dialects.py
@@ -1,7 +1,9 @@
 # RUN: %PYTHON %s | FileCheck %s
 
 import gc
+import sys
 from mlir.ir import *
+from mlir.dialects._ods_common import _cext
 
 
 def run(f):
@@ -104,3 +106,18 @@ def testIsRegisteredOperation():
     print(f"cf.cond_br: {ctx.is_registered_operation('cf.cond_br')}")
     # CHECK: func.not_existing: False
     print(f"func.not_existing: {ctx.is_registered_operation('func.not_existing')}")
+
+
+# CHECK-LABEL: TEST: testAppendPrefixSearchPath
+ at run
+def testAppendPrefixSearchPath():
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    with Location.unknown(ctx):
+        assert not _cext.globals._check_dialect_module_loaded("custom")
+        Operation.create("custom.op")
+        assert not _cext.globals._check_dialect_module_loaded("custom")
+
+        sys.path.append(".")
+        _cext.globals.append_dialect_search_prefix("custom_dialect")
+        assert _cext.globals._check_dialect_module_loaded("custom")
diff --git a/mlir/test/python/ir/insertion_point.py b/mlir/test/python/ir/insertion_point.py
index b9054afaba82432..5eb861a2c089191 100644
--- a/mlir/test/python/ir/insertion_point.py
+++ b/mlir/test/python/ir/insertion_point.py
@@ -1,15 +1,7 @@
 # RUN: %PYTHON %s | FileCheck %s
 
 import gc
-import io
-import itertools
-import sys
-
 from mlir.ir import *
-from mlir.dialects._ods_common import _cext
-
-sys.path.append(".")
-_cext.globals.append_dialect_search_prefix("custom_dialect")
 
 
 def run(f):



More information about the Mlir-commits mailing list