[Mlir-commits] [mlir] 96dadc9 - [mlir] support scalable vectors in python bindings (#71050)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 6 04:15:01 PST 2023


Author: Oleksandr "Alex" Zinenko
Date: 2023-11-06T13:14:56+01:00
New Revision: 96dadc9fc83dddf450e42ea5e9c3fd2616761830

URL: https://github.com/llvm/llvm-project/commit/96dadc9fc83dddf450e42ea5e9c3fd2616761830
DIFF: https://github.com/llvm/llvm-project/commit/96dadc9fc83dddf450e42ea5e9c3fd2616761830.diff

LOG: [mlir] support scalable vectors in python bindings (#71050)

The scalable dimension functionality was added to the vector type after
the bindings for it were defined, without the bindings being ever
updated. Fix that.

Added: 
    

Modified: 
    mlir/include/mlir-c/BuiltinTypes.h
    mlir/lib/Bindings/Python/IRTypes.cpp
    mlir/lib/CAPI/IR/BuiltinTypes.cpp
    mlir/test/CAPI/ir.c
    mlir/test/python/ir/builtin_types.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index a6d8e10efbde923..1fd5691f41eec35 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -271,6 +271,32 @@ MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetChecked(MlirLocation loc,
                                                      const int64_t *shape,
                                                      MlirType elementType);
 
+/// Creates a scalable vector type with the shape identified by its rank and
+/// dimensions. A subset of dimensions may be marked as scalable via the
+/// corresponding flag list, which is expected to have as many entries as the
+/// rank of the vector. The vector is created in the same context as the element
+/// type.
+MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetScalable(intptr_t rank,
+                                                      const int64_t *shape,
+                                                      const bool *scalable,
+                                                      MlirType elementType);
+
+/// Same as "mlirVectorTypeGetScalable" but returns a nullptr wrapping MlirType
+/// on illegal arguments, emitting appropriate diagnostics.
+MLIR_CAPI_EXPORTED
+MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
+                                          const int64_t *shape,
+                                          const bool *scalable,
+                                          MlirType elementType);
+
+/// Checks whether the given vector type is scalable, i.e., has at least one
+/// scalable dimension.
+MLIR_CAPI_EXPORTED bool mlirVectorTypeIsScalable(MlirType type);
+
+/// Checks whether the "dim"-th dimension of the given vector is scalable.
+MLIR_CAPI_EXPORTED bool mlirVectorTypeIsDimScalable(MlirType type,
+                                                    intptr_t dim);
+
 //===----------------------------------------------------------------------===//
 // Ranked / Unranked Tensor type.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index a7ccfbea542f5c7..483db673f989e6b 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -462,19 +462,62 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
   using PyConcreteType::PyConcreteType;
 
   static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](std::vector<int64_t> shape, PyType &elementType,
-           DefaultingPyLocation loc) {
-          PyMlirContext::ErrorCapture errors(loc->getContext());
-          MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
-                                                elementType);
-          if (mlirTypeIsNull(t))
-            throw MLIRError("Invalid type", errors.take());
-          return PyVectorType(elementType.getContext(), t);
-        },
-        py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(),
-        "Create a vector type");
+    c.def_static("get", &PyVectorType::get, py::arg("shape"),
+                 py::arg("elementType"), py::kw_only(),
+                 py::arg("scalable") = py::none(),
+                 py::arg("scalable_dims") = py::none(),
+                 py::arg("loc") = py::none(), "Create a vector type")
+        .def_property_readonly(
+            "scalable",
+            [](MlirType self) { return mlirVectorTypeIsScalable(self); })
+        .def_property_readonly("scalable_dims", [](MlirType self) {
+          std::vector<bool> scalableDims;
+          size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
+          scalableDims.reserve(rank);
+          for (size_t i = 0; i < rank; ++i)
+            scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i));
+          return scalableDims;
+        });
+  }
+
+private:
+  static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
+                          std::optional<py::list> scalable,
+                          std::optional<std::vector<int64_t>> scalableDims,
+                          DefaultingPyLocation loc) {
+    if (scalable && scalableDims) {
+      throw py::value_error("'scalable' and 'scalable_dims' kwargs "
+                            "are mutually exclusive.");
+    }
+
+    PyMlirContext::ErrorCapture errors(loc->getContext());
+    MlirType type;
+    if (scalable) {
+      if (scalable->size() != shape.size())
+        throw py::value_error("Expected len(scalable) == len(shape).");
+
+      SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
+          *scalable, [](const py::handle &h) { return h.cast<bool>(); }));
+      type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
+                                              scalableDimFlags.data(),
+                                              elementType);
+    } else if (scalableDims) {
+      SmallVector<bool> scalableDimFlags(shape.size(), false);
+      for (int64_t dim : *scalableDims) {
+        if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
+          throw py::value_error("Scalable dimension index out of bounds.");
+        scalableDimFlags[dim] = true;
+      }
+      type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
+                                              scalableDimFlags.data(),
+                                              elementType);
+    } else {
+      type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
+                                      elementType);
+    }
+    if (mlirTypeIsNull(type))
+      throw MLIRError("Invalid type", errors.take());
+    return PyVectorType(elementType.getContext(), type);
   }
 };
 

diff  --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 50266b4b5233235..6e645188dac8616 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -281,6 +281,31 @@ MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
       unwrap(elementType)));
 }
 
+MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape,
+                                   const bool *scalable, MlirType elementType) {
+  return wrap(VectorType::get(
+      llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
+      llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
+}
+
+MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
+                                          const int64_t *shape,
+                                          const bool *scalable,
+                                          MlirType elementType) {
+  return wrap(VectorType::getChecked(
+      unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
+      unwrap(elementType),
+      llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
+}
+
+bool mlirVectorTypeIsScalable(MlirType type) {
+  return unwrap(type).cast<VectorType>().isScalable();
+}
+
+bool mlirVectorTypeIsDimScalable(MlirType type, intptr_t dim) {
+  return unwrap(type).cast<VectorType>().getScalableDims()[dim];
+}
+
 //===----------------------------------------------------------------------===//
 // Ranked / Unranked tensor type.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 8d5dcbf62e85e2b..315458a08b613e0 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -746,13 +746,27 @@ static int printBuiltinTypes(MlirContext ctx) {
   fprintf(stderr, "\n");
   // CHECK: vector<2x3xf32>
 
+  // Scalable vector type.
+  bool scalable[] = {false, true};
+  MlirType scalableVector = mlirVectorTypeGetScalable(
+      sizeof(shape) / sizeof(int64_t), shape, scalable, f32);
+  if (!mlirTypeIsAVector(scalableVector))
+    return 16;
+  if (!mlirVectorTypeIsScalable(scalableVector) ||
+      mlirVectorTypeIsDimScalable(scalableVector, 0) ||
+      !mlirVectorTypeIsDimScalable(scalableVector, 1))
+    return 17;
+  mlirTypeDump(scalableVector);
+  fprintf(stderr, "\n");
+  // CHECK: vector<2x[3]xf32>
+
   // Ranked tensor type.
   MlirType rankedTensor = mlirRankedTensorTypeGet(
       sizeof(shape) / sizeof(int64_t), shape, f32, mlirAttributeGetNull());
   if (!mlirTypeIsATensor(rankedTensor) ||
       !mlirTypeIsARankedTensor(rankedTensor) ||
       !mlirAttributeIsNull(mlirRankedTensorTypeGetEncoding(rankedTensor)))
-    return 16;
+    return 18;
   mlirTypeDump(rankedTensor);
   fprintf(stderr, "\n");
   // CHECK: tensor<2x3xf32>
@@ -762,7 +776,7 @@ static int printBuiltinTypes(MlirContext ctx) {
   if (!mlirTypeIsATensor(unrankedTensor) ||
       !mlirTypeIsAUnrankedTensor(unrankedTensor) ||
       mlirShapedTypeHasRank(unrankedTensor))
-    return 17;
+    return 19;
   mlirTypeDump(unrankedTensor);
   fprintf(stderr, "\n");
   // CHECK: tensor<*xf32>
@@ -773,7 +787,7 @@ static int printBuiltinTypes(MlirContext ctx) {
       f32, sizeof(shape) / sizeof(int64_t), shape, memSpace2);
   if (!mlirTypeIsAMemRef(memRef) ||
       !mlirAttributeEqual(mlirMemRefTypeGetMemorySpace(memRef), memSpace2))
-    return 18;
+    return 20;
   mlirTypeDump(memRef);
   fprintf(stderr, "\n");
   // CHECK: memref<2x3xf32, 2>
@@ -785,7 +799,7 @@ static int printBuiltinTypes(MlirContext ctx) {
       mlirTypeIsAMemRef(unrankedMemRef) ||
       !mlirAttributeEqual(mlirUnrankedMemrefGetMemorySpace(unrankedMemRef),
                           memSpace4))
-    return 19;
+    return 21;
   mlirTypeDump(unrankedMemRef);
   fprintf(stderr, "\n");
   // CHECK: memref<*xf32, 4>
@@ -796,7 +810,7 @@ static int printBuiltinTypes(MlirContext ctx) {
   if (!mlirTypeIsATuple(tuple) || mlirTupleTypeGetNumTypes(tuple) != 2 ||
       !mlirTypeEqual(mlirTupleTypeGetType(tuple, 0), unrankedMemRef) ||
       !mlirTypeEqual(mlirTupleTypeGetType(tuple, 1), f32))
-    return 20;
+    return 22;
   mlirTypeDump(tuple);
   fprintf(stderr, "\n");
   // CHECK: tuple<memref<*xf32, 4>, f32>
@@ -808,16 +822,16 @@ static int printBuiltinTypes(MlirContext ctx) {
                              mlirIntegerTypeGet(ctx, 64)};
   MlirType funcType = mlirFunctionTypeGet(ctx, 2, funcInputs, 3, funcResults);
   if (mlirFunctionTypeGetNumInputs(funcType) != 2)
-    return 21;
+    return 23;
   if (mlirFunctionTypeGetNumResults(funcType) != 3)
-    return 22;
+    return 24;
   if (!mlirTypeEqual(funcInputs[0], mlirFunctionTypeGetInput(funcType, 0)) ||
       !mlirTypeEqual(funcInputs[1], mlirFunctionTypeGetInput(funcType, 1)))
-    return 23;
+    return 25;
   if (!mlirTypeEqual(funcResults[0], mlirFunctionTypeGetResult(funcType, 0)) ||
       !mlirTypeEqual(funcResults[1], mlirFunctionTypeGetResult(funcType, 1)) ||
       !mlirTypeEqual(funcResults[2], mlirFunctionTypeGetResult(funcType, 2)))
-    return 24;
+    return 26;
   mlirTypeDump(funcType);
   fprintf(stderr, "\n");
   // CHECK: (index, i1) -> (i16, i32, i64)
@@ -832,7 +846,7 @@ static int printBuiltinTypes(MlirContext ctx) {
       !mlirStringRefEqual(mlirOpaqueTypeGetDialectNamespace(opaque),
                           namespace) ||
       !mlirStringRefEqual(mlirOpaqueTypeGetData(opaque), data))
-    return 25;
+    return 27;
   mlirTypeDump(opaque);
   fprintf(stderr, "\n");
   // CHECK: !dialect.type

diff  --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 672418b5383ae45..4c891a2ca2ab9a2 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -300,7 +300,7 @@ def testVectorType():
 
         none = NoneType.get()
         try:
-            vector_invalid = VectorType.get(shape, none)
+            VectorType.get(shape, none)
         except MLIRError as e:
             # CHECK: Invalid type:
             # CHECK: error: unknown: vector elements must be int/index/float type but got 'none'
@@ -308,6 +308,46 @@ def testVectorType():
         else:
             print("Exception not produced")
 
+        scalable_1 = VectorType.get(shape, f32, scalable=[False, True])
+        scalable_2 = VectorType.get([2, 3, 4], f32, scalable=[True, False, True])
+        assert scalable_1.scalable
+        assert scalable_2.scalable
+        assert scalable_1.scalable_dims == [False, True]
+        assert scalable_2.scalable_dims == [True, False, True]
+        # CHECK: scalable 1: vector<2x[3]xf32>
+        print("scalable 1: ", scalable_1)
+        # CHECK: scalable 2: vector<[2]x3x[4]xf32>
+        print("scalable 2: ", scalable_2)
+
+        scalable_3 = VectorType.get(shape, f32, scalable_dims=[1])
+        scalable_4 = VectorType.get([2, 3, 4], f32, scalable_dims=[0, 2])
+        assert scalable_3 == scalable_1
+        assert scalable_4 == scalable_2
+
+        try:
+            VectorType.get(shape, f32, scalable=[False, True, True])
+        except ValueError as e:
+            # CHECK: Expected len(scalable) == len(shape).
+            print(e)
+        else:
+            print("Exception not produced")
+
+        try:
+            VectorType.get(shape, f32, scalable=[False, True], scalable_dims=[1])
+        except ValueError as e:
+            # CHECK: kwargs are mutually exclusive.
+            print(e)
+        else:
+            print("Exception not produced")
+
+        try:
+            VectorType.get(shape, f32, scalable=[False, True], scalable_dims=[42])
+        except ValueError as e:
+            # CHECK: Scalable dimension index out of bounds.
+            print(e)
+        else:
+            print("Exception not produced")
+
 
 # CHECK-LABEL: TEST: testRankedTensorType
 @run
@@ -337,7 +377,6 @@ def testRankedTensorType():
         assert RankedTensorType.get(shape, f32).encoding is None
 
 
-
 # CHECK-LABEL: TEST: testUnrankedTensorType
 @run
 def testUnrankedTensorType():


        


More information about the Mlir-commits mailing list