[Mlir-commits] [mlir] 0b10fde - [mlir] Move PyConcreteAttribute to header. NFC.
Alex Zinenko
llvmlistbot at llvm.org
Thu Apr 22 07:12:07 PDT 2021
Author: Alex Zinenko
Date: 2021-04-22T16:11:59+02:00
New Revision: 0b10fdedf96eb228d782897a05b59edb8a057d18
URL: https://github.com/llvm/llvm-project/commit/0b10fdedf96eb228d782897a05b59edb8a057d18
DIFF: https://github.com/llvm/llvm-project/commit/0b10fdedf96eb228d782897a05b59edb8a057d18.diff
LOG: [mlir] Move PyConcreteAttribute to header. NFC.
This allows out-of-tree users to derive PyConcreteAttribute to bind custom
attributes.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D101063
Added:
Modified:
mlir/lib/Bindings/Python/IRAttributes.cpp
mlir/lib/Bindings/Python/IRModule.h
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index b5e3c5c9c94b7..0af762d93acb0 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -27,46 +27,6 @@ static MlirStringRef toMlirStringRef(const std::string &s) {
return mlirStringRefCreate(s.data(), s.size());
}
-/// CRTP base classes for Python attributes that subclass Attribute and should
-/// be castable from it (i.e. via something like StringAttr(attr)).
-/// By default, attribute class hierarchies are one level deep (i.e. a
-/// concrete attribute class extends PyAttribute); however, intermediate
-/// python-visible base classes can be modeled by specifying a BaseTy.
-template <typename DerivedTy, typename BaseTy = PyAttribute>
-class PyConcreteAttribute : public BaseTy {
-public:
- // Derived classes must define statics for:
- // IsAFunctionTy isaFunction
- // const char *pyClassName
- using ClassTy = py::class_<DerivedTy, BaseTy>;
- using IsAFunctionTy = bool (*)(MlirAttribute);
-
- PyConcreteAttribute() = default;
- PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
- : BaseTy(std::move(contextRef), attr) {}
- PyConcreteAttribute(PyAttribute &orig)
- : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {}
-
- static MlirAttribute castFrom(PyAttribute &orig) {
- if (!DerivedTy::isaFunction(orig)) {
- auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
- throw SetPyError(PyExc_ValueError, Twine("Cannot cast attribute to ") +
- DerivedTy::pyClassName +
- " (from " + origRepr + ")");
- }
- return orig;
- }
-
- static void bind(py::module &m) {
- auto cls = ClassTy(m, DerivedTy::pyClassName, py::buffer_protocol());
- cls.def(py::init<PyAttribute &>(), py::keep_alive<0, 1>());
- DerivedTy::bindDerived(cls);
- }
-
- /// Implemented by derived classes to add methods to the Python subclass.
- static void bindDerived(ClassTy &m) {}
-};
-
class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
public:
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 861673abc7018..f3f5ee5edf523 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -642,6 +642,46 @@ class PyNamedAttribute {
std::unique_ptr<std::string> ownedName;
};
+/// CRTP base classes for Python attributes that subclass Attribute and should
+/// be castable from it (i.e. via something like StringAttr(attr)).
+/// By default, attribute class hierarchies are one level deep (i.e. a
+/// concrete attribute class extends PyAttribute); however, intermediate
+/// python-visible base classes can be modeled by specifying a BaseTy.
+template <typename DerivedTy, typename BaseTy = PyAttribute>
+class PyConcreteAttribute : public BaseTy {
+public:
+ // Derived classes must define statics for:
+ // IsAFunctionTy isaFunction
+ // const char *pyClassName
+ using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
+ using IsAFunctionTy = bool (*)(MlirAttribute);
+
+ PyConcreteAttribute() = default;
+ PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
+ : BaseTy(std::move(contextRef), attr) {}
+ PyConcreteAttribute(PyAttribute &orig)
+ : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {}
+
+ static MlirAttribute castFrom(PyAttribute &orig) {
+ if (!DerivedTy::isaFunction(orig)) {
+ auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>();
+ throw SetPyError(PyExc_ValueError,
+ llvm::Twine("Cannot cast attribute to ") +
+ DerivedTy::pyClassName + " (from " + origRepr + ")");
+ }
+ return orig;
+ }
+
+ static void bind(pybind11::module &m) {
+ auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol());
+ cls.def(pybind11::init<PyAttribute &>(), pybind11::keep_alive<0, 1>());
+ DerivedTy::bindDerived(cls);
+ }
+
+ /// Implemented by derived classes to add methods to the Python subclass.
+ static void bindDerived(ClassTy &m) {}
+};
+
/// Wrapper around the generic MlirType.
/// The lifetime of a type is bound by the PyContext that created it.
class PyType : public BaseContextObject {
More information about the Mlir-commits
mailing list