[Mlir-commits] [mlir] de54bcc - Revert "[MLIR]Extend vector.gather to support n-D result"

Mehdi Amini llvmlistbot at llvm.org
Tue Aug 23 13:27:09 PDT 2022


Author: Mehdi Amini
Date: 2022-08-23T20:26:38Z
New Revision: de54bcc54c6147d90f11f70b6b53f84e62b1e74a

URL: https://github.com/llvm/llvm-project/commit/de54bcc54c6147d90f11f70b6b53f84e62b1e74a
DIFF: https://github.com/llvm/llvm-project/commit/de54bcc54c6147d90f11f70b6b53f84e62b1e74a.diff

LOG: Revert "[MLIR]Extend vector.gather to support n-D result"

This reverts commit 0cbfd6fd1633a075dcfd1bcd8a11e1c6d2785fa8.

A test is crashing with the shared_lib config.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index aeffe601324e..3cc2287599a7 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1761,10 +1761,10 @@ def Vector_GatherOp :
   Vector_Op<"gather">,
     Arguments<(ins Arg<AnyShaped, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
-               VectorOf<[AnyInteger, Index]>:$index_vec,
-               VectorOf<[I1]>:$mask,
-               AnyVector:$pass_thru)>,
-    Results<(outs AnyVector:$result)> {
+               VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec,
+               VectorOfRankAndType<[1], [I1]>:$mask,
+               VectorOfRank<[1]>:$pass_thru)>,
+    Results<(outs VectorOfRank<[1]>:$result)> {
 
   let summary = [{
     gathers elements from memory or ranked tensor into a vector as defined by an
@@ -1773,10 +1773,10 @@ def Vector_GatherOp :
 
   let description = [{
     The gather operation gathers elements from memory or ranked tensor into a
-    n-D vector as defined by a base with indices and an additional n-D index
-    vector (each index is a 1-D offset on the base), but only if the
-    corresponding bit is set in a n-D mask vector. Otherwise, the element is
-    taken from a n-D pass-through vector. Informally the semantics are:
+    1-D vector as defined by a base with indices and an additional 1-D index
+    vector, but only if the corresponding bit is set in a 1-D mask vector.
+    Otherwise, the element is taken from a 1-D pass-through vector. Informally
+    the semantics are:
     ```
     result[0] := mask[0] ? base[index[0]] : pass_thru[0]
     result[1] := mask[1] ? base[index[1]] : pass_thru[1]
@@ -1786,13 +1786,15 @@ def Vector_GatherOp :
 
     The gather operation can be used directly where applicable, or can be used
     during progressively lowering to bring other memory operations closer to
-    hardware ISA support for a gather.
+    hardware ISA support for a gather. The semantics of the operation closely
+    correspond to those of the `llvm.masked.gather`
+    [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-gather-intrinsics).
 
     Examples:
 
     ```mlir
     %0 = vector.gather %base[%c0][%v], %mask, %pass_thru
-       : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
+       : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
 
     %1 = vector.gather %base[%i, %j][%v], %mask, %pass_thru
        : memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>

diff  --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index 2f0091f99dd3..f5cd8addf925 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -82,6 +82,14 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
     std::function<Value(Type, ValueRange)> createOperand,
     ConversionPatternRewriter &rewriter) {
   auto resultNDVectorType = op->getResult(0).getType().cast<VectorType>();
+
+  SmallVector<Type> operand1DVectorTypes;
+  for (Value operand : op->getOperands()) {
+    auto operandNDVectorType = operand.getType().cast<VectorType>();
+    auto operandTypeInfo =
+        extractNDVectorTypeInfo(operandNDVectorType, typeConverter);
+    operand1DVectorTypes.push_back(operandTypeInfo.llvm1DVectorTy);
+  }
   auto resultTypeInfo =
       extractNDVectorTypeInfo(resultNDVectorType, typeConverter);
   auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy;

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index a9c483067610..bb293f33fb2c 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -91,28 +91,24 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
   return success();
 }
 
-// Check if the last stride is non-unit or the memory space is not zero.
-static LogicalResult isMemRefTypeSupported(MemRefType memRefType) {
+// Add an index vector component to a base pointer. This almost always succeeds
+// unless the last stride is non-unit or the memory space is not zero.
+static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
+                                    Location loc, Value memref, Value base,
+                                    Value index, MemRefType memRefType,
+                                    VectorType vType, Value &ptrs) {
   int64_t offset;
   SmallVector<int64_t, 4> strides;
   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
   if (failed(successStrides) || strides.back() != 1 ||
       memRefType.getMemorySpaceAsInt() != 0)
     return failure();
+  auto pType = MemRefDescriptor(memref).getElementPtrType();
+  auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0));
+  ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index);
   return success();
 }
 
-// Add an index vector component to a base pointer.
-static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
-                            MemRefType memRefType, Value llvmMemref, Value base,
-                            Value index, uint64_t vLen) {
-  assert(succeeded(isMemRefTypeSupported(memRefType)) &&
-         "unsupported memref type");
-  auto pType = MemRefDescriptor(llvmMemref).getElementPtrType();
-  auto ptrsType = LLVM::getFixedVectorType(pType, vLen);
-  return rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index);
-}
-
 // Casts a strided element pointer to a vector pointer.  The vector pointer
 // will be in the same address space as the incoming memref type.
 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
@@ -261,53 +257,29 @@ class VectorGatherOpConversion
   LogicalResult
   matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    auto loc = gather->getLoc();
     MemRefType memRefType = gather.getBaseType().dyn_cast<MemRefType>();
     assert(memRefType && "The base should be bufferized");
 
-    if (failed(isMemRefTypeSupported(memRefType)))
-      return failure();
-
-    auto loc = gather->getLoc();
-
     // Resolve alignment.
     unsigned align;
     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
       return failure();
 
+    // Resolve address.
+    Value ptrs;
+    VectorType vType = gather.getVectorType();
     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
                                      adaptor.getIndices(), rewriter);
-    Value base = adaptor.getBase();
-
-    auto llvmNDVectorTy = adaptor.getIndexVec().getType();
-    // Handle the simple case of 1-D vector.
-    if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>()) {
-      auto vType = gather.getVectorType();
-      // Resolve address.
-      Value ptrs = getIndexedPtrs(rewriter, loc, memRefType, base, ptr,
-                                  adaptor.getIndexVec(),
-                                  /*vLen=*/vType.getDimSize(0));
-      // Replace with the gather intrinsic.
-      rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
-          gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
-          adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
-      return success();
-    }
+    if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr,
+                              adaptor.getIndexVec(), memRefType, vType, ptrs)))
+      return failure();
 
-    auto callback = [align, memRefType, base, ptr, loc, &rewriter](
-                        Type llvm1DVectorTy, ValueRange vectorOperands) {
-      // Resolve address.
-      Value ptrs = getIndexedPtrs(
-          rewriter, loc, memRefType, base, ptr, /*index=*/vectorOperands[0],
-          LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue());
-      // Create the gather intrinsic.
-      return rewriter.create<LLVM::masked_gather>(
-          loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1],
-          /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align));
-    };
-    ValueRange vectorOperands = {adaptor.getIndexVec(), adaptor.getMask(),
-                                 adaptor.getPassThru()};
-    return LLVM::detail::handleMultidimensionalVectors(
-        gather, vectorOperands, *getTypeConverter(), callback, rewriter);
+    // Replace with the gather intrinsic.
+    rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
+        gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
+        adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
+    return success();
   }
 };
 
@@ -323,21 +295,19 @@ class VectorScatterOpConversion
     auto loc = scatter->getLoc();
     MemRefType memRefType = scatter.getMemRefType();
 
-    if (failed(isMemRefTypeSupported(memRefType)))
-      return failure();
-
     // Resolve alignment.
     unsigned align;
     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
       return failure();
 
     // Resolve address.
+    Value ptrs;
     VectorType vType = scatter.getVectorType();
     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
                                      adaptor.getIndices(), rewriter);
-    Value ptrs =
-        getIndexedPtrs(rewriter, loc, memRefType, adaptor.getBase(), ptr,
-                       adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0));
+    if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr,
+                              adaptor.getIndexVec(), memRefType, vType, ptrs)))
+      return failure();
 
     // Replace with the scatter intrinsic.
     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ef37005ddc91..b41147198311 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -49,11 +49,11 @@ enum class MaskFormat {
   Unknown = 2,
 };
 
-/// Helper method to classify a mask value. Currently, the method
+/// Helper method to classify a 1-D mask value. Currently, the method
 /// looks "under the hood" of a constant value with dense attributes
 /// and a constant mask operation (since the client may be called at
 /// various stages during progressive lowering).
-static MaskFormat getMaskFormat(Value mask) {
+static MaskFormat get1DMaskFormat(Value mask) {
   if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
     // Inspect constant dense values. We count up for bits that
     // are set, count down for bits that are cleared, and bail
@@ -77,20 +77,12 @@ static MaskFormat getMaskFormat(Value mask) {
     // dimension size, all bits are set. If the index is zero
     // or less, no bits are set.
     ArrayAttr masks = m.getMaskDimSizes();
-    auto shape = m.getType().getShape();
-    bool allTrue = true;
-    bool allFalse = true;
-    for (auto pair : llvm::zip(masks, shape)) {
-      int64_t i = std::get<0>(pair).cast<IntegerAttr>().getInt();
-      int64_t u = std::get<1>(pair);
-      if (i < u)
-        allTrue = false;
-      if (i > 0)
-        allFalse = false;
-    }
-    if (allTrue)
+    assert(masks.size() == 1);
+    int64_t i = masks[0].cast<IntegerAttr>().getInt();
+    int64_t u = m.getType().getDimSize(0);
+    if (i >= u)
       return MaskFormat::AllTrue;
-    if (allFalse)
+    if (i <= 0)
       return MaskFormat::AllFalse;
   }
   return MaskFormat::Unknown;
@@ -3988,7 +3980,7 @@ class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
   using OpRewritePattern<MaskedLoadOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(MaskedLoadOp load,
                                 PatternRewriter &rewriter) const override {
-    switch (getMaskFormat(load.getMask())) {
+    switch (get1DMaskFormat(load.getMask())) {
     case MaskFormat::AllTrue:
       rewriter.replaceOpWithNewOp<vector::LoadOp>(
           load, load.getType(), load.getBase(), load.getIndices());
@@ -4039,7 +4031,7 @@ class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
   using OpRewritePattern<MaskedStoreOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(MaskedStoreOp store,
                                 PatternRewriter &rewriter) const override {
-    switch (getMaskFormat(store.getMask())) {
+    switch (get1DMaskFormat(store.getMask())) {
     case MaskFormat::AllTrue:
       rewriter.replaceOpWithNewOp<vector::StoreOp>(
           store, store.getValueToStore(), store.getBase(), store.getIndices());
@@ -4082,9 +4074,9 @@ LogicalResult GatherOp::verify() {
     return emitOpError("base and result element type should match");
   if (llvm::size(getIndices()) != baseType.getRank())
     return emitOpError("requires ") << baseType.getRank() << " indices";
-  if (resVType.getShape() != indVType.getShape())
+  if (resVType.getDimSize(0) != indVType.getDimSize(0))
     return emitOpError("expected result dim to match indices dim");
-  if (resVType.getShape() != maskVType.getShape())
+  if (resVType.getDimSize(0) != maskVType.getDimSize(0))
     return emitOpError("expected result dim to match mask dim");
   if (resVType != getPassThruVectorType())
     return emitOpError("expected pass_thru of same type as result type");
@@ -4097,7 +4089,7 @@ class GatherFolder final : public OpRewritePattern<GatherOp> {
   using OpRewritePattern<GatherOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(GatherOp gather,
                                 PatternRewriter &rewriter) const override {
-    switch (getMaskFormat(gather.getMask())) {
+    switch (get1DMaskFormat(gather.getMask())) {
     case MaskFormat::AllTrue:
       return failure(); // no unmasked equivalent
     case MaskFormat::AllFalse:
@@ -4143,7 +4135,7 @@ class ScatterFolder final : public OpRewritePattern<ScatterOp> {
   using OpRewritePattern<ScatterOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(ScatterOp scatter,
                                 PatternRewriter &rewriter) const override {
-    switch (getMaskFormat(scatter.getMask())) {
+    switch (get1DMaskFormat(scatter.getMask())) {
     case MaskFormat::AllTrue:
       return failure(); // no unmasked equivalent
     case MaskFormat::AllFalse:
@@ -4189,7 +4181,7 @@ class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
   using OpRewritePattern<ExpandLoadOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(ExpandLoadOp expand,
                                 PatternRewriter &rewriter) const override {
-    switch (getMaskFormat(expand.getMask())) {
+    switch (get1DMaskFormat(expand.getMask())) {
     case MaskFormat::AllTrue:
       rewriter.replaceOpWithNewOp<vector::LoadOp>(
           expand, expand.getType(), expand.getBase(), expand.getIndices());
@@ -4234,7 +4226,7 @@ class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
   using OpRewritePattern<CompressStoreOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(CompressStoreOp compress,
                                 PatternRewriter &rewriter) const override {
-    switch (getMaskFormat(compress.getMask())) {
+    switch (get1DMaskFormat(compress.getMask())) {
     case MaskFormat::AllTrue:
       rewriter.replaceOpWithNewOp<vector::StoreOp>(
           compress, compress.getValueToStore(), compress.getBase(),

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 37827bd1c226..e99460172c98 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1932,56 +1932,6 @@ func.func @gather_op_index(%arg0: memref<?xindex>, %arg1: vector<3xindex>, %arg2
 
 // -----
 
-func.func @gather_op_multi_dims(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> {
-  %0 = arith.constant 0: index
-  %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
-  return %1 : vector<2x3xf32>
-}
-
-// CHECK-LABEL: func @gather_op_multi_dims
-// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
-// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi32>>
-// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi1>>
-// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr<f32>, vector<3xi32>) -> !llvm.vec<3 x ptr<f32>>
-// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr<f32>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi32>>
-// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi1>>
-// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr<f32>, vector<3xi32>) -> !llvm.vec<3 x ptr<f32>>
-// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr<f32>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
-
-// -----
-
-func.func @gather_op_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
-  %0 = arith.constant 0: index
-  %1 = vector.constant_mask [1, 2] : vector<2x3xi1>
-  %2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
-  return %2 : vector<2x3xf32>
-}
-
-// CHECK-LABEL: func @gather_op_with_mask
-// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr<f32>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
-// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr<f32>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
-
-// -----
-
-func.func @gather_op_with_zero_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
-  %0 = arith.constant 0: index
-  %1 = vector.constant_mask [0, 0] : vector<2x3xi1>
-  %2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
-  return %2 : vector<2x3xf32>
-}
-
-// CHECK-LABEL: func @gather_op_with_zero_mask
-// CHECK-SAME:    (%{{.*}}: memref<?xf32>, %{{.*}}: vector<2x3xi32>, %[[S:.*]]: vector<2x3xf32>)
-// CHECK-NOT:   %{{.*}} = llvm.intr.masked.gather
-// CHECK:       return %[[S]] : vector<2x3xf32>
-
-// -----
-
 func.func @gather_2d_op(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> {
   %0 = arith.constant 3 : index
   %1 = vector.gather %arg0[%0, %0][%arg1], %arg2, %arg3 : memref<4x4xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> into vector<4xf32>

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 9a7e6f4979a3..d50315970d74 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1305,7 +1305,7 @@ func.func @gather_memref_mismatch(%base: memref<?x?xf64>, %indices: vector<16xi3
 func.func @gather_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
                            %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{'vector.gather' op expected result dim to match indices dim}}
+  // expected-error at +1 {{'vector.gather' op result #0 must be  of ranks 1, but got 'vector<2x16xf32>'}}
   %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
     : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<2x16xf32>
 }

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index ab8d1e07d99c..a017a8c0cfd4 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -701,14 +701,6 @@ func.func @gather_on_tensor(%base: tensor<?xf32>, %v: vector<16xi32>, %mask: vec
   return %0 : vector<16xf32>
 }
 
-// CHECK-LABEL: @gather_multi_dims
-func.func @gather_multi_dims(%base: tensor<?xf32>, %v: vector<2x16xi32>, %mask: vector<2x16xi1>, %pass_thru: vector<2x16xf32>) -> vector<2x16xf32> {
-  %c0 = arith.constant 0 : index
-  // CHECK: vector.gather %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %{{.*}} : tensor<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
-  %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
-  return %0 : vector<2x16xf32>
-}
-
 // CHECK-LABEL: @expand_and_compress
 func.func @expand_and_compress(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
   %c0 = arith.constant 0 : index


        


More information about the Mlir-commits mailing list