[llvm-branch-commits] [mlir] 046612d - [mlir][vector] verify memref of vector memory ops

Aart Bik via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Jan 11 13:36:51 PST 2021


Author: Aart Bik
Date: 2021-01-11T13:32:39-08:00
New Revision: 046612d29d7894783e8fcecbc62ebd6b4a78499f

URL: https://github.com/llvm/llvm-project/commit/046612d29d7894783e8fcecbc62ebd6b4a78499f
DIFF: https://github.com/llvm/llvm-project/commit/046612d29d7894783e8fcecbc62ebd6b4a78499f.diff

LOG: [mlir][vector] verify memref of vector memory ops

This ensures the memref base + indices expression is well-formed

Reviewed By: ThomasRaoux, ftynse

Differential Revision: https://reviews.llvm.org/D94441

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Dialect/Vector/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 731ddae85ead..54e5e008e56f 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -2365,10 +2365,12 @@ static LogicalResult verify(MaskedLoadOp op) {
   VectorType maskVType = op.getMaskVectorType();
   VectorType passVType = op.getPassThruVectorType();
   VectorType resVType = op.getResultVectorType();
+  MemRefType memType = op.getMemRefType();
 
-  if (resVType.getElementType() != op.getMemRefType().getElementType())
+  if (resVType.getElementType() != memType.getElementType())
     return op.emitOpError("base and result element type should match");
-
+  if (llvm::size(op.indices()) != memType.getRank())
+    return op.emitOpError("requires ") << memType.getRank() << " indices";
   if (resVType.getDimSize(0) != maskVType.getDimSize(0))
     return op.emitOpError("expected result dim to match mask dim");
   if (resVType != passVType)
@@ -2410,10 +2412,12 @@ void MaskedLoadOp::getCanonicalizationPatterns(
 static LogicalResult verify(MaskedStoreOp op) {
   VectorType maskVType = op.getMaskVectorType();
   VectorType valueVType = op.getValueVectorType();
+  MemRefType memType = op.getMemRefType();
 
-  if (valueVType.getElementType() != op.getMemRefType().getElementType())
+  if (valueVType.getElementType() != memType.getElementType())
     return op.emitOpError("base and value element type should match");
-
+  if (llvm::size(op.indices()) != memType.getRank())
+    return op.emitOpError("requires ") << memType.getRank() << " indices";
   if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
     return op.emitOpError("expected value dim to match mask dim");
   return success();
@@ -2454,10 +2458,10 @@ static LogicalResult verify(GatherOp op) {
   VectorType indicesVType = op.getIndicesVectorType();
   VectorType maskVType = op.getMaskVectorType();
   VectorType resVType = op.getResultVectorType();
+  MemRefType memType = op.getMemRefType();
 
-  if (resVType.getElementType() != op.getMemRefType().getElementType())
+  if (resVType.getElementType() != memType.getElementType())
     return op.emitOpError("base and result element type should match");
-
   if (resVType.getDimSize(0) != indicesVType.getDimSize(0))
     return op.emitOpError("expected result dim to match indices dim");
   if (resVType.getDimSize(0) != maskVType.getDimSize(0))
@@ -2500,10 +2504,10 @@ static LogicalResult verify(ScatterOp op) {
   VectorType indicesVType = op.getIndicesVectorType();
   VectorType maskVType = op.getMaskVectorType();
   VectorType valueVType = op.getValueVectorType();
+  MemRefType memType = op.getMemRefType();
 
-  if (valueVType.getElementType() != op.getMemRefType().getElementType())
+  if (valueVType.getElementType() != memType.getElementType())
     return op.emitOpError("base and value element type should match");
-
   if (valueVType.getDimSize(0) != indicesVType.getDimSize(0))
     return op.emitOpError("expected value dim to match indices dim");
   if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
@@ -2544,10 +2548,12 @@ static LogicalResult verify(ExpandLoadOp op) {
   VectorType maskVType = op.getMaskVectorType();
   VectorType passVType = op.getPassThruVectorType();
   VectorType resVType = op.getResultVectorType();
+  MemRefType memType = op.getMemRefType();
 
-  if (resVType.getElementType() != op.getMemRefType().getElementType())
+  if (resVType.getElementType() != memType.getElementType())
     return op.emitOpError("base and result element type should match");
-
+  if (llvm::size(op.indices()) != memType.getRank())
+    return op.emitOpError("requires ") << memType.getRank() << " indices";
   if (resVType.getDimSize(0) != maskVType.getDimSize(0))
     return op.emitOpError("expected result dim to match mask dim");
   if (resVType != passVType)
@@ -2589,10 +2595,12 @@ void ExpandLoadOp::getCanonicalizationPatterns(
 static LogicalResult verify(CompressStoreOp op) {
   VectorType maskVType = op.getMaskVectorType();
   VectorType valueVType = op.getValueVectorType();
+  MemRefType memType = op.getMemRefType();
 
-  if (valueVType.getElementType() != op.getMemRefType().getElementType())
+  if (valueVType.getElementType() != memType.getElementType())
     return op.emitOpError("base and value element type should match");
-
+  if (llvm::size(op.indices()) != memType.getRank())
+    return op.emitOpError("requires ") << memType.getRank() << " indices";
   if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
     return op.emitOpError("expected value dim to match mask dim");
   return success();

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 11100c4e615e..099dad7eada4 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1222,6 +1222,13 @@ func @maskedload_pass_thru_type_mask_mismatch(%base: memref<?xf32>, %mask: vecto
 
 // -----
 
+func @maskedload_memref_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>, %pass: vector<16xf32>) {
+  // expected-error at +1 {{'vector.maskedload' op requires 1 indices}}
+  %0 = vector.maskedload %base[], %mask, %pass : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
 func @maskedstore_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
   %c0 = constant 0 : index
   // expected-error at +1 {{'vector.maskedstore' op base and value element type should match}}
@@ -1238,6 +1245,14 @@ func @maskedstore_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<15xi1>,
 
 // -----
 
+func @maskedstore_memref_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>, %value: vector<16xf32>) {
+  %c0 = constant 0 : index
+  // expected-error at +1 {{'vector.maskedstore' op requires 1 indices}}
+  vector.maskedstore %base[%c0, %c0], %mask, %value : memref<?xf32>, vector<16xi1>, vector<16xf32>
+}
+
+// -----
+
 func @gather_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
                                 %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
   // expected-error at +1 {{'vector.gather' op base and result element type should match}}
@@ -1343,6 +1358,14 @@ func @expand_pass_thru_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>, %pas
 
 // -----
 
+func @expand_memref_mismatch(%base: memref<?x?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+  %c0 = constant 0 : index
+  // expected-error at +1 {{'vector.expandload' op requires 2 indices}}
+  %0 = vector.expandload %base[%c0], %mask, %pass_thru : memref<?x?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
 func @compress_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
   %c0 = constant 0 : index
   // expected-error at +1 {{'vector.compressstore' op base and value element type should match}}
@@ -1359,6 +1382,14 @@ func @compress_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %va
 
 // -----
 
+func @compress_memref_mismatch(%base: memref<?x?xf32>, %mask: vector<16xi1>, %value: vector<16xf32>) {
+  %c0 = constant 0 : index
+  // expected-error at +1 {{'vector.compressstore' op requires 2 indices}}
+  vector.compressstore %base[%c0, %c0, %c0], %mask, %value : memref<?x?xf32>, vector<16xi1>, vector<16xf32>
+}
+
+// -----
+
 func @extract_map_rank(%v: vector<32xf32>, %id : index) {
   // expected-error at +1 {{'vector.extract_map' op expected source and destination vectors of same rank}}
   %0 = vector.extract_map %v[%id] : vector<32xf32> to vector<2x1xf32>


        


More information about the llvm-branch-commits mailing list