[Mlir-commits] [mlir] [MLIR, Python] Support converting boolean numpy arrays to and from mlir attributes (unrevert) (PR #115481)

Kasper Nielsen llvmlistbot at llvm.org
Fri Nov 8 07:45:37 PST 2024


https://github.com/kasper0406 updated https://github.com/llvm/llvm-project/pull/115481

>From b587e28266bd0ea44e724d555a5e786e797f145a Mon Sep 17 00:00:00 2001
From: Kasper Nielsen <kasper0406 at gmail.com>
Date: Fri, 8 Nov 2024 14:12:46 +0100
Subject: [PATCH 1/4] [MLIR,Python] Support converting boolean numpy arrays to
 and from mlir attributes

This reverts commit 0a68171b3c67503f7143856580f1b22a93ef566e.
---
 mlir/lib/Bindings/Python/IRAttributes.cpp | 278 ++++++++++++++--------
 mlir/test/python/ir/array_attributes.py   |  72 ++++++
 2 files changed, 253 insertions(+), 97 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index ead81a76c0538d..c8883c0d8270a2 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -13,6 +13,7 @@
 #include "IRModule.h"
 
 #include "PybindUtils.h"
+#include <pybind11/numpy.h>
 
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/Support/raw_ostream.h"
@@ -757,103 +758,10 @@ class PyDenseElementsAttribute
       throw py::error_already_set();
     }
     auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
-    SmallVector<int64_t> shape;
-    if (explicitShape) {
-      shape.append(explicitShape->begin(), explicitShape->end());
-    } else {
-      shape.append(view.shape, view.shape + view.ndim);
-    }
 
-    MlirAttribute encodingAttr = mlirAttributeGetNull();
     MlirContext context = contextWrapper->get();
-
-    // Detect format codes that are suitable for bulk loading. This includes
-    // all byte aligned integer and floating point types up to 8 bytes.
-    // Notably, this excludes, bool (which needs to be bit-packed) and
-    // other exotics which do not have a direct representation in the buffer
-    // protocol (i.e. complex, etc).
-    std::optional<MlirType> bulkLoadElementType;
-    if (explicitType) {
-      bulkLoadElementType = *explicitType;
-    } else {
-      std::string_view format(view.format);
-      if (format == "f") {
-        // f32
-        assert(view.itemsize == 4 && "mismatched array itemsize");
-        bulkLoadElementType = mlirF32TypeGet(context);
-      } else if (format == "d") {
-        // f64
-        assert(view.itemsize == 8 && "mismatched array itemsize");
-        bulkLoadElementType = mlirF64TypeGet(context);
-      } else if (format == "e") {
-        // f16
-        assert(view.itemsize == 2 && "mismatched array itemsize");
-        bulkLoadElementType = mlirF16TypeGet(context);
-      } else if (isSignedIntegerFormat(format)) {
-        if (view.itemsize == 4) {
-          // i32
-          bulkLoadElementType = signless
-                                    ? mlirIntegerTypeGet(context, 32)
-                                    : mlirIntegerTypeSignedGet(context, 32);
-        } else if (view.itemsize == 8) {
-          // i64
-          bulkLoadElementType = signless
-                                    ? mlirIntegerTypeGet(context, 64)
-                                    : mlirIntegerTypeSignedGet(context, 64);
-        } else if (view.itemsize == 1) {
-          // i8
-          bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
-                                         : mlirIntegerTypeSignedGet(context, 8);
-        } else if (view.itemsize == 2) {
-          // i16
-          bulkLoadElementType = signless
-                                    ? mlirIntegerTypeGet(context, 16)
-                                    : mlirIntegerTypeSignedGet(context, 16);
-        }
-      } else if (isUnsignedIntegerFormat(format)) {
-        if (view.itemsize == 4) {
-          // unsigned i32
-          bulkLoadElementType = signless
-                                    ? mlirIntegerTypeGet(context, 32)
-                                    : mlirIntegerTypeUnsignedGet(context, 32);
-        } else if (view.itemsize == 8) {
-          // unsigned i64
-          bulkLoadElementType = signless
-                                    ? mlirIntegerTypeGet(context, 64)
-                                    : mlirIntegerTypeUnsignedGet(context, 64);
-        } else if (view.itemsize == 1) {
-          // i8
-          bulkLoadElementType = signless
-                                    ? mlirIntegerTypeGet(context, 8)
-                                    : mlirIntegerTypeUnsignedGet(context, 8);
-        } else if (view.itemsize == 2) {
-          // i16
-          bulkLoadElementType = signless
-                                    ? mlirIntegerTypeGet(context, 16)
-                                    : mlirIntegerTypeUnsignedGet(context, 16);
-        }
-      }
-      if (!bulkLoadElementType) {
-        throw std::invalid_argument(
-            std::string("unimplemented array format conversion from format: ") +
-            std::string(format));
-      }
-    }
-
-    MlirType shapedType;
-    if (mlirTypeIsAShaped(*bulkLoadElementType)) {
-      if (explicitShape) {
-        throw std::invalid_argument("Shape can only be specified explicitly "
-                                    "when the type is not a shaped type.");
-      }
-      shapedType = *bulkLoadElementType;
-    } else {
-      shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(),
-                                           *bulkLoadElementType, encodingAttr);
-    }
-    size_t rawBufferSize = view.len;
-    MlirAttribute attr =
-        mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, view.buf);
+    MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType,
+                                                explicitShape, context);
     if (mlirAttributeIsNull(attr)) {
       throw std::invalid_argument(
           "DenseElementsAttr could not be constructed from the given buffer. "
@@ -963,6 +871,13 @@ class PyDenseElementsAttribute
         // unsigned i16
         return bufferInfo<uint16_t>(shapedType);
       }
+    } else if (mlirTypeIsAInteger(elementType) &&
+               mlirIntegerTypeGetWidth(elementType) == 1) {
+      // i1 / bool
+      // We can not send the buffer directly back to Python, because the i1
+      // values are bitpacked within MLIR. We call numpy's unpackbits function
+      // to convert the bytes.
+      return getBooleanBufferFromBitpackedAttribute();
     }
 
     // TODO: Currently crashes the program.
@@ -1016,14 +931,183 @@ class PyDenseElementsAttribute
            code == 'q';
   }
 
+  static MlirType
+  getShapedType(std::optional<MlirType> bulkLoadElementType,
+                std::optional<std::vector<int64_t>> explicitShape,
+                Py_buffer &view) {
+    SmallVector<int64_t> shape;
+    if (explicitShape) {
+      shape.append(explicitShape->begin(), explicitShape->end());
+    } else {
+      shape.append(view.shape, view.shape + view.ndim);
+    }
+
+    if (mlirTypeIsAShaped(*bulkLoadElementType)) {
+      if (explicitShape) {
+        throw std::invalid_argument("Shape can only be specified explicitly "
+                                    "when the type is not a shaped type.");
+      }
+      return *bulkLoadElementType;
+    } else {
+      MlirAttribute encodingAttr = mlirAttributeGetNull();
+      return mlirRankedTensorTypeGet(shape.size(), shape.data(),
+                                     *bulkLoadElementType, encodingAttr);
+    }
+  }
+
+  static MlirAttribute getAttributeFromBuffer(
+      Py_buffer &view, bool signless, std::optional<PyType> explicitType,
+      std::optional<std::vector<int64_t>> explicitShape, MlirContext &context) {
+    // Detect format codes that are suitable for bulk loading. This includes
+    // all byte aligned integer and floating point types up to 8 bytes.
+    // Notably, this excludes exotics types which do not have a direct
+    // representation in the buffer protocol (i.e. complex, etc).
+    std::optional<MlirType> bulkLoadElementType;
+    if (explicitType) {
+      bulkLoadElementType = *explicitType;
+    } else {
+      std::string_view format(view.format);
+      if (format == "f") {
+        // f32
+        assert(view.itemsize == 4 && "mismatched array itemsize");
+        bulkLoadElementType = mlirF32TypeGet(context);
+      } else if (format == "d") {
+        // f64
+        assert(view.itemsize == 8 && "mismatched array itemsize");
+        bulkLoadElementType = mlirF64TypeGet(context);
+      } else if (format == "e") {
+        // f16
+        assert(view.itemsize == 2 && "mismatched array itemsize");
+        bulkLoadElementType = mlirF16TypeGet(context);
+      } else if (format == "?") {
+        // i1
+        // The i1 type needs to be bit-packed, so we will handle it seperately
+        return getBitpackedAttributeFromBooleanBuffer(view, explicitShape,
+                                                      context);
+      } else if (isSignedIntegerFormat(format)) {
+        if (view.itemsize == 4) {
+          // i32
+          bulkLoadElementType = signless
+                                    ? mlirIntegerTypeGet(context, 32)
+                                    : mlirIntegerTypeSignedGet(context, 32);
+        } else if (view.itemsize == 8) {
+          // i64
+          bulkLoadElementType = signless
+                                    ? mlirIntegerTypeGet(context, 64)
+                                    : mlirIntegerTypeSignedGet(context, 64);
+        } else if (view.itemsize == 1) {
+          // i8
+          bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
+                                         : mlirIntegerTypeSignedGet(context, 8);
+        } else if (view.itemsize == 2) {
+          // i16
+          bulkLoadElementType = signless
+                                    ? mlirIntegerTypeGet(context, 16)
+                                    : mlirIntegerTypeSignedGet(context, 16);
+        }
+      } else if (isUnsignedIntegerFormat(format)) {
+        if (view.itemsize == 4) {
+          // unsigned i32
+          bulkLoadElementType = signless
+                                    ? mlirIntegerTypeGet(context, 32)
+                                    : mlirIntegerTypeUnsignedGet(context, 32);
+        } else if (view.itemsize == 8) {
+          // unsigned i64
+          bulkLoadElementType = signless
+                                    ? mlirIntegerTypeGet(context, 64)
+                                    : mlirIntegerTypeUnsignedGet(context, 64);
+        } else if (view.itemsize == 1) {
+          // i8
+          bulkLoadElementType = signless
+                                    ? mlirIntegerTypeGet(context, 8)
+                                    : mlirIntegerTypeUnsignedGet(context, 8);
+        } else if (view.itemsize == 2) {
+          // i16
+          bulkLoadElementType = signless
+                                    ? mlirIntegerTypeGet(context, 16)
+                                    : mlirIntegerTypeUnsignedGet(context, 16);
+        }
+      }
+      if (!bulkLoadElementType) {
+        throw std::invalid_argument(
+            std::string("unimplemented array format conversion from format: ") +
+            std::string(format));
+      }
+    }
+
+    MlirType type = getShapedType(bulkLoadElementType, explicitShape, view);
+    return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf);
+  }
+
+  // There is a complication for boolean numpy arrays, as numpy represents them
+  // as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 booleans
+  // per byte.
+  static MlirAttribute getBitpackedAttributeFromBooleanBuffer(
+      Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
+      MlirContext &context) {
+    if (llvm::endianness::native != llvm::endianness::little) {
+      // Given we have no good way of testing the behavior on big-endian systems
+      // we will throw
+      throw py::type_error("Constructing a bit-packed MLIR attribute is "
+                           "unsupported on big-endian systems");
+    }
+
+    py::array_t<uint8_t> unpackedArray(view.len,
+                                       static_cast<uint8_t *>(view.buf));
+
+    py::module numpy = py::module::import("numpy");
+    py::object packbits_func = numpy.attr("packbits");
+    py::object packed_booleans =
+        packbits_func(unpackedArray, "bitorder"_a = "little");
+    py::buffer_info pythonBuffer = packed_booleans.cast<py::buffer>().request();
+
+    MlirType bitpackedType =
+        getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view);
+    return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size,
+                                             pythonBuffer.ptr);
+  }
+
+  // This does the opposite transformation of
+  // `getBitpackedAttributeFromBooleanBuffer`
+  py::buffer_info getBooleanBufferFromBitpackedAttribute() {
+    if (llvm::endianness::native != llvm::endianness::little) {
+      // Given we have no good way of testing the behavior on big-endian systems
+      // we will throw
+      throw py::type_error("Constructing a numpy array from a MLIR attribute "
+                           "is unsupported on big-endian systems");
+    }
+
+    int64_t numBooleans = mlirElementsAttrGetNumElements(*this);
+    int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8);
+    uint8_t *bitpackedData = static_cast<uint8_t *>(
+        const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
+    py::array_t<uint8_t> packedArray(numBitpackedBytes, bitpackedData);
+
+    py::module numpy = py::module::import("numpy");
+    py::object unpackbits_func = numpy.attr("unpackbits");
+    py::object unpacked_booleans =
+        unpackbits_func(packedArray, "bitorder"_a = "little");
+    py::buffer_info pythonBuffer =
+        unpacked_booleans.cast<py::buffer>().request();
+
+    MlirType shapedType = mlirAttributeGetType(*this);
+    return bufferInfo<bool>(shapedType, (bool *)pythonBuffer.ptr, "?");
+  }
+
   template <typename Type>
   py::buffer_info bufferInfo(MlirType shapedType,
                              const char *explicitFormat = nullptr) {
-    intptr_t rank = mlirShapedTypeGetRank(shapedType);
     // Prepare the data for the buffer_info.
-    // Buffer is configured for read-only access below.
+    // Buffer is configured for read-only access inside the `bufferInfo` call.
     Type *data = static_cast<Type *>(
         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
+    return bufferInfo<Type>(shapedType, data, explicitFormat);
+  }
+
+  template <typename Type>
+  py::buffer_info bufferInfo(MlirType shapedType, Type *data,
+                             const char *explicitFormat = nullptr) {
+    intptr_t rank = mlirShapedTypeGetRank(shapedType);
     // Prepare the shape for the buffer_info.
     SmallVector<intptr_t, 4> shape;
     for (intptr_t i = 0; i < rank; ++i)
diff --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py
index 2bc403aace8348..256a69a939658d 100644
--- a/mlir/test/python/ir/array_attributes.py
+++ b/mlir/test/python/ir/array_attributes.py
@@ -326,6 +326,78 @@ def testGetDenseElementsF64():
         print(np.array(attr))
 
 
+### 1 bit/boolean integer arrays
+# CHECK-LABEL: TEST: testGetDenseElementsI1Signless
+ at run
+def testGetDenseElementsI1Signless():
+    with Context():
+        array = np.array([True], dtype=np.bool_)
+        attr = DenseElementsAttr.get(array)
+        # CHECK: dense<true> : tensor<1xi1>
+        print(attr)
+        # CHECK{LITERAL}: [ True]
+        print(np.array(attr))
+
+        array = np.array([[True, False, True], [True, True, False]], dtype=np.bool_)
+        attr = DenseElementsAttr.get(array)
+        # CHECK{LITERAL}: dense<[[true, false, true], [true, true, false]]> : tensor<2x3xi1>
+        print(attr)
+        # CHECK{LITERAL}: [[ True False True]
+        # CHECK{LITERAL}:  [ True True False]]
+        print(np.array(attr))
+
+        array = np.array(
+            [[True, True, False, False], [True, False, True, False]], dtype=np.bool_
+        )
+        attr = DenseElementsAttr.get(array)
+        # CHECK{LITERAL}: dense<[[true, true, false, false], [true, false, true, false]]> : tensor<2x4xi1>
+        print(attr)
+        # CHECK{LITERAL}: [[ True True False False]
+        # CHECK{LITERAL}:  [ True False True False]]
+        print(np.array(attr))
+
+        array = np.array(
+            [
+                [True, True, False, False],
+                [True, False, True, False],
+                [False, False, False, False],
+                [True, True, True, True],
+                [True, False, False, True],
+            ],
+            dtype=np.bool_,
+        )
+        attr = DenseElementsAttr.get(array)
+        # CHECK{LITERAL}: dense<[[true, true, false, false], [true, false, true, false], [false, false, false, false], [true, true, true, true], [true, false, false, true]]> : tensor<5x4xi1>
+        print(attr)
+        # CHECK{LITERAL}: [[ True True False False]
+        # CHECK{LITERAL}:  [ True False True False]
+        # CHECK{LITERAL}:  [False False False False]
+        # CHECK{LITERAL}:  [ True True True True]
+        # CHECK{LITERAL}:  [ True False False True]]
+        print(np.array(attr))
+
+        array = np.array(
+            [
+                [True, True, False, False, True, True, False, False, False],
+                [False, False, False, True, False, True, True, False, True],
+            ],
+            dtype=np.bool_,
+        )
+        attr = DenseElementsAttr.get(array)
+        # CHECK{LITERAL}: dense<[[true, true, false, false, true, true, false, false, false], [false, false, false, true, false, true, true, false, true]]> : tensor<2x9xi1>
+        print(attr)
+        # CHECK{LITERAL}: [[ True True False False True True False False False]
+        # CHECK{LITERAL}:  [False False False True False True True False True]]
+        print(np.array(attr))
+
+        array = np.array([], dtype=np.bool_)
+        attr = DenseElementsAttr.get(array)
+        # CHECK: dense<> : tensor<0xi1>
+        print(attr)
+        # CHECK{LITERAL}: []
+        print(np.array(attr))
+
+
 ### 16 bit integer arrays
 # CHECK-LABEL: TEST: testGetDenseElementsI16Signless
 @run

>From 75d34c9f939c855ea4cf58643f5ea65dc3df2996 Mon Sep 17 00:00:00 2001
From: Kasper Nielsen <kasper0406 at gmail.com>
Date: Fri, 8 Nov 2024 14:09:14 +0100
Subject: [PATCH 2/4] Fix python buffer lifetime issues

---
 mlir/lib/Bindings/Python/IRAttributes.cpp | 37 +++++++++++------------
 1 file changed, 17 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index c8883c0d8270a2..a2e06c5346020a 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -1056,13 +1056,17 @@ class PyDenseElementsAttribute
                                        static_cast<uint8_t *>(view.buf));
 
     py::module numpy = py::module::import("numpy");
-    py::object packbits_func = numpy.attr("packbits");
-    py::object packed_booleans =
-        packbits_func(unpackedArray, "bitorder"_a = "little");
-    py::buffer_info pythonBuffer = packed_booleans.cast<py::buffer>().request();
+    py::object packbitsFunc = numpy.attr("packbits");
+    py::object packedBooleans =
+        packbitsFunc(unpackedArray, "bitorder"_a = "little");
+    py::buffer_info pythonBuffer = packedBooleans.cast<py::buffer>().request();
 
     MlirType bitpackedType =
         getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view);
+    assert(pythonBuffer.itemsize == 1 && "Packbits must return uint8");
+    // Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of
+    // packedBooleans, hence the MlirAttribute will remain valid even when
+    // packedBooleans get reclaimed by the end of the function.
     return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size,
                                              pythonBuffer.ptr);
   }
@@ -1084,29 +1088,22 @@ class PyDenseElementsAttribute
     py::array_t<uint8_t> packedArray(numBitpackedBytes, bitpackedData);
 
     py::module numpy = py::module::import("numpy");
-    py::object unpackbits_func = numpy.attr("unpackbits");
-    py::object unpacked_booleans =
-        unpackbits_func(packedArray, "bitorder"_a = "little");
-    py::buffer_info pythonBuffer =
-        unpacked_booleans.cast<py::buffer>().request();
-
-    MlirType shapedType = mlirAttributeGetType(*this);
-    return bufferInfo<bool>(shapedType, (bool *)pythonBuffer.ptr, "?");
+    py::object unpackbitsFunc = numpy.attr("unpackbits");
+    py::object unpackedBooleans =
+        unpackbitsFunc(packedArray, "bitorder"_a = "little");
+    py::buffer pythonBuffer = unpackedBooleans.cast<py::buffer>();
+
+    // Make sure the returned py::buffer_view claims ownership of the data in
+    // `pythonBuffer` so it remains valid when Python reads it
+    return pythonBuffer.request();
   }
 
   template <typename Type>
   py::buffer_info bufferInfo(MlirType shapedType,
                              const char *explicitFormat = nullptr) {
-    // Prepare the data for the buffer_info.
-    // Buffer is configured for read-only access inside the `bufferInfo` call.
+    // Buffer is configured for read-only access below
     Type *data = static_cast<Type *>(
         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
-    return bufferInfo<Type>(shapedType, data, explicitFormat);
-  }
-
-  template <typename Type>
-  py::buffer_info bufferInfo(MlirType shapedType, Type *data,
-                             const char *explicitFormat = nullptr) {
     intptr_t rank = mlirShapedTypeGetRank(shapedType);
     // Prepare the shape for the buffer_info.
     SmallVector<intptr_t, 4> shape;

>From b11026a9d36307777b2ddc9f53e9a463394a9700 Mon Sep 17 00:00:00 2001
From: Kasper Nielsen <kasper0406 at gmail.com>
Date: Fri, 8 Nov 2024 14:18:23 +0100
Subject: [PATCH 3/4] Minor refactoring

---
 mlir/lib/Bindings/Python/IRAttributes.cpp | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index a2e06c5346020a..dba8220b5543a1 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -1101,10 +1101,11 @@ class PyDenseElementsAttribute
   template <typename Type>
   py::buffer_info bufferInfo(MlirType shapedType,
                              const char *explicitFormat = nullptr) {
-    // Buffer is configured for read-only access below
+    intptr_t rank = mlirShapedTypeGetRank(shapedType);
+    // Prepare the data for the buffer_info.
+    // Buffer is configured for read-only access below.
     Type *data = static_cast<Type *>(
         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
-    intptr_t rank = mlirShapedTypeGetRank(shapedType);
     // Prepare the shape for the buffer_info.
     SmallVector<intptr_t, 4> shape;
     for (intptr_t i = 0; i < rank; ++i)

>From e480fa8b6661a11a925d2a5c933a8d686f1ccf98 Mon Sep 17 00:00:00 2001
From: Kasper Nielsen <kasper0406 at gmail.com>
Date: Fri, 8 Nov 2024 16:45:00 +0100
Subject: [PATCH 4/4] Fix the boolean array padding, type and shape

---
 mlir/lib/Bindings/Python/IRAttributes.cpp | 22 ++++++++++++++++++++--
 1 file changed, 20 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index dba8220b5543a1..417c66b9165e3b 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -1089,12 +1089,30 @@ class PyDenseElementsAttribute
 
     py::module numpy = py::module::import("numpy");
     py::object unpackbitsFunc = numpy.attr("unpackbits");
-    py::object unpackedBooleans =
+    py::object equalFunc = numpy.attr("equal");
+    py::object reshapeFunc = numpy.attr("reshape");
+    py::array unpackedBooleans =
         unpackbitsFunc(packedArray, "bitorder"_a = "little");
-    py::buffer pythonBuffer = unpackedBooleans.cast<py::buffer>();
+
+    // Unpackbits operates on bytes and gives back a flat 0 / 1 integer array.
+    // We need to:
+    //   1. Slice away the padded bits
+    //   2. Make the boolean array have the correct shape
+    //   3. Convert the array to a boolean array
+    unpackedBooleans = unpackedBooleans[py::slice(0, numBooleans, 1)];
+    unpackedBooleans = equalFunc(unpackedBooleans, 1);
+
+    std::vector<intptr_t> shape;
+    MlirType shapedType = mlirAttributeGetType(*this);
+    intptr_t rank = mlirShapedTypeGetRank(shapedType);
+    for (intptr_t i = 0; i < rank; ++i) {
+      shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
+    }
+    unpackedBooleans = reshapeFunc(unpackedBooleans, shape);
 
     // Make sure the returned py::buffer_view claims ownership of the data in
     // `pythonBuffer` so it remains valid when Python reads it
+    py::buffer pythonBuffer = unpackedBooleans.cast<py::buffer>();
     return pythonBuffer.request();
   }
 



More information about the Mlir-commits mailing list