[Mlir-commits] [mlir] [mlir][linalg] Improve UnPackOp tiling for perfect-tiling cases. (PR #189470)

Han-Chung Wang llvmlistbot at llvm.org
Mon Mar 30 13:54:09 PDT 2026


https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/189470

>From d318ffd333d5031ea820777b7675278f6151114d Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Mon, 30 Mar 2026 13:24:14 -0700
Subject: [PATCH 1/2] [mlir][linalg] Improve UnPackOp tiling for perfect-tiling
 cases.

If a dimension is not tiled at all, it should be perfect-tiling
dimension even if the size is not multiple of the inner tile size. No
intermediate tensor is needed in this context.

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 .../Linalg/Transforms/TilingInterfaceImpl.cpp | 12 +++++-
 .../Linalg/transform-op-tile-pack-unpack.mlir | 39 +++++++++++++++++++
 2 files changed, 50 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 558ebdebd65c5..9d6ec439c29b5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -1389,6 +1389,13 @@ struct UnPackOpTiling
     // The perfect tiling case indicates that the tiling sizes are multiple of
     // inner_tile_size. In this context, no extra data is needed when
     // representing the tiled unpack op.
+    //
+    // Also, if a dimension is not tiled (tile size equals the iteration domain
+    // size), the original unpack already handles any truncation from
+    // (outerTiles * innerTileSize) down to the dest dim size. No intermediate
+    // expansion is needed for that dimension.
+    SmallVector<Range> iterationDomain =
+        cast<TilingInterface>(op).getIterationDomain(b);
     bool isPerfectTilingCase = true;
     Attribute oneAttr = b.getIndexAttr(1);
     SmallVector<OpFoldResult> sliceSrcStrides(destRank, oneAttr);
@@ -1397,7 +1404,10 @@ struct UnPackOpTiling
     for (auto dim : llvm::seq<int64_t>(0, destRank)) {
       UnpackTileDimInfo info =
           getUnpackTileDimInfo(b, unpackOp, dim, offsets[dim], sizes[dim]);
-      if (!info.isAlignedToInnerTileSize)
+      // If a dimension is tiled and it is not aligned to inner tile size, it is
+      // not a perfect tiling case.
+      if (!info.isAlignedToInnerTileSize &&
+          !isEqualConstantIntOrValue(sizes[dim], iterationDomain[dim].size))
         isPerfectTilingCase = false;
       sliceSrcIndices.push_back(info.sourceOffset);
       sliceSrcSizes.push_back(info.sourceSize);
diff --git a/mlir/test/Dialect/Linalg/transform-op-tile-pack-unpack.mlir b/mlir/test/Dialect/Linalg/transform-op-tile-pack-unpack.mlir
index 456a5ea453963..402b818938f67 100644
--- a/mlir/test/Dialect/Linalg/transform-op-tile-pack-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-tile-pack-unpack.mlir
@@ -477,6 +477,45 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME:    into %{{.+}}[0, %[[ARG2]], %[[ARG4]], %[[ARG6]], 0]
 // CHECK:       scf.yield %[[RES]]
 
+// When only some dimensions are tiled (tile_sizes [1, 0, 0]) and the untiled
+// dimensions are not aligned with inner tile sizes (123 % 128 != 0,
+// 1023 % 128 != 0), the tiled unpack should use extract_slice on the dest
+// (not tensor.empty) because the untiled dimensions pass through as-is and
+// the original unpack handles truncation.
+
+// CHECK-LABEL: func.func @unpack_with_untiled_non_aligned_dims
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:       scf.for %[[IV:[a-zA-Z0-9]+]] =
+// CHECK-SAME:    iter_args(%[[ITER:[a-zA-Z0-9]+]] = %[[DEST]])
+// CHECK:         %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]][%[[IV]], 0, 0, 0, 0] [1, 1, 8, 128, 128]
+// CHECK:         %[[DEST_SLICE:.+]] = tensor.extract_slice %[[ITER]][%[[IV]], 0, 0] [1, 123, 1023] [1, 1, 1]
+// CHECK:         %[[UNPACK:.+]] = linalg.unpack %[[SRC_SLICE]]
+// CHECK-SAME:      outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [128, 128]
+// CHECK-SAME:      into %[[DEST_SLICE]]
+// CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[UNPACK]] into %[[ITER]][%[[IV]], 0, 0] [1, 123, 1023] [1, 1, 1]
+// CHECK:         scf.yield %[[INSERT]]
+func.func @unpack_with_untiled_non_aligned_dims(
+    %src: tensor<1828x1x8x128x128xf32>,
+    %dest: tensor<1828x123x1023xf32>) -> tensor<1828x123x1023xf32> {
+  %0 = linalg.unpack %src
+      outer_dims_perm = [0, 1, 2]
+      inner_dims_pos = [1, 2]
+      inner_tiles = [128, 128]
+      into %dest : tensor<1828x1x8x128x128xf32> -> tensor<1828x123x1023xf32>
+  return %0 : tensor<1828x123x1023xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.unpack"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %loops:1 = transform.structured.tile_using_for %0 tile_sizes [1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
 func.func @perfect_NPQK_to_NKPQk(%source: tensor<1x6x6x8xf32>, %dest: tensor<1x4x6x6x2xf32>) -> tensor<1x4x6x6x2xf32> {
   %0 = linalg.pack %source outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [2] into %dest : tensor<1x6x6x8xf32> -> tensor<1x4x6x6x2xf32>
   return %0 : tensor<1x4x6x6x2xf32>

>From 9eee9a8a0be3a4d91d10628b5e7413250b0baa5e Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Mon, 30 Mar 2026 13:53:42 -0700
Subject: [PATCH 2/2] put the new test in the right scope....

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 .../Linalg/transform-op-tile-pack-unpack.mlir | 42 +++++++++----------
 1 file changed, 21 insertions(+), 21 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/transform-op-tile-pack-unpack.mlir b/mlir/test/Dialect/Linalg/transform-op-tile-pack-unpack.mlir
index 402b818938f67..4dd5db8e9c88c 100644
--- a/mlir/test/Dialect/Linalg/transform-op-tile-pack-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-tile-pack-unpack.mlir
@@ -456,27 +456,6 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 * 2)>
-// CHECK: func.func @perfect_NPQK_to_NKPQk
-// CHECK-SAME:  %[[SOURCE:.+]]: tensor<1x6x6x8xf32>,
-// CHECK-SAME:  %{{.+}}: tensor<1x4x6x6x2xf32>)
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
-// CHECK-DAG: %[[C6:.+]] = arith.constant 6 : index
-// CHECK: %{{.+}} = scf.for %[[ARG2:.+]] = %[[C0]] to %[[C4]] step %[[C1]]
-// CHECK:   %{{.+}} = scf.for %[[ARG4:.+]] = %[[C0]] to %[[C6]] step %[[C1]]
-// CHECK:     %{{.+}} = scf.for %[[ARG6:.+]] = %[[C0]] to %[[C6]] step %[[C1]]
-// CHECK:       %[[APPLY:.+]] = affine.apply #[[MAP1]](%[[ARG2]])
-// CHECK:       %[[SLICE_SOURCE:.+]] = tensor.extract_slice %[[SOURCE]][0, %[[ARG4]], %[[ARG6]], %[[APPLY]]]
-// CHECK:       %[[SLICE_DEST:.+]] = tensor.extract_slice %{{.+}}[0, %[[ARG2]], %[[ARG4]], %[[ARG6]], 0]
-// CHECK:       %[[PACK:.+]] = linalg.pack
-// CHECK-SAME:    %[[SLICE_SOURCE]] outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [2]
-// CHECK-SAME:    into %[[SLICE_DEST]]
-// CHECK:       %[[RES:.+]] = tensor.insert_slice %[[PACK]]
-// CHECK-SAME:    into %{{.+}}[0, %[[ARG2]], %[[ARG4]], %[[ARG6]], 0]
-// CHECK:       scf.yield %[[RES]]
-
 // When only some dimensions are tiled (tile_sizes [1, 0, 0]) and the untiled
 // dimensions are not aligned with inner tile sizes (123 % 128 != 0,
 // 1023 % 128 != 0), the tiled unpack should use extract_slice on the dest
@@ -516,6 +495,27 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 * 2)>
+// CHECK: func.func @perfect_NPQK_to_NKPQk
+// CHECK-SAME:  %[[SOURCE:.+]]: tensor<1x6x6x8xf32>,
+// CHECK-SAME:  %{{.+}}: tensor<1x4x6x6x2xf32>)
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
+// CHECK-DAG: %[[C6:.+]] = arith.constant 6 : index
+// CHECK: %{{.+}} = scf.for %[[ARG2:.+]] = %[[C0]] to %[[C4]] step %[[C1]]
+// CHECK:   %{{.+}} = scf.for %[[ARG4:.+]] = %[[C0]] to %[[C6]] step %[[C1]]
+// CHECK:     %{{.+}} = scf.for %[[ARG6:.+]] = %[[C0]] to %[[C6]] step %[[C1]]
+// CHECK:       %[[APPLY:.+]] = affine.apply #[[MAP1]](%[[ARG2]])
+// CHECK:       %[[SLICE_SOURCE:.+]] = tensor.extract_slice %[[SOURCE]][0, %[[ARG4]], %[[ARG6]], %[[APPLY]]]
+// CHECK:       %[[SLICE_DEST:.+]] = tensor.extract_slice %{{.+}}[0, %[[ARG2]], %[[ARG4]], %[[ARG6]], 0]
+// CHECK:       %[[PACK:.+]] = linalg.pack
+// CHECK-SAME:    %[[SLICE_SOURCE]] outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [2]
+// CHECK-SAME:    into %[[SLICE_DEST]]
+// CHECK:       %[[RES:.+]] = tensor.insert_slice %[[PACK]]
+// CHECK-SAME:    into %{{.+}}[0, %[[ARG2]], %[[ARG4]], %[[ARG6]], 0]
+// CHECK:       scf.yield %[[RES]]
+
 func.func @perfect_NPQK_to_NKPQk(%source: tensor<1x6x6x8xf32>, %dest: tensor<1x4x6x6x2xf32>) -> tensor<1x4x6x6x2xf32> {
   %0 = linalg.pack %source outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [2] into %dest : tensor<1x6x6x8xf32> -> tensor<1x4x6x6x2xf32>
   return %0 : tensor<1x4x6x6x2xf32>



More information about the Mlir-commits mailing list