[Mlir-commits] [mlir] 96c1611 - [mlir][linalg] fix OuterUnitDims linalg.pack decomposition pattern (#141613)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 27 00:24:37 PDT 2025


Author: Christopher McGirr
Date: 2025-06-27T09:24:33+02:00
New Revision: 96c1611163d3420c78e30b13c1f7211e3572e58b

URL: https://github.com/llvm/llvm-project/commit/96c1611163d3420c78e30b13c1f7211e3572e58b
DIFF: https://github.com/llvm/llvm-project/commit/96c1611163d3420c78e30b13c1f7211e3572e58b.diff

LOG: [mlir][linalg] fix OuterUnitDims linalg.pack decomposition pattern (#141613)

Given the following example:
```
module {
  func.func @main(%arg0: tensor<1x1x1x4x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x1x4x1xf32> {
    %pack = linalg.pack %arg1 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg0 : tensor<1x1x4xf32> -> tensor<1x1x1x4x1xf32>
    return %pack : tensor<1x1x1x4x1xf32>
  }
}
```

We would generate an invalid transpose operation because the calculated
permutation would be `[0, 2, 0]` which is semantically incorrect. As the
permutation must contain unique integers corresponding to the source
tensor dimensions.

The following change modifies how we calculate the permutation array and
ensures that the dimension indices given in the permutation array is
unique.

The above example would then translate to a transpose having a
permutation of `[1, 2, 0]`. Following the rule, that the `inner_dim_pos`
is appended to the permutation array and the preceding indices are
filled with the remaining dimensions.

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/Dialect/Linalg/decompose-pack.mlir
    mlir/test/Dialect/Linalg/decompose-unpack.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 615d1f66414b9..a775699f99343 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1178,13 +1178,6 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
   int64_t destRank = packOp.getDestRank();
   int64_t numTiles = destRank - srcRank;
 
-  if (!llvm::all_of(packOp.getInnerDimsPos(),
-                    [&srcRank, &numTiles](int64_t dimPos) {
-                      return dimPos >= (srcRank - numTiles - 1);
-                    }))
-    return rewriter.notifyMatchFailure(
-        packOp, "Attempting to tile non-trailing source dims!");
-
   // 1. Extract the inner tile sizes.
   // Where possible, values are replaced with constant attributes (to match the
   // behaviour of `getPackOpSourceOrPaddedSource`).
@@ -1204,16 +1197,24 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
   //    %init = tensor.empty()
   //    %transposed_tile = linalg.transpose ins(%source_or_padded_source),
   //                                        outs(%init)
-  // Two assumptions are made:
-  //  1. All outer dims are 1 - the corresponding transposition doesn't matter.
-  //  2. Inner dims position correspond to the trailing `numTiles` dims.
-  SmallVector<int64_t> tilesPermNormalized =
-      getPackUnpackNormalizedPerm(srcRank, packOp.getInnerDimsPos());
+  // Assumptions made:
+  //  1. All outer dims are 1 - the corresponding transposition order doesn't
+  //     matter, but requires all dim indices to be present.
   SmallVector<int64_t> srcPermForTranspose;
-  for (int64_t i = 0; i < (srcRank - numTiles); i++)
+  ArrayRef<int64_t> innerDimPos(packOp.getInnerDimsPos());
+  for (int64_t i = 0; i < srcRank; i++) {
+    // We assume the `k` dimensions of the inner dim position, where `k` is the
+    // rank of the inner tiling, correspond to the last `k` indices of the
+    // transpose permutation. This is done by adding the indices not contained
+    // in the inner dimension position in order from 0 to `n`. Where n is the
+    // rank of the source tensor. For example if we have a source tensor with
+    // indices [0, 1, 2, 3] and inner dim position of [3, 0], the remaining
+    // indices are [1, 2]. and the transpose will be [1, 2, 3, 0].
+    if (llvm::is_contained(innerDimPos, i))
+      continue;
     srcPermForTranspose.push_back(i);
-
-  srcPermForTranspose.append(SmallVector<int64_t>(packOp.getInnerDimsPos()));
+  }
+  srcPermForTranspose.append(innerDimPos.begin(), innerDimPos.end());
 
   LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n"
                     << "perm: " << llvm::interleaved(srcPermForTranspose)

diff  --git a/mlir/test/Dialect/Linalg/decompose-pack.mlir b/mlir/test/Dialect/Linalg/decompose-pack.mlir
index 911b453f919c3..17e6c29754f9d 100644
--- a/mlir/test/Dialect/Linalg/decompose-pack.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-pack.mlir
@@ -229,3 +229,48 @@ func.func @simple_KCRS_to_KRSCsr(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<1x1x
 // CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
 // CHECK-SAME:      [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1]
 // CHECK:         return %[[INSERT]]
+
+// -----
+
+// The following example shows a pack operation that is defined with inner
+// dimension positions that are not adjacent, i.e. `[2, 0]`. And the outer
+// dimensions of the packed tensor are of unit values, i.e. `1x1x1`.
+func.func @pack_with_non_adjacent_inner_dims_pos_and_unit_outer(%arg0: tensor<1x1x4xf32>, %arg1: tensor<1x1x1x4x1xf32>) -> tensor<1x1x1x4x1xf32> {
+  %pack = linalg.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg1 : tensor<1x1x4xf32> -> tensor<1x1x1x4x1xf32>
+  return %pack : tensor<1x1x1x4x1xf32>
+}
+// CHECK-LABEL: func.func @pack_with_non_adjacent_inner_dims_pos_and_unit_outer
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x4x1xf32>
+// CHECK:         %[[TRANSP:.+]] = linalg.transpose
+// CHECK-SAME:      ins(%[[SRC]] : tensor<1x1x4xf32>)
+// CHECK-SAME:      outs(%[[EMPTY]] : tensor<1x4x1xf32>)
+// CHECK-SAME:      permutation = [1, 2, 0]
+// CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
+// CHECK-SAME:      [0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x4x1xf32> into tensor<1x1x1x4x1xf32>
+// CHECK:         return %[[INSERT]]
+
+// -----
+
+// The following example shows a pack operation where the inner dimension
+// positions are specified as [2, 1] which are termed adjacent trailing
+// dimensions as they contain the last dimension of the source tensor with a
+// neighboring dimension. [1, 2] would also be considered trailing adjacent.
+// And the outer dimensions of the packed tensor are all set to unit values
+// of `1x1x1`.
+func.func @pack_with_adjacent_trailing_dimensions_inner_dims_pos_and_unit_outer(%arg0: tensor<1x1x4xf32>, %arg1: tensor<1x1x1x4x1xf32>) -> tensor<1x1x1x4x1xf32> {
+  %pack = linalg.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [4, 1] into %arg1 : tensor<1x1x4xf32> -> tensor<1x1x1x4x1xf32>
+  return %pack : tensor<1x1x1x4x1xf32>
+}
+// CHECK-LABEL: func.func @pack_with_adjacent_trailing_dimensions_inner_dims_pos_and_unit_outer
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x4x1xf32>
+// CHECK:         %[[TRANSP:.+]] = linalg.transpose
+// CHECK-SAME:      ins(%[[SRC]] : tensor<1x1x4xf32>)
+// CHECK-SAME:      outs(%[[EMPTY]] : tensor<1x4x1xf32>)
+// CHECK-SAME:      permutation = [0, 2, 1]
+// CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
+// CHECK-SAME:      [0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x4x1xf32> into tensor<1x1x1x4x1xf32>
+// CHECK:         return %[[INSERT]]

diff  --git a/mlir/test/Dialect/Linalg/decompose-unpack.mlir b/mlir/test/Dialect/Linalg/decompose-unpack.mlir
index d460c506d6e18..e173d557c770d 100644
--- a/mlir/test/Dialect/Linalg/decompose-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-unpack.mlir
@@ -169,3 +169,37 @@ func.func @unpack_with_dynamic_dims(%arg0: tensor<?x1x1x1x8x32xf32>, %arg1: tens
 // CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[EXTRACT_SLICE]] into %[[DEST]]
 // CHECK-SAME:      [0, 0, 0, 0] [%[[DIM0_DEST]], 1, 32, 8] [1, 1, 1, 1]
 // CHECK:         return %[[INSERT]]
+
+// -----
+
+func.func @unpack_with_non_adjacent_inner_dims_pos_and_unit_outer(%arg0: tensor<1x1x1x4x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x4xf32> {
+  %0 = linalg.unpack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg1 : tensor<1x1x1x4x1xf32> -> tensor<1x1x4xf32>
+  return %0 : tensor<1x1x4xf32>
+}
+// CHECK-LABEL: func.func @unpack_with_non_adjacent_inner_dims_pos_and_unit_outer
+// CHECK-SAME:     %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:        %[[SLICE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x1x1x4x1xf32> to tensor<4x1xf32>
+// CHECK:        %[[EMPTY:.+]] = tensor.empty() : tensor<1x4xf32>
+// CHECK:        %[[TRANSP:.+]] = linalg.transpose
+// CHECK-SAME:                      ins(%[[SLICE]] : tensor<4x1xf32>)
+// CHECK-SAME:                      outs(%[[EMPTY]] : tensor<1x4xf32>) permutation = [1, 0]
+// CHECK:        %[[INSERT:.+]] = tensor.insert_slice %transposed into %[[DEST]][0, 0, 0] [1, 1, 4] [1, 1, 1] : tensor<1x4xf32> into tensor<1x1x4xf32>
+// CHECK:        return %[[INSERT]]
+
+// -----
+
+func.func @unpack_with_non_trailing_dimensions_in_inner_dims(%arg0: tensor<1x1x1x4x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x4xf32> {
+  %pack = linalg.unpack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [4, 1] into %arg1 : tensor<1x1x1x4x1xf32> -> tensor<1x1x4xf32>
+  return %pack : tensor<1x1x4xf32>
+}
+// CHECK-LABEL: func.func @unpack_with_non_trailing_dimensions_in_inner_dims
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:        %[[SLICE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x1x1x4x1xf32> to tensor<4x1xf32>
+// CHECK:        %[[EMPTY:.+]] = tensor.empty() : tensor<1x4xf32>
+// CHECK:        %[[TRANSP:.+]] = linalg.transpose
+// CHECK-SAME:                      ins(%[[SLICE]] : tensor<4x1xf32>)
+// CHECK-SAME:                      outs(%[[EMPTY]] : tensor<1x4xf32>) permutation = [1, 0]
+// CHECK:        %[[INSERT:.+]] = tensor.insert_slice %transposed into %[[DEST]][0, 0, 0] [1, 1, 4] [1, 1, 1] : tensor<1x4xf32> into tensor<1x1x4xf32>
+// CHECK:        return %[[INSERT]]


        


More information about the Mlir-commits mailing list