[Mlir-commits] [mlir] 656674a - [mlir][Vector] Align gather/scatter/expand/compress API

Diego Caballero llvmlistbot at llvm.org
Fri Feb 12 10:53:30 PST 2021


Author: Diego Caballero
Date: 2021-02-12T20:48:38+02:00
New Revision: 656674a7c43305dbf0b4234f1086cc88f7f4e4fb

URL: https://github.com/llvm/llvm-project/commit/656674a7c43305dbf0b4234f1086cc88f7f4e4fb
DIFF: https://github.com/llvm/llvm-project/commit/656674a7c43305dbf0b4234f1086cc88f7f4e4fb.diff

LOG: [mlir][Vector] Align gather/scatter/expand/compress API

Align the vector gather/scatter/expand/compress API with
the vector load/store/maskedload/maskedstore API.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D96396

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Dialect/Vector/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 557446bddb00..3da4cd25dc62 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1619,7 +1619,7 @@ def Vector_GatherOp :
     VectorType getPassThruVectorType() {
       return pass_thru().getType().cast<VectorType>();
     }
-    VectorType getResultVectorType() {
+    VectorType getVectorType() {
       return result().getType().cast<VectorType>();
     }
   }];
@@ -1633,7 +1633,7 @@ def Vector_ScatterOp :
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                VectorOfRankAndType<[1], [AnyInteger]>:$indices,
                VectorOfRankAndType<[1], [I1]>:$mask,
-               VectorOfRank<[1]>:$value)> {
+               VectorOfRank<[1]>:$valueToStore)> {
 
   let summary = "scatters elements from a vector into memory as defined by an index vector and mask";
 
@@ -1675,12 +1675,13 @@ def Vector_ScatterOp :
     VectorType getMaskVectorType() {
       return mask().getType().cast<VectorType>();
     }
-    VectorType getValueVectorType() {
-      return value().getType().cast<VectorType>();
+    VectorType getVectorType() {
+      return valueToStore().getType().cast<VectorType>();
     }
   }];
-  let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $value attr-dict `:` "
-    "type($base) `,` type($indices) `,` type($mask) `,` type($value)";
+  let assemblyFormat =
+      "$base `[` $indices `]` `,` $mask `,` $valueToStore attr-dict `:` "
+      "type($base) `,` type($indices) `,` type($mask) `,` type($valueToStore)";
   let hasCanonicalizer = 1;
 }
 
@@ -1730,7 +1731,7 @@ def Vector_ExpandLoadOp :
     VectorType getPassThruVectorType() {
       return pass_thru().getType().cast<VectorType>();
     }
-    VectorType getResultVectorType() {
+    VectorType getVectorType() {
       return result().getType().cast<VectorType>();
     }
   }];
@@ -1744,7 +1745,7 @@ def Vector_CompressStoreOp :
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$indices,
                VectorOfRankAndType<[1], [I1]>:$mask,
-               VectorOfRank<[1]>:$value)> {
+               VectorOfRank<[1]>:$valueToStore)> {
 
   let summary = "writes elements selectively from a vector as defined by a mask";
 
@@ -1781,12 +1782,13 @@ def Vector_CompressStoreOp :
     VectorType getMaskVectorType() {
       return mask().getType().cast<VectorType>();
     }
-    VectorType getValueVectorType() {
-      return value().getType().cast<VectorType>();
+    VectorType getVectorType() {
+      return valueToStore().getType().cast<VectorType>();
     }
   }];
-  let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $value attr-dict `:` "
-    "type($base) `,` type($mask) `,` type($value)";
+  let assemblyFormat =
+      "$base `[` $indices `]` `,` $mask `,` $valueToStore attr-dict `:` "
+      "type($base) `,` type($mask) `,` type($valueToStore)";
   let hasCanonicalizer = 1;
 }
 

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 3393bb702a78..b6106560024c 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -446,7 +446,7 @@ class VectorGatherOpConversion
       return failure();
 
     // Get index ptrs.
-    VectorType vType = gather.getResultVectorType();
+    VectorType vType = gather.getVectorType();
     Type iType = gather.getIndicesVectorType().getElementType();
     Value ptrs;
     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
@@ -480,7 +480,7 @@ class VectorScatterOpConversion
       return failure();
 
     // Get index ptrs.
-    VectorType vType = scatter.getValueVectorType();
+    VectorType vType = scatter.getVectorType();
     Type iType = scatter.getIndicesVectorType().getElementType();
     Value ptrs;
     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
@@ -489,7 +489,7 @@ class VectorScatterOpConversion
 
     // Replace with the scatter intrinsic.
     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
-        scatter, adaptor.value(), ptrs, adaptor.mask(),
+        scatter, adaptor.valueToStore(), ptrs, adaptor.mask(),
         rewriter.getI32IntegerAttr(align));
     return success();
   }
@@ -509,7 +509,7 @@ class VectorExpandLoadOpConversion
     MemRefType memRefType = expand.getMemRefType();
 
     // Resolve address.
-    auto vtype = typeConverter->convertType(expand.getResultVectorType());
+    auto vtype = typeConverter->convertType(expand.getVectorType());
     Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
                                            adaptor.indices(), rewriter);
 
@@ -537,7 +537,7 @@ class VectorCompressStoreOpConversion
                                            adaptor.indices(), rewriter);
 
     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
-        compress, adaptor.value(), ptr, adaptor.mask());
+        compress, adaptor.valueToStore(), ptr, adaptor.mask());
     return success();
   }
 };

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index a56b49a315d8..2c65a00092ea 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -2534,7 +2534,7 @@ void MaskedStoreOp::getCanonicalizationPatterns(
 static LogicalResult verify(GatherOp op) {
   VectorType indicesVType = op.getIndicesVectorType();
   VectorType maskVType = op.getMaskVectorType();
-  VectorType resVType = op.getResultVectorType();
+  VectorType resVType = op.getVectorType();
   MemRefType memType = op.getMemRefType();
 
   if (resVType.getElementType() != memType.getElementType())
@@ -2580,15 +2580,15 @@ void GatherOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
 static LogicalResult verify(ScatterOp op) {
   VectorType indicesVType = op.getIndicesVectorType();
   VectorType maskVType = op.getMaskVectorType();
-  VectorType valueVType = op.getValueVectorType();
+  VectorType valueVType = op.getVectorType();
   MemRefType memType = op.getMemRefType();
 
   if (valueVType.getElementType() != memType.getElementType())
-    return op.emitOpError("base and value element type should match");
+    return op.emitOpError("base and valueToStore element type should match");
   if (valueVType.getDimSize(0) != indicesVType.getDimSize(0))
-    return op.emitOpError("expected value dim to match indices dim");
+    return op.emitOpError("expected valueToStore dim to match indices dim");
   if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
-    return op.emitOpError("expected value dim to match mask dim");
+    return op.emitOpError("expected valueToStore dim to match mask dim");
   return success();
 }
 
@@ -2624,7 +2624,7 @@ void ScatterOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
 static LogicalResult verify(ExpandLoadOp op) {
   VectorType maskVType = op.getMaskVectorType();
   VectorType passVType = op.getPassThruVectorType();
-  VectorType resVType = op.getResultVectorType();
+  VectorType resVType = op.getVectorType();
   MemRefType memType = op.getMemRefType();
 
   if (resVType.getElementType() != memType.getElementType())
@@ -2671,15 +2671,15 @@ void ExpandLoadOp::getCanonicalizationPatterns(
 
 static LogicalResult verify(CompressStoreOp op) {
   VectorType maskVType = op.getMaskVectorType();
-  VectorType valueVType = op.getValueVectorType();
+  VectorType valueVType = op.getVectorType();
   MemRefType memType = op.getMemRefType();
 
   if (valueVType.getElementType() != memType.getElementType())
-    return op.emitOpError("base and value element type should match");
+    return op.emitOpError("base and valueToStore element type should match");
   if (llvm::size(op.indices()) != memType.getRank())
     return op.emitOpError("requires ") << memType.getRank() << " indices";
   if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
-    return op.emitOpError("expected value dim to match mask dim");
+    return op.emitOpError("expected valueToStore dim to match mask dim");
   return success();
 }
 
@@ -2692,8 +2692,8 @@ class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
     switch (get1DMaskFormat(compress.mask())) {
     case MaskFormat::AllTrue:
       rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
-          compress, compress.value(), compress.base(), compress.indices(),
-          false);
+          compress, compress.valueToStore(), compress.base(),
+          compress.indices(), false);
       return success();
     case MaskFormat::AllFalse:
       rewriter.eraseOp(compress);

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index ab58fdc37ccf..6dbc63dc4a59 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1332,7 +1332,7 @@ func @gather_pass_thru_type_mismatch(%base: memref<?xf32>, %indices: vector<16xi
 
 func @scatter_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
                                  %mask: vector<16xi1>, %value: vector<16xf32>) {
-  // expected-error at +1 {{'vector.scatter' op base and value element type should match}}
+  // expected-error at +1 {{'vector.scatter' op base and valueToStore element type should match}}
   vector.scatter %base[%indices], %mask, %value
     : memref<?xf64>, vector<16xi32>, vector<16xi1>, vector<16xf32>
 }
@@ -1350,7 +1350,7 @@ func @scatter_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
 
 func @scatter_dim_indices_mismatch(%base: memref<?xf32>, %indices: vector<17xi32>,
                                    %mask: vector<16xi1>, %value: vector<16xf32>) {
-  // expected-error at +1 {{'vector.scatter' op expected value dim to match indices dim}}
+  // expected-error at +1 {{'vector.scatter' op expected valueToStore dim to match indices dim}}
   vector.scatter %base[%indices], %mask, %value
     : memref<?xf32>, vector<17xi32>, vector<16xi1>, vector<16xf32>
 }
@@ -1359,7 +1359,7 @@ func @scatter_dim_indices_mismatch(%base: memref<?xf32>, %indices: vector<17xi32
 
 func @scatter_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
                                 %mask: vector<17xi1>, %value: vector<16xf32>) {
-  // expected-error at +1 {{'vector.scatter' op expected value dim to match mask dim}}
+  // expected-error at +1 {{'vector.scatter' op expected valueToStore dim to match mask dim}}
   vector.scatter %base[%indices], %mask, %value
     : memref<?xf32>, vector<16xi32>, vector<17xi1>, vector<16xf32>
 }
@@ -1400,7 +1400,7 @@ func @expand_memref_mismatch(%base: memref<?x?xf32>, %mask: vector<16xi1>, %pass
 
 func @compress_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
   %c0 = constant 0 : index
-  // expected-error at +1 {{'vector.compressstore' op base and value element type should match}}
+  // expected-error at +1 {{'vector.compressstore' op base and valueToStore element type should match}}
   vector.compressstore %base[%c0], %mask, %value : memref<?xf64>, vector<16xi1>, vector<16xf32>
 }
 
@@ -1408,7 +1408,7 @@ func @compress_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %v
 
 func @compress_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %value: vector<16xf32>) {
   %c0 = constant 0 : index
-  // expected-error at +1 {{'vector.compressstore' op expected value dim to match mask dim}}
+  // expected-error at +1 {{'vector.compressstore' op expected valueToStore dim to match mask dim}}
   vector.compressstore %base[%c0], %mask, %value : memref<?xf32>, vector<17xi1>, vector<16xf32>
 }
 


        


More information about the Mlir-commits mailing list