[Mlir-commits] [mlir] [mlir][shard, mpi] Allow more than one last axis to be "unsplit" (PR #180754)
Frank Schlimbach
llvmlistbot at llvm.org
Tue Feb 10 07:42:33 PST 2026
================
@@ -132,98 +132,117 @@ trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
return std::nullopt;
}
-// Detect if the resharding is of type e.g.
-// [[0, 1, 2]] -> [[0, 1]].
-// If detected, returns the corresponding tensor axis grid axis pair.
-static std::optional<std::tuple<int64_t, GridAxis>>
-detectUnsplitLastAxisInResharding(const Sharding &sourceSharding,
+// Detect if the resharding removes trailing split Axes along a tensor
+// dimension, e.g.
+// [[0, 1, 2]] -> [[0, 1]], [[0, 1, 2]] -> [0] or [[0, 1, 2]] -> [].
+// If detected, returns the corresponding (tensor dim, grid axes) pair, where
+// the "grid axes" are the removed trailing split axes.
+static std::optional<std::tuple<int64_t, SmallVector<GridAxis>>>
+detectUnsplitLastAxesInResharding(const Sharding &sourceSharding,
const Sharding &targetSharding) {
- for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
- ++tensorAxis) {
- if (targetSharding.getSplitAxes().size() > tensorAxis) {
- if (sourceSharding.getSplitAxes()[tensorAxis].size() !=
- targetSharding.getSplitAxes()[tensorAxis].size() + 1)
+ for (size_t tensorDim = 0; tensorDim < sourceSharding.getSplitAxes().size();
+ ++tensorDim) {
+ if (targetSharding.getSplitAxes().size() > tensorDim) {
+ // No match if the target sharding does not have less split axes than the
+ // source sharding along the current tensor dimension.
+ if (sourceSharding.getSplitAxes()[tensorDim].size() <=
+ targetSharding.getSplitAxes()[tensorDim].size())
continue;
- if (!llvm::equal(
- llvm::make_range(
- sourceSharding.getSplitAxes()[tensorAxis]
- .asArrayRef()
- .begin(),
- sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
- 1),
- targetSharding.getSplitAxes()[tensorAxis].asArrayRef()))
+ // No match if the split axes of the target sharding are different from
+ // the first split axes of the source sharding.
+ if (!std::equal(
+ targetSharding.getSplitAxes()[tensorDim].asArrayRef().begin(),
+ targetSharding.getSplitAxes()[tensorDim].asArrayRef().end(),
+ sourceSharding.getSplitAxes()[tensorDim].asArrayRef().begin()))
----------------
fschlimb wrote:
Yes it does just above.
https://github.com/llvm/llvm-project/pull/180754
More information about the Mlir-commits
mailing list