[Mlir-commits] [mlir] [mlir][linalg] Add extra_pad_tiles to linalg.pack & unpack (PR #189049)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 27 09:27:08 PDT 2026
https://github.com/fabrizio-indirli created https://github.com/llvm/llvm-project/pull/189049
- In linalg.pack, add the optional `extra_pad_tiles` attribute to append a chosen number of additional full tiles of high-padding to each tiled dimension.
- In linalg.unpack, add the dual optional attribute `drop_last_tiles` to drop a chosen number of full outer tiles for each tiled dimension before reconstructing the unpacked result.
>From 557a4cb56be485b202f117ddf91f9a812ee0b2a8 Mon Sep 17 00:00:00 2001
From: Fabrizio Indirli <fabrizio.indirli at arm.com>
Date: Wed, 11 Mar 2026 17:39:01 +0000
Subject: [PATCH] [mlir][linalg] Add extra_pad_tiles to linalg.pack & unpack
- In linalg.pack, add the optional `extra_pad_tiles` attribute
to append a chosen number of additional full tiles of high-padding
to each tiled dimension.
- In linalg.unpack, add the dual optional attribute `drop_last_tiles` to
drop a chosen number of full outer tiles for each tiled dimension before
reconstructing the unpacked result.
Signed-off-by: Fabrizio Indirli <fabrizio.indirli at arm.com>
---
.../Dialect/Linalg/IR/LinalgRelayoutOps.td | 53 ++--
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 255 +++++++++++++++---
.../Transforms/DataLayoutPropagation.cpp | 38 +--
.../Transforms/PackAndUnpackPatterns.cpp | 63 +++--
.../Dialect/Linalg/Transforms/Transforms.cpp | 8 +-
mlir/python/mlir/dialects/linalg/__init__.py | 4 +
mlir/test/Dialect/Linalg/canonicalize.mlir | 26 ++
mlir/test/Dialect/Linalg/invalid.mlir | 51 ++++
mlir/test/Dialect/Linalg/roundtrip.mlir | 13 +
.../Dialect/Linalg/transform-lower-pack.mlir | 49 ++++
mlir/test/python/dialects/linalg/ops.py | 10 +-
11 files changed, 474 insertions(+), 96 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 95383e6262f71..cdecb41db3123 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -164,13 +164,16 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
tiles divide perfectly the corresponding outer dimension in the result
tensor. It is UB if the tile does not perfectly divide the dimension.
- If present, it will pad along high dimensions (high-padding) to make the
- tile complete. Note that it is not allowed to have artificial padding that
- is not strictly required by linalg.pack (i.e., padding past what is needed
- to complete the last tile along each packed dimension). It is UB if extra
- padding is requested.
+ tile complete.
It is not possible to verify the requirements statically with dynamic
shapes, so they are treated as UB.
+ `extra_pad_tiles` (optional) specifies a number of additional full tiles of
+ high-padding to append for each tiled dimension. It is indexed in the same
+ order as `inner_dims_pos` / `inner_tiles`, must have the same length, and
+ defaults to all zeros when omitted. `extra_pad_tiles[i]` adds that many
+ extra full tiles at the end of tiled dimension `inner_dims_pos[i]`.
+
Example:
```mlir
%0 = linalg.pack %arg0 padding_value(%pad : f32) outer_dims_perm = [2, 1, 0]
@@ -200,7 +203,8 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
DenseI64ArrayAttr:$inner_dims_pos,
Variadic<Index>:$inner_tiles,
- DenseI64ArrayAttr:$static_inner_tiles);
+ DenseI64ArrayAttr:$static_inner_tiles,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$extra_pad_tiles);
let results = (outs Optional<AnyRankedTensor>:$result);
let builders = [
@@ -208,7 +212,8 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
"ArrayRef<int64_t>":$innerDimsPos,
"ArrayRef<OpFoldResult>":$innerTiles,
CArg<"std::optional<Value>", "std::nullopt">:$paddingValue,
- CArg<"ArrayRef<int64_t>", "{}">:$outerDimsPerm)>
+ CArg<"ArrayRef<int64_t>", "{}">:$outerDimsPerm,
+ CArg<"ArrayRef<int64_t>", "{}">:$extraPadTiles)>
];
let extraClassDeclaration = commonExtraClassDeclaration # [{
@@ -218,28 +223,32 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
static SmallVector<OpFoldResult> getResultShape(OpBuilder &builder,
Location loc, ArrayRef<OpFoldResult> sourceDims,
ArrayRef<OpFoldResult> innerTileDims, ArrayRef<int64_t> innerDimsPos,
- ArrayRef<int64_t> outerDimsPerm = {});
+ ArrayRef<int64_t> outerDimsPerm = {},
+ ArrayRef<int64_t> extraPadTiles = {});
// Method to get the `RankedTensorType` of the result based on the inner
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
// of outer loops (outerDimsPerm).
static RankedTensorType inferPackedTensorType(RankedTensorType sourceType,
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
- ArrayRef<int64_t> outerDimsPerm = {});
+ ArrayRef<int64_t> outerDimsPerm = {},
+ ArrayRef<int64_t> extraPadTiles = {});
// Method to get the `MemRefType` of the result based on the inner
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
// of outer loops (outerDimsPerm).
static MemRefType inferPackedMemRefType(MemRefType sourceType,
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
- ArrayRef<int64_t> outerDimsPerm = {});
+ ArrayRef<int64_t> outerDimsPerm = {},
+ ArrayRef<int64_t> extraPadTiles = {});
// Returns the shape of the packed type. It is a shared helper that helps
// type inference methods in a way that ensures that they agree on which
// dimensions are dynamic.
static SmallVector<int64_t> inferPackedShape(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
- ArrayRef<int64_t> outerDimsPerm = {});
+ ArrayRef<int64_t> outerDimsPerm = {},
+ ArrayRef<int64_t> extraPadTiles = {});
// Returns true if we have enough static information to catch undefined
// behavior when the tile size does not divide perfectly the dimension of
@@ -249,7 +258,8 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outputShape,
ArrayRef<int64_t> outerDimsPerm,
- ArrayRef<OpFoldResult> innerTiles);
+ ArrayRef<OpFoldResult> innerTiles,
+ ArrayRef<int64_t> extraPadTiles = {});
// Same as above function but here dynamic dimensions are assumed
// to require padding.
@@ -257,11 +267,13 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outputShape,
ArrayRef<int64_t> outerDimsPerm,
- ArrayRef<OpFoldResult> innerTiles);
+ ArrayRef<OpFoldResult> innerTiles,
+ ArrayRef<int64_t> extraPadTiles = {});
static Value createDestinationTensor(OpBuilder &b, Location loc,
Value source, ArrayRef<OpFoldResult> innerTileSizes,
- ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);
+ ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm,
+ ArrayRef<int64_t> extraPadTiles = {});
/// Build and return a new PackOp that is a clone of the current PackOp with
/// (innerDimsPos, innerTiles) (resp. outerDimsPerm) are permuted by
@@ -325,6 +337,12 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
dimensions. If specified, it must have `n - k` elements. If specified, this
permutation is applied before combining any dimensions.
+ `drop_last_tiles` (optional) specifies how many full packed outer tiles to
+ drop for each tiled dimension before reconstructing the unpacked result. It
+ is indexed in the same order as `inner_dims_pos` / `inner_tiles`, must have
+ the same length, and defaults to all zeros when omitted. This can be used
+ to drop the extra pad tiles added by a previous `pack` operation with `extra_pad_tiles`.
+
Note, the unpack operation may drop any padding introduced by the pack
operation and hence the following holds
`NumElementsOf(source) >= NumElementsOf(result)`.
@@ -362,20 +380,23 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
TensorOrMemRef<[AnyType]>:$dest,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
DenseI64ArrayAttr:$inner_dims_pos, Variadic<Index>:$inner_tiles,
- DenseI64ArrayAttr:$static_inner_tiles);
+ DenseI64ArrayAttr:$static_inner_tiles,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$drop_last_tiles);
let results = (outs Optional<AnyRankedTensor>:$result);
let builders = [
OpBuilder<(ins "Value":$source, "Value":$dest,
"ArrayRef<int64_t>":$innerDimsPos,
"ArrayRef<OpFoldResult>":$innerTiles,
- CArg<"ArrayRef<int64_t>", "{}">:$outerDimsPerm)>
+ CArg<"ArrayRef<int64_t>", "{}">:$outerDimsPerm,
+ CArg<"ArrayRef<int64_t>", "{}">:$dropLastTiles)>
];
let extraClassDeclaration = commonExtraClassDeclaration # [{
static Value createDestinationTensor(OpBuilder &b, Location loc,
Value source, ArrayRef<OpFoldResult> innerTileSizes,
- ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);
+ ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm,
+ ArrayRef<int64_t> dropLastTiles = {});
/// Build and return a new UnPackOp that is a clone of the current UnPackOp
/// with (innerDimsPos, innerTiles) (resp. outerDimsPerm) are permuted by
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e9698365765e7..8223ab78d1ef5 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5104,6 +5104,51 @@ static SmallVector<int64_t> getStaticTilesImpl(OpTy op) {
return staticTiles;
}
+static inline bool isEmptyOrZeroArray(ArrayRef<int64_t> values) {
+ return values.empty() ||
+ llvm::all_of(values, [](int64_t value) { return value == 0; });
+}
+
+static ArrayRef<int64_t> getExtraTrailingTiles(PackOp op) {
+ return op.getExtraPadTiles();
+}
+
+static ArrayRef<int64_t> getExtraTrailingTiles(UnPackOp op) {
+ return op.getDropLastTiles();
+}
+
+static void addExtraTrailingTiles(SmallVectorImpl<int64_t> &outerDims,
+ ArrayRef<int64_t> innerDimsPos,
+ ArrayRef<int64_t> adjustments) {
+ if (adjustments.empty())
+ return;
+ for (auto [index, dimPos] : llvm::enumerate(innerDimsPos)) {
+ if (ShapedType::isDynamic(outerDims[dimPos]))
+ continue;
+ outerDims[dimPos] += adjustments[index];
+ }
+}
+
+static void addExtraTrailingTiles(OpBuilder &builder, Location loc,
+ SmallVectorImpl<OpFoldResult> &outerDims,
+ ArrayRef<int64_t> innerDimsPos,
+ ArrayRef<int64_t> adjustments) {
+ if (adjustments.empty())
+ return;
+ AffineExpr d0, c0;
+ bindDims(builder.getContext(), d0);
+ bindSymbols(builder.getContext(), c0);
+ auto addConstant = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/1, d0 + c0,
+ builder.getContext());
+ for (auto [index, dimPos] : llvm::enumerate(innerDimsPos)) {
+ if (adjustments[index] == 0)
+ continue;
+ outerDims[dimPos] = affine::makeComposedFoldedAffineApply(
+ builder, loc, addConstant,
+ {outerDims[dimPos], builder.getIndexAttr(adjustments[index])});
+ }
+}
+
/// Returns true if `dimsPos` is invalid. It is invalid when:
/// a) It contains duplicate.
/// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
@@ -5161,12 +5206,26 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
size_t unpackedRank = unpackedType.getRank();
ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
+ ArrayRef<int64_t> extraPadTiles = getExtraTrailingTiles(packOrUnPack);
if (isInvalidPackingPosSpecification(innerDimsPos, unpackedRank))
return op->emitError("invalid inner_dims_pos vector");
if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank))
return op->emitError("invalid outer_dims_perm vector");
if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
return op->emitError("outer_dims_perm must be a permutation or empty");
+ if (!extraPadTiles.empty() && extraPadTiles.size() != innerDimsPos.size()) {
+ return op->emitError() << (std::is_same<OpTy, PackOp>::value
+ ? "extra_pad_tiles"
+ : "drop_last_tiles")
+ << " must have the same number of entries as "
+ "inner_dims_pos";
+ }
+ if (llvm::any_of(extraPadTiles, [](int64_t value) { return value < 0; })) {
+ return op->emitError() << (std::is_same<OpTy, PackOp>::value
+ ? "extra_pad_tiles"
+ : "drop_last_tiles")
+ << " must contain only non-negative values";
+ }
// Tiling factors must be less than or equal to the input rank for pack (or
// output rank for unpack), and must match the number of `inner_dims_pos`.
@@ -5196,7 +5255,8 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
// represents full tiles.
SmallVector<int64_t> expectedPackedShape = PackOp::inferPackedShape(
unpackedType.getShape(), packOrUnPack.getStaticTiles(),
- packOrUnPack.getInnerDimsPos(), packOrUnPack.getOuterDimsPerm());
+ packOrUnPack.getInnerDimsPos(), packOrUnPack.getOuterDimsPerm(),
+ extraPadTiles);
for (auto it : llvm::enumerate(llvm::zip(
packedType.getShape().take_back(mixedTiles.size()), mixedTiles))) {
int64_t dimSize = std::get<0>(it.value());
@@ -5244,6 +5304,7 @@ struct PackOrUnPackTransposeResult {
SmallVector<int64_t> innerDimsPos;
SmallVector<OpFoldResult> innerTiles;
SmallVector<int64_t> outerDimsPerm;
+ SmallVector<int64_t> trailingTileAdjustments;
};
} // namespace
@@ -5261,6 +5322,8 @@ commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp,
SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
metadata.innerTiles =
SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
+ metadata.trailingTileAdjustments =
+ llvm::to_vector(getExtraTrailingTiles(packOrUnPackOp));
int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
? packOrUnPackOp.getSourceRank()
: packOrUnPackOp.getDestRank();
@@ -5274,6 +5337,9 @@ commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp,
"invalid inner permutation");
applyPermutationToVector(metadata.innerDimsPos, innerPermutation);
applyPermutationToVector(metadata.innerTiles, innerPermutation);
+ if (!metadata.trailingTileAdjustments.empty())
+ applyPermutationToVector(metadata.trailingTileAdjustments,
+ innerPermutation);
}
if (!outerPermutation.empty()) {
assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
@@ -5299,7 +5365,7 @@ ParseResult PackOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand> paddingValue;
SmallVector<Type> paddingValueType;
SmallVector<int64_t> staticTiles;
- DenseI64ArrayAttr innerDimsPos, outerDimsPerm;
+ DenseI64ArrayAttr innerDimsPos, outerDimsPerm, extraPadTiles;
Type sourceType, destType, resultType;
if (parser.parseOperand(source))
@@ -5352,6 +5418,24 @@ ParseResult PackOp::parse(OpAsmParser &parser, OperationState &result) {
for (auto val : staticTilesAttr.asArrayRef())
staticTiles.push_back(val);
+ if (succeeded(parser.parseOptionalKeyword("extra_pad_tiles"))) {
+ if (parser.parseEqual())
+ return failure();
+ SmallVector<int64_t> extraPadTilesVec;
+ if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&]() {
+ int64_t value;
+ if (parser.parseInteger(value))
+ return failure();
+ extraPadTilesVec.push_back(value);
+ return success();
+ }))
+ return failure();
+ if (!isEmptyOrZeroArray(extraPadTilesVec)) {
+ extraPadTiles =
+ parser.getBuilder().getDenseI64ArrayAttr(extraPadTilesVec);
+ }
+ }
+
if (parser.parseKeyword("into") || parser.parseOperand(dest))
return failure();
@@ -5395,6 +5479,8 @@ ParseResult PackOp::parse(OpAsmParser &parser, OperationState &result) {
result.addAttribute("inner_dims_pos", innerDimsPos);
if (outerDimsPerm)
result.addAttribute("outer_dims_perm", outerDimsPerm);
+ if (extraPadTiles)
+ result.addAttribute("extra_pad_tiles", extraPadTiles);
SmallVector<int32_t> segmentSizes = {
1, 1, static_cast<int32_t>(paddingValue.size()),
@@ -5429,11 +5515,18 @@ void PackOp::print(OpAsmPrinter &p) {
p << " inner_tiles = ";
printDynamicIndexList(p, *this, getInnerTiles(), getStaticInnerTilesAttr());
+ if (!isEmptyOrZeroArray(getExtraPadTiles())) {
+ p << " extra_pad_tiles = [";
+ llvm::interleaveComma(getExtraPadTiles(), p);
+ p << "]";
+ }
+
p << " into " << getDest();
p.printOptionalAttrDict((*this)->getAttrs(),
{"static_inner_tiles", "inner_dims_pos",
- "outer_dims_perm", "operandSegmentSizes"});
+ "outer_dims_perm", "extra_pad_tiles",
+ "operandSegmentSizes"});
p << " : " << getSource().getType();
p << " -> " << getDest().getType();
@@ -5443,7 +5536,8 @@ void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
Value dest, ArrayRef<int64_t> innerDimsPos,
ArrayRef<OpFoldResult> innerTiles,
std::optional<Value> paddingValue,
- ArrayRef<int64_t> outerDimsPerm) {
+ ArrayRef<int64_t> outerDimsPerm,
+ ArrayRef<int64_t> extraPadTiles) {
assert(innerDimsPos.size() == innerTiles.size() &&
"number of tile sizes specified must match the specified number of "
"original dimensions to be tiled");
@@ -5455,7 +5549,10 @@ void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
outerDimsPerm.empty() ? nullptr
: builder.getDenseI64ArrayAttr(outerDimsPerm),
builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
- builder.getDenseI64ArrayAttr(staticTileSizes));
+ builder.getDenseI64ArrayAttr(staticTileSizes),
+ isEmptyOrZeroArray(extraPadTiles)
+ ? nullptr
+ : builder.getDenseI64ArrayAttr(extraPadTiles));
}
LogicalResult
@@ -5504,7 +5601,10 @@ bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outputShape,
ArrayRef<int64_t> outerDimsPerm,
- ArrayRef<OpFoldResult> innerTiles) {
+ ArrayRef<OpFoldResult> innerTiles,
+ ArrayRef<int64_t> extraPadTiles) {
+ if (!isEmptyOrZeroArray(extraPadTiles))
+ return true;
SmallVector<int64_t> outputTileSizes(
outputShape.take_front(inputShape.size()));
if (!outerDimsPerm.empty()) {
@@ -5535,7 +5635,10 @@ bool PackOp::requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outputShape,
ArrayRef<int64_t> outerDimsPerm,
- ArrayRef<OpFoldResult> innerTiles) {
+ ArrayRef<OpFoldResult> innerTiles,
+ ArrayRef<int64_t> extraPadTiles) {
+ if (!isEmptyOrZeroArray(extraPadTiles))
+ return true;
SmallVector<int64_t> outputTileSizes(
outputShape.take_front(inputShape.size()));
if (!outerDimsPerm.empty()) {
@@ -5576,7 +5679,7 @@ LogicalResult PackOp::verify() {
if (!paddingValue &&
requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
getDestType().getShape(), getOuterDimsPerm(),
- getMixedTiles())) {
+ getMixedTiles(), getExtraPadTiles())) {
return emitOpError(
"invalid tile factor or output size provided. Only full tiles are "
"supported when padding_value is not set");
@@ -5602,7 +5705,8 @@ asShapeWithAnyValueAsDynamic(ArrayRef<OpFoldResult> ofrs) {
SmallVector<int64_t> PackOp::inferPackedShape(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> innerTileSizes,
ArrayRef<int64_t> innerDimsPos,
- ArrayRef<int64_t> outerDimsPerm) {
+ ArrayRef<int64_t> outerDimsPerm,
+ ArrayRef<int64_t> extraPadTiles) {
SmallVector<int64_t> resultShape = llvm::to_vector(inputShape);
for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
@@ -5614,6 +5718,7 @@ SmallVector<int64_t> PackOp::inferPackedShape(ArrayRef<int64_t> inputShape,
resultShape[tiledDim.value()] = llvm::divideCeilSigned(
resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
}
+ addExtraTrailingTiles(resultShape, innerDimsPos, extraPadTiles);
// Swap tile loops if outer_dims_perm is available.
if (!outerDimsPerm.empty())
@@ -5627,7 +5732,7 @@ SmallVector<int64_t> PackOp::inferPackedShape(ArrayRef<int64_t> inputShape,
SmallVector<OpFoldResult> PackOp::getResultShape(
OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
- ArrayRef<int64_t> outerDimsPerm) {
+ ArrayRef<int64_t> outerDimsPerm, ArrayRef<int64_t> extraPadTiles) {
SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
AffineExpr s0, s1;
@@ -5638,6 +5743,7 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
builder, loc, ceilDivExpr,
{resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
}
+ addExtraTrailingTiles(builder, loc, resultDims, innerDimsPos, extraPadTiles);
if (!outerDimsPerm.empty())
applyPermutationToVector(resultDims, outerDimsPerm);
resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
@@ -5645,7 +5751,7 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
SmallVector<int64_t> resultTypeShape =
inferPackedShape(asShapeWithAnyValueAsDynamic(sourceDims),
asShapeWithAnyValueAsDynamic(innerTileSizes),
- innerDimsPos, outerDimsPerm);
+ innerDimsPos, outerDimsPerm, extraPadTiles);
// Fix-up `resultDims` to ensure that they are Value's if and only if the
// result type shape says it's a dynamic dim. This is needed as callers may
@@ -5663,25 +5769,30 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
RankedTensorType PackOp::inferPackedTensorType(
RankedTensorType sourceType, ArrayRef<int64_t> innerTileSizes,
- ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
- SmallVector<int64_t> resultShape = inferPackedShape(
- sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
+ ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm,
+ ArrayRef<int64_t> extraPadTiles) {
+ SmallVector<int64_t> resultShape =
+ inferPackedShape(sourceType.getShape(), innerTileSizes, innerDimsPos,
+ outerDimsPerm, extraPadTiles);
return RankedTensorType::get(resultShape, sourceType.getElementType());
}
MemRefType PackOp::inferPackedMemRefType(MemRefType sourceType,
ArrayRef<int64_t> innerTileSizes,
ArrayRef<int64_t> innerDimsPos,
- ArrayRef<int64_t> outerDimsPerm) {
- SmallVector<int64_t> resultShape = inferPackedShape(
- sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
+ ArrayRef<int64_t> outerDimsPerm,
+ ArrayRef<int64_t> extraPadTiles) {
+ SmallVector<int64_t> resultShape =
+ inferPackedShape(sourceType.getShape(), innerTileSizes, innerDimsPos,
+ outerDimsPerm, extraPadTiles);
return MemRefType::get(resultShape, sourceType.getElementType());
}
Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
ArrayRef<OpFoldResult> innerTileSizes,
ArrayRef<int64_t> innerDimsPos,
- ArrayRef<int64_t> outerDimsPerm) {
+ ArrayRef<int64_t> outerDimsPerm,
+ ArrayRef<int64_t> extraPadTiles) {
AffineExpr dim0, dim1;
bindDims(b.getContext(), dim0, dim1);
auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
@@ -5703,6 +5814,7 @@ Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
OpFoldResult tileSize = std::get<1>(it);
mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
}
+ addExtraTrailingTiles(b, loc, mixedSizes, innerDimsPos, extraPadTiles);
if (!outerDimsPerm.empty())
applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
@@ -5716,12 +5828,13 @@ PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
ArrayRef<int64_t> outerPermutation) {
PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
*this, innerPermutation, outerPermutation);
- Value transposedDest =
- createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
- metadata.innerDimsPos, metadata.outerDimsPerm);
+ Value transposedDest = createDestinationTensor(
+ b, loc, getSource(), metadata.innerTiles, metadata.innerDimsPos,
+ metadata.outerDimsPerm, metadata.trailingTileAdjustments);
return PackOp::create(b, loc, getSource(), transposedDest,
metadata.innerDimsPos, metadata.innerTiles,
- getPaddingValue(), metadata.outerDimsPerm);
+ getPaddingValue(), metadata.outerDimsPerm,
+ metadata.trailingTileAdjustments);
}
template <typename OpTy>
@@ -5810,6 +5923,10 @@ static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
isIdentityPermutation(unPackOp.getOuterDimsPerm());
}
+static bool hasSameTrailingTileAdjustments(PackOp packOp, UnPackOp unPackOp) {
+ return packOp.getExtraPadTiles() == unPackOp.getDropLastTiles();
+}
+
// Return true if pack and unpack have the same tiles.
// Same SSA values or same integer constants.
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
@@ -5834,7 +5951,7 @@ static bool paddingIsNotNeeded(PackOp op) {
return false;
return !PackOp::requirePaddingValue(
srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
- op.getOuterDimsPerm(), op.getMixedTiles());
+ op.getOuterDimsPerm(), op.getMixedTiles(), op.getExtraPadTiles());
}
/// Returns true if the `srcShape` or `destShape` is different from the one in
@@ -5883,6 +6000,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
if (unPackOp.getSourceType() == packOp.getDestType() &&
!packOp.getPaddingValue() &&
hasSameInnerOuterAttribute(packOp, unPackOp) &&
+ hasSameTrailingTileAdjustments(packOp, unPackOp) &&
haveSameTiles(packOp, unPackOp)) {
rewriter.replaceOp(packOp, unPackOp.getSource());
return success();
@@ -5938,6 +6056,8 @@ static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType) {
static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
std::is_same<PackOrUnpackOp, UnPackOp>::value,
"Function meant for pack/unpack");
+ if (!isEmptyOrZeroArray(getExtraTrailingTiles(packOp)))
+ return false;
// This is a pad if packing only adds ones and we don't transpose dimensions.
// Check that we are not transposing any dimensions.
@@ -6028,10 +6148,10 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
// TODO: Strictly speaking, discardable attributes should be _discarded_ at
// this point. However, in practice, we use them for things that we'd like
// to preserve. Implement a better abstraction.
- PackOp newOp =
- PackOp::create(rewriter, op.getLoc(), newOperands[0], newOperands[1],
- op.getInnerDimsPos(), newMixedTileSizes,
- op.getPaddingValue(), op.getOuterDimsPerm());
+ PackOp newOp = PackOp::create(rewriter, op.getLoc(), newOperands[0],
+ newOperands[1], op.getInnerDimsPos(),
+ newMixedTileSizes, op.getPaddingValue(),
+ op.getOuterDimsPerm(), op.getExtraPadTiles());
newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
// Replace op.
@@ -6064,7 +6184,7 @@ ParseResult UnPackOp::parse(OpAsmParser &parser, OperationState &result) {
OpAsmParser::UnresolvedOperand source, dest;
SmallVector<OpAsmParser::UnresolvedOperand> dynamicTiles;
SmallVector<int64_t> staticTiles;
- DenseI64ArrayAttr innerDimsPos, outerDimsPerm;
+ DenseI64ArrayAttr innerDimsPos, outerDimsPerm, dropLastTiles;
Type sourceType, destType, resultType;
if (parser.parseOperand(source))
@@ -6109,6 +6229,24 @@ ParseResult UnPackOp::parse(OpAsmParser &parser, OperationState &result) {
for (auto val : staticTilesAttr.asArrayRef())
staticTiles.push_back(val);
+ if (succeeded(parser.parseOptionalKeyword("drop_last_tiles"))) {
+ if (parser.parseEqual())
+ return failure();
+ SmallVector<int64_t> dropLastTilesVec;
+ if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&]() {
+ int64_t value;
+ if (parser.parseInteger(value))
+ return failure();
+ dropLastTilesVec.push_back(value);
+ return success();
+ }))
+ return failure();
+ if (!isEmptyOrZeroArray(dropLastTilesVec)) {
+ dropLastTiles =
+ parser.getBuilder().getDenseI64ArrayAttr(dropLastTilesVec);
+ }
+ }
+
if (parser.parseKeyword("into") || parser.parseOperand(dest))
return failure();
@@ -6147,6 +6285,8 @@ ParseResult UnPackOp::parse(OpAsmParser &parser, OperationState &result) {
result.addAttribute("inner_dims_pos", innerDimsPos);
if (outerDimsPerm)
result.addAttribute("outer_dims_perm", outerDimsPerm);
+ if (dropLastTiles)
+ result.addAttribute("drop_last_tiles", dropLastTiles);
SmallVector<int32_t> segmentSizes = {
1, 1, 0, static_cast<int32_t>(dynamicTiles.size())};
@@ -6175,11 +6315,18 @@ void UnPackOp::print(OpAsmPrinter &p) {
p << " inner_tiles = ";
printDynamicIndexList(p, *this, getInnerTiles(), getStaticInnerTilesAttr());
+ if (!isEmptyOrZeroArray(getDropLastTiles())) {
+ p << " drop_last_tiles = [";
+ llvm::interleaveComma(getDropLastTiles(), p);
+ p << "]";
+ }
+
p << " into " << getDest();
p.printOptionalAttrDict((*this)->getAttrs(),
{"static_inner_tiles", "inner_dims_pos",
- "outer_dims_perm", "operandSegmentSizes"});
+ "outer_dims_perm", "drop_last_tiles",
+ "operandSegmentSizes"});
p << " : " << getSource().getType();
p << " -> " << getDest().getType();
@@ -6244,7 +6391,8 @@ Speculation::Speculatability UnPackOp::getSpeculatability() {
void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
Value dest, ArrayRef<int64_t> innerDimsPos,
ArrayRef<OpFoldResult> innerTiles,
- ArrayRef<int64_t> outerDimsPerm) {
+ ArrayRef<int64_t> outerDimsPerm,
+ ArrayRef<int64_t> dropLastTiles) {
assert(innerDimsPos.size() == innerTiles.size() &&
"number of tile sizes specified must match the specified number of "
"original dimensions to be tiled");
@@ -6255,14 +6403,18 @@ void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
outerDimsPerm.empty() ? nullptr
: builder.getDenseI64ArrayAttr(outerDimsPerm),
builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
- builder.getDenseI64ArrayAttr(staticTileSizes));
+ builder.getDenseI64ArrayAttr(staticTileSizes),
+ isEmptyOrZeroArray(dropLastTiles)
+ ? nullptr
+ : builder.getDenseI64ArrayAttr(dropLastTiles));
}
Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
Value source,
ArrayRef<OpFoldResult> innerTileSizes,
ArrayRef<int64_t> innerDimsPos,
- ArrayRef<int64_t> outerDimsPerm) {
+ ArrayRef<int64_t> outerDimsPerm,
+ ArrayRef<int64_t> dropLastTiles) {
AffineExpr sym0, sym1;
bindSymbols(b.getContext(), sym0, sym1);
auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
@@ -6284,6 +6436,21 @@ Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
mixedSizes, invertPermutationVector(outerDimsPerm));
}
+ if (!dropLastTiles.empty()) {
+ AffineExpr d0, c0;
+ bindDims(b.getContext(), d0);
+ bindSymbols(b.getContext(), c0);
+ auto subConstant = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/1,
+ d0 - c0, b.getContext());
+ for (auto [index, dimPos] : llvm::enumerate(innerDimsPos)) {
+ if (dropLastTiles[index] == 0)
+ continue;
+ mixedSizes[dimPos] = affine::makeComposedFoldedAffineApply(
+ b, loc, subConstant,
+ {mixedSizes[dimPos], b.getIndexAttr(dropLastTiles[index])});
+ }
+ }
+
for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
@@ -6299,7 +6466,8 @@ UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
*this, innerPermutation, outerPermutation);
return UnPackOp::create(b, loc, transposedSource, getDest(),
metadata.innerDimsPos, metadata.innerTiles,
- metadata.outerDimsPerm);
+ metadata.outerDimsPerm,
+ metadata.trailingTileAdjustments);
}
/// Returns true if the `srcShape` or `destShape` is different from the one in
@@ -6350,6 +6518,7 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
return failure();
if (packOp.getPaddingValue() ||
!hasSameInnerOuterAttribute(packOp, unPackOp) ||
+ !hasSameTrailingTileAdjustments(packOp, unPackOp) ||
!haveSameTiles(packOp, unPackOp))
return failure();
rewriter.replaceOp(unPackOp, packOp.getSource());
@@ -6402,7 +6571,8 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
}
UnPackOp newOp = UnPackOp::create(
rewriter, loc, source, dest, unPackOp.getInnerDimsPos(),
- unPackOp.getMixedTiles(), unPackOp.getOuterDimsPerm());
+ unPackOp.getMixedTiles(), unPackOp.getOuterDimsPerm(),
+ unPackOp.getDropLastTiles());
rewriter.replaceOpWithNewOp<tensor::CastOp>(
unPackOp, unPackOp.getResult().getType(), newOp.getResult());
return success();
@@ -6422,8 +6592,9 @@ bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
SmallVector<int64_t> outerShapeWithoutTranspose =
getPackedOuterShapeWithoutTransposition(*this);
SmallVector<bool> areOuterDimsTiled(outerShapeWithoutTranspose.size(), false);
- for (auto [pos, tileSize] :
- llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
+ for (auto [index, values] : llvm::enumerate(llvm::zip_equal(
+ this->getInnerDimsPos(), this->getStaticInnerTiles()))) {
+ auto [pos, tileSize] = values;
areOuterDimsTiled[pos] = true;
if (unpackedTypeAfterFold.isDynamicDim(pos))
return false;
@@ -6431,6 +6602,11 @@ bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
return false;
if (ShapedType::isDynamic(tileSize))
return false;
+ if (!getDropLastTiles().empty()) {
+ if (outerShapeWithoutTranspose[pos] < getDropLastTiles()[index])
+ return false;
+ outerShapeWithoutTranspose[pos] -= getDropLastTiles()[index];
+ }
int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
unpackedTypeAfterFold.getDimSize(pos);
if (paddingSize >= tileSize)
@@ -6509,9 +6685,10 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
// TODO: Strictly speaking, discardable attributes should be _discarded_ at
// this point. However, in practice, we use them for things that we'd like
// to preserve. Implement a better abstraction.
- UnPackOp newOp = UnPackOp::create(rewriter, op.getLoc(), sourceTensor,
- newOperands[1], op.getInnerDimsPos(),
- newMixedTileSizes, op.getOuterDimsPerm());
+ UnPackOp newOp =
+ UnPackOp::create(rewriter, op.getLoc(), sourceTensor, newOperands[1],
+ op.getInnerDimsPos(), newMixedTileSizes,
+ op.getOuterDimsPerm(), op.getDropLastTiles());
newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
// Replace op.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index d36ca43a6cbb3..c6dcbe62fc1d6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -365,6 +365,7 @@ packGenericOp(RewriterBase &rewriter, GenericOp genericOp, Value dest,
auto hasEquivalentTiles = [](PackOp packOp, UnPackOp unPackOp) {
return packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm() &&
packOp.getInnerDimsPos() == unPackOp.getInnerDimsPos() &&
+ packOp.getExtraPadTiles() == unPackOp.getDropLastTiles() &&
llvm::equal(packOp.getMixedTiles(), unPackOp.getMixedTiles());
};
DenseMap<OpOperand *, PackedOperandDetails> packedOperandMap;
@@ -642,10 +643,10 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<linalg::PackOp> {
SmallVector<OpFoldResult> mixedTiles = packOp.getMixedTiles();
auto empty = linalg::PackOp::createDestinationTensor(
rewriter, loc, padOp.getSource(), mixedTiles, innerDimsPos,
- outerDimsPerm);
+ outerDimsPerm, packOp.getExtraPadTiles());
auto sourcePack = linalg::PackOp::create(
rewriter, loc, padOp.getSource(), empty, innerDimsPos, mixedTiles,
- /*padding=*/std::nullopt, outerDimsPerm);
+ /*padding=*/std::nullopt, outerDimsPerm, packOp.getExtraPadTiles());
// If we have `outer_dims_perms` we need to adjust the padded dimensions.
SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
@@ -668,10 +669,11 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<linalg::PackOp> {
// replace the other uses.
if (!padOp->hasOneUse()) {
auto unpackEmpty = linalg::UnPackOp::createDestinationTensor(
- rewriter, loc, newPadOp, mixedTiles, innerDimsPos, outerDimsPerm);
- UnPackOp unpackedPad =
- linalg::UnPackOp::create(rewriter, loc, newPadOp, unpackEmpty,
- innerDimsPos, mixedTiles, outerDimsPerm);
+ rewriter, loc, newPadOp, mixedTiles, innerDimsPos, outerDimsPerm,
+ packOp.getExtraPadTiles());
+ UnPackOp unpackedPad = linalg::UnPackOp::create(
+ rewriter, loc, newPadOp, unpackEmpty, innerDimsPos, mixedTiles,
+ outerDimsPerm, packOp.getExtraPadTiles());
rewriter.replaceAllUsesExcept(padOp, unpackedPad.getResult(), sourcePack);
}
@@ -803,11 +805,11 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
auto emptyOp = linalg::PackOp::createDestinationTensor(
rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(),
- projectedInnerDimsPos, newOuterDimsPerm);
+ projectedInnerDimsPos, newOuterDimsPerm, packOp.getExtraPadTiles());
auto newPackOp = linalg::PackOp::create(
rewriter, packOp.getLoc(), collapseOp.getSrc(), emptyOp,
projectedInnerDimsPos, packOp.getMixedTiles(), packOp.getPaddingValue(),
- newOuterDimsPerm);
+ newOuterDimsPerm, packOp.getExtraPadTiles());
SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
// First apply the permutation on the reassociations of the outer dims.
@@ -933,7 +935,8 @@ bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
// dimensions or packing extending dimensions.
RankedTensorType newPackType = linalg::PackOp::inferPackedTensorType(
expandOp.getSrcType(), packOp.getStaticInnerTiles(),
- projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
+ projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{},
+ packOp.getExtraPadTiles());
auto reassocExpand =
getReassociationIndicesForReshape(newPackType, packOp.getDestType());
if (!reassocExpand)
@@ -942,11 +945,12 @@ bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
Value destTensor = linalg::PackOp::createDestinationTensor(
rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(),
- projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
+ projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{},
+ packOp.getExtraPadTiles());
PackOp packedVal = linalg::PackOp::create(
rewriter, packOp.getLoc(), expandOp.getSrc(), destTensor,
projectedInnerDimsPos, packOp.getMixedTiles(), packOp.getPaddingValue(),
- /*outerDimsPerm=*/SmallVector<int64_t>{});
+ /*outerDimsPerm=*/SmallVector<int64_t>{}, packOp.getExtraPadTiles());
Value newExpandOp = tensor::ExpandShapeOp::create(
rewriter, packOp.getLoc(), packOp.getDestType(), packedVal.getResult(),
@@ -1068,17 +1072,19 @@ static LogicalResult pushDownUnPackOpThroughExpandShape(
}
RankedTensorType newExpandType = linalg::PackOp::inferPackedTensorType(
- expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm);
+ expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm,
+ unPackOp.getDropLastTiles());
auto newExpandOp =
tensor::ExpandShapeOp::create(rewriter, expandOp.getLoc(), newExpandType,
unPackOp.getSource(), newReassocIndices);
auto emptyOp = linalg::UnPackOp::createDestinationTensor(
rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(),
- projectedInnerDimsPos, newOuterDimsPerm);
+ projectedInnerDimsPos, newOuterDimsPerm, unPackOp.getDropLastTiles());
auto newUnPackOp = linalg::UnPackOp::create(
rewriter, unPackOp.getLoc(), newExpandOp.getResult(), emptyOp,
- projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm);
+ projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm,
+ unPackOp.getDropLastTiles());
rewriter.replaceOp(expandOp, newUnPackOp);
return success();
@@ -1252,7 +1258,7 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
Value unPackOpRes =
linalg::UnPackOp::create(rewriter, genericOp.getLoc(), newResult,
destPack.getSource(), innerDimsPos, mixedTiles,
- outerDimsPerm)
+ outerDimsPerm, destPack.getExtraPadTiles())
.getResult();
return std::make_tuple(newGenericOp, unPackOpRes);
@@ -1340,7 +1346,7 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
UnPackOp replacement = linalg::UnPackOp::create(
rewriter, loc, newPadOp.getResult(), outputUnPack, innerDimsPos,
- unpackOp.getMixedTiles(), outerDimsPerm);
+ unpackOp.getMixedTiles(), outerDimsPerm, unpackOp.getDropLastTiles());
rewriter.replaceOp(padOp, replacement);
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 993eae62535c3..e513443687ffa 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -180,6 +180,9 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
// TODO: Support Memref UnPackOp. Temporarily return failure.
if (!unpackOp.hasPureTensorSemantics())
return failure();
+ if (!unpackOp.getDropLastTiles().empty())
+ return rewriter.notifyMatchFailure(unpackOp,
+ "expects no drop_last_tiles");
ShapedType destType = unpackOp.getDestType();
if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
@@ -236,9 +239,10 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
ShapedType unpackedType = packOp.getSourceType();
SmallVector<int64_t> outerShapeWithoutTranspose =
getPackedOuterShapeWithoutTransposition(packOp);
- for (auto [pos, tileSize, high] :
- llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getStaticInnerTiles(),
- padOp.getMixedHighPad())) {
+ for (auto [index, tuple] : llvm::enumerate(llvm::zip_equal(
+ packOp.getInnerDimsPos(), packOp.getStaticInnerTiles(),
+ padOp.getMixedHighPad()))) {
+ auto [pos, tileSize, high] = tuple;
if (unpackedType.isDynamicDim(pos))
return failure();
if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
@@ -248,17 +252,19 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
std::optional<int64_t> cstHigh = getConstantIntValue(high);
if (!cstHigh)
return failure();
- int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
- unpackedType.getDimSize(pos);
- // Do not fold the op if it requires artificial padding.
- if (paddingSize + cstHigh.value() >= tileSize)
+ int64_t originalDim = unpackedType.getDimSize(pos) - cstHigh.value();
+ int64_t baseOuter = llvm::divideCeilSigned(originalDim, tileSize);
+ int64_t adjustedOuter = baseOuter;
+ if (!packOp.getExtraPadTiles().empty())
+ adjustedOuter += packOp.getExtraPadTiles()[index];
+ if (adjustedOuter != outerShapeWithoutTranspose[pos])
return failure();
}
rewriter.replaceOpWithNewOp<PackOp>(
packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
- packOp.getMixedTiles(), constantPaddingValue,
- packOp.getOuterDimsPerm());
+ packOp.getMixedTiles(), constantPaddingValue, packOp.getOuterDimsPerm(),
+ packOp.getExtraPadTiles());
return success();
}
@@ -299,7 +305,8 @@ struct FoldUnpackWithExtractSliceOp
rewriter, sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType);
rewriter.replaceOpWithNewOp<UnPackOp>(
sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(),
- unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
+ unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm(),
+ unpackOp.getDropLastTiles());
return success();
}
@@ -368,6 +375,7 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
SmallVector<int64_t> newOuterDimsPermVec;
SmallVector<int64_t> newInnerDimsPosVec;
SmallVector<OpFoldResult> newMixedInnerTilesVec;
+ SmallVector<int64_t> newExtraPadTilesVec;
int64_t srcRank = packOp.getSourceRank();
if (!checkAndPermute(transposePerm, outerDimsPerm, newOuterDimsPermVec,
@@ -382,15 +390,19 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
int64_t remappedPosition = transposePerm[i] - srcRank;
newMixedInnerTilesVec.push_back(mixedInnerTiles[remappedPosition]);
newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
+ if (!packOp.getExtraPadTiles().empty())
+ newExtraPadTilesVec.push_back(
+ packOp.getExtraPadTiles()[remappedPosition]);
}
Value output = packOp.createDestinationTensor(
rewriter, linalgOp.getLoc(), packOp.getSource(), newMixedInnerTilesVec,
- newInnerDimsPosVec, newOuterDimsPermVec);
+ newInnerDimsPosVec, newOuterDimsPermVec, newExtraPadTilesVec);
rewriter.replaceOpWithNewOp<PackOp>(
linalgOp, packOp.getSource(), output, newInnerDimsPosVec,
- newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);
+ newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec,
+ newExtraPadTilesVec);
return success();
}
@@ -432,6 +444,7 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
auto outerDimsPerm = packOp.getOuterDimsPerm();
auto innerDimsPos = packOp.getInnerDimsPos();
SmallVector<int64_t> newInnerDimsPosVec;
+ SmallVector<int64_t> newExtraPadTilesVec;
SmallVector<int64_t> newOuterDimsPermVec =
llvm::to_vector(transposePermutation);
@@ -440,16 +453,21 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
// Can't use applyPermutationToVector for newInnerDimsPosVec since input and
// permutation rank won't necessarily be equal in all cases.
- for (auto dim : innerDimsPos)
+ for (auto [index, dim] : llvm::enumerate(innerDimsPos)) {
newInnerDimsPosVec.push_back(transposePermutation[dim]);
+ if (!packOp.getExtraPadTiles().empty())
+ newExtraPadTilesVec.push_back(packOp.getExtraPadTiles()[index]);
+ }
Value output = packOp.createDestinationTensor(
rewriter, packOp.getLoc(), linalgOp->getOperand(0),
- packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
+ packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec,
+ newExtraPadTilesVec);
rewriter.replaceOpWithNewOp<PackOp>(
packOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
- packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec);
+ packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec,
+ newExtraPadTilesVec);
return success();
}
@@ -492,13 +510,17 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
auto outerDimsPerm = unPackOp.getOuterDimsPerm();
auto innerDimsPos = unPackOp.getInnerDimsPos();
SmallVector<int64_t> newInnerDimsPosVec;
+ SmallVector<int64_t> newDropLastTilesVec;
SmallVector<int64_t> newOuterDimsPermVec =
invertPermutationVector(maybePerm.value());
// Can't use applyPermutationToVector for newInnerDimsPosVec since input and
// permutation rank won't necessarily be equal in all cases.
- for (auto dim : innerDimsPos)
+ for (auto [index, dim] : llvm::enumerate(innerDimsPos)) {
newInnerDimsPosVec.push_back(newOuterDimsPermVec[dim]);
+ if (!unPackOp.getDropLastTiles().empty())
+ newDropLastTilesVec.push_back(unPackOp.getDropLastTiles()[index]);
+ }
if (!outerDimsPerm.empty())
applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
@@ -506,7 +528,8 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
// Reuse the destination of the transpose op.
rewriter.replaceOpWithNewOp<UnPackOp>(
linalgOp, unPackOp.getSource(), linalgOp.getDpsInits()[0],
- newInnerDimsPosVec, unPackOp.getMixedTiles(), newOuterDimsPermVec);
+ newInnerDimsPosVec, unPackOp.getMixedTiles(), newOuterDimsPermVec,
+ newDropLastTilesVec);
return success();
}
@@ -559,6 +582,7 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
SmallVector<int64_t> newOuterDimsPermVec;
SmallVector<int64_t> newInnerDimsPosVec;
SmallVector<OpFoldResult> newMixedInnerTilesVec;
+ SmallVector<int64_t> newDropLastTilesVec;
if (!checkAndPermute(inverseTransposePerm, outerDimsPerm,
newOuterDimsPermVec, destRank))
return rewriter.notifyMatchFailure(
@@ -571,6 +595,9 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
int64_t remappedPosition = inverseTransposePerm[i] - destRank;
newMixedInnerTilesVec.push_back(mixedInnerTilesVec[remappedPosition]);
newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
+ if (!unPackOp.getDropLastTiles().empty())
+ newDropLastTilesVec.push_back(
+ unPackOp.getDropLastTiles()[remappedPosition]);
}
auto elemType =
@@ -580,7 +607,7 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
rewriter.replaceOpWithNewOp<UnPackOp>(
unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
- newMixedInnerTilesVec, newOuterDimsPermVec);
+ newMixedInnerTilesVec, newOuterDimsPermVec, newDropLastTilesVec);
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 260e36fb47f04..74ecfb5ccc3e2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -549,8 +549,8 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
if (areConstantTiles && operandType.hasStaticShape() &&
!linalg::PackOp::requirePaddingValue(
operandType.getShape(), innerPos,
- cast<ShapedType>(dest.getType()).getShape(), {},
- innerPackSizes)) {
+ cast<ShapedType>(dest.getType()).getShape(), {}, innerPackSizes,
+ {})) {
packOps.push_back(linalg::PackOp::create(rewriter, loc, operand, dest,
innerPos, innerPackSizes));
} else {
@@ -588,7 +588,9 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
// Build the symmetrical UnPackOp to the existing PackOp.
unPackOps.push_back(linalg::UnPackOp::create(
rewriter, packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
- maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
+ maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles(),
+ maybePackedInit.getOuterDimsPerm(),
+ maybePackedInit.getExtraPadTiles()));
results.push_back(unPackOps.back().getResult());
}
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 1f6d1a68fbbb8..f6414c934e2f5 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -299,6 +299,7 @@ def pack(
*,
padding_value=None,
outer_dims_perm=None,
+ extra_pad_tiles=None,
loc=None,
ip=None,
) -> ir.Value:
@@ -321,6 +322,7 @@ def pack(
static_inner_tiles=static_inner_tiles,
padding_value=padding_value,
outer_dims_perm=outer_dims_perm,
+ extra_pad_tiles=extra_pad_tiles,
loc=loc,
ip=ip,
)
@@ -334,6 +336,7 @@ def unpack(
inner_tiles,
*,
outer_dims_perm=None,
+ drop_last_tiles=None,
loc=None,
ip=None,
) -> ir.Value:
@@ -354,6 +357,7 @@ def unpack(
inner_tiles=dynamic_inner_tiles,
static_inner_tiles=static_inner_tiles,
outer_dims_perm=outer_dims_perm,
+ drop_last_tiles=drop_last_tiles,
loc=loc,
ip=ip,
)
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 0c5a1c6108ae3..bba1261d02289 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -2158,6 +2158,32 @@ func.func @fold_pack_unpack_tensor(%x: tensor<2x3xf32>) -> tensor<2x3xf32> {
// -----
+// CHECK-LABEL: func.func @keep_unpack_pack_tensor_with_tile_adjustments
+// CHECK: linalg.unpack
+// CHECK: linalg.pack
+func.func @keep_unpack_pack_tensor_with_tile_adjustments(%x: tensor<17x10x8x32xf32>, %dest: tensor<127x255xf32>) -> tensor<17x10x8x32xf32> {
+ %unpacked = linalg.unpack %x inner_dims_pos = [0, 1] inner_tiles = [8, 32] drop_last_tiles = [1, 2]
+ into %dest : tensor<17x10x8x32xf32> -> tensor<127x255xf32>
+ %cst = arith.constant 0.0 : f32
+ %packed = linalg.pack %unpacked padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] extra_pad_tiles = [1, 2]
+ into %x : tensor<127x255xf32> -> tensor<17x10x8x32xf32>
+ return %packed : tensor<17x10x8x32xf32>
+}
+
+// CHECK-LABEL: func.func @do_not_fold_unpack_pack_tensor_with_mismatched_tile_adjustments
+// CHECK: linalg.unpack
+// CHECK: linalg.pack
+func.func @do_not_fold_unpack_pack_tensor_with_mismatched_tile_adjustments(%x: tensor<17x10x8x32xf32>, %dest: tensor<127x255xf32>, %packed_dest: tensor<17x9x8x32xf32>) -> tensor<17x9x8x32xf32> {
+ %unpacked = linalg.unpack %x inner_dims_pos = [0, 1] inner_tiles = [8, 32] drop_last_tiles = [1, 2]
+ into %dest : tensor<17x10x8x32xf32> -> tensor<127x255xf32>
+ %cst = arith.constant 0.0 : f32
+ %packed = linalg.pack %unpacked padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] extra_pad_tiles = [1, 1]
+ into %packed_dest : tensor<127x255xf32> -> tensor<17x9x8x32xf32>
+ return %packed : tensor<17x9x8x32xf32>
+}
+
+// -----
+
// Test that pack/unpack canonicalization is disabled for memref versions.
// CHECK-LABEL: func.func @negative_pack_unpack_memref_no_canonicalization
// CHECK: linalg.pack
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index d9192cbda14e7..75347fe28ef87 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1911,6 +1911,32 @@ func.func @pack_invalid_outer_dims_perm(%source: tensor<128x256xf32>, %dest: ten
// -----
+func.func @pack_invalid_extra_pad_tiles_length(%source: tensor<128x256xf32>, %dest: tensor<17x8x8x32xf32>) -> tensor<17x8x8x32xf32> {
+ %pad = arith.constant 0.0 : f32
+ // expected-error at +1 {{extra_pad_tiles must have the same number of entries as inner_dims_pos}}
+ %0 = linalg.pack %source padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] extra_pad_tiles = [1] into %dest : tensor<128x256xf32> -> tensor<17x8x8x32xf32>
+ return %0 : tensor<17x8x8x32xf32>
+}
+
+// -----
+
+func.func @pack_invalid_negative_extra_pad_tiles(%source: tensor<128x256xf32>, %dest: tensor<17x8x8x32xf32>) -> tensor<17x8x8x32xf32> {
+ %pad = arith.constant 0.0 : f32
+ // expected-error at +1 {{extra_pad_tiles must contain only non-negative values}}
+ %0 = linalg.pack %source padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] extra_pad_tiles = [1, -1] into %dest : tensor<128x256xf32> -> tensor<17x8x8x32xf32>
+ return %0 : tensor<17x8x8x32xf32>
+}
+
+// -----
+
+func.func @pack_invalid_extra_pad_tiles_without_padding(%source: tensor<128x256xf32>, %dest: tensor<17x8x8x32xf32>) -> tensor<17x8x8x32xf32> {
+ // expected-error at +1 {{invalid tile factor or output size provided. Only full tiles are supported when padding_value is not set}}
+ %0 = linalg.pack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32] extra_pad_tiles = [1, 0] into %dest : tensor<128x256xf32> -> tensor<17x8x8x32xf32>
+ return %0 : tensor<17x8x8x32xf32>
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// linalg.unpack
//===----------------------------------------------------------------------===//
@@ -1939,6 +1965,22 @@ func.func @unpack_invalid_outer_dims_perm(%source: tensor<128x256xf32>, %dest: t
// -----
+func.func @unpack_invalid_drop_last_tiles_length(%source: tensor<17x8x8x32xf32>, %dest: tensor<128x256xf32>) -> tensor<128x256xf32> {
+ // expected-error at +1 {{drop_last_tiles must have the same number of entries as inner_dims_pos}}
+ %0 = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32] drop_last_tiles = [1] into %dest : tensor<17x8x8x32xf32> -> tensor<128x256xf32>
+ return %0 : tensor<128x256xf32>
+}
+
+// -----
+
+func.func @unpack_invalid_negative_drop_last_tiles(%source: tensor<17x8x8x32xf32>, %dest: tensor<128x256xf32>) -> tensor<128x256xf32> {
+ // expected-error at +1 {{drop_last_tiles must contain only non-negative values}}
+ %0 = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32] drop_last_tiles = [-1, 0] into %dest : tensor<17x8x8x32xf32> -> tensor<128x256xf32>
+ return %0 : tensor<128x256xf32>
+}
+
+// -----
+
func.func @pack_with_artificial_padding(%input: tensor<9xf32>, %output: tensor<3x8xf32>) -> tensor<3x8xf32> {
%cst = arith.constant 0.0 : f32
// expected-error at +1 {{expected 'tensor<2x8xf32>' for the packed domain value, got 'tensor<3x8xf32>'}}
@@ -1977,6 +2019,15 @@ func.func @unpack_with_artifical_tiles_that_are_dropped(%input: tensor<3x8xf32>,
// -----
+func.func @unpack_invalid_drop_last_tiles_shape(%input: tensor<3x8xf32>, %output: tensor<9xf32>) -> tensor<9xf32> {
+ // expected-error at +1 {{expected 'tensor<4x8xf32>' for the packed domain value, got 'tensor<3x8xf32>'}}
+ %0 = linalg.unpack %input inner_dims_pos = [0] inner_tiles = [8] drop_last_tiles = [2] into %output
+ : tensor<3x8xf32> -> tensor<9xf32>
+ return %0 : tensor<9xf32>
+}
+
+// -----
+
func.func @unpack_invalid_source_shape(%output: tensor<256x128xf32>, %input: tensor<8x8x4x32xf32>) -> tensor<256x128xf32> {
// expected-error at +1 {{expected 'tensor<8x32x4x32xf32>' for the packed domain value, got 'tensor<8x8x4x32xf32>'}}
%0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x4x32xf32> -> tensor<256x128xf32>
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index bfb92c3289a49..ace2579566025 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -829,3 +829,16 @@ func.func @test_unpack_tensor(%arg0: tensor<16x8x8x32xf32>, %arg1: tensor<128x25
// CHECK: return %[[RESULT]] : tensor<128x256xf32>
return %0 : tensor<128x256xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @test_pack_unpack_tensor_with_tile_adjustments
+func.func @test_pack_unpack_tensor_with_tile_adjustments(%arg0: tensor<127x255xf32>, %arg1: tensor<17x10x8x32xf32>) -> tensor<127x255xf32> {
+ %pad = arith.constant 0.0 : f32
+ // CHECK: %[[PACK:.*]] = linalg.pack %{{.*}} padding_value(%{{.*}} : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] extra_pad_tiles = [1, 2] into %{{.*}} : tensor<127x255xf32> -> tensor<17x10x8x32xf32>
+ %0 = linalg.pack %arg0 padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] extra_pad_tiles = [1, 2] into %arg1 : tensor<127x255xf32> -> tensor<17x10x8x32xf32>
+ // CHECK: %[[UNPACK:.*]] = linalg.unpack %[[PACK]] inner_dims_pos = [0, 1] inner_tiles = [8, 32] drop_last_tiles = [1, 2] into %{{.*}} : tensor<17x10x8x32xf32> -> tensor<127x255xf32>
+ %1 = linalg.unpack %0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] drop_last_tiles = [1, 2] into %arg0 : tensor<17x10x8x32xf32> -> tensor<127x255xf32>
+ // CHECK: return %[[UNPACK]] : tensor<127x255xf32>
+ return %1 : tensor<127x255xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index b6fe67a9ae1f3..0fbbd812a9292 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -229,6 +229,55 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: func.func @pack_with_extra_pad_tiles(
+func.func @pack_with_extra_pad_tiles(%arg0: tensor<127x255xf32>, %arg1: tensor<17x10x8x32xf32>) -> tensor<17x10x8x32xf32> {
+ %cst_0 = arith.constant 0.0 : f32
+ // CHECK: tensor.pad {{.*}} low[0, 0] high[9, 65]
+ // CHECK: : tensor<127x255xf32> to tensor<136x320xf32>
+ %pack = linalg.pack %arg0 padding_value(%cst_0 : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] extra_pad_tiles = [1, 2] into %arg1
+ : tensor<127x255xf32> -> tensor<17x10x8x32xf32>
+ return %pack : tensor<17x10x8x32xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %pack = transform.structured.match ops{["linalg.pack"]} in %module_op
+ : (!transform.any_op) -> !transform.op<"linalg.pack">
+ transform.structured.lower_pack %pack : (!transform.op<"linalg.pack">)
+ -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpack_with_drop_last_tiles(
+func.func @unpack_with_drop_last_tiles(%arg0: tensor<17x10x8x32xf32>, %arg1: tensor<127x255xf32>) -> tensor<127x255xf32> {
+ // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<17x8x10x32xf32>
+ // CHECK: %[[TRAN:.*]] = linalg.transpose
+ // CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{.*}} : tensor<17x8x10x32xf32> into tensor<136x320xf32>
+ // CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0] [127, 255] [1, 1]
+ %unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] drop_last_tiles = [1, 2] into %arg1
+ : tensor<17x10x8x32xf32> -> tensor<127x255xf32>
+ return %unpack : tensor<127x255xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %unpack = transform.structured.match ops{["linalg.unpack"]} in %module_op
+ : (!transform.any_op) -> !transform.op<"linalg.unpack">
+ transform.structured.lower_unpack %unpack : (!transform.op<"linalg.unpack">)
+ -> (!transform.op<"tensor.empty">,
+ !transform.op<"linalg.transpose">,
+ !transform.op<"tensor.collapse_shape">,
+ !transform.op<"tensor.extract_slice">,
+ !transform.op<"linalg.copy">)
+ transform.yield
+ }
+}
+
+// -----
+
// When an unpack is a plain 'unpad', lower it to a simple extract_slice.
// CHECK-LABEL: func.func @unpack_as_pad(
func.func @unpack_as_pad(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 68b32098b7782..8f3fad247a6d7 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -644,7 +644,7 @@ def testPackUnPackOp():
@func.FuncOp.from_py_func(
RankedTensorType.get((128, 128), f32),
- RankedTensorType.get((16, 16, 8, 8), f32),
+ RankedTensorType.get((18, 17, 8, 8), f32),
)
def tensor_pack(src, dst):
packed = linalg.pack(
@@ -652,6 +652,7 @@ def tensor_pack(src, dst):
dst,
inner_dims_pos=[1, 0],
inner_tiles=[8, 8],
+ extra_pad_tiles=[1, 2],
padding_value=arith.constant(f32, 0.0),
)
@@ -660,6 +661,7 @@ def tensor_pack(src, dst):
src,
inner_dims_pos=[0, 1],
inner_tiles=[8, 8],
+ drop_last_tiles=[2, 1],
)
return unpacked
@@ -679,10 +681,10 @@ def memref_pack(src, dst):
)
# CHECK-LABEL: func.func @tensor_pack(
- # CHECK-SAME: %[[VAL_0:.*]]: tensor<128x128xf32>, %[[VAL_1:.*]]: tensor<16x16x8x8xf32>) -> tensor<128x128xf32> {
+ # CHECK-SAME: %[[VAL_0:.*]]: tensor<128x128xf32>, %[[VAL_1:.*]]: tensor<18x17x8x8xf32>) -> tensor<128x128xf32> {
# CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
- # CHECK: %[[VAL_3:.*]] = linalg.pack %[[VAL_0]] padding_value(%[[VAL_2]] : f32) inner_dims_pos = [1, 0] inner_tiles = [8, 8] into %[[VAL_1]] : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
- # CHECK: %[[VAL_4:.*]] = linalg.unpack %[[VAL_3]] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %[[VAL_0]] : tensor<16x16x8x8xf32> -> tensor<128x128xf32>
+ # CHECK: %[[VAL_3:.*]] = linalg.pack %[[VAL_0]] padding_value(%[[VAL_2]] : f32) inner_dims_pos = [1, 0] inner_tiles = [8, 8] extra_pad_tiles = [1, 2] into %[[VAL_1]] : tensor<128x128xf32> -> tensor<18x17x8x8xf32>
+ # CHECK: %[[VAL_4:.*]] = linalg.unpack %[[VAL_3]] inner_dims_pos = [0, 1] inner_tiles = [8, 8] drop_last_tiles = [2, 1] into %[[VAL_0]] : tensor<18x17x8x8xf32> -> tensor<128x128xf32>
# CHECK: return %[[VAL_4]] : tensor<128x128xf32>
# CHECK: }
# CHECK-LABEL: func.func @memref_pack(
More information about the Mlir-commits
mailing list