[Mlir-commits] [mlir] [mlir][linalg] Enable scalable vectorization of linalg.unpack (PR #149293)

Andrzej WarzyƄski llvmlistbot at llvm.org
Wed Jul 30 06:17:10 PDT 2025


================
@@ -1831,126 +1832,138 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
   return success();
 }
 
-/// Vectorize a `linalg::UnPackOp` to these 4 Ops:
-///   Vector::TransferReadOp - Reads a vector from the source tensor
-///   vector::TransposeOp - Transpose the Source tensor
-///   ShapeCastOp - Reshape the data based on the target.
-///   vector::TransferWriteOp. - Write the result vector back to the destination
-///   tensor.
-///   If the vector sizes are not provided:
-///   * the vector sizes are determined by the input operand and attributes,
-///   * update the inBounds attribute instead of masking.
+/// Vectorize `linalg.unpack` into:
+///   * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
+///
+/// The input-vector-sizes specify both the read and the write vector
+/// sizes and are passed as one array covering both operations, i.e.:
+///
+///  input-vector-sizes = [1, 1, 8, [8],  8, [8]]
+///                        \         /    \    /
+///                        read-sizes   write-sizes
+///
+/// (for brefity, in the diagram,
+///    * input-vector-sizes = `inputVectorSizes` + `inputScalableDims`
+/// )
+///
+/// If the vector sizes are not provided:
+///  * the vector sizes are determined by the operands,
+///  * the inBounds attribute is used instead of masking.
+///
+/// EXAMPLE (no vector sizes):
+/// ```
+///   %unpack = linalg.unpack  %src
+///    inner_dims_pos = [0, 1]
+///    inner_tiles = [8, 8]
+///    into %dest : tensor<1x1x8x8xf32> -> tensor<8x8xf32>
+/// ```
+/// is vectorized as:
+/// ```
+///   // Reads a vector from the source tensor
+///   %read = vector.transfer_read  %src
+///     : tensor<1x1x8x8xf32>, vector<1x1x8x8xf32>
+///   // Transpose %read as specified in `outer_dims_perm` attribute
+///   %tr = vector.transpose %read [0, 2, 1, 3]
+///     : vector<1x1x8x8xf32> to vector<1x8x1x8xf32>
+///   // Reshape the data based on the target
+///   %sc = vector.shape_cast %tr : vector<1x8x1x8xf32> to vector<8x8xf32>
+///   // Write the result vector to the destination tensor.
+///   vector.transfer_write %sc into %dest : vector<8x8xf32>, tensor<8x8xf32>
+/// ```
 static LogicalResult
 vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
                           ArrayRef<int64_t> inputVectorSizes,
+                          ArrayRef<bool> inputScalableVecDims,
                           SmallVectorImpl<Value> &newResults) {
+  if (!inputVectorSizes.empty()) {
+    assert(inputVectorSizes.size() ==
+               unpackOp.getDestRank() + unpackOp.getSourceRank() &&
+           "Invalid number of input vector sizes!");
+    assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
+           "Incompatible number of vector sizes and vector scalable flags!");
+  }
 
   // TODO: Introduce a parent class that will handle the insertion point update.
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(unpackOp);
 
   RankedTensorType unpackTensorType = unpackOp.getSourceType();
 
-  ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
-  ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
   ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
+  ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape();
   bool useInBoundsInsteadOfMasking = false;
-  ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
-
-  auto destSize = unpackOp.getDestRank();
-
-  if (!inputVectorSizes.empty())
-    assert(inputVectorSizes.size() == destSize &&
-           "Incorrect number of input vector sizes");
-
-  // vectorSizes is the shape of the vector that will be used to do final
-  // write on the destination tensor. It is set like this: Let's say the
-  // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
-  // Thus:
-  // 1. vectorSizes = sourceShape.take_front(N)
-  // 2. if outer_dims_perms is present: do that permutation on vectorSizes.
-  // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
-  //    innerTiles attribute value.
-  SmallVector<int64_t> vectorSizes(inputVectorSizes);
-  if (vectorSizes.empty()) {
-    llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
-    if (!outerDimsPerm.empty())
-      applyPermutationToVector(vectorSizes, outerDimsPerm);
-    for (auto [i, pos] : llvm::enumerate(innerDimPos))
-      vectorSizes[pos] *= innerTiles[i];
 
-    useInBoundsInsteadOfMasking = true;
-  }
+  Location loc = unpackOp->getLoc();
 
-  // readVectorSizes is the size of tensor used to read and apply mask. It is
-  // set like this: Let's say the vectorSize (VS) array is size 'N' and
-  // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
-  // size M-N
-  // Thus:
-  // - initially: readVectorSizes = vectorInputSizes
-  // - Divide all the readMaskShape locations pointed by innerDimPos
-  //   by the innerTileSize attribute value.
-  // - if outer_dims_perms is present: do that permutation on readVectorSizes.
-  // - Append the remaining shape from SS
-  // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
-  // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
-  // 128] and outer_dims_perm is [1, 0] then read shape is:
-  //   ReadVectorSizes(initial): [512, 128]
-  //   Final Value(after innerDim Adjustment): [512/32, 128/16]
-  //                                           = [16, 8]
-  //   After applying outer_dims_perm: [8, 16]
-  //   After appending the rest of the sourceShape: [8, 16, 32, 16]
-
-  SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());
-
-  for (auto [index, size] : enumerate(innerTiles)) {
-    readVectorSizes[innerDimPos[index]] =
-        llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
-  }
-  if (!outerDimsPerm.empty()) {
-    applyPermutationToVector(readVectorSizes, outerDimsPerm);
+  // 1. Obtain vector sizes for the read and write operations.
+  SmallVector<int64_t> readVectorSizes;
+  SmallVector<int64_t> writeVectorSizes;
+  SmallVector<bool> readScalableVectorFlags;
+  SmallVector<bool> writeScalableVectorFlags;
+
+  // CASE 1.1: Vector sizes are user-specified.
+  if (!inputVectorSizes.empty()) {
+    readVectorSizes.append(inputVectorSizes.begin(),
+                           inputVectorSizes.begin() + sourceShape.size());
+    writeVectorSizes.append(inputVectorSizes.begin() + sourceShape.size(),
+                            inputVectorSizes.end());
+    readScalableVectorFlags.append(inputScalableVecDims.begin(),
+                                   inputScalableVecDims.begin() +
+                                       sourceShape.size());
+    writeScalableVectorFlags.append(inputScalableVecDims.begin() +
+                                        sourceShape.size(),
+                                    inputScalableVecDims.end());
   }
-  readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
-                         sourceShape.end());
 
-  Location loc = unpackOp->getLoc();
+  // CASE 1. 2: Vector sizes have to be inferred.
+  if (writeVectorSizes.empty()) {
+    if (ShapedType::isDynamicShape(destShape) ||
+        ShapedType::isDynamicShape(sourceShape))
+      return failure();
+
+    readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
+    writeVectorSizes.assign(destShape.begin(), destShape.end());
+    useInBoundsInsteadOfMasking = true;
+  }
 
+  // 2. Generate the read operation.
   auto padValue = arith::ConstantOp::create(
       rewriter, loc,
       rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
-
-  // Read result, mask if necessary. If transferReadOp shape is not equal
-  // to shape of source, then a mask is necessary.
   Value readResult = vector::createReadOrMaskedRead(
       rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
-      /*useInBoundsInsteadOfMasking=*/false);
+      /*useInBoundsInsteadOfMasking=*/false, readScalableVectorFlags);
 
+  // 3. Generate the transpose operation.
   PackingMetadata packMetadata;
   SmallVector<int64_t> lastDimToInsertPosPerm =
       getUnPackInverseSrcPerm(unpackOp, packMetadata);
+  vector::TransposeOp transposeOp = vector::TransposeOp::create(
+      rewriter, loc, readResult, lastDimToInsertPosPerm);
+
+  // 3. Generate the shape_cast operation.
   ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
-  SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
   mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
+
+  SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
   applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
   RankedTensorType stripMineTensorType =
       RankedTensorType::get(stripMineShape, stripMineElemType);
-  // Transpose the appropriate rows to match output.
-  vector::TransposeOp transposeOp = vector::TransposeOp::create(
-      rewriter, loc, readResult, lastDimToInsertPosPerm);
-
-  // Collapse the vector to the size required by result.
   RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
       stripMineTensorType, packMetadata.reassociations);
   mlir::VectorType vecCollapsedType =
-      VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
+      VectorType::get(collapsedType.getShape(), collapsedType.getElementType(),
+                      writeScalableVectorFlags);
   vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
       rewriter, loc, vecCollapsedType, transposeOp->getResult(0));
 
-  // writeVectorSizes had to match the shapecast shape for dynamic sizes,
+  // 4. Generate the write operation.
+  // writeVectorSizesFinal had to match the shapecast shape for dynamic sizes,
   // otherwise the validator complains that the mask size is invalid.
-  SmallVector<int64_t> writeVectorSizes(
+  // FIXME: We should not override write-vector-sizes like this.
+  SmallVector<int64_t> writeVectorSizesFinal(
----------------
banach-space wrote:

In fact, we don't need `writeVectorSizes`, see https://github.com/llvm/llvm-project/pull/151325. Shortly I will post a note explaining "why".

https://github.com/llvm/llvm-project/pull/149293


More information about the Mlir-commits mailing list