[Mlir-commits] [mlir] c3a6e7c - [mlir] Expose operation attributes to Python bindings

Alex Zinenko llvmlistbot at llvm.org
Mon Nov 9 06:00:04 PST 2020


Author: Alex Zinenko
Date: 2020-11-09T14:59:56+01:00
New Revision: c3a6e7c9b7474f7977b77d38a1de13f27c785e5c

URL: https://github.com/llvm/llvm-project/commit/c3a6e7c9b7474f7977b77d38a1de13f27c785e5c
DIFF: https://github.com/llvm/llvm-project/commit/c3a6e7c9b7474f7977b77d38a1de13f27c785e5c.diff

LOG: [mlir] Expose operation attributes to Python bindings

Operations in a MLIR have a dictionary of attributes attached. Expose
those to Python bindings through a pseudo-container that can be indexed
either by attribute name, producing a PyAttribute, or by a contiguous
index for enumeration purposes, producing a PyNamedAttribute.

Depends On D90917

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D90919

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/test/Bindings/Python/ir_operation.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index cf71cc3eb92e..8f71181b385d 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -1282,6 +1282,47 @@ class PyOpResultList {
   PyOperationRef operation;
 };
 
+/// A list of operation attributes. Can be indexed by name, producing
+/// attributes, or by index, producing named attributes.
+class PyOpAttributeMap {
+public:
+  PyOpAttributeMap(PyOperationRef operation) : operation(operation) {}
+
+  PyAttribute dunderGetItemNamed(const std::string &name) {
+    MlirAttribute attr =
+        mlirOperationGetAttributeByName(operation->get(), name.c_str());
+    if (mlirAttributeIsNull(attr)) {
+      throw SetPyError(PyExc_KeyError,
+                       "attempt to access a non-existent attribute");
+    }
+    return PyAttribute(operation->getContext(), attr);
+  }
+
+  PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
+    if (index < 0 || index >= dunderLen()) {
+      throw SetPyError(PyExc_IndexError,
+                       "attempt to access out of bounds attribute");
+    }
+    MlirNamedAttribute namedAttr =
+        mlirOperationGetAttribute(operation->get(), index);
+    return PyNamedAttribute(namedAttr.attribute, std::string(namedAttr.name));
+  }
+
+  intptr_t dunderLen() {
+    return mlirOperationGetNumAttributes(operation->get());
+  }
+
+  static void bind(py::module &m) {
+    py::class_<PyOpAttributeMap>(m, "OpAttributeMap")
+        .def("__len__", &PyOpAttributeMap::dunderLen)
+        .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
+        .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed);
+  }
+
+private:
+  PyOperationRef operation;
+};
+
 } // end namespace
 
 //------------------------------------------------------------------------------
@@ -2436,6 +2477,11 @@ void mlir::python::populateIRSubmodule(py::module &m) {
            })
       .def("__eq__",
            [](PyOperationBase &self, py::object other) { return false; })
+      .def_property_readonly("attributes",
+                             [](PyOperationBase &self) {
+                               return PyOpAttributeMap(
+                                   self.getOperation().getRef());
+                             })
       .def_property_readonly("operands",
                              [](PyOperationBase &self) {
                                return PyOpOperandList(
@@ -2810,6 +2856,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
   PyBlockList::bind(m);
   PyOperationIterator::bind(m);
   PyOperationList::bind(m);
+  PyOpAttributeMap::bind(m);
   PyOpOperandList::bind(m);
   PyOpResultList::bind(m);
   PyRegionIterator::bind(m);

diff  --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py
index 0ce7ceea0a5c..54bc428ce8ae 100644
--- a/mlir/test/Bindings/Python/ir_operation.py
+++ b/mlir/test/Bindings/Python/ir_operation.py
@@ -277,6 +277,53 @@ def testOperationResultList():
 run(testOperationResultList)
 
 
+# CHECK-LABEL: TEST: testOperationAttributes
+def testOperationAttributes():
+  ctx = Context()
+  ctx.allow_unregistered_dialects = True
+  module = Module.parse(r"""
+    "some.op"() { some.attribute = 1 : i8,
+                  other.attribute = 3.0,
+                  dependent = "text" } : () -> ()
+  """, ctx)
+  op = module.body.operations[0]
+  assert len(op.attributes) == 3
+  iattr = IntegerAttr(op.attributes["some.attribute"])
+  fattr = FloatAttr(op.attributes["other.attribute"])
+  sattr = StringAttr(op.attributes["dependent"])
+  # CHECK: Attribute type i8, value 1
+  print(f"Attribute type {iattr.type}, value {iattr.value}")
+  # CHECK: Attribute type f64, value 3.0
+  print(f"Attribute type {fattr.type}, value {fattr.value}")
+  # CHECK: Attribute value text
+  print(f"Attribute value {sattr.value}")
+
+  # We don't know in which order the attributes are stored.
+  # CHECK-DAG: NamedAttribute(dependent="text")
+  # CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64)
+  # CHECK-DAG: NamedAttribute(some.attribute=1 : i8)
+  for attr in op.attributes:
+    print(str(attr))
+
+  # Check that exceptions are raised as expected.
+  try:
+    op.attributes["does_not_exist"]
+  except KeyError:
+    pass
+  else:
+    assert False, "expected KeyError on accessing a non-existent attribute"
+
+  try:
+    op.attributes[42]
+  except IndexError:
+    pass
+  else:
+    assert False, "expected IndexError on accessing an out-of-bounds attribute"
+
+
+run(testOperationAttributes)
+
+
 # CHECK-LABEL: TEST: testOperationPrint
 def testOperationPrint():
   ctx = Context()


        


More information about the Mlir-commits mailing list