[Mlir-commits] [mlir] [MLIR][Linalg] Fix insert_slice fusion with rank reduction (PR #130961)

Andrzej WarzyƄski llvmlistbot at llvm.org
Fri May 16 10:20:23 PDT 2025


================
@@ -318,3 +318,81 @@ func.func @pad_generic_static(%small_input: tensor<58x1xf32>, %large_input: tens
   }
   return %for0 : tensor<64x128xf32>
 }
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map4 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map5 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @rank_reduced_extract_slice(%arg0: tensor<1x6x5xf32>, %arg1: tensor<1x5x6xf32>, %arg2: tensor<4x6xf32>) -> tensor<4x6xf32> {
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %c6 = arith.constant 6 : index
+  %cst = arith.constant 0.0 : f32
+  %init1 = tensor.empty() : tensor<1x6x6xf32>
+  %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<1x6x6xf32>) -> tensor<1x6x6xf32>
+  %0 = linalg.generic
+    {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+    ins(%arg0, %arg1 : tensor<1x6x5xf32>, tensor<1x5x6xf32>) outs(%fill1 : tensor<1x6x6xf32>) {
+  ^bb0(%in: f32, %in_1: f32, %out: f32):
+    %10 = arith.mulf %in, %in_1 : f32
+    %11 = arith.addf %out, %10 : f32
+    linalg.yield %11 : f32
+  } -> tensor<1x6x6xf32>
+  %init2 = tensor.empty() : tensor<4x6xf32>
+  %1 = scf.for %arg4 = %c0 to %c6 step %c2 iter_args(%arg3 = %init2) -> (tensor<4x6xf32>) {
+    %2 = tensor.extract_slice %0[0, 0, %arg4] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<6x2xf32>
+    %init3 = tensor.empty() : tensor<4x2xf32>
+    %fill3 = linalg.fill ins(%cst : f32) outs(%init3 : tensor<4x2xf32>) -> tensor<4x2xf32>
----------------
banach-space wrote:

is this required?

https://github.com/llvm/llvm-project/pull/130961


More information about the Mlir-commits mailing list