[Mlir-commits] [mlir] [MLIR] Add patterns to bubble-up pack and push-down unpack through collapse/expand shape ops (PR #85297)
Quinn Dawkins
llvmlistbot at llvm.org
Tue Mar 26 18:29:50 PDT 2024
================
@@ -905,3 +905,127 @@ func.func @unpack_different_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [16]
// CHECK-SAME: into %[[UNPACK_NEW_DEST]]
// CHECK: return %[[UNPACK]] : tensor<16x540x960xi32>
+
+// -----
+
+func.func @bubble_up_pack_through_collapse(%1: tensor<?x16x4xf32>, %dim : index) -> tensor<?x4x8x1xf32> {
+ %collapsed = tensor.collapse_shape %1 [[0, 1], [2]] : tensor<?x16x4xf32> into tensor<?x4xf32>
+ %2 = tensor.empty(%dim) : tensor<?x4x8x1xf32>
+ %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %2 : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
+ func.return %pack : tensor<?x4x8x1xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_through_collapse
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x16x4xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x2x4x8x1xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %[[EMPTY]] : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]] : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1xf32>
+// CHECK: return %[[COLLAPSED]] : tensor<?x4x8x1xf32>
+
+// -----
+
+func.func @bubble_up_permuted_pack_through_collapse(%1: tensor<4x192x16x256xf32>) -> tensor<4x32x3072x8x1xf32> {
+ %collapsed = tensor.collapse_shape %1 [[0], [1, 2], [3]] : tensor<4x192x16x256xf32> into tensor<4x3072x256xf32>
+ %2 = tensor.empty() : tensor<4x32x3072x8x1xf32>
+ %pack = tensor.pack %collapsed outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [8, 1] into %2 : tensor<4x3072x256xf32> -> tensor<4x32x3072x8x1xf32>
+ func.return %pack : tensor<4x32x3072x8x1xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_permuted_pack_through_collapse
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x32x192x16x8x1xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3, 2] inner_tiles = [8, 1] into %[[EMPTY]] : tensor<4x192x16x256xf32> -> tensor<4x32x192x16x8x1xf32>
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %pack {{\[}}[0], [1], [2, 3], [4], [5]] : tensor<4x32x192x16x8x1xf32> into tensor<4x32x3072x8x1xf32>
+// CHECK: return %[[COLLAPSED]] : tensor<4x32x3072x8x1xf32>
+
+// -----
+
+func.func @bubble_up_pack_through_unit_collapse(%1: tensor<1x64x1x4xf32>) -> tensor<8x4x8x1xf32> {
+ %collapsed = tensor.collapse_shape %1 [[0, 1, 2], [3]] : tensor<1x64x1x4xf32> into tensor<64x4xf32>
+ %2 = tensor.empty() : tensor<8x4x8x1xf32>
+ %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %2 : tensor<64x4xf32> -> tensor<8x4x8x1xf32>
+ func.return %pack : tensor<8x4x8x1xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_through_unit_collapse
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x8x1x4x8x1xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [1, 3] inner_tiles = [8, 1] into %[[EMPTY]] : tensor<1x64x1x4xf32> -> tensor<1x8x1x4x8x1xf32>
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[PACK]] {{\[}}[0, 1, 2], [3], [4], [5]] : tensor<1x8x1x4x8x1xf32> into tensor<8x4x8x1xf32>
+// CHECK: return %[[COLLAPSED]] : tensor<8x4x8x1xf32>
+
+// -----
+
+func.func @no_bubble_up_pack_through_non_divisible_collapse(%1: tensor<3072x64x4xf32>) -> tensor<384x32x8x8xf32> {
+ %collapsed = tensor.collapse_shape %1 [[0], [1, 2]] : tensor<3072x64x4xf32> into tensor<3072x256xf32>
+ %2 = tensor.empty() : tensor<384x32x8x8xf32>
+ %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %2 : tensor<3072x256xf32> -> tensor<384x32x8x8xf32>
+ func.return %pack : tensor<384x32x8x8xf32>
+}
+// CHECK-LABEL: func.func @no_bubble_up_pack_through_non_divisible_collapse
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<3072x64x4xf32> into tensor<3072x256xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[COLLAPSED]]
+// CHECK: return %[[PACK]] : tensor<384x32x8x8xf32>
+
+// -----
+
+func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index) -> tensor<?x256x256xf32> {
+ %6 = tensor.empty(%dim) : tensor<?x256xf32>
+ %unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
+ %expanded = tensor.expand_shape %unpack [[0, 1], [2]] : tensor<?x256xf32> into tensor<?x256x256xf32>
+ func.return %expanded : tensor<?x256x256xf32>
+}
+// CHECK-LABEL: func.func @push_down_unpack_through_expand
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3], [4]] : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
+// CHECK: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] : tensor<?x32x32x8x8xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x256x256xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED:.+]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
+// CHECK: return %[[UNPACK]] : tensor<?x256x256xf32>
+
+// -----
+
+func.func @push_down_permuted_unpack_through_expand(%5: tensor<4x32x384x8x8xf32>) -> tensor<4x12x256x256xf32> {
+ %6 = tensor.empty() : tensor<4x3072x256xf32>
+ %unpack = tensor.unpack %5 outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [8, 8] into %6 : tensor<4x32x384x8x8xf32> -> tensor<4x3072x256xf32>
+ %expanded = tensor.expand_shape %unpack [[0], [1, 2], [3]] : tensor<4x3072x256xf32> into tensor<4x12x256x256xf32>
+ func.return %expanded : tensor<4x12x256x256xf32>
+}
+// CHECK-LABEL: @push_down_permuted_unpack_through_expand
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2, 3], [4], [5]] : tensor<4x32x384x8x8xf32> into tensor<4x32x12x32x8x8xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x12x256x256xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED]] outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<4x32x12x32x8x8xf32> -> tensor<4x12x256x256xf32>
+// CHECK: return %[[UNPACK]] : tensor<4x12x256x256xf32>
+
+// -----
+
+func.func @push_down_unpack_through_unit_expand(%5: tensor<6x32x8x8xf32>) -> tensor<3x16x1x256xf32> {
+ %6 = tensor.empty() : tensor<48x256xf32>
+ %unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<6x32x8x8xf32> -> tensor<48x256xf32>
+ %expanded = tensor.expand_shape %unpack [[0, 1, 2], [3]] : tensor<48x256xf32> into tensor<3x16x1x256xf32>
+ func.return %expanded : tensor<3x16x1x256xf32>
+}
+// CHECK-LABEL: func.func @push_down_unpack_through_unit_expand
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3], [4], [5]] : tensor<6x32x8x8xf32> into tensor<3x2x1x32x8x8xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x16x1x256xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED]] outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [1, 3] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<3x2x1x32x8x8xf32> -> tensor<3x16x1x256xf32>
+// CHECK: return %[[UNPACK]] : tensor<3x16x1x256xf32>
+
+// -----
+
+func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x8xf32>) -> tensor<256x12x256xf32> {
+ %6 = tensor.empty() : tensor<3072x256xf32>
+ %unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<384x32x8x8xf32> -> tensor<3072x256xf32>
+ %expanded = tensor.expand_shape %unpack [[0, 1], [2]] : tensor<3072x256xf32> into tensor<256x12x256xf32>
+ func.return %expanded : tensor<256x12x256xf32>
+}
+// CHECK-LABEL: func.func @no_push_down_unpack_through_non_divisible_expand
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[UNPACK]] {{\[}}[0, 1], [2]] : tensor<3072x256xf32> into tensor<256x12x256xf32>
+// CHECK: return %[[EXPANDED]] : tensor<256x12x256xf32>
----------------
qedawkins wrote:
Can you add a test where the reassociated dims are not affected by the pack (or unpack)? e.g.
```mlir
%collapsed = tensor.collapse_shape %1 [[0, 1], [2]] : tensor<?x16x4xf32> into tensor<?x4xf32>
%2 = tensor.empty(%dim) : tensor<?x2x2xf32>
%pack = tensor.pack %collapsed inner_dims_pos = [1] inner_tiles = [2] into %2 : tensor<?x4xf32> -> tensor<?x2x2xf32>
```
https://github.com/llvm/llvm-project/pull/85297
More information about the Mlir-commits
mailing list