<table border="1" cellspacing="0" cellpadding="8">
    <tr>
        <th>Issue</th>
        <td>
            <a href=https://github.com/llvm/llvm-project/issues/61386>61386</a>
        </td>
    </tr>

    <tr>
        <th>Summary</th>
        <td>
            [MLIR] Tile and fuse does not play well with bufferization
        </td>
    </tr>

    <tr>
      <th>Labels</th>
      <td>
            new issue
      </td>
    </tr>

    <tr>
      <th>Assignees</th>
      <td>
      </td>
    </tr>

    <tr>
      <th>Reporter</th>
      <td>
          chelini
      </td>
    </tr>
</table>

<pre>
    Reproducer:
```
#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d4, d6)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2, d6 floordiv 2, d5, d3)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d4, d5)>
#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map4 = affine_map<(d0, d1, d2, d3) -> (d1, d3)>

func.func @mlp(%arg0: tensor<8x48x32x32xbf16>, %arg1: tensor<48x48x16x32x2xbf16>, %arg2: tensor<1536xbf16>, %arg3: tensor<8x48x32x32xbf16>, %arg4: tensor<48x48x16x32x2xbf16>, %arg5: tensor<1536xbf16>, %arg6: tensor<8x48x32x32xbf16>, %arg7: tensor<48x48x16x32x2xbf16>, %arg8: tensor<1536xbf16>, %arg9: tensor<8x48x32x32xbf16> ) -> tensor<8x48x32x32xbf16> {
  %cst = arith.constant 0.000000e+00 : bf16
 %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<8x48x32x32xbf16>, tensor<48x48x16x32x2xbf16>) outs(%arg3 : tensor<8x48x32x32xbf16>) {
 ^bb0(%in: bf16, %in_0: bf16, %out: bf16):
      %mul = arith.mulf %in, %in_0 : bf16
      %add = arith.addf %out, %mul : bf16
 linalg.yield %add : bf16
  } -> tensor<8x48x32x32xbf16>
  %expanded = tensor.expand_shape %arg2 [[0, 1]] : tensor<1536xbf16> into tensor<48x32xbf16>
  %2 = linalg.generic {indexing_maps = [#map3, #map4, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%0, %expanded : tensor<8x48x32x32xbf16>, tensor<48x32xbf16>) outs(%arg3 : tensor<8x48x32x32xbf16>) {
    ^bb0(%in: bf16, %in_0: bf16, %out: bf16):
      %add = arith.addf %in, %in_0 : bf16
      linalg.yield %add : bf16
  } -> tensor<8x48x32x32xbf16>
 %3 = linalg.generic {__internal_linalg_transform__ = "fusion", indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%2 : tensor<8x48x32x32xbf16>) outs(%arg3 : tensor<8x48x32x32xbf16>) {
    ^bb0(%in: bf16, %out: bf16):
      %max = arith.maxf %in, %cst : bf16
      linalg.yield %max : bf16
  } -> tensor<8x48x32x32xbf16>

  
  return %3 : tensor<8x48x32x32xbf16>
}
```

using `mlir-opt -test-tiling-interface=tile-consumer-and-fuse-producer-using-scf-for -cse`
We get the following IR:

```#map = affine_map<(d0) -> (10, -d0 + 8)>
#map1 = affine_map<(d0) -> (20, -d0 + 48)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d4, d6)>
#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2, d6 floordiv 2, d5, d3)>
#map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d4, d5)>
#map5 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map6 = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
module {
  func.func @mlp(%arg0: tensor<8x48x32x32xbf16>, %arg1: tensor<48x48x16x32x2xbf16>, %arg2: tensor<1536xbf16>, %arg3: tensor<8x48x32x32xbf16>, %arg4: tensor<48x48x16x32x2xbf16>, %arg5: tensor<1536xbf16>, %arg6: tensor<8x48x32x32xbf16>, %arg7: tensor<48x48x16x32x2xbf16>, %arg8: tensor<1536xbf16>, %arg9: tensor<8x48x32x32xbf16>) -> tensor<8x48x32x32xbf16> {
    %c10 = arith.constant 10 : index
    %c20 = arith.constant 20 : index
    %c0 = arith.constant 0 : index
    %c8 = arith.constant 8 : index
    %c48 = arith.constant 48 : index
    %cst = arith.constant 0.000000e+00 : bf16
    %expanded = tensor.expand_shape %arg2 [[0, 1]] : tensor<1536xbf16> into tensor<48x32xbf16>
    %0 = scf.for %arg10 = %c0 to %c8 step %c10 iter_args(%arg11 = %arg3) -> (tensor<8x48x32x32xbf16>) {
 %1 = affine.min #map(%arg10)
      %2 = scf.for %arg12 = %c0 to %c48 step %c20 iter_args(%arg13 = %arg11) -> (tensor<8x48x32x32xbf16>) {
 %3 = affine.min #map1(%arg12)
        %extracted_slice = tensor.extract_slice %arg0[%arg10, 0, 0, 0] [%1, 48, 32, 32] [1, 1, 1, 1] : tensor<8x48x32x32xbf16> to tensor<?x48x32x32xbf16>
 %extracted_slice_0 = tensor.extract_slice %arg1[%arg12, 0, 0, 0, 0] [%3, 48, 16, 32, 2] [1, 1, 1, 1, 1] : tensor<48x48x16x32x2xbf16> to tensor<?x48x16x32x2xbf16>
        %extracted_slice_1 = tensor.extract_slice %arg3[%arg10, %arg12, 0, 0] [%1, %3, 32, 32] [1, 1, 1, 1] : tensor<8x48x32x32xbf16> to tensor<?x?x32x32xbf16>
        %4 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%extracted_slice, %extracted_slice_0 : tensor<?x48x32x32xbf16>, tensor<?x48x16x32x2xbf16>) outs(%extracted_slice_1 : tensor<?x?x32x32xbf16>) {
        ^bb0(%in: bf16, %in_4: bf16, %out: bf16):
          %7 = arith.mulf %in, %in_4 : bf16
          %8 = arith.addf %out, %7 : bf16
          linalg.yield %8 : bf16
        } -> tensor<?x?x32x32xbf16>
        %extracted_slice_2 = tensor.extract_slice %expanded[%arg12, 0] [%3, 32] [1, 1] : tensor<48x32xbf16> to tensor<?x32xbf16>
        %5 = linalg.generic {indexing_maps = [#map5, #map6, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%4, %extracted_slice_2 : tensor<?x?x32x32xbf16>, tensor<?x32xbf16>) outs(%extracted_slice_1 : tensor<?x?x32x32xbf16>) {
        ^bb0(%in: bf16, %in_4: bf16, %out: bf16):
          %7 = arith.addf %in, %in_4 : bf16
 linalg.yield %7 : bf16
        } -> tensor<?x?x32x32xbf16>
 %extracted_slice_3 = tensor.extract_slice %arg13[%arg10, %arg12, 0, 0] [%1, %3, 32, 32] [1, 1, 1, 1] : tensor<8x48x32x32xbf16> to tensor<?x?x32x32xbf16>
        %6 = linalg.generic {indexing_maps = [#map5, #map5], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%5 : tensor<?x?x32x32xbf16>) outs(%extracted_slice_3 : tensor<?x?x32x32xbf16>) attrs =  {__internal_linalg_transform__ = "tiled"} {
        ^bb0(%in: bf16, %out: bf16):
          %7 = arith.maxf %in, %cst : bf16
          linalg.yield %7 : bf16
        } -> tensor<?x?x32x32xbf16>
        %inserted_slice = tensor.insert_slice %6 into %arg13[%arg10, %arg12, 0, 0] [%1, %3, 32, 32] [1, 1, 1, 1] : tensor<?x?x32x32xbf16> into tensor<8x48x32x32xbf16>
        scf.yield %inserted_slice : tensor<8x48x32x32xbf16>
      }
      scf.yield %2 : tensor<8x48x32x32xbf16>
    }
    return %0 : tensor<8x48x32x32xbf16>
  }
}
```
</pre>
<img width="1px" height="1px" alt="" src="http://email.email.llvm.org/o/eJzsWl9v4joW_zTmxQqK7SSkDzy0ZSqNtPsyWmkfkUkc8Mo4yHam7X76K8dAQnBI0qG9qnRRMZCcPzk-v3Psn2eo1nwrGVuC-AnEqxmtzK5Uy2zHBJd8tinz9-UvdlBlXmVMAfIIwhUIH0ESHv_cT0z29AABWUFaFFyy9Z4eAHkGOM1DgJ9hjuoR1yOpx6ge43pMAH6AASA_YKOBW1L2PiA_2s7Qvby1ZRNYiLJUOf8NcUuBXLvH9w0WXUp1vZEJ3nps475Ioo_ZRh579VhUMpvbAYIo3IsDwCnAMVXbEJBHaJjUpQLkOX2L0jeC7d-mQIk1gp-hk0QXklEtihIr6hHGF8IoJsm1DBntOpriOh7hOhntejHFdTrC9cOAa3hO6E2hxZNLLLRmM20cXhQ3u3lWSm2oNDCch_WLAfwUhtA6rvWdIsBxWGsJLqnYzrdMMsUza5rLnL1xubXI07WMbUQ1MF0gda03XzGIV_YXN0xRU6q1eT-wliI-UEWFYAJg7LQ8VxTLq8zwUt68NNlSvAKLFeRStyDfYBqOAMJQ9h9gWZnGPBm2-dDKH4h_bDah0-bynCP3iFyuw86lsjLNlYdz74f1C-B4X4kWFvaVKJyhlsUuEk6qNM9bqjTPi5NDp-tMX6gesfPOmcgbE5fW7fwPAboFZvZ2oDJn7lGcytxdW-sdPbBTh6mhFT_V2UQ2zfEK9tYf5NKUF5n0-sbTC4I0ZRA1X8k9KuL2lQ6wT6huTd8kZJO7IdqC6c6g9iNzGNR3hifAMelByHrNpWFKUrF2N9dGUamLUu3Xa5d3jItKN81sHKa-Hkh4TLY_Ex-DHY6-tTscfbsEg1sMh7HgzHwUC2eN46diplLyhJDbs-EMLFb-XXs9VprLLQRJuBdcBeXBwMAwbQLDBZfboMZaQTMGyMpwwQK75ld7pgIq86CoNAtO7CCoTQU6K4KiVDDINDv7-i-DW2ag2TFYlEKUr9bpz18Noeg84G1C0exDUd2MgjyEAD_BdApPaIzgCyORx8qdt_sD3GbKdv8TuM0URvDn3Cb-RG6T3I3b7Mu8EqzdWv7hOd-F50ylOY7ooNBHdJBb_esV9VIee-Vxr7xXvFc69UmnfdKRVzzqlf8IqYN_7xYaNqxSZ8XcLjjHogqPeyA7xaY8zp427HBKq93grKnaNtsKhE46dRG1esF4coXj9lIz33MJTzz25MUuOp0dBvZFgK8jiFohYG8IpBUCQh-NgfhjQI0f3AniCAOjaGZYvtaCZ-wSDfWt041ji7QbytOcPMP2EB-3m3Hdhe1i_AwJPo7uJnIoaoYumDz13QYUIC839t2dWNbhUDSoiQZ3o7kMiTQhuT2oC6w3Lm9w3o7pi6_bVAeStkZDgZJO2jxBX2bvFPIn5M--vbz-HF40nWHjNhtqke1vc-bUSemZqF8j-nGgHNqcvQdPF-TMB6bHoZR1ydoYQh9N4XBHKCwGzqkiL407Kqc3T6oWvapdFpj2SF4zwXH47s44vlm-p4X6qlddNqdujXqaz42ivPG08fRqjJsSTJqv8ZcfU0Q9ZYRHQfy5d4q-Z_34jsSu6qeL_b4qmYZ9zzSRwcX5ey1ayR-VydfXRjwSpv1AJyMtUGOUC2X0IajhguX2oReriXUyeVEZezToWxjuUxzNg3GpmfLuxd2dpj4Sx7S-ulK8YXRIX_8_3LiXpUznGbyKeMSp6HmO2z8vrA4fTzdctG2mOaANx1pojmivzmpn-ZLkD-SBztgSJYt0ge17tlvmEQsJTnEepSgNw7zY5AxHhGSbHCdFHs34EoeYhAQRtECLGM9xhOIUFaxIWIzoYgOikO0pF3Mhfu_npdrOuNYVWyaIpMlM0A0Teukah2SvsL7pesBMLa1OsKm2GkSh4NroxorhRtT_o-Tf__r5y2b_P1wwSGUOi0ozmJdMQ1kaeBD0Hb4yIeArNzu4qYqCKf5_aje3s0qJ5c6Yg7bVh18Aftlys6s286zcA_xinR0_goMq_8cyA_BL_Yga4Jc6hL8CAAD__-DAdUc">