[Mlir-commits] [mlir] [mlir][Vectorizer] Added support to Vectorize tensor.unpack (PR #76087)
Balaji V. Iyer.
llvmlistbot at llvm.org
Tue Feb 20 13:46:48 PST 2024
https://github.com/bviyer updated https://github.com/llvm/llvm-project/pull/76087
>From c0c6432dce8f94b2b2f07595de0973dc12f90d45 Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Thu, 30 Nov 2023 20:39:55 +0000
Subject: [PATCH 01/12] [mlir][Vectorizer] Vectorize `tensor.unpack`
This patch allows vectorization of a `tensor.unpack` operation.
---
.../Linalg/Transforms/Vectorization.cpp | 348 +++++++++++-------
1 file changed, 220 insertions(+), 128 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 2bd6929fea6142..0760ad114b07b1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1400,6 +1400,88 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
return success();
}
+// Vectorize an `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
+// Vector::TransferReadOp - Reads the Vector Array of Source data
+// vector::TransposeOp - Transpose the Source
+// ShapeCastOp - Reshapes the data based on the target.
+// vector::TransferWriteOp. - Write the result vector back.
+
+static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
+ tensor::UnPackOp unpackOp,
+ ArrayRef<int64_t> inputVectorSizes,
+ SmallVectorImpl<Value> &newResults) {
+
+ if (!unpackOp.getOuterDimsPerm().empty()) {
+ LDBG("outer dimensions perms NYI for: " << unpackOp);
+ return failure();
+ }
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(unpackOp);
+
+ RankedTensorType packTensorType = unpackOp.getSourceType();
+ auto maskType =
+ VectorType::get(packTensorType.getShape(), rewriter.getI1Type());
+ auto vectorType = VectorType::get(packTensorType.getShape(),
+ packTensorType.getElementType());
+ ReifiedRankedShapedTypeDims reifiedRetShapes;
+ LogicalResult status =
+ cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
+ .reifyResultShapes(rewriter, reifiedRetShapes);
+ if (status.failed()) {
+ LDBG("Unable to reify result shapes of " << unpackOp);
+ return failure();
+ }
+
+ arith::ConstantIndexOp zeroOp =
+ rewriter.create<arith::ConstantIndexOp>(unpackOp->getLoc(), 0);
+ Value mask = rewriter.create<vector::CreateMaskOp>(
+ unpackOp.getLoc(), maskType,
+ tensor::getMixedSizes(rewriter, unpackOp.getLoc(), unpackOp.getSource()));
+
+ vector::TransferReadOp readOp = rewriter.create<vector::TransferReadOp>(
+ unpackOp.getLoc(), vectorType, unpackOp.getSource(),
+ SmallVector<Value>(packTensorType.getRank(), zeroOp),
+ rewriter.getMultiDimIdentityMap(packTensorType.getRank()));
+
+ vector::MaskOp maskedOp =
+ cast<vector::MaskOp>(mlir::vector::maskOperation(rewriter, readOp, mask));
+
+ int64_t numPackedDim = unpackOp.getInnerDimsPos().size();
+ int64_t packRank = packTensorType.getRank();
+ auto lastDims =
+ llvm::to_vector(llvm::seq<int64_t>(packRank - numPackedDim, packRank));
+ PackingMetadata packMetadata =
+ computePackingMetadata(packRank, unpackOp.getInnerDimsPos());
+ SmallVector<int64_t> lastDimToInsertPosPerm = computePermutationVector(
+ packRank, lastDims, packMetadata.insertPositions);
+ SmallVector<int64_t> stripMineShape(packTensorType.getShape());
+ applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
+
+ RankedTensorType stripMineTensorType =
+ RankedTensorType::Builder(packTensorType).setShape(stripMineShape);
+
+ RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+ stripMineTensorType, packMetadata.reassociations);
+ auto vecCollapsedType =
+ VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
+
+ vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
+ unpackOp.getLoc(), maskedOp.getResult(0), lastDimToInsertPosPerm);
+
+ vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
+ unpackOp.getLoc(), vecCollapsedType, transposeOp->getResult(0));
+ tensor::EmptyOp emptyOp = rewriter.create<tensor::EmptyOp>(
+ unpackOp.getLoc(), reifiedRetShapes[0], packTensorType.getElementType());
+
+ vector::TransferWriteOp writeOp = rewriter.create<vector::TransferWriteOp>(
+ unpackOp.getLoc(), shapeCastOp->getResult(0), emptyOp,
+ SmallVector<Value>(lastDims.size(), zeroOp),
+ SmallVector<bool>(lastDims.size(), true));
+
+ newResults.push_back(writeOp->getResult(0));
+ return success();
+}
/// Given a tensor::PackOp, return the `dest` shape before any packing
/// permutations.
@@ -1748,6 +1830,12 @@ vectorizePackOpPrecondition(tensor::PackOp packOp,
return success();
}
+static LogicalResult
+vectorizeUnpackOpPrecondition(tensor::UnPackOp unpackOp,
+ ArrayRef<int64_t> inputVectorSizes) {
+ return success();
+}
+
static LogicalResult
vectorizePadOpPrecondition(tensor::PadOp padOp,
ArrayRef<int64_t> inputVectorSizes) {
@@ -1801,31 +1889,32 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
return TypeSwitch<Operation *, LogicalResult>(op)
.Case<linalg::LinalgOp>([&](auto linalgOp) {
- return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
- vectorizeNDExtract);
+ return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
+ vectorizeNDExtract);
})
.Case<tensor::PadOp>([&](auto padOp) {
- return vectorizePadOpPrecondition(padOp, inputVectorSizes);
+ return vectorizePadOpPrecondition(padOp, inputVectorSizes);
})
.Case<tensor::PackOp>([&](auto packOp) {
- return vectorizePackOpPrecondition(packOp, inputVectorSizes);
- })
- .Default([](auto) { return failure(); });
+ return vectorizePackOpPrecondition(packOp, inputVectorSizes);
+ .Case<tensor::UnPackOp>([&](auto unpackOp) {
+ return vectorizeUnpackOpPrecondition(unpackOp, inputVectorSizes);
+ }).Default([](auto) { return failure(); });
}
/// Converts affine.apply Ops to arithmetic operations.
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
- OpBuilder::InsertionGuard g(rewriter);
- auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
-
- for (auto op : make_early_inc_range(toReplace)) {
- rewriter.setInsertionPoint(op);
- auto expanded = affine::expandAffineExpr(
- rewriter, op->getLoc(), op.getAffineMap().getResult(0),
- op.getOperands().take_front(op.getAffineMap().getNumDims()),
- op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
- rewriter.replaceOp(op, expanded);
- }
+ OpBuilder::InsertionGuard g(rewriter);
+ auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
+
+ for (auto op : make_early_inc_range(toReplace)) {
+ rewriter.setInsertionPoint(op);
+ auto expanded = affine::expandAffineExpr(
+ rewriter, op->getLoc(), op.getAffineMap().getResult(0),
+ op.getOperands().take_front(op.getAffineMap().getNumDims()),
+ op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
+ rewriter.replaceOp(op, expanded);
+ }
}
/// Emit a suitable vector form for an operation. If provided,
@@ -1839,117 +1928,119 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<bool> inputScalableVecDims,
bool vectorizeNDExtract,
bool flatten1DDepthwiseConv) {
- LDBG("Attempting to vectorize:\n" << *op << "\n");
- LDBG("Input vector sizes: ");
- LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
- LLVM_DEBUG(llvm::dbgs() << "\n");
- LDBG("Input scalable vector dims: ");
- LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
- LLVM_DEBUG(llvm::dbgs() << "\n");
-
- if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
- vectorizeNDExtract))) {
- LDBG("Vectorization pre-conditions failed\n");
- return failure();
- }
-
- // Initialize vectorization state.
- VectorizationState state(rewriter);
- if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
- if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
- inputScalableVecDims))) {
- LDBG("Vectorization state couldn't be initialized\n");
+ LDBG("Attempting to vectorize:\n" << *op << "\n");
+ LDBG("Input vector sizes: ");
+ LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
+ LLVM_DEBUG(llvm::dbgs() << "\n");
+ LDBG("Input scalable vector dims: ");
+ LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
+ LLVM_DEBUG(llvm::dbgs() << "\n");
+
+ if (failed(vectorizeOpPrecondition(
+ op, inputVectorSizes, inputScalableVecDims, vectorizeNDExtract))) {
+ LDBG("Vectorization pre-conditions failed\n");
return failure();
}
- }
- SmallVector<Value> results;
- auto vectorizeResult =
- TypeSwitch<Operation *, LogicalResult>(op)
- .Case<linalg::LinalgOp>([&](auto linalgOp) {
- // TODO: isaConvolutionOpInterface that can also infer from generic
- // features. Will require stride/dilation attributes inference.
- if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
- FailureOr<Operation *> convOr = vectorizeConvolution(
- rewriter, linalgOp, flatten1DDepthwiseConv);
- if (succeeded(convOr)) {
- llvm::append_range(results, (*convOr)->getResults());
- return success();
- }
-
- LDBG("Unsupported convolution can't be vectorized.\n");
- return failure();
- }
-
- LDBG("Vectorize generic by broadcasting to the canonical vector "
- "shape\n");
-
- // Pre-process before proceeding.
- convertAffineApply(rewriter, linalgOp);
-
- // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
- // to 'OpBuilder' when it is passed over to some methods like
- // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
- // erase an op within these methods, the actual rewriter won't be
- // notified and we will end up with read-after-free issues!
- return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
- })
- .Case<tensor::PadOp>([&](auto padOp) {
- return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
- results);
- })
- .Case<tensor::PackOp>([&](auto packOp) {
- return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
- results);
- })
- .Default([](auto) { return failure(); });
-
- if (failed(vectorizeResult)) {
- LDBG("Vectorization failed\n");
- return failure();
- }
+ // Initialize vectorization state.
+ VectorizationState state(rewriter);
+ if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
+ if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
+ inputScalableVecDims))) {
+ LDBG("Vectorization state couldn't be initialized\n");
+ return failure();
+ }
+ }
- if (!results.empty())
- rewriter.replaceOp(op, results);
- else
- rewriter.eraseOp(op);
+ SmallVector<Value> results;
+ auto vectorizeResult =
+ TypeSwitch<Operation *, LogicalResult>(op)
+ .Case<linalg::LinalgOp>([&](auto linalgOp) {
+ // TODO: isaConvolutionOpInterface that can also infer from
+ // generic features. Will require stride/dilation attributes
+ // inference.
+ if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
+ FailureOr<Operation *> convOr =
+ vectorizeConvolution(rewriter, linalgOp, flatten1DDepthwiseConv);
+ if (succeeded(convOr)) {
+ llvm::append_range(results, (*convOr)->getResults());
+ return success();
+ }
- return success();
+ LDBG("Unsupported convolution can't be vectorized.\n");
+ return failure();
+ }
+
+ LDBG("Vectorize generic by broadcasting to the canonical vector "
+ "shape\n");
+
+ // Pre-process before proceeding.
+ convertAffineApply(rewriter, linalgOp);
+
+ // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
+ // to 'OpBuilder' when it is passed over to some methods like
+ // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
+ // erase an op within these methods, the actual rewriter won't be
+ // notified and we will end up with read-after-free issues!
+ return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
+ })
+ .Case<tensor::PadOp>([&](auto padOp) {
+ return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes, results);
+ })
+ .Case<tensor::PackOp>([&](auto packOp) {
+ return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
+ results);
+ .Case<tensor::UnPackOp>([&](auto unpackOp) {
+ return vectorizeAsUnpackOp(rewriter, unpackOp, inputVectorSizes,
+ results);
+ }).Default([](auto) { return failure(); });
+
+ if (failed(vectorizeResult)) {
+ LDBG("Vectorization failed\n");
+ return failure();
+ }
+
+ if (!results.empty())
+ rewriter.replaceOp(op, results);
+ else
+ rewriter.eraseOp(op);
+
+ return success();
}
LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
memref::CopyOp copyOp) {
+ auto srcType = cast<MemRefType>(copyOp.getSource().getType());
+ auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
+ if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
+ return failure();
- auto srcType = cast<MemRefType>(copyOp.getSource().getType());
- auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
- if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
- return failure();
+ auto srcElementType = getElementTypeOrSelf(srcType);
+ auto dstElementType = getElementTypeOrSelf(dstType);
+ if (!VectorType::isValidElementType(srcElementType) ||
+ !VectorType::isValidElementType(dstElementType))
+ return failure();
- auto srcElementType = getElementTypeOrSelf(srcType);
- auto dstElementType = getElementTypeOrSelf(dstType);
- if (!VectorType::isValidElementType(srcElementType) ||
- !VectorType::isValidElementType(dstElementType))
- return failure();
+ auto readType = VectorType::get(srcType.getShape(), srcElementType);
+ auto writeType = VectorType::get(dstType.getShape(), dstElementType);
- auto readType = VectorType::get(srcType.getShape(), srcElementType);
- auto writeType = VectorType::get(dstType.getShape(), dstElementType);
+ Location loc = copyOp->getLoc();
+ Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ SmallVector<Value> indices(srcType.getRank(), zero);
- Location loc = copyOp->getLoc();
- Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- SmallVector<Value> indices(srcType.getRank(), zero);
-
- Value readValue = rewriter.create<vector::TransferReadOp>(
- loc, readType, copyOp.getSource(), indices,
- rewriter.getMultiDimIdentityMap(srcType.getRank()));
- if (cast<VectorType>(readValue.getType()).getRank() == 0) {
- readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
- readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
- }
- Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
- loc, readValue, copyOp.getTarget(), indices,
- rewriter.getMultiDimIdentityMap(srcType.getRank()));
- rewriter.replaceOp(copyOp, writeValue->getResults());
- return success();
+ Value readValue = rewriter.create<vector::TransferReadOp>(
+ loc, readType, copyOp.getSource(), indices,
+ rewriter.getMultiDimIdentityMap(srcType.getRank()));
+ if (cast<VectorType>(readValue.getType()).getRank() == 0) {
+ readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
+ readValue =
+ rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
+ }
+ Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
+ loc, readValue, copyOp.getTarget(), indices,
+ rewriter.getMultiDimIdentityMap(srcType.getRank()));
+ rewriter.replaceOp(copyOp, writeValue->getResults());
+ return success();
}
//----------------------------------------------------------------------------//
@@ -1958,7 +2049,7 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
/// Helper function that retrieves the value of an IntegerAttr.
static int64_t getIntFromAttr(Attribute attr) {
- return cast<IntegerAttr>(attr).getInt();
+ return cast<IntegerAttr>(attr).getInt();
}
/// Given an ArrayRef of OpFoldResults, return a vector of Values.
@@ -1966,16 +2057,16 @@ static int64_t getIntFromAttr(Attribute attr) {
/// not supported.
static SmallVector<Value> ofrToIndexValues(RewriterBase &rewriter, Location loc,
ArrayRef<OpFoldResult> ofrs) {
- SmallVector<Value> result;
- for (auto o : ofrs) {
- if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
- result.push_back(val);
- } else {
- result.push_back(rewriter.create<arith::ConstantIndexOp>(
- loc, getIntFromAttr(o.template get<Attribute>())));
- }
- }
- return result;
+ SmallVector<Value> result;
+ for (auto o : ofrs) {
+ if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
+ result.push_back(val);
+ } else {
+ result.push_back(rewriter.create<arith::ConstantIndexOp>(
+ loc, getIntFromAttr(o.template get<Attribute>())));
+ }
+ }
+ return result;
}
/// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
@@ -2050,7 +2141,8 @@ struct GenericPadOpVectorizationPattern : public GeneralizePadOpPattern {
// If `dest` is a FillOp and the TransferWriteOp would overwrite the
// entire tensor, write directly to the FillOp's operand.
if (llvm::equal(vecShape, resultType.getShape()) &&
- llvm::all_of(writeInBounds, [](bool b) { return b; }))
+ llvm::all_of(writeInBounds, [](bool b) {
+ return b; }))
if (auto fill = dest.getDefiningOp<FillOp>())
dest = fill.output();
@@ -2061,7 +2153,7 @@ struct GenericPadOpVectorizationPattern : public GeneralizePadOpPattern {
padOp, read, dest, writeIndices, ArrayRef<bool>{writeInBounds});
return success();
- }
+}
};
/// Base pattern for rewriting tensor::PadOps whose result is consumed by a
>From 853a735ad233ad24c30a36ba7a1a870ced8eb947 Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Wed, 20 Dec 2023 11:45:20 -0600
Subject: [PATCH 02/12] Enabled tensor.unpack vectorization and added test
case.
---
.../TransformOps/LinalgTransformOps.cpp | 3 ++-
mlir/test/Dialect/Linalg/vectorization.mlir | 25 +++++++++++++++++++
2 files changed, 27 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 585fd14b40d764..12e3f1a5d0d31e 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3152,7 +3152,8 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
// TODO: Check that the correct number of vectorSizes was provided.
for (Operation *target : targets) {
- if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp>(target)) {
+ if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
+ target)) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "Unsupported Op, cannot vectorize";
}
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 5d1bef478ee987..50d872c95128ba 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -419,6 +419,31 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: func @test_vectorize_unpack
+func.func @test_vectorize_unpack(%0 : tensor<7x1136x16x16xf32>) -> tensor<100x18176xf32> {
+ // CHECK %[[c0:.*]] = arith.constant 0 : index
+ // CHECK: %[[tr0:.*]] = vector.mask %[[m0:.*]] {{.*}} vector.transfer_read %{{.*}} : tensor<7x1136x16x16xf32>, vector<7x1136x16x16xf32> } : vector<7x1136x16x16xi1> -> vector<7x1136x16x16xf32>
+ // CHECK: %[[trans0:.*]] = vector.transpose %[[tr0]], [0, 2, 1, 3] : vector<7x1136x16x16xf32> to vector<7x16x1136x16xf32>
+ // CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<7x16x1136x16xf32> to vector<112x18176xf32>
+ // CHECK: %[[empt0:.*]] = tensor.empty() : tensor<100x18176xf32>
+ // CHECK: %[[tw0:.*]] = vector.transfer_write %[[sc0]], %[[empt0]]
+ // CHECK: return %[[tw0]]
+ %8 = tensor.empty() : tensor<100x18176xf32>
+ %unpack = tensor.unpack %0 inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %8 : tensor<7x1136x16x16xf32> -> tensor<100x18176xf32>
+ return %unpack : tensor<100x18176xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [2, 4] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK-LABEL: func @test_masked_vectorize_pad
func.func @test_masked_vectorize_pad(
%0 : tensor<?x?xf32>, %h0 : index, %h1 : index)
>From a48dfac0d89493134e75c36ea307c64d0c941875 Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Fri, 19 Jan 2024 19:11:04 -0600
Subject: [PATCH 03/12] Added some of the changes requested by Diego and HanHan
---
.../Dialect/Linalg/Transforms/Vectorization.cpp | 16 +++++-----------
1 file changed, 5 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0760ad114b07b1..f0b9da7aca4171 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1400,11 +1400,11 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
return success();
}
-// Vectorize an `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
-// Vector::TransferReadOp - Reads the Vector Array of Source data
-// vector::TransposeOp - Transpose the Source
-// ShapeCastOp - Reshapes the data based on the target.
-// vector::TransferWriteOp. - Write the result vector back.
+/// Vectorize an `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
+/// Vector::TransferReadOp - Reads the Vector Array of Source data
+/// vector::TransposeOp - Transpose the Source
+/// ShapeCastOp - Reshapes the data based on the target.
+/// vector::TransferWriteOp. - Write the result vector back.
static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
tensor::UnPackOp unpackOp,
@@ -1830,12 +1830,6 @@ vectorizePackOpPrecondition(tensor::PackOp packOp,
return success();
}
-static LogicalResult
-vectorizeUnpackOpPrecondition(tensor::UnPackOp unpackOp,
- ArrayRef<int64_t> inputVectorSizes) {
- return success();
-}
-
static LogicalResult
vectorizePadOpPrecondition(tensor::PadOp padOp,
ArrayRef<int64_t> inputVectorSizes) {
>From 70cc122d0f8b77668f51592f6fa7fd78534b1f16 Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Tue, 6 Feb 2024 23:45:59 +0000
Subject: [PATCH 04/12] Used vectorSizes for masks and added a dynamic shapes
test case.
---
.../Linalg/Transforms/Vectorization.cpp | 109 +++++++++++++-----
mlir/test/Dialect/Linalg/vectorization.mlir | 33 +++++-
2 files changed, 111 insertions(+), 31 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index f0b9da7aca4171..866b4e8774f5e2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1400,17 +1400,18 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
return success();
}
+
/// Vectorize an `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
/// Vector::TransferReadOp - Reads the Vector Array of Source data
/// vector::TransposeOp - Transpose the Source
/// ShapeCastOp - Reshapes the data based on the target.
/// vector::TransferWriteOp. - Write the result vector back.
-
static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
tensor::UnPackOp unpackOp,
ArrayRef<int64_t> inputVectorSizes,
SmallVectorImpl<Value> &newResults) {
-
+ // Handling this case requires a bit more change. Right now
+ // just the required attributes are handled.
if (!unpackOp.getOuterDimsPerm().empty()) {
LDBG("outer dimensions perms NYI for: " << unpackOp);
return failure();
@@ -1419,11 +1420,19 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unpackOp);
- RankedTensorType packTensorType = unpackOp.getSourceType();
- auto maskType =
- VectorType::get(packTensorType.getShape(), rewriter.getI1Type());
- auto vectorType = VectorType::get(packTensorType.getShape(),
- packTensorType.getElementType());
+ RankedTensorType unpackTensorType = unpackOp.getSourceType();
+ llvm::SmallVector<int64_t> readMaskShape(unpackTensorType.getShape());
+ for (unsigned int ii = 0; ii < inputVectorSizes.size(); ii++) {
+ readMaskShape[ii] = inputVectorSizes[ii];
+ }
+
+ // ReadMask is the size of tensor used to read and apply mask. It is
+ // set like this. Let's say the vectorSize (VS) array is size 'N' and
+ // the sourceShape(SS) is 'M' where M >= N
+ // Thus:
+ // ReadMaskShape = [VS[0], ..., VS[N-1], SS[N], ..., SS[M-1]]
+ auto vectorType =
+ VectorType::get(readMaskShape, unpackTensorType.getElementType());
ReifiedRankedShapedTypeDims reifiedRetShapes;
LogicalResult status =
cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
@@ -1432,54 +1441,87 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
LDBG("Unable to reify result shapes of " << unpackOp);
return failure();
}
-
+ int64_t unpackRank = unpackTensorType.getRank();
arith::ConstantIndexOp zeroOp =
rewriter.create<arith::ConstantIndexOp>(unpackOp->getLoc(), 0);
- Value mask = rewriter.create<vector::CreateMaskOp>(
- unpackOp.getLoc(), maskType,
- tensor::getMixedSizes(rewriter, unpackOp.getLoc(), unpackOp.getSource()));
vector::TransferReadOp readOp = rewriter.create<vector::TransferReadOp>(
unpackOp.getLoc(), vectorType, unpackOp.getSource(),
- SmallVector<Value>(packTensorType.getRank(), zeroOp),
- rewriter.getMultiDimIdentityMap(packTensorType.getRank()));
+ SmallVector<Value>(unpackRank, zeroOp),
+ rewriter.getMultiDimIdentityMap(unpackRank));
+ auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
+ Value mask = rewriter.create<vector::CreateMaskOp>(
+ unpackOp.getLoc(), readMaskType,
+ tensor::getMixedSizes(rewriter, unpackOp.getLoc(), unpackOp.getSource()));
vector::MaskOp maskedOp =
cast<vector::MaskOp>(mlir::vector::maskOperation(rewriter, readOp, mask));
int64_t numPackedDim = unpackOp.getInnerDimsPos().size();
- int64_t packRank = packTensorType.getRank();
- auto lastDims =
- llvm::to_vector(llvm::seq<int64_t>(packRank - numPackedDim, packRank));
+ llvm::SmallVector<int64_t> lastDims = llvm::to_vector(
+ llvm::seq<int64_t>(unpackRank - numPackedDim, unpackRank));
PackingMetadata packMetadata =
- computePackingMetadata(packRank, unpackOp.getInnerDimsPos());
+ computePackingMetadata(unpackRank, unpackOp.getInnerDimsPos());
SmallVector<int64_t> lastDimToInsertPosPerm = computePermutationVector(
- packRank, lastDims, packMetadata.insertPositions);
- SmallVector<int64_t> stripMineShape(packTensorType.getShape());
+ unpackRank, lastDims, packMetadata.insertPositions);
+ ShapedType maskedOpShapedType =
+ cast<ShapedType>(maskedOp.getResult(0).getType());
+ SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
+ mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
RankedTensorType stripMineTensorType =
- RankedTensorType::Builder(packTensorType).setShape(stripMineShape);
+ RankedTensorType::Builder(stripMineShape, stripMineElemType, {})
+ .setShape(stripMineShape);
+ // Collapse the tensor to the size required by result.
RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
stripMineTensorType, packMetadata.reassociations);
auto vecCollapsedType =
VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
+ // Transpose the appropriate rows to match output.
vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
unpackOp.getLoc(), maskedOp.getResult(0), lastDimToInsertPosPerm);
vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
unpackOp.getLoc(), vecCollapsedType, transposeOp->getResult(0));
- tensor::EmptyOp emptyOp = rewriter.create<tensor::EmptyOp>(
- unpackOp.getLoc(), reifiedRetShapes[0], packTensorType.getElementType());
+ tensor::EmptyOp emptyOp =
+ rewriter.create<tensor::EmptyOp>(unpackOp.getLoc(), reifiedRetShapes[0],
+ unpackTensorType.getElementType());
- vector::TransferWriteOp writeOp = rewriter.create<vector::TransferWriteOp>(
+ int64_t destRank = cast<ShapedType>(emptyOp.getType()).getRank();
+ Operation *writeOp = rewriter.create<vector::TransferWriteOp>(
unpackOp.getLoc(), shapeCastOp->getResult(0), emptyOp,
- SmallVector<Value>(lastDims.size(), zeroOp),
- SmallVector<bool>(lastDims.size(), true));
-
- newResults.push_back(writeOp->getResult(0));
+ SmallVector<Value>(destRank, zeroOp), SmallVector<bool>(destRank, true));
+ auto resultShape = unpackOp.getResult().getType().getShape();
+
+ // If the shape of the result doesn't match the inputVectorSizes, a mask
+ // is necessary.
+ bool needMaskForWrite =
+ llvm::any_of(llvm::zip_equal(inputVectorSizes, resultShape),
+ [](auto it) { return std::get<0>(it) != std::get<1>(it); });
+ mlir::OpResult result = writeOp->getResult(0);
+ if (needMaskForWrite) {
+ SmallVector<int64_t> writeMaskShape(inputVectorSizes);
+ llvm::ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
+ llvm::ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
+ for (auto [index, size] : enumerate(innerTiles)) {
+ writeMaskShape[innerDimPos[index]] *= size;
+ }
+ // WriteMaskShape is computed using the vectorSizes, inner Dim Position and
+ // innerTiles.
+ // WriteMaskShape (WMS) initialized to [inputVectorSizes]
+ // for-each index, value in inner-Tiles vector:
+ // WMS[innerDimPos[index]] = WMS[innerDimPos[index]] * value
+ auto writeMaskType = VectorType::get(writeMaskShape, rewriter.getI1Type());
+ Value writeMask = rewriter.create<vector::CreateMaskOp>(
+ unpackOp.getLoc(), writeMaskType, reifiedRetShapes[0]);
+ Operation *writeOpWithMask =
+ mlir::vector::maskOperation(rewriter, writeOp, writeMask);
+ result = writeOpWithMask->getResult(0);
+ }
+ newResults.push_back(result);
return success();
}
@@ -1737,6 +1779,19 @@ isValidMaskedInputVector(ArrayRef<int64_t> shape,
return success();
}
+/// Need to check if the inner-tiles are static/constant.
+static LogicalResult
+vectorizeUnPackOpPreCondition(tensor::UnPackOp unpackOp,
+ ArrayRef<int64_t> inputVectorSizes) {
+ if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
+ return !getConstantIntValue(res).has_value();
+ })) {
+ LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
+ return failure();
+ }
+ return success();
+}
+
static LogicalResult
vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
ArrayRef<int64_t> inputVectorSizes,
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 50d872c95128ba..a79ff6bd75795c 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -422,11 +422,13 @@ module attributes {transform.with_named_sequence} {
// CHECK-LABEL: func @test_vectorize_unpack
func.func @test_vectorize_unpack(%0 : tensor<7x1136x16x16xf32>) -> tensor<100x18176xf32> {
// CHECK %[[c0:.*]] = arith.constant 0 : index
- // CHECK: %[[tr0:.*]] = vector.mask %[[m0:.*]] {{.*}} vector.transfer_read %{{.*}} : tensor<7x1136x16x16xf32>, vector<7x1136x16x16xf32> } : vector<7x1136x16x16xi1> -> vector<7x1136x16x16xf32>
- // CHECK: %[[trans0:.*]] = vector.transpose %[[tr0]], [0, 2, 1, 3] : vector<7x1136x16x16xf32> to vector<7x16x1136x16xf32>
- // CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<7x16x1136x16xf32> to vector<112x18176xf32>
+ // CHECK: %[[m0:.*]] = vector.create_mask %c7, %c1136, %c16, %c16_0 : vector<2x4x16x16xi1>
+ // CHECK: %[[tr0:.*]] = vector.mask %[[m0]] {{.*}} vector.transfer_read %{{.*}} : tensor<7x1136x16x16xf32>, vector<2x4x16x16xf32> } : vector<2x4x16x16xi1> -> vector<2x4x16x16xf32>
+ // CHECK: %[[trans0:.*]] = vector.transpose %[[tr0]], [0, 2, 1, 3] : vector<2x4x16x16xf32> to vector<2x16x4x16xf32>
+ // CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<2x16x4x16xf32> to vector<32x64xf32>
// CHECK: %[[empt0:.*]] = tensor.empty() : tensor<100x18176xf32>
- // CHECK: %[[tw0:.*]] = vector.transfer_write %[[sc0]], %[[empt0]]
+ // CHECK: %[[mask0:.*]] = vector.create_mask %c100, %c18176 : vector<32x64xi1>
+ // CHECK: %[[tw0:.*]] = vector.mask %[[mask0]] {{.*}} vector.transfer_write %[[sc0]], %[[empt0]]
// CHECK: return %[[tw0]]
%8 = tensor.empty() : tensor<100x18176xf32>
%unpack = tensor.unpack %0 inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %8 : tensor<7x1136x16x16xf32> -> tensor<100x18176xf32>
@@ -444,6 +446,29 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: func @test_vectorize_dynamic_shapes_unpack
+func.func @test_vectorize_dynamic_shapes_unpack(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x16x2xf32>) -> tensor<?x?xf32> {
+ // CHECK: %[[readMsk0:.*]] = vector.create_mask %dim_3, %dim_5, %c16, %c2 : vector<4x1x16x2xi1>
+ // CHECK: %[[read0:.*]] = vector.mask %[[readMsk0]] {{.*}} vector.transfer_read %{{.*}} : tensor<?x?x16x2xf32>, vector<4x1x16x2xf32> } : vector<4x1x16x2xi1> -> vector<4x1x16x2xf32>
+ // CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 3, 1, 2] : vector<4x1x16x2xf32> to vector<4x2x1x16xf32>
+ // CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<4x2x1x16xf32> to vector<8x16xf32>
+ // CHECK: %[[empt0:.*]] = tensor.empty
+ // CHECK: %[[writeMsk0:.*]] = vector.create_mask {{.*}} : vector<8x16xi1>
+ // CHECK: %[[write0:.*]] = vector.mask %[[writeMsk0:.*]] {{.*}} vector.transfer_write %[[sc0]], %[[empt0]]
+ // CHECK: return %[[write0]]
+ %ret = tensor.unpack %arg1 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg0 : tensor<?x?x16x2xf32> -> tensor<?x?xf32>
+ return %ret : tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [4, 1] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK-LABEL: func @test_masked_vectorize_pad
func.func @test_masked_vectorize_pad(
%0 : tensor<?x?xf32>, %h0 : index, %h1 : index)
>From c33642b2da876026e13968199adbbba6f6bb2432 Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Wed, 7 Feb 2024 19:43:48 +0000
Subject: [PATCH 05/12] Added some changes proposed by HanHan.
---
.../Linalg/Transforms/Vectorization.cpp | 52 +++++++++++--------
1 file changed, 29 insertions(+), 23 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 866b4e8774f5e2..4d0c62fbd46cfa 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1410,12 +1410,6 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
tensor::UnPackOp unpackOp,
ArrayRef<int64_t> inputVectorSizes,
SmallVectorImpl<Value> &newResults) {
- // Handling this case requires a bit more change. Right now
- // just the required attributes are handled.
- if (!unpackOp.getOuterDimsPerm().empty()) {
- LDBG("outer dimensions perms NYI for: " << unpackOp);
- return failure();
- }
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unpackOp);
@@ -1442,18 +1436,19 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
return failure();
}
int64_t unpackRank = unpackTensorType.getRank();
+ Location loc = unpackOp->getLoc();
arith::ConstantIndexOp zeroOp =
- rewriter.create<arith::ConstantIndexOp>(unpackOp->getLoc(), 0);
+ rewriter.create<arith::ConstantIndexOp>(loc, 0);
vector::TransferReadOp readOp = rewriter.create<vector::TransferReadOp>(
- unpackOp.getLoc(), vectorType, unpackOp.getSource(),
+ loc, vectorType, unpackOp.getSource(),
SmallVector<Value>(unpackRank, zeroOp),
rewriter.getMultiDimIdentityMap(unpackRank));
auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
Value mask = rewriter.create<vector::CreateMaskOp>(
- unpackOp.getLoc(), readMaskType,
- tensor::getMixedSizes(rewriter, unpackOp.getLoc(), unpackOp.getSource()));
+ loc, readMaskType,
+ tensor::getMixedSizes(rewriter, loc, unpackOp.getSource()));
vector::MaskOp maskedOp =
cast<vector::MaskOp>(mlir::vector::maskOperation(rewriter, readOp, mask));
@@ -1474,25 +1469,23 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
RankedTensorType::Builder(stripMineShape, stripMineElemType, {})
.setShape(stripMineShape);
- // Collapse the tensor to the size required by result.
- RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
- stripMineTensorType, packMetadata.reassociations);
- auto vecCollapsedType =
- VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
-
// Transpose the appropriate rows to match output.
vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
- unpackOp.getLoc(), maskedOp.getResult(0), lastDimToInsertPosPerm);
+ loc, maskedOp.getResult(0), lastDimToInsertPosPerm);
+ // Collapse the vector to the size required by result.
+ RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+ stripMineTensorType, packMetadata.reassociations);
+ mlir::VectorType vecCollapsedType =
+ VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
- unpackOp.getLoc(), vecCollapsedType, transposeOp->getResult(0));
- tensor::EmptyOp emptyOp =
- rewriter.create<tensor::EmptyOp>(unpackOp.getLoc(), reifiedRetShapes[0],
- unpackTensorType.getElementType());
+ loc, vecCollapsedType, transposeOp->getResult(0));
+ tensor::EmptyOp emptyOp = rewriter.create<tensor::EmptyOp>(
+ loc, reifiedRetShapes[0], unpackTensorType.getElementType());
int64_t destRank = cast<ShapedType>(emptyOp.getType()).getRank();
Operation *writeOp = rewriter.create<vector::TransferWriteOp>(
- unpackOp.getLoc(), shapeCastOp->getResult(0), emptyOp,
+ loc, shapeCastOp->getResult(0), emptyOp,
SmallVector<Value>(destRank, zeroOp), SmallVector<bool>(destRank, true));
auto resultShape = unpackOp.getResult().getType().getShape();
@@ -1516,7 +1509,7 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
// WMS[innerDimPos[index]] = WMS[innerDimPos[index]] * value
auto writeMaskType = VectorType::get(writeMaskShape, rewriter.getI1Type());
Value writeMask = rewriter.create<vector::CreateMaskOp>(
- unpackOp.getLoc(), writeMaskType, reifiedRetShapes[0]);
+ loc, writeMaskType, reifiedRetShapes[0]);
Operation *writeOpWithMask =
mlir::vector::maskOperation(rewriter, writeOp, writeMask);
result = writeOpWithMask->getResult(0);
@@ -1783,12 +1776,25 @@ isValidMaskedInputVector(ArrayRef<int64_t> shape,
static LogicalResult
vectorizeUnPackOpPreCondition(tensor::UnPackOp unpackOp,
ArrayRef<int64_t> inputVectorSizes) {
+
+ // Handling this case requires a bit more change. Right now
+ // just the required attributes are handled.
+ if (!unpackOp.getOuterDimsPerm().empty()) {
+ LDBG("outer dimensions perms NYI for: " << unpackOp);
+ return failure();
+ }
+
if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
return !getConstantIntValue(res).has_value();
})) {
LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
return failure();
}
+ llvm::ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
+ if (inputVectorSizes.empty() == false &&
+ failed(isValidMaskedInputVector(resultShape, inputVectorSizes)))
+ return failure();
+
return success();
}
>From 744a291b346a3f5bf36f8c744b4d28e152ca4f5e Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Fri, 9 Feb 2024 17:33:25 +0000
Subject: [PATCH 06/12] Fixed all issues pointed out by HanHan except factoring
in StripMineTensorType
---
.../include/mlir/Dialect/Tensor/Utils/Utils.h | 3 +-
.../Dialect/Linalg/Transforms/Transforms.cpp | 2 +-
.../Linalg/Transforms/Vectorization.cpp | 544 +++++++++---------
mlir/lib/Dialect/Tensor/Utils/Utils.cpp | 27 +-
mlir/test/Dialect/Linalg/vectorization.mlir | 85 ++-
5 files changed, 322 insertions(+), 339 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
index fe9b16cb44b3da..60522ac48d95b5 100644
--- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
@@ -38,7 +38,8 @@ computeTransposedType(RankedTensorType rankedTensorType,
/// i.e. for a pack from an ABCD layout to an ABCDba:
/// The packed shape would be ABCDba.
/// The pre-permutation shape would be AaBbCD.
-SmallVector<int64_t> getPackInverseDestPermutation(PackOp packOp);
+SmallVector<int64_t> getPackUnPackInverseDestPerm(
+ std::variant<tensor::PackOp, tensor::UnPackOp> packOp);
/// A tensor.insert_slice is a cast-like operation if it merely rank-extends the
/// source tensor or inserts the source tensor into a destination tensor with
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 596b7c50c1e4e4..9f8ea7f1f3969b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -237,7 +237,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
PackingMetadata packingMetadata = computePackingMetadata(
packedTensorType.getRank(), packOp.getInnerDimsPos());
SmallVector<int64_t> packedToStripMinedShapePerm =
- tensor::getPackInverseDestPermutation(packOp);
+ tensor::getPackUnPackInverseDestPerm(packOp);
// 3. Compute the stripMinedShape: this is the packed shape before any outer
// or inner permutations have been applied.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 4d0c62fbd46cfa..420ffe533ff0b3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1401,129 +1401,12 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
return success();
}
-/// Vectorize an `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
-/// Vector::TransferReadOp - Reads the Vector Array of Source data
-/// vector::TransposeOp - Transpose the Source
-/// ShapeCastOp - Reshapes the data based on the target.
-/// vector::TransferWriteOp. - Write the result vector back.
-static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
- tensor::UnPackOp unpackOp,
- ArrayRef<int64_t> inputVectorSizes,
- SmallVectorImpl<Value> &newResults) {
-
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(unpackOp);
-
- RankedTensorType unpackTensorType = unpackOp.getSourceType();
- llvm::SmallVector<int64_t> readMaskShape(unpackTensorType.getShape());
- for (unsigned int ii = 0; ii < inputVectorSizes.size(); ii++) {
- readMaskShape[ii] = inputVectorSizes[ii];
- }
-
- // ReadMask is the size of tensor used to read and apply mask. It is
- // set like this. Let's say the vectorSize (VS) array is size 'N' and
- // the sourceShape(SS) is 'M' where M >= N
- // Thus:
- // ReadMaskShape = [VS[0], ..., VS[N-1], SS[N], ..., SS[M-1]]
- auto vectorType =
- VectorType::get(readMaskShape, unpackTensorType.getElementType());
- ReifiedRankedShapedTypeDims reifiedRetShapes;
- LogicalResult status =
- cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
- .reifyResultShapes(rewriter, reifiedRetShapes);
- if (status.failed()) {
- LDBG("Unable to reify result shapes of " << unpackOp);
- return failure();
- }
- int64_t unpackRank = unpackTensorType.getRank();
- Location loc = unpackOp->getLoc();
- arith::ConstantIndexOp zeroOp =
- rewriter.create<arith::ConstantIndexOp>(loc, 0);
-
- vector::TransferReadOp readOp = rewriter.create<vector::TransferReadOp>(
- loc, vectorType, unpackOp.getSource(),
- SmallVector<Value>(unpackRank, zeroOp),
- rewriter.getMultiDimIdentityMap(unpackRank));
-
- auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
- Value mask = rewriter.create<vector::CreateMaskOp>(
- loc, readMaskType,
- tensor::getMixedSizes(rewriter, loc, unpackOp.getSource()));
- vector::MaskOp maskedOp =
- cast<vector::MaskOp>(mlir::vector::maskOperation(rewriter, readOp, mask));
-
- int64_t numPackedDim = unpackOp.getInnerDimsPos().size();
- llvm::SmallVector<int64_t> lastDims = llvm::to_vector(
- llvm::seq<int64_t>(unpackRank - numPackedDim, unpackRank));
- PackingMetadata packMetadata =
- computePackingMetadata(unpackRank, unpackOp.getInnerDimsPos());
- SmallVector<int64_t> lastDimToInsertPosPerm = computePermutationVector(
- unpackRank, lastDims, packMetadata.insertPositions);
- ShapedType maskedOpShapedType =
- cast<ShapedType>(maskedOp.getResult(0).getType());
- SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
- mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
- applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
-
- RankedTensorType stripMineTensorType =
- RankedTensorType::Builder(stripMineShape, stripMineElemType, {})
- .setShape(stripMineShape);
-
- // Transpose the appropriate rows to match output.
- vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
- loc, maskedOp.getResult(0), lastDimToInsertPosPerm);
-
- // Collapse the vector to the size required by result.
- RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
- stripMineTensorType, packMetadata.reassociations);
- mlir::VectorType vecCollapsedType =
- VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
- vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
- loc, vecCollapsedType, transposeOp->getResult(0));
- tensor::EmptyOp emptyOp = rewriter.create<tensor::EmptyOp>(
- loc, reifiedRetShapes[0], unpackTensorType.getElementType());
-
- int64_t destRank = cast<ShapedType>(emptyOp.getType()).getRank();
- Operation *writeOp = rewriter.create<vector::TransferWriteOp>(
- loc, shapeCastOp->getResult(0), emptyOp,
- SmallVector<Value>(destRank, zeroOp), SmallVector<bool>(destRank, true));
- auto resultShape = unpackOp.getResult().getType().getShape();
-
- // If the shape of the result doesn't match the inputVectorSizes, a mask
- // is necessary.
- bool needMaskForWrite =
- llvm::any_of(llvm::zip_equal(inputVectorSizes, resultShape),
- [](auto it) { return std::get<0>(it) != std::get<1>(it); });
- mlir::OpResult result = writeOp->getResult(0);
- if (needMaskForWrite) {
- SmallVector<int64_t> writeMaskShape(inputVectorSizes);
- llvm::ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
- llvm::ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
- for (auto [index, size] : enumerate(innerTiles)) {
- writeMaskShape[innerDimPos[index]] *= size;
- }
- // WriteMaskShape is computed using the vectorSizes, inner Dim Position and
- // innerTiles.
- // WriteMaskShape (WMS) initialized to [inputVectorSizes]
- // for-each index, value in inner-Tiles vector:
- // WMS[innerDimPos[index]] = WMS[innerDimPos[index]] * value
- auto writeMaskType = VectorType::get(writeMaskShape, rewriter.getI1Type());
- Value writeMask = rewriter.create<vector::CreateMaskOp>(
- loc, writeMaskType, reifiedRetShapes[0]);
- Operation *writeOpWithMask =
- mlir::vector::maskOperation(rewriter, writeOp, writeMask);
- result = writeOpWithMask->getResult(0);
- }
- newResults.push_back(result);
- return success();
-}
-
/// Given a tensor::PackOp, return the `dest` shape before any packing
/// permutations.
static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
ArrayRef<int64_t> destShape) {
return applyPermutation(destShape,
- tensor::getPackInverseDestPermutation(packOp));
+ tensor::getPackUnPackInverseDestPerm(packOp));
}
/// Create a TransferReadOp from `source` with static shape `readShape`. If the
@@ -1537,16 +1420,28 @@ static Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
auto sourceShape = dyn_cast<ShapedType>(source.getType()).getShape();
assert(sourceShape.size() == readShape.size());
auto maskType = VectorType::get(readShape, builder.getI1Type());
- auto vectorType = VectorType::get(readShape, padValue.getType());
+ Type vecElemType = padValue != nullptr
+ ? padValue.getType()
+ : cast<ShapedType>(source.getType()).getElementType();
+ auto vectorType = VectorType::get(readShape, vecElemType);
int64_t readRank = readShape.size();
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
- auto transferReadOp = builder.create<vector::TransferReadOp>(
- loc,
- /*vectorType=*/vectorType,
- /*source=*/source,
- /*indices=*/SmallVector<Value>(readRank, zero),
- /*padding=*/padValue,
- /*inBounds=*/SmallVector<bool>(readRank, true));
+ vector::TransferReadOp transferReadOp = nullptr;
+ if (padValue == nullptr) {
+ transferReadOp = builder.create<vector::TransferReadOp>(
+ loc,
+ /*vectorType=*/vectorType,
+ /*source=*/source,
+ /*indices=*/SmallVector<Value>(readRank, zero));
+ } else {
+ transferReadOp = builder.create<vector::TransferReadOp>(
+ loc,
+ /*vectorType=*/vectorType,
+ /*source=*/source,
+ /*indices=*/SmallVector<Value>(readRank, zero),
+ /*padding=*/padValue,
+ /*inBounds=*/SmallVector<bool>(readRank, true));
+ }
if (llvm::equal(readShape, sourceShape)) {
return transferReadOp;
}
@@ -1664,7 +1559,7 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
// Create TransposeOp.
auto destPermutation =
- invertPermutationVector(tensor::getPackInverseDestPermutation(packOp));
+ invertPermutationVector(tensor::getPackUnPackInverseDestPerm(packOp));
auto transposeOp = rewriter.create<vector::TransposeOp>(
loc, shapeCastOp.getResult(), destPermutation);
@@ -1676,6 +1571,90 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
return success();
}
+/// Vectorize an `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
+/// Vector::TransferReadOp - Reads the Vector Array of Source data
+/// vector::TransposeOp - Transpose the Source
+/// ShapeCastOp - Reshapes the data based on the target.
+/// vector::TransferWriteOp. - Write the result vector back.
+static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
+ tensor::UnPackOp unpackOp,
+ ArrayRef<int64_t> inputVectorSizes,
+ SmallVectorImpl<Value> &newResults) {
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(unpackOp);
+
+ RankedTensorType unpackTensorType = unpackOp.getSourceType();
+
+ SmallVector<int64_t> readMaskShape(unpackTensorType.getShape());
+ llvm::ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
+ llvm::ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
+ for (unsigned int i = 0; i < inputVectorSizes.size(); i++) {
+ readMaskShape[i] = inputVectorSizes[i];
+ }
+ for (auto [index, size] : enumerate(innerTiles)) {
+ readMaskShape[innerDimPos[index]] =
+ llvm::divideCeil(readMaskShape[innerDimPos[index]], size);
+ }
+
+ // ReadMask is the size of tensor used to read and apply mask. It is
+ // set like this. Let's say the vectorSize (VS) array is size 'N' and
+ // the sourceShape(SS) is 'M' where M >= N
+ // Thus:
+ // ReadMaskShape = [VS[0], ..., VS[N-1], SS[N], ..., SS[M-1]]
+ ReifiedRankedShapedTypeDims reifiedRetShapes;
+ LogicalResult status =
+ cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
+ .reifyResultShapes(rewriter, reifiedRetShapes);
+ if (status.failed()) {
+ LDBG("Unable to reify result shapes of " << unpackOp);
+ return failure();
+ }
+ int64_t unpackRank = unpackTensorType.getRank();
+ Location loc = unpackOp->getLoc();
+
+ Value readResult = createReadOrMaskedRead(
+ rewriter, loc, unpackOp.getSource(),
+ llvm::ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()),
+ nullptr);
+
+ int64_t numPackedDim = unpackOp.getInnerDimsPos().size();
+ llvm::SmallVector<int64_t> lastDims = llvm::to_vector(
+ llvm::seq<int64_t>(unpackRank - numPackedDim, unpackRank));
+ PackingMetadata packMetadata =
+ computePackingMetadata(unpackRank, unpackOp.getInnerDimsPos());
+ SmallVector<int64_t> lastDimToInsertPosPerm = computePermutationVector(
+ unpackRank, lastDims, packMetadata.insertPositions);
+ ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
+ SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
+ mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
+ applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
+
+ RankedTensorType stripMineTensorType =
+ RankedTensorType::Builder(stripMineShape, stripMineElemType, {})
+ .setShape(stripMineShape);
+
+ // Transpose the appropriate rows to match output.
+ vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
+ loc, readResult, lastDimToInsertPosPerm);
+
+ // Collapse the vector to the size required by result.
+ RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+ stripMineTensorType, packMetadata.reassociations);
+ mlir::VectorType vecCollapsedType =
+ VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
+ vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
+ loc, vecCollapsedType, transposeOp->getResult(0));
+
+ SmallVector<int64_t> writeMaskShape(
+ shapeCastOp.getResultVectorType().getShape());
+ Operation *write =
+ createWriteOrMaskedWrite(rewriter, loc, shapeCastOp.getResult(),
+ reifiedRetShapes[0], writeMaskShape);
+ newResults.push_back(write->getResult(0));
+ return success();
+}
+
/// Vectorize a `padOp` with (1) static result type, (2) constant padding value
/// and (3) all-zero lowPad to
/// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
@@ -1774,11 +1753,12 @@ isValidMaskedInputVector(ArrayRef<int64_t> shape,
/// Need to check if the inner-tiles are static/constant.
static LogicalResult
-vectorizeUnPackOpPreCondition(tensor::UnPackOp unpackOp,
+vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
ArrayRef<int64_t> inputVectorSizes) {
// Handling this case requires a bit more change. Right now
// just the required attributes are handled.
+ // TODO: Handle OuterDimsPerm.
if (!unpackOp.getOuterDimsPerm().empty()) {
LDBG("outer dimensions perms NYI for: " << unpackOp);
return failure();
@@ -1846,9 +1826,10 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
}
if (isElementwise(linalgOp))
return success();
- // TODO: isaConvolutionOpInterface that can also infer from generic features.
- // But we will still need stride/dilation attributes that will be annoying to
- // reverse-engineer...
+
+ // TODO: isaConvolutionOpInterface that can also infer from generic
+ // features. But we will still need stride/dilation attributes that will be
+ // annoying to reverse-engineer...
if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
return success();
// TODO: the common vector shape is equal to the static loop sizes only when
@@ -1944,158 +1925,162 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
return TypeSwitch<Operation *, LogicalResult>(op)
.Case<linalg::LinalgOp>([&](auto linalgOp) {
- return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
- vectorizeNDExtract);
+ return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
+ vectorizeNDExtract);
})
.Case<tensor::PadOp>([&](auto padOp) {
- return vectorizePadOpPrecondition(padOp, inputVectorSizes);
+ return vectorizePadOpPrecondition(padOp, inputVectorSizes);
})
.Case<tensor::PackOp>([&](auto packOp) {
- return vectorizePackOpPrecondition(packOp, inputVectorSizes);
- .Case<tensor::UnPackOp>([&](auto unpackOp) {
- return vectorizeUnpackOpPrecondition(unpackOp, inputVectorSizes);
- }).Default([](auto) { return failure(); });
+ return vectorizePackOpPrecondition(packOp, inputVectorSizes);
+ })
+ .Case<tensor::UnPackOp>([&](auto unpackOp) {
+ return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
+ })
+ .Default([](auto) { return failure(); });
}
/// Converts affine.apply Ops to arithmetic operations.
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
- OpBuilder::InsertionGuard g(rewriter);
- auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
-
- for (auto op : make_early_inc_range(toReplace)) {
- rewriter.setInsertionPoint(op);
- auto expanded = affine::expandAffineExpr(
- rewriter, op->getLoc(), op.getAffineMap().getResult(0),
- op.getOperands().take_front(op.getAffineMap().getNumDims()),
- op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
- rewriter.replaceOp(op, expanded);
- }
+ OpBuilder::InsertionGuard g(rewriter);
+ auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
+
+ for (auto op : make_early_inc_range(toReplace)) {
+ rewriter.setInsertionPoint(op);
+ auto expanded = affine::expandAffineExpr(
+ rewriter, op->getLoc(), op.getAffineMap().getResult(0),
+ op.getOperands().take_front(op.getAffineMap().getNumDims()),
+ op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
+ rewriter.replaceOp(op, expanded);
+ }
}
/// Emit a suitable vector form for an operation. If provided,
-/// `inputVectorSizes` are used to vectorize this operation. `inputVectorSizes`
-/// must match the rank of the iteration space of the operation and the input
-/// vector sizes must be greater than or equal to their counterpart iteration
-/// space sizes, if static. `inputVectorShapes` also allows the vectorization of
-/// operations with dynamic shapes.
+/// `inputVectorSizes` are used to vectorize this operation.
+/// `inputVectorSizes` must match the rank of the iteration space of the
+/// operation and the input vector sizes must be greater than or equal to
+/// their counterpart iteration space sizes, if static. `inputVectorShapes`
+/// also allows the vectorization of operations with dynamic shapes.
LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims,
bool vectorizeNDExtract,
bool flatten1DDepthwiseConv) {
- LDBG("Attempting to vectorize:\n" << *op << "\n");
- LDBG("Input vector sizes: ");
- LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
- LLVM_DEBUG(llvm::dbgs() << "\n");
- LDBG("Input scalable vector dims: ");
- LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
- LLVM_DEBUG(llvm::dbgs() << "\n");
-
- if (failed(vectorizeOpPrecondition(
- op, inputVectorSizes, inputScalableVecDims, vectorizeNDExtract))) {
- LDBG("Vectorization pre-conditions failed\n");
- return failure();
- }
-
- // Initialize vectorization state.
- VectorizationState state(rewriter);
- if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
- if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
- inputScalableVecDims))) {
- LDBG("Vectorization state couldn't be initialized\n");
- return failure();
- }
- }
-
- SmallVector<Value> results;
- auto vectorizeResult =
- TypeSwitch<Operation *, LogicalResult>(op)
- .Case<linalg::LinalgOp>([&](auto linalgOp) {
- // TODO: isaConvolutionOpInterface that can also infer from
- // generic features. Will require stride/dilation attributes
- // inference.
- if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
- FailureOr<Operation *> convOr =
- vectorizeConvolution(rewriter, linalgOp, flatten1DDepthwiseConv);
- if (succeeded(convOr)) {
- llvm::append_range(results, (*convOr)->getResults());
- return success();
- }
+ LDBG("Attempting to vectorize:\n" << *op << "\n");
+ LDBG("Input vector sizes: ");
+ LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
+ LLVM_DEBUG(llvm::dbgs() << "\n");
+ LDBG("Input scalable vector dims: ");
+ LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
+ LLVM_DEBUG(llvm::dbgs() << "\n");
- LDBG("Unsupported convolution can't be vectorized.\n");
- return failure();
- }
+ if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
+ vectorizeNDExtract))) {
+ LDBG("Vectorization pre-conditions failed\n");
+ return failure();
+ }
- LDBG("Vectorize generic by broadcasting to the canonical vector "
- "shape\n");
-
- // Pre-process before proceeding.
- convertAffineApply(rewriter, linalgOp);
-
- // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
- // to 'OpBuilder' when it is passed over to some methods like
- // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
- // erase an op within these methods, the actual rewriter won't be
- // notified and we will end up with read-after-free issues!
- return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
- })
- .Case<tensor::PadOp>([&](auto padOp) {
- return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes, results);
- })
- .Case<tensor::PackOp>([&](auto packOp) {
- return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
- results);
- .Case<tensor::UnPackOp>([&](auto unpackOp) {
- return vectorizeAsUnpackOp(rewriter, unpackOp, inputVectorSizes,
- results);
- }).Default([](auto) { return failure(); });
+ // Initialize vectorization state.
+ VectorizationState state(rewriter);
+ if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
+ if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
+ inputScalableVecDims))) {
+ LDBG("Vectorization state couldn't be initialized\n");
+ return failure();
+ }
+ }
- if (failed(vectorizeResult)) {
- LDBG("Vectorization failed\n");
- return failure();
- }
+ SmallVector<Value> results;
+ auto vectorizeResult =
+ TypeSwitch<Operation *, LogicalResult>(op)
+ .Case<linalg::LinalgOp>([&](auto linalgOp) {
+ // TODO: isaConvolutionOpInterface that can also infer from
+ // generic features. Will require stride/dilation attributes
+ // inference.
+ if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
+ FailureOr<Operation *> convOr = vectorizeConvolution(
+ rewriter, linalgOp, flatten1DDepthwiseConv);
+ if (succeeded(convOr)) {
+ llvm::append_range(results, (*convOr)->getResults());
+ return success();
+ }
+
+ LDBG("Unsupported convolution can't be vectorized.\n");
+ return failure();
+ }
+
+ LDBG("Vectorize generic by broadcasting to the canonical vector "
+ "shape\n");
+
+ // Pre-process before proceeding.
+ convertAffineApply(rewriter, linalgOp);
+
+ // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
+ // to 'OpBuilder' when it is passed over to some methods like
+ // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
+ // erase an op within these methods, the actual rewriter won't be
+ // notified and we will end up with read-after-free issues!
+ return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
+ })
+ .Case<tensor::PadOp>([&](auto padOp) {
+ return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
+ results);
+ })
+ .Case<tensor::PackOp>([&](auto packOp) {
+ return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
+ results);
+ })
+ .Case<tensor::UnPackOp>([&](auto unpackOp) {
+ return vectorizeAsUnpackOp(rewriter, unpackOp, inputVectorSizes,
+ results);
+ })
+ .Default([](auto) { return failure(); });
+
+ if (failed(vectorizeResult)) {
+ LDBG("Vectorization failed\n");
+ return failure();
+ }
- if (!results.empty())
- rewriter.replaceOp(op, results);
- else
- rewriter.eraseOp(op);
+ if (!results.empty())
+ rewriter.replaceOp(op, results);
+ else
+ rewriter.eraseOp(op);
- return success();
+ return success();
}
LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
memref::CopyOp copyOp) {
- auto srcType = cast<MemRefType>(copyOp.getSource().getType());
- auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
- if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
- return failure();
-
- auto srcElementType = getElementTypeOrSelf(srcType);
- auto dstElementType = getElementTypeOrSelf(dstType);
- if (!VectorType::isValidElementType(srcElementType) ||
- !VectorType::isValidElementType(dstElementType))
- return failure();
+ auto srcType = cast<MemRefType>(copyOp.getSource().getType());
+ auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
+ if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
+ return failure();
- auto readType = VectorType::get(srcType.getShape(), srcElementType);
- auto writeType = VectorType::get(dstType.getShape(), dstElementType);
+ auto srcElementType = getElementTypeOrSelf(srcType);
+ auto dstElementType = getElementTypeOrSelf(dstType);
+ if (!VectorType::isValidElementType(srcElementType) ||
+ !VectorType::isValidElementType(dstElementType))
+ return failure();
- Location loc = copyOp->getLoc();
- Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- SmallVector<Value> indices(srcType.getRank(), zero);
+ auto readType = VectorType::get(srcType.getShape(), srcElementType);
+ auto writeType = VectorType::get(dstType.getShape(), dstElementType);
- Value readValue = rewriter.create<vector::TransferReadOp>(
- loc, readType, copyOp.getSource(), indices,
- rewriter.getMultiDimIdentityMap(srcType.getRank()));
- if (cast<VectorType>(readValue.getType()).getRank() == 0) {
- readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
- readValue =
- rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
- }
- Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
- loc, readValue, copyOp.getTarget(), indices,
- rewriter.getMultiDimIdentityMap(srcType.getRank()));
- rewriter.replaceOp(copyOp, writeValue->getResults());
- return success();
+ Location loc = copyOp->getLoc();
+ Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ SmallVector<Value> indices(srcType.getRank(), zero);
+
+ Value readValue = rewriter.create<vector::TransferReadOp>(
+ loc, readType, copyOp.getSource(), indices,
+ rewriter.getMultiDimIdentityMap(srcType.getRank()));
+ if (cast<VectorType>(readValue.getType()).getRank() == 0) {
+ readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
+ readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
+ }
+ Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
+ loc, readValue, copyOp.getTarget(), indices,
+ rewriter.getMultiDimIdentityMap(srcType.getRank()));
+ rewriter.replaceOp(copyOp, writeValue->getResults());
+ return success();
}
//----------------------------------------------------------------------------//
@@ -2104,7 +2089,7 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
/// Helper function that retrieves the value of an IntegerAttr.
static int64_t getIntFromAttr(Attribute attr) {
- return cast<IntegerAttr>(attr).getInt();
+ return cast<IntegerAttr>(attr).getInt();
}
/// Given an ArrayRef of OpFoldResults, return a vector of Values.
@@ -2112,16 +2097,16 @@ static int64_t getIntFromAttr(Attribute attr) {
/// not supported.
static SmallVector<Value> ofrToIndexValues(RewriterBase &rewriter, Location loc,
ArrayRef<OpFoldResult> ofrs) {
- SmallVector<Value> result;
- for (auto o : ofrs) {
- if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
- result.push_back(val);
- } else {
- result.push_back(rewriter.create<arith::ConstantIndexOp>(
- loc, getIntFromAttr(o.template get<Attribute>())));
- }
- }
- return result;
+ SmallVector<Value> result;
+ for (auto o : ofrs) {
+ if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
+ result.push_back(val);
+ } else {
+ result.push_back(rewriter.create<arith::ConstantIndexOp>(
+ loc, getIntFromAttr(o.template get<Attribute>())));
+ }
+ }
+ return result;
}
/// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
@@ -2196,8 +2181,7 @@ struct GenericPadOpVectorizationPattern : public GeneralizePadOpPattern {
// If `dest` is a FillOp and the TransferWriteOp would overwrite the
// entire tensor, write directly to the FillOp's operand.
if (llvm::equal(vecShape, resultType.getShape()) &&
- llvm::all_of(writeInBounds, [](bool b) {
- return b; }))
+ llvm::all_of(writeInBounds, [](bool b) { return b; }))
if (auto fill = dest.getDefiningOp<FillOp>())
dest = fill.output();
@@ -2208,7 +2192,7 @@ struct GenericPadOpVectorizationPattern : public GeneralizePadOpPattern {
padOp, read, dest, writeIndices, ArrayRef<bool>{writeInBounds});
return success();
-}
+ }
};
/// Base pattern for rewriting tensor::PadOps whose result is consumed by a
@@ -2980,8 +2964,8 @@ struct Conv1DGenerator
Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
resPadding);
- // The base vectorization case for channeled convolution is input: {n,w,c},
- // weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
+ // The base vectorization case for channeled convolution is input:
+ // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
// vectorization case, we do pre transpose on input, weight, and output.
switch (conv1DOpOrder) {
case Conv1DOpOrder::W:
@@ -3024,9 +3008,9 @@ struct Conv1DGenerator
return kw * (wSize / wSizeStep) + w;
};
- // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f} or
- // perform outerproduct for non-channeled convolution or
- // perform simple arith operation for pooling
+ // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
+ // or perform outerproduct for non-channeled convolution or perform simple
+ // arith operation for pooling
for (int64_t kw = 0; kw < kwSize; ++kw) {
for (int64_t w = 0; w < wSize; w += wSizeStep) {
switch (oper) {
@@ -3055,9 +3039,9 @@ struct Conv1DGenerator
// End vector-only rewrite part
//===------------------------------------------------------------------===//
- // The base vectorization case for channeled convolution is output: {n,w,f}
- // To reuse the result from base pattern vectorization case, we post
- // transpose the base case result.
+ // The base vectorization case for channeled convolution is output:
+ // {n,w,f} To reuse the result from base pattern vectorization case, we
+ // post transpose the base case result.
switch (conv1DOpOrder) {
case Conv1DOpOrder::W:
case Conv1DOpOrder::Nwc:
@@ -3495,9 +3479,9 @@ static FailureOr<Operation *>
vectorizeConvolution(RewriterBase &rewriter, LinalgOp op,
bool flatten1DDepthwiseConv) {
// The ConvolutionOpInterface gives us guarantees of existence for
- // strides/dilations. However, we do not need to rely on those, we can simply
- // use them if present, otherwise use the default and let the generic conv.
- // matcher in the ConvGenerator succeed or fail.
+ // strides/dilations. However, we do not need to rely on those, we can
+ // simply use them if present, otherwise use the default and let the generic
+ // conv. matcher in the ConvGenerator succeed or fail.
auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index f20008a1ed2b2f..6303dec81327a0 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -73,25 +73,38 @@ mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType,
return transposedTensorType;
}
-SmallVector<int64_t>
-mlir::tensor::getPackInverseDestPermutation(PackOp packOp) {
+SmallVector<int64_t> mlir::tensor::getPackUnPackInverseDestPerm(
+ std::variant<tensor::PackOp, tensor::UnPackOp> op) {
+
+ llvm::ArrayRef<int64_t> innerDimsPos, outerPerm;
+ RankedTensorType destType;
+ if (std::holds_alternative<tensor::PackOp>(op)) {
+ tensor::PackOp packOp = std::get<tensor::PackOp>(op);
+ innerDimsPos = packOp.getInnerDimsPos();
+ destType = packOp.getDestType();
+ outerPerm = packOp.getOuterDimsPerm();
+ } else {
+ tensor::UnPackOp unpackOp = std::get<tensor::UnPackOp>(op);
+ innerDimsPos = unpackOp.getInnerDimsPos();
+ destType = unpackOp.getDestType();
+ outerPerm = unpackOp.getOuterDimsPerm();
+ }
// The permutation can be obtained from two permutations:
// a) Compute the permutation vector to move the last `numPackedDims` into
// the `innerPosDims` of a shape of rank `packedRank`.
// b) Compute the permutation vector to move outer dims if the pack op
// has outer_dims_perm.
// Apply (b) permutation on (a) permutation to get the final permutation.
- int64_t numPackedDims = packOp.getInnerDimsPos().size();
- int64_t packedRank = packOp.getDestType().getRank();
+ int64_t numPackedDims = innerDimsPos.size();
+ int64_t packedRank = destType.getRank();
auto lastDims = llvm::to_vector(
llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
- PackingMetadata packingMetadata = computePackingMetadata(
- packOp.getDestType().getRank(), packOp.getInnerDimsPos());
+ PackingMetadata packingMetadata =
+ computePackingMetadata(destType.getRank(), innerDimsPos);
SmallVector<int64_t> innerPositionsPerm = computePermutationVector(
packedRank, lastDims, packingMetadata.insertPositions);
SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
- ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
if (!outerPerm.empty())
applyPermutationToVector(outerPos, outerPerm);
SmallVector<int64_t> outerPositionPerm = computePermutationVector(
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index a79ff6bd75795c..76ea8d83b3c0cf 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -419,56 +419,6 @@ module attributes {transform.with_named_sequence} {
// -----
-// CHECK-LABEL: func @test_vectorize_unpack
-func.func @test_vectorize_unpack(%0 : tensor<7x1136x16x16xf32>) -> tensor<100x18176xf32> {
- // CHECK %[[c0:.*]] = arith.constant 0 : index
- // CHECK: %[[m0:.*]] = vector.create_mask %c7, %c1136, %c16, %c16_0 : vector<2x4x16x16xi1>
- // CHECK: %[[tr0:.*]] = vector.mask %[[m0]] {{.*}} vector.transfer_read %{{.*}} : tensor<7x1136x16x16xf32>, vector<2x4x16x16xf32> } : vector<2x4x16x16xi1> -> vector<2x4x16x16xf32>
- // CHECK: %[[trans0:.*]] = vector.transpose %[[tr0]], [0, 2, 1, 3] : vector<2x4x16x16xf32> to vector<2x16x4x16xf32>
- // CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<2x16x4x16xf32> to vector<32x64xf32>
- // CHECK: %[[empt0:.*]] = tensor.empty() : tensor<100x18176xf32>
- // CHECK: %[[mask0:.*]] = vector.create_mask %c100, %c18176 : vector<32x64xi1>
- // CHECK: %[[tw0:.*]] = vector.mask %[[mask0]] {{.*}} vector.transfer_write %[[sc0]], %[[empt0]]
- // CHECK: return %[[tw0]]
- %8 = tensor.empty() : tensor<100x18176xf32>
- %unpack = tensor.unpack %0 inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %8 : tensor<7x1136x16x16xf32> -> tensor<100x18176xf32>
- return %unpack : tensor<100x18176xf32>
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1
- : (!transform.any_op) -> !transform.any_op
- transform.structured.vectorize %0 vector_sizes [2, 4] : !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
-// CHECK-LABEL: func @test_vectorize_dynamic_shapes_unpack
-func.func @test_vectorize_dynamic_shapes_unpack(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x16x2xf32>) -> tensor<?x?xf32> {
- // CHECK: %[[readMsk0:.*]] = vector.create_mask %dim_3, %dim_5, %c16, %c2 : vector<4x1x16x2xi1>
- // CHECK: %[[read0:.*]] = vector.mask %[[readMsk0]] {{.*}} vector.transfer_read %{{.*}} : tensor<?x?x16x2xf32>, vector<4x1x16x2xf32> } : vector<4x1x16x2xi1> -> vector<4x1x16x2xf32>
- // CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 3, 1, 2] : vector<4x1x16x2xf32> to vector<4x2x1x16xf32>
- // CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<4x2x1x16xf32> to vector<8x16xf32>
- // CHECK: %[[empt0:.*]] = tensor.empty
- // CHECK: %[[writeMsk0:.*]] = vector.create_mask {{.*}} : vector<8x16xi1>
- // CHECK: %[[write0:.*]] = vector.mask %[[writeMsk0:.*]] {{.*}} vector.transfer_write %[[sc0]], %[[empt0]]
- // CHECK: return %[[write0]]
- %ret = tensor.unpack %arg1 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg0 : tensor<?x?x16x2xf32> -> tensor<?x?xf32>
- return %ret : tensor<?x?xf32>
-}
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- transform.structured.vectorize %0 vector_sizes [4, 1] : !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
// CHECK-LABEL: func @test_masked_vectorize_pad
func.func @test_masked_vectorize_pad(
%0 : tensor<?x?xf32>, %h0 : index, %h1 : index)
@@ -722,3 +672,38 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// CHECK-LABEL: func @test_vectorize_dynamic_shapes_unpack
+func.func @test_vectorize_dynamic_shapes_unpack(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x16x2xf32>) -> tensor<?x?xf32> {
+// CHECK: %[[C0:.*]] = arith.constant 0
+// CHECK: %[[DIM:.*]] = tensor.dim %arg0, %[[C0]] : tensor<?x?xf32>
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM0:.*]] = tensor.dim %arg0, %[[C1]] : tensor<?x?xf32>
+// CHECK: %[[C01:.*]] = arith.constant 0
+// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[C02:.*]] = arith.constant 0
+// CHECK: %[[DIM4:.*]] = tensor.dim %arg1, %[[C02]] : tensor<?x?x16x2xf32>
+// CHECK: %[[CNST15:.*]] = arith.constant 1
+// CHECK: %[[DIM6:.*]] = tensor.dim %arg1, %[[CNST15]] : tensor<?x?x16x2xf32>
+// CHECK: %[[CNST16:.*]] = arith.constant 16 : index
+// CHECK: %[[CNST2:.*]] = arith.constant 2 : index
+// CHECK: %[[readMsk0:.*]] = vector.create_mask %[[DIM4]], %[[DIM6]], %[[CNST16]], %[[CNST2]] : vector<2x1x16x2xi1>
+// CHECK: %[[read0:.*]] = vector.mask %[[readMsk0]] {{.*}} vector.transfer_read %{{.*}} : tensor<?x?x16x2xf32>, vector<2x1x16x2xf32> } : vector<2x1x16x2xi1> -> vector<2x1x16x2xf32>
+// CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 3, 1, 2] : vector<2x1x16x2xf32> to vector<2x2x1x16xf32>
+// CHEdCK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<4x2x1x16xf32> to vector<4x16xf32>
+// CHEdCK: %[[empt0:.*]] = tensor.empty
+// CHEdCK: %[[writeMsk0:.*]] = vector.create_mask {{.*}} : vector<4x16xi1>
+// CHEdCK: %[[write0:.*]] = vector.mask %[[writeMsk0:.*]] {{.*}} vector.transfer_write %[[sc0]], %[[empt0]]
+// CHEdCK: return %[[write0]]
+ %ret = tensor.unpack %arg1 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg0 : tensor<?x?x16x2xf32> -> tensor<?x?xf32>
+ return %ret : tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [4, 1] : !transform.any_op
+ transform.yield
+ }
+}
>From d5a0dec6194c6f843765ef202bf5cd3a4150ed99 Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Fri, 9 Feb 2024 23:14:22 +0000
Subject: [PATCH 07/12] Fixed all the issues pointed out by HanHan and Diego.
---
.../include/mlir/Dialect/Tensor/Utils/Utils.h | 5 ++
.../Linalg/Transforms/Vectorization.cpp | 30 ++++-----
mlir/lib/Dialect/Tensor/Utils/Utils.cpp | 48 +++++++------
mlir/test/Dialect/Linalg/vectorization.mlir | 67 +++++++++++++++++--
4 files changed, 108 insertions(+), 42 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
index 60522ac48d95b5..8c8107e0507d70 100644
--- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
@@ -41,6 +41,11 @@ computeTransposedType(RankedTensorType rankedTensorType,
SmallVector<int64_t> getPackUnPackInverseDestPerm(
std::variant<tensor::PackOp, tensor::UnPackOp> packOp);
+/// Unpack requires some packing metadata data, so create another
+/// function where this value is passed by reference.
+SmallVector<int64_t> getPackUnPackInverseDestPerm(
+ std::variant<tensor::PackOp, tensor::UnPackOp> packOp,
+ PackingMetadata &PackingMetadata);
/// A tensor.insert_slice is a cast-like operation if it merely rank-extends the
/// source tensor or inserts the source tensor into a destination tensor with
/// the same shape.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 420ffe533ff0b3..8c5fb1b03d033f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1571,11 +1571,12 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
return success();
}
-/// Vectorize an `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
-/// Vector::TransferReadOp - Reads the Vector Array of Source data
-/// vector::TransposeOp - Transpose the Source
-/// ShapeCastOp - Reshapes the data based on the target.
-/// vector::TransferWriteOp. - Write the result vector back.
+/// Vectorize a `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
+/// Vector::TransferReadOp - Reads a vector from the source tensor
+/// vector::TransposeOp - Transpose the Source tensor
+/// ShapeCastOp - Reshape the data based on the target.
+/// vector::TransferWriteOp. - Write the result vector back to the destination
+/// tensor
static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
tensor::UnPackOp unpackOp,
ArrayRef<int64_t> inputVectorSizes,
@@ -1610,26 +1611,21 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
LDBG("Unable to reify result shapes of " << unpackOp);
return failure();
}
- int64_t unpackRank = unpackTensorType.getRank();
Location loc = unpackOp->getLoc();
+ // Read result, mask if necessary.
Value readResult = createReadOrMaskedRead(
rewriter, loc, unpackOp.getSource(),
llvm::ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()),
nullptr);
- int64_t numPackedDim = unpackOp.getInnerDimsPos().size();
- llvm::SmallVector<int64_t> lastDims = llvm::to_vector(
- llvm::seq<int64_t>(unpackRank - numPackedDim, unpackRank));
- PackingMetadata packMetadata =
- computePackingMetadata(unpackRank, unpackOp.getInnerDimsPos());
- SmallVector<int64_t> lastDimToInsertPosPerm = computePermutationVector(
- unpackRank, lastDims, packMetadata.insertPositions);
+ PackingMetadata packMetadata;
+ SmallVector<int64_t> lastDimToInsertPosPerm = invertPermutationVector(
+ tensor::getPackUnPackInverseDestPerm(unpackOp, packMetadata));
ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
-
RankedTensorType stripMineTensorType =
RankedTensorType::Builder(stripMineShape, stripMineElemType, {})
.setShape(stripMineShape);
@@ -1646,8 +1642,12 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
loc, vecCollapsedType, transposeOp->getResult(0));
+ // WriteMaskShape had to match the shapecast shape for dynamic sizes,
+ // otherwise the validator complains that the mask size is invalid.
SmallVector<int64_t> writeMaskShape(
- shapeCastOp.getResultVectorType().getShape());
+ unpackOp.getDestType().hasStaticShape()
+ ? inputVectorSizes
+ : shapeCastOp.getResultVectorType().getShape());
Operation *write =
createWriteOrMaskedWrite(rewriter, loc, shapeCastOp.getResult(),
reifiedRetShapes[0], writeMaskShape);
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index 6303dec81327a0..0902e33a1f19fd 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -75,18 +75,26 @@ mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType,
SmallVector<int64_t> mlir::tensor::getPackUnPackInverseDestPerm(
std::variant<tensor::PackOp, tensor::UnPackOp> op) {
+ PackingMetadata pMetaData;
+ return getPackUnPackInverseDestPerm(op, pMetaData);
+}
+
+SmallVector<int64_t> mlir::tensor::getPackUnPackInverseDestPerm(
+ std::variant<tensor::PackOp, tensor::UnPackOp> op,
+ PackingMetadata &packingMetadata) {
llvm::ArrayRef<int64_t> innerDimsPos, outerPerm;
- RankedTensorType destType;
- if (std::holds_alternative<tensor::PackOp>(op)) {
+ int64_t rank = 0;
+ bool isPackOp = std::holds_alternative<tensor::PackOp>(op);
+ if (isPackOp) {
tensor::PackOp packOp = std::get<tensor::PackOp>(op);
innerDimsPos = packOp.getInnerDimsPos();
- destType = packOp.getDestType();
+ rank = packOp.getDestType().getRank();
outerPerm = packOp.getOuterDimsPerm();
} else {
tensor::UnPackOp unpackOp = std::get<tensor::UnPackOp>(op);
innerDimsPos = unpackOp.getInnerDimsPos();
- destType = unpackOp.getDestType();
+ rank = unpackOp.getSourceType().getRank();
outerPerm = unpackOp.getOuterDimsPerm();
}
// The permutation can be obtained from two permutations:
@@ -96,23 +104,21 @@ SmallVector<int64_t> mlir::tensor::getPackUnPackInverseDestPerm(
// has outer_dims_perm.
// Apply (b) permutation on (a) permutation to get the final permutation.
int64_t numPackedDims = innerDimsPos.size();
- int64_t packedRank = destType.getRank();
- auto lastDims = llvm::to_vector(
- llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
- PackingMetadata packingMetadata =
- computePackingMetadata(destType.getRank(), innerDimsPos);
- SmallVector<int64_t> innerPositionsPerm = computePermutationVector(
- packedRank, lastDims, packingMetadata.insertPositions);
-
- SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
- if (!outerPerm.empty())
- applyPermutationToVector(outerPos, outerPerm);
- SmallVector<int64_t> outerPositionPerm = computePermutationVector(
- packedRank, packingMetadata.outerPositions, outerPos);
-
- SmallVector<int64_t> packInverseDestPermutation = innerPositionsPerm;
- applyPermutationToVector(packInverseDestPermutation, outerPositionPerm);
- return packInverseDestPermutation;
+ auto lastDims =
+ llvm::to_vector(llvm::seq<int64_t>(rank - numPackedDims, rank));
+ packingMetadata = computePackingMetadata(rank, innerDimsPos);
+ SmallVector<int64_t> innerPositionsPerm =
+ computePermutationVector(rank, lastDims, packingMetadata.insertPositions);
+
+ if (isPackOp) {
+ SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
+ if (!outerPerm.empty())
+ applyPermutationToVector(outerPos, outerPerm);
+ SmallVector<int64_t> outerPositionPerm = computePermutationVector(
+ rank, packingMetadata.outerPositions, outerPos);
+ applyPermutationToVector(innerPositionsPerm, outerPositionPerm);
+ }
+ return innerPositionsPerm;
}
bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 76ea8d83b3c0cf..0c8a76d5231f02 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -691,12 +691,12 @@ func.func @test_vectorize_dynamic_shapes_unpack(%arg0: tensor<?x?xf32>, %arg1: t
// CHECK: %[[CNST2:.*]] = arith.constant 2 : index
// CHECK: %[[readMsk0:.*]] = vector.create_mask %[[DIM4]], %[[DIM6]], %[[CNST16]], %[[CNST2]] : vector<2x1x16x2xi1>
// CHECK: %[[read0:.*]] = vector.mask %[[readMsk0]] {{.*}} vector.transfer_read %{{.*}} : tensor<?x?x16x2xf32>, vector<2x1x16x2xf32> } : vector<2x1x16x2xi1> -> vector<2x1x16x2xf32>
-// CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 3, 1, 2] : vector<2x1x16x2xf32> to vector<2x2x1x16xf32>
-// CHEdCK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<4x2x1x16xf32> to vector<4x16xf32>
-// CHEdCK: %[[empt0:.*]] = tensor.empty
-// CHEdCK: %[[writeMsk0:.*]] = vector.create_mask {{.*}} : vector<4x16xi1>
-// CHEdCK: %[[write0:.*]] = vector.mask %[[writeMsk0:.*]] {{.*}} vector.transfer_write %[[sc0]], %[[empt0]]
-// CHEdCK: return %[[write0]]
+// CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 2, 3, 1] : vector<2x1x16x2xf32> to vector<2x16x2x1xf32>
+// CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<2x16x2x1xf32> to vector<32x2xf32>
+// CHECK: %[[empt0:.*]] = tensor.empty
+// CHECK: %[[writeMsk0:.*]] = vector.create_mask {{.*}} : vector<32x2xi1>
+// CHECK: %[[write0:.*]] = vector.mask %[[writeMsk0:.*]] {{.*}} vector.transfer_write %[[sc0]], %[[empt0]]
+// CHECK: return %[[write0]]
%ret = tensor.unpack %arg1 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg0 : tensor<?x?x16x2xf32> -> tensor<?x?xf32>
return %ret : tensor<?x?xf32>
}
@@ -707,3 +707,58 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// CHECK-LABEL: func @test_vectorize_unpack
+func.func @test_vectorize_unpack(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
+ // CHECK: %[[C0:.*]]= arith.constant 0 : index
+ // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+ // CHECK: %[[C8:.*]] = arith.constant 8 : index
+ // CHECK: %[[C80:.*]] = arith.constant 8 : index
+ // CHECK: %[[C32:.*]] = arith.constant 32 : index
+ // CHECK: %[[C16:.*]] = arith.constant 16 : index
+ // CHECK: %[[MSK0:.*]] = vector.create_mask %c8, %c8_0, %c32, %c16 : vector<16x8x32x16xi1>
+ // CHECK: %[[READ0:.*]] = vector.mask %[[MSK0]] {{.*}} : vector<16x8x32x16xi1> -> vector<16x8x32x16xf32>
+ // CHECK: %[[TRANSP0:.*]] = vector.transpose %[[READ0]], [0, 2, 1, 3] : vector<16x8x32x16xf32> to vector<16x32x8x16xf32>
+ // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP0]] : vector<16x32x8x16xf32> to vector<512x128xf32>
+ // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
+ // CHECK: %[[C01:.*]] = arith.constant 0 : index
+ // CHECK: %[[C256:.*]] = arith.constant 256 : index
+ // CHECK: %[[C128:.*]] = arith.constant 128 : index
+ // CHECK: %[[WRITEMSK:.*]] = vector.create_mask %[[C256]], %[[C128]] : vector<512x128xi1>
+ // CHECK: %[[WRIT:.*]] = vector.mask %[[WRITEMSK]] {{.*}} : vector<512x128xi1> -> tensor<256x128xf32>
+ // CHECK: return %[[WRIT]] : tensor<256x128xf32>
+ %0 = tensor.unpack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
+ return %0 : tensor<256x128xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [512, 128] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func @test_vectorize_unpack_no_masks
+func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+ // CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
+ // CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [0, 2, 1, 3] : vector<8x8x32x16xf32> to vector<8x32x8x16xf32>
+ // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<8x32x8x16xf32> to vector<256x128xf32>
+ // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
+ // CHECK: %[[C00:.*]] = arith.constant 0 : index
+ // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<256x128xf32>, tensor<256x128xf32>
+ // CHECK: return %[[WRIT]] : tensor<256x128xf32>
+ %0 = tensor.unpack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
+ return %0 : tensor<256x128xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [256, 128] : !transform.any_op
+ transform.yield
+ } }
\ No newline at end of file
>From 59d761fae181acf4e66075696cd46bca5c609db5 Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Thu, 15 Feb 2024 23:24:24 +0000
Subject: [PATCH 08/12] Added all the changes requested by Diego and Max
(Except handling of outer Dimensions attribute)
---
.../Linalg/Transforms/Vectorization.cpp | 66 +++++++++----------
mlir/test/Dialect/Linalg/vectorization.mlir | 17 ++---
2 files changed, 42 insertions(+), 41 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 8c5fb1b03d033f..f57fae3baa9e6b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1420,28 +1420,16 @@ static Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
auto sourceShape = dyn_cast<ShapedType>(source.getType()).getShape();
assert(sourceShape.size() == readShape.size());
auto maskType = VectorType::get(readShape, builder.getI1Type());
- Type vecElemType = padValue != nullptr
- ? padValue.getType()
- : cast<ShapedType>(source.getType()).getElementType();
- auto vectorType = VectorType::get(readShape, vecElemType);
+ auto vectorType = VectorType::get(readShape, padValue.getType());
int64_t readRank = readShape.size();
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
- vector::TransferReadOp transferReadOp = nullptr;
- if (padValue == nullptr) {
- transferReadOp = builder.create<vector::TransferReadOp>(
- loc,
- /*vectorType=*/vectorType,
- /*source=*/source,
- /*indices=*/SmallVector<Value>(readRank, zero));
- } else {
- transferReadOp = builder.create<vector::TransferReadOp>(
- loc,
- /*vectorType=*/vectorType,
- /*source=*/source,
- /*indices=*/SmallVector<Value>(readRank, zero),
- /*padding=*/padValue,
- /*inBounds=*/SmallVector<bool>(readRank, true));
- }
+ auto transferReadOp = builder.create<vector::TransferReadOp>(
+ loc,
+ /*vectorType=*/vectorType,
+ /*source=*/source,
+ /*indices=*/SmallVector<Value>(readRank, zero),
+ /*padding=*/padValue,
+ /*inBounds=*/SmallVector<bool>(readRank, true));
if (llvm::equal(readShape, sourceShape)) {
return transferReadOp;
}
@@ -1588,21 +1576,32 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
RankedTensorType unpackTensorType = unpackOp.getSourceType();
SmallVector<int64_t> readMaskShape(unpackTensorType.getShape());
- llvm::ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
- llvm::ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
+ ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
+ ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
for (unsigned int i = 0; i < inputVectorSizes.size(); i++) {
readMaskShape[i] = inputVectorSizes[i];
}
+
+ // ReadMask is the size of tensor used to read and apply mask. It is
+ // set like this. Let's say the vectorSize (VS) array is size 'N' and
+ // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
+ // size M-N
+ // Thus:
+ // ReadMaskShape (initial) = [VS[0], ..., VS[N-1], SS[N], ..., SS[M-1]]
+ // Then divide all the readMaskShape locations pointed by innerDimPos
+ // by the innerTileSize attribute value.
+ // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
+ // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
+ // 128] then read shape is:
+ // ReadMaskShape(initial): [8, 8, 32, 16]
+ // After settin vectorSizes: [512, 128, 32, 16]
+ // Final Value(after innerDim Adjustment): [512/32, 128/16, 32, 16]
+ // = [16, 8, 32, 16]
for (auto [index, size] : enumerate(innerTiles)) {
readMaskShape[innerDimPos[index]] =
llvm::divideCeil(readMaskShape[innerDimPos[index]], size);
}
- // ReadMask is the size of tensor used to read and apply mask. It is
- // set like this. Let's say the vectorSize (VS) array is size 'N' and
- // the sourceShape(SS) is 'M' where M >= N
- // Thus:
- // ReadMaskShape = [VS[0], ..., VS[N-1], SS[N], ..., SS[M-1]]
ReifiedRankedShapedTypeDims reifiedRetShapes;
LogicalResult status =
cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
@@ -1613,11 +1612,14 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
}
Location loc = unpackOp->getLoc();
- // Read result, mask if necessary.
+ auto padValue = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
+
+ // Read result, mask if necessary. If transferReadOp shape is not equal
+ // to shape of source, then a mask is necessary.
Value readResult = createReadOrMaskedRead(
rewriter, loc, unpackOp.getSource(),
- llvm::ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()),
- nullptr);
+ ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()), padValue);
PackingMetadata packMetadata;
SmallVector<int64_t> lastDimToInsertPosPerm = invertPermutationVector(
@@ -1627,9 +1629,7 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
RankedTensorType stripMineTensorType =
- RankedTensorType::Builder(stripMineShape, stripMineElemType, {})
- .setShape(stripMineShape);
-
+ RankedTensorType::get(stripMineShape, stripMineElemType);
// Transpose the appropriate rows to match output.
vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
loc, readResult, lastDimToInsertPosPerm);
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 0c8a76d5231f02..757cc46093daf9 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -681,12 +681,12 @@ func.func @test_vectorize_dynamic_shapes_unpack(%arg0: tensor<?x?xf32>, %arg1: t
// CHECK: %[[DIM:.*]] = tensor.dim %arg0, %[[C0]] : tensor<?x?xf32>
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[DIM0:.*]] = tensor.dim %arg0, %[[C1]] : tensor<?x?xf32>
+// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00
// CHECK: %[[C01:.*]] = arith.constant 0
-// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[C02:.*]] = arith.constant 0
// CHECK: %[[DIM4:.*]] = tensor.dim %arg1, %[[C02]] : tensor<?x?x16x2xf32>
-// CHECK: %[[CNST15:.*]] = arith.constant 1
-// CHECK: %[[DIM6:.*]] = tensor.dim %arg1, %[[CNST15]] : tensor<?x?x16x2xf32>
+// CHECK: %[[CNST14:.*]] = arith.constant 1
+// CHECK: %[[DIM6:.*]] = tensor.dim %arg1, %[[CNST14]] : tensor<?x?x16x2xf32>
// CHECK: %[[CNST16:.*]] = arith.constant 16 : index
// CHECK: %[[CNST2:.*]] = arith.constant 2 : index
// CHECK: %[[readMsk0:.*]] = vector.create_mask %[[DIM4]], %[[DIM6]], %[[CNST16]], %[[CNST2]] : vector<2x1x16x2xi1>
@@ -703,7 +703,7 @@ func.func @test_vectorize_dynamic_shapes_unpack(%arg0: tensor<?x?xf32>, %arg1: t
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- transform.structured.vectorize %0 vector_sizes [4, 1] : !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [4, 16] : !transform.any_op
transform.yield
}
}
@@ -712,13 +712,13 @@ module attributes {transform.with_named_sequence} {
// CHECK-LABEL: func @test_vectorize_unpack
func.func @test_vectorize_unpack(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
- // CHECK: %[[C0:.*]]= arith.constant 0 : index
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+ // CHECK: %[[C0:.*]]= arith.constant 0 : index
// CHECK: %[[C8:.*]] = arith.constant 8 : index
// CHECK: %[[C80:.*]] = arith.constant 8 : index
// CHECK: %[[C32:.*]] = arith.constant 32 : index
// CHECK: %[[C16:.*]] = arith.constant 16 : index
- // CHECK: %[[MSK0:.*]] = vector.create_mask %c8, %c8_0, %c32, %c16 : vector<16x8x32x16xi1>
+ // CHECK: %[[MSK0:.*]] = vector.create_mask %[[C8]], %[[C80]], %[[C32]], %[[C16]] : vector<16x8x32x16xi1>
// CHECK: %[[READ0:.*]] = vector.mask %[[MSK0]] {{.*}} : vector<16x8x32x16xi1> -> vector<16x8x32x16xf32>
// CHECK: %[[TRANSP0:.*]] = vector.transpose %[[READ0]], [0, 2, 1, 3] : vector<16x8x32x16xf32> to vector<16x32x8x16xf32>
// CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP0]] : vector<16x32x8x16xf32> to vector<512x128xf32>
@@ -744,8 +744,8 @@ func.func @test_vectorize_unpack(%source: tensor<8x8x32x16xf32>, %dest: tensor<2
// CHECK-LABEL: func @test_vectorize_unpack_no_masks
func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
- // CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
// CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [0, 2, 1, 3] : vector<8x8x32x16xf32> to vector<8x32x8x16xf32>
// CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<8x32x8x16xf32> to vector<256x128xf32>
@@ -761,4 +761,5 @@ func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest:
%0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.structured.vectorize %0 vector_sizes [256, 128] : !transform.any_op
transform.yield
- } }
\ No newline at end of file
+ }
+}
>From c7ed75e39f79fdf8c4de880c7ea8d1800be347d4 Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Fri, 16 Feb 2024 01:45:50 +0000
Subject: [PATCH 09/12] Added outer_dims_perm support to unpack.
---
.../Linalg/Transforms/Vectorization.cpp | 41 ++++++++++---------
mlir/test/Dialect/Linalg/vectorization.mlir | 32 +++++++++++++++
2 files changed, 53 insertions(+), 20 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index f57fae3baa9e6b..0aa43b6c863e28 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1575,28 +1575,37 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
RankedTensorType unpackTensorType = unpackOp.getSourceType();
- SmallVector<int64_t> readMaskShape(unpackTensorType.getShape());
ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
- for (unsigned int i = 0; i < inputVectorSizes.size(); i++) {
- readMaskShape[i] = inputVectorSizes[i];
+
+ SmallVector<int64_t> readMaskShape(inputVectorSizes.begin(),
+ inputVectorSizes.end());
+ ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
+ if (outerDimsPerm.empty() == false) {
+ applyPermutationToVector(readMaskShape, outerDimsPerm);
}
+ ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
+ readMaskShape.append(sourceShape.begin() + inputVectorSizes.size(),
+ sourceShape.end());
// ReadMask is the size of tensor used to read and apply mask. It is
- // set like this. Let's say the vectorSize (VS) array is size 'N' and
+ // set like this: Let's say the vectorSize (VS) array is size 'N' and
// the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
// size M-N
// Thus:
- // ReadMaskShape (initial) = [VS[0], ..., VS[N-1], SS[N], ..., SS[M-1]]
- // Then divide all the readMaskShape locations pointed by innerDimPos
- // by the innerTileSize attribute value.
+ // - initially: ReadMaskShape = vectorInputSizes
+ // - if outer_dims_perms is present: do that permutation on readMaskShape.
+ // - Append the remaining shape from SS
+ // - Divide all teh readMaskShape locations pointed by innerDimPos
+ // by the innerTileSize attribute value.
// E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
// inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
- // 128] then read shape is:
- // ReadMaskShape(initial): [8, 8, 32, 16]
- // After settin vectorSizes: [512, 128, 32, 16]
- // Final Value(after innerDim Adjustment): [512/32, 128/16, 32, 16]
- // = [16, 8, 32, 16]
+ // 128] and outer_dims_perm is [1, 0] then read shape is:
+ // ReadMaskShape(initial): [512, 128]
+ // After applying outer_dims_perm: [128, 512]
+ // After appending the rest of the sourceShape: [128, 512, 32, 16]
+ // Final Value(after innerDim Adjustment): [128/32, 512/16, 32, 16]
+ // = [4, 32, 32, 16]
for (auto [index, size] : enumerate(innerTiles)) {
readMaskShape[innerDimPos[index]] =
llvm::divideCeil(readMaskShape[innerDimPos[index]], size);
@@ -1756,14 +1765,6 @@ static LogicalResult
vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
ArrayRef<int64_t> inputVectorSizes) {
- // Handling this case requires a bit more change. Right now
- // just the required attributes are handled.
- // TODO: Handle OuterDimsPerm.
- if (!unpackOp.getOuterDimsPerm().empty()) {
- LDBG("outer dimensions perms NYI for: " << unpackOp);
- return failure();
- }
-
if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
return !getConstantIntValue(res).has_value();
})) {
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 757cc46093daf9..3d37c657740055 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -762,4 +762,36 @@ func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest:
transform.structured.vectorize %0 vector_sizes [256, 128] : !transform.any_op
transform.yield
}
+ }
+
+ // -----
+
+ // This test is same as the one test_vectorize_unpack_no_masks but with outer_dims_perm.
+ // Note that adding this attribute causes a read mask.
+
+ // CHECK-LABEL: test_vectorize_unpack_with_outer_perm
+ func.func @test_vectorize_unpack_with_outer_perm(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
+ // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[C8:.*]] = arith.constant 8 : index
+ // CHECK: %[[C80:.*]] = arith.constant 8 : index
+ // CHECK: %[[C32:.*]] = arith.constant 32 : index
+ // CHECK: %[[C16:.*]] = arith.constant 16 : index
+ // CHECK: %[[MSK0:.*]] = vector.create_mask %[[C8]], %[[C80]], %[[C32]], %[[C16]] : vector<4x16x32x16xi1>
+ // CHECK: %[[READ:.*]] = vector.mask %[[MSK0]] {{.*}} : vector<4x16x32x16xi1> -> vector<4x16x32x16xf32>
+ // CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [0, 2, 1, 3] : vector<4x16x32x16xf32> to vector<4x32x16x16xf32>
+ // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<4x32x16x16xf32> to vector<128x256xf32>
+ // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
+ // CHECK: %[[C00:.*]] = arith.constant 0 : index
+ // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<128x256xf32>, tensor<256x128xf32>
+ // CHECK: return %[[WRIT]] : tensor<256x128xf32>
+ %0 = tensor.unpack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
+ return %0 : tensor<256x128xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [256, 128] : !transform.any_op
+ transform.yield
+ }
}
>From a349b1446ee399dca59164ffef4dfc130bb1202f Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Fri, 16 Feb 2024 17:37:58 +0000
Subject: [PATCH 10/12] Fixed all the issues mentioned by Diego on 2/16.
---
.../include/mlir/Dialect/Tensor/Utils/Utils.h | 18 ++--
.../Dialect/Linalg/Transforms/Transforms.cpp | 2 +-
.../Linalg/Transforms/Vectorization.cpp | 13 ++-
mlir/lib/Dialect/Tensor/Utils/Utils.cpp | 85 ++++++++++---------
mlir/test/Dialect/Linalg/vectorization.mlir | 4 +-
5 files changed, 68 insertions(+), 54 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
index 8c8107e0507d70..009702f126eaf3 100644
--- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
@@ -38,14 +38,22 @@ computeTransposedType(RankedTensorType rankedTensorType,
/// i.e. for a pack from an ABCD layout to an ABCDba:
/// The packed shape would be ABCDba.
/// The pre-permutation shape would be AaBbCD.
-SmallVector<int64_t> getPackUnPackInverseDestPerm(
- std::variant<tensor::PackOp, tensor::UnPackOp> packOp);
+SmallVector<int64_t> computePackUnPackPerm(int64_t rank,
+ ArrayRef<int64_t> &innerDimsPos,
+ ArrayRef<int64_t> &outerPerm,
+ PackingMetadata &packingMetadata);
+
+/// This function uses the helper function `computePackUnPackPerm` to get
+/// the permutation vector. Only major difference between UnPack and Pack is
+/// that packOp uses destination rank whereas unpack Uses source rank.
+SmallVector<int64_t> getPackInverseDestPerm(tensor::PackOp packOp);
+SmallVector<int64_t> getUnPackInverseSrcPerm(tensor::UnPackOp unpackOp);
/// Unpack requires some packing metadata data, so create another
/// function where this value is passed by reference.
-SmallVector<int64_t> getPackUnPackInverseDestPerm(
- std::variant<tensor::PackOp, tensor::UnPackOp> packOp,
- PackingMetadata &PackingMetadata);
+SmallVector<int64_t> getUnPackInverseSrcPerm(tensor::UnPackOp,
+ PackingMetadata &metadata);
+
/// A tensor.insert_slice is a cast-like operation if it merely rank-extends the
/// source tensor or inserts the source tensor into a destination tensor with
/// the same shape.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 9f8ea7f1f3969b..850cb861672ad6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -237,7 +237,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
PackingMetadata packingMetadata = computePackingMetadata(
packedTensorType.getRank(), packOp.getInnerDimsPos());
SmallVector<int64_t> packedToStripMinedShapePerm =
- tensor::getPackUnPackInverseDestPerm(packOp);
+ tensor::getPackInverseDestPerm(packOp);
// 3. Compute the stripMinedShape: this is the packed shape before any outer
// or inner permutations have been applied.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0aa43b6c863e28..f066967c4a9097 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1405,8 +1405,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
/// permutations.
static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
ArrayRef<int64_t> destShape) {
- return applyPermutation(destShape,
- tensor::getPackUnPackInverseDestPerm(packOp));
+ return applyPermutation(destShape, tensor::getPackInverseDestPerm(packOp));
}
/// Create a TransferReadOp from `source` with static shape `readShape`. If the
@@ -1547,7 +1546,7 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
// Create TransposeOp.
auto destPermutation =
- invertPermutationVector(tensor::getPackUnPackInverseDestPerm(packOp));
+ invertPermutationVector(tensor::getPackInverseDestPerm(packOp));
auto transposeOp = rewriter.create<vector::TransposeOp>(
loc, shapeCastOp.getResult(), destPermutation);
@@ -1559,7 +1558,7 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
return success();
}
-/// Vectorize a `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
+/// Vectorize a `tensor::UnPackOp` to these 4 Ops:
/// Vector::TransferReadOp - Reads a vector from the source tensor
/// vector::TransposeOp - Transpose the Source tensor
/// ShapeCastOp - Reshape the data based on the target.
@@ -1581,7 +1580,7 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
SmallVector<int64_t> readMaskShape(inputVectorSizes.begin(),
inputVectorSizes.end());
ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
- if (outerDimsPerm.empty() == false) {
+ if (!outerDimsPerm.empty()) {
applyPermutationToVector(readMaskShape, outerDimsPerm);
}
ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
@@ -1632,7 +1631,7 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
PackingMetadata packMetadata;
SmallVector<int64_t> lastDimToInsertPosPerm = invertPermutationVector(
- tensor::getPackUnPackInverseDestPerm(unpackOp, packMetadata));
+ tensor::getUnPackInverseSrcPerm(unpackOp, packMetadata));
ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
@@ -1772,7 +1771,7 @@ vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
return failure();
}
llvm::ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
- if (inputVectorSizes.empty() == false &&
+ if (!inputVectorSizes.empty() &&
failed(isValidMaskedInputVector(resultShape, inputVectorSizes)))
return failure();
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index 0902e33a1f19fd..f1126aaf44c76c 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -72,37 +72,15 @@ mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType,
RTTBuilder(rankedTensorType).setShape(transposedShape);
return transposedTensorType;
}
-
-SmallVector<int64_t> mlir::tensor::getPackUnPackInverseDestPerm(
- std::variant<tensor::PackOp, tensor::UnPackOp> op) {
- PackingMetadata pMetaData;
- return getPackUnPackInverseDestPerm(op, pMetaData);
-}
-
-SmallVector<int64_t> mlir::tensor::getPackUnPackInverseDestPerm(
- std::variant<tensor::PackOp, tensor::UnPackOp> op,
+/// The permutation can be obtained from two permutations:
+/// a) Compute the permutation vector to move the last `numPackedDims` into
+/// the `innerPosDims` of a shape of rank `rank`.
+/// b) Compute the permutation vector to move outer dims if the
+/// `outerPerm` parameter is not empty.
+/// Apply (b) permutation on (a) permutation to get the final permutation.
+SmallVector<int64_t> mlir::tensor::computePackUnPackPerm(
+ int64_t rank, ArrayRef<int64_t> &innerDimsPos, ArrayRef<int64_t> &outerPerm,
PackingMetadata &packingMetadata) {
-
- llvm::ArrayRef<int64_t> innerDimsPos, outerPerm;
- int64_t rank = 0;
- bool isPackOp = std::holds_alternative<tensor::PackOp>(op);
- if (isPackOp) {
- tensor::PackOp packOp = std::get<tensor::PackOp>(op);
- innerDimsPos = packOp.getInnerDimsPos();
- rank = packOp.getDestType().getRank();
- outerPerm = packOp.getOuterDimsPerm();
- } else {
- tensor::UnPackOp unpackOp = std::get<tensor::UnPackOp>(op);
- innerDimsPos = unpackOp.getInnerDimsPos();
- rank = unpackOp.getSourceType().getRank();
- outerPerm = unpackOp.getOuterDimsPerm();
- }
- // The permutation can be obtained from two permutations:
- // a) Compute the permutation vector to move the last `numPackedDims` into
- // the `innerPosDims` of a shape of rank `packedRank`.
- // b) Compute the permutation vector to move outer dims if the pack op
- // has outer_dims_perm.
- // Apply (b) permutation on (a) permutation to get the final permutation.
int64_t numPackedDims = innerDimsPos.size();
auto lastDims =
llvm::to_vector(llvm::seq<int64_t>(rank - numPackedDims, rank));
@@ -110,15 +88,44 @@ SmallVector<int64_t> mlir::tensor::getPackUnPackInverseDestPerm(
SmallVector<int64_t> innerPositionsPerm =
computePermutationVector(rank, lastDims, packingMetadata.insertPositions);
- if (isPackOp) {
- SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
- if (!outerPerm.empty())
- applyPermutationToVector(outerPos, outerPerm);
- SmallVector<int64_t> outerPositionPerm = computePermutationVector(
- rank, packingMetadata.outerPositions, outerPos);
- applyPermutationToVector(innerPositionsPerm, outerPositionPerm);
- }
- return innerPositionsPerm;
+ SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
+ if (!outerPerm.empty())
+ applyPermutationToVector(outerPos, outerPerm);
+ SmallVector<int64_t> outerPositionPerm =
+ computePermutationVector(rank, packingMetadata.outerPositions, outerPos);
+
+ SmallVector<int64_t> packInverseDestPermutation = innerPositionsPerm;
+ applyPermutationToVector(packInverseDestPermutation, outerPositionPerm);
+ return packInverseDestPermutation;
+}
+
+/// Shell function to compute the Destination Permutation of PackOp
+SmallVector<int64_t> mlir::tensor::getPackInverseDestPerm(PackOp packOp) {
+
+ PackingMetadata pMetadata;
+ int64_t packedRank = packOp.getDestType().getRank();
+ ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos();
+ ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
+ SmallVector<int64_t> packInvDestPerm = mlir::tensor::computePackUnPackPerm(
+ packedRank, innerDimPos, outerPerm, pMetadata);
+ return packInvDestPerm;
+}
+
+SmallVector<int64_t> mlir::tensor::getUnPackInverseSrcPerm(UnPackOp unpackOp) {
+ PackingMetadata metadata;
+ return mlir::tensor::getUnPackInverseSrcPerm(unpackOp, metadata);
+}
+
+/// Shell function to compute the Source rank permutation for unpackOp
+SmallVector<int64_t>
+mlir::tensor::getUnPackInverseSrcPerm(UnPackOp unpackOp,
+ PackingMetadata &metadata) {
+ int64_t unpackRank = unpackOp.getSourceType().getRank();
+ ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
+ ArrayRef<int64_t> outerPerm = unpackOp.getOuterDimsPerm();
+ SmallVector<int64_t> unpackInvSrcPerm = mlir::tensor::computePackUnPackPerm(
+ unpackRank, innerDimPos, outerPerm, metadata);
+ return unpackInvSrcPerm;
}
bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 3d37c657740055..36106312be4f95 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -779,8 +779,8 @@ func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest:
// CHECK: %[[C16:.*]] = arith.constant 16 : index
// CHECK: %[[MSK0:.*]] = vector.create_mask %[[C8]], %[[C80]], %[[C32]], %[[C16]] : vector<4x16x32x16xi1>
// CHECK: %[[READ:.*]] = vector.mask %[[MSK0]] {{.*}} : vector<4x16x32x16xi1> -> vector<4x16x32x16xf32>
- // CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [0, 2, 1, 3] : vector<4x16x32x16xf32> to vector<4x32x16x16xf32>
- // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<4x32x16x16xf32> to vector<128x256xf32>
+ // CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [2, 0, 1, 3] : vector<4x16x32x16xf32> to vector<32x4x16x16xf32>
+ // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<32x4x16x16xf32> to vector<128x256xf32>
// CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
// CHECK: %[[C00:.*]] = arith.constant 0 : index
// CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<128x256xf32>, tensor<256x128xf32>
>From e8e0d88d33dbdcfb1b26838c6d68e49070c0e1f3 Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Sat, 17 Feb 2024 05:15:55 +0000
Subject: [PATCH 11/12] Added all the comment changes requested by Diego.
---
.../include/mlir/Dialect/Tensor/Utils/Utils.h | 16 ------------
.../Linalg/Transforms/Vectorization.cpp | 2 +-
mlir/lib/Dialect/Tensor/Utils/Utils.cpp | 25 +++++++++++++------
3 files changed, 19 insertions(+), 24 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
index 009702f126eaf3..d09c9e36f6ff88 100644
--- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
@@ -32,25 +32,9 @@ FailureOr<RankedTensorType>
computeTransposedType(RankedTensorType rankedTensorType,
ArrayRef<int64_t> transposeVector);
-/// Given a tensor::PackOp, compute the permutation vector to shuffle the
-/// packed shape into the shape before any outer or inner permutations have
-/// been applied.
-/// i.e. for a pack from an ABCD layout to an ABCDba:
-/// The packed shape would be ABCDba.
-/// The pre-permutation shape would be AaBbCD.
-SmallVector<int64_t> computePackUnPackPerm(int64_t rank,
- ArrayRef<int64_t> &innerDimsPos,
- ArrayRef<int64_t> &outerPerm,
- PackingMetadata &packingMetadata);
-
-/// This function uses the helper function `computePackUnPackPerm` to get
-/// the permutation vector. Only major difference between UnPack and Pack is
-/// that packOp uses destination rank whereas unpack Uses source rank.
SmallVector<int64_t> getPackInverseDestPerm(tensor::PackOp packOp);
SmallVector<int64_t> getUnPackInverseSrcPerm(tensor::UnPackOp unpackOp);
-/// Unpack requires some packing metadata data, so create another
-/// function where this value is passed by reference.
SmallVector<int64_t> getUnPackInverseSrcPerm(tensor::UnPackOp,
PackingMetadata &metadata);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index f066967c4a9097..a8b64eb149ed63 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1595,7 +1595,7 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
// - initially: ReadMaskShape = vectorInputSizes
// - if outer_dims_perms is present: do that permutation on readMaskShape.
// - Append the remaining shape from SS
- // - Divide all teh readMaskShape locations pointed by innerDimPos
+ // - Divide all the readMaskShape locations pointed by innerDimPos
// by the innerTileSize attribute value.
// E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
// inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index f1126aaf44c76c..186f85d2ce20a6 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -78,9 +78,10 @@ mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType,
/// b) Compute the permutation vector to move outer dims if the
/// `outerPerm` parameter is not empty.
/// Apply (b) permutation on (a) permutation to get the final permutation.
-SmallVector<int64_t> mlir::tensor::computePackUnPackPerm(
- int64_t rank, ArrayRef<int64_t> &innerDimsPos, ArrayRef<int64_t> &outerPerm,
- PackingMetadata &packingMetadata) {
+static SmallVector<int64_t>
+computePackUnPackPerm(int64_t rank, ArrayRef<int64_t> &innerDimsPos,
+ ArrayRef<int64_t> &outerPerm,
+ PackingMetadata &packingMetadata) {
int64_t numPackedDims = innerDimsPos.size();
auto lastDims =
llvm::to_vector(llvm::seq<int64_t>(rank - numPackedDims, rank));
@@ -100,31 +101,41 @@ SmallVector<int64_t> mlir::tensor::computePackUnPackPerm(
}
/// Shell function to compute the Destination Permutation of PackOp
+/// This function uses the helper function `computePackUnPackPerm` to get
+/// the permutation vector. Only major difference between UnPack and Pack is
+/// that packOp uses destination rank whereas unpack Uses source rank.
SmallVector<int64_t> mlir::tensor::getPackInverseDestPerm(PackOp packOp) {
PackingMetadata pMetadata;
int64_t packedRank = packOp.getDestType().getRank();
ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos();
ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
- SmallVector<int64_t> packInvDestPerm = mlir::tensor::computePackUnPackPerm(
- packedRank, innerDimPos, outerPerm, pMetadata);
+ SmallVector<int64_t> packInvDestPerm =
+ computePackUnPackPerm(packedRank, innerDimPos, outerPerm, pMetadata);
return packInvDestPerm;
}
+/// Shell function to compute the Source Permutation of unPackOp.
+/// This function, like the getPackInverseDestPerm uses the helper function
+/// computePackUnPackPerm` to get the permutation vector.
+/// Only major difference between UnPack and Pack is that packOp uses
+/// destination rank whereas unpack Uses source rank.
SmallVector<int64_t> mlir::tensor::getUnPackInverseSrcPerm(UnPackOp unpackOp) {
PackingMetadata metadata;
return mlir::tensor::getUnPackInverseSrcPerm(unpackOp, metadata);
}
/// Shell function to compute the Source rank permutation for unpackOp
+/// Unpack requires some packing metadata data information, so created
+/// another function where this value is passed by reference.
SmallVector<int64_t>
mlir::tensor::getUnPackInverseSrcPerm(UnPackOp unpackOp,
PackingMetadata &metadata) {
int64_t unpackRank = unpackOp.getSourceType().getRank();
ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
ArrayRef<int64_t> outerPerm = unpackOp.getOuterDimsPerm();
- SmallVector<int64_t> unpackInvSrcPerm = mlir::tensor::computePackUnPackPerm(
- unpackRank, innerDimPos, outerPerm, metadata);
+ SmallVector<int64_t> unpackInvSrcPerm =
+ computePackUnPackPerm(unpackRank, innerDimPos, outerPerm, metadata);
return unpackInvSrcPerm;
}
>From 524c0d97228d94f9c994bc12754850b3a6c641d6 Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Tue, 20 Feb 2024 21:46:02 +0000
Subject: [PATCH 12/12] Fixed all the issues mentioned by Max on 2/20.
---
.../Linalg/Transforms/Vectorization.cpp | 39 ++++++++++---------
mlir/test/Dialect/Linalg/vectorization.mlir | 22 ++++-------
2 files changed, 27 insertions(+), 34 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index a8b64eb149ed63..ac043e87223dfe 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1564,10 +1564,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
/// ShapeCastOp - Reshape the data based on the target.
/// vector::TransferWriteOp. - Write the result vector back to the destination
/// tensor
-static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
- tensor::UnPackOp unpackOp,
- ArrayRef<int64_t> inputVectorSizes,
- SmallVectorImpl<Value> &newResults) {
+static LogicalResult
+vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
+ ArrayRef<int64_t> inputVectorSizes,
+ SmallVectorImpl<Value> &newResults) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unpackOp);
@@ -1580,12 +1580,7 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
SmallVector<int64_t> readMaskShape(inputVectorSizes.begin(),
inputVectorSizes.end());
ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
- if (!outerDimsPerm.empty()) {
- applyPermutationToVector(readMaskShape, outerDimsPerm);
- }
ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
- readMaskShape.append(sourceShape.begin() + inputVectorSizes.size(),
- sourceShape.end());
// ReadMask is the size of tensor used to read and apply mask. It is
// set like this: Let's say the vectorSize (VS) array is size 'N' and
@@ -1593,22 +1588,28 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
// size M-N
// Thus:
// - initially: ReadMaskShape = vectorInputSizes
- // - if outer_dims_perms is present: do that permutation on readMaskShape.
- // - Append the remaining shape from SS
// - Divide all the readMaskShape locations pointed by innerDimPos
// by the innerTileSize attribute value.
+ // - if outer_dims_perms is present: do that permutation on readMaskShape.
+ // - Append the remaining shape from SS
// E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
// inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
// 128] and outer_dims_perm is [1, 0] then read shape is:
// ReadMaskShape(initial): [512, 128]
- // After applying outer_dims_perm: [128, 512]
- // After appending the rest of the sourceShape: [128, 512, 32, 16]
- // Final Value(after innerDim Adjustment): [128/32, 512/16, 32, 16]
- // = [4, 32, 32, 16]
+ // Final Value(after innerDim Adjustment): [512/32, 128/16]
+ // = [16, 8]
+ // After applying outer_dims_perm: [8, 16]
+ // After appending the rest of the sourceShape: [8, 16, 32, 16]
+
for (auto [index, size] : enumerate(innerTiles)) {
readMaskShape[innerDimPos[index]] =
llvm::divideCeil(readMaskShape[innerDimPos[index]], size);
}
+ if (!outerDimsPerm.empty()) {
+ applyPermutationToVector(readMaskShape, outerDimsPerm);
+ }
+ readMaskShape.append(sourceShape.begin() + inputVectorSizes.size(),
+ sourceShape.end());
ReifiedRankedShapedTypeDims reifiedRetShapes;
LogicalResult status =
@@ -1630,8 +1631,8 @@ static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()), padValue);
PackingMetadata packMetadata;
- SmallVector<int64_t> lastDimToInsertPosPerm = invertPermutationVector(
- tensor::getUnPackInverseSrcPerm(unpackOp, packMetadata));
+ SmallVector<int64_t> lastDimToInsertPosPerm =
+ tensor::getUnPackInverseSrcPerm(unpackOp, packMetadata);
ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
@@ -2031,8 +2032,8 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
results);
})
.Case<tensor::UnPackOp>([&](auto unpackOp) {
- return vectorizeAsUnpackOp(rewriter, unpackOp, inputVectorSizes,
- results);
+ return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
+ inputVectorSizes, results);
})
.Default([](auto) { return failure(); });
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 36106312be4f95..64f9439d6fe3a8 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -691,10 +691,10 @@ func.func @test_vectorize_dynamic_shapes_unpack(%arg0: tensor<?x?xf32>, %arg1: t
// CHECK: %[[CNST2:.*]] = arith.constant 2 : index
// CHECK: %[[readMsk0:.*]] = vector.create_mask %[[DIM4]], %[[DIM6]], %[[CNST16]], %[[CNST2]] : vector<2x1x16x2xi1>
// CHECK: %[[read0:.*]] = vector.mask %[[readMsk0]] {{.*}} vector.transfer_read %{{.*}} : tensor<?x?x16x2xf32>, vector<2x1x16x2xf32> } : vector<2x1x16x2xi1> -> vector<2x1x16x2xf32>
-// CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 2, 3, 1] : vector<2x1x16x2xf32> to vector<2x16x2x1xf32>
-// CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<2x16x2x1xf32> to vector<32x2xf32>
+// CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 3, 1, 2] : vector<2x1x16x2xf32> to vector<2x2x1x16xf32>
+// CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<2x2x1x16xf32> to vector<4x16xf32>
// CHECK: %[[empt0:.*]] = tensor.empty
-// CHECK: %[[writeMsk0:.*]] = vector.create_mask {{.*}} : vector<32x2xi1>
+// CHECK: %[[writeMsk0:.*]] = vector.create_mask {{.*}} : vector<4x16xi1>
// CHECK: %[[write0:.*]] = vector.mask %[[writeMsk0:.*]] {{.*}} vector.transfer_write %[[sc0]], %[[empt0]]
// CHECK: return %[[write0]]
%ret = tensor.unpack %arg1 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg0 : tensor<?x?x16x2xf32> -> tensor<?x?xf32>
@@ -766,24 +766,16 @@ func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest:
// -----
- // This test is same as the one test_vectorize_unpack_no_masks but with outer_dims_perm.
- // Note that adding this attribute causes a read mask.
-
// CHECK-LABEL: test_vectorize_unpack_with_outer_perm
func.func @test_vectorize_unpack_with_outer_perm(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: %[[C8:.*]] = arith.constant 8 : index
- // CHECK: %[[C80:.*]] = arith.constant 8 : index
- // CHECK: %[[C32:.*]] = arith.constant 32 : index
- // CHECK: %[[C16:.*]] = arith.constant 16 : index
- // CHECK: %[[MSK0:.*]] = vector.create_mask %[[C8]], %[[C80]], %[[C32]], %[[C16]] : vector<4x16x32x16xi1>
- // CHECK: %[[READ:.*]] = vector.mask %[[MSK0]] {{.*}} : vector<4x16x32x16xi1> -> vector<4x16x32x16xf32>
- // CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [2, 0, 1, 3] : vector<4x16x32x16xf32> to vector<32x4x16x16xf32>
- // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<32x4x16x16xf32> to vector<128x256xf32>
+ // CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
+ // CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [1, 2, 0, 3] : vector<8x8x32x16xf32> to vector<8x32x8x16xf32>
+ // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<8x32x8x16xf32> to vector<256x128xf32>
// CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
// CHECK: %[[C00:.*]] = arith.constant 0 : index
- // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<128x256xf32>, tensor<256x128xf32>
+ // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<256x128xf32>, tensor<256x128xf32>
// CHECK: return %[[WRIT]] : tensor<256x128xf32>
%0 = tensor.unpack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
return %0 : tensor<256x128xf32>
More information about the Mlir-commits
mailing list