[Mlir-commits] [mlir] c3a6e7c - [mlir] Expose operation attributes to Python bindings
Alex Zinenko
llvmlistbot at llvm.org
Mon Nov 9 06:00:04 PST 2020
Author: Alex Zinenko
Date: 2020-11-09T14:59:56+01:00
New Revision: c3a6e7c9b7474f7977b77d38a1de13f27c785e5c
URL: https://github.com/llvm/llvm-project/commit/c3a6e7c9b7474f7977b77d38a1de13f27c785e5c
DIFF: https://github.com/llvm/llvm-project/commit/c3a6e7c9b7474f7977b77d38a1de13f27c785e5c.diff
LOG: [mlir] Expose operation attributes to Python bindings
Operations in a MLIR have a dictionary of attributes attached. Expose
those to Python bindings through a pseudo-container that can be indexed
either by attribute name, producing a PyAttribute, or by a contiguous
index for enumeration purposes, producing a PyNamedAttribute.
Depends On D90917
Reviewed By: stellaraccident
Differential Revision: https://reviews.llvm.org/D90919
Added:
Modified:
mlir/lib/Bindings/Python/IRModules.cpp
mlir/test/Bindings/Python/ir_operation.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index cf71cc3eb92e..8f71181b385d 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -1282,6 +1282,47 @@ class PyOpResultList {
PyOperationRef operation;
};
+/// A list of operation attributes. Can be indexed by name, producing
+/// attributes, or by index, producing named attributes.
+class PyOpAttributeMap {
+public:
+ PyOpAttributeMap(PyOperationRef operation) : operation(operation) {}
+
+ PyAttribute dunderGetItemNamed(const std::string &name) {
+ MlirAttribute attr =
+ mlirOperationGetAttributeByName(operation->get(), name.c_str());
+ if (mlirAttributeIsNull(attr)) {
+ throw SetPyError(PyExc_KeyError,
+ "attempt to access a non-existent attribute");
+ }
+ return PyAttribute(operation->getContext(), attr);
+ }
+
+ PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
+ if (index < 0 || index >= dunderLen()) {
+ throw SetPyError(PyExc_IndexError,
+ "attempt to access out of bounds attribute");
+ }
+ MlirNamedAttribute namedAttr =
+ mlirOperationGetAttribute(operation->get(), index);
+ return PyNamedAttribute(namedAttr.attribute, std::string(namedAttr.name));
+ }
+
+ intptr_t dunderLen() {
+ return mlirOperationGetNumAttributes(operation->get());
+ }
+
+ static void bind(py::module &m) {
+ py::class_<PyOpAttributeMap>(m, "OpAttributeMap")
+ .def("__len__", &PyOpAttributeMap::dunderLen)
+ .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
+ .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed);
+ }
+
+private:
+ PyOperationRef operation;
+};
+
} // end namespace
//------------------------------------------------------------------------------
@@ -2436,6 +2477,11 @@ void mlir::python::populateIRSubmodule(py::module &m) {
})
.def("__eq__",
[](PyOperationBase &self, py::object other) { return false; })
+ .def_property_readonly("attributes",
+ [](PyOperationBase &self) {
+ return PyOpAttributeMap(
+ self.getOperation().getRef());
+ })
.def_property_readonly("operands",
[](PyOperationBase &self) {
return PyOpOperandList(
@@ -2810,6 +2856,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
PyBlockList::bind(m);
PyOperationIterator::bind(m);
PyOperationList::bind(m);
+ PyOpAttributeMap::bind(m);
PyOpOperandList::bind(m);
PyOpResultList::bind(m);
PyRegionIterator::bind(m);
diff --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py
index 0ce7ceea0a5c..54bc428ce8ae 100644
--- a/mlir/test/Bindings/Python/ir_operation.py
+++ b/mlir/test/Bindings/Python/ir_operation.py
@@ -277,6 +277,53 @@ def testOperationResultList():
run(testOperationResultList)
+# CHECK-LABEL: TEST: testOperationAttributes
+def testOperationAttributes():
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ module = Module.parse(r"""
+ "some.op"() { some.attribute = 1 : i8,
+ other.attribute = 3.0,
+ dependent = "text" } : () -> ()
+ """, ctx)
+ op = module.body.operations[0]
+ assert len(op.attributes) == 3
+ iattr = IntegerAttr(op.attributes["some.attribute"])
+ fattr = FloatAttr(op.attributes["other.attribute"])
+ sattr = StringAttr(op.attributes["dependent"])
+ # CHECK: Attribute type i8, value 1
+ print(f"Attribute type {iattr.type}, value {iattr.value}")
+ # CHECK: Attribute type f64, value 3.0
+ print(f"Attribute type {fattr.type}, value {fattr.value}")
+ # CHECK: Attribute value text
+ print(f"Attribute value {sattr.value}")
+
+ # We don't know in which order the attributes are stored.
+ # CHECK-DAG: NamedAttribute(dependent="text")
+ # CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64)
+ # CHECK-DAG: NamedAttribute(some.attribute=1 : i8)
+ for attr in op.attributes:
+ print(str(attr))
+
+ # Check that exceptions are raised as expected.
+ try:
+ op.attributes["does_not_exist"]
+ except KeyError:
+ pass
+ else:
+ assert False, "expected KeyError on accessing a non-existent attribute"
+
+ try:
+ op.attributes[42]
+ except IndexError:
+ pass
+ else:
+ assert False, "expected IndexError on accessing an out-of-bounds attribute"
+
+
+run(testOperationAttributes)
+
+
# CHECK-LABEL: TEST: testOperationPrint
def testOperationPrint():
ctx = Context()
More information about the Mlir-commits
mailing list