[Mlir-commits] [mlir] Reapply "[mlir][linalg] Restrict linalg.pack to not have artificial padding." (#150675) (PR #150680)

Han-Chung Wang llvmlistbot at llvm.org
Fri Jul 25 11:52:31 PDT 2025


https://github.com/hanhanW created https://github.com/llvm/llvm-project/pull/150680

Reapply the patch with a shape fix in https://github.com/llvm/llvm-project/commit/1db4c6b27500e686fad9e55bbbe7c7c68b246b7e

The revision restrict the `linalg.pack` op to not have artificial padding semantics. E.g., the below is valid without the change, and it becomes invalid with the change.

```mlir
func.func @foo(%src: tensor<9xf32>) -> tensor<100x8xf32> {
  %cst = arith.constant 0.000000e+00 : f32
  %dest = tensor.empty() : tensor<100x8xf32>
  %pack = linalg.pack %src
    padding_value(%cst : f32)
    inner_dims_pos = [0]
    inner_tiles = [8] into %dest
    : tensor<9xf32> -> tensor<100x8xf32>
  return %pack : tensor<100x8xf32>
}
```

IMO, it is a misuse if we use pack ops with artificial padding sizes because the intention of the pack op is to relayout the source based on target intrinsics, etc. The output shape is expected to be `tensor<2x8xf32>`. If people need extra padding sizes, they can create a new pad op followed by the pack op.

This also makes consumer tiling much easier because the consumer fusion does not support artificial padding sizes. It is very hard to make it work without using ad-hoc patterns because the tiling sizes are about source, which implies that you don't have a core_id/thread_id to write padding values to the whole tile.

People may have a question how why pad tiling implementation works. The answer is that it creates an `if-else` branch to handle the case. In my experience, it is very struggle in transformation because most of the time people only need one side of the branch given that the tile sizes are usually greater than padding sizes. However, the implementation is conservatively correct in terms of semantics. Given that the introduction of `pack` op is to serve the relayout needs better, having the restriction makes sense to me.

Removed tests:
- `no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate` from `data-layout-propagation.mlir`: it is a dup test to `bubble_up_pack_non_expanded_dims_through_expand` after we fix the shape.
- `fuse_pack_consumer_with_untiled_extra_padding` from `tile-and-fuse-consumer.mlir`: it was created for artificial padding in the consumer fusion implementation.

The other changes in lit tests are just fixing the shape.

>From f6f89c000ea7511d9992e2394159b477e2229873 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Fri, 25 Jul 2025 11:43:41 -0700
Subject: [PATCH 1/2] Reapply "[mlir][linalg] Restrict linalg.pack to not have
 artificial padding." (#150675)

This reverts commit 0844812b2e9d7f5ab005223443791c9287bcf5a2.
---
 .../Dialect/Linalg/IR/LinalgRelayoutOps.td    | 25 +++++-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 28 ++-----
 .../Transforms/PackAndUnpackPatterns.cpp      |  1 +
 mlir/test/Dialect/Linalg/canonicalize.mlir    | 25 +++---
 .../Linalg/data-layout-propagation.mlir       | 18 -----
 mlir/test/Dialect/Linalg/invalid.mlir         | 37 +++++++--
 .../Dialect/Linalg/transform-lower-pack.mlir  | 16 ++--
 .../tile-and-fuse-consumer.mlir               | 81 -------------------
 8 files changed, 80 insertions(+), 151 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index fa572024ff72b..f36b41ccf6745 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -106,7 +106,9 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
      result tensor in the order in which they appear, i.e.
      `shape(result)[rank(result) + i] = inner_tiles[i]` for `0 <= i < k`.
      - The following relationship for the tiled dimensions holds:
-     `shape(result)[inner_dims_pos[i]] = shape(source)[inner_dims_pos[i]] / inner_tiles[i]`.
+     `shape(result)[inner_dims_pos[i]] = shape(source)[inner_dims_pos[i]] / inner_tiles[i]`,
+     where (⌈/⌉ indicates CeilDiv).
+
 
     Example: If `inner_tiles = [16, 32]`, the result tensor has a shape of
     `...x16x32`. If `inner_dims_pos = [0, 1]`, the 0th source dimension is tiled
@@ -150,9 +152,17 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
 
     `padding_value` specifies a padding value at the boundary on non-perfectly
     divisible dimensions. Padding is optional:
-    - If absent, it is UB if the tile does not perfectly divide the dimension.
+    - If absent, it is assumed that for all inner tiles,
+      `shape(source)[inner_dims_pos[i]] % inner_tiles[i] == 0`, i.e. all inner
+      tiles divide perfectly the corresponding outer dimension in the result
+      tensor. It is UB if the tile does not perfectly divide the dimension.
     - If present, it will pad along high dimensions (high-padding) to make the
-      tile complete.
+      tile complete. Note that it is not allowed to have artificial padding that
+      is not strictly required by linalg.pack (i.e., padding past what is needed
+      to complete the last tile along each packed dimension). It is UB if extra
+      padding is requested.
+    It is not possible to verify the requirements statically with dynamic
+    shapes, so they are treated as UB.
 
     Example:
     ```mlir
@@ -167,6 +177,15 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
     //
     // Note: Only tiled dimensions can be padded.
     ```
+
+    Invalid example that has artificial padding:
+    ```mlir
+    %0 = linalg.pack %src padding_value(%cst : f32) inner_dims_pos = [0]
+        inner_tiles = [8] into %dest
+        : tensor<9xf32> -> tensor<3x8xf32>
+    //                             \
+    //            expect tensor<2x8xf32> because CeilDiv(9, 8) = 2
+    ```
   }];
   let arguments = (ins AnyRankedTensor:$source,
                        AnyRankedTensor:$dest,
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 4fee81aa2ef67..e25d063fce97b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -32,6 +32,7 @@
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
@@ -4624,22 +4625,6 @@ static bool isInvalidPackingPosSpecification(ArrayRef<int64_t> dimsPos,
   });
 }
 
-/// Returns true if the dimension of `sourceShape` is smaller than the dimension
-/// of the `limitShape`.
-static bool areAllInBound(ArrayRef<int64_t> sourceShape,
-                          ArrayRef<int64_t> limitShape) {
-  assert(
-      sourceShape.size() == limitShape.size() &&
-      "expected source shape rank, and limit of the shape to have same rank");
-  return llvm::all_of(
-      llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
-        int64_t sourceExtent = std::get<0>(it);
-        int64_t limit = std::get<1>(it);
-        return ShapedType::isDynamic(sourceExtent) ||
-               ShapedType::isDynamic(limit) || sourceExtent <= limit;
-      });
-}
-
 template <typename OpTy>
 static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
   static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
@@ -4698,11 +4683,6 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
   // represents full tiles.
   RankedTensorType expectedPackedType = PackOp::inferPackedType(
       unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
-  if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
-    return op->emitError("the shape of output is not large enough to hold the "
-                         "packed data. Expected at least ")
-           << expectedPackedType << ", got " << packedType;
-  }
   if (!llvm::all_of(
           llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
                     mixedTiles),
@@ -4719,6 +4699,12 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
     return op->emitError("mismatch in inner tile sizes specified and shaped of "
                          "tiled dimension in the packed type");
   }
+  if (failed(verifyCompatibleShape(expectedPackedType.getShape(),
+                                   packedType.getShape()))) {
+    return op->emitError("expected ")
+           << expectedPackedType << " for the packed domain value, got "
+           << packedType;
+  }
   return success();
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index a45a4e314e511..595d2625ee27c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -10,6 +10,7 @@
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/PatternMatch.h"
 
 namespace mlir {
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 9cbb56e4de884..39a7b1b1a2775 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1387,42 +1387,43 @@ func.func @recursive_effect(%arg : tensor<1xf32>) {
 // CHECK-LABEL: @recursive_effect
 //       CHECK: linalg.map
 
+// -----
+
 //===----------------------------------------------------------------------===//
 // linalg.pack
 //===----------------------------------------------------------------------===//
 
 // CHECK-LABEL: func @fold_pack_constant_splat
 //   CHECK-NOT: linalg.pack
-//       CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
-func.func @fold_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
+//       CHECK: arith.constant dense<1.000000e-01> : tensor<4x8x8x32xf32>
+func.func @fold_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> {
   %cst = arith.constant dense<1.000000e-01> : tensor<64x128xf32>
   %0 = linalg.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
-    inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<8x16x8x32xf32>
-  return %0 : tensor<8x16x8x32xf32>
+    inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<4x8x8x32xf32>
+  return %0 : tensor<4x8x8x32xf32>
 }
 
 // -----
 
 // CHECK-LABEL: func @fold_padding_value_pack_constant_splat
 //   CHECK-NOT: linalg.pack
-//       CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
-func.func @fold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
+//       CHECK: arith.constant dense<1.000000e-01> : tensor<4x8x8x32xf32>
+func.func @fold_padding_value_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> {
   %pad = arith.constant 1.000000e-01 : f32
   %cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
   %0 = linalg.pack %cst
     padding_value(%pad : f32)
     outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
-    inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
-  return %0 : tensor<8x16x8x32xf32>
+    inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<4x8x8x32xf32>
+  return %0 : tensor<4x8x8x32xf32>
 }
 
-
 // -----
 
 // CHECK-LABEL: func @nofold_padding_value_pack_constant_splat
 //       CHECK: arith.constant dense<1.000000e-01> : tensor<63x127xf32>
 //       CHECK: linalg.pack
-func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
+func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> {
   %pad = arith.constant 0.0 : f32
   %cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
   %0 = linalg.pack %cst
@@ -1430,8 +1431,8 @@ func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32
     outer_dims_perm = [1, 0]
     inner_dims_pos = [0, 1]
     inner_tiles = [8, 32]
-    into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
-  return %0 : tensor<8x16x8x32xf32>
+    into %dest : tensor<63x127xf32> -> tensor<4x8x8x32xf32>
+  return %0 : tensor<4x8x8x32xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 6fc8d9f152f4e..cc26fa48abf4b 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1295,24 +1295,6 @@ func.func @no_bubble_up_pack_expanded_padding_through_expand_cannot_reassociate(
 
 // -----
 
-func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate(%arg0: tensor<32x64xf32>) -> tensor<8x4x16x8xf32> {
-  %empty = tensor.empty() : tensor<8x4x16x8xf32>
-  %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32>
-  %pack = linalg.pack %expanded inner_dims_pos = [0] inner_tiles = [8] into %empty : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32>
-  return %pack : tensor<8x4x16x8xf32>
-}
-// CHECK-LABEL: func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate(
-// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x4x16x8xf32>
-// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]]
-// CHECK-SAME:      output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32>
-// CHECK:         %[[PACK:.+]] = linalg.pack %[[EXPANDED]]
-// CHECK-SAME:      inner_dims_pos = [0] inner_tiles = [8] into %[[EMPTY]]
-// CHECK-SAME:      : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32>
-// CHECK:         return %[[PACK]] : tensor<8x4x16x8xf32>
-
-// -----
-
 func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index, %sz0: index) -> tensor<?x256x256xf32> {
   %6 = tensor.empty(%dim) : tensor<?x256xf32>
   %unpack = linalg.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index da1dfc7b6a624..40bf4d19d6b91 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1760,6 +1760,7 @@ func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf
 }
 
 // -----
+
 func.func @pack_mismatch_inner_tile_size_and_output_shape(
   %input : tensor<?x?xf32>, %output : tensor<?x?x8x8xf32>) -> tensor<?x?x8x8xf32> {
   // expected-error at +1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
@@ -1824,27 +1825,47 @@ func.func @unpack_invalid_outer_dims_perm(%source: tensor<128x256xf32>, %dest: t
 
 // -----
 
+func.func @pack_with_artificial_padding(%input: tensor<9xf32>, %output: tensor<3x8xf32>) -> tensor<3x8xf32> {
+  %cst = arith.constant 0.0 : f32
+  // expected-error at +1 {{expected 'tensor<2x8xf32>' for the packed domain value, got 'tensor<3x8xf32>'}}
+  %0 = linalg.pack %input padding_value(%cst : f32) inner_dims_pos = [0]
+      inner_tiles = [8] into %output
+      : tensor<9xf32> -> tensor<3x8xf32>
+  return %0 : tensor<3x8xf32>
+}
+
+// -----
+
 // The outer dims in the output tensor are incorrectly/unexpectedly transposed.
 // This could be fixed by adding `outer_dims_perm = [1, 0]` (the default value assumes no transpose).
 func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tensor<4x16x32x16xf32>) -> tensor<4x16x32x16xf32> {
-  // expected-error at +1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<16x4x32x16xf32>', got 'tensor<4x16x32x16xf32>'}}
+  // expected-error at +1 {{expected 'tensor<16x4x32x16xf32>' for the packed domain value, got 'tensor<4x16x32x16xf32>'}}
   %0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [32, 16] into %output : tensor<256x128xf32> -> tensor<4x16x32x16xf32>
   return %0 : tensor<4x16x32x16xf32>
 }
 
 // -----
 
-func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> {
-  // expected-error at +1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x8x16x32xf32>', got 'tensor<8x8x32x16xf32>'}}
-  %0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32>
-  return %0 : tensor<8x8x32x16xf32>
+func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tensor<8x7x16x32xf32>) -> tensor<8x7x16x32xf32> {
+  // expected-error at +1 {{expected 'tensor<8x8x16x32xf32>' for the packed domain value, got 'tensor<8x7x16x32xf32>'}}
+  %0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x7x16x32xf32>
+  return %0 : tensor<8x7x16x32xf32>
+}
+
+// -----
+
+func.func @unpack_with_artifical_tiles_that_are_dropped(%input: tensor<3x8xf32>, %output: tensor<9xf32>) -> tensor<9xf32> {
+  // expected-error at +1 {{expected 'tensor<2x8xf32>' for the packed domain value, got 'tensor<3x8xf32>'}}
+  %0 = linalg.unpack %input inner_dims_pos = [0] inner_tiles = [8] into %output
+      : tensor<3x8xf32> -> tensor<9xf32>
+  return %0 : tensor<9xf32>
 }
 
 // -----
 
-func.func @unpack_invalid(%output: tensor<256x128xf32>, %input: tensor<8x8x32x16xf32>) -> tensor<256x128xf32> {
-  // expected-error at +1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x32x4x32xf32>', got 'tensor<8x8x32x16xf32>'}}
-  %0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
+func.func @unpack_invalid_source_shape(%output: tensor<256x128xf32>, %input: tensor<8x8x4x32xf32>) -> tensor<256x128xf32> {
+  // expected-error at +1 {{expected 'tensor<8x32x4x32xf32>' for the packed domain value, got 'tensor<8x8x4x32xf32>'}}
+  %0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x4x32xf32> -> tensor<256x128xf32>
   return %0 : tensor<256x128xf32>
 }
 
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 81fd7a8a947d7..9e7681d1a1b7d 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -326,23 +326,23 @@ module attributes {transform.with_named_sequence} {
 // -----
 
 // CHECK-LABEL: func.func @pack_with_pad(
-func.func @pack_with_pad(%src: tensor<4225x12xf32>, %dest: tensor<265x16x16x1xf32>)
-    -> tensor<265x16x16x1xf32> {
+func.func @pack_with_pad(%src: tensor<4225x12xf32>, %dest: tensor<265x12x16x1xf32>)
+    -> tensor<265x12x16x1xf32> {
   //      CHECK: tensor.pad {{.*}} low[0, 0]
-  //      CHECK:   : tensor<4225x12xf32> to tensor<4240x16xf32>
+  //      CHECK:   : tensor<4225x12xf32> to tensor<4240x12xf32>
   //      CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2, 3]]
-  // CHECK-SAME:   : tensor<4240x16xf32> into tensor<265x16x16x1xf32>
+  // CHECK-SAME:   : tensor<4240x12xf32> into tensor<265x16x12x1xf32>
   //      CHECK: linalg.transpose
-  // CHECK-SAME:   ins(%{{[a-zA-Z0-9]*}} : tensor<265x16x16x1xf32>)
-  // CHECK-SAME:   outs(%{{[a-zA-Z0-9]*}} : tensor<265x16x16x1xf32>)
+  // CHECK-SAME:   ins(%{{[a-zA-Z0-9]*}} : tensor<265x16x12x1xf32>)
+  // CHECK-SAME:   outs(%{{[a-zA-Z0-9]*}} : tensor<265x12x16x1xf32>)
   // CHECK-SAME:   permutation = [0, 2, 1, 3]
   %cst = arith.constant 0.000000e+00 : f32
   %0 = linalg.pack %src
     padding_value(%cst : f32)
     inner_dims_pos = [0, 1]
     inner_tiles = [16, 1] into %dest
-    : tensor<4225x12xf32> -> tensor<265x16x16x1xf32>
-  return %0 : tensor<265x16x16x1xf32>
+    : tensor<4225x12xf32> -> tensor<265x12x16x1xf32>
+  return %0 : tensor<265x12x16x1xf32>
 }
 
 module attributes {transform.with_named_sequence} {
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index cdbca7228ded3..e48e5c6c308be 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -646,87 +646,6 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-// It is valid to fuse the pack if the dimension is not tiled even when it needs
-// extra padding.
-
-func.func @fuse_pack_consumer_with_untiled_extra_padding(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<33x2x3x16xf32> {
-  %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
-    %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
-    %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
-    %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32>
-    scf.forall.in_parallel {
-      tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32>
-    }
-  }
-  %1 = tensor.empty() : tensor<33x2x3x16xf32>
-  %cst = arith.constant 0.000000e+00 : f32
-  %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<33x2x3x16xf32>
-  return %pack : tensor<33x2x3x16xf32>
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-    %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-    %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-    transform.yield
-  }
-}
-//      CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
-//      CHECK: func.func @fuse_pack_consumer_with_untiled_extra_padding(
-// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
-//  CHECK-DAG:   %[[OUT_INIT:.*]] = tensor.empty() : tensor<33x2x3x16xf32>
-//  CHECK-DAG:   %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32
-//      CHECK:   %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16)
-// CHECK-SAME:      shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]])
-//      CHECK:      %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1]
-//      CHECK:      %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
-//      CHECK:      %[[ELEM:.*]] = linalg.exp
-// CHECK-SAME:        ins(%[[ELEM_SRC]]
-// CHECK-SAME:        outs(%[[ELEM_DEST]]
-//  CHECK-DAG:      %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
-//  CHECK-DAG:      %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [33, 1, 3, 16] [1, 1, 1, 1]
-//      CHECK:      %[[TILED_PACK_OUT:.*]] = linalg.pack %[[ELEM]]
-// CHECK-SAME:        padding_value(%[[PAD_VAL]] : f32)
-// CHECK-SAME:        inner_dims_pos = [0, 1] inner_tiles = [3, 16]
-// CHECK-SAME:        into %[[TILED_PACK_DEST]]
-//      CHECK:      scf.forall.in_parallel {
-//      CHECK:          tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
-//      CHECK:          tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [33, 1, 3, 16] [1, 1, 1, 1]
-
-// -----
-
-// If the dimension is tiled and it needs extra padding, do not fuse the pack
-// op.
-
-func.func @nofuse_pack_consumer_with_extra_padding(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x32x3x16xf32> {
-  %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
-    %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
-    %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
-    %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32>
-    scf.forall.in_parallel {
-      // expected-error @below {{failed to fuse consumer of slice}}
-      tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32>
-    }
-  }
-  %1 = tensor.empty() : tensor<23x32x3x16xf32>
-  %cst = arith.constant 0.000000e+00 : f32
-  %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<23x32x3x16xf32>
-  return %pack : tensor<23x32x3x16xf32>
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-    %1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-    %consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
-    transform.yield
-  }
-}
-
-// -----
-
 // Imperfect tiling is not supported in pack op consumer fusion.
 
 #map = affine_map<(d0) -> (d0 * 5)>

>From 1db4c6b27500e686fad9e55bbbe7c7c68b246b7e Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Fri, 25 Jul 2025 11:50:54 -0700
Subject: [PATCH 2/2] Fix cpu pack-unpack-mmt4d shape

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 .../Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir        | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir
index 05e678227de32..a7bb039b04102 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir
@@ -81,21 +81,21 @@ func.func private @matmul(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tenso
 func.func private @mmt4d(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tensor<7x13xi32>) -> tensor<7x13xi32> {
   %zero = arith.constant 0 : i32
 
-  %A_pack_empty = tensor.empty() : tensor<2x16x8x1xi32>
+  %A_pack_empty = tensor.empty() : tensor<1x16x8x1xi32>
   %B_pack_empty = tensor.empty() : tensor<2x16x8x1xi32>
-  %C_pack_empty = tensor.empty() : tensor<2x2x8x8xi32>
+  %C_pack_empty = tensor.empty() : tensor<1x2x8x8xi32>
 
   // Pack matrices
-  %A_pack = linalg.pack %A padding_value(%zero : i32) inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %A_pack_empty : tensor<7x16xi32> -> tensor<2x16x8x1xi32>
+  %A_pack = linalg.pack %A padding_value(%zero : i32) inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %A_pack_empty : tensor<7x16xi32> -> tensor<1x16x8x1xi32>
   %B_pack = linalg.pack %B padding_value(%zero : i32) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [8, 1] into %B_pack_empty : tensor<16x13xi32> -> tensor<2x16x8x1xi32>
-  %C_pack = linalg.pack %C padding_value(%zero : i32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %C_pack_empty : tensor<7x13xi32> -> tensor<2x2x8x8xi32>
+  %C_pack = linalg.pack %C padding_value(%zero : i32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %C_pack_empty : tensor<7x13xi32> -> tensor<1x2x8x8xi32>
 
   // MMT4D
-  %mmt4d = linalg.mmt4d ins(%A_pack, %B_pack : tensor<2x16x8x1xi32>, tensor<2x16x8x1xi32>) outs(%C_pack : tensor<2x2x8x8xi32>) -> tensor<2x2x8x8xi32>
+  %mmt4d = linalg.mmt4d ins(%A_pack, %B_pack : tensor<1x16x8x1xi32>, tensor<2x16x8x1xi32>) outs(%C_pack : tensor<1x2x8x8xi32>) -> tensor<1x2x8x8xi32>
 
   // Unpack output
   %C_out_empty = tensor.empty() : tensor<7x13xi32>
-  %C_out_unpack = linalg.unpack %mmt4d outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %C_out_empty : tensor<2x2x8x8xi32> -> tensor<7x13xi32>
+  %C_out_unpack = linalg.unpack %mmt4d outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %C_out_empty : tensor<1x2x8x8xi32> -> tensor<7x13xi32>
 
   return %C_out_unpack : tensor<7x13xi32>
 }



More information about the Mlir-commits mailing list