[llvm-branch-commits] [mlir] c84b53c - [mlir] Add Python binding for MLIR Dict Attribute
Mehdi Amini via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Sat Dec 12 21:13:34 PST 2020
Author: kweisamx
Date: 2020-12-13T04:30:35Z
New Revision: c84b53ca9bcddcbaa8b726be0a4d6cb684dedbd5
URL: https://github.com/llvm/llvm-project/commit/c84b53ca9bcddcbaa8b726be0a4d6cb684dedbd5
DIFF: https://github.com/llvm/llvm-project/commit/c84b53ca9bcddcbaa8b726be0a4d6cb684dedbd5.diff
LOG: [mlir] Add Python binding for MLIR Dict Attribute
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D93004
Added:
Modified:
mlir/lib/Bindings/Python/IRModules.cpp
mlir/test/Bindings/Python/ir_attributes.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 66443bf89072..8a77d60741b4 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -1968,6 +1968,58 @@ class PyDenseIntElementsAttribute
}
};
+class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
+ static constexpr const char *pyClassName = "DictAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
+
+ static void bindDerived(ClassTy &c) {
+ c.def("__len__", &PyDictAttribute::dunderLen);
+ c.def_static(
+ "get",
+ [](py::dict attributes, DefaultingPyMlirContext context) {
+ SmallVector<MlirNamedAttribute> mlirNamedAttributes;
+ mlirNamedAttributes.reserve(attributes.size());
+ for (auto &it : attributes) {
+ auto &mlir_attr = it.second.cast<PyAttribute &>();
+ auto name = it.first.cast<std::string>();
+ mlirNamedAttributes.push_back(mlirNamedAttributeGet(
+ mlirIdentifierGet(mlirAttributeGetContext(mlir_attr),
+ toMlirStringRef(name)),
+ mlir_attr));
+ }
+ MlirAttribute attr =
+ mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
+ mlirNamedAttributes.data());
+ return PyDictAttribute(context->getRef(), attr);
+ },
+ py::arg("value"), py::arg("context") = py::none(),
+ "Gets an uniqued dict attribute");
+ c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
+ MlirAttribute attr =
+ mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
+ if (mlirAttributeIsNull(attr)) {
+ throw SetPyError(PyExc_KeyError,
+ "attempt to access a non-existent attribute");
+ }
+ return PyAttribute(self.getContext(), attr);
+ });
+ c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
+ if (index < 0 || index >= self.dunderLen()) {
+ throw SetPyError(PyExc_IndexError,
+ "attempt to access out of bounds attribute");
+ }
+ MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
+ return PyNamedAttribute(
+ namedAttr.attribute,
+ std::string(mlirIdentifierStr(namedAttr.name).data));
+ });
+ }
+};
+
/// Refinement of PyDenseElementsAttribute for attributes containing
/// floating-point values. Supports element access.
class PyDenseFPElementsAttribute
@@ -3181,6 +3233,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
PyDenseElementsAttribute::bind(m);
PyDenseIntElementsAttribute::bind(m);
PyDenseFPElementsAttribute::bind(m);
+ PyDictAttribute::bind(m);
PyTypeAttribute::bind(m);
PyUnitAttribute::bind(m);
diff --git a/mlir/test/Bindings/Python/ir_attributes.py b/mlir/test/Bindings/Python/ir_attributes.py
index 642c1f6a836c..84f313912547 100644
--- a/mlir/test/Bindings/Python/ir_attributes.py
+++ b/mlir/test/Bindings/Python/ir_attributes.py
@@ -257,6 +257,47 @@ def testDenseFPAttr():
run(testDenseFPAttr)
+# CHECK-LABEL: TEST: testDictAttr
+def testDictAttr():
+ with Context():
+ dict_attr = {
+ 'stringattr': StringAttr.get('string'),
+ 'integerattr' : IntegerAttr.get(
+ IntegerType.get_signless(32), 42)
+ }
+
+ a = DictAttr.get(dict_attr)
+
+ # CHECK attr: {integerattr = 42 : i32, stringattr = "string"}
+ print("attr:", a)
+
+ assert len(a) == 2
+
+ # CHECK: 42 : i32
+ print(a['integerattr'])
+
+ # CHECK: "string"
+ print(a['stringattr'])
+
+ # Check that exceptions are raised as expected.
+ try:
+ _ = a['does_not_exist']
+ except KeyError:
+ pass
+ else:
+ assert False, "Exception not produced"
+
+ try:
+ _ = a[42]
+ except IndexError:
+ pass
+ else:
+ assert False, "expected IndexError on accessing an out-of-bounds attribute"
+
+
+
+run(testDictAttr)
+
# CHECK-LABEL: TEST: testTypeAttr
def testTypeAttr():
with Context():
More information about the llvm-branch-commits
mailing list