[Mlir-commits] [mlir] 77363fb - [mlir][linalg] Add getCollapsedVecType and update vectorization of linalg.unpack (#151503)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Aug 1 03:26:22 PDT 2025
Author: Andrzej Warzyński
Date: 2025-08-01T11:26:19+01:00
New Revision: 77363fbd7ca24206a991eeb43d5ded54dd52a1f6
URL: https://github.com/llvm/llvm-project/commit/77363fbd7ca24206a991eeb43d5ded54dd52a1f6
DIFF: https://github.com/llvm/llvm-project/commit/77363fbd7ca24206a991eeb43d5ded54dd52a1f6.diff
LOG: [mlir][linalg] Add getCollapsedVecType and update vectorization of linalg.unpack (#151503)
This patch introduces a new helper, `getCollapsedVecType`, and updates
`vectorizeAsTensorUnpackOp` to use it. The motivation stems from improving how
`vector.shape_cast` operations are generated when vectorizing `linalg.unpack`.
Previously, the vectorizer relied on
* `tensor::CollapseShapeOp::inferCollapsedType`
to compute the collapsed vector type. This approach is suboptimal
because:
* `inferCollapsedType` lacks awareness of scalable vector flags.
* Linalg vectorization should not depend on Tensor dialect utilities.
Instead of relocating `inferCollapsedType`, we introduce
`getCollapsedVecType` — a lightweight, specialized hook that:
* Assumes no dynamic sizes.
* Handles scalable flags alongside shape dimensions.
This change also reduces temporary variables in
`vectorizeAsTensorUnpackOp` and paves the way for a cleaner update in
#149293.
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ea68b1ad572c3..0860ceafa0270 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1831,6 +1831,53 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
return success();
}
+/// Given the re-associations, "collapses" the input Vector type
+///
+/// This is similar to CollapseShapeOp::inferCollapsedType with two notable
+///
diff erences:
+/// * We can safely assume that there are no dynamic sizes.
+/// * Scalable flags are updated alongside regular dims.
+///
+/// When collapsing scalable flags, conservatively avoids cases with two
+/// scalable dims. We could re-visit this in the future.
+///
+/// EXAMPLE:
+/// type = vector<4x16x[8]x16xf32>
+/// reassociation = [(d0, d1, d2, d3) -> (d0, d1),
+/// (d0, d1, d2, d3) -> (d2, d3)]
+/// Result:
+/// vector<64x[128]xf32>
+static VectorType getCollapsedVecType(VectorType type,
+ ArrayRef<AffineMap> reassociation) {
+ assert(type.getNumScalableDims() < 2 &&
+ "Collapsing more than 1 scalable dim is not supported ATM");
+
+ // Use the fact that reassociation is valid to simplify the logic: only use
+ // each map's rank.
+ assert(isReassociationValid(reassociation) && "invalid reassociation");
+
+ auto shape = type.getShape();
+ auto scalableFlags = type.getScalableDims();
+ SmallVector<int64_t> newShape;
+ SmallVector<bool> newScalableFlags;
+
+ unsigned currentDim = 0;
+ for (AffineMap m : reassociation) {
+ unsigned dim = m.getNumResults();
+ int64_t size = 1;
+ bool flag = false;
+ for (unsigned d = 0; d < dim; ++d) {
+ size *= shape[currentDim + d];
+ flag |= scalableFlags[currentDim + d];
+ }
+ newShape.push_back(size);
+ newScalableFlags.push_back(flag);
+ currentDim += dim;
+ }
+
+ return VectorType::get(newShape, type.getElementType(), newScalableFlags);
+}
+
/// Vectorize a `linalg::UnPackOp` to these 4 Ops:
/// Vector::TransferReadOp - Reads a vector from the source tensor
/// vector::TransposeOp - Transpose the Source tensor
@@ -1928,23 +1975,17 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
PackingMetadata packMetadata;
SmallVector<int64_t> lastDimToInsertPosPerm =
getUnPackInverseSrcPerm(unpackOp, packMetadata);
- ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
- SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
- mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
- applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
- RankedTensorType stripMineTensorType =
- RankedTensorType::get(stripMineShape, stripMineElemType);
// Transpose the appropriate rows to match output.
vector::TransposeOp transposeOp = vector::TransposeOp::create(
rewriter, loc, readResult, lastDimToInsertPosPerm);
// Collapse the vector to the size required by result.
- RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
- stripMineTensorType, packMetadata.reassociations);
- mlir::VectorType vecCollapsedType =
- VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
+ VectorType collapsedVecType = getCollapsedVecType(
+ transposeOp.getType(),
+ getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
+ rewriter.getContext(), packMetadata.reassociations)));
vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
- rewriter, loc, vecCollapsedType, transposeOp->getResult(0));
+ rewriter, loc, collapsedVecType, transposeOp->getResult(0));
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
More information about the Mlir-commits
mailing list