[Mlir-commits] [mlir] [mlir][shard, mpi] Allow more than one last axis to be "unsplit" (PR #180754)

Frank Schlimbach llvmlistbot at llvm.org
Tue Feb 10 07:42:33 PST 2026


================
@@ -132,98 +132,117 @@ trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
   return std::nullopt;
 }
 
-// Detect if the resharding is of type e.g.
-// [[0, 1, 2]] -> [[0, 1]].
-// If detected, returns the corresponding tensor axis grid axis pair.
-static std::optional<std::tuple<int64_t, GridAxis>>
-detectUnsplitLastAxisInResharding(const Sharding &sourceSharding,
+// Detect if the resharding removes trailing split Axes along a tensor
+// dimension, e.g.
+// [[0, 1, 2]] -> [[0, 1]], [[0, 1, 2]] -> [0] or [[0, 1, 2]] -> [].
+// If detected, returns the corresponding (tensor dim, grid axes) pair, where
+// the "grid axes" are the removed trailing split axes.
+static std::optional<std::tuple<int64_t, SmallVector<GridAxis>>>
+detectUnsplitLastAxesInResharding(const Sharding &sourceSharding,
                                   const Sharding &targetSharding) {
-  for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
-       ++tensorAxis) {
-    if (targetSharding.getSplitAxes().size() > tensorAxis) {
-      if (sourceSharding.getSplitAxes()[tensorAxis].size() !=
-          targetSharding.getSplitAxes()[tensorAxis].size() + 1)
+  for (size_t tensorDim = 0; tensorDim < sourceSharding.getSplitAxes().size();
+       ++tensorDim) {
+    if (targetSharding.getSplitAxes().size() > tensorDim) {
+      // No match if the target sharding does not have less split axes than the
+      // source sharding along the current tensor dimension.
+      if (sourceSharding.getSplitAxes()[tensorDim].size() <=
+          targetSharding.getSplitAxes()[tensorDim].size())
         continue;
-      if (!llvm::equal(
-              llvm::make_range(
-                  sourceSharding.getSplitAxes()[tensorAxis]
-                      .asArrayRef()
-                      .begin(),
-                  sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
-                      1),
-              targetSharding.getSplitAxes()[tensorAxis].asArrayRef()))
+      // No match if the split axes of the target sharding are different from
+      // the first split axes of the source sharding.
+      if (!std::equal(
+              targetSharding.getSplitAxes()[tensorDim].asArrayRef().begin(),
+              targetSharding.getSplitAxes()[tensorDim].asArrayRef().end(),
+              sourceSharding.getSplitAxes()[tensorDim].asArrayRef().begin()))
----------------
fschlimb wrote:

Yes it does just above.

https://github.com/llvm/llvm-project/pull/180754


More information about the Mlir-commits mailing list