[Mlir-commits] [mlir] [mlir][python] Add bindings for mlirDenseElementsAttrGet (PR #91389)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue May 7 13:04:49 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: None (pranavm-nvidia)
<details>
<summary>Changes</summary>
This change adds bindings for `mlirDenseElementsAttrGet` which accepts a list of MLIR attributes and constructs a DenseElementsAttr. This allows for creating `DenseElementsAttr`s of types not natively supported by Python (e.g. BF16) without requiring other dependencies (e.g. `numpy` + `ml-dtypes`).
---
Full diff: https://github.com/llvm/llvm-project/pull/91389.diff
2 Files Affected:
- (modified) mlir/lib/Bindings/Python/IRAttributes.cpp (+76-5)
- (modified) mlir/test/python/ir/array_attributes.py (+82)
``````````diff
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index dda2003ba037..b7ad4f3a78b7 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -72,6 +72,23 @@ or 255), then a splat will be created.
type or if the buffer does not meet expectations.
)";
+static const char kDenseElementsAttrGetFromListDocstring[] =
+ R"(Gets a DenseElementsAttr from a Python list of attributes.
+
+Args:
+ attrs: A list of attributes.
+ type: The desired shape and type of the resulting DenseElementsAttr.
+ If not provided, the element type is determined based on the types
+ of the attributes and the shape is `[len(attrs)]`.
+ context: Explicit context, if not from context manager.
+
+Returns:
+ DenseElementsAttr on success.
+
+Raises:
+ ValueError: If the type of the attributes does not match the type specified by `shaped_type`.
+)";
+
static const char kDenseResourceElementsAttrGetFromBufferDocstring[] =
R"(Gets a DenseResourceElementsAttr from a Python buffer or array.
@@ -647,6 +664,55 @@ class PyDenseElementsAttribute
static constexpr const char *pyClassName = "DenseElementsAttr";
using PyConcreteAttribute::PyConcreteAttribute;
+ static PyDenseElementsAttribute
+ getFromList(py::list attributes, std::optional<PyType> explicitType,
+ DefaultingPyMlirContext contextWrapper) {
+
+ if (py::len(attributes) == 0) {
+ throw py::value_error("Attributes list must be non-empty");
+ }
+
+ MlirType shapedType;
+ if (explicitType) {
+ if ((!mlirTypeIsAShaped(*explicitType) ||
+ !mlirShapedTypeHasStaticShape(*explicitType))) {
+ std::string message =
+ "Expected a static ShapedType for the shaped_type parameter: ";
+ message.append(py::repr(py::cast(*explicitType)));
+ throw py::value_error(message);
+ }
+ shapedType = *explicitType;
+ } else {
+ SmallVector<int64_t> shape{py::len(attributes)};
+ shapedType = mlirRankedTensorTypeGet(
+ shape.size(), shape.data(),
+ mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])),
+ mlirAttributeGetNull());
+ }
+
+ SmallVector<MlirAttribute> mlirAttributes;
+ mlirAttributes.reserve(py::len(attributes));
+ for (auto attribute : attributes) {
+ MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
+ MlirType attrType = mlirAttributeGetType(mlirAttribute);
+ mlirAttributes.push_back(mlirAttribute);
+
+ if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) {
+ std::string message = "All attributes must be of the same type and "
+ "match the type parameter: expected=";
+ message.append(py::repr(py::cast(shapedType)));
+ message.append(", but got=");
+ message.append(py::repr(py::cast(attrType)));
+ throw py::value_error(message);
+ }
+ }
+
+ MlirAttribute elements = mlirDenseElementsAttrGet(
+ shapedType, mlirAttributes.size(), mlirAttributes.data());
+
+ return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
+ }
+
static PyDenseElementsAttribute
getFromBuffer(py::buffer array, bool signless,
std::optional<PyType> explicitType,
@@ -883,6 +949,10 @@ class PyDenseElementsAttribute
py::arg("type") = py::none(), py::arg("shape") = py::none(),
py::arg("context") = py::none(),
kDenseElementsAttrGetDocstring)
+ .def_static("get", PyDenseElementsAttribute::getFromList,
+ py::arg("attrs"), py::arg("type") = py::none(),
+ py::arg("context") = py::none(),
+ kDenseElementsAttrGetFromListDocstring)
.def_static("get_splat", PyDenseElementsAttribute::getSplat,
py::arg("shaped_type"), py::arg("element_attr"),
"Gets a DenseElementsAttr where all values are the same")
@@ -954,8 +1024,8 @@ class PyDenseElementsAttribute
}
}; // namespace
-/// Refinement of the PyDenseElementsAttribute for attributes containing integer
-/// (and boolean) values. Supports element access.
+/// Refinement of the PyDenseElementsAttribute for attributes containing
+/// integer (and boolean) values. Supports element access.
class PyDenseIntElementsAttribute
: public PyConcreteAttribute<PyDenseIntElementsAttribute,
PyDenseElementsAttribute> {
@@ -964,8 +1034,8 @@ class PyDenseIntElementsAttribute
static constexpr const char *pyClassName = "DenseIntElementsAttr";
using PyConcreteAttribute::PyConcreteAttribute;
- /// Returns the element at the given linear position. Asserts if the index is
- /// out of range.
+ /// Returns the element at the given linear position. Asserts if the index
+ /// is out of range.
py::int_ dunderGetItem(intptr_t pos) {
if (pos < 0 || pos >= dunderLen()) {
throw py::index_error("attempt to access out of bounds element");
@@ -1267,7 +1337,8 @@ class PyStridedLayoutAttribute
return PyStridedLayoutAttribute(ctx->getRef(), attr);
},
py::arg("rank"), py::arg("context") = py::none(),
- "Gets a strided layout attribute with dynamic offset and strides of a "
+ "Gets a strided layout attribute with dynamic offset and strides of "
+ "a "
"given rank.");
c.def_property_readonly(
"offset",
diff --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py
index 9251588a4c48..2bc403aace83 100644
--- a/mlir/test/python/ir/array_attributes.py
+++ b/mlir/test/python/ir/array_attributes.py
@@ -50,6 +50,87 @@ def testGetDenseElementsUnSupportedTypeOkIfExplicitTypeProvided():
print(np.array(attr))
+################################################################################
+# Tests of the list of attributes .get() factory method
+################################################################################
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsFromList
+ at run
+def testGetDenseElementsFromList():
+ with Context(), Location.unknown():
+ attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F64Type.get(), 2.0)]
+ attr = DenseElementsAttr.get(attrs)
+
+ # CHECK: dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf64>
+ print(attr)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsFromListWithExplicitType
+ at run
+def testGetDenseElementsFromListWithExplicitType():
+ with Context(), Location.unknown():
+ attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F64Type.get(), 2.0)]
+ shaped_type = ShapedType(Type.parse("tensor<2xf64>"))
+ attr = DenseElementsAttr.get(attrs, shaped_type)
+
+ # CHECK: dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf64>
+ print(attr)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsFromListEmptyList
+ at run
+def testGetDenseElementsFromListEmptyList():
+ with Context(), Location.unknown():
+ attrs = []
+
+ try:
+ attr = DenseElementsAttr.get(attrs)
+ except ValueError as e:
+ # CHECK: Attributes list must be non-empty
+ print(e)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsFromListNonAttributeType
+ at run
+def testGetDenseElementsFromListNonAttributeType():
+ with Context(), Location.unknown():
+ attrs = [1.0]
+
+ try:
+ attr = DenseElementsAttr.get(attrs)
+ except RuntimeError as e:
+ # CHECK: Invalid attribute when attempting to create an ArrayAttribute
+ print(e)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsFromListMismatchedType
+ at run
+def testGetDenseElementsFromListMismatchedType():
+ with Context(), Location.unknown():
+ attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F64Type.get(), 2.0)]
+ shaped_type = ShapedType(Type.parse("tensor<2xf32>"))
+
+ try:
+ attr = DenseElementsAttr.get(attrs, shaped_type)
+ except ValueError as e:
+ # CHECK: All attributes must be of the same type and match the type parameter
+ print(e)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsFromListMixedTypes
+ at run
+def testGetDenseElementsFromListMixedTypes():
+ with Context(), Location.unknown():
+ attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F32Type.get(), 2.0)]
+
+ try:
+ attr = DenseElementsAttr.get(attrs)
+ except ValueError as e:
+ # CHECK: All attributes must be of the same type and match the type parameter
+ print(e)
+
+
################################################################################
# Splats.
################################################################################
@@ -205,6 +286,7 @@ def testGetDenseElementsBoolSplat():
### float and double arrays.
+
# CHECK-LABEL: TEST: testGetDenseElementsF16
@run
def testGetDenseElementsF16():
``````````
</details>
https://github.com/llvm/llvm-project/pull/91389
More information about the Mlir-commits
mailing list