[Mlir-commits] [mlir] 54d432a - [mlir] Add Shaped Type, Tensor Type and MemRef Type to python bindings.

Stella Laurenzo llvmlistbot at llvm.org
Sun Sep 6 11:46:43 PDT 2020


Author: zhanghb97
Date: 2020-09-06T11:45:54-07:00
New Revision: 54d432aa6b835ee7e835d0626c15ca5e7eb83ab4

URL: https://github.com/llvm/llvm-project/commit/54d432aa6b835ee7e835d0626c15ca5e7eb83ab4
DIFF: https://github.com/llvm/llvm-project/commit/54d432aa6b835ee7e835d0626c15ca5e7eb83ab4.diff

LOG: [mlir] Add Shaped Type, Tensor Type and MemRef Type to python bindings.

Based on the PyType and PyConcreteType classes, this patch implements the bindings of Shaped Type, Tensor Type and MemRef Type subclasses.
The Tensor Type and MemRef Type are bound as ranked and unranked separately.
This patch adds the ***GetChecked C API to make sure the python side can get a valid type or a nullptr.
Shaped type is not a kind of standard types, it is the base class for vectors, memrefs and tensors, this patch binds the PyShapedType class as the base class of Vector Type, Tensor Type and MemRef Type subclasses.

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D87091

Added: 
    

Modified: 
    mlir/include/mlir-c/StandardTypes.h
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/lib/CAPI/IR/StandardTypes.cpp
    mlir/test/Bindings/Python/ir_types.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/StandardTypes.h b/mlir/include/mlir-c/StandardTypes.h
index ad28ea546717..eacfe0d39b6a 100644
--- a/mlir/include/mlir-c/StandardTypes.h
+++ b/mlir/include/mlir-c/StandardTypes.h
@@ -162,6 +162,11 @@ int mlirTypeIsAVector(MlirType type);
  * is owned by the context. */
 MlirType mlirVectorTypeGet(intptr_t rank, int64_t *shape, MlirType elementType);
 
+/** Same as "mlirVectorTypeGet" but returns a nullptr wrapping MlirType on
+ * illegal arguments, emitting appropriate diagnostics. */
+MlirType mlirVectorTypeGetChecked(intptr_t rank, int64_t *shape,
+                                  MlirType elementType, MlirLocation loc);
+
 /*============================================================================*/
 /* Ranked / Unranked Tensor type.                                             */
 /*============================================================================*/
@@ -180,10 +185,20 @@ int mlirTypeIsAUnrankedTensor(MlirType type);
 MlirType mlirRankedTensorTypeGet(intptr_t rank, int64_t *shape,
                                  MlirType elementType);
 
+/** Same as "mlirRankedTensorTypeGet" but returns a nullptr wrapping MlirType on
+ * illegal arguments, emitting appropriate diagnostics. */
+MlirType mlirRankedTensorTypeGetChecked(intptr_t rank, int64_t *shape,
+                                        MlirType elementType, MlirLocation loc);
+
 /** Creates an unranked tensor type with the given element type in the same
  * context as the element type. The type is owned by the context. */
 MlirType mlirUnrankedTensorTypeGet(MlirType elementType);
 
+/** Same as "mlirUnrankedTensorTypeGet" but returns a nullptr wrapping MlirType
+ * on illegal arguments, emitting appropriate diagnostics. */
+MlirType mlirUnrankedTensorTypeGetChecked(MlirType elementType,
+                                          MlirLocation loc);
+
 /*============================================================================*/
 /* Ranked / Unranked MemRef type.                                             */
 /*============================================================================*/
@@ -208,10 +223,23 @@ MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank, int64_t *shape,
 MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank,
                                      int64_t *shape, unsigned memorySpace);
 
+/** Same as "mlirMemRefTypeContiguousGet" but returns a nullptr wrapping
+ * MlirType on illegal arguments, emitting appropriate diagnostics. */
+MlirType mlirMemRefTypeContiguousGetChecked(MlirType elementType, intptr_t rank,
+                                            int64_t *shape,
+                                            unsigned memorySpace,
+                                            MlirLocation loc);
+
 /** Creates an Unranked MemRef type with the given element type and in the given
  * memory space. The type is owned by the context of element type. */
 MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, unsigned memorySpace);
 
+/** Same as "mlirUnrankedMemRefTypeGet" but returns a nullptr wrapping
+ * MlirType on illegal arguments, emitting appropriate diagnostics. */
+MlirType mlirUnrankedMemRefTypeGetChecked(MlirType elementType,
+                                          unsigned memorySpace,
+                                          MlirLocation loc);
+
 /** Returns the number of affine layout maps in the given MemRef type. */
 intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type);
 

diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 70c1a28e92be..149e231aed0b 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -516,30 +516,269 @@ class PyComplexType : public PyConcreteType<PyComplexType> {
   }
 };
 
-/// Vector Type subclass - VectorType.
-class PyVectorType : public PyConcreteType<PyVectorType> {
+class PyShapedType : public PyConcreteType<PyShapedType> {
 public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
-  static constexpr const char *pyClassName = "VectorType";
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAShaped;
+  static constexpr const char *pyClassName = "ShapedType";
   using PyConcreteType::PyConcreteType;
 
   static void bindDerived(ClassTy &c) {
+    c.def_property_readonly(
+        "element_type",
+        [](PyShapedType &self) {
+          MlirType t = mlirShapedTypeGetElementType(self.type);
+          return PyType(t);
+        },
+        py::keep_alive<0, 1>(), "Returns the element type of the shaped type.");
+    c.def_property_readonly(
+        "has_rank",
+        [](PyShapedType &self) -> bool {
+          return mlirShapedTypeHasRank(self.type);
+        },
+        "Returns whether the given shaped type is ranked.");
+    c.def_property_readonly(
+        "rank",
+        [](PyShapedType &self) {
+          self.requireHasRank();
+          return mlirShapedTypeGetRank(self.type);
+        },
+        "Returns the rank of the given ranked shaped type.");
+    c.def_property_readonly(
+        "has_static_shape",
+        [](PyShapedType &self) -> bool {
+          return mlirShapedTypeHasStaticShape(self.type);
+        },
+        "Returns whether the given shaped type has a static shape.");
+    c.def(
+        "is_dynamic_dim",
+        [](PyShapedType &self, intptr_t dim) -> bool {
+          self.requireHasRank();
+          return mlirShapedTypeIsDynamicDim(self.type, dim);
+        },
+        "Returns whether the dim-th dimension of the given shaped type is "
+        "dynamic.");
+    c.def(
+        "get_dim_size",
+        [](PyShapedType &self, intptr_t dim) {
+          self.requireHasRank();
+          return mlirShapedTypeGetDimSize(self.type, dim);
+        },
+        "Returns the dim-th dimension of the given ranked shaped type.");
     c.def_static(
-        "get_vector",
-        [](std::vector<int64_t> shape, PyType &elementType) {
-          // The element must be a floating point or integer scalar type.
-          if (mlirTypeIsAIntegerOrFloat(elementType.type)) {
-            MlirType t =
-                mlirVectorTypeGet(shape.size(), shape.data(), elementType.type);
-            return PyVectorType(t);
-          }
-          throw SetPyError(
-              PyExc_ValueError,
-              llvm::Twine("invalid '") +
-                  py::repr(py::cast(elementType)).cast<std::string>() +
-                  "' and expected floating point or integer type.");
+        "is_dynamic_size",
+        [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); },
+        "Returns whether the given dimension size indicates a dynamic "
+        "dimension.");
+    c.def(
+        "is_dynamic_stride_or_offset",
+        [](PyShapedType &self, int64_t val) -> bool {
+          self.requireHasRank();
+          return mlirShapedTypeIsDynamicStrideOrOffset(val);
         },
-        py::keep_alive<0, 2>(), "Create a vector type");
+        "Returns whether the given value is used as a placeholder for dynamic "
+        "strides and offsets in shaped types.");
+  }
+
+private:
+  void requireHasRank() {
+    if (!mlirShapedTypeHasRank(type)) {
+      throw SetPyError(
+          PyExc_ValueError,
+          "calling this method requires that the type has a rank.");
+    }
+  }
+};
+
+/// Vector Type subclass - VectorType.
+class PyVectorType : public PyShapedType {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
+  static constexpr const char *pyClassName = "VectorType";
+  using PyShapedType::PyShapedType;
+  // TODO: Switch back to bindDerived by making the ClassTy modifiable by
+  // subclasses, exposing the ShapedType hierarchy.
+  static void bind(py::module &m) {
+    py::class_<PyVectorType, PyShapedType>(m, pyClassName)
+        .def(py::init<PyType &>(), py::keep_alive<0, 1>())
+        .def_static(
+            "get_vector",
+            // TODO: Make the location optional and create a default location.
+            [](std::vector<int64_t> shape, PyType &elementType,
+               PyLocation &loc) {
+              MlirType t = mlirVectorTypeGetChecked(shape.size(), shape.data(),
+                                                    elementType.type, loc.loc);
+              // TODO: Rework error reporting once diagnostic engine is exposed
+              // in C API.
+              if (mlirTypeIsNull(t)) {
+                throw SetPyError(
+                    PyExc_ValueError,
+                    llvm::Twine("invalid '") +
+                        py::repr(py::cast(elementType)).cast<std::string>() +
+                        "' and expected floating point or integer type.");
+              }
+              return PyVectorType(t);
+            },
+            py::keep_alive<0, 2>(), "Create a vector type");
+  }
+};
+
+/// Ranked Tensor Type subclass - RankedTensorType.
+class PyRankedTensorType : public PyShapedType {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
+  static constexpr const char *pyClassName = "RankedTensorType";
+  using PyShapedType::PyShapedType;
+  // TODO: Switch back to bindDerived by making the ClassTy modifiable by
+  // subclasses, exposing the ShapedType hierarchy.
+  static void bind(py::module &m) {
+    py::class_<PyRankedTensorType, PyShapedType>(m, pyClassName)
+        .def(py::init<PyType &>(), py::keep_alive<0, 1>())
+        .def_static(
+            "get_ranked_tensor",
+            // TODO: Make the location optional and create a default location.
+            [](std::vector<int64_t> shape, PyType &elementType,
+               PyLocation &loc) {
+              MlirType t = mlirRankedTensorTypeGetChecked(
+                  shape.size(), shape.data(), elementType.type, loc.loc);
+              // TODO: Rework error reporting once diagnostic engine is exposed
+              // in C API.
+              if (mlirTypeIsNull(t)) {
+                throw SetPyError(
+                    PyExc_ValueError,
+                    llvm::Twine("invalid '") +
+                        py::repr(py::cast(elementType)).cast<std::string>() +
+                        "' and expected floating point, integer, vector or "
+                        "complex "
+                        "type.");
+              }
+              return PyRankedTensorType(t);
+            },
+            py::keep_alive<0, 2>(), "Create a ranked tensor type");
+  }
+};
+
+/// Unranked Tensor Type subclass - UnrankedTensorType.
+class PyUnrankedTensorType : public PyShapedType {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
+  static constexpr const char *pyClassName = "UnrankedTensorType";
+  using PyShapedType::PyShapedType;
+  // TODO: Switch back to bindDerived by making the ClassTy modifiable by
+  // subclasses, exposing the ShapedType hierarchy.
+  static void bind(py::module &m) {
+    py::class_<PyUnrankedTensorType, PyShapedType>(m, pyClassName)
+        .def(py::init<PyType &>(), py::keep_alive<0, 1>())
+        .def_static(
+            "get_unranked_tensor",
+            // TODO: Make the location optional and create a default location.
+            [](PyType &elementType, PyLocation &loc) {
+              MlirType t =
+                  mlirUnrankedTensorTypeGetChecked(elementType.type, loc.loc);
+              // TODO: Rework error reporting once diagnostic engine is exposed
+              // in C API.
+              if (mlirTypeIsNull(t)) {
+                throw SetPyError(
+                    PyExc_ValueError,
+                    llvm::Twine("invalid '") +
+                        py::repr(py::cast(elementType)).cast<std::string>() +
+                        "' and expected floating point, integer, vector or "
+                        "complex "
+                        "type.");
+              }
+              return PyUnrankedTensorType(t);
+            },
+            py::keep_alive<0, 1>(), "Create a unranked tensor type");
+  }
+};
+
+/// Ranked MemRef Type subclass - MemRefType.
+class PyMemRefType : public PyShapedType {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
+  static constexpr const char *pyClassName = "MemRefType";
+  using PyShapedType::PyShapedType;
+  // TODO: Switch back to bindDerived by making the ClassTy modifiable by
+  // subclasses, exposing the ShapedType hierarchy.
+  static void bind(py::module &m) {
+    py::class_<PyMemRefType, PyShapedType>(m, pyClassName)
+        .def(py::init<PyType &>(), py::keep_alive<0, 1>())
+        // TODO: Add mlirMemRefTypeGet and mlirMemRefTypeGetAffineMap binding
+        // once the affine map binding is completed.
+        .def_static(
+            "get_contiguous_memref",
+            // TODO: Make the location optional and create a default location.
+            [](PyType &elementType, std::vector<int64_t> shape,
+               unsigned memorySpace, PyLocation &loc) {
+              MlirType t = mlirMemRefTypeContiguousGetChecked(
+                  elementType.type, shape.size(), shape.data(), memorySpace,
+                  loc.loc);
+              // TODO: Rework error reporting once diagnostic engine is exposed
+              // in C API.
+              if (mlirTypeIsNull(t)) {
+                throw SetPyError(
+                    PyExc_ValueError,
+                    llvm::Twine("invalid '") +
+                        py::repr(py::cast(elementType)).cast<std::string>() +
+                        "' and expected floating point, integer, vector or "
+                        "complex "
+                        "type.");
+              }
+              return PyMemRefType(t);
+            },
+            py::keep_alive<0, 1>(), "Create a memref type")
+        .def_property_readonly(
+            "num_affine_maps",
+            [](PyMemRefType &self) -> intptr_t {
+              return mlirMemRefTypeGetNumAffineMaps(self.type);
+            },
+            "Returns the number of affine layout maps in the given MemRef "
+            "type.")
+        .def_property_readonly(
+            "memory_space",
+            [](PyMemRefType &self) -> unsigned {
+              return mlirMemRefTypeGetMemorySpace(self.type);
+            },
+            "Returns the memory space of the given MemRef type.");
+  }
+};
+
+/// Unranked MemRef Type subclass - UnrankedMemRefType.
+class PyUnrankedMemRefType : public PyShapedType {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
+  static constexpr const char *pyClassName = "UnrankedMemRefType";
+  using PyShapedType::PyShapedType;
+  // TODO: Switch back to bindDerived by making the ClassTy modifiable by
+  // subclasses, exposing the ShapedType hierarchy.
+  static void bind(py::module &m) {
+    py::class_<PyUnrankedMemRefType, PyShapedType>(m, pyClassName)
+        .def(py::init<PyType &>(), py::keep_alive<0, 1>())
+        .def_static(
+            "get_unranked_memref",
+            // TODO: Make the location optional and create a default location.
+            [](PyType &elementType, unsigned memorySpace, PyLocation &loc) {
+              MlirType t = mlirUnrankedMemRefTypeGetChecked(
+                  elementType.type, memorySpace, loc.loc);
+              // TODO: Rework error reporting once diagnostic engine is exposed
+              // in C API.
+              if (mlirTypeIsNull(t)) {
+                throw SetPyError(
+                    PyExc_ValueError,
+                    llvm::Twine("invalid '") +
+                        py::repr(py::cast(elementType)).cast<std::string>() +
+                        "' and expected floating point, integer, vector or "
+                        "complex "
+                        "type.");
+              }
+              return PyUnrankedMemRefType(t);
+            },
+            py::keep_alive<0, 1>(), "Create a unranked memref type")
+        .def_property_readonly(
+            "memory_space",
+            [](PyUnrankedMemRefType &self) -> unsigned {
+              return mlirUnrankedMemrefGetMemorySpace(self.type);
+            },
+            "Returns the memory space of the given Unranked MemRef type.");
   }
 };
 
@@ -886,6 +1125,11 @@ void mlir::python::populateIRSubmodule(py::module &m) {
   PyF64Type::bind(m);
   PyNoneType::bind(m);
   PyComplexType::bind(m);
+  PyShapedType::bind(m);
   PyVectorType::bind(m);
+  PyRankedTensorType::bind(m);
+  PyUnrankedTensorType::bind(m);
+  PyMemRefType::bind(m);
+  PyUnrankedMemRefType::bind(m);
   PyTupleType::bind(m);
 }

diff  --git a/mlir/lib/CAPI/IR/StandardTypes.cpp b/mlir/lib/CAPI/IR/StandardTypes.cpp
index eb006242e880..ddd3a5e93147 100644
--- a/mlir/lib/CAPI/IR/StandardTypes.cpp
+++ b/mlir/lib/CAPI/IR/StandardTypes.cpp
@@ -168,6 +168,13 @@ MlirType mlirVectorTypeGet(intptr_t rank, int64_t *shape,
                       unwrap(elementType)));
 }
 
+MlirType mlirVectorTypeGetChecked(intptr_t rank, int64_t *shape,
+                                  MlirType elementType, MlirLocation loc) {
+  return wrap(VectorType::getChecked(
+      llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
+      unwrap(loc)));
+}
+
 /* ========================================================================== */
 /* Ranked / Unranked tensor type.                                             */
 /* ========================================================================== */
@@ -189,10 +196,23 @@ MlirType mlirRankedTensorTypeGet(intptr_t rank, int64_t *shape,
       unwrap(elementType)));
 }
 
+MlirType mlirRankedTensorTypeGetChecked(intptr_t rank, int64_t *shape,
+                                        MlirType elementType,
+                                        MlirLocation loc) {
+  return wrap(RankedTensorType::getChecked(
+      llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
+      unwrap(loc)));
+}
+
 MlirType mlirUnrankedTensorTypeGet(MlirType elementType) {
   return wrap(UnrankedTensorType::get(unwrap(elementType)));
 }
 
+MlirType mlirUnrankedTensorTypeGetChecked(MlirType elementType,
+                                          MlirLocation loc) {
+  return wrap(UnrankedTensorType::getChecked(unwrap(elementType), unwrap(loc)));
+}
+
 /* ========================================================================== */
 /* Ranked / Unranked MemRef type.                                             */
 /* ========================================================================== */
@@ -216,6 +236,15 @@ MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank,
                       unwrap(elementType), llvm::None, memorySpace));
 }
 
+MlirType mlirMemRefTypeContiguousGetChecked(MlirType elementType, intptr_t rank,
+                                            int64_t *shape,
+                                            unsigned memorySpace,
+                                            MlirLocation loc) {
+  return wrap(MemRefType::getChecked(
+      llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
+      llvm::None, memorySpace, unwrap(loc)));
+}
+
 intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type) {
   return static_cast<intptr_t>(
       unwrap(type).cast<MemRefType>().getAffineMaps().size());
@@ -237,6 +266,13 @@ MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, unsigned memorySpace) {
   return wrap(UnrankedMemRefType::get(unwrap(elementType), memorySpace));
 }
 
+MlirType mlirUnrankedMemRefTypeGetChecked(MlirType elementType,
+                                          unsigned memorySpace,
+                                          MlirLocation loc) {
+  return wrap(UnrankedMemRefType::getChecked(unwrap(elementType), memorySpace,
+                                             unwrap(loc)));
+}
+
 unsigned mlirUnrankedMemrefGetMemorySpace(MlirType type) {
   return unwrap(type).cast<UnrankedMemRefType>().getMemorySpace();
 }

diff  --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py
index a8f3a3840497..00cd595843aa 100644
--- a/mlir/test/Bindings/Python/ir_types.py
+++ b/mlir/test/Bindings/Python/ir_types.py
@@ -177,25 +177,187 @@ def testComplexType():
 
 run(testComplexType)
 
+# CHECK-LABEL: TEST: testShapedType
+# Shaped type is not a kind of standard types, it is the base class for
+# vectors, memrefs and tensors, so this test case uses an instance of vector
+# to test the shaped type.
+def testShapedType():
+  ctx = mlir.ir.Context()
+  vector = mlir.ir.VectorType(ctx.parse_type("vector<2x3xf32>"))
+  # CHECK: element type: f32
+  print("element type:", vector.element_type)
+  # CHECK: whether the given shaped type is ranked: True
+  print("whether the given shaped type is ranked:", vector.has_rank)
+  # CHECK: rank: 2
+  print("rank:", vector.rank)
+  # CHECK: whether the shaped type has a static shape: True
+  print("whether the shaped type has a static shape:", vector.has_static_shape)
+  # CHECK: whether the dim-th dimension is dynamic: False
+  print("whether the dim-th dimension is dynamic:", vector.is_dynamic_dim(0))
+  # CHECK: dim size: 3
+  print("dim size:", vector.get_dim_size(1))
+  # CHECK: False
+  print(vector.is_dynamic_size(3))
+  # CHECK: False
+  print(vector.is_dynamic_stride_or_offset(1))
+
+run(testShapedType)
+
 # CHECK-LABEL: TEST: testVectorType
 def testVectorType():
   ctx = mlir.ir.Context()
   f32 = mlir.ir.F32Type(ctx)
   shape = [2, 3]
+  loc = ctx.get_unknown_location()
   # CHECK: vector type: vector<2x3xf32>
-  print("vector type:", mlir.ir.VectorType.get_vector(shape, f32))
+  print("vector type:", mlir.ir.VectorType.get_vector(shape, f32, loc))
 
-  index = mlir.ir.IndexType(ctx)
+  none = mlir.ir.NoneType(ctx)
   try:
-    vector_invalid = mlir.ir.VectorType.get_vector(shape, index)
+    vector_invalid = mlir.ir.VectorType.get_vector(shape, none, loc)
   except ValueError as e:
-    # CHECK: invalid 'Type(index)' and expected floating point or integer type.
+    # CHECK: invalid 'Type(none)' and expected floating point or integer type.
     print(e)
   else:
     print("Exception not produced")
 
 run(testVectorType)
 
+# CHECK-LABEL: TEST: testRankedTensorType
+def testRankedTensorType():
+  ctx = mlir.ir.Context()
+  f32 = mlir.ir.F32Type(ctx)
+  shape = [2, 3]
+  loc = ctx.get_unknown_location()
+  # CHECK: ranked tensor type: tensor<2x3xf32>
+  print("ranked tensor type:",
+        mlir.ir.RankedTensorType.get_ranked_tensor(shape, f32, loc))
+
+  none = mlir.ir.NoneType(ctx)
+  try:
+    tensor_invalid = mlir.ir.RankedTensorType.get_ranked_tensor(shape, none,
+                                                                loc)
+  except ValueError as e:
+    # CHECK: invalid 'Type(none)' and expected floating point, integer, vector
+    # CHECK: or complex type.
+    print(e)
+  else:
+    print("Exception not produced")
+
+run(testRankedTensorType)
+
+# CHECK-LABEL: TEST: testUnrankedTensorType
+def testUnrankedTensorType():
+  ctx = mlir.ir.Context()
+  f32 = mlir.ir.F32Type(ctx)
+  loc = ctx.get_unknown_location()
+  unranked_tensor = mlir.ir.UnrankedTensorType.get_unranked_tensor(f32, loc)
+  # CHECK: unranked tensor type: tensor<*xf32>
+  print("unranked tensor type:", unranked_tensor)
+  try:
+    invalid_rank = unranked_tensor.rank
+  except ValueError as e:
+    # CHECK: calling this method requires that the type has a rank.
+    print(e)
+  else:
+    print("Exception not produced")
+  try:
+    invalid_is_dynamic_dim = unranked_tensor.is_dynamic_dim(0)
+  except ValueError as e:
+    # CHECK: calling this method requires that the type has a rank.
+    print(e)
+  else:
+    print("Exception not produced")
+  try:
+    invalid_get_dim_size = unranked_tensor.get_dim_size(1)
+  except ValueError as e:
+    # CHECK: calling this method requires that the type has a rank.
+    print(e)
+  else:
+    print("Exception not produced")
+
+  none = mlir.ir.NoneType(ctx)
+  try:
+    tensor_invalid = mlir.ir.UnrankedTensorType.get_unranked_tensor(none, loc)
+  except ValueError as e:
+    # CHECK: invalid 'Type(none)' and expected floating point, integer, vector
+    # CHECK: or complex type.
+    print(e)
+  else:
+    print("Exception not produced")
+
+run(testUnrankedTensorType)
+
+# CHECK-LABEL: TEST: testMemRefType
+def testMemRefType():
+  ctx = mlir.ir.Context()
+  f32 = mlir.ir.F32Type(ctx)
+  shape = [2, 3]
+  loc = ctx.get_unknown_location()
+  memref = mlir.ir.MemRefType.get_contiguous_memref(f32, shape, 2, loc)
+  # CHECK: memref type: memref<2x3xf32, 2>
+  print("memref type:", memref)
+  # CHECK: number of affine layout maps: 0
+  print("number of affine layout maps:", memref.num_affine_maps)
+  # CHECK: memory space: 2
+  print("memory space:", memref.memory_space)
+
+  none = mlir.ir.NoneType(ctx)
+  try:
+    memref_invalid = mlir.ir.MemRefType.get_contiguous_memref(none, shape, 2,
+                                                              loc)
+  except ValueError as e:
+    # CHECK: invalid 'Type(none)' and expected floating point, integer, vector
+    # CHECK: or complex type.
+    print(e)
+  else:
+    print("Exception not produced")
+
+run(testMemRefType)
+
+# CHECK-LABEL: TEST: testUnrankedMemRefType
+def testUnrankedMemRefType():
+  ctx = mlir.ir.Context()
+  f32 = mlir.ir.F32Type(ctx)
+  loc = ctx.get_unknown_location()
+  unranked_memref = mlir.ir.UnrankedMemRefType.get_unranked_memref(f32, 2, loc)
+  # CHECK: unranked memref type: memref<*xf32, 2>
+  print("unranked memref type:", unranked_memref)
+  try:
+    invalid_rank = unranked_memref.rank
+  except ValueError as e:
+    # CHECK: calling this method requires that the type has a rank.
+    print(e)
+  else:
+    print("Exception not produced")
+  try:
+    invalid_is_dynamic_dim = unranked_memref.is_dynamic_dim(0)
+  except ValueError as e:
+    # CHECK: calling this method requires that the type has a rank.
+    print(e)
+  else:
+    print("Exception not produced")
+  try:
+    invalid_get_dim_size = unranked_memref.get_dim_size(1)
+  except ValueError as e:
+    # CHECK: calling this method requires that the type has a rank.
+    print(e)
+  else:
+    print("Exception not produced")
+
+  none = mlir.ir.NoneType(ctx)
+  try:
+    memref_invalid = mlir.ir.UnrankedMemRefType.get_unranked_memref(none, 2,
+                                                                    loc)
+  except ValueError as e:
+    # CHECK: invalid 'Type(none)' and expected floating point, integer, vector
+    # CHECK: or complex type.
+    print(e)
+  else:
+    print("Exception not produced")
+
+run(testUnrankedMemRefType)
+
 # CHECK-LABEL: TEST: testTupleType
 def testTupleType():
   ctx = mlir.ir.Context()


        


More information about the Mlir-commits mailing list