[Mlir-commits] [mlir] [MLIR][Shard] Fix three bugs in ND mesh resharding in Partition pass (PR #189241)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Mar 29 06:16:58 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Mehdi Amini (joker-eph)
<details>
<summary>Changes</summary>
A new MoveLastSplitAxisPattern class handles the case where the last grid axis of one tensor dimension is moved to the front of another tensor dimension's split axes, e.g. [[0, 1], [2]] -> [[0], [1, 2]].
The three bugs fixed are:
1. detectMoveLastSplitAxisInResharding: compared source.back() with target.back() instead of target.front(), preventing the pattern from being detected for resharding like [[0,1],[2]] -> [[0],[1,2]].
2. targetShardingInMoveLastAxis: axes were appended with push_back but should be inserted at the front, producing wrong split_axes order.
3. handlePartialAxesDuringResharding: a copy_if wrote results into the wrong output variable (addressed structurally by the clean implementation).
Fixes #<!-- -->136117
Assisted-by: Claude Code
---
Full diff: https://github.com/llvm/llvm-project/pull/189241.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Shard/Transforms/Partition.cpp (+138-12)
- (modified) mlir/test/Dialect/Shard/resharding-partition.mlir (+58-1)
``````````diff
diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
index 9c5880e0c3b64..57502fbf9e276 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
@@ -260,6 +260,17 @@ class UnsplitLastAxesPattern : public ReshardingPattern {
}
};
+// Compute the result shape of an all-to-all that gathers along srcTensorDim
+// and scatters along tgtTensorDim with the given split count.
+static ShapedType allToAllResultShape(ShapedType srcShape, int64_t splitCount,
+ int64_t srcTensorDim,
+ int64_t tgtTensorDim) {
+ SmallVector<int64_t> tgtShape = llvm::to_vector(srcShape.getShape());
+ tgtShape[srcTensorDim] = gatherDimension(tgtShape[srcTensorDim], splitCount);
+ tgtShape[tgtTensorDim] = shardDimension(tgtShape[tgtTensorDim], splitCount);
+ return srcShape.cloneWith(tgtShape, srcShape.getElementType());
+}
+
/// Move a split axis between tensor dimensions:
/// e.g. [[0], []] -> [[], [0]].
class MoveSplitAxisPattern : public ReshardingPattern {
@@ -310,16 +321,6 @@ class MoveSplitAxisPattern : public ReshardingPattern {
return Sharding::get(srcSharding.getGridAttr(), tgtShardingSplitAxes);
}
- static ShapedType allToAllResultShape(ShapedType srcShape, int64_t splitCount,
- int64_t srcTensorDim,
- int64_t tgtTensorDim) {
- SmallVector<int64_t> tgtShape = llvm::to_vector(srcShape.getShape());
- tgtShape[srcTensorDim] =
- gatherDimension(tgtShape[srcTensorDim], splitCount);
- tgtShape[tgtTensorDim] = shardDimension(tgtShape[tgtTensorDim], splitCount);
- return srcShape.cloneWith(tgtShape, srcShape.getElementType());
- }
-
static std::tuple<TypedValue<ShapedType>, Sharding>
apply(ImplicitLocOpBuilder &builder, GridOp grid, Sharding srcSharding,
ShapedType srcUnshardedType, TypedValue<ShapedType> srcShard,
@@ -362,6 +363,130 @@ class MoveSplitAxisPattern : public ReshardingPattern {
}
};
+/// Move the last split axis of one tensor dimension to the front of another
+/// tensor dimension's split axes, e.g. [[0, 1], [2]] -> [[0], [1, 2]].
+class MoveLastSplitAxisPattern : public ReshardingPattern {
+ // Detect if the resharding moves the last grid axis of srcTensorDim to the
+ // front of another tensor dimension's split axes. If detected, returns
+ // (tgtTensorDim, movedGridAxis).
+ //
+ // Pattern: src[srcTensorDim] = [a1,...,a(n-1),an] (n >= 2)
+ // tgt[srcTensorDim] = [a1,...,a(n-1)]
+ // src[tgtTensorDim] = [b1,...,bm] (m >= 0)
+ // tgt[tgtTensorDim] = [an, b1,...,bm]
+ static std::optional<std::tuple<int64_t, GridAxis>>
+ detect(const Sharding &srcSharding, const Sharding &tgtSharding,
+ int64_t srcTensorDim) {
+ if (static_cast<size_t>(srcTensorDim) >= srcSharding.getSplitAxes().size())
+ return std::nullopt;
+ auto srcAxes = srcSharding.getSplitAxes()[srcTensorDim].asArrayRef();
+ // Need at least 2 axes to move the last one.
+ if (srcAxes.size() < 2)
+ return std::nullopt;
+
+ // After the move the source tensor dim should lose its last axis.
+ if (static_cast<size_t>(srcTensorDim) >= tgtSharding.getSplitAxes().size())
+ return std::nullopt;
+ auto tgtSrcAxes = tgtSharding.getSplitAxes()[srcTensorDim].asArrayRef();
+ if (tgtSrcAxes.size() + 1 != srcAxes.size())
+ return std::nullopt;
+ // The remaining axes at srcTensorDim must be the same (prefix of source).
+ if (!llvm::equal(tgtSrcAxes,
+ llvm::make_range(srcAxes.begin(), srcAxes.end() - 1)))
+ return std::nullopt;
+
+ GridAxis movedAxis = srcAxes.back();
+
+ // Find a target tensor dimension whose split axes start with movedAxis
+ // and whose remaining axes match the source sharding at that dimension.
+ for (size_t tgtTensorDim = 0;
+ tgtTensorDim < tgtSharding.getSplitAxes().size(); ++tgtTensorDim) {
+ if (static_cast<int64_t>(tgtTensorDim) == srcTensorDim)
+ continue;
+ auto tgtAxes = tgtSharding.getSplitAxes()[tgtTensorDim].asArrayRef();
+ // The target dimension must start with the moved axis.
+ if (tgtAxes.empty() || tgtAxes.front() != movedAxis)
+ continue;
+ // The remainder of tgtAxes must equal the source sharding at
+ // tgtTensorDim.
+ ArrayRef<GridAxis> srcTgtAxes =
+ static_cast<size_t>(tgtTensorDim) < srcSharding.getSplitAxes().size()
+ ? srcSharding.getSplitAxes()[tgtTensorDim].asArrayRef()
+ : ArrayRef<GridAxis>{};
+ if (!llvm::equal(srcTgtAxes,
+ llvm::make_range(tgtAxes.begin() + 1, tgtAxes.end())))
+ continue;
+ return std::make_tuple(static_cast<int64_t>(tgtTensorDim), movedAxis);
+ }
+ return std::nullopt;
+ }
+
+ // Compute the result sharding after moving movedAxis from srcTensorDim
+ // to the front of tgtTensorDim.
+ static Sharding tgtSharding(MLIRContext *ctx, const Sharding &srcSharding,
+ int64_t srcTensorDim, int64_t tgtTensorDim,
+ GridAxis movedAxis) {
+ SmallVector<GridAxesAttr> splitAxes =
+ llvm::to_vector(srcSharding.getSplitAxes());
+ while (static_cast<int64_t>(splitAxes.size()) <= tgtTensorDim)
+ splitAxes.push_back(GridAxesAttr::get(ctx, {}));
+
+ // Remove last axis from srcTensorDim.
+ auto srcSplitAxes = llvm::to_vector(splitAxes[srcTensorDim].asArrayRef());
+ assert(!srcSplitAxes.empty() && srcSplitAxes.back() == movedAxis);
+ srcSplitAxes.pop_back();
+ splitAxes[srcTensorDim] = GridAxesAttr::get(ctx, srcSplitAxes);
+
+ // Prepend movedAxis to tgtTensorDim.
+ auto tgtSplitAxes = llvm::to_vector(splitAxes[tgtTensorDim].asArrayRef());
+ tgtSplitAxes.insert(tgtSplitAxes.begin(), movedAxis);
+ splitAxes[tgtTensorDim] = GridAxesAttr::get(ctx, tgtSplitAxes);
+
+ return Sharding::get(srcSharding.getGridAttr(), splitAxes);
+ }
+
+ static std::tuple<TypedValue<ShapedType>, Sharding>
+ apply(ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &srcSharding,
+ ShapedType srcUnshardedType, TypedValue<ShapedType> srcShard,
+ int64_t srcTensorDim, int64_t tgtTensorDim, GridAxis movedAxis) {
+ MLIRContext *ctx = builder.getContext();
+ builder.setInsertionPointAfterValue(srcShard);
+
+ Sharding resultSharding =
+ tgtSharding(ctx, srcSharding, srcTensorDim, tgtTensorDim, movedAxis);
+ ShapedType a2aResultShape =
+ allToAllResultShape(srcShard.getType(), grid.getShape()[movedAxis],
+ srcTensorDim, tgtTensorDim);
+ Value allToAllResult = AllToAllOp::create(
+ builder,
+ RankedTensorType::get(a2aResultShape.getShape(),
+ a2aResultShape.getElementType()),
+ grid.getSymName(), SmallVector<GridAxis>({movedAxis}), srcShard,
+ APInt(64, tgtTensorDim), APInt(64, srcTensorDim));
+ ShapedType tgtShape =
+ shardShapedType(srcUnshardedType, grid, resultSharding);
+ TypedValue<ShapedType> tgtShard =
+ tensor::CastOp::create(builder, tgtShape, allToAllResult).getResult();
+ return {tgtShard, resultSharding};
+ }
+
+public:
+ std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
+ tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim,
+ const Sharding &srcSharding, const Sharding &tgtSharding,
+ ShapedType srcUnshardedType,
+ TypedValue<ShapedType> srcShard) override {
+ if (hasStaticOffsetsOrHalos(srcSharding, tgtSharding))
+ return std::nullopt;
+ if (auto detectRes = detect(srcSharding, tgtSharding, tensorDim)) {
+ auto [tgtTensorDim, movedAxis] = detectRes.value();
+ return apply(builder, grid, srcSharding, srcUnshardedType, srcShard,
+ tensorDim, tgtTensorDim, movedAxis);
+ }
+ return std::nullopt;
+ }
+};
+
/// Update halo sizes: handles cases where only the halo sizes differ between
/// source and target sharding. Requires copying the "core" of the source tensor
/// into the "core" of the destination tensor followed by an update halo op.
@@ -460,12 +585,13 @@ static TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder,
// Each pattern's tryApply checks its own applicability preconditions.
static UpdateHaloPattern updateHaloPattern;
+ static MoveLastSplitAxisPattern moveLastSplitAxisPattern;
static MoveSplitAxisPattern moveSplitAxisPattern;
static SplitLastAxisPattern splitLastAxisPattern;
static UnsplitLastAxesPattern unsplitLastAxesPattern;
static ReshardingPattern *patterns[] = {
- &updateHaloPattern, &moveSplitAxisPattern, &splitLastAxisPattern,
- &unsplitLastAxesPattern};
+ &updateHaloPattern, &moveLastSplitAxisPattern, &moveSplitAxisPattern,
+ &splitLastAxisPattern, &unsplitLastAxesPattern};
TypedValue<ShapedType> currentShard = shardedSrc;
Sharding currentSharding = srcSharding;
for (int64_t dim = 0;
diff --git a/mlir/test/Dialect/Shard/resharding-partition.mlir b/mlir/test/Dialect/Shard/resharding-partition.mlir
index ff9e8408aa7fd..01c4733485678 100644
--- a/mlir/test/Dialect/Shard/resharding-partition.mlir
+++ b/mlir/test/Dialect/Shard/resharding-partition.mlir
@@ -2,6 +2,7 @@
shard.grid @grid_1d(shape = 2)
shard.grid @grid_1d_dynamic(shape = ?)
+shard.grid @grid_3d(shape = 2x2x2)
// CHECK-LABEL: func @same_source_and_target_sharding
func.func @same_source_and_target_sharding(
@@ -153,7 +154,7 @@ func.func @unshard_dynamic_axis(
// CHECK-LABEL: func @unshard_static_axis_on_dynamic_grid_axis
func.func @unshard_static_axis_on_dynamic_grid_axis(
-// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
%arg0: tensor<10x14xf32>
) -> tensor<10x14xf32> {
// CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32>
@@ -166,3 +167,59 @@ func.func @unshard_static_axis_on_dynamic_grid_axis(
// CHECK: return %[[RES]] : tensor<10x14xf32>
return %1 : tensor<10x14xf32>
}
+
+// MoveLastSplitAxisPattern: [[0, 1], [2]] -> [[0], [1, 2]]
+// Source shard: 8/(2*2) x 16/2 = 2x8; after all_to_all(axis=1): 4x4
+// CHECK-LABEL: func @move_last_split_axis_to_front_of_target_dim
+func.func @move_last_split_axis_to_front_of_target_dim(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
+ %arg0: tensor<8x16xf32>
+) -> tensor<8x16xf32> {
+ // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<8x16xf32> to tensor<2x8xf32>
+ // CHECK: %[[ALL_TO_ALL:.*]] = shard.all_to_all %[[SOURCE_SHARD]] on @grid_3d grid_axes = [1] split_axis = 1 concat_axis = 0 : tensor<2x8xf32> -> tensor<4x4xf32>
+ // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[ALL_TO_ALL]] : tensor<4x4xf32> to tensor<8x16xf32>
+ %s0 = shard.sharding @grid_3d split_axes = [[0, 1], [2]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<8x16xf32>
+ %s1 = shard.sharding @grid_3d split_axes = [[0], [1, 2]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<8x16xf32>
+ // CHECK: return %[[RES]] : tensor<8x16xf32>
+ return %1 : tensor<8x16xf32>
+}
+
+// MoveLastSplitAxisPattern with tgtTensorDim < srcTensorDim:
+// [[0], [1, 2]] -> [[2, 0], [1]] (axis 2 moved from dim 1 to front of dim 0)
+// Source shard: 8/2 x 16/(2*2) = 4x4; after all_to_all(axis=2): 2x8
+// CHECK-LABEL: func @move_last_split_axis_to_lower_dim
+func.func @move_last_split_axis_to_lower_dim(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
+ %arg0: tensor<8x16xf32>
+) -> tensor<8x16xf32> {
+ // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<8x16xf32> to tensor<4x4xf32>
+ // CHECK: %[[ALL_TO_ALL:.*]] = shard.all_to_all %[[SOURCE_SHARD]] on @grid_3d grid_axes = [2] split_axis = 0 concat_axis = 1 : tensor<4x4xf32> -> tensor<2x8xf32>
+ // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[ALL_TO_ALL]] : tensor<2x8xf32> to tensor<8x16xf32>
+ %s0 = shard.sharding @grid_3d split_axes = [[0], [1, 2]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<8x16xf32>
+ %s1 = shard.sharding @grid_3d split_axes = [[2, 0], [1]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<8x16xf32>
+ // CHECK: return %[[RES]] : tensor<8x16xf32>
+ return %1 : tensor<8x16xf32>
+}
+
+// MoveLastSplitAxisPattern where source has no axes at tgtTensorDim:
+// [[0, 1]] -> [[0], [1]] (tgtTensorDim has empty source)
+// Source shard: 8/(2*2) x 16 = 2x16; after all_to_all(axis=1): 4x8
+// CHECK-LABEL: func @move_last_split_axis_empty_source_at_target_dim
+func.func @move_last_split_axis_empty_source_at_target_dim(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
+ %arg0: tensor<8x16xf32>
+) -> tensor<8x16xf32> {
+ // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<8x16xf32> to tensor<2x16xf32>
+ // CHECK: %[[ALL_TO_ALL:.*]] = shard.all_to_all %[[SOURCE_SHARD]] on @grid_3d grid_axes = [1] split_axis = 1 concat_axis = 0 : tensor<2x16xf32> -> tensor<4x8xf32>
+ // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[ALL_TO_ALL]] : tensor<4x8xf32> to tensor<8x16xf32>
+ %s0 = shard.sharding @grid_3d split_axes = [[0, 1]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<8x16xf32>
+ %s1 = shard.sharding @grid_3d split_axes = [[0], [1]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<8x16xf32>
+ // CHECK: return %[[RES]] : tensor<8x16xf32>
+ return %1 : tensor<8x16xf32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/189241
More information about the Mlir-commits
mailing list