[Mlir-commits] [mlir] adf838d - [mlir][Vectorizer] Added support to Vectorize tensor.unpack (#76087)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 20 14:10:18 PST 2024
Author: Balaji V. Iyer
Date: 2024-02-20T16:10:14-06:00
New Revision: adf838daee63b3245c8822957988da5367e1572c
URL: https://github.com/llvm/llvm-project/commit/adf838daee63b3245c8822957988da5367e1572c
DIFF: https://github.com/llvm/llvm-project/commit/adf838daee63b3245c8822957988da5367e1572c.diff
LOG: [mlir][Vectorizer] Added support to Vectorize tensor.unpack (#76087)
Added support to vectorized tensor.unpack. The unpack Op is split into a
`vector.transfer_read`, `vector.transpose`, `vector.shape_cast` and a
`vector.transfer_write`.
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Tensor/Utils/Utils.cpp
mlir/test/Dialect/Linalg/vectorization.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
index fe9b16cb44b3da..d09c9e36f6ff88 100644
--- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
@@ -32,13 +32,11 @@ FailureOr<RankedTensorType>
computeTransposedType(RankedTensorType rankedTensorType,
ArrayRef<int64_t> transposeVector);
-/// Given a tensor::PackOp, compute the permutation vector to shuffle the
-/// packed shape into the shape before any outer or inner permutations have
-/// been applied.
-/// i.e. for a pack from an ABCD layout to an ABCDba:
-/// The packed shape would be ABCDba.
-/// The pre-permutation shape would be AaBbCD.
-SmallVector<int64_t> getPackInverseDestPermutation(PackOp packOp);
+SmallVector<int64_t> getPackInverseDestPerm(tensor::PackOp packOp);
+SmallVector<int64_t> getUnPackInverseSrcPerm(tensor::UnPackOp unpackOp);
+
+SmallVector<int64_t> getUnPackInverseSrcPerm(tensor::UnPackOp,
+ PackingMetadata &metadata);
/// A tensor.insert_slice is a cast-like operation if it merely rank-extends the
/// source tensor or inserts the source tensor into a destination tensor with
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 4ef8859fd5c430..299965bcfc3ab3 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3152,7 +3152,8 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
// TODO: Check that the correct number of vectorSizes was provided.
for (Operation *target : targets) {
- if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp>(target)) {
+ if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
+ target)) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "Unsupported Op, cannot vectorize";
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 01b393644679c5..a17bc8e4cd318f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -237,7 +237,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
PackingMetadata packingMetadata = computePackingMetadata(
packedTensorType.getRank(), packOp.getInnerDimsPos());
SmallVector<int64_t> packedToStripMinedShapePerm =
- tensor::getPackInverseDestPermutation(packOp);
+ tensor::getPackInverseDestPerm(packOp);
// 3. Compute the stripMinedShape: this is the packed shape before any outer
// or inner permutations have been applied.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 2bd6929fea6142..ac043e87223dfe 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1405,8 +1405,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
/// permutations.
static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
ArrayRef<int64_t> destShape) {
- return applyPermutation(destShape,
- tensor::getPackInverseDestPermutation(packOp));
+ return applyPermutation(destShape, tensor::getPackInverseDestPerm(packOp));
}
/// Create a TransferReadOp from `source` with static shape `readShape`. If the
@@ -1547,7 +1546,7 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
// Create TransposeOp.
auto destPermutation =
- invertPermutationVector(tensor::getPackInverseDestPermutation(packOp));
+ invertPermutationVector(tensor::getPackInverseDestPerm(packOp));
auto transposeOp = rewriter.create<vector::TransposeOp>(
loc, shapeCastOp.getResult(), destPermutation);
@@ -1559,6 +1558,112 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
return success();
}
+/// Vectorize a `tensor::UnPackOp` to these 4 Ops:
+/// Vector::TransferReadOp - Reads a vector from the source tensor
+/// vector::TransposeOp - Transpose the Source tensor
+/// ShapeCastOp - Reshape the data based on the target.
+/// vector::TransferWriteOp. - Write the result vector back to the destination
+/// tensor
+static LogicalResult
+vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
+ ArrayRef<int64_t> inputVectorSizes,
+ SmallVectorImpl<Value> &newResults) {
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(unpackOp);
+
+ RankedTensorType unpackTensorType = unpackOp.getSourceType();
+
+ ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
+ ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
+
+ SmallVector<int64_t> readMaskShape(inputVectorSizes.begin(),
+ inputVectorSizes.end());
+ ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
+ ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
+
+ // ReadMask is the size of tensor used to read and apply mask. It is
+ // set like this: Let's say the vectorSize (VS) array is size 'N' and
+ // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
+ // size M-N
+ // Thus:
+ // - initially: ReadMaskShape = vectorInputSizes
+ // - Divide all the readMaskShape locations pointed by innerDimPos
+ // by the innerTileSize attribute value.
+ // - if outer_dims_perms is present: do that permutation on readMaskShape.
+ // - Append the remaining shape from SS
+ // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
+ // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
+ // 128] and outer_dims_perm is [1, 0] then read shape is:
+ // ReadMaskShape(initial): [512, 128]
+ // Final Value(after innerDim Adjustment): [512/32, 128/16]
+ // = [16, 8]
+ // After applying outer_dims_perm: [8, 16]
+ // After appending the rest of the sourceShape: [8, 16, 32, 16]
+
+ for (auto [index, size] : enumerate(innerTiles)) {
+ readMaskShape[innerDimPos[index]] =
+ llvm::divideCeil(readMaskShape[innerDimPos[index]], size);
+ }
+ if (!outerDimsPerm.empty()) {
+ applyPermutationToVector(readMaskShape, outerDimsPerm);
+ }
+ readMaskShape.append(sourceShape.begin() + inputVectorSizes.size(),
+ sourceShape.end());
+
+ ReifiedRankedShapedTypeDims reifiedRetShapes;
+ LogicalResult status =
+ cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
+ .reifyResultShapes(rewriter, reifiedRetShapes);
+ if (status.failed()) {
+ LDBG("Unable to reify result shapes of " << unpackOp);
+ return failure();
+ }
+ Location loc = unpackOp->getLoc();
+
+ auto padValue = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
+
+ // Read result, mask if necessary. If transferReadOp shape is not equal
+ // to shape of source, then a mask is necessary.
+ Value readResult = createReadOrMaskedRead(
+ rewriter, loc, unpackOp.getSource(),
+ ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()), padValue);
+
+ PackingMetadata packMetadata;
+ SmallVector<int64_t> lastDimToInsertPosPerm =
+ tensor::getUnPackInverseSrcPerm(unpackOp, packMetadata);
+ ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
+ SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
+ mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
+ applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
+ RankedTensorType stripMineTensorType =
+ RankedTensorType::get(stripMineShape, stripMineElemType);
+ // Transpose the appropriate rows to match output.
+ vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
+ loc, readResult, lastDimToInsertPosPerm);
+
+ // Collapse the vector to the size required by result.
+ RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+ stripMineTensorType, packMetadata.reassociations);
+ mlir::VectorType vecCollapsedType =
+ VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
+ vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
+ loc, vecCollapsedType, transposeOp->getResult(0));
+
+ // WriteMaskShape had to match the shapecast shape for dynamic sizes,
+ // otherwise the validator complains that the mask size is invalid.
+ SmallVector<int64_t> writeMaskShape(
+ unpackOp.getDestType().hasStaticShape()
+ ? inputVectorSizes
+ : shapeCastOp.getResultVectorType().getShape());
+ Operation *write =
+ createWriteOrMaskedWrite(rewriter, loc, shapeCastOp.getResult(),
+ reifiedRetShapes[0], writeMaskShape);
+ newResults.push_back(write->getResult(0));
+ return success();
+}
+
/// Vectorize a `padOp` with (1) static result type, (2) constant padding value
/// and (3) all-zero lowPad to
/// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
@@ -1655,6 +1760,25 @@ isValidMaskedInputVector(ArrayRef<int64_t> shape,
return success();
}
+/// Need to check if the inner-tiles are static/constant.
+static LogicalResult
+vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
+ ArrayRef<int64_t> inputVectorSizes) {
+
+ if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
+ return !getConstantIntValue(res).has_value();
+ })) {
+ LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
+ return failure();
+ }
+ llvm::ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
+ if (!inputVectorSizes.empty() &&
+ failed(isValidMaskedInputVector(resultShape, inputVectorSizes)))
+ return failure();
+
+ return success();
+}
+
static LogicalResult
vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
ArrayRef<int64_t> inputVectorSizes,
@@ -1703,9 +1827,10 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
}
if (isElementwise(linalgOp))
return success();
- // TODO: isaConvolutionOpInterface that can also infer from generic features.
- // But we will still need stride/dilation attributes that will be annoying to
- // reverse-engineer...
+
+ // TODO: isaConvolutionOpInterface that can also infer from generic
+ // features. But we will still need stride/dilation attributes that will be
+ // annoying to reverse-engineer...
if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
return success();
// TODO: the common vector shape is equal to the static loop sizes only when
@@ -1810,6 +1935,9 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
.Case<tensor::PackOp>([&](auto packOp) {
return vectorizePackOpPrecondition(packOp, inputVectorSizes);
})
+ .Case<tensor::UnPackOp>([&](auto unpackOp) {
+ return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
+ })
.Default([](auto) { return failure(); });
}
@@ -1829,11 +1957,11 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
}
/// Emit a suitable vector form for an operation. If provided,
-/// `inputVectorSizes` are used to vectorize this operation. `inputVectorSizes`
-/// must match the rank of the iteration space of the operation and the input
-/// vector sizes must be greater than or equal to their counterpart iteration
-/// space sizes, if static. `inputVectorShapes` also allows the vectorization of
-/// operations with dynamic shapes.
+/// `inputVectorSizes` are used to vectorize this operation.
+/// `inputVectorSizes` must match the rank of the iteration space of the
+/// operation and the input vector sizes must be greater than or equal to
+/// their counterpart iteration space sizes, if static. `inputVectorShapes`
+/// also allows the vectorization of operations with dynamic shapes.
LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims,
@@ -1867,8 +1995,9 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
auto vectorizeResult =
TypeSwitch<Operation *, LogicalResult>(op)
.Case<linalg::LinalgOp>([&](auto linalgOp) {
- // TODO: isaConvolutionOpInterface that can also infer from generic
- // features. Will require stride/dilation attributes inference.
+ // TODO: isaConvolutionOpInterface that can also infer from
+ // generic features. Will require stride/dilation attributes
+ // inference.
if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
FailureOr<Operation *> convOr = vectorizeConvolution(
rewriter, linalgOp, flatten1DDepthwiseConv);
@@ -1902,6 +2031,10 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
results);
})
+ .Case<tensor::UnPackOp>([&](auto unpackOp) {
+ return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
+ inputVectorSizes, results);
+ })
.Default([](auto) { return failure(); });
if (failed(vectorizeResult)) {
@@ -1919,7 +2052,6 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
memref::CopyOp copyOp) {
-
auto srcType = cast<MemRefType>(copyOp.getSource().getType());
auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
@@ -2833,8 +2965,8 @@ struct Conv1DGenerator
Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
resPadding);
- // The base vectorization case for channeled convolution is input: {n,w,c},
- // weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
+ // The base vectorization case for channeled convolution is input:
+ // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
// vectorization case, we do pre transpose on input, weight, and output.
switch (conv1DOpOrder) {
case Conv1DOpOrder::W:
@@ -2877,9 +3009,9 @@ struct Conv1DGenerator
return kw * (wSize / wSizeStep) + w;
};
- // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f} or
- // perform outerproduct for non-channeled convolution or
- // perform simple arith operation for pooling
+ // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
+ // or perform outerproduct for non-channeled convolution or perform simple
+ // arith operation for pooling
for (int64_t kw = 0; kw < kwSize; ++kw) {
for (int64_t w = 0; w < wSize; w += wSizeStep) {
switch (oper) {
@@ -2908,9 +3040,9 @@ struct Conv1DGenerator
// End vector-only rewrite part
//===------------------------------------------------------------------===//
- // The base vectorization case for channeled convolution is output: {n,w,f}
- // To reuse the result from base pattern vectorization case, we post
- // transpose the base case result.
+ // The base vectorization case for channeled convolution is output:
+ // {n,w,f} To reuse the result from base pattern vectorization case, we
+ // post transpose the base case result.
switch (conv1DOpOrder) {
case Conv1DOpOrder::W:
case Conv1DOpOrder::Nwc:
@@ -3348,9 +3480,9 @@ static FailureOr<Operation *>
vectorizeConvolution(RewriterBase &rewriter, LinalgOp op,
bool flatten1DDepthwiseConv) {
// The ConvolutionOpInterface gives us guarantees of existence for
- // strides/dilations. However, we do not need to rely on those, we can simply
- // use them if present, otherwise use the default and let the generic conv.
- // matcher in the ConvGenerator succeed or fail.
+ // strides/dilations. However, we do not need to rely on those, we can
+ // simply use them if present, otherwise use the default and let the generic
+ // conv. matcher in the ConvGenerator succeed or fail.
auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index f20008a1ed2b2f..186f85d2ce20a6 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -72,36 +72,73 @@ mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType,
RTTBuilder(rankedTensorType).setShape(transposedShape);
return transposedTensorType;
}
-
-SmallVector<int64_t>
-mlir::tensor::getPackInverseDestPermutation(PackOp packOp) {
- // The permutation can be obtained from two permutations:
- // a) Compute the permutation vector to move the last `numPackedDims` into
- // the `innerPosDims` of a shape of rank `packedRank`.
- // b) Compute the permutation vector to move outer dims if the pack op
- // has outer_dims_perm.
- // Apply (b) permutation on (a) permutation to get the final permutation.
- int64_t numPackedDims = packOp.getInnerDimsPos().size();
- int64_t packedRank = packOp.getDestType().getRank();
- auto lastDims = llvm::to_vector(
- llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
- PackingMetadata packingMetadata = computePackingMetadata(
- packOp.getDestType().getRank(), packOp.getInnerDimsPos());
- SmallVector<int64_t> innerPositionsPerm = computePermutationVector(
- packedRank, lastDims, packingMetadata.insertPositions);
+/// The permutation can be obtained from two permutations:
+/// a) Compute the permutation vector to move the last `numPackedDims` into
+/// the `innerPosDims` of a shape of rank `rank`.
+/// b) Compute the permutation vector to move outer dims if the
+/// `outerPerm` parameter is not empty.
+/// Apply (b) permutation on (a) permutation to get the final permutation.
+static SmallVector<int64_t>
+computePackUnPackPerm(int64_t rank, ArrayRef<int64_t> &innerDimsPos,
+ ArrayRef<int64_t> &outerPerm,
+ PackingMetadata &packingMetadata) {
+ int64_t numPackedDims = innerDimsPos.size();
+ auto lastDims =
+ llvm::to_vector(llvm::seq<int64_t>(rank - numPackedDims, rank));
+ packingMetadata = computePackingMetadata(rank, innerDimsPos);
+ SmallVector<int64_t> innerPositionsPerm =
+ computePermutationVector(rank, lastDims, packingMetadata.insertPositions);
SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
- ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
if (!outerPerm.empty())
applyPermutationToVector(outerPos, outerPerm);
- SmallVector<int64_t> outerPositionPerm = computePermutationVector(
- packedRank, packingMetadata.outerPositions, outerPos);
+ SmallVector<int64_t> outerPositionPerm =
+ computePermutationVector(rank, packingMetadata.outerPositions, outerPos);
SmallVector<int64_t> packInverseDestPermutation = innerPositionsPerm;
applyPermutationToVector(packInverseDestPermutation, outerPositionPerm);
return packInverseDestPermutation;
}
+/// Shell function to compute the Destination Permutation of PackOp
+/// This function uses the helper function `computePackUnPackPerm` to get
+/// the permutation vector. Only major
diff erence between UnPack and Pack is
+/// that packOp uses destination rank whereas unpack Uses source rank.
+SmallVector<int64_t> mlir::tensor::getPackInverseDestPerm(PackOp packOp) {
+
+ PackingMetadata pMetadata;
+ int64_t packedRank = packOp.getDestType().getRank();
+ ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos();
+ ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
+ SmallVector<int64_t> packInvDestPerm =
+ computePackUnPackPerm(packedRank, innerDimPos, outerPerm, pMetadata);
+ return packInvDestPerm;
+}
+
+/// Shell function to compute the Source Permutation of unPackOp.
+/// This function, like the getPackInverseDestPerm uses the helper function
+/// computePackUnPackPerm` to get the permutation vector.
+/// Only major
diff erence between UnPack and Pack is that packOp uses
+/// destination rank whereas unpack Uses source rank.
+SmallVector<int64_t> mlir::tensor::getUnPackInverseSrcPerm(UnPackOp unpackOp) {
+ PackingMetadata metadata;
+ return mlir::tensor::getUnPackInverseSrcPerm(unpackOp, metadata);
+}
+
+/// Shell function to compute the Source rank permutation for unpackOp
+/// Unpack requires some packing metadata data information, so created
+/// another function where this value is passed by reference.
+SmallVector<int64_t>
+mlir::tensor::getUnPackInverseSrcPerm(UnPackOp unpackOp,
+ PackingMetadata &metadata) {
+ int64_t unpackRank = unpackOp.getSourceType().getRank();
+ ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
+ ArrayRef<int64_t> outerPerm = unpackOp.getOuterDimsPerm();
+ SmallVector<int64_t> unpackInvSrcPerm =
+ computePackUnPackPerm(unpackRank, innerDimPos, outerPerm, metadata);
+ return unpackInvSrcPerm;
+}
+
bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
llvm::SmallBitVector droppedDims = op.getDroppedDims();
int64_t srcDim = 0;
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 0272ac599aa3db..2d01d57304013c 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -697,3 +697,118 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// CHECK-LABEL: func @test_vectorize_dynamic_shapes_unpack
+func.func @test_vectorize_dynamic_shapes_unpack(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x16x2xf32>) -> tensor<?x?xf32> {
+// CHECK: %[[C0:.*]] = arith.constant 0
+// CHECK: %[[DIM:.*]] = tensor.dim %arg0, %[[C0]] : tensor<?x?xf32>
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM0:.*]] = tensor.dim %arg0, %[[C1]] : tensor<?x?xf32>
+// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00
+// CHECK: %[[C01:.*]] = arith.constant 0
+// CHECK: %[[C02:.*]] = arith.constant 0
+// CHECK: %[[DIM4:.*]] = tensor.dim %arg1, %[[C02]] : tensor<?x?x16x2xf32>
+// CHECK: %[[CNST14:.*]] = arith.constant 1
+// CHECK: %[[DIM6:.*]] = tensor.dim %arg1, %[[CNST14]] : tensor<?x?x16x2xf32>
+// CHECK: %[[CNST16:.*]] = arith.constant 16 : index
+// CHECK: %[[CNST2:.*]] = arith.constant 2 : index
+// CHECK: %[[readMsk0:.*]] = vector.create_mask %[[DIM4]], %[[DIM6]], %[[CNST16]], %[[CNST2]] : vector<2x1x16x2xi1>
+// CHECK: %[[read0:.*]] = vector.mask %[[readMsk0]] {{.*}} vector.transfer_read %{{.*}} : tensor<?x?x16x2xf32>, vector<2x1x16x2xf32> } : vector<2x1x16x2xi1> -> vector<2x1x16x2xf32>
+// CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 3, 1, 2] : vector<2x1x16x2xf32> to vector<2x2x1x16xf32>
+// CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<2x2x1x16xf32> to vector<4x16xf32>
+// CHECK: %[[empt0:.*]] = tensor.empty
+// CHECK: %[[writeMsk0:.*]] = vector.create_mask {{.*}} : vector<4x16xi1>
+// CHECK: %[[write0:.*]] = vector.mask %[[writeMsk0:.*]] {{.*}} vector.transfer_write %[[sc0]], %[[empt0]]
+// CHECK: return %[[write0]]
+ %ret = tensor.unpack %arg1 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg0 : tensor<?x?x16x2xf32> -> tensor<?x?xf32>
+ return %ret : tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [4, 16] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func @test_vectorize_unpack
+func.func @test_vectorize_unpack(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
+ // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+ // CHECK: %[[C0:.*]]= arith.constant 0 : index
+ // CHECK: %[[C8:.*]] = arith.constant 8 : index
+ // CHECK: %[[C80:.*]] = arith.constant 8 : index
+ // CHECK: %[[C32:.*]] = arith.constant 32 : index
+ // CHECK: %[[C16:.*]] = arith.constant 16 : index
+ // CHECK: %[[MSK0:.*]] = vector.create_mask %[[C8]], %[[C80]], %[[C32]], %[[C16]] : vector<16x8x32x16xi1>
+ // CHECK: %[[READ0:.*]] = vector.mask %[[MSK0]] {{.*}} : vector<16x8x32x16xi1> -> vector<16x8x32x16xf32>
+ // CHECK: %[[TRANSP0:.*]] = vector.transpose %[[READ0]], [0, 2, 1, 3] : vector<16x8x32x16xf32> to vector<16x32x8x16xf32>
+ // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP0]] : vector<16x32x8x16xf32> to vector<512x128xf32>
+ // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
+ // CHECK: %[[C01:.*]] = arith.constant 0 : index
+ // CHECK: %[[C256:.*]] = arith.constant 256 : index
+ // CHECK: %[[C128:.*]] = arith.constant 128 : index
+ // CHECK: %[[WRITEMSK:.*]] = vector.create_mask %[[C256]], %[[C128]] : vector<512x128xi1>
+ // CHECK: %[[WRIT:.*]] = vector.mask %[[WRITEMSK]] {{.*}} : vector<512x128xi1> -> tensor<256x128xf32>
+ // CHECK: return %[[WRIT]] : tensor<256x128xf32>
+ %0 = tensor.unpack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
+ return %0 : tensor<256x128xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [512, 128] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func @test_vectorize_unpack_no_masks
+func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
+ // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
+ // CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [0, 2, 1, 3] : vector<8x8x32x16xf32> to vector<8x32x8x16xf32>
+ // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<8x32x8x16xf32> to vector<256x128xf32>
+ // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
+ // CHECK: %[[C00:.*]] = arith.constant 0 : index
+ // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<256x128xf32>, tensor<256x128xf32>
+ // CHECK: return %[[WRIT]] : tensor<256x128xf32>
+ %0 = tensor.unpack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
+ return %0 : tensor<256x128xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [256, 128] : !transform.any_op
+ transform.yield
+ }
+ }
+
+ // -----
+
+ // CHECK-LABEL: test_vectorize_unpack_with_outer_perm
+ func.func @test_vectorize_unpack_with_outer_perm(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
+ // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
+ // CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [1, 2, 0, 3] : vector<8x8x32x16xf32> to vector<8x32x8x16xf32>
+ // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<8x32x8x16xf32> to vector<256x128xf32>
+ // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
+ // CHECK: %[[C00:.*]] = arith.constant 0 : index
+ // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<256x128xf32>, tensor<256x128xf32>
+ // CHECK: return %[[WRIT]] : tensor<256x128xf32>
+ %0 = tensor.unpack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
+ return %0 : tensor<256x128xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [256, 128] : !transform.any_op
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list