[Mlir-commits] [mlir] 77133b2 - [mlir] Get array from the dense elements attribute with buffer protocol.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Nov 17 23:51:26 PST 2020
Author: zhanghb97
Date: 2020-11-18T15:50:59+08:00
New Revision: 77133b29b93406638915c7d9a6b8b8a81a067df3
URL: https://github.com/llvm/llvm-project/commit/77133b29b93406638915c7d9a6b8b8a81a067df3
DIFF: https://github.com/llvm/llvm-project/commit/77133b29b93406638915c7d9a6b8b8a81a067df3.diff
LOG: [mlir] Get array from the dense elements attribute with buffer protocol.
- Add `mlirElementsAttrGetType` C API.
- Add `def_buffer` binding to PyDenseElementsAttribute.
- Implement the protocol to access the buffer.
Differential Revision: https://reviews.llvm.org/D91021
Added:
Modified:
mlir/docs/Bindings/Python.md
mlir/include/mlir-c/StandardAttributes.h
mlir/lib/Bindings/Python/IRModules.cpp
mlir/lib/CAPI/IR/StandardAttributes.cpp
mlir/test/Bindings/Python/ir_array_attributes.py
mlir/test/CAPI/ir.c
Removed:
################################################################################
diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index fdcd1eb6d3a6..a1626ea4505e 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -7,7 +7,7 @@ Current status: Under development and not enabled by default
### Pre-requisites
* [`pybind11`](https://github.com/pybind/pybind11) must be installed and able to
- be located by CMake.
+ be located by CMake. Note: minimum version required: :2.6.0
* A relatively recent Python3 installation
### CMake variables
diff --git a/mlir/include/mlir-c/StandardAttributes.h b/mlir/include/mlir-c/StandardAttributes.h
index 81f8fd366f31..161722e03914 100644
--- a/mlir/include/mlir-c/StandardAttributes.h
+++ b/mlir/include/mlir-c/StandardAttributes.h
@@ -404,6 +404,10 @@ mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos);
MLIR_CAPI_EXPORTED MlirStringRef
mlirDenseElementsAttrGetStringValue(MlirAttribute attr, intptr_t pos);
+/** Returns the raw data of the given dense elements attribute. */
+MLIR_CAPI_EXPORTED const void *
+mlirDenseElementsAttrGetRawData(MlirAttribute attr);
+
//===----------------------------------------------------------------------===//
// Opaque elements attribute.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 7b5e341bc660..1821ff85ecce 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -1366,7 +1366,7 @@ class PyConcreteAttribute : public BaseTy {
}
static void bind(py::module &m) {
- auto cls = ClassTy(m, DerivedTy::pyClassName);
+ auto cls = ClassTy(m, DerivedTy::pyClassName, py::buffer_protocol());
cls.def(py::init<PyAttribute &>(), py::keep_alive<0, 1>());
DerivedTy::bindDerived(cls);
}
@@ -1630,6 +1630,42 @@ class PyDenseElementsAttribute
intptr_t dunderLen() { return mlirElementsAttrGetNumElements(attr); }
+ py::buffer_info accessBuffer() {
+ MlirType shapedType = mlirAttributeGetType(this->attr);
+ MlirType elementType = mlirShapedTypeGetElementType(shapedType);
+
+ if (mlirTypeIsAF32(elementType)) {
+ // f32
+ return bufferInfo(shapedType, mlirDenseElementsAttrGetFloatValue);
+ } else if (mlirTypeIsAF64(elementType)) {
+ // f64
+ return bufferInfo(shapedType, mlirDenseElementsAttrGetDoubleValue);
+ } else if (mlirTypeIsAInteger(elementType) &&
+ mlirIntegerTypeGetWidth(elementType) == 32) {
+ if (mlirIntegerTypeIsSignless(elementType) ||
+ mlirIntegerTypeIsSigned(elementType)) {
+ // i32
+ return bufferInfo(shapedType, mlirDenseElementsAttrGetInt32Value);
+ } else if (mlirIntegerTypeIsUnsigned(elementType)) {
+ // unsigned i32
+ return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt32Value);
+ }
+ } else if (mlirTypeIsAInteger(elementType) &&
+ mlirIntegerTypeGetWidth(elementType) == 64) {
+ if (mlirIntegerTypeIsSignless(elementType) ||
+ mlirIntegerTypeIsSigned(elementType)) {
+ // i64
+ return bufferInfo(shapedType, mlirDenseElementsAttrGetInt64Value);
+ } else if (mlirIntegerTypeIsUnsigned(elementType)) {
+ // unsigned i64
+ return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt64Value);
+ }
+ }
+
+ std::string message = "unimplemented array format.";
+ throw SetPyError(PyExc_ValueError, message);
+ }
+
static void bindDerived(ClassTy &c) {
c.def("__len__", &PyDenseElementsAttribute::dunderLen)
.def_static("get", PyDenseElementsAttribute::getFromBuffer,
@@ -1642,7 +1678,8 @@ class PyDenseElementsAttribute
.def_property_readonly("is_splat",
[](PyDenseElementsAttribute &self) -> bool {
return mlirDenseElementsAttrIsSplat(self.attr);
- });
+ })
+ .def_buffer(&PyDenseElementsAttribute::accessBuffer);
}
private:
@@ -1675,6 +1712,34 @@ class PyDenseElementsAttribute
return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
code == 'q';
}
+
+ template <typename Type>
+ py::buffer_info bufferInfo(MlirType shapedType,
+ Type (*value)(MlirAttribute, intptr_t)) {
+ intptr_t rank = mlirShapedTypeGetRank(shapedType);
+ // Prepare the data for the buffer_info.
+ // Buffer is configured for read-only access below.
+ Type *data = static_cast<Type *>(
+ const_cast<void *>(mlirDenseElementsAttrGetRawData(this->attr)));
+ // Prepare the shape for the buffer_info.
+ SmallVector<intptr_t, 4> shape;
+ for (intptr_t i = 0; i < rank; ++i)
+ shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
+ // Prepare the strides for the buffer_info.
+ SmallVector<intptr_t, 4> strides;
+ intptr_t strideFactor = 1;
+ for (intptr_t i = 1; i < rank; ++i) {
+ strideFactor = 1;
+ for (intptr_t j = i; j < rank; ++j) {
+ strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
+ }
+ strides.push_back(sizeof(Type) * strideFactor);
+ }
+ strides.push_back(sizeof(Type));
+ return py::buffer_info(data, sizeof(Type),
+ py::format_descriptor<Type>::format(), rank, shape,
+ strides, /*readonly=*/true);
+ }
}; // namespace
/// Refinement of the PyDenseElementsAttribute for attributes containing integer
diff --git a/mlir/lib/CAPI/IR/StandardAttributes.cpp b/mlir/lib/CAPI/IR/StandardAttributes.cpp
index 834ccb66f06a..4e9f03e57a3f 100644
--- a/mlir/lib/CAPI/IR/StandardAttributes.cpp
+++ b/mlir/lib/CAPI/IR/StandardAttributes.cpp
@@ -516,6 +516,14 @@ MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr,
pos));
}
+//===----------------------------------------------------------------------===//
+// Raw data accessors.
+
+const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) {
+ return static_cast<const void *>(
+ unwrap(attr).cast<DenseElementsAttr>().getRawData().data());
+}
+
//===----------------------------------------------------------------------===//
// Opaque elements attribute.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Bindings/Python/ir_array_attributes.py b/mlir/test/Bindings/Python/ir_array_attributes.py
index 74b5451aafe0..2a904e63ac5f 100644
--- a/mlir/test/Bindings/Python/ir_array_attributes.py
+++ b/mlir/test/Bindings/Python/ir_array_attributes.py
@@ -106,6 +106,9 @@ def testGetDenseElementsF32():
print(attr)
# CHECK: is_splat: False
print("is_splat:", attr.is_splat)
+ # CHECK: {{\[}}[1.1 2.2 3.3]
+ # CHECK: {{\[}}4.4 5.5 6.6]]
+ print(np.array(attr))
run(testGetDenseElementsF32)
@@ -117,6 +120,9 @@ def testGetDenseElementsF64():
attr = DenseElementsAttr.get(array)
# CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf64>
print(attr)
+ # CHECK: {{\[}}[1.1 2.2 3.3]
+ # CHECK: {{\[}}4.4 5.5 6.6]]
+ print(np.array(attr))
run(testGetDenseElementsF64)
@@ -129,6 +135,9 @@ def testGetDenseElementsI32Signless():
attr = DenseElementsAttr.get(array)
# CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
print(attr)
+ # CHECK: {{\[}}[1 2 3]
+ # CHECK: {{\[}}4 5 6]]
+ print(np.array(attr))
run(testGetDenseElementsI32Signless)
@@ -140,6 +149,9 @@ def testGetDenseElementsUI32Signless():
attr = DenseElementsAttr.get(array)
# CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
print(attr)
+ # CHECK: {{\[}}[1 2 3]
+ # CHECK: {{\[}}4 5 6]]
+ print(np.array(attr))
run(testGetDenseElementsUI32Signless)
@@ -150,6 +162,9 @@ def testGetDenseElementsI32():
attr = DenseElementsAttr.get(array, signless=False)
# CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi32>
print(attr)
+ # CHECK: {{\[}}[1 2 3]
+ # CHECK: {{\[}}4 5 6]]
+ print(np.array(attr))
run(testGetDenseElementsI32)
@@ -161,6 +176,9 @@ def testGetDenseElementsUI32():
attr = DenseElementsAttr.get(array, signless=False)
# CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui32>
print(attr)
+ # CHECK: {{\[}}[1 2 3]
+ # CHECK: {{\[}}4 5 6]]
+ print(np.array(attr))
run(testGetDenseElementsUI32)
@@ -173,6 +191,9 @@ def testGetDenseElementsI64Signless():
attr = DenseElementsAttr.get(array)
# CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
print(attr)
+ # CHECK: {{\[}}[1 2 3]
+ # CHECK: {{\[}}4 5 6]]
+ print(np.array(attr))
run(testGetDenseElementsI64Signless)
@@ -184,6 +205,9 @@ def testGetDenseElementsUI64Signless():
attr = DenseElementsAttr.get(array)
# CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
print(attr)
+ # CHECK: {{\[}}[1 2 3]
+ # CHECK: {{\[}}4 5 6]]
+ print(np.array(attr))
run(testGetDenseElementsUI64Signless)
@@ -194,6 +218,9 @@ def testGetDenseElementsI64():
attr = DenseElementsAttr.get(array, signless=False)
# CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi64>
print(attr)
+ # CHECK: {{\[}}[1 2 3]
+ # CHECK: {{\[}}4 5 6]]
+ print(np.array(attr))
run(testGetDenseElementsI64)
@@ -205,6 +232,9 @@ def testGetDenseElementsUI64():
attr = DenseElementsAttr.get(array, signless=False)
# CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui64>
print(attr)
+ # CHECK: {{\[}}[1 2 3]
+ # CHECK: {{\[}}4 5 6]]
+ print(np.array(attr))
run(testGetDenseElementsUI64)
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 677c105fc3b5..83d66555dba7 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -903,6 +903,26 @@ int printStandardAttributes(MlirContext ctx) {
fabs(mlirDenseElementsAttrGetDoubleSplatValue(splatDouble) - 1.0) > 1E-6)
return 17;
+ uint32_t *uint32RawData =
+ (uint32_t *)mlirDenseElementsAttrGetRawData(uint32Elements);
+ int32_t *int32RawData =
+ (int32_t *)mlirDenseElementsAttrGetRawData(int32Elements);
+ uint64_t *uint64RawData =
+ (uint64_t *)mlirDenseElementsAttrGetRawData(uint64Elements);
+ int64_t *int64RawData =
+ (int64_t *)mlirDenseElementsAttrGetRawData(int64Elements);
+ float *floatRawData =
+ (float *)mlirDenseElementsAttrGetRawData(floatElements);
+ double *doubleRawData =
+ (double *)mlirDenseElementsAttrGetRawData(doubleElements);
+ if (uint32RawData[0] != 0u || uint32RawData[1] != 1u ||
+ int32RawData[0] != 0 || int32RawData[1] != 1 ||
+ uint64RawData[0] != 0u || uint64RawData[1] != 1u ||
+ int64RawData[0] != 0 || int64RawData[1] != 1 ||
+ floatRawData[0] != 0.0f || floatRawData[1] != 1.0f ||
+ doubleRawData[0] != 0.0 || doubleRawData[1] != 1.0)
+ return 18;
+
mlirAttributeDump(splatBool);
mlirAttributeDump(splatUInt32);
mlirAttributeDump(splatInt32);
More information about the Mlir-commits
mailing list