[Mlir-commits] [mlir] c912f0e - [mlir][python] Add bindings for mlirDenseElementsAttrGet (#91389)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed May 22 03:44:26 PDT 2024
Author: pranavm-nvidia
Date: 2024-05-22T05:44:22-05:00
New Revision: c912f0e773386cc309155b78e2441ee5f1052c13
URL: https://github.com/llvm/llvm-project/commit/c912f0e773386cc309155b78e2441ee5f1052c13
DIFF: https://github.com/llvm/llvm-project/commit/c912f0e773386cc309155b78e2441ee5f1052c13.diff
LOG: [mlir][python] Add bindings for mlirDenseElementsAttrGet (#91389)
This change adds bindings for `mlirDenseElementsAttrGet` which accepts a
list of MLIR attributes and constructs a DenseElementsAttr. This allows
for creating `DenseElementsAttr`s of types not natively supported by
Python (e.g. BF16) without requiring other dependencies (e.g. `numpy` +
`ml-dtypes`).
Added:
Modified:
mlir/lib/Bindings/Python/IRAttributes.cpp
mlir/test/python/ir/array_attributes.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index dda2003ba0375..b5f31aa5dec54 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -15,6 +15,7 @@
#include "PybindUtils.h"
#include "llvm/ADT/ScopeExit.h"
+#include "llvm/Support/raw_ostream.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
@@ -72,6 +73,27 @@ or 255), then a splat will be created.
type or if the buffer does not meet expectations.
)";
+static const char kDenseElementsAttrGetFromListDocstring[] =
+ R"(Gets a DenseElementsAttr from a Python list of attributes.
+
+Note that it can be expensive to construct attributes individually.
+For a large number of elements, consider using a Python buffer or array instead.
+
+Args:
+ attrs: A list of attributes.
+ type: The desired shape and type of the resulting DenseElementsAttr.
+ If not provided, the element type is determined based on the type
+ of the 0th attribute and the shape is `[len(attrs)]`.
+ context: Explicit context, if not from context manager.
+
+Returns:
+ DenseElementsAttr on success.
+
+Raises:
+ ValueError: If the type of the attributes does not match the type
+ specified by `shaped_type`.
+)";
+
static const char kDenseResourceElementsAttrGetFromBufferDocstring[] =
R"(Gets a DenseResourceElementsAttr from a Python buffer or array.
@@ -647,6 +669,57 @@ class PyDenseElementsAttribute
static constexpr const char *pyClassName = "DenseElementsAttr";
using PyConcreteAttribute::PyConcreteAttribute;
+ static PyDenseElementsAttribute
+ getFromList(py::list attributes, std::optional<PyType> explicitType,
+ DefaultingPyMlirContext contextWrapper) {
+
+ const size_t numAttributes = py::len(attributes);
+ if (numAttributes == 0)
+ throw py::value_error("Attributes list must be non-empty.");
+
+ MlirType shapedType;
+ if (explicitType) {
+ if ((!mlirTypeIsAShaped(*explicitType) ||
+ !mlirShapedTypeHasStaticShape(*explicitType))) {
+
+ std::string message;
+ llvm::raw_string_ostream os(message);
+ os << "Expected a static ShapedType for the shaped_type parameter: "
+ << py::repr(py::cast(*explicitType));
+ throw py::value_error(os.str());
+ }
+ shapedType = *explicitType;
+ } else {
+ SmallVector<int64_t> shape{static_cast<int64_t>(numAttributes)};
+ shapedType = mlirRankedTensorTypeGet(
+ shape.size(), shape.data(),
+ mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])),
+ mlirAttributeGetNull());
+ }
+
+ SmallVector<MlirAttribute> mlirAttributes;
+ mlirAttributes.reserve(numAttributes);
+ for (const py::handle &attribute : attributes) {
+ MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
+ MlirType attrType = mlirAttributeGetType(mlirAttribute);
+ mlirAttributes.push_back(mlirAttribute);
+
+ if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) {
+ std::string message;
+ llvm::raw_string_ostream os(message);
+ os << "All attributes must be of the same type and match "
+ << "the type parameter: expected=" << py::repr(py::cast(shapedType))
+ << ", but got=" << py::repr(py::cast(attrType));
+ throw py::value_error(os.str());
+ }
+ }
+
+ MlirAttribute elements = mlirDenseElementsAttrGet(
+ shapedType, mlirAttributes.size(), mlirAttributes.data());
+
+ return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
+ }
+
static PyDenseElementsAttribute
getFromBuffer(py::buffer array, bool signless,
std::optional<PyType> explicitType,
@@ -883,6 +956,10 @@ class PyDenseElementsAttribute
py::arg("type") = py::none(), py::arg("shape") = py::none(),
py::arg("context") = py::none(),
kDenseElementsAttrGetDocstring)
+ .def_static("get", PyDenseElementsAttribute::getFromList,
+ py::arg("attrs"), py::arg("type") = py::none(),
+ py::arg("context") = py::none(),
+ kDenseElementsAttrGetFromListDocstring)
.def_static("get_splat", PyDenseElementsAttribute::getSplat,
py::arg("shaped_type"), py::arg("element_attr"),
"Gets a DenseElementsAttr where all values are the same")
diff --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py
index 9251588a4c48a..2bc403aace834 100644
--- a/mlir/test/python/ir/array_attributes.py
+++ b/mlir/test/python/ir/array_attributes.py
@@ -50,6 +50,87 @@ def testGetDenseElementsUnSupportedTypeOkIfExplicitTypeProvided():
print(np.array(attr))
+################################################################################
+# Tests of the list of attributes .get() factory method
+################################################################################
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsFromList
+ at run
+def testGetDenseElementsFromList():
+ with Context(), Location.unknown():
+ attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F64Type.get(), 2.0)]
+ attr = DenseElementsAttr.get(attrs)
+
+ # CHECK: dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf64>
+ print(attr)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsFromListWithExplicitType
+ at run
+def testGetDenseElementsFromListWithExplicitType():
+ with Context(), Location.unknown():
+ attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F64Type.get(), 2.0)]
+ shaped_type = ShapedType(Type.parse("tensor<2xf64>"))
+ attr = DenseElementsAttr.get(attrs, shaped_type)
+
+ # CHECK: dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf64>
+ print(attr)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsFromListEmptyList
+ at run
+def testGetDenseElementsFromListEmptyList():
+ with Context(), Location.unknown():
+ attrs = []
+
+ try:
+ attr = DenseElementsAttr.get(attrs)
+ except ValueError as e:
+ # CHECK: Attributes list must be non-empty
+ print(e)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsFromListNonAttributeType
+ at run
+def testGetDenseElementsFromListNonAttributeType():
+ with Context(), Location.unknown():
+ attrs = [1.0]
+
+ try:
+ attr = DenseElementsAttr.get(attrs)
+ except RuntimeError as e:
+ # CHECK: Invalid attribute when attempting to create an ArrayAttribute
+ print(e)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsFromListMismatchedType
+ at run
+def testGetDenseElementsFromListMismatchedType():
+ with Context(), Location.unknown():
+ attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F64Type.get(), 2.0)]
+ shaped_type = ShapedType(Type.parse("tensor<2xf32>"))
+
+ try:
+ attr = DenseElementsAttr.get(attrs, shaped_type)
+ except ValueError as e:
+ # CHECK: All attributes must be of the same type and match the type parameter
+ print(e)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsFromListMixedTypes
+ at run
+def testGetDenseElementsFromListMixedTypes():
+ with Context(), Location.unknown():
+ attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F32Type.get(), 2.0)]
+
+ try:
+ attr = DenseElementsAttr.get(attrs)
+ except ValueError as e:
+ # CHECK: All attributes must be of the same type and match the type parameter
+ print(e)
+
+
################################################################################
# Splats.
################################################################################
@@ -205,6 +286,7 @@ def testGetDenseElementsBoolSplat():
### float and double arrays.
+
# CHECK-LABEL: TEST: testGetDenseElementsF16
@run
def testGetDenseElementsF16():
More information about the Mlir-commits
mailing list