[Mlir-commits] [mlir] [mlir][vector] add tensor.concat, bitcast, expand_shape, collapse_shape vectorization support (PR #97297)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Tue Jul 9 02:11:44 PDT 2024
================
@@ -1718,6 +1718,209 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
return success();
}
+/// Vectorize a `tensor::expandshape` to these 3 Ops:
+/// Vector::TransferReadOp - Reads a vector from the source tensor
+/// ShapeCastOp - Reshape the data based on the target.
+/// vector::TransferWriteOp. - Write the result vector back to the destination
+/// tensor
+static LogicalResult lowerTensorReshape(RewriterBase &rewriter,
+ Operation *inputOp,
+ ArrayRef<int64_t> inputVectorSizes,
+ SmallVectorImpl<Value> &newResults) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(inputOp);
+ auto src = inputOp->getOperand(0);
+ auto srcType = mlir::dyn_cast<ShapedType>(src.getType());
+ auto result = inputOp->getResults()[0];
+ auto resultType = mlir::dyn_cast<ShapedType>(result.getType());
+ ArrayRef<int64_t> resultShape = resultType.getShape();
+ ArrayRef<int64_t> srcShape = srcType.getShape();
+ Location loc = inputOp->getLoc();
+
+ llvm::SmallVector<int64_t> srcVectorizedShape;
+ llvm::SmallDenseMap<int64_t, int64_t> shapeScales;
+
+ auto getVectorizeShape = [&](ArrayRef<int64_t> &retShape,
+ ArrayRef<int64_t> &inputShape) {
+ bool isResultShapeBigger = srcType.getRank() < resultType.getRank();
+
+ int64_t cur = 1, resultIdx = 0;
+ for (auto [srcIdx, ss] : llvm::enumerate(inputShape)) {
+ cur *= ss;
+ if (!isResultShapeBigger) {
+ // collapse
+ srcVectorizedShape.emplace_back(ss);
+ if (cur == retShape[resultIdx]) {
+ if (shapeScales.count(resultIdx)) {
+ srcVectorizedShape.back() *= shapeScales[resultIdx];
+ }
+ cur = 1;
+ resultIdx++;
+ }
+ } else {
+ // expand
+ if (cur == retShape[resultIdx]) {
+ srcVectorizedShape.emplace_back(cur);
+ if (shapeScales.count(srcIdx)) {
+ srcVectorizedShape.back() *= shapeScales[srcIdx];
+ }
+ cur = 1;
+ resultIdx++;
+ }
+ }
+ }
+ };
+ if (!inputVectorSizes.empty()) {
+ for (auto [idx, vs] : llvm::enumerate(inputVectorSizes)) {
+ if (vs != resultShape[idx])
+ shapeScales[idx] = vs / resultShape[idx];
+ }
+
+ bool isResultShapeBigger = srcType.getRank() < resultType.getRank();
+ if (!isResultShapeBigger) {
+ getVectorizeShape(resultShape, srcShape);
+ } else {
+ getVectorizeShape(srcShape, resultShape);
+ }
+ } else {
+ srcVectorizedShape.assign(srcShape.begin(), srcShape.end());
+ }
+ // read
+ auto padValue = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(srcType.getElementType()));
+ Value readResult = vector::createReadOrMaskedRead(
+ rewriter, loc, src,
+ inputVectorSizes.empty() ? srcType.getShape() : srcVectorizedShape,
+ padValue, false);
+
+ auto shapeCastType =
+ VectorType::get(inputVectorSizes.empty() ? resultShape : inputVectorSizes,
+ resultType.getElementType());
+ vector::ShapeCastOp shapeCastOp =
+ rewriter.create<vector::ShapeCastOp>(loc, shapeCastType, readResult);
+
+ // write
+ SmallVector<OpFoldResult> destSizes;
+ for (auto size : resultShape) {
+ destSizes.emplace_back(rewriter.getIndexAttr(size));
+ }
----------------
ftynse wrote:
Reserve space before appending in a loop. Or better, use a proper combinator like `map_to_vector`.
https://github.com/llvm/llvm-project/pull/97297
More information about the Mlir-commits
mailing list