[Mlir-commits] [mlir] 7403e3e - Extend PyConcreteType to support intermediate base classes.
Stella Laurenzo
llvmlistbot at llvm.org
Sun Sep 6 23:40:42 PDT 2020
Author: Stella Laurenzo
Date: 2020-09-06T23:39:47-07:00
New Revision: 7403e3ee324018c79d0e55532240952dbaa4fcbe
URL: https://github.com/llvm/llvm-project/commit/7403e3ee324018c79d0e55532240952dbaa4fcbe
DIFF: https://github.com/llvm/llvm-project/commit/7403e3ee324018c79d0e55532240952dbaa4fcbe.diff
LOG: Extend PyConcreteType to support intermediate base classes.
* Resolves todos from D87091.
* Also modifies PyConcreteAttribute to follow suite (should be useful for ElementsAttr and friends).
* Adds a test to ensure that the ShapedType base class functions as expected.
Differential Revision: https://reviews.llvm.org/D87208
Added:
Modified:
mlir/lib/Bindings/Python/IRModules.cpp
mlir/test/Bindings/Python/ir_types.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 149e231aed0b..bf1235a77d08 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -221,34 +221,37 @@ namespace {
/// CRTP base classes for Python attributes that subclass Attribute and should
/// be castable from it (i.e. via something like StringAttr(attr)).
-template <typename T>
-class PyConcreteAttribute : public PyAttribute {
+/// By default, attribute class hierarchies are one level deep (i.e. a
+/// concrete attribute class extends PyAttribute); however, intermediate
+/// python-visible base classes can be modeled by specifying a BaseTy.
+template <typename DerivedTy, typename BaseTy = PyAttribute>
+class PyConcreteAttribute : public BaseTy {
public:
// Derived classes must define statics for:
// IsAFunctionTy isaFunction
// const char *pyClassName
- using ClassTy = py::class_<T, PyAttribute>;
+ using ClassTy = py::class_<DerivedTy, PyAttribute>;
using IsAFunctionTy = int (*)(MlirAttribute);
PyConcreteAttribute() = default;
- PyConcreteAttribute(MlirAttribute attr) : PyAttribute(attr) {}
+ PyConcreteAttribute(MlirAttribute attr) : BaseTy(attr) {}
PyConcreteAttribute(PyAttribute &orig)
: PyConcreteAttribute(castFrom(orig)) {}
static MlirAttribute castFrom(PyAttribute &orig) {
- if (!T::isaFunction(orig.attr)) {
+ if (!DerivedTy::isaFunction(orig.attr)) {
auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
throw SetPyError(PyExc_ValueError,
llvm::Twine("Cannot cast attribute to ") +
- T::pyClassName + " (from " + origRepr + ")");
+ DerivedTy::pyClassName + " (from " + origRepr + ")");
}
return orig.attr;
}
static void bind(py::module &m) {
- auto cls = ClassTy(m, T::pyClassName);
+ auto cls = ClassTy(m, DerivedTy::pyClassName);
cls.def(py::init<PyAttribute &>(), py::keep_alive<0, 1>());
- T::bindDerived(cls);
+ DerivedTy::bindDerived(cls);
}
/// Implemented by derived classes to add methods to the Python subclass.
@@ -301,33 +304,36 @@ namespace {
/// CRTP base classes for Python types that subclass Type and should be
/// castable from it (i.e. via something like IntegerType(t)).
-template <typename T>
-class PyConcreteType : public PyType {
+/// By default, type class hierarchies are one level deep (i.e. a
+/// concrete type class extends PyType); however, intermediate python-visible
+/// base classes can be modeled by specifying a BaseTy.
+template <typename DerivedTy, typename BaseTy = PyType>
+class PyConcreteType : public BaseTy {
public:
// Derived classes must define statics for:
// IsAFunctionTy isaFunction
// const char *pyClassName
- using ClassTy = py::class_<T, PyType>;
+ using ClassTy = py::class_<DerivedTy, BaseTy>;
using IsAFunctionTy = int (*)(MlirType);
PyConcreteType() = default;
- PyConcreteType(MlirType t) : PyType(t) {}
- PyConcreteType(PyType &orig) : PyType(castFrom(orig)) {}
+ PyConcreteType(MlirType t) : BaseTy(t) {}
+ PyConcreteType(PyType &orig) : PyConcreteType(castFrom(orig)) {}
static MlirType castFrom(PyType &orig) {
- if (!T::isaFunction(orig.type)) {
+ if (!DerivedTy::isaFunction(orig.type)) {
auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") +
- T::pyClassName + " (from " +
- origRepr + ")");
+ DerivedTy::pyClassName +
+ " (from " + origRepr + ")");
}
return orig.type;
}
static void bind(py::module &m) {
- auto cls = ClassTy(m, T::pyClassName);
+ auto cls = ClassTy(m, DerivedTy::pyClassName);
cls.def(py::init<PyType &>(), py::keep_alive<0, 1>());
- T::bindDerived(cls);
+ DerivedTy::bindDerived(cls);
}
/// Implemented by derived classes to add methods to the Python subclass.
@@ -590,142 +596,130 @@ class PyShapedType : public PyConcreteType<PyShapedType> {
};
/// Vector Type subclass - VectorType.
-class PyVectorType : public PyShapedType {
+class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
static constexpr const char *pyClassName = "VectorType";
- using PyShapedType::PyShapedType;
- // TODO: Switch back to bindDerived by making the ClassTy modifiable by
- // subclasses, exposing the ShapedType hierarchy.
- static void bind(py::module &m) {
- py::class_<PyVectorType, PyShapedType>(m, pyClassName)
- .def(py::init<PyType &>(), py::keep_alive<0, 1>())
- .def_static(
- "get_vector",
- // TODO: Make the location optional and create a default location.
- [](std::vector<int64_t> shape, PyType &elementType,
- PyLocation &loc) {
- MlirType t = mlirVectorTypeGetChecked(shape.size(), shape.data(),
- elementType.type, loc.loc);
- // TODO: Rework error reporting once diagnostic engine is exposed
- // in C API.
- if (mlirTypeIsNull(t)) {
- throw SetPyError(
- PyExc_ValueError,
- llvm::Twine("invalid '") +
- py::repr(py::cast(elementType)).cast<std::string>() +
- "' and expected floating point or integer type.");
- }
- return PyVectorType(t);
- },
- py::keep_alive<0, 2>(), "Create a vector type");
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get_vector",
+ // TODO: Make the location optional and create a default location.
+ [](std::vector<int64_t> shape, PyType &elementType, PyLocation &loc) {
+ MlirType t = mlirVectorTypeGetChecked(shape.size(), shape.data(),
+ elementType.type, loc.loc);
+ // TODO: Rework error reporting once diagnostic engine is exposed
+ // in C API.
+ if (mlirTypeIsNull(t)) {
+ throw SetPyError(
+ PyExc_ValueError,
+ llvm::Twine("invalid '") +
+ py::repr(py::cast(elementType)).cast<std::string>() +
+ "' and expected floating point or integer type.");
+ }
+ return PyVectorType(t);
+ },
+ py::keep_alive<0, 2>(), "Create a vector type");
}
};
/// Ranked Tensor Type subclass - RankedTensorType.
-class PyRankedTensorType : public PyShapedType {
+class PyRankedTensorType
+ : public PyConcreteType<PyRankedTensorType, PyShapedType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
static constexpr const char *pyClassName = "RankedTensorType";
- using PyShapedType::PyShapedType;
- // TODO: Switch back to bindDerived by making the ClassTy modifiable by
- // subclasses, exposing the ShapedType hierarchy.
- static void bind(py::module &m) {
- py::class_<PyRankedTensorType, PyShapedType>(m, pyClassName)
- .def(py::init<PyType &>(), py::keep_alive<0, 1>())
- .def_static(
- "get_ranked_tensor",
- // TODO: Make the location optional and create a default location.
- [](std::vector<int64_t> shape, PyType &elementType,
- PyLocation &loc) {
- MlirType t = mlirRankedTensorTypeGetChecked(
- shape.size(), shape.data(), elementType.type, loc.loc);
- // TODO: Rework error reporting once diagnostic engine is exposed
- // in C API.
- if (mlirTypeIsNull(t)) {
- throw SetPyError(
- PyExc_ValueError,
- llvm::Twine("invalid '") +
- py::repr(py::cast(elementType)).cast<std::string>() +
- "' and expected floating point, integer, vector or "
- "complex "
- "type.");
- }
- return PyRankedTensorType(t);
- },
- py::keep_alive<0, 2>(), "Create a ranked tensor type");
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get_ranked_tensor",
+ // TODO: Make the location optional and create a default location.
+ [](std::vector<int64_t> shape, PyType &elementType, PyLocation &loc) {
+ MlirType t = mlirRankedTensorTypeGetChecked(
+ shape.size(), shape.data(), elementType.type, loc.loc);
+ // TODO: Rework error reporting once diagnostic engine is exposed
+ // in C API.
+ if (mlirTypeIsNull(t)) {
+ throw SetPyError(
+ PyExc_ValueError,
+ llvm::Twine("invalid '") +
+ py::repr(py::cast(elementType)).cast<std::string>() +
+ "' and expected floating point, integer, vector or "
+ "complex "
+ "type.");
+ }
+ return PyRankedTensorType(t);
+ },
+ py::keep_alive<0, 2>(), "Create a ranked tensor type");
}
};
/// Unranked Tensor Type subclass - UnrankedTensorType.
-class PyUnrankedTensorType : public PyShapedType {
+class PyUnrankedTensorType
+ : public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
static constexpr const char *pyClassName = "UnrankedTensorType";
- using PyShapedType::PyShapedType;
- // TODO: Switch back to bindDerived by making the ClassTy modifiable by
- // subclasses, exposing the ShapedType hierarchy.
- static void bind(py::module &m) {
- py::class_<PyUnrankedTensorType, PyShapedType>(m, pyClassName)
- .def(py::init<PyType &>(), py::keep_alive<0, 1>())
- .def_static(
- "get_unranked_tensor",
- // TODO: Make the location optional and create a default location.
- [](PyType &elementType, PyLocation &loc) {
- MlirType t =
- mlirUnrankedTensorTypeGetChecked(elementType.type, loc.loc);
- // TODO: Rework error reporting once diagnostic engine is exposed
- // in C API.
- if (mlirTypeIsNull(t)) {
- throw SetPyError(
- PyExc_ValueError,
- llvm::Twine("invalid '") +
- py::repr(py::cast(elementType)).cast<std::string>() +
- "' and expected floating point, integer, vector or "
- "complex "
- "type.");
- }
- return PyUnrankedTensorType(t);
- },
- py::keep_alive<0, 1>(), "Create a unranked tensor type");
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get_unranked_tensor",
+ // TODO: Make the location optional and create a default location.
+ [](PyType &elementType, PyLocation &loc) {
+ MlirType t =
+ mlirUnrankedTensorTypeGetChecked(elementType.type, loc.loc);
+ // TODO: Rework error reporting once diagnostic engine is exposed
+ // in C API.
+ if (mlirTypeIsNull(t)) {
+ throw SetPyError(
+ PyExc_ValueError,
+ llvm::Twine("invalid '") +
+ py::repr(py::cast(elementType)).cast<std::string>() +
+ "' and expected floating point, integer, vector or "
+ "complex "
+ "type.");
+ }
+ return PyUnrankedTensorType(t);
+ },
+ py::keep_alive<0, 1>(), "Create a unranked tensor type");
}
};
/// Ranked MemRef Type subclass - MemRefType.
-class PyMemRefType : public PyShapedType {
+class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
static constexpr const char *pyClassName = "MemRefType";
- using PyShapedType::PyShapedType;
- // TODO: Switch back to bindDerived by making the ClassTy modifiable by
- // subclasses, exposing the ShapedType hierarchy.
- static void bind(py::module &m) {
- py::class_<PyMemRefType, PyShapedType>(m, pyClassName)
- .def(py::init<PyType &>(), py::keep_alive<0, 1>())
- // TODO: Add mlirMemRefTypeGet and mlirMemRefTypeGetAffineMap binding
- // once the affine map binding is completed.
- .def_static(
- "get_contiguous_memref",
- // TODO: Make the location optional and create a default location.
- [](PyType &elementType, std::vector<int64_t> shape,
- unsigned memorySpace, PyLocation &loc) {
- MlirType t = mlirMemRefTypeContiguousGetChecked(
- elementType.type, shape.size(), shape.data(), memorySpace,
- loc.loc);
- // TODO: Rework error reporting once diagnostic engine is exposed
- // in C API.
- if (mlirTypeIsNull(t)) {
- throw SetPyError(
- PyExc_ValueError,
- llvm::Twine("invalid '") +
- py::repr(py::cast(elementType)).cast<std::string>() +
- "' and expected floating point, integer, vector or "
- "complex "
- "type.");
- }
- return PyMemRefType(t);
- },
- py::keep_alive<0, 1>(), "Create a memref type")
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ // TODO: Add mlirMemRefTypeGet and mlirMemRefTypeGetAffineMap binding
+ // once the affine map binding is completed.
+ c.def_static(
+ "get_contiguous_memref",
+ // TODO: Make the location optional and create a default location.
+ [](PyType &elementType, std::vector<int64_t> shape,
+ unsigned memorySpace, PyLocation &loc) {
+ MlirType t = mlirMemRefTypeContiguousGetChecked(
+ elementType.type, shape.size(), shape.data(), memorySpace,
+ loc.loc);
+ // TODO: Rework error reporting once diagnostic engine is exposed
+ // in C API.
+ if (mlirTypeIsNull(t)) {
+ throw SetPyError(
+ PyExc_ValueError,
+ llvm::Twine("invalid '") +
+ py::repr(py::cast(elementType)).cast<std::string>() +
+ "' and expected floating point, integer, vector or "
+ "complex "
+ "type.");
+ }
+ return PyMemRefType(t);
+ },
+ py::keep_alive<0, 1>(), "Create a memref type")
.def_property_readonly(
"num_affine_maps",
[](PyMemRefType &self) -> intptr_t {
@@ -743,36 +737,34 @@ class PyMemRefType : public PyShapedType {
};
/// Unranked MemRef Type subclass - UnrankedMemRefType.
-class PyUnrankedMemRefType : public PyShapedType {
+class PyUnrankedMemRefType
+ : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
static constexpr const char *pyClassName = "UnrankedMemRefType";
- using PyShapedType::PyShapedType;
- // TODO: Switch back to bindDerived by making the ClassTy modifiable by
- // subclasses, exposing the ShapedType hierarchy.
- static void bind(py::module &m) {
- py::class_<PyUnrankedMemRefType, PyShapedType>(m, pyClassName)
- .def(py::init<PyType &>(), py::keep_alive<0, 1>())
- .def_static(
- "get_unranked_memref",
- // TODO: Make the location optional and create a default location.
- [](PyType &elementType, unsigned memorySpace, PyLocation &loc) {
- MlirType t = mlirUnrankedMemRefTypeGetChecked(
- elementType.type, memorySpace, loc.loc);
- // TODO: Rework error reporting once diagnostic engine is exposed
- // in C API.
- if (mlirTypeIsNull(t)) {
- throw SetPyError(
- PyExc_ValueError,
- llvm::Twine("invalid '") +
- py::repr(py::cast(elementType)).cast<std::string>() +
- "' and expected floating point, integer, vector or "
- "complex "
- "type.");
- }
- return PyUnrankedMemRefType(t);
- },
- py::keep_alive<0, 1>(), "Create a unranked memref type")
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get_unranked_memref",
+ // TODO: Make the location optional and create a default location.
+ [](PyType &elementType, unsigned memorySpace, PyLocation &loc) {
+ MlirType t = mlirUnrankedMemRefTypeGetChecked(elementType.type,
+ memorySpace, loc.loc);
+ // TODO: Rework error reporting once diagnostic engine is exposed
+ // in C API.
+ if (mlirTypeIsNull(t)) {
+ throw SetPyError(
+ PyExc_ValueError,
+ llvm::Twine("invalid '") +
+ py::repr(py::cast(elementType)).cast<std::string>() +
+ "' and expected floating point, integer, vector or "
+ "complex "
+ "type.");
+ }
+ return PyUnrankedMemRefType(t);
+ },
+ py::keep_alive<0, 1>(), "Create a unranked memref type")
.def_property_readonly(
"memory_space",
[](PyUnrankedMemRefType &self) -> unsigned {
diff --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py
index 00cd595843aa..4710bee27e37 100644
--- a/mlir/test/Bindings/Python/ir_types.py
+++ b/mlir/test/Bindings/Python/ir_types.py
@@ -177,11 +177,11 @@ def testComplexType():
run(testComplexType)
-# CHECK-LABEL: TEST: testShapedType
+# CHECK-LABEL: TEST: testConcreteShapedType
# Shaped type is not a kind of standard types, it is the base class for
# vectors, memrefs and tensors, so this test case uses an instance of vector
-# to test the shaped type.
-def testShapedType():
+# to test the shaped type. The class hierarchy is preserved on the python side.
+def testConcreteShapedType():
ctx = mlir.ir.Context()
vector = mlir.ir.VectorType(ctx.parse_type("vector<2x3xf32>"))
# CHECK: element type: f32
@@ -196,12 +196,25 @@ def testShapedType():
print("whether the dim-th dimension is dynamic:", vector.is_dynamic_dim(0))
# CHECK: dim size: 3
print("dim size:", vector.get_dim_size(1))
- # CHECK: False
- print(vector.is_dynamic_size(3))
- # CHECK: False
- print(vector.is_dynamic_stride_or_offset(1))
+ # CHECK: is_dynamic_size: False
+ print("is_dynamic_size:", vector.is_dynamic_size(3))
+ # CHECK: is_dynamic_stride_or_offset: False
+ print("is_dynamic_stride_or_offset:", vector.is_dynamic_stride_or_offset(1))
+ # CHECK: isinstance(ShapedType): True
+ print("isinstance(ShapedType):", isinstance(vector, mlir.ir.ShapedType))
+
+run(testConcreteShapedType)
+
+# CHECK-LABEL: TEST: testAbstractShapedType
+# Tests that ShapedType operates as an abstract base class of a concrete
+# shaped type (using vector as an example).
+def testAbstractShapedType():
+ ctx = mlir.ir.Context()
+ vector = mlir.ir.ShapedType(ctx.parse_type("vector<2x3xf32>"))
+ # CHECK: element type: f32
+ print("element type:", vector.element_type)
-run(testShapedType)
+run(testAbstractShapedType)
# CHECK-LABEL: TEST: testVectorType
def testVectorType():
More information about the Mlir-commits
mailing list