[Mlir-commits] [mlir] [mlir][Vectorizer] Added support to Vectorize tensor.unpack (PR #76087)
Han-Chung Wang
llvmlistbot at llvm.org
Tue Feb 6 21:15:47 PST 2024
================
@@ -1393,6 +1394,130 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
return success();
}
+/// Vectorize an `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
+/// Vector::TransferReadOp - Reads the Vector Array of Source data
+/// vector::TransposeOp - Transpose the Source
+/// ShapeCastOp - Reshapes the data based on the target.
+/// vector::TransferWriteOp. - Write the result vector back.
+static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
+ tensor::UnPackOp unpackOp,
+ ArrayRef<int64_t> inputVectorSizes,
+ SmallVectorImpl<Value> &newResults) {
+ // Handling this case requires a bit more change. Right now
+ // just the required attributes are handled.
+ if (!unpackOp.getOuterDimsPerm().empty()) {
+ LDBG("outer dimensions perms NYI for: " << unpackOp);
+ return failure();
+ }
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(unpackOp);
+
+ RankedTensorType unpackTensorType = unpackOp.getSourceType();
+ llvm::SmallVector<int64_t> readMaskShape(unpackTensorType.getShape());
+ for (unsigned int ii = 0; ii < inputVectorSizes.size(); ii++) {
+ readMaskShape[ii] = inputVectorSizes[ii];
+ }
+
+ // ReadMask is the size of tensor used to read and apply mask. It is
+ // set like this. Let's say the vectorSize (VS) array is size 'N' and
+ // the sourceShape(SS) is 'M' where M >= N
+ // Thus:
+ // ReadMaskShape = [VS[0], ..., VS[N-1], SS[N], ..., SS[M-1]]
+ auto vectorType =
+ VectorType::get(readMaskShape, unpackTensorType.getElementType());
+ ReifiedRankedShapedTypeDims reifiedRetShapes;
+ LogicalResult status =
+ cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
+ .reifyResultShapes(rewriter, reifiedRetShapes);
+ if (status.failed()) {
+ LDBG("Unable to reify result shapes of " << unpackOp);
+ return failure();
+ }
+ int64_t unpackRank = unpackTensorType.getRank();
+ arith::ConstantIndexOp zeroOp =
+ rewriter.create<arith::ConstantIndexOp>(unpackOp->getLoc(), 0);
+
+ vector::TransferReadOp readOp = rewriter.create<vector::TransferReadOp>(
+ unpackOp.getLoc(), vectorType, unpackOp.getSource(),
+ SmallVector<Value>(unpackRank, zeroOp),
+ rewriter.getMultiDimIdentityMap(unpackRank));
+
+ auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
+ Value mask = rewriter.create<vector::CreateMaskOp>(
+ unpackOp.getLoc(), readMaskType,
+ tensor::getMixedSizes(rewriter, unpackOp.getLoc(), unpackOp.getSource()));
+ vector::MaskOp maskedOp =
+ cast<vector::MaskOp>(mlir::vector::maskOperation(rewriter, readOp, mask));
+
+ int64_t numPackedDim = unpackOp.getInnerDimsPos().size();
+ llvm::SmallVector<int64_t> lastDims = llvm::to_vector(
+ llvm::seq<int64_t>(unpackRank - numPackedDim, unpackRank));
+ PackingMetadata packMetadata =
+ computePackingMetadata(unpackRank, unpackOp.getInnerDimsPos());
+ SmallVector<int64_t> lastDimToInsertPosPerm = computePermutationVector(
+ unpackRank, lastDims, packMetadata.insertPositions);
+ ShapedType maskedOpShapedType =
+ cast<ShapedType>(maskedOp.getResult(0).getType());
+ SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
+ mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
+ applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
+
+ RankedTensorType stripMineTensorType =
+ RankedTensorType::Builder(stripMineShape, stripMineElemType, {})
+ .setShape(stripMineShape);
----------------
hanhanW wrote:
It looks like they are copied from lowerUnPack? Can you try to refactor them to `Tensor/Utils/`?
https://github.com/llvm/llvm-project/pull/76087/files#diff-be7661e240d890a641a87f194bbb12e84a340b4a7f408865d5b89082123e24fbR1453-R1468
https://github.com/llvm/llvm-project/pull/76087
More information about the Mlir-commits
mailing list