[Mlir-commits] [mlir] 5d6d30e - [mlir] Extend C and Python API to support bulk loading of DenseElementsAttr.
Stella Laurenzo
llvmlistbot at llvm.org
Thu Oct 7 08:44:56 PDT 2021
Author: Stella Laurenzo
Date: 2021-10-07T08:42:12-07:00
New Revision: 5d6d30edf8b9b2c69215bdbbc651a85e4d0dc4ff
URL: https://github.com/llvm/llvm-project/commit/5d6d30edf8b9b2c69215bdbbc651a85e4d0dc4ff
DIFF: https://github.com/llvm/llvm-project/commit/5d6d30edf8b9b2c69215bdbbc651a85e4d0dc4ff.diff
LOG: [mlir] Extend C and Python API to support bulk loading of DenseElementsAttr.
* This already half existed in terms of reading the raw buffer backing a DenseElementsAttr.
* Documented the precise expectations of the buffer layout.
* Extended the Python API to support construction from bitcasted buffers, allowing construction of all primitive element types (even those that lack a compatible representation in Python).
* Specifically, the Python API can now load all integer types at all bit widths and all floating point types (f16, f32, f64, bf16).
Differential Revision: https://reviews.llvm.org/D111284
Added:
Modified:
mlir/include/mlir-c/BuiltinAttributes.h
mlir/include/mlir/IR/BuiltinAttributes.h
mlir/lib/Bindings/Python/IRAttributes.cpp
mlir/lib/CAPI/IR/BuiltinAttributes.cpp
mlir/lib/IR/BuiltinAttributes.cpp
mlir/test/python/ir/array_attributes.py
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h
index 247de5cc0bd62..5839cd3d2408a 100644
--- a/mlir/include/mlir-c/BuiltinAttributes.h
+++ b/mlir/include/mlir-c/BuiltinAttributes.h
@@ -306,6 +306,23 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseFPElements(MlirAttribute attr);
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrGet(
MlirType shapedType, intptr_t numElements, MlirAttribute const *elements);
+/// Creates a dense elements attribute with the given Shaped type and elements
+/// populated from a packed, row-major opaque buffer of contents.
+///
+/// The format of the raw buffer is a densely packed array of values that
+/// can be bitcast to the storage format of the element type specified.
+/// Types that are not byte aligned will be:
+/// - For bitwidth > 1: Rounded up to the next byte.
+/// - For bitwidth = 1: Packed into 8bit bytes with bits corresponding to
+/// the linear order of the shape type from MSB to LSB, padded to on the
+/// right.
+///
+/// A raw buffer of a single element (or for 1-bit, a byte of value 0 or 255)
+/// will be interpreted as a splat. User code should be prepared for additional,
+/// conformant patterns to be identified as splats in the future.
+MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrRawBufferGet(
+ MlirType shapedType, size_t rawBufferSize, const void *rawBuffer);
+
/// Creates a dense elements attribute with the given Shaped type containing a
/// single replicated element (splat).
MLIR_CAPI_EXPORTED MlirAttribute
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index bc7beee0dabb7..cc1d1d74615ea 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -183,15 +183,35 @@ class DenseElementsAttr : public Attribute {
}
/// Construct a dense elements attribute from a raw buffer representing the
- /// data for this attribute. Users should generally not use this methods as
- /// the expected buffer format may not be a form the user expects.
+ /// data for this attribute. Users are encouraged to use one of the
+ /// constructors above, which provide more safeties. However, this
+ /// constructor is useful for tools which may want to interop and can
+ /// follow the precise definition.
+ ///
+ /// The format of the raw buffer is a densely packed array of values that
+ /// can be bitcast to the storage format of the element type specified.
+ /// Types that are not byte aligned will be:
+ /// - For bitwidth > 1: Rounded up to the next byte.
+ /// - For bitwidth = 1: Packed into 8bit bytes with bits corresponding to
+ /// the linear order of the shape type from MSB to LSB, padded to on the
+ /// right.
+ ///
+ /// If `isSplatBuffer` is true, then the raw buffer should contain a
+ /// single element (or for the case of 1-bit, a single byte of 0 or 255),
+ /// which will be used to construct a splat.
static DenseElementsAttr getFromRawBuffer(ShapedType type,
ArrayRef<char> rawBuffer,
bool isSplatBuffer);
/// Returns true if the given buffer is a valid raw buffer for the given type.
/// `detectedSplat` is set if the buffer is valid and represents a splat
- /// buffer.
+ /// buffer. The definition may be expanded over time, but currently, a
+ /// splat buffer is detected if:
+ /// - For >1bit: The buffer consists of a single element.
+ /// - For 1bit: The buffer consists of a single byte with value 0 or 255.
+ ///
+ /// User code should be prepared for additional, conformant patterns to be
+ /// identified as splats in the future.
static bool isValidRawBuffer(ShapedType type, ArrayRef<char> rawBuffer,
bool &detectedSplat);
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 2ff75ceedcf2e..47f73ecae4784 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -17,9 +17,57 @@ namespace py = pybind11;
using namespace mlir;
using namespace mlir::python;
+using llvm::None;
+using llvm::Optional;
using llvm::SmallVector;
using llvm::Twine;
+//------------------------------------------------------------------------------
+// Docstrings (trivial, non-duplicated docstrings are included inline).
+//------------------------------------------------------------------------------
+
+static const char kDenseElementsAttrGetDocstring[] =
+ R"(Gets a DenseElementsAttr from a Python buffer or array.
+
+When `type` is not provided, then some limited type inferencing is done based
+on the buffer format. Support presently exists for 8/16/32/64 signed and
+unsigned integers and float16/float32/float64. DenseElementsAttrs of these
+types can also be converted back to a corresponding buffer.
+
+For conversions outside of these types, a `type=` must be explicitly provided
+and the buffer contents must be bit-castable to the MLIR internal
+representation:
+
+ * Integer types (except for i1): the buffer must be byte aligned to the
+ next byte boundary.
+ * Floating point types: Must be bit-castable to the given floating point
+ size.
+ * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
+ row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
+ this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
+
+If a single element buffer is passed (or for i1, a single byte with value 0
+or 255), then a splat will be created.
+
+Args:
+ array: The array or buffer to convert.
+ signless: If inferring an appropriate MLIR type, use signless types for
+ integers (defaults True).
+ type: Skips inference of the MLIR element type and uses this instead. The
+ storage size must be consistent with the actual contents of the buffer.
+ shape: Overrides the shape of the buffer when constructing the MLIR
+ shaped type. This is needed when the physical and logical shape
diff er (as
+ for i1).
+ context: Explicit context, if not from context manager.
+
+Returns:
+ DenseElementsAttr on success.
+
+Raises:
+ ValueError: If the type of the buffer or array cannot be matched to an MLIR
+ type or if the buffer does not meet expectations.
+)";
+
namespace {
static MlirStringRef toMlirStringRef(const std::string &s) {
@@ -301,7 +349,6 @@ class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
}
};
-// TODO: Support construction of bool elements.
// TODO: Support construction of string elements.
class PyDenseElementsAttribute
: public PyConcreteAttribute<PyDenseElementsAttribute> {
@@ -311,7 +358,8 @@ class PyDenseElementsAttribute
using PyConcreteAttribute::PyConcreteAttribute;
static PyDenseElementsAttribute
- getFromBuffer(py::buffer array, bool signless,
+ getFromBuffer(py::buffer array, bool signless, Optional<PyType> explicitType,
+ Optional<std::vector<int64_t>> explicitShape,
DefaultingPyMlirContext contextWrapper) {
// Request a contiguous view. In exotic cases, this will cause a copy.
int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
@@ -321,69 +369,95 @@ class PyDenseElementsAttribute
throw py::error_already_set();
}
py::buffer_info arrayInfo(view);
+ SmallVector<int64_t> shape;
+ if (explicitShape) {
+ shape.append(explicitShape->begin(), explicitShape->end());
+ } else {
+ shape.append(arrayInfo.shape.begin(),
+ arrayInfo.shape.begin() + arrayInfo.ndim);
+ }
+ MlirAttribute encodingAttr = mlirAttributeGetNull();
MlirContext context = contextWrapper->get();
- // Switch on the types that can be bulk loaded between the Python and
- // MLIR-C APIs.
- // See: https://docs.python.org/3/library/struct.html#format-characters
- if (arrayInfo.format == "f") {
+
+ // Detect format codes that are suitable for bulk loading. This includes
+ // all byte aligned integer and floating point types up to 8 bytes.
+ // Notably, this excludes, bool (which needs to be bit-packed) and
+ // other exotics which do not have a direct representation in the buffer
+ // protocol (i.e. complex, etc).
+ Optional<MlirType> bulkLoadElementType;
+ if (explicitType) {
+ bulkLoadElementType = *explicitType;
+ } else if (arrayInfo.format == "f") {
// f32
assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
- return PyDenseElementsAttribute(
- contextWrapper->getRef(),
- bulkLoad(context, mlirDenseElementsAttrFloatGet,
- mlirF32TypeGet(context), arrayInfo));
+ bulkLoadElementType = mlirF32TypeGet(context);
} else if (arrayInfo.format == "d") {
// f64
assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
- return PyDenseElementsAttribute(
- contextWrapper->getRef(),
- bulkLoad(context, mlirDenseElementsAttrDoubleGet,
- mlirF64TypeGet(context), arrayInfo));
+ bulkLoadElementType = mlirF64TypeGet(context);
+ } else if (arrayInfo.format == "e") {
+ // f16
+ assert(arrayInfo.itemsize == 2 && "mismatched array itemsize");
+ bulkLoadElementType = mlirF16TypeGet(context);
} else if (isSignedIntegerFormat(arrayInfo.format)) {
if (arrayInfo.itemsize == 4) {
// i32
- MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
- : mlirIntegerTypeSignedGet(context, 32);
- return PyDenseElementsAttribute(contextWrapper->getRef(),
- bulkLoad(context,
- mlirDenseElementsAttrInt32Get,
- elementType, arrayInfo));
+ bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32)
+ : mlirIntegerTypeSignedGet(context, 32);
} else if (arrayInfo.itemsize == 8) {
// i64
- MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
- : mlirIntegerTypeSignedGet(context, 64);
- return PyDenseElementsAttribute(contextWrapper->getRef(),
- bulkLoad(context,
- mlirDenseElementsAttrInt64Get,
- elementType, arrayInfo));
+ bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64)
+ : mlirIntegerTypeSignedGet(context, 64);
+ } else if (arrayInfo.itemsize == 1) {
+ // i8
+ bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
+ : mlirIntegerTypeSignedGet(context, 8);
+ } else if (arrayInfo.itemsize == 2) {
+ // i16
+ bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16)
+ : mlirIntegerTypeSignedGet(context, 16);
}
} else if (isUnsignedIntegerFormat(arrayInfo.format)) {
if (arrayInfo.itemsize == 4) {
// unsigned i32
- MlirType elementType = signless
- ? mlirIntegerTypeGet(context, 32)
- : mlirIntegerTypeUnsignedGet(context, 32);
- return PyDenseElementsAttribute(contextWrapper->getRef(),
- bulkLoad(context,
- mlirDenseElementsAttrUInt32Get,
- elementType, arrayInfo));
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 32)
+ : mlirIntegerTypeUnsignedGet(context, 32);
} else if (arrayInfo.itemsize == 8) {
// unsigned i64
- MlirType elementType = signless
- ? mlirIntegerTypeGet(context, 64)
- : mlirIntegerTypeUnsignedGet(context, 64);
- return PyDenseElementsAttribute(contextWrapper->getRef(),
- bulkLoad(context,
- mlirDenseElementsAttrUInt64Get,
- elementType, arrayInfo));
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 64)
+ : mlirIntegerTypeUnsignedGet(context, 64);
+ } else if (arrayInfo.itemsize == 1) {
+ // i8
+ bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
+ : mlirIntegerTypeUnsignedGet(context, 8);
+ } else if (arrayInfo.itemsize == 2) {
+ // i16
+ bulkLoadElementType = signless
+ ? mlirIntegerTypeGet(context, 16)
+ : mlirIntegerTypeUnsignedGet(context, 16);
}
}
+ if (bulkLoadElementType) {
+ auto shapedType = mlirRankedTensorTypeGet(
+ shape.size(), shape.data(), *bulkLoadElementType, encodingAttr);
+ size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize;
+ MlirAttribute attr = mlirDenseElementsAttrRawBufferGet(
+ shapedType, rawBufferSize, arrayInfo.ptr);
+ if (mlirAttributeIsNull(attr)) {
+ throw std::invalid_argument(
+ "DenseElementsAttr could not be constructed from the given buffer. "
+ "This may mean that the Python buffer layout does not match that "
+ "MLIR expected layout and is a bug.");
+ }
+ return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
+ }
- // TODO: Fall back to string-based get.
- std::string message = "unimplemented array format conversion from format: ";
- message.append(arrayInfo.format);
- throw SetPyError(PyExc_ValueError, message);
+ throw std::invalid_argument(
+ std::string("unimplemented array format conversion from format: ") +
+ arrayInfo.format);
}
static PyDenseElementsAttribute getSplat(PyType shapedType,
@@ -422,47 +496,82 @@ class PyDenseElementsAttribute
intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
py::buffer_info accessBuffer() {
+ if (mlirDenseElementsAttrIsSplat(*this)) {
+ // TODO: Raise an exception.
+ // Reported as https://github.com/pybind/pybind11/issues/3336
+ return py::buffer_info();
+ }
+
MlirType shapedType = mlirAttributeGetType(*this);
MlirType elementType = mlirShapedTypeGetElementType(shapedType);
+ std::string format;
if (mlirTypeIsAF32(elementType)) {
// f32
- return bufferInfo(shapedType, mlirDenseElementsAttrGetFloatValue);
+ return bufferInfo<float>(shapedType);
} else if (mlirTypeIsAF64(elementType)) {
// f64
- return bufferInfo(shapedType, mlirDenseElementsAttrGetDoubleValue);
+ return bufferInfo<double>(shapedType);
+ } else if (mlirTypeIsAF16(elementType)) {
+ // f16
+ return bufferInfo<uint16_t>(shapedType, "e");
} else if (mlirTypeIsAInteger(elementType) &&
mlirIntegerTypeGetWidth(elementType) == 32) {
if (mlirIntegerTypeIsSignless(elementType) ||
mlirIntegerTypeIsSigned(elementType)) {
// i32
- return bufferInfo(shapedType, mlirDenseElementsAttrGetInt32Value);
+ return bufferInfo<int32_t>(shapedType);
} else if (mlirIntegerTypeIsUnsigned(elementType)) {
// unsigned i32
- return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt32Value);
+ return bufferInfo<uint32_t>(shapedType);
}
} else if (mlirTypeIsAInteger(elementType) &&
mlirIntegerTypeGetWidth(elementType) == 64) {
if (mlirIntegerTypeIsSignless(elementType) ||
mlirIntegerTypeIsSigned(elementType)) {
// i64
- return bufferInfo(shapedType, mlirDenseElementsAttrGetInt64Value);
+ return bufferInfo<int64_t>(shapedType);
} else if (mlirIntegerTypeIsUnsigned(elementType)) {
// unsigned i64
- return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt64Value);
+ return bufferInfo<uint64_t>(shapedType);
+ }
+ } else if (mlirTypeIsAInteger(elementType) &&
+ mlirIntegerTypeGetWidth(elementType) == 8) {
+ if (mlirIntegerTypeIsSignless(elementType) ||
+ mlirIntegerTypeIsSigned(elementType)) {
+ // i8
+ return bufferInfo<int8_t>(shapedType);
+ } else if (mlirIntegerTypeIsUnsigned(elementType)) {
+ // unsigned i8
+ return bufferInfo<uint8_t>(shapedType);
+ }
+ } else if (mlirTypeIsAInteger(elementType) &&
+ mlirIntegerTypeGetWidth(elementType) == 16) {
+ if (mlirIntegerTypeIsSignless(elementType) ||
+ mlirIntegerTypeIsSigned(elementType)) {
+ // i16
+ return bufferInfo<int16_t>(shapedType);
+ } else if (mlirIntegerTypeIsUnsigned(elementType)) {
+ // unsigned i16
+ return bufferInfo<uint16_t>(shapedType);
}
}
- std::string message = "unimplemented array format.";
- throw SetPyError(PyExc_ValueError, message);
+ // TODO: Currently crashes the program. Just returning an empty buffer
+ // for now.
+ // Reported as https://github.com/pybind/pybind11/issues/3336
+ // throw std::invalid_argument(
+ // "unsupported data type for conversion to Python buffer");
+ return py::buffer_info();
}
static void bindDerived(ClassTy &c) {
c.def("__len__", &PyDenseElementsAttribute::dunderLen)
.def_static("get", PyDenseElementsAttribute::getFromBuffer,
py::arg("array"), py::arg("signless") = true,
+ py::arg("type") = py::none(), py::arg("shape") = py::none(),
py::arg("context") = py::none(),
- "Gets from a buffer or ndarray")
+ kDenseElementsAttrGetDocstring)
.def_static("get_splat", PyDenseElementsAttribute::getSplat,
py::arg("shaped_type"), py::arg("element_attr"),
"Gets a DenseElementsAttr where all values are the same")
@@ -474,21 +583,6 @@ class PyDenseElementsAttribute
}
private:
- template <typename ElementTy>
- static MlirAttribute
- bulkLoad(MlirContext context,
- MlirAttribute (*ctor)(MlirType, intptr_t, ElementTy *),
- MlirType mlirElementType, py::buffer_info &arrayInfo) {
- SmallVector<int64_t, 4> shape(arrayInfo.shape.begin(),
- arrayInfo.shape.begin() + arrayInfo.ndim);
- MlirAttribute encodingAttr = mlirAttributeGetNull();
- auto shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(),
- mlirElementType, encodingAttr);
- intptr_t numElements = arrayInfo.size;
- const ElementTy *contents = static_cast<const ElementTy *>(arrayInfo.ptr);
- return ctor(shapedType, numElements, contents);
- }
-
static bool isUnsignedIntegerFormat(const std::string &format) {
if (format.empty())
return false;
@@ -507,7 +601,7 @@ class PyDenseElementsAttribute
template <typename Type>
py::buffer_info bufferInfo(MlirType shapedType,
- Type (*value)(MlirAttribute, intptr_t)) {
+ const char *explicitFormat = nullptr) {
intptr_t rank = mlirShapedTypeGetRank(shapedType);
// Prepare the data for the buffer_info.
// Buffer is configured for read-only access below.
@@ -528,9 +622,14 @@ class PyDenseElementsAttribute
strides.push_back(sizeof(Type) * strideFactor);
}
strides.push_back(sizeof(Type));
- return py::buffer_info(data, sizeof(Type),
- py::format_descriptor<Type>::format(), rank, shape,
- strides, /*readonly=*/true);
+ std::string format;
+ if (explicitFormat) {
+ format = explicitFormat;
+ } else {
+ format = py::format_descriptor<Type>::format();
+ }
+ return py::buffer_info(data, sizeof(Type), format, rank, shape, strides,
+ /*readonly=*/true);
}
}; // namespace
diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index a2ee06722f0d8..3b15212e30023 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -331,6 +331,21 @@ MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType,
unwrapList(numElements, elements, attributes)));
}
+MlirAttribute mlirDenseElementsAttrRawBufferGet(MlirType shapedType,
+ size_t rawBufferSize,
+ const void *rawBuffer) {
+ auto shapedTypeCpp = unwrap(shapedType).cast<ShapedType>();
+ ArrayRef<char> rawBufferCpp(static_cast<const char *>(rawBuffer),
+ rawBufferSize);
+ bool isSplat = false;
+ if (!DenseElementsAttr::isValidRawBuffer(shapedTypeCpp, rawBufferCpp,
+ isSplat)) {
+ return mlirAttributeGetNull();
+ }
+ return wrap(DenseElementsAttr::getFromRawBuffer(shapedTypeCpp, rawBufferCpp,
+ isSplat));
+}
+
MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType,
MlirAttribute element) {
return wrap(DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index e03bac1d3d5c2..27d851c6594ee 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -792,9 +792,16 @@ bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
// Storage width of 1 is special as it is packed by the bit.
if (storageWidth == 1) {
- // Check for a splat, or a buffer equal to the number of elements.
- if ((detectedSplat = rawBuffer.size() == 1))
- return true;
+ // Check for a splat, or a buffer equal to the number of elements which
+ // consists of either all 0's or all 1's.
+ detectedSplat = false;
+ if (rawBuffer.size() == 1) {
+ auto rawByte = static_cast<uint8_t>(rawBuffer[0]);
+ if (rawByte == 0 || rawByte == 0xff) {
+ detectedSplat = true;
+ return true;
+ }
+ }
return rawBufferWidth == llvm::alignTo<8>(type.getNumElements());
}
// All other types are 8-bit aligned.
diff --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py
index 2a904e63ac5f2..13c7c215484b4 100644
--- a/mlir/test/python/ir/array_attributes.py
+++ b/mlir/test/python/ir/array_attributes.py
@@ -11,11 +11,13 @@ def run(f):
f()
gc.collect()
assert Context._get_live_count() == 0
+ return f
################################################################################
# Tests of the array/buffer .get() factory method on unsupported dtype.
################################################################################
+ at run
def testGetDenseElementsUnsupported():
with Context():
array = np.array([["hello", "goodbye"]])
@@ -25,13 +27,12 @@ def testGetDenseElementsUnsupported():
# CHECK: unimplemented array format conversion from format:
print(e)
-run(testGetDenseElementsUnsupported)
-
################################################################################
# Splats.
################################################################################
# CHECK-LABEL: TEST: testGetDenseElementsSplatInt
+ at run
def testGetDenseElementsSplatInt():
with Context(), Location.unknown():
t = IntegerType.get_signless(32)
@@ -43,10 +44,9 @@ def testGetDenseElementsSplatInt():
# CHECK: is_splat: True
print("is_splat:", attr.is_splat)
-run(testGetDenseElementsSplatInt)
-
# CHECK-LABEL: TEST: testGetDenseElementsSplatFloat
+ at run
def testGetDenseElementsSplatFloat():
with Context(), Location.unknown():
t = F32Type.get()
@@ -56,10 +56,9 @@ def testGetDenseElementsSplatFloat():
# CHECK: dense<1.200000e+00> : tensor<2x3x4xf32>
print(attr)
-run(testGetDenseElementsSplatFloat)
-
# CHECK-LABEL: TEST: testGetDenseElementsSplatErrors
+ at run
def testGetDenseElementsSplatErrors():
with Context(), Location.unknown():
t = F32Type.get()
@@ -88,32 +87,113 @@ def testGetDenseElementsSplatErrors():
# CHECK: Shaped element type and attribute type must be equal: shaped=Type(tensor<2x3x4xf32>), element=Attribute(1.200000e+00 : f64)
print(e)
-run(testGetDenseElementsSplatErrors)
+
+# CHECK-LABEL: TEST: testRepeatedValuesSplat
+ at run
+def testRepeatedValuesSplat():
+ with Context():
+ array = np.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], dtype=np.float32)
+ attr = DenseElementsAttr.get(array)
+ # CHECK: dense<1.000000e+00> : tensor<2x3xf32>
+ print(attr)
+ # CHECK: is_splat: True
+ print("is_splat:", attr.is_splat)
+ # CHECK: ()
+ print(np.array(attr))
+
+
+# CHECK-LABEL: TEST: testNonSplat
+ at run
+def testNonSplat():
+ with Context():
+ array = np.array([2.0, 1.0, 1.0], dtype=np.float32)
+ attr = DenseElementsAttr.get(array)
+ # CHECK: is_splat: False
+ print("is_splat:", attr.is_splat)
################################################################################
# Tests of the array/buffer .get() factory method, in all of its permutations.
################################################################################
+### explicitly provided types
+
+ at run
+def testGetDenseElementsBF16():
+ with Context():
+ array = np.array([[2, 4, 8], [16, 32, 64]], dtype=np.uint16)
+ attr = DenseElementsAttr.get(array, type=BF16Type.get())
+ # Note: These values don't mean much since just bit-casting. But they
+ # shouldn't change.
+ # CHECK: dense<{{\[}}[1.836710e-40, 3.673420e-40, 7.346840e-40], [1.469370e-39, 2.938740e-39, 5.877470e-39]]> : tensor<2x3xbf16>
+ print(attr)
+
+ at run
+def testGetDenseElementsInteger4():
+ with Context():
+ array = np.array([[2, 4, 7], [-2, -4, -8]], dtype=np.uint8)
+ attr = DenseElementsAttr.get(array, type=IntegerType.get_signless(4))
+ # Note: These values don't mean much since just bit-casting. But they
+ # shouldn't change.
+ # CHECK: dense<{{\[}}[2, 4, 7], [-2, -4, -8]]> : tensor<2x3xi4>
+ print(attr)
+
+
+ at run
+def testGetDenseElementsBool():
+ with Context():
+ bool_array = np.array([[1, 0, 1], [0, 1, 0]], dtype=np.bool_)
+ array = np.packbits(bool_array, axis=None, bitorder="little")
+ attr = DenseElementsAttr.get(
+ array, type=IntegerType.get_signless(1), shape=bool_array.shape)
+ # CHECK: dense<{{\[}}[true, false, true], [false, true, false]]> : tensor<2x3xi1>
+ print(attr)
+
+
+ at run
+def testGetDenseElementsBoolSplat():
+ with Context():
+ zero = np.array(0, dtype=np.uint8)
+ one = np.array(255, dtype=np.uint8)
+ print(one)
+ # CHECK: dense<false> : tensor<4x2x5xi1>
+ print(DenseElementsAttr.get(
+ zero, type=IntegerType.get_signless(1), shape=(4, 2, 5)))
+ # CHECK: dense<true> : tensor<4x2x5xi1>
+ print(DenseElementsAttr.get(
+ one, type=IntegerType.get_signless(1), shape=(4, 2, 5)))
+
+
### float and double arrays.
+# CHECK-LABEL: TEST: testGetDenseElementsF16
+ at run
+def testGetDenseElementsF16():
+ with Context():
+ array = np.array([[2.0, 4.0, 8.0], [16.0, 32.0, 64.0]], dtype=np.float16)
+ attr = DenseElementsAttr.get(array)
+ # CHECK: dense<{{\[}}[2.000000e+00, 4.000000e+00, 8.000000e+00], [1.600000e+01, 3.200000e+01, 6.400000e+01]]> : tensor<2x3xf16>
+ print(attr)
+ # CHECK: {{\[}}[ 2. 4. 8.]
+ # CHECK: {{\[}}16. 32. 64.]]
+ print(np.array(attr))
+
+
# CHECK-LABEL: TEST: testGetDenseElementsF32
+ at run
def testGetDenseElementsF32():
with Context():
array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float32)
attr = DenseElementsAttr.get(array)
# CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf32>
print(attr)
- # CHECK: is_splat: False
- print("is_splat:", attr.is_splat)
# CHECK: {{\[}}[1.1 2.2 3.3]
# CHECK: {{\[}}4.4 5.5 6.6]]
print(np.array(attr))
-run(testGetDenseElementsF32)
-
# CHECK-LABEL: TEST: testGetDenseElementsF64
+ at run
def testGetDenseElementsF64():
with Context():
array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float64)
@@ -124,11 +204,62 @@ def testGetDenseElementsF64():
# CHECK: {{\[}}4.4 5.5 6.6]]
print(np.array(attr))
-run(testGetDenseElementsF64)
+### 16 bit integer arrays
+# CHECK-LABEL: TEST: testGetDenseElementsI16Signless
+ at run
+def testGetDenseElementsI16Signless():
+ with Context():
+ array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16)
+ attr = DenseElementsAttr.get(array)
+ # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi16>
+ print(attr)
+ # CHECK: {{\[}}[1 2 3]
+ # CHECK: {{\[}}4 5 6]]
+ print(np.array(attr))
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsUI16Signless
+ at run
+def testGetDenseElementsUI16Signless():
+ with Context():
+ array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16)
+ attr = DenseElementsAttr.get(array)
+ # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi16>
+ print(attr)
+ # CHECK: {{\[}}[1 2 3]
+ # CHECK: {{\[}}4 5 6]]
+ print(np.array(attr))
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsI16
+ at run
+def testGetDenseElementsI16():
+ with Context():
+ array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16)
+ attr = DenseElementsAttr.get(array, signless=False)
+ # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi16>
+ print(attr)
+ # CHECK: {{\[}}[1 2 3]
+ # CHECK: {{\[}}4 5 6]]
+ print(np.array(attr))
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsUI16
+ at run
+def testGetDenseElementsUI16():
+ with Context():
+ array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint16)
+ attr = DenseElementsAttr.get(array, signless=False)
+ # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui16>
+ print(attr)
+ # CHECK: {{\[}}[1 2 3]
+ # CHECK: {{\[}}4 5 6]]
+ print(np.array(attr))
### 32 bit integer arrays
# CHECK-LABEL: TEST: testGetDenseElementsI32Signless
+ at run
def testGetDenseElementsI32Signless():
with Context():
array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
@@ -139,10 +270,9 @@ def testGetDenseElementsI32Signless():
# CHECK: {{\[}}4 5 6]]
print(np.array(attr))
-run(testGetDenseElementsI32Signless)
-
# CHECK-LABEL: TEST: testGetDenseElementsUI32Signless
+ at run
def testGetDenseElementsUI32Signless():
with Context():
array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32)
@@ -153,9 +283,9 @@ def testGetDenseElementsUI32Signless():
# CHECK: {{\[}}4 5 6]]
print(np.array(attr))
-run(testGetDenseElementsUI32Signless)
# CHECK-LABEL: TEST: testGetDenseElementsI32
+ at run
def testGetDenseElementsI32():
with Context():
array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
@@ -166,10 +296,9 @@ def testGetDenseElementsI32():
# CHECK: {{\[}}4 5 6]]
print(np.array(attr))
-run(testGetDenseElementsI32)
-
# CHECK-LABEL: TEST: testGetDenseElementsUI32
+ at run
def testGetDenseElementsUI32():
with Context():
array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32)
@@ -180,11 +309,10 @@ def testGetDenseElementsUI32():
# CHECK: {{\[}}4 5 6]]
print(np.array(attr))
-run(testGetDenseElementsUI32)
-
## 64bit integer arrays
# CHECK-LABEL: TEST: testGetDenseElementsI64Signless
+ at run
def testGetDenseElementsI64Signless():
with Context():
array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
@@ -195,10 +323,9 @@ def testGetDenseElementsI64Signless():
# CHECK: {{\[}}4 5 6]]
print(np.array(attr))
-run(testGetDenseElementsI64Signless)
-
# CHECK-LABEL: TEST: testGetDenseElementsUI64Signless
+ at run
def testGetDenseElementsUI64Signless():
with Context():
array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64)
@@ -209,9 +336,9 @@ def testGetDenseElementsUI64Signless():
# CHECK: {{\[}}4 5 6]]
print(np.array(attr))
-run(testGetDenseElementsUI64Signless)
# CHECK-LABEL: TEST: testGetDenseElementsI64
+ at run
def testGetDenseElementsI64():
with Context():
array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
@@ -222,10 +349,9 @@ def testGetDenseElementsI64():
# CHECK: {{\[}}4 5 6]]
print(np.array(attr))
-run(testGetDenseElementsI64)
-
# CHECK-LABEL: TEST: testGetDenseElementsUI64
+ at run
def testGetDenseElementsUI64():
with Context():
array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64)
@@ -236,5 +362,3 @@ def testGetDenseElementsUI64():
# CHECK: {{\[}}4 5 6]]
print(np.array(attr))
-run(testGetDenseElementsUI64)
-
More information about the Mlir-commits
mailing list