[Mlir-commits] [mlir] [mlir][shard, mpi] Allow more than one last axis to be "unsplit" (PR #180754)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 10 07:29:50 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Frank Schlimbach (fschlimb)

<details>
<summary>Changes</summary>

A resharding pattern allowed only a single trailing axis to be "unsplit".
This PR allows multiple trailing axes to be "unsplit".

---
Full diff: https://github.com/llvm/llvm-project/pull/180754.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/Shard/Transforms/Partition.cpp (+77-58) 
- (modified) mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir (+12-1) 
- (modified) mlir/test/Dialect/Shard/partition.mlir (+24) 


``````````diff
diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
index e619c7073a8c4..8652d665e46bf 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
@@ -132,98 +132,117 @@ trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
   return std::nullopt;
 }
 
-// Detect if the resharding is of type e.g.
-// [[0, 1, 2]] -> [[0, 1]].
-// If detected, returns the corresponding tensor axis grid axis pair.
-static std::optional<std::tuple<int64_t, GridAxis>>
-detectUnsplitLastAxisInResharding(const Sharding &sourceSharding,
+// Detect if the resharding removes trailing split Axes along a tensor
+// dimension, e.g.
+// [[0, 1, 2]] -> [[0, 1]], [[0, 1, 2]] -> [0] or [[0, 1, 2]] -> [].
+// If detected, returns the corresponding (tensor dim, grid axes) pair, where
+// the "grid axes" are the removed trailing split axes.
+static std::optional<std::tuple<int64_t, SmallVector<GridAxis>>>
+detectUnsplitLastAxesInResharding(const Sharding &sourceSharding,
                                   const Sharding &targetSharding) {
-  for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
-       ++tensorAxis) {
-    if (targetSharding.getSplitAxes().size() > tensorAxis) {
-      if (sourceSharding.getSplitAxes()[tensorAxis].size() !=
-          targetSharding.getSplitAxes()[tensorAxis].size() + 1)
+  for (size_t tensorDim = 0; tensorDim < sourceSharding.getSplitAxes().size();
+       ++tensorDim) {
+    if (targetSharding.getSplitAxes().size() > tensorDim) {
+      // No match if the target sharding does not have less split axes than the
+      // source sharding along the current tensor dimension.
+      if (sourceSharding.getSplitAxes()[tensorDim].size() <=
+          targetSharding.getSplitAxes()[tensorDim].size())
         continue;
-      if (!llvm::equal(
-              llvm::make_range(
-                  sourceSharding.getSplitAxes()[tensorAxis]
-                      .asArrayRef()
-                      .begin(),
-                  sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
-                      1),
-              targetSharding.getSplitAxes()[tensorAxis].asArrayRef()))
+      // No match if the split axes of the target sharding are different from
+      // the first split axes of the source sharding.
+      if (!std::equal(
+              targetSharding.getSplitAxes()[tensorDim].asArrayRef().begin(),
+              targetSharding.getSplitAxes()[tensorDim].asArrayRef().end(),
+              sourceSharding.getSplitAxes()[tensorDim].asArrayRef().begin()))
         continue;
     } else {
-      if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1)
+      // Here the target dimension is replicated; there is nothing to do if the
+      // source dimension is also replicated.
+      if (sourceSharding.getSplitAxes()[tensorDim].size() == 0)
         continue;
     }
-    return std::make_tuple(
-        tensorAxis,
-        sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
+    // This is a match. Return the current tensor dimension and the trailing
+    // grid axis of the source sharding along this dimension.
+    SmallVector<GridAxis> unsplitAxes;
+    size_t dimOff = tensorDim >= targetSharding.getSplitAxes().size()
+                        ? 0
+                        : targetSharding.getSplitAxes()[tensorDim].size();
+    for (auto a =
+             sourceSharding.getSplitAxes()[tensorDim].asArrayRef().begin() +
+             dimOff;
+         a != sourceSharding.getSplitAxes()[tensorDim].asArrayRef().end(); ++a)
+      unsplitAxes.push_back(*a);
+    return std::make_tuple(tensorDim, unsplitAxes);
   }
   return std::nullopt;
 }
 
-static Sharding targetShardingInUnsplitLastAxis(MLIRContext *ctx,
+// Return the resulting Sharding if the unsplit last axes resharding is applied.
+static Sharding targetShardingInUnsplitLastAxes(MLIRContext *ctx,
                                                 const Sharding &sourceSharding,
-                                                int64_t splitTensorAxis) {
-  SmallVector<GridAxesAttr> targetShardingSplitAxes =
+                                                int64_t splitTensorDim,
+                                                size_t numUnsplitAxes) {
+  SmallVector<GridAxesAttr> resSplitAxes =
       llvm::to_vector(sourceSharding.getSplitAxes());
-  assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >
-         splitTensorAxis);
-  auto targetSplitAxes =
-      llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
-
-  targetSplitAxes.pop_back();
-  targetShardingSplitAxes[splitTensorAxis] =
-      GridAxesAttr::get(ctx, targetSplitAxes);
-  return Sharding::get(sourceSharding.getGridAttr(), targetShardingSplitAxes);
+  assert(static_cast<int64_t>(resSplitAxes.size()) > splitTensorDim);
+  ArrayRef<GridAxis> srcSplitAxes = resSplitAxes[splitTensorDim].asArrayRef();
+  assert(srcSplitAxes.size() >= numUnsplitAxes);
+  size_t numSplitAxes = srcSplitAxes.size() - numUnsplitAxes;
+  SmallVector<GridAxis> newSplitAxes(srcSplitAxes.begin(),
+                                     srcSplitAxes.begin() + numSplitAxes);
+  resSplitAxes[splitTensorDim] = GridAxesAttr::get(ctx, newSplitAxes);
+  return Sharding::get(sourceSharding.getGridAttr(), resSplitAxes);
 }
 
-static ShapedType allGatherResultShapeInUnsplitLastAxis(
-    ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) {
-  SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
-  targetShape[splitTensorAxis] =
-      gatherDimension(targetShape[splitTensorAxis], splitCount);
-  return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
+// Return the resulting Tensor type after applying the unsplit last axes
+// resharding.
+static ShapedType allGatherResultTypeInUnsplitLastAxes(
+    ShapedType sourceType, int64_t splitTensorDim, ArrayRef<int64_t> gridShape,
+    ArrayRef<GridAxis> unsplitAxes) {
+  SmallVector<int64_t> targetShape = llvm::to_vector(sourceType.getShape());
+  for (GridAxis gridAxis : unsplitAxes)
+    targetShape[splitTensorDim] =
+        gatherDimension(targetShape[splitTensorDim], gridShape[gridAxis]);
+  return sourceType.cloneWith(targetShape, sourceType.getElementType());
 }
 
-static std::tuple<TypedValue<ShapedType>, Sharding> unsplitLastAxisInResharding(
+// Perform the resharding for the unsplit last axes case.
+// This basically performs an all-gather along the unsplit grid axes.
+static std::tuple<TypedValue<ShapedType>, Sharding> unsplitLastAxesInResharding(
     ImplicitLocOpBuilder &builder, Sharding sourceSharding,
     ShapedType sourceUnshardedShape, TypedValue<ShapedType> sourceShard,
-    GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis) {
+    GridOp grid, int64_t splitTensorDim, ArrayRef<GridAxis> unsplitAxes) {
   MLIRContext *ctx = builder.getContext();
   builder.setInsertionPointAfterValue(sourceShard);
 
-  Sharding targetSharding = targetShardingInUnsplitLastAxis(
-      ctx, std::move(sourceSharding), splitTensorAxis);
-  ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
-      sourceShard.getType(), grid.getShape()[splitGridAxis], splitTensorAxis);
+  Sharding targetSharding = targetShardingInUnsplitLastAxes(
+      ctx, std::move(sourceSharding), splitTensorDim, unsplitAxes.size());
+  ShapedType allGatherResultType = allGatherResultTypeInUnsplitLastAxes(
+      sourceShard.getType(), splitTensorDim, grid.getShape(), unsplitAxes);
   Value allGatherResult = AllGatherOp::create(
       builder,
-      RankedTensorType::get(allGatherResultShape.getShape(),
-                            allGatherResultShape.getElementType()),
-      grid.getSymName(), SmallVector<GridAxis>({splitGridAxis}), sourceShard,
-      APInt(64, splitTensorAxis));
-  ShapedType targetShape =
+      RankedTensorType::get(allGatherResultType.getShape(),
+                            allGatherResultType.getElementType()),
+      grid.getSymName(), unsplitAxes, sourceShard, APInt(64, splitTensorDim));
+  ShapedType targetType =
       shardShapedType(sourceUnshardedShape, grid, targetSharding);
   TypedValue<ShapedType> targetShard =
-      tensor::CastOp::create(builder, targetShape, allGatherResult).getResult();
+      tensor::CastOp::create(builder, targetType, allGatherResult).getResult();
   return {targetShard, targetSharding};
 }
 
 static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
-tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
+tryUnsplitLastAxesInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
                                const Sharding &sourceSharding,
                                Sharding targetSharding,
                                ShapedType sourceUnshardedShape,
                                TypedValue<ShapedType> sourceShard) {
-  if (auto detectRes = detectUnsplitLastAxisInResharding(
+  if (auto detectRes = detectUnsplitLastAxesInResharding(
           sourceSharding, std::move(targetSharding))) {
-    auto [tensorAxis, gridAxis] = detectRes.value();
-    return unsplitLastAxisInResharding(builder, sourceSharding,
+    auto [tensorDim, gridAxes] = detectRes.value();
+    return unsplitLastAxesInResharding(builder, sourceSharding,
                                        sourceUnshardedShape, sourceShard, grid,
-                                       tensorAxis, gridAxis);
+                                       tensorDim, gridAxes);
   }
 
   return std::nullopt;
@@ -477,7 +496,7 @@ reshard(ImplicitLocOpBuilder &builder, GridOp grid,
                    trySplitLastAxisInResharding(builder, grid, sourceSharding,
                                                 targetSharding, sourceShard)) {
       std::tie(targetShard, actualTargetSharding) = tryRes.value();
-    } else if (auto tryRes = tryUnsplitLastAxisInResharding(
+    } else if (auto tryRes = tryUnsplitLastAxesInResharding(
                    builder, grid, sourceSharding, targetSharding,
                    sourceUnshardedValue.getType(), sourceShard)) {
       std::tie(targetShard, actualTargetSharding) = tryRes.value();
diff --git a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
index 6161c131c8f50..f3da09d05e3b8 100644
--- a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
+++ b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
@@ -5,12 +5,23 @@
 shard.grid @grid0(shape = 3x4x5)
 func.func @process_multi_index() -> (index, index, index) {
   // CHECK: mpi.comm_rank
-  // CHECK: [[res:%.*]]:3 = affine.delinearize_index %1 into (3, 4, 5) : index, index, index 
+  // CHECK: [[v1:%.*]] = arith.index_cast
+  // CHECK: [[res:%.*]]:3 = affine.delinearize_index [[v1]] into (3, 4, 5) : index, index, index 
   %0:3 = shard.process_multi_index on @grid0 axes = [] : index, index, index
   // CHECK: return [[res]]#0, [[res]]#1, [[res]]#2 : index, index, index
   return %0#0, %0#1, %0#2 : index, index, index
 }
 
+// CHECK-LABEL: func @process_multi_index_reorder
+func.func @process_multi_index_reorder() -> (index, index) {
+  // CHECK: mpi.comm_rank
+  // CHECK: [[v1:%.*]] = arith.index_cast
+  // CHECK: [[v2:%.*]]:3 = affine.delinearize_index [[v1]] into (3, 4, 5) : index, index, index
+  %0:2 = shard.process_multi_index on @grid0 axes = [2, 0] : index, index
+  // CHECK: return [[v2]]#2, [[v2]]#0 : index, index
+  return %0#0, %0#1 : index, index
+}
+
 // CHECK-LABEL: func @process_linear_index
 func.func @process_linear_index() -> index {
   // CHECK: %[[RES:.*]], %[[rank:.*]] = mpi.comm_rank
diff --git a/mlir/test/Dialect/Shard/partition.mlir b/mlir/test/Dialect/Shard/partition.mlir
index 4c8271aefcafc..d5db8073fcf2e 100644
--- a/mlir/test/Dialect/Shard/partition.mlir
+++ b/mlir/test/Dialect/Shard/partition.mlir
@@ -5,6 +5,7 @@
 shard.grid @grid_1d(shape = 2)
 shard.grid @grid_1d_4(shape = 4)
 shard.grid @grid_2d_16(shape = 4x4)
+shard.grid @grid_4d(shape = 2x3x4x5)
 
 // CHECK-LABEL: func @return_sharding
 func.func @return_sharding(
@@ -52,6 +53,29 @@ func.func @sharding_triplet(
   return %sharded_1 : tensor<2xf32>
 }
 
+// CHECK-LABEL: func.func @unsplit_last_axes_some(
+// CHECK-SAME: [[varg0:%.*]]: tensor<6x2xi8>) -> tensor<6x24xi8> {
+func.func @unsplit_last_axes_some( %in2: tensor<6x48xi8>) -> tensor<6x48xi8> {
+  %sharding1 = shard.sharding @grid_4d split_axes = [[], [0,1,2]] : !shard.sharding
+  %in2_replicated = shard.shard %in2 to %sharding1 : tensor<6x48xi8>
+  %sharding2 = shard.sharding @grid_4d split_axes = [[], [0]] : !shard.sharding
+  %in2_sharded = shard.shard %in2_replicated to %sharding2 annotate_for_users : tensor<6x48xi8>
+  // CHECK: [[vall_gather:%.*]] = shard.all_gather [[varg0]] on @grid_4d grid_axes = [1, 2] gather_axis = 1 : tensor<6x2xi8> -> tensor<6x24xi8>
+  // CHECK: return [[vall_gather]] : tensor<6x24xi8>
+  return %in2_sharded : tensor<6x48xi8>
+}
+
+// CHECK-LABEL: func.func @unsplit_last_axes_all(
+// CHECK-SAME: [[varg0:%.*]]: tensor<2x48xi8>) -> tensor<48x48xi8> {
+func.func @unsplit_last_axes_all(%in2: tensor<48x48xi8>) -> tensor<48x48xi8> {
+  %sharding1 = shard.sharding @grid_4d split_axes = [[0,1,2]] : !shard.sharding
+  %in2_replicated = shard.shard %in2 to %sharding1 : tensor<48x48xi8>
+  %sharding2 = shard.sharding @grid_4d split_axes = [[]] : !shard.sharding
+  %in2_sharded = shard.shard %in2_replicated to %sharding2 annotate_for_users : tensor<48x48xi8>
+  // CHECK: [[vall_gather:%.*]] = shard.all_gather [[varg0]] on @grid_4d grid_axes = [0, 1, 2] gather_axis = 0 : tensor<2x48xi8> -> tensor<48x48xi8>
+  // CHECK: return [[vall_gather]] : tensor<48x48xi8>
+  return %in2_sharded : tensor<48x48xi8>
+}
 
 // CHECK-LABEL: func @move_split_axis
 func.func @move_split_axis(

``````````

</details>


https://github.com/llvm/llvm-project/pull/180754


More information about the Mlir-commits mailing list