[Mlir-commits] [mlir] [mlir][scf] Extend consumer fuse to nested loop structure (PR #94190)

donald chen llvmlistbot at llvm.org
Sun Jul 14 01:46:10 PDT 2024


================
@@ -315,3 +315,97 @@ module attributes {transform.with_named_sequence} {
 //      CHECK:       }
 //      CHECK:   }
 //      CHECK:   return %[[FINAL_RESULT]]#1 :
+
+// -----
+
+#map = affine_map<(d0) -> (d0 * 128)>
+module {
+  func.func @fuse_tilable_consumer_nested_scf_loop(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
+    %c0 = arith.constant 0 : index
+    %c64 = arith.constant 64 : index
+    %c128 = arith.constant 128 : index
+    %cst = arith.constant 0.000000e+00 : f32
+    %dest0 = tensor.empty() : tensor<256x256xf32>
+    %dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+    %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %dest1) -> tensor<256x256xf32> {
+      %iv0 = affine.apply #map(%arg3)
+      %iv1 = affine.apply #map(%arg4)
+      %extracted_slice_1 = tensor.extract_slice %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<256x256xf32> to tensor<128x128xf32>
+      %extracted_slice_2 = tensor.extract_slice %arg0[%iv0, 0] [128, 512] [1, 1] : tensor<256x512xf32> to tensor<128x512xf32>
+      %extracted_slice_3 = tensor.extract_slice %arg1[0, %iv1] [512, 128] [1, 1] : tensor<512x256xf32> to tensor<512x128xf32>
+      %2 = scf.for %arg6 = %c0 to %c128 step %c64 iter_args(%arg7 = %extracted_slice_1) -> (tensor<128x128xf32>) {
+        %3 = scf.for %arg8 = %c0 to %c128 step %c64 iter_args(%arg9 = %arg7) -> (tensor<128x128xf32>) {
+          %extracted_slice_4 = tensor.extract_slice %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32>
+          %extracted_slice_5 = tensor.extract_slice %extracted_slice_2[%arg6, 0] [64, 512] [1, 1] : tensor<128x512xf32> to tensor<64x512xf32>
+          %extracted_slice_6 = tensor.extract_slice %extracted_slice_3[0, %arg8] [512, 64] [1, 1] : tensor<512x128xf32> to tensor<512x64xf32>
+          %4 = linalg.matmul ins(%extracted_slice_5, %extracted_slice_6 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_4 : tensor<64x64xf32>) -> tensor<64x64xf32>
+          %insert_slice = tensor.insert_slice %4 into %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x128xf32>
+          scf.yield %insert_slice : tensor<128x128xf32>
+        }
+        scf.yield %3 : tensor<128x128xf32>
+      }
+      scf.forall.in_parallel {
+         tensor.parallel_insert_slice %2 into %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<256x256xf32>
+      }
+    }
----------------
cxy-1993 wrote:

Why is your case %dest1 uses as shared_outs and result both?

Is a fuse still valid if the parallel insert slice does not write the entire shared_out?

https://github.com/llvm/llvm-project/pull/94190


More information about the Mlir-commits mailing list