[Mlir-commits] [mlir] [mlir] support scalable vectors in python bindings (PR #71050)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Mon Nov 6 03:42:20 PST 2023
================
@@ -462,19 +462,63 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](std::vector<int64_t> shape, PyType &elementType,
- DefaultingPyLocation loc) {
- PyMlirContext::ErrorCapture errors(loc->getContext());
- MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
+ c.def_static("get", &PyVectorType::get, py::arg("shape"),
+ py::arg("elementType"), py::kw_only(),
+ py::arg("scalable") = py::none(),
+ py::arg("scalable_dims") = py::none(),
+ py::arg("loc") = py::none(), "Create a vector type")
+ .def_property_readonly(
+ "scalable",
+ [](MlirType self) { return mlirVectorTypeIsScalable(self); })
+ .def_property_readonly("scalable_dims", [](MlirType self) {
+ std::vector<bool> scalableDims;
+ size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
+ scalableDims.reserve(rank);
+ for (size_t i = 0; i < rank; ++i)
+ scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i));
+ return scalableDims;
+ });
+ }
+
+private:
+ static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
+ std::optional<py::list> scalable,
+ std::optional<std::vector<int64_t>> scalableDims,
+ DefaultingPyLocation loc) {
+ if (scalable && scalableDims) {
+ throw py::value_error("'scalable' and 'scalable_dims' kwargs "
+ "are mutually exclusive.");
+ }
+
+ PyMlirContext::ErrorCapture errors(loc->getContext());
+ MlirType type;
+ if (scalable) {
+ if (scalable->size() != shape.size())
+ throw py::value_error("Expected len(scalable) == len(shape).");
+
+ SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
+ *scalable, [](const py::handle &h) { return h.cast<bool>(); }));
+ type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
+ scalableDimFlags.data(),
+ elementType);
+ } else if (scalableDims) {
+ SmallVector<bool> scalableDimFlags(shape.size(), false);
+ for (int64_t dim : *scalableDims) {
+ if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0) {
+ throw py::value_error("Scalable dimension index out of bounds.");
+ scalableDimFlags[dim] = true;
+ }
----------------
ftynse wrote:
Good catch, thank you!
https://github.com/llvm/llvm-project/pull/71050
More information about the Mlir-commits
mailing list