[Mlir-commits] [mlir] 1d99472 - [mlir] Add Complex Type, Vector Type and Tuple Type subclasses to python bindings

Mehdi Amini llvmlistbot at llvm.org
Tue Sep 1 22:46:27 PDT 2020


Author: ZHANG Hongbin
Date: 2020-09-02T05:46:00Z
New Revision: 1d99472875100b230bac2d9ea70b5cd4b45e788b

URL: https://github.com/llvm/llvm-project/commit/1d99472875100b230bac2d9ea70b5cd4b45e788b
DIFF: https://github.com/llvm/llvm-project/commit/1d99472875100b230bac2d9ea70b5cd4b45e788b.diff

LOG: [mlir] Add Complex Type, Vector Type and Tuple Type subclasses to python bindings

Based on the PyType and PyConcreteType classes, this patch implements the bindings of Complex Type, Vector Type and Tuple Type subclasses.
For the convenience of type checking, this patch defines a `mlirTypeIsAIntegerOrFloat` function to check whether the given type is an integer or float type.
These three subclasses in this patch have similar binding strategy:
- The function pointer `isaFunction` points to `mlirTypeIsA***`.
- The `mlir***TypeGet` C API is bound with the `get_***` method in the python side.
- The Complex Type and Vector Type check whether the given type is an integer or float type.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D86785

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/test/Bindings/Python/ir_types.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 19da019149f9..70c1a28e92be 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -11,11 +11,15 @@
 
 #include "mlir-c/StandardAttributes.h"
 #include "mlir-c/StandardTypes.h"
+#include "llvm/ADT/SmallVector.h"
+#include <pybind11/stl.h>
 
 namespace py = pybind11;
 using namespace mlir;
 using namespace mlir::python;
 
+using llvm::SmallVector;
+
 //------------------------------------------------------------------------------
 // Docstrings (trivial, non-duplicated docstrings are included inline).
 //------------------------------------------------------------------------------
@@ -152,6 +156,20 @@ struct PySinglePartStringAccumulator {
 
 } // namespace
 
+//------------------------------------------------------------------------------
+// Type-checking utilities.
+//------------------------------------------------------------------------------
+
+namespace {
+
+/// Checks whether the given type is an integer or float type.
+int mlirTypeIsAIntegerOrFloat(MlirType type) {
+  return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
+         mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
+}
+
+} // namespace
+
 //------------------------------------------------------------------------------
 // PyBlock, PyRegion, and PyOperation.
 //------------------------------------------------------------------------------
@@ -465,6 +483,102 @@ class PyNoneType : public PyConcreteType<PyNoneType> {
   }
 };
 
+/// Complex Type subclass - ComplexType.
+class PyComplexType : public PyConcreteType<PyComplexType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
+  static constexpr const char *pyClassName = "ComplexType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get_complex",
+        [](PyType &elementType) {
+          // The element must be a floating point or integer scalar type.
+          if (mlirTypeIsAIntegerOrFloat(elementType.type)) {
+            MlirType t = mlirComplexTypeGet(elementType.type);
+            return PyComplexType(t);
+          }
+          throw SetPyError(
+              PyExc_ValueError,
+              llvm::Twine("invalid '") +
+                  py::repr(py::cast(elementType)).cast<std::string>() +
+                  "' and expected floating point or integer type.");
+        },
+        py::keep_alive<0, 1>(), "Create a complex type");
+    c.def_property_readonly(
+        "element_type",
+        [](PyComplexType &self) -> PyType {
+          MlirType t = mlirComplexTypeGetElementType(self.type);
+          return PyType(t);
+        },
+        "Returns element type.");
+  }
+};
+
+/// Vector Type subclass - VectorType.
+class PyVectorType : public PyConcreteType<PyVectorType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
+  static constexpr const char *pyClassName = "VectorType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get_vector",
+        [](std::vector<int64_t> shape, PyType &elementType) {
+          // The element must be a floating point or integer scalar type.
+          if (mlirTypeIsAIntegerOrFloat(elementType.type)) {
+            MlirType t =
+                mlirVectorTypeGet(shape.size(), shape.data(), elementType.type);
+            return PyVectorType(t);
+          }
+          throw SetPyError(
+              PyExc_ValueError,
+              llvm::Twine("invalid '") +
+                  py::repr(py::cast(elementType)).cast<std::string>() +
+                  "' and expected floating point or integer type.");
+        },
+        py::keep_alive<0, 2>(), "Create a vector type");
+  }
+};
+
+/// Tuple Type subclass - TupleType.
+class PyTupleType : public PyConcreteType<PyTupleType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
+  static constexpr const char *pyClassName = "TupleType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get_tuple",
+        [](PyMlirContext &context, py::list elementList) {
+          intptr_t num = py::len(elementList);
+          // Mapping py::list to SmallVector.
+          SmallVector<MlirType, 4> elements;
+          for (auto element : elementList)
+            elements.push_back(element.cast<PyType>().type);
+          MlirType t = mlirTupleTypeGet(context.context, num, elements.data());
+          return PyTupleType(t);
+        },
+        py::keep_alive<0, 1>(), "Create a tuple type");
+    c.def(
+        "get_type",
+        [](PyTupleType &self, intptr_t pos) -> PyType {
+          MlirType t = mlirTupleTypeGetType(self.type, pos);
+          return PyType(t);
+        },
+        py::keep_alive<0, 1>(), "Returns the pos-th type in the tuple type.");
+    c.def_property_readonly(
+        "num_types",
+        [](PyTupleType &self) -> intptr_t {
+          return mlirTupleTypeGetNumTypes(self.type);
+        },
+        "Returns the number of types contained in a tuple.");
+  }
+};
+
 } // namespace
 
 //------------------------------------------------------------------------------
@@ -771,4 +885,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
   PyF32Type::bind(m);
   PyF64Type::bind(m);
   PyNoneType::bind(m);
+  PyComplexType::bind(m);
+  PyVectorType::bind(m);
+  PyTupleType::bind(m);
 }

diff  --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py
index 32e26c57518a..a8f3a3840497 100644
--- a/mlir/test/Bindings/Python/ir_types.py
+++ b/mlir/test/Bindings/Python/ir_types.py
@@ -154,3 +154,61 @@ def testNoneType():
   print("none type:", mlir.ir.NoneType(ctx))
 
 run(testNoneType)
+
+# CHECK-LABEL: TEST: testComplexType
+def testComplexType():
+  ctx = mlir.ir.Context()
+  complex_i32 = mlir.ir.ComplexType(ctx.parse_type("complex<i32>"))
+  # CHECK: complex type element: i32
+  print("complex type element:", complex_i32.element_type)
+
+  f32 = mlir.ir.F32Type(ctx)
+  # CHECK: complex type: complex<f32>
+  print("complex type:", mlir.ir.ComplexType.get_complex(f32))
+
+  index = mlir.ir.IndexType(ctx)
+  try:
+    complex_invalid = mlir.ir.ComplexType.get_complex(index)
+  except ValueError as e:
+    # CHECK: invalid 'Type(index)' and expected floating point or integer type.
+    print(e)
+  else:
+    print("Exception not produced")
+
+run(testComplexType)
+
+# CHECK-LABEL: TEST: testVectorType
+def testVectorType():
+  ctx = mlir.ir.Context()
+  f32 = mlir.ir.F32Type(ctx)
+  shape = [2, 3]
+  # CHECK: vector type: vector<2x3xf32>
+  print("vector type:", mlir.ir.VectorType.get_vector(shape, f32))
+
+  index = mlir.ir.IndexType(ctx)
+  try:
+    vector_invalid = mlir.ir.VectorType.get_vector(shape, index)
+  except ValueError as e:
+    # CHECK: invalid 'Type(index)' and expected floating point or integer type.
+    print(e)
+  else:
+    print("Exception not produced")
+
+run(testVectorType)
+
+# CHECK-LABEL: TEST: testTupleType
+def testTupleType():
+  ctx = mlir.ir.Context()
+  i32 = mlir.ir.IntegerType(ctx.parse_type("i32"))
+  f32 = mlir.ir.F32Type(ctx)
+  vector = mlir.ir.VectorType(ctx.parse_type("vector<2x3xf32>"))
+  l = [i32, f32, vector]
+  tuple_type = mlir.ir.TupleType.get_tuple(ctx, l)
+  # CHECK: tuple type: tuple<i32, f32, vector<2x3xf32>>
+  print("tuple type:", tuple_type)
+  # CHECK: number of types: 3
+  print("number of types:", tuple_type.num_types)
+  # CHECK: pos-th type in the tuple type: f32
+  print("pos-th type in the tuple type:", tuple_type.get_type(1))
+
+run(testTupleType)


        


More information about the Mlir-commits mailing list