[Mlir-commits] [mlir] [mlir][linalg] Add extra_pad_tiles to linalg.pack & unpack (PR #189049)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Apr 13 02:42:08 PDT 2026
https://github.com/fabrizio-indirli updated https://github.com/llvm/llvm-project/pull/189049
>From c0aee44bf510503a28f6cce8242733fc4050c15d Mon Sep 17 00:00:00 2001
From: Fabrizio Indirli <fabrizio.indirli at arm.com>
Date: Wed, 11 Mar 2026 17:39:01 +0000
Subject: [PATCH] [mlir][linalg] Allow extra pad tiles in linalg.pack & unpack
- In linalg.pack, allow the result shape to contain additional
full tiles of high-padding in the tiled dimensions.
This is allowed only when the affected shape is known statically.
Note that in the tile-and-fuse path, the extra full tiles are
only fusible when the tiled pack covers the whole affected
source dimension in a single slice.
- Similarly, in linalg.unpack allow the source outer dimensions
to be larger than the resulting unpacked shape: any additional
outer extent is treated as extra full-tile high-padding and is
discarded from the high end when reconstructing the result tensor.
Signed-off-by: Fabrizio Indirli <fabrizio.indirli at arm.com>
---
.../Dialect/Linalg/IR/LinalgRelayoutOps.td | 16 ++-
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 87 +++++++-----
.../Transforms/PackAndUnpackPatterns.cpp | 18 +--
.../Linalg/Transforms/TilingInterfaceImpl.cpp | 47 +++++--
.../Dialect/Linalg/Transforms/Transforms.cpp | 3 +-
mlir/test/Dialect/Linalg/canonicalize.mlir | 26 ++++
mlir/test/Dialect/Linalg/invalid.mlir | 44 +++---
mlir/test/Dialect/Linalg/roundtrip.mlir | 13 ++
.../Dialect/Linalg/transform-lower-pack.mlir | 53 +++++++-
.../Linalg/vectorization/linalg-ops.mlir | 53 ++++++++
.../lower-to-loops-using-interface.mlir | 82 ++++++++++++
.../tile-and-fuse-consumer-using-slices.mlir | 126 +++++++++++++++++-
.../tile-and-fuse-consumer.mlir | 126 +++++++++++++++++-
mlir/test/python/dialects/linalg/ops.py | 8 +-
14 files changed, 607 insertions(+), 95 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 95383e6262f71..b7e58ae38e0b3 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -164,10 +164,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
tiles divide perfectly the corresponding outer dimension in the result
tensor. It is UB if the tile does not perfectly divide the dimension.
- If present, it will pad along high dimensions (high-padding) to make the
- tile complete. Note that it is not allowed to have artificial padding that
- is not strictly required by linalg.pack (i.e., padding past what is needed
- to complete the last tile along each packed dimension). It is UB if extra
- padding is requested.
+ tile complete.
It is not possible to verify the requirements statically with dynamic
shapes, so they are treated as UB.
@@ -185,7 +182,11 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
// Note: Only tiled dimensions can be padded.
```
- Invalid example that has artificial padding:
+ The packed outer dimensions may be larger than the minimum shape implied by
+ `shape(source)` and `inner_tiles`, if `padding_value` is specified.
+ Any additional outer extent is treated as extra full-tile high-padding.
+
+ Invalid example that has artificial padding without a `padding_value`:
```mlir
%0 = linalg.pack %src padding_value(%cst : f32) inner_dims_pos = [0]
inner_tiles = [8] into %dest
@@ -329,6 +330,11 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
operation and hence the following holds
`NumElementsOf(source) >= NumElementsOf(result)`.
+ The packed source outer dimensions may be larger than the minimum shape
+ implied by `shape(result)` and `inner_tiles`. Any additional outer extent
+ is treated as extra full-tile high-padding and is discarded from the high
+ end when reconstructing the result tensor.
+
Examples:
```mlir
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e9698365765e7..0e3989aa60992 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5104,6 +5104,21 @@ static SmallVector<int64_t> getStaticTilesImpl(OpTy op) {
return staticTiles;
}
+static inline bool isEmptyOrZeroArray(ArrayRef<int64_t> values) {
+ return values.empty() ||
+ llvm::all_of(values, [](int64_t value) { return value == 0; });
+}
+
+static SmallVector<int64_t>
+getPackedOuterShapeInUnpackedOrder(ArrayRef<int64_t> packedShape,
+ size_t unpackedRank,
+ ArrayRef<int64_t> outerDimsPerm) {
+ SmallVector<int64_t> result(packedShape.take_front(unpackedRank));
+ if (!outerDimsPerm.empty())
+ applyPermutationToVector(result, invertPermutationVector(outerDimsPerm));
+ return result;
+}
+
/// Returns true if `dimsPos` is invalid. It is invalid when:
/// a) It contains duplicate.
/// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
@@ -5215,20 +5230,30 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
<< " but dynamic tile size";
}
}
- if (failed(
- verifyCompatibleShape(expectedPackedShape, packedType.getShape()))) {
- auto elementType = unpackedType.getElementType();
- Type expectedType, actualType;
- if (packOrUnPack.hasPureTensorSemantics()) {
- expectedType = RankedTensorType::get(expectedPackedShape, elementType);
- actualType = RankedTensorType::get(packedType.getShape(), elementType);
- } else {
- expectedType = MemRefType::get(expectedPackedShape, elementType);
- actualType = MemRefType::get(packedType.getShape(), elementType);
+ SmallVector<int64_t> expectedOuterShape = getPackedOuterShapeInUnpackedOrder(
+ expectedPackedShape, unpackedRank, outerDimPerm);
+ SmallVector<int64_t> actualOuterShape =
+ getPackedOuterShapeWithoutTransposition(packOrUnPack);
+ llvm::SmallBitVector areOuterDimsTiled(unpackedRank);
+ for (int64_t pos : innerDimsPos)
+ areOuterDimsTiled.set(pos);
+ for (auto [index, actualOuter] : llvm::enumerate(actualOuterShape)) {
+ int64_t expectedOuter = expectedOuterShape[index];
+ if (ShapedType::isDynamic(actualOuter) ||
+ ShapedType::isDynamic(expectedOuter))
+ continue;
+ if (areOuterDimsTiled[index]) {
+ if (actualOuter < expectedOuter) {
+ return op->emitError("expected packed outer dimension ")
+ << index << " to be at least " << expectedOuter << ", got "
+ << actualOuter;
+ }
+ continue;
+ }
+ if (actualOuter != expectedOuter) {
+ return op->emitError("expected packed outer dimension ")
+ << index << " to be " << expectedOuter << ", got " << actualOuter;
}
- return op->emitError("expected ")
- << expectedType << " for the packed domain value, got "
- << actualType;
}
return success();
}
@@ -5505,28 +5530,22 @@ bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> outputShape,
ArrayRef<int64_t> outerDimsPerm,
ArrayRef<OpFoldResult> innerTiles) {
- SmallVector<int64_t> outputTileSizes(
- outputShape.take_front(inputShape.size()));
+ SmallVector<int64_t> outputTileSizes = getPackedOuterShapeInUnpackedOrder(
+ outputShape, inputShape.size(), outerDimsPerm);
if (!outerDimsPerm.empty()) {
assert(outerDimsPerm.size() == outputTileSizes.size() &&
"expected output and outer_dims_perm to have same size");
- applyPermutationToVector(outputTileSizes,
- invertPermutationVector(outerDimsPerm));
}
for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
- if (ShapedType::isDynamic(inputShape[pos]))
+ if (ShapedType::isDynamic(inputShape[pos]) ||
+ ShapedType::isDynamic(outputTileSizes[pos]))
continue;
std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
- if (!constantTile) {
- if (ShapedType::isStatic(outputTileSizes[pos]) &&
- (inputShape[pos] % outputTileSizes[pos] != 0))
- return true;
- } else {
- assert(*constantTile != 0 && "static tile size can't be zero");
- if (inputShape[pos] % (*constantTile) != 0) {
- return true;
- }
- }
+ if (!constantTile)
+ continue;
+ assert(*constantTile != 0 && "static tile size can't be zero");
+ if (outputTileSizes[pos] * (*constantTile) != inputShape[pos])
+ return true;
}
return false;
}
@@ -5536,13 +5555,11 @@ bool PackOp::requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> outputShape,
ArrayRef<int64_t> outerDimsPerm,
ArrayRef<OpFoldResult> innerTiles) {
- SmallVector<int64_t> outputTileSizes(
- outputShape.take_front(inputShape.size()));
+ SmallVector<int64_t> outputTileSizes = getPackedOuterShapeInUnpackedOrder(
+ outputShape, inputShape.size(), outerDimsPerm);
if (!outerDimsPerm.empty()) {
assert(outerDimsPerm.size() == outputTileSizes.size() &&
"expected output and outer_dims_perm to have same size");
- applyPermutationToVector(outputTileSizes,
- invertPermutationVector(outerDimsPerm));
}
for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
if (ShapedType::isDynamic(inputShape[pos]) ||
@@ -5614,7 +5631,6 @@ SmallVector<int64_t> PackOp::inferPackedShape(ArrayRef<int64_t> inputShape,
resultShape[tiledDim.value()] = llvm::divideCeilSigned(
resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
}
-
// Swap tile loops if outer_dims_perm is available.
if (!outerDimsPerm.empty())
applyPermutationToVector(resultShape, outerDimsPerm);
@@ -6422,8 +6438,9 @@ bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
SmallVector<int64_t> outerShapeWithoutTranspose =
getPackedOuterShapeWithoutTransposition(*this);
SmallVector<bool> areOuterDimsTiled(outerShapeWithoutTranspose.size(), false);
- for (auto [pos, tileSize] :
- llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
+ for (auto [index, values] : llvm::enumerate(llvm::zip_equal(
+ this->getInnerDimsPos(), this->getStaticInnerTiles()))) {
+ auto [pos, tileSize] = values;
areOuterDimsTiled[pos] = true;
if (unpackedTypeAfterFold.isDynamicDim(pos))
return false;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 993eae62535c3..c90981b69cd27 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -236,9 +236,10 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
ShapedType unpackedType = packOp.getSourceType();
SmallVector<int64_t> outerShapeWithoutTranspose =
getPackedOuterShapeWithoutTransposition(packOp);
- for (auto [pos, tileSize, high] :
- llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getStaticInnerTiles(),
- padOp.getMixedHighPad())) {
+ for (auto [index, tuple] : llvm::enumerate(llvm::zip_equal(
+ packOp.getInnerDimsPos(), packOp.getStaticInnerTiles(),
+ padOp.getMixedHighPad()))) {
+ auto [pos, tileSize, high] = tuple;
if (unpackedType.isDynamicDim(pos))
return failure();
if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
@@ -248,10 +249,9 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
std::optional<int64_t> cstHigh = getConstantIntValue(high);
if (!cstHigh)
return failure();
- int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
- unpackedType.getDimSize(pos);
- // Do not fold the op if it requires artificial padding.
- if (paddingSize + cstHigh.value() >= tileSize)
+ int64_t originalDim = unpackedType.getDimSize(pos) - cstHigh.value();
+ int64_t baseOuter = llvm::divideCeilSigned(originalDim, tileSize);
+ if (baseOuter != outerShapeWithoutTranspose[pos])
return failure();
}
@@ -440,7 +440,7 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
// Can't use applyPermutationToVector for newInnerDimsPosVec since input and
// permutation rank won't necessarily be equal in all cases.
- for (auto dim : innerDimsPos)
+ for (int64_t dim : innerDimsPos)
newInnerDimsPosVec.push_back(transposePermutation[dim]);
Value output = packOp.createDestinationTensor(
@@ -497,7 +497,7 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
// Can't use applyPermutationToVector for newInnerDimsPosVec since input and
// permutation rank won't necessarily be equal in all cases.
- for (auto dim : innerDimsPos)
+ for (int64_t dim : innerDimsPos)
newInnerDimsPosVec.push_back(newOuterDimsPermVec[dim]);
if (!outerDimsPerm.empty())
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 558ebdebd65c5..7981dd211ee36 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -1101,13 +1101,16 @@ struct PackOpTiling
SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
packOp.getDimAndTileMapping();
- SmallVector<int64_t> outerShapeWithoutTranspose(
- packOp.getDestType().getShape().take_front(packOp.getSourceRank()));
+ SmallVector<OpFoldResult> outerShapeWithoutTranspose =
+ tensor::getMixedSizes(b, loc, packOp.getDest());
+ outerShapeWithoutTranspose.resize(packOp.getSourceRank());
if (!packOp.getOuterDimsPerm().empty()) {
applyPermutationToVector(
outerShapeWithoutTranspose,
invertPermutationVector(packOp.getOuterDimsPerm()));
}
+ SmallVector<OpFoldResult> srcDimValues =
+ tensor::getMixedSizes(b, loc, packOp.getSource());
for (auto dim : llvm::seq<int64_t>(packOp.getSourceRank())) {
if (dimAndTileMapping.count(dim)) {
FailureOr<int64_t> cstTileSize =
@@ -1125,18 +1128,12 @@ struct PackOpTiling
// TODO: It could be untiled if the `srcDimSize` is dynamic. It is a
// hard check to determine if a dimension is tiled or not.
int64_t srcDimSize = packOp.getSourceType().getDimSize(dim);
- int64_t destDimSize = outerShapeWithoutTranspose[dim];
bool isTiled = failed(cstTileSize) ||
ShapedType::isDynamic(srcDimSize) ||
cstTileSize.value() < srcDimSize;
if (!isTiled) {
outerDimOffsets.push_back(offsets[dim]);
- if (ShapedType::isStatic(destDimSize)) {
- outerDimSizes.push_back(b.getIndexAttr(destDimSize));
- } else {
- outerDimSizes.push_back(
- b.createOrFold<tensor::DimOp>(loc, packOp.getDest(), dim));
- }
+ outerDimSizes.push_back(outerShapeWithoutTranspose[dim]);
continue;
}
@@ -1165,9 +1162,37 @@ struct PackOpTiling
bindSymbols(b.getContext(), sym);
auto avOffset = AV(dim0).bind(offsets[dim]);
auto avSize = AV(dim0).bind(sizes[dim]);
+ auto avSrcDim = AV(dim0).bind(srcDimValues[dim]);
auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
- outerDimOffsets.push_back(ab.floor(avOffset, avTileSize));
- outerDimSizes.push_back(ab.ceil(avSize, avTileSize));
+ OpFoldResult outerDimOffset = ab.floor(avOffset, avTileSize);
+ // Minimum packed outer-dimension size needed for the current tiled
+ // source slice.
+ OpFoldResult minOuterDimSize = ab.ceil(avSize, avTileSize);
+ outerDimOffsets.push_back(outerDimOffset);
+ // Minimum packed outer-dimension size needed to represent the entire
+ // source dimension, ignoring any intentional extra pad tiles.
+ OpFoldResult minGlobalOuterDimSize = ab.ceil(avSrcDim, avTileSize);
+ std::optional<int64_t> actualOuterCst =
+ getConstantIntValue(outerShapeWithoutTranspose[dim]);
+ std::optional<int64_t> minGlobalOuterCst =
+ getConstantIntValue(minGlobalOuterDimSize);
+ bool hasStaticExtraFullTiles = actualOuterCst && minGlobalOuterCst &&
+ *actualOuterCst > *minGlobalOuterCst;
+
+ // Dynamic packed outer dims are interpreted as the minimum packed
+ // shape.
+ if (!hasStaticExtraFullTiles) {
+ outerDimSizes.push_back(minOuterDimSize);
+ continue;
+ }
+
+ // Extra full tiles are only fusible when the tiled pack covers the
+ // whole affected source dimension in a single slice.
+ if (!isZeroInteger(offsets[dim]) ||
+ !isEqualConstantIntOrValue(sizes[dim], srcDimValues[dim]))
+ return failure();
+
+ outerDimSizes.push_back(outerShapeWithoutTranspose[dim]);
} else {
outerDimOffsets.push_back(offsets[dim]);
outerDimSizes.push_back(sizes[dim]);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 260e36fb47f04..ea401435b6449 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -588,7 +588,8 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
// Build the symmetrical UnPackOp to the existing PackOp.
unPackOps.push_back(linalg::UnPackOp::create(
rewriter, packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
- maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
+ maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles(),
+ maybePackedInit.getOuterDimsPerm()));
results.push_back(unPackOps.back().getResult());
}
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 0c5a1c6108ae3..efe06d0e72963 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -2158,6 +2158,32 @@ func.func @fold_pack_unpack_tensor(%x: tensor<2x3xf32>) -> tensor<2x3xf32> {
// -----
+// CHECK-LABEL: func.func @keep_unpack_pack_tensor_with_larger_packed_shape
+// CHECK: linalg.unpack
+// CHECK: linalg.pack
+func.func @keep_unpack_pack_tensor_with_larger_packed_shape(%x: tensor<17x10x8x32xf32>, %dest: tensor<127x255xf32>) -> tensor<17x10x8x32xf32> {
+ %unpacked = linalg.unpack %x inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+ into %dest : tensor<17x10x8x32xf32> -> tensor<127x255xf32>
+ %cst = arith.constant 0.0 : f32
+ %packed = linalg.pack %unpacked padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+ into %x : tensor<127x255xf32> -> tensor<17x10x8x32xf32>
+ return %packed : tensor<17x10x8x32xf32>
+}
+
+// CHECK-LABEL: func.func @do_not_fold_unpack_pack_tensor_with_mismatched_packed_shape
+// CHECK: linalg.unpack
+// CHECK: linalg.pack
+func.func @do_not_fold_unpack_pack_tensor_with_mismatched_packed_shape(%x: tensor<17x10x8x32xf32>, %dest: tensor<127x255xf32>, %packed_dest: tensor<17x9x8x32xf32>) -> tensor<17x9x8x32xf32> {
+ %unpacked = linalg.unpack %x inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+ into %dest : tensor<17x10x8x32xf32> -> tensor<127x255xf32>
+ %cst = arith.constant 0.0 : f32
+ %packed = linalg.pack %unpacked padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+ into %packed_dest : tensor<127x255xf32> -> tensor<17x9x8x32xf32>
+ return %packed : tensor<17x9x8x32xf32>
+}
+
+// -----
+
// Test that pack/unpack canonicalization is disabled for memref versions.
// CHECK-LABEL: func.func @negative_pack_unpack_memref_no_canonicalization
// CHECK: linalg.pack
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index d9192cbda14e7..c81c3c6b92abf 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1811,22 +1811,6 @@ func.func @pack_invalid_no_padding_no_full_tiles(%input: tensor<256x128xf32>, %o
// -----
-func.func @pack_invalid_no_padding_no_full_tiles_dyn_tiles(%input: tensor<256x128xf32>, %output: tensor<10x8x?x?xf32>, %tile_size_0: index, %tile_size_1: index) -> tensor<10x8x?x?xf32> {
- // expected-error at +1 {{invalid tile factor or output size provided. Only full tiles are supported when padding_value is not set}}
- %0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [%tile_size_0, %tile_size_1] into %output : tensor<256x128xf32> -> tensor<10x8x?x?xf32>
- return %0 : tensor<10x8x?x?xf32>
-}
-
-// -----
-
-func.func @pack_invalid_no_padding_no_full_tiles_dyn_tiles_outperm(%input: tensor<256x128xf32>, %output: tensor<8x10x?x?xf32>, %tile_size_0: index, %tile_size_1: index) -> tensor<8x10x?x?xf32> {
- // expected-error at +1 {{invalid tile factor or output size provided. Only full tiles are supported when padding_value is not set}}
- %0 = linalg.pack %input outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [%tile_size_0, %tile_size_1] into %output : tensor<256x128xf32> -> tensor<8x10x?x?xf32>
- return %0 : tensor<8x10x?x?xf32>
-}
-
-// -----
-
func.func @pad_and_pack_invalid_type(%input: tensor<13x15xf32>, %output: tensor<2x8x8x2xf32>, %pad: i32) -> tensor<2x8x8x2xf32> {
// expected-error at +1 {{expected padding_value has 'f32' but got: 'i32'}}
%0 = linalg.pack %input padding_value(%pad: i32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<13x15xf32> -> tensor<2x8x8x2xf32>
@@ -1911,6 +1895,14 @@ func.func @pack_invalid_outer_dims_perm(%source: tensor<128x256xf32>, %dest: ten
// -----
+func.func @pack_invalid_larger_result_without_padding(%source: tensor<128x256xf32>, %dest: tensor<17x8x8x32xf32>) -> tensor<17x8x8x32xf32> {
+ // expected-error at +1 {{invalid tile factor or output size provided. Only full tiles are supported when padding_value is not set}}
+ %0 = linalg.pack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %dest : tensor<128x256xf32> -> tensor<17x8x8x32xf32>
+ return %0 : tensor<17x8x8x32xf32>
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// linalg.unpack
//===----------------------------------------------------------------------===//
@@ -1939,13 +1931,13 @@ func.func @unpack_invalid_outer_dims_perm(%source: tensor<128x256xf32>, %dest: t
// -----
-func.func @pack_with_artificial_padding(%input: tensor<9xf32>, %output: tensor<3x8xf32>) -> tensor<3x8xf32> {
+func.func @pack_invalid_small_result_shape_1d(%input: tensor<9xf32>, %output: tensor<1x8xf32>) -> tensor<1x8xf32> {
%cst = arith.constant 0.0 : f32
- // expected-error at +1 {{expected 'tensor<2x8xf32>' for the packed domain value, got 'tensor<3x8xf32>'}}
+ // expected-error at +1 {{expected packed outer dimension 0 to be at least 2, got 1}}
%0 = linalg.pack %input padding_value(%cst : f32) inner_dims_pos = [0]
inner_tiles = [8] into %output
- : tensor<9xf32> -> tensor<3x8xf32>
- return %0 : tensor<3x8xf32>
+ : tensor<9xf32> -> tensor<1x8xf32>
+ return %0 : tensor<1x8xf32>
}
// -----
@@ -1953,7 +1945,7 @@ func.func @pack_with_artificial_padding(%input: tensor<9xf32>, %output: tensor<3
// The outer dims in the output tensor are incorrectly/unexpectedly transposed.
// This could be fixed by adding `outer_dims_perm = [1, 0]` (the default value assumes no transpose).
func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tensor<4x16x32x16xf32>) -> tensor<4x16x32x16xf32> {
- // expected-error at +1 {{expected 'tensor<16x4x32x16xf32>' for the packed domain value, got 'tensor<4x16x32x16xf32>'}}
+ // expected-error at +1 {{expected packed outer dimension 0 to be at least 16, got 4}}
%0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [32, 16] into %output : tensor<256x128xf32> -> tensor<4x16x32x16xf32>
return %0 : tensor<4x16x32x16xf32>
}
@@ -1961,24 +1953,24 @@ func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tenso
// -----
func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tensor<8x7x16x32xf32>) -> tensor<8x7x16x32xf32> {
- // expected-error at +1 {{expected 'tensor<8x8x16x32xf32>' for the packed domain value, got 'tensor<8x7x16x32xf32>'}}
+ // expected-error at +1 {{expected packed outer dimension 1 to be at least 8, got 7}}
%0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x7x16x32xf32>
return %0 : tensor<8x7x16x32xf32>
}
// -----
-func.func @unpack_with_artifical_tiles_that_are_dropped(%input: tensor<3x8xf32>, %output: tensor<9xf32>) -> tensor<9xf32> {
- // expected-error at +1 {{expected 'tensor<2x8xf32>' for the packed domain value, got 'tensor<3x8xf32>'}}
+func.func @unpack_invalid_small_source_shape_1d(%input: tensor<1x8xf32>, %output: tensor<9xf32>) -> tensor<9xf32> {
+ // expected-error at +1 {{expected packed outer dimension 0 to be at least 2, got 1}}
%0 = linalg.unpack %input inner_dims_pos = [0] inner_tiles = [8] into %output
- : tensor<3x8xf32> -> tensor<9xf32>
+ : tensor<1x8xf32> -> tensor<9xf32>
return %0 : tensor<9xf32>
}
// -----
func.func @unpack_invalid_source_shape(%output: tensor<256x128xf32>, %input: tensor<8x8x4x32xf32>) -> tensor<256x128xf32> {
- // expected-error at +1 {{expected 'tensor<8x32x4x32xf32>' for the packed domain value, got 'tensor<8x8x4x32xf32>'}}
+ // expected-error at +1 {{expected packed outer dimension 1 to be at least 32, got 8}}
%0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x4x32xf32> -> tensor<256x128xf32>
return %0 : tensor<256x128xf32>
}
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index bfb92c3289a49..b776da9a0b86f 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -829,3 +829,16 @@ func.func @test_unpack_tensor(%arg0: tensor<16x8x8x32xf32>, %arg1: tensor<128x25
// CHECK: return %[[RESULT]] : tensor<128x256xf32>
return %0 : tensor<128x256xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @test_pack_unpack_tensor_with_extra_tiles
+func.func @test_pack_unpack_tensor_with_extra_tiles(%arg0: tensor<127x255xf32>, %arg1: tensor<17x10x8x32xf32>) -> tensor<127x255xf32> {
+ %pad = arith.constant 0.0 : f32
+ // CHECK: %[[PACK:.*]] = linalg.pack %{{.*}} padding_value(%{{.*}} : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %{{.*}} : tensor<127x255xf32> -> tensor<17x10x8x32xf32>
+ %0 = linalg.pack %arg0 padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %arg1 : tensor<127x255xf32> -> tensor<17x10x8x32xf32>
+ // CHECK: %[[UNPACK:.*]] = linalg.unpack %[[PACK]] inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %{{.*}} : tensor<17x10x8x32xf32> -> tensor<127x255xf32>
+ %1 = linalg.unpack %0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %arg0 : tensor<17x10x8x32xf32> -> tensor<127x255xf32>
+ // CHECK: return %[[UNPACK]] : tensor<127x255xf32>
+ return %1 : tensor<127x255xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index b6fe67a9ae1f3..9f777c5aac425 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -229,6 +229,55 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: func.func @pack_with_larger_result_shape(
+func.func @pack_with_larger_result_shape(%arg0: tensor<127x255xf32>, %arg1: tensor<17x10x8x32xf32>) -> tensor<17x10x8x32xf32> {
+ %cst_0 = arith.constant 0.0 : f32
+ // CHECK: tensor.pad {{.*}} low[0, 0] high[9, 65]
+ // CHECK: : tensor<127x255xf32> to tensor<136x320xf32>
+ %pack = linalg.pack %arg0 padding_value(%cst_0 : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %arg1
+ : tensor<127x255xf32> -> tensor<17x10x8x32xf32>
+ return %pack : tensor<17x10x8x32xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %pack = transform.structured.match ops{["linalg.pack"]} in %module_op
+ : (!transform.any_op) -> !transform.op<"linalg.pack">
+ transform.structured.lower_pack %pack : (!transform.op<"linalg.pack">)
+ -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpack_with_larger_packed_shape(
+func.func @unpack_with_larger_packed_shape(%arg0: tensor<17x10x8x32xf32>, %arg1: tensor<127x255xf32>) -> tensor<127x255xf32> {
+ // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<17x8x10x32xf32>
+ // CHECK: %[[TRAN:.*]] = linalg.transpose
+ // CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{.*}} : tensor<17x8x10x32xf32> into tensor<136x320xf32>
+ // CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0] [127, 255] [1, 1]
+ %unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %arg1
+ : tensor<17x10x8x32xf32> -> tensor<127x255xf32>
+ return %unpack : tensor<127x255xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %unpack = transform.structured.match ops{["linalg.unpack"]} in %module_op
+ : (!transform.any_op) -> !transform.op<"linalg.unpack">
+ transform.structured.lower_unpack %unpack : (!transform.op<"linalg.unpack">)
+ -> (!transform.op<"tensor.empty">,
+ !transform.op<"linalg.transpose">,
+ !transform.op<"tensor.collapse_shape">,
+ !transform.op<"tensor.extract_slice">,
+ !transform.op<"linalg.copy">)
+ transform.yield
+ }
+}
+
+// -----
+
// When an unpack is a plain 'unpad', lower it to a simple extract_slice.
// CHECK-LABEL: func.func @unpack_as_pad(
func.func @unpack_as_pad(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
@@ -625,8 +674,8 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: permutation = [0, 2, 1, 3]
// CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3]]
// CHECK-SAME: : tensor<?x?x?x?xf32> into tensor<?x?xf32>
-// CHECK: %[[DIM10:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM11:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
+// CHECK-DAG: %[[DIM10:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
+// CHECK-DAG: %[[DIM11:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0] [%[[DIM10]], %[[DIM11]]] [1, 1]
// CHECK-SAME: : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: linalg.copy ins(%[[SLICE]] : tensor<?x?xf32>)
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
index a5d94bc4f581c..fc0d0674b299f 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
@@ -1252,6 +1252,31 @@ func.func @test_vectorize_unpack_no_vector_sizes(%source: tensor<8x8x32x16xf32>,
// -----
+// CHECK-LABEL: test_vectorize_unpack_larger_packed_shape_no_vector_sizes
+// CHECK-SAME: %[[SRC:.*]]: tensor<17x10x8x32xf32>
+// CHECK-SAME: %[[DEST:.*]]: tensor<127x255xf32>
+func.func @test_vectorize_unpack_larger_packed_shape_no_vector_sizes(%source: tensor<17x10x8x32xf32>, %dest: tensor<127x255xf32>) -> tensor<127x255xf32> {
+ // CHECK-DAG: %[[PAD:.*]] = ub.poison : f32
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{.*}} %[[PAD]] {{.*}} : tensor<17x10x8x32xf32>, vector<17x10x8x32xf32>
+ // CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [0, 2, 1, 3] : vector<17x10x8x32xf32> to vector<17x8x10x32xf32>
+ // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<17x8x10x32xf32> to vector<136x320xf32>
+ // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], %[[DEST]]{{.*}} : vector<136x320xf32>, tensor<127x255xf32>
+ // CHECK: return %[[WRIT]] : tensor<127x255xf32>
+ %0 = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %dest
+ : tensor<17x10x8x32xf32> -> tensor<127x255xf32>
+ return %0 : tensor<127x255xf32>
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK-LABEL: test_vectorize_unpack_no_vector_sizes_slice_output
// CHECK-SAME: %[[SRC:.*]]: tensor<8x4x16x16xf32>
// CHECK-SAME: %[[DEST:.*]]: tensor<64x127xf32>
@@ -1421,6 +1446,34 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: func @pack_with_larger_result_shape_no_vector_sizes
+// CHECK-SAME: %[[SRC:.*]]: tensor<127x255xf32>,
+// CHECK-SAME: %[[DEST:.*]]: tensor<17x10x8x32xf32>
+func.func @pack_with_larger_result_shape_no_vector_sizes(%src: tensor<127x255xf32>, %dest: tensor<17x10x8x32xf32>) -> tensor<17x10x8x32xf32> {
+ %pad = arith.constant 0.000000e+00 : f32
+ %pack = linalg.pack %src padding_value(%pad : f32)
+ inner_dims_pos = [0, 1]
+ inner_tiles = [8, 32]
+ into %dest : tensor<127x255xf32> -> tensor<17x10x8x32xf32>
+ return %pack : tensor<17x10x8x32xf32>
+}
+// CHECK-DAG: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C0]]], %[[PAD]] : tensor<127x255xf32>, vector<136x320xf32>
+// CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<136x320xf32> to vector<17x8x10x32xf32>
+// CHECK: %[[TR:.*]] = vector.transpose %[[SC]], [0, 2, 1, 3] : vector<17x8x10x32xf32> to vector<17x10x8x32xf32>
+// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[TR]], %[[DEST]][%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] {in_bounds = [true, true, true, true]} : vector<17x10x8x32xf32>, tensor<17x10x8x32xf32>
+// CHECK: return %[[WRITE]] : tensor<17x10x8x32xf32>
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK-LABEL: func @pack_with_dynamic_dims
// CHECK-SAME: %[[SRC:.*]]: tensor<?x?xf32>,
// CHECK-SAME: %[[DEST:.*]]: tensor<?x?x16x2xf32>
diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
index ec9b606491910..9c792a3ea55f9 100644
--- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
@@ -488,6 +488,52 @@ module attributes {transform.with_named_sequence} {
// -----
+func.func @NC_to_NCnc_extra(%arg0: memref<128x128xf32>, %arg1: memref<18x17x8x8xf32>, %arg2: f32) {
+ linalg.pack %arg0 padding_value(%arg2 : f32) inner_dims_pos = [1, 0] inner_tiles = [8, 8]
+ into %arg1
+ : memref<128x128xf32> -> memref<18x17x8x8xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %pack = transform.structured.match ops{["linalg.pack"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.convert_to_loops %pack
+ : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: #[[$MAP_EXTRA:.*]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
+// CHECK-LABEL: func @NC_to_NCnc_extra(
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: memref<128x128xf32>
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]: memref<18x17x8x8xf32>
+// CHECK-SAME: %[[PAD:[a-zA-Z0-9]+]]: f32
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[C17:.*]] = arith.constant 17 : index
+// CHECK-DAG: %[[C18:.*]] = arith.constant 18 : index
+// CHECK-DAG: %[[C128:.*]] = arith.constant 128 : index
+// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C18]] step %[[C1]] {
+// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[C17]] step %[[C1]] {
+// CHECK: scf.for %[[K:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK: scf.for %[[L:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK-DAG: %[[SRC_I:.*]] = affine.apply #[[$MAP_EXTRA]](%[[I]], %[[L]])
+// CHECK-DAG: %[[SRC_J:.*]] = affine.apply #[[$MAP_EXTRA]](%[[J]], %[[K]])
+// CHECK: %[[BOUND_I:.*]] = arith.cmpi slt, %[[SRC_I]], %[[C128]] : index
+// CHECK: %[[BOUND_J:.*]] = arith.cmpi slt, %[[SRC_J]], %[[C128]] : index
+// CHECK: %[[IN_BOUNDS:.*]] = arith.andi %[[BOUND_I]], %[[BOUND_J]] : i1
+// CHECK: %[[VAL:.*]] = scf.if %[[IN_BOUNDS]] -> (f32) {
+// CHECK: %[[LOAD:.*]] = memref.load %[[SRC]][%[[SRC_I]], %[[SRC_J]]] : memref<128x128xf32>
+// CHECK: scf.yield %[[LOAD]]
+// CHECK: } else {
+// CHECK: scf.yield %[[PAD]]
+// CHECK: }
+// CHECK: memref.store %[[VAL]], %[[DEST]][%[[I]], %[[J]], %[[K]], %[[L]]] : memref<18x17x8x8xf32>
+
+// -----
+
func.func @KC_to_KCck(%arg0: memref<128x256xf32>, %arg1: memref<4x8x32x32xf32>) {
linalg.pack %arg0 inner_dims_pos = [1, 0] inner_tiles = [32, 32] into %arg1
: memref<128x256xf32> -> memref<4x8x32x32xf32>
@@ -564,6 +610,42 @@ module attributes {transform.with_named_sequence} {
// -----
+func.func @NCnc_to_NC_drop(%arg0: memref<128x128xf32>, %arg1: memref<18x17x8x8xf32>) {
+ linalg.unpack %arg1 inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %arg0
+ : memref<18x17x8x8xf32> -> memref<128x128xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %unpack = transform.structured.match ops{["linalg.unpack"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.convert_to_loops %unpack
+ : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+// CHECK-DAG: #[[$MAP_DROP_FLOOR:.*]] = affine_map<(d0) -> (d0 floordiv 8)>
+// CHECK-DAG: #[[$MAP_DROP_MOD:.*]] = affine_map<(d0) -> (d0 mod 8)>
+// CHECK-LABEL: func @NCnc_to_NC_drop(
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]: memref<128x128xf32>
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: memref<18x17x8x8xf32>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C128:.*]] = arith.constant 128 : index
+// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C128]] step %[[C1]] {
+// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[C128]] step %[[C1]] {
+// CHECK-DAG: %[[FLOOR_I:.*]] = affine.apply #[[$MAP_DROP_FLOOR]](%[[I]])
+// CHECK-DAG: %[[MOD_I:.*]] = affine.apply #[[$MAP_DROP_MOD]](%[[I]])
+// CHECK-DAG: %[[FLOOR_J:.*]] = affine.apply #[[$MAP_DROP_FLOOR]](%[[J]])
+// CHECK-DAG: %[[MOD_J:.*]] = affine.apply #[[$MAP_DROP_MOD]](%[[J]])
+// CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[FLOOR_I]], %[[FLOOR_J]], %[[MOD_I]], %[[MOD_J]]] : memref<18x17x8x8xf32>
+// CHECK: memref.store %[[VAL]], %[[DEST]][%[[I]], %[[J]]] : memref<128x128xf32>
+// CHECK: }
+// CHECK: }
+
+// -----
+
func.func @KCck_to_KC(%arg0: memref<128x256xf32>, %arg1: memref<4x8x32x32xf32>) {
linalg.unpack %arg1 inner_dims_pos = [1, 0] inner_tiles = [32, 32] into %arg0
: memref<4x8x32x32xf32> -> memref<128x256xf32>
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir
index 62dd7faec4eb7..344e4c5a76415 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir
@@ -502,6 +502,58 @@ module attributes {transform.with_named_sequence} {
// -----
+#map = affine_map<(d0) -> (-d0 + 4, 16)>
+func.func @fuse_pack_consumer_if_single_iteration_extra_pad_tiles(%arg0: tensor<4x4xf32>) -> tensor<2x5x16x1xf32> {
+ %0 = tensor.empty() : tensor<2x5x16x1xf32>
+ %1 = tensor.empty() : tensor<4x4xf32>
+ %2 = scf.forall (%arg1) = (0) to (4) step (16) shared_outs(%arg2 = %1) -> (tensor<4x4xf32>) {
+ %3 = affine.min #map(%arg1)
+ %extracted_slice = tensor.extract_slice %arg0[%arg1, 0] [%3, 4] [1, 1] : tensor<4x4xf32> to tensor<?x4xf32>
+ %extracted_slice_0 = tensor.extract_slice %arg2[%arg1, 0] [%3, 4] [1, 1] : tensor<4x4xf32> to tensor<?x4xf32>
+ %4 = linalg.exp ins(%extracted_slice : tensor<?x4xf32>) outs(%extracted_slice_0 : tensor<?x4xf32>) -> tensor<?x4xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %4 into %arg2[%arg1, 0] [%3, 4] [1, 1] : tensor<?x4xf32> into tensor<4x4xf32>
+ }
+ }
+ %cst = arith.constant 0.000000e+00 : f32
+ %pack = linalg.pack %2 padding_value(%cst : f32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1]
+ inner_tiles = [16, 1] into %0
+ : tensor<4x4xf32> -> tensor<2x5x16x1xf32>
+ return %pack : tensor<2x5x16x1xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %consumer, %fused_consumer = transform.test.fuse_consumer_using_slice %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: func.func @fuse_pack_consumer_if_single_iteration_extra_pad_tiles(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[PACK_INIT:.*]] = tensor.empty() : tensor<2x5x16x1xf32>
+// CHECK-DAG: %[[ELEM_INIT:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (4) step (16)
+// CHECK-SAME: shared_outs(%[[ELEM_OUT_ARG:.*]] = %[[ELEM_INIT]], %[[PACK_OUT_ARG:.*]] = %[[PACK_INIT]])
+// CHECK-DAG: %[[SIZE:.+]] = affine.min #[[MAP]](%[[IV]])
+// CHECK-DAG: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
+// CHECK-DAG: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
+// CHECK: %[[ELEM:.*]] = linalg.exp
+// CHECK-SAME: ins(%[[ELEM_SRC]]
+// CHECK-SAME: outs(%[[ELEM_DEST]]
+// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[IV]], 0, 0, 0] [2, 5, 16, 1] [1, 1, 1, 1]
+// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]]
+// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32)
+// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 1]
+// CHECK-SAME: into %[[TILED_PACK_DEST]]
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][%[[IV]], 0, 0, 0] [2, 5, 16, 1] [1, 1, 1, 1]
+
+// -----
+
func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>, %arg2: tensor<2x64x16x1xf32>) -> tensor<2x64x16x1xf32> {
%0 = scf.forall (%arg3) = (0) to (32) step (16) shared_outs(%arg4 = %arg1) -> (tensor<64x32xf32>) {
%src = tensor.extract_slice %arg0[0, %arg3] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
@@ -546,6 +598,50 @@ module attributes {transform.with_named_sequence} {
// -----
+func.func @fuse_pack_consumer_with_larger_packed_shape_on_unsliced_dim(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>, %arg2: tensor<64x4x16xf32>) -> tensor<64x4x16xf32> {
+ %0 = scf.forall (%arg3) = (0) to (64) step (16) shared_outs(%arg4 = %arg1) -> (tensor<64x32xf32>) {
+ %src = tensor.extract_slice %arg0[%arg3, 0] [16, 32] [1, 1] : tensor<64x32xf32> to tensor<16x32xf32>
+ %dest = tensor.extract_slice %arg4[%arg3, 0] [16, 32] [1, 1] : tensor<64x32xf32> to tensor<16x32xf32>
+ %1 = linalg.exp ins(%src : tensor<16x32xf32>) outs(%dest : tensor<16x32xf32>) -> tensor<16x32xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %1 into %arg4[%arg3, 0] [16, 32] [1, 1] : tensor<16x32xf32> into tensor<64x32xf32>
+ }
+ }
+ %cst = arith.constant 0.000000e+00 : f32
+ %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [1] inner_tiles = [16] into %arg2 : tensor<64x32xf32> -> tensor<64x4x16xf32>
+ return %pack : tensor<64x4x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %consumer, %fused_consumer = transform.test.fuse_consumer_using_slice %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK-LABEL: func.func @fuse_pack_consumer_with_larger_packed_shape_on_unsliced_dim(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (64) step (16)
+// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[ARG2]])
+// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [16, 32] [1, 1]
+// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]], 0] [16, 32] [1, 1]
+// CHECK: %[[ELEM:.*]] = linalg.exp
+// CHECK-SAME: ins(%[[ELEM_SRC]]
+// CHECK-SAME: outs(%[[ELEM_DEST]]
+// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[IV]], 0, 0] [16, 4, 16] [1, 1, 1]
+// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]]
+// CHECK-SAME: padding_value(%{{.*}} : f32)
+// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [16]
+// CHECK-SAME: into %[[TILED_PACK_DEST]]
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[FIRST_OUT_ARG]][%[[IV]], 0] [16, 32] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][%[[IV]], 0, 0] [16, 4, 16] [1, 1, 1]
+
+// -----
+
// It is valid to fuse the pack op in perfect tiling scenario when the dimension
// is dynamic and padding is not needed.
@@ -649,7 +745,7 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]]
// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32)
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [3, 16]
-// CHECK-SAME: into %[[TILED_PACK_DEST]]
+// CHECK-SAME: into %{{.*}}
// CHECK: scf.forall.in_parallel {
// CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[ELEM_OUT]]
// CHECK-SAME: [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1]
@@ -694,6 +790,33 @@ module attributes {transform.with_named_sequence} {
// -----
+func.func @nofuse_pack_with_larger_packed_shape_on_sliced_dim(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<64x4x16xf32> {
+ %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
+ %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
+ %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
+ %1 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32>
+ scf.forall.in_parallel {
+ // expected-error @below {{failed to fuse consumer of slice}}
+ tensor.parallel_insert_slice %1 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32>
+ }
+ }
+ %2 = tensor.empty() : tensor<64x4x16xf32>
+ %cst = arith.constant 0.000000e+00 : f32
+ %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [1] inner_tiles = [16] into %2 : tensor<64x32xf32> -> tensor<64x4x16xf32>
+ return %pack : tensor<64x4x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %consumer, %fused_consumer = transform.test.fuse_consumer_using_slice %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
module {
func.func @fuse_add_multiple_tilable_consumers(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) {
%c0 = arith.constant 0 : index
@@ -730,6 +853,7 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x256xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<256x256xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
// CHECK: %[[LOOP_RESULT:.*]]:3 = scf.for %[[IV1:.*]] = %[[C0]]
// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]])
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index 0137e2a69a46e..1d3b34547b440 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -545,6 +545,58 @@ module attributes {transform.with_named_sequence} {
// -----
+#map = affine_map<(d0) -> (-d0 + 4, 16)>
+func.func @fuse_pack_consumer_if_single_iteration_extra_pad_tiles(%arg0: tensor<4x4xf32>) -> tensor<2x5x16x1xf32> {
+ %0 = tensor.empty() : tensor<2x5x16x1xf32>
+ %1 = tensor.empty() : tensor<4x4xf32>
+ %2 = scf.forall (%arg1) = (0) to (4) step (16) shared_outs(%arg2 = %1) -> (tensor<4x4xf32>) {
+ %3 = affine.min #map(%arg1)
+ %extracted_slice = tensor.extract_slice %arg0[%arg1, 0] [%3, 4] [1, 1] : tensor<4x4xf32> to tensor<?x4xf32>
+ %extracted_slice_0 = tensor.extract_slice %arg2[%arg1, 0] [%3, 4] [1, 1] : tensor<4x4xf32> to tensor<?x4xf32>
+ %4 = linalg.exp ins(%extracted_slice : tensor<?x4xf32>) outs(%extracted_slice_0 : tensor<?x4xf32>) -> tensor<?x4xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %4 into %arg2[%arg1, 0] [%3, 4] [1, 1] : tensor<?x4xf32> into tensor<4x4xf32>
+ }
+ }
+ %cst = arith.constant 0.000000e+00 : f32
+ %pack = linalg.pack %2 padding_value(%cst : f32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1]
+ inner_tiles = [16, 1] into %0
+ : tensor<4x4xf32> -> tensor<2x5x16x1xf32>
+ return %pack : tensor<2x5x16x1xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %consumer = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %fused_consumer, %new_loop = transform.test.fuse_consumer %consumer into(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: func.func @fuse_pack_consumer_if_single_iteration_extra_pad_tiles(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[PACK_INIT:.*]] = tensor.empty() : tensor<2x5x16x1xf32>
+// CHECK-DAG: %[[ELEM_INIT:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (4) step (16)
+// CHECK-SAME: shared_outs(%[[ELEM_OUT_ARG:.*]] = %[[ELEM_INIT]], %[[PACK_OUT_ARG:.*]] = %[[PACK_INIT]])
+// CHECK-DAG: %[[SIZE:.+]] = affine.min affine_map<(d0) -> (-d0 + 4, 16)>(%[[IV]])
+// CHECK-DAG: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
+// CHECK-DAG: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
+// CHECK: %[[ELEM:.*]] = linalg.exp
+// CHECK-SAME: ins(%[[ELEM_SRC]]
+// CHECK-SAME: outs(%[[ELEM_DEST]]
+// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[IV]], 0, 0, 0] [2, 5, 16, 1] [1, 1, 1, 1]
+// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]]
+// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32)
+// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 1]
+// CHECK-SAME: into %[[TILED_PACK_DEST]]
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][%[[IV]], 0, 0, 0] [2, 5, 16, 1] [1, 1, 1, 1]
+
+// -----
+
func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>, %arg2: tensor<2x64x16x1xf32>) -> tensor<2x64x16x1xf32> {
%0 = scf.forall (%arg3) = (0) to (32) step (16) shared_outs(%arg4 = %arg1) -> (tensor<64x32xf32>) {
%src = tensor.extract_slice %arg0[0, %arg3] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
@@ -588,6 +640,50 @@ module attributes {transform.with_named_sequence} {
// -----
+func.func @fuse_pack_consumer_with_larger_packed_shape_on_unsliced_dim(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>, %arg2: tensor<64x4x16xf32>) -> tensor<64x4x16xf32> {
+ %0 = scf.forall (%arg3) = (0) to (64) step (16) shared_outs(%arg4 = %arg1) -> (tensor<64x32xf32>) {
+ %src = tensor.extract_slice %arg0[%arg3, 0] [16, 32] [1, 1] : tensor<64x32xf32> to tensor<16x32xf32>
+ %dest = tensor.extract_slice %arg4[%arg3, 0] [16, 32] [1, 1] : tensor<64x32xf32> to tensor<16x32xf32>
+ %1 = linalg.exp ins(%src : tensor<16x32xf32>) outs(%dest : tensor<16x32xf32>) -> tensor<16x32xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %1 into %arg4[%arg3, 0] [16, 32] [1, 1] : tensor<16x32xf32> into tensor<64x32xf32>
+ }
+ }
+ %cst = arith.constant 0.000000e+00 : f32
+ %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [1] inner_tiles = [16] into %arg2 : tensor<64x32xf32> -> tensor<64x4x16xf32>
+ return %pack : tensor<64x4x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %fused_consumer, %new_loop = transform.test.fuse_consumer %0 into(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK-LABEL: func.func @fuse_pack_consumer_with_larger_packed_shape_on_unsliced_dim(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (64) step (16)
+// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[ARG2]])
+// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [16, 32] [1, 1]
+// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]], 0] [16, 32] [1, 1]
+// CHECK: %[[ELEM:.*]] = linalg.exp
+// CHECK-SAME: ins(%[[ELEM_SRC]]
+// CHECK-SAME: outs(%[[ELEM_DEST]]
+// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[IV]], 0, 0] [16, 4, 16] [1, 1, 1]
+// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]]
+// CHECK-SAME: padding_value(%{{.*}} : f32)
+// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [16]
+// CHECK-SAME: into %[[TILED_PACK_DEST]]
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[FIRST_OUT_ARG]][%[[IV]], 0] [16, 32] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][%[[IV]], 0, 0] [16, 4, 16] [1, 1, 1]
+
+// -----
+
// It is valid to fuse the pack op in perfect tiling scenario when the dimension
// is dynamic and padding is not needed.
@@ -686,7 +782,7 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]]
// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32)
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [3, 16]
-// CHECK-SAME: into %[[TILED_PACK_DEST]]
+// CHECK-SAME: into %{{.*}}
// CHECK: scf.forall.in_parallel {
// CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[ELEM_OUT]]
// CHECK-SAME: [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1]
@@ -732,6 +828,33 @@ module attributes {transform.with_named_sequence} {
// -----
+func.func @nofuse_pack_with_larger_packed_shape_on_sliced_dim(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<64x4x16xf32> {
+ %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
+ %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
+ %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
+ %1 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %1 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32>
+ }
+ }
+ %2 = tensor.empty() : tensor<64x4x16xf32>
+ %cst = arith.constant 0.000000e+00 : f32
+ // expected-error @below {{failed to fuse consumer of slice}}
+ %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [1] inner_tiles = [16] into %2 : tensor<64x32xf32> -> tensor<64x4x16xf32>
+ return %pack : tensor<64x4x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %fused_consumer, %new_loop = transform.test.fuse_consumer %0 into(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
module {
func.func @fuse_add_multiple_tilable_consumers(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) {
%c0 = arith.constant 0 : index
@@ -772,6 +895,7 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x256xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<256x256xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
// CHECK: %[[LOOP_RESULT:.*]]:3 = scf.for %[[IV1:.*]] = %[[C0]]
// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]])
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 68b32098b7782..a903409ea06e4 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -644,7 +644,7 @@ def testPackUnPackOp():
@func.FuncOp.from_py_func(
RankedTensorType.get((128, 128), f32),
- RankedTensorType.get((16, 16, 8, 8), f32),
+ RankedTensorType.get((18, 17, 8, 8), f32),
)
def tensor_pack(src, dst):
packed = linalg.pack(
@@ -679,10 +679,10 @@ def memref_pack(src, dst):
)
# CHECK-LABEL: func.func @tensor_pack(
- # CHECK-SAME: %[[VAL_0:.*]]: tensor<128x128xf32>, %[[VAL_1:.*]]: tensor<16x16x8x8xf32>) -> tensor<128x128xf32> {
+ # CHECK-SAME: %[[VAL_0:.*]]: tensor<128x128xf32>, %[[VAL_1:.*]]: tensor<18x17x8x8xf32>) -> tensor<128x128xf32> {
# CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
- # CHECK: %[[VAL_3:.*]] = linalg.pack %[[VAL_0]] padding_value(%[[VAL_2]] : f32) inner_dims_pos = [1, 0] inner_tiles = [8, 8] into %[[VAL_1]] : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
- # CHECK: %[[VAL_4:.*]] = linalg.unpack %[[VAL_3]] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %[[VAL_0]] : tensor<16x16x8x8xf32> -> tensor<128x128xf32>
+ # CHECK: %[[VAL_3:.*]] = linalg.pack %[[VAL_0]] padding_value(%[[VAL_2]] : f32) inner_dims_pos = [1, 0] inner_tiles = [8, 8] into %[[VAL_1]] : tensor<128x128xf32> -> tensor<18x17x8x8xf32>
+ # CHECK: %[[VAL_4:.*]] = linalg.unpack %[[VAL_3]] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %[[VAL_0]] : tensor<18x17x8x8xf32> -> tensor<128x128xf32>
# CHECK: return %[[VAL_4]] : tensor<128x128xf32>
# CHECK: }
# CHECK-LABEL: func.func @memref_pack(
More information about the Mlir-commits
mailing list