[Mlir-commits] [mlir] f250b97 - Reland "[MLIR]Extend vector.gather to support n-D result"

Diego Caballero llvmlistbot at llvm.org
Tue Aug 23 21:24:01 PDT 2022


Author: Che-Yu Wu
Date: 2022-08-24T04:18:00Z
New Revision: f250b9722232044e6d1bf794075134efb566a284

URL: https://github.com/llvm/llvm-project/commit/f250b9722232044e6d1bf794075134efb566a284
DIFF: https://github.com/llvm/llvm-project/commit/f250b9722232044e6d1bf794075134efb566a284.diff

LOG: Reland "[MLIR]Extend vector.gather to support n-D result"

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D132507

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 3cc2287599a78..aeffe601324e0 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1761,10 +1761,10 @@ def Vector_GatherOp :
   Vector_Op<"gather">,
     Arguments<(ins Arg<AnyShaped, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
-               VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec,
-               VectorOfRankAndType<[1], [I1]>:$mask,
-               VectorOfRank<[1]>:$pass_thru)>,
-    Results<(outs VectorOfRank<[1]>:$result)> {
+               VectorOf<[AnyInteger, Index]>:$index_vec,
+               VectorOf<[I1]>:$mask,
+               AnyVector:$pass_thru)>,
+    Results<(outs AnyVector:$result)> {
 
   let summary = [{
     gathers elements from memory or ranked tensor into a vector as defined by an
@@ -1773,10 +1773,10 @@ def Vector_GatherOp :
 
   let description = [{
     The gather operation gathers elements from memory or ranked tensor into a
-    1-D vector as defined by a base with indices and an additional 1-D index
-    vector, but only if the corresponding bit is set in a 1-D mask vector.
-    Otherwise, the element is taken from a 1-D pass-through vector. Informally
-    the semantics are:
+    n-D vector as defined by a base with indices and an additional n-D index
+    vector (each index is a 1-D offset on the base), but only if the
+    corresponding bit is set in a n-D mask vector. Otherwise, the element is
+    taken from a n-D pass-through vector. Informally the semantics are:
     ```
     result[0] := mask[0] ? base[index[0]] : pass_thru[0]
     result[1] := mask[1] ? base[index[1]] : pass_thru[1]
@@ -1786,15 +1786,13 @@ def Vector_GatherOp :
 
     The gather operation can be used directly where applicable, or can be used
     during progressively lowering to bring other memory operations closer to
-    hardware ISA support for a gather. The semantics of the operation closely
-    correspond to those of the `llvm.masked.gather`
-    [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-gather-intrinsics).
+    hardware ISA support for a gather.
 
     Examples:
 
     ```mlir
     %0 = vector.gather %base[%c0][%v], %mask, %pass_thru
-       : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+       : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
 
     %1 = vector.gather %base[%i, %j][%v], %mask, %pass_thru
        : memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>

diff  --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index f5cd8addf9256..2f0091f99dd3b 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -82,14 +82,6 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
     std::function<Value(Type, ValueRange)> createOperand,
     ConversionPatternRewriter &rewriter) {
   auto resultNDVectorType = op->getResult(0).getType().cast<VectorType>();
-
-  SmallVector<Type> operand1DVectorTypes;
-  for (Value operand : op->getOperands()) {
-    auto operandNDVectorType = operand.getType().cast<VectorType>();
-    auto operandTypeInfo =
-        extractNDVectorTypeInfo(operandNDVectorType, typeConverter);
-    operand1DVectorTypes.push_back(operandTypeInfo.llvm1DVectorTy);
-  }
   auto resultTypeInfo =
       extractNDVectorTypeInfo(resultNDVectorType, typeConverter);
   auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy;

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index bb293f33fb2cf..c3a18ae29ab49 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -91,24 +91,28 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
   return success();
 }
 
-// Add an index vector component to a base pointer. This almost always succeeds
-// unless the last stride is non-unit or the memory space is not zero.
-static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
-                                    Location loc, Value memref, Value base,
-                                    Value index, MemRefType memRefType,
-                                    VectorType vType, Value &ptrs) {
+// Check if the last stride is non-unit or the memory space is not zero.
+static LogicalResult isMemRefTypeSupported(MemRefType memRefType) {
   int64_t offset;
   SmallVector<int64_t, 4> strides;
   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
   if (failed(successStrides) || strides.back() != 1 ||
       memRefType.getMemorySpaceAsInt() != 0)
     return failure();
-  auto pType = MemRefDescriptor(memref).getElementPtrType();
-  auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0));
-  ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index);
   return success();
 }
 
+// Add an index vector component to a base pointer.
+static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
+                            MemRefType memRefType, Value llvmMemref, Value base,
+                            Value index, uint64_t vLen) {
+  assert(succeeded(isMemRefTypeSupported(memRefType)) &&
+         "unsupported memref type");
+  auto pType = MemRefDescriptor(llvmMemref).getElementPtrType();
+  auto ptrsType = LLVM::getFixedVectorType(pType, vLen);
+  return rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index);
+}
+
 // Casts a strided element pointer to a vector pointer.  The vector pointer
 // will be in the same address space as the incoming memref type.
 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
@@ -257,29 +261,53 @@ class VectorGatherOpConversion
   LogicalResult
   matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto loc = gather->getLoc();
     MemRefType memRefType = gather.getBaseType().dyn_cast<MemRefType>();
     assert(memRefType && "The base should be bufferized");
 
+    if (failed(isMemRefTypeSupported(memRefType)))
+      return failure();
+
+    auto loc = gather->getLoc();
+
     // Resolve alignment.
     unsigned align;
     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
       return failure();
 
-    // Resolve address.
-    Value ptrs;
-    VectorType vType = gather.getVectorType();
     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
                                      adaptor.getIndices(), rewriter);
-    if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr,
-                              adaptor.getIndexVec(), memRefType, vType, ptrs)))
-      return failure();
+    Value base = adaptor.getBase();
+
+    auto llvmNDVectorTy = adaptor.getIndexVec().getType();
+    // Handle the simple case of 1-D vector.
+    if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>()) {
+      auto vType = gather.getVectorType();
+      // Resolve address.
+      Value ptrs = getIndexedPtrs(rewriter, loc, memRefType, base, ptr,
+                                  adaptor.getIndexVec(),
+                                  /*vLen=*/vType.getDimSize(0));
+      // Replace with the gather intrinsic.
+      rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
+          gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
+          adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
+      return success();
+    }
 
-    // Replace with the gather intrinsic.
-    rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
-        gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
-        adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
-    return success();
+    auto callback = [align, memRefType, base, ptr, loc, &rewriter](
+                        Type llvm1DVectorTy, ValueRange vectorOperands) {
+      // Resolve address.
+      Value ptrs = getIndexedPtrs(
+          rewriter, loc, memRefType, base, ptr, /*index=*/vectorOperands[0],
+          LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue());
+      // Create the gather intrinsic.
+      return rewriter.create<LLVM::masked_gather>(
+          loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1],
+          /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align));
+    };
+    SmallVector<Value> vectorOperands = {
+        adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()};
+    return LLVM::detail::handleMultidimensionalVectors(
+        gather, vectorOperands, *getTypeConverter(), callback, rewriter);
   }
 };
 
@@ -295,19 +323,21 @@ class VectorScatterOpConversion
     auto loc = scatter->getLoc();
     MemRefType memRefType = scatter.getMemRefType();
 
+    if (failed(isMemRefTypeSupported(memRefType)))
+      return failure();
+
     // Resolve alignment.
     unsigned align;
     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
       return failure();
 
     // Resolve address.
-    Value ptrs;
     VectorType vType = scatter.getVectorType();
     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
                                      adaptor.getIndices(), rewriter);
-    if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr,
-                              adaptor.getIndexVec(), memRefType, vType, ptrs)))
-      return failure();
+    Value ptrs =
+        getIndexedPtrs(rewriter, loc, memRefType, adaptor.getBase(), ptr,
+                       adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0));
 
     // Replace with the scatter intrinsic.
     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b411471983115..ef37005ddc913 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -49,11 +49,11 @@ enum class MaskFormat {
   Unknown = 2,
 };
 
-/// Helper method to classify a 1-D mask value. Currently, the method
+/// Helper method to classify a mask value. Currently, the method
 /// looks "under the hood" of a constant value with dense attributes
 /// and a constant mask operation (since the client may be called at
 /// various stages during progressive lowering).
-static MaskFormat get1DMaskFormat(Value mask) {
+static MaskFormat getMaskFormat(Value mask) {
   if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
     // Inspect constant dense values. We count up for bits that
     // are set, count down for bits that are cleared, and bail
@@ -77,12 +77,20 @@ static MaskFormat get1DMaskFormat(Value mask) {
     // dimension size, all bits are set. If the index is zero
     // or less, no bits are set.
     ArrayAttr masks = m.getMaskDimSizes();
-    assert(masks.size() == 1);
-    int64_t i = masks[0].cast<IntegerAttr>().getInt();
-    int64_t u = m.getType().getDimSize(0);
-    if (i >= u)
+    auto shape = m.getType().getShape();
+    bool allTrue = true;
+    bool allFalse = true;
+    for (auto pair : llvm::zip(masks, shape)) {
+      int64_t i = std::get<0>(pair).cast<IntegerAttr>().getInt();
+      int64_t u = std::get<1>(pair);
+      if (i < u)
+        allTrue = false;
+      if (i > 0)
+        allFalse = false;
+    }
+    if (allTrue)
       return MaskFormat::AllTrue;
-    if (i <= 0)
+    if (allFalse)
       return MaskFormat::AllFalse;
   }
   return MaskFormat::Unknown;
@@ -3980,7 +3988,7 @@ class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
   using OpRewritePattern<MaskedLoadOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(MaskedLoadOp load,
                                 PatternRewriter &rewriter) const override {
-    switch (get1DMaskFormat(load.getMask())) {
+    switch (getMaskFormat(load.getMask())) {
     case MaskFormat::AllTrue:
       rewriter.replaceOpWithNewOp<vector::LoadOp>(
           load, load.getType(), load.getBase(), load.getIndices());
@@ -4031,7 +4039,7 @@ class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
   using OpRewritePattern<MaskedStoreOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(MaskedStoreOp store,
                                 PatternRewriter &rewriter) const override {
-    switch (get1DMaskFormat(store.getMask())) {
+    switch (getMaskFormat(store.getMask())) {
     case MaskFormat::AllTrue:
       rewriter.replaceOpWithNewOp<vector::StoreOp>(
           store, store.getValueToStore(), store.getBase(), store.getIndices());
@@ -4074,9 +4082,9 @@ LogicalResult GatherOp::verify() {
     return emitOpError("base and result element type should match");
   if (llvm::size(getIndices()) != baseType.getRank())
     return emitOpError("requires ") << baseType.getRank() << " indices";
-  if (resVType.getDimSize(0) != indVType.getDimSize(0))
+  if (resVType.getShape() != indVType.getShape())
     return emitOpError("expected result dim to match indices dim");
-  if (resVType.getDimSize(0) != maskVType.getDimSize(0))
+  if (resVType.getShape() != maskVType.getShape())
     return emitOpError("expected result dim to match mask dim");
   if (resVType != getPassThruVectorType())
     return emitOpError("expected pass_thru of same type as result type");
@@ -4089,7 +4097,7 @@ class GatherFolder final : public OpRewritePattern<GatherOp> {
   using OpRewritePattern<GatherOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(GatherOp gather,
                                 PatternRewriter &rewriter) const override {
-    switch (get1DMaskFormat(gather.getMask())) {
+    switch (getMaskFormat(gather.getMask())) {
     case MaskFormat::AllTrue:
       return failure(); // no unmasked equivalent
     case MaskFormat::AllFalse:
@@ -4135,7 +4143,7 @@ class ScatterFolder final : public OpRewritePattern<ScatterOp> {
   using OpRewritePattern<ScatterOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(ScatterOp scatter,
                                 PatternRewriter &rewriter) const override {
-    switch (get1DMaskFormat(scatter.getMask())) {
+    switch (getMaskFormat(scatter.getMask())) {
     case MaskFormat::AllTrue:
       return failure(); // no unmasked equivalent
     case MaskFormat::AllFalse:
@@ -4181,7 +4189,7 @@ class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
   using OpRewritePattern<ExpandLoadOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(ExpandLoadOp expand,
                                 PatternRewriter &rewriter) const override {
-    switch (get1DMaskFormat(expand.getMask())) {
+    switch (getMaskFormat(expand.getMask())) {
     case MaskFormat::AllTrue:
       rewriter.replaceOpWithNewOp<vector::LoadOp>(
           expand, expand.getType(), expand.getBase(), expand.getIndices());
@@ -4226,7 +4234,7 @@ class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
   using OpRewritePattern<CompressStoreOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(CompressStoreOp compress,
                                 PatternRewriter &rewriter) const override {
-    switch (get1DMaskFormat(compress.getMask())) {
+    switch (getMaskFormat(compress.getMask())) {
     case MaskFormat::AllTrue:
       rewriter.replaceOpWithNewOp<vector::StoreOp>(
           compress, compress.getValueToStore(), compress.getBase(),

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index e99460172c98c..37827bd1c226c 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1932,6 +1932,56 @@ func.func @gather_op_index(%arg0: memref<?xindex>, %arg1: vector<3xindex>, %arg2
 
 // -----
 
+func.func @gather_op_multi_dims(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> {
+  %0 = arith.constant 0: index
+  %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
+  return %1 : vector<2x3xf32>
+}
+
+// CHECK-LABEL: func @gather_op_multi_dims
+// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi32>>
+// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi1>>
+// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
+// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr<f32>, vector<3xi32>) -> !llvm.vec<3 x ptr<f32>>
+// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr<f32>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
+// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi32>>
+// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi1>>
+// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
+// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr<f32>, vector<3xi32>) -> !llvm.vec<3 x ptr<f32>>
+// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr<f32>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
+
+// -----
+
+func.func @gather_op_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
+  %0 = arith.constant 0: index
+  %1 = vector.constant_mask [1, 2] : vector<2x3xi1>
+  %2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
+  return %2 : vector<2x3xf32>
+}
+
+// CHECK-LABEL: func @gather_op_with_mask
+// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr<f32>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr<f32>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+
+// -----
+
+func.func @gather_op_with_zero_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
+  %0 = arith.constant 0: index
+  %1 = vector.constant_mask [0, 0] : vector<2x3xi1>
+  %2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
+  return %2 : vector<2x3xf32>
+}
+
+// CHECK-LABEL: func @gather_op_with_zero_mask
+// CHECK-SAME:    (%{{.*}}: memref<?xf32>, %{{.*}}: vector<2x3xi32>, %[[S:.*]]: vector<2x3xf32>)
+// CHECK-NOT:   %{{.*}} = llvm.intr.masked.gather
+// CHECK:       return %[[S]] : vector<2x3xf32>
+
+// -----
+
 func.func @gather_2d_op(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> {
   %0 = arith.constant 3 : index
   %1 = vector.gather %arg0[%0, %0][%arg1], %arg2, %arg3 : memref<4x4xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> into vector<4xf32>

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index d50315970d744..9a7e6f4979a39 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1305,7 +1305,7 @@ func.func @gather_memref_mismatch(%base: memref<?x?xf64>, %indices: vector<16xi3
 func.func @gather_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
                            %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{'vector.gather' op result #0 must be  of ranks 1, but got 'vector<2x16xf32>'}}
+  // expected-error at +1 {{'vector.gather' op expected result dim to match indices dim}}
   %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
     : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<2x16xf32>
 }

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index a017a8c0cfd45..ab8d1e07d99c7 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -701,6 +701,14 @@ func.func @gather_on_tensor(%base: tensor<?xf32>, %v: vector<16xi32>, %mask: vec
   return %0 : vector<16xf32>
 }
 
+// CHECK-LABEL: @gather_multi_dims
+func.func @gather_multi_dims(%base: tensor<?xf32>, %v: vector<2x16xi32>, %mask: vector<2x16xi1>, %pass_thru: vector<2x16xf32>) -> vector<2x16xf32> {
+  %c0 = arith.constant 0 : index
+  // CHECK: vector.gather %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %{{.*}} : tensor<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
+  %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
+  return %0 : vector<2x16xf32>
+}
+
 // CHECK-LABEL: @expand_and_compress
 func.func @expand_and_compress(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
   %c0 = arith.constant 0 : index


        


More information about the Mlir-commits mailing list