[Mlir-commits] [mlir] [MLIR, Python] Support converting boolean numpy arrays to and from mlir attributes (PR #113064)

Kasper Nielsen llvmlistbot at llvm.org
Wed Oct 23 13:28:33 PDT 2024


================
@@ -1016,14 +930,177 @@ class PyDenseElementsAttribute
            code == 'q';
   }
 
+  static MlirType
+  getShapedType(std::optional<MlirType> bulkLoadElementType,
+                std::optional<std::vector<int64_t>> explicitShape,
+                Py_buffer &view) {
+    SmallVector<int64_t> shape;
+    if (explicitShape) {
+      shape.append(explicitShape->begin(), explicitShape->end());
+    } else {
+      shape.append(view.shape, view.shape + view.ndim);
+    }
+
+    if (mlirTypeIsAShaped(*bulkLoadElementType)) {
+      if (explicitShape) {
+        throw std::invalid_argument("Shape can only be specified explicitly "
+                                    "when the type is not a shaped type.");
+      }
+      return *bulkLoadElementType;
+    } else {
+      MlirAttribute encodingAttr = mlirAttributeGetNull();
+      return mlirRankedTensorTypeGet(shape.size(), shape.data(),
+                                     *bulkLoadElementType, encodingAttr);
+    }
+  }
+
+  static MlirAttribute getAttributeFromBuffer(
+      Py_buffer &view, bool signless, std::optional<PyType> explicitType,
+      std::optional<std::vector<int64_t>> explicitShape, MlirContext &context) {
+    // Detect format codes that are suitable for bulk loading. This includes
+    // all byte aligned integer and floating point types up to 8 bytes.
+    // Notably, this excludes exotics types which do not have a direct
+    // representation in the buffer protocol (i.e. complex, etc).
+    std::optional<MlirType> bulkLoadElementType;
+    if (explicitType) {
+      bulkLoadElementType = *explicitType;
+    } else {
+      std::string_view format(view.format);
+      if (format == "f") {
+        // f32
+        assert(view.itemsize == 4 && "mismatched array itemsize");
+        bulkLoadElementType = mlirF32TypeGet(context);
+      } else if (format == "d") {
+        // f64
+        assert(view.itemsize == 8 && "mismatched array itemsize");
+        bulkLoadElementType = mlirF64TypeGet(context);
+      } else if (format == "e") {
+        // f16
+        assert(view.itemsize == 2 && "mismatched array itemsize");
+        bulkLoadElementType = mlirF16TypeGet(context);
+      } else if (format == "?") {
+        // i1
+        // The i1 type needs to be bit-packed, so we will handle it seperately
+        return getBitpackedAttributeFromBooleanBuffer(view, explicitShape,
+                                                      context);
+      } else if (isSignedIntegerFormat(format)) {
+        if (view.itemsize == 4) {
+          // i32
+          bulkLoadElementType = signless
+                                    ? mlirIntegerTypeGet(context, 32)
+                                    : mlirIntegerTypeSignedGet(context, 32);
+        } else if (view.itemsize == 8) {
+          // i64
+          bulkLoadElementType = signless
+                                    ? mlirIntegerTypeGet(context, 64)
+                                    : mlirIntegerTypeSignedGet(context, 64);
+        } else if (view.itemsize == 1) {
+          // i8
+          bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
+                                         : mlirIntegerTypeSignedGet(context, 8);
+        } else if (view.itemsize == 2) {
+          // i16
+          bulkLoadElementType = signless
+                                    ? mlirIntegerTypeGet(context, 16)
+                                    : mlirIntegerTypeSignedGet(context, 16);
+        }
+      } else if (isUnsignedIntegerFormat(format)) {
+        if (view.itemsize == 4) {
+          // unsigned i32
+          bulkLoadElementType = signless
+                                    ? mlirIntegerTypeGet(context, 32)
+                                    : mlirIntegerTypeUnsignedGet(context, 32);
+        } else if (view.itemsize == 8) {
+          // unsigned i64
+          bulkLoadElementType = signless
+                                    ? mlirIntegerTypeGet(context, 64)
+                                    : mlirIntegerTypeUnsignedGet(context, 64);
+        } else if (view.itemsize == 1) {
+          // i8
+          bulkLoadElementType = signless
+                                    ? mlirIntegerTypeGet(context, 8)
+                                    : mlirIntegerTypeUnsignedGet(context, 8);
+        } else if (view.itemsize == 2) {
+          // i16
+          bulkLoadElementType = signless
+                                    ? mlirIntegerTypeGet(context, 16)
+                                    : mlirIntegerTypeUnsignedGet(context, 16);
+        }
+      }
+      if (!bulkLoadElementType) {
+        throw std::invalid_argument(
+            std::string("unimplemented array format conversion from format: ") +
+            std::string(format));
+      }
+    }
+
+    MlirType type = getShapedType(bulkLoadElementType, explicitShape, view);
+    return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf);
+  }
+
+  // There is a complication for boolean numpy arrays, as numpy represent them
+  // as 8 bits per boolean, whereas MLIR bitpacks them into 8 booleans per byte.
+  // This function does the bit-packing respecting endianess.
+  static MlirAttribute getBitpackedAttributeFromBooleanBuffer(
+      Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
+      MlirContext &context) {
+    // First read the content of the python buffer as u8's, to correct for
+    // endianess
+    MlirType byteType = getShapedType(mlirIntegerTypeUnsignedGet(context, 8),
+                                      explicitShape, view);
+    MlirAttribute intermediateAttr =
+        mlirDenseElementsAttrRawBufferGet(byteType, view.len, view.buf);
+
+    uint8_t *unpackedData = static_cast<uint8_t *>(
+        const_cast<void *>(mlirDenseElementsAttrGetRawData(intermediateAttr)));
+    py::array_t<uint8_t> unpackedArray(view.len, unpackedData);
+
+    py::module numpy = py::module::import("numpy");
+    py::object packbits_func = numpy.attr("packbits");
+    py::object packed_booleans =
+        packbits_func(unpackedArray, "bitorder"_a = "little");
+    py::buffer_info buffer_info = packed_booleans.cast<py::buffer>().request();
+
+    MlirType bitpackedType =
+        getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view);
+    return mlirDenseElementsAttrRawBufferGet(bitpackedType, buffer_info.size,
+                                             buffer_info.ptr);
+  }
+
+  // This does the opposite transformation of
+  // `getBitpackedAttributeFromBooleanBuffer`
+  py::buffer_info getBooleanBufferFromBitpackedAttribute() {
+    int64_t numBooleans = mlirElementsAttrGetNumElements(*this);
+    int64_t numBitpackedBytes = (numBooleans + 7) / 8;
+    uint8_t *bitpackedData = static_cast<uint8_t *>(
+        const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
+    py::array_t<uint8_t> packedArray(numBitpackedBytes, bitpackedData);
+
+    py::module numpy = py::module::import("numpy");
+    py::object unpackbits_func = numpy.attr("unpackbits");
+    py::object unpacked_booleans =
+        unpackbits_func(packedArray, "bitorder"_a = "little");
+    py::buffer_info buffer_info =
+        unpacked_booleans.cast<py::buffer>().request();
+
+    MlirType shapedType = mlirAttributeGetType(*this);
+    return bufferInfo<bool>(shapedType, (bool *)buffer_info.ptr, "?");
+  }
+
   template <typename Type>
   py::buffer_info bufferInfo(MlirType shapedType,
                              const char *explicitFormat = nullptr) {
-    intptr_t rank = mlirShapedTypeGetRank(shapedType);
     // Prepare the data for the buffer_info.
-    // Buffer is configured for read-only access below.
+    // Buffer is configured for read-only access in .
----------------
kasper0406 wrote:

Typo

https://github.com/llvm/llvm-project/pull/113064


More information about the Mlir-commits mailing list