[Mlir-commits] [mlir] ed9e52f - [mlir][python] Usability improvements for Python bindings

Alex Zinenko llvmlistbot at llvm.org
Mon Oct 4 02:45:35 PDT 2021


Author: Alex Zinenko
Date: 2021-10-04T11:45:25+02:00
New Revision: ed9e52f3af4e1d95033268b60b91cbdebe38182c

URL: https://github.com/llvm/llvm-project/commit/ed9e52f3af4e1d95033268b60b91cbdebe38182c
DIFF: https://github.com/llvm/llvm-project/commit/ed9e52f3af4e1d95033268b60b91cbdebe38182c.diff

LOG: [mlir][python] Usability improvements for Python bindings

Provide a couple of quality-of-life usability improvements for Python bindings,
in particular:

  * give access to the list of types for the list of op results or block
    arguments, similarly to ValueRange->TypeRange,

  * allow for constructing empty dictionary arrays,

  * support construction of array attributes by concatenating an existing
    attribute with a Python list of attributes.

All these are required for the upcoming customization of builtin and standard
ops.

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D110946

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRAttributes.cpp
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/test/python/ir/attributes.py
    mlir/test/python/ir/operation.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index bb4b5f4f0462c..2ff75ceedcf2e 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -18,7 +18,6 @@ using namespace mlir;
 using namespace mlir::python;
 
 using llvm::SmallVector;
-using llvm::StringRef;
 using llvm::Twine;
 
 namespace {
@@ -44,6 +43,24 @@ class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
   }
 };
 
+template <typename T>
+static T pyTryCast(py::handle object) {
+  try {
+    return object.cast<T>();
+  } catch (py::cast_error &err) {
+    std::string msg =
+        std::string(
+            "Invalid attribute when attempting to create an ArrayAttribute (") +
+        err.what() + ")";
+    throw py::cast_error(msg);
+  } catch (py::reference_cast_error &err) {
+    std::string msg = std::string("Invalid attribute (None?) when attempting "
+                                  "to create an ArrayAttribute (") +
+                      err.what() + ")";
+    throw py::cast_error(msg);
+  }
+}
+
 class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
@@ -76,6 +93,10 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
     int nextIndex = 0;
   };
 
+  PyAttribute getItem(intptr_t i) {
+    return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i));
+  }
+
   static void bindDerived(ClassTy &c) {
     c.def_static(
         "get",
@@ -83,21 +104,7 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
           SmallVector<MlirAttribute> mlirAttributes;
           mlirAttributes.reserve(py::len(attributes));
           for (auto attribute : attributes) {
-            try {
-              mlirAttributes.push_back(attribute.cast<PyAttribute>());
-            } catch (py::cast_error &err) {
-              std::string msg = std::string("Invalid attribute when attempting "
-                                            "to create an ArrayAttribute (") +
-                                err.what() + ")";
-              throw py::cast_error(msg);
-            } catch (py::reference_cast_error &err) {
-              // This exception seems thrown when the value is "None".
-              std::string msg =
-                  std::string("Invalid attribute (None?) when attempting to "
-                              "create an ArrayAttribute (") +
-                  err.what() + ")";
-              throw py::cast_error(msg);
-            }
+            mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
           }
           MlirAttribute attr = mlirArrayAttrGet(
               context->get(), mlirAttributes.size(), mlirAttributes.data());
@@ -109,8 +116,7 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
           [](PyArrayAttribute &arr, intptr_t i) {
             if (i >= mlirArrayAttrGetNumElements(arr))
               throw py::index_error("ArrayAttribute index out of range");
-            return PyAttribute(arr.getContext(),
-                               mlirArrayAttrGetElement(arr, i));
+            return arr.getItem(i);
           })
         .def("__len__",
              [](const PyArrayAttribute &arr) {
@@ -119,6 +125,18 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
         .def("__iter__", [](const PyArrayAttribute &arr) {
           return PyArrayAttributeIterator(arr);
         });
+    c.def("__add__", [](PyArrayAttribute arr, py::list extras) {
+      std::vector<MlirAttribute> attributes;
+      intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
+      attributes.reserve(numOldElements + py::len(extras));
+      for (intptr_t i = 0; i < numOldElements; ++i)
+        attributes.push_back(arr.getItem(i));
+      for (py::handle attr : extras)
+        attributes.push_back(pyTryCast<PyAttribute>(attr));
+      MlirAttribute arrayAttr = mlirArrayAttrGet(
+          arr.getContext()->get(), attributes.size(), attributes.data());
+      return PyArrayAttribute(arr.getContext(), arrayAttr);
+    });
   }
 };
 
@@ -602,7 +620,7 @@ class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
                                     mlirNamedAttributes.data());
           return PyDictAttribute(context->getRef(), attr);
         },
-        py::arg("value"), py::arg("context") = py::none(),
+        py::arg("value") = py::dict(), py::arg("context") = py::none(),
         "Gets an uniqued dict attribute");
     c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
       MlirAttribute attr =

diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 0434ac37c6881..8ed3bd5ed38a8 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1590,6 +1590,19 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {
   }
 };
 
+/// Returns the list of types of the values held by container.
+template <typename Container>
+static std::vector<PyType> getValueTypes(Container &container,
+                                         PyMlirContextRef &context) {
+  std::vector<PyType> result;
+  result.reserve(container.getNumElements());
+  for (int i = 0, e = container.getNumElements(); i < e; ++i) {
+    result.push_back(
+        PyType(context, mlirValueGetType(container.getElement(i).get())));
+  }
+  return result;
+}
+
 /// A list of block arguments. Internally, these are stored as consecutive
 /// elements, random access is cheap. The argument list is associated with the
 /// operation that contains the block (detached blocks are not allowed in
@@ -1625,6 +1638,12 @@ class PyBlockArgumentList
     return PyBlockArgumentList(operation, block, startIndex, length, step);
   }
 
+  static void bindDerived(ClassTy &c) {
+    c.def_property_readonly("types", [](PyBlockArgumentList &self) {
+      return getValueTypes(self, self.operation->getContext());
+    });
+  }
+
 private:
   PyOperationRef operation;
   MlirBlock block;
@@ -1712,6 +1731,12 @@ class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
     return PyOpResultList(operation, startIndex, length, step);
   }
 
+  static void bindDerived(ClassTy &c) {
+    c.def_property_readonly("types", [](PyOpResultList &self) {
+      return getValueTypes(self, self.operation->getContext());
+    });
+  }
+
 private:
   PyOperationRef operation;
 };

diff  --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index 89da559669db8..d2deb39a69df3 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -343,6 +343,9 @@ def testDictAttr():
     else:
       assert False, "expected IndexError on accessing an out-of-bounds attribute"
 
+    # CHECK "empty: {}"
+    print("empty: ", DictAttr.get())
+
 
 # CHECK-LABEL: TEST: testTypeAttr
 @run
@@ -404,3 +407,9 @@ def testArrayAttr():
     except RuntimeError as e:
       # CHECK: Error: Invalid attribute when attempting to create an ArrayAttribute
       print("Error: ", e)
+
+  with Context():
+    array = ArrayAttr.get([StringAttr.get("a"), StringAttr.get("b")])
+    array = array + [StringAttr.get("c")]
+    # CHECK: concat: ["a", "b", "c"]
+    print("concat: ", array)

diff  --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index d04f52f0c0fe9..f9b4efa4d97eb 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -145,6 +145,12 @@ def testBlockArgumentList():
     print("Length: ",
           len(entry_block.arguments[:2] + entry_block.arguments[1:]))
 
+    # CHECK: Type: i8
+    # CHECK: Type: i16
+    # CHECK: Type: i24
+    for t in entry_block.arguments.types:
+      print("Type: ", t)
+
 
 run(testBlockArgumentList)
 
@@ -380,6 +386,12 @@ def testOperationResultList():
   for res in call.results:
     print(f"Result {res.result_number}, type {res.type}")
 
+  # CHECK: Result type i32
+  # CHECK: Result type f64
+  # CHECK: Result type index
+  for t in call.results.types:
+    print(f"Result type {t}")
+
 
 run(testOperationResultList)
 


        


More information about the Mlir-commits mailing list