[Mlir-commits] [mlir] 1a8fb88 - [mlir][mesh] Add resharding spmdization on a 1D device mesh (#76179)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 2 15:50:12 PST 2024


Author: Boian Petkantchin
Date: 2024-01-02T15:50:07-08:00
New Revision: 1a8fb887197caf709710bedf88ce95ffb0605c56

URL: https://github.com/llvm/llvm-project/commit/1a8fb887197caf709710bedf88ce95ffb0605c56
DIFF: https://github.com/llvm/llvm-project/commit/1a8fb887197caf709710bedf88ce95ffb0605c56.diff

LOG: [mlir][mesh] Add resharding spmdization on a 1D device mesh (#76179)

The current implementation supports only sharding of tensor axes that
have size divisible by the mesh axis size.

Added: 
    mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc.md
    mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
    mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
    mlir/test/Dialect/Mesh/resharding-spmdization.mlir
    mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp

Modified: 
    mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
    mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
    mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
    mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
    mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
    mlir/test/Dialect/Mesh/invalid.mlir
    mlir/test/Dialect/Mesh/ops.mlir
    mlir/test/lib/Dialect/Mesh/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 9d39b1b3329fb4..a9d30dfbb9a76e 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -33,6 +33,10 @@ def Mesh_Dialect : Dialect {
   let useDefaultAttributePrinterParser = 1;
   let hasConstantMaterializer = 1;
 }
+
+def Mesh_MeshAxis : I<16>;
+def Mesh_MeshAxesAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">;
+
 //===----------------------------------------------------------------------===//
 // Mesh Enums.
 //===----------------------------------------------------------------------===//
@@ -125,7 +129,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
 
     // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
     // it is also a partial_sum along mesh axis 1.
-    tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [], [1]]>
+    tensor<4x8xf32, #mesh.shard<@mesh0, [[0], []], partial = sum[1]>
 
     // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
     // it is also a partial_max along mesh axis 1.
@@ -158,6 +162,11 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
     }]>
   ];
 
+  let extraClassDeclaration = [{
+    bool operator==(::mlir::Attribute rhs) const;
+    bool operator==(::mlir::mesh::MeshShardingAttr rhs) const;
+  }];
+
   let genVerifyDecl = 1;
 }
 

diff  --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 9077d2eb0189b7..ce7d5d045122d9 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -30,6 +30,9 @@
 namespace mlir {
 namespace mesh {
 
+using MeshAxis = int16_t;
+using MeshAxesAttr = DenseI16ArrayAttr;
+
 bool isReductionLoop(IteratorType iType);
 
 bool areReductionAndPartialMatch(IteratorType iType, Partial partial);

diff  --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 784f3eb97763ad..1ed54b6519e4d8 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -10,6 +10,7 @@
 #define MLIR_DIALECT_MESH_IR_MESHOPS_TD
 
 include "mlir/Dialect/Mesh/IR/MeshBase.td"
+include "mlir/Dialect/Shape/IR/ShapeBase.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/BuiltinTypes.td"
@@ -95,6 +96,28 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
   let hasVerifier = 1;
 }
 
+def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+  let summary = "Get the shape of the cluster.";
+  let arguments = (ins
+    FlatSymbolRefAttr:$mesh,
+    DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
+  );
+
+  let results = (outs
+    Variadic<Index>:$result
+  );
+
+  let assemblyFormat = [{
+    $mesh (`axes` `=` $axes^)?
+    attr-dict `:` type($result)
+  }];
+
+  let builders = [
+    OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
+    OpBuilder<(ins "StringRef":$mesh, "ArrayRef<int16_t>":$axes)>
+  ];
+}
+
 def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
   let summary = "Annotate on how a tensor is sharded across a mesh cluster.";
   let description = [{
@@ -186,6 +209,29 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
   }];
 }
 
+def Mesh_ProcessIndexOp : Mesh_Op<"process_index", [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+  let summary = "Get the index of current device along specified mesh axis.";
+  let description = [{
+    It is used in the SPMD format of IR.
+    The `axes` mush be non-negative and less than the total number of mesh axes.
+  }];
+  let arguments = (ins
+    FlatSymbolRefAttr:$mesh,
+    DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
+  );
+  let results = (outs
+    Variadic<Index>:$result
+  );
+  let assemblyFormat = [{
+    `on` $mesh (`axes` `=` $axes^)?
+    attr-dict `:` type($result)
+  }];
+  let builders = [
+    OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
+    OpBuilder<(ins "StringRef":$mesh, "ArrayRef<int16_t>":$axes)>
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // collective communication ops
 //===----------------------------------------------------------------------===//
@@ -197,7 +243,7 @@ class Mesh_CollectiveCommunicationOpBase<
       [DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
   dag commonArgs = (ins
     FlatSymbolRefAttr:$mesh,
-    DefaultValuedAttr<DenseI16ArrayAttr, "{}">:$mesh_axes
+    DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$mesh_axes
   );
 }
 

diff  --git a/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc.md b/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc.md
new file mode 100644
index 00000000000000..6368931cf6e075
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc.md
@@ -0,0 +1,683 @@
+# Resharding Spmdization Examples
+
+Reshard `2x3` tensor from sharding `[[0, 1]]` to sharding `[[0, 1]]` on a `2x3` mesh.
+
+unsharded `2x3` tensor
+```
+11 12 13
+21 22 23
+```
+
+sharded on a `2x3` mesh
+
+sharding = `[[0, 1]]`
+
+mesh contents:
+
+```
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 |             |
++----+----+----+             |
+| 21 | 22 | 23 |             |
++----+----+----+             ↓
+```
+
+Transform into
+sharding = `[[1, 0]]`
+```
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 13 | 22 |             |
++----+----+----+             |
+| 12 | 21 | 23 |             |
++----+----+----+             ↓
+```
+Algorithm:
+Swap contents on devices that have the same linear index in the 2 shardings.
+
+--------------------------------------------------------------
+
+Reshard `2x3` tensor from sharding `[[0, 1]]` to sharding `[[1]]` on a `2x3` mesh.
+
+unsharded `2x3` tensor
+```
+11 12 13
+21 22 23
+```
+
+sharded on a `2x3` mesh
+
+sharding = `[[0, 1]]`
+
+mesh contents:
+```
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 |             |
++----+----+----+             |
+| 21 | 22 | 23 |             |
++----+----+----+             ↓
+```
+
+Transform into
+sharding = `[[1]]`
+```
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 |             |
+| 21 | 22 | 23 |             |
++----+----+----+             |
+| 11 | 12 | 13 |             |
+| 21 | 22 | 23 |             |
++----+----+----+             ↓
+```
+Algorithm:
+All-gather along mesh axis 0.
+
+--------------------------------------------------------------
+
+Reshard `4x6` tensor from sharding `[[], [0, 1]]` to sharding `[[], [0]]` on a `2x3` mesh.
+
+unsharded `4x6` tensor
+```
+11 12 13 14 15 16
+21 22 23 24 25 26
+```
+
+sharded on a `2x3` mesh
+
+sharding = `[[], [0, 1]]`
+
+mesh contents:
+```
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 |             |
+| 21 | 22 | 23 |             |
++----+----+----+             |
+| 14 | 15 | 16 |             |
+| 24 | 25 | 26 |             |
++----+----+----+             ↓
+```
+Transform into
+sharding = `[[], [0]]`
+```
+mesh axis 1
+----------->
++----------+----------+ mesh axis 0 |
+| 11 12 13 | 11 12 13 |             |
+| 21 22 23 | 21 22 23 |             |
++----------+----------+             |
+| 14 15 16 | 14 15 16 |             |
+| 24 25 26 | 24 25 26 |             |
++----------+----------+             ↓
+```
+Algorithm:
+All-gather along mesh axis 1.
+
+--------------------------------------------------------------
+
+Reshard `4x8` tensor from sharding `[[0], [1, 2]]` to sharding `[[0], [2]]` on a `2x2x2` mesh.
+
+unsharded `4x8` tensor
+```
+11 12 13 14 15 16 17 18
+21 22 23 24 25 26 27 28
+31 32 33 34 35 36 37 38
+41 42 43 44 45 46 47 48
+```
+sharded on a `2x2x2` mesh
+
+sharding = `[[0], [1, 2]]`
+
+mesh contents:
+```
+mesh axis 2
+----------->
++-------+-------+ mesh axis 1 | mesh axis 0 |
+| 11 12 | 13 14 |             |             |
+| 21 22 | 23 24 |             |             |
++-------+-------+             |             |
+| 15 16 | 17 18 |             |             |
+| 25 26 | 27 28 |             |             |
++-------+-------+             ↓             |
++-------+-------+                           |
+| 31 32 | 33 34 |                           |
+| 41 42 | 43 44 |                           |
++-------+-------+                           |
+| 35 36 | 37 38 |                           |
+| 45 46 | 47 48 |                           |
++-------+-------+                           ↓
+```
+Transform into
+sharding = `[[0], [2]]`
+```
+mesh axis 2
+----------->
++-------------+-------------+ mesh axis 1 | mesh axis 0 |
+| 11 12 13 14 | 15 16 17 18 |             |             |
+| 21 22 23 24 | 25 26 27 28 |             |             |
++-------------+-------------+             |             |
+| 11 12 13 14 | 15 16 17 18 |             |             |
+| 21 22 23 24 | 25 26 27 28 |             |             |
++-------------+-------------+             ↓             |
++-------------+-------------+                           |
+| 31 32 33 34 | 35 36 37 38 |                           |
+| 41 42 43 44 | 45 46 47 48 |                           |
++-------------+-------------+                           |
+| 31 32 33 34 | 35 36 37 38 |                           |
+| 41 42 43 44 | 45 46 47 48 |                           |
++-------------+-------------+                           ↓
+```
+Algorithm:
+
+Can't be done with just an all-gather along mesh axis 1.
+Can be handled by multiple resharding transformations
+`[[0], [1, 2]] -> [[0], [2, 1]] -> [[0], [2]]`
+
+--------------------------------------------------------------
+
+Reshard `6x6` tensor from sharding `[[0], [1]]` to sharding `[[1], [0]]` on a `2x3` mesh.
+
+unsharded `6x6` tensor
+```
+11 12 13 14 15 16
+21 22 23 24 25 26
+31 32 33 34 35 36
+41 42 43 44 45 46
+51 52 53 54 55 56
+61 62 63 64 65 66
+```
+sharded on a `2x3` mesh
+
+sharding = `[[0], [1]]`
+```
+mesh axis 1
+----------->
++-------+-------+-------+ mesh axis 0 |
+| 11 12 | 13 14 | 15 16 |             |
+| 21 22 | 23 24 | 25 26 |             |
+| 31 32 | 33 34 | 35 36 |             |
++-------+-------+-------+             |
+| 41 42 | 43 44 | 45 46 |             |
+| 51 52 | 53 54 | 55 56 |             |
+| 61 62 | 63 64 | 65 66 |             |
++-------+-------+-------+             ↓
+```
+transform to
+sharding = `[[1], [0]]`
+```
+mesh axis 1
+----------->
++----------+----------+----------+ mesh axis 0 |
+| 11 12 13 | 31 32 33 | 51 52 53 |             |
+| 21 22 23 | 41 42 43 | 61 62 63 |             |
++----------+----------+----------+             |
+| 14 15 16 | 34 35 36 | 54 55 56 |             |
+| 24 25 26 | 44 45 46 | 64 65 66 |             |
++----------+----------+----------+             ↓
+
+mesh axis 0
+----------->
++----------+----------+ mesh axis 1 |
+| 11 12 13 | 14 15 16 |             |
+| 21 22 23 | 24 25 26 |             |
++----------+----------+             |
+| 31 32 33 | 34 35 36 |             |
+| 41 42 43 | 44 45 46 |             |
++----------+----------+             |
+| 51 52 53 | 54 55 56 |             |
+| 61 62 63 | 64 65 66 |             |
++----------+----------+             ↓
+```
+Algorithm: TODO
+
+--------------------------------------------------------------
+
+Reshard `6x6` tensor from sharding `[[0], [1]]` to sharding `[[1], [0]]` on a `2x6` mesh.
+
+unsharded 6x6 tensor
+```
+11 12 13 14 15 16
+21 22 23 24 25 26
+31 32 33 34 35 36
+41 42 43 44 45 46
+51 52 53 54 55 56
+61 62 63 64 65 66
+```
+shard on `2x6` mesh
+
+sharding = `[[0], [1]]`
+```
+mesh axis 1
+----------->
++----+----+----+----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 ‖ 14 | 15 | 16 |             |
+| 21 | 22 | 23 ‖ 24 | 23 | 26 |             |
+| 31 | 32 | 33 ‖ 34 | 35 | 36 |             |
++----+----+----+----+----+----+             |
+| 41 | 42 | 43 ‖ 44 | 45 | 46 |             |
+| 51 | 52 | 53 ‖ 54 | 55 | 56 |             |
+| 61 | 62 | 63 ‖ 64 | 65 | 66 |             |
++----+----+----+----+----+----+             ↓
+```
+transform to
+sharding = `[[1], [0]]`
+```
+mesh axis 0
+----------->
++----------+----------+ mesh axis 1 |
+| 11 12 13 | 14 15 16 |             |
++----------+----------+             |
+| 21 22 23 | 24 25 26 |             |
++----------+----------+             |
+| 31 32 33 | 34 35 36 |             |
++==========+==========+             |
+| 41 42 43 | 44 45 46 |             |
++----------+----------+             |
+| 51 52 53 | 54 55 56 |             |
++----------+----------+             |
+| 61 62 63 | 64 65 66 |             |
++----------+----------+             ↓
+```
+Algorithm: TODO
+
+--------------------------------------------------------------
+
+Reshard KxL tensor from `[[0], [1]]` to `[[1], [0]]` on `MxN` mesh.
+
+`M x N` mesh.
+`K x L` tensor `t`.
+`d(m, n)` the tensor on device `(m, n)`.
+
+sharding = `[[0], [1]]`
+Tensor shard s on each device has size `(K ceildiv M, L ceildiv N)`.
+```
+d(m, n)[k, l] -> t[m * (K ceildiv M) + k, n * (L ceildiv N) + l]
+```
+substitute
+```
+i <- m * (K ceildiv M) + k
+j <- n * (L ceildiv N) + l
+```
+```
+m -> i floordiv (K ceildiv M)
+n -> j floordiv (L ceildiv N)
+k -> i % (K ceildiv M)
+l -> j % (L ceildiv N)
+```
+For the inverse map we get
+```
+t[i, j] -> d(
+    i floordiv (K ceildiv M), j floordiv (L ceildiv N)
+)[
+    i % (K ceildiv M), j % (L ceildiv N)
+]
+```
+Check:
+```
+i = 13, j = 17, M = 3, N = 4, K = 16, L = 23
+t[13, 17] = d(
+    13 floordiv (16 ceildiv 3),
+    17 floordiv (23 ceilvid 4)
+)[
+    13 % (16 ceildiv 3),
+    17 % (23 ceilvid 4)
+]
+= d(
+    13 floordiv 6,
+    17 floordiv 6
+)[
+    13 % 6,
+    17 % 6
+]
+= d(2, 2)[1, 5]
+= t[
+    2 * (16 ceildiv 3) + 1,
+    2 * (23 ceildiv 4) + 5
+]
+= t[
+    2 * 6 + 1,
+    2 * 6 + 5
+]
+= t[13, 17]
+```
+
+sharding = `[[1], [0]]`
+Tensor shard s on each device has size `(K ceildiv N, L ceildiv M)`.
+```
+d(m, n)[k, l] -> t[n * (K ceildiv N) + k, m * (L ceildiv M) + l]
+```
+substitute
+```
+i <- n * (K ceildiv N) + k
+j <- m * (L ceildiv M) + l
+```
+```
+m -> j floordiv (L ceildiv M)
+n -> i floordiv (K ceildiv N)
+k -> i % (K ceildiv N)
+l -> j % (L ceildiv M)
+```
+For the inverse map we get
+```
+t[i, j] -> d(
+    j floordiv (L ceildiv M), i floordiv (K ceildiv N)
+)[
+    i % (K ceildiv N), j % (L ceildiv M)
+]
+```
+Check:
+```
+i = 9, j = 19, M = 5, N = 2, K = 27, L = 14
+t[9, 19] = d(
+    19 floordiv (14 ceildiv 5),
+    9 floordiv (27 ceildiv 2)
+)[
+    9 % (27 ceildiv 2),
+    19 % (14 ceildiv 5)
+]
+= d(
+    19 floordiv 3,
+    9 floordiv 14
+)[
+    9 % 14
+    19 % 3
+]
+= d(6, 0)[9, 1]
+= t[
+    0 * (27 ceildiv 2) + 9,
+    6 * (14 ceildiv 5) + 1
+]
+= t[
+    0 * 14 + 9,
+    6 * 3 + 1
+]
+= t[9, 19]
+```
+sharding = `[[0], [1]]`
+```
+d(m, n)[k, l] -> t[m * (K ceildiv M) + k, n * (L ceildiv N) + l]
+t[i, j] -> d(i floordiv (K ceildiv M), j floordiv (L ceildiv N))[i % (K ceildiv M), j % (L ceildiv N)]
+```
+sharding = `[[1], [0]]`
+```
+d(m, n)[k, l] -> t[n * (K ceildiv N) + k, m * (L ceildiv M) + l]
+t[i, j] -> d(j floordiv (L ceildiv M), i floordiv (K ceildiv N))[i % (K ceildiv N), j % (L ceildiv M)]
+```
+sharding `[[0], [1]] -> [[1], [0]]`
+`d1(m, n)` the tensor on device `(m, n)` for sharding sharding `[[0], [1]]`.
+`d2(m, n)` the tensor on device `(m, n)` for sharding sharding `[[1], [0]]`.
+```
+d1(m, n)[k, l] ->
+t[m * (K ceildiv M) + k, n * (L ceildiv N) + l] ->
+d2(
+    (m * (L ceildiv M) + l) floordiv (L ceildiv M),
+    (n * (K ceildiv N) + k) floordiv (K ceildiv N)
+)[
+    (n * (K ceildiv N) + k) % (K ceildiv N),
+    (m * (L ceildiv M) + l) % (L ceildiv M)
+]
+= d2(p, q)[u, v]
+```
+We want to copy the the data between devices in slices/tiles.
+What are the source/target tile coordinates?
+For a fixed `(m, n, p, q)` what is the range of `(k, l, u, v)`?
+TODO
+
+--------------------------------------------------------------
+
+Reshard `KxL` tensor from sharding `[[0], [1]]` to sharding `[[1], [0]]` on a `2x3` mesh.
+
+Device placement on a `2x3` mesh
+```
+11 12 13  <- devices
+21 22 23
+```
+sharding `[[0], [1]]`
+```
+tensor axis 1
+----------->
++----+----+----+ tensor axis 0 |
+| 11 | 12 | 13 |               |
++----+----+----+               |
+| 21 | 22 | 23 |               |
++----+----+----+               ↓
+```
+transform to
+sharding `[[1], [0]]`
+```
+tensor axis 1
+----------->
++----+----+ tensor axis 0 |
+| 11 | 21 |               |
++----+----+               |
+| 12 | 22 |               |
++----+----+               |
+| 13 | 23 |               |
++----+----+               ↓
+```
+```
++-----------------+--------+--------+-----------------+
+|                 |                 |                 |
++                 +                 +                 +
+|       11        |        12       |        13       |
++                 +                 +                 +
+|                 |                 |                 |
++-----------------+--------+--------+-----------------+
+|                 |                 |                 |
++                 +                 +                 +
+|       21        |        22       |        23       |
++                 +                 +                 +
+|                 |                 |                 |
++-----------------+--------+--------+-----------------+
+
++-----------------+--------+--------+-----------------+
+|                          |                          |
++          11              +               21         +
+|                          |                          |
++-----------------+--------+--------+-----------------+
+|                          |                          |
++          12              +               22         +
+|                          |                          |
++-----------------+--------+--------+-----------------+
+|                          |                          |
++          13              +               23         +
+|                          |                          |
++-----------------+--------+--------+-----------------+
+
++-----------------+--------+--------+-----------------+
+|                 |        |        |                 |
++     11  11      + 12  11 + 12  21 +     13  21      +
+|                 |        |        |                 |
++-----------------+--------+--------+-----------------+
+|     11  12      | 12  12 | 12  22 |     13  22      |
++-----------------+--------+--------+-----------------+
+|     21  12      | 22  12 | 22  22 |     23  22      |
++-----------------+--------+--------+-----------------+
+|                 |        |        |                 |
++     21  13      + 22  13 + 22  23 +     23  23      +
+|                 |        |        |                 |
++-----------------+--------+--------+-----------------+
+```
+If `S` and `T` are the source and target shard sizes along some tensor axis.
+Then we have a period of `(S*T)/gcd(S, T)`. Then the cut pattern repeats.
+TODO
+
+--------------------------------------------------------------
+
+Reshard `6x6` tensor from sharding `[[0], []]` to sharding `[[], [0]]` on a `3` mesh.
+
+unsharded `6x6` tensor
+```
+11 12 13 14 15 16
+21 22 23 24 25 26
+31 32 33 34 35 36
+41 42 43 44 45 46
+51 52 53 54 55 56
+61 62 63 64 65 66
+```
+sharded on a `3` mesh
+
+sharding = `[[0], []]`
+```
++-------------------+ mesh axis 0 |
+| 11 12 13 14 15 16 |             |
+| 21 22 23 24 25 26 |             |
++-------------------+             |
+| 31 32 33 34 35 36 |             |
+| 41 42 43 44 45 46 |             |
++-------------------+             |
+| 51 52 53 54 55 56 |             |
+| 61 62 63 64 65 66 |             |
++-------------------+             ↓
+```
+transform to
+sharding = `[[], [0]]`
+```
+mesh axis 0
+----------->
++-------+-------+-------+
+| 11 12 | 13 14 | 15 16 |
+| 21 22 | 23 24 | 25 26 |
+| 31 32 | 33 34 | 35 36 |
+| 41 42 | 43 44 | 45 46 |
+| 51 52 | 53 54 | 55 56 |
+| 61 62 | 63 64 | 65 66 |
++-------+-------+-------+
+```
+Algorithm:
+```mlir
+%1 = all_to_all %0 on @mesh mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<2x6xi8> -> tensor<6x2xi8>
+```
+--------------------------------------------------------------
+
+Reshard `4x4` tensor from sharding `[[0], [1, 2]]` to sharding `[[0, 1], [2]]` on a `2x2x2` mesh.
+
+unsharded `4x4` tensor
+```
+11 12 13 14
+21 22 23 24
+31 32 33 34
+41 42 43 44
+```
+sharded on a `2x2x2` mesh
+
+sharding = `[[0], [1, 2]]`
+```
+mesh axis 2
+----------->
++----+----+ mesh axis 1 | mesh axis 0 |
+| 11 | 12 |             |             |
+| 21 | 22 |             |             |
++----+----+             |             |
+| 13 | 14 |             |             |
+| 23 | 24 |             |             |
++----+----+             ↓             |
++----+----+                           |
+| 31 | 32 |                           |
+| 41 | 42 |                           |
++----+----+                           |
+| 33 | 34 |                           |
+| 43 | 44 |                           |
++----+----+                           ↓
+```
+transform to
+sharding = `[[0, 1], [2]]`
+```
+mesh axis 2
+----------->
++-------+-------+ mesh axis 1 | mesh axis 0 |
+| 11 12 | 13 41 |             |             |
++-------+-------+             |             |
+| 21 22 | 23 24 |             |             |
++-------+-------+             ↓             |
++-------+-------+                           |
+| 31 32 | 33 34 |                           |
++-------+-------+                           |
+| 41 42 | 43 44 |                           |
++-------+-------+                           ↓
+```
+Algorithm:
+```mlir
+%1 = all_to_all %0 on @mesh mesh_axes = [2] split_axis = 1 concat_axis = 0 : tensor<2x1xi8> -> tensor<1x2xi8>
+```
+is not enough.
+
+Can be decomposed into
+```
+[[0], [1, 2]] -> [[0], [2, 1]] -> [[0, 1], [2]]
+```
+
+## Decomposition into basis of reshardings
+
+We can decompose each resharding into a sequence of basis reshardings.
+It is not communication efficient in terms of minimizing the data communicated
+between devices.
+An efficient approach would be more complicated to implement.
+Each device has to receive at most as much data as the size of its target
+sharding tensor.
+
+--------------------------------------------------------------
+
+Basis:
+
+*   From replicate to split.
+    ```
+    [[]] -> [[1]]
+    ```
+    Extract slices without communication.
+
+* From split to replicate.
+    ```
+    [[0]] -> [[]]
+    [[0, 1]] -> [[1]]
+    ```
+    All-gather along mesh axis 0.
+
+*   Swap mesh axes order when assigned to the same tensor axis.
+    ```
+    [[0, 1]] -> [[1, 0]]
+    ```
+    Swap contents on devices with the same linear index.
+
+*   Move mesh axis to 
diff erent tensor dimension.
+    ```
+    [[0], []] -> [[], [0]]
+    ```
+    All-to-all.
+
+--------------------------------------------------------------
+
+Example decomposition of
+```
+[[0], [1]] -> [[1], [0]]
+```
+into
+```
+[[0], [1]] -> all-gather along mesh axis 1    ->
+[[0], []]  -> all-to-all along mesh axis 0    ->
+[[], [0]]  -> extract slice along mesh axis 1 ->
+[[1], [0]]
+```
+
+--------------------------------------------------------------
+
+Example decomposition of
+```
+[[3, 2], [], [0, 1]] -> [[0], [1, 2], []]
+```
+into
+```
+[[3, 2], [], [0, 1]] -> all-to-all along mesh axis 1 ->
+[[3, 2], [1], [0]]   -> all-to-all along mesh axis 2 ->
+[[3], [1, 2], [0]]   -> all-gather along mesh axis 3 ->
+[[], [1, 2], [0]]    -> all-to-all along mesh axis 0 ->
+[[0], [1, 2], []]
+```

diff  --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
new file mode 100644
index 00000000000000..f71bb9b262a380
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
@@ -0,0 +1,35 @@
+//===- Simplifications.h - Mesh Simplifications -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_SPMDIZATION_H
+#define MLIR_DIALECT_MESH_TRANSFORMS_SPMDIZATION_H
+
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+namespace mesh {
+
+// Return the sharded shape `shape` acording ot sharding `sharding`.
+ShapedType shardShapedType(ShapedType shape, ClusterOp mesh,
+                           MeshShardingAttr sharding);
+
+// Insert resharding spmdization of the value `sourceShardValue`
+// from sharding `source` to sharding `target`.
+// `sourceShardValue` is the already sharded value according to `source`.
+TypedValue<ShapedType> reshard(OpBuilder &builder, ClusterOp mesh,
+                               ShardOp source, ShardOp target,
+                               TypedValue<ShapedType> sourceShardValue);
+
+void reshardingRegisterDependentDialects(DialectRegistry &registry);
+
+} // namespace mesh
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MESH_TRANSFORMS_SPMDIZATION_H

diff  --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index d27675564c6464..de4f58d54e8ca5 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -9,8 +9,10 @@
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/Location.h"
@@ -59,8 +61,6 @@ static SmallVector<T> &canonicalizeSetAsVector(SmallVector<T> &vec) {
   return vec;
 }
 
-using MeshAxis = int16_t;
-
 namespace {
 
 struct DimensionSize {
@@ -114,6 +114,56 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
 // Mesh utilities
 //===----------------------------------------------------------------------===//
 
+static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
+                                    SymbolTableCollection &symbolTable) {
+  mesh::ClusterOp mesh =
+      symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(op, meshSymbol);
+  if (!mesh) {
+    return op->emitError() << "Undefined required mesh symbol \""
+                           << meshSymbol.getValue() << "\".";
+  }
+
+  return mesh;
+}
+
+template <typename It>
+bool isUnique(It begin, It end) {
+  if (begin == end) {
+    return true;
+  }
+  It next = std::next(begin);
+  if (next == end) {
+    return true;
+  }
+  for (; next != end; ++next, ++begin) {
+    if (*begin == *next) {
+      return false;
+    }
+  }
+  return true;
+}
+
+static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
+                                    ClusterOp mesh) {
+  SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
+  llvm::sort(sorted);
+  if (!isUnique(sorted.begin(), sorted.end())) {
+    return emitError(loc) << "Mesh axes contains duplicate elements.";
+  }
+
+  MeshAxis rank = mesh.getRank();
+  for (auto axis : axes) {
+    if (axis >= rank || axis < 0) {
+      return emitError(loc)
+             << "0-based mesh axis index " << axis
+             << " is out of bounds. The referenced mesh \"" << mesh.getSymName()
+             << "\" is of rank " << rank << ".";
+    }
+  }
+
+  return success();
+}
+
 bool mesh::isReductionLoop(IteratorType iType) {
   return iType != IteratorType::Parallel && iType != IteratorType::Invalid;
 }
@@ -173,7 +223,45 @@ SmallVector<int64_t> ClusterOp::canonicalDimSizes() {
 }
 
 //===----------------------------------------------------------------------===//
-// mesh.shard op
+// mesh.cluster_shape op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+ClusterShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
+    return failure();
+  }
+
+  size_t expectedResultsCount =
+      getAxes().empty() ? mesh->getRank() : getAxes().size();
+  if (getResult().size() != expectedResultsCount) {
+    return emitError() << "Unexpected number of results " << getResult().size()
+                       << ". Expected " << expectedResultsCount << ".";
+  }
+
+  return success();
+}
+
+void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                           ClusterOp mesh) {
+  build(odsBuilder, odsState,
+        SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
+        mesh.getSymName(), MeshAxesAttr());
+}
+
+void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                           StringRef mesh, ArrayRef<MeshAxis> axes) {
+  build(odsBuilder, odsState,
+        SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
+        MeshAxesAttr::get(odsBuilder.getContext(), axes));
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.shard attr
 //===----------------------------------------------------------------------===//
 
 LogicalResult
@@ -205,6 +293,75 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
+bool MeshShardingAttr::operator==(Attribute rhs) const {
+  MeshShardingAttr rhsAsMeshShardingAttr = rhs.dyn_cast<MeshShardingAttr>();
+  return rhsAsMeshShardingAttr && *this == rhsAsMeshShardingAttr;
+}
+
+bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
+  if (getCluster() != rhs.getCluster() ||
+      getPartialAxes() != rhs.getPartialAxes()) {
+    return false;
+  }
+
+  if (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) {
+    return false;
+  }
+
+  auto minSize = std::min(getSplitAxes().size(), rhs.getSplitAxes().size());
+  if (!llvm::equal(llvm::make_range(getSplitAxes().begin(),
+                                    getSplitAxes().begin() + minSize),
+                   llvm::make_range(rhs.getSplitAxes().begin(),
+                                    rhs.getSplitAxes().begin() + minSize))) {
+    return false;
+  }
+
+  return llvm::all_of(llvm::make_range(getSplitAxes().begin() + minSize,
+                                       getSplitAxes().end()),
+                      std::mem_fn(&DenseI32ArrayAttr::empty)) &&
+         llvm::all_of(llvm::make_range(rhs.getSplitAxes().begin() + minSize,
+                                       rhs.getSplitAxes().end()),
+                      std::mem_fn(&DenseI32ArrayAttr::empty));
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.process_index op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+ProcessIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
+    return failure();
+  }
+
+  size_t expectedResultsCount =
+      getAxes().empty() ? mesh->getRank() : getAxes().size();
+  if (getResult().size() != expectedResultsCount) {
+    return emitError() << "Unexpected number of results " << getResult().size()
+                       << ". Expected " << expectedResultsCount << ".";
+  }
+
+  return success();
+}
+
+void ProcessIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                           ClusterOp mesh) {
+  build(odsBuilder, odsState,
+        SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
+        mesh.getSymName(), MeshAxesAttr());
+}
+
+void ProcessIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                           StringRef mesh, ArrayRef<MeshAxis> axes) {
+  build(odsBuilder, odsState,
+        SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
+        MeshAxesAttr::get(odsBuilder.getContext(), axes));
+}
+
 //===----------------------------------------------------------------------===//
 // collective communication ops
 //===----------------------------------------------------------------------===//
@@ -258,56 +415,6 @@ static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
   return success();
 }
 
-static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
-                                    SymbolTableCollection &symbolTable) {
-  mesh::ClusterOp mesh =
-      symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(op, meshSymbol);
-  if (!mesh) {
-    return op->emitError() << "Undefined required mesh symbol \""
-                           << meshSymbol.getValue() << "\".";
-  }
-
-  return mesh;
-}
-
-template <typename It>
-bool isUnique(It begin, It end) {
-  if (begin == end) {
-    return true;
-  }
-  It next = std::next(begin);
-  if (next == end) {
-    return true;
-  }
-  for (; next != end; ++next, ++begin) {
-    if (*begin == *next) {
-      return false;
-    }
-  }
-  return true;
-}
-
-static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
-                                    ClusterOp mesh) {
-  SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
-  llvm::sort(sorted);
-  if (!isUnique(sorted.begin(), sorted.end())) {
-    return emitError(loc) << "Mesh axes contains duplicate elements.";
-  }
-
-  MeshAxis rank = mesh.getRank();
-  for (auto axis : axes) {
-    if (axis >= rank || axis < 0) {
-      return emitError(loc)
-             << "0-based mesh axis index " << axis
-             << " is out of bounds. The referenced mesh \"" << mesh.getSymName()
-             << "\" is of rank " << rank << ".";
-    }
-  }
-
-  return success();
-}
-
 template <typename Op>
 static FailureOr<ClusterOp>
 getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) {

diff  --git a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
index 044b8672c8c60c..7a70c047ec9dce 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRMeshTransforms
   Simplifications.cpp
   ShardingPropagation.cpp
+  Spmdization.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
@@ -11,11 +12,13 @@ add_mlir_dialect_library(MLIRMeshTransforms
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
+  MLIRControlFlowDialect
   MLIRFuncDialect
   MLIRIR
   MLIRMeshDialect
   MLIRPass
   MLIRShardingInterface
   MLIRSupport
+  MLIRTensorDialect
   MLIRTosaShardingInterfaceImpl
 )

diff  --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
new file mode 100644
index 00000000000000..de8b3a98df998a
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -0,0 +1,639 @@
+//===- Spmdization.cpp --------------------------------------------- C++ --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Support/MathExtras.h"
+#include "llvm/ADT/ADL.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include <algorithm>
+#include <iterator>
+#include <numeric>
+#include <optional>
+#include <tuple>
+#include <type_traits>
+
+namespace mlir {
+namespace mesh {
+
+int64_t shardDimension(int64_t dim, int64_t shardCount) {
+  if (ShapedType::isDynamic(dim) || ShapedType::isDynamic(shardCount))
+    return ShapedType::kDynamic;
+
+  assert(dim % shardCount == 0);
+  return ceilDiv(dim, shardCount);
+}
+
+int64_t unshardDimension(int64_t dim, int64_t shardCount) {
+  if (ShapedType::isDynamic(dim) || ShapedType::isDynamic(shardCount))
+    return ShapedType::kDynamic;
+
+  return dim * shardCount;
+}
+
+template <typename MeshShape, typename SplitAxes>
+int64_t shardCount(const MeshShape &meshShape, const SplitAxes &splitAxes) {
+  int64_t res = 1;
+  for (auto splitAxis : splitAxes) {
+    int64_t meshDimSize = meshShape[splitAxis];
+    if (ShapedType::isDynamic(meshDimSize)) {
+      return ShapedType::kDynamic;
+    }
+    res *= meshDimSize;
+  }
+  return res;
+}
+
+// Compute the shape for the tensor on each device in the mesh.
+// Example:
+// On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1
+// would result in a shape for each shard of ?x2x?.
+template <typename InShape, typename MeshShape, typename SplitAxes,
+          typename OutShape>
+static void shardShape(const InShape &inShape, const MeshShape &meshShape,
+                       const SplitAxes &splitAxes, OutShape &outShape) {
+  std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
+            llvm::adl_begin(outShape));
+  for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
+    outShape[tensorAxis] =
+        shardDimension(inShape[tensorAxis],
+                       shardCount(meshShape, innerSplitAxes.asArrayRef()));
+  }
+}
+
+ShapedType shardShapedType(ShapedType shape, ClusterOp mesh,
+                           MeshShardingAttr sharding) {
+  using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
+  SmallVector<Dim> resShapeArr(shape.getShape().size());
+  shardShape(shape.getShape(), mesh.canonicalDimSizes(),
+             sharding.getSplitAxes(), resShapeArr);
+  return shape.clone(resShapeArr);
+}
+
+template <typename SourceAxes, typename TargetAxes>
+static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
+                                     const TargetAxes &targetAxes) {
+  return llvm::all_of(targetAxes, [&sourceAxes](auto &targetAxis) {
+    return sourceAxes.contains(targetAxis);
+  });
+}
+
+// Return the reduced value and its corresponding sharding.
+// Example:
+// sourceSharding = <@mesh_1d, [[0]], partial = sum[0]>
+// targetSharding = <@mesh_1d, [[]]>
+// Then will apply all-reduce on the source value
+// and return it with the sharding <@mesh_1d, [[0]]>.
+static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
+handlePartialAxesDuringResharding(OpBuilder &builder,
+                                  MeshShardingAttr sourceSharding,
+                                  MeshShardingAttr targetSharding,
+                                  TypedValue<ShapedType> sourceShard) {
+  if (sourceSharding.getPartialAxes().empty() &&
+      targetSharding.getPartialAxes().empty()) {
+    return {sourceShard, sourceSharding};
+  }
+  assert(targetSharding.getPartialAxes().empty() ||
+         (!sourceSharding.getPartialAxes().empty() &&
+          sourceSharding.getPartialType() == targetSharding.getPartialType()));
+  using Axis = std::decay_t<decltype(sourceSharding.getPartialAxes().front())>;
+  using AxisSet = llvm::SmallDenseSet<Axis>;
+  AxisSet sourceShardingPartialAxesSet(sourceSharding.getPartialAxes().begin(),
+                                       sourceSharding.getPartialAxes().end());
+  AxisSet targetShardingPartialAxesSet(targetSharding.getPartialAxes().begin(),
+                                       targetSharding.getPartialAxes().end());
+  assert(arePartialAxesCompatible(sourceShardingPartialAxesSet,
+                                  targetShardingPartialAxesSet));
+  llvm::SmallVector<MeshAxis> allReduceMeshAxes;
+  llvm::copy_if(sourceShardingPartialAxesSet,
+                std::back_inserter(allReduceMeshAxes),
+                [&targetShardingPartialAxesSet](Axis a) {
+                  return !targetShardingPartialAxesSet.contains(a);
+                });
+  if (allReduceMeshAxes.empty()) {
+    return {sourceShard, sourceSharding};
+  }
+
+  builder.setInsertionPointAfterValue(sourceShard);
+  TypedValue<ShapedType> resultValue =
+      builder
+          .create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
+                               sourceSharding.getCluster().getLeafReference(),
+                               allReduceMeshAxes, sourceShard,
+                               sourceSharding.getPartialType())
+          .getResult()
+          .cast<TypedValue<ShapedType>>();
+
+  llvm::SmallVector<int32_t> remainingPartialAxes;
+  llvm::copy_if(sourceShardingPartialAxesSet,
+                std::back_inserter(allReduceMeshAxes),
+                [&targetShardingPartialAxesSet](Axis a) {
+                  return targetShardingPartialAxesSet.contains(a);
+                });
+  MeshShardingAttr resultSharding =
+      MeshShardingAttr::get(builder.getContext(), sourceSharding.getCluster(),
+                            sourceSharding.getSplitAxes(), remainingPartialAxes,
+                            sourceSharding.getPartialType());
+  return {resultValue, resultSharding};
+}
+
+static MeshShardingAttr
+targetShardingInSplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
+                              int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
+  SmallVector<DenseI32ArrayAttr> targetShardingSplitAxes =
+      llvm::to_vector(sourceSharding.getSplitAxes());
+  while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
+         splitTensorAxis) {
+    targetShardingSplitAxes.push_back(DenseI32ArrayAttr::get(ctx, {}));
+  }
+  auto targetSplitAxes =
+      llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
+  targetSplitAxes.push_back(splitMeshAxis);
+  targetShardingSplitAxes[splitTensorAxis] =
+      DenseI32ArrayAttr::get(ctx, targetSplitAxes);
+  return MeshShardingAttr::get(
+      ctx, sourceSharding.getCluster(), targetShardingSplitAxes,
+      sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
+}
+
+static ShapedType targetShapeInSplitLastAxis(ShapedType sourceShape,
+                                             int64_t splitTensorAxis,
+                                             int64_t splitCount) {
+  SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
+  targetShape[splitTensorAxis] =
+      shardDimension(targetShape[splitTensorAxis], splitCount);
+  return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
+}
+
+// Split a replicated tensor along a mesh axis.
+// e.g. [[0, 1]] -> [[0, 1, 2]].
+// Returns the spmdized target value with its sharding.
+//
+// The implementation is the extract the tensor slice corresponding
+// to the current device.
+static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
+splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
+                          MeshShardingAttr sourceSharding,
+                          TypedValue<ShapedType> sourceShard, ClusterOp mesh,
+                          int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
+  MLIRContext *ctx = builder.getContext();
+  builder.setInsertionPointAfterValue(sourceShard);
+
+  Value zero = builder.create<arith::ConstantOp>(builder.getIndexAttr(0));
+
+  Value processIndexAlongAxis =
+      builder
+          .create<ProcessIndexOp>(mesh.getSymName(),
+                                  SmallVector<MeshAxis>({splitMeshAxis}))
+          .getResult()[0];
+
+  MeshShardingAttr targetSharding = targetShardingInSplitLastAxis(
+      ctx, sourceSharding, splitTensorAxis, splitMeshAxis);
+  ShapedType targetShape =
+      targetShapeInSplitLastAxis(sourceShard.getType(), splitTensorAxis,
+                                 mesh.canonicalDimSizes()[splitMeshAxis]);
+
+  Value meshAxisSize =
+      builder
+          .create<ClusterShapeOp>(mesh.getSymName(),
+                                  SmallVector<MeshAxis>({splitMeshAxis}))
+          .getResult()[0];
+
+  Value sourceAxisSize =
+      builder.create<tensor::DimOp>(sourceShard, splitTensorAxis);
+  Value sourceAxisSizeModMeshAxisSize =
+      builder.create<arith::RemUIOp>(sourceAxisSize, meshAxisSize);
+  Value isTargetShapeExactlyDivisible = builder.create<arith::CmpIOp>(
+      arith::CmpIPredicate::eq, sourceAxisSizeModMeshAxisSize, zero);
+  builder.create<cf::AssertOp>(
+      isTargetShapeExactlyDivisible,
+      "Sharding a tensor with axis size that is not exactly divisible by the "
+      "mesh axis size is not supported.");
+  Value targetAxisSize =
+      builder.create<arith::DivUIOp>(sourceAxisSize, meshAxisSize);
+  Value axisOffset =
+      builder.create<arith::MulIOp>(targetAxisSize, processIndexAlongAxis);
+  SmallVector<int64_t> staticOffsets(targetShape.getRank(), 0);
+  staticOffsets[splitTensorAxis] = ShapedType::kDynamic;
+  DenseI64ArrayAttr staticOffsetsAttr =
+      DenseI64ArrayAttr::get(ctx, staticOffsets);
+  SmallVector<Value> dynamicOffsets(1, axisOffset);
+
+  DenseI64ArrayAttr staticSizesAttr =
+      DenseI64ArrayAttr::get(ctx, targetShape.getShape());
+  SmallVector<Value> dynamicSizes;
+  for (int64_t i = 0; i < targetShape.getRank(); ++i) {
+    if (ShapedType::isDynamic(staticSizesAttr.asArrayRef()[i])) {
+      if (i == splitTensorAxis) {
+        dynamicSizes.push_back(targetAxisSize);
+      } else {
+        Value dimSize = builder.create<tensor::DimOp>(sourceShard, i);
+        dynamicSizes.push_back(dimSize);
+      }
+    }
+  }
+
+  DenseI64ArrayAttr staticStridesAttr = DenseI64ArrayAttr::get(
+      ctx, SmallVector<int64_t>(targetShape.getRank(), 1));
+  TypedValue<RankedTensorType> targetShard =
+      builder
+          .create<tensor::ExtractSliceOp>(
+              targetShape, sourceShard, dynamicOffsets, dynamicSizes,
+              SmallVector<Value>({}), staticOffsetsAttr, staticSizesAttr,
+              staticStridesAttr)
+          .getResult();
+  return {targetShard.cast<TypedValue<ShapedType>>(), targetSharding};
+}
+
+// Detect if the resharding is of type e.g.
+// [[0, 1]] -> [[0, 1, 2]].
+// If detected, returns the corresponding tensor axis mesh axis pair.
+// Does not detect insertions like
+// [[0, 1]] -> [[0, 2, 1]].
+static std::optional<std::tuple<int64_t, MeshAxis>>
+detectSplitLastAxisInResharding(MeshShardingAttr sourceSharding,
+                                MeshShardingAttr targetSharding) {
+  for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size();
+       ++tensorAxis) {
+    if (sourceSharding.getSplitAxes().size() > tensorAxis) {
+      if (sourceSharding.getSplitAxes()[tensorAxis].size() + 1 !=
+          targetSharding.getSplitAxes()[tensorAxis].size()) {
+        continue;
+      }
+      if (!llvm::equal(
+              sourceSharding.getSplitAxes()[tensorAxis].asArrayRef(),
+              llvm::make_range(
+                  targetSharding.getSplitAxes()[tensorAxis]
+                      .asArrayRef()
+                      .begin(),
+                  targetSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
+                      1))) {
+        continue;
+      }
+    } else {
+      if (targetSharding.getSplitAxes()[tensorAxis].size() != 1) {
+        continue;
+      }
+    }
+    return std::make_tuple(
+        tensorAxis,
+        targetSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
+  }
+  return std::nullopt;
+}
+
+static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
+trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+                             MeshShardingAttr sourceSharding,
+                             MeshShardingAttr targetSharding,
+                             TypedValue<ShapedType> sourceShard) {
+  if (auto detectRes =
+          detectSplitLastAxisInResharding(sourceSharding, targetSharding)) {
+    auto [tensorAxis, meshAxis] = detectRes.value();
+    return splitLastAxisInResharding(builder, sourceSharding, sourceShard, mesh,
+                                     tensorAxis, meshAxis);
+  }
+
+  return std::nullopt;
+}
+
+// Detect if the resharding is of type e.g.
+// [[0, 1, 2]] -> [[0, 1]].
+// If detected, returns the corresponding tensor axis mesh axis pair.
+static std::optional<std::tuple<int64_t, MeshAxis>>
+detectUnsplitLastAxisInResharding(MeshShardingAttr sourceSharding,
+                                  MeshShardingAttr targetSharding) {
+  for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
+       ++tensorAxis) {
+    if (targetSharding.getSplitAxes().size() > tensorAxis) {
+      if (sourceSharding.getSplitAxes()[tensorAxis].size() !=
+          targetSharding.getSplitAxes()[tensorAxis].size() + 1)
+        continue;
+      if (!llvm::equal(
+              llvm::make_range(
+                  sourceSharding.getSplitAxes()[tensorAxis]
+                      .asArrayRef()
+                      .begin(),
+                  sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
+                      1),
+              targetSharding.getSplitAxes()[tensorAxis].asArrayRef()))
+        continue;
+    } else {
+      if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1)
+        continue;
+    }
+    return std::make_tuple(
+        tensorAxis,
+        sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
+  }
+  return std::nullopt;
+}
+
+static MeshShardingAttr
+targetShardingInUnsplitLastAxis(MLIRContext *ctx,
+                                MeshShardingAttr sourceSharding,
+                                int64_t splitTensorAxis) {
+  SmallVector<DenseI32ArrayAttr> targetShardingSplitAxes =
+      llvm::to_vector(sourceSharding.getSplitAxes());
+  assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >
+         splitTensorAxis);
+  auto targetSplitAxes =
+      llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
+
+  targetSplitAxes.pop_back();
+  targetShardingSplitAxes[splitTensorAxis] =
+      DenseI32ArrayAttr::get(ctx, targetSplitAxes);
+  return MeshShardingAttr::get(
+      ctx, sourceSharding.getCluster(), targetShardingSplitAxes,
+      sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
+}
+
+static ShapedType allGatherResultShapeInUnsplitLastAxis(
+    ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) {
+  SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
+  targetShape[splitTensorAxis] =
+      unshardDimension(targetShape[splitTensorAxis], splitCount);
+  return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
+}
+
+static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
+unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
+                            MeshShardingAttr sourceSharding,
+                            ShapedType sourceUnshardedShape,
+                            TypedValue<ShapedType> sourceShard, ClusterOp mesh,
+                            int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
+  MLIRContext *ctx = builder.getContext();
+  builder.setInsertionPointAfterValue(sourceShard);
+
+  MeshShardingAttr targetSharding =
+      targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitMeshAxis);
+  ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
+      sourceShard.getType(), mesh.canonicalDimSizes()[splitMeshAxis],
+      splitTensorAxis);
+  Value allGatherResult = builder.create<AllGatherOp>(
+      RankedTensorType::get(allGatherResultShape.getShape(),
+                            allGatherResultShape.getElementType()),
+      mesh.getSymName(), SmallVector<MeshAxis>({splitMeshAxis}), sourceShard,
+      APInt(64, splitTensorAxis));
+  ShapedType targetShape =
+      shardShapedType(sourceUnshardedShape, mesh, targetSharding);
+  TypedValue<ShapedType> targetShard =
+      builder.create<tensor::CastOp>(targetShape, allGatherResult)
+          .getResult()
+          .cast<TypedValue<ShapedType>>();
+  return {targetShard, targetSharding};
+}
+
+static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
+tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+                               MeshShardingAttr sourceSharding,
+                               MeshShardingAttr targetSharding,
+                               ShapedType sourceUnshardedShape,
+                               TypedValue<ShapedType> sourceShard) {
+  if (auto detectRes =
+          detectUnsplitLastAxisInResharding(sourceSharding, targetSharding)) {
+    auto [tensorAxis, meshAxis] = detectRes.value();
+    return unsplitLastAxisInResharding(builder, sourceSharding,
+                                       sourceUnshardedShape, sourceShard, mesh,
+                                       tensorAxis, meshAxis);
+  }
+
+  return std::nullopt;
+}
+
+// Detect if the resharding is of type e.g.
+// [[0, 1], [2]] -> [[0], [1, 2]].
+// Only moving the last axis counts.
+// If detected, returns the corresponding (source_tensor_axis,
+// target_tensor_axis, mesh_axis) tuple.
+static std::optional<std::tuple<int64_t, int64_t, MeshAxis>>
+detectMoveLastSplitAxisInResharding(MeshShardingAttr sourceSharding,
+                                    MeshShardingAttr targetSharding) {
+  for (size_t sourceTensorAxis = 0;
+       sourceTensorAxis < sourceSharding.getSplitAxes().size();
+       ++sourceTensorAxis) {
+    for (size_t targetTensorAxis = 0;
+         targetTensorAxis < targetSharding.getSplitAxes().size();
+         ++targetTensorAxis) {
+      if (sourceTensorAxis == targetTensorAxis)
+        continue;
+      if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() ||
+          targetSharding.getSplitAxes()[targetTensorAxis].empty() ||
+          sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
+              targetSharding.getSplitAxes()[targetTensorAxis]
+                  .asArrayRef()
+                  .back())
+        continue;
+      if (!llvm::equal(
+              llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis]
+                                   .asArrayRef()
+                                   .begin(),
+                               sourceSharding.getSplitAxes()[sourceTensorAxis]
+                                       .asArrayRef()
+                                       .end() -
+                                   1),
+              llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis]
+                                   .asArrayRef()
+                                   .begin(),
+                               targetSharding.getSplitAxes()[targetTensorAxis]
+                                       .asArrayRef()
+                                       .end() -
+                                   1)))
+        continue;
+      return std::make_tuple(
+          sourceTensorAxis, targetTensorAxis,
+          sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back());
+    }
+  }
+  return std::nullopt;
+}
+
+static MeshShardingAttr
+targetShardingInMoveLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
+                             int64_t sourceTensorAxis,
+                             int64_t targetTensorAxis) {
+  SmallVector<DenseI32ArrayAttr> targetShardingSplitAxes =
+      llvm::to_vector(sourceSharding.getSplitAxes());
+  while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
+         targetTensorAxis) {
+    targetShardingSplitAxes.push_back(DenseI32ArrayAttr::get(ctx, {}));
+  }
+
+  auto sourceSplitAxes =
+      llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef());
+  assert(!sourceSplitAxes.empty());
+  auto meshAxis = sourceSplitAxes.back();
+  sourceSplitAxes.pop_back();
+  targetShardingSplitAxes[sourceTensorAxis] =
+      DenseI32ArrayAttr::get(ctx, sourceSplitAxes);
+
+  auto targetSplitAxes =
+      llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
+  targetSplitAxes.push_back(meshAxis);
+  targetShardingSplitAxes[targetTensorAxis] =
+      DenseI32ArrayAttr::get(ctx, targetSplitAxes);
+
+  return MeshShardingAttr::get(
+      ctx, sourceSharding.getCluster(), targetShardingSplitAxes,
+      sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
+}
+
+static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape,
+                                                    int64_t splitCount,
+                                                    int64_t sourceTensorAxis,
+                                                    int64_t targetTensorAxis) {
+  SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
+  targetShape[sourceTensorAxis] =
+      unshardDimension(targetShape[sourceTensorAxis], splitCount);
+  targetShape[targetTensorAxis] =
+      shardDimension(targetShape[targetTensorAxis], splitCount);
+  return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
+}
+
+static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
+moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+                              MeshShardingAttr sourceSharding,
+                              ShapedType sourceUnshardedShape,
+                              TypedValue<ShapedType> sourceShard,
+                              int64_t sourceTensorAxis,
+                              int64_t targetTensorAxis, MeshAxis meshAxis) {
+  MLIRContext *ctx = builder.getContext();
+  builder.setInsertionPointAfterValue(sourceShard);
+
+  MeshShardingAttr targetSharding = targetShardingInMoveLastAxis(
+      ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
+  ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis(
+      sourceShard.getType(), mesh.canonicalDimSizes()[meshAxis],
+      sourceTensorAxis, targetTensorAxis);
+  Value allToAllResult = builder.create<AllToAllOp>(
+      RankedTensorType::get(allToAllResultShape.getShape(),
+                            allToAllResultShape.getElementType()),
+      mesh.getSymName(), SmallVector<MeshAxis>({meshAxis}), sourceShard,
+      APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
+  ShapedType targetShape =
+      shardShapedType(sourceUnshardedShape, mesh, targetSharding);
+  TypedValue<ShapedType> targetShard =
+      builder.create<tensor::CastOp>(targetShape, allToAllResult)
+          .getResult()
+          .cast<TypedValue<ShapedType>>();
+  return {targetShard, targetSharding};
+}
+
+static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
+tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+                                 MeshShardingAttr sourceSharding,
+                                 MeshShardingAttr targetSharding,
+                                 ShapedType sourceUnshardedShape,
+                                 TypedValue<ShapedType> sourceShard) {
+  if (auto detectRes =
+          detectMoveLastSplitAxisInResharding(sourceSharding, targetSharding)) {
+    auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value();
+    return moveLastSplitAxisInResharding(
+        builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard,
+        sourceTensorAxis, targetTensorAxis, meshAxis);
+  }
+
+  return std::nullopt;
+}
+
+// Handles only resharding on a 1D mesh.
+// Currently the sharded tensor axes must be exactly divisible by the single
+// mesh axis size.
+static TypedValue<ShapedType>
+reshardOn1DMesh(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+                MeshShardingAttr sourceSharding,
+                MeshShardingAttr targetSharding,
+                TypedValue<ShapedType> sourceUnshardedValue,
+                TypedValue<ShapedType> sourceShard) {
+  assert(sourceShard.getType() ==
+         shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding));
+  ShapedType targetShardType =
+      shardShapedType(sourceUnshardedValue.getType(), mesh, targetSharding);
+  assert(sourceShard.getType().getRank() == targetShardType.getRank());
+  assert(mesh.getRank() == 1 && "Only 1D meshes are currently supported.");
+
+  auto [reducedSourceShard, reducedSourceSharding] =
+      handlePartialAxesDuringResharding(builder, sourceSharding, targetSharding,
+                                        sourceShard);
+
+  if (reducedSourceSharding == targetSharding) {
+    return reducedSourceShard;
+  }
+
+  TypedValue<ShapedType> targetShard;
+  MeshShardingAttr actualTargetSharding;
+  if (auto tryRes = tryMoveLastSplitAxisInResharding(
+          builder, mesh, reducedSourceSharding, targetSharding,
+          sourceUnshardedValue.getType(), reducedSourceShard)) {
+    std::tie(targetShard, actualTargetSharding) = tryRes.value();
+  } else if (auto tryRes = trySplitLastAxisInResharding(
+                 builder, mesh, reducedSourceSharding, targetSharding,
+                 reducedSourceShard)) {
+    std::tie(targetShard, actualTargetSharding) = tryRes.value();
+  } else if (auto tryRes = tryUnsplitLastAxisInResharding(
+                 builder, mesh, reducedSourceSharding, targetSharding,
+                 sourceUnshardedValue.getType(), reducedSourceShard)) {
+    std::tie(targetShard, actualTargetSharding) = tryRes.value();
+  } else {
+    assert(false && "Did not find any pattern to apply.");
+  }
+
+  assert(actualTargetSharding == targetSharding);
+  assert(targetShard.getType() == targetShardType);
+  return targetShard;
+}
+
+TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+                               MeshShardingAttr sourceSharding,
+                               MeshShardingAttr targetSharding,
+                               TypedValue<ShapedType> sourceUnshardedValue,
+                               TypedValue<ShapedType> sourceShard) {
+  // Resort to handling only 1D meshes since the general case is complicated if
+  // it needs to be communication efficient in terms of minimizing the data
+  // transfered between devices.
+  return reshardOn1DMesh(builder, mesh, sourceSharding, targetSharding,
+                         sourceUnshardedValue, sourceShard);
+}
+
+TypedValue<ShapedType> reshard(OpBuilder &builder, ClusterOp mesh,
+                               ShardOp source, ShardOp target,
+                               TypedValue<ShapedType> sourceShardValue) {
+  assert(!source.getAnnotateForUsers());
+  assert(target.getAnnotateForUsers());
+  assert(source.getResult() == target.getOperand());
+  ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
+  return reshard(
+      implicitLocOpBuilder, mesh, source.getShard(), target.getShard(),
+      source.getSrc().cast<TypedValue<ShapedType>>(), sourceShardValue);
+}
+
+void reshardingRegisterDependentDialects(DialectRegistry &registry) {
+  registry.insert<arith::ArithDialect, mesh::MeshDialect, tensor::TensorDialect,
+                  cf::ControlFlowDialect>();
+}
+
+} // namespace mesh
+} // namespace mlir

diff  --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 03994f8f011e1f..3ee578a37235e2 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -70,6 +70,102 @@ func.func @mesh_axis_negtive_in_partial(
 
 // -----
 
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+
+func.func @cluster_shape_mesh_axis_out_of_bounds() -> (index, index) {
+  // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
+  %0:2 = mesh.cluster_shape @mesh0 axes = [0, 2] : index, index
+  return %0#0, %0#1 : index, index
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
+
+func.func @cluster_shape_duplicate_mesh_axis() -> (index, index, index) {
+  // expected-error at +1 {{Mesh axes contains duplicate elements.}}
+  %0:3 = mesh.cluster_shape @mesh0 axes = [0, 2, 0] : index, index, index
+  return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+
+func.func @cluster_shape_wrong_number_of_results() -> (index, index) {
+  // expected-error at +1 {{Unexpected number of results 2. Expected 1.}}
+  %0:2 = mesh.cluster_shape @mesh0 axes = [0] : index, index
+  return %0#0, %0#1 : index, index
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
+
+func.func @cluster_shape_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
+  // expected-error at +1 {{Unexpected number of results 2. Expected 3.}}
+  %0:2 = mesh.cluster_shape @mesh0 : index, index
+  return %0#0, %0#1 : index, index
+}
+
+// -----
+
+func.func @cluster_shape_invalid_mesh_name() -> (index) {
+  // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
+  %0 = mesh.cluster_shape @this_mesh_symbol_does_not_exist : index
+  return %0#0 : index
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+
+func.func @process_index_mesh_axis_out_of_bounds() -> (index, index) {
+  // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
+  %0:2 = mesh.process_index on @mesh0 axes = [0, 2] : index, index
+  return %0#0, %0#1 : index, index
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
+
+func.func @process_index_duplicate_mesh_axis() -> (index, index, index) {
+  // expected-error at +1 {{Mesh axes contains duplicate elements.}}
+  %0:3 = mesh.process_index on @mesh0 axes = [0, 2, 0] : index, index, index
+  return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+
+func.func @process_index_wrong_number_of_results() -> (index, index) {
+  // expected-error at +1 {{Unexpected number of results 2. Expected 1.}}
+  %0:2 = mesh.process_index on @mesh0 axes = [0] : index, index
+  return %0#0, %0#1 : index, index
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
+
+func.func @process_index_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
+  // expected-error at +1 {{Unexpected number of results 2. Expected 3.}}
+  %0:2 = mesh.process_index on @mesh0 : index, index
+  return %0#0, %0#1 : index, index
+}
+
+// -----
+
+func.func @process_index_invalid_mesh_name() -> (index) {
+  // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
+  %0 = mesh.process_index on @this_mesh_symbol_does_not_exist : index
+  return %0#0 : index
+}
+
+// -----
+
 func.func @all_reduce_invalid_mesh_symbol(
     %arg0 : tensor<4xf32>) -> tensor<4xf64> {
   // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}

diff  --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 8f8e309d18f156..a7c3b3dbab9c13 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -132,6 +132,55 @@ func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) ->
   return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32>
 }
 
+// CHECK-LABEL: func @cluster_shape
+func.func @cluster_shape() -> (index, index) {
+  // CHECK: %[[RES:.*]]:2 = mesh.cluster_shape @mesh0 axes = [0, 1] : index, index
+  %0:2 = mesh.cluster_shape @mesh0 axes = [0, 1] : index, index
+  // CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index
+  return %0#0, %0#1 : index, index
+}
+
+// CHECK-LABEL: func @cluster_shape_default_axes
+func.func @cluster_shape_default_axes() -> (index, index, index) {
+  // CHECK: %[[RES:.*]]:3 = mesh.cluster_shape @mesh0 : index, index, index
+  %0:3 = mesh.cluster_shape @mesh0 : index, index, index
+  // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
+  return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// CHECK-LABEL: func @cluster_shape_empty_axes
+func.func @cluster_shape_empty_axes() -> (index, index, index) {
+  // CHECK: %[[RES:.*]]:3 = mesh.cluster_shape @mesh0 : index, index, index
+  %0:3 = mesh.cluster_shape @mesh0 axes = [] : index, index, index
+  // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
+  return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// CHECK-LABEL: func @process_index
+func.func @process_index() -> (index, index) {
+  // CHECK: %[[RES:.*]]:2 = mesh.process_index on @mesh0 axes = [0, 1] : index, index
+  %0:2 = mesh.process_index on @mesh0 axes = [0, 1] : index, index
+  // CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index
+  return %0#0, %0#1 : index, index
+}
+
+// CHECK-LABEL: func @process_index_default_axes
+func.func @process_index_default_axes() -> (index, index, index) {
+  // CHECK: %[[RES:.*]]:3 = mesh.process_index on @mesh0 : index, index, index
+  %0:3 = mesh.process_index on @mesh0 : index, index, index
+  // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
+  return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// CHECK-LABEL: func @process_index_empty_axes
+func.func @process_index_empty_axes() -> (index, index, index) {
+  // CHECK: %[[RES:.*]]:3 = mesh.process_index on @mesh0 : index, index, index
+  %0:3 = mesh.process_index on @mesh0 axes = [] : index, index, index
+  // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
+  return %0#0, %0#1, %0#2 : index, index, index
+}
+
+
 // CHECK-LABEL: func @all_reduce
 func.func @all_reduce(
     // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>

diff  --git a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
new file mode 100644
index 00000000000000..0ba0d76c09a74d
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
@@ -0,0 +1,154 @@
+// RUN: mlir-opt -test-mesh-resharding-spmdization %s | FileCheck %s
+
+mesh.cluster @mesh_1d(rank = 1, dim_sizes = 2)
+mesh.cluster @mesh_1d_dynamic(rank = 1, dim_sizes = ?)
+
+// CHECK-LABEL: func @same_source_and_target_sharding
+func.func @same_source_and_target_sharding(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<2xf32>
+  %arg0: tensor<2xf32>
+) -> tensor<2xf32> {
+  %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xf32>
+  %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users : tensor<2xf32>
+  // CHECK: return %[[ARG]]
+  return %1 : tensor<2xf32>
+}
+
+// CHECK-LABEL: func @split_replicated_tensor_axis
+func.func @split_replicated_tensor_axis(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<3x14xf32>
+  %arg0: tensor<3x14xf32>
+) -> tensor<3x14xf32> {
+  // CHECK: %[[ZERO:.*]] = arith.constant 0 : index
+  // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE:.*]] = arith.constant 14 : index
+  // CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_index on @mesh_1d axes = [0] : index
+  // CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.cluster_shape @mesh_1d axes = [0] : index
+  // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE:.*]] = arith.remui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
+  // CHECK: %[[RESULT_TENSOR_AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE]], %[[ZERO]] : index
+  // CHECK: cf.assert %[[RESULT_TENSOR_AXIS_SIZE_CHECK]]
+  // CHECK: %[[RESULT_TENSOR_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
+  // CHECK: %[[RESULT_TENSOR_AXIS_OFFSET:.*]] = arith.muli %[[RESULT_TENSOR_AXIS_SIZE]], %[[PROCESS_INDEX]] : index
+  // CHECK: %[[RESULT_TENSOR_SLICE:.*]] = tensor.extract_slice %[[ARG]][0, %[[RESULT_TENSOR_AXIS_OFFSET]]] [3, 7] [1, 1] : tensor<3x14xf32> to tensor<3x7xf32>
+  // CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[RESULT_TENSOR_SLICE]] : tensor<3x7xf32> to tensor<3x14xf32>
+  %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<3x14xf32>
+  %1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<3x14xf32>
+  // CHECK: return %[[RESULT]] : tensor<3x14xf32>
+  return %1 : tensor<3x14xf32>
+}
+
+// CHECK-LABEL: func @split_replicated_tensor_axis_dynamic
+func.func @split_replicated_tensor_axis_dynamic(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<?x3x?xf32>
+  %arg0: tensor<?x3x?xf32>
+) -> tensor<?x3x?xf32> {
+  // CHECK: %[[ZERO:.*]] = arith.constant 0 : index
+  // CHECK: %[[TWO:.*]] = arith.constant 2 : index
+  // CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_index on @mesh_1d_dynamic axes = [0] : index
+  // CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.cluster_shape @mesh_1d_dynamic axes = [0] : index
+  // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %[[ZERO]] : tensor<?x3x?xf32>
+  // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE:.*]] = arith.remui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
+  // CHECK: %[[RESULT_TENSOR_AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE]], %[[ZERO]] : index
+  // CHECK: cf.assert %[[RESULT_TENSOR_AXIS_SIZE_CHECK]]
+  // CHECK: %[[RESULT_TENSOR_SPLIT_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
+  // CHECK: %[[RESULT_TENSOR_SPLIT_AXIS_OFFSET:.*]] = arith.muli %[[RESULT_TENSOR_SPLIT_AXIS_SIZE]], %[[PROCESS_INDEX]] : index
+  // CHECK: %[[TENSOR_AXIS_2_SIZE:.*]] = tensor.dim %[[ARG]], %[[TWO]] : tensor<?x3x?xf32>
+  // CHECK: %[[RESULT_TENSOR_SLICE:.*]] = tensor.extract_slice %[[ARG]][%[[RESULT_TENSOR_SPLIT_AXIS_OFFSET]], 0, 0]
+  // CHECK-SAME: [%[[RESULT_TENSOR_SPLIT_AXIS_SIZE]], 3, %[[TENSOR_AXIS_2_SIZE]]] [1, 1, 1] : tensor<?x3x?xf32> to tensor<?x3x?xf32>
+  %0 = mesh.shard %arg0 to <@mesh_1d_dynamic, [[], [], []]> : tensor<?x3x?xf32>
+  %1 = mesh.shard %0 to <@mesh_1d_dynamic, [[0]]> annotate_for_users : tensor<?x3x?xf32>
+  // CHECK: return %[[RESULT_TENSOR_SLICE]] : tensor<?x3x?xf32>
+  return %1 : tensor<?x3x?xf32>
+}
+
+// CHECK-LABEL: func @move_split_axis
+func.func @move_split_axis(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+  %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+  // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32>
+  // CHECK: %[[TARGET_SHARD:.*]] = mesh.all_to_all %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<5x14xf32> -> tensor<10x7xf32>
+  // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x7xf32> to tensor<10x14xf32>
+  %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<10x14xf32>
+  %1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<10x14xf32>
+  // CHECK: return %[[RES]] : tensor<10x14xf32>
+  return %1 : tensor<10x14xf32>
+}
+
+// CHECK-LABEL: func @move_split_axis_dynamic_mesh
+func.func @move_split_axis_dynamic_mesh(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+  %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+  // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32>
+  // CHECK: %[[ALL_TO_ALL:.*]] = mesh.all_to_all %[[SOURCE_SHARD]] on @mesh_1d_dynamic mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x?xf32>
+  // CHECK: %[[TARGET_SHARD:.*]] = tensor.cast %[[ALL_TO_ALL]] : tensor<?x?xf32> to tensor<10x?xf32>
+  // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x?xf32> to tensor<10x14xf32>
+  %0 = mesh.shard %arg0 to <@mesh_1d_dynamic, [[0]]> : tensor<10x14xf32>
+  %1 = mesh.shard %0 to <@mesh_1d_dynamic, [[], [0]]> annotate_for_users : tensor<10x14xf32>
+  // CHECK: return %[[RES]] : tensor<10x14xf32>
+  return %1 : tensor<10x14xf32>
+}
+
+// CHECK-LABEL: func @move_split_dynamic_axis
+func.func @move_split_dynamic_axis(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32>
+  %arg0: tensor<?x14xf32>
+) -> tensor<?x14xf32> {
+  // CHECK: %[[TARGET_SHARD:.*]] = mesh.all_to_all %[[ARG]] on @mesh_1d mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x7xf32>
+  // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<?x7xf32> to tensor<?x14xf32>
+  %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<?x14xf32>
+  %1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<?x14xf32>
+  // CHECK: return %[[RES]] : tensor<?x14xf32>
+  return %1 : tensor<?x14xf32>
+}
+
+// CHECK-LABEL: func @unshard_static_axis
+func.func @unshard_static_axis(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+  %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+  // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32>
+  // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<5x14xf32> -> tensor<10x14xf32>
+  %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<10x14xf32>
+  %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users : tensor<10x14xf32>
+  // CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32>
+  return %1 : tensor<10x14xf32>
+}
+
+// CHECK-LABEL: func @unshard_dynamic_axis
+func.func @unshard_dynamic_axis(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32>
+  %arg0: tensor<?x14xf32>
+) -> tensor<?x14xf32> {
+  // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[ARG]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32>
+  %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<?x14xf32>
+  %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users : tensor<?x14xf32>
+  // CHECK: return %[[ALL_GATHER]] : tensor<?x14xf32>
+  return %1 : tensor<?x14xf32>
+}
+
+// CHECK-LABEL: func @unshard_static_axis_on_dynamic_mesh_axis
+func.func @unshard_static_axis_on_dynamic_mesh_axis(
+// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>  
+  %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+  // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32>
+  // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d_dynamic mesh_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32>
+  // CHECK: %[[RES:.*]] = tensor.cast %[[ALL_GATHER]] : tensor<?x14xf32> to tensor<10x14xf32>
+  %0 = mesh.shard %arg0 to <@mesh_1d_dynamic, [[0]]> : tensor<10x14xf32>
+  %1 = mesh.shard %0 to <@mesh_1d_dynamic, [[]]> annotate_for_users : tensor<10x14xf32>
+  // CHECK: return %[[RES]] : tensor<10x14xf32>
+  return %1 : tensor<10x14xf32>
+}
+
+// CHECK-LABEL: func @partial_axis
+func.func @partial_axis(
+// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>  
+  %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+  // CHECK: %[[ALL_REDUCE:.*]] = mesh.all_reduce %[[ARG]] on @mesh_1d mesh_axes = [0] : tensor<10x14xf32> -> tensor<10x14xf32>
+  %0 = mesh.shard %arg0 to <@mesh_1d, [[]], partial = sum[0]> : tensor<10x14xf32>
+  %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users : tensor<10x14xf32>
+  // CHECK: %[[ALL_REDUCE]] : tensor<10x14xf32>
+  return %1 : tensor<10x14xf32>
+}

diff  --git a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
index 16b50bb878a074..f14d282857a1e0 100644
--- a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
@@ -1,5 +1,6 @@
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRMeshTestSimplifications
+  TestReshardingSpmdization.cpp
   TestSimplifications.cpp
 
   EXCLUDE_FROM_LIBMLIR

diff  --git a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp b/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
new file mode 100644
index 00000000000000..6fecbd48f15387
--- /dev/null
+++ b/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
@@ -0,0 +1,122 @@
+//===- TestSimplification.cpp - Test simplification -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::mesh;
+
+namespace {
+
+struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
+  using OpRewritePattern<ShardOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ShardOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op.getAnnotateForUsers()) {
+      return failure();
+    }
+
+    SymbolTableCollection symbolTable;
+    mesh::ClusterOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(
+        op, op.getShard().getCluster());
+
+    bool foundUser = false;
+    for (auto user : op->getUsers()) {
+      if (auto targetShardOp = llvm::dyn_cast<ShardOp>(user)) {
+        if (targetShardOp.getAnnotateForUsers() &&
+            mesh == symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(
+                        targetShardOp, targetShardOp.getShard().getCluster())) {
+          foundUser = true;
+          break;
+        }
+      }
+    }
+
+    if (!foundUser) {
+      return failure();
+    }
+
+    for (auto user : op->getUsers()) {
+      auto targetShardOp = llvm::dyn_cast<ShardOp>(user);
+      if (!targetShardOp || !targetShardOp.getAnnotateForUsers() ||
+          symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(
+              targetShardOp, targetShardOp.getShard().getCluster()) != mesh) {
+        continue;
+      }
+
+      ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+      ShapedType sourceShardShape =
+          shardShapedType(op.getResult().getType(), mesh, op.getShard());
+      TypedValue<ShapedType> sourceShard =
+          builder
+              .create<UnrealizedConversionCastOp>(sourceShardShape,
+                                                  op.getOperand())
+              ->getResult(0)
+              .cast<TypedValue<ShapedType>>();
+      TypedValue<ShapedType> targetShard =
+          reshard(builder, mesh, op, targetShardOp, sourceShard);
+      Value newTargetUnsharded =
+          builder
+              .create<UnrealizedConversionCastOp>(
+                  targetShardOp.getResult().getType(), targetShard)
+              ->getResult(0);
+      rewriter.replaceAllUsesWith(targetShardOp.getResult(),
+                                  newTargetUnsharded);
+    }
+
+    return success();
+  }
+};
+
+struct TestMeshReshardingPass
+    : public PassWrapper<TestMeshReshardingPass, OperationPass<ModuleOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshReshardingPass)
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    patterns.insert<TestMeshReshardingRewritePattern>(&getContext());
+    if (failed(applyPatternsAndFoldGreedily(getOperation().getOperation(),
+                                            std::move(patterns)))) {
+      return signalPassFailure();
+    }
+  }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    reshardingRegisterDependentDialects(registry);
+    registry.insert<BuiltinDialect>();
+  }
+  StringRef getArgument() const final {
+    return "test-mesh-resharding-spmdization";
+  }
+  StringRef getDescription() const final {
+    return "Test Mesh dialect resharding spmdization.";
+  }
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestMeshReshardingSpmdizationPass() {
+  PassRegistration<TestMeshReshardingPass>();
+}
+} // namespace test
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index dc4121dc46bb9b..f7a5b3183b50b1 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -119,6 +119,7 @@ void registerTestMathPolynomialApproximationPass();
 void registerTestMemRefDependenceCheck();
 void registerTestMemRefStrideCalculation();
 void registerTestMeshSimplificationsPass();
+void registerTestMeshReshardingSpmdizationPass();
 void registerTestNextAccessPass();
 void registerTestOneToNTypeConversionPass();
 void registerTestOpaqueLoc();
@@ -237,6 +238,7 @@ void registerTestPasses() {
   mlir::test::registerTestMemRefDependenceCheck();
   mlir::test::registerTestMemRefStrideCalculation();
   mlir::test::registerTestMeshSimplificationsPass();
+  mlir::test::registerTestMeshReshardingSpmdizationPass();
   mlir::test::registerTestNextAccessPass();
   mlir::test::registerTestOneToNTypeConversionPass();
   mlir::test::registerTestOpaqueLoc();


        


More information about the Mlir-commits mailing list