[Mlir-commits] [mlir] [mlir][python] Add bindings for mlirDenseElementsAttrGet (PR #91389)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue May 21 16:20:49 PDT 2024
https://github.com/pranavm-nvidia updated https://github.com/llvm/llvm-project/pull/91389
>From 0703e7044eed65611dd0f129ec14365d6a5515a6 Mon Sep 17 00:00:00 2001
From: pranavm <pranavm at nvidia.com>
Date: Tue, 7 May 2024 10:52:27 -0700
Subject: [PATCH 1/2] [mlir][python] Add bindings for mlirDenseElementsAttrGet
This change adds bindings for `mlirDenseElementsAttrGet` which accepts a list of
MLIR attributes and constructs a DenseElementsAttr. This allows for creating
`DenseElementsAttr`s of types not natively supported by Python (e.g. BF16) without
requiring other dependencies (e.g. `numpy` + `ml-dtypes`).
---
mlir/lib/Bindings/Python/IRAttributes.cpp | 84 +++++++++++++++++++++--
mlir/test/python/ir/array_attributes.py | 82 ++++++++++++++++++++++
2 files changed, 161 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index dda2003ba0375..756cd8cfab989 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -72,6 +72,26 @@ or 255), then a splat will be created.
type or if the buffer does not meet expectations.
)";
+static const char kDenseElementsAttrGetFromListDocstring[] =
+ R"(Gets a DenseElementsAttr from a Python list of attributes.
+
+Note that it can be expensive to construct attributes individually.
+For a large number of elements, consider using a Python buffer or array instead.
+
+Args:
+ attrs: A list of attributes.
+ type: The desired shape and type of the resulting DenseElementsAttr.
+ If not provided, the element type is determined based on the type
+ of the 0th attribute and the shape is `[len(attrs)]`.
+ context: Explicit context, if not from context manager.
+
+Returns:
+ DenseElementsAttr on success.
+
+Raises:
+ ValueError: If the type of the attributes does not match the type specified by `shaped_type`.
+)";
+
static const char kDenseResourceElementsAttrGetFromBufferDocstring[] =
R"(Gets a DenseResourceElementsAttr from a Python buffer or array.
@@ -647,6 +667,55 @@ class PyDenseElementsAttribute
static constexpr const char *pyClassName = "DenseElementsAttr";
using PyConcreteAttribute::PyConcreteAttribute;
+ static PyDenseElementsAttribute
+ getFromList(py::list attributes, std::optional<PyType> explicitType,
+ DefaultingPyMlirContext contextWrapper) {
+
+ const size_t numAttributes = py::len(attributes);
+ if (numAttributes == 0)
+ throw py::value_error("Attributes list must be non-empty");
+
+ MlirType shapedType;
+ if (explicitType) {
+ if ((!mlirTypeIsAShaped(*explicitType) ||
+ !mlirShapedTypeHasStaticShape(*explicitType))) {
+ std::string message =
+ "Expected a static ShapedType for the shaped_type parameter: ";
+ message.append(py::repr(py::cast(*explicitType)));
+ throw py::value_error(message);
+ }
+ shapedType = *explicitType;
+ } else {
+ SmallVector<int64_t> shape{static_cast<int64_t>(numAttributes)};
+ shapedType = mlirRankedTensorTypeGet(
+ shape.size(), shape.data(),
+ mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])),
+ mlirAttributeGetNull());
+ }
+
+ SmallVector<MlirAttribute> mlirAttributes;
+ mlirAttributes.reserve(numAttributes);
+ for (auto attribute : attributes) {
+ MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
+ MlirType attrType = mlirAttributeGetType(mlirAttribute);
+ mlirAttributes.push_back(mlirAttribute);
+
+ if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) {
+ std::string message = "All attributes must be of the same type and "
+ "match the type parameter: expected=";
+ message.append(py::repr(py::cast(shapedType)));
+ message.append(", but got=");
+ message.append(py::repr(py::cast(attrType)));
+ throw py::value_error(message);
+ }
+ }
+
+ MlirAttribute elements = mlirDenseElementsAttrGet(
+ shapedType, mlirAttributes.size(), mlirAttributes.data());
+
+ return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
+ }
+
static PyDenseElementsAttribute
getFromBuffer(py::buffer array, bool signless,
std::optional<PyType> explicitType,
@@ -883,6 +952,10 @@ class PyDenseElementsAttribute
py::arg("type") = py::none(), py::arg("shape") = py::none(),
py::arg("context") = py::none(),
kDenseElementsAttrGetDocstring)
+ .def_static("get", PyDenseElementsAttribute::getFromList,
+ py::arg("attrs"), py::arg("type") = py::none(),
+ py::arg("context") = py::none(),
+ kDenseElementsAttrGetFromListDocstring)
.def_static("get_splat", PyDenseElementsAttribute::getSplat,
py::arg("shaped_type"), py::arg("element_attr"),
"Gets a DenseElementsAttr where all values are the same")
@@ -954,8 +1027,8 @@ class PyDenseElementsAttribute
}
}; // namespace
-/// Refinement of the PyDenseElementsAttribute for attributes containing integer
-/// (and boolean) values. Supports element access.
+/// Refinement of the PyDenseElementsAttribute for attributes containing
+/// integer (and boolean) values. Supports element access.
class PyDenseIntElementsAttribute
: public PyConcreteAttribute<PyDenseIntElementsAttribute,
PyDenseElementsAttribute> {
@@ -964,8 +1037,8 @@ class PyDenseIntElementsAttribute
static constexpr const char *pyClassName = "DenseIntElementsAttr";
using PyConcreteAttribute::PyConcreteAttribute;
- /// Returns the element at the given linear position. Asserts if the index is
- /// out of range.
+ /// Returns the element at the given linear position. Asserts if the index
+ /// is out of range.
py::int_ dunderGetItem(intptr_t pos) {
if (pos < 0 || pos >= dunderLen()) {
throw py::index_error("attempt to access out of bounds element");
@@ -1267,7 +1340,8 @@ class PyStridedLayoutAttribute
return PyStridedLayoutAttribute(ctx->getRef(), attr);
},
py::arg("rank"), py::arg("context") = py::none(),
- "Gets a strided layout attribute with dynamic offset and strides of a "
+ "Gets a strided layout attribute with dynamic offset and strides of "
+ "a "
"given rank.");
c.def_property_readonly(
"offset",
diff --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py
index 9251588a4c48a..2bc403aace834 100644
--- a/mlir/test/python/ir/array_attributes.py
+++ b/mlir/test/python/ir/array_attributes.py
@@ -50,6 +50,87 @@ def testGetDenseElementsUnSupportedTypeOkIfExplicitTypeProvided():
print(np.array(attr))
+################################################################################
+# Tests of the list of attributes .get() factory method
+################################################################################
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsFromList
+ at run
+def testGetDenseElementsFromList():
+ with Context(), Location.unknown():
+ attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F64Type.get(), 2.0)]
+ attr = DenseElementsAttr.get(attrs)
+
+ # CHECK: dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf64>
+ print(attr)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsFromListWithExplicitType
+ at run
+def testGetDenseElementsFromListWithExplicitType():
+ with Context(), Location.unknown():
+ attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F64Type.get(), 2.0)]
+ shaped_type = ShapedType(Type.parse("tensor<2xf64>"))
+ attr = DenseElementsAttr.get(attrs, shaped_type)
+
+ # CHECK: dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf64>
+ print(attr)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsFromListEmptyList
+ at run
+def testGetDenseElementsFromListEmptyList():
+ with Context(), Location.unknown():
+ attrs = []
+
+ try:
+ attr = DenseElementsAttr.get(attrs)
+ except ValueError as e:
+ # CHECK: Attributes list must be non-empty
+ print(e)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsFromListNonAttributeType
+ at run
+def testGetDenseElementsFromListNonAttributeType():
+ with Context(), Location.unknown():
+ attrs = [1.0]
+
+ try:
+ attr = DenseElementsAttr.get(attrs)
+ except RuntimeError as e:
+ # CHECK: Invalid attribute when attempting to create an ArrayAttribute
+ print(e)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsFromListMismatchedType
+ at run
+def testGetDenseElementsFromListMismatchedType():
+ with Context(), Location.unknown():
+ attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F64Type.get(), 2.0)]
+ shaped_type = ShapedType(Type.parse("tensor<2xf32>"))
+
+ try:
+ attr = DenseElementsAttr.get(attrs, shaped_type)
+ except ValueError as e:
+ # CHECK: All attributes must be of the same type and match the type parameter
+ print(e)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsFromListMixedTypes
+ at run
+def testGetDenseElementsFromListMixedTypes():
+ with Context(), Location.unknown():
+ attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F32Type.get(), 2.0)]
+
+ try:
+ attr = DenseElementsAttr.get(attrs)
+ except ValueError as e:
+ # CHECK: All attributes must be of the same type and match the type parameter
+ print(e)
+
+
################################################################################
# Splats.
################################################################################
@@ -205,6 +286,7 @@ def testGetDenseElementsBoolSplat():
### float and double arrays.
+
# CHECK-LABEL: TEST: testGetDenseElementsF16
@run
def testGetDenseElementsF16():
>From 11c716f51fa243545484379dcad24618bdd8861c Mon Sep 17 00:00:00 2001
From: pranavm <pranavm at nvidia.com>
Date: Tue, 21 May 2024 16:20:36 -0700
Subject: [PATCH 2/2] f
---
mlir/lib/Bindings/Python/IRAttributes.cpp | 41 ++++++++++++-----------
1 file changed, 22 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 756cd8cfab989..d0c05bf0ffb54 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -15,6 +15,7 @@
#include "PybindUtils.h"
#include "llvm/ADT/ScopeExit.h"
+#include "llvm/Support/raw_ostream.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
@@ -89,7 +90,8 @@ For a large number of elements, consider using a Python buffer or array instead.
DenseElementsAttr on success.
Raises:
- ValueError: If the type of the attributes does not match the type specified by `shaped_type`.
+ ValueError: If the type of the attributes does not match the type
+ specified by `shaped_type`.
)";
static const char kDenseResourceElementsAttrGetFromBufferDocstring[] =
@@ -673,16 +675,18 @@ class PyDenseElementsAttribute
const size_t numAttributes = py::len(attributes);
if (numAttributes == 0)
- throw py::value_error("Attributes list must be non-empty");
+ throw py::value_error("Attributes list must be non-empty.");
MlirType shapedType;
if (explicitType) {
if ((!mlirTypeIsAShaped(*explicitType) ||
!mlirShapedTypeHasStaticShape(*explicitType))) {
- std::string message =
- "Expected a static ShapedType for the shaped_type parameter: ";
- message.append(py::repr(py::cast(*explicitType)));
- throw py::value_error(message);
+
+ std::string message;
+ llvm::raw_string_ostream os(message);
+ os << "Expected a static ShapedType for the shaped_type parameter: "
+ << py::repr(py::cast(*explicitType));
+ throw py::value_error(os.str());
}
shapedType = *explicitType;
} else {
@@ -695,18 +699,18 @@ class PyDenseElementsAttribute
SmallVector<MlirAttribute> mlirAttributes;
mlirAttributes.reserve(numAttributes);
- for (auto attribute : attributes) {
+ for (const py::handle& attribute : attributes) {
MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
MlirType attrType = mlirAttributeGetType(mlirAttribute);
mlirAttributes.push_back(mlirAttribute);
if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) {
- std::string message = "All attributes must be of the same type and "
- "match the type parameter: expected=";
- message.append(py::repr(py::cast(shapedType)));
- message.append(", but got=");
- message.append(py::repr(py::cast(attrType)));
- throw py::value_error(message);
+ std::string message;
+ llvm::raw_string_ostream os(message);
+ os << "All attributes must be of the same type and match "
+ << "the type parameter: expected=" << py::repr(py::cast(shapedType))
+ << ", but got=" << py::repr(py::cast(attrType));
+ throw py::value_error(os.str());
}
}
@@ -1027,8 +1031,8 @@ class PyDenseElementsAttribute
}
}; // namespace
-/// Refinement of the PyDenseElementsAttribute for attributes containing
-/// integer (and boolean) values. Supports element access.
+/// Refinement of the PyDenseElementsAttribute for attributes containing integer
+/// (and boolean) values. Supports element access.
class PyDenseIntElementsAttribute
: public PyConcreteAttribute<PyDenseIntElementsAttribute,
PyDenseElementsAttribute> {
@@ -1037,8 +1041,8 @@ class PyDenseIntElementsAttribute
static constexpr const char *pyClassName = "DenseIntElementsAttr";
using PyConcreteAttribute::PyConcreteAttribute;
- /// Returns the element at the given linear position. Asserts if the index
- /// is out of range.
+ /// Returns the element at the given linear position. Asserts if the index is
+ /// out of range.
py::int_ dunderGetItem(intptr_t pos) {
if (pos < 0 || pos >= dunderLen()) {
throw py::index_error("attempt to access out of bounds element");
@@ -1340,8 +1344,7 @@ class PyStridedLayoutAttribute
return PyStridedLayoutAttribute(ctx->getRef(), attr);
},
py::arg("rank"), py::arg("context") = py::none(),
- "Gets a strided layout attribute with dynamic offset and strides of "
- "a "
+ "Gets a strided layout attribute with dynamic offset and strides of a "
"given rank.");
c.def_property_readonly(
"offset",
More information about the Mlir-commits
mailing list