[Mlir-commits] [mlir] [MLIR] Add pattern to bubble up tensor.extract_slice (PR #126898)

ofri frishman llvmlistbot at llvm.org
Wed Feb 26 22:34:52 PST 2025


================
@@ -278,3 +278,141 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// CHECK-LABEL: func.func @swap_expand_shape_with_extract_slice
+//     CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}}
+//     CHECK:   scf.for %[[Y:[A-Za-z0-9]+]] = {{.*}}
+//     CHECK:     scf.for %[[Z:[A-Za-z0-9]+]] = {{.*}}
+//     CHECK:       %[[LINEAR_IDX:.+]] = affine.linearize_index disjoint [%[[X]], %[[Y]], %[[Z]]] by (2, 3, 10)
+//     CHECK:       %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[%[[LINEAR_IDX]]] [5] [1] : tensor<60xf32> to tensor<5xf32>
+//     CHECK:       %[[EXPAND:.+]] = tensor.expand_shape %[[SLICE]] {{\[\[}}0, 1, 2]] output_shape [1, 1, 5]
+//     CHECK:       linalg.exp ins(%[[EXPAND]]
+func.func @swap_expand_shape_with_extract_slice(%0: tensor<60xf32>) -> tensor<2x3x10xf32> {
+  %expand = tensor.expand_shape %0 [[0, 1, 2]] output_shape [2, 3, 10] : tensor<60xf32> into tensor<2x3x10xf32>
+  %empty = tensor.empty() : tensor<2x3x10xf32>
+  %exp = linalg.exp ins(%expand : tensor<2x3x10xf32>) outs(%empty : tensor<2x3x10xf32>) -> tensor<2x3x10xf32>
+  return %exp : tensor<2x3x10xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %transformed, %loops:3 = transform.structured.fuse %0 [1, 1, 5] interchange [0, 1, 2] apply_cleanup = true : 
+      (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op, !transform.any_op)
+    transform.yield 
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @swap_expand_shape_with_extract_slice_full_inner_dim
+//     CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}}
+//     CHECK:   scf.for %[[Y:[A-Za-z0-9]+]] = {{.*}}
+//     CHECK:       %[[LINEAR_IDX:.+]] = affine.linearize_index disjoint [%[[X]], %[[Y]]{{.*}} by (3, 4, 10)
+//     CHECK:       %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[%[[LINEAR_IDX]]] [20] [1] : tensor<120xf32> to tensor<20xf32>
+//     CHECK:       %[[EXPAND:.+]] = tensor.expand_shape %[[SLICE]] {{\[\[}}0, 1, 2]] output_shape [1, 2, 10]
+//     CHECK:       linalg.exp ins(%[[EXPAND]]
+func.func @swap_expand_shape_with_extract_slice_full_inner_dim(%0: tensor<120xf32>) -> tensor<3x4x10xf32> {
+  %expand = tensor.expand_shape %0 [[0, 1, 2]] output_shape [3, 4, 10] : tensor<120xf32> into tensor<3x4x10xf32>
+  %empty = tensor.empty() : tensor<3x4x10xf32>
+  %exp = linalg.exp ins(%expand : tensor<3x4x10xf32>) outs(%empty : tensor<3x4x10xf32>) -> tensor<3x4x10xf32>
+  return %exp : tensor<3x4x10xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %transformed, %loops:2 = transform.structured.fuse %0 [1, 2, 0] interchange [0, 1, 2] apply_cleanup = true : 
+      (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op)
+    transform.yield 
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @no_swap_expand_shape_with_extract_slice_non_contiguous
+//     CHECK: tensor.expand_shape
+//     CHECK: scf.for
+//     CHECK:   scf.for
+//     CHECK:     scf.for
+//     CHECK:       linalg.exp
+func.func @no_swap_expand_shape_with_extract_slice_non_contiguous(%0: tensor<120xf32>) -> tensor<3x4x10xf32> {
+  %expand = tensor.expand_shape %0 [[0, 1, 2]] output_shape [3, 4, 10] : tensor<120xf32> into tensor<3x4x10xf32>
+  %empty = tensor.empty() : tensor<3x4x10xf32>
+  %exp = linalg.exp ins(%expand : tensor<3x4x10xf32>) outs(%empty : tensor<3x4x10xf32>) -> tensor<3x4x10xf32>
+  return %exp : tensor<3x4x10xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %transformed, %loops:3 = transform.structured.fuse %0 [1, 2, 5] interchange [0, 1, 2] apply_cleanup = true : 
+      (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op, !transform.any_op)
+    transform.yield 
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @swap_expand_shape_with_extract_slice_multiple_expanded_dims
+//     CHECK: %[[C0:.+]] = arith.constant 0 : index
+//     CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}}
+//     CHECK:   scf.for %[[Y:[A-Za-z0-9]+]] = {{.*}}
+//     CHECK:     scf.for %[[Z:[A-Za-z0-9]+]] = {{.*}}
+//     CHECK:       scf.for %[[W:[A-Za-z0-9]+]] = {{.*}}
+//     CHECK:       %[[LINEAR_IDX0:.+]] = affine.linearize_index disjoint [%[[X]], %[[Y]], %[[C0]]] by (3, 4, 10)
+//     CHECK:       %[[LINEAR_IDX1:.+]] = affine.linearize_index disjoint [%[[Z]], %[[W]]] by (7, 8)
+//     CHECK:       %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[%[[LINEAR_IDX0]], %[[LINEAR_IDX1]]] [20, 4] [1, 1] : tensor<120x56xf32> to tensor<20x4xf32>
+//     CHECK:       %[[EXPAND:.+]] = tensor.expand_shape %[[SLICE]] {{\[\[}}0, 1, 2], [3, 4]] output_shape [1, 2, 10, 1, 4]
+//     CHECK:       linalg.exp ins(%[[EXPAND]]
+module {
+  func.func @swap_expand_shape_with_extract_slice_multiple_expanded_dims(%0: tensor<120x56xf32>) -> tensor<3x4x10x7x8xf32> {
+    %expand = tensor.expand_shape %0 [[0, 1, 2], [3, 4]] output_shape [3, 4, 10, 7, 8] : tensor<120x56xf32> into tensor<3x4x10x7x8xf32>
+    %empty = tensor.empty() : tensor<3x4x10x7x8xf32>
+    %exp = linalg.exp ins(%expand : tensor<3x4x10x7x8xf32>) outs(%empty : tensor<3x4x10x7x8xf32>) -> tensor<3x4x10x7x8xf32>
+    return %exp : tensor<3x4x10x7x8xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %transformed, %loops:4 = transform.structured.fuse %0 [1, 2, 0, 1, 4] interchange [0, 1, 2, 3, 4] apply_cleanup = true : 
+      (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield 
+  }
+}
+
+// -----
+
+//     CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}}
+//     CHECK:    %[[LINEAR_IDX:.+]] = affine.linearize_index disjoint [%[[X]], {{.*}} by (8, 32)
+//     CHECK:    %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[0, 0, %[[LINEAR_IDX]]] [1, 1800, 32] [1, 1, 1] : tensor<1x1800x256xf32> to tensor<1x1800x32xf32>
+//     CHECK:    %[[ABS:.+]] = linalg.abs ins(%[[SLICE]]
+//     CHECK:    %[[EXPAND:.+]] = tensor.expand_shape %[[ABS]] {{\[\[}}0], [1], [2, 3]] output_shape [1, 1800, 1, 32]
+//     CHECK:    linalg.exp ins(%[[EXPAND]]
+module {
+  func.func @swap_expand_shape_with_extract_slice_and_fuse_with_expand_producer(%0: tensor<1x1800x256xf32>) -> tensor<1x1800x8x32xf32> {
+    %empty1 = tensor.empty() : tensor<1x1800x256xf32>
+    %exp1 = linalg.abs ins(%0 : tensor<1x1800x256xf32>) outs(%empty1 : tensor<1x1800x256xf32>) -> tensor<1x1800x256xf32>
+    %expand = tensor.expand_shape %exp1 [[0], [1], [2, 3]] output_shape [1, 1800, 8, 32] : tensor<1x1800x256xf32> into tensor<1x1800x8x32xf32>
+    %empty2 = tensor.empty() : tensor<1x1800x8x32xf32>
+    %exp2 = linalg.exp ins(%expand : tensor<1x1800x8x32xf32>) outs(%empty2 : tensor<1x1800x8x32xf32>) -> tensor<1x1800x8x32xf32>
+    return %exp2 : tensor<1x1800x8x32xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %transformed, %loops:1 = transform.structured.fuse %0 [0, 0, 1, 0] interchange [0, 1, 2, 3] apply_cleanup = true : 
+      (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">)
+    transform.yield 
+  }
+}
+
+
+
+
----------------
ofri-frishman wrote:

Fixed

https://github.com/llvm/llvm-project/pull/126898


More information about the Mlir-commits mailing list