[Mlir-commits] [mlir] [mlir][linalg] Restrict linalg.pack to not have artificial padding. (PR #149624)

Han-Chung Wang llvmlistbot at llvm.org
Wed Jul 23 10:17:09 PDT 2025


https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/149624

>From 254880973fb98866332f8747f745c1ad003439c7 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Tue, 22 Jul 2025 15:04:19 -0700
Subject: [PATCH 1/4] [mlir][linalg] Take artificial padding into account for
 pack/unpack folding.

The revision only folds the tensor.pad/extract_slice op into
linalg.pack/unpack ops only when it is safe to fold.

According to the doc, it is not valid to have artificial padding.

```
- The following relationship for the tiled dimensions holds:
shape(result)[inner_dims_pos[i]] = shape(source)[inner_dims_pos[i]] / inner_tiles[i].
```

The documentation improvement and verifier update will be done in a
separate PR (i.e., https://github.com/llvm/llvm-project/pull/149624).
The revision is a step towards it.

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 mlir/include/mlir/Dialect/Linalg/IR/Linalg.h  |  6 ++
 .../Dialect/Linalg/IR/LinalgRelayoutOps.td    |  4 ++
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 55 +++++++++++++++--
 .../Transforms/PackAndUnpackPatterns.cpp      | 38 ++++++++----
 mlir/test/Dialect/Linalg/canonicalize.mlir    | 37 ++++++++----
 .../Tensor/fold-into-pack-and-unpack.mlir     | 60 ++++++++++++++-----
 6 files changed, 158 insertions(+), 42 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
index bb0ac414bcc2d..6941939c8db5a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
@@ -10,6 +10,7 @@
 #define MLIR_DIALECT_LINALG_IR_LINALG_H
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/AffineExpr.h"
@@ -89,6 +90,11 @@ Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim);
 OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val,
                                int64_t dim);
 
+/// Returns the outer shape in the packed domain before applying the
+/// transposition.
+template <typename OpTy>
+SmallVector<int64_t> getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack);
+
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index c384e8b638382..fa572024ff72b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -360,6 +360,10 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
                                    ArrayRef<int64_t> innerPermutation,
                                    ArrayRef<int64_t> outerPermutation);
 
+    /// Returns true if it is statically known that the `sliceOp` result shape
+    /// is compatible with the `unPackOp`. I.e., it does not drop any tile.
+    bool canFoldSliceOp(tensor::ExtractSliceOp sliceOp);
+
     /// Check if this UnPackOp is like a simple unpad operation.
     /// In other words, this operation:
     /// 1. drops useless dimensions (dimension of size 1), and
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 3aa6ac3ea0918..046a73c90f110 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4490,6 +4490,29 @@ Speculation::Speculatability ElementwiseOp::getSpeculatability() {
 //===----------------------------------------------------------------------===//
 // PackOp/UnPackOp Common
 //===----------------------------------------------------------------------===//
+
+template <typename OpTy>
+SmallVector<int64_t>
+getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack) {
+  RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
+                                    ? packOrUnPack.getDestType()
+                                    : packOrUnPack.getSourceType();
+  RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
+                                      ? packOrUnPack.getSourceType()
+                                      : packOrUnPack.getDestType();
+  SmallVector<int64_t> result(
+      packedType.getShape().take_front(unpackedType.getRank()));
+  if (!packOrUnPack.getOuterDimsPerm().empty()) {
+    applyPermutationToVector(
+        result, invertPermutationVector(packOrUnPack.getOuterDimsPerm()));
+  }
+  return result;
+}
+template SmallVector<int64_t>
+    getPackedOuterShapeWithoutTransposition<PackOp>(PackOp);
+template SmallVector<int64_t>
+    getPackedOuterShapeWithoutTransposition<UnPackOp>(UnPackOp);
+
 // Given the (potentially) updated packed type, `newPackedTy`, generates an
 // updated mixed-tile-sizes attribute. A tile size is updated only
 // when:
@@ -5447,11 +5470,7 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
   if (unPackOp->hasOneUse()) {
     auto extractSliceUser =
         dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
-    if (extractSliceUser &&
-        areAllConstantIntValue(extractSliceUser.getMixedOffsets(), 0) &&
-        areAllConstantIntValue(extractSliceUser.getMixedStrides(), 1) &&
-        extractSliceUser.getSourceType().getRank() ==
-            extractSliceUser.getResultType().getRank()) {
+    if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
       OpBuilder::InsertionGuard g(rewriter);
       rewriter.setInsertionPoint(unPackOp);
       auto newDest = rewriter.create<tensor::ExtractSliceOp>(
@@ -5494,6 +5513,32 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
   return failure();
 }
 
+bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
+  // Rank-reduced folding is not supported.
+  if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
+    return false;
+  if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
+      !areAllConstantIntValue(sliceOp.getMixedStrides(), 1))
+    return false;
+  RankedTensorType unpackedType = sliceOp.getResultType();
+  SmallVector<int64_t> outerShapeWithoutTranspose =
+      getPackedOuterShapeWithoutTransposition(*this);
+  for (auto [pos, tileSize] :
+       llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
+    if (unpackedType.isDynamicDim(pos))
+      return false;
+    if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
+      return false;
+    if (ShapedType::isDynamic(tileSize))
+      return false;
+    int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
+                          unpackedType.getDimSize(pos);
+    if (paddingSize >= tileSize)
+      return false;
+  }
+  return true;
+}
+
 bool UnPackOp::isLikeUnPad() {
   RankedTensorType packedTensorType = getSourceType();
   return isLikePadUnPad(*this, packedTensorType);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 2afa2f9b71c2a..73e157b42235a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -220,6 +220,31 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
       if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
         return failure();
 
+    // Folding is not allowed if it introduces artificial padding. It is not
+    // safe to fold the ops if any dynamic dimension or tile size is present,
+    // because we can not infer the padding size.
+    RankedTensorType unpackedType = packOp.getSourceType();
+    SmallVector<int64_t> outerShapeWithoutTranspose =
+        getPackedOuterShapeWithoutTransposition(packOp);
+    for (auto [pos, tileSize, high] :
+         llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getStaticInnerTiles(),
+                         padOp.getMixedHighPad())) {
+      if (unpackedType.isDynamicDim(pos))
+        return failure();
+      if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
+        return failure();
+      if (ShapedType::isDynamic(tileSize))
+        return failure();
+      std::optional<int64_t> cstHigh = getConstantIntValue(high);
+      if (!cstHigh)
+        return failure();
+      int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
+                            unpackedType.getDimSize(pos);
+      // Do not fold the op if it requires artificial padding.
+      if (paddingSize + cstHigh.value() >= tileSize)
+        return failure();
+    }
+
     rewriter.replaceOpWithNewOp<PackOp>(
         packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
         packOp.getMixedTiles(), constantPaddingValue,
@@ -251,17 +276,8 @@ struct FoldUnpackWithExtractSliceOp
     if (controlFn && !controlFn(&sliceOp.getSourceMutable()))
       return failure();
 
-    if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
-      return rewriter.notifyMatchFailure(
-          sliceOp, "rank-reduced folding is not supported");
-    }
-
-    // Check all offsets are zeros, and all strides are ones.
-    if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
-        !areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) {
-      return rewriter.notifyMatchFailure(
-          sliceOp, "expects offsets to be 0s and strides to be 1s");
-    }
+    if (!unpackOp.canFoldSliceOp(sliceOp))
+      return failure();
 
     // Create a new empty output tensor.
     Type elementType = unpackOp.getDestType().getElementType();
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 7284ae7dbd673..cd14bc3d1948b 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1890,30 +1890,47 @@ func.func @fold_cast_unpack_dynamic_tile_size(
 //===----------------------------------------------------------------------===//
 
 func.func @fold_extract_slice_into_unpack(
-    %src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
-) -> tensor<28x28x?xf32> {
+    %src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>, %size : index
+) -> tensor<28x28x10xf32> {
   %unpack = linalg.unpack %src
       outer_dims_perm = [0, 1, 2]
       inner_dims_pos = [1, 2]
       inner_tiles = [16, 16]
-      into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32>
+      into %dest : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32>
   %extracted_slice = tensor.extract_slice %unpack
-      [0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32>
-  return %extracted_slice : tensor<28x28x?xf32>
+      [0, 0, 0] [28, 28, 10] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x28x10xf32>
+  return %extracted_slice : tensor<28x28x10xf32>
 }
-
 // CHECK-LABEL: func @fold_extract_slice_into_unpack
-//  CHECK-SAME:     %[[SRC:.+]]: tensor<28x2x?x16x16xf32>
-//  CHECK-SAME:     %[[DEST:.+]]: tensor<28x32x?xf32>
-//  CHECK-SAME:     %[[SIZE:.+]]: index
+//  CHECK-SAME:     %[[SRC:[a-zA-Z0-9]+]]
+//  CHECK-SAME:     %[[DEST:[a-zA-Z0-9]+]]
+//  CHECK-SAME:     %[[SIZE:[a-zA-Z0-9]+]]
 //       CHECK:   %[[DEST_SLICE:.+]] = tensor.extract_slice %[[DEST]]
-//  CHECK-SAME:     [0, 0, 0] [28, 28, %[[SIZE]]] [1, 1, 1]
+//  CHECK-SAME:     [0, 0, 0] [28, 28, 10] [1, 1, 1]
 //       CHECK:   %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
 //  CHECK-SAME:       into %[[DEST_SLICE]]
 //       CHECK:   return %[[UNPACK]]
 
 // -----
 
+func.func @no_fold_extract_slice_into_unpack_dynamic(
+    %src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
+) -> tensor<28x28x?xf32> {
+  %unpack = linalg.unpack %src
+      outer_dims_perm = [0, 1, 2]
+      inner_dims_pos = [1, 2]
+      inner_tiles = [16, 16]
+      into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32>
+  %extracted_slice = tensor.extract_slice %unpack
+      [0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32>
+  return %extracted_slice : tensor<28x28x?xf32>
+}
+// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_dynamic
+//       CHECK:   linalg.unpack
+//       CHECK:   tensor.extract_slice
+
+// -----
+
 func.func @no_fold_extract_slice_into_unpack_rank_reducing(
     %src : tensor<28x2x16xf32>, %dest : tensor<28x32xf32>
 ) -> tensor<28xf32> {
diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index 16efa73f87a2a..4a97d1df25f15 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -1,22 +1,32 @@
 // RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-into-pack-and-unpack  %s | FileCheck %s
 // RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-into-pack-and-unpack-control  %s | FileCheck %s --check-prefix=CONTROL
 
-func.func @fold_unpack_slice(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
+func.func @fold_unpack_slice(%arg0 : tensor<2082x1x8x32xf32>) -> tensor<16649x16xf32> {
+  %empty = tensor.empty() : tensor<16656x16xf32>
+  %0 = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
+      : tensor<2082x1x8x32xf32> -> tensor<16656x16xf32>
+  %1 = tensor.extract_slice %0[0, 0] [16649, 16] [1, 1] : tensor<16656x16xf32> to tensor<16649x16xf32>
+  return %1 : tensor<16649x16xf32>
+}
+// CHECK-LABEL: func @fold_unpack_slice(
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
+//      CHECK:    %[[INIT:.+]] = tensor.empty() : tensor<16649x16xf32>
+//      CHECK:    %[[UNPACK:.+]] = linalg.unpack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+// CHECK-SAME:        into %[[INIT]]
+//      CHECK:    return %[[UNPACK]]
+
+// -----
+
+func.func @nofold_dynamic_unpack_slice(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
     %arg2 : index, %arg3 : index) -> tensor<?x?xf32> {
   %0 = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1
       : tensor<?x?x8x4xf32> -> tensor<?x?xf32>
   %1 = tensor.extract_slice %0[0, 0] [%arg2, %arg3] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
-//      CHECK: func @fold_unpack_slice(
-// CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x8x4xf32>
-// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
-// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
-// CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index
-//      CHECK:   %[[INIT:.+]] = tensor.empty(%[[ARG2]], %[[ARG3]]) : tensor<?x?xf32>
-//      CHECK:   %[[UNPACK:.+]] = linalg.unpack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [8, 4]
-// CHECK-SAME:       into %[[INIT]]
-//      CHECK:   return %[[UNPACK]]
+// CHECK-LABEL: func @nofold_dynamic_unpack_slice(
+//       CHECK:   linalg.unpack
+//       CHECK:   tensor.extract_slice
 
 // -----
 
@@ -59,13 +69,13 @@ func.func @nofold_unpack_slice_rank_reduced(%arg0 : tensor<?x?x8x4xf32>, %arg1 :
 
 // -----
 
-func.func @pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
+func.func @pad_pack(%src: tensor<16649x16xf32>) -> tensor<2082x1x8x32xf32> {
   %c0 = arith.constant 0 : index
   %cst = arith.constant 0.000000e+00 : f32
-  %padded = tensor.pad %src low[0, 0] high[15, 0] {
+  %padded = tensor.pad %src low[0, 0] high[7, 0] {
   ^bb0(%arg0: index, %arg1: index):
     tensor.yield %cst : f32
-  } : tensor<16641x16xf32> to tensor<16656x16xf32>
+  } : tensor<16649x16xf32> to tensor<16656x16xf32>
   %empty = tensor.empty() : tensor<2082x1x8x32xf32>
   %pack = linalg.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
       : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
@@ -81,10 +91,10 @@ func.func @pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
 
 // -----
 
-func.func @nofold_pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
+func.func @nofold_pad_pack_artificial_padding(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
   %c0 = arith.constant 0 : index
   %cst = arith.constant 0.000000e+00 : f32
-  %padded = tensor.pad %src nofold low[0, 0] high[15, 0] {
+  %padded = tensor.pad %src low[0, 0] high[15, 0] {
   ^bb0(%arg0: index, %arg1: index):
     tensor.yield %cst : f32
   } : tensor<16641x16xf32> to tensor<16656x16xf32>
@@ -93,7 +103,25 @@ func.func @nofold_pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32
       : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
   return %pack : tensor<2082x1x8x32xf32>
 }
-// CHECK-LABEL: func.func @nofold_pad_pack
+// CHECK-LABLE: func.func @nofold_pad_pack_artificial_padding(
+// CHECK:         tensor.pad
+// CHECK:         linalg.pack
+
+// -----
+
+func.func @nofold_pad_pack(%src: tensor<16649x16xf32>) -> tensor<2082x1x8x32xf32> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %padded = tensor.pad %src nofold low[0, 0] high[7, 0] {
+  ^bb0(%arg0: index, %arg1: index):
+    tensor.yield %cst : f32
+  } : tensor<16649x16xf32> to tensor<16656x16xf32>
+  %empty = tensor.empty() : tensor<2082x1x8x32xf32>
+  %pack = linalg.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
+      : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
+  return %pack : tensor<2082x1x8x32xf32>
+}
+// CHECK-LABEL: func.func @nofold_pad_pack(
 // CHECK:         tensor.pad
 // CHECK:         linalg.pack
 

>From 6de929abea2ad4fac10605e333050b64eedbb904 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Tue, 22 Jul 2025 15:15:08 -0700
Subject: [PATCH 2/4] [mlir][linalg] Restrict linalg.pack to not have extra
 padding sizes.

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 .../Dialect/Linalg/IR/LinalgRelayoutOps.td    | 14 +++-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 28 ++-----
 .../Transforms/PackAndUnpackPatterns.cpp      |  1 +
 mlir/test/Dialect/Linalg/canonicalize.mlir    | 25 +++---
 .../Linalg/data-layout-propagation.mlir       | 18 -----
 mlir/test/Dialect/Linalg/invalid.mlir         | 17 ++--
 .../Dialect/Linalg/transform-lower-pack.mlir  | 16 ++--
 .../tile-and-fuse-consumer.mlir               | 81 -------------------
 8 files changed, 50 insertions(+), 150 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index fa572024ff72b..f5dd7ae2c84f3 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -150,9 +150,10 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
 
     `padding_value` specifies a padding value at the boundary on non-perfectly
     divisible dimensions. Padding is optional:
-    - If absent, it is UB if the tile does not perfectly divide the dimension.
+    - If absent, it assumes the tile perfectly divides the dimension.
     - If present, it will pad along high dimensions (high-padding) to make the
-      tile complete.
+      tile complete. Note that it is not allowed to have artificial padding that
+      is not strictly required by linalg.pack.
 
     Example:
     ```mlir
@@ -167,6 +168,15 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
     //
     // Note: Only tiled dimensions can be padded.
     ```
+
+    Invalid example that has artificial padding:
+    ```mlir
+    %0 = linalg.pack %src padding_value(%cst : f32) inner_dims_pos = [0]
+        inner_tiles = [8] into %dest
+        : tensor<9xf32> -> tensor<3x8xf32>
+    //                             \
+    //            expect tensor<2x8xf32> because CeilDiv(9, 8) = 2
+    ```
   }];
   let arguments = (ins AnyRankedTensor:$source,
                        AnyRankedTensor:$dest,
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 046a73c90f110..2d473daa1e1f4 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -32,6 +32,7 @@
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
@@ -4622,22 +4623,6 @@ static bool isInvalidPackingPosSpecification(ArrayRef<int64_t> dimsPos,
   });
 }
 
-/// Returns true if the dimension of `sourceShape` is smaller than the dimension
-/// of the `limitShape`.
-static bool areAllInBound(ArrayRef<int64_t> sourceShape,
-                          ArrayRef<int64_t> limitShape) {
-  assert(
-      sourceShape.size() == limitShape.size() &&
-      "expected source shape rank, and limit of the shape to have same rank");
-  return llvm::all_of(
-      llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
-        int64_t sourceExtent = std::get<0>(it);
-        int64_t limit = std::get<1>(it);
-        return ShapedType::isDynamic(sourceExtent) ||
-               ShapedType::isDynamic(limit) || sourceExtent <= limit;
-      });
-}
-
 template <typename OpTy>
 static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
   static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
@@ -4696,11 +4681,6 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
   // represents full tiles.
   RankedTensorType expectedPackedType = PackOp::inferPackedType(
       unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
-  if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
-    return op->emitError("the shape of output is not large enough to hold the "
-                         "packed data. Expected at least ")
-           << expectedPackedType << ", got " << packedType;
-  }
   if (!llvm::all_of(
           llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
                     mixedTiles),
@@ -4717,6 +4697,12 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
     return op->emitError("mismatch in inner tile sizes specified and shaped of "
                          "tiled dimension in the packed type");
   }
+  if (failed(verifyCompatibleShape(expectedPackedType.getShape(),
+                                   packedType.getShape()))) {
+    return op->emitError("the shape of unpacked domain value is not large "
+                         "enough to hold the packed data. Expected at least ")
+           << expectedPackedType << ", got " << packedType;
+  }
   return success();
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 73e157b42235a..be89ecae180bd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -10,6 +10,7 @@
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/PatternMatch.h"
 
 namespace mlir {
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index cd14bc3d1948b..686e6d7049f81 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1387,42 +1387,43 @@ func.func @recursive_effect(%arg : tensor<1xf32>) {
 // CHECK-LABEL: @recursive_effect
 //       CHECK: linalg.map
 
+// -----
+
 //===----------------------------------------------------------------------===//
 // linalg.pack
 //===----------------------------------------------------------------------===//
 
 // CHECK-LABEL: func @fold_pack_constant_splat
 //   CHECK-NOT: linalg.pack
-//       CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
-func.func @fold_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
+//       CHECK: arith.constant dense<1.000000e-01> : tensor<4x8x8x32xf32>
+func.func @fold_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> {
   %cst = arith.constant dense<1.000000e-01> : tensor<64x128xf32>
   %0 = linalg.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
-    inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<8x16x8x32xf32>
-  return %0 : tensor<8x16x8x32xf32>
+    inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<4x8x8x32xf32>
+  return %0 : tensor<4x8x8x32xf32>
 }
 
 // -----
 
 // CHECK-LABEL: func @fold_padding_value_pack_constant_splat
 //   CHECK-NOT: linalg.pack
-//       CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
-func.func @fold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
+//       CHECK: arith.constant dense<1.000000e-01> : tensor<4x8x8x32xf32>
+func.func @fold_padding_value_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> {
   %pad = arith.constant 1.000000e-01 : f32
   %cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
   %0 = linalg.pack %cst
     padding_value(%pad : f32)
     outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
-    inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
-  return %0 : tensor<8x16x8x32xf32>
+    inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<4x8x8x32xf32>
+  return %0 : tensor<4x8x8x32xf32>
 }
 
-
 // -----
 
 // CHECK-LABEL: func @nofold_padding_value_pack_constant_splat
 //       CHECK: arith.constant dense<1.000000e-01> : tensor<63x127xf32>
 //       CHECK: linalg.pack
-func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
+func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> {
   %pad = arith.constant 0.0 : f32
   %cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
   %0 = linalg.pack %cst
@@ -1430,8 +1431,8 @@ func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32
     outer_dims_perm = [1, 0]
     inner_dims_pos = [0, 1]
     inner_tiles = [8, 32]
-    into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
-  return %0 : tensor<8x16x8x32xf32>
+    into %dest : tensor<63x127xf32> -> tensor<4x8x8x32xf32>
+  return %0 : tensor<4x8x8x32xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 6fc8d9f152f4e..cc26fa48abf4b 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1295,24 +1295,6 @@ func.func @no_bubble_up_pack_expanded_padding_through_expand_cannot_reassociate(
 
 // -----
 
-func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate(%arg0: tensor<32x64xf32>) -> tensor<8x4x16x8xf32> {
-  %empty = tensor.empty() : tensor<8x4x16x8xf32>
-  %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32>
-  %pack = linalg.pack %expanded inner_dims_pos = [0] inner_tiles = [8] into %empty : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32>
-  return %pack : tensor<8x4x16x8xf32>
-}
-// CHECK-LABEL: func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate(
-// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x4x16x8xf32>
-// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]]
-// CHECK-SAME:      output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32>
-// CHECK:         %[[PACK:.+]] = linalg.pack %[[EXPANDED]]
-// CHECK-SAME:      inner_dims_pos = [0] inner_tiles = [8] into %[[EMPTY]]
-// CHECK-SAME:      : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32>
-// CHECK:         return %[[PACK]] : tensor<8x4x16x8xf32>
-
-// -----
-
 func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index, %sz0: index) -> tensor<?x256x256xf32> {
   %6 = tensor.empty(%dim) : tensor<?x256xf32>
   %unpack = linalg.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index da1dfc7b6a624..4299a15026f91 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1760,6 +1760,7 @@ func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf
 }
 
 // -----
+
 func.func @pack_mismatch_inner_tile_size_and_output_shape(
   %input : tensor<?x?xf32>, %output : tensor<?x?x8x8xf32>) -> tensor<?x?x8x8xf32> {
   // expected-error at +1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
@@ -1827,24 +1828,24 @@ func.func @unpack_invalid_outer_dims_perm(%source: tensor<128x256xf32>, %dest: t
 // The outer dims in the output tensor are incorrectly/unexpectedly transposed.
 // This could be fixed by adding `outer_dims_perm = [1, 0]` (the default value assumes no transpose).
 func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tensor<4x16x32x16xf32>) -> tensor<4x16x32x16xf32> {
-  // expected-error at +1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<16x4x32x16xf32>', got 'tensor<4x16x32x16xf32>'}}
+  // expected-error at +1 {{the shape of unpacked domain value is not large enough to hold the packed data. Expected at least 'tensor<16x4x32x16xf32>', got 'tensor<4x16x32x16xf32>'}}
   %0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [32, 16] into %output : tensor<256x128xf32> -> tensor<4x16x32x16xf32>
   return %0 : tensor<4x16x32x16xf32>
 }
 
 // -----
 
-func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> {
-  // expected-error at +1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x8x16x32xf32>', got 'tensor<8x8x32x16xf32>'}}
-  %0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32>
-  return %0 : tensor<8x8x32x16xf32>
+func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x7x16x32xf32>) -> tensor<8x7x16x32xf32> {
+  // expected-error at +1 {{the shape of unpacked domain value is not large enough to hold the packed data. Expected at least 'tensor<8x8x16x32xf32>', got 'tensor<8x7x16x32xf32>'}}
+  %0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x7x16x32xf32>
+  return %0 : tensor<8x7x16x32xf32>
 }
 
 // -----
 
-func.func @unpack_invalid(%output: tensor<256x128xf32>, %input: tensor<8x8x32x16xf32>) -> tensor<256x128xf32> {
-  // expected-error at +1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x32x4x32xf32>', got 'tensor<8x8x32x16xf32>'}}
-  %0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
+func.func @unpack_invalid(%output: tensor<256x128xf32>, %input: tensor<8x8x4x32xf32>) -> tensor<256x128xf32> {
+  // expected-error at +1 {{the shape of unpacked domain value is not large enough to hold the packed data. Expected at least 'tensor<8x32x4x32xf32>', got 'tensor<8x8x4x32xf32>'}}
+  %0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x4x32xf32> -> tensor<256x128xf32>
   return %0 : tensor<256x128xf32>
 }
 
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 81fd7a8a947d7..9e7681d1a1b7d 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -326,23 +326,23 @@ module attributes {transform.with_named_sequence} {
 // -----
 
 // CHECK-LABEL: func.func @pack_with_pad(
-func.func @pack_with_pad(%src: tensor<4225x12xf32>, %dest: tensor<265x16x16x1xf32>)
-    -> tensor<265x16x16x1xf32> {
+func.func @pack_with_pad(%src: tensor<4225x12xf32>, %dest: tensor<265x12x16x1xf32>)
+    -> tensor<265x12x16x1xf32> {
   //      CHECK: tensor.pad {{.*}} low[0, 0]
-  //      CHECK:   : tensor<4225x12xf32> to tensor<4240x16xf32>
+  //      CHECK:   : tensor<4225x12xf32> to tensor<4240x12xf32>
   //      CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2, 3]]
-  // CHECK-SAME:   : tensor<4240x16xf32> into tensor<265x16x16x1xf32>
+  // CHECK-SAME:   : tensor<4240x12xf32> into tensor<265x16x12x1xf32>
   //      CHECK: linalg.transpose
-  // CHECK-SAME:   ins(%{{[a-zA-Z0-9]*}} : tensor<265x16x16x1xf32>)
-  // CHECK-SAME:   outs(%{{[a-zA-Z0-9]*}} : tensor<265x16x16x1xf32>)
+  // CHECK-SAME:   ins(%{{[a-zA-Z0-9]*}} : tensor<265x16x12x1xf32>)
+  // CHECK-SAME:   outs(%{{[a-zA-Z0-9]*}} : tensor<265x12x16x1xf32>)
   // CHECK-SAME:   permutation = [0, 2, 1, 3]
   %cst = arith.constant 0.000000e+00 : f32
   %0 = linalg.pack %src
     padding_value(%cst : f32)
     inner_dims_pos = [0, 1]
     inner_tiles = [16, 1] into %dest
-    : tensor<4225x12xf32> -> tensor<265x16x16x1xf32>
-  return %0 : tensor<265x16x16x1xf32>
+    : tensor<4225x12xf32> -> tensor<265x12x16x1xf32>
+  return %0 : tensor<265x12x16x1xf32>
 }
 
 module attributes {transform.with_named_sequence} {
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index cdbca7228ded3..e48e5c6c308be 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -646,87 +646,6 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-// It is valid to fuse the pack if the dimension is not tiled even when it needs
-// extra padding.
-
-func.func @fuse_pack_consumer_with_untiled_extra_padding(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<33x2x3x16xf32> {
-  %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
-    %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
-    %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
-    %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32>
-    scf.forall.in_parallel {
-      tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32>
-    }
-  }
-  %1 = tensor.empty() : tensor<33x2x3x16xf32>
-  %cst = arith.constant 0.000000e+00 : f32
-  %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<33x2x3x16xf32>
-  return %pack : tensor<33x2x3x16xf32>
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-    %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-    %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-    transform.yield
-  }
-}
-//      CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
-//      CHECK: func.func @fuse_pack_consumer_with_untiled_extra_padding(
-// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
-//  CHECK-DAG:   %[[OUT_INIT:.*]] = tensor.empty() : tensor<33x2x3x16xf32>
-//  CHECK-DAG:   %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32
-//      CHECK:   %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16)
-// CHECK-SAME:      shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]])
-//      CHECK:      %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1]
-//      CHECK:      %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
-//      CHECK:      %[[ELEM:.*]] = linalg.exp
-// CHECK-SAME:        ins(%[[ELEM_SRC]]
-// CHECK-SAME:        outs(%[[ELEM_DEST]]
-//  CHECK-DAG:      %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
-//  CHECK-DAG:      %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [33, 1, 3, 16] [1, 1, 1, 1]
-//      CHECK:      %[[TILED_PACK_OUT:.*]] = linalg.pack %[[ELEM]]
-// CHECK-SAME:        padding_value(%[[PAD_VAL]] : f32)
-// CHECK-SAME:        inner_dims_pos = [0, 1] inner_tiles = [3, 16]
-// CHECK-SAME:        into %[[TILED_PACK_DEST]]
-//      CHECK:      scf.forall.in_parallel {
-//      CHECK:          tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
-//      CHECK:          tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [33, 1, 3, 16] [1, 1, 1, 1]
-
-// -----
-
-// If the dimension is tiled and it needs extra padding, do not fuse the pack
-// op.
-
-func.func @nofuse_pack_consumer_with_extra_padding(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x32x3x16xf32> {
-  %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
-    %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
-    %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
-    %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32>
-    scf.forall.in_parallel {
-      // expected-error @below {{failed to fuse consumer of slice}}
-      tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32>
-    }
-  }
-  %1 = tensor.empty() : tensor<23x32x3x16xf32>
-  %cst = arith.constant 0.000000e+00 : f32
-  %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<23x32x3x16xf32>
-  return %pack : tensor<23x32x3x16xf32>
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-    %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-    %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-    transform.yield
-  }
-}
-
-// -----
-
 // Imperfect tiling is not supported in pack op consumer fusion.
 
 #map = affine_map<(d0) -> (d0 * 5)>

>From 8e501ac7e95cce3a2cb38546600ead0e9cf75b4d Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Tue, 22 Jul 2025 11:16:56 -0700
Subject: [PATCH 3/4] Address comments

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 .../Dialect/Linalg/IR/LinalgRelayoutOps.td    |  5 +++-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |  6 ++---
 mlir/test/Dialect/Linalg/invalid.mlir         | 26 ++++++++++++++++---
 3 files changed, 30 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index f5dd7ae2c84f3..f8543fb726e02 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -153,7 +153,10 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
     - If absent, it assumes the tile perfectly divides the dimension.
     - If present, it will pad along high dimensions (high-padding) to make the
       tile complete. Note that it is not allowed to have artificial padding that
-      is not strictly required by linalg.pack.
+      is not strictly required by linalg.pack (i.e., padding past what is needed
+      to complete the last tile along each packed dimension).. It is UB if extra
+      padding is requested for dynamic cases. For static cases, they are caught
+      by the verifier.
 
     Example:
     ```mlir
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 2d473daa1e1f4..7b7a67c303ced 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4699,9 +4699,9 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
   }
   if (failed(verifyCompatibleShape(expectedPackedType.getShape(),
                                    packedType.getShape()))) {
-    return op->emitError("the shape of unpacked domain value is not large "
-                         "enough to hold the packed data. Expected at least ")
-           << expectedPackedType << ", got " << packedType;
+    return op->emitError("expected ")
+           << expectedPackedType << " for the unpacked domain value, got "
+           << packedType;
   }
   return success();
 }
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 4299a15026f91..595dc96a30fbc 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1825,10 +1825,21 @@ func.func @unpack_invalid_outer_dims_perm(%source: tensor<128x256xf32>, %dest: t
 
 // -----
 
+func.func @pack_with_artificial_padding(%input: tensor<9xf32>, %output: tensor<3x8xf32>) -> tensor<3x8xf32> {
+  %cst = arith.constant 0.0 : f32
+  // expected-error at +1 {{expected 'tensor<2x8xf32>' for the unpacked domain value, got 'tensor<3x8xf32>'}}
+  %0 = linalg.pack %input padding_value(%cst : f32) inner_dims_pos = [0]
+      inner_tiles = [8] into %output
+      : tensor<9xf32> -> tensor<3x8xf32>
+  return %0 : tensor<3x8xf32>
+}
+
+// -----
+
 // The outer dims in the output tensor are incorrectly/unexpectedly transposed.
 // This could be fixed by adding `outer_dims_perm = [1, 0]` (the default value assumes no transpose).
 func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tensor<4x16x32x16xf32>) -> tensor<4x16x32x16xf32> {
-  // expected-error at +1 {{the shape of unpacked domain value is not large enough to hold the packed data. Expected at least 'tensor<16x4x32x16xf32>', got 'tensor<4x16x32x16xf32>'}}
+  // expected-error at +1 {{expected 'tensor<16x4x32x16xf32>' for the unpacked domain value, got 'tensor<4x16x32x16xf32>'}}
   %0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [32, 16] into %output : tensor<256x128xf32> -> tensor<4x16x32x16xf32>
   return %0 : tensor<4x16x32x16xf32>
 }
@@ -1836,15 +1847,24 @@ func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tenso
 // -----
 
 func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x7x16x32xf32>) -> tensor<8x7x16x32xf32> {
-  // expected-error at +1 {{the shape of unpacked domain value is not large enough to hold the packed data. Expected at least 'tensor<8x8x16x32xf32>', got 'tensor<8x7x16x32xf32>'}}
+  // expected-error at +1 {{expected 'tensor<8x8x16x32xf32>' for the unpacked domain value, got 'tensor<8x7x16x32xf32>'}}
   %0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x7x16x32xf32>
   return %0 : tensor<8x7x16x32xf32>
 }
 
 // -----
 
+func.func @unpack_with_slicing_tiles(%input: tensor<3x8xf32>, %output: tensor<9xf32>) -> tensor<9xf32> {
+  // expected-error at +1 {{expected 'tensor<2x8xf32>' for the unpacked domain value, got 'tensor<3x8xf32>'}}
+  %0 = linalg.unpack %input inner_dims_pos = [0] inner_tiles = [8] into %output
+      : tensor<3x8xf32> -> tensor<9xf32>
+  return %0 : tensor<9xf32>
+}
+
+// -----
+
 func.func @unpack_invalid(%output: tensor<256x128xf32>, %input: tensor<8x8x4x32xf32>) -> tensor<256x128xf32> {
-  // expected-error at +1 {{the shape of unpacked domain value is not large enough to hold the packed data. Expected at least 'tensor<8x32x4x32xf32>', got 'tensor<8x8x4x32xf32>'}}
+  // expected-error at +1 {{expected 'tensor<8x32x4x32xf32>' for the unpacked domain value, got 'tensor<8x8x4x32xf32>'}}
   %0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x4x32xf32> -> tensor<256x128xf32>
   return %0 : tensor<256x128xf32>
 }

>From 5a1ae044eb95a873b671f58b0d9df00f1611e298 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Tue, 22 Jul 2025 15:35:29 -0700
Subject: [PATCH 4/4] Address banach-space's comments

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 .../mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td        | 10 ++++++----
 mlir/test/Dialect/Linalg/invalid.mlir                  |  6 +++---
 2 files changed, 9 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index f8543fb726e02..66fbd2a84370c 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -150,13 +150,15 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
 
     `padding_value` specifies a padding value at the boundary on non-perfectly
     divisible dimensions. Padding is optional:
-    - If absent, it assumes the tile perfectly divides the dimension.
+    - If absent, it is assumed that for all inner tiles,
+      `shape(source)[inner_dims_pos[i]] % inner_tiles[i] == 0`, i.e. all inner
+      tiles divide perfectly the corresponding outer dimension in the result
+      tensor.
     - If present, it will pad along high dimensions (high-padding) to make the
       tile complete. Note that it is not allowed to have artificial padding that
       is not strictly required by linalg.pack (i.e., padding past what is needed
-      to complete the last tile along each packed dimension).. It is UB if extra
-      padding is requested for dynamic cases. For static cases, they are caught
-      by the verifier.
+      to complete the last tile along each packed dimension). It is UB if extra
+      padding is requested.
 
     Example:
     ```mlir
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 595dc96a30fbc..dd720805b31dc 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1846,7 +1846,7 @@ func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tenso
 
 // -----
 
-func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x7x16x32xf32>) -> tensor<8x7x16x32xf32> {
+func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tensor<8x7x16x32xf32>) -> tensor<8x7x16x32xf32> {
   // expected-error at +1 {{expected 'tensor<8x8x16x32xf32>' for the unpacked domain value, got 'tensor<8x7x16x32xf32>'}}
   %0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x7x16x32xf32>
   return %0 : tensor<8x7x16x32xf32>
@@ -1854,7 +1854,7 @@ func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x7x16x32xf
 
 // -----
 
-func.func @unpack_with_slicing_tiles(%input: tensor<3x8xf32>, %output: tensor<9xf32>) -> tensor<9xf32> {
+func.func @unpack_with_dropping_tiles(%input: tensor<3x8xf32>, %output: tensor<9xf32>) -> tensor<9xf32> {
   // expected-error at +1 {{expected 'tensor<2x8xf32>' for the unpacked domain value, got 'tensor<3x8xf32>'}}
   %0 = linalg.unpack %input inner_dims_pos = [0] inner_tiles = [8] into %output
       : tensor<3x8xf32> -> tensor<9xf32>
@@ -1863,7 +1863,7 @@ func.func @unpack_with_slicing_tiles(%input: tensor<3x8xf32>, %output: tensor<9x
 
 // -----
 
-func.func @unpack_invalid(%output: tensor<256x128xf32>, %input: tensor<8x8x4x32xf32>) -> tensor<256x128xf32> {
+func.func @unpack_invalid_source_shape(%output: tensor<256x128xf32>, %input: tensor<8x8x4x32xf32>) -> tensor<256x128xf32> {
   // expected-error at +1 {{expected 'tensor<8x32x4x32xf32>' for the unpacked domain value, got 'tensor<8x8x4x32xf32>'}}
   %0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x4x32xf32> -> tensor<256x128xf32>
   return %0 : tensor<256x128xf32>



More information about the Mlir-commits mailing list