[Mlir-commits] [mlir] [mlir][tensor][linalg] Move Pack/Unpack Ops to Linalg (PR #123902)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 22 08:05:41 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-sve

@llvm/pr-subscribers-mlir

Author: Andrzej WarzyƄski (banach-space)

<details>
<summary>Changes</summary>

This is merely moving code around, no new functionality is added.

PATCH 1: Copies `tensor.pack` and `tensor.unpack` as `linalg.pack` and
`linalg.unpack`, respectively. New Ops are defined in
LinalgRelayoutOps.td.

Note, `tensor.pack` and `tensor.unpack` are still present at this point.

CONTEXT:
This change was discussed in the following RFC:
* https://discourse.llvm.org/t/rfc-move-tensor-pack-and-tensor-unpack-into-linalg


---

Patch is 714.14 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/123902.diff


72 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt (+7) 
- (modified) mlir/include/mlir/Dialect/Linalg/IR/Linalg.h (+3) 
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td (+8) 
- (added) mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td (+332) 
- (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+53-33) 
- (modified) mlir/include/mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h (+5) 
- (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+36-23) 
- (modified) mlir/include/mlir/Dialect/Linalg/Utils/Utils.h (+18) 
- (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (-308) 
- (modified) mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td (-10) 
- (modified) mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h (-9) 
- (modified) mlir/include/mlir/Dialect/Tensor/Utils/Utils.h (-19) 
- (modified) mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h (+7) 
- (modified) mlir/lib/Dialect/Linalg/IR/CMakeLists.txt (+1) 
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp (+14-1) 
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+1111-17) 
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+18-8) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp (+1-1) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt (+1) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (+60-60) 
- (renamed) mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp (+58-7) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp (+655) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+20-20) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+16-16) 
- (modified) mlir/lib/Dialect/Linalg/Utils/Utils.cpp (+54) 
- (modified) mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp (+1-1) 
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+5-1033) 
- (modified) mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp (-652) 
- (modified) mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp (-5) 
- (modified) mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt (-1) 
- (modified) mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp (+2-46) 
- (modified) mlir/lib/Dialect/Tensor/Utils/Utils.cpp (-55) 
- (modified) mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp (+10) 
- (modified) mlir/test/Dialect/Linalg/block-pack-matmul-layout.mlir (+18-18) 
- (modified) mlir/test/Dialect/Linalg/block-pack-matmul-padding.mlir (+10-10) 
- (modified) mlir/test/Dialect/Linalg/block-pack-matmul.mlir (+45-45) 
- (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+499-3) 
- (modified) mlir/test/Dialect/Linalg/data-layout-propagation.mlir (+127-127) 
- (modified) mlir/test/Dialect/Linalg/decompose-tensor-pack-tile.mlir (+6-6) 
- (modified) mlir/test/Dialect/Linalg/decompose-tensor-pack.mlir (+11-11) 
- (modified) mlir/test/Dialect/Linalg/decompose-tensor-unpack-tile.mlir (+6-6) 
- (modified) mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir (+9-9) 
- (added) mlir/test/Dialect/Linalg/fold-empty-op.mlir (+82) 
- (modified) mlir/test/Dialect/Linalg/invalid.mlir (+183) 
- (modified) mlir/test/Dialect/Linalg/named-ops.mlir (+105) 
- (renamed) mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir (+46-46) 
- (modified) mlir/test/Dialect/Linalg/td/decompose-pack.mlir (+1-1) 
- (modified) mlir/test/Dialect/Linalg/td/decompose-unpack.mlir (+1-1) 
- (modified) mlir/test/Dialect/Linalg/transform-lower-pack.mlir (+86-86) 
- (modified) mlir/test/Dialect/Linalg/transform-op-fuse.mlir (+6-6) 
- (modified) mlir/test/Dialect/Linalg/transform-op-pack.mlir (+62-62) 
- (added) mlir/test/Dialect/Linalg/transform-op-tile-pack-unpack.mlir (+491) 
- (modified) mlir/test/Dialect/Linalg/transform-pack-greedily.mlir (+6-6) 
- (modified) mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir (+16-16) 
- (modified) mlir/test/Dialect/Linalg/vectorization-unsupported.mlir (+2-2) 
- (modified) mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir (+4-4) 
- (modified) mlir/test/Dialect/Linalg/vectorization.mlir (+24-24) 
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (-474) 
- (modified) mlir/test/Dialect/Tensor/fold-empty-op.mlir (-71) 
- (modified) mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir (+99-99) 
- (modified) mlir/test/Dialect/Tensor/invalid.mlir (-175) 
- (modified) mlir/test/Dialect/Tensor/ops.mlir (-103) 
- (modified) mlir/test/Dialect/Tensor/tiling.mlir (-492) 
- (modified) mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/pack-scalable-inner-tile.mlir (+4-4) 
- (modified) mlir/test/Integration/Dialect/Linalg/CPU/pack-dynamic-inner-tile.mlir (+4-4) 
- (modified) mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir (+15-15) 
- (modified) mlir/test/Integration/Dialect/Linalg/CPU/unpack-dynamic-inner-tile.mlir (+4-4) 
- (modified) mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir (+8-8) 
- (modified) mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir (+2-2) 
- (modified) mlir/test/Transforms/loop-invariant-code-motion.mlir (+10-10) 
- (modified) mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp (+27-1) 
- (modified) mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp (-26) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt
index 71214b4404c550..efd708c5e5a113 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt
@@ -65,6 +65,13 @@ add_public_tablegen_target(MLIRLinalgStructuredOpsIncGen)
 add_dependencies(MLIRLinalgStructuredOpsIncGen LinalgOdsGen)
 add_dependencies(mlir-headers MLIRLinalgStructuredOpsIncGen)
 
+set(LLVM_TARGET_DEFINITIONS LinalgRelayoutOps.td)
+mlir_tablegen(LinalgRelayoutOps.h.inc -gen-op-decls)
+mlir_tablegen(LinalgRelayoutOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRLinalgRelayoutOpsIncGen)
+add_dependencies(MLIRLinalgRelayoutOpsIncGen LinalgOdsGen)
+add_dependencies(mlir-headers MLIRLinalgRelayoutOpsIncGen)
+
 set(LLVM_TARGET_DEFINITIONS LinalgInterfaces.td)
 mlir_tablegen(LinalgInterfaces.h.inc -gen-op-interface-decls)
 mlir_tablegen(LinalgInterfaces.cpp.inc -gen-op-interface-defs)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
index 85f5ebeb8081ee..57bf6305a469d0 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
@@ -123,4 +123,7 @@ OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val,
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.h.inc"
 
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.h.inc"
+
 #endif // MLIR_DIALECT_LINALG_IR_LINALG_H
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 244db23925ab3c..5986626a727297 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -178,6 +178,14 @@ def LinalgConvolutionOpInterface : OpInterface<"ConvolutionOpInterface"> {
   ];
 }
 
+// TODO:
+def LinalgRelayoutOpInterface : OpInterface<"RelayoutOpInterface"> {
+  let description = [{
+    TODO
+  }];
+  let cppNamespace = "::mlir::linalg";
+}
+
 def LinalgFillOpInterface : OpInterface<"FillOpInterface"> {
   let description = [{
     A fill operation is defined in general terms:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
new file mode 100644
index 00000000000000..fe0e826f6b7717
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -0,0 +1,332 @@
+//===- LinalgReleayoutOps.td - Linalg dialect library ops -*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the operation definition file for structured operations on buffers
+// that correspond to underlying library calls (e.g. BLAS).
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LINALG_RELEAYOUT_OPS
+#define LINALG_RELEAYOUT_OPS
+
+include "mlir/Dialect/Linalg/IR/LinalgBase.td"
+include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
+include "mlir/IR/OpAsmInterface.td"
+
+//===----------------------------------------------------------------------===//
+// RelayoutOp
+//===----------------------------------------------------------------------===//
+
+class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
+      Op<Linalg_Dialect, mnemonic, !listconcat(traits, [
+        DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+        DestinationStyleOpInterface, LinalgRelayoutOpInterface,
+        ConditionallySpeculatable, NoMemoryEffect,
+        DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+        TypesMatchWith<"result type matches type of dest",
+                   "dest", "result",
+                   "$_self">])> {
+
+  code commonExtraClassDeclaration = [{
+    size_t getSourceRank() { return getSourceType().getRank(); };
+    size_t getDestRank() { return getDestType().getRank(); };
+    RankedTensorType getSourceType() {
+      return ::llvm::cast<RankedTensorType>(getSource().getType()); };
+    RankedTensorType getDestType() {
+      return ::llvm::cast<RankedTensorType>(getDest().getType()); };
+
+    MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
+
+    /// Interface method for ConditionallySpeculatable.
+    Speculation::Speculatability getSpeculatability();
+
+    /// Return a mapping from positions `inner_dims_pos` to their
+    /// tile factors.
+    DenseMap<int64_t, OpFoldResult> getDimAndTileMapping();
+
+    /// Return the tile sizes as OpFoldResult.
+    SmallVector<OpFoldResult> getMixedTiles();
+
+    /// Return the tile sizes as `int64_t`. If a tile size is dynamic
+    /// a sentinel `kDynamic` is introduced at that position in
+    /// the returned vector.
+    SmallVector<int64_t> getStaticTiles();
+
+    /// Retrieve all outer dims for this Pack/UnPack Op, i.e. all the leading
+    /// dims excluding the trailing dims corresponding to `innerTiles`. Note
+    /// that this will include both tiled and non-tiled dimensions. The order
+    /// of the output dimensions is consistent with the shape of the packed
+    /// tensor.
+    ArrayRef<int64_t> getAllOuterDims();
+
+    /// Similar to `getAllOuterDims`, but only retrieve the outer dims that
+    /// have been tiled. Also, the order of the output dimensions is consistent
+    /// with `inner_dims_pos` rather than the packed tensor.
+    SmallVector<int64_t> getTiledOuterDims();
+  }];
+
+  let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// PackOp
+//===----------------------------------------------------------------------===//
+
+def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
+    AttrSizedOperandSegments]> {
+  let summary = "linalg.pack operation";
+  let description = [{
+    The "pack" operation converts a source tensor of rank `n` into a result
+    tensor of rank `n + k` with a tiled and packed layout (maybe with padding)
+    and optionally transposes the tiled source tensor dimensions.
+
+    `inner_dims_pos` (mandatory) specifies `k` source tensor dimensions that are
+    being tiled, where `0 < k <= n`. The order of the dimensions matters:
+     - The tiled dimensions (of size `inner_tiles`) are added to the end of the result
+    tensor in the order in which they appear in `inner_dims_pos`.
+     - `inner_dims_pos[i]` specifies the source tensor dimension tiled by
+    `inner_tiles[i]`.
+
+    `inner_tiles` (mandatory) specifies `k` tile sizes. These tile sizes
+    correspond to the least significant ("inner") result tensor dimension sizes,
+    in the same order. Tile sizes can be static or dynamic.
+
+    Example: If `inner_tiles = [16, 32]`, the result tensor has a shape of
+    `...x16x32`. If `inner_dims_pos = [0, 1]`, the 0th source dimension is tiled
+    by 16 and the 1st source dimension is tiled by 32. Other source dimensions
+    (if any) are not tiled. If `inner_dims_pos = [1, 0]`, the 1st dimension is
+    tiled by 16 and the 0th dimension is tiled by 32.
+
+    Example:
+    ```mlir
+    // NC to NCnc
+    %0 = linalg.pack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+        into %dest : tensor<128x256xf32> -> tensor<16x8 x 8x32 xf32>
+    //                                             \  /   \  /
+    //                                       outer dims  inner dims
+    ```
+
+    `outer_dims_perm` (optional) specifies a permutation for the outer
+    dimensions. If specified, it must have `n` elements.
+
+    Example:
+    ```mlir
+    // CK to KCck
+    %0 = linalg.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
+        inner_tiles = [8, 32] into %dest
+        : tensor<128x256xf32> -> tensor<8x16 x 8x32 xf32>
+    //                                  \  /
+    //            compare with "NC to NCnc": outer dims are transposed
+    ```
+
+    `padding_value` specifies a padding value at the boundary on non-perfectly
+    divisible dimensions. Padding is optional:
+    - If absent, it is UB if the tile does not perfectly divide the dimension.
+    - If present, it will pad along high dimensions (high-padding) to make the
+      tile complete.
+
+    Example:
+    ```mlir
+    %0 = linalg.pack %arg0 padding_value(%pad : f32) outer_dims_perm = [2, 1, 0]
+        inner_dims_pos = [1] inner_tiles = [2] into %arg1
+        : tensor<200x127x256xf32> -> tensor<256x64x200x2xf32>
+    //                 \
+    //                padded and tiled dim
+    //
+    // Source dimension 1 is tiled. 64 does not divide 127 evenly, so 1 padded
+    // element is added at the end.
+    //
+    // Note: Only tiled dimensions can be padded.
+    ```
+  }];
+  let arguments = (ins AnyRankedTensor:$source,
+                       AnyRankedTensor:$dest,
+                       Optional<AnyType>:$padding_value,
+                       DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
+                       DenseI64ArrayAttr:$inner_dims_pos,
+                       Variadic<Index>:$inner_tiles,
+                       DenseI64ArrayAttr:$static_inner_tiles);
+  let results = (outs AnyRankedTensor:$result);
+  let assemblyFormat = [{
+    $source
+    (`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)?
+    (`outer_dims_perm` `=` $outer_dims_perm^)?
+    `inner_dims_pos` `=` $inner_dims_pos
+    `inner_tiles` `=`
+    custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
+    `into` $dest attr-dict `:` type($source) `->` type($dest)
+  }];
+
+  let builders = [
+    OpBuilder<(ins "Value":$source, "Value":$dest,
+      "ArrayRef<int64_t>":$innerDimsPos,
+      "ArrayRef<OpFoldResult>":$innerTiles,
+      CArg<"std::optional<Value>", "std::nullopt">:$paddingValue,
+      CArg<"ArrayRef<int64_t>", "{}">:$outerDimsPerm)>
+  ];
+
+  let extraClassDeclaration = commonExtraClassDeclaration # [{
+    // Method to get the shape of the result as `SmallVector<OpFoldResult>`.
+    // This is a static method to allow getting the shape of the destination
+    // expected while creating a `pack` op.
+    static SmallVector<OpFoldResult> getResultShape(OpBuilder &builder,
+        Location loc, ArrayRef<OpFoldResult> sourceDims,
+        ArrayRef<OpFoldResult> innerTileDims, ArrayRef<int64_t> innerDimsPos,
+        ArrayRef<int64_t> outerDimsPerm = {});
+
+    // Method to get the `RankedTensorType` of the result based on the inner
+    // tiles, position of the inner tiles (innerDimsPos)  and interchange vector
+    // of outer loops (outerDimsPerm).
+    static RankedTensorType inferPackedType(RankedTensorType sourceType,
+        ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
+        ArrayRef<int64_t> outerDimsPerm = {});
+
+    // Returns true if we have enough static information to catch undefined
+    // behavior when the tile size does not divide perfectly the dimension of
+    // the input tensor. Detecting UB requires that the input size and either
+    // corresponding tile or output size are static.
+    static bool requirePaddingValue(ArrayRef<int64_t> inputShape,
+                                    ArrayRef<int64_t> innerDimsPos,
+                                    ArrayRef<int64_t> outputShape,
+                                    ArrayRef<int64_t> outerDimsPerm,
+                                    ArrayRef<OpFoldResult> innerTiles);
+
+    static Value createDestinationTensor(OpBuilder &b, Location loc,
+        Value source, ArrayRef<OpFoldResult> innerTileSizes,
+        ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);
+
+    /// Build and return a new PackOp that is a clone of the current PackOp with
+    /// (innerDimsPos, innerTiles) (resp. outerDimsPerm) are permuted by
+    /// innerPermutation (resp. outerPermutation).
+    /// A new `tensor.empty` of the proper shape is built in the process.
+    /// Asserts that:
+    ///   - At least one of innerPermutation or outerPermutation is non-empty.
+    ///   - If not empty, innerPermutation is a valid permutation of size
+    ///     matching innerDimPos.
+    ///   - If not empty, outerPermutation is a valid permutation of size
+    ///     matching outerDimsPerm.
+    PackOp createTransposedClone(OpBuilder &b,
+                                 Location loc,
+                                 ArrayRef<int64_t> innerPermutation,
+                                 ArrayRef<int64_t> outerPermutation);
+
+    /// Check if this PackOp is like a simple pad operation.
+    /// In other words, this operation:
+    /// 1. adds useless dimensions (dimension of size 1),
+    /// 2. pads the other ones, and
+    /// 3. doesn't shuffle the dimensions
+    bool isLikePad();
+  }];
+
+  let hasCanonicalizeMethod = 1;
+
+  let hasFolder = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// UnPackOp
+//===----------------------------------------------------------------------===//
+
+def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
+  let summary = "linalg.unpack operation";
+  let description = [{
+    The "unpack" operation converts a source tensor of rank `n` with a tiled and
+    packed layout to a result tensor of rank `n - k`.
+
+    `inner_dims_pos` (mandatory) specifies `k` source tensor dimensions with
+    which the last `k` source tensor dimensions are combined, where
+    `0 < k <= n/2`. Each `inner_dims_pos` element must be `>= 0` and `< n - k`.
+    The order of the dimensions in `inner_dims_pos` matters: dimension
+    `inner_dims_pos[i]` is combined with dimension `n - k + i` (assuming that
+    `outer_dims_perm` is not specified).
+
+    `inner_tiles` (mandatory) specifies `k` tile sizes. These tile sizes
+    correspond to the least significant ("inner") source tensor dimension sizes.
+    The behavior of this op is undefined if:
+    - `inner_tiles` do not exactly match with the corresponding source tensor
+      dimension sizes.
+    - Or, `inner_tiles[i]` does not divide the size of dimension
+      `inner_dims_pos[i]` (assuming that `outer_dims_perm` is not specified)
+      evenly.
+
+    `outer_dims_perm` (optional) specifies a permutation for the outer
+    dimensions. If specified, it must have `n - k` elements. If specified, this
+    permutation is applied before combining any dimensions.
+
+    Example:
+
+    ```mlir
+    // NCnc to NC:
+    %0 = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32]
+        into %dest : tensor<16x8x8x32xf32> -> tensor<128x256xf32>
+
+    // CK to KCck:
+    %0 = linalg.unpack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
+        inner_tiles = [8, 32] into %dest
+        : tensor<8x16x8x32xf32> -> tensor<128x256xf32>
+    ```
+  }];
+  let arguments = (ins AnyRankedTensor:$source,
+                       AnyRankedTensor:$dest,
+                       DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
+                       DenseI64ArrayAttr:$inner_dims_pos,
+                       Variadic<Index>:$inner_tiles,
+                       DenseI64ArrayAttr:$static_inner_tiles);
+  let results = (outs AnyRankedTensor:$result);
+  let assemblyFormat = [{
+    $source
+    (`outer_dims_perm` `=` $outer_dims_perm^)?
+    `inner_dims_pos` `=` $inner_dims_pos
+    `inner_tiles` `=`
+    custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
+    `into` $dest attr-dict `:` type($source) `->` type($dest)
+  }];
+
+  let builders = [
+    OpBuilder<(ins "Value":$source, "Value":$dest,
+    "ArrayRef<int64_t>":$innerDimsPos,
+    "ArrayRef<OpFoldResult>":$innerTiles,
+    CArg<"ArrayRef<int64_t>", "{}">:$outerDimsPerm)>
+  ];
+
+  let extraClassDeclaration = commonExtraClassDeclaration # [{
+    static Value createDestinationTensor(OpBuilder &b, Location loc,
+        Value source, ArrayRef<OpFoldResult> innerTileSizes,
+        ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);
+
+    /// Build and return a new UnPackOp that is a clone of the current UnPackOp
+    /// with (innerDimsPos, innerTiles) (resp. outerDimsPerm) are permuted by
+    /// innerPermutation (resp. outerPermutation).
+    /// Asserts that:
+    ///   - At least one of innerPermutation or outerPermutation is non-empty.
+    ///   - If not empty, innerPermutation is a valid permutation of size
+    ///     matching innerDimPos.
+    ///   - If not empty, outerPermutation is a valid permutation of size
+    ///     matching outerDimsPerm.
+    UnPackOp createTransposedClone(OpBuilder &b,
+                                   Location loc,
+                                   Value transposedSource,
+                                   ArrayRef<int64_t> innerPermutation,
+                                   ArrayRef<int64_t> outerPermutation);
+
+    /// Check if this UnPackOp is like a simple unpad operation.
+    /// In other words, this operation:
+    /// 1. drops useless dimensions (dimension of size 1), and
+    /// 2. reduces dimensions in place (i.e., no transpose.)
+    bool isLikeUnPad();
+  }];
+
+  let hasCanonicalizeMethod = 1;
+
+  let hasFolder = 1;
+}
+
+#endif // LINALG_RELEAYOUT_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 081bf9b6d3b239..deee9a84aa6ae9 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -45,7 +45,7 @@ def ApplyDecomposeTensorPackUnpackPatternsOp
     : Op<Transform_Dialect, "apply_patterns.linalg.decompose_pack_unpack",
          [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
   let description = [{
-    Collect patterns to decompose tensor.pack and tensor.unpack into e.g.
+    Collect patterns to decompose linalg.pack and linalg.unpack into e.g.
     tensor::PadOp, linalg::transposeOp Ops. Requires all outer dims to be unit.
   }];
 
@@ -126,6 +126,28 @@ def ApplyPadVectorizationPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyFoldIntoPackAndUnpackPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.tensor.fold_into_pack_and_unpack",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that operations like tensor.pad and tensor.extract_slice should
+    be folded into tensor.pack and tensor.unpack operations, respectively.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
+def ApplyFoldPackUnpackIntoEmptyPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.linalg.fold_pack_unpack_into_empty",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    // TODO:
+  }];
+
+  let arguments = (ins DefaultValuedAttr<BoolAttr, "false">:$fold_single_use_only);
+  let assemblyFormat = "attr-dict";
+}
+
 //===----------------------------------------------------------------------===//
 // BufferizeToAllocationOp
 //===----------------------------------------------------------------------===//
@@ -547,19 +569,18 @@ def LowerPackOp : Op<Transform_Dialect, "structured.lower_pack", [
                          TransformOpInterface,
                          ReportTrackingListenerFailuresOpTrait]> {
   let description = [{
-    Rewrite a tensor.pack into tensor.pad + tensor.expand_shape + linalg.transpose.
+    Rewrite a linalg.pack into tensor.pad + tensor.expand_shape + linalg.transpose.
 
     #### Return modes
 
-    This operation ignores non-pack ops and drops them in the return.
-    This operation produces a silenceable failure if the rewrite fails for any
-    reason.
-    If all the operations referred to by the `target` are rewritten, the
-    transform succeeds.
-    Return handles to the newly produced pad, expand_shape and transpose ops.
+    This operation ignores non-pack ops and drops them in the return. This
+    operation produces a silenceable failure if the rewrite fails for any
+    reason. If all the operations referred to by the `target` are rewrit...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/123902


More information about the Mlir-commits mailing list