[Mlir-commits] [mlir] 8260db7 - [mlir][Python] Return and accept OpView for all functions.
Stella Laurenzo
llvmlistbot at llvm.org
Tue Nov 3 22:49:06 PST 2020
Author: Stella Laurenzo
Date: 2020-11-03T22:48:34-08:00
New Revision: 8260db752c91e0c368b88607132be0a9cd9362ba
URL: https://github.com/llvm/llvm-project/commit/8260db752c91e0c368b88607132be0a9cd9362ba
DIFF: https://github.com/llvm/llvm-project/commit/8260db752c91e0c368b88607132be0a9cd9362ba.diff
LOG: [mlir][Python] Return and accept OpView for all functions.
* All functions that return an Operation now return an OpView.
* All functions that accept an Operation now accept an _OperationBase, which both Operation and OpView extend and can resolve to the backing Operation.
* Moves user-facing instance methods from Operation -> _OperationBase so that both can have the same API.
* Concretely, this means that if there are custom op classes defined (i.e. in Python), any iteration or creation will return the appropriate instance (i.e. if you get/create an std.addf, you will get an instance of the mlir.dialects.std.AddFOp class, getting full access to any custom API it exposes).
* Refactors all __eq__ methods after realizing the proper way to do this for _OperationBase.
Differential Revision: https://reviews.llvm.org/D90584
Added:
Modified:
mlir/lib/Bindings/Python/Globals.h
mlir/lib/Bindings/Python/IRModules.cpp
mlir/lib/Bindings/Python/IRModules.h
mlir/lib/Bindings/Python/MainModule.cpp
mlir/test/Bindings/Python/ir_operation.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index 33ab4cd6722d..6613d2b6963c 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -42,13 +42,17 @@ class PyGlobals {
dialectSearchPrefixes.swap(newValues);
}
+ /// Clears positive and negative caches regarding what implementations are
+ /// available. Future lookups will do more expensive existence checks.
+ void clearImportCache();
+
/// Loads a python module corresponding to the given dialect namespace.
/// No-ops if the module has already been loaded or is not found. Raises
/// an error on any evaluation issues.
/// Note that this returns void because it is expected that the module
/// contains calls to decorators and helpers that register the salient
/// entities.
- void loadDialectModule(const std::string &dialectNamespace);
+ void loadDialectModule(llvm::StringRef dialectNamespace);
/// Decorator for registering a custom Dialect class. The class object must
/// have a DIALECT_NAMESPACE attribute.
@@ -65,27 +69,39 @@ class PyGlobals {
/// This is intended to be called by implementation code.
void registerOperationImpl(const std::string &operationName,
pybind11::object pyClass,
- pybind11::object rawClass);
+ pybind11::object rawOpViewClass);
/// Looks up a registered dialect class by namespace. Note that this may
/// trigger loading of the defining module and can arbitrarily re-enter.
llvm::Optional<pybind11::object>
lookupDialectClass(const std::string &dialectNamespace);
+ /// Looks up a registered raw OpView class by operation name. Note that this
+ /// may trigger a load of the dialect, which can arbitrarily re-enter.
+ llvm::Optional<pybind11::object>
+ lookupRawOpViewClass(llvm::StringRef operationName);
+
private:
static PyGlobals *instance;
/// Module name prefixes to search under for dialect implementation modules.
std::vector<std::string> dialectSearchPrefixes;
- /// Map of dialect namespace to bool flag indicating whether the module has
- /// been successfully loaded or resolved to not found.
- llvm::StringSet<> loadedDialectModules;
/// Map of dialect namespace to external dialect class object.
llvm::StringMap<pybind11::object> dialectClassMap;
/// Map of full operation name to external operation class object.
llvm::StringMap<pybind11::object> operationClassMap;
/// Map of operation name to custom subclass that directly initializes
/// the OpView base class (bypassing the user class constructor).
- llvm::StringMap<pybind11::object> rawOperationClassMap;
+ llvm::StringMap<pybind11::object> rawOpViewClassMap;
+
+ /// Set of dialect namespaces that we have attempted to import implementation
+ /// modules for.
+ llvm::StringSet<> loadedDialectModulesCache;
+ /// Cache of operation name to custom OpView subclass that directly
+ /// initializes the OpView base class (or an undefined object for negative
+ /// lookup). This is maintained on loopup as a shadow of rawOpViewClassMap
+ /// in order for repeat lookups of the OpView classes to only incur the cost
+ /// of one hashtable lookup.
+ llvm::StringMap<pybind11::object> rawOpViewClassMapCache;
};
} // namespace python
diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 8c17e8e6d933..a4862ee59b89 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -407,7 +407,7 @@ class PyOperationIterator {
PyOperationRef returnOperation =
PyOperation::forOperation(parentOperation->getContext(), next);
next = mlirOperationGetNextInBlock(next);
- return returnOperation.releaseObject();
+ return returnOperation->createOpView();
}
static void bind(py::module &m) {
@@ -457,7 +457,7 @@ class PyOperationList {
while (!mlirOperationIsNull(childOp)) {
if (index == 0) {
return PyOperation::forOperation(parentOperation->getContext(), childOp)
- .releaseObject();
+ ->createOpView();
}
childOp = mlirOperationGetNextInBlock(childOp);
index -= 1;
@@ -868,11 +868,12 @@ void PyOperation::checkValid() {
}
}
-void PyOperation::print(py::object fileObject, bool binary,
- llvm::Optional<int64_t> largeElementsLimit,
- bool enableDebugInfo, bool prettyDebugInfo,
- bool printGenericOpForm, bool useLocalScope) {
- checkValid();
+void PyOperationBase::print(py::object fileObject, bool binary,
+ llvm::Optional<int64_t> largeElementsLimit,
+ bool enableDebugInfo, bool prettyDebugInfo,
+ bool printGenericOpForm, bool useLocalScope) {
+ PyOperation &operation = getOperation();
+ operation.checkValid();
if (fileObject.is_none())
fileObject = py::module::import("sys").attr("stdout");
MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
@@ -885,15 +886,16 @@ void PyOperation::print(py::object fileObject, bool binary,
PyFileAccumulator accum(fileObject, binary);
py::gil_scoped_release();
- mlirOperationPrintWithFlags(get(), flags, accum.getCallback(),
+ mlirOperationPrintWithFlags(operation.get(), flags, accum.getCallback(),
accum.getUserData());
mlirOpPrintingFlagsDestroy(flags);
}
-py::object PyOperation::getAsm(bool binary,
- llvm::Optional<int64_t> largeElementsLimit,
- bool enableDebugInfo, bool prettyDebugInfo,
- bool printGenericOpForm, bool useLocalScope) {
+py::object PyOperationBase::getAsm(bool binary,
+ llvm::Optional<int64_t> largeElementsLimit,
+ bool enableDebugInfo, bool prettyDebugInfo,
+ bool printGenericOpForm,
+ bool useLocalScope) {
py::object fileObject;
if (binary) {
fileObject = py::module::import("io").attr("BytesIO")();
@@ -1034,12 +1036,24 @@ py::object PyOperation::create(
ip->insert(*created.get());
}
- return created.releaseObject();
+ return created->createOpView();
}
-PyOpView::PyOpView(py::object operation)
- : operationObject(std::move(operation)),
- operation(py::cast<PyOperation *>(this->operationObject)) {}
+py::object PyOperation::createOpView() {
+ MlirIdentifier ident = mlirOperationGetName(get());
+ MlirStringRef identStr = mlirIdentifierStr(ident);
+ auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
+ llvm::StringRef(identStr.data, identStr.length));
+ if (opViewClass)
+ return (*opViewClass)(getRef().getObject());
+ return py::cast(PyOpView(getRef().getObject()));
+}
+
+PyOpView::PyOpView(py::object operationObject)
+ // Casting through the PyOperationBase base-class and then back to the
+ // Operation lets us accept any PyOperationBase subclass.
+ : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
+ operationObject(operation.getRef().getObject()) {}
py::object PyOpView::createRawSubclass(py::object userClass) {
// This is... a little gross. The typical pattern is to have a pure python
@@ -1082,11 +1096,12 @@ py::object PyOpView::createRawSubclass(py::object userClass) {
PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
-PyInsertionPoint::PyInsertionPoint(PyOperation &beforeOperation)
- : block(beforeOperation.getBlock()),
- refOperation(beforeOperation.getRef()) {}
+PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
+ : refOperation(beforeOperationBase.getOperation().getRef()),
+ block((*refOperation)->getBlock()) {}
-void PyInsertionPoint::insert(PyOperation &operation) {
+void PyInsertionPoint::insert(PyOperationBase &operationBase) {
+ PyOperation &operation = operationBase.getOperation();
if (operation.isAttached())
throw SetPyError(PyExc_ValueError,
"Attempt to insert operation that is already attached");
@@ -2501,33 +2516,36 @@ void mlir::python::populateIRSubmodule(py::module &m) {
//----------------------------------------------------------------------------
// Mapping of Operation.
//----------------------------------------------------------------------------
- py::class_<PyOperation>(m, "Operation")
- .def_static("create", &PyOperation::create, py::arg("name"),
- py::arg("operands") = py::none(),
- py::arg("results") = py::none(),
- py::arg("attributes") = py::none(),
- py::arg("successors") = py::none(), py::arg("regions") = 0,
- py::arg("loc") = py::none(), py::arg("ip") = py::none(),
- kOperationCreateDocstring)
- .def_property_readonly(
- "context",
- [](PyOperation &self) { return self.getContext().getObject(); },
- "Context that owns the Operation")
- .def_property_readonly(
- "operands",
- [](PyOperation &self) { return PyOpOperandList(self.getRef()); })
- .def_property_readonly(
- "regions",
- [](PyOperation &self) { return PyRegionList(self.getRef()); })
+ py::class_<PyOperationBase>(m, "_OperationBase")
+ .def("__eq__",
+ [](PyOperationBase &self, PyOperationBase &other) {
+ return &self.getOperation() == &other.getOperation();
+ })
+ .def("__eq__",
+ [](PyOperationBase &self, py::object other) { return false; })
+ .def_property_readonly("operands",
+ [](PyOperationBase &self) {
+ return PyOpOperandList(
+ self.getOperation().getRef());
+ })
+ .def_property_readonly("regions",
+ [](PyOperationBase &self) {
+ return PyRegionList(
+ self.getOperation().getRef());
+ })
.def_property_readonly(
"results",
- [](PyOperation &self) { return PyOpResultList(self.getRef()); },
+ [](PyOperationBase &self) {
+ return PyOpResultList(self.getOperation().getRef());
+ },
"Returns the list of Operation results.")
.def("__iter__",
- [](PyOperation &self) { return PyRegionIterator(self.getRef()); })
+ [](PyOperationBase &self) {
+ return PyRegionIterator(self.getOperation().getRef());
+ })
.def(
"__str__",
- [](PyOperation &self) {
+ [](PyOperationBase &self) {
return self.getAsm(/*binary=*/false,
/*largeElementsLimit=*/llvm::None,
/*enableDebugInfo=*/false,
@@ -2536,7 +2554,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
/*useLocalScope=*/false);
},
"Returns the assembly form of the operation.")
- .def("print", &PyOperation::print,
+ .def("print", &PyOperationBase::print,
// Careful: Lots of arguments must match up with print method.
py::arg("file") = py::none(), py::arg("binary") = false,
py::arg("large_elements_limit") = py::none(),
@@ -2544,7 +2562,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
py::arg("pretty_debug_info") = false,
py::arg("print_generic_op_form") = false,
py::arg("use_local_scope") = false, kOperationPrintDocstring)
- .def("get_asm", &PyOperation::getAsm,
+ .def("get_asm", &PyOperationBase::getAsm,
// Careful: Lots of arguments must match up with get_asm method.
py::arg("binary") = false,
py::arg("large_elements_limit") = py::none(),
@@ -2553,9 +2571,29 @@ void mlir::python::populateIRSubmodule(py::module &m) {
py::arg("print_generic_op_form") = false,
py::arg("use_local_scope") = false, kOperationGetAsmDocstring);
- py::class_<PyOpView>(m, "OpView")
+ py::class_<PyOperation, PyOperationBase>(m, "Operation")
+ .def_static("create", &PyOperation::create, py::arg("name"),
+ py::arg("operands") = py::none(),
+ py::arg("results") = py::none(),
+ py::arg("attributes") = py::none(),
+ py::arg("successors") = py::none(), py::arg("regions") = 0,
+ py::arg("loc") = py::none(), py::arg("ip") = py::none(),
+ kOperationCreateDocstring)
+ .def_property_readonly(
+ "context",
+ [](PyOperation &self) { return self.getContext().getObject(); },
+ "Context that owns the Operation")
+ .def_property_readonly("opview", &PyOperation::createOpView);
+
+ py::class_<PyOpView, PyOperationBase>(m, "OpView")
.def(py::init<py::object>())
.def_property_readonly("operation", &PyOpView::getOperationObject)
+ .def_property_readonly(
+ "context",
+ [](PyOpView &self) {
+ return self.getOperation().getContext().getObject();
+ },
+ "Context that owns the Operation")
.def("__str__",
[](PyOpView &self) { return py::str(self.getOperationObject()); });
@@ -2577,14 +2615,11 @@ void mlir::python::populateIRSubmodule(py::module &m) {
return PyBlockIterator(self.getParentOperation(), firstBlock);
},
"Iterates over blocks in the region.")
- .def("__eq__", [](PyRegion &self, py::object &other) {
- try {
- PyRegion *otherRegion = other.cast<PyRegion *>();
- return self.get().ptr == otherRegion->get().ptr;
- } catch (std::exception &e) {
- return false;
- }
- });
+ .def("__eq__",
+ [](PyRegion &self, PyRegion &other) {
+ return self.get().ptr == other.get().ptr;
+ })
+ .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
//----------------------------------------------------------------------------
// Mapping of PyBlock.
@@ -2613,14 +2648,10 @@ void mlir::python::populateIRSubmodule(py::module &m) {
},
"Iterates over operations in the block.")
.def("__eq__",
- [](PyBlock &self, py::object &other) {
- try {
- PyBlock *otherBlock = other.cast<PyBlock *>();
- return self.get().ptr == otherBlock->get().ptr;
- } catch (std::exception &e) {
- return false;
- }
+ [](PyBlock &self, PyBlock &other) {
+ return self.get().ptr == other.get().ptr;
})
+ .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
.def(
"__str__",
[](PyBlock &self) {
@@ -2651,7 +2682,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
},
"Gets the InsertionPoint bound to the current thread or raises "
"ValueError if none has been set")
- .def(py::init<PyOperation &>(), py::arg("beforeOperation"),
+ .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
"Inserts before a referenced operation.")
.def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
py::arg("block"), "Inserts at the beginning of the block.")
@@ -2696,14 +2727,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
},
py::keep_alive<0, 1>(), "Binds a name to the attribute")
.def("__eq__",
- [](PyAttribute &self, py::object &other) {
- try {
- PyAttribute otherAttribute = other.cast<PyAttribute>();
- return self == otherAttribute;
- } catch (std::exception &e) {
- return false;
- }
- })
+ [](PyAttribute &self, PyAttribute &other) { return self == other; })
+ .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
.def(
"dump", [](PyAttribute &self) { mlirAttributeDump(self.attr); },
kDumpDocstring)
@@ -2793,15 +2818,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
.def_property_readonly(
"context", [](PyType &self) { return self.getContext().getObject(); },
"Context that owns the Type")
- .def("__eq__",
- [](PyType &self, py::object &other) {
- try {
- PyType otherType = other.cast<PyType>();
- return self == otherType;
- } catch (std::exception &e) {
- return false;
- }
- })
+ .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
+ .def("__eq__", [](PyType &self, py::object &other) { return false; })
.def(
"dump", [](PyType &self) { mlirTypeDump(self.type); }, kDumpDocstring)
.def(
diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h
index e7fdbb9e7a5c..5236187c5b1f 100644
--- a/mlir/lib/Bindings/Python/IRModules.h
+++ b/mlir/lib/Bindings/Python/IRModules.h
@@ -366,6 +366,24 @@ class PyModule : public BaseContextObject {
pybind11::handle handle;
};
+/// Base class for PyOperation and PyOpView which exposes the primary, user
+/// visible methods for manipulating it.
+class PyOperationBase {
+public:
+ virtual ~PyOperationBase() = default;
+ /// Implements the bound 'print' method and helps with others.
+ void print(pybind11::object fileObject, bool binary,
+ llvm::Optional<int64_t> largeElementsLimit, bool enableDebugInfo,
+ bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope);
+ pybind11::object getAsm(bool binary,
+ llvm::Optional<int64_t> largeElementsLimit,
+ bool enableDebugInfo, bool prettyDebugInfo,
+ bool printGenericOpForm, bool useLocalScope);
+
+ /// Each must provide access to the raw Operation.
+ virtual PyOperation &getOperation() = 0;
+};
+
/// Wrapper around PyOperation.
/// Operations exist in either an attached (dependent) or detached (top-level)
/// state. In the detached state (as on creation), an operation is owned by
@@ -374,9 +392,11 @@ class PyModule : public BaseContextObject {
/// is bounded by its top-level parent reference.
class PyOperation;
using PyOperationRef = PyObjectRef<PyOperation>;
-class PyOperation : public BaseContextObject {
+class PyOperation : public PyOperationBase, public BaseContextObject {
public:
~PyOperation();
+ PyOperation &getOperation() override { return *this; }
+
/// Returns a PyOperation for the given MlirOperation, optionally associating
/// it with a parentKeepAlive.
static PyOperationRef
@@ -407,15 +427,6 @@ class PyOperation : public BaseContextObject {
}
void checkValid();
- /// Implements the bound 'print' method and helps with others.
- void print(pybind11::object fileObject, bool binary,
- llvm::Optional<int64_t> largeElementsLimit, bool enableDebugInfo,
- bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope);
- pybind11::object getAsm(bool binary,
- llvm::Optional<int64_t> largeElementsLimit,
- bool enableDebugInfo, bool prettyDebugInfo,
- bool printGenericOpForm, bool useLocalScope);
-
/// Gets the owning block or raises an exception if the operation has no
/// owning block.
PyBlock getBlock();
@@ -432,6 +443,9 @@ class PyOperation : public BaseContextObject {
llvm::Optional<std::vector<PyBlock *>> successors, int regions,
DefaultingPyLocation location, pybind11::object ip);
+ /// Creates an OpView suitable for this operation.
+ pybind11::object createOpView();
+
private:
PyOperation(PyMlirContextRef contextRef, MlirOperation operation);
static PyOperationRef createInstance(PyMlirContextRef contextRef,
@@ -456,17 +470,18 @@ class PyOperation : public BaseContextObject {
/// custom ODS-style operation classes. Since this class is subclass on the
/// python side, it must present an __init__ method that operates in pure
/// python types.
-class PyOpView {
+class PyOpView : public PyOperationBase {
public:
- PyOpView(pybind11::object operation);
+ PyOpView(pybind11::object operationObject);
+ PyOperation &getOperation() override { return operation; }
static pybind11::object createRawSubclass(pybind11::object userClass);
pybind11::object getOperationObject() { return operationObject; }
private:
+ PyOperation &operation; // For efficient, cast-free access from C++
pybind11::object operationObject; // Holds the reference.
- PyOperation *operation; // For efficient, cast-free access from C++
};
/// Wrapper around an MlirRegion.
@@ -519,7 +534,7 @@ class PyInsertionPoint {
/// block, but still inside the block.
PyInsertionPoint(PyBlock &block);
/// Creates an insertion point positioned before a reference operation.
- PyInsertionPoint(PyOperation &beforeOperation);
+ PyInsertionPoint(PyOperationBase &beforeOperationBase);
/// Shortcut to create an insertion point at the beginning of the block.
static PyInsertionPoint atBlockBegin(PyBlock &block);
@@ -527,7 +542,7 @@ class PyInsertionPoint {
static PyInsertionPoint atBlockTerminator(PyBlock &block);
/// Inserts an operation.
- void insert(PyOperation &operation);
+ void insert(PyOperationBase &operationBase);
/// Enter and exit the context manager.
pybind11::object contextEnter();
@@ -540,10 +555,10 @@ class PyInsertionPoint {
// Trampoline constructor that avoids null initializing members while
// looking up parents.
PyInsertionPoint(PyBlock block, llvm::Optional<PyOperationRef> refOperation)
- : block(std::move(block)), refOperation(std::move(refOperation)) {}
+ : refOperation(std::move(refOperation)), block(std::move(block)) {}
- PyBlock block;
llvm::Optional<PyOperationRef> refOperation;
+ PyBlock block;
};
/// Wrapper around the generic MlirAttribute.
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 1340468a8714..b2c1bafa5d69 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -30,17 +30,19 @@ PyGlobals::PyGlobals() {
PyGlobals::~PyGlobals() { instance = nullptr; }
-void PyGlobals::loadDialectModule(const std::string &dialectNamespace) {
- if (loadedDialectModules.contains(dialectNamespace))
+void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
+ py::gil_scoped_acquire();
+ if (loadedDialectModulesCache.contains(dialectNamespace))
return;
// Since re-entrancy is possible, make a copy of the search prefixes.
std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
py::object loaded;
for (std::string moduleName : localSearchPrefixes) {
moduleName.push_back('.');
- moduleName.append(dialectNamespace);
+ moduleName.append(dialectNamespace.data(), dialectNamespace.size());
try {
+ py::gil_scoped_release();
loaded = py::module::import(moduleName.c_str());
} catch (py::error_already_set &e) {
if (e.matches(PyExc_ModuleNotFoundError)) {
@@ -54,11 +56,12 @@ void PyGlobals::loadDialectModule(const std::string &dialectNamespace) {
// Note: Iterator cannot be shared from prior to loading, since re-entrancy
// may have occurred, which may do anything.
- loadedDialectModules.insert(dialectNamespace);
+ loadedDialectModulesCache.insert(dialectNamespace);
}
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
py::object pyClass) {
+ py::gil_scoped_acquire();
py::object &found = dialectClassMap[dialectNamespace];
if (found) {
throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") +
@@ -69,7 +72,9 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
}
void PyGlobals::registerOperationImpl(const std::string &operationName,
- py::object pyClass, py::object rawClass) {
+ py::object pyClass,
+ py::object rawOpViewClass) {
+ py::gil_scoped_acquire();
py::object &found = operationClassMap[operationName];
if (found) {
throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") +
@@ -77,11 +82,12 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,
"' is already registered.");
}
found = std::move(pyClass);
- rawOperationClassMap[operationName] = std::move(rawClass);
+ rawOpViewClassMap[operationName] = std::move(rawOpViewClass);
}
llvm::Optional<py::object>
PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
+ py::gil_scoped_acquire();
loadDialectModule(dialectNamespace);
// Fast match against the class map first (common case).
const auto foundIt = dialectClassMap.find(dialectNamespace);
@@ -97,6 +103,49 @@ PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
return llvm::None;
}
+llvm::Optional<pybind11::object>
+PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) {
+ {
+ py::gil_scoped_acquire();
+ auto foundIt = rawOpViewClassMapCache.find(operationName);
+ if (foundIt != rawOpViewClassMapCache.end()) {
+ if (foundIt->second.is_none())
+ return llvm::None;
+ assert(foundIt->second && "py::object is defined");
+ return foundIt->second;
+ }
+ }
+
+ // Not found. Load the dialect namespace.
+ auto split = operationName.split('.');
+ llvm::StringRef dialectNamespace = split.first;
+ loadDialectModule(dialectNamespace);
+
+ // Attempt to find from the canonical map and cache.
+ {
+ py::gil_scoped_acquire();
+ auto foundIt = rawOpViewClassMap.find(operationName);
+ if (foundIt != rawOpViewClassMap.end()) {
+ if (foundIt->second.is_none())
+ return llvm::None;
+ assert(foundIt->second && "py::object is defined");
+ // Positive cache.
+ rawOpViewClassMapCache[operationName] = foundIt->second;
+ return foundIt->second;
+ } else {
+ // Negative cache.
+ rawOpViewClassMap[operationName] = py::none();
+ return llvm::None;
+ }
+ }
+}
+
+void PyGlobals::clearImportCache() {
+ py::gil_scoped_acquire();
+ loadedDialectModulesCache.clear();
+ rawOpViewClassMapCache.clear();
+}
+
// -----------------------------------------------------------------------------
// Module initialization.
// -----------------------------------------------------------------------------
@@ -111,6 +160,7 @@ PYBIND11_MODULE(_mlir, m) {
.def("append_dialect_search_prefix",
[](PyGlobals &self, std::string moduleName) {
self.getDialectSearchPrefixes().push_back(std::move(moduleName));
+ self.clearImportCache();
})
.def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
"Testing hook for directly registering a dialect")
diff --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py
index 9e0ba3071073..d5c5b3f121de 100644
--- a/mlir/test/Bindings/Python/ir_operation.py
+++ b/mlir/test/Bindings/Python/ir_operation.py
@@ -293,3 +293,34 @@ def testOperationPrint():
pretty_debug_info=True, print_generic_op_form=True, use_local_scope=True)
run(testOperationPrint)
+
+
+def testKnownOpView():
+ with Context(), Location.unknown():
+ Context.current.allow_unregistered_dialects = True
+ module = Module.parse(r"""
+ %1 = "custom.f32"() : () -> f32
+ %2 = "custom.f32"() : () -> f32
+ %3 = addf %1, %2 : f32
+ """)
+ print(module)
+
+ # addf should map to a known OpView class in the std dialect.
+ # We know the OpView for it defines an 'lhs' attribute.
+ addf = module.body.operations[2]
+ # CHECK: <mlir.dialects.std._AddFOp object
+ print(repr(addf))
+ # CHECK: "custom.f32"()
+ print(addf.lhs)
+
+ # One of the custom ops should resolve to the default OpView.
+ custom = module.body.operations[0]
+ # CHECK: <_mlir.ir.OpView object
+ print(repr(custom))
+
+ # Check again to make sure negative caching works.
+ custom = module.body.operations[0]
+ # CHECK: <_mlir.ir.OpView object
+ print(repr(custom))
+
+run(testKnownOpView)
More information about the Mlir-commits
mailing list