[Mlir-commits] [mlir] [MLIR][Shard] Fix three bugs in ND mesh resharding in Partition pass (PR #189241)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Mar 29 06:16:58 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Mehdi Amini (joker-eph)

<details>
<summary>Changes</summary>

A new MoveLastSplitAxisPattern class handles the case where the last grid axis of one tensor dimension is moved to the front of another tensor dimension's split axes, e.g. [[0, 1], [2]] -> [[0], [1, 2]].

The three bugs fixed are:

1. detectMoveLastSplitAxisInResharding: compared source.back() with target.back() instead of target.front(), preventing the pattern from being detected for resharding like [[0,1],[2]] -> [[0],[1,2]].

2. targetShardingInMoveLastAxis: axes were appended with push_back but should be inserted at the front, producing wrong split_axes order.

3. handlePartialAxesDuringResharding: a copy_if wrote results into the wrong output variable (addressed structurally by the clean implementation).

Fixes #<!-- -->136117

Assisted-by: Claude Code

---
Full diff: https://github.com/llvm/llvm-project/pull/189241.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Shard/Transforms/Partition.cpp (+138-12) 
- (modified) mlir/test/Dialect/Shard/resharding-partition.mlir (+58-1) 


``````````diff
diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
index 9c5880e0c3b64..57502fbf9e276 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
@@ -260,6 +260,17 @@ class UnsplitLastAxesPattern : public ReshardingPattern {
   }
 };
 
+// Compute the result shape of an all-to-all that gathers along srcTensorDim
+// and scatters along tgtTensorDim with the given split count.
+static ShapedType allToAllResultShape(ShapedType srcShape, int64_t splitCount,
+                                      int64_t srcTensorDim,
+                                      int64_t tgtTensorDim) {
+  SmallVector<int64_t> tgtShape = llvm::to_vector(srcShape.getShape());
+  tgtShape[srcTensorDim] = gatherDimension(tgtShape[srcTensorDim], splitCount);
+  tgtShape[tgtTensorDim] = shardDimension(tgtShape[tgtTensorDim], splitCount);
+  return srcShape.cloneWith(tgtShape, srcShape.getElementType());
+}
+
 /// Move a split axis between tensor dimensions:
 /// e.g. [[0], []] -> [[], [0]].
 class MoveSplitAxisPattern : public ReshardingPattern {
@@ -310,16 +321,6 @@ class MoveSplitAxisPattern : public ReshardingPattern {
     return Sharding::get(srcSharding.getGridAttr(), tgtShardingSplitAxes);
   }
 
-  static ShapedType allToAllResultShape(ShapedType srcShape, int64_t splitCount,
-                                        int64_t srcTensorDim,
-                                        int64_t tgtTensorDim) {
-    SmallVector<int64_t> tgtShape = llvm::to_vector(srcShape.getShape());
-    tgtShape[srcTensorDim] =
-        gatherDimension(tgtShape[srcTensorDim], splitCount);
-    tgtShape[tgtTensorDim] = shardDimension(tgtShape[tgtTensorDim], splitCount);
-    return srcShape.cloneWith(tgtShape, srcShape.getElementType());
-  }
-
   static std::tuple<TypedValue<ShapedType>, Sharding>
   apply(ImplicitLocOpBuilder &builder, GridOp grid, Sharding srcSharding,
         ShapedType srcUnshardedType, TypedValue<ShapedType> srcShard,
@@ -362,6 +363,130 @@ class MoveSplitAxisPattern : public ReshardingPattern {
   }
 };
 
+/// Move the last split axis of one tensor dimension to the front of another
+/// tensor dimension's split axes, e.g. [[0, 1], [2]] -> [[0], [1, 2]].
+class MoveLastSplitAxisPattern : public ReshardingPattern {
+  // Detect if the resharding moves the last grid axis of srcTensorDim to the
+  // front of another tensor dimension's split axes. If detected, returns
+  // (tgtTensorDim, movedGridAxis).
+  //
+  // Pattern: src[srcTensorDim] = [a1,...,a(n-1),an]  (n >= 2)
+  //          tgt[srcTensorDim] = [a1,...,a(n-1)]
+  //          src[tgtTensorDim] = [b1,...,bm]          (m >= 0)
+  //          tgt[tgtTensorDim] = [an, b1,...,bm]
+  static std::optional<std::tuple<int64_t, GridAxis>>
+  detect(const Sharding &srcSharding, const Sharding &tgtSharding,
+         int64_t srcTensorDim) {
+    if (static_cast<size_t>(srcTensorDim) >= srcSharding.getSplitAxes().size())
+      return std::nullopt;
+    auto srcAxes = srcSharding.getSplitAxes()[srcTensorDim].asArrayRef();
+    // Need at least 2 axes to move the last one.
+    if (srcAxes.size() < 2)
+      return std::nullopt;
+
+    // After the move the source tensor dim should lose its last axis.
+    if (static_cast<size_t>(srcTensorDim) >= tgtSharding.getSplitAxes().size())
+      return std::nullopt;
+    auto tgtSrcAxes = tgtSharding.getSplitAxes()[srcTensorDim].asArrayRef();
+    if (tgtSrcAxes.size() + 1 != srcAxes.size())
+      return std::nullopt;
+    // The remaining axes at srcTensorDim must be the same (prefix of source).
+    if (!llvm::equal(tgtSrcAxes,
+                     llvm::make_range(srcAxes.begin(), srcAxes.end() - 1)))
+      return std::nullopt;
+
+    GridAxis movedAxis = srcAxes.back();
+
+    // Find a target tensor dimension whose split axes start with movedAxis
+    // and whose remaining axes match the source sharding at that dimension.
+    for (size_t tgtTensorDim = 0;
+         tgtTensorDim < tgtSharding.getSplitAxes().size(); ++tgtTensorDim) {
+      if (static_cast<int64_t>(tgtTensorDim) == srcTensorDim)
+        continue;
+      auto tgtAxes = tgtSharding.getSplitAxes()[tgtTensorDim].asArrayRef();
+      // The target dimension must start with the moved axis.
+      if (tgtAxes.empty() || tgtAxes.front() != movedAxis)
+        continue;
+      // The remainder of tgtAxes must equal the source sharding at
+      // tgtTensorDim.
+      ArrayRef<GridAxis> srcTgtAxes =
+          static_cast<size_t>(tgtTensorDim) < srcSharding.getSplitAxes().size()
+              ? srcSharding.getSplitAxes()[tgtTensorDim].asArrayRef()
+              : ArrayRef<GridAxis>{};
+      if (!llvm::equal(srcTgtAxes,
+                       llvm::make_range(tgtAxes.begin() + 1, tgtAxes.end())))
+        continue;
+      return std::make_tuple(static_cast<int64_t>(tgtTensorDim), movedAxis);
+    }
+    return std::nullopt;
+  }
+
+  // Compute the result sharding after moving movedAxis from srcTensorDim
+  // to the front of tgtTensorDim.
+  static Sharding tgtSharding(MLIRContext *ctx, const Sharding &srcSharding,
+                              int64_t srcTensorDim, int64_t tgtTensorDim,
+                              GridAxis movedAxis) {
+    SmallVector<GridAxesAttr> splitAxes =
+        llvm::to_vector(srcSharding.getSplitAxes());
+    while (static_cast<int64_t>(splitAxes.size()) <= tgtTensorDim)
+      splitAxes.push_back(GridAxesAttr::get(ctx, {}));
+
+    // Remove last axis from srcTensorDim.
+    auto srcSplitAxes = llvm::to_vector(splitAxes[srcTensorDim].asArrayRef());
+    assert(!srcSplitAxes.empty() && srcSplitAxes.back() == movedAxis);
+    srcSplitAxes.pop_back();
+    splitAxes[srcTensorDim] = GridAxesAttr::get(ctx, srcSplitAxes);
+
+    // Prepend movedAxis to tgtTensorDim.
+    auto tgtSplitAxes = llvm::to_vector(splitAxes[tgtTensorDim].asArrayRef());
+    tgtSplitAxes.insert(tgtSplitAxes.begin(), movedAxis);
+    splitAxes[tgtTensorDim] = GridAxesAttr::get(ctx, tgtSplitAxes);
+
+    return Sharding::get(srcSharding.getGridAttr(), splitAxes);
+  }
+
+  static std::tuple<TypedValue<ShapedType>, Sharding>
+  apply(ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &srcSharding,
+        ShapedType srcUnshardedType, TypedValue<ShapedType> srcShard,
+        int64_t srcTensorDim, int64_t tgtTensorDim, GridAxis movedAxis) {
+    MLIRContext *ctx = builder.getContext();
+    builder.setInsertionPointAfterValue(srcShard);
+
+    Sharding resultSharding =
+        tgtSharding(ctx, srcSharding, srcTensorDim, tgtTensorDim, movedAxis);
+    ShapedType a2aResultShape =
+        allToAllResultShape(srcShard.getType(), grid.getShape()[movedAxis],
+                            srcTensorDim, tgtTensorDim);
+    Value allToAllResult = AllToAllOp::create(
+        builder,
+        RankedTensorType::get(a2aResultShape.getShape(),
+                              a2aResultShape.getElementType()),
+        grid.getSymName(), SmallVector<GridAxis>({movedAxis}), srcShard,
+        APInt(64, tgtTensorDim), APInt(64, srcTensorDim));
+    ShapedType tgtShape =
+        shardShapedType(srcUnshardedType, grid, resultSharding);
+    TypedValue<ShapedType> tgtShard =
+        tensor::CastOp::create(builder, tgtShape, allToAllResult).getResult();
+    return {tgtShard, resultSharding};
+  }
+
+public:
+  std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
+  tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim,
+           const Sharding &srcSharding, const Sharding &tgtSharding,
+           ShapedType srcUnshardedType,
+           TypedValue<ShapedType> srcShard) override {
+    if (hasStaticOffsetsOrHalos(srcSharding, tgtSharding))
+      return std::nullopt;
+    if (auto detectRes = detect(srcSharding, tgtSharding, tensorDim)) {
+      auto [tgtTensorDim, movedAxis] = detectRes.value();
+      return apply(builder, grid, srcSharding, srcUnshardedType, srcShard,
+                   tensorDim, tgtTensorDim, movedAxis);
+    }
+    return std::nullopt;
+  }
+};
+
 /// Update halo sizes: handles cases where only the halo sizes differ between
 /// source and target sharding. Requires copying the "core" of the source tensor
 /// into the "core" of the destination tensor followed by an update halo op.
@@ -460,12 +585,13 @@ static TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder,
 
   // Each pattern's tryApply checks its own applicability preconditions.
   static UpdateHaloPattern updateHaloPattern;
+  static MoveLastSplitAxisPattern moveLastSplitAxisPattern;
   static MoveSplitAxisPattern moveSplitAxisPattern;
   static SplitLastAxisPattern splitLastAxisPattern;
   static UnsplitLastAxesPattern unsplitLastAxesPattern;
   static ReshardingPattern *patterns[] = {
-      &updateHaloPattern, &moveSplitAxisPattern, &splitLastAxisPattern,
-      &unsplitLastAxesPattern};
+      &updateHaloPattern, &moveLastSplitAxisPattern, &moveSplitAxisPattern,
+      &splitLastAxisPattern, &unsplitLastAxesPattern};
   TypedValue<ShapedType> currentShard = shardedSrc;
   Sharding currentSharding = srcSharding;
   for (int64_t dim = 0;
diff --git a/mlir/test/Dialect/Shard/resharding-partition.mlir b/mlir/test/Dialect/Shard/resharding-partition.mlir
index ff9e8408aa7fd..01c4733485678 100644
--- a/mlir/test/Dialect/Shard/resharding-partition.mlir
+++ b/mlir/test/Dialect/Shard/resharding-partition.mlir
@@ -2,6 +2,7 @@
 
 shard.grid @grid_1d(shape = 2)
 shard.grid @grid_1d_dynamic(shape = ?)
+shard.grid @grid_3d(shape = 2x2x2)
 
 // CHECK-LABEL: func @same_source_and_target_sharding
 func.func @same_source_and_target_sharding(
@@ -153,7 +154,7 @@ func.func @unshard_dynamic_axis(
 
 // CHECK-LABEL: func @unshard_static_axis_on_dynamic_grid_axis
 func.func @unshard_static_axis_on_dynamic_grid_axis(
-// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>  
+// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
   %arg0: tensor<10x14xf32>
 ) -> tensor<10x14xf32> {
   // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32>
@@ -166,3 +167,59 @@ func.func @unshard_static_axis_on_dynamic_grid_axis(
   // CHECK: return %[[RES]] : tensor<10x14xf32>
   return %1 : tensor<10x14xf32>
 }
+
+// MoveLastSplitAxisPattern: [[0, 1], [2]] -> [[0], [1, 2]]
+// Source shard: 8/(2*2) x 16/2 = 2x8; after all_to_all(axis=1): 4x4
+// CHECK-LABEL: func @move_last_split_axis_to_front_of_target_dim
+func.func @move_last_split_axis_to_front_of_target_dim(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
+  %arg0: tensor<8x16xf32>
+) -> tensor<8x16xf32> {
+  // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<8x16xf32> to tensor<2x8xf32>
+  // CHECK: %[[ALL_TO_ALL:.*]] = shard.all_to_all %[[SOURCE_SHARD]] on @grid_3d grid_axes = [1] split_axis = 1 concat_axis = 0 : tensor<2x8xf32> -> tensor<4x4xf32>
+  // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[ALL_TO_ALL]] : tensor<4x4xf32> to tensor<8x16xf32>
+  %s0 = shard.sharding @grid_3d split_axes = [[0, 1], [2]] : !shard.sharding
+  %0 = shard.shard %arg0 to %s0 : tensor<8x16xf32>
+  %s1 = shard.sharding @grid_3d split_axes = [[0], [1, 2]] : !shard.sharding
+  %1 = shard.shard %0 to %s1 annotate_for_users : tensor<8x16xf32>
+  // CHECK: return %[[RES]] : tensor<8x16xf32>
+  return %1 : tensor<8x16xf32>
+}
+
+// MoveLastSplitAxisPattern with tgtTensorDim < srcTensorDim:
+// [[0], [1, 2]] -> [[2, 0], [1]] (axis 2 moved from dim 1 to front of dim 0)
+// Source shard: 8/2 x 16/(2*2) = 4x4; after all_to_all(axis=2): 2x8
+// CHECK-LABEL: func @move_last_split_axis_to_lower_dim
+func.func @move_last_split_axis_to_lower_dim(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
+  %arg0: tensor<8x16xf32>
+) -> tensor<8x16xf32> {
+  // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<8x16xf32> to tensor<4x4xf32>
+  // CHECK: %[[ALL_TO_ALL:.*]] = shard.all_to_all %[[SOURCE_SHARD]] on @grid_3d grid_axes = [2] split_axis = 0 concat_axis = 1 : tensor<4x4xf32> -> tensor<2x8xf32>
+  // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[ALL_TO_ALL]] : tensor<2x8xf32> to tensor<8x16xf32>
+  %s0 = shard.sharding @grid_3d split_axes = [[0], [1, 2]] : !shard.sharding
+  %0 = shard.shard %arg0 to %s0 : tensor<8x16xf32>
+  %s1 = shard.sharding @grid_3d split_axes = [[2, 0], [1]] : !shard.sharding
+  %1 = shard.shard %0 to %s1 annotate_for_users : tensor<8x16xf32>
+  // CHECK: return %[[RES]] : tensor<8x16xf32>
+  return %1 : tensor<8x16xf32>
+}
+
+// MoveLastSplitAxisPattern where source has no axes at tgtTensorDim:
+// [[0, 1]] -> [[0], [1]] (tgtTensorDim has empty source)
+// Source shard: 8/(2*2) x 16 = 2x16; after all_to_all(axis=1): 4x8
+// CHECK-LABEL: func @move_last_split_axis_empty_source_at_target_dim
+func.func @move_last_split_axis_empty_source_at_target_dim(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
+  %arg0: tensor<8x16xf32>
+) -> tensor<8x16xf32> {
+  // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<8x16xf32> to tensor<2x16xf32>
+  // CHECK: %[[ALL_TO_ALL:.*]] = shard.all_to_all %[[SOURCE_SHARD]] on @grid_3d grid_axes = [1] split_axis = 1 concat_axis = 0 : tensor<2x16xf32> -> tensor<4x8xf32>
+  // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[ALL_TO_ALL]] : tensor<4x8xf32> to tensor<8x16xf32>
+  %s0 = shard.sharding @grid_3d split_axes = [[0, 1]] : !shard.sharding
+  %0 = shard.shard %arg0 to %s0 : tensor<8x16xf32>
+  %s1 = shard.sharding @grid_3d split_axes = [[0], [1]] : !shard.sharding
+  %1 = shard.shard %0 to %s1 annotate_for_users : tensor<8x16xf32>
+  // CHECK: return %[[RES]] : tensor<8x16xf32>
+  return %1 : tensor<8x16xf32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/189241


More information about the Mlir-commits mailing list