[Mlir-commits] [mlir] [mlir][Vectorizer] Added support to Vectorize tensor.unpack (PR #76087)
Han-Chung Wang
llvmlistbot at llvm.org
Tue Feb 6 21:15:49 PST 2024
================
@@ -1393,6 +1394,130 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
return success();
}
+/// Vectorize an `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
+/// Vector::TransferReadOp - Reads the Vector Array of Source data
+/// vector::TransposeOp - Transpose the Source
+/// ShapeCastOp - Reshapes the data based on the target.
+/// vector::TransferWriteOp. - Write the result vector back.
+static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
+ tensor::UnPackOp unpackOp,
+ ArrayRef<int64_t> inputVectorSizes,
+ SmallVectorImpl<Value> &newResults) {
+ // Handling this case requires a bit more change. Right now
+ // just the required attributes are handled.
+ if (!unpackOp.getOuterDimsPerm().empty()) {
+ LDBG("outer dimensions perms NYI for: " << unpackOp);
+ return failure();
+ }
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(unpackOp);
+
+ RankedTensorType unpackTensorType = unpackOp.getSourceType();
+ llvm::SmallVector<int64_t> readMaskShape(unpackTensorType.getShape());
+ for (unsigned int ii = 0; ii < inputVectorSizes.size(); ii++) {
+ readMaskShape[ii] = inputVectorSizes[ii];
+ }
+
+ // ReadMask is the size of tensor used to read and apply mask. It is
+ // set like this. Let's say the vectorSize (VS) array is size 'N' and
+ // the sourceShape(SS) is 'M' where M >= N
+ // Thus:
+ // ReadMaskShape = [VS[0], ..., VS[N-1], SS[N], ..., SS[M-1]]
+ auto vectorType =
+ VectorType::get(readMaskShape, unpackTensorType.getElementType());
+ ReifiedRankedShapedTypeDims reifiedRetShapes;
+ LogicalResult status =
+ cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
+ .reifyResultShapes(rewriter, reifiedRetShapes);
+ if (status.failed()) {
+ LDBG("Unable to reify result shapes of " << unpackOp);
+ return failure();
+ }
+ int64_t unpackRank = unpackTensorType.getRank();
+ arith::ConstantIndexOp zeroOp =
+ rewriter.create<arith::ConstantIndexOp>(unpackOp->getLoc(), 0);
+
+ vector::TransferReadOp readOp = rewriter.create<vector::TransferReadOp>(
+ unpackOp.getLoc(), vectorType, unpackOp.getSource(),
+ SmallVector<Value>(unpackRank, zeroOp),
+ rewriter.getMultiDimIdentityMap(unpackRank));
+
+ auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
+ Value mask = rewriter.create<vector::CreateMaskOp>(
+ unpackOp.getLoc(), readMaskType,
+ tensor::getMixedSizes(rewriter, unpackOp.getLoc(), unpackOp.getSource()));
+ vector::MaskOp maskedOp =
+ cast<vector::MaskOp>(mlir::vector::maskOperation(rewriter, readOp, mask));
+
+ int64_t numPackedDim = unpackOp.getInnerDimsPos().size();
+ llvm::SmallVector<int64_t> lastDims = llvm::to_vector(
+ llvm::seq<int64_t>(unpackRank - numPackedDim, unpackRank));
+ PackingMetadata packMetadata =
+ computePackingMetadata(unpackRank, unpackOp.getInnerDimsPos());
+ SmallVector<int64_t> lastDimToInsertPosPerm = computePermutationVector(
+ unpackRank, lastDims, packMetadata.insertPositions);
+ ShapedType maskedOpShapedType =
+ cast<ShapedType>(maskedOp.getResult(0).getType());
+ SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
+ mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
+ applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
+
+ RankedTensorType stripMineTensorType =
+ RankedTensorType::Builder(stripMineShape, stripMineElemType, {})
+ .setShape(stripMineShape);
+
+ // Collapse the tensor to the size required by result.
+ RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+ stripMineTensorType, packMetadata.reassociations);
+ auto vecCollapsedType =
+ VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
+
+ // Transpose the appropriate rows to match output.
+ vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
+ unpackOp.getLoc(), maskedOp.getResult(0), lastDimToInsertPosPerm);
+
+ vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
+ unpackOp.getLoc(), vecCollapsedType, transposeOp->getResult(0));
+ tensor::EmptyOp emptyOp =
+ rewriter.create<tensor::EmptyOp>(unpackOp.getLoc(), reifiedRetShapes[0],
+ unpackTensorType.getElementType());
+
+ int64_t destRank = cast<ShapedType>(emptyOp.getType()).getRank();
+ Operation *writeOp = rewriter.create<vector::TransferWriteOp>(
+ unpackOp.getLoc(), shapeCastOp->getResult(0), emptyOp,
+ SmallVector<Value>(destRank, zeroOp), SmallVector<bool>(destRank, true));
+ auto resultShape = unpackOp.getResult().getType().getShape();
+
+ // If the shape of the result doesn't match the inputVectorSizes, a mask
+ // is necessary.
+ bool needMaskForWrite =
+ llvm::any_of(llvm::zip_equal(inputVectorSizes, resultShape),
+ [](auto it) { return std::get<0>(it) != std::get<1>(it); });
+ mlir::OpResult result = writeOp->getResult(0);
+ if (needMaskForWrite) {
+ SmallVector<int64_t> writeMaskShape(inputVectorSizes);
+ llvm::ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
+ llvm::ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
+ for (auto [index, size] : enumerate(innerTiles)) {
+ writeMaskShape[innerDimPos[index]] *= size;
+ }
+ // WriteMaskShape is computed using the vectorSizes, inner Dim Position and
+ // innerTiles.
+ // WriteMaskShape (WMS) initialized to [inputVectorSizes]
+ // for-each index, value in inner-Tiles vector:
+ // WMS[innerDimPos[index]] = WMS[innerDimPos[index]] * value
+ auto writeMaskType = VectorType::get(writeMaskShape, rewriter.getI1Type());
+ Value writeMask = rewriter.create<vector::CreateMaskOp>(
+ unpackOp.getLoc(), writeMaskType, reifiedRetShapes[0]);
+ Operation *writeOpWithMask =
+ mlir::vector::maskOperation(rewriter, writeOp, writeMask);
+ result = writeOpWithMask->getResult(0);
+ }
----------------
hanhanW wrote:
same here, can we try to refactor/generalize it a bit more? Like what Max has done in his PR.
https://github.com/llvm/llvm-project/pull/76087
More information about the Mlir-commits
mailing list