[Mlir-commits] [mlir] 3eb7cce - [mlir][linalg] Add tests for tensor.unpack decomposition (#118786)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Dec 6 01:28:18 PST 2024
Author: Andrzej WarzyĆski
Date: 2024-12-06T09:28:15Z
New Revision: 3eb7ccec58ae5c7360c6b871952f5f5d7f0a1860
URL: https://github.com/llvm/llvm-project/commit/3eb7ccec58ae5c7360c6b871952f5f5d7f0a1860
DIFF: https://github.com/llvm/llvm-project/commit/3eb7ccec58ae5c7360c6b871952f5f5d7f0a1860.diff
LOG: [mlir][linalg] Add tests for tensor.unpack decomposition (#118786)
This commit adds additional tests and documentation for
`DecomposeOuterUnitDimsUnPackOpPattern` to ensure symmetry with its
counterpart for `tensor.pack`, `DecomposeOuterUnitDimsPackOpPattern`.
The new tests aim to improve implementation, documentation, and test
coverage for tensor.unpack. They cover the following scenarios:
* Static tile sizes: A simple `tensor.unpack` case
(`@simple_unpack_static_tiles`).
* Dynamic tile size: `tensor.unpack` with a single dynamic tile size
(`@simple_unpack_dynamic_tile`).
* Transpose: `tensor.unpack` with dynamic tile size and transpose
(`@simple_unpack_dynamic_tile_transpose`), currently commented out due
to some missing logic (see below)
* Scalable tile size: `tensor.unpack` with a scalable inner tile size
(@simple_unpack_scalable_tile).
Notes:
The test `@simple_unpack_dynamic_tile_transpose` is commented out
because the logic for capturing dynamic sizes for `tensor::EmptyOp` when
some tile sizes are dynamic is incomplete. This missing functionality
will be addressed in a follow-up patch.
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/test/Dialect/Linalg/decompose-tensor-pack.mlir
mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 3c160d55a38e75..f31371ec6a0540 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1519,7 +1519,7 @@ struct DecomposePadOpPattern : public OpRewritePattern<tensor::PadOp> {
/// * tensor::PadOp + linalg::TransposeOp + tensor::EmptyOp +
/// tensor::InsertSliceOp ops.
///
-/// Required that all the outer dims of the input tensor::PackOp are 1.
+/// Requires that all the outer dims of the input tensor::PackOp are 1.
///
/// Before:
/// ```
@@ -1555,9 +1555,33 @@ struct DecomposeOuterUnitDimsPackOpPattern
PatternRewriter &rewriter) const override;
};
-/// Rewrites a tensor::UnPackOp into a sequence of rank-reduced extract_slice op
-/// + transpose op + insert_slice op, where the tensor::UnPackOp has outer dims
-/// being all 1s.
+/// Rewrites a tensor::UnPackOp into a sequence of rank-reduced
+/// * tensor::ExtractSliceOp + linalg::TransposeOp + tensor::InsertSliceOp
+///
+/// Requires that all the outer dims of the input tensor::PackOp are 1.
+///
+/// Before:
+/// ```
+/// %packed = tensor.unpack %input
+/// inner_dims_pos = [1, 0]
+/// inner_tiles = [2, 8]
+/// into %output : tensor<1x1x2x8xf32> -> tensor<5x1xf32>
+/// ```
+///
+/// After:
+/// ```
+/// // Rank-reduced extract to obtain the tile
+/// %slice = tensor.extract_slice %arg0[0, 0, 0, 0] [1, 1, 2, 8] [1, 1, 1, 1]
+/// : tensor<1x1x2x8xf32> to tensor<2x8xf32>
+/// // EmptyOp + TransposeOp
+/// %init = tensor.empty() : tensor<8x2xf32>
+/// %transposed = linalg.transpose
+/// ins(%extracted_slice : tensor<2x8xf32>)
+/// outs(%0 : tensor<8x2xf32>) permutation = [1, 0]
+/// // Extract a slice matching the specified output size
+/// %result = tensor.extract_slice %transposed[0, 0] [5, 1] [1, 1]
+/// : tensor<8x2xf32> to tensor<5x1xf32>
+/// ```
struct DecomposeOuterUnitDimsUnPackOpPattern
: public OpRewritePattern<tensor::UnPackOp> {
using OpRewritePattern<tensor::UnPackOp>::OpRewritePattern;
diff --git a/mlir/test/Dialect/Linalg/decompose-tensor-pack.mlir b/mlir/test/Dialect/Linalg/decompose-tensor-pack.mlir
index 4f986606ef93ad..1cc1484ed40951 100644
--- a/mlir/test/Dialect/Linalg/decompose-tensor-pack.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-tensor-pack.mlir
@@ -67,6 +67,9 @@ func.func @simple_pad_and_pack_dynamic_tile(%input: tensor<5x1xf32>, %output: te
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD:.*]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_0]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
// CHECK: return %[[RES]] : tensor<1x1x?x2xf32>
+/// Same as example above, but the dynamic tile size is a compile-time constant
+/// that's folded away.
+
func.func @simple_pad_and_pack_dynamic_tile_cst(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32) -> tensor<1x1x?x2xf32> {
%tile_dim_0 = arith.constant 8 : index
%0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%tile_dim_0, 2] into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32>
diff --git a/mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir b/mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir
index 8b15873473a972..a720c655e4be51 100644
--- a/mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-tensor-unpack.mlir
@@ -19,11 +19,11 @@ func.func @simple_KCRSsr_to_KCRS(%arg0: tensor<1x1x1x1x8x32xf32>, %arg1: tensor<
// -----
-func.func @simple_unpack_and_extract_slice(%input: tensor<1x1x8x2xf32>, %output: tensor<5x1xf32>) -> tensor<5x1xf32> {
+func.func @simple_unpack_static_tiles(%input: tensor<1x1x8x2xf32>, %output: tensor<5x1xf32>) -> tensor<5x1xf32> {
%0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<1x1x8x2xf32> -> tensor<5x1xf32>
return %0 : tensor<5x1xf32>
}
-// CHECK-LABEL: func.func @simple_unpack_and_extract_slice
+// CHECK-LABEL: func.func @simple_unpack_static_tiles
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
@@ -33,6 +33,55 @@ func.func @simple_unpack_and_extract_slice(%input: tensor<1x1x8x2xf32>, %output:
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[TILE]][0, 0] [5, 1] [1, 1]
// CHECK: return %[[SLICE]]
+/// Same as example above, but with 1 dynamic tile size.
+
+func.func @simple_unpack_dynamic_tile(%input: tensor<1x1x?x2xf32>, %output: tensor<5x1xf32>, %tile_dim_0: index) -> tensor<5x1xf32> {
+ %0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [%tile_dim_0, 2] into %output : tensor<1x1x?x2xf32> -> tensor<5x1xf32>
+ return %0 : tensor<5x1xf32>
+}
+// CHECK-LABEL: func.func @simple_unpack_dynamic_tile
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[TILE_DIM_1:[a-zA-Z0-9]+]]
+// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_1]], 2] [1, 1, 1, 1]
+// CHECK-NOT: linalg.transpose
+// They have the same type, so the insert_slice op is folded
+// away.
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[TILE]][0, 0] [5, 1] [1, 1]
+// CHECK: return %[[SLICE]]
+
+/// Same as example above, but with 1 dynamic tile size and a trasnpose
+
+/// FIXME: This is currently broken:
+/// * 'tensor.empty' op incorrect number of dynamic sizes, has 0, expected 1
+
+//func.func @simple_unpack_dynamic_tile_transpose(%input: tensor<1x1x2x?xf32>, %output: tensor<5x1xf32>, %tile_dim_0: index) -> tensor<5x1xf32> {
+// %0 = tensor.unpack %input inner_dims_pos = [1, 0] inner_tiles = [2, %tile_dim_0] into %output : tensor<1x1x2x?xf32> -> tensor<5x1xf32>
+// return %0 : tensor<5x1xf32>
+//}
+
+/// Same as example above, but with 1 scalable tile size.
+
+func.func @simple_unpack_scalable_tile(%input: tensor<1x1x?x2xf32>, %output: tensor<5x1xf32>) -> tensor<5x1xf32> {
+ %c8 = arith.constant 8 : index
+ %vscale = vector.vscale
+ %c8_vscale = arith.muli %vscale, %c8 : index
+ %0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [%c8_vscale, 2] into %output : tensor<1x1x?x2xf32> -> tensor<5x1xf32>
+ return %0 : tensor<5x1xf32>
+}
+// CHECK-LABEL: func.func @simple_unpack_scalable_tile
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
+// CHECK-DAG: %[[VS:.+]] = vector.vscale
+// CHECK: %[[C8_VS:.+]] = arith.muli %[[VS]], %[[C8]] : index
+// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, %[[C8_VS]], 2] [1, 1, 1, 1]
+// CHECK-NOT: linalg.transpose
+// They have the same type, so the insert_slice op is folded
+// away.
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[TILE]][0, 0] [5, 1] [1, 1]
+// CHECK: return %[[SLICE]]
+
// -----
func.func @simple_CNnc_to_NC(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<32x8xf32>) -> tensor<32x8xf32>{
More information about the Mlir-commits
mailing list