[Mlir-commits] [mlir] 77133b2 - [mlir] Get array from the dense elements attribute with buffer protocol.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 17 23:51:26 PST 2020


Author: zhanghb97
Date: 2020-11-18T15:50:59+08:00
New Revision: 77133b29b93406638915c7d9a6b8b8a81a067df3

URL: https://github.com/llvm/llvm-project/commit/77133b29b93406638915c7d9a6b8b8a81a067df3
DIFF: https://github.com/llvm/llvm-project/commit/77133b29b93406638915c7d9a6b8b8a81a067df3.diff

LOG: [mlir] Get array from the dense elements attribute with buffer protocol.

- Add `mlirElementsAttrGetType` C API.
- Add `def_buffer` binding to PyDenseElementsAttribute.
- Implement the protocol to access the buffer.

Differential Revision: https://reviews.llvm.org/D91021

Added: 
    

Modified: 
    mlir/docs/Bindings/Python.md
    mlir/include/mlir-c/StandardAttributes.h
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/lib/CAPI/IR/StandardAttributes.cpp
    mlir/test/Bindings/Python/ir_array_attributes.py
    mlir/test/CAPI/ir.c

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index fdcd1eb6d3a6..a1626ea4505e 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -7,7 +7,7 @@ Current status: Under development and not enabled by default
 ### Pre-requisites
 
 * [`pybind11`](https://github.com/pybind/pybind11) must be installed and able to
-  be located by CMake.
+  be located by CMake. Note: minimum version required: :2.6.0
 * A relatively recent Python3 installation
 
 ### CMake variables

diff  --git a/mlir/include/mlir-c/StandardAttributes.h b/mlir/include/mlir-c/StandardAttributes.h
index 81f8fd366f31..161722e03914 100644
--- a/mlir/include/mlir-c/StandardAttributes.h
+++ b/mlir/include/mlir-c/StandardAttributes.h
@@ -404,6 +404,10 @@ mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos);
 MLIR_CAPI_EXPORTED MlirStringRef
 mlirDenseElementsAttrGetStringValue(MlirAttribute attr, intptr_t pos);
 
+/** Returns the raw data of the given dense elements attribute. */
+MLIR_CAPI_EXPORTED const void *
+mlirDenseElementsAttrGetRawData(MlirAttribute attr);
+
 //===----------------------------------------------------------------------===//
 // Opaque elements attribute.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 7b5e341bc660..1821ff85ecce 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -1366,7 +1366,7 @@ class PyConcreteAttribute : public BaseTy {
   }
 
   static void bind(py::module &m) {
-    auto cls = ClassTy(m, DerivedTy::pyClassName);
+    auto cls = ClassTy(m, DerivedTy::pyClassName, py::buffer_protocol());
     cls.def(py::init<PyAttribute &>(), py::keep_alive<0, 1>());
     DerivedTy::bindDerived(cls);
   }
@@ -1630,6 +1630,42 @@ class PyDenseElementsAttribute
 
   intptr_t dunderLen() { return mlirElementsAttrGetNumElements(attr); }
 
+  py::buffer_info accessBuffer() {
+    MlirType shapedType = mlirAttributeGetType(this->attr);
+    MlirType elementType = mlirShapedTypeGetElementType(shapedType);
+
+    if (mlirTypeIsAF32(elementType)) {
+      // f32
+      return bufferInfo(shapedType, mlirDenseElementsAttrGetFloatValue);
+    } else if (mlirTypeIsAF64(elementType)) {
+      // f64
+      return bufferInfo(shapedType, mlirDenseElementsAttrGetDoubleValue);
+    } else if (mlirTypeIsAInteger(elementType) &&
+               mlirIntegerTypeGetWidth(elementType) == 32) {
+      if (mlirIntegerTypeIsSignless(elementType) ||
+          mlirIntegerTypeIsSigned(elementType)) {
+        // i32
+        return bufferInfo(shapedType, mlirDenseElementsAttrGetInt32Value);
+      } else if (mlirIntegerTypeIsUnsigned(elementType)) {
+        // unsigned i32
+        return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt32Value);
+      }
+    } else if (mlirTypeIsAInteger(elementType) &&
+               mlirIntegerTypeGetWidth(elementType) == 64) {
+      if (mlirIntegerTypeIsSignless(elementType) ||
+          mlirIntegerTypeIsSigned(elementType)) {
+        // i64
+        return bufferInfo(shapedType, mlirDenseElementsAttrGetInt64Value);
+      } else if (mlirIntegerTypeIsUnsigned(elementType)) {
+        // unsigned i64
+        return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt64Value);
+      }
+    }
+
+    std::string message = "unimplemented array format.";
+    throw SetPyError(PyExc_ValueError, message);
+  }
+
   static void bindDerived(ClassTy &c) {
     c.def("__len__", &PyDenseElementsAttribute::dunderLen)
         .def_static("get", PyDenseElementsAttribute::getFromBuffer,
@@ -1642,7 +1678,8 @@ class PyDenseElementsAttribute
         .def_property_readonly("is_splat",
                                [](PyDenseElementsAttribute &self) -> bool {
                                  return mlirDenseElementsAttrIsSplat(self.attr);
-                               });
+                               })
+        .def_buffer(&PyDenseElementsAttribute::accessBuffer);
   }
 
 private:
@@ -1675,6 +1712,34 @@ class PyDenseElementsAttribute
     return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
            code == 'q';
   }
+
+  template <typename Type>
+  py::buffer_info bufferInfo(MlirType shapedType,
+                             Type (*value)(MlirAttribute, intptr_t)) {
+    intptr_t rank = mlirShapedTypeGetRank(shapedType);
+    // Prepare the data for the buffer_info.
+    // Buffer is configured for read-only access below.
+    Type *data = static_cast<Type *>(
+        const_cast<void *>(mlirDenseElementsAttrGetRawData(this->attr)));
+    // Prepare the shape for the buffer_info.
+    SmallVector<intptr_t, 4> shape;
+    for (intptr_t i = 0; i < rank; ++i)
+      shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
+    // Prepare the strides for the buffer_info.
+    SmallVector<intptr_t, 4> strides;
+    intptr_t strideFactor = 1;
+    for (intptr_t i = 1; i < rank; ++i) {
+      strideFactor = 1;
+      for (intptr_t j = i; j < rank; ++j) {
+        strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
+      }
+      strides.push_back(sizeof(Type) * strideFactor);
+    }
+    strides.push_back(sizeof(Type));
+    return py::buffer_info(data, sizeof(Type),
+                           py::format_descriptor<Type>::format(), rank, shape,
+                           strides, /*readonly=*/true);
+  }
 }; // namespace
 
 /// Refinement of the PyDenseElementsAttribute for attributes containing integer

diff  --git a/mlir/lib/CAPI/IR/StandardAttributes.cpp b/mlir/lib/CAPI/IR/StandardAttributes.cpp
index 834ccb66f06a..4e9f03e57a3f 100644
--- a/mlir/lib/CAPI/IR/StandardAttributes.cpp
+++ b/mlir/lib/CAPI/IR/StandardAttributes.cpp
@@ -516,6 +516,14 @@ MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr,
         pos));
 }
 
+//===----------------------------------------------------------------------===//
+// Raw data accessors.
+
+const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) {
+  return static_cast<const void *>(
+      unwrap(attr).cast<DenseElementsAttr>().getRawData().data());
+}
+
 //===----------------------------------------------------------------------===//
 // Opaque elements attribute.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Bindings/Python/ir_array_attributes.py b/mlir/test/Bindings/Python/ir_array_attributes.py
index 74b5451aafe0..2a904e63ac5f 100644
--- a/mlir/test/Bindings/Python/ir_array_attributes.py
+++ b/mlir/test/Bindings/Python/ir_array_attributes.py
@@ -106,6 +106,9 @@ def testGetDenseElementsF32():
     print(attr)
     # CHECK: is_splat: False
     print("is_splat:", attr.is_splat)
+    # CHECK: {{\[}}[1.1 2.2 3.3]
+    # CHECK: {{\[}}4.4 5.5 6.6]]
+    print(np.array(attr))
 
 run(testGetDenseElementsF32)
 
@@ -117,6 +120,9 @@ def testGetDenseElementsF64():
     attr = DenseElementsAttr.get(array)
     # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf64>
     print(attr)
+    # CHECK: {{\[}}[1.1 2.2 3.3]
+    # CHECK: {{\[}}4.4 5.5 6.6]]
+    print(np.array(attr))
 
 run(testGetDenseElementsF64)
 
@@ -129,6 +135,9 @@ def testGetDenseElementsI32Signless():
     attr = DenseElementsAttr.get(array)
     # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
     print(attr)
+    # CHECK: {{\[}}[1 2 3]
+    # CHECK: {{\[}}4 5 6]]
+    print(np.array(attr))
 
 run(testGetDenseElementsI32Signless)
 
@@ -140,6 +149,9 @@ def testGetDenseElementsUI32Signless():
     attr = DenseElementsAttr.get(array)
     # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
     print(attr)
+    # CHECK: {{\[}}[1 2 3]
+    # CHECK: {{\[}}4 5 6]]
+    print(np.array(attr))
 
 run(testGetDenseElementsUI32Signless)
 
@@ -150,6 +162,9 @@ def testGetDenseElementsI32():
     attr = DenseElementsAttr.get(array, signless=False)
     # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi32>
     print(attr)
+    # CHECK: {{\[}}[1 2 3]
+    # CHECK: {{\[}}4 5 6]]
+    print(np.array(attr))
 
 run(testGetDenseElementsI32)
 
@@ -161,6 +176,9 @@ def testGetDenseElementsUI32():
     attr = DenseElementsAttr.get(array, signless=False)
     # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui32>
     print(attr)
+    # CHECK: {{\[}}[1 2 3]
+    # CHECK: {{\[}}4 5 6]]
+    print(np.array(attr))
 
 run(testGetDenseElementsUI32)
 
@@ -173,6 +191,9 @@ def testGetDenseElementsI64Signless():
     attr = DenseElementsAttr.get(array)
     # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
     print(attr)
+    # CHECK: {{\[}}[1 2 3]
+    # CHECK: {{\[}}4 5 6]]
+    print(np.array(attr))
 
 run(testGetDenseElementsI64Signless)
 
@@ -184,6 +205,9 @@ def testGetDenseElementsUI64Signless():
     attr = DenseElementsAttr.get(array)
     # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
     print(attr)
+    # CHECK: {{\[}}[1 2 3]
+    # CHECK: {{\[}}4 5 6]]
+    print(np.array(attr))
 
 run(testGetDenseElementsUI64Signless)
 
@@ -194,6 +218,9 @@ def testGetDenseElementsI64():
     attr = DenseElementsAttr.get(array, signless=False)
     # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi64>
     print(attr)
+    # CHECK: {{\[}}[1 2 3]
+    # CHECK: {{\[}}4 5 6]]
+    print(np.array(attr))
 
 run(testGetDenseElementsI64)
 
@@ -205,6 +232,9 @@ def testGetDenseElementsUI64():
     attr = DenseElementsAttr.get(array, signless=False)
     # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui64>
     print(attr)
+    # CHECK: {{\[}}[1 2 3]
+    # CHECK: {{\[}}4 5 6]]
+    print(np.array(attr))
 
 run(testGetDenseElementsUI64)
 

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 677c105fc3b5..83d66555dba7 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -903,6 +903,26 @@ int printStandardAttributes(MlirContext ctx) {
       fabs(mlirDenseElementsAttrGetDoubleSplatValue(splatDouble) - 1.0) > 1E-6)
     return 17;
 
+  uint32_t *uint32RawData =
+      (uint32_t *)mlirDenseElementsAttrGetRawData(uint32Elements);
+  int32_t *int32RawData =
+      (int32_t *)mlirDenseElementsAttrGetRawData(int32Elements);
+  uint64_t *uint64RawData =
+      (uint64_t *)mlirDenseElementsAttrGetRawData(uint64Elements);
+  int64_t *int64RawData =
+      (int64_t *)mlirDenseElementsAttrGetRawData(int64Elements);
+  float *floatRawData =
+      (float *)mlirDenseElementsAttrGetRawData(floatElements);
+  double *doubleRawData =
+      (double *)mlirDenseElementsAttrGetRawData(doubleElements);
+  if (uint32RawData[0] != 0u || uint32RawData[1] != 1u ||
+      int32RawData[0] != 0 || int32RawData[1] != 1 ||
+      uint64RawData[0] != 0u || uint64RawData[1] != 1u ||
+      int64RawData[0] != 0 || int64RawData[1] != 1 ||
+      floatRawData[0] != 0.0f || floatRawData[1] != 1.0f ||
+      doubleRawData[0] != 0.0 || doubleRawData[1] != 1.0)
+    return 18;
+
   mlirAttributeDump(splatBool);
   mlirAttributeDump(splatUInt32);
   mlirAttributeDump(splatInt32);


        


More information about the Mlir-commits mailing list