[Mlir-commits] [mlir] [mlir][linalg] fix OuterUnitDims linalg.pack decomposition pattern (PR #141613)

Christopher McGirr llvmlistbot at llvm.org
Thu Jun 12 08:01:31 PDT 2025


================
@@ -1205,16 +1198,30 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
   //    %init = tensor.empty()
   //    %transposed_tile = linalg.transpose ins(%source_or_padded_source),
   //                                        outs(%init)
-  // Two assumptions are made:
-  //  1. All outer dims are 1 - the corresponding transposition doesn't matter.
-  //  2. Inner dims position correspond to the trailing `numTiles` dims.
-  SmallVector<int64_t> tilesPermNormalized =
-      getPackUnpackNormalizedPerm(srcRank, packOp.getInnerDimsPos());
+  // Assumptions made:
+  //  1. All outer dims are 1 - the corresponding transposition order doesn't
+  //     matter, but requires all dim indices to be present.
+  //  2. Inner dims position can have non-adjacent trailing dimensions. Where,
+  //     For example, a source tensor with indices [0, 1, 2] can have:
+  //       * adjacent trailing dimensions of [1, 2], [2, 1]
+  //       * non-adjacent trailing dimensions of [0, 2] or [2, 0]
+  //     Trailing dimensions are defined in the case above as index [2].
+  //     And the indices [0] or [1] are not defined to be trailing.
   SmallVector<int64_t> srcPermForTranspose;
-  for (int64_t i = 0; i < (srcRank - numTiles); i++)
+  ArrayRef<int64_t> innerDimPos(packOp.getInnerDimsPos());
+  for (int64_t i = 0; i < srcRank; i++) {
+    // We assume the `k` dimensions of the inner dim position correspond
+    // to the last `k` indices of the transpose permutation. This is
+    // done by adding the indices not contained in the inner dimension position
+    // in order from 0 to `n`. Where n is the rank of the source tensor.
+    //   For example if we have a source tensor with indices [0, 1, 2, 3]
+    //   and inner dim position of [3, 0], the remaining indices are [1, 2].
+    //   and the transpose will be [1, 2, 3, 0].
----------------
chrsmcgrr wrote:

Done

https://github.com/llvm/llvm-project/pull/141613


More information about the Mlir-commits mailing list