[Mlir-commits] [mlir] [mlir] support scalable vectors in python bindings (PR #71050)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Thu Nov 2 04:45:10 PDT 2023


https://github.com/ftynse created https://github.com/llvm/llvm-project/pull/71050

The scalable dimension functionality was added to the vector type after the bindings for it were defined, without the bindings being ever updated. Fix that.

>From bdb47c786fd4db596929b3752a2aecc79abc0df9 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Thu, 2 Nov 2023 11:43:14 +0000
Subject: [PATCH] [mlir] support scalable vectors in python bindings

The scalable dimension functionality was added to the vector type after
the bindings for it were defined, without the bindings being ever
updated. Fix that.
---
 mlir/include/mlir-c/BuiltinTypes.h   | 26 ++++++++++++++
 mlir/lib/Bindings/Python/IRTypes.cpp | 54 +++++++++++++++++++++-------
 mlir/lib/CAPI/IR/BuiltinTypes.cpp    | 25 +++++++++++++
 mlir/test/CAPI/ir.c                  | 34 ++++++++++++------
 mlir/test/python/ir/builtin_types.py | 23 +++++++++++-
 5 files changed, 139 insertions(+), 23 deletions(-)

diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index a6d8e10efbde923..1fd5691f41eec35 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -271,6 +271,32 @@ MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetChecked(MlirLocation loc,
                                                      const int64_t *shape,
                                                      MlirType elementType);
 
+/// Creates a scalable vector type with the shape identified by its rank and
+/// dimensions. A subset of dimensions may be marked as scalable via the
+/// corresponding flag list, which is expected to have as many entries as the
+/// rank of the vector. The vector is created in the same context as the element
+/// type.
+MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetScalable(intptr_t rank,
+                                                      const int64_t *shape,
+                                                      const bool *scalable,
+                                                      MlirType elementType);
+
+/// Same as "mlirVectorTypeGetScalable" but returns a nullptr wrapping MlirType
+/// on illegal arguments, emitting appropriate diagnostics.
+MLIR_CAPI_EXPORTED
+MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
+                                          const int64_t *shape,
+                                          const bool *scalable,
+                                          MlirType elementType);
+
+/// Checks whether the given vector type is scalable, i.e., has at least one
+/// scalable dimension.
+MLIR_CAPI_EXPORTED bool mlirVectorTypeIsScalable(MlirType type);
+
+/// Checks whether the "dim"-th dimension of the given vector is scalable.
+MLIR_CAPI_EXPORTED bool mlirVectorTypeIsDimScalable(MlirType type,
+                                                    intptr_t dim);
+
 //===----------------------------------------------------------------------===//
 // Ranked / Unranked Tensor type.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index a7ccfbea542f5c7..e145e05ad9b4f19 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/BuiltinTypes.h"
+#include "llvm/ADT/ScopeExit.h"
 #include <optional>
 
 namespace py = pybind11;
@@ -463,18 +464,47 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
 
   static void bindDerived(ClassTy &c) {
     c.def_static(
-        "get",
-        [](std::vector<int64_t> shape, PyType &elementType,
-           DefaultingPyLocation loc) {
-          PyMlirContext::ErrorCapture errors(loc->getContext());
-          MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
-                                                elementType);
-          if (mlirTypeIsNull(t))
-            throw MLIRError("Invalid type", errors.take());
-          return PyVectorType(elementType.getContext(), t);
-        },
-        py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(),
-        "Create a vector type");
+         "get",
+         [](std::vector<int64_t> shape, PyType &elementType,
+            std::optional<py::list> scalable, DefaultingPyLocation loc) {
+           PyMlirContext::ErrorCapture errors(loc->getContext());
+           MlirType type;
+           if (scalable) {
+             if (scalable->size() != shape.size())
+               throw py::value_error("Expected len(scalable) == len(shape).");
+
+             // Vector-of-bool may be using bit packing, so we cannot access its
+             // data directly. Explicitly create an array-of-bool instead.
+             bool *scalableData =
+                 static_cast<bool *>(malloc(sizeof(bool) * scalable->size()));
+             auto deleter = llvm::make_scope_exit([&] { free(scalableData); });
+             auto range = llvm::map_range(
+                 *scalable, [](const py::handle &h) { return h.cast<bool>(); });
+             llvm::copy(range, scalableData);
+             type = mlirVectorTypeGetScalableChecked(
+                 loc, shape.size(), shape.data(), scalableData, elementType);
+           } else {
+             type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
+                                             elementType);
+           }
+           if (mlirTypeIsNull(type))
+             throw MLIRError("Invalid type", errors.take());
+           return PyVectorType(elementType.getContext(), type);
+         },
+         py::arg("shape"), py::arg("elementType"), py::kw_only(),
+         py::arg("scalable") = py::none(), py::arg("loc") = py::none(),
+         "Create a vector type")
+        .def_property_readonly(
+            "scalable",
+            [](MlirType self) { return mlirVectorTypeIsScalable(self); })
+        .def_property_readonly("scalable_dims", [](MlirType self) {
+          std::vector<bool> scalableDims;
+          size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
+          scalableDims.reserve(rank);
+          for (size_t i = 0; i < rank; ++i)
+            scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i));
+          return scalableDims;
+        });
   }
 };
 
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 50266b4b5233235..6e645188dac8616 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -281,6 +281,31 @@ MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
       unwrap(elementType)));
 }
 
+MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape,
+                                   const bool *scalable, MlirType elementType) {
+  return wrap(VectorType::get(
+      llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
+      llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
+}
+
+MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
+                                          const int64_t *shape,
+                                          const bool *scalable,
+                                          MlirType elementType) {
+  return wrap(VectorType::getChecked(
+      unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
+      unwrap(elementType),
+      llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
+}
+
+bool mlirVectorTypeIsScalable(MlirType type) {
+  return unwrap(type).cast<VectorType>().isScalable();
+}
+
+bool mlirVectorTypeIsDimScalable(MlirType type, intptr_t dim) {
+  return unwrap(type).cast<VectorType>().getScalableDims()[dim];
+}
+
 //===----------------------------------------------------------------------===//
 // Ranked / Unranked tensor type.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index c6425f80a8bce9c..3a2bdb9bfc93334 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -743,13 +743,27 @@ static int printBuiltinTypes(MlirContext ctx) {
   fprintf(stderr, "\n");
   // CHECK: vector<2x3xf32>
 
+  // Scalable vector type.
+  bool scalable[] = {false, true};
+  MlirType scalableVector = mlirVectorTypeGetScalable(
+      sizeof(shape) / sizeof(int64_t), shape, scalable, f32);
+  if (!mlirTypeIsAVector(scalableVector))
+    return 16;
+  if (!mlirVectorTypeIsScalable(scalableVector) ||
+      mlirVectorTypeIsDimScalable(scalableVector, 0) ||
+      !mlirVectorTypeIsDimScalable(scalableVector, 1))
+    return 17;
+  mlirTypeDump(scalableVector);
+  fprintf(stderr, "\n");
+  // CHECK: vector<2x[3]xf32>
+
   // Ranked tensor type.
   MlirType rankedTensor = mlirRankedTensorTypeGet(
       sizeof(shape) / sizeof(int64_t), shape, f32, mlirAttributeGetNull());
   if (!mlirTypeIsATensor(rankedTensor) ||
       !mlirTypeIsARankedTensor(rankedTensor) ||
       !mlirAttributeIsNull(mlirRankedTensorTypeGetEncoding(rankedTensor)))
-    return 16;
+    return 18;
   mlirTypeDump(rankedTensor);
   fprintf(stderr, "\n");
   // CHECK: tensor<2x3xf32>
@@ -759,7 +773,7 @@ static int printBuiltinTypes(MlirContext ctx) {
   if (!mlirTypeIsATensor(unrankedTensor) ||
       !mlirTypeIsAUnrankedTensor(unrankedTensor) ||
       mlirShapedTypeHasRank(unrankedTensor))
-    return 17;
+    return 19;
   mlirTypeDump(unrankedTensor);
   fprintf(stderr, "\n");
   // CHECK: tensor<*xf32>
@@ -770,7 +784,7 @@ static int printBuiltinTypes(MlirContext ctx) {
       f32, sizeof(shape) / sizeof(int64_t), shape, memSpace2);
   if (!mlirTypeIsAMemRef(memRef) ||
       !mlirAttributeEqual(mlirMemRefTypeGetMemorySpace(memRef), memSpace2))
-    return 18;
+    return 20;
   mlirTypeDump(memRef);
   fprintf(stderr, "\n");
   // CHECK: memref<2x3xf32, 2>
@@ -782,7 +796,7 @@ static int printBuiltinTypes(MlirContext ctx) {
       mlirTypeIsAMemRef(unrankedMemRef) ||
       !mlirAttributeEqual(mlirUnrankedMemrefGetMemorySpace(unrankedMemRef),
                           memSpace4))
-    return 19;
+    return 21;
   mlirTypeDump(unrankedMemRef);
   fprintf(stderr, "\n");
   // CHECK: memref<*xf32, 4>
@@ -793,7 +807,7 @@ static int printBuiltinTypes(MlirContext ctx) {
   if (!mlirTypeIsATuple(tuple) || mlirTupleTypeGetNumTypes(tuple) != 2 ||
       !mlirTypeEqual(mlirTupleTypeGetType(tuple, 0), unrankedMemRef) ||
       !mlirTypeEqual(mlirTupleTypeGetType(tuple, 1), f32))
-    return 20;
+    return 22;
   mlirTypeDump(tuple);
   fprintf(stderr, "\n");
   // CHECK: tuple<memref<*xf32, 4>, f32>
@@ -805,16 +819,16 @@ static int printBuiltinTypes(MlirContext ctx) {
                              mlirIntegerTypeGet(ctx, 64)};
   MlirType funcType = mlirFunctionTypeGet(ctx, 2, funcInputs, 3, funcResults);
   if (mlirFunctionTypeGetNumInputs(funcType) != 2)
-    return 21;
+    return 23;
   if (mlirFunctionTypeGetNumResults(funcType) != 3)
-    return 22;
+    return 24;
   if (!mlirTypeEqual(funcInputs[0], mlirFunctionTypeGetInput(funcType, 0)) ||
       !mlirTypeEqual(funcInputs[1], mlirFunctionTypeGetInput(funcType, 1)))
-    return 23;
+    return 25;
   if (!mlirTypeEqual(funcResults[0], mlirFunctionTypeGetResult(funcType, 0)) ||
       !mlirTypeEqual(funcResults[1], mlirFunctionTypeGetResult(funcType, 1)) ||
       !mlirTypeEqual(funcResults[2], mlirFunctionTypeGetResult(funcType, 2)))
-    return 24;
+    return 26;
   mlirTypeDump(funcType);
   fprintf(stderr, "\n");
   // CHECK: (index, i1) -> (i16, i32, i64)
@@ -829,7 +843,7 @@ static int printBuiltinTypes(MlirContext ctx) {
       !mlirStringRefEqual(mlirOpaqueTypeGetDialectNamespace(opaque),
                           namespace) ||
       !mlirStringRefEqual(mlirOpaqueTypeGetData(opaque), data))
-    return 25;
+    return 27;
   mlirTypeDump(opaque);
   fprintf(stderr, "\n");
   // CHECK: !dialect.type
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 672418b5383ae45..e2344794c839a3a 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -300,7 +300,7 @@ def testVectorType():
 
         none = NoneType.get()
         try:
-            vector_invalid = VectorType.get(shape, none)
+            VectorType.get(shape, none)
         except MLIRError as e:
             # CHECK: Invalid type:
             # CHECK: error: unknown: vector elements must be int/index/float type but got 'none'
@@ -308,6 +308,27 @@ def testVectorType():
         else:
             print("Exception not produced")
 
+        scalable_1 = VectorType.get(shape, f32, scalable=[False, True])
+        scalable_2 = VectorType.get([2, 3, 4],
+                                    f32,
+                                    scalable=[True, False, True])
+        assert scalable_1.scalable
+        assert scalable_2.scalable
+        assert scalable_1.scalable_dims == [False, True]
+        assert scalable_2.scalable_dims == [True, False, True]
+        # CHECK: scalable 1: vector<2x[3]xf32>
+        print("scalable 1: ", scalable_1)
+        # CHECK: scalable 2: vector<[2]x3x[4]xf32>
+        print("scalable 2: ", scalable_2)
+
+        try:
+            VectorType.get(shape, f32, scalable=[False, True, True])
+        except ValueError as e:
+            # CHECK: Expected len(scalable) == len(shape).
+            print(e)
+        else:
+            print("Exception not produced")
+
 
 # CHECK-LABEL: TEST: testRankedTensorType
 @run



More information about the Mlir-commits mailing list