[Mlir-commits] [mlir] [MLIR][Linalg] Fix insert_slice fusion with rank reduction (PR #130961)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Tue May 20 09:09:45 PDT 2025
================
@@ -318,3 +318,76 @@ func.func @pad_generic_static(%small_input: tensor<58x1xf32>, %large_input: tens
}
return %for0 : tensor<64x128xf32>
}
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map4 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map5 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @rank_reduced_extract_slice(
+ %arg0: tensor<1x6x5xf32>, %arg1: tensor<1x5x6xf32>, %arg2: tensor<4x6xf32>,
+ %arg3: tensor<1x6x6xf32>, %arg4: tensor<4x6xf32>, %arg5: tensor<4x2xf32>
+) -> tensor<4x6xf32> {
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %c6 = arith.constant 6 : index
+ %0 = linalg.generic
+ {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+ ins(%arg0, %arg1 : tensor<1x6x5xf32>, tensor<1x5x6xf32>) outs(%arg3 : tensor<1x6x6xf32>) {
+ ^bb0(%in: f32, %in_1: f32, %out: f32):
+ %10 = arith.mulf %in, %in_1 : f32
+ %11 = arith.addf %out, %10 : f32
+ linalg.yield %11 : f32
+ } -> tensor<1x6x6xf32>
+ %1 = scf.for %arg7 = %c0 to %c6 step %c2 iter_args(%arg6 = %arg4) -> (tensor<4x6xf32>) {
+ %2 = tensor.extract_slice %0[0, 0, %arg7] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<6x2xf32>
----------------
banach-space wrote:
I would add some empty line and a big bold comment to highlight this rank-reducing extract_slice - this is the key part of this test.
https://github.com/llvm/llvm-project/pull/130961
More information about the Mlir-commits
mailing list