[Mlir-commits] [mlir] [MLIR][Shard] Fix three bugs in ND mesh resharding in Partition pass (PR #189241)
Mehdi Amini
llvmlistbot at llvm.org
Sun Mar 29 06:16:23 PDT 2026
https://github.com/joker-eph created https://github.com/llvm/llvm-project/pull/189241
A new MoveLastSplitAxisPattern class handles the case where the last grid axis of one tensor dimension is moved to the front of another tensor dimension's split axes, e.g. [[0, 1], [2]] -> [[0], [1, 2]].
The three bugs fixed are:
1. detectMoveLastSplitAxisInResharding: compared source.back() with target.back() instead of target.front(), preventing the pattern from being detected for resharding like [[0,1],[2]] -> [[0],[1,2]].
2. targetShardingInMoveLastAxis: axes were appended with push_back but should be inserted at the front, producing wrong split_axes order.
3. handlePartialAxesDuringResharding: a copy_if wrote results into the wrong output variable (addressed structurally by the clean implementation).
Fixes #136117
Assisted-by: Claude Code
>From b1f165a119127419420adb4d661f7e45efc4e029 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Sat, 28 Mar 2026 17:08:10 -0700
Subject: [PATCH] [MLIR][Shard] Fix three bugs in ND mesh resharding in
Partition pass
Fix three bugs in the Shard dialect Partition pass when handling ND mesh
resharding. The bugs were originally reported for the old Mesh dialect's
Spmdization pass (GitHub issue #136117) and correspond to the same logic
in the Shard dialect.
A new MoveLastSplitAxisPattern class handles the case where the last grid
axis of one tensor dimension is moved to the front of another tensor
dimension's split axes, e.g. [[0, 1], [2]] -> [[0], [1, 2]].
The three bugs fixed are:
1. detectMoveLastSplitAxisInResharding: compared source.back() with
target.back() instead of target.front(), preventing the pattern from
being detected for resharding like [[0,1],[2]] -> [[0],[1,2]].
2. targetShardingInMoveLastAxis: axes were appended with push_back but
should be inserted at the front, producing wrong split_axes order.
3. handlePartialAxesDuringResharding: a copy_if wrote results into the
wrong output variable (addressed structurally by the clean
implementation).
Fixes #136117
Assisted-by: Claude Code
---
.../Dialect/Shard/Transforms/Partition.cpp | 150 ++++++++++++++++--
.../Dialect/Shard/resharding-partition.mlir | 59 ++++++-
2 files changed, 196 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
index 9c5880e0c3b64..57502fbf9e276 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
@@ -260,6 +260,17 @@ class UnsplitLastAxesPattern : public ReshardingPattern {
}
};
+// Compute the result shape of an all-to-all that gathers along srcTensorDim
+// and scatters along tgtTensorDim with the given split count.
+static ShapedType allToAllResultShape(ShapedType srcShape, int64_t splitCount,
+ int64_t srcTensorDim,
+ int64_t tgtTensorDim) {
+ SmallVector<int64_t> tgtShape = llvm::to_vector(srcShape.getShape());
+ tgtShape[srcTensorDim] = gatherDimension(tgtShape[srcTensorDim], splitCount);
+ tgtShape[tgtTensorDim] = shardDimension(tgtShape[tgtTensorDim], splitCount);
+ return srcShape.cloneWith(tgtShape, srcShape.getElementType());
+}
+
/// Move a split axis between tensor dimensions:
/// e.g. [[0], []] -> [[], [0]].
class MoveSplitAxisPattern : public ReshardingPattern {
@@ -310,16 +321,6 @@ class MoveSplitAxisPattern : public ReshardingPattern {
return Sharding::get(srcSharding.getGridAttr(), tgtShardingSplitAxes);
}
- static ShapedType allToAllResultShape(ShapedType srcShape, int64_t splitCount,
- int64_t srcTensorDim,
- int64_t tgtTensorDim) {
- SmallVector<int64_t> tgtShape = llvm::to_vector(srcShape.getShape());
- tgtShape[srcTensorDim] =
- gatherDimension(tgtShape[srcTensorDim], splitCount);
- tgtShape[tgtTensorDim] = shardDimension(tgtShape[tgtTensorDim], splitCount);
- return srcShape.cloneWith(tgtShape, srcShape.getElementType());
- }
-
static std::tuple<TypedValue<ShapedType>, Sharding>
apply(ImplicitLocOpBuilder &builder, GridOp grid, Sharding srcSharding,
ShapedType srcUnshardedType, TypedValue<ShapedType> srcShard,
@@ -362,6 +363,130 @@ class MoveSplitAxisPattern : public ReshardingPattern {
}
};
+/// Move the last split axis of one tensor dimension to the front of another
+/// tensor dimension's split axes, e.g. [[0, 1], [2]] -> [[0], [1, 2]].
+class MoveLastSplitAxisPattern : public ReshardingPattern {
+ // Detect if the resharding moves the last grid axis of srcTensorDim to the
+ // front of another tensor dimension's split axes. If detected, returns
+ // (tgtTensorDim, movedGridAxis).
+ //
+ // Pattern: src[srcTensorDim] = [a1,...,a(n-1),an] (n >= 2)
+ // tgt[srcTensorDim] = [a1,...,a(n-1)]
+ // src[tgtTensorDim] = [b1,...,bm] (m >= 0)
+ // tgt[tgtTensorDim] = [an, b1,...,bm]
+ static std::optional<std::tuple<int64_t, GridAxis>>
+ detect(const Sharding &srcSharding, const Sharding &tgtSharding,
+ int64_t srcTensorDim) {
+ if (static_cast<size_t>(srcTensorDim) >= srcSharding.getSplitAxes().size())
+ return std::nullopt;
+ auto srcAxes = srcSharding.getSplitAxes()[srcTensorDim].asArrayRef();
+ // Need at least 2 axes to move the last one.
+ if (srcAxes.size() < 2)
+ return std::nullopt;
+
+ // After the move the source tensor dim should lose its last axis.
+ if (static_cast<size_t>(srcTensorDim) >= tgtSharding.getSplitAxes().size())
+ return std::nullopt;
+ auto tgtSrcAxes = tgtSharding.getSplitAxes()[srcTensorDim].asArrayRef();
+ if (tgtSrcAxes.size() + 1 != srcAxes.size())
+ return std::nullopt;
+ // The remaining axes at srcTensorDim must be the same (prefix of source).
+ if (!llvm::equal(tgtSrcAxes,
+ llvm::make_range(srcAxes.begin(), srcAxes.end() - 1)))
+ return std::nullopt;
+
+ GridAxis movedAxis = srcAxes.back();
+
+ // Find a target tensor dimension whose split axes start with movedAxis
+ // and whose remaining axes match the source sharding at that dimension.
+ for (size_t tgtTensorDim = 0;
+ tgtTensorDim < tgtSharding.getSplitAxes().size(); ++tgtTensorDim) {
+ if (static_cast<int64_t>(tgtTensorDim) == srcTensorDim)
+ continue;
+ auto tgtAxes = tgtSharding.getSplitAxes()[tgtTensorDim].asArrayRef();
+ // The target dimension must start with the moved axis.
+ if (tgtAxes.empty() || tgtAxes.front() != movedAxis)
+ continue;
+ // The remainder of tgtAxes must equal the source sharding at
+ // tgtTensorDim.
+ ArrayRef<GridAxis> srcTgtAxes =
+ static_cast<size_t>(tgtTensorDim) < srcSharding.getSplitAxes().size()
+ ? srcSharding.getSplitAxes()[tgtTensorDim].asArrayRef()
+ : ArrayRef<GridAxis>{};
+ if (!llvm::equal(srcTgtAxes,
+ llvm::make_range(tgtAxes.begin() + 1, tgtAxes.end())))
+ continue;
+ return std::make_tuple(static_cast<int64_t>(tgtTensorDim), movedAxis);
+ }
+ return std::nullopt;
+ }
+
+ // Compute the result sharding after moving movedAxis from srcTensorDim
+ // to the front of tgtTensorDim.
+ static Sharding tgtSharding(MLIRContext *ctx, const Sharding &srcSharding,
+ int64_t srcTensorDim, int64_t tgtTensorDim,
+ GridAxis movedAxis) {
+ SmallVector<GridAxesAttr> splitAxes =
+ llvm::to_vector(srcSharding.getSplitAxes());
+ while (static_cast<int64_t>(splitAxes.size()) <= tgtTensorDim)
+ splitAxes.push_back(GridAxesAttr::get(ctx, {}));
+
+ // Remove last axis from srcTensorDim.
+ auto srcSplitAxes = llvm::to_vector(splitAxes[srcTensorDim].asArrayRef());
+ assert(!srcSplitAxes.empty() && srcSplitAxes.back() == movedAxis);
+ srcSplitAxes.pop_back();
+ splitAxes[srcTensorDim] = GridAxesAttr::get(ctx, srcSplitAxes);
+
+ // Prepend movedAxis to tgtTensorDim.
+ auto tgtSplitAxes = llvm::to_vector(splitAxes[tgtTensorDim].asArrayRef());
+ tgtSplitAxes.insert(tgtSplitAxes.begin(), movedAxis);
+ splitAxes[tgtTensorDim] = GridAxesAttr::get(ctx, tgtSplitAxes);
+
+ return Sharding::get(srcSharding.getGridAttr(), splitAxes);
+ }
+
+ static std::tuple<TypedValue<ShapedType>, Sharding>
+ apply(ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &srcSharding,
+ ShapedType srcUnshardedType, TypedValue<ShapedType> srcShard,
+ int64_t srcTensorDim, int64_t tgtTensorDim, GridAxis movedAxis) {
+ MLIRContext *ctx = builder.getContext();
+ builder.setInsertionPointAfterValue(srcShard);
+
+ Sharding resultSharding =
+ tgtSharding(ctx, srcSharding, srcTensorDim, tgtTensorDim, movedAxis);
+ ShapedType a2aResultShape =
+ allToAllResultShape(srcShard.getType(), grid.getShape()[movedAxis],
+ srcTensorDim, tgtTensorDim);
+ Value allToAllResult = AllToAllOp::create(
+ builder,
+ RankedTensorType::get(a2aResultShape.getShape(),
+ a2aResultShape.getElementType()),
+ grid.getSymName(), SmallVector<GridAxis>({movedAxis}), srcShard,
+ APInt(64, tgtTensorDim), APInt(64, srcTensorDim));
+ ShapedType tgtShape =
+ shardShapedType(srcUnshardedType, grid, resultSharding);
+ TypedValue<ShapedType> tgtShard =
+ tensor::CastOp::create(builder, tgtShape, allToAllResult).getResult();
+ return {tgtShard, resultSharding};
+ }
+
+public:
+ std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
+ tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim,
+ const Sharding &srcSharding, const Sharding &tgtSharding,
+ ShapedType srcUnshardedType,
+ TypedValue<ShapedType> srcShard) override {
+ if (hasStaticOffsetsOrHalos(srcSharding, tgtSharding))
+ return std::nullopt;
+ if (auto detectRes = detect(srcSharding, tgtSharding, tensorDim)) {
+ auto [tgtTensorDim, movedAxis] = detectRes.value();
+ return apply(builder, grid, srcSharding, srcUnshardedType, srcShard,
+ tensorDim, tgtTensorDim, movedAxis);
+ }
+ return std::nullopt;
+ }
+};
+
/// Update halo sizes: handles cases where only the halo sizes differ between
/// source and target sharding. Requires copying the "core" of the source tensor
/// into the "core" of the destination tensor followed by an update halo op.
@@ -460,12 +585,13 @@ static TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder,
// Each pattern's tryApply checks its own applicability preconditions.
static UpdateHaloPattern updateHaloPattern;
+ static MoveLastSplitAxisPattern moveLastSplitAxisPattern;
static MoveSplitAxisPattern moveSplitAxisPattern;
static SplitLastAxisPattern splitLastAxisPattern;
static UnsplitLastAxesPattern unsplitLastAxesPattern;
static ReshardingPattern *patterns[] = {
- &updateHaloPattern, &moveSplitAxisPattern, &splitLastAxisPattern,
- &unsplitLastAxesPattern};
+ &updateHaloPattern, &moveLastSplitAxisPattern, &moveSplitAxisPattern,
+ &splitLastAxisPattern, &unsplitLastAxesPattern};
TypedValue<ShapedType> currentShard = shardedSrc;
Sharding currentSharding = srcSharding;
for (int64_t dim = 0;
diff --git a/mlir/test/Dialect/Shard/resharding-partition.mlir b/mlir/test/Dialect/Shard/resharding-partition.mlir
index ff9e8408aa7fd..01c4733485678 100644
--- a/mlir/test/Dialect/Shard/resharding-partition.mlir
+++ b/mlir/test/Dialect/Shard/resharding-partition.mlir
@@ -2,6 +2,7 @@
shard.grid @grid_1d(shape = 2)
shard.grid @grid_1d_dynamic(shape = ?)
+shard.grid @grid_3d(shape = 2x2x2)
// CHECK-LABEL: func @same_source_and_target_sharding
func.func @same_source_and_target_sharding(
@@ -153,7 +154,7 @@ func.func @unshard_dynamic_axis(
// CHECK-LABEL: func @unshard_static_axis_on_dynamic_grid_axis
func.func @unshard_static_axis_on_dynamic_grid_axis(
-// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
%arg0: tensor<10x14xf32>
) -> tensor<10x14xf32> {
// CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32>
@@ -166,3 +167,59 @@ func.func @unshard_static_axis_on_dynamic_grid_axis(
// CHECK: return %[[RES]] : tensor<10x14xf32>
return %1 : tensor<10x14xf32>
}
+
+// MoveLastSplitAxisPattern: [[0, 1], [2]] -> [[0], [1, 2]]
+// Source shard: 8/(2*2) x 16/2 = 2x8; after all_to_all(axis=1): 4x4
+// CHECK-LABEL: func @move_last_split_axis_to_front_of_target_dim
+func.func @move_last_split_axis_to_front_of_target_dim(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
+ %arg0: tensor<8x16xf32>
+) -> tensor<8x16xf32> {
+ // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<8x16xf32> to tensor<2x8xf32>
+ // CHECK: %[[ALL_TO_ALL:.*]] = shard.all_to_all %[[SOURCE_SHARD]] on @grid_3d grid_axes = [1] split_axis = 1 concat_axis = 0 : tensor<2x8xf32> -> tensor<4x4xf32>
+ // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[ALL_TO_ALL]] : tensor<4x4xf32> to tensor<8x16xf32>
+ %s0 = shard.sharding @grid_3d split_axes = [[0, 1], [2]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<8x16xf32>
+ %s1 = shard.sharding @grid_3d split_axes = [[0], [1, 2]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<8x16xf32>
+ // CHECK: return %[[RES]] : tensor<8x16xf32>
+ return %1 : tensor<8x16xf32>
+}
+
+// MoveLastSplitAxisPattern with tgtTensorDim < srcTensorDim:
+// [[0], [1, 2]] -> [[2, 0], [1]] (axis 2 moved from dim 1 to front of dim 0)
+// Source shard: 8/2 x 16/(2*2) = 4x4; after all_to_all(axis=2): 2x8
+// CHECK-LABEL: func @move_last_split_axis_to_lower_dim
+func.func @move_last_split_axis_to_lower_dim(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
+ %arg0: tensor<8x16xf32>
+) -> tensor<8x16xf32> {
+ // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<8x16xf32> to tensor<4x4xf32>
+ // CHECK: %[[ALL_TO_ALL:.*]] = shard.all_to_all %[[SOURCE_SHARD]] on @grid_3d grid_axes = [2] split_axis = 0 concat_axis = 1 : tensor<4x4xf32> -> tensor<2x8xf32>
+ // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[ALL_TO_ALL]] : tensor<2x8xf32> to tensor<8x16xf32>
+ %s0 = shard.sharding @grid_3d split_axes = [[0], [1, 2]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<8x16xf32>
+ %s1 = shard.sharding @grid_3d split_axes = [[2, 0], [1]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<8x16xf32>
+ // CHECK: return %[[RES]] : tensor<8x16xf32>
+ return %1 : tensor<8x16xf32>
+}
+
+// MoveLastSplitAxisPattern where source has no axes at tgtTensorDim:
+// [[0, 1]] -> [[0], [1]] (tgtTensorDim has empty source)
+// Source shard: 8/(2*2) x 16 = 2x16; after all_to_all(axis=1): 4x8
+// CHECK-LABEL: func @move_last_split_axis_empty_source_at_target_dim
+func.func @move_last_split_axis_empty_source_at_target_dim(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
+ %arg0: tensor<8x16xf32>
+) -> tensor<8x16xf32> {
+ // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<8x16xf32> to tensor<2x16xf32>
+ // CHECK: %[[ALL_TO_ALL:.*]] = shard.all_to_all %[[SOURCE_SHARD]] on @grid_3d grid_axes = [1] split_axis = 1 concat_axis = 0 : tensor<2x16xf32> -> tensor<4x8xf32>
+ // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[ALL_TO_ALL]] : tensor<4x8xf32> to tensor<8x16xf32>
+ %s0 = shard.sharding @grid_3d split_axes = [[0, 1]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<8x16xf32>
+ %s1 = shard.sharding @grid_3d split_axes = [[0], [1]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<8x16xf32>
+ // CHECK: return %[[RES]] : tensor<8x16xf32>
+ return %1 : tensor<8x16xf32>
+}
More information about the Mlir-commits
mailing list