[Mlir-commits] [mlir] 5d6d30e - [mlir] Extend C and Python API to support bulk loading of DenseElementsAttr.

Stella Laurenzo llvmlistbot at llvm.org
Thu Oct 7 08:44:56 PDT 2021


Author: Stella Laurenzo
Date: 2021-10-07T08:42:12-07:00
New Revision: 5d6d30edf8b9b2c69215bdbbc651a85e4d0dc4ff

URL: https://github.com/llvm/llvm-project/commit/5d6d30edf8b9b2c69215bdbbc651a85e4d0dc4ff
DIFF: https://github.com/llvm/llvm-project/commit/5d6d30edf8b9b2c69215bdbbc651a85e4d0dc4ff.diff

LOG: [mlir] Extend C and Python API to support bulk loading of DenseElementsAttr.

* This already half existed in terms of reading the raw buffer backing a DenseElementsAttr.
* Documented the precise expectations of the buffer layout.
* Extended the Python API to support construction from bitcasted buffers, allowing construction of all primitive element types (even those that lack a compatible representation in Python).
* Specifically, the Python API can now load all integer types at all bit widths and all floating point types (f16, f32, f64, bf16).

Differential Revision: https://reviews.llvm.org/D111284

Added: 
    

Modified: 
    mlir/include/mlir-c/BuiltinAttributes.h
    mlir/include/mlir/IR/BuiltinAttributes.h
    mlir/lib/Bindings/Python/IRAttributes.cpp
    mlir/lib/CAPI/IR/BuiltinAttributes.cpp
    mlir/lib/IR/BuiltinAttributes.cpp
    mlir/test/python/ir/array_attributes.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h
index 247de5cc0bd62..5839cd3d2408a 100644
--- a/mlir/include/mlir-c/BuiltinAttributes.h
+++ b/mlir/include/mlir-c/BuiltinAttributes.h
@@ -306,6 +306,23 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseFPElements(MlirAttribute attr);
 MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrGet(
     MlirType shapedType, intptr_t numElements, MlirAttribute const *elements);
 
+/// Creates a dense elements attribute with the given Shaped type and elements
+/// populated from a packed, row-major opaque buffer of contents.
+///
+/// The format of the raw buffer is a densely packed array of values that
+/// can be bitcast to the storage format of the element type specified.
+/// Types that are not byte aligned will be:
+///   - For bitwidth > 1: Rounded up to the next byte.
+///   - For bitwidth = 1: Packed into 8bit bytes with bits corresponding to
+///     the linear order of the shape type from MSB to LSB, padded to on the
+///     right.
+///
+/// A raw buffer of a single element (or for 1-bit, a byte of value 0 or 255)
+/// will be interpreted as a splat. User code should be prepared for additional,
+/// conformant patterns to be identified as splats in the future.
+MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrRawBufferGet(
+    MlirType shapedType, size_t rawBufferSize, const void *rawBuffer);
+
 /// Creates a dense elements attribute with the given Shaped type containing a
 /// single replicated element (splat).
 MLIR_CAPI_EXPORTED MlirAttribute

diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index bc7beee0dabb7..cc1d1d74615ea 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -183,15 +183,35 @@ class DenseElementsAttr : public Attribute {
   }
 
   /// Construct a dense elements attribute from a raw buffer representing the
-  /// data for this attribute. Users should generally not use this methods as
-  /// the expected buffer format may not be a form the user expects.
+  /// data for this attribute. Users are encouraged to use one of the
+  /// constructors above, which provide more safeties. However, this
+  /// constructor is useful for tools which may want to interop and can
+  /// follow the precise definition.
+  ///
+  /// The format of the raw buffer is a densely packed array of values that
+  /// can be bitcast to the storage format of the element type specified.
+  /// Types that are not byte aligned will be:
+  ///   - For bitwidth > 1: Rounded up to the next byte.
+  ///   - For bitwidth = 1: Packed into 8bit bytes with bits corresponding to
+  ///     the linear order of the shape type from MSB to LSB, padded to on the
+  ///     right.
+  ///
+  /// If `isSplatBuffer` is true, then the raw buffer should contain a
+  /// single element (or for the case of 1-bit, a single byte of 0 or 255),
+  /// which will be used to construct a splat.
   static DenseElementsAttr getFromRawBuffer(ShapedType type,
                                             ArrayRef<char> rawBuffer,
                                             bool isSplatBuffer);
 
   /// Returns true if the given buffer is a valid raw buffer for the given type.
   /// `detectedSplat` is set if the buffer is valid and represents a splat
-  /// buffer.
+  /// buffer. The definition may be expanded over time, but currently, a
+  /// splat buffer is detected if:
+  ///   - For >1bit: The buffer consists of a single element.
+  ///   - For 1bit: The buffer consists of a single byte with value 0 or 255.
+  ///
+  /// User code should be prepared for additional, conformant patterns to be
+  /// identified as splats in the future.
   static bool isValidRawBuffer(ShapedType type, ArrayRef<char> rawBuffer,
                                bool &detectedSplat);
 

diff  --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 2ff75ceedcf2e..47f73ecae4784 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -17,9 +17,57 @@ namespace py = pybind11;
 using namespace mlir;
 using namespace mlir::python;
 
+using llvm::None;
+using llvm::Optional;
 using llvm::SmallVector;
 using llvm::Twine;
 
+//------------------------------------------------------------------------------
+// Docstrings (trivial, non-duplicated docstrings are included inline).
+//------------------------------------------------------------------------------
+
+static const char kDenseElementsAttrGetDocstring[] =
+    R"(Gets a DenseElementsAttr from a Python buffer or array.
+
+When `type` is not provided, then some limited type inferencing is done based
+on the buffer format. Support presently exists for 8/16/32/64 signed and
+unsigned integers and float16/float32/float64. DenseElementsAttrs of these
+types can also be converted back to a corresponding buffer.
+
+For conversions outside of these types, a `type=` must be explicitly provided
+and the buffer contents must be bit-castable to the MLIR internal
+representation:
+
+  * Integer types (except for i1): the buffer must be byte aligned to the
+    next byte boundary.
+  * Floating point types: Must be bit-castable to the given floating point
+    size.
+  * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
+    row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
+    this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
+
+If a single element buffer is passed (or for i1, a single byte with value 0
+or 255), then a splat will be created.
+
+Args:
+  array: The array or buffer to convert.
+  signless: If inferring an appropriate MLIR type, use signless types for
+    integers (defaults True).
+  type: Skips inference of the MLIR element type and uses this instead. The
+    storage size must be consistent with the actual contents of the buffer.
+  shape: Overrides the shape of the buffer when constructing the MLIR
+    shaped type. This is needed when the physical and logical shape 
diff er (as
+    for i1).
+  context: Explicit context, if not from context manager.
+
+Returns:
+  DenseElementsAttr on success.
+
+Raises:
+  ValueError: If the type of the buffer or array cannot be matched to an MLIR
+    type or if the buffer does not meet expectations.
+)";
+
 namespace {
 
 static MlirStringRef toMlirStringRef(const std::string &s) {
@@ -301,7 +349,6 @@ class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
   }
 };
 
-// TODO: Support construction of bool elements.
 // TODO: Support construction of string elements.
 class PyDenseElementsAttribute
     : public PyConcreteAttribute<PyDenseElementsAttribute> {
@@ -311,7 +358,8 @@ class PyDenseElementsAttribute
   using PyConcreteAttribute::PyConcreteAttribute;
 
   static PyDenseElementsAttribute
-  getFromBuffer(py::buffer array, bool signless,
+  getFromBuffer(py::buffer array, bool signless, Optional<PyType> explicitType,
+                Optional<std::vector<int64_t>> explicitShape,
                 DefaultingPyMlirContext contextWrapper) {
     // Request a contiguous view. In exotic cases, this will cause a copy.
     int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
@@ -321,69 +369,95 @@ class PyDenseElementsAttribute
       throw py::error_already_set();
     }
     py::buffer_info arrayInfo(view);
+    SmallVector<int64_t> shape;
+    if (explicitShape) {
+      shape.append(explicitShape->begin(), explicitShape->end());
+    } else {
+      shape.append(arrayInfo.shape.begin(),
+                   arrayInfo.shape.begin() + arrayInfo.ndim);
+    }
 
+    MlirAttribute encodingAttr = mlirAttributeGetNull();
     MlirContext context = contextWrapper->get();
-    // Switch on the types that can be bulk loaded between the Python and
-    // MLIR-C APIs.
-    // See: https://docs.python.org/3/library/struct.html#format-characters
-    if (arrayInfo.format == "f") {
+
+    // Detect format codes that are suitable for bulk loading. This includes
+    // all byte aligned integer and floating point types up to 8 bytes.
+    // Notably, this excludes, bool (which needs to be bit-packed) and
+    // other exotics which do not have a direct representation in the buffer
+    // protocol (i.e. complex, etc).
+    Optional<MlirType> bulkLoadElementType;
+    if (explicitType) {
+      bulkLoadElementType = *explicitType;
+    } else if (arrayInfo.format == "f") {
       // f32
       assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
-      return PyDenseElementsAttribute(
-          contextWrapper->getRef(),
-          bulkLoad(context, mlirDenseElementsAttrFloatGet,
-                   mlirF32TypeGet(context), arrayInfo));
+      bulkLoadElementType = mlirF32TypeGet(context);
     } else if (arrayInfo.format == "d") {
       // f64
       assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
-      return PyDenseElementsAttribute(
-          contextWrapper->getRef(),
-          bulkLoad(context, mlirDenseElementsAttrDoubleGet,
-                   mlirF64TypeGet(context), arrayInfo));
+      bulkLoadElementType = mlirF64TypeGet(context);
+    } else if (arrayInfo.format == "e") {
+      // f16
+      assert(arrayInfo.itemsize == 2 && "mismatched array itemsize");
+      bulkLoadElementType = mlirF16TypeGet(context);
     } else if (isSignedIntegerFormat(arrayInfo.format)) {
       if (arrayInfo.itemsize == 4) {
         // i32
-        MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
-                                        : mlirIntegerTypeSignedGet(context, 32);
-        return PyDenseElementsAttribute(contextWrapper->getRef(),
-                                        bulkLoad(context,
-                                                 mlirDenseElementsAttrInt32Get,
-                                                 elementType, arrayInfo));
+        bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32)
+                                       : mlirIntegerTypeSignedGet(context, 32);
       } else if (arrayInfo.itemsize == 8) {
         // i64
-        MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
-                                        : mlirIntegerTypeSignedGet(context, 64);
-        return PyDenseElementsAttribute(contextWrapper->getRef(),
-                                        bulkLoad(context,
-                                                 mlirDenseElementsAttrInt64Get,
-                                                 elementType, arrayInfo));
+        bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64)
+                                       : mlirIntegerTypeSignedGet(context, 64);
+      } else if (arrayInfo.itemsize == 1) {
+        // i8
+        bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
+                                       : mlirIntegerTypeSignedGet(context, 8);
+      } else if (arrayInfo.itemsize == 2) {
+        // i16
+        bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16)
+                                       : mlirIntegerTypeSignedGet(context, 16);
       }
     } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
       if (arrayInfo.itemsize == 4) {
         // unsigned i32
-        MlirType elementType = signless
-                                   ? mlirIntegerTypeGet(context, 32)
-                                   : mlirIntegerTypeUnsignedGet(context, 32);
-        return PyDenseElementsAttribute(contextWrapper->getRef(),
-                                        bulkLoad(context,
-                                                 mlirDenseElementsAttrUInt32Get,
-                                                 elementType, arrayInfo));
+        bulkLoadElementType = signless
+                                  ? mlirIntegerTypeGet(context, 32)
+                                  : mlirIntegerTypeUnsignedGet(context, 32);
       } else if (arrayInfo.itemsize == 8) {
         // unsigned i64
-        MlirType elementType = signless
-                                   ? mlirIntegerTypeGet(context, 64)
-                                   : mlirIntegerTypeUnsignedGet(context, 64);
-        return PyDenseElementsAttribute(contextWrapper->getRef(),
-                                        bulkLoad(context,
-                                                 mlirDenseElementsAttrUInt64Get,
-                                                 elementType, arrayInfo));
+        bulkLoadElementType = signless
+                                  ? mlirIntegerTypeGet(context, 64)
+                                  : mlirIntegerTypeUnsignedGet(context, 64);
+      } else if (arrayInfo.itemsize == 1) {
+        // i8
+        bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
+                                       : mlirIntegerTypeUnsignedGet(context, 8);
+      } else if (arrayInfo.itemsize == 2) {
+        // i16
+        bulkLoadElementType = signless
+                                  ? mlirIntegerTypeGet(context, 16)
+                                  : mlirIntegerTypeUnsignedGet(context, 16);
       }
     }
+    if (bulkLoadElementType) {
+      auto shapedType = mlirRankedTensorTypeGet(
+          shape.size(), shape.data(), *bulkLoadElementType, encodingAttr);
+      size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize;
+      MlirAttribute attr = mlirDenseElementsAttrRawBufferGet(
+          shapedType, rawBufferSize, arrayInfo.ptr);
+      if (mlirAttributeIsNull(attr)) {
+        throw std::invalid_argument(
+            "DenseElementsAttr could not be constructed from the given buffer. "
+            "This may mean that the Python buffer layout does not match that "
+            "MLIR expected layout and is a bug.");
+      }
+      return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
+    }
 
-    // TODO: Fall back to string-based get.
-    std::string message = "unimplemented array format conversion from format: ";
-    message.append(arrayInfo.format);
-    throw SetPyError(PyExc_ValueError, message);
+    throw std::invalid_argument(
+        std::string("unimplemented array format conversion from format: ") +
+        arrayInfo.format);
   }
 
   static PyDenseElementsAttribute getSplat(PyType shapedType,
@@ -422,47 +496,82 @@ class PyDenseElementsAttribute
   intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
 
   py::buffer_info accessBuffer() {
+    if (mlirDenseElementsAttrIsSplat(*this)) {
+      // TODO: Raise an exception.
+      // Reported as https://github.com/pybind/pybind11/issues/3336
+      return py::buffer_info();
+    }
+
     MlirType shapedType = mlirAttributeGetType(*this);
     MlirType elementType = mlirShapedTypeGetElementType(shapedType);
+    std::string format;
 
     if (mlirTypeIsAF32(elementType)) {
       // f32
-      return bufferInfo(shapedType, mlirDenseElementsAttrGetFloatValue);
+      return bufferInfo<float>(shapedType);
     } else if (mlirTypeIsAF64(elementType)) {
       // f64
-      return bufferInfo(shapedType, mlirDenseElementsAttrGetDoubleValue);
+      return bufferInfo<double>(shapedType);
+    } else if (mlirTypeIsAF16(elementType)) {
+      // f16
+      return bufferInfo<uint16_t>(shapedType, "e");
     } else if (mlirTypeIsAInteger(elementType) &&
                mlirIntegerTypeGetWidth(elementType) == 32) {
       if (mlirIntegerTypeIsSignless(elementType) ||
           mlirIntegerTypeIsSigned(elementType)) {
         // i32
-        return bufferInfo(shapedType, mlirDenseElementsAttrGetInt32Value);
+        return bufferInfo<int32_t>(shapedType);
       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
         // unsigned i32
-        return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt32Value);
+        return bufferInfo<uint32_t>(shapedType);
       }
     } else if (mlirTypeIsAInteger(elementType) &&
                mlirIntegerTypeGetWidth(elementType) == 64) {
       if (mlirIntegerTypeIsSignless(elementType) ||
           mlirIntegerTypeIsSigned(elementType)) {
         // i64
-        return bufferInfo(shapedType, mlirDenseElementsAttrGetInt64Value);
+        return bufferInfo<int64_t>(shapedType);
       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
         // unsigned i64
-        return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt64Value);
+        return bufferInfo<uint64_t>(shapedType);
+      }
+    } else if (mlirTypeIsAInteger(elementType) &&
+               mlirIntegerTypeGetWidth(elementType) == 8) {
+      if (mlirIntegerTypeIsSignless(elementType) ||
+          mlirIntegerTypeIsSigned(elementType)) {
+        // i8
+        return bufferInfo<int8_t>(shapedType);
+      } else if (mlirIntegerTypeIsUnsigned(elementType)) {
+        // unsigned i8
+        return bufferInfo<uint8_t>(shapedType);
+      }
+    } else if (mlirTypeIsAInteger(elementType) &&
+               mlirIntegerTypeGetWidth(elementType) == 16) {
+      if (mlirIntegerTypeIsSignless(elementType) ||
+          mlirIntegerTypeIsSigned(elementType)) {
+        // i16
+        return bufferInfo<int16_t>(shapedType);
+      } else if (mlirIntegerTypeIsUnsigned(elementType)) {
+        // unsigned i16
+        return bufferInfo<uint16_t>(shapedType);
       }
     }
 
-    std::string message = "unimplemented array format.";
-    throw SetPyError(PyExc_ValueError, message);
+    // TODO: Currently crashes the program. Just returning an empty buffer
+    // for now.
+    // Reported as https://github.com/pybind/pybind11/issues/3336
+    // throw std::invalid_argument(
+    //     "unsupported data type for conversion to Python buffer");
+    return py::buffer_info();
   }
 
   static void bindDerived(ClassTy &c) {
     c.def("__len__", &PyDenseElementsAttribute::dunderLen)
         .def_static("get", PyDenseElementsAttribute::getFromBuffer,
                     py::arg("array"), py::arg("signless") = true,
+                    py::arg("type") = py::none(), py::arg("shape") = py::none(),
                     py::arg("context") = py::none(),
-                    "Gets from a buffer or ndarray")
+                    kDenseElementsAttrGetDocstring)
         .def_static("get_splat", PyDenseElementsAttribute::getSplat,
                     py::arg("shaped_type"), py::arg("element_attr"),
                     "Gets a DenseElementsAttr where all values are the same")
@@ -474,21 +583,6 @@ class PyDenseElementsAttribute
   }
 
 private:
-  template <typename ElementTy>
-  static MlirAttribute
-  bulkLoad(MlirContext context,
-           MlirAttribute (*ctor)(MlirType, intptr_t, ElementTy *),
-           MlirType mlirElementType, py::buffer_info &arrayInfo) {
-    SmallVector<int64_t, 4> shape(arrayInfo.shape.begin(),
-                                  arrayInfo.shape.begin() + arrayInfo.ndim);
-    MlirAttribute encodingAttr = mlirAttributeGetNull();
-    auto shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(),
-                                              mlirElementType, encodingAttr);
-    intptr_t numElements = arrayInfo.size;
-    const ElementTy *contents = static_cast<const ElementTy *>(arrayInfo.ptr);
-    return ctor(shapedType, numElements, contents);
-  }
-
   static bool isUnsignedIntegerFormat(const std::string &format) {
     if (format.empty())
       return false;
@@ -507,7 +601,7 @@ class PyDenseElementsAttribute
 
   template <typename Type>
   py::buffer_info bufferInfo(MlirType shapedType,
-                             Type (*value)(MlirAttribute, intptr_t)) {
+                             const char *explicitFormat = nullptr) {
     intptr_t rank = mlirShapedTypeGetRank(shapedType);
     // Prepare the data for the buffer_info.
     // Buffer is configured for read-only access below.
@@ -528,9 +622,14 @@ class PyDenseElementsAttribute
       strides.push_back(sizeof(Type) * strideFactor);
     }
     strides.push_back(sizeof(Type));
-    return py::buffer_info(data, sizeof(Type),
-                           py::format_descriptor<Type>::format(), rank, shape,
-                           strides, /*readonly=*/true);
+    std::string format;
+    if (explicitFormat) {
+      format = explicitFormat;
+    } else {
+      format = py::format_descriptor<Type>::format();
+    }
+    return py::buffer_info(data, sizeof(Type), format, rank, shape, strides,
+                           /*readonly=*/true);
   }
 }; // namespace
 

diff  --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index a2ee06722f0d8..3b15212e30023 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -331,6 +331,21 @@ MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType,
                              unwrapList(numElements, elements, attributes)));
 }
 
+MlirAttribute mlirDenseElementsAttrRawBufferGet(MlirType shapedType,
+                                                size_t rawBufferSize,
+                                                const void *rawBuffer) {
+  auto shapedTypeCpp = unwrap(shapedType).cast<ShapedType>();
+  ArrayRef<char> rawBufferCpp(static_cast<const char *>(rawBuffer),
+                              rawBufferSize);
+  bool isSplat = false;
+  if (!DenseElementsAttr::isValidRawBuffer(shapedTypeCpp, rawBufferCpp,
+                                           isSplat)) {
+    return mlirAttributeGetNull();
+  }
+  return wrap(DenseElementsAttr::getFromRawBuffer(shapedTypeCpp, rawBufferCpp,
+                                                  isSplat));
+}
+
 MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType,
                                             MlirAttribute element) {
   return wrap(DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),

diff  --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index e03bac1d3d5c2..27d851c6594ee 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -792,9 +792,16 @@ bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
 
   // Storage width of 1 is special as it is packed by the bit.
   if (storageWidth == 1) {
-    // Check for a splat, or a buffer equal to the number of elements.
-    if ((detectedSplat = rawBuffer.size() == 1))
-      return true;
+    // Check for a splat, or a buffer equal to the number of elements which
+    // consists of either all 0's or all 1's.
+    detectedSplat = false;
+    if (rawBuffer.size() == 1) {
+      auto rawByte = static_cast<uint8_t>(rawBuffer[0]);
+      if (rawByte == 0 || rawByte == 0xff) {
+        detectedSplat = true;
+        return true;
+      }
+    }
     return rawBufferWidth == llvm::alignTo<8>(type.getNumElements());
   }
   // All other types are 8-bit aligned.

diff  --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py
index 2a904e63ac5f2..13c7c215484b4 100644
--- a/mlir/test/python/ir/array_attributes.py
+++ b/mlir/test/python/ir/array_attributes.py
@@ -11,11 +11,13 @@ def run(f):
   f()
   gc.collect()
   assert Context._get_live_count() == 0
+  return f
 
 ################################################################################
 # Tests of the array/buffer .get() factory method on unsupported dtype.
 ################################################################################
 
+ at run
 def testGetDenseElementsUnsupported():
   with Context():
     array = np.array([["hello", "goodbye"]])
@@ -25,13 +27,12 @@ def testGetDenseElementsUnsupported():
       # CHECK: unimplemented array format conversion from format:
       print(e)
 
-run(testGetDenseElementsUnsupported)
-
 ################################################################################
 # Splats.
 ################################################################################
 
 # CHECK-LABEL: TEST: testGetDenseElementsSplatInt
+ at run
 def testGetDenseElementsSplatInt():
   with Context(), Location.unknown():
     t = IntegerType.get_signless(32)
@@ -43,10 +44,9 @@ def testGetDenseElementsSplatInt():
     # CHECK: is_splat: True
     print("is_splat:", attr.is_splat)
 
-run(testGetDenseElementsSplatInt)
-
 
 # CHECK-LABEL: TEST: testGetDenseElementsSplatFloat
+ at run
 def testGetDenseElementsSplatFloat():
   with Context(), Location.unknown():
     t = F32Type.get()
@@ -56,10 +56,9 @@ def testGetDenseElementsSplatFloat():
     # CHECK: dense<1.200000e+00> : tensor<2x3x4xf32>
     print(attr)
 
-run(testGetDenseElementsSplatFloat)
-
 
 # CHECK-LABEL: TEST: testGetDenseElementsSplatErrors
+ at run
 def testGetDenseElementsSplatErrors():
   with Context(), Location.unknown():
     t = F32Type.get()
@@ -88,32 +87,113 @@ def testGetDenseElementsSplatErrors():
       # CHECK: Shaped element type and attribute type must be equal: shaped=Type(tensor<2x3x4xf32>), element=Attribute(1.200000e+00 : f64)
       print(e)
 
-run(testGetDenseElementsSplatErrors)
+
+# CHECK-LABEL: TEST: testRepeatedValuesSplat
+ at run
+def testRepeatedValuesSplat():
+  with Context():
+    array = np.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], dtype=np.float32)
+    attr = DenseElementsAttr.get(array)
+    # CHECK: dense<1.000000e+00> : tensor<2x3xf32>
+    print(attr)
+    # CHECK: is_splat: True
+    print("is_splat:", attr.is_splat)
+    # CHECK: ()
+    print(np.array(attr))
+
+
+# CHECK-LABEL: TEST: testNonSplat
+ at run
+def testNonSplat():
+  with Context():
+    array = np.array([2.0, 1.0, 1.0], dtype=np.float32)
+    attr = DenseElementsAttr.get(array)
+    # CHECK: is_splat: False
+    print("is_splat:", attr.is_splat)
 
 
 ################################################################################
 # Tests of the array/buffer .get() factory method, in all of its permutations.
 ################################################################################
 
+### explicitly provided types
+
+ at run
+def testGetDenseElementsBF16():
+  with Context():
+    array = np.array([[2, 4, 8], [16, 32, 64]], dtype=np.uint16)
+    attr = DenseElementsAttr.get(array, type=BF16Type.get())
+    # Note: These values don't mean much since just bit-casting. But they
+    # shouldn't change.
+    # CHECK: dense<{{\[}}[1.836710e-40, 3.673420e-40, 7.346840e-40], [1.469370e-39, 2.938740e-39, 5.877470e-39]]> : tensor<2x3xbf16>
+    print(attr)
+
+ at run
+def testGetDenseElementsInteger4():
+  with Context():
+    array = np.array([[2, 4, 7], [-2, -4, -8]], dtype=np.uint8)
+    attr = DenseElementsAttr.get(array, type=IntegerType.get_signless(4))
+    # Note: These values don't mean much since just bit-casting. But they
+    # shouldn't change.
+    # CHECK: dense<{{\[}}[2, 4, 7], [-2, -4, -8]]> : tensor<2x3xi4>
+    print(attr)
+
+
+ at run
+def testGetDenseElementsBool():
+  with Context():
+    bool_array = np.array([[1, 0, 1], [0, 1, 0]], dtype=np.bool_)
+    array = np.packbits(bool_array, axis=None, bitorder="little")
+    attr = DenseElementsAttr.get(
+        array, type=IntegerType.get_signless(1), shape=bool_array.shape)
+    # CHECK: dense<{{\[}}[true, false, true], [false, true, false]]> : tensor<2x3xi1>
+    print(attr)
+
+
+ at run
+def testGetDenseElementsBoolSplat():
+  with Context():
+    zero = np.array(0, dtype=np.uint8)
+    one = np.array(255, dtype=np.uint8)
+    print(one)
+    # CHECK: dense<false> : tensor<4x2x5xi1>
+    print(DenseElementsAttr.get(
+        zero, type=IntegerType.get_signless(1), shape=(4, 2, 5)))
+    # CHECK: dense<true> : tensor<4x2x5xi1>
+    print(DenseElementsAttr.get(
+        one, type=IntegerType.get_signless(1), shape=(4, 2, 5)))
+
+
 ### float and double arrays.
 
+# CHECK-LABEL: TEST: testGetDenseElementsF16
+ at run
+def testGetDenseElementsF16():
+  with Context():
+    array = np.array([[2.0, 4.0, 8.0], [16.0, 32.0, 64.0]], dtype=np.float16)
+    attr = DenseElementsAttr.get(array)
+    # CHECK: dense<{{\[}}[2.000000e+00, 4.000000e+00, 8.000000e+00], [1.600000e+01, 3.200000e+01, 6.400000e+01]]> : tensor<2x3xf16>
+    print(attr)
+    # CHECK: {{\[}}[ 2. 4. 8.]
+    # CHECK: {{\[}}16. 32. 64.]]
+    print(np.array(attr))
+
+
 # CHECK-LABEL: TEST: testGetDenseElementsF32
+ at run
 def testGetDenseElementsF32():
   with Context():
     array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float32)
     attr = DenseElementsAttr.get(array)
     # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf32>
     print(attr)
-    # CHECK: is_splat: False
-    print("is_splat:", attr.is_splat)
     # CHECK: {{\[}}[1.1 2.2 3.3]
     # CHECK: {{\[}}4.4 5.5 6.6]]
     print(np.array(attr))
 
-run(testGetDenseElementsF32)
-
 
 # CHECK-LABEL: TEST: testGetDenseElementsF64
+ at run
 def testGetDenseElementsF64():
   with Context():
     array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float64)
@@ -124,11 +204,62 @@ def testGetDenseElementsF64():
     # CHECK: {{\[}}4.4 5.5 6.6]]
     print(np.array(attr))
 
-run(testGetDenseElementsF64)
 
+### 16 bit integer arrays
+# CHECK-LABEL: TEST: testGetDenseElementsI16Signless
+ at run
+def testGetDenseElementsI16Signless():
+  with Context():
+    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16)
+    attr = DenseElementsAttr.get(array)
+    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi16>
+    print(attr)
+    # CHECK: {{\[}}[1 2 3]
+    # CHECK: {{\[}}4 5 6]]
+    print(np.array(attr))
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsUI16Signless
+ at run
+def testGetDenseElementsUI16Signless():
+  with Context():
+    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16)
+    attr = DenseElementsAttr.get(array)
+    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi16>
+    print(attr)
+    # CHECK: {{\[}}[1 2 3]
+    # CHECK: {{\[}}4 5 6]]
+    print(np.array(attr))
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsI16
+ at run
+def testGetDenseElementsI16():
+  with Context():
+    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16)
+    attr = DenseElementsAttr.get(array, signless=False)
+    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi16>
+    print(attr)
+    # CHECK: {{\[}}[1 2 3]
+    # CHECK: {{\[}}4 5 6]]
+    print(np.array(attr))
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsUI16
+ at run
+def testGetDenseElementsUI16():
+  with Context():
+    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16)
+    attr = DenseElementsAttr.get(array, signless=False)
+    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui16>
+    print(attr)
+    # CHECK: {{\[}}[1 2 3]
+    # CHECK: {{\[}}4 5 6]]
+    print(np.array(attr))
 
 ### 32 bit integer arrays
 # CHECK-LABEL: TEST: testGetDenseElementsI32Signless
+ at run
 def testGetDenseElementsI32Signless():
   with Context():
     array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
@@ -139,10 +270,9 @@ def testGetDenseElementsI32Signless():
     # CHECK: {{\[}}4 5 6]]
     print(np.array(attr))
 
-run(testGetDenseElementsI32Signless)
-
 
 # CHECK-LABEL: TEST: testGetDenseElementsUI32Signless
+ at run
 def testGetDenseElementsUI32Signless():
   with Context():
     array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32)
@@ -153,9 +283,9 @@ def testGetDenseElementsUI32Signless():
     # CHECK: {{\[}}4 5 6]]
     print(np.array(attr))
 
-run(testGetDenseElementsUI32Signless)
 
 # CHECK-LABEL: TEST: testGetDenseElementsI32
+ at run
 def testGetDenseElementsI32():
   with Context():
     array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
@@ -166,10 +296,9 @@ def testGetDenseElementsI32():
     # CHECK: {{\[}}4 5 6]]
     print(np.array(attr))
 
-run(testGetDenseElementsI32)
-
 
 # CHECK-LABEL: TEST: testGetDenseElementsUI32
+ at run
 def testGetDenseElementsUI32():
   with Context():
     array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32)
@@ -180,11 +309,10 @@ def testGetDenseElementsUI32():
     # CHECK: {{\[}}4 5 6]]
     print(np.array(attr))
 
-run(testGetDenseElementsUI32)
-
 
 ## 64bit integer arrays
 # CHECK-LABEL: TEST: testGetDenseElementsI64Signless
+ at run
 def testGetDenseElementsI64Signless():
   with Context():
     array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
@@ -195,10 +323,9 @@ def testGetDenseElementsI64Signless():
     # CHECK: {{\[}}4 5 6]]
     print(np.array(attr))
 
-run(testGetDenseElementsI64Signless)
-
 
 # CHECK-LABEL: TEST: testGetDenseElementsUI64Signless
+ at run
 def testGetDenseElementsUI64Signless():
   with Context():
     array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64)
@@ -209,9 +336,9 @@ def testGetDenseElementsUI64Signless():
     # CHECK: {{\[}}4 5 6]]
     print(np.array(attr))
 
-run(testGetDenseElementsUI64Signless)
 
 # CHECK-LABEL: TEST: testGetDenseElementsI64
+ at run
 def testGetDenseElementsI64():
   with Context():
     array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
@@ -222,10 +349,9 @@ def testGetDenseElementsI64():
     # CHECK: {{\[}}4 5 6]]
     print(np.array(attr))
 
-run(testGetDenseElementsI64)
-
 
 # CHECK-LABEL: TEST: testGetDenseElementsUI64
+ at run
 def testGetDenseElementsUI64():
   with Context():
     array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64)
@@ -236,5 +362,3 @@ def testGetDenseElementsUI64():
     # CHECK: {{\[}}4 5 6]]
     print(np.array(attr))
 
-run(testGetDenseElementsUI64)
-


        


More information about the Mlir-commits mailing list