[Mlir-commits] [mlir] [mlir][Vectorizer] Added support to Vectorize tensor.unpack (PR #76087)
Balaji V. Iyer.
llvmlistbot at llvm.org
Tue Feb 6 15:47:04 PST 2024
https://github.com/bviyer updated https://github.com/llvm/llvm-project/pull/76087
>From d97a5729f496fb603f4fb9cf2977b015d8e37ed6 Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Thu, 30 Nov 2023 20:39:55 +0000
Subject: [PATCH 1/4] [mlir][Vectorizer] Vectorize `tensor.unpack`
This patch allows vectorization of a `tensor.unpack` operation.
---
.../Linalg/Transforms/Vectorization.cpp | 96 +++++++++++++++++++
1 file changed, 96 insertions(+)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index f9a53a8451a60..7a9846154bf34 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
@@ -1385,6 +1386,88 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
return success();
}
+// Vectorize an `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
+// Vector::TransferReadOp - Reads the Vector Array of Source data
+// vector::TransposeOp - Transpose the Source
+// ShapeCastOp - Reshapes the data based on the target.
+// vector::TransferWriteOp. - Write the result vector back.
+
+static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
+ tensor::UnPackOp unpackOp,
+ ArrayRef<int64_t> inputVectorSizes,
+ SmallVectorImpl<Value> &newResults) {
+
+ if (!unpackOp.getOuterDimsPerm().empty()) {
+ LDBG("outer dimensions perms NYI for: " << unpackOp);
+ return failure();
+ }
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(unpackOp);
+
+ RankedTensorType packTensorType = unpackOp.getSourceType();
+ auto maskType =
+ VectorType::get(packTensorType.getShape(), rewriter.getI1Type());
+ auto vectorType = VectorType::get(packTensorType.getShape(),
+ packTensorType.getElementType());
+ ReifiedRankedShapedTypeDims reifiedRetShapes;
+ LogicalResult status =
+ cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
+ .reifyResultShapes(rewriter, reifiedRetShapes);
+ if (status.failed()) {
+ LDBG("Unable to reify result shapes of " << unpackOp);
+ return failure();
+ }
+
+ arith::ConstantIndexOp zeroOp =
+ rewriter.create<arith::ConstantIndexOp>(unpackOp->getLoc(), 0);
+ Value mask = rewriter.create<vector::CreateMaskOp>(
+ unpackOp.getLoc(), maskType,
+ tensor::getMixedSizes(rewriter, unpackOp.getLoc(), unpackOp.getSource()));
+
+ vector::TransferReadOp readOp = rewriter.create<vector::TransferReadOp>(
+ unpackOp.getLoc(), vectorType, unpackOp.getSource(),
+ SmallVector<Value>(packTensorType.getRank(), zeroOp),
+ rewriter.getMultiDimIdentityMap(packTensorType.getRank()));
+
+ vector::MaskOp maskedOp =
+ cast<vector::MaskOp>(mlir::vector::maskOperation(rewriter, readOp, mask));
+
+ int64_t numPackedDim = unpackOp.getInnerDimsPos().size();
+ int64_t packRank = packTensorType.getRank();
+ auto lastDims =
+ llvm::to_vector(llvm::seq<int64_t>(packRank - numPackedDim, packRank));
+ PackingMetadata packMetadata =
+ computePackingMetadata(packRank, unpackOp.getInnerDimsPos());
+ SmallVector<int64_t> lastDimToInsertPosPerm = computePermutationVector(
+ packRank, lastDims, packMetadata.insertPositions);
+ SmallVector<int64_t> stripMineShape(packTensorType.getShape());
+ applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
+
+ RankedTensorType stripMineTensorType =
+ RankedTensorType::Builder(packTensorType).setShape(stripMineShape);
+
+ RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+ stripMineTensorType, packMetadata.reassociations);
+ auto vecCollapsedType =
+ VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
+
+ vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
+ unpackOp.getLoc(), maskedOp.getResult(0), lastDimToInsertPosPerm);
+
+ vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
+ unpackOp.getLoc(), vecCollapsedType, transposeOp->getResult(0));
+ tensor::EmptyOp emptyOp = rewriter.create<tensor::EmptyOp>(
+ unpackOp.getLoc(), reifiedRetShapes[0], packTensorType.getElementType());
+
+ vector::TransferWriteOp writeOp = rewriter.create<vector::TransferWriteOp>(
+ unpackOp.getLoc(), shapeCastOp->getResult(0), emptyOp,
+ SmallVector<Value>(lastDims.size(), zeroOp),
+ SmallVector<bool>(lastDims.size(), true));
+
+ newResults.push_back(writeOp->getResult(0));
+ return success();
+}
/// Vectorize a `padOp` with (1) static result type, (2) constant padding value
/// and (3) all-zero lowPad to
@@ -1578,6 +1661,12 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
return success();
}
+static LogicalResult
+vectorizeUnpackOpPrecondition(tensor::UnPackOp unpackOp,
+ ArrayRef<int64_t> inputVectorSizes) {
+ return success();
+}
+
static LogicalResult
vectorizePadOpPrecondition(tensor::PadOp padOp,
ArrayRef<int64_t> inputVectorSizes) {
@@ -1637,6 +1726,9 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
.Case<tensor::PadOp>([&](auto padOp) {
return vectorizePadOpPrecondition(padOp, inputVectorSizes);
})
+ .Case<tensor::UnPackOp>([&](auto unpackOp) {
+ return vectorizeUnpackOpPrecondition(unpackOp, inputVectorSizes);
+ })
.Default([](auto) { return failure(); });
}
@@ -1724,6 +1816,10 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
results);
})
+ .Case<tensor::UnPackOp>([&](auto unpackOp) {
+ return vectorizeAsUnpackOp(rewriter, unpackOp, inputVectorSizes,
+ results);
+ })
.Default([](auto) { return failure(); });
if (failed(vectorizeResult)) {
>From a52d650fd8ff0f761fd2aa77671ffa858f3e9237 Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Wed, 20 Dec 2023 11:45:20 -0600
Subject: [PATCH 2/4] Enabled tensor.unpack vectorization and added test case.
---
.../TransformOps/LinalgTransformOps.cpp | 2 +-
mlir/test/Dialect/Linalg/vectorization.mlir | 25 +++++++++++++++++++
2 files changed, 26 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 14404d837ff74..4c456a5e671f5 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3078,7 +3078,7 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
// TODO: Check that the correct number of vectorSizes was provided.
for (Operation *target : targets) {
- if (!isa<linalg::LinalgOp, tensor::PadOp>(target)) {
+ if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::UnPackOp>(target)) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "Unsupported Op, cannot vectorize";
}
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 610339405d1c2..acf7276626a4c 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -419,6 +419,31 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: func @test_vectorize_unpack
+func.func @test_vectorize_unpack(%0 : tensor<7x1136x16x16xf32>) -> tensor<100x18176xf32> {
+ // CHECK %[[c0:.*]] = arith.constant 0 : index
+ // CHECK: %[[tr0:.*]] = vector.mask %[[m0:.*]] {{.*}} vector.transfer_read %{{.*}} : tensor<7x1136x16x16xf32>, vector<7x1136x16x16xf32> } : vector<7x1136x16x16xi1> -> vector<7x1136x16x16xf32>
+ // CHECK: %[[trans0:.*]] = vector.transpose %[[tr0]], [0, 2, 1, 3] : vector<7x1136x16x16xf32> to vector<7x16x1136x16xf32>
+ // CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<7x16x1136x16xf32> to vector<112x18176xf32>
+ // CHECK: %[[empt0:.*]] = tensor.empty() : tensor<100x18176xf32>
+ // CHECK: %[[tw0:.*]] = vector.transfer_write %[[sc0]], %[[empt0]]
+ // CHECK: return %[[tw0]]
+ %8 = tensor.empty() : tensor<100x18176xf32>
+ %unpack = tensor.unpack %0 inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %8 : tensor<7x1136x16x16xf32> -> tensor<100x18176xf32>
+ return %unpack : tensor<100x18176xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [2, 4] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK-LABEL: func @test_masked_vectorize_pad
func.func @test_masked_vectorize_pad(
%0 : tensor<?x?xf32>, %h0 : index, %h1 : index)
>From a77ac60ddd46e25a9183d83be302445552766e3f Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Fri, 19 Jan 2024 19:11:04 -0600
Subject: [PATCH 3/4] Added some of the changes requested by Diego and HanHan
---
.../Linalg/Transforms/Vectorization.cpp | 20 ++++++-------------
1 file changed, 6 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 7a9846154bf34..611dec4af0f93 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1386,11 +1386,11 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
return success();
}
-// Vectorize an `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
-// Vector::TransferReadOp - Reads the Vector Array of Source data
-// vector::TransposeOp - Transpose the Source
-// ShapeCastOp - Reshapes the data based on the target.
-// vector::TransferWriteOp. - Write the result vector back.
+/// Vectorize an `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
+/// Vector::TransferReadOp - Reads the Vector Array of Source data
+/// vector::TransposeOp - Transpose the Source
+/// ShapeCastOp - Reshapes the data based on the target.
+/// vector::TransferWriteOp. - Write the result vector back.
static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
tensor::UnPackOp unpackOp,
@@ -1661,12 +1661,6 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
return success();
}
-static LogicalResult
-vectorizeUnpackOpPrecondition(tensor::UnPackOp unpackOp,
- ArrayRef<int64_t> inputVectorSizes) {
- return success();
-}
-
static LogicalResult
vectorizePadOpPrecondition(tensor::PadOp padOp,
ArrayRef<int64_t> inputVectorSizes) {
@@ -1726,9 +1720,7 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
.Case<tensor::PadOp>([&](auto padOp) {
return vectorizePadOpPrecondition(padOp, inputVectorSizes);
})
- .Case<tensor::UnPackOp>([&](auto unpackOp) {
- return vectorizeUnpackOpPrecondition(unpackOp, inputVectorSizes);
- })
+ .Case<tensor::UnPackOp>([&](auto unpackOp) { return success(); })
.Default([](auto) { return failure(); });
}
>From 432ce04dfe64e9c24f27061de811b4e34e113776 Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Tue, 6 Feb 2024 23:45:59 +0000
Subject: [PATCH 4/4] Used vectorSizes for masks and added a dynamic shapes
test case.
---
.../Linalg/Transforms/Vectorization.cpp | 113 +++++++++++++-----
mlir/test/Dialect/Linalg/vectorization.mlir | 33 ++++-
2 files changed, 114 insertions(+), 32 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 4aa604610e217..c450b0635edfb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1393,17 +1393,18 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
return success();
}
+
/// Vectorize an `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
/// Vector::TransferReadOp - Reads the Vector Array of Source data
/// vector::TransposeOp - Transpose the Source
/// ShapeCastOp - Reshapes the data based on the target.
/// vector::TransferWriteOp. - Write the result vector back.
-
static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
tensor::UnPackOp unpackOp,
ArrayRef<int64_t> inputVectorSizes,
SmallVectorImpl<Value> &newResults) {
-
+ // Handling this case requires a bit more change. Right now
+ // just the required attributes are handled.
if (!unpackOp.getOuterDimsPerm().empty()) {
LDBG("outer dimensions perms NYI for: " << unpackOp);
return failure();
@@ -1412,11 +1413,19 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unpackOp);
- RankedTensorType packTensorType = unpackOp.getSourceType();
- auto maskType =
- VectorType::get(packTensorType.getShape(), rewriter.getI1Type());
- auto vectorType = VectorType::get(packTensorType.getShape(),
- packTensorType.getElementType());
+ RankedTensorType unpackTensorType = unpackOp.getSourceType();
+ llvm::SmallVector<int64_t> readMaskShape(unpackTensorType.getShape());
+ for (unsigned int ii = 0; ii < inputVectorSizes.size(); ii++) {
+ readMaskShape[ii] = inputVectorSizes[ii];
+ }
+
+ // ReadMask is the size of tensor used to read and apply mask. It is
+ // set like this. Let's say the vectorSize (VS) array is size 'N' and
+ // the sourceShape(SS) is 'M' where M >= N
+ // Thus:
+ // ReadMaskShape = [VS[0], ..., VS[N-1], SS[N], ..., SS[M-1]]
+ auto vectorType =
+ VectorType::get(readMaskShape, unpackTensorType.getElementType());
ReifiedRankedShapedTypeDims reifiedRetShapes;
LogicalResult status =
cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
@@ -1425,54 +1434,87 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
LDBG("Unable to reify result shapes of " << unpackOp);
return failure();
}
-
+ int64_t unpackRank = unpackTensorType.getRank();
arith::ConstantIndexOp zeroOp =
rewriter.create<arith::ConstantIndexOp>(unpackOp->getLoc(), 0);
- Value mask = rewriter.create<vector::CreateMaskOp>(
- unpackOp.getLoc(), maskType,
- tensor::getMixedSizes(rewriter, unpackOp.getLoc(), unpackOp.getSource()));
vector::TransferReadOp readOp = rewriter.create<vector::TransferReadOp>(
unpackOp.getLoc(), vectorType, unpackOp.getSource(),
- SmallVector<Value>(packTensorType.getRank(), zeroOp),
- rewriter.getMultiDimIdentityMap(packTensorType.getRank()));
+ SmallVector<Value>(unpackRank, zeroOp),
+ rewriter.getMultiDimIdentityMap(unpackRank));
+ auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
+ Value mask = rewriter.create<vector::CreateMaskOp>(
+ unpackOp.getLoc(), readMaskType,
+ tensor::getMixedSizes(rewriter, unpackOp.getLoc(), unpackOp.getSource()));
vector::MaskOp maskedOp =
cast<vector::MaskOp>(mlir::vector::maskOperation(rewriter, readOp, mask));
int64_t numPackedDim = unpackOp.getInnerDimsPos().size();
- int64_t packRank = packTensorType.getRank();
- auto lastDims =
- llvm::to_vector(llvm::seq<int64_t>(packRank - numPackedDim, packRank));
+ llvm::SmallVector<int64_t> lastDims = llvm::to_vector(
+ llvm::seq<int64_t>(unpackRank - numPackedDim, unpackRank));
PackingMetadata packMetadata =
- computePackingMetadata(packRank, unpackOp.getInnerDimsPos());
+ computePackingMetadata(unpackRank, unpackOp.getInnerDimsPos());
SmallVector<int64_t> lastDimToInsertPosPerm = computePermutationVector(
- packRank, lastDims, packMetadata.insertPositions);
- SmallVector<int64_t> stripMineShape(packTensorType.getShape());
+ unpackRank, lastDims, packMetadata.insertPositions);
+ ShapedType maskedOpShapedType =
+ cast<ShapedType>(maskedOp.getResult(0).getType());
+ SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
+ mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
RankedTensorType stripMineTensorType =
- RankedTensorType::Builder(packTensorType).setShape(stripMineShape);
+ RankedTensorType::Builder(stripMineShape, stripMineElemType, {})
+ .setShape(stripMineShape);
+ // Collapse the tensor to the size required by result.
RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
stripMineTensorType, packMetadata.reassociations);
auto vecCollapsedType =
VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
+ // Transpose the appropriate rows to match output.
vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
unpackOp.getLoc(), maskedOp.getResult(0), lastDimToInsertPosPerm);
vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
unpackOp.getLoc(), vecCollapsedType, transposeOp->getResult(0));
- tensor::EmptyOp emptyOp = rewriter.create<tensor::EmptyOp>(
- unpackOp.getLoc(), reifiedRetShapes[0], packTensorType.getElementType());
+ tensor::EmptyOp emptyOp =
+ rewriter.create<tensor::EmptyOp>(unpackOp.getLoc(), reifiedRetShapes[0],
+ unpackTensorType.getElementType());
- vector::TransferWriteOp writeOp = rewriter.create<vector::TransferWriteOp>(
+ int64_t destRank = cast<ShapedType>(emptyOp.getType()).getRank();
+ Operation *writeOp = rewriter.create<vector::TransferWriteOp>(
unpackOp.getLoc(), shapeCastOp->getResult(0), emptyOp,
- SmallVector<Value>(lastDims.size(), zeroOp),
- SmallVector<bool>(lastDims.size(), true));
-
- newResults.push_back(writeOp->getResult(0));
+ SmallVector<Value>(destRank, zeroOp), SmallVector<bool>(destRank, true));
+ auto resultShape = unpackOp.getResult().getType().getShape();
+
+ // If the shape of the result doesn't match the inputVectorSizes, a mask
+ // is necessary.
+ bool needMaskForWrite =
+ llvm::any_of(llvm::zip_equal(inputVectorSizes, resultShape),
+ [](auto it) { return std::get<0>(it) != std::get<1>(it); });
+ mlir::OpResult result = writeOp->getResult(0);
+ if (needMaskForWrite) {
+ SmallVector<int64_t> writeMaskShape(inputVectorSizes);
+ llvm::ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
+ llvm::ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
+ for (auto [index, size] : enumerate(innerTiles)) {
+ writeMaskShape[innerDimPos[index]] *= size;
+ }
+ // WriteMaskShape is computed using the vectorSizes, inner Dim Position and
+ // innerTiles.
+ // WriteMaskShape (WMS) initialized to [inputVectorSizes]
+ // for-each index, value in inner-Tiles vector:
+ // WMS[innerDimPos[index]] = WMS[innerDimPos[index]] * value
+ auto writeMaskType = VectorType::get(writeMaskShape, rewriter.getI1Type());
+ Value writeMask = rewriter.create<vector::CreateMaskOp>(
+ unpackOp.getLoc(), writeMaskType, reifiedRetShapes[0]);
+ Operation *writeOpWithMask =
+ mlir::vector::maskOperation(rewriter, writeOp, writeMask);
+ result = writeOpWithMask->getResult(0);
+ }
+ newResults.push_back(result);
return success();
}
@@ -1601,6 +1643,19 @@ isValidMaskedInputVector(ArrayRef<int64_t> shape,
return success();
}
+/// Need to check if the inner-tiles are static/constant.
+static LogicalResult
+vectorizeUnPackOpPreCondition(tensor::UnPackOp unpackOp,
+ ArrayRef<int64_t> inputVectorSizes) {
+ if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
+ return !getConstantIntValue(res).has_value();
+ })) {
+ LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
+ return failure();
+ }
+ return success();
+}
+
static LogicalResult
vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
ArrayRef<int64_t> inputVectorSizes,
@@ -1727,7 +1782,9 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
.Case<tensor::PadOp>([&](auto padOp) {
return vectorizePadOpPrecondition(padOp, inputVectorSizes);
})
- .Case<tensor::UnPackOp>([&](auto unpackOp) { return success(); })
+ .Case<tensor::UnPackOp>([&](auto unpackOp) {
+ return vectorizeUnPackOpPreCondition(unpackOp, inputVectorSizes);
+ })
.Default([](auto) { return failure(); });
}
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index a4bcfec33d11b..91ada32e7b68f 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -422,11 +422,13 @@ module attributes {transform.with_named_sequence} {
// CHECK-LABEL: func @test_vectorize_unpack
func.func @test_vectorize_unpack(%0 : tensor<7x1136x16x16xf32>) -> tensor<100x18176xf32> {
// CHECK %[[c0:.*]] = arith.constant 0 : index
- // CHECK: %[[tr0:.*]] = vector.mask %[[m0:.*]] {{.*}} vector.transfer_read %{{.*}} : tensor<7x1136x16x16xf32>, vector<7x1136x16x16xf32> } : vector<7x1136x16x16xi1> -> vector<7x1136x16x16xf32>
- // CHECK: %[[trans0:.*]] = vector.transpose %[[tr0]], [0, 2, 1, 3] : vector<7x1136x16x16xf32> to vector<7x16x1136x16xf32>
- // CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<7x16x1136x16xf32> to vector<112x18176xf32>
+ // CHECK: %[[m0:.*]] = vector.create_mask %c7, %c1136, %c16, %c16_0 : vector<2x4x16x16xi1>
+ // CHECK: %[[tr0:.*]] = vector.mask %[[m0]] {{.*}} vector.transfer_read %{{.*}} : tensor<7x1136x16x16xf32>, vector<2x4x16x16xf32> } : vector<2x4x16x16xi1> -> vector<2x4x16x16xf32>
+ // CHECK: %[[trans0:.*]] = vector.transpose %[[tr0]], [0, 2, 1, 3] : vector<2x4x16x16xf32> to vector<2x16x4x16xf32>
+ // CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<2x16x4x16xf32> to vector<32x64xf32>
// CHECK: %[[empt0:.*]] = tensor.empty() : tensor<100x18176xf32>
- // CHECK: %[[tw0:.*]] = vector.transfer_write %[[sc0]], %[[empt0]]
+ // CHECK: %[[mask0:.*]] = vector.create_mask %c100, %c18176 : vector<32x64xi1>
+ // CHECK: %[[tw0:.*]] = vector.mask %[[mask0]] {{.*}} vector.transfer_write %[[sc0]], %[[empt0]]
// CHECK: return %[[tw0]]
%8 = tensor.empty() : tensor<100x18176xf32>
%unpack = tensor.unpack %0 inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %8 : tensor<7x1136x16x16xf32> -> tensor<100x18176xf32>
@@ -444,6 +446,29 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: func @test_vectorize_dynamic_shapes_unpack
+func.func @test_vectorize_dynamic_shapes_unpack(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x16x2xf32>) -> tensor<?x?xf32> {
+ // CHECK: %[[readMsk0:.*]] = vector.create_mask %dim_3, %dim_5, %c16, %c2 : vector<4x1x16x2xi1>
+ // CHECK: %[[read0:.*]] = vector.mask %[[readMsk0]] {{.*}} vector.transfer_read %{{.*}} : tensor<?x?x16x2xf32>, vector<4x1x16x2xf32> } : vector<4x1x16x2xi1> -> vector<4x1x16x2xf32>
+ // CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 3, 1, 2] : vector<4x1x16x2xf32> to vector<4x2x1x16xf32>
+ // CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<4x2x1x16xf32> to vector<8x16xf32>
+ // CHECK: %[[empt0:.*]] = tensor.empty
+ // CHECK: %[[writeMsk0:.*]] = vector.create_mask {{.*}} : vector<8x16xi1>
+ // CHECK: %[[write0:.*]] = vector.mask %[[writeMsk0:.*]] {{.*}} vector.transfer_write %[[sc0]], %[[empt0]]
+ // CHECK: return %[[write0]]
+ %ret = tensor.unpack %arg1 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg0 : tensor<?x?x16x2xf32> -> tensor<?x?xf32>
+ return %ret : tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [4, 1] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK-LABEL: func @test_masked_vectorize_pad
func.func @test_masked_vectorize_pad(
%0 : tensor<?x?xf32>, %h0 : index, %h1 : index)
More information about the Mlir-commits
mailing list