[Mlir-commits] [mlir] [mlir][linalg] Enable scalable vectorization of linalg.unpack (WIP) (PR #149293)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jul 24 02:13:36 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Andrzej WarzyĆski (banach-space)
<details>
<summary>Changes</summary>
This patch updates `vectorizeAsTensorUnpackOp` to support scalable vectorization by requiring user-specified vector sizes for both the _read_ and _write_ operations involved in `linalg.unpack`. Detailed rationale and an example are provided below.
Conceptually, `linalg.unpack` consists of the following high-level steps:
1. _Read_ from the source tensor.
2. Transpose the value read in step (1).
3. _Write_ the value from step (2) into the destination tensor.
Currently, when vectorizing with user-provided vector sizes, only the sizes for the _write_ operation (step 3) are required. Sizes for the _read_ operation (step 1) are inferred from static shapes and inner tile sizes.
This logic breaks when the input shapes or tile sizes are dynamic (indeed, `vectorizeUnPackOpPrecondition` rejects such cases ATM and the vectorization fails). This patch addresses the issue by requiring explicit vector sizes for both the read and write sides, enabling scalable vectorization in such cases.
Example:
```mlir
func.func @<!-- -->unpack(%in: tensor<1x1x8x?xf32>, %out: tensor<8x?xf32>) -> tensor<8x?xf32> {
%vs = vector.vscale
%c8 = arith.constant 8 : index
%tile_size = arith.muli %vs, %c8 : index
%unpack = linalg.unpack %in
inner_dims_pos = [0, 1]
inner_tiles = [8, %tile_size]
into %out : tensor<1x1x8x?xf32> -> tensor<8x?xf32>
return %unpack : tensor<8x?xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @<!-- -->__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.structured.vectorize %0 vector_sizes [1, 1, 8, [8], 8, [8]] : !transform.any_op
// \ / \ /
// read-sizes write-sizes
transform.yield
}
}
```
Finally, this patch also extends `createReadOrMaskedRead` and `createWriteOrMaskedWrite` to take scalable flags.
---
Patch is 27.05 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149293.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h (+1-1)
- (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+89-41)
- (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+12-10)
- (modified) mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir (+107-28)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 7cd70e42d363c..8bd54cf31b893 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -228,7 +228,7 @@ bool isLinearizableVector(VectorType type);
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
ArrayRef<int64_t> inputVectorSizes, Value padValue,
bool useInBoundsInsteadOfMasking = false,
- ArrayRef<bool> scalableDims = {});
+ ArrayRef<bool> inputScalableVecDims = {});
/// Returns success if `inputVectorSizes` is a valid masking configuraion for
/// given `shape`, i.e., it meets:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 77c85abab9aa0..78f1c524b69fd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1806,7 +1806,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
inputShape[innerDimsPos[idx]] *= size;
auto maskedRead = vector::createReadOrMaskedRead(
rewriter, loc, packOp.getSource(), inputShape, padValue,
- useInBoundsInsteadOfMasking);
+ useInBoundsInsteadOfMasking,
+ /*inputScalableVecSizes=*/{});
// Create ShapeCastOp.
SmallVector<int64_t> destShape(inputVectorSizes);
@@ -1832,18 +1833,23 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
return success();
}
-/// Vectorize a `linalg::UnPackOp` to these 4 Ops:
-/// Vector::TransferReadOp - Reads a vector from the source tensor
-/// vector::TransposeOp - Transpose the Source tensor
-/// ShapeCastOp - Reshape the data based on the target.
-/// vector::TransferWriteOp. - Write the result vector back to the destination
-/// tensor.
-/// If the vector sizes are not provided:
+/// Vectorize `linalg.unpack %src into %dest` as:
+/// // Reads a vector from the source tensor
+/// %read = vector.transfer_read %src
+/// // Transpose %read as specified in `outer_dims_perm` attribute
+/// %tr = vector.transpose %read
+/// // Reshape the data based on the target
+/// %sc = vector.shape_cast %tr
+/// // Write the result vector to the destination tensor.
+/// vector.transfer_write %sc into %dest
+///
+/// If the vector sizes are not provided:
/// * the vector sizes are determined by the input operand and attributes,
/// * update the inBounds attribute instead of masking.
static LogicalResult
vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
ArrayRef<int64_t> inputVectorSizes,
+ ArrayRef<bool> inputScalableVecDims,
SmallVectorImpl<Value> &newResults) {
// TODO: Introduce a parent class that will handle the insertion point update.
@@ -1860,25 +1866,54 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
auto destSize = unpackOp.getDestRank();
- if (!inputVectorSizes.empty())
- assert(inputVectorSizes.size() == destSize &&
+ if (!inputVectorSizes.empty()) {
+ assert(inputVectorSizes.size() == destSize + sourceShape.size() &&
"Incorrect number of input vector sizes");
+ }
+
+ SmallVector<bool> readScalableVectorFlags;
+ SmallVector<bool> writeScalableVectorFlags;
+ SmallVector<int64_t> readVectorSizes;
+ SmallVector<int64_t> writeVectorSizes;
- // vectorSizes is the shape of the vector that will be used to do final
+ // Split input-vector-sizes into vector sizes for the read and write
+ // operations.
+ if (!inputVectorSizes.empty()) {
+ readVectorSizes.append(inputVectorSizes.begin(),
+ inputVectorSizes.begin() + sourceShape.size());
+ writeVectorSizes.append(inputVectorSizes.begin() + sourceShape.size(),
+ inputVectorSizes.end());
+ }
+ if (!inputScalableVecDims.empty()) {
+ readScalableVectorFlags.append(inputScalableVecDims.begin(),
+ inputScalableVecDims.begin() +
+ sourceShape.size());
+ writeScalableVectorFlags.append(inputScalableVecDims.begin() +
+ sourceShape.size(),
+ inputScalableVecDims.end());
+ } else {
+ readScalableVectorFlags = SmallVector<bool>(sourceShape.size(), false);
+ writeScalableVectorFlags = SmallVector<bool>(destSize, false);
+ }
+
+ // writeVectorSizes is the shape of the vector that will be used to do final
// write on the destination tensor. It is set like this: Let's say the
// source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
// Thus:
- // 1. vectorSizes = sourceShape.take_front(N)
- // 2. if outer_dims_perms is present: do that permutation on vectorSizes.
+ // 1. writeVectorSizes = sourceShape.take_front(N)
+ // 2. if outer_dims_perms is present: do that permutation on writeVectorSizes.
// 3. multiply all the locations in vectorSize pointed by innerDimPos by the
// innerTiles attribute value.
- SmallVector<int64_t> vectorSizes(inputVectorSizes);
- if (vectorSizes.empty()) {
- llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
+ // SmallVector<int64_t> writeVectorSizes(inputVectorSizes);
+ if (writeVectorSizes.empty()) {
+ if (ShapedType::isDynamicShape(sourceShape))
+ return failure();
+
+ llvm::append_range(writeVectorSizes, sourceShape.take_front(destSize));
if (!outerDimsPerm.empty())
- applyPermutationToVector(vectorSizes, outerDimsPerm);
+ applyPermutationToVector(writeVectorSizes, outerDimsPerm);
for (auto [i, pos] : llvm::enumerate(innerDimPos))
- vectorSizes[pos] *= innerTiles[i];
+ writeVectorSizes[pos] *= innerTiles[i];
useInBoundsInsteadOfMasking = true;
}
@@ -1902,17 +1937,20 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
// After applying outer_dims_perm: [8, 16]
// After appending the rest of the sourceShape: [8, 16, 32, 16]
- SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());
-
- for (auto [index, size] : enumerate(innerTiles)) {
- readVectorSizes[innerDimPos[index]] =
- llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
- }
- if (!outerDimsPerm.empty()) {
- applyPermutationToVector(readVectorSizes, outerDimsPerm);
+ if (readVectorSizes.empty()) {
+ // Compute read-vector-sizes based on the write-vector-sizes and inner tile
+ // sizes. Note, this will only work when all sizes are static.
+ readVectorSizes = writeVectorSizes;
+ for (auto [index, size] : enumerate(innerTiles)) {
+ readVectorSizes[innerDimPos[index]] =
+ llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
+ }
+ if (!outerDimsPerm.empty()) {
+ applyPermutationToVector(readVectorSizes, outerDimsPerm);
+ }
+ readVectorSizes.append(sourceShape.begin() + writeVectorSizes.size(),
+ sourceShape.end());
}
- readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
- sourceShape.end());
ReifiedRankedShapedTypeDims reifiedRetShapes;
LogicalResult status =
@@ -1931,7 +1969,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
// to shape of source, then a mask is necessary.
Value readResult = vector::createReadOrMaskedRead(
rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
- /*useInBoundsInsteadOfMasking=*/false);
+ /*useInBoundsInsteadOfMasking=*/false, readScalableVectorFlags);
PackingMetadata packMetadata;
SmallVector<int64_t> lastDimToInsertPosPerm =
@@ -1950,15 +1988,17 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
stripMineTensorType, packMetadata.reassociations);
mlir::VectorType vecCollapsedType =
- VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
+ VectorType::get(collapsedType.getShape(), collapsedType.getElementType(),
+ writeScalableVectorFlags);
vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
loc, vecCollapsedType, transposeOp->getResult(0));
- // writeVectorSizes had to match the shapecast shape for dynamic sizes,
+ // writeVectorSizesFinal had to match the shapecast shape for dynamic sizes,
// otherwise the validator complains that the mask size is invalid.
- SmallVector<int64_t> writeVectorSizes(
+ // FIXME: We should not override write-vector-sizes like this.
+ SmallVector<int64_t> writeVectorSizesFinal(
unpackOp.getDestType().hasStaticShape()
- ? vectorSizes
+ ? writeVectorSizes
: shapeCastOp.getResultVectorType().getShape());
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
@@ -1989,7 +2029,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
assert(succeeded(status) && "failed to reify result shapes");
auto maskedRead = vector::createReadOrMaskedRead(
rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
- /*useInBoundsInsteadOfMasking=*/false);
+ /*useInBoundsInsteadOfMasking=*/false, /*inputScalableVecSizes=*/{});
// Create Xfer write Op
Value dest = rewriter.create<tensor::EmptyOp>(
@@ -2073,6 +2113,9 @@ static LogicalResult
vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
ArrayRef<int64_t> inputVectorSizes) {
+ // FIXME!!!
+ return success();
+
if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
return !getConstantIntValue(res).has_value();
})) {
@@ -2409,6 +2452,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp,
LDBG("pad value is not constant: " << packOp << "\n");
return failure();
}
+
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
bool satisfyEmptyCond = true;
if (inputVectorSizes.empty()) {
@@ -2487,12 +2531,14 @@ vectorizeScalableVectorPrecondition(Operation *op,
if (numOfScalableDims == 0)
return success();
+ // TODO: Check the following!
auto linalgOp = dyn_cast<LinalgOp>(op);
- // Cond 1: There's been no need for scalable vectorisation of
- // non-linalg Ops so far
- if (!linalgOp)
- return failure();
+ // Cond 1: Reject Ops that don't implement the LinalgOp interface, with the
+ // exception of UnpackOp for which there is a dedicated hook.
+ if (!linalgOp) {
+ return isa<linalg::UnPackOp>(op) ? success() : failure();
+ }
// Cond 2: There's been no need for more than 2 scalable dims so far
if (numOfScalableDims > 2)
@@ -2588,7 +2634,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
isa<linalg::MatmulTransposeAOp>(op) ||
isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
- hasReductionIterator(linalgOp));
+ isa<linalg::UnPackOp>(op) || hasReductionIterator(linalgOp));
}
LogicalResult mlir::linalg::vectorizeOpPrecondition(
@@ -2723,7 +2769,8 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
})
.Case<linalg::UnPackOp>([&](auto unpackOp) {
return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
- inputVectorSizes, results);
+ inputVectorSizes,
+ inputScalableVecDims, results);
})
.Case<tensor::InsertSliceOp>([&](auto sliceOp) {
return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
@@ -3114,7 +3161,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
vecType.getRank(), rewriter.create<arith::ConstantIndexOp>(loc, 0));
Value read = mlir::vector::createReadOrMaskedRead(
rewriter, loc, source, vecType.getShape(), padValue,
- /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
+ /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty(),
+ /*inputScalableVecSizes=*/{});
// Create write
auto writeIndices =
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index c045063e8194f..a379229d3d5a5 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -281,14 +281,16 @@ vector::createUnrollIterator(VectorType vType, int64_t targetRank) {
// Attempt to unroll until targetRank or the first scalable dimension (which
// cannot be unrolled).
auto shapeToUnroll = vType.getShape().drop_back(targetRank);
- auto scalableDimsToUnroll = vType.getScalableDims().drop_back(targetRank);
- auto it = llvm::find(scalableDimsToUnroll, true);
- auto firstScalableDim = it - scalableDimsToUnroll.begin();
+ auto inputScalableVecDimsToUnroll =
+ vType.getScalableDims().drop_back(targetRank);
+ auto it = llvm::find(inputScalableVecDimsToUnroll, true);
+ auto firstScalableDim = it - inputScalableVecDimsToUnroll.begin();
if (firstScalableDim == 0)
return {};
// All scalable dimensions should be removed now.
- scalableDimsToUnroll = scalableDimsToUnroll.slice(0, firstScalableDim);
- assert(!llvm::is_contained(scalableDimsToUnroll, true) &&
+ inputScalableVecDimsToUnroll =
+ inputScalableVecDimsToUnroll.slice(0, firstScalableDim);
+ assert(!llvm::is_contained(inputScalableVecDimsToUnroll, true) &&
"unexpected leading scalable dimension");
// Create an unroll iterator for leading dimensions.
shapeToUnroll = shapeToUnroll.slice(0, firstScalableDim);
@@ -321,15 +323,15 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
ArrayRef<int64_t> inputVectorSizes,
Value padValue,
bool useInBoundsInsteadOfMasking,
- ArrayRef<bool> scalableDims) {
+ ArrayRef<bool> inputScalableVecDims) {
assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) &&
"invalid input vector sizes");
auto sourceShapedType = cast<ShapedType>(source.getType());
auto sourceShape = sourceShapedType.getShape();
assert(sourceShape.size() == inputVectorSizes.size() &&
"expected same ranks.");
- auto vectorType =
- VectorType::get(inputVectorSizes, padValue.getType(), scalableDims);
+ auto vectorType = VectorType::get(inputVectorSizes, padValue.getType(),
+ inputScalableVecDims);
assert(padValue.getType() == sourceShapedType.getElementType() &&
"expected same pad element type to match source element type");
int64_t readRank = inputVectorSizes.size();
@@ -358,8 +360,8 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
? memref::getMixedSizes(builder, loc, source)
: tensor::getMixedSizes(builder, loc, source);
- auto maskType =
- VectorType::get(inputVectorSizes, builder.getI1Type(), scalableDims);
+ auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type(),
+ inputScalableVecDims);
Value mask =
vector::CreateMaskOp::create(builder, loc, maskType, mixedSourceDims);
return mlir::vector::maskOperation(builder, transferReadOp, mask)
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
index 98e8f5079176c..b38d3bdedd52a 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
@@ -940,34 +940,113 @@ module attributes {transform.with_named_sequence} {
///----------------------------------------------------------------------------------------
// CHECK-LABEL: func @test_vectorize_dynamic_shapes_unpack
-// CHECK-SAME: %[[ARG_0:.*]]: tensor<?x?xf32>,
-func.func @test_vectorize_dynamic_shapes_unpack(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x16x2xf32>) -> tensor<?x?xf32> {
-// CHECK: %[[C0:.*]] = arith.constant 0
-// CHECK: %[[DIM:.*]] = tensor.dim %arg0, %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[DIM0:.*]] = tensor.dim %arg0, %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00
-// CHECK: %[[C01:.*]] = arith.constant 0
-// CHECK: %[[C02:.*]] = arith.constant 0
-// CHECK: %[[DIM4:.*]] = tensor.dim %arg1, %[[C02]] : tensor<?x?x16x2xf32>
-// CHECK: %[[CNST14:.*]] = arith.constant 1
-// CHECK: %[[DIM6:.*]] = tensor.dim %arg1, %[[CNST14]] : tensor<?x?x16x2xf32>
-// CHECK: %[[CNST16:.*]] = arith.constant 16 : index
-// CHECK: %[[CNST2:.*]] = arith.constant 2 : index
-// CHECK: %[[readMsk0:.*]] = vector.create_mask %[[DIM4]], %[[DIM6]], %[[CNST16]], %[[CNST2]] : vector<2x1x16x2xi1>
-// CHECK: %[[read0:.*]] = vector.mask %[[readMsk0]] {{.*}} vector.transfer_read %{{.*}} : tensor<?x?x16x2xf32>, vector<2x1x16x2xf32> } : vector<2x1x16x2xi1> -> vector<2x1x16x2xf32>
-// CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 3, 1, 2] : vector<2x1x16x2xf32> to vector<2x2x1x16xf32>
-// CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<2x2x1x16xf32> to vector<4x16xf32>
-// CHECK: %[[writeMsk0:.*]] = vector.create_mask {{.*}} : vector<4x16xi1>
-// CHECK: %[[write0:.*]] = vector.mask %[[writeMsk0:.*]] {{.*}} vector.transfer_write %[[sc0]], %[[ARG_0]]
-// CHECK: return %[[write0]]
- %ret = linalg.unpack %arg1 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg0 : tensor<?x?x16x2xf32> -> tensor<?x?xf32>
- return %ret : tensor<?x?xf32>
+// CHECK-SAME: %[[DEST:.*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[SRC:.*]]: tensor<?x?x16x2xf32>
+func.func @test_vectorize_dynamic_shapes_unpack(%dest: tensor<?x?xf32>, %src: tensor<?x?x16x2xf32>) -> tensor<?x?xf32> {
+ // CHECK: %[[C0:.*]] = arith.constant 0
+ // CHECK: %[[DIM:.*]] = tensor.dim %[[DEST]], %[[C0]] : tensor<?x?xf32>
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK: %[[DIM0:.*]] = tensor.dim %[[DEST]], %[[C1]] : tensor<?x?xf32>
+ // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00
+ // CHECK: %[[C01:.*]] = arith.constant 0
+ // CHECK: %[[C02:.*]] = arith.constant 0
+ // CHECK: %[[DIM4:.*]] = tensor.dim %[[SRC]], %[[C02]] : tensor<?x?x16x2xf32>
+ // CHECK: %[[CNST14:.*]] = arith.constant 1
+ // CHECK: %[[DIM6:.*]] = tensor.dim %[[SRC]], %[[CNST14]] : tensor<?x?x16x2xf32>
+ // CHECK: %[[CNST16:.*]] = arith.constant 16 : index
+ // CHECK: %[[CNST2:.*]] = arith.constant 2 : index
+ // CHECK: %[[MASK_READ:.*]] = vector.create_mask %[[DIM4]], %[[DIM6]], %[[CNST16]], %[[CNST2]] : vector<2x1x16x2xi1>
+ // CHECK: %[[READ:.*]] = vector.mask %[[MASK_READ]] {{.*}} vector.transfer_read %{{.*}} : tensor<?x?x16x2xf32>, vector<2x1x16x2xf32> } : vector<2x1x16x2xi1> -> vector<2x1x16x2xf32>
+ // CHECK: %[[TR:.*]] = vector.transpose %[[READ]], [0, 3, 1, 2] : vector<2x1x16x2xf32> to vector<2x2x1x16xf32>
+ // CHECK: %[[SC:.*]] = vector.shape_cast %[[TR]] : vector<2x2x1x16xf32> to vector<4x16xf32>
+ // CHECK: %[[MASK_WRITE:.*]] = vector.create_mask {{.*}} : vector<4x16xi1>
+ // CHECK: %[[WRITE:.*]] = vector.mask %[[MASK_WRITE:.*]] {{.*}} vector.transfer_write %[[SC]], %[[DEST]]
+ // CHECK: return %[[WRITE]]
+ %ret = linalg.unpack %src inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %dest : tensor<?x?x16x2xf32> -> tensor<?x?xf32>
+ return %ret : tensor<?x?xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- transform.structured.vectorize %0 vector_sizes [4, 16] : !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [2, 1, 16, 2, 4, 16] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func @test_vectoriz...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/149293
More information about the Mlir-commits
mailing list