[Mlir-commits] [mlir] 1d99472 - [mlir] Add Complex Type, Vector Type and Tuple Type subclasses to python bindings
Mehdi Amini
llvmlistbot at llvm.org
Tue Sep 1 22:46:27 PDT 2020
Author: ZHANG Hongbin
Date: 2020-09-02T05:46:00Z
New Revision: 1d99472875100b230bac2d9ea70b5cd4b45e788b
URL: https://github.com/llvm/llvm-project/commit/1d99472875100b230bac2d9ea70b5cd4b45e788b
DIFF: https://github.com/llvm/llvm-project/commit/1d99472875100b230bac2d9ea70b5cd4b45e788b.diff
LOG: [mlir] Add Complex Type, Vector Type and Tuple Type subclasses to python bindings
Based on the PyType and PyConcreteType classes, this patch implements the bindings of Complex Type, Vector Type and Tuple Type subclasses.
For the convenience of type checking, this patch defines a `mlirTypeIsAIntegerOrFloat` function to check whether the given type is an integer or float type.
These three subclasses in this patch have similar binding strategy:
- The function pointer `isaFunction` points to `mlirTypeIsA***`.
- The `mlir***TypeGet` C API is bound with the `get_***` method in the python side.
- The Complex Type and Vector Type check whether the given type is an integer or float type.
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D86785
Added:
Modified:
mlir/lib/Bindings/Python/IRModules.cpp
mlir/test/Bindings/Python/ir_types.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 19da019149f9..70c1a28e92be 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -11,11 +11,15 @@
#include "mlir-c/StandardAttributes.h"
#include "mlir-c/StandardTypes.h"
+#include "llvm/ADT/SmallVector.h"
+#include <pybind11/stl.h>
namespace py = pybind11;
using namespace mlir;
using namespace mlir::python;
+using llvm::SmallVector;
+
//------------------------------------------------------------------------------
// Docstrings (trivial, non-duplicated docstrings are included inline).
//------------------------------------------------------------------------------
@@ -152,6 +156,20 @@ struct PySinglePartStringAccumulator {
} // namespace
+//------------------------------------------------------------------------------
+// Type-checking utilities.
+//------------------------------------------------------------------------------
+
+namespace {
+
+/// Checks whether the given type is an integer or float type.
+int mlirTypeIsAIntegerOrFloat(MlirType type) {
+ return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
+ mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
+}
+
+} // namespace
+
//------------------------------------------------------------------------------
// PyBlock, PyRegion, and PyOperation.
//------------------------------------------------------------------------------
@@ -465,6 +483,102 @@ class PyNoneType : public PyConcreteType<PyNoneType> {
}
};
+/// Complex Type subclass - ComplexType.
+class PyComplexType : public PyConcreteType<PyComplexType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
+ static constexpr const char *pyClassName = "ComplexType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get_complex",
+ [](PyType &elementType) {
+ // The element must be a floating point or integer scalar type.
+ if (mlirTypeIsAIntegerOrFloat(elementType.type)) {
+ MlirType t = mlirComplexTypeGet(elementType.type);
+ return PyComplexType(t);
+ }
+ throw SetPyError(
+ PyExc_ValueError,
+ llvm::Twine("invalid '") +
+ py::repr(py::cast(elementType)).cast<std::string>() +
+ "' and expected floating point or integer type.");
+ },
+ py::keep_alive<0, 1>(), "Create a complex type");
+ c.def_property_readonly(
+ "element_type",
+ [](PyComplexType &self) -> PyType {
+ MlirType t = mlirComplexTypeGetElementType(self.type);
+ return PyType(t);
+ },
+ "Returns element type.");
+ }
+};
+
+/// Vector Type subclass - VectorType.
+class PyVectorType : public PyConcreteType<PyVectorType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
+ static constexpr const char *pyClassName = "VectorType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get_vector",
+ [](std::vector<int64_t> shape, PyType &elementType) {
+ // The element must be a floating point or integer scalar type.
+ if (mlirTypeIsAIntegerOrFloat(elementType.type)) {
+ MlirType t =
+ mlirVectorTypeGet(shape.size(), shape.data(), elementType.type);
+ return PyVectorType(t);
+ }
+ throw SetPyError(
+ PyExc_ValueError,
+ llvm::Twine("invalid '") +
+ py::repr(py::cast(elementType)).cast<std::string>() +
+ "' and expected floating point or integer type.");
+ },
+ py::keep_alive<0, 2>(), "Create a vector type");
+ }
+};
+
+/// Tuple Type subclass - TupleType.
+class PyTupleType : public PyConcreteType<PyTupleType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
+ static constexpr const char *pyClassName = "TupleType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get_tuple",
+ [](PyMlirContext &context, py::list elementList) {
+ intptr_t num = py::len(elementList);
+ // Mapping py::list to SmallVector.
+ SmallVector<MlirType, 4> elements;
+ for (auto element : elementList)
+ elements.push_back(element.cast<PyType>().type);
+ MlirType t = mlirTupleTypeGet(context.context, num, elements.data());
+ return PyTupleType(t);
+ },
+ py::keep_alive<0, 1>(), "Create a tuple type");
+ c.def(
+ "get_type",
+ [](PyTupleType &self, intptr_t pos) -> PyType {
+ MlirType t = mlirTupleTypeGetType(self.type, pos);
+ return PyType(t);
+ },
+ py::keep_alive<0, 1>(), "Returns the pos-th type in the tuple type.");
+ c.def_property_readonly(
+ "num_types",
+ [](PyTupleType &self) -> intptr_t {
+ return mlirTupleTypeGetNumTypes(self.type);
+ },
+ "Returns the number of types contained in a tuple.");
+ }
+};
+
} // namespace
//------------------------------------------------------------------------------
@@ -771,4 +885,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
PyF32Type::bind(m);
PyF64Type::bind(m);
PyNoneType::bind(m);
+ PyComplexType::bind(m);
+ PyVectorType::bind(m);
+ PyTupleType::bind(m);
}
diff --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py
index 32e26c57518a..a8f3a3840497 100644
--- a/mlir/test/Bindings/Python/ir_types.py
+++ b/mlir/test/Bindings/Python/ir_types.py
@@ -154,3 +154,61 @@ def testNoneType():
print("none type:", mlir.ir.NoneType(ctx))
run(testNoneType)
+
+# CHECK-LABEL: TEST: testComplexType
+def testComplexType():
+ ctx = mlir.ir.Context()
+ complex_i32 = mlir.ir.ComplexType(ctx.parse_type("complex<i32>"))
+ # CHECK: complex type element: i32
+ print("complex type element:", complex_i32.element_type)
+
+ f32 = mlir.ir.F32Type(ctx)
+ # CHECK: complex type: complex<f32>
+ print("complex type:", mlir.ir.ComplexType.get_complex(f32))
+
+ index = mlir.ir.IndexType(ctx)
+ try:
+ complex_invalid = mlir.ir.ComplexType.get_complex(index)
+ except ValueError as e:
+ # CHECK: invalid 'Type(index)' and expected floating point or integer type.
+ print(e)
+ else:
+ print("Exception not produced")
+
+run(testComplexType)
+
+# CHECK-LABEL: TEST: testVectorType
+def testVectorType():
+ ctx = mlir.ir.Context()
+ f32 = mlir.ir.F32Type(ctx)
+ shape = [2, 3]
+ # CHECK: vector type: vector<2x3xf32>
+ print("vector type:", mlir.ir.VectorType.get_vector(shape, f32))
+
+ index = mlir.ir.IndexType(ctx)
+ try:
+ vector_invalid = mlir.ir.VectorType.get_vector(shape, index)
+ except ValueError as e:
+ # CHECK: invalid 'Type(index)' and expected floating point or integer type.
+ print(e)
+ else:
+ print("Exception not produced")
+
+run(testVectorType)
+
+# CHECK-LABEL: TEST: testTupleType
+def testTupleType():
+ ctx = mlir.ir.Context()
+ i32 = mlir.ir.IntegerType(ctx.parse_type("i32"))
+ f32 = mlir.ir.F32Type(ctx)
+ vector = mlir.ir.VectorType(ctx.parse_type("vector<2x3xf32>"))
+ l = [i32, f32, vector]
+ tuple_type = mlir.ir.TupleType.get_tuple(ctx, l)
+ # CHECK: tuple type: tuple<i32, f32, vector<2x3xf32>>
+ print("tuple type:", tuple_type)
+ # CHECK: number of types: 3
+ print("number of types:", tuple_type.num_types)
+ # CHECK: pos-th type in the tuple type: f32
+ print("pos-th type in the tuple type:", tuple_type.get_type(1))
+
+run(testTupleType)
More information about the Mlir-commits
mailing list