[Mlir-commits] [mlir] 0cbfd6f - [MLIR]Extend vector.gather to support n-D result

Diego Caballero llvmlistbot at llvm.org
Tue Aug 23 09:59:22 PDT 2022


Author: Che-Yu Wu
Date: 2022-08-23T16:53:19Z
New Revision: 0cbfd6fd1633a075dcfd1bcd8a11e1c6d2785fa8

URL: https://github.com/llvm/llvm-project/commit/0cbfd6fd1633a075dcfd1bcd8a11e1c6d2785fa8
DIFF: https://github.com/llvm/llvm-project/commit/0cbfd6fd1633a075dcfd1bcd8a11e1c6d2785fa8.diff

LOG: [MLIR]Extend vector.gather to support n-D result

Currently vector.gather only supports reading memory into a 1-D result vector.
This patch extends it to support an n-D result vector with the indices, masks,
and passthroughs in n-D vectors.

As we are trying to vectorize tensor.extract with vector.gather
(https://github.com/iree-org/iree/issues/9198), it will need to gather the
elements into an n-D vector. Having vector.gather with n-D results allows us
to avoid flatten and reshape at the vectorization stage. The backends can then
decide the optimal ways to lower the vector.gather op.

Note that this is different from n-D gathering, which is about reading n-D
memory with the n-D indices. The indices here are still only 1-D offsets on
the base.

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D131905

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 3cc2287599a78..aeffe601324e0 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1761,10 +1761,10 @@ def Vector_GatherOp :
   Vector_Op<"gather">,
     Arguments<(ins Arg<AnyShaped, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
-               VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec,
-               VectorOfRankAndType<[1], [I1]>:$mask,
-               VectorOfRank<[1]>:$pass_thru)>,
-    Results<(outs VectorOfRank<[1]>:$result)> {
+               VectorOf<[AnyInteger, Index]>:$index_vec,
+               VectorOf<[I1]>:$mask,
+               AnyVector:$pass_thru)>,
+    Results<(outs AnyVector:$result)> {
 
   let summary = [{
     gathers elements from memory or ranked tensor into a vector as defined by an
@@ -1773,10 +1773,10 @@ def Vector_GatherOp :
 
   let description = [{
     The gather operation gathers elements from memory or ranked tensor into a
-    1-D vector as defined by a base with indices and an additional 1-D index
-    vector, but only if the corresponding bit is set in a 1-D mask vector.
-    Otherwise, the element is taken from a 1-D pass-through vector. Informally
-    the semantics are:
+    n-D vector as defined by a base with indices and an additional n-D index
+    vector (each index is a 1-D offset on the base), but only if the
+    corresponding bit is set in a n-D mask vector. Otherwise, the element is
+    taken from a n-D pass-through vector. Informally the semantics are:
     ```
     result[0] := mask[0] ? base[index[0]] : pass_thru[0]
     result[1] := mask[1] ? base[index[1]] : pass_thru[1]
@@ -1786,15 +1786,13 @@ def Vector_GatherOp :
 
     The gather operation can be used directly where applicable, or can be used
     during progressively lowering to bring other memory operations closer to
-    hardware ISA support for a gather. The semantics of the operation closely
-    correspond to those of the `llvm.masked.gather`
-    [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-gather-intrinsics).
+    hardware ISA support for a gather.
 
     Examples:
 
     ```mlir
     %0 = vector.gather %base[%c0][%v], %mask, %pass_thru
-       : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+       : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
 
     %1 = vector.gather %base[%i, %j][%v], %mask, %pass_thru
        : memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>

diff  --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index f5cd8addf9256..2f0091f99dd3b 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -82,14 +82,6 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
     std::function<Value(Type, ValueRange)> createOperand,
     ConversionPatternRewriter &rewriter) {
   auto resultNDVectorType = op->getResult(0).getType().cast<VectorType>();
-
-  SmallVector<Type> operand1DVectorTypes;
-  for (Value operand : op->getOperands()) {
-    auto operandNDVectorType = operand.getType().cast<VectorType>();
-    auto operandTypeInfo =
-        extractNDVectorTypeInfo(operandNDVectorType, typeConverter);
-    operand1DVectorTypes.push_back(operandTypeInfo.llvm1DVectorTy);
-  }
   auto resultTypeInfo =
       extractNDVectorTypeInfo(resultNDVectorType, typeConverter);
   auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy;

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index bb293f33fb2cf..a9c4830676106 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -91,24 +91,28 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
   return success();
 }
 
-// Add an index vector component to a base pointer. This almost always succeeds
-// unless the last stride is non-unit or the memory space is not zero.
-static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
-                                    Location loc, Value memref, Value base,
-                                    Value index, MemRefType memRefType,
-                                    VectorType vType, Value &ptrs) {
+// Check if the last stride is non-unit or the memory space is not zero.
+static LogicalResult isMemRefTypeSupported(MemRefType memRefType) {
   int64_t offset;
   SmallVector<int64_t, 4> strides;
   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
   if (failed(successStrides) || strides.back() != 1 ||
       memRefType.getMemorySpaceAsInt() != 0)
     return failure();
-  auto pType = MemRefDescriptor(memref).getElementPtrType();
-  auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0));
-  ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index);
   return success();
 }
 
+// Add an index vector component to a base pointer.
+static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
+                            MemRefType memRefType, Value llvmMemref, Value base,
+                            Value index, uint64_t vLen) {
+  assert(succeeded(isMemRefTypeSupported(memRefType)) &&
+         "unsupported memref type");
+  auto pType = MemRefDescriptor(llvmMemref).getElementPtrType();
+  auto ptrsType = LLVM::getFixedVectorType(pType, vLen);
+  return rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index);
+}
+
 // Casts a strided element pointer to a vector pointer.  The vector pointer
 // will be in the same address space as the incoming memref type.
 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
@@ -257,29 +261,53 @@ class VectorGatherOpConversion
   LogicalResult
   matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto loc = gather->getLoc();
     MemRefType memRefType = gather.getBaseType().dyn_cast<MemRefType>();
     assert(memRefType && "The base should be bufferized");
 
+    if (failed(isMemRefTypeSupported(memRefType)))
+      return failure();
+
+    auto loc = gather->getLoc();
+
     // Resolve alignment.
     unsigned align;
     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
       return failure();
 
-    // Resolve address.
-    Value ptrs;
-    VectorType vType = gather.getVectorType();
     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
                                      adaptor.getIndices(), rewriter);
-    if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr,
-                              adaptor.getIndexVec(), memRefType, vType, ptrs)))
-      return failure();
+    Value base = adaptor.getBase();
+
+    auto llvmNDVectorTy = adaptor.getIndexVec().getType();
+    // Handle the simple case of 1-D vector.
+    if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>()) {
+      auto vType = gather.getVectorType();
+      // Resolve address.
+      Value ptrs = getIndexedPtrs(rewriter, loc, memRefType, base, ptr,
+                                  adaptor.getIndexVec(),
+                                  /*vLen=*/vType.getDimSize(0));
+      // Replace with the gather intrinsic.
+      rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
+          gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
+          adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
+      return success();
+    }
 
-    // Replace with the gather intrinsic.
-    rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
-        gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
-        adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
-    return success();
+    auto callback = [align, memRefType, base, ptr, loc, &rewriter](
+                        Type llvm1DVectorTy, ValueRange vectorOperands) {
+      // Resolve address.
+      Value ptrs = getIndexedPtrs(
+          rewriter, loc, memRefType, base, ptr, /*index=*/vectorOperands[0],
+          LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue());
+      // Create the gather intrinsic.
+      return rewriter.create<LLVM::masked_gather>(
+          loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1],
+          /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align));
+    };
+    ValueRange vectorOperands = {adaptor.getIndexVec(), adaptor.getMask(),
+                                 adaptor.getPassThru()};
+    return LLVM::detail::handleMultidimensionalVectors(
+        gather, vectorOperands, *getTypeConverter(), callback, rewriter);
   }
 };
 
@@ -295,19 +323,21 @@ class VectorScatterOpConversion
     auto loc = scatter->getLoc();
     MemRefType memRefType = scatter.getMemRefType();
 
+    if (failed(isMemRefTypeSupported(memRefType)))
+      return failure();
+
     // Resolve alignment.
     unsigned align;
     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
       return failure();
 
     // Resolve address.
-    Value ptrs;
     VectorType vType = scatter.getVectorType();
     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
                                      adaptor.getIndices(), rewriter);
-    if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr,
-                              adaptor.getIndexVec(), memRefType, vType, ptrs)))
-      return failure();
+    Value ptrs =
+        getIndexedPtrs(rewriter, loc, memRefType, adaptor.getBase(), ptr,
+                       adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0));
 
     // Replace with the scatter intrinsic.
     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b411471983115..ef37005ddc913 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -49,11 +49,11 @@ enum class MaskFormat {
   Unknown = 2,
 };
 
-/// Helper method to classify a 1-D mask value. Currently, the method
+/// Helper method to classify a mask value. Currently, the method
 /// looks "under the hood" of a constant value with dense attributes
 /// and a constant mask operation (since the client may be called at
 /// various stages during progressive lowering).
-static MaskFormat get1DMaskFormat(Value mask) {
+static MaskFormat getMaskFormat(Value mask) {
   if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
     // Inspect constant dense values. We count up for bits that
     // are set, count down for bits that are cleared, and bail
@@ -77,12 +77,20 @@ static MaskFormat get1DMaskFormat(Value mask) {
     // dimension size, all bits are set. If the index is zero
     // or less, no bits are set.
     ArrayAttr masks = m.getMaskDimSizes();
-    assert(masks.size() == 1);
-    int64_t i = masks[0].cast<IntegerAttr>().getInt();
-    int64_t u = m.getType().getDimSize(0);
-    if (i >= u)
+    auto shape = m.getType().getShape();
+    bool allTrue = true;
+    bool allFalse = true;
+    for (auto pair : llvm::zip(masks, shape)) {
+      int64_t i = std::get<0>(pair).cast<IntegerAttr>().getInt();
+      int64_t u = std::get<1>(pair);
+      if (i < u)
+        allTrue = false;
+      if (i > 0)
+        allFalse = false;
+    }
+    if (allTrue)
       return MaskFormat::AllTrue;
-    if (i <= 0)
+    if (allFalse)
       return MaskFormat::AllFalse;
   }
   return MaskFormat::Unknown;
@@ -3980,7 +3988,7 @@ class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
   using OpRewritePattern<MaskedLoadOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(MaskedLoadOp load,
                                 PatternRewriter &rewriter) const override {
-    switch (get1DMaskFormat(load.getMask())) {
+    switch (getMaskFormat(load.getMask())) {
     case MaskFormat::AllTrue:
       rewriter.replaceOpWithNewOp<vector::LoadOp>(
           load, load.getType(), load.getBase(), load.getIndices());
@@ -4031,7 +4039,7 @@ class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
   using OpRewritePattern<MaskedStoreOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(MaskedStoreOp store,
                                 PatternRewriter &rewriter) const override {
-    switch (get1DMaskFormat(store.getMask())) {
+    switch (getMaskFormat(store.getMask())) {
     case MaskFormat::AllTrue:
       rewriter.replaceOpWithNewOp<vector::StoreOp>(
           store, store.getValueToStore(), store.getBase(), store.getIndices());
@@ -4074,9 +4082,9 @@ LogicalResult GatherOp::verify() {
     return emitOpError("base and result element type should match");
   if (llvm::size(getIndices()) != baseType.getRank())
     return emitOpError("requires ") << baseType.getRank() << " indices";
-  if (resVType.getDimSize(0) != indVType.getDimSize(0))
+  if (resVType.getShape() != indVType.getShape())
     return emitOpError("expected result dim to match indices dim");
-  if (resVType.getDimSize(0) != maskVType.getDimSize(0))
+  if (resVType.getShape() != maskVType.getShape())
     return emitOpError("expected result dim to match mask dim");
   if (resVType != getPassThruVectorType())
     return emitOpError("expected pass_thru of same type as result type");
@@ -4089,7 +4097,7 @@ class GatherFolder final : public OpRewritePattern<GatherOp> {
   using OpRewritePattern<GatherOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(GatherOp gather,
                                 PatternRewriter &rewriter) const override {
-    switch (get1DMaskFormat(gather.getMask())) {
+    switch (getMaskFormat(gather.getMask())) {
     case MaskFormat::AllTrue:
       return failure(); // no unmasked equivalent
     case MaskFormat::AllFalse:
@@ -4135,7 +4143,7 @@ class ScatterFolder final : public OpRewritePattern<ScatterOp> {
   using OpRewritePattern<ScatterOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(ScatterOp scatter,
                                 PatternRewriter &rewriter) const override {
-    switch (get1DMaskFormat(scatter.getMask())) {
+    switch (getMaskFormat(scatter.getMask())) {
     case MaskFormat::AllTrue:
       return failure(); // no unmasked equivalent
     case MaskFormat::AllFalse:
@@ -4181,7 +4189,7 @@ class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
   using OpRewritePattern<ExpandLoadOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(ExpandLoadOp expand,
                                 PatternRewriter &rewriter) const override {
-    switch (get1DMaskFormat(expand.getMask())) {
+    switch (getMaskFormat(expand.getMask())) {
     case MaskFormat::AllTrue:
       rewriter.replaceOpWithNewOp<vector::LoadOp>(
           expand, expand.getType(), expand.getBase(), expand.getIndices());
@@ -4226,7 +4234,7 @@ class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
   using OpRewritePattern<CompressStoreOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(CompressStoreOp compress,
                                 PatternRewriter &rewriter) const override {
-    switch (get1DMaskFormat(compress.getMask())) {
+    switch (getMaskFormat(compress.getMask())) {
     case MaskFormat::AllTrue:
       rewriter.replaceOpWithNewOp<vector::StoreOp>(
           compress, compress.getValueToStore(), compress.getBase(),

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index e99460172c98c..37827bd1c226c 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1932,6 +1932,56 @@ func.func @gather_op_index(%arg0: memref<?xindex>, %arg1: vector<3xindex>, %arg2
 
 // -----
 
+func.func @gather_op_multi_dims(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> {
+  %0 = arith.constant 0: index
+  %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
+  return %1 : vector<2x3xf32>
+}
+
+// CHECK-LABEL: func @gather_op_multi_dims
+// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi32>>
+// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi1>>
+// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
+// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr<f32>, vector<3xi32>) -> !llvm.vec<3 x ptr<f32>>
+// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr<f32>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
+// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi32>>
+// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi1>>
+// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
+// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr<f32>, vector<3xi32>) -> !llvm.vec<3 x ptr<f32>>
+// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr<f32>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
+
+// -----
+
+func.func @gather_op_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
+  %0 = arith.constant 0: index
+  %1 = vector.constant_mask [1, 2] : vector<2x3xi1>
+  %2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
+  return %2 : vector<2x3xf32>
+}
+
+// CHECK-LABEL: func @gather_op_with_mask
+// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr<f32>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr<f32>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+
+// -----
+
+func.func @gather_op_with_zero_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
+  %0 = arith.constant 0: index
+  %1 = vector.constant_mask [0, 0] : vector<2x3xi1>
+  %2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
+  return %2 : vector<2x3xf32>
+}
+
+// CHECK-LABEL: func @gather_op_with_zero_mask
+// CHECK-SAME:    (%{{.*}}: memref<?xf32>, %{{.*}}: vector<2x3xi32>, %[[S:.*]]: vector<2x3xf32>)
+// CHECK-NOT:   %{{.*}} = llvm.intr.masked.gather
+// CHECK:       return %[[S]] : vector<2x3xf32>
+
+// -----
+
 func.func @gather_2d_op(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> {
   %0 = arith.constant 3 : index
   %1 = vector.gather %arg0[%0, %0][%arg1], %arg2, %arg3 : memref<4x4xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> into vector<4xf32>

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index d50315970d744..9a7e6f4979a39 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1305,7 +1305,7 @@ func.func @gather_memref_mismatch(%base: memref<?x?xf64>, %indices: vector<16xi3
 func.func @gather_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
                            %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{'vector.gather' op result #0 must be  of ranks 1, but got 'vector<2x16xf32>'}}
+  // expected-error at +1 {{'vector.gather' op expected result dim to match indices dim}}
   %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
     : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<2x16xf32>
 }

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index a017a8c0cfd45..ab8d1e07d99c7 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -701,6 +701,14 @@ func.func @gather_on_tensor(%base: tensor<?xf32>, %v: vector<16xi32>, %mask: vec
   return %0 : vector<16xf32>
 }
 
+// CHECK-LABEL: @gather_multi_dims
+func.func @gather_multi_dims(%base: tensor<?xf32>, %v: vector<2x16xi32>, %mask: vector<2x16xi1>, %pass_thru: vector<2x16xf32>) -> vector<2x16xf32> {
+  %c0 = arith.constant 0 : index
+  // CHECK: vector.gather %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %{{.*}} : tensor<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
+  %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
+  return %0 : vector<2x16xf32>
+}
+
 // CHECK-LABEL: @expand_and_compress
 func.func @expand_and_compress(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
   %c0 = arith.constant 0 : index


        


More information about the Mlir-commits mailing list