[Mlir-commits] [mlir] 9125996 - Support retrieving the splat value from DenseElementsAttrs in Python
Jacques Pienaar
llvmlistbot at llvm.org
Tue Mar 21 08:43:36 PDT 2023
Author: Adam Paszke
Date: 2023-03-21T08:43:17-07:00
New Revision: 912599638027e5cbed7b11318273b8703837c6ae
URL: https://github.com/llvm/llvm-project/commit/912599638027e5cbed7b11318273b8703837c6ae
DIFF: https://github.com/llvm/llvm-project/commit/912599638027e5cbed7b11318273b8703837c6ae.diff
LOG: Support retrieving the splat value from DenseElementsAttrs in Python
This is especially convenient when trying to resize the splat.
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D146510
Added:
Modified:
mlir/lib/Bindings/Python/IRAttributes.cpp
mlir/test/python/ir/array_attributes.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index c59a54b6699a7..40598ecfd21a7 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -777,6 +777,16 @@ class PyDenseElementsAttribute
[](PyDenseElementsAttribute &self) -> bool {
return mlirDenseElementsAttrIsSplat(self);
})
+ .def("get_splat_value",
+ [](PyDenseElementsAttribute &self) -> PyAttribute {
+ if (!mlirDenseElementsAttrIsSplat(self)) {
+ throw SetPyError(
+ PyExc_ValueError,
+ "get_splat_value called on a non-splat attribute");
+ }
+ return PyAttribute(self.getContext(),
+ mlirDenseElementsAttrGetSplatValue(self));
+ })
.def_buffer(&PyDenseElementsAttribute::accessBuffer);
}
diff --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py
index b618802e52436..c1f1633eecaaf 100644
--- a/mlir/test/python/ir/array_attributes.py
+++ b/mlir/test/python/ir/array_attributes.py
@@ -43,6 +43,7 @@ def testGetDenseElementsSplatInt():
print(attr)
# CHECK: is_splat: True
print("is_splat:", attr.is_splat)
+ assert attr.get_splat_value() == element
# CHECK-LABEL: TEST: testGetDenseElementsSplatFloat
@@ -55,6 +56,7 @@ def testGetDenseElementsSplatFloat():
attr = DenseElementsAttr.get_splat(shaped_type, element)
# CHECK: dense<1.200000e+00> : tensor<2x3x4xf32>
print(attr)
+ assert attr.get_splat_value() == element
# CHECK-LABEL: TEST: testGetDenseElementsSplatErrors
More information about the Mlir-commits
mailing list