[Mlir-commits] [mlir] [mlir][linalg] Add tests for tensor.unpack decomposition (PR #118786)

Andrzej WarzyƄski llvmlistbot at llvm.org
Fri Dec 6 00:05:58 PST 2024


================
@@ -1555,9 +1555,34 @@ struct DecomposeOuterUnitDimsPackOpPattern
                                 PatternRewriter &rewriter) const override;
 };
 
-/// Rewrites a tensor::UnPackOp into a sequence of rank-reduced extract_slice op
-/// + transpose op + insert_slice op, where the tensor::UnPackOp has outer dims
-/// being all 1s.
+/// Rewrites a tensor::UnPackOp into a sequence of rank-reduced
+///   * tensor::ExtractSliceOp + linalg::TransposeOp + tensor::InsertSliceOp
+///
+/// Requires that all the outer dims of the input tensor::PackOp are 1.
+///
+/// Before:
+/// ```
+/// %packed = tensor.unpack %input
+///   inner_dims_pos = [1, 0]
+///   inner_tiles = [2, 8]
+///   into %output : tensor<1x1x2x8xf32> -> tensor<5x1xf32>
+/// ```
+///
+/// After:
+/// ```
+///   // Rank-reduced extract to obtain the tile
+///   %slice = tensor.extract_slice %arg0[0, 0, 0, 0] [1, 1, 2, 8] [1, 1, 1, 1]
+///     : tensor<1x1x2x8xf32> to tensor<2x8xf32>
+///   // EmptyOp + TransposeOp
+///   %init = tensor.empty() : tensor<8x2xf32>
+///   %transposed = linalg.transpose
+///     ins(%extracted_slice : tensor<2x8xf32>)
+///     outs(%0 : tensor<8x2xf32>) permutation = [1, 0]
+///   // Extract a slice matching the specified output size
+///   %result = tensor.extract_slice %transposed[0, 0] [5, 1] [1, 1]
+///     : tensor<8x2xf32> to tensor<5x1xf32>
+///
----------------
banach-space wrote:

Removed.

https://github.com/llvm/llvm-project/pull/118786


More information about the Mlir-commits mailing list