[Mlir-commits] [mlir] [mlir][linalg] Vectorize unpack op without masking (PR #89067)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Thu May 2 12:14:41 PDT 2024
================
@@ -1560,11 +1574,32 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
+ ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
+ bool useInBoundsInsteadOfMasking = false;
+ ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
+
+ auto destSize = unpackOp.getDestRank();
+
+ // initVectorShape is the shape of the vector that will be read from the
+ // source tensor. It is set like this: Let's say the sourceShape is 'M' and
+ // the vectorSize (VS) array is size 'N' where N <= M. Thus:
+ // - initVectorShape = sourceShape.take_front(N)
+ // - if outer_dims_perms is present: do that permutation on initVectorShape.
+ // - Multiply all the locations pointed by innerDimPos by the innerTileSize
+ // attribute value.
+ SmallVector<int64_t> initVectorShape(sourceShape.take_front(destSize));
+ if (inputVectorSizes.empty()) {
+ if (!outerDimsPerm.empty())
+ applyPermutationToVector(initVectorShape, outerDimsPerm);
+ for (auto [i, pos] : llvm::enumerate(innerDimPos))
+ initVectorShape[pos] *= innerTiles[i];
+
+ inputVectorSizes = initVectorShape;
+ useInBoundsInsteadOfMasking = true;
+ }
SmallVector<int64_t> readMaskShape(inputVectorSizes.begin(),
inputVectorSizes.end());
----------------
banach-space wrote:
Good point, I missed that!
`readMaskShape` is used when calling `createReadOrMaskedRead`:
https://github.com/llvm/llvm-project/blob/6d44a1ef55b559e59d725b07ffe1da988b4e5f1c/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp#L1613-L1616
And, the signature of `createReadOrMaskedRead` is here: https://github.com/llvm/llvm-project/blob/6d44a1ef55b559e59d725b07ffe1da988b4e5f1c/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h#L195-L197
So it's not really `readMaskShape`, it's `readShape` or `readVectorSizes` like you suggested. Am I correct that we only need to calculate this once?
https://github.com/llvm/llvm-project/pull/89067
More information about the Mlir-commits
mailing list