[Mlir-commits] [mlir] [mlir][Vectorizer] Added support to Vectorize tensor.unpack (PR #76087)
Han-Chung Wang
llvmlistbot at llvm.org
Tue Feb 6 21:15:48 PST 2024
================
@@ -1393,6 +1394,130 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
return success();
}
+/// Vectorize an `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
+/// Vector::TransferReadOp - Reads the Vector Array of Source data
+/// vector::TransposeOp - Transpose the Source
+/// ShapeCastOp - Reshapes the data based on the target.
+/// vector::TransferWriteOp. - Write the result vector back.
+static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
+ tensor::UnPackOp unpackOp,
+ ArrayRef<int64_t> inputVectorSizes,
+ SmallVectorImpl<Value> &newResults) {
+ // Handling this case requires a bit more change. Right now
+ // just the required attributes are handled.
+ if (!unpackOp.getOuterDimsPerm().empty()) {
+ LDBG("outer dimensions perms NYI for: " << unpackOp);
+ return failure();
+ }
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(unpackOp);
+
+ RankedTensorType unpackTensorType = unpackOp.getSourceType();
+ llvm::SmallVector<int64_t> readMaskShape(unpackTensorType.getShape());
+ for (unsigned int ii = 0; ii < inputVectorSizes.size(); ii++) {
+ readMaskShape[ii] = inputVectorSizes[ii];
+ }
+
+ // ReadMask is the size of tensor used to read and apply mask. It is
+ // set like this. Let's say the vectorSize (VS) array is size 'N' and
+ // the sourceShape(SS) is 'M' where M >= N
+ // Thus:
+ // ReadMaskShape = [VS[0], ..., VS[N-1], SS[N], ..., SS[M-1]]
+ auto vectorType =
+ VectorType::get(readMaskShape, unpackTensorType.getElementType());
+ ReifiedRankedShapedTypeDims reifiedRetShapes;
+ LogicalResult status =
+ cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
+ .reifyResultShapes(rewriter, reifiedRetShapes);
+ if (status.failed()) {
+ LDBG("Unable to reify result shapes of " << unpackOp);
+ return failure();
+ }
+ int64_t unpackRank = unpackTensorType.getRank();
+ arith::ConstantIndexOp zeroOp =
+ rewriter.create<arith::ConstantIndexOp>(unpackOp->getLoc(), 0);
+
+ vector::TransferReadOp readOp = rewriter.create<vector::TransferReadOp>(
+ unpackOp.getLoc(), vectorType, unpackOp.getSource(),
+ SmallVector<Value>(unpackRank, zeroOp),
+ rewriter.getMultiDimIdentityMap(unpackRank));
+
+ auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
+ Value mask = rewriter.create<vector::CreateMaskOp>(
+ unpackOp.getLoc(), readMaskType,
+ tensor::getMixedSizes(rewriter, unpackOp.getLoc(), unpackOp.getSource()));
+ vector::MaskOp maskedOp =
+ cast<vector::MaskOp>(mlir::vector::maskOperation(rewriter, readOp, mask));
+
+ int64_t numPackedDim = unpackOp.getInnerDimsPos().size();
+ llvm::SmallVector<int64_t> lastDims = llvm::to_vector(
+ llvm::seq<int64_t>(unpackRank - numPackedDim, unpackRank));
+ PackingMetadata packMetadata =
+ computePackingMetadata(unpackRank, unpackOp.getInnerDimsPos());
+ SmallVector<int64_t> lastDimToInsertPosPerm = computePermutationVector(
+ unpackRank, lastDims, packMetadata.insertPositions);
+ ShapedType maskedOpShapedType =
+ cast<ShapedType>(maskedOp.getResult(0).getType());
+ SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
+ mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
+ applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
+
+ RankedTensorType stripMineTensorType =
+ RankedTensorType::Builder(stripMineShape, stripMineElemType, {})
+ .setShape(stripMineShape);
+
+ // Collapse the tensor to the size required by result.
+ RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+ stripMineTensorType, packMetadata.reassociations);
+ auto vecCollapsedType =
+ VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
----------------
hanhanW wrote:
please move them to right before where it is used. The comment seems to be outdated? You are shape_cast on vectors, not tensors.
https://github.com/llvm/llvm-project/pull/76087
More information about the Mlir-commits
mailing list