[Mlir-commits] [mlir] [mlir][python] Add bindings for mlirDenseElementsAttrGet (PR #91389)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 21 16:22:34 PDT 2024


================
@@ -647,6 +664,55 @@ class PyDenseElementsAttribute
   static constexpr const char *pyClassName = "DenseElementsAttr";
   using PyConcreteAttribute::PyConcreteAttribute;
 
+  static PyDenseElementsAttribute
+  getFromList(py::list attributes, std::optional<PyType> explicitType,
+              DefaultingPyMlirContext contextWrapper) {
+
+    if (py::len(attributes) == 0) {
+      throw py::value_error("Attributes list must be non-empty");
+    }
+
+    MlirType shapedType;
+    if (explicitType) {
+      if ((!mlirTypeIsAShaped(*explicitType) ||
+           !mlirShapedTypeHasStaticShape(*explicitType))) {
+        std::string message =
+            "Expected a static ShapedType for the shaped_type parameter: ";
+        message.append(py::repr(py::cast(*explicitType)));
+        throw py::value_error(message);
+      }
+      shapedType = *explicitType;
+    } else {
+      SmallVector<int64_t> shape{static_cast<int64_t>(py::len(attributes))};
+      shapedType = mlirRankedTensorTypeGet(
+          shape.size(), shape.data(),
+          mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])),
+          mlirAttributeGetNull());
+    }
+
+    SmallVector<MlirAttribute> mlirAttributes;
+    mlirAttributes.reserve(py::len(attributes));
+    for (auto attribute : attributes) {
----------------
pranavm-nvidia wrote:

Fixed

https://github.com/llvm/llvm-project/pull/91389


More information about the Mlir-commits mailing list