[Mlir-commits] [mlir] 9125996 - Support retrieving the splat value from DenseElementsAttrs in Python

Jacques Pienaar llvmlistbot at llvm.org
Tue Mar 21 08:43:36 PDT 2023


Author: Adam Paszke
Date: 2023-03-21T08:43:17-07:00
New Revision: 912599638027e5cbed7b11318273b8703837c6ae

URL: https://github.com/llvm/llvm-project/commit/912599638027e5cbed7b11318273b8703837c6ae
DIFF: https://github.com/llvm/llvm-project/commit/912599638027e5cbed7b11318273b8703837c6ae.diff

LOG: Support retrieving the splat value from DenseElementsAttrs in Python

This is especially convenient when trying to resize the splat.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D146510

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRAttributes.cpp
    mlir/test/python/ir/array_attributes.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index c59a54b6699a7..40598ecfd21a7 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -777,6 +777,16 @@ class PyDenseElementsAttribute
                                [](PyDenseElementsAttribute &self) -> bool {
                                  return mlirDenseElementsAttrIsSplat(self);
                                })
+        .def("get_splat_value",
+             [](PyDenseElementsAttribute &self) -> PyAttribute {
+               if (!mlirDenseElementsAttrIsSplat(self)) {
+                 throw SetPyError(
+                     PyExc_ValueError,
+                     "get_splat_value called on a non-splat attribute");
+               }
+               return PyAttribute(self.getContext(),
+                                  mlirDenseElementsAttrGetSplatValue(self));
+             })
         .def_buffer(&PyDenseElementsAttribute::accessBuffer);
   }
 

diff  --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py
index b618802e52436..c1f1633eecaaf 100644
--- a/mlir/test/python/ir/array_attributes.py
+++ b/mlir/test/python/ir/array_attributes.py
@@ -43,6 +43,7 @@ def testGetDenseElementsSplatInt():
     print(attr)
     # CHECK: is_splat: True
     print("is_splat:", attr.is_splat)
+    assert attr.get_splat_value() == element
 
 
 # CHECK-LABEL: TEST: testGetDenseElementsSplatFloat
@@ -55,6 +56,7 @@ def testGetDenseElementsSplatFloat():
     attr = DenseElementsAttr.get_splat(shaped_type, element)
     # CHECK: dense<1.200000e+00> : tensor<2x3x4xf32>
     print(attr)
+    assert attr.get_splat_value() == element
 
 
 # CHECK-LABEL: TEST: testGetDenseElementsSplatErrors


        


More information about the Mlir-commits mailing list