[Mlir-commits] [mlir] [MLIR, Python] Support converting boolean numpy arrays to and from mlir attributes (PR #113064)

Kasper Nielsen llvmlistbot at llvm.org
Tue Oct 29 08:17:19 PDT 2024


https://github.com/kasper0406 updated https://github.com/llvm/llvm-project/pull/113064

>From 053d0b7e20bee7783bc895234e7c1b11a4f7b280 Mon Sep 17 00:00:00 2001
From: Kasper Nielsen <kasper0406 at gmail.com>
Date: Sat, 19 Oct 2024 19:55:58 +0200
Subject: [PATCH 01/12] [Python] Attempt at getting boolean types working

---
 mlir/lib/Bindings/Python/IRAttributes.cpp | 111 ++++++++++++++++++----
 mlir/lib/CAPI/IR/BuiltinAttributes.cpp    |   9 +-
 mlir/lib/IR/BuiltinAttributes.cpp         |   4 +
 mlir/test/python/ir/array_attributes.py   |  64 +++++++++++++
 4 files changed, 167 insertions(+), 21 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index ead81a76c0538d..9671cf77397383 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -6,9 +6,12 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include <iostream>
+
 #include <optional>
 #include <string_view>
 #include <utility>
+#include <memory>
 
 #include "IRModule.h"
 
@@ -757,14 +760,7 @@ class PyDenseElementsAttribute
       throw py::error_already_set();
     }
     auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
-    SmallVector<int64_t> shape;
-    if (explicitShape) {
-      shape.append(explicitShape->begin(), explicitShape->end());
-    } else {
-      shape.append(view.shape, view.shape + view.ndim);
-    }
 
-    MlirAttribute encodingAttr = mlirAttributeGetNull();
     MlirContext context = contextWrapper->get();
 
     // Detect format codes that are suitable for bulk loading. This includes
@@ -773,6 +769,7 @@ class PyDenseElementsAttribute
     // other exotics which do not have a direct representation in the buffer
     // protocol (i.e. complex, etc).
     std::optional<MlirType> bulkLoadElementType;
+    bool kasperTest = false;
     if (explicitType) {
       bulkLoadElementType = *explicitType;
     } else {
@@ -789,6 +786,10 @@ class PyDenseElementsAttribute
         // f16
         assert(view.itemsize == 2 && "mismatched array itemsize");
         bulkLoadElementType = mlirF16TypeGet(context);
+      } else if (format == "?") {
+        // i1
+        kasperTest = true;
+        bulkLoadElementType = mlirIntegerTypeGet(context, 1);
       } else if (isSignedIntegerFormat(format)) {
         if (view.itemsize == 4) {
           // i32
@@ -840,20 +841,46 @@ class PyDenseElementsAttribute
       }
     }
 
-    MlirType shapedType;
-    if (mlirTypeIsAShaped(*bulkLoadElementType)) {
-      if (explicitShape) {
-        throw std::invalid_argument("Shape can only be specified explicitly "
-                                    "when the type is not a shaped type.");
+    size_t rawBufferSize = view.len;
+    MlirAttribute attr;
+    if (kasperTest) {
+      std::cerr << "Buffer content:" << std::endl;
+      for (int i = 0; i < view.len; i++) {
+        std::cerr << (int)*((char*)view.buf + i) << std::endl;
+      }
+
+      std::cerr << "Constructing intermediate buffer..." << std::endl;
+      // First read the content of the python buffer as u8's, to correct for endianess
+      MlirAttribute intermediateAttr = mlirDenseElementsAttrRawBufferGet(
+        getShapedType(mlirIntegerTypeUnsignedGet(context, 8), explicitShape, view), rawBufferSize, view.buf);
+
+      std::cerr << "Endian corrected buffer content:" << std::endl;
+      for (int i = 0; i < view.len; i++) {
+        std::cerr << (int) mlirDenseElementsAttrGetUInt8Value(intermediateAttr, i) << std::endl;
       }
-      shapedType = *bulkLoadElementType;
+
+      // Pack the boolean array according to the i8 bitpacking layout
+      const int numPackedBytes = (view.len + 7) / 8;
+      SmallVector<uint8_t, 8> bitpacked(numPackedBytes);
+      for (int byteNum = 0; byteNum < numPackedBytes; byteNum++) {
+        uint8_t byte = 0;
+        for (int bitNr = 0; 8 * byteNum + bitNr < view.len; bitNr++) {
+          int pos = 8 * byteNum + bitNr;
+          uint8_t boolVal = mlirDenseElementsAttrGetUInt8Value(intermediateAttr, pos) << bitNr;
+          byte |= boolVal;
+        }
+        bitpacked[byteNum] = byte;
+      }
+
+      std::cerr << "Bitpacked: " << std::endl;
+      for (int i = 0; i < numPackedBytes; i++) {
+        std::cerr << (int)*((uint8_t*)bitpacked.data() + i) << std::endl;
+      }
+
+      attr = mlirDenseElementsAttrRawBufferGet(getShapedType(bulkLoadElementType, explicitShape, view), numPackedBytes, bitpacked.data());
     } else {
-      shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(),
-                                           *bulkLoadElementType, encodingAttr);
+      attr = mlirDenseElementsAttrRawBufferGet(getShapedType(bulkLoadElementType, explicitShape, view), rawBufferSize, view.buf);
     }
-    size_t rawBufferSize = view.len;
-    MlirAttribute attr =
-        mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, view.buf);
     if (mlirAttributeIsNull(attr)) {
       throw std::invalid_argument(
           "DenseElementsAttr could not be constructed from the given buffer. "
@@ -963,6 +990,20 @@ class PyDenseElementsAttribute
         // unsigned i16
         return bufferInfo<uint16_t>(shapedType);
       }
+    } else if (mlirTypeIsAInteger(elementType) &&
+               mlirIntegerTypeGetWidth(elementType) == 1) {
+      // i1 / bool type
+      if (!m_boolBuffer.has_value()) {
+        // TODO(knielsen): Handle endianess
+        int64_t numBooleans = mlirElementsAttrGetNumElements(*this);
+        std::cerr << "Allocating a buffer with #elements = " << numBooleans << std::endl;
+        m_boolBuffer = SmallVector<uint8_t, 8>(numBooleans);
+        // TODO(knielsen): Bit unpack!
+        if (numBooleans > 0) {
+          m_boolBuffer.value()[0] = 0b10101011;
+        }
+      }
+      return bufferInfo<uint8_t>(shapedType, "?", m_boolBuffer.value().data());
     }
 
     // TODO: Currently crashes the program.
@@ -1016,14 +1057,44 @@ class PyDenseElementsAttribute
            code == 'q';
   }
 
+  static MlirType getShapedType(std::optional<MlirType> bulkLoadElementType,
+                                std::optional<std::vector<int64_t>> explicitShape,
+                                Py_buffer& view) {
+    SmallVector<int64_t> shape;
+    if (explicitShape) {
+      shape.append(explicitShape->begin(), explicitShape->end());
+    } else {
+      shape.append(view.shape, view.shape + view.ndim);
+    }
+
+    MlirType shapedType;
+    if (mlirTypeIsAShaped(*bulkLoadElementType)) {
+      if (explicitShape) {
+        throw std::invalid_argument("Shape can only be specified explicitly "
+                                    "when the type is not a shaped type.");
+      }
+      return *bulkLoadElementType;
+    } else {
+      MlirAttribute encodingAttr = mlirAttributeGetNull();
+      return mlirRankedTensorTypeGet(shape.size(), shape.data(),
+                                     *bulkLoadElementType, encodingAttr);
+    }
+  }
+
+  std::optional<SmallVector<uint8_t, 8>> m_boolBuffer;
+
   template <typename Type>
   py::buffer_info bufferInfo(MlirType shapedType,
-                             const char *explicitFormat = nullptr) {
+                             const char *explicitFormat = nullptr,
+                             Type* dataOverride = nullptr) {
     intptr_t rank = mlirShapedTypeGetRank(shapedType);
     // Prepare the data for the buffer_info.
     // Buffer is configured for read-only access below.
     Type *data = static_cast<Type *>(
         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
+    if (dataOverride != nullptr) {
+      data = dataOverride;
+    }
     // Prepare the shape for the buffer_info.
     SmallVector<intptr_t, 4> shape;
     for (intptr_t i = 0; i < rank; ++i)
@@ -1083,6 +1154,7 @@ class PyDenseIntElementsAttribute
     bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
     if (isUnsigned) {
       if (width == 1) {
+        std::cerr << "Loading unsigned i1 values at position: " << pos << std::endl;
         return mlirDenseElementsAttrGetBoolValue(*this, pos);
       }
       if (width == 8) {
@@ -1099,6 +1171,7 @@ class PyDenseIntElementsAttribute
       }
     } else {
       if (width == 1) {
+        std::cerr << "Loading signed i1 values at position: " << pos << std::endl;
         return mlirDenseElementsAttrGetBoolValue(*this, pos);
       }
       if (width == 8) {
diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 11d1ade552f5a2..dc78be53eee0cb 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -6,6 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include <iostream>
+
 #include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/Support.h"
 #include "mlir/CAPI/AffineMap.h"
@@ -527,8 +529,11 @@ MlirAttribute mlirDenseElementsAttrRawBufferGet(MlirType shapedType,
                               rawBufferSize);
   bool isSplat = false;
   if (!DenseElementsAttr::isValidRawBuffer(shapedTypeCpp, rawBufferCpp,
-                                           isSplat))
+                                           isSplat)) {
+    std::cerr << "NULL POINTER!!!" << std::endl;
     return mlirAttributeGetNull();
+  }
+  std::cerr << "Pointer looks ok..." << std::endl;
   return wrap(DenseElementsAttr::getFromRawBuffer(shapedTypeCpp, rawBufferCpp));
 }
 
@@ -588,7 +593,7 @@ MlirAttribute mlirDenseElementsAttrBoolGet(MlirType shapedType,
                                            const int *elements) {
   SmallVector<bool, 8> values(elements, elements + numElements);
   return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
-                                     values));
+                                    values));
 }
 
 /// Creates a dense attribute with elements of the type deduced by templates.
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 8861a940336133..1009b1882e9425 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -6,6 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include <iostream>
+
 #include "mlir/IR/BuiltinAttributes.h"
 #include "AttributeDetail.h"
 #include "mlir/IR/AffineMap.h"
@@ -1088,6 +1090,8 @@ bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
     }
 
     // This is a valid non-splat buffer if it has the right size.
+    std::cerr << "Raw buffer width: " << rawBufferWidth << std::endl;
+    std::cerr << "Aligned to width: " << llvm::alignTo<8>(numElements) << std::endl;
     return rawBufferWidth == llvm::alignTo<8>(numElements);
   }
 
diff --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py
index 2bc403aace8348..9084cd7d55f2f8 100644
--- a/mlir/test/python/ir/array_attributes.py
+++ b/mlir/test/python/ir/array_attributes.py
@@ -326,6 +326,70 @@ def testGetDenseElementsF64():
         print(np.array(attr))
 
 
+### 1 bit/boolean integer arrays
+# CHECK-LABEL: TEST: testGetDenseElementsI1Signless
+ at run
+def testGetDenseElementsI1Signless():
+    with Context():
+        array = np.array([True], dtype=np.bool_)
+        attr = DenseElementsAttr.get(array)
+        # CHECK: dense<true> : tensor<1xi1>
+        print(attr)
+        # CHECK: {{\[}} True]
+        print(np.array(attr))
+
+        array = np.array([[True, False, True], [True, True, False]], dtype=np.bool_)
+        attr = DenseElementsAttr.get(array)
+        # CHECK: dense<{{\[}}[true, false, true], [true, true, false]]> : tensor<2x3xi1>
+        print(attr)
+        # CHECK: {{\[}}[ True False True]
+        # CHECK: {{\[}} True True False]]
+        print(np.array(attr))
+
+        array = np.array([[True, True, False, False], [True, False, True, False]], dtype=np.bool_)
+        attr = DenseElementsAttr.get(array)
+        # CHECK: dense<{{\[}}[true, true, false, false], [true, false, true, false]]> : tensor<2x4xi1>
+        print(attr)
+        # CHECK: {{\[}}[ True True False False]
+        # CHECK: {{\[}} True False True False]]
+        print(np.array(attr))
+
+        array = np.array([
+            [True, True, False, False],
+            [True, False, True, False],
+            [False, False, False, False],
+            [True, True, True, True],
+            [True, False, False, True],
+        ], dtype=np.bool_)
+        attr = DenseElementsAttr.get(array)
+        # CHECK: dense<{{\[}}[true, true, false, false], [true, false, true, false], [false, false, false, false], [true, true, true, true], [true, false, false, true]]> : tensor<5x4xi1>
+        print(attr)
+        # CHECK: {{\[}}[ True True False False]
+        # CHECK: {{\[}} True False True False]]
+        # CHECK: {{\[}}False False False False]]
+        # CHECK: {{\[}} True True True True]]
+        # CHECK: {{\[}} True False False True]]
+        print(np.array(attr))
+
+        array = np.array([
+            [True, True, False, False, True, True, False, False, False],
+            [False, False, False, True, False, True, True, False, True]
+        ], dtype=np.bool_)
+        attr = DenseElementsAttr.get(array)
+        # CHECK: dense<{{\[}}[true, true, false, false, true, true, false, false, false], [false, false, false, true, false, true, true, false, true]]> : tensor<2x9xi1>
+        print(attr)
+        # CHECK: {{\[}}[ True True False False True True False False False]
+        # CHECK: {{\[}}False False False True False True True False True]]
+        print(np.array(attr))
+
+        array = np.array([], dtype=np.bool_)
+        attr = DenseElementsAttr.get(array)
+        # CHECK: dense<{{\[}}]> : tensor<0xi1>
+        print(attr)
+        # CHECK: {{\[}} ]
+        print(np.array(attr))
+
+
 ### 16 bit integer arrays
 # CHECK-LABEL: TEST: testGetDenseElementsI16Signless
 @run

>From d5da538eae082136fab07edd0fcff1c1bad77f9e Mon Sep 17 00:00:00 2001
From: Kasper Nielsen <kasper0406 at gmail.com>
Date: Sat, 19 Oct 2024 22:00:56 +0200
Subject: [PATCH 02/12] Refactorings

---
 mlir/lib/Bindings/Python/IRAttributes.cpp | 249 +++++++++++-----------
 mlir/test/python/ir/array_attributes.py   |  10 +-
 2 files changed, 125 insertions(+), 134 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 9671cf77397383..c85c95c896f628 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -762,125 +762,7 @@ class PyDenseElementsAttribute
     auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
 
     MlirContext context = contextWrapper->get();
-
-    // Detect format codes that are suitable for bulk loading. This includes
-    // all byte aligned integer and floating point types up to 8 bytes.
-    // Notably, this excludes, bool (which needs to be bit-packed) and
-    // other exotics which do not have a direct representation in the buffer
-    // protocol (i.e. complex, etc).
-    std::optional<MlirType> bulkLoadElementType;
-    bool kasperTest = false;
-    if (explicitType) {
-      bulkLoadElementType = *explicitType;
-    } else {
-      std::string_view format(view.format);
-      if (format == "f") {
-        // f32
-        assert(view.itemsize == 4 && "mismatched array itemsize");
-        bulkLoadElementType = mlirF32TypeGet(context);
-      } else if (format == "d") {
-        // f64
-        assert(view.itemsize == 8 && "mismatched array itemsize");
-        bulkLoadElementType = mlirF64TypeGet(context);
-      } else if (format == "e") {
-        // f16
-        assert(view.itemsize == 2 && "mismatched array itemsize");
-        bulkLoadElementType = mlirF16TypeGet(context);
-      } else if (format == "?") {
-        // i1
-        kasperTest = true;
-        bulkLoadElementType = mlirIntegerTypeGet(context, 1);
-      } else if (isSignedIntegerFormat(format)) {
-        if (view.itemsize == 4) {
-          // i32
-          bulkLoadElementType = signless
-                                    ? mlirIntegerTypeGet(context, 32)
-                                    : mlirIntegerTypeSignedGet(context, 32);
-        } else if (view.itemsize == 8) {
-          // i64
-          bulkLoadElementType = signless
-                                    ? mlirIntegerTypeGet(context, 64)
-                                    : mlirIntegerTypeSignedGet(context, 64);
-        } else if (view.itemsize == 1) {
-          // i8
-          bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
-                                         : mlirIntegerTypeSignedGet(context, 8);
-        } else if (view.itemsize == 2) {
-          // i16
-          bulkLoadElementType = signless
-                                    ? mlirIntegerTypeGet(context, 16)
-                                    : mlirIntegerTypeSignedGet(context, 16);
-        }
-      } else if (isUnsignedIntegerFormat(format)) {
-        if (view.itemsize == 4) {
-          // unsigned i32
-          bulkLoadElementType = signless
-                                    ? mlirIntegerTypeGet(context, 32)
-                                    : mlirIntegerTypeUnsignedGet(context, 32);
-        } else if (view.itemsize == 8) {
-          // unsigned i64
-          bulkLoadElementType = signless
-                                    ? mlirIntegerTypeGet(context, 64)
-                                    : mlirIntegerTypeUnsignedGet(context, 64);
-        } else if (view.itemsize == 1) {
-          // i8
-          bulkLoadElementType = signless
-                                    ? mlirIntegerTypeGet(context, 8)
-                                    : mlirIntegerTypeUnsignedGet(context, 8);
-        } else if (view.itemsize == 2) {
-          // i16
-          bulkLoadElementType = signless
-                                    ? mlirIntegerTypeGet(context, 16)
-                                    : mlirIntegerTypeUnsignedGet(context, 16);
-        }
-      }
-      if (!bulkLoadElementType) {
-        throw std::invalid_argument(
-            std::string("unimplemented array format conversion from format: ") +
-            std::string(format));
-      }
-    }
-
-    size_t rawBufferSize = view.len;
-    MlirAttribute attr;
-    if (kasperTest) {
-      std::cerr << "Buffer content:" << std::endl;
-      for (int i = 0; i < view.len; i++) {
-        std::cerr << (int)*((char*)view.buf + i) << std::endl;
-      }
-
-      std::cerr << "Constructing intermediate buffer..." << std::endl;
-      // First read the content of the python buffer as u8's, to correct for endianess
-      MlirAttribute intermediateAttr = mlirDenseElementsAttrRawBufferGet(
-        getShapedType(mlirIntegerTypeUnsignedGet(context, 8), explicitShape, view), rawBufferSize, view.buf);
-
-      std::cerr << "Endian corrected buffer content:" << std::endl;
-      for (int i = 0; i < view.len; i++) {
-        std::cerr << (int) mlirDenseElementsAttrGetUInt8Value(intermediateAttr, i) << std::endl;
-      }
-
-      // Pack the boolean array according to the i8 bitpacking layout
-      const int numPackedBytes = (view.len + 7) / 8;
-      SmallVector<uint8_t, 8> bitpacked(numPackedBytes);
-      for (int byteNum = 0; byteNum < numPackedBytes; byteNum++) {
-        uint8_t byte = 0;
-        for (int bitNr = 0; 8 * byteNum + bitNr < view.len; bitNr++) {
-          int pos = 8 * byteNum + bitNr;
-          uint8_t boolVal = mlirDenseElementsAttrGetUInt8Value(intermediateAttr, pos) << bitNr;
-          byte |= boolVal;
-        }
-        bitpacked[byteNum] = byte;
-      }
-
-      std::cerr << "Bitpacked: " << std::endl;
-      for (int i = 0; i < numPackedBytes; i++) {
-        std::cerr << (int)*((uint8_t*)bitpacked.data() + i) << std::endl;
-      }
-
-      attr = mlirDenseElementsAttrRawBufferGet(getShapedType(bulkLoadElementType, explicitShape, view), numPackedBytes, bitpacked.data());
-    } else {
-      attr = mlirDenseElementsAttrRawBufferGet(getShapedType(bulkLoadElementType, explicitShape, view), rawBufferSize, view.buf);
-    }
+    MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType, explicitShape, context);
     if (mlirAttributeIsNull(attr)) {
       throw std::invalid_argument(
           "DenseElementsAttr could not be constructed from the given buffer. "
@@ -992,18 +874,19 @@ class PyDenseElementsAttribute
       }
     } else if (mlirTypeIsAInteger(elementType) &&
                mlirIntegerTypeGetWidth(elementType) == 1) {
-      // i1 / bool type
+      // i1 / bool
       if (!m_boolBuffer.has_value()) {
-        // TODO(knielsen): Handle endianess
+        // Because i1's are bitpacked within MLIR, we need to convert it into the
+        // one bool per byte representation used by numpy.
+        // We allocate a new array to keep around for this purpose.
         int64_t numBooleans = mlirElementsAttrGetNumElements(*this);
-        std::cerr << "Allocating a buffer with #elements = " << numBooleans << std::endl;
-        m_boolBuffer = SmallVector<uint8_t, 8>(numBooleans);
-        // TODO(knielsen): Bit unpack!
-        if (numBooleans > 0) {
-          m_boolBuffer.value()[0] = 0b10101011;
+        m_boolBuffer = SmallVector<bool, 8>(numBooleans);
+        for (int i = 0; i < numBooleans; i++) {
+          bool value = mlirDenseElementsAttrGetBoolValue(*this, i);
+          m_boolBuffer.value()[i] = value;
         }
       }
-      return bufferInfo<uint8_t>(shapedType, "?", m_boolBuffer.value().data());
+      return bufferInfo<bool>(shapedType, "?", m_boolBuffer.value().data());
     }
 
     // TODO: Currently crashes the program.
@@ -1041,6 +924,8 @@ class PyDenseElementsAttribute
   }
 
 private:
+  std::optional<SmallVector<bool, 8>> m_boolBuffer;
+
   static bool isUnsignedIntegerFormat(std::string_view format) {
     if (format.empty())
       return false;
@@ -1067,7 +952,6 @@ class PyDenseElementsAttribute
       shape.append(view.shape, view.shape + view.ndim);
     }
 
-    MlirType shapedType;
     if (mlirTypeIsAShaped(*bulkLoadElementType)) {
       if (explicitShape) {
         throw std::invalid_argument("Shape can only be specified explicitly "
@@ -1081,7 +965,114 @@ class PyDenseElementsAttribute
     }
   }
 
-  std::optional<SmallVector<uint8_t, 8>> m_boolBuffer;
+  static MlirAttribute getAttributeFromBuffer(Py_buffer& view,
+                                              bool signless,
+                                              std::optional<PyType> explicitType,
+                                              std::optional<std::vector<int64_t>> explicitShape,
+                                              MlirContext& context) {
+    // Detect format codes that are suitable for bulk loading. This includes
+    // all byte aligned integer and floating point types up to 8 bytes.
+    // Notably, this excludes, bool (which needs to be bit-packed) and
+    // other exotics which do not have a direct representation in the buffer
+    // protocol (i.e. complex, etc).
+    std::optional<MlirType> bulkLoadElementType;
+    if (explicitType) {
+      bulkLoadElementType = *explicitType;
+    } else {
+      std::string_view format(view.format);
+      if (format == "f") {
+        // f32
+        assert(view.itemsize == 4 && "mismatched array itemsize");
+        bulkLoadElementType = mlirF32TypeGet(context);
+      } else if (format == "d") {
+        // f64
+        assert(view.itemsize == 8 && "mismatched array itemsize");
+        bulkLoadElementType = mlirF64TypeGet(context);
+      } else if (format == "e") {
+        // f16
+        assert(view.itemsize == 2 && "mismatched array itemsize");
+        bulkLoadElementType = mlirF16TypeGet(context);
+      } else if (format == "?") {
+        // i1
+        // The i1 type needs to be bit-packed, so we will handle it seperately
+        return getAttributeFromBufferBoolean(view, explicitShape, context);
+      } else if (isSignedIntegerFormat(format)) {
+        if (view.itemsize == 4) {
+          // i32
+          bulkLoadElementType = signless
+                                    ? mlirIntegerTypeGet(context, 32)
+                                    : mlirIntegerTypeSignedGet(context, 32);
+        } else if (view.itemsize == 8) {
+          // i64
+          bulkLoadElementType = signless
+                                    ? mlirIntegerTypeGet(context, 64)
+                                    : mlirIntegerTypeSignedGet(context, 64);
+        } else if (view.itemsize == 1) {
+          // i8
+          bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
+                                         : mlirIntegerTypeSignedGet(context, 8);
+        } else if (view.itemsize == 2) {
+          // i16
+          bulkLoadElementType = signless
+                                    ? mlirIntegerTypeGet(context, 16)
+                                    : mlirIntegerTypeSignedGet(context, 16);
+        }
+      } else if (isUnsignedIntegerFormat(format)) {
+        if (view.itemsize == 4) {
+          // unsigned i32
+          bulkLoadElementType = signless
+                                    ? mlirIntegerTypeGet(context, 32)
+                                    : mlirIntegerTypeUnsignedGet(context, 32);
+        } else if (view.itemsize == 8) {
+          // unsigned i64
+          bulkLoadElementType = signless
+                                    ? mlirIntegerTypeGet(context, 64)
+                                    : mlirIntegerTypeUnsignedGet(context, 64);
+        } else if (view.itemsize == 1) {
+          // i8
+          bulkLoadElementType = signless
+                                    ? mlirIntegerTypeGet(context, 8)
+                                    : mlirIntegerTypeUnsignedGet(context, 8);
+        } else if (view.itemsize == 2) {
+          // i16
+          bulkLoadElementType = signless
+                                    ? mlirIntegerTypeGet(context, 16)
+                                    : mlirIntegerTypeUnsignedGet(context, 16);
+        }
+      }
+      if (!bulkLoadElementType) {
+        throw std::invalid_argument(
+            std::string("unimplemented array format conversion from format: ") +
+            std::string(format));
+      }
+    }
+
+    return mlirDenseElementsAttrRawBufferGet(getShapedType(bulkLoadElementType, explicitShape, view), view.len, view.buf);
+  }
+
+  static MlirAttribute getAttributeFromBufferBoolean(Py_buffer& view,
+                                                     std::optional<std::vector<int64_t>> explicitShape,
+                                                     MlirContext& context) {
+    // First read the content of the python buffer as u8's, to correct for endianess
+    MlirAttribute intermediateAttr = mlirDenseElementsAttrRawBufferGet(
+      getShapedType(mlirIntegerTypeUnsignedGet(context, 8), explicitShape, view), view.len, view.buf);
+
+    // Pack the boolean array according to the i8 bitpacking layout
+    const int numPackedBytes = (view.len + 7) / 8;
+    SmallVector<uint8_t, 8> bitpacked(numPackedBytes);
+    for (int byteNum = 0; byteNum < numPackedBytes; byteNum++) {
+      uint8_t byte = 0;
+      for (int bitNr = 0; 8 * byteNum + bitNr < view.len; bitNr++) {
+        int pos = 8 * byteNum + bitNr;
+        uint8_t boolVal = mlirDenseElementsAttrGetUInt8Value(intermediateAttr, pos) << bitNr;
+        byte |= boolVal;
+      }
+      bitpacked[byteNum] = byte;
+    }
+
+    return mlirDenseElementsAttrRawBufferGet(getShapedType(
+      mlirIntegerTypeGet(context, 1), explicitShape, view), numPackedBytes, bitpacked.data());
+  }
 
   template <typename Type>
   py::buffer_info bufferInfo(MlirType shapedType,
diff --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py
index 9084cd7d55f2f8..16d7322cfe7119 100644
--- a/mlir/test/python/ir/array_attributes.py
+++ b/mlir/test/python/ir/array_attributes.py
@@ -365,9 +365,9 @@ def testGetDenseElementsI1Signless():
         # CHECK: dense<{{\[}}[true, true, false, false], [true, false, true, false], [false, false, false, false], [true, true, true, true], [true, false, false, true]]> : tensor<5x4xi1>
         print(attr)
         # CHECK: {{\[}}[ True True False False]
-        # CHECK: {{\[}} True False True False]]
-        # CHECK: {{\[}}False False False False]]
-        # CHECK: {{\[}} True True True True]]
+        # CHECK: {{\[}} True False True False]
+        # CHECK: {{\[}}False False False False]
+        # CHECK: {{\[}} True True True True]
         # CHECK: {{\[}} True False False True]]
         print(np.array(attr))
 
@@ -384,9 +384,9 @@ def testGetDenseElementsI1Signless():
 
         array = np.array([], dtype=np.bool_)
         attr = DenseElementsAttr.get(array)
-        # CHECK: dense<{{\[}}]> : tensor<0xi1>
+        # CHECK: dense<> : tensor<0xi1>
         print(attr)
-        # CHECK: {{\[}} ]
+        # CHECK: {{\[}}]
         print(np.array(attr))
 
 

>From 52b49ac835f82c9d5ca9c1ae8f433e0a685e23a5 Mon Sep 17 00:00:00 2001
From: Kasper Nielsen <kasper0406 at gmail.com>
Date: Sat, 19 Oct 2024 22:17:44 +0200
Subject: [PATCH 03/12] Cleanups

---
 mlir/lib/Bindings/Python/IRAttributes.cpp | 29 +++++++++++------------
 mlir/lib/CAPI/IR/BuiltinAttributes.cpp    |  9 ++-----
 mlir/lib/IR/BuiltinAttributes.cpp         |  4 ----
 3 files changed, 16 insertions(+), 26 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index c85c95c896f628..043b7ed5867e8b 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -6,12 +6,9 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include <iostream>
-
 #include <optional>
 #include <string_view>
 #include <utility>
-#include <memory>
 
 #include "IRModule.h"
 
@@ -995,7 +992,7 @@ class PyDenseElementsAttribute
       } else if (format == "?") {
         // i1
         // The i1 type needs to be bit-packed, so we will handle it seperately
-        return getAttributeFromBufferBoolean(view, explicitShape, context);
+        return getAttributeFromBufferBoolBitpack(view, explicitShape, context);
       } else if (isSignedIntegerFormat(format)) {
         if (view.itemsize == 4) {
           // i32
@@ -1047,17 +1044,21 @@ class PyDenseElementsAttribute
       }
     }
 
-    return mlirDenseElementsAttrRawBufferGet(getShapedType(bulkLoadElementType, explicitShape, view), view.len, view.buf);
+    MlirType type = getShapedType(bulkLoadElementType, explicitShape, view);
+    return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf);
   }
 
-  static MlirAttribute getAttributeFromBufferBoolean(Py_buffer& view,
-                                                     std::optional<std::vector<int64_t>> explicitShape,
-                                                     MlirContext& context) {
+  // There is a complication for boolean numpy arrays, as numpy represent them as
+  // 8 bits per boolean, whereas MLIR bitpacks them into 8 booleans per byte.
+  // This function does the bit-packing respecting endianess.
+  static MlirAttribute getAttributeFromBufferBoolBitpack(Py_buffer& view,
+                                                         std::optional<std::vector<int64_t>> explicitShape,
+                                                         MlirContext& context) {
     // First read the content of the python buffer as u8's, to correct for endianess
-    MlirAttribute intermediateAttr = mlirDenseElementsAttrRawBufferGet(
-      getShapedType(mlirIntegerTypeUnsignedGet(context, 8), explicitShape, view), view.len, view.buf);
+    MlirType byteType = getShapedType(mlirIntegerTypeUnsignedGet(context, 8), explicitShape, view);
+    MlirAttribute intermediateAttr = mlirDenseElementsAttrRawBufferGet(byteType, view.len, view.buf);
 
-    // Pack the boolean array according to the i8 bitpacking layout
+    // Pack the boolean array according to the i1 bitpacking layout
     const int numPackedBytes = (view.len + 7) / 8;
     SmallVector<uint8_t, 8> bitpacked(numPackedBytes);
     for (int byteNum = 0; byteNum < numPackedBytes; byteNum++) {
@@ -1070,8 +1071,8 @@ class PyDenseElementsAttribute
       bitpacked[byteNum] = byte;
     }
 
-    return mlirDenseElementsAttrRawBufferGet(getShapedType(
-      mlirIntegerTypeGet(context, 1), explicitShape, view), numPackedBytes, bitpacked.data());
+    MlirType bitpackedType = getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view);
+    return mlirDenseElementsAttrRawBufferGet(bitpackedType, numPackedBytes, bitpacked.data());
   }
 
   template <typename Type>
@@ -1145,7 +1146,6 @@ class PyDenseIntElementsAttribute
     bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
     if (isUnsigned) {
       if (width == 1) {
-        std::cerr << "Loading unsigned i1 values at position: " << pos << std::endl;
         return mlirDenseElementsAttrGetBoolValue(*this, pos);
       }
       if (width == 8) {
@@ -1162,7 +1162,6 @@ class PyDenseIntElementsAttribute
       }
     } else {
       if (width == 1) {
-        std::cerr << "Loading signed i1 values at position: " << pos << std::endl;
         return mlirDenseElementsAttrGetBoolValue(*this, pos);
       }
       if (width == 8) {
diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index dc78be53eee0cb..11d1ade552f5a2 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -6,8 +6,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include <iostream>
-
 #include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/Support.h"
 #include "mlir/CAPI/AffineMap.h"
@@ -529,11 +527,8 @@ MlirAttribute mlirDenseElementsAttrRawBufferGet(MlirType shapedType,
                               rawBufferSize);
   bool isSplat = false;
   if (!DenseElementsAttr::isValidRawBuffer(shapedTypeCpp, rawBufferCpp,
-                                           isSplat)) {
-    std::cerr << "NULL POINTER!!!" << std::endl;
+                                           isSplat))
     return mlirAttributeGetNull();
-  }
-  std::cerr << "Pointer looks ok..." << std::endl;
   return wrap(DenseElementsAttr::getFromRawBuffer(shapedTypeCpp, rawBufferCpp));
 }
 
@@ -593,7 +588,7 @@ MlirAttribute mlirDenseElementsAttrBoolGet(MlirType shapedType,
                                            const int *elements) {
   SmallVector<bool, 8> values(elements, elements + numElements);
   return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
-                                    values));
+                                     values));
 }
 
 /// Creates a dense attribute with elements of the type deduced by templates.
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 1009b1882e9425..8861a940336133 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -6,8 +6,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include <iostream>
-
 #include "mlir/IR/BuiltinAttributes.h"
 #include "AttributeDetail.h"
 #include "mlir/IR/AffineMap.h"
@@ -1090,8 +1088,6 @@ bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
     }
 
     // This is a valid non-splat buffer if it has the right size.
-    std::cerr << "Raw buffer width: " << rawBufferWidth << std::endl;
-    std::cerr << "Aligned to width: " << llvm::alignTo<8>(numElements) << std::endl;
     return rawBufferWidth == llvm::alignTo<8>(numElements);
   }
 

>From 6d3204c5a9433dd06f63d74203825efe6fcc3d98 Mon Sep 17 00:00:00 2001
From: Kasper Nielsen <kasper0406 at gmail.com>
Date: Sat, 19 Oct 2024 23:01:07 +0200
Subject: [PATCH 04/12] Fix style

---
 mlir/lib/Bindings/Python/IRAttributes.cpp | 54 +++++++++++++----------
 1 file changed, 30 insertions(+), 24 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 043b7ed5867e8b..fc096c8d9e8374 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -759,7 +759,8 @@ class PyDenseElementsAttribute
     auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
 
     MlirContext context = contextWrapper->get();
-    MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType, explicitShape, context);
+    MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType,
+                                                explicitShape, context);
     if (mlirAttributeIsNull(attr)) {
       throw std::invalid_argument(
           "DenseElementsAttr could not be constructed from the given buffer. "
@@ -873,9 +874,9 @@ class PyDenseElementsAttribute
                mlirIntegerTypeGetWidth(elementType) == 1) {
       // i1 / bool
       if (!m_boolBuffer.has_value()) {
-        // Because i1's are bitpacked within MLIR, we need to convert it into the
-        // one bool per byte representation used by numpy.
-        // We allocate a new array to keep around for this purpose.
+        // Because i1's are bitpacked within MLIR, we need to convert it into
+        // the one bool per byte representation used by numpy. We allocate a new
+        // array to keep around for this purpose.
         int64_t numBooleans = mlirElementsAttrGetNumElements(*this);
         m_boolBuffer = SmallVector<bool, 8>(numBooleans);
         for (int i = 0; i < numBooleans; i++) {
@@ -939,9 +940,10 @@ class PyDenseElementsAttribute
            code == 'q';
   }
 
-  static MlirType getShapedType(std::optional<MlirType> bulkLoadElementType,
-                                std::optional<std::vector<int64_t>> explicitShape,
-                                Py_buffer& view) {
+  static MlirType
+  getShapedType(std::optional<MlirType> bulkLoadElementType,
+                std::optional<std::vector<int64_t>> explicitShape,
+                Py_buffer &view) {
     SmallVector<int64_t> shape;
     if (explicitShape) {
       shape.append(explicitShape->begin(), explicitShape->end());
@@ -962,11 +964,9 @@ class PyDenseElementsAttribute
     }
   }
 
-  static MlirAttribute getAttributeFromBuffer(Py_buffer& view,
-                                              bool signless,
-                                              std::optional<PyType> explicitType,
-                                              std::optional<std::vector<int64_t>> explicitShape,
-                                              MlirContext& context) {
+  static MlirAttribute getAttributeFromBuffer(
+      Py_buffer &view, bool signless, std::optional<PyType> explicitType,
+      std::optional<std::vector<int64_t>> explicitShape, MlirContext &context) {
     // Detect format codes that are suitable for bulk loading. This includes
     // all byte aligned integer and floating point types up to 8 bytes.
     // Notably, this excludes, bool (which needs to be bit-packed) and
@@ -1048,15 +1048,18 @@ class PyDenseElementsAttribute
     return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf);
   }
 
-  // There is a complication for boolean numpy arrays, as numpy represent them as
-  // 8 bits per boolean, whereas MLIR bitpacks them into 8 booleans per byte.
+  // There is a complication for boolean numpy arrays, as numpy represent them
+  // as 8 bits per boolean, whereas MLIR bitpacks them into 8 booleans per byte.
   // This function does the bit-packing respecting endianess.
-  static MlirAttribute getAttributeFromBufferBoolBitpack(Py_buffer& view,
-                                                         std::optional<std::vector<int64_t>> explicitShape,
-                                                         MlirContext& context) {
-    // First read the content of the python buffer as u8's, to correct for endianess
-    MlirType byteType = getShapedType(mlirIntegerTypeUnsignedGet(context, 8), explicitShape, view);
-    MlirAttribute intermediateAttr = mlirDenseElementsAttrRawBufferGet(byteType, view.len, view.buf);
+  static MlirAttribute getAttributeFromBufferBoolBitpack(
+      Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
+      MlirContext &context) {
+    // First read the content of the python buffer as u8's, to correct for
+    // endianess
+    MlirType byteType = getShapedType(mlirIntegerTypeUnsignedGet(context, 8),
+                                      explicitShape, view);
+    MlirAttribute intermediateAttr =
+        mlirDenseElementsAttrRawBufferGet(byteType, view.len, view.buf);
 
     // Pack the boolean array according to the i1 bitpacking layout
     const int numPackedBytes = (view.len + 7) / 8;
@@ -1065,20 +1068,23 @@ class PyDenseElementsAttribute
       uint8_t byte = 0;
       for (int bitNr = 0; 8 * byteNum + bitNr < view.len; bitNr++) {
         int pos = 8 * byteNum + bitNr;
-        uint8_t boolVal = mlirDenseElementsAttrGetUInt8Value(intermediateAttr, pos) << bitNr;
+        uint8_t boolVal =
+            mlirDenseElementsAttrGetUInt8Value(intermediateAttr, pos) << bitNr;
         byte |= boolVal;
       }
       bitpacked[byteNum] = byte;
     }
 
-    MlirType bitpackedType = getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view);
-    return mlirDenseElementsAttrRawBufferGet(bitpackedType, numPackedBytes, bitpacked.data());
+    MlirType bitpackedType =
+        getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view);
+    return mlirDenseElementsAttrRawBufferGet(bitpackedType, numPackedBytes,
+                                             bitpacked.data());
   }
 
   template <typename Type>
   py::buffer_info bufferInfo(MlirType shapedType,
                              const char *explicitFormat = nullptr,
-                             Type* dataOverride = nullptr) {
+                             Type *dataOverride = nullptr) {
     intptr_t rank = mlirShapedTypeGetRank(shapedType);
     // Prepare the data for the buffer_info.
     // Buffer is configured for read-only access below.

>From f8a21fc3ff9213ac74cd9854585af5430adbe571 Mon Sep 17 00:00:00 2001
From: Kasper Nielsen <kasper0406 at gmail.com>
Date: Sat, 19 Oct 2024 23:02:37 +0200
Subject: [PATCH 05/12] More styles

---
 mlir/lib/Bindings/Python/IRAttributes.cpp |  5 ++--
 mlir/test/python/ir/array_attributes.py   | 32 ++++++++++++++---------
 2 files changed, 22 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index fc096c8d9e8374..e4dbeb25305a20 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -969,9 +969,8 @@ class PyDenseElementsAttribute
       std::optional<std::vector<int64_t>> explicitShape, MlirContext &context) {
     // Detect format codes that are suitable for bulk loading. This includes
     // all byte aligned integer and floating point types up to 8 bytes.
-    // Notably, this excludes, bool (which needs to be bit-packed) and
-    // other exotics which do not have a direct representation in the buffer
-    // protocol (i.e. complex, etc).
+    // Notably, this excludes exotics types which do not have a direct
+    // representation in the buffer protocol (i.e. complex, etc).
     std::optional<MlirType> bulkLoadElementType;
     if (explicitType) {
       bulkLoadElementType = *explicitType;
diff --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py
index 16d7322cfe7119..c1e14601407925 100644
--- a/mlir/test/python/ir/array_attributes.py
+++ b/mlir/test/python/ir/array_attributes.py
@@ -346,7 +346,9 @@ def testGetDenseElementsI1Signless():
         # CHECK: {{\[}} True True False]]
         print(np.array(attr))
 
-        array = np.array([[True, True, False, False], [True, False, True, False]], dtype=np.bool_)
+        array = np.array(
+            [[True, True, False, False], [True, False, True, False]], dtype=np.bool_
+        )
         attr = DenseElementsAttr.get(array)
         # CHECK: dense<{{\[}}[true, true, false, false], [true, false, true, false]]> : tensor<2x4xi1>
         print(attr)
@@ -354,13 +356,16 @@ def testGetDenseElementsI1Signless():
         # CHECK: {{\[}} True False True False]]
         print(np.array(attr))
 
-        array = np.array([
-            [True, True, False, False],
-            [True, False, True, False],
-            [False, False, False, False],
-            [True, True, True, True],
-            [True, False, False, True],
-        ], dtype=np.bool_)
+        array = np.array(
+            [
+                [True, True, False, False],
+                [True, False, True, False],
+                [False, False, False, False],
+                [True, True, True, True],
+                [True, False, False, True],
+            ],
+            dtype=np.bool_,
+        )
         attr = DenseElementsAttr.get(array)
         # CHECK: dense<{{\[}}[true, true, false, false], [true, false, true, false], [false, false, false, false], [true, true, true, true], [true, false, false, true]]> : tensor<5x4xi1>
         print(attr)
@@ -371,10 +376,13 @@ def testGetDenseElementsI1Signless():
         # CHECK: {{\[}} True False False True]]
         print(np.array(attr))
 
-        array = np.array([
-            [True, True, False, False, True, True, False, False, False],
-            [False, False, False, True, False, True, True, False, True]
-        ], dtype=np.bool_)
+        array = np.array(
+            [
+                [True, True, False, False, True, True, False, False, False],
+                [False, False, False, True, False, True, True, False, True],
+            ],
+            dtype=np.bool_,
+        )
         attr = DenseElementsAttr.get(array)
         # CHECK: dense<{{\[}}[true, true, false, false, true, true, false, false, false], [false, false, false, true, false, true, true, false, true]]> : tensor<2x9xi1>
         print(attr)

>From 73df6fbe5cf4d51a8743fc15532ac95d695227ab Mon Sep 17 00:00:00 2001
From: Kasper Nielsen <kasper0406 at gmail.com>
Date: Wed, 23 Oct 2024 15:23:11 +0200
Subject: [PATCH 06/12] Use numpy to bitpack and unpack, to avoid additional
 fields

---
 mlir/lib/Bindings/Python/IRAttributes.cpp | 63 ++++++++++++-----------
 mlir/lib/Bindings/Python/PybindUtils.h    |  1 +
 2 files changed, 33 insertions(+), 31 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index e4dbeb25305a20..9f9cf800ad8142 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -873,18 +873,10 @@ class PyDenseElementsAttribute
     } else if (mlirTypeIsAInteger(elementType) &&
                mlirIntegerTypeGetWidth(elementType) == 1) {
       // i1 / bool
-      if (!m_boolBuffer.has_value()) {
-        // Because i1's are bitpacked within MLIR, we need to convert it into
-        // the one bool per byte representation used by numpy. We allocate a new
-        // array to keep around for this purpose.
-        int64_t numBooleans = mlirElementsAttrGetNumElements(*this);
-        m_boolBuffer = SmallVector<bool, 8>(numBooleans);
-        for (int i = 0; i < numBooleans; i++) {
-          bool value = mlirDenseElementsAttrGetBoolValue(*this, i);
-          m_boolBuffer.value()[i] = value;
-        }
-      }
-      return bufferInfo<bool>(shapedType, "?", m_boolBuffer.value().data());
+      // We can not send the buffer directly back to Python, because the i1 values
+      // are bitpacked within MLIR.
+      // We call numpy's unpackbits function to convert the bytes.
+      return getBooleanBufferFromBitpackedAttribute();
     }
 
     // TODO: Currently crashes the program.
@@ -922,8 +914,6 @@ class PyDenseElementsAttribute
   }
 
 private:
-  std::optional<SmallVector<bool, 8>> m_boolBuffer;
-
   static bool isUnsignedIntegerFormat(std::string_view format) {
     if (format.empty())
       return false;
@@ -991,7 +981,7 @@ class PyDenseElementsAttribute
       } else if (format == "?") {
         // i1
         // The i1 type needs to be bit-packed, so we will handle it seperately
-        return getAttributeFromBufferBoolBitpack(view, explicitShape, context);
+        return getBitpackedAttributeFromBooleanBuffer(view, explicitShape, context);
       } else if (isSignedIntegerFormat(format)) {
         if (view.itemsize == 4) {
           // i32
@@ -1050,7 +1040,7 @@ class PyDenseElementsAttribute
   // There is a complication for boolean numpy arrays, as numpy represent them
   // as 8 bits per boolean, whereas MLIR bitpacks them into 8 booleans per byte.
   // This function does the bit-packing respecting endianess.
-  static MlirAttribute getAttributeFromBufferBoolBitpack(
+  static MlirAttribute getBitpackedAttributeFromBooleanBuffer(
       Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
       MlirContext &context) {
     // First read the content of the python buffer as u8's, to correct for
@@ -1060,24 +1050,35 @@ class PyDenseElementsAttribute
     MlirAttribute intermediateAttr =
         mlirDenseElementsAttrRawBufferGet(byteType, view.len, view.buf);
 
-    // Pack the boolean array according to the i1 bitpacking layout
-    const int numPackedBytes = (view.len + 7) / 8;
-    SmallVector<uint8_t, 8> bitpacked(numPackedBytes);
-    for (int byteNum = 0; byteNum < numPackedBytes; byteNum++) {
-      uint8_t byte = 0;
-      for (int bitNr = 0; 8 * byteNum + bitNr < view.len; bitNr++) {
-        int pos = 8 * byteNum + bitNr;
-        uint8_t boolVal =
-            mlirDenseElementsAttrGetUInt8Value(intermediateAttr, pos) << bitNr;
-        byte |= boolVal;
-      }
-      bitpacked[byteNum] = byte;
-    }
+    uint8_t *unpackedData = static_cast<uint8_t *>(
+        const_cast<void *>(mlirDenseElementsAttrGetRawData(intermediateAttr)));
+    py::array_t<uint8_t> arr(view.len, unpackedData);
+
+    py::module numpy = py::module::import("numpy");
+    py::object packbits_func = numpy.attr("packbits");
+    py::object packed_booleans = packbits_func(arr, "bitorder"_a = "little");
+    py::buffer_info buffer_info = packed_booleans.cast<py::buffer>().request();
 
     MlirType bitpackedType =
         getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view);
-    return mlirDenseElementsAttrRawBufferGet(bitpackedType, numPackedBytes,
-                                             bitpacked.data());
+    return mlirDenseElementsAttrRawBufferGet(bitpackedType, buffer_info.size, buffer_info.ptr);
+  }
+
+  // This does the opposite transformation of `getBitpackedAttributeFromBooleanBuffer`
+  py::buffer_info getBooleanBufferFromBitpackedAttribute() {
+    int64_t numBooleans = mlirElementsAttrGetNumElements(*this);
+    int64_t numBitpackedBytes = (numBooleans + 7) / 8;
+    uint8_t *bitpackedData = static_cast<uint8_t *>(
+      const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
+    py::array_t<uint8_t> arr(numBitpackedBytes, bitpackedData);
+
+    py::module numpy = py::module::import("numpy");
+    py::object unpackbits_func = numpy.attr("unpackbits");
+    py::object unpacked_booleans = unpackbits_func(arr, "bitorder"_a = "little");
+    py::buffer_info buffer_info = unpacked_booleans.cast<py::buffer>().request();
+
+    MlirType shapedType = mlirAttributeGetType(*this);
+    return bufferInfo<bool>(shapedType, "?", (bool*)buffer_info.ptr);
   }
 
   template <typename Type>
diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h
index 38462ac8ba6db9..b8c764c030bc6a 100644
--- a/mlir/lib/Bindings/Python/PybindUtils.h
+++ b/mlir/lib/Bindings/Python/PybindUtils.h
@@ -16,6 +16,7 @@
 
 #include <pybind11/pybind11.h>
 #include <pybind11/stl.h>
+#include <pybind11/numpy.h>
 
 namespace mlir {
 namespace python {

>From 93156b1e6f12d157d3bbc796ff7288a230fd8e14 Mon Sep 17 00:00:00 2001
From: Kasper Nielsen <kasper0406 at gmail.com>
Date: Wed, 23 Oct 2024 15:36:52 +0200
Subject: [PATCH 07/12] Small refactoring

---
 mlir/lib/Bindings/Python/IRAttributes.cpp | 44 +++++++++++++----------
 mlir/lib/Bindings/Python/PybindUtils.h    |  2 +-
 2 files changed, 27 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 9f9cf800ad8142..7b66adc5c7488f 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -873,9 +873,9 @@ class PyDenseElementsAttribute
     } else if (mlirTypeIsAInteger(elementType) &&
                mlirIntegerTypeGetWidth(elementType) == 1) {
       // i1 / bool
-      // We can not send the buffer directly back to Python, because the i1 values
-      // are bitpacked within MLIR.
-      // We call numpy's unpackbits function to convert the bytes.
+      // We can not send the buffer directly back to Python, because the i1
+      // values are bitpacked within MLIR. We call numpy's unpackbits function
+      // to convert the bytes.
       return getBooleanBufferFromBitpackedAttribute();
     }
 
@@ -981,7 +981,8 @@ class PyDenseElementsAttribute
       } else if (format == "?") {
         // i1
         // The i1 type needs to be bit-packed, so we will handle it seperately
-        return getBitpackedAttributeFromBooleanBuffer(view, explicitShape, context);
+        return getBitpackedAttributeFromBooleanBuffer(view, explicitShape,
+                                                      context);
       } else if (isSignedIntegerFormat(format)) {
         if (view.itemsize == 4) {
           // i32
@@ -1061,38 +1062,45 @@ class PyDenseElementsAttribute
 
     MlirType bitpackedType =
         getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view);
-    return mlirDenseElementsAttrRawBufferGet(bitpackedType, buffer_info.size, buffer_info.ptr);
+    return mlirDenseElementsAttrRawBufferGet(bitpackedType, buffer_info.size,
+                                             buffer_info.ptr);
   }
 
-  // This does the opposite transformation of `getBitpackedAttributeFromBooleanBuffer`
+  // This does the opposite transformation of
+  // `getBitpackedAttributeFromBooleanBuffer`
   py::buffer_info getBooleanBufferFromBitpackedAttribute() {
     int64_t numBooleans = mlirElementsAttrGetNumElements(*this);
     int64_t numBitpackedBytes = (numBooleans + 7) / 8;
     uint8_t *bitpackedData = static_cast<uint8_t *>(
-      const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
+        const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
     py::array_t<uint8_t> arr(numBitpackedBytes, bitpackedData);
 
     py::module numpy = py::module::import("numpy");
     py::object unpackbits_func = numpy.attr("unpackbits");
-    py::object unpacked_booleans = unpackbits_func(arr, "bitorder"_a = "little");
-    py::buffer_info buffer_info = unpacked_booleans.cast<py::buffer>().request();
+    py::object unpacked_booleans =
+        unpackbits_func(arr, "bitorder"_a = "little");
+    py::buffer_info buffer_info =
+        unpacked_booleans.cast<py::buffer>().request();
 
     MlirType shapedType = mlirAttributeGetType(*this);
-    return bufferInfo<bool>(shapedType, "?", (bool*)buffer_info.ptr);
+    return bufferInfo<bool>(shapedType, (bool *)buffer_info.ptr, "?");
   }
 
   template <typename Type>
   py::buffer_info bufferInfo(MlirType shapedType,
-                             const char *explicitFormat = nullptr,
-                             Type *dataOverride = nullptr) {
-    intptr_t rank = mlirShapedTypeGetRank(shapedType);
+                             const char *explicitFormat = nullptr) {
     // Prepare the data for the buffer_info.
-    // Buffer is configured for read-only access below.
+    // Buffer is configured for read-only access in .
     Type *data = static_cast<Type *>(
-        const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
-    if (dataOverride != nullptr) {
-      data = dataOverride;
-    }
+      const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
+    return bufferInfo<Type>(shapedType, data, explicitFormat);
+  }
+
+  template <typename Type>
+  py::buffer_info bufferInfo(MlirType shapedType,
+                             Type *data,
+                             const char *explicitFormat = nullptr) {
+    intptr_t rank = mlirShapedTypeGetRank(shapedType);
     // Prepare the shape for the buffer_info.
     SmallVector<intptr_t, 4> shape;
     for (intptr_t i = 0; i < rank; ++i)
diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h
index b8c764c030bc6a..7df078e7d27e0f 100644
--- a/mlir/lib/Bindings/Python/PybindUtils.h
+++ b/mlir/lib/Bindings/Python/PybindUtils.h
@@ -14,9 +14,9 @@
 #include "llvm/ADT/Twine.h"
 #include "llvm/Support/DataTypes.h"
 
+#include <pybind11/numpy.h>
 #include <pybind11/pybind11.h>
 #include <pybind11/stl.h>
-#include <pybind11/numpy.h>
 
 namespace mlir {
 namespace python {

>From 90868b8ce78c3d344b5fb84eb7fa3efda4cbd18f Mon Sep 17 00:00:00 2001
From: Kasper Nielsen <kasper0406 at gmail.com>
Date: Wed, 23 Oct 2024 15:37:30 +0200
Subject: [PATCH 08/12] Fix styles

---
 mlir/lib/Bindings/Python/IRAttributes.cpp | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 7b66adc5c7488f..c5692b29bd6fc8 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -1092,13 +1092,12 @@ class PyDenseElementsAttribute
     // Prepare the data for the buffer_info.
     // Buffer is configured for read-only access in .
     Type *data = static_cast<Type *>(
-      const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
+        const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
     return bufferInfo<Type>(shapedType, data, explicitFormat);
   }
 
   template <typename Type>
-  py::buffer_info bufferInfo(MlirType shapedType,
-                             Type *data,
+  py::buffer_info bufferInfo(MlirType shapedType, Type *data,
                              const char *explicitFormat = nullptr) {
     intptr_t rank = mlirShapedTypeGetRank(shapedType);
     // Prepare the shape for the buffer_info.

>From d216d43250ecc742f59436a9c12cfabbe26154cd Mon Sep 17 00:00:00 2001
From: Kasper Nielsen <kasper0406 at gmail.com>
Date: Wed, 23 Oct 2024 15:50:27 +0200
Subject: [PATCH 09/12] Minor rename

---
 mlir/lib/Bindings/Python/IRAttributes.cpp | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index c5692b29bd6fc8..c3d22259cab5f6 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -1053,11 +1053,11 @@ class PyDenseElementsAttribute
 
     uint8_t *unpackedData = static_cast<uint8_t *>(
         const_cast<void *>(mlirDenseElementsAttrGetRawData(intermediateAttr)));
-    py::array_t<uint8_t> arr(view.len, unpackedData);
+    py::array_t<uint8_t> unpackedArray(view.len, unpackedData);
 
     py::module numpy = py::module::import("numpy");
     py::object packbits_func = numpy.attr("packbits");
-    py::object packed_booleans = packbits_func(arr, "bitorder"_a = "little");
+    py::object packed_booleans = packbits_func(unpackedArray, "bitorder"_a = "little");
     py::buffer_info buffer_info = packed_booleans.cast<py::buffer>().request();
 
     MlirType bitpackedType =
@@ -1073,12 +1073,12 @@ class PyDenseElementsAttribute
     int64_t numBitpackedBytes = (numBooleans + 7) / 8;
     uint8_t *bitpackedData = static_cast<uint8_t *>(
         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
-    py::array_t<uint8_t> arr(numBitpackedBytes, bitpackedData);
+    py::array_t<uint8_t> packedArray(numBitpackedBytes, bitpackedData);
 
     py::module numpy = py::module::import("numpy");
     py::object unpackbits_func = numpy.attr("unpackbits");
     py::object unpacked_booleans =
-        unpackbits_func(arr, "bitorder"_a = "little");
+        unpackbits_func(packedArray, "bitorder"_a = "little");
     py::buffer_info buffer_info =
         unpacked_booleans.cast<py::buffer>().request();
 

>From 6543732fe9e5299d02153c67e89d5ff5640d82db Mon Sep 17 00:00:00 2001
From: Kasper Nielsen <kasper0406 at gmail.com>
Date: Wed, 23 Oct 2024 15:50:44 +0200
Subject: [PATCH 10/12] Code format

---
 mlir/lib/Bindings/Python/IRAttributes.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index c3d22259cab5f6..2ca0765f904b32 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -1057,7 +1057,8 @@ class PyDenseElementsAttribute
 
     py::module numpy = py::module::import("numpy");
     py::object packbits_func = numpy.attr("packbits");
-    py::object packed_booleans = packbits_func(unpackedArray, "bitorder"_a = "little");
+    py::object packed_booleans =
+        packbits_func(unpackedArray, "bitorder"_a = "little");
     py::buffer_info buffer_info = packed_booleans.cast<py::buffer>().request();
 
     MlirType bitpackedType =

>From 75c8264724a53c88e49bc949b47d7093b5781642 Mon Sep 17 00:00:00 2001
From: Kasper Nielsen <kasper0406 at gmail.com>
Date: Wed, 23 Oct 2024 22:23:55 +0200
Subject: [PATCH 11/12] Address comments

---
 mlir/lib/Bindings/Python/IRAttributes.cpp | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 2ca0765f904b32..0990f878a98399 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -1038,9 +1038,9 @@ class PyDenseElementsAttribute
     return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf);
   }
 
-  // There is a complication for boolean numpy arrays, as numpy represent them
-  // as 8 bits per boolean, whereas MLIR bitpacks them into 8 booleans per byte.
-  // This function does the bit-packing respecting endianess.
+  // There is a complication for boolean numpy arrays, as numpy represents them
+  // as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 booleans
+  // per byte. This function does the bit-packing respecting endianess.
   static MlirAttribute getBitpackedAttributeFromBooleanBuffer(
       Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
       MlirContext &context) {
@@ -1091,7 +1091,7 @@ class PyDenseElementsAttribute
   py::buffer_info bufferInfo(MlirType shapedType,
                              const char *explicitFormat = nullptr) {
     // Prepare the data for the buffer_info.
-    // Buffer is configured for read-only access in .
+    // Buffer is configured for read-only access inside the `bufferInfo` call.
     Type *data = static_cast<Type *>(
         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
     return bufferInfo<Type>(shapedType, data, explicitFormat);

>From e5b10a3336e715a24fa8596142548e034e77b04d Mon Sep 17 00:00:00 2001
From: Kasper Nielsen <kasper0406 at gmail.com>
Date: Tue, 29 Oct 2024 16:16:31 +0100
Subject: [PATCH 12/12] Fix nits

---
 mlir/lib/Bindings/Python/IRAttributes.cpp | 11 ++++++-----
 mlir/lib/Bindings/Python/PybindUtils.h    |  1 -
 2 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 0990f878a98399..a7a06e816ff132 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -13,6 +13,7 @@
 #include "IRModule.h"
 
 #include "PybindUtils.h"
+#include <pybind11/numpy.h>
 
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/Support/raw_ostream.h"
@@ -1059,12 +1060,12 @@ class PyDenseElementsAttribute
     py::object packbits_func = numpy.attr("packbits");
     py::object packed_booleans =
         packbits_func(unpackedArray, "bitorder"_a = "little");
-    py::buffer_info buffer_info = packed_booleans.cast<py::buffer>().request();
+    py::buffer_info pythonBuffer = packed_booleans.cast<py::buffer>().request();
 
     MlirType bitpackedType =
         getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view);
-    return mlirDenseElementsAttrRawBufferGet(bitpackedType, buffer_info.size,
-                                             buffer_info.ptr);
+    return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size,
+                                             pythonBuffer.ptr);
   }
 
   // This does the opposite transformation of
@@ -1080,11 +1081,11 @@ class PyDenseElementsAttribute
     py::object unpackbits_func = numpy.attr("unpackbits");
     py::object unpacked_booleans =
         unpackbits_func(packedArray, "bitorder"_a = "little");
-    py::buffer_info buffer_info =
+    py::buffer_info pythonBuffer =
         unpacked_booleans.cast<py::buffer>().request();
 
     MlirType shapedType = mlirAttributeGetType(*this);
-    return bufferInfo<bool>(shapedType, (bool *)buffer_info.ptr, "?");
+    return bufferInfo<bool>(shapedType, (bool *)pythonBuffer.ptr, "?");
   }
 
   template <typename Type>
diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h
index 7df078e7d27e0f..38462ac8ba6db9 100644
--- a/mlir/lib/Bindings/Python/PybindUtils.h
+++ b/mlir/lib/Bindings/Python/PybindUtils.h
@@ -14,7 +14,6 @@
 #include "llvm/ADT/Twine.h"
 #include "llvm/Support/DataTypes.h"
 
-#include <pybind11/numpy.h>
 #include <pybind11/pybind11.h>
 #include <pybind11/stl.h>
 



More information about the Mlir-commits mailing list