<table border="1" cellspacing="0" cellpadding="8">
    <tr>
        <th>Issue</th>
        <td>
            <a href=https://github.com/llvm/llvm-project/issues/58398>58398</a>
        </td>
    </tr>

    <tr>
        <th>Summary</th>
        <td>
            [mlir] tileConsumerAndFuseProducerGreedilyUsingSCFForOp incorrectly fuses on reduction dim
        </td>
    </tr>

    <tr>
      <th>Labels</th>
      <td>
            new issue
      </td>
    </tr>

    <tr>
      <th>Assignees</th>
      <td>
      </td>
    </tr>

    <tr>
      <th>Reporter</th>
      <td>
          pzread
      </td>
    </tr>
</table>

<pre>
    When using `scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp` to tile and fuse `linalg.matmul` with `linalg.fill`, it doesn't keep `linalg.fill` outside the reduction loop, which results in incorrect result.

This can be reproduced with https://github.com/pzread/llvm-project/commit/9ea44bdee2f0b47289f80061248c88223c0c99ab. In the example below, we try to tile and fuse `linalg.matmul` on all 3 dims:
```mlir
func.func @gemm_fill_fusion_reduction(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %cst = arith.constant 0.0 : f32
  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
  %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
  %init = tensor.empty(%d0, %d1) : tensor<?x?xf32>
  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
  %gemm = linalg.matmul {__internal_linalg_transform__ = "fusion_reduction"}
      ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
      outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
  return %gemm : tensor<?x?xf32>
}
```
Currently it outputs:
```mlir
#map0 = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>                                                                                         
#map1 = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
#map2 = affine_map<(d0)[s0, s1] -> (30, -d0 + s1)>
module {
  func.func @gemm_fill_fusion_reduction(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %c10 = arith.constant 10 : index
    %c20 = arith.constant 20 : index
    %c30 = arith.constant 30 : index
    %cst = arith.constant 0.000000e+00 : f32
    %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
    %dim_0 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
    %0 = tensor.empty(%dim, %dim_0) : tensor<?x?xf32>
    %dim_1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
    %dim_2 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
    %dim_3 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
    %1 = scf.for %arg2 = %c0 to %dim_1 step %c10 iter_args(%arg3 = %0) -> (tensor<?x?xf32>) {
      %2 = affine.min #map0(%arg2)[%c10, %dim_1]
      %3 = scf.for %arg4 = %c0 to %dim_3 step %c20 iter_args(%arg5 = %arg3) -> (tensor<?x?xf32>) {
        %4 = affine.min #map1(%arg4)[%c20, %dim_3]
        %5 = scf.for %arg6 = %c0 to %dim_2 step %c30 iter_args(%arg7 = %arg5) -> (tensor<?x?xf32>) {
          %6 = affine.min #map2(%arg6)[%c30, %dim_2]
          %extracted_slice = tensor.extract_slice %arg0[%arg2, %arg6] [%2, %6] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
          %extracted_slice_4 = tensor.extract_slice %arg1[%arg6, %arg4] [%6, %4] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
          %extracted_slice_5 = tensor.extract_slice %arg7[%arg2, %arg4] [%2, %4] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
          %7 = linalg.fill ins(%cst : f32) outs(%extracted_slice_5 : tensor<?x?xf32>) -> tensor<?x?xf32>
          %8 = linalg.matmul {__internal_linalg_transform__ = "tiled"} ins(%extracted_slice, %extracted_slice_4 : tensor<?x?xf32>, tensor<?
x?xf32>) outs(%7 : tensor<?x?xf32>) -> tensor<?x?xf32>
          %inserted_slice = tensor.insert_slice %8 into %arg7[%arg2, %arg4] [%2, %4] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
          scf.yield %inserted_slice : tensor<?x?xf32>
        }
        scf.yield %5 : tensor<?x?xf32>
      }
      scf.yield %3 : tensor<?x?xf32>
    }
    return %1 : tensor<?x?xf32>
  }
}
```
The `linalg.fill` should be at the second level of loop `%3 = scf.for %arg4 = %c0 to %dim_3 step %c20` instead of the inner-most reduction loop, so it won't overwrite the reduction result.
</pre>
<img width="1px" height="1px" alt="" src="http://email.email.llvm.org/o/eJy9WFtzszYQ_TX4RRMPCDDw4IfEqTt9amf6dfrIYLTYakF4JJFLf31XkjEkwY6dy-chjtHl6JyVdlfSpmXPy793IEinuNgSb-GrsvLCW3w0r2HVCtU1IG8FW3cK_pAt60qQv0oAxuvnv0ynP1frdSt_32NfoltiupFCMFJhBwNYc1HU23lT6KarTaNHrnejiorXptijK8I1YS0o4dFEk38B9m-bkbbTijMgegdEAtLRvBWkbtu9QXjc8XKH5aqrtSJc4FO2UkKpD4Vzz7_3_Fv3_WPHFSkLQTYGa-_UMUdwp_VeGUvQNT5bLOo287Jt8GX_n4SC4Y-6fmhusNs_iI-vWNtw8yODIoo2DIBW_iZKaJpVqe8vAhqlZZpSGpZ-mWXFZk5-E1YIPBXNHu22gbp9tDpQoHy-zJ4ov6hrEhLGG8vYKUST2qepuXRFVSfKufkiXuRvoWlyY9QcgdGE-dGYHk09Ghdy6xNEIxqEaqUXrrxw_WT-qpB64S-GpWsWnG-WkRv8f6qeeMmdY0cMXGnGvCeFRHujtYXShdDEEeGCwdO4bTDVNphuq_Qk8NxBGy5DY-ZIOMZztOpBqH_QXJ43zAgomAYKeqDzphuAuOB6DAXNXj-7eWI9KxYYW1-GZ-bd4o2cC22mHKQzlrMKQhqPcxUHGh-e7BEDs_zGDNxyNqshz7nQILE0d3W5loVQVSubPLddPErfrlnqJfc9vvkc1Ywm7v3FujojbYw-2ORgyk_bRILupBiZ5p15PMo9Orp7XXUY7oSun000RZp7w_RMUPBo2BT7g9tVFReQ47sZlKZmbWVefKesAVXgxfdOC9YFtuzG-Aq9M3XYEmu-6zMmG1xNlk6SHWPSqzHDM5gNZhIM26PYdnXsvSz0flHkvS72Xhd9D60nwYOT6HSyPT3ZPpxsH55sfzIf2A_glPpvUoOL6bz5iuzQY-Uncs11KcKi-acyBG_6FGHGuzBLHAmeyGEflEvfQbtcrkELv8x4TiVufueYaA4w9JBtjEbciR3tobTZmbpFzTFV5dh2SDZh38s_eiDWnXHSF35oyYyD0bzhJivYIH0cg7qw5DiMJtcEqFdY4YSwaFJYOAijU8LivpcR-TFtllE0rS44DhQN6uhYXfhKnUWLJ_QtJvXRQV84pS8Z6Ys_qs9yWkwrpMehFoPCcKyQvlFo8eAJt0ClBparmpfwws9dVV9xcCYL7dZJnyoWJne5ir60L7LuYnPbGW-xZ5H3HOkU4zx6j3Nw5LwYOEcD5740-nmc4_c4J1N2jt7a-Rs5Jx_YxU_p_PT2dcQp_cS-3px2mdvMD0peET5YdWqJXbyzd7Rf6RxMlHyxSVAKyEn_dTXDskpRtgtZ37bA7AAXcjdx9ZlDzSZVXJBarQGSV0HtBer55TdOZi9hXoCEFyb6McRw4rrsGH48cJ04ef3YwcR9ldq1HXLcACm0vetRgNtNRmp4gJq0lb26IvYC7KO52gyDc6OhYAbQDMKFAHnTtEpPXJKp1pwNH1t3z9Y-gHzEbfDrK7XDfdkMlsFiEWepH8fRjC1DloVZMdNc17DENWdPkrjirr0uHG7m8Khq7rYUsYP2w6PKWSfr5Zl7OHP99vYWjivVAfrxOk7DLJ3tllkVJ2W6KahP0ywIWJWElLE4KzYsKSANZnWxgVotrU9RAY_EQpgoFN_P-BL70cAPEj-iNIzmZUyrxaLMqgVD1CzC8xw0Ba_nhse8lduZXFpKmw43FxGuB6XVUFkoxbcCrO0MftHpXSuX7k5xZkdeWub_A2n7DmY">