[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to nested loop structure (PR #94190)
donald chen
llvmlistbot at llvm.org
Sun Jul 14 22:05:55 PDT 2024
================
@@ -315,3 +315,97 @@ module attributes {transform.with_named_sequence} {
// CHECK: }
// CHECK: }
// CHECK: return %[[FINAL_RESULT]]#1 :
+
+// -----
+
+#map = affine_map<(d0) -> (d0 * 128)>
+module {
+ func.func @fuse_tilable_consumer_nested_scf_loop(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %dest0 = tensor.empty() : tensor<256x256xf32>
+ %dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+ %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %dest1) -> tensor<256x256xf32> {
+ %iv0 = affine.apply #map(%arg3)
+ %iv1 = affine.apply #map(%arg4)
+ %extracted_slice_1 = tensor.extract_slice %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<256x256xf32> to tensor<128x128xf32>
+ %extracted_slice_2 = tensor.extract_slice %arg0[%iv0, 0] [128, 512] [1, 1] : tensor<256x512xf32> to tensor<128x512xf32>
+ %extracted_slice_3 = tensor.extract_slice %arg1[0, %iv1] [512, 128] [1, 1] : tensor<512x256xf32> to tensor<512x128xf32>
+ %2 = scf.for %arg6 = %c0 to %c128 step %c64 iter_args(%arg7 = %extracted_slice_1) -> (tensor<128x128xf32>) {
+ %3 = scf.for %arg8 = %c0 to %c128 step %c64 iter_args(%arg9 = %arg7) -> (tensor<128x128xf32>) {
+ %extracted_slice_4 = tensor.extract_slice %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32>
+ %extracted_slice_5 = tensor.extract_slice %extracted_slice_2[%arg6, 0] [64, 512] [1, 1] : tensor<128x512xf32> to tensor<64x512xf32>
+ %extracted_slice_6 = tensor.extract_slice %extracted_slice_3[0, %arg8] [512, 64] [1, 1] : tensor<512x128xf32> to tensor<512x64xf32>
+ %4 = linalg.matmul ins(%extracted_slice_5, %extracted_slice_6 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_4 : tensor<64x64xf32>) -> tensor<64x64xf32>
+ %insert_slice = tensor.insert_slice %4 into %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x128xf32>
+ scf.yield %insert_slice : tensor<128x128xf32>
+ }
+ scf.yield %3 : tensor<128x128xf32>
+ }
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %2 into %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<256x256xf32>
+ }
+ }
----------------
cxy-1993 wrote:
Sorry not to make myself clear. I find it strange that %dest is both the target of the insert slice and the operand of the computation. I thought the IR should like:
`
%0 = linalg.fill ... outs(%empty)
%init = tensor.empty(...)
%1 = scf.forall ... shared_outs(%arg0 = %init) {
%2 = tensor.extract_slice %arg0[...]
%3 = scf.for ... iter_args(%arg1=%2) {
%4 = scf.for ... iter_args(%arg2=%3) {
%5 = tensor.extract_slice %arg2[...]
%6 = linalg.matmul ... outs(%5)
%7 = tensor.insert_slice %6
scf.yield %7
}
scf.yield %4
}
scf.forall.in_parallel {
tensor.insert_in_parallel %3 into %init
}
}
%2 = linalg.add (.., %1)
`
My point is maybe you should create a new tensor to insert since you don't need the init value in %0.
https://github.com/llvm/llvm-project/pull/94190
More information about the Mlir-commits
mailing list