[Mlir-commits] [mlir] 96dadc9 - [mlir] support scalable vectors in python bindings (#71050)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 6 04:15:01 PST 2023
Author: Oleksandr "Alex" Zinenko
Date: 2023-11-06T13:14:56+01:00
New Revision: 96dadc9fc83dddf450e42ea5e9c3fd2616761830
URL: https://github.com/llvm/llvm-project/commit/96dadc9fc83dddf450e42ea5e9c3fd2616761830
DIFF: https://github.com/llvm/llvm-project/commit/96dadc9fc83dddf450e42ea5e9c3fd2616761830.diff
LOG: [mlir] support scalable vectors in python bindings (#71050)
The scalable dimension functionality was added to the vector type after
the bindings for it were defined, without the bindings being ever
updated. Fix that.
Added:
Modified:
mlir/include/mlir-c/BuiltinTypes.h
mlir/lib/Bindings/Python/IRTypes.cpp
mlir/lib/CAPI/IR/BuiltinTypes.cpp
mlir/test/CAPI/ir.c
mlir/test/python/ir/builtin_types.py
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index a6d8e10efbde923..1fd5691f41eec35 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -271,6 +271,32 @@ MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetChecked(MlirLocation loc,
const int64_t *shape,
MlirType elementType);
+/// Creates a scalable vector type with the shape identified by its rank and
+/// dimensions. A subset of dimensions may be marked as scalable via the
+/// corresponding flag list, which is expected to have as many entries as the
+/// rank of the vector. The vector is created in the same context as the element
+/// type.
+MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetScalable(intptr_t rank,
+ const int64_t *shape,
+ const bool *scalable,
+ MlirType elementType);
+
+/// Same as "mlirVectorTypeGetScalable" but returns a nullptr wrapping MlirType
+/// on illegal arguments, emitting appropriate diagnostics.
+MLIR_CAPI_EXPORTED
+MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
+ const int64_t *shape,
+ const bool *scalable,
+ MlirType elementType);
+
+/// Checks whether the given vector type is scalable, i.e., has at least one
+/// scalable dimension.
+MLIR_CAPI_EXPORTED bool mlirVectorTypeIsScalable(MlirType type);
+
+/// Checks whether the "dim"-th dimension of the given vector is scalable.
+MLIR_CAPI_EXPORTED bool mlirVectorTypeIsDimScalable(MlirType type,
+ intptr_t dim);
+
//===----------------------------------------------------------------------===//
// Ranked / Unranked Tensor type.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index a7ccfbea542f5c7..483db673f989e6b 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -462,19 +462,62 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
- c.def_static(
- "get",
- [](std::vector<int64_t> shape, PyType &elementType,
- DefaultingPyLocation loc) {
- PyMlirContext::ErrorCapture errors(loc->getContext());
- MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
- elementType);
- if (mlirTypeIsNull(t))
- throw MLIRError("Invalid type", errors.take());
- return PyVectorType(elementType.getContext(), t);
- },
- py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(),
- "Create a vector type");
+ c.def_static("get", &PyVectorType::get, py::arg("shape"),
+ py::arg("elementType"), py::kw_only(),
+ py::arg("scalable") = py::none(),
+ py::arg("scalable_dims") = py::none(),
+ py::arg("loc") = py::none(), "Create a vector type")
+ .def_property_readonly(
+ "scalable",
+ [](MlirType self) { return mlirVectorTypeIsScalable(self); })
+ .def_property_readonly("scalable_dims", [](MlirType self) {
+ std::vector<bool> scalableDims;
+ size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
+ scalableDims.reserve(rank);
+ for (size_t i = 0; i < rank; ++i)
+ scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i));
+ return scalableDims;
+ });
+ }
+
+private:
+ static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
+ std::optional<py::list> scalable,
+ std::optional<std::vector<int64_t>> scalableDims,
+ DefaultingPyLocation loc) {
+ if (scalable && scalableDims) {
+ throw py::value_error("'scalable' and 'scalable_dims' kwargs "
+ "are mutually exclusive.");
+ }
+
+ PyMlirContext::ErrorCapture errors(loc->getContext());
+ MlirType type;
+ if (scalable) {
+ if (scalable->size() != shape.size())
+ throw py::value_error("Expected len(scalable) == len(shape).");
+
+ SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
+ *scalable, [](const py::handle &h) { return h.cast<bool>(); }));
+ type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
+ scalableDimFlags.data(),
+ elementType);
+ } else if (scalableDims) {
+ SmallVector<bool> scalableDimFlags(shape.size(), false);
+ for (int64_t dim : *scalableDims) {
+ if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
+ throw py::value_error("Scalable dimension index out of bounds.");
+ scalableDimFlags[dim] = true;
+ }
+ type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
+ scalableDimFlags.data(),
+ elementType);
+ } else {
+ type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
+ elementType);
+ }
+ if (mlirTypeIsNull(type))
+ throw MLIRError("Invalid type", errors.take());
+ return PyVectorType(elementType.getContext(), type);
}
};
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 50266b4b5233235..6e645188dac8616 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -281,6 +281,31 @@ MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
unwrap(elementType)));
}
+MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape,
+ const bool *scalable, MlirType elementType) {
+ return wrap(VectorType::get(
+ llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
+ llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
+}
+
+MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
+ const int64_t *shape,
+ const bool *scalable,
+ MlirType elementType) {
+ return wrap(VectorType::getChecked(
+ unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
+ unwrap(elementType),
+ llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
+}
+
+bool mlirVectorTypeIsScalable(MlirType type) {
+ return unwrap(type).cast<VectorType>().isScalable();
+}
+
+bool mlirVectorTypeIsDimScalable(MlirType type, intptr_t dim) {
+ return unwrap(type).cast<VectorType>().getScalableDims()[dim];
+}
+
//===----------------------------------------------------------------------===//
// Ranked / Unranked tensor type.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 8d5dcbf62e85e2b..315458a08b613e0 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -746,13 +746,27 @@ static int printBuiltinTypes(MlirContext ctx) {
fprintf(stderr, "\n");
// CHECK: vector<2x3xf32>
+ // Scalable vector type.
+ bool scalable[] = {false, true};
+ MlirType scalableVector = mlirVectorTypeGetScalable(
+ sizeof(shape) / sizeof(int64_t), shape, scalable, f32);
+ if (!mlirTypeIsAVector(scalableVector))
+ return 16;
+ if (!mlirVectorTypeIsScalable(scalableVector) ||
+ mlirVectorTypeIsDimScalable(scalableVector, 0) ||
+ !mlirVectorTypeIsDimScalable(scalableVector, 1))
+ return 17;
+ mlirTypeDump(scalableVector);
+ fprintf(stderr, "\n");
+ // CHECK: vector<2x[3]xf32>
+
// Ranked tensor type.
MlirType rankedTensor = mlirRankedTensorTypeGet(
sizeof(shape) / sizeof(int64_t), shape, f32, mlirAttributeGetNull());
if (!mlirTypeIsATensor(rankedTensor) ||
!mlirTypeIsARankedTensor(rankedTensor) ||
!mlirAttributeIsNull(mlirRankedTensorTypeGetEncoding(rankedTensor)))
- return 16;
+ return 18;
mlirTypeDump(rankedTensor);
fprintf(stderr, "\n");
// CHECK: tensor<2x3xf32>
@@ -762,7 +776,7 @@ static int printBuiltinTypes(MlirContext ctx) {
if (!mlirTypeIsATensor(unrankedTensor) ||
!mlirTypeIsAUnrankedTensor(unrankedTensor) ||
mlirShapedTypeHasRank(unrankedTensor))
- return 17;
+ return 19;
mlirTypeDump(unrankedTensor);
fprintf(stderr, "\n");
// CHECK: tensor<*xf32>
@@ -773,7 +787,7 @@ static int printBuiltinTypes(MlirContext ctx) {
f32, sizeof(shape) / sizeof(int64_t), shape, memSpace2);
if (!mlirTypeIsAMemRef(memRef) ||
!mlirAttributeEqual(mlirMemRefTypeGetMemorySpace(memRef), memSpace2))
- return 18;
+ return 20;
mlirTypeDump(memRef);
fprintf(stderr, "\n");
// CHECK: memref<2x3xf32, 2>
@@ -785,7 +799,7 @@ static int printBuiltinTypes(MlirContext ctx) {
mlirTypeIsAMemRef(unrankedMemRef) ||
!mlirAttributeEqual(mlirUnrankedMemrefGetMemorySpace(unrankedMemRef),
memSpace4))
- return 19;
+ return 21;
mlirTypeDump(unrankedMemRef);
fprintf(stderr, "\n");
// CHECK: memref<*xf32, 4>
@@ -796,7 +810,7 @@ static int printBuiltinTypes(MlirContext ctx) {
if (!mlirTypeIsATuple(tuple) || mlirTupleTypeGetNumTypes(tuple) != 2 ||
!mlirTypeEqual(mlirTupleTypeGetType(tuple, 0), unrankedMemRef) ||
!mlirTypeEqual(mlirTupleTypeGetType(tuple, 1), f32))
- return 20;
+ return 22;
mlirTypeDump(tuple);
fprintf(stderr, "\n");
// CHECK: tuple<memref<*xf32, 4>, f32>
@@ -808,16 +822,16 @@ static int printBuiltinTypes(MlirContext ctx) {
mlirIntegerTypeGet(ctx, 64)};
MlirType funcType = mlirFunctionTypeGet(ctx, 2, funcInputs, 3, funcResults);
if (mlirFunctionTypeGetNumInputs(funcType) != 2)
- return 21;
+ return 23;
if (mlirFunctionTypeGetNumResults(funcType) != 3)
- return 22;
+ return 24;
if (!mlirTypeEqual(funcInputs[0], mlirFunctionTypeGetInput(funcType, 0)) ||
!mlirTypeEqual(funcInputs[1], mlirFunctionTypeGetInput(funcType, 1)))
- return 23;
+ return 25;
if (!mlirTypeEqual(funcResults[0], mlirFunctionTypeGetResult(funcType, 0)) ||
!mlirTypeEqual(funcResults[1], mlirFunctionTypeGetResult(funcType, 1)) ||
!mlirTypeEqual(funcResults[2], mlirFunctionTypeGetResult(funcType, 2)))
- return 24;
+ return 26;
mlirTypeDump(funcType);
fprintf(stderr, "\n");
// CHECK: (index, i1) -> (i16, i32, i64)
@@ -832,7 +846,7 @@ static int printBuiltinTypes(MlirContext ctx) {
!mlirStringRefEqual(mlirOpaqueTypeGetDialectNamespace(opaque),
namespace) ||
!mlirStringRefEqual(mlirOpaqueTypeGetData(opaque), data))
- return 25;
+ return 27;
mlirTypeDump(opaque);
fprintf(stderr, "\n");
// CHECK: !dialect.type
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 672418b5383ae45..4c891a2ca2ab9a2 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -300,7 +300,7 @@ def testVectorType():
none = NoneType.get()
try:
- vector_invalid = VectorType.get(shape, none)
+ VectorType.get(shape, none)
except MLIRError as e:
# CHECK: Invalid type:
# CHECK: error: unknown: vector elements must be int/index/float type but got 'none'
@@ -308,6 +308,46 @@ def testVectorType():
else:
print("Exception not produced")
+ scalable_1 = VectorType.get(shape, f32, scalable=[False, True])
+ scalable_2 = VectorType.get([2, 3, 4], f32, scalable=[True, False, True])
+ assert scalable_1.scalable
+ assert scalable_2.scalable
+ assert scalable_1.scalable_dims == [False, True]
+ assert scalable_2.scalable_dims == [True, False, True]
+ # CHECK: scalable 1: vector<2x[3]xf32>
+ print("scalable 1: ", scalable_1)
+ # CHECK: scalable 2: vector<[2]x3x[4]xf32>
+ print("scalable 2: ", scalable_2)
+
+ scalable_3 = VectorType.get(shape, f32, scalable_dims=[1])
+ scalable_4 = VectorType.get([2, 3, 4], f32, scalable_dims=[0, 2])
+ assert scalable_3 == scalable_1
+ assert scalable_4 == scalable_2
+
+ try:
+ VectorType.get(shape, f32, scalable=[False, True, True])
+ except ValueError as e:
+ # CHECK: Expected len(scalable) == len(shape).
+ print(e)
+ else:
+ print("Exception not produced")
+
+ try:
+ VectorType.get(shape, f32, scalable=[False, True], scalable_dims=[1])
+ except ValueError as e:
+ # CHECK: kwargs are mutually exclusive.
+ print(e)
+ else:
+ print("Exception not produced")
+
+ try:
+ VectorType.get(shape, f32, scalable=[False, True], scalable_dims=[42])
+ except ValueError as e:
+ # CHECK: Scalable dimension index out of bounds.
+ print(e)
+ else:
+ print("Exception not produced")
+
# CHECK-LABEL: TEST: testRankedTensorType
@run
@@ -337,7 +377,6 @@ def testRankedTensorType():
assert RankedTensorType.get(shape, f32).encoding is None
-
# CHECK-LABEL: TEST: testUnrankedTensorType
@run
def testUnrankedTensorType():
More information about the Mlir-commits
mailing list