[Mlir-commits] [mlir] [MLIR, Python] Support converting boolean numpy arrays to and from mlir attributes (unrevert) (PR #115481)
Kasper Nielsen
llvmlistbot at llvm.org
Fri Nov 8 07:45:37 PST 2024
https://github.com/kasper0406 updated https://github.com/llvm/llvm-project/pull/115481
>From b587e28266bd0ea44e724d555a5e786e797f145a Mon Sep 17 00:00:00 2001
From: Kasper Nielsen <kasper0406 at gmail.com>
Date: Fri, 8 Nov 2024 14:12:46 +0100
Subject: [PATCH 1/4] [MLIR,Python] Support converting boolean numpy arrays to
and from mlir attributes
This reverts commit 0a68171b3c67503f7143856580f1b22a93ef566e.
---
mlir/lib/Bindings/Python/IRAttributes.cpp | 278 ++++++++++++++--------
mlir/test/python/ir/array_attributes.py | 72 ++++++
2 files changed, 253 insertions(+), 97 deletions(-)
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index ead81a76c0538d..c8883c0d8270a2 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -13,6 +13,7 @@
#include "IRModule.h"
#include "PybindUtils.h"
+#include <pybind11/numpy.h>
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/raw_ostream.h"
@@ -757,103 +758,10 @@ class PyDenseElementsAttribute
throw py::error_already_set();
}
auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
- SmallVector<int64_t> shape;
- if (explicitShape) {
- shape.append(explicitShape->begin(), explicitShape->end());
- } else {
- shape.append(view.shape, view.shape + view.ndim);
- }
- MlirAttribute encodingAttr = mlirAttributeGetNull();
MlirContext context = contextWrapper->get();
-
- // Detect format codes that are suitable for bulk loading. This includes
- // all byte aligned integer and floating point types up to 8 bytes.
- // Notably, this excludes, bool (which needs to be bit-packed) and
- // other exotics which do not have a direct representation in the buffer
- // protocol (i.e. complex, etc).
- std::optional<MlirType> bulkLoadElementType;
- if (explicitType) {
- bulkLoadElementType = *explicitType;
- } else {
- std::string_view format(view.format);
- if (format == "f") {
- // f32
- assert(view.itemsize == 4 && "mismatched array itemsize");
- bulkLoadElementType = mlirF32TypeGet(context);
- } else if (format == "d") {
- // f64
- assert(view.itemsize == 8 && "mismatched array itemsize");
- bulkLoadElementType = mlirF64TypeGet(context);
- } else if (format == "e") {
- // f16
- assert(view.itemsize == 2 && "mismatched array itemsize");
- bulkLoadElementType = mlirF16TypeGet(context);
- } else if (isSignedIntegerFormat(format)) {
- if (view.itemsize == 4) {
- // i32
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 32)
- : mlirIntegerTypeSignedGet(context, 32);
- } else if (view.itemsize == 8) {
- // i64
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 64)
- : mlirIntegerTypeSignedGet(context, 64);
- } else if (view.itemsize == 1) {
- // i8
- bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
- : mlirIntegerTypeSignedGet(context, 8);
- } else if (view.itemsize == 2) {
- // i16
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 16)
- : mlirIntegerTypeSignedGet(context, 16);
- }
- } else if (isUnsignedIntegerFormat(format)) {
- if (view.itemsize == 4) {
- // unsigned i32
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 32)
- : mlirIntegerTypeUnsignedGet(context, 32);
- } else if (view.itemsize == 8) {
- // unsigned i64
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 64)
- : mlirIntegerTypeUnsignedGet(context, 64);
- } else if (view.itemsize == 1) {
- // i8
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 8)
- : mlirIntegerTypeUnsignedGet(context, 8);
- } else if (view.itemsize == 2) {
- // i16
- bulkLoadElementType = signless
- ? mlirIntegerTypeGet(context, 16)
- : mlirIntegerTypeUnsignedGet(context, 16);
- }
- }
- if (!bulkLoadElementType) {
- throw std::invalid_argument(
- std::string("unimplemented array format conversion from format: ") +
- std::string(format));
- }
- }
-
- MlirType shapedType;
- if (mlirTypeIsAShaped(*bulkLoadElementType)) {
- if (explicitShape) {
- throw std::invalid_argument("Shape can only be specified explicitly "
- "when the type is not a shaped type.");
- }
- shapedType = *bulkLoadElementType;
- } else {
- shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(),
- *bulkLoadElementType, encodingAttr);
- }
- size_t rawBufferSize = view.len;
- MlirAttribute attr =
- mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, view.buf);
+ MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType,
+ explicitShape, context);
if (mlirAttributeIsNull(attr)) {
throw std::invalid_argument(
"DenseElementsAttr could not be constructed from the given buffer. "
@@ -963,6 +871,13 @@ class PyDenseElementsAttribute
// unsigned i16
return bufferInfo<uint16_t>(shapedType);
}
+ } else if (mlirTypeIsAInteger(elementType) &&
+ mlirIntegerTypeGetWidth(elementType) == 1) {
+ // i1 / bool
+ // We can not send the buffer directly back to Python, because the i1
+ // values are bitpacked within MLIR. We call numpy's unpackbits function
+ // to convert the bytes.
+ return getBooleanBufferFromBitpackedAttribute();
}
// TODO: Currently crashes the program.
@@ -1016,14 +931,183 @@ class PyDenseElementsAttribute
code == 'q';
}
+ static MlirType
+ getShapedType(std::optional<MlirType> bulkLoadElementType,
+ std::optional<std::vector<int64_t>> explicitShape,
+ Py_buffer &view) {
+ SmallVector<int64_t> shape;
+ if (explicitShape) {
+ shape.append(explicitShape->begin(), explicitShape->end());
+ } else {
+ shape.append(view.shape, view.shape + view.ndim);
+ }
+
+ if (mlirTypeIsAShaped(*bulkLoadElementType)) {
+ if (explicitShape) {
+ throw std::invalid_argument("Shape can only be specified explicitly "
+ "when the type is not a shaped type.");
+ }
+ return *bulkLoadElementType;
+ } else {
+ MlirAttribute encodingAttr = mlirAttributeGetNull();
+ return mlirRankedTensorTypeGet(shape.size(), shape.data(),
+ *bulkLoadElementType, encodingAttr);
+ }
+ }
+
+ static MlirAttribute getAttributeFromBuffer(
+ Py_buffer &view, bool signless, std::optional<PyType> explicitType,
+ std::optional<std::vector<int64_t>> explicitShape, MlirContext &context) {
+ // Detect format codes that are suitable for bulk loading. This includes
+ // all byte aligned integer and floating point types up to 8 bytes.
+ // Notably, this excludes exotics types which do not have a direct
+ // representation in the buffer protocol (i.e. complex, etc).
+ std::optional<MlirType> bulkLoadElementType;
+ if (explicitType) {
+ bulkLoadElementType = *explicitType;
+ } else {
+ std::string_view format(view.format);
+ if (format == "f") {
+ // f32
+ assert(view.itemsize == 4 && "mismatched array itemsize");
+ bulkLoadElementType = mlirF32TypeGet(context);
+ } else if (format == "d") {
+ // f64
+ assert(view.itemsize == 8 && "mismatched array itemsize");
+ bulkLoadElementType = mlirF64TypeGet(context);
+ } else if (format == "e") {
+ // f16
+ assert(view.itemsize == 2 && "mismatched array itemsize");
+ bulkLoadElementType = mlirF16TypeGet(context);
+ } else if (format == "?") {
+ // i1
+ // The i1 type needs to be bit-packed, so we will handle it seperately
+ return getBitpackedAttributeFromBooleanBuffer(view, explicitShape,
+ context);
+ } else if (isSignedIntegerFormat(format)) {
+ if (view.itemsize == 4) {
+ // i32
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 32)
+ : mlirIntegerTypeSignedGet(context, 32);
+ } else if (view.itemsize == 8) {
+ // i64
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 64)
+ : mlirIntegerTypeSignedGet(context, 64);
+ } else if (view.itemsize == 1) {
+ // i8
+ bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
+ : mlirIntegerTypeSignedGet(context, 8);
+ } else if (view.itemsize == 2) {
+ // i16
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 16)
+ : mlirIntegerTypeSignedGet(context, 16);
+ }
+ } else if (isUnsignedIntegerFormat(format)) {
+ if (view.itemsize == 4) {
+ // unsigned i32
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 32)
+ : mlirIntegerTypeUnsignedGet(context, 32);
+ } else if (view.itemsize == 8) {
+ // unsigned i64
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 64)
+ : mlirIntegerTypeUnsignedGet(context, 64);
+ } else if (view.itemsize == 1) {
+ // i8
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 8)
+ : mlirIntegerTypeUnsignedGet(context, 8);
+ } else if (view.itemsize == 2) {
+ // i16
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 16)
+ : mlirIntegerTypeUnsignedGet(context, 16);
+ }
+ }
+ if (!bulkLoadElementType) {
+ throw std::invalid_argument(
+ std::string("unimplemented array format conversion from format: ") +
+ std::string(format));
+ }
+ }
+
+ MlirType type = getShapedType(bulkLoadElementType, explicitShape, view);
+ return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf);
+ }
+
+ // There is a complication for boolean numpy arrays, as numpy represents them
+ // as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 booleans
+ // per byte.
+ static MlirAttribute getBitpackedAttributeFromBooleanBuffer(
+ Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
+ MlirContext &context) {
+ if (llvm::endianness::native != llvm::endianness::little) {
+ // Given we have no good way of testing the behavior on big-endian systems
+ // we will throw
+ throw py::type_error("Constructing a bit-packed MLIR attribute is "
+ "unsupported on big-endian systems");
+ }
+
+ py::array_t<uint8_t> unpackedArray(view.len,
+ static_cast<uint8_t *>(view.buf));
+
+ py::module numpy = py::module::import("numpy");
+ py::object packbits_func = numpy.attr("packbits");
+ py::object packed_booleans =
+ packbits_func(unpackedArray, "bitorder"_a = "little");
+ py::buffer_info pythonBuffer = packed_booleans.cast<py::buffer>().request();
+
+ MlirType bitpackedType =
+ getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view);
+ return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size,
+ pythonBuffer.ptr);
+ }
+
+ // This does the opposite transformation of
+ // `getBitpackedAttributeFromBooleanBuffer`
+ py::buffer_info getBooleanBufferFromBitpackedAttribute() {
+ if (llvm::endianness::native != llvm::endianness::little) {
+ // Given we have no good way of testing the behavior on big-endian systems
+ // we will throw
+ throw py::type_error("Constructing a numpy array from a MLIR attribute "
+ "is unsupported on big-endian systems");
+ }
+
+ int64_t numBooleans = mlirElementsAttrGetNumElements(*this);
+ int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8);
+ uint8_t *bitpackedData = static_cast<uint8_t *>(
+ const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
+ py::array_t<uint8_t> packedArray(numBitpackedBytes, bitpackedData);
+
+ py::module numpy = py::module::import("numpy");
+ py::object unpackbits_func = numpy.attr("unpackbits");
+ py::object unpacked_booleans =
+ unpackbits_func(packedArray, "bitorder"_a = "little");
+ py::buffer_info pythonBuffer =
+ unpacked_booleans.cast<py::buffer>().request();
+
+ MlirType shapedType = mlirAttributeGetType(*this);
+ return bufferInfo<bool>(shapedType, (bool *)pythonBuffer.ptr, "?");
+ }
+
template <typename Type>
py::buffer_info bufferInfo(MlirType shapedType,
const char *explicitFormat = nullptr) {
- intptr_t rank = mlirShapedTypeGetRank(shapedType);
// Prepare the data for the buffer_info.
- // Buffer is configured for read-only access below.
+ // Buffer is configured for read-only access inside the `bufferInfo` call.
Type *data = static_cast<Type *>(
const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
+ return bufferInfo<Type>(shapedType, data, explicitFormat);
+ }
+
+ template <typename Type>
+ py::buffer_info bufferInfo(MlirType shapedType, Type *data,
+ const char *explicitFormat = nullptr) {
+ intptr_t rank = mlirShapedTypeGetRank(shapedType);
// Prepare the shape for the buffer_info.
SmallVector<intptr_t, 4> shape;
for (intptr_t i = 0; i < rank; ++i)
diff --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py
index 2bc403aace8348..256a69a939658d 100644
--- a/mlir/test/python/ir/array_attributes.py
+++ b/mlir/test/python/ir/array_attributes.py
@@ -326,6 +326,78 @@ def testGetDenseElementsF64():
print(np.array(attr))
+### 1 bit/boolean integer arrays
+# CHECK-LABEL: TEST: testGetDenseElementsI1Signless
+ at run
+def testGetDenseElementsI1Signless():
+ with Context():
+ array = np.array([True], dtype=np.bool_)
+ attr = DenseElementsAttr.get(array)
+ # CHECK: dense<true> : tensor<1xi1>
+ print(attr)
+ # CHECK{LITERAL}: [ True]
+ print(np.array(attr))
+
+ array = np.array([[True, False, True], [True, True, False]], dtype=np.bool_)
+ attr = DenseElementsAttr.get(array)
+ # CHECK{LITERAL}: dense<[[true, false, true], [true, true, false]]> : tensor<2x3xi1>
+ print(attr)
+ # CHECK{LITERAL}: [[ True False True]
+ # CHECK{LITERAL}: [ True True False]]
+ print(np.array(attr))
+
+ array = np.array(
+ [[True, True, False, False], [True, False, True, False]], dtype=np.bool_
+ )
+ attr = DenseElementsAttr.get(array)
+ # CHECK{LITERAL}: dense<[[true, true, false, false], [true, false, true, false]]> : tensor<2x4xi1>
+ print(attr)
+ # CHECK{LITERAL}: [[ True True False False]
+ # CHECK{LITERAL}: [ True False True False]]
+ print(np.array(attr))
+
+ array = np.array(
+ [
+ [True, True, False, False],
+ [True, False, True, False],
+ [False, False, False, False],
+ [True, True, True, True],
+ [True, False, False, True],
+ ],
+ dtype=np.bool_,
+ )
+ attr = DenseElementsAttr.get(array)
+ # CHECK{LITERAL}: dense<[[true, true, false, false], [true, false, true, false], [false, false, false, false], [true, true, true, true], [true, false, false, true]]> : tensor<5x4xi1>
+ print(attr)
+ # CHECK{LITERAL}: [[ True True False False]
+ # CHECK{LITERAL}: [ True False True False]
+ # CHECK{LITERAL}: [False False False False]
+ # CHECK{LITERAL}: [ True True True True]
+ # CHECK{LITERAL}: [ True False False True]]
+ print(np.array(attr))
+
+ array = np.array(
+ [
+ [True, True, False, False, True, True, False, False, False],
+ [False, False, False, True, False, True, True, False, True],
+ ],
+ dtype=np.bool_,
+ )
+ attr = DenseElementsAttr.get(array)
+ # CHECK{LITERAL}: dense<[[true, true, false, false, true, true, false, false, false], [false, false, false, true, false, true, true, false, true]]> : tensor<2x9xi1>
+ print(attr)
+ # CHECK{LITERAL}: [[ True True False False True True False False False]
+ # CHECK{LITERAL}: [False False False True False True True False True]]
+ print(np.array(attr))
+
+ array = np.array([], dtype=np.bool_)
+ attr = DenseElementsAttr.get(array)
+ # CHECK: dense<> : tensor<0xi1>
+ print(attr)
+ # CHECK{LITERAL}: []
+ print(np.array(attr))
+
+
### 16 bit integer arrays
# CHECK-LABEL: TEST: testGetDenseElementsI16Signless
@run
>From 75d34c9f939c855ea4cf58643f5ea65dc3df2996 Mon Sep 17 00:00:00 2001
From: Kasper Nielsen <kasper0406 at gmail.com>
Date: Fri, 8 Nov 2024 14:09:14 +0100
Subject: [PATCH 2/4] Fix python buffer lifetime issues
---
mlir/lib/Bindings/Python/IRAttributes.cpp | 37 +++++++++++------------
1 file changed, 17 insertions(+), 20 deletions(-)
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index c8883c0d8270a2..a2e06c5346020a 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -1056,13 +1056,17 @@ class PyDenseElementsAttribute
static_cast<uint8_t *>(view.buf));
py::module numpy = py::module::import("numpy");
- py::object packbits_func = numpy.attr("packbits");
- py::object packed_booleans =
- packbits_func(unpackedArray, "bitorder"_a = "little");
- py::buffer_info pythonBuffer = packed_booleans.cast<py::buffer>().request();
+ py::object packbitsFunc = numpy.attr("packbits");
+ py::object packedBooleans =
+ packbitsFunc(unpackedArray, "bitorder"_a = "little");
+ py::buffer_info pythonBuffer = packedBooleans.cast<py::buffer>().request();
MlirType bitpackedType =
getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view);
+ assert(pythonBuffer.itemsize == 1 && "Packbits must return uint8");
+ // Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of
+ // packedBooleans, hence the MlirAttribute will remain valid even when
+ // packedBooleans get reclaimed by the end of the function.
return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size,
pythonBuffer.ptr);
}
@@ -1084,29 +1088,22 @@ class PyDenseElementsAttribute
py::array_t<uint8_t> packedArray(numBitpackedBytes, bitpackedData);
py::module numpy = py::module::import("numpy");
- py::object unpackbits_func = numpy.attr("unpackbits");
- py::object unpacked_booleans =
- unpackbits_func(packedArray, "bitorder"_a = "little");
- py::buffer_info pythonBuffer =
- unpacked_booleans.cast<py::buffer>().request();
-
- MlirType shapedType = mlirAttributeGetType(*this);
- return bufferInfo<bool>(shapedType, (bool *)pythonBuffer.ptr, "?");
+ py::object unpackbitsFunc = numpy.attr("unpackbits");
+ py::object unpackedBooleans =
+ unpackbitsFunc(packedArray, "bitorder"_a = "little");
+ py::buffer pythonBuffer = unpackedBooleans.cast<py::buffer>();
+
+ // Make sure the returned py::buffer_view claims ownership of the data in
+ // `pythonBuffer` so it remains valid when Python reads it
+ return pythonBuffer.request();
}
template <typename Type>
py::buffer_info bufferInfo(MlirType shapedType,
const char *explicitFormat = nullptr) {
- // Prepare the data for the buffer_info.
- // Buffer is configured for read-only access inside the `bufferInfo` call.
+ // Buffer is configured for read-only access below
Type *data = static_cast<Type *>(
const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
- return bufferInfo<Type>(shapedType, data, explicitFormat);
- }
-
- template <typename Type>
- py::buffer_info bufferInfo(MlirType shapedType, Type *data,
- const char *explicitFormat = nullptr) {
intptr_t rank = mlirShapedTypeGetRank(shapedType);
// Prepare the shape for the buffer_info.
SmallVector<intptr_t, 4> shape;
>From b11026a9d36307777b2ddc9f53e9a463394a9700 Mon Sep 17 00:00:00 2001
From: Kasper Nielsen <kasper0406 at gmail.com>
Date: Fri, 8 Nov 2024 14:18:23 +0100
Subject: [PATCH 3/4] Minor refactoring
---
mlir/lib/Bindings/Python/IRAttributes.cpp | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index a2e06c5346020a..dba8220b5543a1 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -1101,10 +1101,11 @@ class PyDenseElementsAttribute
template <typename Type>
py::buffer_info bufferInfo(MlirType shapedType,
const char *explicitFormat = nullptr) {
- // Buffer is configured for read-only access below
+ intptr_t rank = mlirShapedTypeGetRank(shapedType);
+ // Prepare the data for the buffer_info.
+ // Buffer is configured for read-only access below.
Type *data = static_cast<Type *>(
const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
- intptr_t rank = mlirShapedTypeGetRank(shapedType);
// Prepare the shape for the buffer_info.
SmallVector<intptr_t, 4> shape;
for (intptr_t i = 0; i < rank; ++i)
>From e480fa8b6661a11a925d2a5c933a8d686f1ccf98 Mon Sep 17 00:00:00 2001
From: Kasper Nielsen <kasper0406 at gmail.com>
Date: Fri, 8 Nov 2024 16:45:00 +0100
Subject: [PATCH 4/4] Fix the boolean array padding, type and shape
---
mlir/lib/Bindings/Python/IRAttributes.cpp | 22 ++++++++++++++++++++--
1 file changed, 20 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index dba8220b5543a1..417c66b9165e3b 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -1089,12 +1089,30 @@ class PyDenseElementsAttribute
py::module numpy = py::module::import("numpy");
py::object unpackbitsFunc = numpy.attr("unpackbits");
- py::object unpackedBooleans =
+ py::object equalFunc = numpy.attr("equal");
+ py::object reshapeFunc = numpy.attr("reshape");
+ py::array unpackedBooleans =
unpackbitsFunc(packedArray, "bitorder"_a = "little");
- py::buffer pythonBuffer = unpackedBooleans.cast<py::buffer>();
+
+ // Unpackbits operates on bytes and gives back a flat 0 / 1 integer array.
+ // We need to:
+ // 1. Slice away the padded bits
+ // 2. Make the boolean array have the correct shape
+ // 3. Convert the array to a boolean array
+ unpackedBooleans = unpackedBooleans[py::slice(0, numBooleans, 1)];
+ unpackedBooleans = equalFunc(unpackedBooleans, 1);
+
+ std::vector<intptr_t> shape;
+ MlirType shapedType = mlirAttributeGetType(*this);
+ intptr_t rank = mlirShapedTypeGetRank(shapedType);
+ for (intptr_t i = 0; i < rank; ++i) {
+ shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
+ }
+ unpackedBooleans = reshapeFunc(unpackedBooleans, shape);
// Make sure the returned py::buffer_view claims ownership of the data in
// `pythonBuffer` so it remains valid when Python reads it
+ py::buffer pythonBuffer = unpackedBooleans.cast<py::buffer>();
return pythonBuffer.request();
}
More information about the Mlir-commits
mailing list