[Mlir-commits] [mlir] [mlir][mesh] Add resharding spmdization on a 1D device mesh (PR #76179)
Chengji Yao
llvmlistbot at llvm.org
Mon Jan 1 12:54:15 PST 2024
@@ -0,0 +1,621 @@
+Reshadring Spmdization Examples
+Reshard 2x3 tensor from sharding [[0, 1]] to sharding [[0, 1]] on a 2x3 mesh.
+unsharded 2x3 tensor
+11 12 13
+21 22 23
+sharded on a 2x3 mesh
+sharding = [[0, 1]]
+mesh contents:
+mesh axis 1
++----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 | |
++----+----+----+ |
+| 21 | 22 | 23 | |
++----+----+----+ ↓
+Transform into
+sharding = [[1, 0]]
+mesh axis 1
++----+----+----+ mesh axis 0 |
+| 11 | 13 | 22 | |
++----+----+----+ |
+| 12 | 21 | 23 | |
++----+----+----+ ↓
+Swap contents on devices that have the same linear index in the 2 shardings.
+Reshard 2x3 tensor from sharding [[0, 1]] to sharding [[1]] on a 2x3 mesh.
+unsharded 2x3 tensor
+11 12 13
+21 22 23
+sharded on a 2x3 mesh
+sharding = [[0, 1]]
+mesh contents:
+mesh axis 1
++----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 | |
++----+----+----+ |
+| 21 | 22 | 23 | |
++----+----+----+ ↓
+Transform into
+sharding = [[1]]
+mesh axis 1
++----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 | |
+| 21 | 22 | 23 | |
++----+----+----+ |
+| 11 | 12 | 13 | |
+| 21 | 22 | 23 | |
++----+----+----+ ↓
+All-gather along mesh axis 0.
+Reshard 4x6 tensor from sharding [[], [0, 1]] to sharding [[], [0]] on a 2x3 mesh.
+unsharded 4x6 tensor
+11 12 13 14 15 16
+21 22 23 24 25 26
+sharded on a 2x3 mesh
+sharding = [[], [0, 1]]
+mesh contents:
+mesh axis 1
++----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 | |
+| 21 | 22 | 23 | |
++----+----+----+ |
+| 14 | 15 | 16 | |
+| 24 | 25 | 26 | |
++----+----+----+ ↓
+Transform into
+sharding = [[], [0]]
+mesh axis 1
++----------+----------+ mesh axis 0 |
+| 11 12 13 | 11 12 13 | |
+| 21 22 23 | 21 22 23 | |
++----------+----------+ |
+| 14 15 16 | 14 15 16 | |
+| 24 25 26 | 24 25 26 | |
++----------+----------+ ↓
+all-gather along mesh axis 1
+Reshard 4x8 tensor from sharding [[0], [1, 2]] to sharding [[0], [2]] on a 2x2x2 mesh.
+unsharded 4x8 tensor
+11 12 13 14 15 16 17 18
+21 22 23 24 25 26 27 28
+31 32 33 34 35 36 37 38
+41 42 43 44 45 46 47 48
+sharded on a 2x2x2 mesh
+sharding = [[0], [1, 2]]
+mesh contents:
+mesh axis 2
++-------+-------+ mesh axis 1 | mesh axis 0 |
+| 11 12 | 13 14 | | |
+| 21 22 | 23 24 | | |
++-------+-------+ | |
+| 15 16 | 17 18 | | |
+| 25 26 | 27 28 | | |
++-------+-------+ ↓ |
++-------+-------+ |
+| 31 32 | 33 34 | |
+| 41 42 | 43 44 | |
++-------+-------+ |
+| 35 36 | 37 38 | |
+| 45 46 | 47 48 | |
++-------+-------+ ↓
+Transform into
+sharding = [[0], [2]]
+mesh axis 2
++-------------+-------------+ mesh axis 1 | mesh axis 0 |
+| 11 12 13 14 | 15 16 17 18 | | |
+| 21 22 23 24 | 25 26 27 28 | | |
++-------------+-------------+ | |
+| 11 12 13 14 | 15 16 17 18 | | |
+| 21 22 23 24 | 25 26 27 28 | | |
++-------------+-------------+ ↓ |
++-------------+-------------+ |
+| 31 32 33 34 | 35 36 37 38 | |
+| 41 42 43 44 | 45 46 47 48 | |
++-------------+-------------+ |
+| 31 32 33 34 | 35 36 37 38 | |
+| 41 42 43 44 | 45 46 47 48 | |
++-------------+-------------+ ↓
+Can't be done with just an all-gather along mesh axis 1.
+Can be handled by multiple resharding transformations
+[[0], [1, 2]] -> [[0], [2, 1]] -> [[0], [2]]
+Reshard 6x6 tensor from sharding [[0], [1]] to sharding [[1], [0]] on a 2x3 mesh.
+unsharded 6x6 tensor
+11 12 13 14 15 16
+21 22 23 24 25 26
+31 32 33 34 35 36
+41 42 43 44 45 46
+51 52 53 54 55 56
+61 62 63 64 65 66
+sharded on a 2x3 mesh
+sharding = [[0], [1]]
+mesh axis 1
++-------+-------+-------+ mesh axis 0 |
+| 11 12 | 13 14 | 15 16 | |
+| 21 22 | 23 24 | 25 26 | |
+| 31 32 | 33 34 | 35 36 | |
++-------+-------+-------+ |
+| 41 42 | 43 44 | 45 46 | |
+| 51 52 | 53 54 | 55 56 | |
+| 61 62 | 63 64 | 65 66 | |
++-------+-------+-------+ ↓
+transform to
+sharding = [[1], [0]]
+mesh axis 1
++----------+----------+----------+ mesh axis 0 |
+| 11 12 13 | 31 32 33 | 51 52 53 | |
+| 21 22 23 | 41 42 43 | 61 62 63 | |
++----------+----------+----------+ |
+| 14 15 16 | 34 35 36 | 54 55 56 | |
+| 24 25 26 | 44 45 46 | 64 65 66 | |
++----------+----------+----------+ ↓
+mesh axis 0
++----------+----------+ mesh axis 1 |
+| 11 12 13 | 14 15 16 | |
+| 21 22 23 | 24 25 26 | |
++----------+----------+ |
+| 31 32 33 | 34 35 36 | |
+| 41 42 43 | 44 45 46 | |
++----------+----------+ |
+| 51 52 53 | 54 55 56 | |
+| 61 62 63 | 64 65 66 | |
++----------+----------+ ↓
+Reshard 6x6 tensor from sharding [[0], [1]] to sharding [[1], [0]] on a 2x6 mesh.
+unsharded 6x6 tensor
+11 12 13 14 15 16
+21 22 23 24 25 26
+31 32 33 34 35 36
+41 42 43 44 45 46
+51 52 53 54 55 56
+61 62 63 64 65 66
+shard on 2x6 mesh
+sharding = [[0], [1]]
+mesh axis 1
++----+----+----+----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 ‖ 14 | 15 | 16 | |
+| 21 | 22 | 23 ‖ 24 | 23 | 26 | |
+| 31 | 32 | 33 ‖ 34 | 35 | 36 | |
++----+----+----+----+----+----+ |
+| 41 | 42 | 43 ‖ 44 | 45 | 46 | |
+| 51 | 52 | 53 ‖ 54 | 55 | 56 | |
+| 61 | 62 | 63 ‖ 64 | 65 | 66 | |
++----+----+----+----+----+----+ ↓
+transform to
+sharding = [[1], [0]]
+mesh axis 0
++----------+----------+ mesh axis 1 |
+| 11 12 13 | 14 15 16 | |
++----------+----------+ |
+| 21 22 23 | 24 25 26 | |
++----------+----------+ |
+| 31 32 33 | 34 35 36 | |
++==========+==========+ |
+| 41 42 43 | 44 45 46 | |
++----------+----------+ |
+| 51 52 53 | 54 55 56 | |
++----------+----------+ |
+| 61 62 63 | 64 65 66 | |
++----------+----------+ ↓
+Reshard KxL tensor from [[0], [1]] to [[1], [0]] on MxN mesh.
+M x N mesh.
+K x L tensor t.
+d(m, n) the tensor on device (m, n).
+sharding = [[0], [1]]
+Tensor shard s on each device has size (K ceildiv M, L ceildiv N).
+d(m, n)[k, l] -> t[m * (K ceildiv M) + k, n * (L ceildiv N) + l]
+i <- m * (K ceildiv M) + k
+j <- n * (L ceildiv N) + l
+m -> i floordiv (K ceildiv M)
+n -> j floordiv (L ceildiv N)
+k -> i % (K ceildiv M)
+l -> j % (L ceildiv N)
+For the inverse map we get
+t[i, j] -> d(
+ i floordiv (K ceildiv M), j floordiv (L ceildiv N)
+ i % (K ceildiv M), j % (L ceildiv N)
+i = 13, j = 17, M = 3, N = 4, K = 16, L = 23
+t[13, 17] = d(
+ 13 floordiv (16 ceildiv 3),
+ 17 floordiv (23 ceilvid 4)
+ 13 % (16 ceildiv 3),
+ 17 % (23 ceilvid 4)
+= d(
+ 13 floordiv 6,
+ 17 floordiv 6
+ 13 % 6,
+ 17 % 6
+= d(2, 2)[1, 5]
+= t[
+ 2 * (16 ceildiv 3) + 1,
+ 2 * (23 ceildiv 4) + 5
+= t[
+ 2 * 6 + 1,
+ 2 * 6 + 5
+= t[13, 17]
+sharding = [[1], [0]]
+Tensor shard s on each device has size (K ceildiv N, L ceildiv M).
+d(m, n)[k, l] -> t[n * (K ceildiv N) + k, m * (L ceildiv M) + l]
+i <- n * (K ceildiv N) + k
+j <- m * (L ceildiv M) + l
+m -> j floordiv (L ceildiv M)
+n -> i floordiv (K ceildiv N)
+k -> i % (K ceildiv N)
+l -> j % (L ceildiv M)
+For the inverse map we get
+t[i, j] -> d(
+ j floordiv (L ceildiv M), i floordiv (K ceildiv N)
+ i % (K ceildiv N), j % (L ceildiv M)
+i = 9, j = 19, M = 5, N = 2, K = 27, L = 14
+t[9, 19] = d(
+ 19 floordiv (14 ceildiv 5),
+ 9 floordiv (27 ceildiv 2)
+ 9 % (27 ceildiv 2),
+ 19 % (14 ceildiv 5)
+= d(
+ 19 floordiv 3,
+ 9 floordiv 14
+ 9 % 14
+ 19 % 3
+= d(6, 0)[9, 1]
+= t[
+ 0 * (27 ceildiv 2) + 9,
+ 6 * (14 ceildiv 5) + 1
+= t[
+ 0 * 14 + 9,
+ 6 * 3 + 1
+= t[9, 19]
+sharding = [[0], [1]]
+d(m, n)[k, l] -> t[m * (K ceildiv M) + k, n * (L ceildiv N) + l]
+t[i, j] -> d(i floordiv (K ceildiv M), j floordiv (L ceildiv N))[i % (K ceildiv M), j % (L ceildiv N)]
+sharding = [[1], [0]]
+d(m, n)[k, l] -> t[n * (K ceildiv N) + k, m * (L ceildiv M) + l]
+t[i, j] -> d(j floordiv (L ceildiv M), i floordiv (K ceildiv N))[i % (K ceildiv N), j % (L ceildiv M)]
+sharding [[0], [1]] -> [[1], [0]]
+d1(m, n) the tensor on device (m, n) for sharding sharding [[0], [1]].
+d2(m, n) the tensor on device (m, n) for sharding sharding [[1], [0]].
+d1(m, n)[k, l] ->
+t[m * (K ceildiv M) + k, n * (L ceildiv N) + l] ->
+ (m * (L ceildiv M) + l) floordiv (L ceildiv M),
+ (n * (K ceildiv N) + k) floordiv (K ceildiv N)
+ (n * (K ceildiv N) + k) % (K ceildiv N),
+ (m * (L ceildiv M) + l) % (L ceildiv M)
+= d2(p, q)[u, v]
+We want to copy the the data between devices in slices/tiles.
+What are the source/target tile coordinates?
+Fro a fixed (m, n, p, q) what is the range of (k, l, u, v)?
yaochengji wrote:
For a fixed (m, n, p, q) what is the range of (k, l, u, v)?
More information about the Mlir-commits
mailing list