[Mlir-commits] [mlir] [mlir] support scalable vectors in python bindings (PR #71050)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 2 04:45:47 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Oleksandr "Alex" Zinenko (ftynse)
<details>
<summary>Changes</summary>
The scalable dimension functionality was added to the vector type after the bindings for it were defined, without the bindings being ever updated. Fix that.
---
Full diff: https://github.com/llvm/llvm-project/pull/71050.diff
5 Files Affected:
- (modified) mlir/include/mlir-c/BuiltinTypes.h (+26)
- (modified) mlir/lib/Bindings/Python/IRTypes.cpp (+42-12)
- (modified) mlir/lib/CAPI/IR/BuiltinTypes.cpp (+25)
- (modified) mlir/test/CAPI/ir.c (+24-10)
- (modified) mlir/test/python/ir/builtin_types.py (+22-1)
``````````diff
diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index a6d8e10efbde923..1fd5691f41eec35 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -271,6 +271,32 @@ MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetChecked(MlirLocation loc,
const int64_t *shape,
MlirType elementType);
+/// Creates a scalable vector type with the shape identified by its rank and
+/// dimensions. A subset of dimensions may be marked as scalable via the
+/// corresponding flag list, which is expected to have as many entries as the
+/// rank of the vector. The vector is created in the same context as the element
+/// type.
+MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetScalable(intptr_t rank,
+ const int64_t *shape,
+ const bool *scalable,
+ MlirType elementType);
+
+/// Same as "mlirVectorTypeGetScalable" but returns a nullptr wrapping MlirType
+/// on illegal arguments, emitting appropriate diagnostics.
+MLIR_CAPI_EXPORTED
+MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
+ const int64_t *shape,
+ const bool *scalable,
+ MlirType elementType);
+
+/// Checks whether the given vector type is scalable, i.e., has at least one
+/// scalable dimension.
+MLIR_CAPI_EXPORTED bool mlirVectorTypeIsScalable(MlirType type);
+
+/// Checks whether the "dim"-th dimension of the given vector is scalable.
+MLIR_CAPI_EXPORTED bool mlirVectorTypeIsDimScalable(MlirType type,
+ intptr_t dim);
+
//===----------------------------------------------------------------------===//
// Ranked / Unranked Tensor type.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index a7ccfbea542f5c7..e145e05ad9b4f19 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -12,6 +12,7 @@
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
+#include "llvm/ADT/ScopeExit.h"
#include <optional>
namespace py = pybind11;
@@ -463,18 +464,47 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
static void bindDerived(ClassTy &c) {
c.def_static(
- "get",
- [](std::vector<int64_t> shape, PyType &elementType,
- DefaultingPyLocation loc) {
- PyMlirContext::ErrorCapture errors(loc->getContext());
- MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
- elementType);
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyVectorType(elementType.getContext(), t);
- },
- py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(),
- "Create a vector type");
+ "get",
+ [](std::vector<int64_t> shape, PyType &elementType,
+ std::optional<py::list> scalable, DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
+ MlirType type;
+ if (scalable) {
+ if (scalable->size() != shape.size())
+ throw py::value_error("Expected len(scalable) == len(shape).");
+
+ // Vector-of-bool may be using bit packing, so we cannot access its
+ // data directly. Explicitly create an array-of-bool instead.
+ bool *scalableData =
+ static_cast<bool *>(malloc(sizeof(bool) * scalable->size()));
+ auto deleter = llvm::make_scope_exit([&] { free(scalableData); });
+ auto range = llvm::map_range(
+ *scalable, [](const py::handle &h) { return h.cast<bool>(); });
+ llvm::copy(range, scalableData);
+ type = mlirVectorTypeGetScalableChecked(
+ loc, shape.size(), shape.data(), scalableData, elementType);
+ } else {
+ type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
+ elementType);
+ }
+ if (mlirTypeIsNull(type))
+ throw MLIRError("Invalid type", errors.take());
+ return PyVectorType(elementType.getContext(), type);
+ },
+ py::arg("shape"), py::arg("elementType"), py::kw_only(),
+ py::arg("scalable") = py::none(), py::arg("loc") = py::none(),
+ "Create a vector type")
+ .def_property_readonly(
+ "scalable",
+ [](MlirType self) { return mlirVectorTypeIsScalable(self); })
+ .def_property_readonly("scalable_dims", [](MlirType self) {
+ std::vector<bool> scalableDims;
+ size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
+ scalableDims.reserve(rank);
+ for (size_t i = 0; i < rank; ++i)
+ scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i));
+ return scalableDims;
+ });
}
};
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 50266b4b5233235..6e645188dac8616 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -281,6 +281,31 @@ MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
unwrap(elementType)));
}
+MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape,
+ const bool *scalable, MlirType elementType) {
+ return wrap(VectorType::get(
+ llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
+ llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
+}
+
+MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
+ const int64_t *shape,
+ const bool *scalable,
+ MlirType elementType) {
+ return wrap(VectorType::getChecked(
+ unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
+ unwrap(elementType),
+ llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
+}
+
+bool mlirVectorTypeIsScalable(MlirType type) {
+ return unwrap(type).cast<VectorType>().isScalable();
+}
+
+bool mlirVectorTypeIsDimScalable(MlirType type, intptr_t dim) {
+ return unwrap(type).cast<VectorType>().getScalableDims()[dim];
+}
+
//===----------------------------------------------------------------------===//
// Ranked / Unranked tensor type.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index c6425f80a8bce9c..3a2bdb9bfc93334 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -743,13 +743,27 @@ static int printBuiltinTypes(MlirContext ctx) {
fprintf(stderr, "\n");
// CHECK: vector<2x3xf32>
+ // Scalable vector type.
+ bool scalable[] = {false, true};
+ MlirType scalableVector = mlirVectorTypeGetScalable(
+ sizeof(shape) / sizeof(int64_t), shape, scalable, f32);
+ if (!mlirTypeIsAVector(scalableVector))
+ return 16;
+ if (!mlirVectorTypeIsScalable(scalableVector) ||
+ mlirVectorTypeIsDimScalable(scalableVector, 0) ||
+ !mlirVectorTypeIsDimScalable(scalableVector, 1))
+ return 17;
+ mlirTypeDump(scalableVector);
+ fprintf(stderr, "\n");
+ // CHECK: vector<2x[3]xf32>
+
// Ranked tensor type.
MlirType rankedTensor = mlirRankedTensorTypeGet(
sizeof(shape) / sizeof(int64_t), shape, f32, mlirAttributeGetNull());
if (!mlirTypeIsATensor(rankedTensor) ||
!mlirTypeIsARankedTensor(rankedTensor) ||
!mlirAttributeIsNull(mlirRankedTensorTypeGetEncoding(rankedTensor)))
- return 16;
+ return 18;
mlirTypeDump(rankedTensor);
fprintf(stderr, "\n");
// CHECK: tensor<2x3xf32>
@@ -759,7 +773,7 @@ static int printBuiltinTypes(MlirContext ctx) {
if (!mlirTypeIsATensor(unrankedTensor) ||
!mlirTypeIsAUnrankedTensor(unrankedTensor) ||
mlirShapedTypeHasRank(unrankedTensor))
- return 17;
+ return 19;
mlirTypeDump(unrankedTensor);
fprintf(stderr, "\n");
// CHECK: tensor<*xf32>
@@ -770,7 +784,7 @@ static int printBuiltinTypes(MlirContext ctx) {
f32, sizeof(shape) / sizeof(int64_t), shape, memSpace2);
if (!mlirTypeIsAMemRef(memRef) ||
!mlirAttributeEqual(mlirMemRefTypeGetMemorySpace(memRef), memSpace2))
- return 18;
+ return 20;
mlirTypeDump(memRef);
fprintf(stderr, "\n");
// CHECK: memref<2x3xf32, 2>
@@ -782,7 +796,7 @@ static int printBuiltinTypes(MlirContext ctx) {
mlirTypeIsAMemRef(unrankedMemRef) ||
!mlirAttributeEqual(mlirUnrankedMemrefGetMemorySpace(unrankedMemRef),
memSpace4))
- return 19;
+ return 21;
mlirTypeDump(unrankedMemRef);
fprintf(stderr, "\n");
// CHECK: memref<*xf32, 4>
@@ -793,7 +807,7 @@ static int printBuiltinTypes(MlirContext ctx) {
if (!mlirTypeIsATuple(tuple) || mlirTupleTypeGetNumTypes(tuple) != 2 ||
!mlirTypeEqual(mlirTupleTypeGetType(tuple, 0), unrankedMemRef) ||
!mlirTypeEqual(mlirTupleTypeGetType(tuple, 1), f32))
- return 20;
+ return 22;
mlirTypeDump(tuple);
fprintf(stderr, "\n");
// CHECK: tuple<memref<*xf32, 4>, f32>
@@ -805,16 +819,16 @@ static int printBuiltinTypes(MlirContext ctx) {
mlirIntegerTypeGet(ctx, 64)};
MlirType funcType = mlirFunctionTypeGet(ctx, 2, funcInputs, 3, funcResults);
if (mlirFunctionTypeGetNumInputs(funcType) != 2)
- return 21;
+ return 23;
if (mlirFunctionTypeGetNumResults(funcType) != 3)
- return 22;
+ return 24;
if (!mlirTypeEqual(funcInputs[0], mlirFunctionTypeGetInput(funcType, 0)) ||
!mlirTypeEqual(funcInputs[1], mlirFunctionTypeGetInput(funcType, 1)))
- return 23;
+ return 25;
if (!mlirTypeEqual(funcResults[0], mlirFunctionTypeGetResult(funcType, 0)) ||
!mlirTypeEqual(funcResults[1], mlirFunctionTypeGetResult(funcType, 1)) ||
!mlirTypeEqual(funcResults[2], mlirFunctionTypeGetResult(funcType, 2)))
- return 24;
+ return 26;
mlirTypeDump(funcType);
fprintf(stderr, "\n");
// CHECK: (index, i1) -> (i16, i32, i64)
@@ -829,7 +843,7 @@ static int printBuiltinTypes(MlirContext ctx) {
!mlirStringRefEqual(mlirOpaqueTypeGetDialectNamespace(opaque),
namespace) ||
!mlirStringRefEqual(mlirOpaqueTypeGetData(opaque), data))
- return 25;
+ return 27;
mlirTypeDump(opaque);
fprintf(stderr, "\n");
// CHECK: !dialect.type
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 672418b5383ae45..e2344794c839a3a 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -300,7 +300,7 @@ def testVectorType():
none = NoneType.get()
try:
- vector_invalid = VectorType.get(shape, none)
+ VectorType.get(shape, none)
except MLIRError as e:
# CHECK: Invalid type:
# CHECK: error: unknown: vector elements must be int/index/float type but got 'none'
@@ -308,6 +308,27 @@ def testVectorType():
else:
print("Exception not produced")
+ scalable_1 = VectorType.get(shape, f32, scalable=[False, True])
+ scalable_2 = VectorType.get([2, 3, 4],
+ f32,
+ scalable=[True, False, True])
+ assert scalable_1.scalable
+ assert scalable_2.scalable
+ assert scalable_1.scalable_dims == [False, True]
+ assert scalable_2.scalable_dims == [True, False, True]
+ # CHECK: scalable 1: vector<2x[3]xf32>
+ print("scalable 1: ", scalable_1)
+ # CHECK: scalable 2: vector<[2]x3x[4]xf32>
+ print("scalable 2: ", scalable_2)
+
+ try:
+ VectorType.get(shape, f32, scalable=[False, True, True])
+ except ValueError as e:
+ # CHECK: Expected len(scalable) == len(shape).
+ print(e)
+ else:
+ print("Exception not produced")
+
# CHECK-LABEL: TEST: testRankedTensorType
@run
``````````
</details>
https://github.com/llvm/llvm-project/pull/71050
More information about the Mlir-commits
mailing list