[Mlir-commits] [mlir] [mlir][mesh] Add resharding spmdization on a 1D device mesh (PR #76179)

Boian Petkantchin llvmlistbot at llvm.org
Tue Jan 2 10:45:15 PST 2024


https://github.com/sogartar updated https://github.com/llvm/llvm-project/pull/76179

>From 28111bed8dbc8bf057070baa56edfb530728dcb1 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Wed, 6 Dec 2023 07:58:53 -0800
Subject: [PATCH 1/6] [mlir][mesh] Add resharding spmdization on a 1D device
 mesh

The current implementation supports only sharding of tensor axes that have size
divisible by the mesh axis size.
---
 mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td |  11 +-
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h   |   3 +
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  |  48 +-
 .../Mesh/Transforms/ReshardingSpmdizationDoc  | 621 +++++++++++++++++
 .../Dialect/Mesh/Transforms/Spmdization.h     |  35 +
 mlir/include/mlir/Support/MathExtras.h        |  11 +
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          | 213 ++++--
 .../Dialect/Mesh/Transforms/CMakeLists.txt    |   3 +
 .../Dialect/Mesh/Transforms/Spmdization.cpp   | 639 ++++++++++++++++++
 mlir/test/Dialect/Mesh/invalid.mlir           |  96 +++
 mlir/test/Dialect/Mesh/ops.mlir               |  49 ++
 .../Dialect/Mesh/resharding-spmdization.mlir  | 154 +++++
 mlir/test/lib/Dialect/Mesh/CMakeLists.txt     |   1 +
 .../Mesh/TestReshardingSpmdization.cpp        | 124 ++++
 mlir/tools/mlir-opt/mlir-opt.cpp              |   2 +
 15 files changed, 1955 insertions(+), 55 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc
 create mode 100644 mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
 create mode 100644 mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
 create mode 100644 mlir/test/Dialect/Mesh/resharding-spmdization.mlir
 create mode 100644 mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 9d39b1b3329fb4..a9d30dfbb9a76e 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -33,6 +33,10 @@ def Mesh_Dialect : Dialect {
   let useDefaultAttributePrinterParser = 1;
   let hasConstantMaterializer = 1;
 }
+
+def Mesh_MeshAxis : I<16>;
+def Mesh_MeshAxesAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">;
+
 //===----------------------------------------------------------------------===//
 // Mesh Enums.
 //===----------------------------------------------------------------------===//
@@ -125,7 +129,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
 
     // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
     // it is also a partial_sum along mesh axis 1.
-    tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [], [1]]>
+    tensor<4x8xf32, #mesh.shard<@mesh0, [[0], []], partial = sum[1]>
 
     // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
     // it is also a partial_max along mesh axis 1.
@@ -158,6 +162,11 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
     }]>
   ];
 
+  let extraClassDeclaration = [{
+    bool operator==(::mlir::Attribute rhs) const;
+    bool operator==(::mlir::mesh::MeshShardingAttr rhs) const;
+  }];
+
   let genVerifyDecl = 1;
 }
 
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 9077d2eb0189b7..ce7d5d045122d9 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -30,6 +30,9 @@
 namespace mlir {
 namespace mesh {
 
+using MeshAxis = int16_t;
+using MeshAxesAttr = DenseI16ArrayAttr;
+
 bool isReductionLoop(IteratorType iType);
 
 bool areReductionAndPartialMatch(IteratorType iType, Partial partial);
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 784f3eb97763ad..1ed54b6519e4d8 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -10,6 +10,7 @@
 #define MLIR_DIALECT_MESH_IR_MESHOPS_TD
 
 include "mlir/Dialect/Mesh/IR/MeshBase.td"
+include "mlir/Dialect/Shape/IR/ShapeBase.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/BuiltinTypes.td"
@@ -95,6 +96,28 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
   let hasVerifier = 1;
 }
 
+def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+  let summary = "Get the shape of the cluster.";
+  let arguments = (ins
+    FlatSymbolRefAttr:$mesh,
+    DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
+  );
+
+  let results = (outs
+    Variadic<Index>:$result
+  );
+
+  let assemblyFormat = [{
+    $mesh (`axes` `=` $axes^)?
+    attr-dict `:` type($result)
+  }];
+
+  let builders = [
+    OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
+    OpBuilder<(ins "StringRef":$mesh, "ArrayRef<int16_t>":$axes)>
+  ];
+}
+
 def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
   let summary = "Annotate on how a tensor is sharded across a mesh cluster.";
   let description = [{
@@ -186,6 +209,29 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
   }];
 }
 
+def Mesh_ProcessIndexOp : Mesh_Op<"process_index", [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+  let summary = "Get the index of current device along specified mesh axis.";
+  let description = [{
+    It is used in the SPMD format of IR.
+    The `axes` mush be non-negative and less than the total number of mesh axes.
+  }];
+  let arguments = (ins
+    FlatSymbolRefAttr:$mesh,
+    DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
+  );
+  let results = (outs
+    Variadic<Index>:$result
+  );
+  let assemblyFormat = [{
+    `on` $mesh (`axes` `=` $axes^)?
+    attr-dict `:` type($result)
+  }];
+  let builders = [
+    OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
+    OpBuilder<(ins "StringRef":$mesh, "ArrayRef<int16_t>":$axes)>
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // collective communication ops
 //===----------------------------------------------------------------------===//
@@ -197,7 +243,7 @@ class Mesh_CollectiveCommunicationOpBase<
       [DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
   dag commonArgs = (ins
     FlatSymbolRefAttr:$mesh,
-    DefaultValuedAttr<DenseI16ArrayAttr, "{}">:$mesh_axes
+    DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$mesh_axes
   );
 }
 
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc b/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc
new file mode 100644
index 00000000000000..181f07177e0af9
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc
@@ -0,0 +1,621 @@
+Reshadring Spmdization Examples
+
+--------------------------------------------------------------
+
+Reshard 2x3 tensor from sharding [[0, 1]] to sharding [[0, 1]] on a 2x3 mesh.
+
+unsharded 2x3 tensor
+11 12 13
+21 22 23
+
+sharded on a 2x3 mesh
+
+sharding = [[0, 1]]
+
+mesh contents:
+
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 |             |
++----+----+----+             |
+| 21 | 22 | 23 |             |
++----+----+----+             ↓
+
+Transform into
+sharding = [[1, 0]]
+
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 13 | 22 |             |
++----+----+----+             |
+| 12 | 21 | 23 |             |
++----+----+----+             ↓
+
+Swap contents on devices that have the same linear index in the 2 shardings.
+
+--------------------------------------------------------------
+
+Reshard 2x3 tensor from sharding [[0, 1]] to sharding [[1]] on a 2x3 mesh.
+
+unsharded 2x3 tensor
+11 12 13
+21 22 23
+
+sharded on a 2x3 mesh
+
+sharding = [[0, 1]]
+
+mesh contents:
+
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 |             |
++----+----+----+             |
+| 21 | 22 | 23 |             |
++----+----+----+             ↓
+
+Transform into
+sharding = [[1]]
+
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 |             |
+| 21 | 22 | 23 |             |
++----+----+----+             |
+| 11 | 12 | 13 |             |
+| 21 | 22 | 23 |             |
++----+----+----+             ↓
+
+All-gather along mesh axis 0.
+
+--------------------------------------------------------------
+
+Reshard 4x6 tensor from sharding [[], [0, 1]] to sharding [[], [0]] on a 2x3 mesh.
+
+unsharded 4x6 tensor
+11 12 13 14 15 16
+21 22 23 24 25 26
+
+sharded on a 2x3 mesh
+
+sharding = [[], [0, 1]]
+
+mesh contents:
+
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 |             |
+| 21 | 22 | 23 |             |
++----+----+----+             |
+| 14 | 15 | 16 |             |
+| 24 | 25 | 26 |             |
++----+----+----+             ↓
+
+Transform into
+sharding = [[], [0]]
+
+mesh axis 1
+----------->
++----------+----------+ mesh axis 0 |
+| 11 12 13 | 11 12 13 |             |
+| 21 22 23 | 21 22 23 |             |
++----------+----------+             |
+| 14 15 16 | 14 15 16 |             |
+| 24 25 26 | 24 25 26 |             |
++----------+----------+             ↓
+
+all-gather along mesh axis 1
+
+--------------------------------------------------------------
+
+Reshard 4x8 tensor from sharding [[0], [1, 2]] to sharding [[0], [2]] on a 2x2x2 mesh.
+
+unsharded 4x8 tensor
+11 12 13 14 15 16 17 18
+21 22 23 24 25 26 27 28
+31 32 33 34 35 36 37 38
+41 42 43 44 45 46 47 48
+
+sharded on a 2x2x2 mesh
+
+sharding = [[0], [1, 2]]
+
+mesh contents:
+
+mesh axis 2
+----------->
++-------+-------+ mesh axis 1 | mesh axis 0 |
+| 11 12 | 13 14 |             |             |
+| 21 22 | 23 24 |             |             |
++-------+-------+             |             |
+| 15 16 | 17 18 |             |             |
+| 25 26 | 27 28 |             |             |
++-------+-------+             ↓             |
++-------+-------+                           |
+| 31 32 | 33 34 |                           |
+| 41 42 | 43 44 |                           |
++-------+-------+                           |
+| 35 36 | 37 38 |                           |
+| 45 46 | 47 48 |                           |
++-------+-------+                           ↓
+
+Transform into
+sharding = [[0], [2]]
+
+mesh axis 2
+----------->
++-------------+-------------+ mesh axis 1 | mesh axis 0 |
+| 11 12 13 14 | 15 16 17 18 |             |             |
+| 21 22 23 24 | 25 26 27 28 |             |             |
++-------------+-------------+             |             |
+| 11 12 13 14 | 15 16 17 18 |             |             |
+| 21 22 23 24 | 25 26 27 28 |             |             |
++-------------+-------------+             ↓             |
++-------------+-------------+                           |
+| 31 32 33 34 | 35 36 37 38 |                           |
+| 41 42 43 44 | 45 46 47 48 |                           |
++-------------+-------------+                           |
+| 31 32 33 34 | 35 36 37 38 |                           |
+| 41 42 43 44 | 45 46 47 48 |                           |
++-------------+-------------+                           ↓
+
+Can't be done with just an all-gather along mesh axis 1.
+Can be handled by multiple resharding transformations
+[[0], [1, 2]] -> [[0], [2, 1]] -> [[0], [2]]
+
+--------------------------------------------------------------
+
+Reshard 6x6 tensor from sharding [[0], [1]] to sharding [[1], [0]] on a 2x3 mesh.
+
+unsharded 6x6 tensor
+11 12 13 14 15 16
+21 22 23 24 25 26
+31 32 33 34 35 36
+41 42 43 44 45 46
+51 52 53 54 55 56
+61 62 63 64 65 66
+
+sharded on a 2x3 mesh
+
+sharding = [[0], [1]]
+
+mesh axis 1
+----------->
++-------+-------+-------+ mesh axis 0 |
+| 11 12 | 13 14 | 15 16 |             |
+| 21 22 | 23 24 | 25 26 |             |
+| 31 32 | 33 34 | 35 36 |             |
++-------+-------+-------+             |
+| 41 42 | 43 44 | 45 46 |             |
+| 51 52 | 53 54 | 55 56 |             |
+| 61 62 | 63 64 | 65 66 |             |
++-------+-------+-------+             ↓
+
+transform to
+sharding = [[1], [0]]
+
+mesh axis 1
+----------->
++----------+----------+----------+ mesh axis 0 |
+| 11 12 13 | 31 32 33 | 51 52 53 |             |
+| 21 22 23 | 41 42 43 | 61 62 63 |             |
++----------+----------+----------+             |
+| 14 15 16 | 34 35 36 | 54 55 56 |             |
+| 24 25 26 | 44 45 46 | 64 65 66 |             |
++----------+----------+----------+             ↓
+
+mesh axis 0
+----------->
++----------+----------+ mesh axis 1 |
+| 11 12 13 | 14 15 16 |             |
+| 21 22 23 | 24 25 26 |             |
++----------+----------+             |
+| 31 32 33 | 34 35 36 |             |
+| 41 42 43 | 44 45 46 |             |
++----------+----------+             |
+| 51 52 53 | 54 55 56 |             |
+| 61 62 63 | 64 65 66 |             |
++----------+----------+             ↓
+
+TODO
+
+--------------------------------------------------------------
+
+Reshard 6x6 tensor from sharding [[0], [1]] to sharding [[1], [0]] on a 2x6 mesh.
+
+unsharded 6x6 tensor
+11 12 13 14 15 16
+21 22 23 24 25 26
+31 32 33 34 35 36
+41 42 43 44 45 46
+51 52 53 54 55 56
+61 62 63 64 65 66
+
+shard on 2x6 mesh
+
+sharding = [[0], [1]]
+
+mesh axis 1
+----------->
++----+----+----+----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 ‖ 14 | 15 | 16 |             |
+| 21 | 22 | 23 ‖ 24 | 23 | 26 |             |
+| 31 | 32 | 33 ‖ 34 | 35 | 36 |             |
++----+----+----+----+----+----+             |
+| 41 | 42 | 43 ‖ 44 | 45 | 46 |             |
+| 51 | 52 | 53 ‖ 54 | 55 | 56 |             |
+| 61 | 62 | 63 ‖ 64 | 65 | 66 |             |
++----+----+----+----+----+----+             ↓
+
+
+transform to
+sharding = [[1], [0]]
+
+mesh axis 0
+----------->
++----------+----------+ mesh axis 1 |
+| 11 12 13 | 14 15 16 |             |
++----------+----------+             |
+| 21 22 23 | 24 25 26 |             |
++----------+----------+             |
+| 31 32 33 | 34 35 36 |             |
++==========+==========+             |
+| 41 42 43 | 44 45 46 |             |
++----------+----------+             |
+| 51 52 53 | 54 55 56 |             |
++----------+----------+             |
+| 61 62 63 | 64 65 66 |             |
++----------+----------+             ↓
+
+TODO
+
+--------------------------------------------------------------
+
+Reshard KxL tensor from [[0], [1]] to [[1], [0]] on MxN mesh.
+
+M x N mesh.
+K x L tensor t.
+d(m, n) the tensor on device (m, n).
+
+sharding = [[0], [1]]
+Tensor shard s on each device has size (K ceildiv M, L ceildiv N).
+d(m, n)[k, l] -> t[m * (K ceildiv M) + k, n * (L ceildiv N) + l]
+
+substitute
+i <- m * (K ceildiv M) + k
+j <- n * (L ceildiv N) + l
+
+m -> i floordiv (K ceildiv M)
+n -> j floordiv (L ceildiv N)
+k -> i % (K ceildiv M)
+l -> j % (L ceildiv N)
+
+For the inverse map we get
+t[i, j] -> d(
+    i floordiv (K ceildiv M), j floordiv (L ceildiv N)
+)[
+    i % (K ceildiv M), j % (L ceildiv N)
+]
+
+Check:
+i = 13, j = 17, M = 3, N = 4, K = 16, L = 23
+t[13, 17] = d(
+    13 floordiv (16 ceildiv 3),
+    17 floordiv (23 ceilvid 4)
+)[
+    13 % (16 ceildiv 3),
+    17 % (23 ceilvid 4)
+]
+= d(
+    13 floordiv 6,
+    17 floordiv 6
+)[
+    13 % 6,
+    17 % 6
+]
+= d(2, 2)[1, 5]
+= t[
+    2 * (16 ceildiv 3) + 1,
+    2 * (23 ceildiv 4) + 5
+]
+= t[
+    2 * 6 + 1,
+    2 * 6 + 5
+]
+= t[13, 17]
+
+
+sharding = [[1], [0]]
+Tensor shard s on each device has size (K ceildiv N, L ceildiv M).
+d(m, n)[k, l] -> t[n * (K ceildiv N) + k, m * (L ceildiv M) + l]
+
+substitute
+i <- n * (K ceildiv N) + k
+j <- m * (L ceildiv M) + l
+
+m -> j floordiv (L ceildiv M)
+n -> i floordiv (K ceildiv N)
+k -> i % (K ceildiv N)
+l -> j % (L ceildiv M)
+
+For the inverse map we get
+t[i, j] -> d(
+    j floordiv (L ceildiv M), i floordiv (K ceildiv N)
+)[
+    i % (K ceildiv N), j % (L ceildiv M)
+]
+
+Check:
+i = 9, j = 19, M = 5, N = 2, K = 27, L = 14
+t[9, 19] = d(
+    19 floordiv (14 ceildiv 5),
+    9 floordiv (27 ceildiv 2)
+)[
+    9 % (27 ceildiv 2),
+    19 % (14 ceildiv 5)
+]
+= d(
+    19 floordiv 3,
+    9 floordiv 14
+)[
+    9 % 14
+    19 % 3
+]
+= d(6, 0)[9, 1]
+= t[
+    0 * (27 ceildiv 2) + 9,
+    6 * (14 ceildiv 5) + 1
+]
+= t[
+    0 * 14 + 9,
+    6 * 3 + 1
+]
+= t[9, 19]
+
+sharding = [[0], [1]]
+d(m, n)[k, l] -> t[m * (K ceildiv M) + k, n * (L ceildiv N) + l]
+t[i, j] -> d(i floordiv (K ceildiv M), j floordiv (L ceildiv N))[i % (K ceildiv M), j % (L ceildiv N)]
+
+sharding = [[1], [0]]
+d(m, n)[k, l] -> t[n * (K ceildiv N) + k, m * (L ceildiv M) + l]
+t[i, j] -> d(j floordiv (L ceildiv M), i floordiv (K ceildiv N))[i % (K ceildiv N), j % (L ceildiv M)]
+
+sharding [[0], [1]] -> [[1], [0]]
+d1(m, n) the tensor on device (m, n) for sharding sharding [[0], [1]].
+d2(m, n) the tensor on device (m, n) for sharding sharding [[1], [0]].
+d1(m, n)[k, l] ->
+t[m * (K ceildiv M) + k, n * (L ceildiv N) + l] ->
+d2(
+    (m * (L ceildiv M) + l) floordiv (L ceildiv M),
+    (n * (K ceildiv N) + k) floordiv (K ceildiv N)
+)[
+    (n * (K ceildiv N) + k) % (K ceildiv N),
+    (m * (L ceildiv M) + l) % (L ceildiv M)
+]
+= d2(p, q)[u, v]
+
+We want to copy the the data between devices in slices/tiles.
+What are the source/target tile coordinates?
+Fro a fixed (m, n, p, q) what is the range of (k, l, u, v)?
+TODO
+
+--------------------------------------------------------------
+
+Reshard KxL tensor from sharding [[0], [1]] to sharding [[1], [0]] on a 2x3 mesh.
+
+Device placement on a 2x3 mesh
+11 12 13  <- devices
+21 22 23
+
+sharding [[0], [1]]
+tensor axis 1
+----------->
++----+----+----+ tensor axis 0 |
+| 11 | 12 | 13 |               |
++----+----+----+               |
+| 21 | 22 | 23 |               |
++----+----+----+               ↓
+
+transform to
+sharding [[1], [0]]
+tensor axis 1
+----------->
++----+----+ tensor axis 0 |
+| 11 | 21 |               |
++----+----+               |
+| 12 | 22 |               |
++----+----+               |
+| 13 | 23 |               |
++----+----+               ↓
+
++-----------------+--------+--------+-----------------+
+|                 |                 |                 |
++                 +                 +                 +
+|       11        |        12       |        13       |
++                 +                 +                 +
+|                 |                 |                 |
++-----------------+--------+--------+-----------------+
+|                 |                 |                 |
++                 +                 +                 +
+|       21        |        22       |        23       |
++                 +                 +                 +
+|                 |                 |                 |
++-----------------+--------+--------+-----------------+
+
++-----------------+--------+--------+-----------------+
+|                          |                          |
++          11              +               21         +
+|                          |                          |
++-----------------+--------+--------+-----------------+
+|                          |                          |
++          12              +               22         +
+|                          |                          |
++-----------------+--------+--------+-----------------+
+|                          |                          |
++          13              +               23         +
+|                          |                          |
++-----------------+--------+--------+-----------------+
+
++-----------------+--------+--------+-----------------+
+|                 |        |        |                 |
++     11  11      + 12  11 + 12  21 +     13  21      +
+|                 |        |        |                 |
++-----------------+--------+--------+-----------------+
+|     11  12      | 12  12 | 12  22 |     13  22      |
++-----------------+--------+--------+-----------------+
+|     21  12      | 22  12 | 22  22 |     23  22      |
++-----------------+--------+--------+-----------------+
+|                 |        |        |                 |
++     21  13      + 22  13 + 22  23 +     23  23      +
+|                 |        |        |                 |
++-----------------+--------+--------+-----------------+
+
+If S and T are the source and target shard sizes along some tensor axis.
+Then we have a period of (S*T)/gcd(S, T). Then the cut pattern repeats.
+TODO
+
+--------------------------------------------------------------
+
+Reshard 6x6 tensor from sharding [[0], []] to sharding [[], [0]] on a 3 mesh.
+
+unsharded 6x6 tensor
+11 12 13 14 15 16
+21 22 23 24 25 26
+31 32 33 34 35 36
+41 42 43 44 45 46
+51 52 53 54 55 56
+61 62 63 64 65 66
+
+sharded on a 3 mesh
+
+sharding = [[0], []]
+
++-------------------+ mesh axis 0 |
+| 11 12 13 14 15 16 |             |
+| 21 22 23 24 25 26 |             |
++-------------------+             |
+| 31 32 33 34 35 36 |             |
+| 41 42 43 44 45 46 |             |
++-------------------+             |
+| 51 52 53 54 55 56 |             |
+| 61 62 63 64 65 66 |             |
++-------------------+             ↓
+
+transform to
+sharding = [[], [0]]
+
+mesh axis 0
+----------->
++-------+-------+-------+
+| 11 12 | 13 14 | 15 16 |
+| 21 22 | 23 24 | 25 26 |
+| 31 32 | 33 34 | 35 36 |
+| 41 42 | 43 44 | 45 46 |
+| 51 52 | 53 54 | 55 56 |
+| 61 62 | 63 64 | 65 66 |
++-------+-------+-------+
+
+%1 = all_to_all %0 on @mesh mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<2x6xi8> -> tensor<6x2xi8>
+
+--------------------------------------------------------------
+
+Reshard 4x4 tensor from sharding [[0], [1, 2]] to sharding [[0, 1], [2]] on a 2x2x2 mesh.
+
+unsharded 4x4 tensor
+11 12 13 14
+21 22 23 24
+31 32 33 34
+41 42 43 44
+
+sharded on a 2x2x2 mesh
+
+sharding = [[0], [1, 2]]
+
+mesh axis 2
+----------->
++----+----+ mesh axis 1 | mesh axis 0 |
+| 11 | 12 |             |             |
+| 21 | 22 |             |             |
++----+----+             |             |
+| 13 | 14 |             |             |
+| 23 | 24 |             |             |
++----+----+             ↓             |
++----+----+                           |
+| 31 | 32 |                           |
+| 41 | 42 |                           |
++----+----+                           |
+| 33 | 34 |                           |
+| 43 | 44 |                           |
++----+----+                           ↓
+
+transform to
+sharding = [[0, 1], [2]]
+
+mesh axis 2
+----------->
++-------+-------+ mesh axis 1 | mesh axis 0 |
+| 11 12 | 13 41 |             |             |
++-------+-------+             |             |
+| 21 22 | 23 24 |             |             |
++-------+-------+             ↓             |
++-------+-------+                           |
+| 31 32 | 33 34 |                           |
++-------+-------+                           |
+| 41 42 | 43 44 |                           |
++-------+-------+                           ↓
+
+%1 = all_to_all %0 on @mesh mesh_axes = [2] split_axis = 1 concat_axis = 0 : tensor<2x1xi8> -> tensor<1x2xi8>
+is not enough.
+
+Can be decomposed into
+[[0], [1, 2]] -> [[0], [2, 1]] -> [[0, 1], [2]]
+
+--------------------------------------------------------------
+
+We can decompose each resharding into a sequence of basis reshardings.
+It is not communication efficient in terms of minimizing the data communicated
+between devices.
+An efficient approach would be more complicated to implement.
+Each device has to receive at most as much data as the size of its target
+sharding tensor.
+
+Basis:
+
+From replicate to split.
+[[]] -> [[1]]
+Extract slices without communication.
+
+From split to replicate.
+[[0]] -> [[]]
+[[0, 1]] -> [[1]]
+All-gather along mesh axis 0.
+
+Swap mesh axes order when assigned to the same tensor axis.
+[[0, 1]] -> [[1, 0]]
+Swap contents on devices with the same linear index.
+
+Move mesh axis to different tensor dimension.
+[[0], []] -> [[], [0]]
+All-to-all.
+
+Example decomposition of
+[[0], [1]] -> [[1], [0]]
+
+[[0], [1]] -> all-gather along mesh axis 1    ->
+[[0], []]  -> all-to-all along mesh axis 0    ->
+[[], [0]]  -> extract slice along mesh axis 1 ->
+[[1], [0]]
+
+Example decomposition of
+[[3, 2], [], [0, 1]] -> [[0], [1, 2], []]
+
+[[3, 2], [], [0, 1]] -> all-to-all along mesh axis 1 ->
+[[3, 2], [1], [0]]   -> all-to-all along mesh axis 2 ->
+[[3], [1, 2], [0]]   -> all-gather along mesh axis 3 ->
+[[], [1, 2], [0]]    -> all-to-all along mesh axis 0 ->
+[[0], [1, 2], []]
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
new file mode 100644
index 00000000000000..f71bb9b262a380
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
@@ -0,0 +1,35 @@
+//===- Simplifications.h - Mesh Simplifications -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_SPMDIZATION_H
+#define MLIR_DIALECT_MESH_TRANSFORMS_SPMDIZATION_H
+
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+namespace mesh {
+
+// Return the sharded shape `shape` acording ot sharding `sharding`.
+ShapedType shardShapedType(ShapedType shape, ClusterOp mesh,
+                           MeshShardingAttr sharding);
+
+// Insert resharding spmdization of the value `sourceShardValue`
+// from sharding `source` to sharding `target`.
+// `sourceShardValue` is the already sharded value according to `source`.
+TypedValue<ShapedType> reshard(OpBuilder &builder, ClusterOp mesh,
+                               ShardOp source, ShardOp target,
+                               TypedValue<ShapedType> sourceShardValue);
+
+void reshardingRegisterDependentDialects(DialectRegistry &registry);
+
+} // namespace mesh
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MESH_TRANSFORMS_SPMDIZATION_H
diff --git a/mlir/include/mlir/Support/MathExtras.h b/mlir/include/mlir/Support/MathExtras.h
index 17a747393d26eb..22493c11a370d1 100644
--- a/mlir/include/mlir/Support/MathExtras.h
+++ b/mlir/include/mlir/Support/MathExtras.h
@@ -15,9 +15,20 @@
 
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/APInt.h"
+#include <type_traits>
 
 namespace mlir {
 
+// ceilDiv for unsigned integral.
+template <typename T, std::enable_if_t<std::is_integral_v<T> &&
+                                       !std::is_unsigned_v<T>> = true>
+T ceilDiv(T lhs, T rhs) {
+  assert(rhs != static_cast<T>(0));
+  T q = lhs / rhs;
+  T r = lhs % rhs;
+  return r == static_cast<T>(0) ? q : q + static_cast<T>(1);
+}
+
 /// Returns the result of MLIR's ceildiv operation on constants. The RHS is
 /// expected to be non-zero.
 inline int64_t ceilDiv(int64_t lhs, int64_t rhs) {
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index d27675564c6464..de4f58d54e8ca5 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -9,8 +9,10 @@
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/Location.h"
@@ -59,8 +61,6 @@ static SmallVector<T> &canonicalizeSetAsVector(SmallVector<T> &vec) {
   return vec;
 }
 
-using MeshAxis = int16_t;
-
 namespace {
 
 struct DimensionSize {
@@ -114,6 +114,56 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
 // Mesh utilities
 //===----------------------------------------------------------------------===//
 
+static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
+                                    SymbolTableCollection &symbolTable) {
+  mesh::ClusterOp mesh =
+      symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(op, meshSymbol);
+  if (!mesh) {
+    return op->emitError() << "Undefined required mesh symbol \""
+                           << meshSymbol.getValue() << "\".";
+  }
+
+  return mesh;
+}
+
+template <typename It>
+bool isUnique(It begin, It end) {
+  if (begin == end) {
+    return true;
+  }
+  It next = std::next(begin);
+  if (next == end) {
+    return true;
+  }
+  for (; next != end; ++next, ++begin) {
+    if (*begin == *next) {
+      return false;
+    }
+  }
+  return true;
+}
+
+static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
+                                    ClusterOp mesh) {
+  SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
+  llvm::sort(sorted);
+  if (!isUnique(sorted.begin(), sorted.end())) {
+    return emitError(loc) << "Mesh axes contains duplicate elements.";
+  }
+
+  MeshAxis rank = mesh.getRank();
+  for (auto axis : axes) {
+    if (axis >= rank || axis < 0) {
+      return emitError(loc)
+             << "0-based mesh axis index " << axis
+             << " is out of bounds. The referenced mesh \"" << mesh.getSymName()
+             << "\" is of rank " << rank << ".";
+    }
+  }
+
+  return success();
+}
+
 bool mesh::isReductionLoop(IteratorType iType) {
   return iType != IteratorType::Parallel && iType != IteratorType::Invalid;
 }
@@ -173,7 +223,45 @@ SmallVector<int64_t> ClusterOp::canonicalDimSizes() {
 }
 
 //===----------------------------------------------------------------------===//
-// mesh.shard op
+// mesh.cluster_shape op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+ClusterShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
+    return failure();
+  }
+
+  size_t expectedResultsCount =
+      getAxes().empty() ? mesh->getRank() : getAxes().size();
+  if (getResult().size() != expectedResultsCount) {
+    return emitError() << "Unexpected number of results " << getResult().size()
+                       << ". Expected " << expectedResultsCount << ".";
+  }
+
+  return success();
+}
+
+void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                           ClusterOp mesh) {
+  build(odsBuilder, odsState,
+        SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
+        mesh.getSymName(), MeshAxesAttr());
+}
+
+void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                           StringRef mesh, ArrayRef<MeshAxis> axes) {
+  build(odsBuilder, odsState,
+        SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
+        MeshAxesAttr::get(odsBuilder.getContext(), axes));
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.shard attr
 //===----------------------------------------------------------------------===//
 
 LogicalResult
@@ -205,6 +293,75 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
+bool MeshShardingAttr::operator==(Attribute rhs) const {
+  MeshShardingAttr rhsAsMeshShardingAttr = rhs.dyn_cast<MeshShardingAttr>();
+  return rhsAsMeshShardingAttr && *this == rhsAsMeshShardingAttr;
+}
+
+bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
+  if (getCluster() != rhs.getCluster() ||
+      getPartialAxes() != rhs.getPartialAxes()) {
+    return false;
+  }
+
+  if (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) {
+    return false;
+  }
+
+  auto minSize = std::min(getSplitAxes().size(), rhs.getSplitAxes().size());
+  if (!llvm::equal(llvm::make_range(getSplitAxes().begin(),
+                                    getSplitAxes().begin() + minSize),
+                   llvm::make_range(rhs.getSplitAxes().begin(),
+                                    rhs.getSplitAxes().begin() + minSize))) {
+    return false;
+  }
+
+  return llvm::all_of(llvm::make_range(getSplitAxes().begin() + minSize,
+                                       getSplitAxes().end()),
+                      std::mem_fn(&DenseI32ArrayAttr::empty)) &&
+         llvm::all_of(llvm::make_range(rhs.getSplitAxes().begin() + minSize,
+                                       rhs.getSplitAxes().end()),
+                      std::mem_fn(&DenseI32ArrayAttr::empty));
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.process_index op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+ProcessIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
+    return failure();
+  }
+
+  size_t expectedResultsCount =
+      getAxes().empty() ? mesh->getRank() : getAxes().size();
+  if (getResult().size() != expectedResultsCount) {
+    return emitError() << "Unexpected number of results " << getResult().size()
+                       << ". Expected " << expectedResultsCount << ".";
+  }
+
+  return success();
+}
+
+void ProcessIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                           ClusterOp mesh) {
+  build(odsBuilder, odsState,
+        SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
+        mesh.getSymName(), MeshAxesAttr());
+}
+
+void ProcessIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                           StringRef mesh, ArrayRef<MeshAxis> axes) {
+  build(odsBuilder, odsState,
+        SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
+        MeshAxesAttr::get(odsBuilder.getContext(), axes));
+}
+
 //===----------------------------------------------------------------------===//
 // collective communication ops
 //===----------------------------------------------------------------------===//
@@ -258,56 +415,6 @@ static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
   return success();
 }
 
-static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
-                                    SymbolTableCollection &symbolTable) {
-  mesh::ClusterOp mesh =
-      symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(op, meshSymbol);
-  if (!mesh) {
-    return op->emitError() << "Undefined required mesh symbol \""
-                           << meshSymbol.getValue() << "\".";
-  }
-
-  return mesh;
-}
-
-template <typename It>
-bool isUnique(It begin, It end) {
-  if (begin == end) {
-    return true;
-  }
-  It next = std::next(begin);
-  if (next == end) {
-    return true;
-  }
-  for (; next != end; ++next, ++begin) {
-    if (*begin == *next) {
-      return false;
-    }
-  }
-  return true;
-}
-
-static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
-                                    ClusterOp mesh) {
-  SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
-  llvm::sort(sorted);
-  if (!isUnique(sorted.begin(), sorted.end())) {
-    return emitError(loc) << "Mesh axes contains duplicate elements.";
-  }
-
-  MeshAxis rank = mesh.getRank();
-  for (auto axis : axes) {
-    if (axis >= rank || axis < 0) {
-      return emitError(loc)
-             << "0-based mesh axis index " << axis
-             << " is out of bounds. The referenced mesh \"" << mesh.getSymName()
-             << "\" is of rank " << rank << ".";
-    }
-  }
-
-  return success();
-}
-
 template <typename Op>
 static FailureOr<ClusterOp>
 getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) {
diff --git a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
index 044b8672c8c60c..7a70c047ec9dce 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRMeshTransforms
   Simplifications.cpp
   ShardingPropagation.cpp
+  Spmdization.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
@@ -11,11 +12,13 @@ add_mlir_dialect_library(MLIRMeshTransforms
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
+  MLIRControlFlowDialect
   MLIRFuncDialect
   MLIRIR
   MLIRMeshDialect
   MLIRPass
   MLIRShardingInterface
   MLIRSupport
+  MLIRTensorDialect
   MLIRTosaShardingInterfaceImpl
 )
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
new file mode 100644
index 00000000000000..de8b3a98df998a
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -0,0 +1,639 @@
+//===- Spmdization.cpp --------------------------------------------- C++ --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Support/MathExtras.h"
+#include "llvm/ADT/ADL.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include <algorithm>
+#include <iterator>
+#include <numeric>
+#include <optional>
+#include <tuple>
+#include <type_traits>
+
+namespace mlir {
+namespace mesh {
+
+int64_t shardDimension(int64_t dim, int64_t shardCount) {
+  if (ShapedType::isDynamic(dim) || ShapedType::isDynamic(shardCount))
+    return ShapedType::kDynamic;
+
+  assert(dim % shardCount == 0);
+  return ceilDiv(dim, shardCount);
+}
+
+int64_t unshardDimension(int64_t dim, int64_t shardCount) {
+  if (ShapedType::isDynamic(dim) || ShapedType::isDynamic(shardCount))
+    return ShapedType::kDynamic;
+
+  return dim * shardCount;
+}
+
+template <typename MeshShape, typename SplitAxes>
+int64_t shardCount(const MeshShape &meshShape, const SplitAxes &splitAxes) {
+  int64_t res = 1;
+  for (auto splitAxis : splitAxes) {
+    int64_t meshDimSize = meshShape[splitAxis];
+    if (ShapedType::isDynamic(meshDimSize)) {
+      return ShapedType::kDynamic;
+    }
+    res *= meshDimSize;
+  }
+  return res;
+}
+
+// Compute the shape for the tensor on each device in the mesh.
+// Example:
+// On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1
+// would result in a shape for each shard of ?x2x?.
+template <typename InShape, typename MeshShape, typename SplitAxes,
+          typename OutShape>
+static void shardShape(const InShape &inShape, const MeshShape &meshShape,
+                       const SplitAxes &splitAxes, OutShape &outShape) {
+  std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
+            llvm::adl_begin(outShape));
+  for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
+    outShape[tensorAxis] =
+        shardDimension(inShape[tensorAxis],
+                       shardCount(meshShape, innerSplitAxes.asArrayRef()));
+  }
+}
+
+ShapedType shardShapedType(ShapedType shape, ClusterOp mesh,
+                           MeshShardingAttr sharding) {
+  using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
+  SmallVector<Dim> resShapeArr(shape.getShape().size());
+  shardShape(shape.getShape(), mesh.canonicalDimSizes(),
+             sharding.getSplitAxes(), resShapeArr);
+  return shape.clone(resShapeArr);
+}
+
+template <typename SourceAxes, typename TargetAxes>
+static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
+                                     const TargetAxes &targetAxes) {
+  return llvm::all_of(targetAxes, [&sourceAxes](auto &targetAxis) {
+    return sourceAxes.contains(targetAxis);
+  });
+}
+
+// Return the reduced value and its corresponding sharding.
+// Example:
+// sourceSharding = <@mesh_1d, [[0]], partial = sum[0]>
+// targetSharding = <@mesh_1d, [[]]>
+// Then will apply all-reduce on the source value
+// and return it with the sharding <@mesh_1d, [[0]]>.
+static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
+handlePartialAxesDuringResharding(OpBuilder &builder,
+                                  MeshShardingAttr sourceSharding,
+                                  MeshShardingAttr targetSharding,
+                                  TypedValue<ShapedType> sourceShard) {
+  if (sourceSharding.getPartialAxes().empty() &&
+      targetSharding.getPartialAxes().empty()) {
+    return {sourceShard, sourceSharding};
+  }
+  assert(targetSharding.getPartialAxes().empty() ||
+         (!sourceSharding.getPartialAxes().empty() &&
+          sourceSharding.getPartialType() == targetSharding.getPartialType()));
+  using Axis = std::decay_t<decltype(sourceSharding.getPartialAxes().front())>;
+  using AxisSet = llvm::SmallDenseSet<Axis>;
+  AxisSet sourceShardingPartialAxesSet(sourceSharding.getPartialAxes().begin(),
+                                       sourceSharding.getPartialAxes().end());
+  AxisSet targetShardingPartialAxesSet(targetSharding.getPartialAxes().begin(),
+                                       targetSharding.getPartialAxes().end());
+  assert(arePartialAxesCompatible(sourceShardingPartialAxesSet,
+                                  targetShardingPartialAxesSet));
+  llvm::SmallVector<MeshAxis> allReduceMeshAxes;
+  llvm::copy_if(sourceShardingPartialAxesSet,
+                std::back_inserter(allReduceMeshAxes),
+                [&targetShardingPartialAxesSet](Axis a) {
+                  return !targetShardingPartialAxesSet.contains(a);
+                });
+  if (allReduceMeshAxes.empty()) {
+    return {sourceShard, sourceSharding};
+  }
+
+  builder.setInsertionPointAfterValue(sourceShard);
+  TypedValue<ShapedType> resultValue =
+      builder
+          .create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
+                               sourceSharding.getCluster().getLeafReference(),
+                               allReduceMeshAxes, sourceShard,
+                               sourceSharding.getPartialType())
+          .getResult()
+          .cast<TypedValue<ShapedType>>();
+
+  llvm::SmallVector<int32_t> remainingPartialAxes;
+  llvm::copy_if(sourceShardingPartialAxesSet,
+                std::back_inserter(allReduceMeshAxes),
+                [&targetShardingPartialAxesSet](Axis a) {
+                  return targetShardingPartialAxesSet.contains(a);
+                });
+  MeshShardingAttr resultSharding =
+      MeshShardingAttr::get(builder.getContext(), sourceSharding.getCluster(),
+                            sourceSharding.getSplitAxes(), remainingPartialAxes,
+                            sourceSharding.getPartialType());
+  return {resultValue, resultSharding};
+}
+
+static MeshShardingAttr
+targetShardingInSplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
+                              int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
+  SmallVector<DenseI32ArrayAttr> targetShardingSplitAxes =
+      llvm::to_vector(sourceSharding.getSplitAxes());
+  while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
+         splitTensorAxis) {
+    targetShardingSplitAxes.push_back(DenseI32ArrayAttr::get(ctx, {}));
+  }
+  auto targetSplitAxes =
+      llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
+  targetSplitAxes.push_back(splitMeshAxis);
+  targetShardingSplitAxes[splitTensorAxis] =
+      DenseI32ArrayAttr::get(ctx, targetSplitAxes);
+  return MeshShardingAttr::get(
+      ctx, sourceSharding.getCluster(), targetShardingSplitAxes,
+      sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
+}
+
+static ShapedType targetShapeInSplitLastAxis(ShapedType sourceShape,
+                                             int64_t splitTensorAxis,
+                                             int64_t splitCount) {
+  SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
+  targetShape[splitTensorAxis] =
+      shardDimension(targetShape[splitTensorAxis], splitCount);
+  return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
+}
+
+// Split a replicated tensor along a mesh axis.
+// e.g. [[0, 1]] -> [[0, 1, 2]].
+// Returns the spmdized target value with its sharding.
+//
+// The implementation is the extract the tensor slice corresponding
+// to the current device.
+static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
+splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
+                          MeshShardingAttr sourceSharding,
+                          TypedValue<ShapedType> sourceShard, ClusterOp mesh,
+                          int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
+  MLIRContext *ctx = builder.getContext();
+  builder.setInsertionPointAfterValue(sourceShard);
+
+  Value zero = builder.create<arith::ConstantOp>(builder.getIndexAttr(0));
+
+  Value processIndexAlongAxis =
+      builder
+          .create<ProcessIndexOp>(mesh.getSymName(),
+                                  SmallVector<MeshAxis>({splitMeshAxis}))
+          .getResult()[0];
+
+  MeshShardingAttr targetSharding = targetShardingInSplitLastAxis(
+      ctx, sourceSharding, splitTensorAxis, splitMeshAxis);
+  ShapedType targetShape =
+      targetShapeInSplitLastAxis(sourceShard.getType(), splitTensorAxis,
+                                 mesh.canonicalDimSizes()[splitMeshAxis]);
+
+  Value meshAxisSize =
+      builder
+          .create<ClusterShapeOp>(mesh.getSymName(),
+                                  SmallVector<MeshAxis>({splitMeshAxis}))
+          .getResult()[0];
+
+  Value sourceAxisSize =
+      builder.create<tensor::DimOp>(sourceShard, splitTensorAxis);
+  Value sourceAxisSizeModMeshAxisSize =
+      builder.create<arith::RemUIOp>(sourceAxisSize, meshAxisSize);
+  Value isTargetShapeExactlyDivisible = builder.create<arith::CmpIOp>(
+      arith::CmpIPredicate::eq, sourceAxisSizeModMeshAxisSize, zero);
+  builder.create<cf::AssertOp>(
+      isTargetShapeExactlyDivisible,
+      "Sharding a tensor with axis size that is not exactly divisible by the "
+      "mesh axis size is not supported.");
+  Value targetAxisSize =
+      builder.create<arith::DivUIOp>(sourceAxisSize, meshAxisSize);
+  Value axisOffset =
+      builder.create<arith::MulIOp>(targetAxisSize, processIndexAlongAxis);
+  SmallVector<int64_t> staticOffsets(targetShape.getRank(), 0);
+  staticOffsets[splitTensorAxis] = ShapedType::kDynamic;
+  DenseI64ArrayAttr staticOffsetsAttr =
+      DenseI64ArrayAttr::get(ctx, staticOffsets);
+  SmallVector<Value> dynamicOffsets(1, axisOffset);
+
+  DenseI64ArrayAttr staticSizesAttr =
+      DenseI64ArrayAttr::get(ctx, targetShape.getShape());
+  SmallVector<Value> dynamicSizes;
+  for (int64_t i = 0; i < targetShape.getRank(); ++i) {
+    if (ShapedType::isDynamic(staticSizesAttr.asArrayRef()[i])) {
+      if (i == splitTensorAxis) {
+        dynamicSizes.push_back(targetAxisSize);
+      } else {
+        Value dimSize = builder.create<tensor::DimOp>(sourceShard, i);
+        dynamicSizes.push_back(dimSize);
+      }
+    }
+  }
+
+  DenseI64ArrayAttr staticStridesAttr = DenseI64ArrayAttr::get(
+      ctx, SmallVector<int64_t>(targetShape.getRank(), 1));
+  TypedValue<RankedTensorType> targetShard =
+      builder
+          .create<tensor::ExtractSliceOp>(
+              targetShape, sourceShard, dynamicOffsets, dynamicSizes,
+              SmallVector<Value>({}), staticOffsetsAttr, staticSizesAttr,
+              staticStridesAttr)
+          .getResult();
+  return {targetShard.cast<TypedValue<ShapedType>>(), targetSharding};
+}
+
+// Detect if the resharding is of type e.g.
+// [[0, 1]] -> [[0, 1, 2]].
+// If detected, returns the corresponding tensor axis mesh axis pair.
+// Does not detect insertions like
+// [[0, 1]] -> [[0, 2, 1]].
+static std::optional<std::tuple<int64_t, MeshAxis>>
+detectSplitLastAxisInResharding(MeshShardingAttr sourceSharding,
+                                MeshShardingAttr targetSharding) {
+  for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size();
+       ++tensorAxis) {
+    if (sourceSharding.getSplitAxes().size() > tensorAxis) {
+      if (sourceSharding.getSplitAxes()[tensorAxis].size() + 1 !=
+          targetSharding.getSplitAxes()[tensorAxis].size()) {
+        continue;
+      }
+      if (!llvm::equal(
+              sourceSharding.getSplitAxes()[tensorAxis].asArrayRef(),
+              llvm::make_range(
+                  targetSharding.getSplitAxes()[tensorAxis]
+                      .asArrayRef()
+                      .begin(),
+                  targetSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
+                      1))) {
+        continue;
+      }
+    } else {
+      if (targetSharding.getSplitAxes()[tensorAxis].size() != 1) {
+        continue;
+      }
+    }
+    return std::make_tuple(
+        tensorAxis,
+        targetSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
+  }
+  return std::nullopt;
+}
+
+static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
+trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+                             MeshShardingAttr sourceSharding,
+                             MeshShardingAttr targetSharding,
+                             TypedValue<ShapedType> sourceShard) {
+  if (auto detectRes =
+          detectSplitLastAxisInResharding(sourceSharding, targetSharding)) {
+    auto [tensorAxis, meshAxis] = detectRes.value();
+    return splitLastAxisInResharding(builder, sourceSharding, sourceShard, mesh,
+                                     tensorAxis, meshAxis);
+  }
+
+  return std::nullopt;
+}
+
+// Detect if the resharding is of type e.g.
+// [[0, 1, 2]] -> [[0, 1]].
+// If detected, returns the corresponding tensor axis mesh axis pair.
+static std::optional<std::tuple<int64_t, MeshAxis>>
+detectUnsplitLastAxisInResharding(MeshShardingAttr sourceSharding,
+                                  MeshShardingAttr targetSharding) {
+  for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
+       ++tensorAxis) {
+    if (targetSharding.getSplitAxes().size() > tensorAxis) {
+      if (sourceSharding.getSplitAxes()[tensorAxis].size() !=
+          targetSharding.getSplitAxes()[tensorAxis].size() + 1)
+        continue;
+      if (!llvm::equal(
+              llvm::make_range(
+                  sourceSharding.getSplitAxes()[tensorAxis]
+                      .asArrayRef()
+                      .begin(),
+                  sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
+                      1),
+              targetSharding.getSplitAxes()[tensorAxis].asArrayRef()))
+        continue;
+    } else {
+      if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1)
+        continue;
+    }
+    return std::make_tuple(
+        tensorAxis,
+        sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
+  }
+  return std::nullopt;
+}
+
+static MeshShardingAttr
+targetShardingInUnsplitLastAxis(MLIRContext *ctx,
+                                MeshShardingAttr sourceSharding,
+                                int64_t splitTensorAxis) {
+  SmallVector<DenseI32ArrayAttr> targetShardingSplitAxes =
+      llvm::to_vector(sourceSharding.getSplitAxes());
+  assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >
+         splitTensorAxis);
+  auto targetSplitAxes =
+      llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
+
+  targetSplitAxes.pop_back();
+  targetShardingSplitAxes[splitTensorAxis] =
+      DenseI32ArrayAttr::get(ctx, targetSplitAxes);
+  return MeshShardingAttr::get(
+      ctx, sourceSharding.getCluster(), targetShardingSplitAxes,
+      sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
+}
+
+static ShapedType allGatherResultShapeInUnsplitLastAxis(
+    ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) {
+  SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
+  targetShape[splitTensorAxis] =
+      unshardDimension(targetShape[splitTensorAxis], splitCount);
+  return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
+}
+
+static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
+unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
+                            MeshShardingAttr sourceSharding,
+                            ShapedType sourceUnshardedShape,
+                            TypedValue<ShapedType> sourceShard, ClusterOp mesh,
+                            int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
+  MLIRContext *ctx = builder.getContext();
+  builder.setInsertionPointAfterValue(sourceShard);
+
+  MeshShardingAttr targetSharding =
+      targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitMeshAxis);
+  ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
+      sourceShard.getType(), mesh.canonicalDimSizes()[splitMeshAxis],
+      splitTensorAxis);
+  Value allGatherResult = builder.create<AllGatherOp>(
+      RankedTensorType::get(allGatherResultShape.getShape(),
+                            allGatherResultShape.getElementType()),
+      mesh.getSymName(), SmallVector<MeshAxis>({splitMeshAxis}), sourceShard,
+      APInt(64, splitTensorAxis));
+  ShapedType targetShape =
+      shardShapedType(sourceUnshardedShape, mesh, targetSharding);
+  TypedValue<ShapedType> targetShard =
+      builder.create<tensor::CastOp>(targetShape, allGatherResult)
+          .getResult()
+          .cast<TypedValue<ShapedType>>();
+  return {targetShard, targetSharding};
+}
+
+static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
+tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+                               MeshShardingAttr sourceSharding,
+                               MeshShardingAttr targetSharding,
+                               ShapedType sourceUnshardedShape,
+                               TypedValue<ShapedType> sourceShard) {
+  if (auto detectRes =
+          detectUnsplitLastAxisInResharding(sourceSharding, targetSharding)) {
+    auto [tensorAxis, meshAxis] = detectRes.value();
+    return unsplitLastAxisInResharding(builder, sourceSharding,
+                                       sourceUnshardedShape, sourceShard, mesh,
+                                       tensorAxis, meshAxis);
+  }
+
+  return std::nullopt;
+}
+
+// Detect if the resharding is of type e.g.
+// [[0, 1], [2]] -> [[0], [1, 2]].
+// Only moving the last axis counts.
+// If detected, returns the corresponding (source_tensor_axis,
+// target_tensor_axis, mesh_axis) tuple.
+static std::optional<std::tuple<int64_t, int64_t, MeshAxis>>
+detectMoveLastSplitAxisInResharding(MeshShardingAttr sourceSharding,
+                                    MeshShardingAttr targetSharding) {
+  for (size_t sourceTensorAxis = 0;
+       sourceTensorAxis < sourceSharding.getSplitAxes().size();
+       ++sourceTensorAxis) {
+    for (size_t targetTensorAxis = 0;
+         targetTensorAxis < targetSharding.getSplitAxes().size();
+         ++targetTensorAxis) {
+      if (sourceTensorAxis == targetTensorAxis)
+        continue;
+      if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() ||
+          targetSharding.getSplitAxes()[targetTensorAxis].empty() ||
+          sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
+              targetSharding.getSplitAxes()[targetTensorAxis]
+                  .asArrayRef()
+                  .back())
+        continue;
+      if (!llvm::equal(
+              llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis]
+                                   .asArrayRef()
+                                   .begin(),
+                               sourceSharding.getSplitAxes()[sourceTensorAxis]
+                                       .asArrayRef()
+                                       .end() -
+                                   1),
+              llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis]
+                                   .asArrayRef()
+                                   .begin(),
+                               targetSharding.getSplitAxes()[targetTensorAxis]
+                                       .asArrayRef()
+                                       .end() -
+                                   1)))
+        continue;
+      return std::make_tuple(
+          sourceTensorAxis, targetTensorAxis,
+          sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back());
+    }
+  }
+  return std::nullopt;
+}
+
+static MeshShardingAttr
+targetShardingInMoveLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
+                             int64_t sourceTensorAxis,
+                             int64_t targetTensorAxis) {
+  SmallVector<DenseI32ArrayAttr> targetShardingSplitAxes =
+      llvm::to_vector(sourceSharding.getSplitAxes());
+  while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
+         targetTensorAxis) {
+    targetShardingSplitAxes.push_back(DenseI32ArrayAttr::get(ctx, {}));
+  }
+
+  auto sourceSplitAxes =
+      llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef());
+  assert(!sourceSplitAxes.empty());
+  auto meshAxis = sourceSplitAxes.back();
+  sourceSplitAxes.pop_back();
+  targetShardingSplitAxes[sourceTensorAxis] =
+      DenseI32ArrayAttr::get(ctx, sourceSplitAxes);
+
+  auto targetSplitAxes =
+      llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
+  targetSplitAxes.push_back(meshAxis);
+  targetShardingSplitAxes[targetTensorAxis] =
+      DenseI32ArrayAttr::get(ctx, targetSplitAxes);
+
+  return MeshShardingAttr::get(
+      ctx, sourceSharding.getCluster(), targetShardingSplitAxes,
+      sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
+}
+
+static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape,
+                                                    int64_t splitCount,
+                                                    int64_t sourceTensorAxis,
+                                                    int64_t targetTensorAxis) {
+  SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
+  targetShape[sourceTensorAxis] =
+      unshardDimension(targetShape[sourceTensorAxis], splitCount);
+  targetShape[targetTensorAxis] =
+      shardDimension(targetShape[targetTensorAxis], splitCount);
+  return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
+}
+
+static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
+moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+                              MeshShardingAttr sourceSharding,
+                              ShapedType sourceUnshardedShape,
+                              TypedValue<ShapedType> sourceShard,
+                              int64_t sourceTensorAxis,
+                              int64_t targetTensorAxis, MeshAxis meshAxis) {
+  MLIRContext *ctx = builder.getContext();
+  builder.setInsertionPointAfterValue(sourceShard);
+
+  MeshShardingAttr targetSharding = targetShardingInMoveLastAxis(
+      ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
+  ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis(
+      sourceShard.getType(), mesh.canonicalDimSizes()[meshAxis],
+      sourceTensorAxis, targetTensorAxis);
+  Value allToAllResult = builder.create<AllToAllOp>(
+      RankedTensorType::get(allToAllResultShape.getShape(),
+                            allToAllResultShape.getElementType()),
+      mesh.getSymName(), SmallVector<MeshAxis>({meshAxis}), sourceShard,
+      APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
+  ShapedType targetShape =
+      shardShapedType(sourceUnshardedShape, mesh, targetSharding);
+  TypedValue<ShapedType> targetShard =
+      builder.create<tensor::CastOp>(targetShape, allToAllResult)
+          .getResult()
+          .cast<TypedValue<ShapedType>>();
+  return {targetShard, targetSharding};
+}
+
+static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
+tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+                                 MeshShardingAttr sourceSharding,
+                                 MeshShardingAttr targetSharding,
+                                 ShapedType sourceUnshardedShape,
+                                 TypedValue<ShapedType> sourceShard) {
+  if (auto detectRes =
+          detectMoveLastSplitAxisInResharding(sourceSharding, targetSharding)) {
+    auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value();
+    return moveLastSplitAxisInResharding(
+        builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard,
+        sourceTensorAxis, targetTensorAxis, meshAxis);
+  }
+
+  return std::nullopt;
+}
+
+// Handles only resharding on a 1D mesh.
+// Currently the sharded tensor axes must be exactly divisible by the single
+// mesh axis size.
+static TypedValue<ShapedType>
+reshardOn1DMesh(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+                MeshShardingAttr sourceSharding,
+                MeshShardingAttr targetSharding,
+                TypedValue<ShapedType> sourceUnshardedValue,
+                TypedValue<ShapedType> sourceShard) {
+  assert(sourceShard.getType() ==
+         shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding));
+  ShapedType targetShardType =
+      shardShapedType(sourceUnshardedValue.getType(), mesh, targetSharding);
+  assert(sourceShard.getType().getRank() == targetShardType.getRank());
+  assert(mesh.getRank() == 1 && "Only 1D meshes are currently supported.");
+
+  auto [reducedSourceShard, reducedSourceSharding] =
+      handlePartialAxesDuringResharding(builder, sourceSharding, targetSharding,
+                                        sourceShard);
+
+  if (reducedSourceSharding == targetSharding) {
+    return reducedSourceShard;
+  }
+
+  TypedValue<ShapedType> targetShard;
+  MeshShardingAttr actualTargetSharding;
+  if (auto tryRes = tryMoveLastSplitAxisInResharding(
+          builder, mesh, reducedSourceSharding, targetSharding,
+          sourceUnshardedValue.getType(), reducedSourceShard)) {
+    std::tie(targetShard, actualTargetSharding) = tryRes.value();
+  } else if (auto tryRes = trySplitLastAxisInResharding(
+                 builder, mesh, reducedSourceSharding, targetSharding,
+                 reducedSourceShard)) {
+    std::tie(targetShard, actualTargetSharding) = tryRes.value();
+  } else if (auto tryRes = tryUnsplitLastAxisInResharding(
+                 builder, mesh, reducedSourceSharding, targetSharding,
+                 sourceUnshardedValue.getType(), reducedSourceShard)) {
+    std::tie(targetShard, actualTargetSharding) = tryRes.value();
+  } else {
+    assert(false && "Did not find any pattern to apply.");
+  }
+
+  assert(actualTargetSharding == targetSharding);
+  assert(targetShard.getType() == targetShardType);
+  return targetShard;
+}
+
+TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+                               MeshShardingAttr sourceSharding,
+                               MeshShardingAttr targetSharding,
+                               TypedValue<ShapedType> sourceUnshardedValue,
+                               TypedValue<ShapedType> sourceShard) {
+  // Resort to handling only 1D meshes since the general case is complicated if
+  // it needs to be communication efficient in terms of minimizing the data
+  // transfered between devices.
+  return reshardOn1DMesh(builder, mesh, sourceSharding, targetSharding,
+                         sourceUnshardedValue, sourceShard);
+}
+
+TypedValue<ShapedType> reshard(OpBuilder &builder, ClusterOp mesh,
+                               ShardOp source, ShardOp target,
+                               TypedValue<ShapedType> sourceShardValue) {
+  assert(!source.getAnnotateForUsers());
+  assert(target.getAnnotateForUsers());
+  assert(source.getResult() == target.getOperand());
+  ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
+  return reshard(
+      implicitLocOpBuilder, mesh, source.getShard(), target.getShard(),
+      source.getSrc().cast<TypedValue<ShapedType>>(), sourceShardValue);
+}
+
+void reshardingRegisterDependentDialects(DialectRegistry &registry) {
+  registry.insert<arith::ArithDialect, mesh::MeshDialect, tensor::TensorDialect,
+                  cf::ControlFlowDialect>();
+}
+
+} // namespace mesh
+} // namespace mlir
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 03994f8f011e1f..3ee578a37235e2 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -70,6 +70,102 @@ func.func @mesh_axis_negtive_in_partial(
 
 // -----
 
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+
+func.func @cluster_shape_mesh_axis_out_of_bounds() -> (index, index) {
+  // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
+  %0:2 = mesh.cluster_shape @mesh0 axes = [0, 2] : index, index
+  return %0#0, %0#1 : index, index
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
+
+func.func @cluster_shape_duplicate_mesh_axis() -> (index, index, index) {
+  // expected-error at +1 {{Mesh axes contains duplicate elements.}}
+  %0:3 = mesh.cluster_shape @mesh0 axes = [0, 2, 0] : index, index, index
+  return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+
+func.func @cluster_shape_wrong_number_of_results() -> (index, index) {
+  // expected-error at +1 {{Unexpected number of results 2. Expected 1.}}
+  %0:2 = mesh.cluster_shape @mesh0 axes = [0] : index, index
+  return %0#0, %0#1 : index, index
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
+
+func.func @cluster_shape_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
+  // expected-error at +1 {{Unexpected number of results 2. Expected 3.}}
+  %0:2 = mesh.cluster_shape @mesh0 : index, index
+  return %0#0, %0#1 : index, index
+}
+
+// -----
+
+func.func @cluster_shape_invalid_mesh_name() -> (index) {
+  // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
+  %0 = mesh.cluster_shape @this_mesh_symbol_does_not_exist : index
+  return %0#0 : index
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+
+func.func @process_index_mesh_axis_out_of_bounds() -> (index, index) {
+  // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
+  %0:2 = mesh.process_index on @mesh0 axes = [0, 2] : index, index
+  return %0#0, %0#1 : index, index
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
+
+func.func @process_index_duplicate_mesh_axis() -> (index, index, index) {
+  // expected-error at +1 {{Mesh axes contains duplicate elements.}}
+  %0:3 = mesh.process_index on @mesh0 axes = [0, 2, 0] : index, index, index
+  return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+
+func.func @process_index_wrong_number_of_results() -> (index, index) {
+  // expected-error at +1 {{Unexpected number of results 2. Expected 1.}}
+  %0:2 = mesh.process_index on @mesh0 axes = [0] : index, index
+  return %0#0, %0#1 : index, index
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
+
+func.func @process_index_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
+  // expected-error at +1 {{Unexpected number of results 2. Expected 3.}}
+  %0:2 = mesh.process_index on @mesh0 : index, index
+  return %0#0, %0#1 : index, index
+}
+
+// -----
+
+func.func @process_index_invalid_mesh_name() -> (index) {
+  // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
+  %0 = mesh.process_index on @this_mesh_symbol_does_not_exist : index
+  return %0#0 : index
+}
+
+// -----
+
 func.func @all_reduce_invalid_mesh_symbol(
     %arg0 : tensor<4xf32>) -> tensor<4xf64> {
   // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 8f8e309d18f156..a7c3b3dbab9c13 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -132,6 +132,55 @@ func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) ->
   return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32>
 }
 
+// CHECK-LABEL: func @cluster_shape
+func.func @cluster_shape() -> (index, index) {
+  // CHECK: %[[RES:.*]]:2 = mesh.cluster_shape @mesh0 axes = [0, 1] : index, index
+  %0:2 = mesh.cluster_shape @mesh0 axes = [0, 1] : index, index
+  // CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index
+  return %0#0, %0#1 : index, index
+}
+
+// CHECK-LABEL: func @cluster_shape_default_axes
+func.func @cluster_shape_default_axes() -> (index, index, index) {
+  // CHECK: %[[RES:.*]]:3 = mesh.cluster_shape @mesh0 : index, index, index
+  %0:3 = mesh.cluster_shape @mesh0 : index, index, index
+  // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
+  return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// CHECK-LABEL: func @cluster_shape_empty_axes
+func.func @cluster_shape_empty_axes() -> (index, index, index) {
+  // CHECK: %[[RES:.*]]:3 = mesh.cluster_shape @mesh0 : index, index, index
+  %0:3 = mesh.cluster_shape @mesh0 axes = [] : index, index, index
+  // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
+  return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// CHECK-LABEL: func @process_index
+func.func @process_index() -> (index, index) {
+  // CHECK: %[[RES:.*]]:2 = mesh.process_index on @mesh0 axes = [0, 1] : index, index
+  %0:2 = mesh.process_index on @mesh0 axes = [0, 1] : index, index
+  // CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index
+  return %0#0, %0#1 : index, index
+}
+
+// CHECK-LABEL: func @process_index_default_axes
+func.func @process_index_default_axes() -> (index, index, index) {
+  // CHECK: %[[RES:.*]]:3 = mesh.process_index on @mesh0 : index, index, index
+  %0:3 = mesh.process_index on @mesh0 : index, index, index
+  // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
+  return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// CHECK-LABEL: func @process_index_empty_axes
+func.func @process_index_empty_axes() -> (index, index, index) {
+  // CHECK: %[[RES:.*]]:3 = mesh.process_index on @mesh0 : index, index, index
+  %0:3 = mesh.process_index on @mesh0 axes = [] : index, index, index
+  // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
+  return %0#0, %0#1, %0#2 : index, index, index
+}
+
+
 // CHECK-LABEL: func @all_reduce
 func.func @all_reduce(
     // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
diff --git a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
new file mode 100644
index 00000000000000..0ba0d76c09a74d
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
@@ -0,0 +1,154 @@
+// RUN: mlir-opt -test-mesh-resharding-spmdization %s | FileCheck %s
+
+mesh.cluster @mesh_1d(rank = 1, dim_sizes = 2)
+mesh.cluster @mesh_1d_dynamic(rank = 1, dim_sizes = ?)
+
+// CHECK-LABEL: func @same_source_and_target_sharding
+func.func @same_source_and_target_sharding(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<2xf32>
+  %arg0: tensor<2xf32>
+) -> tensor<2xf32> {
+  %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xf32>
+  %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users : tensor<2xf32>
+  // CHECK: return %[[ARG]]
+  return %1 : tensor<2xf32>
+}
+
+// CHECK-LABEL: func @split_replicated_tensor_axis
+func.func @split_replicated_tensor_axis(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<3x14xf32>
+  %arg0: tensor<3x14xf32>
+) -> tensor<3x14xf32> {
+  // CHECK: %[[ZERO:.*]] = arith.constant 0 : index
+  // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE:.*]] = arith.constant 14 : index
+  // CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_index on @mesh_1d axes = [0] : index
+  // CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.cluster_shape @mesh_1d axes = [0] : index
+  // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE:.*]] = arith.remui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
+  // CHECK: %[[RESULT_TENSOR_AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE]], %[[ZERO]] : index
+  // CHECK: cf.assert %[[RESULT_TENSOR_AXIS_SIZE_CHECK]]
+  // CHECK: %[[RESULT_TENSOR_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
+  // CHECK: %[[RESULT_TENSOR_AXIS_OFFSET:.*]] = arith.muli %[[RESULT_TENSOR_AXIS_SIZE]], %[[PROCESS_INDEX]] : index
+  // CHECK: %[[RESULT_TENSOR_SLICE:.*]] = tensor.extract_slice %[[ARG]][0, %[[RESULT_TENSOR_AXIS_OFFSET]]] [3, 7] [1, 1] : tensor<3x14xf32> to tensor<3x7xf32>
+  // CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[RESULT_TENSOR_SLICE]] : tensor<3x7xf32> to tensor<3x14xf32>
+  %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<3x14xf32>
+  %1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<3x14xf32>
+  // CHECK: return %[[RESULT]] : tensor<3x14xf32>
+  return %1 : tensor<3x14xf32>
+}
+
+// CHECK-LABEL: func @split_replicated_tensor_axis_dynamic
+func.func @split_replicated_tensor_axis_dynamic(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<?x3x?xf32>
+  %arg0: tensor<?x3x?xf32>
+) -> tensor<?x3x?xf32> {
+  // CHECK: %[[ZERO:.*]] = arith.constant 0 : index
+  // CHECK: %[[TWO:.*]] = arith.constant 2 : index
+  // CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_index on @mesh_1d_dynamic axes = [0] : index
+  // CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.cluster_shape @mesh_1d_dynamic axes = [0] : index
+  // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %[[ZERO]] : tensor<?x3x?xf32>
+  // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE:.*]] = arith.remui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
+  // CHECK: %[[RESULT_TENSOR_AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE]], %[[ZERO]] : index
+  // CHECK: cf.assert %[[RESULT_TENSOR_AXIS_SIZE_CHECK]]
+  // CHECK: %[[RESULT_TENSOR_SPLIT_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
+  // CHECK: %[[RESULT_TENSOR_SPLIT_AXIS_OFFSET:.*]] = arith.muli %[[RESULT_TENSOR_SPLIT_AXIS_SIZE]], %[[PROCESS_INDEX]] : index
+  // CHECK: %[[TENSOR_AXIS_2_SIZE:.*]] = tensor.dim %[[ARG]], %[[TWO]] : tensor<?x3x?xf32>
+  // CHECK: %[[RESULT_TENSOR_SLICE:.*]] = tensor.extract_slice %[[ARG]][%[[RESULT_TENSOR_SPLIT_AXIS_OFFSET]], 0, 0]
+  // CHECK-SAME: [%[[RESULT_TENSOR_SPLIT_AXIS_SIZE]], 3, %[[TENSOR_AXIS_2_SIZE]]] [1, 1, 1] : tensor<?x3x?xf32> to tensor<?x3x?xf32>
+  %0 = mesh.shard %arg0 to <@mesh_1d_dynamic, [[], [], []]> : tensor<?x3x?xf32>
+  %1 = mesh.shard %0 to <@mesh_1d_dynamic, [[0]]> annotate_for_users : tensor<?x3x?xf32>
+  // CHECK: return %[[RESULT_TENSOR_SLICE]] : tensor<?x3x?xf32>
+  return %1 : tensor<?x3x?xf32>
+}
+
+// CHECK-LABEL: func @move_split_axis
+func.func @move_split_axis(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+  %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+  // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32>
+  // CHECK: %[[TARGET_SHARD:.*]] = mesh.all_to_all %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<5x14xf32> -> tensor<10x7xf32>
+  // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x7xf32> to tensor<10x14xf32>
+  %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<10x14xf32>
+  %1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<10x14xf32>
+  // CHECK: return %[[RES]] : tensor<10x14xf32>
+  return %1 : tensor<10x14xf32>
+}
+
+// CHECK-LABEL: func @move_split_axis_dynamic_mesh
+func.func @move_split_axis_dynamic_mesh(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+  %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+  // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32>
+  // CHECK: %[[ALL_TO_ALL:.*]] = mesh.all_to_all %[[SOURCE_SHARD]] on @mesh_1d_dynamic mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x?xf32>
+  // CHECK: %[[TARGET_SHARD:.*]] = tensor.cast %[[ALL_TO_ALL]] : tensor<?x?xf32> to tensor<10x?xf32>
+  // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x?xf32> to tensor<10x14xf32>
+  %0 = mesh.shard %arg0 to <@mesh_1d_dynamic, [[0]]> : tensor<10x14xf32>
+  %1 = mesh.shard %0 to <@mesh_1d_dynamic, [[], [0]]> annotate_for_users : tensor<10x14xf32>
+  // CHECK: return %[[RES]] : tensor<10x14xf32>
+  return %1 : tensor<10x14xf32>
+}
+
+// CHECK-LABEL: func @move_split_dynamic_axis
+func.func @move_split_dynamic_axis(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32>
+  %arg0: tensor<?x14xf32>
+) -> tensor<?x14xf32> {
+  // CHECK: %[[TARGET_SHARD:.*]] = mesh.all_to_all %[[ARG]] on @mesh_1d mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x7xf32>
+  // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<?x7xf32> to tensor<?x14xf32>
+  %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<?x14xf32>
+  %1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<?x14xf32>
+  // CHECK: return %[[RES]] : tensor<?x14xf32>
+  return %1 : tensor<?x14xf32>
+}
+
+// CHECK-LABEL: func @unshard_static_axis
+func.func @unshard_static_axis(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+  %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+  // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32>
+  // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<5x14xf32> -> tensor<10x14xf32>
+  %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<10x14xf32>
+  %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users : tensor<10x14xf32>
+  // CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32>
+  return %1 : tensor<10x14xf32>
+}
+
+// CHECK-LABEL: func @unshard_dynamic_axis
+func.func @unshard_dynamic_axis(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32>
+  %arg0: tensor<?x14xf32>
+) -> tensor<?x14xf32> {
+  // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[ARG]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32>
+  %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<?x14xf32>
+  %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users : tensor<?x14xf32>
+  // CHECK: return %[[ALL_GATHER]] : tensor<?x14xf32>
+  return %1 : tensor<?x14xf32>
+}
+
+// CHECK-LABEL: func @unshard_static_axis_on_dynamic_mesh_axis
+func.func @unshard_static_axis_on_dynamic_mesh_axis(
+// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>  
+  %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+  // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32>
+  // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d_dynamic mesh_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32>
+  // CHECK: %[[RES:.*]] = tensor.cast %[[ALL_GATHER]] : tensor<?x14xf32> to tensor<10x14xf32>
+  %0 = mesh.shard %arg0 to <@mesh_1d_dynamic, [[0]]> : tensor<10x14xf32>
+  %1 = mesh.shard %0 to <@mesh_1d_dynamic, [[]]> annotate_for_users : tensor<10x14xf32>
+  // CHECK: return %[[RES]] : tensor<10x14xf32>
+  return %1 : tensor<10x14xf32>
+}
+
+// CHECK-LABEL: func @partial_axis
+func.func @partial_axis(
+// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>  
+  %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+  // CHECK: %[[ALL_REDUCE:.*]] = mesh.all_reduce %[[ARG]] on @mesh_1d mesh_axes = [0] : tensor<10x14xf32> -> tensor<10x14xf32>
+  %0 = mesh.shard %arg0 to <@mesh_1d, [[]], partial = sum[0]> : tensor<10x14xf32>
+  %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users : tensor<10x14xf32>
+  // CHECK: %[[ALL_REDUCE]] : tensor<10x14xf32>
+  return %1 : tensor<10x14xf32>
+}
diff --git a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
index 16b50bb878a074..f14d282857a1e0 100644
--- a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
@@ -1,5 +1,6 @@
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRMeshTestSimplifications
+  TestReshardingSpmdization.cpp
   TestSimplifications.cpp
 
   EXCLUDE_FROM_LIBMLIR
diff --git a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp b/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
new file mode 100644
index 00000000000000..940010fd94888d
--- /dev/null
+++ b/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
@@ -0,0 +1,124 @@
+//===- TestSimplification.cpp - Test simplification -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::mesh;
+
+namespace {
+
+struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
+  using OpRewritePattern<ShardOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ShardOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op.getAnnotateForUsers()) {
+      return failure();
+    }
+
+    mesh::ClusterOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(
+        op, op.getShard().getCluster());
+
+    bool foundUser = false;
+    for (auto user : op->getUsers()) {
+      if (auto targetShardOp = llvm::dyn_cast<ShardOp>(user)) {
+        if (targetShardOp.getAnnotateForUsers() &&
+            mesh == symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(
+                        targetShardOp, targetShardOp.getShard().getCluster())) {
+          foundUser = true;
+          break;
+        }
+      }
+    }
+
+    if (!foundUser) {
+      return failure();
+    }
+
+    for (auto user : op->getUsers()) {
+      auto targetShardOp = llvm::dyn_cast<ShardOp>(user);
+      if (!targetShardOp || !targetShardOp.getAnnotateForUsers() ||
+          symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(
+              targetShardOp, targetShardOp.getShard().getCluster()) != mesh) {
+        continue;
+      }
+
+      ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+      ShapedType sourceShardShape =
+          shardShapedType(op.getResult().getType(), mesh, op.getShard());
+      TypedValue<ShapedType> sourceShard =
+          builder
+              .create<UnrealizedConversionCastOp>(sourceShardShape,
+                                                  op.getOperand())
+              ->getResult(0)
+              .cast<TypedValue<ShapedType>>();
+      TypedValue<ShapedType> targetShard =
+          reshard(builder, mesh, op, targetShardOp, sourceShard);
+      Value newTargetUnsharded =
+          builder
+              .create<UnrealizedConversionCastOp>(
+                  targetShardOp.getResult().getType(), targetShard)
+              ->getResult(0);
+      rewriter.replaceAllUsesWith(targetShardOp.getResult(),
+                                  newTargetUnsharded);
+    }
+
+    return success();
+  }
+
+private:
+  mutable SymbolTableCollection symbolTable;
+};
+
+struct TestMeshReshardingPass
+    : public PassWrapper<TestMeshReshardingPass, OperationPass<ModuleOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshReshardingPass)
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    patterns.insert<TestMeshReshardingRewritePattern>(&getContext());
+    if (failed(applyPatternsAndFoldGreedily(getOperation().getOperation(),
+                                            std::move(patterns)))) {
+      return signalPassFailure();
+    }
+  }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    reshardingRegisterDependentDialects(registry);
+    registry.insert<BuiltinDialect>();
+  }
+  StringRef getArgument() const final {
+    return "test-mesh-resharding-spmdization";
+  }
+  StringRef getDescription() const final {
+    return "Test Mesh dialect resharding spmdization.";
+  }
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestMeshReshardingSpmdizationPass() {
+  PassRegistration<TestMeshReshardingPass>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index eedade691c6c39..bea2449f44dcde 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -119,6 +119,7 @@ void registerTestMathPolynomialApproximationPass();
 void registerTestMemRefDependenceCheck();
 void registerTestMemRefStrideCalculation();
 void registerTestMeshSimplificationsPass();
+void registerTestMeshReshardingSpmdizationPass();
 void registerTestNextAccessPass();
 void registerTestOneToNTypeConversionPass();
 void registerTestOpaqueLoc();
@@ -237,6 +238,7 @@ void registerTestPasses() {
   mlir::test::registerTestMemRefDependenceCheck();
   mlir::test::registerTestMemRefStrideCalculation();
   mlir::test::registerTestMeshSimplificationsPass();
+  mlir::test::registerTestMeshReshardingSpmdizationPass();
   mlir::test::registerTestNextAccessPass();
   mlir::test::registerTestOneToNTypeConversionPass();
   mlir::test::registerTestOpaqueLoc();

>From 1e0f1eb0d12f60d12654401b1e6304e58c8d964a Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Thu, 21 Dec 2023 13:47:33 -0800
Subject: [PATCH 2/6] Remove unused ceilDiv function

---
 mlir/include/mlir/Support/MathExtras.h | 11 -----------
 1 file changed, 11 deletions(-)

diff --git a/mlir/include/mlir/Support/MathExtras.h b/mlir/include/mlir/Support/MathExtras.h
index 22493c11a370d1..17a747393d26eb 100644
--- a/mlir/include/mlir/Support/MathExtras.h
+++ b/mlir/include/mlir/Support/MathExtras.h
@@ -15,20 +15,9 @@
 
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/APInt.h"
-#include <type_traits>
 
 namespace mlir {
 
-// ceilDiv for unsigned integral.
-template <typename T, std::enable_if_t<std::is_integral_v<T> &&
-                                       !std::is_unsigned_v<T>> = true>
-T ceilDiv(T lhs, T rhs) {
-  assert(rhs != static_cast<T>(0));
-  T q = lhs / rhs;
-  T r = lhs % rhs;
-  return r == static_cast<T>(0) ? q : q + static_cast<T>(1);
-}
-
 /// Returns the result of MLIR's ceildiv operation on constants. The RHS is
 /// expected to be non-zero.
 inline int64_t ceilDiv(int64_t lhs, int64_t rhs) {

>From 95c5e52a5271222cb462c137edcc3419840ac6a2 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Thu, 21 Dec 2023 15:10:18 -0800
Subject: [PATCH 3/6] Remove mutable state from
 TestMeshReshardingRewritePattern

---
 mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp | 4 +---
 1 file changed, 1 insertion(+), 3 deletions(-)

diff --git a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp b/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
index 940010fd94888d..6fecbd48f15387 100644
--- a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
+++ b/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
@@ -36,6 +36,7 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
       return failure();
     }
 
+    SymbolTableCollection symbolTable;
     mesh::ClusterOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(
         op, op.getShard().getCluster());
 
@@ -85,9 +86,6 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
 
     return success();
   }
-
-private:
-  mutable SymbolTableCollection symbolTable;
 };
 
 struct TestMeshReshardingPass

>From 9e72161754aace136cf4f99fe340a8cee45d004e Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <sogartary at yahoo.com>
Date: Tue, 2 Jan 2024 09:28:23 -0800
Subject: [PATCH 4/6] Update
 mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc

Co-authored-by: Chengji Yao <yaochengji at hotmail.com>
---
 .../mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc       | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc b/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc
index 181f07177e0af9..ba31e7e7ab28a1 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc
@@ -1,4 +1,4 @@
-Reshadring Spmdization Examples
+Resharding Spmdization Examples
 
 --------------------------------------------------------------
 

>From 517b50123acfa1ea502acea050a3b50ad87c8f94 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <sogartary at yahoo.com>
Date: Tue, 2 Jan 2024 09:29:15 -0800
Subject: [PATCH 5/6] Update
 mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc

Co-authored-by: Chengji Yao <yaochengji at hotmail.com>
---
 .../mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc       | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc b/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc
index ba31e7e7ab28a1..9299bc68d75338 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc
@@ -401,7 +401,7 @@ d2(
 
 We want to copy the the data between devices in slices/tiles.
 What are the source/target tile coordinates?
-Fro a fixed (m, n, p, q) what is the range of (k, l, u, v)?
+For a fixed (m, n, p, q) what is the range of (k, l, u, v)?
 TODO
 
 --------------------------------------------------------------

>From 8b5fceb928400d5ca861991f20af16e312916b46 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Tue, 2 Jan 2024 10:44:47 -0800
Subject: [PATCH 6/6] Make resharding doc markdown

---
 .../Mesh/Transforms/ReshardingSpmdizationDoc  | 621 ------------------
 1 file changed, 621 deletions(-)
 delete mode 100644 mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc

diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc b/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc
deleted file mode 100644
index 9299bc68d75338..00000000000000
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc
+++ /dev/null
@@ -1,621 +0,0 @@
-Resharding Spmdization Examples
-
---------------------------------------------------------------
-
-Reshard 2x3 tensor from sharding [[0, 1]] to sharding [[0, 1]] on a 2x3 mesh.
-
-unsharded 2x3 tensor
-11 12 13
-21 22 23
-
-sharded on a 2x3 mesh
-
-sharding = [[0, 1]]
-
-mesh contents:
-
-mesh axis 1
------------>
-+----+----+----+ mesh axis 0 |
-| 11 | 12 | 13 |             |
-+----+----+----+             |
-| 21 | 22 | 23 |             |
-+----+----+----+             ↓
-
-Transform into
-sharding = [[1, 0]]
-
-mesh axis 1
------------>
-+----+----+----+ mesh axis 0 |
-| 11 | 13 | 22 |             |
-+----+----+----+             |
-| 12 | 21 | 23 |             |
-+----+----+----+             ↓
-
-Swap contents on devices that have the same linear index in the 2 shardings.
-
---------------------------------------------------------------
-
-Reshard 2x3 tensor from sharding [[0, 1]] to sharding [[1]] on a 2x3 mesh.
-
-unsharded 2x3 tensor
-11 12 13
-21 22 23
-
-sharded on a 2x3 mesh
-
-sharding = [[0, 1]]
-
-mesh contents:
-
-mesh axis 1
------------>
-+----+----+----+ mesh axis 0 |
-| 11 | 12 | 13 |             |
-+----+----+----+             |
-| 21 | 22 | 23 |             |
-+----+----+----+             ↓
-
-Transform into
-sharding = [[1]]
-
-mesh axis 1
------------>
-+----+----+----+ mesh axis 0 |
-| 11 | 12 | 13 |             |
-| 21 | 22 | 23 |             |
-+----+----+----+             |
-| 11 | 12 | 13 |             |
-| 21 | 22 | 23 |             |
-+----+----+----+             ↓
-
-All-gather along mesh axis 0.
-
---------------------------------------------------------------
-
-Reshard 4x6 tensor from sharding [[], [0, 1]] to sharding [[], [0]] on a 2x3 mesh.
-
-unsharded 4x6 tensor
-11 12 13 14 15 16
-21 22 23 24 25 26
-
-sharded on a 2x3 mesh
-
-sharding = [[], [0, 1]]
-
-mesh contents:
-
-mesh axis 1
------------>
-+----+----+----+ mesh axis 0 |
-| 11 | 12 | 13 |             |
-| 21 | 22 | 23 |             |
-+----+----+----+             |
-| 14 | 15 | 16 |             |
-| 24 | 25 | 26 |             |
-+----+----+----+             ↓
-
-Transform into
-sharding = [[], [0]]
-
-mesh axis 1
------------>
-+----------+----------+ mesh axis 0 |
-| 11 12 13 | 11 12 13 |             |
-| 21 22 23 | 21 22 23 |             |
-+----------+----------+             |
-| 14 15 16 | 14 15 16 |             |
-| 24 25 26 | 24 25 26 |             |
-+----------+----------+             ↓
-
-all-gather along mesh axis 1
-
---------------------------------------------------------------
-
-Reshard 4x8 tensor from sharding [[0], [1, 2]] to sharding [[0], [2]] on a 2x2x2 mesh.
-
-unsharded 4x8 tensor
-11 12 13 14 15 16 17 18
-21 22 23 24 25 26 27 28
-31 32 33 34 35 36 37 38
-41 42 43 44 45 46 47 48
-
-sharded on a 2x2x2 mesh
-
-sharding = [[0], [1, 2]]
-
-mesh contents:
-
-mesh axis 2
------------>
-+-------+-------+ mesh axis 1 | mesh axis 0 |
-| 11 12 | 13 14 |             |             |
-| 21 22 | 23 24 |             |             |
-+-------+-------+             |             |
-| 15 16 | 17 18 |             |             |
-| 25 26 | 27 28 |             |             |
-+-------+-------+             ↓             |
-+-------+-------+                           |
-| 31 32 | 33 34 |                           |
-| 41 42 | 43 44 |                           |
-+-------+-------+                           |
-| 35 36 | 37 38 |                           |
-| 45 46 | 47 48 |                           |
-+-------+-------+                           ↓
-
-Transform into
-sharding = [[0], [2]]
-
-mesh axis 2
------------>
-+-------------+-------------+ mesh axis 1 | mesh axis 0 |
-| 11 12 13 14 | 15 16 17 18 |             |             |
-| 21 22 23 24 | 25 26 27 28 |             |             |
-+-------------+-------------+             |             |
-| 11 12 13 14 | 15 16 17 18 |             |             |
-| 21 22 23 24 | 25 26 27 28 |             |             |
-+-------------+-------------+             ↓             |
-+-------------+-------------+                           |
-| 31 32 33 34 | 35 36 37 38 |                           |
-| 41 42 43 44 | 45 46 47 48 |                           |
-+-------------+-------------+                           |
-| 31 32 33 34 | 35 36 37 38 |                           |
-| 41 42 43 44 | 45 46 47 48 |                           |
-+-------------+-------------+                           ↓
-
-Can't be done with just an all-gather along mesh axis 1.
-Can be handled by multiple resharding transformations
-[[0], [1, 2]] -> [[0], [2, 1]] -> [[0], [2]]
-
---------------------------------------------------------------
-
-Reshard 6x6 tensor from sharding [[0], [1]] to sharding [[1], [0]] on a 2x3 mesh.
-
-unsharded 6x6 tensor
-11 12 13 14 15 16
-21 22 23 24 25 26
-31 32 33 34 35 36
-41 42 43 44 45 46
-51 52 53 54 55 56
-61 62 63 64 65 66
-
-sharded on a 2x3 mesh
-
-sharding = [[0], [1]]
-
-mesh axis 1
------------>
-+-------+-------+-------+ mesh axis 0 |
-| 11 12 | 13 14 | 15 16 |             |
-| 21 22 | 23 24 | 25 26 |             |
-| 31 32 | 33 34 | 35 36 |             |
-+-------+-------+-------+             |
-| 41 42 | 43 44 | 45 46 |             |
-| 51 52 | 53 54 | 55 56 |             |
-| 61 62 | 63 64 | 65 66 |             |
-+-------+-------+-------+             ↓
-
-transform to
-sharding = [[1], [0]]
-
-mesh axis 1
------------>
-+----------+----------+----------+ mesh axis 0 |
-| 11 12 13 | 31 32 33 | 51 52 53 |             |
-| 21 22 23 | 41 42 43 | 61 62 63 |             |
-+----------+----------+----------+             |
-| 14 15 16 | 34 35 36 | 54 55 56 |             |
-| 24 25 26 | 44 45 46 | 64 65 66 |             |
-+----------+----------+----------+             ↓
-
-mesh axis 0
------------>
-+----------+----------+ mesh axis 1 |
-| 11 12 13 | 14 15 16 |             |
-| 21 22 23 | 24 25 26 |             |
-+----------+----------+             |
-| 31 32 33 | 34 35 36 |             |
-| 41 42 43 | 44 45 46 |             |
-+----------+----------+             |
-| 51 52 53 | 54 55 56 |             |
-| 61 62 63 | 64 65 66 |             |
-+----------+----------+             ↓
-
-TODO
-
---------------------------------------------------------------
-
-Reshard 6x6 tensor from sharding [[0], [1]] to sharding [[1], [0]] on a 2x6 mesh.
-
-unsharded 6x6 tensor
-11 12 13 14 15 16
-21 22 23 24 25 26
-31 32 33 34 35 36
-41 42 43 44 45 46
-51 52 53 54 55 56
-61 62 63 64 65 66
-
-shard on 2x6 mesh
-
-sharding = [[0], [1]]
-
-mesh axis 1
------------>
-+----+----+----+----+----+----+ mesh axis 0 |
-| 11 | 12 | 13 ‖ 14 | 15 | 16 |             |
-| 21 | 22 | 23 ‖ 24 | 23 | 26 |             |
-| 31 | 32 | 33 ‖ 34 | 35 | 36 |             |
-+----+----+----+----+----+----+             |
-| 41 | 42 | 43 ‖ 44 | 45 | 46 |             |
-| 51 | 52 | 53 ‖ 54 | 55 | 56 |             |
-| 61 | 62 | 63 ‖ 64 | 65 | 66 |             |
-+----+----+----+----+----+----+             ↓
-
-
-transform to
-sharding = [[1], [0]]
-
-mesh axis 0
------------>
-+----------+----------+ mesh axis 1 |
-| 11 12 13 | 14 15 16 |             |
-+----------+----------+             |
-| 21 22 23 | 24 25 26 |             |
-+----------+----------+             |
-| 31 32 33 | 34 35 36 |             |
-+==========+==========+             |
-| 41 42 43 | 44 45 46 |             |
-+----------+----------+             |
-| 51 52 53 | 54 55 56 |             |
-+----------+----------+             |
-| 61 62 63 | 64 65 66 |             |
-+----------+----------+             ↓
-
-TODO
-
---------------------------------------------------------------
-
-Reshard KxL tensor from [[0], [1]] to [[1], [0]] on MxN mesh.
-
-M x N mesh.
-K x L tensor t.
-d(m, n) the tensor on device (m, n).
-
-sharding = [[0], [1]]
-Tensor shard s on each device has size (K ceildiv M, L ceildiv N).
-d(m, n)[k, l] -> t[m * (K ceildiv M) + k, n * (L ceildiv N) + l]
-
-substitute
-i <- m * (K ceildiv M) + k
-j <- n * (L ceildiv N) + l
-
-m -> i floordiv (K ceildiv M)
-n -> j floordiv (L ceildiv N)
-k -> i % (K ceildiv M)
-l -> j % (L ceildiv N)
-
-For the inverse map we get
-t[i, j] -> d(
-    i floordiv (K ceildiv M), j floordiv (L ceildiv N)
-)[
-    i % (K ceildiv M), j % (L ceildiv N)
-]
-
-Check:
-i = 13, j = 17, M = 3, N = 4, K = 16, L = 23
-t[13, 17] = d(
-    13 floordiv (16 ceildiv 3),
-    17 floordiv (23 ceilvid 4)
-)[
-    13 % (16 ceildiv 3),
-    17 % (23 ceilvid 4)
-]
-= d(
-    13 floordiv 6,
-    17 floordiv 6
-)[
-    13 % 6,
-    17 % 6
-]
-= d(2, 2)[1, 5]
-= t[
-    2 * (16 ceildiv 3) + 1,
-    2 * (23 ceildiv 4) + 5
-]
-= t[
-    2 * 6 + 1,
-    2 * 6 + 5
-]
-= t[13, 17]
-
-
-sharding = [[1], [0]]
-Tensor shard s on each device has size (K ceildiv N, L ceildiv M).
-d(m, n)[k, l] -> t[n * (K ceildiv N) + k, m * (L ceildiv M) + l]
-
-substitute
-i <- n * (K ceildiv N) + k
-j <- m * (L ceildiv M) + l
-
-m -> j floordiv (L ceildiv M)
-n -> i floordiv (K ceildiv N)
-k -> i % (K ceildiv N)
-l -> j % (L ceildiv M)
-
-For the inverse map we get
-t[i, j] -> d(
-    j floordiv (L ceildiv M), i floordiv (K ceildiv N)
-)[
-    i % (K ceildiv N), j % (L ceildiv M)
-]
-
-Check:
-i = 9, j = 19, M = 5, N = 2, K = 27, L = 14
-t[9, 19] = d(
-    19 floordiv (14 ceildiv 5),
-    9 floordiv (27 ceildiv 2)
-)[
-    9 % (27 ceildiv 2),
-    19 % (14 ceildiv 5)
-]
-= d(
-    19 floordiv 3,
-    9 floordiv 14
-)[
-    9 % 14
-    19 % 3
-]
-= d(6, 0)[9, 1]
-= t[
-    0 * (27 ceildiv 2) + 9,
-    6 * (14 ceildiv 5) + 1
-]
-= t[
-    0 * 14 + 9,
-    6 * 3 + 1
-]
-= t[9, 19]
-
-sharding = [[0], [1]]
-d(m, n)[k, l] -> t[m * (K ceildiv M) + k, n * (L ceildiv N) + l]
-t[i, j] -> d(i floordiv (K ceildiv M), j floordiv (L ceildiv N))[i % (K ceildiv M), j % (L ceildiv N)]
-
-sharding = [[1], [0]]
-d(m, n)[k, l] -> t[n * (K ceildiv N) + k, m * (L ceildiv M) + l]
-t[i, j] -> d(j floordiv (L ceildiv M), i floordiv (K ceildiv N))[i % (K ceildiv N), j % (L ceildiv M)]
-
-sharding [[0], [1]] -> [[1], [0]]
-d1(m, n) the tensor on device (m, n) for sharding sharding [[0], [1]].
-d2(m, n) the tensor on device (m, n) for sharding sharding [[1], [0]].
-d1(m, n)[k, l] ->
-t[m * (K ceildiv M) + k, n * (L ceildiv N) + l] ->
-d2(
-    (m * (L ceildiv M) + l) floordiv (L ceildiv M),
-    (n * (K ceildiv N) + k) floordiv (K ceildiv N)
-)[
-    (n * (K ceildiv N) + k) % (K ceildiv N),
-    (m * (L ceildiv M) + l) % (L ceildiv M)
-]
-= d2(p, q)[u, v]
-
-We want to copy the the data between devices in slices/tiles.
-What are the source/target tile coordinates?
-For a fixed (m, n, p, q) what is the range of (k, l, u, v)?
-TODO
-
---------------------------------------------------------------
-
-Reshard KxL tensor from sharding [[0], [1]] to sharding [[1], [0]] on a 2x3 mesh.
-
-Device placement on a 2x3 mesh
-11 12 13  <- devices
-21 22 23
-
-sharding [[0], [1]]
-tensor axis 1
------------>
-+----+----+----+ tensor axis 0 |
-| 11 | 12 | 13 |               |
-+----+----+----+               |
-| 21 | 22 | 23 |               |
-+----+----+----+               ↓
-
-transform to
-sharding [[1], [0]]
-tensor axis 1
------------>
-+----+----+ tensor axis 0 |
-| 11 | 21 |               |
-+----+----+               |
-| 12 | 22 |               |
-+----+----+               |
-| 13 | 23 |               |
-+----+----+               ↓
-
-+-----------------+--------+--------+-----------------+
-|                 |                 |                 |
-+                 +                 +                 +
-|       11        |        12       |        13       |
-+                 +                 +                 +
-|                 |                 |                 |
-+-----------------+--------+--------+-----------------+
-|                 |                 |                 |
-+                 +                 +                 +
-|       21        |        22       |        23       |
-+                 +                 +                 +
-|                 |                 |                 |
-+-----------------+--------+--------+-----------------+
-
-+-----------------+--------+--------+-----------------+
-|                          |                          |
-+          11              +               21         +
-|                          |                          |
-+-----------------+--------+--------+-----------------+
-|                          |                          |
-+          12              +               22         +
-|                          |                          |
-+-----------------+--------+--------+-----------------+
-|                          |                          |
-+          13              +               23         +
-|                          |                          |
-+-----------------+--------+--------+-----------------+
-
-+-----------------+--------+--------+-----------------+
-|                 |        |        |                 |
-+     11  11      + 12  11 + 12  21 +     13  21      +
-|                 |        |        |                 |
-+-----------------+--------+--------+-----------------+
-|     11  12      | 12  12 | 12  22 |     13  22      |
-+-----------------+--------+--------+-----------------+
-|     21  12      | 22  12 | 22  22 |     23  22      |
-+-----------------+--------+--------+-----------------+
-|                 |        |        |                 |
-+     21  13      + 22  13 + 22  23 +     23  23      +
-|                 |        |        |                 |
-+-----------------+--------+--------+-----------------+
-
-If S and T are the source and target shard sizes along some tensor axis.
-Then we have a period of (S*T)/gcd(S, T). Then the cut pattern repeats.
-TODO
-
---------------------------------------------------------------
-
-Reshard 6x6 tensor from sharding [[0], []] to sharding [[], [0]] on a 3 mesh.
-
-unsharded 6x6 tensor
-11 12 13 14 15 16
-21 22 23 24 25 26
-31 32 33 34 35 36
-41 42 43 44 45 46
-51 52 53 54 55 56
-61 62 63 64 65 66
-
-sharded on a 3 mesh
-
-sharding = [[0], []]
-
-+-------------------+ mesh axis 0 |
-| 11 12 13 14 15 16 |             |
-| 21 22 23 24 25 26 |             |
-+-------------------+             |
-| 31 32 33 34 35 36 |             |
-| 41 42 43 44 45 46 |             |
-+-------------------+             |
-| 51 52 53 54 55 56 |             |
-| 61 62 63 64 65 66 |             |
-+-------------------+             ↓
-
-transform to
-sharding = [[], [0]]
-
-mesh axis 0
------------>
-+-------+-------+-------+
-| 11 12 | 13 14 | 15 16 |
-| 21 22 | 23 24 | 25 26 |
-| 31 32 | 33 34 | 35 36 |
-| 41 42 | 43 44 | 45 46 |
-| 51 52 | 53 54 | 55 56 |
-| 61 62 | 63 64 | 65 66 |
-+-------+-------+-------+
-
-%1 = all_to_all %0 on @mesh mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<2x6xi8> -> tensor<6x2xi8>
-
---------------------------------------------------------------
-
-Reshard 4x4 tensor from sharding [[0], [1, 2]] to sharding [[0, 1], [2]] on a 2x2x2 mesh.
-
-unsharded 4x4 tensor
-11 12 13 14
-21 22 23 24
-31 32 33 34
-41 42 43 44
-
-sharded on a 2x2x2 mesh
-
-sharding = [[0], [1, 2]]
-
-mesh axis 2
------------>
-+----+----+ mesh axis 1 | mesh axis 0 |
-| 11 | 12 |             |             |
-| 21 | 22 |             |             |
-+----+----+             |             |
-| 13 | 14 |             |             |
-| 23 | 24 |             |             |
-+----+----+             ↓             |
-+----+----+                           |
-| 31 | 32 |                           |
-| 41 | 42 |                           |
-+----+----+                           |
-| 33 | 34 |                           |
-| 43 | 44 |                           |
-+----+----+                           ↓
-
-transform to
-sharding = [[0, 1], [2]]
-
-mesh axis 2
------------>
-+-------+-------+ mesh axis 1 | mesh axis 0 |
-| 11 12 | 13 41 |             |             |
-+-------+-------+             |             |
-| 21 22 | 23 24 |             |             |
-+-------+-------+             ↓             |
-+-------+-------+                           |
-| 31 32 | 33 34 |                           |
-+-------+-------+                           |
-| 41 42 | 43 44 |                           |
-+-------+-------+                           ↓
-
-%1 = all_to_all %0 on @mesh mesh_axes = [2] split_axis = 1 concat_axis = 0 : tensor<2x1xi8> -> tensor<1x2xi8>
-is not enough.
-
-Can be decomposed into
-[[0], [1, 2]] -> [[0], [2, 1]] -> [[0, 1], [2]]
-
---------------------------------------------------------------
-
-We can decompose each resharding into a sequence of basis reshardings.
-It is not communication efficient in terms of minimizing the data communicated
-between devices.
-An efficient approach would be more complicated to implement.
-Each device has to receive at most as much data as the size of its target
-sharding tensor.
-
-Basis:
-
-From replicate to split.
-[[]] -> [[1]]
-Extract slices without communication.
-
-From split to replicate.
-[[0]] -> [[]]
-[[0, 1]] -> [[1]]
-All-gather along mesh axis 0.
-
-Swap mesh axes order when assigned to the same tensor axis.
-[[0, 1]] -> [[1, 0]]
-Swap contents on devices with the same linear index.
-
-Move mesh axis to different tensor dimension.
-[[0], []] -> [[], [0]]
-All-to-all.
-
-Example decomposition of
-[[0], [1]] -> [[1], [0]]
-
-[[0], [1]] -> all-gather along mesh axis 1    ->
-[[0], []]  -> all-to-all along mesh axis 0    ->
-[[], [0]]  -> extract slice along mesh axis 1 ->
-[[1], [0]]
-
-Example decomposition of
-[[3, 2], [], [0, 1]] -> [[0], [1, 2], []]
-
-[[3, 2], [], [0, 1]] -> all-to-all along mesh axis 1 ->
-[[3, 2], [1], [0]]   -> all-to-all along mesh axis 2 ->
-[[3], [1, 2], [0]]   -> all-gather along mesh axis 3 ->
-[[], [1, 2], [0]]    -> all-to-all along mesh axis 0 ->
-[[0], [1, 2], []]



More information about the Mlir-commits mailing list