[Mlir-commits] [mlir] [mlir][python] Add bindings for mlirDenseElementsAttrGet (PR #91389)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 21 17:11:10 PDT 2024


https://github.com/pranavm-nvidia updated https://github.com/llvm/llvm-project/pull/91389

>From 510942ad1fca6b7f837a14a4e577c6f48332de3d Mon Sep 17 00:00:00 2001
From: pranavm <pranavm at nvidia.com>
Date: Tue, 7 May 2024 10:52:27 -0700
Subject: [PATCH] [mlir][python] Add bindings for mlirDenseElementsAttrGet

This change adds bindings for `mlirDenseElementsAttrGet` which accepts a list of
MLIR attributes and constructs a DenseElementsAttr. This allows for creating
`DenseElementsAttr`s of types not natively supported by Python (e.g. BF16) without
requiring other dependencies (e.g. `numpy` + `ml-dtypes`).
---
 mlir/lib/Bindings/Python/IRAttributes.cpp | 77 +++++++++++++++++++++
 mlir/test/python/ir/array_attributes.py   | 82 +++++++++++++++++++++++
 2 files changed, 159 insertions(+)

diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index dda2003ba0375..b5f31aa5dec54 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -15,6 +15,7 @@
 #include "PybindUtils.h"
 
 #include "llvm/ADT/ScopeExit.h"
+#include "llvm/Support/raw_ostream.h"
 
 #include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/BuiltinTypes.h"
@@ -72,6 +73,27 @@ or 255), then a splat will be created.
     type or if the buffer does not meet expectations.
 )";
 
+static const char kDenseElementsAttrGetFromListDocstring[] =
+    R"(Gets a DenseElementsAttr from a Python list of attributes.
+
+Note that it can be expensive to construct attributes individually.
+For a large number of elements, consider using a Python buffer or array instead.
+
+Args:
+  attrs: A list of attributes.
+  type: The desired shape and type of the resulting DenseElementsAttr.
+    If not provided, the element type is determined based on the type
+    of the 0th attribute and the shape is `[len(attrs)]`.
+  context: Explicit context, if not from context manager.
+
+Returns:
+  DenseElementsAttr on success.
+
+Raises:
+  ValueError: If the type of the attributes does not match the type
+    specified by `shaped_type`.
+)";
+
 static const char kDenseResourceElementsAttrGetFromBufferDocstring[] =
     R"(Gets a DenseResourceElementsAttr from a Python buffer or array.
 
@@ -647,6 +669,57 @@ class PyDenseElementsAttribute
   static constexpr const char *pyClassName = "DenseElementsAttr";
   using PyConcreteAttribute::PyConcreteAttribute;
 
+  static PyDenseElementsAttribute
+  getFromList(py::list attributes, std::optional<PyType> explicitType,
+              DefaultingPyMlirContext contextWrapper) {
+
+    const size_t numAttributes = py::len(attributes);
+    if (numAttributes == 0)
+      throw py::value_error("Attributes list must be non-empty.");
+
+    MlirType shapedType;
+    if (explicitType) {
+      if ((!mlirTypeIsAShaped(*explicitType) ||
+           !mlirShapedTypeHasStaticShape(*explicitType))) {
+
+        std::string message;
+        llvm::raw_string_ostream os(message);
+        os << "Expected a static ShapedType for the shaped_type parameter: "
+           << py::repr(py::cast(*explicitType));
+        throw py::value_error(os.str());
+      }
+      shapedType = *explicitType;
+    } else {
+      SmallVector<int64_t> shape{static_cast<int64_t>(numAttributes)};
+      shapedType = mlirRankedTensorTypeGet(
+          shape.size(), shape.data(),
+          mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])),
+          mlirAttributeGetNull());
+    }
+
+    SmallVector<MlirAttribute> mlirAttributes;
+    mlirAttributes.reserve(numAttributes);
+    for (const py::handle &attribute : attributes) {
+      MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
+      MlirType attrType = mlirAttributeGetType(mlirAttribute);
+      mlirAttributes.push_back(mlirAttribute);
+
+      if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) {
+        std::string message;
+        llvm::raw_string_ostream os(message);
+        os << "All attributes must be of the same type and match "
+           << "the type parameter: expected=" << py::repr(py::cast(shapedType))
+           << ", but got=" << py::repr(py::cast(attrType));
+        throw py::value_error(os.str());
+      }
+    }
+
+    MlirAttribute elements = mlirDenseElementsAttrGet(
+        shapedType, mlirAttributes.size(), mlirAttributes.data());
+
+    return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
+  }
+
   static PyDenseElementsAttribute
   getFromBuffer(py::buffer array, bool signless,
                 std::optional<PyType> explicitType,
@@ -883,6 +956,10 @@ class PyDenseElementsAttribute
                     py::arg("type") = py::none(), py::arg("shape") = py::none(),
                     py::arg("context") = py::none(),
                     kDenseElementsAttrGetDocstring)
+        .def_static("get", PyDenseElementsAttribute::getFromList,
+                    py::arg("attrs"), py::arg("type") = py::none(),
+                    py::arg("context") = py::none(),
+                    kDenseElementsAttrGetFromListDocstring)
         .def_static("get_splat", PyDenseElementsAttribute::getSplat,
                     py::arg("shaped_type"), py::arg("element_attr"),
                     "Gets a DenseElementsAttr where all values are the same")
diff --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py
index 9251588a4c48a..2bc403aace834 100644
--- a/mlir/test/python/ir/array_attributes.py
+++ b/mlir/test/python/ir/array_attributes.py
@@ -50,6 +50,87 @@ def testGetDenseElementsUnSupportedTypeOkIfExplicitTypeProvided():
         print(np.array(attr))
 
 
+################################################################################
+# Tests of the list of attributes .get() factory method
+################################################################################
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsFromList
+ at run
+def testGetDenseElementsFromList():
+    with Context(), Location.unknown():
+        attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F64Type.get(), 2.0)]
+        attr = DenseElementsAttr.get(attrs)
+
+        # CHECK: dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf64>
+        print(attr)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsFromListWithExplicitType
+ at run
+def testGetDenseElementsFromListWithExplicitType():
+    with Context(), Location.unknown():
+        attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F64Type.get(), 2.0)]
+        shaped_type = ShapedType(Type.parse("tensor<2xf64>"))
+        attr = DenseElementsAttr.get(attrs, shaped_type)
+
+        # CHECK: dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf64>
+        print(attr)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsFromListEmptyList
+ at run
+def testGetDenseElementsFromListEmptyList():
+    with Context(), Location.unknown():
+        attrs = []
+
+        try:
+            attr = DenseElementsAttr.get(attrs)
+        except ValueError as e:
+            # CHECK: Attributes list must be non-empty
+            print(e)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsFromListNonAttributeType
+ at run
+def testGetDenseElementsFromListNonAttributeType():
+    with Context(), Location.unknown():
+        attrs = [1.0]
+
+        try:
+            attr = DenseElementsAttr.get(attrs)
+        except RuntimeError as e:
+            # CHECK: Invalid attribute when attempting to create an ArrayAttribute
+            print(e)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsFromListMismatchedType
+ at run
+def testGetDenseElementsFromListMismatchedType():
+    with Context(), Location.unknown():
+        attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F64Type.get(), 2.0)]
+        shaped_type = ShapedType(Type.parse("tensor<2xf32>"))
+
+        try:
+            attr = DenseElementsAttr.get(attrs, shaped_type)
+        except ValueError as e:
+            # CHECK: All attributes must be of the same type and match the type parameter
+            print(e)
+
+
+# CHECK-LABEL: TEST: testGetDenseElementsFromListMixedTypes
+ at run
+def testGetDenseElementsFromListMixedTypes():
+    with Context(), Location.unknown():
+        attrs = [FloatAttr.get(F64Type.get(), 1.0), FloatAttr.get(F32Type.get(), 2.0)]
+
+        try:
+            attr = DenseElementsAttr.get(attrs)
+        except ValueError as e:
+            # CHECK: All attributes must be of the same type and match the type parameter
+            print(e)
+
+
 ################################################################################
 # Splats.
 ################################################################################
@@ -205,6 +286,7 @@ def testGetDenseElementsBoolSplat():
 
 ### float and double arrays.
 
+
 # CHECK-LABEL: TEST: testGetDenseElementsF16
 @run
 def testGetDenseElementsF16():



More information about the Mlir-commits mailing list