<table border="1" cellspacing="0" cellpadding="8">
<tr>
<th>Issue</th>
<td>
<a href=https://github.com/llvm/llvm-project/issues/136117>136117</a>
</td>
</tr>
<tr>
<th>Summary</th>
<td>
Issues enabling ND mesh resharding in Spmdization pass: incorrect axis comparison and resharding assertion failures
</td>
</tr>
<tr>
<th>Labels</th>
<td>
new issue
</td>
</tr>
<tr>
<th>Assignees</th>
<td>
</td>
</tr>
<tr>
<th>Reporter</th>
<td>
zhangdianchen
</td>
</tr>
</table>
<pre>
When trying to enable ND mesh resharding in the Spmdization pass of MLIR, I encountered several issues that cause incorrect behavior or assertion failures. Below is a detailed breakdown:
### 1. Incorrect detection logic in detectMoveLastSplitAxisInResharding
**Problem Reproduction:**
When executing the following resharding sequence:
```
%sharding = mesh.sharding @mesh_3d split_axes = [[0, 1], [2]] : !mesh.sharding
%in1_sharded1 = mesh.shard %in1 to %sharding : tensor<8x16xi8>
%sharding = mesh.sharding @mesh_3d split_axes = [[0], [1, 2]] : !mesh.sharding
%in1_sharded2 = mesh.shard %in1_sharded1 to %sharding annotate_for_users : tensor<8x16xi8>
```
The pass is expected to detect a valid last-axis movement and insert a mesh.all_to_all operation. However, instead it crashes with the following assertion:
`mlir::TypedValue<mlir::ShapedType> mlir::sharding::reshardOn1DGrid(...): Assertion `targetShard && "Did not find any pattern to apply."' failed.
`
**Root Cause:**
In detectMoveLastSplitAxisInResharding, the logic checks:
```
if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() ||
targetSharding.getSplitAxes()[targetTensorAxis].empty() ||
sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
targetSharding.getSplitAxes()[targetTensorAxis].asArrayRef().back())
continue;
```
In the example [[0, 1], [2]] -> [[0], [1, 2]], this compares:
- source.split_axes[0][1] = 1 vs. target.split_axes[1][1] = 2 — incorrect
It should instead compare source.back() with target.front():
`sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
targetSharding.getSplitAxes()[targetTensorAxis].asArrayRef().front()
`
**Additional incorrect check:**
```
if (!llvm::equal(
llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().begin(),
sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().end() - 1),
llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis].asArrayRef().begin(),
targetSharding.getSplitAxes()[targetTensorAxis].asArrayRef().end() - 1)))
continue;
```
This incorrectly compares the wrong slices. In the example, it ends up comparing [0] and [1]. Instead, it should skip the first of the target and compare:
```
if (llvm::equal(
llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().begin(),
sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().end() - 1),
llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis].asArrayRef().begin() + 1,
targetSharding.getSplitAxes()[targetTensorAxis].asArrayRef().end())))
continue;
```
This now compares [0] with [2] — which is correct.
### 2.Incorrect ShardingTarget construction in targetShardingInMoveLastAxis
Skipping the above assert leads to another failure:
`mlir::TypedValue<mlir::ShapedType> mlir::sharding::reshardOn1DGrid(...): Assertion `actualTargetSharding == targetSharding' failed.`
Root cause: in targetShardingInMoveLastAxis, the targetShardingSplitAxes are incorrectly ordered.
Current result:
```
actualTargetSharding: split_axes = [[0], [2, 1]]
targetSharding: split_axes = [[0], [1, 2]]
```
Fix: Instead of:
`targetSplitAxes.push_back(gridAxis);
`
Use:
```
targetSplitAxes.insert(targetSplitAxes.begin(), gridAxis);
```
### 3. Bug in handlePartialAxesDuringResharding
In the following snippet:
```
llvm::SmallVector<GridAxis> remainingPartialAxes;
llvm::copy_if(sourceShardingPartialAxesSet,
std::back_inserter(allReduceGridAxes),
[&targetShardingPartialAxesSet](Axis a) {
return targetShardingPartialAxesSet.contains(a);
});
```
It should be writing to remainingPartialAxes, not allReduceGridAxes. Corrected version:
```
llvm::copy_if(sourceShardingPartialAxesSet,
std::back_inserter(remainingPartialAxes),
[&targetShardingPartialAxesSet](Axis a) {
return targetShardingPartialAxesSet.contains(a);
});
```
Please let me know if a patch is desired — I'm happy to contribute a PR for these changes.
</pre>
<img width="1" height="1" alt="" src="http://email.email.llvm.org/o/eJzsWFtv6zYS_jXMyyCCRNmx_eAHxz7pBmh3D5Js99GgxbHFDU2qJOXE_fWLIeW7z6XbtIsFGghwRJFz_-ZC4b1aGcQx69-z_uxGtKG2bvxrLcxKKmGqGs3Nwsrt-F81Gghuq8wKggU0YqER_j6DNfoaHPpaOEkflYFQIzw3a6l-FUFZA43wHuwSfvrx8YnxKTwCmsq2JqBDCR436IQG5X2LHkItAlSi9QjKVNY5rAIssBYbZR1YB8J7dJHuUijdOvQZ3KO2b6A8CJAYhNIoYeFQvEr7Zlg5YfmE8TI9UGTwuKcsMWAVqWm7UhWJn5Z-shv8Ufjw3GgVJu_KP5qnvZaRHj2fnV1oXMMTNs7KNlIifvEjyyfRbPiOVRui5WqEpdXavtHbkdU8_tKiqbCT9S7vHuLT3-9iZbJ3dljp5bQwLyV4EnQu3tHHfdGj9znZu2D9Gf2y_j2nf_szAFZOgPHihFripkwxjysoizOOkD5TAByLFYkFNN46Vk6H78Xduxqy8tMHSL8XvKDf3yY9vy79QbkzNYQxNoiA86V189aj819T7NhF-eSlxhTmygO-N1gFlEQ_xRII2AitJGjhw614Vx7WdoNrNAGEkaAMhTSIJKzQeh7sXGgNtkEXMZTB3-wbAYXMoIwPKCSoAJUTvkYPbyrUZ9G1x8k-ptZaOXopJy_bBuXPQrfIyulh-bkWDUr6yMpPcFjfmzi-dWH7D1PMfnBKMj7MsozxERlrsgcnu8uDcCsMz53t7xi_A8b5TEkwNsBSGQnCbKERIaAzZC7RNHqbMc4ZH0R4o8w6Y-8Q92RtgCnlh2OcPX4fbPk0GilBvaqxevWXiFNLYHzobesqfO5OZqRIIome8SGp279Pe15igBAv1p9luG7CNu0ANpjSk08AAI6s8UWCac93EvxvJRR-4pzYPuEybcsWonrdMeAFK2cdg98j9VeY0EMMKmuCMhSC95eAekxlBN_FutH4jWx2S-H61ZSRPK88VHbdCIed2287I2aH9LOjQef7s5hCCtj4rLPE6c7ibCcH9omzYc5GvUP1Im0C-Nq2Wu6x28mx43_kggTlxGzprAmd1XZx-oe6_cPcfSz5KYAnUirKEFTy92U4YvEY0JeAZLzQerNOOQh_aYWmxV2kHj6txSvOnTAr_B0ovjQWrpTZxe_0CCBnfx_GEI3snHMLxRnXLyj7cVj9TmU_jOGlst-XJV4I0vso0ts9vGPyeHOWWiutKuoRTzNKrKMB0EgPbdOdiz1Jwn-syx206WxEbXeog7J_VU2qucr5QA0uvSR94_FOmC9VmL-i-U-PZmD8nvj_qTH9vfEcg9nYt0MQ72IxloSu3h1VmLdaVTXEohbjPzsZc3h2mHJ26ryk4Kys8cGlaSUObCc6P5pdExV1yyfPr6ppduOLWNgNds0laBTSx8bN2FCj241k_7OWU1ShFfrlRB-qzFScT7U87jB36ST2llXXW37TMF0zebpnHyxA1f04NVknad7NEqtp6xy1_w59q8NlirimCAn1rTmJ71uk_mU9JwIdsH_DuHWZdh_UO5HqsiLY5V7-jt_OCFnT-nreNRorp2Qy3OgQ_Pnkn_5KhjynkyakQ1bYfzjNZHCVx2GU3mGjzOC-jXcVtTBS42fhghKaKM5aKgMno35XOQ6jlTeqafCK1w6p7HkttP4ZqxBHxx92YpWfwOFaKKPM6ohpkvVwurLNdq6WFyn_6Mgzhmt5zAeZSJDR58lsNDUOhdZPKNsKkyzov5D7KRT43WncnLGlKBmSOiDSUHJ_JZ06DK07R9ApoYxyoVCGEqvYe-xMnMHsqi8PHfWCCr0K3cXUVevyaZw4L0yQwTTBEyVs0PnjYfmKT_8Ar1wX9__ZMZ81Co-gMcAa4ZXqmVqCoCk_FSuJXjmUR3XskfHBGmrRNFtyITF3atEGBAGfn2BpHeHPI1Q19QY-u5HjUo7KkbjBcTHo9YZDPhqMbupxXlayqEqRF3LYu1viki_4AMveQlRVscj5jRrznPfzXjHIRz3eH2X9clgNB6PirsqHXFQL1svJJTojt2fWrW7iveS4KO-KYnCjxQK1j9elnBt8S7eWjFOqvHFjOnS7aFee9XKtfPAHMkEFjePHdMkZr08pYq9foJ5fnqZqtCvm8f6oa1i9NbHVPDp_eT960zo9rkNo4tTLHxh_WKlQt4ussmvGH2KEp5_bxtl_08zKH9J9LOMPneqbMf9PAAAA__-lP_MS">