[Mlir-commits] [mlir] [mlir][linalg] Improve UnPackOp tiling for perfect-tiling cases. (PR #189470)
Han-Chung Wang
llvmlistbot at llvm.org
Mon Mar 30 13:54:09 PDT 2026
https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/189470
>From d318ffd333d5031ea820777b7675278f6151114d Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Mon, 30 Mar 2026 13:24:14 -0700
Subject: [PATCH 1/2] [mlir][linalg] Improve UnPackOp tiling for perfect-tiling
cases.
If a dimension is not tiled at all, it should be perfect-tiling
dimension even if the size is not multiple of the inner tile size. No
intermediate tensor is needed in this context.
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
.../Linalg/Transforms/TilingInterfaceImpl.cpp | 12 +++++-
.../Linalg/transform-op-tile-pack-unpack.mlir | 39 +++++++++++++++++++
2 files changed, 50 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 558ebdebd65c5..9d6ec439c29b5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -1389,6 +1389,13 @@ struct UnPackOpTiling
// The perfect tiling case indicates that the tiling sizes are multiple of
// inner_tile_size. In this context, no extra data is needed when
// representing the tiled unpack op.
+ //
+ // Also, if a dimension is not tiled (tile size equals the iteration domain
+ // size), the original unpack already handles any truncation from
+ // (outerTiles * innerTileSize) down to the dest dim size. No intermediate
+ // expansion is needed for that dimension.
+ SmallVector<Range> iterationDomain =
+ cast<TilingInterface>(op).getIterationDomain(b);
bool isPerfectTilingCase = true;
Attribute oneAttr = b.getIndexAttr(1);
SmallVector<OpFoldResult> sliceSrcStrides(destRank, oneAttr);
@@ -1397,7 +1404,10 @@ struct UnPackOpTiling
for (auto dim : llvm::seq<int64_t>(0, destRank)) {
UnpackTileDimInfo info =
getUnpackTileDimInfo(b, unpackOp, dim, offsets[dim], sizes[dim]);
- if (!info.isAlignedToInnerTileSize)
+ // If a dimension is tiled and it is not aligned to inner tile size, it is
+ // not a perfect tiling case.
+ if (!info.isAlignedToInnerTileSize &&
+ !isEqualConstantIntOrValue(sizes[dim], iterationDomain[dim].size))
isPerfectTilingCase = false;
sliceSrcIndices.push_back(info.sourceOffset);
sliceSrcSizes.push_back(info.sourceSize);
diff --git a/mlir/test/Dialect/Linalg/transform-op-tile-pack-unpack.mlir b/mlir/test/Dialect/Linalg/transform-op-tile-pack-unpack.mlir
index 456a5ea453963..402b818938f67 100644
--- a/mlir/test/Dialect/Linalg/transform-op-tile-pack-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-tile-pack-unpack.mlir
@@ -477,6 +477,45 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: into %{{.+}}[0, %[[ARG2]], %[[ARG4]], %[[ARG6]], 0]
// CHECK: scf.yield %[[RES]]
+// When only some dimensions are tiled (tile_sizes [1, 0, 0]) and the untiled
+// dimensions are not aligned with inner tile sizes (123 % 128 != 0,
+// 1023 % 128 != 0), the tiled unpack should use extract_slice on the dest
+// (not tensor.empty) because the untiled dimensions pass through as-is and
+// the original unpack handles truncation.
+
+// CHECK-LABEL: func.func @unpack_with_untiled_non_aligned_dims
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK: scf.for %[[IV:[a-zA-Z0-9]+]] =
+// CHECK-SAME: iter_args(%[[ITER:[a-zA-Z0-9]+]] = %[[DEST]])
+// CHECK: %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]][%[[IV]], 0, 0, 0, 0] [1, 1, 8, 128, 128]
+// CHECK: %[[DEST_SLICE:.+]] = tensor.extract_slice %[[ITER]][%[[IV]], 0, 0] [1, 123, 1023] [1, 1, 1]
+// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC_SLICE]]
+// CHECK-SAME: outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [128, 128]
+// CHECK-SAME: into %[[DEST_SLICE]]
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[UNPACK]] into %[[ITER]][%[[IV]], 0, 0] [1, 123, 1023] [1, 1, 1]
+// CHECK: scf.yield %[[INSERT]]
+func.func @unpack_with_untiled_non_aligned_dims(
+ %src: tensor<1828x1x8x128x128xf32>,
+ %dest: tensor<1828x123x1023xf32>) -> tensor<1828x123x1023xf32> {
+ %0 = linalg.unpack %src
+ outer_dims_perm = [0, 1, 2]
+ inner_dims_pos = [1, 2]
+ inner_tiles = [128, 128]
+ into %dest : tensor<1828x1x8x128x128xf32> -> tensor<1828x123x1023xf32>
+ return %0 : tensor<1828x123x1023xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.unpack"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1, %loops:1 = transform.structured.tile_using_for %0 tile_sizes [1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
func.func @perfect_NPQK_to_NKPQk(%source: tensor<1x6x6x8xf32>, %dest: tensor<1x4x6x6x2xf32>) -> tensor<1x4x6x6x2xf32> {
%0 = linalg.pack %source outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [2] into %dest : tensor<1x6x6x8xf32> -> tensor<1x4x6x6x2xf32>
return %0 : tensor<1x4x6x6x2xf32>
>From 9eee9a8a0be3a4d91d10628b5e7413250b0baa5e Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Mon, 30 Mar 2026 13:53:42 -0700
Subject: [PATCH 2/2] put the new test in the right scope....
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
.../Linalg/transform-op-tile-pack-unpack.mlir | 42 +++++++++----------
1 file changed, 21 insertions(+), 21 deletions(-)
diff --git a/mlir/test/Dialect/Linalg/transform-op-tile-pack-unpack.mlir b/mlir/test/Dialect/Linalg/transform-op-tile-pack-unpack.mlir
index 402b818938f67..4dd5db8e9c88c 100644
--- a/mlir/test/Dialect/Linalg/transform-op-tile-pack-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-tile-pack-unpack.mlir
@@ -456,27 +456,6 @@ module attributes {transform.with_named_sequence} {
// -----
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 * 2)>
-// CHECK: func.func @perfect_NPQK_to_NKPQk
-// CHECK-SAME: %[[SOURCE:.+]]: tensor<1x6x6x8xf32>,
-// CHECK-SAME: %{{.+}}: tensor<1x4x6x6x2xf32>)
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
-// CHECK-DAG: %[[C6:.+]] = arith.constant 6 : index
-// CHECK: %{{.+}} = scf.for %[[ARG2:.+]] = %[[C0]] to %[[C4]] step %[[C1]]
-// CHECK: %{{.+}} = scf.for %[[ARG4:.+]] = %[[C0]] to %[[C6]] step %[[C1]]
-// CHECK: %{{.+}} = scf.for %[[ARG6:.+]] = %[[C0]] to %[[C6]] step %[[C1]]
-// CHECK: %[[APPLY:.+]] = affine.apply #[[MAP1]](%[[ARG2]])
-// CHECK: %[[SLICE_SOURCE:.+]] = tensor.extract_slice %[[SOURCE]][0, %[[ARG4]], %[[ARG6]], %[[APPLY]]]
-// CHECK: %[[SLICE_DEST:.+]] = tensor.extract_slice %{{.+}}[0, %[[ARG2]], %[[ARG4]], %[[ARG6]], 0]
-// CHECK: %[[PACK:.+]] = linalg.pack
-// CHECK-SAME: %[[SLICE_SOURCE]] outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [2]
-// CHECK-SAME: into %[[SLICE_DEST]]
-// CHECK: %[[RES:.+]] = tensor.insert_slice %[[PACK]]
-// CHECK-SAME: into %{{.+}}[0, %[[ARG2]], %[[ARG4]], %[[ARG6]], 0]
-// CHECK: scf.yield %[[RES]]
-
// When only some dimensions are tiled (tile_sizes [1, 0, 0]) and the untiled
// dimensions are not aligned with inner tile sizes (123 % 128 != 0,
// 1023 % 128 != 0), the tiled unpack should use extract_slice on the dest
@@ -516,6 +495,27 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 * 2)>
+// CHECK: func.func @perfect_NPQK_to_NKPQk
+// CHECK-SAME: %[[SOURCE:.+]]: tensor<1x6x6x8xf32>,
+// CHECK-SAME: %{{.+}}: tensor<1x4x6x6x2xf32>)
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
+// CHECK-DAG: %[[C6:.+]] = arith.constant 6 : index
+// CHECK: %{{.+}} = scf.for %[[ARG2:.+]] = %[[C0]] to %[[C4]] step %[[C1]]
+// CHECK: %{{.+}} = scf.for %[[ARG4:.+]] = %[[C0]] to %[[C6]] step %[[C1]]
+// CHECK: %{{.+}} = scf.for %[[ARG6:.+]] = %[[C0]] to %[[C6]] step %[[C1]]
+// CHECK: %[[APPLY:.+]] = affine.apply #[[MAP1]](%[[ARG2]])
+// CHECK: %[[SLICE_SOURCE:.+]] = tensor.extract_slice %[[SOURCE]][0, %[[ARG4]], %[[ARG6]], %[[APPLY]]]
+// CHECK: %[[SLICE_DEST:.+]] = tensor.extract_slice %{{.+}}[0, %[[ARG2]], %[[ARG4]], %[[ARG6]], 0]
+// CHECK: %[[PACK:.+]] = linalg.pack
+// CHECK-SAME: %[[SLICE_SOURCE]] outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [2]
+// CHECK-SAME: into %[[SLICE_DEST]]
+// CHECK: %[[RES:.+]] = tensor.insert_slice %[[PACK]]
+// CHECK-SAME: into %{{.+}}[0, %[[ARG2]], %[[ARG4]], %[[ARG6]], 0]
+// CHECK: scf.yield %[[RES]]
+
func.func @perfect_NPQK_to_NKPQk(%source: tensor<1x6x6x8xf32>, %dest: tensor<1x4x6x6x2xf32>) -> tensor<1x4x6x6x2xf32> {
%0 = linalg.pack %source outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [2] into %dest : tensor<1x6x6x8xf32> -> tensor<1x4x6x6x2xf32>
return %0 : tensor<1x4x6x6x2xf32>
More information about the Mlir-commits
mailing list