[Mlir-commits] [mlir] ed9e52f - [mlir][python] Usability improvements for Python bindings
Alex Zinenko
llvmlistbot at llvm.org
Mon Oct 4 02:45:35 PDT 2021
Author: Alex Zinenko
Date: 2021-10-04T11:45:25+02:00
New Revision: ed9e52f3af4e1d95033268b60b91cbdebe38182c
URL: https://github.com/llvm/llvm-project/commit/ed9e52f3af4e1d95033268b60b91cbdebe38182c
DIFF: https://github.com/llvm/llvm-project/commit/ed9e52f3af4e1d95033268b60b91cbdebe38182c.diff
LOG: [mlir][python] Usability improvements for Python bindings
Provide a couple of quality-of-life usability improvements for Python bindings,
in particular:
* give access to the list of types for the list of op results or block
arguments, similarly to ValueRange->TypeRange,
* allow for constructing empty dictionary arrays,
* support construction of array attributes by concatenating an existing
attribute with a Python list of attributes.
All these are required for the upcoming customization of builtin and standard
ops.
Reviewed By: stellaraccident
Differential Revision: https://reviews.llvm.org/D110946
Added:
Modified:
mlir/lib/Bindings/Python/IRAttributes.cpp
mlir/lib/Bindings/Python/IRCore.cpp
mlir/test/python/ir/attributes.py
mlir/test/python/ir/operation.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index bb4b5f4f0462c..2ff75ceedcf2e 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -18,7 +18,6 @@ using namespace mlir;
using namespace mlir::python;
using llvm::SmallVector;
-using llvm::StringRef;
using llvm::Twine;
namespace {
@@ -44,6 +43,24 @@ class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
}
};
+template <typename T>
+static T pyTryCast(py::handle object) {
+ try {
+ return object.cast<T>();
+ } catch (py::cast_error &err) {
+ std::string msg =
+ std::string(
+ "Invalid attribute when attempting to create an ArrayAttribute (") +
+ err.what() + ")";
+ throw py::cast_error(msg);
+ } catch (py::reference_cast_error &err) {
+ std::string msg = std::string("Invalid attribute (None?) when attempting "
+ "to create an ArrayAttribute (") +
+ err.what() + ")";
+ throw py::cast_error(msg);
+ }
+}
+
class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
public:
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
@@ -76,6 +93,10 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
int nextIndex = 0;
};
+ PyAttribute getItem(intptr_t i) {
+ return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i));
+ }
+
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
@@ -83,21 +104,7 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
SmallVector<MlirAttribute> mlirAttributes;
mlirAttributes.reserve(py::len(attributes));
for (auto attribute : attributes) {
- try {
- mlirAttributes.push_back(attribute.cast<PyAttribute>());
- } catch (py::cast_error &err) {
- std::string msg = std::string("Invalid attribute when attempting "
- "to create an ArrayAttribute (") +
- err.what() + ")";
- throw py::cast_error(msg);
- } catch (py::reference_cast_error &err) {
- // This exception seems thrown when the value is "None".
- std::string msg =
- std::string("Invalid attribute (None?) when attempting to "
- "create an ArrayAttribute (") +
- err.what() + ")";
- throw py::cast_error(msg);
- }
+ mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
}
MlirAttribute attr = mlirArrayAttrGet(
context->get(), mlirAttributes.size(), mlirAttributes.data());
@@ -109,8 +116,7 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
[](PyArrayAttribute &arr, intptr_t i) {
if (i >= mlirArrayAttrGetNumElements(arr))
throw py::index_error("ArrayAttribute index out of range");
- return PyAttribute(arr.getContext(),
- mlirArrayAttrGetElement(arr, i));
+ return arr.getItem(i);
})
.def("__len__",
[](const PyArrayAttribute &arr) {
@@ -119,6 +125,18 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
.def("__iter__", [](const PyArrayAttribute &arr) {
return PyArrayAttributeIterator(arr);
});
+ c.def("__add__", [](PyArrayAttribute arr, py::list extras) {
+ std::vector<MlirAttribute> attributes;
+ intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
+ attributes.reserve(numOldElements + py::len(extras));
+ for (intptr_t i = 0; i < numOldElements; ++i)
+ attributes.push_back(arr.getItem(i));
+ for (py::handle attr : extras)
+ attributes.push_back(pyTryCast<PyAttribute>(attr));
+ MlirAttribute arrayAttr = mlirArrayAttrGet(
+ arr.getContext()->get(), attributes.size(), attributes.data());
+ return PyArrayAttribute(arr.getContext(), arrayAttr);
+ });
}
};
@@ -602,7 +620,7 @@ class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
mlirNamedAttributes.data());
return PyDictAttribute(context->getRef(), attr);
},
- py::arg("value"), py::arg("context") = py::none(),
+ py::arg("value") = py::dict(), py::arg("context") = py::none(),
"Gets an uniqued dict attribute");
c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
MlirAttribute attr =
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 0434ac37c6881..8ed3bd5ed38a8 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1590,6 +1590,19 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {
}
};
+/// Returns the list of types of the values held by container.
+template <typename Container>
+static std::vector<PyType> getValueTypes(Container &container,
+ PyMlirContextRef &context) {
+ std::vector<PyType> result;
+ result.reserve(container.getNumElements());
+ for (int i = 0, e = container.getNumElements(); i < e; ++i) {
+ result.push_back(
+ PyType(context, mlirValueGetType(container.getElement(i).get())));
+ }
+ return result;
+}
+
/// A list of block arguments. Internally, these are stored as consecutive
/// elements, random access is cheap. The argument list is associated with the
/// operation that contains the block (detached blocks are not allowed in
@@ -1625,6 +1638,12 @@ class PyBlockArgumentList
return PyBlockArgumentList(operation, block, startIndex, length, step);
}
+ static void bindDerived(ClassTy &c) {
+ c.def_property_readonly("types", [](PyBlockArgumentList &self) {
+ return getValueTypes(self, self.operation->getContext());
+ });
+ }
+
private:
PyOperationRef operation;
MlirBlock block;
@@ -1712,6 +1731,12 @@ class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
return PyOpResultList(operation, startIndex, length, step);
}
+ static void bindDerived(ClassTy &c) {
+ c.def_property_readonly("types", [](PyOpResultList &self) {
+ return getValueTypes(self, self.operation->getContext());
+ });
+ }
+
private:
PyOperationRef operation;
};
diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index 89da559669db8..d2deb39a69df3 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -343,6 +343,9 @@ def testDictAttr():
else:
assert False, "expected IndexError on accessing an out-of-bounds attribute"
+ # CHECK "empty: {}"
+ print("empty: ", DictAttr.get())
+
# CHECK-LABEL: TEST: testTypeAttr
@run
@@ -404,3 +407,9 @@ def testArrayAttr():
except RuntimeError as e:
# CHECK: Error: Invalid attribute when attempting to create an ArrayAttribute
print("Error: ", e)
+
+ with Context():
+ array = ArrayAttr.get([StringAttr.get("a"), StringAttr.get("b")])
+ array = array + [StringAttr.get("c")]
+ # CHECK: concat: ["a", "b", "c"]
+ print("concat: ", array)
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index d04f52f0c0fe9..f9b4efa4d97eb 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -145,6 +145,12 @@ def testBlockArgumentList():
print("Length: ",
len(entry_block.arguments[:2] + entry_block.arguments[1:]))
+ # CHECK: Type: i8
+ # CHECK: Type: i16
+ # CHECK: Type: i24
+ for t in entry_block.arguments.types:
+ print("Type: ", t)
+
run(testBlockArgumentList)
@@ -380,6 +386,12 @@ def testOperationResultList():
for res in call.results:
print(f"Result {res.result_number}, type {res.type}")
+ # CHECK: Result type i32
+ # CHECK: Result type f64
+ # CHECK: Result type index
+ for t in call.results.types:
+ print(f"Result type {t}")
+
run(testOperationResultList)
More information about the Mlir-commits
mailing list