[Mlir-commits] [mlir] [mlir][linalg] Improve linalg.pack consumer fusion. (PR #148993)

Han-Chung Wang llvmlistbot at llvm.org
Thu Jul 17 11:17:16 PDT 2025


================
@@ -451,6 +451,54 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @fuse_pack_consumer_with_padding_semantic_into_scf_forall(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x32x3x16xf32> {
+  %0 = scf.forall (%arg2) in (2) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
+    %extracted_slice = tensor.extract_slice %arg0[0, %arg2] [64, 32] [1, 1] : tensor<64x32xf32> to tensor<64x32xf32>
+    %extracted_slice_0 = tensor.extract_slice %arg3[0, %arg2] [64, 32] [1, 1] : tensor<64x32xf32> to tensor<64x32xf32>
+    %2 = linalg.exp ins(%extracted_slice : tensor<64x32xf32>) outs(%extracted_slice_0 : tensor<64x32xf32>) -> tensor<64x32xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 32] [1, 1] : tensor<64x32xf32> into tensor<64x32xf32>
+    }
+  }
+  %1 = tensor.empty() : tensor<23x32x3x16xf32>
+  %cst = arith.constant 0.000000e+00 : f32
+  %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<23x32x3x16xf32>
----------------
hanhanW wrote:

Good catch! I think I had a fat finger, it should be `2` in this case. I update the logic a bit, which makes sure that the fusion does not happen in this case.

I found that the other test used wrong number of thread, so I fixed it as well.

https://github.com/llvm/llvm-project/pull/148993


More information about the Mlir-commits mailing list