[Mlir-commits] [mlir] c84b53c - [mlir] Add Python binding for MLIR Dict Attribute

Mehdi Amini llvmlistbot at llvm.org
Sat Dec 12 21:08:55 PST 2020


Author: kweisamx
Date: 2020-12-13T04:30:35Z
New Revision: c84b53ca9bcddcbaa8b726be0a4d6cb684dedbd5

URL: https://github.com/llvm/llvm-project/commit/c84b53ca9bcddcbaa8b726be0a4d6cb684dedbd5
DIFF: https://github.com/llvm/llvm-project/commit/c84b53ca9bcddcbaa8b726be0a4d6cb684dedbd5.diff

LOG: [mlir] Add Python binding for MLIR Dict Attribute

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D93004

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/test/Bindings/Python/ir_attributes.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 66443bf89072..8a77d60741b4 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -1968,6 +1968,58 @@ class PyDenseIntElementsAttribute
   }
 };
 
+class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
+  static constexpr const char *pyClassName = "DictAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
+
+  static void bindDerived(ClassTy &c) {
+    c.def("__len__", &PyDictAttribute::dunderLen);
+    c.def_static(
+        "get",
+        [](py::dict attributes, DefaultingPyMlirContext context) {
+          SmallVector<MlirNamedAttribute> mlirNamedAttributes;
+          mlirNamedAttributes.reserve(attributes.size());
+          for (auto &it : attributes) {
+            auto &mlir_attr = it.second.cast<PyAttribute &>();
+            auto name = it.first.cast<std::string>();
+            mlirNamedAttributes.push_back(mlirNamedAttributeGet(
+                mlirIdentifierGet(mlirAttributeGetContext(mlir_attr),
+                                  toMlirStringRef(name)),
+                mlir_attr));
+          }
+          MlirAttribute attr =
+              mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
+                                    mlirNamedAttributes.data());
+          return PyDictAttribute(context->getRef(), attr);
+        },
+        py::arg("value"), py::arg("context") = py::none(),
+        "Gets an uniqued dict attribute");
+    c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
+      MlirAttribute attr =
+          mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
+      if (mlirAttributeIsNull(attr)) {
+        throw SetPyError(PyExc_KeyError,
+                         "attempt to access a non-existent attribute");
+      }
+      return PyAttribute(self.getContext(), attr);
+    });
+    c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
+      if (index < 0 || index >= self.dunderLen()) {
+        throw SetPyError(PyExc_IndexError,
+                         "attempt to access out of bounds attribute");
+      }
+      MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
+      return PyNamedAttribute(
+          namedAttr.attribute,
+          std::string(mlirIdentifierStr(namedAttr.name).data));
+    });
+  }
+};
+
 /// Refinement of PyDenseElementsAttribute for attributes containing
 /// floating-point values. Supports element access.
 class PyDenseFPElementsAttribute
@@ -3181,6 +3233,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
   PyDenseElementsAttribute::bind(m);
   PyDenseIntElementsAttribute::bind(m);
   PyDenseFPElementsAttribute::bind(m);
+  PyDictAttribute::bind(m);
   PyTypeAttribute::bind(m);
   PyUnitAttribute::bind(m);
 

diff  --git a/mlir/test/Bindings/Python/ir_attributes.py b/mlir/test/Bindings/Python/ir_attributes.py
index 642c1f6a836c..84f313912547 100644
--- a/mlir/test/Bindings/Python/ir_attributes.py
+++ b/mlir/test/Bindings/Python/ir_attributes.py
@@ -257,6 +257,47 @@ def testDenseFPAttr():
 run(testDenseFPAttr)
 
 
+# CHECK-LABEL: TEST: testDictAttr
+def testDictAttr():
+  with Context():
+    dict_attr = {
+      'stringattr':  StringAttr.get('string'),
+      'integerattr' : IntegerAttr.get(
+        IntegerType.get_signless(32), 42)
+    }
+
+    a = DictAttr.get(dict_attr)
+
+    # CHECK attr: {integerattr = 42 : i32, stringattr = "string"}
+    print("attr:", a)
+
+    assert len(a) == 2
+
+    # CHECK: 42 : i32
+    print(a['integerattr'])
+
+    # CHECK: "string"
+    print(a['stringattr'])
+
+    # Check that exceptions are raised as expected.
+    try:
+      _ = a['does_not_exist']
+    except KeyError:
+      pass
+    else:
+      assert False, "Exception not produced"
+
+    try:
+      _ = a[42]
+    except IndexError:
+      pass
+    else:
+      assert False, "expected IndexError on accessing an out-of-bounds attribute"
+
+
+
+run(testDictAttr)
+
 # CHECK-LABEL: TEST: testTypeAttr
 def testTypeAttr():
   with Context():


        


More information about the Mlir-commits mailing list