[Mlir-commits] [mlir] [mlir][shard, mpi] Allow more than one last axis to be "unsplit" (PR #180754)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 10 07:29:50 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Frank Schlimbach (fschlimb)
<details>
<summary>Changes</summary>
A resharding pattern allowed only a single trailing axis to be "unsplit".
This PR allows multiple trailing axes to be "unsplit".
---
Full diff: https://github.com/llvm/llvm-project/pull/180754.diff
3 Files Affected:
- (modified) mlir/lib/Dialect/Shard/Transforms/Partition.cpp (+77-58)
- (modified) mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir (+12-1)
- (modified) mlir/test/Dialect/Shard/partition.mlir (+24)
``````````diff
diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
index e619c7073a8c4..8652d665e46bf 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
@@ -132,98 +132,117 @@ trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
return std::nullopt;
}
-// Detect if the resharding is of type e.g.
-// [[0, 1, 2]] -> [[0, 1]].
-// If detected, returns the corresponding tensor axis grid axis pair.
-static std::optional<std::tuple<int64_t, GridAxis>>
-detectUnsplitLastAxisInResharding(const Sharding &sourceSharding,
+// Detect if the resharding removes trailing split Axes along a tensor
+// dimension, e.g.
+// [[0, 1, 2]] -> [[0, 1]], [[0, 1, 2]] -> [0] or [[0, 1, 2]] -> [].
+// If detected, returns the corresponding (tensor dim, grid axes) pair, where
+// the "grid axes" are the removed trailing split axes.
+static std::optional<std::tuple<int64_t, SmallVector<GridAxis>>>
+detectUnsplitLastAxesInResharding(const Sharding &sourceSharding,
const Sharding &targetSharding) {
- for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
- ++tensorAxis) {
- if (targetSharding.getSplitAxes().size() > tensorAxis) {
- if (sourceSharding.getSplitAxes()[tensorAxis].size() !=
- targetSharding.getSplitAxes()[tensorAxis].size() + 1)
+ for (size_t tensorDim = 0; tensorDim < sourceSharding.getSplitAxes().size();
+ ++tensorDim) {
+ if (targetSharding.getSplitAxes().size() > tensorDim) {
+ // No match if the target sharding does not have less split axes than the
+ // source sharding along the current tensor dimension.
+ if (sourceSharding.getSplitAxes()[tensorDim].size() <=
+ targetSharding.getSplitAxes()[tensorDim].size())
continue;
- if (!llvm::equal(
- llvm::make_range(
- sourceSharding.getSplitAxes()[tensorAxis]
- .asArrayRef()
- .begin(),
- sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
- 1),
- targetSharding.getSplitAxes()[tensorAxis].asArrayRef()))
+ // No match if the split axes of the target sharding are different from
+ // the first split axes of the source sharding.
+ if (!std::equal(
+ targetSharding.getSplitAxes()[tensorDim].asArrayRef().begin(),
+ targetSharding.getSplitAxes()[tensorDim].asArrayRef().end(),
+ sourceSharding.getSplitAxes()[tensorDim].asArrayRef().begin()))
continue;
} else {
- if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1)
+ // Here the target dimension is replicated; there is nothing to do if the
+ // source dimension is also replicated.
+ if (sourceSharding.getSplitAxes()[tensorDim].size() == 0)
continue;
}
- return std::make_tuple(
- tensorAxis,
- sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
+ // This is a match. Return the current tensor dimension and the trailing
+ // grid axis of the source sharding along this dimension.
+ SmallVector<GridAxis> unsplitAxes;
+ size_t dimOff = tensorDim >= targetSharding.getSplitAxes().size()
+ ? 0
+ : targetSharding.getSplitAxes()[tensorDim].size();
+ for (auto a =
+ sourceSharding.getSplitAxes()[tensorDim].asArrayRef().begin() +
+ dimOff;
+ a != sourceSharding.getSplitAxes()[tensorDim].asArrayRef().end(); ++a)
+ unsplitAxes.push_back(*a);
+ return std::make_tuple(tensorDim, unsplitAxes);
}
return std::nullopt;
}
-static Sharding targetShardingInUnsplitLastAxis(MLIRContext *ctx,
+// Return the resulting Sharding if the unsplit last axes resharding is applied.
+static Sharding targetShardingInUnsplitLastAxes(MLIRContext *ctx,
const Sharding &sourceSharding,
- int64_t splitTensorAxis) {
- SmallVector<GridAxesAttr> targetShardingSplitAxes =
+ int64_t splitTensorDim,
+ size_t numUnsplitAxes) {
+ SmallVector<GridAxesAttr> resSplitAxes =
llvm::to_vector(sourceSharding.getSplitAxes());
- assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >
- splitTensorAxis);
- auto targetSplitAxes =
- llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
-
- targetSplitAxes.pop_back();
- targetShardingSplitAxes[splitTensorAxis] =
- GridAxesAttr::get(ctx, targetSplitAxes);
- return Sharding::get(sourceSharding.getGridAttr(), targetShardingSplitAxes);
+ assert(static_cast<int64_t>(resSplitAxes.size()) > splitTensorDim);
+ ArrayRef<GridAxis> srcSplitAxes = resSplitAxes[splitTensorDim].asArrayRef();
+ assert(srcSplitAxes.size() >= numUnsplitAxes);
+ size_t numSplitAxes = srcSplitAxes.size() - numUnsplitAxes;
+ SmallVector<GridAxis> newSplitAxes(srcSplitAxes.begin(),
+ srcSplitAxes.begin() + numSplitAxes);
+ resSplitAxes[splitTensorDim] = GridAxesAttr::get(ctx, newSplitAxes);
+ return Sharding::get(sourceSharding.getGridAttr(), resSplitAxes);
}
-static ShapedType allGatherResultShapeInUnsplitLastAxis(
- ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) {
- SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
- targetShape[splitTensorAxis] =
- gatherDimension(targetShape[splitTensorAxis], splitCount);
- return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
+// Return the resulting Tensor type after applying the unsplit last axes
+// resharding.
+static ShapedType allGatherResultTypeInUnsplitLastAxes(
+ ShapedType sourceType, int64_t splitTensorDim, ArrayRef<int64_t> gridShape,
+ ArrayRef<GridAxis> unsplitAxes) {
+ SmallVector<int64_t> targetShape = llvm::to_vector(sourceType.getShape());
+ for (GridAxis gridAxis : unsplitAxes)
+ targetShape[splitTensorDim] =
+ gatherDimension(targetShape[splitTensorDim], gridShape[gridAxis]);
+ return sourceType.cloneWith(targetShape, sourceType.getElementType());
}
-static std::tuple<TypedValue<ShapedType>, Sharding> unsplitLastAxisInResharding(
+// Perform the resharding for the unsplit last axes case.
+// This basically performs an all-gather along the unsplit grid axes.
+static std::tuple<TypedValue<ShapedType>, Sharding> unsplitLastAxesInResharding(
ImplicitLocOpBuilder &builder, Sharding sourceSharding,
ShapedType sourceUnshardedShape, TypedValue<ShapedType> sourceShard,
- GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis) {
+ GridOp grid, int64_t splitTensorDim, ArrayRef<GridAxis> unsplitAxes) {
MLIRContext *ctx = builder.getContext();
builder.setInsertionPointAfterValue(sourceShard);
- Sharding targetSharding = targetShardingInUnsplitLastAxis(
- ctx, std::move(sourceSharding), splitTensorAxis);
- ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
- sourceShard.getType(), grid.getShape()[splitGridAxis], splitTensorAxis);
+ Sharding targetSharding = targetShardingInUnsplitLastAxes(
+ ctx, std::move(sourceSharding), splitTensorDim, unsplitAxes.size());
+ ShapedType allGatherResultType = allGatherResultTypeInUnsplitLastAxes(
+ sourceShard.getType(), splitTensorDim, grid.getShape(), unsplitAxes);
Value allGatherResult = AllGatherOp::create(
builder,
- RankedTensorType::get(allGatherResultShape.getShape(),
- allGatherResultShape.getElementType()),
- grid.getSymName(), SmallVector<GridAxis>({splitGridAxis}), sourceShard,
- APInt(64, splitTensorAxis));
- ShapedType targetShape =
+ RankedTensorType::get(allGatherResultType.getShape(),
+ allGatherResultType.getElementType()),
+ grid.getSymName(), unsplitAxes, sourceShard, APInt(64, splitTensorDim));
+ ShapedType targetType =
shardShapedType(sourceUnshardedShape, grid, targetSharding);
TypedValue<ShapedType> targetShard =
- tensor::CastOp::create(builder, targetShape, allGatherResult).getResult();
+ tensor::CastOp::create(builder, targetType, allGatherResult).getResult();
return {targetShard, targetSharding};
}
static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
-tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
+tryUnsplitLastAxesInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
const Sharding &sourceSharding,
Sharding targetSharding,
ShapedType sourceUnshardedShape,
TypedValue<ShapedType> sourceShard) {
- if (auto detectRes = detectUnsplitLastAxisInResharding(
+ if (auto detectRes = detectUnsplitLastAxesInResharding(
sourceSharding, std::move(targetSharding))) {
- auto [tensorAxis, gridAxis] = detectRes.value();
- return unsplitLastAxisInResharding(builder, sourceSharding,
+ auto [tensorDim, gridAxes] = detectRes.value();
+ return unsplitLastAxesInResharding(builder, sourceSharding,
sourceUnshardedShape, sourceShard, grid,
- tensorAxis, gridAxis);
+ tensorDim, gridAxes);
}
return std::nullopt;
@@ -477,7 +496,7 @@ reshard(ImplicitLocOpBuilder &builder, GridOp grid,
trySplitLastAxisInResharding(builder, grid, sourceSharding,
targetSharding, sourceShard)) {
std::tie(targetShard, actualTargetSharding) = tryRes.value();
- } else if (auto tryRes = tryUnsplitLastAxisInResharding(
+ } else if (auto tryRes = tryUnsplitLastAxesInResharding(
builder, grid, sourceSharding, targetSharding,
sourceUnshardedValue.getType(), sourceShard)) {
std::tie(targetShard, actualTargetSharding) = tryRes.value();
diff --git a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
index 6161c131c8f50..f3da09d05e3b8 100644
--- a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
+++ b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
@@ -5,12 +5,23 @@
shard.grid @grid0(shape = 3x4x5)
func.func @process_multi_index() -> (index, index, index) {
// CHECK: mpi.comm_rank
- // CHECK: [[res:%.*]]:3 = affine.delinearize_index %1 into (3, 4, 5) : index, index, index
+ // CHECK: [[v1:%.*]] = arith.index_cast
+ // CHECK: [[res:%.*]]:3 = affine.delinearize_index [[v1]] into (3, 4, 5) : index, index, index
%0:3 = shard.process_multi_index on @grid0 axes = [] : index, index, index
// CHECK: return [[res]]#0, [[res]]#1, [[res]]#2 : index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}
+// CHECK-LABEL: func @process_multi_index_reorder
+func.func @process_multi_index_reorder() -> (index, index) {
+ // CHECK: mpi.comm_rank
+ // CHECK: [[v1:%.*]] = arith.index_cast
+ // CHECK: [[v2:%.*]]:3 = affine.delinearize_index [[v1]] into (3, 4, 5) : index, index, index
+ %0:2 = shard.process_multi_index on @grid0 axes = [2, 0] : index, index
+ // CHECK: return [[v2]]#2, [[v2]]#0 : index, index
+ return %0#0, %0#1 : index, index
+}
+
// CHECK-LABEL: func @process_linear_index
func.func @process_linear_index() -> index {
// CHECK: %[[RES:.*]], %[[rank:.*]] = mpi.comm_rank
diff --git a/mlir/test/Dialect/Shard/partition.mlir b/mlir/test/Dialect/Shard/partition.mlir
index 4c8271aefcafc..d5db8073fcf2e 100644
--- a/mlir/test/Dialect/Shard/partition.mlir
+++ b/mlir/test/Dialect/Shard/partition.mlir
@@ -5,6 +5,7 @@
shard.grid @grid_1d(shape = 2)
shard.grid @grid_1d_4(shape = 4)
shard.grid @grid_2d_16(shape = 4x4)
+shard.grid @grid_4d(shape = 2x3x4x5)
// CHECK-LABEL: func @return_sharding
func.func @return_sharding(
@@ -52,6 +53,29 @@ func.func @sharding_triplet(
return %sharded_1 : tensor<2xf32>
}
+// CHECK-LABEL: func.func @unsplit_last_axes_some(
+// CHECK-SAME: [[varg0:%.*]]: tensor<6x2xi8>) -> tensor<6x24xi8> {
+func.func @unsplit_last_axes_some( %in2: tensor<6x48xi8>) -> tensor<6x48xi8> {
+ %sharding1 = shard.sharding @grid_4d split_axes = [[], [0,1,2]] : !shard.sharding
+ %in2_replicated = shard.shard %in2 to %sharding1 : tensor<6x48xi8>
+ %sharding2 = shard.sharding @grid_4d split_axes = [[], [0]] : !shard.sharding
+ %in2_sharded = shard.shard %in2_replicated to %sharding2 annotate_for_users : tensor<6x48xi8>
+ // CHECK: [[vall_gather:%.*]] = shard.all_gather [[varg0]] on @grid_4d grid_axes = [1, 2] gather_axis = 1 : tensor<6x2xi8> -> tensor<6x24xi8>
+ // CHECK: return [[vall_gather]] : tensor<6x24xi8>
+ return %in2_sharded : tensor<6x48xi8>
+}
+
+// CHECK-LABEL: func.func @unsplit_last_axes_all(
+// CHECK-SAME: [[varg0:%.*]]: tensor<2x48xi8>) -> tensor<48x48xi8> {
+func.func @unsplit_last_axes_all(%in2: tensor<48x48xi8>) -> tensor<48x48xi8> {
+ %sharding1 = shard.sharding @grid_4d split_axes = [[0,1,2]] : !shard.sharding
+ %in2_replicated = shard.shard %in2 to %sharding1 : tensor<48x48xi8>
+ %sharding2 = shard.sharding @grid_4d split_axes = [[]] : !shard.sharding
+ %in2_sharded = shard.shard %in2_replicated to %sharding2 annotate_for_users : tensor<48x48xi8>
+ // CHECK: [[vall_gather:%.*]] = shard.all_gather [[varg0]] on @grid_4d grid_axes = [0, 1, 2] gather_axis = 0 : tensor<2x48xi8> -> tensor<48x48xi8>
+ // CHECK: return [[vall_gather]] : tensor<48x48xi8>
+ return %in2_sharded : tensor<48x48xi8>
+}
// CHECK-LABEL: func @move_split_axis
func.func @move_split_axis(
``````````
</details>
https://github.com/llvm/llvm-project/pull/180754
More information about the Mlir-commits
mailing list