[Mlir-commits] [mlir] [mlir][linalg] Take artificial padding into account for pack/unpack folding. (PR #150272)
Han-Chung Wang
llvmlistbot at llvm.org
Thu Jul 24 11:11:06 PDT 2025
https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/150272
>From 254880973fb98866332f8747f745c1ad003439c7 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Tue, 22 Jul 2025 15:04:19 -0700
Subject: [PATCH 1/6] [mlir][linalg] Take artificial padding into account for
pack/unpack folding.
The revision only folds the tensor.pad/extract_slice op into
linalg.pack/unpack ops only when it is safe to fold.
According to the doc, it is not valid to have artificial padding.
```
- The following relationship for the tiled dimensions holds:
shape(result)[inner_dims_pos[i]] = shape(source)[inner_dims_pos[i]] / inner_tiles[i].
```
The documentation improvement and verifier update will be done in a
separate PR (i.e., https://github.com/llvm/llvm-project/pull/149624).
The revision is a step towards it.
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
mlir/include/mlir/Dialect/Linalg/IR/Linalg.h | 6 ++
.../Dialect/Linalg/IR/LinalgRelayoutOps.td | 4 ++
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 55 +++++++++++++++--
.../Transforms/PackAndUnpackPatterns.cpp | 38 ++++++++----
mlir/test/Dialect/Linalg/canonicalize.mlir | 37 ++++++++----
.../Tensor/fold-into-pack-and-unpack.mlir | 60 ++++++++++++++-----
6 files changed, 158 insertions(+), 42 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
index bb0ac414bcc2d..6941939c8db5a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
@@ -10,6 +10,7 @@
#define MLIR_DIALECT_LINALG_IR_LINALG_H
#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
@@ -89,6 +90,11 @@ Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim);
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val,
int64_t dim);
+/// Returns the outer shape in the packed domain before applying the
+/// transposition.
+template <typename OpTy>
+SmallVector<int64_t> getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack);
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index c384e8b638382..fa572024ff72b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -360,6 +360,10 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
ArrayRef<int64_t> innerPermutation,
ArrayRef<int64_t> outerPermutation);
+ /// Returns true if it is statically known that the `sliceOp` result shape
+ /// is compatible with the `unPackOp`. I.e., it does not drop any tile.
+ bool canFoldSliceOp(tensor::ExtractSliceOp sliceOp);
+
/// Check if this UnPackOp is like a simple unpad operation.
/// In other words, this operation:
/// 1. drops useless dimensions (dimension of size 1), and
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 3aa6ac3ea0918..046a73c90f110 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4490,6 +4490,29 @@ Speculation::Speculatability ElementwiseOp::getSpeculatability() {
//===----------------------------------------------------------------------===//
// PackOp/UnPackOp Common
//===----------------------------------------------------------------------===//
+
+template <typename OpTy>
+SmallVector<int64_t>
+getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack) {
+ RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
+ ? packOrUnPack.getDestType()
+ : packOrUnPack.getSourceType();
+ RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
+ ? packOrUnPack.getSourceType()
+ : packOrUnPack.getDestType();
+ SmallVector<int64_t> result(
+ packedType.getShape().take_front(unpackedType.getRank()));
+ if (!packOrUnPack.getOuterDimsPerm().empty()) {
+ applyPermutationToVector(
+ result, invertPermutationVector(packOrUnPack.getOuterDimsPerm()));
+ }
+ return result;
+}
+template SmallVector<int64_t>
+ getPackedOuterShapeWithoutTransposition<PackOp>(PackOp);
+template SmallVector<int64_t>
+ getPackedOuterShapeWithoutTransposition<UnPackOp>(UnPackOp);
+
// Given the (potentially) updated packed type, `newPackedTy`, generates an
// updated mixed-tile-sizes attribute. A tile size is updated only
// when:
@@ -5447,11 +5470,7 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
if (unPackOp->hasOneUse()) {
auto extractSliceUser =
dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
- if (extractSliceUser &&
- areAllConstantIntValue(extractSliceUser.getMixedOffsets(), 0) &&
- areAllConstantIntValue(extractSliceUser.getMixedStrides(), 1) &&
- extractSliceUser.getSourceType().getRank() ==
- extractSliceUser.getResultType().getRank()) {
+ if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unPackOp);
auto newDest = rewriter.create<tensor::ExtractSliceOp>(
@@ -5494,6 +5513,32 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
return failure();
}
+bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
+ // Rank-reduced folding is not supported.
+ if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
+ return false;
+ if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
+ !areAllConstantIntValue(sliceOp.getMixedStrides(), 1))
+ return false;
+ RankedTensorType unpackedType = sliceOp.getResultType();
+ SmallVector<int64_t> outerShapeWithoutTranspose =
+ getPackedOuterShapeWithoutTransposition(*this);
+ for (auto [pos, tileSize] :
+ llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
+ if (unpackedType.isDynamicDim(pos))
+ return false;
+ if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
+ return false;
+ if (ShapedType::isDynamic(tileSize))
+ return false;
+ int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
+ unpackedType.getDimSize(pos);
+ if (paddingSize >= tileSize)
+ return false;
+ }
+ return true;
+}
+
bool UnPackOp::isLikeUnPad() {
RankedTensorType packedTensorType = getSourceType();
return isLikePadUnPad(*this, packedTensorType);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 2afa2f9b71c2a..73e157b42235a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -220,6 +220,31 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
return failure();
+ // Folding is not allowed if it introduces artificial padding. It is not
+ // safe to fold the ops if any dynamic dimension or tile size is present,
+ // because we can not infer the padding size.
+ RankedTensorType unpackedType = packOp.getSourceType();
+ SmallVector<int64_t> outerShapeWithoutTranspose =
+ getPackedOuterShapeWithoutTransposition(packOp);
+ for (auto [pos, tileSize, high] :
+ llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getStaticInnerTiles(),
+ padOp.getMixedHighPad())) {
+ if (unpackedType.isDynamicDim(pos))
+ return failure();
+ if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
+ return failure();
+ if (ShapedType::isDynamic(tileSize))
+ return failure();
+ std::optional<int64_t> cstHigh = getConstantIntValue(high);
+ if (!cstHigh)
+ return failure();
+ int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
+ unpackedType.getDimSize(pos);
+ // Do not fold the op if it requires artificial padding.
+ if (paddingSize + cstHigh.value() >= tileSize)
+ return failure();
+ }
+
rewriter.replaceOpWithNewOp<PackOp>(
packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
packOp.getMixedTiles(), constantPaddingValue,
@@ -251,17 +276,8 @@ struct FoldUnpackWithExtractSliceOp
if (controlFn && !controlFn(&sliceOp.getSourceMutable()))
return failure();
- if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
- return rewriter.notifyMatchFailure(
- sliceOp, "rank-reduced folding is not supported");
- }
-
- // Check all offsets are zeros, and all strides are ones.
- if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
- !areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) {
- return rewriter.notifyMatchFailure(
- sliceOp, "expects offsets to be 0s and strides to be 1s");
- }
+ if (!unpackOp.canFoldSliceOp(sliceOp))
+ return failure();
// Create a new empty output tensor.
Type elementType = unpackOp.getDestType().getElementType();
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 7284ae7dbd673..cd14bc3d1948b 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1890,30 +1890,47 @@ func.func @fold_cast_unpack_dynamic_tile_size(
//===----------------------------------------------------------------------===//
func.func @fold_extract_slice_into_unpack(
- %src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
-) -> tensor<28x28x?xf32> {
+ %src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>, %size : index
+) -> tensor<28x28x10xf32> {
%unpack = linalg.unpack %src
outer_dims_perm = [0, 1, 2]
inner_dims_pos = [1, 2]
inner_tiles = [16, 16]
- into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32>
+ into %dest : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32>
%extracted_slice = tensor.extract_slice %unpack
- [0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32>
- return %extracted_slice : tensor<28x28x?xf32>
+ [0, 0, 0] [28, 28, 10] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x28x10xf32>
+ return %extracted_slice : tensor<28x28x10xf32>
}
-
// CHECK-LABEL: func @fold_extract_slice_into_unpack
-// CHECK-SAME: %[[SRC:.+]]: tensor<28x2x?x16x16xf32>
-// CHECK-SAME: %[[DEST:.+]]: tensor<28x32x?xf32>
-// CHECK-SAME: %[[SIZE:.+]]: index
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[SIZE:[a-zA-Z0-9]+]]
// CHECK: %[[DEST_SLICE:.+]] = tensor.extract_slice %[[DEST]]
-// CHECK-SAME: [0, 0, 0] [28, 28, %[[SIZE]]] [1, 1, 1]
+// CHECK-SAME: [0, 0, 0] [28, 28, 10] [1, 1, 1]
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
// CHECK-SAME: into %[[DEST_SLICE]]
// CHECK: return %[[UNPACK]]
// -----
+func.func @no_fold_extract_slice_into_unpack_dynamic(
+ %src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
+) -> tensor<28x28x?xf32> {
+ %unpack = linalg.unpack %src
+ outer_dims_perm = [0, 1, 2]
+ inner_dims_pos = [1, 2]
+ inner_tiles = [16, 16]
+ into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32>
+ %extracted_slice = tensor.extract_slice %unpack
+ [0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32>
+ return %extracted_slice : tensor<28x28x?xf32>
+}
+// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_dynamic
+// CHECK: linalg.unpack
+// CHECK: tensor.extract_slice
+
+// -----
+
func.func @no_fold_extract_slice_into_unpack_rank_reducing(
%src : tensor<28x2x16xf32>, %dest : tensor<28x32xf32>
) -> tensor<28xf32> {
diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index 16efa73f87a2a..4a97d1df25f15 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -1,22 +1,32 @@
// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-into-pack-and-unpack %s | FileCheck %s
// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-into-pack-and-unpack-control %s | FileCheck %s --check-prefix=CONTROL
-func.func @fold_unpack_slice(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
+func.func @fold_unpack_slice(%arg0 : tensor<2082x1x8x32xf32>) -> tensor<16649x16xf32> {
+ %empty = tensor.empty() : tensor<16656x16xf32>
+ %0 = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
+ : tensor<2082x1x8x32xf32> -> tensor<16656x16xf32>
+ %1 = tensor.extract_slice %0[0, 0] [16649, 16] [1, 1] : tensor<16656x16xf32> to tensor<16649x16xf32>
+ return %1 : tensor<16649x16xf32>
+}
+// CHECK-LABEL: func @fold_unpack_slice(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<16649x16xf32>
+// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+// CHECK-SAME: into %[[INIT]]
+// CHECK: return %[[UNPACK]]
+
+// -----
+
+func.func @nofold_dynamic_unpack_slice(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
%arg2 : index, %arg3 : index) -> tensor<?x?xf32> {
%0 = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1
: tensor<?x?x8x4xf32> -> tensor<?x?xf32>
%1 = tensor.extract_slice %0[0, 0] [%arg2, %arg3] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
-// CHECK: func @fold_unpack_slice(
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x8x4xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
-// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
-// CHECK: %[[INIT:.+]] = tensor.empty(%[[ARG2]], %[[ARG3]]) : tensor<?x?xf32>
-// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [8, 4]
-// CHECK-SAME: into %[[INIT]]
-// CHECK: return %[[UNPACK]]
+// CHECK-LABEL: func @nofold_dynamic_unpack_slice(
+// CHECK: linalg.unpack
+// CHECK: tensor.extract_slice
// -----
@@ -59,13 +69,13 @@ func.func @nofold_unpack_slice_rank_reduced(%arg0 : tensor<?x?x8x4xf32>, %arg1 :
// -----
-func.func @pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
+func.func @pad_pack(%src: tensor<16649x16xf32>) -> tensor<2082x1x8x32xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
- %padded = tensor.pad %src low[0, 0] high[15, 0] {
+ %padded = tensor.pad %src low[0, 0] high[7, 0] {
^bb0(%arg0: index, %arg1: index):
tensor.yield %cst : f32
- } : tensor<16641x16xf32> to tensor<16656x16xf32>
+ } : tensor<16649x16xf32> to tensor<16656x16xf32>
%empty = tensor.empty() : tensor<2082x1x8x32xf32>
%pack = linalg.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
: tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
@@ -81,10 +91,10 @@ func.func @pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
// -----
-func.func @nofold_pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
+func.func @nofold_pad_pack_artificial_padding(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
- %padded = tensor.pad %src nofold low[0, 0] high[15, 0] {
+ %padded = tensor.pad %src low[0, 0] high[15, 0] {
^bb0(%arg0: index, %arg1: index):
tensor.yield %cst : f32
} : tensor<16641x16xf32> to tensor<16656x16xf32>
@@ -93,7 +103,25 @@ func.func @nofold_pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32
: tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
return %pack : tensor<2082x1x8x32xf32>
}
-// CHECK-LABEL: func.func @nofold_pad_pack
+// CHECK-LABLE: func.func @nofold_pad_pack_artificial_padding(
+// CHECK: tensor.pad
+// CHECK: linalg.pack
+
+// -----
+
+func.func @nofold_pad_pack(%src: tensor<16649x16xf32>) -> tensor<2082x1x8x32xf32> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %padded = tensor.pad %src nofold low[0, 0] high[7, 0] {
+ ^bb0(%arg0: index, %arg1: index):
+ tensor.yield %cst : f32
+ } : tensor<16649x16xf32> to tensor<16656x16xf32>
+ %empty = tensor.empty() : tensor<2082x1x8x32xf32>
+ %pack = linalg.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
+ : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
+ return %pack : tensor<2082x1x8x32xf32>
+}
+// CHECK-LABEL: func.func @nofold_pad_pack(
// CHECK: tensor.pad
// CHECK: linalg.pack
>From 2fc03160b1aa10fa738b21a3e9791f07749159e8 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Wed, 23 Jul 2025 19:15:44 -0700
Subject: [PATCH 2/6] update function name
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index 4a97d1df25f15..ac1074bb7e47a 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -109,7 +109,7 @@ func.func @nofold_pad_pack_artificial_padding(%src: tensor<16641x16xf32>) -> ten
// -----
-func.func @nofold_pad_pack(%src: tensor<16649x16xf32>) -> tensor<2082x1x8x32xf32> {
+func.func @nofold_pad_pack_with_nofold_attribute(%src: tensor<16649x16xf32>) -> tensor<2082x1x8x32xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%padded = tensor.pad %src nofold low[0, 0] high[7, 0] {
@@ -121,7 +121,7 @@ func.func @nofold_pad_pack(%src: tensor<16649x16xf32>) -> tensor<2082x1x8x32xf32
: tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
return %pack : tensor<2082x1x8x32xf32>
}
-// CHECK-LABEL: func.func @nofold_pad_pack(
+// CHECK-LABEL: func.func @nofold_pad_pack_with_nofold_attribute(
// CHECK: tensor.pad
// CHECK: linalg.pack
>From 601f21a18ee1cb723558d62bbc411f7e9e262ce7 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Thu, 24 Jul 2025 10:08:48 -0700
Subject: [PATCH 3/6] address comments
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
mlir/include/mlir/Dialect/Linalg/IR/Linalg.h | 18 +++++--
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 8 +--
.../Transforms/PackAndUnpackPatterns.cpp | 8 +--
mlir/test/Dialect/Linalg/canonicalize.mlir | 49 ++++++++++++++++++-
.../Tensor/fold-into-pack-and-unpack.mlir | 30 +++++-------
5 files changed, 82 insertions(+), 31 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
index 6941939c8db5a..62c04bb2ee1ab 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
@@ -90,11 +90,6 @@ Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim);
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val,
int64_t dim);
-/// Returns the outer shape in the packed domain before applying the
-/// transposition.
-template <typename OpTy>
-SmallVector<int64_t> getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack);
-
} // namespace linalg
} // namespace mlir
@@ -150,4 +145,17 @@ std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr);
#define GET_OP_CLASSES
#include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.h.inc"
+namespace mlir {
+namespace linalg {
+
+/// Returns the outer shape in the packed domain before applying the
+/// transposition.
+template <typename OpTy,
+ typename = std::enable_if_t<std::is_same_v<OpTy, linalg::PackOp> ||
+ std::is_same_v<OpTy, linalg::UnPackOp>>>
+SmallVector<int64_t> getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack);
+
+} // namespace linalg
+} // namespace mlir
+
#endif // MLIR_DIALECT_LINALG_IR_LINALG_H
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 046a73c90f110..3f7756e79210d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4491,7 +4491,7 @@ Speculation::Speculatability ElementwiseOp::getSpeculatability() {
// PackOp/UnPackOp Common
//===----------------------------------------------------------------------===//
-template <typename OpTy>
+template <typename OpTy, typename>
SmallVector<int64_t>
getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack) {
RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
@@ -5520,19 +5520,19 @@ bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
!areAllConstantIntValue(sliceOp.getMixedStrides(), 1))
return false;
- RankedTensorType unpackedType = sliceOp.getResultType();
+ RankedTensorType unpackedTypeAfterFold = sliceOp.getResultType();
SmallVector<int64_t> outerShapeWithoutTranspose =
getPackedOuterShapeWithoutTransposition(*this);
for (auto [pos, tileSize] :
llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
- if (unpackedType.isDynamicDim(pos))
+ if (unpackedTypeAfterFold.isDynamicDim(pos))
return false;
if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
return false;
if (ShapedType::isDynamic(tileSize))
return false;
int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
- unpackedType.getDimSize(pos);
+ unpackedTypeAfterFold.getDimSize(pos);
if (paddingSize >= tileSize)
return false;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 73e157b42235a..0b5eb449ffa2f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -220,9 +220,11 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
return failure();
- // Folding is not allowed if it introduces artificial padding. It is not
- // safe to fold the ops if any dynamic dimension or tile size is present,
- // because we can not infer the padding size.
+ // Folding is not allowed if it were to introduce artificial padding.
+ // Folding is also disabled in the case of dynamic dimensions and/or tile
+ // sizes - that is because it would be impossible to compute the padding
+ // size and hence to establish whether "artificial" padding would be
+ // created.
RankedTensorType unpackedType = packOp.getSourceType();
SmallVector<int64_t> outerShapeWithoutTranspose =
getPackedOuterShapeWithoutTransposition(packOp);
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index cd14bc3d1948b..21342a4c0d45d 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1889,7 +1889,7 @@ func.func @fold_cast_unpack_dynamic_tile_size(
// linalg.unpack + tensor.extract_slice
//===----------------------------------------------------------------------===//
-func.func @fold_extract_slice_into_unpack(
+func.func @fold_extract_slice_into_unpack_slicing_trailing_dim(
%src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>, %size : index
) -> tensor<28x28x10xf32> {
%unpack = linalg.unpack %src
@@ -1901,7 +1901,7 @@ func.func @fold_extract_slice_into_unpack(
[0, 0, 0] [28, 28, 10] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x28x10xf32>
return %extracted_slice : tensor<28x28x10xf32>
}
-// CHECK-LABEL: func @fold_extract_slice_into_unpack
+// CHECK-LABEL: func @fold_extract_slice_into_unpack_slicing_trailing_dim
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[SIZE:[a-zA-Z0-9]+]]
@@ -1913,6 +1913,51 @@ func.func @fold_extract_slice_into_unpack(
// -----
+// The available dimension size is [17, 32], because CeilDiv(%d1, 16) == 2.
+
+
+func.func @fold_extract_slice_into_unpack_slicing_dim_1(
+ %src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>, %size : index
+) -> tensor<28x17x15xf32> {
+ %unpack = linalg.unpack %src
+ inner_dims_pos = [1, 2]
+ inner_tiles = [16, 16]
+ into %dest : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32>
+ %extracted_slice = tensor.extract_slice %unpack
+ [0, 0, 0] [28, 17, 15] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x17x15xf32>
+ return %extracted_slice : tensor<28x17x15xf32>
+}
+// CHECK-LABEL: func @fold_extract_slice_into_unpack_slicing_dim_1(
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[SIZE:[a-zA-Z0-9]+]]
+// CHECK: %[[DEST_SLICE:.+]] = tensor.extract_slice %[[DEST]]
+// CHECK-SAME: [0, 0, 0] [28, 17, 15] [1, 1, 1]
+// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
+// CHECK-SAME: into %[[DEST_SLICE]]
+// CHECK: return %[[UNPACK]]
+
+// -----
+
+// The available dimension size is [17, 32], because CeilDiv(%d1, 16) == 2.
+
+func.func @no_fold_extract_slice_into_unpack_artificial_padding(
+ %src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>, %size : index
+) -> tensor<28x16x15xf32> {
+ %unpack = linalg.unpack %src
+ inner_dims_pos = [1, 2]
+ inner_tiles = [16, 16]
+ into %dest : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32>
+ %extracted_slice = tensor.extract_slice %unpack
+ [0, 0, 0] [28, 16, 15] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x16x15xf32>
+ return %extracted_slice : tensor<28x16x15xf32>
+}
+// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_artificial_padding
+// CHECK: linalg.unpack
+// CHECK: tensor.extract_slice
+
+// -----
+
func.func @no_fold_extract_slice_into_unpack_dynamic(
%src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
) -> tensor<28x28x?xf32> {
diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index ac1074bb7e47a..2a4542d3d97a7 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -69,39 +69,37 @@ func.func @nofold_unpack_slice_rank_reduced(%arg0 : tensor<?x?x8x4xf32>, %arg1 :
// -----
-func.func @pad_pack(%src: tensor<16649x16xf32>) -> tensor<2082x1x8x32xf32> {
- %c0 = arith.constant 0 : index
+func.func @fold_pad_pack(%src: tensor<9x16xf32>) -> tensor<2x1x8x32xf32> {
%cst = arith.constant 0.000000e+00 : f32
%padded = tensor.pad %src low[0, 0] high[7, 0] {
^bb0(%arg0: index, %arg1: index):
tensor.yield %cst : f32
- } : tensor<16649x16xf32> to tensor<16656x16xf32>
- %empty = tensor.empty() : tensor<2082x1x8x32xf32>
+ } : tensor<9x16xf32> to tensor<16x16xf32>
+ %empty = tensor.empty() : tensor<2x1x8x32xf32>
%pack = linalg.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
- : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
- return %pack : tensor<2082x1x8x32xf32>
+ : tensor<16x16xf32> -> tensor<2x1x8x32xf32>
+ return %pack : tensor<2x1x8x32xf32>
}
-// CHECK-LABEL: func.func @pad_pack
+// CHECK-LABEL: func.func @fold_pad_pack
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK: %[[PAD_VAL:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<2082x1x8x32xf32>
+// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<2x1x8x32xf32>
// CHECK: %[[PACK:.+]] = linalg.pack %[[SRC]]
// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32)
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %[[DEST]]
// -----
-func.func @nofold_pad_pack_artificial_padding(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
- %c0 = arith.constant 0 : index
+func.func @nofold_pad_pack_artificial_padding(%src: tensor<9x16xf32>) -> tensor<3x1x8x32xf32> {
%cst = arith.constant 0.000000e+00 : f32
- %padded = tensor.pad %src low[0, 0] high[15, 0] {
+ %padded = tensor.pad %src low[0, 0] high[8, 0] {
^bb0(%arg0: index, %arg1: index):
tensor.yield %cst : f32
- } : tensor<16641x16xf32> to tensor<16656x16xf32>
- %empty = tensor.empty() : tensor<2082x1x8x32xf32>
+ } : tensor<9x16xf32> to tensor<17x16xf32>
+ %empty = tensor.empty() : tensor<3x1x8x32xf32>
%pack = linalg.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
- : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
- return %pack : tensor<2082x1x8x32xf32>
+ : tensor<17x16xf32> -> tensor<3x1x8x32xf32>
+ return %pack : tensor<3x1x8x32xf32>
}
// CHECK-LABLE: func.func @nofold_pad_pack_artificial_padding(
// CHECK: tensor.pad
@@ -110,7 +108,6 @@ func.func @nofold_pad_pack_artificial_padding(%src: tensor<16641x16xf32>) -> ten
// -----
func.func @nofold_pad_pack_with_nofold_attribute(%src: tensor<16649x16xf32>) -> tensor<2082x1x8x32xf32> {
- %c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%padded = tensor.pad %src nofold low[0, 0] high[7, 0] {
^bb0(%arg0: index, %arg1: index):
@@ -128,7 +125,6 @@ func.func @nofold_pad_pack_with_nofold_attribute(%src: tensor<16649x16xf32>) ->
// -----
func.func @pad_pack_different_padding_value(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
- %c0 = arith.constant 0 : index
%cst0 = arith.constant 0.000000e+00 : f32
%cst1 = arith.constant 1.000000e+00 : f32
%padded = tensor.pad %src low[0, 0] high[15, 0] {
>From d832b43f567979415b428e75daa02f2f68c107e1 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Thu, 24 Jul 2025 10:10:21 -0700
Subject: [PATCH 4/6] trim includes
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
mlir/include/mlir/Dialect/Linalg/IR/Linalg.h | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
index 62c04bb2ee1ab..954aa59164c5d 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
@@ -10,7 +10,6 @@
#define MLIR_DIALECT_LINALG_IR_LINALG_H
#include "mlir/Bytecode/BytecodeOpInterface.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
>From 3aac143dc0450ee74a859ede46cf4ab74099cfdb Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Thu, 24 Jul 2025 10:56:15 -0700
Subject: [PATCH 5/6] add includes back because of canFoldSlice method
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
mlir/include/mlir/Dialect/Linalg/IR/Linalg.h | 1 +
1 file changed, 1 insertion(+)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
index 954aa59164c5d..62c04bb2ee1ab 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
@@ -10,6 +10,7 @@
#define MLIR_DIALECT_LINALG_IR_LINALG_H
#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
>From bab30f2ed82823bfc945eaf07dd2a7288974a47c Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Thu, 24 Jul 2025 11:10:23 -0700
Subject: [PATCH 6/6] Add tests and delete unused arguments
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
mlir/test/Dialect/Linalg/canonicalize.mlir | 15 +---
.../Tensor/fold-into-pack-and-unpack.mlir | 86 ++++++++++++++++---
2 files changed, 76 insertions(+), 25 deletions(-)
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 21342a4c0d45d..9cbb56e4de884 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1889,9 +1889,7 @@ func.func @fold_cast_unpack_dynamic_tile_size(
// linalg.unpack + tensor.extract_slice
//===----------------------------------------------------------------------===//
-func.func @fold_extract_slice_into_unpack_slicing_trailing_dim(
- %src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>, %size : index
-) -> tensor<28x28x10xf32> {
+func.func @fold_extract_slice_into_unpack_slicing_trailing_dim(%src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>) -> tensor<28x28x10xf32> {
%unpack = linalg.unpack %src
outer_dims_perm = [0, 1, 2]
inner_dims_pos = [1, 2]
@@ -1904,7 +1902,6 @@ func.func @fold_extract_slice_into_unpack_slicing_trailing_dim(
// CHECK-LABEL: func @fold_extract_slice_into_unpack_slicing_trailing_dim
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
-// CHECK-SAME: %[[SIZE:[a-zA-Z0-9]+]]
// CHECK: %[[DEST_SLICE:.+]] = tensor.extract_slice %[[DEST]]
// CHECK-SAME: [0, 0, 0] [28, 28, 10] [1, 1, 1]
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
@@ -1915,10 +1912,7 @@ func.func @fold_extract_slice_into_unpack_slicing_trailing_dim(
// The available dimension size is [17, 32], because CeilDiv(%d1, 16) == 2.
-
-func.func @fold_extract_slice_into_unpack_slicing_dim_1(
- %src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>, %size : index
-) -> tensor<28x17x15xf32> {
+func.func @fold_extract_slice_into_unpack_slicing_dim_1(%src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>) -> tensor<28x17x15xf32> {
%unpack = linalg.unpack %src
inner_dims_pos = [1, 2]
inner_tiles = [16, 16]
@@ -1930,7 +1924,6 @@ func.func @fold_extract_slice_into_unpack_slicing_dim_1(
// CHECK-LABEL: func @fold_extract_slice_into_unpack_slicing_dim_1(
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
-// CHECK-SAME: %[[SIZE:[a-zA-Z0-9]+]]
// CHECK: %[[DEST_SLICE:.+]] = tensor.extract_slice %[[DEST]]
// CHECK-SAME: [0, 0, 0] [28, 17, 15] [1, 1, 1]
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
@@ -1941,9 +1934,7 @@ func.func @fold_extract_slice_into_unpack_slicing_dim_1(
// The available dimension size is [17, 32], because CeilDiv(%d1, 16) == 2.
-func.func @no_fold_extract_slice_into_unpack_artificial_padding(
- %src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>, %size : index
-) -> tensor<28x16x15xf32> {
+func.func @no_fold_extract_slice_into_unpack_artificial_padding(%src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>) -> tensor<28x16x15xf32> {
%unpack = linalg.unpack %src
inner_dims_pos = [1, 2]
inner_tiles = [16, 16]
diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index 2a4542d3d97a7..9da2dea0bbd3c 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -1,19 +1,79 @@
// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-into-pack-and-unpack %s | FileCheck %s
// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-into-pack-and-unpack-control %s | FileCheck %s --check-prefix=CONTROL
-func.func @fold_unpack_slice(%arg0 : tensor<2082x1x8x32xf32>) -> tensor<16649x16xf32> {
- %empty = tensor.empty() : tensor<16656x16xf32>
- %0 = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
- : tensor<2082x1x8x32xf32> -> tensor<16656x16xf32>
- %1 = tensor.extract_slice %0[0, 0] [16649, 16] [1, 1] : tensor<16656x16xf32> to tensor<16649x16xf32>
- return %1 : tensor<16649x16xf32>
-}
-// CHECK-LABEL: func @fold_unpack_slice(
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<16649x16xf32>
-// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [8, 32]
-// CHECK-SAME: into %[[INIT]]
-// CHECK: return %[[UNPACK]]
+func.func @fold_extract_slice_into_unpack_slicing_trailing_dim(%arg0 : tensor<28x2x1x16x16xf32>) -> tensor<28x28x10xf32> {
+ %empty = tensor.empty() : tensor<28x28x15xf32>
+ %unpack = linalg.unpack %arg0
+ inner_dims_pos = [1, 2]
+ inner_tiles = [16, 16]
+ into %empty : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32>
+ %extracted_slice = tensor.extract_slice %unpack
+ [0, 0, 0] [28, 28, 10] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x28x10xf32>
+ return %extracted_slice : tensor<28x28x10xf32>
+}
+// CHECK-LABEL: func @fold_extract_slice_into_unpack_slicing_trailing_dim
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK: %[[DEST_SLICE:.+]] = tensor.empty() : tensor<28x28x10xf32>
+// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
+// CHECK-SAME: into %[[DEST_SLICE]]
+// CHECK: return %[[UNPACK]]
+
+// -----
+
+// The available dimension size is [17, 32], because CeilDiv(%d1, 16) == 2.
+
+func.func @fold_extract_slice_into_unpack_slicing_dim_1(%arg0 : tensor<28x2x1x16x16xf32>) -> tensor<28x17x15xf32> {
+ %empty = tensor.empty() : tensor<28x28x15xf32>
+ %unpack = linalg.unpack %arg0
+ inner_dims_pos = [1, 2]
+ inner_tiles = [16, 16]
+ into %empty : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32>
+ %extracted_slice = tensor.extract_slice %unpack
+ [0, 0, 0] [28, 17, 15] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x17x15xf32>
+ return %extracted_slice : tensor<28x17x15xf32>
+}
+// CHECK-LABEL: func @fold_extract_slice_into_unpack_slicing_dim_1(
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK: %[[DEST_SLICE:.+]] = tensor.empty() : tensor<28x17x15xf32>
+// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
+// CHECK-SAME: into %[[DEST_SLICE]]
+// CHECK: return %[[UNPACK]]
+
+// -----
+
+// The available dimension size is [17, 32], because CeilDiv(%d1, 16) == 2.
+
+func.func @no_fold_extract_slice_into_unpack_artificial_padding(%arg0 : tensor<28x2x1x16x16xf32>) -> tensor<28x16x15xf32> {
+ %empty = tensor.empty() : tensor<28x28x15xf32>
+ %unpack = linalg.unpack %arg0
+ inner_dims_pos = [1, 2]
+ inner_tiles = [16, 16]
+ into %empty : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32>
+ %extracted_slice = tensor.extract_slice %unpack
+ [0, 0, 0] [28, 16, 15] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x16x15xf32>
+ return %extracted_slice : tensor<28x16x15xf32>
+}
+// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_artificial_padding
+// CHECK: linalg.unpack
+// CHECK: tensor.extract_slice
+
+// -----
+
+func.func @no_fold_extract_slice_into_unpack_dynamic(
+ %src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
+) -> tensor<28x28x?xf32> {
+ %unpack = linalg.unpack %src
+ outer_dims_perm = [0, 1, 2]
+ inner_dims_pos = [1, 2]
+ inner_tiles = [16, 16]
+ into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32>
+ %extracted_slice = tensor.extract_slice %unpack
+ [0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32>
+ return %extracted_slice : tensor<28x28x?xf32>
+}
+// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_dynamic
+// CHECK: linalg.unpack
+// CHECK: tensor.extract_slice
// -----
More information about the Mlir-commits
mailing list