[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to nested loop structure (PR #94190)

donald chen llvmlistbot at llvm.org
Sun Jul 14 22:05:55 PDT 2024


================
@@ -315,3 +315,97 @@ module attributes {transform.with_named_sequence} {
 //      CHECK:       }
 //      CHECK:   }
 //      CHECK:   return %[[FINAL_RESULT]]#1 :
+
+// -----
+
+#map = affine_map<(d0) -> (d0 * 128)>
+module {
+  func.func @fuse_tilable_consumer_nested_scf_loop(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
+    %c0 = arith.constant 0 : index
+    %c64 = arith.constant 64 : index
+    %c128 = arith.constant 128 : index
+    %cst = arith.constant 0.000000e+00 : f32
+    %dest0 = tensor.empty() : tensor<256x256xf32>
+    %dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+    %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %dest1) -> tensor<256x256xf32> {
+      %iv0 = affine.apply #map(%arg3)
+      %iv1 = affine.apply #map(%arg4)
+      %extracted_slice_1 = tensor.extract_slice %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<256x256xf32> to tensor<128x128xf32>
+      %extracted_slice_2 = tensor.extract_slice %arg0[%iv0, 0] [128, 512] [1, 1] : tensor<256x512xf32> to tensor<128x512xf32>
+      %extracted_slice_3 = tensor.extract_slice %arg1[0, %iv1] [512, 128] [1, 1] : tensor<512x256xf32> to tensor<512x128xf32>
+      %2 = scf.for %arg6 = %c0 to %c128 step %c64 iter_args(%arg7 = %extracted_slice_1) -> (tensor<128x128xf32>) {
+        %3 = scf.for %arg8 = %c0 to %c128 step %c64 iter_args(%arg9 = %arg7) -> (tensor<128x128xf32>) {
+          %extracted_slice_4 = tensor.extract_slice %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32>
+          %extracted_slice_5 = tensor.extract_slice %extracted_slice_2[%arg6, 0] [64, 512] [1, 1] : tensor<128x512xf32> to tensor<64x512xf32>
+          %extracted_slice_6 = tensor.extract_slice %extracted_slice_3[0, %arg8] [512, 64] [1, 1] : tensor<512x128xf32> to tensor<512x64xf32>
+          %4 = linalg.matmul ins(%extracted_slice_5, %extracted_slice_6 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_4 : tensor<64x64xf32>) -> tensor<64x64xf32>
+          %insert_slice = tensor.insert_slice %4 into %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x128xf32>
+          scf.yield %insert_slice : tensor<128x128xf32>
+        }
+        scf.yield %3 : tensor<128x128xf32>
+      }
+      scf.forall.in_parallel {
+         tensor.parallel_insert_slice %2 into %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<256x256xf32>
+      }
+    }
----------------
cxy-1993 wrote:

Sorry not to make myself clear. I find it strange that %dest is both the target of the insert slice and the operand of the computation. I thought the IR should like:
`
%0 = linalg.fill ... outs(%empty)
%init = tensor.empty(...)
%1 = scf.forall ... shared_outs(%arg0 = %init) {
   %2 = tensor.extract_slice %arg0[...]
   %3 = scf.for ... iter_args(%arg1=%2) {
      %4 = scf.for ... iter_args(%arg2=%3) {
         %5 = tensor.extract_slice %arg2[...]
         %6 = linalg.matmul ... outs(%5)
         %7 = tensor.insert_slice %6 
         scf.yield %7
      }
      scf.yield %4
   }
   scf.forall.in_parallel {
       tensor.insert_in_parallel %3 into %init
   }
}
%2 = linalg.add (.., %1)
`

My point is maybe you should create a new tensor to insert since you don't need the init value in %0.

https://github.com/llvm/llvm-project/pull/94190


More information about the Mlir-commits mailing list