[Mlir-commits] [mlir] b0d76f4 - [mlir] Centralize handling of memref element types.

Sean Silva llvmlistbot at llvm.org
Fri Aug 7 15:24:54 PDT 2020


Author: Sean Silva
Date: 2020-08-07T15:17:23-07:00
New Revision: b0d76f454daad66482b1084b302ed252124b7bdd

URL: https://github.com/llvm/llvm-project/commit/b0d76f454daad66482b1084b302ed252124b7bdd
DIFF: https://github.com/llvm/llvm-project/commit/b0d76f454daad66482b1084b302ed252124b7bdd.diff

LOG: [mlir] Centralize handling of memref element types.

This also beefs up the test coverage:
- Make unranked memref testing consistent with ranked memrefs.
- Add testing for the invalid element type cases.

This is not quite NFC: index types are now allowed in unranked memrefs.

Differential Revision: https://reviews.llvm.org/D85541

Added: 
    

Modified: 
    mlir/include/mlir/IR/StandardTypes.h
    mlir/lib/IR/StandardTypes.cpp
    mlir/test/IR/invalid.mlir
    mlir/test/IR/parser.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h
index 11f7e442f416..a4a4566e3eb8 100644
--- a/mlir/include/mlir/IR/StandardTypes.h
+++ b/mlir/include/mlir/IR/StandardTypes.h
@@ -426,6 +426,11 @@ class BaseMemRefType : public ShapedType {
 public:
   using ShapedType::ShapedType;
 
+  /// Return true if the specified element type is ok in a memref.
+  static bool isValidElementType(Type type) {
+    return type.isIntOrIndexOrFloat() || type.isa<VectorType, ComplexType>();
+  }
+
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool classof(Type type);
 };

diff  --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp
index fc4f555b37f1..c2295bb14573 100644
--- a/mlir/lib/IR/StandardTypes.cpp
+++ b/mlir/lib/IR/StandardTypes.cpp
@@ -408,9 +408,7 @@ MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
                                Optional<Location> location) {
   auto *context = elementType.getContext();
 
-  // Check that memref is formed from allowed types.
-  if (!elementType.isIntOrIndexOrFloat() &&
-      !elementType.isa<VectorType, ComplexType>())
+  if (!BaseMemRefType::isValidElementType(elementType))
     return emitOptionalError(location, "invalid memref element type"),
            MemRefType();
 
@@ -486,9 +484,7 @@ unsigned UnrankedMemRefType::getMemorySpace() const {
 LogicalResult
 UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType,
                                                  unsigned memorySpace) {
-  // Check that memref is formed from allowed types.
-  if (!elementType.isIntOrFloat() &&
-      !elementType.isa<VectorType, ComplexType>())
+  if (!BaseMemRefType::isValidElementType(elementType))
     return emitError(loc, "invalid memref element type");
   return success();
 }

diff  --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir
index dcf04735c901..8c098197a261 100644
--- a/mlir/test/IR/invalid.mlir
+++ b/mlir/test/IR/invalid.mlir
@@ -17,6 +17,14 @@ func @nestedtensor(tensor<tensor<i8>>) -> () // expected-error {{invalid tensor
 
 // -----
 
+func @illegalmemrefelementtype(memref<?xtensor<i8>>) -> () // expected-error {{invalid memref element type}}
+
+// -----
+
+func @illegalunrankedmemrefelementtype(memref<*xtensor<i8>>) -> () // expected-error {{invalid memref element type}}
+
+// -----
+
 func @indexvector(vector<4 x index>) -> () // expected-error {{vector elements must be int or float type}}
 
 // -----

diff  --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index 8d3d161ef27e..aad4ba9e9c1e 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -152,6 +152,12 @@ func @memref_with_vector_elems(memref<1x?xvector<10xf32>>)
 // CHECK: func @unranked_memref_with_complex_elems(memref<*xcomplex<f32>>)
 func @unranked_memref_with_complex_elems(memref<*xcomplex<f32>>)
 
+// CHECK: func @unranked_memref_with_index_elems(memref<*xindex>)
+func @unranked_memref_with_index_elems(memref<*xindex>)
+
+// CHECK: func @unranked_memref_with_vector_elems(memref<*xvector<10xf32>>)
+func @unranked_memref_with_vector_elems(memref<*xvector<10xf32>>)
+
 // CHECK: func @functions((memref<1x?x4x?x?xi32, #map0>, memref<8xi8>) -> (), () -> ())
 func @functions((memref<1x?x4x?x?xi32, #map0, 0>, memref<8xi8, #map1, 0>) -> (), ()->())
 


        


More information about the Mlir-commits mailing list