[Mlir-commits] [mlir] [mlir][tensor] Fold producer linalg transpose with consumer tensor pack (PR #75658)

Prathamesh Tagore llvmlistbot at llvm.org
Mon Jan 8 11:21:48 PST 2024


================
@@ -345,3 +345,164 @@ func.func @tensor_pack_linalg_transpose_fold_dynamic_outer_dims_tile_dims_tile_s
 //      CHECK:     %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [2, 1, 3, 0] inner_dims_pos = [3, 1, 2] inner_tiles = [%[[ARG3]], %[[ARG1]], %[[ARG2]]] into %[[INIT]] : tensor<?x?x?x?xf32> -> tensor<?x?x?x?x?x?x?xf32>
 //      CHECK:     return %[[PACK]] : tensor<?x?x?x?x?x?x?xf32>
 //      CHECK:   }
+
+// -----
+
+func.func @linalg_transpose_tensor_pack_fold(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x57x56x2x32xf32> {
+  %0 = tensor.empty() : tensor<1x56x57x64xf32>
+  %transposed = linalg.transpose
+    ins(%arg0 : tensor<56x57x1x64xf32>)
+    outs(%0 : tensor<1x56x57x64xf32>)
+    permutation = [2, 0, 1, 3]
+
+  %1 = tensor.empty() : tensor<1x57x56x2x32xf32>
+  %pack = tensor.pack %transposed
+    outer_dims_perm = [0, 2, 1, 3]
+    inner_dims_pos = [3]
+    inner_tiles = [32]
+    into %1 : tensor<1x56x57x64xf32> -> tensor<1x57x56x2x32xf32>
+  return %pack : tensor<1x57x56x2x32xf32>
+}
+//CHECK-LABEL: func @linalg_transpose_tensor_pack_fold(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<56x57x1x64xf32>)
+//      CHECK:   %[[INIT:.+]] = tensor.empty() : tensor<1x57x56x2x32xf32>
+//      CHECK:   %[[PACK:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:      outer_dims_perm = [2, 1, 0, 3]
+// CHECK-SAME:      inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:       into %[[INIT]]
+//      CHECK:   return %[[PACK]]
+
+// -----
+
+func.func @linalg_transpose_tensor_pack_fold_with_padding(%arg0: tensor<56x57x1x55xf32>, %padding: f32) -> tensor<1x57x56x2x32xf32> {
+  %0 = tensor.empty() : tensor<1x56x57x55xf32>
+  %transpose = linalg.transpose
+    ins(%arg0 : tensor<56x57x1x55xf32>)
+    outs(%0 : tensor<1x56x57x55xf32>)
+    permutation = [2, 0, 1, 3]
+  
+  %1 = tensor.empty() : tensor<1x57x56x2x32xf32>
+  %pack = tensor.pack %transpose padding_value(%padding : f32)
+    outer_dims_perm = [0, 2, 1, 3]
+    inner_dims_pos = [3]
+    inner_tiles = [32]
+    into %1 : tensor<1x56x57x55xf32> -> tensor<1x57x56x2x32xf32>
+  return %pack : tensor<1x57x56x2x32xf32>
+}
+//CHECK-LABEL: func @linalg_transpose_tensor_pack_fold_with_padding(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<56x57x1x55xf32>, %[[PADDING:.+]]: f32)
+//      CHECK:   %[[INIT:.+]] = tensor.empty() : tensor<1x57x56x2x32xf32>
+//      CHECK:   %[[PACK:.+]] = tensor.pack %[[ARG0]] padding_value(%[[PADDING]] : f32)
+// CHECK-SAME:      outer_dims_perm = [2, 1, 0, 3]
+// CHECK-SAME:      inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:       into %[[INIT]]
+//      CHECK:   return %[[PACK]]
+
+// -----
+
+func.func @linalg_transpose_tensor_pack_fold_no_outer_dims_perm(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x56x57x2x32xf32> {
+  %0 = tensor.empty() : tensor<1x56x57x64xf32>
+  %transposed = linalg.transpose
+    ins(%arg0 : tensor<56x57x1x64xf32>)
+    outs(%0 : tensor<1x56x57x64xf32>)
+    permutation = [2, 0, 1, 3]
+  
+  %1 = tensor.empty() : tensor<1x56x57x2x32xf32>
+  %pack = tensor.pack %transposed
+    inner_dims_pos = [3]
+    inner_tiles = [32]
+    into %1 : tensor<1x56x57x64xf32> -> tensor<1x56x57x2x32xf32>
+  return %pack : tensor<1x56x57x2x32xf32>
+}
+//CHECK-LABEL: func @linalg_transpose_tensor_pack_fold_no_outer_dims_perm(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<56x57x1x64xf32>)
+//      CHECK:   %[[INIT:.+]] = tensor.empty() : tensor<1x56x57x2x32xf32>
+//      CHECK:   %[[PACK:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:      outer_dims_perm = [2, 0, 1, 3]
+// CHECK-SAME:      inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:       into %[[INIT]]
+//      CHECK:   return %[[PACK]]
+
+// -----
+
+func.func @linalg_transpose_tensor_pack_fold_complex_inner_dims_change(%arg0: tensor<25x30x35x40xf32>, %transpose_dest: tensor<35x40x25x30xf32>, %pack_dest: tensor<3x35x5x8x5x10x5xf32>) -> tensor<3x35x5x8x5x10x5xf32> {
+  %transposed = linalg.transpose
+    ins(%arg0 : tensor<25x30x35x40xf32>)
+    outs(%transpose_dest : tensor<35x40x25x30xf32>)
+    permutation = [2, 3, 0, 1]
+  
+  %pack = tensor.pack %transposed
+    outer_dims_perm = [3, 0, 2, 1]
+    inner_dims_pos = [1, 3, 2]
+    inner_tiles = [5, 10, 5]
+    into %pack_dest : tensor<35x40x25x30xf32> -> tensor<3x35x5x8x5x10x5xf32>
+  return %pack : tensor<3x35x5x8x5x10x5xf32>
+}
+//CHECK-LABEL:   func.func @linalg_transpose_tensor_pack_fold_complex_inner_dims_change(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<25x30x35x40xf32>, 
+// CHECK-SAME:     %[[ARG1:.+]]: tensor<35x40x25x30xf32>, 
+// CHECK-SAME:     %[[ARG2:.+]]: tensor<3x35x5x8x5x10x5xf32>) -> tensor<3x35x5x8x5x10x5xf32> {
+//      CHECK:     %[[VAL0:.+]] = tensor.empty() : tensor<3x35x5x8x5x10x5xf32>
+//      CHECK:     %[[PACK:.+]] = tensor.pack %[[ARG0]] 
+// CHECK-SAME:        outer_dims_perm = [1, 2, 0, 3] 
+// CHECK-SAME:        inner_dims_pos = [3, 1, 0] 
+// CHECK-SAME:        inner_tiles = [5, 10, 5] 
+// CHECK-SAME:         into %[[VAL0]] 
+//      CHECK:     return %[[PACK]]
+
+// -----
+
+func.func @linalg_transpose_tensor_pack_fold_dynamic_outer_dims_tile_dims_tile_sizes(%arg0: tensor<?x?x?x?xf32>, %transpose_dest: tensor<?x?x?x?xf32>, %pack_dest: tensor<?x?x?x?x?x?x?xf32>, %tile_p : index, %tile_q : index, %tile_r : index) -> tensor<?x?x?x?x?x?x?xf32> {
+  %transposed = linalg.transpose
+    ins(%arg0 : tensor<?x?x?x?xf32>)
+    outs(%transpose_dest : tensor<?x?x?x?xf32>)
+    permutation = [2, 3, 0, 1]
+  
+  %pack = tensor.pack %transposed
+    outer_dims_perm = [3, 0, 2, 1]
+    inner_dims_pos = [1, 3, 2]
+    inner_tiles = [%tile_p, %tile_q, %tile_r]
+    into %pack_dest : tensor<?x?x?x?xf32> -> tensor<?x?x?x?x?x?x?xf32>
+  return %pack : tensor<?x?x?x?x?x?x?xf32>
+}
+//      CHECK:   #[[map:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
+//CHECK-LABEL:   func.func @linalg_transpose_tensor_pack_fold_dynamic_outer_dims_tile_dims_tile_sizes(
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?x?xf32>, 
+// CHECK-SAME:   %[[ARG2:.+]]: tensor<?x?x?x?x?x?x?xf32>, %[[ARG3:.+]]: index, %[[ARG4:.+]]: index, %[[ARG5:.+]]: index) -> tensor<?x?x?x?x?x?x?xf32> {
+//      CHECK:     %[[C0:.+]] = arith.constant 0 : index
+//      CHECK:     %[[C1:.+]] = arith.constant 1 : index
+//      CHECK:     %[[C2:.+]] = arith.constant 2 : index
+//      CHECK:     %[[C3:.+]] = arith.constant 3 : index
+//      CHECK:     %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?x?xf32>
+//      CHECK:     %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?xf32>
+//      CHECK:     %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?xf32>
+//      CHECK:     %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C3]] : tensor<?x?x?x?xf32>
+//      CHECK:     %[[VAL0:.+]] = affine.apply #[[map:.+]]()[%[[DIM2]], %[[ARG3]]]
+//      CHECK:     %[[VAL1:.+]] = affine.apply #[[map:.+]]()[%[[DIM0]], %[[ARG4]]]
+//      CHECK:     %[[VAL2:.+]] = affine.apply #[[map:.+]]()[%[[DIM]], %[[ARG5]]]
+//      CHECK:     %[[VAL3:.+]] = tensor.empty(%[[VAL1]], %[[DIM1]], %[[VAL2]], %[[VAL0]], %[[ARG3]], %[[ARG4]], %[[ARG5]]) : tensor<?x?x?x?x?x?x?xf32>
+//      CHECK:     %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [1, 2, 0, 3] inner_dims_pos = [3, 1, 0] inner_tiles = [%[[ARG3]], %[[ARG4]], %[[ARG5]]] into %[[VAL3]] : tensor<?x?x?x?xf32> -> tensor<?x?x?x?x?x?x?xf32>
+//      CHECK:     return %[[PACK]] : tensor<?x?x?x?x?x?x?xf32>
+
+// -----
+
+func.func @linalg_transpose_tensor_cast_tensor_pack_fold(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x57x56x2x32xf32> {
+  %0 = tensor.empty() : tensor<1x56x57x64xf32>
+  %transposed = linalg.transpose
+    ins(%arg0 : tensor<56x57x1x64xf32>)
+    outs(%0 : tensor<1x56x57x64xf32>)
+    permutation = [2, 0, 1, 3]
+
+  %transposed_cast = tensor.cast %transposed : tensor<1x56x57x64xf32> to tensor<?x56x57x64xf32> 
----------------
meshtag wrote:

The `tensor.cast` serves as the trigger for negative test. It prevents `lining.transpose` from being the direct parent op of `tensor.pack` and thus doesn't transform the IR. This was added as a negative test to the pattern. 

https://github.com/llvm/llvm-project/pull/75658


More information about the Mlir-commits mailing list