[Mlir-commits] [mlir] 71a2545 - [MLIR:Python] Make DenseElementsAttr.get() only request a buffer format if no explicit type was provided.

Jacques Pienaar llvmlistbot at llvm.org
Fri Jul 14 16:08:22 PDT 2023


Author: Peter Hawkins
Date: 2023-07-14T16:08:15-07:00
New Revision: 71a254543d44a943dfe8790abc60795b87173f0b

URL: https://github.com/llvm/llvm-project/commit/71a254543d44a943dfe8790abc60795b87173f0b
DIFF: https://github.com/llvm/llvm-project/commit/71a254543d44a943dfe8790abc60795b87173f0b.diff

LOG: [MLIR:Python] Make DenseElementsAttr.get() only request a buffer format if no explicit type was provided.

Not every NumPy type (e.g., the `ml_dtypes.bfloat16` NumPy extension
type) has a type in the Python buffer protocol, so exporting such a
buffer with `PyBUF_FORMAT` may fail.

However, we don't care about the self-reported type of a buffer if the
user provides an explicit type. In the case that an explicit type is
provided, don't request the format from the buffer protocol, which
allows arrays whose element types are unknown to the buffer protocol to
be passed.

Reviewed By: jpienaar, ftynse

Differential Revision: https://reviews.llvm.org/D155209

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRAttributes.cpp
    mlir/test/python/ir/array_attributes.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 84a48a890eb409..75d743f3a3962a 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -7,12 +7,15 @@
 //===----------------------------------------------------------------------===//
 
 #include <optional>
+#include <string_view>
 #include <utility>
 
 #include "IRModule.h"
 
 #include "PybindUtils.h"
 
+#include "llvm/ADT/ScopeExit.h"
+
 #include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/BuiltinTypes.h"
 #include "mlir/Bindings/Python/PybindAdaptors.h"
@@ -612,19 +615,20 @@ class PyDenseElementsAttribute
                 std::optional<std::vector<int64_t>> explicitShape,
                 DefaultingPyMlirContext contextWrapper) {
     // Request a contiguous view. In exotic cases, this will cause a copy.
-    int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
-    Py_buffer *view = new Py_buffer();
-    if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
-      delete view;
+    int flags = PyBUF_ND;
+    if (!explicitType) {
+      flags |= PyBUF_FORMAT;
+    }
+    Py_buffer view;
+    if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) {
       throw py::error_already_set();
     }
-    py::buffer_info arrayInfo(view);
+    auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
     SmallVector<int64_t> shape;
     if (explicitShape) {
       shape.append(explicitShape->begin(), explicitShape->end());
     } else {
-      shape.append(arrayInfo.shape.begin(),
-                   arrayInfo.shape.begin() + arrayInfo.ndim);
+      shape.append(view.shape, view.shape + view.ndim);
     }
 
     MlirAttribute encodingAttr = mlirAttributeGetNull();
@@ -638,85 +642,92 @@ class PyDenseElementsAttribute
     std::optional<MlirType> bulkLoadElementType;
     if (explicitType) {
       bulkLoadElementType = *explicitType;
-    } else if (arrayInfo.format == "f") {
-      // f32
-      assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
-      bulkLoadElementType = mlirF32TypeGet(context);
-    } else if (arrayInfo.format == "d") {
-      // f64
-      assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
-      bulkLoadElementType = mlirF64TypeGet(context);
-    } else if (arrayInfo.format == "e") {
-      // f16
-      assert(arrayInfo.itemsize == 2 && "mismatched array itemsize");
-      bulkLoadElementType = mlirF16TypeGet(context);
-    } else if (isSignedIntegerFormat(arrayInfo.format)) {
-      if (arrayInfo.itemsize == 4) {
-        // i32
-        bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32)
-                                       : mlirIntegerTypeSignedGet(context, 32);
-      } else if (arrayInfo.itemsize == 8) {
-        // i64
-        bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64)
-                                       : mlirIntegerTypeSignedGet(context, 64);
-      } else if (arrayInfo.itemsize == 1) {
-        // i8
-        bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
-                                       : mlirIntegerTypeSignedGet(context, 8);
-      } else if (arrayInfo.itemsize == 2) {
-        // i16
-        bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16)
-                                       : mlirIntegerTypeSignedGet(context, 16);
-      }
-    } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
-      if (arrayInfo.itemsize == 4) {
-        // unsigned i32
-        bulkLoadElementType = signless
-                                  ? mlirIntegerTypeGet(context, 32)
-                                  : mlirIntegerTypeUnsignedGet(context, 32);
-      } else if (arrayInfo.itemsize == 8) {
-        // unsigned i64
-        bulkLoadElementType = signless
-                                  ? mlirIntegerTypeGet(context, 64)
-                                  : mlirIntegerTypeUnsignedGet(context, 64);
-      } else if (arrayInfo.itemsize == 1) {
-        // i8
-        bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
-                                       : mlirIntegerTypeUnsignedGet(context, 8);
-      } else if (arrayInfo.itemsize == 2) {
-        // i16
-        bulkLoadElementType = signless
-                                  ? mlirIntegerTypeGet(context, 16)
-                                  : mlirIntegerTypeUnsignedGet(context, 16);
-      }
-    }
-    if (bulkLoadElementType) {
-      MlirType shapedType;
-      if (mlirTypeIsAShaped(*bulkLoadElementType)) {
-        if (explicitShape) {
-          throw std::invalid_argument("Shape can only be specified explicitly "
-                                      "when the type is not a shaped type.");
+    } else {
+      std::string_view format(view.format);
+      if (format == "f") {
+        // f32
+        assert(view.itemsize == 4 && "mismatched array itemsize");
+        bulkLoadElementType = mlirF32TypeGet(context);
+      } else if (format == "d") {
+        // f64
+        assert(view.itemsize == 8 && "mismatched array itemsize");
+        bulkLoadElementType = mlirF64TypeGet(context);
+      } else if (format == "e") {
+        // f16
+        assert(view.itemsize == 2 && "mismatched array itemsize");
+        bulkLoadElementType = mlirF16TypeGet(context);
+      } else if (isSignedIntegerFormat(format)) {
+        if (view.itemsize == 4) {
+          // i32
+          bulkLoadElementType = signless
+                                    ? mlirIntegerTypeGet(context, 32)
+                                    : mlirIntegerTypeSignedGet(context, 32);
+        } else if (view.itemsize == 8) {
+          // i64
+          bulkLoadElementType = signless
+                                    ? mlirIntegerTypeGet(context, 64)
+                                    : mlirIntegerTypeSignedGet(context, 64);
+        } else if (view.itemsize == 1) {
+          // i8
+          bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
+                                         : mlirIntegerTypeSignedGet(context, 8);
+        } else if (view.itemsize == 2) {
+          // i16
+          bulkLoadElementType = signless
+                                    ? mlirIntegerTypeGet(context, 16)
+                                    : mlirIntegerTypeSignedGet(context, 16);
+        }
+      } else if (isUnsignedIntegerFormat(format)) {
+        if (view.itemsize == 4) {
+          // unsigned i32
+          bulkLoadElementType = signless
+                                    ? mlirIntegerTypeGet(context, 32)
+                                    : mlirIntegerTypeUnsignedGet(context, 32);
+        } else if (view.itemsize == 8) {
+          // unsigned i64
+          bulkLoadElementType = signless
+                                    ? mlirIntegerTypeGet(context, 64)
+                                    : mlirIntegerTypeUnsignedGet(context, 64);
+        } else if (view.itemsize == 1) {
+          // i8
+          bulkLoadElementType = signless
+                                    ? mlirIntegerTypeGet(context, 8)
+                                    : mlirIntegerTypeUnsignedGet(context, 8);
+        } else if (view.itemsize == 2) {
+          // i16
+          bulkLoadElementType = signless
+                                    ? mlirIntegerTypeGet(context, 16)
+                                    : mlirIntegerTypeUnsignedGet(context, 16);
         }
-        shapedType = *bulkLoadElementType;
-      } else {
-        shapedType = mlirRankedTensorTypeGet(
-            shape.size(), shape.data(), *bulkLoadElementType, encodingAttr);
       }
-      size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize;
-      MlirAttribute attr = mlirDenseElementsAttrRawBufferGet(
-          shapedType, rawBufferSize, arrayInfo.ptr);
-      if (mlirAttributeIsNull(attr)) {
+      if (!bulkLoadElementType) {
         throw std::invalid_argument(
-            "DenseElementsAttr could not be constructed from the given buffer. "
-            "This may mean that the Python buffer layout does not match that "
-            "MLIR expected layout and is a bug.");
+            std::string("unimplemented array format conversion from format: ") +
+            std::string(format));
       }
-      return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
     }
 
-    throw std::invalid_argument(
-        std::string("unimplemented array format conversion from format: ") +
-        arrayInfo.format);
+    MlirType shapedType;
+    if (mlirTypeIsAShaped(*bulkLoadElementType)) {
+      if (explicitShape) {
+        throw std::invalid_argument("Shape can only be specified explicitly "
+                                    "when the type is not a shaped type.");
+      }
+      shapedType = *bulkLoadElementType;
+    } else {
+      shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(),
+                                           *bulkLoadElementType, encodingAttr);
+    }
+    size_t rawBufferSize = view.len;
+    MlirAttribute attr =
+        mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, view.buf);
+    if (mlirAttributeIsNull(attr)) {
+      throw std::invalid_argument(
+          "DenseElementsAttr could not be constructed from the given buffer. "
+          "This may mean that the Python buffer layout does not match that "
+          "MLIR expected layout and is a bug.");
+    }
+    return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
   }
 
   static PyDenseElementsAttribute getSplat(const PyType &shapedType,
@@ -852,7 +863,7 @@ class PyDenseElementsAttribute
   }
 
 private:
-  static bool isUnsignedIntegerFormat(const std::string &format) {
+  static bool isUnsignedIntegerFormat(std::string_view format) {
     if (format.empty())
       return false;
     char code = format[0];
@@ -860,7 +871,7 @@ class PyDenseElementsAttribute
            code == 'Q';
   }
 
-  static bool isSignedIntegerFormat(const std::string &format) {
+  static bool isSignedIntegerFormat(std::string_view format) {
     if (format.empty())
       return false;
     char code = format[0];

diff  --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py
index b592804013b545..452d860861d783 100644
--- a/mlir/test/python/ir/array_attributes.py
+++ b/mlir/test/python/ir/array_attributes.py
@@ -30,6 +30,24 @@ def testGetDenseElementsUnsupported():
             # CHECK: unimplemented array format conversion from format:
             print(e)
 
+# CHECK-LABEL: TEST: testGetDenseElementsUnSupportedTypeOkIfExplicitTypeProvided
+ at run
+def testGetDenseElementsUnSupportedTypeOkIfExplicitTypeProvided():
+    with Context():
+        array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
+        # datetime64 specifically isn't important: it's just a 64-bit type that
+        # doesn't have a format under the Python buffer protocol. A more
+        # realistic example would be a NumPy extension type like the bfloat16
+        # type from the ml_dtypes package, which isn't a dependency of this
+        # test.
+        attr = DenseElementsAttr.get(array.view(np.datetime64),
+                                     type=IntegerType.get_signless(64))
+        # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
+        print(attr)
+        # CHECK: {{\[}}[1 2 3]
+        # CHECK: {{\[}}4 5 6]]
+        print(np.array(attr))
+
 
 ################################################################################
 # Splats.


        


More information about the Mlir-commits mailing list