[Mlir-commits] [mlir] [mlir][linalg] Enable scalable vectorization of linalg.unpack (PR #149293)
Ege Beysel
llvmlistbot at llvm.org
Thu Jul 24 12:27:11 PDT 2025
================
@@ -1860,25 +1866,54 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
auto destSize = unpackOp.getDestRank();
- if (!inputVectorSizes.empty())
- assert(inputVectorSizes.size() == destSize &&
+ if (!inputVectorSizes.empty()) {
+ assert(inputVectorSizes.size() == destSize + sourceShape.size() &&
"Incorrect number of input vector sizes");
+ }
+
+ SmallVector<bool> readScalableVectorFlags;
+ SmallVector<bool> writeScalableVectorFlags;
+ SmallVector<int64_t> readVectorSizes;
+ SmallVector<int64_t> writeVectorSizes;
- // vectorSizes is the shape of the vector that will be used to do final
+ // Split input-vector-sizes into vector sizes for the read and write
+ // operations.
+ if (!inputVectorSizes.empty()) {
+ readVectorSizes.append(inputVectorSizes.begin(),
+ inputVectorSizes.begin() + sourceShape.size());
+ writeVectorSizes.append(inputVectorSizes.begin() + sourceShape.size(),
+ inputVectorSizes.end());
+ }
+ if (!inputScalableVecDims.empty()) {
+ readScalableVectorFlags.append(inputScalableVecDims.begin(),
+ inputScalableVecDims.begin() +
+ sourceShape.size());
+ writeScalableVectorFlags.append(inputScalableVecDims.begin() +
+ sourceShape.size(),
+ inputScalableVecDims.end());
+ } else {
+ readScalableVectorFlags = SmallVector<bool>(sourceShape.size(), false);
+ writeScalableVectorFlags = SmallVector<bool>(destSize, false);
+ }
+
+ // writeVectorSizes is the shape of the vector that will be used to do final
// write on the destination tensor. It is set like this: Let's say the
// source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
// Thus:
- // 1. vectorSizes = sourceShape.take_front(N)
- // 2. if outer_dims_perms is present: do that permutation on vectorSizes.
+ // 1. writeVectorSizes = sourceShape.take_front(N)
+ // 2. if outer_dims_perms is present: do that permutation on writeVectorSizes.
// 3. multiply all the locations in vectorSize pointed by innerDimPos by the
// innerTiles attribute value.
- SmallVector<int64_t> vectorSizes(inputVectorSizes);
- if (vectorSizes.empty()) {
- llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
+ // SmallVector<int64_t> writeVectorSizes(inputVectorSizes);
----------------
egebeysel wrote:
Also, can we add a comment here saying that this is the case that we would be inferring the write vector sizes from the IR? Thanks!
https://github.com/llvm/llvm-project/pull/149293
More information about the Mlir-commits
mailing list