[Mlir-commits] [mlir] [mlir][linalg] Enable scalable vectorization of linalg.unpack (WIP) (PR #149293)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Jul 17 05:06:54 PDT 2025


https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/149293

This patch updates `vectorizeAsTensorUnpackOp` to support scalable vectorization by requiring user-specified vector sizes for both the _read_ and _write_ operations involved in `linalg.unpack`. Detailed rationale and an example are provided below.

Conceptually, `linalg.unpack` consists of the following high-level steps:
  1. _Read_ from the source tensor.
  2. Transpose the value read in step (1).
  3. _Write_ the value from step (2) into the destination tensor.

Currently, when vectorizing with user-provided vector sizes, only the sizes for the _write_ operation (step 3) are required. Sizes for the _read_ operation (step 1) are inferred from static shapes and inner tile sizes.

This logic breaks when the input shapes or tile sizes are dynamic (indeed, `vectorizeUnPackOpPrecondition` rejects such cases ATM and the vectorization fails). This patch addresses the issue by requiring explicit vector sizes for both the read and write sides, enabling scalable vectorization in such cases.

Example:

```mlir
func.func @unpack(%in: tensor<1x1x8x?xf32>, %out: tensor<8x?xf32>) -> tensor<8x?xf32> {
  %vs = vector.vscale
  %c8 = arith.constant 8 : index
  %tile_size = arith.muli %vs, %c8 : index

  %unpack = linalg.unpack  %in
    inner_dims_pos = [0, 1]
    inner_tiles = [8, %tile_size]
    into %out : tensor<1x1x8x?xf32> -> tensor<8x?xf32>
  return %unpack : tensor<8x?xf32>
}

module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
    %0 = transform.structured.match ops{["linalg.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
    transform.structured.vectorize %0 vector_sizes [1, 1, 8, [8],  8, [8]] : !transform.any_op
    //                                              \         /    \    /
    //                                              read-sizes   write-sizes
    transform.yield
  }
}
```

Finally, this patch also extends `createReadOrMaskedRead` and `createWriteOrMaskedWrite` to take scalable flags.

>From 0267d2a8519b481c88beac534e44df13ed17799d Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 16 Jul 2025 17:08:55 +0000
Subject: [PATCH] [mlir][linalg] Enable scalable vectorization of linalg.unpack
 (WIP)

This patch updates `vectorizeAsTensorUnpackOp` to support scalable
vectorization by requiring user-specified vector sizes for both the
_read_ and _write_ operations involved in `linalg.unpack`. Detailed
rationale and an example are provided below.

Conceptually, `linalg.unpack` consists of the following high-level steps:
  1. _Read_ from the source tensor.
  2. Transpose the value read in step (1).
  3. _Write_ the value from step (2) into the destination tensor.

Currently, when vectorizing with user-provided vector sizes, only the
sizes for the _write_ operation (step 3) are required. Sizes for the
_read_ operation (step 1) are inferred from static shapes and inner tile
sizes.

This logic breaks when the input shapes or tile sizes are dynamic
(indeed, `vectorizeUnPackOpPrecondition` rejects such cases ATM and the
vectorization fails). This patch addresses the issue by requiring
explicit vector sizes for both the read and write sides, enabling
scalable vectorization in such cases.

Example:

```mlir
func.func @unpack(%in: tensor<1x1x8x?xf32>, %out: tensor<8x?xf32>) -> tensor<8x?xf32> {
  %vs = vector.vscale
  %c8 = arith.constant 8 : index
  %tile_size = arith.muli %vs, %c8 : index

  %unpack = linalg.unpack  %in
    inner_dims_pos = [0, 1]
    inner_tiles = [8, %tile_size]
    into %out : tensor<1x1x8x?xf32> -> tensor<8x?xf32>
  return %unpack : tensor<8x?xf32>
}

module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
    %0 = transform.structured.match ops{["linalg.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
    transform.structured.vectorize %0 vector_sizes [1, 1, 8, [8],  8, [8]] : !transform.any_op
    //                                              \         /    \    /
    //                                              read-sizes   write-sizes
    transform.yield
  }
}
```

Finally, this patch also extends `createReadOrMaskedRead` and
`createWriteOrMaskedWrite` to take scalable flags.
---
 .../mlir/Dialect/Vector/Utils/VectorUtils.h   |   4 +-
 .../Linalg/Transforms/Vectorization.cpp       | 110 +++++++++++++-----
 mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp |   7 +-
 .../Linalg/vectorization/linalg-ops.mlir      |  28 ++++-
 4 files changed, 114 insertions(+), 35 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index cc8421b23a074..be71483410744 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -225,7 +225,9 @@ bool isLinearizableVector(VectorType type);
 ///
 /// Note: all read offsets are set to 0.
 Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
-                             ArrayRef<int64_t> inputVectorSizes, Value padValue,
+                             ArrayRef<int64_t> inputVectorSizes,
+                             ArrayRef<bool> inputScalableVecSizes,
+                             Value padValue,
                              bool useInBoundsInsteadOfMasking = false);
 
 /// Returns success if `inputVectorSizes` is a valid masking configuraion for
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 458ed543b8216..f4999f3a8d080 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1681,7 +1681,8 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
     return write;
 
   // Compute the mask and mask the write Op.
-  auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
+  auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type(),
+                                       vecToStoreType.getScalableDims());
 
   SmallVector<OpFoldResult> destSizes =
       tensor::getMixedSizes(builder, loc, dest);
@@ -1773,8 +1774,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
   for (auto [idx, size] : enumerate(innerTiles))
     inputShape[innerDimsPos[idx]] *= size;
   auto maskedRead = vector::createReadOrMaskedRead(
-      rewriter, loc, packOp.getSource(), inputShape, padValue,
-      useInBoundsInsteadOfMasking);
+      rewriter, loc, packOp.getSource(), inputShape,
+      /*inputScalableVecSizes=*/{}, padValue, useInBoundsInsteadOfMasking);
 
   // Create ShapeCastOp.
   SmallVector<int64_t> destShape(inputVectorSizes);
@@ -1812,6 +1813,7 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
 static LogicalResult
 vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
                           ArrayRef<int64_t> inputVectorSizes,
+                          ArrayRef<bool> inputScalableVecDims,
                           SmallVectorImpl<Value> &newResults) {
 
   // TODO: Introduce a parent class that will handle the insertion point update.
@@ -1829,24 +1831,52 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
   auto destSize = unpackOp.getDestRank();
 
   if (!inputVectorSizes.empty())
-    assert(inputVectorSizes.size() == destSize &&
+    assert(inputVectorSizes.size() == destSize + sourceShape.size() &&
            "Incorrect number of input vector sizes");
 
-  // vectorSizes is the shape of the vector that will be used to do final
+  SmallVector<bool> readScalableVectorFlags;
+  SmallVector<bool> writeScalableVectorFlags;
+  SmallVector<int64_t> readVectorSizes;
+  SmallVector<int64_t> writeVectorSizes;
+
+  // Split input-vector-sizes into vector sizes for the read and write
+  // operations.
+  if (!inputVectorSizes.empty()) {
+    readVectorSizes.append(inputVectorSizes.begin(),
+                           inputVectorSizes.begin() + sourceShape.size());
+    writeVectorSizes.append(inputVectorSizes.begin() + sourceShape.size(),
+                            inputVectorSizes.end());
+  }
+  if (!inputScalableVecDims.empty()) {
+    readScalableVectorFlags.append(inputScalableVecDims.begin(),
+                                   inputScalableVecDims.begin() +
+                                       sourceShape.size());
+    writeScalableVectorFlags.append(inputScalableVecDims.begin() +
+                                        sourceShape.size(),
+                                    inputScalableVecDims.end());
+  } else {
+    readScalableVectorFlags = SmallVector<bool>(sourceShape.size(), false);
+    writeScalableVectorFlags = SmallVector<bool>(destSize, false);
+  }
+
+  // writeVectorSizes is the shape of the vector that will be used to do final
   // write on the destination tensor. It is set like this: Let's say the
   // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
   // Thus:
-  // 1. vectorSizes = sourceShape.take_front(N)
-  // 2. if outer_dims_perms is present: do that permutation on vectorSizes.
+  // 1. writeVectorSizes = sourceShape.take_front(N)
+  // 2. if outer_dims_perms is present: do that permutation on writeVectorSizes.
   // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
   //    innerTiles attribute value.
-  SmallVector<int64_t> vectorSizes(inputVectorSizes);
-  if (vectorSizes.empty()) {
-    llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
+  // SmallVector<int64_t> writeVectorSizes(inputVectorSizes);
+  if (writeVectorSizes.empty()) {
+    if (ShapedType::isDynamicShape(sourceShape))
+      return failure();
+
+    llvm::append_range(writeVectorSizes, sourceShape.take_front(destSize));
     if (!outerDimsPerm.empty())
-      applyPermutationToVector(vectorSizes, outerDimsPerm);
+      applyPermutationToVector(writeVectorSizes, outerDimsPerm);
     for (auto [i, pos] : llvm::enumerate(innerDimPos))
-      vectorSizes[pos] *= innerTiles[i];
+      writeVectorSizes[pos] *= innerTiles[i];
 
     useInBoundsInsteadOfMasking = true;
   }
@@ -1870,17 +1900,20 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
   //   After applying outer_dims_perm: [8, 16]
   //   After appending the rest of the sourceShape: [8, 16, 32, 16]
 
-  SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());
-
-  for (auto [index, size] : enumerate(innerTiles)) {
-    readVectorSizes[innerDimPos[index]] =
-        llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
-  }
-  if (!outerDimsPerm.empty()) {
-    applyPermutationToVector(readVectorSizes, outerDimsPerm);
+  if (readVectorSizes.empty()) {
+    // Compute read-vector-sizes based on the write-vector-sizes and inner tile
+    // sizes. Note, this will only work when all sizes are static.
+    readVectorSizes = writeVectorSizes;
+    for (auto [index, size] : enumerate(innerTiles)) {
+      readVectorSizes[innerDimPos[index]] =
+          llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
+    }
+    if (!outerDimsPerm.empty()) {
+      applyPermutationToVector(readVectorSizes, outerDimsPerm);
+    }
+    readVectorSizes.append(sourceShape.begin() + writeVectorSizes.size(),
+                           sourceShape.end());
   }
-  readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
-                         sourceShape.end());
 
   ReifiedRankedShapedTypeDims reifiedRetShapes;
   LogicalResult status =
@@ -1898,7 +1931,8 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
   // Read result, mask if necessary. If transferReadOp shape is not equal
   // to shape of source, then a mask is necessary.
   Value readResult = vector::createReadOrMaskedRead(
-      rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
+      rewriter, loc, unpackOp.getSource(), readVectorSizes,
+      readScalableVectorFlags, padValue,
       /*useInBoundsInsteadOfMasking=*/false);
 
   PackingMetadata packMetadata;
@@ -1918,15 +1952,17 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
   RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
       stripMineTensorType, packMetadata.reassociations);
   mlir::VectorType vecCollapsedType =
-      VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
+      VectorType::get(collapsedType.getShape(), collapsedType.getElementType(),
+                      writeScalableVectorFlags);
   vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
       loc, vecCollapsedType, transposeOp->getResult(0));
 
-  // writeVectorSizes had to match the shapecast shape for dynamic sizes,
+  // writeVectorSizesFinal had to match the shapecast shape for dynamic sizes,
   // otherwise the validator complains that the mask size is invalid.
-  SmallVector<int64_t> writeVectorSizes(
+  // FIXME: We should not override write-vector-sizes like this.
+  SmallVector<int64_t> writeVectorSizesFinal(
       unpackOp.getDestType().hasStaticShape()
-          ? vectorSizes
+          ? writeVectorSizes
           : shapeCastOp.getResultVectorType().getShape());
   Operation *write = createWriteOrMaskedWrite(
       rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
@@ -1956,7 +1992,8 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
   (void)status; // prevent unused variable warning on non-assert builds
   assert(succeeded(status) && "failed to reify result shapes");
   auto maskedRead = vector::createReadOrMaskedRead(
-      rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
+      rewriter, loc, padOp.getSource(), inputVectorSizes,
+      /*inputScalableVecSizes=*/{}, padValue,
       /*useInBoundsInsteadOfMasking=*/false);
 
   // Create Xfer write Op
@@ -2041,6 +2078,9 @@ static LogicalResult
 vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
                               ArrayRef<int64_t> inputVectorSizes) {
 
+  // FIXME!!!
+  return success();
+
   if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
         return !getConstantIntValue(res).has_value();
       })) {
@@ -2291,6 +2331,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp,
     LDBG("pad value is not constant: " << packOp << "\n");
     return failure();
   }
+
   ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
   bool satisfyEmptyCond = true;
   if (inputVectorSizes.empty()) {
@@ -2369,6 +2410,10 @@ vectorizeScalableVectorPrecondition(Operation *op,
   if (numOfScalableDims == 0)
     return success();
 
+  // FIXME!!!
+  return success();
+
+  // TODO: Check the following!
   auto linalgOp = dyn_cast<LinalgOp>(op);
 
   // Cond 1: There's been no need for scalable vectorisation of
@@ -2469,7 +2514,8 @@ vectorizeScalableVectorPrecondition(Operation *op,
   return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
                  isa<linalg::MatmulTransposeAOp>(op) ||
                  isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
-                 isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp));
+                 isa<linalg::MatvecOp>(op) || isa<linalg::UnPackOp>(op) ||
+                 hasReductionIterator(linalgOp));
 }
 
 LogicalResult mlir::linalg::vectorizeOpPrecondition(
@@ -2598,7 +2644,8 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
           })
           .Case<linalg::UnPackOp>([&](auto unpackOp) {
             return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
-                                             inputVectorSizes, results);
+                                             inputVectorSizes,
+                                             inputScalableVecDims, results);
           })
           .Case<tensor::InsertSliceOp>([&](auto sliceOp) {
             return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
@@ -2988,7 +3035,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
   SmallVector<Value> readIndices(
       vecType.getRank(), rewriter.create<arith::ConstantIndexOp>(loc, 0));
   Value read = mlir::vector::createReadOrMaskedRead(
-      rewriter, loc, source, vecType.getShape(), padValue,
+      rewriter, loc, source, vecType.getShape(), /*inputScalableVecSizes=*/{},
+      padValue,
       /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
 
   // Create write
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 7e4984582b373..0a8729c0e473e 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -319,6 +319,7 @@ bool vector::isLinearizableVector(VectorType type) {
 Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
                                      Value source,
                                      ArrayRef<int64_t> inputVectorSizes,
+                                     ArrayRef<bool> inputScalableVecSizes,
                                      Value padValue,
                                      bool useInBoundsInsteadOfMasking) {
   assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) &&
@@ -327,7 +328,8 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
   auto sourceShape = sourceShapedType.getShape();
   assert(sourceShape.size() == inputVectorSizes.size() &&
          "expected same ranks.");
-  auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
+  auto vectorType = VectorType::get(inputVectorSizes, padValue.getType(),
+                                    inputScalableVecSizes);
   assert(padValue.getType() == sourceShapedType.getElementType() &&
          "expected same pad element type to match source element type");
   int64_t readRank = inputVectorSizes.size();
@@ -354,7 +356,8 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
   SmallVector<OpFoldResult> mixedSourceDims =
       tensor::getMixedSizes(builder, loc, source);
 
-  auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
+  auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type(),
+                                  inputScalableVecSizes);
   Value mask =
       builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
   return mlir::vector::maskOperation(builder, transferReadOp, mask)
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
index 679adf0a52175..dc3b9073335f9 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
@@ -986,7 +986,7 @@ module attributes {transform.with_named_sequence} {
 
 func.func @test_vectorize_padded_pack(%arg0: tensor<32x7x15xf32>, %arg1: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
   %pad = arith.constant 0.000000e+00 : f32
-  %pack = linalg.pack %arg0 padding_value(%pad : f32) inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x7x15xf32> -> tensor<32x4x1x16x2xf32>
+  %pack = linalg.pack %arg0 padding_value(%pad : f32) inner_dims_pos = [2, 1] inner_tiles = [16, [2]] into %arg1 : tensor<32x7x15xf32> -> tensor<32x4x1x16x2xf32>
   return %pack : tensor<32x4x1x16x2xf32>
 }
 //  CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
@@ -1017,6 +1017,32 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @test_vectorize_padded_pack(%extracted_slice: tensor<1x?xf32>, %extracted_slice_0: tensor<1x1x?x1xf32>) -> tensor<1x1x?x1xf32> {
+  %pad = arith.constant 1.23:  f32
+
+  %vs = vector.vscale
+  %c8 = arith.constant 8 : index
+  %tile_size = arith.muli %vs, %c8 : index
+
+  %pack = linalg.pack %extracted_slice
+     padding_value(%pad : f32)
+     outer_dims_perm = [1, 0]
+     inner_dims_pos = [1, 0]
+     inner_tiles = [%tile_size, 1]
+     into %extracted_slice_0  : tensor<1x?xf32> -> tensor<1x1x?x1xf32>
+  return %pack : tensor<1x1x?x1xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [1, 1] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 func.func @test_vectorize_dynamic_pack(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x16x2xf32>) -> tensor<?x?x16x2xf32> {
   %pack = linalg.pack %arg0 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg1 : tensor<?x?xf32> -> tensor<?x?x16x2xf32>
   return %pack : tensor<?x?x16x2xf32>



More information about the Mlir-commits mailing list