[Mlir-commits] [mlir] [MLIR, Python] Support converting boolean numpy arrays to and from mlir attributes (PR #113064)

Kasper Nielsen llvmlistbot at llvm.org
Wed Oct 23 08:04:03 PDT 2024


================
@@ -757,103 +757,10 @@ class PyDenseElementsAttribute
       throw py::error_already_set();
     }
     auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
-    SmallVector<int64_t> shape;
-    if (explicitShape) {
-      shape.append(explicitShape->begin(), explicitShape->end());
-    } else {
-      shape.append(view.shape, view.shape + view.ndim);
-    }
 
-    MlirAttribute encodingAttr = mlirAttributeGetNull();
     MlirContext context = contextWrapper->get();
-
-    // Detect format codes that are suitable for bulk loading. This includes
-    // all byte aligned integer and floating point types up to 8 bytes.
-    // Notably, this excludes, bool (which needs to be bit-packed) and
-    // other exotics which do not have a direct representation in the buffer
-    // protocol (i.e. complex, etc).
-    std::optional<MlirType> bulkLoadElementType;
-    if (explicitType) {
-      bulkLoadElementType = *explicitType;
-    } else {
-      std::string_view format(view.format);
-      if (format == "f") {
-        // f32
-        assert(view.itemsize == 4 && "mismatched array itemsize");
-        bulkLoadElementType = mlirF32TypeGet(context);
-      } else if (format == "d") {
-        // f64
-        assert(view.itemsize == 8 && "mismatched array itemsize");
-        bulkLoadElementType = mlirF64TypeGet(context);
-      } else if (format == "e") {
-        // f16
-        assert(view.itemsize == 2 && "mismatched array itemsize");
-        bulkLoadElementType = mlirF16TypeGet(context);
-      } else if (isSignedIntegerFormat(format)) {
-        if (view.itemsize == 4) {
-          // i32
-          bulkLoadElementType = signless
-                                    ? mlirIntegerTypeGet(context, 32)
-                                    : mlirIntegerTypeSignedGet(context, 32);
-        } else if (view.itemsize == 8) {
-          // i64
-          bulkLoadElementType = signless
-                                    ? mlirIntegerTypeGet(context, 64)
-                                    : mlirIntegerTypeSignedGet(context, 64);
-        } else if (view.itemsize == 1) {
-          // i8
-          bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
-                                         : mlirIntegerTypeSignedGet(context, 8);
-        } else if (view.itemsize == 2) {
-          // i16
-          bulkLoadElementType = signless
-                                    ? mlirIntegerTypeGet(context, 16)
-                                    : mlirIntegerTypeSignedGet(context, 16);
-        }
-      } else if (isUnsignedIntegerFormat(format)) {
-        if (view.itemsize == 4) {
-          // unsigned i32
-          bulkLoadElementType = signless
-                                    ? mlirIntegerTypeGet(context, 32)
-                                    : mlirIntegerTypeUnsignedGet(context, 32);
-        } else if (view.itemsize == 8) {
-          // unsigned i64
-          bulkLoadElementType = signless
-                                    ? mlirIntegerTypeGet(context, 64)
-                                    : mlirIntegerTypeUnsignedGet(context, 64);
-        } else if (view.itemsize == 1) {
-          // i8
-          bulkLoadElementType = signless
-                                    ? mlirIntegerTypeGet(context, 8)
-                                    : mlirIntegerTypeUnsignedGet(context, 8);
-        } else if (view.itemsize == 2) {
-          // i16
-          bulkLoadElementType = signless
-                                    ? mlirIntegerTypeGet(context, 16)
-                                    : mlirIntegerTypeUnsignedGet(context, 16);
-        }
-      }
-      if (!bulkLoadElementType) {
-        throw std::invalid_argument(
-            std::string("unimplemented array format conversion from format: ") +
-            std::string(format));
-      }
-    }
-
-    MlirType shapedType;
-    if (mlirTypeIsAShaped(*bulkLoadElementType)) {
-      if (explicitShape) {
-        throw std::invalid_argument("Shape can only be specified explicitly "
-                                    "when the type is not a shaped type.");
-      }
-      shapedType = *bulkLoadElementType;
-    } else {
-      shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(),
-                                           *bulkLoadElementType, encodingAttr);
-    }
-    size_t rawBufferSize = view.len;
-    MlirAttribute attr =
-        mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, view.buf);
+    MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType,
----------------
kasper0406 wrote:

I had to move the if-statements to the `getAttributeFromBuffer` method, as now the i1 case will not follow the usual flow, but instead call `getBitpackedAttributeFromBooleanBuffer` to construct the MlirAttribute.

https://github.com/llvm/llvm-project/pull/113064


More information about the Mlir-commits mailing list