<table border="1" cellspacing="0" cellpadding="8">
    <tr>
        <th>Issue</th>
        <td>
            <a href=https://github.com/llvm/llvm-project/issues/136117>136117</a>
        </td>
    </tr>

    <tr>
        <th>Summary</th>
        <td>
            Issues enabling ND mesh resharding in Spmdization pass: incorrect axis comparison and resharding assertion failures
        </td>
    </tr>

    <tr>
      <th>Labels</th>
      <td>
            new issue
      </td>
    </tr>

    <tr>
      <th>Assignees</th>
      <td>
      </td>
    </tr>

    <tr>
      <th>Reporter</th>
      <td>
          zhangdianchen
      </td>
    </tr>
</table>

<pre>
    When trying to enable ND mesh resharding in the Spmdization pass of MLIR, I encountered several issues that cause incorrect behavior or assertion failures. Below is a detailed breakdown:
### 1. Incorrect detection logic in detectMoveLastSplitAxisInResharding
**Problem Reproduction:**
When executing the following resharding sequence:
```
%sharding = mesh.sharding @mesh_3d split_axes = [[0, 1], [2]]  : !mesh.sharding
%in1_sharded1 = mesh.shard %in1 to %sharding  : tensor<8x16xi8>
%sharding = mesh.sharding @mesh_3d split_axes = [[0], [1, 2]]  : !mesh.sharding
%in1_sharded2 = mesh.shard %in1_sharded1 to %sharding annotate_for_users : tensor<8x16xi8>
```

The pass is expected to detect a valid last-axis movement and insert a mesh.all_to_all operation. However, instead it crashes with the following assertion:
`mlir::TypedValue<mlir::ShapedType> mlir::sharding::reshardOn1DGrid(...): Assertion `targetShard && "Did not find any pattern to apply."' failed.
`
**Root Cause:**
In detectMoveLastSplitAxisInResharding, the logic checks:
```
if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() ||
    targetSharding.getSplitAxes()[targetTensorAxis].empty() ||
    sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
        targetSharding.getSplitAxes()[targetTensorAxis].asArrayRef().back())
  continue;
```

In the example [[0, 1], [2]] -> [[0], [1, 2]], this compares:
- source.split_axes[0][1] = 1 vs. target.split_axes[1][1] = 2 — incorrect
It should instead compare source.back() with target.front():
`sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
targetSharding.getSplitAxes()[targetTensorAxis].asArrayRef().front()
`
**Additional incorrect check:**
```
if (!llvm::equal(
      llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().begin(),
                       sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().end() - 1),
      llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis].asArrayRef().begin(),
                       targetSharding.getSplitAxes()[targetTensorAxis].asArrayRef().end() - 1)))
  continue;
```

This incorrectly compares the wrong slices. In the example, it ends up comparing [0] and [1]. Instead, it should skip the first of the target and compare:
```
if (llvm::equal(
      llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().begin(),
                       sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().end() - 1),
      llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis].asArrayRef().begin() + 1,
                       targetSharding.getSplitAxes()[targetTensorAxis].asArrayRef().end())))
  continue;
```
This now compares [0] with [2] — which is correct.
### 2.Incorrect ShardingTarget construction in targetShardingInMoveLastAxis
Skipping the above assert leads to another failure:
`mlir::TypedValue<mlir::ShapedType> mlir::sharding::reshardOn1DGrid(...): Assertion `actualTargetSharding == targetSharding' failed.`

Root cause: in targetShardingInMoveLastAxis, the targetShardingSplitAxes are incorrectly ordered.

Current result:
```
actualTargetSharding: split_axes = [[0], [2, 1]]
targetSharding:       split_axes = [[0], [1, 2]]
```

Fix: Instead of:
`targetSplitAxes.push_back(gridAxis);
`
Use:
```
targetSplitAxes.insert(targetSplitAxes.begin(), gridAxis);
```
### 3. Bug in handlePartialAxesDuringResharding
In the following snippet:
```
llvm::SmallVector<GridAxis> remainingPartialAxes;
llvm::copy_if(sourceShardingPartialAxesSet,
              std::back_inserter(allReduceGridAxes),
              [&targetShardingPartialAxesSet](Axis a) {
                return targetShardingPartialAxesSet.contains(a);
              });
```
It should be writing to remainingPartialAxes, not allReduceGridAxes. Corrected version:
```
llvm::copy_if(sourceShardingPartialAxesSet,
              std::back_inserter(remainingPartialAxes),
              [&targetShardingPartialAxesSet](Axis a) {
                return targetShardingPartialAxesSet.contains(a);
              });
```
Please let me know if a patch is desired — I'm happy to contribute a PR for these changes.
</pre>
<img width="1" height="1" alt="" src="http://email.email.llvm.org/o/eJzsWFtv6zYS_jXMyyCCRNmx_eAHxz7pBmh3D5Js99GgxbHFDU2qJOXE_fWLIeW7z6XbtIsFGghwRJFz_-ZC4b1aGcQx69-z_uxGtKG2bvxrLcxKKmGqGs3Nwsrt-F81Gghuq8wKggU0YqER_j6DNfoaHPpaOEkflYFQIzw3a6l-FUFZA43wHuwSfvrx8YnxKTwCmsq2JqBDCR436IQG5X2LHkItAlSi9QjKVNY5rAIssBYbZR1YB8J7dJHuUijdOvQZ3KO2b6A8CJAYhNIoYeFQvEr7Zlg5YfmE8TI9UGTwuKcsMWAVqWm7UhWJn5Z-shv8Ufjw3GgVJu_KP5qnvZaRHj2fnV1oXMMTNs7KNlIifvEjyyfRbPiOVRui5WqEpdXavtHbkdU8_tKiqbCT9S7vHuLT3-9iZbJ3dljp5bQwLyV4EnQu3tHHfdGj9znZu2D9Gf2y_j2nf_szAFZOgPHihFripkwxjysoizOOkD5TAByLFYkFNN46Vk6H78Xduxqy8tMHSL8XvKDf3yY9vy79QbkzNYQxNoiA86V189aj819T7NhF-eSlxhTmygO-N1gFlEQ_xRII2AitJGjhw614Vx7WdoNrNAGEkaAMhTSIJKzQeh7sXGgNtkEXMZTB3-wbAYXMoIwPKCSoAJUTvkYPbyrUZ9G1x8k-ptZaOXopJy_bBuXPQrfIyulh-bkWDUr6yMpPcFjfmzi-dWH7D1PMfnBKMj7MsozxERlrsgcnu8uDcCsMz53t7xi_A8b5TEkwNsBSGQnCbKERIaAzZC7RNHqbMc4ZH0R4o8w6Y-8Q92RtgCnlh2OcPX4fbPk0GilBvaqxevWXiFNLYHzobesqfO5OZqRIIome8SGp279Pe15igBAv1p9luG7CNu0ANpjSk08AAI6s8UWCac93EvxvJRR-4pzYPuEybcsWonrdMeAFK2cdg98j9VeY0EMMKmuCMhSC95eAekxlBN_FutH4jWx2S-H61ZSRPK88VHbdCIed2287I2aH9LOjQef7s5hCCtj4rLPE6c7ibCcH9omzYc5GvUP1Im0C-Nq2Wu6x28mx43_kggTlxGzprAmd1XZx-oe6_cPcfSz5KYAnUirKEFTy92U4YvEY0JeAZLzQerNOOQh_aYWmxV2kHj6txSvOnTAr_B0ovjQWrpTZxe_0CCBnfx_GEI3snHMLxRnXLyj7cVj9TmU_jOGlst-XJV4I0vso0ts9vGPyeHOWWiutKuoRTzNKrKMB0EgPbdOdiz1Jwn-syx206WxEbXeog7J_VU2qucr5QA0uvSR94_FOmC9VmL-i-U-PZmD8nvj_qTH9vfEcg9nYt0MQ72IxloSu3h1VmLdaVTXEohbjPzsZc3h2mHJ26ryk4Kys8cGlaSUObCc6P5pdExV1yyfPr6ppduOLWNgNds0laBTSx8bN2FCj241k_7OWU1ShFfrlRB-qzFScT7U87jB36ST2llXXW37TMF0zebpnHyxA1f04NVknad7NEqtp6xy1_w59q8NlirimCAn1rTmJ71uk_mU9JwIdsH_DuHWZdh_UO5HqsiLY5V7-jt_OCFnT-nreNRorp2Qy3OgQ_Pnkn_5KhjynkyakQ1bYfzjNZHCVx2GU3mGjzOC-jXcVtTBS42fhghKaKM5aKgMno35XOQ6jlTeqafCK1w6p7HkttP4ZqxBHxx92YpWfwOFaKKPM6ohpkvVwurLNdq6WFyn_6Mgzhmt5zAeZSJDR58lsNDUOhdZPKNsKkyzov5D7KRT43WncnLGlKBmSOiDSUHJ_JZ06DK07R9ApoYxyoVCGEqvYe-xMnMHsqi8PHfWCCr0K3cXUVevyaZw4L0yQwTTBEyVs0PnjYfmKT_8Ar1wX9__ZMZ81Co-gMcAa4ZXqmVqCoCk_FSuJXjmUR3XskfHBGmrRNFtyITF3atEGBAGfn2BpHeHPI1Q19QY-u5HjUo7KkbjBcTHo9YZDPhqMbupxXlayqEqRF3LYu1viki_4AMveQlRVscj5jRrznPfzXjHIRz3eH2X9clgNB6PirsqHXFQL1svJJTojt2fWrW7iveS4KO-KYnCjxQK1j9elnBt8S7eWjFOqvHFjOnS7aFee9XKtfPAHMkEFjePHdMkZr08pYq9foJ5fnqZqtCvm8f6oa1i9NbHVPDp_eT960zo9rkNo4tTLHxh_WKlQt4ussmvGH2KEp5_bxtl_08zKH9J9LOMPneqbMf9PAAAA__-lP_MS">