[Mlir-commits] [mlir] [mlir][linalg] fix OuterUnitDims linalg.pack decomposition pattern (PR #141613)
Christopher McGirr
llvmlistbot at llvm.org
Wed Jun 4 04:23:13 PDT 2025
https://github.com/chrsmcgrr updated https://github.com/llvm/llvm-project/pull/141613
>From b06ceb8e6d8f578494a1af2cb0b85eb918fe978e Mon Sep 17 00:00:00 2001
From: Christopher McGirr <mcgirr at roofline.ai>
Date: Thu, 22 May 2025 15:55:55 +0000
Subject: [PATCH 1/2] [mlir][linalg] fix OuterUnitDims linalg.pack
decomposition pattern
Given the following example:
```
module {
func.func @main(%arg0: tensor<1x1x1x4x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x1x4x1xf32> {
%pack = linalg.pack %arg1 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg0 : tensor<1x1x4xf32> -> tensor<1x1x1x4x1xf32>
return %pack : tensor<1x1x1x4x1xf32>
}
}
```
We would generate an invalid transpose operation because the calculated
permutation would be `[0, 2, 0]` which is semantically incorrect. As the
permutation must contain unique integers corresponding to the source
tensor dimensions.
The following change modifies how we calculate the permutation array and
ensures that the dimension indices given in the permutation array is
unique.
The above example would then translate to a transpose having a
permutation of `[1, 2, 0]`. Following the rule, that the `inner_dim_pos`
is appended to the permutation array and the preceding indices are
filled with the remaining dimensions.
---
.../Dialect/Linalg/Transforms/Transforms.cpp | 23 ++++++++++++-------
mlir/test/Dialect/Linalg/decompose-pack.mlir | 19 +++++++++++++++
2 files changed, 34 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 8718c57b9e86c..7b6c8243d1040 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1205,16 +1205,23 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
// %init = tensor.empty()
// %transposed_tile = linalg.transpose ins(%source_or_padded_source),
// outs(%init)
- // Two assumptions are made:
- // 1. All outer dims are 1 - the corresponding transposition doesn't matter.
- // 2. Inner dims position correspond to the trailing `numTiles` dims.
- SmallVector<int64_t> tilesPermNormalized =
- getPackUnpackNormalizedPerm(srcRank, packOp.getInnerDimsPos());
+ // Assumptions made:
+ // 1. Inner dims position correspond to the trailing `numTiles` dims.
SmallVector<int64_t> srcPermForTranspose;
- for (int64_t i = 0; i < (srcRank - numTiles); i++)
+ ArrayRef<int64_t> innerDimPos(packOp.getInnerDimsPos());
+ for (int64_t i = 0; i < srcRank; i++) {
+ // As we assume the trailing dimensions of the inner dim position correspond
+ // to the trailing indices of the transpose permutation, we need to
+ // calculate the remaining indicies of the transpose permutation. This is
+ // done by adding the indices not contained in the inner dimension position.
+ // For example if we have a source tensor of dimensions [0, 1, 2, 3]
+ // and inner dim position of [3, 0], the remaining indices are [1, 2].
+ // and the transpose will be [1, 2, 3, 0].
+ if (llvm::is_contained(innerDimPos, i))
+ continue;
srcPermForTranspose.push_back(i);
-
- srcPermForTranspose.append(SmallVector<int64_t>(packOp.getInnerDimsPos()));
+ }
+ srcPermForTranspose.append(innerDimPos.begin(), innerDimPos.end());
LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n"
<< "perm: " << llvm::interleaved(srcPermForTranspose)
diff --git a/mlir/test/Dialect/Linalg/decompose-pack.mlir b/mlir/test/Dialect/Linalg/decompose-pack.mlir
index 911b453f919c3..6d091406a639c 100644
--- a/mlir/test/Dialect/Linalg/decompose-pack.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-pack.mlir
@@ -229,3 +229,22 @@ func.func @simple_KCRS_to_KRSCsr(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<1x1x
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
// CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1]
// CHECK: return %[[INSERT]]
+
+// -----
+
+func.func @pack_with_unit_outer_dims_and_unit_inner(%arg0: tensor<1x1x4xf32>, %arg1: tensor<1x1x1x4x1xf32>) -> tensor<1x1x1x4x1xf32> {
+ %pack = linalg.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg1 : tensor<1x1x4xf32> -> tensor<1x1x1x4x1xf32>
+ return %pack : tensor<1x1x1x4x1xf32>
+}
+
+// CHECK-LABEL: func.func @pack_with_unit_outer_dims_and_unit_inner
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x4x1xf32>
+// CHECK: %[[TRANSP:.+]] = linalg.transpose
+// CHECK-SAME: ins(%[[SRC]] : tensor<1x1x4xf32>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x4x1xf32>)
+// CHECK-SAME: permutation = [1, 2, 0]
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
+// CHECK-SAME: [0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x4x1xf32> into tensor<1x1x1x4x1xf32>
+// CHECK: return %[[INSERT]]
\ No newline at end of file
>From 4f378a5b0b904c40236ea1840b914c068a3d00f0 Mon Sep 17 00:00:00 2001
From: Christopher McGirr <mcgirr at roofline.ai>
Date: Mon, 2 Jun 2025 12:58:05 +0000
Subject: [PATCH 2/2] Update(1) [mlir][linalg] fix OuterUnitDims linalg.pack
decomposition pattern
---
.../Dialect/Linalg/Transforms/Transforms.cpp | 26 +++++++++----------
mlir/test/Dialect/Linalg/decompose-pack.mlir | 3 +--
.../test/Dialect/Linalg/decompose-unpack.mlir | 17 ++++++++++++
3 files changed, 31 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 7b6c8243d1040..69c17cba3f307 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1179,13 +1179,6 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
int64_t destRank = packOp.getDestRank();
int64_t numTiles = destRank - srcRank;
- if (!llvm::all_of(packOp.getInnerDimsPos(),
- [&srcRank, &numTiles](int64_t dimPos) {
- return dimPos >= (srcRank - numTiles - 1);
- }))
- return rewriter.notifyMatchFailure(
- packOp, "Attempting to tile non-trailing source dims!");
-
// 1. Extract the inner tile sizes.
// Where possible, values are replaced with constant attributes (to match the
// behaviour of `getPackOpSourceOrPaddedSource`).
@@ -1206,15 +1199,22 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
// %transposed_tile = linalg.transpose ins(%source_or_padded_source),
// outs(%init)
// Assumptions made:
- // 1. Inner dims position correspond to the trailing `numTiles` dims.
+ // 1. All outer dims are 1 - the corresponding transposition order doesn't
+ // matter, but requires all dim indices to be present.
+ // 2. Inner dims position can have non-adjacent trailing dimensions. Where,
+ // For example, a source tensor with indices [0, 1, 2] can have:
+ // * adjacent trailing dimensions of [1, 2], [2, 1]
+ // * non-adjacent trailing dimensions of [0, 2] or [2, 0]
+ // Trailing dimensions are defined in the case above as index [2].
+ // And the indices [0] or [1] are not defined to be trailing.
SmallVector<int64_t> srcPermForTranspose;
ArrayRef<int64_t> innerDimPos(packOp.getInnerDimsPos());
for (int64_t i = 0; i < srcRank; i++) {
- // As we assume the trailing dimensions of the inner dim position correspond
- // to the trailing indices of the transpose permutation, we need to
- // calculate the remaining indicies of the transpose permutation. This is
- // done by adding the indices not contained in the inner dimension position.
- // For example if we have a source tensor of dimensions [0, 1, 2, 3]
+ // We assume the `k` dimensions of the inner dim position correspond
+ // to the last `k` indices of the transpose permutation. This is
+ // done by adding the indices not contained in the inner dimension position
+ // in order from 0 to `n`. Where n is the rank of the source tensor.
+ // For example if we have a source tensor with indices [0, 1, 2, 3]
// and inner dim position of [3, 0], the remaining indices are [1, 2].
// and the transpose will be [1, 2, 3, 0].
if (llvm::is_contained(innerDimPos, i))
diff --git a/mlir/test/Dialect/Linalg/decompose-pack.mlir b/mlir/test/Dialect/Linalg/decompose-pack.mlir
index 6d091406a639c..6239a82168f38 100644
--- a/mlir/test/Dialect/Linalg/decompose-pack.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-pack.mlir
@@ -236,7 +236,6 @@ func.func @pack_with_unit_outer_dims_and_unit_inner(%arg0: tensor<1x1x4xf32>, %a
%pack = linalg.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg1 : tensor<1x1x4xf32> -> tensor<1x1x1x4x1xf32>
return %pack : tensor<1x1x1x4x1xf32>
}
-
// CHECK-LABEL: func.func @pack_with_unit_outer_dims_and_unit_inner
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
@@ -247,4 +246,4 @@ func.func @pack_with_unit_outer_dims_and_unit_inner(%arg0: tensor<1x1x4xf32>, %a
// CHECK-SAME: permutation = [1, 2, 0]
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
// CHECK-SAME: [0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x4x1xf32> into tensor<1x1x1x4x1xf32>
-// CHECK: return %[[INSERT]]
\ No newline at end of file
+// CHECK: return %[[INSERT]]
diff --git a/mlir/test/Dialect/Linalg/decompose-unpack.mlir b/mlir/test/Dialect/Linalg/decompose-unpack.mlir
index d460c506d6e18..c6c99dca186d5 100644
--- a/mlir/test/Dialect/Linalg/decompose-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-unpack.mlir
@@ -169,3 +169,20 @@ func.func @unpack_with_dynamic_dims(%arg0: tensor<?x1x1x1x8x32xf32>, %arg1: tens
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[EXTRACT_SLICE]] into %[[DEST]]
// CHECK-SAME: [0, 0, 0, 0] [%[[DIM0_DEST]], 1, 32, 8] [1, 1, 1, 1]
// CHECK: return %[[INSERT]]
+
+// -----
+
+func.func @pack_with_unit_outer_dims_and_unit_inner(%arg0: tensor<1x1x1x4x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x4xf32> {
+ %0 = linalg.unpack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg1 : tensor<1x1x1x4x1xf32> -> tensor<1x1x4xf32>
+ return %0 : tensor<1x1x4xf32>
+}
+// CHECK-LABEL: func.func @pack_with_unit_outer_dims_and_unit_inner
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x1x1x4x1xf32> to tensor<4x1xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x4xf32>
+// CHECK: %[[TRANSP:.+]] = linalg.transpose
+// CHECK-SAME: ins(%[[SLICE]] : tensor<4x1xf32>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x4xf32>) permutation = [1, 0]
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %transposed into %[[DEST]][0, 0, 0] [1, 1, 4] [1, 1, 1] : tensor<1x4xf32> into tensor<1x1x4xf32>
+// CHECK: return %[[INSERT]]
More information about the Mlir-commits
mailing list