[Mlir-commits] [mlir] [Linalg] Add pattern to push down extract slice through linalg generic op (PR #154162)

Nirvedh Meshram llvmlistbot at llvm.org
Mon Aug 25 14:10:24 PDT 2025


================
@@ -1447,3 +1447,113 @@ func.func @push_unpack_in_padded_domain_out_used(%arg0: tensor<8x8x4x8xf32>, %ar
 // CHECK:         %[[UNPACK2:.+]] = linalg.unpack %[[GENERIC]]
 // CHECK-SAME:    into %[[ARG1]]
 // CHECK:         return %[[UNPACK2]] : tensor<?x64xf32>
+
+// -----
+
+module {
+  func.func @push_extract_through_generic(%arg0: tensor<128x7x128xf32>, %arg1: tensor<?x5x3x128xf32>, %arg2: tensor<?x5x128xbf16>, %arg3: index) -> tensor<?x5x128xbf16> {
+    %extracted_slice = tensor.extract_slice %arg0[0, 0, %arg3] [128, 7, %arg3] [1, 1, 1] : tensor<128x7x128xf32> to tensor<128x7x?xf32>
+    %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%extracted_slice, %arg1 : tensor<128x7x?xf32>, tensor<?x5x3x128xf32>) outs(%arg2 : tensor<?x5x128xbf16>) {
+    ^bb0(%in: f32, %in_0: f32, %out: bf16):
+      %1 = arith.truncf %in : f32 to bf16
+      linalg.yield %1 : bf16
+    } -> tensor<?x5x128xbf16>
+    return %0 : tensor<?x5x128xbf16>
+  }
+}
+
+// CHECK-LABEL: func.func @push_extract_through_generic
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG3:[a-zA-Z0-9]+]]
+// CHECK:         %[[POISON:.+]] = ub.poison : f32
+// CHECK:         %[[PADDED:.+]] = tensor.pad %arg1
+// CHECK:           tensor.yield %[[POISON]] : f32
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<128x5x128xbf16>
+// CHECK:         %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME:    ins(%[[ARG0]], %[[PADDED]]   
+// CHECK-SAME:    outs(%[[EMPTY]]
+// CHECK:         %[[EXTRACT:.+]] = tensor.extract_slice %3[%[[ARG3]], 0, 0] [%[[ARG3]], 5, 128] [1, 1, 1] : tensor<128x5x128xbf16> to tensor<?x5x128xbf16>
+// CHECK:         return %[[EXTRACT]]
+
+// -----
+
+func.func @nopush_extract_through_generic_nodimexpr1(%arg0: tensor<128x7x128xf32>, %arg1: tensor<?x5x3x128xf32>, %arg2: tensor<?x5x128xbf16>, %arg3: index) -> tensor<?x5x128xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[0, %arg3, %arg3] [128, 7, %arg3] [1, 1, 1] : tensor<128x7x128xf32> to tensor<128x7x?xf32>
+  %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%extracted_slice, %arg1 : tensor<128x7x?xf32>, tensor<?x5x3x128xf32>) outs(%arg2 : tensor<?x5x128xbf16>) {
+  ^bb0(%in: f32, %in_0: f32, %out: bf16):
+    %1 = arith.truncf %in : f32 to bf16
+    linalg.yield %1 : bf16
+  } -> tensor<?x5x128xbf16>
+  return %0 : tensor<?x5x128xbf16>
+}
+
+// CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr1
+// CHECK:         %[[GENERIC:.+]] = linalg.generic
+// CHECK:         return %[[GENERIC]]          
+
+// -----
+
+func.func @nopush_extract_through_generic_nodimexpr2(%arg0: tensor<128x?x128xf32>, %arg1: tensor<128x5x3x128xf32>, %arg2: tensor<128x?x128xbf16>, %arg3: index) -> tensor<128x?x128xbf16> {
+  %extracted_slice = tensor.extract_slice %arg1[0, %arg3, 0, 0] [128, %arg3, 3, 128] [1, 1, 1, 1] : tensor<128x5x3x128xf32> to tensor<128x?x3x128xf32>
+  %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %extracted_slice : tensor<128x?x128xf32>, tensor<128x?x3x128xf32>) outs(%arg2 : tensor<128x?x128xbf16>) {
+  ^bb0(%in: f32, %in_0: f32, %out: bf16):
+    %1 = arith.truncf %in : f32 to bf16
+    linalg.yield %1 : bf16
+  } -> tensor<128x?x128xbf16>
+  return %0 : tensor<128x?x128xbf16>
+}
+
+// CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr2
+// CHECK:         %[[GENERIC:.+]] = linalg.generic
+// CHECK:         return %[[GENERIC]]   
+
+// -----
+
+func.func @push_redcutionextract_through_generic_withoutsused_2(%arg0: tensor<128x128xf32>, %arg1: tensor<?xbf16>, %arg2: index) -> tensor<?xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[%arg2, %arg2] [%arg2, %arg2] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
+  %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice : tensor<?x?xf32>) outs(%arg1 : tensor<?xbf16>) {
+  ^bb0(%in: f32, %out: bf16):
+    %1 = arith.truncf %in : f32 to bf16
+    %2 = arith.addf %1, %out : bf16
+    linalg.yield %2 : bf16
+  } -> tensor<?xbf16>
+  return %0 : tensor<?xbf16>
+}
+
+// CHECK-LABEL: func.func @push_redcutionextract_through_generic_withoutsused_2
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK:         %[[POISON_BF16:.+]] = ub.poison : bf16
+// CHECK:         %[[POISON_F32:.+]] = ub.poison : f32
+// CHECK:         %[[EXTRACT:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], %[[ARG2]]] [%[[ARG2]], %[[ARG2]]] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
+// CHECK:         %[[PADDED:.+]] = tensor.pad %[[EXTRACT]]
+// CHECK:           tensor.yield %[[POISON_F32]] : f32
+// CHECK:         %[[APPLY2:.+]] = affine.apply #map()[%[[ARG2]]]
+// CHECK:         %[[PADDED1:.+]] = tensor.pad %[[ARG1]] low[%[[ARG2]]] high[%[[APPLY2]]]
+// CHECK:           tensor.yield %[[POISON_BF16]] : bf16
----------------
nirvedhmeshram wrote:

Added the shapes to the pad ops, the reason I had skipped it was I am letting the padOp builder decide the result type and passing `Type()` to it, note that the dims post padding in these tests are dynamic but they should resolve to static values with some canonicalization patterns. 

https://github.com/llvm/llvm-project/pull/154162


More information about the Mlir-commits mailing list