[Mlir-commits] [mlir] [mlir][linalg] Add getCollapsedVecType and update vectorization of linalg.unpack (PR #151503)

Andrzej Warzyński llvmlistbot at llvm.org
Thu Jul 31 05:17:45 PDT 2025


https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/151503

This patch introduces a new helper, `getCollapsedVecType`, and updates
`vectorizeAsTensorUnpackOp` to use it. The motivation stems from improving
how `vector.shape_cast` operations are generated when vectorizing
`linalg.unpack`.

Previously, the vectorizer relied on
* `tensor::CollapseShapeOp::inferCollapsedType`

to compute the collapsed vector type. This approach is suboptimal
because:
  * `inferCollapsedType` lacks awareness of scalable vector flags.
  * Linalg vectorization should not depend on Tensor dialect utilities.

Instead of relocating `inferCollapsedType`, we introduce
`getCollapsedVecType` — a lightweight, specialized hook that:
  * Assumes no dynamic sizes.
  * Handles scalable flags alongside shape dimensions.

This change also reduces temporary variables in
`vectorizeAsTensorUnpackOp` and paves the way for a cleaner update in
 #149293.


>From 33fae274c5ef92611b79e69a944676f6fba968ed Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 31 Jul 2025 11:45:45 +0000
Subject: [PATCH] [mlir][linalg] Add getCollapsedVecType and update
 vectorization of linalg.unpack
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This patch introduces a new helper, `getCollapsedVecType`, and updates
`vectorizeAsTensorUnpackOp` to use it. The motivation stems from improving
how `vector.shape_cast` operations are generated when vectorizing
`linalg.unpack`.

Previously, the vectorizer relied on
* `tensor::CollapseShapeOp::inferCollapsedType`

to compute the collapsed vector type. This approach is suboptimal
because:
  * `inferCollapsedType` lacks awareness of scalable vector flags.
  * Linalg vectorization should not depend on Tensor dialect utilities.

Instead of relocating `inferCollapsedType`, we introduce
`getCollapsedVecType` — a lightweight, specialized hook that:
  * Assumes no dynamic sizes.
  * Handles scalable flags alongside shape dimensions.

This change also reduces temporary variables in
`vectorizeAsTensorUnpackOp` and paves the way for a cleaner update in
 #149293.
---
 .../Linalg/Transforms/Vectorization.cpp       | 56 +++++++++++++++----
 1 file changed, 45 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ea68b1ad572c3..a82f31d988f76 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1831,6 +1831,46 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
   return success();
 }
 
+/// Given the re-associations, "collapses" the input Vector type
+///
+/// This is similar to CollapseShapeOp::inferCollapsedType with two notable
+/// differences:
+///   * We can safely assume that there are no dynamic sizes.
+///   * Scalable flags are updated alongside regular dims.
+///
+/// When collapsing scalable flags, conservatively avoids cases with two
+/// scalable dims. We could re-visit this in the future.
+static VectorType getCollapsedVecType(VectorType type,
+                                      ArrayRef<AffineMap> reassociation) {
+  assert(type.getNumScalableDims() < 2 &&
+         "Collapsing more than 1 scalable dim is not supported ATM");
+
+  // Use the fact that reassociation is valid to simplify the logic: only use
+  // each map's rank.
+  assert(isReassociationValid(reassociation) && "invalid reassociation");
+
+  auto shape = type.getShape();
+  auto scalableFlags = type.getScalableDims();
+  SmallVector<int64_t> newShape;
+  SmallVector<bool> newScalableFlags;
+
+  unsigned currentDim = 0;
+  for (AffineMap m : reassociation) {
+    unsigned dim = m.getNumResults();
+    int64_t size = 1;
+    bool flag = false;
+    for (unsigned d = 0; d < dim; ++d) {
+      size *= shape[currentDim + d];
+      flag |= scalableFlags[currentDim + d];
+    }
+    newShape.push_back(size);
+    newScalableFlags.push_back(flag);
+    currentDim += dim;
+  }
+
+  return VectorType::get(newShape, type.getElementType(), newScalableFlags);
+}
+
 /// Vectorize a `linalg::UnPackOp` to these 4 Ops:
 ///   Vector::TransferReadOp - Reads a vector from the source tensor
 ///   vector::TransposeOp - Transpose the Source tensor
@@ -1928,23 +1968,17 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
   PackingMetadata packMetadata;
   SmallVector<int64_t> lastDimToInsertPosPerm =
       getUnPackInverseSrcPerm(unpackOp, packMetadata);
-  ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
-  SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
-  mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
-  applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
-  RankedTensorType stripMineTensorType =
-      RankedTensorType::get(stripMineShape, stripMineElemType);
   // Transpose the appropriate rows to match output.
   vector::TransposeOp transposeOp = vector::TransposeOp::create(
       rewriter, loc, readResult, lastDimToInsertPosPerm);
 
   // Collapse the vector to the size required by result.
-  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
-      stripMineTensorType, packMetadata.reassociations);
-  mlir::VectorType vecCollapsedType =
-      VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
+  VectorType collapsedVecType = getCollapsedVecType(
+      transposeOp.getType(),
+      getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
+          rewriter.getContext(), packMetadata.reassociations)));
   vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
-      rewriter, loc, vecCollapsedType, transposeOp->getResult(0));
+      rewriter, loc, collapsedVecType, transposeOp->getResult(0));
 
   Operation *write = createWriteOrMaskedWrite(
       rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),



More information about the Mlir-commits mailing list