[llvm-branch-commits] [mlir] 285c0aa - Add MLIR Python binding for Array Attribute
Mehdi Amini via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Dec 10 12:56:27 PST 2020
Author: Mehdi Amini
Date: 2020-12-10T20:51:34Z
New Revision: 285c0aa262c9255e6ea4efbce1418e5f5f17e9c1
URL: https://github.com/llvm/llvm-project/commit/285c0aa262c9255e6ea4efbce1418e5f5f17e9c1
DIFF: https://github.com/llvm/llvm-project/commit/285c0aa262c9255e6ea4efbce1418e5f5f17e9c1.diff
LOG: Add MLIR Python binding for Array Attribute
Differential Revision: https://reviews.llvm.org/D92948
Added:
Modified:
mlir/lib/Bindings/Python/IRModules.cpp
mlir/test/Bindings/Python/ir_attributes.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 5519c66ee1ab..5ebb2e4ccee3 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -1461,6 +1461,83 @@ class PyConcreteAttribute : public BaseTy {
static void bindDerived(ClassTy &m) {}
};
+class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
+ static constexpr const char *pyClassName = "ArrayAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ class PyArrayAttributeIterator {
+ public:
+ PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {}
+
+ PyArrayAttributeIterator &dunderIter() { return *this; }
+
+ PyAttribute dunderNext() {
+ if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) {
+ throw py::stop_iteration();
+ }
+ return PyAttribute(attr.getContext(),
+ mlirArrayAttrGetElement(attr.get(), nextIndex++));
+ }
+
+ static void bind(py::module &m) {
+ py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator")
+ .def("__iter__", &PyArrayAttributeIterator::dunderIter)
+ .def("__next__", &PyArrayAttributeIterator::dunderNext);
+ }
+
+ private:
+ PyAttribute attr;
+ int nextIndex = 0;
+ };
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](py::list attributes, DefaultingPyMlirContext context) {
+ SmallVector<MlirAttribute> mlirAttributes;
+ mlirAttributes.reserve(py::len(attributes));
+ for (auto attribute : attributes) {
+ try {
+ mlirAttributes.push_back(attribute.cast<PyAttribute>());
+ } catch (py::cast_error &err) {
+ std::string msg = std::string("Invalid attribute when attempting "
+ "to create an ArrayAttribute (") +
+ err.what() + ")";
+ throw py::cast_error(msg);
+ } catch (py::reference_cast_error &err) {
+ // This exception seems thrown when the value is "None".
+ std::string msg =
+ std::string("Invalid attribute (None?) when attempting to "
+ "create an ArrayAttribute (") +
+ err.what() + ")";
+ throw py::cast_error(msg);
+ }
+ }
+ MlirAttribute attr = mlirArrayAttrGet(
+ context->get(), mlirAttributes.size(), mlirAttributes.data());
+ return PyArrayAttribute(context->getRef(), attr);
+ },
+ py::arg("attributes"), py::arg("context") = py::none(),
+ "Gets a uniqued Array attribute");
+ c.def("__getitem__",
+ [](PyArrayAttribute &arr, intptr_t i) {
+ if (i >= mlirArrayAttrGetNumElements(arr))
+ throw py::index_error("ArrayAttribute index out of range");
+ return PyAttribute(arr.getContext(),
+ mlirArrayAttrGetElement(arr, i));
+ })
+ .def("__len__",
+ [](const PyArrayAttribute &arr) {
+ return mlirArrayAttrGetNumElements(arr);
+ })
+ .def("__iter__", [](const PyArrayAttribute &arr) {
+ return PyArrayAttributeIterator(arr);
+ });
+ }
+};
+
/// Float Point Attribute subclass - FloatAttr.
class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
public:
@@ -3089,6 +3166,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
// Builtin attribute bindings.
PyFloatAttribute::bind(m);
+ PyArrayAttribute::bind(m);
+ PyArrayAttribute::PyArrayAttributeIterator::bind(m);
PyIntegerAttribute::bind(m);
PyBoolAttribute::bind(m);
PyStringAttribute::bind(m);
diff --git a/mlir/test/Bindings/Python/ir_attributes.py b/mlir/test/Bindings/Python/ir_attributes.py
index 4ad180bb1b37..642c1f6a836c 100644
--- a/mlir/test/Bindings/Python/ir_attributes.py
+++ b/mlir/test/Bindings/Python/ir_attributes.py
@@ -269,3 +269,54 @@ def testTypeAttr():
run(testTypeAttr)
+
+
+# CHECK-LABEL: TEST: testArrayAttr
+def testArrayAttr():
+ with Context():
+ raw = Attribute.parse("[42, true, vector<4xf32>]")
+ # CHECK: attr: [42, true, vector<4xf32>]
+ print("raw attr:", raw)
+ # CHECK: - 42
+ # CHECK: - true
+ # CHECK: - vector<4xf32>
+ for attr in ArrayAttr(raw):
+ print("- ", attr)
+
+ with Context():
+ intAttr = Attribute.parse("42")
+ vecAttr = Attribute.parse("vector<4xf32>")
+ boolAttr = BoolAttr.get(True)
+ raw = ArrayAttr.get([vecAttr, boolAttr, intAttr])
+ # CHECK: attr: [vector<4xf32>, true, 42]
+ print("raw attr:", raw)
+ # CHECK: - vector<4xf32>
+ # CHECK: - true
+ # CHECK: - 42
+ arr = ArrayAttr(raw)
+ for attr in arr:
+ print("- ", attr)
+ # CHECK: attr[0]: vector<4xf32>
+ print("attr[0]:", arr[0])
+ # CHECK: attr[1]: true
+ print("attr[1]:", arr[1])
+ # CHECK: attr[2]: 42
+ print("attr[2]:", arr[2])
+ try:
+ print("attr[3]:", arr[3])
+ except IndexError as e:
+ # CHECK: Error: ArrayAttribute index out of range
+ print("Error: ", e)
+ with Context():
+ try:
+ ArrayAttr.get([None])
+ except RuntimeError as e:
+ # CHECK: Error: Invalid attribute (None?) when attempting to create an ArrayAttribute
+ print("Error: ", e)
+ try:
+ ArrayAttr.get([42])
+ except RuntimeError as e:
+ # CHECK: Error: Invalid attribute when attempting to create an ArrayAttribute (Unable to cast Python instance of type <class 'int'> to C++ type 'mlir::python::PyAttribute')
+ print("Error: ", e)
+run(testArrayAttr)
+
More information about the llvm-branch-commits
mailing list