[Mlir-commits] [mlir] b0d76f4 - [mlir] Centralize handling of memref element types.
Sean Silva
llvmlistbot at llvm.org
Fri Aug 7 15:24:54 PDT 2020
Author: Sean Silva
Date: 2020-08-07T15:17:23-07:00
New Revision: b0d76f454daad66482b1084b302ed252124b7bdd
URL: https://github.com/llvm/llvm-project/commit/b0d76f454daad66482b1084b302ed252124b7bdd
DIFF: https://github.com/llvm/llvm-project/commit/b0d76f454daad66482b1084b302ed252124b7bdd.diff
LOG: [mlir] Centralize handling of memref element types.
This also beefs up the test coverage:
- Make unranked memref testing consistent with ranked memrefs.
- Add testing for the invalid element type cases.
This is not quite NFC: index types are now allowed in unranked memrefs.
Differential Revision: https://reviews.llvm.org/D85541
Added:
Modified:
mlir/include/mlir/IR/StandardTypes.h
mlir/lib/IR/StandardTypes.cpp
mlir/test/IR/invalid.mlir
mlir/test/IR/parser.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h
index 11f7e442f416..a4a4566e3eb8 100644
--- a/mlir/include/mlir/IR/StandardTypes.h
+++ b/mlir/include/mlir/IR/StandardTypes.h
@@ -426,6 +426,11 @@ class BaseMemRefType : public ShapedType {
public:
using ShapedType::ShapedType;
+ /// Return true if the specified element type is ok in a memref.
+ static bool isValidElementType(Type type) {
+ return type.isIntOrIndexOrFloat() || type.isa<VectorType, ComplexType>();
+ }
+
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Type type);
};
diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp
index fc4f555b37f1..c2295bb14573 100644
--- a/mlir/lib/IR/StandardTypes.cpp
+++ b/mlir/lib/IR/StandardTypes.cpp
@@ -408,9 +408,7 @@ MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
Optional<Location> location) {
auto *context = elementType.getContext();
- // Check that memref is formed from allowed types.
- if (!elementType.isIntOrIndexOrFloat() &&
- !elementType.isa<VectorType, ComplexType>())
+ if (!BaseMemRefType::isValidElementType(elementType))
return emitOptionalError(location, "invalid memref element type"),
MemRefType();
@@ -486,9 +484,7 @@ unsigned UnrankedMemRefType::getMemorySpace() const {
LogicalResult
UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType,
unsigned memorySpace) {
- // Check that memref is formed from allowed types.
- if (!elementType.isIntOrFloat() &&
- !elementType.isa<VectorType, ComplexType>())
+ if (!BaseMemRefType::isValidElementType(elementType))
return emitError(loc, "invalid memref element type");
return success();
}
diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir
index dcf04735c901..8c098197a261 100644
--- a/mlir/test/IR/invalid.mlir
+++ b/mlir/test/IR/invalid.mlir
@@ -17,6 +17,14 @@ func @nestedtensor(tensor<tensor<i8>>) -> () // expected-error {{invalid tensor
// -----
+func @illegalmemrefelementtype(memref<?xtensor<i8>>) -> () // expected-error {{invalid memref element type}}
+
+// -----
+
+func @illegalunrankedmemrefelementtype(memref<*xtensor<i8>>) -> () // expected-error {{invalid memref element type}}
+
+// -----
+
func @indexvector(vector<4 x index>) -> () // expected-error {{vector elements must be int or float type}}
// -----
diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index 8d3d161ef27e..aad4ba9e9c1e 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -152,6 +152,12 @@ func @memref_with_vector_elems(memref<1x?xvector<10xf32>>)
// CHECK: func @unranked_memref_with_complex_elems(memref<*xcomplex<f32>>)
func @unranked_memref_with_complex_elems(memref<*xcomplex<f32>>)
+// CHECK: func @unranked_memref_with_index_elems(memref<*xindex>)
+func @unranked_memref_with_index_elems(memref<*xindex>)
+
+// CHECK: func @unranked_memref_with_vector_elems(memref<*xvector<10xf32>>)
+func @unranked_memref_with_vector_elems(memref<*xvector<10xf32>>)
+
// CHECK: func @functions((memref<1x?x4x?x?xi32, #map0>, memref<8xi8>) -> (), () -> ())
func @functions((memref<1x?x4x?x?xi32, #map0, 0>, memref<8xi8, #map1, 0>) -> (), ()->())
More information about the Mlir-commits
mailing list