[Mlir-commits] [mlir] [mlir][python] remove various caching mechanism (PR #70831)
Maksim Levental
llvmlistbot at llvm.org
Tue Oct 31 15:52:52 PDT 2023
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/70831
>From c5aa0249bbd1e9ba2e231bc19661b1c439d30742 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Tue, 31 Oct 2023 17:50:12 -0500
Subject: [PATCH 1/2] [mlir][python] fix python_test dialect and
I32/I64ElementsBuilder
---
mlir/python/mlir/ir.py | 4 ++--
mlir/test/python/dialects/python_test.py | 19 +++++++++----------
mlir/test/python/python_test_ops.td | 4 +++-
3 files changed, 14 insertions(+), 13 deletions(-)
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 43553f3118a51fc..cf4228c2a63a91b 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -277,7 +277,7 @@ def _f64ElementsAttr(x, context):
def _i32ElementsAttr(x, context):
return DenseElementsAttr.get(
np.array(x, dtype=np.int32),
- type=IntegerType.get_signed(32, context=context),
+ type=IntegerType.get_signless(32, context=context),
context=context,
)
@@ -285,7 +285,7 @@ def _i32ElementsAttr(x, context):
def _i64ElementsAttr(x, context):
return DenseElementsAttr.get(
np.array(x, dtype=np.int64),
- type=IntegerType.get_signed(64, context=context),
+ type=IntegerType.get_signless(64, context=context),
context=context,
)
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 3d4cd087fbfed8f..472db7e5124dbed 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -17,8 +17,7 @@ def run(f):
@run
def testAttributes():
with Context() as ctx, Location.unknown():
- ctx.allow_unregistered_dialects = True
-
+ test.register_python_test_dialect(ctx)
#
# Check op construction with attributes.
#
@@ -28,7 +27,7 @@ def testAttributes():
two = IntegerAttr.get(i32, 2)
unit = UnitAttr.get()
- # CHECK: "python_test.attributed_op"() {
+ # CHECK: python_test.attributed_op {
# CHECK-DAG: mandatory_i32 = 1 : i32
# CHECK-DAG: optional_i32 = 2 : i32
# CHECK-DAG: unit
@@ -36,7 +35,7 @@ def testAttributes():
op = test.AttributedOp(one, optional_i32=two, unit=unit)
print(f"{op}")
- # CHECK: "python_test.attributed_op"() {
+ # CHECK: python_test.attributed_op {
# CHECK: mandatory_i32 = 2 : i32
# CHECK: }
op2 = test.AttributedOp(two)
@@ -48,21 +47,21 @@ def testAttributes():
assert "additional" not in op.attributes
- # CHECK: "python_test.attributed_op"() {
+ # CHECK: python_test.attributed_op {
# CHECK-DAG: additional = 1 : i32
# CHECK-DAG: mandatory_i32 = 2 : i32
# CHECK: }
op2.attributes["additional"] = one
print(f"{op2}")
- # CHECK: "python_test.attributed_op"() {
+ # CHECK: python_test.attributed_op {
# CHECK-DAG: additional = 2 : i32
# CHECK-DAG: mandatory_i32 = 2 : i32
# CHECK: }
op2.attributes["additional"] = two
print(f"{op2}")
- # CHECK: "python_test.attributed_op"() {
+ # CHECK: python_test.attributed_op {
# CHECK-NOT: additional = 2 : i32
# CHECK: mandatory_i32 = 2 : i32
# CHECK: }
@@ -139,7 +138,7 @@ def testAttributes():
@run
def attrBuilder():
with Context() as ctx, Location.unknown():
- ctx.allow_unregistered_dialects = True
+ test.register_python_test_dialect(ctx)
# CHECK: python_test.attributes_op
op = test.AttributesOp(
# CHECK-DAG: x_affinemap = affine_map<() -> (2)>
@@ -177,10 +176,10 @@ def attrBuilder():
x_i16=42, # CHECK-DAG: x_i16 = 42 : i16
x_i32=6, # CHECK-DAG: x_i32 = 6 : i32
x_i32arr=[4, 5], # CHECK-DAG: x_i32arr = [4 : i32, 5 : i32]
- x_i32elems=[5, 6], # CHECK-DAG: x_i32elems = dense<[5, 6]> : tensor<2xsi32>
+ x_i32elems=[5, 6], # CHECK-DAG: x_i32elems = dense<[5, 6]> : tensor<2xi32>
x_i64=9, # CHECK-DAG: x_i64 = 9 : i64
x_i64arr=[7, 8], # CHECK-DAG: x_i64arr = [7, 8]
- x_i64elems=[8, 9], # CHECK-DAG: x_i64elems = dense<[8, 9]> : tensor<2xsi64>
+ x_i64elems=[8, 9], # CHECK-DAG: x_i64elems = dense<[8, 9]> : tensor<2xi64>
x_i64svecarr=[10, 11], # CHECK-DAG: x_i64svecarr = [10, 11]
x_i8=11, # CHECK-DAG: x_i8 = 11 : i8
x_idx=10, # CHECK-DAG: x_idx = 10 : index
diff --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td
index d79714301ae951e..95301985e3fde03 100644
--- a/mlir/test/python/python_test_ops.td
+++ b/mlir/test/python/python_test_ops.td
@@ -32,7 +32,9 @@ class TestAttr<string name, string attrMnemonic>
}
class TestOp<string mnemonic, list<Trait> traits = []>
- : Op<Python_Test_Dialect, mnemonic, traits>;
+ : Op<Python_Test_Dialect, mnemonic, traits> {
+ let assemblyFormat = "operands attr-dict functional-type(operands, results)";
+}
//===----------------------------------------------------------------------===//
// Type definitions.
>From 384b222ea78007aff0dd4b65f585c3ea03a7ddd1 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Tue, 31 Oct 2023 11:58:46 -0500
Subject: [PATCH 2/2] [mlir][python] remove various caching mechanism
---
mlir/docs/Bindings/Python.md | 2 +-
mlir/lib/Bindings/Python/Globals.h | 24 ++---
mlir/lib/Bindings/Python/IRModule.cpp | 131 +++++++-----------------
mlir/lib/Bindings/Python/MainModule.cpp | 5 +-
4 files changed, 48 insertions(+), 114 deletions(-)
diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index bc2e676a878c0f4..ef984e2bed7ea3a 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -945,7 +945,7 @@ When the python bindings need to locate a wrapper module, they consult the
`dialect_search_path` and use it to find an appropriately named module. For the
main repository, this search path is hard-coded to include the `mlir.dialects`
module, which is where wrappers are emitted by the above build rule. Out of tree
-dialects and add their modules to the search path by calling:
+dialects can add their modules to the search path by calling:
```python
mlir._cext.append_dialect_search_prefix("myproject.mlir.dialects")
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index 21899bdce22e810..976297257ced06e 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -9,10 +9,6 @@
#ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H
#define MLIR_BINDINGS_PYTHON_GLOBALS_H
-#include <optional>
-#include <string>
-#include <vector>
-
#include "PybindUtils.h"
#include "mlir-c/IR.h"
@@ -21,6 +17,10 @@
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"
+#include <optional>
+#include <string>
+#include <vector>
+
namespace mlir {
namespace python {
@@ -45,17 +45,13 @@ class PyGlobals {
dialectSearchPrefixes.swap(newValues);
}
- /// Clears positive and negative caches regarding what implementations are
- /// available. Future lookups will do more expensive existence checks.
- void clearImportCache();
-
/// Loads a python module corresponding to the given dialect namespace.
/// No-ops if the module has already been loaded or is not found. Raises
/// an error on any evaluation issues.
/// Note that this returns void because it is expected that the module
/// contains calls to decorators and helpers that register the salient
- /// entities.
- void loadDialectModule(llvm::StringRef dialectNamespace);
+ /// entities. Returns true if dialect is successfully loaded.
+ bool loadDialectModule(llvm::StringRef dialectNamespace);
/// Adds a user-friendly Attribute builder.
/// Raises an exception if the mapping already exists and replace == false.
@@ -113,16 +109,10 @@ class PyGlobals {
llvm::StringMap<pybind11::object> attributeBuilderMap;
/// Map of MlirTypeID to custom type caster.
llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMap;
- /// Cache for map of MlirTypeID to custom type caster.
- llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMapCache;
/// Set of dialect namespaces that we have attempted to import implementation
/// modules for.
- llvm::StringSet<> loadedDialectModulesCache;
- /// Cache of operation name to external operation class object. This is
- /// maintained on lookup as a shadow of operationClassMap in order for repeat
- /// lookups of the classes to only incur the cost of one hashtable lookup.
- llvm::StringMap<pybind11::object> operationClassMapCache;
+ llvm::StringSet<> loadedDialectModules;
};
} // namespace python
diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index f8e22f7bb0c1ba7..6c5cde86236ce90 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -10,12 +10,12 @@
#include "Globals.h"
#include "PybindUtils.h"
-#include <optional>
-#include <vector>
-
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/Support.h"
+#include <optional>
+#include <vector>
+
namespace py = pybind11;
using namespace mlir;
using namespace mlir::python;
@@ -36,12 +36,12 @@ PyGlobals::PyGlobals() {
PyGlobals::~PyGlobals() { instance = nullptr; }
-void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
- if (loadedDialectModulesCache.contains(dialectNamespace))
- return;
+bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
+ if (loadedDialectModules.contains(dialectNamespace))
+ return true;
// Since re-entrancy is possible, make a copy of the search prefixes.
std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
- py::object loaded;
+ py::object loaded = py::none();
for (std::string moduleName : localSearchPrefixes) {
moduleName.push_back('.');
moduleName.append(dialectNamespace.data(), dialectNamespace.size());
@@ -57,15 +57,18 @@ void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
break;
}
+ if (loaded.is_none())
+ return false;
// Note: Iterator cannot be shared from prior to loading, since re-entrancy
// may have occurred, which may do anything.
- loadedDialectModulesCache.insert(dialectNamespace);
+ loadedDialectModules.insert(dialectNamespace);
+ return true;
}
void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
py::function pyFunc, bool replace) {
py::object &found = attributeBuilderMap[attributeKind];
- if (found && !found.is_none() && !replace) {
+ if (found && !replace) {
throw std::runtime_error((llvm::Twine("Attribute builder for '") +
attributeKind +
"' is already registered with func: " +
@@ -79,13 +82,10 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
pybind11::function typeCaster,
bool replace) {
pybind11::object &found = typeCasterMap[mlirTypeID];
- if (found && !found.is_none() && !replace)
- throw std::runtime_error("Type caster is already registered");
+ if (found && !replace)
+ throw std::runtime_error("Type caster is already registered with caster: " +
+ py::str(found).operator std::string());
found = std::move(typeCaster);
- const auto foundIt = typeCasterMapCache.find(mlirTypeID);
- if (foundIt != typeCasterMapCache.end() && !foundIt->second.is_none()) {
- typeCasterMapCache[mlirTypeID] = found;
- }
}
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
@@ -108,114 +108,59 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,
.str());
}
found = std::move(pyClass);
- auto foundIt = operationClassMapCache.find(operationName);
- if (foundIt != operationClassMapCache.end() && !foundIt->second.is_none()) {
- operationClassMapCache[operationName] = found;
- }
}
std::optional<py::function>
PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
- // Fast match against the class map first (common case).
const auto foundIt = attributeBuilderMap.find(attributeKind);
if (foundIt != attributeBuilderMap.end()) {
- if (foundIt->second.is_none())
- return std::nullopt;
- assert(foundIt->second && "py::function is defined");
+ assert(foundIt->second && "attribute builder is defined");
return foundIt->second;
}
-
- // Not found and loading did not yield a registration. Negative cache.
- attributeBuilderMap[attributeKind] = py::none();
return std::nullopt;
}
std::optional<py::function> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
MlirDialect dialect) {
- {
- // Fast match against the class map first (common case).
- const auto foundIt = typeCasterMapCache.find(mlirTypeID);
- if (foundIt != typeCasterMapCache.end()) {
- if (foundIt->second.is_none())
- return std::nullopt;
- assert(foundIt->second && "py::function is defined");
- return foundIt->second;
- }
- }
-
- // Not found. Load the dialect namespace.
- loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
-
- // Attempt to find from the canonical map and cache.
- {
- const auto foundIt = typeCasterMap.find(mlirTypeID);
- if (foundIt != typeCasterMap.end()) {
- if (foundIt->second.is_none())
- return std::nullopt;
- assert(foundIt->second && "py::object is defined");
- // Positive cache.
- typeCasterMapCache[mlirTypeID] = foundIt->second;
- return foundIt->second;
- }
- // Negative cache.
- typeCasterMap[mlirTypeID] = py::none();
+ // Make sure dialect module is loaded.
+ if (!loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))))
return std::nullopt;
+
+ const auto foundIt = typeCasterMap.find(mlirTypeID);
+ if (foundIt != typeCasterMap.end()) {
+ assert(foundIt->second && "type caster is defined");
+ return foundIt->second;
}
+ return std::nullopt;
}
std::optional<py::object>
PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
- loadDialectModule(dialectNamespace);
- // Fast match against the class map first (common case).
+ // Make sure dialect module is loaded.
+ if (!loadDialectModule(dialectNamespace))
+ return std::nullopt;
const auto foundIt = dialectClassMap.find(dialectNamespace);
if (foundIt != dialectClassMap.end()) {
- if (foundIt->second.is_none())
- return std::nullopt;
- assert(foundIt->second && "py::object is defined");
+ assert(foundIt->second && "dialect class is defined");
return foundIt->second;
}
-
- // Not found and loading did not yield a registration. Negative cache.
- dialectClassMap[dialectNamespace] = py::none();
+ // Not found and loading did not yield a registration.
return std::nullopt;
}
std::optional<pybind11::object>
PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
- {
- auto foundIt = operationClassMapCache.find(operationName);
- if (foundIt != operationClassMapCache.end()) {
- if (foundIt->second.is_none())
- return std::nullopt;
- assert(foundIt->second && "py::object is defined");
- return foundIt->second;
- }
- }
-
- // Not found. Load the dialect namespace.
+ // Make sure dialect module is loaded.
auto split = operationName.split('.');
llvm::StringRef dialectNamespace = split.first;
- loadDialectModule(dialectNamespace);
-
- // Attempt to find from the canonical map and cache.
- {
- auto foundIt = operationClassMap.find(operationName);
- if (foundIt != operationClassMap.end()) {
- if (foundIt->second.is_none())
- return std::nullopt;
- assert(foundIt->second && "py::object is defined");
- // Positive cache.
- operationClassMapCache[operationName] = foundIt->second;
- return foundIt->second;
- }
- // Negative cache.
- operationClassMap[operationName] = py::none();
+ if (!loadDialectModule(dialectNamespace))
return std::nullopt;
- }
-}
-void PyGlobals::clearImportCache() {
- loadedDialectModulesCache.clear();
- operationClassMapCache.clear();
- typeCasterMapCache.clear();
+ auto foundIt = operationClassMap.find(operationName);
+ if (foundIt != operationClassMap.end()) {
+ assert(foundIt->second && "OpView is defined");
+ return foundIt->second;
+ }
+ // Not found and loading did not yield a registration.
+ return std::nullopt;
}
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index a936becf67bea75..2b6248321c1c110 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -6,14 +6,14 @@
//
//===----------------------------------------------------------------------===//
-#include <tuple>
-
#include "PybindUtils.h"
#include "Globals.h"
#include "IRModule.h"
#include "Pass.h"
+#include <tuple>
+
namespace py = pybind11;
using namespace mlir;
using namespace py::literals;
@@ -34,7 +34,6 @@ PYBIND11_MODULE(_mlir, m) {
"append_dialect_search_prefix",
[](PyGlobals &self, std::string moduleName) {
self.getDialectSearchPrefixes().push_back(std::move(moduleName));
- self.clearImportCache();
},
"module_name"_a)
.def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
More information about the Mlir-commits
mailing list