<table border="1" cellspacing="0" cellpadding="8">
    <tr>
        <th>Issue</th>
        <td>
            <a href=https://github.com/llvm/llvm-project/issues/57205>57205</a>
        </td>
    </tr>

    <tr>
        <th>Summary</th>
        <td>
            Allow `tileConsumerAndFuseProducers` must return the value of the fused producer as well
        </td>
    </tr>

    <tr>
      <th>Labels</th>
      <td>
            mlir:linalg
      </td>
    </tr>

    <tr>
      <th>Assignees</th>
      <td>
      </td>
    </tr>

    <tr>
      <th>Reporter</th>
      <td>
          MaheshRavishankar
      </td>
    </tr>
</table>

<pre>
    I hit this before, but hit this again, so filing an issue for it. Consider the following code

```
func.func @tile_and_fuse_err(%10 : tensor<12xf32>, %11 : tensor<12x12x12x12x12xf32>, %12 : tensor<12xf32>) -> (tensor<12xf32>, tensor<12xf32>, tensor<12xf32>)
{
  %cst = arith.constant 1.42 : f32
  %cst_1 = arith.constant 1.45 : f32
  %cst_0 = arith.constant 1.3 : f32
  %cst_2 = arith.constant 0.0 : f32
  %13 = linalg.init_tensor [12] : tensor<12xf32>
  %14 = linalg.fill ins(%cst_2 : f32) outs(%13 : tensor<12xf32>) -> tensor<12xf32>
  %15 = linalg.generic {
      indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d3, d4, d0)>,
                         affine_map<(d0, d1, d2, d3, d4) -> (d0)>, affine_map<(d0, d1, d2, d3, d4) -> (d0)>],
        iterator_types = ["parallel", "reduction", "reduction", "reduction", "reduction"]}
        ins(%11, %12 : tensor<12x12x12x12x12xf32>, tensor<12xf32>) outs(%14 : tensor<12xf32>) {
    ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
      %19 = arith.subf %arg1, %arg2 : f32
      %20 = arith.mulf %19, %19 : f32
      %21 = arith.addf %arg3, %20 : f32
      linalg.yield %21 : f32
    } -> tensor<12xf32>
  %16 = linalg.generic {
      indexing_maps = [affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} outs(%15 : tensor<12xf32>) {
    ^bb0(%arg1: f32):
      %19 = arith.divf %arg1, %cst_1 : f32
      %20 = arith.addf %19, %cst_0 : f32
      linalg.yield %20 : f32
    } -> tensor<12xf32>
  %17 = linalg.generic {
      indexing_maps = [affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} outs(%16 : tensor<12xf32>) {
    ^bb0(%arg1: f32):
      %19 = math.sqrt %arg1 : f32
      linalg.yield %19 : f32
    } -> tensor<12xf32>
  %18 = linalg.generic {
      indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]}
      {__internal_linalg_transform__ = "tensor_fuse_err"}
      ins(%10 : tensor<12xf32>) outs(%17 : tensor<12xf32>)  {
    ^bb0(%arg1: f32, %arg2: f32):
      %19 = arith.subf %arg1, %arg2 : f32
      %20 = arith.mulf %19, %cst : f32
      %21 = arith.subf %arg1, %20 : f32
      linalg.yield %21 : f32
    } -> tensor<12xf32>
  return %16, %17, %18 : tensor<12xf32>, tensor<12xf32>, tensor<12xf32>
}
```

When using tile and fuse through use of `tileConsumerAndFuseProducer` I get the following code (had to use [this](https://github.com/MaheshRavishankar/llvm-project/commit/ad27f961ed323cf2d16731cdff86683f35c41c11)) patch to actually create the repro)

```
#map0 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d3, d4, d0)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0)>
#map2 = affine_map<(d0) -> (d0)>
module {
  func.func @tile_and_fuse_err(%arg0: tensor<12xf32>, %arg1: tensor<12x12x12x12x12xf32>, %arg2: tensor<12xf32>) -> (tensor<12xf32>, tensor<12xf32>, tensor<12xf32>) {
    %c3 = arith.constant 3 : index
    %c12 = arith.constant 12 : index
    %c0 = arith.constant 0 : index
    %cst = arith.constant 1.420000e+00 : f32
    %cst_0 = arith.constant 1.450000e+00 : f32
    %cst_1 = arith.constant 1.300000e+00 : f32
    %cst_2 = arith.constant 0.000000e+00 : f32
    %0 = linalg.init_tensor [12] : tensor<12xf32>
    %1 = linalg.fill ins(%cst_2 : f32) outs(%0 : tensor<12xf32>) -> tensor<12xf32>
    %2 = linalg.generic {indexing_maps = [#map0, #map1, #map1], iterator_types = ["parallel", "reduction", "reduction", "reduction", "reduction"]} ins(%arg1, %arg2 : tensor<12x12x12x12x12xf32>, tensor<12xf32>) outs(%1 : tensor<12xf32>) {
    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
      %6 = arith.subf %arg3, %arg4 : f32
      %7 = arith.mulf %6, %6 : f32
      %8 = arith.addf %arg5, %7 : f32
      linalg.yield %8 : f32
    } -> tensor<12xf32>
    %3 = linalg.generic {indexing_maps = [#map2], iterator_types = ["parallel"]} outs(%2 : tensor<12xf32>) {
    ^bb0(%arg3: f32):
      %6 = arith.divf %arg3, %cst_0 : f32
      %7 = arith.addf %6, %cst_1 : f32
      linalg.yield %7 : f32
    } -> tensor<12xf32>
    %4 = linalg.generic {indexing_maps = [#map2], iterator_types = ["parallel"]} outs(%3 : tensor<12xf32>) {
    ^bb0(%arg3: f32):
      %6 = math.sqrt %arg3 : f32
      linalg.yield %6 : f32
    } -> tensor<12xf32>
    %5 = scf.for %arg3 = %c0 to %c12 step %c3 iter_args(%arg4 = %0) -> (tensor<12xf32>) {
      %6 = tensor.extract_slice %arg0[%arg3] [3] [1] : tensor<12xf32> to tensor<3xf32>
      %7 = tensor.extract_slice %arg4[%arg3] [3] [1] : tensor<12xf32> to tensor<3xf32>
      %8 = linalg.fill ins(%cst_2 : f32) outs(%7 : tensor<3xf32>) -> tensor<3xf32>
      %9 = tensor.extract_slice %arg1[0, 0, 0, 0, %arg3] [12, 12, 12, 12, 3] [1, 1, 1, 1, 1] : tensor<12x12x12x12x12xf32> to tensor<12x12x12x12x3xf32>
      %10 = tensor.extract_slice %arg2[%arg3] [3] [1] : tensor<12xf32> to tensor<3xf32>
      %11 = linalg.generic {indexing_maps = [#map0, #map1, #map1], iterator_types = ["parallel", "reduction", "reduction", "reduction", "reduction"]} ins(%9, %10 : tensor<12x12x12x12x3xf32>, tensor<3xf32>) outs(%8 : tensor<3xf32>) {
      ^bb0(%arg5: f32, %arg6: f32, %arg7: f32):
        %16 = arith.subf %arg5, %arg6 : f32
        %17 = arith.mulf %16, %16 : f32
        %18 = arith.addf %arg7, %17 : f32
        linalg.yield %18 : f32
      } -> tensor<3xf32>
      %12 = linalg.generic {indexing_maps = [#map2], iterator_types = ["parallel"]} outs(%11 : tensor<3xf32>) {
      ^bb0(%arg5: f32):
        %16 = arith.divf %arg5, %cst_0 : f32
        %17 = arith.addf %16, %cst_1 : f32
        linalg.yield %17 : f32
      } -> tensor<3xf32>
      %13 = linalg.generic {indexing_maps = [#map2], iterator_types = ["parallel"]} outs(%12 : tensor<3xf32>) {
      ^bb0(%arg5: f32):
        %16 = math.sqrt %arg5 : f32
        linalg.yield %16 : f32
      } -> tensor<3xf32>
      %14 = linalg.generic {indexing_maps = [#map2, #map2], iterator_types = ["parallel"]} ins(%6 : tensor<3xf32>) outs(%13 : tensor<3xf32>) {
      ^bb0(%arg5: f32, %arg6: f32):
        %16 = arith.subf %arg5, %arg6 : f32
        %17 = arith.mulf %16, %cst : f32
        %18 = arith.subf %arg5, %17 : f32
        linalg.yield %18 : f32
      } -> tensor<3xf32>
      %15 = tensor.insert_slice %14 into %arg4[%arg3] [3] [1] : tensor<3xf32> into tensor<12xf32>
      scf.yield %15 : tensor<12xf32>
    }
    return %3, %4, %5 : tensor<12xf32>, tensor<12xf32>, tensor<12xf32>
  }
}
```
Problem here is the first three generic ops get fused with the last one (through tile + fuse) but their original use in the return persists. This negates all advantage of tile + fuse.


</pre>
<img width="1px" height="1px" alt="" src="http://email.email.llvm.org/o/eJzNWt1vozgQ_2vIi7URmK_wkId2u5X24aTV6aR7jBwwie8I5Izpbv_7GxtIHLADTdpqo4oQM_bY8_mbodsqe11_R3smkNizGm1pXnHq4K9o24jzMNkRVsrRukI5K1i5Q6RErK4bimACYmKJvlZlzTLKYYYcLIrqp6RLq4w67pPjPnTXyO3-1M-8KdOlvCAncAUr6IaU2SZvarqhnDt45eDQc5HjPyBBy7rijv_Vw79yHzv-N7kj-dwbPdf_Lmmxda0EfYFvIFqZOc0fTbqjxo_tDZKc01oA6ydEOBP7ZQrSEqQUyFsG7Zbk5AvqjWehDy30rpnet5BjE7m7dA3knq9oQfGk2C1ZycSmPTdywkcPO-GTTaraGoG-BlhRgVhZtxrut9PyBU1Ujegeef6Uwq5zDXWuO1pSzsDWzpqRH1Zm9BcY6-ZAjrWih2ORPGcllUOwNOwlc6W6M09dsbr66hropmN-DldXmoUyGZ2z4XMj3_P6964QPo02yQTlRFR8I16P9CQhB-Mj4aQoaAG3rYNhTrMmFawq7xqCPcRPgz30xuJ5dl82-73ZeDQbC67Y2IWtOOG37dZtZxG-804mqzYEI3g04p_NGu70M0nOieaEdbPNUb_wecWhP3ZTse7uh6bI2_V62SS2aXpUIVmWn7bZTsQj_5efzn9eGS2y0yoDKtDYHIeM3tMhbeY7z2KVlel2EN5tB1M6ztjLUMd9qJ9Ucq-tk5L7oD-tLgPVTHXFv7G6og9U14FIj_yPi15bc-Rs8rqZcl59sJwteeEdNHIhv_hxs2ElzIVzbNrTbAQnZQ0w8bDZtOtg3MpCw3l4sNA52tvB30UIj6-Q3RzDPzNit_hwMmQbWH5oyOZUNLxsI3efW-L-ZnUFmM8e7bDyk7E-aK9_72mJmlrWE7JGgOIjQ9J2oNbgVbPbI3lfgVwiVUPIUqQ5UP5QZs_w5AevAFpQDk_Rd7SjwlCiSA_YkwyJSq0Fpi4LH-ULq70Qx1paAn6Gvx1oodkCYj7Ajz_Intb7P8kLq_ek_JeAJT8Xxcvhy5FX_9BUwE8gPDB5QzIc50nk0czHfprjzIti30uzPF9F0crP_TANvFSim0Qa7ZGIdC_3Q1LRgMe9opRTIqjaPKfA4FxoGCWHffD0zuQ-BstqfLx7-BjWw9b1bPMOoGQwDc3V59SW4Ebu1eKyjxCzqss-eHxweTmIZxA7fFMt11ZOKmlcUnvG0q-D0wZ6Y2Hp2qjtda4LH-rgR9cER67WsEE4Y66lXvbdGXNtxfDUXPfOwrjNKTdVxtcS40RQb9OKBXEYUUYXTVo7Vw6v374BMrx7lXiWlSkLv0NxeCPI9Ee4IhiNhNeQRmTO-r62oAUwxAac0SfvyDJnZS4Lw25aPAdirG5CGIq9_3ZjxPfUDtfagPO0OqExrdDrNWYt1wYq68UfafNmlR8GJc2Vf_DJ8r_W1btT_sPKbdT_NInO4BZzRdd2GOs0X8pW-JnnU5c7AcJ1SbcW9Nilaym0DRCeQ1fQT3GnsEIyrAtPR2-pl_QXlFyp2NQFSynqUY7USSvCUGmo__bsWUru_TTqjw6vWe4VzsFHcb4omWenzUGh6FuzpoVrMnVeONWjSpSDy6UMPJUGxtezbOTQ8GKQ1zirXchOf2w5kedOHQl_lAo97-1x5_cGIacG7BiejfWgww_fiD5WdmsdBoHLQBmOwEY0GomtwVRv1Y4QSKitaYqteuNw0Ow4tRCuTLQgkVPTwQhFDL04AxoxBXWbZd4Aj-_KiMO3mLeoelqNGiwJp2DJWI-nDvQkNDHow6i3-fr4bIQ4fLv03voYwpTRW12LGM0AfrYYbwJ6pyh7g0RPkTGyy9P-uvd9At5nhzdLL3cc3wxsPyu-hXrSByVRruV8MBNWtsD1jQiuZ9fOvwadkcLM5_NYX77piPz849yY7uuroPu2v8V7a1daY2lrUP_g1bagB7SnnCJWt61lxmvZZOaUot7FKvAq2XqW7ccM_QTlK9KCAGVVqv5z38xWLW4HPypaafjyf3GAmHFUcbaTdqCa1KzsesFKEEfKa1aLeon-kv-yU9IdEeCa4I2IZC-kFGSneuQXqy_1BvKCrr0ocuMVDuJoka39LPETshBMFHT9ILvlEw32WnbYDw2cqNuT3N4LKZqWsZSMOv2xo0ekRj9pUSwaXqyvdNllR33cWFf_fgRh4zmMsRsu9uskxmlCvSiAAitI4yx1My8M_chLs22ydYNFQba0qNdtpDoUDDT-0LpVG64WbI1djN2Vt3JdHIfBEuMkoTFJVtRbpRGmTuDSA2HFUu5lWfHdgq_VtrYNlHKBWygNnB6Suma7klLFEtYnjdhXfD16a7BQZ1mrg_wP260PEQ">