[Mlir-commits] [mlir] e90f8d8 - [mlir][linalg] Extend DecomposeOuterUnitDimsPackOpPattern (linalg.pack) (#162666)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 13 01:48:08 PDT 2025


Author: Andrzej WarzyƄski
Date: 2025-10-13T09:48:03+01:00
New Revision: e90f8d84b0063b47455430fd8b57c1b1dafe7735

URL: https://github.com/llvm/llvm-project/commit/e90f8d84b0063b47455430fd8b57c1b1dafe7735
DIFF: https://github.com/llvm/llvm-project/commit/e90f8d84b0063b47455430fd8b57c1b1dafe7735.diff

LOG: [mlir][linalg] Extend DecomposeOuterUnitDimsPackOpPattern (linalg.pack) (#162666)

Similarly to #152960, this PR fixes `getTiledOuterDims` for `linalg.pack` by
ensuring that the `outer_dims_perm` attributeis properly taken into account.

This enables the main change in this PR: relaxing the constraints in
  * `DecomposeOuterUnitDimsPackOpPattern`.

Specifically, the pattern is extended to allow non-unit untiled outer
dimensions. For example:
```mlir
func.func @example(
    %src: tensor<2x32x16x8xf32>,
    %dest: tensor<2x1x16x8x32xf32>) ->  tensor<2x1x16x8x32xf32> {

  %pack = linalg.pack %src
    inner_dims_pos = [1]
    inner_tiles = [32] into %dest
    : tensor<2x32x16x8xf32> -> tensor<2x1x16x8x32xf32>

  return %pack : tensor<2x1x16x8x32xf32>
}
```
decomposes as:
```mlir
func.func @example(
      %src: tensor<2x32x16x8xf32>,
      %dest: tensor<2x1x16x8x32xf32>) -> tensor<2x1x16x8x32xf32> {

  %0 = tensor.empty() : tensor<2x16x8x32xf32>
  %transposed = linalg.transpose
    ins(%src : tensor<2x32x16x8xf32>)
    outs(%init : tensor<2x16x8x32xf32>)
    permutation = [0, 2, 3, 1]
  %inserted_slice = tensor.insert_slice %transposed
    into %dest[0, 0, 0, 0, 0] [2, 1, 16, 8, 32] [1, 1, 1, 1, 1]
    : tensor<2x16x8x32xf32> into tensor<2x1x16x8x32xf32>

  return %inserted_slice : tensor<2x1x16x8x32xf32>
  }
```
Importantly, this change makes `DecomposeOuterUnitDimsPackOpPattern` (the
decomposition pattern for `linalg.pack`) consistent with the corresponding
pattern for `linalg.unpack`:
  * `DecomposeOuterUnitDimsUnPackOpPattern`.

One notable assumption remains: untiled outer dimensions are not permuted. This
was already the case but is now explicitly documented.

Co-authored by: Max Bartel <bartel at roofline.ai>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Linalg/decompose-pack.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 7266687584b38..ae7a085a1f7a8 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1650,8 +1650,12 @@ struct DecomposePadOpPattern : public OpRewritePattern<tensor::PadOp> {
 /// Rewrites a linalg::PackOp into a sequence of:
 ///   * tensor::PadOp + linalg::TransposeOp + tensor::EmptyOp +
 ///     tensor::InsertSliceOp ops.
+/// (InsertSliceOp is rank-expanding).
 ///
-/// Requires that all the outer dims of the input linalg::PackOp are 1.
+/// Requires that all the tiled-outer-dims of the input linalg::PackOp are 1.
+/// Note that this constraint means that effectively exactly one tile is packed.
+///
+/// In addition, assumes that the un-tiled-outer-dims are not permuted.
 ///
 /// Before:
 /// ```
@@ -1687,10 +1691,13 @@ struct DecomposeOuterUnitDimsPackOpPattern
                                 PatternRewriter &rewriter) const override;
 };
 
-/// Rewrites a linalg::UnPackOp into a sequence of rank-reduced
+/// Rewrites a linalg::UnPackOp into a sequence of:
 ///   * tensor::ExtractSliceOp + linalg::TransposeOp + tensor::InsertSliceOp
+/// (ExtractSliceOp is rank-reducing).
 ///
-/// Requires that all the tiled outer dims of the input linalg::PackOp are 1.
+/// Requires that all the tiled-outer-dims of the input linalg::UnPackOp are 1.
+/// Note that this constraint means that effectively exactly one tile is
+/// unpacked.
 ///
 /// Before:
 /// ```

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 59013a23b3e3b..cbc565b0c8cbd 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5272,11 +5272,18 @@ ArrayRef<int64_t> PackOp::getAllOuterDims() {
 
 SmallVector<int64_t> PackOp::getTiledOuterDims() {
   auto innerDimsPos = getInnerDimsPos();
-  auto packedShape = getDestType().getShape();
+  SmallVector<int64_t> outerDims(getAllOuterDims());
   SmallVector<int64_t> res;
 
+  // Recover the original order of the outer dims.
+  SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
+  invertPermutationVector(outerDimPermInv);
+  if (!outerDimPermInv.empty())
+    applyPermutationToVector(outerDims, outerDimPermInv);
+
+  // Collect the outer dims corresponding to the tilled inner dims.
   for (auto index : innerDimsPos)
-    res.push_back(packedShape[index]);
+    res.push_back(outerDims[index]);
 
   return res;
 }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 0dac688e1c26d..eb2d825e17e44 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1134,22 +1134,45 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
 
 LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
     linalg::PackOp packOp, PatternRewriter &rewriter) const {
-  // TODO: support the case that outer dimensions are not all 1s. A
-  // tensor.expand_shape will be generated in this case.
-  if (llvm::any_of(packOp.getAllOuterDims(),
+  if (llvm::any_of(packOp.getTiledOuterDims(),
                    [](int64_t dim) { return dim != 1; })) {
     return rewriter.notifyMatchFailure(
         packOp, "not all outer dimensions of the result are 1s");
   }
 
+  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
+  auto outerDimsPerm = packOp.getOuterDimsPerm();
+
+  // Verify that there are no:
+  //   * non-unit + un-tiled-outer-dims,
+  // that are permuted. Supporting such cases would require refining the logic
+  // that generates the Transpose Op.
+  if (!llvm::all_of(outerDimsPerm, [&innerDimsPos, &packOp](int64_t dim) {
+        static int prev = 0;
+        // Skip tiled dims - these can be permuted.
+        if (llvm::is_contained(innerDimsPos, dim))
+          return true;
+
+        // Check whether this dim has been permuted. Permuting unit dims is fine
+        // as that's effectively a no-op.
+        if (dim < prev && (packOp.getType().getShape()[prev] != 1 ||
+                           packOp.getType().getShape()[dim] != 1))
+          return false;
+
+        prev = dim;
+        return true;
+      })) {
+    return rewriter.notifyMatchFailure(
+        packOp, "At least one non-unit and un-tiled outer dim is permuted, "
+                "this is not supported ATM!");
+  }
+
   Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
   Attribute oneIdxAttr = rewriter.getIndexAttr(1);
   Location loc = packOp.getLoc();
 
   int64_t srcRank = packOp.getSourceRank();
   int64_t destRank = packOp.getDestRank();
-  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
-  int64_t numberOfTiles = innerDimsPos.size();
 
   // 1. Get the input that is going to be packed. If the input requires padding,
   // add a padding operation and return that as the input.
@@ -1160,10 +1183,13 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
   //    %transposed_tile = linalg.transpose ins(%source_or_padded_source),
   //                                        outs(%init)
   // Assumptions made:
-  //  - All outer dims are 1 - the corresponding transposition order doesn't
-  //     matter, but requires all dim indices to be present.
+  //  - All tiled outer dims are 1 - the corresponding transposition order
+  //    doesn't matter, but requires all dim indices to be present.
+  //  - Un-tiled outer dims remain un-permuted.
 
-  // 2.1 Get the permutation for linalg.transpose
+  // 2.1 Get the permutation for linalg.transpose:
+  //   [ untiled-dims, inner-dims-pos ]
+  // Note, this logic assumes that the untiled dims are not permuted.
   SmallVector<int64_t> srcPermForTranspose;
   for (int64_t i = 0; i < srcRank; i++) {
     // We assume the `k` dimensions of the inner dim position, where `k` is the
@@ -1179,9 +1205,21 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
   }
   srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end());
 
-  // 2.2 Create the init tensor for linalg.transpose with the correct shape
-  SmallVector<OpFoldResult> shapeForEmptyOp(srcRank - numberOfTiles,
-                                            oneIdxAttr);
+  // 2.2 Create the init tensor for linalg.transpose with the correct shape:
+  //    [ untiled-dims, tiled-dims ]
+  ShapedType inputTy = cast<ShapedType>(input.getType());
+  SmallVector<OpFoldResult> shapeForEmptyOp;
+  for (int64_t i = 0; i < srcRank; i++) {
+    if (llvm::is_contained(innerDimsPos, i)) {
+      // The tiled dims are appended after this loop.
+      continue;
+    }
+    if (inputTy.isStaticDim(i))
+      shapeForEmptyOp.push_back(rewriter.getIndexAttr(inputTy.getShape()[i]));
+    else
+      shapeForEmptyOp.emplace_back(
+          tensor::DimOp::create(rewriter, loc, input, i).getResult());
+  }
   shapeForEmptyOp.append(packOp.getMixedTiles());
 
   // getMixedTiles() may contain Values pointing to constant ops, not the
@@ -1204,25 +1242,36 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
   auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
                                                   srcPermForTranspose);
 
-  // 3. Insert the inner tile to the destination:
+  // 3. Insert the inner tile into the destination tensor:
   //  %inserted_tile = tensor.insert_slice(%transposed_tile)
-  SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
-  SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
-  // Outer dims are all 1s!
-  SmallVector<OpFoldResult> writeSizes(destRank - numberOfTiles, oneIdxAttr);
-  SmallVector<int64_t> writeShape;
+
+  // Compute the sizes attribute:
+  //    [ outer-dims, tile-sizes ]
+  // Note that the output from the transpose Op excludes the tiled outer dims.
+  // However, given the assumption that:
+  //  * all tiled outer dims == 1,
+  // we can just use a rank-expanding tensor.insert_slice.
+  SmallVector<OpFoldResult> writeSizes;
+  for (auto size : packOp.getAllOuterDims()) {
+    writeSizes.push_back(rewriter.getIndexAttr(size));
+  }
 
   for (auto tileSize : packOp.getMixedTiles()) {
-    auto [tileSizeStatic, tileSizeOfr] =
+    auto [_, tileSizeOfr] =
         getSimplifiedOfrAndStaticSizePair(tileSize, rewriter);
     writeSizes.push_back(tileSizeOfr);
-    writeShape.push_back(tileSizeStatic);
   }
 
-  // 4. Replace tensor.packOp with tensor.insert_slice created above
+  // TODO: Add a constructor for tensor.insert_slice that doesn't require
+  // strides nor offsets.
+  SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
+  SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
+
   auto insert = tensor::InsertSliceOp::create(
       rewriter, loc, transposedOp.getResult()[0], packOp.getDest(),
       writeOffsets, writeSizes, writeStrides);
+
+  // 4. Replace tensor.packOp with tensor.insert_slice created above
   rewriter.replaceOp(packOp, insert.getResult());
 
   return success();

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index fa97b49a41d97..ac7200294a3a6 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2310,6 +2310,7 @@ RankedTensorType ExtractSliceOp::inferResultType(
                                sourceTensorType.getEncoding());
 }
 
+// TODO: This uses neither offsets nor strides!
 RankedTensorType ExtractSliceOp::inferResultType(
     RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
     ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {

diff  --git a/mlir/test/Dialect/Linalg/decompose-pack.mlir b/mlir/test/Dialect/Linalg/decompose-pack.mlir
index 18a09f4c669bb..12292ee573cee 100644
--- a/mlir/test/Dialect/Linalg/decompose-pack.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-pack.mlir
@@ -31,6 +31,25 @@ func.func @simple_KCRS_to_KCRSsr(%arg0: tensor<?x?xi32>, %arg1: tensor<1x1x?x1xi
 
 // -----
 
+func.func @NCHW_to_NCHWc(%src: tensor<2x32x16x8xf32>, %dest: tensor<2x1x16x8x32xf32>) ->  tensor<2x1x16x8x32xf32> {
+  %pack = linalg.pack %src
+    inner_dims_pos = [1]
+    inner_tiles = [32] into %dest
+    : tensor<2x32x16x8xf32> -> tensor<2x1x16x8x32xf32>
+  return %pack : tensor<2x1x16x8x32xf32>
+}
+// CHECK-LABEL:   func.func @NCHW_to_NCHWc(
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:           %[[INIT:.*]] = tensor.empty() : tensor<2x16x8x32xf32>
+// CHECK:           %[[TR:.*]] = linalg.transpose ins(%[[SRC]] : tensor<2x32x16x8xf32>) outs(%[[INIT]] : tensor<2x16x8x32xf32>) permutation = [0, 2, 3, 1]
+// CHECK:           %[[RES:.*]] = tensor.insert_slice %[[TR]] into %[[DEST]]
+// CHECK-SAME:        [0, 0, 0, 0, 0] [2, 1, 16, 8, 32] [1, 1, 1, 1, 1]
+// CHECK-SAME:        : tensor<2x16x8x32xf32> into tensor<2x1x16x8x32xf32>
+// CHECK:           return %[[RES]] : tensor<2x1x16x8x32xf32>
+
+// -----
+
 func.func @simple_pad_and_pack_static_tiles(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2xf32>, %pad: f32) -> tensor<1x1x8x2xf32> {
   %0 = linalg.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<5x1xf32> -> tensor<1x1x8x2xf32>
   return %0 : tensor<1x1x8x2xf32>
@@ -157,6 +176,8 @@ func.func @simple_pad_and_pack_dynamic_tiles(%input: tensor<5x1xf32>, %output: t
 
 // -----
 
+// Note - un-tiled outer dims are permueted. However, these are unit dims, which is supported.
+
 func.func @simple_pad_and_pack_dynamic_tile_not_all_dims_tiled(%input: tensor<1x1x5x1xf32>, %output: tensor<1x1x1x1x2x?xf32>, %pad: f32, %high: index) -> tensor<1x1x1x1x2x?xf32> {
   %0 = linalg.pack %input padding_value(%pad : f32) outer_dims_perm = [1, 0, 2, 3] inner_dims_pos = [3, 2] inner_tiles = [2, %high] into %output : tensor<1x1x5x1xf32> -> tensor<1x1x1x1x2x?xf32>
   return %0 : tensor<1x1x1x1x2x?xf32>
@@ -182,6 +203,28 @@ func.func @simple_pad_and_pack_dynamic_tile_not_all_dims_tiled(%input: tensor<1x
 
 // -----
 
+// Similar as the example above, but one of the un-tiled outer dims that are permuted is non-unit: (7,1) -> (1, 7)
+
+func.func @negative_not_all_dims_tiled_outer_dim_0_permuted(%input: tensor<7x1x5x1xf32>, %output: tensor<1x7x1x1x2x?xf32>, %pad: f32, %high: index) -> tensor<1x7x1x1x2x?xf32> {
+  %0 = linalg.pack %input padding_value(%pad : f32) outer_dims_perm = [1, 0, 2, 3] inner_dims_pos = [3, 2] inner_tiles = [2, %high] into %output : tensor<7x1x5x1xf32> -> tensor<1x7x1x1x2x?xf32>
+  return %0 : tensor<1x7x1x1x2x?xf32>
+}
+// CHECK-LABEL:   func.func @negative_not_all_dims_tiled_outer_dim_0_permuted
+// CHECK: linalg.pack
+
+// -----
+
+// Similar as the example above, but one of the un-tiled outer dims that are permuted is non-unit: (1, 7) -> (7, 1).
+
+func.func @negative_not_all_dims_tiled_outer_dim_1_permuted(%input: tensor<1x7x5x1xf32>, %output: tensor<7x1x1x1x2x?xf32>, %pad: f32, %high: index) -> tensor<7x1x1x1x2x?xf32> {
+  %0 = linalg.pack %input padding_value(%pad : f32) outer_dims_perm = [1, 0, 2, 3] inner_dims_pos = [3, 2] inner_tiles = [2, %high] into %output : tensor<1x7x5x1xf32> -> tensor<7x1x1x1x2x?xf32>
+  return %0 : tensor<7x1x1x1x2x?xf32>
+}
+// CHECK-LABEL:   func.func @negative_not_all_dims_tiled_outer_dim_1_permuted
+// CHECK: linalg.pack
+
+// -----
+
 func.func @simple_NC_to_CNnc(%arg0: tensor<32x8xf32>, %arg1: tensor<1x1x32x8xf32>) -> tensor<1x1x32x8xf32>{
   %0 = linalg.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : tensor<32x8xf32> -> tensor<1x1x32x8xf32>
   return %0 : tensor<1x1x32x8xf32>
@@ -295,3 +338,21 @@ func.func @pack_with_non_adjacent_and_non_permuted_inner_dims(%arg0: tensor<8x1x
 // CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
 // CHECK-SAME:      [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x1xf32> into tensor<1x1x1x1x8x1xf32>
 // CHECK:         return %[[INSERT]]
+
+// -----
+
+/// Note "126", which is a non-unit tiled-outer-dim. This is not supported.
+
+func.func @negative_non_unit_tiled_outer_dim(%dest: tensor<1x126x1x1x8xf32>, %src: tensor<1x1x1x1001xf32>, %pad: f32) -> tensor<1x126x1x1x8xf32> {
+  %pack = linalg.pack %src
+    padding_value(%pad : f32)
+    outer_dims_perm = [0, 3, 2, 1]
+    inner_dims_pos = [3]
+    inner_tiles = [8]
+    into %dest
+    : tensor<1x1x1x1001xf32> -> tensor<1x126x1x1x8xf32>
+
+  return %pack : tensor<1x126x1x1x8xf32>
+}
+// CHECK-LABEL: @negative_non_unit_tiled_outer_dim(
+// CHECK: linalg.pack


        


More information about the Mlir-commits mailing list