[Mlir-commits] [mlir] [MLIR] Add pattern to bubble up tensor.extract_slice (PR #126898)

Andrzej WarzyƄski llvmlistbot at llvm.org
Sun Feb 23 09:28:11 PST 2025


================
@@ -0,0 +1,74 @@
+// RUN: mlir-opt -split-input-file -transform-interpreter  %s | FileCheck %s
+
+// CHECK-LABEL:   func.func @bubble_up_extract_slice_through_expand_shape(
+// CHECK-SAME:                                                    %[[ARG0:.*]]: tensor<60xf32>) -> tensor<1x1x5xf32> {
+// CHECK:           %[[C1:.+]] = arith.constant 5 : index
+// CHECK:           %[[EXTRACT:.*]] = tensor.extract_slice %[[ARG0]]{{\[}}%[[C1]]] [5] [1] : tensor<60xf32> to tensor<5xf32>
+// CHECK:           %[[EXPAND:.*]] = tensor.expand_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]] output_shape [1, 1, 5] : tensor<5xf32> into tensor<1x1x5xf32>
+// CHECK:           return %[[EXPAND]] : tensor<1x1x5xf32>
+
+func.func @bubble_up_extract_slice_through_expand_shape(%0: tensor<60xf32>) -> tensor<1x1x5xf32> {
+  %expand = tensor.expand_shape %0 [[0, 1, 2]] output_shape [2, 3, 10] : tensor<60xf32> into tensor<2x3x10xf32>
+  %extract = tensor.extract_slice %expand[0, 0, 5][1, 1, 5][1, 1, 1] : tensor<2x3x10xf32> to tensor<1x1x5xf32>
+  return %extract : tensor<1x1x5xf32>
+}
+
+
+// CHECK-LABEL:   func.func @no_bubble_up_extract_slice_on_non_contiguous(
+// CHECK:           %[[EXPAND:.*]] = tensor.expand_shape 
+// CHECK:           %[[EXTRACT:.*]] = tensor.extract_slice 
+// CHECK:           return %[[EXTRACT]]
+
+func.func @no_bubble_up_extract_slice_on_non_contiguous(%0: tensor<60xf32>) -> tensor<1x2x5xf32> {
+  %expand = tensor.expand_shape %0 [[0, 1, 2]] output_shape [2, 3, 10] : tensor<60xf32> into tensor<2x3x10xf32>
+  %extract = tensor.extract_slice %expand[0, 0, 5][1, 2, 5][1, 1, 1] : tensor<2x3x10xf32> to tensor<1x2x5xf32>
+  return %extract : tensor<1x2x5xf32>
+}
+
+
+// CHECK-LABEL:   func.func @no_bubble_up_extract_slice_on_stride(
+// CHECK:           %[[EXPAND:.*]] = tensor.expand_shape 
+// CHECK:           %[[EXTRACT:.*]] = tensor.extract_slice 
+// CHECK:           return %[[EXTRACT]]
+
+func.func @no_bubble_up_extract_slice_on_stride(%0: tensor<60xf32>) -> tensor<1x1x5xf32> {
+  %expand = tensor.expand_shape %0 [[0, 1, 2]] output_shape [2, 3, 10] : tensor<60xf32> into tensor<2x3x10xf32>
+  %extract = tensor.extract_slice %expand[0, 0, 5][1, 1, 5][1, 1, 2] : tensor<2x3x10xf32> to tensor<1x1x5xf32>
+  return %extract : tensor<1x1x5xf32>
+}
+
+
+// CHECK-LABEL:   func.func @no_bubble_up_extract_slice_on_rank_reducing(
+// CHECK:           %[[EXPAND:.*]] = tensor.expand_shape 
+// CHECK:           %[[EXTRACT:.*]] = tensor.extract_slice 
+// CHECK:           return %[[EXTRACT]]
+
+func.func @no_bubble_up_extract_slice_on_rank_reducing(%0: tensor<60xf32>) -> tensor<1x5xf32> {
+  %expand = tensor.expand_shape %0 [[0, 1, 2]] output_shape [2, 3, 10] : tensor<60xf32> into tensor<2x3x10xf32>
+  %extract = tensor.extract_slice %expand[0, 0, 5][1, 1, 5][1, 1, 1] : tensor<2x3x10xf32> to tensor<1x5xf32>
+  return %extract : tensor<1x5xf32>
+}
+
+
+// CHECK-LABEL:   func.func @bubble_up_extract_slice_through_expand_shape_multiple_expanded_dims(
+// CHECK-SAME:                                                    %[[ARG0:.*]]: tensor<120x56xf32>) -> tensor<1x2x10x1x4xf32> {
+// CHECK:           %[[C0:.+]] = arith.constant 0 : index
+// CHECK:           %[[EXTRACT:.*]] = tensor.extract_slice %[[ARG0]]{{\[}}%[[C0]], %[[C0]]] [20, 4] [1, 1] : tensor<120x56xf32> to tensor<20x4xf32>
+// CHECK:           %[[EXPAND:.*]] = tensor.expand_shape %[[EXTRACT]] {{\[\[}}0, 1, 2], [3, 4]] output_shape [1, 2, 10, 1, 4] : tensor<20x4xf32> into tensor<1x2x10x1x4xf32>
+// CHECK:           return %[[EXPAND]] : tensor<1x2x10x1x4xf32>
+
+func.func @bubble_up_extract_slice_through_expand_shape_multiple_expanded_dims(%0: tensor<120x56xf32>) -> tensor<1x2x10x1x4xf32> {
+  %expand = tensor.expand_shape %0 [[0, 1, 2], [3, 4]] output_shape [3, 4, 10, 7, 8] : tensor<120x56xf32> into tensor<3x4x10x7x8xf32>
+  %extract = tensor.extract_slice %expand[0, 0, 0, 0, 0][1, 2, 10, 1, 4][1, 1, 1, 1, 1] : tensor<3x4x10x7x8xf32> to tensor<1x2x10x1x4xf32>
+  return %extract : tensor<1x2x10x1x4xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
+    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.tensor.bubble_up_extract_slice
+    } : !transform.op<"func.func">
+    transform.yield
+  }
+}
----------------
banach-space wrote:

Missing EOF

https://github.com/llvm/llvm-project/pull/126898


More information about the Mlir-commits mailing list