[Mlir-commits] [mlir] [mlir][python] Add bindings for mlirDenseElementsAttrGet (PR #91389)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 21 13:48:28 PDT 2024


https://github.com/pranavm-nvidia updated https://github.com/llvm/llvm-project/pull/91389

>From e267483c376bcbdf7febf275d87069812fc49095 Mon Sep 17 00:00:00 2001
From: pranavm <pranavm at nvidia.com>
Date: Tue, 7 May 2024 10:52:27 -0700
Subject: [PATCH] [mlir][python] Add bindings for mlirDenseElementsAttrGet

This change adds bindings for `mlirDenseElementsAttrGet` which accepts a list of
MLIR attributes and constructs a DenseElementsAttr. This allows for creating
`DenseElementsAttr`s of types not natively supported by Python (e.g. BF16) without
requiring other dependencies (e.g. `numpy` + `ml-dtypes`).
---
 mlir/lib/Bindings/Python/IRAttributes.cpp | 84 +++++++++++++++++++++--
 mlir/test/python/ir/array_attributes.py   | 82 ++++++++++++++++++++++
 2 files changed, 161 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index dda2003ba0375..756cd8cfab989 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -72,6 +72,26 @@ or 255), then a splat will be created.
     type or if the buffer does not meet expectations.
 )";
 
+static const char kDenseElementsAttrGetFromListDocstring[] =
+    R"(Gets a DenseElementsAttr from a Python list of attributes.
+
+Note that it can be expensive to construct attributes individually.
+For a large number of elements, consider using a Python buffer or array instead.
+
+Args:
+  attrs: A list of attributes.
+  type: The desired shape and type of the resulting DenseElementsAttr.
+    If not provided, the element type is determined based on the type
+    of the 0th attribute and the shape is `[len(attrs)]`.
+  context: Explicit context, if not from context manager.
+
+Returns:
+  DenseElementsAttr on success.
+
+Raises:
+  ValueError: If the type of the attributes does not match the type specified by `shaped_type`.
+)";
+
 static const char kDenseResourceElementsAttrGetFromBufferDocstring[] =
     R"(Gets a DenseResourceElementsAttr from a Python buffer or array.
 
@@ -647,6 +667,55 @@ class PyDenseElementsAttribute
   static constexpr const char *pyClassName = "DenseElementsAttr";
   using PyConcreteAttribute::PyConcreteAttribute;
 
+  static PyDenseElementsAttribute
+  getFromList(py::list attributes, std::optional<PyType> explicitType,
+              DefaultingPyMlirContext contextWrapper) {
+
+    const size_t numAttributes = py::len(attributes);
+    if (numAttributes == 0)
+      throw py::value_error("Attributes list must be non-empty");
+
+    MlirType shapedType;
+    if (explicitType) {
+      if ((!mlirTypeIsAShaped(*explicitType) ||
+           !mlirShapedTypeHasStaticShape(*explicitType))) {
+        std::string message =
+            "Expected a static ShapedType for the shaped_type parameter: ";
+        message.append(py::repr(py::cast(*explicitType)));
+        throw py::value_error(message);
+      }
+      shapedType = *explicitType;
+    } else {
+      SmallVector<int64_t> shape{static_cast<int64_t>(numAttributes)};
+      shapedType = mlirRankedTensorTypeGet(
+          shape.size(), shape.data(),
+          mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])),
+          mlirAttributeGetNull());
+    }
+
+    SmallVector<MlirAttribute> mlirAttributes;
+    mlirAttributes.reserve(numAttributes);
+    for (auto attribute : attributes) {
+      MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
+      MlirType attrType = mlirAttributeGetType(mlirAttribute);
+      mlirAttributes.push_back(mlirAttribute);
+
+      if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) {
+        std::string message = "All attributes must be of the same type and "
+                              "match the type parameter: expected=";
+        message.append(py::repr(py::cast(shapedType)));
+        message.append(", but got=");
+        message.append(py::repr(py::cast(attrType)));
+        throw py::value_error(message);
+      }
+    }
+
+    MlirAttribute elements = mlirDenseElementsAttrGet(
+        shapedType, mlirAttributes.size(), mlirAttributes.data());
+
+    return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
+  }
+
   static PyDenseElementsAttribute
   getFromBuffer(py::buffer array, bool signless,
                 std::optional<PyType> explicitType,
@@ -883,6 +952,10 @@ class PyDenseElementsAttribute
                     py::arg("type") = py::none(), py::arg("shape") = py::none(),
                     py::arg("context") = py::none(),
                     kDenseElementsAttrGetDocstring)
+        .def_static("get", PyDenseElementsAttribute::getFromList,
+                    py::arg("attrs"), py::arg("type") = py::none(),
+                    py::arg("context") = py::none(),
+                    kDenseElementsAttrGetFromListDocstring)
         .def_static("get_splat", PyDenseElementsAttribute::getSplat,
                     py::arg("shaped_type"), py::arg("element_attr"),
                     "Gets a DenseElementsAttr where all values are the same")
@@ -954,8 +1027,8 @@ class PyDenseElementsAttribute
   }
 }; // namespace
 
-/// Refinement of the PyDenseElementsAttribute for attributes containing integer
-/// (and boolean) values. Supports element access.
+/// Refinement of the PyDenseElementsAttribute for attributes containing
+/// integer (and boolean) values. Supports element access.
 class PyDenseIntElementsAttribute
     : public PyConcreteAttribute<PyDenseIntElementsAttribute,
                                  PyDenseElementsAttribute> {
@@ -964,8 +1037,8 @@ class PyDenseIntElementsAttribute
   static constexpr const char *pyClassName = "DenseIntElementsAttr";
   using PyConcreteAttribute::PyConcreteAttribute;
 
-  /// Returns the element at the given linear position. Asserts if the index is
-  /// out of range.
+  /// Returns the element at the given linear position. Asserts if the index
+  /// is out of range.
   py::int_ dunderGetItem(intptr_t pos) {
     if (pos < 0 || pos >= dunderLen()) {
       throw py::index_error("attempt to access out of bounds element");
@@ -1267,7 +1340,8 @@ class PyStridedLayoutAttribute
           return PyStridedLayoutAttribute(ctx->getRef(), attr);
         },
         py::arg("rank"), py::arg("context") = py::none(),
-        "Gets a strided layout attribute with dynamic offset and strides of a "
+        "Gets a strided layout attribute with dynamic offset and strides of "
+        "a "
         "given rank.");
     c.def_property_readonly(
         "offset",
diff --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py
index 9251588a4c48a..2bc403aace834 100644
--- a/mlir/test/python/ir/array_attributes.py
+++ b/mlir/test/python/ir/array_attributes.py
@@ -50,6 +50,87 @@ def testGetDenseElementsUnSupportedTypeOkIfExplicitTypeProvided():
         print(np.array(attr))
 
 
+################################################################################
+# Tests of the list of attributes .get() factory method
+################################################################################
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsFromList
+ at run
+def testGetDenseElementsFromList():
+    with Context(), Location.unknown():
+        attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F64Type.get(), 2.0)]
+        attr = DenseElementsAttr.get(attrs)
+
+        # CHECK: dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf64>
+        print(attr)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsFromListWithExplicitType
+ at run
+def testGetDenseElementsFromListWithExplicitType():
+    with Context(), Location.unknown():
+        attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F64Type.get(), 2.0)]
+        shaped_type = ShapedType(Type.parse("tensor<2xf64>"))
+        attr = DenseElementsAttr.get(attrs, shaped_type)
+
+        # CHECK: dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf64>
+        print(attr)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsFromListEmptyList
+ at run
+def testGetDenseElementsFromListEmptyList():
+    with Context(), Location.unknown():
+        attrs = []
+
+        try:
+            attr = DenseElementsAttr.get(attrs)
+        except ValueError as e:
+            # CHECK: Attributes list must be non-empty
+            print(e)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsFromListNonAttributeType
+ at run
+def testGetDenseElementsFromListNonAttributeType():
+    with Context(), Location.unknown():
+        attrs = [1.0]
+
+        try:
+            attr = DenseElementsAttr.get(attrs)
+        except RuntimeError as e:
+            # CHECK: Invalid attribute when attempting to create an ArrayAttribute
+            print(e)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsFromListMismatchedType
+ at run
+def testGetDenseElementsFromListMismatchedType():
+    with Context(), Location.unknown():
+        attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F64Type.get(), 2.0)]
+        shaped_type = ShapedType(Type.parse("tensor<2xf32>"))
+
+        try:
+            attr = DenseElementsAttr.get(attrs, shaped_type)
+        except ValueError as e:
+            # CHECK: All attributes must be of the same type and match the type parameter
+            print(e)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsFromListMixedTypes
+ at run
+def testGetDenseElementsFromListMixedTypes():
+    with Context(), Location.unknown():
+        attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F32Type.get(), 2.0)]
+
+        try:
+            attr = DenseElementsAttr.get(attrs)
+        except ValueError as e:
+            # CHECK: All attributes must be of the same type and match the type parameter
+            print(e)
+
+
 ################################################################################
 # Splats.
 ################################################################################
@@ -205,6 +286,7 @@ def testGetDenseElementsBoolSplat():
 
 ### float and double arrays.
 
+
 # CHECK-LABEL: TEST: testGetDenseElementsF16
 @run
 def testGetDenseElementsF16():



More information about the Mlir-commits mailing list