[Mlir-commits] [mlir] [mlir] support scalable vectors in python bindings (PR #71050)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Mon Nov 6 03:42:20 PST 2023


================
@@ -462,19 +462,63 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
   using PyConcreteType::PyConcreteType;
 
   static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](std::vector<int64_t> shape, PyType &elementType,
-           DefaultingPyLocation loc) {
-          PyMlirContext::ErrorCapture errors(loc->getContext());
-          MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
+    c.def_static("get", &PyVectorType::get, py::arg("shape"),
+                 py::arg("elementType"), py::kw_only(),
+                 py::arg("scalable") = py::none(),
+                 py::arg("scalable_dims") = py::none(),
+                 py::arg("loc") = py::none(), "Create a vector type")
+        .def_property_readonly(
+            "scalable",
+            [](MlirType self) { return mlirVectorTypeIsScalable(self); })
+        .def_property_readonly("scalable_dims", [](MlirType self) {
+          std::vector<bool> scalableDims;
+          size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
+          scalableDims.reserve(rank);
+          for (size_t i = 0; i < rank; ++i)
+            scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i));
+          return scalableDims;
+        });
+  }
+
+private:
+  static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
+                          std::optional<py::list> scalable,
+                          std::optional<std::vector<int64_t>> scalableDims,
+                          DefaultingPyLocation loc) {
+    if (scalable && scalableDims) {
+      throw py::value_error("'scalable' and 'scalable_dims' kwargs "
+                            "are mutually exclusive.");
+    }
+
+    PyMlirContext::ErrorCapture errors(loc->getContext());
+    MlirType type;
+    if (scalable) {
+      if (scalable->size() != shape.size())
+        throw py::value_error("Expected len(scalable) == len(shape).");
+
+      SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
+          *scalable, [](const py::handle &h) { return h.cast<bool>(); }));
+      type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
+                                              scalableDimFlags.data(),
+                                              elementType);
+    } else if (scalableDims) {
+      SmallVector<bool> scalableDimFlags(shape.size(), false);
+      for (int64_t dim : *scalableDims) {
+        if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0) {
+          throw py::value_error("Scalable dimension index out of bounds.");
+          scalableDimFlags[dim] = true;
+        }
----------------
ftynse wrote:

Good catch, thank you!

https://github.com/llvm/llvm-project/pull/71050


More information about the Mlir-commits mailing list