[Mlir-commits] [mlir] [mlir] support scalable vectors in python bindings (PR #71050)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 2 04:45:47 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Oleksandr "Alex" Zinenko (ftynse)

<details>
<summary>Changes</summary>

The scalable dimension functionality was added to the vector type after the bindings for it were defined, without the bindings being ever updated. Fix that.

---
Full diff: https://github.com/llvm/llvm-project/pull/71050.diff


5 Files Affected:

- (modified) mlir/include/mlir-c/BuiltinTypes.h (+26) 
- (modified) mlir/lib/Bindings/Python/IRTypes.cpp (+42-12) 
- (modified) mlir/lib/CAPI/IR/BuiltinTypes.cpp (+25) 
- (modified) mlir/test/CAPI/ir.c (+24-10) 
- (modified) mlir/test/python/ir/builtin_types.py (+22-1) 


``````````diff
diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index a6d8e10efbde923..1fd5691f41eec35 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -271,6 +271,32 @@ MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetChecked(MlirLocation loc,
                                                      const int64_t *shape,
                                                      MlirType elementType);
 
+/// Creates a scalable vector type with the shape identified by its rank and
+/// dimensions. A subset of dimensions may be marked as scalable via the
+/// corresponding flag list, which is expected to have as many entries as the
+/// rank of the vector. The vector is created in the same context as the element
+/// type.
+MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetScalable(intptr_t rank,
+                                                      const int64_t *shape,
+                                                      const bool *scalable,
+                                                      MlirType elementType);
+
+/// Same as "mlirVectorTypeGetScalable" but returns a nullptr wrapping MlirType
+/// on illegal arguments, emitting appropriate diagnostics.
+MLIR_CAPI_EXPORTED
+MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
+                                          const int64_t *shape,
+                                          const bool *scalable,
+                                          MlirType elementType);
+
+/// Checks whether the given vector type is scalable, i.e., has at least one
+/// scalable dimension.
+MLIR_CAPI_EXPORTED bool mlirVectorTypeIsScalable(MlirType type);
+
+/// Checks whether the "dim"-th dimension of the given vector is scalable.
+MLIR_CAPI_EXPORTED bool mlirVectorTypeIsDimScalable(MlirType type,
+                                                    intptr_t dim);
+
 //===----------------------------------------------------------------------===//
 // Ranked / Unranked Tensor type.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index a7ccfbea542f5c7..e145e05ad9b4f19 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/BuiltinTypes.h"
+#include "llvm/ADT/ScopeExit.h"
 #include <optional>
 
 namespace py = pybind11;
@@ -463,18 +464,47 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
 
   static void bindDerived(ClassTy &c) {
     c.def_static(
-        "get",
-        [](std::vector<int64_t> shape, PyType &elementType,
-           DefaultingPyLocation loc) {
-          PyMlirContext::ErrorCapture errors(loc->getContext());
-          MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
-                                                elementType);
-          if (mlirTypeIsNull(t))
-            throw MLIRError("Invalid type", errors.take());
-          return PyVectorType(elementType.getContext(), t);
-        },
-        py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(),
-        "Create a vector type");
+         "get",
+         [](std::vector<int64_t> shape, PyType &elementType,
+            std::optional<py::list> scalable, DefaultingPyLocation loc) {
+           PyMlirContext::ErrorCapture errors(loc->getContext());
+           MlirType type;
+           if (scalable) {
+             if (scalable->size() != shape.size())
+               throw py::value_error("Expected len(scalable) == len(shape).");
+
+             // Vector-of-bool may be using bit packing, so we cannot access its
+             // data directly. Explicitly create an array-of-bool instead.
+             bool *scalableData =
+                 static_cast<bool *>(malloc(sizeof(bool) * scalable->size()));
+             auto deleter = llvm::make_scope_exit([&] { free(scalableData); });
+             auto range = llvm::map_range(
+                 *scalable, [](const py::handle &h) { return h.cast<bool>(); });
+             llvm::copy(range, scalableData);
+             type = mlirVectorTypeGetScalableChecked(
+                 loc, shape.size(), shape.data(), scalableData, elementType);
+           } else {
+             type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
+                                             elementType);
+           }
+           if (mlirTypeIsNull(type))
+             throw MLIRError("Invalid type", errors.take());
+           return PyVectorType(elementType.getContext(), type);
+         },
+         py::arg("shape"), py::arg("elementType"), py::kw_only(),
+         py::arg("scalable") = py::none(), py::arg("loc") = py::none(),
+         "Create a vector type")
+        .def_property_readonly(
+            "scalable",
+            [](MlirType self) { return mlirVectorTypeIsScalable(self); })
+        .def_property_readonly("scalable_dims", [](MlirType self) {
+          std::vector<bool> scalableDims;
+          size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
+          scalableDims.reserve(rank);
+          for (size_t i = 0; i < rank; ++i)
+            scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i));
+          return scalableDims;
+        });
   }
 };
 
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 50266b4b5233235..6e645188dac8616 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -281,6 +281,31 @@ MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
       unwrap(elementType)));
 }
 
+MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape,
+                                   const bool *scalable, MlirType elementType) {
+  return wrap(VectorType::get(
+      llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
+      llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
+}
+
+MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
+                                          const int64_t *shape,
+                                          const bool *scalable,
+                                          MlirType elementType) {
+  return wrap(VectorType::getChecked(
+      unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
+      unwrap(elementType),
+      llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
+}
+
+bool mlirVectorTypeIsScalable(MlirType type) {
+  return unwrap(type).cast<VectorType>().isScalable();
+}
+
+bool mlirVectorTypeIsDimScalable(MlirType type, intptr_t dim) {
+  return unwrap(type).cast<VectorType>().getScalableDims()[dim];
+}
+
 //===----------------------------------------------------------------------===//
 // Ranked / Unranked tensor type.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index c6425f80a8bce9c..3a2bdb9bfc93334 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -743,13 +743,27 @@ static int printBuiltinTypes(MlirContext ctx) {
   fprintf(stderr, "\n");
   // CHECK: vector<2x3xf32>
 
+  // Scalable vector type.
+  bool scalable[] = {false, true};
+  MlirType scalableVector = mlirVectorTypeGetScalable(
+      sizeof(shape) / sizeof(int64_t), shape, scalable, f32);
+  if (!mlirTypeIsAVector(scalableVector))
+    return 16;
+  if (!mlirVectorTypeIsScalable(scalableVector) ||
+      mlirVectorTypeIsDimScalable(scalableVector, 0) ||
+      !mlirVectorTypeIsDimScalable(scalableVector, 1))
+    return 17;
+  mlirTypeDump(scalableVector);
+  fprintf(stderr, "\n");
+  // CHECK: vector<2x[3]xf32>
+
   // Ranked tensor type.
   MlirType rankedTensor = mlirRankedTensorTypeGet(
       sizeof(shape) / sizeof(int64_t), shape, f32, mlirAttributeGetNull());
   if (!mlirTypeIsATensor(rankedTensor) ||
       !mlirTypeIsARankedTensor(rankedTensor) ||
       !mlirAttributeIsNull(mlirRankedTensorTypeGetEncoding(rankedTensor)))
-    return 16;
+    return 18;
   mlirTypeDump(rankedTensor);
   fprintf(stderr, "\n");
   // CHECK: tensor<2x3xf32>
@@ -759,7 +773,7 @@ static int printBuiltinTypes(MlirContext ctx) {
   if (!mlirTypeIsATensor(unrankedTensor) ||
       !mlirTypeIsAUnrankedTensor(unrankedTensor) ||
       mlirShapedTypeHasRank(unrankedTensor))
-    return 17;
+    return 19;
   mlirTypeDump(unrankedTensor);
   fprintf(stderr, "\n");
   // CHECK: tensor<*xf32>
@@ -770,7 +784,7 @@ static int printBuiltinTypes(MlirContext ctx) {
       f32, sizeof(shape) / sizeof(int64_t), shape, memSpace2);
   if (!mlirTypeIsAMemRef(memRef) ||
       !mlirAttributeEqual(mlirMemRefTypeGetMemorySpace(memRef), memSpace2))
-    return 18;
+    return 20;
   mlirTypeDump(memRef);
   fprintf(stderr, "\n");
   // CHECK: memref<2x3xf32, 2>
@@ -782,7 +796,7 @@ static int printBuiltinTypes(MlirContext ctx) {
       mlirTypeIsAMemRef(unrankedMemRef) ||
       !mlirAttributeEqual(mlirUnrankedMemrefGetMemorySpace(unrankedMemRef),
                           memSpace4))
-    return 19;
+    return 21;
   mlirTypeDump(unrankedMemRef);
   fprintf(stderr, "\n");
   // CHECK: memref<*xf32, 4>
@@ -793,7 +807,7 @@ static int printBuiltinTypes(MlirContext ctx) {
   if (!mlirTypeIsATuple(tuple) || mlirTupleTypeGetNumTypes(tuple) != 2 ||
       !mlirTypeEqual(mlirTupleTypeGetType(tuple, 0), unrankedMemRef) ||
       !mlirTypeEqual(mlirTupleTypeGetType(tuple, 1), f32))
-    return 20;
+    return 22;
   mlirTypeDump(tuple);
   fprintf(stderr, "\n");
   // CHECK: tuple<memref<*xf32, 4>, f32>
@@ -805,16 +819,16 @@ static int printBuiltinTypes(MlirContext ctx) {
                              mlirIntegerTypeGet(ctx, 64)};
   MlirType funcType = mlirFunctionTypeGet(ctx, 2, funcInputs, 3, funcResults);
   if (mlirFunctionTypeGetNumInputs(funcType) != 2)
-    return 21;
+    return 23;
   if (mlirFunctionTypeGetNumResults(funcType) != 3)
-    return 22;
+    return 24;
   if (!mlirTypeEqual(funcInputs[0], mlirFunctionTypeGetInput(funcType, 0)) ||
       !mlirTypeEqual(funcInputs[1], mlirFunctionTypeGetInput(funcType, 1)))
-    return 23;
+    return 25;
   if (!mlirTypeEqual(funcResults[0], mlirFunctionTypeGetResult(funcType, 0)) ||
       !mlirTypeEqual(funcResults[1], mlirFunctionTypeGetResult(funcType, 1)) ||
       !mlirTypeEqual(funcResults[2], mlirFunctionTypeGetResult(funcType, 2)))
-    return 24;
+    return 26;
   mlirTypeDump(funcType);
   fprintf(stderr, "\n");
   // CHECK: (index, i1) -> (i16, i32, i64)
@@ -829,7 +843,7 @@ static int printBuiltinTypes(MlirContext ctx) {
       !mlirStringRefEqual(mlirOpaqueTypeGetDialectNamespace(opaque),
                           namespace) ||
       !mlirStringRefEqual(mlirOpaqueTypeGetData(opaque), data))
-    return 25;
+    return 27;
   mlirTypeDump(opaque);
   fprintf(stderr, "\n");
   // CHECK: !dialect.type
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 672418b5383ae45..e2344794c839a3a 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -300,7 +300,7 @@ def testVectorType():
 
         none = NoneType.get()
         try:
-            vector_invalid = VectorType.get(shape, none)
+            VectorType.get(shape, none)
         except MLIRError as e:
             # CHECK: Invalid type:
             # CHECK: error: unknown: vector elements must be int/index/float type but got 'none'
@@ -308,6 +308,27 @@ def testVectorType():
         else:
             print("Exception not produced")
 
+        scalable_1 = VectorType.get(shape, f32, scalable=[False, True])
+        scalable_2 = VectorType.get([2, 3, 4],
+                                    f32,
+                                    scalable=[True, False, True])
+        assert scalable_1.scalable
+        assert scalable_2.scalable
+        assert scalable_1.scalable_dims == [False, True]
+        assert scalable_2.scalable_dims == [True, False, True]
+        # CHECK: scalable 1: vector<2x[3]xf32>
+        print("scalable 1: ", scalable_1)
+        # CHECK: scalable 2: vector<[2]x3x[4]xf32>
+        print("scalable 2: ", scalable_2)
+
+        try:
+            VectorType.get(shape, f32, scalable=[False, True, True])
+        except ValueError as e:
+            # CHECK: Expected len(scalable) == len(shape).
+            print(e)
+        else:
+            print("Exception not produced")
+
 
 # CHECK-LABEL: TEST: testRankedTensorType
 @run

``````````

</details>


https://github.com/llvm/llvm-project/pull/71050


More information about the Mlir-commits mailing list