[Mlir-commits] [mlir] 32e2fec - [mlir] Move PyConcreteType to header. NFC.
John Demme
llvmlistbot at llvm.org
Wed Apr 28 16:46:31 PDT 2021
Author: John Demme
Date: 2021-04-28T16:40:56-07:00
New Revision: 32e2fec726beec2800f3db493bea8b4bdbbde936
URL: https://github.com/llvm/llvm-project/commit/32e2fec726beec2800f3db493bea8b4bdbbde936
DIFF: https://github.com/llvm/llvm-project/commit/32e2fec726beec2800f3db493bea8b4bdbbde936.diff
LOG: [mlir] Move PyConcreteType to header. NFC.
This allows out-of-tree users to derive PyConcreteType to bind custom
types.
The Type version of https://reviews.llvm.org/D101063/new/
Reviewed By: stellaraccident
Differential Revision: https://reviews.llvm.org/D101496
Added:
Modified:
mlir/lib/Bindings/Python/IRModule.h
mlir/lib/Bindings/Python/IRTypes.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index ff3faeefd994..292080d911d1 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -705,6 +705,49 @@ class PyType : public BaseContextObject {
MlirType type;
};
+/// CRTP base classes for Python types that subclass Type and should be
+/// castable from it (i.e. via something like IntegerType(t)).
+/// By default, type class hierarchies are one level deep (i.e. a
+/// concrete type class extends PyType); however, intermediate python-visible
+/// base classes can be modeled by specifying a BaseTy.
+template <typename DerivedTy, typename BaseTy = PyType>
+class PyConcreteType : public BaseTy {
+public:
+ // Derived classes must define statics for:
+ // IsAFunctionTy isaFunction
+ // const char *pyClassName
+ using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
+ using IsAFunctionTy = bool (*)(MlirType);
+
+ PyConcreteType() = default;
+ PyConcreteType(PyMlirContextRef contextRef, MlirType t)
+ : BaseTy(std::move(contextRef), t) {}
+ PyConcreteType(PyType &orig)
+ : PyConcreteType(orig.getContext(), castFrom(orig)) {}
+
+ static MlirType castFrom(PyType &orig) {
+ if (!DerivedTy::isaFunction(orig)) {
+ auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>();
+ throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") +
+ DerivedTy::pyClassName +
+ " (from " + origRepr + ")");
+ }
+ return orig;
+ }
+
+ static void bind(pybind11::module &m) {
+ auto cls = ClassTy(m, DerivedTy::pyClassName);
+ cls.def(pybind11::init<PyType &>(), pybind11::keep_alive<0, 1>());
+ cls.def_static("isinstance", [](PyType &otherType) -> bool {
+ return DerivedTy::isaFunction(otherType);
+ });
+ DerivedTy::bindDerived(cls);
+ }
+
+ /// Implemented by derived classes to add methods to the Python subclass.
+ static void bindDerived(ClassTy &m) {}
+};
+
/// Wrapper around the generic MlirValue.
/// Values are managed completely by the operation that resulted in their
/// definition. For op result value, this is the operation that defines the
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 421df4dab7ea..b6875c76e09c 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -28,49 +28,6 @@ static int mlirTypeIsAIntegerOrFloat(MlirType type) {
mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
}
-/// CRTP base classes for Python types that subclass Type and should be
-/// castable from it (i.e. via something like IntegerType(t)).
-/// By default, type class hierarchies are one level deep (i.e. a
-/// concrete type class extends PyType); however, intermediate python-visible
-/// base classes can be modeled by specifying a BaseTy.
-template <typename DerivedTy, typename BaseTy = PyType>
-class PyConcreteType : public BaseTy {
-public:
- // Derived classes must define statics for:
- // IsAFunctionTy isaFunction
- // const char *pyClassName
- using ClassTy = py::class_<DerivedTy, BaseTy>;
- using IsAFunctionTy = bool (*)(MlirType);
-
- PyConcreteType() = default;
- PyConcreteType(PyMlirContextRef contextRef, MlirType t)
- : BaseTy(std::move(contextRef), t) {}
- PyConcreteType(PyType &orig)
- : PyConcreteType(orig.getContext(), castFrom(orig)) {}
-
- static MlirType castFrom(PyType &orig) {
- if (!DerivedTy::isaFunction(orig)) {
- auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
- throw SetPyError(PyExc_ValueError, Twine("Cannot cast type to ") +
- DerivedTy::pyClassName +
- " (from " + origRepr + ")");
- }
- return orig;
- }
-
- static void bind(py::module &m) {
- auto cls = ClassTy(m, DerivedTy::pyClassName);
- cls.def(py::init<PyType &>(), py::keep_alive<0, 1>());
- cls.def_static("isinstance", [](PyType &otherType) -> bool {
- return DerivedTy::isaFunction(otherType);
- });
- DerivedTy::bindDerived(cls);
- }
-
- /// Implemented by derived classes to add methods to the Python subclass.
- static void bindDerived(ClassTy &m) {}
-};
-
class PyIntegerType : public PyConcreteType<PyIntegerType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
More information about the Mlir-commits
mailing list