[Mlir-commits] [mlir] c537a94 - [mlir][Vector] Thread 0-d vectors through vector.transfer ops
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Dec 1 08:49:48 PST 2021
Author: Nicolas Vasilache
Date: 2021-12-01T16:49:43Z
New Revision: c537a943342be66d0876c6440a2df317b572c092
URL: https://github.com/llvm/llvm-project/commit/c537a943342be66d0876c6440a2df317b572c092
DIFF: https://github.com/llvm/llvm-project/commit/c537a943342be66d0876c6440a2df317b572c092.diff
LOG: [mlir][Vector] Thread 0-d vectors through vector.transfer ops
This revision adds 0-d vector support to vector.transfer ops.
In the process, numerous cleanups are applied, in particular around normalizing
and reducing the number of builders.
Reviewed By: ThomasRaoux, springerm
Differential Revision: https://reviews.llvm.org/D114803
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/include/mlir/Interfaces/VectorInterfaces.td
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
mlir/lib/Dialect/Vector/VectorDropLeadUnitDim.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/lib/Interfaces/VectorInterfaces.cpp
mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
mlir/test/Dialect/Linalg/vectorization.mlir
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 1bab07e77325c..8eaf785319578 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1133,26 +1133,28 @@ def Vector_TransferReadOp :
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
AttrSizedOperandSegments
]>,
- Arguments<(ins AnyShaped:$source, Variadic<Index>:$indices,
- AffineMapAttr:$permutation_map, AnyType:$padding,
- Optional<VectorOf<[I1]>>:$mask,
- OptionalAttr<BoolArrayAttr>:$in_bounds)>,
- Results<(outs AnyVector:$vector)> {
+ Arguments<(ins AnyShaped:$source,
+ Variadic<Index>:$indices,
+ AffineMapAttr:$permutation_map,
+ AnyType:$padding,
+ Optional<VectorOf<[I1]>>:$mask,
+ OptionalAttr<BoolArrayAttr>:$in_bounds)>,
+ Results<(outs AnyVectorOfAnyRank:$vector)> {
let summary = "Reads a supervector from memory into an SSA vector value.";
let description = [{
The `vector.transfer_read` op performs a read from a slice within a
[MemRef](../LangRef.md#memref-type) or a Ranked
- [Tensor](../LangRef.md#tensor-type) supplied as its first operand into a
- [vector](../LangRef.md#vector-type) of the same base elemental type.
+ [Tensor](../LangRef.md#tensor-type) supplied as its first operand
+ into a [vector](../LangRef.md#vector-type) of the same base elemental type.
A memref/tensor operand with vector element type, must have its vector
element type match a suffix (shape and element type) of the vector (e.g.
memref<3x2x6x4x3xf32>, vector<1x1x4x3xf32>).
The slice is further defined by a full-rank index within the MemRef/Tensor,
- supplied as the operands `2 .. 1 + rank(memref/tensor)`.
+ supplied as the operands `[1 .. 1 + rank(memref/tensor))`.
The permutation_map [attribute](../LangRef.md#attributes) is an
[affine-map](Affine.md#affine-maps) which specifies the transposition on the
@@ -1301,39 +1303,31 @@ def Vector_TransferReadOp :
}];
let builders = [
- // Builder that sets padding to zero.
- OpBuilder<(ins "VectorType":$vector, "Value":$source,
- "ValueRange":$indices, "AffineMap":$permutationMap,
- CArg<"ArrayRef<bool>", "{}">:$inBounds)>,
- // Builder that sets permutation map to 'getMinorIdentityMap'.
- OpBuilder<(ins "VectorType":$vector, "Value":$source,
- "ValueRange":$indices, "Value":$padding,
- CArg<"ArrayRef<bool>", "{}">:$inBounds)>,
- // Builder that sets permutation map (resp. padding) to
- // 'getMinorIdentityMap' (resp. zero).
- OpBuilder<(ins "VectorType":$vector, "Value":$source,
- "ValueRange":$indices, CArg<"ArrayRef<bool>", "{}">:$inBounds)>,
- // Builder that does not set mask.
- OpBuilder<(ins "Type":$vector, "Value":$source,
- "ValueRange":$indices, "AffineMapAttr":$permutationMap, "Value":$padding,
- "ArrayAttr":$inBounds)>,
- // Builder that does not set mask.
- OpBuilder<(ins "Type":$vector, "Value":$source,
- "ValueRange":$indices, "AffineMap":$permutationMap, "Value":$padding,
- "ArrayAttr":$inBounds)>
+ /// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
+ OpBuilder<(ins "VectorType":$vectorType,
+ "Value":$source,
+ "ValueRange":$indices,
+ "AffineMapAttr":$permutationMapAttr,
+ "ArrayAttr":$inBoundsAttr)>,
+ /// 2. Builder that sets padding to zero and an empty mask (variant without attrs).
+ OpBuilder<(ins "VectorType":$vectorType,
+ "Value":$source,
+ "ValueRange":$indices,
+ "AffineMap":$permutationMap,
+ CArg<"Optional<ArrayRef<bool>>", "::llvm::None">:$inBounds)>,
+ /// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
+ OpBuilder<(ins "VectorType":$vectorType,
+ "Value":$source,
+ "ValueRange":$indices,
+ "Value":$padding,
+ CArg<"Optional<ArrayRef<bool>>", "::llvm::None">:$inBounds)>,
+ /// 4. Builder that sets padding to zero and permutation map to
+ /// 'getMinorIdentityMap'.
+ OpBuilder<(ins "VectorType":$vectorType,
+ "Value":$source,
+ "ValueRange":$indices,
+ CArg<"Optional<ArrayRef<bool>>", "::llvm::None">:$inBounds)>,
];
-
- let extraClassDeclaration = [{
- /// Temporary convenience builders to account for the fact that we do not
- /// have 0-d vectors atm. These create a constant `vector<1xt>` and
- /// insert/extract into it.
- // Builder that sets permutation map (resp. padding) to
- // 'getMinorIdentityMap' (resp. zero).
- static Value createScalarOp(OpBuilder &builder, Location loc, Value source,
- ValueRange indices,
- ArrayRef<bool> inBounds = ArrayRef<bool>{});
- }];
-
let hasCanonicalizer = 1;
let hasFolder = 1;
}
@@ -1345,11 +1339,12 @@ def Vector_TransferWriteOp :
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
AttrSizedOperandSegments
]>,
- Arguments<(ins AnyVector:$vector, AnyShaped:$source,
- Variadic<Index>:$indices,
- AffineMapAttr:$permutation_map,
- Optional<VectorOf<[I1]>>:$mask,
- OptionalAttr<BoolArrayAttr>:$in_bounds)>,
+ Arguments<(ins AnyVectorOfAnyRank:$vector,
+ AnyShaped:$source,
+ Variadic<Index>:$indices,
+ AffineMapAttr:$permutation_map,
+ Optional<VectorOf<[I1]>>:$mask,
+ OptionalAttr<BoolArrayAttr>:$in_bounds)>,
Results<(outs Optional<AnyRankedTensor>:$result)> {
let summary = "The vector.transfer_write op writes a supervector to memory.";
@@ -1367,7 +1362,7 @@ def Vector_TransferWriteOp :
new tensor of the same type.
The slice is further defined by a full-rank index within the MemRef/Tensor,
- supplied as the operands `3 .. 2 + rank(memref/tensor)`.
+ supplied as the operands `[2 .. 2 + rank(memref/tensor))`.
The permutation_map [attribute](../LangRef.md#attributes) is an
[affine-map](Affine.md#affine-maps) which specifies the transposition on the
@@ -1444,32 +1439,32 @@ def Vector_TransferWriteOp :
}];
let builders = [
- // Builder that sets an empty mask.
- OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
- "AffineMap":$permutationMap, CArg<"ArrayRef<bool>", "{}">:$inBounds)>,
- // Builder that sets permutation map to 'getMinorIdentityMap'.
- OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
- CArg<"ArrayRef<bool>", "{}">:$inBounds)>,
- OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
- "AffineMapAttr":$permutationMap, "ArrayAttr":$inBounds)>,
- OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
- "AffineMap":$permutationMap, "Value":$mask, "ArrayAttr":$inBounds)>,
- OpBuilder<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
- "AffineMap":$permutationMap, "ArrayAttr":$inBounds)>,
+ /// 1. Builder with type inference.
+ OpBuilder<(ins "Value":$vector,
+ "Value":$dest,
+ "ValueRange":$indices,
+ "AffineMapAttr":$permutationMapAttr,
+ "Value":$mask,
+ "ArrayAttr":$inBoundsAttr)>,
+ /// 2. Builder with type inference that sets an empty mask (variant with attrs).
+ OpBuilder<(ins "Value":$vector,
+ "Value":$dest,
+ "ValueRange":$indices,
+ "AffineMapAttr":$permutationMapAttr,
+ "ArrayAttr":$inBoundsAttr)>,
+ /// 3. Builder with type inference that sets an empty mask (variant without attrs).
+ OpBuilder<(ins "Value":$vector,
+ "Value":$dest,
+ "ValueRange":$indices,
+ "AffineMap":$permutationMap,
+ CArg<"Optional<ArrayRef<bool>>", "::llvm::None">:$inBounds)>,
+ /// 4. Builder with type inference that sets an empty mask and sets permutation
+ /// map to 'getMinorIdentityMap'.
+ OpBuilder<(ins "Value":$vector,
+ "Value":$dest,
+ "ValueRange":$indices,
+ CArg<"Optional<ArrayRef<bool>>", "::llvm::None">:$inBounds)>,
];
-
- let extraClassDeclaration = [{
- /// Temporary convenience builders to account for the fact that we do not
- /// have 0-d vectors atm. These create a constant `vector<1xt>` and
- /// insert/extract into it.
- // Builder that sets permutation map (resp. padding) to
- // 'getMinorIdentityMap' (resp. zero).
- static Operation *createScalarOp(
- OpBuilder &builder, Location loc, Value value,
- Value dest, ValueRange indices,
- ArrayRef<bool> inBounds = ArrayRef<bool>{});
- }];
-
let hasFolder = 1;
let hasCanonicalizer = 1;
}
diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td
index c713f1806d1f1..68b88860b2ff3 100644
--- a/mlir/include/mlir/Interfaces/VectorInterfaces.td
+++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td
@@ -114,29 +114,6 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
/*methodBody=*/"return $_op.permutation_map();"
/*defaultImplementation=*/
>,
- InterfaceMethod<
- /*desc=*/[{
- Returns true if op involves a 0-d tensor/memref and a vector
- of shape {1}. This is temporary until we have 0-d vectors.
- // TODO: turn this into 0-d vectors + empty permutation_map.
- }],
- /*retTy=*/"bool",
- /*methodName=*/"isZeroD",
- /*args=*/(ins),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- if (getShapedType().getRank() > 0)
- return false;
- if (getVectorType().getShape() != ArrayRef<int64_t>{1})
- return false;
- AffineMap map = AffineMap::get(
- /*numDims=*/0, /*numSymbols=*/0,
- getAffineConstantExpr(0, $_op->getContext()));
- if ($_op.permutation_map() != map)
- return false;
- return true;
- }]
- >,
InterfaceMethod<
/*desc=*/[{ Returns true if the specified dimension is a broadcast. }],
/*retTy=*/"bool",
@@ -157,10 +134,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- // 0-d transfers are not considered broadcasts but they need to be
- // represented with a vector<1xt> until we have 0-d vectors.
- if ($_op.isZeroD()) return false;
- for (unsigned i = 0; i < $_op.permutation_map().getNumResults(); ++i) {
+ for (unsigned i = 0, rank = getTransferRank(); i < rank; ++i) {
if ($_op.isBroadcastDim(i))
return true;
}
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 84e9cae77dd11..4c50b4f699365 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -92,6 +92,10 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
// Return true if the transfer op can be converted to a MMA matrix store.
static bool
transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
+ // TODO: support 0-d corner case.
+ if (writeOp.getTransferRank() == 0)
+ return false;
+
if (writeOp.mask() || writeOp.hasOutOfBoundsDim() ||
writeOp.getVectorType().getRank() != 2)
return false;
@@ -295,6 +299,11 @@ struct CombineTransferReadOpTranspose final
auto transferReadOp = op.vector().getDefiningOp<vector::TransferReadOp>();
if (!transferReadOp)
return failure();
+
+ // TODO: support 0-d corner case.
+ if (transferReadOp.getTransferRank() == 0)
+ return failure();
+
if (transferReadOp.mask() || transferReadOp.hasOutOfBoundsDim())
return failure();
SmallVector<int64_t, 2> perm;
@@ -307,8 +316,8 @@ struct CombineTransferReadOpTranspose final
AffineMap newMap = permutationMap.compose(transferReadOp.permutation_map());
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
op, op.getType(), transferReadOp.source(), transferReadOp.indices(),
- newMap, transferReadOp.padding(), transferReadOp.mask(),
- transferReadOp.in_boundsAttr());
+ AffineMapAttr::get(newMap), transferReadOp.padding(),
+ transferReadOp.mask(), transferReadOp.in_boundsAttr());
return success();
}
};
@@ -335,6 +344,7 @@ static const char *inferFragType(OpTy op) {
static void convertTransferReadOp(vector::TransferReadOp op,
llvm::DenseMap<Value, Value> &valueMapping) {
+ assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
assert(transferReadSupportsMMAMatrixType(op));
Optional<int64_t> stride =
getMemrefConstantHorizontalStride(op.getShapedType());
diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
index cc54b7f8bd2ed..50b4c3a8dd2e7 100644
--- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
+++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
@@ -64,6 +64,10 @@ class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
LogicalResult
matchAndRewrite(ConcreteOp xferOp, typename ConcreteOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ // TODO: support 0-d corner case.
+ if (xferOp.getTransferRank() == 0)
+ return failure();
+
if (xferOp.getVectorType().getRank() > 1 ||
llvm::size(xferOp.indices()) == 0)
return failure();
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 6d2c91f19bb68..4709bed076377 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -52,6 +52,8 @@ struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
/// A return value of None indicates a broadcast.
template <typename OpTy>
static Optional<int64_t> unpackedDim(OpTy xferOp) {
+ // TODO: support 0-d corner case.
+ assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
auto map = xferOp.permutation_map();
if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
return expr.getPosition();
@@ -66,6 +68,8 @@ static Optional<int64_t> unpackedDim(OpTy xferOp) {
/// omitted.
template <typename OpTy>
static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) {
+ // TODO: support 0-d corner case.
+ assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
auto map = xferOp.permutation_map();
return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
b.getContext());
@@ -1081,6 +1085,7 @@ get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv,
SmallVector<Value, 8> &memrefIndices) {
auto indices = xferOp.indices();
auto map = xferOp.permutation_map();
+ assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
memrefIndices.append(indices.begin(), indices.end());
assert(map.getNumResults() == 1 &&
@@ -1206,6 +1211,9 @@ struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
LogicalResult matchAndRewrite(OpTy xferOp,
PatternRewriter &rewriter) const override {
+ // TODO: support 0-d corner case.
+ if (xferOp.getTransferRank() == 0)
+ return failure();
auto map = xferOp.permutation_map();
auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
index b10857e13c9c8..a8736bd3ca1c2 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
@@ -101,8 +101,7 @@ struct TransferWriteOpInterface
return failure();
b.create<vector::TransferWriteOp>(
writeOp.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(),
- writeOp.permutation_map(),
- writeOp.in_bounds() ? *writeOp.in_bounds() : ArrayAttr());
+ writeOp.permutation_mapAttr(), writeOp.in_boundsAttr());
state.mapBuffer(op->getResult(0), resultBuffer);
return success();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 206daf9c81650..42811da8f59ec 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -115,8 +115,6 @@ struct VectorizationResult {
/// ShapedType of `v`.
static VectorType extractVectorTypeFromShapedValue(Value v) {
auto st = v.getType().cast<ShapedType>();
- if (st.getShape().empty())
- return VectorType();
return VectorType::get(st.getShape(), st.getElementType());
}
@@ -179,21 +177,6 @@ static Value broadcastIfNeeded(OpBuilder &b, Value value,
return b.createOrFold<vector::BroadcastOp>(loc, targetVectorType, value);
}
-/// Build a vector.transfer_read from `source` at indices set to all `0`.
-/// If source has rank zero, build a `vector<1xt> transfer_read + extract`.
-/// Return the produced value.
-static Value buildVectorRead(OpBuilder &b, Value source, Type readType,
- AffineMap map) {
- Location loc = source.getLoc();
- auto shapedType = source.getType().cast<ShapedType>();
- SmallVector<Value> indices(shapedType.getRank(),
- b.create<arith::ConstantIndexOp>(loc, 0));
- if (auto vectorType = readType.dyn_cast<VectorType>())
- return b.create<vector::TransferReadOp>(loc, vectorType, source, indices,
- map);
- return vector::TransferReadOp::createScalarOp(b, loc, source, indices);
-}
-
/// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
/// assumes that `reductionOp` has two operands and one of them is the reduction
/// initial value.
@@ -226,8 +209,11 @@ static Value buildVectorWrite(OpBuilder &b, Value value,
Operation *write;
Location loc = value.getLoc();
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
- if (VectorType vectorType =
- extractVectorTypeFromShapedValue(outputOperand->get())) {
+ ArrayRef<int64_t> shape = linalgOp.getShape(outputOperand);
+ auto vectorType = VectorType::get(
+ shape, getElementTypeOrSelf(outputOperand->get().getType()));
+ if (vectorType.getRank() > 0) {
+ // 0-d case is still special: do not invert the reindexing map.
AffineMap map =
reindexIndexingMap(linalgOp.getTiedIndexingMap(outputOperand));
SmallVector<int64_t> transposeShape =
@@ -240,8 +226,11 @@ static Value buildVectorWrite(OpBuilder &b, Value value,
write = b.create<vector::TransferWriteOp>(loc, value, outputOperand->get(),
indices, map);
} else {
- write = vector::TransferWriteOp::createScalarOp(
- b, loc, value, outputOperand->get(), ValueRange{});
+ if (!value.getType().isa<VectorType>())
+ value = b.create<vector::BroadcastOp>(loc, vectorType, value);
+ assert(value.getType() == vectorType && "incorrect type");
+ write = b.create<vector::TransferWriteOp>(loc, value, outputOperand->get(),
+ ValueRange{});
}
LDBG("vectorized op: " << *write);
if (!write->getResults().empty())
@@ -515,32 +504,42 @@ vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp,
SmallVector<int64_t> commonVectorShape = linalgOp.computeStaticLoopSizes();
// 3. Turn all BBArgs into vector.transfer_read / load.
- SmallVector<AffineMap> indexings;
+ Location loc = linalgOp.getLoc();
+ Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
BlockArgument bbarg = block->getArgument(opOperand->getOperandNumber());
if (linalgOp.isScalar(opOperand)) {
bvm.map(bbarg, opOperand->get());
continue;
}
- // TODO: 0-d vectors.
- Type readType;
+ VectorType readType;
AffineMap map;
- if (linalgOp.getShape(opOperand).empty()) {
- readType = bbarg.getType();
+ // TODO: can we keep this simplification?
+ // if (linalgOp.getShape(opOperand).empty()) {
+ // readType = VectorType::get({}, bbarg.getType());
+ // } else {
+ if (opOperand->getOperandNumber() < linalgOp.getNumInputs()) {
+ map = inverseAndBroadcastProjectedPermuation(
+ linalgOp.getTiedIndexingMap(opOperand));
+ readType = VectorType::get(commonVectorShape,
+ getElementTypeOrSelf(opOperand->get()));
} else {
- if (opOperand->getOperandNumber() < linalgOp.getNumInputs()) {
- map = inverseAndBroadcastProjectedPermuation(
- linalgOp.getTiedIndexingMap(opOperand));
- readType = VectorType::get(commonVectorShape,
- getElementTypeOrSelf(opOperand->get()));
- } else {
- map = inversePermutation(
- reindexIndexingMap(linalgOp.getTiedIndexingMap(opOperand)));
- readType = VectorType::get(map.compose(linalgOp.getShape(opOperand)),
- getElementTypeOrSelf(opOperand->get()));
- }
+ map = inversePermutation(
+ reindexIndexingMap(linalgOp.getTiedIndexingMap(opOperand)));
+ readType = VectorType::get(map.compose(linalgOp.getShape(opOperand)),
+ getElementTypeOrSelf(opOperand->get()));
}
- Value readValue = buildVectorRead(b, opOperand->get(), readType, map);
+ // }
+
+ auto shape = linalgOp.getShape(opOperand);
+ SmallVector<Value> indices(shape.size(), zero);
+ Value readValue = b.create<vector::TransferReadOp>(
+ loc, readType, opOperand->get(), indices, map);
+ // Not all ops support 0-d vectors, extract the scalar for now.
+ // TODO: remove this.
+ if (readValue.getType().cast<VectorType>().getRank() == 0)
+ readValue = b.create<vector::ExtractElementOp>(loc, readValue);
+
LDBG("new vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue);
bvm.map(bbarg, readValue);
bvm.map(opOperand->get(), readValue);
@@ -752,7 +751,7 @@ struct GenericPadTensorOpVectorizationPattern
rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
auto read = rewriter.create<vector::TransferReadOp>(
padOp.getLoc(), vecType, padOp.source(), readIndices, padValue,
- readInBounds);
+ ArrayRef<bool>{readInBounds});
// If `dest` is a FillOp and the TransferWriteOp would overwrite the entire
// tensor, write directly to the FillOp's operand.
@@ -765,7 +764,7 @@ struct GenericPadTensorOpVectorizationPattern
auto writeIndices =
ofrToIndexValues(rewriter, padOp.getLoc(), padOp.getMixedLowPad());
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
- padOp, read, dest, writeIndices, writeInBounds);
+ padOp, read, dest, writeIndices, ArrayRef<bool>{writeInBounds});
return success();
}
@@ -878,6 +877,10 @@ struct PadTensorOpVectorizationWithTransferWritePattern
LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp,
vector::TransferWriteOp xferOp) const override {
+ // TODO: support 0-d corner case.
+ if (xferOp.getTransferRank() == 0)
+ return failure();
+
// Low padding must be static 0.
if (!padOp.hasZeroLowPad())
return failure();
@@ -1072,7 +1075,8 @@ struct PadTensorOpVectorizationWithInsertSlicePattern
ofrToIndexValues(rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
SmallVector<bool> inBounds(vecRank, true);
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
- insertOp, read, insertOp.dest(), writeIndices, inBounds);
+ insertOp, read, insertOp.dest(), writeIndices,
+ ArrayRef<bool>{inBounds});
return success();
}
@@ -1266,6 +1270,10 @@ static memref::SubViewOp getSubViewUseIfUnique(Value v) {
LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
+ // TODO: support mask.
+ if (xferOp.mask())
+ return failure();
+
// Transfer into `view`.
Value viewOrAlloc = xferOp.source();
if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
@@ -1328,7 +1336,9 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
// conservatively.
Value res = rewriter.create<vector::TransferReadOp>(
xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.indices(),
- xferOp.permutation_map(), xferOp.padding(), ArrayAttr());
+ xferOp.permutation_mapAttr(), xferOp.padding(), xferOp.mask(),
+ // in_bounds is explicitly reset
+ /*inBoundsAttr=*/ArrayAttr());
if (maybeFillOp)
rewriter.eraseOp(maybeFillOp);
@@ -1342,6 +1352,10 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
/// when available.
LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
+ // TODO: support mask.
+ if (xferOp.mask())
+ return failure();
+
// Transfer into `viewOrAlloc`.
Value viewOrAlloc = xferOp.source();
if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
@@ -1380,7 +1394,9 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
// conservatively.
rewriter.create<vector::TransferWriteOp>(
xferOp.getLoc(), xferOp.vector(), out, xferOp.indices(),
- xferOp.permutation_map(), ArrayAttr());
+ xferOp.permutation_mapAttr(), xferOp.mask(),
+ // in_bounds is explicitly reset
+ /*inBoundsAttr=*/ArrayAttr());
rewriter.eraseOp(copyOp);
rewriter.eraseOp(xferOp);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
index 95bab64e63ed2..1feda57d8de03 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
@@ -103,9 +103,9 @@ static Value getMemRefOperand(vector::TransferWriteOp op) {
/// Given the permutation map of the original
/// `vector.transfer_read`/`vector.transfer_write` operations compute the
/// permutation map to use after the subview is folded with it.
-static AffineMap getPermutationMap(MLIRContext *context,
- memref::SubViewOp subViewOp,
- AffineMap currPermutationMap) {
+static AffineMapAttr getPermutationMapAttr(MLIRContext *context,
+ memref::SubViewOp subViewOp,
+ AffineMap currPermutationMap) {
llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims();
SmallVector<AffineExpr> exprs;
int64_t sourceRank = subViewOp.getSourceType().getRank();
@@ -115,7 +115,8 @@ static AffineMap getPermutationMap(MLIRContext *context,
exprs.push_back(getAffineDimExpr(dim, context));
}
auto resultDimToSourceDimMap = AffineMap::get(sourceRank, 0, exprs, context);
- return currPermutationMap.compose(resultDimToSourceDimMap);
+ return AffineMapAttr::get(
+ currPermutationMap.compose(resultDimToSourceDimMap));
}
//===----------------------------------------------------------------------===//
@@ -163,13 +164,18 @@ void LoadOpOfSubViewFolder<memref::LoadOp>::replaceOp(
template <>
void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp(
- vector::TransferReadOp loadOp, memref::SubViewOp subViewOp,
+ vector::TransferReadOp transferReadOp, memref::SubViewOp subViewOp,
ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
+ // TODO: support 0-d corner case.
+ if (transferReadOp.getTransferRank() == 0)
+ return;
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
- loadOp, loadOp.getVectorType(), subViewOp.source(), sourceIndices,
- getPermutationMap(rewriter.getContext(), subViewOp,
- loadOp.permutation_map()),
- loadOp.padding(), loadOp.in_boundsAttr());
+ transferReadOp, transferReadOp.getVectorType(), subViewOp.source(),
+ sourceIndices,
+ getPermutationMapAttr(rewriter.getContext(), subViewOp,
+ transferReadOp.permutation_map()),
+ transferReadOp.padding(),
+ /*mask=*/Value(), transferReadOp.in_boundsAttr());
}
template <>
@@ -184,11 +190,14 @@ template <>
void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
vector::TransferWriteOp transferWriteOp, memref::SubViewOp subViewOp,
ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
+ // TODO: support 0-d corner case.
+ if (transferWriteOp.getTransferRank() == 0)
+ return;
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
transferWriteOp, transferWriteOp.vector(), subViewOp.source(),
sourceIndices,
- getPermutationMap(rewriter.getContext(), subViewOp,
- transferWriteOp.permutation_map()),
+ getPermutationMapAttr(rewriter.getContext(), subViewOp,
+ transferWriteOp.permutation_map()),
transferWriteOp.in_boundsAttr());
}
} // namespace
diff --git a/mlir/lib/Dialect/Vector/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/VectorDropLeadUnitDim.cpp
index f00a4f808c695..9739c7c792e63 100644
--- a/mlir/lib/Dialect/Vector/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/VectorDropLeadUnitDim.cpp
@@ -133,6 +133,10 @@ struct CastAwayTransferReadLeadingOneDim
LogicalResult matchAndRewrite(vector::TransferReadOp read,
PatternRewriter &rewriter) const override {
+ // TODO: support 0-d corner case.
+ if (read.getTransferRank() == 0)
+ return failure();
+
if (read.mask())
return failure();
@@ -153,14 +157,15 @@ struct CastAwayTransferReadLeadingOneDim
AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
rewriter.getContext());
- ArrayAttr inBounds;
+ ArrayAttr inBoundsAttr;
if (read.in_bounds())
- inBounds = rewriter.getArrayAttr(
+ inBoundsAttr = rewriter.getArrayAttr(
read.in_boundsAttr().getValue().take_back(newType.getRank()));
auto newRead = rewriter.create<vector::TransferReadOp>(
- read.getLoc(), newType, read.source(), read.indices(), newMap,
- read.padding(), inBounds);
+ read.getLoc(), newType, read.source(), read.indices(),
+ AffineMapAttr::get(newMap), read.padding(), /*mask=*/Value(),
+ inBoundsAttr);
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
return success();
@@ -176,6 +181,10 @@ struct CastAwayTransferWriteLeadingOneDim
LogicalResult matchAndRewrite(vector::TransferWriteOp write,
PatternRewriter &rewriter) const override {
+ // TODO: support 0-d corner case.
+ if (write.getTransferRank() == 0)
+ return failure();
+
if (write.mask())
return failure();
@@ -196,15 +205,16 @@ struct CastAwayTransferWriteLeadingOneDim
AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
rewriter.getContext());
- ArrayAttr inBounds;
+ ArrayAttr inBoundsAttr;
if (write.in_bounds())
- inBounds = rewriter.getArrayAttr(
+ inBoundsAttr = rewriter.getArrayAttr(
write.in_boundsAttr().getValue().take_back(newType.getRank()));
auto newVector = rewriter.create<vector::ExtractOp>(
write.getLoc(), write.vector(), splatZero(dropDim));
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
- write, newVector, write.source(), write.indices(), newMap, inBounds);
+ write, newVector, write.source(), write.indices(),
+ AffineMapAttr::get(newMap), inBoundsAttr);
return success();
}
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 474b08276933b..859067b2bffe8 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1613,8 +1613,8 @@ static LogicalResult verify(InsertOp op) {
static_cast<unsigned>(destVectorType.getRank())))
return op.emitOpError("expected position attribute rank + source rank to "
"match dest vector rank");
- if (!srcVectorType && (positionAttr.size() !=
- static_cast<unsigned>(destVectorType.getRank())))
+ if (!srcVectorType &&
+ (positionAttr.size() != static_cast<unsigned>(destVectorType.getRank())))
return op.emitOpError(
"expected position attribute rank to match the dest vector rank");
for (auto en : llvm::enumerate(positionAttr)) {
@@ -2314,6 +2314,59 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
// TransferReadOp
//===----------------------------------------------------------------------===//
+/// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
+void TransferReadOp::build(OpBuilder &builder, OperationState &result,
+ VectorType vectorType, Value source,
+ ValueRange indices, AffineMapAttr permutationMapAttr,
+ /*optional*/ ArrayAttr inBoundsAttr) {
+ Type elemType = source.getType().cast<ShapedType>().getElementType();
+ Value padding = builder.create<arith::ConstantOp>(
+ result.location, elemType, builder.getZeroAttr(elemType));
+ build(builder, result, vectorType, source, indices, permutationMapAttr,
+ padding, /*mask=*/Value(), inBoundsAttr);
+}
+
+/// 2. Builder that sets padding to zero an empty mask (variant without attrs).
+void TransferReadOp::build(OpBuilder &builder, OperationState &result,
+ VectorType vectorType, Value source,
+ ValueRange indices, AffineMap permutationMap,
+ Optional<ArrayRef<bool>> inBounds) {
+ auto permutationMapAttr = AffineMapAttr::get(permutationMap);
+ auto inBoundsAttr = (inBounds && !inBounds.getValue().empty())
+ ? builder.getBoolArrayAttr(inBounds.getValue())
+ : ArrayAttr();
+ build(builder, result, vectorType, source, indices, permutationMapAttr,
+ inBoundsAttr);
+}
+
+/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
+void TransferReadOp::build(OpBuilder &builder, OperationState &result,
+ VectorType vectorType, Value source,
+ ValueRange indices, Value padding,
+ Optional<ArrayRef<bool>> inBounds) {
+ AffineMap permutationMap = getTransferMinorIdentityMap(
+ source.getType().cast<ShapedType>(), vectorType);
+ auto permutationMapAttr = AffineMapAttr::get(permutationMap);
+ auto inBoundsAttr = (inBounds && !inBounds.getValue().empty())
+ ? builder.getBoolArrayAttr(inBounds.getValue())
+ : ArrayAttr();
+ build(builder, result, vectorType, source, indices, permutationMapAttr,
+ padding,
+ /*mask=*/Value(), inBoundsAttr);
+}
+
+/// 4. Builder that sets padding to zero and permutation map to
+/// 'getMinorIdentityMap'.
+void TransferReadOp::build(OpBuilder &builder, OperationState &result,
+ VectorType vectorType, Value source,
+ ValueRange indices,
+ Optional<ArrayRef<bool>> inBounds) {
+ Type elemType = source.getType().cast<ShapedType>().getElementType();
+ Value padding = builder.create<arith::ConstantOp>(
+ result.location, elemType, builder.getZeroAttr(elemType));
+ build(builder, result, vectorType, source, indices, padding, inBounds);
+}
+
template <typename EmitFun>
static LogicalResult verifyPermutationMap(AffineMap permutationMap,
EmitFun emitOpError) {
@@ -2347,10 +2400,6 @@ static LogicalResult
verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
VectorType vectorType, VectorType maskType,
AffineMap permutationMap, ArrayAttr inBounds) {
- if (shapedType.getRank() == 0 && !op.isZeroD())
- return op->emitOpError("0-d transfer requires vector<1xt> shape and () -> "
- "(0) permutation_map");
-
if (op->hasAttr("masked")) {
return op->emitOpError("masked attribute has been removed. "
"Use in_bounds instead.");
@@ -2359,6 +2408,7 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
if (!shapedType.isa<MemRefType, RankedTensorType>())
return op->emitOpError(
"requires source to be a memref or ranked tensor type");
+
auto elementType = shapedType.getElementType();
DataLayout dataLayout = DataLayout::closest(op);
if (auto vectorElementType = elementType.dyn_cast<VectorType>()) {
@@ -2389,9 +2439,10 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
return op->emitOpError("does not support masks with vector element type");
} else {
// Memref or tensor has scalar element type.
+ unsigned minorSize =
+ vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
unsigned resultVecSize =
- dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
- vectorType.getShape().back();
+ dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize;
if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
return op->emitOpError(
"requires the bitwidth of the minor 1-D vector to be an integral "
@@ -2412,8 +2463,8 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
if (permutationMap.getNumSymbols() != 0)
return op->emitOpError("requires permutation_map without symbols");
- // TODO: implement 0-d vector corner cases.
- if (!op.isZeroD() && permutationMap.getNumInputs() != shapedType.getRank())
+
+ if (permutationMap.getNumInputs() != shapedType.getRank())
return op->emitOpError("requires a permutation_map with input dims of the "
"same rank as the source type");
@@ -2421,7 +2472,8 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
return op->emitOpError("expects the optional in_bounds attr of same rank "
"as permutation_map results: ")
- << AffineMapAttr::get(permutationMap);
+ << AffineMapAttr::get(permutationMap)
+ << " vs inBounds of size: " << inBounds.size();
for (unsigned int i = 0; i < permutationMap.getNumResults(); ++i)
if (permutationMap.getResult(i).isa<AffineConstantExpr>() &&
!inBounds.getValue()[i].cast<BoolAttr>().getValue())
@@ -2431,77 +2483,6 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
return success();
}
-/// Builder that sets padding to zero.
-void TransferReadOp::build(OpBuilder &builder, OperationState &result,
- VectorType vectorType, Value source,
- ValueRange indices, AffineMap permutationMap,
- ArrayRef<bool> inBounds) {
- Type elemType = source.getType().cast<ShapedType>().getElementType();
- Value padding = builder.create<arith::ConstantOp>(
- result.location, elemType, builder.getZeroAttr(elemType));
- if (inBounds.empty())
- return build(builder, result, vectorType, source, indices, permutationMap,
- padding, ArrayAttr());
- ArrayAttr inBoundsArrayAttr = builder.getBoolArrayAttr(inBounds);
- build(builder, result, vectorType, source, indices, permutationMap, padding,
- inBoundsArrayAttr);
-}
-
-/// Builder that sets permutation map to 'getMinorIdentityMap'.
-void TransferReadOp::build(OpBuilder &builder, OperationState &result,
- VectorType vectorType, Value source,
- ValueRange indices, Value padding,
- ArrayRef<bool> inBounds) {
- auto permMap = getTransferMinorIdentityMap(
- source.getType().cast<ShapedType>(), vectorType);
- if (inBounds.empty())
- return build(builder, result, vectorType, source, indices, permMap, padding,
- ArrayAttr());
- ArrayAttr inBoundsArrayAttr = builder.getBoolArrayAttr(inBounds);
- build(builder, result, vectorType, source, indices, permMap, padding,
- inBoundsArrayAttr);
-}
-
-/// Builder that sets permutation map (resp. padding) to 'getMinorIdentityMap'
-/// (resp. zero).
-void TransferReadOp::build(OpBuilder &builder, OperationState &result,
- VectorType vectorType, Value source,
- ValueRange indices, ArrayRef<bool> inBounds) {
- auto permMap = getTransferMinorIdentityMap(
- source.getType().cast<ShapedType>(), vectorType);
- build(builder, result, vectorType, source, indices, permMap, inBounds);
-}
-
-/// Builder that does not provide a mask.
-void TransferReadOp::build(OpBuilder &builder, OperationState &result,
- Type vectorType, Value source, ValueRange indices,
- AffineMap permutationMap, Value padding,
- ArrayAttr inBounds) {
- build(builder, result, vectorType, source, indices, permutationMap, padding,
- /*mask=*/Value(), inBounds);
-}
-
-/// Builder that does not provide a mask.
-void TransferReadOp::build(OpBuilder &builder, OperationState &result,
- Type vectorType, Value source, ValueRange indices,
- AffineMapAttr permutationMap, Value padding,
- ArrayAttr inBounds) {
- build(builder, result, vectorType, source, indices, permutationMap, padding,
- /*mask=*/Value(), inBounds);
-}
-
-Value TransferReadOp::createScalarOp(OpBuilder &builder, Location loc,
- Value source, ValueRange indices,
- ArrayRef<bool> inBounds) {
- Type elemType = source.getType().cast<ShapedType>().getElementType();
- auto vectorType = VectorType::get(ArrayRef<int64_t>{1}, elemType);
- AffineMap map = AffineMap::get(/*numDims=*/0, /*numSymbols=*/0,
- getAffineConstantExpr(0, loc.getContext()));
- Value read = builder.create<vector::TransferReadOp>(loc, vectorType, source,
- indices, map, inBounds);
- return builder.create<vector::ExtractOp>(loc, read, ArrayRef<int64_t>{0});
-}
-
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
SmallVector<StringRef, 3> elidedAttrs;
elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
@@ -2563,6 +2544,7 @@ static ParseResult parseTransferReadOp(OpAsmParser &parser,
Attribute mapAttr = result.attributes.get(permutationAttrName);
if (!mapAttr) {
auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
+ // Update `mapAttr` that is used later to determine mask type.
mapAttr = AffineMapAttr::get(permMap);
result.attributes.set(permutationAttrName, mapAttr);
}
@@ -2677,8 +2659,9 @@ static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
template <typename TransferOp>
static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
- // TODO: Be less conservative once we have 0-d vectors.
- if (op.isZeroD())
+ // TODO: support 0-d corner case.
+ // TODO: Be less conservative.
+ if (op.getTransferRank() == 0)
return failure();
AffineMap permutationMap = op.permutation_map();
bool changed = false;
@@ -2783,6 +2766,9 @@ struct FoldExtractSliceIntoTransferRead
LogicalResult matchAndRewrite(TransferReadOp xferOp,
PatternRewriter &rewriter) const override {
+ // TODO: support 0-d corner case.
+ if (xferOp.getTransferRank() == 0)
+ return failure();
if (xferOp.hasOutOfBoundsDim())
return failure();
if (!xferOp.permutation_map().isIdentity())
@@ -2814,9 +2800,9 @@ struct FoldExtractSliceIntoTransferRead
offset)));
}
SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
- rewriter.replaceOpWithNewOp<TransferReadOp>(xferOp, xferOp.getVectorType(),
- extractOp.source(), newIndices,
- xferOp.padding(), inBounds);
+ rewriter.replaceOpWithNewOp<TransferReadOp>(
+ xferOp, xferOp.getVectorType(), extractOp.source(), newIndices,
+ xferOp.padding(), ArrayRef<bool>{inBounds});
return success();
}
@@ -2832,69 +2818,49 @@ void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
// TransferWriteOp
//===----------------------------------------------------------------------===//
+/// 1. Builder with type inference.
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
Value vector, Value dest, ValueRange indices,
- AffineMap permutationMap, ArrayRef<bool> inBounds) {
- if (inBounds.empty())
- return build(builder, result, vector, dest, indices, permutationMap,
- /*mask=*/Value(), ArrayAttr());
- build(builder, result, vector, dest, indices, permutationMap,
- /*mask=*/Value(), builder.getBoolArrayAttr(inBounds));
-}
-
-/// Builder that sets permutation map to 'getMinorIdentityMap'.
-void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
- Value vector, Value source, ValueRange indices,
- ArrayRef<bool> inBounds) {
- auto vectorType = vector.getType().cast<VectorType>();
- auto permMap = getTransferMinorIdentityMap(
- source.getType().cast<ShapedType>(), vectorType);
- if (inBounds.empty())
- return build(builder, result, vector, source, indices, permMap,
- ArrayAttr());
- ArrayAttr inBoundsArrayAttr = builder.getBoolArrayAttr(inBounds);
- build(builder, result, vector, source, indices, permMap, inBoundsArrayAttr);
+ AffineMapAttr permutationMapAttr,
+ /*optional*/ Value mask,
+ /*optional*/ ArrayAttr inBoundsAttr) {
+ Type resultType = dest.getType().dyn_cast<RankedTensorType>();
+ build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
+ mask, inBoundsAttr);
}
+/// 2. Builder with type inference that sets an empty mask (variant with attrs).
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
- Value vector, Value source, ValueRange indices,
- AffineMapAttr permutationMap,
- /*optional*/ ArrayAttr inBounds) {
- Type resultType = source.getType().dyn_cast<RankedTensorType>();
- build(builder, result, resultType, vector, source, indices, permutationMap,
- /*mask=*/Value(), inBounds);
+ Value vector, Value dest, ValueRange indices,
+ AffineMapAttr permutationMapAttr,
+ /*optional*/ ArrayAttr inBoundsAttr) {
+ build(builder, result, vector, dest, indices, permutationMapAttr,
+ /*mask=*/Value(), inBoundsAttr);
}
+/// 3. Builder with type inference that sets an empty mask (variant without
+/// attrs)
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
- Value vector, Value source, ValueRange indices,
+ Value vector, Value dest, ValueRange indices,
AffineMap permutationMap,
- /*optional*/ ArrayAttr inBounds) {
- Type resultType = source.getType().dyn_cast<RankedTensorType>();
- build(builder, result, resultType, vector, source, indices, permutationMap,
- /*mask=*/Value(), inBounds);
+ Optional<ArrayRef<bool>> inBounds) {
+ auto permutationMapAttr = AffineMapAttr::get(permutationMap);
+ auto inBoundsAttr = (inBounds && !inBounds.getValue().empty())
+ ? builder.getBoolArrayAttr(inBounds.getValue())
+ : ArrayAttr();
+ build(builder, result, vector, dest, indices, permutationMapAttr,
+ /*mask=*/Value(), inBoundsAttr);
}
+/// 4. Builder with type inference that sets an empty mask and sets permutation
+/// map to 'getMinorIdentityMap'.
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
- Value vector, Value source, ValueRange indices,
- AffineMap permutationMap, /*optional*/ Value mask,
- /*optional*/ ArrayAttr inBounds) {
- Type resultType = source.getType().dyn_cast<RankedTensorType>();
- build(builder, result, resultType, vector, source, indices, permutationMap,
- mask, inBounds);
-}
-
-Operation *TransferWriteOp::createScalarOp(OpBuilder &builder, Location loc,
- Value value, Value dest,
- ValueRange indices,
- ArrayRef<bool> inBounds) {
- Value vectorOfAScalar = value;
- if (!value.getType().isa<VectorType>())
- vectorOfAScalar = builder.create<vector::BroadcastOp>(
- loc, VectorType::get({1}, value.getType()), value);
- AffineMap map = AffineMap::get(/*numDims=*/0, /*numSymbols=*/0,
- getAffineConstantExpr(0, loc.getContext()));
- return builder.create<vector::TransferWriteOp>(loc, vectorOfAScalar, dest,
- indices, map, inBounds);
+ Value vector, Value dest, ValueRange indices,
+ Optional<ArrayRef<bool>> inBounds) {
+ auto vectorType = vector.getType().cast<VectorType>();
+ AffineMap permutationMap = getTransferMinorIdentityMap(
+ dest.getType().cast<ShapedType>(), vectorType);
+ build(builder, result, vector, dest, indices, permutationMap, inBounds);
}
static ParseResult parseTransferWriteOp(OpAsmParser &parser,
@@ -3003,6 +2969,9 @@ static LogicalResult verify(TransferWriteOp op) {
static LogicalResult foldReadInitWrite(TransferWriteOp write,
ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &results) {
+ // TODO: support 0-d corner case.
+ if (write.getTransferRank() == 0)
+ return failure();
auto rankedTensorType = write.source().getType().dyn_cast<RankedTensorType>();
// If not operating on tensors, bail.
if (!rankedTensorType)
@@ -3011,6 +2980,9 @@ static LogicalResult foldReadInitWrite(TransferWriteOp write,
auto read = write.vector().getDefiningOp<vector::TransferReadOp>();
if (!read)
return failure();
+ // TODO: support 0-d corner case.
+ if (read.getTransferRank() == 0)
+ return failure();
// For now, only accept minor identity. Future: composition is minor identity.
if (!read.permutation_map().isMinorIdentity() ||
!write.permutation_map().isMinorIdentity())
@@ -3179,9 +3151,14 @@ struct FoldInsertSliceIntoTransferWrite
PatternRewriter &rewriter) const override {
if (!insertOp.hasUnitStride())
return failure();
+
auto xferOp = insertOp.source().getDefiningOp<TransferWriteOp>();
if (!xferOp)
return failure();
+ // TODO: support 0-d corner case.
+ if (xferOp.getTransferRank() == 0)
+ return failure();
+
if (xferOp.hasOutOfBoundsDim())
return failure();
if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank())
@@ -3200,8 +3177,9 @@ struct FoldInsertSliceIntoTransferWrite
SmallVector<Value> indices = getValueOrCreateConstantIndexOp(
rewriter, insertOp.getLoc(), insertOp.getMixedOffsets());
SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
- rewriter.replaceOpWithNewOp<TransferWriteOp>(
- insertOp, xferOp.vector(), insertOp.dest(), indices, inBounds);
+ rewriter.replaceOpWithNewOp<TransferWriteOp>(insertOp, xferOp.vector(),
+ insertOp.dest(), indices,
+ ArrayRef<bool>{inBounds});
return success();
}
};
diff --git a/mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp b/mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp
index e8817e71ce2ac..36725e03ae09e 100644
--- a/mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp
@@ -31,6 +31,7 @@ transposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
attr.getValue()[pos].cast<BoolAttr>().getValue());
return builder.getBoolArrayAttr(newInBoundsValues);
}
+
/// Lower transfer_read op with permutation into a transfer_read with a
/// permutation map composed of leading zeros followed by a minor identiy +
/// vector.transpose op.
@@ -56,6 +57,10 @@ struct TransferReadPermutationLowering
LogicalResult matchAndRewrite(vector::TransferReadOp op,
PatternRewriter &rewriter) const override {
+ // TODO: support 0-d corner case.
+ if (op.getTransferRank() == 0)
+ return failure();
+
SmallVector<unsigned> permutation;
AffineMap map = op.permutation_map();
if (map.getNumResults() == 0)
@@ -99,7 +104,7 @@ struct TransferReadPermutationLowering
}
// Transpose in_bounds attribute.
- ArrayAttr newInBounds =
+ ArrayAttr newInBoundsAttr =
op.in_bounds() ? transposeInBoundsAttr(
rewriter, op.in_bounds().getValue(), permutation)
: ArrayAttr();
@@ -108,8 +113,8 @@ struct TransferReadPermutationLowering
VectorType newReadType =
VectorType::get(newVectorShape, op.getVectorType().getElementType());
Value newRead = rewriter.create<vector::TransferReadOp>(
- op.getLoc(), newReadType, op.source(), op.indices(), newMap,
- op.padding(), newMask, newInBounds);
+ op.getLoc(), newReadType, op.source(), op.indices(),
+ AffineMapAttr::get(newMap), op.padding(), newMask, newInBoundsAttr);
// Transpose result of transfer_read.
SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
@@ -141,7 +146,8 @@ struct TransferWritePermutationLowering
LogicalResult matchAndRewrite(vector::TransferWriteOp op,
PatternRewriter &rewriter) const override {
- if (op.isZeroD())
+ // TODO: support 0-d corner case.
+ if (op.getTransferRank() == 0)
return failure();
SmallVector<unsigned> permutation;
@@ -168,7 +174,7 @@ struct TransferWritePermutationLowering
: Value();
// Transpose in_bounds attribute.
- ArrayAttr newInBounds =
+ ArrayAttr newInBoundsAttr =
op.in_bounds() ? transposeInBoundsAttr(
rewriter, op.in_bounds().getValue(), permutation)
: ArrayAttr();
@@ -179,8 +185,8 @@ struct TransferWritePermutationLowering
auto newMap = AffineMap::getMinorIdentityMap(
map.getNumDims(), map.getNumResults(), rewriter.getContext());
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
- op, Type(), newVec, op.source(), op.indices(), newMap, newMask,
- newInBounds);
+ op, Type(), newVec, op.source(), op.indices(),
+ AffineMapAttr::get(newMap), newMask, newInBoundsAttr);
return success();
}
@@ -199,6 +205,10 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
LogicalResult matchAndRewrite(vector::TransferReadOp op,
PatternRewriter &rewriter) const override {
+ // TODO: support 0-d corner case.
+ if (op.getTransferRank() == 0)
+ return failure();
+
AffineMap map = op.permutation_map();
unsigned numLeadingBroadcast = 0;
for (auto expr : map.getResults()) {
@@ -245,14 +255,14 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
return failure();
VectorType newReadType =
VectorType::get(newShape, originalVecType.getElementType());
- ArrayAttr newInBounds =
+ ArrayAttr newInBoundsAttr =
op.in_bounds()
? rewriter.getArrayAttr(
op.in_boundsAttr().getValue().take_back(reducedShapeRank))
: ArrayAttr();
Value newRead = rewriter.create<vector::TransferReadOp>(
- op.getLoc(), newReadType, op.source(), op.indices(), newMap,
- op.padding(), op.mask(), newInBounds);
+ op.getLoc(), newReadType, op.source(), op.indices(),
+ AffineMapAttr::get(newMap), op.padding(), op.mask(), newInBoundsAttr);
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
newRead);
return success();
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 6bdbeb1a550b5..876f8aeb219cb 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -229,7 +229,9 @@ struct UnrollTransferReadPattern
options(options) {}
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
PatternRewriter &rewriter) const override {
-
+ // TODO: support 0-d corner case.
+ if (readOp.getTransferRank() == 0)
+ return failure();
if (readOp.mask())
return failure();
auto targetShape = getTargetShape(options, readOp);
@@ -254,9 +256,9 @@ struct UnrollTransferReadPattern
sliceTransferIndices(i, originalSize, *targetShape, originalIndices,
readOp.permutation_map(), loc, rewriter);
auto slicedRead = rewriter.create<vector::TransferReadOp>(
- loc, targetType, readOp.source(), indices, readOp.permutation_map(),
- readOp.padding(),
- readOp.in_bounds() ? *readOp.in_bounds() : ArrayAttr());
+ loc, targetType, readOp.source(), indices,
+ readOp.permutation_mapAttr(), readOp.padding(), readOp.mask(),
+ readOp.in_boundsAttr());
SmallVector<int64_t, 4> elementOffsets =
getVectorOffset(originalSize, *targetShape, i);
@@ -279,6 +281,10 @@ struct UnrollTransferWritePattern
options(options) {}
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
PatternRewriter &rewriter) const override {
+ // TODO: support 0-d corner case.
+ if (writeOp.getTransferRank() == 0)
+ return failure();
+
if (writeOp.mask())
return failure();
auto targetShape = getTargetShape(options, writeOp);
@@ -305,8 +311,7 @@ struct UnrollTransferWritePattern
writeOp.permutation_map(), loc, rewriter);
Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
loc, slicedVector, resultTensor ? resultTensor : writeOp.source(),
- indices, writeOp.permutation_map(),
- writeOp.in_bounds() ? *writeOp.in_bounds() : ArrayAttr());
+ indices, writeOp.permutation_mapAttr(), writeOp.in_boundsAttr());
// For the tensor case update the destination for the next transfer write.
if (!slicedWrite->getResults().empty())
resultTensor = slicedWrite->getResult(0);
@@ -2057,6 +2062,10 @@ static Value createInBoundsCond(OpBuilder &b,
/// rank-reducing subviews.
static LogicalResult
splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) {
+ // TODO: support 0-d corner case.
+ if (xferOp.getTransferRank() == 0)
+ return failure();
+
// TODO: expand support to these 2 cases.
if (!xferOp.permutation_map().isMinorIdentity())
return failure();
@@ -2682,6 +2691,10 @@ struct TransferReadExtractPattern
: OpRewritePattern<vector::TransferReadOp>(context) {}
LogicalResult matchAndRewrite(vector::TransferReadOp read,
PatternRewriter &rewriter) const override {
+ // TODO: support 0-d corner case.
+ if (read.getTransferRank() == 0)
+ return failure();
+
if (!read.getResult().hasOneUse())
return failure();
auto extract =
@@ -2711,8 +2724,8 @@ struct TransferReadExtractPattern
{indices[indexPos], extract.ids()[idCount++]});
}
Value newRead = lb.create<vector::TransferReadOp>(
- extract.getType(), read.source(), indices, read.permutation_map(),
- read.padding(), read.in_boundsAttr());
+ extract.getType(), read.source(), indices, read.permutation_mapAttr(),
+ read.padding(), read.mask(), read.in_boundsAttr());
Value dest = lb.create<arith::ConstantOp>(
read.getType(), rewriter.getZeroAttr(read.getType()));
newRead = lb.create<vector::InsertMapOp>(newRead, dest, extract.ids());
@@ -2727,6 +2740,10 @@ struct TransferWriteInsertPattern
: OpRewritePattern<vector::TransferWriteOp>(context) {}
LogicalResult matchAndRewrite(vector::TransferWriteOp write,
PatternRewriter &rewriter) const override {
+ // TODO: support 0-d corner case.
+ if (write.getTransferRank() == 0)
+ return failure();
+
auto insert = write.vector().getDefiningOp<vector::InsertMapOp>();
if (!insert)
return failure();
@@ -2754,8 +2771,8 @@ struct TransferWriteInsertPattern
{indices[indexPos], insert.ids()[idCount++]});
}
rewriter.create<vector::TransferWriteOp>(
- loc, insert.vector(), write.source(), indices, write.permutation_map(),
- write.in_boundsAttr());
+ loc, insert.vector(), write.source(), indices,
+ write.permutation_mapAttr(), write.in_boundsAttr());
rewriter.eraseOp(write);
return success();
}
@@ -2780,15 +2797,19 @@ struct TransferReadToVectorLoadLowering
PatternRewriter &rewriter) const override {
if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank)
return failure();
+
SmallVector<unsigned, 4> broadcastedDims;
// Permutations are handled by VectorToSCF or
// populateVectorTransferPermutationMapLoweringPatterns.
+ // We let the 0-d corner case pass-through as it is supported.
if (!read.permutation_map().isMinorIdentityWithBroadcasting(
&broadcastedDims))
return failure();
+
auto memRefType = read.getShapedType().dyn_cast<MemRefType>();
if (!memRefType)
return failure();
+
// Non-unit strides are handled by VectorToSCF.
if (!vector::isLastMemrefDimUnitStride(memRefType))
return failure();
@@ -2808,6 +2829,7 @@ struct TransferReadToVectorLoadLowering
auto memrefElTy = memRefType.getElementType();
if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType)
return failure();
+
// Otherwise, element types of the memref and the vector must match.
if (!memrefElTy.isa<VectorType>() &&
memrefElTy != read.getVectorType().getElementType())
@@ -2845,7 +2867,14 @@ struct TransferReadToVectorLoadLowering
llvm::Optional<unsigned> maxTransferRank;
};
-/// Replace a scalar vector.load with a memref.load.
+/// Replace a 0-d vector.load with a memref.load + vector.broadcast.
+// TODO: we shouldn't cross the vector/scalar domains just for this
+// but atm we lack the infra to avoid it. Possible solutions include:
+// - go directly to LLVM + bitcast
+// - introduce a bitcast op and likely a new pointer dialect
+// - let memref.load/store additionally support the 0-d vector case
+// There are still deeper data layout issues lingering even in this
+// trivial case (for architectures for which this matters).
struct VectorLoadToMemrefLoadLowering
: public OpRewritePattern<vector::LoadOp> {
using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
@@ -2857,13 +2886,13 @@ struct VectorLoadToMemrefLoadLowering
return failure();
auto memrefLoad = rewriter.create<memref::LoadOp>(
loadOp.getLoc(), loadOp.base(), loadOp.indices());
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
- loadOp, VectorType::get({1}, vecType.getElementType()), memrefLoad);
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType,
+ memrefLoad);
return success();
}
};
-/// Replace a scalar vector.store with a memref.store.
+/// Replace a 0-d vector.store with a vector.extractelement + memref.store.
struct VectorStoreToMemrefStoreLowering
: public OpRewritePattern<vector::StoreOp> {
using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
@@ -2873,9 +2902,17 @@ struct VectorStoreToMemrefStoreLowering
auto vecType = storeOp.getVectorType();
if (vecType.getNumElements() != 1)
return failure();
- SmallVector<int64_t> indices(vecType.getRank(), 0);
- Value extracted = rewriter.create<vector::ExtractOp>(
- storeOp.getLoc(), storeOp.valueToStore(), indices);
+ Value extracted;
+ if (vecType.getRank() == 0) {
+ // TODO: Unifiy once ExtractOp supports 0-d vectors.
+ extracted = rewriter.create<vector::ExtractElementOp>(
+ storeOp.getLoc(), storeOp.valueToStore());
+ } else {
+ SmallVector<int64_t> indices(vecType.getRank(), 0);
+ extracted = rewriter.create<vector::ExtractOp>(
+ storeOp.getLoc(), storeOp.valueToStore(), indices);
+ }
+
rewriter.replaceOpWithNewOp<memref::StoreOp>(
storeOp, extracted, storeOp.base(), storeOp.indices());
return success();
@@ -2901,25 +2938,32 @@ struct TransferWriteToVectorStoreLowering
PatternRewriter &rewriter) const override {
if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank)
return failure();
+
// Permutations are handled by VectorToSCF or
// populateVectorTransferPermutationMapLoweringPatterns.
- if (!write.isZeroD() && !write.permutation_map().isMinorIdentity())
+ if ( // pass-through for the 0-d corner case.
+ !write.permutation_map().isMinorIdentity())
return failure();
+
auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
if (!memRefType)
return failure();
+
// Non-unit strides are handled by VectorToSCF.
if (!vector::isLastMemrefDimUnitStride(memRefType))
return failure();
+
// `vector.store` supports vector types as memref's elements only when the
// type of the vector value being written is the same as the element type.
auto memrefElTy = memRefType.getElementType();
if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType())
return failure();
+
// Otherwise, element types of the memref and the vector must match.
if (!memrefElTy.isa<VectorType>() &&
memrefElTy != write.getVectorType().getElementType())
return failure();
+
// Out-of-bounds dims are handled by MaterializeTransferMask.
if (write.hasOutOfBoundsDim())
return failure();
@@ -3319,6 +3363,14 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
PatternRewriter &rewriter) const override {
+ // TODO: support 0-d corner case.
+ if (readOp.getTransferRank() == 0)
+ return failure();
+
+ // TODO: support mask.
+ if (readOp.mask())
+ return failure();
+
auto srcType = readOp.source().getType().dyn_cast<MemRefType>();
if (!srcType || !srcType.hasStaticShape())
return failure();
@@ -3375,7 +3427,7 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
SmallVector<int64_t> offsets(srcType.getRank(), 0);
SmallVector<int64_t> strides(srcType.getRank(), 1);
- ArrayAttr inBounds =
+ ArrayAttr inBoundsAttr =
readOp.in_bounds()
? rewriter.getArrayAttr(
readOp.in_boundsAttr().getValue().drop_back(dimsToDrop))
@@ -3387,8 +3439,10 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
rankedReducedView.getType().cast<ShapedType>(), resultTargetVecType);
Value result = rewriter.create<vector::TransferReadOp>(
loc, resultTargetVecType, rankedReducedView,
- readOp.indices().drop_back(dimsToDrop), permMap, readOp.padding(),
- inBounds);
+ readOp.indices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
+ readOp.padding(),
+ // TODO: support mask.
+ /*mask=*/Value(), inBoundsAttr);
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
result);
return success();
diff --git a/mlir/lib/Interfaces/VectorInterfaces.cpp b/mlir/lib/Interfaces/VectorInterfaces.cpp
index 625ffa9852396..3115a5d983c45 100644
--- a/mlir/lib/Interfaces/VectorInterfaces.cpp
+++ b/mlir/lib/Interfaces/VectorInterfaces.cpp
@@ -20,7 +20,7 @@ VectorType mlir::vector::detail::transferMaskType(VectorType vecType,
shape.push_back(vecType.getDimSize(i));
}
}
- return shape.empty() ? VectorType() : VectorType::get(shape, i1Type);
+ return VectorType::get(shape, i1Type);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index 08b3ffbdb688e..7cddb46f094e1 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -2,25 +2,20 @@
// RUN: mlir-opt %s -convert-vector-to-scf=full-unroll=true -split-input-file -allow-unregistered-dialect | FileCheck %s --check-prefix=FULL-UNROLL
// CHECK-LABEL: func @vector_transfer_ops_0d(
-// CHECK-SAME: %[[MEM:.*]]: memref<f32>) {
func @vector_transfer_ops_0d(%M: memref<f32>) {
- %f0 = arith.constant 0.0 : f32
-
-// CHECK: %[[V0:.*]] = arith.constant dense<0{{.*}}> : vector<1xf32>
-// CHECK: %[[R0:.*]] = scf.for %[[I:.*]] = {{.*}} iter_args(%[[V0_ITER:.*]] = %[[V0]]) -> (vector<1xf32>) {
-// CHECK: %[[S:.*]] = memref.load %[[MEM]][] : memref<f32>
-// CHECK: %[[R_ITER:.*]] = vector.insertelement %[[S]], %[[V0_ITER]][%[[I]] : index] : vector<1xf32>
-// CHECK: scf.yield %[[R_ITER]] : vector<1xf32>
- %0 = vector.transfer_read %M[], %f0 {permutation_map = affine_map<()->(0)>} :
- memref<f32>, vector<1xf32>
-
-// CHECK: scf.for %[[J:.*]] = %{{.*}}
-// CHECK: %[[SS:.*]] = vector.extractelement %[[R0]][%[[J]] : index] : vector<1xf32>
-// CHECK: memref.store %[[SS]], %[[MEM]][] : memref<f32>
- vector.transfer_write %0, %M[] {permutation_map = affine_map<()->(0)>} :
- vector<1xf32>, memref<f32>
-
- return
+ %f0 = arith.constant 0.0 : f32
+
+ // 0-d transfers are left untouched by vector-to-scf.
+ // They are independently lowered to the proper memref.load/store.
+ // CHECK: vector.transfer_read {{.*}}: memref<f32>, vector<f32>
+ %0 = vector.transfer_read %M[], %f0 {permutation_map = affine_map<()->()>} :
+ memref<f32>, vector<f32>
+
+ // CHECK: vector.transfer_write {{.*}}: vector<f32>, memref<f32>
+ vector.transfer_write %0, %M[] {permutation_map = affine_map<()->()>} :
+ vector<f32>, memref<f32>
+
+ return
}
// -----
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index c055ef47a36d0..b7ef524475487 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -200,8 +200,8 @@ func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
// CHECK-LABEL: func @test_vectorize_fill
func @test_vectorize_fill_scalar(%A : memref<f32>, %arg0 : f32) {
// CHECK-SAME: (%[[M:.*]]: memref<f32>, %[[val:.*]]: f32)
- // CHECK: %[[VEC:.*]] = vector.broadcast %[[val]] : f32 to vector<1xf32>
- // CHECK: vector.transfer_write %[[VEC]], %[[M]][] {{.*}} : vector<1xf32>, memref<f32>
+ // CHECK: %[[VEC:.*]] = vector.broadcast %[[val]] : f32 to vector<f32>
+ // CHECK: vector.transfer_write %[[VEC]], %[[M]][] : vector<f32>, memref<f32>
linalg.fill(%arg0, %A) : f32, memref<f32>
return
}
@@ -221,10 +221,10 @@ func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) {
// CHECK-LABEL: func @test_vectorize_copy_scalar
func @test_vectorize_copy_scalar(%A : memref<f32>, %B : memref<f32>) {
// CHECK-SAME: (%[[A:.*]]: memref<f32>, %[[B:.*]]: memref<f32>)
- // CHECK: %[[V:.*]] = vector.transfer_read %[[A]][]{{.*}} : memref<f32>, vector<1xf32>
- // CHECK: %[[val:.*]] = vector.extract %[[V]][0] : vector<1xf32>
- // CHECK: %[[VV:.*]] = vector.broadcast %[[val]] : f32 to vector<1xf32>
- // CHECK: vector.transfer_write %[[VV]], %[[B]][] {{.*}} : vector<1xf32>, memref<f32>
+ // CHECK: %[[V:.*]] = vector.transfer_read %[[A]][]{{.*}} : memref<f32>, vector<f32>
+ // CHECK: %[[val:.*]] = vector.extractelement %[[V]][] : vector<f32>
+ // CHECK: %[[VV:.*]] = vector.broadcast %[[val]] : f32 to vector<f32>
+ // CHECK: vector.transfer_write %[[VV]], %[[B]][] : vector<f32>, memref<f32>
linalg.copy(%A, %B) : memref<f32>, memref<f32>
return
}
@@ -1005,7 +1005,7 @@ func @fused_broadcast_red_2d(%arg0: tensor<4x4xf32>, %arg1: tensor<4x1xf32>) ->
// CHECK-LABEL: func @reduce_1d(
// CHECK-SAME: %[[A:.*]]: tensor<32xf32>
func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {
- // CHECK-DAG: %[[F0_v1:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
+ // CHECK-DAG: %[[vF0:.*]] = arith.constant dense<0.000000e+00> : vector<f32>
// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
%f0 = arith.constant 0.000000e+00 : f32
@@ -1013,17 +1013,18 @@ func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {
// CHECK: %[[init:.*]] = linalg.init_tensor [] : tensor<f32>
%0 = linalg.init_tensor [] : tensor<f32>
- // CHECK: %[[f:.*]] = vector.transfer_write %[[F0_v1]], %[[init]][]
- // CHECK-SAME: : vector<1xf32>, tensor<f32>
+ // CHECK: %[[f:.*]] = vector.transfer_write %[[vF0]], %[[init]][]
+ // CHECK-SAME: : vector<f32>, tensor<f32>
%1 = linalg.fill(%f0, %0) : f32, tensor<f32> -> tensor<f32>
// CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]]
// CHECK-SAME: : tensor<32xf32>, vector<32xf32>
+ // CHECK: %[[f0:.*]] = vector.extractelement %[[vF0]][] : vector<f32>
// CHECK: %[[red:.*]] = vector.multi_reduction #vector.kind<add>, %[[r]] [0]
// CHECK-SAME: : vector<32xf32> to f32
- // CHECK: %[[a:.*]] = arith.addf %[[red]], %[[F0]] : f32
- // CHECK: %[[red_v1:.*]] = vector.broadcast %[[a]] : f32 to vector<1xf32>
+ // CHECK: %[[a:.*]] = arith.addf %[[red]], %[[f0]] : f32
+ // CHECK: %[[red_v1:.*]] = vector.broadcast %[[a]] : f32 to vector<f32>
// CHECK: %[[res:.*]] = vector.transfer_write %[[red_v1]], %[[f]][]
- // CHECK-SAME: : vector<1xf32>, tensor<f32>
+ // CHECK-SAME: : vector<f32>, tensor<f32>
%2 = linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>,
affine_map<(d0) -> ()>],
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 593686a425a51..c550a0818809f 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1427,15 +1427,3 @@ func @insert_map_id(%v: vector<2x1xf32>, %v1: vector<4x32xf32>, %id : index) {
%0 = vector.insert_map %v, %v1[%id] : vector<2x1xf32> into vector<4x32xf32>
}
-// -----
-
-func @vector_transfer_ops_0d(%arg0: tensor<f32>)
- -> tensor<f32> {
- %f0 = arith.constant 0.0 : f32
- // expected-error at +1 {{0-d transfer requires vector<1xt> shape and () -> (0) permutation_map}}
- %0 = vector.transfer_read %arg0[], %f0 {permutation_map = affine_map<(d0)->(d0)>} :
- tensor<f32>, vector<1xf32>
- %1 = vector.transfer_write %0, %arg0[] {permutation_map = affine_map<()->(0)>} :
- vector<1xf32>, tensor<f32>
- return %1: tensor<f32>
-}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 11b986fc9b87c..576924e1addff 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -4,17 +4,33 @@
func @vector_transfer_ops_0d(%arg0: tensor<f32>, %arg1: memref<f32>)
-> tensor<f32> {
%f0 = arith.constant 0.0 : f32
- %0 = vector.transfer_read %arg0[], %f0 {permutation_map = affine_map<()->(0)>} :
- tensor<f32>, vector<1xf32>
- %1 = vector.transfer_write %0, %arg0[] {permutation_map = affine_map<()->(0)>} :
- vector<1xf32>, tensor<f32>
- %2 = vector.transfer_read %arg1[], %f0 {permutation_map = affine_map<()->(0)>} :
- memref<f32>, vector<1xf32>
- vector.transfer_write %2, %arg1[] {permutation_map = affine_map<()->(0)>} :
- vector<1xf32>, memref<f32>
+ %0 = vector.transfer_read %arg0[], %f0 {permutation_map = affine_map<()->()>} :
+ tensor<f32>, vector<f32>
+ %1 = vector.transfer_write %0, %arg0[] {permutation_map = affine_map<()->()>} :
+ vector<f32>, tensor<f32>
+ %2 = vector.transfer_read %arg1[], %f0 {permutation_map = affine_map<()->()>} :
+ memref<f32>, vector<f32>
+ vector.transfer_write %2, %arg1[] {permutation_map = affine_map<()->()>} :
+ vector<f32>, memref<f32>
return %1: tensor<f32>
}
+// CHECK-LABEL: func @vector_transfer_ops_0d_from_higher_d(
+func @vector_transfer_ops_0d_from_higher_d(%arg0: tensor<?xf32>, %arg1: memref<?x?xf32>)
+ -> tensor<?xf32> {
+ %c0 = arith.constant 0 : index
+ %f0 = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %arg0[%c0], %f0 {permutation_map = affine_map<(d0)->()>} :
+ tensor<?xf32>, vector<f32>
+ %1 = vector.transfer_write %0, %arg0[%c0] {permutation_map = affine_map<(d0)->()>} :
+ vector<f32>, tensor<?xf32>
+ %2 = vector.transfer_read %arg1[%c0, %c0], %f0 {permutation_map = affine_map<(d0, d1)->()>} :
+ memref<?x?xf32>, vector<f32>
+ vector.transfer_write %2, %arg1[%c0, %c0] {permutation_map = affine_map<(d0, d1)->()>} :
+ vector<f32>, memref<?x?xf32>
+ return %1: tensor<?xf32>
+}
+
// CHECK-LABEL: func @vector_transfer_ops(
func @vector_transfer_ops(%arg0: memref<?x?xf32>,
%arg1 : memref<?x?xvector<4x3xf32>>,
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index a5c0cb584b11b..562870c4d9fe6 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -6,13 +6,13 @@
func @vector_transfer_ops_0d_memref(%M: memref<f32>, %v: vector<1x1x1xf32>) {
%f0 = arith.constant 0.0 : f32
-// CHECK-NEXT: %[[V:.*]] = memref.load %[[MEM]][] : memref<f32>
- %0 = vector.transfer_read %M[], %f0 {permutation_map = affine_map<()->(0)>} :
- memref<f32>, vector<1xf32>
+// CHECK-NEXT: %[[s:.*]] = memref.load %[[MEM]][] : memref<f32>
+// CHECK-NEXT: %[[V:.*]] = vector.broadcast %[[s]] : f32 to vector<f32>
+ %0 = vector.transfer_read %M[], %f0 : memref<f32>, vector<f32>
-// CHECK-NEXT: memref.store %[[V]], %[[MEM]][] : memref<f32>
- vector.transfer_write %0, %M[] {permutation_map = affine_map<()->(0)>} :
- vector<1xf32>, memref<f32>
+// CHECK-NEXT: %[[ss:.*]] = vector.extractelement %[[V]][] : vector<f32>
+// CHECK-NEXT: memref.store %[[ss]], %[[MEM]][] : memref<f32>
+ vector.transfer_write %0, %M[] : vector<f32>, memref<f32>
// CHECK-NEXT: %[[VV:.*]] = vector.extract %arg1[0, 0, 0] : vector<1x1x1xf32>
// CHECK-NEXT: memref.store %[[VV]], %[[MEM]][] : memref<f32>
More information about the Mlir-commits
mailing list