[Mlir-commits] [mlir] 32e2fec - [mlir] Move PyConcreteType to header. NFC.

John Demme llvmlistbot at llvm.org
Wed Apr 28 16:46:31 PDT 2021


Author: John Demme
Date: 2021-04-28T16:40:56-07:00
New Revision: 32e2fec726beec2800f3db493bea8b4bdbbde936

URL: https://github.com/llvm/llvm-project/commit/32e2fec726beec2800f3db493bea8b4bdbbde936
DIFF: https://github.com/llvm/llvm-project/commit/32e2fec726beec2800f3db493bea8b4bdbbde936.diff

LOG: [mlir] Move PyConcreteType to header. NFC.

This allows out-of-tree users to derive PyConcreteType to bind custom
types.

The Type version of https://reviews.llvm.org/D101063/new/

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D101496

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRModule.h
    mlir/lib/Bindings/Python/IRTypes.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index ff3faeefd994..292080d911d1 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -705,6 +705,49 @@ class PyType : public BaseContextObject {
   MlirType type;
 };
 
+/// CRTP base classes for Python types that subclass Type and should be
+/// castable from it (i.e. via something like IntegerType(t)).
+/// By default, type class hierarchies are one level deep (i.e. a
+/// concrete type class extends PyType); however, intermediate python-visible
+/// base classes can be modeled by specifying a BaseTy.
+template <typename DerivedTy, typename BaseTy = PyType>
+class PyConcreteType : public BaseTy {
+public:
+  // Derived classes must define statics for:
+  //   IsAFunctionTy isaFunction
+  //   const char *pyClassName
+  using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
+  using IsAFunctionTy = bool (*)(MlirType);
+
+  PyConcreteType() = default;
+  PyConcreteType(PyMlirContextRef contextRef, MlirType t)
+      : BaseTy(std::move(contextRef), t) {}
+  PyConcreteType(PyType &orig)
+      : PyConcreteType(orig.getContext(), castFrom(orig)) {}
+
+  static MlirType castFrom(PyType &orig) {
+    if (!DerivedTy::isaFunction(orig)) {
+      auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>();
+      throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") +
+                                             DerivedTy::pyClassName +
+                                             " (from " + origRepr + ")");
+    }
+    return orig;
+  }
+
+  static void bind(pybind11::module &m) {
+    auto cls = ClassTy(m, DerivedTy::pyClassName);
+    cls.def(pybind11::init<PyType &>(), pybind11::keep_alive<0, 1>());
+    cls.def_static("isinstance", [](PyType &otherType) -> bool {
+      return DerivedTy::isaFunction(otherType);
+    });
+    DerivedTy::bindDerived(cls);
+  }
+
+  /// Implemented by derived classes to add methods to the Python subclass.
+  static void bindDerived(ClassTy &m) {}
+};
+
 /// Wrapper around the generic MlirValue.
 /// Values are managed completely by the operation that resulted in their
 /// definition. For op result value, this is the operation that defines the

diff  --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 421df4dab7ea..b6875c76e09c 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -28,49 +28,6 @@ static int mlirTypeIsAIntegerOrFloat(MlirType type) {
          mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
 }
 
-/// CRTP base classes for Python types that subclass Type and should be
-/// castable from it (i.e. via something like IntegerType(t)).
-/// By default, type class hierarchies are one level deep (i.e. a
-/// concrete type class extends PyType); however, intermediate python-visible
-/// base classes can be modeled by specifying a BaseTy.
-template <typename DerivedTy, typename BaseTy = PyType>
-class PyConcreteType : public BaseTy {
-public:
-  // Derived classes must define statics for:
-  //   IsAFunctionTy isaFunction
-  //   const char *pyClassName
-  using ClassTy = py::class_<DerivedTy, BaseTy>;
-  using IsAFunctionTy = bool (*)(MlirType);
-
-  PyConcreteType() = default;
-  PyConcreteType(PyMlirContextRef contextRef, MlirType t)
-      : BaseTy(std::move(contextRef), t) {}
-  PyConcreteType(PyType &orig)
-      : PyConcreteType(orig.getContext(), castFrom(orig)) {}
-
-  static MlirType castFrom(PyType &orig) {
-    if (!DerivedTy::isaFunction(orig)) {
-      auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
-      throw SetPyError(PyExc_ValueError, Twine("Cannot cast type to ") +
-                                             DerivedTy::pyClassName +
-                                             " (from " + origRepr + ")");
-    }
-    return orig;
-  }
-
-  static void bind(py::module &m) {
-    auto cls = ClassTy(m, DerivedTy::pyClassName);
-    cls.def(py::init<PyType &>(), py::keep_alive<0, 1>());
-    cls.def_static("isinstance", [](PyType &otherType) -> bool {
-      return DerivedTy::isaFunction(otherType);
-    });
-    DerivedTy::bindDerived(cls);
-  }
-
-  /// Implemented by derived classes to add methods to the Python subclass.
-  static void bindDerived(ClassTy &m) {}
-};
-
 class PyIntegerType : public PyConcreteType<PyIntegerType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;


        


More information about the Mlir-commits mailing list