[Mlir-commits] [mlir] [MLIR, Python] Support converting boolean numpy arrays to and from mlir attributes (PR #113064)
Kasper Nielsen
llvmlistbot at llvm.org
Wed Oct 23 13:03:24 PDT 2024
================
@@ -1016,14 +930,177 @@ class PyDenseElementsAttribute
code == 'q';
}
+ static MlirType
+ getShapedType(std::optional<MlirType> bulkLoadElementType,
+ std::optional<std::vector<int64_t>> explicitShape,
+ Py_buffer &view) {
+ SmallVector<int64_t> shape;
+ if (explicitShape) {
+ shape.append(explicitShape->begin(), explicitShape->end());
+ } else {
+ shape.append(view.shape, view.shape + view.ndim);
+ }
+
+ if (mlirTypeIsAShaped(*bulkLoadElementType)) {
+ if (explicitShape) {
+ throw std::invalid_argument("Shape can only be specified explicitly "
+ "when the type is not a shaped type.");
+ }
+ return *bulkLoadElementType;
+ } else {
+ MlirAttribute encodingAttr = mlirAttributeGetNull();
+ return mlirRankedTensorTypeGet(shape.size(), shape.data(),
+ *bulkLoadElementType, encodingAttr);
+ }
+ }
+
+ static MlirAttribute getAttributeFromBuffer(
+ Py_buffer &view, bool signless, std::optional<PyType> explicitType,
+ std::optional<std::vector<int64_t>> explicitShape, MlirContext &context) {
+ // Detect format codes that are suitable for bulk loading. This includes
+ // all byte aligned integer and floating point types up to 8 bytes.
+ // Notably, this excludes exotics types which do not have a direct
+ // representation in the buffer protocol (i.e. complex, etc).
+ std::optional<MlirType> bulkLoadElementType;
+ if (explicitType) {
+ bulkLoadElementType = *explicitType;
+ } else {
+ std::string_view format(view.format);
+ if (format == "f") {
+ // f32
+ assert(view.itemsize == 4 && "mismatched array itemsize");
+ bulkLoadElementType = mlirF32TypeGet(context);
+ } else if (format == "d") {
+ // f64
+ assert(view.itemsize == 8 && "mismatched array itemsize");
+ bulkLoadElementType = mlirF64TypeGet(context);
+ } else if (format == "e") {
+ // f16
+ assert(view.itemsize == 2 && "mismatched array itemsize");
+ bulkLoadElementType = mlirF16TypeGet(context);
+ } else if (format == "?") {
----------------
kasper0406 wrote:
I believe it is because of the bit-packing. If it started returning true, I think that it will not be possible to distinguish between the i8 (and u8) cases and the boolean case, as they both would have an itemsize of 1.
So the format would have to be compared to "?" explicitly anyways.
https://github.com/llvm/llvm-project/pull/113064
More information about the Mlir-commits
mailing list