[Mlir-commits] [mlir] [mlir][linalg] Restrict linalg.pack to not have artificial padding. (PR #149624)
Han-Chung Wang
llvmlistbot at llvm.org
Tue Jul 22 11:46:17 PDT 2025
https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/149624
>From 1ba3c59a1efd8149560cc2c1167e0582eedac389 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Fri, 18 Jul 2025 17:46:08 -0700
Subject: [PATCH 1/8] [mlir][linalg] Restrict linalg.pack to not have extra
padding sizes.
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 17 ++--
.../Transforms/PackAndUnpackPatterns.cpp | 29 +++++++
mlir/test/Dialect/Linalg/canonicalize.mlir | 25 +++---
.../Linalg/data-layout-propagation.mlir | 22 ++---
mlir/test/Dialect/Linalg/invalid.mlir | 15 ++--
.../Dialect/Linalg/transform-lower-pack.mlir | 16 ++--
.../Tensor/fold-into-pack-and-unpack.mlir | 30 +++++--
.../tile-and-fuse-consumer.mlir | 81 -------------------
8 files changed, 102 insertions(+), 133 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 3aa6ac3ea0918..5f72cac21fce7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4601,8 +4601,8 @@ static bool isInvalidPackingPosSpecification(ArrayRef<int64_t> dimsPos,
/// Returns true if the dimension of `sourceShape` is smaller than the dimension
/// of the `limitShape`.
-static bool areAllInBound(ArrayRef<int64_t> sourceShape,
- ArrayRef<int64_t> limitShape) {
+static bool isCompatibleShape(ArrayRef<int64_t> sourceShape,
+ ArrayRef<int64_t> limitShape) {
assert(
sourceShape.size() == limitShape.size() &&
"expected source shape rank, and limit of the shape to have same rank");
@@ -4611,7 +4611,7 @@ static bool areAllInBound(ArrayRef<int64_t> sourceShape,
int64_t sourceExtent = std::get<0>(it);
int64_t limit = std::get<1>(it);
return ShapedType::isDynamic(sourceExtent) ||
- ShapedType::isDynamic(limit) || sourceExtent <= limit;
+ ShapedType::isDynamic(limit) || sourceExtent == limit;
});
}
@@ -4673,11 +4673,6 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
// represents full tiles.
RankedTensorType expectedPackedType = PackOp::inferPackedType(
unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
- if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
- return op->emitError("the shape of output is not large enough to hold the "
- "packed data. Expected at least ")
- << expectedPackedType << ", got " << packedType;
- }
if (!llvm::all_of(
llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
mixedTiles),
@@ -4694,6 +4689,12 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
return op->emitError("mismatch in inner tile sizes specified and shaped of "
"tiled dimension in the packed type");
}
+ if (!isCompatibleShape(expectedPackedType.getShape(),
+ packedType.getShape())) {
+ return op->emitError("the shape of output is not large enough to hold the "
+ "packed data. Expected at least ")
+ << expectedPackedType << ", got " << packedType;
+ }
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 2afa2f9b71c2a..02fdd01ed548b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -10,6 +10,7 @@
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/PatternMatch.h"
namespace mlir {
@@ -220,6 +221,34 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
return failure();
+ RankedTensorType srcType = packOp.getSourceType();
+ RankedTensorType destType = packOp.getDestType();
+ SmallVector<int64_t> outerShapeWithoutTranspose(
+ destType.getShape().take_front(srcType.getRank()));
+ if (!packOp.getOuterDimsPerm().empty()) {
+ applyPermutationToVector(
+ outerShapeWithoutTranspose,
+ invertPermutationVector(packOp.getOuterDimsPerm()));
+ }
+ for (auto [pos, tileSize, high] :
+ llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getStaticInnerTiles(),
+ padOp.getMixedHighPad())) {
+ if (srcType.isDynamicDim(pos))
+ return failure();
+ if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
+ return failure();
+ if (ShapedType::isDynamic(tileSize))
+ return failure();
+ std::optional<int64_t> cstHigh = getConstantIntValue(high);
+ if (!cstHigh)
+ return failure();
+ int64_t paddingSize =
+ outerShapeWithoutTranspose[pos] * tileSize - srcType.getDimSize(pos);
+ // Do not fold the ops if it requires extra padding sizes.
+ if (paddingSize + cstHigh.value() >= tileSize)
+ return failure();
+ }
+
rewriter.replaceOpWithNewOp<PackOp>(
packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
packOp.getMixedTiles(), constantPaddingValue,
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 7284ae7dbd673..dfe3bfd4a967a 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1387,42 +1387,43 @@ func.func @recursive_effect(%arg : tensor<1xf32>) {
// CHECK-LABEL: @recursive_effect
// CHECK: linalg.map
+// -----
+
//===----------------------------------------------------------------------===//
// linalg.pack
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func @fold_pack_constant_splat
// CHECK-NOT: linalg.pack
-// CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
-func.func @fold_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
+// CHECK: arith.constant dense<1.000000e-01> : tensor<4x8x8x32xf32>
+func.func @fold_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> {
%cst = arith.constant dense<1.000000e-01> : tensor<64x128xf32>
%0 = linalg.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
- inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<8x16x8x32xf32>
- return %0 : tensor<8x16x8x32xf32>
+ inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<4x8x8x32xf32>
+ return %0 : tensor<4x8x8x32xf32>
}
// -----
// CHECK-LABEL: func @fold_padding_value_pack_constant_splat
// CHECK-NOT: linalg.pack
-// CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
-func.func @fold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
+// CHECK: arith.constant dense<1.000000e-01> : tensor<4x8x8x32xf32>
+func.func @fold_padding_value_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> {
%pad = arith.constant 1.000000e-01 : f32
%cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
%0 = linalg.pack %cst
padding_value(%pad : f32)
outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
- inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
- return %0 : tensor<8x16x8x32xf32>
+ inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<4x8x8x32xf32>
+ return %0 : tensor<4x8x8x32xf32>
}
-
// -----
// CHECK-LABEL: func @nofold_padding_value_pack_constant_splat
// CHECK: arith.constant dense<1.000000e-01> : tensor<63x127xf32>
// CHECK: linalg.pack
-func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
+func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> {
%pad = arith.constant 0.0 : f32
%cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
%0 = linalg.pack %cst
@@ -1430,8 +1431,8 @@ func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32
outer_dims_perm = [1, 0]
inner_dims_pos = [0, 1]
inner_tiles = [8, 32]
- into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
- return %0 : tensor<8x16x8x32xf32>
+ into %dest : tensor<63x127xf32> -> tensor<4x8x8x32xf32>
+ return %0 : tensor<4x8x8x32xf32>
}
// -----
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 6fc8d9f152f4e..ae87fffd1af02 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1295,21 +1295,21 @@ func.func @no_bubble_up_pack_expanded_padding_through_expand_cannot_reassociate(
// -----
-func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate(%arg0: tensor<32x64xf32>) -> tensor<8x4x16x8xf32> {
- %empty = tensor.empty() : tensor<8x4x16x8xf32>
+func.func @bubble_up_pack_extending_dimension_through_expand_can_reassociate(%arg0: tensor<32x64xf32>) -> tensor<4x4x16x8xf32> {
+ %empty = tensor.empty() : tensor<4x4x16x8xf32>
%expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32>
- %pack = linalg.pack %expanded inner_dims_pos = [0] inner_tiles = [8] into %empty : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32>
- return %pack : tensor<8x4x16x8xf32>
+ %pack = linalg.pack %expanded inner_dims_pos = [0] inner_tiles = [8] into %empty : tensor<32x4x16xf32> -> tensor<4x4x16x8xf32>
+ return %pack : tensor<4x4x16x8xf32>
}
-// CHECK-LABEL: func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate(
+// CHECK-LABEL: func.func @bubble_up_pack_extending_dimension_through_expand_can_reassociate(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x4x16x8xf32>
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]]
-// CHECK-SAME: output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32>
-// CHECK: %[[PACK:.+]] = linalg.pack %[[EXPANDED]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x64x8xf32>
+// CHECK: %[[PACK:.+]] = linalg.pack %[[ARG0]]
// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [8] into %[[EMPTY]]
-// CHECK-SAME: : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32>
-// CHECK: return %[[PACK]] : tensor<8x4x16x8xf32>
+// CHECK-SAME: : tensor<32x64xf32> -> tensor<4x64x8xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3]]
+// CHECK-SAME: output_shape [4, 4, 16, 8] : tensor<4x64x8xf32> into tensor<4x4x16x8xf32>
+// CHECK: return %[[EXPANDED]] : tensor<4x4x16x8xf32>
// -----
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index da1dfc7b6a624..83611a217f652 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1760,6 +1760,7 @@ func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf
}
// -----
+
func.func @pack_mismatch_inner_tile_size_and_output_shape(
%input : tensor<?x?xf32>, %output : tensor<?x?x8x8xf32>) -> tensor<?x?x8x8xf32> {
// expected-error at +1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
@@ -1834,17 +1835,17 @@ func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tenso
// -----
-func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> {
- // expected-error at +1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x8x16x32xf32>', got 'tensor<8x8x32x16xf32>'}}
- %0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32>
- return %0 : tensor<8x8x32x16xf32>
+func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x7x16x32xf32>) -> tensor<8x7x16x32xf32> {
+ // expected-error at +1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x8x16x32xf32>', got 'tensor<8x7x16x32xf32>'}}
+ %0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x7x16x32xf32>
+ return %0 : tensor<8x7x16x32xf32>
}
// -----
-func.func @unpack_invalid(%output: tensor<256x128xf32>, %input: tensor<8x8x32x16xf32>) -> tensor<256x128xf32> {
- // expected-error at +1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x32x4x32xf32>', got 'tensor<8x8x32x16xf32>'}}
- %0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
+func.func @unpack_invalid(%output: tensor<256x128xf32>, %input: tensor<8x8x4x32xf32>) -> tensor<256x128xf32> {
+ // expected-error at +1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x32x4x32xf32>', got 'tensor<8x8x4x32xf32>'}}
+ %0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x4x32xf32> -> tensor<256x128xf32>
return %0 : tensor<256x128xf32>
}
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 81fd7a8a947d7..9e7681d1a1b7d 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -326,23 +326,23 @@ module attributes {transform.with_named_sequence} {
// -----
// CHECK-LABEL: func.func @pack_with_pad(
-func.func @pack_with_pad(%src: tensor<4225x12xf32>, %dest: tensor<265x16x16x1xf32>)
- -> tensor<265x16x16x1xf32> {
+func.func @pack_with_pad(%src: tensor<4225x12xf32>, %dest: tensor<265x12x16x1xf32>)
+ -> tensor<265x12x16x1xf32> {
// CHECK: tensor.pad {{.*}} low[0, 0]
- // CHECK: : tensor<4225x12xf32> to tensor<4240x16xf32>
+ // CHECK: : tensor<4225x12xf32> to tensor<4240x12xf32>
// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2, 3]]
- // CHECK-SAME: : tensor<4240x16xf32> into tensor<265x16x16x1xf32>
+ // CHECK-SAME: : tensor<4240x12xf32> into tensor<265x16x12x1xf32>
// CHECK: linalg.transpose
- // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<265x16x16x1xf32>)
- // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<265x16x16x1xf32>)
+ // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<265x16x12x1xf32>)
+ // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<265x12x16x1xf32>)
// CHECK-SAME: permutation = [0, 2, 1, 3]
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.pack %src
padding_value(%cst : f32)
inner_dims_pos = [0, 1]
inner_tiles = [16, 1] into %dest
- : tensor<4225x12xf32> -> tensor<265x16x16x1xf32>
- return %0 : tensor<265x16x16x1xf32>
+ : tensor<4225x12xf32> -> tensor<265x12x16x1xf32>
+ return %0 : tensor<265x12x16x1xf32>
}
module attributes {transform.with_named_sequence} {
diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index 16efa73f87a2a..eb62de13ebc94 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -59,13 +59,13 @@ func.func @nofold_unpack_slice_rank_reduced(%arg0 : tensor<?x?x8x4xf32>, %arg1 :
// -----
-func.func @pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
+func.func @pad_pack(%src: tensor<16649x16xf32>) -> tensor<2082x1x8x32xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
- %padded = tensor.pad %src low[0, 0] high[15, 0] {
+ %padded = tensor.pad %src low[0, 0] high[7, 0] {
^bb0(%arg0: index, %arg1: index):
tensor.yield %cst : f32
- } : tensor<16641x16xf32> to tensor<16656x16xf32>
+ } : tensor<16649x16xf32> to tensor<16656x16xf32>
%empty = tensor.empty() : tensor<2082x1x8x32xf32>
%pack = linalg.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
: tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
@@ -81,10 +81,10 @@ func.func @pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
// -----
-func.func @nofold_pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
+func.func @nofold_pad_pack_extra_padding(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
- %padded = tensor.pad %src nofold low[0, 0] high[15, 0] {
+ %padded = tensor.pad %src low[0, 0] high[15, 0] {
^bb0(%arg0: index, %arg1: index):
tensor.yield %cst : f32
} : tensor<16641x16xf32> to tensor<16656x16xf32>
@@ -93,7 +93,25 @@ func.func @nofold_pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32
: tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
return %pack : tensor<2082x1x8x32xf32>
}
-// CHECK-LABEL: func.func @nofold_pad_pack
+// CHECK-LABLE: func.func @nofold_pad_pack_extra_padding(
+// CHECK: tensor.pad
+// CHECK: linalg.pack
+
+// -----
+
+func.func @nofold_pad_pack(%src: tensor<16649x16xf32>) -> tensor<2082x1x8x32xf32> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %padded = tensor.pad %src nofold low[0, 0] high[7, 0] {
+ ^bb0(%arg0: index, %arg1: index):
+ tensor.yield %cst : f32
+ } : tensor<16649x16xf32> to tensor<16656x16xf32>
+ %empty = tensor.empty() : tensor<2082x1x8x32xf32>
+ %pack = linalg.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
+ : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
+ return %pack : tensor<2082x1x8x32xf32>
+}
+// CHECK-LABEL: func.func @nofold_pad_pack(
// CHECK: tensor.pad
// CHECK: linalg.pack
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index cdbca7228ded3..e48e5c6c308be 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -646,87 +646,6 @@ module attributes {transform.with_named_sequence} {
// -----
-// It is valid to fuse the pack if the dimension is not tiled even when it needs
-// extra padding.
-
-func.func @fuse_pack_consumer_with_untiled_extra_padding(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<33x2x3x16xf32> {
- %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
- %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
- %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
- %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32>
- scf.forall.in_parallel {
- tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32>
- }
- }
- %1 = tensor.empty() : tensor<33x2x3x16xf32>
- %cst = arith.constant 0.000000e+00 : f32
- %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<33x2x3x16xf32>
- return %pack : tensor<33x2x3x16xf32>
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
- transform.yield
- }
-}
-// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
-// CHECK: func.func @fuse_pack_consumer_with_untiled_extra_padding(
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK-DAG: %[[OUT_INIT:.*]] = tensor.empty() : tensor<33x2x3x16xf32>
-// CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16)
-// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]])
-// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1]
-// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
-// CHECK: %[[ELEM:.*]] = linalg.exp
-// CHECK-SAME: ins(%[[ELEM_SRC]]
-// CHECK-SAME: outs(%[[ELEM_DEST]]
-// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
-// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [33, 1, 3, 16] [1, 1, 1, 1]
-// CHECK: %[[TILED_PACK_OUT:.*]] = linalg.pack %[[ELEM]]
-// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32)
-// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [3, 16]
-// CHECK-SAME: into %[[TILED_PACK_DEST]]
-// CHECK: scf.forall.in_parallel {
-// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
-// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [33, 1, 3, 16] [1, 1, 1, 1]
-
-// -----
-
-// If the dimension is tiled and it needs extra padding, do not fuse the pack
-// op.
-
-func.func @nofuse_pack_consumer_with_extra_padding(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x32x3x16xf32> {
- %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
- %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
- %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
- %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32>
- scf.forall.in_parallel {
- // expected-error @below {{failed to fuse consumer of slice}}
- tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32>
- }
- }
- %1 = tensor.empty() : tensor<23x32x3x16xf32>
- %cst = arith.constant 0.000000e+00 : f32
- %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<23x32x3x16xf32>
- return %pack : tensor<23x32x3x16xf32>
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
- transform.yield
- }
-}
-
-// -----
-
// Imperfect tiling is not supported in pack op consumer fusion.
#map = affine_map<(d0) -> (d0 * 5)>
>From 1e75325716a7c1a14e51969ffcec6904b2ec7b30 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Mon, 21 Jul 2025 14:02:27 -0700
Subject: [PATCH 2/8] Delete dup test
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
.../Linalg/data-layout-propagation.mlir | 18 ------------------
1 file changed, 18 deletions(-)
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index ae87fffd1af02..cc26fa48abf4b 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1295,24 +1295,6 @@ func.func @no_bubble_up_pack_expanded_padding_through_expand_cannot_reassociate(
// -----
-func.func @bubble_up_pack_extending_dimension_through_expand_can_reassociate(%arg0: tensor<32x64xf32>) -> tensor<4x4x16x8xf32> {
- %empty = tensor.empty() : tensor<4x4x16x8xf32>
- %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32>
- %pack = linalg.pack %expanded inner_dims_pos = [0] inner_tiles = [8] into %empty : tensor<32x4x16xf32> -> tensor<4x4x16x8xf32>
- return %pack : tensor<4x4x16x8xf32>
-}
-// CHECK-LABEL: func.func @bubble_up_pack_extending_dimension_through_expand_can_reassociate(
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x64x8xf32>
-// CHECK: %[[PACK:.+]] = linalg.pack %[[ARG0]]
-// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [8] into %[[EMPTY]]
-// CHECK-SAME: : tensor<32x64xf32> -> tensor<4x64x8xf32>
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3]]
-// CHECK-SAME: output_shape [4, 4, 16, 8] : tensor<4x64x8xf32> into tensor<4x4x16x8xf32>
-// CHECK: return %[[EXPANDED]] : tensor<4x4x16x8xf32>
-
-// -----
-
func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index, %sz0: index) -> tensor<?x256x256xf32> {
%6 = tensor.empty(%dim) : tensor<?x256xf32>
%unpack = linalg.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
>From aafbbfce79a4d0bc312640beaa1d79c0727d0593 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Mon, 21 Jul 2025 14:59:24 -0700
Subject: [PATCH 3/8] Update docs and verifiers
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
.../Dialect/Linalg/IR/LinalgRelayoutOps.td | 14 +++++++++--
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 25 ++++---------------
mlir/test/Dialect/Linalg/invalid.mlir | 6 ++---
3 files changed, 20 insertions(+), 25 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index c384e8b638382..c1a96d5eb1dbe 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -150,9 +150,10 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
`padding_value` specifies a padding value at the boundary on non-perfectly
divisible dimensions. Padding is optional:
- - If absent, it is UB if the tile does not perfectly divide the dimension.
+ - If absent, it assumes the tile perfectly divides the dimension.
- If present, it will pad along high dimensions (high-padding) to make the
- tile complete.
+ tile complete. Note that it is not allowed to have artificial padding that
+ is not strictly required by linalg.pack.
Example:
```mlir
@@ -167,6 +168,15 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
//
// Note: Only tiled dimensions can be padded.
```
+
+ Invalid example that has artificial padding:
+ ```mlir
+ %0 = linalg.pack %src padding_value(%cst : f32) inner_dims_pos = [0]
+ inner_tiles = [8] into %dest
+ : tensor<9xf32> -> tensor<3x8xf32>
+ // \
+ // expect tensor<2x8xf32> because CeilDiv(9, 8) = 2
+ ```
}];
let arguments = (ins AnyRankedTensor:$source,
AnyRankedTensor:$dest,
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 5f72cac21fce7..248cefc5d707f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -32,6 +32,7 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
@@ -4599,22 +4600,6 @@ static bool isInvalidPackingPosSpecification(ArrayRef<int64_t> dimsPos,
});
}
-/// Returns true if the dimension of `sourceShape` is smaller than the dimension
-/// of the `limitShape`.
-static bool isCompatibleShape(ArrayRef<int64_t> sourceShape,
- ArrayRef<int64_t> limitShape) {
- assert(
- sourceShape.size() == limitShape.size() &&
- "expected source shape rank, and limit of the shape to have same rank");
- return llvm::all_of(
- llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
- int64_t sourceExtent = std::get<0>(it);
- int64_t limit = std::get<1>(it);
- return ShapedType::isDynamic(sourceExtent) ||
- ShapedType::isDynamic(limit) || sourceExtent == limit;
- });
-}
-
template <typename OpTy>
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
@@ -4689,10 +4674,10 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
return op->emitError("mismatch in inner tile sizes specified and shaped of "
"tiled dimension in the packed type");
}
- if (!isCompatibleShape(expectedPackedType.getShape(),
- packedType.getShape())) {
- return op->emitError("the shape of output is not large enough to hold the "
- "packed data. Expected at least ")
+ if (failed(verifyCompatibleShape(expectedPackedType.getShape(),
+ packedType.getShape()))) {
+ return op->emitError("the shape of unpacked domain value is not large "
+ "enough to hold the packed data. Expected at least ")
<< expectedPackedType << ", got " << packedType;
}
return success();
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 83611a217f652..4299a15026f91 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1828,7 +1828,7 @@ func.func @unpack_invalid_outer_dims_perm(%source: tensor<128x256xf32>, %dest: t
// The outer dims in the output tensor are incorrectly/unexpectedly transposed.
// This could be fixed by adding `outer_dims_perm = [1, 0]` (the default value assumes no transpose).
func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tensor<4x16x32x16xf32>) -> tensor<4x16x32x16xf32> {
- // expected-error at +1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<16x4x32x16xf32>', got 'tensor<4x16x32x16xf32>'}}
+ // expected-error at +1 {{the shape of unpacked domain value is not large enough to hold the packed data. Expected at least 'tensor<16x4x32x16xf32>', got 'tensor<4x16x32x16xf32>'}}
%0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [32, 16] into %output : tensor<256x128xf32> -> tensor<4x16x32x16xf32>
return %0 : tensor<4x16x32x16xf32>
}
@@ -1836,7 +1836,7 @@ func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tenso
// -----
func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x7x16x32xf32>) -> tensor<8x7x16x32xf32> {
- // expected-error at +1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x8x16x32xf32>', got 'tensor<8x7x16x32xf32>'}}
+ // expected-error at +1 {{the shape of unpacked domain value is not large enough to hold the packed data. Expected at least 'tensor<8x8x16x32xf32>', got 'tensor<8x7x16x32xf32>'}}
%0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x7x16x32xf32>
return %0 : tensor<8x7x16x32xf32>
}
@@ -1844,7 +1844,7 @@ func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x7x16x32xf
// -----
func.func @unpack_invalid(%output: tensor<256x128xf32>, %input: tensor<8x8x4x32xf32>) -> tensor<256x128xf32> {
- // expected-error at +1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x32x4x32xf32>', got 'tensor<8x8x4x32xf32>'}}
+ // expected-error at +1 {{the shape of unpacked domain value is not large enough to hold the packed data. Expected at least 'tensor<8x32x4x32xf32>', got 'tensor<8x8x4x32xf32>'}}
%0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x4x32xf32> -> tensor<256x128xf32>
return %0 : tensor<256x128xf32>
}
>From cf2adb655a8b27ec88787de5b587d802e3f5c708 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Mon, 21 Jul 2025 14:59:32 -0700
Subject: [PATCH 4/8] Update folding patterns.
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
.../Transforms/PackAndUnpackPatterns.cpp | 61 +++++++++++++++----
.../Tensor/fold-into-pack-and-unpack.mlir | 30 ++++++---
2 files changed, 68 insertions(+), 23 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 02fdd01ed548b..31574f4a8791c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -195,6 +195,28 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
}
};
+/// Returns the outer shape in the packed domain before applying the
+/// transposition.
+template <typename OpTy>
+static SmallVector<int64_t>
+getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack) {
+ static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
+ "applies to only pack or unpack operations");
+ RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
+ ? packOrUnPack.getDestType()
+ : packOrUnPack.getSourceType();
+ RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
+ ? packOrUnPack.getSourceType()
+ : packOrUnPack.getDestType();
+ SmallVector<int64_t> result(
+ packedType.getShape().take_front(unpackedType.getRank()));
+ if (!packOrUnPack.getOuterDimsPerm().empty()) {
+ applyPermutationToVector(
+ result, invertPermutationVector(packOrUnPack.getOuterDimsPerm()));
+ }
+ return result;
+}
+
/// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
/// the pad op has zero low paddings, or if `pack` has no padding values.
struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
@@ -221,19 +243,14 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
return failure();
- RankedTensorType srcType = packOp.getSourceType();
- RankedTensorType destType = packOp.getDestType();
- SmallVector<int64_t> outerShapeWithoutTranspose(
- destType.getShape().take_front(srcType.getRank()));
- if (!packOp.getOuterDimsPerm().empty()) {
- applyPermutationToVector(
- outerShapeWithoutTranspose,
- invertPermutationVector(packOp.getOuterDimsPerm()));
- }
+ // Folding is not allowed if it introduces artificial padding.
+ RankedTensorType unpackedType = packOp.getSourceType();
+ SmallVector<int64_t> outerShapeWithoutTranspose =
+ getPackedOuterShapeWithoutTransposition(packOp);
for (auto [pos, tileSize, high] :
llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getStaticInnerTiles(),
padOp.getMixedHighPad())) {
- if (srcType.isDynamicDim(pos))
+ if (unpackedType.isDynamicDim(pos))
return failure();
if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
return failure();
@@ -242,9 +259,9 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
std::optional<int64_t> cstHigh = getConstantIntValue(high);
if (!cstHigh)
return failure();
- int64_t paddingSize =
- outerShapeWithoutTranspose[pos] * tileSize - srcType.getDimSize(pos);
- // Do not fold the ops if it requires extra padding sizes.
+ int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
+ unpackedType.getDimSize(pos);
+ // Do not fold the op if it requires artificial padding.
if (paddingSize + cstHigh.value() >= tileSize)
return failure();
}
@@ -292,6 +309,24 @@ struct FoldUnpackWithExtractSliceOp
sliceOp, "expects offsets to be 0s and strides to be 1s");
}
+ // Folding is not allowed if any tile is dropped.
+ RankedTensorType unpackedType = sliceOp.getResultType();
+ SmallVector<int64_t> outerShapeWithoutTranspose =
+ getPackedOuterShapeWithoutTransposition(unpackOp);
+ for (auto [pos, tileSize] : llvm::zip_equal(
+ unpackOp.getInnerDimsPos(), unpackOp.getStaticInnerTiles())) {
+ if (unpackedType.isDynamicDim(pos))
+ return failure();
+ if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
+ return failure();
+ if (ShapedType::isDynamic(tileSize))
+ return failure();
+ int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
+ unpackedType.getDimSize(pos);
+ if (paddingSize >= tileSize)
+ return failure();
+ }
+
// Create a new empty output tensor.
Type elementType = unpackOp.getDestType().getElementType();
Value output = rewriter.create<tensor::EmptyOp>(
diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index eb62de13ebc94..5d3d668568058 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -1,22 +1,32 @@
// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-into-pack-and-unpack %s | FileCheck %s
// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-into-pack-and-unpack-control %s | FileCheck %s --check-prefix=CONTROL
-func.func @fold_unpack_slice(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
+func.func @fold_unpack_slice(%arg0 : tensor<2082x1x8x32xf32>) -> tensor<16649x16xf32> {
+ %empty = tensor.empty() : tensor<16656x16xf32>
+ %0 = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
+ : tensor<2082x1x8x32xf32> -> tensor<16656x16xf32>
+ %1 = tensor.extract_slice %0[0, 0] [16649, 16] [1, 1] : tensor<16656x16xf32> to tensor<16649x16xf32>
+ return %1 : tensor<16649x16xf32>
+}
+// CHECK-LABEL: func @fold_unpack_slice(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<16649x16xf32>
+// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+// CHECK-SAME: into %[[INIT]]
+// CHECK: return %[[UNPACK]]
+
+// -----
+
+func.func @nofold_dynamic_unpack_slice(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
%arg2 : index, %arg3 : index) -> tensor<?x?xf32> {
%0 = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1
: tensor<?x?x8x4xf32> -> tensor<?x?xf32>
%1 = tensor.extract_slice %0[0, 0] [%arg2, %arg3] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
-// CHECK: func @fold_unpack_slice(
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x8x4xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
-// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
-// CHECK: %[[INIT:.+]] = tensor.empty(%[[ARG2]], %[[ARG3]]) : tensor<?x?xf32>
-// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [8, 4]
-// CHECK-SAME: into %[[INIT]]
-// CHECK: return %[[UNPACK]]
+// CHECK-LABEL: func @nofold_dynamic_unpack_slice(
+// CHECK: linalg.unpack
+// CHECK: tensor.extract_slice
// -----
>From 6c7fd89e135d42c2a2023f95a0ba317956d79f03 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Mon, 21 Jul 2025 15:00:19 -0700
Subject: [PATCH 5/8] update test name
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index 5d3d668568058..4a97d1df25f15 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -91,7 +91,7 @@ func.func @pad_pack(%src: tensor<16649x16xf32>) -> tensor<2082x1x8x32xf32> {
// -----
-func.func @nofold_pad_pack_extra_padding(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
+func.func @nofold_pad_pack_artificial_padding(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%padded = tensor.pad %src low[0, 0] high[15, 0] {
@@ -103,7 +103,7 @@ func.func @nofold_pad_pack_extra_padding(%src: tensor<16641x16xf32>) -> tensor<2
: tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
return %pack : tensor<2082x1x8x32xf32>
}
-// CHECK-LABLE: func.func @nofold_pad_pack_extra_padding(
+// CHECK-LABLE: func.func @nofold_pad_pack_artificial_padding(
// CHECK: tensor.pad
// CHECK: linalg.pack
>From 0fdd02369b5eb19ad9f988d84f536b739c3d766d Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Mon, 21 Jul 2025 15:06:47 -0700
Subject: [PATCH 6/8] format
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 31574f4a8791c..cac77e45e8575 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -203,8 +203,8 @@ getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack) {
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
"applies to only pack or unpack operations");
RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
- ? packOrUnPack.getDestType()
- : packOrUnPack.getSourceType();
+ ? packOrUnPack.getDestType()
+ : packOrUnPack.getSourceType();
RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
? packOrUnPack.getSourceType()
: packOrUnPack.getDestType();
>From 6acc2e2a033b782f68b35e2127fafbc731122f00 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Tue, 22 Jul 2025 11:16:56 -0700
Subject: [PATCH 7/8] Address comments
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
mlir/include/mlir/Dialect/Linalg/IR/Linalg.h | 5 ++++
.../Dialect/Linalg/IR/LinalgRelayoutOps.td | 5 +++-
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 29 +++++++++++++++++--
.../Transforms/PackAndUnpackPatterns.cpp | 26 ++---------------
mlir/test/Dialect/Linalg/invalid.mlir | 26 +++++++++++++++--
5 files changed, 61 insertions(+), 30 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
index bb0ac414bcc2d..cced80b03de4d 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
@@ -89,6 +89,11 @@ Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim);
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val,
int64_t dim);
+/// Returns the outer shape in the packed domain before applying the
+/// transposition.
+template <typename OpTy>
+SmallVector<int64_t> getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack);
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index c1a96d5eb1dbe..73757ecb73b4b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -153,7 +153,10 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
- If absent, it assumes the tile perfectly divides the dimension.
- If present, it will pad along high dimensions (high-padding) to make the
tile complete. Note that it is not allowed to have artificial padding that
- is not strictly required by linalg.pack.
+ is not strictly required by linalg.pack (i.e., padding past what is needed
+ to complete the last tile along each packed dimension).. It is UB if extra
+ padding is requested for dynamic cases. For static cases, they are caught
+ by the verifier.
Example:
```mlir
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 248cefc5d707f..f8ad1684aec32 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4491,6 +4491,29 @@ Speculation::Speculatability ElementwiseOp::getSpeculatability() {
//===----------------------------------------------------------------------===//
// PackOp/UnPackOp Common
//===----------------------------------------------------------------------===//
+
+template <typename OpTy>
+SmallVector<int64_t>
+getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack) {
+ RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
+ ? packOrUnPack.getDestType()
+ : packOrUnPack.getSourceType();
+ RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
+ ? packOrUnPack.getSourceType()
+ : packOrUnPack.getDestType();
+ SmallVector<int64_t> result(
+ packedType.getShape().take_front(unpackedType.getRank()));
+ if (!packOrUnPack.getOuterDimsPerm().empty()) {
+ applyPermutationToVector(
+ result, invertPermutationVector(packOrUnPack.getOuterDimsPerm()));
+ }
+ return result;
+}
+template SmallVector<int64_t>
+ getPackedOuterShapeWithoutTransposition<PackOp>(PackOp);
+template SmallVector<int64_t>
+ getPackedOuterShapeWithoutTransposition<UnPackOp>(UnPackOp);
+
// Given the (potentially) updated packed type, `newPackedTy`, generates an
// updated mixed-tile-sizes attribute. A tile size is updated only
// when:
@@ -4676,9 +4699,9 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
}
if (failed(verifyCompatibleShape(expectedPackedType.getShape(),
packedType.getShape()))) {
- return op->emitError("the shape of unpacked domain value is not large "
- "enough to hold the packed data. Expected at least ")
- << expectedPackedType << ", got " << packedType;
+ return op->emitError("expected ")
+ << expectedPackedType << " for the unpacked domain value, got "
+ << packedType;
}
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index cac77e45e8575..299971a14a6c3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -195,28 +195,6 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
}
};
-/// Returns the outer shape in the packed domain before applying the
-/// transposition.
-template <typename OpTy>
-static SmallVector<int64_t>
-getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack) {
- static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
- "applies to only pack or unpack operations");
- RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
- ? packOrUnPack.getDestType()
- : packOrUnPack.getSourceType();
- RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
- ? packOrUnPack.getSourceType()
- : packOrUnPack.getDestType();
- SmallVector<int64_t> result(
- packedType.getShape().take_front(unpackedType.getRank()));
- if (!packOrUnPack.getOuterDimsPerm().empty()) {
- applyPermutationToVector(
- result, invertPermutationVector(packOrUnPack.getOuterDimsPerm()));
- }
- return result;
-}
-
/// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
/// the pad op has zero low paddings, or if `pack` has no padding values.
struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
@@ -243,7 +221,9 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
return failure();
- // Folding is not allowed if it introduces artificial padding.
+ // Folding is not allowed if it introduces artificial padding. It is not
+ // safe to fold the ops if any dynamic dimension or tile size is present,
+ // because we can not infer the padding size.
RankedTensorType unpackedType = packOp.getSourceType();
SmallVector<int64_t> outerShapeWithoutTranspose =
getPackedOuterShapeWithoutTransposition(packOp);
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 4299a15026f91..595dc96a30fbc 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1825,10 +1825,21 @@ func.func @unpack_invalid_outer_dims_perm(%source: tensor<128x256xf32>, %dest: t
// -----
+func.func @pack_with_artificial_padding(%input: tensor<9xf32>, %output: tensor<3x8xf32>) -> tensor<3x8xf32> {
+ %cst = arith.constant 0.0 : f32
+ // expected-error at +1 {{expected 'tensor<2x8xf32>' for the unpacked domain value, got 'tensor<3x8xf32>'}}
+ %0 = linalg.pack %input padding_value(%cst : f32) inner_dims_pos = [0]
+ inner_tiles = [8] into %output
+ : tensor<9xf32> -> tensor<3x8xf32>
+ return %0 : tensor<3x8xf32>
+}
+
+// -----
+
// The outer dims in the output tensor are incorrectly/unexpectedly transposed.
// This could be fixed by adding `outer_dims_perm = [1, 0]` (the default value assumes no transpose).
func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tensor<4x16x32x16xf32>) -> tensor<4x16x32x16xf32> {
- // expected-error at +1 {{the shape of unpacked domain value is not large enough to hold the packed data. Expected at least 'tensor<16x4x32x16xf32>', got 'tensor<4x16x32x16xf32>'}}
+ // expected-error at +1 {{expected 'tensor<16x4x32x16xf32>' for the unpacked domain value, got 'tensor<4x16x32x16xf32>'}}
%0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [32, 16] into %output : tensor<256x128xf32> -> tensor<4x16x32x16xf32>
return %0 : tensor<4x16x32x16xf32>
}
@@ -1836,15 +1847,24 @@ func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tenso
// -----
func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x7x16x32xf32>) -> tensor<8x7x16x32xf32> {
- // expected-error at +1 {{the shape of unpacked domain value is not large enough to hold the packed data. Expected at least 'tensor<8x8x16x32xf32>', got 'tensor<8x7x16x32xf32>'}}
+ // expected-error at +1 {{expected 'tensor<8x8x16x32xf32>' for the unpacked domain value, got 'tensor<8x7x16x32xf32>'}}
%0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x7x16x32xf32>
return %0 : tensor<8x7x16x32xf32>
}
// -----
+func.func @unpack_with_slicing_tiles(%input: tensor<3x8xf32>, %output: tensor<9xf32>) -> tensor<9xf32> {
+ // expected-error at +1 {{expected 'tensor<2x8xf32>' for the unpacked domain value, got 'tensor<3x8xf32>'}}
+ %0 = linalg.unpack %input inner_dims_pos = [0] inner_tiles = [8] into %output
+ : tensor<3x8xf32> -> tensor<9xf32>
+ return %0 : tensor<9xf32>
+}
+
+// -----
+
func.func @unpack_invalid(%output: tensor<256x128xf32>, %input: tensor<8x8x4x32xf32>) -> tensor<256x128xf32> {
- // expected-error at +1 {{the shape of unpacked domain value is not large enough to hold the packed data. Expected at least 'tensor<8x32x4x32xf32>', got 'tensor<8x8x4x32xf32>'}}
+ // expected-error at +1 {{expected 'tensor<8x32x4x32xf32>' for the unpacked domain value, got 'tensor<8x8x4x32xf32>'}}
%0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x4x32xf32> -> tensor<256x128xf32>
return %0 : tensor<256x128xf32>
}
>From 4f2247f739ef7ac891797957aa0dc7b69c63e8ec Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Tue, 22 Jul 2025 11:45:56 -0700
Subject: [PATCH 8/8] Fix canonicalization pattern.
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
mlir/include/mlir/Dialect/Linalg/IR/Linalg.h | 1 +
.../Dialect/Linalg/IR/LinalgRelayoutOps.td | 4 ++
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 32 +++++++++++++---
.../Transforms/PackAndUnpackPatterns.cpp | 31 +---------------
mlir/test/Dialect/Linalg/canonicalize.mlir | 37 ++++++++++++++-----
5 files changed, 61 insertions(+), 44 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
index cced80b03de4d..6941939c8db5a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
@@ -10,6 +10,7 @@
#define MLIR_DIALECT_LINALG_IR_LINALG_H
#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 73757ecb73b4b..f8543fb726e02 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -373,6 +373,10 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
ArrayRef<int64_t> innerPermutation,
ArrayRef<int64_t> outerPermutation);
+ /// Returns true if it is statically known that the `sliceOp` result shape
+ /// is compatible with the `unPackOp`. I.e., it does not drop any tile.
+ bool canFoldSliceOp(tensor::ExtractSliceOp sliceOp);
+
/// Check if this UnPackOp is like a simple unpad operation.
/// In other words, this operation:
/// 1. drops useless dimensions (dimension of size 1), and
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index f8ad1684aec32..7b7a67c303ced 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5456,11 +5456,7 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
if (unPackOp->hasOneUse()) {
auto extractSliceUser =
dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
- if (extractSliceUser &&
- areAllConstantIntValue(extractSliceUser.getMixedOffsets(), 0) &&
- areAllConstantIntValue(extractSliceUser.getMixedStrides(), 1) &&
- extractSliceUser.getSourceType().getRank() ==
- extractSliceUser.getResultType().getRank()) {
+ if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unPackOp);
auto newDest = rewriter.create<tensor::ExtractSliceOp>(
@@ -5503,6 +5499,32 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
return failure();
}
+bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
+ // Rank-reduced folding is not supported.
+ if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
+ return false;
+ if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
+ !areAllConstantIntValue(sliceOp.getMixedStrides(), 1))
+ return false;
+ RankedTensorType unpackedType = sliceOp.getResultType();
+ SmallVector<int64_t> outerShapeWithoutTranspose =
+ getPackedOuterShapeWithoutTransposition(*this);
+ for (auto [pos, tileSize] :
+ llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
+ if (unpackedType.isDynamicDim(pos))
+ return false;
+ if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
+ return false;
+ if (ShapedType::isDynamic(tileSize))
+ return false;
+ int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
+ unpackedType.getDimSize(pos);
+ if (paddingSize >= tileSize)
+ return false;
+ }
+ return true;
+}
+
bool UnPackOp::isLikeUnPad() {
RankedTensorType packedTensorType = getSourceType();
return isLikePadUnPad(*this, packedTensorType);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 299971a14a6c3..be89ecae180bd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -277,35 +277,8 @@ struct FoldUnpackWithExtractSliceOp
if (controlFn && !controlFn(&sliceOp.getSourceMutable()))
return failure();
- if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
- return rewriter.notifyMatchFailure(
- sliceOp, "rank-reduced folding is not supported");
- }
-
- // Check all offsets are zeros, and all strides are ones.
- if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
- !areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) {
- return rewriter.notifyMatchFailure(
- sliceOp, "expects offsets to be 0s and strides to be 1s");
- }
-
- // Folding is not allowed if any tile is dropped.
- RankedTensorType unpackedType = sliceOp.getResultType();
- SmallVector<int64_t> outerShapeWithoutTranspose =
- getPackedOuterShapeWithoutTransposition(unpackOp);
- for (auto [pos, tileSize] : llvm::zip_equal(
- unpackOp.getInnerDimsPos(), unpackOp.getStaticInnerTiles())) {
- if (unpackedType.isDynamicDim(pos))
- return failure();
- if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
- return failure();
- if (ShapedType::isDynamic(tileSize))
- return failure();
- int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
- unpackedType.getDimSize(pos);
- if (paddingSize >= tileSize)
- return failure();
- }
+ if (!unpackOp.canFoldSliceOp(sliceOp))
+ return failure();
// Create a new empty output tensor.
Type elementType = unpackOp.getDestType().getElementType();
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index dfe3bfd4a967a..686e6d7049f81 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1891,30 +1891,47 @@ func.func @fold_cast_unpack_dynamic_tile_size(
//===----------------------------------------------------------------------===//
func.func @fold_extract_slice_into_unpack(
- %src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
-) -> tensor<28x28x?xf32> {
+ %src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>, %size : index
+) -> tensor<28x28x10xf32> {
%unpack = linalg.unpack %src
outer_dims_perm = [0, 1, 2]
inner_dims_pos = [1, 2]
inner_tiles = [16, 16]
- into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32>
+ into %dest : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32>
%extracted_slice = tensor.extract_slice %unpack
- [0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32>
- return %extracted_slice : tensor<28x28x?xf32>
+ [0, 0, 0] [28, 28, 10] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x28x10xf32>
+ return %extracted_slice : tensor<28x28x10xf32>
}
-
// CHECK-LABEL: func @fold_extract_slice_into_unpack
-// CHECK-SAME: %[[SRC:.+]]: tensor<28x2x?x16x16xf32>
-// CHECK-SAME: %[[DEST:.+]]: tensor<28x32x?xf32>
-// CHECK-SAME: %[[SIZE:.+]]: index
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[SIZE:[a-zA-Z0-9]+]]
// CHECK: %[[DEST_SLICE:.+]] = tensor.extract_slice %[[DEST]]
-// CHECK-SAME: [0, 0, 0] [28, 28, %[[SIZE]]] [1, 1, 1]
+// CHECK-SAME: [0, 0, 0] [28, 28, 10] [1, 1, 1]
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
// CHECK-SAME: into %[[DEST_SLICE]]
// CHECK: return %[[UNPACK]]
// -----
+func.func @no_fold_extract_slice_into_unpack_dynamic(
+ %src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
+) -> tensor<28x28x?xf32> {
+ %unpack = linalg.unpack %src
+ outer_dims_perm = [0, 1, 2]
+ inner_dims_pos = [1, 2]
+ inner_tiles = [16, 16]
+ into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32>
+ %extracted_slice = tensor.extract_slice %unpack
+ [0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32>
+ return %extracted_slice : tensor<28x28x?xf32>
+}
+// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_dynamic
+// CHECK: linalg.unpack
+// CHECK: tensor.extract_slice
+
+// -----
+
func.func @no_fold_extract_slice_into_unpack_rank_reducing(
%src : tensor<28x2x16xf32>, %dest : tensor<28x32xf32>
) -> tensor<28xf32> {
More information about the Mlir-commits
mailing list