[Mlir-commits] [mlir] [MLIR][Linalg] Fix empty tensor assumptions for linalg.pack decomposition (PR #160246)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Sep 23 00:05:45 PDT 2025
https://github.com/maxbartel created https://github.com/llvm/llvm-project/pull/160246
The original code seemed to assume that the tiling dimensions for the tensor.empty op before applying the transposing were always the last dimensions. However, pack allows you to choose any dimension to tile.
The easiest way I found to solve this is to prefill the SmallVector with 1s and then replace the tiled dimension with the tile size directly when figuring out the tile size. That way we do not have the need to add another for loop.
>From 9f031bccdd56b02726055d7744de6be649fbf3fc Mon Sep 17 00:00:00 2001
From: Maximilian Bartel <bartel at roofline.ai>
Date: Tue, 23 Sep 2025 08:57:22 +0200
Subject: [PATCH] (linalg.pack): fix empty tensor assumptions
---
.../Dialect/Linalg/Transforms/Transforms.cpp | 17 +++++++----------
mlir/test/Dialect/Linalg/decompose-pack.mlir | 19 +++++++++++++++++++
2 files changed, 26 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index e9a8b253eea35..69cbc7048f646 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1151,11 +1151,11 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
packOp.getDimAndTileMapping();
int64_t srcRank = packOp.getSourceRank();
int64_t destRank = packOp.getDestRank();
- int64_t numTiles = destRank - srcRank;
- // 1. Extract the inner tile sizes.
- // Where possible, values are replaced with constant attributes (to match the
- // behaviour of `getPackOpSourceOrPaddedSource`).
+ // 1. Extract the inner tile sizes and the shapes for the tensor.empty op
+ // before transposing. Where possible, values are replaced with constant
+ // attributes (to match the behaviour of `getPackOpSourceOrPaddedSource`).
+ SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank, oneIdxAttr);
SmallVector<OpFoldResult> tileSizes;
for (auto i : llvm::seq<unsigned>(0, srcRank)) {
if (dimAndTileMapping.count(i)) {
@@ -1165,6 +1165,7 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
auto [_, tileSize] =
getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter);
tileSizes.push_back(tileSize);
+ transShapeForEmptyOp[i] = tileSize;
}
}
@@ -1194,18 +1195,14 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
LDBG() << "Pack permutation: " << packOp;
LDBG() << "perm: " << llvm::interleaved(srcPermForTranspose);
- // 2.1 Create tensor.empty (init value for TransposeOp)
- SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank - numTiles,
- oneIdxAttr);
- transShapeForEmptyOp.append(tileSizes);
-
+ // 2.2 Transpose the tensor.empty shapes.
applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp,
srcPermForTranspose);
Value empty =
tensor::EmptyOp::create(rewriter, loc, transShapeForEmptyOp,
packOp.getSourceType().getElementType());
- // 2.2 Create linalg.transpose
+ // 2.3 Create linalg.transpose
auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
srcPermForTranspose);
diff --git a/mlir/test/Dialect/Linalg/decompose-pack.mlir b/mlir/test/Dialect/Linalg/decompose-pack.mlir
index 17e6c29754f9d..15521d415b8a7 100644
--- a/mlir/test/Dialect/Linalg/decompose-pack.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-pack.mlir
@@ -274,3 +274,22 @@ func.func @pack_with_adjacent_trailing_dimensions_inner_dims_pos_and_unit_outer(
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
// CHECK-SAME: [0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x4x1xf32> into tensor<1x1x1x4x1xf32>
// CHECK: return %[[INSERT]]
+
+// -----
+
+func.func @pack_with_zero_pos_tile_size(%arg0: tensor<8x1x1x1xf32>, %arg1:tensor<1x1x1x1x8x1xf32>) -> tensor<1x1x1x1x8x1xf32> {
+ %pack = linalg.pack %arg0 outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [0, 3] inner_tiles = [8, 1] into %arg1: tensor<8x1x1x1xf32> -> tensor<1x1x1x1x8x1xf32>
+ return %pack : tensor<1x1x1x1x8x1xf32>
+}
+
+// CHECK-LABEL: func.func @pack_with_zero_pos_tile_size
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x1x8x1xf32>
+// CHECK: %[[TRANSP:.+]] = linalg.transpose
+// CHECK-SAME: ins(%[[SRC]] : tensor<8x1x1x1xf32>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x1x8x1xf32>)
+// CHECK-SAME: permutation = [1, 2, 0, 3]
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
+// CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x1xf32> into tensor<1x1x1x1x8x1xf32>
+// CHECK: return %[[INSERT]]
\ No newline at end of file
More information about the Mlir-commits
mailing list