[Mlir-commits] [mlir] f0e847d - [mlir][python] Support buffer protocol for splat dense attributes

Rahul Kayaith llvmlistbot at llvm.org
Thu Mar 30 07:18:09 PDT 2023


Author: Rahul Kayaith
Date: 2023-03-30T10:18:03-04:00
New Revision: f0e847d0a104b45398fd0a981110098bc29250e9

URL: https://github.com/llvm/llvm-project/commit/f0e847d0a104b45398fd0a981110098bc29250e9
DIFF: https://github.com/llvm/llvm-project/commit/f0e847d0a104b45398fd0a981110098bc29250e9.diff

LOG: [mlir][python] Support buffer protocol for splat dense attributes

These can be made to work by setting the buffer strides to 0.

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D147187

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRAttributes.cpp
    mlir/test/python/ir/array_attributes.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 40598ecfd21a7..d252044c8e656 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -688,13 +688,6 @@ class PyDenseElementsAttribute
   intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
 
   py::buffer_info accessBuffer() {
-    if (mlirDenseElementsAttrIsSplat(*this)) {
-      // TODO: Currently crashes the program.
-      // Reported as https://github.com/pybind/pybind11/issues/3336
-      throw std::invalid_argument(
-          "unsupported data type for conversion to Python buffer");
-    }
-
     MlirType shapedType = mlirAttributeGetType(*this);
     MlirType elementType = mlirShapedTypeGetElementType(shapedType);
     std::string format;
@@ -821,15 +814,18 @@ class PyDenseElementsAttribute
       shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
     // Prepare the strides for the buffer_info.
     SmallVector<intptr_t, 4> strides;
-    intptr_t strideFactor = 1;
-    for (intptr_t i = 1; i < rank; ++i) {
-      strideFactor = 1;
-      for (intptr_t j = i; j < rank; ++j) {
-        strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
+    if (mlirDenseElementsAttrIsSplat(*this)) {
+      // Splats are special, only the single value is stored.
+      strides.assign(rank, 0);
+    } else {
+      for (intptr_t i = 1; i < rank; ++i) {
+        intptr_t strideFactor = 1;
+        for (intptr_t j = i; j < rank; ++j)
+          strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
+        strides.push_back(sizeof(Type) * strideFactor);
       }
-      strides.push_back(sizeof(Type) * strideFactor);
+      strides.push_back(sizeof(Type));
     }
-    strides.push_back(sizeof(Type));
     std::string format;
     if (explicitFormat) {
       format = explicitFormat;

diff  --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py
index c1f1633eecaaf..36b0769b20653 100644
--- a/mlir/test/python/ir/array_attributes.py
+++ b/mlir/test/python/ir/array_attributes.py
@@ -100,10 +100,9 @@ def testRepeatedValuesSplat():
     print(attr)
     # CHECK: is_splat: True
     print("is_splat:", attr.is_splat)
-    # TODO: Re-enable this once a solution is found to raising an exception
-    # from buffer protocol.
-    # Reported as https://github.com/pybind/pybind11/issues/3336
-    # print(np.array(attr))
+    # CHECK{LITERAL}: [[1. 1. 1.]
+    # CHECK{LITERAL}:  [1. 1. 1.]]
+    print(np.array(attr))
 
 
 # CHECK-LABEL: TEST: testNonSplat


        


More information about the Mlir-commits mailing list