[Mlir-commits] [mlir] [mlir] Add direct vectorization lowering for `tensor.pack` ops (PR #78660)
Han-Chung Wang
llvmlistbot at llvm.org
Thu Jan 18 23:55:59 PST 2024
================
@@ -1393,6 +1399,121 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
return success();
}
+/// Given a tensor::PackOp, return the permutation from the "tiled"
+/// shape to the "packed" shape, defined as the following:
+/// The "packed" shape is the same as the `dest` shape of the pack op.
+/// The "tiled" shape is a permutation of the `dest` shape such that
+/// each outer dimension is in the original `source` order, and the
+/// inner_tile dimensions immediately follow their corresponding outer
+/// dimension.
+/// i.e. for the following tensor.pack:
+/// ```mlir
+/// %pack = tensor.pack %0 padding_value(%1)
+/// outer_dims_perm = [0, 2, 1]
+/// inner_dims_pos = [2, 1]
+/// inner_tiles = [16, 2]
+/// into %2 : tensor<32x8x16> -> tensor<32x1x4x16x2>
+/// ```
+/// The "packed" shape is `32x1x4x16x2`
+/// The "tiled" shape is `32x(4x2)x(1x16)`
+static SmallVector<int64_t>
+getTiledShapeToPackedShapePerm(tensor::PackOp packOp) {
+ auto innerTiles = packOp.getInnerTiles();
+ int64_t srcRank = packOp.getSourceRank();
+ auto innerDimsPos = packOp.getInnerDimsPos();
+ if (innerDimsPos.empty())
+ innerDimsPos = to_vector(llvm::seq<int64_t>(innerTiles.size()));
+ auto outerDimsPerm = packOp.getOuterDimsPerm();
+ if (outerDimsPerm.empty())
+ outerDimsPerm = to_vector(llvm::seq<int64_t>(srcRank));
+ auto packedIdxToTiledIdx = [&](int64_t idx) -> int64_t {
+ int64_t srcIdx;
+ if (idx >= srcRank)
+ srcIdx = innerDimsPos[idx - srcRank];
+ else
+ srcIdx = outerDimsPerm[idx];
+ int64_t tiledIdx = srcIdx;
+ for (int64_t pos : innerDimsPos)
+ if (pos < srcIdx)
+ tiledIdx++;
+ if (idx >= srcRank)
+ tiledIdx++;
+ return tiledIdx;
+ };
+ SmallVector<int64_t> perm;
+ for (int i = 0; i < packOp.getDestRank(); i++)
+ perm.push_back(packedIdxToTiledIdx(i));
+ return perm;
+}
+
+/// Given a tensor::PackOp, return the "tiled" `dest` shape as described
+/// above in `getTiledShapeToPackedShapePerm`.
+static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp) {
+ auto perm = getTiledShapeToPackedShapePerm(packOp);
+ auto destShape = packOp.getDestType().getShape();
+ return applyPermutation(destShape, invertPermutationVector(perm));
+}
+
+///
+static LogicalResult
+vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
+ ArrayRef<int64_t> inputVectorSizes,
+ SmallVectorImpl<Value> &newResults) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(packOp);
+
+ Location loc = packOp.getLoc();
+ auto padValue = packOp.getPaddingValue();
+ if (!padValue) {
+ padValue = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
+ }
+ int64_t inputRank = inputVectorSizes.size();
+ int64_t outputRank = packOp.getDestRank();
+ auto maskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
+ auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
+
+ ReifiedRankedShapedTypeDims reifiedReturnShapes;
+ LogicalResult status =
+ cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
+ .reifyResultShapes(rewriter, reifiedReturnShapes);
+ (void)status; // prevent unused variable warning on non-assert builds
+ assert(succeeded(status) && "failed to reify result shapes");
+ auto emptyOp = rewriter.create<tensor::EmptyOp>(loc, reifiedReturnShapes[0],
+ padValue.getType());
+ SmallVector<OpFoldResult> mixedSourceDims =
+ tensor::getMixedSizes(rewriter, loc, packOp.getSource());
+ Value mask =
+ rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
+ auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto transferReadOp = rewriter.create<vector::TransferReadOp>(
+ loc,
+ /*vectorType=*/vectorType,
+ /*source=*/packOp.getSource(),
+ /*indices=*/SmallVector<Value>(inputRank, zero),
+ /*padding=*/padValue,
+ /*inBounds=*/SmallVector<bool>(inputRank, true));
+ auto maskedOp = cast<vector::MaskOp>(
+ mlir::vector::maskOperation(rewriter, transferReadOp, mask));
+ // ShapeCast
+ auto tiledPackShape = getTiledPackShape(packOp);
+ auto tiledPackType =
+ VectorType::get(tiledPackShape, packOp.getDestType().getElementType());
+ auto shapeCastOp = rewriter.create<vector::ShapeCastOp>(
+ loc, tiledPackType, maskedOp->getResult(0));
+ auto tiledShapeToPackedShapePerm = getTiledShapeToPackedShapePerm(packOp);
+ auto transposeOp = rewriter.create<vector::TransposeOp>(
+ loc, shapeCastOp->getResult(0), tiledShapeToPackedShapePerm);
+ Operation *write = rewriter.create<vector::TransferWriteOp>(
+ loc,
+ /*vector=*/transposeOp->getResult(0),
+ /*source=*/emptyOp,
+ /*indices=*/SmallVector<Value>(outputRank, zero),
+ /*inBounds=*/SmallVector<bool>(outputRank, true));
----------------
hanhanW wrote:
We need to mask the write if the shape and provided input vector do not match. If you follow what I suggest about input_vector changes, the check will be something like:
```
bool needMaskForWrite = llvm::any_of(
llvm::zip_equal(inputVectorSizes, packOp.getDestType().getShape().drop_back(innerTiles.size())),
[](auto it) { return std::get<0>(it) != std::get<1>(it); });
```
https://github.com/llvm/llvm-project/pull/78660
More information about the Mlir-commits
mailing list