[Mlir-commits] [mlir] [mlir][Vector] Introduce vector.transfer_gather (PR #130113)

Kunwar Grover llvmlistbot at llvm.org
Thu Mar 6 06:32:25 PST 2025


https://github.com/Groverkss created https://github.com/llvm/llvm-project/pull/130113

None

>From 5e39a04ad581a6592e30bcf750228376aad55e6b Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Thu, 6 Mar 2025 14:23:36 +0000
Subject: [PATCH] [mlir][Vector] Introduce vector.transfer_gather

---
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 168 ++++
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 773 ++++++++++++++----
 mlir/test/Dialect/Vector/canonicalize.mlir    |  75 ++
 mlir/test/Dialect/Vector/ops.mlir             |  18 +
 4 files changed, 861 insertions(+), 173 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index fbbf817ecff98..bf4300ad3d06a 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1667,6 +1667,174 @@ def Vector_TransferWriteOp :
   let hasVerifier = 1;
 }
 
+// TODO: Tighten semantics so that masks and inbounds can't be used
+// simultaneously within the same transfer op.
+def Vector_TransferGatherOp :
+  Vector_Op<"transfer_gather", [
+      DeclareOpInterfaceMethods<VectorTransferOpInterface>,
+      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+      DeclareOpInterfaceMethods<ConditionallySpeculatable>,
+      AttrSizedOperandSegments,
+      DestinationStyleOpInterface
+    ]>,
+    Arguments<(ins AnyShaped:$source,
+                   Variadic<Index>:$indices,
+                   Variadic<VectorOfAnyRankOf<[Index]>>:$index_vecs,
+                   I64ArrayAttr:$gather_dims,
+                   AffineMapArrayAttr:$index_maps,
+                   AffineMapAttr:$permutation_map,
+                   AnyType:$padding,
+                   Optional<VectorOfNonZeroRankOf<[I1]>>:$mask,
+                   BoolArrayAttr:$in_bounds)>,
+    Results<(outs AnyVectorOfAnyRank:$vector)> {
+
+  let summary = "Gathers a supervector from memory into an SSA vector value.";
+
+  let description = [{
+    The `vector.transfer_gather` operation is a generalization of
+    `vector.transfer_read` op, where the slice from which the read is performed
+    is not guranteed to be contigious, and instead how the slice is gathered is
+    defined explicitly in the operation.
+
+    The operation can be thought of:
+      1. A contigious slice gathered from the source as described by the operation
+      2. A `vector.transfer_read` on the contigious slice
+
+    The operation defines `permutation_map`, `padding`, `mask`, `in_bounds` in
+    the same way as `vector.transfer_read` defines, but on the inferred
+    contigious slice.
+
+    The other parameters of the operation define how the contigious slice is
+    gathered from the source.
+
+    The `indices` contains a base to offset the source by. The dimensions of
+    the source which are gathered are specified as an array of indices in
+    `gather_dims`. Dimensions not specified in this array are contigious. For
+    example, for the following gather:
+
+    ```
+    slice[i, j, k] = source[i + i_offset][j][indices[i][j][k]]
+    ```
+
+    The operation would represent this as:
+
+    ```
+    indices = %i_offset, 0, 0
+    gather_dims = [2]
+    ```
+
+    For every dimension that is gathered, the operation defines how it is
+    gathered. For each gathered dimension, the operation expects a vector of
+    indices in `index_vecs` to act as a source of indices for that dimension
+    and an AffineMap in `index_maps` describing how this source of indices is
+    indexed. For example, for the following gather:
+
+    ```
+    slice[i, j, k] = source[i][indices0[i] + offset][indices1[j, k]]
+    ```
+
+    The indexing would be described by:
+
+    ```
+    indices      = 0, %offset, 0
+    gather_dims  = [1, 2]
+    index_vecs   = %index_vec1, %index_vec2
+    index_maps = [
+      affine_map<(i, j, k) -> (i),
+      affine_map<(i, j, k) -> (j, k)
+    ]
+    ```
+
+    With these additional parameters, the operation can define a supervector
+    read from a non-contigious slice. For example:
+
+    ```
+    source: memref<8192x8x16xf32>
+    indices0 : vector<2xindex>
+    indices1 : vector<4x8xindex>
+
+    slice[i, j, k] = source[indices0[k]][j][indices1[i, j]]
+    vector = read(slice) : memref<8192x8x16xf32> -> vector<2x8x16xf32>
+    ```
+
+    Can be represented by:
+
+    ```
+    %vector = vector.transfer_gather %source[0, 0, 0](%indices0, %indices1) {
+      gather_dims = [0, 2],
+      index_maps = [
+        affine_map<(i, j, k) -> (k)>,
+        affine_map<(i, j, k) -> (i, j)>
+      ],
+      in_bounds = [true, true, true],
+      permutation_map = affine_map<(i, j, k) -> (i, j, k)>
+    } : memref<8192x8x16xf32> -> vector<2x8x16xf32>
+    ```
+  }];
+
+  let builders = [
+    /// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
+    OpBuilder<(ins "VectorType":$vectorType,
+                   "Value":$source,
+                   "ValueRange":$indices,
+                   "ValueRange":$index_vecs,
+                   "ArrayAttr":$gather_dims,
+                   "ArrayAttr":$index_maps,
+                   "AffineMapAttr":$permutationMapAttr,
+                   "ArrayAttr":$inBoundsAttr)>,
+    /// 2. Builder that sets padding to zero and an empty mask (variant without attrs).
+    OpBuilder<(ins "VectorType":$vectorType,
+                   "Value":$source,
+                   "ValueRange":$indices,
+                   "ValueRange":$index_vecs,
+                   "ArrayRef<int64_t>":$gather_dims,
+                   "ArrayRef<AffineMap>":$index_maps,
+                   "AffineMap":$permutationMap,
+                   CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>,
+    /// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
+    OpBuilder<(ins "VectorType":$vectorType,
+                   "Value":$source,
+                   "ValueRange":$indices,
+                   "ValueRange":$index_vecs,
+                   "ArrayRef<int64_t>":$gather_dims,
+                   "ArrayRef<AffineMap>":$index_maps,
+                   "Value":$padding,
+                   CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>,
+    /// 4. Builder that sets padding to zero and permutation map to
+    /// 'getMinorIdentityMap'.
+    OpBuilder<(ins "VectorType":$vectorType,
+                   "Value":$source,
+                   "ValueRange":$indices,
+                   "ValueRange":$index_vecs,
+                   "ArrayRef<int64_t>":$gather_dims,
+                   "ArrayRef<AffineMap>":$index_maps,
+                   CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>,
+  ];
+
+  let extraClassDeclaration = [{
+    // MaskableOpInterface methods.
+    bool supportsPassthru() { return true; }
+
+    MutableOperandRange getDpsInitsMutable() {
+      return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
+    }
+
+    SmallVector<int64_t> getGatherDimsArray() {
+      return llvm::map_to_vector(getGatherDims().getAsValueRange<IntegerAttr>(), 
+        [](APInt dim) { return dim.getSExtValue(); });
+    }
+
+    SmallVector<AffineMap> getIndexMapsArray() {
+      return llvm::to_vector(getIndexMaps().getAsValueRange<AffineMapAttr>());
+    }
+  }];
+
+  let hasCanonicalizer = 1;
+  let hasCustomAssemblyFormat = 1;
+  let hasFolder = 1;
+  let hasVerifier = 1;
+}
+
 def Vector_LoadOp : Vector_Op<"load"> {
   let summary = "reads an n-D slice of memory into an n-D vector";
   let description = [{
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ff323983a17c0..51272d8ac4209 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
@@ -4026,64 +4027,9 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
 }
 
 //===----------------------------------------------------------------------===//
-// TransferReadOp
+// VectorTransferOpInterface
 //===----------------------------------------------------------------------===//
 
-/// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
-void TransferReadOp::build(OpBuilder &builder, OperationState &result,
-                           VectorType vectorType, Value source,
-                           ValueRange indices, AffineMapAttr permutationMapAttr,
-                           /*optional*/ ArrayAttr inBoundsAttr) {
-  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
-  Value padding = builder.create<arith::ConstantOp>(
-      result.location, elemType, builder.getZeroAttr(elemType));
-  build(builder, result, vectorType, source, indices, permutationMapAttr,
-        padding, /*mask=*/Value(), inBoundsAttr);
-}
-
-/// 2. Builder that sets padding to zero an empty mask (variant without attrs).
-void TransferReadOp::build(OpBuilder &builder, OperationState &result,
-                           VectorType vectorType, Value source,
-                           ValueRange indices, AffineMap permutationMap,
-                           std::optional<ArrayRef<bool>> inBounds) {
-  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
-  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
-                          ? builder.getBoolArrayAttr(inBounds.value())
-                          : builder.getBoolArrayAttr(
-                                SmallVector<bool>(vectorType.getRank(), false));
-  build(builder, result, vectorType, source, indices, permutationMapAttr,
-        inBoundsAttr);
-}
-
-/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
-void TransferReadOp::build(OpBuilder &builder, OperationState &result,
-                           VectorType vectorType, Value source,
-                           ValueRange indices, Value padding,
-                           std::optional<ArrayRef<bool>> inBounds) {
-  AffineMap permutationMap = getTransferMinorIdentityMap(
-      llvm::cast<ShapedType>(source.getType()), vectorType);
-  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
-  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
-                          ? builder.getBoolArrayAttr(inBounds.value())
-                          : builder.getBoolArrayAttr(
-                                SmallVector<bool>(vectorType.getRank(), false));
-  build(builder, result, vectorType, source, indices, permutationMapAttr,
-        padding,
-        /*mask=*/Value(), inBoundsAttr);
-}
-
-/// 4. Builder that sets padding to zero and permutation map to
-/// 'getMinorIdentityMap'.
-void TransferReadOp::build(OpBuilder &builder, OperationState &result,
-                           VectorType vectorType, Value source,
-                           ValueRange indices,
-                           std::optional<ArrayRef<bool>> inBounds) {
-  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
-  Value padding = builder.create<arith::ConstantOp>(
-      result.location, elemType, builder.getZeroAttr(elemType));
-  build(builder, result, vectorType, source, indices, padding, inBounds);
-}
-
 template <typename EmitFun>
 static LogicalResult verifyPermutationMap(AffineMap permutationMap,
                                           EmitFun emitOpError) {
@@ -4204,14 +4150,6 @@ static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
   p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
 }
 
-void TransferReadOp::print(OpAsmPrinter &p) {
-  p << " " << getSource() << "[" << getIndices() << "], " << getPadding();
-  if (getMask())
-    p << ", " << getMask();
-  printTransferAttrs(p, *this);
-  p << " : " << getShapedType() << ", " << getVectorType();
-}
-
 VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
                                                  AffineMap permMap) {
   auto i1Type = IntegerType::get(permMap.getContext(), 1);
@@ -4230,6 +4168,158 @@ VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
   return VectorType::get(maskShape, i1Type, scalableDims);
 }
 
+template <typename TransferOp>
+static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
+  // TODO: support more aggressive createOrFold on:
+  // op.getIndices()[indicesIdx] + vectorType < dim(op.getSource(), indicesIdx)
+  if (op.getShapedType().isDynamicDim(indicesIdx))
+    return false;
+  Value index = op.getIndices()[indicesIdx];
+  std::optional<int64_t> cstOp = getConstantIntValue(index);
+  if (!cstOp.has_value())
+    return false;
+
+  int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
+  int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
+
+  return cstOp.value() + vectorSize <= sourceSize;
+}
+
+template <typename TransferOp>
+static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
+  // TODO: support 0-d corner case.
+  // TODO: Be less conservative.
+  if (op.getTransferRank() == 0)
+    return failure();
+  AffineMap permutationMap = op.getPermutationMap();
+  bool changed = false;
+  SmallVector<bool, 4> newInBounds;
+  newInBounds.reserve(op.getTransferRank());
+  // Idxs of non-bcast dims - used when analysing bcast dims.
+  SmallVector<unsigned> nonBcastDims;
+
+  // 1. Process non-broadcast dims
+  for (unsigned i = 0; i < op.getTransferRank(); ++i) {
+    // 1.1. Already marked as in-bounds, nothing to see here.
+    if (op.isDimInBounds(i)) {
+      newInBounds.push_back(true);
+      continue;
+    }
+    // 1.2. Currently out-of-bounds, check whether we can statically determine
+    // it is inBounds.
+    bool inBounds = false;
+    auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(i));
+    if (dimExpr) {
+      inBounds = isInBounds(op, /*resultIdx=*/i,
+                            /*indicesIdx=*/dimExpr.getPosition());
+      nonBcastDims.push_back(i);
+    }
+
+    newInBounds.push_back(inBounds);
+    // We commit the pattern if it is "more inbounds".
+    changed |= inBounds;
+  }
+
+  // 2. Handle broadcast dims
+  // If all non-broadcast dims are "in bounds", then all bcast dims should be
+  // "in bounds" as well.
+  bool allNonBcastDimsInBounds = llvm::all_of(
+      nonBcastDims, [&newInBounds](unsigned idx) { return newInBounds[idx]; });
+  if (allNonBcastDimsInBounds) {
+    for (size_t idx : permutationMap.getBroadcastDims()) {
+      changed |= !newInBounds[idx];
+      newInBounds[idx] = true;
+    }
+  }
+
+  if (!changed)
+    return failure();
+  // OpBuilder is only used as a helper to build an I64ArrayAttr.
+  OpBuilder b(op.getContext());
+  op.setInBoundsAttr(b.getBoolArrayAttr(newInBounds));
+  return success();
+}
+
+template <typename TransferOp>
+static LogicalResult foldTransferFullMask(TransferOp op) {
+  auto mask = op.getMask();
+  if (!mask)
+    return failure();
+
+  if (getMaskFormat(mask) != MaskFormat::AllTrue)
+    return failure();
+
+  op.getMaskMutable().clear();
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// TransferReadOp
+//===----------------------------------------------------------------------===//
+
+/// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
+void TransferReadOp::build(OpBuilder &builder, OperationState &result,
+                           VectorType vectorType, Value source,
+                           ValueRange indices, AffineMapAttr permutationMapAttr,
+                           /*optional*/ ArrayAttr inBoundsAttr) {
+  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
+  Value padding = builder.create<arith::ConstantOp>(
+      result.location, elemType, builder.getZeroAttr(elemType));
+  build(builder, result, vectorType, source, indices, permutationMapAttr,
+        padding, /*mask=*/Value(), inBoundsAttr);
+}
+
+/// 2. Builder that sets padding to zero an empty mask (variant without attrs).
+void TransferReadOp::build(OpBuilder &builder, OperationState &result,
+                           VectorType vectorType, Value source,
+                           ValueRange indices, AffineMap permutationMap,
+                           std::optional<ArrayRef<bool>> inBounds) {
+  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
+  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
+                          ? builder.getBoolArrayAttr(inBounds.value())
+                          : builder.getBoolArrayAttr(
+                                SmallVector<bool>(vectorType.getRank(), false));
+  build(builder, result, vectorType, source, indices, permutationMapAttr,
+        inBoundsAttr);
+}
+
+/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
+void TransferReadOp::build(OpBuilder &builder, OperationState &result,
+                           VectorType vectorType, Value source,
+                           ValueRange indices, Value padding,
+                           std::optional<ArrayRef<bool>> inBounds) {
+  AffineMap permutationMap = getTransferMinorIdentityMap(
+      llvm::cast<ShapedType>(source.getType()), vectorType);
+  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
+  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
+                          ? builder.getBoolArrayAttr(inBounds.value())
+                          : builder.getBoolArrayAttr(
+                                SmallVector<bool>(vectorType.getRank(), false));
+  build(builder, result, vectorType, source, indices, permutationMapAttr,
+        padding,
+        /*mask=*/Value(), inBoundsAttr);
+}
+
+/// 4. Builder that sets padding to zero and permutation map to
+/// 'getMinorIdentityMap'.
+void TransferReadOp::build(OpBuilder &builder, OperationState &result,
+                           VectorType vectorType, Value source,
+                           ValueRange indices,
+                           std::optional<ArrayRef<bool>> inBounds) {
+  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
+  Value padding = builder.create<arith::ConstantOp>(
+      result.location, elemType, builder.getZeroAttr(elemType));
+  build(builder, result, vectorType, source, indices, padding, inBounds);
+}
+
+void TransferReadOp::print(OpAsmPrinter &p) {
+  p << " " << getSource() << "[" << getIndices() << "], " << getPadding();
+  if (getMask())
+    p << ", " << getMask();
+  printTransferAttrs(p, *this);
+  p << " : " << getShapedType() << ", " << getVectorType();
+}
+
 ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
   auto &builder = parser.getBuilder();
   SMLoc typesLoc;
@@ -4354,115 +4444,30 @@ Type TransferReadOp::getExpectedMaskType() {
   return inferTransferOpMaskType(getVectorType(), getPermutationMap());
 }
 
-template <typename TransferOp>
-static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
-  // TODO: support more aggressive createOrFold on:
-  // op.getIndices()[indicesIdx] + vectorType < dim(op.getSource(), indicesIdx)
-  if (op.getShapedType().isDynamicDim(indicesIdx))
-    return false;
-  Value index = op.getIndices()[indicesIdx];
-  std::optional<int64_t> cstOp = getConstantIntValue(index);
-  if (!cstOp.has_value())
-    return false;
-
-  int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
-  int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
-
-  return cstOp.value() + vectorSize <= sourceSize;
-}
-
-template <typename TransferOp>
-static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
-  // TODO: support 0-d corner case.
-  // TODO: Be less conservative.
-  if (op.getTransferRank() == 0)
-    return failure();
-  AffineMap permutationMap = op.getPermutationMap();
-  bool changed = false;
-  SmallVector<bool, 4> newInBounds;
-  newInBounds.reserve(op.getTransferRank());
-  // Idxs of non-bcast dims - used when analysing bcast dims.
-  SmallVector<unsigned> nonBcastDims;
-
-  // 1. Process non-broadcast dims
-  for (unsigned i = 0; i < op.getTransferRank(); ++i) {
-    // 1.1. Already marked as in-bounds, nothing to see here.
-    if (op.isDimInBounds(i)) {
-      newInBounds.push_back(true);
-      continue;
-    }
-    // 1.2. Currently out-of-bounds, check whether we can statically determine
-    // it is inBounds.
-    bool inBounds = false;
-    auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(i));
-    if (dimExpr) {
-      inBounds = isInBounds(op, /*resultIdx=*/i,
-                            /*indicesIdx=*/dimExpr.getPosition());
-      nonBcastDims.push_back(i);
-    }
-
-    newInBounds.push_back(inBounds);
-    // We commit the pattern if it is "more inbounds".
-    changed |= inBounds;
-  }
-
-  // 2. Handle broadcast dims
-  // If all non-broadcast dims are "in bounds", then all bcast dims should be
-  // "in bounds" as well.
-  bool allNonBcastDimsInBounds = llvm::all_of(
-      nonBcastDims, [&newInBounds](unsigned idx) { return newInBounds[idx]; });
-  if (allNonBcastDimsInBounds) {
-    for (size_t idx : permutationMap.getBroadcastDims()) {
-      changed |= !newInBounds[idx];
-      newInBounds[idx] = true;
-    }
-  }
-
-  if (!changed)
-    return failure();
-  // OpBuilder is only used as a helper to build an I64ArrayAttr.
-  OpBuilder b(op.getContext());
-  op.setInBoundsAttr(b.getBoolArrayAttr(newInBounds));
-  return success();
-}
-
-template <typename TransferOp>
-static LogicalResult foldTransferFullMask(TransferOp op) {
-  auto mask = op.getMask();
-  if (!mask)
-    return failure();
-
-  if (getMaskFormat(mask) != MaskFormat::AllTrue)
-    return failure();
-
-  op.getMaskMutable().clear();
-  return success();
-}
-
-///  ```
-///  %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
-///    : vector<1x4xf32>, tensor<4x4xf32>
-///  %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
-///    : tensor<4x4xf32>, vector<1x4xf32>
-///  ```
-///  -> Folds into
-///  ```
-///  %v0
-///  ```
-static Value foldRAW(TransferReadOp readOp) {
-  if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
-    return {};
-  auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
-  while (defWrite) {
-    if (checkSameValueRAW(defWrite, readOp))
-      return defWrite.getVector();
-    if (!isDisjointTransferIndices(
-            cast<VectorTransferOpInterface>(defWrite.getOperation()),
-            cast<VectorTransferOpInterface>(readOp.getOperation())))
-      break;
-    defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
-  }
-  return {};
+///  ```
+///  %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
+///    : vector<1x4xf32>, tensor<4x4xf32>
+///  %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
+///    : tensor<4x4xf32>, vector<1x4xf32>
+///  ```
+///  -> Folds into
+///  ```
+///  %v0
+///  ```
+static Value foldRAW(TransferReadOp readOp) {
+  if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
+    return {};
+  auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
+  while (defWrite) {
+    if (checkSameValueRAW(defWrite, readOp))
+      return defWrite.getVector();
+    if (!isDisjointTransferIndices(
+            cast<VectorTransferOpInterface>(defWrite.getOperation()),
+            cast<VectorTransferOpInterface>(readOp.getOperation())))
+      break;
+    defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
+  }
+  return {};
 }
 
 OpFoldResult TransferReadOp::fold(FoldAdaptor) {
@@ -4589,6 +4594,428 @@ void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<TransferReadAfterWriteToBroadcast>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// TransferGatherOp
+//===----------------------------------------------------------------------===//
+
+/// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
+void TransferGatherOp::build(OpBuilder &builder, OperationState &result,
+                             VectorType vectorType, Value source,
+                             ValueRange indices, ValueRange indexVecs,
+                             ArrayAttr gatherDims, ArrayAttr indexMaps,
+                             AffineMapAttr permutationMapAttr,
+                             /*optional*/ ArrayAttr inBoundsAttr) {
+  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
+  Value padding = builder.create<arith::ConstantOp>(
+      result.location, elemType, builder.getZeroAttr(elemType));
+  build(builder, result, vectorType, source, indices, indexVecs, gatherDims,
+        indexMaps, permutationMapAttr, padding, /*mask=*/Value(), inBoundsAttr);
+}
+
+/// 2. Builder that sets padding to zero an empty mask (variant without attrs).
+void TransferGatherOp::build(OpBuilder &builder, OperationState &result,
+                             VectorType vectorType, Value source,
+                             ValueRange indices, ValueRange indexVecs,
+                             ArrayRef<int64_t> gatherDims,
+                             ArrayRef<AffineMap> indexMaps,
+                             AffineMap permutationMap,
+                             std::optional<ArrayRef<bool>> inBounds) {
+  auto permutationMapAttr = AffineMapAttr::get(permutationMap);
+  auto indexedAttr = builder.getI64ArrayAttr(gatherDims);
+  auto indexMapsAttr = builder.getAffineMapArrayAttr(indexMaps);
+  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
+                          ? builder.getBoolArrayAttr(inBounds.value())
+                          : builder.getBoolArrayAttr(
+                                SmallVector<bool>(vectorType.getRank(), false));
+  build(builder, result, vectorType, source, indices, indexVecs, indexedAttr,
+        indexMapsAttr, permutationMapAttr, inBoundsAttr);
+}
+
+/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
+void TransferGatherOp::build(OpBuilder &builder, OperationState &result,
+                             VectorType vectorType, Value source,
+                             ValueRange indices, ValueRange indexVecs,
+                             ArrayRef<int64_t> gatherDims,
+                             ArrayRef<AffineMap> indexMaps, Value padding,
+                             std::optional<ArrayRef<bool>> inBounds) {
+  AffineMap permutationMap = getTransferMinorIdentityMap(
+      llvm::cast<ShapedType>(source.getType()), vectorType);
+  auto inBoundsAttr = (inBounds && !inBounds.value().empty())
+                          ? builder.getBoolArrayAttr(inBounds.value())
+                          : builder.getBoolArrayAttr(
+                                SmallVector<bool>(vectorType.getRank(), false));
+  build(builder, result, vectorType, source, indices, indexVecs,
+        builder.getI64ArrayAttr(gatherDims),
+        builder.getAffineMapArrayAttr(indexMaps), permutationMap, padding,
+        /*mask=*/Value(), inBoundsAttr);
+}
+
+/// 4. Builder that sets padding to zero and permutation map to
+/// 'getMinorIdentityMap'.
+void TransferGatherOp::build(OpBuilder &builder, OperationState &result,
+                             VectorType vectorType, Value source,
+                             ValueRange indices, ValueRange indexVecs,
+                             ArrayRef<int64_t> gatherDims,
+                             ArrayRef<AffineMap> indexMaps,
+                             std::optional<ArrayRef<bool>> inBounds) {
+  Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
+  Value padding = builder.create<arith::ConstantOp>(
+      result.location, elemType, builder.getZeroAttr(elemType));
+  build(builder, result, vectorType, source, indices, indexVecs, gatherDims,
+        indexMaps, padding, inBounds);
+}
+
+ParseResult TransferGatherOp::parse(OpAsmParser &parser,
+                                    OperationState &result) {
+  auto &builder = parser.getBuilder();
+  SMLoc typesLoc;
+  OpAsmParser::UnresolvedOperand sourceInfo;
+  SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
+  SmallVector<OpAsmParser::UnresolvedOperand, 8> indexVecInfo;
+  OpAsmParser::UnresolvedOperand paddingInfo;
+  SmallVector<Type, 2> types;
+  OpAsmParser::UnresolvedOperand maskInfo;
+  // Parsing with support for paddingValue.
+  if (parser.parseOperand(sourceInfo) ||
+      parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
+      parser.parseOperandList(indexVecInfo, OpAsmParser::Delimiter::Paren) ||
+      parser.parseComma() || parser.parseOperand(paddingInfo))
+    return failure();
+
+  ParseResult hasMask = parser.parseOptionalComma();
+  if (hasMask.succeeded()) {
+    if (parser.parseOperand(maskInfo))
+      return failure();
+  }
+
+  // Parse attributes and types.
+  if (parser.parseOptionalAttrDict(result.attributes) ||
+      parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
+    return failure();
+
+  // Check if number of types given are correct.
+  size_t nRequiredTypes = indexVecInfo.size() + 2;
+  if (types.size() != nRequiredTypes) {
+    return parser.emitError(typesLoc, "expected ")
+           << nRequiredTypes << " types";
+  }
+
+  // The types are arranged as:
+  // sourceTy, resultTy, *indexVecTy
+  auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
+  VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
+  ArrayRef<Type> indexVecTy(types.begin() + 2,
+                            types.begin() + indexVecInfo.size() + 2);
+  if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
+    return parser.emitError(typesLoc, "requires memref or ranked tensor type");
+  if (!vectorType)
+    return parser.emitError(typesLoc, "requires vector type");
+  auto permMapAttrName =
+      TransferGatherOp::getPermutationMapAttrName(result.name);
+  Attribute permMapAttr = result.attributes.get(permMapAttrName);
+  AffineMap permMap;
+  if (!permMapAttr) {
+    permMap = vector::getTransferMinorIdentityMap(shapedType, vectorType);
+    result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
+  } else {
+    permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
+  }
+  auto inBoundsAttrName = TransferGatherOp::getInBoundsAttrName(result.name);
+  Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
+  if (!inBoundsAttr) {
+    result.addAttribute(inBoundsAttrName,
+                        builder.getBoolArrayAttr(
+                            SmallVector<bool>(permMap.getNumResults(), false)));
+  }
+  if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
+      parser.resolveOperands(indexInfo, builder.getIndexType(),
+                             result.operands) ||
+      parser.resolveOperands(indexVecInfo, indexVecTy, typesLoc,
+                             result.operands) ||
+      parser.resolveOperand(paddingInfo, shapedType.getElementType(),
+                            result.operands))
+    return failure();
+  if (hasMask.succeeded()) {
+    if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
+      return parser.emitError(
+          maskInfo.location, "does not support masks with vector element type");
+    if (vectorType.getRank() != permMap.getNumResults()) {
+      return parser.emitError(typesLoc,
+                              "expected the same rank for the vector and the "
+                              "results of the permutation map");
+    }
+    // Instead of adding the mask type as an op type, compute it based on the
+    // vector type and the permutation map (to keep the type signature small).
+    auto maskType = vector::inferTransferOpMaskType(vectorType, permMap);
+    if (parser.resolveOperand(maskInfo, maskType, result.operands))
+      return failure();
+  }
+  result.addAttribute(TransferGatherOp::getOperandSegmentSizeAttr(),
+                      builder.getDenseI32ArrayAttr(
+                          {1, static_cast<int32_t>(indexInfo.size()),
+                           static_cast<int32_t>(indexVecInfo.size()), 1,
+                           static_cast<int32_t>(hasMask.succeeded())}));
+  return parser.addTypeToList(vectorType, result.types);
+}
+
+LogicalResult TransferGatherOp::verify() {
+  // Consistency of elemental types in source and vector.
+  ShapedType shapedType = getShapedType();
+  VectorType vectorType = getVectorType();
+  VectorType maskType = getMaskType();
+  auto paddingType = getPadding().getType();
+  auto permutationMap = getPermutationMap();
+  VectorType inferredMaskType =
+      maskType ? inferTransferOpMaskType(vectorType, permutationMap)
+               : VectorType();
+  auto sourceElementType = shapedType.getElementType();
+
+  if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
+    return emitOpError("requires ") << shapedType.getRank() << " indices";
+
+  if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
+                              shapedType, vectorType, maskType,
+                              inferredMaskType, permutationMap, getInBounds())))
+    return failure();
+
+  if (auto sourceVectorElementType =
+          llvm::dyn_cast<VectorType>(sourceElementType)) {
+    // Source has vector element type.
+    // Check that 'sourceVectorElementType' and 'paddingType' types match.
+    if (sourceVectorElementType != paddingType)
+      return emitOpError(
+          "requires source element type and padding type to match.");
+
+  } else {
+    // Check that 'paddingType' is valid to store in a vector type.
+    if (!VectorType::isValidElementType(paddingType))
+      return emitOpError("requires valid padding vector elemental type");
+
+    // Check that padding type and vector element types match.
+    if (paddingType != sourceElementType)
+      return emitOpError(
+          "requires formal padding and source of the same elemental type");
+  }
+
+  return verifyPermutationMap(permutationMap,
+                              [&](Twine t) { return emitOpError(t); });
+  return success();
+}
+
+void TransferGatherOp::print(OpAsmPrinter &p) {
+  p << " " << getSource() << "[" << getIndices() << "](" << getIndexVecs()
+    << "), " << getPadding();
+  if (getMask())
+    p << ", " << getMask();
+  printTransferAttrs(p, *this);
+  p << " : " << getShapedType() << ", " << getVectorType() << ", "
+    << getIndexVecs().getTypes();
+}
+
+void TransferGatherOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  if (llvm::isa<MemRefType>(getShapedType()))
+    effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable(),
+                         SideEffects::DefaultResource::get());
+}
+
+Speculation::Speculatability TransferGatherOp::getSpeculatability() {
+  if (hasPureTensorSemantics())
+    return Speculation::Speculatable;
+  return Speculation::NotSpeculatable;
+}
+
+struct IndexVecFoldResult {
+  Value indexVec;
+  AffineMap indexMap;
+  bool changed;
+};
+
+static Value foldTransferGatherIndexVecs(
+    TransferGatherOp gatherOp,
+    function_ref<IndexVecFoldResult(Value, AffineMap, int64_t)>
+        indexVecFolder) {
+  bool changed = false;
+  SmallVector<Value> newIndexVecs;
+  SmallVector<AffineMap> newIndexMaps;
+  SmallVector<int64_t> gatherDims;
+  for (auto [i, operand, map, index] :
+       llvm::enumerate(gatherOp.getIndexVecs(), gatherOp.getIndexMapsArray(),
+                       gatherOp.getGatherDimsArray())) {
+    auto [indexVec, indexMap, vecChanged] = indexVecFolder(operand, map, index);
+    changed |= vecChanged;
+    if (!indexVec) {
+      continue;
+    }
+    newIndexVecs.push_back(indexVec);
+    newIndexMaps.push_back(indexMap);
+    gatherDims.push_back(i);
+  }
+
+  if (!changed) {
+    return Value();
+  }
+
+  OpBuilder b(gatherOp);
+
+  SmallVector<Value> operands;
+  SmallVector<int32_t> operandSegmentSizes;
+
+  // Source.
+  operands.push_back(gatherOp.getSource());
+  operandSegmentSizes.push_back(1);
+  // Indices.
+  SmallVector<Value> indices = gatherOp.getIndices();
+  operands.append(indices);
+  operandSegmentSizes.push_back(indices.size());
+  // IndexVecs.
+  operands.append(newIndexVecs);
+  operandSegmentSizes.push_back(newIndexVecs.size());
+  // Padding.
+  operands.push_back(gatherOp.getPadding());
+  operandSegmentSizes.push_back(1);
+  // Mask.
+  if (gatherOp.getMask()) {
+    operands.push_back(gatherOp.getMask());
+    operandSegmentSizes.push_back(1);
+  } else {
+    operandSegmentSizes.push_back(0);
+  }
+
+  gatherOp.setIndexMapsAttr(b.getAffineMapArrayAttr(newIndexMaps));
+  gatherOp->setOperands(operands);
+  gatherOp.setGatherDimsAttr(b.getI64ArrayAttr(gatherDims));
+  gatherOp.getProperties().setOperandSegmentSizes(operandSegmentSizes);
+
+  return gatherOp.getResult();
+}
+
+static int64_t getVectorRank(Type type) {
+  return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
+                                     : 0;
+}
+
+static Value foldTransferGatherFromBroadcast(TransferGatherOp gatherOp) {
+  return foldTransferGatherIndexVecs(
+      gatherOp,
+      [](Value operand, AffineMap map, int64_t) -> IndexVecFoldResult {
+        auto broadcast = operand.getDefiningOp<vector::BroadcastOp>();
+        if (!broadcast) {
+          return {operand, map, false};
+        }
+
+        int64_t sourceRank = getVectorRank(broadcast.getSourceType());
+        int64_t operandRank = getVectorRank(broadcast.getResultVectorType());
+        AffineMap newMap =
+            map.getSliceMap(operandRank - sourceRank, sourceRank);
+        return {broadcast.getSource(), newMap, true};
+      });
+}
+
+static Value foldTransferGatherFromTranspose(TransferGatherOp gatherOp) {
+  return foldTransferGatherIndexVecs(
+      gatherOp,
+      [](Value operand, AffineMap map, int64_t) -> IndexVecFoldResult {
+        auto transpose = operand.getDefiningOp<vector::TransposeOp>();
+        if (!transpose) {
+          return {operand, map, false};
+        }
+
+        AffineMap newMap = map.compose(AffineMap::getPermutationMap(
+            invertPermutationVector(transpose.getPermutation()),
+            transpose.getContext()));
+        return {transpose.getVector(), newMap, true};
+      });
+}
+
+static Value foldTransferGatherFromStep(TransferGatherOp gatherOp) {
+  return foldTransferGatherIndexVecs(
+      gatherOp,
+      [](Value operand, AffineMap map, int64_t) -> IndexVecFoldResult {
+        auto step = operand.getDefiningOp<vector::StepOp>();
+        if (!step) {
+          return {operand, map, false};
+        }
+
+        return {Value(), AffineMap(), true};
+      });
+}
+
+OpFoldResult TransferGatherOp::fold(FoldAdaptor adaptor) {
+  if (auto res = foldTransferGatherFromBroadcast(*this)) {
+    return res;
+  }
+  if (auto res = foldTransferGatherFromTranspose(*this)) {
+    return res;
+  }
+  if (auto res = foldTransferGatherFromStep(*this)) {
+    return res;
+  }
+  return OpFoldResult();
+}
+
+struct FoldSingleElementIndexVec : public OpRewritePattern<TransferGatherOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TransferGatherOp xferOp,
+                                PatternRewriter &rewriter) const override {
+
+    auto indexVecFolder = [&](Value indexVec, AffineMap map,
+                              int64_t index) -> IndexVecFoldResult {
+      auto vectorTy = cast<VectorType>(indexVec.getType());
+      if (vectorTy.getNumElements() != 1) {
+        return {indexVec, map, false};
+      }
+
+      // Extract the scalar and add it to the
+      // corressponding base.
+      OpOperand &base = xferOp.getIndicesMutable()[index];
+      Value extracted = rewriter.create<vector::ExtractOp>(
+          xferOp.getLoc(), indexVec,
+          SmallVector<int64_t>(vectorTy.getRank(), 0));
+      Value newIndex = rewriter.create<arith::AddIOp>(indexVec.getLoc(),
+                                                      base.get(), extracted);
+      base.set(newIndex);
+
+      return {Value(), AffineMap(), true};
+    };
+
+    Value newVal = foldTransferGatherIndexVecs(xferOp, indexVecFolder);
+
+    if (!newVal) {
+      return failure();
+    }
+
+    return success();
+  }
+};
+
+struct FoldContigousGatherToTransferRead
+    : public OpRewritePattern<TransferGatherOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TransferGatherOp xferOp,
+                                PatternRewriter &rewriter) const override {
+    if (!xferOp.getIndexVecs().empty()) {
+      return failure();
+    }
+
+    // Canonicalize to vector.transfer_read.
+    rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
+        xferOp, xferOp.getVectorType(), xferOp.getSource(), xferOp.getIndices(),
+        xferOp.getPermutationMap(), xferOp.getPadding(), xferOp.getMask(),
+        xferOp.getInBounds());
+    return success();
+  };
+};
+
+void TransferGatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                                   MLIRContext *ctx) {
+  results.add<FoldSingleElementIndexVec, FoldContigousGatherToTransferRead>(
+      ctx);
+}
+
 //===----------------------------------------------------------------------===//
 // TransferWriteOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index bf755b466c7eb..3ec6a239f589d 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1075,6 +1075,81 @@ func.func @fold_vector_transfers(%A: memref<?x8xf32>) -> (vector<4x8xf32>, vecto
 
 // -----
 
+func.func @fold_vector_transfer_gather_broadcast(%A: memref<?x?xf32>, %idx: vector<4xindex>) -> vector<4x4xf32> {
+  %c0 = arith.constant 0 : index
+  %f0 = arith.constant 0.0 : f32
+
+  // CHECK-NOT: vector.broadcast
+  %bidx = vector.broadcast %idx : vector<4xindex> to vector<4x4xindex>
+
+  // CHECK: vector.transfer_gather
+  // CHECK-SAME: memref<?x?xf32>, vector<4x4xf32>, vector<4xindex>
+  %1 = vector.transfer_gather %A[%c0, %c0](%bidx), %f0 
+      { gather_dims = [0], index_maps = [affine_map<(d0, d1) -> (d0, d1)>] }
+      : memref<?x?xf32>, vector<4x4xf32>, vector<4x4xindex>
+
+      return %1 : vector<4x4xf32>
+}
+
+func.func @fold_vector_transfer_gather_transpose(%A: memref<?x?xf32>, %idx: vector<4x4xindex>) -> vector<4x4xf32> {
+  %c0 = arith.constant 0 : index
+  %f0 = arith.constant 0.0 : f32
+
+  // CHECK-NOT: vector.transpose
+  %tidx = vector.transpose %idx, [1, 0] : vector<4x4xindex> to vector<4x4xindex>
+
+  // CHECK: vector.transfer_gather
+  // CHECK-SAME: memref<?x?xf32>, vector<4x4xf32>, vector<4x4xindex>
+  %1 = vector.transfer_gather %A[%c0, %c0](%tidx), %f0 
+      { gather_dims = [0], index_maps = [affine_map<(d0, d1) -> (d0, d1)>] }
+      : memref<?x?xf32>, vector<4x4xf32>, vector<4x4xindex>
+
+      return %1 : vector<4x4xf32>
+}
+
+func.func @fold_vector_transfer_gather_step(%A: memref<?x?xf32>, %idx : vector<4x4xindex>) -> vector<4x4xf32> {
+  %c0 = arith.constant 0 : index
+  %f0 = arith.constant 0.0 : f32
+
+  // CHECK-NOT: vector.step
+  %sidx = vector.step : vector<4xindex>
+
+  // CHECK: vector.transfer_gather
+  // CHECK-SAME: memref<?x?xf32>, vector<4x4xf32>, vector<4x4xindex>
+  %1 = vector.transfer_gather %A[%c0, %c0](%sidx, %idx), %f0 
+      { gather_dims = [0, 1], index_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>] }
+      : memref<?x?xf32>, vector<4x4xf32>, vector<4xindex>, vector<4x4xindex>
+
+      return %1 : vector<4x4xf32>
+}
+
+func.func @fold_vector_transfer_gather_single_element(%A: memref<?x?xf32>, %idx : vector<1x1xindex>) -> vector<4x4xf32> {
+  %c0 = arith.constant 0 : index
+  %f0 = arith.constant 0.0 : f32
+
+  // CHECK: vector.transfer_read
+  %1 = vector.transfer_gather %A[%c0, %c0](%idx), %f0 
+      { gather_dims = [0], index_maps = [affine_map<(d0, d1) -> (0, 0)>] }
+      : memref<?x?xf32>, vector<4x4xf32>, vector<1x1xindex>
+
+  return %1 : vector<4x4xf32>
+}
+
+func.func @fold_vector_transfer_gather_contigious(%A: memref<?x?xf32>) -> vector<4x4xf32> {
+  %c0 = arith.constant 0 : index
+  %f0 = arith.constant 0.0 : f32
+
+  // CHECK-NOT: vector.transfer_gather
+  // CHECK: vector.transfer_read
+  %1 = vector.transfer_gather %A[%c0, %c0](), %f0 
+      { gather_dims = [], index_maps = [] }
+      : memref<?x?xf32>, vector<4x4xf32>
+
+  return %1 : vector<4x4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: bitcast_folding
 //  CHECK-SAME:   %[[A:.*]]: vector<4x8xf32>
 //  CHECK-SAME:   %[[B:.*]]: vector<2xi32>
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 67484e06f456d..b1ba8f2141397 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -148,6 +148,24 @@ func.func @vector_transfer_ops_tensor(%arg0: tensor<?x?xf32>,
     tensor<?x?xvector<4x3xindex>>
 }
 
+// CHECK-LABEL: @vector_transfer_gather_memref
+func.func @vector_transfer_gather_memref(%arg0 : memref<?x?xf32>,
+                                    %idx0 : vector<128xindex>,
+                                    %idx1 : vector<7xindex>,
+                                    %idx2 : vector<3x7xindex>) ->
+  (vector<128xf32>, vector<3x7xf32>, vector<3x7xf32>) {
+  %c3 = arith.constant 3 : index
+  %f0 = arith.constant 0.0 : f32
+
+  // CHECK: vector.transfer_gather
+  %0 = vector.transfer_gather %arg0[%c3, %c3](%idx0), %f0 {gather_dims = [0], index_maps = [affine_map<(d0, d1)->(d0)>], permutation_map = affine_map<(d0, d1)->(d0)>} : memref<?x?xf32>, vector<128xf32>, vector<128xindex>
+  // CHECK: vector.transfer_gather
+  %1 = vector.transfer_gather %arg0[%c3, %c3](%idx1), %f0 {gather_dims = [1], index_maps = [affine_map<(d0, d1)->(d0)>], permutation_map = affine_map<(d0, d1)->(d1, d0)>} : memref<?x?xf32>, vector<3x7xf32>, vector<7xindex>
+  // CHECK: vector.transfer_gather
+  %2 = vector.transfer_gather %arg0[%c3, %c3](%idx1, %idx2), %f0 {gather_dims = [0, 1], index_maps = [affine_map<(d0, d1)->(d0)>, affine_map<(d0, d1)->(d1, d0)>], permutation_map = affine_map<(d0, d1)->(d1, d0)>} : memref<?x?xf32>, vector<3x7xf32>, vector<7xindex>, vector<3x7xindex>
+  return %0, %1, %2 : vector<128xf32>, vector<3x7xf32>, vector<3x7xf32>
+}
+
 // CHECK-LABEL: @vector_broadcast
 func.func @vector_broadcast(%a: f32, %b: vector<f32>, %c: vector<16xf32>, %d: vector<1x16xf32>, %e: vector<8x1xf32>) -> vector<8x16xf32> {
   // CHECK: vector.broadcast %{{.*}} : f32 to vector<f32>



More information about the Mlir-commits mailing list