[Mlir-commits] [mlir] [MLIR][Shard] Fix three bugs in ND mesh resharding in Partition pass (PR #189241)

Frank Schlimbach llvmlistbot at llvm.org
Mon Apr 13 08:49:20 PDT 2026


================
@@ -362,6 +363,130 @@ class MoveSplitAxisPattern : public ReshardingPattern {
   }
 };
 
+/// Move the last split axis of one tensor dimension to the front of another
+/// tensor dimension's split axes, e.g. [[0, 1], [2]] -> [[0], [1, 2]].
+class MoveLastSplitAxisPattern : public ReshardingPattern {
+  // Detect if the resharding moves the last grid axis of srcTensorDim to the
+  // front of another tensor dimension's split axes. If detected, returns
+  // (tgtTensorDim, movedGridAxis).
+  //
+  // Pattern: src[srcTensorDim] = [a1,...,a(n-1),an]  (n >= 2)
+  //          tgt[srcTensorDim] = [a1,...,a(n-1)]
+  //          src[tgtTensorDim] = [b1,...,bm]          (m >= 0)
+  //          tgt[tgtTensorDim] = [an, b1,...,bm]
+  static std::optional<std::tuple<int64_t, GridAxis>>
+  detect(const Sharding &srcSharding, const Sharding &tgtSharding,
+         int64_t srcTensorDim) {
+    if (static_cast<size_t>(srcTensorDim) >= srcSharding.getSplitAxes().size())
+      return std::nullopt;
+    auto srcAxes = srcSharding.getSplitAxes()[srcTensorDim].asArrayRef();
+    // Need at least 2 axes to move the last one.
+    if (srcAxes.size() < 2)
+      return std::nullopt;
+
+    // After the move the source tensor dim should lose its last axis.
+    if (static_cast<size_t>(srcTensorDim) >= tgtSharding.getSplitAxes().size())
+      return std::nullopt;
----------------
fschlimb wrote:

The check verifies that that the split dim of the source is not after the last split dim in the target. If you want to check what the comment says, you have to to do something like
```suggestion
    // The src dim must have have splitAxes in the target sharding
    // (as we require at least 2 axes in src)
    if (static_cast<size_t>(srcTensorDim) >= tgtSharding.getSplitAxes().size())
      return std::nullopt;
    // After the move the source tensor dim should lose its last axis.
    if (srcAxes.size() >= tgtSharding.getSplitAxes()[srcTensorDim].size())
      return std::nullopt;
```


https://github.com/llvm/llvm-project/pull/189241


More information about the Mlir-commits mailing list