[Mlir-commits] [mlir] 656674a - [mlir][Vector] Align gather/scatter/expand/compress API
Diego Caballero
llvmlistbot at llvm.org
Fri Feb 12 10:53:30 PST 2021
Author: Diego Caballero
Date: 2021-02-12T20:48:38+02:00
New Revision: 656674a7c43305dbf0b4234f1086cc88f7f4e4fb
URL: https://github.com/llvm/llvm-project/commit/656674a7c43305dbf0b4234f1086cc88f7f4e4fb
DIFF: https://github.com/llvm/llvm-project/commit/656674a7c43305dbf0b4234f1086cc88f7f4e4fb.diff
LOG: [mlir][Vector] Align gather/scatter/expand/compress API
Align the vector gather/scatter/expand/compress API with
the vector load/store/maskedload/maskedstore API.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D96396
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Dialect/Vector/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 557446bddb00..3da4cd25dc62 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1619,7 +1619,7 @@ def Vector_GatherOp :
VectorType getPassThruVectorType() {
return pass_thru().getType().cast<VectorType>();
}
- VectorType getResultVectorType() {
+ VectorType getVectorType() {
return result().getType().cast<VectorType>();
}
}];
@@ -1633,7 +1633,7 @@ def Vector_ScatterOp :
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
VectorOfRankAndType<[1], [AnyInteger]>:$indices,
VectorOfRankAndType<[1], [I1]>:$mask,
- VectorOfRank<[1]>:$value)> {
+ VectorOfRank<[1]>:$valueToStore)> {
let summary = "scatters elements from a vector into memory as defined by an index vector and mask";
@@ -1675,12 +1675,13 @@ def Vector_ScatterOp :
VectorType getMaskVectorType() {
return mask().getType().cast<VectorType>();
}
- VectorType getValueVectorType() {
- return value().getType().cast<VectorType>();
+ VectorType getVectorType() {
+ return valueToStore().getType().cast<VectorType>();
}
}];
- let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $value attr-dict `:` "
- "type($base) `,` type($indices) `,` type($mask) `,` type($value)";
+ let assemblyFormat =
+ "$base `[` $indices `]` `,` $mask `,` $valueToStore attr-dict `:` "
+ "type($base) `,` type($indices) `,` type($mask) `,` type($valueToStore)";
let hasCanonicalizer = 1;
}
@@ -1730,7 +1731,7 @@ def Vector_ExpandLoadOp :
VectorType getPassThruVectorType() {
return pass_thru().getType().cast<VectorType>();
}
- VectorType getResultVectorType() {
+ VectorType getVectorType() {
return result().getType().cast<VectorType>();
}
}];
@@ -1744,7 +1745,7 @@ def Vector_CompressStoreOp :
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$indices,
VectorOfRankAndType<[1], [I1]>:$mask,
- VectorOfRank<[1]>:$value)> {
+ VectorOfRank<[1]>:$valueToStore)> {
let summary = "writes elements selectively from a vector as defined by a mask";
@@ -1781,12 +1782,13 @@ def Vector_CompressStoreOp :
VectorType getMaskVectorType() {
return mask().getType().cast<VectorType>();
}
- VectorType getValueVectorType() {
- return value().getType().cast<VectorType>();
+ VectorType getVectorType() {
+ return valueToStore().getType().cast<VectorType>();
}
}];
- let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $value attr-dict `:` "
- "type($base) `,` type($mask) `,` type($value)";
+ let assemblyFormat =
+ "$base `[` $indices `]` `,` $mask `,` $valueToStore attr-dict `:` "
+ "type($base) `,` type($mask) `,` type($valueToStore)";
let hasCanonicalizer = 1;
}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 3393bb702a78..b6106560024c 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -446,7 +446,7 @@ class VectorGatherOpConversion
return failure();
// Get index ptrs.
- VectorType vType = gather.getResultVectorType();
+ VectorType vType = gather.getVectorType();
Type iType = gather.getIndicesVectorType().getElementType();
Value ptrs;
if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
@@ -480,7 +480,7 @@ class VectorScatterOpConversion
return failure();
// Get index ptrs.
- VectorType vType = scatter.getValueVectorType();
+ VectorType vType = scatter.getVectorType();
Type iType = scatter.getIndicesVectorType().getElementType();
Value ptrs;
if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
@@ -489,7 +489,7 @@ class VectorScatterOpConversion
// Replace with the scatter intrinsic.
rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
- scatter, adaptor.value(), ptrs, adaptor.mask(),
+ scatter, adaptor.valueToStore(), ptrs, adaptor.mask(),
rewriter.getI32IntegerAttr(align));
return success();
}
@@ -509,7 +509,7 @@ class VectorExpandLoadOpConversion
MemRefType memRefType = expand.getMemRefType();
// Resolve address.
- auto vtype = typeConverter->convertType(expand.getResultVectorType());
+ auto vtype = typeConverter->convertType(expand.getVectorType());
Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
adaptor.indices(), rewriter);
@@ -537,7 +537,7 @@ class VectorCompressStoreOpConversion
adaptor.indices(), rewriter);
rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
- compress, adaptor.value(), ptr, adaptor.mask());
+ compress, adaptor.valueToStore(), ptr, adaptor.mask());
return success();
}
};
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index a56b49a315d8..2c65a00092ea 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -2534,7 +2534,7 @@ void MaskedStoreOp::getCanonicalizationPatterns(
static LogicalResult verify(GatherOp op) {
VectorType indicesVType = op.getIndicesVectorType();
VectorType maskVType = op.getMaskVectorType();
- VectorType resVType = op.getResultVectorType();
+ VectorType resVType = op.getVectorType();
MemRefType memType = op.getMemRefType();
if (resVType.getElementType() != memType.getElementType())
@@ -2580,15 +2580,15 @@ void GatherOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
static LogicalResult verify(ScatterOp op) {
VectorType indicesVType = op.getIndicesVectorType();
VectorType maskVType = op.getMaskVectorType();
- VectorType valueVType = op.getValueVectorType();
+ VectorType valueVType = op.getVectorType();
MemRefType memType = op.getMemRefType();
if (valueVType.getElementType() != memType.getElementType())
- return op.emitOpError("base and value element type should match");
+ return op.emitOpError("base and valueToStore element type should match");
if (valueVType.getDimSize(0) != indicesVType.getDimSize(0))
- return op.emitOpError("expected value dim to match indices dim");
+ return op.emitOpError("expected valueToStore dim to match indices dim");
if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
- return op.emitOpError("expected value dim to match mask dim");
+ return op.emitOpError("expected valueToStore dim to match mask dim");
return success();
}
@@ -2624,7 +2624,7 @@ void ScatterOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
static LogicalResult verify(ExpandLoadOp op) {
VectorType maskVType = op.getMaskVectorType();
VectorType passVType = op.getPassThruVectorType();
- VectorType resVType = op.getResultVectorType();
+ VectorType resVType = op.getVectorType();
MemRefType memType = op.getMemRefType();
if (resVType.getElementType() != memType.getElementType())
@@ -2671,15 +2671,15 @@ void ExpandLoadOp::getCanonicalizationPatterns(
static LogicalResult verify(CompressStoreOp op) {
VectorType maskVType = op.getMaskVectorType();
- VectorType valueVType = op.getValueVectorType();
+ VectorType valueVType = op.getVectorType();
MemRefType memType = op.getMemRefType();
if (valueVType.getElementType() != memType.getElementType())
- return op.emitOpError("base and value element type should match");
+ return op.emitOpError("base and valueToStore element type should match");
if (llvm::size(op.indices()) != memType.getRank())
return op.emitOpError("requires ") << memType.getRank() << " indices";
if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
- return op.emitOpError("expected value dim to match mask dim");
+ return op.emitOpError("expected valueToStore dim to match mask dim");
return success();
}
@@ -2692,8 +2692,8 @@ class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
switch (get1DMaskFormat(compress.mask())) {
case MaskFormat::AllTrue:
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
- compress, compress.value(), compress.base(), compress.indices(),
- false);
+ compress, compress.valueToStore(), compress.base(),
+ compress.indices(), false);
return success();
case MaskFormat::AllFalse:
rewriter.eraseOp(compress);
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index ab58fdc37ccf..6dbc63dc4a59 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1332,7 +1332,7 @@ func @gather_pass_thru_type_mismatch(%base: memref<?xf32>, %indices: vector<16xi
func @scatter_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %value: vector<16xf32>) {
- // expected-error at +1 {{'vector.scatter' op base and value element type should match}}
+ // expected-error at +1 {{'vector.scatter' op base and valueToStore element type should match}}
vector.scatter %base[%indices], %mask, %value
: memref<?xf64>, vector<16xi32>, vector<16xi1>, vector<16xf32>
}
@@ -1350,7 +1350,7 @@ func @scatter_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
func @scatter_dim_indices_mismatch(%base: memref<?xf32>, %indices: vector<17xi32>,
%mask: vector<16xi1>, %value: vector<16xf32>) {
- // expected-error at +1 {{'vector.scatter' op expected value dim to match indices dim}}
+ // expected-error at +1 {{'vector.scatter' op expected valueToStore dim to match indices dim}}
vector.scatter %base[%indices], %mask, %value
: memref<?xf32>, vector<17xi32>, vector<16xi1>, vector<16xf32>
}
@@ -1359,7 +1359,7 @@ func @scatter_dim_indices_mismatch(%base: memref<?xf32>, %indices: vector<17xi32
func @scatter_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
%mask: vector<17xi1>, %value: vector<16xf32>) {
- // expected-error at +1 {{'vector.scatter' op expected value dim to match mask dim}}
+ // expected-error at +1 {{'vector.scatter' op expected valueToStore dim to match mask dim}}
vector.scatter %base[%indices], %mask, %value
: memref<?xf32>, vector<16xi32>, vector<17xi1>, vector<16xf32>
}
@@ -1400,7 +1400,7 @@ func @expand_memref_mismatch(%base: memref<?x?xf32>, %mask: vector<16xi1>, %pass
func @compress_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
%c0 = constant 0 : index
- // expected-error at +1 {{'vector.compressstore' op base and value element type should match}}
+ // expected-error at +1 {{'vector.compressstore' op base and valueToStore element type should match}}
vector.compressstore %base[%c0], %mask, %value : memref<?xf64>, vector<16xi1>, vector<16xf32>
}
@@ -1408,7 +1408,7 @@ func @compress_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %v
func @compress_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %value: vector<16xf32>) {
%c0 = constant 0 : index
- // expected-error at +1 {{'vector.compressstore' op expected value dim to match mask dim}}
+ // expected-error at +1 {{'vector.compressstore' op expected valueToStore dim to match mask dim}}
vector.compressstore %base[%c0], %mask, %value : memref<?xf32>, vector<17xi1>, vector<16xf32>
}
More information about the Mlir-commits
mailing list