[Mlir-commits] [mlir] df5ccf5 - [mlir][vector] add higher dimensional support to gather/scatter
Aart Bik
llvmlistbot at llvm.org
Fri Feb 26 14:20:34 PST 2021
Author: Aart Bik
Date: 2021-02-26T14:20:19-08:00
New Revision: df5ccf5a94ad04231e8426dd0d02689a0717453b
URL: https://github.com/llvm/llvm-project/commit/df5ccf5a94ad04231e8426dd0d02689a0717453b
DIFF: https://github.com/llvm/llvm-project/commit/df5ccf5a94ad04231e8426dd0d02689a0717453b.diff
LOG: [mlir][vector] add higher dimensional support to gather/scatter
Similar to mask-load/store and compress/expand, the gather and
scatter operation now allow for higher dimension uses. Note that
to support the mixed-type index, the new syntax is:
vector.gather %base [%i,%j] [%kvector] ....
The first client of this generalization is the sparse compiler,
which needs to define scatter and gathers on dense operands
of higher dimensions too.
Reviewed By: bixia
Differential Revision: https://reviews.llvm.org/D97422
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/Linalg/sparse_vector.mlir
mlir/test/Dialect/Vector/canonicalize.mlir
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
mlir/test/Dialect/Vector/vector-mem-transforms.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-gather.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-scatter.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-sparse-dot-matvec.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-sparse-saxpy-jagged-matvec.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index cf18cd89e170..83b221fac013 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1574,11 +1574,14 @@ def Vector_MaskedLoadOp :
closely correspond to those of the `llvm.masked.load`
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-load-intrinsics).
- Example:
+ Examples:
```mlir
%0 = vector.maskedload %base[%i], %mask, %pass_thru
: memref<?xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+
+ %1 = vector.maskedload %base[%i, %j], %mask, %pass_thru
+ : memref<?x?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
```
}];
let extraClassDeclaration = [{
@@ -1625,11 +1628,14 @@ def Vector_MaskedStoreOp :
closely correspond to those of the `llvm.masked.store`
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-store-intrinsics).
- Example:
+ Examples:
```mlir
vector.maskedstore %base[%i], %mask, %value
: memref<?xf32>, vector<8xi1>, vector<8xf32>
+
+ vector.maskedstore %base[%i, %j], %mask, %value
+ : memref<?x?xf32>, vector<16xi1>, vector<16xf32>
```
}];
let extraClassDeclaration = [{
@@ -1652,7 +1658,8 @@ def Vector_MaskedStoreOp :
def Vector_GatherOp :
Vector_Op<"gather">,
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
- VectorOfRankAndType<[1], [AnyInteger]>:$indices,
+ Variadic<Index>:$indices,
+ VectorOfRankAndType<[1], [AnyInteger]>:$index_vec,
VectorOfRankAndType<[1], [I1]>:$mask,
VectorOfRank<[1]>:$pass_thru)>,
Results<(outs VectorOfRank<[1]>:$result)> {
@@ -1661,9 +1668,10 @@ def Vector_GatherOp :
let description = [{
The gather operation gathers elements from memory into a 1-D vector as
- defined by a base and a 1-D index vector, but only if the corresponding
- bit is set in a 1-D mask vector. Otherwise, the element is taken from a
- 1-D pass-through vector. Informally the semantics are:
+ defined by a base with indices and an additional 1-D index vector, but
+ only if the corresponding bit is set in a 1-D mask vector. Otherwise, the
+ element is taken from a 1-D pass-through vector. Informally the semantics
+ are:
```
result[0] := mask[0] ? base[index[0]] : pass_thru[0]
result[1] := mask[1] ? base[index[1]] : pass_thru[1]
@@ -1677,19 +1685,22 @@ def Vector_GatherOp :
correspond to those of the `llvm.masked.gather`
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-gather-intrinsics).
- Example:
+ Examples:
```mlir
- %g = vector.gather %base[%indices], %mask, %pass_thru
+ %0 = vector.gather %base[%c0][%v], %mask, %pass_thru
: memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+
+ %1 = vector.gather %base[%i, %j][%v], %mask, %pass_thru
+ : memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
```
}];
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return base().getType().cast<MemRefType>();
}
- VectorType getIndicesVectorType() {
- return indices().getType().cast<VectorType>();
+ VectorType getIndexVectorType() {
+ return index_vec().getType().cast<VectorType>();
}
VectorType getMaskVectorType() {
return mask().getType().cast<VectorType>();
@@ -1701,15 +1712,19 @@ def Vector_GatherOp :
return result().getType().cast<VectorType>();
}
}];
- let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` "
- "type($base) `,` type($indices) `,` type($mask) `,` type($pass_thru) `into` type($result)";
+ let assemblyFormat =
+ "$base `[` $indices `]` `[` $index_vec `]` `,` "
+ "$mask `,` $pass_thru attr-dict `:` type($base) `,` "
+ "type($index_vec) `,` type($mask) `,` type($pass_thru) "
+ "`into` type($result)";
let hasCanonicalizer = 1;
}
def Vector_ScatterOp :
Vector_Op<"scatter">,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
- VectorOfRankAndType<[1], [AnyInteger]>:$indices,
+ Variadic<Index>:$indices,
+ VectorOfRankAndType<[1], [AnyInteger]>:$index_vec,
VectorOfRankAndType<[1], [I1]>:$mask,
VectorOfRank<[1]>:$valueToStore)> {
@@ -1717,9 +1732,9 @@ def Vector_ScatterOp :
let description = [{
The scatter operation scatters elements from a 1-D vector into memory as
- defined by a base and a 1-D index vector, but only if the corresponding
- bit in a 1-D mask vector is set. Otherwise, no action is taken for that
- element. Informally the semantics are:
+ defined by a base with indices and an additional 1-D index vector, but
+ only if the corresponding bit in a 1-D mask vector is set. Otherwise, no
+ action is taken for that element. Informally the semantics are:
```
if (mask[0]) base[index[0]] = value[0]
if (mask[1]) base[index[1]] = value[1]
@@ -1736,19 +1751,22 @@ def Vector_ScatterOp :
correspond to those of the `llvm.masked.scatter`
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-scatter-intrinsics).
- Example:
+ Examples:
```mlir
- vector.scatter %base[%indices], %mask, %value
+ vector.scatter %base[%c0][%v], %mask, %value
: memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+
+ vector.scatter %base[%i, %j][%v], %mask, %value
+ : memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
```
}];
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return base().getType().cast<MemRefType>();
}
- VectorType getIndicesVectorType() {
- return indices().getType().cast<VectorType>();
+ VectorType getIndexVectorType() {
+ return index_vec().getType().cast<VectorType>();
}
VectorType getMaskVectorType() {
return mask().getType().cast<VectorType>();
@@ -1758,8 +1776,9 @@ def Vector_ScatterOp :
}
}];
let assemblyFormat =
- "$base `[` $indices `]` `,` $mask `,` $valueToStore attr-dict `:` "
- "type($base) `,` type($indices) `,` type($mask) `,` type($valueToStore)";
+ "$base `[` $indices `]` `[` $index_vec `]` `,` "
+ "$mask `,` $valueToStore attr-dict `:` type($base) `,` "
+ "type($index_vec) `,` type($mask) `,` type($valueToStore)";
let hasCanonicalizer = 1;
}
@@ -1792,11 +1811,14 @@ def Vector_ExpandLoadOp :
correspond to those of the `llvm.masked.expandload`
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics).
- Example:
+ Examples:
```mlir
%0 = vector.expandload %base[%i], %mask, %pass_thru
: memref<?xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+
+ %1 = vector.expandload %base[%i, %j], %mask, %pass_thru
+ : memref<?x?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
```
}];
let extraClassDeclaration = [{
@@ -1846,11 +1868,14 @@ def Vector_CompressStoreOp :
correspond to those of the `llvm.masked.compressstore`
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics).
- Example:
+ Examples:
```mlir
vector.compressstore %base[%i], %mask, %value
: memref<?xf32>, vector<8xi1>, vector<8xf32>
+
+ vector.compressstore %base[%i, %j], %mask, %value
+ : memref<?x?xf32>, vector<16xi1>, vector<16xf32>
```
}];
let extraClassDeclaration = [{
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 168665e23111..d567e065479d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -178,34 +178,21 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
return success();
}
-// Helper that returns the base address of a memref.
-static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
- Value memref, MemRefType memRefType, Value &base) {
- // Inspect stride and offset structure.
- //
- // TODO: flat memory only for now, generalize
- //
+// Add an index vector component to a base pointer. This almost always succeeds
+// unless the last stride is non-unit or the memory space is not zero.
+static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
+ Location loc, Value memref, Value base,
+ Value index, MemRefType memRefType,
+ VectorType vType, Value &ptrs) {
int64_t offset;
SmallVector<int64_t, 4> strides;
auto successStrides = getStridesAndOffset(memRefType, strides, offset);
- if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
- offset != 0 || memRefType.getMemorySpace() != 0)
- return failure();
- base = MemRefDescriptor(memref).alignedPtr(rewriter, loc);
- return success();
-}
-
-// Helper that returns vector of pointers given a memref base with index vector.
-static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
- Location loc, Value memref, Value indices,
- MemRefType memRefType, VectorType vType,
- Type iType, Value &ptrs) {
- Value base;
- if (failed(getBase(rewriter, loc, memref, memRefType, base)))
+ if (failed(successStrides) || strides.back() != 1 ||
+ memRefType.getMemorySpace() != 0)
return failure();
auto pType = MemRefDescriptor(memref).getElementPtrType();
auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0));
- ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
+ ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index);
return success();
}
@@ -435,19 +422,20 @@ class VectorGatherOpConversion
ConversionPatternRewriter &rewriter) const override {
auto loc = gather->getLoc();
auto adaptor = vector::GatherOpAdaptor(operands);
+ MemRefType memRefType = gather.getMemRefType();
// Resolve alignment.
unsigned align;
- if (failed(getMemRefAlignment(*getTypeConverter(), gather.getMemRefType(),
- align)))
+ if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
return failure();
- // Get index ptrs.
- VectorType vType = gather.getVectorType();
- Type iType = gather.getIndicesVectorType().getElementType();
+ // Resolve address.
Value ptrs;
- if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
- gather.getMemRefType(), vType, iType, ptrs)))
+ VectorType vType = gather.getVectorType();
+ Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
+ adaptor.indices(), rewriter);
+ if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr,
+ adaptor.index_vec(), memRefType, vType, ptrs)))
return failure();
// Replace with the gather intrinsic.
@@ -469,19 +457,20 @@ class VectorScatterOpConversion
ConversionPatternRewriter &rewriter) const override {
auto loc = scatter->getLoc();
auto adaptor = vector::ScatterOpAdaptor(operands);
+ MemRefType memRefType = scatter.getMemRefType();
// Resolve alignment.
unsigned align;
- if (failed(getMemRefAlignment(*getTypeConverter(), scatter.getMemRefType(),
- align)))
+ if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
return failure();
- // Get index ptrs.
- VectorType vType = scatter.getVectorType();
- Type iType = scatter.getIndicesVectorType().getElementType();
+ // Resolve address.
Value ptrs;
- if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
- scatter.getMemRefType(), vType, iType, ptrs)))
+ VectorType vType = scatter.getVectorType();
+ Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
+ adaptor.indices(), rewriter);
+ if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr,
+ adaptor.index_vec(), memRefType, vType, ptrs)))
return failure();
// Replace with the scatter intrinsic.
@@ -507,8 +496,8 @@ class VectorExpandLoadOpConversion
// Resolve address.
auto vtype = typeConverter->convertType(expand.getVectorType());
- Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
- adaptor.indices(), rewriter);
+ Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
+ adaptor.indices(), rewriter);
rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru());
@@ -530,8 +519,8 @@ class VectorCompressStoreOpConversion
MemRefType memRefType = compress.getMemRefType();
// Resolve address.
- Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
- adaptor.indices(), rewriter);
+ Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
+ adaptor.indices(), rewriter);
rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
compress, adaptor.valueToStore(), ptr, adaptor.mask());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
index fbb45c6b86d8..efaaed6f2739 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
@@ -652,9 +652,13 @@ static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter,
Location loc = ptr.getLoc();
VectorType vtp = vectorType(codegen, ptr);
Value pass = rewriter.create<ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp));
- if (args.back().getType().isa<VectorType>())
- return rewriter.create<vector::GatherOp>(loc, vtp, ptr, args.back(),
- codegen.curVecMask, pass);
+ if (args.back().getType().isa<VectorType>()) {
+ SmallVector<Value, 4> scalarArgs(args.begin(), args.end());
+ Value indexVec = args.back();
+ scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0);
+ return rewriter.create<vector::GatherOp>(
+ loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass);
+ }
return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args,
codegen.curVecMask, pass);
}
@@ -663,12 +667,16 @@ static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter,
static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter,
Value rhs, Value ptr, ArrayRef<Value> args) {
Location loc = ptr.getLoc();
- if (args.back().getType().isa<VectorType>())
- rewriter.create<vector::ScatterOp>(loc, ptr, args.back(),
+ if (args.back().getType().isa<VectorType>()) {
+ SmallVector<Value, 4> scalarArgs(args.begin(), args.end());
+ Value indexVec = args.back();
+ scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0);
+ rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec,
codegen.curVecMask, rhs);
- else
- rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask,
- rhs);
+ return;
+ }
+ rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask,
+ rhs);
}
/// Generates a vectorized invariant. Here we rely on subsequent loop
@@ -985,11 +993,15 @@ static Operation *genWhile(Merger &merger, CodeGen &codegen,
unsigned tensor = merger.tensor(b);
assert(idx == merger.index(b));
types.push_back(indexType);
+ assert(codegen.pidxs[tensor][idx].getType().isa<IndexType>() &&
+ "type mismatch for sparse index");
operands.push_back(codegen.pidxs[tensor][idx]);
}
}
if (needsUniv) {
types.push_back(indexType);
+ assert(codegen.loops[idx].getType().isa<IndexType>() &&
+ "type_mismatch for universal index");
operands.push_back(codegen.loops[idx]);
}
Location loc = op.getLoc();
@@ -1160,6 +1172,7 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
genTensorStore(merger, codegen, rewriter, op, lhs, rhs);
return;
}
+ assert(codegen.curVecLength == 1);
// Construct iteration lattices for current loop index, with L0 at top.
// Then emit initialization code for the loop sequence at this level.
@@ -1239,6 +1252,7 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
}
genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false);
codegen.loops[idx] = Value();
+ codegen.curVecLength = 1;
}
namespace {
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 01f9c9805675..4702626c3e8c 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -2771,14 +2771,16 @@ void MaskedStoreOp::getCanonicalizationPatterns(
//===----------------------------------------------------------------------===//
static LogicalResult verify(GatherOp op) {
- VectorType indicesVType = op.getIndicesVectorType();
+ VectorType indVType = op.getIndexVectorType();
VectorType maskVType = op.getMaskVectorType();
VectorType resVType = op.getVectorType();
MemRefType memType = op.getMemRefType();
if (resVType.getElementType() != memType.getElementType())
return op.emitOpError("base and result element type should match");
- if (resVType.getDimSize(0) != indicesVType.getDimSize(0))
+ if (llvm::size(op.indices()) != memType.getRank())
+ return op.emitOpError("requires ") << memType.getRank() << " indices";
+ if (resVType.getDimSize(0) != indVType.getDimSize(0))
return op.emitOpError("expected result dim to match indices dim");
if (resVType.getDimSize(0) != maskVType.getDimSize(0))
return op.emitOpError("expected result dim to match mask dim");
@@ -2817,14 +2819,16 @@ void GatherOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
//===----------------------------------------------------------------------===//
static LogicalResult verify(ScatterOp op) {
- VectorType indicesVType = op.getIndicesVectorType();
+ VectorType indVType = op.getIndexVectorType();
VectorType maskVType = op.getMaskVectorType();
VectorType valueVType = op.getVectorType();
MemRefType memType = op.getMemRefType();
if (valueVType.getElementType() != memType.getElementType())
return op.emitOpError("base and valueToStore element type should match");
- if (valueVType.getDimSize(0) != indicesVType.getDimSize(0))
+ if (llvm::size(op.indices()) != memType.getRank())
+ return op.emitOpError("requires ") << memType.getRank() << " indices";
+ if (valueVType.getDimSize(0) != indVType.getDimSize(0))
return op.emitOpError("expected valueToStore dim to match indices dim");
if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
return op.emitOpError("expected valueToStore dim to match mask dim");
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index e5c1adcee0ed..39826a652092 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1302,28 +1302,57 @@ func @masked_store_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<
// -----
func @gather_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
- %0 = vector.gather %arg0[%arg1], %arg2, %arg3 : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
- return %0 : vector<3xf32>
+ %0 = constant 0: index
+ %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
+ return %1 : vector<3xf32>
}
// CHECK-LABEL: func @gather_op
-// CHECK: %[[P:.*]] = llvm.getelementptr {{.*}}[%{{.*}}] : (!llvm.ptr<f32>, vector<3xi32>) -> !llvm.vec<3 x ptr<f32>>
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<f32>, vector<3xi32>) -> !llvm.vec<3 x ptr<f32>>
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr<f32>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
// CHECK: return %[[G]] : vector<3xf32>
// -----
+func @gather_2d_op(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> {
+ %0 = constant 3 : index
+ %1 = vector.gather %arg0[%0, %0][%arg1], %arg2, %arg3 : memref<4x4xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+ return %1 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @gather_2d_op
+// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr<f32>, vector<4xi32>) -> !llvm.vec<4 x ptr<f32>>
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<4 x ptr<f32>>, vector<4xi1>, vector<4xf32>) -> vector<4xf32>
+// CHECK: return %[[G]] : vector<4xf32>
+
+// -----
+
func @scatter_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) {
- vector.scatter %arg0[%arg1], %arg2, %arg3 : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
+ %0 = constant 0: index
+ vector.scatter %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
return
}
// CHECK-LABEL: func @scatter_op
-// CHECK: %[[P:.*]] = llvm.getelementptr {{.*}}[%{{.*}}] : (!llvm.ptr<f32>, vector<3xi32>) -> !llvm.vec<3 x ptr<f32>>
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<f32>, vector<3xi32>) -> !llvm.vec<3 x ptr<f32>>
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into !llvm.vec<3 x ptr<f32>>
// -----
+func @scatter_2d_op(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) {
+ %0 = constant 3 : index
+ vector.scatter %arg0[%0, %0][%arg1], %arg2, %arg3 : memref<4x4xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32>
+ return
+}
+
+// CHECK-LABEL: func @scatter_2d_op
+// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr<f32>, vector<4xi32>) -> !llvm.vec<4 x ptr<f32>>
+// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<4xf32>, vector<4xi1> into !llvm.vec<4 x ptr<f32>>
+
+// -----
+
func @expand_load_op(%arg0: memref<?xf32>, %arg1: vector<11xi1>, %arg2: vector<11xf32>) -> vector<11xf32> {
%c0 = constant 0: index
%0 = vector.expandload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<11xi1>, vector<11xf32> into vector<11xf32>
diff --git a/mlir/test/Dialect/Linalg/sparse_vector.mlir b/mlir/test/Dialect/Linalg/sparse_vector.mlir
index 644b231a3cb9..1ecfe27789af 100644
--- a/mlir/test/Dialect/Linalg/sparse_vector.mlir
+++ b/mlir/test/Dialect/Linalg/sparse_vector.mlir
@@ -128,9 +128,9 @@ func @scale_d(%arga: tensor<1024xf32>, %scale: f32, %argx: tensor<1024xf32>) ->
// CHECK-VEC2: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
// CHECK-VEC2: %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
// CHECK-VEC2: %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
-// CHECK-VEC2: %[[lb:.*]] = vector.gather %{{.*}}[%[[li]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC2: %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
// CHECK-VEC2: %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
-// CHECK-VEC2: vector.scatter %{{.*}}[%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+// CHECK-VEC2: vector.scatter %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
// CHECK-VEC2: }
// CHECK-VEC2: return
//
@@ -159,9 +159,9 @@ func @mul_s(%arga: tensor<1024xf32>, %argb: tensor<1024xf32>, %argx: tensor<1024
// CHECK-VEC2: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
// CHECK-VEC2: %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
// CHECK-VEC2: %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
-// CHECK-VEC2: %[[lb:.*]] = vector.gather %{{.*}}[%[[li]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC2: %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
// CHECK-VEC2: %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
-// CHECK-VEC2: vector.scatter %{{.*}}[%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+// CHECK-VEC2: vector.scatter %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
// CHECK-VEC2: }
// CHECK-VEC2: return
//
@@ -324,9 +324,9 @@ func @reduction_d(%arga: tensor<1024xf32>, %argb: tensor<1024xf32>, %argx: tenso
// CHECK-VEC2: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
// CHECK-VEC2: %[[lj:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
// CHECK-VEC2: %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
-// CHECK-VEC2: %[[lb:.*]] = vector.gather %{{.*}}[%[[lj]]], %[[mask]], %{{.*}} : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC2: %[[lb:.*]] = vector.gather %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %{{.*}} : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
// CHECK-VEC2: %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
-// CHECK-VEC2: vector.scatter %{{.*}}[%[[lj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+// CHECK-VEC2: vector.scatter %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
// CHECK-VEC2: }
// CHECK-VEC2: }
// CHECK-VEC2: return
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index d427cb952f09..45f5d7d2b901 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -718,7 +718,7 @@ func @dead_load(%base: memref<?xf32>, %indices: vector<16xi32>,
%c0 = constant 0 : index
%0 = vector.maskedload %base[%c0], %mask, %passthru :
memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
- %1 = vector.gather %base[%indices], %mask, %passthru :
+ %1 = vector.gather %base[%c0][%indices], %mask, %passthru :
memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
%2 = vector.expandload %base[%c0], %mask, %passthru :
memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 6dbc63dc4a59..3b64a6e57e66 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1287,17 +1287,29 @@ func @maskedstore_memref_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>, %v
func @gather_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+ %c0 = constant 0 : index
// expected-error at +1 {{'vector.gather' op base and result element type should match}}
- %0 = vector.gather %base[%indices], %mask, %pass_thru
+ %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
: memref<?xf64>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
}
// -----
+func @gather_memref_mismatch(%base: memref<?x?xf64>, %indices: vector<16xi32>,
+ %mask: vector<16xi1>, %pass_thru: vector<16xf64>) {
+ %c0 = constant 0 : index
+ // expected-error at +1 {{'vector.gather' op requires 2 indices}}
+ %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
+ : memref<?x?xf64>, vector<16xi32>, vector<16xi1>, vector<16xf64> into vector<16xf64>
+}
+
+// -----
+
func @gather_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+ %c0 = constant 0 : index
// expected-error at +1 {{'vector.gather' op result #0 must be of ranks 1, but got 'vector<2x16xf32>'}}
- %0 = vector.gather %base[%indices], %mask, %pass_thru
+ %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
: memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<2x16xf32>
}
@@ -1305,8 +1317,9 @@ func @gather_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
func @gather_dim_indices_mismatch(%base: memref<?xf32>, %indices: vector<17xi32>,
%mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+ %c0 = constant 0 : index
// expected-error at +1 {{'vector.gather' op expected result dim to match indices dim}}
- %0 = vector.gather %base[%indices], %mask, %pass_thru
+ %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
: memref<?xf32>, vector<17xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
}
@@ -1314,8 +1327,9 @@ func @gather_dim_indices_mismatch(%base: memref<?xf32>, %indices: vector<17xi32>
func @gather_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
%mask: vector<17xi1>, %pass_thru: vector<16xf32>) {
+ %c0 = constant 0 : index
// expected-error at +1 {{'vector.gather' op expected result dim to match mask dim}}
- %0 = vector.gather %base[%indices], %mask, %pass_thru
+ %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
: memref<?xf32>, vector<16xi32>, vector<17xi1>, vector<16xf32> into vector<16xf32>
}
@@ -1323,8 +1337,9 @@ func @gather_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
func @gather_pass_thru_type_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %pass_thru: vector<16xf64>) {
+ %c0 = constant 0 : index
// expected-error at +1 {{'vector.gather' op expected pass_thru of same type as result type}}
- %0 = vector.gather %base[%indices], %mask, %pass_thru
+ %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
: memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf64> into vector<16xf32>
}
@@ -1332,17 +1347,29 @@ func @gather_pass_thru_type_mismatch(%base: memref<?xf32>, %indices: vector<16xi
func @scatter_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %value: vector<16xf32>) {
+ %c0 = constant 0 : index
// expected-error at +1 {{'vector.scatter' op base and valueToStore element type should match}}
- vector.scatter %base[%indices], %mask, %value
+ vector.scatter %base[%c0][%indices], %mask, %value
: memref<?xf64>, vector<16xi32>, vector<16xi1>, vector<16xf32>
}
// -----
+func @scatter_memref_mismatch(%base: memref<?x?xf64>, %indices: vector<16xi32>,
+ %mask: vector<16xi1>, %value: vector<16xf64>) {
+ %c0 = constant 0 : index
+ // expected-error at +1 {{'vector.scatter' op requires 2 indices}}
+ vector.scatter %base[%c0][%indices], %mask, %value
+ : memref<?x?xf64>, vector<16xi32>, vector<16xi1>, vector<16xf64>
+}
+
+// -----
+
func @scatter_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %value: vector<2x16xf32>) {
- // expected-error at +1 {{'vector.scatter' op operand #3 must be of ranks 1, but got 'vector<2x16xf32>'}}
- vector.scatter %base[%indices], %mask, %value
+ %c0 = constant 0 : index
+ // expected-error at +1 {{'vector.scatter' op operand #4 must be of ranks 1, but got 'vector<2x16xf32>'}}
+ vector.scatter %base[%c0][%indices], %mask, %value
: memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<2x16xf32>
}
@@ -1350,8 +1377,9 @@ func @scatter_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
func @scatter_dim_indices_mismatch(%base: memref<?xf32>, %indices: vector<17xi32>,
%mask: vector<16xi1>, %value: vector<16xf32>) {
+ %c0 = constant 0 : index
// expected-error at +1 {{'vector.scatter' op expected valueToStore dim to match indices dim}}
- vector.scatter %base[%indices], %mask, %value
+ vector.scatter %base[%c0][%indices], %mask, %value
: memref<?xf32>, vector<17xi32>, vector<16xi1>, vector<16xf32>
}
@@ -1359,8 +1387,9 @@ func @scatter_dim_indices_mismatch(%base: memref<?xf32>, %indices: vector<17xi32
func @scatter_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
%mask: vector<17xi1>, %value: vector<16xf32>) {
+ %c0 = constant 0 : index
// expected-error at +1 {{'vector.scatter' op expected valueToStore dim to match mask dim}}
- vector.scatter %base[%indices], %mask, %value
+ vector.scatter %base[%c0][%indices], %mask, %value
: memref<?xf32>, vector<16xi32>, vector<17xi1>, vector<16xf32>
}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index eeca040844e4..c0a5af48c821 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -541,25 +541,56 @@ func @masked_load_and_store(%base: memref<?xf32>, %mask: vector<16xi1>, %passthr
return
}
+// CHECK-LABEL: @masked_load_and_store2d
+func @masked_load_and_store2d(%base: memref<?x?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
+ %c0 = constant 0 : index
+ // CHECK: %[[X:.*]] = vector.maskedload %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ %0 = vector.maskedload %base[%c0, %c0], %mask, %passthru : memref<?x?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ // CHECK: vector.maskedstore %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}, %[[X]] : memref<?x?xf32>, vector<16xi1>, vector<16xf32>
+ vector.maskedstore %base[%c0, %c0], %mask, %0 : memref<?x?xf32>, vector<16xi1>, vector<16xf32>
+ return
+}
+
// CHECK-LABEL: @gather_and_scatter
-func @gather_and_scatter(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
- // CHECK: %[[X:.*]] = vector.gather %{{.*}}[%{{.*}}], %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
- %0 = vector.gather %base[%indices], %mask, %pass_thru : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
- // CHECK: vector.scatter %{{.*}}[%{{.*}}], %{{.*}}, %[[X]] : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
- vector.scatter %base[%indices], %mask, %0 : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+func @gather_and_scatter(%base: memref<?xf32>, %v: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+ %c0 = constant 0 : index
+ // CHECK: %[[X:.*]] = vector.gather %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ // CHECK: vector.scatter %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %[[X]] : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+ vector.scatter %base[%c0][%v], %mask, %0 : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+ return
+}
+
+// CHECK-LABEL: @gather_and_scatter2d
+func @gather_and_scatter2d(%base: memref<?x?xf32>, %v: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+ %c0 = constant 0 : index
+ // CHECK: %[[X:.*]] = vector.gather %{{.*}}[%{{.*}}, %{{.*}}] [%{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ %0 = vector.gather %base[%c0, %c0][%v], %mask, %pass_thru : memref<?x?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ // CHECK: vector.scatter %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %[[X]] : memref<?x?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+ vector.scatter %base[%c0, %c0][%v], %mask, %0 : memref<?x?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
return
}
// CHECK-LABEL: @expand_and_compress
func @expand_and_compress(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = constant 0 : index
- // CHECK: %[[X:.*]] = vector.expandload %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ // CHECK: %[[X:.*]] = vector.expandload %{{.*}}[%{{.*}}], %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
%0 = vector.expandload %base[%c0], %mask, %pass_thru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
- // CHECK: vector.compressstore %{{.*}}[{{.*}}], %{{.*}}, %[[X]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
+ // CHECK: vector.compressstore %{{.*}}[%{{.*}}], %{{.*}}, %[[X]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
vector.compressstore %base[%c0], %mask, %0 : memref<?xf32>, vector<16xi1>, vector<16xf32>
return
}
+// CHECK-LABEL: @expand_and_compress2d
+func @expand_and_compress2d(%base: memref<?x?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+ %c0 = constant 0 : index
+ // CHECK: %[[X:.*]] = vector.expandload %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ %0 = vector.expandload %base[%c0, %c0], %mask, %pass_thru : memref<?x?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ // CHECK: vector.compressstore %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}, %[[X]] : memref<?x?xf32>, vector<16xi1>, vector<16xf32>
+ vector.compressstore %base[%c0, %c0], %mask, %0 : memref<?x?xf32>, vector<16xi1>, vector<16xf32>
+ return
+}
+
// CHECK-LABEL: @extract_insert_map
func @extract_insert_map(%v: vector<32xf32>, %v2: vector<16x32xf32>,
%id0 : index, %id1 : index) -> (vector<32xf32>, vector<16x32xf32>) {
diff --git a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
index 5c55cc5b6f0e..ab46a7863a55 100644
--- a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
@@ -85,12 +85,14 @@ func @maskedstore2(%base: memref<16xf32>, %value: vector<16xf32>) {
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) -> vector<16xf32> {
+// CHECK-NEXT: %[[C:.*]] = constant 0 : index
// CHECK-NEXT: %[[M:.*]] = vector.constant_mask [16] : vector<16xi1>
-// CHECK-NEXT: %[[G:.*]] = vector.gather %[[A0]][%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-NEXT: %[[G:.*]] = vector.gather %[[A0]][%[[C]]] [%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
// CHECK-NEXT: return %[[G]] : vector<16xf32>
func @gather1(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
+ %c0 = constant 0 : index
%mask = vector.constant_mask [16] : vector<16xi1>
- %ld = vector.gather %base[%indices], %mask, %pass_thru
+ %ld = vector.gather %base[%c0][%indices], %mask, %pass_thru
: memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %ld : vector<16xf32>
}
@@ -101,8 +103,9 @@ func @gather1(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vecto
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) -> vector<16xf32> {
// CHECK-NEXT: return %[[A2]] : vector<16xf32>
func @gather2(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
+ %c0 = constant 0 : index
%mask = vector.constant_mask [0] : vector<16xi1>
- %ld = vector.gather %base[%indices], %mask, %pass_thru
+ %ld = vector.gather %base[%c0][%indices], %mask, %pass_thru
: memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %ld : vector<16xf32>
}
@@ -111,12 +114,14 @@ func @gather2(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vecto
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) {
+// CHECK-NEXT: %[[C:.*]] = constant 0 : index
// CHECK-NEXT: %[[M:.*]] = vector.constant_mask [16] : vector<16xi1>
-// CHECK-NEXT: vector.scatter %[[A0]][%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+// CHECK-NEXT: vector.scatter %[[A0]][%[[C]]] [%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
// CHECK-NEXT: return
func @scatter1(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {
+ %c0 = constant 0 : index
%mask = vector.constant_mask [16] : vector<16xi1>
- vector.scatter %base[%indices], %mask, %value
+ vector.scatter %base[%c0][%indices], %mask, %value
: memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
return
}
@@ -127,9 +132,10 @@ func @scatter1(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<1
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) {
// CHECK-NEXT: return
func @scatter2(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {
+ %c0 = constant 0 : index
%0 = vector.type_cast %base : memref<16xf32> to memref<vector<16xf32>>
%mask = vector.constant_mask [0] : vector<16xi1>
- vector.scatter %base[%indices], %mask, %value
+ vector.scatter %base[%c0][%indices], %mask, %value
: memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
return
}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-gather.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-gather.mlir
index 95df5aea06e4..0e5ea8557fbc 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-gather.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-gather.mlir
@@ -5,7 +5,8 @@
func @gather8(%base: memref<?xf32>, %indices: vector<8xi32>,
%mask: vector<8xi1>, %pass_thru: vector<8xf32>) -> vector<8xf32> {
- %g = vector.gather %base[%indices], %mask, %pass_thru
+ %c0 = constant 0: index
+ %g = vector.gather %base[%c0][%indices], %mask, %pass_thru
: memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
return %g : vector<8xf32>
}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-scatter.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-scatter.mlir
index 0666cc852c2a..f4316faee5c7 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-scatter.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-scatter.mlir
@@ -6,7 +6,8 @@
func @scatter8(%base: memref<?xf32>,
%indices: vector<8xi32>,
%mask: vector<8xi1>, %value: vector<8xf32>) {
- vector.scatter %base[%indices], %mask, %value
+ %c0 = constant 0: index
+ vector.scatter %base[%c0][%indices], %mask, %value
: memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>
return
}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-sparse-dot-matvec.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-sparse-dot-matvec.mlir
index 7940e8c68b1a..4f399886dd91 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-sparse-dot-matvec.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-sparse-dot-matvec.mlir
@@ -64,7 +64,7 @@ func @spmv8x8(%AVAL: memref<8xvector<4xf32>>,
scf.for %i = %c0 to %cn step %c1 {
%aval = load %AVAL[%i] : memref<8xvector<4xf32>>
%aidx = load %AIDX[%i] : memref<8xvector<4xi32>>
- %0 = vector.gather %X[%aidx], %mask, %pass
+ %0 = vector.gather %X[%c0][%aidx], %mask, %pass
: memref<?xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
%1 = vector.contract #dot_trait %aval, %0, %f0 : vector<4xf32>, vector<4xf32> into f32
store %1, %B[%i] : memref<?xf32>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-sparse-saxpy-jagged-matvec.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-sparse-saxpy-jagged-matvec.mlir
index 31f288e0f6c5..bda77483806e 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-sparse-saxpy-jagged-matvec.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-sparse-saxpy-jagged-matvec.mlir
@@ -57,7 +57,7 @@ func @spmv8x8(%AVAL: memref<4xvector<8xf32>>,
%b_out = scf.for %k = %c0 to %cn step %c1 iter_args(%b_iter = %b) -> (vector<8xf32>) {
%aval = load %AVAL[%k] : memref<4xvector<8xf32>>
%aidx = load %AIDX[%k] : memref<4xvector<8xi32>>
- %0 = vector.gather %X[%aidx], %mask, %pass
+ %0 = vector.gather %X[%c0][%aidx], %mask, %pass
: memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
%b_new = vector.fma %aval, %0, %b_iter : vector<8xf32>
scf.yield %b_new : vector<8xf32>
More information about the Mlir-commits
mailing list