[Mlir-commits] [mlir] [mlir] support scalable vectors in python bindings (PR #71050)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Mon Nov 6 03:45:01 PST 2023


https://github.com/ftynse updated https://github.com/llvm/llvm-project/pull/71050

>From f0d72c524d534c79dd8e17f720b0f39d469726f9 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Thu, 2 Nov 2023 11:43:14 +0000
Subject: [PATCH] [mlir] support scalable vectors in python bindings

The scalable dimension functionality was added to the vector type after
the bindings for it were defined, without the bindings being ever
updated. Fix that.
---
 mlir/include/mlir-c/BuiltinTypes.h   | 26 +++++++++++
 mlir/lib/Bindings/Python/IRTypes.cpp | 69 ++++++++++++++++++++++------
 mlir/lib/CAPI/IR/BuiltinTypes.cpp    | 25 ++++++++++
 mlir/test/CAPI/ir.c                  | 34 ++++++++++----
 mlir/test/python/ir/builtin_types.py | 43 ++++++++++++++++-
 5 files changed, 172 insertions(+), 25 deletions(-)

diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index a6d8e10efbde923..1fd5691f41eec35 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -271,6 +271,32 @@ MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetChecked(MlirLocation loc,
                                                      const int64_t *shape,
                                                      MlirType elementType);
 
+/// Creates a scalable vector type with the shape identified by its rank and
+/// dimensions. A subset of dimensions may be marked as scalable via the
+/// corresponding flag list, which is expected to have as many entries as the
+/// rank of the vector. The vector is created in the same context as the element
+/// type.
+MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetScalable(intptr_t rank,
+                                                      const int64_t *shape,
+                                                      const bool *scalable,
+                                                      MlirType elementType);
+
+/// Same as "mlirVectorTypeGetScalable" but returns a nullptr wrapping MlirType
+/// on illegal arguments, emitting appropriate diagnostics.
+MLIR_CAPI_EXPORTED
+MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
+                                          const int64_t *shape,
+                                          const bool *scalable,
+                                          MlirType elementType);
+
+/// Checks whether the given vector type is scalable, i.e., has at least one
+/// scalable dimension.
+MLIR_CAPI_EXPORTED bool mlirVectorTypeIsScalable(MlirType type);
+
+/// Checks whether the "dim"-th dimension of the given vector is scalable.
+MLIR_CAPI_EXPORTED bool mlirVectorTypeIsDimScalable(MlirType type,
+                                                    intptr_t dim);
+
 //===----------------------------------------------------------------------===//
 // Ranked / Unranked Tensor type.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index a7ccfbea542f5c7..483db673f989e6b 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -462,19 +462,62 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
   using PyConcreteType::PyConcreteType;
 
   static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](std::vector<int64_t> shape, PyType &elementType,
-           DefaultingPyLocation loc) {
-          PyMlirContext::ErrorCapture errors(loc->getContext());
-          MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
-                                                elementType);
-          if (mlirTypeIsNull(t))
-            throw MLIRError("Invalid type", errors.take());
-          return PyVectorType(elementType.getContext(), t);
-        },
-        py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(),
-        "Create a vector type");
+    c.def_static("get", &PyVectorType::get, py::arg("shape"),
+                 py::arg("elementType"), py::kw_only(),
+                 py::arg("scalable") = py::none(),
+                 py::arg("scalable_dims") = py::none(),
+                 py::arg("loc") = py::none(), "Create a vector type")
+        .def_property_readonly(
+            "scalable",
+            [](MlirType self) { return mlirVectorTypeIsScalable(self); })
+        .def_property_readonly("scalable_dims", [](MlirType self) {
+          std::vector<bool> scalableDims;
+          size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
+          scalableDims.reserve(rank);
+          for (size_t i = 0; i < rank; ++i)
+            scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i));
+          return scalableDims;
+        });
+  }
+
+private:
+  static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
+                          std::optional<py::list> scalable,
+                          std::optional<std::vector<int64_t>> scalableDims,
+                          DefaultingPyLocation loc) {
+    if (scalable && scalableDims) {
+      throw py::value_error("'scalable' and 'scalable_dims' kwargs "
+                            "are mutually exclusive.");
+    }
+
+    PyMlirContext::ErrorCapture errors(loc->getContext());
+    MlirType type;
+    if (scalable) {
+      if (scalable->size() != shape.size())
+        throw py::value_error("Expected len(scalable) == len(shape).");
+
+      SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
+          *scalable, [](const py::handle &h) { return h.cast<bool>(); }));
+      type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
+                                              scalableDimFlags.data(),
+                                              elementType);
+    } else if (scalableDims) {
+      SmallVector<bool> scalableDimFlags(shape.size(), false);
+      for (int64_t dim : *scalableDims) {
+        if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
+          throw py::value_error("Scalable dimension index out of bounds.");
+        scalableDimFlags[dim] = true;
+      }
+      type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
+                                              scalableDimFlags.data(),
+                                              elementType);
+    } else {
+      type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
+                                      elementType);
+    }
+    if (mlirTypeIsNull(type))
+      throw MLIRError("Invalid type", errors.take());
+    return PyVectorType(elementType.getContext(), type);
   }
 };
 
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 50266b4b5233235..6e645188dac8616 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -281,6 +281,31 @@ MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
       unwrap(elementType)));
 }
 
+MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape,
+                                   const bool *scalable, MlirType elementType) {
+  return wrap(VectorType::get(
+      llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
+      llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
+}
+
+MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
+                                          const int64_t *shape,
+                                          const bool *scalable,
+                                          MlirType elementType) {
+  return wrap(VectorType::getChecked(
+      unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
+      unwrap(elementType),
+      llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
+}
+
+bool mlirVectorTypeIsScalable(MlirType type) {
+  return unwrap(type).cast<VectorType>().isScalable();
+}
+
+bool mlirVectorTypeIsDimScalable(MlirType type, intptr_t dim) {
+  return unwrap(type).cast<VectorType>().getScalableDims()[dim];
+}
+
 //===----------------------------------------------------------------------===//
 // Ranked / Unranked tensor type.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index c6425f80a8bce9c..3a2bdb9bfc93334 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -743,13 +743,27 @@ static int printBuiltinTypes(MlirContext ctx) {
   fprintf(stderr, "\n");
   // CHECK: vector<2x3xf32>
 
+  // Scalable vector type.
+  bool scalable[] = {false, true};
+  MlirType scalableVector = mlirVectorTypeGetScalable(
+      sizeof(shape) / sizeof(int64_t), shape, scalable, f32);
+  if (!mlirTypeIsAVector(scalableVector))
+    return 16;
+  if (!mlirVectorTypeIsScalable(scalableVector) ||
+      mlirVectorTypeIsDimScalable(scalableVector, 0) ||
+      !mlirVectorTypeIsDimScalable(scalableVector, 1))
+    return 17;
+  mlirTypeDump(scalableVector);
+  fprintf(stderr, "\n");
+  // CHECK: vector<2x[3]xf32>
+
   // Ranked tensor type.
   MlirType rankedTensor = mlirRankedTensorTypeGet(
       sizeof(shape) / sizeof(int64_t), shape, f32, mlirAttributeGetNull());
   if (!mlirTypeIsATensor(rankedTensor) ||
       !mlirTypeIsARankedTensor(rankedTensor) ||
       !mlirAttributeIsNull(mlirRankedTensorTypeGetEncoding(rankedTensor)))
-    return 16;
+    return 18;
   mlirTypeDump(rankedTensor);
   fprintf(stderr, "\n");
   // CHECK: tensor<2x3xf32>
@@ -759,7 +773,7 @@ static int printBuiltinTypes(MlirContext ctx) {
   if (!mlirTypeIsATensor(unrankedTensor) ||
       !mlirTypeIsAUnrankedTensor(unrankedTensor) ||
       mlirShapedTypeHasRank(unrankedTensor))
-    return 17;
+    return 19;
   mlirTypeDump(unrankedTensor);
   fprintf(stderr, "\n");
   // CHECK: tensor<*xf32>
@@ -770,7 +784,7 @@ static int printBuiltinTypes(MlirContext ctx) {
       f32, sizeof(shape) / sizeof(int64_t), shape, memSpace2);
   if (!mlirTypeIsAMemRef(memRef) ||
       !mlirAttributeEqual(mlirMemRefTypeGetMemorySpace(memRef), memSpace2))
-    return 18;
+    return 20;
   mlirTypeDump(memRef);
   fprintf(stderr, "\n");
   // CHECK: memref<2x3xf32, 2>
@@ -782,7 +796,7 @@ static int printBuiltinTypes(MlirContext ctx) {
       mlirTypeIsAMemRef(unrankedMemRef) ||
       !mlirAttributeEqual(mlirUnrankedMemrefGetMemorySpace(unrankedMemRef),
                           memSpace4))
-    return 19;
+    return 21;
   mlirTypeDump(unrankedMemRef);
   fprintf(stderr, "\n");
   // CHECK: memref<*xf32, 4>
@@ -793,7 +807,7 @@ static int printBuiltinTypes(MlirContext ctx) {
   if (!mlirTypeIsATuple(tuple) || mlirTupleTypeGetNumTypes(tuple) != 2 ||
       !mlirTypeEqual(mlirTupleTypeGetType(tuple, 0), unrankedMemRef) ||
       !mlirTypeEqual(mlirTupleTypeGetType(tuple, 1), f32))
-    return 20;
+    return 22;
   mlirTypeDump(tuple);
   fprintf(stderr, "\n");
   // CHECK: tuple<memref<*xf32, 4>, f32>
@@ -805,16 +819,16 @@ static int printBuiltinTypes(MlirContext ctx) {
                              mlirIntegerTypeGet(ctx, 64)};
   MlirType funcType = mlirFunctionTypeGet(ctx, 2, funcInputs, 3, funcResults);
   if (mlirFunctionTypeGetNumInputs(funcType) != 2)
-    return 21;
+    return 23;
   if (mlirFunctionTypeGetNumResults(funcType) != 3)
-    return 22;
+    return 24;
   if (!mlirTypeEqual(funcInputs[0], mlirFunctionTypeGetInput(funcType, 0)) ||
       !mlirTypeEqual(funcInputs[1], mlirFunctionTypeGetInput(funcType, 1)))
-    return 23;
+    return 25;
   if (!mlirTypeEqual(funcResults[0], mlirFunctionTypeGetResult(funcType, 0)) ||
       !mlirTypeEqual(funcResults[1], mlirFunctionTypeGetResult(funcType, 1)) ||
       !mlirTypeEqual(funcResults[2], mlirFunctionTypeGetResult(funcType, 2)))
-    return 24;
+    return 26;
   mlirTypeDump(funcType);
   fprintf(stderr, "\n");
   // CHECK: (index, i1) -> (i16, i32, i64)
@@ -829,7 +843,7 @@ static int printBuiltinTypes(MlirContext ctx) {
       !mlirStringRefEqual(mlirOpaqueTypeGetDialectNamespace(opaque),
                           namespace) ||
       !mlirStringRefEqual(mlirOpaqueTypeGetData(opaque), data))
-    return 25;
+    return 27;
   mlirTypeDump(opaque);
   fprintf(stderr, "\n");
   // CHECK: !dialect.type
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 672418b5383ae45..4c891a2ca2ab9a2 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -300,7 +300,7 @@ def testVectorType():
 
         none = NoneType.get()
         try:
-            vector_invalid = VectorType.get(shape, none)
+            VectorType.get(shape, none)
         except MLIRError as e:
             # CHECK: Invalid type:
             # CHECK: error: unknown: vector elements must be int/index/float type but got 'none'
@@ -308,6 +308,46 @@ def testVectorType():
         else:
             print("Exception not produced")
 
+        scalable_1 = VectorType.get(shape, f32, scalable=[False, True])
+        scalable_2 = VectorType.get([2, 3, 4], f32, scalable=[True, False, True])
+        assert scalable_1.scalable
+        assert scalable_2.scalable
+        assert scalable_1.scalable_dims == [False, True]
+        assert scalable_2.scalable_dims == [True, False, True]
+        # CHECK: scalable 1: vector<2x[3]xf32>
+        print("scalable 1: ", scalable_1)
+        # CHECK: scalable 2: vector<[2]x3x[4]xf32>
+        print("scalable 2: ", scalable_2)
+
+        scalable_3 = VectorType.get(shape, f32, scalable_dims=[1])
+        scalable_4 = VectorType.get([2, 3, 4], f32, scalable_dims=[0, 2])
+        assert scalable_3 == scalable_1
+        assert scalable_4 == scalable_2
+
+        try:
+            VectorType.get(shape, f32, scalable=[False, True, True])
+        except ValueError as e:
+            # CHECK: Expected len(scalable) == len(shape).
+            print(e)
+        else:
+            print("Exception not produced")
+
+        try:
+            VectorType.get(shape, f32, scalable=[False, True], scalable_dims=[1])
+        except ValueError as e:
+            # CHECK: kwargs are mutually exclusive.
+            print(e)
+        else:
+            print("Exception not produced")
+
+        try:
+            VectorType.get(shape, f32, scalable=[False, True], scalable_dims=[42])
+        except ValueError as e:
+            # CHECK: Scalable dimension index out of bounds.
+            print(e)
+        else:
+            print("Exception not produced")
+
 
 # CHECK-LABEL: TEST: testRankedTensorType
 @run
@@ -337,7 +377,6 @@ def testRankedTensorType():
         assert RankedTensorType.get(shape, f32).encoding is None
 
 
-
 # CHECK-LABEL: TEST: testUnrankedTensorType
 @run
 def testUnrankedTensorType():



More information about the Mlir-commits mailing list