[Mlir-commits] [mlir] [mlir][linalg] Extend DecomposeOuterUnitDimsPackOpPattern (linalg.pack) (PR #162666)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Thu Oct 9 07:26:11 PDT 2025
https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/162666
Similarly to #152960, this PR fixes `getTiledOuterDims` for
`linalg.pack` by ensuring that the `outer_dims_perm` attributeis
properly taken into account.
This enables the main change in this PR: relaxing the constraints in
* `DecomposeOuterUnitDimsPackOpPattern`.
Specifically, the pattern is extended to allow non-unit untiled outer
dimensions. This makes it consistent with the corresponding pattern for
`linalg.unpack`:
* `DecomposeOuterUnitDimsUnPackOpPattern`.
One notable assumption remains: untiled outer dimensions are not
permuted. This was already the case but is now explicitly documented.
>From f6cff334d97326dc78294616fe5fceeec1ba5713 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 14 Aug 2025 10:41:19 +0000
Subject: [PATCH] [mlir][linalg] Extend DecomposeOuterUnitDimsPackOpPattern
(linalg.pack)
Similarly to #152960, this PR fixes `getTiledOuterDims` for
`linalg.pack` by ensuring that the `outer_dims_perm` attributeis
properly taken into account.
This enables the main change in this PR: relaxing the constraints in
* `DecomposeOuterUnitDimsPackOpPattern`.
Specifically, the pattern is extended to allow non-unit untiled outer
dimensions. This makes it consistent with the corresponding pattern for
`linalg.unpack`:
* `DecomposeOuterUnitDimsUnPackOpPattern`.
One notable assumption remains: untiled outer dimensions are not
permuted. This was already the case but is now explicitly documented.
---
.../Dialect/Linalg/Transforms/Transforms.h | 6 +-
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 11 +++-
.../Dialect/Linalg/Transforms/Transforms.cpp | 58 +++++++++++++------
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 1 +
mlir/test/Dialect/Linalg/decompose-pack.mlir | 36 ++++++++++++
5 files changed, 91 insertions(+), 21 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 7266687584b38..9dcf77ba742a4 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1651,7 +1651,10 @@ struct DecomposePadOpPattern : public OpRewritePattern<tensor::PadOp> {
/// * tensor::PadOp + linalg::TransposeOp + tensor::EmptyOp +
/// tensor::InsertSliceOp ops.
///
-/// Requires that all the outer dims of the input linalg::PackOp are 1.
+/// Requires that all the tile outer dims of the input linalg::PackOp are 1.
+/// Note that this constraint means to effectively one tile is packed.
+///
+/// In addition, assumes that the un-tiled outer dims are not permuted.
///
/// Before:
/// ```
@@ -1691,6 +1694,7 @@ struct DecomposeOuterUnitDimsPackOpPattern
/// * tensor::ExtractSliceOp + linalg::TransposeOp + tensor::InsertSliceOp
///
/// Requires that all the tiled outer dims of the input linalg::PackOp are 1.
+/// Note that this constraint means to effectively one tile is unpacked.
///
/// Before:
/// ```
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 59013a23b3e3b..cbc565b0c8cbd 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5272,11 +5272,18 @@ ArrayRef<int64_t> PackOp::getAllOuterDims() {
SmallVector<int64_t> PackOp::getTiledOuterDims() {
auto innerDimsPos = getInnerDimsPos();
- auto packedShape = getDestType().getShape();
+ SmallVector<int64_t> outerDims(getAllOuterDims());
SmallVector<int64_t> res;
+ // Recover the original order of the outer dims.
+ SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
+ invertPermutationVector(outerDimPermInv);
+ if (!outerDimPermInv.empty())
+ applyPermutationToVector(outerDims, outerDimPermInv);
+
+ // Collect the outer dims corresponding to the tilled inner dims.
for (auto index : innerDimsPos)
- res.push_back(packedShape[index]);
+ res.push_back(outerDims[index]);
return res;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 0dac688e1c26d..70894e81224e8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1134,9 +1134,7 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
linalg::PackOp packOp, PatternRewriter &rewriter) const {
- // TODO: support the case that outer dimensions are not all 1s. A
- // tensor.expand_shape will be generated in this case.
- if (llvm::any_of(packOp.getAllOuterDims(),
+ if (llvm::any_of(packOp.getTiledOuterDims(),
[](int64_t dim) { return dim != 1; })) {
return rewriter.notifyMatchFailure(
packOp, "not all outer dimensions of the result are 1s");
@@ -1149,7 +1147,6 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
int64_t srcRank = packOp.getSourceRank();
int64_t destRank = packOp.getDestRank();
ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
- int64_t numberOfTiles = innerDimsPos.size();
// 1. Get the input that is going to be packed. If the input requires padding,
// add a padding operation and return that as the input.
@@ -1160,10 +1157,14 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
// %transposed_tile = linalg.transpose ins(%source_or_padded_source),
// outs(%init)
// Assumptions made:
- // - All outer dims are 1 - the corresponding transposition order doesn't
- // matter, but requires all dim indices to be present.
-
- // 2.1 Get the permutation for linalg.transpose
+ // - All tiled outer dims are 1 - the corresponding transposition order
+ // doesn't matter, but requires all dim indices to be present.
+ // - Un-tiled outer dims remain un-permuted. (TODO: Fail when this does not
+ // hold)
+
+ // 2.1 Get the permutation for linalg.transpose:
+ // [ untiled-dims, inner-dims-pos ]
+ // Note, this logic assumes that the untiled dims are not permuted.
SmallVector<int64_t> srcPermForTranspose;
for (int64_t i = 0; i < srcRank; i++) {
// We assume the `k` dimensions of the inner dim position, where `k` is the
@@ -1179,9 +1180,19 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
}
srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end());
- // 2.2 Create the init tensor for linalg.transpose with the correct shape
- SmallVector<OpFoldResult> shapeForEmptyOp(srcRank - numberOfTiles,
- oneIdxAttr);
+ // 2.2 Create the init tensor for linalg.transpose with the correct shape:
+ // [ untiled-dims, tiled-dims ]
+ ShapedType inputTy = cast<ShapedType>(input.getType());
+ SmallVector<OpFoldResult> shapeForEmptyOp;
+ for (int64_t i = 0; i < srcRank; i++) {
+ if (llvm::is_contained(innerDimsPos, i))
+ continue;
+ if (inputTy.isStaticDim(i))
+ shapeForEmptyOp.push_back(rewriter.getIndexAttr(inputTy.getShape()[i]));
+ else
+ shapeForEmptyOp.emplace_back(
+ tensor::DimOp::create(rewriter, loc, input, i).getResult());
+ }
shapeForEmptyOp.append(packOp.getMixedTiles());
// getMixedTiles() may contain Values pointing to constant ops, not the
@@ -1206,23 +1217,34 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
// 3. Insert the inner tile to the destination:
// %inserted_tile = tensor.insert_slice(%transposed_tile)
- SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
- SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
- // Outer dims are all 1s!
- SmallVector<OpFoldResult> writeSizes(destRank - numberOfTiles, oneIdxAttr);
- SmallVector<int64_t> writeShape;
+
+ // Compute the sizes attribute:
+ // [ outer-dims, tile-sizes ]
+ // Note that the output from the transpose Op excludes the tiled outer dims.
+ // Given the assumptions (all tiled outer dims == 1), we can safely use a
+ // rank-expanding tensor.insert_slice. Rather than manually computing where to
+ // insert new unit dims (resulting from the expansion), use the Pack op
+ // attributes.
+ SmallVector<OpFoldResult> writeSizes;
+ for (auto size : packOp.getAllOuterDims()) {
+ writeSizes.push_back(rewriter.getIndexAttr(size));
+ }
for (auto tileSize : packOp.getMixedTiles()) {
auto [tileSizeStatic, tileSizeOfr] =
getSimplifiedOfrAndStaticSizePair(tileSize, rewriter);
writeSizes.push_back(tileSizeOfr);
- writeShape.push_back(tileSizeStatic);
}
- // 4. Replace tensor.packOp with tensor.insert_slice created above
+ SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
+ SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
+
+ // TODO: A constructor that doesn't require strised nor offsets.
auto insert = tensor::InsertSliceOp::create(
rewriter, loc, transposedOp.getResult()[0], packOp.getDest(),
writeOffsets, writeSizes, writeStrides);
+
+ // 4. Replace tensor.packOp with tensor.insert_slice created above
rewriter.replaceOp(packOp, insert.getResult());
return success();
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index fa97b49a41d97..ac7200294a3a6 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2310,6 +2310,7 @@ RankedTensorType ExtractSliceOp::inferResultType(
sourceTensorType.getEncoding());
}
+// TODO: This uses neither offsets nor strides!
RankedTensorType ExtractSliceOp::inferResultType(
RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
diff --git a/mlir/test/Dialect/Linalg/decompose-pack.mlir b/mlir/test/Dialect/Linalg/decompose-pack.mlir
index 18a09f4c669bb..bff58d6377cc9 100644
--- a/mlir/test/Dialect/Linalg/decompose-pack.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-pack.mlir
@@ -31,6 +31,25 @@ func.func @simple_KCRS_to_KCRSsr(%arg0: tensor<?x?xi32>, %arg1: tensor<1x1x?x1xi
// -----
+func.func @NCHW_to_NCHWc(%src: tensor<2x32x16x8xf32>, %dest: tensor<2x1x16x8x32xf32>) -> tensor<2x1x16x8x32xf32> {
+ %pack = linalg.pack %src
+ inner_dims_pos = [1]
+ inner_tiles = [32] into %dest
+ : tensor<2x32x16x8xf32> -> tensor<2x1x16x8x32xf32>
+ return %pack : tensor<2x1x16x8x32xf32>
+}
+// CHECK-LABEL: func.func @NCHW_to_NCHWc(
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<2x16x8x32xf32>
+// CHECK: %[[TR:.*]] = linalg.transpose ins(%[[SRC]] : tensor<2x32x16x8xf32>) outs(%[[INIT]] : tensor<2x16x8x32xf32>) permutation = [0, 2, 3, 1]
+// CHECK: %[[RES:.*]] = tensor.insert_slice %[[TR]] into %[[DEST]]
+// CHECK-SAME: [0, 0, 0, 0, 0] [2, 1, 16, 8, 32] [1, 1, 1, 1, 1]
+// CHECK-SAME: : tensor<2x16x8x32xf32> into tensor<2x1x16x8x32xf32>
+// CHECK: return %[[RES]] : tensor<2x1x16x8x32xf32>
+
+// -----
+
func.func @simple_pad_and_pack_static_tiles(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2xf32>, %pad: f32) -> tensor<1x1x8x2xf32> {
%0 = linalg.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<5x1xf32> -> tensor<1x1x8x2xf32>
return %0 : tensor<1x1x8x2xf32>
@@ -295,3 +314,20 @@ func.func @pack_with_non_adjacent_and_non_permuted_inner_dims(%arg0: tensor<8x1x
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
// CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x1xf32> into tensor<1x1x1x1x8x1xf32>
// CHECK: return %[[INSERT]]
+
+// -----
+/// Note "126", which is a non-unit tile-outer-dim. This is not supported.
+
+func.func @negative_non_unit_tiled_outer_dim(%dest: tensor<1x126x1x1x8xf32>, %src: tensor<1x1x1x1001xf32>, %pad: f32) -> tensor<1x126x1x1x8xf32> {
+ %pack = linalg.pack %src
+ padding_value(%pad : f32)
+ outer_dims_perm = [0, 3, 2, 1]
+ inner_dims_pos = [3]
+ inner_tiles = [8]
+ into %dest : tensor<1x1x1x1001xf32>
+ -> tensor<1x126x1x1x8xf32>
+
+ return %pack : tensor<1x126x1x1x8xf32>
+}
+// CHECK-LABEL: @negative_non_unit_tiled_outer_dim(
+// CHECK: linalg.pack
More information about the Mlir-commits
mailing list