[Mlir-commits] [mlir] 013b932 - [mlir][Python] Custom python op view wrappers for building and traversing.
Stella Laurenzo
llvmlistbot at llvm.org
Tue Oct 27 12:29:14 PDT 2020
Author: Stella Laurenzo
Date: 2020-10-27T12:23:34-07:00
New Revision: 013b9322dea9564ac85c4082fb0f07ff093eef63
URL: https://github.com/llvm/llvm-project/commit/013b9322dea9564ac85c4082fb0f07ff093eef63
DIFF: https://github.com/llvm/llvm-project/commit/013b9322dea9564ac85c4082fb0f07ff093eef63.diff
LOG: [mlir][Python] Custom python op view wrappers for building and traversing.
* Still rough edges that need more sugar but the bones are there. Notes left in the test case for things that can be improved.
* Does not actually yield custom OpViews yet for traversing. Will rework that in a followup.
Differential Revision: https://reviews.llvm.org/D89932
Added:
mlir/lib/Bindings/Python/Globals.h
mlir/lib/Bindings/Python/mlir/dialects/__init__.py
mlir/lib/Bindings/Python/mlir/dialects/std.py
mlir/lib/Bindings/Python/mlir/ir.py
mlir/test/Bindings/Python/dialects.py
Modified:
mlir/lib/Bindings/Python/CMakeLists.txt
mlir/lib/Bindings/Python/IRModules.cpp
mlir/lib/Bindings/Python/IRModules.h
mlir/lib/Bindings/Python/MainModule.cpp
mlir/lib/Bindings/Python/mlir/__init__.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt
index f7b04fff7f63..d4913bb43947 100644
--- a/mlir/lib/Bindings/Python/CMakeLists.txt
+++ b/mlir/lib/Bindings/Python/CMakeLists.txt
@@ -4,6 +4,9 @@
set(PY_SRC_FILES
mlir/__init__.py
+ mlir/ir.py
+ mlir/dialects/__init__.py
+ mlir/dialects/std.py
)
add_custom_target(MLIRBindingsPythonSources ALL
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
new file mode 100644
index 000000000000..33ab4cd6722d
--- /dev/null
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -0,0 +1,94 @@
+//===- Globals.h - MLIR Python extension globals --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H
+#define MLIR_BINDINGS_PYTHON_GLOBALS_H
+
+#include <string>
+#include <vector>
+
+#include "PybindUtils.h"
+
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/StringSet.h"
+
+namespace mlir {
+namespace python {
+
+/// Globals that are always accessible once the extension has been initialized.
+class PyGlobals {
+public:
+ PyGlobals();
+ ~PyGlobals();
+
+ /// Most code should get the globals via this static accessor.
+ static PyGlobals &get() {
+ assert(instance && "PyGlobals is null");
+ return *instance;
+ }
+
+ /// Get and set the list of parent modules to search for dialect
+ /// implementation classes.
+ std::vector<std::string> &getDialectSearchPrefixes() {
+ return dialectSearchPrefixes;
+ }
+ void setDialectSearchPrefixes(std::vector<std::string> newValues) {
+ dialectSearchPrefixes.swap(newValues);
+ }
+
+ /// Loads a python module corresponding to the given dialect namespace.
+ /// No-ops if the module has already been loaded or is not found. Raises
+ /// an error on any evaluation issues.
+ /// Note that this returns void because it is expected that the module
+ /// contains calls to decorators and helpers that register the salient
+ /// entities.
+ void loadDialectModule(const std::string &dialectNamespace);
+
+ /// Decorator for registering a custom Dialect class. The class object must
+ /// have a DIALECT_NAMESPACE attribute.
+ pybind11::object registerDialectDecorator(pybind11::object pyClass);
+
+ /// Adds a concrete implementation dialect class.
+ /// Raises an exception if the mapping already exists.
+ /// This is intended to be called by implementation code.
+ void registerDialectImpl(const std::string &dialectNamespace,
+ pybind11::object pyClass);
+
+ /// Adds a concrete implementation operation class.
+ /// Raises an exception if the mapping already exists.
+ /// This is intended to be called by implementation code.
+ void registerOperationImpl(const std::string &operationName,
+ pybind11::object pyClass,
+ pybind11::object rawClass);
+
+ /// Looks up a registered dialect class by namespace. Note that this may
+ /// trigger loading of the defining module and can arbitrarily re-enter.
+ llvm::Optional<pybind11::object>
+ lookupDialectClass(const std::string &dialectNamespace);
+
+private:
+ static PyGlobals *instance;
+ /// Module name prefixes to search under for dialect implementation modules.
+ std::vector<std::string> dialectSearchPrefixes;
+ /// Map of dialect namespace to bool flag indicating whether the module has
+ /// been successfully loaded or resolved to not found.
+ llvm::StringSet<> loadedDialectModules;
+ /// Map of dialect namespace to external dialect class object.
+ llvm::StringMap<pybind11::object> dialectClassMap;
+ /// Map of full operation name to external operation class object.
+ llvm::StringMap<pybind11::object> operationClassMap;
+ /// Map of operation name to custom subclass that directly initializes
+ /// the OpView base class (bypassing the user class constructor).
+ llvm::StringMap<pybind11::object> rawOperationClassMap;
+};
+
+} // namespace python
+} // namespace mlir
+
+#endif // MLIR_BINDINGS_PYTHON_GLOBALS_H
diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 014b312971b7..2fba7fa5e283 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -7,6 +7,8 @@
//===----------------------------------------------------------------------===//
#include "IRModules.h"
+
+#include "Globals.h"
#include "PybindUtils.h"
#include "mlir-c/Bindings/Python/Interop.h"
@@ -209,19 +211,27 @@ struct PySinglePartStringAccumulator {
} // namespace
//------------------------------------------------------------------------------
-// Type-checking utilities.
+// Utilities.
//------------------------------------------------------------------------------
-namespace {
-
/// Checks whether the given type is an integer or float type.
-int mlirTypeIsAIntegerOrFloat(MlirType type) {
+static int mlirTypeIsAIntegerOrFloat(MlirType type) {
return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
}
-} // namespace
+static py::object
+createCustomDialectWrapper(const std::string &dialectNamespace,
+ py::object dialectDescriptor) {
+ auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
+ if (!dialectClass) {
+ // Use the base class.
+ return py::cast(PyDialect(std::move(dialectDescriptor)));
+ }
+ // Create the custom implementation.
+ return (*dialectClass)(std::move(dialectDescriptor));
+}
//------------------------------------------------------------------------------
// Collections.
//------------------------------------------------------------------------------
@@ -567,9 +577,11 @@ size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
py::object PyMlirContext::createOperation(
std::string name, PyLocation location,
+ llvm::Optional<std::vector<PyValue *>> operands,
llvm::Optional<std::vector<PyType *>> results,
llvm::Optional<py::dict> attributes,
llvm::Optional<std::vector<PyBlock *>> successors, int regions) {
+ llvm::SmallVector<MlirValue, 4> mlirOperands;
llvm::SmallVector<MlirType, 4> mlirResults;
llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
@@ -578,6 +590,16 @@ py::object PyMlirContext::createOperation(
if (regions < 0)
throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
+ // Unpack/validate operands.
+ if (operands) {
+ mlirOperands.reserve(operands->size());
+ for (PyValue *operand : *operands) {
+ if (!operand)
+ throw SetPyError(PyExc_ValueError, "operand value cannot be None");
+ mlirOperands.push_back(operand->get());
+ }
+ }
+
// Unpack/validate results.
if (results) {
mlirResults.reserve(results->size());
@@ -614,6 +636,9 @@ py::object PyMlirContext::createOperation(
// Apply unpacked/validated to the operation state. Beyond this
// point, exceptions cannot be thrown or else the state will leak.
MlirOperationState state = mlirOperationStateGet(name.c_str(), location.loc);
+ if (!mlirOperands.empty())
+ mlirOperationStateAddOperands(&state, mlirOperands.size(),
+ mlirOperands.data());
if (!mlirResults.empty())
mlirOperationStateAddResults(&state, mlirResults.size(),
mlirResults.data());
@@ -646,6 +671,24 @@ py::object PyMlirContext::createOperation(
return PyOperation::createDetached(getRef(), operation).releaseObject();
}
+//------------------------------------------------------------------------------
+// PyDialect, PyDialectDescriptor, PyDialects
+//------------------------------------------------------------------------------
+
+MlirDialect PyDialects::getDialectForKey(const std::string &key,
+ bool attrError) {
+ // If the "std" dialect was asked for, substitute the empty namespace :(
+ static const std::string emptyKey;
+ const std::string *canonKey = key == "std" ? &emptyKey : &key;
+ MlirDialect dialect = mlirContextGetOrLoadDialect(
+ getContext()->get(), {canonKey->data(), canonKey->size()});
+ if (mlirDialectIsNull(dialect)) {
+ throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
+ llvm::Twine("Dialect '") + key + "' not found");
+ }
+ return dialect;
+}
+
//------------------------------------------------------------------------------
// PyModule
//------------------------------------------------------------------------------
@@ -815,6 +858,45 @@ py::object PyOperation::getAsm(bool binary,
return fileObject.attr("getvalue")();
}
+PyOpView::PyOpView(py::object operation)
+ : operationObject(std::move(operation)),
+ operation(py::cast<PyOperation *>(this->operationObject)) {}
+
+py::object PyOpView::createRawSubclass(py::object userClass) {
+ // This is... a little gross. The typical pattern is to have a pure python
+ // class that extends OpView like:
+ // class AddFOp(_cext.ir.OpView):
+ // def __init__(self, loc, lhs, rhs):
+ // operation = loc.context.create_operation(
+ // "addf", lhs, rhs, results=[lhs.type])
+ // super().__init__(operation)
+ //
+ // I.e. The goal of the user facing type is to provide a nice constructor
+ // that has complete freedom for the op under construction. This is at odds
+ // with our other desire to sometimes create this object by just passing an
+ // operation (to initialize the base class). We could do *arg and **kwargs
+ // munging to try to make it work, but instead, we synthesize a new class
+ // on the fly which extends this user class (AddFOp in this example) and
+ // *give it* the base class's __init__ method, thus bypassing the
+ // intermediate subclass's __init__ method entirely. While slightly,
+ // underhanded, this is safe/legal because the type hierarchy has not changed
+ // (we just added a new leaf) and we aren't mucking around with __new__.
+ // Typically, this new class will be stored on the original as "_Raw" and will
+ // be used for casts and other things that need a variant of the class that
+ // is initialized purely from an operation.
+ py::object parentMetaclass =
+ py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
+ py::dict attributes;
+ // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
+ // now.
+ // auto opViewType = py::type::of<PyOpView>();
+ auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
+ attributes["__init__"] = opViewType.attr("__init__");
+ py::str origName = userClass.attr("__name__");
+ py::str newName = py::str("_") + origName;
+ return parentMetaclass(newName, py::make_tuple(userClass), attributes);
+}
+
//------------------------------------------------------------------------------
// PyAttribute.
//------------------------------------------------------------------------------
@@ -966,6 +1048,41 @@ class PyBlockArgumentList {
MlirBlock block;
};
+/// A list of operation results. Internally, these are stored as consecutive
+/// elements, random access is cheap. The result list is associated with the
+/// operation whose results these are, and extends the lifetime of this
+/// operation.
+class PyOpOperandList {
+public:
+ PyOpOperandList(PyOperationRef operation) : operation(operation) {}
+
+ /// Returns the length of the result list.
+ intptr_t dunderLen() {
+ operation->checkValid();
+ return mlirOperationGetNumOperands(operation->get());
+ }
+
+ /// Returns `index`-th element in the result list.
+ PyOpResult dunderGetItem(intptr_t index) {
+ if (index < 0 || index >= dunderLen()) {
+ throw SetPyError(PyExc_IndexError,
+ "attempt to access out of bounds region");
+ }
+ PyValue value(operation, mlirOperationGetOperand(operation->get(), index));
+ return PyOpResult(value);
+ }
+
+ /// Defines a Python class in the bindings.
+ static void bind(py::module &m) {
+ py::class_<PyOpOperandList>(m, "OpOperandList")
+ .def("__len__", &PyOpOperandList::dunderLen)
+ .def("__getitem__", &PyOpOperandList::dunderGetItem);
+ }
+
+private:
+ PyOperationRef operation;
+};
+
/// A list of operation results. Internally, these are stored as consecutive
/// elements, random access is cheap. The result list is associated with the
/// operation whose results these are, and extends the lifetime of this
@@ -1914,7 +2031,9 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
//------------------------------------------------------------------------------
void mlir::python::populateIRSubmodule(py::module &m) {
+ //----------------------------------------------------------------------------
// Mapping of MlirContext
+ //----------------------------------------------------------------------------
py::class_<PyMlirContext>(m, "Context")
.def(py::init<>(&PyMlirContext::createNewContextForInit))
.def_static("_get_live_count", &PyMlirContext::getLiveCount)
@@ -1928,6 +2047,25 @@ void mlir::python::populateIRSubmodule(py::module &m) {
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyMlirContext::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
+ .def_property_readonly(
+ "dialects",
+ [](PyMlirContext &self) { return PyDialects(self.getRef()); },
+ "Gets a container for accessing dialects by name")
+ .def_property_readonly(
+ "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
+ "Alias for 'dialect'")
+ .def(
+ "get_dialect_descriptor",
+ [=](PyMlirContext &self, std::string &name) {
+ MlirDialect dialect = mlirContextGetOrLoadDialect(
+ self.get(), {name.data(), name.size()});
+ if (mlirDialectIsNull(dialect)) {
+ throw SetPyError(PyExc_ValueError,
+ llvm::Twine("Dialect '") + name + "' not found");
+ }
+ return PyDialectDescriptor(self.getRef(), dialect);
+ },
+ "Gets or loads a dialect by name, returning its descriptor object")
.def_property(
"allow_unregistered_dialects",
[](PyMlirContext &self) -> bool {
@@ -1937,8 +2075,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
mlirContextSetAllowUnregisteredDialects(self.get(), value);
})
.def("create_operation", &PyMlirContext::createOperation, py::arg("name"),
- py::arg("location"), py::arg("results") = py::none(),
- py::arg("attributes") = py::none(),
+ py::arg("location"), py::arg("operands") = py::none(),
+ py::arg("results") = py::none(), py::arg("attributes") = py::none(),
py::arg("successors") = py::none(), py::arg("regions") = 0,
kContextCreateOperationDocstring)
.def(
@@ -2009,6 +2147,62 @@ void mlir::python::populateIRSubmodule(py::module &m) {
kContextGetFileLocationDocstring, py::arg("filename"),
py::arg("line"), py::arg("col"));
+ //----------------------------------------------------------------------------
+ // Mapping of PyDialectDescriptor
+ //----------------------------------------------------------------------------
+ py::class_<PyDialectDescriptor>(m, "DialectDescriptor")
+ .def_property_readonly("namespace",
+ [](PyDialectDescriptor &self) {
+ MlirStringRef ns =
+ mlirDialectGetNamespace(self.get());
+ return py::str(ns.data, ns.length);
+ })
+ .def("__repr__", [](PyDialectDescriptor &self) {
+ MlirStringRef ns = mlirDialectGetNamespace(self.get());
+ std::string repr("<DialectDescriptor ");
+ repr.append(ns.data, ns.length);
+ repr.append(">");
+ return repr;
+ });
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyDialects
+ //----------------------------------------------------------------------------
+ py::class_<PyDialects>(m, "Dialects")
+ .def("__getitem__",
+ [=](PyDialects &self, std::string keyName) {
+ MlirDialect dialect =
+ self.getDialectForKey(keyName, /*attrError=*/false);
+ py::object descriptor =
+ py::cast(PyDialectDescriptor{self.getContext(), dialect});
+ return createCustomDialectWrapper(keyName, std::move(descriptor));
+ })
+ .def("__getattr__", [=](PyDialects &self, std::string attrName) {
+ MlirDialect dialect =
+ self.getDialectForKey(attrName, /*attrError=*/true);
+ py::object descriptor =
+ py::cast(PyDialectDescriptor{self.getContext(), dialect});
+ return createCustomDialectWrapper(attrName, std::move(descriptor));
+ });
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyDialect
+ //----------------------------------------------------------------------------
+ py::class_<PyDialect>(m, "Dialect")
+ .def(py::init<py::object>(), "descriptor")
+ .def_property_readonly(
+ "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
+ .def("__repr__", [](py::object self) {
+ auto clazz = self.attr("__class__");
+ return py::str("<Dialect ") +
+ self.attr("descriptor").attr("namespace") + py::str(" (class ") +
+ clazz.attr("__module__") + py::str(".") +
+ clazz.attr("__name__") + py::str(")>");
+ });
+
+ //----------------------------------------------------------------------------
+ // Mapping of Location
+ //----------------------------------------------------------------------------
py::class_<PyLocation>(m, "Location")
.def_property_readonly(
"context",
@@ -2021,7 +2215,9 @@ void mlir::python::populateIRSubmodule(py::module &m) {
return printAccum.join();
});
+ //----------------------------------------------------------------------------
// Mapping of Module
+ //----------------------------------------------------------------------------
py::class_<PyModule>(m, "Module")
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
@@ -2055,12 +2251,17 @@ void mlir::python::populateIRSubmodule(py::module &m) {
},
kOperationStrDunderDocstring);
+ //----------------------------------------------------------------------------
// Mapping of Operation.
+ //----------------------------------------------------------------------------
py::class_<PyOperation>(m, "Operation")
.def_property_readonly(
"context",
[](PyOperation &self) { return self.getContext().getObject(); },
"Context that owns the Operation")
+ .def_property_readonly(
+ "operands",
+ [](PyOperation &self) { return PyOpOperandList(self.getRef()); })
.def_property_readonly(
"regions",
[](PyOperation &self) { return PyRegionList(self.getRef()); })
@@ -2098,7 +2299,15 @@ void mlir::python::populateIRSubmodule(py::module &m) {
py::arg("print_generic_op_form") = false,
py::arg("use_local_scope") = false, kOperationGetAsmDocstring);
+ py::class_<PyOpView>(m, "OpView")
+ .def(py::init<py::object>())
+ .def_property_readonly("operation", &PyOpView::getOperationObject)
+ .def("__str__",
+ [](PyOpView &self) { return py::str(self.getOperationObject()); });
+
+ //----------------------------------------------------------------------------
// Mapping of PyRegion.
+ //----------------------------------------------------------------------------
py::class_<PyRegion>(m, "Region")
.def_property_readonly(
"blocks",
@@ -2123,7 +2332,9 @@ void mlir::python::populateIRSubmodule(py::module &m) {
}
});
+ //----------------------------------------------------------------------------
// Mapping of PyBlock.
+ //----------------------------------------------------------------------------
py::class_<PyBlock>(m, "Block")
.def_property_readonly(
"arguments",
@@ -2167,7 +2378,9 @@ void mlir::python::populateIRSubmodule(py::module &m) {
},
"Returns the assembly form of the block.");
+ //----------------------------------------------------------------------------
// Mapping of PyAttribute.
+ //----------------------------------------------------------------------------
py::class_<PyAttribute>(m, "Attribute")
.def_property_readonly(
"context",
@@ -2219,6 +2432,9 @@ void mlir::python::populateIRSubmodule(py::module &m) {
return printAccum.join();
});
+ //----------------------------------------------------------------------------
+ // Mapping of PyNamedAttribute
+ //----------------------------------------------------------------------------
py::class_<PyNamedAttribute>(m, "NamedAttribute")
.def("__repr__",
[](PyNamedAttribute &self) {
@@ -2257,7 +2473,9 @@ void mlir::python::populateIRSubmodule(py::module &m) {
PyStringAttribute::bind(m);
PyDenseElementsAttribute::bind(m);
+ //----------------------------------------------------------------------------
// Mapping of PyType.
+ //----------------------------------------------------------------------------
py::class_<PyType>(m, "Type")
.def_property_readonly(
"context", [](PyType &self) { return self.getContext().getObject(); },
@@ -2313,7 +2531,9 @@ void mlir::python::populateIRSubmodule(py::module &m) {
PyTupleType::bind(m);
PyFunctionType::bind(m);
+ //----------------------------------------------------------------------------
// Mapping of Value.
+ //----------------------------------------------------------------------------
py::class_<PyValue>(m, "Value")
.def_property_readonly(
"context",
@@ -2346,6 +2566,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
PyBlockList::bind(m);
PyOperationIterator::bind(m);
PyOperationList::bind(m);
+ PyOpOperandList::bind(m);
PyOpResultList::bind(m);
PyRegionIterator::bind(m);
PyRegionList::bind(m);
diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h
index b438e8ac408d..89cca9c1c85f 100644
--- a/mlir/lib/Bindings/Python/IRModules.h
+++ b/mlir/lib/Bindings/Python/IRModules.h
@@ -132,6 +132,7 @@ class PyMlirContext {
/// Creates an operation. See corresponding python docstring.
pybind11::object
createOperation(std::string name, PyLocation location,
+ llvm::Optional<std::vector<PyValue *>> operands,
llvm::Optional<std::vector<PyType *>> results,
llvm::Optional<pybind11::dict> attributes,
llvm::Optional<std::vector<PyBlock *>> successors,
@@ -187,6 +188,45 @@ class BaseContextObject {
PyMlirContextRef contextRef;
};
+/// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in
+/// order to
diff erentiate it from the `Dialect` base class which is extended by
+/// plugins which extend dialect functionality through extension python code.
+/// This should be seen as the "low-level" object and `Dialect` as the
+/// high-level, user facing object.
+class PyDialectDescriptor : public BaseContextObject {
+public:
+ PyDialectDescriptor(PyMlirContextRef contextRef, MlirDialect dialect)
+ : BaseContextObject(std::move(contextRef)), dialect(dialect) {}
+
+ MlirDialect get() { return dialect; }
+
+private:
+ MlirDialect dialect;
+};
+
+/// User-level object for accessing dialects with dotted syntax such as:
+/// ctx.dialect.std
+class PyDialects : public BaseContextObject {
+public:
+ PyDialects(PyMlirContextRef contextRef)
+ : BaseContextObject(std::move(contextRef)) {}
+
+ MlirDialect getDialectForKey(const std::string &key, bool attrError);
+};
+
+/// User-level dialect object. For dialects that have a registered extension,
+/// this will be the base class of the extension dialect type. For un-extended,
+/// objects of this type will be returned directly.
+class PyDialect {
+public:
+ PyDialect(pybind11::object descriptor) : descriptor(std::move(descriptor)) {}
+
+ pybind11::object getDescriptor() { return descriptor; }
+
+private:
+ pybind11::object descriptor;
+};
+
/// Wrapper around an MlirLocation.
class PyLocation : public BaseContextObject {
public:
@@ -305,6 +345,24 @@ class PyOperation : public BaseContextObject {
bool valid = true;
};
+/// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for
+/// providing more instance-specific accessors and serve as the base class for
+/// custom ODS-style operation classes. Since this class is subclass on the
+/// python side, it must present an __init__ method that operates in pure
+/// python types.
+class PyOpView {
+public:
+ PyOpView(pybind11::object operation);
+
+ static pybind11::object createRawSubclass(pybind11::object userClass);
+
+ pybind11::object getOperationObject() { return operationObject; }
+
+private:
+ pybind11::object operationObject; // Holds the reference.
+ PyOperation *operation; // For efficient, cast-free access from C++
+};
+
/// Wrapper around an MlirRegion.
/// Regions are managed completely by their containing operation. Unlike the
/// C++ API, the python API does not support detached regions.
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 7dd525b4b340..1340468a8714 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -8,17 +8,155 @@
#include <tuple>
-#include <pybind11/pybind11.h>
+#include "PybindUtils.h"
+#include "Globals.h"
#include "IRModules.h"
namespace py = pybind11;
using namespace mlir;
using namespace mlir::python;
+// -----------------------------------------------------------------------------
+// PyGlobals
+// -----------------------------------------------------------------------------
+
+PyGlobals *PyGlobals::instance = nullptr;
+
+PyGlobals::PyGlobals() {
+ assert(!instance && "PyGlobals already constructed");
+ instance = this;
+}
+
+PyGlobals::~PyGlobals() { instance = nullptr; }
+
+void PyGlobals::loadDialectModule(const std::string &dialectNamespace) {
+ if (loadedDialectModules.contains(dialectNamespace))
+ return;
+ // Since re-entrancy is possible, make a copy of the search prefixes.
+ std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
+ py::object loaded;
+ for (std::string moduleName : localSearchPrefixes) {
+ moduleName.push_back('.');
+ moduleName.append(dialectNamespace);
+
+ try {
+ loaded = py::module::import(moduleName.c_str());
+ } catch (py::error_already_set &e) {
+ if (e.matches(PyExc_ModuleNotFoundError)) {
+ continue;
+ } else {
+ throw;
+ }
+ }
+ break;
+ }
+
+ // Note: Iterator cannot be shared from prior to loading, since re-entrancy
+ // may have occurred, which may do anything.
+ loadedDialectModules.insert(dialectNamespace);
+}
+
+void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
+ py::object pyClass) {
+ py::object &found = dialectClassMap[dialectNamespace];
+ if (found) {
+ throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") +
+ dialectNamespace +
+ "' is already registered.");
+ }
+ found = std::move(pyClass);
+}
+
+void PyGlobals::registerOperationImpl(const std::string &operationName,
+ py::object pyClass, py::object rawClass) {
+ py::object &found = operationClassMap[operationName];
+ if (found) {
+ throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") +
+ operationName +
+ "' is already registered.");
+ }
+ found = std::move(pyClass);
+ rawOperationClassMap[operationName] = std::move(rawClass);
+}
+
+llvm::Optional<py::object>
+PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
+ loadDialectModule(dialectNamespace);
+ // Fast match against the class map first (common case).
+ const auto foundIt = dialectClassMap.find(dialectNamespace);
+ if (foundIt != dialectClassMap.end()) {
+ if (foundIt->second.is_none())
+ return llvm::None;
+ assert(foundIt->second && "py::object is defined");
+ return foundIt->second;
+ }
+
+ // Not found and loading did not yield a registration. Negative cache.
+ dialectClassMap[dialectNamespace] = py::none();
+ return llvm::None;
+}
+
+// -----------------------------------------------------------------------------
+// Module initialization.
+// -----------------------------------------------------------------------------
+
PYBIND11_MODULE(_mlir, m) {
m.doc() = "MLIR Python Native Extension";
+ py::class_<PyGlobals>(m, "_Globals")
+ .def_property("dialect_search_modules",
+ &PyGlobals::getDialectSearchPrefixes,
+ &PyGlobals::setDialectSearchPrefixes)
+ .def("append_dialect_search_prefix",
+ [](PyGlobals &self, std::string moduleName) {
+ self.getDialectSearchPrefixes().push_back(std::move(moduleName));
+ })
+ .def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
+ "Testing hook for directly registering a dialect")
+ .def("_register_operation_impl", &PyGlobals::registerOperationImpl,
+ "Testing hook for directly registering an operation");
+
+ // Aside from making the globals accessible to python, having python manage
+ // it is necessary to make sure it is destroyed (and releases its python
+ // resources) properly.
+ m.attr("globals") =
+ py::cast(new PyGlobals, py::return_value_policy::take_ownership);
+
+ // Registration decorators.
+ m.def(
+ "register_dialect",
+ [](py::object pyClass) {
+ std::string dialectNamespace =
+ pyClass.attr("DIALECT_NAMESPACE").cast<std::string>();
+ PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
+ return pyClass;
+ },
+ "Class decorator for registering a custom Dialect wrapper");
+ m.def(
+ "register_operation",
+ [](py::object dialectClass) -> py::cpp_function {
+ return py::cpp_function(
+ [dialectClass](py::object opClass) -> py::object {
+ std::string operationName =
+ opClass.attr("OPERATION_NAME").cast<std::string>();
+ auto rawSubclass = PyOpView::createRawSubclass(opClass);
+ PyGlobals::get().registerOperationImpl(operationName, opClass,
+ rawSubclass);
+
+ // Dict-stuff the new opClass by name onto the dialect class.
+ py::object opClassName = opClass.attr("__name__");
+ dialectClass.attr(opClassName) = opClass;
+
+ // Now create a special "Raw" subclass that passes through
+ // construction to the OpView parent (bypasses the intermediate
+ // child's __init__).
+ opClass.attr("_Raw") = rawSubclass;
+ return opClass;
+ });
+ },
+ "Class decorator for registering a custom Operation wrapper");
+
// Define and populate IR submodule.
auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
populateIRSubmodule(irModule);
diff --git a/mlir/lib/Bindings/Python/mlir/__init__.py b/mlir/lib/Bindings/Python/mlir/__init__.py
index 717526771c25..8f3b52c30f35 100644
--- a/mlir/lib/Bindings/Python/mlir/__init__.py
+++ b/mlir/lib/Bindings/Python/mlir/__init__.py
@@ -8,4 +8,37 @@
# and arbitrate any one-time initialization needed in various shared-library
# scenarios.
-from _mlir import *
+__all__ = [
+ "ir",
+]
+
+# Expose the corresponding C-Extension module with a well-known name at this
+# top-level module. This allows relative imports like the following to
+# function:
+# from .. import _cext
+# This reduces coupling, allowing embedding of the python sources into another
+# project that can just vary based on this top-level loader module.
+import _mlir as _cext
+
+def _reexport_cext(cext_module_name, target_module_name):
+ """Re-exports a named sub-module of the C-Extension into another module.
+
+ Typically:
+ from . import _reexport_cext
+ _reexport_cext("ir", __name__)
+ del _reexport_cext
+ """
+ import sys
+ target_module = sys.modules[target_module_name]
+ source_module = getattr(_cext, cext_module_name)
+ for attr_name in dir(source_module):
+ if not attr_name.startswith("__"):
+ setattr(target_module, attr_name, getattr(source_module, attr_name))
+
+
+# Import sub-modules. Since these may import from here, this must come after
+# any exported definitions.
+from . import ir
+
+# Add our 'dialects' parent module to the search path for implementations.
+_cext.globals.append_dialect_search_prefix("mlir.dialects")
diff --git a/mlir/lib/Bindings/Python/mlir/dialects/__init__.py b/mlir/lib/Bindings/Python/mlir/dialects/__init__.py
new file mode 100644
index 000000000000..1b7e62c030fb
--- /dev/null
+++ b/mlir/lib/Bindings/Python/mlir/dialects/__init__.py
@@ -0,0 +1,6 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# Re-export the parent _cext so that every level of the API can get it locally.
+from .. import _cext
diff --git a/mlir/lib/Bindings/Python/mlir/dialects/std.py b/mlir/lib/Bindings/Python/mlir/dialects/std.py
new file mode 100644
index 000000000000..2afc642e0e3d
--- /dev/null
+++ b/mlir/lib/Bindings/Python/mlir/dialects/std.py
@@ -0,0 +1,33 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# TODO: This file should be auto-generated.
+
+from . import _cext
+
+ at _cext.register_dialect
+class _Dialect(_cext.ir.Dialect):
+ # Special case: 'std' namespace aliases to the empty namespace.
+ DIALECT_NAMESPACE = "std"
+ pass
+
+ at _cext.register_operation(_Dialect)
+class AddFOp(_cext.ir.OpView):
+ OPERATION_NAME = "std.addf"
+
+ def __init__(self, loc, lhs, rhs):
+ super().__init__(loc.context.create_operation(
+ "std.addf", loc, operands=[lhs, rhs], results=[lhs.type]))
+
+ @property
+ def lhs(self):
+ return self.operation.operands[0]
+
+ @property
+ def rhs(self):
+ return self.operation.operands[1]
+
+ @property
+ def result(self):
+ return self.operation.results[0]
diff --git a/mlir/lib/Bindings/Python/mlir/ir.py b/mlir/lib/Bindings/Python/mlir/ir.py
new file mode 100644
index 000000000000..70d19737f5e6
--- /dev/null
+++ b/mlir/lib/Bindings/Python/mlir/ir.py
@@ -0,0 +1,8 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# Simply a wrapper around the extension module of the same name.
+from . import _reexport_cext
+_reexport_cext("ir", __name__)
+del _reexport_cext
diff --git a/mlir/test/Bindings/Python/dialects.py b/mlir/test/Bindings/Python/dialects.py
new file mode 100644
index 000000000000..bc88e8668f4d
--- /dev/null
+++ b/mlir/test/Bindings/Python/dialects.py
@@ -0,0 +1,107 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import gc
+import mlir
+
+def run(f):
+ print("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert mlir.ir.Context._get_live_count() == 0
+
+
+# CHECK-LABEL: TEST: testDialectDescriptor
+def testDialectDescriptor():
+ ctx = mlir.ir.Context()
+ d = ctx.get_dialect_descriptor("std")
+ # CHECK: <DialectDescriptor std>
+ print(d)
+ # CHECK: std
+ print(d.namespace)
+ try:
+ _ = ctx.get_dialect_descriptor("not_existing")
+ except ValueError:
+ pass
+ else:
+ assert False, "Expected exception"
+
+run(testDialectDescriptor)
+
+
+# CHECK-LABEL: TEST: testUserDialectClass
+def testUserDialectClass():
+ ctx = mlir.ir.Context()
+ # Access using attribute.
+ d = ctx.dialects.std
+ # Note that the standard dialect namespace prints as ''. Others will print
+ # as "<Dialect %namespace (..."
+ # CHECK: <Dialect (class mlir.dialects.std._Dialect)>
+ print(d)
+ try:
+ _ = ctx.dialects.not_existing
+ except AttributeError:
+ pass
+ else:
+ assert False, "Expected exception"
+
+ # Access using index.
+ d = ctx.dialects["std"]
+ # CHECK: <Dialect (class mlir.dialects.std._Dialect)>
+ print(d)
+ try:
+ _ = ctx.dialects["not_existing"]
+ except IndexError:
+ pass
+ else:
+ assert False, "Expected exception"
+
+ # Using the 'd' alias.
+ d = ctx.d["std"]
+ # CHECK: <Dialect (class mlir.dialects.std._Dialect)>
+ print(d)
+
+run(testUserDialectClass)
+
+
+# CHECK-LABEL: TEST: testCustomOpView
+# This test uses the standard dialect AddFOp as an example of a user op.
+# TODO: Op creation and access is still quite verbose: simplify this test as
+# additional capabilities come online.
+def testCustomOpView():
+ ctx = mlir.ir.Context()
+ ctx.allow_unregistered_dialects = True
+ f32 = mlir.ir.F32Type.get(ctx)
+ loc = ctx.get_unknown_location()
+ m = ctx.create_module(loc)
+ m_block = m.operation.regions[0].blocks[0]
+ # TODO: Remove integer insertion in favor of InsertionPoint and/or op-based.
+ ip = [0]
+ def createInput():
+ op = ctx.create_operation("pytest_dummy.intinput", loc, results=[f32])
+ m_block.operations.insert(ip[0], op)
+ ip[0] += 1
+ # TODO: Auto result cast from operation
+ return op.results[0]
+
+ # Create via dialects context collection.
+ input1 = createInput()
+ input2 = createInput()
+ op1 = ctx.dialects.std.AddFOp(loc, input1, input2)
+ # TODO: Auto operation cast from OpView
+ # TODO: Context manager insertion point
+ m_block.operations.insert(ip[0], op1.operation)
+ ip[0] += 1
+
+ # Create via an import
+ from mlir.dialects.std import AddFOp
+ op2 = AddFOp(loc, input1, op1.result)
+ m_block.operations.insert(ip[0], op2.operation)
+ ip[0] += 1
+
+ # CHECK: %[[INPUT0:.*]] = "pytest_dummy.intinput"
+ # CHECK: %[[INPUT1:.*]] = "pytest_dummy.intinput"
+ # CHECK: %[[R0:.*]] = addf %[[INPUT0]], %[[INPUT1]] : f32
+ # CHECK: %[[R1:.*]] = addf %[[INPUT0]], %[[R0]] : f32
+ m.operation.print()
+
+run(testCustomOpView)
More information about the Mlir-commits
mailing list