[llvm-branch-commits] [mlir] 285c0aa - Add MLIR Python binding for Array Attribute

Mehdi Amini via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Dec 10 12:56:27 PST 2020


Author: Mehdi Amini
Date: 2020-12-10T20:51:34Z
New Revision: 285c0aa262c9255e6ea4efbce1418e5f5f17e9c1

URL: https://github.com/llvm/llvm-project/commit/285c0aa262c9255e6ea4efbce1418e5f5f17e9c1
DIFF: https://github.com/llvm/llvm-project/commit/285c0aa262c9255e6ea4efbce1418e5f5f17e9c1.diff

LOG: Add MLIR Python binding for Array Attribute

Differential Revision: https://reviews.llvm.org/D92948

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/test/Bindings/Python/ir_attributes.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 5519c66ee1ab..5ebb2e4ccee3 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -1461,6 +1461,83 @@ class PyConcreteAttribute : public BaseTy {
   static void bindDerived(ClassTy &m) {}
 };
 
+class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
+  static constexpr const char *pyClassName = "ArrayAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  class PyArrayAttributeIterator {
+  public:
+    PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {}
+
+    PyArrayAttributeIterator &dunderIter() { return *this; }
+
+    PyAttribute dunderNext() {
+      if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) {
+        throw py::stop_iteration();
+      }
+      return PyAttribute(attr.getContext(),
+                         mlirArrayAttrGetElement(attr.get(), nextIndex++));
+    }
+
+    static void bind(py::module &m) {
+      py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator")
+          .def("__iter__", &PyArrayAttributeIterator::dunderIter)
+          .def("__next__", &PyArrayAttributeIterator::dunderNext);
+    }
+
+  private:
+    PyAttribute attr;
+    int nextIndex = 0;
+  };
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](py::list attributes, DefaultingPyMlirContext context) {
+          SmallVector<MlirAttribute> mlirAttributes;
+          mlirAttributes.reserve(py::len(attributes));
+          for (auto attribute : attributes) {
+            try {
+              mlirAttributes.push_back(attribute.cast<PyAttribute>());
+            } catch (py::cast_error &err) {
+              std::string msg = std::string("Invalid attribute when attempting "
+                                            "to create an ArrayAttribute (") +
+                                err.what() + ")";
+              throw py::cast_error(msg);
+            } catch (py::reference_cast_error &err) {
+              // This exception seems thrown when the value is "None".
+              std::string msg =
+                  std::string("Invalid attribute (None?) when attempting to "
+                              "create an ArrayAttribute (") +
+                  err.what() + ")";
+              throw py::cast_error(msg);
+            }
+          }
+          MlirAttribute attr = mlirArrayAttrGet(
+              context->get(), mlirAttributes.size(), mlirAttributes.data());
+          return PyArrayAttribute(context->getRef(), attr);
+        },
+        py::arg("attributes"), py::arg("context") = py::none(),
+        "Gets a uniqued Array attribute");
+    c.def("__getitem__",
+          [](PyArrayAttribute &arr, intptr_t i) {
+            if (i >= mlirArrayAttrGetNumElements(arr))
+              throw py::index_error("ArrayAttribute index out of range");
+            return PyAttribute(arr.getContext(),
+                               mlirArrayAttrGetElement(arr, i));
+          })
+        .def("__len__",
+             [](const PyArrayAttribute &arr) {
+               return mlirArrayAttrGetNumElements(arr);
+             })
+        .def("__iter__", [](const PyArrayAttribute &arr) {
+          return PyArrayAttributeIterator(arr);
+        });
+  }
+};
+
 /// Float Point Attribute subclass - FloatAttr.
 class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
 public:
@@ -3089,6 +3166,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
 
   // Builtin attribute bindings.
   PyFloatAttribute::bind(m);
+  PyArrayAttribute::bind(m);
+  PyArrayAttribute::PyArrayAttributeIterator::bind(m);
   PyIntegerAttribute::bind(m);
   PyBoolAttribute::bind(m);
   PyStringAttribute::bind(m);

diff  --git a/mlir/test/Bindings/Python/ir_attributes.py b/mlir/test/Bindings/Python/ir_attributes.py
index 4ad180bb1b37..642c1f6a836c 100644
--- a/mlir/test/Bindings/Python/ir_attributes.py
+++ b/mlir/test/Bindings/Python/ir_attributes.py
@@ -269,3 +269,54 @@ def testTypeAttr():
 
 
 run(testTypeAttr)
+
+
+# CHECK-LABEL: TEST: testArrayAttr
+def testArrayAttr():
+  with Context():
+    raw = Attribute.parse("[42, true, vector<4xf32>]")
+  # CHECK: attr: [42, true, vector<4xf32>]
+  print("raw attr:", raw)
+  # CHECK: - 42
+  # CHECK: - true
+  # CHECK: - vector<4xf32>
+  for attr in ArrayAttr(raw):
+    print("- ", attr)
+
+  with Context():
+    intAttr = Attribute.parse("42")
+    vecAttr = Attribute.parse("vector<4xf32>")
+    boolAttr = BoolAttr.get(True)
+    raw = ArrayAttr.get([vecAttr, boolAttr, intAttr])
+  # CHECK: attr: [vector<4xf32>, true, 42]
+  print("raw attr:", raw)
+  # CHECK: - vector<4xf32>
+  # CHECK: - true
+  # CHECK: - 42
+  arr = ArrayAttr(raw)
+  for attr in arr:
+    print("- ", attr)
+  # CHECK: attr[0]: vector<4xf32>
+  print("attr[0]:", arr[0])
+  # CHECK: attr[1]: true
+  print("attr[1]:", arr[1])
+  # CHECK: attr[2]: 42
+  print("attr[2]:", arr[2])
+  try:
+    print("attr[3]:", arr[3])
+  except IndexError as e:
+    # CHECK: Error: ArrayAttribute index out of range
+    print("Error: ", e)
+  with Context():
+    try:
+      ArrayAttr.get([None])
+    except RuntimeError as e:
+      # CHECK: Error: Invalid attribute (None?) when attempting to create an ArrayAttribute
+      print("Error: ", e)
+    try:
+      ArrayAttr.get([42])
+    except RuntimeError as e:
+      # CHECK: Error: Invalid attribute when attempting to create an ArrayAttribute (Unable to cast Python instance of type <class 'int'> to C++ type 'mlir::python::PyAttribute')
+      print("Error: ", e)
+run(testArrayAttr)
+


        


More information about the llvm-branch-commits mailing list