[Mlir-commits] [mlir] [MLIR][Linalg] Fix empty tensor assumptions for linalg.pack decomposition (PR #160246)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Sep 23 00:06:18 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: None (maxbartel)
<details>
<summary>Changes</summary>
The original code seemed to assume that the tiling dimensions for the tensor.empty op before applying the transposing were always the last dimensions. However, pack allows you to choose any dimension to tile.
The easiest way I found to solve this is to prefill the SmallVector with 1s and then replace the tiled dimension with the tile size directly when figuring out the tile size. That way we do not have the need to add another for loop.
---
Full diff: https://github.com/llvm/llvm-project/pull/160246.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+7-10)
- (modified) mlir/test/Dialect/Linalg/decompose-pack.mlir (+19)
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index e9a8b253eea35..69cbc7048f646 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1151,11 +1151,11 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
packOp.getDimAndTileMapping();
int64_t srcRank = packOp.getSourceRank();
int64_t destRank = packOp.getDestRank();
- int64_t numTiles = destRank - srcRank;
- // 1. Extract the inner tile sizes.
- // Where possible, values are replaced with constant attributes (to match the
- // behaviour of `getPackOpSourceOrPaddedSource`).
+ // 1. Extract the inner tile sizes and the shapes for the tensor.empty op
+ // before transposing. Where possible, values are replaced with constant
+ // attributes (to match the behaviour of `getPackOpSourceOrPaddedSource`).
+ SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank, oneIdxAttr);
SmallVector<OpFoldResult> tileSizes;
for (auto i : llvm::seq<unsigned>(0, srcRank)) {
if (dimAndTileMapping.count(i)) {
@@ -1165,6 +1165,7 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
auto [_, tileSize] =
getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter);
tileSizes.push_back(tileSize);
+ transShapeForEmptyOp[i] = tileSize;
}
}
@@ -1194,18 +1195,14 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
LDBG() << "Pack permutation: " << packOp;
LDBG() << "perm: " << llvm::interleaved(srcPermForTranspose);
- // 2.1 Create tensor.empty (init value for TransposeOp)
- SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank - numTiles,
- oneIdxAttr);
- transShapeForEmptyOp.append(tileSizes);
-
+ // 2.2 Transpose the tensor.empty shapes.
applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp,
srcPermForTranspose);
Value empty =
tensor::EmptyOp::create(rewriter, loc, transShapeForEmptyOp,
packOp.getSourceType().getElementType());
- // 2.2 Create linalg.transpose
+ // 2.3 Create linalg.transpose
auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
srcPermForTranspose);
diff --git a/mlir/test/Dialect/Linalg/decompose-pack.mlir b/mlir/test/Dialect/Linalg/decompose-pack.mlir
index 17e6c29754f9d..15521d415b8a7 100644
--- a/mlir/test/Dialect/Linalg/decompose-pack.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-pack.mlir
@@ -274,3 +274,22 @@ func.func @pack_with_adjacent_trailing_dimensions_inner_dims_pos_and_unit_outer(
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
// CHECK-SAME: [0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x4x1xf32> into tensor<1x1x1x4x1xf32>
// CHECK: return %[[INSERT]]
+
+// -----
+
+func.func @pack_with_zero_pos_tile_size(%arg0: tensor<8x1x1x1xf32>, %arg1:tensor<1x1x1x1x8x1xf32>) -> tensor<1x1x1x1x8x1xf32> {
+ %pack = linalg.pack %arg0 outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [0, 3] inner_tiles = [8, 1] into %arg1: tensor<8x1x1x1xf32> -> tensor<1x1x1x1x8x1xf32>
+ return %pack : tensor<1x1x1x1x8x1xf32>
+}
+
+// CHECK-LABEL: func.func @pack_with_zero_pos_tile_size
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x1x8x1xf32>
+// CHECK: %[[TRANSP:.+]] = linalg.transpose
+// CHECK-SAME: ins(%[[SRC]] : tensor<8x1x1x1xf32>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x1x8x1xf32>)
+// CHECK-SAME: permutation = [1, 2, 0, 3]
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
+// CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x1xf32> into tensor<1x1x1x1x8x1xf32>
+// CHECK: return %[[INSERT]]
\ No newline at end of file
``````````
</details>
https://github.com/llvm/llvm-project/pull/160246
More information about the Mlir-commits
mailing list