[Mlir-commits] [mlir] [MLIR][Linalg] Fix insert_slice fusion with rank reduction (PR #130961)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Wed Mar 26 02:24:09 PDT 2025
================
@@ -318,3 +318,50 @@ func.func @pad_generic_static(%small_input: tensor<58x1xf32>, %large_input: tens
}
return %for0 : tensor<64x128xf32>
}
+
+// -----
+
+func.func @rank_reduced_extract_slice(%cond : i1) -> tensor<6x2xf32> {
+ %cst = arith.constant 0.0 : f32
+ %cst1 = arith.constant 1.0 : f32
+
+ %empty1 = tensor.empty() : tensor<6x6x1x1x1x1xf32>
+ %init1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} outs(%empty1 : tensor<6x6x1x1x1x1xf32>) {
+ ^bb0(%out: f32):
+ linalg.yield %cst : f32
+ } -> tensor<6x6x1x1x1x1xf32>
+
+ %if = scf.if %cond -> tensor<6x2xf32> {
+ %extract0 = tensor.extract_slice %init1[0, 0, 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2xf32>
+
+ %init2 = tensor.empty() : tensor<6x2xf32>
+ %add1 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extract0 : tensor<6x2xf32>) outs(%init2 : tensor<6x2xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %add = arith.addf %in, %cst1 : f32
+ linalg.yield %add : f32
+ } -> tensor<6x2xf32>
+ scf.yield %add1 : tensor<6x2xf32>
+ } else {
+ %extract2 = tensor.extract_slice %init1[0, 2, 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2xf32>
+ scf.yield %extract2 : tensor<6x2xf32>
+ }
+
+ return %if : tensor<6x2xf32>
+}
+
+// CHECK: func @rank_reduced_extract_slice(
+// CHECK-SAME: %[[COND:[0-9a-z]*]]: i1
----------------
banach-space wrote:
Is `COND` in any way significant to this test? I don't see it being re-used below this.
https://github.com/llvm/llvm-project/pull/130961
More information about the Mlir-commits
mailing list