[Mlir-commits] [mlir] [mlir] support scalable vectors in python bindings (PR #71050)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Fri Nov 3 01:58:09 PDT 2023


================
@@ -463,18 +464,47 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
 
   static void bindDerived(ClassTy &c) {
     c.def_static(
-        "get",
-        [](std::vector<int64_t> shape, PyType &elementType,
-           DefaultingPyLocation loc) {
-          PyMlirContext::ErrorCapture errors(loc->getContext());
-          MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
-                                                elementType);
-          if (mlirTypeIsNull(t))
-            throw MLIRError("Invalid type", errors.take());
-          return PyVectorType(elementType.getContext(), t);
-        },
-        py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(),
-        "Create a vector type");
+         "get",
+         [](std::vector<int64_t> shape, PyType &elementType,
+            std::optional<py::list> scalable, DefaultingPyLocation loc) {
+           PyMlirContext::ErrorCapture errors(loc->getContext());
+           MlirType type;
+           if (scalable) {
+             if (scalable->size() != shape.size())
+               throw py::value_error("Expected len(scalable) == len(shape).");
+
+             // Vector-of-bool may be using bit packing, so we cannot access its
+             // data directly. Explicitly create an array-of-bool instead.
+             bool *scalableData =
----------------
ftynse wrote:

Good points, thanks!

https://github.com/llvm/llvm-project/pull/71050


More information about the Mlir-commits mailing list