[Mlir-commits] [mlir] [MLIR][Linalg] Fix insert_slice fusion with rank reduction (PR #130961)

Andrzej WarzyƄski llvmlistbot at llvm.org
Tue May 20 09:09:45 PDT 2025


================
@@ -318,3 +318,76 @@ func.func @pad_generic_static(%small_input: tensor<58x1xf32>, %large_input: tens
   }
   return %for0 : tensor<64x128xf32>
 }
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map4 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map5 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @rank_reduced_extract_slice(
+    %arg0: tensor<1x6x5xf32>, %arg1: tensor<1x5x6xf32>, %arg2: tensor<4x6xf32>,
+    %arg3: tensor<1x6x6xf32>, %arg4: tensor<4x6xf32>, %arg5: tensor<4x2xf32>
+) -> tensor<4x6xf32> {
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %c6 = arith.constant 6 : index
+  %0 = linalg.generic
+    {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+    ins(%arg0, %arg1 : tensor<1x6x5xf32>, tensor<1x5x6xf32>) outs(%arg3 : tensor<1x6x6xf32>) {
+  ^bb0(%in: f32, %in_1: f32, %out: f32):
+    %10 = arith.mulf %in, %in_1 : f32
+    %11 = arith.addf %out, %10 : f32
+    linalg.yield %11 : f32
+  } -> tensor<1x6x6xf32>
+  %1 = scf.for %arg7 = %c0 to %c6 step %c2 iter_args(%arg6 = %arg4) -> (tensor<4x6xf32>) {
+    %2 = tensor.extract_slice %0[0, 0, %arg7] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<6x2xf32>
----------------
banach-space wrote:

I would add some empty line and a big bold comment to highlight this rank-reducing extract_slice - this is the key part of this test.

https://github.com/llvm/llvm-project/pull/130961


More information about the Mlir-commits mailing list