[Mlir-commits] [mlir] de54bcc - Revert "[MLIR]Extend vector.gather to support n-D result"
Mehdi Amini
llvmlistbot at llvm.org
Tue Aug 23 13:27:09 PDT 2022
Author: Mehdi Amini
Date: 2022-08-23T20:26:38Z
New Revision: de54bcc54c6147d90f11f70b6b53f84e62b1e74a
URL: https://github.com/llvm/llvm-project/commit/de54bcc54c6147d90f11f70b6b53f84e62b1e74a
DIFF: https://github.com/llvm/llvm-project/commit/de54bcc54c6147d90f11f70b6b53f84e62b1e74a.diff
LOG: Revert "[MLIR]Extend vector.gather to support n-D result"
This reverts commit 0cbfd6fd1633a075dcfd1bcd8a11e1c6d2785fa8.
A test is crashing with the shared_lib config.
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index aeffe601324e..3cc2287599a7 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1761,10 +1761,10 @@ def Vector_GatherOp :
Vector_Op<"gather">,
Arguments<(ins Arg<AnyShaped, "", [MemRead]>:$base,
Variadic<Index>:$indices,
- VectorOf<[AnyInteger, Index]>:$index_vec,
- VectorOf<[I1]>:$mask,
- AnyVector:$pass_thru)>,
- Results<(outs AnyVector:$result)> {
+ VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec,
+ VectorOfRankAndType<[1], [I1]>:$mask,
+ VectorOfRank<[1]>:$pass_thru)>,
+ Results<(outs VectorOfRank<[1]>:$result)> {
let summary = [{
gathers elements from memory or ranked tensor into a vector as defined by an
@@ -1773,10 +1773,10 @@ def Vector_GatherOp :
let description = [{
The gather operation gathers elements from memory or ranked tensor into a
- n-D vector as defined by a base with indices and an additional n-D index
- vector (each index is a 1-D offset on the base), but only if the
- corresponding bit is set in a n-D mask vector. Otherwise, the element is
- taken from a n-D pass-through vector. Informally the semantics are:
+ 1-D vector as defined by a base with indices and an additional 1-D index
+ vector, but only if the corresponding bit is set in a 1-D mask vector.
+ Otherwise, the element is taken from a 1-D pass-through vector. Informally
+ the semantics are:
```
result[0] := mask[0] ? base[index[0]] : pass_thru[0]
result[1] := mask[1] ? base[index[1]] : pass_thru[1]
@@ -1786,13 +1786,15 @@ def Vector_GatherOp :
The gather operation can be used directly where applicable, or can be used
during progressively lowering to bring other memory operations closer to
- hardware ISA support for a gather.
+ hardware ISA support for a gather. The semantics of the operation closely
+ correspond to those of the `llvm.masked.gather`
+ [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-gather-intrinsics).
Examples:
```mlir
%0 = vector.gather %base[%c0][%v], %mask, %pass_thru
- : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
+ : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
%1 = vector.gather %base[%i, %j][%v], %mask, %pass_thru
: memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index 2f0091f99dd3..f5cd8addf925 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -82,6 +82,14 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
std::function<Value(Type, ValueRange)> createOperand,
ConversionPatternRewriter &rewriter) {
auto resultNDVectorType = op->getResult(0).getType().cast<VectorType>();
+
+ SmallVector<Type> operand1DVectorTypes;
+ for (Value operand : op->getOperands()) {
+ auto operandNDVectorType = operand.getType().cast<VectorType>();
+ auto operandTypeInfo =
+ extractNDVectorTypeInfo(operandNDVectorType, typeConverter);
+ operand1DVectorTypes.push_back(operandTypeInfo.llvm1DVectorTy);
+ }
auto resultTypeInfo =
extractNDVectorTypeInfo(resultNDVectorType, typeConverter);
auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy;
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index a9c483067610..bb293f33fb2c 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -91,28 +91,24 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
return success();
}
-// Check if the last stride is non-unit or the memory space is not zero.
-static LogicalResult isMemRefTypeSupported(MemRefType memRefType) {
+// Add an index vector component to a base pointer. This almost always succeeds
+// unless the last stride is non-unit or the memory space is not zero.
+static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
+ Location loc, Value memref, Value base,
+ Value index, MemRefType memRefType,
+ VectorType vType, Value &ptrs) {
int64_t offset;
SmallVector<int64_t, 4> strides;
auto successStrides = getStridesAndOffset(memRefType, strides, offset);
if (failed(successStrides) || strides.back() != 1 ||
memRefType.getMemorySpaceAsInt() != 0)
return failure();
+ auto pType = MemRefDescriptor(memref).getElementPtrType();
+ auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0));
+ ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index);
return success();
}
-// Add an index vector component to a base pointer.
-static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
- MemRefType memRefType, Value llvmMemref, Value base,
- Value index, uint64_t vLen) {
- assert(succeeded(isMemRefTypeSupported(memRefType)) &&
- "unsupported memref type");
- auto pType = MemRefDescriptor(llvmMemref).getElementPtrType();
- auto ptrsType = LLVM::getFixedVectorType(pType, vLen);
- return rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index);
-}
-
// Casts a strided element pointer to a vector pointer. The vector pointer
// will be in the same address space as the incoming memref type.
static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
@@ -261,53 +257,29 @@ class VectorGatherOpConversion
LogicalResult
matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ auto loc = gather->getLoc();
MemRefType memRefType = gather.getBaseType().dyn_cast<MemRefType>();
assert(memRefType && "The base should be bufferized");
- if (failed(isMemRefTypeSupported(memRefType)))
- return failure();
-
- auto loc = gather->getLoc();
-
// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
return failure();
+ // Resolve address.
+ Value ptrs;
+ VectorType vType = gather.getVectorType();
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
- Value base = adaptor.getBase();
-
- auto llvmNDVectorTy = adaptor.getIndexVec().getType();
- // Handle the simple case of 1-D vector.
- if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>()) {
- auto vType = gather.getVectorType();
- // Resolve address.
- Value ptrs = getIndexedPtrs(rewriter, loc, memRefType, base, ptr,
- adaptor.getIndexVec(),
- /*vLen=*/vType.getDimSize(0));
- // Replace with the gather intrinsic.
- rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
- gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
- adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
- return success();
- }
+ if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr,
+ adaptor.getIndexVec(), memRefType, vType, ptrs)))
+ return failure();
- auto callback = [align, memRefType, base, ptr, loc, &rewriter](
- Type llvm1DVectorTy, ValueRange vectorOperands) {
- // Resolve address.
- Value ptrs = getIndexedPtrs(
- rewriter, loc, memRefType, base, ptr, /*index=*/vectorOperands[0],
- LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue());
- // Create the gather intrinsic.
- return rewriter.create<LLVM::masked_gather>(
- loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1],
- /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align));
- };
- ValueRange vectorOperands = {adaptor.getIndexVec(), adaptor.getMask(),
- adaptor.getPassThru()};
- return LLVM::detail::handleMultidimensionalVectors(
- gather, vectorOperands, *getTypeConverter(), callback, rewriter);
+ // Replace with the gather intrinsic.
+ rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
+ gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
+ adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
+ return success();
}
};
@@ -323,21 +295,19 @@ class VectorScatterOpConversion
auto loc = scatter->getLoc();
MemRefType memRefType = scatter.getMemRefType();
- if (failed(isMemRefTypeSupported(memRefType)))
- return failure();
-
// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
return failure();
// Resolve address.
+ Value ptrs;
VectorType vType = scatter.getVectorType();
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
- Value ptrs =
- getIndexedPtrs(rewriter, loc, memRefType, adaptor.getBase(), ptr,
- adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0));
+ if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr,
+ adaptor.getIndexVec(), memRefType, vType, ptrs)))
+ return failure();
// Replace with the scatter intrinsic.
rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ef37005ddc91..b41147198311 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -49,11 +49,11 @@ enum class MaskFormat {
Unknown = 2,
};
-/// Helper method to classify a mask value. Currently, the method
+/// Helper method to classify a 1-D mask value. Currently, the method
/// looks "under the hood" of a constant value with dense attributes
/// and a constant mask operation (since the client may be called at
/// various stages during progressive lowering).
-static MaskFormat getMaskFormat(Value mask) {
+static MaskFormat get1DMaskFormat(Value mask) {
if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
// Inspect constant dense values. We count up for bits that
// are set, count down for bits that are cleared, and bail
@@ -77,20 +77,12 @@ static MaskFormat getMaskFormat(Value mask) {
// dimension size, all bits are set. If the index is zero
// or less, no bits are set.
ArrayAttr masks = m.getMaskDimSizes();
- auto shape = m.getType().getShape();
- bool allTrue = true;
- bool allFalse = true;
- for (auto pair : llvm::zip(masks, shape)) {
- int64_t i = std::get<0>(pair).cast<IntegerAttr>().getInt();
- int64_t u = std::get<1>(pair);
- if (i < u)
- allTrue = false;
- if (i > 0)
- allFalse = false;
- }
- if (allTrue)
+ assert(masks.size() == 1);
+ int64_t i = masks[0].cast<IntegerAttr>().getInt();
+ int64_t u = m.getType().getDimSize(0);
+ if (i >= u)
return MaskFormat::AllTrue;
- if (allFalse)
+ if (i <= 0)
return MaskFormat::AllFalse;
}
return MaskFormat::Unknown;
@@ -3988,7 +3980,7 @@ class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
using OpRewritePattern<MaskedLoadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(MaskedLoadOp load,
PatternRewriter &rewriter) const override {
- switch (getMaskFormat(load.getMask())) {
+ switch (get1DMaskFormat(load.getMask())) {
case MaskFormat::AllTrue:
rewriter.replaceOpWithNewOp<vector::LoadOp>(
load, load.getType(), load.getBase(), load.getIndices());
@@ -4039,7 +4031,7 @@ class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
using OpRewritePattern<MaskedStoreOp>::OpRewritePattern;
LogicalResult matchAndRewrite(MaskedStoreOp store,
PatternRewriter &rewriter) const override {
- switch (getMaskFormat(store.getMask())) {
+ switch (get1DMaskFormat(store.getMask())) {
case MaskFormat::AllTrue:
rewriter.replaceOpWithNewOp<vector::StoreOp>(
store, store.getValueToStore(), store.getBase(), store.getIndices());
@@ -4082,9 +4074,9 @@ LogicalResult GatherOp::verify() {
return emitOpError("base and result element type should match");
if (llvm::size(getIndices()) != baseType.getRank())
return emitOpError("requires ") << baseType.getRank() << " indices";
- if (resVType.getShape() != indVType.getShape())
+ if (resVType.getDimSize(0) != indVType.getDimSize(0))
return emitOpError("expected result dim to match indices dim");
- if (resVType.getShape() != maskVType.getShape())
+ if (resVType.getDimSize(0) != maskVType.getDimSize(0))
return emitOpError("expected result dim to match mask dim");
if (resVType != getPassThruVectorType())
return emitOpError("expected pass_thru of same type as result type");
@@ -4097,7 +4089,7 @@ class GatherFolder final : public OpRewritePattern<GatherOp> {
using OpRewritePattern<GatherOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GatherOp gather,
PatternRewriter &rewriter) const override {
- switch (getMaskFormat(gather.getMask())) {
+ switch (get1DMaskFormat(gather.getMask())) {
case MaskFormat::AllTrue:
return failure(); // no unmasked equivalent
case MaskFormat::AllFalse:
@@ -4143,7 +4135,7 @@ class ScatterFolder final : public OpRewritePattern<ScatterOp> {
using OpRewritePattern<ScatterOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ScatterOp scatter,
PatternRewriter &rewriter) const override {
- switch (getMaskFormat(scatter.getMask())) {
+ switch (get1DMaskFormat(scatter.getMask())) {
case MaskFormat::AllTrue:
return failure(); // no unmasked equivalent
case MaskFormat::AllFalse:
@@ -4189,7 +4181,7 @@ class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
using OpRewritePattern<ExpandLoadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ExpandLoadOp expand,
PatternRewriter &rewriter) const override {
- switch (getMaskFormat(expand.getMask())) {
+ switch (get1DMaskFormat(expand.getMask())) {
case MaskFormat::AllTrue:
rewriter.replaceOpWithNewOp<vector::LoadOp>(
expand, expand.getType(), expand.getBase(), expand.getIndices());
@@ -4234,7 +4226,7 @@ class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
using OpRewritePattern<CompressStoreOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CompressStoreOp compress,
PatternRewriter &rewriter) const override {
- switch (getMaskFormat(compress.getMask())) {
+ switch (get1DMaskFormat(compress.getMask())) {
case MaskFormat::AllTrue:
rewriter.replaceOpWithNewOp<vector::StoreOp>(
compress, compress.getValueToStore(), compress.getBase(),
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 37827bd1c226..e99460172c98 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1932,56 +1932,6 @@ func.func @gather_op_index(%arg0: memref<?xindex>, %arg1: vector<3xindex>, %arg2
// -----
-func.func @gather_op_multi_dims(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> {
- %0 = arith.constant 0: index
- %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
- return %1 : vector<2x3xf32>
-}
-
-// CHECK-LABEL: func @gather_op_multi_dims
-// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
-// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi32>>
-// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi1>>
-// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr<f32>, vector<3xi32>) -> !llvm.vec<3 x ptr<f32>>
-// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr<f32>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi32>>
-// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi1>>
-// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr<f32>, vector<3xi32>) -> !llvm.vec<3 x ptr<f32>>
-// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr<f32>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
-
-// -----
-
-func.func @gather_op_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
- %0 = arith.constant 0: index
- %1 = vector.constant_mask [1, 2] : vector<2x3xi1>
- %2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
- return %2 : vector<2x3xf32>
-}
-
-// CHECK-LABEL: func @gather_op_with_mask
-// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr<f32>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
-// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr<f32>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
-
-// -----
-
-func.func @gather_op_with_zero_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
- %0 = arith.constant 0: index
- %1 = vector.constant_mask [0, 0] : vector<2x3xi1>
- %2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
- return %2 : vector<2x3xf32>
-}
-
-// CHECK-LABEL: func @gather_op_with_zero_mask
-// CHECK-SAME: (%{{.*}}: memref<?xf32>, %{{.*}}: vector<2x3xi32>, %[[S:.*]]: vector<2x3xf32>)
-// CHECK-NOT: %{{.*}} = llvm.intr.masked.gather
-// CHECK: return %[[S]] : vector<2x3xf32>
-
-// -----
-
func.func @gather_2d_op(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> {
%0 = arith.constant 3 : index
%1 = vector.gather %arg0[%0, %0][%arg1], %arg2, %arg3 : memref<4x4xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 9a7e6f4979a3..d50315970d74 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1305,7 +1305,7 @@ func.func @gather_memref_mismatch(%base: memref<?x?xf64>, %indices: vector<16xi3
func.func @gather_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = arith.constant 0 : index
- // expected-error at +1 {{'vector.gather' op expected result dim to match indices dim}}
+ // expected-error at +1 {{'vector.gather' op result #0 must be of ranks 1, but got 'vector<2x16xf32>'}}
%0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
: memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<2x16xf32>
}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index ab8d1e07d99c..a017a8c0cfd4 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -701,14 +701,6 @@ func.func @gather_on_tensor(%base: tensor<?xf32>, %v: vector<16xi32>, %mask: vec
return %0 : vector<16xf32>
}
-// CHECK-LABEL: @gather_multi_dims
-func.func @gather_multi_dims(%base: tensor<?xf32>, %v: vector<2x16xi32>, %mask: vector<2x16xi1>, %pass_thru: vector<2x16xf32>) -> vector<2x16xf32> {
- %c0 = arith.constant 0 : index
- // CHECK: vector.gather %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %{{.*}} : tensor<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
- %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
- return %0 : vector<2x16xf32>
-}
-
// CHECK-LABEL: @expand_and_compress
func.func @expand_and_compress(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = arith.constant 0 : index
More information about the Mlir-commits
mailing list