[Mlir-commits] [mlir] [mlir][linalg] Improve linalg.pack consumer fusion. (PR #148993)
Han-Chung Wang
llvmlistbot at llvm.org
Thu Jul 17 11:17:16 PDT 2025
================
@@ -451,6 +451,54 @@ module attributes {transform.with_named_sequence} {
// -----
+func.func @fuse_pack_consumer_with_padding_semantic_into_scf_forall(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x32x3x16xf32> {
+ %0 = scf.forall (%arg2) in (2) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
+ %extracted_slice = tensor.extract_slice %arg0[0, %arg2] [64, 32] [1, 1] : tensor<64x32xf32> to tensor<64x32xf32>
+ %extracted_slice_0 = tensor.extract_slice %arg3[0, %arg2] [64, 32] [1, 1] : tensor<64x32xf32> to tensor<64x32xf32>
+ %2 = linalg.exp ins(%extracted_slice : tensor<64x32xf32>) outs(%extracted_slice_0 : tensor<64x32xf32>) -> tensor<64x32xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 32] [1, 1] : tensor<64x32xf32> into tensor<64x32xf32>
+ }
+ }
+ %1 = tensor.empty() : tensor<23x32x3x16xf32>
+ %cst = arith.constant 0.000000e+00 : f32
+ %pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<23x32x3x16xf32>
----------------
hanhanW wrote:
Good catch! I think I had a fat finger, it should be `2` in this case. I update the logic a bit, which makes sure that the fusion does not happen in this case.
I found that the other test used wrong number of thread, so I fixed it as well.
https://github.com/llvm/llvm-project/pull/148993
More information about the Mlir-commits
mailing list