<table border="1" cellspacing="0" cellpadding="8">
    <tr>
        <th>Issue</th>
        <td>
            <a href=https://github.com/llvm/llvm-project/issues/82375>82375</a>
        </td>
    </tr>

    <tr>
        <th>Summary</th>
        <td>
            [mlir][mesh] Sharding propagation does not produce a complete sharding annotation
        </td>
    </tr>

    <tr>
      <th>Labels</th>
      <td>
            mlir
      </td>
    </tr>

    <tr>
      <th>Assignees</th>
      <td>
      </td>
    </tr>

    <tr>
      <th>Reporter</th>
      <td>
          sogartar
      </td>
    </tr>
</table>

<pre>
    I have not noticed that sharding propagation would not produce `mesh.shard` for each result and a `mesh.shard` with `annotate_for_users` fro each use. See [test](https://github.com/llvm/llvm-project/blob/135529aab0ebe4915143f376d94f8aba6ec71e4e/mlir/test/Dialect/Mesh/sharding-propagation.mlir#L21).

Running the pass on
```mlir
mesh.mesh @mesh_2d(shape = 2x4)

func.func @element_wise_on_def(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
  %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
  %1 = mesh.shard %0 to <@mesh_2d, [[0], [1]]> : tensor<8x16xf32>
  return %1 : tensor<8x16xf32>
}
```

results in

```mlir
mesh.mesh @mesh_2d(shape = 2x4)
func.func @element_wise_on_def(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
  %0 = mesh.shard %arg0 to <@mesh_2d, [[0], [1]]> annotate_for_users : tensor<8x16xf32>
  %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
  %2 = mesh.shard %1 to <@mesh_2d, [[0], [1]]> : tensor<8x16xf32>
  return %2 : tensor<8x16xf32>
}
```

I thought that it should result in
```mlir
mesh.mesh @mesh_2d(shape = 2x4)
func.func @element_wise_on_def(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
  %0 = mesh.shard %arg0 to <@mesh_2d, [[0], [1]]> : tensor<8x16xf32>
  %1 = mesh.shard %0 to <@mesh_2d, [[0], [1]]> annotate_for_users : tensor<8x16xf32>
  %2 = tosa.sigmoid %1 : (tensor<8x16xf32>) -> tensor<8x16xf32>
  %3 = mesh.shard %2 to <@mesh_2d, [[0], [1]]> : tensor<8x16xf32>
  %4 = mesh.shard %3 to <@mesh_2d, [[0], [1]]> annotate_for_users : tensor<8x16xf32>
  return %4 : tensor<8x16xf32>
}
```

I implemented the spmdization pass under the assumption that the sharding annotations at this point are complete.
It may be easy to handle the cases with the missing shardings, but the canonical form should be without missing annotations.
</pre>
<img width="1px" height="1px" alt="" src="http://email.email.llvm.org/o/eJzUVttu4zYQ_Rr6hYggjS6WHvSQxCsgQPvS_QCDlEYSC4kUONQm6dcXpOzEu-te0mRbFLBlS5wL58zh6AgiNWjEmuV3LD_sxOpGY2syg7BO2J003XP9wEfxBbk2zn9Vix13o3CcRmE7pQe-WLOIQThlNH8069QF28Wabm2RsyKekcYomLMi5r2xHEU7cou0To4L3XHxvdmjcqN_KrQ2Tjg89sYeV0JLIYg1W5CVMOKfETnL7xySY_mBQTk6txBLbxk0DJpBuXGVUWtmBs00fTn_3CzW_IqtY9DIyUgGTZLmOVRCyBglZlWSJ1nap_uiq7K-FFIU2O4TzJBBM0_KMmhCTmgOSkxbpJ-RRgbNGZ2bC3SizSf9CRIGVcTiA4tvt-svq9YeSzciXwQRN_q0XMTbJ_iGRwEnf-EsC6AdoWNQ0igW5Cw9cHjKGFSX4ftVt5G_eBeccEbtjo-K8Gj0scOeQckgF3aIWXrLHWoylqX35VNSPPUpsPQTg4rfsPTT1UXO9ndbIs4Z5HHYhTMkIlLDbFTHT9G5D8-gfHOGi-hJiP5Kli2jM5yl95eA3PPA6rs4UCLcJf5vfgg7_qM6T4ksutXqc74_s2X7wzetukR-IzlxpS-fvrer_0U_v0Y8dPOtoH9_lv-yDy8N_5ZOH8IluFJZ8gO5BO_i0gN3o1mH0W0DWPkZHObtaZKqD5oZ_1N2_W0qvXd2_DMaw1UaJx9B4_RKWfAj8MuuJEr_Dfxej1D2ziOk5mVjdNAxyGmZO_Xbpl7Cm3fVHdqwJIjWeQkr4cAF87PqORWhjCYe1hTxxSjtuLDIW-PTODy95B8cn8Uzl8hR0LMHbBS6mzCEbAUhbXrH386KyCc4ZyIPoFzdyVYbrVoxeRk1n8-_xOBuVvfifbG9aNfVaVelldhhnezjMstygGI31pXM2gSwKmWbFHG-x0JWVS9a2ee5zKHYqRpiyGKAOCniIs-jspJSxJXIeolZIVs_IGahpsjrqcjYYaeIVqxLSPf5bhISJwrqEmATP-B1pq2D_JLrQCyLJ0WOXgM45aagSIODJ82d5xXLD_zzNcnZGaSvFKd4Qf9at3arneo368NQFTFoQmG_BwAA__-C6k1b">