[Mlir-commits] [mlir] df5ccf5 - [mlir][vector] add higher dimensional support to gather/scatter

Aart Bik llvmlistbot at llvm.org
Fri Feb 26 14:20:34 PST 2021


Author: Aart Bik
Date: 2021-02-26T14:20:19-08:00
New Revision: df5ccf5a94ad04231e8426dd0d02689a0717453b

URL: https://github.com/llvm/llvm-project/commit/df5ccf5a94ad04231e8426dd0d02689a0717453b
DIFF: https://github.com/llvm/llvm-project/commit/df5ccf5a94ad04231e8426dd0d02689a0717453b.diff

LOG: [mlir][vector] add higher dimensional support to gather/scatter

Similar to mask-load/store and compress/expand, the gather and
scatter operation now allow for higher dimension uses. Note that
to support the mixed-type index, the new syntax is:
   vector.gather %base [%i,%j] [%kvector] ....
The first client of this generalization is the sparse compiler,
which needs to define scatter and gathers on dense operands
of higher dimensions too.

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D97422

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Dialect/Linalg/sparse_vector.mlir
    mlir/test/Dialect/Vector/canonicalize.mlir
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir
    mlir/test/Dialect/Vector/vector-mem-transforms.mlir
    mlir/test/Integration/Dialect/Vector/CPU/test-gather.mlir
    mlir/test/Integration/Dialect/Vector/CPU/test-scatter.mlir
    mlir/test/Integration/Dialect/Vector/CPU/test-sparse-dot-matvec.mlir
    mlir/test/Integration/Dialect/Vector/CPU/test-sparse-saxpy-jagged-matvec.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index cf18cd89e170..83b221fac013 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1574,11 +1574,14 @@ def Vector_MaskedLoadOp :
     closely correspond to those of the `llvm.masked.load`
     [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-load-intrinsics).
 
-    Example:
+    Examples:
 
     ```mlir
     %0 = vector.maskedload %base[%i], %mask, %pass_thru
        : memref<?xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+
+    %1 = vector.maskedload %base[%i, %j], %mask, %pass_thru
+       : memref<?x?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
     ```
   }];
   let extraClassDeclaration = [{
@@ -1625,11 +1628,14 @@ def Vector_MaskedStoreOp :
     closely correspond to those of the `llvm.masked.store`
     [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-store-intrinsics).
 
-    Example:
+    Examples:
 
     ```mlir
     vector.maskedstore %base[%i], %mask, %value
       : memref<?xf32>, vector<8xi1>, vector<8xf32>
+
+    vector.maskedstore %base[%i, %j], %mask, %value
+      : memref<?x?xf32>, vector<16xi1>, vector<16xf32>
     ```
   }];
   let extraClassDeclaration = [{
@@ -1652,7 +1658,8 @@ def Vector_MaskedStoreOp :
 def Vector_GatherOp :
   Vector_Op<"gather">,
     Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
-               VectorOfRankAndType<[1], [AnyInteger]>:$indices,
+               Variadic<Index>:$indices,
+	       VectorOfRankAndType<[1], [AnyInteger]>:$index_vec,
                VectorOfRankAndType<[1], [I1]>:$mask,
                VectorOfRank<[1]>:$pass_thru)>,
     Results<(outs VectorOfRank<[1]>:$result)> {
@@ -1661,9 +1668,10 @@ def Vector_GatherOp :
 
   let description = [{
     The gather operation gathers elements from memory into a 1-D vector as
-    defined by a base and a 1-D index vector, but only if the corresponding
-    bit is set in a 1-D mask vector. Otherwise, the element is taken from a
-    1-D pass-through vector. Informally the semantics are:
+    defined by a base with indices and an additional 1-D index vector, but
+    only if the corresponding bit is set in a 1-D mask vector. Otherwise, the
+    element is taken from a 1-D pass-through vector. Informally the semantics
+    are:
     ```
     result[0] := mask[0] ? base[index[0]] : pass_thru[0]
     result[1] := mask[1] ? base[index[1]] : pass_thru[1]
@@ -1677,19 +1685,22 @@ def Vector_GatherOp :
     correspond to those of the `llvm.masked.gather`
     [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-gather-intrinsics).
 
-    Example:
+    Examples:
 
     ```mlir
-    %g = vector.gather %base[%indices], %mask, %pass_thru
+    %0 = vector.gather %base[%c0][%v], %mask, %pass_thru
        : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+
+    %1 = vector.gather %base[%i, %j][%v], %mask, %pass_thru
+       : memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
     ```
   }];
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
       return base().getType().cast<MemRefType>();
     }
-    VectorType getIndicesVectorType() {
-      return indices().getType().cast<VectorType>();
+    VectorType getIndexVectorType() {
+      return index_vec().getType().cast<VectorType>();
     }
     VectorType getMaskVectorType() {
       return mask().getType().cast<VectorType>();
@@ -1701,15 +1712,19 @@ def Vector_GatherOp :
       return result().getType().cast<VectorType>();
     }
   }];
-  let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` "
-    "type($base) `,` type($indices) `,` type($mask) `,` type($pass_thru) `into` type($result)";
+  let assemblyFormat =
+    "$base `[` $indices `]` `[` $index_vec `]` `,` "
+    "$mask `,` $pass_thru attr-dict `:` type($base) `,` "
+    "type($index_vec)  `,` type($mask) `,` type($pass_thru) "
+    "`into` type($result)";
   let hasCanonicalizer = 1;
 }
 
 def Vector_ScatterOp :
   Vector_Op<"scatter">,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
-               VectorOfRankAndType<[1], [AnyInteger]>:$indices,
+               Variadic<Index>:$indices,
+	       VectorOfRankAndType<[1], [AnyInteger]>:$index_vec,
                VectorOfRankAndType<[1], [I1]>:$mask,
                VectorOfRank<[1]>:$valueToStore)> {
 
@@ -1717,9 +1732,9 @@ def Vector_ScatterOp :
 
   let description = [{
     The scatter operation scatters elements from a 1-D vector into memory as
-    defined by a base and a 1-D index vector, but only if the corresponding
-    bit in a 1-D mask vector is set. Otherwise, no action is taken for that
-    element. Informally the semantics are:
+    defined by a base with indices and an additional 1-D index vector, but
+    only if the corresponding bit in a 1-D mask vector is set. Otherwise, no
+    action is taken for that element. Informally the semantics are:
     ```
     if (mask[0]) base[index[0]] = value[0]
     if (mask[1]) base[index[1]] = value[1]
@@ -1736,19 +1751,22 @@ def Vector_ScatterOp :
     correspond to those of the `llvm.masked.scatter`
     [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-scatter-intrinsics).
 
-    Example:
+    Examples:
 
     ```mlir
-    vector.scatter %base[%indices], %mask, %value
+    vector.scatter %base[%c0][%v], %mask, %value
         : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+
+    vector.scatter %base[%i, %j][%v], %mask, %value
+        : memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
     ```
   }];
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
       return base().getType().cast<MemRefType>();
     }
-    VectorType getIndicesVectorType() {
-      return indices().getType().cast<VectorType>();
+    VectorType getIndexVectorType() {
+      return index_vec().getType().cast<VectorType>();
     }
     VectorType getMaskVectorType() {
       return mask().getType().cast<VectorType>();
@@ -1758,8 +1776,9 @@ def Vector_ScatterOp :
     }
   }];
   let assemblyFormat =
-      "$base `[` $indices `]` `,` $mask `,` $valueToStore attr-dict `:` "
-      "type($base) `,` type($indices) `,` type($mask) `,` type($valueToStore)";
+      "$base `[` $indices `]` `[` $index_vec `]` `,` "
+      "$mask `,` $valueToStore attr-dict `:` type($base) `,` "
+      "type($index_vec)  `,` type($mask) `,` type($valueToStore)";
   let hasCanonicalizer = 1;
 }
 
@@ -1792,11 +1811,14 @@ def Vector_ExpandLoadOp :
     correspond to those of the `llvm.masked.expandload`
     [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics).
 
-    Example:
+    Examples:
 
     ```mlir
     %0 = vector.expandload %base[%i], %mask, %pass_thru
        : memref<?xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+
+    %1 = vector.expandload %base[%i, %j], %mask, %pass_thru
+       : memref<?x?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
     ```
   }];
   let extraClassDeclaration = [{
@@ -1846,11 +1868,14 @@ def Vector_CompressStoreOp :
     correspond to those of the `llvm.masked.compressstore`
     [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics).
 
-    Example:
+    Examples:
 
     ```mlir
     vector.compressstore %base[%i], %mask, %value
       : memref<?xf32>, vector<8xi1>, vector<8xf32>
+
+    vector.compressstore %base[%i, %j], %mask, %value
+      : memref<?x?xf32>, vector<16xi1>, vector<16xf32>
     ```
   }];
   let extraClassDeclaration = [{

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 168665e23111..d567e065479d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -178,34 +178,21 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
   return success();
 }
 
-// Helper that returns the base address of a memref.
-static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
-                             Value memref, MemRefType memRefType, Value &base) {
-  // Inspect stride and offset structure.
-  //
-  // TODO: flat memory only for now, generalize
-  //
+// Add an index vector component to a base pointer. This almost always succeeds
+// unless the last stride is non-unit or the memory space is not zero.
+static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
+                                    Location loc, Value memref, Value base,
+                                    Value index, MemRefType memRefType,
+                                    VectorType vType, Value &ptrs) {
   int64_t offset;
   SmallVector<int64_t, 4> strides;
   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
-  if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
-      offset != 0 || memRefType.getMemorySpace() != 0)
-    return failure();
-  base = MemRefDescriptor(memref).alignedPtr(rewriter, loc);
-  return success();
-}
-
-// Helper that returns vector of pointers given a memref base with index vector.
-static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
-                                    Location loc, Value memref, Value indices,
-                                    MemRefType memRefType, VectorType vType,
-                                    Type iType, Value &ptrs) {
-  Value base;
-  if (failed(getBase(rewriter, loc, memref, memRefType, base)))
+  if (failed(successStrides) || strides.back() != 1 ||
+      memRefType.getMemorySpace() != 0)
     return failure();
   auto pType = MemRefDescriptor(memref).getElementPtrType();
   auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0));
-  ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
+  ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index);
   return success();
 }
 
@@ -435,19 +422,20 @@ class VectorGatherOpConversion
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = gather->getLoc();
     auto adaptor = vector::GatherOpAdaptor(operands);
+    MemRefType memRefType = gather.getMemRefType();
 
     // Resolve alignment.
     unsigned align;
-    if (failed(getMemRefAlignment(*getTypeConverter(), gather.getMemRefType(),
-                                  align)))
+    if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
       return failure();
 
-    // Get index ptrs.
-    VectorType vType = gather.getVectorType();
-    Type iType = gather.getIndicesVectorType().getElementType();
+    // Resolve address.
     Value ptrs;
-    if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
-                              gather.getMemRefType(), vType, iType, ptrs)))
+    VectorType vType = gather.getVectorType();
+    Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
+                                     adaptor.indices(), rewriter);
+    if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr,
+                              adaptor.index_vec(), memRefType, vType, ptrs)))
       return failure();
 
     // Replace with the gather intrinsic.
@@ -469,19 +457,20 @@ class VectorScatterOpConversion
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = scatter->getLoc();
     auto adaptor = vector::ScatterOpAdaptor(operands);
+    MemRefType memRefType = scatter.getMemRefType();
 
     // Resolve alignment.
     unsigned align;
-    if (failed(getMemRefAlignment(*getTypeConverter(), scatter.getMemRefType(),
-                                  align)))
+    if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
       return failure();
 
-    // Get index ptrs.
-    VectorType vType = scatter.getVectorType();
-    Type iType = scatter.getIndicesVectorType().getElementType();
+    // Resolve address.
     Value ptrs;
-    if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
-                              scatter.getMemRefType(), vType, iType, ptrs)))
+    VectorType vType = scatter.getVectorType();
+    Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
+                                     adaptor.indices(), rewriter);
+    if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr,
+                              adaptor.index_vec(), memRefType, vType, ptrs)))
       return failure();
 
     // Replace with the scatter intrinsic.
@@ -507,8 +496,8 @@ class VectorExpandLoadOpConversion
 
     // Resolve address.
     auto vtype = typeConverter->convertType(expand.getVectorType());
-    Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
-                                           adaptor.indices(), rewriter);
+    Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
+                                     adaptor.indices(), rewriter);
 
     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
         expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru());
@@ -530,8 +519,8 @@ class VectorCompressStoreOpConversion
     MemRefType memRefType = compress.getMemRefType();
 
     // Resolve address.
-    Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
-                                           adaptor.indices(), rewriter);
+    Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
+                                     adaptor.indices(), rewriter);
 
     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
         compress, adaptor.valueToStore(), ptr, adaptor.mask());

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
index fbb45c6b86d8..efaaed6f2739 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
@@ -652,9 +652,13 @@ static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter,
   Location loc = ptr.getLoc();
   VectorType vtp = vectorType(codegen, ptr);
   Value pass = rewriter.create<ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp));
-  if (args.back().getType().isa<VectorType>())
-    return rewriter.create<vector::GatherOp>(loc, vtp, ptr, args.back(),
-                                             codegen.curVecMask, pass);
+  if (args.back().getType().isa<VectorType>()) {
+    SmallVector<Value, 4> scalarArgs(args.begin(), args.end());
+    Value indexVec = args.back();
+    scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0);
+    return rewriter.create<vector::GatherOp>(
+        loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass);
+  }
   return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args,
                                                codegen.curVecMask, pass);
 }
@@ -663,12 +667,16 @@ static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter,
 static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter,
                            Value rhs, Value ptr, ArrayRef<Value> args) {
   Location loc = ptr.getLoc();
-  if (args.back().getType().isa<VectorType>())
-    rewriter.create<vector::ScatterOp>(loc, ptr, args.back(),
+  if (args.back().getType().isa<VectorType>()) {
+    SmallVector<Value, 4> scalarArgs(args.begin(), args.end());
+    Value indexVec = args.back();
+    scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0);
+    rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec,
                                        codegen.curVecMask, rhs);
-  else
-    rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask,
-                                           rhs);
+    return;
+  }
+  rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask,
+                                         rhs);
 }
 
 /// Generates a vectorized invariant. Here we rely on subsequent loop
@@ -985,11 +993,15 @@ static Operation *genWhile(Merger &merger, CodeGen &codegen,
       unsigned tensor = merger.tensor(b);
       assert(idx == merger.index(b));
       types.push_back(indexType);
+      assert(codegen.pidxs[tensor][idx].getType().isa<IndexType>() &&
+             "type mismatch for sparse index");
       operands.push_back(codegen.pidxs[tensor][idx]);
     }
   }
   if (needsUniv) {
     types.push_back(indexType);
+    assert(codegen.loops[idx].getType().isa<IndexType>() &&
+           "type_mismatch for universal index");
     operands.push_back(codegen.loops[idx]);
   }
   Location loc = op.getLoc();
@@ -1160,6 +1172,7 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
     genTensorStore(merger, codegen, rewriter, op, lhs, rhs);
     return;
   }
+  assert(codegen.curVecLength == 1);
 
   // Construct iteration lattices for current loop index, with L0 at top.
   // Then emit initialization code for the loop sequence at this level.
@@ -1239,6 +1252,7 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
   }
   genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false);
   codegen.loops[idx] = Value();
+  codegen.curVecLength = 1;
 }
 
 namespace {

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 01f9c9805675..4702626c3e8c 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -2771,14 +2771,16 @@ void MaskedStoreOp::getCanonicalizationPatterns(
 //===----------------------------------------------------------------------===//
 
 static LogicalResult verify(GatherOp op) {
-  VectorType indicesVType = op.getIndicesVectorType();
+  VectorType indVType = op.getIndexVectorType();
   VectorType maskVType = op.getMaskVectorType();
   VectorType resVType = op.getVectorType();
   MemRefType memType = op.getMemRefType();
 
   if (resVType.getElementType() != memType.getElementType())
     return op.emitOpError("base and result element type should match");
-  if (resVType.getDimSize(0) != indicesVType.getDimSize(0))
+  if (llvm::size(op.indices()) != memType.getRank())
+    return op.emitOpError("requires ") << memType.getRank() << " indices";
+  if (resVType.getDimSize(0) != indVType.getDimSize(0))
     return op.emitOpError("expected result dim to match indices dim");
   if (resVType.getDimSize(0) != maskVType.getDimSize(0))
     return op.emitOpError("expected result dim to match mask dim");
@@ -2817,14 +2819,16 @@ void GatherOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
 //===----------------------------------------------------------------------===//
 
 static LogicalResult verify(ScatterOp op) {
-  VectorType indicesVType = op.getIndicesVectorType();
+  VectorType indVType = op.getIndexVectorType();
   VectorType maskVType = op.getMaskVectorType();
   VectorType valueVType = op.getVectorType();
   MemRefType memType = op.getMemRefType();
 
   if (valueVType.getElementType() != memType.getElementType())
     return op.emitOpError("base and valueToStore element type should match");
-  if (valueVType.getDimSize(0) != indicesVType.getDimSize(0))
+  if (llvm::size(op.indices()) != memType.getRank())
+    return op.emitOpError("requires ") << memType.getRank() << " indices";
+  if (valueVType.getDimSize(0) != indVType.getDimSize(0))
     return op.emitOpError("expected valueToStore dim to match indices dim");
   if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
     return op.emitOpError("expected valueToStore dim to match mask dim");

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index e5c1adcee0ed..39826a652092 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1302,28 +1302,57 @@ func @masked_store_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<
 // -----
 
 func @gather_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
-  %0 = vector.gather %arg0[%arg1], %arg2, %arg3 : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
-  return %0 : vector<3xf32>
+  %0 = constant 0: index
+  %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
+  return %1 : vector<3xf32>
 }
 
 // CHECK-LABEL: func @gather_op
-// CHECK: %[[P:.*]] = llvm.getelementptr {{.*}}[%{{.*}}] : (!llvm.ptr<f32>, vector<3xi32>) -> !llvm.vec<3 x ptr<f32>>
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<f32>, vector<3xi32>) -> !llvm.vec<3 x ptr<f32>>
 // CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr<f32>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
 // CHECK: return %[[G]] : vector<3xf32>
 
 // -----
 
+func @gather_2d_op(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> {
+  %0 = constant 3 : index
+  %1 = vector.gather %arg0[%0, %0][%arg1], %arg2, %arg3 : memref<4x4xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+  return %1 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @gather_2d_op
+// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr<f32>, vector<4xi32>) -> !llvm.vec<4 x ptr<f32>>
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<4 x ptr<f32>>, vector<4xi1>, vector<4xf32>) -> vector<4xf32>
+// CHECK: return %[[G]] : vector<4xf32>
+
+// -----
+
 func @scatter_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) {
-  vector.scatter %arg0[%arg1], %arg2, %arg3 : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
+  %0 = constant 0: index
+  vector.scatter %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
   return
 }
 
 // CHECK-LABEL: func @scatter_op
-// CHECK: %[[P:.*]] = llvm.getelementptr {{.*}}[%{{.*}}] : (!llvm.ptr<f32>, vector<3xi32>) -> !llvm.vec<3 x ptr<f32>>
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<f32>, vector<3xi32>) -> !llvm.vec<3 x ptr<f32>>
 // CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into !llvm.vec<3 x ptr<f32>>
 
 // -----
 
+func @scatter_2d_op(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) {
+  %0 = constant 3 : index
+  vector.scatter %arg0[%0, %0][%arg1], %arg2, %arg3 : memref<4x4xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32>
+  return
+}
+
+// CHECK-LABEL: func @scatter_2d_op
+// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr<f32>, vector<4xi32>) -> !llvm.vec<4 x ptr<f32>>
+// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<4xf32>, vector<4xi1> into !llvm.vec<4 x ptr<f32>>
+
+// -----
+
 func @expand_load_op(%arg0: memref<?xf32>, %arg1: vector<11xi1>, %arg2: vector<11xf32>) -> vector<11xf32> {
   %c0 = constant 0: index
   %0 = vector.expandload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<11xi1>, vector<11xf32> into vector<11xf32>

diff  --git a/mlir/test/Dialect/Linalg/sparse_vector.mlir b/mlir/test/Dialect/Linalg/sparse_vector.mlir
index 644b231a3cb9..1ecfe27789af 100644
--- a/mlir/test/Dialect/Linalg/sparse_vector.mlir
+++ b/mlir/test/Dialect/Linalg/sparse_vector.mlir
@@ -128,9 +128,9 @@ func @scale_d(%arga: tensor<1024xf32>, %scale: f32, %argx: tensor<1024xf32>) ->
 // CHECK-VEC2:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
 // CHECK-VEC2:         %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
 // CHECK-VEC2:         %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
-// CHECK-VEC2:         %[[lb:.*]] = vector.gather %{{.*}}[%[[li]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC2:         %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
 // CHECK-VEC2:         %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
-// CHECK-VEC2:         vector.scatter %{{.*}}[%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+// CHECK-VEC2:         vector.scatter %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
 // CHECK-VEC2:       }
 // CHECK-VEC2:       return
 //
@@ -159,9 +159,9 @@ func @mul_s(%arga: tensor<1024xf32>, %argb: tensor<1024xf32>, %argx: tensor<1024
 // CHECK-VEC2:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
 // CHECK-VEC2:         %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
 // CHECK-VEC2:         %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
-// CHECK-VEC2:         %[[lb:.*]] = vector.gather %{{.*}}[%[[li]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC2:         %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
 // CHECK-VEC2:         %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
-// CHECK-VEC2:         vector.scatter %{{.*}}[%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+// CHECK-VEC2:         vector.scatter %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
 // CHECK-VEC2:       }
 // CHECK-VEC2:       return
 //
@@ -324,9 +324,9 @@ func @reduction_d(%arga: tensor<1024xf32>, %argb: tensor<1024xf32>, %argx: tenso
 // CHECK-VEC2:           %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
 // CHECK-VEC2:           %[[lj:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
 // CHECK-VEC2:           %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
-// CHECK-VEC2:           %[[lb:.*]] = vector.gather %{{.*}}[%[[lj]]], %[[mask]], %{{.*}} : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC2:           %[[lb:.*]] = vector.gather %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %{{.*}} : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
 // CHECK-VEC2:           %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
-// CHECK-VEC2:           vector.scatter %{{.*}}[%[[lj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+// CHECK-VEC2:           vector.scatter %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
 // CHECK-VEC2:         }
 // CHECK-VEC2:       }
 // CHECK-VEC2:       return

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index d427cb952f09..45f5d7d2b901 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -718,7 +718,7 @@ func @dead_load(%base: memref<?xf32>, %indices: vector<16xi32>,
   %c0 = constant 0 : index
   %0 = vector.maskedload %base[%c0], %mask, %passthru :
     memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
-  %1 = vector.gather %base[%indices], %mask, %passthru :
+  %1 = vector.gather %base[%c0][%indices], %mask, %passthru :
     memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
   %2 = vector.expandload %base[%c0], %mask, %passthru :
     memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 6dbc63dc4a59..3b64a6e57e66 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1287,17 +1287,29 @@ func @maskedstore_memref_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>, %v
 
 func @gather_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
                                 %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+  %c0 = constant 0 : index
   // expected-error at +1 {{'vector.gather' op base and result element type should match}}
-  %0 = vector.gather %base[%indices], %mask, %pass_thru
+  %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
     : memref<?xf64>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
 }
 
 // -----
 
+func @gather_memref_mismatch(%base: memref<?x?xf64>, %indices: vector<16xi32>,
+                             %mask: vector<16xi1>, %pass_thru: vector<16xf64>) {
+  %c0 = constant 0 : index
+  // expected-error at +1 {{'vector.gather' op requires 2 indices}}
+  %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
+    : memref<?x?xf64>, vector<16xi32>, vector<16xi1>, vector<16xf64> into vector<16xf64>
+}
+
+// -----
+
 func @gather_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
                            %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+  %c0 = constant 0 : index
   // expected-error at +1 {{'vector.gather' op result #0 must be  of ranks 1, but got 'vector<2x16xf32>'}}
-  %0 = vector.gather %base[%indices], %mask, %pass_thru
+  %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
     : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<2x16xf32>
 }
 
@@ -1305,8 +1317,9 @@ func @gather_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
 
 func @gather_dim_indices_mismatch(%base: memref<?xf32>, %indices: vector<17xi32>,
                                   %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+  %c0 = constant 0 : index
   // expected-error at +1 {{'vector.gather' op expected result dim to match indices dim}}
-  %0 = vector.gather %base[%indices], %mask, %pass_thru
+  %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
     : memref<?xf32>, vector<17xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
 }
 
@@ -1314,8 +1327,9 @@ func @gather_dim_indices_mismatch(%base: memref<?xf32>, %indices: vector<17xi32>
 
 func @gather_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
                                %mask: vector<17xi1>, %pass_thru: vector<16xf32>) {
+  %c0 = constant 0 : index
   // expected-error at +1 {{'vector.gather' op expected result dim to match mask dim}}
-  %0 = vector.gather %base[%indices], %mask, %pass_thru
+  %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
     : memref<?xf32>, vector<16xi32>, vector<17xi1>, vector<16xf32> into vector<16xf32>
 }
 
@@ -1323,8 +1337,9 @@ func @gather_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
 
 func @gather_pass_thru_type_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
                                      %mask: vector<16xi1>, %pass_thru: vector<16xf64>) {
+  %c0 = constant 0 : index
   // expected-error at +1 {{'vector.gather' op expected pass_thru of same type as result type}}
-  %0 = vector.gather %base[%indices], %mask, %pass_thru
+  %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
     : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf64> into vector<16xf32>
 }
 
@@ -1332,17 +1347,29 @@ func @gather_pass_thru_type_mismatch(%base: memref<?xf32>, %indices: vector<16xi
 
 func @scatter_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
                                  %mask: vector<16xi1>, %value: vector<16xf32>) {
+  %c0 = constant 0 : index
   // expected-error at +1 {{'vector.scatter' op base and valueToStore element type should match}}
-  vector.scatter %base[%indices], %mask, %value
+  vector.scatter %base[%c0][%indices], %mask, %value
     : memref<?xf64>, vector<16xi32>, vector<16xi1>, vector<16xf32>
 }
 
 // -----
 
+func @scatter_memref_mismatch(%base: memref<?x?xf64>, %indices: vector<16xi32>,
+                              %mask: vector<16xi1>, %value: vector<16xf64>) {
+  %c0 = constant 0 : index
+  // expected-error at +1 {{'vector.scatter' op requires 2 indices}}
+  vector.scatter %base[%c0][%indices], %mask, %value
+    : memref<?x?xf64>, vector<16xi32>, vector<16xi1>, vector<16xf64>
+}
+
+// -----
+
 func @scatter_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
                             %mask: vector<16xi1>, %value: vector<2x16xf32>) {
-  // expected-error at +1 {{'vector.scatter' op operand #3 must be  of ranks 1, but got 'vector<2x16xf32>'}}
-  vector.scatter %base[%indices], %mask, %value
+  %c0 = constant 0 : index
+  // expected-error at +1 {{'vector.scatter' op operand #4 must be  of ranks 1, but got 'vector<2x16xf32>'}}
+  vector.scatter %base[%c0][%indices], %mask, %value
     : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<2x16xf32>
 }
 
@@ -1350,8 +1377,9 @@ func @scatter_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
 
 func @scatter_dim_indices_mismatch(%base: memref<?xf32>, %indices: vector<17xi32>,
                                    %mask: vector<16xi1>, %value: vector<16xf32>) {
+  %c0 = constant 0 : index
   // expected-error at +1 {{'vector.scatter' op expected valueToStore dim to match indices dim}}
-  vector.scatter %base[%indices], %mask, %value
+  vector.scatter %base[%c0][%indices], %mask, %value
     : memref<?xf32>, vector<17xi32>, vector<16xi1>, vector<16xf32>
 }
 
@@ -1359,8 +1387,9 @@ func @scatter_dim_indices_mismatch(%base: memref<?xf32>, %indices: vector<17xi32
 
 func @scatter_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
                                 %mask: vector<17xi1>, %value: vector<16xf32>) {
+  %c0 = constant 0 : index
   // expected-error at +1 {{'vector.scatter' op expected valueToStore dim to match mask dim}}
-  vector.scatter %base[%indices], %mask, %value
+  vector.scatter %base[%c0][%indices], %mask, %value
     : memref<?xf32>, vector<16xi32>, vector<17xi1>, vector<16xf32>
 }
 

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index eeca040844e4..c0a5af48c821 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -541,25 +541,56 @@ func @masked_load_and_store(%base: memref<?xf32>, %mask: vector<16xi1>, %passthr
   return
 }
 
+// CHECK-LABEL: @masked_load_and_store2d
+func @masked_load_and_store2d(%base: memref<?x?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
+  %c0 = constant 0 : index
+  // CHECK: %[[X:.*]] = vector.maskedload %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  %0 = vector.maskedload %base[%c0, %c0], %mask, %passthru : memref<?x?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  // CHECK: vector.maskedstore %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}, %[[X]] : memref<?x?xf32>, vector<16xi1>, vector<16xf32>
+  vector.maskedstore %base[%c0, %c0], %mask, %0 : memref<?x?xf32>, vector<16xi1>, vector<16xf32>
+  return
+}
+
 // CHECK-LABEL: @gather_and_scatter
-func @gather_and_scatter(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
-  // CHECK: %[[X:.*]] = vector.gather %{{.*}}[%{{.*}}], %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
-  %0 = vector.gather %base[%indices], %mask, %pass_thru : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
-  // CHECK: vector.scatter %{{.*}}[%{{.*}}], %{{.*}}, %[[X]] : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
-  vector.scatter %base[%indices], %mask, %0 : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+func @gather_and_scatter(%base: memref<?xf32>, %v: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+  %c0 = constant 0 : index
+  // CHECK: %[[X:.*]] = vector.gather %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  // CHECK: vector.scatter %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %[[X]] : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+  vector.scatter %base[%c0][%v], %mask, %0 : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+  return
+}
+
+// CHECK-LABEL: @gather_and_scatter2d
+func @gather_and_scatter2d(%base: memref<?x?xf32>, %v: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+  %c0 = constant 0 : index
+  // CHECK: %[[X:.*]] = vector.gather %{{.*}}[%{{.*}}, %{{.*}}] [%{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  %0 = vector.gather %base[%c0, %c0][%v], %mask, %pass_thru : memref<?x?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  // CHECK: vector.scatter %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %[[X]] : memref<?x?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+  vector.scatter %base[%c0, %c0][%v], %mask, %0 : memref<?x?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
   return
 }
 
 // CHECK-LABEL: @expand_and_compress
 func @expand_and_compress(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
   %c0 = constant 0 : index
-  // CHECK: %[[X:.*]] = vector.expandload %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  // CHECK: %[[X:.*]] = vector.expandload %{{.*}}[%{{.*}}], %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
   %0 = vector.expandload %base[%c0], %mask, %pass_thru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
-  // CHECK: vector.compressstore %{{.*}}[{{.*}}], %{{.*}}, %[[X]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
+  // CHECK: vector.compressstore %{{.*}}[%{{.*}}], %{{.*}}, %[[X]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
   vector.compressstore %base[%c0], %mask, %0 : memref<?xf32>, vector<16xi1>, vector<16xf32>
   return
 }
 
+// CHECK-LABEL: @expand_and_compress2d
+func @expand_and_compress2d(%base: memref<?x?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+  %c0 = constant 0 : index
+  // CHECK: %[[X:.*]] = vector.expandload %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  %0 = vector.expandload %base[%c0, %c0], %mask, %pass_thru : memref<?x?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  // CHECK: vector.compressstore %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}, %[[X]] : memref<?x?xf32>, vector<16xi1>, vector<16xf32>
+  vector.compressstore %base[%c0, %c0], %mask, %0 : memref<?x?xf32>, vector<16xi1>, vector<16xf32>
+  return
+}
+
 // CHECK-LABEL: @extract_insert_map
 func @extract_insert_map(%v: vector<32xf32>, %v2: vector<16x32xf32>,
   %id0 : index, %id1 : index) -> (vector<32xf32>, vector<16x32xf32>) {

diff  --git a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
index 5c55cc5b6f0e..ab46a7863a55 100644
--- a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
@@ -85,12 +85,14 @@ func @maskedstore2(%base: memref<16xf32>, %value: vector<16xf32>)  {
 // CHECK-SAME:                  %[[A0:.*]]: memref<16xf32>,
 // CHECK-SAME:                  %[[A1:.*]]: vector<16xi32>,
 // CHECK-SAME:                  %[[A2:.*]]: vector<16xf32>) -> vector<16xf32> {
+// CHECK-NEXT:      %[[C:.*]] = constant 0 : index
 // CHECK-NEXT:      %[[M:.*]] = vector.constant_mask [16] : vector<16xi1>
-// CHECK-NEXT:      %[[G:.*]] = vector.gather %[[A0]][%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-NEXT:      %[[G:.*]] = vector.gather %[[A0]][%[[C]]] [%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
 // CHECK-NEXT:      return %[[G]] : vector<16xf32>
 func @gather1(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
+  %c0 = constant 0 : index
   %mask = vector.constant_mask [16] : vector<16xi1>
-  %ld = vector.gather %base[%indices], %mask, %pass_thru
+  %ld = vector.gather %base[%c0][%indices], %mask, %pass_thru
     : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
   return %ld : vector<16xf32>
 }
@@ -101,8 +103,9 @@ func @gather1(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vecto
 // CHECK-SAME:                  %[[A2:.*]]: vector<16xf32>) -> vector<16xf32> {
 // CHECK-NEXT:      return %[[A2]] : vector<16xf32>
 func @gather2(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
+  %c0 = constant 0 : index
   %mask = vector.constant_mask [0] : vector<16xi1>
-  %ld = vector.gather %base[%indices], %mask, %pass_thru
+  %ld = vector.gather %base[%c0][%indices], %mask, %pass_thru
     : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
   return %ld : vector<16xf32>
 }
@@ -111,12 +114,14 @@ func @gather2(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vecto
 // CHECK-SAME:                   %[[A0:.*]]: memref<16xf32>,
 // CHECK-SAME:                   %[[A1:.*]]: vector<16xi32>,
 // CHECK-SAME:                   %[[A2:.*]]: vector<16xf32>) {
+// CHECK-NEXT:      %[[C:.*]] = constant 0 : index
 // CHECK-NEXT:      %[[M:.*]] = vector.constant_mask [16] : vector<16xi1>
-// CHECK-NEXT:      vector.scatter %[[A0]][%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+// CHECK-NEXT:      vector.scatter %[[A0]][%[[C]]] [%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
 // CHECK-NEXT:      return
 func @scatter1(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {
+  %c0 = constant 0 : index
   %mask = vector.constant_mask [16] : vector<16xi1>
-  vector.scatter %base[%indices], %mask, %value
+  vector.scatter %base[%c0][%indices], %mask, %value
     : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
   return
 }
@@ -127,9 +132,10 @@ func @scatter1(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<1
 // CHECK-SAME:                   %[[A2:.*]]: vector<16xf32>) {
 // CHECK-NEXT:      return
 func @scatter2(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {
+  %c0 = constant 0 : index
   %0 = vector.type_cast %base : memref<16xf32> to memref<vector<16xf32>>
   %mask = vector.constant_mask [0] : vector<16xi1>
-  vector.scatter %base[%indices], %mask, %value
+  vector.scatter %base[%c0][%indices], %mask, %value
     : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
   return
 }

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-gather.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-gather.mlir
index 95df5aea06e4..0e5ea8557fbc 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-gather.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-gather.mlir
@@ -5,7 +5,8 @@
 
 func @gather8(%base: memref<?xf32>, %indices: vector<8xi32>,
               %mask: vector<8xi1>, %pass_thru: vector<8xf32>) -> vector<8xf32> {
-  %g = vector.gather %base[%indices], %mask, %pass_thru
+  %c0 = constant 0: index
+  %g = vector.gather %base[%c0][%indices], %mask, %pass_thru
     : memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
   return %g : vector<8xf32>
 }

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-scatter.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-scatter.mlir
index 0666cc852c2a..f4316faee5c7 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-scatter.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-scatter.mlir
@@ -6,7 +6,8 @@
 func @scatter8(%base: memref<?xf32>,
                %indices: vector<8xi32>,
                %mask: vector<8xi1>, %value: vector<8xf32>) {
-  vector.scatter %base[%indices], %mask, %value
+  %c0 = constant 0: index
+  vector.scatter %base[%c0][%indices], %mask, %value
     : memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>
   return
 }

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-sparse-dot-matvec.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-sparse-dot-matvec.mlir
index 7940e8c68b1a..4f399886dd91 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-sparse-dot-matvec.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-sparse-dot-matvec.mlir
@@ -64,7 +64,7 @@ func @spmv8x8(%AVAL: memref<8xvector<4xf32>>,
   scf.for %i = %c0 to %cn step %c1 {
     %aval = load %AVAL[%i] : memref<8xvector<4xf32>>
     %aidx = load %AIDX[%i] : memref<8xvector<4xi32>>
-    %0 = vector.gather %X[%aidx], %mask, %pass
+    %0 = vector.gather %X[%c0][%aidx], %mask, %pass
        : memref<?xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
     %1 = vector.contract #dot_trait %aval, %0, %f0 : vector<4xf32>, vector<4xf32> into f32
     store %1, %B[%i] : memref<?xf32>

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-sparse-saxpy-jagged-matvec.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-sparse-saxpy-jagged-matvec.mlir
index 31f288e0f6c5..bda77483806e 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-sparse-saxpy-jagged-matvec.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-sparse-saxpy-jagged-matvec.mlir
@@ -57,7 +57,7 @@ func @spmv8x8(%AVAL: memref<4xvector<8xf32>>,
   %b_out = scf.for %k = %c0 to %cn step %c1 iter_args(%b_iter = %b) -> (vector<8xf32>) {
     %aval = load %AVAL[%k] : memref<4xvector<8xf32>>
     %aidx = load %AIDX[%k] : memref<4xvector<8xi32>>
-    %0 = vector.gather %X[%aidx], %mask, %pass
+    %0 = vector.gather %X[%c0][%aidx], %mask, %pass
        : memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
     %b_new = vector.fma %aval, %0, %b_iter : vector<8xf32>
     scf.yield %b_new : vector<8xf32>


        


More information about the Mlir-commits mailing list