[Mlir-commits] [mlir] 436c6c9 - NFC: Break up the mlir python bindings into individual sources.
Stella Laurenzo
llvmlistbot at llvm.org
Fri Mar 19 13:34:36 PDT 2021
Author: Stella Laurenzo
Date: 2021-03-19T13:33:51-07:00
New Revision: 436c6c9c20cc522c92a923440a5fc509c342a7db
URL: https://github.com/llvm/llvm-project/commit/436c6c9c20cc522c92a923440a5fc509c342a7db
DIFF: https://github.com/llvm/llvm-project/commit/436c6c9c20cc522c92a923440a5fc509c342a7db.diff
LOG: NFC: Break up the mlir python bindings into individual sources.
* IRModules.cpp -> (IRCore.cpp, IRAffine.cpp, IRAttributes.cpp, IRTypes.cpp).
* The individual pieces now compile in the 5-15s range whereas IRModules.cpp was starting to approach a minute (didn't capture a before time).
* More fine grained splitting is possible, but this represents the most obvious.
Differential Revision: https://reviews.llvm.org/D98978
Added:
mlir/lib/Bindings/Python/IRAffine.cpp
mlir/lib/Bindings/Python/IRAttributes.cpp
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/Bindings/Python/IRModule.h
mlir/lib/Bindings/Python/IRTypes.cpp
Modified:
mlir/lib/Bindings/Python/CMakeLists.txt
mlir/lib/Bindings/Python/ExecutionEngine.cpp
mlir/lib/Bindings/Python/MainModule.cpp
mlir/lib/Bindings/Python/Pass.cpp
Removed:
mlir/lib/Bindings/Python/IRModules.cpp
mlir/lib/Bindings/Python/IRModules.h
################################################################################
diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt
index 5f042ec57c29..5fefa80398c7 100644
--- a/mlir/lib/Bindings/Python/CMakeLists.txt
+++ b/mlir/lib/Bindings/Python/CMakeLists.txt
@@ -70,7 +70,10 @@ add_mlir_python_extension(MLIRCoreBindingsPythonExtension _mlir
python
SOURCES
MainModule.cpp
- IRModules.cpp
+ IRAffine.cpp
+ IRAttributes.cpp
+ IRCore.cpp
+ IRTypes.cpp
PybindUtils.cpp
Pass.cpp
ExecutionEngine.cpp
diff --git a/mlir/lib/Bindings/Python/ExecutionEngine.cpp b/mlir/lib/Bindings/Python/ExecutionEngine.cpp
index f6f52e2e0aae..5ca9b1f68128 100644
--- a/mlir/lib/Bindings/Python/ExecutionEngine.cpp
+++ b/mlir/lib/Bindings/Python/ExecutionEngine.cpp
@@ -8,7 +8,7 @@
#include "ExecutionEngine.h"
-#include "IRModules.h"
+#include "IRModule.h"
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/ExecutionEngine.h"
diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
new file mode 100644
index 000000000000..73a57d95e158
--- /dev/null
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -0,0 +1,781 @@
+//===- IRAffine.cpp - Exports 'ir' module affine related bindings ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "IRModule.h"
+
+#include "PybindUtils.h"
+
+#include "mlir-c/AffineMap.h"
+#include "mlir-c/Bindings/Python/Interop.h"
+#include "mlir-c/IntegerSet.h"
+
+namespace py = pybind11;
+using namespace mlir;
+using namespace mlir::python;
+
+using llvm::SmallVector;
+using llvm::StringRef;
+using llvm::Twine;
+
+static const char kDumpDocstring[] =
+ R"(Dumps a debug representation of the object to stderr.)";
+
+/// Attempts to populate `result` with the content of `list` casted to the
+/// appropriate type (Python and C types are provided as template arguments).
+/// Throws errors in case of failure, using "action" to describe what the caller
+/// was attempting to do.
+template <typename PyType, typename CType>
+static void pyListToVector(py::list list, llvm::SmallVectorImpl<CType> &result,
+ StringRef action) {
+ result.reserve(py::len(list));
+ for (py::handle item : list) {
+ try {
+ result.push_back(item.cast<PyType>());
+ } catch (py::cast_error &err) {
+ std::string msg = (llvm::Twine("Invalid expression when ") + action +
+ " (" + err.what() + ")")
+ .str();
+ throw py::cast_error(msg);
+ } catch (py::reference_cast_error &err) {
+ std::string msg = (llvm::Twine("Invalid expression (None?) when ") +
+ action + " (" + err.what() + ")")
+ .str();
+ throw py::cast_error(msg);
+ }
+ }
+}
+
+template <typename PermutationTy>
+static bool isPermutation(std::vector<PermutationTy> permutation) {
+ llvm::SmallVector<bool, 8> seen(permutation.size(), false);
+ for (auto val : permutation) {
+ if (val < permutation.size()) {
+ if (seen[val])
+ return false;
+ seen[val] = true;
+ continue;
+ }
+ return false;
+ }
+ return true;
+}
+
+namespace {
+
+/// CRTP base class for Python MLIR affine expressions that subclass AffineExpr
+/// and should be castable from it. Intermediate hierarchy classes can be
+/// modeled by specifying BaseTy.
+template <typename DerivedTy, typename BaseTy = PyAffineExpr>
+class PyConcreteAffineExpr : public BaseTy {
+public:
+ // Derived classes must define statics for:
+ // IsAFunctionTy isaFunction
+ // const char *pyClassName
+ // and redefine bindDerived.
+ using ClassTy = py::class_<DerivedTy, BaseTy>;
+ using IsAFunctionTy = bool (*)(MlirAffineExpr);
+
+ PyConcreteAffineExpr() = default;
+ PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
+ : BaseTy(std::move(contextRef), affineExpr) {}
+ PyConcreteAffineExpr(PyAffineExpr &orig)
+ : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {}
+
+ static MlirAffineExpr castFrom(PyAffineExpr &orig) {
+ if (!DerivedTy::isaFunction(orig)) {
+ auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
+ throw SetPyError(PyExc_ValueError,
+ Twine("Cannot cast affine expression to ") +
+ DerivedTy::pyClassName + " (from " + origRepr + ")");
+ }
+ return orig;
+ }
+
+ static void bind(py::module &m) {
+ auto cls = ClassTy(m, DerivedTy::pyClassName);
+ cls.def(py::init<PyAffineExpr &>());
+ DerivedTy::bindDerived(cls);
+ }
+
+ /// Implemented by derived classes to add methods to the Python subclass.
+ static void bindDerived(ClassTy &m) {}
+};
+
+class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant;
+ static constexpr const char *pyClassName = "AffineConstantExpr";
+ using PyConcreteAffineExpr::PyConcreteAffineExpr;
+
+ static PyAffineConstantExpr get(intptr_t value,
+ DefaultingPyMlirContext context) {
+ MlirAffineExpr affineExpr =
+ mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value));
+ return PyAffineConstantExpr(context->getRef(), affineExpr);
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"),
+ py::arg("context") = py::none());
+ c.def_property_readonly("value", [](PyAffineConstantExpr &self) {
+ return mlirAffineConstantExprGetValue(self);
+ });
+ }
+};
+
+class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim;
+ static constexpr const char *pyClassName = "AffineDimExpr";
+ using PyConcreteAffineExpr::PyConcreteAffineExpr;
+
+ static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) {
+ MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos);
+ return PyAffineDimExpr(context->getRef(), affineExpr);
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static("get", &PyAffineDimExpr::get, py::arg("position"),
+ py::arg("context") = py::none());
+ c.def_property_readonly("position", [](PyAffineDimExpr &self) {
+ return mlirAffineDimExprGetPosition(self);
+ });
+ }
+};
+
+class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol;
+ static constexpr const char *pyClassName = "AffineSymbolExpr";
+ using PyConcreteAffineExpr::PyConcreteAffineExpr;
+
+ static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) {
+ MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos);
+ return PyAffineSymbolExpr(context->getRef(), affineExpr);
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"),
+ py::arg("context") = py::none());
+ c.def_property_readonly("position", [](PyAffineSymbolExpr &self) {
+ return mlirAffineSymbolExprGetPosition(self);
+ });
+ }
+};
+
+class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary;
+ static constexpr const char *pyClassName = "AffineBinaryExpr";
+ using PyConcreteAffineExpr::PyConcreteAffineExpr;
+
+ PyAffineExpr lhs() {
+ MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get());
+ return PyAffineExpr(getContext(), lhsExpr);
+ }
+
+ PyAffineExpr rhs() {
+ MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get());
+ return PyAffineExpr(getContext(), rhsExpr);
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs);
+ c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs);
+ }
+};
+
+class PyAffineAddExpr
+ : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd;
+ static constexpr const char *pyClassName = "AffineAddExpr";
+ using PyConcreteAffineExpr::PyConcreteAffineExpr;
+
+ static PyAffineAddExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
+ MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs);
+ return PyAffineAddExpr(lhs.getContext(), expr);
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static("get", &PyAffineAddExpr::get);
+ }
+};
+
+class PyAffineMulExpr
+ : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul;
+ static constexpr const char *pyClassName = "AffineMulExpr";
+ using PyConcreteAffineExpr::PyConcreteAffineExpr;
+
+ static PyAffineMulExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
+ MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs);
+ return PyAffineMulExpr(lhs.getContext(), expr);
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static("get", &PyAffineMulExpr::get);
+ }
+};
+
+class PyAffineModExpr
+ : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod;
+ static constexpr const char *pyClassName = "AffineModExpr";
+ using PyConcreteAffineExpr::PyConcreteAffineExpr;
+
+ static PyAffineModExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
+ MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs);
+ return PyAffineModExpr(lhs.getContext(), expr);
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static("get", &PyAffineModExpr::get);
+ }
+};
+
+class PyAffineFloorDivExpr
+ : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv;
+ static constexpr const char *pyClassName = "AffineFloorDivExpr";
+ using PyConcreteAffineExpr::PyConcreteAffineExpr;
+
+ static PyAffineFloorDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
+ MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs);
+ return PyAffineFloorDivExpr(lhs.getContext(), expr);
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static("get", &PyAffineFloorDivExpr::get);
+ }
+};
+
+class PyAffineCeilDivExpr
+ : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv;
+ static constexpr const char *pyClassName = "AffineCeilDivExpr";
+ using PyConcreteAffineExpr::PyConcreteAffineExpr;
+
+ static PyAffineCeilDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
+ MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs);
+ return PyAffineCeilDivExpr(lhs.getContext(), expr);
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static("get", &PyAffineCeilDivExpr::get);
+ }
+};
+
+} // namespace
+
+bool PyAffineExpr::operator==(const PyAffineExpr &other) {
+ return mlirAffineExprEqual(affineExpr, other.affineExpr);
+}
+
+py::object PyAffineExpr::getCapsule() {
+ return py::reinterpret_steal<py::object>(
+ mlirPythonAffineExprToCapsule(*this));
+}
+
+PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) {
+ MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr());
+ if (mlirAffineExprIsNull(rawAffineExpr))
+ throw py::error_already_set();
+ return PyAffineExpr(
+ PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)),
+ rawAffineExpr);
+}
+
+//------------------------------------------------------------------------------
+// PyAffineMap and utilities.
+//------------------------------------------------------------------------------
+namespace {
+
+/// A list of expressions contained in an affine map. Internally these are
+/// stored as a consecutive array leading to inexpensive random access. Both
+/// the map and the expression are owned by the context so we need not bother
+/// with lifetime extension.
+class PyAffineMapExprList
+ : public Sliceable<PyAffineMapExprList, PyAffineExpr> {
+public:
+ static constexpr const char *pyClassName = "AffineExprList";
+
+ PyAffineMapExprList(PyAffineMap map, intptr_t startIndex = 0,
+ intptr_t length = -1, intptr_t step = 1)
+ : Sliceable(startIndex,
+ length == -1 ? mlirAffineMapGetNumResults(map) : length,
+ step),
+ affineMap(map) {}
+
+ intptr_t getNumElements() { return mlirAffineMapGetNumResults(affineMap); }
+
+ PyAffineExpr getElement(intptr_t pos) {
+ return PyAffineExpr(affineMap.getContext(),
+ mlirAffineMapGetResult(affineMap, pos));
+ }
+
+ PyAffineMapExprList slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) {
+ return PyAffineMapExprList(affineMap, startIndex, length, step);
+ }
+
+private:
+ PyAffineMap affineMap;
+};
+} // end namespace
+
+bool PyAffineMap::operator==(const PyAffineMap &other) {
+ return mlirAffineMapEqual(affineMap, other.affineMap);
+}
+
+py::object PyAffineMap::getCapsule() {
+ return py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(*this));
+}
+
+PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) {
+ MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr());
+ if (mlirAffineMapIsNull(rawAffineMap))
+ throw py::error_already_set();
+ return PyAffineMap(
+ PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)),
+ rawAffineMap);
+}
+
+//------------------------------------------------------------------------------
+// PyIntegerSet and utilities.
+//------------------------------------------------------------------------------
+namespace {
+
+class PyIntegerSetConstraint {
+public:
+ PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {}
+
+ PyAffineExpr getExpr() {
+ return PyAffineExpr(set.getContext(),
+ mlirIntegerSetGetConstraint(set, pos));
+ }
+
+ bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); }
+
+ static void bind(py::module &m) {
+ py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint")
+ .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr)
+ .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq);
+ }
+
+private:
+ PyIntegerSet set;
+ intptr_t pos;
+};
+
+class PyIntegerSetConstraintList
+ : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> {
+public:
+ static constexpr const char *pyClassName = "IntegerSetConstraintList";
+
+ PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0,
+ intptr_t length = -1, intptr_t step = 1)
+ : Sliceable(startIndex,
+ length == -1 ? mlirIntegerSetGetNumConstraints(set) : length,
+ step),
+ set(set) {}
+
+ intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); }
+
+ PyIntegerSetConstraint getElement(intptr_t pos) {
+ return PyIntegerSetConstraint(set, pos);
+ }
+
+ PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) {
+ return PyIntegerSetConstraintList(set, startIndex, length, step);
+ }
+
+private:
+ PyIntegerSet set;
+};
+} // namespace
+
+bool PyIntegerSet::operator==(const PyIntegerSet &other) {
+ return mlirIntegerSetEqual(integerSet, other.integerSet);
+}
+
+py::object PyIntegerSet::getCapsule() {
+ return py::reinterpret_steal<py::object>(
+ mlirPythonIntegerSetToCapsule(*this));
+}
+
+PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) {
+ MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr());
+ if (mlirIntegerSetIsNull(rawIntegerSet))
+ throw py::error_already_set();
+ return PyIntegerSet(
+ PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)),
+ rawIntegerSet);
+}
+
+void mlir::python::populateIRAffine(py::module &m) {
+ //----------------------------------------------------------------------------
+ // Mapping of PyAffineExpr and derived classes.
+ //----------------------------------------------------------------------------
+ py::class_<PyAffineExpr>(m, "AffineExpr")
+ .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
+ &PyAffineExpr::getCapsule)
+ .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule)
+ .def("__add__",
+ [](PyAffineExpr &self, PyAffineExpr &other) {
+ return PyAffineAddExpr::get(self, other);
+ })
+ .def("__mul__",
+ [](PyAffineExpr &self, PyAffineExpr &other) {
+ return PyAffineMulExpr::get(self, other);
+ })
+ .def("__mod__",
+ [](PyAffineExpr &self, PyAffineExpr &other) {
+ return PyAffineModExpr::get(self, other);
+ })
+ .def("__sub__",
+ [](PyAffineExpr &self, PyAffineExpr &other) {
+ auto negOne =
+ PyAffineConstantExpr::get(-1, *self.getContext().get());
+ return PyAffineAddExpr::get(self,
+ PyAffineMulExpr::get(negOne, other));
+ })
+ .def("__eq__", [](PyAffineExpr &self,
+ PyAffineExpr &other) { return self == other; })
+ .def("__eq__",
+ [](PyAffineExpr &self, py::object &other) { return false; })
+ .def("__str__",
+ [](PyAffineExpr &self) {
+ PyPrintAccumulator printAccum;
+ mlirAffineExprPrint(self, printAccum.getCallback(),
+ printAccum.getUserData());
+ return printAccum.join();
+ })
+ .def("__repr__",
+ [](PyAffineExpr &self) {
+ PyPrintAccumulator printAccum;
+ printAccum.parts.append("AffineExpr(");
+ mlirAffineExprPrint(self, printAccum.getCallback(),
+ printAccum.getUserData());
+ printAccum.parts.append(")");
+ return printAccum.join();
+ })
+ .def_property_readonly(
+ "context",
+ [](PyAffineExpr &self) { return self.getContext().getObject(); })
+ .def_static(
+ "get_add", &PyAffineAddExpr::get,
+ "Gets an affine expression containing a sum of two expressions.")
+ .def_static(
+ "get_mul", &PyAffineMulExpr::get,
+ "Gets an affine expression containing a product of two expressions.")
+ .def_static("get_mod", &PyAffineModExpr::get,
+ "Gets an affine expression containing the modulo of dividing "
+ "one expression by another.")
+ .def_static("get_floor_div", &PyAffineFloorDivExpr::get,
+ "Gets an affine expression containing the rounded-down "
+ "result of dividing one expression by another.")
+ .def_static("get_ceil_div", &PyAffineCeilDivExpr::get,
+ "Gets an affine expression containing the rounded-up result "
+ "of dividing one expression by another.")
+ .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"),
+ py::arg("context") = py::none(),
+ "Gets a constant affine expression with the given value.")
+ .def_static(
+ "get_dim", &PyAffineDimExpr::get, py::arg("position"),
+ py::arg("context") = py::none(),
+ "Gets an affine expression of a dimension at the given position.")
+ .def_static(
+ "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"),
+ py::arg("context") = py::none(),
+ "Gets an affine expression of a symbol at the given position.")
+ .def(
+ "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); },
+ kDumpDocstring);
+ PyAffineConstantExpr::bind(m);
+ PyAffineDimExpr::bind(m);
+ PyAffineSymbolExpr::bind(m);
+ PyAffineBinaryExpr::bind(m);
+ PyAffineAddExpr::bind(m);
+ PyAffineMulExpr::bind(m);
+ PyAffineModExpr::bind(m);
+ PyAffineFloorDivExpr::bind(m);
+ PyAffineCeilDivExpr::bind(m);
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyAffineMap.
+ //----------------------------------------------------------------------------
+ py::class_<PyAffineMap>(m, "AffineMap")
+ .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
+ &PyAffineMap::getCapsule)
+ .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule)
+ .def("__eq__",
+ [](PyAffineMap &self, PyAffineMap &other) { return self == other; })
+ .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; })
+ .def("__str__",
+ [](PyAffineMap &self) {
+ PyPrintAccumulator printAccum;
+ mlirAffineMapPrint(self, printAccum.getCallback(),
+ printAccum.getUserData());
+ return printAccum.join();
+ })
+ .def("__repr__",
+ [](PyAffineMap &self) {
+ PyPrintAccumulator printAccum;
+ printAccum.parts.append("AffineMap(");
+ mlirAffineMapPrint(self, printAccum.getCallback(),
+ printAccum.getUserData());
+ printAccum.parts.append(")");
+ return printAccum.join();
+ })
+ .def_property_readonly(
+ "context",
+ [](PyAffineMap &self) { return self.getContext().getObject(); },
+ "Context that owns the Affine Map")
+ .def(
+ "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); },
+ kDumpDocstring)
+ .def_static(
+ "get",
+ [](intptr_t dimCount, intptr_t symbolCount, py::list exprs,
+ DefaultingPyMlirContext context) {
+ SmallVector<MlirAffineExpr> affineExprs;
+ pyListToVector<PyAffineExpr, MlirAffineExpr>(
+ exprs, affineExprs, "attempting to create an AffineMap");
+ MlirAffineMap map =
+ mlirAffineMapGet(context->get(), dimCount, symbolCount,
+ affineExprs.size(), affineExprs.data());
+ return PyAffineMap(context->getRef(), map);
+ },
+ py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"),
+ py::arg("context") = py::none(),
+ "Gets a map with the given expressions as results.")
+ .def_static(
+ "get_constant",
+ [](intptr_t value, DefaultingPyMlirContext context) {
+ MlirAffineMap affineMap =
+ mlirAffineMapConstantGet(context->get(), value);
+ return PyAffineMap(context->getRef(), affineMap);
+ },
+ py::arg("value"), py::arg("context") = py::none(),
+ "Gets an affine map with a single constant result")
+ .def_static(
+ "get_empty",
+ [](DefaultingPyMlirContext context) {
+ MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get());
+ return PyAffineMap(context->getRef(), affineMap);
+ },
+ py::arg("context") = py::none(), "Gets an empty affine map.")
+ .def_static(
+ "get_identity",
+ [](intptr_t nDims, DefaultingPyMlirContext context) {
+ MlirAffineMap affineMap =
+ mlirAffineMapMultiDimIdentityGet(context->get(), nDims);
+ return PyAffineMap(context->getRef(), affineMap);
+ },
+ py::arg("n_dims"), py::arg("context") = py::none(),
+ "Gets an identity map with the given number of dimensions.")
+ .def_static(
+ "get_minor_identity",
+ [](intptr_t nDims, intptr_t nResults,
+ DefaultingPyMlirContext context) {
+ MlirAffineMap affineMap =
+ mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults);
+ return PyAffineMap(context->getRef(), affineMap);
+ },
+ py::arg("n_dims"), py::arg("n_results"),
+ py::arg("context") = py::none(),
+ "Gets a minor identity map with the given number of dimensions and "
+ "results.")
+ .def_static(
+ "get_permutation",
+ [](std::vector<unsigned> permutation,
+ DefaultingPyMlirContext context) {
+ if (!isPermutation(permutation))
+ throw py::cast_error("Invalid permutation when attempting to "
+ "create an AffineMap");
+ MlirAffineMap affineMap = mlirAffineMapPermutationGet(
+ context->get(), permutation.size(), permutation.data());
+ return PyAffineMap(context->getRef(), affineMap);
+ },
+ py::arg("permutation"), py::arg("context") = py::none(),
+ "Gets an affine map that permutes its inputs.")
+ .def("get_submap",
+ [](PyAffineMap &self, std::vector<intptr_t> &resultPos) {
+ intptr_t numResults = mlirAffineMapGetNumResults(self);
+ for (intptr_t pos : resultPos) {
+ if (pos < 0 || pos >= numResults)
+ throw py::value_error("result position out of bounds");
+ }
+ MlirAffineMap affineMap = mlirAffineMapGetSubMap(
+ self, resultPos.size(), resultPos.data());
+ return PyAffineMap(self.getContext(), affineMap);
+ })
+ .def("get_major_submap",
+ [](PyAffineMap &self, intptr_t nResults) {
+ if (nResults >= mlirAffineMapGetNumResults(self))
+ throw py::value_error("number of results out of bounds");
+ MlirAffineMap affineMap =
+ mlirAffineMapGetMajorSubMap(self, nResults);
+ return PyAffineMap(self.getContext(), affineMap);
+ })
+ .def("get_minor_submap",
+ [](PyAffineMap &self, intptr_t nResults) {
+ if (nResults >= mlirAffineMapGetNumResults(self))
+ throw py::value_error("number of results out of bounds");
+ MlirAffineMap affineMap =
+ mlirAffineMapGetMinorSubMap(self, nResults);
+ return PyAffineMap(self.getContext(), affineMap);
+ })
+ .def_property_readonly(
+ "is_permutation",
+ [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); })
+ .def_property_readonly("is_projected_permutation",
+ [](PyAffineMap &self) {
+ return mlirAffineMapIsProjectedPermutation(self);
+ })
+ .def_property_readonly(
+ "n_dims",
+ [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); })
+ .def_property_readonly(
+ "n_inputs",
+ [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); })
+ .def_property_readonly(
+ "n_symbols",
+ [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); })
+ .def_property_readonly("results", [](PyAffineMap &self) {
+ return PyAffineMapExprList(self);
+ });
+ PyAffineMapExprList::bind(m);
+
+ //----------------------------------------------------------------------------
+ // Mapping of PyIntegerSet.
+ //----------------------------------------------------------------------------
+ py::class_<PyIntegerSet>(m, "IntegerSet")
+ .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
+ &PyIntegerSet::getCapsule)
+ .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule)
+ .def("__eq__", [](PyIntegerSet &self,
+ PyIntegerSet &other) { return self == other; })
+ .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; })
+ .def("__str__",
+ [](PyIntegerSet &self) {
+ PyPrintAccumulator printAccum;
+ mlirIntegerSetPrint(self, printAccum.getCallback(),
+ printAccum.getUserData());
+ return printAccum.join();
+ })
+ .def("__repr__",
+ [](PyIntegerSet &self) {
+ PyPrintAccumulator printAccum;
+ printAccum.parts.append("IntegerSet(");
+ mlirIntegerSetPrint(self, printAccum.getCallback(),
+ printAccum.getUserData());
+ printAccum.parts.append(")");
+ return printAccum.join();
+ })
+ .def_property_readonly(
+ "context",
+ [](PyIntegerSet &self) { return self.getContext().getObject(); })
+ .def(
+ "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); },
+ kDumpDocstring)
+ .def_static(
+ "get",
+ [](intptr_t numDims, intptr_t numSymbols, py::list exprs,
+ std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
+ if (exprs.size() != eqFlags.size())
+ throw py::value_error(
+ "Expected the number of constraints to match "
+ "that of equality flags");
+ if (exprs.empty())
+ throw py::value_error("Expected non-empty list of constraints");
+
+ // Copy over to a SmallVector because std::vector has a
+ // specialization for booleans that packs data and does not
+ // expose a `bool *`.
+ SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end());
+
+ SmallVector<MlirAffineExpr> affineExprs;
+ pyListToVector<PyAffineExpr>(exprs, affineExprs,
+ "attempting to create an IntegerSet");
+ MlirIntegerSet set = mlirIntegerSetGet(
+ context->get(), numDims, numSymbols, exprs.size(),
+ affineExprs.data(), flags.data());
+ return PyIntegerSet(context->getRef(), set);
+ },
+ py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"),
+ py::arg("eq_flags"), py::arg("context") = py::none())
+ .def_static(
+ "get_empty",
+ [](intptr_t numDims, intptr_t numSymbols,
+ DefaultingPyMlirContext context) {
+ MlirIntegerSet set =
+ mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols);
+ return PyIntegerSet(context->getRef(), set);
+ },
+ py::arg("num_dims"), py::arg("num_symbols"),
+ py::arg("context") = py::none())
+ .def("get_replaced",
+ [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs,
+ intptr_t numResultDims, intptr_t numResultSymbols) {
+ if (static_cast<intptr_t>(dimExprs.size()) !=
+ mlirIntegerSetGetNumDims(self))
+ throw py::value_error(
+ "Expected the number of dimension replacement expressions "
+ "to match that of dimensions");
+ if (static_cast<intptr_t>(symbolExprs.size()) !=
+ mlirIntegerSetGetNumSymbols(self))
+ throw py::value_error(
+ "Expected the number of symbol replacement expressions "
+ "to match that of symbols");
+
+ SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs;
+ pyListToVector<PyAffineExpr>(
+ dimExprs, dimAffineExprs,
+ "attempting to create an IntegerSet by replacing dimensions");
+ pyListToVector<PyAffineExpr>(
+ symbolExprs, symbolAffineExprs,
+ "attempting to create an IntegerSet by replacing symbols");
+ MlirIntegerSet set = mlirIntegerSetReplaceGet(
+ self, dimAffineExprs.data(), symbolAffineExprs.data(),
+ numResultDims, numResultSymbols);
+ return PyIntegerSet(self.getContext(), set);
+ })
+ .def_property_readonly("is_canonical_empty",
+ [](PyIntegerSet &self) {
+ return mlirIntegerSetIsCanonicalEmpty(self);
+ })
+ .def_property_readonly(
+ "n_dims",
+ [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); })
+ .def_property_readonly(
+ "n_symbols",
+ [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); })
+ .def_property_readonly(
+ "n_inputs",
+ [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); })
+ .def_property_readonly("n_equalities",
+ [](PyIntegerSet &self) {
+ return mlirIntegerSetGetNumEqualities(self);
+ })
+ .def_property_readonly("n_inequalities",
+ [](PyIntegerSet &self) {
+ return mlirIntegerSetGetNumInequalities(self);
+ })
+ .def_property_readonly("constraints", [](PyIntegerSet &self) {
+ return PyIntegerSetConstraintList(self);
+ });
+ PyIntegerSetConstraint::bind(m);
+ PyIntegerSetConstraintList::bind(m);
+}
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
new file mode 100644
index 000000000000..6f9206c1b912
--- /dev/null
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -0,0 +1,761 @@
+//===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "IRModule.h"
+
+#include "PybindUtils.h"
+
+#include "mlir-c/BuiltinAttributes.h"
+#include "mlir-c/BuiltinTypes.h"
+
+namespace py = pybind11;
+using namespace mlir;
+using namespace mlir::python;
+
+using llvm::SmallVector;
+using llvm::StringRef;
+using llvm::Twine;
+
+namespace {
+
+static MlirStringRef toMlirStringRef(const std::string &s) {
+ return mlirStringRefCreate(s.data(), s.size());
+}
+
+/// CRTP base classes for Python attributes that subclass Attribute and should
+/// be castable from it (i.e. via something like StringAttr(attr)).
+/// By default, attribute class hierarchies are one level deep (i.e. a
+/// concrete attribute class extends PyAttribute); however, intermediate
+/// python-visible base classes can be modeled by specifying a BaseTy.
+template <typename DerivedTy, typename BaseTy = PyAttribute>
+class PyConcreteAttribute : public BaseTy {
+public:
+ // Derived classes must define statics for:
+ // IsAFunctionTy isaFunction
+ // const char *pyClassName
+ using ClassTy = py::class_<DerivedTy, BaseTy>;
+ using IsAFunctionTy = bool (*)(MlirAttribute);
+
+ PyConcreteAttribute() = default;
+ PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
+ : BaseTy(std::move(contextRef), attr) {}
+ PyConcreteAttribute(PyAttribute &orig)
+ : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {}
+
+ static MlirAttribute castFrom(PyAttribute &orig) {
+ if (!DerivedTy::isaFunction(orig)) {
+ auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
+ throw SetPyError(PyExc_ValueError, Twine("Cannot cast attribute to ") +
+ DerivedTy::pyClassName +
+ " (from " + origRepr + ")");
+ }
+ return orig;
+ }
+
+ static void bind(py::module &m) {
+ auto cls = ClassTy(m, DerivedTy::pyClassName, py::buffer_protocol());
+ cls.def(py::init<PyAttribute &>(), py::keep_alive<0, 1>());
+ DerivedTy::bindDerived(cls);
+ }
+
+ /// Implemented by derived classes to add methods to the Python subclass.
+ static void bindDerived(ClassTy &m) {}
+};
+
+class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
+ static constexpr const char *pyClassName = "AffineMapAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyAffineMap &affineMap) {
+ MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
+ return PyAffineMapAttribute(affineMap.getContext(), attr);
+ },
+ py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
+ }
+};
+
+class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
+ static constexpr const char *pyClassName = "ArrayAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ class PyArrayAttributeIterator {
+ public:
+ PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {}
+
+ PyArrayAttributeIterator &dunderIter() { return *this; }
+
+ PyAttribute dunderNext() {
+ if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) {
+ throw py::stop_iteration();
+ }
+ return PyAttribute(attr.getContext(),
+ mlirArrayAttrGetElement(attr.get(), nextIndex++));
+ }
+
+ static void bind(py::module &m) {
+ py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator")
+ .def("__iter__", &PyArrayAttributeIterator::dunderIter)
+ .def("__next__", &PyArrayAttributeIterator::dunderNext);
+ }
+
+ private:
+ PyAttribute attr;
+ int nextIndex = 0;
+ };
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](py::list attributes, DefaultingPyMlirContext context) {
+ SmallVector<MlirAttribute> mlirAttributes;
+ mlirAttributes.reserve(py::len(attributes));
+ for (auto attribute : attributes) {
+ try {
+ mlirAttributes.push_back(attribute.cast<PyAttribute>());
+ } catch (py::cast_error &err) {
+ std::string msg = std::string("Invalid attribute when attempting "
+ "to create an ArrayAttribute (") +
+ err.what() + ")";
+ throw py::cast_error(msg);
+ } catch (py::reference_cast_error &err) {
+ // This exception seems thrown when the value is "None".
+ std::string msg =
+ std::string("Invalid attribute (None?) when attempting to "
+ "create an ArrayAttribute (") +
+ err.what() + ")";
+ throw py::cast_error(msg);
+ }
+ }
+ MlirAttribute attr = mlirArrayAttrGet(
+ context->get(), mlirAttributes.size(), mlirAttributes.data());
+ return PyArrayAttribute(context->getRef(), attr);
+ },
+ py::arg("attributes"), py::arg("context") = py::none(),
+ "Gets a uniqued Array attribute");
+ c.def("__getitem__",
+ [](PyArrayAttribute &arr, intptr_t i) {
+ if (i >= mlirArrayAttrGetNumElements(arr))
+ throw py::index_error("ArrayAttribute index out of range");
+ return PyAttribute(arr.getContext(),
+ mlirArrayAttrGetElement(arr, i));
+ })
+ .def("__len__",
+ [](const PyArrayAttribute &arr) {
+ return mlirArrayAttrGetNumElements(arr);
+ })
+ .def("__iter__", [](const PyArrayAttribute &arr) {
+ return PyArrayAttributeIterator(arr);
+ });
+ }
+};
+
+/// Float Point Attribute subclass - FloatAttr.
+class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
+ static constexpr const char *pyClassName = "FloatAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyType &type, double value, DefaultingPyLocation loc) {
+ MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
+ // TODO: Rework error reporting once diagnostic engine is exposed
+ // in C API.
+ if (mlirAttributeIsNull(attr)) {
+ throw SetPyError(PyExc_ValueError,
+ Twine("invalid '") +
+ py::repr(py::cast(type)).cast<std::string>() +
+ "' and expected floating point type.");
+ }
+ return PyFloatAttribute(type.getContext(), attr);
+ },
+ py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
+ "Gets an uniqued float point attribute associated to a type");
+ c.def_static(
+ "get_f32",
+ [](double value, DefaultingPyMlirContext context) {
+ MlirAttribute attr = mlirFloatAttrDoubleGet(
+ context->get(), mlirF32TypeGet(context->get()), value);
+ return PyFloatAttribute(context->getRef(), attr);
+ },
+ py::arg("value"), py::arg("context") = py::none(),
+ "Gets an uniqued float point attribute associated to a f32 type");
+ c.def_static(
+ "get_f64",
+ [](double value, DefaultingPyMlirContext context) {
+ MlirAttribute attr = mlirFloatAttrDoubleGet(
+ context->get(), mlirF64TypeGet(context->get()), value);
+ return PyFloatAttribute(context->getRef(), attr);
+ },
+ py::arg("value"), py::arg("context") = py::none(),
+ "Gets an uniqued float point attribute associated to a f64 type");
+ c.def_property_readonly(
+ "value",
+ [](PyFloatAttribute &self) {
+ return mlirFloatAttrGetValueDouble(self);
+ },
+ "Returns the value of the float point attribute");
+ }
+};
+
+/// Integer Attribute subclass - IntegerAttr.
+class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
+ static constexpr const char *pyClassName = "IntegerAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyType &type, int64_t value) {
+ MlirAttribute attr = mlirIntegerAttrGet(type, value);
+ return PyIntegerAttribute(type.getContext(), attr);
+ },
+ py::arg("type"), py::arg("value"),
+ "Gets an uniqued integer attribute associated to a type");
+ c.def_property_readonly(
+ "value",
+ [](PyIntegerAttribute &self) {
+ return mlirIntegerAttrGetValueInt(self);
+ },
+ "Returns the value of the integer attribute");
+ }
+};
+
+/// Bool Attribute subclass - BoolAttr.
+class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
+ static constexpr const char *pyClassName = "BoolAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](bool value, DefaultingPyMlirContext context) {
+ MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
+ return PyBoolAttribute(context->getRef(), attr);
+ },
+ py::arg("value"), py::arg("context") = py::none(),
+ "Gets an uniqued bool attribute");
+ c.def_property_readonly(
+ "value",
+ [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
+ "Returns the value of the bool attribute");
+ }
+};
+
+class PyFlatSymbolRefAttribute
+ : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
+ static constexpr const char *pyClassName = "FlatSymbolRefAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](std::string value, DefaultingPyMlirContext context) {
+ MlirAttribute attr =
+ mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
+ return PyFlatSymbolRefAttribute(context->getRef(), attr);
+ },
+ py::arg("value"), py::arg("context") = py::none(),
+ "Gets a uniqued FlatSymbolRef attribute");
+ c.def_property_readonly(
+ "value",
+ [](PyFlatSymbolRefAttribute &self) {
+ MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
+ return py::str(stringRef.data, stringRef.length);
+ },
+ "Returns the value of the FlatSymbolRef attribute as a string");
+ }
+};
+
+class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
+ static constexpr const char *pyClassName = "StringAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](std::string value, DefaultingPyMlirContext context) {
+ MlirAttribute attr =
+ mlirStringAttrGet(context->get(), toMlirStringRef(value));
+ return PyStringAttribute(context->getRef(), attr);
+ },
+ py::arg("value"), py::arg("context") = py::none(),
+ "Gets a uniqued string attribute");
+ c.def_static(
+ "get_typed",
+ [](PyType &type, std::string value) {
+ MlirAttribute attr =
+ mlirStringAttrTypedGet(type, toMlirStringRef(value));
+ return PyStringAttribute(type.getContext(), attr);
+ },
+
+ "Gets a uniqued string attribute associated to a type");
+ c.def_property_readonly(
+ "value",
+ [](PyStringAttribute &self) {
+ MlirStringRef stringRef = mlirStringAttrGetValue(self);
+ return py::str(stringRef.data, stringRef.length);
+ },
+ "Returns the value of the string attribute");
+ }
+};
+
+// TODO: Support construction of bool elements.
+// TODO: Support construction of string elements.
+class PyDenseElementsAttribute
+ : public PyConcreteAttribute<PyDenseElementsAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
+ static constexpr const char *pyClassName = "DenseElementsAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static PyDenseElementsAttribute
+ getFromBuffer(py::buffer array, bool signless,
+ DefaultingPyMlirContext contextWrapper) {
+ // Request a contiguous view. In exotic cases, this will cause a copy.
+ int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
+ Py_buffer *view = new Py_buffer();
+ if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
+ delete view;
+ throw py::error_already_set();
+ }
+ py::buffer_info arrayInfo(view);
+
+ MlirContext context = contextWrapper->get();
+ // Switch on the types that can be bulk loaded between the Python and
+ // MLIR-C APIs.
+ // See: https://docs.python.org/3/library/struct.html#format-characters
+ if (arrayInfo.format == "f") {
+ // f32
+ assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
+ return PyDenseElementsAttribute(
+ contextWrapper->getRef(),
+ bulkLoad(context, mlirDenseElementsAttrFloatGet,
+ mlirF32TypeGet(context), arrayInfo));
+ } else if (arrayInfo.format == "d") {
+ // f64
+ assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
+ return PyDenseElementsAttribute(
+ contextWrapper->getRef(),
+ bulkLoad(context, mlirDenseElementsAttrDoubleGet,
+ mlirF64TypeGet(context), arrayInfo));
+ } else if (isSignedIntegerFormat(arrayInfo.format)) {
+ if (arrayInfo.itemsize == 4) {
+ // i32
+ MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
+ : mlirIntegerTypeSignedGet(context, 32);
+ return PyDenseElementsAttribute(contextWrapper->getRef(),
+ bulkLoad(context,
+ mlirDenseElementsAttrInt32Get,
+ elementType, arrayInfo));
+ } else if (arrayInfo.itemsize == 8) {
+ // i64
+ MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
+ : mlirIntegerTypeSignedGet(context, 64);
+ return PyDenseElementsAttribute(contextWrapper->getRef(),
+ bulkLoad(context,
+ mlirDenseElementsAttrInt64Get,
+ elementType, arrayInfo));
+ }
+ } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
+ if (arrayInfo.itemsize == 4) {
+ // unsigned i32
+ MlirType elementType = signless
+ ? mlirIntegerTypeGet(context, 32)
+ : mlirIntegerTypeUnsignedGet(context, 32);
+ return PyDenseElementsAttribute(contextWrapper->getRef(),
+ bulkLoad(context,
+ mlirDenseElementsAttrUInt32Get,
+ elementType, arrayInfo));
+ } else if (arrayInfo.itemsize == 8) {
+ // unsigned i64
+ MlirType elementType = signless
+ ? mlirIntegerTypeGet(context, 64)
+ : mlirIntegerTypeUnsignedGet(context, 64);
+ return PyDenseElementsAttribute(contextWrapper->getRef(),
+ bulkLoad(context,
+ mlirDenseElementsAttrUInt64Get,
+ elementType, arrayInfo));
+ }
+ }
+
+ // TODO: Fall back to string-based get.
+ std::string message = "unimplemented array format conversion from format: ";
+ message.append(arrayInfo.format);
+ throw SetPyError(PyExc_ValueError, message);
+ }
+
+ static PyDenseElementsAttribute getSplat(PyType shapedType,
+ PyAttribute &elementAttr) {
+ auto contextWrapper =
+ PyMlirContext::forContext(mlirTypeGetContext(shapedType));
+ if (!mlirAttributeIsAInteger(elementAttr) &&
+ !mlirAttributeIsAFloat(elementAttr)) {
+ std::string message = "Illegal element type for DenseElementsAttr: ";
+ message.append(py::repr(py::cast(elementAttr)));
+ throw SetPyError(PyExc_ValueError, message);
+ }
+ if (!mlirTypeIsAShaped(shapedType) ||
+ !mlirShapedTypeHasStaticShape(shapedType)) {
+ std::string message =
+ "Expected a static ShapedType for the shaped_type parameter: ";
+ message.append(py::repr(py::cast(shapedType)));
+ throw SetPyError(PyExc_ValueError, message);
+ }
+ MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
+ MlirType attrType = mlirAttributeGetType(elementAttr);
+ if (!mlirTypeEqual(shapedElementType, attrType)) {
+ std::string message =
+ "Shaped element type and attribute type must be equal: shaped=";
+ message.append(py::repr(py::cast(shapedType)));
+ message.append(", element=");
+ message.append(py::repr(py::cast(elementAttr)));
+ throw SetPyError(PyExc_ValueError, message);
+ }
+
+ MlirAttribute elements =
+ mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
+ return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
+ }
+
+ intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
+
+ py::buffer_info accessBuffer() {
+ MlirType shapedType = mlirAttributeGetType(*this);
+ MlirType elementType = mlirShapedTypeGetElementType(shapedType);
+
+ if (mlirTypeIsAF32(elementType)) {
+ // f32
+ return bufferInfo(shapedType, mlirDenseElementsAttrGetFloatValue);
+ } else if (mlirTypeIsAF64(elementType)) {
+ // f64
+ return bufferInfo(shapedType, mlirDenseElementsAttrGetDoubleValue);
+ } else if (mlirTypeIsAInteger(elementType) &&
+ mlirIntegerTypeGetWidth(elementType) == 32) {
+ if (mlirIntegerTypeIsSignless(elementType) ||
+ mlirIntegerTypeIsSigned(elementType)) {
+ // i32
+ return bufferInfo(shapedType, mlirDenseElementsAttrGetInt32Value);
+ } else if (mlirIntegerTypeIsUnsigned(elementType)) {
+ // unsigned i32
+ return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt32Value);
+ }
+ } else if (mlirTypeIsAInteger(elementType) &&
+ mlirIntegerTypeGetWidth(elementType) == 64) {
+ if (mlirIntegerTypeIsSignless(elementType) ||
+ mlirIntegerTypeIsSigned(elementType)) {
+ // i64
+ return bufferInfo(shapedType, mlirDenseElementsAttrGetInt64Value);
+ } else if (mlirIntegerTypeIsUnsigned(elementType)) {
+ // unsigned i64
+ return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt64Value);
+ }
+ }
+
+ std::string message = "unimplemented array format.";
+ throw SetPyError(PyExc_ValueError, message);
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def("__len__", &PyDenseElementsAttribute::dunderLen)
+ .def_static("get", PyDenseElementsAttribute::getFromBuffer,
+ py::arg("array"), py::arg("signless") = true,
+ py::arg("context") = py::none(),
+ "Gets from a buffer or ndarray")
+ .def_static("get_splat", PyDenseElementsAttribute::getSplat,
+ py::arg("shaped_type"), py::arg("element_attr"),
+ "Gets a DenseElementsAttr where all values are the same")
+ .def_property_readonly("is_splat",
+ [](PyDenseElementsAttribute &self) -> bool {
+ return mlirDenseElementsAttrIsSplat(self);
+ })
+ .def_buffer(&PyDenseElementsAttribute::accessBuffer);
+ }
+
+private:
+ template <typename ElementTy>
+ static MlirAttribute
+ bulkLoad(MlirContext context,
+ MlirAttribute (*ctor)(MlirType, intptr_t, ElementTy *),
+ MlirType mlirElementType, py::buffer_info &arrayInfo) {
+ SmallVector<int64_t, 4> shape(arrayInfo.shape.begin(),
+ arrayInfo.shape.begin() + arrayInfo.ndim);
+ auto shapedType =
+ mlirRankedTensorTypeGet(shape.size(), shape.data(), mlirElementType);
+ intptr_t numElements = arrayInfo.size;
+ const ElementTy *contents = static_cast<const ElementTy *>(arrayInfo.ptr);
+ return ctor(shapedType, numElements, contents);
+ }
+
+ static bool isUnsignedIntegerFormat(const std::string &format) {
+ if (format.empty())
+ return false;
+ char code = format[0];
+ return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
+ code == 'Q';
+ }
+
+ static bool isSignedIntegerFormat(const std::string &format) {
+ if (format.empty())
+ return false;
+ char code = format[0];
+ return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
+ code == 'q';
+ }
+
+ template <typename Type>
+ py::buffer_info bufferInfo(MlirType shapedType,
+ Type (*value)(MlirAttribute, intptr_t)) {
+ intptr_t rank = mlirShapedTypeGetRank(shapedType);
+ // Prepare the data for the buffer_info.
+ // Buffer is configured for read-only access below.
+ Type *data = static_cast<Type *>(
+ const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
+ // Prepare the shape for the buffer_info.
+ SmallVector<intptr_t, 4> shape;
+ for (intptr_t i = 0; i < rank; ++i)
+ shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
+ // Prepare the strides for the buffer_info.
+ SmallVector<intptr_t, 4> strides;
+ intptr_t strideFactor = 1;
+ for (intptr_t i = 1; i < rank; ++i) {
+ strideFactor = 1;
+ for (intptr_t j = i; j < rank; ++j) {
+ strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
+ }
+ strides.push_back(sizeof(Type) * strideFactor);
+ }
+ strides.push_back(sizeof(Type));
+ return py::buffer_info(data, sizeof(Type),
+ py::format_descriptor<Type>::format(), rank, shape,
+ strides, /*readonly=*/true);
+ }
+}; // namespace
+
+/// Refinement of the PyDenseElementsAttribute for attributes containing integer
+/// (and boolean) values. Supports element access.
+class PyDenseIntElementsAttribute
+ : public PyConcreteAttribute<PyDenseIntElementsAttribute,
+ PyDenseElementsAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
+ static constexpr const char *pyClassName = "DenseIntElementsAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ /// Returns the element at the given linear position. Asserts if the index is
+ /// out of range.
+ py::int_ dunderGetItem(intptr_t pos) {
+ if (pos < 0 || pos >= dunderLen()) {
+ throw SetPyError(PyExc_IndexError,
+ "attempt to access out of bounds element");
+ }
+
+ MlirType type = mlirAttributeGetType(*this);
+ type = mlirShapedTypeGetElementType(type);
+ assert(mlirTypeIsAInteger(type) &&
+ "expected integer element type in dense int elements attribute");
+ // Dispatch element extraction to an appropriate C function based on the
+ // elemental type of the attribute. py::int_ is implicitly constructible
+ // from any C++ integral type and handles bitwidth correctly.
+ // TODO: consider caching the type properties in the constructor to avoid
+ // querying them on each element access.
+ unsigned width = mlirIntegerTypeGetWidth(type);
+ bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
+ if (isUnsigned) {
+ if (width == 1) {
+ return mlirDenseElementsAttrGetBoolValue(*this, pos);
+ }
+ if (width == 32) {
+ return mlirDenseElementsAttrGetUInt32Value(*this, pos);
+ }
+ if (width == 64) {
+ return mlirDenseElementsAttrGetUInt64Value(*this, pos);
+ }
+ } else {
+ if (width == 1) {
+ return mlirDenseElementsAttrGetBoolValue(*this, pos);
+ }
+ if (width == 32) {
+ return mlirDenseElementsAttrGetInt32Value(*this, pos);
+ }
+ if (width == 64) {
+ return mlirDenseElementsAttrGetInt64Value(*this, pos);
+ }
+ }
+ throw SetPyError(PyExc_TypeError, "Unsupported integer type");
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
+ }
+};
+
+class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
+ static constexpr const char *pyClassName = "DictAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
+
+ static void bindDerived(ClassTy &c) {
+ c.def("__len__", &PyDictAttribute::dunderLen);
+ c.def_static(
+ "get",
+ [](py::dict attributes, DefaultingPyMlirContext context) {
+ SmallVector<MlirNamedAttribute> mlirNamedAttributes;
+ mlirNamedAttributes.reserve(attributes.size());
+ for (auto &it : attributes) {
+ auto &mlir_attr = it.second.cast<PyAttribute &>();
+ auto name = it.first.cast<std::string>();
+ mlirNamedAttributes.push_back(mlirNamedAttributeGet(
+ mlirIdentifierGet(mlirAttributeGetContext(mlir_attr),
+ toMlirStringRef(name)),
+ mlir_attr));
+ }
+ MlirAttribute attr =
+ mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
+ mlirNamedAttributes.data());
+ return PyDictAttribute(context->getRef(), attr);
+ },
+ py::arg("value"), py::arg("context") = py::none(),
+ "Gets an uniqued dict attribute");
+ c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
+ MlirAttribute attr =
+ mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
+ if (mlirAttributeIsNull(attr)) {
+ throw SetPyError(PyExc_KeyError,
+ "attempt to access a non-existent attribute");
+ }
+ return PyAttribute(self.getContext(), attr);
+ });
+ c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
+ if (index < 0 || index >= self.dunderLen()) {
+ throw SetPyError(PyExc_IndexError,
+ "attempt to access out of bounds attribute");
+ }
+ MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
+ return PyNamedAttribute(
+ namedAttr.attribute,
+ std::string(mlirIdentifierStr(namedAttr.name).data));
+ });
+ }
+};
+
+/// Refinement of PyDenseElementsAttribute for attributes containing
+/// floating-point values. Supports element access.
+class PyDenseFPElementsAttribute
+ : public PyConcreteAttribute<PyDenseFPElementsAttribute,
+ PyDenseElementsAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
+ static constexpr const char *pyClassName = "DenseFPElementsAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ py::float_ dunderGetItem(intptr_t pos) {
+ if (pos < 0 || pos >= dunderLen()) {
+ throw SetPyError(PyExc_IndexError,
+ "attempt to access out of bounds element");
+ }
+
+ MlirType type = mlirAttributeGetType(*this);
+ type = mlirShapedTypeGetElementType(type);
+ // Dispatch element extraction to an appropriate C function based on the
+ // elemental type of the attribute. py::float_ is implicitly constructible
+ // from float and double.
+ // TODO: consider caching the type properties in the constructor to avoid
+ // querying them on each element access.
+ if (mlirTypeIsAF32(type)) {
+ return mlirDenseElementsAttrGetFloatValue(*this, pos);
+ }
+ if (mlirTypeIsAF64(type)) {
+ return mlirDenseElementsAttrGetDoubleValue(*this, pos);
+ }
+ throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
+ }
+};
+
+class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
+ static constexpr const char *pyClassName = "TypeAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyType value, DefaultingPyMlirContext context) {
+ MlirAttribute attr = mlirTypeAttrGet(value.get());
+ return PyTypeAttribute(context->getRef(), attr);
+ },
+ py::arg("value"), py::arg("context") = py::none(),
+ "Gets a uniqued Type attribute");
+ c.def_property_readonly("value", [](PyTypeAttribute &self) {
+ return PyType(self.getContext()->getRef(),
+ mlirTypeAttrGetValue(self.get()));
+ });
+ }
+};
+
+/// Unit Attribute subclass. Unit attributes don't have values.
+class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
+ static constexpr const char *pyClassName = "UnitAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ return PyUnitAttribute(context->getRef(),
+ mlirUnitAttrGet(context->get()));
+ },
+ py::arg("context") = py::none(), "Create a Unit attribute.");
+ }
+};
+
+} // namespace
+
+void mlir::python::populateIRAttributes(py::module &m) {
+ PyAffineMapAttribute::bind(m);
+ PyArrayAttribute::bind(m);
+ PyArrayAttribute::PyArrayAttributeIterator::bind(m);
+ PyBoolAttribute::bind(m);
+ PyDenseElementsAttribute::bind(m);
+ PyDenseFPElementsAttribute::bind(m);
+ PyDenseIntElementsAttribute::bind(m);
+ PyDictAttribute::bind(m);
+ PyFlatSymbolRefAttribute::bind(m);
+ PyFloatAttribute::bind(m);
+ PyIntegerAttribute::bind(m);
+ PyStringAttribute::bind(m);
+ PyTypeAttribute::bind(m);
+ PyUnitAttribute::bind(m);
+}
diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
similarity index 52%
rename from mlir/lib/Bindings/Python/IRModules.cpp
rename to mlir/lib/Bindings/Python/IRCore.cpp
index 6b4e5434d1d7..9d87aa52f7c8 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -6,16 +6,14 @@
//
//===----------------------------------------------------------------------===//
-#include "IRModules.h"
+#include "IRModule.h"
#include "Globals.h"
#include "PybindUtils.h"
-#include "mlir-c/AffineMap.h"
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
-#include "mlir-c/IntegerSet.h"
#include "mlir-c/Registration.h"
#include "llvm/ADT/SmallVector.h"
#include <pybind11/stl.h>
@@ -138,12 +136,6 @@ py::object classmethod(Func f, Args... args) {
return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
}
-/// Checks whether the given type is an integer or float type.
-static int mlirTypeIsAIntegerOrFloat(MlirType type) {
- return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
- mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
-}
-
static py::object
createCustomDialectWrapper(const std::string &dialectNamespace,
py::object dialectDescriptor) {
@@ -161,21 +153,6 @@ static MlirStringRef toMlirStringRef(const std::string &s) {
return mlirStringRefCreate(s.data(), s.size());
}
-template <typename PermutationTy>
-static bool isPermutation(std::vector<PermutationTy> permutation) {
- llvm::SmallVector<bool, 8> seen(permutation.size(), false);
- for (auto val : permutation) {
- if (val < permutation.size()) {
- if (seen[val])
- return false;
- seen[val] = true;
- continue;
- }
- return false;
- }
- return true;
-}
-
//------------------------------------------------------------------------------
// Collections.
//------------------------------------------------------------------------------
@@ -1466,7 +1443,8 @@ namespace {
/// CRTP base class for Python MLIR values that subclass Value and should be
/// castable from it. The value hierarchy is one level deep and is not supposed
/// to accommodate other levels unless core MLIR changes.
-template <typename DerivedTy> class PyConcreteValue : public PyValue {
+template <typename DerivedTy>
+class PyConcreteValue : public PyValue {
public:
// Derived classes must define statics for:
// IsAFunctionTy isaFunction
@@ -1717,1910 +1695,169 @@ class PyOpAttributeMap {
} // end namespace
//------------------------------------------------------------------------------
-// Builtin attribute subclasses.
+// Populates the core exports of the 'ir' submodule.
//------------------------------------------------------------------------------
-namespace {
-
-/// CRTP base classes for Python attributes that subclass Attribute and should
-/// be castable from it (i.e. via something like StringAttr(attr)).
-/// By default, attribute class hierarchies are one level deep (i.e. a
-/// concrete attribute class extends PyAttribute); however, intermediate
-/// python-visible base classes can be modeled by specifying a BaseTy.
-template <typename DerivedTy, typename BaseTy = PyAttribute>
-class PyConcreteAttribute : public BaseTy {
-public:
- // Derived classes must define statics for:
- // IsAFunctionTy isaFunction
- // const char *pyClassName
- using ClassTy = py::class_<DerivedTy, BaseTy>;
- using IsAFunctionTy = bool (*)(MlirAttribute);
-
- PyConcreteAttribute() = default;
- PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
- : BaseTy(std::move(contextRef), attr) {}
- PyConcreteAttribute(PyAttribute &orig)
- : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {}
-
- static MlirAttribute castFrom(PyAttribute &orig) {
- if (!DerivedTy::isaFunction(orig)) {
- auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
- throw SetPyError(PyExc_ValueError, Twine("Cannot cast attribute to ") +
- DerivedTy::pyClassName +
- " (from " + origRepr + ")");
- }
- return orig;
- }
-
- static void bind(py::module &m) {
- auto cls = ClassTy(m, DerivedTy::pyClassName, py::buffer_protocol());
- cls.def(py::init<PyAttribute &>(), py::keep_alive<0, 1>());
- DerivedTy::bindDerived(cls);
- }
-
- /// Implemented by derived classes to add methods to the Python subclass.
- static void bindDerived(ClassTy &m) {}
-};
-
-class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
- static constexpr const char *pyClassName = "AffineMapAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyAffineMap &affineMap) {
- MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
- return PyAffineMapAttribute(affineMap.getContext(), attr);
- },
- py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
- }
-};
-
-class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
- static constexpr const char *pyClassName = "ArrayAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- class PyArrayAttributeIterator {
- public:
- PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {}
-
- PyArrayAttributeIterator &dunderIter() { return *this; }
-
- PyAttribute dunderNext() {
- if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) {
- throw py::stop_iteration();
- }
- return PyAttribute(attr.getContext(),
- mlirArrayAttrGetElement(attr.get(), nextIndex++));
- }
-
- static void bind(py::module &m) {
- py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator")
- .def("__iter__", &PyArrayAttributeIterator::dunderIter)
- .def("__next__", &PyArrayAttributeIterator::dunderNext);
- }
-
- private:
- PyAttribute attr;
- int nextIndex = 0;
- };
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](py::list attributes, DefaultingPyMlirContext context) {
- SmallVector<MlirAttribute> mlirAttributes;
- mlirAttributes.reserve(py::len(attributes));
- for (auto attribute : attributes) {
- try {
- mlirAttributes.push_back(attribute.cast<PyAttribute>());
- } catch (py::cast_error &err) {
- std::string msg = std::string("Invalid attribute when attempting "
- "to create an ArrayAttribute (") +
- err.what() + ")";
- throw py::cast_error(msg);
- } catch (py::reference_cast_error &err) {
- // This exception seems thrown when the value is "None".
- std::string msg =
- std::string("Invalid attribute (None?) when attempting to "
- "create an ArrayAttribute (") +
- err.what() + ")";
- throw py::cast_error(msg);
+void mlir::python::populateIRCore(py::module &m) {
+ //----------------------------------------------------------------------------
+ // Mapping of MlirContext
+ //----------------------------------------------------------------------------
+ py::class_<PyMlirContext>(m, "Context")
+ .def(py::init<>(&PyMlirContext::createNewContextForInit))
+ .def_static("_get_live_count", &PyMlirContext::getLiveCount)
+ .def("_get_context_again",
+ [](PyMlirContext &self) {
+ PyMlirContextRef ref = PyMlirContext::forContext(self.get());
+ return ref.releaseObject();
+ })
+ .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
+ .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
+ .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
+ &PyMlirContext::getCapsule)
+ .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
+ .def("__enter__", &PyMlirContext::contextEnter)
+ .def("__exit__", &PyMlirContext::contextExit)
+ .def_property_readonly_static(
+ "current",
+ [](py::object & /*class*/) {
+ auto *context = PyThreadContextEntry::getDefaultContext();
+ if (!context)
+ throw SetPyError(PyExc_ValueError, "No current Context");
+ return context;
+ },
+ "Gets the Context bound to the current thread or raises ValueError")
+ .def_property_readonly(
+ "dialects",
+ [](PyMlirContext &self) { return PyDialects(self.getRef()); },
+ "Gets a container for accessing dialects by name")
+ .def_property_readonly(
+ "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
+ "Alias for 'dialect'")
+ .def(
+ "get_dialect_descriptor",
+ [=](PyMlirContext &self, std::string &name) {
+ MlirDialect dialect = mlirContextGetOrLoadDialect(
+ self.get(), {name.data(), name.size()});
+ if (mlirDialectIsNull(dialect)) {
+ throw SetPyError(PyExc_ValueError,
+ Twine("Dialect '") + name + "' not found");
}
- }
- MlirAttribute attr = mlirArrayAttrGet(
- context->get(), mlirAttributes.size(), mlirAttributes.data());
- return PyArrayAttribute(context->getRef(), attr);
- },
- py::arg("attributes"), py::arg("context") = py::none(),
- "Gets a uniqued Array attribute");
- c.def("__getitem__",
- [](PyArrayAttribute &arr, intptr_t i) {
- if (i >= mlirArrayAttrGetNumElements(arr))
- throw py::index_error("ArrayAttribute index out of range");
- return PyAttribute(arr.getContext(),
- mlirArrayAttrGetElement(arr, i));
- })
- .def("__len__",
- [](const PyArrayAttribute &arr) {
- return mlirArrayAttrGetNumElements(arr);
- })
- .def("__iter__", [](const PyArrayAttribute &arr) {
- return PyArrayAttributeIterator(arr);
- });
- }
-};
-
-/// Float Point Attribute subclass - FloatAttr.
-class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
- static constexpr const char *pyClassName = "FloatAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyType &type, double value, DefaultingPyLocation loc) {
- MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
- // TODO: Rework error reporting once diagnostic engine is exposed
- // in C API.
- if (mlirAttributeIsNull(attr)) {
- throw SetPyError(PyExc_ValueError,
- Twine("invalid '") +
- py::repr(py::cast(type)).cast<std::string>() +
- "' and expected floating point type.");
- }
- return PyFloatAttribute(type.getContext(), attr);
- },
- py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
- "Gets an uniqued float point attribute associated to a type");
- c.def_static(
- "get_f32",
- [](double value, DefaultingPyMlirContext context) {
- MlirAttribute attr = mlirFloatAttrDoubleGet(
- context->get(), mlirF32TypeGet(context->get()), value);
- return PyFloatAttribute(context->getRef(), attr);
- },
- py::arg("value"), py::arg("context") = py::none(),
- "Gets an uniqued float point attribute associated to a f32 type");
- c.def_static(
- "get_f64",
- [](double value, DefaultingPyMlirContext context) {
- MlirAttribute attr = mlirFloatAttrDoubleGet(
- context->get(), mlirF64TypeGet(context->get()), value);
- return PyFloatAttribute(context->getRef(), attr);
- },
- py::arg("value"), py::arg("context") = py::none(),
- "Gets an uniqued float point attribute associated to a f64 type");
- c.def_property_readonly(
- "value",
- [](PyFloatAttribute &self) {
- return mlirFloatAttrGetValueDouble(self);
- },
- "Returns the value of the float point attribute");
- }
-};
-
-/// Integer Attribute subclass - IntegerAttr.
-class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
- static constexpr const char *pyClassName = "IntegerAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyType &type, int64_t value) {
- MlirAttribute attr = mlirIntegerAttrGet(type, value);
- return PyIntegerAttribute(type.getContext(), attr);
- },
- py::arg("type"), py::arg("value"),
- "Gets an uniqued integer attribute associated to a type");
- c.def_property_readonly(
- "value",
- [](PyIntegerAttribute &self) {
- return mlirIntegerAttrGetValueInt(self);
- },
- "Returns the value of the integer attribute");
- }
-};
-
-/// Bool Attribute subclass - BoolAttr.
-class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
- static constexpr const char *pyClassName = "BoolAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](bool value, DefaultingPyMlirContext context) {
- MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
- return PyBoolAttribute(context->getRef(), attr);
- },
- py::arg("value"), py::arg("context") = py::none(),
- "Gets an uniqued bool attribute");
- c.def_property_readonly(
- "value",
- [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
- "Returns the value of the bool attribute");
- }
-};
-
-class PyFlatSymbolRefAttribute
- : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
- static constexpr const char *pyClassName = "FlatSymbolRefAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](std::string value, DefaultingPyMlirContext context) {
- MlirAttribute attr =
- mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
- return PyFlatSymbolRefAttribute(context->getRef(), attr);
- },
- py::arg("value"), py::arg("context") = py::none(),
- "Gets a uniqued FlatSymbolRef attribute");
- c.def_property_readonly(
- "value",
- [](PyFlatSymbolRefAttribute &self) {
- MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
- return py::str(stringRef.data, stringRef.length);
- },
- "Returns the value of the FlatSymbolRef attribute as a string");
- }
-};
-
-class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
- static constexpr const char *pyClassName = "StringAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](std::string value, DefaultingPyMlirContext context) {
- MlirAttribute attr =
- mlirStringAttrGet(context->get(), toMlirStringRef(value));
- return PyStringAttribute(context->getRef(), attr);
- },
- py::arg("value"), py::arg("context") = py::none(),
- "Gets a uniqued string attribute");
- c.def_static(
- "get_typed",
- [](PyType &type, std::string value) {
- MlirAttribute attr =
- mlirStringAttrTypedGet(type, toMlirStringRef(value));
- return PyStringAttribute(type.getContext(), attr);
- },
-
- "Gets a uniqued string attribute associated to a type");
- c.def_property_readonly(
- "value",
- [](PyStringAttribute &self) {
- MlirStringRef stringRef = mlirStringAttrGetValue(self);
- return py::str(stringRef.data, stringRef.length);
- },
- "Returns the value of the string attribute");
- }
-};
-
-// TODO: Support construction of bool elements.
-// TODO: Support construction of string elements.
-class PyDenseElementsAttribute
- : public PyConcreteAttribute<PyDenseElementsAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
- static constexpr const char *pyClassName = "DenseElementsAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- static PyDenseElementsAttribute
- getFromBuffer(py::buffer array, bool signless,
- DefaultingPyMlirContext contextWrapper) {
- // Request a contiguous view. In exotic cases, this will cause a copy.
- int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
- Py_buffer *view = new Py_buffer();
- if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
- delete view;
- throw py::error_already_set();
- }
- py::buffer_info arrayInfo(view);
-
- MlirContext context = contextWrapper->get();
- // Switch on the types that can be bulk loaded between the Python and
- // MLIR-C APIs.
- // See: https://docs.python.org/3/library/struct.html#format-characters
- if (arrayInfo.format == "f") {
- // f32
- assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
- return PyDenseElementsAttribute(
- contextWrapper->getRef(),
- bulkLoad(context, mlirDenseElementsAttrFloatGet,
- mlirF32TypeGet(context), arrayInfo));
- } else if (arrayInfo.format == "d") {
- // f64
- assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
- return PyDenseElementsAttribute(
- contextWrapper->getRef(),
- bulkLoad(context, mlirDenseElementsAttrDoubleGet,
- mlirF64TypeGet(context), arrayInfo));
- } else if (isSignedIntegerFormat(arrayInfo.format)) {
- if (arrayInfo.itemsize == 4) {
- // i32
- MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
- : mlirIntegerTypeSignedGet(context, 32);
- return PyDenseElementsAttribute(contextWrapper->getRef(),
- bulkLoad(context,
- mlirDenseElementsAttrInt32Get,
- elementType, arrayInfo));
- } else if (arrayInfo.itemsize == 8) {
- // i64
- MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
- : mlirIntegerTypeSignedGet(context, 64);
- return PyDenseElementsAttribute(contextWrapper->getRef(),
- bulkLoad(context,
- mlirDenseElementsAttrInt64Get,
- elementType, arrayInfo));
- }
- } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
- if (arrayInfo.itemsize == 4) {
- // unsigned i32
- MlirType elementType = signless
- ? mlirIntegerTypeGet(context, 32)
- : mlirIntegerTypeUnsignedGet(context, 32);
- return PyDenseElementsAttribute(contextWrapper->getRef(),
- bulkLoad(context,
- mlirDenseElementsAttrUInt32Get,
- elementType, arrayInfo));
- } else if (arrayInfo.itemsize == 8) {
- // unsigned i64
- MlirType elementType = signless
- ? mlirIntegerTypeGet(context, 64)
- : mlirIntegerTypeUnsignedGet(context, 64);
- return PyDenseElementsAttribute(contextWrapper->getRef(),
- bulkLoad(context,
- mlirDenseElementsAttrUInt64Get,
- elementType, arrayInfo));
- }
- }
-
- // TODO: Fall back to string-based get.
- std::string message = "unimplemented array format conversion from format: ";
- message.append(arrayInfo.format);
- throw SetPyError(PyExc_ValueError, message);
- }
-
- static PyDenseElementsAttribute getSplat(PyType shapedType,
- PyAttribute &elementAttr) {
- auto contextWrapper =
- PyMlirContext::forContext(mlirTypeGetContext(shapedType));
- if (!mlirAttributeIsAInteger(elementAttr) &&
- !mlirAttributeIsAFloat(elementAttr)) {
- std::string message = "Illegal element type for DenseElementsAttr: ";
- message.append(py::repr(py::cast(elementAttr)));
- throw SetPyError(PyExc_ValueError, message);
- }
- if (!mlirTypeIsAShaped(shapedType) ||
- !mlirShapedTypeHasStaticShape(shapedType)) {
- std::string message =
- "Expected a static ShapedType for the shaped_type parameter: ";
- message.append(py::repr(py::cast(shapedType)));
- throw SetPyError(PyExc_ValueError, message);
- }
- MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
- MlirType attrType = mlirAttributeGetType(elementAttr);
- if (!mlirTypeEqual(shapedElementType, attrType)) {
- std::string message =
- "Shaped element type and attribute type must be equal: shaped=";
- message.append(py::repr(py::cast(shapedType)));
- message.append(", element=");
- message.append(py::repr(py::cast(elementAttr)));
- throw SetPyError(PyExc_ValueError, message);
- }
-
- MlirAttribute elements =
- mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
- return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
- }
+ return PyDialectDescriptor(self.getRef(), dialect);
+ },
+ "Gets or loads a dialect by name, returning its descriptor object")
+ .def_property(
+ "allow_unregistered_dialects",
+ [](PyMlirContext &self) -> bool {
+ return mlirContextGetAllowUnregisteredDialects(self.get());
+ },
+ [](PyMlirContext &self, bool value) {
+ mlirContextSetAllowUnregisteredDialects(self.get(), value);
+ });
- intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
-
- py::buffer_info accessBuffer() {
- MlirType shapedType = mlirAttributeGetType(*this);
- MlirType elementType = mlirShapedTypeGetElementType(shapedType);
-
- if (mlirTypeIsAF32(elementType)) {
- // f32
- return bufferInfo(shapedType, mlirDenseElementsAttrGetFloatValue);
- } else if (mlirTypeIsAF64(elementType)) {
- // f64
- return bufferInfo(shapedType, mlirDenseElementsAttrGetDoubleValue);
- } else if (mlirTypeIsAInteger(elementType) &&
- mlirIntegerTypeGetWidth(elementType) == 32) {
- if (mlirIntegerTypeIsSignless(elementType) ||
- mlirIntegerTypeIsSigned(elementType)) {
- // i32
- return bufferInfo(shapedType, mlirDenseElementsAttrGetInt32Value);
- } else if (mlirIntegerTypeIsUnsigned(elementType)) {
- // unsigned i32
- return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt32Value);
- }
- } else if (mlirTypeIsAInteger(elementType) &&
- mlirIntegerTypeGetWidth(elementType) == 64) {
- if (mlirIntegerTypeIsSignless(elementType) ||
- mlirIntegerTypeIsSigned(elementType)) {
- // i64
- return bufferInfo(shapedType, mlirDenseElementsAttrGetInt64Value);
- } else if (mlirIntegerTypeIsUnsigned(elementType)) {
- // unsigned i64
- return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt64Value);
- }
- }
+ //----------------------------------------------------------------------------
+ // Mapping of PyDialectDescriptor
+ //----------------------------------------------------------------------------
+ py::class_<PyDialectDescriptor>(m, "DialectDescriptor")
+ .def_property_readonly("namespace",
+ [](PyDialectDescriptor &self) {
+ MlirStringRef ns =
+ mlirDialectGetNamespace(self.get());
+ return py::str(ns.data, ns.length);
+ })
+ .def("__repr__", [](PyDialectDescriptor &self) {
+ MlirStringRef ns = mlirDialectGetNamespace(self.get());
+ std::string repr("<DialectDescriptor ");
+ repr.append(ns.data, ns.length);
+ repr.append(">");
+ return repr;
+ });
- std::string message = "unimplemented array format.";
- throw SetPyError(PyExc_ValueError, message);
- }
+ //----------------------------------------------------------------------------
+ // Mapping of PyDialects
+ //----------------------------------------------------------------------------
+ py::class_<PyDialects>(m, "Dialects")
+ .def("__getitem__",
+ [=](PyDialects &self, std::string keyName) {
+ MlirDialect dialect =
+ self.getDialectForKey(keyName, /*attrError=*/false);
+ py::object descriptor =
+ py::cast(PyDialectDescriptor{self.getContext(), dialect});
+ return createCustomDialectWrapper(keyName, std::move(descriptor));
+ })
+ .def("__getattr__", [=](PyDialects &self, std::string attrName) {
+ MlirDialect dialect =
+ self.getDialectForKey(attrName, /*attrError=*/true);
+ py::object descriptor =
+ py::cast(PyDialectDescriptor{self.getContext(), dialect});
+ return createCustomDialectWrapper(attrName, std::move(descriptor));
+ });
- static void bindDerived(ClassTy &c) {
- c.def("__len__", &PyDenseElementsAttribute::dunderLen)
- .def_static("get", PyDenseElementsAttribute::getFromBuffer,
- py::arg("array"), py::arg("signless") = true,
- py::arg("context") = py::none(),
- "Gets from a buffer or ndarray")
- .def_static("get_splat", PyDenseElementsAttribute::getSplat,
- py::arg("shaped_type"), py::arg("element_attr"),
- "Gets a DenseElementsAttr where all values are the same")
- .def_property_readonly("is_splat",
- [](PyDenseElementsAttribute &self) -> bool {
- return mlirDenseElementsAttrIsSplat(self);
- })
- .def_buffer(&PyDenseElementsAttribute::accessBuffer);
- }
+ //----------------------------------------------------------------------------
+ // Mapping of PyDialect
+ //----------------------------------------------------------------------------
+ py::class_<PyDialect>(m, "Dialect")
+ .def(py::init<py::object>(), "descriptor")
+ .def_property_readonly(
+ "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
+ .def("__repr__", [](py::object self) {
+ auto clazz = self.attr("__class__");
+ return py::str("<Dialect ") +
+ self.attr("descriptor").attr("namespace") + py::str(" (class ") +
+ clazz.attr("__module__") + py::str(".") +
+ clazz.attr("__name__") + py::str(")>");
+ });
-private:
- template <typename ElementTy>
- static MlirAttribute
- bulkLoad(MlirContext context,
- MlirAttribute (*ctor)(MlirType, intptr_t, ElementTy *),
- MlirType mlirElementType, py::buffer_info &arrayInfo) {
- SmallVector<int64_t, 4> shape(arrayInfo.shape.begin(),
- arrayInfo.shape.begin() + arrayInfo.ndim);
- auto shapedType =
- mlirRankedTensorTypeGet(shape.size(), shape.data(), mlirElementType);
- intptr_t numElements = arrayInfo.size;
- const ElementTy *contents = static_cast<const ElementTy *>(arrayInfo.ptr);
- return ctor(shapedType, numElements, contents);
- }
-
- static bool isUnsignedIntegerFormat(const std::string &format) {
- if (format.empty())
- return false;
- char code = format[0];
- return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
- code == 'Q';
- }
-
- static bool isSignedIntegerFormat(const std::string &format) {
- if (format.empty())
- return false;
- char code = format[0];
- return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
- code == 'q';
- }
-
- template <typename Type>
- py::buffer_info bufferInfo(MlirType shapedType,
- Type (*value)(MlirAttribute, intptr_t)) {
- intptr_t rank = mlirShapedTypeGetRank(shapedType);
- // Prepare the data for the buffer_info.
- // Buffer is configured for read-only access below.
- Type *data = static_cast<Type *>(
- const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
- // Prepare the shape for the buffer_info.
- SmallVector<intptr_t, 4> shape;
- for (intptr_t i = 0; i < rank; ++i)
- shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
- // Prepare the strides for the buffer_info.
- SmallVector<intptr_t, 4> strides;
- intptr_t strideFactor = 1;
- for (intptr_t i = 1; i < rank; ++i) {
- strideFactor = 1;
- for (intptr_t j = i; j < rank; ++j) {
- strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
- }
- strides.push_back(sizeof(Type) * strideFactor);
- }
- strides.push_back(sizeof(Type));
- return py::buffer_info(data, sizeof(Type),
- py::format_descriptor<Type>::format(), rank, shape,
- strides, /*readonly=*/true);
- }
-}; // namespace
-
-/// Refinement of the PyDenseElementsAttribute for attributes containing integer
-/// (and boolean) values. Supports element access.
-class PyDenseIntElementsAttribute
- : public PyConcreteAttribute<PyDenseIntElementsAttribute,
- PyDenseElementsAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
- static constexpr const char *pyClassName = "DenseIntElementsAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- /// Returns the element at the given linear position. Asserts if the index is
- /// out of range.
- py::int_ dunderGetItem(intptr_t pos) {
- if (pos < 0 || pos >= dunderLen()) {
- throw SetPyError(PyExc_IndexError,
- "attempt to access out of bounds element");
- }
-
- MlirType type = mlirAttributeGetType(*this);
- type = mlirShapedTypeGetElementType(type);
- assert(mlirTypeIsAInteger(type) &&
- "expected integer element type in dense int elements attribute");
- // Dispatch element extraction to an appropriate C function based on the
- // elemental type of the attribute. py::int_ is implicitly constructible
- // from any C++ integral type and handles bitwidth correctly.
- // TODO: consider caching the type properties in the constructor to avoid
- // querying them on each element access.
- unsigned width = mlirIntegerTypeGetWidth(type);
- bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
- if (isUnsigned) {
- if (width == 1) {
- return mlirDenseElementsAttrGetBoolValue(*this, pos);
- }
- if (width == 32) {
- return mlirDenseElementsAttrGetUInt32Value(*this, pos);
- }
- if (width == 64) {
- return mlirDenseElementsAttrGetUInt64Value(*this, pos);
- }
- } else {
- if (width == 1) {
- return mlirDenseElementsAttrGetBoolValue(*this, pos);
- }
- if (width == 32) {
- return mlirDenseElementsAttrGetInt32Value(*this, pos);
- }
- if (width == 64) {
- return mlirDenseElementsAttrGetInt64Value(*this, pos);
- }
- }
- throw SetPyError(PyExc_TypeError, "Unsupported integer type");
- }
-
- static void bindDerived(ClassTy &c) {
- c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
- }
-};
-
-class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
- static constexpr const char *pyClassName = "DictAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
-
- static void bindDerived(ClassTy &c) {
- c.def("__len__", &PyDictAttribute::dunderLen);
- c.def_static(
- "get",
- [](py::dict attributes, DefaultingPyMlirContext context) {
- SmallVector<MlirNamedAttribute> mlirNamedAttributes;
- mlirNamedAttributes.reserve(attributes.size());
- for (auto &it : attributes) {
- auto &mlir_attr = it.second.cast<PyAttribute &>();
- auto name = it.first.cast<std::string>();
- mlirNamedAttributes.push_back(mlirNamedAttributeGet(
- mlirIdentifierGet(mlirAttributeGetContext(mlir_attr),
- toMlirStringRef(name)),
- mlir_attr));
- }
- MlirAttribute attr =
- mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
- mlirNamedAttributes.data());
- return PyDictAttribute(context->getRef(), attr);
- },
- py::arg("value"), py::arg("context") = py::none(),
- "Gets an uniqued dict attribute");
- c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
- MlirAttribute attr =
- mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
- if (mlirAttributeIsNull(attr)) {
- throw SetPyError(PyExc_KeyError,
- "attempt to access a non-existent attribute");
- }
- return PyAttribute(self.getContext(), attr);
- });
- c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
- if (index < 0 || index >= self.dunderLen()) {
- throw SetPyError(PyExc_IndexError,
- "attempt to access out of bounds attribute");
- }
- MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
- return PyNamedAttribute(
- namedAttr.attribute,
- std::string(mlirIdentifierStr(namedAttr.name).data));
- });
- }
-};
-
-/// Refinement of PyDenseElementsAttribute for attributes containing
-/// floating-point values. Supports element access.
-class PyDenseFPElementsAttribute
- : public PyConcreteAttribute<PyDenseFPElementsAttribute,
- PyDenseElementsAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
- static constexpr const char *pyClassName = "DenseFPElementsAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- py::float_ dunderGetItem(intptr_t pos) {
- if (pos < 0 || pos >= dunderLen()) {
- throw SetPyError(PyExc_IndexError,
- "attempt to access out of bounds element");
- }
-
- MlirType type = mlirAttributeGetType(*this);
- type = mlirShapedTypeGetElementType(type);
- // Dispatch element extraction to an appropriate C function based on the
- // elemental type of the attribute. py::float_ is implicitly constructible
- // from float and double.
- // TODO: consider caching the type properties in the constructor to avoid
- // querying them on each element access.
- if (mlirTypeIsAF32(type)) {
- return mlirDenseElementsAttrGetFloatValue(*this, pos);
- }
- if (mlirTypeIsAF64(type)) {
- return mlirDenseElementsAttrGetDoubleValue(*this, pos);
- }
- throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
- }
-
- static void bindDerived(ClassTy &c) {
- c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
- }
-};
-
-class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
- static constexpr const char *pyClassName = "TypeAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyType value, DefaultingPyMlirContext context) {
- MlirAttribute attr = mlirTypeAttrGet(value.get());
- return PyTypeAttribute(context->getRef(), attr);
- },
- py::arg("value"), py::arg("context") = py::none(),
- "Gets a uniqued Type attribute");
- c.def_property_readonly("value", [](PyTypeAttribute &self) {
- return PyType(self.getContext()->getRef(),
- mlirTypeAttrGetValue(self.get()));
- });
- }
-};
-
-/// Unit Attribute subclass. Unit attributes don't have values.
-class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
- static constexpr const char *pyClassName = "UnitAttr";
- using PyConcreteAttribute::PyConcreteAttribute;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- return PyUnitAttribute(context->getRef(),
- mlirUnitAttrGet(context->get()));
- },
- py::arg("context") = py::none(), "Create a Unit attribute.");
- }
-};
-
-} // namespace
-
-//------------------------------------------------------------------------------
-// Builtin type subclasses.
-//------------------------------------------------------------------------------
-
-namespace {
-
-/// CRTP base classes for Python types that subclass Type and should be
-/// castable from it (i.e. via something like IntegerType(t)).
-/// By default, type class hierarchies are one level deep (i.e. a
-/// concrete type class extends PyType); however, intermediate python-visible
-/// base classes can be modeled by specifying a BaseTy.
-template <typename DerivedTy, typename BaseTy = PyType>
-class PyConcreteType : public BaseTy {
-public:
- // Derived classes must define statics for:
- // IsAFunctionTy isaFunction
- // const char *pyClassName
- using ClassTy = py::class_<DerivedTy, BaseTy>;
- using IsAFunctionTy = bool (*)(MlirType);
-
- PyConcreteType() = default;
- PyConcreteType(PyMlirContextRef contextRef, MlirType t)
- : BaseTy(std::move(contextRef), t) {}
- PyConcreteType(PyType &orig)
- : PyConcreteType(orig.getContext(), castFrom(orig)) {}
-
- static MlirType castFrom(PyType &orig) {
- if (!DerivedTy::isaFunction(orig)) {
- auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
- throw SetPyError(PyExc_ValueError, Twine("Cannot cast type to ") +
- DerivedTy::pyClassName +
- " (from " + origRepr + ")");
- }
- return orig;
- }
-
- static void bind(py::module &m) {
- auto cls = ClassTy(m, DerivedTy::pyClassName);
- cls.def(py::init<PyType &>(), py::keep_alive<0, 1>());
- cls.def_static("isinstance", [](PyType &otherType) -> bool {
- return DerivedTy::isaFunction(otherType);
- });
- DerivedTy::bindDerived(cls);
- }
-
- /// Implemented by derived classes to add methods to the Python subclass.
- static void bindDerived(ClassTy &m) {}
-};
-
-class PyIntegerType : public PyConcreteType<PyIntegerType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
- static constexpr const char *pyClassName = "IntegerType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get_signless",
- [](unsigned width, DefaultingPyMlirContext context) {
- MlirType t = mlirIntegerTypeGet(context->get(), width);
- return PyIntegerType(context->getRef(), t);
- },
- py::arg("width"), py::arg("context") = py::none(),
- "Create a signless integer type");
- c.def_static(
- "get_signed",
- [](unsigned width, DefaultingPyMlirContext context) {
- MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
- return PyIntegerType(context->getRef(), t);
- },
- py::arg("width"), py::arg("context") = py::none(),
- "Create a signed integer type");
- c.def_static(
- "get_unsigned",
- [](unsigned width, DefaultingPyMlirContext context) {
- MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
- return PyIntegerType(context->getRef(), t);
- },
- py::arg("width"), py::arg("context") = py::none(),
- "Create an unsigned integer type");
- c.def_property_readonly(
- "width",
- [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
- "Returns the width of the integer type");
- c.def_property_readonly(
- "is_signless",
- [](PyIntegerType &self) -> bool {
- return mlirIntegerTypeIsSignless(self);
- },
- "Returns whether this is a signless integer");
- c.def_property_readonly(
- "is_signed",
- [](PyIntegerType &self) -> bool {
- return mlirIntegerTypeIsSigned(self);
- },
- "Returns whether this is a signed integer");
- c.def_property_readonly(
- "is_unsigned",
- [](PyIntegerType &self) -> bool {
- return mlirIntegerTypeIsUnsigned(self);
- },
- "Returns whether this is an unsigned integer");
- }
-};
-
-/// Index Type subclass - IndexType.
-class PyIndexType : public PyConcreteType<PyIndexType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
- static constexpr const char *pyClassName = "IndexType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirIndexTypeGet(context->get());
- return PyIndexType(context->getRef(), t);
- },
- py::arg("context") = py::none(), "Create a index type.");
- }
-};
-
-/// Floating Point Type subclass - BF16Type.
-class PyBF16Type : public PyConcreteType<PyBF16Type> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
- static constexpr const char *pyClassName = "BF16Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirBF16TypeGet(context->get());
- return PyBF16Type(context->getRef(), t);
- },
- py::arg("context") = py::none(), "Create a bf16 type.");
- }
-};
-
-/// Floating Point Type subclass - F16Type.
-class PyF16Type : public PyConcreteType<PyF16Type> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
- static constexpr const char *pyClassName = "F16Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirF16TypeGet(context->get());
- return PyF16Type(context->getRef(), t);
- },
- py::arg("context") = py::none(), "Create a f16 type.");
- }
-};
-
-/// Floating Point Type subclass - F32Type.
-class PyF32Type : public PyConcreteType<PyF32Type> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
- static constexpr const char *pyClassName = "F32Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirF32TypeGet(context->get());
- return PyF32Type(context->getRef(), t);
- },
- py::arg("context") = py::none(), "Create a f32 type.");
- }
-};
-
-/// Floating Point Type subclass - F64Type.
-class PyF64Type : public PyConcreteType<PyF64Type> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
- static constexpr const char *pyClassName = "F64Type";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirF64TypeGet(context->get());
- return PyF64Type(context->getRef(), t);
- },
- py::arg("context") = py::none(), "Create a f64 type.");
- }
-};
-
-/// None Type subclass - NoneType.
-class PyNoneType : public PyConcreteType<PyNoneType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
- static constexpr const char *pyClassName = "NoneType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](DefaultingPyMlirContext context) {
- MlirType t = mlirNoneTypeGet(context->get());
- return PyNoneType(context->getRef(), t);
- },
- py::arg("context") = py::none(), "Create a none type.");
- }
-};
-
-/// Complex Type subclass - ComplexType.
-class PyComplexType : public PyConcreteType<PyComplexType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
- static constexpr const char *pyClassName = "ComplexType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyType &elementType) {
- // The element must be a floating point or integer scalar type.
- if (mlirTypeIsAIntegerOrFloat(elementType)) {
- MlirType t = mlirComplexTypeGet(elementType);
- return PyComplexType(elementType.getContext(), t);
- }
- throw SetPyError(
- PyExc_ValueError,
- Twine("invalid '") +
- py::repr(py::cast(elementType)).cast<std::string>() +
- "' and expected floating point or integer type.");
- },
- "Create a complex type");
- c.def_property_readonly(
- "element_type",
- [](PyComplexType &self) -> PyType {
- MlirType t = mlirComplexTypeGetElementType(self);
- return PyType(self.getContext(), t);
- },
- "Returns element type.");
- }
-};
-
-class PyShapedType : public PyConcreteType<PyShapedType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAShaped;
- static constexpr const char *pyClassName = "ShapedType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_property_readonly(
- "element_type",
- [](PyShapedType &self) {
- MlirType t = mlirShapedTypeGetElementType(self);
- return PyType(self.getContext(), t);
- },
- "Returns the element type of the shaped type.");
- c.def_property_readonly(
- "has_rank",
- [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); },
- "Returns whether the given shaped type is ranked.");
- c.def_property_readonly(
- "rank",
- [](PyShapedType &self) {
- self.requireHasRank();
- return mlirShapedTypeGetRank(self);
- },
- "Returns the rank of the given ranked shaped type.");
- c.def_property_readonly(
- "has_static_shape",
- [](PyShapedType &self) -> bool {
- return mlirShapedTypeHasStaticShape(self);
- },
- "Returns whether the given shaped type has a static shape.");
- c.def(
- "is_dynamic_dim",
- [](PyShapedType &self, intptr_t dim) -> bool {
- self.requireHasRank();
- return mlirShapedTypeIsDynamicDim(self, dim);
- },
- "Returns whether the dim-th dimension of the given shaped type is "
- "dynamic.");
- c.def(
- "get_dim_size",
- [](PyShapedType &self, intptr_t dim) {
- self.requireHasRank();
- return mlirShapedTypeGetDimSize(self, dim);
- },
- "Returns the dim-th dimension of the given ranked shaped type.");
- c.def_static(
- "is_dynamic_size",
- [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); },
- "Returns whether the given dimension size indicates a dynamic "
- "dimension.");
- c.def(
- "is_dynamic_stride_or_offset",
- [](PyShapedType &self, int64_t val) -> bool {
- self.requireHasRank();
- return mlirShapedTypeIsDynamicStrideOrOffset(val);
- },
- "Returns whether the given value is used as a placeholder for dynamic "
- "strides and offsets in shaped types.");
- }
-
-private:
- void requireHasRank() {
- if (!mlirShapedTypeHasRank(*this)) {
- throw SetPyError(
- PyExc_ValueError,
- "calling this method requires that the type has a rank.");
- }
- }
-};
-
-/// Vector Type subclass - VectorType.
-class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
- static constexpr const char *pyClassName = "VectorType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](std::vector<int64_t> shape, PyType &elementType,
- DefaultingPyLocation loc) {
- MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
- elementType);
- // TODO: Rework error reporting once diagnostic engine is exposed
- // in C API.
- if (mlirTypeIsNull(t)) {
- throw SetPyError(
- PyExc_ValueError,
- Twine("invalid '") +
- py::repr(py::cast(elementType)).cast<std::string>() +
- "' and expected floating point or integer type.");
- }
- return PyVectorType(elementType.getContext(), t);
- },
- py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(),
- "Create a vector type");
- }
-};
-
-/// Ranked Tensor Type subclass - RankedTensorType.
-class PyRankedTensorType
- : public PyConcreteType<PyRankedTensorType, PyShapedType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
- static constexpr const char *pyClassName = "RankedTensorType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](std::vector<int64_t> shape, PyType &elementType,
- DefaultingPyLocation loc) {
- MlirType t = mlirRankedTensorTypeGetChecked(
- loc, shape.size(), shape.data(), elementType);
- // TODO: Rework error reporting once diagnostic engine is exposed
- // in C API.
- if (mlirTypeIsNull(t)) {
- throw SetPyError(
- PyExc_ValueError,
- Twine("invalid '") +
- py::repr(py::cast(elementType)).cast<std::string>() +
- "' and expected floating point, integer, vector or "
- "complex "
- "type.");
- }
- return PyRankedTensorType(elementType.getContext(), t);
- },
- py::arg("shape"), py::arg("element_type"), py::arg("loc") = py::none(),
- "Create a ranked tensor type");
- }
-};
-
-/// Unranked Tensor Type subclass - UnrankedTensorType.
-class PyUnrankedTensorType
- : public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
- static constexpr const char *pyClassName = "UnrankedTensorType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyType &elementType, DefaultingPyLocation loc) {
- MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
- // TODO: Rework error reporting once diagnostic engine is exposed
- // in C API.
- if (mlirTypeIsNull(t)) {
- throw SetPyError(
- PyExc_ValueError,
- Twine("invalid '") +
- py::repr(py::cast(elementType)).cast<std::string>() +
- "' and expected floating point, integer, vector or "
- "complex "
- "type.");
- }
- return PyUnrankedTensorType(elementType.getContext(), t);
- },
- py::arg("element_type"), py::arg("loc") = py::none(),
- "Create a unranked tensor type");
- }
-};
-
-class PyMemRefLayoutMapList;
-
-/// Ranked MemRef Type subclass - MemRefType.
-class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
- static constexpr const char *pyClassName = "MemRefType";
- using PyConcreteType::PyConcreteType;
-
- PyMemRefLayoutMapList getLayout();
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](std::vector<int64_t> shape, PyType &elementType,
- std::vector<PyAffineMap> layout, PyAttribute *memorySpace,
- DefaultingPyLocation loc) {
- SmallVector<MlirAffineMap> maps;
- maps.reserve(layout.size());
- for (PyAffineMap &map : layout)
- maps.push_back(map);
-
- MlirAttribute memSpaceAttr = {};
- if (memorySpace)
- memSpaceAttr = *memorySpace;
-
- MlirType t = mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
- shape.data(), maps.size(),
- maps.data(), memSpaceAttr);
- // TODO: Rework error reporting once diagnostic engine is exposed
- // in C API.
- if (mlirTypeIsNull(t)) {
- throw SetPyError(
- PyExc_ValueError,
- Twine("invalid '") +
- py::repr(py::cast(elementType)).cast<std::string>() +
- "' and expected floating point, integer, vector or "
- "complex "
- "type.");
- }
- return PyMemRefType(elementType.getContext(), t);
- },
- py::arg("shape"), py::arg("element_type"),
- py::arg("layout") = py::list(), py::arg("memory_space") = py::none(),
- py::arg("loc") = py::none(), "Create a memref type")
- .def_property_readonly("layout", &PyMemRefType::getLayout,
- "The list of layout maps of the MemRef type.")
- .def_property_readonly(
- "memory_space",
- [](PyMemRefType &self) -> PyAttribute {
- MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
- return PyAttribute(self.getContext(), a);
- },
- "Returns the memory space of the given MemRef type.");
- }
-};
-
-/// A list of affine layout maps in a memref type. Internally, these are stored
-/// as consecutive elements, random access is cheap. Both the type and the maps
-/// are owned by the context, no need to worry about lifetime extension.
-class PyMemRefLayoutMapList
- : public Sliceable<PyMemRefLayoutMapList, PyAffineMap> {
-public:
- static constexpr const char *pyClassName = "MemRefLayoutMapList";
-
- PyMemRefLayoutMapList(PyMemRefType type, intptr_t startIndex = 0,
- intptr_t length = -1, intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirMemRefTypeGetNumAffineMaps(type) : length,
- step),
- memref(type) {}
-
- intptr_t getNumElements() { return mlirMemRefTypeGetNumAffineMaps(memref); }
-
- PyAffineMap getElement(intptr_t index) {
- return PyAffineMap(memref.getContext(),
- mlirMemRefTypeGetAffineMap(memref, index));
- }
-
- PyMemRefLayoutMapList slice(intptr_t startIndex, intptr_t length,
- intptr_t step) {
- return PyMemRefLayoutMapList(memref, startIndex, length, step);
- }
-
-private:
- PyMemRefType memref;
-};
-
-PyMemRefLayoutMapList PyMemRefType::getLayout() {
- return PyMemRefLayoutMapList(*this);
-}
-
-/// Unranked MemRef Type subclass - UnrankedMemRefType.
-class PyUnrankedMemRefType
- : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
- static constexpr const char *pyClassName = "UnrankedMemRefType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](PyType &elementType, PyAttribute *memorySpace,
- DefaultingPyLocation loc) {
- MlirAttribute memSpaceAttr = {};
- if (memorySpace)
- memSpaceAttr = *memorySpace;
-
- MlirType t =
- mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr);
- // TODO: Rework error reporting once diagnostic engine is exposed
- // in C API.
- if (mlirTypeIsNull(t)) {
- throw SetPyError(
- PyExc_ValueError,
- Twine("invalid '") +
- py::repr(py::cast(elementType)).cast<std::string>() +
- "' and expected floating point, integer, vector or "
- "complex "
- "type.");
- }
- return PyUnrankedMemRefType(elementType.getContext(), t);
- },
- py::arg("element_type"), py::arg("memory_space"),
- py::arg("loc") = py::none(), "Create a unranked memref type")
- .def_property_readonly(
- "memory_space",
- [](PyUnrankedMemRefType &self) -> PyAttribute {
- MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
- return PyAttribute(self.getContext(), a);
- },
- "Returns the memory space of the given Unranked MemRef type.");
- }
-};
-
-/// Tuple Type subclass - TupleType.
-class PyTupleType : public PyConcreteType<PyTupleType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
- static constexpr const char *pyClassName = "TupleType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get_tuple",
- [](py::list elementList, DefaultingPyMlirContext context) {
- intptr_t num = py::len(elementList);
- // Mapping py::list to SmallVector.
- SmallVector<MlirType, 4> elements;
- for (auto element : elementList)
- elements.push_back(element.cast<PyType>());
- MlirType t = mlirTupleTypeGet(context->get(), num, elements.data());
- return PyTupleType(context->getRef(), t);
- },
- py::arg("elements"), py::arg("context") = py::none(),
- "Create a tuple type");
- c.def(
- "get_type",
- [](PyTupleType &self, intptr_t pos) -> PyType {
- MlirType t = mlirTupleTypeGetType(self, pos);
- return PyType(self.getContext(), t);
- },
- "Returns the pos-th type in the tuple type.");
- c.def_property_readonly(
- "num_types",
- [](PyTupleType &self) -> intptr_t {
- return mlirTupleTypeGetNumTypes(self);
- },
- "Returns the number of types contained in a tuple.");
- }
-};
-
-/// Function type.
-class PyFunctionType : public PyConcreteType<PyFunctionType> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
- static constexpr const char *pyClassName = "FunctionType";
- using PyConcreteType::PyConcreteType;
-
- static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](std::vector<PyType> inputs, std::vector<PyType> results,
- DefaultingPyMlirContext context) {
- SmallVector<MlirType, 4> inputsRaw(inputs.begin(), inputs.end());
- SmallVector<MlirType, 4> resultsRaw(results.begin(), results.end());
- MlirType t = mlirFunctionTypeGet(context->get(), inputsRaw.size(),
- inputsRaw.data(), resultsRaw.size(),
- resultsRaw.data());
- return PyFunctionType(context->getRef(), t);
- },
- py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(),
- "Gets a FunctionType from a list of input and result types");
- c.def_property_readonly(
- "inputs",
- [](PyFunctionType &self) {
- MlirType t = self;
- auto contextRef = self.getContext();
- py::list types;
- for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
- ++i) {
- types.append(PyType(contextRef, mlirFunctionTypeGetInput(t, i)));
- }
- return types;
- },
- "Returns the list of input types in the FunctionType.");
- c.def_property_readonly(
- "results",
- [](PyFunctionType &self) {
- auto contextRef = self.getContext();
- py::list types;
- for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
- ++i) {
- types.append(
- PyType(contextRef, mlirFunctionTypeGetResult(self, i)));
- }
- return types;
- },
- "Returns the list of result types in the FunctionType.");
- }
-};
-
-} // namespace
-
-//------------------------------------------------------------------------------
-// PyAffineExpr and subclasses.
-//------------------------------------------------------------------------------
-
-namespace {
-/// CRTP base class for Python MLIR affine expressions that subclass AffineExpr
-/// and should be castable from it. Intermediate hierarchy classes can be
-/// modeled by specifying BaseTy.
-template <typename DerivedTy, typename BaseTy = PyAffineExpr>
-class PyConcreteAffineExpr : public BaseTy {
-public:
- // Derived classes must define statics for:
- // IsAFunctionTy isaFunction
- // const char *pyClassName
- // and redefine bindDerived.
- using ClassTy = py::class_<DerivedTy, BaseTy>;
- using IsAFunctionTy = bool (*)(MlirAffineExpr);
-
- PyConcreteAffineExpr() = default;
- PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
- : BaseTy(std::move(contextRef), affineExpr) {}
- PyConcreteAffineExpr(PyAffineExpr &orig)
- : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {}
-
- static MlirAffineExpr castFrom(PyAffineExpr &orig) {
- if (!DerivedTy::isaFunction(orig)) {
- auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
- throw SetPyError(PyExc_ValueError,
- Twine("Cannot cast affine expression to ") +
- DerivedTy::pyClassName + " (from " + origRepr + ")");
- }
- return orig;
- }
-
- static void bind(py::module &m) {
- auto cls = ClassTy(m, DerivedTy::pyClassName);
- cls.def(py::init<PyAffineExpr &>());
- DerivedTy::bindDerived(cls);
- }
-
- /// Implemented by derived classes to add methods to the Python subclass.
- static void bindDerived(ClassTy &m) {}
-};
-
-class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant;
- static constexpr const char *pyClassName = "AffineConstantExpr";
- using PyConcreteAffineExpr::PyConcreteAffineExpr;
-
- static PyAffineConstantExpr get(intptr_t value,
- DefaultingPyMlirContext context) {
- MlirAffineExpr affineExpr =
- mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value));
- return PyAffineConstantExpr(context->getRef(), affineExpr);
- }
-
- static void bindDerived(ClassTy &c) {
- c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"),
- py::arg("context") = py::none());
- c.def_property_readonly("value", [](PyAffineConstantExpr &self) {
- return mlirAffineConstantExprGetValue(self);
- });
- }
-};
-
-class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim;
- static constexpr const char *pyClassName = "AffineDimExpr";
- using PyConcreteAffineExpr::PyConcreteAffineExpr;
-
- static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) {
- MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos);
- return PyAffineDimExpr(context->getRef(), affineExpr);
- }
-
- static void bindDerived(ClassTy &c) {
- c.def_static("get", &PyAffineDimExpr::get, py::arg("position"),
- py::arg("context") = py::none());
- c.def_property_readonly("position", [](PyAffineDimExpr &self) {
- return mlirAffineDimExprGetPosition(self);
- });
- }
-};
-
-class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol;
- static constexpr const char *pyClassName = "AffineSymbolExpr";
- using PyConcreteAffineExpr::PyConcreteAffineExpr;
-
- static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) {
- MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos);
- return PyAffineSymbolExpr(context->getRef(), affineExpr);
- }
-
- static void bindDerived(ClassTy &c) {
- c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"),
- py::arg("context") = py::none());
- c.def_property_readonly("position", [](PyAffineSymbolExpr &self) {
- return mlirAffineSymbolExprGetPosition(self);
- });
- }
-};
-
-class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary;
- static constexpr const char *pyClassName = "AffineBinaryExpr";
- using PyConcreteAffineExpr::PyConcreteAffineExpr;
-
- PyAffineExpr lhs() {
- MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get());
- return PyAffineExpr(getContext(), lhsExpr);
- }
-
- PyAffineExpr rhs() {
- MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get());
- return PyAffineExpr(getContext(), rhsExpr);
- }
-
- static void bindDerived(ClassTy &c) {
- c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs);
- c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs);
- }
-};
-
-class PyAffineAddExpr
- : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd;
- static constexpr const char *pyClassName = "AffineAddExpr";
- using PyConcreteAffineExpr::PyConcreteAffineExpr;
-
- static PyAffineAddExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
- MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs);
- return PyAffineAddExpr(lhs.getContext(), expr);
- }
-
- static void bindDerived(ClassTy &c) {
- c.def_static("get", &PyAffineAddExpr::get);
- }
-};
-
-class PyAffineMulExpr
- : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul;
- static constexpr const char *pyClassName = "AffineMulExpr";
- using PyConcreteAffineExpr::PyConcreteAffineExpr;
-
- static PyAffineMulExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
- MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs);
- return PyAffineMulExpr(lhs.getContext(), expr);
- }
-
- static void bindDerived(ClassTy &c) {
- c.def_static("get", &PyAffineMulExpr::get);
- }
-};
-
-class PyAffineModExpr
- : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod;
- static constexpr const char *pyClassName = "AffineModExpr";
- using PyConcreteAffineExpr::PyConcreteAffineExpr;
-
- static PyAffineModExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
- MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs);
- return PyAffineModExpr(lhs.getContext(), expr);
- }
-
- static void bindDerived(ClassTy &c) {
- c.def_static("get", &PyAffineModExpr::get);
- }
-};
-
-class PyAffineFloorDivExpr
- : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv;
- static constexpr const char *pyClassName = "AffineFloorDivExpr";
- using PyConcreteAffineExpr::PyConcreteAffineExpr;
-
- static PyAffineFloorDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
- MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs);
- return PyAffineFloorDivExpr(lhs.getContext(), expr);
- }
-
- static void bindDerived(ClassTy &c) {
- c.def_static("get", &PyAffineFloorDivExpr::get);
- }
-};
-
-class PyAffineCeilDivExpr
- : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv;
- static constexpr const char *pyClassName = "AffineCeilDivExpr";
- using PyConcreteAffineExpr::PyConcreteAffineExpr;
-
- static PyAffineCeilDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
- MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs);
- return PyAffineCeilDivExpr(lhs.getContext(), expr);
- }
-
- static void bindDerived(ClassTy &c) {
- c.def_static("get", &PyAffineCeilDivExpr::get);
- }
-};
-} // namespace
-
-bool PyAffineExpr::operator==(const PyAffineExpr &other) {
- return mlirAffineExprEqual(affineExpr, other.affineExpr);
-}
-
-py::object PyAffineExpr::getCapsule() {
- return py::reinterpret_steal<py::object>(
- mlirPythonAffineExprToCapsule(*this));
-}
-
-PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) {
- MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr());
- if (mlirAffineExprIsNull(rawAffineExpr))
- throw py::error_already_set();
- return PyAffineExpr(
- PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)),
- rawAffineExpr);
-}
-
-//------------------------------------------------------------------------------
-// PyAffineMap and utilities.
-//------------------------------------------------------------------------------
-
-namespace {
-/// A list of expressions contained in an affine map. Internally these are
-/// stored as a consecutive array leading to inexpensive random access. Both
-/// the map and the expression are owned by the context so we need not bother
-/// with lifetime extension.
-class PyAffineMapExprList
- : public Sliceable<PyAffineMapExprList, PyAffineExpr> {
-public:
- static constexpr const char *pyClassName = "AffineExprList";
-
- PyAffineMapExprList(PyAffineMap map, intptr_t startIndex = 0,
- intptr_t length = -1, intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirAffineMapGetNumResults(map) : length,
- step),
- affineMap(map) {}
-
- intptr_t getNumElements() { return mlirAffineMapGetNumResults(affineMap); }
-
- PyAffineExpr getElement(intptr_t pos) {
- return PyAffineExpr(affineMap.getContext(),
- mlirAffineMapGetResult(affineMap, pos));
- }
-
- PyAffineMapExprList slice(intptr_t startIndex, intptr_t length,
- intptr_t step) {
- return PyAffineMapExprList(affineMap, startIndex, length, step);
- }
-
-private:
- PyAffineMap affineMap;
-};
-} // end namespace
-
-bool PyAffineMap::operator==(const PyAffineMap &other) {
- return mlirAffineMapEqual(affineMap, other.affineMap);
-}
-
-py::object PyAffineMap::getCapsule() {
- return py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(*this));
-}
-
-PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) {
- MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr());
- if (mlirAffineMapIsNull(rawAffineMap))
- throw py::error_already_set();
- return PyAffineMap(
- PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)),
- rawAffineMap);
-}
-
-//------------------------------------------------------------------------------
-// PyIntegerSet and utilities.
-//------------------------------------------------------------------------------
-
-class PyIntegerSetConstraint {
-public:
- PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {}
-
- PyAffineExpr getExpr() {
- return PyAffineExpr(set.getContext(),
- mlirIntegerSetGetConstraint(set, pos));
- }
-
- bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); }
-
- static void bind(py::module &m) {
- py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint")
- .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr)
- .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq);
- }
-
-private:
- PyIntegerSet set;
- intptr_t pos;
-};
-
-class PyIntegerSetConstraintList
- : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> {
-public:
- static constexpr const char *pyClassName = "IntegerSetConstraintList";
-
- PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0,
- intptr_t length = -1, intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirIntegerSetGetNumConstraints(set) : length,
- step),
- set(set) {}
-
- intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); }
-
- PyIntegerSetConstraint getElement(intptr_t pos) {
- return PyIntegerSetConstraint(set, pos);
- }
-
- PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length,
- intptr_t step) {
- return PyIntegerSetConstraintList(set, startIndex, length, step);
- }
-
-private:
- PyIntegerSet set;
-};
-
-bool PyIntegerSet::operator==(const PyIntegerSet &other) {
- return mlirIntegerSetEqual(integerSet, other.integerSet);
-}
-
-py::object PyIntegerSet::getCapsule() {
- return py::reinterpret_steal<py::object>(
- mlirPythonIntegerSetToCapsule(*this));
-}
-
-PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) {
- MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr());
- if (mlirIntegerSetIsNull(rawIntegerSet))
- throw py::error_already_set();
- return PyIntegerSet(
- PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)),
- rawIntegerSet);
-}
-
-/// Attempts to populate `result` with the content of `list` casted to the
-/// appropriate type (Python and C types are provided as template arguments).
-/// Throws errors in case of failure, using "action" to describe what the caller
-/// was attempting to do.
-template <typename PyType, typename CType>
-static void pyListToVector(py::list list, llvm::SmallVectorImpl<CType> &result,
- StringRef action) {
- result.reserve(py::len(list));
- for (py::handle item : list) {
- try {
- result.push_back(item.cast<PyType>());
- } catch (py::cast_error &err) {
- std::string msg = (llvm::Twine("Invalid expression when ") + action +
- " (" + err.what() + ")")
- .str();
- throw py::cast_error(msg);
- } catch (py::reference_cast_error &err) {
- std::string msg = (llvm::Twine("Invalid expression (None?) when ") +
- action + " (" + err.what() + ")")
- .str();
- throw py::cast_error(msg);
- }
- }
-}
-
-//------------------------------------------------------------------------------
-// Populates the pybind11 IR submodule.
-//------------------------------------------------------------------------------
-
-void mlir::python::populateIRSubmodule(py::module &m) {
- //----------------------------------------------------------------------------
- // Mapping of MlirContext
- //----------------------------------------------------------------------------
- py::class_<PyMlirContext>(m, "Context")
- .def(py::init<>(&PyMlirContext::createNewContextForInit))
- .def_static("_get_live_count", &PyMlirContext::getLiveCount)
- .def("_get_context_again",
- [](PyMlirContext &self) {
- PyMlirContextRef ref = PyMlirContext::forContext(self.get());
- return ref.releaseObject();
- })
- .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
- .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
- .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
- &PyMlirContext::getCapsule)
- .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
- .def("__enter__", &PyMlirContext::contextEnter)
- .def("__exit__", &PyMlirContext::contextExit)
- .def_property_readonly_static(
- "current",
- [](py::object & /*class*/) {
- auto *context = PyThreadContextEntry::getDefaultContext();
- if (!context)
- throw SetPyError(PyExc_ValueError, "No current Context");
- return context;
- },
- "Gets the Context bound to the current thread or raises ValueError")
- .def_property_readonly(
- "dialects",
- [](PyMlirContext &self) { return PyDialects(self.getRef()); },
- "Gets a container for accessing dialects by name")
- .def_property_readonly(
- "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
- "Alias for 'dialect'")
- .def(
- "get_dialect_descriptor",
- [=](PyMlirContext &self, std::string &name) {
- MlirDialect dialect = mlirContextGetOrLoadDialect(
- self.get(), {name.data(), name.size()});
- if (mlirDialectIsNull(dialect)) {
- throw SetPyError(PyExc_ValueError,
- Twine("Dialect '") + name + "' not found");
- }
- return PyDialectDescriptor(self.getRef(), dialect);
- },
- "Gets or loads a dialect by name, returning its descriptor object")
- .def_property(
- "allow_unregistered_dialects",
- [](PyMlirContext &self) -> bool {
- return mlirContextGetAllowUnregisteredDialects(self.get());
- },
- [](PyMlirContext &self, bool value) {
- mlirContextSetAllowUnregisteredDialects(self.get(), value);
- });
-
- //----------------------------------------------------------------------------
- // Mapping of PyDialectDescriptor
- //----------------------------------------------------------------------------
- py::class_<PyDialectDescriptor>(m, "DialectDescriptor")
- .def_property_readonly("namespace",
- [](PyDialectDescriptor &self) {
- MlirStringRef ns =
- mlirDialectGetNamespace(self.get());
- return py::str(ns.data, ns.length);
- })
- .def("__repr__", [](PyDialectDescriptor &self) {
- MlirStringRef ns = mlirDialectGetNamespace(self.get());
- std::string repr("<DialectDescriptor ");
- repr.append(ns.data, ns.length);
- repr.append(">");
- return repr;
- });
-
- //----------------------------------------------------------------------------
- // Mapping of PyDialects
- //----------------------------------------------------------------------------
- py::class_<PyDialects>(m, "Dialects")
- .def("__getitem__",
- [=](PyDialects &self, std::string keyName) {
- MlirDialect dialect =
- self.getDialectForKey(keyName, /*attrError=*/false);
- py::object descriptor =
- py::cast(PyDialectDescriptor{self.getContext(), dialect});
- return createCustomDialectWrapper(keyName, std::move(descriptor));
- })
- .def("__getattr__", [=](PyDialects &self, std::string attrName) {
- MlirDialect dialect =
- self.getDialectForKey(attrName, /*attrError=*/true);
- py::object descriptor =
- py::cast(PyDialectDescriptor{self.getContext(), dialect});
- return createCustomDialectWrapper(attrName, std::move(descriptor));
- });
-
- //----------------------------------------------------------------------------
- // Mapping of PyDialect
- //----------------------------------------------------------------------------
- py::class_<PyDialect>(m, "Dialect")
- .def(py::init<py::object>(), "descriptor")
- .def_property_readonly(
- "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
- .def("__repr__", [](py::object self) {
- auto clazz = self.attr("__class__");
- return py::str("<Dialect ") +
- self.attr("descriptor").attr("namespace") + py::str(" (class ") +
- clazz.attr("__module__") + py::str(".") +
- clazz.attr("__name__") + py::str(")>");
- });
-
- //----------------------------------------------------------------------------
- // Mapping of Location
- //----------------------------------------------------------------------------
- py::class_<PyLocation>(m, "Location")
- .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
- .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
- .def("__enter__", &PyLocation::contextEnter)
- .def("__exit__", &PyLocation::contextExit)
- .def("__eq__",
- [](PyLocation &self, PyLocation &other) -> bool {
- return mlirLocationEqual(self, other);
- })
- .def("__eq__", [](PyLocation &self, py::object other) { return false; })
- .def_property_readonly_static(
- "current",
- [](py::object & /*class*/) {
- auto *loc = PyThreadContextEntry::getDefaultLocation();
- if (!loc)
- throw SetPyError(PyExc_ValueError, "No current Location");
- return loc;
- },
- "Gets the Location bound to the current thread or raises ValueError")
- .def_static(
- "unknown",
- [](DefaultingPyMlirContext context) {
- return PyLocation(context->getRef(),
- mlirLocationUnknownGet(context->get()));
- },
- py::arg("context") = py::none(),
- "Gets a Location representing an unknown location")
- .def_static(
- "file",
- [](std::string filename, int line, int col,
- DefaultingPyMlirContext context) {
- return PyLocation(
- context->getRef(),
- mlirLocationFileLineColGet(
- context->get(), toMlirStringRef(filename), line, col));
- },
- py::arg("filename"), py::arg("line"), py::arg("col"),
- py::arg("context") = py::none(), kContextGetFileLocationDocstring)
- .def_property_readonly(
- "context",
- [](PyLocation &self) { return self.getContext().getObject(); },
- "Context that owns the Location")
- .def("__repr__", [](PyLocation &self) {
- PyPrintAccumulator printAccum;
- mlirLocationPrint(self, printAccum.getCallback(),
- printAccum.getUserData());
- return printAccum.join();
- });
+ //----------------------------------------------------------------------------
+ // Mapping of Location
+ //----------------------------------------------------------------------------
+ py::class_<PyLocation>(m, "Location")
+ .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
+ .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
+ .def("__enter__", &PyLocation::contextEnter)
+ .def("__exit__", &PyLocation::contextExit)
+ .def("__eq__",
+ [](PyLocation &self, PyLocation &other) -> bool {
+ return mlirLocationEqual(self, other);
+ })
+ .def("__eq__", [](PyLocation &self, py::object other) { return false; })
+ .def_property_readonly_static(
+ "current",
+ [](py::object & /*class*/) {
+ auto *loc = PyThreadContextEntry::getDefaultLocation();
+ if (!loc)
+ throw SetPyError(PyExc_ValueError, "No current Location");
+ return loc;
+ },
+ "Gets the Location bound to the current thread or raises ValueError")
+ .def_static(
+ "unknown",
+ [](DefaultingPyMlirContext context) {
+ return PyLocation(context->getRef(),
+ mlirLocationUnknownGet(context->get()));
+ },
+ py::arg("context") = py::none(),
+ "Gets a Location representing an unknown location")
+ .def_static(
+ "file",
+ [](std::string filename, int line, int col,
+ DefaultingPyMlirContext context) {
+ return PyLocation(
+ context->getRef(),
+ mlirLocationFileLineColGet(
+ context->get(), toMlirStringRef(filename), line, col));
+ },
+ py::arg("filename"), py::arg("line"), py::arg("col"),
+ py::arg("context") = py::none(), kContextGetFileLocationDocstring)
+ .def_property_readonly(
+ "context",
+ [](PyLocation &self) { return self.getContext().getObject(); },
+ "Context that owns the Location")
+ .def("__repr__", [](PyLocation &self) {
+ PyPrintAccumulator printAccum;
+ mlirLocationPrint(self, printAccum.getCallback(),
+ printAccum.getUserData());
+ return printAccum.join();
+ });
//----------------------------------------------------------------------------
// Mapping of Module
@@ -4022,22 +2259,6 @@ void mlir::python::populateIRSubmodule(py::module &m) {
py::keep_alive<0, 1>(),
"The underlying generic attribute of the NamedAttribute binding");
- // Builtin attribute bindings.
- PyAffineMapAttribute::bind(m);
- PyArrayAttribute::bind(m);
- PyArrayAttribute::PyArrayAttributeIterator::bind(m);
- PyBoolAttribute::bind(m);
- PyDenseElementsAttribute::bind(m);
- PyDenseFPElementsAttribute::bind(m);
- PyDenseIntElementsAttribute::bind(m);
- PyDictAttribute::bind(m);
- PyFlatSymbolRefAttribute::bind(m);
- PyFloatAttribute::bind(m);
- PyIntegerAttribute::bind(m);
- PyStringAttribute::bind(m);
- PyTypeAttribute::bind(m);
- PyUnitAttribute::bind(m);
-
//----------------------------------------------------------------------------
// Mapping of PyType.
//----------------------------------------------------------------------------
@@ -4088,25 +2309,6 @@ void mlir::python::populateIRSubmodule(py::module &m) {
return printAccum.join();
});
- // Builtin type bindings.
- PyIntegerType::bind(m);
- PyIndexType::bind(m);
- PyBF16Type::bind(m);
- PyF16Type::bind(m);
- PyF32Type::bind(m);
- PyF64Type::bind(m);
- PyNoneType::bind(m);
- PyComplexType::bind(m);
- PyShapedType::bind(m);
- PyVectorType::bind(m);
- PyRankedTensorType::bind(m);
- PyUnrankedTensorType::bind(m);
- PyMemRefType::bind(m);
- PyMemRefLayoutMapList::bind(m);
- PyUnrankedMemRefType::bind(m);
- PyTupleType::bind(m);
- PyFunctionType::bind(m);
-
//----------------------------------------------------------------------------
// Mapping of Value.
//----------------------------------------------------------------------------
@@ -4152,359 +2354,4 @@ void mlir::python::populateIRSubmodule(py::module &m) {
PyOpResultList::bind(m);
PyRegionIterator::bind(m);
PyRegionList::bind(m);
-
- //----------------------------------------------------------------------------
- // Mapping of PyAffineExpr and derived classes.
- //----------------------------------------------------------------------------
- py::class_<PyAffineExpr>(m, "AffineExpr")
- .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
- &PyAffineExpr::getCapsule)
- .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule)
- .def("__add__",
- [](PyAffineExpr &self, PyAffineExpr &other) {
- return PyAffineAddExpr::get(self, other);
- })
- .def("__mul__",
- [](PyAffineExpr &self, PyAffineExpr &other) {
- return PyAffineMulExpr::get(self, other);
- })
- .def("__mod__",
- [](PyAffineExpr &self, PyAffineExpr &other) {
- return PyAffineModExpr::get(self, other);
- })
- .def("__sub__",
- [](PyAffineExpr &self, PyAffineExpr &other) {
- auto negOne =
- PyAffineConstantExpr::get(-1, *self.getContext().get());
- return PyAffineAddExpr::get(self,
- PyAffineMulExpr::get(negOne, other));
- })
- .def("__eq__", [](PyAffineExpr &self,
- PyAffineExpr &other) { return self == other; })
- .def("__eq__",
- [](PyAffineExpr &self, py::object &other) { return false; })
- .def("__str__",
- [](PyAffineExpr &self) {
- PyPrintAccumulator printAccum;
- mlirAffineExprPrint(self, printAccum.getCallback(),
- printAccum.getUserData());
- return printAccum.join();
- })
- .def("__repr__",
- [](PyAffineExpr &self) {
- PyPrintAccumulator printAccum;
- printAccum.parts.append("AffineExpr(");
- mlirAffineExprPrint(self, printAccum.getCallback(),
- printAccum.getUserData());
- printAccum.parts.append(")");
- return printAccum.join();
- })
- .def_property_readonly(
- "context",
- [](PyAffineExpr &self) { return self.getContext().getObject(); })
- .def_static(
- "get_add", &PyAffineAddExpr::get,
- "Gets an affine expression containing a sum of two expressions.")
- .def_static(
- "get_mul", &PyAffineMulExpr::get,
- "Gets an affine expression containing a product of two expressions.")
- .def_static("get_mod", &PyAffineModExpr::get,
- "Gets an affine expression containing the modulo of dividing "
- "one expression by another.")
- .def_static("get_floor_div", &PyAffineFloorDivExpr::get,
- "Gets an affine expression containing the rounded-down "
- "result of dividing one expression by another.")
- .def_static("get_ceil_div", &PyAffineCeilDivExpr::get,
- "Gets an affine expression containing the rounded-up result "
- "of dividing one expression by another.")
- .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"),
- py::arg("context") = py::none(),
- "Gets a constant affine expression with the given value.")
- .def_static(
- "get_dim", &PyAffineDimExpr::get, py::arg("position"),
- py::arg("context") = py::none(),
- "Gets an affine expression of a dimension at the given position.")
- .def_static(
- "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"),
- py::arg("context") = py::none(),
- "Gets an affine expression of a symbol at the given position.")
- .def(
- "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); },
- kDumpDocstring);
- PyAffineConstantExpr::bind(m);
- PyAffineDimExpr::bind(m);
- PyAffineSymbolExpr::bind(m);
- PyAffineBinaryExpr::bind(m);
- PyAffineAddExpr::bind(m);
- PyAffineMulExpr::bind(m);
- PyAffineModExpr::bind(m);
- PyAffineFloorDivExpr::bind(m);
- PyAffineCeilDivExpr::bind(m);
-
- //----------------------------------------------------------------------------
- // Mapping of PyAffineMap.
- //----------------------------------------------------------------------------
- py::class_<PyAffineMap>(m, "AffineMap")
- .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
- &PyAffineMap::getCapsule)
- .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule)
- .def("__eq__",
- [](PyAffineMap &self, PyAffineMap &other) { return self == other; })
- .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; })
- .def("__str__",
- [](PyAffineMap &self) {
- PyPrintAccumulator printAccum;
- mlirAffineMapPrint(self, printAccum.getCallback(),
- printAccum.getUserData());
- return printAccum.join();
- })
- .def("__repr__",
- [](PyAffineMap &self) {
- PyPrintAccumulator printAccum;
- printAccum.parts.append("AffineMap(");
- mlirAffineMapPrint(self, printAccum.getCallback(),
- printAccum.getUserData());
- printAccum.parts.append(")");
- return printAccum.join();
- })
- .def_property_readonly(
- "context",
- [](PyAffineMap &self) { return self.getContext().getObject(); },
- "Context that owns the Affine Map")
- .def(
- "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); },
- kDumpDocstring)
- .def_static(
- "get",
- [](intptr_t dimCount, intptr_t symbolCount, py::list exprs,
- DefaultingPyMlirContext context) {
- SmallVector<MlirAffineExpr> affineExprs;
- pyListToVector<PyAffineExpr, MlirAffineExpr>(
- exprs, affineExprs, "attempting to create an AffineMap");
- MlirAffineMap map =
- mlirAffineMapGet(context->get(), dimCount, symbolCount,
- affineExprs.size(), affineExprs.data());
- return PyAffineMap(context->getRef(), map);
- },
- py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"),
- py::arg("context") = py::none(),
- "Gets a map with the given expressions as results.")
- .def_static(
- "get_constant",
- [](intptr_t value, DefaultingPyMlirContext context) {
- MlirAffineMap affineMap =
- mlirAffineMapConstantGet(context->get(), value);
- return PyAffineMap(context->getRef(), affineMap);
- },
- py::arg("value"), py::arg("context") = py::none(),
- "Gets an affine map with a single constant result")
- .def_static(
- "get_empty",
- [](DefaultingPyMlirContext context) {
- MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get());
- return PyAffineMap(context->getRef(), affineMap);
- },
- py::arg("context") = py::none(), "Gets an empty affine map.")
- .def_static(
- "get_identity",
- [](intptr_t nDims, DefaultingPyMlirContext context) {
- MlirAffineMap affineMap =
- mlirAffineMapMultiDimIdentityGet(context->get(), nDims);
- return PyAffineMap(context->getRef(), affineMap);
- },
- py::arg("n_dims"), py::arg("context") = py::none(),
- "Gets an identity map with the given number of dimensions.")
- .def_static(
- "get_minor_identity",
- [](intptr_t nDims, intptr_t nResults,
- DefaultingPyMlirContext context) {
- MlirAffineMap affineMap =
- mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults);
- return PyAffineMap(context->getRef(), affineMap);
- },
- py::arg("n_dims"), py::arg("n_results"),
- py::arg("context") = py::none(),
- "Gets a minor identity map with the given number of dimensions and "
- "results.")
- .def_static(
- "get_permutation",
- [](std::vector<unsigned> permutation,
- DefaultingPyMlirContext context) {
- if (!isPermutation(permutation))
- throw py::cast_error("Invalid permutation when attempting to "
- "create an AffineMap");
- MlirAffineMap affineMap = mlirAffineMapPermutationGet(
- context->get(), permutation.size(), permutation.data());
- return PyAffineMap(context->getRef(), affineMap);
- },
- py::arg("permutation"), py::arg("context") = py::none(),
- "Gets an affine map that permutes its inputs.")
- .def("get_submap",
- [](PyAffineMap &self, std::vector<intptr_t> &resultPos) {
- intptr_t numResults = mlirAffineMapGetNumResults(self);
- for (intptr_t pos : resultPos) {
- if (pos < 0 || pos >= numResults)
- throw py::value_error("result position out of bounds");
- }
- MlirAffineMap affineMap = mlirAffineMapGetSubMap(
- self, resultPos.size(), resultPos.data());
- return PyAffineMap(self.getContext(), affineMap);
- })
- .def("get_major_submap",
- [](PyAffineMap &self, intptr_t nResults) {
- if (nResults >= mlirAffineMapGetNumResults(self))
- throw py::value_error("number of results out of bounds");
- MlirAffineMap affineMap =
- mlirAffineMapGetMajorSubMap(self, nResults);
- return PyAffineMap(self.getContext(), affineMap);
- })
- .def("get_minor_submap",
- [](PyAffineMap &self, intptr_t nResults) {
- if (nResults >= mlirAffineMapGetNumResults(self))
- throw py::value_error("number of results out of bounds");
- MlirAffineMap affineMap =
- mlirAffineMapGetMinorSubMap(self, nResults);
- return PyAffineMap(self.getContext(), affineMap);
- })
- .def_property_readonly(
- "is_permutation",
- [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); })
- .def_property_readonly("is_projected_permutation",
- [](PyAffineMap &self) {
- return mlirAffineMapIsProjectedPermutation(self);
- })
- .def_property_readonly(
- "n_dims",
- [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); })
- .def_property_readonly(
- "n_inputs",
- [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); })
- .def_property_readonly(
- "n_symbols",
- [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); })
- .def_property_readonly("results", [](PyAffineMap &self) {
- return PyAffineMapExprList(self);
- });
- PyAffineMapExprList::bind(m);
-
- //----------------------------------------------------------------------------
- // Mapping of PyIntegerSet.
- //----------------------------------------------------------------------------
- py::class_<PyIntegerSet>(m, "IntegerSet")
- .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
- &PyIntegerSet::getCapsule)
- .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule)
- .def("__eq__", [](PyIntegerSet &self,
- PyIntegerSet &other) { return self == other; })
- .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; })
- .def("__str__",
- [](PyIntegerSet &self) {
- PyPrintAccumulator printAccum;
- mlirIntegerSetPrint(self, printAccum.getCallback(),
- printAccum.getUserData());
- return printAccum.join();
- })
- .def("__repr__",
- [](PyIntegerSet &self) {
- PyPrintAccumulator printAccum;
- printAccum.parts.append("IntegerSet(");
- mlirIntegerSetPrint(self, printAccum.getCallback(),
- printAccum.getUserData());
- printAccum.parts.append(")");
- return printAccum.join();
- })
- .def_property_readonly(
- "context",
- [](PyIntegerSet &self) { return self.getContext().getObject(); })
- .def(
- "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); },
- kDumpDocstring)
- .def_static(
- "get",
- [](intptr_t numDims, intptr_t numSymbols, py::list exprs,
- std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
- if (exprs.size() != eqFlags.size())
- throw py::value_error(
- "Expected the number of constraints to match "
- "that of equality flags");
- if (exprs.empty())
- throw py::value_error("Expected non-empty list of constraints");
-
- // Copy over to a SmallVector because std::vector has a
- // specialization for booleans that packs data and does not
- // expose a `bool *`.
- SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end());
-
- SmallVector<MlirAffineExpr> affineExprs;
- pyListToVector<PyAffineExpr>(exprs, affineExprs,
- "attempting to create an IntegerSet");
- MlirIntegerSet set = mlirIntegerSetGet(
- context->get(), numDims, numSymbols, exprs.size(),
- affineExprs.data(), flags.data());
- return PyIntegerSet(context->getRef(), set);
- },
- py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"),
- py::arg("eq_flags"), py::arg("context") = py::none())
- .def_static(
- "get_empty",
- [](intptr_t numDims, intptr_t numSymbols,
- DefaultingPyMlirContext context) {
- MlirIntegerSet set =
- mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols);
- return PyIntegerSet(context->getRef(), set);
- },
- py::arg("num_dims"), py::arg("num_symbols"),
- py::arg("context") = py::none())
- .def("get_replaced",
- [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs,
- intptr_t numResultDims, intptr_t numResultSymbols) {
- if (static_cast<intptr_t>(dimExprs.size()) !=
- mlirIntegerSetGetNumDims(self))
- throw py::value_error(
- "Expected the number of dimension replacement expressions "
- "to match that of dimensions");
- if (static_cast<intptr_t>(symbolExprs.size()) !=
- mlirIntegerSetGetNumSymbols(self))
- throw py::value_error(
- "Expected the number of symbol replacement expressions "
- "to match that of symbols");
-
- SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs;
- pyListToVector<PyAffineExpr>(
- dimExprs, dimAffineExprs,
- "attempting to create an IntegerSet by replacing dimensions");
- pyListToVector<PyAffineExpr>(
- symbolExprs, symbolAffineExprs,
- "attempting to create an IntegerSet by replacing symbols");
- MlirIntegerSet set = mlirIntegerSetReplaceGet(
- self, dimAffineExprs.data(), symbolAffineExprs.data(),
- numResultDims, numResultSymbols);
- return PyIntegerSet(self.getContext(), set);
- })
- .def_property_readonly("is_canonical_empty",
- [](PyIntegerSet &self) {
- return mlirIntegerSetIsCanonicalEmpty(self);
- })
- .def_property_readonly(
- "n_dims",
- [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); })
- .def_property_readonly(
- "n_symbols",
- [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); })
- .def_property_readonly(
- "n_inputs",
- [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); })
- .def_property_readonly("n_equalities",
- [](PyIntegerSet &self) {
- return mlirIntegerSetGetNumEqualities(self);
- })
- .def_property_readonly("n_inequalities",
- [](PyIntegerSet &self) {
- return mlirIntegerSetGetNumInequalities(self);
- })
- .def_property_readonly("constraints", [](PyIntegerSet &self) {
- return PyIntegerSetConstraintList(self);
- });
- PyIntegerSetConstraint::bind(m);
- PyIntegerSetConstraintList::bind(m);
}
diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModule.h
similarity index 99%
rename from mlir/lib/Bindings/Python/IRModules.h
rename to mlir/lib/Bindings/Python/IRModule.h
index 8140d704300d..5c710abe789a 100644
--- a/mlir/lib/Bindings/Python/IRModules.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -747,7 +747,10 @@ class PyIntegerSet : public BaseContextObject {
MlirIntegerSet integerSet;
};
-void populateIRSubmodule(pybind11::module &m);
+void populateIRAffine(pybind11::module &m);
+void populateIRAttributes(pybind11::module &m);
+void populateIRCore(pybind11::module &m);
+void populateIRTypes(pybind11::module &m);
} // namespace python
} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
new file mode 100644
index 000000000000..96f6bf6666c9
--- /dev/null
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -0,0 +1,678 @@
+//===- IRTypes.cpp - Exports builtin and standard types -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "IRModule.h"
+
+#include "PybindUtils.h"
+
+#include "mlir-c/BuiltinTypes.h"
+
+namespace py = pybind11;
+using namespace mlir;
+using namespace mlir::python;
+
+using llvm::SmallVector;
+using llvm::Twine;
+
+namespace {
+
+/// Checks whether the given type is an integer or float type.
+static int mlirTypeIsAIntegerOrFloat(MlirType type) {
+ return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
+ mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
+}
+
+/// CRTP base classes for Python types that subclass Type and should be
+/// castable from it (i.e. via something like IntegerType(t)).
+/// By default, type class hierarchies are one level deep (i.e. a
+/// concrete type class extends PyType); however, intermediate python-visible
+/// base classes can be modeled by specifying a BaseTy.
+template <typename DerivedTy, typename BaseTy = PyType>
+class PyConcreteType : public BaseTy {
+public:
+ // Derived classes must define statics for:
+ // IsAFunctionTy isaFunction
+ // const char *pyClassName
+ using ClassTy = py::class_<DerivedTy, BaseTy>;
+ using IsAFunctionTy = bool (*)(MlirType);
+
+ PyConcreteType() = default;
+ PyConcreteType(PyMlirContextRef contextRef, MlirType t)
+ : BaseTy(std::move(contextRef), t) {}
+ PyConcreteType(PyType &orig)
+ : PyConcreteType(orig.getContext(), castFrom(orig)) {}
+
+ static MlirType castFrom(PyType &orig) {
+ if (!DerivedTy::isaFunction(orig)) {
+ auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
+ throw SetPyError(PyExc_ValueError, Twine("Cannot cast type to ") +
+ DerivedTy::pyClassName +
+ " (from " + origRepr + ")");
+ }
+ return orig;
+ }
+
+ static void bind(py::module &m) {
+ auto cls = ClassTy(m, DerivedTy::pyClassName);
+ cls.def(py::init<PyType &>(), py::keep_alive<0, 1>());
+ cls.def_static("isinstance", [](PyType &otherType) -> bool {
+ return DerivedTy::isaFunction(otherType);
+ });
+ DerivedTy::bindDerived(cls);
+ }
+
+ /// Implemented by derived classes to add methods to the Python subclass.
+ static void bindDerived(ClassTy &m) {}
+};
+
+class PyIntegerType : public PyConcreteType<PyIntegerType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
+ static constexpr const char *pyClassName = "IntegerType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get_signless",
+ [](unsigned width, DefaultingPyMlirContext context) {
+ MlirType t = mlirIntegerTypeGet(context->get(), width);
+ return PyIntegerType(context->getRef(), t);
+ },
+ py::arg("width"), py::arg("context") = py::none(),
+ "Create a signless integer type");
+ c.def_static(
+ "get_signed",
+ [](unsigned width, DefaultingPyMlirContext context) {
+ MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
+ return PyIntegerType(context->getRef(), t);
+ },
+ py::arg("width"), py::arg("context") = py::none(),
+ "Create a signed integer type");
+ c.def_static(
+ "get_unsigned",
+ [](unsigned width, DefaultingPyMlirContext context) {
+ MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
+ return PyIntegerType(context->getRef(), t);
+ },
+ py::arg("width"), py::arg("context") = py::none(),
+ "Create an unsigned integer type");
+ c.def_property_readonly(
+ "width",
+ [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
+ "Returns the width of the integer type");
+ c.def_property_readonly(
+ "is_signless",
+ [](PyIntegerType &self) -> bool {
+ return mlirIntegerTypeIsSignless(self);
+ },
+ "Returns whether this is a signless integer");
+ c.def_property_readonly(
+ "is_signed",
+ [](PyIntegerType &self) -> bool {
+ return mlirIntegerTypeIsSigned(self);
+ },
+ "Returns whether this is a signed integer");
+ c.def_property_readonly(
+ "is_unsigned",
+ [](PyIntegerType &self) -> bool {
+ return mlirIntegerTypeIsUnsigned(self);
+ },
+ "Returns whether this is an unsigned integer");
+ }
+};
+
+/// Index Type subclass - IndexType.
+class PyIndexType : public PyConcreteType<PyIndexType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
+ static constexpr const char *pyClassName = "IndexType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirIndexTypeGet(context->get());
+ return PyIndexType(context->getRef(), t);
+ },
+ py::arg("context") = py::none(), "Create a index type.");
+ }
+};
+
+/// Floating Point Type subclass - BF16Type.
+class PyBF16Type : public PyConcreteType<PyBF16Type> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
+ static constexpr const char *pyClassName = "BF16Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirBF16TypeGet(context->get());
+ return PyBF16Type(context->getRef(), t);
+ },
+ py::arg("context") = py::none(), "Create a bf16 type.");
+ }
+};
+
+/// Floating Point Type subclass - F16Type.
+class PyF16Type : public PyConcreteType<PyF16Type> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
+ static constexpr const char *pyClassName = "F16Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirF16TypeGet(context->get());
+ return PyF16Type(context->getRef(), t);
+ },
+ py::arg("context") = py::none(), "Create a f16 type.");
+ }
+};
+
+/// Floating Point Type subclass - F32Type.
+class PyF32Type : public PyConcreteType<PyF32Type> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
+ static constexpr const char *pyClassName = "F32Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirF32TypeGet(context->get());
+ return PyF32Type(context->getRef(), t);
+ },
+ py::arg("context") = py::none(), "Create a f32 type.");
+ }
+};
+
+/// Floating Point Type subclass - F64Type.
+class PyF64Type : public PyConcreteType<PyF64Type> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
+ static constexpr const char *pyClassName = "F64Type";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirF64TypeGet(context->get());
+ return PyF64Type(context->getRef(), t);
+ },
+ py::arg("context") = py::none(), "Create a f64 type.");
+ }
+};
+
+/// None Type subclass - NoneType.
+class PyNoneType : public PyConcreteType<PyNoneType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
+ static constexpr const char *pyClassName = "NoneType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirNoneTypeGet(context->get());
+ return PyNoneType(context->getRef(), t);
+ },
+ py::arg("context") = py::none(), "Create a none type.");
+ }
+};
+
+/// Complex Type subclass - ComplexType.
+class PyComplexType : public PyConcreteType<PyComplexType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
+ static constexpr const char *pyClassName = "ComplexType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyType &elementType) {
+ // The element must be a floating point or integer scalar type.
+ if (mlirTypeIsAIntegerOrFloat(elementType)) {
+ MlirType t = mlirComplexTypeGet(elementType);
+ return PyComplexType(elementType.getContext(), t);
+ }
+ throw SetPyError(
+ PyExc_ValueError,
+ Twine("invalid '") +
+ py::repr(py::cast(elementType)).cast<std::string>() +
+ "' and expected floating point or integer type.");
+ },
+ "Create a complex type");
+ c.def_property_readonly(
+ "element_type",
+ [](PyComplexType &self) -> PyType {
+ MlirType t = mlirComplexTypeGetElementType(self);
+ return PyType(self.getContext(), t);
+ },
+ "Returns element type.");
+ }
+};
+
+class PyShapedType : public PyConcreteType<PyShapedType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAShaped;
+ static constexpr const char *pyClassName = "ShapedType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_property_readonly(
+ "element_type",
+ [](PyShapedType &self) {
+ MlirType t = mlirShapedTypeGetElementType(self);
+ return PyType(self.getContext(), t);
+ },
+ "Returns the element type of the shaped type.");
+ c.def_property_readonly(
+ "has_rank",
+ [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); },
+ "Returns whether the given shaped type is ranked.");
+ c.def_property_readonly(
+ "rank",
+ [](PyShapedType &self) {
+ self.requireHasRank();
+ return mlirShapedTypeGetRank(self);
+ },
+ "Returns the rank of the given ranked shaped type.");
+ c.def_property_readonly(
+ "has_static_shape",
+ [](PyShapedType &self) -> bool {
+ return mlirShapedTypeHasStaticShape(self);
+ },
+ "Returns whether the given shaped type has a static shape.");
+ c.def(
+ "is_dynamic_dim",
+ [](PyShapedType &self, intptr_t dim) -> bool {
+ self.requireHasRank();
+ return mlirShapedTypeIsDynamicDim(self, dim);
+ },
+ "Returns whether the dim-th dimension of the given shaped type is "
+ "dynamic.");
+ c.def(
+ "get_dim_size",
+ [](PyShapedType &self, intptr_t dim) {
+ self.requireHasRank();
+ return mlirShapedTypeGetDimSize(self, dim);
+ },
+ "Returns the dim-th dimension of the given ranked shaped type.");
+ c.def_static(
+ "is_dynamic_size",
+ [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); },
+ "Returns whether the given dimension size indicates a dynamic "
+ "dimension.");
+ c.def(
+ "is_dynamic_stride_or_offset",
+ [](PyShapedType &self, int64_t val) -> bool {
+ self.requireHasRank();
+ return mlirShapedTypeIsDynamicStrideOrOffset(val);
+ },
+ "Returns whether the given value is used as a placeholder for dynamic "
+ "strides and offsets in shaped types.");
+ }
+
+private:
+ void requireHasRank() {
+ if (!mlirShapedTypeHasRank(*this)) {
+ throw SetPyError(
+ PyExc_ValueError,
+ "calling this method requires that the type has a rank.");
+ }
+ }
+};
+
+/// Vector Type subclass - VectorType.
+class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
+ static constexpr const char *pyClassName = "VectorType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](std::vector<int64_t> shape, PyType &elementType,
+ DefaultingPyLocation loc) {
+ MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
+ elementType);
+ // TODO: Rework error reporting once diagnostic engine is exposed
+ // in C API.
+ if (mlirTypeIsNull(t)) {
+ throw SetPyError(
+ PyExc_ValueError,
+ Twine("invalid '") +
+ py::repr(py::cast(elementType)).cast<std::string>() +
+ "' and expected floating point or integer type.");
+ }
+ return PyVectorType(elementType.getContext(), t);
+ },
+ py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(),
+ "Create a vector type");
+ }
+};
+
+/// Ranked Tensor Type subclass - RankedTensorType.
+class PyRankedTensorType
+ : public PyConcreteType<PyRankedTensorType, PyShapedType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
+ static constexpr const char *pyClassName = "RankedTensorType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](std::vector<int64_t> shape, PyType &elementType,
+ DefaultingPyLocation loc) {
+ MlirType t = mlirRankedTensorTypeGetChecked(
+ loc, shape.size(), shape.data(), elementType);
+ // TODO: Rework error reporting once diagnostic engine is exposed
+ // in C API.
+ if (mlirTypeIsNull(t)) {
+ throw SetPyError(
+ PyExc_ValueError,
+ Twine("invalid '") +
+ py::repr(py::cast(elementType)).cast<std::string>() +
+ "' and expected floating point, integer, vector or "
+ "complex "
+ "type.");
+ }
+ return PyRankedTensorType(elementType.getContext(), t);
+ },
+ py::arg("shape"), py::arg("element_type"), py::arg("loc") = py::none(),
+ "Create a ranked tensor type");
+ }
+};
+
+/// Unranked Tensor Type subclass - UnrankedTensorType.
+class PyUnrankedTensorType
+ : public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
+ static constexpr const char *pyClassName = "UnrankedTensorType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyType &elementType, DefaultingPyLocation loc) {
+ MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
+ // TODO: Rework error reporting once diagnostic engine is exposed
+ // in C API.
+ if (mlirTypeIsNull(t)) {
+ throw SetPyError(
+ PyExc_ValueError,
+ Twine("invalid '") +
+ py::repr(py::cast(elementType)).cast<std::string>() +
+ "' and expected floating point, integer, vector or "
+ "complex "
+ "type.");
+ }
+ return PyUnrankedTensorType(elementType.getContext(), t);
+ },
+ py::arg("element_type"), py::arg("loc") = py::none(),
+ "Create a unranked tensor type");
+ }
+};
+
+class PyMemRefLayoutMapList;
+
+/// Ranked MemRef Type subclass - MemRefType.
+class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
+ static constexpr const char *pyClassName = "MemRefType";
+ using PyConcreteType::PyConcreteType;
+
+ PyMemRefLayoutMapList getLayout();
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](std::vector<int64_t> shape, PyType &elementType,
+ std::vector<PyAffineMap> layout, PyAttribute *memorySpace,
+ DefaultingPyLocation loc) {
+ SmallVector<MlirAffineMap> maps;
+ maps.reserve(layout.size());
+ for (PyAffineMap &map : layout)
+ maps.push_back(map);
+
+ MlirAttribute memSpaceAttr = {};
+ if (memorySpace)
+ memSpaceAttr = *memorySpace;
+
+ MlirType t = mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
+ shape.data(), maps.size(),
+ maps.data(), memSpaceAttr);
+ // TODO: Rework error reporting once diagnostic engine is exposed
+ // in C API.
+ if (mlirTypeIsNull(t)) {
+ throw SetPyError(
+ PyExc_ValueError,
+ Twine("invalid '") +
+ py::repr(py::cast(elementType)).cast<std::string>() +
+ "' and expected floating point, integer, vector or "
+ "complex "
+ "type.");
+ }
+ return PyMemRefType(elementType.getContext(), t);
+ },
+ py::arg("shape"), py::arg("element_type"),
+ py::arg("layout") = py::list(), py::arg("memory_space") = py::none(),
+ py::arg("loc") = py::none(), "Create a memref type")
+ .def_property_readonly("layout", &PyMemRefType::getLayout,
+ "The list of layout maps of the MemRef type.")
+ .def_property_readonly(
+ "memory_space",
+ [](PyMemRefType &self) -> PyAttribute {
+ MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
+ return PyAttribute(self.getContext(), a);
+ },
+ "Returns the memory space of the given MemRef type.");
+ }
+};
+
+/// A list of affine layout maps in a memref type. Internally, these are stored
+/// as consecutive elements, random access is cheap. Both the type and the maps
+/// are owned by the context, no need to worry about lifetime extension.
+class PyMemRefLayoutMapList
+ : public Sliceable<PyMemRefLayoutMapList, PyAffineMap> {
+public:
+ static constexpr const char *pyClassName = "MemRefLayoutMapList";
+
+ PyMemRefLayoutMapList(PyMemRefType type, intptr_t startIndex = 0,
+ intptr_t length = -1, intptr_t step = 1)
+ : Sliceable(startIndex,
+ length == -1 ? mlirMemRefTypeGetNumAffineMaps(type) : length,
+ step),
+ memref(type) {}
+
+ intptr_t getNumElements() { return mlirMemRefTypeGetNumAffineMaps(memref); }
+
+ PyAffineMap getElement(intptr_t index) {
+ return PyAffineMap(memref.getContext(),
+ mlirMemRefTypeGetAffineMap(memref, index));
+ }
+
+ PyMemRefLayoutMapList slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) {
+ return PyMemRefLayoutMapList(memref, startIndex, length, step);
+ }
+
+private:
+ PyMemRefType memref;
+};
+
+PyMemRefLayoutMapList PyMemRefType::getLayout() {
+ return PyMemRefLayoutMapList(*this);
+}
+
+/// Unranked MemRef Type subclass - UnrankedMemRefType.
+class PyUnrankedMemRefType
+ : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
+ static constexpr const char *pyClassName = "UnrankedMemRefType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyType &elementType, PyAttribute *memorySpace,
+ DefaultingPyLocation loc) {
+ MlirAttribute memSpaceAttr = {};
+ if (memorySpace)
+ memSpaceAttr = *memorySpace;
+
+ MlirType t =
+ mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr);
+ // TODO: Rework error reporting once diagnostic engine is exposed
+ // in C API.
+ if (mlirTypeIsNull(t)) {
+ throw SetPyError(
+ PyExc_ValueError,
+ Twine("invalid '") +
+ py::repr(py::cast(elementType)).cast<std::string>() +
+ "' and expected floating point, integer, vector or "
+ "complex "
+ "type.");
+ }
+ return PyUnrankedMemRefType(elementType.getContext(), t);
+ },
+ py::arg("element_type"), py::arg("memory_space"),
+ py::arg("loc") = py::none(), "Create a unranked memref type")
+ .def_property_readonly(
+ "memory_space",
+ [](PyUnrankedMemRefType &self) -> PyAttribute {
+ MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
+ return PyAttribute(self.getContext(), a);
+ },
+ "Returns the memory space of the given Unranked MemRef type.");
+ }
+};
+
+/// Tuple Type subclass - TupleType.
+class PyTupleType : public PyConcreteType<PyTupleType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
+ static constexpr const char *pyClassName = "TupleType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get_tuple",
+ [](py::list elementList, DefaultingPyMlirContext context) {
+ intptr_t num = py::len(elementList);
+ // Mapping py::list to SmallVector.
+ SmallVector<MlirType, 4> elements;
+ for (auto element : elementList)
+ elements.push_back(element.cast<PyType>());
+ MlirType t = mlirTupleTypeGet(context->get(), num, elements.data());
+ return PyTupleType(context->getRef(), t);
+ },
+ py::arg("elements"), py::arg("context") = py::none(),
+ "Create a tuple type");
+ c.def(
+ "get_type",
+ [](PyTupleType &self, intptr_t pos) -> PyType {
+ MlirType t = mlirTupleTypeGetType(self, pos);
+ return PyType(self.getContext(), t);
+ },
+ "Returns the pos-th type in the tuple type.");
+ c.def_property_readonly(
+ "num_types",
+ [](PyTupleType &self) -> intptr_t {
+ return mlirTupleTypeGetNumTypes(self);
+ },
+ "Returns the number of types contained in a tuple.");
+ }
+};
+
+/// Function type.
+class PyFunctionType : public PyConcreteType<PyFunctionType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
+ static constexpr const char *pyClassName = "FunctionType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](std::vector<PyType> inputs, std::vector<PyType> results,
+ DefaultingPyMlirContext context) {
+ SmallVector<MlirType, 4> inputsRaw(inputs.begin(), inputs.end());
+ SmallVector<MlirType, 4> resultsRaw(results.begin(), results.end());
+ MlirType t = mlirFunctionTypeGet(context->get(), inputsRaw.size(),
+ inputsRaw.data(), resultsRaw.size(),
+ resultsRaw.data());
+ return PyFunctionType(context->getRef(), t);
+ },
+ py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(),
+ "Gets a FunctionType from a list of input and result types");
+ c.def_property_readonly(
+ "inputs",
+ [](PyFunctionType &self) {
+ MlirType t = self;
+ auto contextRef = self.getContext();
+ py::list types;
+ for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
+ ++i) {
+ types.append(PyType(contextRef, mlirFunctionTypeGetInput(t, i)));
+ }
+ return types;
+ },
+ "Returns the list of input types in the FunctionType.");
+ c.def_property_readonly(
+ "results",
+ [](PyFunctionType &self) {
+ auto contextRef = self.getContext();
+ py::list types;
+ for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
+ ++i) {
+ types.append(
+ PyType(contextRef, mlirFunctionTypeGetResult(self, i)));
+ }
+ return types;
+ },
+ "Returns the list of result types in the FunctionType.");
+ }
+};
+
+} // namespace
+
+void mlir::python::populateIRTypes(py::module &m) {
+ PyIntegerType::bind(m);
+ PyIndexType::bind(m);
+ PyBF16Type::bind(m);
+ PyF16Type::bind(m);
+ PyF32Type::bind(m);
+ PyF64Type::bind(m);
+ PyNoneType::bind(m);
+ PyComplexType::bind(m);
+ PyShapedType::bind(m);
+ PyVectorType::bind(m);
+ PyRankedTensorType::bind(m);
+ PyUnrankedTensorType::bind(m);
+ PyMemRefType::bind(m);
+ PyMemRefLayoutMapList::bind(m);
+ PyUnrankedMemRefType::bind(m);
+ PyTupleType::bind(m);
+ PyFunctionType::bind(m);
+}
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 9bfe8b09f6db..5fe0401afaeb 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -12,7 +12,7 @@
#include "ExecutionEngine.h"
#include "Globals.h"
-#include "IRModules.h"
+#include "IRModule.h"
#include "Pass.h"
namespace py = pybind11;
@@ -211,7 +211,10 @@ PYBIND11_MODULE(_mlir, m) {
// Define and populate IR submodule.
auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
- populateIRSubmodule(irModule);
+ populateIRCore(irModule);
+ populateIRAffine(irModule);
+ populateIRAttributes(irModule);
+ populateIRTypes(irModule);
// Define and populate PassManager submodule.
auto passModule =
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index dd57647f0327..0e2f5bafb465 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -8,7 +8,7 @@
#include "Pass.h"
-#include "IRModules.h"
+#include "IRModule.h"
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/Pass.h"
More information about the Mlir-commits
mailing list