[Mlir-commits] [mlir] [MLIR][Linalg] Fix empty tensor assumptions for linalg.pack decomposition (PR #160246)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 29 07:36:03 PDT 2025
https://github.com/maxbartel updated https://github.com/llvm/llvm-project/pull/160246
>From 9f031bccdd56b02726055d7744de6be649fbf3fc Mon Sep 17 00:00:00 2001
From: Maximilian Bartel <bartel at roofline.ai>
Date: Tue, 23 Sep 2025 08:57:22 +0200
Subject: [PATCH 1/2] (linalg.pack): fix empty tensor assumptions
---
.../Dialect/Linalg/Transforms/Transforms.cpp | 17 +++++++----------
mlir/test/Dialect/Linalg/decompose-pack.mlir | 19 +++++++++++++++++++
2 files changed, 26 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index e9a8b253eea35..69cbc7048f646 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1151,11 +1151,11 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
packOp.getDimAndTileMapping();
int64_t srcRank = packOp.getSourceRank();
int64_t destRank = packOp.getDestRank();
- int64_t numTiles = destRank - srcRank;
- // 1. Extract the inner tile sizes.
- // Where possible, values are replaced with constant attributes (to match the
- // behaviour of `getPackOpSourceOrPaddedSource`).
+ // 1. Extract the inner tile sizes and the shapes for the tensor.empty op
+ // before transposing. Where possible, values are replaced with constant
+ // attributes (to match the behaviour of `getPackOpSourceOrPaddedSource`).
+ SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank, oneIdxAttr);
SmallVector<OpFoldResult> tileSizes;
for (auto i : llvm::seq<unsigned>(0, srcRank)) {
if (dimAndTileMapping.count(i)) {
@@ -1165,6 +1165,7 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
auto [_, tileSize] =
getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter);
tileSizes.push_back(tileSize);
+ transShapeForEmptyOp[i] = tileSize;
}
}
@@ -1194,18 +1195,14 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
LDBG() << "Pack permutation: " << packOp;
LDBG() << "perm: " << llvm::interleaved(srcPermForTranspose);
- // 2.1 Create tensor.empty (init value for TransposeOp)
- SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank - numTiles,
- oneIdxAttr);
- transShapeForEmptyOp.append(tileSizes);
-
+ // 2.2 Transpose the tensor.empty shapes.
applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp,
srcPermForTranspose);
Value empty =
tensor::EmptyOp::create(rewriter, loc, transShapeForEmptyOp,
packOp.getSourceType().getElementType());
- // 2.2 Create linalg.transpose
+ // 2.3 Create linalg.transpose
auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
srcPermForTranspose);
diff --git a/mlir/test/Dialect/Linalg/decompose-pack.mlir b/mlir/test/Dialect/Linalg/decompose-pack.mlir
index 17e6c29754f9d..15521d415b8a7 100644
--- a/mlir/test/Dialect/Linalg/decompose-pack.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-pack.mlir
@@ -274,3 +274,22 @@ func.func @pack_with_adjacent_trailing_dimensions_inner_dims_pos_and_unit_outer(
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
// CHECK-SAME: [0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x4x1xf32> into tensor<1x1x1x4x1xf32>
// CHECK: return %[[INSERT]]
+
+// -----
+
+func.func @pack_with_zero_pos_tile_size(%arg0: tensor<8x1x1x1xf32>, %arg1:tensor<1x1x1x1x8x1xf32>) -> tensor<1x1x1x1x8x1xf32> {
+ %pack = linalg.pack %arg0 outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [0, 3] inner_tiles = [8, 1] into %arg1: tensor<8x1x1x1xf32> -> tensor<1x1x1x1x8x1xf32>
+ return %pack : tensor<1x1x1x1x8x1xf32>
+}
+
+// CHECK-LABEL: func.func @pack_with_zero_pos_tile_size
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x1x8x1xf32>
+// CHECK: %[[TRANSP:.+]] = linalg.transpose
+// CHECK-SAME: ins(%[[SRC]] : tensor<8x1x1x1xf32>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x1x8x1xf32>)
+// CHECK-SAME: permutation = [1, 2, 0, 3]
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
+// CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x1xf32> into tensor<1x1x1x1x8x1xf32>
+// CHECK: return %[[INSERT]]
\ No newline at end of file
>From eae9f9946e30104dbb4e7a86b96ed3a735700929 Mon Sep 17 00:00:00 2001
From: Maximilian Bartel <bartel at roofline.ai>
Date: Mon, 29 Sep 2025 16:35:46 +0200
Subject: [PATCH 2/2] (linalg.pack): simplify outer dims patterns after review
---
.../Dialect/Linalg/IR/LinalgRelayoutOps.td | 3 +-
.../Dialect/Linalg/Transforms/Transforms.cpp | 59 +++++++++----------
2 files changed, 30 insertions(+), 32 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index f36b41ccf6745..5006d815a798a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -57,7 +57,8 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
/// tile factors.
DenseMap<int64_t, OpFoldResult> getDimAndTileMapping();
- /// Return the tile sizes as OpFoldResult.
+ /// Return the tile sizes as OpFoldResult. Will return the Value
+ /// of the constant Op, not the constant Attribute.
SmallVector<OpFoldResult> getMixedTiles();
/// Return the tile sizes as `int64_t`. If a tile size is dynamic
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 69cbc7048f646..60219335d6a1c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1146,38 +1146,25 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
Location loc = packOp.getLoc();
- Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
- DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
- packOp.getDimAndTileMapping();
int64_t srcRank = packOp.getSourceRank();
int64_t destRank = packOp.getDestRank();
+ ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
+ int64_t numberOfTiles = innerDimsPos.size();
- // 1. Extract the inner tile sizes and the shapes for the tensor.empty op
- // before transposing. Where possible, values are replaced with constant
- // attributes (to match the behaviour of `getPackOpSourceOrPaddedSource`).
- SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank, oneIdxAttr);
- SmallVector<OpFoldResult> tileSizes;
- for (auto i : llvm::seq<unsigned>(0, srcRank)) {
- if (dimAndTileMapping.count(i)) {
- // Rather than taking the tile size as is, extact the actual constant
- // value Attribute where possible, e.g.:
- // [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8]
- auto [_, tileSize] =
- getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter);
- tileSizes.push_back(tileSize);
- transShapeForEmptyOp[i] = tileSize;
- }
- }
+ // 1. Get the input that is going to be packed. If the input requires padding,
+ // add a padding operation and return that as the input.
+ Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
// 2. Transpose the input to match the inner tile order:
// %init = tensor.empty()
// %transposed_tile = linalg.transpose ins(%source_or_padded_source),
// outs(%init)
// Assumptions made:
- // 1. All outer dims are 1 - the corresponding transposition order doesn't
+ // - All outer dims are 1 - the corresponding transposition order doesn't
// matter, but requires all dim indices to be present.
+
+ // 2.1 Get the permutation for linalg.transpose
SmallVector<int64_t> srcPermForTranspose;
- ArrayRef<int64_t> innerDimPos(packOp.getInnerDimsPos());
for (int64_t i = 0; i < srcRank; i++) {
// We assume the `k` dimensions of the inner dim position, where `k` is the
// rank of the inner tiling, correspond to the last `k` indices of the
@@ -1186,21 +1173,32 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
// rank of the source tensor. For example if we have a source tensor with
// indices [0, 1, 2, 3] and inner dim position of [3, 0], the remaining
// indices are [1, 2]. and the transpose will be [1, 2, 3, 0].
- if (llvm::is_contained(innerDimPos, i))
+ if (llvm::is_contained(innerDimsPos, i))
continue;
srcPermForTranspose.push_back(i);
}
- srcPermForTranspose.append(innerDimPos.begin(), innerDimPos.end());
+ srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end());
+
+ // 2.2 Create the init tensor for linalg.transpose with the correct shape
+ SmallVector<OpFoldResult> shapeForEmptyOp(srcRank - numberOfTiles,
+ oneIdxAttr);
+ shapeForEmptyOp.append(packOp.getMixedTiles());
+
+ // getMixedTiles() may contain Values pointing to constant ops, not the
+ // constant attributes. Replace them with a true OpFoldResult.
+ llvm::transform(shapeForEmptyOp, shapeForEmptyOp.begin(),
+ [&](OpFoldResult ofr) {
+ if (auto val = llvm::dyn_cast<Value>(ofr))
+ return getAsOpFoldResult(val);
+ return ofr;
+ });
LDBG() << "Pack permutation: " << packOp;
LDBG() << "perm: " << llvm::interleaved(srcPermForTranspose);
+ LDBG() << "Shape of empty tensor: " << llvm::interleaved(shapeForEmptyOp);
- // 2.2 Transpose the tensor.empty shapes.
- applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp,
- srcPermForTranspose);
- Value empty =
- tensor::EmptyOp::create(rewriter, loc, transShapeForEmptyOp,
- packOp.getSourceType().getElementType());
+ Value empty = tensor::EmptyOp::create(
+ rewriter, loc, shapeForEmptyOp, packOp.getSourceType().getElementType());
// 2.3 Create linalg.transpose
auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
@@ -1211,8 +1209,7 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
// Outer dims are all 1s!
- SmallVector<OpFoldResult> writeSizes(destRank - dimAndTileMapping.size(),
- oneIdxAttr);
+ SmallVector<OpFoldResult> writeSizes(destRank - numberOfTiles, oneIdxAttr);
SmallVector<int64_t> writeShape;
for (auto tileSize : packOp.getMixedTiles()) {
More information about the Mlir-commits
mailing list