[Mlir-commits] [mlir] [mlir][Vector] Introduce vector.transfer_gather (PR #130113)
Kunwar Grover
llvmlistbot at llvm.org
Thu Mar 6 06:32:25 PST 2025
https://github.com/Groverkss created https://github.com/llvm/llvm-project/pull/130113
None
>From 5e39a04ad581a6592e30bcf750228376aad55e6b Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Thu, 6 Mar 2025 14:23:36 +0000
Subject: [PATCH] [mlir][Vector] Introduce vector.transfer_gather
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 168 ++++
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 773 ++++++++++++++----
mlir/test/Dialect/Vector/canonicalize.mlir | 75 ++
mlir/test/Dialect/Vector/ops.mlir | 18 +
4 files changed, 861 insertions(+), 173 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index fbbf817ecff98..bf4300ad3d06a 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1667,6 +1667,174 @@ def Vector_TransferWriteOp :
let hasVerifier = 1;
}
+// TODO: Tighten semantics so that masks and inbounds can't be used
+// simultaneously within the same transfer op.
+def Vector_TransferGatherOp :
+ Vector_Op<"transfer_gather", [
+ DeclareOpInterfaceMethods<VectorTransferOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ DeclareOpInterfaceMethods<ConditionallySpeculatable>,
+ AttrSizedOperandSegments,
+ DestinationStyleOpInterface
+ ]>,
+ Arguments<(ins AnyShaped:$source,
+ Variadic<Index>:$indices,
+ Variadic<VectorOfAnyRankOf<[Index]>>:$index_vecs,
+ I64ArrayAttr:$gather_dims,
+ AffineMapArrayAttr:$index_maps,
+ AffineMapAttr:$permutation_map,
+ AnyType:$padding,
+ Optional<VectorOfNonZeroRankOf<[I1]>>:$mask,
+ BoolArrayAttr:$in_bounds)>,
+ Results<(outs AnyVectorOfAnyRank:$vector)> {
+
+ let summary = "Gathers a supervector from memory into an SSA vector value.";
+
+ let description = [{
+ The `vector.transfer_gather` operation is a generalization of
+ `vector.transfer_read` op, where the slice from which the read is performed
+ is not guranteed to be contigious, and instead how the slice is gathered is
+ defined explicitly in the operation.
+
+ The operation can be thought of:
+ 1. A contigious slice gathered from the source as described by the operation
+ 2. A `vector.transfer_read` on the contigious slice
+
+ The operation defines `permutation_map`, `padding`, `mask`, `in_bounds` in
+ the same way as `vector.transfer_read` defines, but on the inferred
+ contigious slice.
+
+ The other parameters of the operation define how the contigious slice is
+ gathered from the source.
+
+ The `indices` contains a base to offset the source by. The dimensions of
+ the source which are gathered are specified as an array of indices in
+ `gather_dims`. Dimensions not specified in this array are contigious. For
+ example, for the following gather:
+
+ ```
+ slice[i, j, k] = source[i + i_offset][j][indices[i][j][k]]
+ ```
+
+ The operation would represent this as:
+
+ ```
+ indices = %i_offset, 0, 0
+ gather_dims = [2]
+ ```
+
+ For every dimension that is gathered, the operation defines how it is
+ gathered. For each gathered dimension, the operation expects a vector of
+ indices in `index_vecs` to act as a source of indices for that dimension
+ and an AffineMap in `index_maps` describing how this source of indices is
+ indexed. For example, for the following gather:
+
+ ```
+ slice[i, j, k] = source[i][indices0[i] + offset][indices1[j, k]]
+ ```
+
+ The indexing would be described by:
+
+ ```
+ indices = 0, %offset, 0
+ gather_dims = [1, 2]
+ index_vecs = %index_vec1, %index_vec2
+ index_maps = [
+ affine_map<(i, j, k) -> (i),
+ affine_map<(i, j, k) -> (j, k)
+ ]
+ ```
+
+ With these additional parameters, the operation can define a supervector
+ read from a non-contigious slice. For example:
+
+ ```
+ source: memref<8192x8x16xf32>
+ indices0 : vector<2xindex>
+ indices1 : vector<4x8xindex>
+
+ slice[i, j, k] = source[indices0[k]][j][indices1[i, j]]
+ vector = read(slice) : memref<8192x8x16xf32> -> vector<2x8x16xf32>
+ ```
+
+ Can be represented by:
+
+ ```
+ %vector = vector.transfer_gather %source[0, 0, 0](%indices0, %indices1) {
+ gather_dims = [0, 2],
+ index_maps = [
+ affine_map<(i, j, k) -> (k)>,
+ affine_map<(i, j, k) -> (i, j)>
+ ],
+ in_bounds = [true, true, true],
+ permutation_map = affine_map<(i, j, k) -> (i, j, k)>
+ } : memref<8192x8x16xf32> -> vector<2x8x16xf32>
+ ```
+ }];
+
+ let builders = [
+ /// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
+ OpBuilder<(ins "VectorType":$vectorType,
+ "Value":$source,
+ "ValueRange":$indices,
+ "ValueRange":$index_vecs,
+ "ArrayAttr":$gather_dims,
+ "ArrayAttr":$index_maps,
+ "AffineMapAttr":$permutationMapAttr,
+ "ArrayAttr":$inBoundsAttr)>,
+ /// 2. Builder that sets padding to zero and an empty mask (variant without attrs).
+ OpBuilder<(ins "VectorType":$vectorType,
+ "Value":$source,
+ "ValueRange":$indices,
+ "ValueRange":$index_vecs,
+ "ArrayRef<int64_t>":$gather_dims,
+ "ArrayRef<AffineMap>":$index_maps,
+ "AffineMap":$permutationMap,
+ CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>,
+ /// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
+ OpBuilder<(ins "VectorType":$vectorType,
+ "Value":$source,
+ "ValueRange":$indices,
+ "ValueRange":$index_vecs,
+ "ArrayRef<int64_t>":$gather_dims,
+ "ArrayRef<AffineMap>":$index_maps,
+ "Value":$padding,
+ CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>,
+ /// 4. Builder that sets padding to zero and permutation map to
+ /// 'getMinorIdentityMap'.
+ OpBuilder<(ins "VectorType":$vectorType,
+ "Value":$source,
+ "ValueRange":$indices,
+ "ValueRange":$index_vecs,
+ "ArrayRef<int64_t>":$gather_dims,
+ "ArrayRef<AffineMap>":$index_maps,
+ CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>,
+ ];
+
+ let extraClassDeclaration = [{
+ // MaskableOpInterface methods.
+ bool supportsPassthru() { return true; }
+
+ MutableOperandRange getDpsInitsMutable() {
+ return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
+ }
+
+ SmallVector<int64_t> getGatherDimsArray() {
+ return llvm::map_to_vector(getGatherDims().getAsValueRange<IntegerAttr>(),
+ [](APInt dim) { return dim.getSExtValue(); });
+ }
+
+ SmallVector<AffineMap> getIndexMapsArray() {
+ return llvm::to_vector(getIndexMaps().getAsValueRange<AffineMapAttr>());
+ }
+ }];
+
+ let hasCanonicalizer = 1;
+ let hasCustomAssemblyFormat = 1;
+ let hasFolder = 1;
+ let hasVerifier = 1;
+}
+
def Vector_LoadOp : Vector_Op<"load"> {
let summary = "reads an n-D slice of memory into an n-D vector";
let description = [{
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ff323983a17c0..51272d8ac4209 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
@@ -4026,64 +4027,9 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
}
//===----------------------------------------------------------------------===//
-// TransferReadOp
+// VectorTransferOpInterface
//===----------------------------------------------------------------------===//
-/// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
-void TransferReadOp::build(OpBuilder &builder, OperationState &result,
- VectorType vectorType, Value source,
- ValueRange indices, AffineMapAttr permutationMapAttr,
- /*optional*/ ArrayAttr inBoundsAttr) {
- Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
- Value padding = builder.create<arith::ConstantOp>(
- result.location, elemType, builder.getZeroAttr(elemType));
- build(builder, result, vectorType, source, indices, permutationMapAttr,
- padding, /*mask=*/Value(), inBoundsAttr);
-}
-
-/// 2. Builder that sets padding to zero an empty mask (variant without attrs).
-void TransferReadOp::build(OpBuilder &builder, OperationState &result,
- VectorType vectorType, Value source,
- ValueRange indices, AffineMap permutationMap,
- std::optional<ArrayRef<bool>> inBounds) {
- auto permutationMapAttr = AffineMapAttr::get(permutationMap);
- auto inBoundsAttr = (inBounds && !inBounds.value().empty())
- ? builder.getBoolArrayAttr(inBounds.value())
- : builder.getBoolArrayAttr(
- SmallVector<bool>(vectorType.getRank(), false));
- build(builder, result, vectorType, source, indices, permutationMapAttr,
- inBoundsAttr);
-}
-
-/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
-void TransferReadOp::build(OpBuilder &builder, OperationState &result,
- VectorType vectorType, Value source,
- ValueRange indices, Value padding,
- std::optional<ArrayRef<bool>> inBounds) {
- AffineMap permutationMap = getTransferMinorIdentityMap(
- llvm::cast<ShapedType>(source.getType()), vectorType);
- auto permutationMapAttr = AffineMapAttr::get(permutationMap);
- auto inBoundsAttr = (inBounds && !inBounds.value().empty())
- ? builder.getBoolArrayAttr(inBounds.value())
- : builder.getBoolArrayAttr(
- SmallVector<bool>(vectorType.getRank(), false));
- build(builder, result, vectorType, source, indices, permutationMapAttr,
- padding,
- /*mask=*/Value(), inBoundsAttr);
-}
-
-/// 4. Builder that sets padding to zero and permutation map to
-/// 'getMinorIdentityMap'.
-void TransferReadOp::build(OpBuilder &builder, OperationState &result,
- VectorType vectorType, Value source,
- ValueRange indices,
- std::optional<ArrayRef<bool>> inBounds) {
- Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
- Value padding = builder.create<arith::ConstantOp>(
- result.location, elemType, builder.getZeroAttr(elemType));
- build(builder, result, vectorType, source, indices, padding, inBounds);
-}
-
template <typename EmitFun>
static LogicalResult verifyPermutationMap(AffineMap permutationMap,
EmitFun emitOpError) {
@@ -4204,14 +4150,6 @@ static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
}
-void TransferReadOp::print(OpAsmPrinter &p) {
- p << " " << getSource() << "[" << getIndices() << "], " << getPadding();
- if (getMask())
- p << ", " << getMask();
- printTransferAttrs(p, *this);
- p << " : " << getShapedType() << ", " << getVectorType();
-}
-
VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
AffineMap permMap) {
auto i1Type = IntegerType::get(permMap.getContext(), 1);
@@ -4230,6 +4168,158 @@ VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
return VectorType::get(maskShape, i1Type, scalableDims);
}
+template <typename TransferOp>
+static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
+ // TODO: support more aggressive createOrFold on:
+ // op.getIndices()[indicesIdx] + vectorType < dim(op.getSource(), indicesIdx)
+ if (op.getShapedType().isDynamicDim(indicesIdx))
+ return false;
+ Value index = op.getIndices()[indicesIdx];
+ std::optional<int64_t> cstOp = getConstantIntValue(index);
+ if (!cstOp.has_value())
+ return false;
+
+ int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
+ int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
+
+ return cstOp.value() + vectorSize <= sourceSize;
+}
+
+template <typename TransferOp>
+static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
+ // TODO: support 0-d corner case.
+ // TODO: Be less conservative.
+ if (op.getTransferRank() == 0)
+ return failure();
+ AffineMap permutationMap = op.getPermutationMap();
+ bool changed = false;
+ SmallVector<bool, 4> newInBounds;
+ newInBounds.reserve(op.getTransferRank());
+ // Idxs of non-bcast dims - used when analysing bcast dims.
+ SmallVector<unsigned> nonBcastDims;
+
+ // 1. Process non-broadcast dims
+ for (unsigned i = 0; i < op.getTransferRank(); ++i) {
+ // 1.1. Already marked as in-bounds, nothing to see here.
+ if (op.isDimInBounds(i)) {
+ newInBounds.push_back(true);
+ continue;
+ }
+ // 1.2. Currently out-of-bounds, check whether we can statically determine
+ // it is inBounds.
+ bool inBounds = false;
+ auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(i));
+ if (dimExpr) {
+ inBounds = isInBounds(op, /*resultIdx=*/i,
+ /*indicesIdx=*/dimExpr.getPosition());
+ nonBcastDims.push_back(i);
+ }
+
+ newInBounds.push_back(inBounds);
+ // We commit the pattern if it is "more inbounds".
+ changed |= inBounds;
+ }
+
+ // 2. Handle broadcast dims
+ // If all non-broadcast dims are "in bounds", then all bcast dims should be
+ // "in bounds" as well.
+ bool allNonBcastDimsInBounds = llvm::all_of(
+ nonBcastDims, [&newInBounds](unsigned idx) { return newInBounds[idx]; });
+ if (allNonBcastDimsInBounds) {
+ for (size_t idx : permutationMap.getBroadcastDims()) {
+ changed |= !newInBounds[idx];
+ newInBounds[idx] = true;
+ }
+ }
+
+ if (!changed)
+ return failure();
+ // OpBuilder is only used as a helper to build an I64ArrayAttr.
+ OpBuilder b(op.getContext());
+ op.setInBoundsAttr(b.getBoolArrayAttr(newInBounds));
+ return success();
+}
+
+template <typename TransferOp>
+static LogicalResult foldTransferFullMask(TransferOp op) {
+ auto mask = op.getMask();
+ if (!mask)
+ return failure();
+
+ if (getMaskFormat(mask) != MaskFormat::AllTrue)
+ return failure();
+
+ op.getMaskMutable().clear();
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// TransferReadOp
+//===----------------------------------------------------------------------===//
+
+/// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
+void TransferReadOp::build(OpBuilder &builder, OperationState &result,
+ VectorType vectorType, Value source,
+ ValueRange indices, AffineMapAttr permutationMapAttr,
+ /*optional*/ ArrayAttr inBoundsAttr) {
+ Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
+ Value padding = builder.create<arith::ConstantOp>(
+ result.location, elemType, builder.getZeroAttr(elemType));
+ build(builder, result, vectorType, source, indices, permutationMapAttr,
+ padding, /*mask=*/Value(), inBoundsAttr);
+}
+
+/// 2. Builder that sets padding to zero an empty mask (variant without attrs).
+void TransferReadOp::build(OpBuilder &builder, OperationState &result,
+ VectorType vectorType, Value source,
+ ValueRange indices, AffineMap permutationMap,
+ std::optional<ArrayRef<bool>> inBounds) {
+ auto permutationMapAttr = AffineMapAttr::get(permutationMap);
+ auto inBoundsAttr = (inBounds && !inBounds.value().empty())
+ ? builder.getBoolArrayAttr(inBounds.value())
+ : builder.getBoolArrayAttr(
+ SmallVector<bool>(vectorType.getRank(), false));
+ build(builder, result, vectorType, source, indices, permutationMapAttr,
+ inBoundsAttr);
+}
+
+/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
+void TransferReadOp::build(OpBuilder &builder, OperationState &result,
+ VectorType vectorType, Value source,
+ ValueRange indices, Value padding,
+ std::optional<ArrayRef<bool>> inBounds) {
+ AffineMap permutationMap = getTransferMinorIdentityMap(
+ llvm::cast<ShapedType>(source.getType()), vectorType);
+ auto permutationMapAttr = AffineMapAttr::get(permutationMap);
+ auto inBoundsAttr = (inBounds && !inBounds.value().empty())
+ ? builder.getBoolArrayAttr(inBounds.value())
+ : builder.getBoolArrayAttr(
+ SmallVector<bool>(vectorType.getRank(), false));
+ build(builder, result, vectorType, source, indices, permutationMapAttr,
+ padding,
+ /*mask=*/Value(), inBoundsAttr);
+}
+
+/// 4. Builder that sets padding to zero and permutation map to
+/// 'getMinorIdentityMap'.
+void TransferReadOp::build(OpBuilder &builder, OperationState &result,
+ VectorType vectorType, Value source,
+ ValueRange indices,
+ std::optional<ArrayRef<bool>> inBounds) {
+ Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
+ Value padding = builder.create<arith::ConstantOp>(
+ result.location, elemType, builder.getZeroAttr(elemType));
+ build(builder, result, vectorType, source, indices, padding, inBounds);
+}
+
+void TransferReadOp::print(OpAsmPrinter &p) {
+ p << " " << getSource() << "[" << getIndices() << "], " << getPadding();
+ if (getMask())
+ p << ", " << getMask();
+ printTransferAttrs(p, *this);
+ p << " : " << getShapedType() << ", " << getVectorType();
+}
+
ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
auto &builder = parser.getBuilder();
SMLoc typesLoc;
@@ -4354,115 +4444,30 @@ Type TransferReadOp::getExpectedMaskType() {
return inferTransferOpMaskType(getVectorType(), getPermutationMap());
}
-template <typename TransferOp>
-static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
- // TODO: support more aggressive createOrFold on:
- // op.getIndices()[indicesIdx] + vectorType < dim(op.getSource(), indicesIdx)
- if (op.getShapedType().isDynamicDim(indicesIdx))
- return false;
- Value index = op.getIndices()[indicesIdx];
- std::optional<int64_t> cstOp = getConstantIntValue(index);
- if (!cstOp.has_value())
- return false;
-
- int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
- int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
-
- return cstOp.value() + vectorSize <= sourceSize;
-}
-
-template <typename TransferOp>
-static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
- // TODO: support 0-d corner case.
- // TODO: Be less conservative.
- if (op.getTransferRank() == 0)
- return failure();
- AffineMap permutationMap = op.getPermutationMap();
- bool changed = false;
- SmallVector<bool, 4> newInBounds;
- newInBounds.reserve(op.getTransferRank());
- // Idxs of non-bcast dims - used when analysing bcast dims.
- SmallVector<unsigned> nonBcastDims;
-
- // 1. Process non-broadcast dims
- for (unsigned i = 0; i < op.getTransferRank(); ++i) {
- // 1.1. Already marked as in-bounds, nothing to see here.
- if (op.isDimInBounds(i)) {
- newInBounds.push_back(true);
- continue;
- }
- // 1.2. Currently out-of-bounds, check whether we can statically determine
- // it is inBounds.
- bool inBounds = false;
- auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(i));
- if (dimExpr) {
- inBounds = isInBounds(op, /*resultIdx=*/i,
- /*indicesIdx=*/dimExpr.getPosition());
- nonBcastDims.push_back(i);
- }
-
- newInBounds.push_back(inBounds);
- // We commit the pattern if it is "more inbounds".
- changed |= inBounds;
- }
-
- // 2. Handle broadcast dims
- // If all non-broadcast dims are "in bounds", then all bcast dims should be
- // "in bounds" as well.
- bool allNonBcastDimsInBounds = llvm::all_of(
- nonBcastDims, [&newInBounds](unsigned idx) { return newInBounds[idx]; });
- if (allNonBcastDimsInBounds) {
- for (size_t idx : permutationMap.getBroadcastDims()) {
- changed |= !newInBounds[idx];
- newInBounds[idx] = true;
- }
- }
-
- if (!changed)
- return failure();
- // OpBuilder is only used as a helper to build an I64ArrayAttr.
- OpBuilder b(op.getContext());
- op.setInBoundsAttr(b.getBoolArrayAttr(newInBounds));
- return success();
-}
-
-template <typename TransferOp>
-static LogicalResult foldTransferFullMask(TransferOp op) {
- auto mask = op.getMask();
- if (!mask)
- return failure();
-
- if (getMaskFormat(mask) != MaskFormat::AllTrue)
- return failure();
-
- op.getMaskMutable().clear();
- return success();
-}
-
-/// ```
-/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
-/// : vector<1x4xf32>, tensor<4x4xf32>
-/// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
-/// : tensor<4x4xf32>, vector<1x4xf32>
-/// ```
-/// -> Folds into
-/// ```
-/// %v0
-/// ```
-static Value foldRAW(TransferReadOp readOp) {
- if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
- return {};
- auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
- while (defWrite) {
- if (checkSameValueRAW(defWrite, readOp))
- return defWrite.getVector();
- if (!isDisjointTransferIndices(
- cast<VectorTransferOpInterface>(defWrite.getOperation()),
- cast<VectorTransferOpInterface>(readOp.getOperation())))
- break;
- defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
- }
- return {};
+/// ```
+/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
+/// : vector<1x4xf32>, tensor<4x4xf32>
+/// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
+/// : tensor<4x4xf32>, vector<1x4xf32>
+/// ```
+/// -> Folds into
+/// ```
+/// %v0
+/// ```
+static Value foldRAW(TransferReadOp readOp) {
+ if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
+ return {};
+ auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
+ while (defWrite) {
+ if (checkSameValueRAW(defWrite, readOp))
+ return defWrite.getVector();
+ if (!isDisjointTransferIndices(
+ cast<VectorTransferOpInterface>(defWrite.getOperation()),
+ cast<VectorTransferOpInterface>(readOp.getOperation())))
+ break;
+ defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
+ }
+ return {};
}
OpFoldResult TransferReadOp::fold(FoldAdaptor) {
@@ -4589,6 +4594,428 @@ void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<TransferReadAfterWriteToBroadcast>(context);
}
+//===----------------------------------------------------------------------===//
+// TransferGatherOp
+//===----------------------------------------------------------------------===//
+
+/// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
+void TransferGatherOp::build(OpBuilder &builder, OperationState &result,
+ VectorType vectorType, Value source,
+ ValueRange indices, ValueRange indexVecs,
+ ArrayAttr gatherDims, ArrayAttr indexMaps,
+ AffineMapAttr permutationMapAttr,
+ /*optional*/ ArrayAttr inBoundsAttr) {
+ Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
+ Value padding = builder.create<arith::ConstantOp>(
+ result.location, elemType, builder.getZeroAttr(elemType));
+ build(builder, result, vectorType, source, indices, indexVecs, gatherDims,
+ indexMaps, permutationMapAttr, padding, /*mask=*/Value(), inBoundsAttr);
+}
+
+/// 2. Builder that sets padding to zero an empty mask (variant without attrs).
+void TransferGatherOp::build(OpBuilder &builder, OperationState &result,
+ VectorType vectorType, Value source,
+ ValueRange indices, ValueRange indexVecs,
+ ArrayRef<int64_t> gatherDims,
+ ArrayRef<AffineMap> indexMaps,
+ AffineMap permutationMap,
+ std::optional<ArrayRef<bool>> inBounds) {
+ auto permutationMapAttr = AffineMapAttr::get(permutationMap);
+ auto indexedAttr = builder.getI64ArrayAttr(gatherDims);
+ auto indexMapsAttr = builder.getAffineMapArrayAttr(indexMaps);
+ auto inBoundsAttr = (inBounds && !inBounds.value().empty())
+ ? builder.getBoolArrayAttr(inBounds.value())
+ : builder.getBoolArrayAttr(
+ SmallVector<bool>(vectorType.getRank(), false));
+ build(builder, result, vectorType, source, indices, indexVecs, indexedAttr,
+ indexMapsAttr, permutationMapAttr, inBoundsAttr);
+}
+
+/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
+void TransferGatherOp::build(OpBuilder &builder, OperationState &result,
+ VectorType vectorType, Value source,
+ ValueRange indices, ValueRange indexVecs,
+ ArrayRef<int64_t> gatherDims,
+ ArrayRef<AffineMap> indexMaps, Value padding,
+ std::optional<ArrayRef<bool>> inBounds) {
+ AffineMap permutationMap = getTransferMinorIdentityMap(
+ llvm::cast<ShapedType>(source.getType()), vectorType);
+ auto inBoundsAttr = (inBounds && !inBounds.value().empty())
+ ? builder.getBoolArrayAttr(inBounds.value())
+ : builder.getBoolArrayAttr(
+ SmallVector<bool>(vectorType.getRank(), false));
+ build(builder, result, vectorType, source, indices, indexVecs,
+ builder.getI64ArrayAttr(gatherDims),
+ builder.getAffineMapArrayAttr(indexMaps), permutationMap, padding,
+ /*mask=*/Value(), inBoundsAttr);
+}
+
+/// 4. Builder that sets padding to zero and permutation map to
+/// 'getMinorIdentityMap'.
+void TransferGatherOp::build(OpBuilder &builder, OperationState &result,
+ VectorType vectorType, Value source,
+ ValueRange indices, ValueRange indexVecs,
+ ArrayRef<int64_t> gatherDims,
+ ArrayRef<AffineMap> indexMaps,
+ std::optional<ArrayRef<bool>> inBounds) {
+ Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
+ Value padding = builder.create<arith::ConstantOp>(
+ result.location, elemType, builder.getZeroAttr(elemType));
+ build(builder, result, vectorType, source, indices, indexVecs, gatherDims,
+ indexMaps, padding, inBounds);
+}
+
+ParseResult TransferGatherOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ auto &builder = parser.getBuilder();
+ SMLoc typesLoc;
+ OpAsmParser::UnresolvedOperand sourceInfo;
+ SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
+ SmallVector<OpAsmParser::UnresolvedOperand, 8> indexVecInfo;
+ OpAsmParser::UnresolvedOperand paddingInfo;
+ SmallVector<Type, 2> types;
+ OpAsmParser::UnresolvedOperand maskInfo;
+ // Parsing with support for paddingValue.
+ if (parser.parseOperand(sourceInfo) ||
+ parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
+ parser.parseOperandList(indexVecInfo, OpAsmParser::Delimiter::Paren) ||
+ parser.parseComma() || parser.parseOperand(paddingInfo))
+ return failure();
+
+ ParseResult hasMask = parser.parseOptionalComma();
+ if (hasMask.succeeded()) {
+ if (parser.parseOperand(maskInfo))
+ return failure();
+ }
+
+ // Parse attributes and types.
+ if (parser.parseOptionalAttrDict(result.attributes) ||
+ parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
+ return failure();
+
+ // Check if number of types given are correct.
+ size_t nRequiredTypes = indexVecInfo.size() + 2;
+ if (types.size() != nRequiredTypes) {
+ return parser.emitError(typesLoc, "expected ")
+ << nRequiredTypes << " types";
+ }
+
+ // The types are arranged as:
+ // sourceTy, resultTy, *indexVecTy
+ auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
+ VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
+ ArrayRef<Type> indexVecTy(types.begin() + 2,
+ types.begin() + indexVecInfo.size() + 2);
+ if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
+ return parser.emitError(typesLoc, "requires memref or ranked tensor type");
+ if (!vectorType)
+ return parser.emitError(typesLoc, "requires vector type");
+ auto permMapAttrName =
+ TransferGatherOp::getPermutationMapAttrName(result.name);
+ Attribute permMapAttr = result.attributes.get(permMapAttrName);
+ AffineMap permMap;
+ if (!permMapAttr) {
+ permMap = vector::getTransferMinorIdentityMap(shapedType, vectorType);
+ result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
+ } else {
+ permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
+ }
+ auto inBoundsAttrName = TransferGatherOp::getInBoundsAttrName(result.name);
+ Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
+ if (!inBoundsAttr) {
+ result.addAttribute(inBoundsAttrName,
+ builder.getBoolArrayAttr(
+ SmallVector<bool>(permMap.getNumResults(), false)));
+ }
+ if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
+ parser.resolveOperands(indexInfo, builder.getIndexType(),
+ result.operands) ||
+ parser.resolveOperands(indexVecInfo, indexVecTy, typesLoc,
+ result.operands) ||
+ parser.resolveOperand(paddingInfo, shapedType.getElementType(),
+ result.operands))
+ return failure();
+ if (hasMask.succeeded()) {
+ if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
+ return parser.emitError(
+ maskInfo.location, "does not support masks with vector element type");
+ if (vectorType.getRank() != permMap.getNumResults()) {
+ return parser.emitError(typesLoc,
+ "expected the same rank for the vector and the "
+ "results of the permutation map");
+ }
+ // Instead of adding the mask type as an op type, compute it based on the
+ // vector type and the permutation map (to keep the type signature small).
+ auto maskType = vector::inferTransferOpMaskType(vectorType, permMap);
+ if (parser.resolveOperand(maskInfo, maskType, result.operands))
+ return failure();
+ }
+ result.addAttribute(TransferGatherOp::getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr(
+ {1, static_cast<int32_t>(indexInfo.size()),
+ static_cast<int32_t>(indexVecInfo.size()), 1,
+ static_cast<int32_t>(hasMask.succeeded())}));
+ return parser.addTypeToList(vectorType, result.types);
+}
+
+LogicalResult TransferGatherOp::verify() {
+ // Consistency of elemental types in source and vector.
+ ShapedType shapedType = getShapedType();
+ VectorType vectorType = getVectorType();
+ VectorType maskType = getMaskType();
+ auto paddingType = getPadding().getType();
+ auto permutationMap = getPermutationMap();
+ VectorType inferredMaskType =
+ maskType ? inferTransferOpMaskType(vectorType, permutationMap)
+ : VectorType();
+ auto sourceElementType = shapedType.getElementType();
+
+ if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
+ return emitOpError("requires ") << shapedType.getRank() << " indices";
+
+ if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
+ shapedType, vectorType, maskType,
+ inferredMaskType, permutationMap, getInBounds())))
+ return failure();
+
+ if (auto sourceVectorElementType =
+ llvm::dyn_cast<VectorType>(sourceElementType)) {
+ // Source has vector element type.
+ // Check that 'sourceVectorElementType' and 'paddingType' types match.
+ if (sourceVectorElementType != paddingType)
+ return emitOpError(
+ "requires source element type and padding type to match.");
+
+ } else {
+ // Check that 'paddingType' is valid to store in a vector type.
+ if (!VectorType::isValidElementType(paddingType))
+ return emitOpError("requires valid padding vector elemental type");
+
+ // Check that padding type and vector element types match.
+ if (paddingType != sourceElementType)
+ return emitOpError(
+ "requires formal padding and source of the same elemental type");
+ }
+
+ return verifyPermutationMap(permutationMap,
+ [&](Twine t) { return emitOpError(t); });
+ return success();
+}
+
+void TransferGatherOp::print(OpAsmPrinter &p) {
+ p << " " << getSource() << "[" << getIndices() << "](" << getIndexVecs()
+ << "), " << getPadding();
+ if (getMask())
+ p << ", " << getMask();
+ printTransferAttrs(p, *this);
+ p << " : " << getShapedType() << ", " << getVectorType() << ", "
+ << getIndexVecs().getTypes();
+}
+
+void TransferGatherOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ if (llvm::isa<MemRefType>(getShapedType()))
+ effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable(),
+ SideEffects::DefaultResource::get());
+}
+
+Speculation::Speculatability TransferGatherOp::getSpeculatability() {
+ if (hasPureTensorSemantics())
+ return Speculation::Speculatable;
+ return Speculation::NotSpeculatable;
+}
+
+struct IndexVecFoldResult {
+ Value indexVec;
+ AffineMap indexMap;
+ bool changed;
+};
+
+static Value foldTransferGatherIndexVecs(
+ TransferGatherOp gatherOp,
+ function_ref<IndexVecFoldResult(Value, AffineMap, int64_t)>
+ indexVecFolder) {
+ bool changed = false;
+ SmallVector<Value> newIndexVecs;
+ SmallVector<AffineMap> newIndexMaps;
+ SmallVector<int64_t> gatherDims;
+ for (auto [i, operand, map, index] :
+ llvm::enumerate(gatherOp.getIndexVecs(), gatherOp.getIndexMapsArray(),
+ gatherOp.getGatherDimsArray())) {
+ auto [indexVec, indexMap, vecChanged] = indexVecFolder(operand, map, index);
+ changed |= vecChanged;
+ if (!indexVec) {
+ continue;
+ }
+ newIndexVecs.push_back(indexVec);
+ newIndexMaps.push_back(indexMap);
+ gatherDims.push_back(i);
+ }
+
+ if (!changed) {
+ return Value();
+ }
+
+ OpBuilder b(gatherOp);
+
+ SmallVector<Value> operands;
+ SmallVector<int32_t> operandSegmentSizes;
+
+ // Source.
+ operands.push_back(gatherOp.getSource());
+ operandSegmentSizes.push_back(1);
+ // Indices.
+ SmallVector<Value> indices = gatherOp.getIndices();
+ operands.append(indices);
+ operandSegmentSizes.push_back(indices.size());
+ // IndexVecs.
+ operands.append(newIndexVecs);
+ operandSegmentSizes.push_back(newIndexVecs.size());
+ // Padding.
+ operands.push_back(gatherOp.getPadding());
+ operandSegmentSizes.push_back(1);
+ // Mask.
+ if (gatherOp.getMask()) {
+ operands.push_back(gatherOp.getMask());
+ operandSegmentSizes.push_back(1);
+ } else {
+ operandSegmentSizes.push_back(0);
+ }
+
+ gatherOp.setIndexMapsAttr(b.getAffineMapArrayAttr(newIndexMaps));
+ gatherOp->setOperands(operands);
+ gatherOp.setGatherDimsAttr(b.getI64ArrayAttr(gatherDims));
+ gatherOp.getProperties().setOperandSegmentSizes(operandSegmentSizes);
+
+ return gatherOp.getResult();
+}
+
+static int64_t getVectorRank(Type type) {
+ return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
+ : 0;
+}
+
+static Value foldTransferGatherFromBroadcast(TransferGatherOp gatherOp) {
+ return foldTransferGatherIndexVecs(
+ gatherOp,
+ [](Value operand, AffineMap map, int64_t) -> IndexVecFoldResult {
+ auto broadcast = operand.getDefiningOp<vector::BroadcastOp>();
+ if (!broadcast) {
+ return {operand, map, false};
+ }
+
+ int64_t sourceRank = getVectorRank(broadcast.getSourceType());
+ int64_t operandRank = getVectorRank(broadcast.getResultVectorType());
+ AffineMap newMap =
+ map.getSliceMap(operandRank - sourceRank, sourceRank);
+ return {broadcast.getSource(), newMap, true};
+ });
+}
+
+static Value foldTransferGatherFromTranspose(TransferGatherOp gatherOp) {
+ return foldTransferGatherIndexVecs(
+ gatherOp,
+ [](Value operand, AffineMap map, int64_t) -> IndexVecFoldResult {
+ auto transpose = operand.getDefiningOp<vector::TransposeOp>();
+ if (!transpose) {
+ return {operand, map, false};
+ }
+
+ AffineMap newMap = map.compose(AffineMap::getPermutationMap(
+ invertPermutationVector(transpose.getPermutation()),
+ transpose.getContext()));
+ return {transpose.getVector(), newMap, true};
+ });
+}
+
+static Value foldTransferGatherFromStep(TransferGatherOp gatherOp) {
+ return foldTransferGatherIndexVecs(
+ gatherOp,
+ [](Value operand, AffineMap map, int64_t) -> IndexVecFoldResult {
+ auto step = operand.getDefiningOp<vector::StepOp>();
+ if (!step) {
+ return {operand, map, false};
+ }
+
+ return {Value(), AffineMap(), true};
+ });
+}
+
+OpFoldResult TransferGatherOp::fold(FoldAdaptor adaptor) {
+ if (auto res = foldTransferGatherFromBroadcast(*this)) {
+ return res;
+ }
+ if (auto res = foldTransferGatherFromTranspose(*this)) {
+ return res;
+ }
+ if (auto res = foldTransferGatherFromStep(*this)) {
+ return res;
+ }
+ return OpFoldResult();
+}
+
+struct FoldSingleElementIndexVec : public OpRewritePattern<TransferGatherOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(TransferGatherOp xferOp,
+ PatternRewriter &rewriter) const override {
+
+ auto indexVecFolder = [&](Value indexVec, AffineMap map,
+ int64_t index) -> IndexVecFoldResult {
+ auto vectorTy = cast<VectorType>(indexVec.getType());
+ if (vectorTy.getNumElements() != 1) {
+ return {indexVec, map, false};
+ }
+
+ // Extract the scalar and add it to the
+ // corressponding base.
+ OpOperand &base = xferOp.getIndicesMutable()[index];
+ Value extracted = rewriter.create<vector::ExtractOp>(
+ xferOp.getLoc(), indexVec,
+ SmallVector<int64_t>(vectorTy.getRank(), 0));
+ Value newIndex = rewriter.create<arith::AddIOp>(indexVec.getLoc(),
+ base.get(), extracted);
+ base.set(newIndex);
+
+ return {Value(), AffineMap(), true};
+ };
+
+ Value newVal = foldTransferGatherIndexVecs(xferOp, indexVecFolder);
+
+ if (!newVal) {
+ return failure();
+ }
+
+ return success();
+ }
+};
+
+struct FoldContigousGatherToTransferRead
+ : public OpRewritePattern<TransferGatherOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(TransferGatherOp xferOp,
+ PatternRewriter &rewriter) const override {
+ if (!xferOp.getIndexVecs().empty()) {
+ return failure();
+ }
+
+ // Canonicalize to vector.transfer_read.
+ rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
+ xferOp, xferOp.getVectorType(), xferOp.getSource(), xferOp.getIndices(),
+ xferOp.getPermutationMap(), xferOp.getPadding(), xferOp.getMask(),
+ xferOp.getInBounds());
+ return success();
+ };
+};
+
+void TransferGatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *ctx) {
+ results.add<FoldSingleElementIndexVec, FoldContigousGatherToTransferRead>(
+ ctx);
+}
+
//===----------------------------------------------------------------------===//
// TransferWriteOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index bf755b466c7eb..3ec6a239f589d 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1075,6 +1075,81 @@ func.func @fold_vector_transfers(%A: memref<?x8xf32>) -> (vector<4x8xf32>, vecto
// -----
+func.func @fold_vector_transfer_gather_broadcast(%A: memref<?x?xf32>, %idx: vector<4xindex>) -> vector<4x4xf32> {
+ %c0 = arith.constant 0 : index
+ %f0 = arith.constant 0.0 : f32
+
+ // CHECK-NOT: vector.broadcast
+ %bidx = vector.broadcast %idx : vector<4xindex> to vector<4x4xindex>
+
+ // CHECK: vector.transfer_gather
+ // CHECK-SAME: memref<?x?xf32>, vector<4x4xf32>, vector<4xindex>
+ %1 = vector.transfer_gather %A[%c0, %c0](%bidx), %f0
+ { gather_dims = [0], index_maps = [affine_map<(d0, d1) -> (d0, d1)>] }
+ : memref<?x?xf32>, vector<4x4xf32>, vector<4x4xindex>
+
+ return %1 : vector<4x4xf32>
+}
+
+func.func @fold_vector_transfer_gather_transpose(%A: memref<?x?xf32>, %idx: vector<4x4xindex>) -> vector<4x4xf32> {
+ %c0 = arith.constant 0 : index
+ %f0 = arith.constant 0.0 : f32
+
+ // CHECK-NOT: vector.transpose
+ %tidx = vector.transpose %idx, [1, 0] : vector<4x4xindex> to vector<4x4xindex>
+
+ // CHECK: vector.transfer_gather
+ // CHECK-SAME: memref<?x?xf32>, vector<4x4xf32>, vector<4x4xindex>
+ %1 = vector.transfer_gather %A[%c0, %c0](%tidx), %f0
+ { gather_dims = [0], index_maps = [affine_map<(d0, d1) -> (d0, d1)>] }
+ : memref<?x?xf32>, vector<4x4xf32>, vector<4x4xindex>
+
+ return %1 : vector<4x4xf32>
+}
+
+func.func @fold_vector_transfer_gather_step(%A: memref<?x?xf32>, %idx : vector<4x4xindex>) -> vector<4x4xf32> {
+ %c0 = arith.constant 0 : index
+ %f0 = arith.constant 0.0 : f32
+
+ // CHECK-NOT: vector.step
+ %sidx = vector.step : vector<4xindex>
+
+ // CHECK: vector.transfer_gather
+ // CHECK-SAME: memref<?x?xf32>, vector<4x4xf32>, vector<4x4xindex>
+ %1 = vector.transfer_gather %A[%c0, %c0](%sidx, %idx), %f0
+ { gather_dims = [0, 1], index_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>] }
+ : memref<?x?xf32>, vector<4x4xf32>, vector<4xindex>, vector<4x4xindex>
+
+ return %1 : vector<4x4xf32>
+}
+
+func.func @fold_vector_transfer_gather_single_element(%A: memref<?x?xf32>, %idx : vector<1x1xindex>) -> vector<4x4xf32> {
+ %c0 = arith.constant 0 : index
+ %f0 = arith.constant 0.0 : f32
+
+ // CHECK: vector.transfer_read
+ %1 = vector.transfer_gather %A[%c0, %c0](%idx), %f0
+ { gather_dims = [0], index_maps = [affine_map<(d0, d1) -> (0, 0)>] }
+ : memref<?x?xf32>, vector<4x4xf32>, vector<1x1xindex>
+
+ return %1 : vector<4x4xf32>
+}
+
+func.func @fold_vector_transfer_gather_contigious(%A: memref<?x?xf32>) -> vector<4x4xf32> {
+ %c0 = arith.constant 0 : index
+ %f0 = arith.constant 0.0 : f32
+
+ // CHECK-NOT: vector.transfer_gather
+ // CHECK: vector.transfer_read
+ %1 = vector.transfer_gather %A[%c0, %c0](), %f0
+ { gather_dims = [], index_maps = [] }
+ : memref<?x?xf32>, vector<4x4xf32>
+
+ return %1 : vector<4x4xf32>
+}
+
+// -----
+
// CHECK-LABEL: bitcast_folding
// CHECK-SAME: %[[A:.*]]: vector<4x8xf32>
// CHECK-SAME: %[[B:.*]]: vector<2xi32>
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 67484e06f456d..b1ba8f2141397 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -148,6 +148,24 @@ func.func @vector_transfer_ops_tensor(%arg0: tensor<?x?xf32>,
tensor<?x?xvector<4x3xindex>>
}
+// CHECK-LABEL: @vector_transfer_gather_memref
+func.func @vector_transfer_gather_memref(%arg0 : memref<?x?xf32>,
+ %idx0 : vector<128xindex>,
+ %idx1 : vector<7xindex>,
+ %idx2 : vector<3x7xindex>) ->
+ (vector<128xf32>, vector<3x7xf32>, vector<3x7xf32>) {
+ %c3 = arith.constant 3 : index
+ %f0 = arith.constant 0.0 : f32
+
+ // CHECK: vector.transfer_gather
+ %0 = vector.transfer_gather %arg0[%c3, %c3](%idx0), %f0 {gather_dims = [0], index_maps = [affine_map<(d0, d1)->(d0)>], permutation_map = affine_map<(d0, d1)->(d0)>} : memref<?x?xf32>, vector<128xf32>, vector<128xindex>
+ // CHECK: vector.transfer_gather
+ %1 = vector.transfer_gather %arg0[%c3, %c3](%idx1), %f0 {gather_dims = [1], index_maps = [affine_map<(d0, d1)->(d0)>], permutation_map = affine_map<(d0, d1)->(d1, d0)>} : memref<?x?xf32>, vector<3x7xf32>, vector<7xindex>
+ // CHECK: vector.transfer_gather
+ %2 = vector.transfer_gather %arg0[%c3, %c3](%idx1, %idx2), %f0 {gather_dims = [0, 1], index_maps = [affine_map<(d0, d1)->(d0)>, affine_map<(d0, d1)->(d1, d0)>], permutation_map = affine_map<(d0, d1)->(d1, d0)>} : memref<?x?xf32>, vector<3x7xf32>, vector<7xindex>, vector<3x7xindex>
+ return %0, %1, %2 : vector<128xf32>, vector<3x7xf32>, vector<3x7xf32>
+}
+
// CHECK-LABEL: @vector_broadcast
func.func @vector_broadcast(%a: f32, %b: vector<f32>, %c: vector<16xf32>, %d: vector<1x16xf32>, %e: vector<8x1xf32>) -> vector<8x16xf32> {
// CHECK: vector.broadcast %{{.*}} : f32 to vector<f32>
More information about the Mlir-commits
mailing list