[llvm-branch-commits] [mlir] 046612d - [mlir][vector] verify memref of vector memory ops
Aart Bik via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Jan 11 13:36:51 PST 2021
Author: Aart Bik
Date: 2021-01-11T13:32:39-08:00
New Revision: 046612d29d7894783e8fcecbc62ebd6b4a78499f
URL: https://github.com/llvm/llvm-project/commit/046612d29d7894783e8fcecbc62ebd6b4a78499f
DIFF: https://github.com/llvm/llvm-project/commit/046612d29d7894783e8fcecbc62ebd6b4a78499f.diff
LOG: [mlir][vector] verify memref of vector memory ops
This ensures the memref base + indices expression is well-formed
Reviewed By: ThomasRaoux, ftynse
Differential Revision: https://reviews.llvm.org/D94441
Added:
Modified:
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Dialect/Vector/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 731ddae85ead..54e5e008e56f 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -2365,10 +2365,12 @@ static LogicalResult verify(MaskedLoadOp op) {
VectorType maskVType = op.getMaskVectorType();
VectorType passVType = op.getPassThruVectorType();
VectorType resVType = op.getResultVectorType();
+ MemRefType memType = op.getMemRefType();
- if (resVType.getElementType() != op.getMemRefType().getElementType())
+ if (resVType.getElementType() != memType.getElementType())
return op.emitOpError("base and result element type should match");
-
+ if (llvm::size(op.indices()) != memType.getRank())
+ return op.emitOpError("requires ") << memType.getRank() << " indices";
if (resVType.getDimSize(0) != maskVType.getDimSize(0))
return op.emitOpError("expected result dim to match mask dim");
if (resVType != passVType)
@@ -2410,10 +2412,12 @@ void MaskedLoadOp::getCanonicalizationPatterns(
static LogicalResult verify(MaskedStoreOp op) {
VectorType maskVType = op.getMaskVectorType();
VectorType valueVType = op.getValueVectorType();
+ MemRefType memType = op.getMemRefType();
- if (valueVType.getElementType() != op.getMemRefType().getElementType())
+ if (valueVType.getElementType() != memType.getElementType())
return op.emitOpError("base and value element type should match");
-
+ if (llvm::size(op.indices()) != memType.getRank())
+ return op.emitOpError("requires ") << memType.getRank() << " indices";
if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
return op.emitOpError("expected value dim to match mask dim");
return success();
@@ -2454,10 +2458,10 @@ static LogicalResult verify(GatherOp op) {
VectorType indicesVType = op.getIndicesVectorType();
VectorType maskVType = op.getMaskVectorType();
VectorType resVType = op.getResultVectorType();
+ MemRefType memType = op.getMemRefType();
- if (resVType.getElementType() != op.getMemRefType().getElementType())
+ if (resVType.getElementType() != memType.getElementType())
return op.emitOpError("base and result element type should match");
-
if (resVType.getDimSize(0) != indicesVType.getDimSize(0))
return op.emitOpError("expected result dim to match indices dim");
if (resVType.getDimSize(0) != maskVType.getDimSize(0))
@@ -2500,10 +2504,10 @@ static LogicalResult verify(ScatterOp op) {
VectorType indicesVType = op.getIndicesVectorType();
VectorType maskVType = op.getMaskVectorType();
VectorType valueVType = op.getValueVectorType();
+ MemRefType memType = op.getMemRefType();
- if (valueVType.getElementType() != op.getMemRefType().getElementType())
+ if (valueVType.getElementType() != memType.getElementType())
return op.emitOpError("base and value element type should match");
-
if (valueVType.getDimSize(0) != indicesVType.getDimSize(0))
return op.emitOpError("expected value dim to match indices dim");
if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
@@ -2544,10 +2548,12 @@ static LogicalResult verify(ExpandLoadOp op) {
VectorType maskVType = op.getMaskVectorType();
VectorType passVType = op.getPassThruVectorType();
VectorType resVType = op.getResultVectorType();
+ MemRefType memType = op.getMemRefType();
- if (resVType.getElementType() != op.getMemRefType().getElementType())
+ if (resVType.getElementType() != memType.getElementType())
return op.emitOpError("base and result element type should match");
-
+ if (llvm::size(op.indices()) != memType.getRank())
+ return op.emitOpError("requires ") << memType.getRank() << " indices";
if (resVType.getDimSize(0) != maskVType.getDimSize(0))
return op.emitOpError("expected result dim to match mask dim");
if (resVType != passVType)
@@ -2589,10 +2595,12 @@ void ExpandLoadOp::getCanonicalizationPatterns(
static LogicalResult verify(CompressStoreOp op) {
VectorType maskVType = op.getMaskVectorType();
VectorType valueVType = op.getValueVectorType();
+ MemRefType memType = op.getMemRefType();
- if (valueVType.getElementType() != op.getMemRefType().getElementType())
+ if (valueVType.getElementType() != memType.getElementType())
return op.emitOpError("base and value element type should match");
-
+ if (llvm::size(op.indices()) != memType.getRank())
+ return op.emitOpError("requires ") << memType.getRank() << " indices";
if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
return op.emitOpError("expected value dim to match mask dim");
return success();
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 11100c4e615e..099dad7eada4 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1222,6 +1222,13 @@ func @maskedload_pass_thru_type_mask_mismatch(%base: memref<?xf32>, %mask: vecto
// -----
+func @maskedload_memref_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>, %pass: vector<16xf32>) {
+ // expected-error at +1 {{'vector.maskedload' op requires 1 indices}}
+ %0 = vector.maskedload %base[], %mask, %pass : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
func @maskedstore_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
%c0 = constant 0 : index
// expected-error at +1 {{'vector.maskedstore' op base and value element type should match}}
@@ -1238,6 +1245,14 @@ func @maskedstore_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<15xi1>,
// -----
+func @maskedstore_memref_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>, %value: vector<16xf32>) {
+ %c0 = constant 0 : index
+ // expected-error at +1 {{'vector.maskedstore' op requires 1 indices}}
+ vector.maskedstore %base[%c0, %c0], %mask, %value : memref<?xf32>, vector<16xi1>, vector<16xf32>
+}
+
+// -----
+
func @gather_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
// expected-error at +1 {{'vector.gather' op base and result element type should match}}
@@ -1343,6 +1358,14 @@ func @expand_pass_thru_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>, %pas
// -----
+func @expand_memref_mismatch(%base: memref<?x?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+ %c0 = constant 0 : index
+ // expected-error at +1 {{'vector.expandload' op requires 2 indices}}
+ %0 = vector.expandload %base[%c0], %mask, %pass_thru : memref<?x?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
func @compress_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
%c0 = constant 0 : index
// expected-error at +1 {{'vector.compressstore' op base and value element type should match}}
@@ -1359,6 +1382,14 @@ func @compress_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %va
// -----
+func @compress_memref_mismatch(%base: memref<?x?xf32>, %mask: vector<16xi1>, %value: vector<16xf32>) {
+ %c0 = constant 0 : index
+ // expected-error at +1 {{'vector.compressstore' op requires 2 indices}}
+ vector.compressstore %base[%c0, %c0, %c0], %mask, %value : memref<?x?xf32>, vector<16xi1>, vector<16xf32>
+}
+
+// -----
+
func @extract_map_rank(%v: vector<32xf32>, %id : index) {
// expected-error at +1 {{'vector.extract_map' op expected source and destination vectors of same rank}}
%0 = vector.extract_map %v[%id] : vector<32xf32> to vector<2x1xf32>
More information about the llvm-branch-commits
mailing list