[Mlir-commits] [mlir] f250b97 - Reland "[MLIR]Extend vector.gather to support n-D result"
Diego Caballero
llvmlistbot at llvm.org
Tue Aug 23 21:24:01 PDT 2022
Author: Che-Yu Wu
Date: 2022-08-24T04:18:00Z
New Revision: f250b9722232044e6d1bf794075134efb566a284
URL: https://github.com/llvm/llvm-project/commit/f250b9722232044e6d1bf794075134efb566a284
DIFF: https://github.com/llvm/llvm-project/commit/f250b9722232044e6d1bf794075134efb566a284.diff
LOG: Reland "[MLIR]Extend vector.gather to support n-D result"
Reviewed By: dcaballe
Differential Revision: https://reviews.llvm.org/D132507
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 3cc2287599a78..aeffe601324e0 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1761,10 +1761,10 @@ def Vector_GatherOp :
Vector_Op<"gather">,
Arguments<(ins Arg<AnyShaped, "", [MemRead]>:$base,
Variadic<Index>:$indices,
- VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec,
- VectorOfRankAndType<[1], [I1]>:$mask,
- VectorOfRank<[1]>:$pass_thru)>,
- Results<(outs VectorOfRank<[1]>:$result)> {
+ VectorOf<[AnyInteger, Index]>:$index_vec,
+ VectorOf<[I1]>:$mask,
+ AnyVector:$pass_thru)>,
+ Results<(outs AnyVector:$result)> {
let summary = [{
gathers elements from memory or ranked tensor into a vector as defined by an
@@ -1773,10 +1773,10 @@ def Vector_GatherOp :
let description = [{
The gather operation gathers elements from memory or ranked tensor into a
- 1-D vector as defined by a base with indices and an additional 1-D index
- vector, but only if the corresponding bit is set in a 1-D mask vector.
- Otherwise, the element is taken from a 1-D pass-through vector. Informally
- the semantics are:
+ n-D vector as defined by a base with indices and an additional n-D index
+ vector (each index is a 1-D offset on the base), but only if the
+ corresponding bit is set in a n-D mask vector. Otherwise, the element is
+ taken from a n-D pass-through vector. Informally the semantics are:
```
result[0] := mask[0] ? base[index[0]] : pass_thru[0]
result[1] := mask[1] ? base[index[1]] : pass_thru[1]
@@ -1786,15 +1786,13 @@ def Vector_GatherOp :
The gather operation can be used directly where applicable, or can be used
during progressively lowering to bring other memory operations closer to
- hardware ISA support for a gather. The semantics of the operation closely
- correspond to those of the `llvm.masked.gather`
- [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-gather-intrinsics).
+ hardware ISA support for a gather.
Examples:
```mlir
%0 = vector.gather %base[%c0][%v], %mask, %pass_thru
- : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
%1 = vector.gather %base[%i, %j][%v], %mask, %pass_thru
: memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index f5cd8addf9256..2f0091f99dd3b 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -82,14 +82,6 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
std::function<Value(Type, ValueRange)> createOperand,
ConversionPatternRewriter &rewriter) {
auto resultNDVectorType = op->getResult(0).getType().cast<VectorType>();
-
- SmallVector<Type> operand1DVectorTypes;
- for (Value operand : op->getOperands()) {
- auto operandNDVectorType = operand.getType().cast<VectorType>();
- auto operandTypeInfo =
- extractNDVectorTypeInfo(operandNDVectorType, typeConverter);
- operand1DVectorTypes.push_back(operandTypeInfo.llvm1DVectorTy);
- }
auto resultTypeInfo =
extractNDVectorTypeInfo(resultNDVectorType, typeConverter);
auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy;
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index bb293f33fb2cf..c3a18ae29ab49 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -91,24 +91,28 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
return success();
}
-// Add an index vector component to a base pointer. This almost always succeeds
-// unless the last stride is non-unit or the memory space is not zero.
-static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
- Location loc, Value memref, Value base,
- Value index, MemRefType memRefType,
- VectorType vType, Value &ptrs) {
+// Check if the last stride is non-unit or the memory space is not zero.
+static LogicalResult isMemRefTypeSupported(MemRefType memRefType) {
int64_t offset;
SmallVector<int64_t, 4> strides;
auto successStrides = getStridesAndOffset(memRefType, strides, offset);
if (failed(successStrides) || strides.back() != 1 ||
memRefType.getMemorySpaceAsInt() != 0)
return failure();
- auto pType = MemRefDescriptor(memref).getElementPtrType();
- auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0));
- ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index);
return success();
}
+// Add an index vector component to a base pointer.
+static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
+ MemRefType memRefType, Value llvmMemref, Value base,
+ Value index, uint64_t vLen) {
+ assert(succeeded(isMemRefTypeSupported(memRefType)) &&
+ "unsupported memref type");
+ auto pType = MemRefDescriptor(llvmMemref).getElementPtrType();
+ auto ptrsType = LLVM::getFixedVectorType(pType, vLen);
+ return rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index);
+}
+
// Casts a strided element pointer to a vector pointer. The vector pointer
// will be in the same address space as the incoming memref type.
static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
@@ -257,29 +261,53 @@ class VectorGatherOpConversion
LogicalResult
matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto loc = gather->getLoc();
MemRefType memRefType = gather.getBaseType().dyn_cast<MemRefType>();
assert(memRefType && "The base should be bufferized");
+ if (failed(isMemRefTypeSupported(memRefType)))
+ return failure();
+
+ auto loc = gather->getLoc();
+
// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
return failure();
- // Resolve address.
- Value ptrs;
- VectorType vType = gather.getVectorType();
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
- if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr,
- adaptor.getIndexVec(), memRefType, vType, ptrs)))
- return failure();
+ Value base = adaptor.getBase();
+
+ auto llvmNDVectorTy = adaptor.getIndexVec().getType();
+ // Handle the simple case of 1-D vector.
+ if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>()) {
+ auto vType = gather.getVectorType();
+ // Resolve address.
+ Value ptrs = getIndexedPtrs(rewriter, loc, memRefType, base, ptr,
+ adaptor.getIndexVec(),
+ /*vLen=*/vType.getDimSize(0));
+ // Replace with the gather intrinsic.
+ rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
+ gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
+ adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
+ return success();
+ }
- // Replace with the gather intrinsic.
- rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
- gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
- adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
- return success();
+ auto callback = [align, memRefType, base, ptr, loc, &rewriter](
+ Type llvm1DVectorTy, ValueRange vectorOperands) {
+ // Resolve address.
+ Value ptrs = getIndexedPtrs(
+ rewriter, loc, memRefType, base, ptr, /*index=*/vectorOperands[0],
+ LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue());
+ // Create the gather intrinsic.
+ return rewriter.create<LLVM::masked_gather>(
+ loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1],
+ /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align));
+ };
+ SmallVector<Value> vectorOperands = {
+ adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()};
+ return LLVM::detail::handleMultidimensionalVectors(
+ gather, vectorOperands, *getTypeConverter(), callback, rewriter);
}
};
@@ -295,19 +323,21 @@ class VectorScatterOpConversion
auto loc = scatter->getLoc();
MemRefType memRefType = scatter.getMemRefType();
+ if (failed(isMemRefTypeSupported(memRefType)))
+ return failure();
+
// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
return failure();
// Resolve address.
- Value ptrs;
VectorType vType = scatter.getVectorType();
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
- if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr,
- adaptor.getIndexVec(), memRefType, vType, ptrs)))
- return failure();
+ Value ptrs =
+ getIndexedPtrs(rewriter, loc, memRefType, adaptor.getBase(), ptr,
+ adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0));
// Replace with the scatter intrinsic.
rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b411471983115..ef37005ddc913 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -49,11 +49,11 @@ enum class MaskFormat {
Unknown = 2,
};
-/// Helper method to classify a 1-D mask value. Currently, the method
+/// Helper method to classify a mask value. Currently, the method
/// looks "under the hood" of a constant value with dense attributes
/// and a constant mask operation (since the client may be called at
/// various stages during progressive lowering).
-static MaskFormat get1DMaskFormat(Value mask) {
+static MaskFormat getMaskFormat(Value mask) {
if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
// Inspect constant dense values. We count up for bits that
// are set, count down for bits that are cleared, and bail
@@ -77,12 +77,20 @@ static MaskFormat get1DMaskFormat(Value mask) {
// dimension size, all bits are set. If the index is zero
// or less, no bits are set.
ArrayAttr masks = m.getMaskDimSizes();
- assert(masks.size() == 1);
- int64_t i = masks[0].cast<IntegerAttr>().getInt();
- int64_t u = m.getType().getDimSize(0);
- if (i >= u)
+ auto shape = m.getType().getShape();
+ bool allTrue = true;
+ bool allFalse = true;
+ for (auto pair : llvm::zip(masks, shape)) {
+ int64_t i = std::get<0>(pair).cast<IntegerAttr>().getInt();
+ int64_t u = std::get<1>(pair);
+ if (i < u)
+ allTrue = false;
+ if (i > 0)
+ allFalse = false;
+ }
+ if (allTrue)
return MaskFormat::AllTrue;
- if (i <= 0)
+ if (allFalse)
return MaskFormat::AllFalse;
}
return MaskFormat::Unknown;
@@ -3980,7 +3988,7 @@ class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
using OpRewritePattern<MaskedLoadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(MaskedLoadOp load,
PatternRewriter &rewriter) const override {
- switch (get1DMaskFormat(load.getMask())) {
+ switch (getMaskFormat(load.getMask())) {
case MaskFormat::AllTrue:
rewriter.replaceOpWithNewOp<vector::LoadOp>(
load, load.getType(), load.getBase(), load.getIndices());
@@ -4031,7 +4039,7 @@ class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
using OpRewritePattern<MaskedStoreOp>::OpRewritePattern;
LogicalResult matchAndRewrite(MaskedStoreOp store,
PatternRewriter &rewriter) const override {
- switch (get1DMaskFormat(store.getMask())) {
+ switch (getMaskFormat(store.getMask())) {
case MaskFormat::AllTrue:
rewriter.replaceOpWithNewOp<vector::StoreOp>(
store, store.getValueToStore(), store.getBase(), store.getIndices());
@@ -4074,9 +4082,9 @@ LogicalResult GatherOp::verify() {
return emitOpError("base and result element type should match");
if (llvm::size(getIndices()) != baseType.getRank())
return emitOpError("requires ") << baseType.getRank() << " indices";
- if (resVType.getDimSize(0) != indVType.getDimSize(0))
+ if (resVType.getShape() != indVType.getShape())
return emitOpError("expected result dim to match indices dim");
- if (resVType.getDimSize(0) != maskVType.getDimSize(0))
+ if (resVType.getShape() != maskVType.getShape())
return emitOpError("expected result dim to match mask dim");
if (resVType != getPassThruVectorType())
return emitOpError("expected pass_thru of same type as result type");
@@ -4089,7 +4097,7 @@ class GatherFolder final : public OpRewritePattern<GatherOp> {
using OpRewritePattern<GatherOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GatherOp gather,
PatternRewriter &rewriter) const override {
- switch (get1DMaskFormat(gather.getMask())) {
+ switch (getMaskFormat(gather.getMask())) {
case MaskFormat::AllTrue:
return failure(); // no unmasked equivalent
case MaskFormat::AllFalse:
@@ -4135,7 +4143,7 @@ class ScatterFolder final : public OpRewritePattern<ScatterOp> {
using OpRewritePattern<ScatterOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ScatterOp scatter,
PatternRewriter &rewriter) const override {
- switch (get1DMaskFormat(scatter.getMask())) {
+ switch (getMaskFormat(scatter.getMask())) {
case MaskFormat::AllTrue:
return failure(); // no unmasked equivalent
case MaskFormat::AllFalse:
@@ -4181,7 +4189,7 @@ class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
using OpRewritePattern<ExpandLoadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ExpandLoadOp expand,
PatternRewriter &rewriter) const override {
- switch (get1DMaskFormat(expand.getMask())) {
+ switch (getMaskFormat(expand.getMask())) {
case MaskFormat::AllTrue:
rewriter.replaceOpWithNewOp<vector::LoadOp>(
expand, expand.getType(), expand.getBase(), expand.getIndices());
@@ -4226,7 +4234,7 @@ class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
using OpRewritePattern<CompressStoreOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CompressStoreOp compress,
PatternRewriter &rewriter) const override {
- switch (get1DMaskFormat(compress.getMask())) {
+ switch (getMaskFormat(compress.getMask())) {
case MaskFormat::AllTrue:
rewriter.replaceOpWithNewOp<vector::StoreOp>(
compress, compress.getValueToStore(), compress.getBase(),
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index e99460172c98c..37827bd1c226c 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1932,6 +1932,56 @@ func.func @gather_op_index(%arg0: memref<?xindex>, %arg1: vector<3xindex>, %arg2
// -----
+func.func @gather_op_multi_dims(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> {
+ %0 = arith.constant 0: index
+ %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
+ return %1 : vector<2x3xf32>
+}
+
+// CHECK-LABEL: func @gather_op_multi_dims
+// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi32>>
+// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi1>>
+// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
+// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr<f32>, vector<3xi32>) -> !llvm.vec<3 x ptr<f32>>
+// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr<f32>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
+// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi32>>
+// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi1>>
+// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
+// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr<f32>, vector<3xi32>) -> !llvm.vec<3 x ptr<f32>>
+// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr<f32>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
+
+// -----
+
+func.func @gather_op_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
+ %0 = arith.constant 0: index
+ %1 = vector.constant_mask [1, 2] : vector<2x3xi1>
+ %2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
+ return %2 : vector<2x3xf32>
+}
+
+// CHECK-LABEL: func @gather_op_with_mask
+// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr<f32>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr<f32>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+
+// -----
+
+func.func @gather_op_with_zero_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
+ %0 = arith.constant 0: index
+ %1 = vector.constant_mask [0, 0] : vector<2x3xi1>
+ %2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
+ return %2 : vector<2x3xf32>
+}
+
+// CHECK-LABEL: func @gather_op_with_zero_mask
+// CHECK-SAME: (%{{.*}}: memref<?xf32>, %{{.*}}: vector<2x3xi32>, %[[S:.*]]: vector<2x3xf32>)
+// CHECK-NOT: %{{.*}} = llvm.intr.masked.gather
+// CHECK: return %[[S]] : vector<2x3xf32>
+
+// -----
+
func.func @gather_2d_op(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> {
%0 = arith.constant 3 : index
%1 = vector.gather %arg0[%0, %0][%arg1], %arg2, %arg3 : memref<4x4xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index d50315970d744..9a7e6f4979a39 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1305,7 +1305,7 @@ func.func @gather_memref_mismatch(%base: memref<?x?xf64>, %indices: vector<16xi3
func.func @gather_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = arith.constant 0 : index
- // expected-error at +1 {{'vector.gather' op result #0 must be of ranks 1, but got 'vector<2x16xf32>'}}
+ // expected-error at +1 {{'vector.gather' op expected result dim to match indices dim}}
%0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
: memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<2x16xf32>
}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index a017a8c0cfd45..ab8d1e07d99c7 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -701,6 +701,14 @@ func.func @gather_on_tensor(%base: tensor<?xf32>, %v: vector<16xi32>, %mask: vec
return %0 : vector<16xf32>
}
+// CHECK-LABEL: @gather_multi_dims
+func.func @gather_multi_dims(%base: tensor<?xf32>, %v: vector<2x16xi32>, %mask: vector<2x16xi1>, %pass_thru: vector<2x16xf32>) -> vector<2x16xf32> {
+ %c0 = arith.constant 0 : index
+ // CHECK: vector.gather %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %{{.*}} : tensor<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
+ %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
+ return %0 : vector<2x16xf32>
+}
+
// CHECK-LABEL: @expand_and_compress
func.func @expand_and_compress(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = arith.constant 0 : index
More information about the Mlir-commits
mailing list