[Mlir-commits] [mlir] [mlir][Vectorizer] Added support to Vectorize tensor.unpack (PR #76087)
Balaji V. Iyer.
llvmlistbot at llvm.org
Fri Feb 16 21:11:08 PST 2024
================
@@ -1559,6 +1558,111 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
return success();
}
+/// Vectorize a `tensor::UnPackOp` to these 4 Ops:
+/// Vector::TransferReadOp - Reads a vector from the source tensor
+/// vector::TransposeOp - Transpose the Source tensor
+/// ShapeCastOp - Reshape the data based on the target.
+/// vector::TransferWriteOp. - Write the result vector back to the destination
+/// tensor
+static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
+ tensor::UnPackOp unpackOp,
+ ArrayRef<int64_t> inputVectorSizes,
+ SmallVectorImpl<Value> &newResults) {
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(unpackOp);
+
+ RankedTensorType unpackTensorType = unpackOp.getSourceType();
+
+ ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
+ ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
+
+ SmallVector<int64_t> readMaskShape(inputVectorSizes.begin(),
+ inputVectorSizes.end());
+ ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
+ if (!outerDimsPerm.empty()) {
+ applyPermutationToVector(readMaskShape, outerDimsPerm);
+ }
+ ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
+ readMaskShape.append(sourceShape.begin() + inputVectorSizes.size(),
+ sourceShape.end());
+
+ // ReadMask is the size of tensor used to read and apply mask. It is
+ // set like this: Let's say the vectorSize (VS) array is size 'N' and
+ // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
+ // size M-N
+ // Thus:
+ // - initially: ReadMaskShape = vectorInputSizes
+ // - if outer_dims_perms is present: do that permutation on readMaskShape.
+ // - Append the remaining shape from SS
+ // - Divide all teh readMaskShape locations pointed by innerDimPos
+ // by the innerTileSize attribute value.
+ // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
----------------
bviyer wrote:
Fixed.
https://github.com/llvm/llvm-project/pull/76087
More information about the Mlir-commits
mailing list