[Mlir-commits] [mlir] 3137c29 - Add initial python bindings for attributes.
Stella Laurenzo
llvmlistbot at llvm.org
Sun Aug 23 22:18:35 PDT 2020
Author: Stella Laurenzo
Date: 2020-08-23T22:16:23-07:00
New Revision: 3137c299269dd758c4c1630dc0c4621a1137eb7c
URL: https://github.com/llvm/llvm-project/commit/3137c299269dd758c4c1630dc0c4621a1137eb7c
DIFF: https://github.com/llvm/llvm-project/commit/3137c299269dd758c4c1630dc0c4621a1137eb7c.diff
LOG: Add initial python bindings for attributes.
* Generic mlir.ir.Attribute class.
* First standard attribute (mlir.ir.StringAttr), following the same pattern as generic vs standard types.
* NamedAttribute class.
Differential Revision: https://reviews.llvm.org/D86250
Added:
mlir/test/Bindings/Python/ir_attributes.py
Modified:
mlir/include/mlir-c/IR.h
mlir/lib/Bindings/Python/IRModules.cpp
mlir/lib/Bindings/Python/IRModules.h
mlir/lib/Bindings/Python/PybindUtils.cpp
mlir/lib/Bindings/Python/PybindUtils.h
mlir/test/Bindings/Python/ir_types.py
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 9293a40ebbab..225b53166306 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -336,6 +336,9 @@ void mlirTypeDump(MlirType type);
/** Parses an attribute. The attribute is owned by the context. */
MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr);
+/** Checks whether an attribute is null. */
+inline int mlirAttributeIsNull(MlirAttribute attr) { return !attr.ptr; }
+
/** Checks if two attributes are equal. */
int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2);
diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 188fdf39ff14..ae48e33d3530 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -9,6 +9,7 @@
#include "IRModules.h"
#include "PybindUtils.h"
+#include "mlir-c/StandardAttributes.h"
#include "mlir-c/StandardTypes.h"
namespace py = pybind11;
@@ -76,8 +77,52 @@ struct PyPrintAccumulator {
}
};
+/// Accumulates into a python string from a method that is expected to make
+/// one (no more, no less) call to the callback (asserts internally on
+/// violation).
+struct PySinglePartStringAccumulator {
+ void *getUserData() { return this; }
+
+ MlirStringCallback getCallback() {
+ return [](const char *part, intptr_t size, void *userData) {
+ PySinglePartStringAccumulator *accum =
+ static_cast<PySinglePartStringAccumulator *>(userData);
+ assert(!accum->invoked &&
+ "PySinglePartStringAccumulator called back multiple times");
+ accum->invoked = true;
+ accum->value = py::str(part, size);
+ };
+ }
+
+ py::str takeValue() {
+ assert(invoked && "PySinglePartStringAccumulator not called back");
+ return std::move(value);
+ }
+
+private:
+ py::str value;
+ bool invoked = false;
+};
+
} // namespace
+//------------------------------------------------------------------------------
+// PyAttribute.
+//------------------------------------------------------------------------------
+
+bool PyAttribute::operator==(const PyAttribute &other) {
+ return mlirAttributeEqual(attr, other.attr);
+}
+
+//------------------------------------------------------------------------------
+// PyNamedAttribute.
+//------------------------------------------------------------------------------
+
+PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
+ : ownedName(new std::string(std::move(ownedName))) {
+ namedAttr = mlirNamedAttributeGet(this->ownedName->c_str(), attr);
+}
+
//------------------------------------------------------------------------------
// PyType.
//------------------------------------------------------------------------------
@@ -86,6 +131,86 @@ bool PyType::operator==(const PyType &other) {
return mlirTypeEqual(type, other.type);
}
+//------------------------------------------------------------------------------
+// Standard attribute subclasses.
+//------------------------------------------------------------------------------
+
+namespace {
+
+/// CRTP base classes for Python attributes that subclass Attribute and should
+/// be castable from it (i.e. via something like StringAttr(attr)).
+template <typename T>
+class PyConcreteAttribute : public PyAttribute {
+public:
+ // Derived classes must define statics for:
+ // IsAFunctionTy isaFunction
+ // const char *pyClassName
+ using ClassTy = py::class_<T, PyAttribute>;
+ using IsAFunctionTy = int (*)(MlirAttribute);
+
+ PyConcreteAttribute() = default;
+ PyConcreteAttribute(MlirAttribute attr) : PyAttribute(attr) {}
+ PyConcreteAttribute(PyAttribute &orig)
+ : PyConcreteAttribute(castFrom(orig)) {}
+
+ static MlirAttribute castFrom(PyAttribute &orig) {
+ if (!T::isaFunction(orig.attr)) {
+ auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
+ throw SetPyError(PyExc_ValueError,
+ llvm::Twine("Cannot cast attribute to ") +
+ T::pyClassName + " (from " + origRepr + ")");
+ }
+ return orig.attr;
+ }
+
+ static void bind(py::module &m) {
+ auto cls = ClassTy(m, T::pyClassName);
+ cls.def(py::init<PyAttribute &>(), py::keep_alive<0, 1>());
+ T::bindDerived(cls);
+ }
+
+ /// Implemented by derived classes to add methods to the Python subclass.
+ static void bindDerived(ClassTy &m) {}
+};
+
+class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
+ static constexpr const char *pyClassName = "StringAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyMlirContext &context, std::string value) {
+ MlirAttribute attr =
+ mlirStringAttrGet(context.context, value.size(), &value[0]);
+ return PyStringAttribute(attr);
+ },
+ py::keep_alive<0, 1>(), "Gets a uniqued string attribute");
+ c.def_static(
+ "get_typed",
+ [](PyType &type, std::string value) {
+ MlirAttribute attr =
+ mlirStringAttrTypedGet(type.type, value.size(), &value[0]);
+ return PyStringAttribute(attr);
+ },
+ py::keep_alive<0, 1>(),
+ "Gets a uniqued string attribute associated to a type");
+ c.def_property_readonly(
+ "value",
+ [](PyStringAttribute &self) {
+ PySinglePartStringAccumulator accum;
+ mlirStringAttrGetValue(self.attr, accum.getCallback(),
+ accum.getUserData());
+ return accum.takeValue();
+ },
+ "Returns the value of the string attribute");
+ }
+};
+
+} // namespace
+
//------------------------------------------------------------------------------
// Standard type subclasses.
//------------------------------------------------------------------------------
@@ -118,9 +243,9 @@ class PyConcreteType : public PyType {
}
static void bind(py::module &m) {
- auto class_ = ClassTy(m, T::pyClassName);
- class_.def(py::init<PyType &>(), py::keep_alive<0, 1>());
- T::bindDerived(class_);
+ auto cls = ClassTy(m, T::pyClassName);
+ cls.def(py::init<PyType &>(), py::keep_alive<0, 1>());
+ T::bindDerived(cls);
}
/// Implemented by derived classes to add methods to the Python subclass.
@@ -135,21 +260,21 @@ class PyIntegerType : public PyConcreteType<PyIntegerType> {
static void bindDerived(ClassTy &c) {
c.def_static(
- "signless",
+ "get_signless",
[](PyMlirContext &context, unsigned width) {
MlirType t = mlirIntegerTypeGet(context.context, width);
return PyIntegerType(t);
},
py::keep_alive<0, 1>(), "Create a signless integer type");
c.def_static(
- "signed",
+ "get_signed",
[](PyMlirContext &context, unsigned width) {
MlirType t = mlirIntegerTypeSignedGet(context.context, width);
return PyIntegerType(t);
},
py::keep_alive<0, 1>(), "Create a signed integer type");
c.def_static(
- "unsigned",
+ "get_unsigned",
[](PyMlirContext &context, unsigned width) {
MlirType t = mlirIntegerTypeUnsignedGet(context.context, width);
return PyIntegerType(t);
@@ -195,6 +320,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
[](PyMlirContext &self, const std::string module) {
auto moduleRef =
mlirModuleCreateParse(self.context, module.c_str());
+ // TODO: Rework error reporting once diagnostic engine is exposed
+ // in C API.
if (mlirModuleIsNull(moduleRef)) {
throw SetPyError(
PyExc_ValueError,
@@ -203,10 +330,27 @@ void mlir::python::populateIRSubmodule(py::module &m) {
return PyModule(moduleRef);
},
py::keep_alive<0, 1>(), kContextParseDocstring)
+ .def(
+ "parse_attr",
+ [](PyMlirContext &self, std::string attrSpec) {
+ MlirAttribute type =
+ mlirAttributeParseGet(self.context, attrSpec.c_str());
+ // TODO: Rework error reporting once diagnostic engine is exposed
+ // in C API.
+ if (mlirAttributeIsNull(type)) {
+ throw SetPyError(PyExc_ValueError,
+ llvm::Twine("Unable to parse attribute: '") +
+ attrSpec + "'");
+ }
+ return PyAttribute(type);
+ },
+ py::keep_alive<0, 1>())
.def(
"parse_type",
[](PyMlirContext &self, std::string typeSpec) {
MlirType type = mlirTypeParseGet(self.context, typeSpec.c_str());
+ // TODO: Rework error reporting once diagnostic engine is exposed
+ // in C API.
if (mlirTypeIsNull(type)) {
throw SetPyError(PyExc_ValueError,
llvm::Twine("Unable to parse type: '") +
@@ -235,6 +379,79 @@ void mlir::python::populateIRSubmodule(py::module &m) {
},
kOperationStrDunderDocstring);
+ // Mapping of Type.
+ py::class_<PyAttribute>(m, "Attribute")
+ .def(
+ "get_named",
+ [](PyAttribute &self, std::string name) {
+ return PyNamedAttribute(self.attr, std::move(name));
+ },
+ py::keep_alive<0, 1>(), "Binds a name to the attribute")
+ .def("__eq__",
+ [](PyAttribute &self, py::object &other) {
+ try {
+ PyAttribute otherAttribute = other.cast<PyAttribute>();
+ return self == otherAttribute;
+ } catch (std::exception &e) {
+ return false;
+ }
+ })
+ .def(
+ "dump", [](PyAttribute &self) { mlirAttributeDump(self.attr); },
+ kDumpDocstring)
+ .def(
+ "__str__",
+ [](PyAttribute &self) {
+ PyPrintAccumulator printAccum;
+ mlirAttributePrint(self.attr, printAccum.getCallback(),
+ printAccum.getUserData());
+ return printAccum.join();
+ },
+ kTypeStrDunderDocstring)
+ .def("__repr__", [](PyAttribute &self) {
+ // Generally, assembly formats are not printed for __repr__ because
+ // this can cause exceptionally long debug output and exceptions.
+ // However, attribute values are generally considered useful and are
+ // printed. This may need to be re-evaluated if debug dumps end up
+ // being excessive.
+ PyPrintAccumulator printAccum;
+ printAccum.parts.append("Attribute(");
+ mlirAttributePrint(self.attr, printAccum.getCallback(),
+ printAccum.getUserData());
+ printAccum.parts.append(")");
+ return printAccum.join();
+ });
+
+ py::class_<PyNamedAttribute>(m, "NamedAttribute")
+ .def("__repr__",
+ [](PyNamedAttribute &self) {
+ PyPrintAccumulator printAccum;
+ printAccum.parts.append("NamedAttribute(");
+ printAccum.parts.append(self.namedAttr.name);
+ printAccum.parts.append("=");
+ mlirAttributePrint(self.namedAttr.attribute,
+ printAccum.getCallback(),
+ printAccum.getUserData());
+ printAccum.parts.append(")");
+ return printAccum.join();
+ })
+ .def_property_readonly(
+ "name",
+ [](PyNamedAttribute &self) {
+ return py::str(self.namedAttr.name, strlen(self.namedAttr.name));
+ },
+ "The name of the NamedAttribute binding")
+ .def_property_readonly(
+ "attr",
+ [](PyNamedAttribute &self) {
+ return PyAttribute(self.namedAttr.attribute);
+ },
+ py::keep_alive<0, 1>(),
+ "The underlying generic attribute of the NamedAttribute binding");
+
+ // Standard attribute bindings.
+ PyStringAttribute::bind(m);
+
// Mapping of Type.
py::class_<PyType>(m, "Type")
.def("__eq__",
diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h
index 4e90a9ae9795..1edfc1cead3e 100644
--- a/mlir/lib/Bindings/Python/IRModules.h
+++ b/mlir/lib/Bindings/Python/IRModules.h
@@ -45,6 +45,39 @@ class PyModule {
MlirModule module;
};
+/// Wrapper around the generic MlirAttribute.
+/// The lifetime of a type is bound by the PyContext that created it.
+class PyAttribute {
+public:
+ PyAttribute(MlirAttribute attr) : attr(attr) {}
+ bool operator==(const PyAttribute &other);
+
+ MlirAttribute attr;
+};
+
+/// Represents a Python MlirNamedAttr, carrying an optional owned name.
+/// TODO: Refactor this and the C-API to be based on an Identifier owned
+/// by the context so as to avoid ownership issues here.
+class PyNamedAttribute {
+public:
+ /// Constructs a PyNamedAttr that retains an owned name. This should be
+ /// used in any code that originates an MlirNamedAttribute from a python
+ /// string.
+ /// The lifetime of the PyNamedAttr must extend to the lifetime of the
+ /// passed attribute.
+ PyNamedAttribute(MlirAttribute attr, std::string ownedName);
+
+ MlirNamedAttribute namedAttr;
+
+private:
+ // Since the MlirNamedAttr contains an internal pointer to the actual
+ // memory of the owned string, it must be heap allocated to remain valid.
+ // Otherwise, strings that fit within the small object optimization threshold
+ // will have their memory address change as the containing object is moved,
+ // resulting in an invalid aliased pointer.
+ std::unique_ptr<std::string> ownedName;
+};
+
/// Wrapper around the generic MlirType.
/// The lifetime of a type is bound by the PyContext that created it.
class PyType {
diff --git a/mlir/lib/Bindings/Python/PybindUtils.cpp b/mlir/lib/Bindings/Python/PybindUtils.cpp
index 9013c0669794..bd80b8c14702 100644
--- a/mlir/lib/Bindings/Python/PybindUtils.cpp
+++ b/mlir/lib/Bindings/Python/PybindUtils.cpp
@@ -10,8 +10,8 @@
namespace py = pybind11;
-pybind11::error_already_set mlir::python::SetPyError(PyObject *excClass,
- llvm::Twine message) {
+pybind11::error_already_set
+mlir::python::SetPyError(PyObject *excClass, const llvm::Twine &message) {
auto messageStr = message.str();
PyErr_SetString(excClass, messageStr.c_str());
return pybind11::error_already_set();
diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h
index 1a82f8e824ec..0c0e069fae03 100644
--- a/mlir/lib/Bindings/Python/PybindUtils.h
+++ b/mlir/lib/Bindings/Python/PybindUtils.h
@@ -20,7 +20,8 @@ namespace python {
// python runtime.
// Correct usage:
// throw SetPyError(PyExc_ValueError, "Foobar'd");
-pybind11::error_already_set SetPyError(PyObject *excClass, llvm::Twine message);
+pybind11::error_already_set SetPyError(PyObject *excClass,
+ const llvm::Twine &message);
} // namespace python
} // namespace mlir
diff --git a/mlir/test/Bindings/Python/ir_attributes.py b/mlir/test/Bindings/Python/ir_attributes.py
new file mode 100644
index 000000000000..328dfb40b972
--- /dev/null
+++ b/mlir/test/Bindings/Python/ir_attributes.py
@@ -0,0 +1,119 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import mlir
+
+def run(f):
+ print("\nTEST:", f.__name__)
+ f()
+
+
+# CHECK-LABEL: TEST: testParsePrint
+def testParsePrint():
+ ctx = mlir.ir.Context()
+ t = ctx.parse_attr('"hello"')
+ # CHECK: "hello"
+ print(str(t))
+ # CHECK: Attribute("hello")
+ print(repr(t))
+
+run(testParsePrint)
+
+
+# CHECK-LABEL: TEST: testParseError
+# TODO: Hook the diagnostic manager to capture a more meaningful error
+# message.
+def testParseError():
+ ctx = mlir.ir.Context()
+ try:
+ t = ctx.parse_attr("BAD_ATTR_DOES_NOT_EXIST")
+ except ValueError as e:
+ # CHECK: Unable to parse attribute: 'BAD_ATTR_DOES_NOT_EXIST'
+ print("testParseError:", e)
+ else:
+ print("Exception not produced")
+
+run(testParseError)
+
+
+# CHECK-LABEL: TEST: testAttrEq
+def testAttrEq():
+ ctx = mlir.ir.Context()
+ a1 = ctx.parse_attr('"attr1"')
+ a2 = ctx.parse_attr('"attr2"')
+ a3 = ctx.parse_attr('"attr1"')
+ # CHECK: a1 == a1: True
+ print("a1 == a1:", a1 == a1)
+ # CHECK: a1 == a2: False
+ print("a1 == a2:", a1 == a2)
+ # CHECK: a1 == a3: True
+ print("a1 == a3:", a1 == a3)
+ # CHECK: a1 == None: False
+ print("a1 == None:", a1 == None)
+
+run(testAttrEq)
+
+
+# CHECK-LABEL: TEST: testAttrEqDoesNotRaise
+def testAttrEqDoesNotRaise():
+ ctx = mlir.ir.Context()
+ a1 = ctx.parse_attr('"attr1"')
+ not_an_attr = "foo"
+ # CHECK: False
+ print(a1 == not_an_attr)
+ # CHECK: False
+ print(a1 == None)
+ # CHECK: True
+ print(a1 != None)
+
+run(testAttrEqDoesNotRaise)
+
+
+# CHECK-LABEL: TEST: testStandardAttrCasts
+def testStandardAttrCasts():
+ ctx = mlir.ir.Context()
+ a1 = ctx.parse_attr('"attr1"')
+ astr = mlir.ir.StringAttr(a1)
+ aself = mlir.ir.StringAttr(astr)
+ # CHECK: Attribute("attr1")
+ print(repr(astr))
+ try:
+ tillegal = mlir.ir.StringAttr(ctx.parse_attr("1.0"))
+ except ValueError as e:
+ # CHECK: ValueError: Cannot cast attribute to StringAttr (from Attribute(1.000000e+00 : f64))
+ print("ValueError:", e)
+ else:
+ print("Exception not produced")
+
+run(testStandardAttrCasts)
+
+
+# CHECK-LABEL: TEST: testStringAttr
+def testStringAttr():
+ ctx = mlir.ir.Context()
+ sattr = mlir.ir.StringAttr(ctx.parse_attr('"stringattr"'))
+ # CHECK: sattr value: stringattr
+ print("sattr value:", sattr.value)
+
+ # Test factory methods.
+ # CHECK: default_get: "foobar"
+ print("default_get:", mlir.ir.StringAttr.get(ctx, "foobar"))
+ # CHECK: typed_get: "12345" : i32
+ print("typed_get:", mlir.ir.StringAttr.get_typed(
+ mlir.ir.IntegerType.get_signless(ctx, 32), "12345"))
+
+run(testStringAttr)
+
+
+# CHECK-LABEL: TEST: testNamedAttr
+def testNamedAttr():
+ ctx = mlir.ir.Context()
+ a = ctx.parse_attr('"stringattr"')
+ named = a.get_named("foobar") # Note: under the small object threshold
+ # CHECK: attr: "stringattr"
+ print("attr:", named.attr)
+ # CHECK: name: foobar
+ print("name:", named.name)
+ # CHECK: named: NamedAttribute(foobar="stringattr")
+ print("named:", named)
+
+run(testNamedAttr)
diff --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py
index cc66b1fdb208..1dce0a95c812 100644
--- a/mlir/test/Bindings/Python/ir_types.py
+++ b/mlir/test/Bindings/Python/ir_types.py
@@ -117,10 +117,10 @@ def testIntegerType():
print("u32 unsigned:", u32.is_unsigned)
# CHECK: signless: i16
- print("signless:", mlir.ir.IntegerType.signless(ctx, 16))
+ print("signless:", mlir.ir.IntegerType.get_signless(ctx, 16))
# CHECK: signed: si8
- print("signed:", mlir.ir.IntegerType.signed(ctx, 8))
+ print("signed:", mlir.ir.IntegerType.get_signed(ctx, 8))
# CHECK: unsigned: ui64
- print("unsigned:", mlir.ir.IntegerType.unsigned(ctx, 64))
+ print("unsigned:", mlir.ir.IntegerType.get_unsigned(ctx, 64))
run(testIntegerType)
More information about the Mlir-commits
mailing list