[Mlir-commits] [mlir] [mlir][linalg] Add tests for tensor.unpack decomposition (PR #118786)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Fri Dec 6 00:05:58 PST 2024
================
@@ -1555,9 +1555,34 @@ struct DecomposeOuterUnitDimsPackOpPattern
PatternRewriter &rewriter) const override;
};
-/// Rewrites a tensor::UnPackOp into a sequence of rank-reduced extract_slice op
-/// + transpose op + insert_slice op, where the tensor::UnPackOp has outer dims
-/// being all 1s.
+/// Rewrites a tensor::UnPackOp into a sequence of rank-reduced
+/// * tensor::ExtractSliceOp + linalg::TransposeOp + tensor::InsertSliceOp
+///
+/// Requires that all the outer dims of the input tensor::PackOp are 1.
+///
+/// Before:
+/// ```
+/// %packed = tensor.unpack %input
+/// inner_dims_pos = [1, 0]
+/// inner_tiles = [2, 8]
+/// into %output : tensor<1x1x2x8xf32> -> tensor<5x1xf32>
+/// ```
+///
+/// After:
+/// ```
+/// // Rank-reduced extract to obtain the tile
+/// %slice = tensor.extract_slice %arg0[0, 0, 0, 0] [1, 1, 2, 8] [1, 1, 1, 1]
+/// : tensor<1x1x2x8xf32> to tensor<2x8xf32>
+/// // EmptyOp + TransposeOp
+/// %init = tensor.empty() : tensor<8x2xf32>
+/// %transposed = linalg.transpose
+/// ins(%extracted_slice : tensor<2x8xf32>)
+/// outs(%0 : tensor<8x2xf32>) permutation = [1, 0]
+/// // Extract a slice matching the specified output size
+/// %result = tensor.extract_slice %transposed[0, 0] [5, 1] [1, 1]
+/// : tensor<8x2xf32> to tensor<5x1xf32>
+///
----------------
banach-space wrote:
Removed.
https://github.com/llvm/llvm-project/pull/118786
More information about the Mlir-commits
mailing list