[Mlir-commits] [mlir] f0e847d - [mlir][python] Support buffer protocol for splat dense attributes
Rahul Kayaith
llvmlistbot at llvm.org
Thu Mar 30 07:18:09 PDT 2023
Author: Rahul Kayaith
Date: 2023-03-30T10:18:03-04:00
New Revision: f0e847d0a104b45398fd0a981110098bc29250e9
URL: https://github.com/llvm/llvm-project/commit/f0e847d0a104b45398fd0a981110098bc29250e9
DIFF: https://github.com/llvm/llvm-project/commit/f0e847d0a104b45398fd0a981110098bc29250e9.diff
LOG: [mlir][python] Support buffer protocol for splat dense attributes
These can be made to work by setting the buffer strides to 0.
Reviewed By: stellaraccident
Differential Revision: https://reviews.llvm.org/D147187
Added:
Modified:
mlir/lib/Bindings/Python/IRAttributes.cpp
mlir/test/python/ir/array_attributes.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 40598ecfd21a7..d252044c8e656 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -688,13 +688,6 @@ class PyDenseElementsAttribute
intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
py::buffer_info accessBuffer() {
- if (mlirDenseElementsAttrIsSplat(*this)) {
- // TODO: Currently crashes the program.
- // Reported as https://github.com/pybind/pybind11/issues/3336
- throw std::invalid_argument(
- "unsupported data type for conversion to Python buffer");
- }
-
MlirType shapedType = mlirAttributeGetType(*this);
MlirType elementType = mlirShapedTypeGetElementType(shapedType);
std::string format;
@@ -821,15 +814,18 @@ class PyDenseElementsAttribute
shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
// Prepare the strides for the buffer_info.
SmallVector<intptr_t, 4> strides;
- intptr_t strideFactor = 1;
- for (intptr_t i = 1; i < rank; ++i) {
- strideFactor = 1;
- for (intptr_t j = i; j < rank; ++j) {
- strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
+ if (mlirDenseElementsAttrIsSplat(*this)) {
+ // Splats are special, only the single value is stored.
+ strides.assign(rank, 0);
+ } else {
+ for (intptr_t i = 1; i < rank; ++i) {
+ intptr_t strideFactor = 1;
+ for (intptr_t j = i; j < rank; ++j)
+ strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
+ strides.push_back(sizeof(Type) * strideFactor);
}
- strides.push_back(sizeof(Type) * strideFactor);
+ strides.push_back(sizeof(Type));
}
- strides.push_back(sizeof(Type));
std::string format;
if (explicitFormat) {
format = explicitFormat;
diff --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py
index c1f1633eecaaf..36b0769b20653 100644
--- a/mlir/test/python/ir/array_attributes.py
+++ b/mlir/test/python/ir/array_attributes.py
@@ -100,10 +100,9 @@ def testRepeatedValuesSplat():
print(attr)
# CHECK: is_splat: True
print("is_splat:", attr.is_splat)
- # TODO: Re-enable this once a solution is found to raising an exception
- # from buffer protocol.
- # Reported as https://github.com/pybind/pybind11/issues/3336
- # print(np.array(attr))
+ # CHECK{LITERAL}: [[1. 1. 1.]
+ # CHECK{LITERAL}: [1. 1. 1.]]
+ print(np.array(attr))
# CHECK-LABEL: TEST: testNonSplat
More information about the Mlir-commits
mailing list