[Mlir-commits] [mlir] [mlir][linalg] Enable scalable vectorization of linalg.unpack (PR #149293)

Andrzej WarzyƄski llvmlistbot at llvm.org
Mon Jul 28 02:20:25 PDT 2025


================
@@ -1866,59 +1879,76 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
 
   auto destSize = unpackOp.getDestRank();
 
-  if (!inputVectorSizes.empty())
-    assert(inputVectorSizes.size() == destSize &&
-           "Incorrect number of input vector sizes");
-
-  // vectorSizes is the shape of the vector that will be used to do final
-  // write on the destination tensor. It is set like this: Let's say the
-  // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
-  // Thus:
-  // 1. vectorSizes = sourceShape.take_front(N)
-  // 2. if outer_dims_perms is present: do that permutation on vectorSizes.
-  // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
-  //    innerTiles attribute value.
-  SmallVector<int64_t> vectorSizes(inputVectorSizes);
-  if (vectorSizes.empty()) {
-    llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
+  // 1. Obtain vector sizes for the read and write operation.s
+  SmallVector<int64_t> readVectorSizes;
+  SmallVector<int64_t> writeVectorSizes;
+  SmallVector<bool> readScalableVectorFlags;
+  SmallVector<bool> writeScalableVectorFlags;
+
+  // CASE 1: Vector sizes are user-specified.
+  // 1.0 This is the trivial case, simply split the input vector sizes.
+  if (!inputVectorSizes.empty()) {
+    readVectorSizes.append(inputVectorSizes.begin(),
+                           inputVectorSizes.begin() + sourceShape.size());
+    writeVectorSizes.append(inputVectorSizes.begin() + sourceShape.size(),
+                            inputVectorSizes.end());
+    readScalableVectorFlags.append(inputScalableVecDims.begin(),
+                                   inputScalableVecDims.begin() +
+                                       sourceShape.size());
+    writeScalableVectorFlags.append(inputScalableVecDims.begin() +
+                                        sourceShape.size(),
+                                    inputScalableVecDims.end());
+  }
+
+  // CASE 2: Vector sizes have to be inferred.
+  //
+  // 1.1 Infer vector sizes for the write operation.
+  //
+  // Let:
+  //    * rank(source tensor) = 'M'
+  //    * rank(dest tensor) = 'N',
+  // and N <= M. The steps are:
+  //  1. writeVectorSizes = sourceShape.take_front(N)
+  //  2. Multiply all the locations in writeVectorSize pointed by inner_dims_pos
+  //     by the corresponding values from the `inner_tiles` attribute value.
+  //  3. If outer_dims_perms is present, permutate writeVectorSizes accordingly.
+  //
+  // Note, this will only work when all sizes are static!
+  if (writeVectorSizes.empty()) {
+    if (ShapedType::isDynamicShape(sourceShape))
+      return failure();
+
+    llvm::append_range(writeVectorSizes, sourceShape.take_front(destSize));
     if (!outerDimsPerm.empty())
-      applyPermutationToVector(vectorSizes, outerDimsPerm);
+      applyPermutationToVector(writeVectorSizes, outerDimsPerm);
     for (auto [i, pos] : llvm::enumerate(innerDimPos))
-      vectorSizes[pos] *= innerTiles[i];
+      writeVectorSizes[pos] *= innerTiles[i];
 
     useInBoundsInsteadOfMasking = true;
   }
 
-  // readVectorSizes is the size of tensor used to read and apply mask. It is
-  // set like this: Let's say the vectorSize (VS) array is size 'N' and
-  // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
-  // size M-N
-  // Thus:
-  // - initially: readVectorSizes = vectorInputSizes
-  // - Divide all the readMaskShape locations pointed by innerDimPos
-  //   by the innerTileSize attribute value.
-  // - if outer_dims_perms is present: do that permutation on readVectorSizes.
-  // - Append the remaining shape from SS
-  // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
-  // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
-  // 128] and outer_dims_perm is [1, 0] then read shape is:
-  //   ReadVectorSizes(initial): [512, 128]
-  //   Final Value(after innerDim Adjustment): [512/32, 128/16]
-  //                                           = [16, 8]
-  //   After applying outer_dims_perm: [8, 16]
-  //   After appending the rest of the sourceShape: [8, 16, 32, 16]
-
-  SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());
-
-  for (auto [index, size] : enumerate(innerTiles)) {
-    readVectorSizes[innerDimPos[index]] =
-        llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
-  }
-  if (!outerDimsPerm.empty()) {
-    applyPermutationToVector(readVectorSizes, outerDimsPerm);
+  // 1.2 Infer vector sizes for the read operation.
+  //
+  // The steps are:
+  //  1. readVectorSizes = writeVectorSizes
+  //  2. Take readVectorSizes from 1. and divide all locations pointed by
+  //     the inner_dims_pos attribyte by the `inner_tiles` attribute value.
+  //  3. If outer_dims_perms is present, permutate readVectorSizes accordingly.
+  //  4. Append the remaining sizes from the source tensor.
+  //
+  // Note, this will only work when all sizes are static!
+  if (readVectorSizes.empty()) {
+    readVectorSizes = writeVectorSizes;
+    for (auto [index, size] : enumerate(innerTiles)) {
+      readVectorSizes[innerDimPos[index]] =
+          llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
+    }
+    if (!outerDimsPerm.empty()) {
+      applyPermutationToVector(readVectorSizes, outerDimsPerm);
+    }
+    readVectorSizes.append(sourceShape.begin() + writeVectorSizes.size(),
+                           sourceShape.end());
----------------
banach-space wrote:

Indeed, thanks!

https://github.com/llvm/llvm-project/pull/149293


More information about the Mlir-commits mailing list