[Mlir-commits] [mlir] 7403e3e - Extend PyConcreteType to support intermediate base classes.

Stella Laurenzo llvmlistbot at llvm.org
Sun Sep 6 23:40:42 PDT 2020


Author: Stella Laurenzo
Date: 2020-09-06T23:39:47-07:00
New Revision: 7403e3ee324018c79d0e55532240952dbaa4fcbe

URL: https://github.com/llvm/llvm-project/commit/7403e3ee324018c79d0e55532240952dbaa4fcbe
DIFF: https://github.com/llvm/llvm-project/commit/7403e3ee324018c79d0e55532240952dbaa4fcbe.diff

LOG: Extend PyConcreteType to support intermediate base classes.

* Resolves todos from D87091.
* Also modifies PyConcreteAttribute to follow suite (should be useful for ElementsAttr and friends).
* Adds a test to ensure that the ShapedType base class functions as expected.

Differential Revision: https://reviews.llvm.org/D87208

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/test/Bindings/Python/ir_types.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 149e231aed0b..bf1235a77d08 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -221,34 +221,37 @@ namespace {
 
 /// CRTP base classes for Python attributes that subclass Attribute and should
 /// be castable from it (i.e. via something like StringAttr(attr)).
-template <typename T>
-class PyConcreteAttribute : public PyAttribute {
+/// By default, attribute class hierarchies are one level deep (i.e. a
+/// concrete attribute class extends PyAttribute); however, intermediate
+/// python-visible base classes can be modeled by specifying a BaseTy.
+template <typename DerivedTy, typename BaseTy = PyAttribute>
+class PyConcreteAttribute : public BaseTy {
 public:
   // Derived classes must define statics for:
   //   IsAFunctionTy isaFunction
   //   const char *pyClassName
-  using ClassTy = py::class_<T, PyAttribute>;
+  using ClassTy = py::class_<DerivedTy, PyAttribute>;
   using IsAFunctionTy = int (*)(MlirAttribute);
 
   PyConcreteAttribute() = default;
-  PyConcreteAttribute(MlirAttribute attr) : PyAttribute(attr) {}
+  PyConcreteAttribute(MlirAttribute attr) : BaseTy(attr) {}
   PyConcreteAttribute(PyAttribute &orig)
       : PyConcreteAttribute(castFrom(orig)) {}
 
   static MlirAttribute castFrom(PyAttribute &orig) {
-    if (!T::isaFunction(orig.attr)) {
+    if (!DerivedTy::isaFunction(orig.attr)) {
       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
       throw SetPyError(PyExc_ValueError,
                        llvm::Twine("Cannot cast attribute to ") +
-                           T::pyClassName + " (from " + origRepr + ")");
+                           DerivedTy::pyClassName + " (from " + origRepr + ")");
     }
     return orig.attr;
   }
 
   static void bind(py::module &m) {
-    auto cls = ClassTy(m, T::pyClassName);
+    auto cls = ClassTy(m, DerivedTy::pyClassName);
     cls.def(py::init<PyAttribute &>(), py::keep_alive<0, 1>());
-    T::bindDerived(cls);
+    DerivedTy::bindDerived(cls);
   }
 
   /// Implemented by derived classes to add methods to the Python subclass.
@@ -301,33 +304,36 @@ namespace {
 
 /// CRTP base classes for Python types that subclass Type and should be
 /// castable from it (i.e. via something like IntegerType(t)).
-template <typename T>
-class PyConcreteType : public PyType {
+/// By default, type class hierarchies are one level deep (i.e. a
+/// concrete type class extends PyType); however, intermediate python-visible
+/// base classes can be modeled by specifying a BaseTy.
+template <typename DerivedTy, typename BaseTy = PyType>
+class PyConcreteType : public BaseTy {
 public:
   // Derived classes must define statics for:
   //   IsAFunctionTy isaFunction
   //   const char *pyClassName
-  using ClassTy = py::class_<T, PyType>;
+  using ClassTy = py::class_<DerivedTy, BaseTy>;
   using IsAFunctionTy = int (*)(MlirType);
 
   PyConcreteType() = default;
-  PyConcreteType(MlirType t) : PyType(t) {}
-  PyConcreteType(PyType &orig) : PyType(castFrom(orig)) {}
+  PyConcreteType(MlirType t) : BaseTy(t) {}
+  PyConcreteType(PyType &orig) : PyConcreteType(castFrom(orig)) {}
 
   static MlirType castFrom(PyType &orig) {
-    if (!T::isaFunction(orig.type)) {
+    if (!DerivedTy::isaFunction(orig.type)) {
       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
       throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") +
-                                             T::pyClassName + " (from " +
-                                             origRepr + ")");
+                                             DerivedTy::pyClassName +
+                                             " (from " + origRepr + ")");
     }
     return orig.type;
   }
 
   static void bind(py::module &m) {
-    auto cls = ClassTy(m, T::pyClassName);
+    auto cls = ClassTy(m, DerivedTy::pyClassName);
     cls.def(py::init<PyType &>(), py::keep_alive<0, 1>());
-    T::bindDerived(cls);
+    DerivedTy::bindDerived(cls);
   }
 
   /// Implemented by derived classes to add methods to the Python subclass.
@@ -590,142 +596,130 @@ class PyShapedType : public PyConcreteType<PyShapedType> {
 };
 
 /// Vector Type subclass - VectorType.
-class PyVectorType : public PyShapedType {
+class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
   static constexpr const char *pyClassName = "VectorType";
-  using PyShapedType::PyShapedType;
-  // TODO: Switch back to bindDerived by making the ClassTy modifiable by
-  // subclasses, exposing the ShapedType hierarchy.
-  static void bind(py::module &m) {
-    py::class_<PyVectorType, PyShapedType>(m, pyClassName)
-        .def(py::init<PyType &>(), py::keep_alive<0, 1>())
-        .def_static(
-            "get_vector",
-            // TODO: Make the location optional and create a default location.
-            [](std::vector<int64_t> shape, PyType &elementType,
-               PyLocation &loc) {
-              MlirType t = mlirVectorTypeGetChecked(shape.size(), shape.data(),
-                                                    elementType.type, loc.loc);
-              // TODO: Rework error reporting once diagnostic engine is exposed
-              // in C API.
-              if (mlirTypeIsNull(t)) {
-                throw SetPyError(
-                    PyExc_ValueError,
-                    llvm::Twine("invalid '") +
-                        py::repr(py::cast(elementType)).cast<std::string>() +
-                        "' and expected floating point or integer type.");
-              }
-              return PyVectorType(t);
-            },
-            py::keep_alive<0, 2>(), "Create a vector type");
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get_vector",
+        // TODO: Make the location optional and create a default location.
+        [](std::vector<int64_t> shape, PyType &elementType, PyLocation &loc) {
+          MlirType t = mlirVectorTypeGetChecked(shape.size(), shape.data(),
+                                                elementType.type, loc.loc);
+          // TODO: Rework error reporting once diagnostic engine is exposed
+          // in C API.
+          if (mlirTypeIsNull(t)) {
+            throw SetPyError(
+                PyExc_ValueError,
+                llvm::Twine("invalid '") +
+                    py::repr(py::cast(elementType)).cast<std::string>() +
+                    "' and expected floating point or integer type.");
+          }
+          return PyVectorType(t);
+        },
+        py::keep_alive<0, 2>(), "Create a vector type");
   }
 };
 
 /// Ranked Tensor Type subclass - RankedTensorType.
-class PyRankedTensorType : public PyShapedType {
+class PyRankedTensorType
+    : public PyConcreteType<PyRankedTensorType, PyShapedType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
   static constexpr const char *pyClassName = "RankedTensorType";
-  using PyShapedType::PyShapedType;
-  // TODO: Switch back to bindDerived by making the ClassTy modifiable by
-  // subclasses, exposing the ShapedType hierarchy.
-  static void bind(py::module &m) {
-    py::class_<PyRankedTensorType, PyShapedType>(m, pyClassName)
-        .def(py::init<PyType &>(), py::keep_alive<0, 1>())
-        .def_static(
-            "get_ranked_tensor",
-            // TODO: Make the location optional and create a default location.
-            [](std::vector<int64_t> shape, PyType &elementType,
-               PyLocation &loc) {
-              MlirType t = mlirRankedTensorTypeGetChecked(
-                  shape.size(), shape.data(), elementType.type, loc.loc);
-              // TODO: Rework error reporting once diagnostic engine is exposed
-              // in C API.
-              if (mlirTypeIsNull(t)) {
-                throw SetPyError(
-                    PyExc_ValueError,
-                    llvm::Twine("invalid '") +
-                        py::repr(py::cast(elementType)).cast<std::string>() +
-                        "' and expected floating point, integer, vector or "
-                        "complex "
-                        "type.");
-              }
-              return PyRankedTensorType(t);
-            },
-            py::keep_alive<0, 2>(), "Create a ranked tensor type");
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get_ranked_tensor",
+        // TODO: Make the location optional and create a default location.
+        [](std::vector<int64_t> shape, PyType &elementType, PyLocation &loc) {
+          MlirType t = mlirRankedTensorTypeGetChecked(
+              shape.size(), shape.data(), elementType.type, loc.loc);
+          // TODO: Rework error reporting once diagnostic engine is exposed
+          // in C API.
+          if (mlirTypeIsNull(t)) {
+            throw SetPyError(
+                PyExc_ValueError,
+                llvm::Twine("invalid '") +
+                    py::repr(py::cast(elementType)).cast<std::string>() +
+                    "' and expected floating point, integer, vector or "
+                    "complex "
+                    "type.");
+          }
+          return PyRankedTensorType(t);
+        },
+        py::keep_alive<0, 2>(), "Create a ranked tensor type");
   }
 };
 
 /// Unranked Tensor Type subclass - UnrankedTensorType.
-class PyUnrankedTensorType : public PyShapedType {
+class PyUnrankedTensorType
+    : public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
   static constexpr const char *pyClassName = "UnrankedTensorType";
-  using PyShapedType::PyShapedType;
-  // TODO: Switch back to bindDerived by making the ClassTy modifiable by
-  // subclasses, exposing the ShapedType hierarchy.
-  static void bind(py::module &m) {
-    py::class_<PyUnrankedTensorType, PyShapedType>(m, pyClassName)
-        .def(py::init<PyType &>(), py::keep_alive<0, 1>())
-        .def_static(
-            "get_unranked_tensor",
-            // TODO: Make the location optional and create a default location.
-            [](PyType &elementType, PyLocation &loc) {
-              MlirType t =
-                  mlirUnrankedTensorTypeGetChecked(elementType.type, loc.loc);
-              // TODO: Rework error reporting once diagnostic engine is exposed
-              // in C API.
-              if (mlirTypeIsNull(t)) {
-                throw SetPyError(
-                    PyExc_ValueError,
-                    llvm::Twine("invalid '") +
-                        py::repr(py::cast(elementType)).cast<std::string>() +
-                        "' and expected floating point, integer, vector or "
-                        "complex "
-                        "type.");
-              }
-              return PyUnrankedTensorType(t);
-            },
-            py::keep_alive<0, 1>(), "Create a unranked tensor type");
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get_unranked_tensor",
+        // TODO: Make the location optional and create a default location.
+        [](PyType &elementType, PyLocation &loc) {
+          MlirType t =
+              mlirUnrankedTensorTypeGetChecked(elementType.type, loc.loc);
+          // TODO: Rework error reporting once diagnostic engine is exposed
+          // in C API.
+          if (mlirTypeIsNull(t)) {
+            throw SetPyError(
+                PyExc_ValueError,
+                llvm::Twine("invalid '") +
+                    py::repr(py::cast(elementType)).cast<std::string>() +
+                    "' and expected floating point, integer, vector or "
+                    "complex "
+                    "type.");
+          }
+          return PyUnrankedTensorType(t);
+        },
+        py::keep_alive<0, 1>(), "Create a unranked tensor type");
   }
 };
 
 /// Ranked MemRef Type subclass - MemRefType.
-class PyMemRefType : public PyShapedType {
+class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
   static constexpr const char *pyClassName = "MemRefType";
-  using PyShapedType::PyShapedType;
-  // TODO: Switch back to bindDerived by making the ClassTy modifiable by
-  // subclasses, exposing the ShapedType hierarchy.
-  static void bind(py::module &m) {
-    py::class_<PyMemRefType, PyShapedType>(m, pyClassName)
-        .def(py::init<PyType &>(), py::keep_alive<0, 1>())
-        // TODO: Add mlirMemRefTypeGet and mlirMemRefTypeGetAffineMap binding
-        // once the affine map binding is completed.
-        .def_static(
-            "get_contiguous_memref",
-            // TODO: Make the location optional and create a default location.
-            [](PyType &elementType, std::vector<int64_t> shape,
-               unsigned memorySpace, PyLocation &loc) {
-              MlirType t = mlirMemRefTypeContiguousGetChecked(
-                  elementType.type, shape.size(), shape.data(), memorySpace,
-                  loc.loc);
-              // TODO: Rework error reporting once diagnostic engine is exposed
-              // in C API.
-              if (mlirTypeIsNull(t)) {
-                throw SetPyError(
-                    PyExc_ValueError,
-                    llvm::Twine("invalid '") +
-                        py::repr(py::cast(elementType)).cast<std::string>() +
-                        "' and expected floating point, integer, vector or "
-                        "complex "
-                        "type.");
-              }
-              return PyMemRefType(t);
-            },
-            py::keep_alive<0, 1>(), "Create a memref type")
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    // TODO: Add mlirMemRefTypeGet and mlirMemRefTypeGetAffineMap binding
+    // once the affine map binding is completed.
+    c.def_static(
+         "get_contiguous_memref",
+         // TODO: Make the location optional and create a default location.
+         [](PyType &elementType, std::vector<int64_t> shape,
+            unsigned memorySpace, PyLocation &loc) {
+           MlirType t = mlirMemRefTypeContiguousGetChecked(
+               elementType.type, shape.size(), shape.data(), memorySpace,
+               loc.loc);
+           // TODO: Rework error reporting once diagnostic engine is exposed
+           // in C API.
+           if (mlirTypeIsNull(t)) {
+             throw SetPyError(
+                 PyExc_ValueError,
+                 llvm::Twine("invalid '") +
+                     py::repr(py::cast(elementType)).cast<std::string>() +
+                     "' and expected floating point, integer, vector or "
+                     "complex "
+                     "type.");
+           }
+           return PyMemRefType(t);
+         },
+         py::keep_alive<0, 1>(), "Create a memref type")
         .def_property_readonly(
             "num_affine_maps",
             [](PyMemRefType &self) -> intptr_t {
@@ -743,36 +737,34 @@ class PyMemRefType : public PyShapedType {
 };
 
 /// Unranked MemRef Type subclass - UnrankedMemRefType.
-class PyUnrankedMemRefType : public PyShapedType {
+class PyUnrankedMemRefType
+    : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
   static constexpr const char *pyClassName = "UnrankedMemRefType";
-  using PyShapedType::PyShapedType;
-  // TODO: Switch back to bindDerived by making the ClassTy modifiable by
-  // subclasses, exposing the ShapedType hierarchy.
-  static void bind(py::module &m) {
-    py::class_<PyUnrankedMemRefType, PyShapedType>(m, pyClassName)
-        .def(py::init<PyType &>(), py::keep_alive<0, 1>())
-        .def_static(
-            "get_unranked_memref",
-            // TODO: Make the location optional and create a default location.
-            [](PyType &elementType, unsigned memorySpace, PyLocation &loc) {
-              MlirType t = mlirUnrankedMemRefTypeGetChecked(
-                  elementType.type, memorySpace, loc.loc);
-              // TODO: Rework error reporting once diagnostic engine is exposed
-              // in C API.
-              if (mlirTypeIsNull(t)) {
-                throw SetPyError(
-                    PyExc_ValueError,
-                    llvm::Twine("invalid '") +
-                        py::repr(py::cast(elementType)).cast<std::string>() +
-                        "' and expected floating point, integer, vector or "
-                        "complex "
-                        "type.");
-              }
-              return PyUnrankedMemRefType(t);
-            },
-            py::keep_alive<0, 1>(), "Create a unranked memref type")
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+         "get_unranked_memref",
+         // TODO: Make the location optional and create a default location.
+         [](PyType &elementType, unsigned memorySpace, PyLocation &loc) {
+           MlirType t = mlirUnrankedMemRefTypeGetChecked(elementType.type,
+                                                         memorySpace, loc.loc);
+           // TODO: Rework error reporting once diagnostic engine is exposed
+           // in C API.
+           if (mlirTypeIsNull(t)) {
+             throw SetPyError(
+                 PyExc_ValueError,
+                 llvm::Twine("invalid '") +
+                     py::repr(py::cast(elementType)).cast<std::string>() +
+                     "' and expected floating point, integer, vector or "
+                     "complex "
+                     "type.");
+           }
+           return PyUnrankedMemRefType(t);
+         },
+         py::keep_alive<0, 1>(), "Create a unranked memref type")
         .def_property_readonly(
             "memory_space",
             [](PyUnrankedMemRefType &self) -> unsigned {

diff  --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py
index 00cd595843aa..4710bee27e37 100644
--- a/mlir/test/Bindings/Python/ir_types.py
+++ b/mlir/test/Bindings/Python/ir_types.py
@@ -177,11 +177,11 @@ def testComplexType():
 
 run(testComplexType)
 
-# CHECK-LABEL: TEST: testShapedType
+# CHECK-LABEL: TEST: testConcreteShapedType
 # Shaped type is not a kind of standard types, it is the base class for
 # vectors, memrefs and tensors, so this test case uses an instance of vector
-# to test the shaped type.
-def testShapedType():
+# to test the shaped type. The class hierarchy is preserved on the python side.
+def testConcreteShapedType():
   ctx = mlir.ir.Context()
   vector = mlir.ir.VectorType(ctx.parse_type("vector<2x3xf32>"))
   # CHECK: element type: f32
@@ -196,12 +196,25 @@ def testShapedType():
   print("whether the dim-th dimension is dynamic:", vector.is_dynamic_dim(0))
   # CHECK: dim size: 3
   print("dim size:", vector.get_dim_size(1))
-  # CHECK: False
-  print(vector.is_dynamic_size(3))
-  # CHECK: False
-  print(vector.is_dynamic_stride_or_offset(1))
+  # CHECK: is_dynamic_size: False
+  print("is_dynamic_size:", vector.is_dynamic_size(3))
+  # CHECK: is_dynamic_stride_or_offset: False
+  print("is_dynamic_stride_or_offset:", vector.is_dynamic_stride_or_offset(1))
+  # CHECK: isinstance(ShapedType): True
+  print("isinstance(ShapedType):", isinstance(vector, mlir.ir.ShapedType))
+
+run(testConcreteShapedType)
+
+# CHECK-LABEL: TEST: testAbstractShapedType
+# Tests that ShapedType operates as an abstract base class of a concrete
+# shaped type (using vector as an example).
+def testAbstractShapedType():
+  ctx = mlir.ir.Context()
+  vector = mlir.ir.ShapedType(ctx.parse_type("vector<2x3xf32>"))
+  # CHECK: element type: f32
+  print("element type:", vector.element_type)
 
-run(testShapedType)
+run(testAbstractShapedType)
 
 # CHECK-LABEL: TEST: testVectorType
 def testVectorType():


        


More information about the Mlir-commits mailing list